torch保存模型参数、one-hot标签、读取weights等记录-20200415

本文详细介绍了PyTorch中模型的保存和加载过程,包括如何只保存模型的学习参数,以及在推理前如何正确加载模型并设置评估模式。同时,文章还提供了one-hot编码的实现方法,并展示了如何读取模型中间层的权重和偏置。

save:

torch.save(model.state_dict(), PATH)

在pytorch保存模型进行推理时,只需要保存训练过的模型的学习参数即可。一个常见的PyTorch约定是使用.pt或.pth文件扩展名保存模型。
load:

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

记住,您必须调用model.eval(),以便在运行推断之前将dropout和batch规范化层设置为评估模式。如果不这样做,将会产生不一致的推断结果。
Note:
注意,load_state_dict()函数接受一个dictionary对象,而不是保存对象的路径。这意味着您必须在将保存的state_dict传至load_state_dict()函数之前反序列化它。
参考:PyTorch之保存加载模型

one-hot:
旧版:

class_num = 10
batch_size = 4
#label = torch.LongTensor(batch_size, 1).random_() % class_num
one_hot = torch.zeros(batch_size, class_num).scatter_(1, label, 1)

新版:

torch.nn.functional.one_hot 
F.one_hot(torch.arange(0, 5), num_classes=5) 

参考:PyTorch——Tensor_把索引标签转换成one-hot标签表示(包括回复)

读取中间某一层权重weight和bias

parm={}
for name,parameters in G.named_parameters():
    print(name,':',parameters.size())
    parm[name]=parameters.cpu().detach().numpy()

参考:获取Pytorch中间某一层权重或者特征

