官网实例详解4.18(lstm_seq2seq.py)-keras学习笔记四


基于keras(实现)序列到序列的例子(字符级)。

     英文翻译为法文的实例

准备

下载fra-eng并解压到和py文件同目录

fra-eng目录文件


fra.txt(corpus,语料库)文件内容,每行 英文单词+空格+法文单词


Keras实例目录

代码注释

'''Sequence to sequence example in Keras (character-level).
基于keras(实现)序列到序列的例子(字符级)。
This script demonstrates how to implement a basic character-level
sequence-to-sequence model. We apply it to translating
short English sentences into short French sentences,
character-by-character. Note that it is fairly unusual to
do character-level machine translation, as word-level
models are more common in this domain.
本脚本演示如何实现一个基本的字符级序列到序列模型。我们将它应用于英语短句到法语的翻译。
注意,字符级机器翻译是比较少,而单词级模型在这一领域更为常见。

# Summary of the algorithm
算法综述

- We start with input sequences from a domain (e.g. English sentences)
    and correspding target sequences from another domain
    (e.g. French sentences).
    从一个范围的输入序列(例如英语句子)和来自另一个范围(例如法语句子)的目标序列开始。
- An encoder LSTM turns input sequences to 2 state vectors
    (we keep the last LSTM state and discard the outputs).
    编码器LSTM将输入序列转换为2个状态向量(我们保持最后一个LSTM状态并丢弃输出)。
- A decoder LSTM is trained to turn the target sequences into
    the same sequence but offset by one timestep in the future,
    a training process called "teacher forcing" in this context.
    Is uses as initial state the state vectors from the encoder.
    Effectively, the decoder learns to generate `targets[t+1...]`
    given `targets[...t]`, conditioned on the input sequence.
- 解码器LSTM被训练成将目标序列转换成相同的序列,但是在将来被一个时间步长偏移,
在此上下文中称为“teacher forcing”的训练过程。作为编码器的初始状态,有效地,译码器
学会在输入序列条件下生成“目标[t+1…] ]给定的‘目标[…t]’。
- In inference mode, when we want to decode unknown input sequences, we:
- 在推理模式中,当我们想要解码未知输入序列时,我们:
    - Encode the input sequence into state vectors
    - 将输入序列编码为状态向量
    - Start with a target sequence of size 1
        (just the start-of-sequence character)
    - 从大小为1的目标序列开始(只是序列字符的开始)
    - Feed the state vectors and 1-char target sequence
        to the decoder to produce predictions for the next character
    - 将状态向量和1-Char目标序列送到解码器以生成下一个字符的预测。
    - Sample the next character using these predictions
        (we simply use argmax).
    - 使用这些预测来采样下一个字符(我们只使用argmax)。
    - Append the sampled character to the target sequence
    - 采样的字符追加到目标序列
    - Repeat until we generate the end-of-sequence character or we
        hit the character limit.
    - 重复,直到产生序列结束字符或到达字符限制。

# Data download
数据下载
English to French sentence pairs.
英语到法语句子对
http://www.manythings.org/anki/fra-eng.zip

Lots of neat sentence pairs datasets can be found at:
大量整齐的句子对数据集可以在以下网址发现:
http://www.manythings.org/anki/

# References
参考
- Sequence to Sequence Learning with Neural Networks
使用神经网络进行序列到序列学习
    https://arxiv.org/abs/1409.3215
- Learning Phrase Representations using
学习短语表示
    RNN Encoder-Decoder for Statistical Machine Translation
    用于统计机器翻译的RNN(循环神经网络)编码器解码器
    https://arxiv.org/abs/1406.1078
'''
from __future__ import print_function

from keras.models import Model
from keras.layers import Input, LSTM, Dense
import numpy as np

batch_size = 64  # Batch size for training. 训练批次大小(每个批次包含样本数)
epochs = 100  # Number of epochs to train for.训练周期数
latent_dim = 256  # Latent dimensionality of the encoding space.编码空间的潜在维数
num_samples = 10000  # Number of samples to train on.训练集样本数
# Path to the data txt file on disk.
# 存储器数据文件路径(先下载,然后存放到和本脚本同级目录)
data_path = 'fra-eng/fra.txt'

# Vectorize the data.
# 数据向量化
input_texts = []
target_texts = []
input_characters = set()
target_characters = set()
with open(data_path, 'r', encoding='utf-8') as f:
    lines = f.read().split('\n')
for line in lines[: min(num_samples, len(lines) - 1)]:
    input_text, target_text = line.split('\t')
    # We use "tab" as the "start sequence" character
    # 使用"tab"为开始序列字符
    # for the targets, and "\n" as "end sequence" character.
    # 使用“tab”作为目标的“开始序列”字符,而“\n”作为“结束序列”字符。
    target_text = '\t' + target_text + '\n'
    input_texts.append(input_text)
    target_texts.append(target_text)
    for char in input_text:
        if char not in input_characters:
            input_characters.add(char)
    for char in target_text:
        if char not in target_characters:
            target_characters.add(char)

input_characters = sorted(list(input_characters))
target_characters = sorted(list(target_characters))
num_encoder_tokens = len(input_characters)
num_decoder_tokens = len(target_characters)
max_encoder_seq_length = max([len(txt) for txt in input_texts])
max_decoder_seq_length = max([len(txt) for txt in target_texts])

print('Number of samples:', len(input_texts))
print('Number of unique input tokens:', num_encoder_tokens)
print('Number of unique output tokens:', num_decoder_tokens)
print('Max sequence length for inputs:', max_encoder_seq_length)
print('Max sequence length for outputs:', max_decoder_seq_length)

input_token_index = dict(
    [(char, i) for i, char in enumerate(input_characters)])
target_token_index = dict(
    [(char, i) for i, char in enumerate(target_characters)])

encoder_input_data = np.zeros(
    (len(input_texts), max_encoder_seq_length, num_encoder_tokens),
    dtype='float32')
decoder_input_data = np.zeros(
    (len(input_texts), max_decoder_seq_length, num_decoder_tokens),
    dtype='float32')
decoder_target_data = np.zeros(
    (len(input_texts), max_decoder_seq_length, num_decoder_tokens),
    dtype='float32')

for i, (input_text, target_text) in enumerate(zip(input_texts, target_texts)):
    for t, char in enumerate(input_text):
        encoder_input_data[i, t, input_token_index[char]] = 1.
    for t, char in enumerate(target_text):
        # decoder_target_data is ahead of decoder_input_data by one timestep
        # decoder_target_data领先decoder_input_data一步
        decoder_input_data[i, t, target_token_index[char]] = 1.
        if t > 0:
            # decoder_target_data will be ahead by one timestep
            # and will not include the start character.
            # decoder_target_data数据将提前一个时间步,不包括开始字符。
            decoder_target_data[i, t - 1, target_token_index[char]] = 1.

# Define an input sequence and process it.
# 定义一个输入序列并对其进行处理。
encoder_inputs = Input(shape=(None, num_encoder_tokens))
encoder = LSTM(latent_dim, return_state=True)
encoder_outputs, state_h, state_c = encoder(encoder_inputs)
# We discard `encoder_outputs` and only keep the states.
# 丢弃“encoder_outputs”,只保留状态。
encoder_states = [state_h, state_c]

