Supervised Learning models

This commit is contained in:
SOFIA GARCIA s321387 2025-11-06 14:16:49 +01:00
parent bde2959227
commit a1e046c1ae
51 changed files with 8508 additions and 2 deletions

View File

@ -0,0 +1,944 @@
"""
MobileNetV2 Transfer Learning para Clasificación de Fases Fenológicas - Nocciola
Adaptado para Visual Studio Code
Dataset: Nocciola GBIF
Objetivo: Predecir fase R (fenológica reproductive)
"""
import os
import shutil
import random
import argparse
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.utils import class_weight
import json
# ----------------- CONFIG -----------------
PROJECT_PATH = r'C:\Users\sof12\Desktop\ML\Datasets\Nocciola\GBIF'
IMAGES_DIR = PROJECT_PATH # Las imágenes están en el directorio principal
CSV_PATH = os.path.join(PROJECT_PATH, 'assignments.csv') # CSV principal
OUTPUT_DIR = os.path.join(PROJECT_PATH, 'results_mobilenet_faseV_V1')
os.makedirs(OUTPUT_DIR, exist_ok=True)
IMG_SIZE = (224, 224) # Recomendado para MobileNetV2
BATCH_SIZE = 16 # Reducido para mejor estabilidad
SEED = 42
SPLIT = {'train': 0.7, 'val': 0.15, 'test': 0.15}
FORCE_SPLIT = False
# ----------------- Utilities -----------------
def set_seed(seed=42):
"""Establecer semilla para reproducibilidad"""
random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)
def analyze_class_distribution(df, column_name='fase V'):
"""Analizar distribución de clases y detectar desbalances"""
print(f"\n📊 === Análisis de Distribución de Clases ===")
# Contar por clase
counts = df[column_name].value_counts()
total = len(df)
print(f"📊 Total de muestras: {total}")
print(f"📊 Número de clases: {len(counts)}")
print(f"📊 Distribución por clase:")
# Mostrar estadísticas detalladas
for clase, count in counts.items():
percentage = (count / total) * 100
print(f" - {clase}: {count} muestras ({percentage:.1f}%)")
# Detectar clases problemáticas
min_samples = 5 # Umbral mínimo recomendado
small_classes = counts[counts < min_samples]
if len(small_classes) > 0:
print(f"\n⚠️ Clases con menos de {min_samples} muestras:")
for clase, count in small_classes.items():
print(f" - {clase}: {count} muestras")
print(f"\n💡 Recomendaciones:")
print(f" 1. Considera recolectar más datos para estas clases")
print(f" 2. O fusionar clases similares")
print(f" 3. O usar técnicas de data augmentation específicas")
return counts, small_classes
def safe_read_csv(path):
"""Leer CSV con manejo de encoding"""
if not os.path.exists(path):
raise FileNotFoundError(f'CSV no encontrado: {path}')
try:
df = pd.read_csv(path, encoding='utf-8')
except UnicodeDecodeError:
try:
df = pd.read_csv(path, encoding='latin-1')
except:
df = pd.read_csv(path, encoding='iso-8859-1')
return df
def resolve_image_path(images_dir, img_id):
"""Resolver la ruta completa de una imagen"""
if pd.isna(img_id) or str(img_id).strip() == '':
return None
img_id = str(img_id).strip()
# Verificar si ya incluye extensión y existe
direct_path = os.path.join(images_dir, img_id)
if os.path.exists(direct_path):
return direct_path
# Probar con extensiones comunes
stem = os.path.splitext(img_id)[0]
for ext in ['.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG']:
img_path = os.path.join(images_dir, stem + ext)
if os.path.exists(img_path):
return img_path
return None
def prepare_image_folders(df, images_dir, out_dir, split=SPLIT, seed=SEED):
"""Crear estructura de carpetas para flow_from_directory"""
set_seed(seed)
# Filtrar solo filas con fase R válida e imágenes existentes
print(f"📊 Datos iniciales: {len(df)} filas")
# Filtrar filas con fase R válida
df_valid = df.dropna(subset=['fase V']).copy()
df_valid = df_valid[df_valid['fase V'].str.strip() != '']
print(f"📊 Con fase V válida: {len(df_valid)} filas")
# Verificar existencia de imágenes
valid_rows = []
for _, row in df_valid.iterrows():
img_path = resolve_image_path(images_dir, row['id_img'])
if img_path:
valid_rows.append(row)
else:
print(f"⚠️ Imagen no encontrada: {row['id_img']}")
if not valid_rows:
raise ValueError("❌ No se encontraron imágenes válidas")
df_final = pd.DataFrame(valid_rows)
print(f"📊 Con imágenes existentes: {len(df_final)} filas")
# Mostrar distribución de clases
fase_counts = df_final['fase V'].value_counts()
print(f"\n📊 Distribución de fases R:")
for fase, count in fase_counts.items():
print(f" - {fase}: {count} imágenes")
# Remover clases con muy pocas muestras (menos de 3)
min_samples = 3
valid_phases = fase_counts[fase_counts >= min_samples].index.tolist()
if len(valid_phases) < len(fase_counts):
excluded = fase_counts[fase_counts < min_samples].index.tolist()
print(f"⚠️ Excluyendo fases con menos de {min_samples} muestras: {excluded}")
df_final = df_final[df_final['fase V'].isin(valid_phases)]
print(f"📊 Después de filtrar: {len(df_final)} filas, {len(valid_phases)} clases")
labels = df_final['fase V'].unique().tolist()
print(f"📊 Clases finales: {labels}")
# Mezclar y dividir datos
df_shuffled = df_final.sample(frac=1, random_state=seed).reset_index(drop=True)
n = len(df_shuffled)
n_train = int(n * split['train'])
n_val = int(n * split['val'])
train_df = df_shuffled.iloc[:n_train]
val_df = df_shuffled.iloc[n_train:n_train + n_val]
test_df = df_shuffled.iloc[n_train + n_val:]
print(f"📊 División final:")
print(f" - Entrenamiento: {len(train_df)} imágenes")
print(f" - Validación: {len(val_df)} imágenes")
print(f" - Prueba: {len(test_df)} imágenes")
# Crear estructura de carpetas
for part in ['train', 'val', 'test']:
for label in labels:
label_dir = os.path.join(out_dir, part, str(label))
os.makedirs(label_dir, exist_ok=True)
# Función para copiar imágenes
def copy_subset(subdf, subset_name):
copied, missing = 0, 0
for _, row in subdf.iterrows():
src = resolve_image_path(images_dir, row['id_img'])
if src:
fase = str(row['fase V'])
dst = os.path.join(out_dir, subset_name, fase, f"{row['id_img']}.jpg")
try:
shutil.copy2(src, dst)
copied += 1
except Exception as e:
print(f"⚠️ Error copiando {src}: {e}")
missing += 1
else:
missing += 1
print(f"{subset_name}: {copied} imágenes copiadas, {missing} fallidas")
return copied
# Copiar imágenes a las carpetas correspondientes
copy_subset(train_df, 'train')
copy_subset(val_df, 'val')
copy_subset(test_df, 'test')
return train_df, val_df, test_df
def main():
"""Función principal del pipeline"""
parser = argparse.ArgumentParser(description='MobileNetV2 Transfer Learning para Nocciola')
parser.add_argument('--csv_path', type=str, default=CSV_PATH,
help='Ruta al archivo CSV con metadatos')
parser.add_argument('--images_dir', type=str, default=IMAGES_DIR,
help='Directorio con las imágenes')
parser.add_argument('--output_dir', type=str, default=OUTPUT_DIR,
help='Directorio de salida para resultados')
parser.add_argument('--epochs', type=int, default=30,
help='Número de épocas de entrenamiento')
parser.add_argument('--force_split', action='store_true',
help='Forzar recreación del split de datos')
args = parser.parse_args()
print('\n🚀 === Inicio del pipeline MobileNetV2 para Nocciola ===')
print(f"📁 Directorio de imágenes: {args.images_dir}")
print(f"📄 Archivo CSV: {args.csv_path}")
print(f"📂 Directorio de salida: {args.output_dir}")
# Establecer semilla
set_seed(SEED)
# Crear directorio de salida
os.makedirs(args.output_dir, exist_ok=True)
# Leer datos
print('\n📊 === Cargando datos ===')
df = safe_read_csv(args.csv_path)
print(f'📊 Total de registros en CSV: {len(df)}')
print(f'📊 Columnas disponibles: {list(df.columns)}')
# Verificar columnas requeridas
required_cols = {'id_img', 'fase V'}
if not required_cols.issubset(set(df.columns)):
missing = required_cols - set(df.columns)
raise ValueError(f'❌ CSV debe contener las columnas: {missing}')
# Analizar distribución de clases antes del procesamiento
analyze_class_distribution(df, 'fase V')
# Preparar estructura de carpetas
SPLIT_DIR = os.path.join(args.output_dir, 'data_split')
if args.force_split and os.path.exists(SPLIT_DIR):
print("🗑️ Eliminando split existente...")
shutil.rmtree(SPLIT_DIR)
if not os.path.exists(SPLIT_DIR):
print("\n📁 === Creando nueva división de datos ===")
train_df, val_df, test_df = prepare_image_folders(df, args.images_dir, SPLIT_DIR)
# Guardar información del split
train_df.to_csv(os.path.join(args.output_dir, 'train_split.csv'), index=False)
val_df.to_csv(os.path.join(args.output_dir, 'val_split.csv'), index=False)
test_df.to_csv(os.path.join(args.output_dir, 'test_split.csv'), index=False)
else:
print("\n♻️ === Reutilizando división existente ===")
# Cargar información del split si existe
try:
train_df = pd.read_csv(os.path.join(args.output_dir, 'train_split.csv'))
val_df = pd.read_csv(os.path.join(args.output_dir, 'val_split.csv'))
test_df = pd.read_csv(os.path.join(args.output_dir, 'test_split.csv'))
except:
print("⚠️ No se pudieron cargar los archivos de split, recreando...")
train_df, val_df, test_df = prepare_image_folders(df, args.images_dir, SPLIT_DIR)
# Crear generadores de datos
print("\n🔄 === Creando generadores de datos ===")
# Data augmentation para entrenamiento
train_datagen = ImageDataGenerator(
rescale=1./255,
rotation_range=20,
width_shift_range=0.1,
height_shift_range=0.1,
shear_range=0.1,
zoom_range=0.1,
horizontal_flip=True,
fill_mode='nearest'
)
# Solo normalización para validación y test
val_test_datagen = ImageDataGenerator(rescale=1./255)
# Crear generadores
train_gen = train_datagen.flow_from_directory(
os.path.join(SPLIT_DIR, 'train'),
target_size=IMG_SIZE,
batch_size=BATCH_SIZE,
class_mode='categorical',
seed=SEED
)
val_gen = val_test_datagen.flow_from_directory(
os.path.join(SPLIT_DIR, 'val'),
target_size=IMG_SIZE,
batch_size=BATCH_SIZE,
class_mode='categorical',
shuffle=False
)
test_gen = val_test_datagen.flow_from_directory(
os.path.join(SPLIT_DIR, 'test'),
target_size=IMG_SIZE,
batch_size=BATCH_SIZE,
class_mode='categorical',
shuffle=False
)
# Guardar mapeo de clases
class_indices = train_gen.class_indices
print(f'🏷️ Mapeo de clases: {class_indices}')
with open(os.path.join(args.output_dir, 'class_indices.json'), 'w') as f:
json.dump(class_indices, f, indent=2)
print(f"📊 Muestras por conjunto:")
print(f" - Entrenamiento: {train_gen.samples}")
print(f" - Validación: {val_gen.samples}")
print(f" - Prueba: {test_gen.samples}")
print(f" - Número de clases: {train_gen.num_classes}")
# Crear y entrenar modelo
print("\n🤖 === Construcción del modelo ===")
# Modelo base MobileNetV2
base_model = MobileNetV2(
weights='imagenet',
include_top=False,
input_shape=(*IMG_SIZE, 3)
)
base_model.trainable = False # Congelar inicialmente
# Construir modelo secuencial
model = models.Sequential([
base_model,
layers.GlobalAveragePooling2D(),
layers.Dropout(0.3),
layers.Dense(128, activation='relu'),
layers.Dropout(0.3),
layers.Dense(train_gen.num_classes, activation='softmax')
])
# Compilar modelo
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
loss='categorical_crossentropy',
metrics=['accuracy']
)
print("📋 Resumen del modelo:")
model.summary()
# Calcular pesos de clase
print("\n⚖️ === Calculando pesos de clase ===")
try:
# Obtener etiquetas de entrenamiento
train_labels = []
for i in range(len(train_gen)):
_, labels = train_gen[i]
train_labels.extend(np.argmax(labels, axis=1))
if len(train_labels) >= train_gen.samples:
break
# Calcular pesos balanceados
class_weights = class_weight.compute_class_weight(
'balanced',
classes=np.unique(train_labels),
y=train_labels
)
class_weight_dict = dict(zip(np.unique(train_labels), class_weights))
print(f"⚖️ Pesos de clase: {class_weight_dict}")
except Exception as e:
print(f"⚠️ Error calculando pesos de clase: {e}")
class_weight_dict = None
# Callbacks para entrenamiento
early_stopping = tf.keras.callbacks.EarlyStopping(
monitor='val_loss',
patience=7,
restore_best_weights=True,
verbose=1
)
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
os.path.join(args.output_dir, 'best_model.keras'),
save_best_only=True,
monitor='val_loss',
verbose=1
)
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
monitor='val_loss',
factor=0.2,
patience=3,
min_lr=1e-7,
verbose=1
)
callbacks = [early_stopping, model_checkpoint, reduce_lr]
# Entrenamiento inicial
print(f"\n🏋️ === Entrenamiento inicial ({args.epochs} épocas) ===")
try:
history = model.fit(
train_gen,
validation_data=val_gen,
epochs=args.epochs,
callbacks=callbacks,
class_weight=class_weight_dict,
verbose=1
)
print("✅ Entrenamiento inicial completado")
except Exception as e:
print(f"❌ Error durante entrenamiento: {e}")
# Entrenar sin class_weight si hay problemas
print("🔄 Intentando entrenamiento sin pesos de clase...")
history = model.fit(
train_gen,
validation_data=val_gen,
epochs=args.epochs,
callbacks=callbacks,
verbose=1
)
# Fine-tuning
print("\n🔧 === Fine-tuning ===")
# Descongelar algunas capas del modelo base
base_model.trainable = True
fine_tune_at = 100 # Descongelar las últimas 100 capas
for layer in base_model.layers[:fine_tune_at]:
layer.trainable = False
# Recompilar con learning rate más bajo
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),
loss='categorical_crossentropy',
metrics=['accuracy']
)
# Continuar entrenamiento
fine_tune_epochs = 10
total_epochs = len(history.history['loss']) + fine_tune_epochs
try:
history_fine = model.fit(
train_gen,
validation_data=val_gen,
epochs=total_epochs,
initial_epoch=len(history.history['loss']),
callbacks=callbacks,
verbose=1
)
print("✅ Fine-tuning completado")
# Combinar historiales
for key in history.history:
if key in history_fine.history:
history.history[key].extend(history_fine.history[key])
except Exception as e:
print(f"⚠️ Error durante fine-tuning: {e}")
print("Continuando con modelo del entrenamiento inicial...")
# Evaluación final
print("\n📊 === Evaluación en conjunto de prueba ===")
# Cargar mejor modelo
try:
model.load_weights(os.path.join(args.output_dir, 'best_model.keras'))
print("✅ Cargado mejor modelo guardado")
except:
print("⚠️ Usando modelo actual")
# Guardar modelo final
model.save(os.path.join(args.output_dir, 'final_model.keras'))
print("💾 Modelo final guardado")
# Predicciones en test
test_gen.reset()
y_pred_prob = model.predict(test_gen, verbose=1)
y_pred = np.argmax(y_pred_prob, axis=1)
y_true = test_gen.classes
# Mapeo de índices a nombres de clase
index_to_class = {v: k for k, v in class_indices.items()}
# Obtener solo las clases que realmente aparecen en el conjunto de test
unique_test_classes = np.unique(np.concatenate([y_true, y_pred]))
test_class_names = [index_to_class[i] for i in unique_test_classes]
print(f"📊 Clases en conjunto de test: {len(unique_test_classes)}")
print(f"📊 Todas las clases entrenadas: {len(class_indices)}")
print(f"📊 Clases presentes en test: {test_class_names}")
# Verificar si hay clases faltantes
all_classes = set(range(len(class_indices)))
test_classes = set(unique_test_classes)
missing_classes = all_classes - test_classes
if missing_classes:
missing_names = [index_to_class[i] for i in missing_classes]
print(f"⚠️ Clases sin muestras en test: {missing_names}")
# Reporte de clasificación con clases filtradas
print("\n📋 === Reporte de Clasificación ===")
try:
report = classification_report(
y_true, y_pred,
labels=unique_test_classes, # Especificar las clases exactas
target_names=test_class_names,
output_dict=False,
zero_division=0 # Manejar divisiones por cero
)
print(report)
# Guardar reporte
with open(os.path.join(args.output_dir, 'classification_report.txt'), 'w') as f:
f.write(f"Clases evaluadas: {test_class_names}\n")
f.write(f"Clases faltantes en test: {[index_to_class[i] for i in missing_classes] if missing_classes else 'Ninguna'}\n\n")
f.write(report)
except Exception as e:
print(f"❌ Error en classification_report: {e}")
print("📊 Generando reporte alternativo...")
# Reporte manual si falla el automático
from collections import Counter
true_counts = Counter(y_true)
pred_counts = Counter(y_pred)
print("\n📊 Distribución manual:")
print("Clase | Verdaderos | Predichos")
print("-" * 35)
for class_idx in unique_test_classes:
class_name = index_to_class[class_idx]
true_count = true_counts.get(class_idx, 0)
pred_count = pred_counts.get(class_idx, 0)
print(f"{class_name[:15]:15} | {true_count:10} | {pred_count:9}")
# Calcular accuracy básico
accuracy = np.mean(y_true == y_pred)
print(f"\n📊 Accuracy general: {accuracy:.4f}")
# Guardar reporte manual
with open(os.path.join(args.output_dir, 'classification_report.txt'), 'w') as f:
f.write("REPORTE MANUAL DE CLASIFICACIÓN\n")
f.write("=" * 40 + "\n\n")
f.write(f"Clases evaluadas: {test_class_names}\n")
f.write(f"Clases faltantes en test: {[index_to_class[i] for i in missing_classes] if missing_classes else 'Ninguna'}\n\n")
f.write("Distribución por clase:\n")
f.write("Clase | Verdaderos | Predichos\n")
f.write("-" * 35 + "\n")
for class_idx in unique_test_classes:
class_name = index_to_class[class_idx]
true_count = true_counts.get(class_idx, 0)
pred_count = pred_counts.get(class_idx, 0)
f.write(f"{class_name[:15]:15} | {true_count:10} | {pred_count:9}\n")
f.write(f"\nAccuracy general: {accuracy:.4f}\n")
# Matriz de confusión con clases filtradas
cm = confusion_matrix(y_true, y_pred, labels=unique_test_classes)
print(f"\n🔢 Matriz de Confusión ({len(unique_test_classes)} clases):")
print(cm)
np.savetxt(os.path.join(args.output_dir, 'confusion_matrix.csv'),
cm, delimiter=',', fmt='%d')
# Visualizaciones con clases filtradas
print("\n📈 === Generando visualizaciones ===")
# Gráfico de entrenamiento
plot_training_history(history, args.output_dir)
# Matriz de confusión visual con clases filtradas
plot_confusion_matrix(cm, test_class_names, args.output_dir)
# Ejemplos de predicciones con clases filtradas
plot_prediction_examples(test_gen, y_true, y_pred, test_class_names, args.output_dir, unique_test_classes)
print(f"\n🎉 === Pipeline completado ===")
print(f"📁 Resultados guardados en: {args.output_dir}")
print(f"📊 Precisión final en test: {np.mean(y_true == y_pred):.4f}")
print(f"📊 Clases evaluadas: {len(unique_test_classes)}/{len(class_indices)}")
# Información adicional sobre clases desbalanceadas
if missing_classes:
print(f"\n⚠️ === Información sobre Clases Desbalanceadas ===")
print(f"❌ Clases sin muestras en test: {len(missing_classes)}")
for missing_idx in missing_classes:
missing_name = index_to_class[missing_idx]
print(f" - {missing_name} (índice {missing_idx})")
print(f"💡 Sugerencia: Considera aumentar el dataset o fusionar clases similares")
def plot_training_history(history, output_dir):
"""Graficar historial de entrenamiento"""
try:
plt.figure(figsize=(12, 4))
# Accuracy
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Entrenamiento')
if 'val_accuracy' in history.history:
plt.plot(history.history['val_accuracy'], label='Validación')
plt.title('Precisión del Modelo')
plt.xlabel('Época')
plt.ylabel('Precisión')
plt.legend()
plt.grid(True)
# Loss
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Entrenamiento')
if 'val_loss' in history.history:
plt.plot(history.history['val_loss'], label='Validación')
plt.title('Pérdida del Modelo')
plt.xlabel('Época')
plt.ylabel('Pérdida')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'training_history.png'), dpi=300, bbox_inches='tight')
plt.close()
print("✅ Gráfico de entrenamiento guardado")
except Exception as e:
print(f"⚠️ Error creando gráfico de entrenamiento: {e}")
def plot_confusion_matrix(cm, class_names, output_dir):
"""Graficar matriz de confusión"""
try:
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=class_names, yticklabels=class_names)
plt.title('Matriz de Confusión')
plt.ylabel('Etiqueta Verdadera')
plt.xlabel('Etiqueta Predicha')
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'confusion_matrix.png'), dpi=300, bbox_inches='tight')
plt.close()
print("✅ Matriz de confusión guardada")
except Exception as e:
print(f"⚠️ Error creando matriz de confusión: {e}")
def plot_prediction_examples(test_gen, y_true, y_pred, class_names, output_dir, unique_classes=None, n_examples=12):
"""Mostrar ejemplos de predicciones correctas e incorrectas"""
try:
# Obtener índices de predicciones correctas e incorrectas
correct_idx = np.where(y_true == y_pred)[0]
incorrect_idx = np.where(y_true != y_pred)[0]
# Seleccionar ejemplos
n_correct = min(n_examples // 2, len(correct_idx))
n_incorrect = min(n_examples // 2, len(incorrect_idx))
selected_correct = np.random.choice(correct_idx, n_correct, replace=False) if len(correct_idx) > 0 else []
selected_incorrect = np.random.choice(incorrect_idx, n_incorrect, replace=False) if len(incorrect_idx) > 0 else []
selected_indices = np.concatenate([selected_correct, selected_incorrect])
if len(selected_indices) == 0:
print("⚠️ No hay ejemplos para mostrar")
return
# Crear gráfico
n_show = len(selected_indices)
cols = 4
rows = (n_show + cols - 1) // cols
plt.figure(figsize=(15, 4 * rows))
for i, idx in enumerate(selected_indices):
plt.subplot(rows, cols, i + 1)
# Obtener imagen
# Nota: esto es una aproximación, idealmente necesitaríamos acceder a las imágenes originales
img_path = test_gen.filepaths[idx]
img = plt.imread(img_path)
plt.imshow(img)
plt.axis('off')
true_label = class_names[y_true[idx]]
pred_label = class_names[y_pred[idx]]
color = 'green' if y_true[idx] == y_pred[idx] else 'red'
plt.title(f'Real: {true_label}\nPredicción: {pred_label}',
color=color, fontsize=10)
plt.suptitle('Ejemplos de Predicciones (Verde=Correcta, Rojo=Incorrecta)', fontsize=14)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'prediction_examples.png'), dpi=300, bbox_inches='tight')
plt.close()
print("✅ Ejemplos de predicciones guardados")
except Exception as e:
print(f"⚠️ Error creando ejemplos de predicciones: {e}")
if __name__ == "__main__":
main()
# ----------------- MAIN -----------------
print('\n=== Start of the pipeline ===')
df = safe_read_csv(CSV_PATH)
print('Total registered images in the CSV:', len(df))
# Check columns
required_cols = {'id_img','fase V'}
if not required_cols.issubset(set(df.columns)):
raise ValueError(f'CSV must contain the columns: {required_cols}')
# Prepare folders
SPLIT_DIR = os.path.join(PROJECT_PATH, 'results_nocc/split_fase V')
if FORCE_SPLIT:
shutil.rmtree(SPLIT_DIR, ignore_errors=True)
if not os.path.exists(SPLIT_DIR):
print("Creating a new split...")
train_df, val_df, test_df = prepare_image_folders(df, IMAGES_DIR, SPLIT_DIR)
else:
print("Reusing existing split...")
# Load the dataframes from the created split directories
train_df = pd.DataFrame([(f.name.split('.')[0], Path(f).parent.name) for f in Path(os.path.join(SPLIT_DIR, 'train')).rglob('*.jpg')], columns=['id_img', 'fase'])
val_df = pd.DataFrame([(f.name.split('.')[0], Path(f).parent.name) for f in Path(os.path.join(SPLIT_DIR, 'val')).rglob('*.jpg')], columns=['id_img', 'fase'])
test_df = pd.DataFrame([(f.name.split('.')[0], Path(f).parent.name) for f in Path(os.path.join(SPLIT_DIR, 'test')).rglob('*.jpg')], columns=['id_img', 'fase'])
# Data generators
train_datagen = ImageDataGenerator(rescale=1./255,
rotation_range=20,
width_shift_range=0.1,
height_shift_range=0.1,
shear_range=0.1,
zoom_range=0.1,
horizontal_flip=True,
fill_mode='nearest')
val_test_datagen = ImageDataGenerator(rescale=1./255)
train_gen = train_datagen.flow_from_directory(os.path.join(SPLIT_DIR,'train'),
target_size=IMG_SIZE,
batch_size=BATCH_SIZE,
class_mode='categorical',
seed=SEED)
val_gen = val_test_datagen.flow_from_directory(os.path.join(SPLIT_DIR,'val'),
target_size=IMG_SIZE,
batch_size=BATCH_SIZE,
class_mode='categorical',
shuffle=False)
test_gen = val_test_datagen.flow_from_directory(os.path.join(SPLIT_DIR,'test'),
target_size=IMG_SIZE,
batch_size=BATCH_SIZE,
class_mode='categorical',
shuffle=False)
# Save class mapping->index
class_indices = train_gen.class_indices
print('Class indices:', class_indices)
import json
with open(os.path.join(OUTPUT_DIR,'class_indices.txt'),'w') as f:
json.dump(class_indices, f)
# ----------------- Modelo (Transfer Learning MobileNetV2) - Retraining from scratch -----------------
print('\n=== Inicio del entrenamiento desde cero ===')
# Define the MobileNetV2 base model
base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=(*IMG_SIZE,3))
# Set the base model to not be trainable initially
base_model.trainable = False
# Build a new sequential model
model = models.Sequential([
base_model,
layers.GlobalAveragePooling2D(),
layers.Dropout(0.3),
layers.Dense(128, activation='relu'),
layers.Dropout(0.3),
layers.Dense(train_gen.num_classes, activation='softmax') # Use the correct number of classes
])
# Compile the new model
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
loss='categorical_crossentropy',
metrics=['accuracy'])
model.summary()
# Define callbacks
early = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
chk = tf.keras.callbacks.ModelCheckpoint(os.path.join(OUTPUT_DIR,'best_model.keras'), save_best_only=True)
# Calculate class weights (re-calculate in case the dataframe changed)
from sklearn.utils import class_weight
class_weights = class_weight.compute_class_weight(
'balanced',
classes=np.unique(train_df['fase']),
y=train_df['fase']
)
class_weights = dict(zip(np.unique(train_gen.classes), class_weights))
print("Class weights for training:", class_weights)
# Train the new model
EPOCHS = 45 # Use the original number of epochs for the first phase
history = model.fit(
train_gen,
validation_data=val_gen,
epochs=EPOCHS,
callbacks=[early, chk],
class_weight=class_weights
)
# --- FINE-TUNING ---
print('\n=== Start of fine-tuning phase ===')
# Thawing some layers
base_model.trainable = True
fine_tune_at = 100 # This value could be adjusted, for example the last 100 layers
for layer in base_model.layers[:fine_tune_at]:
layer.trainable = False
# Recompiling the model with a lower learning rate
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),
loss='categorical_crossentropy',
metrics=['accuracy'])
model.summary()
# Continuar entrenamiento
fine_tune_epochs = 10
total_epochs = EPOCHS + fine_tune_epochs
history_fine_tune = model.fit(
train_gen,
validation_data=val_gen,
epochs=total_epochs,
initial_epoch=history.epoch[-1],
callbacks=[early, chk] # Using the same callbacks
)
# ----------------- Evaluación en Test -----------------
print('\n=== Evaluation in a test set ===')
# Load the best weights saved during the training (from either initial or fine-tuning phase)
model.load_weights(os.path.join(OUTPUT_DIR,'best_model.keras'))
# Predictions
y_pred_prob = model.predict(test_gen)
y_pred = np.argmax(y_pred_prob, axis=1)
y_true = test_gen.classes
# Load the full class indices mapping
import json
with open(os.path.join(OUTPUT_DIR,'class_indices.txt'),'r') as f:
full_class_indices = json.load(f)
# Get the corresponding class names for all classes from the loaded class_indices
index_to_class = {v: k for k, v in full_class_indices.items()}
all_class_names = [index_to_class[i] for i in sorted(index_to_class.keys())]
# Get the unique class indices present in the test set
unique_test_indices = np.unique(y_true)
# Get the corresponding class names for the unique test indices
test_labels_filtered = [index_to_class[i] for i in unique_test_indices]
# Reporte
report = classification_report(y_true, y_pred, labels=unique_test_indices, target_names=test_labels_filtered) # Use unique_test_indices for labels and test_labels_filtered for target_names
cm = confusion_matrix(y_true, y_pred, labels=unique_test_indices) # Specify labels for confusion matrix
print('\nClassification Report:\n', report)
print('\nConfusion Matrix:\n', cm)
with open(os.path.join(OUTPUT_DIR,'classification_report.txt'),'w') as f:
f.write(report)
np.savetxt(os.path.join(OUTPUT_DIR,'confusion_matrix.csv'), cm, delimiter=',', fmt='%d')
# ----------------- Visualizations -----------------
def show_examples(test_gen, y_true, y_pred, labels, n=6):
filepaths = []
for i in range(len(test_gen.filepaths)):
filepaths.append(test_gen.filepaths[i])
# select examples
correct_idx = [i for i,(a,b) in enumerate(zip(y_true,y_pred)) if a==b]
wrong_idx = [i for i,(a,b) in enumerate(zip(y_true,y_pred)) if a!=b]
examples = (correct_idx[:n//2] if len(correct_idx)>0 else []) + (wrong_idx[:n//2] if len(wrong_idx)>0 else [])
plt.figure(figsize=(15,8))
for i, idx in enumerate(examples):
img = plt.imread(filepaths[idx])
plt.subplot(2, n//2, i+1)
plt.imshow(img)
plt.axis('off')
plt.title(f'True: {labels[y_true[idx]]}\nPred: {labels[y_pred[idx]]}')
plt.suptitle('Examples: Right and Wrong')
plt.show()
# Call the function with the correct variables from the evaluation step
show_examples(test_gen, y_true, y_pred, all_class_names, n=6)
#
plt.figure(figsize=(8,4))
plt.plot(history.history['accuracy'], label='train_acc')
plt.plot(history.history['val_accuracy'], label='val_acc')
# Include fine-tuning history if it exists
if 'history_fine_tune' in locals():
plt.plot(history_fine_tune.history['accuracy'], label='fine_tune_train_acc')
plt.plot(history_fine_tune.history['val_accuracy'], label='fine_tune_val_acc')
plt.title('Accuracy during training')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid()
plt.show()
# Heatmap of the confusion matrix
plt.figure(figsize=(8,6))
sns.heatmap(cm, annot=True, fmt='d', xticklabels=all_class_names, yticklabels=all_class_names, cmap='Blues')
plt.xlabel('Prediction')
plt.ylabel('True')
plt.title('Matriz de confusión')
plt.show()

View File

@ -0,0 +1,268 @@
# ResNet50 Transfer Learning para Nocciola - Documentación
## 🚀 **Implementación Completa de ResNet50**
Este proyecto implementa transfer learning con ResNet50 para clasificación de fases fenológicas de nocciola, basado en la estructura exitosa de MobileNetV1.py pero optimizado específicamente para ResNet50.
## 📁 **Estructura de Archivos**
```
Code/Supervised_learning/
├── ResNET.py # Script principal ResNet50 ✅ NUEVO
├── train_resnet50.py # Script facilitado de entrenamiento ✅ NUEVO
├── MobileNetV1.py # Script MobileNetV2 (referencia exitosa)
└── README_ResNet50.md # Esta documentación ✅ NUEVO
```
## 🔧 **Características Específicas de ResNet50**
### **Arquitectura Optimizada:**
```python
# Modelo base ResNet50
base_model = ResNet50(
weights='imagenet',
include_top=False,
input_shape=(224, 224, 3)
)
# Capas personalizadas optimizadas para ResNet50
model = Sequential([
base_model,
GlobalAveragePooling2D(),
BatchNormalization(), # ✅ Importante para ResNet50
Dropout(0.5), # ✅ Dropout más alto
Dense(256, activation='relu'), # ✅ Más neuronas
BatchNormalization(),
Dropout(0.3),
Dense(num_classes, activation='softmax')
])
```
### **Optimizaciones Específicas:**
1. **Learning Rate Conservador:**
- Inicial: `5e-4` (vs `1e-3` en MobileNet)
- Fine-tuning: `1e-6` (muy bajo para estabilidad)
2. **Data Augmentation Reducido:**
- Rotación: 15° (vs 20° en MobileNet)
- Shifts/Zoom: 0.08 (vs 0.1 en MobileNet)
- Más conservador para evitar overfitting
3. **Fine-tuning Selectivo:**
- Solo últimas capas residuales (conv5_x)
- Punto de corte: capa 140 (de 175 total)
- Preserva features básicas, adapta características específicas
4. **Callbacks Adaptados:**
- Paciencia EarlyStopping: 10 (vs 7 en MobileNet)
- Más épocas de fine-tuning: 15 (vs 10)
- Factor ReduceLR: 0.3 (más agresivo)
## 📊 **Comparación ResNet50 vs MobileNetV2**
| Característica | ResNet50 | MobileNetV2 |
|----------------|----------|-------------|
| **Parámetros** | ~25M | ~3.4M |
| **Memoria** | Alta | Baja |
| **Velocidad** | Lenta | Rápida |
| **Precisión** | Alta | Media-Alta |
| **Overfitting** | Más propenso | Menos propenso |
| **Dataset ideal** | Grande/Complejo | Pequeño/Mediano |
## 🎯 **Casos de Uso Recomendados**
### **Usar ResNet50 cuando:**
- ✅ Dataset tiene suficientes muestras (>1000)
- ✅ Patrones complejos en imágenes
- ✅ Precisión es más importante que velocidad
- ✅ Hardware robusto disponible
- ✅ Investigación/análisis detallado
### **Usar MobileNetV2 cuando:**
- ✅ Dataset pequeño (<500 muestras)
- ✅ Velocidad es importante
- ✅ Recursos limitados
- ✅ Producción/móvil
- ✅ Prototipado rápido
## 🚀 **Ejecución**
### **Opción 1: Script Facilitado (Recomendado)**
```bash
# Entrenar solo ResNet50
python train_resnet50.py --model resnet50 --epochs 25
# Comparar ambos modelos
python train_resnet50.py --model both --epochs 20
# Ver comparación detallada
python train_resnet50.py --compare
```
### **Opción 2: Ejecución Directa**
```bash
# ResNet50 básico
python ResNET.py --epochs 25 --force_split
# ResNet50 con dataset específico
python ResNET.py --csv_path "assignments.csv" --epochs 30
```
### **Parámetros Disponibles:**
- `--csv_path`: Ruta al CSV (default: assignments.csv)
- `--images_dir`: Directorio de imágenes
- `--output_dir`: Directorio de resultados (default: results_resnet50_faseV)
- `--epochs`: Épocas de entrenamiento (default: 25)
- `--force_split`: Recrear división de datos
## 📈 **Proceso de Entrenamiento ResNet50**
### **Fase 1: Entrenamiento Inicial**
1. Base model congelado (ImageNet weights)
2. Solo entrenar capas personalizadas
3. Learning rate conservador (5e-4)
4. Callbacks con paciencia aumentada
### **Fase 2: Fine-tuning Selectivo**
5. Descongelar solo conv5_x layers (últimas capas residuales)
6. Learning rate muy bajo (1e-6)
7. 15 épocas adicionales
8. Monitoreo estricto de overfitting
### **Evaluación Robusta:**
9. Manejo automático de clases desbalanceadas
10. Métricas específicas para ResNet50
11. Visualizaciones diferenciadas
## 📂 **Resultados Generados**
### **Modelos:**
- `best_resnet50_model.keras`: Mejor modelo durante entrenamiento
- `final_resnet50_model.keras`: Modelo final completo
### **Reportes:**
- `classification_report.txt`: Reporte detallado con marcador ResNet50
- `confusion_matrix.csv`: Matriz numérica
### **Visualizaciones:**
- `resnet50_training_history.png`: Gráficos de entrenamiento
- `resnet50_confusion_matrix.png`: Matriz visual
- `resnet50_prediction_examples.png`: Ejemplos de predicciones
### **Data Splits:**
- `train_split.csv`, `val_split.csv`, `test_split.csv`
- `class_indices.json`: Mapeo de clases
## ⚙️ **Configuración Técnica**
### **Requisitos del Sistema:**
- RAM: Mínimo 8GB (recomendado 16GB)
- GPU: Opcional pero muy recomendada
- Espacio: ~3GB para resultados completos
### **Dependencias:**
```bash
tensorflow>=2.8.0
scikit-learn
pandas
numpy
matplotlib
seaborn
```
### **Configuración Interna:**
```python
IMG_SIZE = (224, 224) # Estándar ImageNet
BATCH_SIZE = 16 # Reducido para ResNet50
SPLIT = {'train': 0.7, 'val': 0.15, 'test': 0.15}
```
## 🔍 **Diferencias Implementadas vs MobileNet**
### **1. Arquitectura:**
- BatchNormalization adicional
- Dropout más agresivo (0.5 vs 0.3)
- Dense layer mayor (256 vs 128)
### **2. Entrenamiento:**
- Learning rates más conservadores
- Fine-tuning más selectivo
- Más épocas de fine-tuning
### **3. Data Augmentation:**
- Rotaciones menores (15° vs 20°)
- Shifts reducidos (0.08 vs 0.1)
- Menos agresivo para ResNet50
### **4. Callbacks:**
- Paciencia aumentada (10 vs 7)
- Factor ReduceLR más agresivo (0.3 vs 0.2)
- Monitoreo específico para ResNet
### **5. Outputs:**
- Nombres diferenciados (`resnet50_*`)
- Reportes marcados con modelo
- Métricas específicas
## 🎯 **Recomendaciones de Uso**
### **Para Dataset Nocciola Actual:**
Dado que el dataset es relativamente pequeño (~500 muestras):
1. **Primera opción:** MobileNetV2 (más adecuado)
2. **Segunda opción:** ResNet50 con regularización fuerte
3. **Comparación:** Entrenar ambos y comparar resultados
### **Comando Recomendado:**
```bash
# Comparar ambos modelos con pocas épocas
python train_resnet50.py --model both --epochs 15
# Analizar resultados y elegir el mejor
```
## 🚨 **Solución de Problemas**
### **Overfitting en ResNet50:**
- Reducir épocas de entrenamiento
- Aumentar dropout
- Usar dataset filtrado
- Más data augmentation
### **Underfitting:**
- Aumentar épocas
- Reducir regularización
- Learning rate más alto
- Descongelar más capas
### **Problemas de memoria:**
- Reducir BATCH_SIZE a 8 o 4
- Usar gradient checkpointing
- Cerrar otras aplicaciones
## 📊 **Interpretación de Resultados**
### **Métricas Esperadas:**
- **ResNet50:** Mayor precision, posible overfitting
- **MobileNetV2:** Más generalizable, menos overfitting
### **Comparación Visual:**
- Training curves más suaves en MobileNet
- Posible gap train/val en ResNet50
- Matriz de confusión similar o mejor en ResNet50
### **Decisión Final:**
Elegir modelo basado en:
1. Accuracy en test set
2. Diferencia train/validation
3. Requisitos de producción
4. Interpretabilidad de errores
---
## 🎉 **¡ResNet50 Implementado Exitosamente!**
El modelo ResNet50 está completamente implementado usando la misma estructura robusta que MobileNetV1.py, con optimizaciones específicas para ResNet50 y manejo automático de clases desbalanceadas.
**Para comenzar:** `python train_resnet50.py --model resnet50 --epochs 20`

View File

@ -0,0 +1,741 @@
"""
ResNet50 Transfer Learning para Clasificación de Fases Fenológicas - Nocciola
Adaptado para Visual Studio Code basado en MobileNetV1.py exitoso
Dataset: Nocciola GBIF
Objetivo: Predecir fase (fenológica vegetativa)
"""
# =============================================================================
# LIBRARY IMPORTS
# =============================================================================
# System and file handling libraries
import os # Operating system interface (paths, directories)
import shutil # High-level file operations (copy, move, delete)
import argparse # Command-line argument parsing
import random # Random number generation for reproducibility
import pandas as pd # Data manipulation and analysis
import numpy as np # Numerical computing with multidimensional arrays
from pathlib import Path # Modern path handling
import json
# Data visualization libraries
import matplotlib.pyplot as plt # Plotting and visualization
import seaborn as sns # Statistical data visualization
# Deep learning libraries (TensorFlow and Keras)
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.utils import class_weight
# ----------------- CONFIG -----------------
PROJECT_PATH = r'C:\Users\sof12\Desktop\ML\Datasets\Nocciola\combi'
IMAGES_DIR = r'C:\Users\sof12\Desktop\ML\Datasets\Nocciola\combi' # Las imágenes están en el directorio principal
CSV_PATH = os.path.join(PROJECT_PATH, 'assignments.csv') # CSV principal
OUTPUT_DIR = os.path.join(PROJECT_PATH, 'results_resnet50')
os.makedirs(OUTPUT_DIR, exist_ok=True)
IMG_SIZE = (224, 224) # Recomendado para ResNet50 (tamaño estándar ImageNet)
BATCH_SIZE = 16 # Reducido para ResNet50 (más pesado que MobileNet)
SEED = 42
SPLIT = {'train': 0.7, 'val': 0.15, 'test': 0.15}
FORCE_SPLIT = False
# ----------------- Utilities -----------------
def set_seed(seed=42):
"""Establecer semilla para reproducibilidad"""
random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)
def analyze_class_distribution(df, column_name='fase V'):
"""Analizar distribución de clases y detectar desbalances"""
print(f"\n📊 === Análisis de Distribución de Clases ===")
# Contar por clase
counts = df[column_name].value_counts()
total = len(df)
print(f"📊 Total de muestras: {total}")
print(f"📊 Número de clases: {len(counts)}")
print(f"📊 Distribución por clase:")
# Mostrar estadísticas detalladas
for clase, count in counts.items():
percentage = (count / total) * 100
print(f" - {clase}: {count} muestras ({percentage:.1f}%)")
# Detectar clases problemáticas
min_samples = 5 # Umbral mínimo recomendado
small_classes = counts[counts < min_samples]
if len(small_classes) > 0:
print(f"\n⚠️ Clases con menos de {min_samples} muestras:")
for clase, count in small_classes.items():
print(f" - {clase}: {count} muestras")
print(f"\n💡 Recomendaciones:")
print(f" 1. Considera recolectar más datos para estas clases")
print(f" 2. O fusionar clases similares")
print(f" 3. O usar técnicas de data augmentation específicas")
return counts, small_classes
def safe_read_csv(path):
"""Leer CSV con manejo de encoding"""
if not os.path.exists(path):
raise FileNotFoundError(f'CSV no encontrado: {path}')
try:
df = pd.read_csv(path, encoding='utf-8')
except UnicodeDecodeError:
try:
df = pd.read_csv(path, encoding='latin-1')
except:
df = pd.read_csv(path, encoding='cp1252')
return df
def resolve_image_path(images_dir, img_id):
"""Resolver la ruta completa de una imagen"""
if pd.isna(img_id) or str(img_id).strip() == '':
return None
img_id = str(img_id).strip()
# Verificar si ya incluye extensión y existe
direct_path = os.path.join(images_dir, img_id)
if os.path.exists(direct_path):
return direct_path
# Probar con extensiones comunes
stem = os.path.splitext(img_id)[0]
for ext in ['.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG']:
img_path = os.path.join(images_dir, stem + ext)
if os.path.exists(img_path):
return img_path
return None
def prepare_image_folders(df, images_dir, out_dir, split=SPLIT, seed=SEED):
"""Crear estructura de carpetas para flow_from_directory"""
set_seed(seed)
# Filtrar solo filas con fase V válida e imágenes existentes
print(f"📊 Datos iniciales: {len(df)} filas")
# Filtrar filas con fase V válida
df_valid = df.dropna(subset=['fase V']).copy()
df_valid = df_valid[df_valid['fase V'].str.strip() != '']
print(f"📊 Con fase V válida: {len(df_valid)} filas")
# Verificar existencia de imágenes
valid_rows = []
for _, row in df_valid.iterrows():
img_path = resolve_image_path(images_dir, row['id_img'])
if img_path:
valid_rows.append(row)
else:
print(f"⚠️ Imagen no encontrada: {row['id_img']}")
if not valid_rows:
raise ValueError("❌ No se encontraron imágenes válidas")
df_final = pd.DataFrame(valid_rows)
print(f"📊 Con imágenes existentes: {len(df_final)} filas")
# Mostrar distribución de clases
fase_counts = df_final['fase V'].value_counts()
print(f"\n📊 Distribución de fases:")
for fase, count in fase_counts.items():
print(f" - {fase}: {count} imágenes")
# Remover clases con muy pocas muestras (menos de 3)
min_samples = 3
valid_phases = fase_counts[fase_counts >= min_samples].index.tolist()
if len(valid_phases) < len(fase_counts):
excluded = fase_counts[fase_counts < min_samples].index.tolist()
print(f"⚠️ Excluyendo fases con menos de {min_samples} muestras: {excluded}")
df_final = df_final[df_final['fase V'].isin(valid_phases)]
print(f"📊 Después de filtrar: {len(df_final)} filas, {len(valid_phases)} clases")
labels = df_final['fase V'].unique().tolist()
print(f"📊 Clases finales: {labels}")
# Mezclar y dividir datos
df_shuffled = df_final.sample(frac=1, random_state=seed).reset_index(drop=True)
n = len(df_shuffled)
n_train = int(n * split['train'])
n_val = int(n * split['val'])
train_df = df_shuffled.iloc[:n_train]
val_df = df_shuffled.iloc[n_train:n_train + n_val]
test_df = df_shuffled.iloc[n_train + n_val:]
print(f"📊 División final:")
print(f" - Entrenamiento: {len(train_df)} imágenes")
print(f" - Validación: {len(val_df)} imágenes")
print(f" - Prueba: {len(test_df)} imágenes")
# Crear estructura de carpetas
for part in ['train', 'val', 'test']:
for label in labels:
label_dir = os.path.join(out_dir, part, str(label))
os.makedirs(label_dir, exist_ok=True)
# Función para copiar imágenes
def copy_subset(subdf, subset_name):
copied, missing = 0, 0
for _, row in subdf.iterrows():
src = resolve_image_path(images_dir, row['id_img'])
if src:
fase = str(row['fase V'])
dst = os.path.join(out_dir, subset_name, fase, f"{row['id_img']}.jpg")
try:
shutil.copy2(src, dst)
copied += 1
except Exception as e:
print(f"⚠️ Error copiando {src}: {e}")
missing += 1
else:
missing += 1
print(f"{subset_name}: {copied} imágenes copiadas, {missing} fallidas")
return copied
# Copiar imágenes a las carpetas correspondientes
copy_subset(train_df, 'train')
copy_subset(val_df, 'val')
copy_subset(test_df, 'test')
return train_df, val_df, test_df
def main():
"""Función principal del pipeline ResNet50"""
parser = argparse.ArgumentParser(description='ResNet50 Transfer Learning para Nocciola')
parser.add_argument('--csv_path', type=str, default=CSV_PATH,
help='Ruta al archivo CSV con metadatos')
parser.add_argument('--images_dir', type=str, default=IMAGES_DIR,
help='Directorio con las imágenes')
parser.add_argument('--output_dir', type=str, default=OUTPUT_DIR,
help='Directorio de salida para resultados')
parser.add_argument('--epochs', type=int, default=25,
help='Número de épocas de entrenamiento')
parser.add_argument('--force_split', action='store_true',
help='Forzar recreación del split de datos')
args = parser.parse_args()
print('\n🚀 === Inicio del pipeline ResNet50 para Nocciola ===')
print(f"📁 Directorio de imágenes: {args.images_dir}")
print(f"📄 Archivo CSV: {args.csv_path}")
print(f"📂 Directorio de salida: {args.output_dir}")
# Establecer semilla
set_seed(SEED)
# Crear directorio de salida
os.makedirs(args.output_dir, exist_ok=True)
# Leer datos
print('\n📊 === Cargando datos ===')
df = safe_read_csv(args.csv_path)
print(f'📊 Total de registros en CSV: {len(df)}')
print(f'📊 Columnas disponibles: {list(df.columns)}')
# Verificar columnas requeridas
required_cols = {'id_img', 'fase V'}
if not required_cols.issubset(set(df.columns)):
missing = required_cols - set(df.columns)
raise ValueError(f'❌ CSV debe contener las columnas: {missing}')
# Analizar distribución de clases antes del procesamiento
analyze_class_distribution(df, 'fase V')
# Preparar estructura de carpetas
SPLIT_DIR = os.path.join(args.output_dir, 'data_split')
if args.force_split and os.path.exists(SPLIT_DIR):
print("🗑️ Eliminando split existente...")
shutil.rmtree(SPLIT_DIR)
if not os.path.exists(SPLIT_DIR):
print("\n📁 === Creando nueva división de datos ===")
train_df, val_df, test_df = prepare_image_folders(df, args.images_dir, SPLIT_DIR)
# Guardar información del split
train_df.to_csv(os.path.join(args.output_dir, 'train_split.csv'), index=False)
val_df.to_csv(os.path.join(args.output_dir, 'val_split.csv'), index=False)
test_df.to_csv(os.path.join(args.output_dir, 'test_split.csv'), index=False)
else:
print("\n♻️ === Reutilizando división existente ===")
# Cargar información del split si existe
try:
train_df = pd.read_csv(os.path.join(args.output_dir, 'train_split.csv'))
val_df = pd.read_csv(os.path.join(args.output_dir, 'val_split.csv'))
test_df = pd.read_csv(os.path.join(args.output_dir, 'test_split.csv'))
except:
print("⚠️ No se pudieron cargar los archivos de split, recreando...")
train_df, val_df, test_df = prepare_image_folders(df, args.images_dir, SPLIT_DIR)
# Crear generadores de datos
print("\n🔄 === Creando generadores de datos ===")
# Data augmentation para entrenamiento (más conservador para ResNet50)
train_datagen = ImageDataGenerator(
rescale=1./255,
rotation_range=15, # Menos rotación que MobileNet
width_shift_range=0.08,
height_shift_range=0.08,
shear_range=0.08,
zoom_range=0.08,
horizontal_flip=True,
fill_mode='nearest'
)
# Solo normalización para validación y test
val_test_datagen = ImageDataGenerator(rescale=1./255)
# Crear generadores
train_gen = train_datagen.flow_from_directory(
os.path.join(SPLIT_DIR, 'train'),
target_size=IMG_SIZE,
batch_size=BATCH_SIZE,
class_mode='categorical',
seed=SEED
)
val_gen = val_test_datagen.flow_from_directory(
os.path.join(SPLIT_DIR, 'val'),
target_size=IMG_SIZE,
batch_size=BATCH_SIZE,
class_mode='categorical',
shuffle=False
)
test_gen = val_test_datagen.flow_from_directory(
os.path.join(SPLIT_DIR, 'test'),
target_size=IMG_SIZE,
batch_size=BATCH_SIZE,
class_mode='categorical',
shuffle=False
)
# Guardar mapeo de clases
class_indices = train_gen.class_indices
print(f'🏷️ Mapeo de clases: {class_indices}')
with open(os.path.join(args.output_dir, 'class_indices.json'), 'w') as f:
json.dump(class_indices, f, indent=2)
print(f"📊 Muestras por conjunto:")
print(f" - Entrenamiento: {train_gen.samples}")
print(f" - Validación: {val_gen.samples}")
print(f" - Prueba: {test_gen.samples}")
print(f" - Número de clases: {train_gen.num_classes}")
# Crear y entrenar modelo ResNet50
print("\n🤖 === Construcción del modelo ResNet50 ===")
# Modelo base ResNet50
base_model = ResNet50(
weights='imagenet',
include_top=False,
input_shape=(*IMG_SIZE, 3)
)
base_model.trainable = False # Congelar inicialmente
# Construir modelo secuencial optimizado para ResNet50
model = models.Sequential([
base_model,
layers.GlobalAveragePooling2D(),
layers.BatchNormalization(), # Importante para ResNet50
layers.Dropout(0.5), # Dropout más alto para ResNet50
layers.Dense(256, activation='relu'),
layers.BatchNormalization(),
layers.Dropout(0.3),
layers.Dense(train_gen.num_classes, activation='softmax')
])
# Compilar modelo con learning rate más bajo para ResNet50
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=5e-4), # LR más bajo
loss='categorical_crossentropy',
metrics=['accuracy']
)
print("📋 Resumen del modelo ResNet50:")
model.summary()
# Calcular pesos de clase
print("\n⚖️ === Calculando pesos de clase ===")
try:
# Obtener etiquetas de entrenamiento
train_labels = []
for i in range(len(train_gen)):
_, labels = train_gen[i]
train_labels.extend(np.argmax(labels, axis=1))
if len(train_labels) >= train_gen.samples:
break
# Calcular pesos balanceados
class_weights = class_weight.compute_class_weight(
'balanced',
classes=np.unique(train_labels),
y=train_labels
)
class_weight_dict = dict(zip(np.unique(train_labels), class_weights))
print(f"⚖️ Pesos de clase: {class_weight_dict}")
except Exception as e:
print(f"⚠️ Error calculando pesos de clase: {e}")
class_weight_dict = None
# Callbacks para entrenamiento ResNet50
early_stopping = tf.keras.callbacks.EarlyStopping(
monitor='val_loss',
patience=10, # Más paciencia para ResNet50
restore_best_weights=True,
verbose=1
)
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
os.path.join(args.output_dir, 'best_resnet50_model.keras'),
save_best_only=True,
monitor='val_loss',
verbose=1
)
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
monitor='val_loss',
factor=0.3, # Reducción más agresiva
patience=5, # Menos paciencia para reducir LR
min_lr=1e-8,
verbose=1
)
callbacks = [early_stopping, model_checkpoint, reduce_lr]
# Entrenamiento inicial
print(f"\n🏋️ === Entrenamiento inicial ResNet50 ({args.epochs} épocas) ===")
try:
history = model.fit(
train_gen,
validation_data=val_gen,
epochs=args.epochs,
callbacks=callbacks,
class_weight=class_weight_dict,
verbose=1
)
print("✅ Entrenamiento inicial completado")
except Exception as e:
print(f"❌ Error durante entrenamiento: {e}")
# Entrenar sin class_weight si hay problemas
print("🔄 Intentando entrenamiento sin pesos de clase...")
history = model.fit(
train_gen,
validation_data=val_gen,
epochs=args.epochs,
callbacks=callbacks,
verbose=1
)
# Fine-tuning específico para ResNet50
print("\n🔧 === Fine-tuning ResNet50 ===")
# Descongelar capas específicas de ResNet50
base_model.trainable = True
# Para ResNet50, descongelar solo las últimas capas residuales
# ResNet50 tiene bloques conv5_x que son los más específicos
fine_tune_at = 140 # Aproximadamente desde conv5_block1
for layer in base_model.layers[:fine_tune_at]:
layer.trainable = False
print(f"🔧 Capas entrenables: {len([l for l in base_model.layers if l.trainable])}/{len(base_model.layers)}")
# Recompilar con learning rate muy bajo para fine-tuning
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-6), # LR muy bajo para ResNet50
loss='categorical_crossentropy',
metrics=['accuracy']
)
# Continuar entrenamiento
fine_tune_epochs = 15 # Más épocas para fine-tuning
total_epochs = len(history.history['loss']) + fine_tune_epochs
try:
history_fine = model.fit(
train_gen,
validation_data=val_gen,
epochs=total_epochs,
initial_epoch=len(history.history['loss']),
callbacks=callbacks,
verbose=1
)
print("✅ Fine-tuning completado")
# Combinar historiales
for key in history.history:
if key in history_fine.history:
history.history[key].extend(history_fine.history[key])
except Exception as e:
print(f"⚠️ Error durante fine-tuning: {e}")
print("Continuando con modelo del entrenamiento inicial...")
# Evaluación final
print("\n📊 === Evaluación en conjunto de prueba ===")
# Cargar mejor modelo
try:
model.load_weights(os.path.join(args.output_dir, 'best_resnet50_model.keras'))
print("✅ Cargado mejor modelo ResNet50 guardado")
except:
print("⚠️ Usando modelo actual")
# Guardar modelo final
model.save(os.path.join(args.output_dir, 'final_resnet50_model.keras'))
print("💾 Modelo ResNet50 final guardado")
# Predicciones en test
test_gen.reset()
y_pred_prob = model.predict(test_gen, verbose=1)
y_pred = np.argmax(y_pred_prob, axis=1)
y_true = test_gen.classes
# Mapeo de índices a nombres de clase
index_to_class = {v: k for k, v in class_indices.items()}
# Obtener solo las clases que realmente aparecen en el conjunto de test
unique_test_classes = np.unique(np.concatenate([y_true, y_pred]))
test_class_names = [index_to_class[i] for i in unique_test_classes]
print(f"📊 Clases en conjunto de test: {len(unique_test_classes)}")
print(f"📊 Todas las clases entrenadas: {len(class_indices)}")
print(f"📊 Clases presentes en test: {test_class_names}")
# Verificar si hay clases faltantes
all_classes = set(range(len(class_indices)))
test_classes = set(unique_test_classes)
missing_classes = all_classes - test_classes
if missing_classes:
missing_names = [index_to_class[i] for i in missing_classes]
print(f"⚠️ Clases sin muestras en test: {missing_names}")
# Reporte de clasificación con clases filtradas
print("\n📋 === Reporte de Clasificación ResNet50 ===")
try:
report = classification_report(
y_true, y_pred,
labels=unique_test_classes,
target_names=test_class_names,
output_dict=False,
zero_division=0
)
print(report)
# Guardar reporte
with open(os.path.join(args.output_dir, 'classification_report.txt'), 'w') as f:
f.write(f"Modelo: ResNet50\n")
f.write(f"Clases evaluadas: {test_class_names}\n")
f.write(f"Clases faltantes en test: {[index_to_class[i] for i in missing_classes] if missing_classes else 'Ninguna'}\n\n")
f.write(report)
except Exception as e:
print(f"❌ Error en classification_report: {e}")
print("📊 Generando reporte alternativo...")
# Reporte manual si falla el automático
from collections import Counter
true_counts = Counter(y_true)
pred_counts = Counter(y_pred)
print("\n📊 Distribución manual:")
print("Clase | Verdaderos | Predichos")
print("-" * 35)
for class_idx in unique_test_classes:
class_name = index_to_class[class_idx]
true_count = true_counts.get(class_idx, 0)
pred_count = pred_counts.get(class_idx, 0)
print(f"{class_name[:15]:15} | {true_count:10} | {pred_count:9}")
# Calcular accuracy básico
accuracy = np.mean(y_true == y_pred)
print(f"\n📊 Accuracy general ResNet50: {accuracy:.4f}")
# Guardar reporte manual
with open(os.path.join(args.output_dir, 'classification_report.txt'), 'w') as f:
f.write("REPORTE MANUAL DE CLASIFICACIÓN - ResNet50\n")
f.write("=" * 50 + "\n\n")
f.write(f"Clases evaluadas: {test_class_names}\n")
f.write(f"Clases faltantes en test: {[index_to_class[i] for i in missing_classes] if missing_classes else 'Ninguna'}\n\n")
f.write("Distribución por clase:\n")
f.write("Clase | Verdaderos | Predichos\n")
f.write("-" * 35 + "\n")
for class_idx in unique_test_classes:
class_name = index_to_class[class_idx]
true_count = true_counts.get(class_idx, 0)
pred_count = pred_counts.get(class_idx, 0)
f.write(f"{class_name[:15]:15} | {true_count:10} | {pred_count:9}\n")
f.write(f"\nAccuracy general ResNet50: {accuracy:.4f}\n")
# Matriz de confusión con clases filtradas
cm = confusion_matrix(y_true, y_pred, labels=unique_test_classes)
print(f"\n🔢 Matriz de Confusión ResNet50 ({len(unique_test_classes)} clases):")
print(cm)
np.savetxt(os.path.join(args.output_dir, 'confusion_matrix.csv'),
cm, delimiter=',', fmt='%d')
# Visualizaciones con clases filtradas
print("\n📈 === Generando visualizaciones ResNet50 ===")
# Gráfico de entrenamiento
plot_training_history(history, args.output_dir)
# Matriz de confusión visual con clases filtradas
plot_confusion_matrix(cm, test_class_names, args.output_dir)
# Ejemplos de predicciones con clases filtradas
plot_prediction_examples(test_gen, y_true, y_pred, test_class_names, args.output_dir, unique_test_classes)
print(f"\n🎉 === Pipeline ResNet50 completado ===")
print(f"📁 Resultados guardados en: {args.output_dir}")
print(f"📊 Precisión final en test: {np.mean(y_true == y_pred):.4f}")
print(f"📊 Clases evaluadas: {len(unique_test_classes)}/{len(class_indices)}")
# Información adicional sobre clases desbalanceadas
if missing_classes:
print(f"\n⚠️ === Información sobre Clases Desbalanceadas ===")
print(f"❌ Clases sin muestras en test: {len(missing_classes)}")
for missing_idx in missing_classes:
missing_name = index_to_class[missing_idx]
print(f" - {missing_name} (índice {missing_idx})")
print(f"💡 Sugerencia: Considera aumentar el dataset o fusionar clases similares")
def plot_training_history(history, output_dir):
"""Graficar historial de entrenamiento ResNet50"""
try:
plt.figure(figsize=(12, 4))
# Accuracy
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Entrenamiento', linewidth=2)
if 'val_accuracy' in history.history:
plt.plot(history.history['val_accuracy'], label='Validación', linewidth=2)
plt.title('Precisión del Modelo ResNet50')
plt.xlabel('Época')
plt.ylabel('Precisión')
plt.legend()
plt.grid(True, alpha=0.3)
# Loss
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Entrenamiento', linewidth=2)
if 'val_loss' in history.history:
plt.plot(history.history['val_loss'], label='Validación', linewidth=2)
plt.title('Pérdida del Modelo ResNet50')
plt.xlabel('Época')
plt.ylabel('Pérdida')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'resnet50_training_history.png'), dpi=300, bbox_inches='tight')
plt.close()
print("✅ Gráfico de entrenamiento ResNet50 guardado")
except Exception as e:
print(f"⚠️ Error creando gráfico de entrenamiento: {e}")
def plot_confusion_matrix(cm, class_names, output_dir):
"""Graficar matriz de confusión ResNet50"""
try:
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=class_names, yticklabels=class_names,
cbar_kws={'label': 'Número de muestras'})
plt.title('Matriz de Confusión - ResNet50')
plt.ylabel('Etiqueta Verdadera')
plt.xlabel('Etiqueta Predicha')
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'resnet50_confusion_matrix.png'), dpi=300, bbox_inches='tight')
plt.close()
print("✅ Matriz de confusión ResNet50 guardada")
except Exception as e:
print(f"⚠️ Error creando matriz de confusión: {e}")
def plot_prediction_examples(test_gen, y_true, y_pred, class_names, output_dir, unique_classes=None, n_examples=12):
"""Mostrar ejemplos de predicciones correctas e incorrectas para ResNet50"""
try:
# Obtener índices de predicciones correctas e incorrectas
correct_idx = np.where(y_true == y_pred)[0]
incorrect_idx = np.where(y_true != y_pred)[0]
# Seleccionar ejemplos
n_correct = min(n_examples // 2, len(correct_idx))
n_incorrect = min(n_examples // 2, len(incorrect_idx))
selected_correct = np.random.choice(correct_idx, n_correct, replace=False) if len(correct_idx) > 0 else []
selected_incorrect = np.random.choice(incorrect_idx, n_incorrect, replace=False) if len(incorrect_idx) > 0 else []
selected_indices = np.concatenate([selected_correct, selected_incorrect])
if len(selected_indices) == 0:
print("⚠️ No hay ejemplos para mostrar")
return
# Crear gráfico
n_show = len(selected_indices)
cols = 4
rows = (n_show + cols - 1) // cols
plt.figure(figsize=(15, 4 * rows))
for i, idx in enumerate(selected_indices):
plt.subplot(rows, cols, i + 1)
# Obtener imagen
try:
img_path = test_gen.filepaths[idx]
img = plt.imread(img_path)
plt.imshow(img)
plt.axis('off')
true_label = class_names[y_true[idx]]
pred_label = class_names[y_pred[idx]]
color = 'green' if y_true[idx] == y_pred[idx] else 'red'
plt.title(f'Real: {true_label}\nPredicción: {pred_label}',
color=color, fontsize=9, fontweight='bold')
except Exception as e:
plt.text(0.5, 0.5, f'Error cargando\nimagen {idx}',
ha='center', va='center', transform=plt.gca().transAxes)
plt.axis('off')
plt.suptitle('Ejemplos de Predicciones ResNet50 (Verde=Correcta, Rojo=Incorrecta)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'resnet50_prediction_examples.png'), dpi=300, bbox_inches='tight')
plt.close()
print("✅ Ejemplos de predicciones ResNet50 guardados")
except Exception as e:
print(f"⚠️ Error creando ejemplos de predicciones: {e}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,184 @@
---
description: Perform a non-destructive cross-artifact consistency and quality analysis across spec.md, plan.md, and tasks.md after task generation.
---
## User Input
```text
$ARGUMENTS
```
You **MUST** consider the user input before proceeding (if not empty).
## Goal
Identify inconsistencies, duplications, ambiguities, and underspecified items across the three core artifacts (`spec.md`, `plan.md`, `tasks.md`) before implementation. This command MUST run only after `/speckit.tasks` has successfully produced a complete `tasks.md`.
## Operating Constraints
**STRICTLY READ-ONLY**: Do **not** modify any files. Output a structured analysis report. Offer an optional remediation plan (user must explicitly approve before any follow-up editing commands would be invoked manually).
**Constitution Authority**: The project constitution (`.specify/memory/constitution.md`) is **non-negotiable** within this analysis scope. Constitution conflicts are automatically CRITICAL and require adjustment of the spec, plan, or tasks—not dilution, reinterpretation, or silent ignoring of the principle. If a principle itself needs to change, that must occur in a separate, explicit constitution update outside `/speckit.analyze`.
## Execution Steps
### 1. Initialize Analysis Context
Run `.specify/scripts/powershell/check-prerequisites.ps1 -Json -RequireTasks -IncludeTasks` once from repo root and parse JSON for FEATURE_DIR and AVAILABLE_DOCS. Derive absolute paths:
- SPEC = FEATURE_DIR/spec.md
- PLAN = FEATURE_DIR/plan.md
- TASKS = FEATURE_DIR/tasks.md
Abort with an error message if any required file is missing (instruct the user to run missing prerequisite command).
For single quotes in args like "I'm Groot", use escape syntax: e.g 'I'\''m Groot' (or double-quote if possible: "I'm Groot").
### 2. Load Artifacts (Progressive Disclosure)
Load only the minimal necessary context from each artifact:
**From spec.md:**
- Overview/Context
- Functional Requirements
- Non-Functional Requirements
- User Stories
- Edge Cases (if present)
**From plan.md:**
- Architecture/stack choices
- Data Model references
- Phases
- Technical constraints
**From tasks.md:**
- Task IDs
- Descriptions
- Phase grouping
- Parallel markers [P]
- Referenced file paths
**From constitution:**
- Load `.specify/memory/constitution.md` for principle validation
### 3. Build Semantic Models
Create internal representations (do not include raw artifacts in output):
- **Requirements inventory**: Each functional + non-functional requirement with a stable key (derive slug based on imperative phrase; e.g., "User can upload file" → `user-can-upload-file`)
- **User story/action inventory**: Discrete user actions with acceptance criteria
- **Task coverage mapping**: Map each task to one or more requirements or stories (inference by keyword / explicit reference patterns like IDs or key phrases)
- **Constitution rule set**: Extract principle names and MUST/SHOULD normative statements
### 4. Detection Passes (Token-Efficient Analysis)
Focus on high-signal findings. Limit to 50 findings total; aggregate remainder in overflow summary.
#### A. Duplication Detection
- Identify near-duplicate requirements
- Mark lower-quality phrasing for consolidation
#### B. Ambiguity Detection
- Flag vague adjectives (fast, scalable, secure, intuitive, robust) lacking measurable criteria
- Flag unresolved placeholders (TODO, TKTK, ???, `<placeholder>`, etc.)
#### C. Underspecification
- Requirements with verbs but missing object or measurable outcome
- User stories missing acceptance criteria alignment
- Tasks referencing files or components not defined in spec/plan
#### D. Constitution Alignment
- Any requirement or plan element conflicting with a MUST principle
- Missing mandated sections or quality gates from constitution
#### E. Coverage Gaps
- Requirements with zero associated tasks
- Tasks with no mapped requirement/story
- Non-functional requirements not reflected in tasks (e.g., performance, security)
#### F. Inconsistency
- Terminology drift (same concept named differently across files)
- Data entities referenced in plan but absent in spec (or vice versa)
- Task ordering contradictions (e.g., integration tasks before foundational setup tasks without dependency note)
- Conflicting requirements (e.g., one requires Next.js while other specifies Vue)
### 5. Severity Assignment
Use this heuristic to prioritize findings:
- **CRITICAL**: Violates constitution MUST, missing core spec artifact, or requirement with zero coverage that blocks baseline functionality
- **HIGH**: Duplicate or conflicting requirement, ambiguous security/performance attribute, untestable acceptance criterion
- **MEDIUM**: Terminology drift, missing non-functional task coverage, underspecified edge case
- **LOW**: Style/wording improvements, minor redundancy not affecting execution order
### 6. Produce Compact Analysis Report
Output a Markdown report (no file writes) with the following structure:
## Specification Analysis Report
| ID | Category | Severity | Location(s) | Summary | Recommendation |
|----|----------|----------|-------------|---------|----------------|
| A1 | Duplication | HIGH | spec.md:L120-134 | Two similar requirements ... | Merge phrasing; keep clearer version |
(Add one row per finding; generate stable IDs prefixed by category initial.)
**Coverage Summary Table:**
| Requirement Key | Has Task? | Task IDs | Notes |
|-----------------|-----------|----------|-------|
**Constitution Alignment Issues:** (if any)
**Unmapped Tasks:** (if any)
**Metrics:**
- Total Requirements
- Total Tasks
- Coverage % (requirements with >=1 task)
- Ambiguity Count
- Duplication Count
- Critical Issues Count
### 7. Provide Next Actions
At end of report, output a concise Next Actions block:
- If CRITICAL issues exist: Recommend resolving before `/speckit.implement`
- If only LOW/MEDIUM: User may proceed, but provide improvement suggestions
- Provide explicit command suggestions: e.g., "Run /speckit.specify with refinement", "Run /speckit.plan to adjust architecture", "Manually edit tasks.md to add coverage for 'performance-metrics'"
### 8. Offer Remediation
Ask the user: "Would you like me to suggest concrete remediation edits for the top N issues?" (Do NOT apply them automatically.)
## Operating Principles
### Context Efficiency
- **Minimal high-signal tokens**: Focus on actionable findings, not exhaustive documentation
- **Progressive disclosure**: Load artifacts incrementally; don't dump all content into analysis
- **Token-efficient output**: Limit findings table to 50 rows; summarize overflow
- **Deterministic results**: Rerunning without changes should produce consistent IDs and counts
### Analysis Guidelines
- **NEVER modify files** (this is read-only analysis)
- **NEVER hallucinate missing sections** (if absent, report them accurately)
- **Prioritize constitution violations** (these are always CRITICAL)
- **Use examples over exhaustive rules** (cite specific instances, not generic patterns)
- **Report zero issues gracefully** (emit success report with coverage statistics)
## Context
$ARGUMENTS

View File

@ -0,0 +1,294 @@
---
description: Generate a custom checklist for the current feature based on user requirements.
---
## Checklist Purpose: "Unit Tests for English"
**CRITICAL CONCEPT**: Checklists are **UNIT TESTS FOR REQUIREMENTS WRITING** - they validate the quality, clarity, and completeness of requirements in a given domain.
**NOT for verification/testing**:
- ❌ NOT "Verify the button clicks correctly"
- ❌ NOT "Test error handling works"
- ❌ NOT "Confirm the API returns 200"
- ❌ NOT checking if code/implementation matches the spec
**FOR requirements quality validation**:
- ✅ "Are visual hierarchy requirements defined for all card types?" (completeness)
- ✅ "Is 'prominent display' quantified with specific sizing/positioning?" (clarity)
- ✅ "Are hover state requirements consistent across all interactive elements?" (consistency)
- ✅ "Are accessibility requirements defined for keyboard navigation?" (coverage)
- ✅ "Does the spec define what happens when logo image fails to load?" (edge cases)
**Metaphor**: If your spec is code written in English, the checklist is its unit test suite. You're testing whether the requirements are well-written, complete, unambiguous, and ready for implementation - NOT whether the implementation works.
## User Input
```text
$ARGUMENTS
```
You **MUST** consider the user input before proceeding (if not empty).
## Execution Steps
1. **Setup**: Run `.specify/scripts/powershell/check-prerequisites.ps1 -Json` from repo root and parse JSON for FEATURE_DIR and AVAILABLE_DOCS list.
- All file paths must be absolute.
- For single quotes in args like "I'm Groot", use escape syntax: e.g 'I'\''m Groot' (or double-quote if possible: "I'm Groot").
2. **Clarify intent (dynamic)**: Derive up to THREE initial contextual clarifying questions (no pre-baked catalog). They MUST:
- Be generated from the user's phrasing + extracted signals from spec/plan/tasks
- Only ask about information that materially changes checklist content
- Be skipped individually if already unambiguous in `$ARGUMENTS`
- Prefer precision over breadth
Generation algorithm:
1. Extract signals: feature domain keywords (e.g., auth, latency, UX, API), risk indicators ("critical", "must", "compliance"), stakeholder hints ("QA", "review", "security team"), and explicit deliverables ("a11y", "rollback", "contracts").
2. Cluster signals into candidate focus areas (max 4) ranked by relevance.
3. Identify probable audience & timing (author, reviewer, QA, release) if not explicit.
4. Detect missing dimensions: scope breadth, depth/rigor, risk emphasis, exclusion boundaries, measurable acceptance criteria.
5. Formulate questions chosen from these archetypes:
- Scope refinement (e.g., "Should this include integration touchpoints with X and Y or stay limited to local module correctness?")
- Risk prioritization (e.g., "Which of these potential risk areas should receive mandatory gating checks?")
- Depth calibration (e.g., "Is this a lightweight pre-commit sanity list or a formal release gate?")
- Audience framing (e.g., "Will this be used by the author only or peers during PR review?")
- Boundary exclusion (e.g., "Should we explicitly exclude performance tuning items this round?")
- Scenario class gap (e.g., "No recovery flows detected—are rollback / partial failure paths in scope?")
Question formatting rules:
- If presenting options, generate a compact table with columns: Option | Candidate | Why It Matters
- Limit to AE options maximum; omit table if a free-form answer is clearer
- Never ask the user to restate what they already said
- Avoid speculative categories (no hallucination). If uncertain, ask explicitly: "Confirm whether X belongs in scope."
Defaults when interaction impossible:
- Depth: Standard
- Audience: Reviewer (PR) if code-related; Author otherwise
- Focus: Top 2 relevance clusters
Output the questions (label Q1/Q2/Q3). After answers: if ≥2 scenario classes (Alternate / Exception / Recovery / Non-Functional domain) remain unclear, you MAY ask up to TWO more targeted followups (Q4/Q5) with a one-line justification each (e.g., "Unresolved recovery path risk"). Do not exceed five total questions. Skip escalation if user explicitly declines more.
3. **Understand user request**: Combine `$ARGUMENTS` + clarifying answers:
- Derive checklist theme (e.g., security, review, deploy, ux)
- Consolidate explicit must-have items mentioned by user
- Map focus selections to category scaffolding
- Infer any missing context from spec/plan/tasks (do NOT hallucinate)
4. **Load feature context**: Read from FEATURE_DIR:
- spec.md: Feature requirements and scope
- plan.md (if exists): Technical details, dependencies
- tasks.md (if exists): Implementation tasks
**Context Loading Strategy**:
- Load only necessary portions relevant to active focus areas (avoid full-file dumping)
- Prefer summarizing long sections into concise scenario/requirement bullets
- Use progressive disclosure: add follow-on retrieval only if gaps detected
- If source docs are large, generate interim summary items instead of embedding raw text
5. **Generate checklist** - Create "Unit Tests for Requirements":
- Create `FEATURE_DIR/checklists/` directory if it doesn't exist
- Generate unique checklist filename:
- Use short, descriptive name based on domain (e.g., `ux.md`, `api.md`, `security.md`)
- Format: `[domain].md`
- If file exists, append to existing file
- Number items sequentially starting from CHK001
- Each `/speckit.checklist` run creates a NEW file (never overwrites existing checklists)
**CORE PRINCIPLE - Test the Requirements, Not the Implementation**:
Every checklist item MUST evaluate the REQUIREMENTS THEMSELVES for:
- **Completeness**: Are all necessary requirements present?
- **Clarity**: Are requirements unambiguous and specific?
- **Consistency**: Do requirements align with each other?
- **Measurability**: Can requirements be objectively verified?
- **Coverage**: Are all scenarios/edge cases addressed?
**Category Structure** - Group items by requirement quality dimensions:
- **Requirement Completeness** (Are all necessary requirements documented?)
- **Requirement Clarity** (Are requirements specific and unambiguous?)
- **Requirement Consistency** (Do requirements align without conflicts?)
- **Acceptance Criteria Quality** (Are success criteria measurable?)
- **Scenario Coverage** (Are all flows/cases addressed?)
- **Edge Case Coverage** (Are boundary conditions defined?)
- **Non-Functional Requirements** (Performance, Security, Accessibility, etc. - are they specified?)
- **Dependencies & Assumptions** (Are they documented and validated?)
- **Ambiguities & Conflicts** (What needs clarification?)
**HOW TO WRITE CHECKLIST ITEMS - "Unit Tests for English"**:
**WRONG** (Testing implementation):
- "Verify landing page displays 3 episode cards"
- "Test hover states work on desktop"
- "Confirm logo click navigates home"
**CORRECT** (Testing requirements quality):
- "Are the exact number and layout of featured episodes specified?" [Completeness]
- "Is 'prominent display' quantified with specific sizing/positioning?" [Clarity]
- "Are hover state requirements consistent across all interactive elements?" [Consistency]
- "Are keyboard navigation requirements defined for all interactive UI?" [Coverage]
- "Is the fallback behavior specified when logo image fails to load?" [Edge Cases]
- "Are loading states defined for asynchronous episode data?" [Completeness]
- "Does the spec define visual hierarchy for competing UI elements?" [Clarity]
**ITEM STRUCTURE**:
Each item should follow this pattern:
- Question format asking about requirement quality
- Focus on what's WRITTEN (or not written) in the spec/plan
- Include quality dimension in brackets [Completeness/Clarity/Consistency/etc.]
- Reference spec section `[Spec §X.Y]` when checking existing requirements
- Use `[Gap]` marker when checking for missing requirements
**EXAMPLES BY QUALITY DIMENSION**:
Completeness:
- "Are error handling requirements defined for all API failure modes? [Gap]"
- "Are accessibility requirements specified for all interactive elements? [Completeness]"
- "Are mobile breakpoint requirements defined for responsive layouts? [Gap]"
Clarity:
- "Is 'fast loading' quantified with specific timing thresholds? [Clarity, Spec §NFR-2]"
- "Are 'related episodes' selection criteria explicitly defined? [Clarity, Spec §FR-5]"
- "Is 'prominent' defined with measurable visual properties? [Ambiguity, Spec §FR-4]"
Consistency:
- "Do navigation requirements align across all pages? [Consistency, Spec §FR-10]"
- "Are card component requirements consistent between landing and detail pages? [Consistency]"
Coverage:
- "Are requirements defined for zero-state scenarios (no episodes)? [Coverage, Edge Case]"
- "Are concurrent user interaction scenarios addressed? [Coverage, Gap]"
- "Are requirements specified for partial data loading failures? [Coverage, Exception Flow]"
Measurability:
- "Are visual hierarchy requirements measurable/testable? [Acceptance Criteria, Spec §FR-1]"
- "Can 'balanced visual weight' be objectively verified? [Measurability, Spec §FR-2]"
**Scenario Classification & Coverage** (Requirements Quality Focus):
- Check if requirements exist for: Primary, Alternate, Exception/Error, Recovery, Non-Functional scenarios
- For each scenario class, ask: "Are [scenario type] requirements complete, clear, and consistent?"
- If scenario class missing: "Are [scenario type] requirements intentionally excluded or missing? [Gap]"
- Include resilience/rollback when state mutation occurs: "Are rollback requirements defined for migration failures? [Gap]"
**Traceability Requirements**:
- MINIMUM: ≥80% of items MUST include at least one traceability reference
- Each item should reference: spec section `[Spec §X.Y]`, or use markers: `[Gap]`, `[Ambiguity]`, `[Conflict]`, `[Assumption]`
- If no ID system exists: "Is a requirement & acceptance criteria ID scheme established? [Traceability]"
**Surface & Resolve Issues** (Requirements Quality Problems):
Ask questions about the requirements themselves:
- Ambiguities: "Is the term 'fast' quantified with specific metrics? [Ambiguity, Spec §NFR-1]"
- Conflicts: "Do navigation requirements conflict between §FR-10 and §FR-10a? [Conflict]"
- Assumptions: "Is the assumption of 'always available podcast API' validated? [Assumption]"
- Dependencies: "Are external podcast API requirements documented? [Dependency, Gap]"
- Missing definitions: "Is 'visual hierarchy' defined with measurable criteria? [Gap]"
**Content Consolidation**:
- Soft cap: If raw candidate items > 40, prioritize by risk/impact
- Merge near-duplicates checking the same requirement aspect
- If >5 low-impact edge cases, create one item: "Are edge cases X, Y, Z addressed in requirements? [Coverage]"
**🚫 ABSOLUTELY PROHIBITED** - These make it an implementation test, not a requirements test:
- ❌ Any item starting with "Verify", "Test", "Confirm", "Check" + implementation behavior
- ❌ References to code execution, user actions, system behavior
- ❌ "Displays correctly", "works properly", "functions as expected"
- ❌ "Click", "navigate", "render", "load", "execute"
- ❌ Test cases, test plans, QA procedures
- ❌ Implementation details (frameworks, APIs, algorithms)
**✅ REQUIRED PATTERNS** - These test requirements quality:
- ✅ "Are [requirement type] defined/specified/documented for [scenario]?"
- ✅ "Is [vague term] quantified/clarified with specific criteria?"
- ✅ "Are requirements consistent between [section A] and [section B]?"
- ✅ "Can [requirement] be objectively measured/verified?"
- ✅ "Are [edge cases/scenarios] addressed in requirements?"
- ✅ "Does the spec define [missing aspect]?"
6. **Structure Reference**: Generate the checklist following the canonical template in `.specify/templates/checklist-template.md` for title, meta section, category headings, and ID formatting. If template is unavailable, use: H1 title, purpose/created meta lines, `##` category sections containing `- [ ] CHK### <requirement item>` lines with globally incrementing IDs starting at CHK001.
7. **Report**: Output full path to created checklist, item count, and remind user that each run creates a new file. Summarize:
- Focus areas selected
- Depth level
- Actor/timing
- Any explicit user-specified must-have items incorporated
**Important**: Each `/speckit.checklist` command invocation creates a checklist file using short, descriptive names unless file already exists. This allows:
- Multiple checklists of different types (e.g., `ux.md`, `test.md`, `security.md`)
- Simple, memorable filenames that indicate checklist purpose
- Easy identification and navigation in the `checklists/` folder
To avoid clutter, use descriptive types and clean up obsolete checklists when done.
## Example Checklist Types & Sample Items
**UX Requirements Quality:** `ux.md`
Sample items (testing the requirements, NOT the implementation):
- "Are visual hierarchy requirements defined with measurable criteria? [Clarity, Spec §FR-1]"
- "Is the number and positioning of UI elements explicitly specified? [Completeness, Spec §FR-1]"
- "Are interaction state requirements (hover, focus, active) consistently defined? [Consistency]"
- "Are accessibility requirements specified for all interactive elements? [Coverage, Gap]"
- "Is fallback behavior defined when images fail to load? [Edge Case, Gap]"
- "Can 'prominent display' be objectively measured? [Measurability, Spec §FR-4]"
**API Requirements Quality:** `api.md`
Sample items:
- "Are error response formats specified for all failure scenarios? [Completeness]"
- "Are rate limiting requirements quantified with specific thresholds? [Clarity]"
- "Are authentication requirements consistent across all endpoints? [Consistency]"
- "Are retry/timeout requirements defined for external dependencies? [Coverage, Gap]"
- "Is versioning strategy documented in requirements? [Gap]"
**Performance Requirements Quality:** `performance.md`
Sample items:
- "Are performance requirements quantified with specific metrics? [Clarity]"
- "Are performance targets defined for all critical user journeys? [Coverage]"
- "Are performance requirements under different load conditions specified? [Completeness]"
- "Can performance requirements be objectively measured? [Measurability]"
- "Are degradation requirements defined for high-load scenarios? [Edge Case, Gap]"
**Security Requirements Quality:** `security.md`
Sample items:
- "Are authentication requirements specified for all protected resources? [Coverage]"
- "Are data protection requirements defined for sensitive information? [Completeness]"
- "Is the threat model documented and requirements aligned to it? [Traceability]"
- "Are security requirements consistent with compliance obligations? [Consistency]"
- "Are security failure/breach response requirements defined? [Gap, Exception Flow]"
## Anti-Examples: What NOT To Do
**❌ WRONG - These test implementation, not requirements:**
```markdown
- [ ] CHK001 - Verify landing page displays 3 episode cards [Spec §FR-001]
- [ ] CHK002 - Test hover states work correctly on desktop [Spec §FR-003]
- [ ] CHK003 - Confirm logo click navigates to home page [Spec §FR-010]
- [ ] CHK004 - Check that related episodes section shows 3-5 items [Spec §FR-005]
```
**✅ CORRECT - These test requirements quality:**
```markdown
- [ ] CHK001 - Are the number and layout of featured episodes explicitly specified? [Completeness, Spec §FR-001]
- [ ] CHK002 - Are hover state requirements consistently defined for all interactive elements? [Consistency, Spec §FR-003]
- [ ] CHK003 - Are navigation requirements clear for all clickable brand elements? [Clarity, Spec §FR-010]
- [ ] CHK004 - Is the selection criteria for related episodes documented? [Gap, Spec §FR-005]
- [ ] CHK005 - Are loading state requirements defined for asynchronous episode data? [Gap]
- [ ] CHK006 - Can "visual hierarchy" requirements be objectively measured? [Measurability, Spec §FR-001]
```
**Key Differences:**
- Wrong: Tests if the system works correctly
- Correct: Tests if the requirements are written correctly
- Wrong: Verification of behavior
- Correct: Validation of requirement quality
- Wrong: "Does it do X?"
- Correct: "Is X clearly specified?"

View File

@ -0,0 +1,177 @@
---
description: Identify underspecified areas in the current feature spec by asking up to 5 highly targeted clarification questions and encoding answers back into the spec.
---
## User Input
```text
$ARGUMENTS
```
You **MUST** consider the user input before proceeding (if not empty).
## Outline
Goal: Detect and reduce ambiguity or missing decision points in the active feature specification and record the clarifications directly in the spec file.
Note: This clarification workflow is expected to run (and be completed) BEFORE invoking `/speckit.plan`. If the user explicitly states they are skipping clarification (e.g., exploratory spike), you may proceed, but must warn that downstream rework risk increases.
Execution steps:
1. Run `.specify/scripts/powershell/check-prerequisites.ps1 -Json -PathsOnly` from repo root **once** (combined `--json --paths-only` mode / `-Json -PathsOnly`). Parse minimal JSON payload fields:
- `FEATURE_DIR`
- `FEATURE_SPEC`
- (Optionally capture `IMPL_PLAN`, `TASKS` for future chained flows.)
- If JSON parsing fails, abort and instruct user to re-run `/speckit.specify` or verify feature branch environment.
- For single quotes in args like "I'm Groot", use escape syntax: e.g 'I'\''m Groot' (or double-quote if possible: "I'm Groot").
2. Load the current spec file. Perform a structured ambiguity & coverage scan using this taxonomy. For each category, mark status: Clear / Partial / Missing. Produce an internal coverage map used for prioritization (do not output raw map unless no questions will be asked).
Functional Scope & Behavior:
- Core user goals & success criteria
- Explicit out-of-scope declarations
- User roles / personas differentiation
Domain & Data Model:
- Entities, attributes, relationships
- Identity & uniqueness rules
- Lifecycle/state transitions
- Data volume / scale assumptions
Interaction & UX Flow:
- Critical user journeys / sequences
- Error/empty/loading states
- Accessibility or localization notes
Non-Functional Quality Attributes:
- Performance (latency, throughput targets)
- Scalability (horizontal/vertical, limits)
- Reliability & availability (uptime, recovery expectations)
- Observability (logging, metrics, tracing signals)
- Security & privacy (authN/Z, data protection, threat assumptions)
- Compliance / regulatory constraints (if any)
Integration & External Dependencies:
- External services/APIs and failure modes
- Data import/export formats
- Protocol/versioning assumptions
Edge Cases & Failure Handling:
- Negative scenarios
- Rate limiting / throttling
- Conflict resolution (e.g., concurrent edits)
Constraints & Tradeoffs:
- Technical constraints (language, storage, hosting)
- Explicit tradeoffs or rejected alternatives
Terminology & Consistency:
- Canonical glossary terms
- Avoided synonyms / deprecated terms
Completion Signals:
- Acceptance criteria testability
- Measurable Definition of Done style indicators
Misc / Placeholders:
- TODO markers / unresolved decisions
- Ambiguous adjectives ("robust", "intuitive") lacking quantification
For each category with Partial or Missing status, add a candidate question opportunity unless:
- Clarification would not materially change implementation or validation strategy
- Information is better deferred to planning phase (note internally)
3. Generate (internally) a prioritized queue of candidate clarification questions (maximum 5). Do NOT output them all at once. Apply these constraints:
- Maximum of 10 total questions across the whole session.
- Each question must be answerable with EITHER:
- A short multiplechoice selection (25 distinct, mutually exclusive options), OR
- A one-word / shortphrase answer (explicitly constrain: "Answer in <=5 words").
- Only include questions whose answers materially impact architecture, data modeling, task decomposition, test design, UX behavior, operational readiness, or compliance validation.
- Ensure category coverage balance: attempt to cover the highest impact unresolved categories first; avoid asking two low-impact questions when a single high-impact area (e.g., security posture) is unresolved.
- Exclude questions already answered, trivial stylistic preferences, or plan-level execution details (unless blocking correctness).
- Favor clarifications that reduce downstream rework risk or prevent misaligned acceptance tests.
- If more than 5 categories remain unresolved, select the top 5 by (Impact * Uncertainty) heuristic.
4. Sequential questioning loop (interactive):
- Present EXACTLY ONE question at a time.
- For multiplechoice questions:
- **Analyze all options** and determine the **most suitable option** based on:
- Best practices for the project type
- Common patterns in similar implementations
- Risk reduction (security, performance, maintainability)
- Alignment with any explicit project goals or constraints visible in the spec
- Present your **recommended option prominently** at the top with clear reasoning (1-2 sentences explaining why this is the best choice).
- Format as: `**Recommended:** Option [X] - <reasoning>`
- Then render all options as a Markdown table:
| Option | Description |
|--------|-------------|
| A | <Option A description> |
| B | <Option B description> |
| C | <Option C description> (add D/E as needed up to 5) |
| Short | Provide a different short answer (<=5 words) (Include only if free-form alternative is appropriate) |
- After the table, add: `You can reply with the option letter (e.g., "A"), accept the recommendation by saying "yes" or "recommended", or provide your own short answer.`
- For shortanswer style (no meaningful discrete options):
- Provide your **suggested answer** based on best practices and context.
- Format as: `**Suggested:** <your proposed answer> - <brief reasoning>`
- Then output: `Format: Short answer (<=5 words). You can accept the suggestion by saying "yes" or "suggested", or provide your own answer.`
- After the user answers:
- If the user replies with "yes", "recommended", or "suggested", use your previously stated recommendation/suggestion as the answer.
- Otherwise, validate the answer maps to one option or fits the <=5 word constraint.
- If ambiguous, ask for a quick disambiguation (count still belongs to same question; do not advance).
- Once satisfactory, record it in working memory (do not yet write to disk) and move to the next queued question.
- Stop asking further questions when:
- All critical ambiguities resolved early (remaining queued items become unnecessary), OR
- User signals completion ("done", "good", "no more"), OR
- You reach 5 asked questions.
- Never reveal future queued questions in advance.
- If no valid questions exist at start, immediately report no critical ambiguities.
5. Integration after EACH accepted answer (incremental update approach):
- Maintain in-memory representation of the spec (loaded once at start) plus the raw file contents.
- For the first integrated answer in this session:
- Ensure a `## Clarifications` section exists (create it just after the highest-level contextual/overview section per the spec template if missing).
- Under it, create (if not present) a `### Session YYYY-MM-DD` subheading for today.
- Append a bullet line immediately after acceptance: `- Q: <question> → A: <final answer>`.
- Then immediately apply the clarification to the most appropriate section(s):
- Functional ambiguity → Update or add a bullet in Functional Requirements.
- User interaction / actor distinction → Update User Stories or Actors subsection (if present) with clarified role, constraint, or scenario.
- Data shape / entities → Update Data Model (add fields, types, relationships) preserving ordering; note added constraints succinctly.
- Non-functional constraint → Add/modify measurable criteria in Non-Functional / Quality Attributes section (convert vague adjective to metric or explicit target).
- Edge case / negative flow → Add a new bullet under Edge Cases / Error Handling (or create such subsection if template provides placeholder for it).
- Terminology conflict → Normalize term across spec; retain original only if necessary by adding `(formerly referred to as "X")` once.
- If the clarification invalidates an earlier ambiguous statement, replace that statement instead of duplicating; leave no obsolete contradictory text.
- Save the spec file AFTER each integration to minimize risk of context loss (atomic overwrite).
- Preserve formatting: do not reorder unrelated sections; keep heading hierarchy intact.
- Keep each inserted clarification minimal and testable (avoid narrative drift).
6. Validation (performed after EACH write plus final pass):
- Clarifications session contains exactly one bullet per accepted answer (no duplicates).
- Total asked (accepted) questions ≤ 5.
- Updated sections contain no lingering vague placeholders the new answer was meant to resolve.
- No contradictory earlier statement remains (scan for now-invalid alternative choices removed).
- Markdown structure valid; only allowed new headings: `## Clarifications`, `### Session YYYY-MM-DD`.
- Terminology consistency: same canonical term used across all updated sections.
7. Write the updated spec back to `FEATURE_SPEC`.
8. Report completion (after questioning loop ends or early termination):
- Number of questions asked & answered.
- Path to updated spec.
- Sections touched (list names).
- Coverage summary table listing each taxonomy category with Status: Resolved (was Partial/Missing and addressed), Deferred (exceeds question quota or better suited for planning), Clear (already sufficient), Outstanding (still Partial/Missing but low impact).
- If any Outstanding or Deferred remain, recommend whether to proceed to `/speckit.plan` or run `/speckit.clarify` again later post-plan.
- Suggested next command.
Behavior rules:
- If no meaningful ambiguities found (or all potential questions would be low-impact), respond: "No critical ambiguities detected worth formal clarification." and suggest proceeding.
- If spec file missing, instruct user to run `/speckit.specify` first (do not create a new spec here).
- Never exceed 5 total asked questions (clarification retries for a single question do not count as new questions).
- Avoid speculative tech stack questions unless the absence blocks functional clarity.
- Respect user early termination signals ("stop", "done", "proceed").
- If no questions asked due to full coverage, output a compact coverage summary (all categories Clear) then suggest advancing.
- If quota reached with unresolved high-impact categories remaining, explicitly flag them under Deferred with rationale.
Context for prioritization: $ARGUMENTS

View File

@ -0,0 +1,78 @@
---
description: Create or update the project constitution from interactive or provided principle inputs, ensuring all dependent templates stay in sync
---
## User Input
```text
$ARGUMENTS
```
You **MUST** consider the user input before proceeding (if not empty).
## Outline
You are updating the project constitution at `.specify/memory/constitution.md`. This file is a TEMPLATE containing placeholder tokens in square brackets (e.g. `[PROJECT_NAME]`, `[PRINCIPLE_1_NAME]`). Your job is to (a) collect/derive concrete values, (b) fill the template precisely, and (c) propagate any amendments across dependent artifacts.
Follow this execution flow:
1. Load the existing constitution template at `.specify/memory/constitution.md`.
- Identify every placeholder token of the form `[ALL_CAPS_IDENTIFIER]`.
**IMPORTANT**: The user might require less or more principles than the ones used in the template. If a number is specified, respect that - follow the general template. You will update the doc accordingly.
2. Collect/derive values for placeholders:
- If user input (conversation) supplies a value, use it.
- Otherwise infer from existing repo context (README, docs, prior constitution versions if embedded).
- For governance dates: `RATIFICATION_DATE` is the original adoption date (if unknown ask or mark TODO), `LAST_AMENDED_DATE` is today if changes are made, otherwise keep previous.
- `CONSTITUTION_VERSION` must increment according to semantic versioning rules:
- MAJOR: Backward incompatible governance/principle removals or redefinitions.
- MINOR: New principle/section added or materially expanded guidance.
- PATCH: Clarifications, wording, typo fixes, non-semantic refinements.
- If version bump type ambiguous, propose reasoning before finalizing.
3. Draft the updated constitution content:
- Replace every placeholder with concrete text (no bracketed tokens left except intentionally retained template slots that the project has chosen not to define yet—explicitly justify any left).
- Preserve heading hierarchy and comments can be removed once replaced unless they still add clarifying guidance.
- Ensure each Principle section: succinct name line, paragraph (or bullet list) capturing nonnegotiable rules, explicit rationale if not obvious.
- Ensure Governance section lists amendment procedure, versioning policy, and compliance review expectations.
4. Consistency propagation checklist (convert prior checklist into active validations):
- Read `.specify/templates/plan-template.md` and ensure any "Constitution Check" or rules align with updated principles.
- Read `.specify/templates/spec-template.md` for scope/requirements alignment—update if constitution adds/removes mandatory sections or constraints.
- Read `.specify/templates/tasks-template.md` and ensure task categorization reflects new or removed principle-driven task types (e.g., observability, versioning, testing discipline).
- Read each command file in `.specify/templates/commands/*.md` (including this one) to verify no outdated references (agent-specific names like CLAUDE only) remain when generic guidance is required.
- Read any runtime guidance docs (e.g., `README.md`, `docs/quickstart.md`, or agent-specific guidance files if present). Update references to principles changed.
5. Produce a Sync Impact Report (prepend as an HTML comment at top of the constitution file after update):
- Version change: old → new
- List of modified principles (old title → new title if renamed)
- Added sections
- Removed sections
- Templates requiring updates (✅ updated / ⚠ pending) with file paths
- Follow-up TODOs if any placeholders intentionally deferred.
6. Validation before final output:
- No remaining unexplained bracket tokens.
- Version line matches report.
- Dates ISO format YYYY-MM-DD.
- Principles are declarative, testable, and free of vague language ("should" → replace with MUST/SHOULD rationale where appropriate).
7. Write the completed constitution back to `.specify/memory/constitution.md` (overwrite).
8. Output a final summary to the user with:
- New version and bump rationale.
- Any files flagged for manual follow-up.
- Suggested commit message (e.g., `docs: amend constitution to vX.Y.Z (principle additions + governance update)`).
Formatting & Style Requirements:
- Use Markdown headings exactly as in the template (do not demote/promote levels).
- Wrap long rationale lines to keep readability (<100 chars ideally) but do not hard enforce with awkward breaks.
- Keep a single blank line between sections.
- Avoid trailing whitespace.
If the user supplies partial updates (e.g., only one principle revision), still perform validation and version decision steps.
If critical info missing (e.g., ratification date truly unknown), insert `TODO(<FIELD_NAME>): explanation` and include in the Sync Impact Report under deferred items.
Do not create a new template; always operate on the existing `.specify/memory/constitution.md` file.

View File

@ -0,0 +1,134 @@
---
description: Execute the implementation plan by processing and executing all tasks defined in tasks.md
---
## User Input
```text
$ARGUMENTS
```
You **MUST** consider the user input before proceeding (if not empty).
## Outline
1. Run `.specify/scripts/powershell/check-prerequisites.ps1 -Json -RequireTasks -IncludeTasks` from repo root and parse FEATURE_DIR and AVAILABLE_DOCS list. All paths must be absolute. For single quotes in args like "I'm Groot", use escape syntax: e.g 'I'\''m Groot' (or double-quote if possible: "I'm Groot").
2. **Check checklists status** (if FEATURE_DIR/checklists/ exists):
- Scan all checklist files in the checklists/ directory
- For each checklist, count:
- Total items: All lines matching `- [ ]` or `- [X]` or `- [x]`
- Completed items: Lines matching `- [X]` or `- [x]`
- Incomplete items: Lines matching `- [ ]`
- Create a status table:
```text
| Checklist | Total | Completed | Incomplete | Status |
|-----------|-------|-----------|------------|--------|
| ux.md | 12 | 12 | 0 | ✓ PASS |
| test.md | 8 | 5 | 3 | ✗ FAIL |
| security.md | 6 | 6 | 0 | ✓ PASS |
```
- Calculate overall status:
- **PASS**: All checklists have 0 incomplete items
- **FAIL**: One or more checklists have incomplete items
- **If any checklist is incomplete**:
- Display the table with incomplete item counts
- **STOP** and ask: "Some checklists are incomplete. Do you want to proceed with implementation anyway? (yes/no)"
- Wait for user response before continuing
- If user says "no" or "wait" or "stop", halt execution
- If user says "yes" or "proceed" or "continue", proceed to step 3
- **If all checklists are complete**:
- Display the table showing all checklists passed
- Automatically proceed to step 3
3. Load and analyze the implementation context:
- **REQUIRED**: Read tasks.md for the complete task list and execution plan
- **REQUIRED**: Read plan.md for tech stack, architecture, and file structure
- **IF EXISTS**: Read data-model.md for entities and relationships
- **IF EXISTS**: Read contracts/ for API specifications and test requirements
- **IF EXISTS**: Read research.md for technical decisions and constraints
- **IF EXISTS**: Read quickstart.md for integration scenarios
4. **Project Setup Verification**:
- **REQUIRED**: Create/verify ignore files based on actual project setup:
**Detection & Creation Logic**:
- Check if the following command succeeds to determine if the repository is a git repo (create/verify .gitignore if so):
```sh
git rev-parse --git-dir 2>/dev/null
```
- Check if Dockerfile* exists or Docker in plan.md → create/verify .dockerignore
- Check if .eslintrc*or eslint.config.* exists → create/verify .eslintignore
- Check if .prettierrc* exists → create/verify .prettierignore
- Check if .npmrc or package.json exists → create/verify .npmignore (if publishing)
- Check if terraform files (*.tf) exist → create/verify .terraformignore
- Check if .helmignore needed (helm charts present) → create/verify .helmignore
**If ignore file already exists**: Verify it contains essential patterns, append missing critical patterns only
**If ignore file missing**: Create with full pattern set for detected technology
**Common Patterns by Technology** (from plan.md tech stack):
- **Node.js/JavaScript/TypeScript**: `node_modules/`, `dist/`, `build/`, `*.log`, `.env*`
- **Python**: `__pycache__/`, `*.pyc`, `.venv/`, `venv/`, `dist/`, `*.egg-info/`
- **Java**: `target/`, `*.class`, `*.jar`, `.gradle/`, `build/`
- **C#/.NET**: `bin/`, `obj/`, `*.user`, `*.suo`, `packages/`
- **Go**: `*.exe`, `*.test`, `vendor/`, `*.out`
- **Ruby**: `.bundle/`, `log/`, `tmp/`, `*.gem`, `vendor/bundle/`
- **PHP**: `vendor/`, `*.log`, `*.cache`, `*.env`
- **Rust**: `target/`, `debug/`, `release/`, `*.rs.bk`, `*.rlib`, `*.prof*`, `.idea/`, `*.log`, `.env*`
- **Kotlin**: `build/`, `out/`, `.gradle/`, `.idea/`, `*.class`, `*.jar`, `*.iml`, `*.log`, `.env*`
- **C++**: `build/`, `bin/`, `obj/`, `out/`, `*.o`, `*.so`, `*.a`, `*.exe`, `*.dll`, `.idea/`, `*.log`, `.env*`
- **C**: `build/`, `bin/`, `obj/`, `out/`, `*.o`, `*.a`, `*.so`, `*.exe`, `Makefile`, `config.log`, `.idea/`, `*.log`, `.env*`
- **Swift**: `.build/`, `DerivedData/`, `*.swiftpm/`, `Packages/`
- **R**: `.Rproj.user/`, `.Rhistory`, `.RData`, `.Ruserdata`, `*.Rproj`, `packrat/`, `renv/`
- **Universal**: `.DS_Store`, `Thumbs.db`, `*.tmp`, `*.swp`, `.vscode/`, `.idea/`
**Tool-Specific Patterns**:
- **Docker**: `node_modules/`, `.git/`, `Dockerfile*`, `.dockerignore`, `*.log*`, `.env*`, `coverage/`
- **ESLint**: `node_modules/`, `dist/`, `build/`, `coverage/`, `*.min.js`
- **Prettier**: `node_modules/`, `dist/`, `build/`, `coverage/`, `package-lock.json`, `yarn.lock`, `pnpm-lock.yaml`
- **Terraform**: `.terraform/`, `*.tfstate*`, `*.tfvars`, `.terraform.lock.hcl`
- **Kubernetes/k8s**: `*.secret.yaml`, `secrets/`, `.kube/`, `kubeconfig*`, `*.key`, `*.crt`
5. Parse tasks.md structure and extract:
- **Task phases**: Setup, Tests, Core, Integration, Polish
- **Task dependencies**: Sequential vs parallel execution rules
- **Task details**: ID, description, file paths, parallel markers [P]
- **Execution flow**: Order and dependency requirements
6. Execute implementation following the task plan:
- **Phase-by-phase execution**: Complete each phase before moving to the next
- **Respect dependencies**: Run sequential tasks in order, parallel tasks [P] can run together
- **Follow TDD approach**: Execute test tasks before their corresponding implementation tasks
- **File-based coordination**: Tasks affecting the same files must run sequentially
- **Validation checkpoints**: Verify each phase completion before proceeding
7. Implementation execution rules:
- **Setup first**: Initialize project structure, dependencies, configuration
- **Tests before code**: If you need to write tests for contracts, entities, and integration scenarios
- **Core development**: Implement models, services, CLI commands, endpoints
- **Integration work**: Database connections, middleware, logging, external services
- **Polish and validation**: Unit tests, performance optimization, documentation
8. Progress tracking and error handling:
- Report progress after each completed task
- Halt execution if any non-parallel task fails
- For parallel tasks [P], continue with successful tasks, report failed ones
- Provide clear error messages with context for debugging
- Suggest next steps if implementation cannot proceed
- **IMPORTANT** For completed tasks, make sure to mark the task off as [X] in the tasks file.
9. Completion validation:
- Verify all required tasks are completed
- Check that implemented features match the original specification
- Validate that tests pass and coverage meets requirements
- Confirm the implementation follows the technical plan
- Report final status with summary of completed work
Note: This command assumes a complete task breakdown exists in tasks.md. If tasks are incomplete or missing, suggest running `/speckit.tasks` first to regenerate the task list.

View File

@ -0,0 +1,81 @@
---
description: Execute the implementation planning workflow using the plan template to generate design artifacts.
---
## User Input
```text
$ARGUMENTS
```
You **MUST** consider the user input before proceeding (if not empty).
## Outline
1. **Setup**: Run `.specify/scripts/powershell/setup-plan.ps1 -Json` from repo root and parse JSON for FEATURE_SPEC, IMPL_PLAN, SPECS_DIR, BRANCH. For single quotes in args like "I'm Groot", use escape syntax: e.g 'I'\''m Groot' (or double-quote if possible: "I'm Groot").
2. **Load context**: Read FEATURE_SPEC and `.specify/memory/constitution.md`. Load IMPL_PLAN template (already copied).
3. **Execute plan workflow**: Follow the structure in IMPL_PLAN template to:
- Fill Technical Context (mark unknowns as "NEEDS CLARIFICATION")
- Fill Constitution Check section from constitution
- Evaluate gates (ERROR if violations unjustified)
- Phase 0: Generate research.md (resolve all NEEDS CLARIFICATION)
- Phase 1: Generate data-model.md, contracts/, quickstart.md
- Phase 1: Update agent context by running the agent script
- Re-evaluate Constitution Check post-design
4. **Stop and report**: Command ends after Phase 2 planning. Report branch, IMPL_PLAN path, and generated artifacts.
## Phases
### Phase 0: Outline & Research
1. **Extract unknowns from Technical Context** above:
- For each NEEDS CLARIFICATION → research task
- For each dependency → best practices task
- For each integration → patterns task
2. **Generate and dispatch research agents**:
```text
For each unknown in Technical Context:
Task: "Research {unknown} for {feature context}"
For each technology choice:
Task: "Find best practices for {tech} in {domain}"
```
3. **Consolidate findings** in `research.md` using format:
- Decision: [what was chosen]
- Rationale: [why chosen]
- Alternatives considered: [what else evaluated]
**Output**: research.md with all NEEDS CLARIFICATION resolved
### Phase 1: Design & Contracts
**Prerequisites:** `research.md` complete
1. **Extract entities from feature spec**`data-model.md`:
- Entity name, fields, relationships
- Validation rules from requirements
- State transitions if applicable
2. **Generate API contracts** from functional requirements:
- For each user action → endpoint
- Use standard REST/GraphQL patterns
- Output OpenAPI/GraphQL schema to `/contracts/`
3. **Agent context update**:
- Run `.specify/scripts/powershell/update-agent-context.ps1 -AgentType copilot`
- These scripts detect which AI agent is in use
- Update the appropriate agent-specific context file
- Add only new technology from current plan
- Preserve manual additions between markers
**Output**: data-model.md, /contracts/*, quickstart.md, agent-specific file
## Key rules
- Use absolute paths
- ERROR on gate failures or unresolved clarifications

View File

@ -0,0 +1,249 @@
---
description: Create or update the feature specification from a natural language feature description.
---
## User Input
```text
$ARGUMENTS
```
You **MUST** consider the user input before proceeding (if not empty).
## Outline
The text the user typed after `/speckit.specify` in the triggering message **is** the feature description. Assume you always have it available in this conversation even if `$ARGUMENTS` appears literally below. Do not ask the user to repeat it unless they provided an empty command.
Given that feature description, do this:
1. **Generate a concise short name** (2-4 words) for the branch:
- Analyze the feature description and extract the most meaningful keywords
- Create a 2-4 word short name that captures the essence of the feature
- Use action-noun format when possible (e.g., "add-user-auth", "fix-payment-bug")
- Preserve technical terms and acronyms (OAuth2, API, JWT, etc.)
- Keep it concise but descriptive enough to understand the feature at a glance
- Examples:
- "I want to add user authentication" → "user-auth"
- "Implement OAuth2 integration for the API" → "oauth2-api-integration"
- "Create a dashboard for analytics" → "analytics-dashboard"
- "Fix payment processing timeout bug" → "fix-payment-timeout"
2. **Check for existing branches before creating new one**:
a. First, fetch all remote branches to ensure we have the latest information:
```bash
git fetch --all --prune
```
b. Find the highest feature number across all sources for the short-name:
- Remote branches: `git ls-remote --heads origin | grep -E 'refs/heads/[0-9]+-<short-name>$'`
- Local branches: `git branch | grep -E '^[* ]*[0-9]+-<short-name>$'`
- Specs directories: Check for directories matching `specs/[0-9]+-<short-name>`
c. Determine the next available number:
- Extract all numbers from all three sources
- Find the highest number N
- Use N+1 for the new branch number
d. Run the script `.specify/scripts/powershell/create-new-feature.ps1 -Json "$ARGUMENTS"` with the calculated number and short-name:
- Pass `--number N+1` and `--short-name "your-short-name"` along with the feature description
- Bash example: `.specify/scripts/powershell/create-new-feature.ps1 -Json "$ARGUMENTS" --json --number 5 --short-name "user-auth" "Add user authentication"`
- PowerShell example: `.specify/scripts/powershell/create-new-feature.ps1 -Json "$ARGUMENTS" -Json -Number 5 -ShortName "user-auth" "Add user authentication"`
**IMPORTANT**:
- Check all three sources (remote branches, local branches, specs directories) to find the highest number
- Only match branches/directories with the exact short-name pattern
- If no existing branches/directories found with this short-name, start with number 1
- You must only ever run this script once per feature
- The JSON is provided in the terminal as output - always refer to it to get the actual content you're looking for
- The JSON output will contain BRANCH_NAME and SPEC_FILE paths
- For single quotes in args like "I'm Groot", use escape syntax: e.g 'I'\''m Groot' (or double-quote if possible: "I'm Groot")
3. Load `.specify/templates/spec-template.md` to understand required sections.
4. Follow this execution flow:
1. Parse user description from Input
If empty: ERROR "No feature description provided"
2. Extract key concepts from description
Identify: actors, actions, data, constraints
3. For unclear aspects:
- Make informed guesses based on context and industry standards
- Only mark with [NEEDS CLARIFICATION: specific question] if:
- The choice significantly impacts feature scope or user experience
- Multiple reasonable interpretations exist with different implications
- No reasonable default exists
- **LIMIT: Maximum 3 [NEEDS CLARIFICATION] markers total**
- Prioritize clarifications by impact: scope > security/privacy > user experience > technical details
4. Fill User Scenarios & Testing section
If no clear user flow: ERROR "Cannot determine user scenarios"
5. Generate Functional Requirements
Each requirement must be testable
Use reasonable defaults for unspecified details (document assumptions in Assumptions section)
6. Define Success Criteria
Create measurable, technology-agnostic outcomes
Include both quantitative metrics (time, performance, volume) and qualitative measures (user satisfaction, task completion)
Each criterion must be verifiable without implementation details
7. Identify Key Entities (if data involved)
8. Return: SUCCESS (spec ready for planning)
5. Write the specification to SPEC_FILE using the template structure, replacing placeholders with concrete details derived from the feature description (arguments) while preserving section order and headings.
6. **Specification Quality Validation**: After writing the initial spec, validate it against quality criteria:
a. **Create Spec Quality Checklist**: Generate a checklist file at `FEATURE_DIR/checklists/requirements.md` using the checklist template structure with these validation items:
```markdown
# Specification Quality Checklist: [FEATURE NAME]
**Purpose**: Validate specification completeness and quality before proceeding to planning
**Created**: [DATE]
**Feature**: [Link to spec.md]
## Content Quality
- [ ] No implementation details (languages, frameworks, APIs)
- [ ] Focused on user value and business needs
- [ ] Written for non-technical stakeholders
- [ ] All mandatory sections completed
## Requirement Completeness
- [ ] No [NEEDS CLARIFICATION] markers remain
- [ ] Requirements are testable and unambiguous
- [ ] Success criteria are measurable
- [ ] Success criteria are technology-agnostic (no implementation details)
- [ ] All acceptance scenarios are defined
- [ ] Edge cases are identified
- [ ] Scope is clearly bounded
- [ ] Dependencies and assumptions identified
## Feature Readiness
- [ ] All functional requirements have clear acceptance criteria
- [ ] User scenarios cover primary flows
- [ ] Feature meets measurable outcomes defined in Success Criteria
- [ ] No implementation details leak into specification
## Notes
- Items marked incomplete require spec updates before `/speckit.clarify` or `/speckit.plan`
```
b. **Run Validation Check**: Review the spec against each checklist item:
- For each item, determine if it passes or fails
- Document specific issues found (quote relevant spec sections)
c. **Handle Validation Results**:
- **If all items pass**: Mark checklist complete and proceed to step 6
- **If items fail (excluding [NEEDS CLARIFICATION])**:
1. List the failing items and specific issues
2. Update the spec to address each issue
3. Re-run validation until all items pass (max 3 iterations)
4. If still failing after 3 iterations, document remaining issues in checklist notes and warn user
- **If [NEEDS CLARIFICATION] markers remain**:
1. Extract all [NEEDS CLARIFICATION: ...] markers from the spec
2. **LIMIT CHECK**: If more than 3 markers exist, keep only the 3 most critical (by scope/security/UX impact) and make informed guesses for the rest
3. For each clarification needed (max 3), present options to user in this format:
```markdown
## Question [N]: [Topic]
**Context**: [Quote relevant spec section]
**What we need to know**: [Specific question from NEEDS CLARIFICATION marker]
**Suggested Answers**:
| Option | Answer | Implications |
|--------|--------|--------------|
| A | [First suggested answer] | [What this means for the feature] |
| B | [Second suggested answer] | [What this means for the feature] |
| C | [Third suggested answer] | [What this means for the feature] |
| Custom | Provide your own answer | [Explain how to provide custom input] |
**Your choice**: _[Wait for user response]_
```
4. **CRITICAL - Table Formatting**: Ensure markdown tables are properly formatted:
- Use consistent spacing with pipes aligned
- Each cell should have spaces around content: `| Content |` not `|Content|`
- Header separator must have at least 3 dashes: `|--------|`
- Test that the table renders correctly in markdown preview
5. Number questions sequentially (Q1, Q2, Q3 - max 3 total)
6. Present all questions together before waiting for responses
7. Wait for user to respond with their choices for all questions (e.g., "Q1: A, Q2: Custom - [details], Q3: B")
8. Update the spec by replacing each [NEEDS CLARIFICATION] marker with the user's selected or provided answer
9. Re-run validation after all clarifications are resolved
d. **Update Checklist**: After each validation iteration, update the checklist file with current pass/fail status
7. Report completion with branch name, spec file path, checklist results, and readiness for the next phase (`/speckit.clarify` or `/speckit.plan`).
**NOTE:** The script creates and checks out the new branch and initializes the spec file before writing.
## General Guidelines
## Quick Guidelines
- Focus on **WHAT** users need and **WHY**.
- Avoid HOW to implement (no tech stack, APIs, code structure).
- Written for business stakeholders, not developers.
- DO NOT create any checklists that are embedded in the spec. That will be a separate command.
### Section Requirements
- **Mandatory sections**: Must be completed for every feature
- **Optional sections**: Include only when relevant to the feature
- When a section doesn't apply, remove it entirely (don't leave as "N/A")
### For AI Generation
When creating this spec from a user prompt:
1. **Make informed guesses**: Use context, industry standards, and common patterns to fill gaps
2. **Document assumptions**: Record reasonable defaults in the Assumptions section
3. **Limit clarifications**: Maximum 3 [NEEDS CLARIFICATION] markers - use only for critical decisions that:
- Significantly impact feature scope or user experience
- Have multiple reasonable interpretations with different implications
- Lack any reasonable default
4. **Prioritize clarifications**: scope > security/privacy > user experience > technical details
5. **Think like a tester**: Every vague requirement should fail the "testable and unambiguous" checklist item
6. **Common areas needing clarification** (only if no reasonable default exists):
- Feature scope and boundaries (include/exclude specific use cases)
- User types and permissions (if multiple conflicting interpretations possible)
- Security/compliance requirements (when legally/financially significant)
**Examples of reasonable defaults** (don't ask about these):
- Data retention: Industry-standard practices for the domain
- Performance targets: Standard web/mobile app expectations unless specified
- Error handling: User-friendly messages with appropriate fallbacks
- Authentication method: Standard session-based or OAuth2 for web apps
- Integration patterns: RESTful APIs unless specified otherwise
### Success Criteria Guidelines
Success criteria must be:
1. **Measurable**: Include specific metrics (time, percentage, count, rate)
2. **Technology-agnostic**: No mention of frameworks, languages, databases, or tools
3. **User-focused**: Describe outcomes from user/business perspective, not system internals
4. **Verifiable**: Can be tested/validated without knowing implementation details
**Good examples**:
- "Users can complete checkout in under 3 minutes"
- "System supports 10,000 concurrent users"
- "95% of searches return results in under 1 second"
- "Task completion rate improves by 40%"
**Bad examples** (implementation-focused):
- "API response time is under 200ms" (too technical, use "Users see results instantly")
- "Database can handle 1000 TPS" (implementation detail, use user-facing metric)
- "React components render efficiently" (framework-specific)
- "Redis cache hit rate above 80%" (technology-specific)

View File

@ -0,0 +1,128 @@
---
description: Generate an actionable, dependency-ordered tasks.md for the feature based on available design artifacts.
---
## User Input
```text
$ARGUMENTS
```
You **MUST** consider the user input before proceeding (if not empty).
## Outline
1. **Setup**: Run `.specify/scripts/powershell/check-prerequisites.ps1 -Json` from repo root and parse FEATURE_DIR and AVAILABLE_DOCS list. All paths must be absolute. For single quotes in args like "I'm Groot", use escape syntax: e.g 'I'\''m Groot' (or double-quote if possible: "I'm Groot").
2. **Load design documents**: Read from FEATURE_DIR:
- **Required**: plan.md (tech stack, libraries, structure), spec.md (user stories with priorities)
- **Optional**: data-model.md (entities), contracts/ (API endpoints), research.md (decisions), quickstart.md (test scenarios)
- Note: Not all projects have all documents. Generate tasks based on what's available.
3. **Execute task generation workflow**:
- Load plan.md and extract tech stack, libraries, project structure
- Load spec.md and extract user stories with their priorities (P1, P2, P3, etc.)
- If data-model.md exists: Extract entities and map to user stories
- If contracts/ exists: Map endpoints to user stories
- If research.md exists: Extract decisions for setup tasks
- Generate tasks organized by user story (see Task Generation Rules below)
- Generate dependency graph showing user story completion order
- Create parallel execution examples per user story
- Validate task completeness (each user story has all needed tasks, independently testable)
4. **Generate tasks.md**: Use `.specify.specify/templates/tasks-template.md` as structure, fill with:
- Correct feature name from plan.md
- Phase 1: Setup tasks (project initialization)
- Phase 2: Foundational tasks (blocking prerequisites for all user stories)
- Phase 3+: One phase per user story (in priority order from spec.md)
- Each phase includes: story goal, independent test criteria, tests (if requested), implementation tasks
- Final Phase: Polish & cross-cutting concerns
- All tasks must follow the strict checklist format (see Task Generation Rules below)
- Clear file paths for each task
- Dependencies section showing story completion order
- Parallel execution examples per story
- Implementation strategy section (MVP first, incremental delivery)
5. **Report**: Output path to generated tasks.md and summary:
- Total task count
- Task count per user story
- Parallel opportunities identified
- Independent test criteria for each story
- Suggested MVP scope (typically just User Story 1)
- Format validation: Confirm ALL tasks follow the checklist format (checkbox, ID, labels, file paths)
Context for task generation: $ARGUMENTS
The tasks.md should be immediately executable - each task must be specific enough that an LLM can complete it without additional context.
## Task Generation Rules
**CRITICAL**: Tasks MUST be organized by user story to enable independent implementation and testing.
**Tests are OPTIONAL**: Only generate test tasks if explicitly requested in the feature specification or if user requests TDD approach.
### Checklist Format (REQUIRED)
Every task MUST strictly follow this format:
```text
- [ ] [TaskID] [P?] [Story?] Description with file path
```
**Format Components**:
1. **Checkbox**: ALWAYS start with `- [ ]` (markdown checkbox)
2. **Task ID**: Sequential number (T001, T002, T003...) in execution order
3. **[P] marker**: Include ONLY if task is parallelizable (different files, no dependencies on incomplete tasks)
4. **[Story] label**: REQUIRED for user story phase tasks only
- Format: [US1], [US2], [US3], etc. (maps to user stories from spec.md)
- Setup phase: NO story label
- Foundational phase: NO story label
- User Story phases: MUST have story label
- Polish phase: NO story label
5. **Description**: Clear action with exact file path
**Examples**:
- ✅ CORRECT: `- [ ] T001 Create project structure per implementation plan`
- ✅ CORRECT: `- [ ] T005 [P] Implement authentication middleware in src/middleware/auth.py`
- ✅ CORRECT: `- [ ] T012 [P] [US1] Create User model in src/models/user.py`
- ✅ CORRECT: `- [ ] T014 [US1] Implement UserService in src/services/user_service.py`
- ❌ WRONG: `- [ ] Create User model` (missing ID and Story label)
- ❌ WRONG: `T001 [US1] Create model` (missing checkbox)
- ❌ WRONG: `- [ ] [US1] Create User model` (missing Task ID)
- ❌ WRONG: `- [ ] T001 [US1] Create model` (missing file path)
### Task Organization
1. **From User Stories (spec.md)** - PRIMARY ORGANIZATION:
- Each user story (P1, P2, P3...) gets its own phase
- Map all related components to their story:
- Models needed for that story
- Services needed for that story
- Endpoints/UI needed for that story
- If tests requested: Tests specific to that story
- Mark story dependencies (most stories should be independent)
2. **From Contracts**:
- Map each contract/endpoint → to the user story it serves
- If tests requested: Each contract → contract test task [P] before implementation in that story's phase
3. **From Data Model**:
- Map each entity to the user story(ies) that need it
- If entity serves multiple stories: Put in earliest story or Setup phase
- Relationships → service layer tasks in appropriate story phase
4. **From Setup/Infrastructure**:
- Shared infrastructure → Setup phase (Phase 1)
- Foundational/blocking tasks → Foundational phase (Phase 2)
- Story-specific setup → within that story's phase
### Phase Structure
- **Phase 1**: Setup (project initialization)
- **Phase 2**: Foundational (blocking prerequisites - MUST complete before user stories)
- **Phase 3+**: User Stories in priority order (P1, P2, P3...)
- Within each story: Tests (if requested) → Models → Services → Endpoints → Integration
- Each phase should be a complete, independently testable increment
- **Final Phase**: Polish & Cross-Cutting Concerns

View File

@ -0,0 +1,31 @@
__pycache__/
*.pyc
*.pyo
*.pyd
.Python
*.so
*.egg
*.egg-info/
dist/
build/
.venv/
venv/
env/
ENV/
*.log
.env
.env.*
models/*.pth
models/*.pt
data/images/
.pytest_cache/
.coverage
htmlcov/
*.swp
*.swo
*~
.DS_Store
Thumbs.db
.vscode/
.idea/
*.iml

View File

@ -0,0 +1,39 @@
<!--
Sync Impact Report:
- Version change: none → 1.0.0
- Added sections: Core Principles (5 new), Additional Standards, Development Workflow, Governance
- Templates requiring updates: ✅ updated .specify/templates/plan-template.md (Constitution Check section)
- Follow-up TODOs: RATIFICATION_DATE placeholder
-->
# ResNet AI Model Constitution
## Core Principles
### Code Quality and Modularity
All code must follow PEP 8 standards, include type hints, comprehensive docstrings, and be organized into modular, reusable components. Rationale: Enhances readability, maintainability, and collaboration in complex AI projects.
### Rigorous Testing Standards
Implement unit tests for all utility functions, integration tests for data pipelines, and model validation tests including cross-validation and performance benchmarks. Use pytest or equivalent frameworks. Rationale: Ensures reliability and catches errors early in AI model development where failures can have significant impacts.
### Reproducibility and Versioning
Version all code, data, and models using Git and DVC. Set random seeds for reproducibility in experiments. Document all dependencies and environments. Rationale: Critical for scientific validation and debugging in machine learning.
### Model Evaluation and Validation
Evaluate models on independent test sets with multiple metrics (accuracy, F1-score, AUC, etc.). Perform error analysis and bias checks. Rationale: Prevents overfitting and ensures models are fair and effective.
### Continuous Integration and Quality Gates
All changes must pass linting, unit tests, and integration tests in CI/CD pipelines. Model performance regressions must be flagged. Rationale: Maintains code and model quality over time.
## Additional Standards
Security: Protect sensitive data, comply with privacy regulations. Ethics: Conduct bias audits, ensure responsible AI practices. Performance: Optimize for computational efficiency and scalability.
## Development Workflow
Code reviews mandatory for all PRs. Model changes require peer review of evaluation results. Use issue tracking for bugs and features.
## Governance
Constitution supersedes other practices. Amendments require documentation and approval from project maintainers. Compliance verified in code reviews.
**Version**: 1.0.0 | **Ratified**: TODO(RATIFICATION_DATE): Original adoption date unknown. | **Last Amended**: 2025-11-04

View File

@ -0,0 +1,148 @@
#!/usr/bin/env pwsh
# Consolidated prerequisite checking script (PowerShell)
#
# This script provides unified prerequisite checking for Spec-Driven Development workflow.
# It replaces the functionality previously spread across multiple scripts.
#
# Usage: ./check-prerequisites.ps1 [OPTIONS]
#
# OPTIONS:
# -Json Output in JSON format
# -RequireTasks Require tasks.md to exist (for implementation phase)
# -IncludeTasks Include tasks.md in AVAILABLE_DOCS list
# -PathsOnly Only output path variables (no validation)
# -Help, -h Show help message
[CmdletBinding()]
param(
[switch]$Json,
[switch]$RequireTasks,
[switch]$IncludeTasks,
[switch]$PathsOnly,
[switch]$Help
)
$ErrorActionPreference = 'Stop'
# Show help if requested
if ($Help) {
Write-Output @"
Usage: check-prerequisites.ps1 [OPTIONS]
Consolidated prerequisite checking for Spec-Driven Development workflow.
OPTIONS:
-Json Output in JSON format
-RequireTasks Require tasks.md to exist (for implementation phase)
-IncludeTasks Include tasks.md in AVAILABLE_DOCS list
-PathsOnly Only output path variables (no prerequisite validation)
-Help, -h Show this help message
EXAMPLES:
# Check task prerequisites (plan.md required)
.\check-prerequisites.ps1 -Json
# Check implementation prerequisites (plan.md + tasks.md required)
.\check-prerequisites.ps1 -Json -RequireTasks -IncludeTasks
# Get feature paths only (no validation)
.\check-prerequisites.ps1 -PathsOnly
"@
exit 0
}
# Source common functions
. "$PSScriptRoot/common.ps1"
# Get feature paths and validate branch
$paths = Get-FeaturePathsEnv
if (-not (Test-FeatureBranch -Branch $paths.CURRENT_BRANCH -HasGit:$paths.HAS_GIT)) {
exit 1
}
# If paths-only mode, output paths and exit (support combined -Json -PathsOnly)
if ($PathsOnly) {
if ($Json) {
[PSCustomObject]@{
REPO_ROOT = $paths.REPO_ROOT
BRANCH = $paths.CURRENT_BRANCH
FEATURE_DIR = $paths.FEATURE_DIR
FEATURE_SPEC = $paths.FEATURE_SPEC
IMPL_PLAN = $paths.IMPL_PLAN
TASKS = $paths.TASKS
} | ConvertTo-Json -Compress
} else {
Write-Output "REPO_ROOT: $($paths.REPO_ROOT)"
Write-Output "BRANCH: $($paths.CURRENT_BRANCH)"
Write-Output "FEATURE_DIR: $($paths.FEATURE_DIR)"
Write-Output "FEATURE_SPEC: $($paths.FEATURE_SPEC)"
Write-Output "IMPL_PLAN: $($paths.IMPL_PLAN)"
Write-Output "TASKS: $($paths.TASKS)"
}
exit 0
}
# Validate required directories and files
if (-not (Test-Path $paths.FEATURE_DIR -PathType Container)) {
Write-Output "ERROR: Feature directory not found: $($paths.FEATURE_DIR)"
Write-Output "Run /speckit.specify first to create the feature structure."
exit 1
}
if (-not (Test-Path $paths.IMPL_PLAN -PathType Leaf)) {
Write-Output "ERROR: plan.md not found in $($paths.FEATURE_DIR)"
Write-Output "Run /speckit.plan first to create the implementation plan."
exit 1
}
# Check for tasks.md if required
if ($RequireTasks -and -not (Test-Path $paths.TASKS -PathType Leaf)) {
Write-Output "ERROR: tasks.md not found in $($paths.FEATURE_DIR)"
Write-Output "Run /speckit.tasks first to create the task list."
exit 1
}
# Build list of available documents
$docs = @()
# Always check these optional docs
if (Test-Path $paths.RESEARCH) { $docs += 'research.md' }
if (Test-Path $paths.DATA_MODEL) { $docs += 'data-model.md' }
# Check contracts directory (only if it exists and has files)
if ((Test-Path $paths.CONTRACTS_DIR) -and (Get-ChildItem -Path $paths.CONTRACTS_DIR -ErrorAction SilentlyContinue | Select-Object -First 1)) {
$docs += 'contracts/'
}
if (Test-Path $paths.QUICKSTART) { $docs += 'quickstart.md' }
# Include tasks.md if requested and it exists
if ($IncludeTasks -and (Test-Path $paths.TASKS)) {
$docs += 'tasks.md'
}
# Output results
if ($Json) {
# JSON output
[PSCustomObject]@{
FEATURE_DIR = $paths.FEATURE_DIR
AVAILABLE_DOCS = $docs
} | ConvertTo-Json -Compress
} else {
# Text output
Write-Output "FEATURE_DIR:$($paths.FEATURE_DIR)"
Write-Output "AVAILABLE_DOCS:"
# Show status of each potential document
Test-FileExists -Path $paths.RESEARCH -Description 'research.md' | Out-Null
Test-FileExists -Path $paths.DATA_MODEL -Description 'data-model.md' | Out-Null
Test-DirHasFiles -Path $paths.CONTRACTS_DIR -Description 'contracts/' | Out-Null
Test-FileExists -Path $paths.QUICKSTART -Description 'quickstart.md' | Out-Null
if ($IncludeTasks) {
Test-FileExists -Path $paths.TASKS -Description 'tasks.md' | Out-Null
}
}

View File

@ -0,0 +1,137 @@
#!/usr/bin/env pwsh
# Common PowerShell functions analogous to common.sh
function Get-RepoRoot {
try {
$result = git rev-parse --show-toplevel 2>$null
if ($LASTEXITCODE -eq 0) {
return $result
}
} catch {
# Git command failed
}
# Fall back to script location for non-git repos
return (Resolve-Path (Join-Path $PSScriptRoot "../../..")).Path
}
function Get-CurrentBranch {
# First check if SPECIFY_FEATURE environment variable is set
if ($env:SPECIFY_FEATURE) {
return $env:SPECIFY_FEATURE
}
# Then check git if available
try {
$result = git rev-parse --abbrev-ref HEAD 2>$null
if ($LASTEXITCODE -eq 0) {
return $result
}
} catch {
# Git command failed
}
# For non-git repos, try to find the latest feature directory
$repoRoot = Get-RepoRoot
$specsDir = Join-Path $repoRoot "specs"
if (Test-Path $specsDir) {
$latestFeature = ""
$highest = 0
Get-ChildItem -Path $specsDir -Directory | ForEach-Object {
if ($_.Name -match '^(\d{3})-') {
$num = [int]$matches[1]
if ($num -gt $highest) {
$highest = $num
$latestFeature = $_.Name
}
}
}
if ($latestFeature) {
return $latestFeature
}
}
# Final fallback
return "main"
}
function Test-HasGit {
try {
git rev-parse --show-toplevel 2>$null | Out-Null
return ($LASTEXITCODE -eq 0)
} catch {
return $false
}
}
function Test-FeatureBranch {
param(
[string]$Branch,
[bool]$HasGit = $true
)
# For non-git repos, we can't enforce branch naming but still provide output
if (-not $HasGit) {
Write-Warning "[specify] Warning: Git repository not detected; skipped branch validation"
return $true
}
if ($Branch -notmatch '^[0-9]{3}-') {
Write-Output "ERROR: Not on a feature branch. Current branch: $Branch"
Write-Output "Feature branches should be named like: 001-feature-name"
return $false
}
return $true
}
function Get-FeatureDir {
param([string]$RepoRoot, [string]$Branch)
Join-Path $RepoRoot "specs/$Branch"
}
function Get-FeaturePathsEnv {
$repoRoot = Get-RepoRoot
$currentBranch = Get-CurrentBranch
$hasGit = Test-HasGit
$featureDir = Get-FeatureDir -RepoRoot $repoRoot -Branch $currentBranch
[PSCustomObject]@{
REPO_ROOT = $repoRoot
CURRENT_BRANCH = $currentBranch
HAS_GIT = $hasGit
FEATURE_DIR = $featureDir
FEATURE_SPEC = Join-Path $featureDir 'spec.md'
IMPL_PLAN = Join-Path $featureDir 'plan.md'
TASKS = Join-Path $featureDir 'tasks.md'
RESEARCH = Join-Path $featureDir 'research.md'
DATA_MODEL = Join-Path $featureDir 'data-model.md'
QUICKSTART = Join-Path $featureDir 'quickstart.md'
CONTRACTS_DIR = Join-Path $featureDir 'contracts'
}
}
function Test-FileExists {
param([string]$Path, [string]$Description)
if (Test-Path -Path $Path -PathType Leaf) {
Write-Output "$Description"
return $true
} else {
Write-Output "$Description"
return $false
}
}
function Test-DirHasFiles {
param([string]$Path, [string]$Description)
if ((Test-Path -Path $Path -PathType Container) -and (Get-ChildItem -Path $Path -ErrorAction SilentlyContinue | Where-Object { -not $_.PSIsContainer } | Select-Object -First 1)) {
Write-Output "$Description"
return $true
} else {
Write-Output "$Description"
return $false
}
}

View File

@ -0,0 +1,290 @@
#!/usr/bin/env pwsh
# Create a new feature
[CmdletBinding()]
param(
[switch]$Json,
[string]$ShortName,
[int]$Number = 0,
[switch]$Help,
[Parameter(ValueFromRemainingArguments = $true)]
[string[]]$FeatureDescription
)
$ErrorActionPreference = 'Stop'
# Show help if requested
if ($Help) {
Write-Host "Usage: ./create-new-feature.ps1 [-Json] [-ShortName <name>] [-Number N] <feature description>"
Write-Host ""
Write-Host "Options:"
Write-Host " -Json Output in JSON format"
Write-Host " -ShortName <name> Provide a custom short name (2-4 words) for the branch"
Write-Host " -Number N Specify branch number manually (overrides auto-detection)"
Write-Host " -Help Show this help message"
Write-Host ""
Write-Host "Examples:"
Write-Host " ./create-new-feature.ps1 'Add user authentication system' -ShortName 'user-auth'"
Write-Host " ./create-new-feature.ps1 'Implement OAuth2 integration for API'"
exit 0
}
# Check if feature description provided
if (-not $FeatureDescription -or $FeatureDescription.Count -eq 0) {
Write-Error "Usage: ./create-new-feature.ps1 [-Json] [-ShortName <name>] <feature description>"
exit 1
}
$featureDesc = ($FeatureDescription -join ' ').Trim()
# Resolve repository root. Prefer git information when available, but fall back
# to searching for repository markers so the workflow still functions in repositories that
# were initialized with --no-git.
function Find-RepositoryRoot {
param(
[string]$StartDir,
[string[]]$Markers = @('.git', '.specify')
)
$current = Resolve-Path $StartDir
while ($true) {
foreach ($marker in $Markers) {
if (Test-Path (Join-Path $current $marker)) {
return $current
}
}
$parent = Split-Path $current -Parent
if ($parent -eq $current) {
# Reached filesystem root without finding markers
return $null
}
$current = $parent
}
}
function Get-NextBranchNumber {
param(
[string]$ShortName,
[string]$SpecsDir
)
# Fetch all remotes to get latest branch info (suppress errors if no remotes)
try {
git fetch --all --prune 2>$null | Out-Null
} catch {
# Ignore fetch errors
}
# Find remote branches matching the pattern using git ls-remote
$remoteBranches = @()
try {
$remoteRefs = git ls-remote --heads origin 2>$null
if ($remoteRefs) {
$remoteBranches = $remoteRefs | Where-Object { $_ -match "refs/heads/(\d+)-$([regex]::Escape($ShortName))$" } | ForEach-Object {
if ($_ -match "refs/heads/(\d+)-") {
[int]$matches[1]
}
}
}
} catch {
# Ignore errors
}
# Check local branches
$localBranches = @()
try {
$allBranches = git branch 2>$null
if ($allBranches) {
$localBranches = $allBranches | Where-Object { $_ -match "^\*?\s*(\d+)-$([regex]::Escape($ShortName))$" } | ForEach-Object {
if ($_ -match "(\d+)-") {
[int]$matches[1]
}
}
}
} catch {
# Ignore errors
}
# Check specs directory
$specDirs = @()
if (Test-Path $SpecsDir) {
try {
$specDirs = Get-ChildItem -Path $SpecsDir -Directory | Where-Object { $_.Name -match "^(\d+)-$([regex]::Escape($ShortName))$" } | ForEach-Object {
if ($_.Name -match "^(\d+)-") {
[int]$matches[1]
}
}
} catch {
# Ignore errors
}
}
# Combine all sources and get the highest number
$maxNum = 0
foreach ($num in ($remoteBranches + $localBranches + $specDirs)) {
if ($num -gt $maxNum) {
$maxNum = $num
}
}
# Return next number
return $maxNum + 1
}
$fallbackRoot = (Find-RepositoryRoot -StartDir $PSScriptRoot)
if (-not $fallbackRoot) {
Write-Error "Error: Could not determine repository root. Please run this script from within the repository."
exit 1
}
try {
$repoRoot = git rev-parse --show-toplevel 2>$null
if ($LASTEXITCODE -eq 0) {
$hasGit = $true
} else {
throw "Git not available"
}
} catch {
$repoRoot = $fallbackRoot
$hasGit = $false
}
Set-Location $repoRoot
$specsDir = Join-Path $repoRoot 'specs'
New-Item -ItemType Directory -Path $specsDir -Force | Out-Null
# Function to generate branch name with stop word filtering and length filtering
function Get-BranchName {
param([string]$Description)
# Common stop words to filter out
$stopWords = @(
'i', 'a', 'an', 'the', 'to', 'for', 'of', 'in', 'on', 'at', 'by', 'with', 'from',
'is', 'are', 'was', 'were', 'be', 'been', 'being', 'have', 'has', 'had',
'do', 'does', 'did', 'will', 'would', 'should', 'could', 'can', 'may', 'might', 'must', 'shall',
'this', 'that', 'these', 'those', 'my', 'your', 'our', 'their',
'want', 'need', 'add', 'get', 'set'
)
# Convert to lowercase and extract words (alphanumeric only)
$cleanName = $Description.ToLower() -replace '[^a-z0-9\s]', ' '
$words = $cleanName -split '\s+' | Where-Object { $_ }
# Filter words: remove stop words and words shorter than 3 chars (unless they're uppercase acronyms in original)
$meaningfulWords = @()
foreach ($word in $words) {
# Skip stop words
if ($stopWords -contains $word) { continue }
# Keep words that are length >= 3 OR appear as uppercase in original (likely acronyms)
if ($word.Length -ge 3) {
$meaningfulWords += $word
} elseif ($Description -match "\b$($word.ToUpper())\b") {
# Keep short words if they appear as uppercase in original (likely acronyms)
$meaningfulWords += $word
}
}
# If we have meaningful words, use first 3-4 of them
if ($meaningfulWords.Count -gt 0) {
$maxWords = if ($meaningfulWords.Count -eq 4) { 4 } else { 3 }
$result = ($meaningfulWords | Select-Object -First $maxWords) -join '-'
return $result
} else {
# Fallback to original logic if no meaningful words found
$result = $Description.ToLower() -replace '[^a-z0-9]', '-' -replace '-{2,}', '-' -replace '^-', '' -replace '-$', ''
$fallbackWords = ($result -split '-') | Where-Object { $_ } | Select-Object -First 3
return [string]::Join('-', $fallbackWords)
}
}
# Generate branch name
if ($ShortName) {
# Use provided short name, just clean it up
$branchSuffix = $ShortName.ToLower() -replace '[^a-z0-9]', '-' -replace '-{2,}', '-' -replace '^-', '' -replace '-$', ''
} else {
# Generate from description with smart filtering
$branchSuffix = Get-BranchName -Description $featureDesc
}
# Determine branch number
if ($Number -eq 0) {
if ($hasGit) {
# Check existing branches on remotes
$Number = Get-NextBranchNumber -ShortName $branchSuffix -SpecsDir $specsDir
} else {
# Fall back to local directory check
$highest = 0
if (Test-Path $specsDir) {
Get-ChildItem -Path $specsDir -Directory | ForEach-Object {
if ($_.Name -match '^(\d{3})') {
$num = [int]$matches[1]
if ($num -gt $highest) { $highest = $num }
}
}
}
$Number = $highest + 1
}
}
$featureNum = ('{0:000}' -f $Number)
$branchName = "$featureNum-$branchSuffix"
# GitHub enforces a 244-byte limit on branch names
# Validate and truncate if necessary
$maxBranchLength = 244
if ($branchName.Length -gt $maxBranchLength) {
# Calculate how much we need to trim from suffix
# Account for: feature number (3) + hyphen (1) = 4 chars
$maxSuffixLength = $maxBranchLength - 4
# Truncate suffix
$truncatedSuffix = $branchSuffix.Substring(0, [Math]::Min($branchSuffix.Length, $maxSuffixLength))
# Remove trailing hyphen if truncation created one
$truncatedSuffix = $truncatedSuffix -replace '-$', ''
$originalBranchName = $branchName
$branchName = "$featureNum-$truncatedSuffix"
Write-Warning "[specify] Branch name exceeded GitHub's 244-byte limit"
Write-Warning "[specify] Original: $originalBranchName ($($originalBranchName.Length) bytes)"
Write-Warning "[specify] Truncated to: $branchName ($($branchName.Length) bytes)"
}
if ($hasGit) {
try {
git checkout -b $branchName | Out-Null
} catch {
Write-Warning "Failed to create git branch: $branchName"
}
} else {
Write-Warning "[specify] Warning: Git repository not detected; skipped branch creation for $branchName"
}
$featureDir = Join-Path $specsDir $branchName
New-Item -ItemType Directory -Path $featureDir -Force | Out-Null
$template = Join-Path $repoRoot '.specify/templates/spec-template.md'
$specFile = Join-Path $featureDir 'spec.md'
if (Test-Path $template) {
Copy-Item $template $specFile -Force
} else {
New-Item -ItemType File -Path $specFile | Out-Null
}
# Set the SPECIFY_FEATURE environment variable for the current session
$env:SPECIFY_FEATURE = $branchName
if ($Json) {
$obj = [PSCustomObject]@{
BRANCH_NAME = $branchName
SPEC_FILE = $specFile
FEATURE_NUM = $featureNum
HAS_GIT = $hasGit
}
$obj | ConvertTo-Json -Compress
} else {
Write-Output "BRANCH_NAME: $branchName"
Write-Output "SPEC_FILE: $specFile"
Write-Output "FEATURE_NUM: $featureNum"
Write-Output "HAS_GIT: $hasGit"
Write-Output "SPECIFY_FEATURE environment variable set to: $branchName"
}

View File

@ -0,0 +1,62 @@
#!/usr/bin/env pwsh
# Setup implementation plan for a feature
[CmdletBinding()]
param(
[switch]$Json,
[switch]$Help
)
$ErrorActionPreference = 'Stop'
# Show help if requested
if ($Help) {
Write-Output "Usage: ./setup-plan.ps1 [-Json] [-Help]"
Write-Output " -Json Output results in JSON format"
Write-Output " -Help Show this help message"
exit 0
}
# Load common functions
. "$PSScriptRoot/common.ps1"
# Get all paths and variables from common functions
$paths = Get-FeaturePathsEnv
# Check if we're on a proper feature branch (only for git repos)
if (-not (Test-FeatureBranch -Branch $paths.CURRENT_BRANCH -HasGit $paths.HAS_GIT)) {
exit 1
}
# Ensure the feature directory exists
New-Item -ItemType Directory -Path $paths.FEATURE_DIR -Force | Out-Null
# Copy plan template if it exists, otherwise note it or create empty file
$template = Join-Path $paths.REPO_ROOT '.specify/templates/plan-template.md'
if (Test-Path $template) {
Copy-Item $template $paths.IMPL_PLAN -Force
Write-Output "Copied plan template to $($paths.IMPL_PLAN)"
} else {
Write-Warning "Plan template not found at $template"
# Create a basic plan file if template doesn't exist
New-Item -ItemType File -Path $paths.IMPL_PLAN -Force | Out-Null
}
# Output results
if ($Json) {
$result = [PSCustomObject]@{
FEATURE_SPEC = $paths.FEATURE_SPEC
IMPL_PLAN = $paths.IMPL_PLAN
SPECS_DIR = $paths.FEATURE_DIR
BRANCH = $paths.CURRENT_BRANCH
HAS_GIT = $paths.HAS_GIT
}
$result | ConvertTo-Json -Compress
} else {
Write-Output "FEATURE_SPEC: $($paths.FEATURE_SPEC)"
Write-Output "IMPL_PLAN: $($paths.IMPL_PLAN)"
Write-Output "SPECS_DIR: $($paths.FEATURE_DIR)"
Write-Output "BRANCH: $($paths.CURRENT_BRANCH)"
Write-Output "HAS_GIT: $($paths.HAS_GIT)"
}

View File

@ -0,0 +1,439 @@
#!/usr/bin/env pwsh
<#!
.SYNOPSIS
Update agent context files with information from plan.md (PowerShell version)
.DESCRIPTION
Mirrors the behavior of scripts/bash/update-agent-context.sh:
1. Environment Validation
2. Plan Data Extraction
3. Agent File Management (create from template or update existing)
4. Content Generation (technology stack, recent changes, timestamp)
5. Multi-Agent Support (claude, gemini, copilot, cursor-agent, qwen, opencode, codex, windsurf, kilocode, auggie, roo, amp, q)
.PARAMETER AgentType
Optional agent key to update a single agent. If omitted, updates all existing agent files (creating a default Claude file if none exist).
.EXAMPLE
./update-agent-context.ps1 -AgentType claude
.EXAMPLE
./update-agent-context.ps1 # Updates all existing agent files
.NOTES
Relies on common helper functions in common.ps1
#>
param(
[Parameter(Position=0)]
[ValidateSet('claude','gemini','copilot','cursor-agent','qwen','opencode','codex','windsurf','kilocode','auggie','roo','codebuddy','amp','q')]
[string]$AgentType
)
$ErrorActionPreference = 'Stop'
# Import common helpers
$ScriptDir = Split-Path -Parent $MyInvocation.MyCommand.Path
. (Join-Path $ScriptDir 'common.ps1')
# Acquire environment paths
$envData = Get-FeaturePathsEnv
$REPO_ROOT = $envData.REPO_ROOT
$CURRENT_BRANCH = $envData.CURRENT_BRANCH
$HAS_GIT = $envData.HAS_GIT
$IMPL_PLAN = $envData.IMPL_PLAN
$NEW_PLAN = $IMPL_PLAN
# Agent file paths
$CLAUDE_FILE = Join-Path $REPO_ROOT 'CLAUDE.md'
$GEMINI_FILE = Join-Path $REPO_ROOT 'GEMINI.md'
$COPILOT_FILE = Join-Path $REPO_ROOT '.github/copilot-instructions.md'
$CURSOR_FILE = Join-Path $REPO_ROOT '.cursor/rules/specify-rules.mdc'
$QWEN_FILE = Join-Path $REPO_ROOT 'QWEN.md'
$AGENTS_FILE = Join-Path $REPO_ROOT 'AGENTS.md'
$WINDSURF_FILE = Join-Path $REPO_ROOT '.windsurf/rules/specify-rules.md'
$KILOCODE_FILE = Join-Path $REPO_ROOT '.kilocode/rules/specify-rules.md'
$AUGGIE_FILE = Join-Path $REPO_ROOT '.augment/rules/specify-rules.md'
$ROO_FILE = Join-Path $REPO_ROOT '.roo/rules/specify-rules.md'
$CODEBUDDY_FILE = Join-Path $REPO_ROOT 'CODEBUDDY.md'
$AMP_FILE = Join-Path $REPO_ROOT 'AGENTS.md'
$Q_FILE = Join-Path $REPO_ROOT 'AGENTS.md'
$TEMPLATE_FILE = Join-Path $REPO_ROOT '.specify/templates/agent-file-template.md'
# Parsed plan data placeholders
$script:NEW_LANG = ''
$script:NEW_FRAMEWORK = ''
$script:NEW_DB = ''
$script:NEW_PROJECT_TYPE = ''
function Write-Info {
param(
[Parameter(Mandatory=$true)]
[string]$Message
)
Write-Host "INFO: $Message"
}
function Write-Success {
param(
[Parameter(Mandatory=$true)]
[string]$Message
)
Write-Host "$([char]0x2713) $Message"
}
function Write-WarningMsg {
param(
[Parameter(Mandatory=$true)]
[string]$Message
)
Write-Warning $Message
}
function Write-Err {
param(
[Parameter(Mandatory=$true)]
[string]$Message
)
Write-Host "ERROR: $Message" -ForegroundColor Red
}
function Validate-Environment {
if (-not $CURRENT_BRANCH) {
Write-Err 'Unable to determine current feature'
if ($HAS_GIT) { Write-Info "Make sure you're on a feature branch" } else { Write-Info 'Set SPECIFY_FEATURE environment variable or create a feature first' }
exit 1
}
if (-not (Test-Path $NEW_PLAN)) {
Write-Err "No plan.md found at $NEW_PLAN"
Write-Info 'Ensure you are working on a feature with a corresponding spec directory'
if (-not $HAS_GIT) { Write-Info 'Use: $env:SPECIFY_FEATURE=your-feature-name or create a new feature first' }
exit 1
}
if (-not (Test-Path $TEMPLATE_FILE)) {
Write-Err "Template file not found at $TEMPLATE_FILE"
Write-Info 'Run specify init to scaffold .specify/templates, or add agent-file-template.md there.'
exit 1
}
}
function Extract-PlanField {
param(
[Parameter(Mandatory=$true)]
[string]$FieldPattern,
[Parameter(Mandatory=$true)]
[string]$PlanFile
)
if (-not (Test-Path $PlanFile)) { return '' }
# Lines like **Language/Version**: Python 3.12
$regex = "^\*\*$([Regex]::Escape($FieldPattern))\*\*: (.+)$"
Get-Content -LiteralPath $PlanFile -Encoding utf8 | ForEach-Object {
if ($_ -match $regex) {
$val = $Matches[1].Trim()
if ($val -notin @('NEEDS CLARIFICATION','N/A')) { return $val }
}
} | Select-Object -First 1
}
function Parse-PlanData {
param(
[Parameter(Mandatory=$true)]
[string]$PlanFile
)
if (-not (Test-Path $PlanFile)) { Write-Err "Plan file not found: $PlanFile"; return $false }
Write-Info "Parsing plan data from $PlanFile"
$script:NEW_LANG = Extract-PlanField -FieldPattern 'Language/Version' -PlanFile $PlanFile
$script:NEW_FRAMEWORK = Extract-PlanField -FieldPattern 'Primary Dependencies' -PlanFile $PlanFile
$script:NEW_DB = Extract-PlanField -FieldPattern 'Storage' -PlanFile $PlanFile
$script:NEW_PROJECT_TYPE = Extract-PlanField -FieldPattern 'Project Type' -PlanFile $PlanFile
if ($NEW_LANG) { Write-Info "Found language: $NEW_LANG" } else { Write-WarningMsg 'No language information found in plan' }
if ($NEW_FRAMEWORK) { Write-Info "Found framework: $NEW_FRAMEWORK" }
if ($NEW_DB -and $NEW_DB -ne 'N/A') { Write-Info "Found database: $NEW_DB" }
if ($NEW_PROJECT_TYPE) { Write-Info "Found project type: $NEW_PROJECT_TYPE" }
return $true
}
function Format-TechnologyStack {
param(
[Parameter(Mandatory=$false)]
[string]$Lang,
[Parameter(Mandatory=$false)]
[string]$Framework
)
$parts = @()
if ($Lang -and $Lang -ne 'NEEDS CLARIFICATION') { $parts += $Lang }
if ($Framework -and $Framework -notin @('NEEDS CLARIFICATION','N/A')) { $parts += $Framework }
if (-not $parts) { return '' }
return ($parts -join ' + ')
}
function Get-ProjectStructure {
param(
[Parameter(Mandatory=$false)]
[string]$ProjectType
)
if ($ProjectType -match 'web') { return "backend/`nfrontend/`ntests/" } else { return "src/`ntests/" }
}
function Get-CommandsForLanguage {
param(
[Parameter(Mandatory=$false)]
[string]$Lang
)
switch -Regex ($Lang) {
'Python' { return "cd src; pytest; ruff check ." }
'Rust' { return "cargo test; cargo clippy" }
'JavaScript|TypeScript' { return "npm test; npm run lint" }
default { return "# Add commands for $Lang" }
}
}
function Get-LanguageConventions {
param(
[Parameter(Mandatory=$false)]
[string]$Lang
)
if ($Lang) { "${Lang}: Follow standard conventions" } else { 'General: Follow standard conventions' }
}
function New-AgentFile {
param(
[Parameter(Mandatory=$true)]
[string]$TargetFile,
[Parameter(Mandatory=$true)]
[string]$ProjectName,
[Parameter(Mandatory=$true)]
[datetime]$Date
)
if (-not (Test-Path $TEMPLATE_FILE)) { Write-Err "Template not found at $TEMPLATE_FILE"; return $false }
$temp = New-TemporaryFile
Copy-Item -LiteralPath $TEMPLATE_FILE -Destination $temp -Force
$projectStructure = Get-ProjectStructure -ProjectType $NEW_PROJECT_TYPE
$commands = Get-CommandsForLanguage -Lang $NEW_LANG
$languageConventions = Get-LanguageConventions -Lang $NEW_LANG
$escaped_lang = $NEW_LANG
$escaped_framework = $NEW_FRAMEWORK
$escaped_branch = $CURRENT_BRANCH
$content = Get-Content -LiteralPath $temp -Raw -Encoding utf8
$content = $content -replace '\[PROJECT NAME\]',$ProjectName
$content = $content -replace '\[DATE\]',$Date.ToString('yyyy-MM-dd')
# Build the technology stack string safely
$techStackForTemplate = ""
if ($escaped_lang -and $escaped_framework) {
$techStackForTemplate = "- $escaped_lang + $escaped_framework ($escaped_branch)"
} elseif ($escaped_lang) {
$techStackForTemplate = "- $escaped_lang ($escaped_branch)"
} elseif ($escaped_framework) {
$techStackForTemplate = "- $escaped_framework ($escaped_branch)"
}
$content = $content -replace '\[EXTRACTED FROM ALL PLAN.MD FILES\]',$techStackForTemplate
# For project structure we manually embed (keep newlines)
$escapedStructure = [Regex]::Escape($projectStructure)
$content = $content -replace '\[ACTUAL STRUCTURE FROM PLANS\]',$escapedStructure
# Replace escaped newlines placeholder after all replacements
$content = $content -replace '\[ONLY COMMANDS FOR ACTIVE TECHNOLOGIES\]',$commands
$content = $content -replace '\[LANGUAGE-SPECIFIC, ONLY FOR LANGUAGES IN USE\]',$languageConventions
# Build the recent changes string safely
$recentChangesForTemplate = ""
if ($escaped_lang -and $escaped_framework) {
$recentChangesForTemplate = "- ${escaped_branch}: Added ${escaped_lang} + ${escaped_framework}"
} elseif ($escaped_lang) {
$recentChangesForTemplate = "- ${escaped_branch}: Added ${escaped_lang}"
} elseif ($escaped_framework) {
$recentChangesForTemplate = "- ${escaped_branch}: Added ${escaped_framework}"
}
$content = $content -replace '\[LAST 3 FEATURES AND WHAT THEY ADDED\]',$recentChangesForTemplate
# Convert literal \n sequences introduced by Escape to real newlines
$content = $content -replace '\\n',[Environment]::NewLine
$parent = Split-Path -Parent $TargetFile
if (-not (Test-Path $parent)) { New-Item -ItemType Directory -Path $parent | Out-Null }
Set-Content -LiteralPath $TargetFile -Value $content -NoNewline -Encoding utf8
Remove-Item $temp -Force
return $true
}
function Update-ExistingAgentFile {
param(
[Parameter(Mandatory=$true)]
[string]$TargetFile,
[Parameter(Mandatory=$true)]
[datetime]$Date
)
if (-not (Test-Path $TargetFile)) { return (New-AgentFile -TargetFile $TargetFile -ProjectName (Split-Path $REPO_ROOT -Leaf) -Date $Date) }
$techStack = Format-TechnologyStack -Lang $NEW_LANG -Framework $NEW_FRAMEWORK
$newTechEntries = @()
if ($techStack) {
$escapedTechStack = [Regex]::Escape($techStack)
if (-not (Select-String -Pattern $escapedTechStack -Path $TargetFile -Quiet)) {
$newTechEntries += "- $techStack ($CURRENT_BRANCH)"
}
}
if ($NEW_DB -and $NEW_DB -notin @('N/A','NEEDS CLARIFICATION')) {
$escapedDB = [Regex]::Escape($NEW_DB)
if (-not (Select-String -Pattern $escapedDB -Path $TargetFile -Quiet)) {
$newTechEntries += "- $NEW_DB ($CURRENT_BRANCH)"
}
}
$newChangeEntry = ''
if ($techStack) { $newChangeEntry = "- ${CURRENT_BRANCH}: Added ${techStack}" }
elseif ($NEW_DB -and $NEW_DB -notin @('N/A','NEEDS CLARIFICATION')) { $newChangeEntry = "- ${CURRENT_BRANCH}: Added ${NEW_DB}" }
$lines = Get-Content -LiteralPath $TargetFile -Encoding utf8
$output = New-Object System.Collections.Generic.List[string]
$inTech = $false; $inChanges = $false; $techAdded = $false; $changeAdded = $false; $existingChanges = 0
for ($i=0; $i -lt $lines.Count; $i++) {
$line = $lines[$i]
if ($line -eq '## Active Technologies') {
$output.Add($line)
$inTech = $true
continue
}
if ($inTech -and $line -match '^##\s') {
if (-not $techAdded -and $newTechEntries.Count -gt 0) { $newTechEntries | ForEach-Object { $output.Add($_) }; $techAdded = $true }
$output.Add($line); $inTech = $false; continue
}
if ($inTech -and [string]::IsNullOrWhiteSpace($line)) {
if (-not $techAdded -and $newTechEntries.Count -gt 0) { $newTechEntries | ForEach-Object { $output.Add($_) }; $techAdded = $true }
$output.Add($line); continue
}
if ($line -eq '## Recent Changes') {
$output.Add($line)
if ($newChangeEntry) { $output.Add($newChangeEntry); $changeAdded = $true }
$inChanges = $true
continue
}
if ($inChanges -and $line -match '^##\s') { $output.Add($line); $inChanges = $false; continue }
if ($inChanges -and $line -match '^- ') {
if ($existingChanges -lt 2) { $output.Add($line); $existingChanges++ }
continue
}
if ($line -match '\*\*Last updated\*\*: .*\d{4}-\d{2}-\d{2}') {
$output.Add(($line -replace '\d{4}-\d{2}-\d{2}',$Date.ToString('yyyy-MM-dd')))
continue
}
$output.Add($line)
}
# Post-loop check: if we're still in the Active Technologies section and haven't added new entries
if ($inTech -and -not $techAdded -and $newTechEntries.Count -gt 0) {
$newTechEntries | ForEach-Object { $output.Add($_) }
}
Set-Content -LiteralPath $TargetFile -Value ($output -join [Environment]::NewLine) -Encoding utf8
return $true
}
function Update-AgentFile {
param(
[Parameter(Mandatory=$true)]
[string]$TargetFile,
[Parameter(Mandatory=$true)]
[string]$AgentName
)
if (-not $TargetFile -or -not $AgentName) { Write-Err 'Update-AgentFile requires TargetFile and AgentName'; return $false }
Write-Info "Updating $AgentName context file: $TargetFile"
$projectName = Split-Path $REPO_ROOT -Leaf
$date = Get-Date
$dir = Split-Path -Parent $TargetFile
if (-not (Test-Path $dir)) { New-Item -ItemType Directory -Path $dir | Out-Null }
if (-not (Test-Path $TargetFile)) {
if (New-AgentFile -TargetFile $TargetFile -ProjectName $projectName -Date $date) { Write-Success "Created new $AgentName context file" } else { Write-Err 'Failed to create new agent file'; return $false }
} else {
try {
if (Update-ExistingAgentFile -TargetFile $TargetFile -Date $date) { Write-Success "Updated existing $AgentName context file" } else { Write-Err 'Failed to update agent file'; return $false }
} catch {
Write-Err "Cannot access or update existing file: $TargetFile. $_"
return $false
}
}
return $true
}
function Update-SpecificAgent {
param(
[Parameter(Mandatory=$true)]
[string]$Type
)
switch ($Type) {
'claude' { Update-AgentFile -TargetFile $CLAUDE_FILE -AgentName 'Claude Code' }
'gemini' { Update-AgentFile -TargetFile $GEMINI_FILE -AgentName 'Gemini CLI' }
'copilot' { Update-AgentFile -TargetFile $COPILOT_FILE -AgentName 'GitHub Copilot' }
'cursor-agent' { Update-AgentFile -TargetFile $CURSOR_FILE -AgentName 'Cursor IDE' }
'qwen' { Update-AgentFile -TargetFile $QWEN_FILE -AgentName 'Qwen Code' }
'opencode' { Update-AgentFile -TargetFile $AGENTS_FILE -AgentName 'opencode' }
'codex' { Update-AgentFile -TargetFile $AGENTS_FILE -AgentName 'Codex CLI' }
'windsurf' { Update-AgentFile -TargetFile $WINDSURF_FILE -AgentName 'Windsurf' }
'kilocode' { Update-AgentFile -TargetFile $KILOCODE_FILE -AgentName 'Kilo Code' }
'auggie' { Update-AgentFile -TargetFile $AUGGIE_FILE -AgentName 'Auggie CLI' }
'roo' { Update-AgentFile -TargetFile $ROO_FILE -AgentName 'Roo Code' }
'codebuddy' { Update-AgentFile -TargetFile $CODEBUDDY_FILE -AgentName 'CodeBuddy CLI' }
'amp' { Update-AgentFile -TargetFile $AMP_FILE -AgentName 'Amp' }
'q' { Update-AgentFile -TargetFile $Q_FILE -AgentName 'Amazon Q Developer CLI' }
default { Write-Err "Unknown agent type '$Type'"; Write-Err 'Expected: claude|gemini|copilot|cursor-agent|qwen|opencode|codex|windsurf|kilocode|auggie|roo|codebuddy|amp|q'; return $false }
}
}
function Update-AllExistingAgents {
$found = $false
$ok = $true
if (Test-Path $CLAUDE_FILE) { if (-not (Update-AgentFile -TargetFile $CLAUDE_FILE -AgentName 'Claude Code')) { $ok = $false }; $found = $true }
if (Test-Path $GEMINI_FILE) { if (-not (Update-AgentFile -TargetFile $GEMINI_FILE -AgentName 'Gemini CLI')) { $ok = $false }; $found = $true }
if (Test-Path $COPILOT_FILE) { if (-not (Update-AgentFile -TargetFile $COPILOT_FILE -AgentName 'GitHub Copilot')) { $ok = $false }; $found = $true }
if (Test-Path $CURSOR_FILE) { if (-not (Update-AgentFile -TargetFile $CURSOR_FILE -AgentName 'Cursor IDE')) { $ok = $false }; $found = $true }
if (Test-Path $QWEN_FILE) { if (-not (Update-AgentFile -TargetFile $QWEN_FILE -AgentName 'Qwen Code')) { $ok = $false }; $found = $true }
if (Test-Path $AGENTS_FILE) { if (-not (Update-AgentFile -TargetFile $AGENTS_FILE -AgentName 'Codex/opencode')) { $ok = $false }; $found = $true }
if (Test-Path $WINDSURF_FILE) { if (-not (Update-AgentFile -TargetFile $WINDSURF_FILE -AgentName 'Windsurf')) { $ok = $false }; $found = $true }
if (Test-Path $KILOCODE_FILE) { if (-not (Update-AgentFile -TargetFile $KILOCODE_FILE -AgentName 'Kilo Code')) { $ok = $false }; $found = $true }
if (Test-Path $AUGGIE_FILE) { if (-not (Update-AgentFile -TargetFile $AUGGIE_FILE -AgentName 'Auggie CLI')) { $ok = $false }; $found = $true }
if (Test-Path $ROO_FILE) { if (-not (Update-AgentFile -TargetFile $ROO_FILE -AgentName 'Roo Code')) { $ok = $false }; $found = $true }
if (Test-Path $CODEBUDDY_FILE) { if (-not (Update-AgentFile -TargetFile $CODEBUDDY_FILE -AgentName 'CodeBuddy CLI')) { $ok = $false }; $found = $true }
if (Test-Path $Q_FILE) { if (-not (Update-AgentFile -TargetFile $Q_FILE -AgentName 'Amazon Q Developer CLI')) { $ok = $false }; $found = $true }
if (-not $found) {
Write-Info 'No existing agent files found, creating default Claude file...'
if (-not (Update-AgentFile -TargetFile $CLAUDE_FILE -AgentName 'Claude Code')) { $ok = $false }
}
return $ok
}
function Print-Summary {
Write-Host ''
Write-Info 'Summary of changes:'
if ($NEW_LANG) { Write-Host " - Added language: $NEW_LANG" }
if ($NEW_FRAMEWORK) { Write-Host " - Added framework: $NEW_FRAMEWORK" }
if ($NEW_DB -and $NEW_DB -ne 'N/A') { Write-Host " - Added database: $NEW_DB" }
Write-Host ''
Write-Info 'Usage: ./update-agent-context.ps1 [-AgentType claude|gemini|copilot|cursor-agent|qwen|opencode|codex|windsurf|kilocode|auggie|roo|codebuddy|amp|q]'
}
function Main {
Validate-Environment
Write-Info "=== Updating agent context files for feature $CURRENT_BRANCH ==="
if (-not (Parse-PlanData -PlanFile $NEW_PLAN)) { Write-Err 'Failed to parse plan data'; exit 1 }
$success = $true
if ($AgentType) {
Write-Info "Updating specific agent: $AgentType"
if (-not (Update-SpecificAgent -Type $AgentType)) { $success = $false }
}
else {
Write-Info 'No agent specified, updating all existing agent files...'
if (-not (Update-AllExistingAgents)) { $success = $false }
}
Print-Summary
if ($success) { Write-Success 'Agent context update completed successfully'; exit 0 } else { Write-Err 'Agent context update completed with errors'; exit 1 }
}
Main

View File

@ -0,0 +1,28 @@
# [PROJECT NAME] Development Guidelines
Auto-generated from all feature plans. Last updated: [DATE]
## Active Technologies
[EXTRACTED FROM ALL PLAN.MD FILES]
## Project Structure
```text
[ACTUAL STRUCTURE FROM PLANS]
```
## Commands
[ONLY COMMANDS FOR ACTIVE TECHNOLOGIES]
## Code Style
[LANGUAGE-SPECIFIC, ONLY FOR LANGUAGES IN USE]
## Recent Changes
[LAST 3 FEATURES AND WHAT THEY ADDED]
<!-- MANUAL ADDITIONS START -->
<!-- MANUAL ADDITIONS END -->

View File

@ -0,0 +1,40 @@
# [CHECKLIST TYPE] Checklist: [FEATURE NAME]
**Purpose**: [Brief description of what this checklist covers]
**Created**: [DATE]
**Feature**: [Link to spec.md or relevant documentation]
**Note**: This checklist is generated by the `/speckit.checklist` command based on feature context and requirements.
<!--
============================================================================
IMPORTANT: The checklist items below are SAMPLE ITEMS for illustration only.
The /speckit.checklist command MUST replace these with actual items based on:
- User's specific checklist request
- Feature requirements from spec.md
- Technical context from plan.md
- Implementation details from tasks.md
DO NOT keep these sample items in the generated checklist file.
============================================================================
-->
## [Category 1]
- [ ] CHK001 First checklist item with clear action
- [ ] CHK002 Second checklist item
- [ ] CHK003 Third checklist item
## [Category 2]
- [ ] CHK004 Another category item
- [ ] CHK005 Item with specific criteria
- [ ] CHK006 Final item in this category
## Notes
- Check items off as completed: `[x]`
- Add comments or findings inline
- Link to relevant resources or documentation
- Items are numbered sequentially for easy reference

View File

@ -0,0 +1,108 @@
# Implementation Plan: [FEATURE]
**Branch**: `[###-feature-name]` | **Date**: [DATE] | **Spec**: [link]
**Input**: Feature specification from `/specs/[###-feature-name]/spec.md`
**Note**: This template is filled in by the `/speckit.plan` command. See `.specify/templates/commands/plan.md` for the execution workflow.
## Summary
[Extract from feature spec: primary requirement + technical approach from research]
## Technical Context
<!--
ACTION REQUIRED: Replace the content in this section with the technical details
for the project. The structure here is presented in advisory capacity to guide
the iteration process.
-->
**Language/Version**: [e.g., Python 3.11, Swift 5.9, Rust 1.75 or NEEDS CLARIFICATION]
**Primary Dependencies**: [e.g., FastAPI, UIKit, LLVM or NEEDS CLARIFICATION]
**Storage**: [if applicable, e.g., PostgreSQL, CoreData, files or N/A]
**Testing**: [e.g., pytest, XCTest, cargo test or NEEDS CLARIFICATION]
**Target Platform**: [e.g., Linux server, iOS 15+, WASM or NEEDS CLARIFICATION]
**Project Type**: [single/web/mobile - determines source structure]
**Performance Goals**: [domain-specific, e.g., 1000 req/s, 10k lines/sec, 60 fps or NEEDS CLARIFICATION]
**Constraints**: [domain-specific, e.g., <200ms p95, <100MB memory, offline-capable or NEEDS CLARIFICATION]
**Scale/Scope**: [domain-specific, e.g., 10k users, 1M LOC, 50 screens or NEEDS CLARIFICATION]
## Constitution Check
*GATE: Must pass before Phase 0 research. Re-check after Phase 1 design.*
- Code Quality and Modularity: Confirm adherence to PEP 8, type hints, docstrings, and modular design.
- Rigorous Testing Standards: Plan includes unit tests, integration tests, and model validation tests.
- Reproducibility and Versioning: Versioning strategy for code, data, and models; random seed management.
- Model Evaluation and Validation: Metrics and bias checks defined for model assessment.
- Continuous Integration and Quality Gates: CI/CD pipeline with linting, testing, and performance checks.
## Project Structure
### Documentation (this feature)
```text
specs/[###-feature]/
├── plan.md # This file (/speckit.plan command output)
├── research.md # Phase 0 output (/speckit.plan command)
├── data-model.md # Phase 1 output (/speckit.plan command)
├── quickstart.md # Phase 1 output (/speckit.plan command)
├── contracts/ # Phase 1 output (/speckit.plan command)
└── tasks.md # Phase 2 output (/speckit.tasks command - NOT created by /speckit.plan)
```
### Source Code (repository root)
<!--
ACTION REQUIRED: Replace the placeholder tree below with the concrete layout
for this feature. Delete unused options and expand the chosen structure with
real paths (e.g., apps/admin, packages/something). The delivered plan must
not include Option labels.
-->
```text
# [REMOVE IF UNUSED] Option 1: Single project (DEFAULT)
src/
├── models/
├── services/
├── cli/
└── lib/
tests/
├── contract/
├── integration/
└── unit/
# [REMOVE IF UNUSED] Option 2: Web application (when "frontend" + "backend" detected)
backend/
├── src/
│ ├── models/
│ ├── services/
│ └── api/
└── tests/
frontend/
├── src/
│ ├── components/
│ ├── pages/
│ └── services/
└── tests/
# [REMOVE IF UNUSED] Option 3: Mobile + API (when "iOS/Android" detected)
api/
└── [same as backend above]
ios/ or android/
└── [platform-specific structure: feature modules, UI flows, platform tests]
```
**Structure Decision**: [Document the selected structure and reference the real
directories captured above]
## Complexity Tracking
> **Fill ONLY if Constitution Check has violations that must be justified**
| Violation | Why Needed | Simpler Alternative Rejected Because |
|-----------|------------|-------------------------------------|
| [e.g., 4th project] | [current need] | [why 3 projects insufficient] |
| [e.g., Repository pattern] | [specific problem] | [why direct DB access insufficient] |

View File

@ -0,0 +1,115 @@
# Feature Specification: [FEATURE NAME]
**Feature Branch**: `[###-feature-name]`
**Created**: [DATE]
**Status**: Draft
**Input**: User description: "$ARGUMENTS"
## User Scenarios & Testing *(mandatory)*
<!--
IMPORTANT: User stories should be PRIORITIZED as user journeys ordered by importance.
Each user story/journey must be INDEPENDENTLY TESTABLE - meaning if you implement just ONE of them,
you should still have a viable MVP (Minimum Viable Product) that delivers value.
Assign priorities (P1, P2, P3, etc.) to each story, where P1 is the most critical.
Think of each story as a standalone slice of functionality that can be:
- Developed independently
- Tested independently
- Deployed independently
- Demonstrated to users independently
-->
### User Story 1 - [Brief Title] (Priority: P1)
[Describe this user journey in plain language]
**Why this priority**: [Explain the value and why it has this priority level]
**Independent Test**: [Describe how this can be tested independently - e.g., "Can be fully tested by [specific action] and delivers [specific value]"]
**Acceptance Scenarios**:
1. **Given** [initial state], **When** [action], **Then** [expected outcome]
2. **Given** [initial state], **When** [action], **Then** [expected outcome]
---
### User Story 2 - [Brief Title] (Priority: P2)
[Describe this user journey in plain language]
**Why this priority**: [Explain the value and why it has this priority level]
**Independent Test**: [Describe how this can be tested independently]
**Acceptance Scenarios**:
1. **Given** [initial state], **When** [action], **Then** [expected outcome]
---
### User Story 3 - [Brief Title] (Priority: P3)
[Describe this user journey in plain language]
**Why this priority**: [Explain the value and why it has this priority level]
**Independent Test**: [Describe how this can be tested independently]
**Acceptance Scenarios**:
1. **Given** [initial state], **When** [action], **Then** [expected outcome]
---
[Add more user stories as needed, each with an assigned priority]
### Edge Cases
<!--
ACTION REQUIRED: The content in this section represents placeholders.
Fill them out with the right edge cases.
-->
- What happens when [boundary condition]?
- How does system handle [error scenario]?
## Requirements *(mandatory)*
<!--
ACTION REQUIRED: The content in this section represents placeholders.
Fill them out with the right functional requirements.
-->
### Functional Requirements
- **FR-001**: System MUST [specific capability, e.g., "allow users to create accounts"]
- **FR-002**: System MUST [specific capability, e.g., "validate email addresses"]
- **FR-003**: Users MUST be able to [key interaction, e.g., "reset their password"]
- **FR-004**: System MUST [data requirement, e.g., "persist user preferences"]
- **FR-005**: System MUST [behavior, e.g., "log all security events"]
*Example of marking unclear requirements:*
- **FR-006**: System MUST authenticate users via [NEEDS CLARIFICATION: auth method not specified - email/password, SSO, OAuth?]
- **FR-007**: System MUST retain user data for [NEEDS CLARIFICATION: retention period not specified]
### Key Entities *(include if feature involves data)*
- **[Entity 1]**: [What it represents, key attributes without implementation]
- **[Entity 2]**: [What it represents, relationships to other entities]
## Success Criteria *(mandatory)*
<!--
ACTION REQUIRED: Define measurable success criteria.
These must be technology-agnostic and measurable.
-->
### Measurable Outcomes
- **SC-001**: [Measurable metric, e.g., "Users can complete account creation in under 2 minutes"]
- **SC-002**: [Measurable metric, e.g., "System handles 1000 concurrent users without degradation"]
- **SC-003**: [User satisfaction metric, e.g., "90% of users successfully complete primary task on first attempt"]
- **SC-004**: [Business metric, e.g., "Reduce support tickets related to [X] by 50%"]

View File

@ -0,0 +1,251 @@
---
description: "Task list template for feature implementation"
---
# Tasks: [FEATURE NAME]
**Input**: Design documents from `/specs/[###-feature-name]/`
**Prerequisites**: plan.md (required), spec.md (required for user stories), research.md, data-model.md, contracts/
**Tests**: The examples below include test tasks. Tests are OPTIONAL - only include them if explicitly requested in the feature specification.
**Organization**: Tasks are grouped by user story to enable independent implementation and testing of each story.
## Format: `[ID] [P?] [Story] Description`
- **[P]**: Can run in parallel (different files, no dependencies)
- **[Story]**: Which user story this task belongs to (e.g., US1, US2, US3)
- Include exact file paths in descriptions
## Path Conventions
- **Single project**: `src/`, `tests/` at repository root
- **Web app**: `backend/src/`, `frontend/src/`
- **Mobile**: `api/src/`, `ios/src/` or `android/src/`
- Paths shown below assume single project - adjust based on plan.md structure
<!--
============================================================================
IMPORTANT: The tasks below are SAMPLE TASKS for illustration purposes only.
The /speckit.tasks command MUST replace these with actual tasks based on:
- User stories from spec.md (with their priorities P1, P2, P3...)
- Feature requirements from plan.md
- Entities from data-model.md
- Endpoints from contracts/
Tasks MUST be organized by user story so each story can be:
- Implemented independently
- Tested independently
- Delivered as an MVP increment
DO NOT keep these sample tasks in the generated tasks.md file.
============================================================================
-->
## Phase 1: Setup (Shared Infrastructure)
**Purpose**: Project initialization and basic structure
- [ ] T001 Create project structure per implementation plan
- [ ] T002 Initialize [language] project with [framework] dependencies
- [ ] T003 [P] Configure linting and formatting tools
---
## Phase 2: Foundational (Blocking Prerequisites)
**Purpose**: Core infrastructure that MUST be complete before ANY user story can be implemented
**⚠️ CRITICAL**: No user story work can begin until this phase is complete
Examples of foundational tasks (adjust based on your project):
- [ ] T004 Setup database schema and migrations framework
- [ ] T005 [P] Implement authentication/authorization framework
- [ ] T006 [P] Setup API routing and middleware structure
- [ ] T007 Create base models/entities that all stories depend on
- [ ] T008 Configure error handling and logging infrastructure
- [ ] T009 Setup environment configuration management
**Checkpoint**: Foundation ready - user story implementation can now begin in parallel
---
## Phase 3: User Story 1 - [Title] (Priority: P1) 🎯 MVP
**Goal**: [Brief description of what this story delivers]
**Independent Test**: [How to verify this story works on its own]
### Tests for User Story 1 (OPTIONAL - only if tests requested) ⚠️
> **NOTE: Write these tests FIRST, ensure they FAIL before implementation**
- [ ] T010 [P] [US1] Contract test for [endpoint] in tests/contract/test_[name].py
- [ ] T011 [P] [US1] Integration test for [user journey] in tests/integration/test_[name].py
### Implementation for User Story 1
- [ ] T012 [P] [US1] Create [Entity1] model in src/models/[entity1].py
- [ ] T013 [P] [US1] Create [Entity2] model in src/models/[entity2].py
- [ ] T014 [US1] Implement [Service] in src/services/[service].py (depends on T012, T013)
- [ ] T015 [US1] Implement [endpoint/feature] in src/[location]/[file].py
- [ ] T016 [US1] Add validation and error handling
- [ ] T017 [US1] Add logging for user story 1 operations
**Checkpoint**: At this point, User Story 1 should be fully functional and testable independently
---
## Phase 4: User Story 2 - [Title] (Priority: P2)
**Goal**: [Brief description of what this story delivers]
**Independent Test**: [How to verify this story works on its own]
### Tests for User Story 2 (OPTIONAL - only if tests requested) ⚠️
- [ ] T018 [P] [US2] Contract test for [endpoint] in tests/contract/test_[name].py
- [ ] T019 [P] [US2] Integration test for [user journey] in tests/integration/test_[name].py
### Implementation for User Story 2
- [ ] T020 [P] [US2] Create [Entity] model in src/models/[entity].py
- [ ] T021 [US2] Implement [Service] in src/services/[service].py
- [ ] T022 [US2] Implement [endpoint/feature] in src/[location]/[file].py
- [ ] T023 [US2] Integrate with User Story 1 components (if needed)
**Checkpoint**: At this point, User Stories 1 AND 2 should both work independently
---
## Phase 5: User Story 3 - [Title] (Priority: P3)
**Goal**: [Brief description of what this story delivers]
**Independent Test**: [How to verify this story works on its own]
### Tests for User Story 3 (OPTIONAL - only if tests requested) ⚠️
- [ ] T024 [P] [US3] Contract test for [endpoint] in tests/contract/test_[name].py
- [ ] T025 [P] [US3] Integration test for [user journey] in tests/integration/test_[name].py
### Implementation for User Story 3
- [ ] T026 [P] [US3] Create [Entity] model in src/models/[entity].py
- [ ] T027 [US3] Implement [Service] in src/services/[service].py
- [ ] T028 [US3] Implement [endpoint/feature] in src/[location]/[file].py
**Checkpoint**: All user stories should now be independently functional
---
[Add more user story phases as needed, following the same pattern]
---
## Phase N: Polish & Cross-Cutting Concerns
**Purpose**: Improvements that affect multiple user stories
- [ ] TXXX [P] Documentation updates in docs/
- [ ] TXXX Code cleanup and refactoring
- [ ] TXXX Performance optimization across all stories
- [ ] TXXX [P] Additional unit tests (if requested) in tests/unit/
- [ ] TXXX Security hardening
- [ ] TXXX Run quickstart.md validation
---
## Dependencies & Execution Order
### Phase Dependencies
- **Setup (Phase 1)**: No dependencies - can start immediately
- **Foundational (Phase 2)**: Depends on Setup completion - BLOCKS all user stories
- **User Stories (Phase 3+)**: All depend on Foundational phase completion
- User stories can then proceed in parallel (if staffed)
- Or sequentially in priority order (P1 → P2 → P3)
- **Polish (Final Phase)**: Depends on all desired user stories being complete
### User Story Dependencies
- **User Story 1 (P1)**: Can start after Foundational (Phase 2) - No dependencies on other stories
- **User Story 2 (P2)**: Can start after Foundational (Phase 2) - May integrate with US1 but should be independently testable
- **User Story 3 (P3)**: Can start after Foundational (Phase 2) - May integrate with US1/US2 but should be independently testable
### Within Each User Story
- Tests (if included) MUST be written and FAIL before implementation
- Models before services
- Services before endpoints
- Core implementation before integration
- Story complete before moving to next priority
### Parallel Opportunities
- All Setup tasks marked [P] can run in parallel
- All Foundational tasks marked [P] can run in parallel (within Phase 2)
- Once Foundational phase completes, all user stories can start in parallel (if team capacity allows)
- All tests for a user story marked [P] can run in parallel
- Models within a story marked [P] can run in parallel
- Different user stories can be worked on in parallel by different team members
---
## Parallel Example: User Story 1
```bash
# Launch all tests for User Story 1 together (if tests requested):
Task: "Contract test for [endpoint] in tests/contract/test_[name].py"
Task: "Integration test for [user journey] in tests/integration/test_[name].py"
# Launch all models for User Story 1 together:
Task: "Create [Entity1] model in src/models/[entity1].py"
Task: "Create [Entity2] model in src/models/[entity2].py"
```
---
## Implementation Strategy
### MVP First (User Story 1 Only)
1. Complete Phase 1: Setup
2. Complete Phase 2: Foundational (CRITICAL - blocks all stories)
3. Complete Phase 3: User Story 1
4. **STOP and VALIDATE**: Test User Story 1 independently
5. Deploy/demo if ready
### Incremental Delivery
1. Complete Setup + Foundational → Foundation ready
2. Add User Story 1 → Test independently → Deploy/Demo (MVP!)
3. Add User Story 2 → Test independently → Deploy/Demo
4. Add User Story 3 → Test independently → Deploy/Demo
5. Each story adds value without breaking previous stories
### Parallel Team Strategy
With multiple developers:
1. Team completes Setup + Foundational together
2. Once Foundational is done:
- Developer A: User Story 1
- Developer B: User Story 2
- Developer C: User Story 3
3. Stories complete and integrate independently
---
## Notes
- [P] tasks = different files, no dependencies
- [Story] label maps task to specific user story for traceability
- Each user story should be independently completable and testable
- Verify tests fail before implementing
- Commit after each task or logical group
- Stop at any checkpoint to validate story independently
- Avoid: vague tasks, same file conflicts, cross-story dependencies that break independence

View File

@ -0,0 +1,203 @@
# ResNet Phenology Classifier - Development Guide
## Development Setup
### Prerequisites
- Python 3.11+
- CUDA-capable GPU (recommended)
- 8GB+ RAM
- Git
### Environment Setup
1. Clone the repository:
```bash
git clone <repository_url>
cd resnet
```
2. Create virtual environment:
```bash
python -m venv venv
source venv/bin/activate # Windows: venv\Scripts\activate
```
3. Install dependencies:
```bash
pip install -r requirements.txt
```
4. Install development dependencies:
```bash
pip install black flake8 mypy pylint pytest-cov
```
## Code Quality Standards
### PEP 8 Compliance
All code must follow PEP 8 standards:
```bash
flake8 src/ tests/
```
### Type Hints
Use type hints for all functions:
```python
def train_model(epochs: int, lr: float) -> dict:
...
```
### Docstrings
All modules, classes, and functions must have docstrings:
```python
def function(arg: str) -> int:
"""
Brief description.
Args:
arg: Description
Returns:
Description
"""
pass
```
## Testing
### Running Tests
```bash
# All tests
pytest tests/ -v
# With coverage
pytest tests/ --cov=src --cov-report=html
# Specific markers
pytest tests/ -m unit
pytest tests/ -m integration
pytest tests/ -m slow
```
### Writing Tests
- Unit tests for all utility functions
- Integration tests for data pipelines
- Model validation tests
- Use fixtures for common setup
### Test Coverage
- Minimum 80% code coverage
- 100% coverage for critical paths
## Continuous Integration
### Pre-commit Checks
Before committing:
1. Run linter: `flake8 src/ tests/`
2. Run type checker: `mypy src/`
3. Run tests: `pytest tests/ -v`
4. Check formatting: `black --check src/ tests/`
### CI Pipeline
The CI/CD pipeline runs:
1. Linting (flake8, pylint)
2. Type checking (mypy)
3. Unit tests
4. Integration tests
5. Coverage report
## Model Development
### Training Best Practices
1. Always set random seed
2. Use validation set for hyperparameter tuning
3. Save checkpoints regularly
4. Monitor training metrics
5. Use early stopping
### Evaluation
- Evaluate on independent test set
- Report multiple metrics (accuracy, recall, F1)
- Analyze confusion matrix
- Check for bias
### Versioning
- Version models with timestamp
- Track hyperparameters
- Save class mappings
- Document training data
## Git Workflow
### Branching Strategy
- `master`: Production-ready code
- `1-phenology-classifier`: Feature branch
- Feature branches for new capabilities
### Commit Messages
Follow conventional commits:
```
feat: add confusion matrix visualization
fix: correct data loader split logic
docs: update README with API examples
test: add unit tests for inference
```
## Performance Optimization
### Training
- Use mixed precision training
- Optimize data loading (num_workers)
- Use GPU if available
- Batch size tuning
### Inference
- Model quantization
- Batch predictions
- Cache loaded models
- Optimize image preprocessing
## Troubleshooting
### Common Issues
**CUDA out of memory:**
- Reduce batch size
- Use gradient accumulation
- Clear cache: `torch.cuda.empty_cache()`
**Slow data loading:**
- Increase num_workers
- Use SSD for dataset
- Preprocess images offline
**Poor accuracy:**
- Check data quality
- Increase training epochs
- Try different learning rates
- Use data augmentation
## Documentation
### Code Documentation
- Docstrings for all public APIs
- Inline comments for complex logic
- Type hints throughout
### Project Documentation
- Update README for new features
- Document API changes
- Maintain changelog
## Release Process
1. Update version number
2. Run full test suite
3. Build documentation
4. Create release notes
5. Tag release in git
6. Deploy to production
## Contact
For questions or issues, refer to the project specifications in `specs/1-phenology-classifier/`.

View File

@ -0,0 +1,224 @@
# ResNet Phenology Classifier
A deep learning model for classifying plant images by phenological phase using ResNet50 architecture.
## Features
- **Training**: Train ResNet50 model on labeled plant images
- **Evaluation**: Comprehensive metrics including accuracy, recall, macro-F1, and confusion matrix
- **Inference**: Classify new images with visual output
- **API**: REST API for batch classification
- **Reproducibility**: Random seed management and versioning
## Dataset
The model uses the dataset located at:
```
C:\Users\sof12\Desktop\ML\Datasets\Nocciola\GBIF
```
Dataset split: 70% training, 15% validation, 15% testing
## Installation
1. Create a virtual environment:
```bash
python -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate
```
2. Install dependencies:
```bash
pip install -r requirements.txt
```
## Usage
### Training
Train the model on your dataset:
```bash
python src/train.py \
--data_dir C:\Users\sof12\Desktop\ML\Datasets\Nocciola\GBIF \
--csv_file C:\Users\sof12\Desktop\ML\Datasets\Nocciola\GBIF\labels.csv \
--output_dir models \
--epochs 50 \
--batch_size 32 \
--lr 0.001
```
**Arguments:**
- `--data_dir`: Directory containing images
- `--csv_file`: Path to CSV file with columns: `image_path`, `phase`
- `--output_dir`: Directory to save trained models (default: `models`)
- `--epochs`: Number of training epochs (default: 50)
- `--batch_size`: Batch size (default: 32)
- `--lr`: Learning rate (default: 0.001)
- `--num_workers`: Data loader workers (default: 4)
- `--seed`: Random seed for reproducibility (default: 42)
- `--patience`: Early stopping patience (default: 10)
### Evaluation
Evaluate a trained model:
```bash
python src/evaluate.py \
--model_path models/best_model.pth \
--data_dir C:\Users\sof12\Desktop\ML\Datasets\Nocciola\GBIF \
--csv_file C:\Users\sof12\Desktop\ML\Datasets\Nocciola\GBIF\labels.csv \
--class_mapping models/class_mapping.json \
--output_dir evaluation \
--split test
```
**Arguments:**
- `--model_path`: Path to trained model checkpoint
- `--data_dir`: Directory containing images
- `--csv_file`: Path to CSV file with labels
- `--class_mapping`: Path to class mapping JSON file
- `--output_dir`: Directory to save evaluation results (default: `evaluation`)
- `--split`: Dataset split to evaluate (`val` or `test`, default: `test`)
**Output:**
- Accuracy, Recall (macro), F1 (macro)
- Per-class metrics
- Confusion matrix (normalized and raw)
- Classification report
### Inference
Classify a single image:
```bash
python src/inference.py \
--image_path path/to/image.jpg \
--model_path models/best_model.pth \
--class_mapping models/class_mapping.json \
--output_dir results
```
**Arguments:**
- `--image_path`: Path to image file
- `--model_path`: Path to trained model checkpoint
- `--class_mapping`: Path to class mapping JSON file
- `--output_dir`: Directory to save results (optional)
- `--no_visualize`: Skip creating visualization
**Output:**
- Predicted phenological phase
- Confidence score
- Probabilities for all classes
- Visual output showing prediction and probabilities
### API Server
Start the FastAPI server:
```bash
# Set environment variables
export MODEL_PATH=models/best_model.pth
export CLASS_MAPPING_PATH=models/class_mapping.json
# Start server
python -m uvicorn src.api:app --host 0.0.0.0 --port 8000 --reload
```
**Endpoints:**
- `GET /`: API information
- `GET /health`: Health check
- `GET /classes`: List available classes
- `POST /classify`: Classify single image
- `POST /classify/batch`: Classify multiple images (max 10)
**Example request:**
```bash
curl -X POST "http://localhost:8000/classify" \
-H "Content-Type: multipart/form-data" \
-F "file=@path/to/image.jpg"
```
## Testing
Run unit tests:
```bash
pytest tests/ -v
```
Run specific test categories:
```bash
pytest tests/ -v -m unit # Unit tests only
pytest tests/ -v -m integration # Integration tests only
```
## Project Structure
```
.
├── src/
│ ├── __init__.py
│ ├── data_loader.py # Dataset and data loading
│ ├── model.py # ResNet model definition
│ ├── train.py # Training script
│ ├── evaluate.py # Evaluation script
│ ├── inference.py # Inference script
│ ├── api.py # FastAPI application
│ └── utils.py # Utility functions
├── tests/
│ ├── test_data_loader.py
│ ├── test_train.py
│ ├── test_evaluate.py
│ └── test_inference.py
├── data/ # Dataset directory
├── models/ # Saved models
├── specs/ # Feature specifications
├── requirements.txt # Python dependencies
├── pytest.ini # Pytest configuration
└── README.md # This file
```
## Model Architecture
- **Base**: ResNet50 pretrained on ImageNet
- **Classification Head**:
- Dropout (0.5)
- Linear (2048 → 512)
- ReLU
- Dropout (0.3)
- Linear (512 → num_classes)
## Performance Requirements
- **Accuracy**: >90% on test set
- **Training Time**: <1 hour for standard dataset
- **Inference Time**: <1 second per image
## Data Format
The CSV file should have the following format:
```csv
image_path,phase
images/img001.jpg,vegetative
images/img002.jpg,flowering
images/img003.jpg,fruiting
```
## Reproducibility
The project ensures reproducibility through:
- Random seed management (default: 42)
- Deterministic training
- Model checkpointing
- Class mapping versioning
## License
This project is part of a supervised learning implementation for phenology classification.
## Citation
If you use this code, please cite the project specifications in `specs/1-phenology-classifier/`.

View File

@ -0,0 +1,10 @@
[pytest]
testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
addopts = -v --tb=short --strict-markers
markers =
unit: Unit tests
integration: Integration tests
slow: Slow running tests

View File

@ -0,0 +1,13 @@
torch>=2.0.0
torchvision>=0.15.0
pandas>=2.0.0
scikit-learn>=1.3.0
matplotlib>=3.7.0
seaborn>=0.12.0
Pillow>=10.0.0
numpy>=1.24.0
tqdm>=4.65.0
pytest>=7.4.0
fastapi>=0.100.0
uvicorn>=0.23.0
python-multipart>=0.0.6

View File

@ -0,0 +1,35 @@
# Specification Quality Checklist: ResNet Phenology Classifier
**Purpose**: Validate specification completeness and quality before proceeding to planning
**Created**: 2025-11-04
**Feature**: [Link to spec.md](specs/1-phenology-classifier/spec.md)
## Content Quality
- [x] No implementation details (languages, frameworks, APIs)
- [x] Focused on user value and business needs
- [x] Written for non-technical stakeholders
- [x] All mandatory sections completed
## Requirement Completeness
- [x] No [NEEDS CLARIFICATION] markers remain
- [x] Requirements are testable and unambiguous
- [x] Success criteria are measurable
- [x] Success criteria are technology-agnostic (no implementation details)
- [x] All acceptance scenarios are defined
- [x] Edge cases are identified
- [x] Scope is clearly bounded
- [x] Dependencies and assumptions identified
## Feature Readiness
- [x] All functional requirements have clear acceptance criteria
- [x] User scenarios cover primary flows
- [x] Feature meets measurable outcomes defined in Success Criteria
- [x] No implementation details leak into specification
## Notes
- Phenological phases are not specified; may need clarification.
- Assumed standard dataset format and phases based on common plant phenology.

View File

@ -0,0 +1,60 @@
openapi: 3.0.3
info:
title: ResNet Phenology Classifier API
version: 1.0.0
description: API for classifying plant images by phenological phase using ResNet model
paths:
/classify:
post:
summary: Classify a plant image
requestBody:
required: true
content:
multipart/form-data:
schema:
type: object
properties:
image:
type: string
format: binary
description: Plant image file (JPEG/PNG)
required:
- image
responses:
'200':
description: Classification result
content:
application/json:
schema:
type: object
properties:
phase:
type: string
description: Predicted phenological phase
confidence:
type: number
format: float
description: Confidence score (0-1)
probabilities:
type: object
description: Probabilities for all phases
additionalProperties:
type: number
format: float
'400':
description: Invalid input
'500':
description: Server error
components:
schemas:
ClassificationResult:
type: object
properties:
phase:
type: string
confidence:
type: number
probabilities:
type: object

View File

@ -0,0 +1,41 @@
# Data Model: ResNet Phenology Classifier
**Date**: 2025-11-04
**Feature**: specs/1-phenology-classifier/spec.md
## Entities
### Plant Image
**Description**: Represents an image of a plant used for training or inference.
**Fields**:
- `path` (string): File path to the image (required)
- `phase` (string): Phenological phase label (required)
- `width` (int): Image width in pixels
- `height` (int): Image height in pixels
**Validation Rules**:
- Path must exist and be readable image file
- Phase must be one of the defined classes in CSV
- Image dimensions must be at least 224x224 for ResNet input
### Phenological Phase
**Description**: Represents a growth stage classification for plants.
**Fields**:
- `name` (string): Phase identifier (e.g., "vegetative", "flowering") (required)
- `description` (string): Human-readable description of the phase
- `index` (int): Numerical index for model output (0-based)
**Validation Rules**:
- Name must be unique
- Index must be sequential starting from 0
## Relationships
- Plant Image belongs to Phenological Phase (many-to-one)
- Phases are defined in the labels CSV file
## State Transitions
N/A - Static classification task

View File

@ -0,0 +1,88 @@
# Implementation Plan: ResNet Phenology Classifier
**Branch**: `1-phenology-classifier` | **Date**: 2025-11-04 | **Spec**: specs/1-phenology-classifier/spec.md
**Input**: Feature specification from `/specs/1-phenology-classifier/spec.md`
**Note**: This template is filled in by the `/speckit.plan` command. See `.specify/templates/commands/plan.md` for the execution workflow.
## Summary
Build a ResNet model to classify plant images by phenological phase using labeled datasets. Technical approach: Use PyTorch to implement ResNet architecture, train on image dataset with CSV labels, evaluate performance metrics.
## Technical Context
**Language/Version**: Python 3.11
**Primary Dependencies**: PyTorch, torchvision, pandas, scikit-learn
**Storage**: Image files (JPEG/PNG) and CSV labels
**Testing**: pytest, torch.testing
**Target Platform**: Linux server with GPU support
**Project Type**: single (ML model training and inference script)
**Performance Goals**: >90% classification accuracy, <1 second inference time
**Constraints**: GPU memory availability, dataset size up to 10k images
**Scale/Scope**: Support for multiple phenological phases as defined in CSV
## Constitution Check
*GATE: Must pass before Phase 0 research. Re-check after Phase 1 design.*
- Code Quality and Modularity: Confirm adherence to PEP 8, type hints, docstrings, and modular design.
- Rigorous Testing Standards: Plan includes unit tests, integration tests, and model validation tests.
- Reproducibility and Versioning: Versioning strategy for code, data, and models; random seed management.
- Model Evaluation and Validation: Metrics and bias checks defined for model assessment.
- Continuous Integration and Quality Gates: CI/CD pipeline with linting, testing, and performance checks.
## Project Structure
### Documentation (this feature)
```text
specs/1-phenology-classifier/
├── plan.md # This file (/speckit.plan command output)
├── research.md # Phase 0 output (/speckit.plan command)
├── data-model.md # Phase 1 output (/speckit.plan command)
├── quickstart.md # Phase 1 output (/speckit.plan command)
├── contracts/ # Phase 1 output (/speckit.plan command)
└── tasks.md # Phase 2 output (/speckit.tasks command - NOT created by /speckit.plan)
```
### Source Code (repository root)
```text
src/
├── __init__.py
├── data_loader.py # Load images and labels from CSV
├── model.py # ResNet model definition
├── train.py # Training script
├── evaluate.py # Evaluation script
├── inference.py # Inference on new images
└── utils.py # Helper functions
tests/
├── test_data_loader.py
├── test_model.py
├── test_train.py
└── test_inference.py
data/
├── images/ # Plant images
└── labels.csv # Image labels
models/ # Saved trained models
requirements.txt # Python dependencies
```
### Phase Execution
**Phase 0: Research** (resolve unknowns)
- Research best ResNet variant for image classification
- Research data preprocessing for plant images
- Research evaluation metrics for multi-class classification
**Phase 1: Design** (create contracts)
- Define data model for images and labels
- Design API for model inference
- Create quickstart guide
**Phase 2: Tasks** (implementation breakdown)
- Break down into specific coding tasks
- Assign priorities and dependencies

View File

@ -0,0 +1,68 @@
# Quickstart: ResNet Phenology Classifier
**Date**: 2025-11-04
**Feature**: specs/1-phenology-classifier/spec.md
## Prerequisites
- Python 3.11+
- GPU with CUDA support (recommended)
- 4GB+ RAM
- Dataset with plant images and labels CSV
## Installation
1. Clone the repository and checkout the feature branch:
```bash
git checkout 1-phenology-classifier
```
2. Install dependencies:
```bash
pip install -r requirements.txt
```
3. Prepare your dataset:
- Place images in `data/images/`
- Create `data/labels.csv` with columns: `image_path`, `phase`
## Training
Run the training script:
```bash
python src/train.py --data_dir data/ --epochs 50 --batch_size 32
```
This will:
- Load the dataset
- Train ResNet50 on your data
- Save the model to `models/phenology_classifier.pth`
## Evaluation
Evaluate the trained model:
```bash
python src/evaluate.py --model_path models/phenology_classifier.pth --data_dir data/
```
This outputs accuracy, F1-score, and per-class metrics.
## Inference
Classify a new image:
```bash
python src/inference.py --model_path models/phenology_classifier.pth --image_path path/to/image.jpg
```
Or start the API server:
```bash
python -m uvicorn src.api:app --reload
```
Then POST to `http://localhost:8000/classify` with image file.
## Expected Results
- Training time: ~30 minutes on GPU
- Accuracy: >90% on validation set
- Inference time: <1 second per image

View File

@ -0,0 +1,39 @@
# Research: ResNet Phenology Classifier
**Date**: 2025-11-04
**Feature**: specs/1-phenology-classifier/spec.md
## Research Tasks
1. Research best ResNet variant for plant image classification
2. Research data preprocessing techniques for botanical images
3. Research evaluation metrics for multi-class phenological phase classification
4. Research reproducibility practices for ML experiments
## Findings & Decisions
### ResNet Variant
**Decision**: Use ResNet50 as the base architecture
**Rationale**: Provides good balance between accuracy and computational efficiency. ResNet50 has been proven effective for image classification tasks similar to ImageNet.
**Alternatives considered**: ResNet18 (faster but lower accuracy), ResNet101 (higher accuracy but more compute-intensive)
### Data Preprocessing
**Decision**: Use standard ImageNet preprocessing with augmentation
**Rationale**: Random cropping, horizontal flipping, normalization to ImageNet mean/std. This is standard for transfer learning with ResNet.
**Alternatives considered**: Custom augmentations for plant-specific features, but standard works well for general classification.
### Evaluation Metrics
**Decision**: Primary: Accuracy, Secondary: F1-score per class, Precision, Recall
**Rationale**: Accuracy for overall performance, F1-score to handle class imbalance in phenological phases.
**Alternatives considered**: AUC-ROC (more for binary), but multi-class metrics are appropriate.
### Reproducibility
**Decision**: Use random seeds, version data with DVC, log all hyperparameters
**Rationale**: Ensures experiments can be reproduced. DVC for data versioning, MLflow or similar for experiment tracking.
**Alternatives considered**: Manual logging, but automated tools are more reliable.
## Resolved Clarifications
- Dataset format: Images in directory, labels in CSV with columns: image_path, phase
- Model output: Probabilities for each phase class
- Training hardware: GPU required for reasonable training time

View File

@ -0,0 +1,94 @@
# Feature Specification: ResNet Phenology Classifier
**Feature Branch**: `1-phenology-classifier`
**Created**: 2025-11-04
**Status**: Draft
**Input**: User description: "Construye un modelo ResNet que tenga la capacidad de clasificar por fase fenologica las imagenes de una planta. La imagenes estan dadas en datasets y etiquetadas de pendiendo de su fase"
## User Scenarios & Testing *(mandatory)*
### User Story 1 - Train ResNet Model (Priority: P1)
As a researcher, I want to train a ResNet model on a dataset of labeled plant images to learn phenological phase classification.
**Why this priority**: This is the core functionality required to build the classifier.
**Independent Test**: Can be fully tested by training on a subset of the dataset and validating that the model learns to classify phases accurately.
**Acceptance Scenarios**:
1. **Given** a dataset of plant images labeled by phenological phase, **When** the model is trained, **Then** it achieves at least 90% accuracy on a held-out test set.
2. **Given** training data, **When** training completes, **Then** the model can be saved and loaded for inference.
---
### User Story 2 - Evaluate Model Performance (Priority: P2)
As a researcher, I want to evaluate the trained model's performance on unseen data to ensure reliability.
**Why this priority**: Evaluation is essential to validate the model's effectiveness before deployment.
**Independent Test**: Can be tested by running evaluation on test data and checking metrics like accuracy, precision, and recall.
**Acceptance Scenarios**:
1. **Given** a trained model and test dataset, **When** evaluation is run, **Then** detailed metrics are provided including accuracy, recall, macro-f1, and confusion matrix.
2. **Given** evaluation results, **When** performance is below threshold, **Then** the model is flagged for retraining.
---
### User Story 3 - Classify New Images (Priority: P3)
As a user, I want to use the trained model to classify new plant images by phenological phase.
**Why this priority**: This enables practical use of the model for monitoring or analysis.
**Independent Test**: Can be tested by providing a new image and verifying the predicted phase matches expectations.
**Acceptance Scenarios**:
1. **Given** a new plant image, **When** classification is requested, **Then** the model returns the predicted phenological phase with visual output.
2. **Given** an image, **When** classified, **Then** response time is under 1 second.
---
## Clarifications
### Session 2025-11-04
- Q: How should the dataset be split for training, validation, and testing? → A: 70/15/15
- Q: What evaluation metrics must be included? → A: Accuracy, recall, macro-f1, confusion matrix; other metrics accepted
- Q: How should classification results be returned? → A: Visually
### Edge Cases
- What happens when an image is of poor quality or not a plant?
- How does the system handle images with multiple plants or unclear phases?
- What if the dataset has imbalanced classes for certain phases?
## Requirements *(mandatory)*
### Functional Requirements
- **FR-001**: System MUST load and preprocess labeled plant image datasets.
- **FR-002**: System MUST train a ResNet model on the dataset to classify phenological phases specified in the dataset's .csv labels file.
- **FR-003**: System MUST evaluate the model using metrics including accuracy, recall, macro-f1, and confusion matrix. Other metrics are accepted.
- **FR-004**: System MUST provide inference capability to classify new images.
- **FR-005**: System MUST save and load trained models for reuse.
- **FR-006**: System MUST split the dataset located at C:\Users\sof12\Desktop\ML\Datasets\Nocciola\GBIF into 70% training, 15% validation, 15% testing.
- **FR-007**: System MUST provide visual output of classification results.
### Key Entities *(include if feature involves data)*
- **Dataset**: Located at C:\Users\sof12\Desktop\ML\Datasets\Nocciola\GBIF, containing plant images and labels.
- **Plant Image**: Represents an image of a plant, with attributes like image data and associated phenological phase label.
- **Phenological Phase**: Represents a growth stage of the plant, with attributes like phase name and description.
## Success Criteria *(mandatory)*
### Measurable Outcomes
- **SC-001**: Model achieves >90% accuracy on a held-out test set.
- **SC-002**: Training process completes within 1 hour for a standard dataset size.
- **SC-003**: Inference on a single image takes less than 1 second.
- **SC-004**: System handles datasets with up to 10,000 images without performance degradation.

View File

@ -0,0 +1,97 @@
# Tasks: ResNet Phenology Classifier
**Input**: Design documents from `/specs/1-phenology-classifier/`
**Prerequisites**: plan.md (required), spec.md (required for user stories), research.md, data-model.md, contracts/
**Tests**: Included as per constitution testing standards.
**Organization**: Tasks are grouped by user story to enable independent implementation and testing of each story.
## Format: `[ID] [P?] [Story] Description`
- **[P]**: Can run in parallel (different files, no dependencies)
- **[Story]**: Which user story this task belongs to (e.g., US1, US2, US3)
- Include exact file paths in descriptions
## Path Conventions
- **Single project**: `src/`, `tests/` at repository root
## Dependencies
- US1 (Train Model) must complete before US2 (Evaluate) and US3 (Classify)
- US2 and US3 can run in parallel after US1
## Parallel Execution Examples
- Setup tasks: T001-T005 can run in parallel
- US1 tasks: T010-T015 can run in parallel except for training which depends on data loading
- US2 and US3: Can run in parallel after US1 completion
## Implementation Strategy
- MVP: Complete US1 for basic training capability
- Incremental: Add US2 for evaluation, then US3 for inference
- Each user story delivers independently testable value
## Phase 1: Setup (Shared Infrastructure)
**Purpose**: Project initialization and basic structure
- [x] T001 Create project directory structure per plan.md
- [x] T002 Create requirements.txt with PyTorch, torchvision, pandas, scikit-learn
- [x] T003 Create data/ directory and subdirectories for images and labels
- [x] T004 Create models/ directory for saved models
- [x] T005 Create src/__init__.py and basic module structure
## Phase 2: Foundational (Blocking Prerequisites)
**Purpose**: Core components needed by all user stories
- [x] T006 [P] Implement data loader in src/data_loader.py for CSV labels and image loading
- [x] T007 [P] Define ResNet50 model in src/model.py with classification head
- [x] T008 [P] Create utils.py with preprocessing and helper functions
- [x] T009 [P] Set up test framework with pytest configuration
- [x] T010 [P] Create unit tests for data loader in tests/test_data_loader.py
## Phase 3: US1 - Train ResNet Model
**Purpose**: Enable model training on labeled datasets
**Independent Test**: Train on subset and verify model learns (accuracy improves)
- [x] T011 [P] [US1] Implement training script in src/train.py with data loading and ResNet training loop
- [x] T012 [P] [US1] Add model saving functionality to train.py
- [x] T013 [P] [US1] Implement data augmentation in utils.py for training
- [x] T014 [P] [US1] Create unit tests for training components in tests/test_train.py
- [x] T015 [US1] Integrate training pipeline and test end-to-end training
## Phase 4: US2 - Evaluate Model Performance
**Purpose**: Provide evaluation metrics for trained models
**Independent Test**: Run evaluation on test set and verify metrics output
- [x] T016 [P] [US2] Implement evaluation script in src/evaluate.py with accuracy and F1-score calculation
- [x] T017 [P] [US2] Add per-class metrics and confusion matrix in evaluate.py
- [x] T018 [P] [US2] Create unit tests for evaluation in tests/test_evaluate.py
- [x] T019 [US2] Integrate evaluation and test on trained model
## Phase 5: US3 - Classify New Images
**Purpose**: Enable inference on new plant images
**Independent Test**: Classify sample image and verify output format
- [x] T020 [P] [US3] Implement inference script in src/inference.py for single image classification
- [x] T021 [P] [US3] Create API endpoint in src/api.py using FastAPI for /classify POST
- [x] T022 [P] [US3] Add input validation and error handling in api.py
- [x] T023 [P] [US3] Create unit tests for inference in tests/test_inference.py
- [x] T024 [US3] Integrate API and test classification endpoint
## Final Phase: Polish & Cross-Cutting Concerns
**Purpose**: Quality assurance and production readiness
- [x] T025 Add logging and monitoring to all scripts
- [x] T026 Implement CI/CD pipeline with linting and testing
- [x] T027 Add comprehensive documentation and README updates
- [x] T028 Performance optimization and memory management
- [x] T029 Final integration testing and validation

View File

@ -0,0 +1,3 @@
"""ResNet Phenology Classifier Package"""
__version__ = "1.0.0"

View File

@ -0,0 +1,174 @@
"""FastAPI application for phenology classification."""
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
import torch
from PIL import Image
import io
import json
import os
from typing import Dict
from src.inference import load_inference_model, preprocess_image, predict
# Initialize FastAPI app
app = FastAPI(
title="ResNet Phenology Classifier API",
description="API for classifying plant images by phenological phase",
version="1.0.0"
)
# Global variables for model
model = None
classes = None
device = None
# Configuration
MODEL_PATH = os.getenv("MODEL_PATH", "models/best_model.pth")
CLASS_MAPPING_PATH = os.getenv("CLASS_MAPPING_PATH", "models/class_mapping.json")
@app.on_event("startup")
async def load_model():
"""Load model on startup."""
global model, classes, device
try:
model, classes, device = load_inference_model(MODEL_PATH, CLASS_MAPPING_PATH)
print(f"Model loaded successfully from {MODEL_PATH}")
print(f"Classes: {classes}")
except Exception as e:
print(f"Error loading model: {e}")
raise
@app.get("/")
async def root():
"""Root endpoint."""
return {
"message": "ResNet Phenology Classifier API",
"version": "1.0.0",
"endpoints": {
"classify": "/classify",
"health": "/health",
"classes": "/classes"
}
}
@app.get("/health")
async def health_check():
"""Health check endpoint."""
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
return {"status": "healthy", "model_loaded": True}
@app.get("/classes")
async def get_classes():
"""Get list of available classes."""
if classes is None:
raise HTTPException(status_code=503, detail="Model not loaded")
return {"classes": classes, "num_classes": len(classes)}
@app.post("/classify")
async def classify_image(file: UploadFile = File(...)) -> Dict:
"""
Classify a plant image.
Args:
file: Uploaded image file
Returns:
Dictionary with classification results
"""
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
# Validate file type
if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="File must be an image")
try:
# Read image
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert('RGB')
# Preprocess
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image_tensor = transform(image).unsqueeze(0)
# Predict
result = predict(model, image_tensor, classes, device)
return JSONResponse(content=result)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
@app.post("/classify/batch")
async def classify_batch(files: list[UploadFile] = File(...)) -> Dict:
"""
Classify multiple images.
Args:
files: List of uploaded image files
Returns:
Dictionary with classification results for each image
"""
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
if len(files) > 10:
raise HTTPException(status_code=400, detail="Maximum 10 images per batch")
results = []
for idx, file in enumerate(files):
if not file.content_type.startswith("image/"):
results.append({
"filename": file.filename,
"error": "File must be an image"
})
continue
try:
# Read image
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert('RGB')
# Preprocess
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image_tensor = transform(image).unsqueeze(0)
# Predict
result = predict(model, image_tensor, classes, device)
result["filename"] = file.filename
results.append(result)
except Exception as e:
results.append({
"filename": file.filename,
"error": str(e)
})
return JSONResponse(content={"results": results, "total": len(files)})
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)

View File

@ -0,0 +1,171 @@
"""Data loader module for loading and preprocessing plant images."""
import os
from typing import Tuple, List, Dict
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
class PhenologyDataset(Dataset):
"""Dataset class for plant phenology images."""
def __init__(
self,
csv_file: str,
root_dir: str,
transform: transforms.Compose = None,
split_ratio: Tuple[float, float, float] = (0.7, 0.15, 0.15),
split: str = 'train',
random_seed: int = 42
):
"""
Initialize the dataset.
Args:
csv_file: Path to CSV file with image paths and labels
root_dir: Root directory containing images
transform: Optional transform to be applied on images
split_ratio: Tuple of (train, val, test) ratios
split: One of 'train', 'val', or 'test'
random_seed: Random seed for reproducibility
"""
self.root_dir = root_dir
self.transform = transform
self.split = split
# Load CSV
self.data_frame = pd.read_csv(csv_file)
# Get unique classes and create mapping
self.classes = sorted(self.data_frame['phase'].unique())
self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
self.idx_to_class = {idx: cls for cls, idx in self.class_to_idx.items()}
# Split data
np.random.seed(random_seed)
indices = np.random.permutation(len(self.data_frame))
train_size = int(len(indices) * split_ratio[0])
val_size = int(len(indices) * split_ratio[1])
if split == 'train':
self.indices = indices[:train_size]
elif split == 'val':
self.indices = indices[train_size:train_size + val_size]
elif split == 'test':
self.indices = indices[train_size + val_size:]
else:
raise ValueError(f"Invalid split: {split}. Must be 'train', 'val', or 'test'")
self.data_frame = self.data_frame.iloc[self.indices].reset_index(drop=True)
def __len__(self) -> int:
"""Return the size of the dataset."""
return len(self.data_frame)
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
"""
Get item at index.
Args:
idx: Index of item
Returns:
Tuple of (image tensor, label)
"""
if torch.is_tensor(idx):
idx = idx.tolist()
img_name = os.path.join(self.root_dir, self.data_frame.iloc[idx]['image_path'])
image = Image.open(img_name).convert('RGB')
label = self.class_to_idx[self.data_frame.iloc[idx]['phase']]
if self.transform:
image = self.transform(image)
return image, label
def get_class_weights(self) -> torch.Tensor:
"""
Calculate class weights for handling imbalanced datasets.
Returns:
Tensor of class weights
"""
class_counts = self.data_frame['phase'].value_counts()
total = len(self.data_frame)
weights = torch.tensor([total / class_counts[cls] for cls in self.classes])
return weights / weights.sum()
def get_data_loaders(
csv_file: str,
root_dir: str,
batch_size: int = 32,
num_workers: int = 4,
split_ratio: Tuple[float, float, float] = (0.7, 0.15, 0.15),
random_seed: int = 42
) -> Dict[str, DataLoader]:
"""
Create data loaders for train, validation, and test sets.
Args:
csv_file: Path to CSV file with image paths and labels
root_dir: Root directory containing images
batch_size: Batch size for data loaders
num_workers: Number of worker processes for data loading
split_ratio: Tuple of (train, val, test) ratios
random_seed: Random seed for reproducibility
Returns:
Dictionary containing train, val, and test data loaders
"""
# Define transforms
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
val_test_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Create datasets
train_dataset = PhenologyDataset(
csv_file, root_dir, train_transform, split_ratio, 'train', random_seed
)
val_dataset = PhenologyDataset(
csv_file, root_dir, val_test_transform, split_ratio, 'val', random_seed
)
test_dataset = PhenologyDataset(
csv_file, root_dir, val_test_transform, split_ratio, 'test', random_seed
)
# Create data loaders
data_loaders = {
'train': DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
num_workers=num_workers, pin_memory=True
),
'val': DataLoader(
val_dataset, batch_size=batch_size, shuffle=False,
num_workers=num_workers, pin_memory=True
),
'test': DataLoader(
test_dataset, batch_size=batch_size, shuffle=False,
num_workers=num_workers, pin_memory=True
)
}
return data_loaders, train_dataset.classes, train_dataset.class_to_idx

View File

@ -0,0 +1,259 @@
"""Evaluation script for ResNet phenology classifier."""
import torch
import torch.nn as nn
from sklearn.metrics import accuracy_score, recall_score, f1_score, confusion_matrix, classification_report
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import argparse
import json
import os
from tqdm import tqdm
from src.model import create_model
from src.data_loader import get_data_loaders
from src.utils import load_checkpoint
def plot_confusion_matrix(
cm: np.ndarray,
class_names: list,
save_path: str = None,
normalize: bool = True
):
"""
Plot confusion matrix.
Args:
cm: Confusion matrix
class_names: List of class names
save_path: Path to save plot
normalize: Whether to normalize the confusion matrix
"""
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
fmt = '.2f'
else:
fmt = 'd'
plt.figure(figsize=(10, 8))
sns.heatmap(
cm,
annot=True,
fmt=fmt,
cmap='Blues',
xticklabels=class_names,
yticklabels=class_names,
cbar_kws={'label': 'Proportion' if normalize else 'Count'}
)
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.title('Confusion Matrix' + (' (Normalized)' if normalize else ''))
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"Confusion matrix saved to {save_path}")
plt.show()
def evaluate_model(
model: nn.Module,
dataloader: torch.utils.data.DataLoader,
device: str,
class_names: list
) -> dict:
"""
Evaluate model and compute metrics.
Args:
model: Model to evaluate
dataloader: Data loader
device: Device to use
class_names: List of class names
Returns:
Dictionary containing evaluation metrics
"""
model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
pbar = tqdm(dataloader, desc='Evaluating')
for images, labels in pbar:
images = images.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
all_preds.extend(predicted.cpu().numpy())
all_labels.extend(labels.numpy())
all_preds = np.array(all_preds)
all_labels = np.array(all_labels)
# Compute metrics
accuracy = accuracy_score(all_labels, all_preds)
recall = recall_score(all_labels, all_preds, average='macro')
macro_f1 = f1_score(all_labels, all_preds, average='macro')
# Per-class metrics
per_class_recall = recall_score(all_labels, all_preds, average=None)
per_class_f1 = f1_score(all_labels, all_preds, average=None)
# Confusion matrix
cm = confusion_matrix(all_labels, all_preds)
# Classification report
report = classification_report(
all_labels, all_preds,
target_names=class_names,
digits=4
)
metrics = {
'accuracy': float(accuracy),
'recall_macro': float(recall),
'f1_macro': float(macro_f1),
'per_class_recall': {class_names[i]: float(per_class_recall[i]) for i in range(len(class_names))},
'per_class_f1': {class_names[i]: float(per_class_f1[i]) for i in range(len(class_names))},
'confusion_matrix': cm.tolist(),
'classification_report': report
}
return metrics, cm
def evaluate(
model_path: str,
data_dir: str,
csv_file: str,
class_mapping_path: str,
output_dir: str = 'evaluation',
batch_size: int = 32,
num_workers: int = 4,
split: str = 'test'
):
"""
Main evaluation function.
Args:
model_path: Path to trained model checkpoint
data_dir: Directory containing images
csv_file: Path to CSV file with labels
class_mapping_path: Path to class mapping JSON file
output_dir: Directory to save evaluation results
batch_size: Batch size
num_workers: Number of data loader workers
split: Which split to evaluate ('val' or 'test')
"""
# Create output directory
os.makedirs(output_dir, exist_ok=True)
# Device configuration
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
# Load class mapping
print("\nLoading class mapping...")
with open(class_mapping_path, 'r') as f:
class_mapping = json.load(f)
classes = class_mapping['classes']
num_classes = len(classes)
print(f"Number of classes: {num_classes}")
print(f"Classes: {classes}")
# Load data
print(f"\nLoading {split} data...")
data_loaders, _, _ = get_data_loaders(
csv_file=csv_file,
root_dir=data_dir,
batch_size=batch_size,
num_workers=num_workers
)
dataloader = data_loaders[split]
print(f"Number of {split} samples: {len(dataloader.dataset)}")
# Load model
print("\nLoading model...")
model = create_model(num_classes, pretrained=False, device=device)
load_checkpoint(model_path, model, device=device)
# Evaluate
print(f"\nEvaluating on {split} set...")
metrics, cm = evaluate_model(model, dataloader, device, classes)
# Print results
print(f"\n{'='*50}")
print(f"EVALUATION RESULTS ({split.upper()} SET)")
print(f"{'='*50}")
print(f"Accuracy: {metrics['accuracy']:.4f} ({metrics['accuracy']*100:.2f}%)")
print(f"Recall (Macro): {metrics['recall_macro']:.4f}")
print(f"F1 (Macro): {metrics['f1_macro']:.4f}")
print(f"\nPer-Class Metrics:")
print("-" * 50)
for cls in classes:
print(f" {cls:20s} - Recall: {metrics['per_class_recall'][cls]:.4f}, F1: {metrics['per_class_f1'][cls]:.4f}")
print(f"{'='*50}")
print(f"\nClassification Report:")
print(metrics['classification_report'])
# Save metrics
metrics_path = os.path.join(output_dir, f'{split}_metrics.json')
with open(metrics_path, 'w') as f:
# Don't save classification report in JSON (it's a string)
metrics_to_save = {k: v for k, v in metrics.items() if k != 'classification_report'}
json.dump(metrics_to_save, f, indent=2)
print(f"\nMetrics saved to: {metrics_path}")
# Save classification report
report_path = os.path.join(output_dir, f'{split}_classification_report.txt')
with open(report_path, 'w') as f:
f.write(metrics['classification_report'])
print(f"Classification report saved to: {report_path}")
# Plot confusion matrix
cm_path = os.path.join(output_dir, f'{split}_confusion_matrix.png')
plot_confusion_matrix(cm, classes, save_path=cm_path, normalize=True)
# Also save non-normalized version
cm_raw_path = os.path.join(output_dir, f'{split}_confusion_matrix_raw.png')
plot_confusion_matrix(cm, classes, save_path=cm_raw_path, normalize=False)
# Check if model meets success criteria
if metrics['accuracy'] >= 0.9:
print(f"\n✓ Model meets success criteria (accuracy >= 90%)")
else:
print(f"\n✗ Model does not meet success criteria (accuracy < 90%)")
print(f" Current accuracy: {metrics['accuracy']*100:.2f}%")
print(f" Model should be flagged for retraining")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Evaluate ResNet phenology classifier')
parser.add_argument('--model_path', type=str, required=True, help='Path to model checkpoint')
parser.add_argument('--data_dir', type=str, required=True, help='Directory containing images')
parser.add_argument('--csv_file', type=str, required=True, help='Path to CSV file with labels')
parser.add_argument('--class_mapping', type=str, required=True, help='Path to class mapping JSON')
parser.add_argument('--output_dir', type=str, default='evaluation', help='Directory to save results')
parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
parser.add_argument('--num_workers', type=int, default=4, help='Number of data loader workers')
parser.add_argument('--split', type=str, default='test', choices=['val', 'test'], help='Split to evaluate')
args = parser.parse_args()
evaluate(
model_path=args.model_path,
data_dir=args.data_dir,
csv_file=args.csv_file,
class_mapping_path=args.class_mapping,
output_dir=args.output_dir,
batch_size=args.batch_size,
num_workers=args.num_workers,
split=args.split
)

View File

@ -0,0 +1,226 @@
"""Inference script for ResNet phenology classifier."""
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import argparse
import json
import os
import matplotlib.pyplot as plt
import numpy as np
from src.model import create_model
from src.utils import load_checkpoint
def load_inference_model(model_path: str, class_mapping_path: str, device: str = None):
"""
Load model for inference.
Args:
model_path: Path to model checkpoint
class_mapping_path: Path to class mapping JSON
device: Device to use
Returns:
Tuple of (model, classes, device)
"""
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load class mapping
with open(class_mapping_path, 'r') as f:
class_mapping = json.load(f)
classes = class_mapping['classes']
num_classes = len(classes)
# Load model
model = create_model(num_classes, pretrained=False, device=device)
load_checkpoint(model_path, model, device=device)
model.eval()
return model, classes, device
def preprocess_image(image_path: str) -> torch.Tensor:
"""
Preprocess image for inference.
Args:
image_path: Path to image file
Returns:
Preprocessed image tensor
"""
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image = Image.open(image_path).convert('RGB')
image_tensor = transform(image).unsqueeze(0) # Add batch dimension
return image_tensor, image
def predict(
model: torch.nn.Module,
image_tensor: torch.Tensor,
classes: list,
device: str
) -> dict:
"""
Make prediction on image.
Args:
model: Trained model
image_tensor: Preprocessed image tensor
classes: List of class names
device: Device to use
Returns:
Dictionary with prediction results
"""
image_tensor = image_tensor.to(device)
with torch.no_grad():
outputs = model(image_tensor)
probabilities = F.softmax(outputs, dim=1)
confidence, predicted = torch.max(probabilities, 1)
predicted_class = classes[predicted.item()]
confidence_score = confidence.item()
# Get probabilities for all classes
all_probs = {classes[i]: probabilities[0][i].item() for i in range(len(classes))}
result = {
'phase': predicted_class,
'confidence': confidence_score,
'probabilities': all_probs
}
return result
def visualize_prediction(
image: Image.Image,
result: dict,
save_path: str = None
):
"""
Visualize prediction result.
Args:
image: Original PIL image
result: Prediction result dictionary
save_path: Optional path to save visualization
"""
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
# Display image
ax1.imshow(image)
ax1.axis('off')
ax1.set_title(f"Predicted: {result['phase']}\nConfidence: {result['confidence']:.2%}",
fontsize=14, fontweight='bold')
# Display probabilities
classes = list(result['probabilities'].keys())
probs = list(result['probabilities'].values())
colors = ['green' if cls == result['phase'] else 'lightblue' for cls in classes]
ax2.barh(classes, probs, color=colors)
ax2.set_xlabel('Probability', fontsize=12)
ax2.set_title('Class Probabilities', fontsize=14, fontweight='bold')
ax2.set_xlim(0, 1)
# Add probability values on bars
for i, (cls, prob) in enumerate(zip(classes, probs)):
ax2.text(prob + 0.02, i, f'{prob:.2%}', va='center', fontsize=10)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"Visualization saved to {save_path}")
plt.show()
def classify_image(
image_path: str,
model_path: str,
class_mapping_path: str,
output_dir: str = None,
visualize: bool = True
):
"""
Classify a single image.
Args:
image_path: Path to image file
model_path: Path to model checkpoint
class_mapping_path: Path to class mapping JSON
output_dir: Optional directory to save results
visualize: Whether to create visualization
"""
print(f"Loading model from {model_path}...")
model, classes, device = load_inference_model(model_path, class_mapping_path)
print(f"Processing image: {image_path}")
image_tensor, original_image = preprocess_image(image_path)
print("Making prediction...")
result = predict(model, image_tensor, classes, device)
# Print results
print(f"\n{'='*50}")
print(f"CLASSIFICATION RESULT")
print(f"{'='*50}")
print(f"Predicted Phase: {result['phase']}")
print(f"Confidence: {result['confidence']:.2%}")
print(f"\nAll Probabilities:")
for cls, prob in sorted(result['probabilities'].items(), key=lambda x: x[1], reverse=True):
print(f" {cls:20s}: {prob:.2%}")
print(f"{'='*50}")
# Save results
if output_dir:
os.makedirs(output_dir, exist_ok=True)
# Save JSON result
result_path = os.path.join(output_dir, 'prediction.json')
with open(result_path, 'w') as f:
json.dump(result, f, indent=2)
print(f"\nPrediction saved to: {result_path}")
# Visualize
if visualize:
save_path = os.path.join(output_dir, 'prediction_visualization.png') if output_dir else None
visualize_prediction(original_image, result, save_path)
return result
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Classify plant image by phenological phase')
parser.add_argument('--image_path', type=str, required=True, help='Path to image file')
parser.add_argument('--model_path', type=str, required=True, help='Path to model checkpoint')
parser.add_argument('--class_mapping', type=str, required=True, help='Path to class mapping JSON')
parser.add_argument('--output_dir', type=str, default=None, help='Directory to save results')
parser.add_argument('--no_visualize', action='store_true', help='Do not create visualization')
args = parser.parse_args()
classify_image(
image_path=args.image_path,
model_path=args.model_path,
class_mapping_path=args.class_mapping,
output_dir=args.output_dir,
visualize=not args.no_visualize
)

View File

@ -0,0 +1,106 @@
"""ResNet model definition for phenology classification."""
import torch
import torch.nn as nn
from torchvision import models
from typing import Optional
class ResNetPhenologyClassifier(nn.Module):
"""ResNet50-based classifier for plant phenology phases."""
def __init__(self, num_classes: int, pretrained: bool = True, freeze_backbone: bool = False):
"""
Initialize the ResNet classifier.
Args:
num_classes: Number of phenology phase classes
pretrained: Whether to use pretrained ImageNet weights
freeze_backbone: Whether to freeze backbone layers
"""
super(ResNetPhenologyClassifier, self).__init__()
# Load pretrained ResNet50
if pretrained:
self.resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
else:
self.resnet = models.resnet50(weights=None)
# Freeze backbone if specified
if freeze_backbone:
for param in self.resnet.parameters():
param.requires_grad = False
# Replace final fully connected layer
num_features = self.resnet.fc.in_features
self.resnet.fc = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(num_features, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, num_classes)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the network.
Args:
x: Input tensor of shape (batch_size, 3, 224, 224)
Returns:
Output tensor of shape (batch_size, num_classes)
"""
return self.resnet(x)
def unfreeze_backbone(self):
"""Unfreeze all layers for fine-tuning."""
for param in self.resnet.parameters():
param.requires_grad = True
def get_features(self, x: torch.Tensor) -> torch.Tensor:
"""
Extract features before classification layer.
Args:
x: Input tensor
Returns:
Feature tensor
"""
# Forward through all layers except fc
x = self.resnet.conv1(x)
x = self.resnet.bn1(x)
x = self.resnet.relu(x)
x = self.resnet.maxpool(x)
x = self.resnet.layer1(x)
x = self.resnet.layer2(x)
x = self.resnet.layer3(x)
x = self.resnet.layer4(x)
x = self.resnet.avgpool(x)
x = torch.flatten(x, 1)
return x
def create_model(num_classes: int, pretrained: bool = True, device: Optional[str] = None) -> ResNetPhenologyClassifier:
"""
Create and initialize a ResNet model.
Args:
num_classes: Number of output classes
pretrained: Whether to use pretrained weights
device: Device to place model on ('cuda' or 'cpu')
Returns:
Initialized model
"""
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = ResNetPhenologyClassifier(num_classes, pretrained=pretrained)
model = model.to(device)
return model

View File

@ -0,0 +1,276 @@
"""Training script for ResNet phenology classifier."""
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
import argparse
import os
from typing import Dict, Tuple
import json
from src.model import create_model
from src.data_loader import get_data_loaders
from src.utils import set_seed, save_checkpoint, plot_training_history
def train_epoch(
model: nn.Module,
dataloader: torch.utils.data.DataLoader,
criterion: nn.Module,
optimizer: optim.Optimizer,
device: str
) -> Tuple[float, float]:
"""
Train for one epoch.
Args:
model: Model to train
dataloader: Training data loader
criterion: Loss function
optimizer: Optimizer
device: Device to train on
Returns:
Tuple of (average loss, accuracy)
"""
model.train()
running_loss = 0.0
correct = 0
total = 0
pbar = tqdm(dataloader, desc='Training')
for images, labels in pbar:
images, labels = images.to(device), labels.to(device)
# Zero gradients
optimizer.zero_grad()
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward pass
loss.backward()
optimizer.step()
# Statistics
running_loss += loss.item() * images.size(0)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
# Update progress bar
pbar.set_postfix({'loss': loss.item(), 'acc': 100 * correct / total})
epoch_loss = running_loss / total
epoch_acc = 100 * correct / total
return epoch_loss, epoch_acc
def validate(
model: nn.Module,
dataloader: torch.utils.data.DataLoader,
criterion: nn.Module,
device: str
) -> Tuple[float, float]:
"""
Validate the model.
Args:
model: Model to validate
dataloader: Validation data loader
criterion: Loss function
device: Device to validate on
Returns:
Tuple of (average loss, accuracy)
"""
model.eval()
running_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
pbar = tqdm(dataloader, desc='Validation')
for images, labels in pbar:
images, labels = images.to(device), labels.to(device)
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Statistics
running_loss += loss.item() * images.size(0)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
# Update progress bar
pbar.set_postfix({'loss': loss.item(), 'acc': 100 * correct / total})
epoch_loss = running_loss / total
epoch_acc = 100 * correct / total
return epoch_loss, epoch_acc
def train(
data_dir: str,
csv_file: str,
output_dir: str = 'models',
epochs: int = 50,
batch_size: int = 32,
learning_rate: float = 0.001,
num_workers: int = 4,
random_seed: int = 42,
pretrained: bool = True,
early_stopping_patience: int = 10
):
"""
Main training function.
Args:
data_dir: Directory containing images
csv_file: Path to CSV file with labels
output_dir: Directory to save models
epochs: Number of training epochs
batch_size: Batch size
learning_rate: Learning rate
num_workers: Number of data loader workers
random_seed: Random seed for reproducibility
pretrained: Use pretrained weights
early_stopping_patience: Patience for early stopping
"""
# Set seed for reproducibility
set_seed(random_seed)
# Device configuration
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
# Create output directory
os.makedirs(output_dir, exist_ok=True)
# Load data
print("\nLoading data...")
data_loaders, classes, class_to_idx = get_data_loaders(
csv_file=csv_file,
root_dir=data_dir,
batch_size=batch_size,
num_workers=num_workers,
random_seed=random_seed
)
num_classes = len(classes)
print(f"Number of classes: {num_classes}")
print(f"Classes: {classes}")
# Save class mapping
class_mapping = {'classes': classes, 'class_to_idx': class_to_idx}
with open(os.path.join(output_dir, 'class_mapping.json'), 'w') as f:
json.dump(class_mapping, f, indent=2)
# Create model
print("\nCreating model...")
model = create_model(num_classes, pretrained=pretrained, device=device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)
# Training history
train_losses, val_losses = [], []
train_accs, val_accs = [], []
best_val_acc = 0.0
epochs_without_improvement = 0
print("\nStarting training...")
for epoch in range(epochs):
print(f"\nEpoch {epoch + 1}/{epochs}")
print("-" * 50)
# Train
train_loss, train_acc = train_epoch(model, data_loaders['train'], criterion, optimizer, device)
print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
# Validate
val_loss, val_acc = validate(model, data_loaders['val'], criterion, device)
print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
# Update learning rate
scheduler.step(val_loss)
# Save history
train_losses.append(train_loss)
val_losses.append(val_loss)
train_accs.append(train_acc)
val_accs.append(val_acc)
# Save best model
if val_acc > best_val_acc:
best_val_acc = val_acc
epochs_without_improvement = 0
save_checkpoint(
model, optimizer, epoch, val_loss, val_acc,
os.path.join(output_dir, 'best_model.pth')
)
print(f"✓ New best model saved! Val Acc: {val_acc:.2f}%")
else:
epochs_without_improvement += 1
# Early stopping
if epochs_without_improvement >= early_stopping_patience:
print(f"\nEarly stopping triggered after {epoch + 1} epochs")
break
# Save final model
save_checkpoint(
model, optimizer, epoch, val_loss, val_acc,
os.path.join(output_dir, 'final_model.pth')
)
# Plot training history
plot_training_history(
train_losses, val_losses, train_accs, val_accs,
save_path=os.path.join(output_dir, 'training_history.png')
)
print(f"\n{'='*50}")
print(f"Training completed!")
print(f"Best validation accuracy: {best_val_acc:.2f}%")
print(f"Model saved to: {output_dir}")
print(f"{'='*50}")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Train ResNet phenology classifier')
parser.add_argument('--data_dir', type=str, required=True, help='Directory containing images')
parser.add_argument('--csv_file', type=str, required=True, help='Path to CSV file with labels')
parser.add_argument('--output_dir', type=str, default='models', help='Directory to save models')
parser.add_argument('--epochs', type=int, default=50, help='Number of epochs')
parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
parser.add_argument('--lr', type=float, default=0.001, help='Learning rate')
parser.add_argument('--num_workers', type=int, default=4, help='Number of data loader workers')
parser.add_argument('--seed', type=int, default=42, help='Random seed')
parser.add_argument('--no_pretrained', action='store_true', help='Do not use pretrained weights')
parser.add_argument('--patience', type=int, default=10, help='Early stopping patience')
args = parser.parse_args()
train(
data_dir=args.data_dir,
csv_file=args.csv_file,
output_dir=args.output_dir,
epochs=args.epochs,
batch_size=args.batch_size,
learning_rate=args.lr,
num_workers=args.num_workers,
random_seed=args.seed,
pretrained=not args.no_pretrained,
early_stopping_patience=args.patience
)

View File

@ -0,0 +1,183 @@
"""Utility functions for training and preprocessing."""
import torch
import random
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Optional
import os
def set_seed(seed: int = 42):
"""
Set random seed for reproducibility.
Args:
seed: Random seed value
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def save_checkpoint(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
epoch: int,
loss: float,
accuracy: float,
filepath: str
):
"""
Save model checkpoint.
Args:
model: Model to save
optimizer: Optimizer state
epoch: Current epoch
loss: Current loss
accuracy: Current accuracy
filepath: Path to save checkpoint
"""
os.makedirs(os.path.dirname(filepath), exist_ok=True)
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'accuracy': accuracy
}
torch.save(checkpoint, filepath)
print(f"Checkpoint saved to {filepath}")
def load_checkpoint(
filepath: str,
model: torch.nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
device: str = 'cuda'
) -> dict:
"""
Load model checkpoint.
Args:
filepath: Path to checkpoint file
model: Model to load weights into
optimizer: Optional optimizer to load state into
device: Device to load model on
Returns:
Dictionary with checkpoint information
"""
checkpoint = torch.load(filepath, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
if optimizer is not None:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
print(f"Checkpoint loaded from {filepath}")
print(f"Epoch: {checkpoint['epoch']}, Loss: {checkpoint['loss']:.4f}, Accuracy: {checkpoint['accuracy']:.4f}")
return checkpoint
def plot_training_history(
train_losses: List[float],
val_losses: List[float],
train_accs: List[float],
val_accs: List[float],
save_path: Optional[str] = None
):
"""
Plot training history.
Args:
train_losses: List of training losses
val_losses: List of validation losses
train_accs: List of training accuracies
val_accs: List of validation accuracies
save_path: Optional path to save plot
"""
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
# Plot loss
ax1.plot(train_losses, label='Train Loss', marker='o')
ax1.plot(val_losses, label='Val Loss', marker='s')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()
ax1.grid(True)
# Plot accuracy
ax2.plot(train_accs, label='Train Accuracy', marker='o')
ax2.plot(val_accs, label='Val Accuracy', marker='s')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('Training and Validation Accuracy')
ax2.legend()
ax2.grid(True)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"Training history plot saved to {save_path}")
plt.show()
def visualize_predictions(
images: torch.Tensor,
predictions: torch.Tensor,
labels: torch.Tensor,
class_names: List[str],
num_images: int = 6,
save_path: Optional[str] = None
):
"""
Visualize model predictions on images.
Args:
images: Batch of images
predictions: Model predictions
labels: True labels
class_names: List of class names
num_images: Number of images to display
save_path: Optional path to save visualization
"""
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.ravel()
# Denormalize images
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
for idx in range(min(num_images, len(images))):
img = images[idx].cpu() * std + mean
img = img.permute(1, 2, 0).numpy()
img = np.clip(img, 0, 1)
pred_class = class_names[predictions[idx]]
true_class = class_names[labels[idx]]
axes[idx].imshow(img)
axes[idx].axis('off')
color = 'green' if pred_class == true_class else 'red'
axes[idx].set_title(f'Pred: {pred_class}\nTrue: {true_class}', color=color, fontsize=12)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"Predictions visualization saved to {save_path}")
plt.show()

View File

@ -0,0 +1,144 @@
"""Unit tests for data loader module."""
import pytest
import torch
import pandas as pd
import os
from PIL import Image
import tempfile
import shutil
from src.data_loader import PhenologyDataset, get_data_loaders
@pytest.fixture
def sample_dataset():
"""Create a temporary dataset for testing."""
temp_dir = tempfile.mkdtemp()
img_dir = os.path.join(temp_dir, 'images')
os.makedirs(img_dir, exist_ok=True)
# Create sample images
phases = ['vegetative', 'flowering', 'fruiting']
image_paths = []
for i, phase in enumerate(phases * 10): # 30 images total
img = Image.new('RGB', (224, 224), color=(i*8, i*8, i*8))
img_path = f'img_{i:03d}.jpg'
img.save(os.path.join(img_dir, img_path))
image_paths.append((img_path, phase))
# Create CSV
csv_path = os.path.join(temp_dir, 'labels.csv')
df = pd.DataFrame(image_paths, columns=['image_path', 'phase'])
df.to_csv(csv_path, index=False)
yield csv_path, img_dir, phases
# Cleanup
shutil.rmtree(temp_dir)
@pytest.mark.unit
def test_dataset_initialization(sample_dataset):
"""Test dataset initialization."""
csv_path, img_dir, phases = sample_dataset
dataset = PhenologyDataset(
csv_file=csv_path,
root_dir=img_dir,
split='train'
)
assert len(dataset) > 0
assert len(dataset.classes) == len(phases)
assert set(dataset.classes) == set(phases)
@pytest.mark.unit
def test_dataset_split_ratios(sample_dataset):
"""Test dataset splitting."""
csv_path, img_dir, _ = sample_dataset
train_dataset = PhenologyDataset(csv_path, img_dir, split='train')
val_dataset = PhenologyDataset(csv_path, img_dir, split='val')
test_dataset = PhenologyDataset(csv_path, img_dir, split='test')
total = len(train_dataset) + len(val_dataset) + len(test_dataset)
# Check approximate split ratios (70/15/15)
assert abs(len(train_dataset) / total - 0.7) < 0.1
assert abs(len(val_dataset) / total - 0.15) < 0.1
assert abs(len(test_dataset) / total - 0.15) < 0.1
@pytest.mark.unit
def test_dataset_getitem(sample_dataset):
"""Test getting items from dataset."""
csv_path, img_dir, _ = sample_dataset
dataset = PhenologyDataset(csv_path, img_dir, split='train')
image, label = dataset[0]
# Without transform, should return PIL Image
assert isinstance(label, int)
assert label >= 0 and label < len(dataset.classes)
@pytest.mark.unit
def test_data_loaders(sample_dataset):
"""Test data loader creation."""
csv_path, img_dir, _ = sample_dataset
data_loaders, classes, class_to_idx = get_data_loaders(
csv_file=csv_path,
root_dir=img_dir,
batch_size=4,
num_workers=0 # Use 0 for testing
)
assert 'train' in data_loaders
assert 'val' in data_loaders
assert 'test' in data_loaders
assert len(classes) > 0
assert len(class_to_idx) == len(classes)
# Test getting a batch
batch = next(iter(data_loaders['train']))
images, labels = batch
assert images.shape[0] <= 4 # Batch size
assert images.shape[1] == 3 # RGB channels
assert images.shape[2] == 224 # Height
assert images.shape[3] == 224 # Width
assert len(labels) == images.shape[0]
@pytest.mark.unit
def test_class_weights(sample_dataset):
"""Test class weight calculation."""
csv_path, img_dir, _ = sample_dataset
dataset = PhenologyDataset(csv_path, img_dir, split='train')
weights = dataset.get_class_weights()
assert weights.shape[0] == len(dataset.classes)
assert torch.all(weights > 0)
assert torch.isclose(weights.sum(), torch.tensor(1.0), atol=1e-6)
@pytest.mark.unit
def test_reproducibility(sample_dataset):
"""Test that same seed produces same splits."""
csv_path, img_dir, _ = sample_dataset
dataset1 = PhenologyDataset(csv_path, img_dir, split='train', random_seed=42)
dataset2 = PhenologyDataset(csv_path, img_dir, split='train', random_seed=42)
# Should have same samples
assert len(dataset1) == len(dataset2)
# Check first few samples are the same
for i in range(min(5, len(dataset1))):
_, label1 = dataset1[i]
_, label2 = dataset2[i]
assert label1 == label2

View File

@ -0,0 +1,89 @@
"""Unit tests for evaluation module."""
import pytest
import torch
import numpy as np
from src.evaluate import evaluate_model, plot_confusion_matrix
from src.model import ResNetPhenologyClassifier
@pytest.fixture
def mock_model():
"""Create a small model for testing."""
return ResNetPhenologyClassifier(num_classes=3, pretrained=False)
@pytest.fixture
def mock_dataloader():
"""Create a mock dataloader for testing."""
images = torch.randn(8, 3, 224, 224)
labels = torch.randint(0, 3, (8,))
dataset = torch.utils.data.TensorDataset(images, labels)
return torch.utils.data.DataLoader(dataset, batch_size=4)
@pytest.mark.unit
def test_evaluate_model(mock_model, mock_dataloader):
"""Test model evaluation."""
device = 'cpu'
model = mock_model.to(device)
class_names = ['class_0', 'class_1', 'class_2']
metrics, cm = evaluate_model(model, mock_dataloader, device, class_names)
# Check metrics exist
assert 'accuracy' in metrics
assert 'recall_macro' in metrics
assert 'f1_macro' in metrics
assert 'per_class_recall' in metrics
assert 'per_class_f1' in metrics
assert 'confusion_matrix' in metrics
# Check metric ranges
assert 0 <= metrics['accuracy'] <= 1
assert 0 <= metrics['recall_macro'] <= 1
assert 0 <= metrics['f1_macro'] <= 1
# Check confusion matrix shape
assert len(metrics['confusion_matrix']) == 3
assert len(metrics['confusion_matrix'][0]) == 3
assert isinstance(cm, np.ndarray)
assert cm.shape == (3, 3)
@pytest.mark.unit
def test_confusion_matrix_values(mock_model, mock_dataloader):
"""Test confusion matrix values."""
device = 'cpu'
model = mock_model.to(device)
class_names = ['class_0', 'class_1', 'class_2']
_, cm = evaluate_model(model, mock_dataloader, device, class_names)
# Confusion matrix should sum to total number of samples
total_samples = len(mock_dataloader.dataset)
assert cm.sum() == total_samples
# All values should be non-negative
assert np.all(cm >= 0)
@pytest.mark.unit
def test_per_class_metrics(mock_model, mock_dataloader):
"""Test per-class metrics."""
device = 'cpu'
model = mock_model.to(device)
class_names = ['class_0', 'class_1', 'class_2']
metrics, _ = evaluate_model(model, mock_dataloader, device, class_names)
# Check per-class metrics have correct number of classes
assert len(metrics['per_class_recall']) == 3
assert len(metrics['per_class_f1']) == 3
# Check all classes have metrics
for cls in class_names:
assert cls in metrics['per_class_recall']
assert cls in metrics['per_class_f1']
assert 0 <= metrics['per_class_recall'][cls] <= 1
assert 0 <= metrics['per_class_f1'][cls] <= 1

View File

@ -0,0 +1,113 @@
"""Unit tests for inference module."""
import pytest
import torch
import json
import tempfile
import os
from PIL import Image
from src.inference import preprocess_image, predict, load_inference_model
from src.model import ResNetPhenologyClassifier
@pytest.fixture
def mock_model():
"""Create a small model for testing."""
return ResNetPhenologyClassifier(num_classes=3, pretrained=False)
@pytest.fixture
def temp_image():
"""Create a temporary test image."""
temp_file = tempfile.NamedTemporaryFile(suffix='.jpg', delete=False)
img = Image.new('RGB', (224, 224), color='red')
img.save(temp_file.name)
yield temp_file.name
os.unlink(temp_file.name)
@pytest.fixture
def temp_class_mapping():
"""Create a temporary class mapping file."""
temp_file = tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False)
mapping = {
'classes': ['vegetative', 'flowering', 'fruiting'],
'class_to_idx': {'vegetative': 0, 'flowering': 1, 'fruiting': 2}
}
json.dump(mapping, temp_file)
temp_file.flush()
yield temp_file.name
os.unlink(temp_file.name)
@pytest.mark.unit
def test_preprocess_image(temp_image):
"""Test image preprocessing."""
image_tensor, original_image = preprocess_image(temp_image)
assert image_tensor.shape == (1, 3, 224, 224) # Batch size 1, RGB, 224x224
assert isinstance(original_image, Image.Image)
assert original_image.size == (224, 224)
@pytest.mark.unit
def test_predict(mock_model, temp_image):
"""Test prediction."""
device = 'cpu'
model = mock_model.to(device)
model.eval()
classes = ['vegetative', 'flowering', 'fruiting']
image_tensor, _ = preprocess_image(temp_image)
result = predict(model, image_tensor, classes, device)
# Check result structure
assert 'phase' in result
assert 'confidence' in result
assert 'probabilities' in result
# Check values
assert result['phase'] in classes
assert 0 <= result['confidence'] <= 1
assert len(result['probabilities']) == len(classes)
# Check probabilities sum to ~1
prob_sum = sum(result['probabilities'].values())
assert abs(prob_sum - 1.0) < 1e-5
@pytest.mark.unit
def test_predict_confidence(mock_model, temp_image):
"""Test that prediction confidence matches max probability."""
device = 'cpu'
model = mock_model.to(device)
model.eval()
classes = ['vegetative', 'flowering', 'fruiting']
image_tensor, _ = preprocess_image(temp_image)
result = predict(model, image_tensor, classes, device)
# Confidence should match the probability of the predicted class
predicted_prob = result['probabilities'][result['phase']]
assert abs(result['confidence'] - predicted_prob) < 1e-6
@pytest.mark.unit
def test_predict_consistency(mock_model, temp_image):
"""Test that predictions are consistent."""
device = 'cpu'
model = mock_model.to(device)
model.eval()
classes = ['vegetative', 'flowering', 'fruiting']
image_tensor, _ = preprocess_image(temp_image)
# Run prediction twice
result1 = predict(model, image_tensor, classes, device)
result2 = predict(model, image_tensor, classes, device)
# Should get same results
assert result1['phase'] == result2['phase']
assert abs(result1['confidence'] - result2['confidence']) < 1e-6

View File

@ -0,0 +1,93 @@
"""Unit tests for training module."""
import pytest
import torch
import torch.nn as nn
from src.train import train_epoch, validate
from src.model import ResNetPhenologyClassifier
@pytest.fixture
def mock_model():
"""Create a small model for testing."""
return ResNetPhenologyClassifier(num_classes=3, pretrained=False)
@pytest.fixture
def mock_dataloader():
"""Create a mock dataloader for testing."""
# Create dummy data
images = torch.randn(8, 3, 224, 224)
labels = torch.randint(0, 3, (8,))
dataset = torch.utils.data.TensorDataset(images, labels)
return torch.utils.data.DataLoader(dataset, batch_size=4)
@pytest.mark.unit
def test_train_epoch(mock_model, mock_dataloader):
"""Test training for one epoch."""
device = 'cpu'
model = mock_model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss, acc = train_epoch(model, mock_dataloader, criterion, optimizer, device)
assert isinstance(loss, float)
assert isinstance(acc, float)
assert loss >= 0
assert 0 <= acc <= 100
@pytest.mark.unit
def test_validate(mock_model, mock_dataloader):
"""Test validation."""
device = 'cpu'
model = mock_model.to(device)
criterion = nn.CrossEntropyLoss()
loss, acc = validate(model, mock_dataloader, criterion, device)
assert isinstance(loss, float)
assert isinstance(acc, float)
assert loss >= 0
assert 0 <= acc <= 100
@pytest.mark.unit
def test_model_gradients(mock_model):
"""Test that gradients are computed during training."""
device = 'cpu'
model = mock_model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
# Create dummy batch
images = torch.randn(2, 3, 224, 224)
labels = torch.randint(0, 3, (2,))
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward pass
optimizer.zero_grad()
loss.backward()
# Check gradients exist
has_gradients = any(p.grad is not None for p in model.parameters())
assert has_gradients
@pytest.mark.unit
def test_model_inference_mode(mock_model):
"""Test that model can switch between train and eval modes."""
model = mock_model
# Train mode
model.train()
assert model.training
# Eval mode
model.eval()
assert not model.training

View File

@ -0,0 +1,298 @@
#!/usr/bin/env python3
"""
Script facilitado para entrenar ResNet50 - Nocciola
Comparación con MobileNetV2
"""
import os
import sys
import subprocess
import argparse
# Configuración de rutas
PROJECT_DIR = r'C:\Users\sof12\Desktop\ML'
RESNET_SCRIPT = os.path.join(PROJECT_DIR, 'Code', 'Supervised_learning', 'ResNET.py')
MOBILENET_SCRIPT = os.path.join(PROJECT_DIR, 'Code', 'Supervised_learning', 'MobileNetV1.py')
PYTHON_EXE = r'C:/Users/sof12/AppData/Local/Programs/Python/Python312/python.exe'
# Datasets disponibles
DATASETS = {
'original': os.path.join(PROJECT_DIR, 'Datasets', 'Nocciola_GBIF', 'metadatos_unidos.csv'),
'filtered': os.path.join(PROJECT_DIR, 'Datasets', 'Nocciola_GBIF', 'metadatos_unidos_filtered.csv'),
'assignments': os.path.join(PROJECT_DIR, 'Datasets', 'Nocciola_GBIF', 'assignments.csv')
}
def check_files():
"""Verificar que todos los archivos necesarios existen"""
print("🔍 === Verificando archivos ===")
files_to_check = {
'ResNet Script': RESNET_SCRIPT,
'MobileNet Script': MOBILENET_SCRIPT,
'Python Executable': PYTHON_EXE
}
all_ok = True
for name, path in files_to_check.items():
if os.path.exists(path):
print(f"{name}: {path}")
else:
print(f"{name}: {path}")
all_ok = False
# Verificar datasets
print(f"\n📊 === Datasets disponibles ===")
for name, path in DATASETS.items():
if os.path.exists(path):
print(f"{name}: {path}")
else:
print(f"{name}: {path}")
return all_ok
def train_resnet50(dataset_key='assignments', epochs=25, force_split=True):
"""Entrenar modelo ResNet50"""
dataset_path = DATASETS.get(dataset_key)
if not dataset_path or not os.path.exists(dataset_path):
print(f"❌ Dataset '{dataset_key}' no encontrado")
return False
print(f"\n🚀 === Entrenando ResNet50 ===")
print(f"📊 Dataset: {dataset_key} ({dataset_path})")
print(f"⏱️ Épocas: {epochs}")
# Comando de entrenamiento
cmd = [
PYTHON_EXE, RESNET_SCRIPT,
'--csv_path', dataset_path,
'--epochs', str(epochs)
]
if force_split:
cmd.append('--force_split')
print("🔄 Comando a ejecutar:")
print(" ".join(cmd))
try:
# Cambiar al directorio correcto
os.chdir(os.path.join(PROJECT_DIR, 'Code', 'Supervised_learning'))
print(f"\n🏋️ Iniciando entrenamiento ResNet50...")
result = subprocess.run(cmd, check=True)
print("✅ Entrenamiento ResNet50 completado exitosamente!")
return True
except subprocess.CalledProcessError as e:
print(f"❌ Error durante el entrenamiento ResNet50: {e}")
return False
except KeyboardInterrupt:
print("\n⚠️ Entrenamiento ResNet50 interrumpido por el usuario")
return False
def train_mobilenet(dataset_key='assignments', epochs=25, force_split=True):
"""Entrenar modelo MobileNetV2 para comparación"""
dataset_path = DATASETS.get(dataset_key)
if not dataset_path or not os.path.exists(dataset_path):
print(f"❌ Dataset '{dataset_key}' no encontrado")
return False
print(f"\n🚀 === Entrenando MobileNetV2 (comparación) ===")
print(f"📊 Dataset: {dataset_key} ({dataset_path})")
print(f"⏱️ Épocas: {epochs}")
# Comando de entrenamiento
cmd = [
PYTHON_EXE, MOBILENET_SCRIPT,
'--csv_path', dataset_path,
'--epochs', str(epochs)
]
if force_split:
cmd.append('--force_split')
print("🔄 Comando a ejecutar:")
print(" ".join(cmd))
try:
# Cambiar al directorio correcto
os.chdir(os.path.join(PROJECT_DIR, 'Code', 'Supervised_learning'))
print(f"\n🏋️ Iniciando entrenamiento MobileNetV2...")
result = subprocess.run(cmd, check=True)
print("✅ Entrenamiento MobileNetV2 completado exitosamente!")
return True
except subprocess.CalledProcessError as e:
print(f"❌ Error durante el entrenamiento MobileNetV2: {e}")
return False
except KeyboardInterrupt:
print("\n⚠️ Entrenamiento MobileNetV2 interrumpido por el usuario")
return False
def compare_models():
"""Mostrar comparación entre modelos"""
print("""
📊 === Comparación ResNet50 vs MobileNetV2 ===
RESNET50:
🟢 Ventajas:
- Mayor capacidad de aprendizaje (25M parámetros)
- Mejor para datasets complejos
- Residual connections mejoran gradiente
- Excelente para transfer learning
🟡 Desventajas:
- Más lento en entrenamiento e inferencia
- Requiere más memoria
- Overfitting en datasets pequeños
MOBILENETV2:
🟢 Ventajas:
- Rápido y eficiente (3.4M parámetros)
- Menos propenso a overfitting
- Ideal para datasets pequeños/medianos
- Menor uso de memoria
🟡 Desventajas:
- Menor capacidad para patrones complejos
- Puede underfittear en datos complejos
RECOMENDACIONES:
📋 Dataset Nocciola (pequeño): MobileNetV2 recomendado
📋 Dataset grande/complejo: ResNet50 recomendado
📋 Producción/móvil: MobileNetV2
📋 Investigación/precisión máxima: ResNet50
""")
def main():
parser = argparse.ArgumentParser(description='Entrenamiento ResNet50 vs MobileNetV2 para Nocciola')
parser.add_argument('--model', choices=['resnet50', 'mobilenet', 'both'], default='resnet50',
help='Modelo a entrenar')
parser.add_argument('--dataset', choices=['original', 'filtered', 'assignments'], default='assignments',
help='Dataset a usar')
parser.add_argument('--epochs', type=int, default=25,
help='Número de épocas')
parser.add_argument('--no_force_split', action='store_true',
help='No forzar recreación del split')
parser.add_argument('--compare', action='store_true',
help='Mostrar comparación de modelos')
parser.add_argument('--check_only', action='store_true',
help='Solo verificar archivos')
args = parser.parse_args()
print("🎯 === Entrenamiento ResNet50 para Nocciola ===")
print(f"📁 Directorio del proyecto: {PROJECT_DIR}")
print(f"🐍 Python ejecutable: {PYTHON_EXE}")
# Verificar archivos
if not check_files():
print("❌ Faltan archivos necesarios")
return False
if args.check_only:
print(" Solo verificación solicitada. Finalizando.")
return True
if args.compare:
compare_models()
return True
force_split = not args.no_force_split
success = True
# Entrenar modelos según selección
if args.model in ['resnet50', 'both']:
print(f"\n🤖 === Iniciando ResNet50 ===")
success &= train_resnet50(args.dataset, args.epochs, force_split)
if args.model in ['mobilenet', 'both']:
print(f"\n🤖 === Iniciando MobileNetV2 ===")
success &= train_mobilenet(args.dataset, args.epochs, force_split)
if success:
print(f"\n🎉 === Entrenamiento(s) Completado(s) ===")
# Mostrar ubicaciones de resultados
results_info = []
if args.model in ['resnet50', 'both']:
resnet_results = os.path.join(PROJECT_DIR, 'Datasets', 'Nocciola_GBIF', 'results_resnet50_faseV')
results_info.append(f"📁 ResNet50: {resnet_results}")
if args.model in ['mobilenet', 'both']:
mobilenet_results = os.path.join(PROJECT_DIR, 'Datasets', 'Nocciola_GBIF', 'results_mobilenet_faseV_V1')
results_info.append(f"📁 MobileNetV2: {mobilenet_results}")
for info in results_info:
print(info)
if args.model == 'both':
print(f"\n💡 Comparar resultados:")
print(f" - Revisar classification_report.txt en cada directorio")
print(f" - Comparar matrices de confusión")
print(f" - Evaluar tiempos de entrenamiento")
return True
return False
def show_help():
"""Mostrar ayuda detallada"""
print("""
📚 === Ayuda - ResNet50 Transfer Learning ===
NUEVO MODELO IMPLEMENTADO:
ResNet50 con transfer learning optimizado para el dataset Nocciola
CARACTERÍSTICAS PRINCIPALES:
🔧 Arquitectura: ResNet50 pre-entrenado (ImageNet)
Optimizaciones: BatchNormalization, Dropout adaptativos
🎯 Fine-tuning: Descongelado selectivo de capas residuales
📊 Manejo robusto: Clases desbalanceadas automáticamente
DIFERENCIAS CON MOBILENET:
- Learning rate más conservador (5e-4 vs 1e-3)
- Data augmentation menos agresivo
- Fine-tuning más específico (conv5_x layers)
- Más épocas de fine-tuning (15 vs 10)
EJEMPLOS DE USO:
# Entrenar solo ResNet50
python train_resnet50.py --model resnet50 --epochs 25
# Comparar ambos modelos
python train_resnet50.py --model both --epochs 20
# Análisis de diferencias
python train_resnet50.py --compare
# Verificar setup
python train_resnet50.py --check_only
DATASETS SOPORTADOS:
- assignments: CSV con asignaciones de clases (recomendado)
- filtered: Dataset filtrado sin clases problemáticas
- original: Dataset completo con manejo robusto
OUTPUTS ESPECÍFICOS:
- best_resnet50_model.keras: Mejor modelo durante entrenamiento
- final_resnet50_model.keras: Modelo final
- resnet50_training_history.png: Gráficos específicos
- resnet50_confusion_matrix.png: Matriz de confusión
- resnet50_prediction_examples.png: Ejemplos de predicciones
""")
if __name__ == "__main__":
if len(sys.argv) > 1 and sys.argv[1] in ['--help', '-h', 'help']:
show_help()
else:
success = main()
if not success:
print("\n❌ Proceso terminado con errores")
print("💡 Usa --help para más información")
else:
print("\n🎉 ¡Proceso completado exitosamente!")

View File

@ -0,0 +1,433 @@
#!/usr/bin/env python3
import os
import re
import json
import argparse
import warnings
from typing import List, Optional, Tuple
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans, DBSCAN, AgglomerativeClustering, MiniBatchKMeans
from sklearn.metrics import silhouette_score, calinski_harabasz_score, davies_bouldin_score
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import NearestNeighbors
import matplotlib.pyplot as plt
import seaborn as sns
import joblib
import tensorflow as tf
from keras.applications import MobileNetV2, EfficientNetB0
from keras.applications.mobilenet_v2 import preprocess_input as mobilenet_preprocess
from keras.applications.efficientnet import preprocess_input as efficientnet_preprocess
from keras import backend as K
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
K.set_image_data_format("channels_last")
# -----------------------------
# Utils
# -----------------------------
def set_seed(seed: int = 42):
np.random.seed(seed)
tf.random.set_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
def ensure_dir(path: str):
os.makedirs(path, exist_ok=True)
def _read_csv_any(path: str) -> pd.DataFrame:
for enc in ("utf-8", "utf-8-sig", "latin-1"):
try:
return pd.read_csv(path, encoding=enc)
except UnicodeDecodeError:
continue
return pd.read_csv(path, encoding="utf-8", errors="replace")
def _normalize_col_name(name: str) -> str:
if not isinstance(name, str):
return ""
s = name.strip().lower()
m = re.match(r"^(.*)_(a|b)$", s)
if m:
s = m.group(1)
for ch in [" ", "_", "-", ".", "/"]:
s = s.replace(ch, "")
return s
def find_matching_cols(df: pd.DataFrame, aliases: List[str]) -> List[str]:
tgt = {_normalize_col_name(a) for a in aliases}
out = []
for c in df.columns:
if _normalize_col_name(c) in tgt:
out.append(c)
return out
def best_filename_from_row(row: pd.Series, img_ext: str = ".jpg") -> Optional[str]:
for key in ["filename", "file_name", "image", "image_name", "New_Name_With_Date", "New_Name", "Nombre_Nuevo", "Old_Name"]:
if key in row and pd.notna(row[key]) and str(row[key]).strip() != "":
fname = str(row[key]).strip()
if not os.path.splitext(fname)[1]:
fname = fname + img_ext
return fname
for key in ["basename_final", "basename"]:
if key in row and pd.notna(row[key]) and str(row[key]).strip() != "":
return f"{row[key]}{img_ext}"
return None
def attach_paths_single_csv(df: pd.DataFrame, images_dir: str, img_ext: str = ".jpg", search_subdirs: bool = False) -> pd.DataFrame:
paths = []
miss = 0
for _, r in df.iterrows():
fname = best_filename_from_row(r, img_ext)
if not fname:
paths.append((None, None))
miss += 1
continue
p = os.path.join(images_dir, fname)
if not os.path.exists(p) and search_subdirs:
# buscar en subcarpetas
found = None
for root, _, files in os.walk(images_dir):
if fname in files:
found = os.path.join(root, fname)
break
p = found if found else p
paths.append((fname, p if p and isinstance(p, str) and os.path.exists(p) else None))
if paths[-1][1] is None:
miss += 1
if miss:
warnings.warn(f"{miss} archivos listados no existen en disco. Serán ignorados.")
out = df.copy()
out["filename"] = [t[0] for t in paths]
out["path"] = [t[1] for t in paths]
out = out[pd.notna(out["path"])].reset_index(drop=True)
return out
# -----------------------------
# Embeddings
# -----------------------------
def make_preprocess(backbone: str):
return mobilenet_preprocess if backbone == "mobilenet" else efficientnet_preprocess
def make_backbone_model(img_size: int, backbone: str) -> tf.keras.Model:
tf.keras.backend.clear_session()
K.set_image_data_format("channels_last")
input_shape = (img_size, img_size, 3)
if backbone == "efficientnet":
try:
model = EfficientNetB0(include_top=False, weights="imagenet", input_shape=input_shape, pooling="avg")
except Exception as e:
warnings.warn(f"No se pudo cargar EfficientNetB0 con pesos ImageNet ({e}). Se usarán pesos aleatorios.")
model = EfficientNetB0(include_top=False, weights=None, input_shape=input_shape, pooling="avg")
else:
model = MobileNetV2(include_top=False, weights="imagenet", input_shape=input_shape, pooling="avg")
model.trainable = False
return model
def build_dataset(paths: List[str], img_size: int, preprocess_fn, batch_size: int = 64) -> tf.data.Dataset:
ds = tf.data.Dataset.from_tensor_slices(paths)
def _load_tf(p):
x = tf.io.read_file(p)
x = tf.image.decode_jpeg(x, channels=3)
x = tf.image.resize(x, [img_size, img_size], method="bilinear", antialias=True)
x = tf.cast(x, tf.float32)
x = preprocess_fn(x)
return x
return ds.map(_load_tf, num_parallel_calls=tf.data.AUTOTUNE).batch(batch_size).prefetch(tf.data.AUTOTUNE)
def compute_embeddings(model: tf.keras.Model, ds: tf.data.Dataset) -> np.ndarray:
return model.predict(ds, verbose=1)
# -----------------------------
# Reduction + clustering
# -----------------------------
def fit_reduction(train_emb: np.ndarray, n_pca: int = 50):
scaler = StandardScaler()
Xs = scaler.fit_transform(train_emb)
pca = PCA(n_components=min(n_pca, Xs.shape[1]))
Z = pca.fit_transform(Xs)
return scaler, pca, Z
def transform_reduction(emb: np.ndarray, scaler: StandardScaler, pca: PCA) -> np.ndarray:
return pca.transform(scaler.transform(emb))
def _centers_from_labels(X: np.ndarray, y: np.ndarray) -> Optional[np.ndarray]:
cs = []
for c in sorted(set(y)):
if c == -1:
continue
cs.append(X[y == c].mean(axis=0))
return np.array(cs) if cs else None
def tune_dbscan(train_feats: np.ndarray,
metric: str = "euclidean",
min_samples_grid=(3, 5, 10),
quantiles=(0.6, 0.7, 0.8, 0.9)) -> Tuple[Optional[DBSCAN], Optional[np.ndarray], Optional[np.ndarray]]:
best = {"score": -np.inf, "model": None, "labels": None}
for ms in min_samples_grid:
k = max(2, min(ms, len(train_feats)-1))
nbrs = NearestNeighbors(n_neighbors=k, metric=metric).fit(train_feats)
dists, _ = nbrs.kneighbors(train_feats)
kth = np.sort(dists[:, -1])
for q in quantiles:
eps = float(np.quantile(kth, q))
m = DBSCAN(eps=eps, min_samples=ms, metric=metric, n_jobs=-1)
y = m.fit_predict(train_feats)
valid = y[y != -1]
if len(np.unique(valid)) < 2:
continue
try:
score = silhouette_score(train_feats[y != -1], y[y != -1])
except Exception:
score = -np.inf
if score > best["score"]:
best = {"score": score, "model": m, "labels": y}
if best["model"] is None:
return None, None, None
return best["model"], best["labels"], _centers_from_labels(train_feats, best["labels"])
def fit_cluster_algo(kind: str,
n_clusters: int,
train_feats: np.ndarray,
fast_kmeans: bool = True,
dbscan_eps: float = 0.8,
dbscan_min_samples: int = 5,
dbscan_metric: str = "euclidean",
dbscan_auto: bool = False):
if kind == "kmeans":
m = MiniBatchKMeans(n_clusters=n_clusters, batch_size=2048, n_init=10, random_state=42) if fast_kmeans \
else KMeans(n_clusters=n_clusters, n_init=10, random_state=42)
y = m.fit_predict(train_feats)
return m, y, getattr(m, "cluster_centers_", None)
if kind == "dbscan":
if dbscan_auto:
m, y, centers = tune_dbscan(train_feats, metric=dbscan_metric)
if m is None:
warnings.warn("DBSCAN(auto) no encontró ≥2 clusters. Fallback a KMeans.")
km = MiniBatchKMeans(n_clusters=max(n_clusters, 2), batch_size=2048, n_init=10, random_state=42)
y = km.fit_predict(train_feats)
return km, y, km.cluster_centers_
print(f"DBSCAN(auto) seleccionado (metric={dbscan_metric}).")
return m, y, centers
m = DBSCAN(eps=dbscan_eps, min_samples=dbscan_min_samples, metric=dbscan_metric, n_jobs=-1)
y = m.fit_predict(train_feats)
uniq = set(y) - {-1}
if len(uniq) < 2:
warnings.warn(f"DBSCAN devolvió {len(uniq)} cluster(s) válido(s). Considera ajustar eps/min_samples/metric o usar --dbscan_auto.")
return m, y, _centers_from_labels(train_feats, y)
ag = AgglomerativeClustering(n_clusters=n_clusters)
y = ag.fit_predict(train_feats)
centers = _centers_from_labels(train_feats, y)
return ag, y, centers
def assign_to_nearest_centroid(feats: np.ndarray, centers: Optional[np.ndarray]) -> np.ndarray:
if centers is None or len(centers) == 0:
return np.full((feats.shape[0],), -1, dtype=int)
d = ((feats[:, None, :] - centers[None, :, :]) ** 2).sum(axis=2)
return np.argmin(d, axis=1)
def internal_metrics(X: np.ndarray, y: np.ndarray) -> dict:
m = y != -1
if m.sum() > 1 and len(np.unique(y[m])) > 1:
return {
"silhouette": float(silhouette_score(X[m], y[m])),
"calinski_harabasz": float(calinski_harabasz_score(X[m], y[m])),
"davies_bouldin": float(davies_bouldin_score(X[m], y[m])),
}
return {"silhouette": None, "calinski_harabasz": None, "davies_bouldin": None}
# -----------------------------
# Plot
# -----------------------------
def plot_scatter_2d(X2d: np.ndarray, labels: np.ndarray, title: str, out_path: str):
plt.figure(figsize=(8, 6))
uniq = np.unique(labels)
if len(uniq) <= 1:
sns.scatterplot(x=X2d[:, 0], y=X2d[:, 1], s=12, linewidth=0, color="#1f77b4", legend=False)
else:
palette = sns.color_palette("tab20", n_colors=len(uniq))
sns.scatterplot(x=X2d[:, 0], y=X2d[:, 1], hue=labels, palette=palette, s=12, linewidth=0, legend=False)
plt.title(title)
plt.tight_layout()
plt.savefig(out_path, dpi=180)
plt.close()
# -----------------------------
# Main
# -----------------------------
def parse_args():
p = argparse.ArgumentParser(description="Unsupervised clustering for Carciofo (single CSV)")
p.add_argument("--images_dir", default=r"C:\Users\sof12\Desktop\ML\Datasets\Carciofo_GBIF", help="Carpeta que contiene las imágenes")
p.add_argument("--csv_path", default=r"C:\Users\sof12\Desktop\ML\Datasets\Carciofo_GBIF\joined_metadata.csv")
p.add_argument("--out_dir", default=r"C:\Users\sof12\Desktop\ML\Datasets\Carciofo_GBIF\TrainingV2")
p.add_argument("--img_ext", default=".jpg")
p.add_argument("--img_size", type=int, default=224)
p.add_argument("--batch_size", type=int, default=64)
p.add_argument("--seed", type=int, default=42)
p.add_argument("--sample", type=int, default=None)
p.add_argument("--search_subdirs", action="store_true", help="Buscar archivos faltantes en subcarpetas")
p.add_argument("--backbone", choices=["mobilenet", "efficientnet"], default="mobilenet")
p.add_argument("--cluster", choices=["kmeans", "dbscan", "agglomerative"], default="kmeans")
p.add_argument("--n_clusters", type=int, default=5)
p.add_argument("--fast_kmeans", action="store_true")
# DBSCAN
p.add_argument("--dbscan_eps", type=float, default=0.8)
p.add_argument("--dbscan_min_samples", type=int, default=5)
p.add_argument("--dbscan_metric", choices=["euclidean", "cosine", "manhattan"], default="euclidean")
p.add_argument("--dbscan_auto", action="store_true")
return p.parse_args()
# ...existing code...
def main():
args = parse_args()
set_seed(args.seed)
ensure_dir(args.out_dir)
print("Loading CSV...")
df = _read_csv_any(args.csv_path)
print("Resolving filenames and verifying files on disk...")
df = attach_paths_single_csv(df, args.images_dir, img_ext=args.img_ext, search_subdirs=args.search_subdirs)
if len(df) == 0:
print("No images found. Check images_dir and csv_path.")
return
# --- Solo 'fase' (Carciofo no usa 'fase V' / 'fase R') ---
phase_cols = find_matching_cols(df, ["fase"])
if phase_cols:
ser_phase = None
for c in phase_cols:
ser_phase = df[c] if ser_phase is None else ser_phase.combine_first(df[c])
df["fase"] = ser_phase
print(f"Using column(s) for 'fase': {phase_cols}")
else:
warnings.warn("No se encontró columna 'fase' en el CSV. No se incluirá en el output.")
# --- fin fase ---
# Optional sampling
if args.sample is not None and args.sample < len(df):
df = df.sample(n=args.sample, random_state=args.seed).reset_index(drop=True)
# Split indices
print("Splitting train/val/test...")
idx_all = np.arange(len(df))
idx_train, idx_tmp = train_test_split(idx_all, test_size=0.30, random_state=args.seed, shuffle=True)
idx_val, idx_test = train_test_split(idx_tmp, test_size=0.50, random_state=args.seed, shuffle=True)
df_train = df.iloc[idx_train].reset_index(drop=True)
df_val = df.iloc[idx_val].reset_index(drop=True)
df_test = df.iloc[idx_test].reset_index(drop=True)
# Embeddings in one pass
print("Building embedding model...")
preprocess_fn = make_preprocess(args.backbone)
model = make_backbone_model(args.img_size, args.backbone)
print("Computing embeddings (one pass)...")
ds_all = build_dataset(df["path"].tolist(), args.img_size, preprocess_fn, args.batch_size)
emb_all = compute_embeddings(model, ds_all)
emb_train = emb_all[idx_train]
emb_val = emb_all[idx_val]
emb_test = emb_all[idx_test]
# PCA reduction
print("Fitting PCA reduction (50D for clustering, 2D for plots)...")
scaler, pca50, train_50 = fit_reduction(emb_train, n_pca=50)
val_50 = transform_reduction(emb_val, scaler, pca50)
test_50 = transform_reduction(emb_test, scaler, pca50)
pca2 = PCA(n_components=2).fit(scaler.transform(emb_train))
train_2d = pca2.transform(scaler.transform(emb_train))
val_2d = pca2.transform(scaler.transform(emb_val))
test_2d = pca2.transform(scaler.transform(emb_test))
# Clustering
print(f"Clustering with {args.cluster}...")
model_c, y_train, centers = fit_cluster_algo(
args.cluster, args.n_clusters, train_50,
fast_kmeans=args.fast_kmeans,
dbscan_eps=args.dbscan_eps,
dbscan_min_samples=args.dbscan_min_samples,
dbscan_metric=args.dbscan_metric,
dbscan_auto=args.dbscan_auto,
)
if args.cluster == "kmeans":
y_val = model_c.predict(val_50)
y_test = model_c.predict(test_50)
else:
y_val = assign_to_nearest_centroid(val_50, centers)
y_test = assign_to_nearest_centroid(test_50, centers)
# Metrics
print("Computing internal metrics...")
train_m = internal_metrics(train_50, y_train)
val_m = internal_metrics(val_50, y_val)
test_m = internal_metrics(test_50, y_test)
# Save outputs (filename, fase, cluster, split)
print("Saving outputs...")
ensure_dir(args.out_dir)
def pick_min(df_split: pd.DataFrame, y: np.ndarray, split: str) -> pd.DataFrame:
cols = ["filename", "fase"]
keep = [c for c in cols if c in df_split.columns]
out = df_split[keep].copy()
out["cluster"] = y
out["split"] = split
return out
train_out = pick_min(df_train, y_train, "train")
val_out = pick_min(df_val, y_val, "val")
test_out = pick_min(df_test, y_test, "test")
assignments = pd.concat([train_out, val_out, test_out], ignore_index=True)
assignments.to_csv(os.path.join(args.out_dir, "assignments.csv"), index=False, encoding="utf-8")
train_out.to_csv(os.path.join(args.out_dir, "train_assignments.csv"), index=False, encoding="utf-8")
val_out.to_csv(os.path.join(args.out_dir, "val_assignments.csv"), index=False, encoding="utf-8")
test_out.to_csv(os.path.join(args.out_dir, "test_assignments.csv"), index=False, encoding="utf-8")
# Save models
joblib.dump(scaler, os.path.join(args.out_dir, "scaler.joblib"))
joblib.dump(pca50, os.path.join(args.out_dir, "pca50.joblib"))
joblib.dump(pca2, os.path.join(args.out_dir, "pca2.joblib"))
joblib.dump(model_c, os.path.join(args.out_dir, f"{args.cluster}.joblib"))
# Plots
plot_scatter_2d(train_2d, y_train, f"Train clusters ({args.cluster})", os.path.join(args.out_dir, "train_clusters_2d.png"))
plot_scatter_2d(val_2d, y_val, f"Val clusters ({args.cluster})", os.path.join(args.out_dir, "val_clusters_2d.png"))
plot_scatter_2d(test_2d, y_test, f"Test clusters ({args.cluster})", os.path.join(args.out_dir, "test_clusters_2d.png"))
# Summary
summary = {
"counts": {"train": len(df_train), "val": len(df_val), "test": len(df_test)},
"cluster": args.cluster, "n_clusters": args.n_clusters,
"backbone": args.backbone, "img_size": args.img_size,
"internal_metrics": {"train": train_m, "val": val_m, "test": test_m},
"csv": os.path.join(args.out_dir, "assignments.csv"),
}
with open(os.path.join(args.out_dir, "summary.json"), "w", encoding="utf-8") as f:
json.dump(summary, f, indent=2, ensure_ascii=False)
# Optional: save features
np.save(os.path.join(args.out_dir, "features.npy"), emb_all)
np.save(os.path.join(args.out_dir, "feature_paths.npy"), df["path"].to_numpy())
print("Done. Results saved to:", args.out_dir)
if __name__ == "__main__":
main()

View File

@ -21,7 +21,7 @@ from sklearn.manifold import TSNE
import umap
# ========== CONFIG ==========
ASSIGNMENTS_CSV = r"C:\Users\sof12\Desktop\ML\Datasets\Nocciola_GBIF\TrainingV7\assignments.csv"
ASSIGNMENTS_CSV = r"C:\Users\sof12\Desktop\ML\Datasets\Carciofo_GBIF\TrainingV2\assignments.csv"
OUT_DIR = os.path.dirname(ASSIGNMENTS_CSV) # donde el pipeline guardó joblibs / features
METHOD = "umap" # 'umap' o 'tsne'
RANDOM_STATE = 42
@ -30,7 +30,7 @@ UMAP_MIN_DIST = 0.1
TSNE_PERPLEXITY = 30
TSNE_ITER = 1000
SAVE_PLOT = True
PLOT_BY = ["cluster", "fase V", "fase R"] # lista de columnas de assignments.csv para colorear (usa lo que tengas)
PLOT_BY = ["cluster", "fase"] # lista de columnas de assignments.csv para colorear (usa lo que tengas)
# ============================
def find_file(patterns, folder):