MMfewshot之datasets

简要记录mmfewshot学习的一些体会

一 Dataset

mmfewshot/classfication/datasets/* 路径下包含了各类数据集的.py文件。mini_imagenet.py tiered_imagenet.py 中定义了各自数据集的类对象, 且都继承自 BaseFweShotDataset 类。该类是 torch.utils.data.Dataset 的子类。

1 BaseFewshotDataset

BaseFewshotDataset 定义了基础的初始化方法、数据读取和分析方法。 部分方法是抽象定义的,部分类成员对象的预定义值为None所以BaseFewShotDataset在实际使用中是无法实例化的,其作为不同客制化类对象的父类,提供代码规范并实现了基础方法,预定义了一些抽象类成员。

1.1 魔方方法

BaseFweShotDataset 定义了数据集类的初始化方法 __init__

def __init__(self,
         data_prefix: str, # 数据集文件路径
         pipeline: List[Dict], # 读取、数据增强、totensor 方式。
         classes: Optional[Union[str, List[str]]] = None, 
         ann_file: Optional[str] = None) -> None: # annotation file 路径
    super().__init__()
    
    self.ann_file = ann_file
    self.data_prefix = data_prefix
    assert isinstance(pipeline, list), 'pipeline is type of list'
    self.pipeline = Compose(pipeline)
    self.CLASSES = self.get_classes(classes)
    self.data_infos = self.load_annotations()
    
    # build a look up dict for sampling instances by class id
    self.data_infos_class_dict = {i: [] for i in range(len(self.CLASSES))}
    for idx, data_info in enumerate(self.data_infos):
    self.data_infos_class_dict[data_info['gt_label'].item()].append(idx)

__init__ 中, 并未显示的读入数据,成员对象记录了数据集文件和标签文件的路径、生成了包含每个image file信息 的列表、并保存了真实类名self.CLASSES等。self.pipeline指示了 data preprocess 的方式( load channel 3 RGB-> augmentations-> tensor)。self.data_infos_class_dict建立了gt_label(数字化标签 from 0) 到 imag idx 的映射。

较为重要的是 self.load_annotations , 该方法在 BaseFweShotDataset 中仅定义了抽象方法。子类中(MiniImageNetDataset TieredImageNetDataset)必须根据数据集的结构进行客制化实现。MiniImageNetDataset类为例,其 load_annotations方法 return 1个data_infos 列表,以 dict 为元素。每个dict涵盖了数据集中每个图片的全部必要信息,于MiniImageNet中, dict结构如下:

info = {
        'img_prefix':osp.join(self.data_prefix, 'images', class_name),
        'img_info': {'filename': filename}, # 前两个键值用于在__getitem__中load image。
        'gt_label':np.array(gt_label, dtype=np.int64) # 每个imag 样本对应的标签(0~63 for trianing set)。
        }

BaseFewShotDataset 实现了一些基础的 data 读取、解析方法。其中,__len____getitem__方法依赖于 self.data_infos 列表,二者是通用的实现范式,子类中基本上不需要重写。不同的数据集可能需要的self.pipeline不同。

def __len__(self) -> int:
        """Return length of the dataset."""
        return len(self.data_infos)
        
# __getitem__ 是较为通用的实现范式,从self.data_infos 列表中根据 idx 获得某 image的文件信息,
# 再 call self.pipeline
def __getitem__(self, idx: int) -> Dict:
    return self.prepare_data(idx)
    
def prepare_data(self, idx: int) -> Dict:
    results = copy.deepcopy(self.data_infos[idx])
    return self.pipeline(results)

1.2 classmethod

@classmethod装饰方法,可以在不实例化 dataset 类对象的前提下调用该方法。
BaseFewshotDataset 实现了get_classes(cls, classes), 可以便利化查看该 dataset 类对象的真实类名(wnid)。
一般 classes参数传入Noneget_classes 在内部调用 预定义的类成员数据CLASSES,返回classes_names 给实例成员self.CLASSES子类中,需要根据数据集来预定义CLASSES

1.3 实例方法

除了在__getiem__中调用的 prepare_data方法外,BaseFewShotDataset 实现了sample_shots_by_class_id方法,根据class_id随机选择 num_shots个该类别image的idx,契合 few shot learning 的任务需求。

def sample_shots_by_class_id(self, class_id: int,
                             num_shots: int) -> List[int]:
    """Random sample shots of given class id."""
    all_shot_ids = self.data_infos_class_dict[class_id]
    return np.random.choice(
        all_shot_ids, num_shots, replace=False).tolist()

1.4 property

使用@property将方法装饰为成员变量。class_to_idx(self),返回 gt_label 到真实类名wnid的字典映射。gt_label根据self.CLASSES类名列表的顺序生成。

@classmethod
def class_to_idx(self) -> Mapping:
    """Map mapping class name to class index.

    Returns:
        dict: mapping from class name to class index.
    """
    return {_class: i for i, _class in enumerate(self.CLASSES)}

1.5 staticmethod

BaseFewshotDataset类中实现了 测试方法。该方法是统一的实现范式,子类不需要重写。

@staticmethod
evaluate(results: List,
         gt_labels: np.array,
         metric: Union[str, List[str]] = 'accuracy',
         metric_options: Optional[dict] = None,
         logger: Optional[object] = None) -> Dict

2 MiniImageNetDataset

为不同数据集实现的 “TypeDataset” 继承自 BaseFewshotDataset且必须重写load_annotations补全CLASSES。根据数据集结构的具体需求,get_classes方法一般也要客制化。在__init__方法中,根据任务模式(train\val\test),引入str对象self.subset,根据任务模式,控制load_annotations方法中读入不同的CLASSES

@DATASETS.register_module()
class MiniImageNetDataset(BaseFewShotDataset):
    """MiniImageNet dataset for few shot classification.

    Args:
        subset (str| list[str]): The classes of whole dataset are split into
            three disjoint subset: train, val and test. If subset is a string,
            only one subset data will be loaded. If subset is a list of
            string, then all data of subset in list will be loaded.
            Options: ['train', 'val', 'test']. Default: 'train'.
        file_format (str): The file format of the image. Default: 'JPEG'
    """

    resource = 'https://github.com/twitter/meta-learning-lstm/tree/master/data/miniImagenet'  # noqa

    TRAIN_CLASSES = TRAIN_CLASSES  # 
    VAL_CLASSES = VAL_CLASSES
    TEST_CLASSES = TEST_CLASSES

    def __init__(self,
                 subset: Literal['train', 'test', 'val'] = 'train',
                 file_format: str = 'JPEG',
                 *args,
                 **kwargs):
        if isinstance(subset, str):
            subset = [subset]
        for subset_ in subset:
            assert subset_ in ['train', 'test', 'val']
        self.subset = subset
        self.file_format = file_format
        super().__init__(*args, **kwargs)

对于MiniImageNetDaset而言,其实只根据数据集结构重写了get_classes load_annotations方法,定义了类变量“CLASSES"。流程是很简单的。

TRAIN_CLASSES = TRAIN_CLASSES  # 类变量
VAL_CLASSES = VAL_CLASSES
TEST_CLASSES = TEST_CLASSES

3 dataset_wrappers.py

mmfewshot/classification/datasets/*目录下,还有一个dataset_wrappers.py。其中,实现了EpisodicDatasetMetaTestDataset,是 MiniImageNetDataset类的 wrapper,在__init__中用self.daaset接收 MiniImageNetDataset对象,加以包裹。

3.1 EpisodicDataset

EpisodicDataset的目的是实现 情景训练 范式。其在__init__方法中,接收 TypeDataset对象、episodes总数、way数、shot数、query数、随机种子:

def __init__(self,
             dataset: Dataset,
             num_episodes: int,
             num_ways: int,
             num_shots: int,
             num_queries: int,
             episodes_seed: int | None = None) -> None:
    self.dataset = dataset
    self.num_ways = num_ways
    self.num_shots = num_shots
    self.num_queries = num_queries
    self.num_episodes = num_episodes
    self._len = len(self.dataset)
    self.CLASSES = dataset.CLASSES
    # using same episodes seed can generate same episodes for same dataset
    # it is designed for the reproducibility of meta train or meta test
    self.episodes_seed = episodes_seed
    self.episode_idxes, self.episode_class_ids = self.generate_episodic_idxes()

EpisodicDataset类并非是子类,其直接在初始化方法中接收MiniImageNetDataset类作为类成员。在__init__中,最为关键的是调用了self.generate_episodic_idxes方法,该方法依据自定义的随机种子,按照情景训练范式,生成 2 个长度为 self.num_episodes的列表,episode_idxes, episode_class_ids

episode_class_idx,该列表中的每个元素为列表,是某个episode sample 出的self.num_ways个类别的gt_label,即数字化标签。

episode_idxes,该列表中的每个元素为字典,以’support’、'query’为key,值为sample出的长度分别为self.num_shots*self.num_waysself.num_queries*self.num_ways列表,其元素为 image 的 索引。

{ 'support': episodic_support_idx,'query': episodic_query_idx}

EpisodicDataset类实现了自己的__len____getitem__方法。__getitem__接收某episode的索引值,episode_idxes中获取该episode需要的 support、query样本对应的索引,并调用self.dataset.__getitem__返回样本数据。

def __getitem__(self, idx: int) -> dict:
    """Return a episode data at the same time.

    For `EpisodicDataset`, this function would return num_ways *
    num_shots support images and num_ways * num_queries query image.
    """

    return {
        'support_data':
        [self.dataset[i] for i in self.episode_idxes[idx]['support']],
        'query_data':
        [self.dataset[i] for i in self.episode_idxes[idx]['query']]
    }

def __len__(self) -> int:
    """The length of the dataset is the number of generated episodes."""
    return self.num_episodes

3.2 MetaTestDataset

继承自EpisodicDataset, 引入了self._mode,来更改__getitem__返回样本数据的方式。

 def __init__(self, *args, **kwargs) -> None:
    super().__init__(*args, **kwargs)
    self._mode = 'test_set'
    self._task_id = 0
    self._with_cache_feats = False

def __getitem__(self, idx: int) -> dict:
    """Return data according to mode.

    For mode `test_set`, this function would return single image as regular
    dataset. For mode `support`, this function would return single support
    image of current episode. For mode `query`, this function would return
    single query image of current episode. If the dataset have cached the
    extracted features from fixed backbone, then the features will be
    return instead of image.
    """

    if self._mode == 'test_set':
        idx = idx
    elif self._mode == 'support':
        idx = self.episode_idxes[self._task_id]['support'][idx]
    elif self._mode == 'query':
        idx = self.episode_idxes[self._task_id]['query'][idx]

    if self._with_cache_feats:
        return {
            'feats': self.dataset.data_infos[idx]['feats'],
            'gt_label': self.dataset.data_infos[idx]['gt_label']
        }
    else:
        return self.dataset[idx]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值