Gymnasium教程:从零开始创建自定义GridWorld环境

Gymnasium教程:从零开始创建自定义GridWorld环境

【免费下载链接】Gymnasium An API standard for single-agent reinforcement learning environments, with popular reference environments and related utilities (formerly Gym) 【免费下载链接】Gymnasium 项目地址: https://gitcode.com/GitHub_Trending/gy/Gymnasium

为什么需要自定义环境?

强化学习(Reinforcement Learning, RL)研究中,标准环境往往无法满足特定问题需求。你是否曾遇到:

  • 现有环境与研究目标不匹配?
  • 公开环境过于复杂,难以调试算法?
  • 需要精确控制环境参数以验证假设?

本教程将带你构建一个可扩展的GridWorld环境,掌握Gymnasium环境开发的核心范式。完成后,你将能够:

  • 设计符合OpenAI Gymnasium标准的自定义环境
  • 实现 observation space 与 action space 的标准化定义
  • 添加渲染功能实现环境可视化
  • 注册环境并与主流RL框架无缝集成
  • 避免环境设计中的常见陷阱

环境设计规范与核心组件

Gymnasium环境遵循严格的接口规范,确保与各类RL算法兼容。核心组件包括:

mermaid

关键设计决策:

  • 任务目标:导航至随机目标位置
  • 观测空间:Dict类型包含智能体与目标坐标
  • 动作空间:4个离散方向(上下左右)
  • 奖励机制:到达目标+1,否则0(稀疏奖励)
  • 终止条件:智能体到达目标位置

完整实现步骤

1. 环境基础架构

from enum import Enum
import numpy as np
import pygame
import gymnasium as gym
from gymnasium import spaces

class Actions(Enum):
    RIGHT = 0
    UP = 1
    LEFT = 2
    DOWN = 3

class GridWorldEnv(gym.Env):
    metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4}
    
    def __init__(self, render_mode=None, size=5):
        self.size = size  # 网格尺寸
        self.window_size = 512  # 渲染窗口大小
        
        # 定义观测空间:智能体和目标的坐标位置
        self.observation_space = spaces.Dict({
            "agent": spaces.Box(0, size-1, shape=(2,), dtype=int),
            "target": spaces.Box(0, size-1, shape=(2,), dtype=int),
        })
        
        # 定义动作空间:4个离散动作
        self.action_space = spaces.Discrete(4)
        
        # 动作到方向的映射
        self._action_to_direction = {
            Actions.RIGHT.value: np.array([1, 0]),
            Actions.UP.value: np.array([0, 1]),
            Actions.LEFT.value: np.array([-1, 0]),
            Actions.DOWN.value: np.array([0, -1]),
        }
        
        self.render_mode = render_mode
        self.window = None
        self.clock = None
        self._agent_location = np.array([-1, -1], dtype=int)
        self._target_location = np.array([-1, -1], dtype=int)

2. 状态与信息获取

    def _get_obs(self):
        """将环境内部状态转换为观测值"""
        return {"agent": self._agent_location, "target": self._target_location}
    
    def _get_info(self):
        """计算辅助信息(曼哈顿距离)"""
        return {
            "distance": np.linalg.norm(
                self._agent_location - self._target_location, ord=1
            )
        }

3. 重置函数实现

    def reset(self, seed=None, options=None):
        """初始化新回合,返回初始观测和信息"""
        super().reset(seed=seed)  # 必须调用父类reset以确保正确的随机数种子
        
        # 随机放置智能体
        self._agent_location = self.np_random.integers(
            0, self.size, size=2, dtype=int
        )
        
        # 随机放置目标,确保与智能体位置不同
        self._target_location = self._agent_location
        while np.array_equal(self._target_location, self._agent_location):
            self._target_location = self.np_random.integers(
                0, self.size, size=2, dtype=int
            )
        
        if self.render_mode == "human":
            self._render_frame()
            
        return self._get_obs(), self._get_info()

