875 lines
33 KiB
Python
875 lines
33 KiB
Python
"""
|
|
ResNet50 Transfer Learning for Phenological Phase Classification - Hazelnut/Artichoke
|
|
This code was originally developed in Google Colab and has been adapted for Visual Studio Code.
|
|
Dataset Path: C:/Users/sof12/Desktop/ML/Datasets/Nocciola/GBIF
|
|
Objective: Predict phenological phase R (reproductive)
|
|
"""
|
|
|
|
#----------------- IMPORTS -----------------
|
|
import os
|
|
import shutil
|
|
import random
|
|
import argparse
|
|
from pathlib import Path
|
|
import pandas as pd
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
import seaborn as sns
|
|
import tensorflow as tf
|
|
from tensorflow.keras import layers, models
|
|
from tensorflow.keras.applications import ResNet50
|
|
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
|
from sklearn.metrics import classification_report, confusion_matrix
|
|
from sklearn.utils import class_weight
|
|
import json
|
|
|
|
# ----------------- CONFIG -----------------
|
|
PROJECT_PATH = ""
|
|
IMAGES_DIR = PROJECT_PATH # The images are in the main project directory
|
|
CSV_PATH = os.path.join(PROJECT_PATH, 'tags.csv') # Main CSV
|
|
OUTPUT_DIR = os.path.join(PROJECT_PATH, 'results_resnet50_fase_V_Combi(AV)')
|
|
|
|
IMG_SIZE = (224, 224) # Recommended for ResNet50
|
|
BATCH_SIZE = 16 # Standard batch size
|
|
SEED = 42 # Seed for reproducibility
|
|
SPLIT = {'train': 0.7, 'val': 0.15, 'test': 0.15}
|
|
FORCE_SPLIT = True # Whether to force re-creation of the data split
|
|
|
|
# ----------------- Utilities -----------------
|
|
def set_seed(seed=42):
|
|
"""Set seed for reproducibility"""
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
tf.random.set_seed(seed)
|
|
|
|
def analyze_class_distribution(df, column_name):
|
|
"""Analyze class distribution and detect imbalances"""
|
|
print(f"\n Analyzing class distribution for column: '{column_name}'")
|
|
|
|
min_samples = 2 # Recommended minimum threshold
|
|
|
|
# Count by class
|
|
counts = df[column_name].value_counts()
|
|
total = len(df)
|
|
real_split = []
|
|
|
|
print(f" Total samples: {total}")
|
|
print(f" Number of classes: {len(counts)}")
|
|
print(f" Class distribution:")
|
|
|
|
# Show detailed statistics
|
|
for clase, count in counts.items():
|
|
percentage = (count / total) * 100
|
|
if count >= min_samples:
|
|
real_split = clase
|
|
print(f" - {clase}: {count} samples ({percentage:.1f}%)")
|
|
|
|
# Detect problematic classes
|
|
small_classes = counts[counts < min_samples]
|
|
if len(small_classes) > 0:
|
|
print(f"\n Classes with less than {min_samples} samples:")
|
|
for clase, count in small_classes.items():
|
|
print(f" - {clase}: {count} samples")
|
|
|
|
print(f"\n Recommendations:")
|
|
print(f" 1. Consider collecting more data for these classes")
|
|
print(f" 2. Or merge similar classes")
|
|
print(f" 3. Or use specific data augmentation techniques")
|
|
|
|
return real_split
|
|
|
|
def safe_read_csv(path):
|
|
"""Read CSV with encoding handling"""
|
|
if not os.path.exists(path):
|
|
raise FileNotFoundError(f'CSV not found: {path}')
|
|
try:
|
|
df = pd.read_csv(path, encoding='utf-8')
|
|
except UnicodeDecodeError:
|
|
try:
|
|
df = pd.read_csv(path, encoding='latin-1')
|
|
except:
|
|
df = pd.read_csv(path, encoding='iso-8859-1')
|
|
return df
|
|
|
|
def resolve_image_path(images_dir, img_id):
|
|
"""Resolve the full path of an image"""
|
|
if pd.isna(img_id) or str(img_id).strip() == '':
|
|
return None
|
|
|
|
img_id = str(img_id).strip()
|
|
|
|
# Verify if the image exists directly
|
|
direct_path = os.path.join(images_dir, img_id)
|
|
if os.path.exists(direct_path):
|
|
return direct_path
|
|
|
|
# Try common extensions
|
|
stem = os.path.splitext(img_id)[0]
|
|
for ext in ['.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG']:
|
|
img_path = os.path.join(images_dir, stem + ext)
|
|
if os.path.exists(img_path):
|
|
return img_path
|
|
|
|
return None
|
|
|
|
def prepare_image_folders(df, images_dir, out_dir, split=SPLIT, seed=SEED, column_name='fase R', test_AV=False, class_split=[]):
|
|
"""Create folder structure for flow_from_directory"""
|
|
set_seed(seed)
|
|
|
|
# Filter only rows with valid phase and existing images
|
|
print(f"Initial data: {len(df)} rows")
|
|
|
|
# Filter rows with valid phase
|
|
df_valid = df.dropna(subset=[column_name]).copy()
|
|
df_valid = df_valid[df_valid[column_name].str.strip() != '']
|
|
print(f"With valid phase: {len(df_valid)} rows")
|
|
|
|
# Verify existence of images
|
|
valid_rows = []
|
|
for _, row in df_valid.iterrows():
|
|
img_path = resolve_image_path(images_dir, row['id_img'])
|
|
if img_path:
|
|
valid_rows.append(row)
|
|
else:
|
|
print(f"Image not found: {row['id_img']}")
|
|
|
|
if not valid_rows:
|
|
raise ValueError("No valid images found")
|
|
|
|
df_final = pd.DataFrame(valid_rows)
|
|
print(f"With existing images: {len(df_final)} rows")
|
|
|
|
# Show class distribution
|
|
fase_counts = df_final[column_name].value_counts()
|
|
print(f"\n Distribution of phase:")
|
|
for fase, count in fase_counts.items():
|
|
print(f" - {fase}: {count} images")
|
|
|
|
# Remove classes with very few samples (less than 3)
|
|
min_samples = 3
|
|
valid_phases = fase_counts[fase_counts >= min_samples].index.tolist()
|
|
if len(valid_phases) < len(fase_counts):
|
|
excluded = fase_counts[fase_counts < min_samples].index.tolist()
|
|
print(f"Excluded phases with less than {min_samples} samples: {excluded}")
|
|
df_final = df_final[df_final[column_name].isin(valid_phases)]
|
|
print(f"After filtering: {len(df_final)} rows, {len(valid_phases)} classes")
|
|
|
|
labels = df_final[column_name].unique().tolist()
|
|
print(f"Final classes: {labels}")
|
|
|
|
# Standard split - Stratified by class
|
|
train_dfs = []
|
|
val_dfs = []
|
|
test_dfs = []
|
|
|
|
# Shuffle and split data
|
|
if test_AV:
|
|
print("\n=== Using test_AV mode: NocciolaAV images prioritized for test ===")
|
|
|
|
# Split each class separately to maintain proportions
|
|
for label in labels:
|
|
# Filter data for this class
|
|
df_class = df_final[df_final[column_name] == label].copy()
|
|
n_class = len(df_class)
|
|
|
|
# Filter for AV dataset
|
|
df_class['is_av'] = df_class['id_img'].astype(str).str.contains('NocciolaAV', case=False, na=False)
|
|
df_class_av = df_class[df_class['is_av']].copy()
|
|
df_class_non_av = df_class[~df_class['is_av']].copy()
|
|
|
|
# Shuffle this class
|
|
df_class_av_shuffled = df_class_av.sample(frac=1, random_state=seed).reset_index(drop=True)
|
|
df_class_non_av_shuffled = df_class_non_av.sample(frac=1, random_state=seed).reset_index(drop=True)
|
|
|
|
# Remove the helper column
|
|
df_class_av_shuffled = df_class_av_shuffled.drop(columns=['is_av'], errors='ignore')
|
|
df_class_non_av_shuffled = df_class_non_av_shuffled.drop(columns=['is_av'], errors='ignore')
|
|
|
|
# Calculate split sizes for this class
|
|
n_train_class = int(n_class * split['train'])
|
|
n_val_class = int(n_class * split['val'])
|
|
n_test_class = n_class - n_train_class - n_val_class
|
|
|
|
# Split this class
|
|
if len(df_class_av_shuffled) > n_test_class:
|
|
test_class = df_class_av_shuffled.iloc[:n_test_class]
|
|
df_class_non_av_shuffled = pd.concat([df_class_av_shuffled.iloc[n_test_class:], df_class_non_av_shuffled], ignore_index=True)
|
|
df_class_non_av_shuffled = df_class_non_av_shuffled.sample(frac=1, random_state=seed).reset_index(drop=True)
|
|
elif len(df_class_av_shuffled) == n_test_class:
|
|
test_class = df_class_av_shuffled
|
|
else:
|
|
needed = n_test_class - len(df_class_av_shuffled)
|
|
test_class = pd.concat([df_class_av_shuffled, df_class_non_av_shuffled.iloc[:needed]], ignore_index=True)
|
|
test_class = test_class.sample(frac=1, random_state=seed).reset_index(drop=True)
|
|
|
|
|
|
if needed:
|
|
train_class = df_class_non_av_shuffled.iloc[needed:n_train_class]
|
|
else:
|
|
train_class = df_class_non_av_shuffled.iloc[:n_train_class]
|
|
|
|
val_class = df_class_non_av_shuffled.iloc[n_test_class + n_train_class:]
|
|
|
|
|
|
# Store splits
|
|
train_dfs.append(train_class)
|
|
val_dfs.append(val_class)
|
|
test_dfs.append(test_class)
|
|
|
|
print(f" Class '{label}': {n_class} total -> Train: {len(train_class)}, Val: {len(val_class)}, Test: {len(test_class)}")
|
|
|
|
else:
|
|
print("\n=== Standard stratified split ===")
|
|
# Split each class separately to maintain proportions
|
|
for label in labels:
|
|
# Filter data for this class
|
|
df_class = df_final[df_final[column_name] == label].copy()
|
|
n_class = len(df_class)
|
|
|
|
# Shuffle this class
|
|
df_class_shuffled = df_class.sample(frac=1, random_state=seed).reset_index(drop=True)
|
|
|
|
# Calculate split sizes for this class
|
|
n_train_class = int(n_class * split['train'])
|
|
n_val_class = int(n_class * split['val'])
|
|
#n_test_class = n_class - n_train_class - n_val_class
|
|
|
|
# Split this class
|
|
train_class = df_class_shuffled.iloc[:n_train_class]
|
|
val_class = df_class_shuffled.iloc[n_train_class:n_train_class + n_val_class]
|
|
test_class = df_class_shuffled.iloc[n_train_class + n_val_class:]
|
|
|
|
# Store splits
|
|
train_dfs.append(train_class)
|
|
val_dfs.append(val_class)
|
|
test_dfs.append(test_class)
|
|
|
|
print(f" Class '{label}': {n_class} total -> Train: {len(train_class)}, Val: {len(val_class)}, Test: {len(test_class)}")
|
|
|
|
# Combine all classes for each split
|
|
train_df = pd.concat(train_dfs, ignore_index=True).sample(frac=1, random_state=seed).reset_index(drop=True)
|
|
val_df = pd.concat(val_dfs, ignore_index=True).sample(frac=1, random_state=seed).reset_index(drop=True)
|
|
test_df = pd.concat(test_dfs, ignore_index=True).sample(frac=1, random_state=seed).reset_index(drop=True)
|
|
|
|
print(f"\nFinal split (stratified):")
|
|
print(f" - Training: {len(train_df)} images")
|
|
print(f" - Validation: {len(val_df)} images")
|
|
print(f" - Test: {len(test_df)} images")
|
|
|
|
# Verify class distribution in each split
|
|
print(f"\nClass distribution verification:")
|
|
for subset_name, subset_df in [('Train', train_df), ('Val', val_df), ('Test', test_df)]:
|
|
dist = subset_df[column_name].value_counts().sort_index()
|
|
print(f" {subset_name}: {dict(dist)}")
|
|
|
|
# Create folder structure
|
|
for part in ['train', 'val', 'test']:
|
|
for label in labels:
|
|
label_dir = os.path.join(out_dir, part, str(label))
|
|
os.makedirs(label_dir, exist_ok=True)
|
|
|
|
# Function to copy images
|
|
def copy_subset(subdf, subset_name, column_name):
|
|
copied, missing = 0, 0
|
|
failed = []
|
|
miss = []
|
|
for _, row in subdf.iterrows():
|
|
src = resolve_image_path(images_dir, row['id_img'])
|
|
if src:
|
|
fase = str(row[column_name])
|
|
# Crear carpeta de destino si no existe (por si acaso)
|
|
dest_dir = os.path.join(out_dir, subset_name, fase)
|
|
os.makedirs(dest_dir, exist_ok=True)
|
|
|
|
# Usar el nombre original del archivo (con extensión correcta)
|
|
original_filename = os.path.basename(src)
|
|
dst = os.path.join(dest_dir, original_filename)
|
|
|
|
# Verificar que la ruta de destino no sea demasiado larga
|
|
if len(dst) > 260:
|
|
print(f" Ruta demasiado larga ({len(dst)} caracteres): {dst}")
|
|
# Crear nombre más corto
|
|
ext = os.path.splitext(original_filename)[1]
|
|
short_name = f"img_{copied:04d}{ext}"
|
|
dst = os.path.join(dest_dir, short_name)
|
|
print(f" Usando nombre corto: {short_name}")
|
|
|
|
try:
|
|
# Verificar que origen y destino existen/son válidos
|
|
if not os.path.exists(src):
|
|
print(f" Archivo origen no existe: {src}")
|
|
missing += 1
|
|
miss.append(row['id_img'])
|
|
continue
|
|
|
|
if not os.path.exists(dest_dir):
|
|
print(f" Carpeta destino no existe: {dest_dir}")
|
|
os.makedirs(dest_dir, exist_ok=True)
|
|
|
|
shutil.copy2(src, dst)
|
|
copied += 1
|
|
except Exception as e:
|
|
print(f" Error copying {src} to {dst}: {e}")
|
|
print(f" - Source exists: {os.path.exists(src)}")
|
|
print(f" - Dest dir exists: {os.path.exists(dest_dir)}")
|
|
print(f" - Source path length: {len(src)}")
|
|
print(f" - Dest path length: {len(dst)}")
|
|
failed.append(src)
|
|
missing += 1
|
|
else:
|
|
missing += 1
|
|
miss.append(row['id_img'])
|
|
|
|
miss_file_path = os.path.join(os.getcwd(), f'missing_{subset_name}.txt')
|
|
with open(miss_file_path, 'w') as f:
|
|
f.write(f"Missing images in {subset_name}:\n")
|
|
for item in miss:
|
|
f.write(f"{item}\n")
|
|
f.write(f"\n\n Failed to copy:\n")
|
|
for item in failed:
|
|
f.write(f"{item}\n")
|
|
|
|
print(f" {subset_name}: {copied} images copied, {missing} failed")
|
|
return copied
|
|
|
|
# Copy images to the corresponding folders
|
|
copy_subset(train_df, 'train', column_name)
|
|
copy_subset(val_df, 'val', column_name)
|
|
copy_subset(test_df, 'test', column_name)
|
|
|
|
return train_df, val_df, test_df
|
|
|
|
def main():
|
|
"""Main function to run the ResNet50 transfer learning pipeline"""
|
|
parser = argparse.ArgumentParser(description='ResNet50 Transfer Learning for Nocciola')
|
|
parser.add_argument('--csv_path', type=str, default=CSV_PATH,
|
|
help='Path to the CSV file with image assignments')
|
|
parser.add_argument('--images_dir', type=str, default=IMAGES_DIR,
|
|
help='Directory with the images')
|
|
parser.add_argument('--output_dir', type=str, default=OUTPUT_DIR,
|
|
help='Output directory for results')
|
|
parser.add_argument('--epochs', type=int, default=30,
|
|
help='Number of training epochs')
|
|
parser.add_argument('--force_split', action='store_true',
|
|
help='Force recreation of the data split')
|
|
parser.add_argument('--phase', type=str, default='fase R',
|
|
help='Phase of the analysis (V or R)')
|
|
parser.add_argument('--test', action='store_true', default=True,
|
|
help='Run model using AV dataset for testing purposes')
|
|
|
|
args = parser.parse_args()
|
|
|
|
print('\n === Start of the pipeline ===')
|
|
print(f"Images directory: {args.images_dir}")
|
|
print(f"CSV file: {args.csv_path}")
|
|
print(f"Output directory: {args.output_dir}")
|
|
|
|
gpus = tf.config.list_physical_devices('GPU')
|
|
try:
|
|
tf.config.set_logical_device_configuration(
|
|
gpus[0],
|
|
[tf.config.LogicalDeviceConfiguration(memory_limit=4096)])
|
|
logical_gpus = tf.config.list_logical_devices('GPU')
|
|
print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
|
|
except RuntimeError as e:
|
|
# Visible devices must be set before GPUs have been initialized
|
|
print(e)
|
|
|
|
# Set seed
|
|
set_seed(SEED)
|
|
|
|
# Create output directory
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
|
|
|
# Load data
|
|
print('\n === Loading data ===')
|
|
df = safe_read_csv(args.csv_path)
|
|
print(f'Total of records in CSV: {len(df)}')
|
|
print(f'Available columns: {list(df.columns)}')
|
|
|
|
# Check required columns
|
|
required_cols = {'id_img', args.phase}
|
|
if not required_cols.issubset(set(df.columns)):
|
|
missing = required_cols - set(df.columns)
|
|
raise ValueError(f'CSV must contain the columns: {missing}')
|
|
|
|
# Analyze class distribution before processing
|
|
real_split = analyze_class_distribution(df, args.phase)
|
|
|
|
# Prepare folder structure
|
|
SPLIT_DIR = os.path.join(args.output_dir, 'data_split')
|
|
|
|
if args.force_split and os.path.exists(SPLIT_DIR):
|
|
print("Eliminating existing split")
|
|
shutil.rmtree(SPLIT_DIR)
|
|
|
|
if not os.path.exists(SPLIT_DIR):
|
|
print("\n=== Creating new data split ===")
|
|
train_df, val_df, test_df = prepare_image_folders(df, args.images_dir, SPLIT_DIR, column_name=args.phase, test_AV=args.test, class_split=real_split)
|
|
|
|
# Save split information
|
|
train_df.to_csv(os.path.join(args.output_dir, 'train_split.csv'), index=False)
|
|
val_df.to_csv(os.path.join(args.output_dir, 'val_split.csv'), index=False)
|
|
test_df.to_csv(os.path.join(args.output_dir, 'test_split.csv'), index=False)
|
|
|
|
else:
|
|
print("\n === Reuse existing split ===")
|
|
# Load split information if it exists
|
|
try:
|
|
train_df = pd.read_csv(os.path.join(args.output_dir, 'train_split.csv'))
|
|
val_df = pd.read_csv(os.path.join(args.output_dir, 'val_split.csv'))
|
|
test_df = pd.read_csv(os.path.join(args.output_dir, 'test_split.csv'))
|
|
except:
|
|
print("Could not load split files, recreating")
|
|
train_df, val_df, test_df = prepare_image_folders(df, args.images_dir, SPLIT_DIR, column_name=args.phase, test_AV=args.test, class_split=real_split)
|
|
|
|
# Create data generators
|
|
print("\n=== Creating data generators ===")
|
|
|
|
# Data augmentation for training
|
|
train_datagen = ImageDataGenerator(
|
|
rescale=1./255,
|
|
rotation_range=20,
|
|
width_shift_range=0.1,
|
|
height_shift_range=0.1,
|
|
shear_range=0.1,
|
|
zoom_range=0.1,
|
|
horizontal_flip=True,
|
|
fill_mode='nearest'
|
|
)
|
|
|
|
# Only normalization for validation and test
|
|
val_test_datagen = ImageDataGenerator(rescale=1./255)
|
|
|
|
# Create generators
|
|
train_gen = train_datagen.flow_from_directory(
|
|
os.path.join(SPLIT_DIR, 'train'),
|
|
target_size=IMG_SIZE,
|
|
batch_size=BATCH_SIZE,
|
|
class_mode='categorical',
|
|
seed=SEED
|
|
)
|
|
|
|
val_gen = val_test_datagen.flow_from_directory(
|
|
os.path.join(SPLIT_DIR, 'val'),
|
|
target_size=IMG_SIZE,
|
|
batch_size=BATCH_SIZE,
|
|
class_mode='categorical',
|
|
shuffle=False
|
|
)
|
|
|
|
test_gen = val_test_datagen.flow_from_directory(
|
|
os.path.join(SPLIT_DIR, 'test'),
|
|
target_size=IMG_SIZE,
|
|
batch_size=BATCH_SIZE,
|
|
class_mode='categorical',
|
|
shuffle=False
|
|
)
|
|
|
|
# Save class mapping
|
|
class_indices = train_gen.class_indices
|
|
print(f'Class mapping: {class_indices}')
|
|
|
|
with open(os.path.join(args.output_dir, 'class_indices.json'), 'w') as f:
|
|
json.dump(class_indices, f, indent=2)
|
|
|
|
print(f"Samples per class:")
|
|
print(f" - Training: {train_gen.samples}")
|
|
print(f" - Validation: {val_gen.samples}")
|
|
print(f" - Test: {test_gen.samples}")
|
|
print(f" - Number of classes: {train_gen.num_classes}")
|
|
|
|
# Create and train model
|
|
print("\n === Model construction ===")
|
|
|
|
# Base model ResNet50
|
|
base_model = ResNet50(
|
|
weights='imagenet',
|
|
include_top=False,
|
|
input_shape=(*IMG_SIZE, 3)
|
|
)
|
|
base_model.trainable = False # Freeze base model
|
|
|
|
# Build sequential model
|
|
model = models.Sequential([
|
|
base_model,
|
|
layers.GlobalAveragePooling2D(),
|
|
layers.BatchNormalization(),
|
|
layers.Dropout(0.5),
|
|
layers.Dense(512, activation='relu'),
|
|
layers.BatchNormalization(),
|
|
layers.Dropout(0.5),
|
|
layers.Dense(256, activation='relu'),
|
|
layers.Dropout(0.3),
|
|
layers.Dense(train_gen.num_classes, activation='softmax')
|
|
])
|
|
|
|
# Compile model
|
|
model.compile(
|
|
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
|
|
loss='categorical_crossentropy',
|
|
metrics=['accuracy']
|
|
)
|
|
|
|
print(" Summary of the model:")
|
|
model.summary()
|
|
|
|
# Calculate class weights
|
|
print("\n === Calculating class weights ===")
|
|
try:
|
|
# Get training labels
|
|
train_labels = []
|
|
for i in range(len(train_gen)):
|
|
_, labels = train_gen[i]
|
|
train_labels.extend(np.argmax(labels, axis=1))
|
|
if len(train_labels) >= train_gen.samples:
|
|
break
|
|
|
|
# Calculate class weights
|
|
class_weights = class_weight.compute_class_weight(
|
|
'balanced',
|
|
classes=np.unique(train_labels),
|
|
y=train_labels
|
|
)
|
|
class_weight_dict = dict(zip(np.unique(train_labels), class_weights))
|
|
print(f"Class weights: {class_weight_dict}")
|
|
|
|
except Exception as e:
|
|
print(f"Error calculating class weights: {e}")
|
|
class_weight_dict = None
|
|
|
|
# Callbacks for training
|
|
early_stopping = tf.keras.callbacks.EarlyStopping(
|
|
monitor='val_loss',
|
|
patience=7,
|
|
restore_best_weights=True,
|
|
verbose=1
|
|
)
|
|
|
|
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
|
|
os.path.join(args.output_dir, 'best_model.keras'),
|
|
save_best_only=True,
|
|
monitor='val_loss',
|
|
verbose=1
|
|
)
|
|
|
|
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
|
|
monitor='val_loss',
|
|
factor=0.2,
|
|
patience=3,
|
|
min_lr=1e-7,
|
|
verbose=1
|
|
)
|
|
|
|
callbacks = [early_stopping, model_checkpoint, reduce_lr]
|
|
|
|
# Initial training
|
|
print(f"\n === Initial training ({args.epochs} epochs) ===")
|
|
|
|
try:
|
|
history = model.fit(
|
|
train_gen,
|
|
validation_data=val_gen,
|
|
epochs=args.epochs,
|
|
callbacks=callbacks,
|
|
class_weight=class_weight_dict,
|
|
verbose=1
|
|
)
|
|
|
|
print("Initial training completed successfully")
|
|
|
|
except Exception as e:
|
|
print(f"Error during training: {e}")
|
|
# Train without class_weight if there are issues
|
|
print("Trying training without class weights")
|
|
history = model.fit(
|
|
train_gen,
|
|
validation_data=val_gen,
|
|
epochs=args.epochs,
|
|
callbacks=callbacks,
|
|
verbose=1
|
|
)
|
|
|
|
# Fine-tuning
|
|
print("\n === Fine-tuning ===")
|
|
|
|
# Unfreeze some layers of the base model
|
|
base_model.trainable = True
|
|
fine_tune_at = 143 # Unfreeze after conv5_block1 (ResNet50 has 175 layers)
|
|
|
|
for layer in base_model.layers[:fine_tune_at]:
|
|
layer.trainable = False
|
|
|
|
# Recompile with a lower learning rate
|
|
model.compile(
|
|
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),
|
|
loss='categorical_crossentropy',
|
|
metrics=['accuracy']
|
|
)
|
|
|
|
# Continue training
|
|
fine_tune_epochs = 10
|
|
total_epochs = len(history.history['loss']) + fine_tune_epochs
|
|
|
|
try:
|
|
history_fine = model.fit(
|
|
train_gen,
|
|
validation_data=val_gen,
|
|
epochs=total_epochs,
|
|
initial_epoch=len(history.history['loss']),
|
|
callbacks=callbacks,
|
|
verbose=1
|
|
)
|
|
|
|
print("Fine-tuning completed successfully")
|
|
|
|
# Combine histories
|
|
for key in history.history:
|
|
if key in history_fine.history:
|
|
history.history[key].extend(history_fine.history[key])
|
|
|
|
except Exception as e:
|
|
print(f"Error during fine-tuning: {e}")
|
|
print("Continuing with initial training model")
|
|
|
|
# Final evaluation
|
|
print("\n === Evaluation on test set ===")
|
|
|
|
# Load best model
|
|
try:
|
|
model.load_weights(os.path.join(args.output_dir, 'best_model.keras'))
|
|
print(" Loaded best saved model")
|
|
except:
|
|
print(" Using current model")
|
|
|
|
# Save final model
|
|
model.save(os.path.join(args.output_dir, 'final_model.keras'))
|
|
print("Saved final model")
|
|
|
|
# Predictions on test set
|
|
test_gen.reset()
|
|
y_pred_prob = model.predict(test_gen, verbose=1)
|
|
y_pred = np.argmax(y_pred_prob, axis=1)
|
|
y_true = test_gen.classes
|
|
|
|
# Map indices to class names
|
|
index_to_class = {v: k for k, v in class_indices.items()}
|
|
|
|
# Get only the classes that actually appear in the test set
|
|
unique_test_classes = np.unique(np.concatenate([y_true, y_pred]))
|
|
test_class_names = [index_to_class[i] for i in unique_test_classes]
|
|
|
|
print(f"Classes in test set: {len(unique_test_classes)}")
|
|
print(f"All trained classes: {len(class_indices)}")
|
|
print(f"Classes present in test: {test_class_names}")
|
|
|
|
# Check for missing classes
|
|
all_classes = set(range(len(class_indices)))
|
|
test_classes = set(unique_test_classes)
|
|
missing_classes = all_classes - test_classes
|
|
|
|
if missing_classes:
|
|
missing_names = [index_to_class[i] for i in missing_classes]
|
|
print(f"Classes without samples in test: {missing_names}")
|
|
|
|
# Classification report with filtered classes
|
|
print("\n === Classification Report ===")
|
|
try:
|
|
report = classification_report(
|
|
y_true, y_pred,
|
|
labels=unique_test_classes, # Specify exact classes
|
|
target_names=test_class_names,
|
|
output_dict=False,
|
|
zero_division=0 # Handle division by zero
|
|
)
|
|
print(report)
|
|
|
|
# Save report
|
|
with open(os.path.join(args.output_dir, 'classification_report.txt'), 'w') as f:
|
|
f.write(f"Classes evaluated: {test_class_names}\n")
|
|
f.write(f"Classes missing in test: {[index_to_class[i] for i in missing_classes] if missing_classes else 'None'}\n\n")
|
|
f.write(report)
|
|
|
|
except Exception as e:
|
|
print(f"Error in classification_report: {e}")
|
|
print("Generating alternative report")
|
|
|
|
# Manual report if automatic fails
|
|
from collections import Counter
|
|
true_counts = Counter(y_true)
|
|
pred_counts = Counter(y_pred)
|
|
|
|
print("\n Manual distribution:")
|
|
print("Class | True | Predicted")
|
|
print("-" * 35)
|
|
for class_idx in unique_test_classes:
|
|
class_name = index_to_class[class_idx]
|
|
true_count = true_counts.get(class_idx, 0)
|
|
pred_count = pred_counts.get(class_idx, 0)
|
|
print(f"{class_name[:15]:15} | {true_count:10} | {pred_count:9}")
|
|
|
|
# Calculate basic accuracy
|
|
accuracy = np.mean(y_true == y_pred)
|
|
print(f"\nOverall accuracy: {accuracy:.4f}")
|
|
|
|
# Save manual report
|
|
with open(os.path.join(args.output_dir, 'classification_report.txt'), 'w') as f:
|
|
f.write("MANUAL CLASSIFICATION REPORT\n")
|
|
f.write("=" * 40 + "\n\n")
|
|
f.write(f"Classes evaluated: {test_class_names}\n")
|
|
f.write(f"Classes missing in test: {[index_to_class[i] for i in missing_classes] if missing_classes else 'None'}\n\n")
|
|
f.write("Class distribution:\n")
|
|
f.write("Class | True | Predicted\n")
|
|
f.write("-" * 35 + "\n")
|
|
for class_idx in unique_test_classes:
|
|
class_name = index_to_class[class_idx]
|
|
true_count = true_counts.get(class_idx, 0)
|
|
pred_count = pred_counts.get(class_idx, 0)
|
|
f.write(f"{class_name[:15]:15} | {true_count:10} | {pred_count:9}\n")
|
|
f.write(f"\nOverall accuracy: {accuracy:.4f}\n")
|
|
|
|
# Confusion matrix with filtered classes
|
|
cm = confusion_matrix(y_true, y_pred, labels=unique_test_classes)
|
|
print(f"\n Confusion Matrix ({len(unique_test_classes)} classes):")
|
|
print(cm)
|
|
|
|
np.savetxt(os.path.join(args.output_dir, 'confusion_matrix.csv'),
|
|
cm, delimiter=',', fmt='%d')
|
|
|
|
# Visualizations with filtered classes
|
|
print("\n === Generating visualizations ===")
|
|
|
|
# Training plot
|
|
plot_training_history(history, args.output_dir)
|
|
|
|
# Confusion matrix visualization with filtered classes
|
|
plot_confusion_matrix(cm, test_class_names, args.output_dir)
|
|
|
|
# Prediction examples with filtered classes
|
|
plot_prediction_examples(test_gen, y_true, y_pred, test_class_names, args.output_dir, unique_test_classes)
|
|
|
|
print(f"\n=== Pipeline completed ===")
|
|
print(f" Results saved in: {args.output_dir}")
|
|
print(f" Final test accuracy: {np.mean(y_true == y_pred):.4f}")
|
|
print(f" Classes evaluated: {len(unique_test_classes)}/{len(class_indices)}")
|
|
|
|
# Additional information about imbalanced classes
|
|
if missing_classes:
|
|
print(f"\n === Information about Imbalanced Classes ===")
|
|
print(f" Classes without samples in test: {len(missing_classes)}")
|
|
for missing_idx in missing_classes:
|
|
missing_name = index_to_class[missing_idx]
|
|
print(f" - {missing_name} (index {missing_idx})")
|
|
print(f" Suggestion: Consider increasing the dataset or merging similar classes")
|
|
|
|
def plot_training_history(history, output_dir):
|
|
"""Plot training history"""
|
|
try:
|
|
plt.figure(figsize=(12, 4))
|
|
|
|
# Accuracy
|
|
plt.subplot(1, 2, 1)
|
|
plt.plot(history.history['accuracy'], label='Training')
|
|
if 'val_accuracy' in history.history:
|
|
plt.plot(history.history['val_accuracy'], label='Validation')
|
|
plt.title('Model Accuracy')
|
|
plt.xlabel('Epoch')
|
|
plt.ylabel('Accuracy')
|
|
plt.legend()
|
|
plt.grid(True)
|
|
|
|
# Loss
|
|
plt.subplot(1, 2, 2)
|
|
plt.plot(history.history['loss'], label='Training')
|
|
if 'val_loss' in history.history:
|
|
plt.plot(history.history['val_loss'], label='Validation')
|
|
plt.title('Model Loss')
|
|
plt.xlabel('Epoch')
|
|
plt.ylabel('Loss')
|
|
plt.legend()
|
|
plt.grid(True)
|
|
|
|
plt.tight_layout()
|
|
plt.savefig(os.path.join(output_dir, 'training_history.png'), dpi=300, bbox_inches='tight')
|
|
plt.close()
|
|
print("Saved training history plot")
|
|
|
|
except Exception as e:
|
|
print(f"Error creating training history plot: {e}")
|
|
|
|
def plot_confusion_matrix(cm, class_names, output_dir):
|
|
"""Plot confusion matrix"""
|
|
try:
|
|
plt.figure(figsize=(10, 8))
|
|
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
|
|
xticklabels=class_names, yticklabels=class_names)
|
|
plt.title('Confusion Matrix')
|
|
plt.ylabel('True Label')
|
|
plt.xlabel('Predicted Label')
|
|
plt.xticks(rotation=45, ha='right')
|
|
plt.yticks(rotation=0)
|
|
plt.tight_layout()
|
|
plt.savefig(os.path.join(output_dir, 'confusion_matrix.png'), dpi=300, bbox_inches='tight')
|
|
plt.close()
|
|
print("Saved confusion matrix plot")
|
|
|
|
except Exception as e:
|
|
print(f"Error creating confusion matrix plot: {e}")
|
|
|
|
def plot_prediction_examples(test_gen, y_true, y_pred, class_names, output_dir, unique_classes=None, n_examples=12):
|
|
"""Show examples of correct and incorrect predictions"""
|
|
try:
|
|
# Get indices of correct and incorrect predictions
|
|
correct_idx = np.where(y_true == y_pred)[0]
|
|
incorrect_idx = np.where(y_true != y_pred)[0]
|
|
|
|
# Select examples
|
|
n_correct = min(n_examples // 2, len(correct_idx))
|
|
n_incorrect = min(n_examples // 2, len(incorrect_idx))
|
|
|
|
selected_correct = np.random.choice(correct_idx, n_correct, replace=False) if len(correct_idx) > 0 else []
|
|
selected_incorrect = np.random.choice(incorrect_idx, n_incorrect, replace=False) if len(incorrect_idx) > 0 else []
|
|
|
|
selected_indices = np.concatenate([selected_correct, selected_incorrect])
|
|
|
|
if len(selected_indices) == 0:
|
|
print("There are no examples to show.")
|
|
return
|
|
|
|
# Create plot
|
|
n_show = len(selected_indices)
|
|
cols = 4
|
|
rows = (n_show + cols - 1) // cols
|
|
|
|
plt.figure(figsize=(15, 4 * rows))
|
|
|
|
for i, idx in enumerate(selected_indices):
|
|
plt.subplot(rows, cols, i + 1)
|
|
|
|
# Obtain the image from the generator
|
|
# This is a workaround since we don't have direct access to the images
|
|
img_path = test_gen.filepaths[idx]
|
|
img = plt.imread(img_path)
|
|
|
|
plt.imshow(img)
|
|
plt.axis('off')
|
|
|
|
true_label = class_names[y_true[idx]]
|
|
pred_label = class_names[y_pred[idx]]
|
|
|
|
color = 'green' if y_true[idx] == y_pred[idx] else 'red'
|
|
plt.title(f'True: {true_label}\nPredicted: {pred_label}',
|
|
color=color, fontsize=10)
|
|
|
|
plt.suptitle('Prediction Examples (Green=Correct, Red=Incorrect)', fontsize=14)
|
|
plt.tight_layout()
|
|
plt.savefig(os.path.join(output_dir, 'prediction_examples.png'), dpi=300, bbox_inches='tight')
|
|
plt.close()
|
|
print("Saved prediction examples plot")
|
|
|
|
except Exception as e:
|
|
print(f"Error creating prediction examples plot: {e}")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|