吐槽
虽然代码量不大,花了很长时间才理清楚,几个py文件的调用关系,作为一个看源码的人来说,大量函数调来调去看过去看过来真的很累。简单的代码非要写很多步来实现,实在不理解。
只有一个地方调用的函数,实在不建议单独写个函数(除非确实需要函数),就算给别人看,这样也会带来困扰,页面切换来切换去,去理解你写的什么东西;另外,也不建议一个函数超过5个参数;不论给调用者还是给代码观看者来说体验并不好。
训练代码(几个文件浓缩成一个了)
#-*-coding:utf-8-*-
import numpy as np
import tensorflow as tf
import tensorflow.contrib as contrib
import os
import sys
import sqlite3
import json
import copy
import time
from collections import OrderedDict
'''--数据获取和字库表--'''
class BucketData(object):
def __init__(self, buckets_dir, encoder_size, decoder_size):
self.encoder_size = encoder_size
self.decoder_size = decoder_size
self.conn = sqlite3.connect(os.path.join(buckets_dir, 'bucket_%d_%d.db' % (encoder_size, decoder_size)))
self.cur = self.conn.cursor()
sql = '''SELECT MAX(ROWID) FROM conversation;'''
self.size = self.cur.execute(sql).fetchall()[0][0]
#查询出一个问题对应的多个答案
def all_answers(self, ask):
sql = '''SELECT answer FROM conversation WHERE ask = '{}';'''.format(ask.replace("'", "''"))
ret = []
for s in self.cur.execute(sql):
ret.append(s[0])
return list(set(ret))
#从当前库里随机获取一个问答
def random(self):
while True:
rowid = np.random.randint(1, self.size + 1)
sql = '''SELECT ask, answer FROM conversation WHERE ROWID = {};'''.format(rowid)
ret = self.cur.execute(sql).fetchall()
#ret结构为[('咱们梅家从你爷爷起', '就一直小心翼翼地唱戏')]
if len(ret) == 1:
ask, answer = ret[0]
if ask is not None and answer is not None:
return ask, answer
buckets = [(5, 15),(10, 20),(15, 25),(20, 30)]
buckets_dir='./bucket_dbs'
#所有buckets对象list
buckets_object_list = []
for encode_size,decode_size in buckets:
bucket_data=BucketData(buckets_dir,encode_size,decode_size)
buckets_object_list.append(bucket_data)
#所有buckets拥有数据量列表
bucket_sizes = []
for i in range(len(buckets)):
bucket_size = buckets_object_list[i].size
bucket_sizes.append(bucket_size)
print('bucket {} 中有数据 {} 条'.format(i, bucket_size))
total_size = sum(bucket_sizes) #所有数据问答对的个数
print('共有数据 {} 条'.format(total_size))
DICTIONARY_PATH = 'db/dictionary.json';EOS = '<eos>';UNK = '<unk>';PAD = '<pad>';GO = '<go>'
current_dir = os.path.dirname(os.path.abspath(__file__))
join_path=lambda file:os.path.join(current_dir,file)
with open(join_path(DICTIONARY_PATH), 'r', encoding='UTF-8') as fp:
dictionary = [EOS, UNK, PAD, GO] + json.load(fp)
index2word = OrderedDict()
word2index = OrderedDict()
for index, word in enumerate(dictionary):
index2word[index] = word
word2index[word] = index
vocab_size = len(dictionary) #字符集长度
print('字符集长度:',vocab_size)
'''--建立模型--'''
placeholder_encoder_inputs = []
placeholder_decoder_inputs = []
placeholder_decoder_weights= []
#表示encoder最长输入句子长度
for i in range(buckets[-1][0]): #20
placeholder_encoder_inputs.append(tf.placeholder(tf.int32,shape=[None],name='encoder_input_{}'.format(i)))
# decoder输出比decoder输入大 1,这是为了保证下面的targets可以向左shift 1位
for i in range(buckets[-1][1] + 1): #31
placeholder_decoder_inputs.append(tf.placeholder(tf.int32,shape=[None],name='decoder_input_{}'.format(i)))
placeholder_decoder_weights.append(tf.placeholder(tf.float32,shape=[None],name='decoder_weight_{}'.format(i)))
targets =placeholder_decoder_inputs[1:] #30
#编码模型
cell_num=512 #神经元个数
target_vocab_size=vocab_size
encode_cell_layer_num=2 #神经元层数
cell=contrib.rnn.BasicRNNCell(cell_num)
cell=contrib.rnn.DropoutWrapper(cell,output_keep_prob=1.0)
encoder_cell=contrib.rnn.MultiRNNCell([cell]*encode_cell_layer_num)
#没这个,深度复制会出问题
setattr(tf.contrib.rnn.GRUCell, '__deepcopy__', lambda self, _: self)
setattr(tf.contrib.rnn.BasicLSTMCell, '__deepcopy__', lambda self, _: self)
setattr(tf.contrib.rnn.MultiRNNCell, '__deepcopy__', lambda self, _: self)
#定义损失函数
num_samples=512
embed_size=512
w=tf.get_variable('w',[cell_num,target_vocab_size],dtype=tf.float32)
b = tf.get_variable("b",[target_vocab_size],dtype=tf.float32)
if num_samples<0 or num_samples>target_vocab_size: #num_samples表示有多少个不重复字符
raise('抽样数量不在范围内')
def softmax_loss_function( labels,logits): #必须是logits,不能改其他否者报错
return tf.nn.sampled_softmax_loss(weights=tf.transpose(w),biases=b,labels=tf.reshape(labels, [-1, 1]),inputs=logits,num_sampled=num_samples,num_classes=target_vocab_size)
#嵌入attension
source_vocab_size=vocab_size
tmp_cell = copy.deepcopy(encoder_cell)
def attention_cell(encoder_inputs_x,decoder_inputs_y):
return contrib.legacy_seq2seq.embedding_attention_seq2seq(encoder_inputs_x,decoder_inputs_y,tmp_cell,num_encoder_symbols=source_vocab_size, num_decoder_symbols=target_vocab_size,embedding_size=embed_size,output_projection=(w,b), feed_previous=True,dtype=tf.float32)
#模型输出
outputs, losses = contrib.legacy_seq2seq.model_with_buckets(placeholder_encoder_inputs,placeholder_decoder_inputs,targets,placeholder_decoder_weights,buckets,attention_cell,softmax_loss_function=softmax_loss_function)
#优化器
optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
max_threshold_value=5.0
gradient_norms = []
updates = []
params= tf.trainable_variables() #自动获取需要训练的变量
for output, loss in zip(outputs, losses): # 用梯度下降法优化
gradients = tf.gradients(loss,params) #求导
clipped_gradients, norm = tf.clip_by_global_norm(gradients,max_threshold_value)
gradient_norms.append(norm)
updates.append(optimizer.apply_gradients(zip(clipped_gradients, params)))
#定义保存工具
saver = tf.train.Saver(tf.global_variables())#,write_version=tf.train.SaverDef.V2)
'''--训练模型--'''
#时间显示格式转换,秒转时分秒格式
def time_display(s):
ret = ''
if s >= 60 * 60:
h = np.math.floor(s / (60 * 60))
ret += '{}h'.format(h)
s -= h * 60 * 60
if s >= 60:
m = np.math.floor(s / 60)
ret += '{}m'.format(m)
s -= m * 60
if s >= 1:
s = np.math.floor(s)
ret += '{}s'.format(s)
return ret
#字转数字函数
def sentence2index(chinese_sentence): #................中英文怎么办。。。。。。。。。。。。?
sentence_ret = []
for word in chinese_sentence:
if word in word2index:
sentence_ret.append(word2index[word])
else:
sentence_ret.append(word2index[UNK])
return sentence_ret
num_epoch=1600
batch_size=64
num_per_epoch=1024
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# 几个buckets的累计和长度----好像没啥用,就是后面用于选择某个bucket
buckets_scale = [sum(bucket_sizes[:i + 1]) / total_size for i in range(len(bucket_sizes))]
# 开始训练
metrics = ' '.join(['\r[{}]','{:.1f}%','{}/{}','loss={:.3f}','{}/{}'])
bars_max = 20
with tf.device('/gpu:0'):
for epoch_index in range(1, num_epoch+1):
print('Epoch {}:'.format(epoch_index))
time_start = time.time()
epoch_trained = 0
batch_loss = []
while True:
# 选择一个要训练的bucket
random_number = np.random.random_sample()
selected_bucket_id = min([i for i in range(len(buckets_scale)) if buckets_scale[i] > random_number])
#selected_bucket_id中的数据装进列表
ask_answer_tuble_list = []
# answer_ask_tuble_list = []
selected_bucket_db = buckets_object_list[selected_bucket_id]
for _ in range(batch_size): #随机抽取batch_size的样本量
ask, answer = selected_bucket_db.random()
ask_answer_tuble_list.append((ask, answer))
#列表中的数据转转数字,并且填充成一样长数组
encoder_size, decoder_size = buckets[selected_bucket_id]
encoder_inputs, decoder_inputs = [], []
for encoder_input, decoder_input in ask_answer_tuble_list:
encoder_input = sentence2index(encoder_input)
decoder_input = sentence2index(decoder_input)
# Encoder
encoder_pad = [word2index[PAD]] * (encoder_size - len(encoder_input)) #填充<PAD>长度
encoder_inputs.append(list(reversed(encoder_input + encoder_pad))) #每个encoder_input都反序了,encoder_inputs[[]]
# Decoder
decoder_pad_size = decoder_size - len(decoder_input) - 2 #-2是什么?
decoder_inputs.append([word2index[GO]] + decoder_input +[word2index[EOS]] +[word2index[PAD]] * decoder_pad_size) #GO和EOS?
#selected_bucket_id中转成数字后,选择batch_size大小的数据量
batch_encoder_inputs, batch_decoder_inputs, batch_weights = [], [], []
# batch encoder
for i in range(encoder_size): #就是转置了成为encoder_size x batch_size
batch_encoder_inputs.append(np.array([encoder_inputs[j][i] for j in range(batch_size)],dtype=np.int32))
# batch decoder
for i in range(decoder_size): #就是转置了成为decoder_inputs x batch_size
batch_decoder_inputs.append(np.array([decoder_inputs[j][i] for j in range(batch_size)],dtype=np.int32))
batch_weight = np.ones(batch_size, dtype=np.float32)
for j in range(batch_size):
if i < decoder_size - 1:
target = decoder_inputs[j][i + 1] #go后面开始的,到decoder_size-1,预测输出
if i == decoder_size - 1 or target == word2index[PAD]: #有正常字权重为1,非正常字为权重为0
batch_weight[j] = 0.0
batch_weights.append(batch_weight) #权重为decoder_size x batch_size
# feed数据格式
input_feed = {}
for i in range(encoder_size):
input_feed[placeholder_encoder_inputs[i].name] = batch_encoder_inputs[i]
for i in range(decoder_size):
input_feed[placeholder_decoder_inputs[i].name] = batch_decoder_inputs[i]
input_feed[placeholder_decoder_weights[i].name] = batch_weights[i]
last_target = placeholder_decoder_inputs[decoder_size].name #placeholder_decoder_inputs的第31个
input_feed[last_target] = np.zeros([batch_size], dtype=np.int32)
#模型输出列表
output_feed = [updates[selected_bucket_id],gradient_norms[selected_bucket_id],losses[selected_bucket_id]]
output_feed.append(outputs[selected_bucket_id][i])
outputs1= sess.run(output_feed, input_feed)
epoch_trained += batch_size
batch_loss.append(outputs1[1])
time_now = time.time()
time_spend = time_now - time_start
time_estimate = time_spend / (epoch_trained /num_per_epoch)
percent = min(100, epoch_trained / num_per_epoch) * 100
bars = np.math.floor(percent / 100 * bars_max)
sys.stdout.write(metrics.format('=' * bars + '-' * (bars_max - bars),percent,epoch_trained, num_per_epoch,np.mean(batch_loss),time_display(time_spend), time_display(time_estimate)))
sys.stdout.flush()
if epoch_trained >= num_per_epoch:
break
print('\n')
model_dir = './model'
model_name = 'model3'
if not os.path.exists(model_dir):
os.makedirs(model_dir)
if epoch_index % 800 == 0:
saver.save(sess, os.path.join(model_dir, model_name))
调用模型应用
#-*-coding:utf-8-*-
import numpy as np
import tensorflow as tf
import os
import json
from collections import OrderedDict
buckets = [(5, 15),(10, 20),(15, 25),(20, 30)]
DICTIONARY_PATH = 'db/dictionary.json';EOS = '<eos>';UNK = '<unk>';PAD = '<pad>';GO = '<go>'
current_dir = os.path.dirname(os.path.abspath(__file__))
join_path=lambda file:os.path.join(current_dir,file)
with open(join_path(DICTIONARY_PATH), 'r', encoding='UTF-8') as fp:
dictionary = [EOS, UNK, PAD, GO] + json.load(fp)
index2word = OrderedDict()
word2index = OrderedDict()
for index, word in enumerate(dictionary):
index2word[index] = word
word2index[word] = index
def sentence2index(chinese_sentence):
sentence_ret = []
for word in chinese_sentence:
if word in word2index:
sentence_ret.append(word2index[word])
else:
sentence_ret.append(word2index[UNK])
return sentence_ret
def index2sentence(indice):
ret = []
for index in indice:
word = index2word[index]
if word == EOS:
break
if word != UNK and word != GO and word != PAD:
ret.append(word)
return ''.join(ret)
with tf.Session() as sess:
#恢复图和参数
saver=tf.train.import_meta_graph("model/model3.meta")
saver.restore(sess,'model/model3')
#获取需要用到的变量
#获取losses
loss_name=f"model_out_and_loss/sequence_loss{''}/truediv:0"
graph=tf.get_default_graph()
losses=[]
losses.append(graph.get_tensor_by_name(loss_name))
for i in range(1,4):
losses.append(graph.get_tensor_by_name(f"model_out_and_loss/sequence_loss_{i}/truediv:0"))
##获取outputs
outputs=[]
buck_list=[15,20,25,30]
for j in range(len(buck_list)):
sedj="_" + str(j)
if j == 0:
sedj = ''
outputj=[]
for k in range(buck_list[j]):
sedk = "_" + str(k)
if k==0:
sedk=''
out_dir=f'model_out_and_loss/embedding_attention_seq2seq{sedj}/embedding_attention_decoder/attention_decoder/AttnOutputProjection{sedk}/BiasAdd:0'
outputj.append(graph.get_tensor_by_name(out_dir))
outputs.append(outputj)
## 获取 placeholder
placeholder_encoder_inputs = []
placeholder_decoder_inputs = []
placeholder_decoder_weights = []
# 表示encoder最长输入句子长度
for i in range(buckets[-1][0]): # 20
placeholder_encoder_inputs.append(graph.get_tensor_by_name(f'encoder_input_{i}:0'))
# decoder输出比decoder输入大 1,这是为了保证下面的targets可以向左shift 1位
for i in range(buckets[-1][1] + 1): # 31
placeholder_decoder_inputs.append(graph.get_tensor_by_name(f'decoder_input_{i}:0'))
placeholder_decoder_weights.append(graph.get_tensor_by_name(f'decoder_weight_{i}:0'))
targets = placeholder_decoder_inputs[1:] # 30
#测试
batch_size = 1
sentence=input(">")
while sentence:
selected_bucket_id=np.random.randint(4)
#对sentence进行转换,下面都是到feed
ask_answer_tuble_list = []
ask_answer_tuble_list.append((sentence, ""))
# 列表中的数据转转数字,并且填充成一样长数组
encoder_size, decoder_size = buckets[selected_bucket_id]
encoder_inputs, decoder_inputs = [], []
for encoder_input, decoder_input in ask_answer_tuble_list:
encoder_input = sentence2index(encoder_input)
decoder_input = sentence2index(decoder_input)
# Encoder
encoder_pad = [word2index[PAD]] * (encoder_size - len(encoder_input)) # 填充<PAD>长度
encoder_inputs.append(list(reversed(encoder_input + encoder_pad))) # 每个encoder_input都反序了,encoder_inputs[[]]
# Decoder
decoder_pad_size = decoder_size - len(decoder_input) - 2 # -2是什么?
decoder_inputs.append(
[word2index[GO]] + decoder_input + [word2index[EOS]] + [word2index[PAD]] * decoder_pad_size) # GO和EOS?
# selected_bucket_id中转成数字后,选择batch_size大小的数据量
batch_encoder_inputs, batch_decoder_inputs, batch_weights = [], [], []
# batch encoder
for i in range(encoder_size): # 就是转置了成为encoder_size x batch_size
batch_encoder_inputs.append(np.array([encoder_inputs[j][i] for j in range(batch_size)], dtype=np.int32))
# batch decoder
for i in range(decoder_size): # 就是转置了成为decoder_inputs x batch_size
batch_decoder_inputs.append(np.array([decoder_inputs[j][i] for j in range(batch_size)], dtype=np.int32))
batch_weight = np.ones(batch_size, dtype=np.float32)
for j in range(batch_size):
if i < decoder_size - 1:
target = decoder_inputs[j][i + 1] # go后面开始的,到decoder_size-1,预测输出
if i == decoder_size - 1 or target == word2index[PAD]: # 有正常字权重为1,非正常字为权重为0
batch_weight[j] = 0.0
batch_weights.append(batch_weight) # 权重为decoder_size x batch_size
# 制作feed数据格式
input_feed = {}
for i in range(encoder_size):
input_feed[placeholder_encoder_inputs[i].name] = batch_encoder_inputs[i]
for i in range(decoder_size):
input_feed[placeholder_decoder_inputs[i].name] = batch_decoder_inputs[i]
input_feed[placeholder_decoder_weights[i].name] = batch_weights[i]
last_target = placeholder_decoder_inputs[decoder_size].name # placeholder_decoder_inputs的第31个
input_feed[last_target] = np.zeros([batch_size], dtype=np.int32)
output_feed = [losses[selected_bucket_id]]
for i in range(decoder_size):
output_feed.append(outputs[selected_bucket_id][i])
output_logits = sess.run(output_feed, input_feed)
outputsx = [int(np.argmax(logit, axis=1)) for logit in output_logits[1:]]
retx = index2sentence(outputsx)
print(retx)
sentence = input('>')
效果
结束语
运行必要说明:环境要安装sqlite3;并且python工作路径中要有bucket_5_15.db、bucket_10_20.db等类似文件在该编辑文件目录下的bucket_dbs目录中。至于数据为什么要弄成bucket的形式,自己去找答案。
bucket_10_20.db中的格式如下,其他一样:
“你接到这封信的时候|不知道大伯还在不在人世了”
最后,代码还能精简,本文后面继续更改和添删。