Phenology/Code/Supervised_learning/resnet/tests/test_evaluate.py
2025-11-06 14:16:49 +01:00

90 lines
2.8 KiB
Python

"""Unit tests for evaluation module."""
import pytest
import torch
import numpy as np
from src.evaluate import evaluate_model, plot_confusion_matrix
from src.model import ResNetPhenologyClassifier
@pytest.fixture
def mock_model():
"""Create a small model for testing."""
return ResNetPhenologyClassifier(num_classes=3, pretrained=False)
@pytest.fixture
def mock_dataloader():
"""Create a mock dataloader for testing."""
images = torch.randn(8, 3, 224, 224)
labels = torch.randint(0, 3, (8,))
dataset = torch.utils.data.TensorDataset(images, labels)
return torch.utils.data.DataLoader(dataset, batch_size=4)
@pytest.mark.unit
def test_evaluate_model(mock_model, mock_dataloader):
"""Test model evaluation."""
device = 'cpu'
model = mock_model.to(device)
class_names = ['class_0', 'class_1', 'class_2']
metrics, cm = evaluate_model(model, mock_dataloader, device, class_names)
# Check metrics exist
assert 'accuracy' in metrics
assert 'recall_macro' in metrics
assert 'f1_macro' in metrics
assert 'per_class_recall' in metrics
assert 'per_class_f1' in metrics
assert 'confusion_matrix' in metrics
# Check metric ranges
assert 0 <= metrics['accuracy'] <= 1
assert 0 <= metrics['recall_macro'] <= 1
assert 0 <= metrics['f1_macro'] <= 1
# Check confusion matrix shape
assert len(metrics['confusion_matrix']) == 3
assert len(metrics['confusion_matrix'][0]) == 3
assert isinstance(cm, np.ndarray)
assert cm.shape == (3, 3)
@pytest.mark.unit
def test_confusion_matrix_values(mock_model, mock_dataloader):
"""Test confusion matrix values."""
device = 'cpu'
model = mock_model.to(device)
class_names = ['class_0', 'class_1', 'class_2']
_, cm = evaluate_model(model, mock_dataloader, device, class_names)
# Confusion matrix should sum to total number of samples
total_samples = len(mock_dataloader.dataset)
assert cm.sum() == total_samples
# All values should be non-negative
assert np.all(cm >= 0)
@pytest.mark.unit
def test_per_class_metrics(mock_model, mock_dataloader):
"""Test per-class metrics."""
device = 'cpu'
model = mock_model.to(device)
class_names = ['class_0', 'class_1', 'class_2']
metrics, _ = evaluate_model(model, mock_dataloader, device, class_names)
# Check per-class metrics have correct number of classes
assert len(metrics['per_class_recall']) == 3
assert len(metrics['per_class_f1']) == 3
# Check all classes have metrics
for cls in class_names:
assert cls in metrics['per_class_recall']
assert cls in metrics['per_class_f1']
assert 0 <= metrics['per_class_recall'][cls] <= 1
assert 0 <= metrics['per_class_f1'][cls] <= 1