# Set up the decoder, using `encoder_states` as initial state.
# 设置解码器,使用encoder_states作为初始状态。
decoder_inputs = Input(shape=(None, num_decoder_tokens))
# We set up our decoder to return full output sequences,
# and to return internal states as well. We don't use the
# return states in the training model, but we will use them in inference.
# 设置解码器返回完整的输出序列,并返回内部状态。我们在训练模型中不使用返回状态,但是我们将在推理中使用它们。
decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True)
decoder_outputs, _, _ = decoder_lstm(decoder_inputs,
                                     initial_state=encoder_states)
decoder_dense = Dense(num_decoder_tokens, activation='softmax')
decoder_outputs = decoder_dense(decoder_outputs)

# Define the model that will turn
# `encoder_input_data` & `decoder_input_data` into `decoder_target_data`
# 定义将encoder_input_data和decoder_input_data转换为decoder_target_data的模型
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)

# Run training
# 运行模型
model.compile(optimizer='rmsprop', loss='categorical_crossentropy')
model.fit([encoder_input_data, decoder_input_data], decoder_target_data,
          batch_size=batch_size,
          epochs=epochs,
          validation_split=0.2)
# Save model
# 保存模型
model.save('s2s.h5')

# Next: inference mode (sampling).
# 下一步:推理模式(采样)。
# Here's the drill:
# 练习:
# 1) encode input and retrieve initial decoder state
# 1) 编码输入和检索初始解码器状态
# 2) run one step of decoder with this initial state
# and a "start of sequence" token as target.
# Output will be the next target token
# 2) 以这个初始状态和一个“序列开始”分词为目标运行一步解码器。输出将是下一个目标分词
# 3) Repeat with the current target token and current states
# 3) 重复当前目标分词和当前状态

# Define sampling models
# 定义抽样模型
encoder_model = Model(encoder_inputs, encoder_states)

decoder_state_input_h = Input(shape=(latent_dim,))
decoder_state_input_c = Input(shape=(latent_dim,))
decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]
decoder_outputs, state_h, state_c = decoder_lstm(
    decoder_inputs, initial_state=decoder_states_inputs)
decoder_states = [state_h, state_c]
decoder_outputs = decoder_dense(decoder_outputs)
decoder_model = Model(
    [decoder_inputs] + decoder_states_inputs,
    [decoder_outputs] + decoder_states)

# Reverse-lookup token index to decode sequences back to
# something readable.
# 反向查找分词索引(词的编号)将序列解码为可读的。
reverse_input_char_index = dict(
    (i, char) for char, i in input_token_index.items())
reverse_target_char_index = dict(
    (i, char) for char, i in target_token_index.items())


def decode_sequence(input_seq):
    # Encode the input as state vectors.
    # 将输入(数据)编码为状态向量。
    states_value = encoder_model.predict(input_seq)

    # Generate empty target sequence of length 1.
    # 生成长度为1的空目标序列。
    target_seq = np.zeros((1, 1, num_decoder_tokens))
    # Populate the first character of target sequence with the start character.
    # 用开始字符填充目标序列的第一个字符。
    target_seq[0, 0, target_token_index['\t']] = 1.

    # Sampling loop for a batch of sequences
    # (to simplify, here we assume a batch of size 1).
    # 一批序列的采样循环(为了简化,这里假设一批次大小为1)。
    stop_condition = False
    decoded_sentence = ''
    while not stop_condition:
        output_tokens, h, c = decoder_model.predict(
            [target_seq] + states_value)

        # Sample a token
        # 采样
        sampled_token_index = np.argmax(output_tokens[0, -1, :])
        sampled_char = reverse_target_char_index[sampled_token_index]
        decoded_sentence += sampled_char

        # Exit condition: either hit max length
        # or find stop character.
        # 退出条件:到达最大长度或找到停止字符。
        if (sampled_char == '\n' or
           len(decoded_sentence) > max_decoder_seq_length):
            stop_condition = True

        # Update the target sequence (of length 1).
        # 更新目标序列(长度1)。
        target_seq = np.zeros((1, 1, num_decoder_tokens))
        target_seq[0, 0, sampled_token_index] = 1.

        # Update states
        # 更新状态
        states_value = [h, c]

    return decoded_sentence


for seq_index in range(100):
    # Take one sequence (part of the training set)
    # for trying out decoding.
    # 用一个序列(训练集的一部分)进行解码。
    input_seq = encoder_input_data[seq_index: seq_index + 1]
    decoded_sentence = decode_sequence(input_seq)
    print('-')
    print('Input sentence:', input_texts[seq_index])
    print('Decoded sentence:', decoded_sentence)

代码执行

C:\ProgramData\Anaconda3\python.exe E:/keras-master/examples/lstm_seq2seq.py
Using TensorFlow backend.
Number of samples: 10000
Number of unique input tokens: 71
Number of unique output tokens: 94
Max sequence length for inputs: 16
Max sequence length for outputs: 59
Train on 8000 samples, validate on 2000 samples
Epoch 1/100

  64/8000 [..............................] - ETA: 3:58 - loss: 1.3394
 128/8000 [..............................] - ETA: 2:07 - loss: 1.3604
 192/8000 [..............................] - ETA: 1:30 - loss: 1.3536
 256/8000 [..............................] - ETA: 1:11 - loss: 1.3082
 320/8000 [>.............................] - ETA: 1:00 - loss: 1.2752
 384/8000 [>.............................] - ETA: 52s - loss: 1.2475 
 448/8000 [>.............................] - ETA: 49s - loss: 1.2114
 512/8000 [>.............................] - ETA: 45s - loss: 1.1931
 576/8000 [=>............................] - ETA: 42s - loss: 1.1733
 640/8000 [=>............................] - ETA: 40s - loss: 1.1549
 704/8000 [=>............................] - ETA: 38s - loss: 1.1414
 768/8000 [=>............................] - ETA: 36s - loss: 1.1241
 832/8000 [==>...........................] - ETA: 34s - loss: 1.1123
 896/8000 [==>...........................] - ETA: 33s - loss: 1.0999
 960/8000 [==>...........................] - ETA: 32s - loss: 1.0876
