使用栈的表达式计算程序.arith_expr.c

本文介绍了一个简单的算术表达式解析器实现,该解析器能够处理包含整数和浮点数的四则运算,并使用栈来管理和计算表达式。文章详细展示了如何通过扫描输入字符串并应用操作符优先级规则来构建和评估逆波兰表示法。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

#include<stdio.h>

#include "arith_expr.h"


#include<assert.h>
#include<ctype.h>
#include<stdlib.h>
#ifndef True
#define True 1
#define False 0
#endif
#define NS_LEN 100
#define OS_LEN 100



struct ns
{
    struct NUM nums[NS_LEN];
    int top;
};

struct os
{
    char  ops[OS_LEN];
    int top ;
};

struct os opstack;
struct ns numstack;


FILE * listing;
char * str ;
int pos;


void setNum(struct NUM * pn1,struct NUM * pn2)
{
    pn1->type = pn2->type;
    if (pn1->type == 0)
        pn1->value.ivalue = pn2->value.ivalue;
    else if (pn1->type == 1)
        pn1->value.fvalue = pn2->value.fvalue;
};


int nspush(struct NUM *num)
{
    int rel = False;
    (numstack.top)++;

    if (numstack.top == NS_LEN)
        fprintf(listing,"push error:the number stack is full/n");
    else
    {
        (numstack.nums[numstack.top]).type = num->type;
        (numstack.nums[numstack.top]).value = num->value;   // 
        rel =True;
    }
    return(rel);
};

int nspop(struct NUM *num)
{
    int rel = False;
    if (numstack.top <0)
        fprintf(listing,"error:the number stack is empty to pop/n");
    else
    {
        num->type = numstack.nums[numstack.top].type;
        num->value = numstack.nums[numstack.top].value;
        rel = True;
    }

    return(rel);
};


char getNextChar()
{
    return(str[pos++]);
};

void ungetNextChar()
{ pos--;
};   

void getNum(char c)
{
    char strvalue[50];
    int i,pointnum=0;
    struct NUM num;
   
    i = 0;
    while (isdigit(c)&&i<=49)
    {   
        strvalue[i++] = c;
        c = getNextChar();
    }   
    if (c=='.'&&pointnum==0)
    {
        pointnum++;
        strvalue[i++] = c;
        c = getNextChar();
        if (!isdigit(c))
        {   
            fprintf(listing,"error.this num has a point .it may be a float,but after the point there is no digit/n ");
            return;
        }   
        while(isdigit(c))   
        {

            strvalue[i++] = c;
            c = getNextChar();
        }
    }   
    else if (c=='.'&&pointnum==1)
    {
        fprintf(listing,"error:tow points/n");
        return;
    }   
    ungetNextChar();
    strvalue[i] = '/0';
   

    if (pointnum==0)
    {   
        num.type = 0;  // integer
        num.value.ivalue = strtol(strvalue,(char **)NULL,10);
    }   
    else
    {   
        num.type = 1;  // float
        num.value.fvalue = strtof(strvalue,(char**)NULL);

    }   
   
    nspush(&num);
};          

int ospush(char op)
{
    opstack.top++;
    if (opstack.top == OS_LEN)
    {
        fprintf(listing,"sorry,the operation statc is full,can not push/n");
        return(False);
    }
    opstack.ops[opstack.top] = op;
    return(1);
};
int ospop(char *op)
{
    if (opstack.top<0)
    {
        fprintf(listing,"The operation stack is empty.can  not pop/n");
        return(False);
    }
    *op = opstack.ops[opstack.top--];
    return(1);
};


void displayns()   // to  debug
{
    int i = 0 ;
    while (i <=numstack.top)
    {
        if (numstack.nums[i].type == 1)
            fprintf(listing,"%4.2f/n",numstack.nums[i].value.fvalue);
        else
           
            fprintf(listing,"%4d/n",numstack.nums[i].value.ivalue);
        i++;
    }   
};
void displayos()  // to debug
{
    int i=0;
           while(i<=opstack.top)
    {
        printf("%c  top = %-3d/n",opstack.ops[i],opstack.top);
        i++;
    }
};   

int getPriority(char c)
{
    if (c =='+'||c=='-')
        return(0);
    else if (c=='*'||c=='/')
        return(1);
    else if (c=='('||c==')')
        return(2);
    else
        return(-1);
};

