C2CRS_System
C2CRS_System.py
是一个用于构建和训练对话推荐系统(Conversational Recommender System, C2CRS)的系统类。它继承自 BaseSystem
,并扩展了对话推荐系统的特定功能。
主要功能
-
系统初始化:
- 初始化推荐系统、对话系统和预训练模块。
- 设置数据加载器、词汇表、侧边数据(side data)和优化器。
- 支持恢复训练、交互模式和调试模式。
-
数据预处理:
- 扩展训练数据集,添加额外信息(如知识图谱和评论信息)。
- 初始化推荐、对话和预训练模块的属性。
-
预训练:
- 实现推荐模块的预训练,使用交叉熵损失进行优化。
- 支持在预训练阶段保存特定轮次的模型。
-
推荐任务:
- 训练推荐模块,使用交叉熵损失进行优化。
- 支持早停机制(early stopping)以避免过拟合。
- 在验证集和测试集上评估推荐性能,计算命中率(hit rate)等指标。
-
对话任务:
- 训练对话模块,使用生成任务的损失进行优化。
- 支持冻结参数(freeze parameters)以提高训练效率。
- 在验证集和测试集上评估对话性能,计算困惑度(PPL)和多样性指标。
-
模型保存与恢复:
- 支持在训练过程中保存模型。
- 支持从保存的模型中恢复训练。
-
日志记录与评估:
- 使用
loguru
记录训练过程中的日志。 - 支持多种评估指标,包括推荐任务的命中率和对话任务的困惑度、多样性等。
- 使用
-
交互模式:
- 提供与系统交互的接口(目前未实现具体逻辑)。
模块功能详解
1. 初始化
__init__
:- 初始化系统的基本属性,包括数据加载器、词汇表、侧边数据等。
- 调用
_init_token_attribute
、_init_rec_attribute
、_init_conv_attribute
和_init_pretrain_attribute
方法,分别初始化与词汇表、推荐、对话和预训练相关的属性。
2. 数据预处理
extend_datasets
:- 扩展训练、验证和测试数据集,添加额外信息(如视频信息)。
3. 预训练
pre_training
:- 初始化预训练优化器。
- 调用
pretrain_recommender_convergence
方法进行推荐模块的预训练。
pretrain_recommender_one_epoch
:- 训练预训练推荐模块的一个轮次。
valid_pretrain_recommender
:- 在验证集上评估预训练推荐模块的性能。
4. 推荐任务
train_recommender_default
:- 初始化推荐模块的优化器。
- 调用
train_recommender_convergence
方法进行推荐模块的训练。
train_recommender_one_epoch
:- 训练推荐模块的一个轮次。
valid_recommender
:- 在验证集上评估推荐模块的性能。
test_recommender
:- 在测试集上评估推荐模块的性能。
5. 对话任务
train_conversation_using_rec_model
:- 初始化对话模块的优化器。
- 调用
train_conversation_convergence
方法进行对话模块的训练。
train_conversation_one_epoch
:- 训练对话模块的一个轮次。
valid_conversation
:- 在验证集上评估对话模块的性能。
test_conversation
:- 在测试集上评估对话模块的性能。
6. 模型保存与恢复
is_early_stop
:- 检查是否满足早停条件。
save_model
:- 保存模型到指定路径。
restore_model_from_save
:- 从保存的路径中恢复模型。
7. 日志记录与评估
record_conv_gt_pred
、record_conv_gt
、record_conv_pred
:- 记录对话任务的预测和真实值。
get_file_writer
:- 获取文件写入器,用于记录日志。
convert_tensor_ids_to_tokens
:- 将张量 ID 转换为词汇表中的单词。
总结
C2CRS_System
适用于构建对话推荐系统,特别是在需要结合推荐和对话功能的场景中。它支持预训练、推荐和对话任务的训练和评估,能够有效地提升系统的性能和用户体验。