涉及的论文
Neural Conversational Model https://arxiv.org/abs/1506.05869
Luong attention mechanism(s) https://arxiv.org/abs/1508.04025
Sutskever et al. https://arxiv.org/abs/1409.3215
GRU Cho et al. https://arxiv.org/pdf/1406.1078v3.pdf
Bahdanau et al. https://arxiv.org/abs/1409.0473
使用的数据集
Corpus web https://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html
Corpus link http://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip
代码列表
chatbot_test.py
chatbot_train.py
corpus_dataset.py
vocabulary.py
graph.py
model.py
etc.py
main.py
chatbot_test.py
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import corpus_dataset
import graph
import etc
def run_test():
config = etc.config
voc, pairs = corpus_dataset.load_vocabulary_and_pairs(config)
g = graph.CorpusGraph(config)
train_model = g.create_train_model(voc, "test")
g.evaluate_input(voc, train_model)
chatbot_train.py
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import corpus_dataset
import graph
import etc
def run_train():
config = etc.config
voc, pairs = corpus_dataset.load_vocabulary_and_pairs(config)
g = graph.CorpusGraph(config)
print("Create model")
train_model = g.create_train_model(voc)
print("Starting Training!")
g.trainIters(voc, pairs, train_model)
# print("Starting evaluate!")
# g.evaluate_input(voc, train_model)
corpus_dataset.py
#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : corpus_dataset.py
# Create date : 2019-01-16 11:16
# Modified date : 2019-02-02 14:55
# Author : DARREN
# Describe : not set
# Email : lzygzh@126.com
#####################################
from __future__ import division
from __future__ import print_function
import os
import re
import csv
import codecs
import unicodedata
import vocabulary
def _check_is_have_file(file_name):
return os.path.exists(file_name)
def _filter_pair(p, max_length):
return len(p[0].split(' ')) < max_length and len(p[1].split(' ')) < max_length
def _filter_pairs(pairs, max_length):
return [pair for pair in pairs if _filter_pair(pair, max_length)]
def _read_vocabulary(datafile, corpus_name):
print("Reading lines...")
lines = open(datafile, encoding='utf-8'). read().strip().split('\n')
pairs = [[normalize_string(s) for s in l.split('\t')] for l in lines]
voc = vocabulary.Voc(corpus_name)
return voc, pairs
def _unicode_to_ascii(s):
return ''.join(
c for c in unicodedata.normalize('NFD', s)
if unicodedata.category(c) != 'Mn'
)
def _get_delimiter(config):
delimiter = config["delimiter"]
delimiter = str(codecs.decode(delimiter, "unicode_escape"))
return delimiter
def _get_object(line, fields):
values = line.split(" +++$+++ ")
obj = {
}
for i, field in enumerate(fields):
obj[field] = values[i]
return obj
def _load_lines(config):
lines_file_name = config["lines_file_name"]
corpus_path = config["corpus_path"]
lines_file_full_path = "%s/%s" % (corpus_path, lines_file_name)
fields = config["movie_lines_fields"]
lines = {
}
f = open(lines_file_full_path, 'r', encoding='iso-8859-1')
for line in f:
line_obj = _get_object(line, fields)
lines[line_obj['lineID']] = line_obj
f.close()
return lines
def _cellect_lines(conv_obj, lines):
# Convert string to list (conv_obj["utteranceIDs"] == "['L598485', 'L598486', ...]")
line_ids = eval(conv_obj["utteranceIDs"])
# Reassemble lines
conv_obj["lines"] = []
for line_id in line_ids:
conv_obj["lines"].append(lines[line_id])
return conv_obj
def _load_conversations(lines, config):
conversations = []
corpus_path = config["corpus_path"]
conversation_file_name = config["conversation_file_name"]
conversation_file_full_path = "%s/%s" % (corpus_path, conversation_file_name)
fields = config["movie_conversations_fields"]
f = open(conversation_file_full_path, 'r', encoding='iso-8859-1')
for line in f:
conv_obj = _get_object(line, fields)
conv_obj = _cellect_lines(conv_obj, lines)
conversations.append(conv_obj)
f.close()
return conversations
def _get_conversations(config):
lines = {
}
conversations = []
lines = _load_lines(config)
print("lines count:", len(lines))
conversations = _load_conversations(lines, config)
print("conversations count:", len(conversations))
return conversations
def _extract_sentence_pairs(conversations):
pairs = []
for conversation in conversations:
for i in range(len(conversation["lines"]) - 1): # We ignore the last line (no answer for it)
inputLine = conversation["lines"][i]["text"].strip()
targetLine = conversation["lines"][i+1]["text"].strip()
# Filter wrong samples (if one of the lists is empty)
if inputLine and targetLine:
pairs.append([inputLine, targetLine])
return pairs
def _load_formatted_data(config):
max_length = config["max_length"]
corpus_name = config["corpus_name"]
formatted_file_full_path = get_formatted_file_full_path(config)
print("Start preparing training data ...")
voc, pairs = _read_vocabulary(formatted_file_full_path, corpus_name)
print("Read {!s} sentence pairs".format(len(pairs)))
pairs = _filter_pairs(pairs, max_length)
print("Trimmed to {!s} sentence pairs".format(len(pairs)))
for pair in pairs:
voc.addSentence(pair[0])
voc.addSentence(pair[1])
print("Counted words:", voc.num_words)
return voc, pairs
def _trim_rare_words(voc, pairs, min_count):
voc.trim(min_count)
keep_pairs = []
for pair in pairs:
input_sentence = pair[0]
output_sentence = pair[1]
keep_input = True
keep_output = True
for word in input_sentence.split(' '):
if word not in voc.word2index:
keep_input = False
break
for word in output_sentence.split(' '):
if word not in voc.word2index:
keep_output = False
break
if keep_input and keep_output:
keep_pairs.append(pair)
print("Trimmed from {} pairs to {}, {:.4f} of total".format(len(pairs), len(keep_pairs), len

最低0.47元/天 解锁文章
4915

被折叠的 条评论
为什么被折叠?



