keras 时间序列数据预测与结果分析

本文介绍了一个用于预测股市的深度学习模型,包括CNN、GRU、ResNet和混合Attention的训练和评估过程。模型通过tushare接口获取历史数据,并利用Keras框架实现,详细展示了模型在不同股票和时间范围内的准确率评估方法。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

数据来源

使用tushare的接口获取股票历史数据.

# 安装 tushare
pip3 install tushare -i https://pypi.tuna.tsinghua.edu.cn/simple 

需要在他们网站注册然后修改下个人资料,然后拿个token才能用.

# 设置 token
>>> ts.set_token('*****************')

tushare返回的是panda.dataframe格式的数据.
使用时,先将数据下载到 “…/data/” 文件夹下,避免每次使用都要调用网络接口.
要更新数据,将 “…/data/” 文件夹清空即可.

文件结构

  • project
    • data # 存放数据
    • market # 代码和模型
      • model # 存放模型
      • load_tools.py # 加载模型,数据
      • get_tools.py # 提供数据
      • get_samples.py # 构建模型需要的数据样本
      • new_generator.py # 数据样本生成器
      • dotrain.py # 训练模型
      • evaluate_model.py # 衡量模型各种性能
      • history_predict.py # 显示模型的历史预测曲线
      • serch_predict.py # 搜索目标时间预测的结果
      • run.py # 执行今日预测
      • run.bat # 执行今日预测

代码

load_tools.py

import os

import keras
import pandas as pd
import tushare as ts
from keras import backend as K


# 模型衡量标准
# 正样本中有多少被识别为正样本
def recall(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    real_true = K.sum(y_true)
    return true_positives / (real_true + K.epsilon())


def recall1(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * (y_pred - 0.2), 0, 1)))
    real_true = K.sum(y_true)
    return true_positives / (real_true + K.epsilon())


def recall2(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * (y_pred - 0.4), 0, 1)))
    real_true = K.sum(y_true)
    return true_positives / (real_true + K.epsilon())


# 负样本中有多少被识别为负样本
def n_recall(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip((1 - y_true) * (1 - y_pred), 0, 1)))
    real_true = K.sum(1 - y_true)
    return true_positives / (real_true + K.epsilon())


def n_recall1(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip((1 - y_true) * ((1 - y_pred) - 0.2), 0, 1)))
    real_true = K.sum(1 - y_true)
    return true_positives / (real_true + K.epsilon())


def n_recall2(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip((1 - y_true) * ((1 - y_pred) - 0.4), 0, 1)))
    real_true = K.sum(1 - y_true)
    return true_positives / (real_true + K.epsilon())


# 识别为正样本中有多少是正样本
def precision(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    predict_true = K.sum(K.round(K.clip(y_pred, 0, 1)))
    return true_positives / (predict_true + K.epsilon())


def precision1(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * (y_pred - 0.2), 0, 1)))
    predict_true = K.sum(K.round(K.clip((y_pred - 0.2), 0, 1)))
    return true_positives / (predict_true + K.epsilon())


def precision2(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * (y_pred - 0.4), 0, 1)))
    predict_true = K.sum(K.round(K.clip((y_pred - 0.4), 0, 1)))
    return true_positives / (predict_true + K.epsilon())


# 识别为负样本中有多少是负样本
def n_precision(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip((1 - y_true) * (1 - y_pred), 0, 1)))
    predict_true = K.sum(K.round(K.clip((1 - y_pred), 0, 1)))
    return true_positives / (predict_true + K.epsilon())


def n_precision1(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip((1 - y_true) * ((1 - y_pred) - 0.2), 0, 1)))
    predict_true = K.sum(K.round(K.clip(((1 - y_pred) - 0.2), 0, 1)))
    return true_positives / (predict_true + K.epsilon())


def n_precision2(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip((1 - y_true) * ((1 - y_pred) - 0.4), 0, 1)))
    predict_true = K.sum(K.round(K.clip(((1 - y_pred) - 0.4), 0, 1)))
    return true_positives / (predict_true + K.epsilon())


# 预测结果中有多少是正样本
def prate(y_true, y_pred):
    return K.mean(K.round(K.clip(y_pred, 0, 1)))


# 实际中有多少是正样本
def trate(y_true, y_pred):
    return K.mean(K.round(K.clip(y_true, 0, 1)))


# 加载模型
def load_model(model_name='./model/cnn960to1080b.model', lookback=61, shape=5):
    # 加载模型时使用 keras.models.load_model(path, custom_objects=dependencies)
    dependencies = {
        'recall': recall,
        'recall1': recall1,
        'recall2': recall2,
        'precision': precision,
        'precision1': precision1,
        'precision2': precision2,
        'prate': prate,
        'trate': trate,
        'lookback': lookback,
        'shape': shape
    }

    model = keras.models.load_model(model_name, custom_objects=dependencies)
    model.compile(optimizer=keras.optimizers.RMSprop(),
                  loss=keras.losses.binary_crossentropy,
                  metrics=[recall, precision, recall1, precision1, recall2, precision2,
                           n_recall, n_precision, n_recall1, n_precision1, n_recall2, n_precision2, trate, prate])
    return model


# 获取一支股票的历史数据
def load_data(ts_code):
    # 判断文件是否存在,不存在则通过网络接口获得
    data_dir = '../data/'
    if not os.path.exists(data_dir + ts_code + '.csv'):
        # 初始化pro接口
        # pro = ts.pro_api('********************************')
        # 获取前复权数据
        df = ts.pro_bar(ts_code=ts_code, adj='qfq')
        # 保存数据到文件
        if df is None:
            print('can not get data')
            return
        df.to_csv(data_dir + ts_code + '.csv', index=False)
    df = pd.read_csv(data_dir + ts_code + '.csv')
    # ts_code, trade_date, open, high, low, close, pre_close, change, pct_chg, vol, amount, adj_factor
    # 股票代码, 交易日期, 开盘价, 最高价, 最低价, 收盘价, 昨收价, 涨跌额, 涨跌幅, 成交量, 成交额(千元)
    # 去空
    df.dropna(inplace=True)
    # 正序
    df = df.sort_index(ascending=False)
    # 索引重排序
    df.reset_index(drop=True, inplace=True)
    return df


# 加载股票列表
def load_code_list(market='SSE'):
    file_dir = '../data/' + 'code_list_' + market + '.csv'
    # 判断文件是否存在,不存在则通过网络接口获得
    if os.path.exists(file_dir):
        code_list = pd.read_csv(file_dir)
    else:
        # 初始化pro接口
        pro = ts.pro_api('*****************************')
        # 查询某交易所所有上市公司
        code_list = pro.stock_basic(exchange=market, list_status='L', fields='ts_code')  # ,symbol,name,market,list_date
        # 保存数据到文件
        code_list.to_csv(file_dir, index=False)
    code_list = code_list[['ts_code']].values.flatten()
    return code_list


# 根据模式输出
def print_verbose(verbose, text):
    if verbose:
        print(text)

get_tools.py

import numpy as np
from load_tools import *
from matplotlib import pyplot as plt


# 加载所需数据
def init(market_list, normalize=False):
    market_names = []
    market_datas = []
    market_datas_normal = []
    market_datas_date = []
    for i in range(len(market_list)):
        print('Load ', market_list[i])
        market_names.append([])
        market_datas.append([])
        market_datas_normal.append([])
        market_datas_date.append([])
        print('正在加载数据进入内存')
        for code_name in load_code_list(market=market_list[i]):
            print(code_name)
            market_names[i].append(code_name)
            market_datas[i].append(load_data(code_name))
        print('数据加载完毕')
        print('正在检查数据')
        for j in range(len(market_names[i])):
            # 查空
            if market_datas[i][j] is None:
                market_datas_normal[i].append(None)
                market_datas_date[i].append(None)
                continue
            if market_datas[i][j].empty:
                market_datas_normal[i].append(None)
                market_datas_date[i].append(None)
                continue
            # data = market_datas[i][j][['close', 'high', 'low', 'amount']].values
            data = market_datas[i][j][['close', 'open', 'high', 'low', 'amount']].values
            # 检查是否有错误值
            if np.isnan(data).any():
                print('nan in %s' % market_names[i][j])
                market_datas_normal[i].append(None)
                market_datas_date[i].append(None)
                continue
            # 进行正规化
            if normalize:
                mean = data.mean(axis=0)  # [6.98017146e+00, 7.12046020e+00, 6.83100609e+00, 1.65669341e+05]
                std = data.std(axis=0)  # [6.36818017e+00, 6.50689074e+00, 6.22204203e+00, 4.74562019e+05]
            else:
                mean = [0]
                std = [1]
            data -= mean
            if data.std(axis=0)[0] == 0:
                print('std is 0 in %s' % market_names[i][j])
                market_datas_normal[i].append(None)
                market_datas_date[i].append(None)
                continue
            data /= std
            market_datas_normal[i].append([data, mean, std])
            market_datas_date[i].append(market_datas[i][j]['trade_date'].tolist())
        print('数据检查完成')
    return market_names, market_datas, market_datas_normal, market_datas_date


# 加载所需数据
market_list = ['SSE', 'SZSE']
market_names, market_datas, market_datas_normal, market_datas_date = init(market_list, True)
date_list = market_datas[0][0]['trade_date'].values.tolist()


def get_data(ts_code):
    for i in range(len(market_list)):
        if ts_code in market_names[i]:
            return market_datas[i][market_names[i].index(ts_code)]


def get_data_normal(ts_code):
    for i in range(len(market_list)):
        if ts_code in market_names[i]:
            return market_datas_normal[i][market_names[i].index(ts_code)]


def get_data_date(ts_code):
    for i in range(len(market_list)):
        if ts_code in market_names[i]:
            return market_datas_date[i][market_names[i].index(ts_code)]


def get_code_list(market='SSE'):
    if market == 'ALL':
        ALL_names = []
        for i in range(len(market_list)):
            ALL_names += market_names[i]
        return ALL_names
    else:
        return market_names[market_list.index(market)]


# 训练历史可视化
def show_train_history(train_history, train_metrics, validation_metrics):
    plt.plot(train_history.history[train_metrics])
    plt.plot(train_history.history[validation_metrics])
    # plt.title('Train History')
    plt.ylabel(train_metrics)
    plt.xlabel('Epoch')
    plt.legend(['train', 'validation'], loc='upper left')


# 显示训练过程
def plot_history(history):
    plt.figure(figsize=(12, 8))
    plt.subplot(2, 2, 1)
    show_train_history(history, 'loss', 'val_loss')
    plt.subplot(2, 2, 2)
    show_train_history(history, 'recall', 'val_recall')
    plt.subplot(2, 2, 3)
    show_train_history(history, 'precision', 'val_precision')
    plt.subplot(2, 2, 4)
    show_train_history(history, 'precision2', 'val_precision2')
    plt.savefig('./model/auto_save.jpg')
    plt.show()


# 通过时间搜索index
def date2index(ts_code, start_date, end_date, lookback, delay, verbose=0):
    start = lookback - 1
    dl = get_data_date(ts_code)
    if not dl:
        return
    if start_date != '':
        if start_date not in dl:
            print_verbose(verbose, 'can not find date')
            return
        else:
            start = max(start, dl.index(start_date))
    end = len(dl) - delay
    if end_date != '':
        if end_date not in dl:
            print_verbose(verbose, 'can not find date')
            return
        else:
            end = min(end, dl.index(end_date))
    if start >= end:
        print_verbose(verbose, 'data range too small, may be date too close to boundary.')
        return
    return start, end

L36:根据提供的数据筛选某些值进行训练,这里选了[‘close’, ‘open’, ‘high’, ‘low’, ‘amount’],分别是收盘价,开盘价,最高值,最低值,成交量.由于后面的预测都是根据每日的收盘价进行,所以收盘价必需在这个列表的第一位
L44:若normalize=True,这里对每支股票单独进行正规化,这样可能会导致模型发现该股票的最高值和最低值,但是后来发现好像影响不大.
L64:market_list 这里进行设置要使用的交易所的数据 ‘SSE’:上海交易所,‘SZSE’:深圳交易所
L65:开始加载数据.init的第二个参数选择是否进行正规化.第一次运行可能需要差不多一个小时来下载数据.
L108:plot_history()根据训练的类型来调整这个函数里各项的值以显示需要的学习曲线.

get_samples.py

import random

import numpy as np

from get_tools import *


