import os
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from sklearn.metrics import roc_curve, auc
from sklearn.preprocessing import label_binarize
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, GlobalAveragePooling2D, Dropout, LayerNormalization
from tensorflow.keras.layers import Conv2D, MaxPooling2D, BatchNormalization, Add, Activation
from tensorflow.keras.layers import Concatenate, Reshape, Flatten
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from tensorflow.keras.regularizers import l2
from tensorflow.keras.utils import Sequence
from tensorflow.keras.layers import Input, Dense, GlobalAveragePooling2D, GlobalAveragePooling1D, Dropout, \
LayerNormalization # 添加 GlobalAveragePooling1D
from tensorflow.keras.layers import Conv2D, Conv1D, MaxPooling2D, BatchNormalization, Add, Activation
import gc
import warnings
warnings.filterwarnings('ignore')
import tifffile
# GPU配置
try:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
tf.config.experimental.set_memory_growth(gpus[0], True)
print("GPU加速已启用")
except Exception as e:
print(f"GPU配置错误: {str(e)}")
class MultiModalDataGenerator(Sequence):
"""多模态数据生成器,处理7x4x39图像和39维化学数据"""
def __init__(self, image_paths, chemical_data, labels, batch_size=16, shuffle=True):
self.image_paths = image_paths
self.chemical_data = chemical_data
self.labels = labels
self.batch_size = batch_size
self.shuffle = shuffle
self.indices = np.arange(len(self.image_paths))
self.skip_same_value_count = 0 # 初始化跳过计数器
if self.shuffle:
np.random.shuffle(self.indices)
def __len__(self):
return int(np.ceil(len(self.image_paths) / self.batch_size))
def __getitem__(self, idx):
batch_indices = self.indices[idx * self.batch_size:(idx + 1) * self.batch_size]
batch_images = []
batch_chemical = []
batch_labels = []
for i in batch_indices:
# 读取39通道TIFF图像
try:
img = tifffile.imread(self.image_paths[i])
# 处理图像形状,确保为 (39, 4, 7)
if img.shape == (7, 4, 39):
img = np.moveaxis(img, -1, 0) # (7, 4, 39) -> (39, 4, 7)
elif img.shape == (39, 7, 4):
img = np.moveaxis(img, 1, 2) # (39, 7, 4) -> (39, 4, 7)
elif img.shape != (39, 4, 7):
print(f"警告: 图像形状异常 {img.shape},尝试处理")
# 仅转换数据类型,跳过标准化
# print(f"原始图像数据: {self.image_paths[i]}, 形状: {img.shape}, 类型: {img.dtype}")
# print(f"原始数据范围: min={img.min()}, max={img.max()}, mean={img.mean()}")
if img.dtype == np.uint16 or img.dtype == np.uint8:
img = img.astype(np.float32) # 仅转换类型,不缩放
else:
img = img.astype(np.float32)
if np.isnan(img).any():
print(f"警告: 图像 {self.image_paths[i]} 包含 nan 值,跳过处理")
continue
if img.max() == img.min():
print(f"警告: 图像 {self.image_paths[i]} 数据无效(全相同值),跳过处理")
self.skip_same_value_count += 1 # 增加跳过计数器
continue
# print(f"处理后范围: min={img.min()}, max={img.max()}, mean={img.mean()}")
# 检查图像数据
if np.isnan(img).any():
print(f"警告: 图像 {self.image_paths[i]} 包含 nan 值,形状: {img.shape}, 数据类型: {img.dtype}")
print(f"数据范围: min={img.min()}, max={img.max()}, mean={img.mean()}")
img = np.zeros((39, 4, 7), dtype=np.float32) # 替换为全零
if not np.isnan(img).any() and img.max() != img.min():
batch_images.append(img)
else:
print(f"跳过无效图像: {self.image_paths[i]}")
except Exception as e:
print(f"加载图像失败 {self.image_paths[i]}: {str(e)}")
continue
# 获取化学数据,确保形状为 (39,)
if isinstance(self.chemical_data, np.ndarray):
chemical_features = self.chemical_data[i]
else:
chemical_features = self.chemical_data.iloc[i].values
if chemical_features.shape != (39,):
print(f"警告: 化学数据形状异常 {chemical_features.shape},调整为 (39,)")
chemical_features = chemical_features.reshape(39,)
batch_chemical.append(chemical_features)
batch_labels.append(self.labels[i])
# 组装批次数据
batch_images = np.array(batch_images, dtype=np.float32)
batch_chemical = np.array(batch_chemical, dtype=np.float32)
batch_labels = np.array(batch_labels)
# 验证化学数据形状
if batch_chemical.shape[1] != 39:
print(f"警告: 批次化学数据形状异常 {batch_chemical.shape},调整为 (batch_size, 39)")
batch_chemical = batch_chemical.reshape(-1, 39)
# 调试信息(仅第一次批次)
if idx == 0:
print(f"批次数据形状: 图像={batch_images.shape}, 化学={batch_chemical.shape}, 标签={batch_labels.shape}")
print(f"标签分布: {np.bincount(batch_labels)}")
assert batch_images.shape[1:] == (39, 7, 4), f"图像形状应为 (39, 4, 7),实际为 {batch_images.shape[1:]}"
# 检查数据
if np.isnan(batch_images).any() or np.isnan(batch_chemical).any():
print(f"警告: 批次 {idx} 数据包含 nan 值!")
# 打印输入形状并结束进程
print(f"image_input 形状: {batch_images.shape}")
print(f"chemical_input 形状: {batch_chemical.shape}")
# 转换为 TensorFlow 兼容的输出格式
return ({"image_input": batch_images, "chemical_input": batch_chemical}, batch_labels)
def on_epoch_end(self):
print(f"跳过处理的图像数量(全相同值): {self.skip_same_value_count}")
if self.shuffle:
np.random.shuffle(self.indices)
class MultiModalFusionModel:
"""多模态特征融合模型 - 基础版本"""
def __init__(self, img_root="D:\\西北地区铜镍矿\\多模态测试\\图片训练",
data_path="D:\\西北地区铜镍矿\\数据\\训练数据.xlsx"):
self.img_root = img_root
self.data_path = data_path
self.scaler = StandardScaler()
self.model = None
self.history = None
def load_data(self):
"""加载多模态数据,包括图像路径、化学特征和标签。
Returns:
tuple: (image_paths, chemical_data, labels)
"""
print("=== 加载多模态数据 ===")
df = pd.read_excel(self.data_path)
print(f"数据形状: {df.shape}")
print(f"列名: {df.columns.tolist()}")
# 检查必要列
if 'name' not in df.columns:
raise ValueError("Excel数据中必须包含'name'列以匹配图像文件!")
if 'class' not in df.columns:
raise ValueError("Excel数据中必须包含'class'列作为标签!")
# 选择化学特征列(第6到45列,共39列)
feature_cols = df.columns[6:45]
chemical_data = df[feature_cols].select_dtypes(include=[np.number])
print(f"化学数据形状: {chemical_data.shape}")
# 构建图像路径和标签
image_paths = []
image_labels = []
valid_indices = []
label_map = {'positive': 0, 'neutral': 1, 'negative': 2}
for idx, row in df.iterrows():
filename = row['name']
class_label = row['class']
if not isinstance(filename, str) or class_label not in label_map:
print(f"跳过无效数据: filename={filename}, class={class_label}")
continue
# 在对应类别文件夹中查找图像
class_folder = os.path.join(self.img_root, class_label)
# 尝试不同的文件扩展名
possible_paths = [
os.path.join(class_folder, filename),
os.path.join(class_folder, f"{filename}.tif"),
os.path.join(class_folder, f"{filename}.tiff")
]
found = False
for img_path in possible_paths:
if os.path.exists(img_path):
image_paths.append(img_path)
image_labels.append(label_map[class_label])
valid_indices.append(idx)
found = True
break
if not found:
print(f"警告: 图像文件未找到: {filename} in {class_label}")
# 只保留匹配成功的化学数据
chemical_data = chemical_data.iloc[valid_indices].reset_index(drop=True)
image_labels = np.array(image_labels)
if len(image_paths) == 0:
raise ValueError("未找到任何匹配的图像-化学数据对!")
print(f"成功匹配样本数: {len(image_paths)}")
print(f"标签分布详情:")
for class_name, label in label_map.items():
count = np.sum(image_labels == label)
print(f" {class_name} (标签{label}): {count}个样本")
# 显示前5个样本的路径和标签
print(f"前5个样本:")
for i in range(min(5, len(image_paths))):
print(f" 样本{i}: {os.path.basename(image_paths[i])} -> 标签{image_labels[i]}")
return image_paths, chemical_data.values, image_labels
def build_model(self):
"""构建多模态融合模型"""
print("=== 构建多模态融合模型 ===")
# 图像输入分支 - 处理39x4x7的多光谱图像
img_input = Input(shape=(39, 7, 4), name='image_input')
# 图像特征提取 - 使用1D卷积处理每个光谱维度
x = Reshape((39, 7 * 4))(img_input) # 重塑为(39, 28)
x = Conv1D(64, 3, activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = Conv1D(128, 3, activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = GlobalAveragePooling1D()(x)
x = Dense(256, activation='relu')(x)
x = Dropout(0.3)(x)
img_features = Dense(128, activation='relu', name='img_features')(x)
# 化学数据输入分支
chem_input = Input(shape=(39,), name='chemical_input')
y = Dense(128, activation='relu')(chem_input)
y = BatchNormalization()(y)
y = Dropout(0.3)(y)
y = Dense(256, activation='relu')(y)
y = Dropout(0.3)(y)
chem_features = Dense(128, activation='relu', name='chem_features')(y)
# 特征融合
fused = Concatenate(name='feature_fusion')([img_features, chem_features])
fused = Dense(256, activation='relu')(fused)
fused = Dropout(0.4)(fused)
fused = Dense(128, activation='relu')(fused)
fused = Dropout(0.3)(fused)
# 分类输出
output = Dense(3, activation='softmax', name='classification')(fused)
# 创建模型
model = Model(inputs=[img_input, chem_input], outputs=output)
# 编译模型(添加梯度裁剪)
optimizer = Adam(learning_rate=0.0001, clipvalue=1.0)
model.compile(
optimizer=optimizer,
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
# print(model.summary())
self.model = model
return model
def train(self, image_paths, chemical_data, labels,
test_size=0.2, batch_size=16, epochs=100):
"""训练模型"""
print("=== 开始训练 ===")
# 数据预处理
chemical_data_scaled = self.scaler.fit_transform(chemical_data)
# 数据分割
X_img_train, X_img_test, X_chem_train, X_chem_test, y_train, y_test = train_test_split(
image_paths, chemical_data_scaled, labels,
test_size=test_size, random_state=42, stratify=labels
)
print(f"训练集大小: {len(X_img_train)}")
print(f"测试集大小: {len(X_img_test)}")
print(f"训练集标签分布: {np.bincount(y_train)}")
print(f"测试集标签分布: {np.bincount(y_test)}")
# 创建数据生成器
train_generator = MultiModalDataGenerator(
X_img_train, X_chem_train, y_train,
batch_size=batch_size, shuffle=True
)
val_generator = MultiModalDataGenerator(
X_img_test, X_chem_test, y_test,
batch_size=batch_size, shuffle=False
)
# 调试生成器输出
sample_data, sample_labels = train_generator[0]
print(f"调试生成器输出: 图像形状{sample_data['image_input'].shape}, 化学形状{sample_data['chemical_input'].shape}, 标签形状{sample_labels.shape}")
# print(f"图像数据范围: {sample_data['image_input'].min()} ~ {sample_data['image_input'].max()}")
# print(f"化学数据范围: {sample_data['chemical_input'].min()} ~ {sample_data['chemical_input'].max()}")
# 回调函数
callbacks = [
EarlyStopping(monitor='val_loss', patience=15, restore_best_weights=True),
ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=8, min_lr=1e-6),
ModelCheckpoint('best_multimodal_model_v3.keras', save_best_only=True, monitor='val_accuracy')
]
# 训练模型
try:
self.history = self.model.fit(
train_generator,
validation_data=val_generator,
epochs=epochs,
callbacks=callbacks,
verbose=1
)
except Exception as e:
print(f"使用生成器训练失败,尝试加载所有数据到内存: {str(e)}")
# 备用方案:加载所有数据到内存
print("正在加载训练数据到内存...")
X_img_train_loaded = []
X_chem_train_loaded = []
y_train_loaded = []
for i in range(len(train_generator)):
batch_data, batch_labels = train_generator[i]
X_img_train_loaded.append(batch_data["image_input"])
X_chem_train_loaded.append(batch_data["chemical_input"])
y_train_loaded.append(batch_labels)
X_img_train_array = np.vstack(X_img_train_loaded)
X_chem_train_array = np.vstack(X_chem_train_loaded)
y_train_array = np.hstack(y_train_loaded)
print("正在加载验证数据到内存...")
X_img_val_loaded = []
X_chem_val_loaded = []
y_val_loaded = []
for i in range(len(val_generator)):
batch_data, batch_labels = val_generator[i]
X_img_val_loaded.append(batch_data["image_input"])
X_chem_val_loaded.append(batch_data["chemical_input"])
y_val_loaded.append(batch_labels)
X_img_val_array = np.vstack(X_img_val_loaded)
X_chem_val_array = np.vstack(X_chem_val_loaded)
y_val_array = np.hstack(y_val_loaded)
print(
f"训练数据形状: 图像{X_img_train_array.shape}, 化学{X_chem_train_array.shape}, 标签{y_train_array.shape}")
print(f"验证数据形状: 图像{X_img_val_array.shape}, 化学{X_chem_val_array.shape}, 标签{y_val_array.shape}")
# 使用数组训练
self.history = self.model.fit(
[X_img_train_array, X_chem_train_array], y_train_array,
validation_data=([X_img_val_array, X_chem_val_array], y_val_array),
batch_size=batch_size,
epochs=epochs,
callbacks=callbacks,
verbose=1
)
# 恢复默认行为
tf.config.run_functions_eagerly(False)
return self.history
def evaluate(self, image_paths, chemical_data, labels):
"""评估模型"""
print("=== 模型评估 ===")
# 预处理化学数据
chemical_data_scaled = self.scaler.transform(chemical_data)
# 创建测试生成器
test_generator = MultiModalDataGenerator(
image_paths, chemical_data_scaled, labels,
batch_size=16, shuffle=False
)
# 预测
predictions = self.model.predict(test_generator)
y_pred = np.argmax(predictions, axis=1)
# 计算指标
accuracy = accuracy_score(labels, y_pred)
print(f"准确率: {accuracy:.4f}")
# 分类报告
class_names = ['positive', 'neutral', 'negative']
print("\n分类报告:")
print(classification_report(labels, y_pred, target_names=class_names))
# 混淆矩阵
cm = confusion_matrix(labels, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=class_names, yticklabels=class_names)
plt.title('混淆矩阵')
plt.ylabel('真实标签')
plt.xlabel('预测标签')
plt.tight_layout()
plt.savefig('confusion_matrix_v3.png', dpi=300, bbox_inches='tight')
plt.show()
return accuracy, y_pred, predictions
def main():
"""主函数"""
# 创建模型实例
model = MultiModalFusionModel()
# 加载数据
image_paths, chemical_data, labels = model.load_data()
# 构建模型
model.build_model()
# 训练模型
history = model.train(image_paths, chemical_data, labels,
batch_size=24, epochs=10)
# 评估模型
accuracy, predictions, probabilities = model.evaluate(image_paths, chemical_data, labels)
print(f"\n最终准确率: {accuracy:.4f}")
if __name__ == "__main__":
main()
最新发布