Phenology/Code/Supervised_learning/ResNET.py
2025-11-06 14:16:49 +01:00

741 lines
28 KiB
Python

"""
ResNet50 Transfer Learning para Clasificación de Fases Fenológicas - Nocciola
Adaptado para Visual Studio Code basado en MobileNetV1.py exitoso
Dataset: Nocciola GBIF
Objetivo: Predecir fase (fenológica vegetativa)
"""
# =============================================================================
# LIBRARY IMPORTS
# =============================================================================
# System and file handling libraries
import os # Operating system interface (paths, directories)
import shutil # High-level file operations (copy, move, delete)
import argparse # Command-line argument parsing
import random # Random number generation for reproducibility
import pandas as pd # Data manipulation and analysis
import numpy as np # Numerical computing with multidimensional arrays
from pathlib import Path # Modern path handling
import json
# Data visualization libraries
import matplotlib.pyplot as plt # Plotting and visualization
import seaborn as sns # Statistical data visualization
# Deep learning libraries (TensorFlow and Keras)
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.utils import class_weight
# ----------------- CONFIG -----------------
PROJECT_PATH = r'C:\Users\sof12\Desktop\ML\Datasets\Nocciola\combi'
IMAGES_DIR = r'C:\Users\sof12\Desktop\ML\Datasets\Nocciola\combi' # Las imágenes están en el directorio principal
CSV_PATH = os.path.join(PROJECT_PATH, 'assignments.csv') # CSV principal
OUTPUT_DIR = os.path.join(PROJECT_PATH, 'results_resnet50')
os.makedirs(OUTPUT_DIR, exist_ok=True)
IMG_SIZE = (224, 224) # Recomendado para ResNet50 (tamaño estándar ImageNet)
BATCH_SIZE = 16 # Reducido para ResNet50 (más pesado que MobileNet)
SEED = 42
SPLIT = {'train': 0.7, 'val': 0.15, 'test': 0.15}
FORCE_SPLIT = False
# ----------------- Utilities -----------------
def set_seed(seed=42):
"""Establecer semilla para reproducibilidad"""
random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)
def analyze_class_distribution(df, column_name='fase V'):
"""Analizar distribución de clases y detectar desbalances"""
print(f"\n📊 === Análisis de Distribución de Clases ===")
# Contar por clase
counts = df[column_name].value_counts()
total = len(df)
print(f"📊 Total de muestras: {total}")
print(f"📊 Número de clases: {len(counts)}")
print(f"📊 Distribución por clase:")
# Mostrar estadísticas detalladas
for clase, count in counts.items():
percentage = (count / total) * 100
print(f" - {clase}: {count} muestras ({percentage:.1f}%)")
# Detectar clases problemáticas
min_samples = 5 # Umbral mínimo recomendado
small_classes = counts[counts < min_samples]
if len(small_classes) > 0:
print(f"\n⚠️ Clases con menos de {min_samples} muestras:")
for clase, count in small_classes.items():
print(f" - {clase}: {count} muestras")
print(f"\n💡 Recomendaciones:")
print(f" 1. Considera recolectar más datos para estas clases")
print(f" 2. O fusionar clases similares")
print(f" 3. O usar técnicas de data augmentation específicas")
return counts, small_classes
def safe_read_csv(path):
"""Leer CSV con manejo de encoding"""
if not os.path.exists(path):
raise FileNotFoundError(f'CSV no encontrado: {path}')
try:
df = pd.read_csv(path, encoding='utf-8')
except UnicodeDecodeError:
try:
df = pd.read_csv(path, encoding='latin-1')
except:
df = pd.read_csv(path, encoding='cp1252')
return df
def resolve_image_path(images_dir, img_id):
"""Resolver la ruta completa de una imagen"""
if pd.isna(img_id) or str(img_id).strip() == '':
return None
img_id = str(img_id).strip()
# Verificar si ya incluye extensión y existe
direct_path = os.path.join(images_dir, img_id)
if os.path.exists(direct_path):
return direct_path
# Probar con extensiones comunes
stem = os.path.splitext(img_id)[0]
for ext in ['.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG']:
img_path = os.path.join(images_dir, stem + ext)
if os.path.exists(img_path):
return img_path
return None
def prepare_image_folders(df, images_dir, out_dir, split=SPLIT, seed=SEED):
"""Crear estructura de carpetas para flow_from_directory"""
set_seed(seed)
# Filtrar solo filas con fase V válida e imágenes existentes
print(f"📊 Datos iniciales: {len(df)} filas")
# Filtrar filas con fase V válida
df_valid = df.dropna(subset=['fase V']).copy()
df_valid = df_valid[df_valid['fase V'].str.strip() != '']
print(f"📊 Con fase V válida: {len(df_valid)} filas")
# Verificar existencia de imágenes
valid_rows = []
for _, row in df_valid.iterrows():
img_path = resolve_image_path(images_dir, row['id_img'])
if img_path:
valid_rows.append(row)
else:
print(f"⚠️ Imagen no encontrada: {row['id_img']}")
if not valid_rows:
raise ValueError("❌ No se encontraron imágenes válidas")
df_final = pd.DataFrame(valid_rows)
print(f"📊 Con imágenes existentes: {len(df_final)} filas")
# Mostrar distribución de clases
fase_counts = df_final['fase V'].value_counts()
print(f"\n📊 Distribución de fases:")
for fase, count in fase_counts.items():
print(f" - {fase}: {count} imágenes")
# Remover clases con muy pocas muestras (menos de 3)
min_samples = 3
valid_phases = fase_counts[fase_counts >= min_samples].index.tolist()
if len(valid_phases) < len(fase_counts):
excluded = fase_counts[fase_counts < min_samples].index.tolist()
print(f"⚠️ Excluyendo fases con menos de {min_samples} muestras: {excluded}")
df_final = df_final[df_final['fase V'].isin(valid_phases)]
print(f"📊 Después de filtrar: {len(df_final)} filas, {len(valid_phases)} clases")
labels = df_final['fase V'].unique().tolist()
print(f"📊 Clases finales: {labels}")
# Mezclar y dividir datos
df_shuffled = df_final.sample(frac=1, random_state=seed).reset_index(drop=True)
n = len(df_shuffled)
n_train = int(n * split['train'])
n_val = int(n * split['val'])
train_df = df_shuffled.iloc[:n_train]
val_df = df_shuffled.iloc[n_train:n_train + n_val]
test_df = df_shuffled.iloc[n_train + n_val:]
print(f"📊 División final:")
print(f" - Entrenamiento: {len(train_df)} imágenes")
print(f" - Validación: {len(val_df)} imágenes")
print(f" - Prueba: {len(test_df)} imágenes")
# Crear estructura de carpetas
for part in ['train', 'val', 'test']:
for label in labels:
label_dir = os.path.join(out_dir, part, str(label))
os.makedirs(label_dir, exist_ok=True)
# Función para copiar imágenes
def copy_subset(subdf, subset_name):
copied, missing = 0, 0
for _, row in subdf.iterrows():
src = resolve_image_path(images_dir, row['id_img'])
if src:
fase = str(row['fase V'])
dst = os.path.join(out_dir, subset_name, fase, f"{row['id_img']}.jpg")
try:
shutil.copy2(src, dst)
copied += 1
except Exception as e:
print(f"⚠️ Error copiando {src}: {e}")
missing += 1
else:
missing += 1
print(f"{subset_name}: {copied} imágenes copiadas, {missing} fallidas")
return copied
# Copiar imágenes a las carpetas correspondientes
copy_subset(train_df, 'train')
copy_subset(val_df, 'val')
copy_subset(test_df, 'test')
return train_df, val_df, test_df
def main():
"""Función principal del pipeline ResNet50"""
parser = argparse.ArgumentParser(description='ResNet50 Transfer Learning para Nocciola')
parser.add_argument('--csv_path', type=str, default=CSV_PATH,
help='Ruta al archivo CSV con metadatos')
parser.add_argument('--images_dir', type=str, default=IMAGES_DIR,
help='Directorio con las imágenes')
parser.add_argument('--output_dir', type=str, default=OUTPUT_DIR,
help='Directorio de salida para resultados')
parser.add_argument('--epochs', type=int, default=25,
help='Número de épocas de entrenamiento')
parser.add_argument('--force_split', action='store_true',
help='Forzar recreación del split de datos')
args = parser.parse_args()
print('\n🚀 === Inicio del pipeline ResNet50 para Nocciola ===')
print(f"📁 Directorio de imágenes: {args.images_dir}")
print(f"📄 Archivo CSV: {args.csv_path}")
print(f"📂 Directorio de salida: {args.output_dir}")
# Establecer semilla
set_seed(SEED)
# Crear directorio de salida
os.makedirs(args.output_dir, exist_ok=True)
# Leer datos
print('\n📊 === Cargando datos ===')
df = safe_read_csv(args.csv_path)
print(f'📊 Total de registros en CSV: {len(df)}')
print(f'📊 Columnas disponibles: {list(df.columns)}')
# Verificar columnas requeridas
required_cols = {'id_img', 'fase V'}
if not required_cols.issubset(set(df.columns)):
missing = required_cols - set(df.columns)
raise ValueError(f'❌ CSV debe contener las columnas: {missing}')
# Analizar distribución de clases antes del procesamiento
analyze_class_distribution(df, 'fase V')
# Preparar estructura de carpetas
SPLIT_DIR = os.path.join(args.output_dir, 'data_split')
if args.force_split and os.path.exists(SPLIT_DIR):
print("🗑️ Eliminando split existente...")
shutil.rmtree(SPLIT_DIR)
if not os.path.exists(SPLIT_DIR):
print("\n📁 === Creando nueva división de datos ===")
train_df, val_df, test_df = prepare_image_folders(df, args.images_dir, SPLIT_DIR)
# Guardar información del split
train_df.to_csv(os.path.join(args.output_dir, 'train_split.csv'), index=False)
val_df.to_csv(os.path.join(args.output_dir, 'val_split.csv'), index=False)
test_df.to_csv(os.path.join(args.output_dir, 'test_split.csv'), index=False)
else:
print("\n♻️ === Reutilizando división existente ===")
# Cargar información del split si existe
try:
train_df = pd.read_csv(os.path.join(args.output_dir, 'train_split.csv'))
val_df = pd.read_csv(os.path.join(args.output_dir, 'val_split.csv'))
test_df = pd.read_csv(os.path.join(args.output_dir, 'test_split.csv'))
except:
print("⚠️ No se pudieron cargar los archivos de split, recreando...")
train_df, val_df, test_df = prepare_image_folders(df, args.images_dir, SPLIT_DIR)
# Crear generadores de datos
print("\n🔄 === Creando generadores de datos ===")
# Data augmentation para entrenamiento (más conservador para ResNet50)
train_datagen = ImageDataGenerator(
rescale=1./255,
rotation_range=15, # Menos rotación que MobileNet
width_shift_range=0.08,
height_shift_range=0.08,
shear_range=0.08,
zoom_range=0.08,
horizontal_flip=True,
fill_mode='nearest'
)
# Solo normalización para validación y test
val_test_datagen = ImageDataGenerator(rescale=1./255)
# Crear generadores
train_gen = train_datagen.flow_from_directory(
os.path.join(SPLIT_DIR, 'train'),
target_size=IMG_SIZE,
batch_size=BATCH_SIZE,
class_mode='categorical',
seed=SEED
)
val_gen = val_test_datagen.flow_from_directory(
os.path.join(SPLIT_DIR, 'val'),
target_size=IMG_SIZE,
batch_size=BATCH_SIZE,
class_mode='categorical',
shuffle=False
)
test_gen = val_test_datagen.flow_from_directory(
os.path.join(SPLIT_DIR, 'test'),
target_size=IMG_SIZE,
batch_size=BATCH_SIZE,
class_mode='categorical',
shuffle=False
)
# Guardar mapeo de clases
class_indices = train_gen.class_indices
print(f'🏷️ Mapeo de clases: {class_indices}')
with open(os.path.join(args.output_dir, 'class_indices.json'), 'w') as f:
json.dump(class_indices, f, indent=2)
print(f"📊 Muestras por conjunto:")
print(f" - Entrenamiento: {train_gen.samples}")
print(f" - Validación: {val_gen.samples}")
print(f" - Prueba: {test_gen.samples}")
print(f" - Número de clases: {train_gen.num_classes}")
# Crear y entrenar modelo ResNet50
print("\n🤖 === Construcción del modelo ResNet50 ===")
# Modelo base ResNet50
base_model = ResNet50(
weights='imagenet',
include_top=False,
input_shape=(*IMG_SIZE, 3)
)
base_model.trainable = False # Congelar inicialmente
# Construir modelo secuencial optimizado para ResNet50
model = models.Sequential([
base_model,
layers.GlobalAveragePooling2D(),
layers.BatchNormalization(), # Importante para ResNet50
layers.Dropout(0.5), # Dropout más alto para ResNet50
layers.Dense(256, activation='relu'),
layers.BatchNormalization(),
layers.Dropout(0.3),
layers.Dense(train_gen.num_classes, activation='softmax')
])
# Compilar modelo con learning rate más bajo para ResNet50
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=5e-4), # LR más bajo
loss='categorical_crossentropy',
metrics=['accuracy']
)
print("📋 Resumen del modelo ResNet50:")
model.summary()
# Calcular pesos de clase
print("\n⚖️ === Calculando pesos de clase ===")
try:
# Obtener etiquetas de entrenamiento
train_labels = []
for i in range(len(train_gen)):
_, labels = train_gen[i]
train_labels.extend(np.argmax(labels, axis=1))
if len(train_labels) >= train_gen.samples:
break
# Calcular pesos balanceados
class_weights = class_weight.compute_class_weight(
'balanced',
classes=np.unique(train_labels),
y=train_labels
)
class_weight_dict = dict(zip(np.unique(train_labels), class_weights))
print(f"⚖️ Pesos de clase: {class_weight_dict}")
except Exception as e:
print(f"⚠️ Error calculando pesos de clase: {e}")
class_weight_dict = None
# Callbacks para entrenamiento ResNet50
early_stopping = tf.keras.callbacks.EarlyStopping(
monitor='val_loss',
patience=10, # Más paciencia para ResNet50
restore_best_weights=True,
verbose=1
)
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
os.path.join(args.output_dir, 'best_resnet50_model.keras'),
save_best_only=True,
monitor='val_loss',
verbose=1
)
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
monitor='val_loss',
factor=0.3, # Reducción más agresiva
patience=5, # Menos paciencia para reducir LR
min_lr=1e-8,
verbose=1
)
callbacks = [early_stopping, model_checkpoint, reduce_lr]
# Entrenamiento inicial
print(f"\n🏋️ === Entrenamiento inicial ResNet50 ({args.epochs} épocas) ===")
try:
history = model.fit(
train_gen,
validation_data=val_gen,
epochs=args.epochs,
callbacks=callbacks,
class_weight=class_weight_dict,
verbose=1
)
print("✅ Entrenamiento inicial completado")
except Exception as e:
print(f"❌ Error durante entrenamiento: {e}")
# Entrenar sin class_weight si hay problemas
print("🔄 Intentando entrenamiento sin pesos de clase...")
history = model.fit(
train_gen,
validation_data=val_gen,
epochs=args.epochs,
callbacks=callbacks,
verbose=1
)
# Fine-tuning específico para ResNet50
print("\n🔧 === Fine-tuning ResNet50 ===")
# Descongelar capas específicas de ResNet50
base_model.trainable = True
# Para ResNet50, descongelar solo las últimas capas residuales
# ResNet50 tiene bloques conv5_x que son los más específicos
fine_tune_at = 140 # Aproximadamente desde conv5_block1
for layer in base_model.layers[:fine_tune_at]:
layer.trainable = False
print(f"🔧 Capas entrenables: {len([l for l in base_model.layers if l.trainable])}/{len(base_model.layers)}")
# Recompilar con learning rate muy bajo para fine-tuning
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-6), # LR muy bajo para ResNet50
loss='categorical_crossentropy',
metrics=['accuracy']
)
# Continuar entrenamiento
fine_tune_epochs = 15 # Más épocas para fine-tuning
total_epochs = len(history.history['loss']) + fine_tune_epochs
try:
history_fine = model.fit(
train_gen,
validation_data=val_gen,
epochs=total_epochs,
initial_epoch=len(history.history['loss']),
callbacks=callbacks,
verbose=1
)
print("✅ Fine-tuning completado")
# Combinar historiales
for key in history.history:
if key in history_fine.history:
history.history[key].extend(history_fine.history[key])
except Exception as e:
print(f"⚠️ Error durante fine-tuning: {e}")
print("Continuando con modelo del entrenamiento inicial...")
# Evaluación final
print("\n📊 === Evaluación en conjunto de prueba ===")
# Cargar mejor modelo
try:
model.load_weights(os.path.join(args.output_dir, 'best_resnet50_model.keras'))
print("✅ Cargado mejor modelo ResNet50 guardado")
except:
print("⚠️ Usando modelo actual")
# Guardar modelo final
model.save(os.path.join(args.output_dir, 'final_resnet50_model.keras'))
print("💾 Modelo ResNet50 final guardado")
# Predicciones en test
test_gen.reset()
y_pred_prob = model.predict(test_gen, verbose=1)
y_pred = np.argmax(y_pred_prob, axis=1)
y_true = test_gen.classes
# Mapeo de índices a nombres de clase
index_to_class = {v: k for k, v in class_indices.items()}
# Obtener solo las clases que realmente aparecen en el conjunto de test
unique_test_classes = np.unique(np.concatenate([y_true, y_pred]))
test_class_names = [index_to_class[i] for i in unique_test_classes]
print(f"📊 Clases en conjunto de test: {len(unique_test_classes)}")
print(f"📊 Todas las clases entrenadas: {len(class_indices)}")
print(f"📊 Clases presentes en test: {test_class_names}")
# Verificar si hay clases faltantes
all_classes = set(range(len(class_indices)))
test_classes = set(unique_test_classes)
missing_classes = all_classes - test_classes
if missing_classes:
missing_names = [index_to_class[i] for i in missing_classes]
print(f"⚠️ Clases sin muestras en test: {missing_names}")
# Reporte de clasificación con clases filtradas
print("\n📋 === Reporte de Clasificación ResNet50 ===")
try:
report = classification_report(
y_true, y_pred,
labels=unique_test_classes,
target_names=test_class_names,
output_dict=False,
zero_division=0
)
print(report)
# Guardar reporte
with open(os.path.join(args.output_dir, 'classification_report.txt'), 'w') as f:
f.write(f"Modelo: ResNet50\n")
f.write(f"Clases evaluadas: {test_class_names}\n")
f.write(f"Clases faltantes en test: {[index_to_class[i] for i in missing_classes] if missing_classes else 'Ninguna'}\n\n")
f.write(report)
except Exception as e:
print(f"❌ Error en classification_report: {e}")
print("📊 Generando reporte alternativo...")
# Reporte manual si falla el automático
from collections import Counter
true_counts = Counter(y_true)
pred_counts = Counter(y_pred)
print("\n📊 Distribución manual:")
print("Clase | Verdaderos | Predichos")
print("-" * 35)
for class_idx in unique_test_classes:
class_name = index_to_class[class_idx]
true_count = true_counts.get(class_idx, 0)
pred_count = pred_counts.get(class_idx, 0)
print(f"{class_name[:15]:15} | {true_count:10} | {pred_count:9}")
# Calcular accuracy básico
accuracy = np.mean(y_true == y_pred)
print(f"\n📊 Accuracy general ResNet50: {accuracy:.4f}")
# Guardar reporte manual
with open(os.path.join(args.output_dir, 'classification_report.txt'), 'w') as f:
f.write("REPORTE MANUAL DE CLASIFICACIÓN - ResNet50\n")
f.write("=" * 50 + "\n\n")
f.write(f"Clases evaluadas: {test_class_names}\n")
f.write(f"Clases faltantes en test: {[index_to_class[i] for i in missing_classes] if missing_classes else 'Ninguna'}\n\n")
f.write("Distribución por clase:\n")
f.write("Clase | Verdaderos | Predichos\n")
f.write("-" * 35 + "\n")
for class_idx in unique_test_classes:
class_name = index_to_class[class_idx]
true_count = true_counts.get(class_idx, 0)
pred_count = pred_counts.get(class_idx, 0)
f.write(f"{class_name[:15]:15} | {true_count:10} | {pred_count:9}\n")
f.write(f"\nAccuracy general ResNet50: {accuracy:.4f}\n")
# Matriz de confusión con clases filtradas
cm = confusion_matrix(y_true, y_pred, labels=unique_test_classes)
print(f"\n🔢 Matriz de Confusión ResNet50 ({len(unique_test_classes)} clases):")
print(cm)
np.savetxt(os.path.join(args.output_dir, 'confusion_matrix.csv'),
cm, delimiter=',', fmt='%d')
# Visualizaciones con clases filtradas
print("\n📈 === Generando visualizaciones ResNet50 ===")
# Gráfico de entrenamiento
plot_training_history(history, args.output_dir)
# Matriz de confusión visual con clases filtradas
plot_confusion_matrix(cm, test_class_names, args.output_dir)
# Ejemplos de predicciones con clases filtradas
plot_prediction_examples(test_gen, y_true, y_pred, test_class_names, args.output_dir, unique_test_classes)
print(f"\n🎉 === Pipeline ResNet50 completado ===")
print(f"📁 Resultados guardados en: {args.output_dir}")
print(f"📊 Precisión final en test: {np.mean(y_true == y_pred):.4f}")
print(f"📊 Clases evaluadas: {len(unique_test_classes)}/{len(class_indices)}")
# Información adicional sobre clases desbalanceadas
if missing_classes:
print(f"\n⚠️ === Información sobre Clases Desbalanceadas ===")
print(f"❌ Clases sin muestras en test: {len(missing_classes)}")
for missing_idx in missing_classes:
missing_name = index_to_class[missing_idx]
print(f" - {missing_name} (índice {missing_idx})")
print(f"💡 Sugerencia: Considera aumentar el dataset o fusionar clases similares")
def plot_training_history(history, output_dir):
"""Graficar historial de entrenamiento ResNet50"""
try:
plt.figure(figsize=(12, 4))
# Accuracy
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Entrenamiento', linewidth=2)
if 'val_accuracy' in history.history:
plt.plot(history.history['val_accuracy'], label='Validación', linewidth=2)
plt.title('Precisión del Modelo ResNet50')
plt.xlabel('Época')
plt.ylabel('Precisión')
plt.legend()
plt.grid(True, alpha=0.3)
# Loss
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Entrenamiento', linewidth=2)
if 'val_loss' in history.history:
plt.plot(history.history['val_loss'], label='Validación', linewidth=2)
plt.title('Pérdida del Modelo ResNet50')
plt.xlabel('Época')
plt.ylabel('Pérdida')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'resnet50_training_history.png'), dpi=300, bbox_inches='tight')
plt.close()
print("✅ Gráfico de entrenamiento ResNet50 guardado")
except Exception as e:
print(f"⚠️ Error creando gráfico de entrenamiento: {e}")
def plot_confusion_matrix(cm, class_names, output_dir):
"""Graficar matriz de confusión ResNet50"""
try:
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=class_names, yticklabels=class_names,
cbar_kws={'label': 'Número de muestras'})
plt.title('Matriz de Confusión - ResNet50')
plt.ylabel('Etiqueta Verdadera')
plt.xlabel('Etiqueta Predicha')
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'resnet50_confusion_matrix.png'), dpi=300, bbox_inches='tight')
plt.close()
print("✅ Matriz de confusión ResNet50 guardada")
except Exception as e:
print(f"⚠️ Error creando matriz de confusión: {e}")
def plot_prediction_examples(test_gen, y_true, y_pred, class_names, output_dir, unique_classes=None, n_examples=12):
"""Mostrar ejemplos de predicciones correctas e incorrectas para ResNet50"""
try:
# Obtener índices de predicciones correctas e incorrectas
correct_idx = np.where(y_true == y_pred)[0]
incorrect_idx = np.where(y_true != y_pred)[0]
# Seleccionar ejemplos
n_correct = min(n_examples // 2, len(correct_idx))
n_incorrect = min(n_examples // 2, len(incorrect_idx))
selected_correct = np.random.choice(correct_idx, n_correct, replace=False) if len(correct_idx) > 0 else []
selected_incorrect = np.random.choice(incorrect_idx, n_incorrect, replace=False) if len(incorrect_idx) > 0 else []
selected_indices = np.concatenate([selected_correct, selected_incorrect])
if len(selected_indices) == 0:
print("⚠️ No hay ejemplos para mostrar")
return
# Crear gráfico
n_show = len(selected_indices)
cols = 4
rows = (n_show + cols - 1) // cols
plt.figure(figsize=(15, 4 * rows))
for i, idx in enumerate(selected_indices):
plt.subplot(rows, cols, i + 1)
# Obtener imagen
try:
img_path = test_gen.filepaths[idx]
img = plt.imread(img_path)
plt.imshow(img)
plt.axis('off')
true_label = class_names[y_true[idx]]
pred_label = class_names[y_pred[idx]]
color = 'green' if y_true[idx] == y_pred[idx] else 'red'
plt.title(f'Real: {true_label}\nPredicción: {pred_label}',
color=color, fontsize=9, fontweight='bold')
except Exception as e:
plt.text(0.5, 0.5, f'Error cargando\nimagen {idx}',
ha='center', va='center', transform=plt.gca().transAxes)
plt.axis('off')
plt.suptitle('Ejemplos de Predicciones ResNet50 (Verde=Correcta, Rojo=Incorrecta)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'resnet50_prediction_examples.png'), dpi=300, bbox_inches='tight')
plt.close()
print("✅ Ejemplos de predicciones ResNet50 guardados")
except Exception as e:
print(f"⚠️ Error creando ejemplos de predicciones: {e}")
if __name__ == "__main__":
main()