pytorch中基本的变量类型当属FloatTensor(以下都用floattensor),而Variable(以下都用variable)是floattensor的封装,除了包含floattensor还包含有梯度信息
pytorch中的dochi给出一些对于floattensor的基本的操作,比如四则运算以及平方等(链接),这些操作对于floattensor是十分的不友好,有时候需要写一个正则化的项需要写很长的一串,比如两个floattensor之间的相加需要用torch.add()来实现
然而正确的打开方式并不是这样
韩国一位大神写了一个pytorch的turorial,其中包含style transfer的一个代码实现
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
for stepin range(config.total_step):
# Extract multiple(5) conv feature vectors
target_features= vgg(target)# 每一次输入到网络中的是同样一张图片,反传优化的目标是输入的target
content_features= vgg(Variable(content))
style_features= vgg(Variable(style))
style_loss= 0
content_loss= 0
for f1, f2, f3in zip(target_features, content_features, style_features):
# Compute content loss (target and content image)
content_loss+= torch.mean((f1- f2)**2)# square 可以进行直接加-操作?可以,并且mean对所有的元素进行均值化造作
# Reshape conv features
_, c, h, w= f1.size()# channel height width
f1= f1.view(c, h* w)# reshape a vector
f3= f3.view(c, h* w)# reshape a vector
# Compute gram matrix
f1= torch.mm(f1, f1.t())
f3= torch.mm(f3, f3.t())
# Compute style loss (target and style image)
style_loss+= torch.mean((f1- f3)**2)/ (c* h* w)# 总共元素的数目?
其中f1与f2,f3的变量类型是Variable,作者对其直接用四则运算符进行加减,并且用python内置的**进行平方操作,然后
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# -*-coding: utf-8 -*-
import torch
from torch.autogradimport Variable
# dtype = torch.FloatTensor
dtype= torch.cuda.FloatTensor# Uncomment this to run on GPU
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out= 64,1000,100,10
# Randomly initialize weights
w1= torch.randn(D_in, H).type(dtype)# 两个权重矩阵
w2= torch.randn(D_in, H).type(dtype)
# operate with +-*/ and **
w3= w1-2*w2
w4= w3**2
w5= w4/w1
# operate the Variable with +-*/ and **
w6= Variable(torch.randn(N, D_in).type(dtype))
w7= Variable(torch.randn(N, D_in).type(dtype))
w8= w6+ w7
w9= w6*w7
w10= w9**2
print(1)
基本上调试的结果与预期相符

所以,对于floattensor以及variable进行普通的+-×/以及**没毛病
以上这篇Pytorch基本变量类型FloatTensor与Variable用法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持服务器之家。
原文链接:https://blog.youkuaiyun.com/u013517182/article/details/93051322
本文介绍了PyTorch中的基本变量类型FloatTensor与Variable的使用方法。详细解释了Variable如何封装FloatTensor及其在深度学习任务如Style Transfer中的应用。通过实例展示了如何利用PyTorch进行高效的张量操作。
775

被折叠的 条评论
为什么被折叠?



