
pytorch
文章平均质量分 63
cc__cc__
记录日常,欢迎交流
展开
-
DETR训练自己的数据集
github地址:https://github.com/facebookresearch/detr1.创建conda环境推荐通过conda创建虚拟环境,具体操作可见linux系统下创建anaconda新环境及问题解决2.clone代码并安装依赖库git clone https://github.com/facebookresearch/detr.gitconda install -c pytorch pytorch torchvisionconda install cython scipy.原创 2021-12-04 11:57:26 · 13381 阅读 · 14 评论 -
反卷积torch.nn.ConvTranspose2d详解(含转换成卷积运算的代码示例)
1.torch.nn.ConvTranspose2d参数介绍官方文档:https://pytorch.org/docs/master/generated/torch.nn.ConvTranspose2d.html#torch.nn.ConvTranspose2dtorch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1,原创 2021-11-20 21:22:43 · 19075 阅读 · 2 评论 -
【Pytorch实战(五)】实现MNIST手写体识别
一、MNIST数据集MNIST数据集地址http://yann.lecun.com/exdb/mnist/该数据集包含6w张训练集图片和1w张测试集图片二、实现MNIST手写体识别1.借助torchvision下载数据集train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=torchvision.tran原创 2020-05-17 19:16:39 · 609 阅读 · 0 评论 -
【Pytorch实战(四)】线性判决与逻辑回归
在【Pytorch实战(三)】线性回归中介绍了线性回归。线性判决与线性回归是类似的,区别在于线性回归的“标签”通常是较为连续的取值,而线性判决的“标签”往往只有几个可能的取值,例如【0,1】。在线性回归中,经常使用MSE损失torch.nn.MSELoss()、L1损失torch.nn.L1Loss和SmoothL1损失torch.nn.SmoothL1Loss等。相应地,在线性判决中也需使用合适的损失。对于二元判决常用BCE(Binary Cross Entropy)损失;对于多元则常用Cross E.原创 2020-05-17 12:06:31 · 317 阅读 · 0 评论 -
【Pytorch实战(三)】线性回归
一、使用函数求解最小二乘的权重最小二乘法是一种最基本的线性回归方法,使用torch.lstsq直接实现最小二乘法的代码如下:import torchx = torch.tensor([[1., 2., 1.], [2., 4., 1.], [3., 5., 1.], [4., 2., 1.], [4., 4., 1.]])y = torch.tensor([-12., 13., 15., 14., 18.])wr, _ = torch.lstsq(y, x) # 返回2个值w = wr[原创 2020-05-16 18:22:20 · 1277 阅读 · 0 评论 -
【Pytorch实战(二)】梯度及优化算法
一、计算梯度的简单示例import torchx = torch.tensor([1., 2.], requires_grad=True)y = x[0] ** 2 + x[1] ** 2print('y = {}'.format(y))y.backward()print('grad = {}'.format(x.grad))# 输出结果 # y = 5.0# grad = tensor([2., 4.])注意问题:构造张量x时应将参数requires_grad设置为True,这原创 2020-05-15 22:20:47 · 483 阅读 · 0 评论 -
【Pytorch实战(一)】张量
一、初步认识在Pytorch中张量是一种基本的数据类型,用类torch.Tensor实现import torchT = torch.tensor([[1, 2, 3], [4, 5, 6]]) # 注意函数名tensor是小写的查看张量的几个性质张量的维度 T.dim()张量的大小 T.size()张量中的元素个数 T.numel张量中元素的类型 T.dtype 例如torch.float32,torch.int64等...原创 2020-05-14 22:14:30 · 394 阅读 · 0 评论 -
【mmdetection源码解读(二)】RPN网络
以下仅为个人理解,若有不正之处还请指出,欢迎交流!在two-stage目标检测方法中,通过骨干网络获得的特征图需要送进RPN网络产生区域建议候选框,下面就结合mmdetection中的源码详细解释这一过程。相关理论可结合faster rcnn进行理解Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Ne...原创 2020-04-13 14:41:51 · 2535 阅读 · 0 评论 -
【mmdetection源码解读(一)】骨干网络--ResNet
以下仅为个人理解,若有不正之处还请指出,欢迎交流!本文解读的源码为mmdet/models/backbones中的resnet.py首先附上ResNet原文地址Deep Residual Learning for Image Recognition其中,ResNet整体网络结构图如下:一、ResNet网络中的两种基本残差块由网络结构图可以看出,ResNet-18和ResNet-3...原创 2020-04-12 17:10:17 · 5671 阅读 · 2 评论 -
结合实例理解pytorch中交叉熵损失函数
以下仅为个人理解,如有不正之处还请指出,欢迎交流!1.pytorch中的交叉熵损失函数如下:torch.nn.functional.cross_entropy(input, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean')损失函数即反映了预测值(input...原创 2020-04-10 18:14:35 · 4466 阅读 · 0 评论 -
pytorch常用函数整理
一些常见函数,经常想不起函数本身的形参包含什么或者某实参对应的是什么含义,所以干脆随手记录在这里了(不断更新)……# 卷积torch.nn.Conv2d(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True) # 插值/上采样torch.nn...原创 2020-03-27 09:09:41 · 521 阅读 · 0 评论