"""Unit tests for training module.""" import pytest import torch import torch.nn as nn from src.train import train_epoch, validate from src.model import ResNetPhenologyClassifier @pytest.fixture def mock_model(): """Create a small model for testing.""" return ResNetPhenologyClassifier(num_classes=3, pretrained=False) @pytest.fixture def mock_dataloader(): """Create a mock dataloader for testing.""" # Create dummy data images = torch.randn(8, 3, 224, 224) labels = torch.randint(0, 3, (8,)) dataset = torch.utils.data.TensorDataset(images, labels) return torch.utils.data.DataLoader(dataset, batch_size=4) @pytest.mark.unit def test_train_epoch(mock_model, mock_dataloader): """Test training for one epoch.""" device = 'cpu' model = mock_model.to(device) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.001) loss, acc = train_epoch(model, mock_dataloader, criterion, optimizer, device) assert isinstance(loss, float) assert isinstance(acc, float) assert loss >= 0 assert 0 <= acc <= 100 @pytest.mark.unit def test_validate(mock_model, mock_dataloader): """Test validation.""" device = 'cpu' model = mock_model.to(device) criterion = nn.CrossEntropyLoss() loss, acc = validate(model, mock_dataloader, criterion, device) assert isinstance(loss, float) assert isinstance(acc, float) assert loss >= 0 assert 0 <= acc <= 100 @pytest.mark.unit def test_model_gradients(mock_model): """Test that gradients are computed during training.""" device = 'cpu' model = mock_model.to(device) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters()) # Create dummy batch images = torch.randn(2, 3, 224, 224) labels = torch.randint(0, 3, (2,)) # Forward pass outputs = model(images) loss = criterion(outputs, labels) # Backward pass optimizer.zero_grad() loss.backward() # Check gradients exist has_gradients = any(p.grad is not None for p in model.parameters()) assert has_gradients @pytest.mark.unit def test_model_inference_mode(mock_model): """Test that model can switch between train and eval modes.""" model = mock_model # Train mode model.train() assert model.training # Eval mode model.eval() assert not model.training