145 lines
4.2 KiB
Python
145 lines
4.2 KiB
Python
"""Unit tests for data loader module."""
|
|
|
|
import pytest
|
|
import torch
|
|
import pandas as pd
|
|
import os
|
|
from PIL import Image
|
|
import tempfile
|
|
import shutil
|
|
from src.data_loader import PhenologyDataset, get_data_loaders
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_dataset():
|
|
"""Create a temporary dataset for testing."""
|
|
temp_dir = tempfile.mkdtemp()
|
|
img_dir = os.path.join(temp_dir, 'images')
|
|
os.makedirs(img_dir, exist_ok=True)
|
|
|
|
# Create sample images
|
|
phases = ['vegetative', 'flowering', 'fruiting']
|
|
image_paths = []
|
|
|
|
for i, phase in enumerate(phases * 10): # 30 images total
|
|
img = Image.new('RGB', (224, 224), color=(i*8, i*8, i*8))
|
|
img_path = f'img_{i:03d}.jpg'
|
|
img.save(os.path.join(img_dir, img_path))
|
|
image_paths.append((img_path, phase))
|
|
|
|
# Create CSV
|
|
csv_path = os.path.join(temp_dir, 'labels.csv')
|
|
df = pd.DataFrame(image_paths, columns=['image_path', 'phase'])
|
|
df.to_csv(csv_path, index=False)
|
|
|
|
yield csv_path, img_dir, phases
|
|
|
|
# Cleanup
|
|
shutil.rmtree(temp_dir)
|
|
|
|
|
|
@pytest.mark.unit
|
|
def test_dataset_initialization(sample_dataset):
|
|
"""Test dataset initialization."""
|
|
csv_path, img_dir, phases = sample_dataset
|
|
|
|
dataset = PhenologyDataset(
|
|
csv_file=csv_path,
|
|
root_dir=img_dir,
|
|
split='train'
|
|
)
|
|
|
|
assert len(dataset) > 0
|
|
assert len(dataset.classes) == len(phases)
|
|
assert set(dataset.classes) == set(phases)
|
|
|
|
|
|
@pytest.mark.unit
|
|
def test_dataset_split_ratios(sample_dataset):
|
|
"""Test dataset splitting."""
|
|
csv_path, img_dir, _ = sample_dataset
|
|
|
|
train_dataset = PhenologyDataset(csv_path, img_dir, split='train')
|
|
val_dataset = PhenologyDataset(csv_path, img_dir, split='val')
|
|
test_dataset = PhenologyDataset(csv_path, img_dir, split='test')
|
|
|
|
total = len(train_dataset) + len(val_dataset) + len(test_dataset)
|
|
|
|
# Check approximate split ratios (70/15/15)
|
|
assert abs(len(train_dataset) / total - 0.7) < 0.1
|
|
assert abs(len(val_dataset) / total - 0.15) < 0.1
|
|
assert abs(len(test_dataset) / total - 0.15) < 0.1
|
|
|
|
|
|
@pytest.mark.unit
|
|
def test_dataset_getitem(sample_dataset):
|
|
"""Test getting items from dataset."""
|
|
csv_path, img_dir, _ = sample_dataset
|
|
|
|
dataset = PhenologyDataset(csv_path, img_dir, split='train')
|
|
image, label = dataset[0]
|
|
|
|
# Without transform, should return PIL Image
|
|
assert isinstance(label, int)
|
|
assert label >= 0 and label < len(dataset.classes)
|
|
|
|
|
|
@pytest.mark.unit
|
|
def test_data_loaders(sample_dataset):
|
|
"""Test data loader creation."""
|
|
csv_path, img_dir, _ = sample_dataset
|
|
|
|
data_loaders, classes, class_to_idx = get_data_loaders(
|
|
csv_file=csv_path,
|
|
root_dir=img_dir,
|
|
batch_size=4,
|
|
num_workers=0 # Use 0 for testing
|
|
)
|
|
|
|
assert 'train' in data_loaders
|
|
assert 'val' in data_loaders
|
|
assert 'test' in data_loaders
|
|
assert len(classes) > 0
|
|
assert len(class_to_idx) == len(classes)
|
|
|
|
# Test getting a batch
|
|
batch = next(iter(data_loaders['train']))
|
|
images, labels = batch
|
|
|
|
assert images.shape[0] <= 4 # Batch size
|
|
assert images.shape[1] == 3 # RGB channels
|
|
assert images.shape[2] == 224 # Height
|
|
assert images.shape[3] == 224 # Width
|
|
assert len(labels) == images.shape[0]
|
|
|
|
|
|
@pytest.mark.unit
|
|
def test_class_weights(sample_dataset):
|
|
"""Test class weight calculation."""
|
|
csv_path, img_dir, _ = sample_dataset
|
|
|
|
dataset = PhenologyDataset(csv_path, img_dir, split='train')
|
|
weights = dataset.get_class_weights()
|
|
|
|
assert weights.shape[0] == len(dataset.classes)
|
|
assert torch.all(weights > 0)
|
|
assert torch.isclose(weights.sum(), torch.tensor(1.0), atol=1e-6)
|
|
|
|
|
|
@pytest.mark.unit
|
|
def test_reproducibility(sample_dataset):
|
|
"""Test that same seed produces same splits."""
|
|
csv_path, img_dir, _ = sample_dataset
|
|
|
|
dataset1 = PhenologyDataset(csv_path, img_dir, split='train', random_seed=42)
|
|
dataset2 = PhenologyDataset(csv_path, img_dir, split='train', random_seed=42)
|
|
|
|
# Should have same samples
|
|
assert len(dataset1) == len(dataset2)
|
|
|
|
# Check first few samples are the same
|
|
for i in range(min(5, len(dataset1))):
|
|
_, label1 = dataset1[i]
|
|
_, label2 = dataset2[i]
|
|
assert label1 == label2
|