简要记录mmfewshot学习的一些体会
一 Dataset
在 mmfewshot/classfication/datasets/*
路径下包含了各类数据集的.py文件。mini_imagenet.py
tiered_imagenet.py
中定义了各自数据集的类对象, 且都继承自 BaseFweShotDataset
类。该类是 torch.utils.data.Dataset
的子类。
1 BaseFewshotDataset
BaseFewshotDataset
定义了基础的初始化方法、数据读取和分析方法。 部分方法是抽象定义的,部分类成员对象的预定义值为None
,所以BaseFewShotDataset
在实际使用中是无法实例化的,其作为不同客制化类对象的父类,提供代码规范并实现了基础方法,预定义了一些抽象类成员。
1.1 魔方方法
BaseFweShotDataset
定义了数据集类的初始化方法 __init__
。
def __init__(self,
data_prefix: str, # 数据集文件路径
pipeline: List[Dict], # 读取、数据增强、totensor 方式。
classes: Optional[Union[str, List[str]]] = None,
ann_file: Optional[str] = None) -> None: # annotation file 路径
super().__init__()
self.ann_file = ann_file
self.data_prefix = data_prefix
assert isinstance(pipeline, list), 'pipeline is type of list'
self.pipeline = Compose(pipeline)
self.CLASSES = self.get_classes(classes)
self.data_infos = self.load_annotations()
# build a look up dict for sampling instances by class id
self.data_infos_class_dict = {i: [] for i in range(len(self.CLASSES))}
for idx, data_info in enumerate(self.data_infos):
self.data_infos_class_dict[data_info['gt_label'].item()].append(idx)
在 __init__
中, 并未显示的读入数据,成员对象记录了数据集文件和标签文件的路径、生成了包含每个image file信息 的列表、并保存了真实类名self.CLASSES
等。self.pipeline
指示了 data preprocess 的方式( load channel 3 RGB-> augmentations-> tensor)。self.data_infos_class_dict
建立了gt_label
(数字化标签 from 0) 到 imag idx 的映射。
较为重要的是 self.load_annotations
, 该方法在 BaseFweShotDataset
中仅定义了抽象方法。子类中(MiniImageNetDataset
TieredImageNetDataset
)必须根据数据集的结构进行客制化实现。 以MiniImageNetDataset
类为例,其 load_annotations
方法 return 1个data_infos
列表,以 dict
为元素。每个dict
涵盖了数据集中每个图片的全部必要信息,于MiniImageNet
中, dict
结构如下:
info = {
'img_prefix':osp.join(self.data_prefix, 'images', class_name),
'img_info': {'filename': filename}, # 前两个键值用于在__getitem__中load image。
'gt_label':np.array(gt_label, dtype=np.int64) # 每个imag 样本对应的标签(0~63 for trianing set)。
}
BaseFewShotDataset
实现了一些基础的 data 读取、解析方法。其中,__len__
、 __getitem__
方法依赖于 self.data_infos
列表,二者是通用的实现范式,子类中基本上不需要重写。不同的数据集可能需要的self.pipeline
不同。
def __len__(self) -> int:
"""Return length of the dataset."""
return len(self.data_infos)
# __getitem__ 是较为通用的实现范式,从self.data_infos 列表中根据 idx 获得某 image的文件信息,
# 再 call self.pipeline
def __getitem__(self, idx: int) -> Dict:
return self.prepare_data(idx)
def prepare_data(self, idx: int) -> Dict:
results = copy.deepcopy(self.data_infos[idx])
return self.pipeline(results)
1.2 classmethod
用@classmethod
装饰方法,可以在不实例化 dataset 类对象的前提下调用该方法。
BaseFewshotDataset
实现了get_classes(cls, classes)
, 可以便利化查看该 dataset 类对象的真实类名(wnid)。
一般 classes
参数传入None
,get_classes
在内部调用 预定义的类成员数据CLASSES
,返回classes_names
给实例成员self.CLASSES
。子类中,需要根据数据集来预定义CLASSES
。
1.3 实例方法
除了在__getiem__
中调用的 prepare_data
方法外,BaseFewShotDataset
实现了sample_shots_by_class_id
方法,根据class_id
随机选择 num_shots
个该类别image的idx,契合 few shot learning 的任务需求。
def sample_shots_by_class_id(self, class_id: int,
num_shots: int) -> List[int]:
"""Random sample shots of given class id."""
all_shot_ids = self.data_infos_class_dict[class_id]
return np.random.choice(
all_shot_ids, num_shots, replace=False).tolist()
1.4 property
使用@property
将方法装饰为成员变量。class_to_idx(self)
,返回 gt_label
到真实类名wnid
的字典映射。gt_label
根据self.CLASSES
类名列表的顺序生成。
@classmethod
def class_to_idx(self) -> Mapping:
"""Map mapping class name to class index.
Returns:
dict: mapping from class name to class index.
"""
return {_class: i for i, _class in enumerate(self.CLASSES)}
1.5 staticmethod
在BaseFewshotDataset
类中实现了 测试方法。该方法是统一的实现范式,子类不需要重写。
@staticmethod
evaluate(results: List,
gt_labels: np.array,
metric: Union[str, List[str]] = 'accuracy',
metric_options: Optional[dict] = None,
logger: Optional[object] = None) -> Dict
2 MiniImageNetDataset
为不同数据集实现的 “TypeDataset” 继承自 BaseFewshotDataset
,且必须重写load_annotations
、补全CLASSES
。根据数据集结构的具体需求,get_classes
方法一般也要客制化。在__init__
方法中,根据任务模式(train\val\test),引入str
对象self.subset
,根据任务模式,控制load_annotations
方法中读入不同的CLASSES
。
@DATASETS.register_module()
class MiniImageNetDataset(BaseFewShotDataset):
"""MiniImageNet dataset for few shot classification.
Args:
subset (str| list[str]): The classes of whole dataset are split into
three disjoint subset: train, val and test. If subset is a string,
only one subset data will be loaded. If subset is a list of
string, then all data of subset in list will be loaded.
Options: ['train', 'val', 'test']. Default: 'train'.
file_format (str): The file format of the image. Default: 'JPEG'
"""
resource = 'https://github.com/twitter/meta-learning-lstm/tree/master/data/miniImagenet' # noqa
TRAIN_CLASSES = TRAIN_CLASSES #
VAL_CLASSES = VAL_CLASSES
TEST_CLASSES = TEST_CLASSES
def __init__(self,
subset: Literal['train', 'test', 'val'] = 'train',
file_format: str = 'JPEG',
*args,
**kwargs):
if isinstance(subset, str):
subset = [subset]
for subset_ in subset:
assert subset_ in ['train', 'test', 'val']
self.subset = subset
self.file_format = file_format
super().__init__(*args, **kwargs)
对于MiniImageNetDaset
而言,其实只根据数据集结构重写了get_classes
load_annotations
方法,定义了类变量“CLASSES"。流程是很简单的。
TRAIN_CLASSES = TRAIN_CLASSES # 类变量
VAL_CLASSES = VAL_CLASSES
TEST_CLASSES = TEST_CLASSES
3 dataset_wrappers.py
在 mmfewshot/classification/datasets/*
目录下,还有一个dataset_wrappers.py
。其中,实现了EpisodicDataset
、MetaTestDataset
,是 MiniImageNetDataset
类的 wrapper
,在__init__
中用self.daaset
接收 MiniImageNetDataset
对象,加以包裹。
3.1 EpisodicDataset
EpisodicDataset
的目的是实现 情景训练 范式。其在__init__
方法中,接收 TypeDataset
对象、episodes总数、way数、shot数、query数、随机种子:
def __init__(self,
dataset: Dataset,
num_episodes: int,
num_ways: int,
num_shots: int,
num_queries: int,
episodes_seed: int | None = None) -> None:
self.dataset = dataset
self.num_ways = num_ways
self.num_shots = num_shots
self.num_queries = num_queries
self.num_episodes = num_episodes
self._len = len(self.dataset)
self.CLASSES = dataset.CLASSES
# using same episodes seed can generate same episodes for same dataset
# it is designed for the reproducibility of meta train or meta test
self.episodes_seed = episodes_seed
self.episode_idxes, self.episode_class_ids = self.generate_episodic_idxes()
EpisodicDataset
类并非是子类,其直接在初始化方法中接收MiniImageNetDataset
类作为类成员。在__init__
中,最为关键的是调用了self.generate_episodic_idxes
方法,该方法依据自定义的随机种子,按照情景训练范式,生成 2 个长度为 self.num_episodes
的列表,episode_idxes
, episode_class_ids
。
episode_class_idx
,该列表中的每个元素为列表,是某个episode
sample 出的self.num_ways
个类别的gt_label
,即数字化标签。
episode_idxes
,该列表中的每个元素为字典,以’support’、'query’为key,值为sample出的长度分别为self.num_shots*self.num_ways
、self.num_queries*self.num_ways
列表,其元素为 image 的 索引。
{ 'support': episodic_support_idx,'query': episodic_query_idx}
EpisodicDataset
类实现了自己的__len__
、__getitem__
方法。__getitem__
接收某episode的索引值,从episode_idxes
中获取该episode需要的 support、query样本对应的索引,并调用self.dataset.__getitem__
返回样本数据。
def __getitem__(self, idx: int) -> dict:
"""Return a episode data at the same time.
For `EpisodicDataset`, this function would return num_ways *
num_shots support images and num_ways * num_queries query image.
"""
return {
'support_data':
[self.dataset[i] for i in self.episode_idxes[idx]['support']],
'query_data':
[self.dataset[i] for i in self.episode_idxes[idx]['query']]
}
def __len__(self) -> int:
"""The length of the dataset is the number of generated episodes."""
return self.num_episodes
3.2 MetaTestDataset
继承自EpisodicDataset
, 引入了self._mode
,来更改__getitem__
返回样本数据的方式。
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._mode = 'test_set'
self._task_id = 0
self._with_cache_feats = False
def __getitem__(self, idx: int) -> dict:
"""Return data according to mode.
For mode `test_set`, this function would return single image as regular
dataset. For mode `support`, this function would return single support
image of current episode. For mode `query`, this function would return
single query image of current episode. If the dataset have cached the
extracted features from fixed backbone, then the features will be
return instead of image.
"""
if self._mode == 'test_set':
idx = idx
elif self._mode == 'support':
idx = self.episode_idxes[self._task_id]['support'][idx]
elif self._mode == 'query':
idx = self.episode_idxes[self._task_id]['query'][idx]
if self._with_cache_feats:
return {
'feats': self.dataset.data_infos[idx]['feats'],
'gt_label': self.dataset.data_infos[idx]['gt_label']
}
else:
return self.dataset[idx]