今天帮妹子调试tensorflow的程序,遇到了nan的问题,找了好久终于解决,也没辜负妹子。
最终找到了问题是tf.sqrt, 引自stackoverflow, Why is my loss function returning nan?
解释为:
It was coming from the fact that x was approaching a tensor with all zeros for entries. This was making the derivative of sigma wrt x a NaN. The solution was to add a small value to that quantity.
也就是tf.sqrt(x), x为0导致的nan的问题。 当x为0, 导致导数为NaN。
解决方案: 加一个极小量避免x为0,也就是:
tf.sqrt(x+1e-8)
再给大家推荐一下知乎相似问题,很有借鉴意义。为什么用tensorflow训练网络,出现了loss=nan,accuracy总是一个固定值?