"""FreqFormerV7_full.py
Combined model (FreqFormerV7) and training script tuned for 2xRTX4090 (DDP, AMP, improved FFT, stronger transformer,
better fusion, balanced loss, scheduler restarts). Format follows your earlier scripts for easy swap-in.
Usage (single-node multi-gpu):
torchrun --nproc_per_node=2 FreqFormerV7_full.py --data_dir <...> --save_dir ./checkpoints_v7 --batch_size 8 --num_epochs 200
If you want single-GPU debug:
python FreqFormerV7_full.py --local_rank 0 --nproc_per_node 1 --debug
"""
import os
import time
import argparse
import math
from typing import Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
# -----------------------------
# Utilities: distributed helpers
# -----------------------------
def is_main_process():
return not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
def setup_ddp(local_rank):
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
else:
rank = 0
world_size = 1
torch.cuda.set_device(local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
return rank, world_size
# -----------------------------
# Dice + Lovasz (helpers)
# -----------------------------
def one_hot(labels, num_classes):
y = torch.eye(num_classes, device=labels.device)[labels]
return y
def multiclass_dice_loss(probs, labels, eps=1e-6):
C = probs.shape[1]
mask = (labels >= 0)
if mask.sum() == 0:
return probs.new_tensor(0.)
probs = probs[mask]
labels = labels[mask]
gt = one_hot(labels, C)
intersection = (probs * gt).sum(dim=0)
cardinality = probs.sum(dim=0) + gt.sum(dim=0)
dice = (2. * intersection + eps) / (cardinality + eps)
loss = 1.0 - dice
return loss.mean()
def lovasz_grad(gt_sorted):
gts = gt_sorted.sum()
if gts == 0:
return torch.zeros_like(gt_sorted)
intersection = gts - gt_sorted.cumsum(0)
union = gts + (1 - gt_sorted).cumsum(0)
jaccard = 1. - intersection / union
if gt_sorted.numel() > 1:
jaccard[1:] = jaccard[1:] - jaccard[:-1]
return jaccard
def flatten_probas(probas, labels, ignore_index=-1):
mask = (labels != ignore_index)
if not mask.any():
return probas.new(0), labels.new(0)
probas = probas[mask]
labels = labels[mask]
return probas, labels
def lovasz_softmax(probas, labels, classes='present', ignore_index=-1):
C = probas.size(1)
losses = []
probas, labels = flatten_probas(probas, labels, ignore_index)
if probas.numel() == 0:
return probas.new_tensor(0.)
for c in range(C):
fg = (labels == c).float()
if classes == 'present' and fg.sum() == 0:
continue
class_pred = probas[:, c]
errors = (fg - class_pred).abs()
perm = torch.argsort(errors, descending=True)
fg_sorted = fg[perm]
grad = lovasz_grad(fg_sorted)
loss_c = torch.dot(F.relu(errors[perm]), grad)
losses.append(loss_c)
if len(losses) == 0:
return probas.new_tensor(0.)
return sum(losses) / len(losses)
# -----------------------------
# Dataset (S3DIS-like npy layout)
# -----------------------------
class S3DISDatasetAug(Dataset):
def __init__(self, data_dir, split='train', val_area='Area_5', num_points=2048, augment=True):
self.num_points = num_points
self.augment = augment and (split == 'train')
self.files = []
for f in sorted(os.listdir(data_dir)):
if not f.endswith('.npy'):
continue
if split == 'train' and val_area in f:
continue
if split == 'val' and val_area not in f:
continue
self.files.append(os.path.join(data_dir, f))
if len(self.files) == 0:
raise RuntimeError(f"No files found in {data_dir} (split={split})")
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
data = np.load(self.files[idx])
coords = data[:, :3].astype(np.float32)
# extra could be RGB or normal; ensure shape [N,3]
extra = data[:, 3:6].astype(np.float32)
labels = data[:, 6].astype(np.int64)
N = coords.shape[0]
if N >= self.num_points:
choice = np.random.choice(N, self.num_points, replace=False)
else:
choice = np.random.choice(N, self.num_points, replace=True)
coords = coords[choice]
extra = extra[choice]
labels = labels[choice]
if self.augment:
theta = np.random.uniform(0, 2 * np.pi)
c, s = np.cos(theta), np.sin(theta)
R = np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]], dtype=np.float32)
coords = coords.dot(R.T)
scale = np.random.uniform(0.9, 1.1)
coords = coords * scale
coords = coords + np.random.normal(0, 0.02, coords.shape).astype(np.float32)
# random flip x/y
if np.random.rand() > 0.5:
coords[:, 0] = -coords[:, 0]
local_feat = np.concatenate([coords, extra], axis=1)
return {
'local_feat': torch.from_numpy(local_feat).float(),
'coords': torch.from_numpy(coords).float(),
'extra': torch.from_numpy(extra).float(),
'label': torch.from_numpy(labels).long()
}
# -----------------------------
# compute class weights
# -----------------------------
def compute_class_weights(file_list, num_classes, method='inv_sqrt'):
counts = np.zeros(num_classes, dtype=np.float64)
for p in file_list:
data = np.load(p, mmap_mode='r')
labels = data[:, 6].astype(np.int64)
for c in range(num_classes):
counts[c] += (labels == c).sum()
counts = np.maximum(counts, 1.0)
if method == 'inv_freq':
weights = 1.0 / counts
elif method == 'inv_sqrt':
weights = 1.0 / np.sqrt(counts)
else:
weights = np.ones_like(counts)
weights = weights / weights.sum() * num_classes
return torch.from_numpy(weights.astype(np.float32))
# -----------------------------
# Model: FreqFormerV7
# -----------------------------
class FreqConvBlock(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.net = nn.Sequential(
nn.Conv1d(in_ch, out_ch, kernel_size=3, padding=1),
nn.BatchNorm1d(out_ch),
nn.GELU(),
nn.Conv1d(out_ch, out_ch, kernel_size=3, padding=1),
nn.BatchNorm1d(out_ch),
nn.GELU(),
)
def forward(self, x):
return self.net(x.transpose(1, 2)).transpose(1, 2)
class FreqChannelAttention(nn.Module):
def __init__(self, dim, reduction=8):
super().__init__()
self.fc1 = nn.Linear(dim, dim // reduction)
self.fc2 = nn.Linear(dim // reduction, dim)
self.act = nn.GELU()
self.sigmoid = nn.Sigmoid()
def forward(self, x):
attn = torch.mean(x, dim=1)
attn = self.fc2(self.act(self.fc1(attn)))
attn = self.sigmoid(attn).unsqueeze(1)
return x * attn
class TransformerBlock(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4.0, drop=0.1):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(dim, num_heads, dropout=drop, batch_first=True)
self.norm2 = nn.LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, int(dim * mlp_ratio)),
nn.GELU(),
nn.Dropout(drop),
nn.Linear(int(dim * mlp_ratio), dim),
nn.Dropout(drop)
)
def forward(self, x):
h = x
x = self.norm1(x)
x, _ = self.attn(x, x, x)
x = x + h
h = x
x = self.norm2(x)
x = x + self.mlp(x)
return x
class CrossAttentionBlock(nn.Module):
"""Cross-attention between spatial and freq branches"""
def __init__(self, dim_q, dim_kv, num_heads, drop=0.1):
super().__init__()
self.norm_q = nn.LayerNorm(dim_q)
self.norm_kv = nn.LayerNorm(dim_kv)
self.attn = nn.MultiheadAttention(dim_q, num_heads, dropout=drop, batch_first=True)
# project kv to same dim as q
if dim_q != dim_kv:
self.kv_proj = nn.Linear(dim_kv, dim_q)
else:
self.kv_proj = nn.Identity()
self.ff = nn.Sequential(nn.LayerNorm(dim_q), nn.Linear(dim_q, dim_q * 4), nn.GELU(), nn.Linear(dim_q * 4, dim_q))
def forward(self, q, kv):
qn = self.norm_q(q)
kvn = self.kv_proj(self.norm_kv(kv))
attn_out, _ = self.attn(qn, kvn, kvn)
q = q + attn_out
q = q + self.ff(q)
return q
class FreqFormerV7(nn.Module):
def __init__(self, num_classes=13, embed_dim=384, freq_embed=192, depth=8, num_heads=8, drop=0.1, use_cross=True):
super().__init__()
self.embed_dim = embed_dim
# spatial: xyz + rgb/normals (6)
self.spatial_embed = nn.Linear(6, embed_dim)
# freq branch: uses xyz+feats (6) for FFT
self.freq_proj = nn.Linear(12, freq_embed) # after real+imag concat, length doubles (we'll handle dims dynamically)
self.freq_conv = nn.Sequential(
FreqConvBlock(freq_embed, freq_embed),
FreqConvBlock(freq_embed, freq_embed)
)
self.fca = FreqChannelAttention(freq_embed)
# projection of freq -> embed_dim
self.freq_to_spatial = nn.Linear(freq_embed, embed_dim)
# fusion
self.fuse_proj = nn.Linear(embed_dim + embed_dim, embed_dim)
# transformer backbone
self.blocks = nn.ModuleList([
TransformerBlock(embed_dim, num_heads=num_heads, mlp_ratio=4.0, drop=drop)
for _ in range(depth)
])
# optional cross attention between spatial and freq (early)
self.use_cross = use_cross
if use_cross:
self.cross = CrossAttentionBlock(embed_dim, freq_embed, num_heads, drop=drop)
# classification head
self.cls_head = nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Dropout(0.3),
nn.Linear(embed_dim, num_classes)
)
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, coords, feats=None):
# coords: [B,N,3], feats: [B,N,3]
if feats is None:
x = coords
feats = torch.zeros_like(coords)
else:
x = torch.cat([coords, feats], dim=-1)
B, N, _ = x.shape
# spatial branch
spatial_feat = self.spatial_embed(x) # [B,N,embed_dim]
# FFT branch: build input containing coords+feats then FFT along sequence dimension
fft_input = torch.cat([coords, feats], dim=-1) # [B,N,6]
# perform FFT along point dimension -> complex tensor [B,N,6]
fft_c = torch.fft.fft(fft_input, dim=1)
fft_real = fft_c.real
fft_imag = fft_c.imag
fft_cat = torch.cat([fft_real, fft_imag], dim=-1) # [B,N,12]
freq_feat = self.freq_proj(fft_cat) # [B,N,freq_embed]
freq_feat = self.freq_conv(freq_feat)
freq_feat = self.fca(freq_feat)
# optionally cross-attend: let spatial query freq
if self.use_cross:
# project freq to same dim if needed inside cross
spatial_feat = self.cross(spatial_feat, freq_feat)
# project freq to embed and fuse
freq_to_spatial = self.freq_to_spatial(freq_feat)
fused = torch.cat([spatial_feat, freq_to_spatial], dim=-1)
fused = self.fuse_proj(fused)
# transformer backbone
for blk in self.blocks:
fused = blk(fused)
out = self.cls_head(fused)
return out
# -----------------------------
# Helpers: confusion, iou
# -----------------------------
def compute_confusion_matrix(preds, gts, num_classes):
mask = (gts >= 0) & (gts < num_classes)
gt = gts[mask].astype(np.int64)
pred = preds[mask].astype(np.int64)
conf = np.bincount(gt * num_classes + pred, minlength=num_classes ** 2)
return conf.reshape((num_classes, num_classes))
def compute_iou_from_conf(conf):
inter = np.diag(conf)
gt_sum = conf.sum(axis=1)
pred_sum = conf.sum(axis=0)
union = gt_sum + pred_sum - inter
iou = inter / (union + 1e-10)
return iou
# -----------------------------
# Training loop (DDP-ready, AMP)
# -----------------------------
def save_checkpoint(state, path):
torch.save(state, path)
def train_main():
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', default='/root/autodl-tmp/pointcloud_seg/data/S3DIS_new/processed_npy')
parser.add_argument('--save_dir', default='./checkpoints_v7')
parser.add_argument('--batch_size', type=int, default=8)
parser.add_argument('--num_epochs', type=int, default=200)
parser.add_argument('--num_points', type=int, default=2048)
parser.add_argument('--num_classes', type=int, default=13)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--local_rank', type=int, default=int(os.environ.get('LOCAL_RANK', 0)))
parser.add_argument('--use_class_weights', action='store_true')
parser.add_argument('--use_lovasz', action='store_true')
parser.add_argument('--warmup_epochs', type=int, default=5)
parser.add_argument('--num_workers', type=int, default=8)
parser.add_argument('--grad_clip', type=float, default=1.0)
parser.add_argument('--debug', action='store_true')
parser.add_argument('--drop_last', action='store_true')
args = parser.parse_args()
# DDP setup
world_size = int(os.environ.get('WORLD_SIZE', 1))
use_ddp = world_size > 1
if use_ddp:
rank, ws = setup_ddp(args.local_rank)
else:
rank = 0
ws = 1
device = torch.device('cuda', args.local_rank if torch.cuda.is_available() else 'cpu')
os.makedirs(args.save_dir, exist_ok=True)
# datasets & samplers
train_ds = S3DISDatasetAug(args.data_dir, split='train', num_points=args.num_points, augment=True)
val_ds = S3DISDatasetAug(args.data_dir, split='val', num_points=args.num_points, augment=False)
if use_ddp:
train_sampler = DistributedSampler(train_ds)
val_sampler = DistributedSampler(val_ds, shuffle=False)
else:
train_sampler = None
val_sampler = None
train_loader = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler,
shuffle=(train_sampler is None), num_workers=args.num_workers, drop_last=(not args.debug and args.drop_last))
val_loader = DataLoader(val_ds, batch_size=1, sampler=val_sampler, shuffle=False,
num_workers=max(1, args.num_workers // 2))
# class weights
class_weights = None
if args.use_class_weights:
if is_main_process():
print("Computing class weights...")
cw = compute_class_weights(train_ds.files, args.num_classes, method='inv_sqrt')
class_weights = cw.to(device)
if is_main_process():
print("class weights:", class_weights.cpu().numpy())
# model
model = FreqFormerV7(num_classes=args.num_classes, embed_dim=384, freq_embed=192, depth=8, num_heads=8, drop=0.1, use_cross=True)
model.to(device)
if use_ddp:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=False)
# optimizer, scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=30, T_mult=2)
scaler = torch.cuda.amp.GradScaler()
best_miou = 0.0
start_epoch = 0
if is_main_process():
print("Training config:", vars(args))
print("Model params (M):", sum(p.numel() for p in model.parameters()) / 1e6)
# training loop
def get_lr_factor(epoch):
if epoch < args.warmup_epochs:
return float(epoch + 1) / max(1.0, args.warmup_epochs)
return 1.0
for epoch in range(start_epoch, args.num_epochs):
if use_ddp:
train_sampler.set_epoch(epoch)
model.train()
t0 = time.time()
running_loss = 0.0
iters = 0
for batch in train_loader:
local_feat = batch['local_feat'].to(device)
coords = batch['coords'].to(device)
extra = batch['extra'].to(device)
labels = batch['label'].to(device)
optimizer.zero_grad()
with torch.cuda.amp.autocast():
logits = model(coords, extra)
B, N, C = logits.shape
logits_flat = logits.view(-1, C)
labels_flat = labels.view(-1)
if class_weights is not None:
ce = F.cross_entropy(logits_flat, labels_flat, weight=class_weights, ignore_index=-1)
else:
ce = F.cross_entropy(logits_flat, labels_flat, ignore_index=-1)
probs = F.softmax(logits_flat, dim=-1)
dice = multiclass_dice_loss(probs, labels_flat)
if args.use_lovasz:
lov = lovasz_softmax(probs, labels_flat, ignore_index=-1)
else:
lov = logits_flat.new_tensor(0.0)
# stronger balanced combination
loss = 0.5 * ce + 0.8 * dice + 0.5 * lov
lr_mult = get_lr_factor(epoch)
for g in optimizer.param_groups:
g['lr'] = args.lr * lr_mult
scaler.scale(loss).backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
scaler.step(optimizer)
scaler.update()
running_loss += loss.item()
iters += 1
try:
scheduler.step()
except Exception:
pass
avg_loss = running_loss / max(1, iters)
t1 = time.time()
if is_main_process():
print(f"Epoch {epoch + 1}/{args.num_epochs} TrainLoss: {avg_loss:.4f} Time: {(t1 - t0):.1f}s LR: {optimizer.param_groups[0]['lr']:.6f}")
# validation every 5 epochs
if (epoch + 1) % 5 == 0 or (epoch + 1) == args.num_epochs:
model.eval()
conf = np.zeros((args.num_classes, args.num_classes), dtype=np.int64)
tot_loss = 0.0
cnt = 0
with torch.no_grad():
for batch in val_loader:
local_feat = batch['local_feat'].to(device)
coords = batch['coords'].to(device)
extra = batch['extra'].to(device)
labels = batch['label'].to(device)
logits = model(coords, extra)
B, N, C = logits.shape
logits_flat = logits.view(-1, C)
labels_flat = labels.view(-1)
if class_weights is not None:
loss_ce = F.cross_entropy(logits_flat, labels_flat, weight=class_weights, ignore_index=-1)
else:
loss_ce = F.cross_entropy(logits_flat, labels_flat, ignore_index=-1)
probs = F.softmax(logits_flat, dim=-1)
dice = multiclass_dice_loss(probs, labels_flat)
if args.use_lovasz:
lov = lovasz_softmax(probs, labels_flat, ignore_index=-1)
else:
lov = logits_flat.new_tensor(0.0)
loss = 0.5 * loss_ce + 0.8 * dice + 0.5 * lov
tot_loss += loss.item()
preds = logits.argmax(dim=-1).cpu().numpy().reshape(-1)
gts = labels.cpu().numpy().reshape(-1)
conf += compute_confusion_matrix(preds, gts, args.num_classes)
cnt += 1
mean_loss = tot_loss / max(1, cnt)
iou = compute_iou_from_conf(conf)
miou = np.nanmean(iou)
oa = np.diag(conf).sum() / (conf.sum() + 1e-12)
if is_main_process():
print(f"-- Validation Loss: {mean_loss:.4f} mIoU: {miou:.4f} OA: {oa:.4f}")
print("Per-class IoU:")
for cid, v in enumerate(iou):
print(f" class {cid:02d}: {v:.4f}")
# gather miou across ranks (optional) - here we assume main does saving
if is_main_process() and miou > best_miou:
best_miou = miou
path = os.path.join(args.save_dir, f'best_epoch_{epoch + 1:03d}_miou_{miou:.4f}.pth')
state = {'epoch': epoch + 1, 'best_miou': best_miou}
# unwrap DDP
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
state['model_state_dict'] = model.module.state_dict()
else:
state['model_state_dict'] = model.state_dict()
save_checkpoint(state, path)
print("Saved best:", path)
# final save
if is_main_process():
final_path = os.path.join(args.save_dir, f'final_epoch_{args.num_epochs:03d}_miou_{best_miou:.4f}.pth')
state = {'epoch': args.num_epochs, 'best_miou': best_miou}
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
state['model_state_dict'] = model.module.state_dict()
else:
state['model_state_dict'] = model.state_dict()
save_checkpoint(state, final_path)
print("Training finished. Final saved to:", final_path)
if __name__ == "__main__":
train_main()这个代码运行有问题(base) root@autodl-container-cac742a9c6-f35b76d7:~# source /root/miniconda3/bin/activate pointcloud
(pointcloud) root@autodl-container-cac742a9c6-f35b76d7:~# /root/miniconda3/envs/pointcloud/bin/python /root/autodl-tmp/pointcloud_seg/freqformer_v8.py
/root/autodl-tmp/pointcloud_seg/freqformer_v8.py:460: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
scaler = torch.cuda.amp.GradScaler()
Training config: {'data_dir': '/root/autodl-tmp/pointcloud_seg/data/S3DIS_new/processed_npy', 'save_dir': './checkpoints_v7', 'batch_size': 8, 'num_epochs': 200, 'num_points': 2048, 'num_classes': 13, 'lr': 0.001, 'local_rank': 0, 'use_class_weights': False, 'use_lovasz': False, 'warmup_epochs': 5, 'num_workers': 8, 'grad_clip': 1.0, 'debug': False, 'drop_last': False}
Model params (M): 16.879141
/root/autodl-tmp/pointcloud_seg/freqformer_v8.py:489: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with torch.cuda.amp.autocast():
Epoch 1/200 TrainLoss: nan Time: 7.9s LR: 0.000997
Epoch 2/200 TrainLoss: nan Time: 7.5s LR: 0.000989
Epoch 3/200 TrainLoss: nan Time: 7.2s LR: 0.000976
Epoch 4/200 TrainLoss: nan Time: 6.8s LR: 0.000957
Epoch 5/200 TrainLoss: nan Time: 7.0s LR: 0.000933
-- Validation Loss: nan mIoU: 0.0145 OA: 0.1889
Per-class IoU:
class 00: 0.1889
class 01: 0.0000
class 02: 0.0000
class 03: 0.0000
class 04: 0.0000
class 05: 0.0000
class 06: 0.0000
class 07: 0.0000
class 08: 0.0000
class 09: 0.0000
class 10: 0.0000
class 11: 0.0000
class 12: 0.0000
Saved best: ./checkpoints_v7/best_epoch_005_miou_0.0145.pth
最新发布