pytorch实现 chatbot聊天机器人

涉及的论文

Neural Conversational Model https://arxiv.org/abs/1506.05869
Luong attention mechanism(s) https://arxiv.org/abs/1508.04025
Sutskever et al. https://arxiv.org/abs/1409.3215
GRU Cho et al. https://arxiv.org/pdf/1406.1078v3.pdf
Bahdanau et al. https://arxiv.org/abs/1409.0473

使用的数据集

Corpus web https://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html
Corpus link http://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip

代码列表

chatbot_test.py
chatbot_train.py
corpus_dataset.py
vocabulary.py
graph.py  
model.py
etc.py             
main.py   

chatbot_test.py

# -*- coding: utf-8 -*-

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import corpus_dataset
import graph
import etc

def run_test():
    config = etc.config
    voc, pairs = corpus_dataset.load_vocabulary_and_pairs(config)
    g = graph.CorpusGraph(config)
    train_model = g.create_train_model(voc, "test")
    g.evaluate_input(voc, train_model)

chatbot_train.py

# -*- coding: utf-8 -*-

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import corpus_dataset
import graph
import etc

def run_train():
    config = etc.config
    voc, pairs = corpus_dataset.load_vocabulary_and_pairs(config)
    g = graph.CorpusGraph(config)
    print("Create model")
    train_model = g.create_train_model(voc)
    print("Starting Training!")
    g.trainIters(voc, pairs, train_model)
#    print("Starting evaluate!")
#    g.evaluate_input(voc, train_model)

corpus_dataset.py

#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : corpus_dataset.py
# Create date : 2019-01-16 11:16
# Modified date : 2019-02-02 14:55
# Author : DARREN
# Describe : not set
# Email : lzygzh@126.com
#####################################
from __future__ import division
from __future__ import print_function

import os
import re
import csv
import codecs
import unicodedata
import vocabulary

def _check_is_have_file(file_name):
    return os.path.exists(file_name)

def _filter_pair(p, max_length):
    return len(p[0].split(' ')) < max_length and len(p[1].split(' ')) < max_length

def _filter_pairs(pairs, max_length):
    return [pair for pair in pairs if _filter_pair(pair, max_length)]

def _read_vocabulary(datafile, corpus_name):
    print("Reading lines...")
    lines = open(datafile, encoding='utf-8'). read().strip().split('\n')
    pairs = [[normalize_string(s) for s in l.split('\t')] for l in lines]
    voc = vocabulary.Voc(corpus_name)
    return voc, pairs

def _unicode_to_ascii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

def _get_delimiter(config):
    delimiter = config["delimiter"]
    delimiter = str(codecs.decode(delimiter, "unicode_escape"))
    return delimiter

def _get_object(line, fields):
    values = line.split(" +++$+++ ")
    obj = {
   
   }
    for i, field in enumerate(fields):
        obj[field] = values[i]
    return obj

def _load_lines(config):
    lines_file_name = config["lines_file_name"]
    corpus_path = config["corpus_path"]
    lines_file_full_path = "%s/%s" % (corpus_path, lines_file_name)
    fields = config["movie_lines_fields"]

    lines = {
   
   }
    f = open(lines_file_full_path, 'r', encoding='iso-8859-1')
    for line in f:
        line_obj = _get_object(line, fields)
        lines[line_obj['lineID']] = line_obj
    f.close()
    return lines

def _cellect_lines(conv_obj, lines):
    # Convert string to list (conv_obj["utteranceIDs"] == "['L598485', 'L598486', ...]")
    line_ids = eval(conv_obj["utteranceIDs"])
    # Reassemble lines
    conv_obj["lines"] = []
    for line_id in line_ids:
        conv_obj["lines"].append(lines[line_id])
    return conv_obj

def _load_conversations(lines, config):
    conversations = []
    corpus_path = config["corpus_path"]
    conversation_file_name = config["conversation_file_name"]
    conversation_file_full_path = "%s/%s" % (corpus_path, conversation_file_name)
    fields = config["movie_conversations_fields"]
    f = open(conversation_file_full_path, 'r', encoding='iso-8859-1')
    for line in f:
        conv_obj = _get_object(line, fields)
        conv_obj = _cellect_lines(conv_obj, lines)
        conversations.append(conv_obj)
    f.close()
    return conversations