# 给出一支股票某段时间的Samples
def get_samples(ts_code='600004.SH', date=20191108, lookback=61, duiring=20, verbose=1, normalize=True):
    # 获取数据
    data_normal = get_data_normal(ts_code)
    if data_normal is None:
        print_verbose(verbose, 'can not find date normal')
        return
    # 获取标准化
    data = data_normal[0]
    mean = data_normal[1]
    std = data_normal[2]

    # 找到预测集
    se_index = date2index(ts_code, date, '', lookback, 0, verbose)
    if not se_index:
        print_verbose(verbose, 'can not get date')
        return
    i = se_index[0] + 1
    rows = np.arange(i - duiring + 1, i + 1)
    samples = np.zeros((len(rows),
                        lookback,
                        data.shape[-1]))
    # targets = np.zeros((len(rows),))
    for j, row in enumerate(rows):
        if rows[j] - lookback < 0:
            print_verbose(verbose, 'date range too small in %s' % ts_code)
            return
        indices = range(rows[j] - lookback, rows[j])
        samples[j] = data[indices]
    return samples


# 给出一支股票某段时间的Samples和Targets
def get_samples_targets(ts_code='600004.SH', start_date='', end_date='', lookback=61, delay=1, uprate=0.0, mod='', rand=False, verbose=1):
    # 获取数据
    data_normal = get_data_normal(ts_code)
    if data_normal is None:
        print_verbose(verbose, 'can not find date normal')
        return
    # 获取标准化
    data = data_normal[0]
    mean = data_normal[1]
    std = data_normal[2]

    # 找到起点终点位置
    se_index = date2index(ts_code, start_date, end_date, lookback, delay, verbose)  # 0.08
    if se_index is None:
        return
    start, end = se_index

    # 随机抽取一个
    if rand:
        start = random.randint(start, end - 1)
        end = start + 1

    # 构建
    rows = np.arange(start, end)
    samples = np.zeros((len(rows),
                        lookback,
                        data.shape[-1]))
    targets = np.zeros((len(rows),))
    for j, row in enumerate(rows):
        indices = range(rows[j] - (lookback - 1), rows[j] + 1)
        samples[j] = data[indices]
        # 涨跌值
        if mod == 'delta':
            targets[j] = (data[row + delay][0] * std[0] + mean[0]) - (data[row][0] * std[0] + mean[0])
            continue
        # 涨跌幅
        if mod == 'rate':
            targets[j] = (data[row + delay][0] * std[0] + mean[0]) / (data[row][0] * std[0] + mean[0]) - 1
            continue
        # 是否上涨
        if data[row + delay][0] * std[0] + mean[0] > (data[row][0] * std[0] + mean[0]) * (1 + uprate):
            targets[j] = 1
        else:
            targets[j] = 0
    return samples, targets


# 计算每只股票一段时间内的sample大小
def count_samples_weight(market, start_date='', end_date='', lookback=61, delay=1, verbose=1):
    code_list = get_code_list(market=market)
    names = []
    weight = []
    for code_name in code_list:
        print_verbose(verbose, code_name)
        df = get_data(code_name)
        if df is None:
            print_verbose(verbose, 'can not find data')
            continue
        # 找到起点终点位置
        se_index = date2index(code_name, start_date, end_date, lookback, delay, verbose)
        if se_index is None:
            continue
        start, end = se_index
        names.append(code_name)
        weight.append(end - start)
    return names, weight

get_samples_targets()中mod参数控制产生target的方式,取决于是要预测涨跌幅mod=‘delta’,涨跌值mod=‘rate’还是是否上涨mod=’’.

new_generator.py

import os
import random

import tushare as ts
import numpy as np
import pandas as pd

from get_tools import *
from get_samples import *


def new_generator(market='SSE', batch_size=1024, shape=4, start_date='', end_date='', lookback=61, delay=1,
                  uprate=0.0):
    # 加载权重
    # print('init generator')
    data = count_samples_weight(market, start_date=start_date, end_date=end_date, lookback=lookback,
                                delay=delay, verbose=0)
    if not data[0]:
        print('can not get data, maybe date wrong')
        return
    samples = np.zeros((batch_size,
                        lookback,
                        shape))
    targets = np.zeros((batch_size,))
    while 1:
        for i in range(batch_size):
            name = random.choices(data[0], data[1])[0]
            sample, target = get_samples_targets(ts_code=name, start_date=start_date, end_date=end_date,
                                                 lookback=lookback, delay=delay, uprate=uprate, rand=True, mod='')
            samples[i] = sample[0]
            targets[i] = target[0]
        yield samples, targets

evaluate_model.py

import os
import random

import tushare as ts
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd

import tensorflow as tf
import keras
from keras.models import Sequential
from keras import layers
from keras.optimizers import RMSprop
from keras import backend as K
import keras.backend.tensorflow_backend as KTF

from get_samples import get_samples_targets
from get_tools import *
from make_generators import make_generators
from new_generator import new_generator


# 衡量模型对一支股票的准确率
def evaluate_old(model, ts_code='600004.SH'):
    print(ts_code)
    generator = make_generators(ts_code, shuffle=False, batch_size='auto')
    if generator is None:
        return
    result = model.evaluate_generator(generator[0], steps=1)
    print(result)


# 根据正输出的价格差衡量模型(结果为按照模型交易每个交易日平均价格变化,只适用于delay=1)
def evaluate_delta(model, ts_code='600004.SH', start_date='', end_date='', lookback=61, delay=1, base_line=0.5, verbose=1):
    data = get_samples_targets(ts_code=ts_code, start_date=start_date, end_date=end_date,
                               lookback=lookback, delay=delay, mod='delta', verbose=verbose)
    if data is None:
        return
    result = model.predict(data[0])
    result = result.T[0]
    if base_line < 0.5 and base_line != 0.0:
        predict = 1 - np.round(result - base_line + 0.5)
    else:
        predict = np.round(result - base_line + 0.5)
    return sum(predict * data[1]) / sum(predict)


# 衡量模型对一支股票的准确率
def evaluate(model, ts_code='600004.SH', start_date='', end_date='', lookback=61, delay=1):
    print(ts_code)
    data = get_samples_targets(ts_code=ts_code, start_date=start_date, end_date=end_date,
                               lookback=lookback, delay=delay)
    if data is None:
        return
    result = model.evaluate(data[0], data[1], batch_size=9999, verbose=0)
    return result


# 衡量模型对所有股票的准确率
def evaluate_total(model, market='ALL', steps=10, shape=5, start_date='', end_date='', lookback=61, delay=1, uprate=0.0):
    generator = new_generator(market=market,
                              shape=shape,
                              start_date=start_date,
                              end_date=end_date,
                              lookback=lookback,
                              delay=delay,
                              uprate=uprate,
                              batch_size=len(get_code_list(market)))
    test = next(generator)
    if test is None:
        return
    result = model.evaluate_generator(generator, steps=steps)
    return result


# 批量衡量模型对每支股票的准确率
def evaluate_all(model, market='SSE', start_date='', end_date='', lookback=61, delay=1):
    # 加载股票列表
    code_list = get_code_list(market=market)
    for code_name in code_list[:]:
        result = evaluate(model=model, ts_code=code_name, start_date=start_date, end_date=end_date,
                          lookback=lookback, delay=delay)
        print(result)


# 批量衡量模型对每支股票的delta
def evaluate_all_delta(model, market='SSE', start_date='', end_date='', lookback=61, delay=1, base_line=0.5, verbose=0):
    # 加载股票列表
    code_list = get_code_list(market=market)
    sum_list = []
    for code_name in code_list[:]:
        print_verbose(verbose, code_name)
        result = evaluate_delta(model=model, ts_code=code_name, start_date=start_date, end_date=end_date,
                                lookback=lookback, delay=delay, base_line=base_line, verbose=verbose)
        print_verbose(verbose, result)
        if result and not np.isnan(result):
            sum_list.append(result)
    print("平均:", np.average(sum_list))
    return sum_list


# 按时间衡量模型准确度
def evaluate_total_time(model, date_step=61, steps=3, start_date='20170103', end_date='', lookback=61, delay=1, uprate=0.0):
    # 计算起止index
    if start_date == '':
        start = 0
    elif int(start_date) not in date_list:
        print('can not find date')
        return
    else:
        start = date_list.index(int(start_date))
    if end_date == '':
        end = len(date_list)
    elif int(end_date) not in date_list:
        print('can not find date')
        return
    else:
        end = date_list.index(int(end_date))
    # 开始计算
    dates = []
    results = []
    for i in range(start, end, date_step):
        if i + date_step >= end:
            continue
        date = '%s : %s' % (date_list[i], date_list[i + date_step])
        print(date)
        result = evaluate_total(model, market='ALL', steps=steps, start_date=date_list[i],
                                end_date=date_list[i + date_step], lookback=lookback, delay=delay, uprate=uprate)
        if result:
            dates.append(date)
            results.append(result)
            print(result)
    plt.plot([i[2] for i in results], label='acc5', c='green')
    plt.plot([i[4] for i in results], label='acc7', c='blue')
    plt.plot([i[6] for i in results], label='acc9', c='red')
    plt.plot([i[1] for i in results], label='rec5', c='lightgreen')
    plt.plot([i[3] for i in results], label='rec7', c='lightblue')
    plt.plot([i[5] for i in results], label='rec9', c='pink')
    # plt.plot([i[8] for i in results], label='n_acc5', c='green')
    # plt.plot([i[10] for i in results], label='n_acc3', c='blue')
    # plt.plot([i[12] for i in results], label='n_acc1', c='red')
    # plt.plot([i[7] for i in results], label='n_rec5', c='lightgreen')
    # plt.plot([i[9] for i in results], label='n_rec3', c='lightblue')
    # plt.plot([i[11] for i in results], label='n_rec1', c='pink')
    plt.plot([i[13] for i in results], label='Trate', c='black')
    plt.plot([i[14] for i in results], label='Prate', c='brown')
    plt.legend()
    plt.show()
    return dates, results

大部分衡量方法只对进行预测是否上涨(即mod=’’)的模型有效

衡量对所有股票的准确率
evaluate_total()

In [1]: evaluate_total(model)
Out[1]:
[0.6048318326473237,	# loss
 0.61862553358078,		# 0.5 recall
 0.6655296385288239,	# 0.5 precision
 0.26063627749681473,	# 0.7 recall
 0.8474650800228118,	# 0.7 precision
 0.08491439446806907,	# 0.9 recall
 0.9620700001716613,	# 0.9 precision
 0.6881303012371063,	# 负结果 0.5 recall
 0.6426444888114929,	# 负结果 0.5 precision
 0.2039469599723816,	# 负结果 0.3 recall
 0.8001361966133118,	# 负结果 0.3 precision
 0.03344998843967915,	# 负结果 0.1 recall
 0.945061206817627,		# 负结果 0.1 precision
 0.5008360147476196,	# 预测得正的比例
 0.4655305534601212]	# 实际为正的比例

对每支股票按照模型的输出交易,衡量平均每个交易日的价格变化
evaluate_all_delta()

In [11]: sum_list = evaluate_all_delta(model, market='ALL', start_date=20191108, end_date='', base_line=0.9, delay=1)
平均: 0.01869930191972075

按时间衡量模型准确度
evaluate_total_time()
date_step:每个时间段持续时间
steps:计算次数,次数越多结果越准确

