94 lines
2.4 KiB
Python
94 lines
2.4 KiB
Python
"""Unit tests for training module."""
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.nn as nn
|
|
from src.train import train_epoch, validate
|
|
from src.model import ResNetPhenologyClassifier
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_model():
|
|
"""Create a small model for testing."""
|
|
return ResNetPhenologyClassifier(num_classes=3, pretrained=False)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_dataloader():
|
|
"""Create a mock dataloader for testing."""
|
|
# Create dummy data
|
|
images = torch.randn(8, 3, 224, 224)
|
|
labels = torch.randint(0, 3, (8,))
|
|
dataset = torch.utils.data.TensorDataset(images, labels)
|
|
return torch.utils.data.DataLoader(dataset, batch_size=4)
|
|
|
|
|
|
@pytest.mark.unit
|
|
def test_train_epoch(mock_model, mock_dataloader):
|
|
"""Test training for one epoch."""
|
|
device = 'cpu'
|
|
model = mock_model.to(device)
|
|
criterion = nn.CrossEntropyLoss()
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
|
|
|
loss, acc = train_epoch(model, mock_dataloader, criterion, optimizer, device)
|
|
|
|
assert isinstance(loss, float)
|
|
assert isinstance(acc, float)
|
|
assert loss >= 0
|
|
assert 0 <= acc <= 100
|
|
|
|
|
|
@pytest.mark.unit
|
|
def test_validate(mock_model, mock_dataloader):
|
|
"""Test validation."""
|
|
device = 'cpu'
|
|
model = mock_model.to(device)
|
|
criterion = nn.CrossEntropyLoss()
|
|
|
|
loss, acc = validate(model, mock_dataloader, criterion, device)
|
|
|
|
assert isinstance(loss, float)
|
|
assert isinstance(acc, float)
|
|
assert loss >= 0
|
|
assert 0 <= acc <= 100
|
|
|
|
|
|
@pytest.mark.unit
|
|
def test_model_gradients(mock_model):
|
|
"""Test that gradients are computed during training."""
|
|
device = 'cpu'
|
|
model = mock_model.to(device)
|
|
criterion = nn.CrossEntropyLoss()
|
|
optimizer = torch.optim.Adam(model.parameters())
|
|
|
|
# Create dummy batch
|
|
images = torch.randn(2, 3, 224, 224)
|
|
labels = torch.randint(0, 3, (2,))
|
|
|
|
# Forward pass
|
|
outputs = model(images)
|
|
loss = criterion(outputs, labels)
|
|
|
|
# Backward pass
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
|
|
# Check gradients exist
|
|
has_gradients = any(p.grad is not None for p in model.parameters())
|
|
assert has_gradients
|
|
|
|
|
|
@pytest.mark.unit
|
|
def test_model_inference_mode(mock_model):
|
|
"""Test that model can switch between train and eval modes."""
|
|
model = mock_model
|
|
|
|
# Train mode
|
|
model.train()
|
|
assert model.training
|
|
|
|
# Eval mode
|
|
model.eval()
|
|
assert not model.training
|