Phenology/Code/Supervised_learning/Nocciola/EfficientNetV2_Nocciola.py
2025-11-25 11:30:37 +01:00

873 lines
33 KiB
Python

"""
EfficientNetV2 Transfer Learning for Phenological Phase Classification - Hazelnut/Artichoke
This code was originally developed in Google Colab and has been adapted for Visual Studio Code.
Dataset Path: C:/Users/sof12/Desktop/ML/Datasets/Nocciola/GBIF
Objective: Predict phenological phase R (reproductive)
"""
#----------------- IMPORTS -----------------
import os
import shutil
import random
import argparse
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import EfficientNetV2B0
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.utils import class_weight
import json
# ----------------- CONFIG -----------------
PROJECT_PATH = ""
IMAGES_DIR = PROJECT_PATH # The images are in the main project directory
CSV_PATH = os.path.join(PROJECT_PATH, 'tags.csv') # Main CSV
OUTPUT_DIR = os.path.join(PROJECT_PATH, 'results_efficientnetv2_fase_V_Combi(AV)')
IMG_SIZE = (480, 480) # Recommended for EfficientNetV2B0
BATCH_SIZE = 8 # Reduced due to larger image size
SEED = 42 # Seed for reproducibility
SPLIT = {'train': 0.7, 'val': 0.15, 'test': 0.15}
FORCE_SPLIT = True # Whether to force re-creation of the data split
# ----------------- Utilities -----------------
def set_seed(seed=42):
"""Set seed for reproducibility"""
random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)
def analyze_class_distribution(df, column_name):
"""Analyze class distribution and detect imbalances"""
print(f"\n Analyzing class distribution for column: '{column_name}'")
min_samples = 2 # Recommended minimum threshold
# Count by class
counts = df[column_name].value_counts()
total = len(df)
real_split = []
print(f" Total samples: {total}")
print(f" Number of classes: {len(counts)}")
print(f" Class distribution:")
# Show detailed statistics
for clase, count in counts.items():
percentage = (count / total) * 100
if count >= min_samples:
real_split = clase
print(f" - {clase}: {count} samples ({percentage:.1f}%)")
# Detect problematic classes
small_classes = counts[counts < min_samples]
if len(small_classes) > 0:
print(f"\n Classes with less than {min_samples} samples:")
for clase, count in small_classes.items():
print(f" - {clase}: {count} samples")
print(f"\n Recommendations:")
print(f" 1. Consider collecting more data for these classes")
print(f" 2. Or merge similar classes")
print(f" 3. Or use specific data augmentation techniques")
return real_split
def safe_read_csv(path):
"""Read CSV with encoding handling"""
if not os.path.exists(path):
raise FileNotFoundError(f'CSV not found: {path}')
try:
df = pd.read_csv(path, encoding='utf-8')
except UnicodeDecodeError:
try:
df = pd.read_csv(path, encoding='latin-1')
except:
df = pd.read_csv(path, encoding='iso-8859-1')
return df
def resolve_image_path(images_dir, img_id):
"""Resolve the full path of an image"""
if pd.isna(img_id) or str(img_id).strip() == '':
return None
img_id = str(img_id).strip()
# Verify if the image exists directly
direct_path = os.path.join(images_dir, img_id)
if os.path.exists(direct_path):
return direct_path
# Try common extensions
stem = os.path.splitext(img_id)[0]
for ext in ['.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG']:
img_path = os.path.join(images_dir, stem + ext)
if os.path.exists(img_path):
return img_path
return None
def prepare_image_folders(df, images_dir, out_dir, split=SPLIT, seed=SEED, column_name='fase R', test_AV=False, class_split=[]):
"""Create folder structure for flow_from_directory"""
set_seed(seed)
# Filter only rows with valid phase and existing images
print(f"Initial data: {len(df)} rows")
# Filter rows with valid phase
df_valid = df.dropna(subset=[column_name]).copy()
df_valid = df_valid[df_valid[column_name].str.strip() != '']
print(f"With valid phase: {len(df_valid)} rows")
# Verify existence of images
valid_rows = []
for _, row in df_valid.iterrows():
img_path = resolve_image_path(images_dir, row['id_img'])
if img_path:
valid_rows.append(row)
else:
print(f"Image not found: {row['id_img']}")
if not valid_rows:
raise ValueError("No valid images found")
df_final = pd.DataFrame(valid_rows)
print(f"With existing images: {len(df_final)} rows")
# Show class distribution
fase_counts = df_final[column_name].value_counts()
print(f"\n Distribution of phase:")
for fase, count in fase_counts.items():
print(f" - {fase}: {count} images")
# Remove classes with very few samples (less than 3)
min_samples = 3
valid_phases = fase_counts[fase_counts >= min_samples].index.tolist()
if len(valid_phases) < len(fase_counts):
excluded = fase_counts[fase_counts < min_samples].index.tolist()
print(f"Excluded phases with less than {min_samples} samples: {excluded}")
df_final = df_final[df_final[column_name].isin(valid_phases)]
print(f"After filtering: {len(df_final)} rows, {len(valid_phases)} classes")
labels = df_final[column_name].unique().tolist()
print(f"Final classes: {labels}")
# Standard split - Stratified by class
train_dfs = []
val_dfs = []
test_dfs = []
# Shuffle and split data
if test_AV:
print("\n=== Using test_AV mode: NocciolaAV images prioritized for test ===")
# Split each class separately to maintain proportions
for label in labels:
# Filter data for this class
df_class = df_final[df_final[column_name] == label].copy()
n_class = len(df_class)
# Filter for AV dataset
df_class['is_av'] = df_class['id_img'].astype(str).str.contains('NocciolaAV', case=False, na=False)
df_class_av = df_class[df_class['is_av']].copy()
df_class_non_av = df_class[~df_class['is_av']].copy()
# Shuffle this class
df_class_av_shuffled = df_class_av.sample(frac=1, random_state=seed).reset_index(drop=True)
df_class_non_av_shuffled = df_class_non_av.sample(frac=1, random_state=seed).reset_index(drop=True)
# Remove the helper column
df_class_av_shuffled = df_class_av_shuffled.drop(columns=['is_av'], errors='ignore')
df_class_non_av_shuffled = df_class_non_av_shuffled.drop(columns=['is_av'], errors='ignore')
# Calculate split sizes for this class
n_train_class = int(n_class * split['train'])
n_val_class = int(n_class * split['val'])
n_test_class = n_class - n_train_class - n_val_class
# Split this class
if len(df_class_av_shuffled) > n_test_class:
test_class = df_class_av_shuffled.iloc[:n_test_class]
df_class_non_av_shuffled = pd.concat([df_class_av_shuffled.iloc[n_test_class:], df_class_non_av_shuffled], ignore_index=True)
df_class_non_av_shuffled = df_class_non_av_shuffled.sample(frac=1, random_state=seed).reset_index(drop=True)
elif len(df_class_av_shuffled) == n_test_class:
test_class = df_class_av_shuffled
else:
needed = n_test_class - len(df_class_av_shuffled)
test_class = pd.concat([df_class_av_shuffled, df_class_non_av_shuffled.iloc[:needed]], ignore_index=True)
test_class = test_class.sample(frac=1, random_state=seed).reset_index(drop=True)
if needed:
train_class = df_class_non_av_shuffled.iloc[needed:n_train_class]
else:
train_class = df_class_non_av_shuffled.iloc[:n_train_class]
val_class = df_class_non_av_shuffled.iloc[n_test_class + n_train_class:]
# Store splits
train_dfs.append(train_class)
val_dfs.append(val_class)
test_dfs.append(test_class)
print(f" Class '{label}': {n_class} total -> Train: {len(train_class)}, Val: {len(val_class)}, Test: {len(test_class)}")
else:
print("\n=== Standard stratified split ===")
# Split each class separately to maintain proportions
for label in labels:
# Filter data for this class
df_class = df_final[df_final[column_name] == label].copy()
n_class = len(df_class)
# Shuffle this class
df_class_shuffled = df_class.sample(frac=1, random_state=seed).reset_index(drop=True)
# Calculate split sizes for this class
n_train_class = int(n_class * split['train'])
n_val_class = int(n_class * split['val'])
#n_test_class = n_class - n_train_class - n_val_class
# Split this class
train_class = df_class_shuffled.iloc[:n_train_class]
val_class = df_class_shuffled.iloc[n_train_class:n_train_class + n_val_class]
test_class = df_class_shuffled.iloc[n_train_class + n_val_class:]
# Store splits
train_dfs.append(train_class)
val_dfs.append(val_class)
test_dfs.append(test_class)
print(f" Class '{label}': {n_class} total -> Train: {len(train_class)}, Val: {len(val_class)}, Test: {len(test_class)}")
# Combine all classes for each split
train_df = pd.concat(train_dfs, ignore_index=True).sample(frac=1, random_state=seed).reset_index(drop=True)
val_df = pd.concat(val_dfs, ignore_index=True).sample(frac=1, random_state=seed).reset_index(drop=True)
test_df = pd.concat(test_dfs, ignore_index=True).sample(frac=1, random_state=seed).reset_index(drop=True)
print(f"\nFinal split (stratified):")
print(f" - Training: {len(train_df)} images")
print(f" - Validation: {len(val_df)} images")
print(f" - Test: {len(test_df)} images")
# Verify class distribution in each split
print(f"\nClass distribution verification:")
for subset_name, subset_df in [('Train', train_df), ('Val', val_df), ('Test', test_df)]:
dist = subset_df[column_name].value_counts().sort_index()
print(f" {subset_name}: {dict(dist)}")
# Create folder structure
for part in ['train', 'val', 'test']:
for label in labels:
label_dir = os.path.join(out_dir, part, str(label))
os.makedirs(label_dir, exist_ok=True)
# Function to copy images
def copy_subset(subdf, subset_name, column_name):
copied, missing = 0, 0
failed = []
miss = []
for _, row in subdf.iterrows():
src = resolve_image_path(images_dir, row['id_img'])
if src:
fase = str(row[column_name])
# Crear carpeta de destino si no existe (por si acaso)
dest_dir = os.path.join(out_dir, subset_name, fase)
os.makedirs(dest_dir, exist_ok=True)
# Usar el nombre original del archivo (con extensión correcta)
original_filename = os.path.basename(src)
dst = os.path.join(dest_dir, original_filename)
# Verificar que la ruta de destino no sea demasiado larga
if len(dst) > 260:
print(f" Ruta demasiado larga ({len(dst)} caracteres): {dst}")
# Crear nombre más corto
ext = os.path.splitext(original_filename)[1]
short_name = f"img_{copied:04d}{ext}"
dst = os.path.join(dest_dir, short_name)
print(f" Usando nombre corto: {short_name}")
try:
# Verificar que origen y destino existen/son válidos
if not os.path.exists(src):
print(f" Archivo origen no existe: {src}")
missing += 1
miss.append(row['id_img'])
continue
if not os.path.exists(dest_dir):
print(f" Carpeta destino no existe: {dest_dir}")
os.makedirs(dest_dir, exist_ok=True)
shutil.copy2(src, dst)
copied += 1
except Exception as e:
print(f" Error copying {src} to {dst}: {e}")
print(f" - Source exists: {os.path.exists(src)}")
print(f" - Dest dir exists: {os.path.exists(dest_dir)}")
print(f" - Source path length: {len(src)}")
print(f" - Dest path length: {len(dst)}")
failed.append(src)
missing += 1
else:
missing += 1
miss.append(row['id_img'])
miss_file_path = os.path.join(os.getcwd(), f'missing_{subset_name}.txt')
with open(miss_file_path, 'w') as f:
f.write(f"Missing images in {subset_name}:\n")
for item in miss:
f.write(f"{item}\n")
f.write(f"\n\n Failed to copy:\n")
for item in failed:
f.write(f"{item}\n")
print(f" {subset_name}: {copied} images copied, {missing} failed")
return copied
# Copy images to the corresponding folders
copy_subset(train_df, 'train', column_name)
copy_subset(val_df, 'val', column_name)
copy_subset(test_df, 'test', column_name)
return train_df, val_df, test_df
def main():
"""Main function to run the EfficientNetV2 transfer learning pipeline"""
parser = argparse.ArgumentParser(description='EfficientNetV2 Transfer Learning for Nocciola')
parser.add_argument('--csv_path', type=str, default=CSV_PATH,
help='Path to the CSV file with image assignments')
parser.add_argument('--images_dir', type=str, default=IMAGES_DIR,
help='Directory with the images')
parser.add_argument('--output_dir', type=str, default=OUTPUT_DIR,
help='Output directory for results')
parser.add_argument('--epochs', type=int, default=30,
help='Number of training epochs')
parser.add_argument('--force_split', action='store_true',
help='Force recreation of the data split')
parser.add_argument('--phase', type=str, default='fase R',
help='Phase of the analysis (V or R)')
parser.add_argument('--test', action='store_true', default=True,
help='Run model using AV dataset for testing purposes')
args = parser.parse_args()
print('\n === Start of the pipeline ===')
print(f"Images directory: {args.images_dir}")
print(f"CSV file: {args.csv_path}")
print(f"Output directory: {args.output_dir}")
gpus = tf.config.list_physical_devices('GPU')
try:
tf.config.set_logical_device_configuration(
gpus[0],
[tf.config.LogicalDeviceConfiguration(memory_limit=4096)])
logical_gpus = tf.config.list_logical_devices('GPU')
print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
except RuntimeError as e:
# Visible devices must be set before GPUs have been initialized
print(e)
# Set seed
set_seed(SEED)
# Create output directory
os.makedirs(args.output_dir, exist_ok=True)
# Load data
print('\n === Loading data ===')
df = safe_read_csv(args.csv_path)
print(f'Total of records in CSV: {len(df)}')
print(f'Available columns: {list(df.columns)}')
# Check required columns
required_cols = {'id_img', args.phase}
if not required_cols.issubset(set(df.columns)):
missing = required_cols - set(df.columns)
raise ValueError(f'CSV must contain the columns: {missing}')
# Analyze class distribution before processing
real_split = analyze_class_distribution(df, args.phase)
# Prepare folder structure
SPLIT_DIR = os.path.join(args.output_dir, 'data_split')
if args.force_split and os.path.exists(SPLIT_DIR):
print("Eliminating existing split")
shutil.rmtree(SPLIT_DIR)
if not os.path.exists(SPLIT_DIR):
print("\n=== Creating new data split ===")
train_df, val_df, test_df = prepare_image_folders(df, args.images_dir, SPLIT_DIR, column_name=args.phase, test_AV=args.test, class_split=real_split)
# Save split information
train_df.to_csv(os.path.join(args.output_dir, 'train_split.csv'), index=False)
val_df.to_csv(os.path.join(args.output_dir, 'val_split.csv'), index=False)
test_df.to_csv(os.path.join(args.output_dir, 'test_split.csv'), index=False)
else:
print("\n === Reuse existing split ===")
# Load split information if it exists
try:
train_df = pd.read_csv(os.path.join(args.output_dir, 'train_split.csv'))
val_df = pd.read_csv(os.path.join(args.output_dir, 'val_split.csv'))
test_df = pd.read_csv(os.path.join(args.output_dir, 'test_split.csv'))
except:
print("Could not load split files, recreating")
train_df, val_df, test_df = prepare_image_folders(df, args.images_dir, SPLIT_DIR, column_name=args.phase, test_AV=args.test, class_split=real_split)
# Create data generators
print("\n=== Creating data generators ===")
# Data augmentation for training
train_datagen = ImageDataGenerator(
rescale=1./255,
rotation_range=20,
width_shift_range=0.1,
height_shift_range=0.1,
shear_range=0.1,
zoom_range=0.1,
horizontal_flip=True,
fill_mode='nearest'
)
# Only normalization for validation and test
val_test_datagen = ImageDataGenerator(rescale=1./255)
# Create generators
train_gen = train_datagen.flow_from_directory(
os.path.join(SPLIT_DIR, 'train'),
target_size=IMG_SIZE,
batch_size=BATCH_SIZE,
class_mode='categorical',
seed=SEED
)
val_gen = val_test_datagen.flow_from_directory(
os.path.join(SPLIT_DIR, 'val'),
target_size=IMG_SIZE,
batch_size=BATCH_SIZE,
class_mode='categorical',
shuffle=False
)
test_gen = val_test_datagen.flow_from_directory(
os.path.join(SPLIT_DIR, 'test'),
target_size=IMG_SIZE,
batch_size=BATCH_SIZE,
class_mode='categorical',
shuffle=False
)
# Save class mapping
class_indices = train_gen.class_indices
print(f'Class mapping: {class_indices}')
with open(os.path.join(args.output_dir, 'class_indices.json'), 'w') as f:
json.dump(class_indices, f, indent=2)
print(f"Samples per class:")
print(f" - Training: {train_gen.samples}")
print(f" - Validation: {val_gen.samples}")
print(f" - Test: {test_gen.samples}")
print(f" - Number of classes: {train_gen.num_classes}")
# Create and train model
print("\n === Model construction ===")
# Base model EfficientNetV2B0
base_model = EfficientNetV2B0(
weights='imagenet',
include_top=False,
input_shape=(*IMG_SIZE, 3)
)
base_model.trainable = False # Freeze base model
# Build sequential model
model = models.Sequential([
base_model,
layers.GlobalAveragePooling2D(),
layers.BatchNormalization(),
layers.Dropout(0.4),
layers.Dense(256, activation='relu'),
layers.BatchNormalization(),
layers.Dropout(0.4),
layers.Dense(train_gen.num_classes, activation='softmax')
])
# Compile model
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
loss='categorical_crossentropy',
metrics=['accuracy']
)
print(" Summary of the model:")
model.summary()
# Calculate class weights
print("\n === Calculating class weights ===")
try:
# Get training labels
train_labels = []
for i in range(len(train_gen)):
_, labels = train_gen[i]
train_labels.extend(np.argmax(labels, axis=1))
if len(train_labels) >= train_gen.samples:
break
# Calculate class weights
class_weights = class_weight.compute_class_weight(
'balanced',
classes=np.unique(train_labels),
y=train_labels
)
class_weight_dict = dict(zip(np.unique(train_labels), class_weights))
print(f"Class weights: {class_weight_dict}")
except Exception as e:
print(f"Error calculating class weights: {e}")
class_weight_dict = None
# Callbacks for training
early_stopping = tf.keras.callbacks.EarlyStopping(
monitor='val_loss',
patience=7,
restore_best_weights=True,
verbose=1
)
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
os.path.join(args.output_dir, 'best_model.keras'),
save_best_only=True,
monitor='val_loss',
verbose=1
)
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
monitor='val_loss',
factor=0.2,
patience=3,
min_lr=1e-7,
verbose=1
)
callbacks = [early_stopping, model_checkpoint, reduce_lr]
# Initial training
print(f"\n === Initial training ({args.epochs} epochs) ===")
try:
history = model.fit(
train_gen,
validation_data=val_gen,
epochs=args.epochs,
callbacks=callbacks,
class_weight=class_weight_dict,
verbose=1
)
print("Initial training completed successfully")
except Exception as e:
print(f"Error during training: {e}")
# Train without class_weight if there are issues
print("Trying training without class weights")
history = model.fit(
train_gen,
validation_data=val_gen,
epochs=args.epochs,
callbacks=callbacks,
verbose=1
)
# Fine-tuning
print("\n === Fine-tuning ===")
# Unfreeze some layers of the base model
base_model.trainable = True
fine_tune_at = 200 # Unfreeze the last 200 layers
for layer in base_model.layers[:fine_tune_at]:
layer.trainable = False
# Recompile with a lower learning rate
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),
loss='categorical_crossentropy',
metrics=['accuracy']
)
# Continue training
fine_tune_epochs = 10
total_epochs = len(history.history['loss']) + fine_tune_epochs
try:
history_fine = model.fit(
train_gen,
validation_data=val_gen,
epochs=total_epochs,
initial_epoch=len(history.history['loss']),
callbacks=callbacks,
verbose=1
)
print("Fine-tuning completed successfully")
# Combine histories
for key in history.history:
if key in history_fine.history:
history.history[key].extend(history_fine.history[key])
except Exception as e:
print(f"Error during fine-tuning: {e}")
print("Continuing with initial training model")
# Final evaluation
print("\n === Evaluation on test set ===")
# Load best model
try:
model.load_weights(os.path.join(args.output_dir, 'best_model.keras'))
print(" Loaded best saved model")
except:
print(" Using current model")
# Save final model
model.save(os.path.join(args.output_dir, 'final_model.keras'))
print("Saved final model")
# Predictions on test set
test_gen.reset()
y_pred_prob = model.predict(test_gen, verbose=1)
y_pred = np.argmax(y_pred_prob, axis=1)
y_true = test_gen.classes
# Map indices to class names
index_to_class = {v: k for k, v in class_indices.items()}
# Get only the classes that actually appear in the test set
unique_test_classes = np.unique(np.concatenate([y_true, y_pred]))
test_class_names = [index_to_class[i] for i in unique_test_classes]
print(f"Classes in test set: {len(unique_test_classes)}")
print(f"All trained classes: {len(class_indices)}")
print(f"Classes present in test: {test_class_names}")
# Check for missing classes
all_classes = set(range(len(class_indices)))
test_classes = set(unique_test_classes)
missing_classes = all_classes - test_classes
if missing_classes:
missing_names = [index_to_class[i] for i in missing_classes]
print(f"Classes without samples in test: {missing_names}")
# Classification report with filtered classes
print("\n === Classification Report ===")
try:
report = classification_report(
y_true, y_pred,
labels=unique_test_classes, # Specify exact classes
target_names=test_class_names,
output_dict=False,
zero_division=0 # Handle division by zero
)
print(report)
# Save report
with open(os.path.join(args.output_dir, 'classification_report.txt'), 'w') as f:
f.write(f"Classes evaluated: {test_class_names}\n")
f.write(f"Classes missing in test: {[index_to_class[i] for i in missing_classes] if missing_classes else 'None'}\n\n")
f.write(report)
except Exception as e:
print(f"Error in classification_report: {e}")
print("Generating alternative report")
# Manual report if automatic fails
from collections import Counter
true_counts = Counter(y_true)
pred_counts = Counter(y_pred)
print("\n Manual distribution:")
print("Class | True | Predicted")
print("-" * 35)
for class_idx in unique_test_classes:
class_name = index_to_class[class_idx]
true_count = true_counts.get(class_idx, 0)
pred_count = pred_counts.get(class_idx, 0)
print(f"{class_name[:15]:15} | {true_count:10} | {pred_count:9}")
# Calculate basic accuracy
accuracy = np.mean(y_true == y_pred)
print(f"\nOverall accuracy: {accuracy:.4f}")
# Save manual report
with open(os.path.join(args.output_dir, 'classification_report.txt'), 'w') as f:
f.write("MANUAL CLASSIFICATION REPORT\n")
f.write("=" * 40 + "\n\n")
f.write(f"Classes evaluated: {test_class_names}\n")
f.write(f"Classes missing in test: {[index_to_class[i] for i in missing_classes] if missing_classes else 'None'}\n\n")
f.write("Class distribution:\n")
f.write("Class | True | Predicted\n")
f.write("-" * 35 + "\n")
for class_idx in unique_test_classes:
class_name = index_to_class[class_idx]
true_count = true_counts.get(class_idx, 0)
pred_count = pred_counts.get(class_idx, 0)
f.write(f"{class_name[:15]:15} | {true_count:10} | {pred_count:9}\n")
f.write(f"\nOverall accuracy: {accuracy:.4f}\n")
# Confusion matrix with filtered classes
cm = confusion_matrix(y_true, y_pred, labels=unique_test_classes)
print(f"\n Confusion Matrix ({len(unique_test_classes)} classes):")
print(cm)
np.savetxt(os.path.join(args.output_dir, 'confusion_matrix.csv'),
cm, delimiter=',', fmt='%d')
# Visualizations with filtered classes
print("\n === Generating visualizations ===")
# Training plot
plot_training_history(history, args.output_dir)
# Confusion matrix visualization with filtered classes
plot_confusion_matrix(cm, test_class_names, args.output_dir)
# Prediction examples with filtered classes
plot_prediction_examples(test_gen, y_true, y_pred, test_class_names, args.output_dir, unique_test_classes)
print(f"\n=== Pipeline completed ===")
print(f" Results saved in: {args.output_dir}")
print(f" Final test accuracy: {np.mean(y_true == y_pred):.4f}")
print(f" Classes evaluated: {len(unique_test_classes)}/{len(class_indices)}")
# Additional information about imbalanced classes
if missing_classes:
print(f"\n === Information about Imbalanced Classes ===")
print(f" Classes without samples in test: {len(missing_classes)}")
for missing_idx in missing_classes:
missing_name = index_to_class[missing_idx]
print(f" - {missing_name} (index {missing_idx})")
print(f" Suggestion: Consider increasing the dataset or merging similar classes")
def plot_training_history(history, output_dir):
"""Plot training history"""
try:
plt.figure(figsize=(12, 4))
# Accuracy
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Training')
if 'val_accuracy' in history.history:
plt.plot(history.history['val_accuracy'], label='Validation')
plt.title('Model Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)
# Loss
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Training')
if 'val_loss' in history.history:
plt.plot(history.history['val_loss'], label='Validation')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'training_history.png'), dpi=300, bbox_inches='tight')
plt.close()
print("Saved training history plot")
except Exception as e:
print(f"Error creating training history plot: {e}")
def plot_confusion_matrix(cm, class_names, output_dir):
"""Plot confusion matrix"""
try:
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=class_names, yticklabels=class_names)
plt.title('Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'confusion_matrix.png'), dpi=300, bbox_inches='tight')
plt.close()
print("Saved confusion matrix plot")
except Exception as e:
print(f"Error creating confusion matrix plot: {e}")
def plot_prediction_examples(test_gen, y_true, y_pred, class_names, output_dir, unique_classes=None, n_examples=12):
"""Show examples of correct and incorrect predictions"""
try:
# Get indices of correct and incorrect predictions
correct_idx = np.where(y_true == y_pred)[0]
incorrect_idx = np.where(y_true != y_pred)[0]
# Select examples
n_correct = min(n_examples // 2, len(correct_idx))
n_incorrect = min(n_examples // 2, len(incorrect_idx))
selected_correct = np.random.choice(correct_idx, n_correct, replace=False) if len(correct_idx) > 0 else []
selected_incorrect = np.random.choice(incorrect_idx, n_incorrect, replace=False) if len(incorrect_idx) > 0 else []
selected_indices = np.concatenate([selected_correct, selected_incorrect])
if len(selected_indices) == 0:
print("There are no examples to show.")
return
# Create plot
n_show = len(selected_indices)
cols = 4
rows = (n_show + cols - 1) // cols
plt.figure(figsize=(15, 4 * rows))
for i, idx in enumerate(selected_indices):
plt.subplot(rows, cols, i + 1)
# Obtain the image from the generator
# This is a workaround since we don't have direct access to the images
img_path = test_gen.filepaths[idx]
img = plt.imread(img_path)
plt.imshow(img)
plt.axis('off')
true_label = class_names[y_true[idx]]
pred_label = class_names[y_pred[idx]]
color = 'green' if y_true[idx] == y_pred[idx] else 'red'
plt.title(f'True: {true_label}\nPredicted: {pred_label}',
color=color, fontsize=10)
plt.suptitle('Prediction Examples (Green=Correct, Red=Incorrect)', fontsize=14)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'prediction_examples.png'), dpi=300, bbox_inches='tight')
plt.close()
print("Saved prediction examples plot")
except Exception as e:
print(f"Error creating prediction examples plot: {e}")
if __name__ == "__main__":
main()