def accuracy(y_hat, y): #@save
"""计算预测正确的数量"""
if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
# y_hat.shape返回y_hat的形状,len()返回第一维度的长度
#如果y_hat不是标量且y_hat的列>1
y_hat = y_hat.argmax(axis=1)
#按列降维(列相加)
cmp = y_hat.type(y.dtype) == y
#
return float(cmp.type(y.dtype).sum())
1.size ():返回数组中元素总个数;
2.shape():返回数组各个维度对应长度;
3.len ():返回数组第一维度的长度。