""" ResNet50 Transfer Learning for Phenological Phase Classification - Hazelnut/Artichoke This code was originally developed in Google Colab and has been adapted for Visual Studio Code. Dataset Path: C:/Users/sof12/Desktop/ML/Datasets/Nocciola/GBIF Objective: Predict phenological phase R (reproductive) """ #----------------- IMPORTS ----------------- import os import shutil import random import argparse from pathlib import Path import pandas as pd import numpy as np import matplotlib.pyplot as plt import seaborn as sns import tensorflow as tf from tensorflow.keras import layers, models from tensorflow.keras.applications import ResNet50 from tensorflow.keras.preprocessing.image import ImageDataGenerator from sklearn.metrics import classification_report, confusion_matrix from sklearn.utils import class_weight import json # ----------------- CONFIG ----------------- PROJECT_PATH = "" IMAGES_DIR = PROJECT_PATH # The images are in the main project directory CSV_PATH = os.path.join(PROJECT_PATH, 'tags.csv') # Main CSV OUTPUT_DIR = os.path.join(PROJECT_PATH, 'results_resnet50_fase_V_Combi(AV)') IMG_SIZE = (224, 224) # Recommended for ResNet50 BATCH_SIZE = 16 # Standard batch size SEED = 42 # Seed for reproducibility SPLIT = {'train': 0.7, 'val': 0.15, 'test': 0.15} FORCE_SPLIT = True # Whether to force re-creation of the data split # ----------------- Utilities ----------------- def set_seed(seed=42): """Set seed for reproducibility""" random.seed(seed) np.random.seed(seed) tf.random.set_seed(seed) def analyze_class_distribution(df, column_name): """Analyze class distribution and detect imbalances""" print(f"\n Analyzing class distribution for column: '{column_name}'") min_samples = 2 # Recommended minimum threshold # Count by class counts = df[column_name].value_counts() total = len(df) real_split = [] print(f" Total samples: {total}") print(f" Number of classes: {len(counts)}") print(f" Class distribution:") # Show detailed statistics for clase, count in counts.items(): percentage = (count / total) * 100 if count >= min_samples: real_split = clase print(f" - {clase}: {count} samples ({percentage:.1f}%)") # Detect problematic classes small_classes = counts[counts < min_samples] if len(small_classes) > 0: print(f"\n Classes with less than {min_samples} samples:") for clase, count in small_classes.items(): print(f" - {clase}: {count} samples") print(f"\n Recommendations:") print(f" 1. Consider collecting more data for these classes") print(f" 2. Or merge similar classes") print(f" 3. Or use specific data augmentation techniques") return real_split def safe_read_csv(path): """Read CSV with encoding handling""" if not os.path.exists(path): raise FileNotFoundError(f'CSV not found: {path}') try: df = pd.read_csv(path, encoding='utf-8') except UnicodeDecodeError: try: df = pd.read_csv(path, encoding='latin-1') except: df = pd.read_csv(path, encoding='iso-8859-1') return df def resolve_image_path(images_dir, img_id): """Resolve the full path of an image""" if pd.isna(img_id) or str(img_id).strip() == '': return None img_id = str(img_id).strip() # Verify if the image exists directly direct_path = os.path.join(images_dir, img_id) if os.path.exists(direct_path): return direct_path # Try common extensions stem = os.path.splitext(img_id)[0] for ext in ['.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG']: img_path = os.path.join(images_dir, stem + ext) if os.path.exists(img_path): return img_path return None def prepare_image_folders(df, images_dir, out_dir, split=SPLIT, seed=SEED, column_name='fase R', test_AV=False, class_split=[]): """Create folder structure for flow_from_directory""" set_seed(seed) # Filter only rows with valid phase and existing images print(f"Initial data: {len(df)} rows") # Filter rows with valid phase df_valid = df.dropna(subset=[column_name]).copy() df_valid = df_valid[df_valid[column_name].str.strip() != ''] print(f"With valid phase: {len(df_valid)} rows") # Verify existence of images valid_rows = [] for _, row in df_valid.iterrows(): img_path = resolve_image_path(images_dir, row['id_img']) if img_path: valid_rows.append(row) else: print(f"Image not found: {row['id_img']}") if not valid_rows: raise ValueError("No valid images found") df_final = pd.DataFrame(valid_rows) print(f"With existing images: {len(df_final)} rows") # Show class distribution fase_counts = df_final[column_name].value_counts() print(f"\n Distribution of phase:") for fase, count in fase_counts.items(): print(f" - {fase}: {count} images") # Remove classes with very few samples (less than 3) min_samples = 3 valid_phases = fase_counts[fase_counts >= min_samples].index.tolist() if len(valid_phases) < len(fase_counts): excluded = fase_counts[fase_counts < min_samples].index.tolist() print(f"Excluded phases with less than {min_samples} samples: {excluded}") df_final = df_final[df_final[column_name].isin(valid_phases)] print(f"After filtering: {len(df_final)} rows, {len(valid_phases)} classes") labels = df_final[column_name].unique().tolist() print(f"Final classes: {labels}") # Standard split - Stratified by class train_dfs = [] val_dfs = [] test_dfs = [] # Shuffle and split data if test_AV: print("\n=== Using test_AV mode: NocciolaAV images prioritized for test ===") # Split each class separately to maintain proportions for label in labels: # Filter data for this class df_class = df_final[df_final[column_name] == label].copy() n_class = len(df_class) # Filter for AV dataset df_class['is_av'] = df_class['id_img'].astype(str).str.contains('NocciolaAV', case=False, na=False) df_class_av = df_class[df_class['is_av']].copy() df_class_non_av = df_class[~df_class['is_av']].copy() # Shuffle this class df_class_av_shuffled = df_class_av.sample(frac=1, random_state=seed).reset_index(drop=True) df_class_non_av_shuffled = df_class_non_av.sample(frac=1, random_state=seed).reset_index(drop=True) # Remove the helper column df_class_av_shuffled = df_class_av_shuffled.drop(columns=['is_av'], errors='ignore') df_class_non_av_shuffled = df_class_non_av_shuffled.drop(columns=['is_av'], errors='ignore') # Calculate split sizes for this class n_train_class = int(n_class * split['train']) n_val_class = int(n_class * split['val']) n_test_class = n_class - n_train_class - n_val_class # Split this class if len(df_class_av_shuffled) > n_test_class: test_class = df_class_av_shuffled.iloc[:n_test_class] df_class_non_av_shuffled = pd.concat([df_class_av_shuffled.iloc[n_test_class:], df_class_non_av_shuffled], ignore_index=True) df_class_non_av_shuffled = df_class_non_av_shuffled.sample(frac=1, random_state=seed).reset_index(drop=True) elif len(df_class_av_shuffled) == n_test_class: test_class = df_class_av_shuffled else: needed = n_test_class - len(df_class_av_shuffled) test_class = pd.concat([df_class_av_shuffled, df_class_non_av_shuffled.iloc[:needed]], ignore_index=True) test_class = test_class.sample(frac=1, random_state=seed).reset_index(drop=True) if needed: train_class = df_class_non_av_shuffled.iloc[needed:n_train_class] else: train_class = df_class_non_av_shuffled.iloc[:n_train_class] val_class = df_class_non_av_shuffled.iloc[n_test_class + n_train_class:] # Store splits train_dfs.append(train_class) val_dfs.append(val_class) test_dfs.append(test_class) print(f" Class '{label}': {n_class} total -> Train: {len(train_class)}, Val: {len(val_class)}, Test: {len(test_class)}") else: print("\n=== Standard stratified split ===") # Split each class separately to maintain proportions for label in labels: # Filter data for this class df_class = df_final[df_final[column_name] == label].copy() n_class = len(df_class) # Shuffle this class df_class_shuffled = df_class.sample(frac=1, random_state=seed).reset_index(drop=True) # Calculate split sizes for this class n_train_class = int(n_class * split['train']) n_val_class = int(n_class * split['val']) #n_test_class = n_class - n_train_class - n_val_class # Split this class train_class = df_class_shuffled.iloc[:n_train_class] val_class = df_class_shuffled.iloc[n_train_class:n_train_class + n_val_class] test_class = df_class_shuffled.iloc[n_train_class + n_val_class:] # Store splits train_dfs.append(train_class) val_dfs.append(val_class) test_dfs.append(test_class) print(f" Class '{label}': {n_class} total -> Train: {len(train_class)}, Val: {len(val_class)}, Test: {len(test_class)}") # Combine all classes for each split train_df = pd.concat(train_dfs, ignore_index=True).sample(frac=1, random_state=seed).reset_index(drop=True) val_df = pd.concat(val_dfs, ignore_index=True).sample(frac=1, random_state=seed).reset_index(drop=True) test_df = pd.concat(test_dfs, ignore_index=True).sample(frac=1, random_state=seed).reset_index(drop=True) print(f"\nFinal split (stratified):") print(f" - Training: {len(train_df)} images") print(f" - Validation: {len(val_df)} images") print(f" - Test: {len(test_df)} images") # Verify class distribution in each split print(f"\nClass distribution verification:") for subset_name, subset_df in [('Train', train_df), ('Val', val_df), ('Test', test_df)]: dist = subset_df[column_name].value_counts().sort_index() print(f" {subset_name}: {dict(dist)}") # Create folder structure for part in ['train', 'val', 'test']: for label in labels: label_dir = os.path.join(out_dir, part, str(label)) os.makedirs(label_dir, exist_ok=True) # Function to copy images def copy_subset(subdf, subset_name, column_name): copied, missing = 0, 0 failed = [] miss = [] for _, row in subdf.iterrows(): src = resolve_image_path(images_dir, row['id_img']) if src: fase = str(row[column_name]) # Crear carpeta de destino si no existe (por si acaso) dest_dir = os.path.join(out_dir, subset_name, fase) os.makedirs(dest_dir, exist_ok=True) # Usar el nombre original del archivo (con extensión correcta) original_filename = os.path.basename(src) dst = os.path.join(dest_dir, original_filename) # Verificar que la ruta de destino no sea demasiado larga if len(dst) > 260: print(f" Ruta demasiado larga ({len(dst)} caracteres): {dst}") # Crear nombre más corto ext = os.path.splitext(original_filename)[1] short_name = f"img_{copied:04d}{ext}" dst = os.path.join(dest_dir, short_name) print(f" Usando nombre corto: {short_name}") try: # Verificar que origen y destino existen/son válidos if not os.path.exists(src): print(f" Archivo origen no existe: {src}") missing += 1 miss.append(row['id_img']) continue if not os.path.exists(dest_dir): print(f" Carpeta destino no existe: {dest_dir}") os.makedirs(dest_dir, exist_ok=True) shutil.copy2(src, dst) copied += 1 except Exception as e: print(f" Error copying {src} to {dst}: {e}") print(f" - Source exists: {os.path.exists(src)}") print(f" - Dest dir exists: {os.path.exists(dest_dir)}") print(f" - Source path length: {len(src)}") print(f" - Dest path length: {len(dst)}") failed.append(src) missing += 1 else: missing += 1 miss.append(row['id_img']) miss_file_path = os.path.join(os.getcwd(), f'missing_{subset_name}.txt') with open(miss_file_path, 'w') as f: f.write(f"Missing images in {subset_name}:\n") for item in miss: f.write(f"{item}\n") f.write(f"\n\n Failed to copy:\n") for item in failed: f.write(f"{item}\n") print(f" {subset_name}: {copied} images copied, {missing} failed") return copied # Copy images to the corresponding folders copy_subset(train_df, 'train', column_name) copy_subset(val_df, 'val', column_name) copy_subset(test_df, 'test', column_name) return train_df, val_df, test_df def main(): """Main function to run the ResNet50 transfer learning pipeline""" parser = argparse.ArgumentParser(description='ResNet50 Transfer Learning for Nocciola') parser.add_argument('--csv_path', type=str, default=CSV_PATH, help='Path to the CSV file with image assignments') parser.add_argument('--images_dir', type=str, default=IMAGES_DIR, help='Directory with the images') parser.add_argument('--output_dir', type=str, default=OUTPUT_DIR, help='Output directory for results') parser.add_argument('--epochs', type=int, default=30, help='Number of training epochs') parser.add_argument('--force_split', action='store_true', help='Force recreation of the data split') parser.add_argument('--phase', type=str, default='fase R', help='Phase of the analysis (V or R)') parser.add_argument('--test', action='store_true', default=True, help='Run model using AV dataset for testing purposes') args = parser.parse_args() print('\n === Start of the pipeline ===') print(f"Images directory: {args.images_dir}") print(f"CSV file: {args.csv_path}") print(f"Output directory: {args.output_dir}") gpus = tf.config.list_physical_devices('GPU') try: tf.config.set_logical_device_configuration( gpus[0], [tf.config.LogicalDeviceConfiguration(memory_limit=4096)]) logical_gpus = tf.config.list_logical_devices('GPU') print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs") except RuntimeError as e: # Visible devices must be set before GPUs have been initialized print(e) # Set seed set_seed(SEED) # Create output directory os.makedirs(args.output_dir, exist_ok=True) # Load data print('\n === Loading data ===') df = safe_read_csv(args.csv_path) print(f'Total of records in CSV: {len(df)}') print(f'Available columns: {list(df.columns)}') # Check required columns required_cols = {'id_img', args.phase} if not required_cols.issubset(set(df.columns)): missing = required_cols - set(df.columns) raise ValueError(f'CSV must contain the columns: {missing}') # Analyze class distribution before processing real_split = analyze_class_distribution(df, args.phase) # Prepare folder structure SPLIT_DIR = os.path.join(args.output_dir, 'data_split') if args.force_split and os.path.exists(SPLIT_DIR): print("Eliminating existing split") shutil.rmtree(SPLIT_DIR) if not os.path.exists(SPLIT_DIR): print("\n=== Creating new data split ===") train_df, val_df, test_df = prepare_image_folders(df, args.images_dir, SPLIT_DIR, column_name=args.phase, test_AV=args.test, class_split=real_split) # Save split information train_df.to_csv(os.path.join(args.output_dir, 'train_split.csv'), index=False) val_df.to_csv(os.path.join(args.output_dir, 'val_split.csv'), index=False) test_df.to_csv(os.path.join(args.output_dir, 'test_split.csv'), index=False) else: print("\n === Reuse existing split ===") # Load split information if it exists try: train_df = pd.read_csv(os.path.join(args.output_dir, 'train_split.csv')) val_df = pd.read_csv(os.path.join(args.output_dir, 'val_split.csv')) test_df = pd.read_csv(os.path.join(args.output_dir, 'test_split.csv')) except: print("Could not load split files, recreating") train_df, val_df, test_df = prepare_image_folders(df, args.images_dir, SPLIT_DIR, column_name=args.phase, test_AV=args.test, class_split=real_split) # Create data generators print("\n=== Creating data generators ===") # Data augmentation for training train_datagen = ImageDataGenerator( rescale=1./255, rotation_range=20, width_shift_range=0.1, height_shift_range=0.1, shear_range=0.1, zoom_range=0.1, horizontal_flip=True, fill_mode='nearest' ) # Only normalization for validation and test val_test_datagen = ImageDataGenerator(rescale=1./255) # Create generators train_gen = train_datagen.flow_from_directory( os.path.join(SPLIT_DIR, 'train'), target_size=IMG_SIZE, batch_size=BATCH_SIZE, class_mode='categorical', seed=SEED ) val_gen = val_test_datagen.flow_from_directory( os.path.join(SPLIT_DIR, 'val'), target_size=IMG_SIZE, batch_size=BATCH_SIZE, class_mode='categorical', shuffle=False ) test_gen = val_test_datagen.flow_from_directory( os.path.join(SPLIT_DIR, 'test'), target_size=IMG_SIZE, batch_size=BATCH_SIZE, class_mode='categorical', shuffle=False ) # Save class mapping class_indices = train_gen.class_indices print(f'Class mapping: {class_indices}') with open(os.path.join(args.output_dir, 'class_indices.json'), 'w') as f: json.dump(class_indices, f, indent=2) print(f"Samples per class:") print(f" - Training: {train_gen.samples}") print(f" - Validation: {val_gen.samples}") print(f" - Test: {test_gen.samples}") print(f" - Number of classes: {train_gen.num_classes}") # Create and train model print("\n === Model construction ===") # Base model ResNet50 base_model = ResNet50( weights='imagenet', include_top=False, input_shape=(*IMG_SIZE, 3) ) base_model.trainable = False # Freeze base model # Build sequential model model = models.Sequential([ base_model, layers.GlobalAveragePooling2D(), layers.BatchNormalization(), layers.Dropout(0.5), layers.Dense(512, activation='relu'), layers.BatchNormalization(), layers.Dropout(0.5), layers.Dense(256, activation='relu'), layers.Dropout(0.3), layers.Dense(train_gen.num_classes, activation='softmax') ]) # Compile model model.compile( optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3), loss='categorical_crossentropy', metrics=['accuracy'] ) print(" Summary of the model:") model.summary() # Calculate class weights print("\n === Calculating class weights ===") try: # Get training labels train_labels = [] for i in range(len(train_gen)): _, labels = train_gen[i] train_labels.extend(np.argmax(labels, axis=1)) if len(train_labels) >= train_gen.samples: break # Calculate class weights class_weights = class_weight.compute_class_weight( 'balanced', classes=np.unique(train_labels), y=train_labels ) class_weight_dict = dict(zip(np.unique(train_labels), class_weights)) print(f"Class weights: {class_weight_dict}") except Exception as e: print(f"Error calculating class weights: {e}") class_weight_dict = None # Callbacks for training early_stopping = tf.keras.callbacks.EarlyStopping( monitor='val_loss', patience=7, restore_best_weights=True, verbose=1 ) model_checkpoint = tf.keras.callbacks.ModelCheckpoint( os.path.join(args.output_dir, 'best_model.keras'), save_best_only=True, monitor='val_loss', verbose=1 ) reduce_lr = tf.keras.callbacks.ReduceLROnPlateau( monitor='val_loss', factor=0.2, patience=3, min_lr=1e-7, verbose=1 ) callbacks = [early_stopping, model_checkpoint, reduce_lr] # Initial training print(f"\n === Initial training ({args.epochs} epochs) ===") try: history = model.fit( train_gen, validation_data=val_gen, epochs=args.epochs, callbacks=callbacks, class_weight=class_weight_dict, verbose=1 ) print("Initial training completed successfully") except Exception as e: print(f"Error during training: {e}") # Train without class_weight if there are issues print("Trying training without class weights") history = model.fit( train_gen, validation_data=val_gen, epochs=args.epochs, callbacks=callbacks, verbose=1 ) # Fine-tuning print("\n === Fine-tuning ===") # Unfreeze some layers of the base model base_model.trainable = True fine_tune_at = 143 # Unfreeze after conv5_block1 (ResNet50 has 175 layers) for layer in base_model.layers[:fine_tune_at]: layer.trainable = False # Recompile with a lower learning rate model.compile( optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5), loss='categorical_crossentropy', metrics=['accuracy'] ) # Continue training fine_tune_epochs = 10 total_epochs = len(history.history['loss']) + fine_tune_epochs try: history_fine = model.fit( train_gen, validation_data=val_gen, epochs=total_epochs, initial_epoch=len(history.history['loss']), callbacks=callbacks, verbose=1 ) print("Fine-tuning completed successfully") # Combine histories for key in history.history: if key in history_fine.history: history.history[key].extend(history_fine.history[key]) except Exception as e: print(f"Error during fine-tuning: {e}") print("Continuing with initial training model") # Final evaluation print("\n === Evaluation on test set ===") # Load best model try: model.load_weights(os.path.join(args.output_dir, 'best_model.keras')) print(" Loaded best saved model") except: print(" Using current model") # Save final model model.save(os.path.join(args.output_dir, 'final_model.keras')) print("Saved final model") # Predictions on test set test_gen.reset() y_pred_prob = model.predict(test_gen, verbose=1) y_pred = np.argmax(y_pred_prob, axis=1) y_true = test_gen.classes # Map indices to class names index_to_class = {v: k for k, v in class_indices.items()} # Get only the classes that actually appear in the test set unique_test_classes = np.unique(np.concatenate([y_true, y_pred])) test_class_names = [index_to_class[i] for i in unique_test_classes] print(f"Classes in test set: {len(unique_test_classes)}") print(f"All trained classes: {len(class_indices)}") print(f"Classes present in test: {test_class_names}") # Check for missing classes all_classes = set(range(len(class_indices))) test_classes = set(unique_test_classes) missing_classes = all_classes - test_classes if missing_classes: missing_names = [index_to_class[i] for i in missing_classes] print(f"Classes without samples in test: {missing_names}") # Classification report with filtered classes print("\n === Classification Report ===") try: report = classification_report( y_true, y_pred, labels=unique_test_classes, # Specify exact classes target_names=test_class_names, output_dict=False, zero_division=0 # Handle division by zero ) print(report) # Save report with open(os.path.join(args.output_dir, 'classification_report.txt'), 'w') as f: f.write(f"Classes evaluated: {test_class_names}\n") f.write(f"Classes missing in test: {[index_to_class[i] for i in missing_classes] if missing_classes else 'None'}\n\n") f.write(report) except Exception as e: print(f"Error in classification_report: {e}") print("Generating alternative report") # Manual report if automatic fails from collections import Counter true_counts = Counter(y_true) pred_counts = Counter(y_pred) print("\n Manual distribution:") print("Class | True | Predicted") print("-" * 35) for class_idx in unique_test_classes: class_name = index_to_class[class_idx] true_count = true_counts.get(class_idx, 0) pred_count = pred_counts.get(class_idx, 0) print(f"{class_name[:15]:15} | {true_count:10} | {pred_count:9}") # Calculate basic accuracy accuracy = np.mean(y_true == y_pred) print(f"\nOverall accuracy: {accuracy:.4f}") # Save manual report with open(os.path.join(args.output_dir, 'classification_report.txt'), 'w') as f: f.write("MANUAL CLASSIFICATION REPORT\n") f.write("=" * 40 + "\n\n") f.write(f"Classes evaluated: {test_class_names}\n") f.write(f"Classes missing in test: {[index_to_class[i] for i in missing_classes] if missing_classes else 'None'}\n\n") f.write("Class distribution:\n") f.write("Class | True | Predicted\n") f.write("-" * 35 + "\n") for class_idx in unique_test_classes: class_name = index_to_class[class_idx] true_count = true_counts.get(class_idx, 0) pred_count = pred_counts.get(class_idx, 0) f.write(f"{class_name[:15]:15} | {true_count:10} | {pred_count:9}\n") f.write(f"\nOverall accuracy: {accuracy:.4f}\n") # Confusion matrix with filtered classes cm = confusion_matrix(y_true, y_pred, labels=unique_test_classes) print(f"\n Confusion Matrix ({len(unique_test_classes)} classes):") print(cm) np.savetxt(os.path.join(args.output_dir, 'confusion_matrix.csv'), cm, delimiter=',', fmt='%d') # Visualizations with filtered classes print("\n === Generating visualizations ===") # Training plot plot_training_history(history, args.output_dir) # Confusion matrix visualization with filtered classes plot_confusion_matrix(cm, test_class_names, args.output_dir) # Prediction examples with filtered classes plot_prediction_examples(test_gen, y_true, y_pred, test_class_names, args.output_dir, unique_test_classes) print(f"\n=== Pipeline completed ===") print(f" Results saved in: {args.output_dir}") print(f" Final test accuracy: {np.mean(y_true == y_pred):.4f}") print(f" Classes evaluated: {len(unique_test_classes)}/{len(class_indices)}") # Additional information about imbalanced classes if missing_classes: print(f"\n === Information about Imbalanced Classes ===") print(f" Classes without samples in test: {len(missing_classes)}") for missing_idx in missing_classes: missing_name = index_to_class[missing_idx] print(f" - {missing_name} (index {missing_idx})") print(f" Suggestion: Consider increasing the dataset or merging similar classes") def plot_training_history(history, output_dir): """Plot training history""" try: plt.figure(figsize=(12, 4)) # Accuracy plt.subplot(1, 2, 1) plt.plot(history.history['accuracy'], label='Training') if 'val_accuracy' in history.history: plt.plot(history.history['val_accuracy'], label='Validation') plt.title('Model Accuracy') plt.xlabel('Epoch') plt.ylabel('Accuracy') plt.legend() plt.grid(True) # Loss plt.subplot(1, 2, 2) plt.plot(history.history['loss'], label='Training') if 'val_loss' in history.history: plt.plot(history.history['val_loss'], label='Validation') plt.title('Model Loss') plt.xlabel('Epoch') plt.ylabel('Loss') plt.legend() plt.grid(True) plt.tight_layout() plt.savefig(os.path.join(output_dir, 'training_history.png'), dpi=300, bbox_inches='tight') plt.close() print("Saved training history plot") except Exception as e: print(f"Error creating training history plot: {e}") def plot_confusion_matrix(cm, class_names, output_dir): """Plot confusion matrix""" try: plt.figure(figsize=(10, 8)) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names) plt.title('Confusion Matrix') plt.ylabel('True Label') plt.xlabel('Predicted Label') plt.xticks(rotation=45, ha='right') plt.yticks(rotation=0) plt.tight_layout() plt.savefig(os.path.join(output_dir, 'confusion_matrix.png'), dpi=300, bbox_inches='tight') plt.close() print("Saved confusion matrix plot") except Exception as e: print(f"Error creating confusion matrix plot: {e}") def plot_prediction_examples(test_gen, y_true, y_pred, class_names, output_dir, unique_classes=None, n_examples=12): """Show examples of correct and incorrect predictions""" try: # Get indices of correct and incorrect predictions correct_idx = np.where(y_true == y_pred)[0] incorrect_idx = np.where(y_true != y_pred)[0] # Select examples n_correct = min(n_examples // 2, len(correct_idx)) n_incorrect = min(n_examples // 2, len(incorrect_idx)) selected_correct = np.random.choice(correct_idx, n_correct, replace=False) if len(correct_idx) > 0 else [] selected_incorrect = np.random.choice(incorrect_idx, n_incorrect, replace=False) if len(incorrect_idx) > 0 else [] selected_indices = np.concatenate([selected_correct, selected_incorrect]) if len(selected_indices) == 0: print("There are no examples to show.") return # Create plot n_show = len(selected_indices) cols = 4 rows = (n_show + cols - 1) // cols plt.figure(figsize=(15, 4 * rows)) for i, idx in enumerate(selected_indices): plt.subplot(rows, cols, i + 1) # Obtain the image from the generator # This is a workaround since we don't have direct access to the images img_path = test_gen.filepaths[idx] img = plt.imread(img_path) plt.imshow(img) plt.axis('off') true_label = class_names[y_true[idx]] pred_label = class_names[y_pred[idx]] color = 'green' if y_true[idx] == y_pred[idx] else 'red' plt.title(f'True: {true_label}\nPredicted: {pred_label}', color=color, fontsize=10) plt.suptitle('Prediction Examples (Green=Correct, Red=Incorrect)', fontsize=14) plt.tight_layout() plt.savefig(os.path.join(output_dir, 'prediction_examples.png'), dpi=300, bbox_inches='tight') plt.close() print("Saved prediction examples plot") except Exception as e: print(f"Error creating prediction examples plot: {e}") if __name__ == "__main__": main()