【STABLE BASELINE3】自定义环境代码,PPO,SAC,离散动作/连续状态

遇到的一些问题

明明用的正常的代码,结果报错,还以为是程序本身有问题

The algorithm only supports (<class 'gym.spaces.discrete.Discrete'>,) as action spaces but Discrete(4) was provided

结果是库的问题,估计是代码比较老了,兼容性不高?
正确:

from gym import spaces

错误:

import gymnasium as gym

代码示例

import os
import sys

import gymnasium
import gymnasium as gym
import pandas as pd
from gymnasium import spaces
from matplotlib import pyplot as plt

from stable_baselines3 import PPO, SAC, A2C, DDPG
from stable_baselines3.common.callbacks import CheckpointCallback, EveryNTimesteps
from stable_baselines3.common.env_util import make_vec_env
import gym
import numpy as np
from gym import spaces
from gym.spaces import Box, Discrete
from pylab import mpl
sys.modules["gym"] = gymnasium

# 设置显示中文字体
mpl.rcParams["font.sans-serif"] = ["SimHei"]
# import ray
e_sell = 0.3
price = pd.read_csv('.//data//price_data.csv')  # 已经除于1000了
price = np.array(price['实时电价']).flatten()

class CustomEnv(gym.Env):
    """Custom Environment that follows gym interface."""

    metadata = {"render_modes": ["human"], "render_fps": 30}

    def __init__(self, applience):
        self.title = applience.title[0] if type(applience.title) == tuple else applience.title
        self.power_rating = np.array(applience.power_rating, dtype=float).flatten()[0]
        self.T_ini = applience.T_ini  # 规定的启动时间
        self.T_end = applience.T_end  # 规定的结束时间
        self.T_last = applience.T_last  # 规定的运行时间
        self.T_run = 0  # 记录已经运行的时间
        self.ds = applience.ds  # dissatisfaction parameter 0.045 0.07
        self.t = 0
        self.prices = price
        self.n_state_space = 3  # 状态,[是否工作时间,剩余工作时间,已经工作时间]
        # self.n_action_space = 1  # 动作个数,[0, 1]  1:on  0:off
        self.n_action_space = 2  # 动作个数,[0, 1]  1:on  0:off
        self.n_action_type = 'Discrete'  # 动作个数,[0, 1]  1:on  0:off
        # self.observation_space = spaces.Box(low=-float('inf'), high=float('inf'), shape=(self.n_state_space,),
        #                                     dtype=np.float32)
        self.action_space = spaces.Discrete(self.n_action_space)
        # self.action_space = spaces.Discrete(self.n_action_space)
        # self.action_space = spaces.Box(low=-1, high=1, shape=(self.n_action_space,),
        #                                dtype=np.float32)
        self.observation_space = spaces.Box(low=-float('inf'), high=float('inf'), shape=(self.n_state_space,))
        # self.action_space = Box(low=-1, high=1, shape=(self.n_action_space,))
        self.run_list = np.zeros(24)  # 记录运行过程
        self.power_list = np.zeros(24)  # 记录每个时刻的功率大小
        self.reward_list = np.zeros(24)  # 记录每个时刻的奖励值
        self.reward_detail = np.zeros([24, 2])  # 记录每个时刻的奖励值

    def step(self, action):
        # 不在运行时间范围内
        if self.T_ini <= self.t % 24 < self.T_end:  # 在运行范围内
            cost_penalty = 0
            if self.T_run == 0:  # 还没有开始运行
                if self.t < self.T_end - self.T_last:  # 还能完成运行任务
                    if action > 0:
                        power = self.power_rating
                        self.T_run += 1
                        self.run_list[self.t] = self.T_run
                        cost_penalty += self.ds
                    else:
                        power = 0
                        self.run_list[self.t] = 0
                        cost_penalty += self.ds
                else:
                    power = self.power_rating
                    self.T_run += 1
                    self.run_list[self.t] = self.T_run
                    cost_penalty += self.ds
                    if action <= 0:
                        cost_penalty += 10
            elif self.T_last > self.T_run > 0:  # 开始运行了,还未到指定时间
                power = self.power_rating
                self.T_run += 1
                self.run_list[self.t] = self.T_run
                cost_penalty += self.ds
                if action <= 0:
                    cost_penalty += 10
            else:  # 已到运行时间
                power = 0
                self.run_list[self.t] = 0
            if power >= 0:
                cost_energy = self.prices[self.t] * power
            else:
                cost_energy = e_sell * self.prices[self.t] * power
            self.reward_list[self.t] = -cost_energy - cost_penalty
            self.reward_detail[self.t] = [cost_energy, cost_penalty]
            self.power_list[self.t] = power
            self.t = (self.t + 1) % 24
        else:  # 不在运行范围内
            self.run_list[self.t] = 0
            self.power_list[self.t] = 0
            self.reward_list[self.t] = 0
            self.reward_detail[self.t] = [0, 0]
            self.t = (self.t + 1) % 24

        if self.T_ini <= self.t < self.T_end:  # 在运行范围内
            observation = [1, (self.T_end - self.t) / (self.T_end - self.T_ini), self.T_run / self.T_last]
        else:
            observation = [0, 0, 0]

        reward = self.reward_list[self.t - 1] if self.t > 0 else self.reward_list[24 - 1]

        dones = 1 if self.t == 12 else 0

        return observation, reward, dones, {}

    def reset(self, seed=None, options=None):
        self.t = 12
        self.T_run = 0  # 记录已经运行的时间
        self.power_list = np.zeros(24)  # 记录每个时刻的功率大小
        self.reward_list = np.zeros(24)  # 记录每个时刻的奖励值
        self.run_list = np.zeros(24)  # 记录每个时刻的功率大小
        self.reward_detail = np.zeros([24, 2])  # 记录每个时刻的奖励值
        if self.T_ini <= self.t < self.T_end:  # 在运行范围内
            observation = [1, (self.T_end - self.t) / (self.T_end - self.T_ini), self.T_run / self.T_last]
        else:
            observation = [0, 0, 0]
        return observation

    def get_output(self, visualize=False):
        column = ['电价', self.title + '功率', self.title + '运行状态', self.title + '电费',
                  self.title + '舒适度', self.title + '总成本']
        data = np.concatenate([price.reshape(24, 1), self.power_list.reshape(24, 1), self.run_list.reshape(24, 1),
                               self.reward_detail, self.reward_list.reshape(24, 1)], axis=1)
        data = pd.DataFrame(data, columns=column)
        if visualize:
            figure = plt.figure(tight_layout=True)
            # 表示1行2列的第一个区域
            cost_e, cost_c, cost_a = sum(data[self.title + '电费']), sum(data[self.title + '舒适度']), \
                                     sum(data[self.title + '总成本'])
            cost_e, cost_c, cost_a = "{:.2f}".format(cost_e), "{:.2f}".format(cost_c), "{:.2f}".format(cost_a)
            plt.subplot(3, 1, 1)
            plt.plot(range(24), price[:24], linestyle=':', marker='o', color="red")
            plt.xlabel("电价")
            plt.title(f'{cost_e, cost_c, cost_a}')
            # 表示1行2列的第二个区域
            plt.subplot(3, 1, 2)
            plt.plot(range(24), data[self.title + '运行状态'], linestyle='-', marker='s', color="black")
            plt.xlabel(self.title + '运行状态')
            # 表示1行2列的第二个区域
            plt.subplot(3, 1, 3)
            plt.plot(range(24), data[self.title + '功率'], linestyle='-', marker='s', color="blue")
            plt.xlabel(self.title + '功率')
            plt.show()
        return data