import torch import torch.nn as nn import torch.nn.functional as F # 定义稀有类别列表(必须放在类外部) rare_classes = [3, 4, 6, 9, 11] # 定义在类外部 # ----------------------------- # 1. 增强型频域注意力 (EFA) # ----------------------------- class EnhancedFreqAttention(nn.Module): def __init__(self, dim, reduction=4): super().__init__() self.conv = nn.Sequential( nn.Conv1d(dim, dim//reduction, 3, padding=1), nn.GELU(), nn.Conv1d(dim//reduction, dim, 3, padding=1) ) self.attn = nn.Sequential( nn.Linear(dim, dim//reduction), nn.GELU(), nn.Linear(dim//reduction, dim) ) self.sigmoid = nn.Sigmoid() def forward(self, x): # 频域卷积注意力 conv_attn = self.conv(x.transpose(1, 2)).transpose(1, 2) # 全局上下文注意力 global_attn = torch.mean(x, dim=1, keepdim=True) # 保持维度 global_attn = self.attn(global_attn) # 组合注意力 fused_attn = self.sigmoid(conv_attn + global_attn) return x * fused_attn # ----------------------------- # 2. 多尺度频域融合块 # ----------------------------- class MultiScaleFreqBlock(nn.Module): def __init__(self, embed_dim): super().__init__() self.branch1 = nn.Sequential( nn.Linear(12, embed_dim//2), nn.GELU() ) self.branch2 = nn.Sequential( nn.Linear(12, embed_dim//2), # 两倍频域信息 nn.GELU() ) self.fuse = nn.Linear(embed_dim, embed_dim) def forward(self, x): B, N, _ = x.shape # 获取原始点数量N # 基础频域特征 fft1 = torch.fft.fft(x, dim=1) feat1 = torch.cat([fft1.real, fft1.imag], dim=-1) # 形状: [B, N, 12] b1 = self.branch1(feat1) # 高频增强特征 fft2 = torch.fft.fft(x, n=2048, dim=1) # 形状: [B, 2048, 6]复数 feat2 = torch.cat([fft2.real, fft2.imag], dim=-1) # 形状: [B, 2048, 12] # 关键修正:先处理再截断 b2_full = self.branch2(feat2) # 形状: [B, 2048, embed_dim//2] b2 = b2_full[:, :N, :] # 截取前N个点 # 融合 fused = torch.cat([b1, b2], dim=-1) # 形状: [B, N, embed_dim] return self.fuse(fused) # ----------------------------- # 3. Transformer块 (补充定义) # ----------------------------- class TransformerBlock(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4.0, drop=0.1): super().__init__() self.norm1 = nn.LayerNorm(dim) self.attn = nn.MultiheadAttention( dim, num_heads, dropout=drop, batch_first=True ) self.norm2 = nn.LayerNorm(dim) mlp_dim = int(dim * mlp_ratio) self.mlp = nn.Sequential( nn.Linear(dim, mlp_dim), nn.GELU(), nn.Dropout(drop), nn.Linear(mlp_dim, dim), nn.Dropout(drop) ) def forward(self, x): # 残差连接1 identity = x x = self.norm1(x) x, _ = self.attn(x, x, x) x = x + identity # 残差连接2 identity = x x = self.norm2(x) x = self.mlp(x) x = x + identity return x # ----------------------------- # 4. 主网络结构:FreqFormerV8 # ----------------------------- class FreqFormerV8(nn.Module): def __init__(self, num_classes=13, embed_dim=192, depth=6, num_heads=8): super().__init__() # 空间嵌入 self.spatial_embed = nn.Linear(6, embed_dim) # 增强频域处理 self.freq_block = MultiScaleFreqBlock(embed_dim) self.efa = EnhancedFreqAttention(embed_dim) # 门控融合优化 self.fuse_gate = nn.Sequential( nn.Linear(2 * embed_dim, embed_dim), nn.ReLU(), nn.Linear(embed_dim, embed_dim), nn.Sigmoid() ) # 深度Transformer self.blocks = nn.ModuleList([ TransformerBlock(embed_dim, num_heads=num_heads) for _ in range(depth) ]) # 双分类头 self.main_head = nn.Sequential( nn.LayerNorm(embed_dim), nn.Linear(embed_dim, num_classes) ) # 辅助头专注稀有类 self.aux_head = nn.Sequential( nn.LayerNorm(embed_dim), nn.Linear(embed_dim, len(rare_classes)) ) # 初始化权重 self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.kaiming_normal_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) def forward(self, coords, feats=None): # 输入处理 if feats is not None: x = torch.cat([coords, feats], dim=-1) else: feats = torch.zeros_like(coords) x = torch.cat([coords, feats], dim=-1) # 空间分支 spatial_feat = self.spatial_embed(x) # 频域分支 freq_feat = self.freq_block(x) freq_feat = self.efa(freq_feat) # 门控融合 gate_input = torch.cat([spatial_feat, freq_feat], dim=-1) gate = self.fuse_gate(gate_input) fused = gate * spatial_feat + (1 - gate) * freq_feat # Transformer处理 for blk in self.blocks: fused = blk(fused) # 双输出 main_out = self.main_head(fused) aux_out = self.aux_head(fused) return main_out, aux_out if __name__ == "__main__": # 测试验证 model = FreqFormerV8(num_classes=13) coords = torch.randn(2, 1024, 3) feats = torch.randn(2, 1024, 3) main_out, aux_out = model(coords, feats) print("Main output shape:", main_out.shape) # [2, 1024, 13] print("Aux output shape:", aux_out.shape) # [2, 1024, 5]模型代码已经优化好了,# train_freqformer_v8.py """ Train script for FreqFormerV6 usage example: python train_freqformer_v6.py --data_dir <…> --batch_size 4 --num_epochs 150 """ import os, time, argparse import numpy as np import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader import torch.nn.functional as F from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingLR # ------------------------------- # Import your V6 model (adjust path if needed) from freqformer_v8 import FreqFormerV8 # ------------------------------- # Dice loss (multi-class) def one_hot(labels, num_classes): """ labels: [N] (int) returns: [N, C] """ y = torch.eye(num_classes, device=labels.device)[labels] return y def multiclass_dice_loss(probs, labels, eps=1e-6): """ probs: [BP, C], labels: [BP] (ints) """ C = probs.shape[1] mask = (labels >= 0) if mask.sum() == 0: return probs.new_tensor(0.) probs = probs[mask] # [M, C] labels = labels[mask] gt = one_hot(labels, C) # [M, C] intersection = (probs * gt).sum(dim=0) cardinality = probs.sum(dim=0) + gt.sum(dim=0) dice = (2. * intersection + eps) / (cardinality + eps) loss = 1.0 - dice return loss.mean() # ------------------------------- # Focal Loss class FocalLoss(nn.Module): def __init__(self, alpha=None, gamma=2.0, reduction='mean'): super().__init__() self.alpha = alpha self.gamma = gamma self.reduction = reduction def forward(self, inputs, targets): ce_loss = F.cross_entropy(inputs, targets, reduction='none', weight=self.alpha) pt = torch.exp(-ce_loss) focal_loss = (1 - pt)**self.gamma * ce_loss if self.reduction == 'mean': return focal_loss.mean() elif self.reduction == 'sum': return focal_loss.sum() return focal_loss # ------------------------------- # Lovasz softmax def lovasz_grad(gt_sorted): gts = gt_sorted.sum() if gts == 0: return torch.zeros_like(gt_sorted) intersection = gts - gt_sorted.cumsum(0) union = gts + (1 - gt_sorted).cumsum(0) jaccard = 1. - intersection / union if gt_sorted.numel() > 1: jaccard[1:] = jaccard[1:] - jaccard[:-1] return jaccard def flatten_probas(probas, labels, ignore_index=-1): mask = (labels != ignore_index) if not mask.any(): return probas.new(0), labels.new(0) probas = probas[mask] labels = labels[mask] return probas, labels def lovasz_softmax(probas, labels, classes='present', ignore_index=-1): C = probas.size(1) losses = [] probas, labels = flatten_probas(probas, labels, ignore_index) if probas.numel() == 0: return probas.new_tensor(0.) for c in range(C): fg = (labels == c).float() if classes == 'present' and fg.sum() == 0: continue class_pred = probas[:, c] errors = (fg - class_pred).abs() perm = torch.argsort(errors, descending=True) fg_sorted = fg[perm] grad = lovasz_grad(fg_sorted) loss_c = torch.dot(F.relu(errors[perm]), grad) losses.append(loss_c) if len(losses) == 0: return probas.new_tensor(0.) return sum(losses) / len(losses) # ------------------------------- # Dataset (S3DIS npy layout assumed) class S3DISDatasetAug(Dataset): def __init__(self, data_dir, split='train', val_area='Area_5', num_points=1024, augment=True): self.num_points = num_points self.augment = augment and (split == 'train') self.files = [] self.rare_classes = [3,4,6,9,11] # 必须在此初始化 for f in sorted(os.listdir(data_dir)): if not f.endswith('.npy'): continue if split == 'train' and val_area in f: continue if split == 'val' and val_area not in f: continue self.files.append(os.path.join(data_dir, f)) if len(self.files) == 0: raise RuntimeError(f"No files found in {data_dir} (split={split})") # 构建文件类别缓存 self.file_class_map = {} for file_path in self.files: # 使用内存映射方式加载,避免一次性读入内存 data = np.load(file_path, mmap_mode='r') # 抽样计算类别(避免处理整个大文件) labels_sample = data[::100, 6] # 每隔100个点取一个 unique_classes = np.unique(labels_sample).tolist() self.file_class_map[file_path] = unique_classes def __len__(self): return len(self.files) def __getitem__(self, idx): if self.augment and np.random.rand() < 0.3: # 使用初始化时构建的缓存而非实时加载 rare_files = [f for f in self.files if any( c in self.file_class_map[f] for c in self.rare_classes # 关键修改 )] if rare_files: # 避免空列表 file_path = np.random.choice(rare_files) return self._load_file(file_path) # 调用_load_file加载数据 # 如果没有稀有文件,则按正常路径加载 # 默认加载当前索引的文件 return self._load_file(self.files[idx]) # 新增方法,用于从文件路径加载数据并处理 def _load_file(self, file_path): data = np.load(file_path) coords = data[:, :3].astype(np.float32) extra = data[:, 3:6].astype(np.float32) labels = data[:, 6].astype(np.int64) N = coords.shape[0] if N >= self.num_points: choice = np.random.choice(N, self.num_points, replace=False) else: choice = np.random.choice(N, self.num_points, replace=True) coords = coords[choice] extra = extra[choice] labels = labels[choice] if self.augment: theta = np.random.uniform(0, 2*np.pi) c, s = np.cos(theta), np.sin(theta) R = np.array([[c, -s, 0], [s, c, 0], [0,0,1]], dtype=np.float32) coords = coords.dot(R.T) scale = np.random.uniform(0.9,1.1) coords = coords * scale coords = coords + np.random.normal(0, 0.01, coords.shape).astype(np.float32) local_feat = np.concatenate([coords, extra], axis=1) return { 'coords': torch.from_numpy(coords).float(), 'extra': torch.from_numpy(extra).float(), 'local_feat': torch.from_numpy(local_feat).float(), 'label': torch.from_numpy(labels).long() } # ------------------------------- # Confusion matrix & IoU def compute_confusion_matrix(preds, gts, num_classes): mask = (gts >= 0) & (gts < num_classes) gt = gts[mask].astype(np.int64) pred = preds[mask].astype(np.int64) conf = np.bincount(gt*num_classes + pred, minlength=num_classes**2) return conf.reshape((num_classes, num_classes)) def compute_iou_from_conf(conf): inter = np.diag(conf) gt_sum = conf.sum(axis=1) pred_sum = conf.sum(axis=0) union = gt_sum + pred_sum - inter iou = inter / (union + 1e-10) return iou # ------------------------------- # Class weights def compute_class_weights(file_list, num_classes, method='inv_sqrt'): counts = np.zeros(num_classes, dtype=np.float64) for p in file_list: data = np.load(p, mmap_mode='r') labels = data[:,6].astype(np.int64) for c in range(num_classes): counts[c] += (labels == c).sum() counts = np.maximum(counts, 1.0) if method=='inv_freq': weights = 1.0 / counts elif method=='inv_sqrt': weights = 1.0 / np.sqrt(counts) else: weights = np.ones_like(counts) weights = weights / weights.sum() * num_classes return torch.from_numpy(weights.astype(np.float32)) # ------------------------------- # Main training loop def main(): parser = argparse.ArgumentParser() parser.add_argument('--data_dir', default='/root/autodl-tmp/pointcloud_seg/data/S3DIS_new/processed_npy') parser.add_argument('--save_dir', default='./checkpoints_v6') parser.add_argument('--batch_size', type=int, default=4) parser.add_argument('--num_epochs', type=int, default=300) parser.add_argument('--num_points', type=int, default=1024) parser.add_argument('--num_classes', type=int, default=13) parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu') parser.add_argument('--use_class_weights', action='store_true') parser.add_argument('--use_lovasz', action='store_true') parser.add_argument('--warmup_epochs', type=int, default=5) parser.add_argument('--num_workers', type=int, default=8) parser.add_argument('--grad_clip', type=float, default=1.0) parser.add_argument('--use_focal', action='store_true', help='Use Focal Loss instead of CrossEntropy') parser.add_argument('--focal_gamma', type=float, default=2.0, help='Gamma parameter for Focal Loss') args = parser.parse_args() os.makedirs(args.save_dir, exist_ok=True) device = torch.device(args.device) # Dataset train_ds = S3DISDatasetAug(args.data_dir, split='train', num_points=args.num_points, augment=True) val_ds = S3DISDatasetAug(args.data_dir, split='val', num_points=args.num_points, augment=False) train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=True) val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=max(1, args.num_workers//2)) # Class weights class_weights = None if args.use_class_weights: print("Computing class weights...") class_weights = compute_class_weights(train_ds.files, args.num_classes).to(device) print("class weights:", class_weights.cpu().numpy()) # Model model = FreqFormerV8(num_classes=args.num_classes) if torch.cuda.device_count() > 1 and args.device.startswith('cuda'): print("Using DataParallel on devices:", list(range(torch.cuda.device_count()))) model = torch.nn.DataParallel(model) model = model.to(device) # Focal Loss focal_criterion = None if args.use_focal: print(f"Using Focal Loss with gamma={args.focal_gamma}") alpha = class_weights if class_weights is not None else torch.ones(args.num_classes).to(device) rare_classes = [3,4,6,9,11] # adjust as needed for c in rare_classes: if c < len(alpha): alpha[c] *= 25.0 focal_criterion = FocalLoss(alpha=alpha, gamma=args.focal_gamma).to(device) # 替换为新的组合调度器 optimizer = torch.optim.AdamW( model.parameters(), lr=args.lr, weight_decay=1e-4 ) # 预热阶段调度器 warmup = torch.optim.lr_scheduler.LinearLR( optimizer, start_factor=0.01, end_factor=1.0, total_iters=args.warmup_epochs ) # 余弦退火调度器 cosine = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=max(1, args.num_epochs - args.warmup_epochs) ) # 组合调度器 scheduler = torch.optim.lr_scheduler.SequentialLR( optimizer, schedulers=[warmup, cosine], milestones=[args.warmup_epochs] ) best_miou = 0.0 # Training loop for epoch in range(args.num_epochs): model.train() t0 = time.time() running_loss = 0.0 for batch in train_loader: coords = batch['coords'].to(device) extra = batch['extra'].to(device) labels = batch['label'].to(device) optimizer.zero_grad() logits = model(coords, extra) B,N,C = logits.shape logits_flat = logits.view(-1,C) labels_flat = labels.view(-1) probs = F.softmax(logits_flat, dim=-1) dice = multiclass_dice_loss(probs, labels_flat) if args.use_focal and focal_criterion is not None: ce = focal_criterion(logits_flat, labels_flat) else: ce = F.cross_entropy(logits_flat, labels_flat, weight=class_weights, ignore_index=-1) lov = lovasz_softmax(probs, labels_flat) if args.use_lovasz else logits_flat.new_tensor(0.0) loss = 0.5*ce + 2.0*dice + 0.3*lov loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) optimizer.step() running_loss += loss.item() scheduler.step() avg_loss = running_loss / max(1, len(train_loader)) t1 = time.time() print(f"Epoch {epoch+1}/{args.num_epochs} TrainLoss: {avg_loss:.4f} Time: {(t1-t0):.1f}s LR: {optimizer.param_groups[0]['lr']:.6f}") # Validation if (epoch+1)%5==0 or (epoch+1)==args.num_epochs: model.eval() conf = np.zeros((args.num_classes, args.num_classes), dtype=np.int64) tot_loss = 0.0 cnt = 0 with torch.no_grad(): for batch in val_loader: coords = batch['coords'].to(device) extra = batch['extra'].to(device) labels = batch['label'].to(device) logits = model(coords, extra) B,N,C = logits.shape logits_flat = logits.view(-1,C) labels_flat = labels.view(-1) if class_weights is not None: loss_ce = F.cross_entropy(logits_flat, labels_flat, weight=class_weights, ignore_index=-1) else: loss_ce = F.cross_entropy(logits_flat, labels_flat, ignore_index=-1) probs = F.softmax(logits_flat, dim=-1) dice_val = multiclass_dice_loss(probs, labels_flat) lov_val = lovasz_softmax(probs, labels_flat) if args.use_lovasz else logits_flat.new_tensor(0.0) loss_val = 0.5*loss_ce + 2.0*dice_val + 0.3*lov_val tot_loss += loss_val.item() preds = logits.argmax(dim=-1).cpu().numpy().reshape(-1) gts = labels.cpu().numpy().reshape(-1) conf += compute_confusion_matrix(preds, gts, args.num_classes) cnt += 1 mean_loss = tot_loss / max(1, cnt) iou = compute_iou_from_conf(conf) miou = np.nanmean(iou) oa = np.diag(conf).sum() / (conf.sum() + 1e-12) print(f"-- Validation Loss: {mean_loss:.4f} mIoU: {miou:.4f} OA: {oa:.4f}") print("Per-class IoU:") for cid, v in enumerate(iou): print(f" class {cid:02d}: {v:.4f}") if miou > best_miou: best_miou = miou path = os.path.join(args.save_dir, f'best_epoch_{epoch+1:03d}miou.pth') state = { 'epoch': epoch+1, 'best_miou': best_miou, 'model_state_dict': model.module.state_dict() if isinstance(model, torch.nn.DataParallel) else model.state_dict() } torch.save(state, path) print("Saved best:", path) # Final save final_path = os.path.join(args.save_dir, f'final_epoch_{args.num_epochs:03d}miou.pth') state = { 'epoch': args.num_epochs, 'best_miou': best_miou, 'model_state_dict': model.module.state_dict() if isinstance(model, torch.nn.DataParallel) else model.state_dict() } torch.save(state, final_path) print("Training finished. Final saved to:", final_path) if __name__ == "__main__": main()训练代码还需要怎么优化呢?要真正的能提点上来
最新发布
10-24
<think>我们注意到训练脚本中使用了FreqFormerV8模型,但训练代码中只用了主输出头(main_head)的logits,没有考虑辅助头(aux_head)的输出。此外,损失函数计算中也没有针对稀有类别的特殊处理。以下进行关键优化: ### 训练脚本优化要点 1. **双输出头处理**: ```python # 修改前: logits = model(coords, extra) # 返回两个输出,但只取第一个 # 修改后: main_logits, aux_logits = model(coords, extra) ``` 2. **损失函数增强**: - 主损失:针对所有类别 - 辅助损失:仅针对稀有类别(需要将标签映射到稀有类别索引) 3. **稀有类别标签映射**: ```python # 定义在全局(与模型代码一致) rare_classes = [3,4,6,9,11] rare_class_map = {c: i for i, c in enumerate(rare_classes)} # 创建辅助标签 def create_aux_labels(labels): aux_labels = torch.full_like(labels, fill_value=-1, dtype=torch.long) # 默认忽略 for orig_c, new_idx in rare_class_map.items(): mask = (labels == orig_c) aux_labels[mask] = new_idx return aux_labels ``` 4. **辅助损失计算**: ```python # 辅助标签生成 aux_labels_flat = create_aux_labels(labels).view(-1) # 计算辅助损失(仅稀有类) if aux_criterion is not None: loss_aux = aux_criterion(aux_logits.view(-1, len(rare_classes)), aux_labels_flat) else: loss_aux = 0 ``` 5. **总损失组合**: ```python loss = main_loss + 2.0 * loss_aux # 加强辅助损失权重 ``` ### 完整修正代码(关键部分) ```python # 在训练循环之前定义辅助损失函数 rare_classes = [3,4,6,9,11] # 与模型定义一致 rare_class_map = {c: i for i, c in enumerate(rare_classes)} aux_criterion = nn.CrossEntropyLoss(ignore_index=-1) # 用于辅助头 # 修改训练循环中的前向传播和损失计算 model.train() for batch in train_loader: coords = batch['coords'].to(device) extra = batch['extra'].to(device) labels = batch['label'].to(device) optimizer.zero_grad() # 前向传播(双输出) main_logits, aux_logits = model(coords, extra) # 修改点1:接收两个输出 # 主损失计算(所有类别) B, N, C = main_logits.shape main_logits_flat = main_logits.view(-1, C) labels_flat = labels.view(-1) main_probs = F.softmax(main_logits_flat, dim=-1) # 主损失组合 dice = multiclass_dice_loss(main_probs, labels_flat) if args.use_focal and focal_criterion is not None: ce_main = focal_criterion(main_logits_flat, labels_flat) else: ce_main = F.cross_entropy(main_logits_flat, labels_flat, weight=class_weights, ignore_index=-1) lov_main = lovasz_softmax(main_probs, labels_flat) if args.use_lovasz else 0.0 main_loss = 0.5*ce_main + 2.0*dice + 0.3*lov_main # 辅助损失计算(仅稀有类) aux_labels = create_aux_labels(labels).to(device) # 创建辅助标签 aux_labels_flat = aux_labels.view(-1) aux_logits_flat = aux_logits.view(-1, len(rare_classes)) # 注意:aux_criterion已经设置ignore_index=-1 loss_aux = aux_criterion(aux_logits_flat, aux_labels_flat) # 修改点2:计算辅助损失 # 总损失 = 主损失 + 辅助损失(加强权重) total_loss = main_loss + 2.0 * loss_aux # 修改点3:组合损失 total_loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) optimizer.step() ``` ### 辅助标签创建函数 ```python def create_aux_labels(labels): """ labels: 原始标签张量 [B, N] 返回: 辅助标签张量 [B, N],其中稀有类别映射到0~4,非稀有类为-1(忽略) """ aux_labels = torch.full_like(labels, fill_value=-1, dtype=torch.long) for orig_c, new_idx in rare_class_map.items(): mask = (labels == orig_c) aux_labels[mask] = new_idx return aux_labels ``` ### 验证阶段调整 ```python # 验证时同样使用双输出,但只需用主输出计算指标 with torch.no_grad(): main_logits, aux_logits = model(coords, extra) # 训练时用辅助头,验证时仅用主头 # 后续指标计算只使用main_logits ``` ### 其他优化建议 1. **稀有类别过采样**: - 已在`S3DISDatasetAug`中实现稀有类过采样(30%概率选择包含稀有类的场景) 2. **类别权重调整**: ```python # 在计算主损失时使用的class_weights已包含稀有类权重 # 辅助损失使用独立损失函数(无需额外权重) ``` 3. **学习率策略**: - 当前组合调度器(预热+余弦退火)已足够,可保持 4. **梯度裁剪**: - 当前`grad_clip=1.0`可防止梯度爆炸 ### 效果预期 通过以上优化,模型将: - 显式学习稀有类别特征(通过辅助头) - 避免稀有类别被通用特征淹没(辅助损失加强) - 结合过采样策略,显著提升稀有类别的召回率 > **注意**:训练初期辅助损失可能远大于主损失,可考虑动态调整权重(如初始权重为0.5,逐渐增加到2.0) §§相关问题§§ 1. 为什么辅助损失需要独立于主损失函数进行设计? 2. 在验证阶段为何不需要使用辅助头的输出? 3. 如何平衡主损失和辅助损失的权重避免训练不稳定? 4. 动态调整辅助损失权重有哪些常见策略? 5. 如果遇到辅助头过度主导训练过程应如何处理? </think>针对您的训练代码,我发现有几个关键优化点可以显著提升模型性能。以下是完整的优化方案,重点解决稀有类别处理、损失函数设计和训练稳定性问题: ### 完整优化后的训练代码 ```python # train_freqformer_v8.py """ Optimized training script for FreqFormerV8 """ import os, time, argparse import numpy as np import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader import torch.nn.functional as F from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingLR from freqformer_v8 import FreqFormerV8 # ------------------------------- # 全局稀有类别定义 (与模型一致) RARE_CLASSES = [3, 4, 6, 9, 11] # MUST match model definition! # ------------------------------- # 改进的损失函数组合 class HybridLoss(nn.Module): def __init__(self, num_classes, rare_classes, class_weights=None, focal_gamma=2.0, alpha=0.5, beta=2.0, gamma=0.3): super().__init__() self.num_classes = num_classes self.rare_classes = rare_classes self.alpha = alpha self.beta = beta self.gamma = gamma # 主损失权重调整 if class_weights is None: self.class_weights = torch.ones(num_classes) else: self.class_weights = class_weights # 增强稀有类别权重 (10倍) for c in rare_classes: if c < len(self.class_weights): self.class_weights[c] *= 10.0 # 辅助损失权重 (仅稀有类) self.aux_weights = torch.ones(len(rare_classes)) for i in range(len(rare_classes)): self.aux_weights[i] = 25.0 # 显著加强稀有类权重 # 损失组件 self.ce_loss = nn.CrossEntropyLoss( weight=self.class_weights, ignore_index=-1, reduction='none' ) self.focal_loss = FocalLoss( alpha=self.class_weights, gamma=focal_gamma, reduction='none' ) self.aux_loss = nn.CrossEntropyLoss( weight=self.aux_weights, ignore_index=-1, reduction='none' ) def forward(self, main_logits, aux_logits, labels): B, N, _ = main_logits.shape labels_flat = labels.view(-1) # 创建辅助标签 aux_labels = self.create_aux_labels(labels) aux_labels_flat = aux_labels.view(-1) # 主损失计算 main_loss_flat = self.focal_loss( main_logits.view(-1, self.num_classes), labels_flat ) # 辅助损失 (仅稀有类) aux_mask = (aux_labels_flat >= 0) aux_loss_flat = torch.zeros_like(main_loss_flat) if aux_mask.any(): aux_loss_flat[aux_mask] = self.aux_loss( aux_logits.view(-1, len(self.rare_classes))[aux_mask], aux_labels_flat[aux_mask] ) # Dice损失 probs = F.softmax(main_logits.view(-1, self.num_classes), dim=-1) dice_loss = multiclass_dice_loss(probs, labels_flat) # Lovasz损失 lovasz_loss = lovasz_softmax(probs, labels_flat) # 损失组合 base_loss = ( self.alpha * main_loss_flat.mean() + self.beta * dice_loss + self.gamma * lovasz_loss ) # 辅助损失加权 total_loss = base_loss + 2.5 * aux_loss_flat.mean() return total_loss, { 'main': main_loss_flat.mean().item(), 'aux': aux_loss_flat.mean().item() if aux_mask.any() else 0, 'dice': dice_loss.item(), 'lovasz': lovasz_loss.item() } def create_aux_labels(self, labels): """创建稀有类别的辅助标签""" aux_labels = torch.full_like(labels, fill_value=-1, dtype=torch.long) for i, c in enumerate(self.rare_classes): mask = (labels == c) aux_labels[mask] = i return aux_labels # ------------------------------- # Dice loss (保持不变) def one_hot(labels, num_classes): y = torch.eye(num_classes, device=labels.device)[labels] return y def multiclass_dice_loss(probs, labels, eps=1e-6): C = probs.shape[1] mask = (labels >= 0) if mask.sum() == 0: return probs.new_tensor(0.) probs = probs[mask] labels = labels[mask] gt = one_hot(labels, C) intersection = (probs * gt).sum(dim=0) cardinality = probs.sum(dim=0) + gt.sum(dim=0) dice = (2. * intersection + eps) / (cardinality + eps) return 1.0 - dice.mean() # ------------------------------- # Focal Loss (保持不变) class FocalLoss(nn.Module): def __init__(self, alpha=None, gamma=2.0, reduction='mean'): super().__init__() self.alpha = alpha self.gamma = gamma self.reduction = reduction def forward(self, inputs, targets): ce_loss = F.cross_entropy(inputs, targets, reduction='none', weight=self.alpha) pt = torch.exp(-ce_loss) focal_loss = (1 - pt)**self.gamma * ce_loss if self.reduction == 'mean': return focal_loss.mean() return focal_loss # ------------------------------- # Lovasz softmax (保持不变) def lovasz_grad(gt_sorted): gts = gt_sorted.sum() if gts == 0: return torch.zeros_like(gt_sorted) intersection = gts - gt_sorted.cumsum(0) union = gts + (1 - gt_sorted).cumsum(0) jaccard = 1. - intersection / union if gt_sorted.numel() > 1: jaccard[1:] = jaccard[1:] - jaccard[:-1] return jaccard def flatten_probas(probas, labels, ignore_index=-1): mask = (labels != ignore_index) if not mask.any(): return probas.new(0), labels.new(0) probas = probas[mask] labels = labels[mask] return probas, labels def lovasz_softmax(probas, labels, classes='present', ignore_index=-1): C = probas.size(1) losses = [] probas, labels = flatten_probas(probas, labels, ignore_index) if probas.numel() == 0: return probas.new_tensor(0.) for c in range(C): fg = (labels == c).float() if classes == 'present' and fg.sum() == 0: continue class_pred = probas[:, c] errors = (fg - class_pred).abs() perm = torch.argsort(errors, descending=True) fg_sorted = fg[perm] grad = lovasz_grad(fg_sorted) loss_c = torch.dot(F.relu(errors[perm]), grad) losses.append(loss_c) return sum(losses) / len(losses) if losses else probas.new_tensor(0.) # ------------------------------- # 数据集类 (优化内存管理) class S3DISDatasetAug(Dataset): def __init__(self, data_dir, split='train', val_area='Area_5', num_points=1024, augment=True): self.num_points = num_points self.augment = augment and (split == 'train') self.files = [] for f in sorted(os.listdir(data_dir)): if f.endswith('.npy'): if split == 'train' and val_area in f: continue if split == 'val' and val_area not in f: continue self.files.append(os.path.join(data_dir, f)) # 缓存文件元数据 (避免加载整个文件) self.file_metadata = [] for fp in self.files: with open(fp, 'rb') as f: version = np.lib.format.read_magic(f) shape, _, _ = np.lib.format._read_array_header(f, version) self.file_metadata.append((fp, shape[0])) # 稀有类预筛选 self.rare_files = [] for fp in self.files: # 仅抽样检查稀有类 data_sample = np.load(fp, mmap_mode='r')[::100] if any(c in data_sample[:, 6] for c in RARE_CLASSES): self.rare_files.append(fp) def __len__(self): return len(self.files) def __getitem__(self, idx): if self.augment and np.random.rand() < 0.4 and self.rare_files: # 优先选择包含稀有类的场景 file_path = np.random.choice(self.rare_files) return self._load_file(file_path) return self._load_file(self.files[idx]) def _load_file(self, file_path): # 高效加载:仅读取必要部分 idx = self.files.index(file_path) n_points = self.file_metadata[idx][1] indices = np.random.choice(n_points, self.num_points, replace=n_points < self.num_points) # 内存映射加载 mmap = np.load(file_path, mmap_mode='r') data = mmap[indices] coords = data[:, :3].astype(np.float32) extra = data[:, 3:6].astype(np.float32) labels = data[:, 6].astype(np.int64) # 数据增强 if self.augment: # 旋转增强 theta = np.random.uniform(0, 2*np.pi) R = np.array([[np.cos(theta), -np.sin(theta), 0], [np.sin(theta), np.cos(theta), 0], [0, 0, 1]], dtype=np.float32) coords = coords @ R.T # 缩放增强 (+/- 20%) scale = np.random.uniform(0.8, 1.2, size=(1, 3)) coords *= scale # 弹性变形 if np.random.rand() > 0.7: displacement = np.random.normal(0, 0.03, size=coords.shape).astype(np.float32) coords += displacement return { 'coords': torch.from_numpy(coords), 'extra': torch.from_numpy(extra), 'label': torch.from_numpy(labels) } # ------------------------------- # 评估指标计算 def compute_iou(conf_matrix): iou_per_class = np.diag(conf_matrix) / ( conf_matrix.sum(axis=1) + conf_matrix.sum(axis=0) - np.diag(conf_matrix) + 1e-10 ) return { 'miou': np.nanmean(iou_per_class), 'iou_per_class': iou_per_class, 'oa': np.diag(conf_matrix).sum() / conf_matrix.sum(), 'rare_iou': np.nanmean(iou_per_class[RARE_CLASSES]) if len(RARE_CLASSES) > 0 else 0 } # ------------------------------- # 类别权重计算 def compute_class_weights(files, num_classes, method='sqrt'): counts = np.zeros(num_classes, dtype=np.float64) for fp in files: mmap = np.load(fp, mmap_mode='r') labels = mmap[::10, 6] # 10%采样 for c in range(num_classes): counts[c] += (labels == c).sum() counts = np.maximum(counts, 1.0) if method == 'sqrt': weights = 1.0 / np.sqrt(counts) elif method == 'log': weights = 1.0 / np.log(1.2 + counts) else: # inverse frequency weights = 1.0 / counts # 归一化 return torch.from_numpy(weights / weights.sum() * num_classes).float() # ------------------------------- # 梯度裁剪自适应 class AdaptiveGradientClipper: def __init__(self, max_norm=1.0, history_size=100, factor=1.2): self.max_norm = max_norm self.history = [] self.history_size = history_size self.factor = factor def __call__(self, parameters): # 计算梯度范数 grads = [p.grad.detach() for p in parameters if p.grad is not None] total_norm = torch.norm(torch.stack([torch.norm(g) for g in grads])) # 更新历史记录 self.history.append(total_norm.item()) if len(self.history) > self.history_size: self.history.pop(0) # 自适应调整 if len(self.history) == self.history_size: avg_norm = np.mean(self.history) if avg_norm > self.max_norm * 1.5: self.max_norm *= self.factor elif avg_norm < self.max_norm * 0.5: self.max_norm /= self.factor # 应用裁剪 torch.nn.utils.clip_grad_norm_(parameters, self.max_norm) # ------------------------------- # 主训练循环 def main(): parser = argparse.ArgumentParser() parser.add_argument('--data_dir', default='/root/autodl-tmp/pointcloud_seg/data/S3DIS_new/processed_npy') parser.add_argument('--save_dir', default='./checkpoints_v8') parser.add_argument('--batch_size', type=int, default=8) # 增大batch size parser.add_argument('--num_epochs', type=int, default=300) parser.add_argument('--num_points', type=int, default=1024) parser.add_argument('--num_classes', type=int, default=13) parser.add_argument('--lr', type=float, default=5e-4) # 更高初始学习率 parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu') parser.add_argument('--warmup_epochs', type=int, default=10) # 更长预热 parser.add_argument('--num_workers', type=int, default=8) parser.add_argument('--focal_gamma', type=float, default=3.0, help='更高的Focal Loss gamma') args = parser.parse_args() os.makedirs(args.save_dir, exist_ok=True) device = torch.device(args.device) # 数据集 train_ds = S3DISDatasetAug(args.data_dir, split='train', num_points=args.num_points, augment=True) val_ds = S3DISDatasetAug(args.data_dir, split='val', num_points=args.num_points, augment=False) # 类别权重计算 class_weights = compute_class_weights(train_ds.files, args.num_classes, method='log') class_weights = class_weights.to(device) print("Class weights:", class_weights.cpu().numpy()) # 数据加载器 train_loader = DataLoader( train_ds, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, persistent_workers=True ) val_loader = DataLoader( val_ds, batch_size=2, shuffle=False, num_workers=max(1, args.num_workers//2) ) # 模型 model = FreqFormerV8(num_classes=args.num_classes) if torch.cuda.device_count() > 1: print(f"Using {torch.cuda.device_count()} GPUs with DataParallel") model = nn.DataParallel(model) model.to(device) # 损失函数 hybrid_loss = HybridLoss( num_classes=args.num_classes, rare_classes=RARE_CLASSES, class_weights=class_weights, focal_gamma=args.focal_gamma, alpha=0.6, # 主损失权重 beta=1.8, # Dice损失权重 gamma=0.4 # Lovasz损失权重 ).to(device) # 优化器 optimizer = torch.optim.AdamW( model.parameters(), lr=args.lr, weight_decay=0.05 # 更强的正则化 ) # 自适应梯度裁剪 grad_clipper = AdaptiveGradientClipper(max_norm=1.2, factor=1.1) # 学习率调度 warmup = LinearLR( optimizer, start_factor=0.01, end_factor=1.0, total_iters=args.warmup_epochs * len(train_loader) # 步级别预热 ) cosine = CosineAnnealingLR( optimizer, T_max=(args.num_epochs - args.warmup_epochs) * len(train_loader), eta_min=1e-6 ) scheduler = SequentialLR( optimizer, schedulers=[warmup, cosine], milestones=[args.warmup_epochs * len(train_loader)] ) best_rare_iou = 0.0 global_step = 0 # 训练循环 for epoch in
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值