parser = argparse.ArgumentParser()
parser.add_argument('--dataset_path', default='datasets/Ciao/', help='dataset directory path: datasets/Ciao/Epinions')
parser.add_argument('--batch_size', type=int, default=256, help='input batch size')
parser.add_argument('--embed_dim', type=int, default=64, help='the dimension of embedding')
parser.add_argument('--epoch', type=int, default=10, help='the number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.001, help='learning rate') # [0.001, 0.0005, 0.0001]
parser.add_argument('--lr_dc', type=float, default=0.1, help='learning rate decay rate')
parser.add_argument('--lr_dc_step', type=int, default=30, help='the number of steps after which the learning rate decay')
parser.add_argument('--test', action='store_true', help='test')
args = parser.parse_args()
print(args)
here = os.path.dirname(os.path.abspath(__file__))
device = torch
gnn代码实现
最新推荐文章于 2024-11-01 15:49:22 发布
该代码片段展示了使用PyTorch进行推荐系统模型训练的过程。它定义了命令行参数,加载数据集,创建数据加载器,定义模型,设置优化器和学习率调度器,并执行训练和验证循环。此外,还实现了模型检查点的保存和最佳模型的选择。

最低0.47元/天 解锁文章
4381