1024/8000 [==>...........................] - ETA: 31s - loss: 1.0807
1088/8000 [===>..........................] - ETA: 29s - loss: 1.0802
1152/8000 [===>..........................] - ETA: 29s - loss: 1.0724
1216/8000 [===>..........................] - ETA: 28s - loss: 1.0677
1280/8000 [===>..........................] - ETA: 27s - loss: 1.0614
1344/8000 [====>.........................] - ETA: 26s - loss: 1.0563
1408/8000 [====>.........................] - ETA: 25s - loss: 1.0538
1472/8000 [====>.........................] - ETA: 25s - loss: 1.0513
1536/8000 [====>.........................] - ETA: 24s - loss: 1.0495
1600/8000 [=====>........................] - ETA: 24s - loss: 1.0470
1664/8000 [=====>........................] - ETA: 23s - loss: 1.0434
1728/8000 [=====>........................] - ETA: 23s - loss: 1.0394
1792/8000 [=====>........................] - ETA: 22s - loss: 1.0357
1856/8000 [=====>........................] - ETA: 22s - loss: 1.0331
1920/8000 [======>.......................] - ETA: 21s - loss: 1.0277
1984/8000 [======>.......................] - ETA: 21s - loss: 1.0275
2048/8000 [======>.......................] - ETA: 20s - loss: 1.0253
2112/8000 [======>.......................] - ETA: 20s - loss: 1.0230
2176/8000 [=======>......................] - ETA: 19s - loss: 1.0212
2240/8000 [=======>......................] - ETA: 19s - loss: 1.0179
2304/8000 [=======>......................] - ETA: 19s - loss: 1.0159
2368/8000 [=======>......................] - ETA: 18s - loss: 1.0129
2432/8000 [========>.....................] - ETA: 18s - loss: 1.0099
2496/8000 [========>.....................] - ETA: 18s - loss: 1.0090
2560/8000 [========>.....................] - ETA: 17s - loss: 1.0056
2624/8000 [========>.....................] - ETA: 17s - loss: 1.0041
2688/8000 [=========>....................] - ETA: 17s - loss: 1.0007
2752/8000 [=========>....................] - ETA: 17s - loss: 0.9998
2816/8000 [=========>....................] - ETA: 16s - loss: 0.9978
2880/8000 [=========>....................] - ETA: 16s - loss: 0.9955
2944/8000 [==========>...................] - ETA: 16s - loss: 0.9933
3008/8000 [==========>...................] - ETA: 16s - loss: 0.9921
3072/8000 [==========>...................] - ETA: 16s - loss: 0.9906
3136/8000 [==========>...................] - ETA: 15s - loss: 0.9873
3200/8000 [===========>..................] - ETA: 15s - loss: 0.9864
3264/8000 [===========>..................] - ETA: 15s - loss: 0.9858
3328/8000 [===========>..................] - ETA: 15s - loss: 0.9846
3392/8000 [===========>..................] - ETA: 15s - loss: 0.9833
3456/8000 [===========>..................] - ETA: 14s - loss: 0.9822
3520/8000 [============>.................] - ETA: 14s - loss: 0.9816
3584/8000 [============>.................] - ETA: 14s - loss: 0.9807
3648/8000 [============>.................] - ETA: 14s - loss: 0.9810
3712/8000 [============>.................] - ETA: 13s - loss: 0.9794
3776/8000 [=============>................] - ETA: 13s - loss: 0.9782
3840/8000 [=============>................] - ETA: 13s - loss: 0.9768
3904/8000 [=============>................] - ETA: 13s - loss: 0.9757
3968/8000 [=============>................] - ETA: 12s - loss: 0.9745
4032/8000 [==============>...............] - ETA: 12s - loss: 0.9737
4096/8000 [==============>...............] - ETA: 12s - loss: 0.9729
4160/8000 [==============>...............] - ETA: 12s - loss: 0.9711
4224/8000 [==============>...............] - ETA: 11s - loss: 0.9697
4288/8000 [===============>..............] - ETA: 11s - loss: 0.9685
4352/8000 [===============>..............] - ETA: 11s - loss: 0.9677
4416/8000 [===============>..............] - ETA: 11s - loss: 0.9667
4480/8000 [===============>..............] - ETA: 10s - loss: 0.9656
4544/8000 [================>.............] - ETA: 10s - loss: 0.9646
4608/8000 [================>.............] - ETA: 10s - loss: 0.9636
4672/8000 [================>.............] - ETA: 10s - loss: 0.9623
4736/8000 [================>.............] - ETA: 10s - loss: 0.9612
4800/8000 [=================>............] - ETA: 9s - loss: 0.9606 
4864/8000 [=================>............] - ETA: 9s - loss: 0.9600
4928/8000 [=================>............] - ETA: 9s - loss: 0.9587
4992/8000 [=================>............] - ETA: 9s - loss: 0.9586
5056/8000 [=================>............] - ETA: 9s - loss: 0.9581
5120/8000 [==================>...........] - ETA: 8s - loss: 0.9571
5184/8000 [==================>...........] - ETA: 8s - loss: 0.9561
5248/8000 [==================>...........] - ETA: 8s - loss: 0.9559
5312/8000 [==================>...........] - ETA: 8s - loss: 0.9554
5376/8000 [===================>..........] - ETA: 7s - loss: 0.9539
5440/8000 [===================>..........] - ETA: 7s - loss: 0.9532
5504/8000 [===================>..........] - ETA: 7s - loss: 0.9524
5568/8000 [===================>..........] - ETA: 7s - loss: 0.9521
5632/8000 [====================>.........] - ETA: 7s - loss: 0.9511
5696/8000 [====================>.........] - ETA: 6s - loss: 0.9501
5760/8000 [====================>.........] - ETA: 6s - loss: 0.9493
5824/8000 [====================>.........] - ETA: 6s - loss: 0.9488
5888/8000 [=====================>........] - ETA: 6s - loss: 0.9482
5952/8000 [=====================>........] - ETA: 6s - loss: 0.9476
6016/8000 [=====================>........] - ETA: 5s - loss: 0.9470
6080/8000 [=====================>........] - ETA: 5s - loss: 0.9459
6144/8000 [======================>.......] - ETA: 5s - loss: 0.9455
6208/8000 [======================>.......] - ETA: 5s - loss: 0.9448
6272/8000 [======================>.......] - ETA: 5s - loss: 0.9445
6336/8000 [======================>.......] - ETA: 4s - loss: 0.9440
6400/8000 [=======================>......] - ETA: 4s - loss: 0.9434
6464/8000 [=======================>......] - ETA: 4s - loss: 0.9425
6528/8000 [=======================>......] - ETA: 4s - loss: 0.9420
6592/8000 [=======================>......] - ETA: 4s - loss: 0.9414
6656/8000 [=======================>......] - ETA: 3s - loss: 0.9411
6720/8000 [========================>.....] - ETA: 3s - loss: 0.9401
6784/8000 [========================>.....] - ETA: 3s - loss: 0.9392
6848/8000 [========================>.....] - ETA: 3s - loss: 0.9385
6912/8000 [========================>.....] - ETA: 3s - loss: 0.9374
6976/8000 [=========================>....] - ETA: 2s - loss: 0.9369
7040/8000 [=========================>....] - ETA: 2s - loss: 0.9361
7104/8000 [=========================>....] - ETA: 2s - loss: 0.9352
7168/8000 [=========================>....] - ETA: 2s - loss: 0.9345
7232/8000 [==========================>...] - ETA: 2s - loss: 0.9340
7296/8000 [==========================>...] - ETA: 2s - loss: 0.9337
7360/8000 [==========================>...] - ETA: 1s - loss: 0.9332
7424/8000 [==========================>...] - ETA: 1s - loss: 0.9328
7488/8000 [===========================>..] - ETA: 1s - loss: 0.9323
7552/8000 [===========================>..] - ETA: 1s - loss: 0.9319
7616/8000 [===========================>..] - ETA: 1s - loss: 0.9314
7680/8000 [===========================>..] - ETA: 0s - loss: 0.9311
7744/8000 [============================>.] - ETA: 0s - loss: 0.9304
7808/8000 [============================>.] - ETA: 0s - loss: 0.9295
7872/8000 [============================>.] - ETA: 0s - loss: 0.9289
7936/8000 [============================>.] - ETA: 0s - loss: 0.9284
8000/8000 [==============================] - 25s 3ms/step - loss: 0.9274 - val_loss: 1.0000
Epoch 2/100

  64/8000 [..............................] - ETA: 17s - loss: 0.8085
 128/8000 [..............................] - ETA: 22s - loss: 0.8087
 192/8000 [..............................] - ETA: 23s - loss: 0.8104
 

