Phenology/Code/Supervised_learning/resnet/tests/test_data_loader.py
2025-11-06 14:16:49 +01:00

145 lines
4.2 KiB
Python

"""Unit tests for data loader module."""
import pytest
import torch
import pandas as pd
import os
from PIL import Image
import tempfile
import shutil
from src.data_loader import PhenologyDataset, get_data_loaders
@pytest.fixture
def sample_dataset():
"""Create a temporary dataset for testing."""
temp_dir = tempfile.mkdtemp()
img_dir = os.path.join(temp_dir, 'images')
os.makedirs(img_dir, exist_ok=True)
# Create sample images
phases = ['vegetative', 'flowering', 'fruiting']
image_paths = []
for i, phase in enumerate(phases * 10): # 30 images total
img = Image.new('RGB', (224, 224), color=(i*8, i*8, i*8))
img_path = f'img_{i:03d}.jpg'
img.save(os.path.join(img_dir, img_path))
image_paths.append((img_path, phase))
# Create CSV
csv_path = os.path.join(temp_dir, 'labels.csv')
df = pd.DataFrame(image_paths, columns=['image_path', 'phase'])
df.to_csv(csv_path, index=False)
yield csv_path, img_dir, phases
# Cleanup
shutil.rmtree(temp_dir)
@pytest.mark.unit
def test_dataset_initialization(sample_dataset):
"""Test dataset initialization."""
csv_path, img_dir, phases = sample_dataset
dataset = PhenologyDataset(
csv_file=csv_path,
root_dir=img_dir,
split='train'
)
assert len(dataset) > 0
assert len(dataset.classes) == len(phases)
assert set(dataset.classes) == set(phases)
@pytest.mark.unit
def test_dataset_split_ratios(sample_dataset):
"""Test dataset splitting."""
csv_path, img_dir, _ = sample_dataset
train_dataset = PhenologyDataset(csv_path, img_dir, split='train')
val_dataset = PhenologyDataset(csv_path, img_dir, split='val')
test_dataset = PhenologyDataset(csv_path, img_dir, split='test')
total = len(train_dataset) + len(val_dataset) + len(test_dataset)
# Check approximate split ratios (70/15/15)
assert abs(len(train_dataset) / total - 0.7) < 0.1
assert abs(len(val_dataset) / total - 0.15) < 0.1
assert abs(len(test_dataset) / total - 0.15) < 0.1
@pytest.mark.unit
def test_dataset_getitem(sample_dataset):
"""Test getting items from dataset."""
csv_path, img_dir, _ = sample_dataset
dataset = PhenologyDataset(csv_path, img_dir, split='train')
image, label = dataset[0]
# Without transform, should return PIL Image
assert isinstance(label, int)
assert label >= 0 and label < len(dataset.classes)
@pytest.mark.unit
def test_data_loaders(sample_dataset):
"""Test data loader creation."""
csv_path, img_dir, _ = sample_dataset
data_loaders, classes, class_to_idx = get_data_loaders(
csv_file=csv_path,
root_dir=img_dir,
batch_size=4,
num_workers=0 # Use 0 for testing
)
assert 'train' in data_loaders
assert 'val' in data_loaders
assert 'test' in data_loaders
assert len(classes) > 0
assert len(class_to_idx) == len(classes)
# Test getting a batch
batch = next(iter(data_loaders['train']))
images, labels = batch
assert images.shape[0] <= 4 # Batch size
assert images.shape[1] == 3 # RGB channels
assert images.shape[2] == 224 # Height
assert images.shape[3] == 224 # Width
assert len(labels) == images.shape[0]
@pytest.mark.unit
def test_class_weights(sample_dataset):
"""Test class weight calculation."""
csv_path, img_dir, _ = sample_dataset
dataset = PhenologyDataset(csv_path, img_dir, split='train')
weights = dataset.get_class_weights()
assert weights.shape[0] == len(dataset.classes)
assert torch.all(weights > 0)
assert torch.isclose(weights.sum(), torch.tensor(1.0), atol=1e-6)
@pytest.mark.unit
def test_reproducibility(sample_dataset):
"""Test that same seed produces same splits."""
csv_path, img_dir, _ = sample_dataset
dataset1 = PhenologyDataset(csv_path, img_dir, split='train', random_seed=42)
dataset2 = PhenologyDataset(csv_path, img_dir, split='train', random_seed=42)
# Should have same samples
assert len(dataset1) == len(dataset2)
# Check first few samples are the same
for i in range(min(5, len(dataset1))):
_, label1 = dataset1[i]
_, label2 = dataset2[i]
assert label1 == label2