文章目录
passion!!!
要解决的问题
故事是这样的,我在计算一个张量的均值mean().item()
的时候,发现计算得到结果在样本数量变多的时候总是得到NaN
的结果,我真的很纳闷。
出现问题的原因
很有可能是因为样本的数值太接近0了,导致计算均值时出现 NaN 通常意味着张量中包含了无效值,比如 NaN 或 inf(无穷大)。
解决方法:
1. 检查数据中是否有 NaN 或 inf 值
在计算均值之前检查数据中是否存在 NaN 或 inf 值,并做相应处理:
import torch
# 检查是否存在 NaN 或 inf
if torch.isnan(flattened_data1).any():
print("Data1 contains NaN values.")
if torch.isinf(flattened_data1).any():
print("Data1 contains inf values.")
2. 过滤 NaN 和 inf 值
如果确认张量中存在无效值,可以用以下方法过滤掉这些值:
# 过滤 NaN 和 inf 值
flattened_data1 = flattened_data1[~torch.isnan(flattened_data1)]
flattened_data1 = flattened_data1[~torch.isinf(flattened_data1)]
3. 用有效的默认值填充无效值
如果希望保留数组的形状,可以用默认值替换 NaN 或 inf:
# 用 0 替换 NaN 和 inf
flattened_data1 = torch.nan_to_num(flattened_data1, nan=0.0, posinf=0.0, neginf=0.0)
我采取的方案:找到然后替换为0
# 检查并处理 flattened_data1 中的 NaN 和 inf
if torch.isnan(flattened_data1).any():
print("Warning: flattened_data1 contains NaN values.")
flattened_data1 = torch.nan_to_num(flattened_data1, nan=0.0) # 将 NaN 替换为 0
if torch.isinf(flattened_data1).any():
print("Warning: flattened_data1 contains Inf values.")
flattened_data1 = torch.nan_to_num(flattened_data1, posinf=0.0, neginf=0.0) # 将 inf 替换为 0
检查张量的数据类型
要查看一个张量的数据类型,可以使用属性dtype来查看,不要用type。也就是说,使用tensor_a.dtype
, 不要用tensor_a.type
。