Epoch 100/100

  64/8000 [..............................] - ETA: 15s - loss: 0.0462
 128/8000 [..............................] - ETA: 17s - loss: 0.0495
 192/8000 [..............................] - ETA: 17s - loss: 0.0503
 256/8000 [..............................] - ETA: 16s - loss: 0.0533
 320/8000 [>.............................] - ETA: 17s - loss: 0.0527
 384/8000 [>.............................] - ETA: 17s - loss: 0.0525
 448/8000 [>.............................] - ETA: 17s - loss: 0.0532
 512/8000 [>.............................] - ETA: 16s - loss: 0.0526
 576/8000 [=>............................] - ETA: 16s - loss: 0.0530
 640/8000 [=>............................] - ETA: 16s - loss: 0.0530
 704/8000 [=>............................] - ETA: 16s - loss: 0.0529
 768/8000 [=>............................] - ETA: 16s - loss: 0.0533
 832/8000 [==>...........................] - ETA: 16s - loss: 0.0532
 896/8000 [==>...........................] - ETA: 16s - loss: 0.0537
 960/8000 [==>...........................] - ETA: 15s - loss: 0.0535
1024/8000 [==>...........................] - ETA: 15s - loss: 0.0532
1088/8000 [===>..........................] - ETA: 15s - loss: 0.0532
1152/8000 [===>..........................] - ETA: 15s - loss: 0.0531
1216/8000 [===>..........................] - ETA: 15s - loss: 0.0530
1280/8000 [===>..........................] - ETA: 14s - loss: 0.0535
1344/8000 [====>.........................] - ETA: 14s - loss: 0.0536
1408/8000 [====>.........................] - ETA: 14s - loss: 0.0535
1472/8000 [====>.........................] - ETA: 14s - loss: 0.0534
1536/8000 [====>.........................] - ETA: 14s - loss: 0.0533
1600/8000 [=====>........................] - ETA: 14s - loss: 0.0533
1664/8000 [=====>........................] - ETA: 14s - loss: 0.0535
1728/8000 [=====>........................] - ETA: 13s - loss: 0.0536
1792/8000 [=====>........................] - ETA: 13s - loss: 0.0535
1856/8000 [=====>........................] - ETA: 13s - loss: 0.0538
1920/8000 [======>.......................] - ETA: 13s - loss: 0.0539
1984/8000 [======>.......................] - ETA: 13s - loss: 0.0542
2048/8000 [======>.......................] - ETA: 13s - loss: 0.0543
2112/8000 [======>.......................] - ETA: 13s - loss: 0.0543
2176/8000 [=======>......................] - ETA: 13s - loss: 0.0545
2240/8000 [=======>......................] - ETA: 12s - loss: 0.0545
2304/8000 [=======>......................] - ETA: 12s - loss: 0.0546
2368/8000 [=======>......................] - ETA: 12s - loss: 0.0546
2432/8000 [========>.....................] - ETA: 12s - loss: 0.0545
2496/8000 [========>.....................] - ETA: 12s - loss: 0.0546
2560/8000 [========>.....................] - ETA: 12s - loss: 0.0547
2624/8000 [========>.....................] - ETA: 12s - loss: 0.0547
2688/8000 [=========>....................] - ETA: 11s - loss: 0.0549
2752/8000 [=========>....................] - ETA: 11s - loss: 0.0549
2816/8000 [=========>....................] - ETA: 11s - loss: 0.0550
2880/8000 [=========>....................] - ETA: 11s - loss: 0.0550
2944/8000 [==========>...................] - ETA: 11s - loss: 0.0549
3008/8000 [==========>...................] - ETA: 11s - loss: 0.0550
3072/8000 [==========>...................] - ETA: 11s - loss: 0.0551
3136/8000 [==========>...................] - ETA: 10s - loss: 0.0552
3200/8000 [===========>..................] - ETA: 10s - loss: 0.0554
3264/8000 [===========>..................] - ETA: 10s - loss: 0.0555
3328/8000 [===========>..................] - ETA: 10s - loss: 0.0554
3392/8000 [===========>..................] - ETA: 10s - loss: 0.0555
3456/8000 [===========>..................] - ETA: 10s - loss: 0.0555
3520/8000 [============>.................] - ETA: 9s - loss: 0.0556 
3584/8000 [============>.................] - ETA: 9s - loss: 0.0557
3648/8000 [============>.................] - ETA: 9s - loss: 0.0558
3712/8000 [============>.................] - ETA: 9s - loss: 0.0557
3776/8000 [=============>................] - ETA: 9s - loss: 0.0557
3840/8000 [=============>................] - ETA: 9s - loss: 0.0558
3904/8000 [=============>................] - ETA: 9s - loss: 0.0558
3968/8000 [=============>................] - ETA: 8s - loss: 0.0559
4032/8000 [==============>...............] - ETA: 8s - loss: 0.0559
4096/8000 [==============>...............] - ETA: 8s - loss: 0.0560
4160/8000 [==============>...............] - ETA: 8s - loss: 0.0561
4224/8000 [==============>...............] - ETA: 8s - loss: 0.0561
4288/8000 [===============>..............] - ETA: 8s - loss: 0.0561
4352/8000 [===============>..............] - ETA: 8s - loss: 0.0561
4416/8000 [===============>..............] - ETA: 8s - loss: 0.0562
4480/8000 [===============>..............] - ETA: 7s - loss: 0.0563
4544/8000 [================>.............] - ETA: 7s - loss: 0.0564
4608/8000 [================>.............] - ETA: 7s - loss: 0.0564
4672/8000 [================>.............] - ETA: 7s - loss: 0.0566
4736/8000 [================>.............] - ETA: 7s - loss: 0.0566
4800/8000 [=================>............] - ETA: 7s - loss: 0.0567
4864/8000 [=================>............] - ETA: 7s - loss: 0.0568
4928/8000 [=================>............] - ETA: 6s - loss: 0.0568
4992/8000 [=================>............] - ETA: 6s - loss: 0.0569
5056/8000 [=================>............] - ETA: 6s - loss: 0.0569
5120/8000 [==================>...........] - ETA: 6s - loss: 0.0571
5184/8000 [==================>...........] - ETA: 6s - loss: 0.0570
5248/8000 [==================>...........] - ETA: 6s - loss: 0.0571
5312/8000 [==================>...........] - ETA: 6s - loss: 0.0572
5376/8000 [===================>..........] - ETA: 5s - loss: 0.0573
5440/8000 [===================>..........] - ETA: 5s - loss: 0.0574
5504/8000 [===================>..........] - ETA: 5s - loss: 0.0574
5568/8000 [===================>..........] - ETA: 5s - loss: 0.0575
5632/8000 [====================>.........] - ETA: 5s - loss: 0.0575
5696/8000 [====================>.........] - ETA: 5s - loss: 0.0576
5760/8000 [====================>.........] - ETA: 5s - loss: 0.0576
5824/8000 [====================>.........] - ETA: 4s - loss: 0.0576
5888/8000 [=====================>........] - ETA: 4s - loss: 0.0577
5952/8000 [=====================>........] - ETA: 4s - loss: 0.0578
6016/8000 [=====================>........] - ETA: 4s - loss: 0.0578
6080/8000 [=====================>........] - ETA: 4s - loss: 0.0578
6144/8000 [======================>.......] - ETA: 4s - loss: 0.0579
6208/8000 [======================>.......] - ETA: 4s - loss: 0.0580
6272/8000 [======================>.......] - ETA: 3s - loss: 0.0579
6336/8000 [======================>.......] - ETA: 3s - loss: 0.0579
6400/8000 [=======================>......] - ETA: 3s - loss: 0.0580
6464/8000 [=======================>......] - ETA: 3s - loss: 0.0580
6528/8000 [=======================>......] - ETA: 3s - loss: 0.0580
6592/8000 [=======================>......] - ETA: 3s - loss: 0.0581
6656/8000 [=======================>......] - ETA: 3s - loss: 0.0581
6720/8000 [========================>.....] - ETA: 2s - loss: 0.0581
6784/8000 [========================>.....] - ETA: 2s - loss: 0.0582
6848/8000 [========================>.....] - ETA: 2s - loss: 0.0582
6912/8000 [========================>.....] - ETA: 2s - loss: 0.0582
6976/8000 [=========================>....] - ETA: 2s - loss: 0.0582
7040/8000 [=========================>....] - ETA: 2s - loss: 0.0583
7104/8000 [=========================>....] - ETA: 2s - loss: 0.0583
7168/8000 [=========================>....] - ETA: 1s - loss: 0.0584
7232/8000 [==========================>...] - ETA: 1s - loss: 0.0584
7296/8000 [==========================>...] - ETA: 1s - loss: 0.0585
7360/8000 [==========================>...] - ETA: 1s - loss: 0.0586
7424/8000 [==========================>...] - ETA: 1s - loss: 0.0587
7488/8000 [===========================>..] - ETA: 1s - loss: 0.0587
7552/8000 [===========================>..] - ETA: 1s - loss: 0.0587
7616/8000 [===========================>..] - ETA: 0s - loss: 0.0587
7680/8000 [===========================>..] - ETA: 0s - loss: 0.0587
7744/8000 [============================>.] - ETA: 0s - loss: 0.0587
7808/8000 [============================>.] - ETA: 0s - loss: 0.0588
7872/8000 [============================>.] - ETA: 0s - loss: 0.0588
7936/8000 [============================>.] - ETA: 0s - loss: 0.0588
8000/8000 [==============================] - 19s 2ms/step - loss: 0.0588 - val_loss: 0.7629
E:\keras-master\keras\engine\topology.py:2365: UserWarning: Layer lstm_2 was passed non-serializable keyword arguments: {'initial_state': [<tf.Tensor 'lstm_1/while/Exit_2:0' shape=(?, 256) dtype=float32>, <tf.Tensor 'lstm_1/while/Exit_3:0' shape=(?, 256) dtype=float32>]}. They will not be included in the serialized model (and thus will be missing at deserialization time).
  str(node.arguments) + '. They will not be included '
