【Torch API】torch.unique()用法详解

部署运行你感兴趣的模型镜像

torch.unique()的功能类似于数学中的集合,就是挑出tensor中的独立不重复元素。

这个方法的参数在官方解释文档中有这么几个:torch.unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None)

input: 待处理的tensor

sorted:是否对返回的无重复张量按照数值进行排列,默认是生序排列的

return_inverse: 是否返回原始tensor中的每个元素在这个无重复张量中的索引

return_counts: 统计原始张量中每个独立元素的个数

dim: 值沿着哪个维度进行unique的处理,这个我试验后没有搞懂怎样的机理。如果处理的张量都是一维的,那么这个不需要理会。

import torch
 
x = torch.tensor([4,0,1,2,1,2,3])#生成一个tensor,作为实验输入
print(x)
 
out = torch.unique(x) #所有参数都设置为默认的
print(out)#将处理结果打印出来
#结果如下:
#tensor([0, 1, 2, 3, 4])   #将x中的不重复元素挑了出来,并且默认为生序排列
 
out = torch.unique(x,sorted=False)#将默认的生序排列改为False
print(out)
#输出结果如下:
#tensor([3, 2, 1, 0, 4])  #将x中的独立元素找了出来,就按照原始顺序输出
 
out = torch.unique(x,return_inverse=True)#将原始数据中的每个元素在新生成的独立元素张量中的索引输出
print(out)
#输出结果如下:
#(tensor([0, 1, 2, 3, 4]), tensor([4, 0, 1, 2, 1, 2, 3]))  #第一个张量是排序后输出的独立张量,第二个结果对应着原始数据中的每个元素在新的独立无重复张量中的索引,比如x[0]=4,在新的张量中的索引为4, x[1]=0,在新的张量中的索引为0,x[6]=3,在新的张量中的索引为3
 
out = torch.unique(x,return_counts=True) #返回每个独立元素的个数
print(out)
#输出结果如下
#(tensor([0, 1, 2, 3, 4]), tensor([1, 2, 2, 1, 1]))  #0这个元素在原始数据中的数量为1,1这个元素在原始数据中的数量为2

您可能感兴趣的与本文相关的镜像

PyTorch 2.6

PyTorch 2.6

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

def get_samples(is_split=True): # 设置随机种子,保证结果可复现 random.seed(42) np.random.seed(42) # 1. 加载数据 #features = np.load(features_path) # 假设形状为 (n_samples, ...) #labels = np.load(labels_path) # 假设形状为 (n_samples,) #sns = np.load(sns_path) # 假设形状为 (n_samples,),每个元素是磁盘SN with h5py.File(samples_path, 'r') as file: features = np.array(file['features']) # 读取数据 labels = np.array(file['labels']) sns = np.array(file['sns']) if is_split == False: return features, labels # 2. 获取所有唯一的SN并打乱顺序 unique_sns = np.unique(sns) random.shuffle(unique_sns) # 随机打乱SN顺序 # 3. 按8:2比例划分SN(训练集80%,测试集20%) split_idx = int(len(unique_sns) * 0.8) train_sns = set(unique_sns[:split_idx]) # 训练集使用的SN test_sns = set(unique_sns[split_idx:]) # 测试集使用的SN # 4. 根据SN筛选样本 # 生成训练集和测试集的掩码(布尔数组) train_mask = np.array([sn in train_sns for sn in sns]) test_mask = np.array([sn in test_sns for sn in sns]) # 筛选样本 train_features = features[train_mask] train_labels = labels[train_mask] test_features = features[test_mask] test_labels = labels[test_mask] X_train = torch.tensor(train_features, dtype=torch.float32) X_test = torch.tensor(test_features, dtype=torch.float32) y_train = torch.tensor(train_labels, dtype=torch.float32) y_test = torch.tensor(test_labels, dtype=torch.float32) return X_train, X_test, y_train, y_test 将上面这个改成分批读取,但是也要确保整个train和val中sn互不重复
08-03
def load_image(filename): ext = splitext(filename)[1] if ext == '.npy': return Image.fromarray(np.load(filename)) elif ext in ['.pt', '.pth']: return Image.fromarray(torch.load(filename).numpy()) else: return Image.open(filename) def unique_mask_values(idx, mask_dir, mask_suffix): mask_file = list(mask_dir.glob(idx + mask_suffix + '.*'))[0] mask = np.asarray(load_image(mask_file)) if mask.ndim == 2: return np.unique(mask) elif mask.ndim == 3: mask = mask.reshape(-1, mask.shape[-1]) return np.unique(mask, axis=0) else: raise ValueError(f'Loaded masks should have 2 or 3 dimensions, found {mask.ndim}') class BasicDataset(Dataset): def __init__(self, images_dir: str, mask_dir: str, img_newsize: int, mask_suffix: str = ''): self.images_dir = Path(images_dir) self.mask_dir = Path(mask_dir) self.img_newsize = img_newsize self.mask_suffix = mask_suffix self.ids = [splitext(file)[0] for file in listdir(images_dir) if isfile(join(images_dir, file)) and not file.startswith('.')] if not self.ids: raise RuntimeError(f'No input file found in {images_dir}, make sure you put your images there') logging.info(f'Creating dataset with {len(self.ids)} examples') logging.info('Scanning mask files to determine unique values') with Pool() as p: unique = list(tqdm(p.imap(partial(unique_mask_values, mask_dir=self.mask_dir, mask_suffix=self.mask_suffix), self.ids), total=len(self.ids))) self.mask_values = list(sorted(np.unique(np.concatenate(unique), axis=0).tolist())) logging.info(f'Unique mask values: {self.mask_values}') def __len__(self): return len(self.ids) @staticmethod def preprocess(mask_values, input_img, img_newsize, is_mask): w, h = input_img.shape[1], input_img.shape[0] newW, newH = int(img_newsize[1]), int(img_newsize[0]) img = cv2.resize(input_img, (newW, newH), interpolation=cv2.INTER_NEAREST if is_mask else cv2.INTER_CUBIC) # img = np.asarray(pil_img) if is_mask: mask = np.zeros((newH, newW), dtype=np.int64) for i, v in enumerate(mask_values): if img.ndim == 2: mask[img == v] = i else: mask[(img == v).all(-1)] = i return mask else: if img.ndim == 2: img = img[np.newaxis, ...] else: img = img.transpose((2, 0, 1)) if (img > 1).any(): img = img / 255.0 return img def __getitem__(self, idx): name = self.ids[idx] mask_file = list(self.mask_dir.glob(name + self.mask_suffix + '.*')) img_file = list(self.images_dir.glob(name + '.*')) assert len(img_file) == 1, f'Either no image or multiple images found for the ID {name}: {img_file}' assert len(mask_file) == 1, f'Either no mask or multiple masks found for the ID {name}: {mask_file}' mask = cv2.imread(str(mask_file[0]),0) # gray img = cv2.imread(str(img_file[0])) # gray -> bgr img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB) # rgb # assert img.size == mask.size, f'Image and mask {name} should be the same size, but are {img.size} and {mask.size}' img = self.preprocess(self.mask_values, img, self.img_newsize, is_mask=False) mask = self.preprocess(self.mask_values, mask, self.img_newsize, is_mask=True) return {'image': torch.as_tensor(img.copy()).float().contiguous(), 'mask': torch.as_tensor(mask.copy()).long().contiguous()} 解析此段代码
06-10
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值