In [13]: evaluate_total_time(model, date_step=61, steps=3, start_date=20170103, end_date='', lookback=61, delay=1, uprate=0.0)
20170103 : 20170407
[0.5714327494303385, 0.5911784172058105, 0.7350501616795858, 0.29650260011355084, 0.8807578881581625, 0.09752903630336125, 0.9589986602465311, 0.7945913275082906, 0.6683776577313741, 0.3276803294817607, 0.8117450277010599, 0.03380454579989115, 0.9170262813568115, 0.49095123012860614, 0.3948471049467723]
20170407 : 20170706
[0.5693944891293844, 0.5984461506207784, 0.7022876739501953, 0.29884132742881775, 0.861210823059082, 0.11071240405241649, 0.9637072682380676, 0.7712886929512024, 0.6806734005610148, 0.34076804916063946, 0.8298511107762655, 0.0571845310429732, 0.940330425898234, 0.47428010900815326, 0.4042078951994578]
20170706 : 20170929
[0.5798394282658895, 0.5932405988375345, 0.7093638777732849, 0.2722257872422536, 0.8895139296849569, 0.08448692659536998, 0.9772547682126363, 0.7707486947377523, 0.6676571170488993, 0.27164021134376526, 0.8189897139867147, 0.03643824098010858, 0.9544040362040201, 0.48542391260464984, 0.40599090854326886]
20170929 : 20180102
[0.5729995369911194, 0.6039112210273743, 0.7030234138170878, 0.27464069922765094, 0.8634665409723917, 0.07280133416255315, 0.9295691649119059, 0.7785913745562235, 0.6936554312705994, 0.29994242389996845, 0.8393307328224182, 0.0484745018184185, 0.9419203003247579, 0.46465187271436054, 0.3991263310114543]
20180102 : 20180404
[0.734302838643392, 0.4271387755870819, 0.5242762168248495, 0.09093193213144939, 0.5486705501874288, 0.005700886559983094, 0.6018797159194946, 0.6124776403109232, 0.5167023837566376, 0.10401579737663269, 0.5127503971258799, 0.0012460101085404556, 0.3690476218859355, 0.500044584274292, 0.40723901987075806]
20180404 : 20180705
[0.7454104423522949, 0.46056893467903137, 0.4625197152296702, 0.12217340618371964, 0.5129115482171377, 0.007561440734813611, 0.4466666678587596, 0.5656497875849406, 0.5636967420578003, 0.10787152250607808, 0.5407769481341044, 0.003071665569829444, 0.5476190447807312, 0.44806989034016925, 0.4461085796356201]
20180705 : 20181008
[0.7507510582605997, 0.45997997124989826, 0.4787188569704692, 0.12299211571613948, 0.5070395072301229, 0.009159183750549952, 0.44206081827481586, 0.5611742933591207, 0.5425703724225363, 0.10632317264874776, 0.5173460145791372, 0.0021679462709774575, 0.49470900495847064, 0.4669697781403859, 0.44869394103686017]
20181008 : 20190103
[0.7226652503013611, 0.49212220311164856, 0.5187315940856934, 0.12433483948310216, 0.5550123651822408, 0.008488856721669436, 0.5555555522441864, 0.5768180886904398, 0.5506282846132914, 0.10334843893845876, 0.5724086364110311, 0.0015320322709158063, 0.5301587382952372, 0.4810555378595988, 0.45636088649431866]
20190103 : 20190408
[0.7271912296613058, 0.4668257534503937, 0.5723404884338379, 0.07485490789016087, 0.5707581837972006, 0.0043184501118958, 0.6057692368825277, 0.560244639714559, 0.45459697643915814, 0.054048859824736915, 0.44462979833285016, 0.0008052353902409474, 0.8333332538604736, 0.5576357245445251, 0.4548453191916148]
20190408 : 20190708
[0.767724335193634, 0.5531678597132365, 0.47463998198509216, 0.16808140774567923, 0.47606175144513446, 0.01771405277152856, 0.48153934876124066, 0.4719281991322835, 0.550476054350535, 0.08642558256785075, 0.5370267828305563, 0.004146686522290111, 0.7166666587193807, 0.46304715673128766, 0.5397165020306905]
20190708 : 20191009
[0.7437284390131632, 0.42941097418467206, 0.4692676564057668, 0.08559017876784007, 0.45823829372723895, 0.00663522615407904, 0.5171670218308767, 0.5693001747131348, 0.5294030706087748, 0.09079131732384364, 0.5481707056363424, 0.0008423991190890471, 0.4833333343267441, 0.4700900415579478, 0.43015066782633465]

在这里插入图片描述

history_predict.py

import os
import random
import tushare as ts
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf
import keras
from keras.models import Sequential
from keras import layers
from keras.optimizers import RMSprop
from keras import backend as K
import keras.backend.tensorflow_backend as KTF

from evaluate_model import evaluate
from get_samples import get_samples, get_data, get_samples_targets


# 显示历史预测曲线
def history_predict(model, ts_code='600004.SH', date=20191128, delay=1, during=244, mod='simple'):
    # 获取数据
    df = get_data(ts_code)
    if df is None:
        print('can not find data')
        return
    # 打印历史准确率
    print(evaluate(model, ts_code=ts_code))
    # 整理数据
    data = get_samples(ts_code=ts_code, date=date, duiring=during)
    if data is None:
        return
    result = model.predict(data)
    print('数据分割日:', df[df['trade_date'].isin(['20180102'])].index[0])
    today = df[df['trade_date'].isin([date])].index[0] + 1

    # 画图
    if mod != 'complex' and mod != 'c':
        # 简单
        plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
        plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
        fig = plt.figure()
        ax1 = fig.add_subplot(111)
        axis_max = max(abs((df['close'].shift(-delay) - df['close'])[today - during:today]))
        plt.ylim(ymin=-axis_max, ymax=axis_max)
        ax1.plot((df['close'].shift(-delay) - df['close'])[today - during:today], c='b', label='目标时间后涨跌幅')
        ax1.set_ylabel('目标时间涨跌幅')
        ax2 = ax1.twinx()
        plt.ylim(ymin=0, ymax=1)
        ax2.plot(range(today - during, today), result * 0 + 0.5, c='r')
        ax2.plot(range(today - during, today), result * 0 + 0.7, c='r')
        ax2.plot(range(today - during, today), result * 0 + 0.9, c='r')
        ax2.plot(range(today - during, today), result, c='y', label='预测值')
        ax2.set_ylabel('预测值')
        # 图例
        handles1, labels1 = ax1.get_legend_handles_labels()
        handles2, labels2 = ax2.get_legend_handles_labels()
        plt.legend(handles1 + handles2, labels1 + labels2, loc='upper right')
        plt.legend(handles1 + handles2, labels1 + labels2, loc='upper right')
        plt.title(ts_code)
        plt.show()
    else:
        # 复杂
        plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
        plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
        fig = plt.figure()
        ax1 = fig.add_subplot(111)
        ax1.plot(df['close'][today - during:today], c='g', label='今天')
        ax1.plot(df['close'].shift(-delay)[today - during:today], c='b', label='目标时间')
        ax1.set_ylabel('走势')
        ax2 = ax1.twinx()
        plt.ylim(ymin=0, ymax=1)
        ax2.plot(range(today - during, today), result * 0 + 0.5, c='r')
        ax2.plot(range(today - during, today), result * 0 + 0.7, c='r')
        ax2.plot(range(today - during, today), result * 0 + 0.9, c='r')
        ax2.plot(range(today - during, today), result, c='y', label='预测值')
        ax2.set_ylabel('预测值')
        # 图例
        handles1, labels1 = ax1.get_legend_handles_labels()
        handles2, labels2 = ax2.get_legend_handles_labels()
        plt.legend(handles1 + handles2, labels1 + labels2, loc='upper right')
        plt.legend(handles1 + handles2, labels1 + labels2, loc='upper right')
        plt.title(ts_code)
        plt.show()

In [87]: history_predict(model, ts_code='600004.SH', date=20191128, delay=1, during=244, mod='simple')
600004.SH
[0.5915805101394653, 0.623115599155426, 0.6919642686843872, 0.22914573550224304, 0.8923678994178772, 0.07638190686702728, 0.9870129823684692, 0.7153171896934509, 0.649040699005127, 0.14698298275470734, 0.8357771039009094, 0.024755029007792473, 0.9599999785423279, 0.5064902305603027, 0.45609569549560547]
模型构建日: 3967

在这里插入图片描述

serch_predict.py

import os
import random

import tushare as ts
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd

import tensorflow as tf
import keras
from keras.models import Sequential
from keras import layers
from keras.optimizers import RMSprop
from keras import backend as K
import keras.backend.tensorflow_backend as KTF

from get_samples import get_samples
from get_tools import *


# 搜索预测值高的股票
def search_predict(model, date=20191128, market='SSE', duiring=1, baseline=0.9, verbose=1):
    # 加载股票列表
    code_list = get_code_list(market=market)
    # 准备循环
    sum_pred = None
    sum_count = 0
    result_code = []
    result_pred = []
    for code_name in code_list[:]:
        print_verbose(verbose, code_name)
        samples = get_samples(ts_code=code_name, date=date, duiring=duiring, verbose=verbose)
        if samples is None:
            continue
        # 统计
        pred = model.predict(samples)
        if sum_count == 0:
            sum_pred = np.round(pred)
        else:
            sum_pred = sum_pred + np.round(pred)
        sum_count += 1
        # 判断
        if any(pred > baseline):
            result_code.append(code_name)
            result_pred.append(pred)
            print_verbose(verbose, '%s*****************************************' % code_name)
            print_verbose(verbose, pred)
    rate_pred = sum_pred / sum_count
    print('rate_pred:\n%s' % rate_pred)
    for i in range(len(result_code)):
        print('%s*****************************************\n%s' % (result_code[i], result_pred[i]))
    return result_code, result_pred, rate_pred

搜索从date开始往前during时间内预测结果达到baseline的股票

In [7]: results=search_predict(model, date=20191108, market='SSE', duiring=1, baseline=0.9, verbose=0)
rate_pred:
[[0.36472148]]
600127.SH*****************************************
[[0.9194063]]
603018.SH*****************************************
[[0.9124184]]

run.py

import os
import shutil
import time

# 重置data文件夹
data_dir = '../data/'
if os.path.exists(data_dir+"code_list_SSE.csv"):
    shutil.rmtree(data_dir)
    os.mkdir(data_dir)

from serch_predict import *
from evaluate_model import *
from history_predict import *

# 加载模型
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
tf.keras.backend.set_session(tf.Session(config=config))
model = load_model(model_name='./model/binary/ATT140to740.model')
# model = load_model(model_name='./model/1y1d/cudnnGRU/cudnnGRU210to340conv.model')

# 开始搜索
today = time.strftime("%Y%m%d")
print('today : %s' % today)
result_code, result_pred, rate_pred = search_predict(model, date=today, duiring=1, market='ALL', verbose=1, baseline=0.9)
# 将结果保存到文件
f = open("result.txt", "w")
f.write('rate_pred:\n%s\n' % rate_pred)
for i in range(len(result_code)):
    f.write('%s*****************************************\n%s\n' % (result_code[i], result_pred[i]))
f.close()

run.bat

python -i run.py

运行run.bat将清空数据并重新加载最新的数据,以今天为目标预测目标时间(今天+dalay)的结果,并储存>0.9的结果到results.txt文件夹中.

dotrain.py

import os
import random
import time

import tushare as ts
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
from keras.layers import *

import tensorflow as tf
import keras
from keras.models import Sequential
from keras import layers
from keras.optimizers import RMSprop
from keras import backend as K
import keras.backend.tensorflow_backend as KTF

from get_tools import *
from new_generator import new_generator

# GPU动态占用率
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
tf.keras.backend.set_session(tf.Session(config=config))

# gen
batch_size = 1024
shape = 5
train_val_date = 20180102
val_test_date = 20191108
lookback = 61  # 244/year
delay = 1
uprate = 0.0
generator = new_generator(market='ALL', batch_size=batch_size, shape=shape,
                          start_date='', end_date=train_val_date,
                          lookback=lookback, delay=delay, uprate=uprate)
val_generator = new_generator(market='ALL', batch_size=batch_size, shape=shape,
                              start_date=train_val_date, end_date=val_test_date,
                              lookback=lookback, delay=delay, uprate=uprate)

