cross_entropy_loss(): argument ‘input‘ (position 1) must be Tensor, not NoneType

在使用cross_entropy_loss()时遇到'argument 'input' (position 1) must be Tensor, not NoneType'错误。该错误通常由于网络模型返回值异常引起,可能是返回了None、返回多个变量或返回顺序错误。检查并确保模型return语句的正确性。" 74370187,5686716,GitLab实践:Comment管理与重要性,"['git', 'gitlab', '版本管理', '注释', '代码维护']

cross_entropy_loss(): argument ‘input‘ (position 1) must be Tensor, not NoneType

yu遇到这个错误的时候,需要看看自己的网络模型

这个模型当中,在return中  return的内容 是否有异常的现象

1 或者return 多个变量

2  是否没有返回任何值

3 返回多个变量的顺序是否正确

目录

cross_entropy_loss(): argument ‘input‘ (position 1) must be Tensor, not NoneType

踩坑记录 | cross_entropy_loss(): argument 'input' (position 1) must be Tensor, not NoneType

1. 报错信息长什么样

2. 根本原因

3. 最小复现 & 现场调试

4. checklist:4 步定位法

5. 正确写法模板

✅ 单输出

✅ 多输出(tuple)

✅ 多输出(dict)

✅ 条件分支

6. 真实踩坑案例(2025 年还在发生)

7. 一句话总结


踩坑记录 | cross_entropy_loss(): argument 'input' (position 1) must be Tensor,

