Phenology/Code/Supervised_learning/train_resnet50.py
2025-11-06 14:16:49 +01:00

298 lines
9.9 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
Script facilitado para entrenar ResNet50 - Nocciola
Comparación con MobileNetV2
"""
import os
import sys
import subprocess
import argparse
# Configuración de rutas
PROJECT_DIR = r'C:\Users\sof12\Desktop\ML'
RESNET_SCRIPT = os.path.join(PROJECT_DIR, 'Code', 'Supervised_learning', 'ResNET.py')
MOBILENET_SCRIPT = os.path.join(PROJECT_DIR, 'Code', 'Supervised_learning', 'MobileNetV1.py')
PYTHON_EXE = r'C:/Users/sof12/AppData/Local/Programs/Python/Python312/python.exe'
# Datasets disponibles
DATASETS = {
'original': os.path.join(PROJECT_DIR, 'Datasets', 'Nocciola_GBIF', 'metadatos_unidos.csv'),
'filtered': os.path.join(PROJECT_DIR, 'Datasets', 'Nocciola_GBIF', 'metadatos_unidos_filtered.csv'),
'assignments': os.path.join(PROJECT_DIR, 'Datasets', 'Nocciola_GBIF', 'assignments.csv')
}
def check_files():
"""Verificar que todos los archivos necesarios existen"""
print("🔍 === Verificando archivos ===")
files_to_check = {
'ResNet Script': RESNET_SCRIPT,
'MobileNet Script': MOBILENET_SCRIPT,
'Python Executable': PYTHON_EXE
}
all_ok = True
for name, path in files_to_check.items():
if os.path.exists(path):
print(f"{name}: {path}")
else:
print(f"{name}: {path}")
all_ok = False
# Verificar datasets
print(f"\n📊 === Datasets disponibles ===")
for name, path in DATASETS.items():
if os.path.exists(path):
print(f"{name}: {path}")
else:
print(f"{name}: {path}")
return all_ok
def train_resnet50(dataset_key='assignments', epochs=25, force_split=True):
"""Entrenar modelo ResNet50"""
dataset_path = DATASETS.get(dataset_key)
if not dataset_path or not os.path.exists(dataset_path):
print(f"❌ Dataset '{dataset_key}' no encontrado")
return False
print(f"\n🚀 === Entrenando ResNet50 ===")
print(f"📊 Dataset: {dataset_key} ({dataset_path})")
print(f"⏱️ Épocas: {epochs}")
# Comando de entrenamiento
cmd = [
PYTHON_EXE, RESNET_SCRIPT,
'--csv_path', dataset_path,
'--epochs', str(epochs)
]
if force_split:
cmd.append('--force_split')
print("🔄 Comando a ejecutar:")
print(" ".join(cmd))
try:
# Cambiar al directorio correcto
os.chdir(os.path.join(PROJECT_DIR, 'Code', 'Supervised_learning'))
print(f"\n🏋️ Iniciando entrenamiento ResNet50...")
result = subprocess.run(cmd, check=True)
print("✅ Entrenamiento ResNet50 completado exitosamente!")
return True
except subprocess.CalledProcessError as e:
print(f"❌ Error durante el entrenamiento ResNet50: {e}")
return False
except KeyboardInterrupt:
print("\n⚠️ Entrenamiento ResNet50 interrumpido por el usuario")
return False
def train_mobilenet(dataset_key='assignments', epochs=25, force_split=True):
"""Entrenar modelo MobileNetV2 para comparación"""
dataset_path = DATASETS.get(dataset_key)
if not dataset_path or not os.path.exists(dataset_path):
print(f"❌ Dataset '{dataset_key}' no encontrado")
return False
print(f"\n🚀 === Entrenando MobileNetV2 (comparación) ===")
print(f"📊 Dataset: {dataset_key} ({dataset_path})")
print(f"⏱️ Épocas: {epochs}")
# Comando de entrenamiento
cmd = [
PYTHON_EXE, MOBILENET_SCRIPT,
'--csv_path', dataset_path,
'--epochs', str(epochs)
]
if force_split:
cmd.append('--force_split')
print("🔄 Comando a ejecutar:")
print(" ".join(cmd))
try:
# Cambiar al directorio correcto
os.chdir(os.path.join(PROJECT_DIR, 'Code', 'Supervised_learning'))
print(f"\n🏋️ Iniciando entrenamiento MobileNetV2...")
result = subprocess.run(cmd, check=True)
print("✅ Entrenamiento MobileNetV2 completado exitosamente!")
return True
except subprocess.CalledProcessError as e:
print(f"❌ Error durante el entrenamiento MobileNetV2: {e}")
return False
except KeyboardInterrupt:
print("\n⚠️ Entrenamiento MobileNetV2 interrumpido por el usuario")
return False
def compare_models():
"""Mostrar comparación entre modelos"""
print("""
📊 === Comparación ResNet50 vs MobileNetV2 ===
RESNET50:
🟢 Ventajas:
- Mayor capacidad de aprendizaje (25M parámetros)
- Mejor para datasets complejos
- Residual connections mejoran gradiente
- Excelente para transfer learning
🟡 Desventajas:
- Más lento en entrenamiento e inferencia
- Requiere más memoria
- Overfitting en datasets pequeños
MOBILENETV2:
🟢 Ventajas:
- Rápido y eficiente (3.4M parámetros)
- Menos propenso a overfitting
- Ideal para datasets pequeños/medianos
- Menor uso de memoria
🟡 Desventajas:
- Menor capacidad para patrones complejos
- Puede underfittear en datos complejos
RECOMENDACIONES:
📋 Dataset Nocciola (pequeño): MobileNetV2 recomendado
📋 Dataset grande/complejo: ResNet50 recomendado
📋 Producción/móvil: MobileNetV2
📋 Investigación/precisión máxima: ResNet50
""")
def main():
parser = argparse.ArgumentParser(description='Entrenamiento ResNet50 vs MobileNetV2 para Nocciola')
parser.add_argument('--model', choices=['resnet50', 'mobilenet', 'both'], default='resnet50',
help='Modelo a entrenar')
parser.add_argument('--dataset', choices=['original', 'filtered', 'assignments'], default='assignments',
help='Dataset a usar')
parser.add_argument('--epochs', type=int, default=25,
help='Número de épocas')
parser.add_argument('--no_force_split', action='store_true',
help='No forzar recreación del split')
parser.add_argument('--compare', action='store_true',
help='Mostrar comparación de modelos')
parser.add_argument('--check_only', action='store_true',
help='Solo verificar archivos')
args = parser.parse_args()
print("🎯 === Entrenamiento ResNet50 para Nocciola ===")
print(f"📁 Directorio del proyecto: {PROJECT_DIR}")
print(f"🐍 Python ejecutable: {PYTHON_EXE}")
# Verificar archivos
if not check_files():
print("❌ Faltan archivos necesarios")
return False
if args.check_only:
print(" Solo verificación solicitada. Finalizando.")
return True
if args.compare:
compare_models()
return True
force_split = not args.no_force_split
success = True
# Entrenar modelos según selección
if args.model in ['resnet50', 'both']:
print(f"\n🤖 === Iniciando ResNet50 ===")
success &= train_resnet50(args.dataset, args.epochs, force_split)
if args.model in ['mobilenet', 'both']:
print(f"\n🤖 === Iniciando MobileNetV2 ===")
success &= train_mobilenet(args.dataset, args.epochs, force_split)
if success:
print(f"\n🎉 === Entrenamiento(s) Completado(s) ===")
# Mostrar ubicaciones de resultados
results_info = []
if args.model in ['resnet50', 'both']:
resnet_results = os.path.join(PROJECT_DIR, 'Datasets', 'Nocciola_GBIF', 'results_resnet50_faseV')
results_info.append(f"📁 ResNet50: {resnet_results}")
if args.model in ['mobilenet', 'both']:
mobilenet_results = os.path.join(PROJECT_DIR, 'Datasets', 'Nocciola_GBIF', 'results_mobilenet_faseV_V1')
results_info.append(f"📁 MobileNetV2: {mobilenet_results}")
for info in results_info:
print(info)
if args.model == 'both':
print(f"\n💡 Comparar resultados:")
print(f" - Revisar classification_report.txt en cada directorio")
print(f" - Comparar matrices de confusión")
print(f" - Evaluar tiempos de entrenamiento")
return True
return False
def show_help():
"""Mostrar ayuda detallada"""
print("""
📚 === Ayuda - ResNet50 Transfer Learning ===
NUEVO MODELO IMPLEMENTADO:
ResNet50 con transfer learning optimizado para el dataset Nocciola
CARACTERÍSTICAS PRINCIPALES:
🔧 Arquitectura: ResNet50 pre-entrenado (ImageNet)
⚙️ Optimizaciones: BatchNormalization, Dropout adaptativos
🎯 Fine-tuning: Descongelado selectivo de capas residuales
📊 Manejo robusto: Clases desbalanceadas automáticamente
DIFERENCIAS CON MOBILENET:
- Learning rate más conservador (5e-4 vs 1e-3)
- Data augmentation menos agresivo
- Fine-tuning más específico (conv5_x layers)
- Más épocas de fine-tuning (15 vs 10)
EJEMPLOS DE USO:
# Entrenar solo ResNet50
python train_resnet50.py --model resnet50 --epochs 25
# Comparar ambos modelos
python train_resnet50.py --model both --epochs 20
# Análisis de diferencias
python train_resnet50.py --compare
# Verificar setup
python train_resnet50.py --check_only
DATASETS SOPORTADOS:
- assignments: CSV con asignaciones de clases (recomendado)
- filtered: Dataset filtrado sin clases problemáticas
- original: Dataset completo con manejo robusto
OUTPUTS ESPECÍFICOS:
- best_resnet50_model.keras: Mejor modelo durante entrenamiento
- final_resnet50_model.keras: Modelo final
- resnet50_training_history.png: Gráficos específicos
- resnet50_confusion_matrix.png: Matriz de confusión
- resnet50_prediction_examples.png: Ejemplos de predicciones
""")
if __name__ == "__main__":
if len(sys.argv) > 1 and sys.argv[1] in ['--help', '-h', 'help']:
show_help()
else:
success = main()
if not success:
print("\n❌ Proceso terminado con errores")
print("💡 Usa --help para más información")
else:
print("\n🎉 ¡Proceso completado exitosamente!")