# 建模
# *************************************** CNN ***********************************
# model = Sequential()
# kernel_size = 4
# dropout_rate = 0.3
# model.add(layers.Conv1D(8, kernel_size=kernel_size, strides=2, padding='same',
#                         input_shape=(lookback, shape)))
# model.add(layers.BatchNormalization())
# model.add(layers.LeakyReLU())
# model.add(layers.Dropout(dropout_rate))
# model.add(layers.Conv1D(16, kernel_size=kernel_size, strides=2, padding='same'))
# model.add(layers.BatchNormalization())
# model.add(layers.LeakyReLU())
# model.add(layers.Dropout(dropout_rate))
# model.add(layers.Conv1D(32, kernel_size=kernel_size, strides=2, padding='same'))
# model.add(layers.BatchNormalization())
# model.add(layers.LeakyReLU())
# model.add(layers.Dropout(dropout_rate))
# model.add(layers.Conv1D(64, kernel_size=kernel_size, strides=2, padding='same'))
# model.add(layers.BatchNormalization())
# model.add(layers.LeakyReLU())
# model.add(layers.Dropout(dropout_rate))
# model.add(layers.Conv1D(128, kernel_size=kernel_size, strides=2, padding='same'))
# model.add(layers.BatchNormalization())
# model.add(layers.LeakyReLU())
# model.add(layers.Dropout(dropout_rate))
# model.add(layers.Conv1D(256, kernel_size=kernel_size, strides=2, padding='same'))
# model.add(layers.BatchNormalization())
# model.add(layers.LeakyReLU())
# model.add(layers.Dropout(dropout_rate))
# model.add(layers.Conv1D(512, kernel_size=kernel_size, strides=2, padding='same'))
# model.add(layers.BatchNormalization())
# model.add(layers.LeakyReLU())
# model.add(layers.Dropout(dropout_rate))
# model.add(layers.Flatten())
# model.add(layers.Dense(1, activation='sigmoid'))
# model.compile(optimizer=keras.optimizers.Adam(),  # lr=1e-4, epsilon=1e-8, decay=1e-4),
#               loss=keras.losses.binary_crossentropy,
#               metrics=[recall, precision, recall2, precision2, trate, prate])
# ************************************* G R U ******************************************
# dropout_rate = 0.5
# model = Sequential()
# # model.add(layers.BatchNormalization())
# model.add(layers.GRU(256,
#                      dropout=0.1,
#                      recurrent_dropout=0.5,
#                      input_shape=(None, shape)))
# model.add(layers.Dense(64, activation='relu'))
# model.add(layers.Dropout(dropout_rate))
# model.add(layers.Dense(1, activation='sigmoid'))
# model.compile(optimizer=keras.optimizers.RMSprop(1e-4),
#               loss=keras.losses.binary_crossentropy,
#               metrics=[recall, precision, recall2, precision2, trate, prate])
# *************************************** ResNet ***************************************
# def ResBlock(x, num_filters, resampling=None, kernel_size=3):
#     def BatchActivation(x):
#         x = BatchNormalization(momentum=0.9, epsilon=1e-5)(x)
#         x = Activation('relu')(x)
#         return x
#
#     def Conv(x, resampling=resampling):
#         weight_decay = 1e-4
#         if resampling is None:
#             x = Conv1D(num_filters, kernel_size=kernel_size, padding='same',
#                        kernel_initializer="he_normal",
#                        kernel_regularizer=regularizers.l2(weight_decay))(x)
#         elif resampling == 'up':
#             x = UpSampling2D()(x)
#             x = Conv1D(num_filters, kernel_size=kernel_size, padding='same',
#                        kernel_initializer="he_normal",
#                        kernel_regularizer=regularizers.l2(weight_decay))(x)
#         elif resampling == 'down':
#             x = Conv1D(num_filters, kernel_size=kernel_size, strides=2, padding='same',
#                        kernel_initializer="he_normal",
#                        kernel_regularizer=regularizers.l2(weight_decay))(x)
#         return x
#
#     a = BatchActivation(x)
#     y = Conv(a, resampling=resampling)
#     y = BatchActivation(y)
#     y = Conv(y, resampling=None)
#     if resampling is not None:
#         x = Conv(a, resampling=resampling)
#     return add([y, x])
#
#
# num_layers = int(np.log2(lookback)) - 3
# max_num_channels = lookback * 8
# weight_decay = 1e-4
#
# x_in = Input(shape=(lookback, shape))
# x = x_in
# for i in range(num_layers + 1):
#     num_channels = max_num_channels // 2 ** (num_layers - i)
#     if i > 0:
#         x = ResBlock(x, num_channels, resampling='down')
#     else:
#         x = Conv1D(num_channels, kernel_size=3, strides=2, padding='same',
#                    kernel_initializer="he_normal",
#                    kernel_regularizer=regularizers.l2(weight_decay))(x)
# x = BatchNormalization(momentum=0.9, epsilon=1e-5)(x)
# x = Activation('relu')(x)
# x = GlobalAveragePooling1D()(x)
# x = Dense(1, activation='sigmoid')(x)
# model = keras.Model(x_in, x)
# model.compile(optimizer=keras.optimizers.Adam(),  # lr=1e-4, epsilon=1e-8, decay=1e-4),
#               loss=keras.losses.binary_crossentropy,
#               metrics=[recall, precision, recall2, precision2, trate, prate])
# # model.summary()
# *********************************** Attention *********************************************
dropout_rate = 0.3
x_in = Input(shape=(lookback, shape))
x = x_in
# x = BatchNormalization()(x)
c = Conv1D(32, 5, activation='relu')(x)
c = LeakyReLU()(c)
c = Dropout(dropout_rate)(c)
c = Flatten()(c)
c = Dense(lookback * shape)(c)
c = LeakyReLU()(c)
c = Lambda(lambda k: K.reshape(k, (-1, lookback, shape)))(c)
m = multiply([x, c])
r = GRU(256)(m)
r = LeakyReLU()(r)
r = Dropout(dropout_rate)(r)
d = Dense(256)(r)
d = LeakyReLU()(d)
d = Dropout(dropout_rate)(d)
# res
res = Dense(1, activation='sigmoid')(d)
model = keras.Model(inputs=x_in, outputs=res)
model.compile(optimizer=keras.optimizers.Adam(lr=1e-4),  # lr=1e-4, epsilon=1e-8, decay=1e-4),
              loss=keras.losses.binary_crossentropy,
              metrics=[recall, precision, recall2, precision2, trate, prate]
              )
# model = load_model('./model/ATTSMALL480bad.model')
# model.load_weights('./model/ATTSMALL480bad.weight')

# callback
checkpoint = keras.callbacks.ModelCheckpoint('./model/auto_save_best.model', monitor='val_loss',
                                             verbose=1, save_best_only=True, mode='min')
learning_rate_reduction = keras.callbacks.ReduceLROnPlateau(monitor='loss', patience=60,
                                                            factor=0.5, min_lr=1e-8, verbose=1)
callbacks_list = [checkpoint, learning_rate_reduction]

# run
history = model.fit_generator(generator,
                              steps_per_epoch=200,  # 1min/epoch
                              epochs=180,
                              validation_data=val_generator,
                              validation_steps=10,
                              callbacks=callbacks_list,
                              # class_weight=class_weight,
                              verbose=1)

model.save('./model/auto_save.model')
model.save_weights('./model/auto_save.weight')
# show_train_history(history, 'loss', 'val_loss')
# plt.savefig('./model/auto_save.jpg')
# plt.show()
plot_history(history)

训练

(以进行lookback=61,delay=1,进行对第二天是否上涨进行训练)

CNN

在这里插入图片描述

GRU

在这里插入图片描述

ResNet

在这里插入图片描述

混合Attention

在这里插入图片描述

效果

(使用ATT140to740.model作为测试)

衡量总体准确率

全部数据

In [19]: evaluate_total(model, market='ALL', steps=10, shape=5, start_date='', end_date='', lookback=61, delay=1, uprate=0.0)
Out[19]:
[0.6090528190135955,
 0.6116366863250733,
 0.6690988123416901, # 取0.5时准确率
 0.256537552177906,
 0.8376343965530395,
 0.08675041273236275,
 0.9569046974182129, # 取0.9时准确率
 0.689475154876709,
 0.6336585283279419,
 0.20669595450162886,
 0.7789588630199432,
 0.03284015003591776,
 0.9456658005714417,
 0.50647232234478,
 0.4630115032196045]

交叉验证集数据

In [22]: evaluate_total(model, market='ALL', steps=10, shape=5, start_date=20180102, end_date=20191108, lookback=61, delay=1, uprate=0.0)
Out[22]:
[0.7405495762825012,
 0.46770868003368377,
 0.5000730127096176, # 取0.5时准确率
 0.1060352623462677,
 0.5055912852287292,
 0.009883439354598521,
 0.5355186939239502, # 取0.9时准确率
 0.5606197416782379,
 0.5284954369068146,
 0.09424014165997505,
 0.5357251167297363,
 0.0018147096503525971,
 0.6095238149166107,
 0.4844076007604599,
 0.4531158059835434]

测试集数据

In [23]: evaluate_total(model, market='ALL', steps=10, shape=5, start_date=20191108, end_date="", lookback=61, delay=1, uprate=0.0)
Out[23]:
[0.7380412459373474,
 0.4238017022609711,
 0.47870981097221377, # 取0.5时准确率
 0.09282655492424965,
 0.5137168973684311,
 0.006658474449068308,
 0.5203565269708633, # 取0.9时准确率
 0.5914961218833923,
 0.5369187474250794,
 0.09529061019420623,
 0.5270131379365921,
 0.0015642174810636788,
 0.6009523928165436,
 0.46967103481292727,
 0.41572613418102267]

衡量模型对每只股票的准确率

全部数据

