导包
import torch
import torch.nn as nn
创建神经网络模型类:
class _netG(nn.Module):
def __init__(self):
super(_netG,self).__init__()
self.fc1 = nn.Linear(2048,1024)线性变换
def forward(self,x):
x_1 = self.fc1(x)
print(x_1.shape)
测试代码:
if __name__ == '__main__':
gt = torch.rand(4, 1024).cuda()
pre_end = torch.rand(4, 1024).cuda()
print(gt.shape, pre_end.shape)
a = torch.cat((gt, pre_end), dim=1).cuda()#向量拼接
print(a.shape)
mynet = _netG().cuda()
mynet(a)