大批量图像处理(7)——用各种各样的常见模型识别分类大量的图片(绝对干货)

本文介绍了一种使用深度学习模型进行图像分类的评估方法,通过加载不同的模型结构和参数,对预处理的图像数据进行预测,并与目标类别进行比较,以计算分类准确率。

准备如下:

所有常见模型的网络结构文件
在这里插入图片描述
需要用到的各种模型的参数文件
在这里插入图片描述

代码如下:

换模型需要改动的地方并不多,会在代码里注释到~~

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import csv
import os

from cleverhans.attacks import FastGradientMethod
from io import BytesIO
import numpy as np
from PIL import Image
from scipy.misc import imread
from scipy.misc import imsave
import tensorflow as tf
#from tensorflow.contrib.slim.nets import inception
from nets import inception_v3, inception_resnet_v2                 #改动地方之一,换网络结构,加上去就行了

slim = tf.contrib.slim
tensorflow_master = ""
checkpoint_path   = "/home/caad/workspace/CAAD2018/dataset/models/inception_v3.ckpt"  #改动地方之一,ckpt文件
input_csv         = "/home/caad/workspace/CAAD2018/dataset/images"
input_dir         ="/home/NEWDISK/output_champion"
max_epsilon       = 16.0
image_width       = 299
image_height      = 299
batch_size        = 10

eps = 2.0 * max_epsilon / 255.0
batch_shape = [batch_size, image_height, image_width, 3]
num_classes = 1001

def load_images(input_dir, batch_shape):
    images = np.zeros(batch_shape)
    filenames = []
    idx = 0
    batch_size = batch_shape[0]
    for filepath in sorted(tf.gfile.Glob(os.path.join(input_dir, '*.png'))):
        with tf.gfile.Open(filepath, "rb") as f:
            images[idx, :, :, :] = imread(f, mode='RGB').astype(np.float)*2.0/255.0 - 1.0
        filenames.append(os.path.basename(filepath))
        idx += 1
        if idx == batch_size:
            yield filenames, images
            filenames = []
            images = np.zeros(batch_shape)
            idx = 0
    if idx > 0:
        yield filenames, images

def load_target_class(input_dir):
  """Loads target classes."""
  with tf.gfile.Open(os.path.join(input_dir, 'target_class.csv')) as f:
    return {row[0]+'.png': int(row[7]) for row in csv.reader(f) if len(row) >= 7}

all_images_taget_class = load_target_class(input_csv)
sum=0
right_number = 0
with tf.Graph().as_default():

    x_input = tf.placeholder(tf.float32, shape=batch_shape)
	'''
    六个常见模型,选择其中一个,把其他五个挡住,需要其他要额外添加
    '''
        '''
    with slim.arg_scope(inception_v3.inception_v3_arg_scope()):
        _, end_points = inception_v3.inception_v3(x_input, num_classes=num_classes, is_training=False)
        
    with slim.arg_scope(inception_v3.inception_v3_arg_scope()):
        _, end_points = inception_v3.inception_v3(
            x_input, num_classes=num_classes, is_training=False, scope='AdvInceptionV3')

    with slim.arg_scope(inception_v3.inception_v3_arg_scope()):
        _, end_points = inception_v3.inception_v3(
            x_input, num_classes=num_classes, is_training=False, scope='Ens3AdvInceptionV3')

    with slim.arg_scope(inception_v3.inception_v3_arg_scope()):
        _, end_points = inception_v3.inception_v3(
            x_input, num_classes=num_classes, is_training=False, scope='Ens4AdvInceptionV3')

    with slim.arg_scope(inception_resnet_v2.inception_resnet_v2_arg_scope()):
        _, end_points = inception_resnet_v2.inception_resnet_v2(
            x_input, num_classes=num_classes, is_training=False, scope='EnsAdvInceptionResnetV2')

    with slim.arg_scope(inception_resnet_v2.inception_resnet_v2_arg_scope()):
        _, end_points = inception_resnet_v2.inception_resnet_v2(
            x_input, num_classes=num_classes, is_training=False, scope='AdvInceptionResnetV2')
        '''

    predicted_labels = tf.argmax(end_points['Predictions'], 1)

    saver = tf.train.Saver(slim.get_model_variables())
    session_creator = tf.train.ChiefSessionCreator(
        scaffold=tf.train.Scaffold(saver=saver),
        checkpoint_filename_with_path=checkpoint_path,
        master=tensorflow_master)

    with tf.train.MonitoredSession(session_creator=session_creator) as sess:
        for filenames, images in load_images(input_dir, batch_shape):
            target_class_for_batch = (
                [all_images_taget_class[n] for n in filenames]
                + [0] * (batch_size - len(filenames)))
            predicted_targeted_classes = sess.run(predicted_labels, feed_dict={x_input: images})

            for i in range(len(images)):
                if(predicted_targeted_classes[i]==target_class_for_batch[i]):
                    right_number+=1
                if(target_class_for_batch[i]!=0):
                    sum+=1
                print("TARGETED ADVERSARIAL IMAGE",sum,
                      "\n\tPredicted class:", predicted_targeted_classes[i],
                      "\n\tattack class:", target_class_for_batch[i])
        print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~over~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
        print("\n\t分类图片总数: ", sum)
        print("\n\t攻击成功数量: ",right_number)
        print("\n\taccuracy:",right_number/sum)






评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

wujiekd

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值