import os
import sys
import argparse
import datetime
import logging
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from pathlib import Path
import numpy as np
from sklearn.metrics import accuracy_score, balanced_accuracy_score
# 添加项目根目录到系统路径
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = BASE_DIR
sys.path.append(ROOT_DIR)
# 导入自定义模块
from data_utils import get_scanobjectnn_dataloader
from model import PointEfficientNetForClassification # 使用分类模型
from utils import PointNetSetAbstraction # 用于模型
def parse_args():
parser = argparse.ArgumentParser('Point-EfficientNet Training for ScanObjectNN')
parser.add_argument('--batch_size', type=int, default=32, help='Batch size during training')
parser.add_argument('--epoch', type=int, default=300, help='Number of epochs to run')
parser.add_argument('--learning_rate', type=float, default=0.001, help='Initial learning rate')
parser.add_argument('--gpu', type=str, default='0', help='Specify GPU devices')
parser.add_argument('--optimizer', type=str, default='Adam', choices=['Adam', 'SGD'], help='Optimizer')
parser.add_argument('--log_dir', type=str, default=None, help='Log directory path')
parser.add_argument('--decay_rate', type=float, default=0.0001, help='Weight decay rate')
parser.add_argument('--npoint', type=int, default=2048, help='Number of points per point cloud')
parser.add_argument('--normal', action='store_true', default=False, help='Use normal vectors')
parser.add_argument('--step_size', type=int, default=20, help='Learning rate decay step')
parser.add_argument('--lr_decay', type=float, default=0.5, help='Learning rate decay rate')
parser.add_argument('--num_workers', type=int, default=8, help='Number of data loading workers')
parser.add_argument('--root', type=str, default='/Users/zhaoboyang/Model_Introduction/Classification/PENet_cls/data', help='ScanObjectNN dataset root directory')
return parser.parse_args()
def main():
args = parse_args()
# 设置GPU设备
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 创建日志目录
timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
if args.log_dir is None:
experiment_dir = Path(f'./log/scanobjectnn/point_efficientnet_{timestr}')
else:
experiment_dir = Path(f'./log/scanobjectnn/{args.log_dir}')
experiment_dir.mkdir(parents=True, exist_ok=True)
checkpoints_dir = experiment_dir / 'checkpoints'
checkpoints_dir.mkdir(exist_ok=True)
log_dir = experiment_dir / 'logs'
log_dir.mkdir(exist_ok=True)
# 配置日志
logger = logging.getLogger("PointEfficientNet")
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler = logging.FileHandler(str(log_dir / 'train.log'))
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
def log_string(message):
logger.info(message)
print(message)
log_string(f'Using device: {device}')
log_string('PARAMETERS:')
log_string(str(args))
# 创建数据加载器
log_string('Loading ScanObjectNN dataset...')
train_loader = get_scanobjectnn_dataloader(
root=args.root,
batch_size=args.batch_size,
npoints=args.npoint,
split='train',
normal_channel=args.normal,
num_workers=args.num_workers,
augment=True
)
test_loader = get_scanobjectnn_dataloader(
root=args.root,
batch_size=args.batch_size,
npoints=args.npoint,
split='test',
normal_channel=args.normal,
num_workers=args.num_workers,
augment=False,
shuffle=False
)
log_string(f"Number of training samples: {len(train_loader.dataset)}")
log_string(f"Number of test samples: {len(test_loader.dataset)}")
# 创建模型
num_classes = 15 # ScanObjectNN有15个类别
model = PointEfficientNetForClassification(
num_classes=num_classes,
normal_channel=args.normal
).to(device)
# 初始化权重
model.init_weights()
# 损失函数
criterion = nn.CrossEntropyLoss().to(device)
# 优化器
if args.optimizer == 'Adam':
optimizer = optim.Adam(
model.parameters(),
lr=args.learning_rate,
weight_decay=args.decay_rate
)
else: # SGD
optimizer = optim.SGD(
model.parameters(),
lr=args.learning_rate,
momentum=0.9,
weight_decay=args.decay_rate
)
# 学习率调度器
scheduler = optim.lr_scheduler.StepLR(
optimizer,
step_size=args.step_size,
gamma=args.lr_decay
)
# 训练状态变量
best_test_acc = 0.0
best_test_acc_avg = 0.0
start_epoch = 0
# 尝试加载检查点
checkpoint_path = checkpoints_dir / 'best_model.pth'
if checkpoint_path.exists():
log_string('Loading checkpoint...')
checkpoint = torch.load(str(checkpoint_path))
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
start_epoch = checkpoint['epoch'] + 1
best_test_acc = checkpoint['best_test_acc']
best_test_acc_avg = checkpoint['best_test_acc_avg']
log_string(f'Resuming training from epoch {start_epoch}')
# 训练循环
for epoch in range(start_epoch, args.epoch):
log_string(f'Epoch {epoch + 1}/{args.epoch}')
model.train()
train_loss = 0.0
correct = 0
total = 0
# 训练一个epoch
for i, (points, labels) in enumerate(tqdm(train_loader, desc='Training')):
points = points.float().to(device)
labels = labels.long().to(device)
# 前向传播
outputs = model(points)
# 计算损失
loss = criterion(outputs, labels)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 统计指标
train_loss += loss.item() * points.size(0)
# 计算准确率
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
# 更新学习率
scheduler.step()
# 计算训练指标
train_loss /= len(train_loader.dataset)
train_acc = 100. * correct / total
log_string(f'Train Loss: {train_loss:.4f}, Accuracy: {train_acc:.2f}%')
log_string(f'Current LR: {scheduler.get_last_lr()[0]:.6f}')
# 评估模型
model.eval()
test_loss = 0.0
test_correct = 0
test_total = 0
all_preds = []
all_labels = []
with torch.no_grad():
for points, labels in tqdm(test_loader, desc='Testing'):
points = points.float().to(device)
labels = labels.long().to(device)
# 前向传播
outputs = model(points)
# 计算损失
loss = criterion(outputs, labels)
test_loss += loss.item() * points.size(0)
# 计算准确率
_, predicted = torch.max(outputs.data, 1)
test_total += labels.size(0)
test_correct += predicted.eq(labels).sum().item()
# 收集预测和真实标签
all_preds.append(predicted.cpu().numpy())
all_labels.append(labels.cpu().numpy())
# 计算测试指标
test_loss /= len(test_loader.dataset)
test_acc = 100. * test_correct / test_total
# 计算平均准确率
all_preds = np.concatenate(all_preds)
all_labels = np.concatenate(all_labels)
test_acc_avg = 100. * balanced_accuracy_score(all_labels, all_preds)
log_string('Test Results:')
log_string(f'Loss: {test_loss:.4f}, Accuracy: {test_acc:.2f}%, Avg Accuracy: {test_acc_avg:.2f}%')
# 保存最佳模型
if test_acc > best_test_acc:
best_test_acc = test_acc
best_test_acc_avg = test_acc_avg
is_best = True
else:
is_best = False
if is_best:
state = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'train_loss': train_loss,
'train_acc': train_acc,
'test_loss': test_loss,
'test_acc': test_acc,
'test_acc_avg': test_acc_avg,
'best_test_acc': best_test_acc,
'best_test_acc_avg': best_test_acc_avg
}
torch.save(state, str(checkpoint_path))
log_string(f'Saved best model at epoch {epoch + 1} with Test Acc: {best_test_acc:.2f}%')
# 定期保存检查点
if (epoch + 1) % 10 == 0:
checkpoint = checkpoints_dir / f'model_epoch_{epoch + 1}.pth'
torch.save(state, str(checkpoint))
log_string(f'Saved checkpoint at epoch {epoch + 1}')
# 训练结束
log_string('Training completed!')
log_string(f'Best Test Accuracy: {best_test_acc:.2f}%')
log_string(f'Best Test Avg Accuracy: {best_test_acc_avg:.2f}%')
if __name__ == '__main__':
main()
# *_*coding:utf-8 *_*
import os
import json
import warnings
import numpy as np
import torch
import h5py
from torch.utils.data import Dataset
from torchvision import transforms
from utils import (
random_scale_point_cloud,
shift_point_cloud,
rotate_point_cloud,
jitter_point_cloud,
pc_normalize # 导入归一化函数
)
warnings.filterwarnings('ignore')
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
# 定义可序列化的转换函数
def jitter_transform(x):
return jitter_point_cloud(x)
def rotate_transform(x):
return rotate_point_cloud(x)
def scale_transform(x):
return random_scale_point_cloud(x)
def shift_transform(x):
return shift_point_cloud(x)
def normalize_transform(x):
return pc_normalize(x)
# 加载ScanObjectNN数据
def load_scanobjectnn_data(split, root_dir):
# 映射分割名称到文件名
split_mapping = {
'train': 'training',
'val': 'training', # ScanObjectNN没有验证集,使用部分训练数据
'test': 'test'
}
if split not in split_mapping:
raise ValueError(f"无效的分割类型: {split}. 可选: {list(split_mapping.keys())}")
# 构建绝对路径
h5_name = os.path.join(
root_dir,
'h5_files',
'main_split',
f'{split_mapping[split]}_objectdataset_augmentedrot_scale75.h5'
)
if not os.path.exists(h5_name):
raise FileNotFoundError(f"ScanObjectNN数据文件不存在: {h5_name}")
with h5py.File(h5_name, 'r') as f:
data = f['data'][:].astype('float32')
labels = f['label'][:].astype('int64')
# 如果是验证集,取前10%作为验证
if split == 'val':
num_val = int(len(data) * 0.1)
data = data[:num_val]
labels = labels[:num_val]
return data, labels
class ScanObjectNNDataset(Dataset):
def __init__(self, root, npoints=2048, split='train',
normal_channel=False, augment=True):
"""
参数:
root: ScanObjectNN数据集根目录 (包含h5_files目录)
npoints: 每个点云采样的点数 (默认2048)
split: 数据集分割 (train/val/test) (默认train)
normal_channel: 是否使用法线信息 (默认False)
augment: 是否使用数据增强 (默认True)
"""
self.root = root
self.npoints = npoints
self.normal_channel = normal_channel
self.split = split
self.augment = augment and (split == 'train') # 仅在训练时增强
# 加载数据
self.data, self.labels = load_scanobjectnn_data(split, root)
# 类别数量
self.num_classes = 15
# 设置数据增强
if self.augment:
self.augment_transforms = transforms.Compose([
transforms.Lambda(scale_transform),
transforms.Lambda(shift_transform),
transforms.Lambda(rotate_transform),
transforms.Lambda(jitter_transform),
transforms.Lambda(normalize_transform)
])
else:
self.augment_transforms = transforms.Compose([
transforms.Lambda(normalize_transform)
])
print(f"成功初始化ScanObjectNN数据集: {split}分割, 样本数: {len(self.data)}")
print(f"数据集根目录: {self.root}")
print(f"使用法线信息: {'是' if self.normal_channel else '否'}")
print(f"数据增强: {'启用' if self.augment else '禁用'}")
def __len__(self):
return len(self.data)
def __getitem__(self, index):
# 获取点云和标签
point_set = self.data[index].copy()
label = self.labels[index]
# 选择通道
if not self.normal_channel:
point_set = point_set[:, :3]
else:
point_set = point_set[:, :6]
# 采样点
if point_set.shape[0] < self.npoints:
# 点不够时重复采样
indices = np.random.choice(point_set.shape[0], self.npoints, replace=True)
else:
indices = np.random.choice(point_set.shape[0], self.npoints, replace=False)
point_set = point_set[indices]
# 应用数据增强
point_set = self.augment_transforms(point_set)
# 转换为PyTorch需要的格式 (C, N)
point_set = point_set.transpose()
return point_set, label
# 定义可序列化的collate函数
def custom_collate_fn(batch):
points, labels = zip(*batch)
# 直接堆叠所有元素
points = torch.tensor(np.stack(points)).float()
labels = torch.tensor(np.stack(labels)).long()
return points, labels
# 数据加载器函数
def get_scanobjectnn_dataloader(root, batch_size=32, npoints=2048,
split='train', normal_channel=False,
num_workers=4, augment=True, shuffle=None):
"""
创建ScanObjectNN的数据加载器
参数:
root: ScanObjectNN数据集根目录 (包含h5_files目录)
batch_size: 批大小 (默认32)
npoints: 采样点数 (默认2048)
split: 数据集分割 (train/val/test) (默认train)
normal_channel: 是否包含法线信息 (默认False)
num_workers: 数据加载线程数 (默认4)
augment: 是否数据增强 (默认True)
shuffle: 是否打乱数据 (默认None, 自动根据split设置)
"""
if shuffle is None:
shuffle = (split == 'train') # 训练集默认打乱
dataset = ScanObjectNNDataset(
root=root,
npoints=npoints,
split=split,
normal_channel=normal_channel,
augment=augment
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
collate_fn=custom_collate_fn, # 使用可序列化的函数
pin_memory=True,
drop_last=(split == 'train') # 训练时丢弃最后不完整的批次
)
return dataloader
# 示例用法
if __name__ == "__main__":
# 设置ScanObjectNN数据集根目录 (包含h5_files目录)
SCANOBJECTNN_ROOT = '/Users/zhaoboyang/Model_Introduction/Classification/PENet_cls/data'
# 创建训练数据加载器
train_loader = get_scanobjectnn_dataloader(
root=SCANOBJECTNN_ROOT,
batch_size=16,
npoints=1024,
split='train',
normal_channel=False
)
# 创建测试数据加载器
test_loader = get_scanobjectnn_dataloader(
root=SCANOBJECTNN_ROOT,
batch_size=16,
npoints=1024,
split='test',
normal_channel=False,
augment=False,
shuffle=False
)
# 测试训练加载器
print("\n训练集示例:")
for i, (points, labels) in enumerate(train_loader):
print(f"批次 {i + 1}:")
print(f" 点云形状: {points.shape}") # (B, C, N)
print(f" 标签形状: {labels.shape}") # (B,)
print(f" 示例标签: {labels[:4].tolist()}")
if i == 0: # 仅显示第一批次
break
# 测试测试加载器
print("\n测试集示例:")
for i, (points, labels) in enumerate(test_loader):
print(f"批次 {i + 1}:")
print(f" 点云形状: {points.shape}")
print(f" 标签形状: {labels.shape}")
print(f" 示例标签: {labels[:4].tolist()}")
if i == 0: # 仅显示第一批次
break
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
def pc_normalize(pc):
"""归一化点云数据"""
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
pc = pc / m
return pc
def square_distance(src, dst):
"""
计算点对之间的欧氏距离平方
参数:
src: 源点, [B, N, C]
dst: 目标点, [B, M, C]
返回:
dist: 每点的平方距离, [B, N, M]
"""
B, N, _ = src.shape
_, M, _ = dst.shape
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
dist += torch.sum(src ** 2, -1).view(B, N, 1)
dist += torch.sum(dst ** 2, -1).view(B, 1, M)
return dist
def index_points(points, idx):
"""
索引点云数据
参数:
points: 输入点云, [B, N, C]
idx: 采样索引, [B, S]
返回:
new_points: 索引后的点, [B, S, C]
"""
device = points.device
B = points.shape[0]
view_shape = list(idx.shape)
view_shape[1:] = [1] * (len(view_shape) - 1)
repeat_shape = list(idx.shape)
repeat_shape[0] = 1
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
new_points = points[batch_indices, idx, :]
return new_points
def farthest_point_sample(xyz, npoint):
"""
最远点采样(FPS)
参数:
xyz: 点云数据, [B, N, 3]
npoint: 采样点数
返回:
centroids: 采样点索引, [B, npoint]
"""
device = xyz.device
B, N, C = xyz.shape
centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
distance = torch.ones(B, N).to(device) * 1e10
farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
batch_indices = torch.arange(B, dtype=torch.long).to(device)
for i in range(npoint):
centroids[:, i] = farthest
centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
dist = torch.sum((xyz - centroid) ** 2, -1)
mask = dist < distance
distance[mask] = dist[mask]
farthest = torch.max(distance, -1)[1]
return centroids
def query_ball_point(radius, nsample, xyz, new_xyz):
"""
球查询
参数:
radius: 搜索半径
nsample: 最大采样数
xyz: 所有点, [B, N, 3]
new_xyz: 查询点, [B, S, 3]
返回:
group_idx: 分组点索引, [B, S, nsample]
"""
device = xyz.device
B, N, C = xyz.shape
_, S, _ = new_xyz.shape
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
sqrdists = square_distance(new_xyz, xyz)
group_idx[sqrdists > radius ** 2] = N
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
mask = group_idx == N
group_idx[mask] = group_first[mask]
return group_idx
def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
"""
采样和分组
参数:
npoint: 采样点数
radius: 搜索半径
nsample: 每组点数
xyz: 点云位置, [B, N, 3]
points: 点云特征, [B, N, D]
返回:
new_xyz: 采样点位置, [B, npoint, nsample, 3]
new_points: 采样点特征, [B, npoint, nsample, 3+D]
"""
B, N, C = xyz.shape
S = npoint
fps_idx = farthest_point_sample(xyz, npoint)
new_xyz = index_points(xyz, fps_idx)
idx = query_ball_point(radius, nsample, xyz, new_xyz)
grouped_xyz = index_points(xyz, idx)
grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
if points is not None:
grouped_points = index_points(points, idx)
new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1)
else:
new_points = grouped_xyz_norm
if returnfps:
return new_xyz, new_points, grouped_xyz, fps_idx
else:
return new_xyz, new_points
def sample_and_group_all(xyz, points):
"""
全局采样和分组
参数:
xyz: 点云位置, [B, N, 3]
points: 点云特征, [B, N, D]
返回:
new_xyz: 采样点位置, [B, 1, 3]
new_points: 采样点特征, [B, 1, N, 3+D]
"""
device = xyz.device
B, N, C = xyz.shape
new_xyz = torch.zeros(B, 1, C).to(device)
grouped_xyz = xyz.view(B, 1, N, C)
if points is not None:
new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
else:
new_points = grouped_xyz
return new_xyz, new_points
# 在utils.py中修改PointNetSetAbstraction类
class PointNetSetAbstraction(nn.Module):
def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
super(PointNetSetAbstraction, self).__init__()
self.npoint = npoint
self.radius = radius
self.nsample = nsample
# 关键修改:添加3个坐标差值通道
actual_in_channel = in_channel + 3 # 3个坐标差值
self.mlp_convs = nn.ModuleList()
self.mlp_bns = nn.ModuleList()
# 使用实际输入通道数
last_channel = actual_in_channel
for out_channel in mlp:
self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
self.mlp_bns.append(nn.BatchNorm2d(out_channel))
last_channel = out_channel
self.group_all = group_all
def forward(self, xyz, points):
xyz = xyz.permute(0, 2, 1)
if points is not None:
points = points.permute(0, 2, 1)
if self.group_all:
new_xyz, new_points = sample_and_group_all(xyz, points)
else:
new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
new_points = new_points.permute(0, 3, 2, 1)
for i, conv in enumerate(self.mlp_convs):
bn = self.mlp_bns[i]
new_points = F.relu(bn(conv(new_points)))
new_points = torch.max(new_points, 2)[0]
new_xyz = new_xyz.permute(0, 2, 1)
return new_xyz, new_points
# 数据增强函数
def random_scale_point_cloud(point_set, scale_low=0.8, scale_high=1.2):
"""随机缩放点云 (默认缩放范围0.8-1.2)"""
scale = np.random.uniform(scale_low, scale_high)
point_set[:, :3] *= scale
return point_set
def shift_point_cloud(point_set, shift_range=0.1):
"""随机平移点云 (默认平移范围±0.1)"""
shifts = np.random.uniform(-shift_range, shift_range, 3)
point_set[:, :3] += shifts
return point_set
def rotate_point_cloud(point_set, normal_start_idx=3):
"""
随机旋转点云 (绕z轴旋转),支持任意向量信息
:param point_set: 点云数组 [N, C]
:param normal_start_idx: 法线/向量起始维度索引
"""
angle = np.random.uniform(0, 2 * np.pi)
cosval = np.cos(angle)
sinval = np.sin(angle)
rotation_matrix = np.array([
[cosval, sinval, 0],
[-sinval, cosval, 0],
[0, 0, 1]
])
# 旋转所有三维向量(坐标和法线)
for i in range(0, point_set.shape[1], 3):
if i + 3 > point_set.shape[1]:
break
point_set[:, i:i + 3] = np.dot(point_set[:, i:i + 3], rotation_matrix)
return point_set
def jitter_point_cloud(point_set, sigma=0.01, clip=0.05):
"""添加噪声扰动 (默认σ=0.01, 裁剪±0.05)"""
noise = np.clip(sigma * np.random.randn(*point_set[:, :3].shape), -clip, clip)
point_set[:, :3] += noise
return point_set
# 包装函数用于数据增强
def random_scale_wrapper(x):
return random_scale_point_cloud(x, scale_low=0.85, scale_high=1.15)
def shift_wrapper(x):
return shift_point_cloud(x, shift_range=0.1)
def rotate_wrapper(x):
return rotate_point_cloud(x)
def jitter_wrapper(x):
return jitter_point_cloud(x, sigma=0.005, clip=0.02)
def normalize_wrapper(x):
return pc_normalize(x)
上述的三个代码分别是之前提到的分类模型的训练代码,数据加载代码以及工具函数代码(utils),看一下是否是这些代码出了错误才导致训练指标及其低的不正常的
最新发布