Code
分享 Notebook
Python 3 (ipykernel)
import tensorflow as tf
1秒
+ Code
+ Markdown
import numpy as np
import os
import time
0秒
+ Code
+ Markdown
#数据加载
def load_dataset():
file_path = 'shakespeare.txt'
text = open(file_path,'rb').read().decode(encoding='utf-8')
return text
text = load_dataset()
#print(text)
0秒
+ Code
+ Markdown
# 创建字符到索引的映射
def create_vocab_mapping(text):
vocab = sorted(set(text))
char2idx = {char: idx for idx, char in enumerate(vocab)}
idx2char = np.array(vocab)
return vocab, char2idx, idx2char
vocab, char2idx, idx2char = create_vocab_mapping(text)
print(vocab)
0秒
['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', ']', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
+ Code
+ Markdown
print(char2idx)
0秒
{'\n': 0, ' ': 1, '!': 2, '$': 3, '&': 4, "'": 5, ',': 6, '-': 7, '.': 8, '3': 9, ':': 10, ';': 11, '?': 12, 'A': 13, 'B': 14, 'C': 15, 'D': 16, 'E': 17, 'F': 18, 'G': 19, 'H': 20, 'I': 21, 'J': 22, 'K': 23, 'L': 24, 'M': 25, 'N': 26, 'O': 27, 'P': 28, 'Q': 29, 'R': 30, 'S': 31, 'T': 32, 'U': 33, 'V': 34, 'W': 35, 'X': 36, 'Y': 37, 'Z': 38, '[': 39, ']': 40, 'a': 41, 'b': 42, 'c': 43, 'd': 44, 'e': 45, 'f': 46, 'g': 47, 'h': 48, 'i': 49, 'j': 50, 'k': 51, 'l': 52, 'm': 53, 'n': 54, 'o': 55, 'p': 56, 'q': 57, 'r': 58, 's': 59, 't': 60, 'u': 61, 'v': 62, 'w': 63, 'x': 64, 'y': 65, 'z': 66}
+ Code
+ Markdown
print(idx2char)
0秒
['\n' ' ' '!' '$' '&' "'" ',' '-' '.' '3' ':' ';' '?' 'A' 'B' 'C' 'D' 'E'
'F' 'G' 'H' 'I' 'J' 'K' 'L' 'M' 'N' 'O' 'P' 'Q' 'R' 'S' 'T' 'U' 'V' 'W'
'X' 'Y' 'Z' '[' ']' 'a' 'b' 'c' 'd' 'e' 'f' 'g' 'h' 'i' 'j' 'k' 'l' 'm'
'n' 'o' 'p' 'q' 'r' 's' 't' 'u' 'v' 'w' 'x' 'y' 'z']
+ Code
+ Markdown
# 将文本转换为数字序列
def text_to_sequence(text, char2idx):
return np.array([char2idx[char] for char in text])
text_as_int = text_to_sequence(text, char2idx)
print(text_as_int)
0秒
[18 49 58 ... 45 8 0]
+ Code
+ Markdown
# 分割输入和目标
def split_input_target(chunk):
input_text = chunk[:-1]
target_text = chunk[1:]
return input_text, target_text
# 创建训练数据集
def create_dataset(text_sequence, seq_length, batch_size, buffer_size=10000):
# 创建字符数据集
char_dataset = tf.data.Dataset.from_tensor_slices(text_sequence)
# 生成序列批次
sequences = char_dataset.batch(seq_length + 1, drop_remainder = True)
#ataset = sequences.map(split_input_target)
dataset = sequences.map(
split_input_target,
num_parallel_calls=tf.data.experimental.AUTOTUNE
)
# 打乱并分批
dataset = dataset.shuffle(buffer_size).batch(batch_size,drop_remainder=True)
return dataset
dataset = create_dataset(text_as_int, seq_length=100, batch_size=64, buffer_size=10000)
0秒
WARNING:tensorflow:Entity <function split_input_target at 0x7f5e7de26b80> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: 'arguments' object has no attribute 'defaults'
WARNING: Entity <function split_input_target at 0x7f5e7de26b80> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: 'arguments' object has no attribute 'defaults'
+ Code
+ Markdown
# 构建模型
def build_model(vocab_size, embedding_dim, rnn_units, batch_size):
model = tf.keras.Sequential([
tf.keras.layers.Embedding(vocab_size, embedding_dim, batch_input_shape = [batch_size, None]),
tf.keras.layers.GRU(rnn_units, return_sequences=True, stateful=True, recurrent_initializer='glorot_uniform'),
tf.keras.layers.Dense(vocab_size)
])
return model
0秒
+ Code
+ Markdown
# 参数设置
SEQ_LENGTH = 100
BATCH_SIZE = 64
BUFFER_SIZE = 10000
EMBEDDING_DIM = 256
RNN_UNITS = 1024
EPOCHS = 10
model = build_model(vocab_size=len(vocab), embedding_dim=EMBEDDING_DIM, rnn_units=RNN_UNITS, batch_size=BATCH_SIZE)
0秒
WARNING:tensorflow:From /opt/conda/lib/python3.8/site-packages/tensorflow_core/python/keras/initializers.py:118: calling RandomUniform.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
WARNING:tensorflow:From /opt/conda/lib/python3.8/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1623: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
+ Code
+ Markdown
# 自定义损失函数
def loss(labels, logits):
return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)
0秒
+ Code
+ Markdown
# 生成文本
def generate_text(model, start_string, char2idx, idx2char, num_generate=1000, temperature=1.0):
# 将起始字符串转换为数字序列
input_eval = [char2idx[char] for char in start_string]
input_eval = tf.expand_dims(input_eval, 0)
# 存储生成结果
text_generated = []
# 重置模型状态
model.reset_states()
for _ in range(num_generate):
predictions = model(input_eval)
# 移除批次维度
predictions = tf.squeeze(predictions, 0)
# 使用温度参数调整预测分布
predictions = predictions / temperature
prediction_id = tf.random.categorical(predictions, num_samples=1)[-1,0].numpy()
# 将预测字符作为下一个输入
input_eval = tf.expand_dims([prediction_id],0)
text_generated.append(idx2char[prediction_id])
return start_string + ''.join(text_generated)
0秒
+ Code
+ Markdown
# 编译模型
model.compile(optimizer = 'adam', loss = loss)
0秒
+ Code
+ Markdown
# 设置检查点回调
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_prefix,
save_weights_only=True
)
0秒
+ Code
+ Markdown
# 训练模型
history = model.fit(
dataset,
epochs=1,
callbacks=[checkpoint_callback])
2025-06-21 15:05:35.595888: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2 AVX512F FMA
WARNING:tensorflow:From /opt/conda/lib/python3.8/site-packages/tensorflow_core/python/ops/math_grad.py:1424: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Train on 707 steps
250/707 [=========>....................] - ETA: 17:31 - loss: 2.4835
+ Code
+ Markdown
vocab_size=len(vocab)
# 加载最佳模型权重
model = build_model(vocab_size, EMBEDDING_DIM, RNN_UNITS, batch_size=1)
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
model.build(tf.TensorShape([1, None]))
+ Code
+ Markdown
Code
# 生成文本
generated_text = generate_text(
model,
start_string="heLlo ",
char2idx=char2idx,
idx2char=idx2char,
num_generate=1000,
temperature=0.8 # 可以调整温度参数
)
print("\nGenerated Text:")
print(generated_text)
# 生成文本
generated_text = generate_text(
model,
start_string="heLlo ",
char2idx=char2idx,
idx2char=idx2char,
num_generate=1000,
temperature=0.8 # 可以调整温度参数
)
print("\nGenerated Text:")
print(generated_text)
7秒
Generated Text:
ROMEO: you do it.
KING HENRY VI:
O Lord, mad, you w
变其丰富
还有假注释
最新发布