Supervised Learning models
This commit is contained in:
parent
bde2959227
commit
a1e046c1ae
944
Code/Supervised_learning/MobileNetV1.py
Normal file
944
Code/Supervised_learning/MobileNetV1.py
Normal file
@ -0,0 +1,944 @@
|
||||
"""
|
||||
MobileNetV2 Transfer Learning para Clasificación de Fases Fenológicas - Nocciola
|
||||
Adaptado para Visual Studio Code
|
||||
Dataset: Nocciola GBIF
|
||||
Objetivo: Predecir fase R (fenológica reproductive)
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import random
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras import layers, models
|
||||
from tensorflow.keras.applications import MobileNetV2
|
||||
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
||||
from sklearn.metrics import classification_report, confusion_matrix
|
||||
from sklearn.utils import class_weight
|
||||
import json
|
||||
|
||||
# ----------------- CONFIG -----------------
|
||||
PROJECT_PATH = r'C:\Users\sof12\Desktop\ML\Datasets\Nocciola\GBIF'
|
||||
IMAGES_DIR = PROJECT_PATH # Las imágenes están en el directorio principal
|
||||
CSV_PATH = os.path.join(PROJECT_PATH, 'assignments.csv') # CSV principal
|
||||
OUTPUT_DIR = os.path.join(PROJECT_PATH, 'results_mobilenet_faseV_V1')
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
IMG_SIZE = (224, 224) # Recomendado para MobileNetV2
|
||||
BATCH_SIZE = 16 # Reducido para mejor estabilidad
|
||||
SEED = 42
|
||||
SPLIT = {'train': 0.7, 'val': 0.15, 'test': 0.15}
|
||||
FORCE_SPLIT = False
|
||||
|
||||
# ----------------- Utilities -----------------
|
||||
def set_seed(seed=42):
|
||||
"""Establecer semilla para reproducibilidad"""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
tf.random.set_seed(seed)
|
||||
|
||||
def analyze_class_distribution(df, column_name='fase V'):
|
||||
"""Analizar distribución de clases y detectar desbalances"""
|
||||
print(f"\n📊 === Análisis de Distribución de Clases ===")
|
||||
|
||||
# Contar por clase
|
||||
counts = df[column_name].value_counts()
|
||||
total = len(df)
|
||||
|
||||
print(f"📊 Total de muestras: {total}")
|
||||
print(f"📊 Número de clases: {len(counts)}")
|
||||
print(f"📊 Distribución por clase:")
|
||||
|
||||
# Mostrar estadísticas detalladas
|
||||
for clase, count in counts.items():
|
||||
percentage = (count / total) * 100
|
||||
print(f" - {clase}: {count} muestras ({percentage:.1f}%)")
|
||||
|
||||
# Detectar clases problemáticas
|
||||
min_samples = 5 # Umbral mínimo recomendado
|
||||
small_classes = counts[counts < min_samples]
|
||||
|
||||
if len(small_classes) > 0:
|
||||
print(f"\n⚠️ Clases con menos de {min_samples} muestras:")
|
||||
for clase, count in small_classes.items():
|
||||
print(f" - {clase}: {count} muestras")
|
||||
|
||||
print(f"\n💡 Recomendaciones:")
|
||||
print(f" 1. Considera recolectar más datos para estas clases")
|
||||
print(f" 2. O fusionar clases similares")
|
||||
print(f" 3. O usar técnicas de data augmentation específicas")
|
||||
|
||||
return counts, small_classes
|
||||
|
||||
def safe_read_csv(path):
|
||||
"""Leer CSV con manejo de encoding"""
|
||||
if not os.path.exists(path):
|
||||
raise FileNotFoundError(f'CSV no encontrado: {path}')
|
||||
try:
|
||||
df = pd.read_csv(path, encoding='utf-8')
|
||||
except UnicodeDecodeError:
|
||||
try:
|
||||
df = pd.read_csv(path, encoding='latin-1')
|
||||
except:
|
||||
df = pd.read_csv(path, encoding='iso-8859-1')
|
||||
return df
|
||||
|
||||
def resolve_image_path(images_dir, img_id):
|
||||
"""Resolver la ruta completa de una imagen"""
|
||||
if pd.isna(img_id) or str(img_id).strip() == '':
|
||||
return None
|
||||
|
||||
img_id = str(img_id).strip()
|
||||
|
||||
# Verificar si ya incluye extensión y existe
|
||||
direct_path = os.path.join(images_dir, img_id)
|
||||
if os.path.exists(direct_path):
|
||||
return direct_path
|
||||
|
||||
# Probar con extensiones comunes
|
||||
stem = os.path.splitext(img_id)[0]
|
||||
for ext in ['.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG']:
|
||||
img_path = os.path.join(images_dir, stem + ext)
|
||||
if os.path.exists(img_path):
|
||||
return img_path
|
||||
|
||||
return None
|
||||
|
||||
def prepare_image_folders(df, images_dir, out_dir, split=SPLIT, seed=SEED):
|
||||
"""Crear estructura de carpetas para flow_from_directory"""
|
||||
set_seed(seed)
|
||||
|
||||
# Filtrar solo filas con fase R válida e imágenes existentes
|
||||
print(f"📊 Datos iniciales: {len(df)} filas")
|
||||
|
||||
# Filtrar filas con fase R válida
|
||||
df_valid = df.dropna(subset=['fase V']).copy()
|
||||
df_valid = df_valid[df_valid['fase V'].str.strip() != '']
|
||||
print(f"📊 Con fase V válida: {len(df_valid)} filas")
|
||||
|
||||
# Verificar existencia de imágenes
|
||||
valid_rows = []
|
||||
for _, row in df_valid.iterrows():
|
||||
img_path = resolve_image_path(images_dir, row['id_img'])
|
||||
if img_path:
|
||||
valid_rows.append(row)
|
||||
else:
|
||||
print(f"⚠️ Imagen no encontrada: {row['id_img']}")
|
||||
|
||||
if not valid_rows:
|
||||
raise ValueError("❌ No se encontraron imágenes válidas")
|
||||
|
||||
df_final = pd.DataFrame(valid_rows)
|
||||
print(f"📊 Con imágenes existentes: {len(df_final)} filas")
|
||||
|
||||
# Mostrar distribución de clases
|
||||
fase_counts = df_final['fase V'].value_counts()
|
||||
print(f"\n📊 Distribución de fases R:")
|
||||
for fase, count in fase_counts.items():
|
||||
print(f" - {fase}: {count} imágenes")
|
||||
|
||||
# Remover clases con muy pocas muestras (menos de 3)
|
||||
min_samples = 3
|
||||
valid_phases = fase_counts[fase_counts >= min_samples].index.tolist()
|
||||
if len(valid_phases) < len(fase_counts):
|
||||
excluded = fase_counts[fase_counts < min_samples].index.tolist()
|
||||
print(f"⚠️ Excluyendo fases con menos de {min_samples} muestras: {excluded}")
|
||||
df_final = df_final[df_final['fase V'].isin(valid_phases)]
|
||||
print(f"📊 Después de filtrar: {len(df_final)} filas, {len(valid_phases)} clases")
|
||||
|
||||
labels = df_final['fase V'].unique().tolist()
|
||||
print(f"📊 Clases finales: {labels}")
|
||||
|
||||
# Mezclar y dividir datos
|
||||
df_shuffled = df_final.sample(frac=1, random_state=seed).reset_index(drop=True)
|
||||
n = len(df_shuffled)
|
||||
n_train = int(n * split['train'])
|
||||
n_val = int(n * split['val'])
|
||||
|
||||
train_df = df_shuffled.iloc[:n_train]
|
||||
val_df = df_shuffled.iloc[n_train:n_train + n_val]
|
||||
test_df = df_shuffled.iloc[n_train + n_val:]
|
||||
|
||||
print(f"📊 División final:")
|
||||
print(f" - Entrenamiento: {len(train_df)} imágenes")
|
||||
print(f" - Validación: {len(val_df)} imágenes")
|
||||
print(f" - Prueba: {len(test_df)} imágenes")
|
||||
|
||||
# Crear estructura de carpetas
|
||||
for part in ['train', 'val', 'test']:
|
||||
for label in labels:
|
||||
label_dir = os.path.join(out_dir, part, str(label))
|
||||
os.makedirs(label_dir, exist_ok=True)
|
||||
|
||||
# Función para copiar imágenes
|
||||
def copy_subset(subdf, subset_name):
|
||||
copied, missing = 0, 0
|
||||
for _, row in subdf.iterrows():
|
||||
src = resolve_image_path(images_dir, row['id_img'])
|
||||
if src:
|
||||
fase = str(row['fase V'])
|
||||
dst = os.path.join(out_dir, subset_name, fase, f"{row['id_img']}.jpg")
|
||||
try:
|
||||
shutil.copy2(src, dst)
|
||||
copied += 1
|
||||
except Exception as e:
|
||||
print(f"⚠️ Error copiando {src}: {e}")
|
||||
missing += 1
|
||||
else:
|
||||
missing += 1
|
||||
|
||||
print(f"✅ {subset_name}: {copied} imágenes copiadas, {missing} fallidas")
|
||||
return copied
|
||||
|
||||
# Copiar imágenes a las carpetas correspondientes
|
||||
copy_subset(train_df, 'train')
|
||||
copy_subset(val_df, 'val')
|
||||
copy_subset(test_df, 'test')
|
||||
|
||||
return train_df, val_df, test_df
|
||||
|
||||
def main():
|
||||
"""Función principal del pipeline"""
|
||||
parser = argparse.ArgumentParser(description='MobileNetV2 Transfer Learning para Nocciola')
|
||||
parser.add_argument('--csv_path', type=str, default=CSV_PATH,
|
||||
help='Ruta al archivo CSV con metadatos')
|
||||
parser.add_argument('--images_dir', type=str, default=IMAGES_DIR,
|
||||
help='Directorio con las imágenes')
|
||||
parser.add_argument('--output_dir', type=str, default=OUTPUT_DIR,
|
||||
help='Directorio de salida para resultados')
|
||||
parser.add_argument('--epochs', type=int, default=30,
|
||||
help='Número de épocas de entrenamiento')
|
||||
parser.add_argument('--force_split', action='store_true',
|
||||
help='Forzar recreación del split de datos')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print('\n🚀 === Inicio del pipeline MobileNetV2 para Nocciola ===')
|
||||
print(f"📁 Directorio de imágenes: {args.images_dir}")
|
||||
print(f"📄 Archivo CSV: {args.csv_path}")
|
||||
print(f"📂 Directorio de salida: {args.output_dir}")
|
||||
|
||||
# Establecer semilla
|
||||
set_seed(SEED)
|
||||
|
||||
# Crear directorio de salida
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# Leer datos
|
||||
print('\n📊 === Cargando datos ===')
|
||||
df = safe_read_csv(args.csv_path)
|
||||
print(f'📊 Total de registros en CSV: {len(df)}')
|
||||
print(f'📊 Columnas disponibles: {list(df.columns)}')
|
||||
|
||||
# Verificar columnas requeridas
|
||||
required_cols = {'id_img', 'fase V'}
|
||||
if not required_cols.issubset(set(df.columns)):
|
||||
missing = required_cols - set(df.columns)
|
||||
raise ValueError(f'❌ CSV debe contener las columnas: {missing}')
|
||||
|
||||
# Analizar distribución de clases antes del procesamiento
|
||||
analyze_class_distribution(df, 'fase V')
|
||||
|
||||
# Preparar estructura de carpetas
|
||||
SPLIT_DIR = os.path.join(args.output_dir, 'data_split')
|
||||
|
||||
if args.force_split and os.path.exists(SPLIT_DIR):
|
||||
print("🗑️ Eliminando split existente...")
|
||||
shutil.rmtree(SPLIT_DIR)
|
||||
|
||||
if not os.path.exists(SPLIT_DIR):
|
||||
print("\n📁 === Creando nueva división de datos ===")
|
||||
train_df, val_df, test_df = prepare_image_folders(df, args.images_dir, SPLIT_DIR)
|
||||
|
||||
# Guardar información del split
|
||||
train_df.to_csv(os.path.join(args.output_dir, 'train_split.csv'), index=False)
|
||||
val_df.to_csv(os.path.join(args.output_dir, 'val_split.csv'), index=False)
|
||||
test_df.to_csv(os.path.join(args.output_dir, 'test_split.csv'), index=False)
|
||||
|
||||
else:
|
||||
print("\n♻️ === Reutilizando división existente ===")
|
||||
# Cargar información del split si existe
|
||||
try:
|
||||
train_df = pd.read_csv(os.path.join(args.output_dir, 'train_split.csv'))
|
||||
val_df = pd.read_csv(os.path.join(args.output_dir, 'val_split.csv'))
|
||||
test_df = pd.read_csv(os.path.join(args.output_dir, 'test_split.csv'))
|
||||
except:
|
||||
print("⚠️ No se pudieron cargar los archivos de split, recreando...")
|
||||
train_df, val_df, test_df = prepare_image_folders(df, args.images_dir, SPLIT_DIR)
|
||||
|
||||
# Crear generadores de datos
|
||||
print("\n🔄 === Creando generadores de datos ===")
|
||||
|
||||
# Data augmentation para entrenamiento
|
||||
train_datagen = ImageDataGenerator(
|
||||
rescale=1./255,
|
||||
rotation_range=20,
|
||||
width_shift_range=0.1,
|
||||
height_shift_range=0.1,
|
||||
shear_range=0.1,
|
||||
zoom_range=0.1,
|
||||
horizontal_flip=True,
|
||||
fill_mode='nearest'
|
||||
)
|
||||
|
||||
# Solo normalización para validación y test
|
||||
val_test_datagen = ImageDataGenerator(rescale=1./255)
|
||||
|
||||
# Crear generadores
|
||||
train_gen = train_datagen.flow_from_directory(
|
||||
os.path.join(SPLIT_DIR, 'train'),
|
||||
target_size=IMG_SIZE,
|
||||
batch_size=BATCH_SIZE,
|
||||
class_mode='categorical',
|
||||
seed=SEED
|
||||
)
|
||||
|
||||
val_gen = val_test_datagen.flow_from_directory(
|
||||
os.path.join(SPLIT_DIR, 'val'),
|
||||
target_size=IMG_SIZE,
|
||||
batch_size=BATCH_SIZE,
|
||||
class_mode='categorical',
|
||||
shuffle=False
|
||||
)
|
||||
|
||||
test_gen = val_test_datagen.flow_from_directory(
|
||||
os.path.join(SPLIT_DIR, 'test'),
|
||||
target_size=IMG_SIZE,
|
||||
batch_size=BATCH_SIZE,
|
||||
class_mode='categorical',
|
||||
shuffle=False
|
||||
)
|
||||
|
||||
# Guardar mapeo de clases
|
||||
class_indices = train_gen.class_indices
|
||||
print(f'🏷️ Mapeo de clases: {class_indices}')
|
||||
|
||||
with open(os.path.join(args.output_dir, 'class_indices.json'), 'w') as f:
|
||||
json.dump(class_indices, f, indent=2)
|
||||
|
||||
print(f"📊 Muestras por conjunto:")
|
||||
print(f" - Entrenamiento: {train_gen.samples}")
|
||||
print(f" - Validación: {val_gen.samples}")
|
||||
print(f" - Prueba: {test_gen.samples}")
|
||||
print(f" - Número de clases: {train_gen.num_classes}")
|
||||
|
||||
# Crear y entrenar modelo
|
||||
print("\n🤖 === Construcción del modelo ===")
|
||||
|
||||
# Modelo base MobileNetV2
|
||||
base_model = MobileNetV2(
|
||||
weights='imagenet',
|
||||
include_top=False,
|
||||
input_shape=(*IMG_SIZE, 3)
|
||||
)
|
||||
base_model.trainable = False # Congelar inicialmente
|
||||
|
||||
# Construir modelo secuencial
|
||||
model = models.Sequential([
|
||||
base_model,
|
||||
layers.GlobalAveragePooling2D(),
|
||||
layers.Dropout(0.3),
|
||||
layers.Dense(128, activation='relu'),
|
||||
layers.Dropout(0.3),
|
||||
layers.Dense(train_gen.num_classes, activation='softmax')
|
||||
])
|
||||
|
||||
# Compilar modelo
|
||||
model.compile(
|
||||
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
|
||||
loss='categorical_crossentropy',
|
||||
metrics=['accuracy']
|
||||
)
|
||||
|
||||
print("📋 Resumen del modelo:")
|
||||
model.summary()
|
||||
|
||||
# Calcular pesos de clase
|
||||
print("\n⚖️ === Calculando pesos de clase ===")
|
||||
try:
|
||||
# Obtener etiquetas de entrenamiento
|
||||
train_labels = []
|
||||
for i in range(len(train_gen)):
|
||||
_, labels = train_gen[i]
|
||||
train_labels.extend(np.argmax(labels, axis=1))
|
||||
if len(train_labels) >= train_gen.samples:
|
||||
break
|
||||
|
||||
# Calcular pesos balanceados
|
||||
class_weights = class_weight.compute_class_weight(
|
||||
'balanced',
|
||||
classes=np.unique(train_labels),
|
||||
y=train_labels
|
||||
)
|
||||
class_weight_dict = dict(zip(np.unique(train_labels), class_weights))
|
||||
print(f"⚖️ Pesos de clase: {class_weight_dict}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ Error calculando pesos de clase: {e}")
|
||||
class_weight_dict = None
|
||||
|
||||
# Callbacks para entrenamiento
|
||||
early_stopping = tf.keras.callbacks.EarlyStopping(
|
||||
monitor='val_loss',
|
||||
patience=7,
|
||||
restore_best_weights=True,
|
||||
verbose=1
|
||||
)
|
||||
|
||||
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
|
||||
os.path.join(args.output_dir, 'best_model.keras'),
|
||||
save_best_only=True,
|
||||
monitor='val_loss',
|
||||
verbose=1
|
||||
)
|
||||
|
||||
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
|
||||
monitor='val_loss',
|
||||
factor=0.2,
|
||||
patience=3,
|
||||
min_lr=1e-7,
|
||||
verbose=1
|
||||
)
|
||||
|
||||
callbacks = [early_stopping, model_checkpoint, reduce_lr]
|
||||
|
||||
# Entrenamiento inicial
|
||||
print(f"\n🏋️ === Entrenamiento inicial ({args.epochs} épocas) ===")
|
||||
|
||||
try:
|
||||
history = model.fit(
|
||||
train_gen,
|
||||
validation_data=val_gen,
|
||||
epochs=args.epochs,
|
||||
callbacks=callbacks,
|
||||
class_weight=class_weight_dict,
|
||||
verbose=1
|
||||
)
|
||||
|
||||
print("✅ Entrenamiento inicial completado")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error durante entrenamiento: {e}")
|
||||
# Entrenar sin class_weight si hay problemas
|
||||
print("🔄 Intentando entrenamiento sin pesos de clase...")
|
||||
history = model.fit(
|
||||
train_gen,
|
||||
validation_data=val_gen,
|
||||
epochs=args.epochs,
|
||||
callbacks=callbacks,
|
||||
verbose=1
|
||||
)
|
||||
|
||||
# Fine-tuning
|
||||
print("\n🔧 === Fine-tuning ===")
|
||||
|
||||
# Descongelar algunas capas del modelo base
|
||||
base_model.trainable = True
|
||||
fine_tune_at = 100 # Descongelar las últimas 100 capas
|
||||
|
||||
for layer in base_model.layers[:fine_tune_at]:
|
||||
layer.trainable = False
|
||||
|
||||
# Recompilar con learning rate más bajo
|
||||
model.compile(
|
||||
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),
|
||||
loss='categorical_crossentropy',
|
||||
metrics=['accuracy']
|
||||
)
|
||||
|
||||
# Continuar entrenamiento
|
||||
fine_tune_epochs = 10
|
||||
total_epochs = len(history.history['loss']) + fine_tune_epochs
|
||||
|
||||
try:
|
||||
history_fine = model.fit(
|
||||
train_gen,
|
||||
validation_data=val_gen,
|
||||
epochs=total_epochs,
|
||||
initial_epoch=len(history.history['loss']),
|
||||
callbacks=callbacks,
|
||||
verbose=1
|
||||
)
|
||||
|
||||
print("✅ Fine-tuning completado")
|
||||
|
||||
# Combinar historiales
|
||||
for key in history.history:
|
||||
if key in history_fine.history:
|
||||
history.history[key].extend(history_fine.history[key])
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ Error durante fine-tuning: {e}")
|
||||
print("Continuando con modelo del entrenamiento inicial...")
|
||||
|
||||
# Evaluación final
|
||||
print("\n📊 === Evaluación en conjunto de prueba ===")
|
||||
|
||||
# Cargar mejor modelo
|
||||
try:
|
||||
model.load_weights(os.path.join(args.output_dir, 'best_model.keras'))
|
||||
print("✅ Cargado mejor modelo guardado")
|
||||
except:
|
||||
print("⚠️ Usando modelo actual")
|
||||
|
||||
# Guardar modelo final
|
||||
model.save(os.path.join(args.output_dir, 'final_model.keras'))
|
||||
print("💾 Modelo final guardado")
|
||||
|
||||
# Predicciones en test
|
||||
test_gen.reset()
|
||||
y_pred_prob = model.predict(test_gen, verbose=1)
|
||||
y_pred = np.argmax(y_pred_prob, axis=1)
|
||||
y_true = test_gen.classes
|
||||
|
||||
# Mapeo de índices a nombres de clase
|
||||
index_to_class = {v: k for k, v in class_indices.items()}
|
||||
|
||||
# Obtener solo las clases que realmente aparecen en el conjunto de test
|
||||
unique_test_classes = np.unique(np.concatenate([y_true, y_pred]))
|
||||
test_class_names = [index_to_class[i] for i in unique_test_classes]
|
||||
|
||||
print(f"📊 Clases en conjunto de test: {len(unique_test_classes)}")
|
||||
print(f"📊 Todas las clases entrenadas: {len(class_indices)}")
|
||||
print(f"📊 Clases presentes en test: {test_class_names}")
|
||||
|
||||
# Verificar si hay clases faltantes
|
||||
all_classes = set(range(len(class_indices)))
|
||||
test_classes = set(unique_test_classes)
|
||||
missing_classes = all_classes - test_classes
|
||||
|
||||
if missing_classes:
|
||||
missing_names = [index_to_class[i] for i in missing_classes]
|
||||
print(f"⚠️ Clases sin muestras en test: {missing_names}")
|
||||
|
||||
# Reporte de clasificación con clases filtradas
|
||||
print("\n📋 === Reporte de Clasificación ===")
|
||||
try:
|
||||
report = classification_report(
|
||||
y_true, y_pred,
|
||||
labels=unique_test_classes, # Especificar las clases exactas
|
||||
target_names=test_class_names,
|
||||
output_dict=False,
|
||||
zero_division=0 # Manejar divisiones por cero
|
||||
)
|
||||
print(report)
|
||||
|
||||
# Guardar reporte
|
||||
with open(os.path.join(args.output_dir, 'classification_report.txt'), 'w') as f:
|
||||
f.write(f"Clases evaluadas: {test_class_names}\n")
|
||||
f.write(f"Clases faltantes en test: {[index_to_class[i] for i in missing_classes] if missing_classes else 'Ninguna'}\n\n")
|
||||
f.write(report)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error en classification_report: {e}")
|
||||
print("📊 Generando reporte alternativo...")
|
||||
|
||||
# Reporte manual si falla el automático
|
||||
from collections import Counter
|
||||
true_counts = Counter(y_true)
|
||||
pred_counts = Counter(y_pred)
|
||||
|
||||
print("\n📊 Distribución manual:")
|
||||
print("Clase | Verdaderos | Predichos")
|
||||
print("-" * 35)
|
||||
for class_idx in unique_test_classes:
|
||||
class_name = index_to_class[class_idx]
|
||||
true_count = true_counts.get(class_idx, 0)
|
||||
pred_count = pred_counts.get(class_idx, 0)
|
||||
print(f"{class_name[:15]:15} | {true_count:10} | {pred_count:9}")
|
||||
|
||||
# Calcular accuracy básico
|
||||
accuracy = np.mean(y_true == y_pred)
|
||||
print(f"\n📊 Accuracy general: {accuracy:.4f}")
|
||||
|
||||
# Guardar reporte manual
|
||||
with open(os.path.join(args.output_dir, 'classification_report.txt'), 'w') as f:
|
||||
f.write("REPORTE MANUAL DE CLASIFICACIÓN\n")
|
||||
f.write("=" * 40 + "\n\n")
|
||||
f.write(f"Clases evaluadas: {test_class_names}\n")
|
||||
f.write(f"Clases faltantes en test: {[index_to_class[i] for i in missing_classes] if missing_classes else 'Ninguna'}\n\n")
|
||||
f.write("Distribución por clase:\n")
|
||||
f.write("Clase | Verdaderos | Predichos\n")
|
||||
f.write("-" * 35 + "\n")
|
||||
for class_idx in unique_test_classes:
|
||||
class_name = index_to_class[class_idx]
|
||||
true_count = true_counts.get(class_idx, 0)
|
||||
pred_count = pred_counts.get(class_idx, 0)
|
||||
f.write(f"{class_name[:15]:15} | {true_count:10} | {pred_count:9}\n")
|
||||
f.write(f"\nAccuracy general: {accuracy:.4f}\n")
|
||||
|
||||
# Matriz de confusión con clases filtradas
|
||||
cm = confusion_matrix(y_true, y_pred, labels=unique_test_classes)
|
||||
print(f"\n🔢 Matriz de Confusión ({len(unique_test_classes)} clases):")
|
||||
print(cm)
|
||||
|
||||
np.savetxt(os.path.join(args.output_dir, 'confusion_matrix.csv'),
|
||||
cm, delimiter=',', fmt='%d')
|
||||
|
||||
# Visualizaciones con clases filtradas
|
||||
print("\n📈 === Generando visualizaciones ===")
|
||||
|
||||
# Gráfico de entrenamiento
|
||||
plot_training_history(history, args.output_dir)
|
||||
|
||||
# Matriz de confusión visual con clases filtradas
|
||||
plot_confusion_matrix(cm, test_class_names, args.output_dir)
|
||||
|
||||
# Ejemplos de predicciones con clases filtradas
|
||||
plot_prediction_examples(test_gen, y_true, y_pred, test_class_names, args.output_dir, unique_test_classes)
|
||||
|
||||
print(f"\n🎉 === Pipeline completado ===")
|
||||
print(f"📁 Resultados guardados en: {args.output_dir}")
|
||||
print(f"📊 Precisión final en test: {np.mean(y_true == y_pred):.4f}")
|
||||
print(f"📊 Clases evaluadas: {len(unique_test_classes)}/{len(class_indices)}")
|
||||
|
||||
# Información adicional sobre clases desbalanceadas
|
||||
if missing_classes:
|
||||
print(f"\n⚠️ === Información sobre Clases Desbalanceadas ===")
|
||||
print(f"❌ Clases sin muestras en test: {len(missing_classes)}")
|
||||
for missing_idx in missing_classes:
|
||||
missing_name = index_to_class[missing_idx]
|
||||
print(f" - {missing_name} (índice {missing_idx})")
|
||||
print(f"💡 Sugerencia: Considera aumentar el dataset o fusionar clases similares")
|
||||
|
||||
def plot_training_history(history, output_dir):
|
||||
"""Graficar historial de entrenamiento"""
|
||||
try:
|
||||
plt.figure(figsize=(12, 4))
|
||||
|
||||
# Accuracy
|
||||
plt.subplot(1, 2, 1)
|
||||
plt.plot(history.history['accuracy'], label='Entrenamiento')
|
||||
if 'val_accuracy' in history.history:
|
||||
plt.plot(history.history['val_accuracy'], label='Validación')
|
||||
plt.title('Precisión del Modelo')
|
||||
plt.xlabel('Época')
|
||||
plt.ylabel('Precisión')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
|
||||
# Loss
|
||||
plt.subplot(1, 2, 2)
|
||||
plt.plot(history.history['loss'], label='Entrenamiento')
|
||||
if 'val_loss' in history.history:
|
||||
plt.plot(history.history['val_loss'], label='Validación')
|
||||
plt.title('Pérdida del Modelo')
|
||||
plt.xlabel('Época')
|
||||
plt.ylabel('Pérdida')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(output_dir, 'training_history.png'), dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
print("✅ Gráfico de entrenamiento guardado")
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ Error creando gráfico de entrenamiento: {e}")
|
||||
|
||||
def plot_confusion_matrix(cm, class_names, output_dir):
|
||||
"""Graficar matriz de confusión"""
|
||||
try:
|
||||
plt.figure(figsize=(10, 8))
|
||||
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
|
||||
xticklabels=class_names, yticklabels=class_names)
|
||||
plt.title('Matriz de Confusión')
|
||||
plt.ylabel('Etiqueta Verdadera')
|
||||
plt.xlabel('Etiqueta Predicha')
|
||||
plt.xticks(rotation=45, ha='right')
|
||||
plt.yticks(rotation=0)
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(output_dir, 'confusion_matrix.png'), dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
print("✅ Matriz de confusión guardada")
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ Error creando matriz de confusión: {e}")
|
||||
|
||||
def plot_prediction_examples(test_gen, y_true, y_pred, class_names, output_dir, unique_classes=None, n_examples=12):
|
||||
"""Mostrar ejemplos de predicciones correctas e incorrectas"""
|
||||
try:
|
||||
# Obtener índices de predicciones correctas e incorrectas
|
||||
correct_idx = np.where(y_true == y_pred)[0]
|
||||
incorrect_idx = np.where(y_true != y_pred)[0]
|
||||
|
||||
# Seleccionar ejemplos
|
||||
n_correct = min(n_examples // 2, len(correct_idx))
|
||||
n_incorrect = min(n_examples // 2, len(incorrect_idx))
|
||||
|
||||
selected_correct = np.random.choice(correct_idx, n_correct, replace=False) if len(correct_idx) > 0 else []
|
||||
selected_incorrect = np.random.choice(incorrect_idx, n_incorrect, replace=False) if len(incorrect_idx) > 0 else []
|
||||
|
||||
selected_indices = np.concatenate([selected_correct, selected_incorrect])
|
||||
|
||||
if len(selected_indices) == 0:
|
||||
print("⚠️ No hay ejemplos para mostrar")
|
||||
return
|
||||
|
||||
# Crear gráfico
|
||||
n_show = len(selected_indices)
|
||||
cols = 4
|
||||
rows = (n_show + cols - 1) // cols
|
||||
|
||||
plt.figure(figsize=(15, 4 * rows))
|
||||
|
||||
for i, idx in enumerate(selected_indices):
|
||||
plt.subplot(rows, cols, i + 1)
|
||||
|
||||
# Obtener imagen
|
||||
# Nota: esto es una aproximación, idealmente necesitaríamos acceder a las imágenes originales
|
||||
img_path = test_gen.filepaths[idx]
|
||||
img = plt.imread(img_path)
|
||||
|
||||
plt.imshow(img)
|
||||
plt.axis('off')
|
||||
|
||||
true_label = class_names[y_true[idx]]
|
||||
pred_label = class_names[y_pred[idx]]
|
||||
|
||||
color = 'green' if y_true[idx] == y_pred[idx] else 'red'
|
||||
plt.title(f'Real: {true_label}\nPredicción: {pred_label}',
|
||||
color=color, fontsize=10)
|
||||
|
||||
plt.suptitle('Ejemplos de Predicciones (Verde=Correcta, Rojo=Incorrecta)', fontsize=14)
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(output_dir, 'prediction_examples.png'), dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
print("✅ Ejemplos de predicciones guardados")
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ Error creando ejemplos de predicciones: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
# ----------------- MAIN -----------------
|
||||
print('\n=== Start of the pipeline ===')
|
||||
df = safe_read_csv(CSV_PATH)
|
||||
print('Total registered images in the CSV:', len(df))
|
||||
|
||||
# Check columns
|
||||
required_cols = {'id_img','fase V'}
|
||||
if not required_cols.issubset(set(df.columns)):
|
||||
raise ValueError(f'CSV must contain the columns: {required_cols}')
|
||||
|
||||
# Prepare folders
|
||||
SPLIT_DIR = os.path.join(PROJECT_PATH, 'results_nocc/split_fase V')
|
||||
if FORCE_SPLIT:
|
||||
shutil.rmtree(SPLIT_DIR, ignore_errors=True)
|
||||
|
||||
if not os.path.exists(SPLIT_DIR):
|
||||
print("Creating a new split...")
|
||||
train_df, val_df, test_df = prepare_image_folders(df, IMAGES_DIR, SPLIT_DIR)
|
||||
else:
|
||||
print("Reusing existing split...")
|
||||
# Load the dataframes from the created split directories
|
||||
train_df = pd.DataFrame([(f.name.split('.')[0], Path(f).parent.name) for f in Path(os.path.join(SPLIT_DIR, 'train')).rglob('*.jpg')], columns=['id_img', 'fase'])
|
||||
val_df = pd.DataFrame([(f.name.split('.')[0], Path(f).parent.name) for f in Path(os.path.join(SPLIT_DIR, 'val')).rglob('*.jpg')], columns=['id_img', 'fase'])
|
||||
test_df = pd.DataFrame([(f.name.split('.')[0], Path(f).parent.name) for f in Path(os.path.join(SPLIT_DIR, 'test')).rglob('*.jpg')], columns=['id_img', 'fase'])
|
||||
|
||||
|
||||
# Data generators
|
||||
train_datagen = ImageDataGenerator(rescale=1./255,
|
||||
rotation_range=20,
|
||||
width_shift_range=0.1,
|
||||
height_shift_range=0.1,
|
||||
shear_range=0.1,
|
||||
zoom_range=0.1,
|
||||
horizontal_flip=True,
|
||||
fill_mode='nearest')
|
||||
|
||||
val_test_datagen = ImageDataGenerator(rescale=1./255)
|
||||
|
||||
train_gen = train_datagen.flow_from_directory(os.path.join(SPLIT_DIR,'train'),
|
||||
target_size=IMG_SIZE,
|
||||
batch_size=BATCH_SIZE,
|
||||
class_mode='categorical',
|
||||
seed=SEED)
|
||||
|
||||
val_gen = val_test_datagen.flow_from_directory(os.path.join(SPLIT_DIR,'val'),
|
||||
target_size=IMG_SIZE,
|
||||
batch_size=BATCH_SIZE,
|
||||
class_mode='categorical',
|
||||
shuffle=False)
|
||||
|
||||
test_gen = val_test_datagen.flow_from_directory(os.path.join(SPLIT_DIR,'test'),
|
||||
target_size=IMG_SIZE,
|
||||
batch_size=BATCH_SIZE,
|
||||
class_mode='categorical',
|
||||
shuffle=False)
|
||||
|
||||
# Save class mapping->index
|
||||
class_indices = train_gen.class_indices
|
||||
print('Class indices:', class_indices)
|
||||
import json
|
||||
with open(os.path.join(OUTPUT_DIR,'class_indices.txt'),'w') as f:
|
||||
json.dump(class_indices, f)
|
||||
|
||||
# ----------------- Modelo (Transfer Learning MobileNetV2) - Retraining from scratch -----------------
|
||||
print('\n=== Inicio del entrenamiento desde cero ===')
|
||||
|
||||
# Define the MobileNetV2 base model
|
||||
base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=(*IMG_SIZE,3))
|
||||
|
||||
# Set the base model to not be trainable initially
|
||||
base_model.trainable = False
|
||||
|
||||
# Build a new sequential model
|
||||
model = models.Sequential([
|
||||
base_model,
|
||||
layers.GlobalAveragePooling2D(),
|
||||
layers.Dropout(0.3),
|
||||
layers.Dense(128, activation='relu'),
|
||||
layers.Dropout(0.3),
|
||||
layers.Dense(train_gen.num_classes, activation='softmax') # Use the correct number of classes
|
||||
])
|
||||
|
||||
# Compile the new model
|
||||
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
|
||||
loss='categorical_crossentropy',
|
||||
metrics=['accuracy'])
|
||||
|
||||
model.summary()
|
||||
|
||||
# Define callbacks
|
||||
early = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
|
||||
chk = tf.keras.callbacks.ModelCheckpoint(os.path.join(OUTPUT_DIR,'best_model.keras'), save_best_only=True)
|
||||
|
||||
# Calculate class weights (re-calculate in case the dataframe changed)
|
||||
from sklearn.utils import class_weight
|
||||
class_weights = class_weight.compute_class_weight(
|
||||
'balanced',
|
||||
classes=np.unique(train_df['fase']),
|
||||
y=train_df['fase']
|
||||
)
|
||||
class_weights = dict(zip(np.unique(train_gen.classes), class_weights))
|
||||
print("Class weights for training:", class_weights)
|
||||
|
||||
# Train the new model
|
||||
EPOCHS = 45 # Use the original number of epochs for the first phase
|
||||
history = model.fit(
|
||||
train_gen,
|
||||
validation_data=val_gen,
|
||||
epochs=EPOCHS,
|
||||
callbacks=[early, chk],
|
||||
class_weight=class_weights
|
||||
)
|
||||
|
||||
# --- FINE-TUNING ---
|
||||
print('\n=== Start of fine-tuning phase ===')
|
||||
|
||||
# Thawing some layers
|
||||
base_model.trainable = True
|
||||
fine_tune_at = 100 # This value could be adjusted, for example the last 100 layers
|
||||
for layer in base_model.layers[:fine_tune_at]:
|
||||
layer.trainable = False
|
||||
|
||||
# Recompiling the model with a lower learning rate
|
||||
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),
|
||||
loss='categorical_crossentropy',
|
||||
metrics=['accuracy'])
|
||||
|
||||
model.summary()
|
||||
|
||||
# Continuar entrenamiento
|
||||
fine_tune_epochs = 10
|
||||
total_epochs = EPOCHS + fine_tune_epochs
|
||||
history_fine_tune = model.fit(
|
||||
train_gen,
|
||||
validation_data=val_gen,
|
||||
epochs=total_epochs,
|
||||
initial_epoch=history.epoch[-1],
|
||||
callbacks=[early, chk] # Using the same callbacks
|
||||
)
|
||||
|
||||
# ----------------- Evaluación en Test -----------------
|
||||
print('\n=== Evaluation in a test set ===')
|
||||
|
||||
# Load the best weights saved during the training (from either initial or fine-tuning phase)
|
||||
model.load_weights(os.path.join(OUTPUT_DIR,'best_model.keras'))
|
||||
|
||||
# Predictions
|
||||
y_pred_prob = model.predict(test_gen)
|
||||
y_pred = np.argmax(y_pred_prob, axis=1)
|
||||
y_true = test_gen.classes
|
||||
|
||||
# Load the full class indices mapping
|
||||
import json
|
||||
with open(os.path.join(OUTPUT_DIR,'class_indices.txt'),'r') as f:
|
||||
full_class_indices = json.load(f)
|
||||
|
||||
# Get the corresponding class names for all classes from the loaded class_indices
|
||||
index_to_class = {v: k for k, v in full_class_indices.items()}
|
||||
all_class_names = [index_to_class[i] for i in sorted(index_to_class.keys())]
|
||||
|
||||
# Get the unique class indices present in the test set
|
||||
unique_test_indices = np.unique(y_true)
|
||||
|
||||
# Get the corresponding class names for the unique test indices
|
||||
test_labels_filtered = [index_to_class[i] for i in unique_test_indices]
|
||||
|
||||
|
||||
# Reporte
|
||||
report = classification_report(y_true, y_pred, labels=unique_test_indices, target_names=test_labels_filtered) # Use unique_test_indices for labels and test_labels_filtered for target_names
|
||||
cm = confusion_matrix(y_true, y_pred, labels=unique_test_indices) # Specify labels for confusion matrix
|
||||
|
||||
print('\nClassification Report:\n', report)
|
||||
print('\nConfusion Matrix:\n', cm)
|
||||
|
||||
with open(os.path.join(OUTPUT_DIR,'classification_report.txt'),'w') as f:
|
||||
f.write(report)
|
||||
np.savetxt(os.path.join(OUTPUT_DIR,'confusion_matrix.csv'), cm, delimiter=',', fmt='%d')
|
||||
|
||||
# ----------------- Visualizations -----------------
|
||||
|
||||
def show_examples(test_gen, y_true, y_pred, labels, n=6):
|
||||
filepaths = []
|
||||
for i in range(len(test_gen.filepaths)):
|
||||
filepaths.append(test_gen.filepaths[i])
|
||||
# select examples
|
||||
correct_idx = [i for i,(a,b) in enumerate(zip(y_true,y_pred)) if a==b]
|
||||
wrong_idx = [i for i,(a,b) in enumerate(zip(y_true,y_pred)) if a!=b]
|
||||
examples = (correct_idx[:n//2] if len(correct_idx)>0 else []) + (wrong_idx[:n//2] if len(wrong_idx)>0 else [])
|
||||
|
||||
plt.figure(figsize=(15,8))
|
||||
for i, idx in enumerate(examples):
|
||||
img = plt.imread(filepaths[idx])
|
||||
plt.subplot(2, n//2, i+1)
|
||||
plt.imshow(img)
|
||||
plt.axis('off')
|
||||
plt.title(f'True: {labels[y_true[idx]]}\nPred: {labels[y_pred[idx]]}')
|
||||
plt.suptitle('Examples: Right and Wrong')
|
||||
plt.show()
|
||||
|
||||
# Call the function with the correct variables from the evaluation step
|
||||
show_examples(test_gen, y_true, y_pred, all_class_names, n=6)
|
||||
|
||||
#
|
||||
plt.figure(figsize=(8,4))
|
||||
plt.plot(history.history['accuracy'], label='train_acc')
|
||||
plt.plot(history.history['val_accuracy'], label='val_acc')
|
||||
# Include fine-tuning history if it exists
|
||||
if 'history_fine_tune' in locals():
|
||||
plt.plot(history_fine_tune.history['accuracy'], label='fine_tune_train_acc')
|
||||
plt.plot(history_fine_tune.history['val_accuracy'], label='fine_tune_val_acc')
|
||||
|
||||
plt.title('Accuracy during training')
|
||||
plt.xlabel('Epoch')
|
||||
plt.ylabel('Accuracy')
|
||||
plt.legend()
|
||||
plt.grid()
|
||||
plt.show()
|
||||
|
||||
# Heatmap of the confusion matrix
|
||||
plt.figure(figsize=(8,6))
|
||||
sns.heatmap(cm, annot=True, fmt='d', xticklabels=all_class_names, yticklabels=all_class_names, cmap='Blues')
|
||||
plt.xlabel('Prediction')
|
||||
plt.ylabel('True')
|
||||
plt.title('Matriz de confusión')
|
||||
plt.show()
|
||||
268
Code/Supervised_learning/README_ResNet50.md
Normal file
268
Code/Supervised_learning/README_ResNet50.md
Normal file
@ -0,0 +1,268 @@
|
||||
# ResNet50 Transfer Learning para Nocciola - Documentación
|
||||
|
||||
## 🚀 **Implementación Completa de ResNet50**
|
||||
|
||||
Este proyecto implementa transfer learning con ResNet50 para clasificación de fases fenológicas de nocciola, basado en la estructura exitosa de MobileNetV1.py pero optimizado específicamente para ResNet50.
|
||||
|
||||
## 📁 **Estructura de Archivos**
|
||||
|
||||
```
|
||||
Code/Supervised_learning/
|
||||
├── ResNET.py # Script principal ResNet50 ✅ NUEVO
|
||||
├── train_resnet50.py # Script facilitado de entrenamiento ✅ NUEVO
|
||||
├── MobileNetV1.py # Script MobileNetV2 (referencia exitosa)
|
||||
└── README_ResNet50.md # Esta documentación ✅ NUEVO
|
||||
```
|
||||
|
||||
## 🔧 **Características Específicas de ResNet50**
|
||||
|
||||
### **Arquitectura Optimizada:**
|
||||
```python
|
||||
# Modelo base ResNet50
|
||||
base_model = ResNet50(
|
||||
weights='imagenet',
|
||||
include_top=False,
|
||||
input_shape=(224, 224, 3)
|
||||
)
|
||||
|
||||
# Capas personalizadas optimizadas para ResNet50
|
||||
model = Sequential([
|
||||
base_model,
|
||||
GlobalAveragePooling2D(),
|
||||
BatchNormalization(), # ✅ Importante para ResNet50
|
||||
Dropout(0.5), # ✅ Dropout más alto
|
||||
Dense(256, activation='relu'), # ✅ Más neuronas
|
||||
BatchNormalization(),
|
||||
Dropout(0.3),
|
||||
Dense(num_classes, activation='softmax')
|
||||
])
|
||||
```
|
||||
|
||||
### **Optimizaciones Específicas:**
|
||||
|
||||
1. **Learning Rate Conservador:**
|
||||
- Inicial: `5e-4` (vs `1e-3` en MobileNet)
|
||||
- Fine-tuning: `1e-6` (muy bajo para estabilidad)
|
||||
|
||||
2. **Data Augmentation Reducido:**
|
||||
- Rotación: 15° (vs 20° en MobileNet)
|
||||
- Shifts/Zoom: 0.08 (vs 0.1 en MobileNet)
|
||||
- Más conservador para evitar overfitting
|
||||
|
||||
3. **Fine-tuning Selectivo:**
|
||||
- Solo últimas capas residuales (conv5_x)
|
||||
- Punto de corte: capa 140 (de 175 total)
|
||||
- Preserva features básicas, adapta características específicas
|
||||
|
||||
4. **Callbacks Adaptados:**
|
||||
- Paciencia EarlyStopping: 10 (vs 7 en MobileNet)
|
||||
- Más épocas de fine-tuning: 15 (vs 10)
|
||||
- Factor ReduceLR: 0.3 (más agresivo)
|
||||
|
||||
## 📊 **Comparación ResNet50 vs MobileNetV2**
|
||||
|
||||
| Característica | ResNet50 | MobileNetV2 |
|
||||
|----------------|----------|-------------|
|
||||
| **Parámetros** | ~25M | ~3.4M |
|
||||
| **Memoria** | Alta | Baja |
|
||||
| **Velocidad** | Lenta | Rápida |
|
||||
| **Precisión** | Alta | Media-Alta |
|
||||
| **Overfitting** | Más propenso | Menos propenso |
|
||||
| **Dataset ideal** | Grande/Complejo | Pequeño/Mediano |
|
||||
|
||||
## 🎯 **Casos de Uso Recomendados**
|
||||
|
||||
### **Usar ResNet50 cuando:**
|
||||
- ✅ Dataset tiene suficientes muestras (>1000)
|
||||
- ✅ Patrones complejos en imágenes
|
||||
- ✅ Precisión es más importante que velocidad
|
||||
- ✅ Hardware robusto disponible
|
||||
- ✅ Investigación/análisis detallado
|
||||
|
||||
### **Usar MobileNetV2 cuando:**
|
||||
- ✅ Dataset pequeño (<500 muestras)
|
||||
- ✅ Velocidad es importante
|
||||
- ✅ Recursos limitados
|
||||
- ✅ Producción/móvil
|
||||
- ✅ Prototipado rápido
|
||||
|
||||
## 🚀 **Ejecución**
|
||||
|
||||
### **Opción 1: Script Facilitado (Recomendado)**
|
||||
```bash
|
||||
# Entrenar solo ResNet50
|
||||
python train_resnet50.py --model resnet50 --epochs 25
|
||||
|
||||
# Comparar ambos modelos
|
||||
python train_resnet50.py --model both --epochs 20
|
||||
|
||||
# Ver comparación detallada
|
||||
python train_resnet50.py --compare
|
||||
```
|
||||
|
||||
### **Opción 2: Ejecución Directa**
|
||||
```bash
|
||||
# ResNet50 básico
|
||||
python ResNET.py --epochs 25 --force_split
|
||||
|
||||
# ResNet50 con dataset específico
|
||||
python ResNET.py --csv_path "assignments.csv" --epochs 30
|
||||
```
|
||||
|
||||
### **Parámetros Disponibles:**
|
||||
- `--csv_path`: Ruta al CSV (default: assignments.csv)
|
||||
- `--images_dir`: Directorio de imágenes
|
||||
- `--output_dir`: Directorio de resultados (default: results_resnet50_faseV)
|
||||
- `--epochs`: Épocas de entrenamiento (default: 25)
|
||||
- `--force_split`: Recrear división de datos
|
||||
|
||||
## 📈 **Proceso de Entrenamiento ResNet50**
|
||||
|
||||
### **Fase 1: Entrenamiento Inicial**
|
||||
1. Base model congelado (ImageNet weights)
|
||||
2. Solo entrenar capas personalizadas
|
||||
3. Learning rate conservador (5e-4)
|
||||
4. Callbacks con paciencia aumentada
|
||||
|
||||
### **Fase 2: Fine-tuning Selectivo**
|
||||
5. Descongelar solo conv5_x layers (últimas capas residuales)
|
||||
6. Learning rate muy bajo (1e-6)
|
||||
7. 15 épocas adicionales
|
||||
8. Monitoreo estricto de overfitting
|
||||
|
||||
### **Evaluación Robusta:**
|
||||
9. Manejo automático de clases desbalanceadas
|
||||
10. Métricas específicas para ResNet50
|
||||
11. Visualizaciones diferenciadas
|
||||
|
||||
## 📂 **Resultados Generados**
|
||||
|
||||
### **Modelos:**
|
||||
- `best_resnet50_model.keras`: Mejor modelo durante entrenamiento
|
||||
- `final_resnet50_model.keras`: Modelo final completo
|
||||
|
||||
### **Reportes:**
|
||||
- `classification_report.txt`: Reporte detallado con marcador ResNet50
|
||||
- `confusion_matrix.csv`: Matriz numérica
|
||||
|
||||
### **Visualizaciones:**
|
||||
- `resnet50_training_history.png`: Gráficos de entrenamiento
|
||||
- `resnet50_confusion_matrix.png`: Matriz visual
|
||||
- `resnet50_prediction_examples.png`: Ejemplos de predicciones
|
||||
|
||||
### **Data Splits:**
|
||||
- `train_split.csv`, `val_split.csv`, `test_split.csv`
|
||||
- `class_indices.json`: Mapeo de clases
|
||||
|
||||
## ⚙️ **Configuración Técnica**
|
||||
|
||||
### **Requisitos del Sistema:**
|
||||
- RAM: Mínimo 8GB (recomendado 16GB)
|
||||
- GPU: Opcional pero muy recomendada
|
||||
- Espacio: ~3GB para resultados completos
|
||||
|
||||
### **Dependencias:**
|
||||
```bash
|
||||
tensorflow>=2.8.0
|
||||
scikit-learn
|
||||
pandas
|
||||
numpy
|
||||
matplotlib
|
||||
seaborn
|
||||
```
|
||||
|
||||
### **Configuración Interna:**
|
||||
```python
|
||||
IMG_SIZE = (224, 224) # Estándar ImageNet
|
||||
BATCH_SIZE = 16 # Reducido para ResNet50
|
||||
SPLIT = {'train': 0.7, 'val': 0.15, 'test': 0.15}
|
||||
```
|
||||
|
||||
## 🔍 **Diferencias Implementadas vs MobileNet**
|
||||
|
||||
### **1. Arquitectura:**
|
||||
- BatchNormalization adicional
|
||||
- Dropout más agresivo (0.5 vs 0.3)
|
||||
- Dense layer mayor (256 vs 128)
|
||||
|
||||
### **2. Entrenamiento:**
|
||||
- Learning rates más conservadores
|
||||
- Fine-tuning más selectivo
|
||||
- Más épocas de fine-tuning
|
||||
|
||||
### **3. Data Augmentation:**
|
||||
- Rotaciones menores (15° vs 20°)
|
||||
- Shifts reducidos (0.08 vs 0.1)
|
||||
- Menos agresivo para ResNet50
|
||||
|
||||
### **4. Callbacks:**
|
||||
- Paciencia aumentada (10 vs 7)
|
||||
- Factor ReduceLR más agresivo (0.3 vs 0.2)
|
||||
- Monitoreo específico para ResNet
|
||||
|
||||
### **5. Outputs:**
|
||||
- Nombres diferenciados (`resnet50_*`)
|
||||
- Reportes marcados con modelo
|
||||
- Métricas específicas
|
||||
|
||||
## 🎯 **Recomendaciones de Uso**
|
||||
|
||||
### **Para Dataset Nocciola Actual:**
|
||||
Dado que el dataset es relativamente pequeño (~500 muestras):
|
||||
|
||||
1. **Primera opción:** MobileNetV2 (más adecuado)
|
||||
2. **Segunda opción:** ResNet50 con regularización fuerte
|
||||
3. **Comparación:** Entrenar ambos y comparar resultados
|
||||
|
||||
### **Comando Recomendado:**
|
||||
```bash
|
||||
# Comparar ambos modelos con pocas épocas
|
||||
python train_resnet50.py --model both --epochs 15
|
||||
|
||||
# Analizar resultados y elegir el mejor
|
||||
```
|
||||
|
||||
## 🚨 **Solución de Problemas**
|
||||
|
||||
### **Overfitting en ResNet50:**
|
||||
- Reducir épocas de entrenamiento
|
||||
- Aumentar dropout
|
||||
- Usar dataset filtrado
|
||||
- Más data augmentation
|
||||
|
||||
### **Underfitting:**
|
||||
- Aumentar épocas
|
||||
- Reducir regularización
|
||||
- Learning rate más alto
|
||||
- Descongelar más capas
|
||||
|
||||
### **Problemas de memoria:**
|
||||
- Reducir BATCH_SIZE a 8 o 4
|
||||
- Usar gradient checkpointing
|
||||
- Cerrar otras aplicaciones
|
||||
|
||||
## 📊 **Interpretación de Resultados**
|
||||
|
||||
### **Métricas Esperadas:**
|
||||
- **ResNet50:** Mayor precision, posible overfitting
|
||||
- **MobileNetV2:** Más generalizable, menos overfitting
|
||||
|
||||
### **Comparación Visual:**
|
||||
- Training curves más suaves en MobileNet
|
||||
- Posible gap train/val en ResNet50
|
||||
- Matriz de confusión similar o mejor en ResNet50
|
||||
|
||||
### **Decisión Final:**
|
||||
Elegir modelo basado en:
|
||||
1. Accuracy en test set
|
||||
2. Diferencia train/validation
|
||||
3. Requisitos de producción
|
||||
4. Interpretabilidad de errores
|
||||
|
||||
---
|
||||
|
||||
## 🎉 **¡ResNet50 Implementado Exitosamente!**
|
||||
|
||||
El modelo ResNet50 está completamente implementado usando la misma estructura robusta que MobileNetV1.py, con optimizaciones específicas para ResNet50 y manejo automático de clases desbalanceadas.
|
||||
|
||||
**Para comenzar:** `python train_resnet50.py --model resnet50 --epochs 20`
|
||||
741
Code/Supervised_learning/ResNET.py
Normal file
741
Code/Supervised_learning/ResNET.py
Normal file
@ -0,0 +1,741 @@
|
||||
"""
|
||||
ResNet50 Transfer Learning para Clasificación de Fases Fenológicas - Nocciola
|
||||
Adaptado para Visual Studio Code basado en MobileNetV1.py exitoso
|
||||
Dataset: Nocciola GBIF
|
||||
Objetivo: Predecir fase (fenológica vegetativa)
|
||||
"""
|
||||
|
||||
# =============================================================================
|
||||
# LIBRARY IMPORTS
|
||||
# =============================================================================
|
||||
|
||||
# System and file handling libraries
|
||||
import os # Operating system interface (paths, directories)
|
||||
import shutil # High-level file operations (copy, move, delete)
|
||||
import argparse # Command-line argument parsing
|
||||
import random # Random number generation for reproducibility
|
||||
import pandas as pd # Data manipulation and analysis
|
||||
import numpy as np # Numerical computing with multidimensional arrays
|
||||
from pathlib import Path # Modern path handling
|
||||
import json
|
||||
|
||||
# Data visualization libraries
|
||||
import matplotlib.pyplot as plt # Plotting and visualization
|
||||
import seaborn as sns # Statistical data visualization
|
||||
|
||||
# Deep learning libraries (TensorFlow and Keras)
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras import layers, models
|
||||
from tensorflow.keras.applications import ResNet50
|
||||
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
||||
from sklearn.metrics import classification_report, confusion_matrix
|
||||
from sklearn.utils import class_weight
|
||||
|
||||
# ----------------- CONFIG -----------------
|
||||
PROJECT_PATH = r'C:\Users\sof12\Desktop\ML\Datasets\Nocciola\combi'
|
||||
IMAGES_DIR = r'C:\Users\sof12\Desktop\ML\Datasets\Nocciola\combi' # Las imágenes están en el directorio principal
|
||||
CSV_PATH = os.path.join(PROJECT_PATH, 'assignments.csv') # CSV principal
|
||||
OUTPUT_DIR = os.path.join(PROJECT_PATH, 'results_resnet50')
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
IMG_SIZE = (224, 224) # Recomendado para ResNet50 (tamaño estándar ImageNet)
|
||||
BATCH_SIZE = 16 # Reducido para ResNet50 (más pesado que MobileNet)
|
||||
SEED = 42
|
||||
SPLIT = {'train': 0.7, 'val': 0.15, 'test': 0.15}
|
||||
FORCE_SPLIT = False
|
||||
|
||||
# ----------------- Utilities -----------------
|
||||
def set_seed(seed=42):
|
||||
"""Establecer semilla para reproducibilidad"""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
tf.random.set_seed(seed)
|
||||
|
||||
def analyze_class_distribution(df, column_name='fase V'):
|
||||
"""Analizar distribución de clases y detectar desbalances"""
|
||||
print(f"\n📊 === Análisis de Distribución de Clases ===")
|
||||
|
||||
# Contar por clase
|
||||
counts = df[column_name].value_counts()
|
||||
total = len(df)
|
||||
|
||||
print(f"📊 Total de muestras: {total}")
|
||||
print(f"📊 Número de clases: {len(counts)}")
|
||||
print(f"📊 Distribución por clase:")
|
||||
|
||||
# Mostrar estadísticas detalladas
|
||||
for clase, count in counts.items():
|
||||
percentage = (count / total) * 100
|
||||
print(f" - {clase}: {count} muestras ({percentage:.1f}%)")
|
||||
|
||||
# Detectar clases problemáticas
|
||||
min_samples = 5 # Umbral mínimo recomendado
|
||||
small_classes = counts[counts < min_samples]
|
||||
|
||||
if len(small_classes) > 0:
|
||||
print(f"\n⚠️ Clases con menos de {min_samples} muestras:")
|
||||
for clase, count in small_classes.items():
|
||||
print(f" - {clase}: {count} muestras")
|
||||
|
||||
print(f"\n💡 Recomendaciones:")
|
||||
print(f" 1. Considera recolectar más datos para estas clases")
|
||||
print(f" 2. O fusionar clases similares")
|
||||
print(f" 3. O usar técnicas de data augmentation específicas")
|
||||
|
||||
return counts, small_classes
|
||||
|
||||
def safe_read_csv(path):
|
||||
"""Leer CSV con manejo de encoding"""
|
||||
if not os.path.exists(path):
|
||||
raise FileNotFoundError(f'CSV no encontrado: {path}')
|
||||
try:
|
||||
df = pd.read_csv(path, encoding='utf-8')
|
||||
except UnicodeDecodeError:
|
||||
try:
|
||||
df = pd.read_csv(path, encoding='latin-1')
|
||||
except:
|
||||
df = pd.read_csv(path, encoding='cp1252')
|
||||
|
||||
return df
|
||||
|
||||
def resolve_image_path(images_dir, img_id):
|
||||
"""Resolver la ruta completa de una imagen"""
|
||||
if pd.isna(img_id) or str(img_id).strip() == '':
|
||||
return None
|
||||
|
||||
img_id = str(img_id).strip()
|
||||
|
||||
# Verificar si ya incluye extensión y existe
|
||||
direct_path = os.path.join(images_dir, img_id)
|
||||
if os.path.exists(direct_path):
|
||||
return direct_path
|
||||
|
||||
# Probar con extensiones comunes
|
||||
stem = os.path.splitext(img_id)[0]
|
||||
for ext in ['.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG']:
|
||||
img_path = os.path.join(images_dir, stem + ext)
|
||||
if os.path.exists(img_path):
|
||||
return img_path
|
||||
|
||||
return None
|
||||
|
||||
def prepare_image_folders(df, images_dir, out_dir, split=SPLIT, seed=SEED):
|
||||
"""Crear estructura de carpetas para flow_from_directory"""
|
||||
set_seed(seed)
|
||||
|
||||
# Filtrar solo filas con fase V válida e imágenes existentes
|
||||
print(f"📊 Datos iniciales: {len(df)} filas")
|
||||
|
||||
# Filtrar filas con fase V válida
|
||||
df_valid = df.dropna(subset=['fase V']).copy()
|
||||
df_valid = df_valid[df_valid['fase V'].str.strip() != '']
|
||||
print(f"📊 Con fase V válida: {len(df_valid)} filas")
|
||||
|
||||
# Verificar existencia de imágenes
|
||||
valid_rows = []
|
||||
for _, row in df_valid.iterrows():
|
||||
img_path = resolve_image_path(images_dir, row['id_img'])
|
||||
if img_path:
|
||||
valid_rows.append(row)
|
||||
else:
|
||||
print(f"⚠️ Imagen no encontrada: {row['id_img']}")
|
||||
|
||||
if not valid_rows:
|
||||
raise ValueError("❌ No se encontraron imágenes válidas")
|
||||
|
||||
df_final = pd.DataFrame(valid_rows)
|
||||
print(f"📊 Con imágenes existentes: {len(df_final)} filas")
|
||||
|
||||
# Mostrar distribución de clases
|
||||
fase_counts = df_final['fase V'].value_counts()
|
||||
print(f"\n📊 Distribución de fases:")
|
||||
for fase, count in fase_counts.items():
|
||||
print(f" - {fase}: {count} imágenes")
|
||||
|
||||
# Remover clases con muy pocas muestras (menos de 3)
|
||||
min_samples = 3
|
||||
valid_phases = fase_counts[fase_counts >= min_samples].index.tolist()
|
||||
if len(valid_phases) < len(fase_counts):
|
||||
excluded = fase_counts[fase_counts < min_samples].index.tolist()
|
||||
print(f"⚠️ Excluyendo fases con menos de {min_samples} muestras: {excluded}")
|
||||
df_final = df_final[df_final['fase V'].isin(valid_phases)]
|
||||
print(f"📊 Después de filtrar: {len(df_final)} filas, {len(valid_phases)} clases")
|
||||
|
||||
labels = df_final['fase V'].unique().tolist()
|
||||
print(f"📊 Clases finales: {labels}")
|
||||
|
||||
# Mezclar y dividir datos
|
||||
df_shuffled = df_final.sample(frac=1, random_state=seed).reset_index(drop=True)
|
||||
n = len(df_shuffled)
|
||||
n_train = int(n * split['train'])
|
||||
n_val = int(n * split['val'])
|
||||
|
||||
train_df = df_shuffled.iloc[:n_train]
|
||||
val_df = df_shuffled.iloc[n_train:n_train + n_val]
|
||||
test_df = df_shuffled.iloc[n_train + n_val:]
|
||||
|
||||
print(f"📊 División final:")
|
||||
print(f" - Entrenamiento: {len(train_df)} imágenes")
|
||||
print(f" - Validación: {len(val_df)} imágenes")
|
||||
print(f" - Prueba: {len(test_df)} imágenes")
|
||||
|
||||
# Crear estructura de carpetas
|
||||
for part in ['train', 'val', 'test']:
|
||||
for label in labels:
|
||||
label_dir = os.path.join(out_dir, part, str(label))
|
||||
os.makedirs(label_dir, exist_ok=True)
|
||||
|
||||
# Función para copiar imágenes
|
||||
def copy_subset(subdf, subset_name):
|
||||
copied, missing = 0, 0
|
||||
for _, row in subdf.iterrows():
|
||||
src = resolve_image_path(images_dir, row['id_img'])
|
||||
if src:
|
||||
fase = str(row['fase V'])
|
||||
dst = os.path.join(out_dir, subset_name, fase, f"{row['id_img']}.jpg")
|
||||
try:
|
||||
shutil.copy2(src, dst)
|
||||
copied += 1
|
||||
except Exception as e:
|
||||
print(f"⚠️ Error copiando {src}: {e}")
|
||||
missing += 1
|
||||
else:
|
||||
missing += 1
|
||||
|
||||
print(f"✅ {subset_name}: {copied} imágenes copiadas, {missing} fallidas")
|
||||
return copied
|
||||
|
||||
# Copiar imágenes a las carpetas correspondientes
|
||||
copy_subset(train_df, 'train')
|
||||
copy_subset(val_df, 'val')
|
||||
copy_subset(test_df, 'test')
|
||||
|
||||
return train_df, val_df, test_df
|
||||
|
||||
def main():
|
||||
"""Función principal del pipeline ResNet50"""
|
||||
parser = argparse.ArgumentParser(description='ResNet50 Transfer Learning para Nocciola')
|
||||
parser.add_argument('--csv_path', type=str, default=CSV_PATH,
|
||||
help='Ruta al archivo CSV con metadatos')
|
||||
parser.add_argument('--images_dir', type=str, default=IMAGES_DIR,
|
||||
help='Directorio con las imágenes')
|
||||
parser.add_argument('--output_dir', type=str, default=OUTPUT_DIR,
|
||||
help='Directorio de salida para resultados')
|
||||
parser.add_argument('--epochs', type=int, default=25,
|
||||
help='Número de épocas de entrenamiento')
|
||||
parser.add_argument('--force_split', action='store_true',
|
||||
help='Forzar recreación del split de datos')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print('\n🚀 === Inicio del pipeline ResNet50 para Nocciola ===')
|
||||
print(f"📁 Directorio de imágenes: {args.images_dir}")
|
||||
print(f"📄 Archivo CSV: {args.csv_path}")
|
||||
print(f"📂 Directorio de salida: {args.output_dir}")
|
||||
|
||||
# Establecer semilla
|
||||
set_seed(SEED)
|
||||
|
||||
# Crear directorio de salida
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# Leer datos
|
||||
print('\n📊 === Cargando datos ===')
|
||||
df = safe_read_csv(args.csv_path)
|
||||
print(f'📊 Total de registros en CSV: {len(df)}')
|
||||
print(f'📊 Columnas disponibles: {list(df.columns)}')
|
||||
|
||||
# Verificar columnas requeridas
|
||||
required_cols = {'id_img', 'fase V'}
|
||||
if not required_cols.issubset(set(df.columns)):
|
||||
missing = required_cols - set(df.columns)
|
||||
raise ValueError(f'❌ CSV debe contener las columnas: {missing}')
|
||||
|
||||
# Analizar distribución de clases antes del procesamiento
|
||||
analyze_class_distribution(df, 'fase V')
|
||||
|
||||
# Preparar estructura de carpetas
|
||||
SPLIT_DIR = os.path.join(args.output_dir, 'data_split')
|
||||
|
||||
if args.force_split and os.path.exists(SPLIT_DIR):
|
||||
print("🗑️ Eliminando split existente...")
|
||||
shutil.rmtree(SPLIT_DIR)
|
||||
|
||||
if not os.path.exists(SPLIT_DIR):
|
||||
print("\n📁 === Creando nueva división de datos ===")
|
||||
train_df, val_df, test_df = prepare_image_folders(df, args.images_dir, SPLIT_DIR)
|
||||
|
||||
# Guardar información del split
|
||||
train_df.to_csv(os.path.join(args.output_dir, 'train_split.csv'), index=False)
|
||||
val_df.to_csv(os.path.join(args.output_dir, 'val_split.csv'), index=False)
|
||||
test_df.to_csv(os.path.join(args.output_dir, 'test_split.csv'), index=False)
|
||||
|
||||
else:
|
||||
print("\n♻️ === Reutilizando división existente ===")
|
||||
# Cargar información del split si existe
|
||||
try:
|
||||
train_df = pd.read_csv(os.path.join(args.output_dir, 'train_split.csv'))
|
||||
val_df = pd.read_csv(os.path.join(args.output_dir, 'val_split.csv'))
|
||||
test_df = pd.read_csv(os.path.join(args.output_dir, 'test_split.csv'))
|
||||
except:
|
||||
print("⚠️ No se pudieron cargar los archivos de split, recreando...")
|
||||
train_df, val_df, test_df = prepare_image_folders(df, args.images_dir, SPLIT_DIR)
|
||||
|
||||
# Crear generadores de datos
|
||||
print("\n🔄 === Creando generadores de datos ===")
|
||||
|
||||
# Data augmentation para entrenamiento (más conservador para ResNet50)
|
||||
train_datagen = ImageDataGenerator(
|
||||
rescale=1./255,
|
||||
rotation_range=15, # Menos rotación que MobileNet
|
||||
width_shift_range=0.08,
|
||||
height_shift_range=0.08,
|
||||
shear_range=0.08,
|
||||
zoom_range=0.08,
|
||||
horizontal_flip=True,
|
||||
fill_mode='nearest'
|
||||
)
|
||||
|
||||
# Solo normalización para validación y test
|
||||
val_test_datagen = ImageDataGenerator(rescale=1./255)
|
||||
|
||||
# Crear generadores
|
||||
train_gen = train_datagen.flow_from_directory(
|
||||
os.path.join(SPLIT_DIR, 'train'),
|
||||
target_size=IMG_SIZE,
|
||||
batch_size=BATCH_SIZE,
|
||||
class_mode='categorical',
|
||||
seed=SEED
|
||||
)
|
||||
|
||||
val_gen = val_test_datagen.flow_from_directory(
|
||||
os.path.join(SPLIT_DIR, 'val'),
|
||||
target_size=IMG_SIZE,
|
||||
batch_size=BATCH_SIZE,
|
||||
class_mode='categorical',
|
||||
shuffle=False
|
||||
)
|
||||
|
||||
test_gen = val_test_datagen.flow_from_directory(
|
||||
os.path.join(SPLIT_DIR, 'test'),
|
||||
target_size=IMG_SIZE,
|
||||
batch_size=BATCH_SIZE,
|
||||
class_mode='categorical',
|
||||
shuffle=False
|
||||
)
|
||||
|
||||
# Guardar mapeo de clases
|
||||
class_indices = train_gen.class_indices
|
||||
print(f'🏷️ Mapeo de clases: {class_indices}')
|
||||
|
||||
with open(os.path.join(args.output_dir, 'class_indices.json'), 'w') as f:
|
||||
json.dump(class_indices, f, indent=2)
|
||||
|
||||
print(f"📊 Muestras por conjunto:")
|
||||
print(f" - Entrenamiento: {train_gen.samples}")
|
||||
print(f" - Validación: {val_gen.samples}")
|
||||
print(f" - Prueba: {test_gen.samples}")
|
||||
print(f" - Número de clases: {train_gen.num_classes}")
|
||||
|
||||
# Crear y entrenar modelo ResNet50
|
||||
print("\n🤖 === Construcción del modelo ResNet50 ===")
|
||||
|
||||
# Modelo base ResNet50
|
||||
base_model = ResNet50(
|
||||
weights='imagenet',
|
||||
include_top=False,
|
||||
input_shape=(*IMG_SIZE, 3)
|
||||
)
|
||||
base_model.trainable = False # Congelar inicialmente
|
||||
|
||||
# Construir modelo secuencial optimizado para ResNet50
|
||||
model = models.Sequential([
|
||||
base_model,
|
||||
layers.GlobalAveragePooling2D(),
|
||||
layers.BatchNormalization(), # Importante para ResNet50
|
||||
layers.Dropout(0.5), # Dropout más alto para ResNet50
|
||||
layers.Dense(256, activation='relu'),
|
||||
layers.BatchNormalization(),
|
||||
layers.Dropout(0.3),
|
||||
layers.Dense(train_gen.num_classes, activation='softmax')
|
||||
])
|
||||
|
||||
# Compilar modelo con learning rate más bajo para ResNet50
|
||||
model.compile(
|
||||
optimizer=tf.keras.optimizers.Adam(learning_rate=5e-4), # LR más bajo
|
||||
loss='categorical_crossentropy',
|
||||
metrics=['accuracy']
|
||||
)
|
||||
|
||||
print("📋 Resumen del modelo ResNet50:")
|
||||
model.summary()
|
||||
|
||||
# Calcular pesos de clase
|
||||
print("\n⚖️ === Calculando pesos de clase ===")
|
||||
try:
|
||||
# Obtener etiquetas de entrenamiento
|
||||
train_labels = []
|
||||
for i in range(len(train_gen)):
|
||||
_, labels = train_gen[i]
|
||||
train_labels.extend(np.argmax(labels, axis=1))
|
||||
if len(train_labels) >= train_gen.samples:
|
||||
break
|
||||
|
||||
# Calcular pesos balanceados
|
||||
class_weights = class_weight.compute_class_weight(
|
||||
'balanced',
|
||||
classes=np.unique(train_labels),
|
||||
y=train_labels
|
||||
)
|
||||
class_weight_dict = dict(zip(np.unique(train_labels), class_weights))
|
||||
print(f"⚖️ Pesos de clase: {class_weight_dict}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ Error calculando pesos de clase: {e}")
|
||||
class_weight_dict = None
|
||||
|
||||
# Callbacks para entrenamiento ResNet50
|
||||
early_stopping = tf.keras.callbacks.EarlyStopping(
|
||||
monitor='val_loss',
|
||||
patience=10, # Más paciencia para ResNet50
|
||||
restore_best_weights=True,
|
||||
verbose=1
|
||||
)
|
||||
|
||||
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
|
||||
os.path.join(args.output_dir, 'best_resnet50_model.keras'),
|
||||
save_best_only=True,
|
||||
monitor='val_loss',
|
||||
verbose=1
|
||||
)
|
||||
|
||||
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
|
||||
monitor='val_loss',
|
||||
factor=0.3, # Reducción más agresiva
|
||||
patience=5, # Menos paciencia para reducir LR
|
||||
min_lr=1e-8,
|
||||
verbose=1
|
||||
)
|
||||
|
||||
callbacks = [early_stopping, model_checkpoint, reduce_lr]
|
||||
|
||||
# Entrenamiento inicial
|
||||
print(f"\n🏋️ === Entrenamiento inicial ResNet50 ({args.epochs} épocas) ===")
|
||||
|
||||
try:
|
||||
history = model.fit(
|
||||
train_gen,
|
||||
validation_data=val_gen,
|
||||
epochs=args.epochs,
|
||||
callbacks=callbacks,
|
||||
class_weight=class_weight_dict,
|
||||
verbose=1
|
||||
)
|
||||
|
||||
print("✅ Entrenamiento inicial completado")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error durante entrenamiento: {e}")
|
||||
# Entrenar sin class_weight si hay problemas
|
||||
print("🔄 Intentando entrenamiento sin pesos de clase...")
|
||||
history = model.fit(
|
||||
train_gen,
|
||||
validation_data=val_gen,
|
||||
epochs=args.epochs,
|
||||
callbacks=callbacks,
|
||||
verbose=1
|
||||
)
|
||||
|
||||
# Fine-tuning específico para ResNet50
|
||||
print("\n🔧 === Fine-tuning ResNet50 ===")
|
||||
|
||||
# Descongelar capas específicas de ResNet50
|
||||
base_model.trainable = True
|
||||
|
||||
# Para ResNet50, descongelar solo las últimas capas residuales
|
||||
# ResNet50 tiene bloques conv5_x que son los más específicos
|
||||
fine_tune_at = 140 # Aproximadamente desde conv5_block1
|
||||
|
||||
for layer in base_model.layers[:fine_tune_at]:
|
||||
layer.trainable = False
|
||||
|
||||
print(f"🔧 Capas entrenables: {len([l for l in base_model.layers if l.trainable])}/{len(base_model.layers)}")
|
||||
|
||||
# Recompilar con learning rate muy bajo para fine-tuning
|
||||
model.compile(
|
||||
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-6), # LR muy bajo para ResNet50
|
||||
loss='categorical_crossentropy',
|
||||
metrics=['accuracy']
|
||||
)
|
||||
|
||||
# Continuar entrenamiento
|
||||
fine_tune_epochs = 15 # Más épocas para fine-tuning
|
||||
total_epochs = len(history.history['loss']) + fine_tune_epochs
|
||||
|
||||
try:
|
||||
history_fine = model.fit(
|
||||
train_gen,
|
||||
validation_data=val_gen,
|
||||
epochs=total_epochs,
|
||||
initial_epoch=len(history.history['loss']),
|
||||
callbacks=callbacks,
|
||||
verbose=1
|
||||
)
|
||||
|
||||
print("✅ Fine-tuning completado")
|
||||
|
||||
# Combinar historiales
|
||||
for key in history.history:
|
||||
if key in history_fine.history:
|
||||
history.history[key].extend(history_fine.history[key])
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ Error durante fine-tuning: {e}")
|
||||
print("Continuando con modelo del entrenamiento inicial...")
|
||||
|
||||
# Evaluación final
|
||||
print("\n📊 === Evaluación en conjunto de prueba ===")
|
||||
|
||||
# Cargar mejor modelo
|
||||
try:
|
||||
model.load_weights(os.path.join(args.output_dir, 'best_resnet50_model.keras'))
|
||||
print("✅ Cargado mejor modelo ResNet50 guardado")
|
||||
except:
|
||||
print("⚠️ Usando modelo actual")
|
||||
|
||||
# Guardar modelo final
|
||||
model.save(os.path.join(args.output_dir, 'final_resnet50_model.keras'))
|
||||
print("💾 Modelo ResNet50 final guardado")
|
||||
|
||||
# Predicciones en test
|
||||
test_gen.reset()
|
||||
y_pred_prob = model.predict(test_gen, verbose=1)
|
||||
y_pred = np.argmax(y_pred_prob, axis=1)
|
||||
y_true = test_gen.classes
|
||||
|
||||
# Mapeo de índices a nombres de clase
|
||||
index_to_class = {v: k for k, v in class_indices.items()}
|
||||
|
||||
# Obtener solo las clases que realmente aparecen en el conjunto de test
|
||||
unique_test_classes = np.unique(np.concatenate([y_true, y_pred]))
|
||||
test_class_names = [index_to_class[i] for i in unique_test_classes]
|
||||
|
||||
print(f"📊 Clases en conjunto de test: {len(unique_test_classes)}")
|
||||
print(f"📊 Todas las clases entrenadas: {len(class_indices)}")
|
||||
print(f"📊 Clases presentes en test: {test_class_names}")
|
||||
|
||||
# Verificar si hay clases faltantes
|
||||
all_classes = set(range(len(class_indices)))
|
||||
test_classes = set(unique_test_classes)
|
||||
missing_classes = all_classes - test_classes
|
||||
|
||||
if missing_classes:
|
||||
missing_names = [index_to_class[i] for i in missing_classes]
|
||||
print(f"⚠️ Clases sin muestras en test: {missing_names}")
|
||||
|
||||
# Reporte de clasificación con clases filtradas
|
||||
print("\n📋 === Reporte de Clasificación ResNet50 ===")
|
||||
try:
|
||||
report = classification_report(
|
||||
y_true, y_pred,
|
||||
labels=unique_test_classes,
|
||||
target_names=test_class_names,
|
||||
output_dict=False,
|
||||
zero_division=0
|
||||
)
|
||||
print(report)
|
||||
|
||||
# Guardar reporte
|
||||
with open(os.path.join(args.output_dir, 'classification_report.txt'), 'w') as f:
|
||||
f.write(f"Modelo: ResNet50\n")
|
||||
f.write(f"Clases evaluadas: {test_class_names}\n")
|
||||
f.write(f"Clases faltantes en test: {[index_to_class[i] for i in missing_classes] if missing_classes else 'Ninguna'}\n\n")
|
||||
f.write(report)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error en classification_report: {e}")
|
||||
print("📊 Generando reporte alternativo...")
|
||||
|
||||
# Reporte manual si falla el automático
|
||||
from collections import Counter
|
||||
true_counts = Counter(y_true)
|
||||
pred_counts = Counter(y_pred)
|
||||
|
||||
print("\n📊 Distribución manual:")
|
||||
print("Clase | Verdaderos | Predichos")
|
||||
print("-" * 35)
|
||||
for class_idx in unique_test_classes:
|
||||
class_name = index_to_class[class_idx]
|
||||
true_count = true_counts.get(class_idx, 0)
|
||||
pred_count = pred_counts.get(class_idx, 0)
|
||||
print(f"{class_name[:15]:15} | {true_count:10} | {pred_count:9}")
|
||||
|
||||
# Calcular accuracy básico
|
||||
accuracy = np.mean(y_true == y_pred)
|
||||
print(f"\n📊 Accuracy general ResNet50: {accuracy:.4f}")
|
||||
|
||||
# Guardar reporte manual
|
||||
with open(os.path.join(args.output_dir, 'classification_report.txt'), 'w') as f:
|
||||
f.write("REPORTE MANUAL DE CLASIFICACIÓN - ResNet50\n")
|
||||
f.write("=" * 50 + "\n\n")
|
||||
f.write(f"Clases evaluadas: {test_class_names}\n")
|
||||
f.write(f"Clases faltantes en test: {[index_to_class[i] for i in missing_classes] if missing_classes else 'Ninguna'}\n\n")
|
||||
f.write("Distribución por clase:\n")
|
||||
f.write("Clase | Verdaderos | Predichos\n")
|
||||
f.write("-" * 35 + "\n")
|
||||
for class_idx in unique_test_classes:
|
||||
class_name = index_to_class[class_idx]
|
||||
true_count = true_counts.get(class_idx, 0)
|
||||
pred_count = pred_counts.get(class_idx, 0)
|
||||
f.write(f"{class_name[:15]:15} | {true_count:10} | {pred_count:9}\n")
|
||||
f.write(f"\nAccuracy general ResNet50: {accuracy:.4f}\n")
|
||||
|
||||
# Matriz de confusión con clases filtradas
|
||||
cm = confusion_matrix(y_true, y_pred, labels=unique_test_classes)
|
||||
print(f"\n🔢 Matriz de Confusión ResNet50 ({len(unique_test_classes)} clases):")
|
||||
print(cm)
|
||||
|
||||
np.savetxt(os.path.join(args.output_dir, 'confusion_matrix.csv'),
|
||||
cm, delimiter=',', fmt='%d')
|
||||
|
||||
# Visualizaciones con clases filtradas
|
||||
print("\n📈 === Generando visualizaciones ResNet50 ===")
|
||||
|
||||
# Gráfico de entrenamiento
|
||||
plot_training_history(history, args.output_dir)
|
||||
|
||||
# Matriz de confusión visual con clases filtradas
|
||||
plot_confusion_matrix(cm, test_class_names, args.output_dir)
|
||||
|
||||
# Ejemplos de predicciones con clases filtradas
|
||||
plot_prediction_examples(test_gen, y_true, y_pred, test_class_names, args.output_dir, unique_test_classes)
|
||||
|
||||
print(f"\n🎉 === Pipeline ResNet50 completado ===")
|
||||
print(f"📁 Resultados guardados en: {args.output_dir}")
|
||||
print(f"📊 Precisión final en test: {np.mean(y_true == y_pred):.4f}")
|
||||
print(f"📊 Clases evaluadas: {len(unique_test_classes)}/{len(class_indices)}")
|
||||
|
||||
# Información adicional sobre clases desbalanceadas
|
||||
if missing_classes:
|
||||
print(f"\n⚠️ === Información sobre Clases Desbalanceadas ===")
|
||||
print(f"❌ Clases sin muestras en test: {len(missing_classes)}")
|
||||
for missing_idx in missing_classes:
|
||||
missing_name = index_to_class[missing_idx]
|
||||
print(f" - {missing_name} (índice {missing_idx})")
|
||||
print(f"💡 Sugerencia: Considera aumentar el dataset o fusionar clases similares")
|
||||
|
||||
def plot_training_history(history, output_dir):
|
||||
"""Graficar historial de entrenamiento ResNet50"""
|
||||
try:
|
||||
plt.figure(figsize=(12, 4))
|
||||
|
||||
# Accuracy
|
||||
plt.subplot(1, 2, 1)
|
||||
plt.plot(history.history['accuracy'], label='Entrenamiento', linewidth=2)
|
||||
if 'val_accuracy' in history.history:
|
||||
plt.plot(history.history['val_accuracy'], label='Validación', linewidth=2)
|
||||
plt.title('Precisión del Modelo ResNet50')
|
||||
plt.xlabel('Época')
|
||||
plt.ylabel('Precisión')
|
||||
plt.legend()
|
||||
plt.grid(True, alpha=0.3)
|
||||
|
||||
# Loss
|
||||
plt.subplot(1, 2, 2)
|
||||
plt.plot(history.history['loss'], label='Entrenamiento', linewidth=2)
|
||||
if 'val_loss' in history.history:
|
||||
plt.plot(history.history['val_loss'], label='Validación', linewidth=2)
|
||||
plt.title('Pérdida del Modelo ResNet50')
|
||||
plt.xlabel('Época')
|
||||
plt.ylabel('Pérdida')
|
||||
plt.legend()
|
||||
plt.grid(True, alpha=0.3)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(output_dir, 'resnet50_training_history.png'), dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
print("✅ Gráfico de entrenamiento ResNet50 guardado")
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ Error creando gráfico de entrenamiento: {e}")
|
||||
|
||||
def plot_confusion_matrix(cm, class_names, output_dir):
|
||||
"""Graficar matriz de confusión ResNet50"""
|
||||
try:
|
||||
plt.figure(figsize=(10, 8))
|
||||
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
|
||||
xticklabels=class_names, yticklabels=class_names,
|
||||
cbar_kws={'label': 'Número de muestras'})
|
||||
plt.title('Matriz de Confusión - ResNet50')
|
||||
plt.ylabel('Etiqueta Verdadera')
|
||||
plt.xlabel('Etiqueta Predicha')
|
||||
plt.xticks(rotation=45, ha='right')
|
||||
plt.yticks(rotation=0)
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(output_dir, 'resnet50_confusion_matrix.png'), dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
print("✅ Matriz de confusión ResNet50 guardada")
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ Error creando matriz de confusión: {e}")
|
||||
|
||||
def plot_prediction_examples(test_gen, y_true, y_pred, class_names, output_dir, unique_classes=None, n_examples=12):
|
||||
"""Mostrar ejemplos de predicciones correctas e incorrectas para ResNet50"""
|
||||
try:
|
||||
# Obtener índices de predicciones correctas e incorrectas
|
||||
correct_idx = np.where(y_true == y_pred)[0]
|
||||
incorrect_idx = np.where(y_true != y_pred)[0]
|
||||
|
||||
# Seleccionar ejemplos
|
||||
n_correct = min(n_examples // 2, len(correct_idx))
|
||||
n_incorrect = min(n_examples // 2, len(incorrect_idx))
|
||||
|
||||
selected_correct = np.random.choice(correct_idx, n_correct, replace=False) if len(correct_idx) > 0 else []
|
||||
selected_incorrect = np.random.choice(incorrect_idx, n_incorrect, replace=False) if len(incorrect_idx) > 0 else []
|
||||
|
||||
selected_indices = np.concatenate([selected_correct, selected_incorrect])
|
||||
|
||||
if len(selected_indices) == 0:
|
||||
print("⚠️ No hay ejemplos para mostrar")
|
||||
return
|
||||
|
||||
# Crear gráfico
|
||||
n_show = len(selected_indices)
|
||||
cols = 4
|
||||
rows = (n_show + cols - 1) // cols
|
||||
|
||||
plt.figure(figsize=(15, 4 * rows))
|
||||
|
||||
for i, idx in enumerate(selected_indices):
|
||||
plt.subplot(rows, cols, i + 1)
|
||||
|
||||
# Obtener imagen
|
||||
try:
|
||||
img_path = test_gen.filepaths[idx]
|
||||
img = plt.imread(img_path)
|
||||
|
||||
plt.imshow(img)
|
||||
plt.axis('off')
|
||||
|
||||
true_label = class_names[y_true[idx]]
|
||||
pred_label = class_names[y_pred[idx]]
|
||||
|
||||
color = 'green' if y_true[idx] == y_pred[idx] else 'red'
|
||||
plt.title(f'Real: {true_label}\nPredicción: {pred_label}',
|
||||
color=color, fontsize=9, fontweight='bold')
|
||||
except Exception as e:
|
||||
plt.text(0.5, 0.5, f'Error cargando\nimagen {idx}',
|
||||
ha='center', va='center', transform=plt.gca().transAxes)
|
||||
plt.axis('off')
|
||||
|
||||
plt.suptitle('Ejemplos de Predicciones ResNet50 (Verde=Correcta, Rojo=Incorrecta)', fontsize=14, fontweight='bold')
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(output_dir, 'resnet50_prediction_examples.png'), dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
print("✅ Ejemplos de predicciones ResNet50 guardados")
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ Error creando ejemplos de predicciones: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
184
Code/Supervised_learning/resnet/.github/prompts/speckit.analyze.prompt.md
vendored
Normal file
184
Code/Supervised_learning/resnet/.github/prompts/speckit.analyze.prompt.md
vendored
Normal file
@ -0,0 +1,184 @@
|
||||
---
|
||||
description: Perform a non-destructive cross-artifact consistency and quality analysis across spec.md, plan.md, and tasks.md after task generation.
|
||||
---
|
||||
|
||||
## User Input
|
||||
|
||||
```text
|
||||
$ARGUMENTS
|
||||
```
|
||||
|
||||
You **MUST** consider the user input before proceeding (if not empty).
|
||||
|
||||
## Goal
|
||||
|
||||
Identify inconsistencies, duplications, ambiguities, and underspecified items across the three core artifacts (`spec.md`, `plan.md`, `tasks.md`) before implementation. This command MUST run only after `/speckit.tasks` has successfully produced a complete `tasks.md`.
|
||||
|
||||
## Operating Constraints
|
||||
|
||||
**STRICTLY READ-ONLY**: Do **not** modify any files. Output a structured analysis report. Offer an optional remediation plan (user must explicitly approve before any follow-up editing commands would be invoked manually).
|
||||
|
||||
**Constitution Authority**: The project constitution (`.specify/memory/constitution.md`) is **non-negotiable** within this analysis scope. Constitution conflicts are automatically CRITICAL and require adjustment of the spec, plan, or tasks—not dilution, reinterpretation, or silent ignoring of the principle. If a principle itself needs to change, that must occur in a separate, explicit constitution update outside `/speckit.analyze`.
|
||||
|
||||
## Execution Steps
|
||||
|
||||
### 1. Initialize Analysis Context
|
||||
|
||||
Run `.specify/scripts/powershell/check-prerequisites.ps1 -Json -RequireTasks -IncludeTasks` once from repo root and parse JSON for FEATURE_DIR and AVAILABLE_DOCS. Derive absolute paths:
|
||||
|
||||
- SPEC = FEATURE_DIR/spec.md
|
||||
- PLAN = FEATURE_DIR/plan.md
|
||||
- TASKS = FEATURE_DIR/tasks.md
|
||||
|
||||
Abort with an error message if any required file is missing (instruct the user to run missing prerequisite command).
|
||||
For single quotes in args like "I'm Groot", use escape syntax: e.g 'I'\''m Groot' (or double-quote if possible: "I'm Groot").
|
||||
|
||||
### 2. Load Artifacts (Progressive Disclosure)
|
||||
|
||||
Load only the minimal necessary context from each artifact:
|
||||
|
||||
**From spec.md:**
|
||||
|
||||
- Overview/Context
|
||||
- Functional Requirements
|
||||
- Non-Functional Requirements
|
||||
- User Stories
|
||||
- Edge Cases (if present)
|
||||
|
||||
**From plan.md:**
|
||||
|
||||
- Architecture/stack choices
|
||||
- Data Model references
|
||||
- Phases
|
||||
- Technical constraints
|
||||
|
||||
**From tasks.md:**
|
||||
|
||||
- Task IDs
|
||||
- Descriptions
|
||||
- Phase grouping
|
||||
- Parallel markers [P]
|
||||
- Referenced file paths
|
||||
|
||||
**From constitution:**
|
||||
|
||||
- Load `.specify/memory/constitution.md` for principle validation
|
||||
|
||||
### 3. Build Semantic Models
|
||||
|
||||
Create internal representations (do not include raw artifacts in output):
|
||||
|
||||
- **Requirements inventory**: Each functional + non-functional requirement with a stable key (derive slug based on imperative phrase; e.g., "User can upload file" → `user-can-upload-file`)
|
||||
- **User story/action inventory**: Discrete user actions with acceptance criteria
|
||||
- **Task coverage mapping**: Map each task to one or more requirements or stories (inference by keyword / explicit reference patterns like IDs or key phrases)
|
||||
- **Constitution rule set**: Extract principle names and MUST/SHOULD normative statements
|
||||
|
||||
### 4. Detection Passes (Token-Efficient Analysis)
|
||||
|
||||
Focus on high-signal findings. Limit to 50 findings total; aggregate remainder in overflow summary.
|
||||
|
||||
#### A. Duplication Detection
|
||||
|
||||
- Identify near-duplicate requirements
|
||||
- Mark lower-quality phrasing for consolidation
|
||||
|
||||
#### B. Ambiguity Detection
|
||||
|
||||
- Flag vague adjectives (fast, scalable, secure, intuitive, robust) lacking measurable criteria
|
||||
- Flag unresolved placeholders (TODO, TKTK, ???, `<placeholder>`, etc.)
|
||||
|
||||
#### C. Underspecification
|
||||
|
||||
- Requirements with verbs but missing object or measurable outcome
|
||||
- User stories missing acceptance criteria alignment
|
||||
- Tasks referencing files or components not defined in spec/plan
|
||||
|
||||
#### D. Constitution Alignment
|
||||
|
||||
- Any requirement or plan element conflicting with a MUST principle
|
||||
- Missing mandated sections or quality gates from constitution
|
||||
|
||||
#### E. Coverage Gaps
|
||||
|
||||
- Requirements with zero associated tasks
|
||||
- Tasks with no mapped requirement/story
|
||||
- Non-functional requirements not reflected in tasks (e.g., performance, security)
|
||||
|
||||
#### F. Inconsistency
|
||||
|
||||
- Terminology drift (same concept named differently across files)
|
||||
- Data entities referenced in plan but absent in spec (or vice versa)
|
||||
- Task ordering contradictions (e.g., integration tasks before foundational setup tasks without dependency note)
|
||||
- Conflicting requirements (e.g., one requires Next.js while other specifies Vue)
|
||||
|
||||
### 5. Severity Assignment
|
||||
|
||||
Use this heuristic to prioritize findings:
|
||||
|
||||
- **CRITICAL**: Violates constitution MUST, missing core spec artifact, or requirement with zero coverage that blocks baseline functionality
|
||||
- **HIGH**: Duplicate or conflicting requirement, ambiguous security/performance attribute, untestable acceptance criterion
|
||||
- **MEDIUM**: Terminology drift, missing non-functional task coverage, underspecified edge case
|
||||
- **LOW**: Style/wording improvements, minor redundancy not affecting execution order
|
||||
|
||||
### 6. Produce Compact Analysis Report
|
||||
|
||||
Output a Markdown report (no file writes) with the following structure:
|
||||
|
||||
## Specification Analysis Report
|
||||
|
||||
| ID | Category | Severity | Location(s) | Summary | Recommendation |
|
||||
|----|----------|----------|-------------|---------|----------------|
|
||||
| A1 | Duplication | HIGH | spec.md:L120-134 | Two similar requirements ... | Merge phrasing; keep clearer version |
|
||||
|
||||
(Add one row per finding; generate stable IDs prefixed by category initial.)
|
||||
|
||||
**Coverage Summary Table:**
|
||||
|
||||
| Requirement Key | Has Task? | Task IDs | Notes |
|
||||
|-----------------|-----------|----------|-------|
|
||||
|
||||
**Constitution Alignment Issues:** (if any)
|
||||
|
||||
**Unmapped Tasks:** (if any)
|
||||
|
||||
**Metrics:**
|
||||
|
||||
- Total Requirements
|
||||
- Total Tasks
|
||||
- Coverage % (requirements with >=1 task)
|
||||
- Ambiguity Count
|
||||
- Duplication Count
|
||||
- Critical Issues Count
|
||||
|
||||
### 7. Provide Next Actions
|
||||
|
||||
At end of report, output a concise Next Actions block:
|
||||
|
||||
- If CRITICAL issues exist: Recommend resolving before `/speckit.implement`
|
||||
- If only LOW/MEDIUM: User may proceed, but provide improvement suggestions
|
||||
- Provide explicit command suggestions: e.g., "Run /speckit.specify with refinement", "Run /speckit.plan to adjust architecture", "Manually edit tasks.md to add coverage for 'performance-metrics'"
|
||||
|
||||
### 8. Offer Remediation
|
||||
|
||||
Ask the user: "Would you like me to suggest concrete remediation edits for the top N issues?" (Do NOT apply them automatically.)
|
||||
|
||||
## Operating Principles
|
||||
|
||||
### Context Efficiency
|
||||
|
||||
- **Minimal high-signal tokens**: Focus on actionable findings, not exhaustive documentation
|
||||
- **Progressive disclosure**: Load artifacts incrementally; don't dump all content into analysis
|
||||
- **Token-efficient output**: Limit findings table to 50 rows; summarize overflow
|
||||
- **Deterministic results**: Rerunning without changes should produce consistent IDs and counts
|
||||
|
||||
### Analysis Guidelines
|
||||
|
||||
- **NEVER modify files** (this is read-only analysis)
|
||||
- **NEVER hallucinate missing sections** (if absent, report them accurately)
|
||||
- **Prioritize constitution violations** (these are always CRITICAL)
|
||||
- **Use examples over exhaustive rules** (cite specific instances, not generic patterns)
|
||||
- **Report zero issues gracefully** (emit success report with coverage statistics)
|
||||
|
||||
## Context
|
||||
|
||||
$ARGUMENTS
|
||||
294
Code/Supervised_learning/resnet/.github/prompts/speckit.checklist.prompt.md
vendored
Normal file
294
Code/Supervised_learning/resnet/.github/prompts/speckit.checklist.prompt.md
vendored
Normal file
@ -0,0 +1,294 @@
|
||||
---
|
||||
description: Generate a custom checklist for the current feature based on user requirements.
|
||||
---
|
||||
|
||||
## Checklist Purpose: "Unit Tests for English"
|
||||
|
||||
**CRITICAL CONCEPT**: Checklists are **UNIT TESTS FOR REQUIREMENTS WRITING** - they validate the quality, clarity, and completeness of requirements in a given domain.
|
||||
|
||||
**NOT for verification/testing**:
|
||||
|
||||
- ❌ NOT "Verify the button clicks correctly"
|
||||
- ❌ NOT "Test error handling works"
|
||||
- ❌ NOT "Confirm the API returns 200"
|
||||
- ❌ NOT checking if code/implementation matches the spec
|
||||
|
||||
**FOR requirements quality validation**:
|
||||
|
||||
- ✅ "Are visual hierarchy requirements defined for all card types?" (completeness)
|
||||
- ✅ "Is 'prominent display' quantified with specific sizing/positioning?" (clarity)
|
||||
- ✅ "Are hover state requirements consistent across all interactive elements?" (consistency)
|
||||
- ✅ "Are accessibility requirements defined for keyboard navigation?" (coverage)
|
||||
- ✅ "Does the spec define what happens when logo image fails to load?" (edge cases)
|
||||
|
||||
**Metaphor**: If your spec is code written in English, the checklist is its unit test suite. You're testing whether the requirements are well-written, complete, unambiguous, and ready for implementation - NOT whether the implementation works.
|
||||
|
||||
## User Input
|
||||
|
||||
```text
|
||||
$ARGUMENTS
|
||||
```
|
||||
|
||||
You **MUST** consider the user input before proceeding (if not empty).
|
||||
|
||||
## Execution Steps
|
||||
|
||||
1. **Setup**: Run `.specify/scripts/powershell/check-prerequisites.ps1 -Json` from repo root and parse JSON for FEATURE_DIR and AVAILABLE_DOCS list.
|
||||
- All file paths must be absolute.
|
||||
- For single quotes in args like "I'm Groot", use escape syntax: e.g 'I'\''m Groot' (or double-quote if possible: "I'm Groot").
|
||||
|
||||
2. **Clarify intent (dynamic)**: Derive up to THREE initial contextual clarifying questions (no pre-baked catalog). They MUST:
|
||||
- Be generated from the user's phrasing + extracted signals from spec/plan/tasks
|
||||
- Only ask about information that materially changes checklist content
|
||||
- Be skipped individually if already unambiguous in `$ARGUMENTS`
|
||||
- Prefer precision over breadth
|
||||
|
||||
Generation algorithm:
|
||||
1. Extract signals: feature domain keywords (e.g., auth, latency, UX, API), risk indicators ("critical", "must", "compliance"), stakeholder hints ("QA", "review", "security team"), and explicit deliverables ("a11y", "rollback", "contracts").
|
||||
2. Cluster signals into candidate focus areas (max 4) ranked by relevance.
|
||||
3. Identify probable audience & timing (author, reviewer, QA, release) if not explicit.
|
||||
4. Detect missing dimensions: scope breadth, depth/rigor, risk emphasis, exclusion boundaries, measurable acceptance criteria.
|
||||
5. Formulate questions chosen from these archetypes:
|
||||
- Scope refinement (e.g., "Should this include integration touchpoints with X and Y or stay limited to local module correctness?")
|
||||
- Risk prioritization (e.g., "Which of these potential risk areas should receive mandatory gating checks?")
|
||||
- Depth calibration (e.g., "Is this a lightweight pre-commit sanity list or a formal release gate?")
|
||||
- Audience framing (e.g., "Will this be used by the author only or peers during PR review?")
|
||||
- Boundary exclusion (e.g., "Should we explicitly exclude performance tuning items this round?")
|
||||
- Scenario class gap (e.g., "No recovery flows detected—are rollback / partial failure paths in scope?")
|
||||
|
||||
Question formatting rules:
|
||||
- If presenting options, generate a compact table with columns: Option | Candidate | Why It Matters
|
||||
- Limit to A–E options maximum; omit table if a free-form answer is clearer
|
||||
- Never ask the user to restate what they already said
|
||||
- Avoid speculative categories (no hallucination). If uncertain, ask explicitly: "Confirm whether X belongs in scope."
|
||||
|
||||
Defaults when interaction impossible:
|
||||
- Depth: Standard
|
||||
- Audience: Reviewer (PR) if code-related; Author otherwise
|
||||
- Focus: Top 2 relevance clusters
|
||||
|
||||
Output the questions (label Q1/Q2/Q3). After answers: if ≥2 scenario classes (Alternate / Exception / Recovery / Non-Functional domain) remain unclear, you MAY ask up to TWO more targeted follow‑ups (Q4/Q5) with a one-line justification each (e.g., "Unresolved recovery path risk"). Do not exceed five total questions. Skip escalation if user explicitly declines more.
|
||||
|
||||
3. **Understand user request**: Combine `$ARGUMENTS` + clarifying answers:
|
||||
- Derive checklist theme (e.g., security, review, deploy, ux)
|
||||
- Consolidate explicit must-have items mentioned by user
|
||||
- Map focus selections to category scaffolding
|
||||
- Infer any missing context from spec/plan/tasks (do NOT hallucinate)
|
||||
|
||||
4. **Load feature context**: Read from FEATURE_DIR:
|
||||
- spec.md: Feature requirements and scope
|
||||
- plan.md (if exists): Technical details, dependencies
|
||||
- tasks.md (if exists): Implementation tasks
|
||||
|
||||
**Context Loading Strategy**:
|
||||
- Load only necessary portions relevant to active focus areas (avoid full-file dumping)
|
||||
- Prefer summarizing long sections into concise scenario/requirement bullets
|
||||
- Use progressive disclosure: add follow-on retrieval only if gaps detected
|
||||
- If source docs are large, generate interim summary items instead of embedding raw text
|
||||
|
||||
5. **Generate checklist** - Create "Unit Tests for Requirements":
|
||||
- Create `FEATURE_DIR/checklists/` directory if it doesn't exist
|
||||
- Generate unique checklist filename:
|
||||
- Use short, descriptive name based on domain (e.g., `ux.md`, `api.md`, `security.md`)
|
||||
- Format: `[domain].md`
|
||||
- If file exists, append to existing file
|
||||
- Number items sequentially starting from CHK001
|
||||
- Each `/speckit.checklist` run creates a NEW file (never overwrites existing checklists)
|
||||
|
||||
**CORE PRINCIPLE - Test the Requirements, Not the Implementation**:
|
||||
Every checklist item MUST evaluate the REQUIREMENTS THEMSELVES for:
|
||||
- **Completeness**: Are all necessary requirements present?
|
||||
- **Clarity**: Are requirements unambiguous and specific?
|
||||
- **Consistency**: Do requirements align with each other?
|
||||
- **Measurability**: Can requirements be objectively verified?
|
||||
- **Coverage**: Are all scenarios/edge cases addressed?
|
||||
|
||||
**Category Structure** - Group items by requirement quality dimensions:
|
||||
- **Requirement Completeness** (Are all necessary requirements documented?)
|
||||
- **Requirement Clarity** (Are requirements specific and unambiguous?)
|
||||
- **Requirement Consistency** (Do requirements align without conflicts?)
|
||||
- **Acceptance Criteria Quality** (Are success criteria measurable?)
|
||||
- **Scenario Coverage** (Are all flows/cases addressed?)
|
||||
- **Edge Case Coverage** (Are boundary conditions defined?)
|
||||
- **Non-Functional Requirements** (Performance, Security, Accessibility, etc. - are they specified?)
|
||||
- **Dependencies & Assumptions** (Are they documented and validated?)
|
||||
- **Ambiguities & Conflicts** (What needs clarification?)
|
||||
|
||||
**HOW TO WRITE CHECKLIST ITEMS - "Unit Tests for English"**:
|
||||
|
||||
❌ **WRONG** (Testing implementation):
|
||||
- "Verify landing page displays 3 episode cards"
|
||||
- "Test hover states work on desktop"
|
||||
- "Confirm logo click navigates home"
|
||||
|
||||
✅ **CORRECT** (Testing requirements quality):
|
||||
- "Are the exact number and layout of featured episodes specified?" [Completeness]
|
||||
- "Is 'prominent display' quantified with specific sizing/positioning?" [Clarity]
|
||||
- "Are hover state requirements consistent across all interactive elements?" [Consistency]
|
||||
- "Are keyboard navigation requirements defined for all interactive UI?" [Coverage]
|
||||
- "Is the fallback behavior specified when logo image fails to load?" [Edge Cases]
|
||||
- "Are loading states defined for asynchronous episode data?" [Completeness]
|
||||
- "Does the spec define visual hierarchy for competing UI elements?" [Clarity]
|
||||
|
||||
**ITEM STRUCTURE**:
|
||||
Each item should follow this pattern:
|
||||
- Question format asking about requirement quality
|
||||
- Focus on what's WRITTEN (or not written) in the spec/plan
|
||||
- Include quality dimension in brackets [Completeness/Clarity/Consistency/etc.]
|
||||
- Reference spec section `[Spec §X.Y]` when checking existing requirements
|
||||
- Use `[Gap]` marker when checking for missing requirements
|
||||
|
||||
**EXAMPLES BY QUALITY DIMENSION**:
|
||||
|
||||
Completeness:
|
||||
- "Are error handling requirements defined for all API failure modes? [Gap]"
|
||||
- "Are accessibility requirements specified for all interactive elements? [Completeness]"
|
||||
- "Are mobile breakpoint requirements defined for responsive layouts? [Gap]"
|
||||
|
||||
Clarity:
|
||||
- "Is 'fast loading' quantified with specific timing thresholds? [Clarity, Spec §NFR-2]"
|
||||
- "Are 'related episodes' selection criteria explicitly defined? [Clarity, Spec §FR-5]"
|
||||
- "Is 'prominent' defined with measurable visual properties? [Ambiguity, Spec §FR-4]"
|
||||
|
||||
Consistency:
|
||||
- "Do navigation requirements align across all pages? [Consistency, Spec §FR-10]"
|
||||
- "Are card component requirements consistent between landing and detail pages? [Consistency]"
|
||||
|
||||
Coverage:
|
||||
- "Are requirements defined for zero-state scenarios (no episodes)? [Coverage, Edge Case]"
|
||||
- "Are concurrent user interaction scenarios addressed? [Coverage, Gap]"
|
||||
- "Are requirements specified for partial data loading failures? [Coverage, Exception Flow]"
|
||||
|
||||
Measurability:
|
||||
- "Are visual hierarchy requirements measurable/testable? [Acceptance Criteria, Spec §FR-1]"
|
||||
- "Can 'balanced visual weight' be objectively verified? [Measurability, Spec §FR-2]"
|
||||
|
||||
**Scenario Classification & Coverage** (Requirements Quality Focus):
|
||||
- Check if requirements exist for: Primary, Alternate, Exception/Error, Recovery, Non-Functional scenarios
|
||||
- For each scenario class, ask: "Are [scenario type] requirements complete, clear, and consistent?"
|
||||
- If scenario class missing: "Are [scenario type] requirements intentionally excluded or missing? [Gap]"
|
||||
- Include resilience/rollback when state mutation occurs: "Are rollback requirements defined for migration failures? [Gap]"
|
||||
|
||||
**Traceability Requirements**:
|
||||
- MINIMUM: ≥80% of items MUST include at least one traceability reference
|
||||
- Each item should reference: spec section `[Spec §X.Y]`, or use markers: `[Gap]`, `[Ambiguity]`, `[Conflict]`, `[Assumption]`
|
||||
- If no ID system exists: "Is a requirement & acceptance criteria ID scheme established? [Traceability]"
|
||||
|
||||
**Surface & Resolve Issues** (Requirements Quality Problems):
|
||||
Ask questions about the requirements themselves:
|
||||
- Ambiguities: "Is the term 'fast' quantified with specific metrics? [Ambiguity, Spec §NFR-1]"
|
||||
- Conflicts: "Do navigation requirements conflict between §FR-10 and §FR-10a? [Conflict]"
|
||||
- Assumptions: "Is the assumption of 'always available podcast API' validated? [Assumption]"
|
||||
- Dependencies: "Are external podcast API requirements documented? [Dependency, Gap]"
|
||||
- Missing definitions: "Is 'visual hierarchy' defined with measurable criteria? [Gap]"
|
||||
|
||||
**Content Consolidation**:
|
||||
- Soft cap: If raw candidate items > 40, prioritize by risk/impact
|
||||
- Merge near-duplicates checking the same requirement aspect
|
||||
- If >5 low-impact edge cases, create one item: "Are edge cases X, Y, Z addressed in requirements? [Coverage]"
|
||||
|
||||
**🚫 ABSOLUTELY PROHIBITED** - These make it an implementation test, not a requirements test:
|
||||
- ❌ Any item starting with "Verify", "Test", "Confirm", "Check" + implementation behavior
|
||||
- ❌ References to code execution, user actions, system behavior
|
||||
- ❌ "Displays correctly", "works properly", "functions as expected"
|
||||
- ❌ "Click", "navigate", "render", "load", "execute"
|
||||
- ❌ Test cases, test plans, QA procedures
|
||||
- ❌ Implementation details (frameworks, APIs, algorithms)
|
||||
|
||||
**✅ REQUIRED PATTERNS** - These test requirements quality:
|
||||
- ✅ "Are [requirement type] defined/specified/documented for [scenario]?"
|
||||
- ✅ "Is [vague term] quantified/clarified with specific criteria?"
|
||||
- ✅ "Are requirements consistent between [section A] and [section B]?"
|
||||
- ✅ "Can [requirement] be objectively measured/verified?"
|
||||
- ✅ "Are [edge cases/scenarios] addressed in requirements?"
|
||||
- ✅ "Does the spec define [missing aspect]?"
|
||||
|
||||
6. **Structure Reference**: Generate the checklist following the canonical template in `.specify/templates/checklist-template.md` for title, meta section, category headings, and ID formatting. If template is unavailable, use: H1 title, purpose/created meta lines, `##` category sections containing `- [ ] CHK### <requirement item>` lines with globally incrementing IDs starting at CHK001.
|
||||
|
||||
7. **Report**: Output full path to created checklist, item count, and remind user that each run creates a new file. Summarize:
|
||||
- Focus areas selected
|
||||
- Depth level
|
||||
- Actor/timing
|
||||
- Any explicit user-specified must-have items incorporated
|
||||
|
||||
**Important**: Each `/speckit.checklist` command invocation creates a checklist file using short, descriptive names unless file already exists. This allows:
|
||||
|
||||
- Multiple checklists of different types (e.g., `ux.md`, `test.md`, `security.md`)
|
||||
- Simple, memorable filenames that indicate checklist purpose
|
||||
- Easy identification and navigation in the `checklists/` folder
|
||||
|
||||
To avoid clutter, use descriptive types and clean up obsolete checklists when done.
|
||||
|
||||
## Example Checklist Types & Sample Items
|
||||
|
||||
**UX Requirements Quality:** `ux.md`
|
||||
|
||||
Sample items (testing the requirements, NOT the implementation):
|
||||
|
||||
- "Are visual hierarchy requirements defined with measurable criteria? [Clarity, Spec §FR-1]"
|
||||
- "Is the number and positioning of UI elements explicitly specified? [Completeness, Spec §FR-1]"
|
||||
- "Are interaction state requirements (hover, focus, active) consistently defined? [Consistency]"
|
||||
- "Are accessibility requirements specified for all interactive elements? [Coverage, Gap]"
|
||||
- "Is fallback behavior defined when images fail to load? [Edge Case, Gap]"
|
||||
- "Can 'prominent display' be objectively measured? [Measurability, Spec §FR-4]"
|
||||
|
||||
**API Requirements Quality:** `api.md`
|
||||
|
||||
Sample items:
|
||||
|
||||
- "Are error response formats specified for all failure scenarios? [Completeness]"
|
||||
- "Are rate limiting requirements quantified with specific thresholds? [Clarity]"
|
||||
- "Are authentication requirements consistent across all endpoints? [Consistency]"
|
||||
- "Are retry/timeout requirements defined for external dependencies? [Coverage, Gap]"
|
||||
- "Is versioning strategy documented in requirements? [Gap]"
|
||||
|
||||
**Performance Requirements Quality:** `performance.md`
|
||||
|
||||
Sample items:
|
||||
|
||||
- "Are performance requirements quantified with specific metrics? [Clarity]"
|
||||
- "Are performance targets defined for all critical user journeys? [Coverage]"
|
||||
- "Are performance requirements under different load conditions specified? [Completeness]"
|
||||
- "Can performance requirements be objectively measured? [Measurability]"
|
||||
- "Are degradation requirements defined for high-load scenarios? [Edge Case, Gap]"
|
||||
|
||||
**Security Requirements Quality:** `security.md`
|
||||
|
||||
Sample items:
|
||||
|
||||
- "Are authentication requirements specified for all protected resources? [Coverage]"
|
||||
- "Are data protection requirements defined for sensitive information? [Completeness]"
|
||||
- "Is the threat model documented and requirements aligned to it? [Traceability]"
|
||||
- "Are security requirements consistent with compliance obligations? [Consistency]"
|
||||
- "Are security failure/breach response requirements defined? [Gap, Exception Flow]"
|
||||
|
||||
## Anti-Examples: What NOT To Do
|
||||
|
||||
**❌ WRONG - These test implementation, not requirements:**
|
||||
|
||||
```markdown
|
||||
- [ ] CHK001 - Verify landing page displays 3 episode cards [Spec §FR-001]
|
||||
- [ ] CHK002 - Test hover states work correctly on desktop [Spec §FR-003]
|
||||
- [ ] CHK003 - Confirm logo click navigates to home page [Spec §FR-010]
|
||||
- [ ] CHK004 - Check that related episodes section shows 3-5 items [Spec §FR-005]
|
||||
```
|
||||
|
||||
**✅ CORRECT - These test requirements quality:**
|
||||
|
||||
```markdown
|
||||
- [ ] CHK001 - Are the number and layout of featured episodes explicitly specified? [Completeness, Spec §FR-001]
|
||||
- [ ] CHK002 - Are hover state requirements consistently defined for all interactive elements? [Consistency, Spec §FR-003]
|
||||
- [ ] CHK003 - Are navigation requirements clear for all clickable brand elements? [Clarity, Spec §FR-010]
|
||||
- [ ] CHK004 - Is the selection criteria for related episodes documented? [Gap, Spec §FR-005]
|
||||
- [ ] CHK005 - Are loading state requirements defined for asynchronous episode data? [Gap]
|
||||
- [ ] CHK006 - Can "visual hierarchy" requirements be objectively measured? [Measurability, Spec §FR-001]
|
||||
```
|
||||
|
||||
**Key Differences:**
|
||||
|
||||
- Wrong: Tests if the system works correctly
|
||||
- Correct: Tests if the requirements are written correctly
|
||||
- Wrong: Verification of behavior
|
||||
- Correct: Validation of requirement quality
|
||||
- Wrong: "Does it do X?"
|
||||
- Correct: "Is X clearly specified?"
|
||||
177
Code/Supervised_learning/resnet/.github/prompts/speckit.clarify.prompt.md
vendored
Normal file
177
Code/Supervised_learning/resnet/.github/prompts/speckit.clarify.prompt.md
vendored
Normal file
@ -0,0 +1,177 @@
|
||||
---
|
||||
description: Identify underspecified areas in the current feature spec by asking up to 5 highly targeted clarification questions and encoding answers back into the spec.
|
||||
---
|
||||
|
||||
## User Input
|
||||
|
||||
```text
|
||||
$ARGUMENTS
|
||||
```
|
||||
|
||||
You **MUST** consider the user input before proceeding (if not empty).
|
||||
|
||||
## Outline
|
||||
|
||||
Goal: Detect and reduce ambiguity or missing decision points in the active feature specification and record the clarifications directly in the spec file.
|
||||
|
||||
Note: This clarification workflow is expected to run (and be completed) BEFORE invoking `/speckit.plan`. If the user explicitly states they are skipping clarification (e.g., exploratory spike), you may proceed, but must warn that downstream rework risk increases.
|
||||
|
||||
Execution steps:
|
||||
|
||||
1. Run `.specify/scripts/powershell/check-prerequisites.ps1 -Json -PathsOnly` from repo root **once** (combined `--json --paths-only` mode / `-Json -PathsOnly`). Parse minimal JSON payload fields:
|
||||
- `FEATURE_DIR`
|
||||
- `FEATURE_SPEC`
|
||||
- (Optionally capture `IMPL_PLAN`, `TASKS` for future chained flows.)
|
||||
- If JSON parsing fails, abort and instruct user to re-run `/speckit.specify` or verify feature branch environment.
|
||||
- For single quotes in args like "I'm Groot", use escape syntax: e.g 'I'\''m Groot' (or double-quote if possible: "I'm Groot").
|
||||
|
||||
2. Load the current spec file. Perform a structured ambiguity & coverage scan using this taxonomy. For each category, mark status: Clear / Partial / Missing. Produce an internal coverage map used for prioritization (do not output raw map unless no questions will be asked).
|
||||
|
||||
Functional Scope & Behavior:
|
||||
- Core user goals & success criteria
|
||||
- Explicit out-of-scope declarations
|
||||
- User roles / personas differentiation
|
||||
|
||||
Domain & Data Model:
|
||||
- Entities, attributes, relationships
|
||||
- Identity & uniqueness rules
|
||||
- Lifecycle/state transitions
|
||||
- Data volume / scale assumptions
|
||||
|
||||
Interaction & UX Flow:
|
||||
- Critical user journeys / sequences
|
||||
- Error/empty/loading states
|
||||
- Accessibility or localization notes
|
||||
|
||||
Non-Functional Quality Attributes:
|
||||
- Performance (latency, throughput targets)
|
||||
- Scalability (horizontal/vertical, limits)
|
||||
- Reliability & availability (uptime, recovery expectations)
|
||||
- Observability (logging, metrics, tracing signals)
|
||||
- Security & privacy (authN/Z, data protection, threat assumptions)
|
||||
- Compliance / regulatory constraints (if any)
|
||||
|
||||
Integration & External Dependencies:
|
||||
- External services/APIs and failure modes
|
||||
- Data import/export formats
|
||||
- Protocol/versioning assumptions
|
||||
|
||||
Edge Cases & Failure Handling:
|
||||
- Negative scenarios
|
||||
- Rate limiting / throttling
|
||||
- Conflict resolution (e.g., concurrent edits)
|
||||
|
||||
Constraints & Tradeoffs:
|
||||
- Technical constraints (language, storage, hosting)
|
||||
- Explicit tradeoffs or rejected alternatives
|
||||
|
||||
Terminology & Consistency:
|
||||
- Canonical glossary terms
|
||||
- Avoided synonyms / deprecated terms
|
||||
|
||||
Completion Signals:
|
||||
- Acceptance criteria testability
|
||||
- Measurable Definition of Done style indicators
|
||||
|
||||
Misc / Placeholders:
|
||||
- TODO markers / unresolved decisions
|
||||
- Ambiguous adjectives ("robust", "intuitive") lacking quantification
|
||||
|
||||
For each category with Partial or Missing status, add a candidate question opportunity unless:
|
||||
- Clarification would not materially change implementation or validation strategy
|
||||
- Information is better deferred to planning phase (note internally)
|
||||
|
||||
3. Generate (internally) a prioritized queue of candidate clarification questions (maximum 5). Do NOT output them all at once. Apply these constraints:
|
||||
- Maximum of 10 total questions across the whole session.
|
||||
- Each question must be answerable with EITHER:
|
||||
- A short multiple‑choice selection (2–5 distinct, mutually exclusive options), OR
|
||||
- A one-word / short‑phrase answer (explicitly constrain: "Answer in <=5 words").
|
||||
- Only include questions whose answers materially impact architecture, data modeling, task decomposition, test design, UX behavior, operational readiness, or compliance validation.
|
||||
- Ensure category coverage balance: attempt to cover the highest impact unresolved categories first; avoid asking two low-impact questions when a single high-impact area (e.g., security posture) is unresolved.
|
||||
- Exclude questions already answered, trivial stylistic preferences, or plan-level execution details (unless blocking correctness).
|
||||
- Favor clarifications that reduce downstream rework risk or prevent misaligned acceptance tests.
|
||||
- If more than 5 categories remain unresolved, select the top 5 by (Impact * Uncertainty) heuristic.
|
||||
|
||||
4. Sequential questioning loop (interactive):
|
||||
- Present EXACTLY ONE question at a time.
|
||||
- For multiple‑choice questions:
|
||||
- **Analyze all options** and determine the **most suitable option** based on:
|
||||
- Best practices for the project type
|
||||
- Common patterns in similar implementations
|
||||
- Risk reduction (security, performance, maintainability)
|
||||
- Alignment with any explicit project goals or constraints visible in the spec
|
||||
- Present your **recommended option prominently** at the top with clear reasoning (1-2 sentences explaining why this is the best choice).
|
||||
- Format as: `**Recommended:** Option [X] - <reasoning>`
|
||||
- Then render all options as a Markdown table:
|
||||
|
||||
| Option | Description |
|
||||
|--------|-------------|
|
||||
| A | <Option A description> |
|
||||
| B | <Option B description> |
|
||||
| C | <Option C description> (add D/E as needed up to 5) |
|
||||
| Short | Provide a different short answer (<=5 words) (Include only if free-form alternative is appropriate) |
|
||||
|
||||
- After the table, add: `You can reply with the option letter (e.g., "A"), accept the recommendation by saying "yes" or "recommended", or provide your own short answer.`
|
||||
- For short‑answer style (no meaningful discrete options):
|
||||
- Provide your **suggested answer** based on best practices and context.
|
||||
- Format as: `**Suggested:** <your proposed answer> - <brief reasoning>`
|
||||
- Then output: `Format: Short answer (<=5 words). You can accept the suggestion by saying "yes" or "suggested", or provide your own answer.`
|
||||
- After the user answers:
|
||||
- If the user replies with "yes", "recommended", or "suggested", use your previously stated recommendation/suggestion as the answer.
|
||||
- Otherwise, validate the answer maps to one option or fits the <=5 word constraint.
|
||||
- If ambiguous, ask for a quick disambiguation (count still belongs to same question; do not advance).
|
||||
- Once satisfactory, record it in working memory (do not yet write to disk) and move to the next queued question.
|
||||
- Stop asking further questions when:
|
||||
- All critical ambiguities resolved early (remaining queued items become unnecessary), OR
|
||||
- User signals completion ("done", "good", "no more"), OR
|
||||
- You reach 5 asked questions.
|
||||
- Never reveal future queued questions in advance.
|
||||
- If no valid questions exist at start, immediately report no critical ambiguities.
|
||||
|
||||
5. Integration after EACH accepted answer (incremental update approach):
|
||||
- Maintain in-memory representation of the spec (loaded once at start) plus the raw file contents.
|
||||
- For the first integrated answer in this session:
|
||||
- Ensure a `## Clarifications` section exists (create it just after the highest-level contextual/overview section per the spec template if missing).
|
||||
- Under it, create (if not present) a `### Session YYYY-MM-DD` subheading for today.
|
||||
- Append a bullet line immediately after acceptance: `- Q: <question> → A: <final answer>`.
|
||||
- Then immediately apply the clarification to the most appropriate section(s):
|
||||
- Functional ambiguity → Update or add a bullet in Functional Requirements.
|
||||
- User interaction / actor distinction → Update User Stories or Actors subsection (if present) with clarified role, constraint, or scenario.
|
||||
- Data shape / entities → Update Data Model (add fields, types, relationships) preserving ordering; note added constraints succinctly.
|
||||
- Non-functional constraint → Add/modify measurable criteria in Non-Functional / Quality Attributes section (convert vague adjective to metric or explicit target).
|
||||
- Edge case / negative flow → Add a new bullet under Edge Cases / Error Handling (or create such subsection if template provides placeholder for it).
|
||||
- Terminology conflict → Normalize term across spec; retain original only if necessary by adding `(formerly referred to as "X")` once.
|
||||
- If the clarification invalidates an earlier ambiguous statement, replace that statement instead of duplicating; leave no obsolete contradictory text.
|
||||
- Save the spec file AFTER each integration to minimize risk of context loss (atomic overwrite).
|
||||
- Preserve formatting: do not reorder unrelated sections; keep heading hierarchy intact.
|
||||
- Keep each inserted clarification minimal and testable (avoid narrative drift).
|
||||
|
||||
6. Validation (performed after EACH write plus final pass):
|
||||
- Clarifications session contains exactly one bullet per accepted answer (no duplicates).
|
||||
- Total asked (accepted) questions ≤ 5.
|
||||
- Updated sections contain no lingering vague placeholders the new answer was meant to resolve.
|
||||
- No contradictory earlier statement remains (scan for now-invalid alternative choices removed).
|
||||
- Markdown structure valid; only allowed new headings: `## Clarifications`, `### Session YYYY-MM-DD`.
|
||||
- Terminology consistency: same canonical term used across all updated sections.
|
||||
|
||||
7. Write the updated spec back to `FEATURE_SPEC`.
|
||||
|
||||
8. Report completion (after questioning loop ends or early termination):
|
||||
- Number of questions asked & answered.
|
||||
- Path to updated spec.
|
||||
- Sections touched (list names).
|
||||
- Coverage summary table listing each taxonomy category with Status: Resolved (was Partial/Missing and addressed), Deferred (exceeds question quota or better suited for planning), Clear (already sufficient), Outstanding (still Partial/Missing but low impact).
|
||||
- If any Outstanding or Deferred remain, recommend whether to proceed to `/speckit.plan` or run `/speckit.clarify` again later post-plan.
|
||||
- Suggested next command.
|
||||
|
||||
Behavior rules:
|
||||
|
||||
- If no meaningful ambiguities found (or all potential questions would be low-impact), respond: "No critical ambiguities detected worth formal clarification." and suggest proceeding.
|
||||
- If spec file missing, instruct user to run `/speckit.specify` first (do not create a new spec here).
|
||||
- Never exceed 5 total asked questions (clarification retries for a single question do not count as new questions).
|
||||
- Avoid speculative tech stack questions unless the absence blocks functional clarity.
|
||||
- Respect user early termination signals ("stop", "done", "proceed").
|
||||
- If no questions asked due to full coverage, output a compact coverage summary (all categories Clear) then suggest advancing.
|
||||
- If quota reached with unresolved high-impact categories remaining, explicitly flag them under Deferred with rationale.
|
||||
|
||||
Context for prioritization: $ARGUMENTS
|
||||
78
Code/Supervised_learning/resnet/.github/prompts/speckit.constitution.prompt.md
vendored
Normal file
78
Code/Supervised_learning/resnet/.github/prompts/speckit.constitution.prompt.md
vendored
Normal file
@ -0,0 +1,78 @@
|
||||
---
|
||||
description: Create or update the project constitution from interactive or provided principle inputs, ensuring all dependent templates stay in sync
|
||||
---
|
||||
|
||||
## User Input
|
||||
|
||||
```text
|
||||
$ARGUMENTS
|
||||
```
|
||||
|
||||
You **MUST** consider the user input before proceeding (if not empty).
|
||||
|
||||
## Outline
|
||||
|
||||
You are updating the project constitution at `.specify/memory/constitution.md`. This file is a TEMPLATE containing placeholder tokens in square brackets (e.g. `[PROJECT_NAME]`, `[PRINCIPLE_1_NAME]`). Your job is to (a) collect/derive concrete values, (b) fill the template precisely, and (c) propagate any amendments across dependent artifacts.
|
||||
|
||||
Follow this execution flow:
|
||||
|
||||
1. Load the existing constitution template at `.specify/memory/constitution.md`.
|
||||
- Identify every placeholder token of the form `[ALL_CAPS_IDENTIFIER]`.
|
||||
**IMPORTANT**: The user might require less or more principles than the ones used in the template. If a number is specified, respect that - follow the general template. You will update the doc accordingly.
|
||||
|
||||
2. Collect/derive values for placeholders:
|
||||
- If user input (conversation) supplies a value, use it.
|
||||
- Otherwise infer from existing repo context (README, docs, prior constitution versions if embedded).
|
||||
- For governance dates: `RATIFICATION_DATE` is the original adoption date (if unknown ask or mark TODO), `LAST_AMENDED_DATE` is today if changes are made, otherwise keep previous.
|
||||
- `CONSTITUTION_VERSION` must increment according to semantic versioning rules:
|
||||
- MAJOR: Backward incompatible governance/principle removals or redefinitions.
|
||||
- MINOR: New principle/section added or materially expanded guidance.
|
||||
- PATCH: Clarifications, wording, typo fixes, non-semantic refinements.
|
||||
- If version bump type ambiguous, propose reasoning before finalizing.
|
||||
|
||||
3. Draft the updated constitution content:
|
||||
- Replace every placeholder with concrete text (no bracketed tokens left except intentionally retained template slots that the project has chosen not to define yet—explicitly justify any left).
|
||||
- Preserve heading hierarchy and comments can be removed once replaced unless they still add clarifying guidance.
|
||||
- Ensure each Principle section: succinct name line, paragraph (or bullet list) capturing non‑negotiable rules, explicit rationale if not obvious.
|
||||
- Ensure Governance section lists amendment procedure, versioning policy, and compliance review expectations.
|
||||
|
||||
4. Consistency propagation checklist (convert prior checklist into active validations):
|
||||
- Read `.specify/templates/plan-template.md` and ensure any "Constitution Check" or rules align with updated principles.
|
||||
- Read `.specify/templates/spec-template.md` for scope/requirements alignment—update if constitution adds/removes mandatory sections or constraints.
|
||||
- Read `.specify/templates/tasks-template.md` and ensure task categorization reflects new or removed principle-driven task types (e.g., observability, versioning, testing discipline).
|
||||
- Read each command file in `.specify/templates/commands/*.md` (including this one) to verify no outdated references (agent-specific names like CLAUDE only) remain when generic guidance is required.
|
||||
- Read any runtime guidance docs (e.g., `README.md`, `docs/quickstart.md`, or agent-specific guidance files if present). Update references to principles changed.
|
||||
|
||||
5. Produce a Sync Impact Report (prepend as an HTML comment at top of the constitution file after update):
|
||||
- Version change: old → new
|
||||
- List of modified principles (old title → new title if renamed)
|
||||
- Added sections
|
||||
- Removed sections
|
||||
- Templates requiring updates (✅ updated / ⚠ pending) with file paths
|
||||
- Follow-up TODOs if any placeholders intentionally deferred.
|
||||
|
||||
6. Validation before final output:
|
||||
- No remaining unexplained bracket tokens.
|
||||
- Version line matches report.
|
||||
- Dates ISO format YYYY-MM-DD.
|
||||
- Principles are declarative, testable, and free of vague language ("should" → replace with MUST/SHOULD rationale where appropriate).
|
||||
|
||||
7. Write the completed constitution back to `.specify/memory/constitution.md` (overwrite).
|
||||
|
||||
8. Output a final summary to the user with:
|
||||
- New version and bump rationale.
|
||||
- Any files flagged for manual follow-up.
|
||||
- Suggested commit message (e.g., `docs: amend constitution to vX.Y.Z (principle additions + governance update)`).
|
||||
|
||||
Formatting & Style Requirements:
|
||||
|
||||
- Use Markdown headings exactly as in the template (do not demote/promote levels).
|
||||
- Wrap long rationale lines to keep readability (<100 chars ideally) but do not hard enforce with awkward breaks.
|
||||
- Keep a single blank line between sections.
|
||||
- Avoid trailing whitespace.
|
||||
|
||||
If the user supplies partial updates (e.g., only one principle revision), still perform validation and version decision steps.
|
||||
|
||||
If critical info missing (e.g., ratification date truly unknown), insert `TODO(<FIELD_NAME>): explanation` and include in the Sync Impact Report under deferred items.
|
||||
|
||||
Do not create a new template; always operate on the existing `.specify/memory/constitution.md` file.
|
||||
134
Code/Supervised_learning/resnet/.github/prompts/speckit.implement.prompt.md
vendored
Normal file
134
Code/Supervised_learning/resnet/.github/prompts/speckit.implement.prompt.md
vendored
Normal file
@ -0,0 +1,134 @@
|
||||
---
|
||||
description: Execute the implementation plan by processing and executing all tasks defined in tasks.md
|
||||
---
|
||||
|
||||
## User Input
|
||||
|
||||
```text
|
||||
$ARGUMENTS
|
||||
```
|
||||
|
||||
You **MUST** consider the user input before proceeding (if not empty).
|
||||
|
||||
## Outline
|
||||
|
||||
1. Run `.specify/scripts/powershell/check-prerequisites.ps1 -Json -RequireTasks -IncludeTasks` from repo root and parse FEATURE_DIR and AVAILABLE_DOCS list. All paths must be absolute. For single quotes in args like "I'm Groot", use escape syntax: e.g 'I'\''m Groot' (or double-quote if possible: "I'm Groot").
|
||||
|
||||
2. **Check checklists status** (if FEATURE_DIR/checklists/ exists):
|
||||
- Scan all checklist files in the checklists/ directory
|
||||
- For each checklist, count:
|
||||
- Total items: All lines matching `- [ ]` or `- [X]` or `- [x]`
|
||||
- Completed items: Lines matching `- [X]` or `- [x]`
|
||||
- Incomplete items: Lines matching `- [ ]`
|
||||
- Create a status table:
|
||||
|
||||
```text
|
||||
| Checklist | Total | Completed | Incomplete | Status |
|
||||
|-----------|-------|-----------|------------|--------|
|
||||
| ux.md | 12 | 12 | 0 | ✓ PASS |
|
||||
| test.md | 8 | 5 | 3 | ✗ FAIL |
|
||||
| security.md | 6 | 6 | 0 | ✓ PASS |
|
||||
```
|
||||
|
||||
- Calculate overall status:
|
||||
- **PASS**: All checklists have 0 incomplete items
|
||||
- **FAIL**: One or more checklists have incomplete items
|
||||
|
||||
- **If any checklist is incomplete**:
|
||||
- Display the table with incomplete item counts
|
||||
- **STOP** and ask: "Some checklists are incomplete. Do you want to proceed with implementation anyway? (yes/no)"
|
||||
- Wait for user response before continuing
|
||||
- If user says "no" or "wait" or "stop", halt execution
|
||||
- If user says "yes" or "proceed" or "continue", proceed to step 3
|
||||
|
||||
- **If all checklists are complete**:
|
||||
- Display the table showing all checklists passed
|
||||
- Automatically proceed to step 3
|
||||
|
||||
3. Load and analyze the implementation context:
|
||||
- **REQUIRED**: Read tasks.md for the complete task list and execution plan
|
||||
- **REQUIRED**: Read plan.md for tech stack, architecture, and file structure
|
||||
- **IF EXISTS**: Read data-model.md for entities and relationships
|
||||
- **IF EXISTS**: Read contracts/ for API specifications and test requirements
|
||||
- **IF EXISTS**: Read research.md for technical decisions and constraints
|
||||
- **IF EXISTS**: Read quickstart.md for integration scenarios
|
||||
|
||||
4. **Project Setup Verification**:
|
||||
- **REQUIRED**: Create/verify ignore files based on actual project setup:
|
||||
|
||||
**Detection & Creation Logic**:
|
||||
- Check if the following command succeeds to determine if the repository is a git repo (create/verify .gitignore if so):
|
||||
|
||||
```sh
|
||||
git rev-parse --git-dir 2>/dev/null
|
||||
```
|
||||
|
||||
- Check if Dockerfile* exists or Docker in plan.md → create/verify .dockerignore
|
||||
- Check if .eslintrc*or eslint.config.* exists → create/verify .eslintignore
|
||||
- Check if .prettierrc* exists → create/verify .prettierignore
|
||||
- Check if .npmrc or package.json exists → create/verify .npmignore (if publishing)
|
||||
- Check if terraform files (*.tf) exist → create/verify .terraformignore
|
||||
- Check if .helmignore needed (helm charts present) → create/verify .helmignore
|
||||
|
||||
**If ignore file already exists**: Verify it contains essential patterns, append missing critical patterns only
|
||||
**If ignore file missing**: Create with full pattern set for detected technology
|
||||
|
||||
**Common Patterns by Technology** (from plan.md tech stack):
|
||||
- **Node.js/JavaScript/TypeScript**: `node_modules/`, `dist/`, `build/`, `*.log`, `.env*`
|
||||
- **Python**: `__pycache__/`, `*.pyc`, `.venv/`, `venv/`, `dist/`, `*.egg-info/`
|
||||
- **Java**: `target/`, `*.class`, `*.jar`, `.gradle/`, `build/`
|
||||
- **C#/.NET**: `bin/`, `obj/`, `*.user`, `*.suo`, `packages/`
|
||||
- **Go**: `*.exe`, `*.test`, `vendor/`, `*.out`
|
||||
- **Ruby**: `.bundle/`, `log/`, `tmp/`, `*.gem`, `vendor/bundle/`
|
||||
- **PHP**: `vendor/`, `*.log`, `*.cache`, `*.env`
|
||||
- **Rust**: `target/`, `debug/`, `release/`, `*.rs.bk`, `*.rlib`, `*.prof*`, `.idea/`, `*.log`, `.env*`
|
||||
- **Kotlin**: `build/`, `out/`, `.gradle/`, `.idea/`, `*.class`, `*.jar`, `*.iml`, `*.log`, `.env*`
|
||||
- **C++**: `build/`, `bin/`, `obj/`, `out/`, `*.o`, `*.so`, `*.a`, `*.exe`, `*.dll`, `.idea/`, `*.log`, `.env*`
|
||||
- **C**: `build/`, `bin/`, `obj/`, `out/`, `*.o`, `*.a`, `*.so`, `*.exe`, `Makefile`, `config.log`, `.idea/`, `*.log`, `.env*`
|
||||
- **Swift**: `.build/`, `DerivedData/`, `*.swiftpm/`, `Packages/`
|
||||
- **R**: `.Rproj.user/`, `.Rhistory`, `.RData`, `.Ruserdata`, `*.Rproj`, `packrat/`, `renv/`
|
||||
- **Universal**: `.DS_Store`, `Thumbs.db`, `*.tmp`, `*.swp`, `.vscode/`, `.idea/`
|
||||
|
||||
**Tool-Specific Patterns**:
|
||||
- **Docker**: `node_modules/`, `.git/`, `Dockerfile*`, `.dockerignore`, `*.log*`, `.env*`, `coverage/`
|
||||
- **ESLint**: `node_modules/`, `dist/`, `build/`, `coverage/`, `*.min.js`
|
||||
- **Prettier**: `node_modules/`, `dist/`, `build/`, `coverage/`, `package-lock.json`, `yarn.lock`, `pnpm-lock.yaml`
|
||||
- **Terraform**: `.terraform/`, `*.tfstate*`, `*.tfvars`, `.terraform.lock.hcl`
|
||||
- **Kubernetes/k8s**: `*.secret.yaml`, `secrets/`, `.kube/`, `kubeconfig*`, `*.key`, `*.crt`
|
||||
|
||||
5. Parse tasks.md structure and extract:
|
||||
- **Task phases**: Setup, Tests, Core, Integration, Polish
|
||||
- **Task dependencies**: Sequential vs parallel execution rules
|
||||
- **Task details**: ID, description, file paths, parallel markers [P]
|
||||
- **Execution flow**: Order and dependency requirements
|
||||
|
||||
6. Execute implementation following the task plan:
|
||||
- **Phase-by-phase execution**: Complete each phase before moving to the next
|
||||
- **Respect dependencies**: Run sequential tasks in order, parallel tasks [P] can run together
|
||||
- **Follow TDD approach**: Execute test tasks before their corresponding implementation tasks
|
||||
- **File-based coordination**: Tasks affecting the same files must run sequentially
|
||||
- **Validation checkpoints**: Verify each phase completion before proceeding
|
||||
|
||||
7. Implementation execution rules:
|
||||
- **Setup first**: Initialize project structure, dependencies, configuration
|
||||
- **Tests before code**: If you need to write tests for contracts, entities, and integration scenarios
|
||||
- **Core development**: Implement models, services, CLI commands, endpoints
|
||||
- **Integration work**: Database connections, middleware, logging, external services
|
||||
- **Polish and validation**: Unit tests, performance optimization, documentation
|
||||
|
||||
8. Progress tracking and error handling:
|
||||
- Report progress after each completed task
|
||||
- Halt execution if any non-parallel task fails
|
||||
- For parallel tasks [P], continue with successful tasks, report failed ones
|
||||
- Provide clear error messages with context for debugging
|
||||
- Suggest next steps if implementation cannot proceed
|
||||
- **IMPORTANT** For completed tasks, make sure to mark the task off as [X] in the tasks file.
|
||||
|
||||
9. Completion validation:
|
||||
- Verify all required tasks are completed
|
||||
- Check that implemented features match the original specification
|
||||
- Validate that tests pass and coverage meets requirements
|
||||
- Confirm the implementation follows the technical plan
|
||||
- Report final status with summary of completed work
|
||||
|
||||
Note: This command assumes a complete task breakdown exists in tasks.md. If tasks are incomplete or missing, suggest running `/speckit.tasks` first to regenerate the task list.
|
||||
81
Code/Supervised_learning/resnet/.github/prompts/speckit.plan.prompt.md
vendored
Normal file
81
Code/Supervised_learning/resnet/.github/prompts/speckit.plan.prompt.md
vendored
Normal file
@ -0,0 +1,81 @@
|
||||
---
|
||||
description: Execute the implementation planning workflow using the plan template to generate design artifacts.
|
||||
---
|
||||
|
||||
## User Input
|
||||
|
||||
```text
|
||||
$ARGUMENTS
|
||||
```
|
||||
|
||||
You **MUST** consider the user input before proceeding (if not empty).
|
||||
|
||||
## Outline
|
||||
|
||||
1. **Setup**: Run `.specify/scripts/powershell/setup-plan.ps1 -Json` from repo root and parse JSON for FEATURE_SPEC, IMPL_PLAN, SPECS_DIR, BRANCH. For single quotes in args like "I'm Groot", use escape syntax: e.g 'I'\''m Groot' (or double-quote if possible: "I'm Groot").
|
||||
|
||||
2. **Load context**: Read FEATURE_SPEC and `.specify/memory/constitution.md`. Load IMPL_PLAN template (already copied).
|
||||
|
||||
3. **Execute plan workflow**: Follow the structure in IMPL_PLAN template to:
|
||||
- Fill Technical Context (mark unknowns as "NEEDS CLARIFICATION")
|
||||
- Fill Constitution Check section from constitution
|
||||
- Evaluate gates (ERROR if violations unjustified)
|
||||
- Phase 0: Generate research.md (resolve all NEEDS CLARIFICATION)
|
||||
- Phase 1: Generate data-model.md, contracts/, quickstart.md
|
||||
- Phase 1: Update agent context by running the agent script
|
||||
- Re-evaluate Constitution Check post-design
|
||||
|
||||
4. **Stop and report**: Command ends after Phase 2 planning. Report branch, IMPL_PLAN path, and generated artifacts.
|
||||
|
||||
## Phases
|
||||
|
||||
### Phase 0: Outline & Research
|
||||
|
||||
1. **Extract unknowns from Technical Context** above:
|
||||
- For each NEEDS CLARIFICATION → research task
|
||||
- For each dependency → best practices task
|
||||
- For each integration → patterns task
|
||||
|
||||
2. **Generate and dispatch research agents**:
|
||||
|
||||
```text
|
||||
For each unknown in Technical Context:
|
||||
Task: "Research {unknown} for {feature context}"
|
||||
For each technology choice:
|
||||
Task: "Find best practices for {tech} in {domain}"
|
||||
```
|
||||
|
||||
3. **Consolidate findings** in `research.md` using format:
|
||||
- Decision: [what was chosen]
|
||||
- Rationale: [why chosen]
|
||||
- Alternatives considered: [what else evaluated]
|
||||
|
||||
**Output**: research.md with all NEEDS CLARIFICATION resolved
|
||||
|
||||
### Phase 1: Design & Contracts
|
||||
|
||||
**Prerequisites:** `research.md` complete
|
||||
|
||||
1. **Extract entities from feature spec** → `data-model.md`:
|
||||
- Entity name, fields, relationships
|
||||
- Validation rules from requirements
|
||||
- State transitions if applicable
|
||||
|
||||
2. **Generate API contracts** from functional requirements:
|
||||
- For each user action → endpoint
|
||||
- Use standard REST/GraphQL patterns
|
||||
- Output OpenAPI/GraphQL schema to `/contracts/`
|
||||
|
||||
3. **Agent context update**:
|
||||
- Run `.specify/scripts/powershell/update-agent-context.ps1 -AgentType copilot`
|
||||
- These scripts detect which AI agent is in use
|
||||
- Update the appropriate agent-specific context file
|
||||
- Add only new technology from current plan
|
||||
- Preserve manual additions between markers
|
||||
|
||||
**Output**: data-model.md, /contracts/*, quickstart.md, agent-specific file
|
||||
|
||||
## Key rules
|
||||
|
||||
- Use absolute paths
|
||||
- ERROR on gate failures or unresolved clarifications
|
||||
249
Code/Supervised_learning/resnet/.github/prompts/speckit.specify.prompt.md
vendored
Normal file
249
Code/Supervised_learning/resnet/.github/prompts/speckit.specify.prompt.md
vendored
Normal file
@ -0,0 +1,249 @@
|
||||
---
|
||||
description: Create or update the feature specification from a natural language feature description.
|
||||
---
|
||||
|
||||
## User Input
|
||||
|
||||
```text
|
||||
$ARGUMENTS
|
||||
```
|
||||
|
||||
You **MUST** consider the user input before proceeding (if not empty).
|
||||
|
||||
## Outline
|
||||
|
||||
The text the user typed after `/speckit.specify` in the triggering message **is** the feature description. Assume you always have it available in this conversation even if `$ARGUMENTS` appears literally below. Do not ask the user to repeat it unless they provided an empty command.
|
||||
|
||||
Given that feature description, do this:
|
||||
|
||||
1. **Generate a concise short name** (2-4 words) for the branch:
|
||||
- Analyze the feature description and extract the most meaningful keywords
|
||||
- Create a 2-4 word short name that captures the essence of the feature
|
||||
- Use action-noun format when possible (e.g., "add-user-auth", "fix-payment-bug")
|
||||
- Preserve technical terms and acronyms (OAuth2, API, JWT, etc.)
|
||||
- Keep it concise but descriptive enough to understand the feature at a glance
|
||||
- Examples:
|
||||
- "I want to add user authentication" → "user-auth"
|
||||
- "Implement OAuth2 integration for the API" → "oauth2-api-integration"
|
||||
- "Create a dashboard for analytics" → "analytics-dashboard"
|
||||
- "Fix payment processing timeout bug" → "fix-payment-timeout"
|
||||
|
||||
2. **Check for existing branches before creating new one**:
|
||||
|
||||
a. First, fetch all remote branches to ensure we have the latest information:
|
||||
```bash
|
||||
git fetch --all --prune
|
||||
```
|
||||
|
||||
b. Find the highest feature number across all sources for the short-name:
|
||||
- Remote branches: `git ls-remote --heads origin | grep -E 'refs/heads/[0-9]+-<short-name>$'`
|
||||
- Local branches: `git branch | grep -E '^[* ]*[0-9]+-<short-name>$'`
|
||||
- Specs directories: Check for directories matching `specs/[0-9]+-<short-name>`
|
||||
|
||||
c. Determine the next available number:
|
||||
- Extract all numbers from all three sources
|
||||
- Find the highest number N
|
||||
- Use N+1 for the new branch number
|
||||
|
||||
d. Run the script `.specify/scripts/powershell/create-new-feature.ps1 -Json "$ARGUMENTS"` with the calculated number and short-name:
|
||||
- Pass `--number N+1` and `--short-name "your-short-name"` along with the feature description
|
||||
- Bash example: `.specify/scripts/powershell/create-new-feature.ps1 -Json "$ARGUMENTS" --json --number 5 --short-name "user-auth" "Add user authentication"`
|
||||
- PowerShell example: `.specify/scripts/powershell/create-new-feature.ps1 -Json "$ARGUMENTS" -Json -Number 5 -ShortName "user-auth" "Add user authentication"`
|
||||
|
||||
**IMPORTANT**:
|
||||
- Check all three sources (remote branches, local branches, specs directories) to find the highest number
|
||||
- Only match branches/directories with the exact short-name pattern
|
||||
- If no existing branches/directories found with this short-name, start with number 1
|
||||
- You must only ever run this script once per feature
|
||||
- The JSON is provided in the terminal as output - always refer to it to get the actual content you're looking for
|
||||
- The JSON output will contain BRANCH_NAME and SPEC_FILE paths
|
||||
- For single quotes in args like "I'm Groot", use escape syntax: e.g 'I'\''m Groot' (or double-quote if possible: "I'm Groot")
|
||||
|
||||
3. Load `.specify/templates/spec-template.md` to understand required sections.
|
||||
|
||||
4. Follow this execution flow:
|
||||
|
||||
1. Parse user description from Input
|
||||
If empty: ERROR "No feature description provided"
|
||||
2. Extract key concepts from description
|
||||
Identify: actors, actions, data, constraints
|
||||
3. For unclear aspects:
|
||||
- Make informed guesses based on context and industry standards
|
||||
- Only mark with [NEEDS CLARIFICATION: specific question] if:
|
||||
- The choice significantly impacts feature scope or user experience
|
||||
- Multiple reasonable interpretations exist with different implications
|
||||
- No reasonable default exists
|
||||
- **LIMIT: Maximum 3 [NEEDS CLARIFICATION] markers total**
|
||||
- Prioritize clarifications by impact: scope > security/privacy > user experience > technical details
|
||||
4. Fill User Scenarios & Testing section
|
||||
If no clear user flow: ERROR "Cannot determine user scenarios"
|
||||
5. Generate Functional Requirements
|
||||
Each requirement must be testable
|
||||
Use reasonable defaults for unspecified details (document assumptions in Assumptions section)
|
||||
6. Define Success Criteria
|
||||
Create measurable, technology-agnostic outcomes
|
||||
Include both quantitative metrics (time, performance, volume) and qualitative measures (user satisfaction, task completion)
|
||||
Each criterion must be verifiable without implementation details
|
||||
7. Identify Key Entities (if data involved)
|
||||
8. Return: SUCCESS (spec ready for planning)
|
||||
|
||||
5. Write the specification to SPEC_FILE using the template structure, replacing placeholders with concrete details derived from the feature description (arguments) while preserving section order and headings.
|
||||
|
||||
6. **Specification Quality Validation**: After writing the initial spec, validate it against quality criteria:
|
||||
|
||||
a. **Create Spec Quality Checklist**: Generate a checklist file at `FEATURE_DIR/checklists/requirements.md` using the checklist template structure with these validation items:
|
||||
|
||||
```markdown
|
||||
# Specification Quality Checklist: [FEATURE NAME]
|
||||
|
||||
**Purpose**: Validate specification completeness and quality before proceeding to planning
|
||||
**Created**: [DATE]
|
||||
**Feature**: [Link to spec.md]
|
||||
|
||||
## Content Quality
|
||||
|
||||
- [ ] No implementation details (languages, frameworks, APIs)
|
||||
- [ ] Focused on user value and business needs
|
||||
- [ ] Written for non-technical stakeholders
|
||||
- [ ] All mandatory sections completed
|
||||
|
||||
## Requirement Completeness
|
||||
|
||||
- [ ] No [NEEDS CLARIFICATION] markers remain
|
||||
- [ ] Requirements are testable and unambiguous
|
||||
- [ ] Success criteria are measurable
|
||||
- [ ] Success criteria are technology-agnostic (no implementation details)
|
||||
- [ ] All acceptance scenarios are defined
|
||||
- [ ] Edge cases are identified
|
||||
- [ ] Scope is clearly bounded
|
||||
- [ ] Dependencies and assumptions identified
|
||||
|
||||
## Feature Readiness
|
||||
|
||||
- [ ] All functional requirements have clear acceptance criteria
|
||||
- [ ] User scenarios cover primary flows
|
||||
- [ ] Feature meets measurable outcomes defined in Success Criteria
|
||||
- [ ] No implementation details leak into specification
|
||||
|
||||
## Notes
|
||||
|
||||
- Items marked incomplete require spec updates before `/speckit.clarify` or `/speckit.plan`
|
||||
```
|
||||
|
||||
b. **Run Validation Check**: Review the spec against each checklist item:
|
||||
- For each item, determine if it passes or fails
|
||||
- Document specific issues found (quote relevant spec sections)
|
||||
|
||||
c. **Handle Validation Results**:
|
||||
|
||||
- **If all items pass**: Mark checklist complete and proceed to step 6
|
||||
|
||||
- **If items fail (excluding [NEEDS CLARIFICATION])**:
|
||||
1. List the failing items and specific issues
|
||||
2. Update the spec to address each issue
|
||||
3. Re-run validation until all items pass (max 3 iterations)
|
||||
4. If still failing after 3 iterations, document remaining issues in checklist notes and warn user
|
||||
|
||||
- **If [NEEDS CLARIFICATION] markers remain**:
|
||||
1. Extract all [NEEDS CLARIFICATION: ...] markers from the spec
|
||||
2. **LIMIT CHECK**: If more than 3 markers exist, keep only the 3 most critical (by scope/security/UX impact) and make informed guesses for the rest
|
||||
3. For each clarification needed (max 3), present options to user in this format:
|
||||
|
||||
```markdown
|
||||
## Question [N]: [Topic]
|
||||
|
||||
**Context**: [Quote relevant spec section]
|
||||
|
||||
**What we need to know**: [Specific question from NEEDS CLARIFICATION marker]
|
||||
|
||||
**Suggested Answers**:
|
||||
|
||||
| Option | Answer | Implications |
|
||||
|--------|--------|--------------|
|
||||
| A | [First suggested answer] | [What this means for the feature] |
|
||||
| B | [Second suggested answer] | [What this means for the feature] |
|
||||
| C | [Third suggested answer] | [What this means for the feature] |
|
||||
| Custom | Provide your own answer | [Explain how to provide custom input] |
|
||||
|
||||
**Your choice**: _[Wait for user response]_
|
||||
```
|
||||
|
||||
4. **CRITICAL - Table Formatting**: Ensure markdown tables are properly formatted:
|
||||
- Use consistent spacing with pipes aligned
|
||||
- Each cell should have spaces around content: `| Content |` not `|Content|`
|
||||
- Header separator must have at least 3 dashes: `|--------|`
|
||||
- Test that the table renders correctly in markdown preview
|
||||
5. Number questions sequentially (Q1, Q2, Q3 - max 3 total)
|
||||
6. Present all questions together before waiting for responses
|
||||
7. Wait for user to respond with their choices for all questions (e.g., "Q1: A, Q2: Custom - [details], Q3: B")
|
||||
8. Update the spec by replacing each [NEEDS CLARIFICATION] marker with the user's selected or provided answer
|
||||
9. Re-run validation after all clarifications are resolved
|
||||
|
||||
d. **Update Checklist**: After each validation iteration, update the checklist file with current pass/fail status
|
||||
|
||||
7. Report completion with branch name, spec file path, checklist results, and readiness for the next phase (`/speckit.clarify` or `/speckit.plan`).
|
||||
|
||||
**NOTE:** The script creates and checks out the new branch and initializes the spec file before writing.
|
||||
|
||||
## General Guidelines
|
||||
|
||||
## Quick Guidelines
|
||||
|
||||
- Focus on **WHAT** users need and **WHY**.
|
||||
- Avoid HOW to implement (no tech stack, APIs, code structure).
|
||||
- Written for business stakeholders, not developers.
|
||||
- DO NOT create any checklists that are embedded in the spec. That will be a separate command.
|
||||
|
||||
### Section Requirements
|
||||
|
||||
- **Mandatory sections**: Must be completed for every feature
|
||||
- **Optional sections**: Include only when relevant to the feature
|
||||
- When a section doesn't apply, remove it entirely (don't leave as "N/A")
|
||||
|
||||
### For AI Generation
|
||||
|
||||
When creating this spec from a user prompt:
|
||||
|
||||
1. **Make informed guesses**: Use context, industry standards, and common patterns to fill gaps
|
||||
2. **Document assumptions**: Record reasonable defaults in the Assumptions section
|
||||
3. **Limit clarifications**: Maximum 3 [NEEDS CLARIFICATION] markers - use only for critical decisions that:
|
||||
- Significantly impact feature scope or user experience
|
||||
- Have multiple reasonable interpretations with different implications
|
||||
- Lack any reasonable default
|
||||
4. **Prioritize clarifications**: scope > security/privacy > user experience > technical details
|
||||
5. **Think like a tester**: Every vague requirement should fail the "testable and unambiguous" checklist item
|
||||
6. **Common areas needing clarification** (only if no reasonable default exists):
|
||||
- Feature scope and boundaries (include/exclude specific use cases)
|
||||
- User types and permissions (if multiple conflicting interpretations possible)
|
||||
- Security/compliance requirements (when legally/financially significant)
|
||||
|
||||
**Examples of reasonable defaults** (don't ask about these):
|
||||
|
||||
- Data retention: Industry-standard practices for the domain
|
||||
- Performance targets: Standard web/mobile app expectations unless specified
|
||||
- Error handling: User-friendly messages with appropriate fallbacks
|
||||
- Authentication method: Standard session-based or OAuth2 for web apps
|
||||
- Integration patterns: RESTful APIs unless specified otherwise
|
||||
|
||||
### Success Criteria Guidelines
|
||||
|
||||
Success criteria must be:
|
||||
|
||||
1. **Measurable**: Include specific metrics (time, percentage, count, rate)
|
||||
2. **Technology-agnostic**: No mention of frameworks, languages, databases, or tools
|
||||
3. **User-focused**: Describe outcomes from user/business perspective, not system internals
|
||||
4. **Verifiable**: Can be tested/validated without knowing implementation details
|
||||
|
||||
**Good examples**:
|
||||
|
||||
- "Users can complete checkout in under 3 minutes"
|
||||
- "System supports 10,000 concurrent users"
|
||||
- "95% of searches return results in under 1 second"
|
||||
- "Task completion rate improves by 40%"
|
||||
|
||||
**Bad examples** (implementation-focused):
|
||||
|
||||
- "API response time is under 200ms" (too technical, use "Users see results instantly")
|
||||
- "Database can handle 1000 TPS" (implementation detail, use user-facing metric)
|
||||
- "React components render efficiently" (framework-specific)
|
||||
- "Redis cache hit rate above 80%" (technology-specific)
|
||||
128
Code/Supervised_learning/resnet/.github/prompts/speckit.tasks.prompt.md
vendored
Normal file
128
Code/Supervised_learning/resnet/.github/prompts/speckit.tasks.prompt.md
vendored
Normal file
@ -0,0 +1,128 @@
|
||||
---
|
||||
description: Generate an actionable, dependency-ordered tasks.md for the feature based on available design artifacts.
|
||||
---
|
||||
|
||||
## User Input
|
||||
|
||||
```text
|
||||
$ARGUMENTS
|
||||
```
|
||||
|
||||
You **MUST** consider the user input before proceeding (if not empty).
|
||||
|
||||
## Outline
|
||||
|
||||
1. **Setup**: Run `.specify/scripts/powershell/check-prerequisites.ps1 -Json` from repo root and parse FEATURE_DIR and AVAILABLE_DOCS list. All paths must be absolute. For single quotes in args like "I'm Groot", use escape syntax: e.g 'I'\''m Groot' (or double-quote if possible: "I'm Groot").
|
||||
|
||||
2. **Load design documents**: Read from FEATURE_DIR:
|
||||
- **Required**: plan.md (tech stack, libraries, structure), spec.md (user stories with priorities)
|
||||
- **Optional**: data-model.md (entities), contracts/ (API endpoints), research.md (decisions), quickstart.md (test scenarios)
|
||||
- Note: Not all projects have all documents. Generate tasks based on what's available.
|
||||
|
||||
3. **Execute task generation workflow**:
|
||||
- Load plan.md and extract tech stack, libraries, project structure
|
||||
- Load spec.md and extract user stories with their priorities (P1, P2, P3, etc.)
|
||||
- If data-model.md exists: Extract entities and map to user stories
|
||||
- If contracts/ exists: Map endpoints to user stories
|
||||
- If research.md exists: Extract decisions for setup tasks
|
||||
- Generate tasks organized by user story (see Task Generation Rules below)
|
||||
- Generate dependency graph showing user story completion order
|
||||
- Create parallel execution examples per user story
|
||||
- Validate task completeness (each user story has all needed tasks, independently testable)
|
||||
|
||||
4. **Generate tasks.md**: Use `.specify.specify/templates/tasks-template.md` as structure, fill with:
|
||||
- Correct feature name from plan.md
|
||||
- Phase 1: Setup tasks (project initialization)
|
||||
- Phase 2: Foundational tasks (blocking prerequisites for all user stories)
|
||||
- Phase 3+: One phase per user story (in priority order from spec.md)
|
||||
- Each phase includes: story goal, independent test criteria, tests (if requested), implementation tasks
|
||||
- Final Phase: Polish & cross-cutting concerns
|
||||
- All tasks must follow the strict checklist format (see Task Generation Rules below)
|
||||
- Clear file paths for each task
|
||||
- Dependencies section showing story completion order
|
||||
- Parallel execution examples per story
|
||||
- Implementation strategy section (MVP first, incremental delivery)
|
||||
|
||||
5. **Report**: Output path to generated tasks.md and summary:
|
||||
- Total task count
|
||||
- Task count per user story
|
||||
- Parallel opportunities identified
|
||||
- Independent test criteria for each story
|
||||
- Suggested MVP scope (typically just User Story 1)
|
||||
- Format validation: Confirm ALL tasks follow the checklist format (checkbox, ID, labels, file paths)
|
||||
|
||||
Context for task generation: $ARGUMENTS
|
||||
|
||||
The tasks.md should be immediately executable - each task must be specific enough that an LLM can complete it without additional context.
|
||||
|
||||
## Task Generation Rules
|
||||
|
||||
**CRITICAL**: Tasks MUST be organized by user story to enable independent implementation and testing.
|
||||
|
||||
**Tests are OPTIONAL**: Only generate test tasks if explicitly requested in the feature specification or if user requests TDD approach.
|
||||
|
||||
### Checklist Format (REQUIRED)
|
||||
|
||||
Every task MUST strictly follow this format:
|
||||
|
||||
```text
|
||||
- [ ] [TaskID] [P?] [Story?] Description with file path
|
||||
```
|
||||
|
||||
**Format Components**:
|
||||
|
||||
1. **Checkbox**: ALWAYS start with `- [ ]` (markdown checkbox)
|
||||
2. **Task ID**: Sequential number (T001, T002, T003...) in execution order
|
||||
3. **[P] marker**: Include ONLY if task is parallelizable (different files, no dependencies on incomplete tasks)
|
||||
4. **[Story] label**: REQUIRED for user story phase tasks only
|
||||
- Format: [US1], [US2], [US3], etc. (maps to user stories from spec.md)
|
||||
- Setup phase: NO story label
|
||||
- Foundational phase: NO story label
|
||||
- User Story phases: MUST have story label
|
||||
- Polish phase: NO story label
|
||||
5. **Description**: Clear action with exact file path
|
||||
|
||||
**Examples**:
|
||||
|
||||
- ✅ CORRECT: `- [ ] T001 Create project structure per implementation plan`
|
||||
- ✅ CORRECT: `- [ ] T005 [P] Implement authentication middleware in src/middleware/auth.py`
|
||||
- ✅ CORRECT: `- [ ] T012 [P] [US1] Create User model in src/models/user.py`
|
||||
- ✅ CORRECT: `- [ ] T014 [US1] Implement UserService in src/services/user_service.py`
|
||||
- ❌ WRONG: `- [ ] Create User model` (missing ID and Story label)
|
||||
- ❌ WRONG: `T001 [US1] Create model` (missing checkbox)
|
||||
- ❌ WRONG: `- [ ] [US1] Create User model` (missing Task ID)
|
||||
- ❌ WRONG: `- [ ] T001 [US1] Create model` (missing file path)
|
||||
|
||||
### Task Organization
|
||||
|
||||
1. **From User Stories (spec.md)** - PRIMARY ORGANIZATION:
|
||||
- Each user story (P1, P2, P3...) gets its own phase
|
||||
- Map all related components to their story:
|
||||
- Models needed for that story
|
||||
- Services needed for that story
|
||||
- Endpoints/UI needed for that story
|
||||
- If tests requested: Tests specific to that story
|
||||
- Mark story dependencies (most stories should be independent)
|
||||
|
||||
2. **From Contracts**:
|
||||
- Map each contract/endpoint → to the user story it serves
|
||||
- If tests requested: Each contract → contract test task [P] before implementation in that story's phase
|
||||
|
||||
3. **From Data Model**:
|
||||
- Map each entity to the user story(ies) that need it
|
||||
- If entity serves multiple stories: Put in earliest story or Setup phase
|
||||
- Relationships → service layer tasks in appropriate story phase
|
||||
|
||||
4. **From Setup/Infrastructure**:
|
||||
- Shared infrastructure → Setup phase (Phase 1)
|
||||
- Foundational/blocking tasks → Foundational phase (Phase 2)
|
||||
- Story-specific setup → within that story's phase
|
||||
|
||||
### Phase Structure
|
||||
|
||||
- **Phase 1**: Setup (project initialization)
|
||||
- **Phase 2**: Foundational (blocking prerequisites - MUST complete before user stories)
|
||||
- **Phase 3+**: User Stories in priority order (P1, P2, P3...)
|
||||
- Within each story: Tests (if requested) → Models → Services → Endpoints → Integration
|
||||
- Each phase should be a complete, independently testable increment
|
||||
- **Final Phase**: Polish & Cross-Cutting Concerns
|
||||
31
Code/Supervised_learning/resnet/.gitignore
vendored
Normal file
31
Code/Supervised_learning/resnet/.gitignore
vendored
Normal file
@ -0,0 +1,31 @@
|
||||
__pycache__/
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pyd
|
||||
.Python
|
||||
*.so
|
||||
*.egg
|
||||
*.egg-info/
|
||||
dist/
|
||||
build/
|
||||
.venv/
|
||||
venv/
|
||||
env/
|
||||
ENV/
|
||||
*.log
|
||||
.env
|
||||
.env.*
|
||||
models/*.pth
|
||||
models/*.pt
|
||||
data/images/
|
||||
.pytest_cache/
|
||||
.coverage
|
||||
htmlcov/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
.vscode/
|
||||
.idea/
|
||||
*.iml
|
||||
@ -0,0 +1,39 @@
|
||||
<!--
|
||||
Sync Impact Report:
|
||||
- Version change: none → 1.0.0
|
||||
- Added sections: Core Principles (5 new), Additional Standards, Development Workflow, Governance
|
||||
- Templates requiring updates: ✅ updated .specify/templates/plan-template.md (Constitution Check section)
|
||||
- Follow-up TODOs: RATIFICATION_DATE placeholder
|
||||
-->
|
||||
# ResNet AI Model Constitution
|
||||
|
||||
## Core Principles
|
||||
|
||||
### Code Quality and Modularity
|
||||
All code must follow PEP 8 standards, include type hints, comprehensive docstrings, and be organized into modular, reusable components. Rationale: Enhances readability, maintainability, and collaboration in complex AI projects.
|
||||
|
||||
### Rigorous Testing Standards
|
||||
Implement unit tests for all utility functions, integration tests for data pipelines, and model validation tests including cross-validation and performance benchmarks. Use pytest or equivalent frameworks. Rationale: Ensures reliability and catches errors early in AI model development where failures can have significant impacts.
|
||||
|
||||
### Reproducibility and Versioning
|
||||
Version all code, data, and models using Git and DVC. Set random seeds for reproducibility in experiments. Document all dependencies and environments. Rationale: Critical for scientific validation and debugging in machine learning.
|
||||
|
||||
### Model Evaluation and Validation
|
||||
Evaluate models on independent test sets with multiple metrics (accuracy, F1-score, AUC, etc.). Perform error analysis and bias checks. Rationale: Prevents overfitting and ensures models are fair and effective.
|
||||
|
||||
### Continuous Integration and Quality Gates
|
||||
All changes must pass linting, unit tests, and integration tests in CI/CD pipelines. Model performance regressions must be flagged. Rationale: Maintains code and model quality over time.
|
||||
|
||||
## Additional Standards
|
||||
|
||||
Security: Protect sensitive data, comply with privacy regulations. Ethics: Conduct bias audits, ensure responsible AI practices. Performance: Optimize for computational efficiency and scalability.
|
||||
|
||||
## Development Workflow
|
||||
|
||||
Code reviews mandatory for all PRs. Model changes require peer review of evaluation results. Use issue tracking for bugs and features.
|
||||
|
||||
## Governance
|
||||
|
||||
Constitution supersedes other practices. Amendments require documentation and approval from project maintainers. Compliance verified in code reviews.
|
||||
|
||||
**Version**: 1.0.0 | **Ratified**: TODO(RATIFICATION_DATE): Original adoption date unknown. | **Last Amended**: 2025-11-04
|
||||
@ -0,0 +1,148 @@
|
||||
#!/usr/bin/env pwsh
|
||||
|
||||
# Consolidated prerequisite checking script (PowerShell)
|
||||
#
|
||||
# This script provides unified prerequisite checking for Spec-Driven Development workflow.
|
||||
# It replaces the functionality previously spread across multiple scripts.
|
||||
#
|
||||
# Usage: ./check-prerequisites.ps1 [OPTIONS]
|
||||
#
|
||||
# OPTIONS:
|
||||
# -Json Output in JSON format
|
||||
# -RequireTasks Require tasks.md to exist (for implementation phase)
|
||||
# -IncludeTasks Include tasks.md in AVAILABLE_DOCS list
|
||||
# -PathsOnly Only output path variables (no validation)
|
||||
# -Help, -h Show help message
|
||||
|
||||
[CmdletBinding()]
|
||||
param(
|
||||
[switch]$Json,
|
||||
[switch]$RequireTasks,
|
||||
[switch]$IncludeTasks,
|
||||
[switch]$PathsOnly,
|
||||
[switch]$Help
|
||||
)
|
||||
|
||||
$ErrorActionPreference = 'Stop'
|
||||
|
||||
# Show help if requested
|
||||
if ($Help) {
|
||||
Write-Output @"
|
||||
Usage: check-prerequisites.ps1 [OPTIONS]
|
||||
|
||||
Consolidated prerequisite checking for Spec-Driven Development workflow.
|
||||
|
||||
OPTIONS:
|
||||
-Json Output in JSON format
|
||||
-RequireTasks Require tasks.md to exist (for implementation phase)
|
||||
-IncludeTasks Include tasks.md in AVAILABLE_DOCS list
|
||||
-PathsOnly Only output path variables (no prerequisite validation)
|
||||
-Help, -h Show this help message
|
||||
|
||||
EXAMPLES:
|
||||
# Check task prerequisites (plan.md required)
|
||||
.\check-prerequisites.ps1 -Json
|
||||
|
||||
# Check implementation prerequisites (plan.md + tasks.md required)
|
||||
.\check-prerequisites.ps1 -Json -RequireTasks -IncludeTasks
|
||||
|
||||
# Get feature paths only (no validation)
|
||||
.\check-prerequisites.ps1 -PathsOnly
|
||||
|
||||
"@
|
||||
exit 0
|
||||
}
|
||||
|
||||
# Source common functions
|
||||
. "$PSScriptRoot/common.ps1"
|
||||
|
||||
# Get feature paths and validate branch
|
||||
$paths = Get-FeaturePathsEnv
|
||||
|
||||
if (-not (Test-FeatureBranch -Branch $paths.CURRENT_BRANCH -HasGit:$paths.HAS_GIT)) {
|
||||
exit 1
|
||||
}
|
||||
|
||||
# If paths-only mode, output paths and exit (support combined -Json -PathsOnly)
|
||||
if ($PathsOnly) {
|
||||
if ($Json) {
|
||||
[PSCustomObject]@{
|
||||
REPO_ROOT = $paths.REPO_ROOT
|
||||
BRANCH = $paths.CURRENT_BRANCH
|
||||
FEATURE_DIR = $paths.FEATURE_DIR
|
||||
FEATURE_SPEC = $paths.FEATURE_SPEC
|
||||
IMPL_PLAN = $paths.IMPL_PLAN
|
||||
TASKS = $paths.TASKS
|
||||
} | ConvertTo-Json -Compress
|
||||
} else {
|
||||
Write-Output "REPO_ROOT: $($paths.REPO_ROOT)"
|
||||
Write-Output "BRANCH: $($paths.CURRENT_BRANCH)"
|
||||
Write-Output "FEATURE_DIR: $($paths.FEATURE_DIR)"
|
||||
Write-Output "FEATURE_SPEC: $($paths.FEATURE_SPEC)"
|
||||
Write-Output "IMPL_PLAN: $($paths.IMPL_PLAN)"
|
||||
Write-Output "TASKS: $($paths.TASKS)"
|
||||
}
|
||||
exit 0
|
||||
}
|
||||
|
||||
# Validate required directories and files
|
||||
if (-not (Test-Path $paths.FEATURE_DIR -PathType Container)) {
|
||||
Write-Output "ERROR: Feature directory not found: $($paths.FEATURE_DIR)"
|
||||
Write-Output "Run /speckit.specify first to create the feature structure."
|
||||
exit 1
|
||||
}
|
||||
|
||||
if (-not (Test-Path $paths.IMPL_PLAN -PathType Leaf)) {
|
||||
Write-Output "ERROR: plan.md not found in $($paths.FEATURE_DIR)"
|
||||
Write-Output "Run /speckit.plan first to create the implementation plan."
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Check for tasks.md if required
|
||||
if ($RequireTasks -and -not (Test-Path $paths.TASKS -PathType Leaf)) {
|
||||
Write-Output "ERROR: tasks.md not found in $($paths.FEATURE_DIR)"
|
||||
Write-Output "Run /speckit.tasks first to create the task list."
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Build list of available documents
|
||||
$docs = @()
|
||||
|
||||
# Always check these optional docs
|
||||
if (Test-Path $paths.RESEARCH) { $docs += 'research.md' }
|
||||
if (Test-Path $paths.DATA_MODEL) { $docs += 'data-model.md' }
|
||||
|
||||
# Check contracts directory (only if it exists and has files)
|
||||
if ((Test-Path $paths.CONTRACTS_DIR) -and (Get-ChildItem -Path $paths.CONTRACTS_DIR -ErrorAction SilentlyContinue | Select-Object -First 1)) {
|
||||
$docs += 'contracts/'
|
||||
}
|
||||
|
||||
if (Test-Path $paths.QUICKSTART) { $docs += 'quickstart.md' }
|
||||
|
||||
# Include tasks.md if requested and it exists
|
||||
if ($IncludeTasks -and (Test-Path $paths.TASKS)) {
|
||||
$docs += 'tasks.md'
|
||||
}
|
||||
|
||||
# Output results
|
||||
if ($Json) {
|
||||
# JSON output
|
||||
[PSCustomObject]@{
|
||||
FEATURE_DIR = $paths.FEATURE_DIR
|
||||
AVAILABLE_DOCS = $docs
|
||||
} | ConvertTo-Json -Compress
|
||||
} else {
|
||||
# Text output
|
||||
Write-Output "FEATURE_DIR:$($paths.FEATURE_DIR)"
|
||||
Write-Output "AVAILABLE_DOCS:"
|
||||
|
||||
# Show status of each potential document
|
||||
Test-FileExists -Path $paths.RESEARCH -Description 'research.md' | Out-Null
|
||||
Test-FileExists -Path $paths.DATA_MODEL -Description 'data-model.md' | Out-Null
|
||||
Test-DirHasFiles -Path $paths.CONTRACTS_DIR -Description 'contracts/' | Out-Null
|
||||
Test-FileExists -Path $paths.QUICKSTART -Description 'quickstart.md' | Out-Null
|
||||
|
||||
if ($IncludeTasks) {
|
||||
Test-FileExists -Path $paths.TASKS -Description 'tasks.md' | Out-Null
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,137 @@
|
||||
#!/usr/bin/env pwsh
|
||||
# Common PowerShell functions analogous to common.sh
|
||||
|
||||
function Get-RepoRoot {
|
||||
try {
|
||||
$result = git rev-parse --show-toplevel 2>$null
|
||||
if ($LASTEXITCODE -eq 0) {
|
||||
return $result
|
||||
}
|
||||
} catch {
|
||||
# Git command failed
|
||||
}
|
||||
|
||||
# Fall back to script location for non-git repos
|
||||
return (Resolve-Path (Join-Path $PSScriptRoot "../../..")).Path
|
||||
}
|
||||
|
||||
function Get-CurrentBranch {
|
||||
# First check if SPECIFY_FEATURE environment variable is set
|
||||
if ($env:SPECIFY_FEATURE) {
|
||||
return $env:SPECIFY_FEATURE
|
||||
}
|
||||
|
||||
# Then check git if available
|
||||
try {
|
||||
$result = git rev-parse --abbrev-ref HEAD 2>$null
|
||||
if ($LASTEXITCODE -eq 0) {
|
||||
return $result
|
||||
}
|
||||
} catch {
|
||||
# Git command failed
|
||||
}
|
||||
|
||||
# For non-git repos, try to find the latest feature directory
|
||||
$repoRoot = Get-RepoRoot
|
||||
$specsDir = Join-Path $repoRoot "specs"
|
||||
|
||||
if (Test-Path $specsDir) {
|
||||
$latestFeature = ""
|
||||
$highest = 0
|
||||
|
||||
Get-ChildItem -Path $specsDir -Directory | ForEach-Object {
|
||||
if ($_.Name -match '^(\d{3})-') {
|
||||
$num = [int]$matches[1]
|
||||
if ($num -gt $highest) {
|
||||
$highest = $num
|
||||
$latestFeature = $_.Name
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if ($latestFeature) {
|
||||
return $latestFeature
|
||||
}
|
||||
}
|
||||
|
||||
# Final fallback
|
||||
return "main"
|
||||
}
|
||||
|
||||
function Test-HasGit {
|
||||
try {
|
||||
git rev-parse --show-toplevel 2>$null | Out-Null
|
||||
return ($LASTEXITCODE -eq 0)
|
||||
} catch {
|
||||
return $false
|
||||
}
|
||||
}
|
||||
|
||||
function Test-FeatureBranch {
|
||||
param(
|
||||
[string]$Branch,
|
||||
[bool]$HasGit = $true
|
||||
)
|
||||
|
||||
# For non-git repos, we can't enforce branch naming but still provide output
|
||||
if (-not $HasGit) {
|
||||
Write-Warning "[specify] Warning: Git repository not detected; skipped branch validation"
|
||||
return $true
|
||||
}
|
||||
|
||||
if ($Branch -notmatch '^[0-9]{3}-') {
|
||||
Write-Output "ERROR: Not on a feature branch. Current branch: $Branch"
|
||||
Write-Output "Feature branches should be named like: 001-feature-name"
|
||||
return $false
|
||||
}
|
||||
return $true
|
||||
}
|
||||
|
||||
function Get-FeatureDir {
|
||||
param([string]$RepoRoot, [string]$Branch)
|
||||
Join-Path $RepoRoot "specs/$Branch"
|
||||
}
|
||||
|
||||
function Get-FeaturePathsEnv {
|
||||
$repoRoot = Get-RepoRoot
|
||||
$currentBranch = Get-CurrentBranch
|
||||
$hasGit = Test-HasGit
|
||||
$featureDir = Get-FeatureDir -RepoRoot $repoRoot -Branch $currentBranch
|
||||
|
||||
[PSCustomObject]@{
|
||||
REPO_ROOT = $repoRoot
|
||||
CURRENT_BRANCH = $currentBranch
|
||||
HAS_GIT = $hasGit
|
||||
FEATURE_DIR = $featureDir
|
||||
FEATURE_SPEC = Join-Path $featureDir 'spec.md'
|
||||
IMPL_PLAN = Join-Path $featureDir 'plan.md'
|
||||
TASKS = Join-Path $featureDir 'tasks.md'
|
||||
RESEARCH = Join-Path $featureDir 'research.md'
|
||||
DATA_MODEL = Join-Path $featureDir 'data-model.md'
|
||||
QUICKSTART = Join-Path $featureDir 'quickstart.md'
|
||||
CONTRACTS_DIR = Join-Path $featureDir 'contracts'
|
||||
}
|
||||
}
|
||||
|
||||
function Test-FileExists {
|
||||
param([string]$Path, [string]$Description)
|
||||
if (Test-Path -Path $Path -PathType Leaf) {
|
||||
Write-Output " ✓ $Description"
|
||||
return $true
|
||||
} else {
|
||||
Write-Output " ✗ $Description"
|
||||
return $false
|
||||
}
|
||||
}
|
||||
|
||||
function Test-DirHasFiles {
|
||||
param([string]$Path, [string]$Description)
|
||||
if ((Test-Path -Path $Path -PathType Container) -and (Get-ChildItem -Path $Path -ErrorAction SilentlyContinue | Where-Object { -not $_.PSIsContainer } | Select-Object -First 1)) {
|
||||
Write-Output " ✓ $Description"
|
||||
return $true
|
||||
} else {
|
||||
Write-Output " ✗ $Description"
|
||||
return $false
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,290 @@
|
||||
#!/usr/bin/env pwsh
|
||||
# Create a new feature
|
||||
[CmdletBinding()]
|
||||
param(
|
||||
[switch]$Json,
|
||||
[string]$ShortName,
|
||||
[int]$Number = 0,
|
||||
[switch]$Help,
|
||||
[Parameter(ValueFromRemainingArguments = $true)]
|
||||
[string[]]$FeatureDescription
|
||||
)
|
||||
$ErrorActionPreference = 'Stop'
|
||||
|
||||
# Show help if requested
|
||||
if ($Help) {
|
||||
Write-Host "Usage: ./create-new-feature.ps1 [-Json] [-ShortName <name>] [-Number N] <feature description>"
|
||||
Write-Host ""
|
||||
Write-Host "Options:"
|
||||
Write-Host " -Json Output in JSON format"
|
||||
Write-Host " -ShortName <name> Provide a custom short name (2-4 words) for the branch"
|
||||
Write-Host " -Number N Specify branch number manually (overrides auto-detection)"
|
||||
Write-Host " -Help Show this help message"
|
||||
Write-Host ""
|
||||
Write-Host "Examples:"
|
||||
Write-Host " ./create-new-feature.ps1 'Add user authentication system' -ShortName 'user-auth'"
|
||||
Write-Host " ./create-new-feature.ps1 'Implement OAuth2 integration for API'"
|
||||
exit 0
|
||||
}
|
||||
|
||||
# Check if feature description provided
|
||||
if (-not $FeatureDescription -or $FeatureDescription.Count -eq 0) {
|
||||
Write-Error "Usage: ./create-new-feature.ps1 [-Json] [-ShortName <name>] <feature description>"
|
||||
exit 1
|
||||
}
|
||||
|
||||
$featureDesc = ($FeatureDescription -join ' ').Trim()
|
||||
|
||||
# Resolve repository root. Prefer git information when available, but fall back
|
||||
# to searching for repository markers so the workflow still functions in repositories that
|
||||
# were initialized with --no-git.
|
||||
function Find-RepositoryRoot {
|
||||
param(
|
||||
[string]$StartDir,
|
||||
[string[]]$Markers = @('.git', '.specify')
|
||||
)
|
||||
$current = Resolve-Path $StartDir
|
||||
while ($true) {
|
||||
foreach ($marker in $Markers) {
|
||||
if (Test-Path (Join-Path $current $marker)) {
|
||||
return $current
|
||||
}
|
||||
}
|
||||
$parent = Split-Path $current -Parent
|
||||
if ($parent -eq $current) {
|
||||
# Reached filesystem root without finding markers
|
||||
return $null
|
||||
}
|
||||
$current = $parent
|
||||
}
|
||||
}
|
||||
|
||||
function Get-NextBranchNumber {
|
||||
param(
|
||||
[string]$ShortName,
|
||||
[string]$SpecsDir
|
||||
)
|
||||
|
||||
# Fetch all remotes to get latest branch info (suppress errors if no remotes)
|
||||
try {
|
||||
git fetch --all --prune 2>$null | Out-Null
|
||||
} catch {
|
||||
# Ignore fetch errors
|
||||
}
|
||||
|
||||
# Find remote branches matching the pattern using git ls-remote
|
||||
$remoteBranches = @()
|
||||
try {
|
||||
$remoteRefs = git ls-remote --heads origin 2>$null
|
||||
if ($remoteRefs) {
|
||||
$remoteBranches = $remoteRefs | Where-Object { $_ -match "refs/heads/(\d+)-$([regex]::Escape($ShortName))$" } | ForEach-Object {
|
||||
if ($_ -match "refs/heads/(\d+)-") {
|
||||
[int]$matches[1]
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
# Ignore errors
|
||||
}
|
||||
|
||||
# Check local branches
|
||||
$localBranches = @()
|
||||
try {
|
||||
$allBranches = git branch 2>$null
|
||||
if ($allBranches) {
|
||||
$localBranches = $allBranches | Where-Object { $_ -match "^\*?\s*(\d+)-$([regex]::Escape($ShortName))$" } | ForEach-Object {
|
||||
if ($_ -match "(\d+)-") {
|
||||
[int]$matches[1]
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
# Ignore errors
|
||||
}
|
||||
|
||||
# Check specs directory
|
||||
$specDirs = @()
|
||||
if (Test-Path $SpecsDir) {
|
||||
try {
|
||||
$specDirs = Get-ChildItem -Path $SpecsDir -Directory | Where-Object { $_.Name -match "^(\d+)-$([regex]::Escape($ShortName))$" } | ForEach-Object {
|
||||
if ($_.Name -match "^(\d+)-") {
|
||||
[int]$matches[1]
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
# Ignore errors
|
||||
}
|
||||
}
|
||||
|
||||
# Combine all sources and get the highest number
|
||||
$maxNum = 0
|
||||
foreach ($num in ($remoteBranches + $localBranches + $specDirs)) {
|
||||
if ($num -gt $maxNum) {
|
||||
$maxNum = $num
|
||||
}
|
||||
}
|
||||
|
||||
# Return next number
|
||||
return $maxNum + 1
|
||||
}
|
||||
$fallbackRoot = (Find-RepositoryRoot -StartDir $PSScriptRoot)
|
||||
if (-not $fallbackRoot) {
|
||||
Write-Error "Error: Could not determine repository root. Please run this script from within the repository."
|
||||
exit 1
|
||||
}
|
||||
|
||||
try {
|
||||
$repoRoot = git rev-parse --show-toplevel 2>$null
|
||||
if ($LASTEXITCODE -eq 0) {
|
||||
$hasGit = $true
|
||||
} else {
|
||||
throw "Git not available"
|
||||
}
|
||||
} catch {
|
||||
$repoRoot = $fallbackRoot
|
||||
$hasGit = $false
|
||||
}
|
||||
|
||||
Set-Location $repoRoot
|
||||
|
||||
$specsDir = Join-Path $repoRoot 'specs'
|
||||
New-Item -ItemType Directory -Path $specsDir -Force | Out-Null
|
||||
|
||||
# Function to generate branch name with stop word filtering and length filtering
|
||||
function Get-BranchName {
|
||||
param([string]$Description)
|
||||
|
||||
# Common stop words to filter out
|
||||
$stopWords = @(
|
||||
'i', 'a', 'an', 'the', 'to', 'for', 'of', 'in', 'on', 'at', 'by', 'with', 'from',
|
||||
'is', 'are', 'was', 'were', 'be', 'been', 'being', 'have', 'has', 'had',
|
||||
'do', 'does', 'did', 'will', 'would', 'should', 'could', 'can', 'may', 'might', 'must', 'shall',
|
||||
'this', 'that', 'these', 'those', 'my', 'your', 'our', 'their',
|
||||
'want', 'need', 'add', 'get', 'set'
|
||||
)
|
||||
|
||||
# Convert to lowercase and extract words (alphanumeric only)
|
||||
$cleanName = $Description.ToLower() -replace '[^a-z0-9\s]', ' '
|
||||
$words = $cleanName -split '\s+' | Where-Object { $_ }
|
||||
|
||||
# Filter words: remove stop words and words shorter than 3 chars (unless they're uppercase acronyms in original)
|
||||
$meaningfulWords = @()
|
||||
foreach ($word in $words) {
|
||||
# Skip stop words
|
||||
if ($stopWords -contains $word) { continue }
|
||||
|
||||
# Keep words that are length >= 3 OR appear as uppercase in original (likely acronyms)
|
||||
if ($word.Length -ge 3) {
|
||||
$meaningfulWords += $word
|
||||
} elseif ($Description -match "\b$($word.ToUpper())\b") {
|
||||
# Keep short words if they appear as uppercase in original (likely acronyms)
|
||||
$meaningfulWords += $word
|
||||
}
|
||||
}
|
||||
|
||||
# If we have meaningful words, use first 3-4 of them
|
||||
if ($meaningfulWords.Count -gt 0) {
|
||||
$maxWords = if ($meaningfulWords.Count -eq 4) { 4 } else { 3 }
|
||||
$result = ($meaningfulWords | Select-Object -First $maxWords) -join '-'
|
||||
return $result
|
||||
} else {
|
||||
# Fallback to original logic if no meaningful words found
|
||||
$result = $Description.ToLower() -replace '[^a-z0-9]', '-' -replace '-{2,}', '-' -replace '^-', '' -replace '-$', ''
|
||||
$fallbackWords = ($result -split '-') | Where-Object { $_ } | Select-Object -First 3
|
||||
return [string]::Join('-', $fallbackWords)
|
||||
}
|
||||
}
|
||||
|
||||
# Generate branch name
|
||||
if ($ShortName) {
|
||||
# Use provided short name, just clean it up
|
||||
$branchSuffix = $ShortName.ToLower() -replace '[^a-z0-9]', '-' -replace '-{2,}', '-' -replace '^-', '' -replace '-$', ''
|
||||
} else {
|
||||
# Generate from description with smart filtering
|
||||
$branchSuffix = Get-BranchName -Description $featureDesc
|
||||
}
|
||||
|
||||
# Determine branch number
|
||||
if ($Number -eq 0) {
|
||||
if ($hasGit) {
|
||||
# Check existing branches on remotes
|
||||
$Number = Get-NextBranchNumber -ShortName $branchSuffix -SpecsDir $specsDir
|
||||
} else {
|
||||
# Fall back to local directory check
|
||||
$highest = 0
|
||||
if (Test-Path $specsDir) {
|
||||
Get-ChildItem -Path $specsDir -Directory | ForEach-Object {
|
||||
if ($_.Name -match '^(\d{3})') {
|
||||
$num = [int]$matches[1]
|
||||
if ($num -gt $highest) { $highest = $num }
|
||||
}
|
||||
}
|
||||
}
|
||||
$Number = $highest + 1
|
||||
}
|
||||
}
|
||||
|
||||
$featureNum = ('{0:000}' -f $Number)
|
||||
$branchName = "$featureNum-$branchSuffix"
|
||||
|
||||
# GitHub enforces a 244-byte limit on branch names
|
||||
# Validate and truncate if necessary
|
||||
$maxBranchLength = 244
|
||||
if ($branchName.Length -gt $maxBranchLength) {
|
||||
# Calculate how much we need to trim from suffix
|
||||
# Account for: feature number (3) + hyphen (1) = 4 chars
|
||||
$maxSuffixLength = $maxBranchLength - 4
|
||||
|
||||
# Truncate suffix
|
||||
$truncatedSuffix = $branchSuffix.Substring(0, [Math]::Min($branchSuffix.Length, $maxSuffixLength))
|
||||
# Remove trailing hyphen if truncation created one
|
||||
$truncatedSuffix = $truncatedSuffix -replace '-$', ''
|
||||
|
||||
$originalBranchName = $branchName
|
||||
$branchName = "$featureNum-$truncatedSuffix"
|
||||
|
||||
Write-Warning "[specify] Branch name exceeded GitHub's 244-byte limit"
|
||||
Write-Warning "[specify] Original: $originalBranchName ($($originalBranchName.Length) bytes)"
|
||||
Write-Warning "[specify] Truncated to: $branchName ($($branchName.Length) bytes)"
|
||||
}
|
||||
|
||||
if ($hasGit) {
|
||||
try {
|
||||
git checkout -b $branchName | Out-Null
|
||||
} catch {
|
||||
Write-Warning "Failed to create git branch: $branchName"
|
||||
}
|
||||
} else {
|
||||
Write-Warning "[specify] Warning: Git repository not detected; skipped branch creation for $branchName"
|
||||
}
|
||||
|
||||
$featureDir = Join-Path $specsDir $branchName
|
||||
New-Item -ItemType Directory -Path $featureDir -Force | Out-Null
|
||||
|
||||
$template = Join-Path $repoRoot '.specify/templates/spec-template.md'
|
||||
$specFile = Join-Path $featureDir 'spec.md'
|
||||
if (Test-Path $template) {
|
||||
Copy-Item $template $specFile -Force
|
||||
} else {
|
||||
New-Item -ItemType File -Path $specFile | Out-Null
|
||||
}
|
||||
|
||||
# Set the SPECIFY_FEATURE environment variable for the current session
|
||||
$env:SPECIFY_FEATURE = $branchName
|
||||
|
||||
if ($Json) {
|
||||
$obj = [PSCustomObject]@{
|
||||
BRANCH_NAME = $branchName
|
||||
SPEC_FILE = $specFile
|
||||
FEATURE_NUM = $featureNum
|
||||
HAS_GIT = $hasGit
|
||||
}
|
||||
$obj | ConvertTo-Json -Compress
|
||||
} else {
|
||||
Write-Output "BRANCH_NAME: $branchName"
|
||||
Write-Output "SPEC_FILE: $specFile"
|
||||
Write-Output "FEATURE_NUM: $featureNum"
|
||||
Write-Output "HAS_GIT: $hasGit"
|
||||
Write-Output "SPECIFY_FEATURE environment variable set to: $branchName"
|
||||
}
|
||||
|
||||
@ -0,0 +1,62 @@
|
||||
#!/usr/bin/env pwsh
|
||||
# Setup implementation plan for a feature
|
||||
|
||||
[CmdletBinding()]
|
||||
param(
|
||||
[switch]$Json,
|
||||
[switch]$Help
|
||||
)
|
||||
|
||||
$ErrorActionPreference = 'Stop'
|
||||
|
||||
# Show help if requested
|
||||
if ($Help) {
|
||||
Write-Output "Usage: ./setup-plan.ps1 [-Json] [-Help]"
|
||||
Write-Output " -Json Output results in JSON format"
|
||||
Write-Output " -Help Show this help message"
|
||||
exit 0
|
||||
}
|
||||
|
||||
# Load common functions
|
||||
. "$PSScriptRoot/common.ps1"
|
||||
|
||||
# Get all paths and variables from common functions
|
||||
$paths = Get-FeaturePathsEnv
|
||||
|
||||
# Check if we're on a proper feature branch (only for git repos)
|
||||
if (-not (Test-FeatureBranch -Branch $paths.CURRENT_BRANCH -HasGit $paths.HAS_GIT)) {
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Ensure the feature directory exists
|
||||
New-Item -ItemType Directory -Path $paths.FEATURE_DIR -Force | Out-Null
|
||||
|
||||
# Copy plan template if it exists, otherwise note it or create empty file
|
||||
$template = Join-Path $paths.REPO_ROOT '.specify/templates/plan-template.md'
|
||||
if (Test-Path $template) {
|
||||
Copy-Item $template $paths.IMPL_PLAN -Force
|
||||
Write-Output "Copied plan template to $($paths.IMPL_PLAN)"
|
||||
} else {
|
||||
Write-Warning "Plan template not found at $template"
|
||||
# Create a basic plan file if template doesn't exist
|
||||
New-Item -ItemType File -Path $paths.IMPL_PLAN -Force | Out-Null
|
||||
}
|
||||
|
||||
# Output results
|
||||
if ($Json) {
|
||||
$result = [PSCustomObject]@{
|
||||
FEATURE_SPEC = $paths.FEATURE_SPEC
|
||||
IMPL_PLAN = $paths.IMPL_PLAN
|
||||
SPECS_DIR = $paths.FEATURE_DIR
|
||||
BRANCH = $paths.CURRENT_BRANCH
|
||||
HAS_GIT = $paths.HAS_GIT
|
||||
}
|
||||
$result | ConvertTo-Json -Compress
|
||||
} else {
|
||||
Write-Output "FEATURE_SPEC: $($paths.FEATURE_SPEC)"
|
||||
Write-Output "IMPL_PLAN: $($paths.IMPL_PLAN)"
|
||||
Write-Output "SPECS_DIR: $($paths.FEATURE_DIR)"
|
||||
Write-Output "BRANCH: $($paths.CURRENT_BRANCH)"
|
||||
Write-Output "HAS_GIT: $($paths.HAS_GIT)"
|
||||
}
|
||||
|
||||
@ -0,0 +1,439 @@
|
||||
#!/usr/bin/env pwsh
|
||||
<#!
|
||||
.SYNOPSIS
|
||||
Update agent context files with information from plan.md (PowerShell version)
|
||||
|
||||
.DESCRIPTION
|
||||
Mirrors the behavior of scripts/bash/update-agent-context.sh:
|
||||
1. Environment Validation
|
||||
2. Plan Data Extraction
|
||||
3. Agent File Management (create from template or update existing)
|
||||
4. Content Generation (technology stack, recent changes, timestamp)
|
||||
5. Multi-Agent Support (claude, gemini, copilot, cursor-agent, qwen, opencode, codex, windsurf, kilocode, auggie, roo, amp, q)
|
||||
|
||||
.PARAMETER AgentType
|
||||
Optional agent key to update a single agent. If omitted, updates all existing agent files (creating a default Claude file if none exist).
|
||||
|
||||
.EXAMPLE
|
||||
./update-agent-context.ps1 -AgentType claude
|
||||
|
||||
.EXAMPLE
|
||||
./update-agent-context.ps1 # Updates all existing agent files
|
||||
|
||||
.NOTES
|
||||
Relies on common helper functions in common.ps1
|
||||
#>
|
||||
param(
|
||||
[Parameter(Position=0)]
|
||||
[ValidateSet('claude','gemini','copilot','cursor-agent','qwen','opencode','codex','windsurf','kilocode','auggie','roo','codebuddy','amp','q')]
|
||||
[string]$AgentType
|
||||
)
|
||||
|
||||
$ErrorActionPreference = 'Stop'
|
||||
|
||||
# Import common helpers
|
||||
$ScriptDir = Split-Path -Parent $MyInvocation.MyCommand.Path
|
||||
. (Join-Path $ScriptDir 'common.ps1')
|
||||
|
||||
# Acquire environment paths
|
||||
$envData = Get-FeaturePathsEnv
|
||||
$REPO_ROOT = $envData.REPO_ROOT
|
||||
$CURRENT_BRANCH = $envData.CURRENT_BRANCH
|
||||
$HAS_GIT = $envData.HAS_GIT
|
||||
$IMPL_PLAN = $envData.IMPL_PLAN
|
||||
$NEW_PLAN = $IMPL_PLAN
|
||||
|
||||
# Agent file paths
|
||||
$CLAUDE_FILE = Join-Path $REPO_ROOT 'CLAUDE.md'
|
||||
$GEMINI_FILE = Join-Path $REPO_ROOT 'GEMINI.md'
|
||||
$COPILOT_FILE = Join-Path $REPO_ROOT '.github/copilot-instructions.md'
|
||||
$CURSOR_FILE = Join-Path $REPO_ROOT '.cursor/rules/specify-rules.mdc'
|
||||
$QWEN_FILE = Join-Path $REPO_ROOT 'QWEN.md'
|
||||
$AGENTS_FILE = Join-Path $REPO_ROOT 'AGENTS.md'
|
||||
$WINDSURF_FILE = Join-Path $REPO_ROOT '.windsurf/rules/specify-rules.md'
|
||||
$KILOCODE_FILE = Join-Path $REPO_ROOT '.kilocode/rules/specify-rules.md'
|
||||
$AUGGIE_FILE = Join-Path $REPO_ROOT '.augment/rules/specify-rules.md'
|
||||
$ROO_FILE = Join-Path $REPO_ROOT '.roo/rules/specify-rules.md'
|
||||
$CODEBUDDY_FILE = Join-Path $REPO_ROOT 'CODEBUDDY.md'
|
||||
$AMP_FILE = Join-Path $REPO_ROOT 'AGENTS.md'
|
||||
$Q_FILE = Join-Path $REPO_ROOT 'AGENTS.md'
|
||||
|
||||
$TEMPLATE_FILE = Join-Path $REPO_ROOT '.specify/templates/agent-file-template.md'
|
||||
|
||||
# Parsed plan data placeholders
|
||||
$script:NEW_LANG = ''
|
||||
$script:NEW_FRAMEWORK = ''
|
||||
$script:NEW_DB = ''
|
||||
$script:NEW_PROJECT_TYPE = ''
|
||||
|
||||
function Write-Info {
|
||||
param(
|
||||
[Parameter(Mandatory=$true)]
|
||||
[string]$Message
|
||||
)
|
||||
Write-Host "INFO: $Message"
|
||||
}
|
||||
|
||||
function Write-Success {
|
||||
param(
|
||||
[Parameter(Mandatory=$true)]
|
||||
[string]$Message
|
||||
)
|
||||
Write-Host "$([char]0x2713) $Message"
|
||||
}
|
||||
|
||||
function Write-WarningMsg {
|
||||
param(
|
||||
[Parameter(Mandatory=$true)]
|
||||
[string]$Message
|
||||
)
|
||||
Write-Warning $Message
|
||||
}
|
||||
|
||||
function Write-Err {
|
||||
param(
|
||||
[Parameter(Mandatory=$true)]
|
||||
[string]$Message
|
||||
)
|
||||
Write-Host "ERROR: $Message" -ForegroundColor Red
|
||||
}
|
||||
|
||||
function Validate-Environment {
|
||||
if (-not $CURRENT_BRANCH) {
|
||||
Write-Err 'Unable to determine current feature'
|
||||
if ($HAS_GIT) { Write-Info "Make sure you're on a feature branch" } else { Write-Info 'Set SPECIFY_FEATURE environment variable or create a feature first' }
|
||||
exit 1
|
||||
}
|
||||
if (-not (Test-Path $NEW_PLAN)) {
|
||||
Write-Err "No plan.md found at $NEW_PLAN"
|
||||
Write-Info 'Ensure you are working on a feature with a corresponding spec directory'
|
||||
if (-not $HAS_GIT) { Write-Info 'Use: $env:SPECIFY_FEATURE=your-feature-name or create a new feature first' }
|
||||
exit 1
|
||||
}
|
||||
if (-not (Test-Path $TEMPLATE_FILE)) {
|
||||
Write-Err "Template file not found at $TEMPLATE_FILE"
|
||||
Write-Info 'Run specify init to scaffold .specify/templates, or add agent-file-template.md there.'
|
||||
exit 1
|
||||
}
|
||||
}
|
||||
|
||||
function Extract-PlanField {
|
||||
param(
|
||||
[Parameter(Mandatory=$true)]
|
||||
[string]$FieldPattern,
|
||||
[Parameter(Mandatory=$true)]
|
||||
[string]$PlanFile
|
||||
)
|
||||
if (-not (Test-Path $PlanFile)) { return '' }
|
||||
# Lines like **Language/Version**: Python 3.12
|
||||
$regex = "^\*\*$([Regex]::Escape($FieldPattern))\*\*: (.+)$"
|
||||
Get-Content -LiteralPath $PlanFile -Encoding utf8 | ForEach-Object {
|
||||
if ($_ -match $regex) {
|
||||
$val = $Matches[1].Trim()
|
||||
if ($val -notin @('NEEDS CLARIFICATION','N/A')) { return $val }
|
||||
}
|
||||
} | Select-Object -First 1
|
||||
}
|
||||
|
||||
function Parse-PlanData {
|
||||
param(
|
||||
[Parameter(Mandatory=$true)]
|
||||
[string]$PlanFile
|
||||
)
|
||||
if (-not (Test-Path $PlanFile)) { Write-Err "Plan file not found: $PlanFile"; return $false }
|
||||
Write-Info "Parsing plan data from $PlanFile"
|
||||
$script:NEW_LANG = Extract-PlanField -FieldPattern 'Language/Version' -PlanFile $PlanFile
|
||||
$script:NEW_FRAMEWORK = Extract-PlanField -FieldPattern 'Primary Dependencies' -PlanFile $PlanFile
|
||||
$script:NEW_DB = Extract-PlanField -FieldPattern 'Storage' -PlanFile $PlanFile
|
||||
$script:NEW_PROJECT_TYPE = Extract-PlanField -FieldPattern 'Project Type' -PlanFile $PlanFile
|
||||
|
||||
if ($NEW_LANG) { Write-Info "Found language: $NEW_LANG" } else { Write-WarningMsg 'No language information found in plan' }
|
||||
if ($NEW_FRAMEWORK) { Write-Info "Found framework: $NEW_FRAMEWORK" }
|
||||
if ($NEW_DB -and $NEW_DB -ne 'N/A') { Write-Info "Found database: $NEW_DB" }
|
||||
if ($NEW_PROJECT_TYPE) { Write-Info "Found project type: $NEW_PROJECT_TYPE" }
|
||||
return $true
|
||||
}
|
||||
|
||||
function Format-TechnologyStack {
|
||||
param(
|
||||
[Parameter(Mandatory=$false)]
|
||||
[string]$Lang,
|
||||
[Parameter(Mandatory=$false)]
|
||||
[string]$Framework
|
||||
)
|
||||
$parts = @()
|
||||
if ($Lang -and $Lang -ne 'NEEDS CLARIFICATION') { $parts += $Lang }
|
||||
if ($Framework -and $Framework -notin @('NEEDS CLARIFICATION','N/A')) { $parts += $Framework }
|
||||
if (-not $parts) { return '' }
|
||||
return ($parts -join ' + ')
|
||||
}
|
||||
|
||||
function Get-ProjectStructure {
|
||||
param(
|
||||
[Parameter(Mandatory=$false)]
|
||||
[string]$ProjectType
|
||||
)
|
||||
if ($ProjectType -match 'web') { return "backend/`nfrontend/`ntests/" } else { return "src/`ntests/" }
|
||||
}
|
||||
|
||||
function Get-CommandsForLanguage {
|
||||
param(
|
||||
[Parameter(Mandatory=$false)]
|
||||
[string]$Lang
|
||||
)
|
||||
switch -Regex ($Lang) {
|
||||
'Python' { return "cd src; pytest; ruff check ." }
|
||||
'Rust' { return "cargo test; cargo clippy" }
|
||||
'JavaScript|TypeScript' { return "npm test; npm run lint" }
|
||||
default { return "# Add commands for $Lang" }
|
||||
}
|
||||
}
|
||||
|
||||
function Get-LanguageConventions {
|
||||
param(
|
||||
[Parameter(Mandatory=$false)]
|
||||
[string]$Lang
|
||||
)
|
||||
if ($Lang) { "${Lang}: Follow standard conventions" } else { 'General: Follow standard conventions' }
|
||||
}
|
||||
|
||||
function New-AgentFile {
|
||||
param(
|
||||
[Parameter(Mandatory=$true)]
|
||||
[string]$TargetFile,
|
||||
[Parameter(Mandatory=$true)]
|
||||
[string]$ProjectName,
|
||||
[Parameter(Mandatory=$true)]
|
||||
[datetime]$Date
|
||||
)
|
||||
if (-not (Test-Path $TEMPLATE_FILE)) { Write-Err "Template not found at $TEMPLATE_FILE"; return $false }
|
||||
$temp = New-TemporaryFile
|
||||
Copy-Item -LiteralPath $TEMPLATE_FILE -Destination $temp -Force
|
||||
|
||||
$projectStructure = Get-ProjectStructure -ProjectType $NEW_PROJECT_TYPE
|
||||
$commands = Get-CommandsForLanguage -Lang $NEW_LANG
|
||||
$languageConventions = Get-LanguageConventions -Lang $NEW_LANG
|
||||
|
||||
$escaped_lang = $NEW_LANG
|
||||
$escaped_framework = $NEW_FRAMEWORK
|
||||
$escaped_branch = $CURRENT_BRANCH
|
||||
|
||||
$content = Get-Content -LiteralPath $temp -Raw -Encoding utf8
|
||||
$content = $content -replace '\[PROJECT NAME\]',$ProjectName
|
||||
$content = $content -replace '\[DATE\]',$Date.ToString('yyyy-MM-dd')
|
||||
|
||||
# Build the technology stack string safely
|
||||
$techStackForTemplate = ""
|
||||
if ($escaped_lang -and $escaped_framework) {
|
||||
$techStackForTemplate = "- $escaped_lang + $escaped_framework ($escaped_branch)"
|
||||
} elseif ($escaped_lang) {
|
||||
$techStackForTemplate = "- $escaped_lang ($escaped_branch)"
|
||||
} elseif ($escaped_framework) {
|
||||
$techStackForTemplate = "- $escaped_framework ($escaped_branch)"
|
||||
}
|
||||
|
||||
$content = $content -replace '\[EXTRACTED FROM ALL PLAN.MD FILES\]',$techStackForTemplate
|
||||
# For project structure we manually embed (keep newlines)
|
||||
$escapedStructure = [Regex]::Escape($projectStructure)
|
||||
$content = $content -replace '\[ACTUAL STRUCTURE FROM PLANS\]',$escapedStructure
|
||||
# Replace escaped newlines placeholder after all replacements
|
||||
$content = $content -replace '\[ONLY COMMANDS FOR ACTIVE TECHNOLOGIES\]',$commands
|
||||
$content = $content -replace '\[LANGUAGE-SPECIFIC, ONLY FOR LANGUAGES IN USE\]',$languageConventions
|
||||
|
||||
# Build the recent changes string safely
|
||||
$recentChangesForTemplate = ""
|
||||
if ($escaped_lang -and $escaped_framework) {
|
||||
$recentChangesForTemplate = "- ${escaped_branch}: Added ${escaped_lang} + ${escaped_framework}"
|
||||
} elseif ($escaped_lang) {
|
||||
$recentChangesForTemplate = "- ${escaped_branch}: Added ${escaped_lang}"
|
||||
} elseif ($escaped_framework) {
|
||||
$recentChangesForTemplate = "- ${escaped_branch}: Added ${escaped_framework}"
|
||||
}
|
||||
|
||||
$content = $content -replace '\[LAST 3 FEATURES AND WHAT THEY ADDED\]',$recentChangesForTemplate
|
||||
# Convert literal \n sequences introduced by Escape to real newlines
|
||||
$content = $content -replace '\\n',[Environment]::NewLine
|
||||
|
||||
$parent = Split-Path -Parent $TargetFile
|
||||
if (-not (Test-Path $parent)) { New-Item -ItemType Directory -Path $parent | Out-Null }
|
||||
Set-Content -LiteralPath $TargetFile -Value $content -NoNewline -Encoding utf8
|
||||
Remove-Item $temp -Force
|
||||
return $true
|
||||
}
|
||||
|
||||
function Update-ExistingAgentFile {
|
||||
param(
|
||||
[Parameter(Mandatory=$true)]
|
||||
[string]$TargetFile,
|
||||
[Parameter(Mandatory=$true)]
|
||||
[datetime]$Date
|
||||
)
|
||||
if (-not (Test-Path $TargetFile)) { return (New-AgentFile -TargetFile $TargetFile -ProjectName (Split-Path $REPO_ROOT -Leaf) -Date $Date) }
|
||||
|
||||
$techStack = Format-TechnologyStack -Lang $NEW_LANG -Framework $NEW_FRAMEWORK
|
||||
$newTechEntries = @()
|
||||
if ($techStack) {
|
||||
$escapedTechStack = [Regex]::Escape($techStack)
|
||||
if (-not (Select-String -Pattern $escapedTechStack -Path $TargetFile -Quiet)) {
|
||||
$newTechEntries += "- $techStack ($CURRENT_BRANCH)"
|
||||
}
|
||||
}
|
||||
if ($NEW_DB -and $NEW_DB -notin @('N/A','NEEDS CLARIFICATION')) {
|
||||
$escapedDB = [Regex]::Escape($NEW_DB)
|
||||
if (-not (Select-String -Pattern $escapedDB -Path $TargetFile -Quiet)) {
|
||||
$newTechEntries += "- $NEW_DB ($CURRENT_BRANCH)"
|
||||
}
|
||||
}
|
||||
$newChangeEntry = ''
|
||||
if ($techStack) { $newChangeEntry = "- ${CURRENT_BRANCH}: Added ${techStack}" }
|
||||
elseif ($NEW_DB -and $NEW_DB -notin @('N/A','NEEDS CLARIFICATION')) { $newChangeEntry = "- ${CURRENT_BRANCH}: Added ${NEW_DB}" }
|
||||
|
||||
$lines = Get-Content -LiteralPath $TargetFile -Encoding utf8
|
||||
$output = New-Object System.Collections.Generic.List[string]
|
||||
$inTech = $false; $inChanges = $false; $techAdded = $false; $changeAdded = $false; $existingChanges = 0
|
||||
|
||||
for ($i=0; $i -lt $lines.Count; $i++) {
|
||||
$line = $lines[$i]
|
||||
if ($line -eq '## Active Technologies') {
|
||||
$output.Add($line)
|
||||
$inTech = $true
|
||||
continue
|
||||
}
|
||||
if ($inTech -and $line -match '^##\s') {
|
||||
if (-not $techAdded -and $newTechEntries.Count -gt 0) { $newTechEntries | ForEach-Object { $output.Add($_) }; $techAdded = $true }
|
||||
$output.Add($line); $inTech = $false; continue
|
||||
}
|
||||
if ($inTech -and [string]::IsNullOrWhiteSpace($line)) {
|
||||
if (-not $techAdded -and $newTechEntries.Count -gt 0) { $newTechEntries | ForEach-Object { $output.Add($_) }; $techAdded = $true }
|
||||
$output.Add($line); continue
|
||||
}
|
||||
if ($line -eq '## Recent Changes') {
|
||||
$output.Add($line)
|
||||
if ($newChangeEntry) { $output.Add($newChangeEntry); $changeAdded = $true }
|
||||
$inChanges = $true
|
||||
continue
|
||||
}
|
||||
if ($inChanges -and $line -match '^##\s') { $output.Add($line); $inChanges = $false; continue }
|
||||
if ($inChanges -and $line -match '^- ') {
|
||||
if ($existingChanges -lt 2) { $output.Add($line); $existingChanges++ }
|
||||
continue
|
||||
}
|
||||
if ($line -match '\*\*Last updated\*\*: .*\d{4}-\d{2}-\d{2}') {
|
||||
$output.Add(($line -replace '\d{4}-\d{2}-\d{2}',$Date.ToString('yyyy-MM-dd')))
|
||||
continue
|
||||
}
|
||||
$output.Add($line)
|
||||
}
|
||||
|
||||
# Post-loop check: if we're still in the Active Technologies section and haven't added new entries
|
||||
if ($inTech -and -not $techAdded -and $newTechEntries.Count -gt 0) {
|
||||
$newTechEntries | ForEach-Object { $output.Add($_) }
|
||||
}
|
||||
|
||||
Set-Content -LiteralPath $TargetFile -Value ($output -join [Environment]::NewLine) -Encoding utf8
|
||||
return $true
|
||||
}
|
||||
|
||||
function Update-AgentFile {
|
||||
param(
|
||||
[Parameter(Mandatory=$true)]
|
||||
[string]$TargetFile,
|
||||
[Parameter(Mandatory=$true)]
|
||||
[string]$AgentName
|
||||
)
|
||||
if (-not $TargetFile -or -not $AgentName) { Write-Err 'Update-AgentFile requires TargetFile and AgentName'; return $false }
|
||||
Write-Info "Updating $AgentName context file: $TargetFile"
|
||||
$projectName = Split-Path $REPO_ROOT -Leaf
|
||||
$date = Get-Date
|
||||
|
||||
$dir = Split-Path -Parent $TargetFile
|
||||
if (-not (Test-Path $dir)) { New-Item -ItemType Directory -Path $dir | Out-Null }
|
||||
|
||||
if (-not (Test-Path $TargetFile)) {
|
||||
if (New-AgentFile -TargetFile $TargetFile -ProjectName $projectName -Date $date) { Write-Success "Created new $AgentName context file" } else { Write-Err 'Failed to create new agent file'; return $false }
|
||||
} else {
|
||||
try {
|
||||
if (Update-ExistingAgentFile -TargetFile $TargetFile -Date $date) { Write-Success "Updated existing $AgentName context file" } else { Write-Err 'Failed to update agent file'; return $false }
|
||||
} catch {
|
||||
Write-Err "Cannot access or update existing file: $TargetFile. $_"
|
||||
return $false
|
||||
}
|
||||
}
|
||||
return $true
|
||||
}
|
||||
|
||||
function Update-SpecificAgent {
|
||||
param(
|
||||
[Parameter(Mandatory=$true)]
|
||||
[string]$Type
|
||||
)
|
||||
switch ($Type) {
|
||||
'claude' { Update-AgentFile -TargetFile $CLAUDE_FILE -AgentName 'Claude Code' }
|
||||
'gemini' { Update-AgentFile -TargetFile $GEMINI_FILE -AgentName 'Gemini CLI' }
|
||||
'copilot' { Update-AgentFile -TargetFile $COPILOT_FILE -AgentName 'GitHub Copilot' }
|
||||
'cursor-agent' { Update-AgentFile -TargetFile $CURSOR_FILE -AgentName 'Cursor IDE' }
|
||||
'qwen' { Update-AgentFile -TargetFile $QWEN_FILE -AgentName 'Qwen Code' }
|
||||
'opencode' { Update-AgentFile -TargetFile $AGENTS_FILE -AgentName 'opencode' }
|
||||
'codex' { Update-AgentFile -TargetFile $AGENTS_FILE -AgentName 'Codex CLI' }
|
||||
'windsurf' { Update-AgentFile -TargetFile $WINDSURF_FILE -AgentName 'Windsurf' }
|
||||
'kilocode' { Update-AgentFile -TargetFile $KILOCODE_FILE -AgentName 'Kilo Code' }
|
||||
'auggie' { Update-AgentFile -TargetFile $AUGGIE_FILE -AgentName 'Auggie CLI' }
|
||||
'roo' { Update-AgentFile -TargetFile $ROO_FILE -AgentName 'Roo Code' }
|
||||
'codebuddy' { Update-AgentFile -TargetFile $CODEBUDDY_FILE -AgentName 'CodeBuddy CLI' }
|
||||
'amp' { Update-AgentFile -TargetFile $AMP_FILE -AgentName 'Amp' }
|
||||
'q' { Update-AgentFile -TargetFile $Q_FILE -AgentName 'Amazon Q Developer CLI' }
|
||||
default { Write-Err "Unknown agent type '$Type'"; Write-Err 'Expected: claude|gemini|copilot|cursor-agent|qwen|opencode|codex|windsurf|kilocode|auggie|roo|codebuddy|amp|q'; return $false }
|
||||
}
|
||||
}
|
||||
|
||||
function Update-AllExistingAgents {
|
||||
$found = $false
|
||||
$ok = $true
|
||||
if (Test-Path $CLAUDE_FILE) { if (-not (Update-AgentFile -TargetFile $CLAUDE_FILE -AgentName 'Claude Code')) { $ok = $false }; $found = $true }
|
||||
if (Test-Path $GEMINI_FILE) { if (-not (Update-AgentFile -TargetFile $GEMINI_FILE -AgentName 'Gemini CLI')) { $ok = $false }; $found = $true }
|
||||
if (Test-Path $COPILOT_FILE) { if (-not (Update-AgentFile -TargetFile $COPILOT_FILE -AgentName 'GitHub Copilot')) { $ok = $false }; $found = $true }
|
||||
if (Test-Path $CURSOR_FILE) { if (-not (Update-AgentFile -TargetFile $CURSOR_FILE -AgentName 'Cursor IDE')) { $ok = $false }; $found = $true }
|
||||
if (Test-Path $QWEN_FILE) { if (-not (Update-AgentFile -TargetFile $QWEN_FILE -AgentName 'Qwen Code')) { $ok = $false }; $found = $true }
|
||||
if (Test-Path $AGENTS_FILE) { if (-not (Update-AgentFile -TargetFile $AGENTS_FILE -AgentName 'Codex/opencode')) { $ok = $false }; $found = $true }
|
||||
if (Test-Path $WINDSURF_FILE) { if (-not (Update-AgentFile -TargetFile $WINDSURF_FILE -AgentName 'Windsurf')) { $ok = $false }; $found = $true }
|
||||
if (Test-Path $KILOCODE_FILE) { if (-not (Update-AgentFile -TargetFile $KILOCODE_FILE -AgentName 'Kilo Code')) { $ok = $false }; $found = $true }
|
||||
if (Test-Path $AUGGIE_FILE) { if (-not (Update-AgentFile -TargetFile $AUGGIE_FILE -AgentName 'Auggie CLI')) { $ok = $false }; $found = $true }
|
||||
if (Test-Path $ROO_FILE) { if (-not (Update-AgentFile -TargetFile $ROO_FILE -AgentName 'Roo Code')) { $ok = $false }; $found = $true }
|
||||
if (Test-Path $CODEBUDDY_FILE) { if (-not (Update-AgentFile -TargetFile $CODEBUDDY_FILE -AgentName 'CodeBuddy CLI')) { $ok = $false }; $found = $true }
|
||||
if (Test-Path $Q_FILE) { if (-not (Update-AgentFile -TargetFile $Q_FILE -AgentName 'Amazon Q Developer CLI')) { $ok = $false }; $found = $true }
|
||||
if (-not $found) {
|
||||
Write-Info 'No existing agent files found, creating default Claude file...'
|
||||
if (-not (Update-AgentFile -TargetFile $CLAUDE_FILE -AgentName 'Claude Code')) { $ok = $false }
|
||||
}
|
||||
return $ok
|
||||
}
|
||||
|
||||
function Print-Summary {
|
||||
Write-Host ''
|
||||
Write-Info 'Summary of changes:'
|
||||
if ($NEW_LANG) { Write-Host " - Added language: $NEW_LANG" }
|
||||
if ($NEW_FRAMEWORK) { Write-Host " - Added framework: $NEW_FRAMEWORK" }
|
||||
if ($NEW_DB -and $NEW_DB -ne 'N/A') { Write-Host " - Added database: $NEW_DB" }
|
||||
Write-Host ''
|
||||
Write-Info 'Usage: ./update-agent-context.ps1 [-AgentType claude|gemini|copilot|cursor-agent|qwen|opencode|codex|windsurf|kilocode|auggie|roo|codebuddy|amp|q]'
|
||||
}
|
||||
|
||||
function Main {
|
||||
Validate-Environment
|
||||
Write-Info "=== Updating agent context files for feature $CURRENT_BRANCH ==="
|
||||
if (-not (Parse-PlanData -PlanFile $NEW_PLAN)) { Write-Err 'Failed to parse plan data'; exit 1 }
|
||||
$success = $true
|
||||
if ($AgentType) {
|
||||
Write-Info "Updating specific agent: $AgentType"
|
||||
if (-not (Update-SpecificAgent -Type $AgentType)) { $success = $false }
|
||||
}
|
||||
else {
|
||||
Write-Info 'No agent specified, updating all existing agent files...'
|
||||
if (-not (Update-AllExistingAgents)) { $success = $false }
|
||||
}
|
||||
Print-Summary
|
||||
if ($success) { Write-Success 'Agent context update completed successfully'; exit 0 } else { Write-Err 'Agent context update completed with errors'; exit 1 }
|
||||
}
|
||||
|
||||
Main
|
||||
|
||||
@ -0,0 +1,28 @@
|
||||
# [PROJECT NAME] Development Guidelines
|
||||
|
||||
Auto-generated from all feature plans. Last updated: [DATE]
|
||||
|
||||
## Active Technologies
|
||||
|
||||
[EXTRACTED FROM ALL PLAN.MD FILES]
|
||||
|
||||
## Project Structure
|
||||
|
||||
```text
|
||||
[ACTUAL STRUCTURE FROM PLANS]
|
||||
```
|
||||
|
||||
## Commands
|
||||
|
||||
[ONLY COMMANDS FOR ACTIVE TECHNOLOGIES]
|
||||
|
||||
## Code Style
|
||||
|
||||
[LANGUAGE-SPECIFIC, ONLY FOR LANGUAGES IN USE]
|
||||
|
||||
## Recent Changes
|
||||
|
||||
[LAST 3 FEATURES AND WHAT THEY ADDED]
|
||||
|
||||
<!-- MANUAL ADDITIONS START -->
|
||||
<!-- MANUAL ADDITIONS END -->
|
||||
@ -0,0 +1,40 @@
|
||||
# [CHECKLIST TYPE] Checklist: [FEATURE NAME]
|
||||
|
||||
**Purpose**: [Brief description of what this checklist covers]
|
||||
**Created**: [DATE]
|
||||
**Feature**: [Link to spec.md or relevant documentation]
|
||||
|
||||
**Note**: This checklist is generated by the `/speckit.checklist` command based on feature context and requirements.
|
||||
|
||||
<!--
|
||||
============================================================================
|
||||
IMPORTANT: The checklist items below are SAMPLE ITEMS for illustration only.
|
||||
|
||||
The /speckit.checklist command MUST replace these with actual items based on:
|
||||
- User's specific checklist request
|
||||
- Feature requirements from spec.md
|
||||
- Technical context from plan.md
|
||||
- Implementation details from tasks.md
|
||||
|
||||
DO NOT keep these sample items in the generated checklist file.
|
||||
============================================================================
|
||||
-->
|
||||
|
||||
## [Category 1]
|
||||
|
||||
- [ ] CHK001 First checklist item with clear action
|
||||
- [ ] CHK002 Second checklist item
|
||||
- [ ] CHK003 Third checklist item
|
||||
|
||||
## [Category 2]
|
||||
|
||||
- [ ] CHK004 Another category item
|
||||
- [ ] CHK005 Item with specific criteria
|
||||
- [ ] CHK006 Final item in this category
|
||||
|
||||
## Notes
|
||||
|
||||
- Check items off as completed: `[x]`
|
||||
- Add comments or findings inline
|
||||
- Link to relevant resources or documentation
|
||||
- Items are numbered sequentially for easy reference
|
||||
@ -0,0 +1,108 @@
|
||||
# Implementation Plan: [FEATURE]
|
||||
|
||||
**Branch**: `[###-feature-name]` | **Date**: [DATE] | **Spec**: [link]
|
||||
**Input**: Feature specification from `/specs/[###-feature-name]/spec.md`
|
||||
|
||||
**Note**: This template is filled in by the `/speckit.plan` command. See `.specify/templates/commands/plan.md` for the execution workflow.
|
||||
|
||||
## Summary
|
||||
|
||||
[Extract from feature spec: primary requirement + technical approach from research]
|
||||
|
||||
## Technical Context
|
||||
|
||||
<!--
|
||||
ACTION REQUIRED: Replace the content in this section with the technical details
|
||||
for the project. The structure here is presented in advisory capacity to guide
|
||||
the iteration process.
|
||||
-->
|
||||
|
||||
**Language/Version**: [e.g., Python 3.11, Swift 5.9, Rust 1.75 or NEEDS CLARIFICATION]
|
||||
**Primary Dependencies**: [e.g., FastAPI, UIKit, LLVM or NEEDS CLARIFICATION]
|
||||
**Storage**: [if applicable, e.g., PostgreSQL, CoreData, files or N/A]
|
||||
**Testing**: [e.g., pytest, XCTest, cargo test or NEEDS CLARIFICATION]
|
||||
**Target Platform**: [e.g., Linux server, iOS 15+, WASM or NEEDS CLARIFICATION]
|
||||
**Project Type**: [single/web/mobile - determines source structure]
|
||||
**Performance Goals**: [domain-specific, e.g., 1000 req/s, 10k lines/sec, 60 fps or NEEDS CLARIFICATION]
|
||||
**Constraints**: [domain-specific, e.g., <200ms p95, <100MB memory, offline-capable or NEEDS CLARIFICATION]
|
||||
**Scale/Scope**: [domain-specific, e.g., 10k users, 1M LOC, 50 screens or NEEDS CLARIFICATION]
|
||||
|
||||
## Constitution Check
|
||||
|
||||
*GATE: Must pass before Phase 0 research. Re-check after Phase 1 design.*
|
||||
|
||||
- Code Quality and Modularity: Confirm adherence to PEP 8, type hints, docstrings, and modular design.
|
||||
- Rigorous Testing Standards: Plan includes unit tests, integration tests, and model validation tests.
|
||||
- Reproducibility and Versioning: Versioning strategy for code, data, and models; random seed management.
|
||||
- Model Evaluation and Validation: Metrics and bias checks defined for model assessment.
|
||||
- Continuous Integration and Quality Gates: CI/CD pipeline with linting, testing, and performance checks.
|
||||
|
||||
## Project Structure
|
||||
|
||||
### Documentation (this feature)
|
||||
|
||||
```text
|
||||
specs/[###-feature]/
|
||||
├── plan.md # This file (/speckit.plan command output)
|
||||
├── research.md # Phase 0 output (/speckit.plan command)
|
||||
├── data-model.md # Phase 1 output (/speckit.plan command)
|
||||
├── quickstart.md # Phase 1 output (/speckit.plan command)
|
||||
├── contracts/ # Phase 1 output (/speckit.plan command)
|
||||
└── tasks.md # Phase 2 output (/speckit.tasks command - NOT created by /speckit.plan)
|
||||
```
|
||||
|
||||
### Source Code (repository root)
|
||||
<!--
|
||||
ACTION REQUIRED: Replace the placeholder tree below with the concrete layout
|
||||
for this feature. Delete unused options and expand the chosen structure with
|
||||
real paths (e.g., apps/admin, packages/something). The delivered plan must
|
||||
not include Option labels.
|
||||
-->
|
||||
|
||||
```text
|
||||
# [REMOVE IF UNUSED] Option 1: Single project (DEFAULT)
|
||||
src/
|
||||
├── models/
|
||||
├── services/
|
||||
├── cli/
|
||||
└── lib/
|
||||
|
||||
tests/
|
||||
├── contract/
|
||||
├── integration/
|
||||
└── unit/
|
||||
|
||||
# [REMOVE IF UNUSED] Option 2: Web application (when "frontend" + "backend" detected)
|
||||
backend/
|
||||
├── src/
|
||||
│ ├── models/
|
||||
│ ├── services/
|
||||
│ └── api/
|
||||
└── tests/
|
||||
|
||||
frontend/
|
||||
├── src/
|
||||
│ ├── components/
|
||||
│ ├── pages/
|
||||
│ └── services/
|
||||
└── tests/
|
||||
|
||||
# [REMOVE IF UNUSED] Option 3: Mobile + API (when "iOS/Android" detected)
|
||||
api/
|
||||
└── [same as backend above]
|
||||
|
||||
ios/ or android/
|
||||
└── [platform-specific structure: feature modules, UI flows, platform tests]
|
||||
```
|
||||
|
||||
**Structure Decision**: [Document the selected structure and reference the real
|
||||
directories captured above]
|
||||
|
||||
## Complexity Tracking
|
||||
|
||||
> **Fill ONLY if Constitution Check has violations that must be justified**
|
||||
|
||||
| Violation | Why Needed | Simpler Alternative Rejected Because |
|
||||
|-----------|------------|-------------------------------------|
|
||||
| [e.g., 4th project] | [current need] | [why 3 projects insufficient] |
|
||||
| [e.g., Repository pattern] | [specific problem] | [why direct DB access insufficient] |
|
||||
@ -0,0 +1,115 @@
|
||||
# Feature Specification: [FEATURE NAME]
|
||||
|
||||
**Feature Branch**: `[###-feature-name]`
|
||||
**Created**: [DATE]
|
||||
**Status**: Draft
|
||||
**Input**: User description: "$ARGUMENTS"
|
||||
|
||||
## User Scenarios & Testing *(mandatory)*
|
||||
|
||||
<!--
|
||||
IMPORTANT: User stories should be PRIORITIZED as user journeys ordered by importance.
|
||||
Each user story/journey must be INDEPENDENTLY TESTABLE - meaning if you implement just ONE of them,
|
||||
you should still have a viable MVP (Minimum Viable Product) that delivers value.
|
||||
|
||||
Assign priorities (P1, P2, P3, etc.) to each story, where P1 is the most critical.
|
||||
Think of each story as a standalone slice of functionality that can be:
|
||||
- Developed independently
|
||||
- Tested independently
|
||||
- Deployed independently
|
||||
- Demonstrated to users independently
|
||||
-->
|
||||
|
||||
### User Story 1 - [Brief Title] (Priority: P1)
|
||||
|
||||
[Describe this user journey in plain language]
|
||||
|
||||
**Why this priority**: [Explain the value and why it has this priority level]
|
||||
|
||||
**Independent Test**: [Describe how this can be tested independently - e.g., "Can be fully tested by [specific action] and delivers [specific value]"]
|
||||
|
||||
**Acceptance Scenarios**:
|
||||
|
||||
1. **Given** [initial state], **When** [action], **Then** [expected outcome]
|
||||
2. **Given** [initial state], **When** [action], **Then** [expected outcome]
|
||||
|
||||
---
|
||||
|
||||
### User Story 2 - [Brief Title] (Priority: P2)
|
||||
|
||||
[Describe this user journey in plain language]
|
||||
|
||||
**Why this priority**: [Explain the value and why it has this priority level]
|
||||
|
||||
**Independent Test**: [Describe how this can be tested independently]
|
||||
|
||||
**Acceptance Scenarios**:
|
||||
|
||||
1. **Given** [initial state], **When** [action], **Then** [expected outcome]
|
||||
|
||||
---
|
||||
|
||||
### User Story 3 - [Brief Title] (Priority: P3)
|
||||
|
||||
[Describe this user journey in plain language]
|
||||
|
||||
**Why this priority**: [Explain the value and why it has this priority level]
|
||||
|
||||
**Independent Test**: [Describe how this can be tested independently]
|
||||
|
||||
**Acceptance Scenarios**:
|
||||
|
||||
1. **Given** [initial state], **When** [action], **Then** [expected outcome]
|
||||
|
||||
---
|
||||
|
||||
[Add more user stories as needed, each with an assigned priority]
|
||||
|
||||
### Edge Cases
|
||||
|
||||
<!--
|
||||
ACTION REQUIRED: The content in this section represents placeholders.
|
||||
Fill them out with the right edge cases.
|
||||
-->
|
||||
|
||||
- What happens when [boundary condition]?
|
||||
- How does system handle [error scenario]?
|
||||
|
||||
## Requirements *(mandatory)*
|
||||
|
||||
<!--
|
||||
ACTION REQUIRED: The content in this section represents placeholders.
|
||||
Fill them out with the right functional requirements.
|
||||
-->
|
||||
|
||||
### Functional Requirements
|
||||
|
||||
- **FR-001**: System MUST [specific capability, e.g., "allow users to create accounts"]
|
||||
- **FR-002**: System MUST [specific capability, e.g., "validate email addresses"]
|
||||
- **FR-003**: Users MUST be able to [key interaction, e.g., "reset their password"]
|
||||
- **FR-004**: System MUST [data requirement, e.g., "persist user preferences"]
|
||||
- **FR-005**: System MUST [behavior, e.g., "log all security events"]
|
||||
|
||||
*Example of marking unclear requirements:*
|
||||
|
||||
- **FR-006**: System MUST authenticate users via [NEEDS CLARIFICATION: auth method not specified - email/password, SSO, OAuth?]
|
||||
- **FR-007**: System MUST retain user data for [NEEDS CLARIFICATION: retention period not specified]
|
||||
|
||||
### Key Entities *(include if feature involves data)*
|
||||
|
||||
- **[Entity 1]**: [What it represents, key attributes without implementation]
|
||||
- **[Entity 2]**: [What it represents, relationships to other entities]
|
||||
|
||||
## Success Criteria *(mandatory)*
|
||||
|
||||
<!--
|
||||
ACTION REQUIRED: Define measurable success criteria.
|
||||
These must be technology-agnostic and measurable.
|
||||
-->
|
||||
|
||||
### Measurable Outcomes
|
||||
|
||||
- **SC-001**: [Measurable metric, e.g., "Users can complete account creation in under 2 minutes"]
|
||||
- **SC-002**: [Measurable metric, e.g., "System handles 1000 concurrent users without degradation"]
|
||||
- **SC-003**: [User satisfaction metric, e.g., "90% of users successfully complete primary task on first attempt"]
|
||||
- **SC-004**: [Business metric, e.g., "Reduce support tickets related to [X] by 50%"]
|
||||
@ -0,0 +1,251 @@
|
||||
---
|
||||
|
||||
description: "Task list template for feature implementation"
|
||||
---
|
||||
|
||||
# Tasks: [FEATURE NAME]
|
||||
|
||||
**Input**: Design documents from `/specs/[###-feature-name]/`
|
||||
**Prerequisites**: plan.md (required), spec.md (required for user stories), research.md, data-model.md, contracts/
|
||||
|
||||
**Tests**: The examples below include test tasks. Tests are OPTIONAL - only include them if explicitly requested in the feature specification.
|
||||
|
||||
**Organization**: Tasks are grouped by user story to enable independent implementation and testing of each story.
|
||||
|
||||
## Format: `[ID] [P?] [Story] Description`
|
||||
|
||||
- **[P]**: Can run in parallel (different files, no dependencies)
|
||||
- **[Story]**: Which user story this task belongs to (e.g., US1, US2, US3)
|
||||
- Include exact file paths in descriptions
|
||||
|
||||
## Path Conventions
|
||||
|
||||
- **Single project**: `src/`, `tests/` at repository root
|
||||
- **Web app**: `backend/src/`, `frontend/src/`
|
||||
- **Mobile**: `api/src/`, `ios/src/` or `android/src/`
|
||||
- Paths shown below assume single project - adjust based on plan.md structure
|
||||
|
||||
<!--
|
||||
============================================================================
|
||||
IMPORTANT: The tasks below are SAMPLE TASKS for illustration purposes only.
|
||||
|
||||
The /speckit.tasks command MUST replace these with actual tasks based on:
|
||||
- User stories from spec.md (with their priorities P1, P2, P3...)
|
||||
- Feature requirements from plan.md
|
||||
- Entities from data-model.md
|
||||
- Endpoints from contracts/
|
||||
|
||||
Tasks MUST be organized by user story so each story can be:
|
||||
- Implemented independently
|
||||
- Tested independently
|
||||
- Delivered as an MVP increment
|
||||
|
||||
DO NOT keep these sample tasks in the generated tasks.md file.
|
||||
============================================================================
|
||||
-->
|
||||
|
||||
## Phase 1: Setup (Shared Infrastructure)
|
||||
|
||||
**Purpose**: Project initialization and basic structure
|
||||
|
||||
- [ ] T001 Create project structure per implementation plan
|
||||
- [ ] T002 Initialize [language] project with [framework] dependencies
|
||||
- [ ] T003 [P] Configure linting and formatting tools
|
||||
|
||||
---
|
||||
|
||||
## Phase 2: Foundational (Blocking Prerequisites)
|
||||
|
||||
**Purpose**: Core infrastructure that MUST be complete before ANY user story can be implemented
|
||||
|
||||
**⚠️ CRITICAL**: No user story work can begin until this phase is complete
|
||||
|
||||
Examples of foundational tasks (adjust based on your project):
|
||||
|
||||
- [ ] T004 Setup database schema and migrations framework
|
||||
- [ ] T005 [P] Implement authentication/authorization framework
|
||||
- [ ] T006 [P] Setup API routing and middleware structure
|
||||
- [ ] T007 Create base models/entities that all stories depend on
|
||||
- [ ] T008 Configure error handling and logging infrastructure
|
||||
- [ ] T009 Setup environment configuration management
|
||||
|
||||
**Checkpoint**: Foundation ready - user story implementation can now begin in parallel
|
||||
|
||||
---
|
||||
|
||||
## Phase 3: User Story 1 - [Title] (Priority: P1) 🎯 MVP
|
||||
|
||||
**Goal**: [Brief description of what this story delivers]
|
||||
|
||||
**Independent Test**: [How to verify this story works on its own]
|
||||
|
||||
### Tests for User Story 1 (OPTIONAL - only if tests requested) ⚠️
|
||||
|
||||
> **NOTE: Write these tests FIRST, ensure they FAIL before implementation**
|
||||
|
||||
- [ ] T010 [P] [US1] Contract test for [endpoint] in tests/contract/test_[name].py
|
||||
- [ ] T011 [P] [US1] Integration test for [user journey] in tests/integration/test_[name].py
|
||||
|
||||
### Implementation for User Story 1
|
||||
|
||||
- [ ] T012 [P] [US1] Create [Entity1] model in src/models/[entity1].py
|
||||
- [ ] T013 [P] [US1] Create [Entity2] model in src/models/[entity2].py
|
||||
- [ ] T014 [US1] Implement [Service] in src/services/[service].py (depends on T012, T013)
|
||||
- [ ] T015 [US1] Implement [endpoint/feature] in src/[location]/[file].py
|
||||
- [ ] T016 [US1] Add validation and error handling
|
||||
- [ ] T017 [US1] Add logging for user story 1 operations
|
||||
|
||||
**Checkpoint**: At this point, User Story 1 should be fully functional and testable independently
|
||||
|
||||
---
|
||||
|
||||
## Phase 4: User Story 2 - [Title] (Priority: P2)
|
||||
|
||||
**Goal**: [Brief description of what this story delivers]
|
||||
|
||||
**Independent Test**: [How to verify this story works on its own]
|
||||
|
||||
### Tests for User Story 2 (OPTIONAL - only if tests requested) ⚠️
|
||||
|
||||
- [ ] T018 [P] [US2] Contract test for [endpoint] in tests/contract/test_[name].py
|
||||
- [ ] T019 [P] [US2] Integration test for [user journey] in tests/integration/test_[name].py
|
||||
|
||||
### Implementation for User Story 2
|
||||
|
||||
- [ ] T020 [P] [US2] Create [Entity] model in src/models/[entity].py
|
||||
- [ ] T021 [US2] Implement [Service] in src/services/[service].py
|
||||
- [ ] T022 [US2] Implement [endpoint/feature] in src/[location]/[file].py
|
||||
- [ ] T023 [US2] Integrate with User Story 1 components (if needed)
|
||||
|
||||
**Checkpoint**: At this point, User Stories 1 AND 2 should both work independently
|
||||
|
||||
---
|
||||
|
||||
## Phase 5: User Story 3 - [Title] (Priority: P3)
|
||||
|
||||
**Goal**: [Brief description of what this story delivers]
|
||||
|
||||
**Independent Test**: [How to verify this story works on its own]
|
||||
|
||||
### Tests for User Story 3 (OPTIONAL - only if tests requested) ⚠️
|
||||
|
||||
- [ ] T024 [P] [US3] Contract test for [endpoint] in tests/contract/test_[name].py
|
||||
- [ ] T025 [P] [US3] Integration test for [user journey] in tests/integration/test_[name].py
|
||||
|
||||
### Implementation for User Story 3
|
||||
|
||||
- [ ] T026 [P] [US3] Create [Entity] model in src/models/[entity].py
|
||||
- [ ] T027 [US3] Implement [Service] in src/services/[service].py
|
||||
- [ ] T028 [US3] Implement [endpoint/feature] in src/[location]/[file].py
|
||||
|
||||
**Checkpoint**: All user stories should now be independently functional
|
||||
|
||||
---
|
||||
|
||||
[Add more user story phases as needed, following the same pattern]
|
||||
|
||||
---
|
||||
|
||||
## Phase N: Polish & Cross-Cutting Concerns
|
||||
|
||||
**Purpose**: Improvements that affect multiple user stories
|
||||
|
||||
- [ ] TXXX [P] Documentation updates in docs/
|
||||
- [ ] TXXX Code cleanup and refactoring
|
||||
- [ ] TXXX Performance optimization across all stories
|
||||
- [ ] TXXX [P] Additional unit tests (if requested) in tests/unit/
|
||||
- [ ] TXXX Security hardening
|
||||
- [ ] TXXX Run quickstart.md validation
|
||||
|
||||
---
|
||||
|
||||
## Dependencies & Execution Order
|
||||
|
||||
### Phase Dependencies
|
||||
|
||||
- **Setup (Phase 1)**: No dependencies - can start immediately
|
||||
- **Foundational (Phase 2)**: Depends on Setup completion - BLOCKS all user stories
|
||||
- **User Stories (Phase 3+)**: All depend on Foundational phase completion
|
||||
- User stories can then proceed in parallel (if staffed)
|
||||
- Or sequentially in priority order (P1 → P2 → P3)
|
||||
- **Polish (Final Phase)**: Depends on all desired user stories being complete
|
||||
|
||||
### User Story Dependencies
|
||||
|
||||
- **User Story 1 (P1)**: Can start after Foundational (Phase 2) - No dependencies on other stories
|
||||
- **User Story 2 (P2)**: Can start after Foundational (Phase 2) - May integrate with US1 but should be independently testable
|
||||
- **User Story 3 (P3)**: Can start after Foundational (Phase 2) - May integrate with US1/US2 but should be independently testable
|
||||
|
||||
### Within Each User Story
|
||||
|
||||
- Tests (if included) MUST be written and FAIL before implementation
|
||||
- Models before services
|
||||
- Services before endpoints
|
||||
- Core implementation before integration
|
||||
- Story complete before moving to next priority
|
||||
|
||||
### Parallel Opportunities
|
||||
|
||||
- All Setup tasks marked [P] can run in parallel
|
||||
- All Foundational tasks marked [P] can run in parallel (within Phase 2)
|
||||
- Once Foundational phase completes, all user stories can start in parallel (if team capacity allows)
|
||||
- All tests for a user story marked [P] can run in parallel
|
||||
- Models within a story marked [P] can run in parallel
|
||||
- Different user stories can be worked on in parallel by different team members
|
||||
|
||||
---
|
||||
|
||||
## Parallel Example: User Story 1
|
||||
|
||||
```bash
|
||||
# Launch all tests for User Story 1 together (if tests requested):
|
||||
Task: "Contract test for [endpoint] in tests/contract/test_[name].py"
|
||||
Task: "Integration test for [user journey] in tests/integration/test_[name].py"
|
||||
|
||||
# Launch all models for User Story 1 together:
|
||||
Task: "Create [Entity1] model in src/models/[entity1].py"
|
||||
Task: "Create [Entity2] model in src/models/[entity2].py"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Implementation Strategy
|
||||
|
||||
### MVP First (User Story 1 Only)
|
||||
|
||||
1. Complete Phase 1: Setup
|
||||
2. Complete Phase 2: Foundational (CRITICAL - blocks all stories)
|
||||
3. Complete Phase 3: User Story 1
|
||||
4. **STOP and VALIDATE**: Test User Story 1 independently
|
||||
5. Deploy/demo if ready
|
||||
|
||||
### Incremental Delivery
|
||||
|
||||
1. Complete Setup + Foundational → Foundation ready
|
||||
2. Add User Story 1 → Test independently → Deploy/Demo (MVP!)
|
||||
3. Add User Story 2 → Test independently → Deploy/Demo
|
||||
4. Add User Story 3 → Test independently → Deploy/Demo
|
||||
5. Each story adds value without breaking previous stories
|
||||
|
||||
### Parallel Team Strategy
|
||||
|
||||
With multiple developers:
|
||||
|
||||
1. Team completes Setup + Foundational together
|
||||
2. Once Foundational is done:
|
||||
- Developer A: User Story 1
|
||||
- Developer B: User Story 2
|
||||
- Developer C: User Story 3
|
||||
3. Stories complete and integrate independently
|
||||
|
||||
---
|
||||
|
||||
## Notes
|
||||
|
||||
- [P] tasks = different files, no dependencies
|
||||
- [Story] label maps task to specific user story for traceability
|
||||
- Each user story should be independently completable and testable
|
||||
- Verify tests fail before implementing
|
||||
- Commit after each task or logical group
|
||||
- Stop at any checkpoint to validate story independently
|
||||
- Avoid: vague tasks, same file conflicts, cross-story dependencies that break independence
|
||||
203
Code/Supervised_learning/resnet/CONTRIBUTING.md
Normal file
203
Code/Supervised_learning/resnet/CONTRIBUTING.md
Normal file
@ -0,0 +1,203 @@
|
||||
# ResNet Phenology Classifier - Development Guide
|
||||
|
||||
## Development Setup
|
||||
|
||||
### Prerequisites
|
||||
- Python 3.11+
|
||||
- CUDA-capable GPU (recommended)
|
||||
- 8GB+ RAM
|
||||
- Git
|
||||
|
||||
### Environment Setup
|
||||
|
||||
1. Clone the repository:
|
||||
```bash
|
||||
git clone <repository_url>
|
||||
cd resnet
|
||||
```
|
||||
|
||||
2. Create virtual environment:
|
||||
```bash
|
||||
python -m venv venv
|
||||
source venv/bin/activate # Windows: venv\Scripts\activate
|
||||
```
|
||||
|
||||
3. Install dependencies:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
4. Install development dependencies:
|
||||
```bash
|
||||
pip install black flake8 mypy pylint pytest-cov
|
||||
```
|
||||
|
||||
## Code Quality Standards
|
||||
|
||||
### PEP 8 Compliance
|
||||
All code must follow PEP 8 standards:
|
||||
```bash
|
||||
flake8 src/ tests/
|
||||
```
|
||||
|
||||
### Type Hints
|
||||
Use type hints for all functions:
|
||||
```python
|
||||
def train_model(epochs: int, lr: float) -> dict:
|
||||
...
|
||||
```
|
||||
|
||||
### Docstrings
|
||||
All modules, classes, and functions must have docstrings:
|
||||
```python
|
||||
def function(arg: str) -> int:
|
||||
"""
|
||||
Brief description.
|
||||
|
||||
Args:
|
||||
arg: Description
|
||||
|
||||
Returns:
|
||||
Description
|
||||
"""
|
||||
pass
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
### Running Tests
|
||||
```bash
|
||||
# All tests
|
||||
pytest tests/ -v
|
||||
|
||||
# With coverage
|
||||
pytest tests/ --cov=src --cov-report=html
|
||||
|
||||
# Specific markers
|
||||
pytest tests/ -m unit
|
||||
pytest tests/ -m integration
|
||||
pytest tests/ -m slow
|
||||
```
|
||||
|
||||
### Writing Tests
|
||||
- Unit tests for all utility functions
|
||||
- Integration tests for data pipelines
|
||||
- Model validation tests
|
||||
- Use fixtures for common setup
|
||||
|
||||
### Test Coverage
|
||||
- Minimum 80% code coverage
|
||||
- 100% coverage for critical paths
|
||||
|
||||
## Continuous Integration
|
||||
|
||||
### Pre-commit Checks
|
||||
Before committing:
|
||||
1. Run linter: `flake8 src/ tests/`
|
||||
2. Run type checker: `mypy src/`
|
||||
3. Run tests: `pytest tests/ -v`
|
||||
4. Check formatting: `black --check src/ tests/`
|
||||
|
||||
### CI Pipeline
|
||||
The CI/CD pipeline runs:
|
||||
1. Linting (flake8, pylint)
|
||||
2. Type checking (mypy)
|
||||
3. Unit tests
|
||||
4. Integration tests
|
||||
5. Coverage report
|
||||
|
||||
## Model Development
|
||||
|
||||
### Training Best Practices
|
||||
1. Always set random seed
|
||||
2. Use validation set for hyperparameter tuning
|
||||
3. Save checkpoints regularly
|
||||
4. Monitor training metrics
|
||||
5. Use early stopping
|
||||
|
||||
### Evaluation
|
||||
- Evaluate on independent test set
|
||||
- Report multiple metrics (accuracy, recall, F1)
|
||||
- Analyze confusion matrix
|
||||
- Check for bias
|
||||
|
||||
### Versioning
|
||||
- Version models with timestamp
|
||||
- Track hyperparameters
|
||||
- Save class mappings
|
||||
- Document training data
|
||||
|
||||
## Git Workflow
|
||||
|
||||
### Branching Strategy
|
||||
- `master`: Production-ready code
|
||||
- `1-phenology-classifier`: Feature branch
|
||||
- Feature branches for new capabilities
|
||||
|
||||
### Commit Messages
|
||||
Follow conventional commits:
|
||||
```
|
||||
feat: add confusion matrix visualization
|
||||
fix: correct data loader split logic
|
||||
docs: update README with API examples
|
||||
test: add unit tests for inference
|
||||
```
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
### Training
|
||||
- Use mixed precision training
|
||||
- Optimize data loading (num_workers)
|
||||
- Use GPU if available
|
||||
- Batch size tuning
|
||||
|
||||
### Inference
|
||||
- Model quantization
|
||||
- Batch predictions
|
||||
- Cache loaded models
|
||||
- Optimize image preprocessing
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
**CUDA out of memory:**
|
||||
- Reduce batch size
|
||||
- Use gradient accumulation
|
||||
- Clear cache: `torch.cuda.empty_cache()`
|
||||
|
||||
**Slow data loading:**
|
||||
- Increase num_workers
|
||||
- Use SSD for dataset
|
||||
- Preprocess images offline
|
||||
|
||||
**Poor accuracy:**
|
||||
- Check data quality
|
||||
- Increase training epochs
|
||||
- Try different learning rates
|
||||
- Use data augmentation
|
||||
|
||||
## Documentation
|
||||
|
||||
### Code Documentation
|
||||
- Docstrings for all public APIs
|
||||
- Inline comments for complex logic
|
||||
- Type hints throughout
|
||||
|
||||
### Project Documentation
|
||||
- Update README for new features
|
||||
- Document API changes
|
||||
- Maintain changelog
|
||||
|
||||
## Release Process
|
||||
|
||||
1. Update version number
|
||||
2. Run full test suite
|
||||
3. Build documentation
|
||||
4. Create release notes
|
||||
5. Tag release in git
|
||||
6. Deploy to production
|
||||
|
||||
## Contact
|
||||
|
||||
For questions or issues, refer to the project specifications in `specs/1-phenology-classifier/`.
|
||||
224
Code/Supervised_learning/resnet/README.md
Normal file
224
Code/Supervised_learning/resnet/README.md
Normal file
@ -0,0 +1,224 @@
|
||||
# ResNet Phenology Classifier
|
||||
|
||||
A deep learning model for classifying plant images by phenological phase using ResNet50 architecture.
|
||||
|
||||
## Features
|
||||
|
||||
- **Training**: Train ResNet50 model on labeled plant images
|
||||
- **Evaluation**: Comprehensive metrics including accuracy, recall, macro-F1, and confusion matrix
|
||||
- **Inference**: Classify new images with visual output
|
||||
- **API**: REST API for batch classification
|
||||
- **Reproducibility**: Random seed management and versioning
|
||||
|
||||
## Dataset
|
||||
|
||||
The model uses the dataset located at:
|
||||
```
|
||||
C:\Users\sof12\Desktop\ML\Datasets\Nocciola\GBIF
|
||||
```
|
||||
|
||||
Dataset split: 70% training, 15% validation, 15% testing
|
||||
|
||||
## Installation
|
||||
|
||||
1. Create a virtual environment:
|
||||
```bash
|
||||
python -m venv venv
|
||||
source venv/bin/activate # On Windows: venv\Scripts\activate
|
||||
```
|
||||
|
||||
2. Install dependencies:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Training
|
||||
|
||||
Train the model on your dataset:
|
||||
|
||||
```bash
|
||||
python src/train.py \
|
||||
--data_dir C:\Users\sof12\Desktop\ML\Datasets\Nocciola\GBIF \
|
||||
--csv_file C:\Users\sof12\Desktop\ML\Datasets\Nocciola\GBIF\labels.csv \
|
||||
--output_dir models \
|
||||
--epochs 50 \
|
||||
--batch_size 32 \
|
||||
--lr 0.001
|
||||
```
|
||||
|
||||
**Arguments:**
|
||||
- `--data_dir`: Directory containing images
|
||||
- `--csv_file`: Path to CSV file with columns: `image_path`, `phase`
|
||||
- `--output_dir`: Directory to save trained models (default: `models`)
|
||||
- `--epochs`: Number of training epochs (default: 50)
|
||||
- `--batch_size`: Batch size (default: 32)
|
||||
- `--lr`: Learning rate (default: 0.001)
|
||||
- `--num_workers`: Data loader workers (default: 4)
|
||||
- `--seed`: Random seed for reproducibility (default: 42)
|
||||
- `--patience`: Early stopping patience (default: 10)
|
||||
|
||||
### Evaluation
|
||||
|
||||
Evaluate a trained model:
|
||||
|
||||
```bash
|
||||
python src/evaluate.py \
|
||||
--model_path models/best_model.pth \
|
||||
--data_dir C:\Users\sof12\Desktop\ML\Datasets\Nocciola\GBIF \
|
||||
--csv_file C:\Users\sof12\Desktop\ML\Datasets\Nocciola\GBIF\labels.csv \
|
||||
--class_mapping models/class_mapping.json \
|
||||
--output_dir evaluation \
|
||||
--split test
|
||||
```
|
||||
|
||||
**Arguments:**
|
||||
- `--model_path`: Path to trained model checkpoint
|
||||
- `--data_dir`: Directory containing images
|
||||
- `--csv_file`: Path to CSV file with labels
|
||||
- `--class_mapping`: Path to class mapping JSON file
|
||||
- `--output_dir`: Directory to save evaluation results (default: `evaluation`)
|
||||
- `--split`: Dataset split to evaluate (`val` or `test`, default: `test`)
|
||||
|
||||
**Output:**
|
||||
- Accuracy, Recall (macro), F1 (macro)
|
||||
- Per-class metrics
|
||||
- Confusion matrix (normalized and raw)
|
||||
- Classification report
|
||||
|
||||
### Inference
|
||||
|
||||
Classify a single image:
|
||||
|
||||
```bash
|
||||
python src/inference.py \
|
||||
--image_path path/to/image.jpg \
|
||||
--model_path models/best_model.pth \
|
||||
--class_mapping models/class_mapping.json \
|
||||
--output_dir results
|
||||
```
|
||||
|
||||
**Arguments:**
|
||||
- `--image_path`: Path to image file
|
||||
- `--model_path`: Path to trained model checkpoint
|
||||
- `--class_mapping`: Path to class mapping JSON file
|
||||
- `--output_dir`: Directory to save results (optional)
|
||||
- `--no_visualize`: Skip creating visualization
|
||||
|
||||
**Output:**
|
||||
- Predicted phenological phase
|
||||
- Confidence score
|
||||
- Probabilities for all classes
|
||||
- Visual output showing prediction and probabilities
|
||||
|
||||
### API Server
|
||||
|
||||
Start the FastAPI server:
|
||||
|
||||
```bash
|
||||
# Set environment variables
|
||||
export MODEL_PATH=models/best_model.pth
|
||||
export CLASS_MAPPING_PATH=models/class_mapping.json
|
||||
|
||||
# Start server
|
||||
python -m uvicorn src.api:app --host 0.0.0.0 --port 8000 --reload
|
||||
```
|
||||
|
||||
**Endpoints:**
|
||||
- `GET /`: API information
|
||||
- `GET /health`: Health check
|
||||
- `GET /classes`: List available classes
|
||||
- `POST /classify`: Classify single image
|
||||
- `POST /classify/batch`: Classify multiple images (max 10)
|
||||
|
||||
**Example request:**
|
||||
```bash
|
||||
curl -X POST "http://localhost:8000/classify" \
|
||||
-H "Content-Type: multipart/form-data" \
|
||||
-F "file=@path/to/image.jpg"
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
Run unit tests:
|
||||
|
||||
```bash
|
||||
pytest tests/ -v
|
||||
```
|
||||
|
||||
Run specific test categories:
|
||||
|
||||
```bash
|
||||
pytest tests/ -v -m unit # Unit tests only
|
||||
pytest tests/ -v -m integration # Integration tests only
|
||||
```
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
.
|
||||
├── src/
|
||||
│ ├── __init__.py
|
||||
│ ├── data_loader.py # Dataset and data loading
|
||||
│ ├── model.py # ResNet model definition
|
||||
│ ├── train.py # Training script
|
||||
│ ├── evaluate.py # Evaluation script
|
||||
│ ├── inference.py # Inference script
|
||||
│ ├── api.py # FastAPI application
|
||||
│ └── utils.py # Utility functions
|
||||
├── tests/
|
||||
│ ├── test_data_loader.py
|
||||
│ ├── test_train.py
|
||||
│ ├── test_evaluate.py
|
||||
│ └── test_inference.py
|
||||
├── data/ # Dataset directory
|
||||
├── models/ # Saved models
|
||||
├── specs/ # Feature specifications
|
||||
├── requirements.txt # Python dependencies
|
||||
├── pytest.ini # Pytest configuration
|
||||
└── README.md # This file
|
||||
```
|
||||
|
||||
## Model Architecture
|
||||
|
||||
- **Base**: ResNet50 pretrained on ImageNet
|
||||
- **Classification Head**:
|
||||
- Dropout (0.5)
|
||||
- Linear (2048 → 512)
|
||||
- ReLU
|
||||
- Dropout (0.3)
|
||||
- Linear (512 → num_classes)
|
||||
|
||||
## Performance Requirements
|
||||
|
||||
- **Accuracy**: >90% on test set
|
||||
- **Training Time**: <1 hour for standard dataset
|
||||
- **Inference Time**: <1 second per image
|
||||
|
||||
## Data Format
|
||||
|
||||
The CSV file should have the following format:
|
||||
|
||||
```csv
|
||||
image_path,phase
|
||||
images/img001.jpg,vegetative
|
||||
images/img002.jpg,flowering
|
||||
images/img003.jpg,fruiting
|
||||
```
|
||||
|
||||
## Reproducibility
|
||||
|
||||
The project ensures reproducibility through:
|
||||
- Random seed management (default: 42)
|
||||
- Deterministic training
|
||||
- Model checkpointing
|
||||
- Class mapping versioning
|
||||
|
||||
## License
|
||||
|
||||
This project is part of a supervised learning implementation for phenology classification.
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this code, please cite the project specifications in `specs/1-phenology-classifier/`.
|
||||
10
Code/Supervised_learning/resnet/pytest.ini
Normal file
10
Code/Supervised_learning/resnet/pytest.ini
Normal file
@ -0,0 +1,10 @@
|
||||
[pytest]
|
||||
testpaths = tests
|
||||
python_files = test_*.py
|
||||
python_classes = Test*
|
||||
python_functions = test_*
|
||||
addopts = -v --tb=short --strict-markers
|
||||
markers =
|
||||
unit: Unit tests
|
||||
integration: Integration tests
|
||||
slow: Slow running tests
|
||||
13
Code/Supervised_learning/resnet/requirements.txt
Normal file
13
Code/Supervised_learning/resnet/requirements.txt
Normal file
@ -0,0 +1,13 @@
|
||||
torch>=2.0.0
|
||||
torchvision>=0.15.0
|
||||
pandas>=2.0.0
|
||||
scikit-learn>=1.3.0
|
||||
matplotlib>=3.7.0
|
||||
seaborn>=0.12.0
|
||||
Pillow>=10.0.0
|
||||
numpy>=1.24.0
|
||||
tqdm>=4.65.0
|
||||
pytest>=7.4.0
|
||||
fastapi>=0.100.0
|
||||
uvicorn>=0.23.0
|
||||
python-multipart>=0.0.6
|
||||
@ -0,0 +1,35 @@
|
||||
# Specification Quality Checklist: ResNet Phenology Classifier
|
||||
|
||||
**Purpose**: Validate specification completeness and quality before proceeding to planning
|
||||
**Created**: 2025-11-04
|
||||
**Feature**: [Link to spec.md](specs/1-phenology-classifier/spec.md)
|
||||
|
||||
## Content Quality
|
||||
|
||||
- [x] No implementation details (languages, frameworks, APIs)
|
||||
- [x] Focused on user value and business needs
|
||||
- [x] Written for non-technical stakeholders
|
||||
- [x] All mandatory sections completed
|
||||
|
||||
## Requirement Completeness
|
||||
|
||||
- [x] No [NEEDS CLARIFICATION] markers remain
|
||||
- [x] Requirements are testable and unambiguous
|
||||
- [x] Success criteria are measurable
|
||||
- [x] Success criteria are technology-agnostic (no implementation details)
|
||||
- [x] All acceptance scenarios are defined
|
||||
- [x] Edge cases are identified
|
||||
- [x] Scope is clearly bounded
|
||||
- [x] Dependencies and assumptions identified
|
||||
|
||||
## Feature Readiness
|
||||
|
||||
- [x] All functional requirements have clear acceptance criteria
|
||||
- [x] User scenarios cover primary flows
|
||||
- [x] Feature meets measurable outcomes defined in Success Criteria
|
||||
- [x] No implementation details leak into specification
|
||||
|
||||
## Notes
|
||||
|
||||
- Phenological phases are not specified; may need clarification.
|
||||
- Assumed standard dataset format and phases based on common plant phenology.
|
||||
@ -0,0 +1,60 @@
|
||||
openapi: 3.0.3
|
||||
info:
|
||||
title: ResNet Phenology Classifier API
|
||||
version: 1.0.0
|
||||
description: API for classifying plant images by phenological phase using ResNet model
|
||||
|
||||
paths:
|
||||
/classify:
|
||||
post:
|
||||
summary: Classify a plant image
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
multipart/form-data:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
image:
|
||||
type: string
|
||||
format: binary
|
||||
description: Plant image file (JPEG/PNG)
|
||||
required:
|
||||
- image
|
||||
responses:
|
||||
'200':
|
||||
description: Classification result
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
phase:
|
||||
type: string
|
||||
description: Predicted phenological phase
|
||||
confidence:
|
||||
type: number
|
||||
format: float
|
||||
description: Confidence score (0-1)
|
||||
probabilities:
|
||||
type: object
|
||||
description: Probabilities for all phases
|
||||
additionalProperties:
|
||||
type: number
|
||||
format: float
|
||||
'400':
|
||||
description: Invalid input
|
||||
'500':
|
||||
description: Server error
|
||||
|
||||
components:
|
||||
schemas:
|
||||
ClassificationResult:
|
||||
type: object
|
||||
properties:
|
||||
phase:
|
||||
type: string
|
||||
confidence:
|
||||
type: number
|
||||
probabilities:
|
||||
type: object
|
||||
@ -0,0 +1,41 @@
|
||||
# Data Model: ResNet Phenology Classifier
|
||||
|
||||
**Date**: 2025-11-04
|
||||
**Feature**: specs/1-phenology-classifier/spec.md
|
||||
|
||||
## Entities
|
||||
|
||||
### Plant Image
|
||||
**Description**: Represents an image of a plant used for training or inference.
|
||||
|
||||
**Fields**:
|
||||
- `path` (string): File path to the image (required)
|
||||
- `phase` (string): Phenological phase label (required)
|
||||
- `width` (int): Image width in pixels
|
||||
- `height` (int): Image height in pixels
|
||||
|
||||
**Validation Rules**:
|
||||
- Path must exist and be readable image file
|
||||
- Phase must be one of the defined classes in CSV
|
||||
- Image dimensions must be at least 224x224 for ResNet input
|
||||
|
||||
### Phenological Phase
|
||||
**Description**: Represents a growth stage classification for plants.
|
||||
|
||||
**Fields**:
|
||||
- `name` (string): Phase identifier (e.g., "vegetative", "flowering") (required)
|
||||
- `description` (string): Human-readable description of the phase
|
||||
- `index` (int): Numerical index for model output (0-based)
|
||||
|
||||
**Validation Rules**:
|
||||
- Name must be unique
|
||||
- Index must be sequential starting from 0
|
||||
|
||||
## Relationships
|
||||
|
||||
- Plant Image belongs to Phenological Phase (many-to-one)
|
||||
- Phases are defined in the labels CSV file
|
||||
|
||||
## State Transitions
|
||||
|
||||
N/A - Static classification task
|
||||
@ -0,0 +1,88 @@
|
||||
# Implementation Plan: ResNet Phenology Classifier
|
||||
|
||||
**Branch**: `1-phenology-classifier` | **Date**: 2025-11-04 | **Spec**: specs/1-phenology-classifier/spec.md
|
||||
**Input**: Feature specification from `/specs/1-phenology-classifier/spec.md`
|
||||
|
||||
**Note**: This template is filled in by the `/speckit.plan` command. See `.specify/templates/commands/plan.md` for the execution workflow.
|
||||
|
||||
## Summary
|
||||
|
||||
Build a ResNet model to classify plant images by phenological phase using labeled datasets. Technical approach: Use PyTorch to implement ResNet architecture, train on image dataset with CSV labels, evaluate performance metrics.
|
||||
|
||||
## Technical Context
|
||||
|
||||
**Language/Version**: Python 3.11
|
||||
**Primary Dependencies**: PyTorch, torchvision, pandas, scikit-learn
|
||||
**Storage**: Image files (JPEG/PNG) and CSV labels
|
||||
**Testing**: pytest, torch.testing
|
||||
**Target Platform**: Linux server with GPU support
|
||||
**Project Type**: single (ML model training and inference script)
|
||||
**Performance Goals**: >90% classification accuracy, <1 second inference time
|
||||
**Constraints**: GPU memory availability, dataset size up to 10k images
|
||||
**Scale/Scope**: Support for multiple phenological phases as defined in CSV
|
||||
|
||||
## Constitution Check
|
||||
|
||||
*GATE: Must pass before Phase 0 research. Re-check after Phase 1 design.*
|
||||
|
||||
- Code Quality and Modularity: Confirm adherence to PEP 8, type hints, docstrings, and modular design.
|
||||
- Rigorous Testing Standards: Plan includes unit tests, integration tests, and model validation tests.
|
||||
- Reproducibility and Versioning: Versioning strategy for code, data, and models; random seed management.
|
||||
- Model Evaluation and Validation: Metrics and bias checks defined for model assessment.
|
||||
- Continuous Integration and Quality Gates: CI/CD pipeline with linting, testing, and performance checks.
|
||||
|
||||
## Project Structure
|
||||
|
||||
### Documentation (this feature)
|
||||
|
||||
```text
|
||||
specs/1-phenology-classifier/
|
||||
├── plan.md # This file (/speckit.plan command output)
|
||||
├── research.md # Phase 0 output (/speckit.plan command)
|
||||
├── data-model.md # Phase 1 output (/speckit.plan command)
|
||||
├── quickstart.md # Phase 1 output (/speckit.plan command)
|
||||
├── contracts/ # Phase 1 output (/speckit.plan command)
|
||||
└── tasks.md # Phase 2 output (/speckit.tasks command - NOT created by /speckit.plan)
|
||||
```
|
||||
|
||||
### Source Code (repository root)
|
||||
|
||||
```text
|
||||
src/
|
||||
├── __init__.py
|
||||
├── data_loader.py # Load images and labels from CSV
|
||||
├── model.py # ResNet model definition
|
||||
├── train.py # Training script
|
||||
├── evaluate.py # Evaluation script
|
||||
├── inference.py # Inference on new images
|
||||
└── utils.py # Helper functions
|
||||
|
||||
tests/
|
||||
├── test_data_loader.py
|
||||
├── test_model.py
|
||||
├── test_train.py
|
||||
└── test_inference.py
|
||||
|
||||
data/
|
||||
├── images/ # Plant images
|
||||
└── labels.csv # Image labels
|
||||
|
||||
models/ # Saved trained models
|
||||
requirements.txt # Python dependencies
|
||||
```
|
||||
|
||||
### Phase Execution
|
||||
|
||||
**Phase 0: Research** (resolve unknowns)
|
||||
- Research best ResNet variant for image classification
|
||||
- Research data preprocessing for plant images
|
||||
- Research evaluation metrics for multi-class classification
|
||||
|
||||
**Phase 1: Design** (create contracts)
|
||||
- Define data model for images and labels
|
||||
- Design API for model inference
|
||||
- Create quickstart guide
|
||||
|
||||
**Phase 2: Tasks** (implementation breakdown)
|
||||
- Break down into specific coding tasks
|
||||
- Assign priorities and dependencies
|
||||
@ -0,0 +1,68 @@
|
||||
# Quickstart: ResNet Phenology Classifier
|
||||
|
||||
**Date**: 2025-11-04
|
||||
**Feature**: specs/1-phenology-classifier/spec.md
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Python 3.11+
|
||||
- GPU with CUDA support (recommended)
|
||||
- 4GB+ RAM
|
||||
- Dataset with plant images and labels CSV
|
||||
|
||||
## Installation
|
||||
|
||||
1. Clone the repository and checkout the feature branch:
|
||||
```bash
|
||||
git checkout 1-phenology-classifier
|
||||
```
|
||||
|
||||
2. Install dependencies:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
3. Prepare your dataset:
|
||||
- Place images in `data/images/`
|
||||
- Create `data/labels.csv` with columns: `image_path`, `phase`
|
||||
|
||||
## Training
|
||||
|
||||
Run the training script:
|
||||
```bash
|
||||
python src/train.py --data_dir data/ --epochs 50 --batch_size 32
|
||||
```
|
||||
|
||||
This will:
|
||||
- Load the dataset
|
||||
- Train ResNet50 on your data
|
||||
- Save the model to `models/phenology_classifier.pth`
|
||||
|
||||
## Evaluation
|
||||
|
||||
Evaluate the trained model:
|
||||
```bash
|
||||
python src/evaluate.py --model_path models/phenology_classifier.pth --data_dir data/
|
||||
```
|
||||
|
||||
This outputs accuracy, F1-score, and per-class metrics.
|
||||
|
||||
## Inference
|
||||
|
||||
Classify a new image:
|
||||
```bash
|
||||
python src/inference.py --model_path models/phenology_classifier.pth --image_path path/to/image.jpg
|
||||
```
|
||||
|
||||
Or start the API server:
|
||||
```bash
|
||||
python -m uvicorn src.api:app --reload
|
||||
```
|
||||
|
||||
Then POST to `http://localhost:8000/classify` with image file.
|
||||
|
||||
## Expected Results
|
||||
|
||||
- Training time: ~30 minutes on GPU
|
||||
- Accuracy: >90% on validation set
|
||||
- Inference time: <1 second per image
|
||||
@ -0,0 +1,39 @@
|
||||
# Research: ResNet Phenology Classifier
|
||||
|
||||
**Date**: 2025-11-04
|
||||
**Feature**: specs/1-phenology-classifier/spec.md
|
||||
|
||||
## Research Tasks
|
||||
|
||||
1. Research best ResNet variant for plant image classification
|
||||
2. Research data preprocessing techniques for botanical images
|
||||
3. Research evaluation metrics for multi-class phenological phase classification
|
||||
4. Research reproducibility practices for ML experiments
|
||||
|
||||
## Findings & Decisions
|
||||
|
||||
### ResNet Variant
|
||||
**Decision**: Use ResNet50 as the base architecture
|
||||
**Rationale**: Provides good balance between accuracy and computational efficiency. ResNet50 has been proven effective for image classification tasks similar to ImageNet.
|
||||
**Alternatives considered**: ResNet18 (faster but lower accuracy), ResNet101 (higher accuracy but more compute-intensive)
|
||||
|
||||
### Data Preprocessing
|
||||
**Decision**: Use standard ImageNet preprocessing with augmentation
|
||||
**Rationale**: Random cropping, horizontal flipping, normalization to ImageNet mean/std. This is standard for transfer learning with ResNet.
|
||||
**Alternatives considered**: Custom augmentations for plant-specific features, but standard works well for general classification.
|
||||
|
||||
### Evaluation Metrics
|
||||
**Decision**: Primary: Accuracy, Secondary: F1-score per class, Precision, Recall
|
||||
**Rationale**: Accuracy for overall performance, F1-score to handle class imbalance in phenological phases.
|
||||
**Alternatives considered**: AUC-ROC (more for binary), but multi-class metrics are appropriate.
|
||||
|
||||
### Reproducibility
|
||||
**Decision**: Use random seeds, version data with DVC, log all hyperparameters
|
||||
**Rationale**: Ensures experiments can be reproduced. DVC for data versioning, MLflow or similar for experiment tracking.
|
||||
**Alternatives considered**: Manual logging, but automated tools are more reliable.
|
||||
|
||||
## Resolved Clarifications
|
||||
|
||||
- Dataset format: Images in directory, labels in CSV with columns: image_path, phase
|
||||
- Model output: Probabilities for each phase class
|
||||
- Training hardware: GPU required for reasonable training time
|
||||
@ -0,0 +1,94 @@
|
||||
# Feature Specification: ResNet Phenology Classifier
|
||||
|
||||
**Feature Branch**: `1-phenology-classifier`
|
||||
**Created**: 2025-11-04
|
||||
**Status**: Draft
|
||||
**Input**: User description: "Construye un modelo ResNet que tenga la capacidad de clasificar por fase fenologica las imagenes de una planta. La imagenes estan dadas en datasets y etiquetadas de pendiendo de su fase"
|
||||
|
||||
## User Scenarios & Testing *(mandatory)*
|
||||
|
||||
### User Story 1 - Train ResNet Model (Priority: P1)
|
||||
|
||||
As a researcher, I want to train a ResNet model on a dataset of labeled plant images to learn phenological phase classification.
|
||||
|
||||
**Why this priority**: This is the core functionality required to build the classifier.
|
||||
|
||||
**Independent Test**: Can be fully tested by training on a subset of the dataset and validating that the model learns to classify phases accurately.
|
||||
|
||||
**Acceptance Scenarios**:
|
||||
|
||||
1. **Given** a dataset of plant images labeled by phenological phase, **When** the model is trained, **Then** it achieves at least 90% accuracy on a held-out test set.
|
||||
2. **Given** training data, **When** training completes, **Then** the model can be saved and loaded for inference.
|
||||
|
||||
---
|
||||
|
||||
### User Story 2 - Evaluate Model Performance (Priority: P2)
|
||||
|
||||
As a researcher, I want to evaluate the trained model's performance on unseen data to ensure reliability.
|
||||
|
||||
**Why this priority**: Evaluation is essential to validate the model's effectiveness before deployment.
|
||||
|
||||
**Independent Test**: Can be tested by running evaluation on test data and checking metrics like accuracy, precision, and recall.
|
||||
|
||||
**Acceptance Scenarios**:
|
||||
|
||||
1. **Given** a trained model and test dataset, **When** evaluation is run, **Then** detailed metrics are provided including accuracy, recall, macro-f1, and confusion matrix.
|
||||
2. **Given** evaluation results, **When** performance is below threshold, **Then** the model is flagged for retraining.
|
||||
|
||||
---
|
||||
|
||||
### User Story 3 - Classify New Images (Priority: P3)
|
||||
|
||||
As a user, I want to use the trained model to classify new plant images by phenological phase.
|
||||
|
||||
**Why this priority**: This enables practical use of the model for monitoring or analysis.
|
||||
|
||||
**Independent Test**: Can be tested by providing a new image and verifying the predicted phase matches expectations.
|
||||
|
||||
**Acceptance Scenarios**:
|
||||
|
||||
1. **Given** a new plant image, **When** classification is requested, **Then** the model returns the predicted phenological phase with visual output.
|
||||
2. **Given** an image, **When** classified, **Then** response time is under 1 second.
|
||||
|
||||
---
|
||||
|
||||
## Clarifications
|
||||
|
||||
### Session 2025-11-04
|
||||
|
||||
- Q: How should the dataset be split for training, validation, and testing? → A: 70/15/15
|
||||
- Q: What evaluation metrics must be included? → A: Accuracy, recall, macro-f1, confusion matrix; other metrics accepted
|
||||
- Q: How should classification results be returned? → A: Visually
|
||||
|
||||
### Edge Cases
|
||||
|
||||
- What happens when an image is of poor quality or not a plant?
|
||||
- How does the system handle images with multiple plants or unclear phases?
|
||||
- What if the dataset has imbalanced classes for certain phases?
|
||||
|
||||
## Requirements *(mandatory)*
|
||||
|
||||
### Functional Requirements
|
||||
|
||||
- **FR-001**: System MUST load and preprocess labeled plant image datasets.
|
||||
- **FR-002**: System MUST train a ResNet model on the dataset to classify phenological phases specified in the dataset's .csv labels file.
|
||||
- **FR-003**: System MUST evaluate the model using metrics including accuracy, recall, macro-f1, and confusion matrix. Other metrics are accepted.
|
||||
- **FR-004**: System MUST provide inference capability to classify new images.
|
||||
- **FR-005**: System MUST save and load trained models for reuse.
|
||||
- **FR-006**: System MUST split the dataset located at C:\Users\sof12\Desktop\ML\Datasets\Nocciola\GBIF into 70% training, 15% validation, 15% testing.
|
||||
- **FR-007**: System MUST provide visual output of classification results.
|
||||
|
||||
### Key Entities *(include if feature involves data)*
|
||||
|
||||
- **Dataset**: Located at C:\Users\sof12\Desktop\ML\Datasets\Nocciola\GBIF, containing plant images and labels.
|
||||
- **Plant Image**: Represents an image of a plant, with attributes like image data and associated phenological phase label.
|
||||
- **Phenological Phase**: Represents a growth stage of the plant, with attributes like phase name and description.
|
||||
|
||||
## Success Criteria *(mandatory)*
|
||||
|
||||
### Measurable Outcomes
|
||||
|
||||
- **SC-001**: Model achieves >90% accuracy on a held-out test set.
|
||||
- **SC-002**: Training process completes within 1 hour for a standard dataset size.
|
||||
- **SC-003**: Inference on a single image takes less than 1 second.
|
||||
- **SC-004**: System handles datasets with up to 10,000 images without performance degradation.
|
||||
@ -0,0 +1,97 @@
|
||||
# Tasks: ResNet Phenology Classifier
|
||||
|
||||
**Input**: Design documents from `/specs/1-phenology-classifier/`
|
||||
**Prerequisites**: plan.md (required), spec.md (required for user stories), research.md, data-model.md, contracts/
|
||||
|
||||
**Tests**: Included as per constitution testing standards.
|
||||
|
||||
**Organization**: Tasks are grouped by user story to enable independent implementation and testing of each story.
|
||||
|
||||
## Format: `[ID] [P?] [Story] Description`
|
||||
|
||||
- **[P]**: Can run in parallel (different files, no dependencies)
|
||||
- **[Story]**: Which user story this task belongs to (e.g., US1, US2, US3)
|
||||
- Include exact file paths in descriptions
|
||||
|
||||
## Path Conventions
|
||||
|
||||
- **Single project**: `src/`, `tests/` at repository root
|
||||
|
||||
## Dependencies
|
||||
|
||||
- US1 (Train Model) must complete before US2 (Evaluate) and US3 (Classify)
|
||||
- US2 and US3 can run in parallel after US1
|
||||
|
||||
## Parallel Execution Examples
|
||||
|
||||
- Setup tasks: T001-T005 can run in parallel
|
||||
- US1 tasks: T010-T015 can run in parallel except for training which depends on data loading
|
||||
- US2 and US3: Can run in parallel after US1 completion
|
||||
|
||||
## Implementation Strategy
|
||||
|
||||
- MVP: Complete US1 for basic training capability
|
||||
- Incremental: Add US2 for evaluation, then US3 for inference
|
||||
- Each user story delivers independently testable value
|
||||
|
||||
## Phase 1: Setup (Shared Infrastructure)
|
||||
|
||||
**Purpose**: Project initialization and basic structure
|
||||
|
||||
- [x] T001 Create project directory structure per plan.md
|
||||
- [x] T002 Create requirements.txt with PyTorch, torchvision, pandas, scikit-learn
|
||||
- [x] T003 Create data/ directory and subdirectories for images and labels
|
||||
- [x] T004 Create models/ directory for saved models
|
||||
- [x] T005 Create src/__init__.py and basic module structure
|
||||
|
||||
## Phase 2: Foundational (Blocking Prerequisites)
|
||||
|
||||
**Purpose**: Core components needed by all user stories
|
||||
|
||||
- [x] T006 [P] Implement data loader in src/data_loader.py for CSV labels and image loading
|
||||
- [x] T007 [P] Define ResNet50 model in src/model.py with classification head
|
||||
- [x] T008 [P] Create utils.py with preprocessing and helper functions
|
||||
- [x] T009 [P] Set up test framework with pytest configuration
|
||||
- [x] T010 [P] Create unit tests for data loader in tests/test_data_loader.py
|
||||
|
||||
## Phase 3: US1 - Train ResNet Model
|
||||
|
||||
**Purpose**: Enable model training on labeled datasets
|
||||
**Independent Test**: Train on subset and verify model learns (accuracy improves)
|
||||
|
||||
- [x] T011 [P] [US1] Implement training script in src/train.py with data loading and ResNet training loop
|
||||
- [x] T012 [P] [US1] Add model saving functionality to train.py
|
||||
- [x] T013 [P] [US1] Implement data augmentation in utils.py for training
|
||||
- [x] T014 [P] [US1] Create unit tests for training components in tests/test_train.py
|
||||
- [x] T015 [US1] Integrate training pipeline and test end-to-end training
|
||||
|
||||
## Phase 4: US2 - Evaluate Model Performance
|
||||
|
||||
**Purpose**: Provide evaluation metrics for trained models
|
||||
**Independent Test**: Run evaluation on test set and verify metrics output
|
||||
|
||||
- [x] T016 [P] [US2] Implement evaluation script in src/evaluate.py with accuracy and F1-score calculation
|
||||
- [x] T017 [P] [US2] Add per-class metrics and confusion matrix in evaluate.py
|
||||
- [x] T018 [P] [US2] Create unit tests for evaluation in tests/test_evaluate.py
|
||||
- [x] T019 [US2] Integrate evaluation and test on trained model
|
||||
|
||||
## Phase 5: US3 - Classify New Images
|
||||
|
||||
**Purpose**: Enable inference on new plant images
|
||||
**Independent Test**: Classify sample image and verify output format
|
||||
|
||||
- [x] T020 [P] [US3] Implement inference script in src/inference.py for single image classification
|
||||
- [x] T021 [P] [US3] Create API endpoint in src/api.py using FastAPI for /classify POST
|
||||
- [x] T022 [P] [US3] Add input validation and error handling in api.py
|
||||
- [x] T023 [P] [US3] Create unit tests for inference in tests/test_inference.py
|
||||
- [x] T024 [US3] Integrate API and test classification endpoint
|
||||
|
||||
## Final Phase: Polish & Cross-Cutting Concerns
|
||||
|
||||
**Purpose**: Quality assurance and production readiness
|
||||
|
||||
- [x] T025 Add logging and monitoring to all scripts
|
||||
- [x] T026 Implement CI/CD pipeline with linting and testing
|
||||
- [x] T027 Add comprehensive documentation and README updates
|
||||
- [x] T028 Performance optimization and memory management
|
||||
- [x] T029 Final integration testing and validation
|
||||
3
Code/Supervised_learning/resnet/src/__init__.py
Normal file
3
Code/Supervised_learning/resnet/src/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
"""ResNet Phenology Classifier Package"""
|
||||
|
||||
__version__ = "1.0.0"
|
||||
174
Code/Supervised_learning/resnet/src/api.py
Normal file
174
Code/Supervised_learning/resnet/src/api.py
Normal file
@ -0,0 +1,174 @@
|
||||
"""FastAPI application for phenology classification."""
|
||||
|
||||
from fastapi import FastAPI, File, UploadFile, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
import torch
|
||||
from PIL import Image
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
from src.inference import load_inference_model, preprocess_image, predict
|
||||
|
||||
# Initialize FastAPI app
|
||||
app = FastAPI(
|
||||
title="ResNet Phenology Classifier API",
|
||||
description="API for classifying plant images by phenological phase",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# Global variables for model
|
||||
model = None
|
||||
classes = None
|
||||
device = None
|
||||
|
||||
# Configuration
|
||||
MODEL_PATH = os.getenv("MODEL_PATH", "models/best_model.pth")
|
||||
CLASS_MAPPING_PATH = os.getenv("CLASS_MAPPING_PATH", "models/class_mapping.json")
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def load_model():
|
||||
"""Load model on startup."""
|
||||
global model, classes, device
|
||||
try:
|
||||
model, classes, device = load_inference_model(MODEL_PATH, CLASS_MAPPING_PATH)
|
||||
print(f"Model loaded successfully from {MODEL_PATH}")
|
||||
print(f"Classes: {classes}")
|
||||
except Exception as e:
|
||||
print(f"Error loading model: {e}")
|
||||
raise
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint."""
|
||||
return {
|
||||
"message": "ResNet Phenology Classifier API",
|
||||
"version": "1.0.0",
|
||||
"endpoints": {
|
||||
"classify": "/classify",
|
||||
"health": "/health",
|
||||
"classes": "/classes"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint."""
|
||||
if model is None:
|
||||
raise HTTPException(status_code=503, detail="Model not loaded")
|
||||
return {"status": "healthy", "model_loaded": True}
|
||||
|
||||
|
||||
@app.get("/classes")
|
||||
async def get_classes():
|
||||
"""Get list of available classes."""
|
||||
if classes is None:
|
||||
raise HTTPException(status_code=503, detail="Model not loaded")
|
||||
return {"classes": classes, "num_classes": len(classes)}
|
||||
|
||||
|
||||
@app.post("/classify")
|
||||
async def classify_image(file: UploadFile = File(...)) -> Dict:
|
||||
"""
|
||||
Classify a plant image.
|
||||
|
||||
Args:
|
||||
file: Uploaded image file
|
||||
|
||||
Returns:
|
||||
Dictionary with classification results
|
||||
"""
|
||||
if model is None:
|
||||
raise HTTPException(status_code=503, detail="Model not loaded")
|
||||
|
||||
# Validate file type
|
||||
if not file.content_type.startswith("image/"):
|
||||
raise HTTPException(status_code=400, detail="File must be an image")
|
||||
|
||||
try:
|
||||
# Read image
|
||||
contents = await file.read()
|
||||
image = Image.open(io.BytesIO(contents)).convert('RGB')
|
||||
|
||||
# Preprocess
|
||||
from torchvision import transforms
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
])
|
||||
image_tensor = transform(image).unsqueeze(0)
|
||||
|
||||
# Predict
|
||||
result = predict(model, image_tensor, classes, device)
|
||||
|
||||
return JSONResponse(content=result)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
|
||||
|
||||
|
||||
@app.post("/classify/batch")
|
||||
async def classify_batch(files: list[UploadFile] = File(...)) -> Dict:
|
||||
"""
|
||||
Classify multiple images.
|
||||
|
||||
Args:
|
||||
files: List of uploaded image files
|
||||
|
||||
Returns:
|
||||
Dictionary with classification results for each image
|
||||
"""
|
||||
if model is None:
|
||||
raise HTTPException(status_code=503, detail="Model not loaded")
|
||||
|
||||
if len(files) > 10:
|
||||
raise HTTPException(status_code=400, detail="Maximum 10 images per batch")
|
||||
|
||||
results = []
|
||||
|
||||
for idx, file in enumerate(files):
|
||||
if not file.content_type.startswith("image/"):
|
||||
results.append({
|
||||
"filename": file.filename,
|
||||
"error": "File must be an image"
|
||||
})
|
||||
continue
|
||||
|
||||
try:
|
||||
# Read image
|
||||
contents = await file.read()
|
||||
image = Image.open(io.BytesIO(contents)).convert('RGB')
|
||||
|
||||
# Preprocess
|
||||
from torchvision import transforms
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
])
|
||||
image_tensor = transform(image).unsqueeze(0)
|
||||
|
||||
# Predict
|
||||
result = predict(model, image_tensor, classes, device)
|
||||
result["filename"] = file.filename
|
||||
results.append(result)
|
||||
|
||||
except Exception as e:
|
||||
results.append({
|
||||
"filename": file.filename,
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
return JSONResponse(content={"results": results, "total": len(files)})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
171
Code/Supervised_learning/resnet/src/data_loader.py
Normal file
171
Code/Supervised_learning/resnet/src/data_loader.py
Normal file
@ -0,0 +1,171 @@
|
||||
"""Data loader module for loading and preprocessing plant images."""
|
||||
|
||||
import os
|
||||
from typing import Tuple, List, Dict
|
||||
import pandas as pd
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torchvision import transforms
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
class PhenologyDataset(Dataset):
|
||||
"""Dataset class for plant phenology images."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
csv_file: str,
|
||||
root_dir: str,
|
||||
transform: transforms.Compose = None,
|
||||
split_ratio: Tuple[float, float, float] = (0.7, 0.15, 0.15),
|
||||
split: str = 'train',
|
||||
random_seed: int = 42
|
||||
):
|
||||
"""
|
||||
Initialize the dataset.
|
||||
|
||||
Args:
|
||||
csv_file: Path to CSV file with image paths and labels
|
||||
root_dir: Root directory containing images
|
||||
transform: Optional transform to be applied on images
|
||||
split_ratio: Tuple of (train, val, test) ratios
|
||||
split: One of 'train', 'val', or 'test'
|
||||
random_seed: Random seed for reproducibility
|
||||
"""
|
||||
self.root_dir = root_dir
|
||||
self.transform = transform
|
||||
self.split = split
|
||||
|
||||
# Load CSV
|
||||
self.data_frame = pd.read_csv(csv_file)
|
||||
|
||||
# Get unique classes and create mapping
|
||||
self.classes = sorted(self.data_frame['phase'].unique())
|
||||
self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
|
||||
self.idx_to_class = {idx: cls for cls, idx in self.class_to_idx.items()}
|
||||
|
||||
# Split data
|
||||
np.random.seed(random_seed)
|
||||
indices = np.random.permutation(len(self.data_frame))
|
||||
|
||||
train_size = int(len(indices) * split_ratio[0])
|
||||
val_size = int(len(indices) * split_ratio[1])
|
||||
|
||||
if split == 'train':
|
||||
self.indices = indices[:train_size]
|
||||
elif split == 'val':
|
||||
self.indices = indices[train_size:train_size + val_size]
|
||||
elif split == 'test':
|
||||
self.indices = indices[train_size + val_size:]
|
||||
else:
|
||||
raise ValueError(f"Invalid split: {split}. Must be 'train', 'val', or 'test'")
|
||||
|
||||
self.data_frame = self.data_frame.iloc[self.indices].reset_index(drop=True)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the size of the dataset."""
|
||||
return len(self.data_frame)
|
||||
|
||||
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
|
||||
"""
|
||||
Get item at index.
|
||||
|
||||
Args:
|
||||
idx: Index of item
|
||||
|
||||
Returns:
|
||||
Tuple of (image tensor, label)
|
||||
"""
|
||||
if torch.is_tensor(idx):
|
||||
idx = idx.tolist()
|
||||
|
||||
img_name = os.path.join(self.root_dir, self.data_frame.iloc[idx]['image_path'])
|
||||
image = Image.open(img_name).convert('RGB')
|
||||
label = self.class_to_idx[self.data_frame.iloc[idx]['phase']]
|
||||
|
||||
if self.transform:
|
||||
image = self.transform(image)
|
||||
|
||||
return image, label
|
||||
|
||||
def get_class_weights(self) -> torch.Tensor:
|
||||
"""
|
||||
Calculate class weights for handling imbalanced datasets.
|
||||
|
||||
Returns:
|
||||
Tensor of class weights
|
||||
"""
|
||||
class_counts = self.data_frame['phase'].value_counts()
|
||||
total = len(self.data_frame)
|
||||
weights = torch.tensor([total / class_counts[cls] for cls in self.classes])
|
||||
return weights / weights.sum()
|
||||
|
||||
|
||||
def get_data_loaders(
|
||||
csv_file: str,
|
||||
root_dir: str,
|
||||
batch_size: int = 32,
|
||||
num_workers: int = 4,
|
||||
split_ratio: Tuple[float, float, float] = (0.7, 0.15, 0.15),
|
||||
random_seed: int = 42
|
||||
) -> Dict[str, DataLoader]:
|
||||
"""
|
||||
Create data loaders for train, validation, and test sets.
|
||||
|
||||
Args:
|
||||
csv_file: Path to CSV file with image paths and labels
|
||||
root_dir: Root directory containing images
|
||||
batch_size: Batch size for data loaders
|
||||
num_workers: Number of worker processes for data loading
|
||||
split_ratio: Tuple of (train, val, test) ratios
|
||||
random_seed: Random seed for reproducibility
|
||||
|
||||
Returns:
|
||||
Dictionary containing train, val, and test data loaders
|
||||
"""
|
||||
# Define transforms
|
||||
train_transform = transforms.Compose([
|
||||
transforms.RandomResizedCrop(224),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.RandomRotation(15),
|
||||
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
val_test_transform = transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
# Create datasets
|
||||
train_dataset = PhenologyDataset(
|
||||
csv_file, root_dir, train_transform, split_ratio, 'train', random_seed
|
||||
)
|
||||
val_dataset = PhenologyDataset(
|
||||
csv_file, root_dir, val_test_transform, split_ratio, 'val', random_seed
|
||||
)
|
||||
test_dataset = PhenologyDataset(
|
||||
csv_file, root_dir, val_test_transform, split_ratio, 'test', random_seed
|
||||
)
|
||||
|
||||
# Create data loaders
|
||||
data_loaders = {
|
||||
'train': DataLoader(
|
||||
train_dataset, batch_size=batch_size, shuffle=True,
|
||||
num_workers=num_workers, pin_memory=True
|
||||
),
|
||||
'val': DataLoader(
|
||||
val_dataset, batch_size=batch_size, shuffle=False,
|
||||
num_workers=num_workers, pin_memory=True
|
||||
),
|
||||
'test': DataLoader(
|
||||
test_dataset, batch_size=batch_size, shuffle=False,
|
||||
num_workers=num_workers, pin_memory=True
|
||||
)
|
||||
}
|
||||
|
||||
return data_loaders, train_dataset.classes, train_dataset.class_to_idx
|
||||
259
Code/Supervised_learning/resnet/src/evaluate.py
Normal file
259
Code/Supervised_learning/resnet/src/evaluate.py
Normal file
@ -0,0 +1,259 @@
|
||||
"""Evaluation script for ResNet phenology classifier."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from sklearn.metrics import accuracy_score, recall_score, f1_score, confusion_matrix, classification_report
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
|
||||
from src.model import create_model
|
||||
from src.data_loader import get_data_loaders
|
||||
from src.utils import load_checkpoint
|
||||
|
||||
|
||||
def plot_confusion_matrix(
|
||||
cm: np.ndarray,
|
||||
class_names: list,
|
||||
save_path: str = None,
|
||||
normalize: bool = True
|
||||
):
|
||||
"""
|
||||
Plot confusion matrix.
|
||||
|
||||
Args:
|
||||
cm: Confusion matrix
|
||||
class_names: List of class names
|
||||
save_path: Path to save plot
|
||||
normalize: Whether to normalize the confusion matrix
|
||||
"""
|
||||
if normalize:
|
||||
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
|
||||
fmt = '.2f'
|
||||
else:
|
||||
fmt = 'd'
|
||||
|
||||
plt.figure(figsize=(10, 8))
|
||||
sns.heatmap(
|
||||
cm,
|
||||
annot=True,
|
||||
fmt=fmt,
|
||||
cmap='Blues',
|
||||
xticklabels=class_names,
|
||||
yticklabels=class_names,
|
||||
cbar_kws={'label': 'Proportion' if normalize else 'Count'}
|
||||
)
|
||||
plt.ylabel('True Label')
|
||||
plt.xlabel('Predicted Label')
|
||||
plt.title('Confusion Matrix' + (' (Normalized)' if normalize else ''))
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||
print(f"Confusion matrix saved to {save_path}")
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def evaluate_model(
|
||||
model: nn.Module,
|
||||
dataloader: torch.utils.data.DataLoader,
|
||||
device: str,
|
||||
class_names: list
|
||||
) -> dict:
|
||||
"""
|
||||
Evaluate model and compute metrics.
|
||||
|
||||
Args:
|
||||
model: Model to evaluate
|
||||
dataloader: Data loader
|
||||
device: Device to use
|
||||
class_names: List of class names
|
||||
|
||||
Returns:
|
||||
Dictionary containing evaluation metrics
|
||||
"""
|
||||
model.eval()
|
||||
all_preds = []
|
||||
all_labels = []
|
||||
|
||||
with torch.no_grad():
|
||||
pbar = tqdm(dataloader, desc='Evaluating')
|
||||
for images, labels in pbar:
|
||||
images = images.to(device)
|
||||
outputs = model(images)
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
|
||||
all_preds.extend(predicted.cpu().numpy())
|
||||
all_labels.extend(labels.numpy())
|
||||
|
||||
all_preds = np.array(all_preds)
|
||||
all_labels = np.array(all_labels)
|
||||
|
||||
# Compute metrics
|
||||
accuracy = accuracy_score(all_labels, all_preds)
|
||||
recall = recall_score(all_labels, all_preds, average='macro')
|
||||
macro_f1 = f1_score(all_labels, all_preds, average='macro')
|
||||
|
||||
# Per-class metrics
|
||||
per_class_recall = recall_score(all_labels, all_preds, average=None)
|
||||
per_class_f1 = f1_score(all_labels, all_preds, average=None)
|
||||
|
||||
# Confusion matrix
|
||||
cm = confusion_matrix(all_labels, all_preds)
|
||||
|
||||
# Classification report
|
||||
report = classification_report(
|
||||
all_labels, all_preds,
|
||||
target_names=class_names,
|
||||
digits=4
|
||||
)
|
||||
|
||||
metrics = {
|
||||
'accuracy': float(accuracy),
|
||||
'recall_macro': float(recall),
|
||||
'f1_macro': float(macro_f1),
|
||||
'per_class_recall': {class_names[i]: float(per_class_recall[i]) for i in range(len(class_names))},
|
||||
'per_class_f1': {class_names[i]: float(per_class_f1[i]) for i in range(len(class_names))},
|
||||
'confusion_matrix': cm.tolist(),
|
||||
'classification_report': report
|
||||
}
|
||||
|
||||
return metrics, cm
|
||||
|
||||
|
||||
def evaluate(
|
||||
model_path: str,
|
||||
data_dir: str,
|
||||
csv_file: str,
|
||||
class_mapping_path: str,
|
||||
output_dir: str = 'evaluation',
|
||||
batch_size: int = 32,
|
||||
num_workers: int = 4,
|
||||
split: str = 'test'
|
||||
):
|
||||
"""
|
||||
Main evaluation function.
|
||||
|
||||
Args:
|
||||
model_path: Path to trained model checkpoint
|
||||
data_dir: Directory containing images
|
||||
csv_file: Path to CSV file with labels
|
||||
class_mapping_path: Path to class mapping JSON file
|
||||
output_dir: Directory to save evaluation results
|
||||
batch_size: Batch size
|
||||
num_workers: Number of data loader workers
|
||||
split: Which split to evaluate ('val' or 'test')
|
||||
"""
|
||||
# Create output directory
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Device configuration
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Load class mapping
|
||||
print("\nLoading class mapping...")
|
||||
with open(class_mapping_path, 'r') as f:
|
||||
class_mapping = json.load(f)
|
||||
|
||||
classes = class_mapping['classes']
|
||||
num_classes = len(classes)
|
||||
print(f"Number of classes: {num_classes}")
|
||||
print(f"Classes: {classes}")
|
||||
|
||||
# Load data
|
||||
print(f"\nLoading {split} data...")
|
||||
data_loaders, _, _ = get_data_loaders(
|
||||
csv_file=csv_file,
|
||||
root_dir=data_dir,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers
|
||||
)
|
||||
|
||||
dataloader = data_loaders[split]
|
||||
print(f"Number of {split} samples: {len(dataloader.dataset)}")
|
||||
|
||||
# Load model
|
||||
print("\nLoading model...")
|
||||
model = create_model(num_classes, pretrained=False, device=device)
|
||||
load_checkpoint(model_path, model, device=device)
|
||||
|
||||
# Evaluate
|
||||
print(f"\nEvaluating on {split} set...")
|
||||
metrics, cm = evaluate_model(model, dataloader, device, classes)
|
||||
|
||||
# Print results
|
||||
print(f"\n{'='*50}")
|
||||
print(f"EVALUATION RESULTS ({split.upper()} SET)")
|
||||
print(f"{'='*50}")
|
||||
print(f"Accuracy: {metrics['accuracy']:.4f} ({metrics['accuracy']*100:.2f}%)")
|
||||
print(f"Recall (Macro): {metrics['recall_macro']:.4f}")
|
||||
print(f"F1 (Macro): {metrics['f1_macro']:.4f}")
|
||||
print(f"\nPer-Class Metrics:")
|
||||
print("-" * 50)
|
||||
for cls in classes:
|
||||
print(f" {cls:20s} - Recall: {metrics['per_class_recall'][cls]:.4f}, F1: {metrics['per_class_f1'][cls]:.4f}")
|
||||
print(f"{'='*50}")
|
||||
|
||||
print(f"\nClassification Report:")
|
||||
print(metrics['classification_report'])
|
||||
|
||||
# Save metrics
|
||||
metrics_path = os.path.join(output_dir, f'{split}_metrics.json')
|
||||
with open(metrics_path, 'w') as f:
|
||||
# Don't save classification report in JSON (it's a string)
|
||||
metrics_to_save = {k: v for k, v in metrics.items() if k != 'classification_report'}
|
||||
json.dump(metrics_to_save, f, indent=2)
|
||||
print(f"\nMetrics saved to: {metrics_path}")
|
||||
|
||||
# Save classification report
|
||||
report_path = os.path.join(output_dir, f'{split}_classification_report.txt')
|
||||
with open(report_path, 'w') as f:
|
||||
f.write(metrics['classification_report'])
|
||||
print(f"Classification report saved to: {report_path}")
|
||||
|
||||
# Plot confusion matrix
|
||||
cm_path = os.path.join(output_dir, f'{split}_confusion_matrix.png')
|
||||
plot_confusion_matrix(cm, classes, save_path=cm_path, normalize=True)
|
||||
|
||||
# Also save non-normalized version
|
||||
cm_raw_path = os.path.join(output_dir, f'{split}_confusion_matrix_raw.png')
|
||||
plot_confusion_matrix(cm, classes, save_path=cm_raw_path, normalize=False)
|
||||
|
||||
# Check if model meets success criteria
|
||||
if metrics['accuracy'] >= 0.9:
|
||||
print(f"\n✓ Model meets success criteria (accuracy >= 90%)")
|
||||
else:
|
||||
print(f"\n✗ Model does not meet success criteria (accuracy < 90%)")
|
||||
print(f" Current accuracy: {metrics['accuracy']*100:.2f}%")
|
||||
print(f" Model should be flagged for retraining")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Evaluate ResNet phenology classifier')
|
||||
parser.add_argument('--model_path', type=str, required=True, help='Path to model checkpoint')
|
||||
parser.add_argument('--data_dir', type=str, required=True, help='Directory containing images')
|
||||
parser.add_argument('--csv_file', type=str, required=True, help='Path to CSV file with labels')
|
||||
parser.add_argument('--class_mapping', type=str, required=True, help='Path to class mapping JSON')
|
||||
parser.add_argument('--output_dir', type=str, default='evaluation', help='Directory to save results')
|
||||
parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
|
||||
parser.add_argument('--num_workers', type=int, default=4, help='Number of data loader workers')
|
||||
parser.add_argument('--split', type=str, default='test', choices=['val', 'test'], help='Split to evaluate')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
evaluate(
|
||||
model_path=args.model_path,
|
||||
data_dir=args.data_dir,
|
||||
csv_file=args.csv_file,
|
||||
class_mapping_path=args.class_mapping,
|
||||
output_dir=args.output_dir,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.num_workers,
|
||||
split=args.split
|
||||
)
|
||||
226
Code/Supervised_learning/resnet/src/inference.py
Normal file
226
Code/Supervised_learning/resnet/src/inference.py
Normal file
@ -0,0 +1,226 @@
|
||||
"""Inference script for ResNet phenology classifier."""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torchvision import transforms
|
||||
from PIL import Image
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
from src.model import create_model
|
||||
from src.utils import load_checkpoint
|
||||
|
||||
|
||||
def load_inference_model(model_path: str, class_mapping_path: str, device: str = None):
|
||||
"""
|
||||
Load model for inference.
|
||||
|
||||
Args:
|
||||
model_path: Path to model checkpoint
|
||||
class_mapping_path: Path to class mapping JSON
|
||||
device: Device to use
|
||||
|
||||
Returns:
|
||||
Tuple of (model, classes, device)
|
||||
"""
|
||||
if device is None:
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
# Load class mapping
|
||||
with open(class_mapping_path, 'r') as f:
|
||||
class_mapping = json.load(f)
|
||||
|
||||
classes = class_mapping['classes']
|
||||
num_classes = len(classes)
|
||||
|
||||
# Load model
|
||||
model = create_model(num_classes, pretrained=False, device=device)
|
||||
load_checkpoint(model_path, model, device=device)
|
||||
model.eval()
|
||||
|
||||
return model, classes, device
|
||||
|
||||
|
||||
def preprocess_image(image_path: str) -> torch.Tensor:
|
||||
"""
|
||||
Preprocess image for inference.
|
||||
|
||||
Args:
|
||||
image_path: Path to image file
|
||||
|
||||
Returns:
|
||||
Preprocessed image tensor
|
||||
"""
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
image = Image.open(image_path).convert('RGB')
|
||||
image_tensor = transform(image).unsqueeze(0) # Add batch dimension
|
||||
|
||||
return image_tensor, image
|
||||
|
||||
|
||||
def predict(
|
||||
model: torch.nn.Module,
|
||||
image_tensor: torch.Tensor,
|
||||
classes: list,
|
||||
device: str
|
||||
) -> dict:
|
||||
"""
|
||||
Make prediction on image.
|
||||
|
||||
Args:
|
||||
model: Trained model
|
||||
image_tensor: Preprocessed image tensor
|
||||
classes: List of class names
|
||||
device: Device to use
|
||||
|
||||
Returns:
|
||||
Dictionary with prediction results
|
||||
"""
|
||||
image_tensor = image_tensor.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(image_tensor)
|
||||
probabilities = F.softmax(outputs, dim=1)
|
||||
confidence, predicted = torch.max(probabilities, 1)
|
||||
|
||||
predicted_class = classes[predicted.item()]
|
||||
confidence_score = confidence.item()
|
||||
|
||||
# Get probabilities for all classes
|
||||
all_probs = {classes[i]: probabilities[0][i].item() for i in range(len(classes))}
|
||||
|
||||
result = {
|
||||
'phase': predicted_class,
|
||||
'confidence': confidence_score,
|
||||
'probabilities': all_probs
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def visualize_prediction(
|
||||
image: Image.Image,
|
||||
result: dict,
|
||||
save_path: str = None
|
||||
):
|
||||
"""
|
||||
Visualize prediction result.
|
||||
|
||||
Args:
|
||||
image: Original PIL image
|
||||
result: Prediction result dictionary
|
||||
save_path: Optional path to save visualization
|
||||
"""
|
||||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
|
||||
|
||||
# Display image
|
||||
ax1.imshow(image)
|
||||
ax1.axis('off')
|
||||
ax1.set_title(f"Predicted: {result['phase']}\nConfidence: {result['confidence']:.2%}",
|
||||
fontsize=14, fontweight='bold')
|
||||
|
||||
# Display probabilities
|
||||
classes = list(result['probabilities'].keys())
|
||||
probs = list(result['probabilities'].values())
|
||||
|
||||
colors = ['green' if cls == result['phase'] else 'lightblue' for cls in classes]
|
||||
|
||||
ax2.barh(classes, probs, color=colors)
|
||||
ax2.set_xlabel('Probability', fontsize=12)
|
||||
ax2.set_title('Class Probabilities', fontsize=14, fontweight='bold')
|
||||
ax2.set_xlim(0, 1)
|
||||
|
||||
# Add probability values on bars
|
||||
for i, (cls, prob) in enumerate(zip(classes, probs)):
|
||||
ax2.text(prob + 0.02, i, f'{prob:.2%}', va='center', fontsize=10)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||
print(f"Visualization saved to {save_path}")
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def classify_image(
|
||||
image_path: str,
|
||||
model_path: str,
|
||||
class_mapping_path: str,
|
||||
output_dir: str = None,
|
||||
visualize: bool = True
|
||||
):
|
||||
"""
|
||||
Classify a single image.
|
||||
|
||||
Args:
|
||||
image_path: Path to image file
|
||||
model_path: Path to model checkpoint
|
||||
class_mapping_path: Path to class mapping JSON
|
||||
output_dir: Optional directory to save results
|
||||
visualize: Whether to create visualization
|
||||
"""
|
||||
print(f"Loading model from {model_path}...")
|
||||
model, classes, device = load_inference_model(model_path, class_mapping_path)
|
||||
|
||||
print(f"Processing image: {image_path}")
|
||||
image_tensor, original_image = preprocess_image(image_path)
|
||||
|
||||
print("Making prediction...")
|
||||
result = predict(model, image_tensor, classes, device)
|
||||
|
||||
# Print results
|
||||
print(f"\n{'='*50}")
|
||||
print(f"CLASSIFICATION RESULT")
|
||||
print(f"{'='*50}")
|
||||
print(f"Predicted Phase: {result['phase']}")
|
||||
print(f"Confidence: {result['confidence']:.2%}")
|
||||
print(f"\nAll Probabilities:")
|
||||
for cls, prob in sorted(result['probabilities'].items(), key=lambda x: x[1], reverse=True):
|
||||
print(f" {cls:20s}: {prob:.2%}")
|
||||
print(f"{'='*50}")
|
||||
|
||||
# Save results
|
||||
if output_dir:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Save JSON result
|
||||
result_path = os.path.join(output_dir, 'prediction.json')
|
||||
with open(result_path, 'w') as f:
|
||||
json.dump(result, f, indent=2)
|
||||
print(f"\nPrediction saved to: {result_path}")
|
||||
|
||||
# Visualize
|
||||
if visualize:
|
||||
save_path = os.path.join(output_dir, 'prediction_visualization.png') if output_dir else None
|
||||
visualize_prediction(original_image, result, save_path)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Classify plant image by phenological phase')
|
||||
parser.add_argument('--image_path', type=str, required=True, help='Path to image file')
|
||||
parser.add_argument('--model_path', type=str, required=True, help='Path to model checkpoint')
|
||||
parser.add_argument('--class_mapping', type=str, required=True, help='Path to class mapping JSON')
|
||||
parser.add_argument('--output_dir', type=str, default=None, help='Directory to save results')
|
||||
parser.add_argument('--no_visualize', action='store_true', help='Do not create visualization')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
classify_image(
|
||||
image_path=args.image_path,
|
||||
model_path=args.model_path,
|
||||
class_mapping_path=args.class_mapping,
|
||||
output_dir=args.output_dir,
|
||||
visualize=not args.no_visualize
|
||||
)
|
||||
106
Code/Supervised_learning/resnet/src/model.py
Normal file
106
Code/Supervised_learning/resnet/src/model.py
Normal file
@ -0,0 +1,106 @@
|
||||
"""ResNet model definition for phenology classification."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision import models
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class ResNetPhenologyClassifier(nn.Module):
|
||||
"""ResNet50-based classifier for plant phenology phases."""
|
||||
|
||||
def __init__(self, num_classes: int, pretrained: bool = True, freeze_backbone: bool = False):
|
||||
"""
|
||||
Initialize the ResNet classifier.
|
||||
|
||||
Args:
|
||||
num_classes: Number of phenology phase classes
|
||||
pretrained: Whether to use pretrained ImageNet weights
|
||||
freeze_backbone: Whether to freeze backbone layers
|
||||
"""
|
||||
super(ResNetPhenologyClassifier, self).__init__()
|
||||
|
||||
# Load pretrained ResNet50
|
||||
if pretrained:
|
||||
self.resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
|
||||
else:
|
||||
self.resnet = models.resnet50(weights=None)
|
||||
|
||||
# Freeze backbone if specified
|
||||
if freeze_backbone:
|
||||
for param in self.resnet.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# Replace final fully connected layer
|
||||
num_features = self.resnet.fc.in_features
|
||||
self.resnet.fc = nn.Sequential(
|
||||
nn.Dropout(0.5),
|
||||
nn.Linear(num_features, 512),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(512, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass through the network.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape (batch_size, 3, 224, 224)
|
||||
|
||||
Returns:
|
||||
Output tensor of shape (batch_size, num_classes)
|
||||
"""
|
||||
return self.resnet(x)
|
||||
|
||||
def unfreeze_backbone(self):
|
||||
"""Unfreeze all layers for fine-tuning."""
|
||||
for param in self.resnet.parameters():
|
||||
param.requires_grad = True
|
||||
|
||||
def get_features(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Extract features before classification layer.
|
||||
|
||||
Args:
|
||||
x: Input tensor
|
||||
|
||||
Returns:
|
||||
Feature tensor
|
||||
"""
|
||||
# Forward through all layers except fc
|
||||
x = self.resnet.conv1(x)
|
||||
x = self.resnet.bn1(x)
|
||||
x = self.resnet.relu(x)
|
||||
x = self.resnet.maxpool(x)
|
||||
|
||||
x = self.resnet.layer1(x)
|
||||
x = self.resnet.layer2(x)
|
||||
x = self.resnet.layer3(x)
|
||||
x = self.resnet.layer4(x)
|
||||
|
||||
x = self.resnet.avgpool(x)
|
||||
x = torch.flatten(x, 1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def create_model(num_classes: int, pretrained: bool = True, device: Optional[str] = None) -> ResNetPhenologyClassifier:
|
||||
"""
|
||||
Create and initialize a ResNet model.
|
||||
|
||||
Args:
|
||||
num_classes: Number of output classes
|
||||
pretrained: Whether to use pretrained weights
|
||||
device: Device to place model on ('cuda' or 'cpu')
|
||||
|
||||
Returns:
|
||||
Initialized model
|
||||
"""
|
||||
if device is None:
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
model = ResNetPhenologyClassifier(num_classes, pretrained=pretrained)
|
||||
model = model.to(device)
|
||||
|
||||
return model
|
||||
276
Code/Supervised_learning/resnet/src/train.py
Normal file
276
Code/Supervised_learning/resnet/src/train.py
Normal file
@ -0,0 +1,276 @@
|
||||
"""Training script for ResNet phenology classifier."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
from tqdm import tqdm
|
||||
import argparse
|
||||
import os
|
||||
from typing import Dict, Tuple
|
||||
import json
|
||||
|
||||
from src.model import create_model
|
||||
from src.data_loader import get_data_loaders
|
||||
from src.utils import set_seed, save_checkpoint, plot_training_history
|
||||
|
||||
|
||||
def train_epoch(
|
||||
model: nn.Module,
|
||||
dataloader: torch.utils.data.DataLoader,
|
||||
criterion: nn.Module,
|
||||
optimizer: optim.Optimizer,
|
||||
device: str
|
||||
) -> Tuple[float, float]:
|
||||
"""
|
||||
Train for one epoch.
|
||||
|
||||
Args:
|
||||
model: Model to train
|
||||
dataloader: Training data loader
|
||||
criterion: Loss function
|
||||
optimizer: Optimizer
|
||||
device: Device to train on
|
||||
|
||||
Returns:
|
||||
Tuple of (average loss, accuracy)
|
||||
"""
|
||||
model.train()
|
||||
running_loss = 0.0
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
pbar = tqdm(dataloader, desc='Training')
|
||||
for images, labels in pbar:
|
||||
images, labels = images.to(device), labels.to(device)
|
||||
|
||||
# Zero gradients
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Forward pass
|
||||
outputs = model(images)
|
||||
loss = criterion(outputs, labels)
|
||||
|
||||
# Backward pass
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Statistics
|
||||
running_loss += loss.item() * images.size(0)
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum().item()
|
||||
|
||||
# Update progress bar
|
||||
pbar.set_postfix({'loss': loss.item(), 'acc': 100 * correct / total})
|
||||
|
||||
epoch_loss = running_loss / total
|
||||
epoch_acc = 100 * correct / total
|
||||
|
||||
return epoch_loss, epoch_acc
|
||||
|
||||
|
||||
def validate(
|
||||
model: nn.Module,
|
||||
dataloader: torch.utils.data.DataLoader,
|
||||
criterion: nn.Module,
|
||||
device: str
|
||||
) -> Tuple[float, float]:
|
||||
"""
|
||||
Validate the model.
|
||||
|
||||
Args:
|
||||
model: Model to validate
|
||||
dataloader: Validation data loader
|
||||
criterion: Loss function
|
||||
device: Device to validate on
|
||||
|
||||
Returns:
|
||||
Tuple of (average loss, accuracy)
|
||||
"""
|
||||
model.eval()
|
||||
running_loss = 0.0
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
with torch.no_grad():
|
||||
pbar = tqdm(dataloader, desc='Validation')
|
||||
for images, labels in pbar:
|
||||
images, labels = images.to(device), labels.to(device)
|
||||
|
||||
# Forward pass
|
||||
outputs = model(images)
|
||||
loss = criterion(outputs, labels)
|
||||
|
||||
# Statistics
|
||||
running_loss += loss.item() * images.size(0)
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum().item()
|
||||
|
||||
# Update progress bar
|
||||
pbar.set_postfix({'loss': loss.item(), 'acc': 100 * correct / total})
|
||||
|
||||
epoch_loss = running_loss / total
|
||||
epoch_acc = 100 * correct / total
|
||||
|
||||
return epoch_loss, epoch_acc
|
||||
|
||||
|
||||
def train(
|
||||
data_dir: str,
|
||||
csv_file: str,
|
||||
output_dir: str = 'models',
|
||||
epochs: int = 50,
|
||||
batch_size: int = 32,
|
||||
learning_rate: float = 0.001,
|
||||
num_workers: int = 4,
|
||||
random_seed: int = 42,
|
||||
pretrained: bool = True,
|
||||
early_stopping_patience: int = 10
|
||||
):
|
||||
"""
|
||||
Main training function.
|
||||
|
||||
Args:
|
||||
data_dir: Directory containing images
|
||||
csv_file: Path to CSV file with labels
|
||||
output_dir: Directory to save models
|
||||
epochs: Number of training epochs
|
||||
batch_size: Batch size
|
||||
learning_rate: Learning rate
|
||||
num_workers: Number of data loader workers
|
||||
random_seed: Random seed for reproducibility
|
||||
pretrained: Use pretrained weights
|
||||
early_stopping_patience: Patience for early stopping
|
||||
"""
|
||||
# Set seed for reproducibility
|
||||
set_seed(random_seed)
|
||||
|
||||
# Device configuration
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Create output directory
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Load data
|
||||
print("\nLoading data...")
|
||||
data_loaders, classes, class_to_idx = get_data_loaders(
|
||||
csv_file=csv_file,
|
||||
root_dir=data_dir,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
random_seed=random_seed
|
||||
)
|
||||
|
||||
num_classes = len(classes)
|
||||
print(f"Number of classes: {num_classes}")
|
||||
print(f"Classes: {classes}")
|
||||
|
||||
# Save class mapping
|
||||
class_mapping = {'classes': classes, 'class_to_idx': class_to_idx}
|
||||
with open(os.path.join(output_dir, 'class_mapping.json'), 'w') as f:
|
||||
json.dump(class_mapping, f, indent=2)
|
||||
|
||||
# Create model
|
||||
print("\nCreating model...")
|
||||
model = create_model(num_classes, pretrained=pretrained, device=device)
|
||||
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
|
||||
|
||||
# Loss and optimizer
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
||||
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)
|
||||
|
||||
# Training history
|
||||
train_losses, val_losses = [], []
|
||||
train_accs, val_accs = [], []
|
||||
best_val_acc = 0.0
|
||||
epochs_without_improvement = 0
|
||||
|
||||
print("\nStarting training...")
|
||||
for epoch in range(epochs):
|
||||
print(f"\nEpoch {epoch + 1}/{epochs}")
|
||||
print("-" * 50)
|
||||
|
||||
# Train
|
||||
train_loss, train_acc = train_epoch(model, data_loaders['train'], criterion, optimizer, device)
|
||||
print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
|
||||
|
||||
# Validate
|
||||
val_loss, val_acc = validate(model, data_loaders['val'], criterion, device)
|
||||
print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
|
||||
|
||||
# Update learning rate
|
||||
scheduler.step(val_loss)
|
||||
|
||||
# Save history
|
||||
train_losses.append(train_loss)
|
||||
val_losses.append(val_loss)
|
||||
train_accs.append(train_acc)
|
||||
val_accs.append(val_acc)
|
||||
|
||||
# Save best model
|
||||
if val_acc > best_val_acc:
|
||||
best_val_acc = val_acc
|
||||
epochs_without_improvement = 0
|
||||
save_checkpoint(
|
||||
model, optimizer, epoch, val_loss, val_acc,
|
||||
os.path.join(output_dir, 'best_model.pth')
|
||||
)
|
||||
print(f"✓ New best model saved! Val Acc: {val_acc:.2f}%")
|
||||
else:
|
||||
epochs_without_improvement += 1
|
||||
|
||||
# Early stopping
|
||||
if epochs_without_improvement >= early_stopping_patience:
|
||||
print(f"\nEarly stopping triggered after {epoch + 1} epochs")
|
||||
break
|
||||
|
||||
# Save final model
|
||||
save_checkpoint(
|
||||
model, optimizer, epoch, val_loss, val_acc,
|
||||
os.path.join(output_dir, 'final_model.pth')
|
||||
)
|
||||
|
||||
# Plot training history
|
||||
plot_training_history(
|
||||
train_losses, val_losses, train_accs, val_accs,
|
||||
save_path=os.path.join(output_dir, 'training_history.png')
|
||||
)
|
||||
|
||||
print(f"\n{'='*50}")
|
||||
print(f"Training completed!")
|
||||
print(f"Best validation accuracy: {best_val_acc:.2f}%")
|
||||
print(f"Model saved to: {output_dir}")
|
||||
print(f"{'='*50}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Train ResNet phenology classifier')
|
||||
parser.add_argument('--data_dir', type=str, required=True, help='Directory containing images')
|
||||
parser.add_argument('--csv_file', type=str, required=True, help='Path to CSV file with labels')
|
||||
parser.add_argument('--output_dir', type=str, default='models', help='Directory to save models')
|
||||
parser.add_argument('--epochs', type=int, default=50, help='Number of epochs')
|
||||
parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
|
||||
parser.add_argument('--lr', type=float, default=0.001, help='Learning rate')
|
||||
parser.add_argument('--num_workers', type=int, default=4, help='Number of data loader workers')
|
||||
parser.add_argument('--seed', type=int, default=42, help='Random seed')
|
||||
parser.add_argument('--no_pretrained', action='store_true', help='Do not use pretrained weights')
|
||||
parser.add_argument('--patience', type=int, default=10, help='Early stopping patience')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
train(
|
||||
data_dir=args.data_dir,
|
||||
csv_file=args.csv_file,
|
||||
output_dir=args.output_dir,
|
||||
epochs=args.epochs,
|
||||
batch_size=args.batch_size,
|
||||
learning_rate=args.lr,
|
||||
num_workers=args.num_workers,
|
||||
random_seed=args.seed,
|
||||
pretrained=not args.no_pretrained,
|
||||
early_stopping_patience=args.patience
|
||||
)
|
||||
183
Code/Supervised_learning/resnet/src/utils.py
Normal file
183
Code/Supervised_learning/resnet/src/utils.py
Normal file
@ -0,0 +1,183 @@
|
||||
"""Utility functions for training and preprocessing."""
|
||||
|
||||
import torch
|
||||
import random
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
from typing import List, Optional
|
||||
import os
|
||||
|
||||
|
||||
def set_seed(seed: int = 42):
|
||||
"""
|
||||
Set random seed for reproducibility.
|
||||
|
||||
Args:
|
||||
seed: Random seed value
|
||||
"""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
model: torch.nn.Module,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
epoch: int,
|
||||
loss: float,
|
||||
accuracy: float,
|
||||
filepath: str
|
||||
):
|
||||
"""
|
||||
Save model checkpoint.
|
||||
|
||||
Args:
|
||||
model: Model to save
|
||||
optimizer: Optimizer state
|
||||
epoch: Current epoch
|
||||
loss: Current loss
|
||||
accuracy: Current accuracy
|
||||
filepath: Path to save checkpoint
|
||||
"""
|
||||
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
||||
|
||||
checkpoint = {
|
||||
'epoch': epoch,
|
||||
'model_state_dict': model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'loss': loss,
|
||||
'accuracy': accuracy
|
||||
}
|
||||
|
||||
torch.save(checkpoint, filepath)
|
||||
print(f"Checkpoint saved to {filepath}")
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
filepath: str,
|
||||
model: torch.nn.Module,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
device: str = 'cuda'
|
||||
) -> dict:
|
||||
"""
|
||||
Load model checkpoint.
|
||||
|
||||
Args:
|
||||
filepath: Path to checkpoint file
|
||||
model: Model to load weights into
|
||||
optimizer: Optional optimizer to load state into
|
||||
device: Device to load model on
|
||||
|
||||
Returns:
|
||||
Dictionary with checkpoint information
|
||||
"""
|
||||
checkpoint = torch.load(filepath, map_location=device)
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
|
||||
if optimizer is not None:
|
||||
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
|
||||
print(f"Checkpoint loaded from {filepath}")
|
||||
print(f"Epoch: {checkpoint['epoch']}, Loss: {checkpoint['loss']:.4f}, Accuracy: {checkpoint['accuracy']:.4f}")
|
||||
|
||||
return checkpoint
|
||||
|
||||
|
||||
def plot_training_history(
|
||||
train_losses: List[float],
|
||||
val_losses: List[float],
|
||||
train_accs: List[float],
|
||||
val_accs: List[float],
|
||||
save_path: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Plot training history.
|
||||
|
||||
Args:
|
||||
train_losses: List of training losses
|
||||
val_losses: List of validation losses
|
||||
train_accs: List of training accuracies
|
||||
val_accs: List of validation accuracies
|
||||
save_path: Optional path to save plot
|
||||
"""
|
||||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
|
||||
|
||||
# Plot loss
|
||||
ax1.plot(train_losses, label='Train Loss', marker='o')
|
||||
ax1.plot(val_losses, label='Val Loss', marker='s')
|
||||
ax1.set_xlabel('Epoch')
|
||||
ax1.set_ylabel('Loss')
|
||||
ax1.set_title('Training and Validation Loss')
|
||||
ax1.legend()
|
||||
ax1.grid(True)
|
||||
|
||||
# Plot accuracy
|
||||
ax2.plot(train_accs, label='Train Accuracy', marker='o')
|
||||
ax2.plot(val_accs, label='Val Accuracy', marker='s')
|
||||
ax2.set_xlabel('Epoch')
|
||||
ax2.set_ylabel('Accuracy (%)')
|
||||
ax2.set_title('Training and Validation Accuracy')
|
||||
ax2.legend()
|
||||
ax2.grid(True)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||
print(f"Training history plot saved to {save_path}")
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def visualize_predictions(
|
||||
images: torch.Tensor,
|
||||
predictions: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
class_names: List[str],
|
||||
num_images: int = 6,
|
||||
save_path: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Visualize model predictions on images.
|
||||
|
||||
Args:
|
||||
images: Batch of images
|
||||
predictions: Model predictions
|
||||
labels: True labels
|
||||
class_names: List of class names
|
||||
num_images: Number of images to display
|
||||
save_path: Optional path to save visualization
|
||||
"""
|
||||
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
|
||||
axes = axes.ravel()
|
||||
|
||||
# Denormalize images
|
||||
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
|
||||
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
|
||||
|
||||
for idx in range(min(num_images, len(images))):
|
||||
img = images[idx].cpu() * std + mean
|
||||
img = img.permute(1, 2, 0).numpy()
|
||||
img = np.clip(img, 0, 1)
|
||||
|
||||
pred_class = class_names[predictions[idx]]
|
||||
true_class = class_names[labels[idx]]
|
||||
|
||||
axes[idx].imshow(img)
|
||||
axes[idx].axis('off')
|
||||
|
||||
color = 'green' if pred_class == true_class else 'red'
|
||||
axes[idx].set_title(f'Pred: {pred_class}\nTrue: {true_class}', color=color, fontsize=12)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||
print(f"Predictions visualization saved to {save_path}")
|
||||
|
||||
plt.show()
|
||||
144
Code/Supervised_learning/resnet/tests/test_data_loader.py
Normal file
144
Code/Supervised_learning/resnet/tests/test_data_loader.py
Normal file
@ -0,0 +1,144 @@
|
||||
"""Unit tests for data loader module."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import pandas as pd
|
||||
import os
|
||||
from PIL import Image
|
||||
import tempfile
|
||||
import shutil
|
||||
from src.data_loader import PhenologyDataset, get_data_loaders
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_dataset():
|
||||
"""Create a temporary dataset for testing."""
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
img_dir = os.path.join(temp_dir, 'images')
|
||||
os.makedirs(img_dir, exist_ok=True)
|
||||
|
||||
# Create sample images
|
||||
phases = ['vegetative', 'flowering', 'fruiting']
|
||||
image_paths = []
|
||||
|
||||
for i, phase in enumerate(phases * 10): # 30 images total
|
||||
img = Image.new('RGB', (224, 224), color=(i*8, i*8, i*8))
|
||||
img_path = f'img_{i:03d}.jpg'
|
||||
img.save(os.path.join(img_dir, img_path))
|
||||
image_paths.append((img_path, phase))
|
||||
|
||||
# Create CSV
|
||||
csv_path = os.path.join(temp_dir, 'labels.csv')
|
||||
df = pd.DataFrame(image_paths, columns=['image_path', 'phase'])
|
||||
df.to_csv(csv_path, index=False)
|
||||
|
||||
yield csv_path, img_dir, phases
|
||||
|
||||
# Cleanup
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_dataset_initialization(sample_dataset):
|
||||
"""Test dataset initialization."""
|
||||
csv_path, img_dir, phases = sample_dataset
|
||||
|
||||
dataset = PhenologyDataset(
|
||||
csv_file=csv_path,
|
||||
root_dir=img_dir,
|
||||
split='train'
|
||||
)
|
||||
|
||||
assert len(dataset) > 0
|
||||
assert len(dataset.classes) == len(phases)
|
||||
assert set(dataset.classes) == set(phases)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_dataset_split_ratios(sample_dataset):
|
||||
"""Test dataset splitting."""
|
||||
csv_path, img_dir, _ = sample_dataset
|
||||
|
||||
train_dataset = PhenologyDataset(csv_path, img_dir, split='train')
|
||||
val_dataset = PhenologyDataset(csv_path, img_dir, split='val')
|
||||
test_dataset = PhenologyDataset(csv_path, img_dir, split='test')
|
||||
|
||||
total = len(train_dataset) + len(val_dataset) + len(test_dataset)
|
||||
|
||||
# Check approximate split ratios (70/15/15)
|
||||
assert abs(len(train_dataset) / total - 0.7) < 0.1
|
||||
assert abs(len(val_dataset) / total - 0.15) < 0.1
|
||||
assert abs(len(test_dataset) / total - 0.15) < 0.1
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_dataset_getitem(sample_dataset):
|
||||
"""Test getting items from dataset."""
|
||||
csv_path, img_dir, _ = sample_dataset
|
||||
|
||||
dataset = PhenologyDataset(csv_path, img_dir, split='train')
|
||||
image, label = dataset[0]
|
||||
|
||||
# Without transform, should return PIL Image
|
||||
assert isinstance(label, int)
|
||||
assert label >= 0 and label < len(dataset.classes)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_data_loaders(sample_dataset):
|
||||
"""Test data loader creation."""
|
||||
csv_path, img_dir, _ = sample_dataset
|
||||
|
||||
data_loaders, classes, class_to_idx = get_data_loaders(
|
||||
csv_file=csv_path,
|
||||
root_dir=img_dir,
|
||||
batch_size=4,
|
||||
num_workers=0 # Use 0 for testing
|
||||
)
|
||||
|
||||
assert 'train' in data_loaders
|
||||
assert 'val' in data_loaders
|
||||
assert 'test' in data_loaders
|
||||
assert len(classes) > 0
|
||||
assert len(class_to_idx) == len(classes)
|
||||
|
||||
# Test getting a batch
|
||||
batch = next(iter(data_loaders['train']))
|
||||
images, labels = batch
|
||||
|
||||
assert images.shape[0] <= 4 # Batch size
|
||||
assert images.shape[1] == 3 # RGB channels
|
||||
assert images.shape[2] == 224 # Height
|
||||
assert images.shape[3] == 224 # Width
|
||||
assert len(labels) == images.shape[0]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_class_weights(sample_dataset):
|
||||
"""Test class weight calculation."""
|
||||
csv_path, img_dir, _ = sample_dataset
|
||||
|
||||
dataset = PhenologyDataset(csv_path, img_dir, split='train')
|
||||
weights = dataset.get_class_weights()
|
||||
|
||||
assert weights.shape[0] == len(dataset.classes)
|
||||
assert torch.all(weights > 0)
|
||||
assert torch.isclose(weights.sum(), torch.tensor(1.0), atol=1e-6)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_reproducibility(sample_dataset):
|
||||
"""Test that same seed produces same splits."""
|
||||
csv_path, img_dir, _ = sample_dataset
|
||||
|
||||
dataset1 = PhenologyDataset(csv_path, img_dir, split='train', random_seed=42)
|
||||
dataset2 = PhenologyDataset(csv_path, img_dir, split='train', random_seed=42)
|
||||
|
||||
# Should have same samples
|
||||
assert len(dataset1) == len(dataset2)
|
||||
|
||||
# Check first few samples are the same
|
||||
for i in range(min(5, len(dataset1))):
|
||||
_, label1 = dataset1[i]
|
||||
_, label2 = dataset2[i]
|
||||
assert label1 == label2
|
||||
89
Code/Supervised_learning/resnet/tests/test_evaluate.py
Normal file
89
Code/Supervised_learning/resnet/tests/test_evaluate.py
Normal file
@ -0,0 +1,89 @@
|
||||
"""Unit tests for evaluation module."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import numpy as np
|
||||
from src.evaluate import evaluate_model, plot_confusion_matrix
|
||||
from src.model import ResNetPhenologyClassifier
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model():
|
||||
"""Create a small model for testing."""
|
||||
return ResNetPhenologyClassifier(num_classes=3, pretrained=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dataloader():
|
||||
"""Create a mock dataloader for testing."""
|
||||
images = torch.randn(8, 3, 224, 224)
|
||||
labels = torch.randint(0, 3, (8,))
|
||||
dataset = torch.utils.data.TensorDataset(images, labels)
|
||||
return torch.utils.data.DataLoader(dataset, batch_size=4)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_evaluate_model(mock_model, mock_dataloader):
|
||||
"""Test model evaluation."""
|
||||
device = 'cpu'
|
||||
model = mock_model.to(device)
|
||||
class_names = ['class_0', 'class_1', 'class_2']
|
||||
|
||||
metrics, cm = evaluate_model(model, mock_dataloader, device, class_names)
|
||||
|
||||
# Check metrics exist
|
||||
assert 'accuracy' in metrics
|
||||
assert 'recall_macro' in metrics
|
||||
assert 'f1_macro' in metrics
|
||||
assert 'per_class_recall' in metrics
|
||||
assert 'per_class_f1' in metrics
|
||||
assert 'confusion_matrix' in metrics
|
||||
|
||||
# Check metric ranges
|
||||
assert 0 <= metrics['accuracy'] <= 1
|
||||
assert 0 <= metrics['recall_macro'] <= 1
|
||||
assert 0 <= metrics['f1_macro'] <= 1
|
||||
|
||||
# Check confusion matrix shape
|
||||
assert len(metrics['confusion_matrix']) == 3
|
||||
assert len(metrics['confusion_matrix'][0]) == 3
|
||||
assert isinstance(cm, np.ndarray)
|
||||
assert cm.shape == (3, 3)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_confusion_matrix_values(mock_model, mock_dataloader):
|
||||
"""Test confusion matrix values."""
|
||||
device = 'cpu'
|
||||
model = mock_model.to(device)
|
||||
class_names = ['class_0', 'class_1', 'class_2']
|
||||
|
||||
_, cm = evaluate_model(model, mock_dataloader, device, class_names)
|
||||
|
||||
# Confusion matrix should sum to total number of samples
|
||||
total_samples = len(mock_dataloader.dataset)
|
||||
assert cm.sum() == total_samples
|
||||
|
||||
# All values should be non-negative
|
||||
assert np.all(cm >= 0)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_per_class_metrics(mock_model, mock_dataloader):
|
||||
"""Test per-class metrics."""
|
||||
device = 'cpu'
|
||||
model = mock_model.to(device)
|
||||
class_names = ['class_0', 'class_1', 'class_2']
|
||||
|
||||
metrics, _ = evaluate_model(model, mock_dataloader, device, class_names)
|
||||
|
||||
# Check per-class metrics have correct number of classes
|
||||
assert len(metrics['per_class_recall']) == 3
|
||||
assert len(metrics['per_class_f1']) == 3
|
||||
|
||||
# Check all classes have metrics
|
||||
for cls in class_names:
|
||||
assert cls in metrics['per_class_recall']
|
||||
assert cls in metrics['per_class_f1']
|
||||
assert 0 <= metrics['per_class_recall'][cls] <= 1
|
||||
assert 0 <= metrics['per_class_f1'][cls] <= 1
|
||||
113
Code/Supervised_learning/resnet/tests/test_inference.py
Normal file
113
Code/Supervised_learning/resnet/tests/test_inference.py
Normal file
@ -0,0 +1,113 @@
|
||||
"""Unit tests for inference module."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import json
|
||||
import tempfile
|
||||
import os
|
||||
from PIL import Image
|
||||
from src.inference import preprocess_image, predict, load_inference_model
|
||||
from src.model import ResNetPhenologyClassifier
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model():
|
||||
"""Create a small model for testing."""
|
||||
return ResNetPhenologyClassifier(num_classes=3, pretrained=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_image():
|
||||
"""Create a temporary test image."""
|
||||
temp_file = tempfile.NamedTemporaryFile(suffix='.jpg', delete=False)
|
||||
img = Image.new('RGB', (224, 224), color='red')
|
||||
img.save(temp_file.name)
|
||||
yield temp_file.name
|
||||
os.unlink(temp_file.name)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_class_mapping():
|
||||
"""Create a temporary class mapping file."""
|
||||
temp_file = tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False)
|
||||
mapping = {
|
||||
'classes': ['vegetative', 'flowering', 'fruiting'],
|
||||
'class_to_idx': {'vegetative': 0, 'flowering': 1, 'fruiting': 2}
|
||||
}
|
||||
json.dump(mapping, temp_file)
|
||||
temp_file.flush()
|
||||
yield temp_file.name
|
||||
os.unlink(temp_file.name)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_preprocess_image(temp_image):
|
||||
"""Test image preprocessing."""
|
||||
image_tensor, original_image = preprocess_image(temp_image)
|
||||
|
||||
assert image_tensor.shape == (1, 3, 224, 224) # Batch size 1, RGB, 224x224
|
||||
assert isinstance(original_image, Image.Image)
|
||||
assert original_image.size == (224, 224)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_predict(mock_model, temp_image):
|
||||
"""Test prediction."""
|
||||
device = 'cpu'
|
||||
model = mock_model.to(device)
|
||||
model.eval()
|
||||
|
||||
classes = ['vegetative', 'flowering', 'fruiting']
|
||||
image_tensor, _ = preprocess_image(temp_image)
|
||||
|
||||
result = predict(model, image_tensor, classes, device)
|
||||
|
||||
# Check result structure
|
||||
assert 'phase' in result
|
||||
assert 'confidence' in result
|
||||
assert 'probabilities' in result
|
||||
|
||||
# Check values
|
||||
assert result['phase'] in classes
|
||||
assert 0 <= result['confidence'] <= 1
|
||||
assert len(result['probabilities']) == len(classes)
|
||||
|
||||
# Check probabilities sum to ~1
|
||||
prob_sum = sum(result['probabilities'].values())
|
||||
assert abs(prob_sum - 1.0) < 1e-5
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_predict_confidence(mock_model, temp_image):
|
||||
"""Test that prediction confidence matches max probability."""
|
||||
device = 'cpu'
|
||||
model = mock_model.to(device)
|
||||
model.eval()
|
||||
|
||||
classes = ['vegetative', 'flowering', 'fruiting']
|
||||
image_tensor, _ = preprocess_image(temp_image)
|
||||
|
||||
result = predict(model, image_tensor, classes, device)
|
||||
|
||||
# Confidence should match the probability of the predicted class
|
||||
predicted_prob = result['probabilities'][result['phase']]
|
||||
assert abs(result['confidence'] - predicted_prob) < 1e-6
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_predict_consistency(mock_model, temp_image):
|
||||
"""Test that predictions are consistent."""
|
||||
device = 'cpu'
|
||||
model = mock_model.to(device)
|
||||
model.eval()
|
||||
|
||||
classes = ['vegetative', 'flowering', 'fruiting']
|
||||
image_tensor, _ = preprocess_image(temp_image)
|
||||
|
||||
# Run prediction twice
|
||||
result1 = predict(model, image_tensor, classes, device)
|
||||
result2 = predict(model, image_tensor, classes, device)
|
||||
|
||||
# Should get same results
|
||||
assert result1['phase'] == result2['phase']
|
||||
assert abs(result1['confidence'] - result2['confidence']) < 1e-6
|
||||
93
Code/Supervised_learning/resnet/tests/test_train.py
Normal file
93
Code/Supervised_learning/resnet/tests/test_train.py
Normal file
@ -0,0 +1,93 @@
|
||||
"""Unit tests for training module."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from src.train import train_epoch, validate
|
||||
from src.model import ResNetPhenologyClassifier
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model():
|
||||
"""Create a small model for testing."""
|
||||
return ResNetPhenologyClassifier(num_classes=3, pretrained=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dataloader():
|
||||
"""Create a mock dataloader for testing."""
|
||||
# Create dummy data
|
||||
images = torch.randn(8, 3, 224, 224)
|
||||
labels = torch.randint(0, 3, (8,))
|
||||
dataset = torch.utils.data.TensorDataset(images, labels)
|
||||
return torch.utils.data.DataLoader(dataset, batch_size=4)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_train_epoch(mock_model, mock_dataloader):
|
||||
"""Test training for one epoch."""
|
||||
device = 'cpu'
|
||||
model = mock_model.to(device)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
||||
|
||||
loss, acc = train_epoch(model, mock_dataloader, criterion, optimizer, device)
|
||||
|
||||
assert isinstance(loss, float)
|
||||
assert isinstance(acc, float)
|
||||
assert loss >= 0
|
||||
assert 0 <= acc <= 100
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_validate(mock_model, mock_dataloader):
|
||||
"""Test validation."""
|
||||
device = 'cpu'
|
||||
model = mock_model.to(device)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
loss, acc = validate(model, mock_dataloader, criterion, device)
|
||||
|
||||
assert isinstance(loss, float)
|
||||
assert isinstance(acc, float)
|
||||
assert loss >= 0
|
||||
assert 0 <= acc <= 100
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_model_gradients(mock_model):
|
||||
"""Test that gradients are computed during training."""
|
||||
device = 'cpu'
|
||||
model = mock_model.to(device)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.Adam(model.parameters())
|
||||
|
||||
# Create dummy batch
|
||||
images = torch.randn(2, 3, 224, 224)
|
||||
labels = torch.randint(0, 3, (2,))
|
||||
|
||||
# Forward pass
|
||||
outputs = model(images)
|
||||
loss = criterion(outputs, labels)
|
||||
|
||||
# Backward pass
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
|
||||
# Check gradients exist
|
||||
has_gradients = any(p.grad is not None for p in model.parameters())
|
||||
assert has_gradients
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_model_inference_mode(mock_model):
|
||||
"""Test that model can switch between train and eval modes."""
|
||||
model = mock_model
|
||||
|
||||
# Train mode
|
||||
model.train()
|
||||
assert model.training
|
||||
|
||||
# Eval mode
|
||||
model.eval()
|
||||
assert not model.training
|
||||
298
Code/Supervised_learning/train_resnet50.py
Normal file
298
Code/Supervised_learning/train_resnet50.py
Normal file
@ -0,0 +1,298 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Script facilitado para entrenar ResNet50 - Nocciola
|
||||
Comparación con MobileNetV2
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import argparse
|
||||
|
||||
# Configuración de rutas
|
||||
PROJECT_DIR = r'C:\Users\sof12\Desktop\ML'
|
||||
RESNET_SCRIPT = os.path.join(PROJECT_DIR, 'Code', 'Supervised_learning', 'ResNET.py')
|
||||
MOBILENET_SCRIPT = os.path.join(PROJECT_DIR, 'Code', 'Supervised_learning', 'MobileNetV1.py')
|
||||
PYTHON_EXE = r'C:/Users/sof12/AppData/Local/Programs/Python/Python312/python.exe'
|
||||
|
||||
# Datasets disponibles
|
||||
DATASETS = {
|
||||
'original': os.path.join(PROJECT_DIR, 'Datasets', 'Nocciola_GBIF', 'metadatos_unidos.csv'),
|
||||
'filtered': os.path.join(PROJECT_DIR, 'Datasets', 'Nocciola_GBIF', 'metadatos_unidos_filtered.csv'),
|
||||
'assignments': os.path.join(PROJECT_DIR, 'Datasets', 'Nocciola_GBIF', 'assignments.csv')
|
||||
}
|
||||
|
||||
def check_files():
|
||||
"""Verificar que todos los archivos necesarios existen"""
|
||||
print("🔍 === Verificando archivos ===")
|
||||
|
||||
files_to_check = {
|
||||
'ResNet Script': RESNET_SCRIPT,
|
||||
'MobileNet Script': MOBILENET_SCRIPT,
|
||||
'Python Executable': PYTHON_EXE
|
||||
}
|
||||
|
||||
all_ok = True
|
||||
for name, path in files_to_check.items():
|
||||
if os.path.exists(path):
|
||||
print(f"✅ {name}: {path}")
|
||||
else:
|
||||
print(f"❌ {name}: {path}")
|
||||
all_ok = False
|
||||
|
||||
# Verificar datasets
|
||||
print(f"\n📊 === Datasets disponibles ===")
|
||||
for name, path in DATASETS.items():
|
||||
if os.path.exists(path):
|
||||
print(f"✅ {name}: {path}")
|
||||
else:
|
||||
print(f"❌ {name}: {path}")
|
||||
|
||||
return all_ok
|
||||
|
||||
def train_resnet50(dataset_key='assignments', epochs=25, force_split=True):
|
||||
"""Entrenar modelo ResNet50"""
|
||||
dataset_path = DATASETS.get(dataset_key)
|
||||
|
||||
if not dataset_path or not os.path.exists(dataset_path):
|
||||
print(f"❌ Dataset '{dataset_key}' no encontrado")
|
||||
return False
|
||||
|
||||
print(f"\n🚀 === Entrenando ResNet50 ===")
|
||||
print(f"📊 Dataset: {dataset_key} ({dataset_path})")
|
||||
print(f"⏱️ Épocas: {epochs}")
|
||||
|
||||
# Comando de entrenamiento
|
||||
cmd = [
|
||||
PYTHON_EXE, RESNET_SCRIPT,
|
||||
'--csv_path', dataset_path,
|
||||
'--epochs', str(epochs)
|
||||
]
|
||||
|
||||
if force_split:
|
||||
cmd.append('--force_split')
|
||||
|
||||
print("🔄 Comando a ejecutar:")
|
||||
print(" ".join(cmd))
|
||||
|
||||
try:
|
||||
# Cambiar al directorio correcto
|
||||
os.chdir(os.path.join(PROJECT_DIR, 'Code', 'Supervised_learning'))
|
||||
|
||||
print(f"\n🏋️ Iniciando entrenamiento ResNet50...")
|
||||
result = subprocess.run(cmd, check=True)
|
||||
print("✅ Entrenamiento ResNet50 completado exitosamente!")
|
||||
return True
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"❌ Error durante el entrenamiento ResNet50: {e}")
|
||||
return False
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠️ Entrenamiento ResNet50 interrumpido por el usuario")
|
||||
return False
|
||||
|
||||
def train_mobilenet(dataset_key='assignments', epochs=25, force_split=True):
|
||||
"""Entrenar modelo MobileNetV2 para comparación"""
|
||||
dataset_path = DATASETS.get(dataset_key)
|
||||
|
||||
if not dataset_path or not os.path.exists(dataset_path):
|
||||
print(f"❌ Dataset '{dataset_key}' no encontrado")
|
||||
return False
|
||||
|
||||
print(f"\n🚀 === Entrenando MobileNetV2 (comparación) ===")
|
||||
print(f"📊 Dataset: {dataset_key} ({dataset_path})")
|
||||
print(f"⏱️ Épocas: {epochs}")
|
||||
|
||||
# Comando de entrenamiento
|
||||
cmd = [
|
||||
PYTHON_EXE, MOBILENET_SCRIPT,
|
||||
'--csv_path', dataset_path,
|
||||
'--epochs', str(epochs)
|
||||
]
|
||||
|
||||
if force_split:
|
||||
cmd.append('--force_split')
|
||||
|
||||
print("🔄 Comando a ejecutar:")
|
||||
print(" ".join(cmd))
|
||||
|
||||
try:
|
||||
# Cambiar al directorio correcto
|
||||
os.chdir(os.path.join(PROJECT_DIR, 'Code', 'Supervised_learning'))
|
||||
|
||||
print(f"\n🏋️ Iniciando entrenamiento MobileNetV2...")
|
||||
result = subprocess.run(cmd, check=True)
|
||||
print("✅ Entrenamiento MobileNetV2 completado exitosamente!")
|
||||
return True
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"❌ Error durante el entrenamiento MobileNetV2: {e}")
|
||||
return False
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠️ Entrenamiento MobileNetV2 interrumpido por el usuario")
|
||||
return False
|
||||
|
||||
def compare_models():
|
||||
"""Mostrar comparación entre modelos"""
|
||||
print("""
|
||||
📊 === Comparación ResNet50 vs MobileNetV2 ===
|
||||
|
||||
RESNET50:
|
||||
🟢 Ventajas:
|
||||
- Mayor capacidad de aprendizaje (25M parámetros)
|
||||
- Mejor para datasets complejos
|
||||
- Residual connections mejoran gradiente
|
||||
- Excelente para transfer learning
|
||||
|
||||
🟡 Desventajas:
|
||||
- Más lento en entrenamiento e inferencia
|
||||
- Requiere más memoria
|
||||
- Overfitting en datasets pequeños
|
||||
|
||||
MOBILENETV2:
|
||||
🟢 Ventajas:
|
||||
- Rápido y eficiente (3.4M parámetros)
|
||||
- Menos propenso a overfitting
|
||||
- Ideal para datasets pequeños/medianos
|
||||
- Menor uso de memoria
|
||||
|
||||
🟡 Desventajas:
|
||||
- Menor capacidad para patrones complejos
|
||||
- Puede underfittear en datos complejos
|
||||
|
||||
RECOMENDACIONES:
|
||||
📋 Dataset Nocciola (pequeño): MobileNetV2 recomendado
|
||||
📋 Dataset grande/complejo: ResNet50 recomendado
|
||||
📋 Producción/móvil: MobileNetV2
|
||||
📋 Investigación/precisión máxima: ResNet50
|
||||
""")
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Entrenamiento ResNet50 vs MobileNetV2 para Nocciola')
|
||||
parser.add_argument('--model', choices=['resnet50', 'mobilenet', 'both'], default='resnet50',
|
||||
help='Modelo a entrenar')
|
||||
parser.add_argument('--dataset', choices=['original', 'filtered', 'assignments'], default='assignments',
|
||||
help='Dataset a usar')
|
||||
parser.add_argument('--epochs', type=int, default=25,
|
||||
help='Número de épocas')
|
||||
parser.add_argument('--no_force_split', action='store_true',
|
||||
help='No forzar recreación del split')
|
||||
parser.add_argument('--compare', action='store_true',
|
||||
help='Mostrar comparación de modelos')
|
||||
parser.add_argument('--check_only', action='store_true',
|
||||
help='Solo verificar archivos')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("🎯 === Entrenamiento ResNet50 para Nocciola ===")
|
||||
print(f"📁 Directorio del proyecto: {PROJECT_DIR}")
|
||||
print(f"🐍 Python ejecutable: {PYTHON_EXE}")
|
||||
|
||||
# Verificar archivos
|
||||
if not check_files():
|
||||
print("❌ Faltan archivos necesarios")
|
||||
return False
|
||||
|
||||
if args.check_only:
|
||||
print("ℹ️ Solo verificación solicitada. Finalizando.")
|
||||
return True
|
||||
|
||||
if args.compare:
|
||||
compare_models()
|
||||
return True
|
||||
|
||||
force_split = not args.no_force_split
|
||||
success = True
|
||||
|
||||
# Entrenar modelos según selección
|
||||
if args.model in ['resnet50', 'both']:
|
||||
print(f"\n🤖 === Iniciando ResNet50 ===")
|
||||
success &= train_resnet50(args.dataset, args.epochs, force_split)
|
||||
|
||||
if args.model in ['mobilenet', 'both']:
|
||||
print(f"\n🤖 === Iniciando MobileNetV2 ===")
|
||||
success &= train_mobilenet(args.dataset, args.epochs, force_split)
|
||||
|
||||
if success:
|
||||
print(f"\n🎉 === Entrenamiento(s) Completado(s) ===")
|
||||
|
||||
# Mostrar ubicaciones de resultados
|
||||
results_info = []
|
||||
if args.model in ['resnet50', 'both']:
|
||||
resnet_results = os.path.join(PROJECT_DIR, 'Datasets', 'Nocciola_GBIF', 'results_resnet50_faseV')
|
||||
results_info.append(f"📁 ResNet50: {resnet_results}")
|
||||
|
||||
if args.model in ['mobilenet', 'both']:
|
||||
mobilenet_results = os.path.join(PROJECT_DIR, 'Datasets', 'Nocciola_GBIF', 'results_mobilenet_faseV_V1')
|
||||
results_info.append(f"📁 MobileNetV2: {mobilenet_results}")
|
||||
|
||||
for info in results_info:
|
||||
print(info)
|
||||
|
||||
if args.model == 'both':
|
||||
print(f"\n💡 Comparar resultados:")
|
||||
print(f" - Revisar classification_report.txt en cada directorio")
|
||||
print(f" - Comparar matrices de confusión")
|
||||
print(f" - Evaluar tiempos de entrenamiento")
|
||||
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def show_help():
|
||||
"""Mostrar ayuda detallada"""
|
||||
print("""
|
||||
📚 === Ayuda - ResNet50 Transfer Learning ===
|
||||
|
||||
NUEVO MODELO IMPLEMENTADO:
|
||||
ResNet50 con transfer learning optimizado para el dataset Nocciola
|
||||
|
||||
CARACTERÍSTICAS PRINCIPALES:
|
||||
🔧 Arquitectura: ResNet50 pre-entrenado (ImageNet)
|
||||
⚙️ Optimizaciones: BatchNormalization, Dropout adaptativos
|
||||
🎯 Fine-tuning: Descongelado selectivo de capas residuales
|
||||
📊 Manejo robusto: Clases desbalanceadas automáticamente
|
||||
|
||||
DIFERENCIAS CON MOBILENET:
|
||||
- Learning rate más conservador (5e-4 vs 1e-3)
|
||||
- Data augmentation menos agresivo
|
||||
- Fine-tuning más específico (conv5_x layers)
|
||||
- Más épocas de fine-tuning (15 vs 10)
|
||||
|
||||
EJEMPLOS DE USO:
|
||||
|
||||
# Entrenar solo ResNet50
|
||||
python train_resnet50.py --model resnet50 --epochs 25
|
||||
|
||||
# Comparar ambos modelos
|
||||
python train_resnet50.py --model both --epochs 20
|
||||
|
||||
# Análisis de diferencias
|
||||
python train_resnet50.py --compare
|
||||
|
||||
# Verificar setup
|
||||
python train_resnet50.py --check_only
|
||||
|
||||
DATASETS SOPORTADOS:
|
||||
- assignments: CSV con asignaciones de clases (recomendado)
|
||||
- filtered: Dataset filtrado sin clases problemáticas
|
||||
- original: Dataset completo con manejo robusto
|
||||
|
||||
OUTPUTS ESPECÍFICOS:
|
||||
- best_resnet50_model.keras: Mejor modelo durante entrenamiento
|
||||
- final_resnet50_model.keras: Modelo final
|
||||
- resnet50_training_history.png: Gráficos específicos
|
||||
- resnet50_confusion_matrix.png: Matriz de confusión
|
||||
- resnet50_prediction_examples.png: Ejemplos de predicciones
|
||||
""")
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) > 1 and sys.argv[1] in ['--help', '-h', 'help']:
|
||||
show_help()
|
||||
else:
|
||||
success = main()
|
||||
if not success:
|
||||
print("\n❌ Proceso terminado con errores")
|
||||
print("💡 Usa --help para más información")
|
||||
else:
|
||||
print("\n🎉 ¡Proceso completado exitosamente!")
|
||||
433
Code/Unsupervised_learning/PCA_V1_C.py
Normal file
433
Code/Unsupervised_learning/PCA_V1_C.py
Normal file
@ -0,0 +1,433 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import argparse
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.decomposition import PCA
|
||||
from sklearn.cluster import KMeans, DBSCAN, AgglomerativeClustering, MiniBatchKMeans
|
||||
from sklearn.metrics import silhouette_score, calinski_harabasz_score, davies_bouldin_score
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.neighbors import NearestNeighbors
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
import joblib
|
||||
|
||||
import tensorflow as tf
|
||||
from keras.applications import MobileNetV2, EfficientNetB0
|
||||
from keras.applications.mobilenet_v2 import preprocess_input as mobilenet_preprocess
|
||||
from keras.applications.efficientnet import preprocess_input as efficientnet_preprocess
|
||||
from keras import backend as K
|
||||
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
|
||||
K.set_image_data_format("channels_last")
|
||||
|
||||
# -----------------------------
|
||||
# Utils
|
||||
# -----------------------------
|
||||
|
||||
def set_seed(seed: int = 42):
|
||||
np.random.seed(seed)
|
||||
tf.random.set_seed(seed)
|
||||
os.environ["PYTHONHASHSEED"] = str(seed)
|
||||
|
||||
def ensure_dir(path: str):
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
def _read_csv_any(path: str) -> pd.DataFrame:
|
||||
for enc in ("utf-8", "utf-8-sig", "latin-1"):
|
||||
try:
|
||||
return pd.read_csv(path, encoding=enc)
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
return pd.read_csv(path, encoding="utf-8", errors="replace")
|
||||
|
||||
def _normalize_col_name(name: str) -> str:
|
||||
if not isinstance(name, str):
|
||||
return ""
|
||||
s = name.strip().lower()
|
||||
m = re.match(r"^(.*)_(a|b)$", s)
|
||||
if m:
|
||||
s = m.group(1)
|
||||
for ch in [" ", "_", "-", ".", "/"]:
|
||||
s = s.replace(ch, "")
|
||||
return s
|
||||
|
||||
def find_matching_cols(df: pd.DataFrame, aliases: List[str]) -> List[str]:
|
||||
tgt = {_normalize_col_name(a) for a in aliases}
|
||||
out = []
|
||||
for c in df.columns:
|
||||
if _normalize_col_name(c) in tgt:
|
||||
out.append(c)
|
||||
return out
|
||||
|
||||
def best_filename_from_row(row: pd.Series, img_ext: str = ".jpg") -> Optional[str]:
|
||||
for key in ["filename", "file_name", "image", "image_name", "New_Name_With_Date", "New_Name", "Nombre_Nuevo", "Old_Name"]:
|
||||
if key in row and pd.notna(row[key]) and str(row[key]).strip() != "":
|
||||
fname = str(row[key]).strip()
|
||||
if not os.path.splitext(fname)[1]:
|
||||
fname = fname + img_ext
|
||||
return fname
|
||||
for key in ["basename_final", "basename"]:
|
||||
if key in row and pd.notna(row[key]) and str(row[key]).strip() != "":
|
||||
return f"{row[key]}{img_ext}"
|
||||
return None
|
||||
|
||||
def attach_paths_single_csv(df: pd.DataFrame, images_dir: str, img_ext: str = ".jpg", search_subdirs: bool = False) -> pd.DataFrame:
|
||||
paths = []
|
||||
miss = 0
|
||||
for _, r in df.iterrows():
|
||||
fname = best_filename_from_row(r, img_ext)
|
||||
if not fname:
|
||||
paths.append((None, None))
|
||||
miss += 1
|
||||
continue
|
||||
p = os.path.join(images_dir, fname)
|
||||
if not os.path.exists(p) and search_subdirs:
|
||||
# buscar en subcarpetas
|
||||
found = None
|
||||
for root, _, files in os.walk(images_dir):
|
||||
if fname in files:
|
||||
found = os.path.join(root, fname)
|
||||
break
|
||||
p = found if found else p
|
||||
paths.append((fname, p if p and isinstance(p, str) and os.path.exists(p) else None))
|
||||
if paths[-1][1] is None:
|
||||
miss += 1
|
||||
if miss:
|
||||
warnings.warn(f"{miss} archivos listados no existen en disco. Serán ignorados.")
|
||||
out = df.copy()
|
||||
out["filename"] = [t[0] for t in paths]
|
||||
out["path"] = [t[1] for t in paths]
|
||||
out = out[pd.notna(out["path"])].reset_index(drop=True)
|
||||
return out
|
||||
|
||||
# -----------------------------
|
||||
# Embeddings
|
||||
# -----------------------------
|
||||
|
||||
def make_preprocess(backbone: str):
|
||||
return mobilenet_preprocess if backbone == "mobilenet" else efficientnet_preprocess
|
||||
|
||||
def make_backbone_model(img_size: int, backbone: str) -> tf.keras.Model:
|
||||
tf.keras.backend.clear_session()
|
||||
K.set_image_data_format("channels_last")
|
||||
input_shape = (img_size, img_size, 3)
|
||||
if backbone == "efficientnet":
|
||||
try:
|
||||
model = EfficientNetB0(include_top=False, weights="imagenet", input_shape=input_shape, pooling="avg")
|
||||
except Exception as e:
|
||||
warnings.warn(f"No se pudo cargar EfficientNetB0 con pesos ImageNet ({e}). Se usarán pesos aleatorios.")
|
||||
model = EfficientNetB0(include_top=False, weights=None, input_shape=input_shape, pooling="avg")
|
||||
else:
|
||||
model = MobileNetV2(include_top=False, weights="imagenet", input_shape=input_shape, pooling="avg")
|
||||
model.trainable = False
|
||||
return model
|
||||
|
||||
def build_dataset(paths: List[str], img_size: int, preprocess_fn, batch_size: int = 64) -> tf.data.Dataset:
|
||||
ds = tf.data.Dataset.from_tensor_slices(paths)
|
||||
def _load_tf(p):
|
||||
x = tf.io.read_file(p)
|
||||
x = tf.image.decode_jpeg(x, channels=3)
|
||||
x = tf.image.resize(x, [img_size, img_size], method="bilinear", antialias=True)
|
||||
x = tf.cast(x, tf.float32)
|
||||
x = preprocess_fn(x)
|
||||
return x
|
||||
return ds.map(_load_tf, num_parallel_calls=tf.data.AUTOTUNE).batch(batch_size).prefetch(tf.data.AUTOTUNE)
|
||||
|
||||
def compute_embeddings(model: tf.keras.Model, ds: tf.data.Dataset) -> np.ndarray:
|
||||
return model.predict(ds, verbose=1)
|
||||
|
||||
# -----------------------------
|
||||
# Reduction + clustering
|
||||
# -----------------------------
|
||||
|
||||
def fit_reduction(train_emb: np.ndarray, n_pca: int = 50):
|
||||
scaler = StandardScaler()
|
||||
Xs = scaler.fit_transform(train_emb)
|
||||
pca = PCA(n_components=min(n_pca, Xs.shape[1]))
|
||||
Z = pca.fit_transform(Xs)
|
||||
return scaler, pca, Z
|
||||
|
||||
def transform_reduction(emb: np.ndarray, scaler: StandardScaler, pca: PCA) -> np.ndarray:
|
||||
return pca.transform(scaler.transform(emb))
|
||||
|
||||
def _centers_from_labels(X: np.ndarray, y: np.ndarray) -> Optional[np.ndarray]:
|
||||
cs = []
|
||||
for c in sorted(set(y)):
|
||||
if c == -1:
|
||||
continue
|
||||
cs.append(X[y == c].mean(axis=0))
|
||||
return np.array(cs) if cs else None
|
||||
|
||||
def tune_dbscan(train_feats: np.ndarray,
|
||||
metric: str = "euclidean",
|
||||
min_samples_grid=(3, 5, 10),
|
||||
quantiles=(0.6, 0.7, 0.8, 0.9)) -> Tuple[Optional[DBSCAN], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
best = {"score": -np.inf, "model": None, "labels": None}
|
||||
for ms in min_samples_grid:
|
||||
k = max(2, min(ms, len(train_feats)-1))
|
||||
nbrs = NearestNeighbors(n_neighbors=k, metric=metric).fit(train_feats)
|
||||
dists, _ = nbrs.kneighbors(train_feats)
|
||||
kth = np.sort(dists[:, -1])
|
||||
for q in quantiles:
|
||||
eps = float(np.quantile(kth, q))
|
||||
m = DBSCAN(eps=eps, min_samples=ms, metric=metric, n_jobs=-1)
|
||||
y = m.fit_predict(train_feats)
|
||||
valid = y[y != -1]
|
||||
if len(np.unique(valid)) < 2:
|
||||
continue
|
||||
try:
|
||||
score = silhouette_score(train_feats[y != -1], y[y != -1])
|
||||
except Exception:
|
||||
score = -np.inf
|
||||
if score > best["score"]:
|
||||
best = {"score": score, "model": m, "labels": y}
|
||||
if best["model"] is None:
|
||||
return None, None, None
|
||||
return best["model"], best["labels"], _centers_from_labels(train_feats, best["labels"])
|
||||
|
||||
def fit_cluster_algo(kind: str,
|
||||
n_clusters: int,
|
||||
train_feats: np.ndarray,
|
||||
fast_kmeans: bool = True,
|
||||
dbscan_eps: float = 0.8,
|
||||
dbscan_min_samples: int = 5,
|
||||
dbscan_metric: str = "euclidean",
|
||||
dbscan_auto: bool = False):
|
||||
if kind == "kmeans":
|
||||
m = MiniBatchKMeans(n_clusters=n_clusters, batch_size=2048, n_init=10, random_state=42) if fast_kmeans \
|
||||
else KMeans(n_clusters=n_clusters, n_init=10, random_state=42)
|
||||
y = m.fit_predict(train_feats)
|
||||
return m, y, getattr(m, "cluster_centers_", None)
|
||||
|
||||
if kind == "dbscan":
|
||||
if dbscan_auto:
|
||||
m, y, centers = tune_dbscan(train_feats, metric=dbscan_metric)
|
||||
if m is None:
|
||||
warnings.warn("DBSCAN(auto) no encontró ≥2 clusters. Fallback a KMeans.")
|
||||
km = MiniBatchKMeans(n_clusters=max(n_clusters, 2), batch_size=2048, n_init=10, random_state=42)
|
||||
y = km.fit_predict(train_feats)
|
||||
return km, y, km.cluster_centers_
|
||||
print(f"DBSCAN(auto) seleccionado (metric={dbscan_metric}).")
|
||||
return m, y, centers
|
||||
m = DBSCAN(eps=dbscan_eps, min_samples=dbscan_min_samples, metric=dbscan_metric, n_jobs=-1)
|
||||
y = m.fit_predict(train_feats)
|
||||
uniq = set(y) - {-1}
|
||||
if len(uniq) < 2:
|
||||
warnings.warn(f"DBSCAN devolvió {len(uniq)} cluster(s) válido(s). Considera ajustar eps/min_samples/metric o usar --dbscan_auto.")
|
||||
return m, y, _centers_from_labels(train_feats, y)
|
||||
|
||||
ag = AgglomerativeClustering(n_clusters=n_clusters)
|
||||
y = ag.fit_predict(train_feats)
|
||||
centers = _centers_from_labels(train_feats, y)
|
||||
return ag, y, centers
|
||||
|
||||
def assign_to_nearest_centroid(feats: np.ndarray, centers: Optional[np.ndarray]) -> np.ndarray:
|
||||
if centers is None or len(centers) == 0:
|
||||
return np.full((feats.shape[0],), -1, dtype=int)
|
||||
d = ((feats[:, None, :] - centers[None, :, :]) ** 2).sum(axis=2)
|
||||
return np.argmin(d, axis=1)
|
||||
|
||||
def internal_metrics(X: np.ndarray, y: np.ndarray) -> dict:
|
||||
m = y != -1
|
||||
if m.sum() > 1 and len(np.unique(y[m])) > 1:
|
||||
return {
|
||||
"silhouette": float(silhouette_score(X[m], y[m])),
|
||||
"calinski_harabasz": float(calinski_harabasz_score(X[m], y[m])),
|
||||
"davies_bouldin": float(davies_bouldin_score(X[m], y[m])),
|
||||
}
|
||||
return {"silhouette": None, "calinski_harabasz": None, "davies_bouldin": None}
|
||||
|
||||
# -----------------------------
|
||||
# Plot
|
||||
# -----------------------------
|
||||
|
||||
def plot_scatter_2d(X2d: np.ndarray, labels: np.ndarray, title: str, out_path: str):
|
||||
plt.figure(figsize=(8, 6))
|
||||
uniq = np.unique(labels)
|
||||
if len(uniq) <= 1:
|
||||
sns.scatterplot(x=X2d[:, 0], y=X2d[:, 1], s=12, linewidth=0, color="#1f77b4", legend=False)
|
||||
else:
|
||||
palette = sns.color_palette("tab20", n_colors=len(uniq))
|
||||
sns.scatterplot(x=X2d[:, 0], y=X2d[:, 1], hue=labels, palette=palette, s=12, linewidth=0, legend=False)
|
||||
plt.title(title)
|
||||
plt.tight_layout()
|
||||
plt.savefig(out_path, dpi=180)
|
||||
plt.close()
|
||||
|
||||
# -----------------------------
|
||||
# Main
|
||||
# -----------------------------
|
||||
|
||||
def parse_args():
|
||||
p = argparse.ArgumentParser(description="Unsupervised clustering for Carciofo (single CSV)")
|
||||
p.add_argument("--images_dir", default=r"C:\Users\sof12\Desktop\ML\Datasets\Carciofo_GBIF", help="Carpeta que contiene las imágenes")
|
||||
p.add_argument("--csv_path", default=r"C:\Users\sof12\Desktop\ML\Datasets\Carciofo_GBIF\joined_metadata.csv")
|
||||
p.add_argument("--out_dir", default=r"C:\Users\sof12\Desktop\ML\Datasets\Carciofo_GBIF\TrainingV2")
|
||||
p.add_argument("--img_ext", default=".jpg")
|
||||
p.add_argument("--img_size", type=int, default=224)
|
||||
p.add_argument("--batch_size", type=int, default=64)
|
||||
p.add_argument("--seed", type=int, default=42)
|
||||
p.add_argument("--sample", type=int, default=None)
|
||||
p.add_argument("--search_subdirs", action="store_true", help="Buscar archivos faltantes en subcarpetas")
|
||||
p.add_argument("--backbone", choices=["mobilenet", "efficientnet"], default="mobilenet")
|
||||
p.add_argument("--cluster", choices=["kmeans", "dbscan", "agglomerative"], default="kmeans")
|
||||
p.add_argument("--n_clusters", type=int, default=5)
|
||||
p.add_argument("--fast_kmeans", action="store_true")
|
||||
# DBSCAN
|
||||
p.add_argument("--dbscan_eps", type=float, default=0.8)
|
||||
p.add_argument("--dbscan_min_samples", type=int, default=5)
|
||||
p.add_argument("--dbscan_metric", choices=["euclidean", "cosine", "manhattan"], default="euclidean")
|
||||
p.add_argument("--dbscan_auto", action="store_true")
|
||||
return p.parse_args()
|
||||
|
||||
# ...existing code...
|
||||
def main():
|
||||
args = parse_args()
|
||||
set_seed(args.seed)
|
||||
ensure_dir(args.out_dir)
|
||||
|
||||
print("Loading CSV...")
|
||||
df = _read_csv_any(args.csv_path)
|
||||
|
||||
print("Resolving filenames and verifying files on disk...")
|
||||
df = attach_paths_single_csv(df, args.images_dir, img_ext=args.img_ext, search_subdirs=args.search_subdirs)
|
||||
if len(df) == 0:
|
||||
print("No images found. Check images_dir and csv_path.")
|
||||
return
|
||||
|
||||
# --- Solo 'fase' (Carciofo no usa 'fase V' / 'fase R') ---
|
||||
phase_cols = find_matching_cols(df, ["fase"])
|
||||
if phase_cols:
|
||||
ser_phase = None
|
||||
for c in phase_cols:
|
||||
ser_phase = df[c] if ser_phase is None else ser_phase.combine_first(df[c])
|
||||
df["fase"] = ser_phase
|
||||
print(f"Using column(s) for 'fase': {phase_cols}")
|
||||
else:
|
||||
warnings.warn("No se encontró columna 'fase' en el CSV. No se incluirá en el output.")
|
||||
# --- fin fase ---
|
||||
|
||||
# Optional sampling
|
||||
if args.sample is not None and args.sample < len(df):
|
||||
df = df.sample(n=args.sample, random_state=args.seed).reset_index(drop=True)
|
||||
|
||||
# Split indices
|
||||
print("Splitting train/val/test...")
|
||||
idx_all = np.arange(len(df))
|
||||
idx_train, idx_tmp = train_test_split(idx_all, test_size=0.30, random_state=args.seed, shuffle=True)
|
||||
idx_val, idx_test = train_test_split(idx_tmp, test_size=0.50, random_state=args.seed, shuffle=True)
|
||||
|
||||
df_train = df.iloc[idx_train].reset_index(drop=True)
|
||||
df_val = df.iloc[idx_val].reset_index(drop=True)
|
||||
df_test = df.iloc[idx_test].reset_index(drop=True)
|
||||
|
||||
# Embeddings in one pass
|
||||
print("Building embedding model...")
|
||||
preprocess_fn = make_preprocess(args.backbone)
|
||||
model = make_backbone_model(args.img_size, args.backbone)
|
||||
|
||||
print("Computing embeddings (one pass)...")
|
||||
ds_all = build_dataset(df["path"].tolist(), args.img_size, preprocess_fn, args.batch_size)
|
||||
emb_all = compute_embeddings(model, ds_all)
|
||||
|
||||
emb_train = emb_all[idx_train]
|
||||
emb_val = emb_all[idx_val]
|
||||
emb_test = emb_all[idx_test]
|
||||
|
||||
# PCA reduction
|
||||
print("Fitting PCA reduction (50D for clustering, 2D for plots)...")
|
||||
scaler, pca50, train_50 = fit_reduction(emb_train, n_pca=50)
|
||||
val_50 = transform_reduction(emb_val, scaler, pca50)
|
||||
test_50 = transform_reduction(emb_test, scaler, pca50)
|
||||
|
||||
pca2 = PCA(n_components=2).fit(scaler.transform(emb_train))
|
||||
train_2d = pca2.transform(scaler.transform(emb_train))
|
||||
val_2d = pca2.transform(scaler.transform(emb_val))
|
||||
test_2d = pca2.transform(scaler.transform(emb_test))
|
||||
|
||||
# Clustering
|
||||
print(f"Clustering with {args.cluster}...")
|
||||
model_c, y_train, centers = fit_cluster_algo(
|
||||
args.cluster, args.n_clusters, train_50,
|
||||
fast_kmeans=args.fast_kmeans,
|
||||
dbscan_eps=args.dbscan_eps,
|
||||
dbscan_min_samples=args.dbscan_min_samples,
|
||||
dbscan_metric=args.dbscan_metric,
|
||||
dbscan_auto=args.dbscan_auto,
|
||||
)
|
||||
|
||||
if args.cluster == "kmeans":
|
||||
y_val = model_c.predict(val_50)
|
||||
y_test = model_c.predict(test_50)
|
||||
else:
|
||||
y_val = assign_to_nearest_centroid(val_50, centers)
|
||||
y_test = assign_to_nearest_centroid(test_50, centers)
|
||||
|
||||
# Metrics
|
||||
print("Computing internal metrics...")
|
||||
train_m = internal_metrics(train_50, y_train)
|
||||
val_m = internal_metrics(val_50, y_val)
|
||||
test_m = internal_metrics(test_50, y_test)
|
||||
|
||||
# Save outputs (filename, fase, cluster, split)
|
||||
print("Saving outputs...")
|
||||
ensure_dir(args.out_dir)
|
||||
|
||||
def pick_min(df_split: pd.DataFrame, y: np.ndarray, split: str) -> pd.DataFrame:
|
||||
cols = ["filename", "fase"]
|
||||
keep = [c for c in cols if c in df_split.columns]
|
||||
out = df_split[keep].copy()
|
||||
out["cluster"] = y
|
||||
out["split"] = split
|
||||
return out
|
||||
|
||||
train_out = pick_min(df_train, y_train, "train")
|
||||
val_out = pick_min(df_val, y_val, "val")
|
||||
test_out = pick_min(df_test, y_test, "test")
|
||||
|
||||
assignments = pd.concat([train_out, val_out, test_out], ignore_index=True)
|
||||
assignments.to_csv(os.path.join(args.out_dir, "assignments.csv"), index=False, encoding="utf-8")
|
||||
train_out.to_csv(os.path.join(args.out_dir, "train_assignments.csv"), index=False, encoding="utf-8")
|
||||
val_out.to_csv(os.path.join(args.out_dir, "val_assignments.csv"), index=False, encoding="utf-8")
|
||||
test_out.to_csv(os.path.join(args.out_dir, "test_assignments.csv"), index=False, encoding="utf-8")
|
||||
|
||||
# Save models
|
||||
joblib.dump(scaler, os.path.join(args.out_dir, "scaler.joblib"))
|
||||
joblib.dump(pca50, os.path.join(args.out_dir, "pca50.joblib"))
|
||||
joblib.dump(pca2, os.path.join(args.out_dir, "pca2.joblib"))
|
||||
joblib.dump(model_c, os.path.join(args.out_dir, f"{args.cluster}.joblib"))
|
||||
|
||||
# Plots
|
||||
plot_scatter_2d(train_2d, y_train, f"Train clusters ({args.cluster})", os.path.join(args.out_dir, "train_clusters_2d.png"))
|
||||
plot_scatter_2d(val_2d, y_val, f"Val clusters ({args.cluster})", os.path.join(args.out_dir, "val_clusters_2d.png"))
|
||||
plot_scatter_2d(test_2d, y_test, f"Test clusters ({args.cluster})", os.path.join(args.out_dir, "test_clusters_2d.png"))
|
||||
|
||||
# Summary
|
||||
summary = {
|
||||
"counts": {"train": len(df_train), "val": len(df_val), "test": len(df_test)},
|
||||
"cluster": args.cluster, "n_clusters": args.n_clusters,
|
||||
"backbone": args.backbone, "img_size": args.img_size,
|
||||
"internal_metrics": {"train": train_m, "val": val_m, "test": test_m},
|
||||
"csv": os.path.join(args.out_dir, "assignments.csv"),
|
||||
}
|
||||
with open(os.path.join(args.out_dir, "summary.json"), "w", encoding="utf-8") as f:
|
||||
json.dump(summary, f, indent=2, ensure_ascii=False)
|
||||
|
||||
# Optional: save features
|
||||
np.save(os.path.join(args.out_dir, "features.npy"), emb_all)
|
||||
np.save(os.path.join(args.out_dir, "feature_paths.npy"), df["path"].to_numpy())
|
||||
|
||||
print("Done. Results saved to:", args.out_dir)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -21,7 +21,7 @@ from sklearn.manifold import TSNE
|
||||
import umap
|
||||
|
||||
# ========== CONFIG ==========
|
||||
ASSIGNMENTS_CSV = r"C:\Users\sof12\Desktop\ML\Datasets\Nocciola_GBIF\TrainingV7\assignments.csv"
|
||||
ASSIGNMENTS_CSV = r"C:\Users\sof12\Desktop\ML\Datasets\Carciofo_GBIF\TrainingV2\assignments.csv"
|
||||
OUT_DIR = os.path.dirname(ASSIGNMENTS_CSV) # donde el pipeline guardó joblibs / features
|
||||
METHOD = "umap" # 'umap' o 'tsne'
|
||||
RANDOM_STATE = 42
|
||||
@ -30,7 +30,7 @@ UMAP_MIN_DIST = 0.1
|
||||
TSNE_PERPLEXITY = 30
|
||||
TSNE_ITER = 1000
|
||||
SAVE_PLOT = True
|
||||
PLOT_BY = ["cluster", "fase V", "fase R"] # lista de columnas de assignments.csv para colorear (usa lo que tengas)
|
||||
PLOT_BY = ["cluster", "fase"] # lista de columnas de assignments.csv para colorear (usa lo que tengas)
|
||||
# ============================
|
||||
|
||||
def find_file(patterns, folder):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user