from data_provider.data_factory import data_provider
from exp.exp_basic import Exp_Basic
from utils.tools import EarlyStopping, adjust_learning_rate, visual
from utils.metrics import metric
import torch
import torch.nn as nn
from torch import optim
import os
import time
import warnings
import numpy as np
warnings.filterwarnings('ignore')
#长期预测类
class Exp_Long_Term_Forecast(Exp_Basic):
#构造函数
def __init__(self, args):
super(Exp_Long_Term_Forecast, self).__init__(args)
#创建模型
def _build_model(self):
model = self.model_dict[self.args.model].Model(self.args).float()
#多gpu且gpu可用
if self.args.use_multi_gpu and self.args.use_gpu:
model = nn.DataParallel(model, device_ids=self.args.device_ids)
return model
#从data_provider函数获取数据集合和数据加载器,并提供标志(train,val,test)
def _get_data(self, flag):
data_set, data_loader = data_provider(self.args, flag)
return data_set, data_loader
#选择优化器,该函数使用adam优化器,从传入的参数self 添加self.args.learning_rate学习率
def _select_optimizer(self):
model_optim = optim.Adam(self.model.parameters(), lr=self.args.learning_rate)
return model_optim
#选择损失函数,MSELoss(均方误差损失)
def _select_criterion(self):
criterion = nn.MSELoss()
return criterion
#验证方法,通过计算模型验证的误差来评估模型性能,即向前传播时不根据学习率计算梯度
def vali(self, vali_data, vali_loader, criterion):
total_loss = []
#设置评估模式
self.model.eval()
with torch.no_grad():
for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(vali_loader):
#将转化为浮点型的数据加载到cpu或gpu
batch_x = batch_x.float().to(self.device)
batch_y = batch_y.float()
#将转化为浮点型的数据加载到cpu或gpu
batch_x_mark = batch_x_mark.float().to(self.device)
batch_y_mark = batch_y_mark.float().to(self.device)
# decoder input
#输出一个形状与输入一致的全零张量,并转化为浮点型格式
dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
#在给定维度对输入的张量序列进行连续操作,并加载到cpu或者gpu
dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
# encoder - decoder
if self.args.use_amp:
with torch.cuda.amp.autocast():
if self.args.output_attention:
outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
else:
outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
else:
if self.args.output_attention:
outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
else: