from data_provider.data_factory import data_provider
from exp.exp_basic import Exp_Basic
from utils.tools import EarlyStopping, adjust_learning_rate, cal_accuracy
import torch
import torch.nn as nn
from torch import optim
import os
import time
import warnings
import numpy as np
import pdb
warnings.filterwarnings('ignore')
class Exp_Classification(Exp_Basic):
def __init__(self, args):
super(Exp_Classification, self).__init__(args)
#创建模型
def _build_model(self):
# model input depends on data
train_data, train_loader = self._get_data(flag='TRAIN')
test_data, test_loader = self._get_data(flag='TEST')
self.args.seq_len = max(train_data.max_seq_len, test_data.max_seq_len)
self.args.pred_len = 0
self.args.enc_in = train_data.feature_df.shape[1]
self.args.num_class = len(train_data.class_names)
# model init
model = self.model_dict[self.args.model].Model(self.args).float()
if self.args.use_multi_gpu and self.args.use_gpu:
model = nn.DataParallel(model, device_ids=self.args.device_ids)
return model
#获取数据
def _get_data(self, flag):
data_set, data_loader = data_provider(self.args, flag)
return data_set, data_loader
#选择优化器
def _select_optimizer(self):
model_optim = optim.Adam(self.model.parameters(), lr=self.args.learning_rate)
return model_optim
#选择评估标准函数
def _select_criterion(self):
#交叉熵
criterion = nn.CrossEntropyLoss()
return criterion
#验证方法,通过计算模型验证的误差来评估模型性能
def vali(self, vali_data, vali_loader, criterion):
total_loss = []