def _get_conversations(config):
    lines = {
   
   }
    conversations = []

    lines = _load_lines(config)
    print("lines count:", len(lines))
    conversations = _load_conversations(lines, config)
    print("conversations count:", len(conversations))
    return conversations

def _extract_sentence_pairs(conversations):
    pairs = []
    for conversation in conversations:
        for i in range(len(conversation["lines"]) - 1):  # We ignore the last line (no answer for it)
            inputLine = conversation["lines"][i]["text"].strip()
            targetLine = conversation["lines"][i+1]["text"].strip()
            # Filter wrong samples (if one of the lists is empty)
            if inputLine and targetLine:
                pairs.append([inputLine, targetLine])
    return pairs

def _load_formatted_data(config):
    max_length = config["max_length"]
    corpus_name = config["corpus_name"]

    formatted_file_full_path = get_formatted_file_full_path(config)
    print("Start preparing training data ...")
    voc, pairs = _read_vocabulary(formatted_file_full_path, corpus_name)
    print("Read {!s} sentence pairs".format(len(pairs)))
    pairs = _filter_pairs(pairs, max_length)
    print("Trimmed to {!s} sentence pairs".format(len(pairs)))
    for pair in pairs:
        voc.addSentence(pair[0])
        voc.addSentence(pair[1])
    print("Counted words:", voc.num_words)
    return voc, pairs

def _trim_rare_words(voc, pairs, min_count):
    voc.trim(min_count)
    keep_pairs = []
    for pair in pairs:
        input_sentence = pair[0]
        output_sentence = pair[1]
        keep_input = True
        keep_output = True
        for word in input_sentence.split(' '):
            if word not in voc.word2index:
                keep_input = False
                break
        for word in output_sentence.split(' '):
            if word not in voc.word2index:
                keep_output = False
                break

        if keep_input and keep_output:
            keep_pairs.append(pair)

    print("Trimmed from {} pairs to {}, {:.4f} of total".format(len(pairs), len(keep_pairs), len
<项目介绍> 基于Transformer模型构建的聊天机器人python实现源码+项目说明.zip基于Transformer模型构建的聊天机器人python实现源码+项目说明.zip 该资源内项目源码是个人的毕设,代码都测试ok,都是运行成功后才上传资源,答辩评审平均分达到94.5分,放心下载使用! 该资源适合计算机相关专业(如人工智能、通信工程、自动化、软件工程等)的在校学生、老师或者企业员工下载,适合小白学习或者实际项目借鉴参考! 当然也可作为毕业设计、课程设计、课程作业、项目初期立项演示等。如果基础还行,可以在此代码基础之上做改动以实现更多功能。 一、简介 基于Transformer模型构建的聊天机器人,可实现日常聊天。 二、系统说明 2.1 功能介绍 使用者输入文本后,系统可根据文本做出相应的回答。 2.2 数据介绍 * 百度中文问答 WebQA数据集 * 青云数据集 * 豆瓣数据集 * chatterbot数据集 由于数据集过大,因此不会上传,如有需要可以在issue中提出。 2.3. 模型介绍(v1.0版本) 基于Transformer模型,使用Python中的keras-transformer包。 训练的参数文件没有上传,如有需要可在issue中提出。 三、注意事项 * keras-transformer包需要自行安装:`pip install keras-transformer`。 * 如果需要实际运行,参数文件放在`ModelTrainedParameters`文件下;`ListData`文件下包含了已经处理好的字典等数据,不需要修改,直接运行Main.py即可。 * 如果需要自行训练,将数据集文件放在`DataSet`文件下。 * `HyperParameters.py`文件中包含了系统所需要的超参数,包括文件路径等,可根据需要自行修改;其中包含了训练模型、重新训练模型、测试模型(实际运行)的控制参数,可自行修改使用。
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值