351 lines
12 KiB
Python
351 lines
12 KiB
Python
"""
|
|
Calcular NMI y ARI desde Resultados de Clustering Previos
|
|
|
|
Este script recalcula las métricas de validación externa (NMI y ARI)
|
|
usando los archivos CSV ya generados, sin necesidad de re-ejecutar todo el pipeline.
|
|
|
|
Uso:
|
|
python Calculate_NMI_ARI.py
|
|
|
|
O con argumentos personalizados:
|
|
python Calculate_NMI_ARI.py --csv_path "ruta/a/clustering_results.csv"
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import argparse
|
|
import pandas as pd
|
|
import numpy as np
|
|
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
|
|
from sklearn.metrics import homogeneity_score, completeness_score, v_measure_score
|
|
import matplotlib.pyplot as plt
|
|
import seaborn as sns
|
|
from datetime import datetime
|
|
|
|
# =============================================================================
|
|
# CONFIGURACIÓN POR DEFECTO
|
|
# =============================================================================
|
|
|
|
DEFAULT_CSV_PATH = r'C:\Users\sof12\Desktop\ML\CResults\Carciofo_Roboflow\results_clustering_avanzado_8C\clustering_results.csv'
|
|
DEFAULT_OUTPUT_DIR = r'C:\Users\sof12\Desktop\ML\CResults\Carciofo_Roboflow\results_clustering_avanzado_8C'
|
|
|
|
# =============================================================================
|
|
# FUNCIONES AUXILIARES
|
|
# =============================================================================
|
|
|
|
def print_header(text):
|
|
"""Imprimir encabezado con formato"""
|
|
print("\n" + "="*80)
|
|
print(f" {text}")
|
|
print("="*80)
|
|
|
|
def print_section(text):
|
|
"""Imprimir sección con formato"""
|
|
print(f"\n{'─'*80}")
|
|
print(f" {text}")
|
|
print(f"{'─'*80}")
|
|
|
|
def safe_read_csv(path):
|
|
"""Leer CSV con manejo de encoding"""
|
|
if not os.path.exists(path):
|
|
raise FileNotFoundError(f' CSV no encontrado: {path}')
|
|
|
|
for encoding in ['utf-8', 'latin-1', 'cp1252']:
|
|
try:
|
|
df = pd.read_csv(path, encoding=encoding)
|
|
print(f" CSV leído correctamente con encoding: {encoding}")
|
|
return df
|
|
except UnicodeDecodeError:
|
|
continue
|
|
|
|
raise ValueError(" No se pudo leer el CSV con ningún encoding")
|
|
|
|
# =============================================================================
|
|
# CÁLCULO DE MÉTRICAS
|
|
# =============================================================================
|
|
|
|
def calculate_metrics(true_labels, pred_labels, method_name="Clustering"):
|
|
"""Calcular todas las métricas de validación externa"""
|
|
|
|
# Métricas principales
|
|
ari = adjusted_rand_score(true_labels, pred_labels)
|
|
nmi = normalized_mutual_info_score(true_labels, pred_labels)
|
|
|
|
|
|
metrics = {
|
|
'method': method_name,
|
|
'ari': ari,
|
|
'nmi': nmi
|
|
}
|
|
|
|
return metrics
|
|
|
|
def print_metrics(metrics):
|
|
"""Imprimir métricas en formato legible"""
|
|
print(f"\n📊 Métricas de Validación Externa - {metrics['method']}")
|
|
print(f"{'─'*60}")
|
|
print(f" 🎯 Adjusted Rand Index (ARI): {metrics['ari']:>7.4f}")
|
|
print(f" 🎯 Normalized Mutual Info (NMI): {metrics['nmi']:>7.4f}")
|
|
print(f"{'─'*60}")
|
|
|
|
# Interpretación
|
|
print(f"\n💡 Interpretación:")
|
|
|
|
# ARI
|
|
if metrics['ari'] > 0.75:
|
|
ari_interp = "Excelente concordancia"
|
|
elif metrics['ari'] > 0.5:
|
|
ari_interp = "Buena concordancia"
|
|
elif metrics['ari'] > 0.25:
|
|
ari_interp = "Concordancia moderada"
|
|
else:
|
|
ari_interp = "Concordancia baja"
|
|
print(f" ARI ({metrics['ari']:.4f}): {ari_interp}")
|
|
|
|
# NMI
|
|
if metrics['nmi'] > 0.75:
|
|
nmi_interp = "Alta información mutua"
|
|
elif metrics['nmi'] > 0.5:
|
|
nmi_interp = "Información mutua moderada"
|
|
elif metrics['nmi'] > 0.25:
|
|
nmi_interp = "Información mutua baja"
|
|
else:
|
|
nmi_interp = "Poca información mutua"
|
|
print(f" NMI ({metrics['nmi']:.4f}): {nmi_interp}")
|
|
|
|
|
|
# =============================================================================
|
|
# VISUALIZACIÓN
|
|
# =============================================================================
|
|
|
|
def create_metrics_comparison(all_metrics, output_dir):
|
|
"""Crear gráfico comparativo de métricas entre métodos"""
|
|
|
|
if len(all_metrics) < 2:
|
|
return
|
|
|
|
df_metrics = pd.DataFrame(all_metrics)
|
|
|
|
# Gráfico de barras
|
|
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
|
|
fig.suptitle('Comparación de Métricas de Validación Externa',
|
|
fontsize=16, fontweight='bold')
|
|
|
|
metrics_to_plot = ['ari', 'nmi']
|
|
metric_names = ['ARI', 'NMI']
|
|
|
|
for idx, (metric, name) in enumerate(zip(metrics_to_plot, metric_names)):
|
|
row = idx // 3
|
|
col = idx % 3
|
|
|
|
ax = axes[row, col]
|
|
|
|
bars = ax.bar(df_metrics['method'], df_metrics[metric], color='steelblue', alpha=0.7)
|
|
ax.set_ylabel(name, fontsize=11, fontweight='bold')
|
|
ax.set_title(f'{name} por Método', fontsize=12)
|
|
ax.set_ylim([0, 1])
|
|
ax.grid(axis='y', alpha=0.3)
|
|
|
|
# Añadir valores sobre las barras
|
|
for bar in bars:
|
|
height = bar.get_height()
|
|
ax.text(bar.get_x() + bar.get_width()/2., height,
|
|
f'{height:.3f}',
|
|
ha='center', va='bottom', fontsize=10)
|
|
|
|
# Rotar etiquetas si es necesario
|
|
ax.tick_params(axis='x', rotation=45)
|
|
|
|
# Ocultar el último subplot si es impar
|
|
axes[1, 2].axis('off')
|
|
|
|
plt.tight_layout()
|
|
comparison_path = os.path.join(output_dir, 'metrics_comparison.png')
|
|
plt.savefig(comparison_path, dpi=300, bbox_inches='tight')
|
|
print(f"✅ Gráfico de comparación guardado: {comparison_path}")
|
|
plt.close()
|
|
|
|
# =============================================================================
|
|
# ANÁLISIS PRINCIPAL
|
|
# =============================================================================
|
|
|
|
def analyze_clustering_results(csv_path, output_dir):
|
|
"""Analizar resultados de clustering desde CSV"""
|
|
|
|
print_header("📊 ANÁLISIS DE MÉTRICAS DE CLUSTERING")
|
|
|
|
# Leer CSV
|
|
print_section("1. Cargando Datos")
|
|
print(f"📁 Ruta del CSV: {csv_path}")
|
|
|
|
df = safe_read_csv(csv_path)
|
|
print(f"📊 Total de filas: {len(df)}")
|
|
print(f"📋 Columnas disponibles: {list(df.columns)}")
|
|
|
|
# Identificar columnas
|
|
print_section("2. Identificando Columnas")
|
|
|
|
# Buscar columna de fase real
|
|
phase_col = None
|
|
for col in ['phase_P', 'fase_P', 'fase_P', 'fase', 'phase']:
|
|
if col in df.columns:
|
|
phase_col = col
|
|
break
|
|
|
|
if phase_col is None:
|
|
print(" No se encontró columna de fase real")
|
|
return
|
|
|
|
print(f" Columna de fase real: {phase_col}")
|
|
|
|
# Buscar columnas de clusters
|
|
cluster_cols = [col for col in df.columns if 'cluster' in col.lower()]
|
|
|
|
if len(cluster_cols) == 0:
|
|
print("❌ No se encontraron columnas de clusters")
|
|
return
|
|
|
|
print(f"✅ Columnas de clusters encontradas: {cluster_cols}")
|
|
|
|
# Análisis para cada método de clustering
|
|
print_section("3. Calculando Métricas")
|
|
|
|
all_metrics = []
|
|
all_contingencies = {}
|
|
|
|
for cluster_col in cluster_cols:
|
|
method_name = cluster_col.replace('cluster_', '').title()
|
|
|
|
print(f"\n🔍 Analizando método: {method_name}")
|
|
|
|
# Filtrar valores válidos (sin NaN)
|
|
valid_mask = (~df[phase_col].isna()) & (~df[cluster_col].isna())
|
|
|
|
if valid_mask.sum() == 0:
|
|
print(f" ⚠️ No hay datos válidos para {method_name}")
|
|
continue
|
|
|
|
print(f" 📊 Muestras válidas: {valid_mask.sum()} de {len(df)}")
|
|
|
|
# Obtener etiquetas
|
|
true_labels_raw = df.loc[valid_mask, phase_col]
|
|
pred_labels_raw = df.loc[valid_mask, cluster_col]
|
|
|
|
# Convertir a códigos numéricos
|
|
true_labels = pd.Categorical(true_labels_raw).codes
|
|
pred_labels = pred_labels_raw.astype(int).values
|
|
|
|
# Obtener nombres únicos
|
|
unique_true = sorted(true_labels_raw.unique())
|
|
unique_pred = sorted(pred_labels_raw.unique())
|
|
|
|
print(f" 🎯 Fases reales únicas: {len(unique_true)} - {unique_true}")
|
|
print(f" 🎯 Clusters únicos: {len(unique_pred)} - {list(unique_pred)}")
|
|
|
|
# Calcular métricas
|
|
metrics = calculate_metrics(true_labels, pred_labels, method_name)
|
|
all_metrics.append(metrics)
|
|
|
|
# Imprimir métricas
|
|
print_metrics(metrics)
|
|
|
|
# Comparación entre métodos
|
|
if len(all_metrics) > 1:
|
|
print_section("4. Comparación entre Métodos")
|
|
create_metrics_comparison(all_metrics, output_dir)
|
|
|
|
# Tabla comparativa
|
|
df_comparison = pd.DataFrame(all_metrics)
|
|
print(f"\n📊 Tabla Comparativa de Métricas:")
|
|
print(df_comparison.to_string(index=False))
|
|
|
|
# Guardar tabla
|
|
comparison_csv = os.path.join(output_dir, 'metrics_comparison.csv')
|
|
df_comparison.to_csv(comparison_csv, index=False)
|
|
print(f"\n✅ Tabla comparativa guardada: {comparison_csv}")
|
|
|
|
# Determinar mejor método
|
|
best_ari_idx = df_comparison['ari'].idxmax()
|
|
best_nmi_idx = df_comparison['nmi'].idxmax()
|
|
|
|
print(f"\n🏆 Mejor método por ARI: {df_comparison.loc[best_ari_idx, 'method']} ({df_comparison.loc[best_ari_idx, 'ari']:.4f})")
|
|
print(f"🏆 Mejor método por NMI: {df_comparison.loc[best_nmi_idx, 'method']} ({df_comparison.loc[best_nmi_idx, 'nmi']:.4f})")
|
|
|
|
# Guardar resumen JSON
|
|
print_section("5. Guardando Resumen")
|
|
|
|
summary = {
|
|
'fecha_analisis': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
|
|
'csv_analizado': csv_path,
|
|
'total_muestras': int(len(df)),
|
|
'fase_columna': phase_col,
|
|
'metricas_por_metodo': all_metrics
|
|
}
|
|
|
|
import json
|
|
summary_path = os.path.join(output_dir, 'nmi_ari_summary.json')
|
|
with open(summary_path, 'w') as f:
|
|
json.dump(summary, f, indent=2)
|
|
|
|
print(f"✅ Resumen JSON guardado: {summary_path}")
|
|
|
|
# Resumen final
|
|
print_header("✅ ANÁLISIS COMPLETADO")
|
|
print(f"\n📁 Resultados guardados en: {output_dir}")
|
|
print(f"\n📄 Archivos generados:")
|
|
print(f" - nmi_ari_summary.json (Resumen completo)")
|
|
print(f" - metrics_comparison.csv (Tabla comparativa)")
|
|
print(f" - metrics_comparison.png (Gráfico comparativo)")
|
|
print("="*80 + "\n")
|
|
|
|
return all_metrics
|
|
|
|
# =============================================================================
|
|
# MAIN
|
|
# =============================================================================
|
|
|
|
def main():
|
|
"""Función principal"""
|
|
|
|
parser = argparse.ArgumentParser(
|
|
description='Calcular NMI y ARI desde resultados de clustering previos'
|
|
)
|
|
parser.add_argument(
|
|
'--csv_path',
|
|
type=str,
|
|
default=DEFAULT_CSV_PATH,
|
|
help='Ruta al CSV con resultados de clustering'
|
|
)
|
|
parser.add_argument(
|
|
'--output_dir',
|
|
type=str,
|
|
default=None,
|
|
help='Directorio de salida (por defecto: mismo que CSV)'
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Determinar directorio de salida
|
|
if args.output_dir is None:
|
|
args.output_dir = os.path.dirname(args.csv_path)
|
|
|
|
# Verificar que el CSV existe
|
|
if not os.path.exists(args.csv_path):
|
|
print(f"\n❌ ERROR: El archivo CSV no existe:")
|
|
print(f" {args.csv_path}")
|
|
print(f"\n💡 Verifica la ruta o usa --csv_path para especificar otra ubicación")
|
|
sys.exit(1)
|
|
|
|
# Ejecutar análisis
|
|
try:
|
|
analyze_clustering_results(args.csv_path, args.output_dir)
|
|
except Exception as e:
|
|
print(f"\n❌ ERROR durante el análisis:")
|
|
print(f" {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
sys.exit(1)
|
|
|
|
if __name__ == '__main__':
|
|
main()
|