225 lines
5.8 KiB
Markdown
225 lines
5.8 KiB
Markdown
# ResNet Phenology Classifier
|
|
|
|
A deep learning model for classifying plant images by phenological phase using ResNet50 architecture.
|
|
|
|
## Features
|
|
|
|
- **Training**: Train ResNet50 model on labeled plant images
|
|
- **Evaluation**: Comprehensive metrics including accuracy, recall, macro-F1, and confusion matrix
|
|
- **Inference**: Classify new images with visual output
|
|
- **API**: REST API for batch classification
|
|
- **Reproducibility**: Random seed management and versioning
|
|
|
|
## Dataset
|
|
|
|
The model uses the dataset located at:
|
|
```
|
|
C:\Users\sof12\Desktop\ML\Datasets\Nocciola\GBIF
|
|
```
|
|
|
|
Dataset split: 70% training, 15% validation, 15% testing
|
|
|
|
## Installation
|
|
|
|
1. Create a virtual environment:
|
|
```bash
|
|
python -m venv venv
|
|
source venv/bin/activate # On Windows: venv\Scripts\activate
|
|
```
|
|
|
|
2. Install dependencies:
|
|
```bash
|
|
pip install -r requirements.txt
|
|
```
|
|
|
|
## Usage
|
|
|
|
### Training
|
|
|
|
Train the model on your dataset:
|
|
|
|
```bash
|
|
python src/train.py \
|
|
--data_dir C:\Users\sof12\Desktop\ML\Datasets\Nocciola\GBIF \
|
|
--csv_file C:\Users\sof12\Desktop\ML\Datasets\Nocciola\GBIF\labels.csv \
|
|
--output_dir models \
|
|
--epochs 50 \
|
|
--batch_size 32 \
|
|
--lr 0.001
|
|
```
|
|
|
|
**Arguments:**
|
|
- `--data_dir`: Directory containing images
|
|
- `--csv_file`: Path to CSV file with columns: `image_path`, `phase`
|
|
- `--output_dir`: Directory to save trained models (default: `models`)
|
|
- `--epochs`: Number of training epochs (default: 50)
|
|
- `--batch_size`: Batch size (default: 32)
|
|
- `--lr`: Learning rate (default: 0.001)
|
|
- `--num_workers`: Data loader workers (default: 4)
|
|
- `--seed`: Random seed for reproducibility (default: 42)
|
|
- `--patience`: Early stopping patience (default: 10)
|
|
|
|
### Evaluation
|
|
|
|
Evaluate a trained model:
|
|
|
|
```bash
|
|
python src/evaluate.py \
|
|
--model_path models/best_model.pth \
|
|
--data_dir C:\Users\sof12\Desktop\ML\Datasets\Nocciola\GBIF \
|
|
--csv_file C:\Users\sof12\Desktop\ML\Datasets\Nocciola\GBIF\labels.csv \
|
|
--class_mapping models/class_mapping.json \
|
|
--output_dir evaluation \
|
|
--split test
|
|
```
|
|
|
|
**Arguments:**
|
|
- `--model_path`: Path to trained model checkpoint
|
|
- `--data_dir`: Directory containing images
|
|
- `--csv_file`: Path to CSV file with labels
|
|
- `--class_mapping`: Path to class mapping JSON file
|
|
- `--output_dir`: Directory to save evaluation results (default: `evaluation`)
|
|
- `--split`: Dataset split to evaluate (`val` or `test`, default: `test`)
|
|
|
|
**Output:**
|
|
- Accuracy, Recall (macro), F1 (macro)
|
|
- Per-class metrics
|
|
- Confusion matrix (normalized and raw)
|
|
- Classification report
|
|
|
|
### Inference
|
|
|
|
Classify a single image:
|
|
|
|
```bash
|
|
python src/inference.py \
|
|
--image_path path/to/image.jpg \
|
|
--model_path models/best_model.pth \
|
|
--class_mapping models/class_mapping.json \
|
|
--output_dir results
|
|
```
|
|
|
|
**Arguments:**
|
|
- `--image_path`: Path to image file
|
|
- `--model_path`: Path to trained model checkpoint
|
|
- `--class_mapping`: Path to class mapping JSON file
|
|
- `--output_dir`: Directory to save results (optional)
|
|
- `--no_visualize`: Skip creating visualization
|
|
|
|
**Output:**
|
|
- Predicted phenological phase
|
|
- Confidence score
|
|
- Probabilities for all classes
|
|
- Visual output showing prediction and probabilities
|
|
|
|
### API Server
|
|
|
|
Start the FastAPI server:
|
|
|
|
```bash
|
|
# Set environment variables
|
|
export MODEL_PATH=models/best_model.pth
|
|
export CLASS_MAPPING_PATH=models/class_mapping.json
|
|
|
|
# Start server
|
|
python -m uvicorn src.api:app --host 0.0.0.0 --port 8000 --reload
|
|
```
|
|
|
|
**Endpoints:**
|
|
- `GET /`: API information
|
|
- `GET /health`: Health check
|
|
- `GET /classes`: List available classes
|
|
- `POST /classify`: Classify single image
|
|
- `POST /classify/batch`: Classify multiple images (max 10)
|
|
|
|
**Example request:**
|
|
```bash
|
|
curl -X POST "http://localhost:8000/classify" \
|
|
-H "Content-Type: multipart/form-data" \
|
|
-F "file=@path/to/image.jpg"
|
|
```
|
|
|
|
## Testing
|
|
|
|
Run unit tests:
|
|
|
|
```bash
|
|
pytest tests/ -v
|
|
```
|
|
|
|
Run specific test categories:
|
|
|
|
```bash
|
|
pytest tests/ -v -m unit # Unit tests only
|
|
pytest tests/ -v -m integration # Integration tests only
|
|
```
|
|
|
|
## Project Structure
|
|
|
|
```
|
|
.
|
|
├── src/
|
|
│ ├── __init__.py
|
|
│ ├── data_loader.py # Dataset and data loading
|
|
│ ├── model.py # ResNet model definition
|
|
│ ├── train.py # Training script
|
|
│ ├── evaluate.py # Evaluation script
|
|
│ ├── inference.py # Inference script
|
|
│ ├── api.py # FastAPI application
|
|
│ └── utils.py # Utility functions
|
|
├── tests/
|
|
│ ├── test_data_loader.py
|
|
│ ├── test_train.py
|
|
│ ├── test_evaluate.py
|
|
│ └── test_inference.py
|
|
├── data/ # Dataset directory
|
|
├── models/ # Saved models
|
|
├── specs/ # Feature specifications
|
|
├── requirements.txt # Python dependencies
|
|
├── pytest.ini # Pytest configuration
|
|
└── README.md # This file
|
|
```
|
|
|
|
## Model Architecture
|
|
|
|
- **Base**: ResNet50 pretrained on ImageNet
|
|
- **Classification Head**:
|
|
- Dropout (0.5)
|
|
- Linear (2048 → 512)
|
|
- ReLU
|
|
- Dropout (0.3)
|
|
- Linear (512 → num_classes)
|
|
|
|
## Performance Requirements
|
|
|
|
- **Accuracy**: >90% on test set
|
|
- **Training Time**: <1 hour for standard dataset
|
|
- **Inference Time**: <1 second per image
|
|
|
|
## Data Format
|
|
|
|
The CSV file should have the following format:
|
|
|
|
```csv
|
|
image_path,phase
|
|
images/img001.jpg,vegetative
|
|
images/img002.jpg,flowering
|
|
images/img003.jpg,fruiting
|
|
```
|
|
|
|
## Reproducibility
|
|
|
|
The project ensures reproducibility through:
|
|
- Random seed management (default: 42)
|
|
- Deterministic training
|
|
- Model checkpointing
|
|
- Class mapping versioning
|
|
|
|
## License
|
|
|
|
This project is part of a supervised learning implementation for phenology classification.
|
|
|
|
## Citation
|
|
|
|
If you use this code, please cite the project specifications in `specs/1-phenology-classifier/`.
|