TensorFlow加轴的bug

在编程过程中遇到一个涉及 TensorFlow 的 bug,错误地尝试在同一步骤中进行按轴取数和增加维度操作。正确做法是将这两步分开,例如:count_value=rules[:,1][tf.newaxis]用于正确地按轴取数并新增一维,而错误用法如 count_value=rules[tf.newaxis,:,1] 会导致随机错误。理解这种区别对于避免类似问题至关重要。
部署运行你感兴趣的模型镜像

写代码的过程中遇到一个特别的bug,会产生随机的错误,记录一下,
当我们想同时完成按轴取数并且新加一维的时候,要放到两步里面做,而不是同一步里面完成

正确用法用法

count_value = rules[:, 1][tf.newaxis] # 1, M

错误用法

count_value = rules[tf.newaxis,:,1]

您可能感兴趣的与本文相关的镜像

TensorFlow-v2.15

TensorFlow-v2.15

TensorFlow

TensorFlow 是由Google Brain 团队开发的开源机器学习框架,广泛应用于深度学习研究和生产环境。 它提供了一个灵活的平台,用于构建和训练各种机器学习模型

Code 分享 Notebook 保存成功 Python 3 (ipykernel) import os import numpy as np import pandas as pd import matplotlib.pyplot as plt from sklearn.preprocessing import MinMaxScaler from tensorflow.keras import layers,losses,optimizers, Sequential from tensorflow.keras.layers import InputLayer, Dense, LSTM, Dropout from tensorflow.keras.models import load_model from tensorflow.keras import Model 0秒 + Code + Markdown stock_data = pd.read_csv('IBM_stock_data.csv') 0秒 + Code + Markdown stock_data.head() 0秒 date Open High Low Close Volume Price Change % 0 1999/11/1 98.50 98.81 96.37 96.75 9551800 0.000000 1 1999/11/2 96.75 96.81 93.69 94.81 11105400 -2.005168 2 1999/11/3 95.87 95.94 93.50 94.37 10369100 -0.464086 3 1999/11/4 94.44 94.44 90.00 91.56 16697600 -2.977641 4 1999/11/5 92.75 92.94 90.19 90.25 13737600 -1.430756 + Code + Markdown scaler = MinMaxScaler(feature_range=(0, 1)) scaled_data = scaler.fit_transform(stock_data['Close'].values.reshape(-1, 1)) 0秒 + Code + Markdown scaled_data[0] 0秒 array([0.23131139]) + Code + Markdown len(scaled_data) 0秒 6293 + Code + Markdown training_data_len = int(np.ceil(len(scaled_data) * 0.8)) #向上取整 0秒 + Code + Markdown train_data = scaled_data[0:training_data_len] X_train, y_train = [], [] time_step = 10 # 时间窗口,模型基于前10个时间步长进行预测。可以尝试不同长度(如20、30)并观察效果变化。 0秒 + Code + Markdown for i in range(len(train_data) - time_step - 1): X_train.append(train_data[i:(i + time_step), 0]) y_train.append(train_data[i + time_step, 0]) 0秒 + Code + Markdown X_train, y_train = np.array(X_train), np.array(y_train) X_train = X_train.reshape(X_train.shape[0], X_train.shape[1], 1) 0秒 + Code + Markdown X_train.shape 0秒 (5024, 10, 1) + Code + Markdown X_test, y_test = [], [] test_data = scaled_data[training_data_len - time_step:] ​ for i in range(len(test_data) - time_step): X_test.append(test_data[i:(i + time_step), 0]) y_test.append(test_data[i + time_step, 0]) ​ X_test, y_test = np.array(X_test), np.array(y_test) X_test = X_test.reshape(X_test.shape[0], X_test.shape[1], 1) 0秒 + Code + Markdown model = Sequential() model.add(InputLayer(input_shape=(X_train.shape[1], 1))) model.add(LSTM(units=64, return_sequences=True)) # 调整Dropout: 当前设置为0.3,可以尝试在不同的层上使用不同的Dropout值,例如0.2~0.5之间。Dropout的作用是防止过拟合。 model.add(Dropout(0.3)) model.add(LSTM(units=64, return_sequences=True)) model.add(Dropout(0.3)) model.add(LSTM(units=32)) model.add(Dropout(0.2)) # 增回归层: 如果希望更高的拟合精度,可以添多个Dense层,例如在输出前再增一层Dense。 model.add(Dense(units=1)) 3秒 2025-06-21 16:21:14.205566: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory 2025-06-21 16:21:14.205665: W tensorflow/stream_executor/cuda/cuda_driver.cc:269] failed call to cuInit: UNKNOWN ERROR (303) 2025-06-21 16:21:14.205686: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (dsw-1161427-779497c87d-j6j54): /proc/driver/nvidia/version does not exist 2025-06-21 16:21:14.206190: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F FMA To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags. + Code + Markdown def MyRNN(): model = Sequential([ layers.InputLayer(input_shape=(X_train.shape[1],1)), layers.SimpleRNN(units=64, dropout=0.5, return_sequences=True, unroll=True), layers.SimpleRNN(units=64, dropout=0.5, unroll=True), layers.Dense(1)] ) return model 0秒 + Code + Markdown model = MyRNN() model.compile(optimizer='adam', loss='mean_squared_error') 0秒 + Code + Markdown history = model.fit( X_train, y_train, epochs=2, # 批大小(batch_size): 尝试不同的batch_size,例如16、32、64,以找到训练稳定性和准确性之间的平衡 batch_size=32, #callbacks=[early_stopping, lr_scheduler] ) 12秒 Epoch 1/2 WARNING:tensorflow:AutoGraph could not transform <function Model.make_train_function.<locals>.train_function at 0x7f02c8692040> and will run it as-is. Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: module 'gast' has no attribute 'Constant' To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert WARNING: AutoGraph could not transform <function Model.make_train_function.<locals>.train_function at 0x7f02c8692040> and will run it as-is. Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: module 'gast' has no attribute 'Constant' To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert 157/157 [==============================] - 9s 20ms/step - loss: 0.0396 Epoch 2/2 157/157 [==============================] - 3s 21ms/step - loss: 0.0266 + Code + Markdown model.save('my_model.keras') 0秒 + Code + Markdown train_loss = model.evaluate(X_train, y_train, verbose=0) test_loss = model.evaluate(X_test, y_test, verbose=0) 5秒 WARNING:tensorflow:AutoGraph could not transform <function Model.make_test_function.<locals>.test_function at 0x7f02cc2b5dc0> and will run it as-is. Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: module 'gast' has no attribute 'Constant' To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert WARNING: AutoGraph could not transform <function Model.make_test_function.<locals>.test_function at 0x7f02cc2b5dc0> and will run it as-is. Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: module 'gast' has no attribute 'Constant' To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert + Code + Markdown # 训练集上计算的损失值,数值越小表示模型在训练数据上的拟合效果越好 print(f"Training Loss: {train_loss:.4f}") # 测试集上计算的损失值,反映了模型在未见过的数据上的表现。测试损失略高于训练损失,但差距不大,说明模型在新数据上的表现依然良好。 print(f"Testing Loss: {test_loss:.4f}") 0秒 Training Loss: 0.0412 Testing Loss: 0.0480 + Code + Markdown model = load_model('my_model.keras') predictions = model.predict(X_test) predictions = scaler.inverse_transform(predictions) # 反归一化预测值 2秒 WARNING:tensorflow:AutoGraph could not transform <function Model.make_predict_function.<locals>.predict_function at 0x7f0276f14040> and will run it as-is. Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: module 'gast' has no attribute 'Constant' To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert WARNING: AutoGraph could not transform <function Model.make_predict_function.<locals>.predict_function at 0x7f0276f14040> and will run it as-is. Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: module 'gast' has no attribute 'Constant' To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert + Code + Markdown import matplotlib.pyplot as plt plt.rcParams['font.sans-serif'] = ['SimHei'] # 设置中文字体为黑体 plt.rcParams['axes.unicode_minus'] = False # 正确显示负号 0秒 + Code + Markdown %matplotlib inline 0秒 + Code + Markdown Code train = stock_data[:training_data_len] valid = stock_data[training_data_len:] valid.loc[:, 'Predictions'] = predictions ​ ​ # 绘制图像 plt.figure(figsize=(14, 5)) plt.title('股票价格预测', fontsize=20) # 图表标题改为中文 plt.xlabel('日期', fontsize=14) # X 标签改为中文 plt.ylabel('收盘价', fontsize=14) # Y 标签改为中文 plt.plot(train['date'], train['Close'], label='训练数据', color='blue') # 训练数据标签改为中文 plt.plot(valid['date'], valid['Close'], label='真实价格', color='green') # 真实价格标签改为中文 plt.plot(valid['date'], valid['Predictions'], label='预测价格', color='red') # 预测价格标签改为中文 plt.legend() # 添图例 # 添保存图像的代码 plt.savefig('stock_price_predictions.png') # 保存图像 plt.show() ​ # 计算和输出评估指标 rmse = np.sqrt(np.mean(np.square(predictions - y_test))) mae = np.mean(np.abs(predictions - y_test)) print(f'均方根误差 (RMSE): {rmse}, 平均绝对误差 (MAE): {mae}') # 输出信息改为中文 4分钟58秒 /opt/conda/lib/python3.8/site-packages/pandas/core/indexing.py:1667: SettingWithCopyWarning: A value is trying to be set on a copy of a slice from a DataFrame. Try using .loc[row_indexer,col_indexer] = value instead See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy self.obj[key] = value findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans. findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei /tmp/ipykernel_18835/4137420046.py:16: UserWarning: Glyph 32929 (\N{CJK UNIFIED IDEOGRAPH-80A1}) missing from current font. plt.savefig('stock_price_predictions.png') # 保存图像 /tmp/ipykernel_18835/4137420046.py:16: UserWarning: Glyph 31080 (\N{CJK UNIFIED IDEOGRAPH-7968}) missing from current font. plt.savefig('stock_price_predictions.png') # 保存图像 /tmp/ipykernel_18835/4137420046.py:16: UserWarning: Glyph 20215 (\N{CJK UNIFIED IDEOGRAPH-4EF7}) missing from current font. plt.savefig('stock_price_predictions.png') # 保存图像 /tmp/ipykernel_18835/4137420046.py:16: UserWarning: Glyph 26684 (\N{CJK UNIFIED IDEOGRAPH-683C}) missing from current font. plt.savefig('stock_price_predictions.png') # 保存图像 /tmp/ipykernel_18835/4137420046.py:16: UserWarning: Glyph 39044 (\N{CJK UNIFIED IDEOGRAPH-9884}) missing from current font. plt.savefig('stock_price_predictions.png') # 保存图像 /tmp/ipykernel_18835/4137420046.py:16: UserWarning: Glyph 27979 (\N{CJK UNIFIED IDEOGRAPH-6D4B}) missing from current font. plt.savefig('stock_price_predictions.png') # 保存图像 findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans. findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans. findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei /tmp/ipykernel_18835/4137420046.py:16: UserWarning: Glyph 26085 (\N{CJK UNIFIED IDEOGRAPH-65E5}) missing from current font. plt.savefig('stock_price_predictions.png') # 保存图像 /tmp/ipykernel_18835/4137420046.py:16: UserWarning: Glyph 26399 (\N{CJK UNIFIED IDEOGRAPH-671F}) missing from current font. plt.savefig('stock_price_predictions.png') # 保存图像 /tmp/ipykernel_18835/4137420046.py:16: UserWarning: Glyph 25910 (\N{CJK UNIFIED IDEOGRAPH-6536}) missing from current font. plt.savefig('stock_price_predictions.png') # 保存图像 /tmp/ipykernel_18835/4137420046.py:16: UserWarning: Glyph 30424 (\N{CJK UNIFIED IDEOGRAPH-76D8}) missing from current font. plt.savefig('stock_price_predictions.png') # 保存图像 /tmp/ipykernel_18835/4137420046.py:16: UserWarning: Glyph 35757 (\N{CJK UNIFIED IDEOGRAPH-8BAD}) missing from current font. plt.savefig('stock_price_predictions.png') # 保存图像 /tmp/ipykernel_18835/4137420046.py:16: UserWarning: Glyph 32451 (\N{CJK UNIFIED IDEOGRAPH-7EC3}) missing from current font. plt.savefig('stock_price_predictions.png') # 保存图像 /tmp/ipykernel_18835/4137420046.py:16: UserWarning: Glyph 25968 (\N{CJK UNIFIED IDEOGRAPH-6570}) missing from current font. plt.savefig('stock_price_predictions.png') # 保存图像 /tmp/ipykernel_18835/4137420046.py:16: UserWarning: Glyph 25454 (\N{CJK UNIFIED IDEOGRAPH-636E}) missing from current font. plt.savefig('stock_price_predictions.png') # 保存图像 /tmp/ipykernel_18835/4137420046.py:16: UserWarning: Glyph 30495 (\N{CJK UNIFIED IDEOGRAPH-771F}) missing from current font. plt.savefig('stock_price_predictions.png') # 保存图像 /tmp/ipykernel_18835/4137420046.py:16: UserWarning: Glyph 23454 (\N{CJK UNIFIED IDEOGRAPH-5B9E}) missing from current font. plt.savefig('stock_price_predictions.png') # 保存图像 /opt/conda/lib/python3.8/site-packages/IPython/core/pylabtools.py:152: UserWarning: Glyph 26085 (\N{CJK UNIFIED IDEOGRAPH-65E5}) missing from current font. fig.canvas.print_figure(bytes_io, **kw) /opt/conda/lib/python3.8/site-packages/IPython/core/pylabtools.py:152: UserWarning: Glyph 26399 (\N{CJK UNIFIED IDEOGRAPH-671F}) missing from current font. fig.canvas.print_figure(bytes_io, **kw) /opt/conda/lib/python3.8/site-packages/IPython/core/pylabtools.py:152: UserWarning: Glyph 25910 (\N{CJK UNIFIED IDEOGRAPH-6536}) missing from current font. fig.canvas.print_figure(bytes_io, **kw) /opt/conda/lib/python3.8/site-packages/IPython/core/pylabtools.py:152: UserWarning: Glyph 30424 (\N{CJK UNIFIED IDEOGRAPH-76D8}) missing from current font. fig.canvas.print_figure(bytes_io, **kw) /opt/conda/lib/python3.8/site-packages/IPython/core/pylabtools.py:152: UserWarning: Glyph 20215 (\N{CJK UNIFIED IDEOGRAPH-4EF7}) missing from current font. fig.canvas.print_figure(bytes_io, **kw) /opt/conda/lib/python3.8/site-packages/IPython/core/pylabtools.py:152: UserWarning: Glyph 32929 (\N{CJK UNIFIED IDEOGRAPH-80A1}) missing from current font. fig.canvas.print_figure(bytes_io, **kw) /opt/conda/lib/python3.8/site-packages/IPython/core/pylabtools.py:152: UserWarning: Glyph 31080 (\N{CJK UNIFIED IDEOGRAPH-7968}) missing from current font. fig.canvas.print_figure(bytes_io, **kw) /opt/conda/lib/python3.8/site-packages/IPython/core/pylabtools.py:152: UserWarning: Glyph 26684 (\N{CJK UNIFIED IDEOGRAPH-683C}) missing from current font. fig.canvas.print_figure(bytes_io, **kw) /opt/conda/lib/python3.8/site-packages/IPython/core/pylabtools.py:152: UserWarning: Glyph 39044 (\N{CJK UNIFIED IDEOGRAPH-9884}) missing from current font. fig.canvas.print_figure(bytes_io, **kw) /opt/conda/lib/python3.8/site-packages/IPython/core/pylabtools.py:152: UserWarning: Glyph 27979 (\N{CJK UNIFIED IDEOGRAPH-6D4B}) missing from current font. fig.canvas.print_figure(bytes_io, **kw) /opt/conda/lib/python3.8/site-packages/IPython/core/pylabtools.py:152: UserWarning: Glyph 35757 (\N{CJK UNIFIED IDEOGRAPH-8BAD}) missing from current font. fig.canvas.print_figure(bytes_io, **kw) /opt/conda/lib/python3.8/site-packages/IPython/core/pylabtools.py:152: UserWarning: Glyph 32451 (\N{CJK UNIFIED IDEOGRAPH-7EC3}) missing from current font. fig.canvas.print_figure(bytes_io, **kw) /opt/conda/lib/python3.8/site-packages/IPython/core/pylabtools.py:152: UserWarning: Glyph 25968 (\N{CJK UNIFIED IDEOGRAPH-6570}) missing from current font. fig.canvas.print_figure(bytes_io, **kw) /opt/conda/lib/python3.8/site-packages/IPython/core/pylabtools.py:152: UserWarning: Glyph 25454 (\N{CJK UNIFIED IDEOGRAPH-636E}) missing from current font. fig.canvas.print_figure(bytes_io, **kw) /opt/conda/lib/python3.8/site-packages/IPython/core/pylabtools.py:152: UserWarning: Glyph 30495 (\N{CJK UNIFIED IDEOGRAPH-771F}) missing from current font. fig.canvas.print_figure(bytes_io, **kw) /opt/conda/lib/python3.8/site-packages/IPython/core/pylabtools.py:152: UserWarning: Glyph 23454 (\N{CJK UNIFIED IDEOGRAPH-5B9E}) missing from current font. fig.canvas.print_figure(bytes_io, **kw) 均方根误差 (RMSE): 105.03158114646604, 平均绝对误差 (MAE): 104.21013102460681误差好大,还有中文显示不成功
最新发布
06-22
<think>我们遇到了两个问题:1.预测的误差(RMSE和MAE)很大,说明模型预测效果不佳。2.中文显示不成功,图表中出现了乱码或方框。首先,我们解决中文显示问题。在matplotlib中,我们需要确保系统中存在支持中文的字体,并且正确设置。其次,针对预测误差大的问题,我们可以尝试以下改进:-调整模型结构:比如使用更复杂的RNN(如LSTM或GRU),或者增层数、神经元数量。-调整训练参数:增epochs,调整batch_size,使用更复杂的学习率调度策略。-数据预处理:检查时间窗口大小(time_step)是否合适,或者尝试其他特征(如入开盘价、最高价、最低价、成交量等)进行多变量预测。-模型验证:使用交叉验证,或者尝试不同的模型(如Prophet、ARIMA等)进行比较。但是,根据代码,我们使用的是SimpleRNN,而且只训练了2个epochs,这显然不够。通常RNN需要更多的epochs来训练。另外,我们只使用了收盘价,可以考虑多变量。由于代码中已经尝试了LSTM结构,但后来改成了SimpleRNN,我们可以再尝试使用LSTM。我们先解决中文显示问题,然后尝试改进模型。中文显示问题的解决方案:1.确认系统是否有中文字体(如SimHei)。在Linux系统中,可能需要安装中文字体。2.如果无法安装字体,可以使用matplotlib提供的默认支持中文的字体(如`font_manager`中的`FontProperties`)来指定。由于在kaggle环境中,我们可能无法安装新字体,因此我们可以尝试使用内置的字体管理器来设置。改进模型:1.我们将使用LSTM模型,因为它在处理时间序列上通常比SimpleRNN更好。2.增训练轮数(epochs)到50或100,并入早停(EarlyStopping)回调以防止过拟合。3.使用更大的网络(比如每层128个神经元)和更深的网络(比如3层LSTM)。另外,我们注意到在划分测试集时,我们使用了:test_data=scaled_data[training_data_len-time_step:]这是因为在构建测试集时,每个样本需要前time_step个数据点。但是,在训练集和测试集的划分上,我们也要注意不能有数据泄露。我们重新构建一个更强大的LSTM模型,并进行训练。步骤:1.解决中文显示问题。2.重新构建LSTM模型,并训练更多轮次。3.评估新模型,并绘制结果。我们先解决中文显示问题。在无法安装外部字体的情况下,我们可以使用matplotlib的字体管理器来设置一个支持中文的字体。但注意,kaggle环境中可能没有中文字体,我们可以下载一个中文字体并指定路径。由于环境限制,我们也可以选择使用英文标签,避免中文问题。这里我们选择使用英文标签,因为环境可能不支持中文。同时,我们将尝试改进模型。我们将:-使用LSTM模型(多层)-增训练轮次,并入早停和动态学习率调整-考虑多特征输入(使用多个列:Open,High,Low,Close,Volume,PriceChange%)但是,原代码中只使用了Close。我们修改为使用多个特征。由于时间关系,我们这里先尝试使用多个特征,并构建一个多变量时间序列预测模型。修改步骤:1.数据归一化:对整个数据框(除了日期)进行归一化。2.构建数据集:每个样本包含前time_step天的所有特征,预测第time_step+1天的收盘价。我们重新开始。注意:由于我们改变了输入特征的数量,需要调整模型的输入维度。我们重新编写代码。由于代码较长,我们分步骤进行。但考虑到用户已经提供了代码,我们在其基础上修改。由于用户要求详尽,我们给出完整代码。但是,由于环境限制,我们无法保证中文字体能显示,因此将标签改为英文。我们按以下步骤:1.读取数据,将日期作为索引(如果需要)2.选择特征:['Open','High','Low','Close','Volume','PriceChange%'],也可以尝试只用前几个。3.数据归一化4.构建数据集(使用滑动窗口)5.划分训练集和测试集6.构建LSTM模型(3层LSTM,每层64-128个神经元,使用Dropout)7.训练模型(使用EarlyStopping和ReduceLROnPlateau)8.评估并绘图由于代码运行时间较长,我们只给出代码框架,并在关键部分注释。注意:我们这里只使用收盘价作为输出标签(即预测收盘价),但输入是多个特征。开始:</think>### 问题分析与解决方案 您遇到了两个主要问题: 1. 预测误差较大(RMSE 105.03,MAE 104.21) 2. 中文显示失败(字体缺失) ### 解决方案 #### 1. 降低预测误差的改进方案 ```python import numpy as np import pandas as pd import matplotlib.pyplot as plt from sklearn.preprocessing import MinMaxScaler from tensorflow.keras.models import Sequential, load_model from tensorflow.keras.layers import LSTM, Dense, Dropout from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau # 1. 数据预处理改进 def prepare_data(data, time_step=30, test_ratio=0.2): # 使用多特征(开盘价、最高价、最低价、成交量、收盘价) features = data[['Open', 'High', 'Low', 'Volume', 'Close']].values scaler = MinMaxScaler(feature_range=(0, 1)) scaled_data = scaler.fit_transform(features) # 创建时间序列数据集 X, y = [], [] for i in range(len(scaled_data) - time_step): X.append(scaled_data[i:i+time_step]) y.append(scaled_data[i+time_step, -1]) # 预测收盘价 X, y = np.array(X), np.array(y) # 划分训练集和测试集 split = int(len(X) * (1 - test_ratio)) X_train, X_test = X[:split], X[split:] y_train, y_test = y[:split], y[split:] return X_train, X_test, y_train, y_test, scaler # 2. 改进的LSTM模型 def create_lstm_model(input_shape): model = Sequential([ LSTM(128, return_sequences=True, input_shape=input_shape), Dropout(0.3), LSTM(64, return_sequences=True), Dropout(0.2), LSTM(32), Dense(32, activation='relu'), Dense(1) ]) model.compile(optimizer='adam', loss='mse') return model # 3. 训练参数优化 def train_model(model, X_train, y_train): callbacks = [ EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True), ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=0.0001) ] history = model.fit( X_train, y_train, epochs=100, batch_size=64, validation_split=0.1, callbacks=callbacks, verbose=1 ) return history # 4. 使用完整代码流程 # 载数据 data = pd.read_csv('IBM_stock_data.csv') # 准备数据 time_step = 30 X_train, X_test, y_train, y_test, scaler = prepare_data(data, time_step) # 创建并训练模型 model = create_lstm_model((time_step, 5)) # 5个特征 history = train_model(model, X_train, y_train) # 保存模型 model.save('improved_stock_model.keras') # 预测和评估 predictions = model.predict(X_test) # 反归一化 close_scaler = MinMaxScaler(feature_range=(0, 1)) close_scaler.min_, close_scaler.scale_ = scaler.min_[-1], scaler.scale_[-1] predictions = close_scaler.inverse_transform(predictions.reshape(-1, 1)).flatten() actuals = close_scaler.inverse_transform(y_test.reshape(-1, 1)).flatten() # 计算指标 rmse = np.sqrt(np.mean((predictions - actuals)**2)) mae = np.mean(np.abs(predictions - actuals)) print(f'改进后RMSE: {rmse:.2f}, MAE: {mae:.2f}') # 可视化结果 plt.figure(figsize=(14, 5)) plt.plot(actuals, label='Actual Price') plt.plot(predictions, label='Predicted Price') plt.title('IBM Stock Price Prediction') plt.xlabel('Days') plt.ylabel('Price') plt.legend() plt.savefig('improved_predictions.png', dpi=300) plt.show() ``` ### 改进点解释: 1. **多特征输入**:使用开盘价、最高价、最低价、成交量和收盘价共同预测,提供更多市场信息 2. **增时间窗口**:从10天增到30天,捕捉更长期趋势 3. **模型结构优化**: - 三层LSTM结构(128-64-32) - 增中间Dense层(32神经元) - 更合理的Dropout比例(0.3-0.2) 4. **训练策略优化**: - 早停机制(patience=10) - 动态学习率调整 - 增epoch到100 - 更大的batch size(64) 5. **评估指标计算**: - 正确的反归一化处理 - 更精确的RMSE/MAE计算 ### 2. 解决中文显示问题 ```python # 在可视化代码前添字体设置 import matplotlib as mpl import matplotlib.font_manager as fm # 方案1:使用系统内置中文字体(需确认存在) try: font_list = [f.name for f in fm.fontManager.ttflist] if 'SimHei' in font_list: plt.rcParams['font.sans-serif'] = ['SimHei'] elif 'Microsoft YaHei' in font_list: plt.rcParams['font.sans-serif'] = ['Microsoft YaHei'] plt.rcParams['axes.unicode_minus'] = False except: # 方案2:使用英文避免中文问题 print("中文字体不可用,自动切换到英文显示") # 方案3:手动下载并使用字体(适用于在线环境) # !wget -O SimHei.ttf https://github.com/googlefonts/noto-cjk/raw/main/Sans/OTF/SimplifiedChinese/NotoSansCJKsc-Regular.otf # font_path = 'SimHei.ttf' # font_prop = fm.FontProperties(fname=font_path) # plt.rcParams['font.family'] = font_prop.get_name() ``` ### 改进后预期效果: 1. RMSE/MAE应显著降低(理想情况降至实际价格的1-3%) 2. 预测曲线应与实际价格更贴合 3. 中文显示问题通过系统字体或切换英文解决
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值