这是我之前用于训练网络的代码,请帮我修改:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MultiTaskSpectralCNN(num_wifi_channels=13, num_bt_channels=13, num_zigbee_channels=13, num_cw_channels=13).to(device)
# 损失函数和优化器
criterion_pres = nn.BCEWithLogitsLoss() # 用于 presence 多标签分类
criterion_ch = nn.CrossEntropyLoss(ignore_index=-1) # 用于信道分类,忽略 -1 标签
optimizer = optim.Adam(model.parameters(), lr=1e-3)
#%% md
# 9. 训练
#%%
def compute_masked_loss(predictions, targets, criterion):
"""
对目标为 -1 的样本进行屏蔽,仅对合法标签计算损失
"""
valid_mask = targets != -1
if not valid_mask.any():
return torch.tensor(0.0, device=predictions.device, requires_grad=True)
return criterion(predictions[valid_mask], targets[valid_mask])
#%%
def train_model(model, train_loader, test_loader, criterion_pres, criterion_ch, optimizer, epochs=50):
train_losses, test_losses = [], []
train_pres_accs, test_pres_accs = [], []
train_wifi_accs, test_wifi_accs = [], []
train_bt_accs, test_bt_accs = [], []
train_zb_accs, test_zb_accs = [], []
train_cw_accs, test_cw_accs = [], []
device = next(model.parameters()).device # 获取模型所在设备
for epoch in range(epochs):
model.train()
running_loss = 0.0
correct_pres = 0
correct_wifi = 0
correct_bt = 0
correct_zb = 0
correct_cw = 0
total_samples = 0
total_wifi_samples = 0
total_bt_samples = 0
total_zb_samples = 0
total_cw_samples = 0
for inputs, pres_labels, wifi_labels, bt_labels, zb_labels, cw_labels in train_loader:
inputs = inputs.to(device)
pres_labels = pres_labels.to(device)
wifi_labels = wifi_labels.to(device)
bt_labels = bt_labels.to(device)
zb_labels = zb_labels.to(device)
cw_labels = cw_labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
# 计算损失
loss_pres = criterion_pres(outputs['presence'], pres_labels)
loss_wifi = compute_masked_loss(outputs['wifi_channel'], wifi_labels, criterion_ch)
loss_bt = compute_masked_loss(outputs['bt_channel'], bt_labels, criterion_ch)
loss_zb = compute_masked_loss(outputs['zigbee_channel'], zb_labels, criterion_ch)
loss_cw = compute_masked_loss(outputs['cw_channel'], cw_labels, criterion_ch)
total_loss = loss_pres + loss_wifi + loss_bt + loss_zb + loss_cw
total_loss.backward()
optimizer.step()
running_loss += total_loss.item()
# Presence Accuracy: 所有三个都正确才算对
pred_pres = (torch.sigmoid(outputs['presence']) > 0.5).int()
correct_pres += (pred_pres == pres_labels).all(dim=1).sum().item()
# WiFi Channel Accuracy (only when present)
mask_wifi = pres_labels[:, 0] == 1
if mask_wifi.any():
_, pred_wifi = outputs['wifi_channel'][mask_wifi].max(1)
correct_wifi += (pred_wifi == wifi_labels[mask_wifi]).sum().item()
total_wifi_samples += mask_wifi.sum().item()
# bluetooth channel accuracy
mask_bt = pres_labels[:, 1] == 1
if mask_bt.any():
_, pred_bt = outputs['bt_channel'][mask_bt].max(1)
correct_bt += (pred_bt == bt_labels[mask_bt]).sum().item()
total_bt_samples += mask_bt.sum().item()
# ZigBee Channel Accuracy
mask_zb = pres_labels[:, 2] == 1
if mask_zb.any():
_, pred_zb = outputs['zigbee_channel'][mask_zb].max(1)
correct_zb += (pred_zb == zb_labels[mask_zb]).sum().item()
total_zb_samples += mask_zb.sum().item()
# cw Channel Accuracy
mask_cw = pres_labels[:, 3] == 1
if mask_cw.any():
_, pred_cw = outputs['cw_channel'][mask_cw].max(1)
correct_cw += (pred_cw == cw_labels[mask_cw]).sum().item()
total_cw_samples += mask_cw.sum().item()
total_samples += inputs.size(0)
# 记录训练指标
avg_train_loss = running_loss / len(train_loader)
train_losses.append(avg_train_loss)
train_pres_accs.append(correct_pres / total_samples)
train_wifi_accs.append(correct_wifi / total_wifi_samples if total_wifi_samples > 0 else 0.0)
train_bt_accs.append(correct_bt / total_bt_samples if total_bt_samples > 0 else 0.0)
train_zb_accs.append(correct_zb / total_zb_samples if total_zb_samples > 0 else 0.0)
train_cw_accs.append(correct_cw / total_cw_samples if total_cw_samples > 0 else 0.0)
# 测试阶段
model.eval()
test_loss = 0.0
test_correct_pres = 0
test_correct_wifi = 0
test_correct_bt = 0
test_correct_zb = 0
test_correct_cw = 0
test_total = 0
test_total_wifi = 0
test_total_bt = 0
test_total_zb = 0
test_total_cw = 0
with torch.no_grad():
for inputs, pres_labels, wifi_labels, bt_labels, zb_labels, cw_labels in test_loader:
inputs = inputs.to(device)
pres_labels = pres_labels.to(device)
wifi_labels = wifi_labels.to(device)
bt_labels = bt_labels.to(device)
zb_labels = zb_labels.to(device)
cw_labels = cw_labels.to(device)
outputs = model(inputs)
loss_pres = criterion_pres(outputs['presence'], pres_labels)
loss_wifi = compute_masked_loss(outputs['wifi_channel'], wifi_labels, criterion_ch)
loss_bt = compute_masked_loss(outputs['bt_channel'], bt_labels, criterion_ch)
loss_zb = compute_masked_loss(outputs['zigbee_channel'], zb_labels, criterion_ch)
loss_cw = compute_masked_loss(outputs['cw_channel'], cw_labels, criterion_ch)
test_loss += (loss_pres + loss_wifi + loss_bt + loss_zb + loss_cw).item()
# Presence Acc
pred_pres = (torch.sigmoid(outputs['presence']) > 0.5).int()
test_correct_pres += (pred_pres == pres_labels).all(dim=1).sum().item()
# WiFi Ch Acc
mask_wifi = pres_labels[:, 0] == 1
if mask_wifi.any():
_, pred_wifi = outputs['wifi_channel'][mask_wifi].max(1)
test_correct_wifi += (pred_wifi == wifi_labels[mask_wifi]).sum().item()
test_total_wifi += mask_wifi.sum().item()
mask_bt = pres_labels[:, 1] == 1
if mask_bt.any():
_, pred_bt = outputs['bt_channel'][mask_bt].max(1)
test_correct_bt += (pred_bt == bt_labels[mask_bt]).sum().item()
test_total_bt += mask_bt.sum().item()
# ZigBee Ch Acc
mask_zb = pres_labels[:, 2] == 1
if mask_zb.any():
_, pred_zb = outputs['zigbee_channel'][mask_zb].max(1)
test_correct_zb += (pred_zb == zb_labels[mask_zb]).sum().item()
test_total_zb += mask_zb.sum().item()
# cw Ch Acc
mask_cw = pres_labels[:, 3] == 1
if mask_cw.any():
_, pred_cw = outputs['cw_channel'][mask_cw].max(1)
test_correct_cw += (pred_cw == cw_labels[mask_cw]).sum().item()
test_total_cw += mask_cw.sum().item()
test_total += inputs.size(0)
avg_test_loss = test_loss / len(test_loader)
test_losses.append(avg_test_loss)
test_pres_accs.append(test_correct_pres / test_total)
test_wifi_accs.append(test_correct_wifi / test_total_wifi if test_total_wifi > 0 else 0.0)
test_bt_accs.append(test_correct_bt / test_total_bt if test_total_bt > 0 else 0.0)
test_zb_accs.append(test_correct_zb / test_total_zb if test_total_zb > 0 else 0.0)
test_cw_accs.append(test_correct_cw / test_total_cw if test_total_cw > 0 else 0.0)
# 每轮打印一次日志
if (epoch + 1) % 1 == 0:
print(f"Epoch [{epoch+1}/{epochs}] "
f"Train Loss: {avg_train_loss:.4f}, Test Loss: {avg_test_loss:.4f} "
f"Pres Acc: {train_pres_accs[-1]:.4f}, Test Pres Acc: {test_pres_accs[-1]:.4f} "
f"Wifi Ch Acc: {train_wifi_accs[-1]:.4f}, BT Ch Acc: {train_bt_accs[-1]:.4f}, ZB Ch Acc: {train_zb_accs[-1]:.4f}, CW Ch Acc: {train_cw_accs[-1]:.4f}")
return (
train_losses, test_losses,
train_pres_accs, test_pres_accs,
train_wifi_accs, test_wifi_accs,
train_bt_accs, test_bt_accs,
train_zb_accs, test_zb_accs,
train_cw_accs, test_cw_accs,
)