要求:求A=torch.rand[8,8]的张量中每个4*4块的均值和方差
逛了一圈优快云,发现没有关于求张量N*N分块后的均值和方差的代码,我就自己来写一个把
- 求张量的均值
def Trans2mean(image):#假设就是1通道张量
h,w=image.shape
block_h,block_w=h//4,w//4#假设4*4分块
#先按行切成 block_h块,每块4个
temp_h=torch.split(image,[4 for x in range(0,block_h)],0)
ALLmean=torch.zeros(block_h,block_w)
for i in range(0,block_h):
temp_w=temp_h[i].split([4 for y in range(0,block_h)],1)
for j in range(0,block_w):
ALLmean[i,j]=temp_w[j].mean()#算的是平方
return ALLmean
- 求张量的方差(其实是一样的)
def Trans2var(image):#假设就是1通道张量
h,w=image.shape
block_h,block_w=h//4,w//4#假设4*4分块
#先按行切成 block_h块,每块4个
temp_h=torch.split(image,[4 for x in range(0,block_h)],0)
ALLvar=torch.zeros(block_h,block_w)
for i in range(0,block_h):
temp_w=temp_h[i].split([4 for y in range(0,block_h)],1)
for j in range(0,block_w):
ALLvar[i,j]=temp_w[j].var()#算的是方差
return ALLvar
- 测试
#单元测试
if __name__ == '__main__':
A=torch.rand(8,8)
var_map=Trans2var(A)
mean_map=Trans2mean(A)
- 实验结果