In [26]: evaluate_all(model, market='ALL', start_date='', end_date='', lookback=61, delay=1)
# code
# [loss, 0.5 recall, 0.5 precision, 0.7 recall, 0.7 precision, 0.9 recall, 0.9 precision, 0.5 neg-recall, 0.5 neg-precision, 0.3 neg-recall, 0.3 neg-precision, 0.1 neg-recall, 0.1 neg-precision, Prate, Trate]
600000.SH
[0.6263973116874695, 0.5945805311203003, 0.6121244430541992, 0.15789473056793213, 0.8885630369186401, 0.029702970758080482, 1.0, 0.64207923412323, 0.6250602602958679, 0.10198019444942474, 0.8995633125305176, 0.0089108906686306, 1.0, 0.4871794879436493, 0.47321656346321106]
600004.SH
[0.5915805101394653, 0.623115599155426, 0.6919642686843872, 0.22914573550224304, 0.8923678994178772, 0.07638190686702728, 0.9870129823684692, 0.7153171896934509, 0.649040699005127, 0.14698298275470734, 0.8357771039009094, 0.024755029007792473, 0.9599999785423279, 0.5064902305603027, 0.45609569549560547]
600006.SH
[0.5564819574356079, 0.6736953258514404, 0.7041322588920593, 0.3494991958141327, 0.8655352592468262, 0.13020558655261993, 0.9610894918441772, 0.7370225191116333, 0.7085687518119812, 0.30313417315483093, 0.8398914337158203, 0.05239960923790932, 0.9553571343421936, 0.4815943241119385, 0.46077683568000793]
600007.SH
[0.5916787385940552, 0.6156201958656311, 0.6794366240501404, 0.26033690571784973, 0.886956512928009, 0.09596733003854752, 0.9740932583808899, 0.7126262784004211, 0.652033269405365, 0.2050504982471466, 0.8319672346115112, 0.03181818127632141, 0.9692307710647583, 0.497334361076355, 0.45062199234962463]
600008.SH
[0.5691049098968506, 0.6397739052772522, 0.6939799189567566, 0.3288797438144684, 0.8839778900146484, 0.1320657730102539, 0.9661654233932495, 0.7245358824729919, 0.6731934547424316, 0.2679377794265747, 0.8436018824577332, 0.05017561465501785, 0.9174311757087708, 0.4940340220928192, 0.4554455578327179]
600009.SH
[0.6188881397247314, 0.5797174572944641, 0.6558219194412231, 0.16195762157440186, 0.9093484282493591, 0.03279515728354454, 1.0, 0.6918753385543823, 0.6191129684448242, 0.11854879558086395, 0.8787878751754761, 0.01532958634197712, 0.9375, 0.5031734108924866, 0.44478294253349304]
600010.SH
[0.5921441912651062, 0.6390403509140015, 0.647871732711792, 0.2764449417591095, 0.8311475515365601, 0.1035986915230751, 0.949999988079071, 0.697387158870697, 0.6892018914222717, 0.227078378200531, 0.8400703072547913, 0.0370546318590641, 0.9397590160369873, 0.4656004011631012, 0.45925360918045044]
600011.SH
[0.6160035729408264, 0.5720338821411133, 0.638675332069397, 0.19968220591545105, 0.8285714387893677, 0.05932203307747841, 0.991150438785553, 0.7020965218544006, 0.6405693888664246, 0.19697707891464233, 0.800000011920929, 0.02096538245677948, 0.9555555582046509, 0.47930946946144104, 0.4292967617511749]
600012.SH
[0.5865558981895447, 0.6079456210136414, 0.6793224215507507, 0.26346054673194885, 0.8704662919044495, 0.07736539095640182, 0.9801324605941772, 0.7290226817131042, 0.6632240414619446, 0.2221125364303589, 0.855513334274292, 0.026159921661019325, 0.9636363387107849, 0.4856562614440918, 0.43462806940078735]
600015.SH
[0.6084867715835571, 0.6029331684112549, 0.6479859948158264, 0.19282998144626617, 0.8897243142127991, 0.04888647422194481, 0.9890109896659851, 0.6902927756309509, 0.6477108597755432, 0.15716487169265747, 0.8571428656578064, 0.02003081701695919, 0.9750000238418579, 0.4860084354877472, 0.4522175192832947]
600016.SH
[0.6412468552589417, 0.589085042476654, 0.5941716432571411, 0.12038522958755493, 0.8426966071128845, 0.024077046662569046, 0.9375, 0.6367149949073792, 0.6318312287330627, 0.08647342771291733, 0.8564593195915222, 0.010628019459545612, 1.0, 0.4744859039783478, 0.47042396664619446]
600017.SH
[0.557805061340332, 0.6498655676841736, 0.6931899785995483, 0.36491936445236206, 0.8715890645980835, 0.13508065044879913, 0.9757281541824341, 0.7346559166908264, 0.6946072578430176, 0.3161810338497162, 0.8279221057891846, 0.06695598363876343, 0.9908257126808167, 0.47984522581100464, 0.44985488057136536]
600018.SH
[0.5923075079917908, 0.6194751262664795, 0.6476534008979797, 0.2790055274963379, 0.8541226387023926, 0.10428176820278168, 0.9679487347602844, 0.6970825791358948, 0.6708482503890991, 0.2178770899772644, 0.839712917804718, 0.03414028510451317, 0.9166666865348816, 0.47335731983184814, 0.45276233553886414]
600019.SH
[0.6044419407844543, 0.6025437116622925, 0.6553314328193665, 0.22363540530204773, 0.8737059831619263, 0.07101219147443771, 0.9781022071838379, 0.7085769772529602, 0.659709632396698, 0.18323586881160736, 0.8430493474006653, 0.025828460231423378, 0.9298245906829834, 0.47905558347702026, 0.44046711921691895]
600020.SH
[0.5614857077598572, 0.6450549364089966, 0.6988095045089722, 0.3434065878391266, 0.8680555820465088, 0.14725275337696075, 0.9605734944343567, 0.7481334209442139, 0.6993950605392456, 0.2862120568752289, 0.8468335866928101, 0.056246887892484665, 0.9576271176338196, 0.47531992197036743, 0.4387568533420563]
# ***********************************数据过多,省略部分***********************************
300770.SZ
[0.6937742233276367, 0.5438596606254578, 0.6458333134651184, 0.12280701845884323, 0.699999988079071, 0.0, 0.0, 0.574999988079071, 0.4693877696990967, 0.05000000074505806, 0.5, 0.0, 0.0, 0.5876288414001465, 0.49484536051750183]
300771.SZ
[0.7008504271507263, 0.5964912176132202, 0.6938775777816772, 0.2631579041481018, 0.7142857313156128, 0.0, 0.0, 0.6153846383094788, 0.5106382966041565, 0.12820513546466827, 0.4545454680919647, 0.0, 0.0, 0.59375, 0.5104166865348816]
300772.SZ
[0.8239073753356934, 0.5227272510528564, 0.46000000834465027, 0.15909090638160706, 0.4117647111415863, 0.0, 0.0, 0.4375, 0.5, 0.1666666716337204, 0.800000011920929, 0.0, 0.0, 0.47826087474823, 0.54347825050354]
300773.SZ
[0.8311894536018372, 0.45098039507865906, 0.5, 0.13725490868091583, 0.46666666865348816, 0.019607843831181526, 0.5, 0.4523809552192688, 0.40425533056259155, 0.0714285746216774, 0.3333333432674408, 0.0, 0.0, 0.5483871102333069, 0.49462366104125977]
300775.SZ
[0.8624421954154968, 0.4444444477558136, 0.4444444477558136, 0.1388888955116272, 0.3125, 0.0, 0.0, 0.523809552192688, 0.523809552192688, 0.095238097012043, 0.3636363744735718, 0.0, 0.0, 0.4615384638309479, 0.4615384638309479]
300776.SZ
[0.7996936440467834, 0.39534884691238403, 0.5, 0.23255814611911774, 0.5555555820465088, 0.023255813866853714, 0.9999998807907104, 0.5405405163764954, 0.43478259444236755, 0.1621621549129486, 0.6000000238418579, 0.0, 0.0, 0.5375000238418579, 0.42500001192092896]
300777.SZ
[0.7923781871795654, 0.574999988079071, 0.5476190447807312, 0.20000000298023224, 0.42105263471603394, 0.05000000074505806, 0.4000000059604645, 0.5365853905677795, 0.5641025900840759, 0.24390244483947754, 0.7142857313156128, 0.0, 0.0, 0.4938271641731262, 0.5185185074806213]
300778.SZ
[0.7725061774253845, 0.800000011920929, 0.5161290168762207, 0.2750000059604645, 0.47826087474823, 0.05000000074505806, 1.0, 0.3333333432674408, 0.6521739363670349, 0.02222222276031971, 0.3333333432674408, 0.0, 0.0, 0.47058823704719543, 0.729411780834198]
300779.SZ
[0.8568037152290344, 0.5348837375640869, 0.574999988079071, 0.1860465109348297, 0.380952388048172, 0.023255813866853714, 0.5, 0.5, 0.45945945382118225, 0.20588235557079315, 0.5, 0.0, 0.0, 0.5584415793418884, 0.5194805264472961]
300780.SZ
[0.6919370293617249, 0.7368420958518982, 0.6222222447395325, 0.34210526943206787, 0.8125, 0.0, 0.0, 0.46875, 0.6000000238418579, 0.03125, 0.25, 0.0, 0.0, 0.5428571701049805, 0.6428571343421936]
300781.SZ
[0.7137081623077393, 0.6216216087341309, 0.6388888955116272, 0.2432432472705841, 0.692307710647583, 0.0, 0.0, 0.5517241358757019, 0.5333333611488342, 0.13793103396892548, 0.6666666865348816, 0.0, 0.0, 0.560606062412262, 0.5454545617103577]
300782.SZ
[0.6372032165527344, 0.71875, 0.6764705777168274, 0.09375, 0.5, 0.03125, 0.9999998807907104, 0.5925925970077515, 0.6399999856948853, 0.1111111119389534, 1.0, 0.0, 0.0, 0.5423728823661804, 0.5762711763381958]
300783.SZ
[0.822144091129303, 0.4545454680919647, 0.5, 0.13636364042758942, 0.5, 0.0, 0.0, 0.4736842215061188, 0.4285714328289032, 0.10526315867900848, 0.3333333432674408, 0.0, 0.0, 0.5365853905677795, 0.4878048896789551]
300785.SZ
[0.824234664440155, 0.6190476417541504, 0.5, 0.1428571492433548, 0.27272728085517883, 0.095238097012043, 0.6666666865348816, 0.31578946113586426, 0.4285714328289032, 0.10526315867900848, 0.5, 0.0, 0.0, 0.5249999761581421, 0.6499999761581421]
300786.SZ
[0.888309121131897, 0.6428571343421936, 0.375, 0.4285714328289032, 0.4000000059604645, 0.0714285746216774, 0.3333333432674408, 0.25, 0.5, 0.05000000074505806, 0.9999998807907104, 0.0, 0.0, 0.4117647111415863, 0.7058823704719543]
300787.SZ
[0.8598592281341553, 0.699999988079071, 0.5384615659713745, 0.10000000149011612, 0.3333333432674408, 0.0, 0.0, 0.1428571492433548, 0.25, 0.1428571492433548, 0.9999998807907104, 0.0, 0.0, 0.5882353186607361, 0.7647058963775635]
300788.SZ
[0.7071133255958557, 0.6666666865348816, 0.6666666865348816, 0.1666666716337204, 0.6666666865348816, 0.0, 0.0, 0.6363636255264282, 0.6363636255264282, 0.1818181872367859, 0.800000011920929, 0.0, 0.0, 0.52173912525177, 0.52173912525177]
300789.SZ
[0.9431849718093872, 0.75, 0.5, 0.5, 0.5, 0.0, 0.0, 0.25, 0.5, 0.0, 0.0, 0.0, 0.0, 0.5, 0.75]
300790.SZ
data range too small, may be date too close to boundary.
None
300791.SZ
data range too small, may be date too close to boundary.
None

交叉验证集

In [24]: evaluate_all(model, market='ALL', start_date=20180102, end_date=20191108, lookback=61, delay=1)
# code
# [loss, 0.5 recall, 0.5 precision, 0.7 recall, 0.7 precision, 0.9 recall, 0.9 precision, 0.5 neg-recall, 0.5 neg-precision, 0.3 neg-recall, 0.3 neg-precision, 0.1 neg-recall, 0.1 neg-precision, Prate, Trate]
600000.SH
[0.693852424621582, 0.3611111044883728, 0.49367088079452515, 0.0, 0.0, 0.0, 0.0, 0.6566523313522339, 0.5257731676101685, 0.004291845485568047, 0.5, 0.0, 0.0, 0.48106902837753296, 0.3518930971622467]
600004.SH
[0.7271435856819153, 0.3504672944545746, 0.5, 0.04205607622861862, 0.44999998807907104, 0.004672897048294544, 0.9999998807907104, 0.6808510422706604, 0.5351170301437378, 0.059574469923973083, 0.5185185074806213, 0.0, 0.0, 0.47661471366882324, 0.3340757191181183]
600006.SH
[0.7446303963661194, 0.4333333373069763, 0.45728641748428345, 0.09047619253396988, 0.4523809552192688, 0.009523809887468815, 1.0, 0.5481171607971191, 0.5239999890327454, 0.12970711290836334, 0.6326530575752258, 0.0, 0.0, 0.46770602464675903, 0.44320711493492126]
600007.SH
[0.7179648876190186, 0.36057692766189575, 0.4491018056869507, 0.028846153989434242, 0.6000000238418579, 0.0, 0.0, 0.6182572841644287, 0.5283687710762024, 0.029045643284916878, 0.46666666865348816, 0.0, 0.0, 0.46325168013572693, 0.37193763256073]
600008.SH
[0.7209228873252869, 0.40594059228897095, 0.48235294222831726, 0.0891089141368866, 0.529411792755127, 0.004950494971126318, 0.5, 0.6437246799468994, 0.5698924660682678, 0.11740890890359879, 0.6041666865348816, 0.004048583097755909, 0.9999998807907104, 0.4498886466026306, 0.37861916422843933]
600009.SH
[0.7013201713562012, 0.36771300435066223, 0.5430463552474976, 0.0044843051582574844, 0.25, 0.0, 0.0, 0.6946902871131897, 0.5268456339836121, 0.08407079428434372, 0.6785714030265808, 0.0, 0.0, 0.4966592490673065, 0.33630290627479553]
600010.SH
[0.7444199323654175, 0.44171780347824097, 0.35820895433425903, 0.07975459843873978, 0.3611111044883728, 0.0061349691823124886, 0.5, 0.5342960357666016, 0.6192468404769897, 0.10469313710927963, 0.6304348111152649, 0.0, 0.0, 0.3704545497894287, 0.45681819319725037]
600011.SH
[0.7138883471488953, 0.4285714328289032, 0.5113636255264282, 0.04285714402794838, 0.40909090638160706, 0.0, 0.0, 0.6401673555374146, 0.5604395866394043, 0.10878661274909973, 0.604651153087616, 0.0041841003112494946, 0.9999998807907104, 0.46770602464675903, 0.39198216795921326]
600012.SH
[0.714976966381073, 0.4439024329185486, 0.4715026021003723, 0.07804878056049347, 0.5714285969734192, 0.0, 0.0, 0.5819672346115112, 0.5546875, 0.069672130048275, 0.6538461446762085, 0.004098360426723957, 0.9999998807907104, 0.4565701484680176, 0.42984411120414734]
600015.SH
[0.6911038160324097, 0.3448275923728943, 0.48275861144065857, 0.004926108289510012, 0.5, 0.0, 0.0, 0.6951219439506531, 0.5625, 0.016260161995887756, 0.6666666865348816, 0.0, 0.0, 0.4521158039569855, 0.3229398727416992]
600016.SH
[0.7003122568130493, 0.3400000035762787, 0.4197530746459961, 0.004999999888241291, 0.5, 0.0, 0.0, 0.6224899888038635, 0.5400696992874146, 0.00803212821483612, 0.6666666865348816, 0.0, 0.0, 0.4454343020915985, 0.3608017861843109]
600017.SH
[0.7326346635818481, 0.42487046122550964, 0.4205128252506256, 0.10362694412469864, 0.5555555820465088, 0.005181347019970417, 0.5, 0.55859375, 0.5629921555519104, 0.07421875, 0.5428571701049805, 0.0, 0.0, 0.42984411120414734, 0.43429845571517944]
600018.SH
[0.7232927083969116, 0.3849765360355377, 0.46857142448425293, 0.061032865196466446, 0.5416666865348816, 0.0, 0.0, 0.6059321761131287, 0.5218977928161621, 0.09322033822536469, 0.6285714507102966, 0.0, 0.0, 0.474387526512146, 0.3897550106048584]
# ***********************************数据过多,省略部分***********************************
300722.SZ
[0.7533032298088074, 0.49344977736473083, 0.5736040472984314, 0.14847160875797272, 0.5862069129943848, 0.0043668122962117195, 0.25, 0.5714285969734192, 0.4912280738353729, 0.19387754797935486, 0.5507246255874634, 0.0, 0.0, 0.5388235449790955, 0.46352940797805786]
300723.SZ
[0.7201671600341797, 0.49514561891555786, 0.5125628113746643, 0.16019417345523834, 0.5593220591545105, 0.019417475908994675, 0.5714285969734192, 0.5488371849060059, 0.5315315127372742, 0.13488371670246124, 0.707317054271698, 0.004651162773370743, 0.9999998807907104, 0.489311158657074, 0.4726840853691101]
300724.SZ
can not find date
None
300725.SZ
[0.7402744889259338, 0.5311004519462585, 0.5235849022865295, 0.1818181872367859, 0.5757575631141663, 0.019138755276799202, 0.6666666865348816, 0.521327018737793, 0.5288461446762085, 0.11848340928554535, 0.5813953280448914, 0.0, 0.0, 0.49761903285980225, 0.5047619342803955]
300726.SZ
[0.7371765375137329, 0.5115207433700562, 0.5388349294662476, 0.1751152127981186, 0.5846154093742371, 0.0138248847797513, 1.0, 0.5273631811141968, 0.5, 0.13930347561836243, 0.5283018946647644, 0.004975124262273312, 0.9999998807907104, 0.519138753414154, 0.4928229749202728]
300727.SZ
[0.7675619721412659, 0.44954127073287964, 0.5051546096801758, 0.1376146823167801, 0.5, 0.027522936463356018, 0.8571428656578064, 0.5102040767669678, 0.4545454680919647, 0.13265305757522583, 0.48148149251937866, 0.005102040711790323, 0.9999998807907104, 0.5265700221061707, 0.46859902143478394]
300729.SZ
[0.7517649531364441, 0.5071770548820496, 0.5299999713897705, 0.13875597715377808, 0.5686274766921997, 0.0, 0.0, 0.5323383212089539, 0.5095238089561462, 0.12437810748815536, 0.5208333134651184, 0.0, 0.0, 0.5097560882568359, 0.4878048896789551]
300730.SZ
[0.771763801574707, 0.43192487955093384, 0.5257142782211304, 0.1690140813589096, 0.5538461804389954, 0.0, 0.0, 0.5631579160690308, 0.46929824352264404, 0.17894737422466278, 0.48571428656578064, 0.0, 0.0, 0.5285359621047974, 0.43424317240715027]
300731.SZ
[0.8007940649986267, 0.4545454680919647, 0.43589743971824646, 0.14438502490520477, 0.3913043439388275, 0.010695187374949455, 0.4000000059604645, 0.45812806487083435, 0.4769230782985687, 0.09852216392755508, 0.4878048896789551, 0.004926108289510012, 0.9999998807907104, 0.47948718070983887, 0.5]
300732.SZ
[0.7830197811126709, 0.5056179761886597, 0.5027933120727539, 0.13483145833015442, 0.4444444477558136, 0.00561797758564353, 0.1666666716337204, 0.5082873106002808, 0.5111111402511597, 0.14364640414714813, 0.5531914830207825, 0.01104972418397665, 1.0, 0.4958217144012451, 0.49860724806785583]
300733.SZ
can not find date
None
300735.SZ
[0.7560451626777649, 0.5095238089561462, 0.5431472063064575, 0.13333334028720856, 0.5490196347236633, 0.014285714365541935, 0.75, 0.5, 0.4663212299346924, 0.05000000074505806, 0.3461538553237915, 0.0055555556900799274, 0.9999998807907104, 0.5384615659713745, 0.5051282048225403]
300736.SZ
can not find date
None
300737.SZ
can not find date
None
300738.SZ
can not find date
None