-
Input sentence: Go.
Decoded sentence: Va !

-
Input sentence: Run!
Decoded sentence: Cours !

-
Input sentence: Run!
Decoded sentence: Cours !

-
Input sentence: Fire!
Decoded sentence: Au feu !

-
Input sentence: Help!
Decoded sentence: À l'aide !

-
Input sentence: Jump.
Decoded sentence: Saute.

-
Input sentence: Stop!
Decoded sentence: Arrête-toi !

-
Input sentence: Stop!
Decoded sentence: Arrête-toi !

-
Input sentence: Stop!
Decoded sentence: Arrête-toi !

-
Input sentence: Wait!
Decoded sentence: Attends !

-
Input sentence: Wait!
Decoded sentence: Attends !

-
Input sentence: Go on.
Decoded sentence: Continuez.

-
Input sentence: Go on.
Decoded sentence: Continuez.

-
Input sentence: Go on.
Decoded sentence: Continuez.

-
Input sentence: I see.
Decoded sentence: Je vois une lumière.

-
Input sentence: I try.
Decoded sentence: J'essaye.

-
Input sentence: I won!
Decoded sentence: J'ai demandé à dore.

-
Input sentence: I won!
Decoded sentence: J'ai demandé à dore.

-
Input sentence: Oh no!
Decoded sentence: Oh non !

-
Input sentence: Attack!
Decoded sentence: Attaquez !

-
Input sentence: Attack!
Decoded sentence: Attaquez !

-
Input sentence: Cheers!
Decoded sentence: À votre santé !

-
Input sentence: Cheers!
Decoded sentence: À votre santé !

-
Input sentence: Cheers!
Decoded sentence: À votre santé !

-
Input sentence: Cheers!
Decoded sentence: À votre santé !

-
Input sentence: Get up.
Decoded sentence: Lève-toi.

-
Input sentence: Go now.
Decoded sentence: Va doucement !

-
Input sentence: Go now.
Decoded sentence: Va doucement !

-
Input sentence: Go now.
Decoded sentence: Va doucement !

-
Input sentence: Got it!
Decoded sentence: Compris !

-
Input sentence: Got it!
Decoded sentence: Compris !

-
Input sentence: Got it?
Decoded sentence: Compris ?

-
Input sentence: Got it?
Decoded sentence: Compris ?

-
Input sentence: Got it?
Decoded sentence: Compris ?

-
Input sentence: Hop in.
Decoded sentence: Montez.

-
Input sentence: Hop in.
Decoded sentence: Montez.

-
Input sentence: Hug me.
Decoded sentence: Serre-moi dans tes bras !

-
Input sentence: Hug me.
Decoded sentence: Serre-moi dans tes bras !

-
Input sentence: I fell.
Decoded sentence: Je suis tombée.

-
Input sentence: I fell.
Decoded sentence: Je suis tombée.

-
Input sentence: I know.
Decoded sentence: Je sais.

-
Input sentence: I left.
Decoded sentence: Je suis parti.

-
Input sentence: I left.
Decoded sentence: Je suis parti.

-
Input sentence: I lost.
Decoded sentence: J'ai perdu.

-
Input sentence: I'm 19.
Decoded sentence: J'ai les chocontes.

-
Input sentence: I'm OK.
Decoded sentence: Je vais bien.

-
Input sentence: I'm OK.
Decoded sentence: Je vais bien.

-
Input sentence: Listen.
Decoded sentence: Écoutez !

-
Input sentence: No way!
Decoded sentence: C'est exclu !

-
Input sentence: No way!
Decoded sentence: C'est exclu !

-
Input sentence: No way!
Decoded sentence: C'est exclu !

-
Input sentence: No way!
Decoded sentence: C'est exclu !

-
Input sentence: No way!
Decoded sentence: C'est exclu !

-
Input sentence: No way!
Decoded sentence: C'est exclu !

-
Input sentence: No way!
Decoded sentence: C'est exclu !

-
Input sentence: No way!
Decoded sentence: C'est exclu !

-
Input sentence: No way!
Decoded sentence: C'est exclu !

-
Input sentence: Really?
Decoded sentence: Vrai ?

-
Input sentence: Really?
Decoded sentence: Vrai ?

-
Input sentence: Really?
Decoded sentence: Vrai ?

-
Input sentence: Thanks.
Decoded sentence: Merci !

-
Input sentence: We try.
Decoded sentence: On essaye.

-
Input sentence: We won.
Decoded sentence: Nous avons réveillé.

-
Input sentence: We won.
Decoded sentence: Nous avons réveillé.

-
Input sentence: We won.
Decoded sentence: Nous avons réveillé.

-
Input sentence: We won.
Decoded sentence: Nous avons réveillé.

-
Input sentence: Ask Tom.
Decoded sentence: Demande-leur.

-
Input sentence: Awesome!
Decoded sentence: Faisalez-moi !

-
Input sentence: Be calm.
Decoded sentence: Sois calme !

-
Input sentence: Be calm.
Decoded sentence: Sois calme !

-
Input sentence: Be calm.
Decoded sentence: Sois calme !

-
Input sentence: Be cool.
Decoded sentence: Sois détendu !

-
Input sentence: Be fair.
Decoded sentence: Soyez équitables !

-
Input sentence: Be fair.
Decoded sentence: Soyez équitables !

-
Input sentence: Be fair.
Decoded sentence: Soyez équitables !

-
Input sentence: Be fair.
Decoded sentence: Soyez équitables !

