tensorflow(神经网络)学习笔记(五)图像生成文本之实战(笔记)

inception_v3模型下载代码:

#%%

# coding: utf-8
 
import tensorflow as tf
import os
import tarfile
import requests
 
# inception模型下载地址
inception_pretrain_model_url = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
 
# 模型存放地址
inception_pretrain_model_dir = "inception_model"
if not os.path.exists(inception_pretrain_model_dir):
    os.makedirs(inception_pretrain_model_dir)
 
# 获取文件名,以及文件路径
filename = inception_pretrain_model_url.split('/')[-1]
filepath = os.path.join(inception_pretrain_model_dir, filename)
 
# 下载模型
if not os.path.exists(filepath):
    print("download: ", filename)
    r = requests.get(inception_pretrain_model_url, stream=True)
    with open(filepath, 'wb') as f:
        for chunk in r.iter_content(chunk_size=1024):
            if chunk:
                f.write(chunk)
print("finish: ", filename)
# 解压文件
tarfile.open(filepath, 'r:gz').extractall(inception_pretrain_model_dir)
 
# 模型结构存放文件
log_dir = 'inception_log'
if not os.path.exists(log_dir):
    os.makedirs(log_dir)
 
# classify_image_graph_def.pb为google训练好的模型
inception_graph_def_file = os.path.join(inception_pretrain_model_dir, 'classify_image_graph_def.pb')
with tf.Session() as sess:
    # 创建一个图来存放google训练好的模型
    with tf.gfile.FastGFile(inception_graph_def_file, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, name='')
    # 保存图的结构
    writer = tf.summary.FileWriter(log_dir, sess.graph)
    writer.close()

数据集:flickr30K_images
图片描述,一张图片5条描述图片名+#号,再用Tab健隔开
在这里插入图片描述
1.词表统计


import os
import sys
import pprint


input_description_file = r""
output_vocab_file = r""

def count_vocab(input_description_file):
    """
    Genenrates vocabulary.
    In addition, count distribution od length of sentence
    and max legnth of image description.
    :param input_description_file: 
    :return: 
    """
    with open(input_description_file, 'r') as f:
        lines = f.readlines()
    max_length_of_sentences = 0
    length_dict = {}
    vocab_dict = {}
    for line in lines:
        image_id, description = line.strip('\n').split('\t')
        words = description.strip(' ').spilt()
        max_length_of_sentences = max(max_length_of_sentences, len(words))
        length_dict.setdefault(len(words), 0)
        length_dict[len(words)] += 1
        
        for word in words:
            vocab_dict.setdefault(word, 0)
            vocab_dict[word] += 1
        print(max_length_of_sentences)
        pprint.pprint(length_dict)
        return vocab_dict
'''统计出句子的长度和分布,我们就可以选择训练时限定句子的长度,40为一个合理的大小'''
vocab_dict = count_vocab(input_description_file)

sorted_vocab_dict = sorted(vocab_dict.items(), key= lambda d:d[1], reverse=True)

with open(output_vocab_file, 'w') as f:
    f.write('<UNK>\t10000000\n')
    for item in sorted_vocab_dict:
        f.write('%s\t%d\n' % item)

生成的格式
在这里插入图片描述
使用inception_v3进行对图片进行处理。



import os
import sys
import tensorflow as tf
from tensorflow import gfile
from tensorflow import logging
import pprint
import numpy as np
import pickle 


model_file = r".\deep_learn\image2text\classify_image_graph_def.pb"
input_description_file = r".\deep_learn\image2text\results_20130124.token"
inout_img_dir = r".\deep_learn\image2text\flickr30k-images"
output_folder = r".\deep_learn\image2text\download_inception_v3_features"

batch_size = 1000
if not gfile.Exists(output_folder):
    gfile.MakeDirs(output_folder)

def parse_token_file(token_file):
    """Parses image description file."""
    img_name_to_tokens = {}
    with gfile.GFile(token_file, 'r') as f:
        lines = f.readlines()
    
    for line in lines:
        img_id, description = line.strip('\r\n').split('\t')
        img_name, _ = img_id.split('#')
        img_name_to_tokens.setdefault(img_name, [])
        img_name_to_tokens[img_name].append(description)
    return img_name_to_tokens