按时间衡量模型准确度

2003-2019,按年衡量

In [27]: evaluate_total_time(model, date_step=244, steps=3, start_date='', end_date='', lookback=61, delay=1, uprate=0.0)
20030114 : 20040203
[0.6568018794059753, 0.46427900592486065, 0.5843748847643534, 0.06550159056981404, 0.773417055606842, 0.0038116557989269495, 0.8251082301139832, 0.7334766785303751, 0.6291241844495138, 0.08584380894899368, 0.7580034534136454, 0.0007994713766189913, 0.6666666666666666, 0.44664348165194195, 0.3549077312151591]
20040203 : 20050127
[0.6508340040842692, 0.535581111907959, 0.5979280471801758, 0.07494993011156718, 0.7752963105837504, 0.004925861178586881, 0.9333333373069763, 0.6983536680539449, 0.6421788732210795, 0.09007209291060765, 0.7887819210688273, 0.00033121334854513407, 0.666666587193807, 0.4558259844779968, 0.40821966528892517]
20050127 : 20060214
[0.667553981145223, 0.5478649338086446, 0.5861708919207255, 0.047291661302248635, 0.8047292033831278, 0.0020023582813640437, 0.8666666746139526, 0.6283081372578939, 0.5911598006884257, 0.03305310135086378, 0.7518921494483948, 0.0, 0.0, 0.4900597333908081, 0.4581438899040222]
20060214 : 20070424
[0.6465277671813965, 0.6215876539548238, 0.6484155257542928, 0.1326626588900884, 0.874542772769928, 0.02713957242667675, 0.9722222288449606, 0.5738389492034912, 0.5453460415204366, 0.059145551174879074, 0.7547684907913208, 0.0022163967757175365, 0.8888888955116272, 0.5583489338556925, 0.535258968671163]
20070424 : 20080428
[0.5526228348414103, 0.7043639024098715, 0.7278452714284261, 0.36831281582514447, 0.8801937301953634, 0.1299465224146843, 0.9723483721415201, 0.7016325195630392, 0.6768803000450134, 0.2890618046124776, 0.8412723143895467, 0.05818237240115801, 0.9391853213310242, 0.5311580499013265, 0.5139520565668741]
20080428 : 20090429
[0.5299248894055685, 0.7118967771530151, 0.7440837621688843, 0.4226831793785095, 0.8908692598342896, 0.18500982224941254, 0.9596697290738424, 0.7404838601748148, 0.7079606850941976, 0.33868083357810974, 0.8527288834253947, 0.051117694626251854, 0.9455873966217041, 0.5146652460098267, 0.492377628882726]
20090429 : 20100517
[0.5744742155075073, 0.7015180389086405, 0.7103689312934875, 0.32112271587053937, 0.8679341475168864, 0.09785450746615727, 0.967859665552775, 0.6708837946256002, 0.6615198651949564, 0.21318497757116953, 0.8435181975364685, 0.025281783193349838, 0.9693877498308817, 0.5349915226300558, 0.5283944010734558]
20100517 : 20110524
[0.5779383579889933, 0.6811196804046631, 0.7198743422826132, 0.3169198234875997, 0.8725708723068237, 0.08273773143688838, 0.9551554322242737, 0.6999802788098654, 0.659738302230835, 0.22190992534160614, 0.8136094411214193, 0.020922282089789707, 0.93113245566686, 0.5309797724088033, 0.5024516383806864]
20110524 : 20120525
[0.563612699508667, 0.6636995673179626, 0.6954499880472819, 0.32321880261103314, 0.8615734577178955, 0.09794096151987712, 0.9653286337852478, 0.735595683256785, 0.7062193155288696, 0.2857717474301656, 0.8460407257080078, 0.040361875047286354, 0.9873376687367758, 0.47641971707344055, 0.45466703176498413]
20120525 : 20130530
[0.5708633462587992, 0.6553651293118795, 0.7170006036758423, 0.28112663825352985, 0.8910457094510397, 0.0848972921570142, 0.9739351073900858, 0.7358709971110026, 0.6764885187149048, 0.24324760834376016, 0.8477358420689901, 0.031711009020606674, 0.9562129179636637, 0.5052153070767721, 0.4617990553379059]
20130530 : 20140605
[0.5766666332880656, 0.6566015680631002, 0.6961189905802408, 0.29123929142951965, 0.8662963509559631, 0.08678068965673447, 0.9734111825625101, 0.7191708286603292, 0.6812556187311808, 0.2289514938990275, 0.8459783395131429, 0.029651952907443047, 0.9576589266459147, 0.494873841603597, 0.46679147084554035]
20140605 : 20150603
[0.5726994872093201, 0.6853102246920267, 0.7366012533505758, 0.3206663131713867, 0.8904303908348083, 0.09374689559141795, 0.9706396659215292, 0.682244082291921, 0.6252208948135376, 0.20458133021990457, 0.8254867792129517, 0.0190351443986098, 0.9315588275591532, 0.5649460554122925, 0.5255416035652161]
20150603 : 20160719
[0.46518025795618695, 0.7420702576637268, 0.7916798988978068, 0.543129583199819, 0.9014248053232828, 0.33435707290967304, 0.9645818471908569, 0.7944092949231466, 0.7452153960863749, 0.4977227846781413, 0.8516929944356283, 0.1906640032927195, 0.9499153892199198, 0.5131496985753378, 0.48087724049886066]
20160719 : 20170720
[0.5732301076253256, 0.5929640928904215, 0.7088420589764913, 0.2925766507784526, 0.8704588413238525, 0.11173826456069946, 0.9588313500086466, 0.7756170431772867, 0.6740103562672933, 0.30973106622695923, 0.8179851770401001, 0.04629548266530037, 0.942859947681427, 0.47953999042510986, 0.40108763178189594]
20170720 : 20180719
[0.6642247041066488, 0.5296896497408549, 0.5801922480265299, 0.18797219296296439, 0.7028467853864034, 0.04001099616289139, 0.8669108748435974, 0.6634118358294169, 0.6162890593210856, 0.1820149372021357, 0.7172070145606995, 0.017918592939774197, 0.9570804635683695, 0.46759383877118427, 0.4268520971139272]
20180719 : 20190722
[0.7431689103444418, 0.48683969179789227, 0.5005723039309183, 0.11705733835697174, 0.507864753405253, 0.00803534360602498, 0.5076609253883362, 0.5368337829907736, 0.5231437285741171, 0.08796707292397817, 0.5264059901237488, 0.0013925166955838602, 0.600000003973643, 0.488187571366628, 0.4747258722782135]

在这里插入图片描述

2016-2019年,按月衡量

