Phenology/Code/Supervised_learning/resnet/tests/test_train.py
2025-11-06 14:16:49 +01:00

94 lines
2.4 KiB
Python

"""Unit tests for training module."""
import pytest
import torch
import torch.nn as nn
from src.train import train_epoch, validate
from src.model import ResNetPhenologyClassifier
@pytest.fixture
def mock_model():
"""Create a small model for testing."""
return ResNetPhenologyClassifier(num_classes=3, pretrained=False)
@pytest.fixture
def mock_dataloader():
"""Create a mock dataloader for testing."""
# Create dummy data
images = torch.randn(8, 3, 224, 224)
labels = torch.randint(0, 3, (8,))
dataset = torch.utils.data.TensorDataset(images, labels)
return torch.utils.data.DataLoader(dataset, batch_size=4)
@pytest.mark.unit
def test_train_epoch(mock_model, mock_dataloader):
"""Test training for one epoch."""
device = 'cpu'
model = mock_model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss, acc = train_epoch(model, mock_dataloader, criterion, optimizer, device)
assert isinstance(loss, float)
assert isinstance(acc, float)
assert loss >= 0
assert 0 <= acc <= 100
@pytest.mark.unit
def test_validate(mock_model, mock_dataloader):
"""Test validation."""
device = 'cpu'
model = mock_model.to(device)
criterion = nn.CrossEntropyLoss()
loss, acc = validate(model, mock_dataloader, criterion, device)
assert isinstance(loss, float)
assert isinstance(acc, float)
assert loss >= 0
assert 0 <= acc <= 100
@pytest.mark.unit
def test_model_gradients(mock_model):
"""Test that gradients are computed during training."""
device = 'cpu'
model = mock_model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
# Create dummy batch
images = torch.randn(2, 3, 224, 224)
labels = torch.randint(0, 3, (2,))
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward pass
optimizer.zero_grad()
loss.backward()
# Check gradients exist
has_gradients = any(p.grad is not None for p in model.parameters())
assert has_gradients
@pytest.mark.unit
def test_model_inference_mode(mock_model):
"""Test that model can switch between train and eval modes."""
model = mock_model
# Train mode
model.train()
assert model.training
# Eval mode
model.eval()
assert not model.training