ElegantRL源码解读
-
框架结构
-
创造agent,各个agent已经封装好,选择算法
from elegantrl.agent import AgentDuelingDQN
-
选择环境
PreprocessEnv(env=gym.make('LunarLander-v2')) # 股票环境 from envs.FinRL.StockTrading import StockTradingEnv # 在环境里加载数据,加入相关指标 self.price_ary, self.tech_ary = self.load_data(cwd, if_eval, ticker_list, tech_indicator_list, start_date, end_date, env_eval_date, )
-
训练和测试
train_and_evaluate(args) # 多进程 args.rollout_num = 4 train_and_evaluate_mp(args)
- 初始化agent
agent.init(net_dim, state_dim, action_dim, if_per)
- 创建网络
- 创建buffer
buffer = ReplayBuffer(max_len=max_memo + max_step, state_dim=state_dim, action_dim=1 if if_discrete else action_dim, if_on_policy=if_on_policy, if_per=if_per, if_gpu=True) 疑问:self.tree = BinarySearchTree(max_len)
- 收集数据
with torch.no_grad(): # update replay buffer steps = explore_before_training(env, buffer, target_step, reward_scale, gamma) # 利用环境里的step next_state, reward, done, _ = env.step(action)