- 加载必要的包
fastai.vision 包含处理计算机视觉问题常用的类与方法:比如图像类image, 数据增强类transform, 通过预训练网络来快速搭建模型的create_cnn函数等等。
from fastai import *
from fastai.vision import *
from fastai.callbacks import *
- 导入数据集摘要信息
数据集的摘要信息存储再以csv结尾的无格式文件种,第一列Image表示图片的名称,第二列Id表示图片的标签,如下所示。
path = Path('../data')
trn_imgs = pd.read_csv(path/'train.csv');
trn_imgs.head()

- 定义TripleImageItem
我们希望定义一个Item,其包含3张图像,其中第一张图像通常称为anchor, 第二张图像与第一张图像为同一类别,第三张图像与第一张图像为不同类别。
自定义Item时,最重要的是obj和data属性,obj属性描述这个item,并且依靠obj可以复制item;data是实际用于学习器学习的数据。也就是说对obj进行简单的预处理,然后将其赋值给data.
class TripleImageItem(ItemBase):
def __init__(self, img1, img2, img3, mean, std):
"""
img1, img2, img3 类型为Image, 且shape相同
"""
self.img1, self.img2, self.img3 = img1, img2, img3
self.mean, self.std = mean, std
self.obj = (img1, img2, img3)
self.data = self.normalize(img1.data, img2.data, img3.data