GridMask是2020年arXiv上的一篇论文,可以认为是直接对标Hide_and_Seek方法。与之不同的是,GridMask采用了等间隔擦除patch的方式,有点类似空洞卷积,或许可以取名叫空洞擦除?
GridMask Data Augmentation
paper: https://arxiv.org/pdf/2001.04086
code: GitHub - dvlab-research/GridMask
核心操作如下图所示,看一下基本就能明白。
实现代码相比于其他方法要复杂一些,如下:
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import pdb
import math
class Grid(object):
def __init__(self, d1, d2, rotate=1, ratio=0.5, mode=0, prob=1.):
self.d1 = d1
self.d2 = d2
self.rotate = rotate
self.ratio = ratio
self.mode = mode
self.st_prob = self.prob = prob
def set_prob(self, epoch, max_epoch):
self.prob = self.st_prob * min(1, epoch / max_epoch)
def __call__(self, img):
if np.random.rand() > self.prob:
return img
h = img.size(1)
w = img.size(2)
# 1.5 * h, 1.5 * w works fine with the squared images
# But w