pytorch新手自学教程(四)

本教程通过PyTorch实现点集分类,演示了如何创建数据集、构建神经网络模型、训练模型并评估分类结果。详细介绍了监督学习在点集分类中的应用,包括数据生成、模型搭建、损失函数选择及训练过程。

pytorch新手自学教程(四)--点集分类

import头文件

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

点集分类

机器学习中的监督学习主要分为回归问题和分类问题,本文讲一下关于简单点集分类的神经网络解决方法。分类问题所预测的结果就是离散的类别c 这时 输入变量可以是离散的,也可以是连续的,而监督学习从数据中学习→个分类模型或 者分类决策函数,它被称为分类器( c1assifier )。 分类器对新的输入进行输出预测,这 个过程即称为分类( c1assification )。思路是先拟合决策边界(回归问题)(这里的决策边界 不局限于线性,还可以是多项式),在建立这个边界和分类概率的关系,从而得到二分 类情况下的概率。

·制造数据集:

data = torch.ones(100,2) #制造100个随机点(xi,yi)所需数据
point1 = torch.normal(1*data, 1) #点类1:分布在1附近的正态分布(标准差为1)
target1 = torch.zeros(100) #point1 100个点的标签设为0
point2 = torch.normal(-1*data, 1) #点类2:分布在-1附近的正态分布(标准差为1)
target2 = torch.ones(100) #point1 100个点的标签设为1
#将两类点和标签放在一起,方便训练
point = torch.cat((point1, point2), 0) #参数为0按行合并,为1按列合并
target = torch.cat((target1, target2), 0)
#打印数据
plt.scatter(point1[:,0], point1[:,1], c='red') #点1的散点图图,红点
plt.scatter(point2[:,0], point2[:,1], c='blue') #点2的散点图图,蓝点
plt.show()

dianjitu

·构建用于分类的神经网络:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.linear = nn.Linear(2, 1) #输入为一个点(xi,yi)二维,输出为判断点的标签(0或1)一维
        self.sm = nn.Sigmoid() #将输出经过sigmoid激活函数映射为概率,并取概率大的标签

    def forward(self, x):
        x = self.linear(x)
        x = self.sm(x)
        return x

net = Net()

使用print(net)可以查看搭建网络的结构
net

·进行训练:

optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
loss_fuc = nn.BCELoss() #二分类用的损失函数

epochs = 100
#训练步骤与前面大致相同,不在细述
for epoch in range(epochs):
    out_target = net(point)
    loss = loss_fuc(out_target, target)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch % 20 == 0:
        print('epoch: {} | loss: {:.3f}'.format(epoch, loss))

·以概率0.5为分界线标记网络划分点:

plt.scatter(point1[:, 0], point1[:, 1], c='red')  # 点1的散点图图,红点
plt.scatter(point2[:,0], point2[:,1], c='blue') #点2的散点图图,蓝点
#将数据中网络分为标签0的点标绿
out_target = net(point)
for i in range(len(out_target)):
    if out_target[i] >= 0.5: #将预测概率大于0.5的点标绿
        plt.scatter(point[i][0], point[i][1], c='green')
plt.show()

图分类

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值