import os
import random
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Lambda, Input, Dropout
from tensorflow.keras.applications import EfficientNetB4
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import EarlyStopping, Callback
from sklearn.model_selection import train_test_split
# --- 1. 配置和参数 ---
IMAGE_SIZE = (380, 380)
BATCH_SIZE = 32
EMBEDDING_DIM = 256
EPOCHS = 50
MARGIN = 0.3
DATASET_PATH = "train"
MODEL_WEIGHTS_PATH = 'initial_best_pet_embedding_model.h5' # 第一阶段保存的权重文件
PATIENCE = 5
# --- 2. 工具函数:加载、增强和预处理图片 ---
def load_and_preprocess_image(image_path, target_size, apply_augmentation=False):
img = tf.io.read_file(image_path)
img = tf.image.decode_jpeg(img, channels=3)
if apply_augmentation:
img = tf.image.resize(img, [int(target_size[0] * 1.1), int(target_size[1] * 1.1)])
img = tf.image.random_crop(img, size=[target_size[0], target_size[1], 3])
img = tf.image.random_flip_left_right(img)
img = tf.image.random_brightness(img, max_delta=0.2)
img = tf.image.random_contrast(img, lower=0.8, upper=1.2)
img = tf.image.random_saturation(img, lower=0.8, upper=1.2)
else:
img = tf.image.resize(img, target_size)
img = tf.keras.applications.efficientnet.preprocess_input(img)
return img
# --- 3. 自定义模型保存器 ---
class CustomModelSaver(Callback):
def __init__(self, filepath, monitor='val_loss', mode='min', verbose=1):
super().__init__()
self.filepath = filepath
self.monitor = monitor
self.mode = mode
self.verbose = verbose
self.best_value = np.Inf if mode == 'min' else -np.Inf
self.embedding_model_ref = None
def on_train_begin(self, logs=None):
for layer in self.model.layers:
if isinstance(layer, Model) and 'embedding_model' in layer.name:
self.embedding_model_ref = layer
break
if not self.embedding_model_ref:
raise ValueError("CustomModelSaver failed to find the embedding_model.")
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
current_value = logs.get(self.monitor)
if current_value is None:
if self.verbose > 0:
print(f'\nWarning: CustomModelSaver could not find metric {self.monitor}. Skipping save.')
return
is_best = (self.mode == 'min' and current_value < self.best_value) or \
(self.mode == 'max' and current_value > self.best_value)
if is_best:
if self.verbose > 0:
print(
f'\nEpoch {epoch + 1}: {self.monitor} improved from {self.best_value:.4f} to {current_value:.4f}, saving model weights to {self.filepath}')
self.best_value = current_value
self.embedding_model_ref.save_weights(self.filepath)
# --- 4. 修复后的三元组损失函数 (使用批内负样本) ---
def triplet_loss(margin=0.5):
def loss_function(y_true, y_pred):
batch_size = tf.shape(y_pred)[0] // 2
anchor_embeddings = y_pred[:batch_size]
positive_embeddings = y_pred[batch_size:]
positive_dist = tf.reduce_sum(tf.square(anchor_embeddings - positive_embeddings), axis=1)
negatives = positive_embeddings
dist_matrix = tf.reduce_sum(tf.square(tf.expand_dims(anchor_embeddings, 1) - tf.expand_dims(negatives, 0)),
axis=2)
identity_matrix = tf.eye(batch_size, dtype=tf.bool)
masked_dist_matrix = tf.where(identity_matrix, 1e12, dist_matrix)
hardest_negative_dist = tf.reduce_min(masked_dist_matrix, axis=1)
loss = tf.maximum(0.0, positive_dist - hardest_negative_dist + margin)
return tf.reduce_mean(loss)
return loss_function
# --- 5. 构建模型 (修正 EfficientNetB4 内部层数问题) ---
def build_embedding_model(input_shape, embedding_dim):
base_model = EfficientNetB4(weights='imagenet', include_top=False, input_shape=input_shape)
x = base_model.output
x = GlobalAveragePooling2D(name='global_average_pooling_2d')(x)
x = Dropout(0.2, name='dropout')(x)
x = Dense(embedding_dim, name='dense')(x)
x = Lambda(lambda x: K.l2_normalize(x, axis=1), name='l2_normalize')(x)
embedding_model = Model(inputs=base_model.input, outputs=x, name='embedding_model')
return embedding_model, base_model
# --- 6. 数据加载器 (保持不变) ---
def create_dataset_from_paths(pet_ids, pet_data, batch_size, apply_augmentation=True):
print(f"开始为 {len(pet_ids)} 个宠物ID创建数据集...")
all_pairs = []
for pet_id in pet_ids:
pet_images = pet_data.get(pet_id)
if len(pet_images) >= 2:
pairs = [(pet_images[i], pet_images[j]) for i in range(len(pet_images)) for j in
range(i + 1, len(pet_images))]
all_pairs.extend(pairs)
print(f"最终生成的正样本对总数: {len(all_pairs)}")
if not all_pairs:
raise ValueError("all_pairs 列表为空。无法创建数据集。")
dataset = tf.data.Dataset.from_tensor_slices(all_pairs)
def load_ap_pair(paths):
anchor = load_and_preprocess_image(paths[0], IMAGE_SIZE, apply_augmentation)
positive = load_and_preprocess_image(paths[1], IMAGE_SIZE, apply_augmentation)
return (anchor, positive), tf.constant([0.0])
dataset = dataset.shuffle(buffer_size=len(all_pairs)).map(load_ap_pair, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.batch(batch_size).prefetch(buffer_size=tf.data.AUTOTUNE)
return dataset
# --- 7. 主训练流程 (第一阶段) ---
if __name__ == '__main__':
pet_data = {}
for pet_id in os.listdir(DATASET_PATH):
pet_path = os.path.join(DATASET_PATH, pet_id)
if os.path.isdir(pet_path):
pet_data[pet_id] = [os.path.join(pet_path, f) for f in os.listdir(pet_path)]
pet_data_cleaned = {pid: paths for pid, paths in pet_data.items() if len(paths) >= 2}
pet_ids_cleaned = list(pet_data_cleaned.keys())
if len(pet_ids_cleaned) < 2:
raise ValueError("数据集中合格的类别数量不足,至少需要2个类别且每个类别至少有2张图片。")
train_ids, val_ids = train_test_split(pet_ids_cleaned, test_size=0.2, random_state=42)
print(f"合格宠物ID数量 (>=2张图片): {len(pet_ids_cleaned)}")
print(f"训练集宠物ID数量: {len(train_ids)}")
print(f"验证集宠物ID数量: {len(val_ids)}")
train_dataset = create_dataset_from_paths(train_ids, pet_data_cleaned, BATCH_SIZE, apply_augmentation=True)
val_dataset = create_dataset_from_paths(val_ids, pet_data_cleaned, BATCH_SIZE, apply_augmentation=False)
# --- 核心: 只构建一次模型实例 ---
embedding_model, base_model = build_embedding_model(input_shape=(IMAGE_SIZE[0], IMAGE_SIZE[1], 3),
embedding_dim=EMBEDDING_DIM)
# 共享输入层
input_anchor = Input(shape=(IMAGE_SIZE[0], IMAGE_SIZE[1], 3))
input_positive = Input(shape=(IMAGE_SIZE[0], IMAGE_SIZE[1], 3))
# --- 阶段1: 训练嵌入层 (冻结基础模型) ---
print("\n--- 阶段1: 训练嵌入层 ---")
base_model.trainable = False
# 确保只有新添加的层是可训练的
for layer in embedding_model.layers:
if layer.name in ['dropout', 'dense', 'l2_normalize']:
layer.trainable = True
combined_embeddings = tf.keras.layers.Concatenate(axis=0)(
[embedding_model(input_anchor), embedding_model(input_positive)])
siamese_model = Model(inputs=[input_anchor, input_positive], outputs=combined_embeddings,
name='siamese_model')
siamese_model.compile(optimizer=Adam(learning_rate=1e-4), loss=triplet_loss(margin=MARGIN))
early_stopping = EarlyStopping(monitor='val_loss', patience=PATIENCE, mode='min', restore_best_weights=True)
custom_saver = CustomModelSaver(filepath=MODEL_WEIGHTS_PATH, monitor='val_loss', mode='min', verbose=1)
siamese_model.fit(
train_dataset,
epochs=EPOCHS,
validation_data=val_dataset,
callbacks=[early_stopping, custom_saver]
)
print("\n--- 第一阶段训练完成,最佳权重已保存至:", MODEL_WEIGHTS_PATH, "---")