import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import h5py
import numpy as np
import os
from tqdm import tqdm
import json
# 动态融合模块
class DynamicFusion(nn.Module):
def __init__(self, input_dims, hidden_dim=128, lambda_val=0.5):
"""
动态融合模块
参数:
input_dims (dict): 各模态的输入维度 {'audio': dim_a, 'visual': dim_v, 'text': dim_t}
hidden_dim (int): 公共特征空间的维度
lambda_val (float): 多模态提升能力与鉴别能力的平衡因子
"""
super(DynamicFusion, self).__init__()
self.lambda_val = lambda_val
self.modalities = list(input_dims.keys())
self.num_modalities = len(self.modalities)
# 模态投影层 (将不同模态映射到公共空间)
self.projections = nn.ModuleDict()
for modality, dim in input_dims.items():
self.projections[modality] = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.ReLU(),
nn.LayerNorm(hidden_dim)
)
# 重要性评估器
self.importance_evaluator = nn.Sequential(
nn.Linear(hidden_dim * self.num_modalities, 256),
nn.ReLU(),
nn.Linear(256, self.num_modalities)
)
# 预测器
self.predictor = nn.Sequential(
nn.Linear(hidden_dim, 128),
nn.ReLU(),
nn.Linear(128, 1) # 假设是二分类任务
)
def forward(self, features, labels=None, return_importance=False):
"""
前向传播
参数:
features (dict): 各模态特征 {'audio': tensor, 'visual': tensor, 'text': tensor}
labels (Tensor): 真实标签 (仅在训练时使用)
return_importance (bool): 是否返回模态重要性
返回:
fused_output: 融合后的预测结果
importance_scores: 模态重要性分数 (如果return_importance=True)
"""
# 1. 投影到公共特征空间
projected = {}
for modality in self.modalities:
projected[modality] = self.projections[modality](features[modality])
# 2. 计算全模态融合预测
all_features = torch.cat([projected[m] for m in self.modalities], dim=1)
importance_logits = self.importance_evaluator(all_features)
importance_weights = F.softmax(importance_logits, dim=1)
# 3. 加权融合
fused_feature = torch.zeros_like(projected[self.modalities[0]])
for i, modality in enumerate(self.modalities):
fused_feature += importance_weights[:, i].unsqueeze(1) * projected[modality]
fused_output = self.predictor(fused_feature)
if not self.training or labels is None:
if return_importance:
return fused_output, importance_weights
return fused_output
# 4. 训练时计算监督信号
# 4.1 计算全模态损失
full_loss = F.binary_cross_entropy_with_logits(fused_output, labels, reduction='none')
# 4.2 计算单模态损失
single_losses = {}
for modality in self.modalities:
single_pred = self.predictor(projected[modality])
single_losses[modality] = F.binary_cross_entropy_with_logits(single_pred, labels, reduction='none')
# 4.3 计算移除各模态后的损失
remove_losses = {}
for modality_to_remove in self.modalities:
# 重归一化剩余模态的权重
remaining_modalities = [m for m in self.modalities if m != modality_to_remove]
remaining_indices = [self.modalities.index(m) for m in remaining_modalities]
# 计算剩余模态的归一化权重
remaining_logits = importance_logits[:, remaining_indices]
remaining_weights = F.softmax(remaining_logits, dim=1)
# 剩余模态的融合
fused_remove = torch.zeros_like(fused_feature)
for i, modality in enumerate(remaining_modalities):
idx = remaining_modalities.index(modality)
fused_remove += remaining_weights[:, idx].unsqueeze(1) * projected[modality]
remove_pred = self.predictor(fused_remove)
remove_losses[modality_to_remove] = F.binary_cross_entropy_with_logits(
remove_pred, labels, reduction='none'
)
# 4.4 计算监督信号 - 模态重要性标签
importance_labels = {}
for modality in self.modalities:
L_m = single_losses[modality]
L_full = full_loss
L_remove = remove_losses[modality]
# 计算多模态提升能力
multimodal_boost = L_full - L_remove
# 公式(2-38): I_m = -[λ·(L - L̂) + (1-λ)·L_m]
I_m = -(self.lambda_val * multimodal_boost + (1 - self.lambda_val) * L_m)
importance_labels[modality] = I_m.detach() # 分离计算图
# 4.5 重要性监督损失 (使用排序损失)
importance_loss = 0
for i, modality_i in enumerate(self.modalities):
for j, modality_j in enumerate(self.modalities):
if i >= j:
continue
# 获取预测的重要性分数差异
score_diff = importance_logits[:, i] - importance_logits[:, j]
# 获取真实重要性标签差异
label_diff = importance_labels[modality_i] - importance_labels[modality_j]
# 使用hinge loss进行排序优化
loss_pair = F.relu(-label_diff * score_diff + 0.1)
importance_loss += loss_pair.mean()
# 4.6 总损失 = 主任务损失 + 重要性监督损失
main_loss = full_loss.mean()
total_loss = main_loss + importance_loss
if return_importance:
return total_loss, importance_weights
return total_loss
# 数据集类 - 修改为使用带噪声的特征文件和单个标签文件
class NoisyMultimodalDataset(Dataset):
def __init__(self, base_dir, split='train'):
"""
带噪声的多模态数据集
参数:
base_dir (str): 数据集基础目录
split (str): 数据集分割 (train/val/test)
"""
self.base_dir = base_dir
self.split = split
self.sample_ids = []
# 定义模态文件路径 - 使用带噪声的文件
self.modality_files = {
'audio': os.path.join(base_dir, "A", "acoustic_noisy.h5"),
'visual': os.path.join(base_dir, "V", "visual_noisy.h5"),
'text': os.path.join(base_dir, "L", "deberta-v3-large_noisy.h5")
}
# 检查文件存在性
self.available_modalities = []
for modality, file_path in self.modality_files.items():
if os.path.exists(file_path):
self.available_modalities.append(modality)
print(f"找到模态文件: {modality} -> {file_path}")
else:
print(f"警告: 模态文件不存在: {file_path}")
if not self.available_modalities:
raise FileNotFoundError("未找到任何模态文件!")
# 验证文件存在性并获取样本ID
sample_ids_sets = []
for modality in self.available_modalities:
file_path = self.modality_files[modality]
try:
with h5py.File(file_path, 'r') as f:
sample_ids = set(f.keys())
sample_ids_sets.append(sample_ids)
print(f"模态 {modality}: 找到 {len(sample_ids)} 个样本")
except Exception as e:
print(f"加载模态 {modality} 文件时出错: {str(e)}")
sample_ids_sets.append(set())
# 取所有模态样本ID的交集
common_ids = set.intersection(*sample_ids_sets)
self.sample_ids = sorted(list(common_ids))
# 加载标签文件
label_file = os.path.join(base_dir, "labels", "all_labels.npy")
if not os.path.exists(label_file):
# 尝试在基础目录下直接查找
label_file = os.path.join(base_dir, "all_labels.npy")
if not os.path.exists(label_file):
raise FileNotFoundError(f"标签文件不存在: {os.path.join(base_dir, 'labels', 'all_labels.npy')} 或 {label_file}")
try:
all_labels = np.load(label_file)
print(f"加载标签文件: {label_file}, 总标签数: {len(all_labels)}")
except Exception as e:
raise IOError(f"加载标签文件时出错: {str(e)}")
# 创建样本ID到标签索引的映射
id_to_index = {}
try:
# 使用第一个模态文件作为参考顺序
first_modality = self.available_modalities[0]
with h5py.File(self.modality_files[first_modality], 'r') as f:
all_ids_in_file = list(f.keys())
for idx, sample_id in enumerate(all_ids_in_file):
id_to_index[sample_id] = idx
print(f"使用 {first_modality} 模态的顺序作为标签映射参考")
except Exception as e:
raise RuntimeError(f"创建ID映射时出错: {str(e)}")
# 获取当前分割样本的标签
self.labels = []
valid_sample_ids = []
missing_samples = 0
for sample_id in self.sample_ids:
if sample_id in id_to_index:
idx = id_to_index[sample_id]
if idx < len(all_labels):
self.labels.append(all_labels[idx])
valid_sample_ids.append(sample_id)
else:
print(f"警告: 样本 {sample_id} 的索引 {idx} 超出标签范围 (标签长度={len(all_labels)})")
missing_samples += 1
else:
print(f"警告: 样本 {sample_id} 未在标签映射中找到")
missing_samples += 1
if missing_samples > 0:
print(f"警告: {missing_samples} 个样本缺少标签")
if not valid_sample_ids:
raise ValueError("没有有效的样本ID与标签匹配")
self.sample_ids = valid_sample_ids
self.labels = np.array(self.labels, dtype=np.float32)
print(f"加载 {split} 数据集: {len(self)} 个样本, {len(self.available_modalities)} 个模态")
print(f"样本示例: {self.sample_ids[:3]}, 标签示例: {self.labels[:3]}")
def __len__(self):
return len(self.sample_ids)
def __getitem__(self, idx):
sample_id = self.sample_ids[idx]
features = {}
# 加载各模态特征
for modality in self.available_modalities:
file_path = self.modality_files[modality]
try:
with h5py.File(file_path, 'r') as f:
features[modality] = f[sample_id][:].astype(np.float32)
except Exception as e:
print(f"加载样本 {sample_id} 的 {modality} 特征时出错: {str(e)}")
# 如果某个特征加载失败,使用零填充
if modality == 'audio':
features[modality] = np.zeros(74, dtype=np.float32)
elif modality == 'visual':
features[modality] = np.zeros(47, dtype=np.float32)
elif modality == 'text':
features[modality] = np.zeros(1024, dtype=np.float32)
else:
features[modality] = np.zeros(128, dtype=np.float32) # 默认
label = self.labels[idx]
return features, label
# 训练函数
def train(model, dataloader, optimizer, device):
model.train()
total_loss = 0
for features, labels in tqdm(dataloader, desc="训练中"):
# 将数据移至设备
for modality in features.keys():
features[modality] = features[modality].to(device)
labels = labels.to(device).unsqueeze(1) # 添加维度以匹配输出
# 前向传播
loss = model(features, labels)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(dataloader)
return avg_loss
# 验证函数
def validate(model, dataloader, device):
model.eval()
total_loss = 0
all_preds = []
all_labels = []
with torch.no_grad():
for features, labels in tqdm(dataloader, desc="验证中"):
# 将数据移至设备
for modality in features.keys():
features[modality] = features[modality].to(device)
labels = labels.to(device).unsqueeze(1)
# 前向传播
outputs = model(features, labels)
# 计算损失
loss = F.binary_cross_entropy_with_logits(outputs, labels)
total_loss += loss.item()
# 收集预测结果和标签
preds = torch.sigmoid(outputs)
all_preds.append(preds.cpu())
all_labels.append(labels.cpu())
# 计算指标
all_preds = torch.cat(all_preds)
all_labels = torch.cat(all_labels)
# 计算准确率
pred_labels = (all_preds > 0.5).float()
accuracy = (pred_labels == all_labels).float().mean().item()
avg_loss = total_loss / len(dataloader)
return avg_loss, accuracy
# 主函数
def main():
# 配置
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
base_dir = "/home/msj_team/data/code/CIF-MMIN/data/dataset/mositext"
# 检查基础目录是否存在
if not os.path.exists(base_dir):
print(f"错误: 基础目录不存在: {base_dir}")
print("请检查路径是否正确,或确保数据集已下载并放置在正确位置")
return
# 检查标签目录是否存在
label_dir = os.path.join(base_dir, "labels")
if not os.path.exists(label_dir):
print(f"警告: 标签目录不存在: {label_dir}")
print("尝试在基础目录下查找标签文件...")
# 检查标签文件是否存在
label_file = os.path.join(label_dir, "all_labels.npy")
if not os.path.exists(label_file):
label_file = os.path.join(base_dir, "all_labels.npy")
if not os.path.exists(label_file):
print(f"错误: 标签文件不存在: {os.path.join(label_dir, 'all_labels.npy')} 或 {label_file}")
return
print(f"使用标签文件: {label_file}")
# 根据实际特征维度设置
input_dims = {
'audio': 74, # 音频特征维度
'visual': 47, # 视觉特征维度
'text': 1024 # 文本特征维度
}
# 训练参数
batch_size = 32
learning_rate = 1e-4
num_epochs = 20
lambda_val = 0.7 # 平衡因子
try:
# 创建数据集和数据加载器
print("\n初始化训练数据集...")
train_dataset = NoisyMultimodalDataset(base_dir, split='train')
print("\n初始化验证数据集...")
val_dataset = NoisyMultimodalDataset(base_dir, split='val')
print(f"\n训练集样本数: {len(train_dataset)}")
print(f"验证集样本数: {len(val_dataset)}")
# 检查数据集是否为空
if len(train_dataset) == 0:
raise ValueError("训练数据集为空!")
if len(val_dataset) == 0:
raise ValueError("验证数据集为空!")
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
# 初始化模型
model = DynamicFusion(input_dims, lambda_val=lambda_val).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# 训练循环
best_val_accuracy = 0.0
print("\n开始训练...")
for epoch in range(num_epochs):
print(f"\nEpoch {epoch+1}/{num_epochs}")
train_loss = train(model, train_loader, optimizer, device)
val_loss, val_accuracy = validate(model, val_loader, device)
print(f"训练损失: {train_loss:.4f}, 验证损失: {val_loss:.4f}, 验证准确率: {val_accuracy:.4f}")
# 保存最佳模型
if val_accuracy > best_val_accuracy:
best_val_accuracy = val_accuracy
torch.save(model.state_dict(), "best_model_noisy.pth")
print(f"保存最佳模型 (验证准确率: {val_accuracy:.4f})")
print("\n训练完成!")
print(f"最终验证准确率: {best_val_accuracy:.4f}")
except Exception as e:
print(f"\n发生错误: {str(e)}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()
这是我的代码,要怎样修改