Phenology/Code/Supervised_learning/resnet/tests/test_inference.py
2025-11-06 14:16:49 +01:00

114 lines
3.3 KiB
Python

"""Unit tests for inference module."""
import pytest
import torch
import json
import tempfile
import os
from PIL import Image
from src.inference import preprocess_image, predict, load_inference_model
from src.model import ResNetPhenologyClassifier
@pytest.fixture
def mock_model():
"""Create a small model for testing."""
return ResNetPhenologyClassifier(num_classes=3, pretrained=False)
@pytest.fixture
def temp_image():
"""Create a temporary test image."""
temp_file = tempfile.NamedTemporaryFile(suffix='.jpg', delete=False)
img = Image.new('RGB', (224, 224), color='red')
img.save(temp_file.name)
yield temp_file.name
os.unlink(temp_file.name)
@pytest.fixture
def temp_class_mapping():
"""Create a temporary class mapping file."""
temp_file = tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False)
mapping = {
'classes': ['vegetative', 'flowering', 'fruiting'],
'class_to_idx': {'vegetative': 0, 'flowering': 1, 'fruiting': 2}
}
json.dump(mapping, temp_file)
temp_file.flush()
yield temp_file.name
os.unlink(temp_file.name)
@pytest.mark.unit
def test_preprocess_image(temp_image):
"""Test image preprocessing."""
image_tensor, original_image = preprocess_image(temp_image)
assert image_tensor.shape == (1, 3, 224, 224) # Batch size 1, RGB, 224x224
assert isinstance(original_image, Image.Image)
assert original_image.size == (224, 224)
@pytest.mark.unit
def test_predict(mock_model, temp_image):
"""Test prediction."""
device = 'cpu'
model = mock_model.to(device)
model.eval()
classes = ['vegetative', 'flowering', 'fruiting']
image_tensor, _ = preprocess_image(temp_image)
result = predict(model, image_tensor, classes, device)
# Check result structure
assert 'phase' in result
assert 'confidence' in result
assert 'probabilities' in result
# Check values
assert result['phase'] in classes
assert 0 <= result['confidence'] <= 1
assert len(result['probabilities']) == len(classes)
# Check probabilities sum to ~1
prob_sum = sum(result['probabilities'].values())
assert abs(prob_sum - 1.0) < 1e-5
@pytest.mark.unit
def test_predict_confidence(mock_model, temp_image):
"""Test that prediction confidence matches max probability."""
device = 'cpu'
model = mock_model.to(device)
model.eval()
classes = ['vegetative', 'flowering', 'fruiting']
image_tensor, _ = preprocess_image(temp_image)
result = predict(model, image_tensor, classes, device)
# Confidence should match the probability of the predicted class
predicted_prob = result['probabilities'][result['phase']]
assert abs(result['confidence'] - predicted_prob) < 1e-6
@pytest.mark.unit
def test_predict_consistency(mock_model, temp_image):
"""Test that predictions are consistent."""
device = 'cpu'
model = mock_model.to(device)
model.eval()
classes = ['vegetative', 'flowering', 'fruiting']
image_tensor, _ = preprocess_image(temp_image)
# Run prediction twice
result1 = predict(model, image_tensor, classes, device)
result2 = predict(model, image_tensor, classes, device)
# Should get same results
assert result1['phase'] == result2['phase']
assert abs(result1['confidence'] - result2['confidence']) < 1e-6