18、使用文本 text = “hello world” 实现一个用于文本生成的简单循环神经网络(RNN)。
以下是使用文本 “hello world” 实现简单 RNN 进行文本生成的代码:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, SimpleRNN
from tensorflow.keras.utils import to_categorical
# 样本文本语料
text = "hello world"
# 创建字符级词汇表
chars = sorted(set(text))
char_to_idx = {char: idx for idx, char in enumerate(chars)}
idx_to_char = {idx: char for char, idx in char_to_idx.items()}
# 创建训练用的输入 - 输出对
sequence_length = 3
X = []
y = []
for i in range(len(text) - sequence_length):
X.append([char_to_idx[char] for char in text[i:i + sequence_length]])
y.append(char_to_idx[text[i + sequence_length]])
X = np.array(X)
y = to_categorical(y, num_classes=len(chars))
# 重塑输入以适应 RNN 输入
X = X.reshape((X.shape[0], X.shape[1], 1))
# 定义 RNN 模型
model = Sequential()
model.add(SimpleRNN(50, input_shape=(sequence_length, 1)))
model.add(Dense(len(chars), activation='softmax'))
# 编译模型
model.compile(optimizer='adam', loss='categorical_crossentropy')
# 训练模型
model.fit(X, y, epochs=200, verbose=1)
# 用于使用训练好的模型生成文本的函数
def generate_text(model, start_string, num_generate):
input_eval = [char_to_idx[s] for s in start_string]
input_eval = np.array(input_eval).reshape((1, len(input_eval), 1))
text_generated = []
for i in range(num_generate):
predictions = model.predict(input_eval)
predicted_id = np.argmax(predictions[-1])
input_eval = np.append(input_eval[:, 1:], [[predicted_id]], axis=1)
text_generated.append(idx_to_char[predicted_id])
return start_string + ''.join(text_generated)
# 生成新文本
start_string = "hel"
generated_text = generate_text(model, start_string, 5)
print("Generated text:")
print(generated_text)
上述代码通过 TensorFlow 和 Keras 构建并训练了一个简单的字符级 RNN,目标是根据给定的输入序列生成文本。
19、使用文本 text = “hello world” 实现一个用于文本生成的长短期记忆网络(LSTM)
以下是实现代码:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, LSTM
from tensorflow.keras.utils import to_categorical
# 样本文本语料
text = "hello world"
# 创建字符级词汇表
chars = sorted(set(text))
char_to_idx = {char: idx for idx, char in enumerate(chars)}
idx_to_char = {idx: char for char, idx in char_to_idx.items()}
# 创建训练用的输入 - 输出对
seq

最低0.47元/天 解锁文章

被折叠的 条评论
为什么被折叠?



