inception_v3
模型下载代码:
#%%
# coding: utf-8
import tensorflow as tf
import os
import tarfile
import requests
# inception模型下载地址
inception_pretrain_model_url = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
# 模型存放地址
inception_pretrain_model_dir = "inception_model"
if not os.path.exists(inception_pretrain_model_dir):
os.makedirs(inception_pretrain_model_dir)
# 获取文件名,以及文件路径
filename = inception_pretrain_model_url.split('/')[-1]
filepath = os.path.join(inception_pretrain_model_dir, filename)
# 下载模型
if not os.path.exists(filepath):
print("download: ", filename)
r = requests.get(inception_pretrain_model_url, stream=True)
with open(filepath, 'wb') as f:
for chunk in r.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
print("finish: ", filename)
# 解压文件
tarfile.open(filepath, 'r:gz').extractall(inception_pretrain_model_dir)
# 模型结构存放文件
log_dir = 'inception_log'
if not os.path.exists(log_dir):
os.makedirs(log_dir)
# classify_image_graph_def.pb为google训练好的模型
inception_graph_def_file = os.path.join(inception_pretrain_model_dir, 'classify_image_graph_def.pb')
with tf.Session() as sess:
# 创建一个图来存放google训练好的模型
with tf.gfile.FastGFile(inception_graph_def_file, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
# 保存图的结构
writer = tf.summary.FileWriter(log_dir, sess.graph)
writer.close()
数据集:flickr30K_images
图片描述,一张图片5条描述图片名+#号,再用Tab健隔开
1.词表统计
import os
import sys
import pprint
input_description_file = r""
output_vocab_file = r""
def count_vocab(input_description_file):
"""
Genenrates vocabulary.
In addition, count distribution od length of sentence
and max legnth of image description.
:param input_description_file:
:return:
"""
with open(input_description_file, 'r') as f:
lines = f.readlines()
max_length_of_sentences = 0
length_dict = {}
vocab_dict = {}
for line in lines:
image_id, description = line.strip('\n').split('\t')
words = description.strip(' ').spilt()
max_length_of_sentences = max(max_length_of_sentences, len(words))
length_dict.setdefault(len(words), 0)
length_dict[len(words)] += 1
for word in words:
vocab_dict.setdefault(word, 0)
vocab_dict[word] += 1
print(max_length_of_sentences)
pprint.pprint(length_dict)
return vocab_dict
'''统计出句子的长度和分布,我们就可以选择训练时限定句子的长度,40为一个合理的大小'''
vocab_dict = count_vocab(input_description_file)
sorted_vocab_dict = sorted(vocab_dict.items(), key= lambda d:d[1], reverse=True)
with open(output_vocab_file, 'w') as f:
f.write('<UNK>\t10000000\n')
for item in sorted_vocab_dict:
f.write('%s\t%d\n' % item)
生成的格式
使用inception_v3进行对图片进行处理。
import os
import sys
import tensorflow as tf
from tensorflow import gfile
from tensorflow import logging
import pprint
import numpy as np
import pickle
model_file = r".\deep_learn\image2text\classify_image_graph_def.pb"
input_description_file = r".\deep_learn\image2text\results_20130124.token"
inout_img_dir = r".\deep_learn\image2text\flickr30k-images"
output_folder = r".\deep_learn\image2text\download_inception_v3_features"
batch_size = 1000
if not gfile.Exists(output_folder):
gfile.MakeDirs(output_folder)
def parse_token_file(token_file):
"""Parses image description file."""
img_name_to_tokens = {}
with gfile.GFile(token_file, 'r') as f:
lines = f.readlines()
for line in lines:
img_id, description = line.strip('\r\n').split('\t')
img_name, _ = img_id.split('#')
img_name_to_tokens.setdefault(img_name, [])
img_name_to_tokens[img_name].append(description)
return img_name_to_tokens
img_name_to_tokens = parse_token_file(input_description_file)
all_img_names = img_name_to_tokens.keys()
logging.info("num of all images: {}".format(len(all_img_names)))
pprint.pprint(list(img_name_to_tokens.keys())[0:10])
pprint.pprint(img_name_to_tokens['2778832101.jpg'])
def load_paretrained_inception_v3(model_file):
with gfile.FastGFile(model_file, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')
load_paretrained_inception_v3(model_file)
# 把30K多张图片划分成 30多个小文件
num_batches = int(len(all_img_names) / batch_size)
if len(all_img_names) % batch_size != 0:
num_batches += 1
with tf.Session() as sess:
second_to_last_tensor = sess.graph.get_tensor_by_name("pool_3:0")
for i in range(num_batches):
batch_img_names = list(all_img_names)[i*batch_size: (i+1)*batch_size]
batch_features = []
for img_name in batch_img_names:
img_path = os.path.join(inout_img_dir, img_name)
if not gfile.Exists(img_path):
tf.logging.info("---")
continue
img_data = gfile.FastGFile(img_path, "rb").read()
# 通过Inception v3 变成矩阵
feature_vector = sess.run(second_to_last_tensor,
feed_dict={
"DecodeJpeg/contents:0": img_data
})
batch_features.append(feature_vector)
batch_features = np.vstack(batch_features)
output_filename = os.path.join(
output_folder,
"image_features-%d.pickle" % i
)
logging.info("writing to file {} ".format(output_filename))
with gfile.GFile(output_filename, 'w') as f:
pickle.dump((batch_img_names, batch_features), f)
载入数据
"""
步骤
1. Data generator
a. Loads vocab
b. Loads image features
c. provide data for training.
2. Builds image caption model
3. Trains the model
"""
import os
import sys
import tensorflow as tf
from tensorflow import gfile
from tensorflow import logging
import pprint
import pickle
import numpy as np
import math
import random
input_description_file = r".\deep_learn\image2text\results_20130124.token"
input_img_feature_dir = r".\deep_learn\image2text\download_inception_v3_features"
input_vocab_file = r".\deep_learn\image2text\vocab.txt"
output_dir = r".\deep_learn\image2text\local_run"
if not gfile.Exists(output_dir):
gfile.MakeDirs(output_dir)
def get_default_params():
return tf.contrib.training.HParams(
# 过滤低频率词汇
num_vocab_word_threshold = 3,
num_embedding_nodes = 32,
num_timesteps = 10,
num_lstm_nodes = [64, 64],
num_lstm_layers = 2,
num_fc_nodes = 32,
batch_size = 80,
cell_type = "lstm",
clip_lstm_grads = 1.0,
learning_rate = 0.001,
keep_prob = 0.8,
log_frequent = 100,
save_frequent = 1000,
)
hps = get_default_params()
class Vocab:
def __init__(self, filename, word_num_threshold):
self._id_to_word = {}
self._word_to_id = {}
self._unk = -1
self._eos = -1
self._word_num_threshold = word_num_threshold
self._read_dict(filename)
def _read_dict(self, filename):
with gfile.GFile(filename, 'r') as f:
lines = f.readlines()
for line in lines:
word, occurrence = line.strip('\r\n').split('\t')
occurrence = int(occurrence)
if occurrence < self._word_num_threshold:
continue
idx = len(self._id_to_word)
if word == '<UNK>':
self._unk = idx
elif word == '.':
self._eos = idx
if word in self._word_to_id or idx in self._id_to_word:
raise Exception("duplicate words in vocab")
self._word_to_id[word] = idx
self._id_to_word[idx] = word
@property
def unk(self):
return self._unk
@property
def eos(self):
return self._eos
def word_to_id(self, word):
return self._word_to_id.get(word, self._unk)
def id_to_word(self, word_id):
return self._id_to_word.get(word_id, '<UNK>')
def size(self):
return len(self._id_to_word)
def encode(self, sentence):
return [self.word_to_id(word) for word in sentence.split(' ')]
def decode(self, sentence_id):
words = [self.id_to_word(word_id) for word_id in sentence_id]
return ' '.join(words)
vocab = Vocab(input_vocab_file, hps.num_vocab_word_threshold)
vocab_size = vocab.size()
logging.info('vocab_size : {}'.format(vocab_size))
pprint.pprint(vocab.encode("I have a dream."))
pprint.pprint(vocab.decode([5, 10, 9, 21]))
def parse_token_file(token_file):
"""Parses image description file."""
img_name_to_tokens = {}
with gfile.GFile(token_file, 'r', ) as f:
lines = f.readlines()
for line in lines:
img_id, description = line.strip('\r\n').split('\t')
img_name, _ = img_id.split('#')
img_name_to_tokens.setdefault(img_name, [])
img_name_to_tokens[img_name].append(description)
return img_name_to_tokens
def convert_token_to_id(img_name_to_tokens, vocab):
"""Converts tokens of each description of imgs to id."""
img_name_to_tokens_id = {}
for img_name in img_name_to_tokens:
img_name_to_tokens_id.setdefault(img_name, [])
for description in img_name_to_tokens[img_name]:
token_ids = vocab.encode(description)
img_name_to_tokens_id[img_name].append(token_ids)
return img_name_to_tokens_id
img_name_to_tokens = parse_token_file(input_description_file)
img_name_to_tokens_id = convert_token_to_id(img_name_to_tokens, vocab)
logging.info("num of all imgs: {}".format(len(img_name_to_tokens)))
pprint.pprint(img_name_to_tokens['2778832101.jpg'])
logging.info("num of all imgs: {}".format(len(img_name_to_tokens_id)))
pprint.pprint(img_name_to_tokens_id['2778832101.jpg'])
class ImageCaptionData:
"""Provides data for image caption model."""
def __init__(self,
img_name_to_tokens_id,
img_feature_dir,
num_timesteps,
vocab,
deterministic = False):
self._vocab = vocab
self._img_name_to_tokens_id = img_name_to_tokens_id
# 截取句子的长度
self._num_timesteps = num_timesteps
# 是否要进行shuffle()随机打乱
self._deterministic = deterministic
self._indicator = 0
# 图片名
self._img_feature_filenames = []
# 图片特征
self._img_feature_data = []
self._all_img_feature_filepaths = []
for filename in gfile.ListDirectory(img_feature_dir):
if not filename[0] == '.':
self._all_img_feature_filepaths.append(
os.path.join(img_feature_dir, filename)
)
pprint.pprint(self._all_img_feature_filepaths)
self._load_img_feature_pickle()
if not self._deterministic:
self._random_shuffle()
def _load_img_feature_pickle(self):
"""Loads img feature data from pickle"""
for filepath in self._all_img_feature_filepaths:
logging.info("loading {}".format(filepath))
with gfile.GFile(filepath, 'rb') as f:
filenames, features = pickle.load(f)
# 图片名字列表的合并
self._img_feature_filenames += filenames
self._img_feature_data.append(features)
#vstack [# (1000, 1, 1, 2048, #(1000, 1, 1, 2048] -> #(2000, 1, 1, 2048)
self._img_feature_data = np.vstack(self._img_feature_data)
origin_shape = self._img_feature_data.shape
self._img_feature_data = np.reshape(self._img_feature_data,
(origin_shape[0], origin_shape[3]))
self._img_feature_filenames = np.asarray(self._img_feature_filenames)
print(self._img_feature_data.shape)
print(self._img_feature_filenames.shape)
def size(self):
return len(self._img_feature_filenames)
def _random_shuffle(self):
"""Shuffle data randomly."""
p = np.random.permutation(self.size())
self._img_feature_filenames = self._img_feature_filenames[p]
self._img_feature_data = self._img_feature_data[p]
# 图片特征的维度
def img_feature_size(self):
return self._img_feature_data.shape[1]
# 选择文件描述
def _img_desc(self, batch_filenames):
"""Gets description for filenames in batch."""
batch_sentence_ids = []
batch_weights = []
for filename in batch_filenames:
token_ids_set = self._img_name_to_tokens_id[filename]
chosen_token_ids = random.choice(token_ids_set)
chosen_token_ids_length = len(chosen_token_ids)
# 填充为1
weight = [1 for i in range(chosen_token_ids_length)]
if chosen_token_ids_length >= self._num_timesteps:
chosen_token_ids = chosen_token_ids[0: self._num_timesteps]
weight = weight[0: self._num_timesteps]
else:
remaing_length = self._num_timesteps - chosen_token_ids_length
# 用eos进行填充
chosen_token_ids += [self._vocab.eos for i in range(remaing_length)]
weight += [0 for i in range(remaing_length)]
batch_sentence_ids.append(chosen_token_ids)
batch_weights.append(weight)
batch_sentence_ids = np.asarray(batch_sentence_ids)
batch_weights = np.asarray(batch_weights)
return batch_sentence_ids, batch_weights
def next_batch(self, batch_size):
"""Returns next batch size."""
end_indicator = self._indicator + batch_size
if end_indicator > self.size():
if not self._deterministic:
self._random_shuffle()
self._indicator = 0
end_indicator = self._indicator + batch_size
assert end_indicator < self.size()
batch_filenames = self._img_feature_filenames[self._indicator: end_indicator]
batch_img_features = self._img_feature_data[self._indicator: end_indicator]
# batch_weights的作用为在 (sentence_ids: [100,101,123,4,5,0,0,0] -> [1,1,1,1,1,0,0,0] 0为 <UNK>未出现过的词) 把没有意义的词在训练时,梯度下降和loss计算中去掉
batch_sentence_ids, batch_weights = self._img_desc(batch_filenames)
self._indicator = end_indicator
return batch_img_features, batch_sentence_ids, batch_weights, batch_filenames
caption_data = ImageCaptionData(img_name_to_tokens_id,
input_img_feature_dir,
hps.num_timesteps,
vocab)
img_feature_dim = caption_data.img_feature_size()
caption_data_size = caption_data.size()
logging.info("img_feature_dim: {}".format(img_feature_dim))
logging.info("caption_data_size: {}".format(caption_data_size))
batch_img_features, batch_sentence_ids, batch_weights, batch_img_names = caption_data.next_batch(5)
pprint.pprint(batch_img_features)
pprint.pprint(batch_sentence_ids)
pprint.pprint(batch_weights)
pprint.pprint(batch_img_names)
训练
#%%
"""
1. Data generator
a. Loads vocab
b. Loads image features
c. provide data for training.
2. Builds image caption model
3. Trains the model
"""
import os
import sys
import tensorflow as tf
from tensorflow import gfile
from tensorflow import logging
import pprint
import pickle
import numpy as np
import math
import random
input_description_file = r".\deep_learn\image2text\results_20130124.token"
input_img_feature_dir = r".\deep_learn\image2text\download_inception_v3_features"
input_vocab_file = r".\deep_learn\image2text\vocab.txt"
output_dir = r".\deep_learn\image2text\local_run"
if not gfile.Exists(output_dir):
gfile.MakeDirs(output_dir)
def get_default_params():
return tf.contrib.training.HParams(
# 过滤低频率词汇
num_vocab_word_threshold = 3,
num_embedding_nodes = 32,
num_timesteps = 10,
num_lstm_nodes = [64, 64],
num_lstm_layers = 2,
num_fc_nodes = 32,
batch_size = 80,
cell_type = "lstm",
clip_lstm_grads = 1.0,
learning_rate = 0.001,
keep_prob = 0.8,
log_frequent = 100,
save_frequent = 1000,
)
hps = get_default_params()
#%%
class Vocab:
def __init__(self, filename, word_num_threshold):
self._id_to_word = {}
self._word_to_id = {}
self._unk = -1
self._eos = -1
self._word_num_threshold = word_num_threshold
self._read_dict(filename)
def _read_dict(self, filename):
with gfile.GFile(filename, 'r') as f:
lines = f.readlines()
for line in lines:
word, occurrence = line.strip('\r\n').split('\t')
occurrence = int(occurrence)
if occurrence < self._word_num_threshold:
continue
idx = len(self._id_to_word)
if word == '<UNK>':
self._unk = idx
elif word == '.':
self._eos = idx
if word in self._word_to_id or idx in self._id_to_word:
raise Exception("duplicate words in vocab")
self._word_to_id[word] = idx
self._id_to_word[idx] = word
@property
def unk(self):
return self._unk
@property
def eos(self):
return self._eos
def word_to_id(self, word):
return self._word_to_id.get(word, self._unk)
def id_to_word(self, word_id):
return self._id_to_word.get(word_id, '<UNK>')
def size(self):
return len(self._id_to_word)
def encode(self, sentence):
return [self.word_to_id(word) for word in sentence.split(' ')]
def decode(self, sentence_id):
words = [self.id_to_word(word_id) for word_id in sentence_id]
return ' '.join(words)
vocab = Vocab(input_vocab_file, hps.num_vocab_word_threshold)
vocab_size = vocab.size()
logging.info('vocab_size : {}'.format(vocab_size))
pprint.pprint(vocab.encode("I have a dream."))
pprint.pprint(vocab.decode([5, 10, 9, 21]))
#%%
def parse_token_file(token_file):
"""Parses image description file."""
img_name_to_tokens = {}
with gfile.GFile(token_file, 'r', ) as f:
lines = f.readlines()
for line in lines:
img_id, description = line.strip('\r\n').split('\t')
img_name, _ = img_id.split('#')
img_name_to_tokens.setdefault(img_name, [])
img_name_to_tokens[img_name].append(description)
return img_name_to_tokens
def convert_token_to_id(img_name_to_tokens, vocab):
"""Converts tokens of each description of imgs to id."""
img_name_to_tokens_id = {}
for img_name in img_name_to_tokens:
img_name_to_tokens_id.setdefault(img_name, [])
for description in img_name_to_tokens[img_name]:
token_ids = vocab.encode(description)
img_name_to_tokens_id[img_name].append(token_ids)
return img_name_to_tokens_id
img_name_to_tokens = parse_token_file(input_description_file)
img_name_to_tokens_id = convert_token_to_id(img_name_to_tokens, vocab)
logging.info("num of all imgs: {}".format(len(img_name_to_tokens)))
pprint.pprint(img_name_to_tokens['2778832101.jpg'])
logging.info("num of all imgs: {}".format(len(img_name_to_tokens_id)))
pprint.pprint(img_name_to_tokens_id['2778832101.jpg'])
#%%
class ImageCaptionData:
"""Provides data for image caption model."""
def __init__(self,
img_name_to_tokens_id,
img_feature_dir,
num_timesteps,
vocab,
deterministic = False):
self._vocab = vocab
self._img_name_to_tokens_id = img_name_to_tokens_id
# 截取句子的长度
self._num_timesteps = num_timesteps
# 是否要进行shuffle()随机打乱
self._deterministic = deterministic
self._indicator = 0
# 图片名
self._img_feature_filenames = []
# 图片特征
self._img_feature_data = []
self._all_img_feature_filepaths = []
for filename in gfile.ListDirectory(img_feature_dir):
if not filename[0] == '.':
self._all_img_feature_filepaths.append(
os.path.join(img_feature_dir, filename)
)
pprint.pprint(self._all_img_feature_filepaths)
self._load_img_feature_pickle()
if not self._deterministic:
self._random_shuffle()
def _load_img_feature_pickle(self):
"""Loads img feature data from pickle"""
for filepath in self._all_img_feature_filepaths:
logging.info("loading {}".format(filepath))
with gfile.GFile(filepath, 'rb') as f:
filenames, features = pickle.load(f)
# 图片名字列表的合并
self._img_feature_filenames += filenames
self._img_feature_data.append(features)
#vstack [# (1000, 1, 1, 2048, #(1000, 1, 1, 2048] -> #(2000, 1, 1, 2048)
self._img_feature_data = np.vstack(self._img_feature_data)
origin_shape = self._img_feature_data.shape
self._img_feature_data = np.reshape(self._img_feature_data,
(origin_shape[0], origin_shape[3]))
self._img_feature_filenames = np.asarray(self._img_feature_filenames)
print(self._img_feature_data.shape)
print(self._img_feature_filenames.shape)
def size(self):
return len(self._img_feature_filenames)
def _random_shuffle(self):
"""Shuffle data randomly."""
p = np.random.permutation(self.size())
self._img_feature_filenames = self._img_feature_filenames[p]
self._img_feature_data = self._img_feature_data[p]
# 图片特征的维度
def img_feature_size(self):
return self._img_feature_data.shape[1]
# 选择文件描述
def _img_desc(self, batch_filenames):
"""Gets description for filenames in batch."""
batch_sentence_ids = []
batch_weights = []
for filename in batch_filenames:
token_ids_set = self._img_name_to_tokens_id[filename]
chosen_token_ids = random.choice(token_ids_set)
chosen_token_ids_length = len(chosen_token_ids)
# 填充为1
weight = [1 for i in range(chosen_token_ids_length)]
if chosen_token_ids_length >= self._num_timesteps:
chosen_token_ids = chosen_token_ids[0: self._num_timesteps]
weight = weight[0: self._num_timesteps]
else:
remaing_length = self._num_timesteps - chosen_token_ids_length
# 用eos进行填充
chosen_token_ids += [self._vocab.eos for i in range(remaing_length)]
weight += [0 for i in range(remaing_length)]
batch_sentence_ids.append(chosen_token_ids)
batch_weights.append(weight)
batch_sentence_ids = np.asarray(batch_sentence_ids)
batch_weights = np.asarray(batch_weights)
return batch_sentence_ids, batch_weights
def next_batch(self, batch_size):
"""Returns next batch size."""
end_indicator = self._indicator + batch_size
if end_indicator > self.size():
if not self._deterministic:
self._random_shuffle()
self._indicator = 0
end_indicator = self._indicator + batch_size
assert end_indicator < self.size()
batch_filenames = self._img_feature_filenames[self._indicator: end_indicator]
batch_img_features = self._img_feature_data[self._indicator: end_indicator]
# batch_weights的作用为在 (sentence_ids: [100,101,123,4,5,0,0,0] -> [1,1,1,1,1,0,0,0] 0为 <UNK>未出现过的词) 把没有意义的词在训练时,梯度下降和loss计算中去掉
batch_sentence_ids, batch_weights = self._img_desc(batch_filenames)
self._indicator = end_indicator
return batch_img_features, batch_sentence_ids, batch_weights, batch_filenames
caption_data = ImageCaptionData(img_name_to_tokens_id,
input_img_feature_dir,
hps.num_timesteps,
vocab)
img_feature_dim = caption_data.img_feature_size()
caption_data_size = caption_data.size()
logging.info("img_feature_dim: {}".format(img_feature_dim))
logging.info("caption_data_size: {}".format(caption_data_size))
batch_img_features, batch_sentence_ids, batch_weights, batch_img_names = caption_data.next_batch(5)
pprint.pprint(batch_img_features)
pprint.pprint(batch_sentence_ids)
pprint.pprint(batch_weights)
pprint.pprint(batch_img_names)
#%%
def create_rnn_cell(hidden_dim, cell_type):
"""Return specific cell according to cell_type."""
if cell_type == 'lstm':
return tf.contrib.rnn.BasicLSTMCell(hidden_dim,
state_is_tuple=True)
elif cell_type == 'gru':
return tf.contrib.rnn.GRUCell(hidden_dim)
else:
raise Exception("{} type has not been supported.".format(cell_type))
def dropout(cell, keep_prob):
"""Wrap cell with dropout."""
return tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob = keep_prob)
def get_train_model(hps, vocab_size, img_feature_dim):
num_timesteps = hps.num_timesteps
batch_size = hps.batch_size
img_feature = tf.placeholder(tf.float32,
(batch_size, img_feature_dim))
sentence = tf.placeholder(tf.int32,
(batch_size, num_timesteps))
# 填充
mask = tf.placeholder(tf.int32, (batch_size, num_timesteps))
keep_prob = tf.placeholder(tf.float32, name="keep_prob")
global_step = tf.Variable(tf.zeros([], tf.int32),
name='global_step',
trainable=False)
# prediction process:
# sentence: [a, b, c, d, e]
# input: [img, a, b, c, d]
# img_feature: [0.4, 0.3, 10, 2]
# predict #1 : img_feature - >embedding_img - >lstm -> (a)
# predict #2 : a -> embedding_word -> lstm -> (b)
# predict #3 : b -> embedding_word -> lstm -> (c)
# .... 原始流程如上,但是我们这里把embedding_img设置成和embedding_word大小一样
# Sets up embedding layer.
embedding_initializer = tf.random_uniform_initializer(-0.1, 1.0)
with tf.variable_scope('embedding',
initializer=embedding_initializer):
# 句子转化成embedding
embeddings = tf.get_variable('embedding',
[vocab_size, hps.num_embedding_nodes],
tf.float32)
# embed_token_ids: [batch_size, num_timesteps-1, num_embedding_nodes]
embed_token_ids = tf.nn.embedding_lookup(
embeddings,
#
sentence[:, 0: num_timesteps -1 ]
)
img_feature_embed_init = tf.uniform_unit_scaling_initializer(
factor=1.0
)
with tf.variable_scope('img_feature_embed',
initializer=img_feature_embed_init):
# img_features: [batch_size, img_feature_dim]
# embed_img : [batch_size, num_embedding_nodes]
embed_img = tf.layers.dense(img_feature,
hps.num_embedding_nodes)
# embed_img : [batch_size, 1, num_embedding_nodes]
embed_img = tf.expand_dims(embed_img, 1)
# embed_input : [batch_size, num_timestpes, num_embedding_nodes]
# 图片和句子embedding合并
embed_inputs = tf.concat([embed_img, embed_token_ids], axis=1)
# Sets up rnn netword
scale = 1.0 / math.sqrt(hps.num_embedding_nodes + hps.num_lstm_nodes[-1]) / 3.0
rnn_init = tf.random_uniform_initializer(-scale, scale)
with tf.variable_scope('lstm_nn', initializer= rnn_init):
cells = []
for i in range(hps.num_lstm_layers):
cell = create_rnn_cell(hps.num_lstm_nodes[i], hps.cell_type)
cell = dropout(cell, keep_prob)
cells.append(cell)
cell = tf.contrib.rnn.MultiRNNCell(cells)
init_state = cell.zero_state(hps.batch_size, tf.float32)
# rnn_outputs: [batch_size, num_timesteps, hps.num_lstm_node[-1]
rnn_outputs, _ = tf.nn.dynamic_rnn(cell,
embed_inputs,
initial_state = init_state)
# Sets up fully-connected layer
fc_init = tf.uniform_unit_scaling_initializer(factor= 1.0)
with tf.variable_scope('fc', initializer=fc_init):
rnn_outputs_2d = tf.reshape(rnn_outputs,
[-1, hps.num_lstm_nodes[-1]])
fc1 = tf.layers.dense(rnn_outputs_2d,
hps.num_fc_nodes,
name='fc1')
fc1_dropout = tf.layers.dropout(fc1, keep_prob)
fc1_relu = tf.nn.relu(fc1_dropout)
logits = tf.layers.dense(fc1_relu,
vocab_size,
name='logits')
# Calculates loss
with tf.variable_scope('loss'):
sentence_flatten = tf.reshape(sentence, [-1])
mask_flatten = tf.reshape(mask, [-1])
mask_sum = tf.reduce_sum(mask_flatten)
sofxmax_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits,
labels=sentence_flatten
)
# 去除填充的eso
weighted_softmax_loss = tf.multiply(
sofxmax_loss, tf.cast(mask_flatten, tf.float32)
)
loss = tf.reduce_sum(weighted_softmax_loss) / tf.cast(mask_sum, tf.float32)
prediction = tf.argmax(logits, 1, output_type=tf.int32)
correct_prediction = tf.equal(prediction,
sentence_flatten)
weighted_correct_prediction = tf.multiply(
tf.cast(correct_prediction, tf.float32),
tf.cast(mask_flatten, tf.float32)
)
# accuracy = tf.reduce_sum(weighted_correct_prediction) / tf.cast(mask_sum, tf.float32)
accuracy = tf.reduce_sum(weighted_correct_prediction) / tf.cast(mask_sum, tf.float32)
tf.summary.scalar('loss', loss)
# Defines train op.
with tf.variable_scope('train_op'):
tvars = tf.trainable_variables()
for var in tvars:
logging.info('variable name: {}'.format(var.name))
grads, _ = tf.clip_by_global_norm(
tf.gradients(loss, tvars), hps.clip_lstm_grads
)
optimizer = tf.train.AdamOptimizer(hps.learning_rate)
train_op = optimizer.apply_gradients(
zip(grads, tvars), global_step= global_step
)
return ((img_feature, sentence, mask, keep_prob),
(loss, accuracy, train_op),
global_step)
placehoders, metrics, global_step = get_train_model(
hps, vocab_size, img_feature_dim
)
img_feature, sentence, mask, keep_prob = placehoders
loss, accuracy, train_op = metrics
summary_op = tf.summary.merge_all()
init_op = tf.global_variables_initializer()
saver = tf.train.Saver(max_to_keep=10)
#%%
training_steps = 1000
with tf.Session() as sess:
sess.run(init_op)
writer = tf.summary.FileWriter(output_dir, sess.graph)
for i in range(training_steps):
(batch_img_features,
batch_sentence_ids,
batch_weights, _) = caption_data.next_batch(hps.batch_size)
input_vals = (batch_img_features,
batch_sentence_ids,
batch_weights,
hps.keep_prob)
feed_dict = dict(zip(placehoders, input_vals))
fetches = [global_step, loss, accuracy, train_op]
should_log = (i+1) % hps.log_frequent == 0
should_save = (i+1) % hps.save_frequent == 0
if should_log:
fetches += [summary_op]
outputs = sess.run(fetches, feed_dict=feed_dict)
global_step_val, loss_val, accuracy_val = outputs[0:3]
if should_log:
summary_str = outputs[-1]
writer.add_summary(summary_str, global_step_val)
logging.info("Step: {}, loss: {:.2f}, acc: {:.2f}".format(global_step_val, loss_val, accuracy_val))
if should_save:
model_save_file = os.path.join(output_dir, "image_caption")
logging.info("Step: {}, model saved ".format(global_step_val))
saver.save(sess, model_save_file, global_step=global_step_val)