图神经网络、GCN实现图内点云分类任务(物体的部件分类)


本项目是一个简单的使用图中点分类代码,内涵完整的网络搭建、模型训练、模型保存、模型调用、可视化、的全过程。可以帮助初学者快速熟悉流程。帮助入门。

数据集下载

关键代码

项目使用了shapenet数据集中的飞机类数据集,在使用图神经网络飞机上进行部件分割,本项目写了一个自动下载数据集的方法,直接运行项目会自动下载数据集。关键部分代码如下。

def load_data():
    path = './data'
    if not os.path.exists(path):
        # 如果目录不存在,则创建它
        os.makedirs(path)
        print(f"目录 '{path}' 已创建。")
    else:
        print(f"目录 '{path}' 已存在。")
    train_data = ShapeNet(root=path, categories=['Airplane'], split='trainval', pre_transform=T.KNNGraph(k=6))
    test_data = ShapeNet(root=path, categories=['Airplane'], split='test', pre_transform=T.KNNGraph(k=6))
    return train_data, test_data
// A code block
这个代码先检查了存储路径是否存在,不存在创建一个,之后在直接下载数据集。
我们下载的时候选定了预处理pre_transform=T.KNNGraph(k=6),会预先使用
k临近算法给图数据生成边。这个方法下载了训练集和测试集。

数据集结构

 t, t1 = load_data()
 print(t)
 print(t[0])
 print(len(t))
// 输出
ShapeNet(2349, categories=['Airplane'])
# 这句的意思一共2349个图,都是飞机类
Data(x=[2518, 3], y=[2518], pos=[2518, 3], category=[1], 
edge_index=[2, 15108])
# 这里是选中了一张图看看里面的结构,x是点特征每个点是三维的,意思是一个
# 点用三个数表示,可以联想成空间里面点的x,y,z吧。pos是空间的点坐标也是
# 2518个。y是各个点的标签,用来表示这个点是属于那个部件的。category
# 表示整个图是什么类,这里只有飞机类,所以说只有一个数。edge_index是
# 邻接矩阵,指明那个点之间有连接
2349

网络模型

class Net(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, num_classes):
        super()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

代码飞速跑

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值