在最新版本1.0.3,上 遇到d2l.torch库里面缺失train_ch3函数,下面是个人写的替代补充函数可以完全平替。
所有函数都放在util.py文件中
import torch.nn
from d2l import torch as d2l
from IPython import display
class Accumulator:
"""
在n个变量上累加
"""
def __init__(self, n):
self.data = [0.0] * n # 创建一个长度为 n 的列表,初始化所有元素为0.0。
def add(self, *args): # 累加
self.data = [a + float(b) for a, b in zip(self.data, args)]
def reset(self): # 重置累加器的状态,将所有元素重置为0.0
self.data = [0.0] * len(self.data)
def __getitem__(self, idx): # 获取所有数据
return self.data[idx]
def accuracy(y_hat, y):
"""
计算正确的数量
:param y_hat:
:param y:
:return:
"""
if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
y_hat = y_hat.argmax(axis=1) # 在每行中找到最大值的索引,以确定每个样本的预测类别
cmp = y_hat.type(y.dtype) == y
return float(cmp.type(y.dtype).sum())
def evaluate_accuracy(net, data_iter):
"""
计算指定数据集的精度
:param net:
:param data_iter:
:return:
"""
if isinstance(net, torch

文章介绍了如何在d2l.torch库的1.0.3版本缺失train_ch3函数时,自定义一个替代函数,包含了Accumulator类用于计算精度,以及train_epoch_ch3和evaluate_accuracy等关键训练和评估函数。此外,还提到了Fashion-MNIST数据集的应用。
最低0.47元/天 解锁文章

1863





