trainer的父类
from typing import Mapping, Dict, Optional
import torch
from ignite.contrib.handlers import ProgressBar
from ignite.engine import Engine, Events
from ignite.handlers import TerminateOnNan, EarlyStopping
from ignite.metrics import Loss, RunningAverage
from src.exception import ModelNotFoundException
from src.experiment import Number
class Trainer(object):
def __init__(self, *, model: torch.nn.Module = None, file: str = None, save: str = None, device: str = None):
if model is not None:
self.model = model
elif file is not None:
self.model = torch.load(file, map_location=device)
else:
raise ModelNotFoundException("模型未定义,请传入 torch.nn.Module 对象或可加载的模型的文件路径.")
if device is not None:
self.device = device
else:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
if save is not None:
self.save: str = save
else:
raise ValueError("模型存储路径未定义!")
self.metrics: Dict = {
"MSE": Loss(torch.nn.MSELoss())}
self.trainer: Optional[Engine] = None
self.evaluator: Optional[Engine] = None
def set_dataset(self