4. 核心步进逻辑

    def step(self, action):
        """执行一步环境动力学
        
        Args:
            action: 智能体选择的动作(0-3对应四个方向)
            
        Returns:
            observation: 新观测
            reward: 即时奖励
            terminated: 是否达到终止状态
            truncated: 是否达到时间限制
            info: 辅助信息
        """
        # 将动作映射为移动方向
        direction = self._action_to_direction[action]
        
        # 更新智能体位置,确保不超出网格边界
        self._agent_location = np.clip(
            self._agent_location + direction, 0, self.size - 1
        )
        
        # 检查是否到达目标
        terminated = np.array_equal(self._agent_location, self._target_location)
        reward = 1 if terminated else 0  # 稀疏奖励设计
        
        if self.render_mode == "human":
            self._render_frame()
            
        return (
            self._get_obs(),
            reward,
            terminated,
            False,  # 本环境不使用truncated
            self._get_info(),
        )

5. 渲染功能实现

    def render(self):
        """根据渲染模式返回图像或更新窗口"""
        if self.render_mode == "rgb_array":
            return self._render_frame()
    
    def _render_frame(self):
        """渲染当前帧"""
        # 初始化Pygame窗口和时钟(仅在人类模式下)
        if self.window is None and self.render_mode == "human":
            pygame.init()
            pygame.display.init()
            self.window = pygame.display.set_mode(
                (self.window_size, self.window_size)
            )
        if self.clock is None and self.render_mode == "human":
            self.clock = pygame.time.Clock()
            
        # 创建画布并填充白色背景
        canvas = pygame.Surface((self.window_size, self.window_size))
        canvas.fill((255, 255, 255))
        pix_square_size = self.window_size / self.size  # 每个网格的像素大小
        
        # 绘制目标(红色方块)
        pygame.draw.rect(
            canvas,
            (255, 0, 0),  # RGB红色
            pygame.Rect(
                pix_square_size * self._target_location,
                (pix_square_size, pix_square_size),
            ),
        )
        
        # 绘制智能体(蓝色圆形)
        pygame.draw.circle(
            canvas,
            (0, 0, 255),  # RGB蓝色
            (self._agent_location + 0.5) * pix_square_size,
            pix_square_size / 3,
        )
        
        # 绘制网格线
        for x in range(self.size + 1):
            pygame.draw.line(
                canvas,
                0,  # 黑色
                (0, pix_square_size * x),
                (self.window_size, pix_square_size * x),
                width=3,
            )
            pygame.draw.line(
                canvas,
                0,  # 黑色
                (pix_square_size * x, 0),
                (pix_square_size * x, self.window_size),
                width=3,
            )
        
        # 在人类模式下更新窗口
        if self.render_mode == "human":
            self.window.blit(canvas, canvas.get_rect())
            pygame.event.pump()
            pygame.display.update()
            self.clock.tick(self.metadata["render_fps"])
        else:  # rgb_array模式
            return np.transpose(
                np.array(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2)
            )
    
    def close(self):
        """关闭渲染窗口"""
        if self.window is not None:
            pygame.display.quit()
            pygame.quit()

6. 环境注册与使用

# 注册环境
from gymnasium.envs.registration import register

register(
    id="GridWorld-v0",
    entry_point=GridWorldEnv,
    max_episode_steps=300,  # 添加时间限制包装器
)

# 使用示例
if __name__ == "__main__":
    # 创建环境实例
    env = gym.make("GridWorld-v0", render_mode="human", size=5)
    
    # 运行随机策略测试
    observation, info = env.reset()
    for _ in range(1000):
        action = env.action_space.sample()  # 随机动作
        observation, reward, terminated, truncated, info = env.step(action)
        
        if terminated or truncated:
            observation, info = env.reset()
    
    env.close()

环境设计进阶技巧

奖励函数优化

稀疏奖励(仅目标达成时+1)使学习困难,可优化为:

# 选项1: 添加步长惩罚鼓励效率
reward = 1.0 if terminated else -0.01