-
Input sentence: Be fair.
Decoded sentence: Soyez équitables !

-
Input sentence: Be fair.
Decoded sentence: Soyez équitables !

-
Input sentence: Be kind.
Decoded sentence: Sois gentil.

-
Input sentence: Be nice.
Decoded sentence: Sois gentille !

-
Input sentence: Be nice.
Decoded sentence: Sois gentille !

-
Input sentence: Be nice.
Decoded sentence: Sois gentille !

-
Input sentence: Be nice.
Decoded sentence: Sois gentille !

-
Input sentence: Be nice.
Decoded sentence: Sois gentille !

-
Input sentence: Be nice.
Decoded sentence: Sois gentille !

-
Input sentence: Beat it.
Decoded sentence: Dégage !

-
Input sentence: Call me.
Decoded sentence: Appellez-moi !

-
Input sentence: Call me.
Decoded sentence: Appellez-moi !

-
Input sentence: Call us.
Decoded sentence: Appelle-nous !

-
Input sentence: Call us.
Decoded sentence: Appelle-nous !

-
Input sentence: Come in.
Decoded sentence: Entrez !

-
Input sentence: Come in.
Decoded sentence: Entrez !

-
Input sentence: Come in.
Decoded sentence: Entrez !

-
Input sentence: Come in.
Decoded sentence: Entrez !

-
Input sentence: Come on!
Decoded sentence: Allez !

-
Input sentence: Come on.
Decoded sentence: Viens !

-
Input sentence: Come on.
Decoded sentence: Viens !

-
Input sentence: Come on.
Decoded sentence: Viens !

-
Input sentence: Drop it!
Decoded sentence: Laissez-le tomber !

-
Input sentence: Drop it!
Decoded sentence: Laissez-le tomber !


Process finished with exit code 0

Keras详细介绍

英文:https://keras.io/

中文:http://keras-cn.readthedocs.io/en/latest/

实例下载

https://github.com/keras-team/keras

https://github.com/keras-team/keras/tree/master/examples

完整项目下载

方便没积分童鞋,请加企鹅452205574,共享文件夹。

包括:代码、数据集合(图片)、已生成model、安装库文件等。


