现有数据集格式如下,有pth文件,输出测试集的真实标签与预测结果到不同的csv表格,训练代码如下
X模态形状: torch.Size([1, 1, 35])
Y模态形状: torch.Size([1, 1, 35])
标签形状: torch.Size([1, 9])
训练代码:import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import os
import time
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from dataset import MultiModalDataset
from model import MultiModalNet, FusionModel
from opt import opt
# 训练函数
def train_model(model, dataloader, val_loader, num_epochs=10, lr=0.001, save_dir='results'):
# 创建保存目录
os.makedirs(save_dir, exist_ok=True)
# 设置TensorBoard
writer = SummaryWriter(os.path.join(save_dir, 'logs'))
# 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
# 跟踪训练过程
train_losses = []
val_losses = []
train_accs = []
val_accs = []
best_val_acc = 0.0
start_time = time.time()
print(f"开始训练,共 {num_epochs} 个epoch...")
print("-" * 50)
for epoch in range(num_epochs):
# 训练阶段
model.train()
running_loss = 0.0
correct = 0
total = 0
for i, batch in enumerate(dataloader):
# 获取数据
x_data = batch['x']
y_data = batch['y']
labels = batch['label']
# 梯度清零
optimizer.zero_grad()
# 前向传播
outputs = model(x_data, y_data)
# print(outputs.shape, labels.shape)
loss = criterion(outputs, labels)
# 反向传播和优化
loss.backward()
optimizer.step()
# 统计结果
running_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
true_labels = labels.argmax(dim=1) if labels.dim() > 1 else labels
total += true_labels.size(0)
correct += (predicted == true_labels).sum().item()
# 每10个batch打印一次
# if i % 10 == 9:
print(f'Epoch [{epoch + 1}/{num_epochs}], Batch [{i + 1}/{len(dataloader)}], Loss: {loss.item():.4f}')
# 计算训练集指标
epoch_loss = running_loss / len(dataloader)
epoch_acc = correct / total
train_losses.append(epoch_loss)
train_accs.append(epoch_acc)
# 验证阶段
val_loss, val_acc = validate_model(model, val_loader, criterion)
val_losses.append(val_loss)
val_accs.append(val_acc)
# 记录到TensorBoard
writer.add_scalar('Loss/train', epoch_loss, epoch)
writer.add_scalar('Loss/val', val_loss, epoch)
writer.add_scalar('Accuracy/train', epoch_acc, epoch)
writer.add_scalar('Accuracy/val', val_acc, epoch)
# 打印epoch结果
print(f'\nEpoch [{epoch + 1}/{num_epochs}] 完成!')
print(f'训练损失: {epoch_loss:.4f}, 训练准确率: {epoch_acc:.4f}')
print(f'验证损失: {val_loss:.4f}, 验证准确率: {val_acc:.4f}')
print("-" * 50)
# 保存最佳模型
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(model.state_dict(), os.path.join(save_dir, 'best_model.pth'))
print(f'保存最佳模型,验证准确率: {val_acc:.4f}')
save_training_results(train_losses, val_losses, train_accs, val_accs, save_dir)
# 保存最终模型
torch.save(model.state_dict(), os.path.join(save_dir, 'final_model.pth'))
# writer.close()
# 保存训练结果
save_training_results(train_losses, val_losses, train_accs, val_accs, save_dir)
# 计算总训练时间
total_time = time.time() - start_time
print(f'训练完成! 总耗时: {total_time // 60:.0f}分 {total_time % 60:.0f}秒')
print(f'最佳验证准确率: {best_val_acc:.4f}')
return model
def validate_model(model, dataloader, criterion):
"""模型验证"""
model.eval()
running_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for batch in dataloader:
x_data = batch['x']
y_data = batch['y']
labels = batch['label']
outputs = model(x_data, y_data)
loss = criterion(outputs, labels)
running_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
true_labels = labels.argmax(dim=1) if labels.dim() > 1 else labels
total += true_labels.size(0)
correct += (predicted == true_labels).sum().item()
val_loss = running_loss / len(dataloader)
val_acc = correct / total
return val_loss, val_acc
def save_training_results(train_losses, val_losses, train_accs, val_accs, save_dir):
"""Save training result charts and metrics"""
# Set font properties for better rendering
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['Times New Roman'] + plt.rcParams['font.serif']
plt.rcParams['mathtext.fontset'] = 'stix' # Use STIX fonts for math symbols compatible with Times style
plt.rcParams['axes.unicode_minus'] = False # Display minus sign correctly
plt.rcParams.update({
'font.size': 14, # Base font size
'axes.labelsize': 16, # Axis label font size
'axes.titlesize': 18, # Title font size
'xtick.labelsize': 14, # X-axis tick label font size
'ytick.labelsize': 14, # Y-axis tick label font size
'legend.fontsize': 14, # Legend font size
})
# Save loss curve
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
# Save accuracy curve
plt.subplot(1, 2, 2)
plt.plot(train_accs, label='Training Accuracy')
plt.plot(val_accs, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.tight_layout()
plt.savefig(os.path.join(save_dir, 'training_results.png'))
plt.close()
# Save metrics to file
with open(os.path.join(save_dir, 'metrics.txt'), 'w') as f:
f.write(f'Final Training Loss: {train_losses[-1]:.6f}\n')
f.write(f'Final Training Accuracy: {train_accs[-1]:.6f}\n')
f.write(f'Final Validation Loss: {val_losses[-1]:.6f}\n')
f.write(f'Final Validation Accuracy: {val_accs[-1]:.6f}\n')
f.write(f'Best Validation Accuracy: {max(val_accs):.6f}\n')
print(f'Training results saved to {save_dir}')
# 主程序
if __name__ == "__main__":
torch.manual_seed(42)
np.random.seed(42)
train_dataset = MultiModalDataset(root_dir=opt.dataset_name, mode='train', num_classes=opt.n_class)
val_dataset = MultiModalDataset(root_dir=opt.dataset_name, mode='val', num_classes=opt.n_class)
train_loader = DataLoader(
train_dataset,
batch_size=opt.batch_size,
shuffle=True,
num_workers=2,
)
val_loader = DataLoader(
val_dataset,
batch_size=1,
shuffle=False,
num_workers=2,
)
print(f"训练集样本数: {len(train_dataset)}")
print(f"验证集样本数: {len(val_dataset)}")
if opt.which_model_net == "FusionModel":
model = FusionModel(input_dim=opt.size,
num_classes=len(set(train_dataset.samples[i]['label'] for i in range(len(train_dataset))))
)
elif opt.which_model_net == "MultiModalNet":
model = MultiModalNet(input_dim=opt.size,
num_classes=len(set(train_dataset.samples[i]['label'] for i in range(len(train_dataset))))
)
else:
model = MultiModalNet(input_dim=opt.size,
num_classes=len(set(train_dataset.samples[i]['label'] for i in range(len(train_dataset))))
)
print(f"模型参数数量: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
trained_model = train_model(
model,
train_loader,
val_loader,
num_epochs=opt.n_epochs,
lr=opt.lr,
save_dir=opt.exp_path
)