pytorch+numpy实现--获取每行中数据大于平均值的数据,将其提取出来转化为稀疏矩阵
代码
import torch
x = torch.randn(3,4)
print(x)
print(torch.mean(x,dim=1,keepdim=True))
mz_items = x[:,] > torch.mean(x,dim=1,keepdim=True)
print(mz_items)
import numpy as np
mz_items_np = mz_items.numpy()
index = np.argwhere( mz_items_np == True )
print(
原创
2021-11-21 21:54:16 ·
1635 阅读 ·
0 评论