In [31]: evaluate_total_time(model, date_step=20, steps=3, start_date=20160104, end_date='', lookback=61, delay=1, uprate=0.0)
20160104 : 20160201
[0.3300749659538269, 0.8333790898323059, 0.8538292646408081, 0.7283818125724792, 0.9206656614939371, 0.5122989416122437, 0.9672505259513855, 0.879724125067393, 0.8622852166493734, 0.7320470015207926, 0.9174177447954813, 0.47305670380592346, 0.9752546151479086, 0.4574306805928548, 0.4464651842912038]
20160201 : 20160331
[0.4164064625898997, 0.7700598041216532, 0.8489522139231364, 0.6232134501139323, 0.9363580544789633, 0.46540990471839905, 0.9835724830627441, 0.833103617032369, 0.7484932343165079, 0.5455527305603027, 0.8428746660550436, 0.17206532756487528, 0.9406526684761047, 0.5491664409637451, 0.4981724222501119]
20160331 : 20160429
[0.5099559426307678, 0.680933674176534, 0.7686257163683573, 0.4594770272572835, 0.8953756093978882, 0.25383887191613513, 0.9734796682993571, 0.8032093644142151, 0.7237701018651327, 0.4142257372538249, 0.8234434127807617, 0.08621150255203247, 0.9528759717941284, 0.4898814260959625, 0.4338949918746948]
20160429 : 20160530
[0.5221199989318848, 0.6397658785184225, 0.736701230208079, 0.38599979877471924, 0.8892502983411154, 0.21077392001946768, 0.9650895595550537, 0.8060749967892965, 0.7252204418182373, 0.4048948486646016, 0.842205802599589, 0.13769917686780295, 0.9546858469645182, 0.45885709921518963, 0.39850228031476337]
20160530 : 20160629
[0.5209031701087952, 0.7052847544352213, 0.7694478432337443, 0.4261919856071472, 0.920309861501058, 0.23533829549948374, 0.9804604848225912, 0.7425735394159952, 0.6740126411120096, 0.2977428336938222, 0.8266976277033488, 0.08406675358613332, 0.9445225795110067, 0.5492556095123291, 0.5034323036670685]
20160629 : 20160727
[0.577679435412089, 0.6042988896369934, 0.7121192812919617, 0.2804385224978129, 0.8790800968805949, 0.11709364255269368, 0.9579868714014689, 0.7689986030260721, 0.6724971532821655, 0.25034789244333905, 0.8153106768925985, 0.033337254698077835, 0.9791463017463684, 0.48622627059618634, 0.41258804003397626]
20160727 : 20160824
[0.554732064406077, 0.6512430508931478, 0.7526447375615438, 0.3729069133599599, 0.8827725251515707, 0.18998457491397858, 0.9538133343060812, 0.7602296868960062, 0.6606735587120056, 0.33815370003382367, 0.7928603291511536, 0.04158638541897138, 0.9380980332692465, 0.5283052325248718, 0.4571632345517476]
20160824 : 20160923
[0.580666204293569, 0.6040971676508585, 0.6967231233914694, 0.2599690556526184, 0.8593732317288717, 0.07136169075965881, 0.9675402045249939, 0.7572202086448669, 0.6743767460187277, 0.30664053559303284, 0.8178765575091044, 0.03992802401383718, 0.978074312210083, 0.4800748825073242, 0.4162432054678599]
20160923 : 20161028
[0.5732341408729553, 0.586691419283549, 0.716315766175588, 0.2835877339045207, 0.8973076740900675, 0.11739646891752879, 0.9615386724472046, 0.7705871065457662, 0.6538620789845785, 0.27918680508931476, 0.8132069309552511, 0.0357852429151535, 0.9409313201904297, 0.4967460036277771, 0.4068824152151744]
20161028 : 20161125
[0.6074507435162863, 0.5327289700508118, 0.7111793756484985, 0.2069567491610845, 0.8885485728581747, 0.07067673156658809, 0.976579487323761, 0.7743290662765503, 0.6136269966761271, 0.21285533905029297, 0.7852751612663269, 0.010242866973082224, 0.9273664951324463, 0.5106534560521444, 0.382544348637263]
20161125 : 20161223
[0.5683644811312357, 0.5833013852437338, 0.6826119621594747, 0.27338143189748126, 0.864722470442454, 0.10352096954981486, 0.9643404086430868, 0.7853506604830424, 0.7042409181594849, 0.33368319272994995, 0.8386533657709757, 0.04345869769652685, 0.9276942610740662, 0.4418293635050456, 0.3775519331296285]
20161223 : 20170123
[0.5523642102877299, 0.6215600172678629, 0.7359422047932943, 0.34820979833602905, 0.8872754772504171, 0.11651075383027394, 0.9595034718513489, 0.797012209892273, 0.6985656420389811, 0.37634897232055664, 0.8233867685000101, 0.0510556697845459, 0.9010580778121948, 0.4760631223519643, 0.40224658449490863]
20170123 : 20170227
[0.5793325304985046, 0.6538892587025961, 0.7454086343447367, 0.34014496207237244, 0.8570442199707031, 0.10472188889980316, 0.9422970414161682, 0.7396511435508728, 0.6470246315002441, 0.2873672346274058, 0.7673331697781881, 0.011376170751949152, 0.9003527363141378, 0.5382009545962015, 0.47205134232838947]
20170227 : 20170327
[0.6039857069651285, 0.5411506295204163, 0.6933780908584595, 0.22313153743743896, 0.8510897159576416, 0.07072568933169048, 0.9354835549990336, 0.7844724853833517, 0.6548874179522196, 0.27066503961881, 0.783754567305247, 0.007463090121746063, 0.8140350977579752, 0.47392351428667706, 0.36988498767217]
20170327 : 20170426
[0.5763227144877116, 0.4942589004834493, 0.6415992180506388, 0.1934088667233785, 0.8342723250389099, 0.05644249667723974, 0.9770851532618204, 0.8163425525029501, 0.7080773909886678, 0.3732957144578298, 0.842408299446106, 0.07203817615906398, 0.9345154364903768, 0.39957208434740704, 0.3077471653620402]
20170426 : 20170525
[0.575218121210734, 0.5653707981109619, 0.6791580518086752, 0.25602561235427856, 0.8325321674346924, 0.06646337856849034, 0.9657052159309387, 0.7783886591593424, 0.6833484768867493, 0.35639803608258563, 0.8446292281150818, 0.06118106345335642, 0.9591080546379089, 0.45350806911786395, 0.3773736258347829]
20170525 : 20170626
[0.5323339104652405, 0.6749416589736938, 0.7707645098368326, 0.40354963143666583, 0.906677226225535, 0.19083259999752045, 0.973891019821167, 0.7609143455823263, 0.6626681685447693, 0.3228462835152944, 0.8150670131047567, 0.07761002580324809, 0.9612560669581095, 0.5436391234397888, 0.4759739637374878]
20170626 : 20170724
[0.5840664903322855, 0.6040328939755758, 0.6892185807228088, 0.2734345992406209, 0.8448772033055624, 0.07662152623136838, 0.9442191123962402, 0.7424192031224569, 0.664755642414093, 0.2839577893416087, 0.8331067562103271, 0.05305617799361547, 0.9661096334457397, 0.48604796330134076, 0.42596060037612915]
20170724 : 20170821
[0.5749452710151672, 0.6443217992782593, 0.7131819923718771, 0.3207562466462453, 0.8860552906990051, 0.10838656624158223, 0.9762597680091858, 0.7355836232503256, 0.6696333686510721, 0.24367683132489523, 0.81400199731191, 0.02395140565931797, 0.9423990249633789, 0.5050370097160339, 0.4562717278798421]
20170821 : 20170918
[0.614362915356954, 0.5368500749270121, 0.6896100242932638, 0.20729578534762064, 0.8752237359682719, 0.053821162631114326, 0.9794501264890035, 0.7633321285247803, 0.6270307302474976, 0.1895776391029358, 0.7742082476615906, 0.003888158050055305, 0.9666666587193807, 0.49505213896433514, 0.38539716601371765]
20170918 : 20171023
[0.5713058908780416, 0.5930356979370117, 0.7139697869618734, 0.2773873209953308, 0.8864696621894836, 0.07482695579528809, 0.9767130414644877, 0.7787680427233378, 0.6725957989692688, 0.296446959177653, 0.8383763829867045, 0.0320691696057717, 0.9491950869560242, 0.4822144905726115, 0.4004635810852051]
20171023 : 20171120
[0.5949939688046774, 0.5280914306640625, 0.658048411210378, 0.17660345137119293, 0.8440783818562826, 0.02302190288901329, 0.9464573264122009, 0.7902607520421346, 0.6865093111991882, 0.27049265305201214, 0.8357558449109396, 0.035086605697870255, 0.9732725222905477, 0.4334492286046346, 0.3477756977081299]
20171120 : 20171218
[0.5636487801869711, 0.6233387589454651, 0.6986027161280314, 0.31632526715596515, 0.8656089504559835, 0.11647027482589085, 0.9493562976519266, 0.7688710689544678, 0.7036671241124471, 0.32529234886169434, 0.8361475268999735, 0.06481341272592545, 0.9492471814155579, 0.46224479873975116, 0.4124097327391307]
20171218 : 20180116
[0.6507112582524618, 0.5340342919031779, 0.6045805811882019, 0.20625622073809305, 0.7602420051892599, 0.039490206787983574, 0.9187212189038595, 0.6878354748090109, 0.6228431661923727, 0.17598187426726022, 0.71885613600413, 0.014866196550428867, 0.8887022137641907, 0.47196219364802044, 0.4168672462304433]
20180116 : 20180213
[0.7278844118118286, 0.4599837561448415, 0.4602884848912557, 0.1162630170583725, 0.5063264071941376, 0.008183811946461598, 0.502400149901708, 0.5823865334192911, 0.5821098883946737, 0.09899737934271495, 0.5990180373191833, 0.002371248362275461, 0.7354497412840525, 0.4363911946614583, 0.4361237386862437]
20180213 : 20180320
[0.7609819769859314, 0.40759652853012085, 0.5561315218607584, 0.07812982300917308, 0.5676490664482117, 0.0039500615481908126, 0.6099715133508047, 0.5780378381411234, 0.429262638092041, 0.0909203365445137, 0.3707800308863322, 0.0004082465699563424, 0.07407407462596893, 0.5647677580515543, 0.413836141427358]
20180320 : 20180419
[0.7320980230967203, 0.4390726884206136, 0.5026789903640747, 0.1000448614358902, 0.5208313961823782, 0.006649315978089969, 0.5265359580516815, 0.5986337661743164, 0.5359570980072021, 0.12316566954056422, 0.5758434136708578, 0.002742952046295007, 0.5703703860441843, 0.4802531798680623, 0.41945261756579083]
20180419 : 20180521
[0.7409135103225708, 0.47827096780141193, 0.5364595651626587, 0.1279920091231664, 0.5895991921424866, 0.01170201258112987, 0.6014203230539957, 0.5556497176488241, 0.49759358167648315, 0.078870490193367, 0.44944079717000324, 0.0011092253068151574, 0.3682539810736974, 0.5182312726974487, 0.46188820401827496]
20180521 : 20180619
[0.7485045591990153, 0.42101940512657166, 0.3634924689928691, 0.10312127818663915, 0.3786735236644745, 0.008763244065145651, 0.4955555597941081, 0.5734093983968099, 0.6313724716504415, 0.11546564350525539, 0.6275175015131632, 0.0019732690804327526, 0.5111111303170522, 0.3663189808527629, 0.4246233403682709]
20180619 : 20180717
[0.7536139289538065, 0.4777560234069824, 0.5488387942314148, 0.14947044352690378, 0.5735464890797933, 0.015432522632181644, 0.6164737145105997, 0.5742554068565369, 0.5035801033178965, 0.14320466915766397, 0.48119376103083294, 0.003156732146938642, 0.42606837550799054, 0.5201034148534139, 0.452794869740804]
20180717 : 20180814
[0.7543952465057373, 0.4393500288327535, 0.4530588189760844, 0.10608174155155818, 0.46461119254430133, 0.006722171790897846, 0.4158549904823303, 0.559668223063151, 0.5459282199541727, 0.11492372552553813, 0.5481685002644857, 0.0016486222545305889, 0.4662698457638423, 0.45359720786412555, 0.4397789041201274]
20180814 : 20180911
[0.7457842429478964, 0.45029595494270325, 0.4246404270331065, 0.12239157408475876, 0.45915862917900085, 0.011601551125446955, 0.4523486892382304, 0.5808447599411011, 0.605999767780304, 0.12287912021080653, 0.5751838684082031, 0.002706772298552096, 0.6160714328289032, 0.4072390099366506, 0.43184452255566913]
20180911 : 20181017
[0.7386055986086527, 0.4584916631380717, 0.48098509510358173, 0.10596313327550888, 0.48386502265930176, 0.009336340085913738, 0.5197132527828217, 0.5827573935190836, 0.560651163260142, 0.1124221682548523, 0.5634302099545797, 0.0013082946922319632, 0.5714285870393118, 0.4575198292732239, 0.4361237386862437]
20181017 : 20181114
[0.7217960158983866, 0.48923853039741516, 0.6020886500676473, 0.12335498879353206, 0.6565253138542175, 0.0061720275940994425, 0.5736111203829447, 0.5822777350743612, 0.46872907876968384, 0.12075733641783397, 0.5177871783574423, 0.0010121881496161222, 0.31111112236976624, 0.5637871225674947, 0.45805474122365314]
20181114 : 20181212
[0.7150911688804626, 0.45989829301834106, 0.5372519493103027, 0.09378999720017116, 0.5626257260640463, 0.004192532428229849, 0.47306398550669354, 0.6187326510747274, 0.5434430440266927, 0.09837564080953598, 0.5888131856918335, 0.0007019117280530432, 0.6666666368643442, 0.4904163380463918, 0.41980921228726703]
20181212 : 20190111
[0.7351961731910706, 0.5297191540400187, 0.48315619428952533, 0.1464984118938446, 0.5075726807117462, 0.014215043745934963, 0.6187636852264404, 0.5218847791353861, 0.5680830478668213, 0.07329785575469334, 0.5738670229911804, 0.001151621089472125, 0.36666667461395264, 0.4576089878877004, 0.5017384390036265]
20190111 : 20190215
[0.7424782117207845, 0.46431917945543927, 0.49375678102175397, 0.08034547666708629, 0.47481729586919147, 0.005360288700709741, 0.5629458427429199, 0.5253990491231283, 0.4959198832511902, 0.05608691523472468, 0.46760066350301105, 0.0001793078651341299, 0.08333333333333333, 0.49924222628275555, 0.4694659908612569]
20190215 : 20190315
[0.7225977381070455, 0.4271313150723775, 0.645826001962026, 0.04964143534501394, 0.7091577053070068, 0.002311936734865109, 0.5704665879408518, 0.6215667525927225, 0.40172789494196576, 0.04497983058293661, 0.3651146988073985, 0.0002305209830713769, 0.3333332935969035, 0.6178122361501058, 0.40848712126413983]
20190315 : 20190415
[0.7330876191457113, 0.5333328445752462, 0.5149634480476379, 0.10322515666484833, 0.5073318779468536, 0.007539146890242894, 0.5527859230836233, 0.5027465720971426, 0.5211473703384399, 0.0569533904393514, 0.5033248861630758, 0.0021315155706057944, 0.614814817905426, 0.4974592129389445, 0.515200138092041]
20190415 : 20190516
[0.7772349715232849, 0.5582201282183329, 0.465119868516922, 0.18224313855171204, 0.48673463861147565, 0.02209085536499818, 0.5040304362773895, 0.4351457456747691, 0.528003474076589, 0.07193617274363835, 0.5164011915524801, 0.0032036421665300927, 0.48055557409922284, 0.4679504334926605, 0.5614691972732544]
20190516 : 20190614
[0.770625114440918, 0.49911148349444073, 0.4433234731356303, 0.15570829312006632, 0.4420177141825358, 0.020726947113871574, 0.46438642342885333, 0.5272402366002401, 0.5824871857961019, 0.11702437698841095, 0.5896152456601461, 0.0018686645586664479, 0.28733766575654346, 0.4299723605314891, 0.48399750391642254]
20190614 : 20190712
[0.7580109437306722, 0.5463989575703939, 0.5106613536675771, 0.16729296743869781, 0.5230797231197357, 0.018630903214216232, 0.4902886251608531, 0.4822925428549449, 0.5181306600570679, 0.07482960323492686, 0.4963911871115367, 0.0023025821428745985, 0.5833333432674408, 0.4971026082833608, 0.5320495764414469]
20190712 : 20190809
[0.7350301941235861, 0.4402405619621277, 0.4303315778573354, 0.09702148040135701, 0.4221001962820689, 0.007121649881203969, 0.4404761989911397, 0.5892868041992188, 0.5989782015482584, 0.11840027074019115, 0.6435391902923584, 0.002883524284698069, 0.6515151659647623, 0.4134795367717743, 0.42292948563893634]
20190809 : 20190906
[0.7477307319641113, 0.36333195368448895, 0.5445758104324341, 0.06072922423481941, 0.5755683382352194, 0.003100892761722207, 0.5476190646489462, 0.631607711315155, 0.4500224788983663, 0.10058060536781947, 0.44978277881940204, 0.0011707139977564414, 0.48888889451821643, 0.5480074882507324, 0.36560577154159546]
20190906 : 20191014
[0.7384669780731201, 0.47571444511413574, 0.5011145075162252, 0.10083309312661488, 0.5169278581937155, 0.00750604597851634, 0.49126983682314557, 0.5304775635401408, 0.5050997932751974, 0.05927569419145584, 0.4959230919679006, 0.0007218405759582917, 0.4166666666666667, 0.4978158175945282, 0.4725862542788188]
20191014 : 20191111
[0.7220893700917562, 0.48640816410382587, 0.4252362052599589, 0.11220294733842213, 0.4682639042536418, 0.008166713795314232, 0.5461446841557821, 0.5729021032651266, 0.6318842768669128, 0.08572812626759212, 0.6430438955624899, 0.0017701273318380117, 0.725000003973643, 0.39386645952860516, 0.45038779576619464]
20191111 : 20191209
[0.7353871663411459, 0.42778585354487103, 0.5087565382321676, 0.09316737701495488, 0.5437935789426168, 0.00675405686100324, 0.6495310366153717, 0.6070831418037415, 0.5271624724070231, 0.09810450424750645, 0.5071024199326833, 0.001036028688152631, 0.45000000794728595, 0.48765265941619873, 0.40973522265752155]