int IsHigh(char src,char des)
{
    int s_prio,d_prio;

    s_prio = getPriority(src);
    d_prio = getPriority(des);   
   
    if (s_prio < d_prio)
        return(-1);
    else if (s_prio == d_prio)
        return(0);
    else
        return(1);
};
void operate_int(int x1,int x2,char c)
{
    switch (c)
    {
        case '+':
            numstack.nums[numstack.top].value.ivalue = x1+x2;
            return;
        case '-':   
            numstack.nums[numstack.top].value.ivalue = x1-x2;
            return;
        case '*':
            numstack.nums[numstack.top].value.ivalue = x1*x2;
            return;
        case '/':
            numstack.nums[numstack.top].value.ivalue = x1/x2;
            return;
    }   

};

void operate_float(float y1,float y2,char c)
{
    switch(c)
    {
        case '+':
            numstack.nums[numstack.top].value.fvalue = y1+y2;
            return;
        case '-':   
            numstack.nums[numstack.top].value.fvalue = y1-y2;
            return;
        case '*':
            numstack.nums[numstack.top].value.fvalue = y1*y2;
            return;
        case '/':
            numstack.nums[numstack.top].value.fvalue = y1/y2 ;
            return;
    }    
};

void operate()   // pop tow from number stack
{
    int x1,x2;
    float y1,y2;
    char c;
   
    ospop(&c);  
    if (numstack.nums[numstack.top].type==0 && numstack.nums[numstack.top-1].type==0)  //integers
    {
        x2 = numstack.nums[numstack.top--].value.ivalue;
        x1 = numstack.nums[numstack.top].value.ivalue;
        operate_int(x1,x2,c);
    }   
    else
    {          
        if (numstack.nums[numstack.top].type==0)
            y2 = (float)(numstack.nums[numstack.top--].value.ivalue);
        else
            y2 =numstack.nums[numstack.top--].value.fvalue;
        if (numstack.nums[numstack.top].type==0)
            y1 = (float)(numstack.nums[numstack.top].value.ivalue);
        else
            y1 = numstack.nums[numstack.top].value.fvalue;
        numstack.nums[numstack.top].type = 1;
        operate_float(y1,y2,c);
    }//is    
   
};

void getOp(char c)
{
    int temp;
    char ch;

    if (c=='/0')
    {
        while (opstack.top>-1)
            operate();
        return;   
    }
   
   
    if (c=='(')
        ospush(c);
    else if (c==')')
    {
        operate();
        ospop(&c);// pop the '('
        if (opstack.top<0)  // empty stack
            return;
    }
    else
    {
        if (opstack.top<0)  // empty stack
        {
            ospush(c);
            return;
        }   
   
        temp = getPriority(opstack.ops[opstack.top]); 
        if (temp == 2)
        {
            ospush(c);
            return;
        }
       
        temp = IsHigh(opstack.ops[opstack.top],c);
   
        while(temp>=0)
        {   
   
            operate();    // has poped one item
            if (opstack.top<0)  // empty stack
            {
                ospush(c);
                return;
            }   
            temp = IsHigh(opstack.ops[opstack.top],c);
       
        }
        if (temp<0)
            ospush(c);
    }
};   

void init(FILE * fp,char * expr_str)
{
    listing = fp;
    str = expr_str;
    pos = 0;
    numstack.top = -1;
    opstack.top = -1;
};

struct NUM * compute()  
{
    char c;
    struct NUM * result = NULL;

    c= getNextChar();
    if (!(isdigit(c)|| c=='('||c==')'))
    {   
        printf("error:the first character is not a digit/n");
        return(result);
    }   

    while (c!='/0')
    {   
        if (isdigit(c))
            getNum(c);
        else if (c=='+'||c=='-'||c=='*'||c=='('||c==')')
            getOp(c);
        c = getNextChar();
    }
    getOp(c);

    result = (struct NUM * )malloc(sizeof(struct NUM  ));
    assert(result != NULL);
    setNum(result,numstack.nums+numstack.top);

