import torch
import torch.nn as nn
import torch.nn.functional as F
# 定义稀有类别列表(必须放在类外部)
rare_classes = [3, 4, 6, 9, 11] # 定义在类外部
# -----------------------------
# 1. 增强型频域注意力 (EFA)
# -----------------------------
class EnhancedFreqAttention(nn.Module):
def __init__(self, dim, reduction=4):
super().__init__()
self.conv = nn.Sequential(
nn.Conv1d(dim, dim//reduction, 3, padding=1),
nn.GELU(),
nn.Conv1d(dim//reduction, dim, 3, padding=1)
)
self.attn = nn.Sequential(
nn.Linear(dim, dim//reduction),
nn.GELU(),
nn.Linear(dim//reduction, dim)
)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
# 频域卷积注意力
conv_attn = self.conv(x.transpose(1, 2)).transpose(1, 2)
# 全局上下文注意力
global_attn = torch.mean(x, dim=1, keepdim=True) # 保持维度
global_attn = self.attn(global_attn)
# 组合注意力
fused_attn = self.sigmoid(conv_attn + global_attn)
return x * fused_attn
# -----------------------------
# 2. 多尺度频域融合块
# -----------------------------
class MultiScaleFreqBlock(nn.Module):
def __init__(self, embed_dim):
super().__init__()
self.branch1 = nn.Sequential(
nn.Linear(12, embed_dim//2),
nn.GELU()
)
self.branch2 = nn.Sequential(
nn.Linear(12, embed_dim//2), # 两倍频域信息
nn.GELU()
)
self.fuse = nn.Linear(embed_dim, embed_dim)
def forward(self, x):
B, N, _ = x.shape # 获取原始点数量N
# 基础频域特征
fft1 = torch.fft.fft(x, dim=1)
feat1 = torch.cat([fft1.real, fft1.imag], dim=-1) # 形状: [B, N, 12]
b1 = self.branch1(feat1)
# 高频增强特征
fft2 = torch.fft.fft(x, n=2048, dim=1) # 形状: [B, 2048, 6]复数
feat2 = torch.cat([fft2.real, fft2.imag], dim=-1) # 形状: [B, 2048, 12]
# 关键修正:先处理再截断
b2_full = self.branch2(feat2) # 形状: [B, 2048, embed_dim//2]
b2 = b2_full[:, :N, :] # 截取前N个点
# 融合
fused = torch.cat([b1, b2], dim=-1) # 形状: [B, N, embed_dim]
return self.fuse(fused)
# -----------------------------
# 3. Transformer块 (补充定义)
# -----------------------------
class TransformerBlock(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4.0, drop=0.1):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(
dim, num_heads, dropout=drop, batch_first=True
)
self.norm2 = nn.LayerNorm(dim)
mlp_dim = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_dim),
nn.GELU(),
nn.Dropout(drop),
nn.Linear(mlp_dim, dim),
nn.Dropout(drop)
)
def forward(self, x):
# 残差连接1
identity = x
x = self.norm1(x)
x, _ = self.attn(x, x, x)
x = x + identity
# 残差连接2
identity = x
x = self.norm2(x)
x = self.mlp(x)
x = x + identity
return x
# -----------------------------
# 4. 主网络结构:FreqFormerV8
# -----------------------------
class FreqFormerV8(nn.Module):
def __init__(self, num_classes=13, embed_dim=192, depth=6, num_heads=8):
super().__init__()
# 空间嵌入
self.spatial_embed = nn.Linear(6, embed_dim)
# 增强频域处理
self.freq_block = MultiScaleFreqBlock(embed_dim)
self.efa = EnhancedFreqAttention(embed_dim)
# 门控融合优化
self.fuse_gate = nn.Sequential(
nn.Linear(2 * embed_dim, embed_dim),
nn.ReLU(),
nn.Linear(embed_dim, embed_dim),
nn.Sigmoid()
)
# 深度Transformer
self.blocks = nn.ModuleList([
TransformerBlock(embed_dim, num_heads=num_heads)
for _ in range(depth)
])
# 双分类头
self.main_head = nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Linear(embed_dim, num_classes)
)
# 辅助头专注稀有类
self.aux_head = nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Linear(embed_dim, len(rare_classes))
)
# 初始化权重
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, coords, feats=None):
# 输入处理
if feats is not None:
x = torch.cat([coords, feats], dim=-1)
else:
feats = torch.zeros_like(coords)
x = torch.cat([coords, feats], dim=-1)
# 空间分支
spatial_feat = self.spatial_embed(x)
# 频域分支
freq_feat = self.freq_block(x)
freq_feat = self.efa(freq_feat)
# 门控融合
gate_input = torch.cat([spatial_feat, freq_feat], dim=-1)
gate = self.fuse_gate(gate_input)
fused = gate * spatial_feat + (1 - gate) * freq_feat
# Transformer处理
for blk in self.blocks:
fused = blk(fused)
# 双输出
main_out = self.main_head(fused)
aux_out = self.aux_head(fused)
return main_out, aux_out
if __name__ == "__main__":
# 测试验证
model = FreqFormerV8(num_classes=13)
coords = torch.randn(2, 1024, 3)
feats = torch.randn(2, 1024, 3)
main_out, aux_out = model(coords, feats)
print("Main output shape:", main_out.shape) # [2, 1024, 13]
print("Aux output shape:", aux_out.shape) # [2, 1024, 5]模型代码已经优化好了,# train_freqformer_v8.py
"""
Train script for FreqFormerV6
usage example:
python train_freqformer_v6.py --data_dir <…> --batch_size 4 --num_epochs 150
"""
import os, time, argparse
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingLR
# -------------------------------
# Import your V6 model (adjust path if needed)
from freqformer_v8 import FreqFormerV8
# -------------------------------
# Dice loss (multi-class)
def one_hot(labels, num_classes):
"""
labels: [N] (int)
returns: [N, C]
"""
y = torch.eye(num_classes, device=labels.device)[labels]
return y
def multiclass_dice_loss(probs, labels, eps=1e-6):
"""
probs: [BP, C], labels: [BP] (ints)
"""
C = probs.shape[1]
mask = (labels >= 0)
if mask.sum() == 0:
return probs.new_tensor(0.)
probs = probs[mask] # [M, C]
labels = labels[mask]
gt = one_hot(labels, C) # [M, C]
intersection = (probs * gt).sum(dim=0)
cardinality = probs.sum(dim=0) + gt.sum(dim=0)
dice = (2. * intersection + eps) / (cardinality + eps)
loss = 1.0 - dice
return loss.mean()
# -------------------------------
# Focal Loss
class FocalLoss(nn.Module):
def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
super().__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs, targets):
ce_loss = F.cross_entropy(inputs, targets, reduction='none', weight=self.alpha)
pt = torch.exp(-ce_loss)
focal_loss = (1 - pt)**self.gamma * ce_loss
if self.reduction == 'mean':
return focal_loss.mean()
elif self.reduction == 'sum':
return focal_loss.sum()
return focal_loss
# -------------------------------
# Lovasz softmax
def lovasz_grad(gt_sorted):
gts = gt_sorted.sum()
if gts == 0:
return torch.zeros_like(gt_sorted)
intersection = gts - gt_sorted.cumsum(0)
union = gts + (1 - gt_sorted).cumsum(0)
jaccard = 1. - intersection / union
if gt_sorted.numel() > 1:
jaccard[1:] = jaccard[1:] - jaccard[:-1]
return jaccard
def flatten_probas(probas, labels, ignore_index=-1):
mask = (labels != ignore_index)
if not mask.any():
return probas.new(0), labels.new(0)
probas = probas[mask]
labels = labels[mask]
return probas, labels
def lovasz_softmax(probas, labels, classes='present', ignore_index=-1):
C = probas.size(1)
losses = []
probas, labels = flatten_probas(probas, labels, ignore_index)
if probas.numel() == 0:
return probas.new_tensor(0.)
for c in range(C):
fg = (labels == c).float()
if classes == 'present' and fg.sum() == 0:
continue
class_pred = probas[:, c]
errors = (fg - class_pred).abs()
perm = torch.argsort(errors, descending=True)
fg_sorted = fg[perm]
grad = lovasz_grad(fg_sorted)
loss_c = torch.dot(F.relu(errors[perm]), grad)
losses.append(loss_c)
if len(losses) == 0:
return probas.new_tensor(0.)
return sum(losses) / len(losses)
# -------------------------------
# Dataset (S3DIS npy layout assumed)
class S3DISDatasetAug(Dataset):
def __init__(self, data_dir, split='train', val_area='Area_5', num_points=1024, augment=True):
self.num_points = num_points
self.augment = augment and (split == 'train')
self.files = []
self.rare_classes = [3,4,6,9,11] # 必须在此初始化
for f in sorted(os.listdir(data_dir)):
if not f.endswith('.npy'):
continue
if split == 'train' and val_area in f:
continue
if split == 'val' and val_area not in f:
continue
self.files.append(os.path.join(data_dir, f))
if len(self.files) == 0:
raise RuntimeError(f"No files found in {data_dir} (split={split})")
# 构建文件类别缓存
self.file_class_map = {}
for file_path in self.files:
# 使用内存映射方式加载,避免一次性读入内存
data = np.load(file_path, mmap_mode='r')
# 抽样计算类别(避免处理整个大文件)
labels_sample = data[::100, 6] # 每隔100个点取一个
unique_classes = np.unique(labels_sample).tolist()
self.file_class_map[file_path] = unique_classes
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
if self.augment and np.random.rand() < 0.3:
# 使用初始化时构建的缓存而非实时加载
rare_files = [f for f in self.files if any(
c in self.file_class_map[f] for c in self.rare_classes # 关键修改
)]
if rare_files: # 避免空列表
file_path = np.random.choice(rare_files)
return self._load_file(file_path) # 调用_load_file加载数据
# 如果没有稀有文件,则按正常路径加载
# 默认加载当前索引的文件
return self._load_file(self.files[idx])
# 新增方法,用于从文件路径加载数据并处理
def _load_file(self, file_path):
data = np.load(file_path)
coords = data[:, :3].astype(np.float32)
extra = data[:, 3:6].astype(np.float32)
labels = data[:, 6].astype(np.int64)
N = coords.shape[0]
if N >= self.num_points:
choice = np.random.choice(N, self.num_points, replace=False)
else:
choice = np.random.choice(N, self.num_points, replace=True)
coords = coords[choice]
extra = extra[choice]
labels = labels[choice]
if self.augment:
theta = np.random.uniform(0, 2*np.pi)
c, s = np.cos(theta), np.sin(theta)
R = np.array([[c, -s, 0], [s, c, 0], [0,0,1]], dtype=np.float32)
coords = coords.dot(R.T)
scale = np.random.uniform(0.9,1.1)
coords = coords * scale
coords = coords + np.random.normal(0, 0.01, coords.shape).astype(np.float32)
local_feat = np.concatenate([coords, extra], axis=1)
return {
'coords': torch.from_numpy(coords).float(),
'extra': torch.from_numpy(extra).float(),
'local_feat': torch.from_numpy(local_feat).float(),
'label': torch.from_numpy(labels).long()
}
# -------------------------------
# Confusion matrix & IoU
def compute_confusion_matrix(preds, gts, num_classes):
mask = (gts >= 0) & (gts < num_classes)
gt = gts[mask].astype(np.int64)
pred = preds[mask].astype(np.int64)
conf = np.bincount(gt*num_classes + pred, minlength=num_classes**2)
return conf.reshape((num_classes, num_classes))
def compute_iou_from_conf(conf):
inter = np.diag(conf)
gt_sum = conf.sum(axis=1)
pred_sum = conf.sum(axis=0)
union = gt_sum + pred_sum - inter
iou = inter / (union + 1e-10)
return iou
# -------------------------------
# Class weights
def compute_class_weights(file_list, num_classes, method='inv_sqrt'):
counts = np.zeros(num_classes, dtype=np.float64)
for p in file_list:
data = np.load(p, mmap_mode='r')
labels = data[:,6].astype(np.int64)
for c in range(num_classes):
counts[c] += (labels == c).sum()
counts = np.maximum(counts, 1.0)
if method=='inv_freq':
weights = 1.0 / counts
elif method=='inv_sqrt':
weights = 1.0 / np.sqrt(counts)
else:
weights = np.ones_like(counts)
weights = weights / weights.sum() * num_classes
return torch.from_numpy(weights.astype(np.float32))
# -------------------------------
# Main training loop
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', default='/root/autodl-tmp/pointcloud_seg/data/S3DIS_new/processed_npy')
parser.add_argument('--save_dir', default='./checkpoints_v6')
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--num_epochs', type=int, default=300)
parser.add_argument('--num_points', type=int, default=1024)
parser.add_argument('--num_classes', type=int, default=13)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--use_class_weights', action='store_true')
parser.add_argument('--use_lovasz', action='store_true')
parser.add_argument('--warmup_epochs', type=int, default=5)
parser.add_argument('--num_workers', type=int, default=8)
parser.add_argument('--grad_clip', type=float, default=1.0)
parser.add_argument('--use_focal', action='store_true', help='Use Focal Loss instead of CrossEntropy')
parser.add_argument('--focal_gamma', type=float, default=2.0, help='Gamma parameter for Focal Loss')
args = parser.parse_args()
os.makedirs(args.save_dir, exist_ok=True)
device = torch.device(args.device)
# Dataset
train_ds = S3DISDatasetAug(args.data_dir, split='train', num_points=args.num_points, augment=True)
val_ds = S3DISDatasetAug(args.data_dir, split='val', num_points=args.num_points, augment=False)
train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=True)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=max(1, args.num_workers//2))
# Class weights
class_weights = None
if args.use_class_weights:
print("Computing class weights...")
class_weights = compute_class_weights(train_ds.files, args.num_classes).to(device)
print("class weights:", class_weights.cpu().numpy())
# Model
model = FreqFormerV8(num_classes=args.num_classes)
if torch.cuda.device_count() > 1 and args.device.startswith('cuda'):
print("Using DataParallel on devices:", list(range(torch.cuda.device_count())))
model = torch.nn.DataParallel(model)
model = model.to(device)
# Focal Loss
focal_criterion = None
if args.use_focal:
print(f"Using Focal Loss with gamma={args.focal_gamma}")
alpha = class_weights if class_weights is not None else torch.ones(args.num_classes).to(device)
rare_classes = [3,4,6,9,11] # adjust as needed
for c in rare_classes:
if c < len(alpha):
alpha[c] *= 25.0
focal_criterion = FocalLoss(alpha=alpha, gamma=args.focal_gamma).to(device)
# 替换为新的组合调度器
optimizer = torch.optim.AdamW(
model.parameters(),
lr=args.lr,
weight_decay=1e-4
)
# 预热阶段调度器
warmup = torch.optim.lr_scheduler.LinearLR(
optimizer,
start_factor=0.01,
end_factor=1.0,
total_iters=args.warmup_epochs
)
# 余弦退火调度器
cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=max(1, args.num_epochs - args.warmup_epochs)
)
# 组合调度器
scheduler = torch.optim.lr_scheduler.SequentialLR(
optimizer,
schedulers=[warmup, cosine],
milestones=[args.warmup_epochs]
)
best_miou = 0.0
# Training loop
for epoch in range(args.num_epochs):
model.train()
t0 = time.time()
running_loss = 0.0
for batch in train_loader:
coords = batch['coords'].to(device)
extra = batch['extra'].to(device)
labels = batch['label'].to(device)
optimizer.zero_grad()
logits = model(coords, extra)
B,N,C = logits.shape
logits_flat = logits.view(-1,C)
labels_flat = labels.view(-1)
probs = F.softmax(logits_flat, dim=-1)
dice = multiclass_dice_loss(probs, labels_flat)
if args.use_focal and focal_criterion is not None:
ce = focal_criterion(logits_flat, labels_flat)
else:
ce = F.cross_entropy(logits_flat, labels_flat, weight=class_weights, ignore_index=-1)
lov = lovasz_softmax(probs, labels_flat) if args.use_lovasz else logits_flat.new_tensor(0.0)
loss = 0.5*ce + 2.0*dice + 0.3*lov
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
optimizer.step()
running_loss += loss.item()
scheduler.step()
avg_loss = running_loss / max(1, len(train_loader))
t1 = time.time()
print(f"Epoch {epoch+1}/{args.num_epochs} TrainLoss: {avg_loss:.4f} Time: {(t1-t0):.1f}s LR: {optimizer.param_groups[0]['lr']:.6f}")
# Validation
if (epoch+1)%5==0 or (epoch+1)==args.num_epochs:
model.eval()
conf = np.zeros((args.num_classes, args.num_classes), dtype=np.int64)
tot_loss = 0.0
cnt = 0
with torch.no_grad():
for batch in val_loader:
coords = batch['coords'].to(device)
extra = batch['extra'].to(device)
labels = batch['label'].to(device)
logits = model(coords, extra)
B,N,C = logits.shape
logits_flat = logits.view(-1,C)
labels_flat = labels.view(-1)
if class_weights is not None:
loss_ce = F.cross_entropy(logits_flat, labels_flat, weight=class_weights, ignore_index=-1)
else:
loss_ce = F.cross_entropy(logits_flat, labels_flat, ignore_index=-1)
probs = F.softmax(logits_flat, dim=-1)
dice_val = multiclass_dice_loss(probs, labels_flat)
lov_val = lovasz_softmax(probs, labels_flat) if args.use_lovasz else logits_flat.new_tensor(0.0)
loss_val = 0.5*loss_ce + 2.0*dice_val + 0.3*lov_val
tot_loss += loss_val.item()
preds = logits.argmax(dim=-1).cpu().numpy().reshape(-1)
gts = labels.cpu().numpy().reshape(-1)
conf += compute_confusion_matrix(preds, gts, args.num_classes)
cnt += 1
mean_loss = tot_loss / max(1, cnt)
iou = compute_iou_from_conf(conf)
miou = np.nanmean(iou)
oa = np.diag(conf).sum() / (conf.sum() + 1e-12)
print(f"-- Validation Loss: {mean_loss:.4f} mIoU: {miou:.4f} OA: {oa:.4f}")
print("Per-class IoU:")
for cid, v in enumerate(iou):
print(f" class {cid:02d}: {v:.4f}")
if miou > best_miou:
best_miou = miou
path = os.path.join(args.save_dir, f'best_epoch_{epoch+1:03d}miou.pth')
state = {
'epoch': epoch+1,
'best_miou': best_miou,
'model_state_dict': model.module.state_dict() if isinstance(model, torch.nn.DataParallel) else model.state_dict()
}
torch.save(state, path)
print("Saved best:", path)
# Final save
final_path = os.path.join(args.save_dir, f'final_epoch_{args.num_epochs:03d}miou.pth')
state = {
'epoch': args.num_epochs,
'best_miou': best_miou,
'model_state_dict': model.module.state_dict() if isinstance(model, torch.nn.DataParallel) else model.state_dict()
}
torch.save(state, final_path)
print("Training finished. Final saved to:", final_path)
if __name__ == "__main__":
main()训练代码还需要怎么优化呢?要真正的能提点上来
最新发布