class wm2():
    def __init__(self):
        self.title = 'WM',
        self.power_rating = 0.7
        self.T_ini = 15
        self.T_end = 23
        self.T_last = 3
        self.ds = 0.2  # dissatisfaction parameter 0.045 0.07
        self.thita = 0.001
        self.dispatch = 0


def output_model_detail(model, env):
    episode_reward = 0
    state = env.reset()
    record = []
    for step in range(24):
        action, _states = model.predict(state)
        # print(state)
        obs, rewards, dones, info = env.step(action)
        # print(action, rewards)
        # print(info)
        if not step:
            record_name = list(info.keys())
            record.append(list(info.values()))
        else:
            record.append(list(info.values()))
        episode_reward += rewards
        if dones:
            break
        state = obs
    record = np.array(record)
    record_d = pd.DataFrame(record)
    record_d.columns = record_name
    output = env._get_output()
    output.to_csv(env.env_name + '输出结果.csv', header=True, encoding='utf_8_sig')
    return episode_reward

env = CustomEnv(wm2())
# 创建一个PPO模型
model = PPO("MlpPolicy", env, verbose=1)

# 创建一个CheckpointCallback,用于保存最优模型
checkpoint_callback = CheckpointCallback(save_freq=1000, save_path='./logs/')

# 训练模型
model.learn(total_timesteps=50000, callback=checkpoint_callback, tb_log_name="ppo_cartpole")

# 保存最优模型和归一化器
best_model_path = os.path.join(checkpoint_callback.save_path, 'best_model')
model.save(best_model_path)

# 加载最优模型和归一化器
loaded_model = PPO.load(best_model_path, env=env)

obs = env.reset()
while True:
    action, _states = model.predict(obs)
    print(obs, action)
    obs, rewards, dones, info = env.step(action)
    if dones:
        break

output = env.get_output(visualize=True)
output.to_csv(f'{env.title}_输出结果.csv', header=True, encoding='utf-8')
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值