在这里插入图片描述

2019年,按天衡量

In [28]: evaluate_total_time(model, date_step=1, steps=3, start_date=20190102, end_date='', lookback=61, delay=1, uprate=0.0)
20190102 : 20190103
[0.705572267373403, 0.16796444356441498, 0.34837595621744794, 0.019603818655014038, 0.2872556348641713, 0.0004671862892185648, 0.3333333333333333, 0.8054861227671305, 0.6101937095324198, 0.20878386000792185, 0.6408450206120809, 0.0030286312103271484, 1.0, 0.38209859530131024, 0.18436301747957864]
20190103 : 20190104
[0.7293160359064738, 0.4662709931532542, 0.9566953976949056, 0.11041913429896037, 0.9694114128748575, 0.005439450653890769, 0.9207017620404562, 0.5960521896680196, 0.05528130133946737, 0.078614491969347, 0.061956015725930534, 0.0, 0.0, 0.9502540628115336, 0.46313631534576416]
20190104 : 20190107
[0.611860195795695, 0.6754318475723267, 0.9231745402018229, 0.21614141762256622, 0.9518770178159078, 0.013156835610667864, 0.9726867278416952, 0.4886184235413869, 0.1419760783513387, 0.0676371989150842, 0.1624330331881841, 0.0, 0.0, 0.9009539087613424, 0.6591780384381613]
20190107 : 20190108
[0.7243722875912985, 0.4978080987930298, 0.4071895082791646, 0.08383707453807195, 0.3930290639400482, 0.0060423092606167, 0.5100233256816864, 0.5485713283220927, 0.6368274291356405, 0.05625045796235403, 0.6789639393488566, 0.001010074425721541, 0.8222221930821737, 0.3837924599647522, 0.4691985348860423]
20190108 : 20190109
[0.7111263871192932, 0.40813730160395306, 0.5202601154645284, 0.03683588653802872, 0.4601799746354421, 0.0007390297250822186, 0.26666667064030963, 0.6501630942026774, 0.5416418711344401, 0.03953922167420387, 0.5126555760701498, 0.0, 0.0, 0.4817687471707662, 0.37790852785110474]
20190109 : 20190110
[0.7426256934801737, 0.5168083707491556, 0.38925134142239887, 0.08652547498544057, 0.3327723840872447, 0.0026520504616200924, 0.35952381292978924, 0.5255582531293234, 0.6502917011578878, 0.05014399935801824, 0.5897094209988912, 0.00042585157401238877, 0.6666666269302368, 0.369082639614741, 0.4901488820711772]
20190110 : 20190111
[0.6785141626993815, 0.602329154809316, 0.7728579839070638, 0.12033901115258534, 0.7923774719238281, 0.004967429519941409, 0.8315789500872294, 0.5077767968177795, 0.31471818685531616, 0.04856595521171888, 0.3893883327643077, 0.0010114980395883322, 0.9999998807907104, 0.7355799078941345, 0.5732370416323344]
20190111 : 20190114
[0.6953628063201904, 0.2894180317719777, 0.28369874755541485, 0.034236966321865715, 0.2994111180305481, 0.0, 0.0, 0.6982697447141012, 0.704126238822937, 0.09876756866772969, 0.6752941211064657, 0.0003779436810873449, 0.9999998807907104, 0.29223500688870746, 0.29811891913414]
20190114 : 20190115
[0.7420263091723124, 0.43638981382052106, 0.822896420955658, 0.05986353134115537, 0.7901981472969055, 0.002437230432406068, 0.8773809472719828, 0.614015797773997, 0.20959109564622244, 0.036879474918047585, 0.14328667024771372, 0.0, 0.0, 0.8043148914972941, 0.4264954924583435]
20190115 : 20190116
[0.732903261979421, 0.509804348150889, 0.36012641588846844, 0.05011440689365069, 0.3243343234062195, 0.0024426247303684554, 0.4212121268113454, 0.482978622118632, 0.6332911849021912, 0.032063632582624756, 0.6195197502772013, 0.00041390730378528434, 0.3333333333333333, 0.36337701479593915, 0.5144869486490885]
20190116 : 20190117
[0.7116859952608744, 0.4611208339532216, 0.20259833335876465, 0.05057622243960699, 0.21190446615219116, 0.0027048024348914623, 0.3777777850627899, 0.5551392634709676, 0.8075803319613138, 0.05120836322506269, 0.7793605923652649, 0.0006700240968105694, 0.6666666666666666, 0.19702237844467163, 0.44798075159390766]
20190117 : 20190118
[0.7438547015190125, 0.4360784391562144, 0.6822526653607687, 0.062473613768815994, 0.6320405205090841, 0.0030737899554272494, 0.8671024044354757, 0.5344203511873881, 0.2924879988034566, 0.03547515037159125, 0.2521167993545532, 0.0, 0.0, 0.6962645848592123, 0.4450387756029765]
20190118 : 20190121
[0.7438400189081827, 0.39680179953575134, 0.6188981930414835, 0.04004719853401184, 0.6233495473861694, 0.0017499179036046069, 0.6500000059604645, 0.6146770914395651, 0.39254732926686603, 0.047346084068218865, 0.29914529124895733, 0.0, 0.0, 0.6119283239046732, 0.39235090216000873]
20190121 : 20190122
[0.6645412643750509, 0.3758576611677806, 0.21247625350952148, 0.06790528694788615, 0.3083601991335551, 0.003910915615657966, 0.55158731341362, 0.6874891320864359, 0.8308535814285278, 0.08534777909517288, 0.8197145859400431, 0.0003274783957749605, 0.9999998807907104, 0.18320405979951224, 0.3241508404413859]
# ***********************************数据过多,省略部分***********************************

在这里插入图片描述

历史预测曲线

2017六月到2018六月.

In [34]: history_predict(model, ts_code='600004.SH', date=20180601, delay=1, during=244, mod='simple')
600004.SH
[0.5915805101394653, 0.623115599155426, 0.6919642686843872, 0.22914573550224304, 0.8923678994178772, 0.07638190686702728, 0.9870129823684692, 0.7153171896934509, 0.649040699005127, 0.14698298275470734, 0.8357771039009094, 0.024755029007792473, 0.9599999785423279, 0.5064902305603027, 0.45609569549560547]
数据分割日: 3518

在这里插入图片描述

总结

多个模型发现均出现过拟合现象(训练准确率可达0.61,以0.9作为baseline准确率甚至高达0.96),经过分析发现过拟合的原因应该是通过对过去的大盘走向进行过拟合,毕竟800,000个参数(上述混合Attention模型)应该不至于拟合深交所和上交所从2004年到2018年的所有数据.
训练期间尝试了进行统一归一化(即所有股票计算出一个mid和std进行归一化),发现并不能提升性能反而会导致训练过程不稳定.
之前做图像识别的时候有很多预训练模型,可以直接拿来用,但是找了下发现没有进行时间序列预测的预训练模型.
训练过程中进行过多种条件的预测,包括:

  • 使用过去一年的数据对一个月后是否上涨10%进行预测(由于数据不平衡,训练曲线波动极大,练不起来),
  • 使用过去一年的数据对一个月后是否上涨进行预测(更容易过拟合),
  • 使用过去四个月的数据对明天是否上涨进行预测(上述例子),
  • 使用过去一个月数据对明天上涨幅度进行预测(卡在baseline练不起来).

正常情况下,对模型阀值(baseline)取得越高,recall越低,准确率越高.
尝试过在Generator出口进行正则化(直接加了个BN层在模型开头),好像能一定程度上提高稳定性.
一开始使用沪股作训练集,港股作测试集,发现无法发现过拟合现象,后来以时间作为分割标准,准确地识别过拟合现象.
最初使用的生成器每次只能生成同一只股票的数据,现在换了个每次生成的数据完全随机的生成器,虽然性能降低了,但是对稳定性应该有帮助.

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值