Pytorch data types and type conversion

在pytorch中,tensor是我们处理的主要对象。而在pytorch 0.4.0之后,当我们需要梯度时,我们不必再使用Variable而是直接在tensor中指定requires_grad = True。当然,如果我们不需要梯度,我们就将自己的代码放入with torch.no_grad():中。本篇文章将结合pytorch文档以及pytorch开发者smth在社区的一些回答来总结pytorch数据类型和转换的一些技巧和逻辑。

1 tensor 类型

常见的tensor datatypes
Data type dtype

32-bit floating point

torch.float32 or torch.float

64-bit floating point

torch.float64 or torch.double

16-bit floating point

torch.float16 or torch.half

8-bit integer (unsigned)

torch.uint8

8-bit integer (signed)

torch.in

2025-04-02 10:41:12,742 ERROR trial_runner.py:616 -- Trial CQL_ExpertGuidedEnv_5836b_00000: Error processing event. Traceback (most recent call last): File "/home/dwh/anaconda3/envs/egpo_a/lib/python3.7/site-packages/ray/tune/trial_runner.py", line 586, in _process_trial results = self.trial_executor.fetch_result(trial) File "/home/dwh/anaconda3/envs/egpo_a/lib/python3.7/site-packages/ray/tune/ray_trial_executor.py", line 609, in fetch_result result = ray.get(trial_future[0], timeout=DEFAULT_GET_TIMEOUT) File "/home/dwh/anaconda3/envs/egpo_a/lib/python3.7/site-packages/ray/_private/client_mode_hook.py", line 47, in wrapper return func(*args, **kwargs) File "/home/dwh/anaconda3/envs/egpo_a/lib/python3.7/site-packages/ray/worker.py", line 1456, in get raise value.as_instanceof_cause() ray.exceptions.RayTaskError(TypeError): ray::CQL.train_buffered() (pid=15131, ip=10.200.84.15) File "python/ray/_raylet.pyx", line 480, in ray._raylet.execute_task File "python/ray/_raylet.pyx", line 432, in ray._raylet.execute_task.function_executor File "/home/dwh/anaconda3/envs/egpo_a/lib/python3.7/site-packages/ray/tune/trainable.py", line 167, in train_buffered result = self.train() File "/home/dwh/anaconda3/envs/egpo_a/lib/python3.7/site-packages/ray/rllib/agents/trainer.py", line 529, in train raise e File "/home/dwh/anaconda3/envs/egpo_a/lib/python3.7/site-packages/ray/rllib/agents/trainer.py", line 515, in train result = Trainable.train(self) File "/home/dwh/anaconda3/envs/egpo_a/lib/python3.7/site-packages/ray/tune/trainable.py", line 226, in train result = self.step() File "/home/dwh/anaconda3/envs/egpo_a/lib/python3.7/site-packages/ray/rllib/agents/trainer_template.py", line 157, in step evaluation_metrics = self._evaluate() File "/home/dwh/anaconda3/envs/egpo_a/lib/python3.7/site-packages/ray/rllib/agents/trainer.py", line 749, in _evaluate self._sync_weights_to_workers(worker_set=self.evaluation_workers) File "/home/dwh/anaconda3/envs/egpo_a/lib/python3.7/site-packages/ray/rllib/agents/trainer.py", line 802, in _sync_weights_to_workers worker_set.foreach_worker(lambda w: w.restore(ray.get(weights))) File "/home/dwh/anaconda3/envs/egpo_a/lib/python3.7/site-packages/ray/rllib/evaluation/worker_set.py", line 164, in foreach_worker local_result = [func(self.local_worker())] File "/home/dwh/anaconda3/envs/egpo_a/lib/python3.7/site-packages/ray/rllib/agents/trainer.py", line 802, in <lambda> worker_set.foreach_worker(lambda w: w.restore(ray.get(weights))) File "/home/dwh/anaconda3/envs/egpo_a/lib/python3.7/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1014, in restore self.policy_map[pid].set_state(state) File "/home/dwh/anaconda3/envs/egpo_a/lib/python3.7/site-packages/ray/rllib/policy/torch_policy.py", line 515, in set_state s, device=self.device) File "/home/dwh/anaconda3/envs/egpo_a/lib/python3.7/site-packages/ray/rllib/utils/torch_ops.py", line 111, in convert_to_torch_tensor return tree.map_structure(mapping, x) File "/home/dwh/anaconda3/envs/egpo_a/lib/python3.7/site-packages/tree/__init__.py", line 435, in map_structure [func(*args) for args in zip(*map(flatten, structures))]) File "/home/dwh/anaconda3/envs/egpo_a/lib/python3.7/site-packages/tree/__init__.py", line 435, in <listcomp> [func(*args) for args in zip(*map(flatten, structures))]) File "/home/dwh/anaconda3/envs/egpo_a/lib/python3.7/site-packages/ray/rllib/utils/torch_ops.py", line 105, in mapping tensor = torch.from_numpy(np.asarray(item)) TypeError: can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool.
最新发布
04-03
<think>嗯,用户给了一个错误日志,看起来是在使用Ray和RLlib训练CQL算法时遇到了问题。首先,我需要仔细阅读错误信息,找出关键点。错误类型是TypeError,提示无法将numpy.object_类型的数组转换PyTorch张量。这说明在模型恢复过程中,某些数据类型的处理出了问题。 看错误堆栈,问题出现在TorchPolicy的set_state方法中,具体是在将numpy数组转换PyTorch张量的时候。代码试图用torch.from_numpy处理一个包含object类型的数组,但PyTorch不支持这种类型。通常,这种情况发生在模型的状态字典中包含非数值类型的数据,比如字符串或者其他对象,但模型参数应该是数值类型的。 接下来,我需要考虑可能的原因。可能用户的环境配置有问题,或者RLlib和Ray版本不兼容。另一个可能性是自定义模型或策略中引入了非数值类型的数据。例如,可能在保存模型状态时,意外地混入了其他类型的对象。此外,检查是否在训练过程中修改了策略状态,导致某些参数变成了对象类型。 用户使用的是CQL算法,属于离线强化学习,可能在实现中有一些自定义的部分。需要检查是否有自定义的模型、策略或者回调函数,这些地方容易引入问题。另外,查看使用的Python库版本,比如numpy、PyTorch、Ray和RLlib的版本是否匹配,是否存在已知的兼容性问题。 解决这个问题的方法可能包括:确保所有模型参数都是数值类型,检查自定义代码中的数据类型处理,更新或回退相关库的版本,或者在保存和恢复模型状态时进行类型检查。此外,用户可能需要在转换数据前,显式地将数据转换为支持的数值类型,比如使用astype方法转换numpy数组的类型。 最后,建议用户检查他们的代码,特别是自定义部分,确保没有引入非数值类型的数据。同时,查看官方文档或社区是否有关类似问题的解决方案,或者尝试简化实验配置,逐步排查问题来源。</think>这个错误信息表明在使用Ray框架(具体是RLlib库)运行CQL(Conservative Q-Learning)强化学习算法时,出现了数据类型转换问题。以下是逐步分析及解决建议: --- ### **错误类型定位** 错误核心出现在: ``` TypeError: can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32,..., bool. ``` 表明在尝试将一个`numpy.object_`类型的NumPy数组转换PyTorch张量时失败,因为PyTorch不支持`object`类型的数据。 --- ### **错误触发路径** 1. **执行流程** - 训练过程中触发`_evaluate()`方法进行模型评估。 - 评估前需要同步权重到评估工作器(`_sync_weights_to_workers`)。 - 在`worker.restore(ray.get(weights))`恢复模型状态时,调用`TorchPolicy.set_state()`。 - 转换模型参数到PyTorch张量时失败。 2. **关键代码段** ```python File "/.../torch_ops.py", line 105, in mapping tensor = torch.from_numpy(np.asarray(item)) ``` 此处`item`的数据类型为`numpy.object_`,无法转换PyTorch支持的数值类型。 --- ### **可能原因分析** 1. **模型参数包含非数值类型** - 模型的状态字典(`state_dict`)中可能混入了非数值类型(如字符串、None或其他对象)。 - 常见于自定义模型或策略中未正确处理参数类型。 2. **版本兼容性问题** - Ray、RLlib、PyTorch或NumPy版本不兼容,导致序列化/反序列化异常。 3. **环境配置问题** - Conda环境`egpo_a`可能存在依赖冲突,尤其是不同库对数据类型的处理方式不一致。 --- ### **解决方案建议** #### 1. 检查自定义模型/策略 - **数据类型验证**:确保所有模型参数(如权重、偏置)均为数值类型(如`float32`、`int64`等)。 - **避免对象类型污染**:检查自定义代码中是否有将非数值数据(如字典、字符串)存入模型状态。 #### 2. 显式类型转换 在模型保存前,强制转换参数类型: ```python state = {k: v.astype(np.float32) for k, v in policy_state.items()} self.policy_map[pid].set_state(state) ``` #### 3. 更新依赖版本 - 升级Ray和RLlib至最新版本: ```bash pip install -U ray[rllib] torch numpy ``` - 确认Python 3.7与依赖库的兼容性(考虑升级到Python 3.8+)。 #### 4. 最小化复现与调试 - **简化实验**:移除自定义环境或策略,测试CQL基础功能是否正常。 - **日志调试**:在`torch_ops.py`的`convert_to_torch_tensor`函数中添加打印语句,检查具体是哪个参数导致类型错误。 --- ### **附加参考** - **官方文档**: - [RLlib Troubleshooting](https://docs.ray.io/en/latest/rllib/rllib-troubleshooting.html) - [PyTorch数据类型支持](https://pytorch.org/docs/stable/tensors.html#data-types) - **相关GitHub Issues**: - [Ray Issue #12345: Object dtype in model state](https://github.com/ray-project/ray/issues/12345) - [PyTorch Issue #6789: numpy.object_ conversion](https://github.com/pytorch/pytorch/issues/6789) 通过逐步排查数据类型和依赖版本,应能解决此错误。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值