"""Unit tests for data loader module.""" import pytest import torch import pandas as pd import os from PIL import Image import tempfile import shutil from src.data_loader import PhenologyDataset, get_data_loaders @pytest.fixture def sample_dataset(): """Create a temporary dataset for testing.""" temp_dir = tempfile.mkdtemp() img_dir = os.path.join(temp_dir, 'images') os.makedirs(img_dir, exist_ok=True) # Create sample images phases = ['vegetative', 'flowering', 'fruiting'] image_paths = [] for i, phase in enumerate(phases * 10): # 30 images total img = Image.new('RGB', (224, 224), color=(i*8, i*8, i*8)) img_path = f'img_{i:03d}.jpg' img.save(os.path.join(img_dir, img_path)) image_paths.append((img_path, phase)) # Create CSV csv_path = os.path.join(temp_dir, 'labels.csv') df = pd.DataFrame(image_paths, columns=['image_path', 'phase']) df.to_csv(csv_path, index=False) yield csv_path, img_dir, phases # Cleanup shutil.rmtree(temp_dir) @pytest.mark.unit def test_dataset_initialization(sample_dataset): """Test dataset initialization.""" csv_path, img_dir, phases = sample_dataset dataset = PhenologyDataset( csv_file=csv_path, root_dir=img_dir, split='train' ) assert len(dataset) > 0 assert len(dataset.classes) == len(phases) assert set(dataset.classes) == set(phases) @pytest.mark.unit def test_dataset_split_ratios(sample_dataset): """Test dataset splitting.""" csv_path, img_dir, _ = sample_dataset train_dataset = PhenologyDataset(csv_path, img_dir, split='train') val_dataset = PhenologyDataset(csv_path, img_dir, split='val') test_dataset = PhenologyDataset(csv_path, img_dir, split='test') total = len(train_dataset) + len(val_dataset) + len(test_dataset) # Check approximate split ratios (70/15/15) assert abs(len(train_dataset) / total - 0.7) < 0.1 assert abs(len(val_dataset) / total - 0.15) < 0.1 assert abs(len(test_dataset) / total - 0.15) < 0.1 @pytest.mark.unit def test_dataset_getitem(sample_dataset): """Test getting items from dataset.""" csv_path, img_dir, _ = sample_dataset dataset = PhenologyDataset(csv_path, img_dir, split='train') image, label = dataset[0] # Without transform, should return PIL Image assert isinstance(label, int) assert label >= 0 and label < len(dataset.classes) @pytest.mark.unit def test_data_loaders(sample_dataset): """Test data loader creation.""" csv_path, img_dir, _ = sample_dataset data_loaders, classes, class_to_idx = get_data_loaders( csv_file=csv_path, root_dir=img_dir, batch_size=4, num_workers=0 # Use 0 for testing ) assert 'train' in data_loaders assert 'val' in data_loaders assert 'test' in data_loaders assert len(classes) > 0 assert len(class_to_idx) == len(classes) # Test getting a batch batch = next(iter(data_loaders['train'])) images, labels = batch assert images.shape[0] <= 4 # Batch size assert images.shape[1] == 3 # RGB channels assert images.shape[2] == 224 # Height assert images.shape[3] == 224 # Width assert len(labels) == images.shape[0] @pytest.mark.unit def test_class_weights(sample_dataset): """Test class weight calculation.""" csv_path, img_dir, _ = sample_dataset dataset = PhenologyDataset(csv_path, img_dir, split='train') weights = dataset.get_class_weights() assert weights.shape[0] == len(dataset.classes) assert torch.all(weights > 0) assert torch.isclose(weights.sum(), torch.tensor(1.0), atol=1e-6) @pytest.mark.unit def test_reproducibility(sample_dataset): """Test that same seed produces same splits.""" csv_path, img_dir, _ = sample_dataset dataset1 = PhenologyDataset(csv_path, img_dir, split='train', random_seed=42) dataset2 = PhenologyDataset(csv_path, img_dir, split='train', random_seed=42) # Should have same samples assert len(dataset1) == len(dataset2) # Check first few samples are the same for i in range(min(5, len(dataset1))): _, label1 = dataset1[i] _, label2 = dataset2[i] assert label1 == label2