# 选项2: 基于距离的奖励塑造
distance = np.linalg.norm(self._agent_location - self._target_location, ord=1)
previous_distance = info.get("distance", distance)
reward = 1.0 if terminated else (previous_distance - distance) * 0.1

参数化环境设计

def __init__(self, render_mode=None, size=5, reward_scale=1.0, step_penalty=0.0):
    self.size = size
    self.reward_scale = reward_scale
    self.step_penalty = step_penalty
    # ... 其余初始化代码 ...
    
def step(self, action):
    # ... 移动逻辑 ...
    if terminated:
        reward = self.reward_scale
    else:
        reward = -self.step_penalty
    # ... 返回结果 ...

常见错误与调试

# 错误1: 忘记调用super().reset()导致随机数种子失效
def reset(self, seed=None, options=None):
    # super().reset(seed=seed)  # 必须调用此行
    
# 错误2: 未处理边界条件导致智能体走出网格
self._agent_location = self._agent_location + direction  # 缺少np.clip

# 错误3: 观测空间定义与实际返回值不匹配
self.observation_space = spaces.Box(0, size-1, shape=(2,))  # 与返回的字典不匹配

包装器使用示例

# 1. 观测空间展平(将字典转换为数组)
from gymnasium.wrappers import FlattenObservation
env = FlattenObservation(gym.make("GridWorld-v0"))
print(env.observation_space)  # Box(0, 4, (4,), int64)

# 2. 观测值归一化
from gymnasium.wrappers import NormalizeObservation
env = NormalizeObservation(FlattenObservation(gym.make("GridWorld-v0")))

# 3. 自定义相对位置包装器
class RelativePosition(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        self.observation_space = spaces.Box(
            -self.env.size + 1, self.env.size - 1, shape=(2,), dtype=int
        )
        
    def observation(self, obs):
        return obs["target"] - obs["agent"]

环境评估与可视化

使用以下代码分析环境动态特性:

def analyze_environment(env_id="GridWorld-v0", episodes=100):
    env = gym.make(env_id)
    steps_per_episode = []
    rewards_per_episode = []
    
    for _ in range(episodes):
        obs, _ = env.reset()
        total_reward = 0
        steps = 0
        
        while True:
            action = env.action_space.sample()
            obs, reward, terminated, truncated, _ = env.step(action)
            total_reward += reward
            steps += 1
            
            if terminated or truncated:
                steps_per_episode.append(steps)
                rewards_per_episode.append(total_reward)
                break
    
    print(f"平均步数: {np.mean(steps_per_episode):.2f} ± {np.std(steps_per_episode):.2f}")
    print(f"平均奖励: {np.mean(rewards_per_episode):.2f} ± {np.std(rewards_per_episode):.2f}")
    
    # 绘制步数分布
    import matplotlib.pyplot as plt
    plt.hist(steps_per_episode, bins=20)
    plt.title("Episode Length Distribution")
    plt.xlabel("Steps")
    plt.ylabel("Frequency")
    plt.show()

analyze_environment()

扩展与应用

GridWorld可扩展为更复杂的环境:

mermaid

总结与最佳实践

创建自定义环境时遵循以下原则:

  1. 接口一致性:严格实现Gymnasium接口规范
  2. 可复现性:正确处理随机数种子
  3. 观测设计:包含决策所需全部信息,避免冗余
  4. 奖励工程:平衡稀疏性与引导性
  5. 逐步复杂化:先实现核心功能,再添加特性
  6. 全面测试:验证边界条件与异常情况

通过本文学习,你已掌握构建Gymnasium兼容环境的完整流程。这个GridWorld框架可作为强化学习研究的实验平台,帮助你深入理解智能体行为与环境动态的相互作用。

【免费下载链接】Gymnasium An API standard for single-agent reinforcement learning environments, with popular reference environments and related utilities (formerly Gym) 【免费下载链接】Gymnasium 项目地址: https://gitcode.com/GitHub_Trending/gy/Gymnasium

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值