""" Train script for FreqFormerV6 - usage example: python train_freqformer_v6.py --data_dir <...> --batch_size 4 --num_epochs 150 """ import os, time, argparse import numpy as np import torch from torch.utils.data import Dataset, DataLoader import torch.nn.functional as F # import your V6 model (adjust path if needed) from freqformer_v6 import FreqFormerV6 # ----------------------- # Dice loss (multi-class) and helpers # ----------------------- def one_hot(labels, num_classes): # labels: [N] (int) y = torch.eye(num_classes, device=labels.device)[labels] return y # [N, C] def multiclass_dice_loss(probs, labels, eps=1e-6): # probs: [B*P, C], labels: [B*P] (ints) C = probs.shape[1] mask = (labels >= 0) if mask.sum() == 0: return probs.new_tensor(0.) probs = probs[mask] # [M, C] labels = labels[mask] gt = one_hot(labels, C) # [M, C] # compute per-class dice intersection = (probs * gt).sum(dim=0) cardinality = probs.sum(dim=0) + gt.sum(dim=0) dice = (2. * intersection + eps) / (cardinality + eps) loss = 1.0 - dice return loss.mean() # Focal Loss (解决类别不平衡问题) # ----------------------- class FocalLoss(nn.Module): def __init__(self, alpha=None, gamma=2.0, reduction='mean'): super().__init__() self.alpha = alpha # 类别权重向量 [num_classes] self.gamma = gamma self.reduction = reduction def forward(self, inputs, targets): ce_loss = F.cross_entropy(inputs, targets, reduction='none', weight=self.alpha) pt = torch.exp(-ce_loss) focal_loss = (1 - pt)**self.gamma * ce_loss if self.reduction == 'mean': return focal_loss.mean() elif self.reduction == 'sum': return focal_loss.sum() return focal_loss # ----------------------- # Lovasz (same as your v5 version) - copied minimal # ----------------------- def lovasz_grad(gt_sorted): gts = gt_sorted.sum() if gts == 0: return torch.zeros_like(gt_sorted) intersection = gts - gt_sorted.cumsum(0) union = gts + (1 - gt_sorted).cumsum(0) jaccard = 1. - intersection / union if gt_sorted.numel() > 1: jaccard[1:] = jaccard[1:] - jaccard[:-1] return jaccard def flatten_probas(probas, labels, ignore_index=-1): mask = (labels != ignore_index) if not mask.any(): return probas.new(0), labels.new(0) probas = probas[mask] labels = labels[mask] return probas, labels def lovasz_softmax(probas, labels, classes='present', ignore_index=-1): C = probas.size(1) losses = [] probas, labels = flatten_probas(probas, labels, ignore_index) if probas.numel() == 0: return probas.new_tensor(0.) for c in range(C): fg = (labels == c).float() if classes == 'present' and fg.sum() == 0: continue class_pred = probas[:, c] errors = (fg - class_pred).abs() perm = torch.argsort(errors, descending=True) fg_sorted = fg[perm] grad = lovasz_grad(fg_sorted) loss_c = torch.dot(F.relu(errors[perm]), grad) losses.append(loss_c) if len(losses) == 0: return probas.new_tensor(0.) return sum(losses) / len(losses) # ----------------------- # Dataset (S3DIS npy layout assumed) # ----------------------- class S3DISDatasetAug(Dataset): def __init__(self, data_dir, split='train', val_area='Area_5', num_points=1024, augment=True): self.num_points = num_points self.augment = augment and (split == 'train') self.files = [] for f in sorted(os.listdir(data_dir)): if not f.endswith('.npy'): continue if split == 'train' and val_area in f: continue if split == 'val' and val_area not in f: continue self.files.append(os.path.join(data_dir, f)) if len(self.files) == 0: raise RuntimeError(f"No files found in {data_dir} (split={split})") def __len__(self): return len(self.files) def __getitem__(self, idx): data = np.load(self.files[idx]) coords = data[:, :3].astype(np.float32) extra = data[:, 3:6].astype(np.float32) labels = data[:, 6].astype(np.int64) N = coords.shape[0] if N >= self.num_points: choice = np.random.choice(N, self.num_points, replace=False) else: choice = np.random.choice(N, self.num_points, replace=True) coords = coords[choice] extra = extra[choice] labels = labels[choice] if self.augment: theta = np.random.uniform(0, 2*np.pi) c, s = np.cos(theta), np.sin(theta) R = np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]], dtype=np.float32) coords = coords.dot(R.T) scale = np.random.uniform(0.9, 1.1) coords = coords * scale coords = coords + np.random.normal(0, 0.01, coords.shape).astype(np.float32) local_feat = np.concatenate([coords, extra], axis=1) return { 'local_feat': torch.from_numpy(local_feat).float(), 'coords': torch.from_numpy(coords).float(), 'extra': torch.from_numpy(extra).float(), 'label': torch.from_numpy(labels).long() } # ----------------------- # helpers: confusion & iou # ----------------------- def compute_confusion_matrix(preds, gts, num_classes): mask = (gts >= 0) & (gts < num_classes) gt = gts[mask].astype(np.int64) pred = preds[mask].astype(np.int64) conf = np.bincount(gt * num_classes + pred, minlength=num_classes**2) return conf.reshape((num_classes, num_classes)) def compute_iou_from_conf(conf): inter = np.diag(conf) gt_sum = conf.sum(axis=1) pred_sum = conf.sum(axis=0) union = gt_sum + pred_sum - inter iou = inter / (union + 1e-10) return iou # ----------------------- # compute class weights # ----------------------- def compute_class_weights(file_list, num_classes, method='inv_sqrt'): counts = np.zeros(num_classes, dtype=np.float64) for p in file_list: data = np.load(p, mmap_mode='r') labels = data[:, 6].astype(np.int64) for c in range(num_classes): counts[c] += (labels == c).sum() counts = np.maximum(counts, 1.0) if method == 'inv_freq': weights = 1.0 / counts elif method == 'inv_sqrt': weights = 1.0 / np.sqrt(counts) else: weights = np.ones_like(counts) weights = weights / weights.sum() * num_classes return torch.from_numpy(weights.astype(np.float32)) # ----------------------- # main # ----------------------- def main(): parser = argparse.ArgumentParser() parser.add_argument('--data_dir', default='/root/autodl-tmp/pointcloud_seg/data/S3DIS_new/processed_npy') parser.add_argument('--save_dir', default='./checkpoints_v6') parser.add_argument('--batch_size', type=int, default=4) parser.add_argument('--num_epochs', type=int, default=300) parser.add_argument('--num_points', type=int, default=1024) parser.add_argument('--num_classes', type=int, default=13) parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu') parser.add_argument('--use_class_weights', action='store_true') parser.add_argument('--use_lovasz', action='store_true') parser.add_argument('--warmup_epochs', type=int, default=5) parser.add_argument('--num_workers', type=int, default=8) parser.add_argument('--grad_clip', type=float, default=1.0) parser.add_argument('--use_focal', action='store_true', help='Use Focal Loss instead of CrossEntropy') parser.add_argument('--focal_gamma', type=float, default=2.0, help='Gamma parameter for Focal Loss') args = parser.parse_args() os.makedirs(args.save_dir, exist_ok=True) device = torch.device(args.device) train_ds = S3DISDatasetAug(args.data_dir, split='train', num_points=args.num_points, augment=True) val_ds = S3DISDatasetAug(args.data_dir, split='val', num_points=args.num_points, augment=False) train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=True) val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=max(1, args.num_workers//2)) class_weights = None if args.use_class_weights: print("Computing class weights...") class_weights = compute_class_weights(train_ds.files, args.num_classes, method='inv_sqrt').to(device) print("class weights:", class_weights.cpu().numpy()) model = FreqFormerV6(num_classes=args.num_classes) if torch.cuda.device_count() > 1 and args.device.startswith('cuda'): print("Using DataParallel on devices:", list(range(torch.cuda.device_count()))) model = torch.nn.DataParallel(model) model = model.to(device) focal_criterion = None # 添加在优化器初始化之前 if args.use_focal: print(f"Using Focal Loss with gamma={args.focal_gamma}") # 设置alpha权重(优先使用计算的类别权重) if class_weights is not None: alpha = class_weights else: alpha = torch.ones(args.num_classes).to(device) # 增强低频类别的权重(根据实际数据集调整) rare_classes = [3, 4, 6, 9, 11] # 示例:S3DIS中的低频类别 for c in rare_classes: if c < len(alpha): alpha[c] *= 3.0 # 低频类别权重增加3倍 focal_criterion = FocalLoss(alpha=alpha, gamma=args.focal_gamma).to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4) # Cosine with restarts is optional; using CosineAnnealingLR for smooth decay scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max(1, args.num_epochs - args.warmup_epochs)) best_miou = 0.0 start_epoch = 0 print("Training config:", vars(args)) print("Model params (M):", sum(p.numel() for p in model.parameters())/1e6) # simple warmup schedule helper def get_lr_factor(epoch): if epoch < args.warmup_epochs: return float(epoch + 1) / max(1.0, args.warmup_epochs) return 1.0 for epoch in range(start_epoch, args.num_epochs): model.train() t0 = time.time() running_loss = 0.0 iters = 0 for batch in train_loader: local_feat = batch['local_feat'].to(device) # [B,N,6] coords = batch['coords'].to(device) # [B,N,3] extra = batch['extra'].to(device) # [B,N,3] labels = batch['label'].to(device) # [B,N] optimizer.zero_grad() # model expects (coords, feats) logits = model(coords, extra) # [B,N,C] B, N, C = logits.shape logits_flat = logits.view(-1, C) labels_flat = labels.view(-1) if args.use_focal and focal_criterion is not None: ce = focal_criterion(logits_flat, labels_flat) # 使用Focal Loss else: if class_weights is not None: ce = F.cross_entropy(logits_flat, labels_flat, weight=class_weights, ignore_index=-1) else: ce = F.cross_entropy(logits_flat, labels_flat, ignore_index=-1) if args.use_lovasz: lov = lovasz_softmax(probs, labels_flat, ignore_index=-1) else: lov = logits_flat.new_tensor(0.0) # combine: CE + 0.6 * Dice + 0.3 * Lovasz (weights chosen experimentally) loss = ce + 0.6 * dice + 0.3 * lov # warmup LR factor by scaling gradient step (we scale optimizer lr directly) lr_mult = get_lr_factor(epoch) for g in optimizer.param_groups: g['lr'] = args.lr * lr_mult loss.backward() # grad clip torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) optimizer.step() running_loss += loss.item() iters += 1 # scheduler step (after epoch; scheduler uses base lr, we used manual warmup above) try: scheduler.step() except Exception: pass avg_loss = running_loss / max(1, iters) t1 = time.time() print(f"Epoch {epoch+1}/{args.num_epochs} TrainLoss: {avg_loss:.4f} Time: {(t1-t0):.1f}s LR: {optimizer.param_groups[0]['lr']:.6f}") # validation every 5 epochs if (epoch + 1) % 5 == 0 or (epoch + 1) == args.num_epochs: model.eval() conf = np.zeros((args.num_classes, args.num_classes), dtype=np.int64) tot_loss = 0.0 cnt = 0 with torch.no_grad(): for batch in val_loader: local_feat = batch['local_feat'].to(device) coords = batch['coords'].to(device) extra = batch['extra'].to(device) labels = batch['label'].to(device) logits = model(coords, extra) B, N, C = logits.shape logits_flat = logits.view(-1, C) labels_flat = labels.view(-1) if class_weights is not None: loss_ce = F.cross_entropy(logits_flat, labels_flat, weight=class_weights, ignore_index=-1) else: loss_ce = F.cross_entropy(logits_flat, labels_flat, ignore_index=-1) probs = F.softmax(logits_flat, dim=-1) dice = multiclass_dice_loss(probs, labels_flat) if args.use_lovasz: lov = lovasz_softmax(probs, labels_flat, ignore_index=-1) else: lov = logits_flat.new_tensor(0.0) loss = loss_ce + 0.6 * dice + 0.3 * lov tot_loss += loss.item() preds = logits.argmax(dim=-1).cpu().numpy().reshape(-1) gts = labels.cpu().numpy().reshape(-1) conf += compute_confusion_matrix(preds, gts, args.num_classes) cnt += 1 mean_loss = tot_loss / max(1, cnt) iou = compute_iou_from_conf(conf) miou = np.nanmean(iou) oa = np.diag(conf).sum() / (conf.sum() + 1e-12) print(f"-- Validation Loss: {mean_loss:.4f} mIoU: {miou:.4f} OA: {oa:.4f}") print("Per-class IoU:") for cid, v in enumerate(iou): print(f" class {cid:02d}: {v:.4f}") if miou > best_miou: best_miou = miou path = os.path.join(args.save_dir, f'best_epoch_{epoch+1:03d}_miou_{miou:.4f}.pth') state = {'epoch': epoch+1, 'best_miou': best_miou} if isinstance(model, torch.nn.DataParallel): state['model_state_dict'] = model.module.state_dict() else: state['model_state_dict'] = model.state_dict() torch.save(state, path) print("Saved best:", path) final_path = os.path.join(args.save_dir, f'final_epoch_{args.num_epochs:03d}_miou_{best_miou:.4f}.pth') state = {'epoch': args.num_epochs, 'best_miou': best_miou} if isinstance(model, torch.nn.DataParallel): state['model_state_dict'] = model.module.state_dict() else: state['model_state_dict'] = model.state_dict() torch.save(state, final_path) print("Training finished. Final saved to:", final_path) if __name__ == "__main__": main()现在这样可以了吗?
10-23
那你“”" Train script for FreqFormerV6 usage example: python train_freqformer_v6.py --data_dir <…> --batch_size 4 --num_epochs 150 “”" import os, time, argparse import numpy as np import torch from torch.utils.data import Dataset, DataLoader import torch.nn.functional as F import your V6 model (adjust path if needed) from freqformer_v6 import FreqFormerV6 ----------------------- Dice loss (multi-class) and helpers ----------------------- def one_hot(labels, num_classes): # labels: [N] (int) y = torch.eye(num_classes, device=labels.device)[labels] return y # [N, C] def multiclass_dice_loss(probs, labels, eps=1e-6): # probs: [BP, C], labels: [BP] (ints) C = probs.shape[1] mask = (labels >= 0) if mask.sum() == 0: return probs.new_tensor(0.) probs = probs[mask] # [M, C] labels = labels[mask] gt = one_hot(labels, C) # [M, C] # compute per-class dice intersection = (probs * gt).sum(dim=0) cardinality = probs.sum(dim=0) + gt.sum(dim=0) dice = (2. * intersection + eps) / (cardinality + eps) loss = 1.0 - dice return loss.mean() Focal Loss (解决类别不平衡问题) ----------------------- class FocalLoss(nn.Module): def init(self, alpha=None, gamma=2.0, reduction=‘mean’): super().init() self.alpha = alpha # 类别权重向量 [num_classes] self.gamma = gamma self.reduction = reduction def forward(self, inputs, targets): ce_loss = F.cross_entropy(inputs, targets, reduction='none', weight=self.alpha) pt = torch.exp(-ce_loss) focal_loss = (1 - pt)**self.gamma * ce_loss if self.reduction == 'mean': return focal_loss.mean() elif self.reduction == 'sum': return focal_loss.sum() return focal_loss ----------------------- Lovasz (same as your v5 version) - copied minimal ----------------------- def lovasz_grad(gt_sorted): gts = gt_sorted.sum() if gts == 0: return torch.zeros_like(gt_sorted) intersection = gts - gt_sorted.cumsum(0) union = gts + (1 - gt_sorted).cumsum(0) jaccard = 1. - intersection / union if gt_sorted.numel() > 1: jaccard[1:] = jaccard[1:] - jaccard[:-1] return jaccard def flatten_probas(probas, labels, ignore_index=-1): mask = (labels != ignore_index) if not mask.any(): return probas.new(0), labels.new(0) probas = probas[mask] labels = labels[mask] return probas, labels def lovasz_softmax(probas, labels, classes=‘present’, ignore_index=-1): C = probas.size(1) losses = [] probas, labels = flatten_probas(probas, labels, ignore_index) if probas.numel() == 0: return probas.new_tensor(0.) for c in range(C): fg = (labels == c).float() if classes == ‘present’ and fg.sum() == 0: continue class_pred = probas[:, c] errors = (fg - class_pred).abs() perm = torch.argsort(errors, descending=True) fg_sorted = fg[perm] grad = lovasz_grad(fg_sorted) loss_c = torch.dot(F.relu(errors[perm]), grad) losses.append(loss_c) if len(losses) == 0: return probas.new_tensor(0.) return sum(losses) / len(losses) ----------------------- Dataset (S3DIS npy layout assumed) ----------------------- class S3DISDatasetAug(Dataset): def init(self, data_dir, split=‘train’, val_area=‘Area_5’, num_points=1024, augment=True): self.num_points = num_points self.augment = augment and (split == ‘train’) self.files = [] for f in sorted(os.listdir(data_dir)): if not f.endswith(.npy’): continue if split == ‘train’ and val_area in f: continue if split == ‘val’ and val_area not in f: continue self.files.append(os.path.join(data_dir, f)) if len(self.files) == 0: raise RuntimeError(f"No files found in {data_dir} (split={split})") def __len__(self): return len(self.files) def __getitem__(self, idx): data = np.load(self.files[idx]) coords = data[:, :3].astype(np.float32) extra = data[:, 3:6].astype(np.float32) labels = data[:, 6].astype(np.int64) N = coords.shape[0] if N >= self.num_points: choice = np.random.choice(N, self.num_points, replace=False) else: choice = np.random.choice(N, self.num_points, replace=True) coords = coords[choice] extra = extra[choice] labels = labels[choice] if self.augment: theta = np.random.uniform(0, 2*np.pi) c, s = np.cos(theta), np.sin(theta) R = np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]], dtype=np.float32) coords = coords.dot(R.T) scale = np.random.uniform(0.9, 1.1) coords = coords * scale coords = coords + np.random.normal(0, 0.01, coords.shape).astype(np.float32) local_feat = np.concatenate([coords, extra], axis=1) return { 'local_feat': torch.from_numpy(local_feat).float(), 'coords': torch.from_numpy(coords).float(), 'extra': torch.from_numpy(extra).float(), 'label': torch.from_numpy(labels).long() } ----------------------- helpers: confusion & iou ----------------------- def compute_confusion_matrix(preds, gts, num_classes): mask = (gts >= 0) & (gts < num_classes) gt = gts[mask].astype(np.int64) pred = preds[mask].astype(np.int64) conf = np.bincount(gt * num_classes + pred, minlength=num_classes**2) return conf.reshape((num_classes, num_classes)) def compute_iou_from_conf(conf): inter = np.diag(conf) gt_sum = conf.sum(axis=1) pred_sum = conf.sum(axis=0) union = gt_sum + pred_sum - inter iou = inter / (union + 1e-10) return iou ----------------------- compute class weights ----------------------- def compute_class_weights(file_list, num_classes, method=‘inv_sqrt’): counts = np.zeros(num_classes, dtype=np.float64) for p in file_list: data = np.load(p, mmap_mode=‘r’) labels = data[:, 6].astype(np.int64) for c in range(num_classes): counts[c] += (labels == c).sum() counts = np.maximum(counts, 1.0) if method == ‘inv_freq’: weights = 1.0 / counts elif method == ‘inv_sqrt’: weights = 1.0 / np.sqrt(counts) else: weights = np.ones_like(counts) weights = weights / weights.sum() * num_classes return torch.from_numpy(weights.astype(np.float32)) ----------------------- main ----------------------- def main(): parser = argparse.ArgumentParser() parser.add_argument(‘–data_dir’, default=‘/root/autodl-tmp/pointcloud_seg/data/S3DIS_new/processed_npy’) parser.add_argument(‘–save_dir’, default=‘./checkpoints_v6’) parser.add_argument(‘–batch_size’, type=int, default=4) parser.add_argument(‘–num_epochs’, type=int, default=300) parser.add_argument(‘–num_points’, type=int, default=1024) parser.add_argument(‘–num_classes’, type=int, default=13) parser.add_argument(‘–lr’, type=float, default=1e-3) parser.add_argument(‘–device’, default=‘cuda’ if torch.cuda.is_available() else ‘cpu’) parser.add_argument(‘–use_class_weights’, action=‘store_true’) parser.add_argument(‘–use_lovasz’, action=‘store_true’) parser.add_argument(‘–warmup_epochs’, type=int, default=5) parser.add_argument(‘–num_workers’, type=int, default=8) parser.add_argument(‘–grad_clip’, type=float, default=1.0) parser.add_argument(‘–use_focal’, action=‘store_true’, help=‘Use Focal Loss instead of CrossEntropy) parser.add_argument(‘–focal_gamma’, type=float, default=2.0, help=‘Gamma parameter for Focal Loss) args = parser.parse_args() os.makedirs(args.save_dir, exist_ok=True) device = torch.device(args.device) train_ds = S3DISDatasetAug(args.data_dir, split='train', num_points=args.num_points, augment=True) val_ds = S3DISDatasetAug(args.data_dir, split='val', num_points=args.num_points, augment=False) train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=True) val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=max(1, args.num_workers//2)) class_weights = None if args.use_class_weights: print("Computing class weights...") class_weights = compute_class_weights(train_ds.files, args.num_classes, method='inv_sqrt').to(device) print("class weights:", class_weights.cpu().numpy()) model = FreqFormerV6(num_classes=args.num_classes) if torch.cuda.device_count() > 1 and args.device.startswith('cuda'): print("Using DataParallel on devices:", list(range(torch.cuda.device_count()))) model = torch.nn.DataParallel(model) model = model.to(device) focal_criterion = None # 添加在优化器初始化之前 if args.use_focal: print(f"Using Focal Loss with gamma={args.focal_gamma}") # 设置alpha权重(优先使用计算的类别权重) if class_weights is not None: alpha = class_weights else: alpha = torch.ones(args.num_classes).to(device) # 增强低频类别的权重(根据实际数据集调整) rare_classes = [3, 4, 6, 9, 11] # 示例:S3DIS中的低频类别 for c in rare_classes: if c < len(alpha): alpha[c] *= 3.0 # 低频类别权重增加3倍 focal_criterion = FocalLoss(alpha=alpha, gamma=args.focal_gamma).to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4) # Cosine with restarts is optional; using CosineAnnealingLR for smooth decay scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max(1, args.num_epochs - args.warmup_epochs)) best_miou = 0.0 start_epoch = 0 print("Training config:", vars(args)) print("Model params (M):", sum(p.numel() for p in model.parameters())/1e6) # simple warmup schedule helper def get_lr_factor(epoch): if epoch < args.warmup_epochs: return float(epoch + 1) / max(1.0, args.warmup_epochs) return 1.0 for epoch in range(start_epoch, args.num_epochs): model.train() t0 = time.time() running_loss = 0.0 iters = 0 for batch in train_loader: local_feat = batch['local_feat'].to(device) # [B,N,6] coords = batch['coords'].to(device) # [B,N,3] extra = batch['extra'].to(device) # [B,N,3] labels = batch['label'].to(device) # [B,N] optimizer.zero_grad() # model expects (coords, feats) logits = model(coords, extra) # [B,N,C] B, N, C = logits.shape logits_flat = logits.view(-1, C) labels_flat = labels.view(-1) # === 关键修复点1:添加概率计算 === probs = F.softmax(logits_flat, dim=-1) # 必需用于Dice和Lovasz # === 关键修复点2:添加Dice损失计算 === dice = multiclass_dice_loss(probs, labels_flat) # 必需 if args.use_focal and focal_criterion is not None: ce = focal_criterion(logits_flat, labels_flat) # 使用Focal Loss else: if class_weights is not None: ce = F.cross_entropy(logits_flat, labels_flat, weight=class_weights, ignore_index=-1) else: ce = F.cross_entropy(logits_flat, labels_flat, ignore_index=-1) if args.use_lovasz: lov = lovasz_softmax(probs, labels_flat, ignore_index=-1) else: lov = logits_flat.new_tensor(0.0) # combine: CE + 0.6 * Dice + 0.3 * Lovasz (weights chosen experimentally) loss = ce + 0.6 * dice + 0.3 * lov # warmup LR factor by scaling gradient step (we scale optimizer lr directly) lr_mult = get_lr_factor(epoch) for g in optimizer.param_groups: g['lr'] = args.lr * lr_mult loss.backward() total_norm = 0 for p in model.parameters(): if p.grad is not None: param_norm = p.grad.detach().data.norm(2) total_norm += param_norm.item() ** 2 print(f"Gradient norm: {total_norm ** 0.5:.4f}") # grad clip torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) optimizer.step() running_loss += loss.item() iters += 1 # scheduler step (after epoch; scheduler uses base lr, we used manual warmup above) try: scheduler.step() except Exception: pass avg_loss = running_loss / max(1, iters) t1 = time.time() print(f"Epoch {epoch+1}/{args.num_epochs} TrainLoss: {avg_loss:.4f} Time: {(t1-t0):.1f}s LR: {optimizer.param_groups[0]['lr']:.6f}") # validation every 5 epochs if (epoch + 1) % 5 == 0 or (epoch + 1) == args.num_epochs: model.eval() conf = np.zeros((args.num_classes, args.num_classes), dtype=np.int64) tot_loss = 0.0 cnt = 0 with torch.no_grad(): for batch in val_loader: local_feat = batch['local_feat'].to(device) coords = batch['coords'].to(device) extra = batch['extra'].to(device) labels = batch['label'].to(device) logits = model(coords, extra) B, N, C = logits.shape logits_flat = logits.view(-1, C) labels_flat = labels.view(-1) if class_weights is not None: loss_ce = F.cross_entropy(logits_flat, labels_flat, weight=class_weights, ignore_index=-1) else: loss_ce = F.cross_entropy(logits_flat, labels_flat, ignore_index=-1) probs = F.softmax(logits_flat, dim=-1) dice = multiclass_dice_loss(probs, labels_flat) if args.use_lovasz: lov = lovasz_softmax(probs, labels_flat, ignore_index=-1) else: lov = logits_flat.new_tensor(0.0) loss = loss_ce + 0.6 * dice + 0.3 * lov tot_loss += loss.item() preds = logits.argmax(dim=-1).cpu().numpy().reshape(-1) gts = labels.cpu().numpy().reshape(-1) conf += compute_confusion_matrix(preds, gts, args.num_classes) cnt += 1 mean_loss = tot_loss / max(1, cnt) iou = compute_iou_from_conf(conf) miou = np.nanmean(iou) oa = np.diag(conf).sum() / (conf.sum() + 1e-12) print(f"-- Validation Loss: {mean_loss:.4f} mIoU: {miou:.4f} OA: {oa:.4f}") print("Per-class IoU:") for cid, v in enumerate(iou): print(f" class {cid:02d}: {v:.4f}") if miou > best_miou: best_miou = miou path = os.path.join(args.save_dir, f'best_epoch_{epoch+1:03d}_miou_{miou:.4f}.pth') state = {'epoch': epoch+1, 'best_miou': best_miou} if isinstance(model, torch.nn.DataParallel): state['model_state_dict'] = model.module.state_dict() else: state['model_state_dict'] = model.state_dict() torch.save(state, path) print("Saved best:", path) final_path = os.path.join(args.save_dir, f'final_epoch_{args.num_epochs:03d}_miou_{best_miou:.4f}.pth') state = {'epoch': args.num_epochs, 'best_miou': best_miou} if isinstance(model, torch.nn.DataParallel): state['model_state_dict'] = model.module.state_dict() else: state['model_state_dict'] = model.state_dict() torch.save(state, final_path) print("Training finished. Final saved to:", final_path) if name == “main”: main()把这个给我生成一份完整的代码吧
最新发布
10-23
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

交通上的硅基思维

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值