import os
import torch
import imageio
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from natsort import natsorted
Vector = [torch.Tensor, torch.Tensor]
def load_diabetes_data(csv_file_path: str, delim: str, data_type=np.float32) -> Vector:
if not os.path.exists(csv_file_path):
print('csv file not exists!')
x_y_data = np.loadtxt(csv_file_path, dtype=data_type, delimiter=delim)
x_data = torch.from_numpy(x_y_data[:, : -1])
y_data = torch.from_numpy(x_y_data[:, [-1]])
return [x_data, y_data]
class Model(torch.nn.Module):
def __init__(self)
loss曲线本地动态显示并保存成gif
最新推荐文章于 2023-11-01 20:56:47 发布
本文介绍了使用PyTorch库对糖尿病数据进行预处理、模型构建(包含线性层和sigmoid激活),以及通过训练过程动态显示损失曲线并将其转换为GIF动画。重点展示了如何读取CSV数据、定义模型结构、优化算法和训练过程的可视化。

最低0.47元/天 解锁文章
605

被折叠的 条评论
为什么被折叠?



