import datetime
import pathlib
import threading
import time
import hydra
import numpy as np
import torch
import zmq # ZeroMQ用于进程间通信
from typing import Optional
import dill # 用于Python对象序列化
from omegaconf import OmegaConf # 配置管理库
# 导入自定义模块
from agent.schema import RobotState, RobotObsShape # 机器人状态和观察数据结构,消息定义发送
from agent.utils import (
dict_apply, # 递归应用函数到字典中的张量
interpolate_image_batch, # 批量调整图像尺寸并进行归一化
)
from controller.policy.dexgraspvla_controller import DexGraspVLAController # 抓取控制策略
import logging
import cv2
import os
# 获取当前脚本所在目录
current_dir = os.path.dirname(os.path.abspath(__file__))
# 构建相对路径
pretrain_path = os.path.join(current_dir, "..", "..", "..", "dinov2", "checkpoints","dinov2_vitb14_pretrain.pth")
# 初始化日志系统
def log_init():
"""配置并初始化日志记录器"""
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG) # 设置日志级别
# 定义日志格式
format = (
"[%(asctime)s %(levelname)s %(filename)s %(funcName)s:%(lineno)d] %(message)s"
)
handler = logging.StreamHandler() # 控制台日志处理器
handler.setLevel(logging.DEBUG)
formatter = logging.Formatter(format, datefmt="%Y-%m-%d %H:%M:%S")
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.info("Logger inited")
return logger
logger = log_init() # 全局日志记录器
class Robot:
"""机器人控制主类,负责状态管理、观察获取和动作执行"""
def __init__(self, config):
"""
根据配置初始化机器人系统
参数:
config: 包含以下关键字段的配置对象
- controller_checkpoint_path: 控制器模型路径
- device: 计算设备 (cpu/cuda)
- port: ZeroMQ通信端口
- executions_per_action_chunk: 每个动作块执行次数
"""
# 加载预训练控制器模型
checkpoint_path = config.controller_checkpoint_path
payload = torch.load(checkpoint_path, pickle_module=dill) # 使用dill加载模型
#使用 `torch.load` 加载模型检查点,但指定了 `pickle_module=dill`。这通常是因为模型保存时使用了 `dill` 库而不是标准的 `pickle` 库。`dill` 可以序列化更广泛的对象,包括一些 `pickle` 无法处理的函数、lambda 表达式等。
# 更新模型配置中的本地权重路径
# payload["cfg"]["policy"]["obs_encoder"]["model_config"]["head"][
# "local_weights_path"
# ] = "/home/fishros/.cache/torch/hub/checkpoints/dinov2_vitb14_pretrain.pth"
payload["cfg"]["policy"]["obs_encoder"]["model_config"]["head"][
"local_weights_path"
] = pretrain_path
cfg = payload["cfg"]
cls = hydra.utils.get_class(cfg._target_) # 动态获取类
workspace = cls(cfg) # 创建工作空间实例
workspace.load_payload(payload, exclude_keys=None, include_keys=None) # 加载模型权重
# 初始化控制器
self.controller: DexGraspVLAController
self.controller = workspace.model
# 配置参数
self.device = config.device
self.executions_per_action_chunk = config.executions_per_action_chunk
# 设置模型为评估模式并转移到指定设备
self.controller.eval()
print(f'!! torch.cuda.is_available()={torch.cuda.is_available()}')
self.controller.to(self.device if torch.cuda.is_available() else "cpu")
# 传感器数据初始化
self.head_image = None # 头部摄像头图像
self.wrist_image = None # 腕部摄像头图像
self.proprioception = None # 本体感知数据(关节角度等)
# ZeroMQ通信设置
self.context = zmq.Context()
self.socket = self.context.socket(zmq.REP) # 应答模式socket
self.port = config.port
self.socket.bind(f"tcp://*:{self.port}") # 绑定到指定端口
# 状态管理
self.state = RobotState.IDLE # 初始状态为空闲
# 重置动作和监听线程
self.resetting_action = None
self.listening_thread = threading.Thread(target=self.listening_mannual)
self.listening_thread.start() # 启动用户输入监听线程
def _parse_obs(self, message: bytes) -> Optional[dict]:
"""解析从socket接收的二进制观察数据"""
# 验证消息长度
if len(message) != RobotObsShape.CHUNK_SIZE:
logger.error(
f"Invalid message size, required {RobotObsShape.CHUNK_SIZE} bytes"
)
return None
# 解析头部摄像头图像数据 (uint8数组)
head_image = np.frombuffer(
message.buffer[: RobotObsShape.HEAD_IMAGE_SIZE],
dtype=np.uint8,
).reshape(RobotObsShape.HEAD_IMAGE_SHAPE)
# 解析腕部摄像头图像数据
wrist_image = np.frombuffer(
message.buffer[RobotObsShape.HEAD_IMAGE_SIZE : RobotObsShape.HEAD_IMAGE_SIZE+ RobotObsShape.WRIST_IMAGE_SIZE],
dtype=np.uint8,
).reshape(RobotObsShape.WRIST_IMAGE_SHAPE)
# 解析本体感知数据 (float32数组)
proprioception = np.frombuffer(
message.buffer[-RobotObsShape.STATE_SIZE :],
dtype=np.float32,
).reshape(RobotObsShape.STATE_SHAPE)
logger.info("Received head_image, wrist_image, and joint_angle")
return {
"head_image": head_image,
"wrist_image": wrist_image,
"proprioception": proprioception,
}
def listening_mannual(self) -> None:
"""监听用户输入线程函数,用于手动控制状态"""
logger.info("Robot is listening...")
while True:
user_input = input("Press <Enter> or <q> to quit: ")
if user_input == "q":
self.state = RobotState.FINISHED # 退出程序
elif user_input == "i":
self.state = RobotState.INITALIZING # 初始化状态
elif user_input == "r":
self.state = RobotState.RESETTING # 重置状态
elif user_input == "f":
self.state = RobotState.FINISHED # 结束状态
else:
logger.info("Invalid input. Please press <Enter> or <q>.")
def _initialize(self) -> None:
"""初始化机器人到准备抓取位置"""
assert self.state == RobotState.INITALIZING
logger.info("Initializing robot...")
# 实际实现中这里会包含机械臂的初始化移动
self.state = RobotState.ACTING # 进入执行状态
logger.info("Robot initialized")
def _reset_socket(self) -> None:
"""重置ZeroMQ socket连接"""
logger.info("Resetting socket...")
self.socket.close()
self.context.term()
# 重新创建socket
self.context = zmq.Context()
self.socket = self.context.socket(zmq.REP)
self.socket.bind(f"tcp://*:{self.port}")
logger.info("Socket reset")
def _reset(self) -> None:
"""任务完成后重置机器人到初始位置"""
assert self.state == RobotState.RESETTING
logger.info("Resetting robot...")
# 实际实现中这里会包含机械臂的复位移动
self.state = RobotState.ACTING
logger.info("Robot reset")
def _get_obs(self) -> Optional[dict]:
"""获取并预处理观察数据"""
logger.info("Waiting for obs...")
message = self.socket.recv(copy=False) # 接收观察数据
obs = self._parse_obs(message)
if obs is None:
self._reset_socket() # 解析失败时重置socket
return None
# 更新传感器数据
self.head_image = obs["head_image"]
self.wrist_image = obs["wrist_image"]
self.proprioception = obs["proprioception"]
#self.head_image = cv2.imread("/home/fishros/hdx/tool/dataset_ori/imgs/0_130.jpg")
#self.wrist_image = cv2.imread("/home/fishros/hdx/tool/dataset_ori/imgs/0_130.jpg")
#self.proprioception = np.array([ 244.02, 39.33, 17.21, 291.47, 119.56, 75.05, 0.8], dtype=np.float32)
#self.proprioception = np.array([ 188.07692307692307,47.12087912087912,-3.1868131868131866,311.56043956043953,156.26373626373626,64.46153846153847,1], dtype=np.float32)
# 图像预处理 (插值和维度转换)
rgb_head = interpolate_image_batch(self.head_image[None, ...]).unsqueeze(0)
rgb_wrist = interpolate_image_batch(self.wrist_image[None, ...]).unsqueeze(0)
logger.info("Robot state updated")
return {
"rgb": rgb_head, # (1,1,3,H,W)
"right_cam_img": rgb_wrist, # (1,1,3,H,W)
"right_state": torch.from_numpy(self.proprioception)
.unsqueeze(0)
.unsqueeze(0), # (1,1,6)
}
def act(self, obs: dict) -> bool:
"""使用控制器模型预测并发送动作"""
# 将观察数据转移到模型设备
obs = dict_apply(obs, lambda x: x.to(self.controller.device))
# 模型推理 (无梯度计算)
with torch.no_grad():
actions = self.controller.predict_action(obs_dict=obs) # (B,64,action_dim)
# 处理动作数据
n_latency_steps = 3 # 延迟补偿步数
actions = (
actions.detach()
.cpu()
.numpy()[
0, n_latency_steps : self.executions_per_action_chunk + n_latency_steps
] # (executions_per_action_chunk, action_dim)
)
# 通过socket发送动作
logger.info(f"Sent action {actions}")
self.socket.send(actions.tobytes())
return True
def step(self) -> bool:
"""单步执行:获取观察->执行动作"""
logger.info("Waiting for obs...")
obs = self._get_obs()
if obs is None:
logger.error("Broken obs")
return False
logger.info("Robot state updated, acting...")
if not self.act(obs):
logger.error("Failed to send action")
return False
logger.info("Action sent, waiting for next obs...")
return True
def run(self) -> None:
"""机器人主控制循环"""
logger.info("Robot loop starting...")
assert self.state == RobotState.IDLE
self.state = RobotState.INITALIZING
# 状态机主循环
while True:
logger.info(f"run loop with robot state: {self.state}")
if self.state == RobotState.INITALIZING:
self._initialize()
elif self.state == RobotState.RESETTING:
self._reset()
elif self.state == RobotState.ACTING:
self.step() # 执行主要控制循环
elif self.state == RobotState.FINISHED:
logger.info("Robot loop finished, waiting for next command")
# 可在此处添加等待新指令的逻辑
else:
logger.error("Robot loop in unknown state.")
break
# OmegaConf解析器注册
def now_resolver(pattern: str):
"""处理${now:}时间格式化的解析器"""
return datetime.now().strftime(pattern)
# 注册自定义解析器
OmegaConf.register_new_resolver("now", now_resolver, replace=True)
OmegaConf.register_new_resolver("eval", eval, replace=True)
@hydra.main(version_base=None,config_path="config", config_name=pathlib.Path(__file__).stem)
def main(cfg):
"""程序入口点:初始化并运行机器人"""
robot = Robot(cfg)
robot.run()
if __name__ == "__main__":
main()这是发送actions的服务端,主要为8组关节角和夹爪状态,我这边接收端的消息解析有点问题, 输出的action都是none def parse_actions(self, action_bytes):
"""解析接收到的动作数据"""
# 验证消息长度
if len(action_bytes) != RobotObsShape.ACTIONS_SHAPE:
return None
actions = np.frombuffer(action_bytes, dtype=np.float32)
return actions.reshape(RobotObsShape.ACTIONS_SHAPE)
最新发布