90 lines
2.8 KiB
Python
90 lines
2.8 KiB
Python
"""Unit tests for evaluation module."""
|
|
|
|
import pytest
|
|
import torch
|
|
import numpy as np
|
|
from src.evaluate import evaluate_model, plot_confusion_matrix
|
|
from src.model import ResNetPhenologyClassifier
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_model():
|
|
"""Create a small model for testing."""
|
|
return ResNetPhenologyClassifier(num_classes=3, pretrained=False)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_dataloader():
|
|
"""Create a mock dataloader for testing."""
|
|
images = torch.randn(8, 3, 224, 224)
|
|
labels = torch.randint(0, 3, (8,))
|
|
dataset = torch.utils.data.TensorDataset(images, labels)
|
|
return torch.utils.data.DataLoader(dataset, batch_size=4)
|
|
|
|
|
|
@pytest.mark.unit
|
|
def test_evaluate_model(mock_model, mock_dataloader):
|
|
"""Test model evaluation."""
|
|
device = 'cpu'
|
|
model = mock_model.to(device)
|
|
class_names = ['class_0', 'class_1', 'class_2']
|
|
|
|
metrics, cm = evaluate_model(model, mock_dataloader, device, class_names)
|
|
|
|
# Check metrics exist
|
|
assert 'accuracy' in metrics
|
|
assert 'recall_macro' in metrics
|
|
assert 'f1_macro' in metrics
|
|
assert 'per_class_recall' in metrics
|
|
assert 'per_class_f1' in metrics
|
|
assert 'confusion_matrix' in metrics
|
|
|
|
# Check metric ranges
|
|
assert 0 <= metrics['accuracy'] <= 1
|
|
assert 0 <= metrics['recall_macro'] <= 1
|
|
assert 0 <= metrics['f1_macro'] <= 1
|
|
|
|
# Check confusion matrix shape
|
|
assert len(metrics['confusion_matrix']) == 3
|
|
assert len(metrics['confusion_matrix'][0]) == 3
|
|
assert isinstance(cm, np.ndarray)
|
|
assert cm.shape == (3, 3)
|
|
|
|
|
|
@pytest.mark.unit
|
|
def test_confusion_matrix_values(mock_model, mock_dataloader):
|
|
"""Test confusion matrix values."""
|
|
device = 'cpu'
|
|
model = mock_model.to(device)
|
|
class_names = ['class_0', 'class_1', 'class_2']
|
|
|
|
_, cm = evaluate_model(model, mock_dataloader, device, class_names)
|
|
|
|
# Confusion matrix should sum to total number of samples
|
|
total_samples = len(mock_dataloader.dataset)
|
|
assert cm.sum() == total_samples
|
|
|
|
# All values should be non-negative
|
|
assert np.all(cm >= 0)
|
|
|
|
|
|
@pytest.mark.unit
|
|
def test_per_class_metrics(mock_model, mock_dataloader):
|
|
"""Test per-class metrics."""
|
|
device = 'cpu'
|
|
model = mock_model.to(device)
|
|
class_names = ['class_0', 'class_1', 'class_2']
|
|
|
|
metrics, _ = evaluate_model(model, mock_dataloader, device, class_names)
|
|
|
|
# Check per-class metrics have correct number of classes
|
|
assert len(metrics['per_class_recall']) == 3
|
|
assert len(metrics['per_class_f1']) == 3
|
|
|
|
# Check all classes have metrics
|
|
for cls in class_names:
|
|
assert cls in metrics['per_class_recall']
|
|
assert cls in metrics['per_class_f1']
|
|
assert 0 <= metrics['per_class_recall'][cls] <= 1
|
|
assert 0 <= metrics['per_class_f1'][cls] <= 1
|