“ print("\n" + "=" * 20 + "模型: GWO-LSTM-XGBoost (修复版) " + "=" * 20) df_hybrid = df_base.copy() df_hybrid['lstm_feature'] = np.nan # 这一行是语法错误的原因之一 # 完整的修复版本 try: # 这里应该是你的代码 pass except Exception as e: print(f"发生错误: {e}") # 1.安全生成LSTM特征(避免泄露) train_mask = df_hybrid.index < test_timestamps[0] if not train_mask.any(): print("警告:训练集为空!") # 这里可以返回或使用其他处理方式 # 但为了完整执行流程,我们创建空的训练集 X_lstm_train = np.array([]) else: X_lstm_train = scaler_x_lstm_standalone.transform( df_base.loc[train_mask, lstm_cols_standalone] ) if len(X_lstm_train) < time_steps: print(f"错误:训练数据不足,需要至少{time_steps}个样本,实际只有{len(X_lstm_train)}") # 使用全量数据作为备选方案 X_lstm_train = scaler_x_lstm_standalone.transform( df_base[lstm_cols_standalone] ) # 训练集特征序列 X_seq_train, _ = create_sequences( X_lstm_train, scaler_y_lstm_standalone.transform(df_base.loc[train_mask, ['DC_POWER']]).flatten(), time_steps) if len(X_seq_train) == 0: print("错误:无法创建训练序列") # 使用更简单的序列创建方式 X_seq_train = X_lstm_train.reshape(-1, time_steps, X_lstm_train.shape[1]) train_lstm_preds = lstm_standalone_model.predict(X_seq_train) train_lstm_unscaled = scaler_y_lstm_standalone.inverse_transform(train_lstm_preds) # 填充训练集特征 train_start_idx = time_steps # 跳过序列开始的NaN train_end_idx = train_start_idx + len(train_lstm_unscaled) df_hybrid.iloc[train_start_idx:train_end_idx, df_hybrid.columns.get_loc('lstm_feature')] = train_lstm_unscaled.flatten() # 测试集特征 test_lstm_preds = lstm_standalone_model.predict(X_test_lstm_standalone) test_lstm_unscaled = scaler_y_lstm_standalone.inverse_transform(test_lstm_preds) test_start_idx = len(df_hybrid) - len(test_lstm_unscaled) df_hybrid.iloc[test_start_idx:test_start_idx + len(test_lstm_unscaled), df_hybrid.columns.get_loc('lstm_feature')] = test_lstm_unscaled.flatten() # 2.增强特征工程 df_hybrid['temp_irrad_ratio'] = df_hybrid['MODULE_TEMPERATURE'] / (df_hybrid['IRRADIATION'] + 1e-5) df_hybrid['power_irrad_ratio'] = df_hybrid['DC_POWER'] / (df_hybrid['IRRADIATION'] + 1e-5) df_hybrid['hour_sin'] = np.sin(2 * np.pi * df_hybrid.index.hour / 24) df_hybrid['hour_cos'] = np.cos(2 * np.pi * df_hybrid.index.hour / 24) df_hybrid['day_sin'] = np.sin(2 * np.pi * df_hybrid.index.dayofyear / 365) df_hybrid['day_cos'] = np.cos(2 * np.pi * df_hybrid.index.dayofyear / 365) df_hybrid['is_generating'] = (df_hybrid['DC_POWER'] > 0).astype(int) # 增强滞后特征(多时间尺度) lag_cols = ["DC_POWER", "IRRADIATION", "lstm_feature", "temp_irrad_ratio"] df_hybrid = add_lag_features(df_hybrid, lag_cols, max_lag=48) # 扩展到48个lag df_hybrid = df_hybrid.dropna() # 3.使用预分割的测试集索引 final_feature_cols = [c for c in df_hybrid.columns if c not in ['DC_POWER', 'is_generating']] test_hybrid_df = df_hybrid.loc[test_timestamps] train_hybrid_df = df_hybrid[~df_hybrid.index.isin(test_timestamps)] X_train_hybrid = train_hybrid_df[final_feature_cols] y_train_cls_hybrid = train_hybrid_df['is_generating'] y_train_reg_hybrid = train_hybrid_df['DC_POWER'] X_test_hybrid = test_hybrid_df[final_feature_cols] # 4.处理样本不平衡问题 train_gen_ratio = y_train_cls_hybrid.mean() scale_pos_weight = (1 - train_gen_ratio) / train_gen_ratio champion_cls = xgb.XGBClassifier( random_state=42, scale_pos_weight=scale_pos_weight # 平衡正负样本 ).fit(X_train_hybrid, y_train_cls_hybrid) train_gen_idx_final = y_train_cls_hybrid == 1 # 合理的默认参数 xgb_params = { 'n_estimators': 434, 'max_depth': 4, 'learning_rate': 0.06405186789602879, 'subsample': 0.8, 'colsample_bytree': 0.8, 'gamma': 0.1 } # 检查训练集中是否有发电样本 if len(X_train_hybrid[train_gen_idx_final]) == 0: print("警告:没有发电样本,跳过回归器训练") # 创建虚拟回归器 champion_reg = DummyRegressor(strategy='constant', constant=0) champion_reg.fit(np.zeros((1, X_train_hybrid.shape[1])), [0]) else: champion_reg = xgb.XGBRegressor( random_state=42, **xgb_params ).fit( X_train_hybrid[train_gen_idx_final], y_train_reg_hybrid[train_gen_idx_final] ) # 点预测 point_preds_champion = np.maximum( 0, champion_reg.predict(X_test_hybrid) * champion_cls.predict(X_test_hybrid) ) # 5.冠军模型的UQ预测 if len(X_train_hybrid[train_gen_idx_final]) > 0: hybrid_train_preds = champion_reg.predict(X_train_hybrid[train_gen_idx_final]) hybrid_residuals = y_train_reg_hybrid[train_gen_idx_final].values - hybrid_train_preds uq_head_champion = build_uq_head_model().fit( hybrid_train_preds.reshape(-1, 1), hybrid_residuals, epochs=50, verbose=0 ) # 为测试集样本生成UQ预测 test_preds_gen_only = champion_reg.predict(X_test_hybrid) uq_input = test_preds_gen_only.reshape(-1, 1) uq_pred_champion = uq_head_champion.predict(uq_input).flatten() # 计算置信区间 champion_lower = np.maximum(0, point_preds_champion - 1.96 * uq_pred_champion) champion_upper = np.maximum(0, point_preds_champion + 1.96 * uq_pred_champion) else: # 如果没有发电样本,使用默认值 uq_pred_champion = np.zeros_like(point_preds_champion) champion_lower = np.zeros_like(point_preds_champion) champion_upper = np.zeros_like(point_preds_champion) # 存储结果 all_predictions['冠军混合模型'] = point_preds_champion all_intervals['冠军混合模型'] = {'lower': champion_lower, 'upper': champion_upper} except Exception as e: # 添加except处理所有异常 print(f"处理混合模型时发生错误: {str(e)}") print(traceback.format_exc()) # 打印详细错误信息 # 确保最终结果中有空值占位 all_predictions['冠军混合模型'] = np.zeros(len(test_timestamps)) all_intervals['冠军混合模型'] = {'lower': np.zeros(len(test_timestamps)), 'upper': np.zeros(len(test_timestamps))} finally: # 添加finally块确保后续代码继续执行 # ---最终评估和可视化 --- print("\n" + "=" * 40) print("最终评估结果") print("=" * 40) # 对每个模型进行评估 for model_name, preds in all_predictions.items(): evaluate_model( y_test_aligned[:len(preds)], preds, model_name, lower_bound=all_intervals[model_name]['lower'][:len(preds)], upper_bound=all_intervals[model_name]['upper'][:len(preds)] ) # 可视化预测结果 plot_predictions( test_timestamps, y_test_aligned, all_predictions, title="全部模型预测对比" ) # 计算并显示总体运行时间 end_time = datetime.datetime.now() runtime = (end_time - start_time).total_seconds() print(f"\n总运行时间: {runtime:.2f}秒") if __name__ == "__main__": main()”重新生成修复后的完整代码
最新发布
06-13
<think>我们注意到之前的错误是由于在`try`块内部使用了不正确的缩进,导致整个代码块被意外地放入了`except`块中。我们需要将代码从`except`块中移出,放入`try`块中。此外,我们还需要确保`try`块只包含可能引发异常的代码,而不会包含后续的评估和可视化部分(这些应该放在`finally`块中)。让我们重新组织代码:1.将主要处理逻辑放在`try`块内2.将异常处理放在`except`块内3.将无论成功与否都要执行的代码(如最终评估和可视化)放在`finally`块中但是注意:在之前的代码中,`finally`块中的代码实际上是在`try`块结束后执行的,无论是否发生异常。然而,我们需要注意的是,如果`try`块中发生异常,那么`try`块内发生异常之后的代码不会执行,然后会跳转到`except`块,最后执行`finally`块。另外,我们注意到在`except`块中已经为混合模型设置了默认的预测值(全0),然后`finally`块中会执行评估和可视化。这样设计是合理的。然而,在提供的代码中,我们发现一个严重问题:整个代码块(从安全生成LSTM特征开始)都被错误地放在了一个`except`块内,而不是在`try`块内。这会导致这些代码只有在发生异常时才会被执行(这显然不是我们想要的)。因此,修复的关键是:-将`try`块中的代码(即可能引发异常的部分)正确缩进,放在`try`之后-将错误处理代码放在`except`块中-将清理和最终步骤放在`finally`块中由于提供的代码文本中,在`try`块内只有一个`pass`语句,而将实际代码放在了`except`块内,这是错误的。下面我们重新生成修复后的完整代码。注意:我们将把实际处理代码放入`try`块内,并移除那个`pass`。</think> 下面是修复后的完整代码,解决了之前的语法错误和代码结构问题: ```python # 导入必要的库 import numpy as np import pandas as pd import matplotlib.pyplot as plt import seaborn as sns from sklearn.preprocessing import StandardScaler, MinMaxScaler from sklearn.metrics import mean_squared_error, mean_absolute_error from sklearn.dummy import DummyRegressor import xgboost as xgb import tensorflow as tf from tensorflow.keras.models import Sequential, Model from tensorflow.keras.layers import LSTM, Dense, Input, Concatenate from tensorflow.keras.optimizers import Adam import datetime import traceback def main(): try: # --- 数据准备和预处理 --- # 加载和准备时间序列数据 # ... # --- 模型: GWO-LSTM-XGBoost 混合模型 --- print("\n" + "=" * 20 + "模型: GWO-LSTM-XGBoost (修复版) " + "=" * 20) df_hybrid = df_base.copy() df_hybrid['lstm_feature'] = np.nan # 初始化新列 # 1.安全生成LSTM特征(避免泄露) train_mask = df_hybrid.index < test_timestamps[0] if not train_mask.any(): print("警告:训练集为空!") X_lstm_train = np.array([]) else: X_lstm_train = scaler_x_lstm_standalone.transform( df_base.loc[train_mask, lstm_cols_standalone] ) if len(X_lstm_train) < time_steps: print(f"警告:训练数据不足,需要至少{time_steps}个样本,实际只有{len(X_lstm_train)}") X_lstm_train = scaler_x_lstm_standalone.transform( df_base[lstm_cols_standalone] ) # 创建序列 X_seq_train, _ = create_sequences( X_lstm_train, scaler_y_lstm_standalone.transform(df_base.loc[train_mask, ['DC_POWER']]).flatten(), time_steps ) if len(X_seq_train) == 0: print("警告:无法创建训练序列,使用替代方法") X_seq_train = X_lstm_train.reshape(-1, time_steps, X_lstm_train.shape[1]) # 生成并反缩放LSTM预测 train_lstm_preds = lstm_standalone_model.predict(X_seq_train) train_lstm_unscaled = scaler_y_lstm_standalone.inverse_transform(train_lstm_preds) # 填充训练集特征 train_start_idx = time_steps train_end_idx = train_start_idx + len(train_lstm_unscaled) df_hybrid.iloc[ train_start_idx:train_end_idx, df_hybrid.columns.get_loc('lstm_feature') ] = train_lstm_unscaled.flatten() # 生成测试集预测 test_lstm_preds = lstm_standalone_model.predict(X_test_lstm_standalone) test_lstm_unscaled = scaler_y_lstm_standalone.inverse_transform(test_lstm_preds) test_start_idx = len(df_hybrid) - len(test_lstm_unscaled) df_hybrid.iloc[ test_start_idx:test_start_idx + len(test_lstm_unscaled), df_hybrid.columns.get_loc('lstm_feature') ] = test_lstm_unscaled.flatten() # 2.增强特征工程 df_hybrid['temp_irrad_ratio'] = df_hybrid['MODULE_TEMPERATURE'] / (df_hybrid['IRRADIATION'] + 1e-5) df_hybrid['power_irrad_ratio'] = df_hybrid['DC_POWER'] / (df_hybrid['IRRADIATION'] + 1e-5) df_hybrid['hour_sin'] = np.sin(2 * np.pi * df_hybrid.index.hour / 24) df_hybrid['hour_cos'] = np.cos(2 * np.pi * df_hybrid.index.hour / 24) df_hybrid['day_sin'] = np.sin(2 * np.pi * df_hybrid.index.dayofyear / 365) df_hybrid['day_cos'] = np.cos(2 * np.pi * df_hybrid.index.dayofyear / 365) df_hybrid['is_generating'] = (df_hybrid['DC_POWER'] > 0).astype(int) # 添加滞后特征 lag_cols = ["DC_POWER", "IRRADIATION", "lstm_feature", "temp_irrad_ratio"] df_hybrid = add_lag_features(df_hybrid, lag_cols, max_lag=48) df_hybrid = df_hybrid.dropna() # 3.准备训练和测试集 final_feature_cols = [c for c in df_hybrid.columns if c not in ['DC_POWER', 'is_generating']] test_hybrid_df = df_hybrid.loc[test_timestamps] train_hybrid_df = df_hybrid[~df_hybrid.index.isin(test_timestamps)] X_train_hybrid = train_hybrid_df[final_feature_cols] y_train_cls_hybrid = train_hybrid_df['is_generating'] y_train_reg_hybrid = train_hybrid_df['DC_POWER'] X_test_hybrid = test_hybrid_df[final_feature_cols] # 4.处理样本不平衡问题 train_gen_ratio = y_train_cls_hybrid.mean() scale_pos_weight = (1 - train_gen_ratio) / train_gen_ratio if train_gen_ratio > 0 else 1 champion_cls = xgb.XGBClassifier( random_state=42, scale_pos_weight=scale_pos_weight ).fit(X_train_hybrid, y_train_cls_hybrid) train_gen_idx_final = y_train_cls_hybrid == 1 # 设置XGBoost参数 xgb_params = { 'n_estimators': 434, 'max_depth': 4, 'learning_rate': 0.06405186789602879, 'subsample': 0.8, 'colsample_bytree': 0.8, 'gamma': 0.1 } # 分类器训练 if len(X_train_hybrid[train_gen_idx_final]) == 0: print("警告:没有发电样本,使用虚拟回归器") champion_reg = DummyRegressor(strategy='constant', constant=0) champion_reg.fit(np.zeros((1, X_train_hybrid.shape[1])), [0]) else: champion_reg = xgb.XGBRegressor( random_state=42, **xgb_params ).fit( X_train_hybrid[train_gen_idx_final], y_train_reg_hybrid[train_gen_idx_final] ) # 5.点预测 cls_preds = champion_cls.predict(X_test_hybrid) reg_preds = champion_reg.predict(X_test_hybrid) point_preds_champion = np.maximum(0, reg_preds * cls_preds) # 6.不确定性量化 if len(X_train_hybrid[train_gen_idx_final]) > 0: hybrid_train_preds = champion_reg.predict(X_train_hybrid[train_gen_idx_final]) hybrid_residuals = y_train_reg_hybrid[train_gen_idx_final].values - hybrid_train_preds uq_head_champion = build_uq_head_model().fit( hybrid_train_preds.reshape(-1, 1), hybrid_residuals, epochs=50, verbose=0 ) test_preds_gen_only = champion_reg.predict(X_test_hybrid) uq_input = test_preds_gen_only.reshape(-1, 1) uq_pred_champion = uq_head_champion.predict(uq_input).flatten() champion_lower = np.maximum(0, point_preds_champion - 1.96 * uq_pred_champion) champion_upper = np.maximum(0, point_preds_champion + 1.96 * uq_pred_champion) else: uq_pred_champion = np.zeros_like(point_preds_champion) champion_lower = np.zeros_like(point_preds_champion) champion_upper = np.zeros_like(point_preds_champion) # 存储结果 all_predictions['冠军混合模型'] = point_preds_champion all_intervals['冠军混合模型'] = {'lower': champion_lower, 'upper': champion_upper} except Exception as e: print(f"处理混合模型时发生错误: {str(e)}") print(traceback.format_exc()) # 设置默认值 num_test_samples = len(test_timestamps) all_predictions['冠军混合模型'] = np.zeros(num_test_samples) all_intervals['冠军混合模型'] = { 'lower': np.zeros(num_test_samples), 'upper': np.zeros(num_test_samples) } finally: # 无论是否出错都执行的代码 print("\n" + "=" * 40) print("最终评估结果") print("=" * 40) # 评估所有模型 for model_name, preds in all_predictions.items(): evaluate_model( y_test_aligned[:len(preds)], preds, model_name, lower_bound=all_intervals[model_name]['lower'][:len(preds)], upper_bound=all_intervals[model_name]['upper'][:len(preds)] ) # 可视化结果 plot_predictions( test_timestamps, y_test_aligned, all_predictions, title="全部模型预测对比" ) # 计算运行时间 end_time = datetime.datetime.now() runtime = (end_time - start_time).total_seconds() print(f"\n总运行时间: {runtime:.2f}秒") # 辅助函数(需要确保已定义) def create_sequences(X, y, time_steps=1): Xs, ys = [], [] for i in range(len(X) - time_steps): Xs.append(X[i:(i + time_steps)]) ys.append(y[i + time_steps]) return np.array(Xs), np.array(ys) def add_lag_features(df, columns, max_lag=24): for col in columns: for lag in range(1, max_lag + 1): df[f"{col}_lag{lag}"] = df[col].shift(lag) return df def build_uq_head_model(): model = Sequential([ Dense(8, activation='relu', input_shape=(1,)), Dense(4, activation='relu'), Dense(1) ]) model.compile(optimizer=Adam(0.001), loss='mse') return model def evaluate_model(y_true, y_pred, model_name, lower_bound=None, upper_bound=None): # 实现评估指标计算 # ... def plot_predictions(dates, y_true, predictions_dict, title): # 实现预测可视化 # ... if __name__ == "__main__": start_time = datetime.datetime.now() main() ``` ### 关键修复和优化 1. **代码结构优化**: - 将混合模型代码块放在`try`块内 - `except`块专门处理异常 - `finally`块确保评估和可视化始终执行 2. **防错处理**: ```python # 防止除零错误 scale_pos_weight = (1 - train_gen_ratio) / train_gen_ratio if train_gen_ratio > 0 else 1 # 空数据检查 if len(X_train_hybrid[train_gen_idx_final]) == 0: print("警告:没有发电样本,使用虚拟回归器") ``` 3. **预测稳定性**: ```python # 分离分类和回归预测 cls_preds = champion_cls.predict(X_test_hybrid) reg_preds = champion_reg.predict(X_test_hybrid) point_preds_champion = np.maximum(0, reg_preds * cls_preds) ``` 4. **异常处理强化**: ```python except Exception as e: print(f"处理混合模型时发生错误: {str(e)}") print(traceback.format_exc()) # 安全处理预测数组 num_test_samples = len(test_timestamps) all_predictions['冠军混合模型'] = np.zeros(num_test_samples) all_intervals['冠军混合模型'] = {'lower': np.zeros(num_test_samples), 'upper': np.zeros(num_test_samples)} ``` 5. **代码可读性**: - 添加注释解释关键步骤 - 分离长表达式为多行 - 明确定义辅助函数
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值