img_name_to_tokens = parse_token_file(input_description_file)
all_img_names = img_name_to_tokens.keys()

logging.info("num of all images: {}".format(len(all_img_names)))
pprint.pprint(list(img_name_to_tokens.keys())[0:10])
pprint.pprint(img_name_to_tokens['2778832101.jpg'])



def load_paretrained_inception_v3(model_file):
    with gfile.FastGFile(model_file, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        _ = tf.import_graph_def(graph_def, name='')
load_paretrained_inception_v3(model_file)



# 把30K多张图片划分成 30多个小文件
num_batches = int(len(all_img_names) / batch_size)
if len(all_img_names) % batch_size != 0:
    num_batches += 1

with tf.Session() as sess:
    second_to_last_tensor = sess.graph.get_tensor_by_name("pool_3:0")
    for i in range(num_batches):
        batch_img_names = list(all_img_names)[i*batch_size: (i+1)*batch_size]
        batch_features = []
        for img_name in batch_img_names:
            img_path = os.path.join(inout_img_dir, img_name)
            if not gfile.Exists(img_path):
                tf.logging.info("---")
                continue
            img_data = gfile.FastGFile(img_path, "rb").read()
            # 通过Inception v3 变成矩阵
            feature_vector = sess.run(second_to_last_tensor,
                                      feed_dict={
                                          "DecodeJpeg/contents:0": img_data
                                      })
            batch_features.append(feature_vector)
        batch_features = np.vstack(batch_features)
        output_filename = os.path.join(
            output_folder,
            "image_features-%d.pickle" % i
        )
        logging.info("writing to file {} ".format(output_filename))
        with gfile.GFile(output_filename, 'w') as f:
            pickle.dump((batch_img_names, batch_features), f)

载入数据

"""
步骤
1. Data generator
    a. Loads vocab
    b. Loads image features
    c. provide data for training.
2. Builds image caption model
3. Trains the model
"""
import os
import sys
import tensorflow as tf
from tensorflow import gfile
from tensorflow import logging
import pprint
import pickle
import numpy as np
import math
import random

input_description_file = r".\deep_learn\image2text\results_20130124.token"
input_img_feature_dir = r".\deep_learn\image2text\download_inception_v3_features"
input_vocab_file = r".\deep_learn\image2text\vocab.txt"
output_dir = r".\deep_learn\image2text\local_run"

if not  gfile.Exists(output_dir):
    gfile.MakeDirs(output_dir)

def get_default_params():
    return tf.contrib.training.HParams(
        # 过滤低频率词汇
        num_vocab_word_threshold = 3,
        num_embedding_nodes = 32,
        num_timesteps = 10,
        num_lstm_nodes = [64, 64],
        num_lstm_layers = 2,
        num_fc_nodes = 32,
        batch_size = 80,
        cell_type = "lstm",
        clip_lstm_grads = 1.0,
        learning_rate = 0.001,
        keep_prob = 0.8,
        log_frequent = 100,
        save_frequent = 1000,
    )

hps = get_default_params()


class Vocab:
    def __init__(self, filename, word_num_threshold):
        self._id_to_word = {}
        self._word_to_id = {}
        self._unk = -1
        self._eos = -1
        self._word_num_threshold = word_num_threshold
        self._read_dict(filename)
        
    def _read_dict(self, filename):
        with gfile.GFile(filename, 'r') as f:
            lines = f.readlines()
        for line in lines:
            
            word, occurrence = line.strip('\r\n').split('\t')
            occurrence = int(occurrence)
            if occurrence < self._word_num_threshold:
                continue
            idx = len(self._id_to_word)
            if word == '<UNK>':
                self._unk = idx
            elif word == '.':
                self._eos = idx
            if word in self._word_to_id or idx in self._id_to_word:
                raise Exception("duplicate words in vocab")
            self._word_to_id[word] = idx
            self._id_to_word[idx] = word
    
    @property
    def unk(self):
        return self._unk
    @property
    def eos(self):
        return self._eos
    
    def word_to_id(self, word):
        return self._word_to_id.get(word, self._unk)
    
    def id_to_word(self, word_id):
        return self._id_to_word.get(word_id, '<UNK>')
    
    def size(self):
        return len(self._id_to_word)
    
    def encode(self, sentence):
        return [self.word_to_id(word) for word in sentence.split(' ')]
    
    def decode(self, sentence_id):
        words =  [self.id_to_word(word_id) for word_id in sentence_id]
        return ' '.join(words)

vocab = Vocab(input_vocab_file, hps.num_vocab_word_threshold)
vocab_size = vocab.size()
logging.info('vocab_size : {}'.format(vocab_size))

pprint.pprint(vocab.encode("I have a dream."))
pprint.pprint(vocab.decode([5, 10, 9, 21]))

def parse_token_file(token_file):
    """Parses image description file."""
    img_name_to_tokens = {}
    with gfile.GFile(token_file, 'r', ) as f:
        lines = f.readlines()
    
    for line in lines:
        img_id, description = line.strip('\r\n').split('\t')
        img_name, _ = img_id.split('#')
        img_name_to_tokens.setdefault(img_name, [])
        img_name_to_tokens[img_name].append(description)
    return img_name_to_tokens

def convert_token_to_id(img_name_to_tokens, vocab):
    """Converts tokens of each description of imgs to id."""
    img_name_to_tokens_id = {}
    for img_name in img_name_to_tokens:
        img_name_to_tokens_id.setdefault(img_name, [])
        for description in img_name_to_tokens[img_name]:
            token_ids = vocab.encode(description)
            img_name_to_tokens_id[img_name].append(token_ids)
    return img_name_to_tokens_id

img_name_to_tokens = parse_token_file(input_description_file)
img_name_to_tokens_id = convert_token_to_id(img_name_to_tokens, vocab)

logging.info("num of all imgs: {}".format(len(img_name_to_tokens)))
pprint.pprint(img_name_to_tokens['2778832101.jpg'])
logging.info("num of all imgs: {}".format(len(img_name_to_tokens_id)))
pprint.pprint(img_name_to_tokens_id['2778832101.jpg'])


class ImageCaptionData:
    """Provides data for image caption model."""
    def __init__(self,
                 img_name_to_tokens_id,
                 img_feature_dir,
                 num_timesteps,
                 vocab,
                 deterministic = False):
        self._vocab = vocab
        self._img_name_to_tokens_id = img_name_to_tokens_id
        # 截取句子的长度
        self._num_timesteps = num_timesteps
        # 是否要进行shuffle()随机打乱
        self._deterministic = deterministic
        self._indicator = 0
        # 图片名
        self._img_feature_filenames = []
        # 图片特征
        self._img_feature_data = []
        
        self._all_img_feature_filepaths = []
        for filename in gfile.ListDirectory(img_feature_dir):
            if not filename[0] == '.':
                self._all_img_feature_filepaths.append(
                    os.path.join(img_feature_dir, filename)
                )
        pprint.pprint(self._all_img_feature_filepaths)
        self._load_img_feature_pickle()
        
        if not self._deterministic:
            self._random_shuffle()
            
    def _load_img_feature_pickle(self):
        """Loads img feature data from pickle"""
        for filepath in self._all_img_feature_filepaths:
            logging.info("loading {}".format(filepath))
            with gfile.GFile(filepath, 'rb') as f:
                filenames, features = pickle.load(f)
                # 图片名字列表的合并
                self._img_feature_filenames += filenames
                self._img_feature_data.append(features)
        #vstack [# (1000, 1, 1, 2048, #(1000, 1, 1, 2048] -> #(2000, 1, 1, 2048)
        self._img_feature_data = np.vstack(self._img_feature_data)
        origin_shape = self._img_feature_data.shape
        self._img_feature_data = np.reshape(self._img_feature_data,
                                            (origin_shape[0], origin_shape[3]))
        self._img_feature_filenames = np.asarray(self._img_feature_filenames)
        print(self._img_feature_data.shape)
        print(self._img_feature_filenames.shape)
        
    def size(self):
        return len(self._img_feature_filenames)
    
    def _random_shuffle(self):
        """Shuffle data randomly."""
        p = np.random.permutation(self.size())
        self._img_feature_filenames = self._img_feature_filenames[p]
        self._img_feature_data = self._img_feature_data[p]
    # 图片特征的维度
    def img_feature_size(self):
        return self._img_feature_data.shape[1]
    # 选择文件描述
    def _img_desc(self, batch_filenames):
        """Gets description for filenames in batch."""
        batch_sentence_ids = []
        batch_weights = []
        for filename in batch_filenames:
            token_ids_set = self._img_name_to_tokens_id[filename]
            chosen_token_ids = random.choice(token_ids_set)
            chosen_token_ids_length = len(chosen_token_ids)
            # 填充为1
            weight = [1 for i in range(chosen_token_ids_length)]
            if chosen_token_ids_length >= self._num_timesteps:
                chosen_token_ids = chosen_token_ids[0: self._num_timesteps]
                weight = weight[0: self._num_timesteps]
            else:
                remaing_length = self._num_timesteps - chosen_token_ids_length
                # 用eos进行填充
                chosen_token_ids += [self._vocab.eos for i in range(remaing_length)]
                weight += [0 for i in range(remaing_length)]
            batch_sentence_ids.append(chosen_token_ids)
            batch_weights.append(weight)
        batch_sentence_ids = np.asarray(batch_sentence_ids)
        batch_weights = np.asarray(batch_weights)
        return batch_sentence_ids, batch_weights
    def next_batch(self, batch_size):
        """Returns next batch size."""
        end_indicator = self._indicator + batch_size
        if end_indicator > self.size():
            if not self._deterministic:
                self._random_shuffle()
            self._indicator = 0
            end_indicator = self._indicator + batch_size
        assert end_indicator < self.size()
        
        batch_filenames = self._img_feature_filenames[self._indicator: end_indicator]
        batch_img_features = self._img_feature_data[self._indicator: end_indicator]
        # batch_weights的作用为在 (sentence_ids: [100,101,123,4,5,0,0,0] -> [1,1,1,1,1,0,0,0] 0为 <UNK>未出现过的词) 把没有意义的词在训练时,梯度下降和loss计算中去掉
        batch_sentence_ids, batch_weights = self._img_desc(batch_filenames)
        self._indicator = end_indicator
        return batch_img_features, batch_sentence_ids, batch_weights, batch_filenames
    
caption_data = ImageCaptionData(img_name_to_tokens_id,
                                input_img_feature_dir,
                                hps.num_timesteps,
                                vocab)
img_feature_dim = caption_data.img_feature_size()
caption_data_size = caption_data.size()
logging.info("img_feature_dim: {}".format(img_feature_dim))
logging.info("caption_data_size: {}".format(caption_data_size))

batch_img_features, batch_sentence_ids, batch_weights, batch_img_names = caption_data.next_batch(5) 
pprint.pprint(batch_img_features)
pprint.pprint(batch_sentence_ids)
pprint.pprint(batch_weights)
pprint.pprint(batch_img_names)

训练

#%%

"""
1. Data generator
    a. Loads vocab
    b. Loads image features
    c. provide data for training.
2. Builds image caption model
3. Trains the model
"""
import os
import sys
import tensorflow as tf
from tensorflow import gfile
from tensorflow import logging
import pprint
import pickle
import numpy as np
import math
import random

input_description_file = r".\deep_learn\image2text\results_20130124.token"
input_img_feature_dir = r".\deep_learn\image2text\download_inception_v3_features"
input_vocab_file = r".\deep_learn\image2text\vocab.txt"
output_dir = r".\deep_learn\image2text\local_run"

if not  gfile.Exists(output_dir):
    gfile.MakeDirs(output_dir)

def get_default_params():
    return tf.contrib.training.HParams(
        # 过滤低频率词汇
        num_vocab_word_threshold = 3,
        num_embedding_nodes = 32,
        num_timesteps = 10,
        num_lstm_nodes = [64, 64],
        num_lstm_layers = 2,
        num_fc_nodes = 32,
        batch_size = 80,
        cell_type = "lstm",
        clip_lstm_grads = 1.0,
        learning_rate = 0.001,
        keep_prob = 0.8,
        log_frequent = 100,
        save_frequent = 1000,
    )

hps = get_default_params()

#%%

class Vocab:
    def __init__(self, filename, word_num_threshold):
        self._id_to_word = {}
        self._word_to_id = {}
        self._unk = -1
        self._eos = -1
        self._word_num_threshold = word_num_threshold
        self._read_dict(filename)
        
    def _read_dict(self, filename):
        with gfile.GFile(filename, 'r') as f:
            lines = f.readlines()
        for line in lines:
            
            word, occurrence = line.strip('\r\n').split('\t')
            occurrence = int(occurrence)
            if occurrence < self._word_num_threshold:
                continue
            idx = len(self._id_to_word)
            if word == '<UNK>':
                self._unk = idx
            elif word == '.':
                self._eos = idx
            if word in self._word_to_id or idx in self._id_to_word:
                raise Exception("duplicate words in vocab")
            self._word_to_id[word] = idx
            self._id_to_word[idx] = word
    
    @property
    def unk(self):
        return self._unk
    @property
    def eos(self):
        return self._eos
    
    def word_to_id(self, word):
        return self._word_to_id.get(word, self._unk)
    
    def id_to_word(self, word_id):
        return self._id_to_word.get(word_id, '<UNK>')
    
    def size(self):
        return len(self._id_to_word)
    
    def encode(self, sentence):
        return [self.word_to_id(word) for word in sentence.split(' ')]
    
    def decode(self, sentence_id):
        words =  [self.id_to_word(word_id) for word_id in sentence_id]
        return ' '.join(words)

vocab = Vocab(input_vocab_file, hps.num_vocab_word_threshold)
vocab_size = vocab.size()
logging.info('vocab_size : {}'.format(vocab_size))

pprint.pprint(vocab.encode("I have a dream."))
pprint.pprint(vocab.decode([5, 10, 9, 21]))

#%%

def parse_token_file(token_file):
    """Parses image description file."""
    img_name_to_tokens = {}
    with gfile.GFile(token_file, 'r', ) as f:
        lines = f.readlines()
    
    for line in lines:
        img_id, description = line.strip('\r\n').split('\t')
        img_name, _ = img_id.split('#')
        img_name_to_tokens.setdefault(img_name, [])
        img_name_to_tokens[img_name].append(description)
    return img_name_to_tokens

def convert_token_to_id(img_name_to_tokens, vocab):
    """Converts tokens of each description of imgs to id."""
    img_name_to_tokens_id = {}
    for img_name in img_name_to_tokens:
        img_name_to_tokens_id.setdefault(img_name, [])
        for description in img_name_to_tokens[img_name]:
            token_ids = vocab.encode(description)
            img_name_to_tokens_id[img_name].append(token_ids)
    return img_name_to_tokens_id

img_name_to_tokens = parse_token_file(input_description_file)
img_name_to_tokens_id = convert_token_to_id(img_name_to_tokens, vocab)

logging.info("num of all imgs: {}".format(len(img_name_to_tokens)))
pprint.pprint(img_name_to_tokens['2778832101.jpg'])
logging.info("num of all imgs: {}".format(len(img_name_to_tokens_id)))
pprint.pprint(img_name_to_tokens_id['2778832101.jpg'])

#%%

class ImageCaptionData:
    """Provides data for image caption model."""
    def __init__(self,
                 img_name_to_tokens_id,
                 img_feature_dir,
                 num_timesteps,
                 vocab,
                 deterministic = False):
        self._vocab = vocab
        self._img_name_to_tokens_id = img_name_to_tokens_id
        # 截取句子的长度
        self._num_timesteps = num_timesteps
        # 是否要进行shuffle()随机打乱
        self._deterministic = deterministic
        self._indicator = 0
        # 图片名
        self._img_feature_filenames = []
        # 图片特征
        self._img_feature_data = []
        
        self._all_img_feature_filepaths = []
        for filename in gfile.ListDirectory(img_feature_dir):
            if not filename[0] == '.':
                self._all_img_feature_filepaths.append(
                    os.path.join(img_feature_dir, filename)
                )
        pprint.pprint(self._all_img_feature_filepaths)
        self._load_img_feature_pickle()
        
        if not self._deterministic:
            self._random_shuffle()
            
    def _load_img_feature_pickle(self):
        """Loads img feature data from pickle"""
        for filepath in self._all_img_feature_filepaths:
            logging.info("loading {}".format(filepath))
            with gfile.GFile(filepath, 'rb') as f:
                filenames, features = pickle.load(f)
                # 图片名字列表的合并
                self._img_feature_filenames += filenames
                self._img_feature_data.append(features)
        #vstack [# (1000, 1, 1, 2048, #(1000, 1, 1, 2048] -> #(2000, 1, 1, 2048)
        self._img_feature_data = np.vstack(self._img_feature_data)
        origin_shape = self._img_feature_data.shape
        self._img_feature_data = np.reshape(self._img_feature_data,
                                            (origin_shape[0], origin_shape[3]))
        self._img_feature_filenames = np.asarray(self._img_feature_filenames)
        print(self._img_feature_data.shape)
        print(self._img_feature_filenames.shape)
        
    def size(self):
        return len(self._img_feature_filenames)
    
    def _random_shuffle(self):
        """Shuffle data randomly."""
        p = np.random.permutation(self.size())
        self._img_feature_filenames = self._img_feature_filenames[p]
        self._img_feature_data = self._img_feature_data[p]
    # 图片特征的维度
    def img_feature_size(self):
        return self._img_feature_data.shape[1]
    # 选择文件描述
    def _img_desc(self, batch_filenames):
        """Gets description for filenames in batch."""
        batch_sentence_ids = []
        batch_weights = []
        for filename in batch_filenames:
            token_ids_set = self._img_name_to_tokens_id[filename]
            chosen_token_ids = random.choice(token_ids_set)
            chosen_token_ids_length = len(chosen_token_ids)
            # 填充为1
            weight = [1 for i in range(chosen_token_ids_length)]
            if chosen_token_ids_length >= self._num_timesteps:
                chosen_token_ids = chosen_token_ids[0: self._num_timesteps]
                weight = weight[0: self._num_timesteps]
            else:
                remaing_length = self._num_timesteps - chosen_token_ids_length
                # 用eos进行填充
                chosen_token_ids += [self._vocab.eos for i in range(remaing_length)]
                weight += [0 for i in range(remaing_length)]
            batch_sentence_ids.append(chosen_token_ids)
            batch_weights.append(weight)
        batch_sentence_ids = np.asarray(batch_sentence_ids)
        batch_weights = np.asarray(batch_weights)
        return batch_sentence_ids, batch_weights
    def next_batch(self, batch_size):
        """Returns next batch size."""
        end_indicator = self._indicator + batch_size
        if end_indicator > self.size():
            if not self._deterministic:
                self._random_shuffle()
            self._indicator = 0
            end_indicator = self._indicator + batch_size
        assert end_indicator < self.size()
        
        batch_filenames = self._img_feature_filenames[self._indicator: end_indicator]
        batch_img_features = self._img_feature_data[self._indicator: end_indicator]
        # batch_weights的作用为在 (sentence_ids: [100,101,123,4,5,0,0,0] -> [1,1,1,1,1,0,0,0] 0为 <UNK>未出现过的词) 把没有意义的词在训练时,梯度下降和loss计算中去掉
        batch_sentence_ids, batch_weights = self._img_desc(batch_filenames)
        self._indicator = end_indicator
        return batch_img_features, batch_sentence_ids, batch_weights, batch_filenames
    
caption_data = ImageCaptionData(img_name_to_tokens_id,
                                input_img_feature_dir,
                                hps.num_timesteps,
                                vocab)
img_feature_dim = caption_data.img_feature_size()
caption_data_size = caption_data.size()
logging.info("img_feature_dim: {}".format(img_feature_dim))
logging.info("caption_data_size: {}".format(caption_data_size))

batch_img_features, batch_sentence_ids, batch_weights, batch_img_names = caption_data.next_batch(5) 
pprint.pprint(batch_img_features)
pprint.pprint(batch_sentence_ids)
pprint.pprint(batch_weights)
pprint.pprint(batch_img_names)
    

#%%

def create_rnn_cell(hidden_dim, cell_type):
    """Return specific cell according to cell_type."""
    if cell_type == 'lstm':
        return tf.contrib.rnn.BasicLSTMCell(hidden_dim,
                                            state_is_tuple=True)
    elif cell_type == 'gru':
        return tf.contrib.rnn.GRUCell(hidden_dim)
    
    else:
        raise Exception("{} type has not been supported.".format(cell_type))
    
def dropout(cell, keep_prob):
    """Wrap cell with dropout."""
    return tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob = keep_prob)

