pytorch复现RRU-Net

论文地址

https://openaccess.thecvf.com/content_CVPRW_2019/html/CV-COPS/Bi_RRU-Net_The_Ringed_Residual_U-Net_for_Image_Splicing_Forgery_Detection_CVPRW_2019_paper.html

开源代码

https://github.com/yelusaleng/RRU-Net

下面就是复现的代码

使用的数据集是NIST

项目结构

详细代码

./dataload/NIST_data.py

import torch
import torchvision.transforms as tfs
import os
from PIL import Image
import time

LINUX = 0

if(LINUX==0):
    prefix = "G:/data/"
else:
    prefix = "/home/yyl/data/"

HEIGHT = 256
WIDTH = 256

class NIST_DATA(object):
    def __init__(self, mode="train"):
        super(NIST_DATA, self).__init__()
        self.nist_data_root = prefix + "NIST/"
        self.image_name = []
        self.gt_name = []
        self.readImage()


        # 图像正则化参数
        self.mean = [0.40789654, 0.44719302, 0.47026115]
        self.std = [0.28863828, 0.27408164, 0.27809835]
        self.im_tfs = tfs.Compose([
            tfs.ToTensor(),
            tfs.Normalize(self.mean, self.std)
        ])
        self.im_tfs_gt = tfs.Compose([tfs.ToTensor()])

        #self.filter()
        # 训练集配置前80%的图片,测试集配置后20%的图片
        if(mode == "train"):
            self.image_name = self.image_name[:round(len(self.image_name)*0.8)]
            self.gt_name = self.gt_name[:round(len(self.gt_name)*0.8)]
        else:
            self.image_name = self.image_name[round(len(self.image_name) * 0.8):]
            self.gt_name = self.gt_name[round(len(self.gt_name) * 0.8):]

        print("%s->一共加载了%d张图片"%(mode,len(self.image_name)))



    # 读取图片
    def readImage(self):
        image_root = self.nist_data_root + "1/"
        gt_root = self.nist_data_root + "1_mask/"
        filename = os.listdir(image_root)
        for i in range(len(filename)):
            self.image_name.append(image_root + filename[i])
            self.gt_name.append(gt_root + filename[i].split(".")[0]+".png")

        image_root = self.nist_data_root + "2/"
        gt_root = self.nist_data_root + "2_mask/"
        filename = os.listdir(image_root)
        for i in range(len(filename)):
            self.image_name.append(image_root + filename[i])
            self.gt_name.append(gt_root + filename[i].split(".")[0] + ".png")

    def filter(self, thresold=10):

        tmp_image_name = self.image_name.copy()
        tmp_gt_name = self.gt_name.copy()
        for idx in range(len(self.image_name)):
            img = Image.open(self.image_name[idx])
            img_gt = Image.open(self.gt_name[idx])
            img, img_gt = self.image_transform(img, img_gt)
            if(img_gt.sum()<thresold):
                #print("删除掉%d"%idx)
                tmp_image_name.remove(self.image_name[idx])
                tmp_gt_name.r
评论 25
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值