Phenology/Code/Supervised_learning/resnet
2025-11-06 14:16:49 +01:00
..
.github/prompts Supervised Learning models 2025-11-06 14:16:49 +01:00
.specify Supervised Learning models 2025-11-06 14:16:49 +01:00
specs/1-phenology-classifier Supervised Learning models 2025-11-06 14:16:49 +01:00
src Supervised Learning models 2025-11-06 14:16:49 +01:00
tests Supervised Learning models 2025-11-06 14:16:49 +01:00
.gitignore Supervised Learning models 2025-11-06 14:16:49 +01:00
CONTRIBUTING.md Supervised Learning models 2025-11-06 14:16:49 +01:00
pytest.ini Supervised Learning models 2025-11-06 14:16:49 +01:00
README.md Supervised Learning models 2025-11-06 14:16:49 +01:00
requirements.txt Supervised Learning models 2025-11-06 14:16:49 +01:00

ResNet Phenology Classifier

A deep learning model for classifying plant images by phenological phase using ResNet50 architecture.

Features

  • Training: Train ResNet50 model on labeled plant images
  • Evaluation: Comprehensive metrics including accuracy, recall, macro-F1, and confusion matrix
  • Inference: Classify new images with visual output
  • API: REST API for batch classification
  • Reproducibility: Random seed management and versioning

Dataset

The model uses the dataset located at:

C:\Users\sof12\Desktop\ML\Datasets\Nocciola\GBIF

Dataset split: 70% training, 15% validation, 15% testing

Installation

  1. Create a virtual environment:
python -m venv venv
source venv/bin/activate  # On Windows: venv\Scripts\activate
  1. Install dependencies:
pip install -r requirements.txt

Usage

Training

Train the model on your dataset:

python src/train.py \
    --data_dir C:\Users\sof12\Desktop\ML\Datasets\Nocciola\GBIF \
    --csv_file C:\Users\sof12\Desktop\ML\Datasets\Nocciola\GBIF\labels.csv \
    --output_dir models \
    --epochs 50 \
    --batch_size 32 \
    --lr 0.001

Arguments:

  • --data_dir: Directory containing images
  • --csv_file: Path to CSV file with columns: image_path, phase
  • --output_dir: Directory to save trained models (default: models)
  • --epochs: Number of training epochs (default: 50)
  • --batch_size: Batch size (default: 32)
  • --lr: Learning rate (default: 0.001)
  • --num_workers: Data loader workers (default: 4)
  • --seed: Random seed for reproducibility (default: 42)
  • --patience: Early stopping patience (default: 10)

Evaluation

Evaluate a trained model:

python src/evaluate.py \
    --model_path models/best_model.pth \
    --data_dir C:\Users\sof12\Desktop\ML\Datasets\Nocciola\GBIF \
    --csv_file C:\Users\sof12\Desktop\ML\Datasets\Nocciola\GBIF\labels.csv \
    --class_mapping models/class_mapping.json \
    --output_dir evaluation \
    --split test

Arguments:

  • --model_path: Path to trained model checkpoint
  • --data_dir: Directory containing images
  • --csv_file: Path to CSV file with labels
  • --class_mapping: Path to class mapping JSON file
  • --output_dir: Directory to save evaluation results (default: evaluation)
  • --split: Dataset split to evaluate (val or test, default: test)

Output:

  • Accuracy, Recall (macro), F1 (macro)
  • Per-class metrics
  • Confusion matrix (normalized and raw)
  • Classification report

Inference

Classify a single image:

python src/inference.py \
    --image_path path/to/image.jpg \
    --model_path models/best_model.pth \
    --class_mapping models/class_mapping.json \
    --output_dir results

Arguments:

  • --image_path: Path to image file
  • --model_path: Path to trained model checkpoint
  • --class_mapping: Path to class mapping JSON file
  • --output_dir: Directory to save results (optional)
  • --no_visualize: Skip creating visualization

Output:

  • Predicted phenological phase
  • Confidence score
  • Probabilities for all classes
  • Visual output showing prediction and probabilities

API Server

Start the FastAPI server:

# Set environment variables
export MODEL_PATH=models/best_model.pth
export CLASS_MAPPING_PATH=models/class_mapping.json

# Start server
python -m uvicorn src.api:app --host 0.0.0.0 --port 8000 --reload

Endpoints:

  • GET /: API information
  • GET /health: Health check
  • GET /classes: List available classes
  • POST /classify: Classify single image
  • POST /classify/batch: Classify multiple images (max 10)

Example request:

curl -X POST "http://localhost:8000/classify" \
    -H "Content-Type: multipart/form-data" \
    -F "file=@path/to/image.jpg"

Testing

Run unit tests:

pytest tests/ -v

Run specific test categories:

pytest tests/ -v -m unit           # Unit tests only
pytest tests/ -v -m integration    # Integration tests only

Project Structure

.
├── src/
│   ├── __init__.py
│   ├── data_loader.py      # Dataset and data loading
│   ├── model.py            # ResNet model definition
│   ├── train.py            # Training script
│   ├── evaluate.py         # Evaluation script
│   ├── inference.py        # Inference script
│   ├── api.py              # FastAPI application
│   └── utils.py            # Utility functions
├── tests/
│   ├── test_data_loader.py
│   ├── test_train.py
│   ├── test_evaluate.py
│   └── test_inference.py
├── data/                   # Dataset directory
├── models/                 # Saved models
├── specs/                  # Feature specifications
├── requirements.txt        # Python dependencies
├── pytest.ini             # Pytest configuration
└── README.md              # This file

Model Architecture

  • Base: ResNet50 pretrained on ImageNet
  • Classification Head:
    • Dropout (0.5)
    • Linear (2048 → 512)
    • ReLU
    • Dropout (0.3)
    • Linear (512 → num_classes)

Performance Requirements

  • Accuracy: >90% on test set
  • Training Time: <1 hour for standard dataset
  • Inference Time: <1 second per image

Data Format

The CSV file should have the following format:

image_path,phase
images/img001.jpg,vegetative
images/img002.jpg,flowering
images/img003.jpg,fruiting

Reproducibility

The project ensures reproducibility through:

  • Random seed management (default: 42)
  • Deterministic training
  • Model checkpointing
  • Class mapping versioning

License

This project is part of a supervised learning implementation for phenology classification.

Citation

If you use this code, please cite the project specifications in specs/1-phenology-classifier/.