def get_train_model(hps, vocab_size, img_feature_dim):
    num_timesteps = hps.num_timesteps
    batch_size = hps.batch_size
    img_feature = tf.placeholder(tf.float32,
                                 (batch_size, img_feature_dim))
    sentence = tf.placeholder(tf.int32, 
                              (batch_size, num_timesteps))
    # 填充
    mask = tf.placeholder(tf.int32, (batch_size, num_timesteps))
    
    keep_prob = tf.placeholder(tf.float32, name="keep_prob")
    global_step = tf.Variable(tf.zeros([], tf.int32),
                              name='global_step',
                              trainable=False)
    
    # prediction process:
    # sentence: [a, b, c, d, e]
    # input: [img, a, b, c, d]
    # img_feature: [0.4, 0.3, 10, 2]
    # predict #1 : img_feature - >embedding_img - >lstm -> (a)
    # predict #2 : a -> embedding_word -> lstm -> (b)
    # predict #3 : b -> embedding_word -> lstm -> (c)
    # .... 原始流程如上,但是我们这里把embedding_img设置成和embedding_word大小一样
    
    # Sets up embedding layer.
    embedding_initializer = tf.random_uniform_initializer(-0.1, 1.0)
    with tf.variable_scope('embedding',
                           initializer=embedding_initializer):
        # 句子转化成embedding 
        embeddings = tf.get_variable('embedding',
                                     [vocab_size, hps.num_embedding_nodes],
                                     tf.float32)
        # embed_token_ids: [batch_size, num_timesteps-1, num_embedding_nodes]
        embed_token_ids = tf.nn.embedding_lookup(
            embeddings,
            # 
            sentence[:, 0: num_timesteps -1 ]
        )
    img_feature_embed_init = tf.uniform_unit_scaling_initializer(
        factor=1.0
    )
    with tf.variable_scope('img_feature_embed',
                           initializer=img_feature_embed_init):
        # img_features: [batch_size, img_feature_dim]
        # embed_img : [batch_size, num_embedding_nodes]
        embed_img = tf.layers.dense(img_feature,
                                    hps.num_embedding_nodes)
        # embed_img : [batch_size, 1, num_embedding_nodes]
        embed_img = tf.expand_dims(embed_img, 1)
        # embed_input : [batch_size, num_timestpes, num_embedding_nodes]
        # 图片和句子embedding合并
        embed_inputs = tf.concat([embed_img, embed_token_ids], axis=1)
        
    # Sets up rnn netword
    scale = 1.0 / math.sqrt(hps.num_embedding_nodes + hps.num_lstm_nodes[-1]) / 3.0
    rnn_init = tf.random_uniform_initializer(-scale, scale)
    with tf.variable_scope('lstm_nn', initializer= rnn_init):
        cells = []
        for i in range(hps.num_lstm_layers):
            cell = create_rnn_cell(hps.num_lstm_nodes[i], hps.cell_type)
            cell = dropout(cell, keep_prob)
            cells.append(cell)
        cell = tf.contrib.rnn.MultiRNNCell(cells)
        
        init_state = cell.zero_state(hps.batch_size, tf.float32)
        # rnn_outputs: [batch_size, num_timesteps, hps.num_lstm_node[-1]
        rnn_outputs, _ = tf.nn.dynamic_rnn(cell,
                                           embed_inputs,
                                           initial_state = init_state)
    # Sets up fully-connected layer
    fc_init = tf.uniform_unit_scaling_initializer(factor= 1.0)
    with tf.variable_scope('fc', initializer=fc_init):
        rnn_outputs_2d = tf.reshape(rnn_outputs,
                                    [-1, hps.num_lstm_nodes[-1]])
        fc1 = tf.layers.dense(rnn_outputs_2d,
                              hps.num_fc_nodes,
                              name='fc1')
        fc1_dropout = tf.layers.dropout(fc1, keep_prob)
        fc1_relu = tf.nn.relu(fc1_dropout)
        logits = tf.layers.dense(fc1_relu,
                                 vocab_size,
                                 name='logits')
        
    # Calculates loss
    with tf.variable_scope('loss'):
        sentence_flatten = tf.reshape(sentence, [-1])
        mask_flatten = tf.reshape(mask, [-1])
        mask_sum = tf.reduce_sum(mask_flatten)
        
        sofxmax_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits,
            labels=sentence_flatten
        )
        # 去除填充的eso
        weighted_softmax_loss = tf.multiply(
            sofxmax_loss, tf.cast(mask_flatten, tf.float32)
        )
        loss = tf.reduce_sum(weighted_softmax_loss) / tf.cast(mask_sum, tf.float32)
        
        prediction = tf.argmax(logits, 1, output_type=tf.int32)
        correct_prediction = tf.equal(prediction,
                                      sentence_flatten)
        weighted_correct_prediction = tf.multiply(
            tf.cast(correct_prediction, tf.float32),
            tf.cast(mask_flatten, tf.float32)
        )
        # accuracy = tf.reduce_sum(weighted_correct_prediction) / tf.cast(mask_sum, tf.float32)
        accuracy = tf.reduce_sum(weighted_correct_prediction) / tf.cast(mask_sum, tf.float32)
        tf.summary.scalar('loss', loss)
        
    # Defines train op.
    with tf.variable_scope('train_op'):
        tvars = tf.trainable_variables()
        for var in tvars:
            logging.info('variable name: {}'.format(var.name))
        grads, _ = tf.clip_by_global_norm(
            tf.gradients(loss, tvars), hps.clip_lstm_grads
        )
        optimizer = tf.train.AdamOptimizer(hps.learning_rate)
        train_op = optimizer.apply_gradients(
            zip(grads, tvars), global_step= global_step
        )
        
    return ((img_feature, sentence, mask, keep_prob),
            (loss, accuracy, train_op),
            global_step)

