td.train()和td.eval()

文章讲述了深度学习中模型的训练模式(td.train())和评估模式(td.eval())的区别,涉及BatchNormalization和Dropout的行为变化,以及PyTorch在两种模式下的梯度管理和内存优化。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

td.train() 是用于将模型切换为训练模式的一行代码。在深度学习中,模型通常会有两种模式:训练模式(training mode)和评估模式(evaluation mode)。

td 是您定义的模型对象,通过调用 td.train() 可以将模型切换到训练模式。主要影响是,当模型处于训练模式时,将启用一些针对训练过程中的特定功能或操作的设置。

在训练模式下,常见的设置和行为包括:

  • Batch Normalization 和 Dropout:在模型中使用了 Batch Normalization 或 Dropout 层时,这些层会在训练模式下以特定方式处理输入,进行归一化和随机失活操作。而在评估模式下,则会按照不同的规则处理输入。

  • 记录梯度信息:在训练模式下,PyTorch 默认会保存每个参数的梯度信息,以便进行反向传播和参数更新。而在评估模式下,梯度信息不会被保留,以降低内存消耗。

  • Dropout 随机性:在训练模式下,Dropout 层会以一定的概率进行神经元的随机失活操作,以防止过拟合。而在评估模式下,Dropout 不进行任何操作。

请注意,在使用 td.train() 切换模型为训练模式之前,通常要将模型的参数和数据移动到相应的设备(例如 GPU),以便进行实际的训练过程。

td.eval() 是用于将模型切换为评估模式的一行代码。在深度学习中,模型通常会有两种模式:训练模式(training mode)和评估模式(evaluation mode)。

td 是您定义的模型对象,通过调用 td.eval() 可以将模型切换到评估模式。主要影响是,在评估模式下,模型会禁用一些只在训练过程中使用的特定功能。

在评估模式下,常见的设置和行为包括:

  • Batch Normalization 和 Dropout:在评估模式下,Batch Normalization 层将使用完整的统计信息归一化输入,而不是使用每个批次的统计信息。Dropout 层也将停用,不进行随机失活操作。

  • 记录梯度信息:在评估模式下,PyTorch 不会保存梯度信息,以降低内存消耗。

  • Dropout 随机性:在评估模式下,Dropout 层不进行任何随机操作,保持原样。

通过将模型切换到评估模式,可以确保在模型推理和性能评估时遵循正确的设置。请注意,在开始评估之前,通常要将模型的参数和数据移动到相应的设备(例如 GPU)

import copy import ray from ray import tune from egpo_utils.cql.cql import CQLTrainer from egpo_utils.common import evaluation_config, ILCallBack, CQLInputReader from egpo_utils.expert_guided_env import ExpertGuidedEnv from egpo_utils.train import get_train_parser from egpo_utils.train.train import train import os data_set_file_path = os.path.join(os.path.dirname(__file__), 'expert_traj_500.json') def get_data_sampler_func(ioctx): return CQLInputReader(data_set_file_path) eval_config = copy.deepcopy(evaluation_config) eval_config["input"] = "sampler" # important to use pgdrive online evaluation eval_config["env_config"]["random_spawn"] = True if __name__ == '__main__': print(data_set_file_path) try: file = open(data_set_file_path) except FileNotFoundError: raise FileExistsError("Please collect dataset by using collect_dataset.py at first") assert ray.__version__ == "1.3.0" or ray.__version__ == "1.2.0", "ray 1.3.0 is required" args = get_train_parser().parse_args() exp_name = "CQL" or args.exp_name stop = {"timesteps_total": 100_0000_00000} config = dict( # ===== Evaluation ===== env=ExpertGuidedEnv, env_config=evaluation_config["env_config"], input_evaluation=["simulation"], evaluation_interval=1, evaluation_num_episodes=30, evaluation_config=eval_config, evaluation_num_workers=2, metrics_smoothing_episodes=20, # ===== Training ===== # cql para lagrangian=False, # Automatic temperature (alpha prime) control temperature=5, # alpha prime in paper, 5 is best in pgdrive min_q_weight=0.2, # best bc_iters=20_0000, # bc_iters > 20_0000 has no obvious improvement # offline setting no_done_at_end=True, input=get_data_sampler_func, optimization=dict(actor_learning_rate=1e-4, critic_learning_rate=1e-4, entropy_learning_rate=1e-4), rollout_fragment_length=200, prioritized_replay=False, horizon=2000, target_network_update_freq=1, timesteps_per_iteration=1000, learning_starts=10000, clip_actions=False, normalize_actions=True, num_cpus_for_driver=0.5, # No extra worker used for learning. But this config impact the evaluation workers. num_cpus_per_worker=0.1, # num_gpus_per_worker=0.1 if args.num_gpus != 0 else 0, num_gpus=0.2 if args.num_gpus != 0 else 0, framework="torch" ) train( CQLTrainer, exp_name=exp_name, keep_checkpoints_num=5, stop=stop, config=config, num_gpus=args.num_gpus, # num_seeds=2, num_seeds=5, custom_callback=ILCallBack, # test_mode=True, # local_mode=True ) 运行结果怎么可视化
04-03
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值