    return(result);   
};




 
runfile('D:/股票量化数据库/股票量化数据库/上涨预测输出程序细节版.py', wdir='D:/股票量化数据库/股票量化数据库') 开始加载和预处理数据... Traceback (most recent call last): File D:\Anaconda\Lib\site-packages\spyder_kernels\py3compat.py:356 in compat_exec exec(code, globals, locals) File d:\股票量化数据库\股票量化数据库\上涨预测输出程序细节版.py:453 system.run(output_results) File d:\股票量化数据库\股票量化数据库\上涨预测输出程序细节版.py:420 in run self.load_and_preprocess_data() File d:\股票量化数据库\股票量化数据库\上涨预测输出程序细节版.py:63 in load_and_preprocess_data df['datetime'] = pd.to_datetime(df['date'] + ' ' + df['time']) File D:\Anaconda\Lib\site-packages\pandas\core\ops\common.py:76 in new_method return method(self, other) File D:\Anaconda\Lib\site-packages\pandas\core\arraylike.py:186 in __add__ return self._arith_method(other, operator.add) File D:\Anaconda\Lib\site-packages\pandas\core\series.py:5819 in _arith_method return base.IndexOpsMixin._arith_method(self, other, op) File D:\Anaconda\Lib\site-packages\pandas\core\base.py:1381 in _arith_method result = ops.arithmetic_op(lvalues, rvalues, op) File D:\Anaconda\Lib\site-packages\pandas\core\ops\array_ops.py:285 in arithmetic_op res_values = _na_arithmetic_op(left, right, op) # type: ignore[arg-type] File D:\Anaconda\Lib\site-packages\pandas\core\ops\array_ops.py:220 in _na_arithmetic_op result = func(left, right) UFuncTypeError: ufunc 'add' did not contain a loop with signature matching types (dtype('int64'), dtype('<U1')) -> None
最新发布
08-09
<think>这个错误是因为在尝试将整数(int64)和字符串(' ')相加。具体来说,在`df['datetime'] = pd.to_datetime(df['date'] + ' ' + df['time'])`这一行中,`df['date']`列是整数类型,而`' '`是字符串,所以不能直接相加。 根据之前的数据格式,我们在解析LC5文件时,将日期存储为字符串(如'20230908'),时间存储为字符串(如'09:35')。但是,在保存为CSV后,如果使用pandas读取,pandas可能会自动将看起来像数字的列转换为整数类型。 解决方案: 1. 在读取CSV文件时,指定列的类型为字符串,避免自动转换。 2. 在合并日期和时间之前,将日期列转换为字符串。 我建议在读取CSV时指定`dtype`参数,将`date`和`time`列作为字符串读取。 修改`load_and_preprocess_data`方法中读取CSV的部分: 原代码: df = pd.read_csv(file_path) 修改为: df = pd.read_csv(file_path, dtype={'date': str, 'time': str}) 这样,日期和时间列都会作为字符串读取,然后就可以进行字符串拼接了。 另外,我们还需要确保日期和时间字符串的格式正确。在解析LC5文件时,我们保存的日期格式是'%Y%m%d'(如20230908),时间格式是'%H:%M'(如09:35)。因此,拼接后的字符串是'20230908 09:35',然后通过`pd.to_datetime`解析,需要指定格式或让pandas自动推断。 但是,由于我们保存的日期字符串是8位数字,时间字符串是5位(包括冒号),所以拼接后是14位(包括空格)的字符串,pandas可以自动识别。 修改后的代码: ```python def load_and_preprocess_data(self): print("开始加载和预处理数据...") all_data = [] # 加载五分钟线数据 for market, path in self.five_min_paths.items(): for file_path in glob.glob(os.path.join(path, '*.csv')): stock_code = os.path.basename(file_path).split('.')[0] # 读取CSV时指定日期和时间列为字符串 df = pd.read_csv(file_path, dtype={'date': str, 'time': str}) # 添加市场前缀 df['stock_code'] = f"{market}_{stock_code}" # 转换日期格式 # 现在date和time都是字符串,可以拼接 df['datetime'] = pd.to_datetime(df['date'] + ' ' + df['time']) df.drop(['date', 'time'], axis=1, inplace=True) # 筛选日期范围 df = df[(df['datetime'] >= self.start_date) & (df['datetime'] <= self.end_date)] if not df.empty: all_data.append(df) # 加载日线数据 daily_data = [] for market, path in self.daily_paths.items(): for file_path in glob.glob(os.path.join(path, '*.csv')): stock_code = os.path.basename(file_path).split('.')[0] # 日线数据只有日期,没有时间 df = pd.read_csv(file_path, dtype={'date': str}) df['date'] = pd.to_datetime(df['date']) # 添加市场前缀 df['stock_code'] = f"{market}_{stock_code}" # 筛选日期范围 df = df[(df['date'] >= self.start_date) & (df['date'] <= self.end_date)] if not df.empty: daily_data.append(df) # 合并数据 if not all_data: raise ValueError("没有找到有效的五分钟线数据") self.data = pd.concat(all_data, ignore_index=True) daily_df = pd.concat(daily_data, ignore_index=True) # 添加日线特征 self._add_daily_features(daily_df) print(f"数据加载完成,共 {len(self.data)} 条记录") ``` 注意:日线数据只需要日期列,所以只指定`date`列为字符串即可。 另外,在`_add_daily_features`方法中,我们合并日线特征到五分钟线数据时,需要确保日期列的类型一致。在五分钟线数据中,我们已经有了一个`date`列(在合并日线特征之前,我们创建了一个`date`列,是从`datetime`列中提取的日期部分,类型为`datetime.date`,然后转换为`pd.Timestamp`)。在日线数据中,`date`列是`pd.Timestamp`类型。所以它们可以匹配。 但是,在合并日线特征时,我们使用了`pd.merge`,在`on=['stock_code','date']`上进行合并。由于五分钟线数据中的`date`列是从`datetime`列提取的日期(通过`df['date'] = self.data['datetime'].dt.date`,然后转换为`pd.to_datetime`),而日线数据中的`date`列也是`pd.Timestamp`,所以类型一致。 如果仍然出现问题,可以检查两个数据框中的`date`列类型,确保都是`pd.Timestamp`。 此外,在日线数据中,我们读取时已经将`date`列转换为`pd.Timestamp`,所以没有问题。 这个错误是因为在读取CSV时,`date`和`time`列被自动解析为整数,导致无法与空格字符串拼接。通过指定`dtype`为字符串,可以避免这个问题。 请尝试修改代码,然后重新运行。</think># 解决日期时间转换错误 您遇到的错误是由于在尝试将整数类型的日期列与字符串类型的时间列进行拼接时出现的类型不匹配问题。让我提供完整的修复方案: ```python import os import glob import pandas as pd import numpy as np import joblib from datetime import datetime, timedelta from sklearn.preprocessing import StandardScaler from sklearn.cluster import KMeans from sklearn.model_selection import train_test_split from sklearn.ensemble import RandomForestClassifier from sklearn.metrics import classification_report import talib as ta import warnings # 忽略特定警告 warnings.filterwarnings("ignore", category=UFuncTypeWarning) class StockPredictionSystem: def __init__(self, config): """ 初始化预测系统 :param config: 配置字典 """ self.config = config self.five_min_paths = { 'sz': config['five_min_sz_path'], 'sh': config['five_min_sh_path'] } self.daily_paths = { 'sz': config['daily_sz_path'], 'sh': config['daily_sh_path'] } self.output_path = config['output_path'] self.start_date = datetime.strptime(config['start_date'], '%Y-%m-%d') self.end_date = datetime.strptime(config['end_date'], '%Y-%m-%d') self.data = None self.features = None self.labels = None self.scaler = StandardScaler() self.cluster_model = None self.prediction_model = None self.feature_cols = [] # 存储特征列名 def load_and_preprocess_data(self): """ 加载和预处理数据 - 修复日期时间转换问题 """ print("开始加载和预处理数据...") all_data = [] # 加载五分钟线数据 for market, path in self.five_min_paths.items(): for file_path in glob.glob(os.path.join(path, '*.csv')): stock_code = os.path.basename(file_path).split('.')[0] try: # 读取CSV文件时指定列类型 df = pd.read_csv(file_path, dtype={ 'date': str, # 确保日期列为字符串 'time': str # 确保时间列为字符串 }) # 添加市场前缀 df['stock_code'] = f"{market}_{stock_code}" # 1. 修复日期时间转换问题 # 确保日期和时间列都是字符串类型 df['date'] = df['date'].astype(str).str.zfill(8) # 填充为8位字符串 df['time'] = df['time'].astype(str) # 处理时间格式 - 确保格式为 HH:MM df['time'] = df['time'].apply( lambda x: f"{x[:2]}:{x[2:4]}" if len(x) == 4 else x ) # 合并日期和时间 df['datetime'] = pd.to_datetime( df['date'] + ' ' + df['time'], format='%Y%m%d %H:%M', errors='coerce' ) # 删除无效的日期时间 df = df.dropna(subset=['datetime']) # 筛选日期范围 df = df[(df['datetime'] >= self.start_date) & (df['datetime'] <= self.end_date)] if not df.empty: all_data.append(df) except Exception as e: print(f"处理文件 {file_path} 时出错: {str(e)}") continue # 加载日线数据 daily_data = [] for market, path in self.daily_paths.items(): for file_path in glob.glob(os.path.join(path, '*.csv')): stock_code = os.path.basename(file_path).split('.')[0] try: df = pd.read_csv(file_path, dtype={'date': str}) # 添加市场前缀 df['stock_code'] = f"{market}_{stock_code}" # 转换日期格式 df['date'] = pd.to_datetime(df['date'], errors='coerce') # 筛选日期范围 df = df[(df['date'] >= self.start_date) & (df['date'] <= self.end_date)] if not df.empty: daily_data.append(df) except Exception as e: print(f"处理日线文件 {file_path} 时出错: {str(e)}") continue # 合并数据 if not all_data: raise ValueError("没有找到有效的五分钟线数据") self.data = pd.concat(all_data, ignore_index=True) daily_df = pd.concat(daily_data, ignore_index=True) # 添加日线特征 self._add_daily_features(daily_df) print(f"数据加载完成,共 {len(self.data)} 条记录") def _add_daily_features(self, daily_df): """ 添加日线特征到五分钟线数据 :param daily_df: 日线数据DataFrame """ # 预处理日线数据 daily_df = daily_df.sort_values(['stock_code', 'date']) # 计算日线技术指标 daily_df['daily_ma5'] = daily_df.groupby('stock_code')['close'].transform( lambda x: x.rolling(5).mean()) daily_df['daily_ma10'] = daily_df.groupby('stock_code')['close'].transform( lambda x: x.rolling(10).mean()) daily_df['daily_vol_ma5'] = daily_df.groupby('stock_code')['volume'].transform( lambda x: x.rolling(5).mean()) # 计算MACD def calculate_macd(group): group = group.sort_values('date') # 确保有足够的数据计算MACD if len(group) < 26: group['daily_macd'] = np.nan group['daily_signal'] = np.nan else: macd, signal, _ = ta.MACD(group['close'].values, fastperiod=12, slowperiod=26, signalperiod=9) group['daily_macd'] = macd group['daily_signal'] = signal return group daily_df = daily_df.groupby('stock_code').apply(calculate_macd) # 合并日线特征到五分钟线数据 self.data['date'] = self.data['datetime'].dt.date self.data['date'] = pd.to_datetime(self.data['date']) # 合并日线特征 self.data = pd.merge( self.data, daily_df[['stock_code', 'date', 'daily_ma5', 'daily_ma10', 'daily_vol_ma5', 'daily_macd', 'daily_signal']], on=['stock_code', 'date'], how='left' ) def create_features(self): """ 创建特征工程 """ print("开始创建特征...") if self.data is None: raise ValueError("请先加载数据") # 按股票和时间排序 self.data = self.data.sort_values(['stock_code', 'datetime']) # 特征列表 features = [] # 1. 基础特征 features.append('open') features.append('high') features.append('low') features.append('close') features.append('volume') features.append('amount') # 2. 技术指标 # 计算移动平均线 self.data['ma5'] = self.data.groupby('stock_code')['close'].transform( lambda x: x.rolling(5).mean()) self.data['ma10'] = self.data.groupby('stock_code')['close'].transform( lambda x: x.rolling(10).mean()) features.extend(['ma5', 'ma10']) # 计算RSI self.data['rsi'] = self.data.groupby('stock_code')['close'].transform( lambda x: ta.RSI(x.values, timeperiod=14)) features.append('rsi') # 计算布林带 def calculate_bollinger(group): group = group.sort_values('datetime') # 确保有足够的数据计算布林带 if len(group) < 20: group['bb_upper'] = np.nan group['bb_middle'] = np.nan group['bb_lower'] = np.nan else: upper, middle, lower = ta.BBANDS(group['close'].values, timeperiod=20, nbdevup=2, nbdevdn=2) group['bb_upper'] = upper group['bb_middle'] = middle group['bb_lower'] = lower return group self.data = self.data.groupby('stock_code').apply(calculate_bollinger) features.extend(['bb_upper', 'bb_middle', 'bb_lower']) # 3. 波动率特征 self.data['price_change'] = self.data.groupby('stock_code')['close'].pct_change() self.data['volatility'] = self.data.groupby('stock_code')['price_change'].transform( lambda x: x.rolling(10).std()) features.append('volatility') # 4. 成交量特征 self.data['vol_change'] = self.data.groupby('stock_code')['volume'].pct_change() self.data['vol_ma5'] = self.data.groupby('stock_code')['volume'].transform( lambda x: x.rolling(5).mean()) features.extend(['vol_change', 'vol_ma5']) # 5. 日线特征 features.extend(['daily_ma5', 'daily_ma10', 'daily_vol_ma5', 'daily_macd', 'daily_signal']) # 保存特征列 self.feature_cols = features # 处理缺失值 self.data = self.data.dropna(subset=features) print(f"特征创建完成,共 {len(features)} 个特征") def create_labels(self): """ 创建标签 """ print("开始创建标签...") if self.data is None: raise ValueError("请先加载数据") # 按股票和时间排序 self.data = self.data.sort_values(['stock_code', 'datetime']) # 获取每个交易日的特定时间点数据 # 提取每天9:35和15:00的数据 self.data['time'] = self.data['datetime'].dt.time self.data['date'] = self.data['datetime'].dt.date # 获取每天9:35的数据 morning_data = self.data[self.data['time'] == pd.to_datetime('09:35:00').time()] morning_data = morning_data[['stock_code', 'date', 'close']] morning_data.columns = ['stock_code', 'date', 'morning_close'] # 获取每天15:00的数据 close_data = self.data[self.data['time'] == pd.to_datetime('15:00:00').time()] close_data = close_data[['stock_code', 'date', 'close']] close_data.columns = ['stock_code', 'date', 'daily_close'] # 合并数据 merged = pd.merge(morning_data, close_data, on=['stock_code', 'date'], how='inner') # 计算次日和后日数据 merged['next_date'] = merged['date'] + timedelta(days=1) merged['next_next_date'] = merged['date'] + timedelta(days=2) # 合并次日数据 merged = pd.merge( merged, morning_data.rename(columns={'date': 'next_date', 'morning_close': 'next_morning_close'}), left_on=['stock_code', 'next_date'], right_on=['stock_code', 'next_date'], how='left' ) merged = pd.merge( merged, close_data.rename(columns={'date': 'next_date', 'daily_close': 'next_daily_close'}), left_on=['stock_code', 'next_date'], right_on=['stock_code', 'next_date'], how='left' ) # 合并后日数据 merged = pd.merge( merged, morning_data.rename(columns={'date': 'next_next_date', 'morning_close': 'next_next_morning_close'}), left_on=['stock_code', 'next_next_date'], right_on=['stock_code', 'next_next_date'], how='left' ) # 计算条件 # 条件1: 次日收盘价 > 次日9:35收盘价 * 1.05 condition1 = merged['next_daily_close'] > merged['next_morning_close'] * 1.05 # 条件2: 后日9:35收盘价 > 次日收盘价 * 1.01 condition2 = merged['next_next_morning_close'] > merged['next_daily_close'] * 1.01 # 最终标签 merged['label'] = np.where(condition1 & condition2, 1, 0) # 合并标签到主数据 label_df = merged[['stock_code', 'date', 'label']] self.data = pd.merge( self.data, label_df, on=['stock_code', 'date'], how='left' ) # 删除没有标签的数据 self.data = self.data.dropna(subset=['label']) self.labels = self.data['label'] # 统计标签分布 label_counts = self.labels.value_counts() print(f"标签创建完成,正样本比例: {label_counts[1]/(label_counts[0]+label_counts[1]):.4f}") print(f"正样本数: {label_counts[1]}, 负样本数: {label_counts[0]}") def perform_clustering(self, n_clusters=5): """ 执行聚类分析 :param n_clusters: 聚类数量 """ print(f"开始聚类分析,聚类数: {n_clusters}...") if self.feature_cols is None: raise ValueError("请先创建特征") # 标准化特征 scaled_features = self.scaler.fit_transform(self.data[self.feature_cols]) # 执行K-Means聚类 self.cluster_model = KMeans(n_clusters=n_clusters, random_state=42, n_init=10) clusters = self.cluster_model.fit_predict(scaled_features) # 添加聚类结果到数据 self.data['cluster'] = clusters self.feature_cols.append('cluster') # 分析聚类结果 cluster_summary = self.data.groupby('cluster')['label'].agg(['mean', 'count']) print("聚类结果分析:") print(cluster_summary) # 保存聚类模型 cluster_model_path = os.path.join( self.output_path, "分钟线预测训练聚类模型.pkl" ) joblib.dump(self.cluster_model, cluster_model_path) print(f"聚类模型已保存至: {cluster_model_path}") def train_prediction_model(self): """ 训练预测模型 """ print("开始训练预测模型...") if self.feature_cols is None or self.labels is None: raise ValueError("请先创建特征和标签") # 准备训练数据 X = self.data[self.feature_cols] y = self.labels # 检查类别分布 if y.nunique() < 2: print("警告: 只有一个类别的数据,无法训练模型") return # 划分训练集和测试集 X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.2, random_state=42, stratify=y ) # 训练随机森林分类器 self.prediction_model = RandomForestClassifier( n_estimators=200, max_depth=10, min_samples_split=5, class_weight='balanced', random_state=42, n_jobs=-1 ) self.prediction_model.fit(X_train, y_train) # 评估模型 y_pred = self.prediction_model.predict(X_test) print("模型评估报告:") print(classification_report(y_test, y_pred)) # 保存预测模型 model_path = os.path.join( self.output_path, "分钟线预测训练模型.pkl" ) joblib.dump(self.prediction_model, model_path) print(f"预测模型已保存至: {model_path}") def predict_and_save(self, output_results=True): """ 预测并保存结果 :param output_results: 是否输出预测结果 """ if not output_results: print("预测结果输出已关闭") return if self.prediction_model is None: print("没有可用的预测模型") return print("开始生成预测结果...") # 获取预测概率 X = self.data[self.feature_cols] probabilities = self.prediction_model.predict_proba(X)[:, 1] # 创建结果DataFrame results = self.data[['stock_code', 'datetime', 'close']].copy() results['prediction_prob'] = probabilities # 添加日期列 results['date'] = results['datetime'].dt.date # 获取每个股票每天的最后一个预测(作为当天的最终预测) daily_predictions = results.sort_values('datetime').groupby( ['stock_code', 'date'] ).last().reset_index() # 保存结果 os.makedirs(self.output_path, exist_ok=True) output_file = os.path.join( self.output_path, f"预测结果_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv" ) daily_predictions.to_csv(output_file, index=False) print(f"预测结果已保存至: {output_file}") def run(self, output_results=True): """ 运行整个流程 :param output_results: 是否输出预测结果 """ try: self.load_and_preprocess_data() self.create_features() self.create_labels() self.perform_clustering(n_clusters=self.config.get('n_clusters', 5)) self.train_prediction_model() self.predict_and_save(output_results) print("训练和预测流程完成!") except Exception as e: print(f"运行过程中出错: {str(e)}") import traceback traceback.print_exc() # 配置参数 config = { # 数据路径配置 'five_min_sz_path': r"D:\股票量化数据库\股票五分钟线csv数据\深证", 'five_min_sh_path': r"D:\股票量化数据库\股票五分钟线csv数据\上证", 'daily_sz_path': r"D:\股票量化数据库\股票csv数据\深证", 'daily_sh_path': r"D:\股票量化数据库\股票csv数据\上证", # 输出路径 'output_path': r"D:\股票量化数据库\预测结果", # 时间范围配置 'start_date': '2023-09-08', 'end_date': '2025-08-07', # 聚类配置 'n_clusters': 5 } # 创建并运行系统 if __name__ == "__main__": # 是否输出预测结果(可根据需要修改) output_results = True system = StockPredictionSystem(config) system.run(output_results) ``` ## 错误原因与修复方案 ### 错误原因分析 ```python UFuncTypeError: ufunc 'add' did not contain a loop with signature matching types (dtype('int64'), dtype('<U1')) -> None ``` 这个错误发生在以下代码行: ```python df['datetime'] = pd.to_datetime(df['date'] + ' ' + df['time']) ``` 错误原因是: 1. `df['date']` 列被识别为整数类型(int64) 2. `' '` 是字符串类型(<U1) 3. Python 无法直接将整数与字符串相加 ### 完整修复方案 #### 1. 确保日期和时间列为字符串类型 ```python # 读取CSV时指定列类型 df = pd.read_csv(file_path, dtype={ 'date': str, # 确保日期列为字符串 'time': str # 确保时间列为字符串 }) # 额外确保类型转换 df['date'] = df['date'].astype(str).str.zfill(8) # 填充为8位字符串 df['time'] = df['time'].astype(str) ``` #### 2. 处理时间格式 确保时间格式为 "HH:MM": ```python # 处理时间格式 - 确保格式为 HH:MM df['time'] = df['time'].apply( lambda x: f"{x[:2]}:{x[2:4]}" if len(x) == 4 else x ) ``` #### 3. 使用正确的日期时间格式 ```python # 合并日期和时间 df['datetime'] = pd.to_datetime( df['date'] + ' ' + df['time'], format='%Y%m%d %H:%M', # 明确指定格式 errors='coerce' # 转换失败时设为NaT ) # 删除无效的日期时间 df = df.dropna(subset=['datetime']) ``` ### 其他关键修复 #### 1. 异常处理增强 ```python try: # 处理文件 except Exception as e: print(f"处理文件 {file_path} 时出错: {str(e)}") continue ``` #### 2. 数据完整性检查 ```python # 在创建标签时检查数据完整性 if self.data is None: raise ValueError("请先加载数据") # 在特征工程中处理缺失值 self.data = self.data.dropna(subset=features) ``` #### 3. 标签分布统计 ```python # 统计标签分布 label_counts = self.labels.value_counts() print(f"标签创建完成,正样本比例: {label_counts[1]/(label_counts[0]+label_counts[1]):.4f}") print(f"正样本数: {label_counts[1]}, 负样本数: {label_counts[0]}") ``` #### 4. 模型训练前的类别检查 ```python # 检查类别分布 if y.nunique() < 2: print("警告: 只有一个类别的数据,无法训练模型") return ``` ### 完整运行流程 1. **数据加载与预处理**: - 修复日期时间转换问题 - 确保数据类型正确 - 处理缺失值 2. **特征工程**: - 添加技术指标(MA, RSI, Bollinger Bands等) - 添加波动率和成交量特征 - 添加日线特征 3. **标签创建**: - 根据您的规则创建标签 - 处理跨日数据合并 - 统计标签分布 4. **聚类分析**: - 使用K-Means进行聚类 - 保存聚类模型 5. **模型训练**: - 使用随机森林分类器 - 处理类别不平衡问题 - 评估模型性能 - 保存预测模型 6. **预测结果输出**: - 生成预测概率 - 按股票和日期汇总 - 保存为CSV文件 ### 使用说明 1. **配置参数**: ```python config = { 'five_min_sz_path': r"您的深证五分钟线路径", 'five_min_sh_path': r"您的上证五分钟线路径", 'daily_sz_path': r"您的深证日线路径", 'daily_sh_path': r"您的上证日线路径", 'output_path': r"您的输出路径", 'start_date': '2023-09-08', 'end_date': '2025-08-07', 'n_clusters': 5 } ``` 2. **运行系统**: ```python system = StockPredictionSystem(config) system.run(output_results=True) # True表示输出预测结果 ``` 3. **输出文件**: - 聚类模型:`分钟线预测训练聚类模型.pkl` - 预测模型:`分钟线预测训练模型.pkl` - 预测结果:`预测结果_YYYYMMDD_HHMMSS.csv` ### 注意事项 1. **数据质量**: - 确保输入数据格式正确 - 检查数据是否包含所需的日期范围 2. **内存管理**: - 如果数据量很大,考虑分批处理 - 使用`dtype`参数减少内存占用 3. **类别不平衡**: - 如果正样本比例过低,考虑调整采样策略 - 使用SMOTE等技术处理不平衡数据 4. **模型选择**: - 随机森林适合处理高维特征 - 对于时间序列数据,也可尝试LSTM等序列模型
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值