卷积层转全连接测试代码
def conv2fn(conv, input_shape, output_shape):
# padding=0, dialation=1
sh, sw = conv.stride # row stride and col stride
oc, ic, kh, kw = conv.weight.shape # output_channels, input_channels, row of kernels, col of kernels
ic, ih, iw = input_shape[-3:] # batch channel height width or channel height width
oc, oh, ow = output_shape[-3:]
W = torch.zeros((oc*oh*ow, ic*ih*iw), dtype=torch.float)
# we use tensor.view(-1) to flatten input and output, the order is col->row->channel
for cha_out in range(oc):
for row_out in range(oh):
for col_out in range(ow):
W_row = cha_out*oh*ow + row_out*ow + col_out
row_in_start = sh*row_out
col_in_start = sw*col_out
for cha_in in range(ic):
for row_ker in range(kh):
W_col1 = cha_in*ih*iw + (row_in_start+row_ker)*iw + col_in_start
W_col2 = W_col1 + kw
W[W_row, W_col1:W_col2] = conv.weight[cha_out, cha_in, row_ker, :]
return W
卷积反向传播代码实现:
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy import signal
from utils.Logger import reproduc
if __name__=="__main__":
reproduc({'seed': 42, 'benchmark': False, 'deterministic': True})
x = [
[[1,2,3],[4,5,6],[7,8,9]],
[[10,11,12],[13,14,15],[16,17,18]],
[[19,20,21],[22,23,24],[25,26,27]]
]
x = torch.tensor(x, dtype=torch.float)
x.requires_grad = True
conv = nn.Conv2d(3, 10, kernel_size=(2,2), stride=(1,1), bias=True)
y = conv(x)
loss = y.sum()
x_grad = torch.autograd.grad(loss, x, grad_outputs=torch.ones_like(loss), create_graph=True)
y_grad = torch.autograd.grad(loss, y, grad_outputs=torch.ones_like(loss), create_graph=True)
print(x_grad[0].shape, y_grad[0].shape)
print(x.shape, conv.weight.shape, y.shape)
# y = F.conv2d(input=input, weight=conv.weight, bias=conv.bias, stride=conv.stride, padding=conv.padding, dilation=conv.dilation)
# print(loss)
W = conv.weight.transpose(0, 1) # in_channels, out_channels
# W = W.transpose(2, 3) # fully convolution
W = torch.flip(W, dims=[-2,-1])
x1_grad = F.conv2d(input=y_grad[0], weight=W, padding=1)
res = x_grad[0] - x1_grad
print(abs(res).max(), abs(x_grad[0]).max(), abs(x1_grad).max())