"""Unit tests for evaluation module.""" import pytest import torch import numpy as np from src.evaluate import evaluate_model, plot_confusion_matrix from src.model import ResNetPhenologyClassifier @pytest.fixture def mock_model(): """Create a small model for testing.""" return ResNetPhenologyClassifier(num_classes=3, pretrained=False) @pytest.fixture def mock_dataloader(): """Create a mock dataloader for testing.""" images = torch.randn(8, 3, 224, 224) labels = torch.randint(0, 3, (8,)) dataset = torch.utils.data.TensorDataset(images, labels) return torch.utils.data.DataLoader(dataset, batch_size=4) @pytest.mark.unit def test_evaluate_model(mock_model, mock_dataloader): """Test model evaluation.""" device = 'cpu' model = mock_model.to(device) class_names = ['class_0', 'class_1', 'class_2'] metrics, cm = evaluate_model(model, mock_dataloader, device, class_names) # Check metrics exist assert 'accuracy' in metrics assert 'recall_macro' in metrics assert 'f1_macro' in metrics assert 'per_class_recall' in metrics assert 'per_class_f1' in metrics assert 'confusion_matrix' in metrics # Check metric ranges assert 0 <= metrics['accuracy'] <= 1 assert 0 <= metrics['recall_macro'] <= 1 assert 0 <= metrics['f1_macro'] <= 1 # Check confusion matrix shape assert len(metrics['confusion_matrix']) == 3 assert len(metrics['confusion_matrix'][0]) == 3 assert isinstance(cm, np.ndarray) assert cm.shape == (3, 3) @pytest.mark.unit def test_confusion_matrix_values(mock_model, mock_dataloader): """Test confusion matrix values.""" device = 'cpu' model = mock_model.to(device) class_names = ['class_0', 'class_1', 'class_2'] _, cm = evaluate_model(model, mock_dataloader, device, class_names) # Confusion matrix should sum to total number of samples total_samples = len(mock_dataloader.dataset) assert cm.sum() == total_samples # All values should be non-negative assert np.all(cm >= 0) @pytest.mark.unit def test_per_class_metrics(mock_model, mock_dataloader): """Test per-class metrics.""" device = 'cpu' model = mock_model.to(device) class_names = ['class_0', 'class_1', 'class_2'] metrics, _ = evaluate_model(model, mock_dataloader, device, class_names) # Check per-class metrics have correct number of classes assert len(metrics['per_class_recall']) == 3 assert len(metrics['per_class_f1']) == 3 # Check all classes have metrics for cls in class_names: assert cls in metrics['per_class_recall'] assert cls in metrics['per_class_f1'] assert 0 <= metrics['per_class_recall'][cls] <= 1 assert 0 <= metrics['per_class_f1'][cls] <= 1