问题描述
在使用GRU训练序列数据时报错

报错部位代码为:
self.gru = torch.nn.GRU(hidden_size, hidden_size, n_layers,bidirectional=self.n_directions)
解决方案:
解决方案来源:https://github.com/Sundrops/video-caption.pytorch/issues/4
在给bidirectional传参时,将整型转为bool型。
self.gru = torch.nn.GRU(hidden_size, hidden_size, n_layers,bidirectional=bool(self.n_directions))
运行后成功解决
原因分析:
事后再来看这个报错,确实是需要传入一个bool值的,造成这样的错误可能是接口不知道什么时候又更新了?

在使用GRU进行序列数据训练时遇到错误,原因是torch.nn.GRU模块的bidirectional参数需要传入bool值而非整型。通过将bidirectional参数转换为bool类型(例如bool(self.n_directions)),成功解决了报错,表明可能是库的接口要求更新导致的错误。
1309

被折叠的 条评论
为什么被折叠?



