下面是一个使用PyTorch的例子:
import torch
## 在这个函数中,我们首先将mask_tensor张量中值为1的位置转化为bool类型的掩码,然后使用这个掩码将data_tensor张量中对应位置的元素提取出来,并返回提取的元素。
def extract_elements_with_ones(mask_tensor, data_tensor):
# 将mask_tensor张量中值为1的位置转化为bool类型掩码
mask = mask_tensor == 1
# 使用掩码将data_tensor张量中对应位置的元素提取出来
extracted_data = data_tensor[mask]
retu