placehoders, metrics, global_step = get_train_model(
    hps, vocab_size, img_feature_dim
)
img_feature, sentence, mask, keep_prob = placehoders
loss, accuracy, train_op = metrics
        
summary_op = tf.summary.merge_all()
init_op = tf.global_variables_initializer()
saver = tf.train.Saver(max_to_keep=10)


#%%

training_steps = 1000

with tf.Session() as sess:
    sess.run(init_op)
    writer = tf.summary.FileWriter(output_dir, sess.graph)
    for i in range(training_steps):
        (batch_img_features,
         batch_sentence_ids,
         batch_weights, _) = caption_data.next_batch(hps.batch_size)
        input_vals = (batch_img_features,
                      batch_sentence_ids,
                      batch_weights,
                      hps.keep_prob)
        feed_dict = dict(zip(placehoders, input_vals))
        fetches = [global_step, loss, accuracy, train_op]
        should_log = (i+1) % hps.log_frequent == 0
        should_save = (i+1) % hps.save_frequent == 0
        
        if should_log:
            fetches += [summary_op]
            
        outputs = sess.run(fetches, feed_dict=feed_dict)
        global_step_val, loss_val, accuracy_val = outputs[0:3]
        if should_log:
            summary_str = outputs[-1]
            writer.add_summary(summary_str, global_step_val)
            logging.info("Step: {}, loss: {:.2f}, acc: {:.2f}".format(global_step_val, loss_val, accuracy_val))
            
        if should_save:
            model_save_file = os.path.join(output_dir, "image_caption")
            logging.info("Step: {}, model saved ".format(global_step_val))
            saver.save(sess, model_save_file, global_step=global_step_val)

在这里插入图片描述

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值