Phenology/Code/Unsupervised_learning/Past_codes/Calculate_NMI_ARI.py
2025-11-25 11:30:37 +01:00

351 lines
12 KiB
Python

"""
Calcular NMI y ARI desde Resultados de Clustering Previos
Este script recalcula las métricas de validación externa (NMI y ARI)
usando los archivos CSV ya generados, sin necesidad de re-ejecutar todo el pipeline.
Uso:
python Calculate_NMI_ARI.py
O con argumentos personalizados:
python Calculate_NMI_ARI.py --csv_path "ruta/a/clustering_results.csv"
"""
import os
import sys
import argparse
import pandas as pd
import numpy as np
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
from sklearn.metrics import homogeneity_score, completeness_score, v_measure_score
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
# =============================================================================
# CONFIGURACIÓN POR DEFECTO
# =============================================================================
DEFAULT_CSV_PATH = r'C:\Users\sof12\Desktop\ML\CResults\Carciofo_Roboflow\results_clustering_avanzado_8C\clustering_results.csv'
DEFAULT_OUTPUT_DIR = r'C:\Users\sof12\Desktop\ML\CResults\Carciofo_Roboflow\results_clustering_avanzado_8C'
# =============================================================================
# FUNCIONES AUXILIARES
# =============================================================================
def print_header(text):
"""Imprimir encabezado con formato"""
print("\n" + "="*80)
print(f" {text}")
print("="*80)
def print_section(text):
"""Imprimir sección con formato"""
print(f"\n{''*80}")
print(f" {text}")
print(f"{''*80}")
def safe_read_csv(path):
"""Leer CSV con manejo de encoding"""
if not os.path.exists(path):
raise FileNotFoundError(f' CSV no encontrado: {path}')
for encoding in ['utf-8', 'latin-1', 'cp1252']:
try:
df = pd.read_csv(path, encoding=encoding)
print(f" CSV leído correctamente con encoding: {encoding}")
return df
except UnicodeDecodeError:
continue
raise ValueError(" No se pudo leer el CSV con ningún encoding")
# =============================================================================
# CÁLCULO DE MÉTRICAS
# =============================================================================
def calculate_metrics(true_labels, pred_labels, method_name="Clustering"):
"""Calcular todas las métricas de validación externa"""
# Métricas principales
ari = adjusted_rand_score(true_labels, pred_labels)
nmi = normalized_mutual_info_score(true_labels, pred_labels)
metrics = {
'method': method_name,
'ari': ari,
'nmi': nmi
}
return metrics
def print_metrics(metrics):
"""Imprimir métricas en formato legible"""
print(f"\n📊 Métricas de Validación Externa - {metrics['method']}")
print(f"{''*60}")
print(f" 🎯 Adjusted Rand Index (ARI): {metrics['ari']:>7.4f}")
print(f" 🎯 Normalized Mutual Info (NMI): {metrics['nmi']:>7.4f}")
print(f"{''*60}")
# Interpretación
print(f"\n💡 Interpretación:")
# ARI
if metrics['ari'] > 0.75:
ari_interp = "Excelente concordancia"
elif metrics['ari'] > 0.5:
ari_interp = "Buena concordancia"
elif metrics['ari'] > 0.25:
ari_interp = "Concordancia moderada"
else:
ari_interp = "Concordancia baja"
print(f" ARI ({metrics['ari']:.4f}): {ari_interp}")
# NMI
if metrics['nmi'] > 0.75:
nmi_interp = "Alta información mutua"
elif metrics['nmi'] > 0.5:
nmi_interp = "Información mutua moderada"
elif metrics['nmi'] > 0.25:
nmi_interp = "Información mutua baja"
else:
nmi_interp = "Poca información mutua"
print(f" NMI ({metrics['nmi']:.4f}): {nmi_interp}")
# =============================================================================
# VISUALIZACIÓN
# =============================================================================
def create_metrics_comparison(all_metrics, output_dir):
"""Crear gráfico comparativo de métricas entre métodos"""
if len(all_metrics) < 2:
return
df_metrics = pd.DataFrame(all_metrics)
# Gráfico de barras
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
fig.suptitle('Comparación de Métricas de Validación Externa',
fontsize=16, fontweight='bold')
metrics_to_plot = ['ari', 'nmi']
metric_names = ['ARI', 'NMI']
for idx, (metric, name) in enumerate(zip(metrics_to_plot, metric_names)):
row = idx // 3
col = idx % 3
ax = axes[row, col]
bars = ax.bar(df_metrics['method'], df_metrics[metric], color='steelblue', alpha=0.7)
ax.set_ylabel(name, fontsize=11, fontweight='bold')
ax.set_title(f'{name} por Método', fontsize=12)
ax.set_ylim([0, 1])
ax.grid(axis='y', alpha=0.3)
# Añadir valores sobre las barras
for bar in bars:
height = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2., height,
f'{height:.3f}',
ha='center', va='bottom', fontsize=10)
# Rotar etiquetas si es necesario
ax.tick_params(axis='x', rotation=45)
# Ocultar el último subplot si es impar
axes[1, 2].axis('off')
plt.tight_layout()
comparison_path = os.path.join(output_dir, 'metrics_comparison.png')
plt.savefig(comparison_path, dpi=300, bbox_inches='tight')
print(f"✅ Gráfico de comparación guardado: {comparison_path}")
plt.close()
# =============================================================================
# ANÁLISIS PRINCIPAL
# =============================================================================
def analyze_clustering_results(csv_path, output_dir):
"""Analizar resultados de clustering desde CSV"""
print_header("📊 ANÁLISIS DE MÉTRICAS DE CLUSTERING")
# Leer CSV
print_section("1. Cargando Datos")
print(f"📁 Ruta del CSV: {csv_path}")
df = safe_read_csv(csv_path)
print(f"📊 Total de filas: {len(df)}")
print(f"📋 Columnas disponibles: {list(df.columns)}")
# Identificar columnas
print_section("2. Identificando Columnas")
# Buscar columna de fase real
phase_col = None
for col in ['phase_P', 'fase_P', 'fase_P', 'fase', 'phase']:
if col in df.columns:
phase_col = col
break
if phase_col is None:
print(" No se encontró columna de fase real")
return
print(f" Columna de fase real: {phase_col}")
# Buscar columnas de clusters
cluster_cols = [col for col in df.columns if 'cluster' in col.lower()]
if len(cluster_cols) == 0:
print("❌ No se encontraron columnas de clusters")
return
print(f"✅ Columnas de clusters encontradas: {cluster_cols}")
# Análisis para cada método de clustering
print_section("3. Calculando Métricas")
all_metrics = []
all_contingencies = {}
for cluster_col in cluster_cols:
method_name = cluster_col.replace('cluster_', '').title()
print(f"\n🔍 Analizando método: {method_name}")
# Filtrar valores válidos (sin NaN)
valid_mask = (~df[phase_col].isna()) & (~df[cluster_col].isna())
if valid_mask.sum() == 0:
print(f" ⚠️ No hay datos válidos para {method_name}")
continue
print(f" 📊 Muestras válidas: {valid_mask.sum()} de {len(df)}")
# Obtener etiquetas
true_labels_raw = df.loc[valid_mask, phase_col]
pred_labels_raw = df.loc[valid_mask, cluster_col]
# Convertir a códigos numéricos
true_labels = pd.Categorical(true_labels_raw).codes
pred_labels = pred_labels_raw.astype(int).values
# Obtener nombres únicos
unique_true = sorted(true_labels_raw.unique())
unique_pred = sorted(pred_labels_raw.unique())
print(f" 🎯 Fases reales únicas: {len(unique_true)} - {unique_true}")
print(f" 🎯 Clusters únicos: {len(unique_pred)} - {list(unique_pred)}")
# Calcular métricas
metrics = calculate_metrics(true_labels, pred_labels, method_name)
all_metrics.append(metrics)
# Imprimir métricas
print_metrics(metrics)
# Comparación entre métodos
if len(all_metrics) > 1:
print_section("4. Comparación entre Métodos")
create_metrics_comparison(all_metrics, output_dir)
# Tabla comparativa
df_comparison = pd.DataFrame(all_metrics)
print(f"\n📊 Tabla Comparativa de Métricas:")
print(df_comparison.to_string(index=False))
# Guardar tabla
comparison_csv = os.path.join(output_dir, 'metrics_comparison.csv')
df_comparison.to_csv(comparison_csv, index=False)
print(f"\n✅ Tabla comparativa guardada: {comparison_csv}")
# Determinar mejor método
best_ari_idx = df_comparison['ari'].idxmax()
best_nmi_idx = df_comparison['nmi'].idxmax()
print(f"\n🏆 Mejor método por ARI: {df_comparison.loc[best_ari_idx, 'method']} ({df_comparison.loc[best_ari_idx, 'ari']:.4f})")
print(f"🏆 Mejor método por NMI: {df_comparison.loc[best_nmi_idx, 'method']} ({df_comparison.loc[best_nmi_idx, 'nmi']:.4f})")
# Guardar resumen JSON
print_section("5. Guardando Resumen")
summary = {
'fecha_analisis': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
'csv_analizado': csv_path,
'total_muestras': int(len(df)),
'fase_columna': phase_col,
'metricas_por_metodo': all_metrics
}
import json
summary_path = os.path.join(output_dir, 'nmi_ari_summary.json')
with open(summary_path, 'w') as f:
json.dump(summary, f, indent=2)
print(f"✅ Resumen JSON guardado: {summary_path}")
# Resumen final
print_header("✅ ANÁLISIS COMPLETADO")
print(f"\n📁 Resultados guardados en: {output_dir}")
print(f"\n📄 Archivos generados:")
print(f" - nmi_ari_summary.json (Resumen completo)")
print(f" - metrics_comparison.csv (Tabla comparativa)")
print(f" - metrics_comparison.png (Gráfico comparativo)")
print("="*80 + "\n")
return all_metrics
# =============================================================================
# MAIN
# =============================================================================
def main():
"""Función principal"""
parser = argparse.ArgumentParser(
description='Calcular NMI y ARI desde resultados de clustering previos'
)
parser.add_argument(
'--csv_path',
type=str,
default=DEFAULT_CSV_PATH,
help='Ruta al CSV con resultados de clustering'
)
parser.add_argument(
'--output_dir',
type=str,
default=None,
help='Directorio de salida (por defecto: mismo que CSV)'
)
args = parser.parse_args()
# Determinar directorio de salida
if args.output_dir is None:
args.output_dir = os.path.dirname(args.csv_path)
# Verificar que el CSV existe
if not os.path.exists(args.csv_path):
print(f"\n❌ ERROR: El archivo CSV no existe:")
print(f" {args.csv_path}")
print(f"\n💡 Verifica la ruta o usa --csv_path para especificar otra ubicación")
sys.exit(1)
# Ejecutar análisis
try:
analyze_clustering_results(args.csv_path, args.output_dir)
except Exception as e:
print(f"\n❌ ERROR durante el análisis:")
print(f" {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == '__main__':
main()