点云分类任务
本项目是一个简单的使用图中点分类代码,内涵完整的网络搭建、模型训练、模型保存、模型调用、可视化、的全过程。可以帮助初学者快速熟悉流程。帮助入门。
数据集下载
关键代码
项目使用了shapenet数据集中的飞机类数据集,在使用图神经网络飞机上进行部件分割,本项目写了一个自动下载数据集的方法,直接运行项目会自动下载数据集。关键部分代码如下。
def load_data():
path = './data'
if not os.path.exists(path):
# 如果目录不存在,则创建它
os.makedirs(path)
print(f"目录 '{path}' 已创建。")
else:
print(f"目录 '{path}' 已存在。")
train_data = ShapeNet(root=path, categories=['Airplane'], split='trainval', pre_transform=T.KNNGraph(k=6))
test_data = ShapeNet(root=path, categories=['Airplane'], split='test', pre_transform=T.KNNGraph(k=6))
return train_data, test_data
// A code block
这个代码先检查了存储路径是否存在,不存在创建一个,之后在直接下载数据集。
我们下载的时候选定了预处理pre_transform=T.KNNGraph(k=6),会预先使用
k临近算法给图数据生成边。这个方法下载了训练集和测试集。
数据集结构
t, t1 = load_data()
print(t)
print(t[0])
print(len(t))
// 输出
ShapeNet(2349, categories=['Airplane'])
# 这句的意思一共2349个图,都是飞机类
Data(x=[2518, 3], y=[2518], pos=[2518, 3], category=[1],
edge_index=[2, 15108])
# 这里是选中了一张图看看里面的结构,x是点特征每个点是三维的,意思是一个
# 点用三个数表示,可以联想成空间里面点的x,y,z吧。pos是空间的点坐标也是
# 2518个。y是各个点的标签,用来表示这个点是属于那个部件的。category
# 表示整个图是什么类,这里只有飞机类,所以说只有一个数。edge_index是
# 邻接矩阵,指明那个点之间有连接
2349
网络模型
class Net(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, num_classes):
super()