import cv2
import torch
import torch.nn.functional as F
import numpy as np
import time
device = torch.device("cuda:0")
def torch_inpaint(pil_img, mask):
start = time.time()
r = mask[:, 0:1, :, :]
g = mask[:, 1:2, :, :]
b = mask[:, 2:3, :, :]
mask = r & torch.bitwise_not(g)
dil_kernel = torch.Tensor([[[[0,1,0], [1,1,1], [0,1,0]]]]).to(device=device)
mask = F.conv2d(mask.type(torch.float32), dil_kernel, padding=1)
mask = F.conv2d(mask.type(torch.float32), dil_kernel, padding=1)
mask = (mask > 0)
mask = mask.repeat(1, 3, 1, 1)
mask_g = g.repeat(1, 3, 1, 1)
mask_b = b.repeat(1, 3, 1, 1)
tt = torch.mean(pil_img[mask_b])
pil_img[mask] = 0
pil_img_tmp = pil_img.clone()
pil_img_tmp[mask_g] = 0
mask_g = mask_g.bool()
box_ks = 15
filters = torch.ones((3, 1, box_ks, box_ks), dtype=torch.float32, device=device)
blurred = F.conv2d(pil_img_tmp, filters, padding=(box_ks-1)//2, groups=3)
countMask = F.conv2d(((mask|mask_g) == 0).type(torch.float32), filters, padding=(box_ks-1)//2, groups=3)
eps = 1e-8
pil_img[mask] = blurred[mask] / (countMask[mask]+eps)
pil_img[(mask & (pil_img == 0))] = tt
pil_img = pil_img.squeeze().cpu().numpy().transpose(1,2,0).astype(np.uint8)
print(time.time() - start)
return pil_img
if __name__ == '__main__':
img = cv2.imread("./331800.jpg")
mask = cv2.imread("./331800.png", 0)
mask_one_hot = np.zeros((mask.shape[0], mask.shape[1], 3),dtype=np.uint8)
mask_one_hot[:, :, 0][mask == 2] = 1
mask_one_hot[:, :, 1][mask == 1] = 1
mask_one_hot[:, :, 2][mask == 0] = 1
im_tensor = torch.tensor([img], dtype=torch.float32)
mask_tensor = torch.tensor([mask_one_hot])
im_tensor = torch.transpose(im_tensor, 1,3)
im_tensor = torch.transpose(im_tensor, 2,3).cuda()
mask_tensor = torch.transpose(mask_tensor, 1,3)
mask_tensor = torch.transpose(mask_tensor, 2,3).cuda()
output_im = torch_inpaint(im_tensor, mask_tensor)
cv2.imwrite("output_im.png", output_im)