```
import tensorflow as tf
from keras import datasets, layers, models
import matplotlib.pyplot as plt
# 导入mnist数据,依次分别为训练集图片、训练集标签、测试集图片、测试集标签
(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()
# 将像素的值标准化至0到1的区间内。(对于灰度图片来说,每个像素最大值是255,每个像素最小值是0,也就是直接除以255就可以完成归一化。)
train_images, test_images = train_images / 255.0, test_images / 255.0
# 查看数据维数信息
print(train_images.shape,test_images.shape,train_labels.shape,test_labels.shape)
#调整数据到我们需要的格式
train_images = train_images.reshape((60000, 28, 28, 1))
test_images = test_images.reshape((10000, 28, 28, 1))
print(train_images.shape,test_images.shape,train_labels.shape,test_labels.shape)
train_images = train_images.astype("float32") / 255.0
def image_to_patches(images, patch_size=4):
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images=images[:, :, :, tf.newaxis],
sizes=[1, patch_size, patch_size, 1],
strides=[1, patch_size, patch_size, 1],
rates=[1, 1, 1, 1],
padding="VALID"
)
return tf.reshape(patches, [batch_size, -1, patch_size*patch_size*1])
class TransformerBlock(tf.keras.layers.Layer):
def __init__(self, embed_dim, num_heads):
super().__init__()
self.att = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
self.ffn = tf.keras.Sequential([
tf.keras.layers.Dense(embed_dim*4, activation="relu"),
tf.keras.layers.Dense(embed_dim)
])
self.layernorm1 = tf.keras.layers.LayerNormalization()
self.layernorm2 = tf.keras.layers.LayerNormalization()
def call(self, inputs):
attn_output = self.att(inputs, inputs)
out1 = self.layernorm1(inputs + attn_output)
ffn_output = self.ffn(out1)
return self.layernorm2(out1 + ffn_output)
class PositionEmbedding(tf.keras.layers.Layer):
def __init__(self, max_len, embed_dim):
super().__init__()
self.pos_emb = tf.keras.layers.Embedding(input_dim=max_len, output_dim=embed_dim)
def call(self, x):
positions = tf.range(start=0, limit=tf.shape(x)[1], delta=1)
return x + self.pos_emb(positions)
def build_transformer_model():
inputs = tf.keras.Input(shape=(49, 16)) # 4x4 patches
x = tf.keras.layers.Dense(64)(inputs) # 嵌入维度64
# 添加位置编码
x = PositionEmbedding(max_len=49, embed_dim=64)(x)
# 堆叠Transformer模块
x = TransformerBlock(embed_dim=64, num_heads=4)(x)
x = TransformerBlock(embed_dim=64, num_heads=4)(x)
# 分类头
x = tf.keras.layers.GlobalAveragePooling1D()(x)
outputs = tf.keras.layers.Dense(10, activation="softmax")(x)
return tf.keras.Model(inputs=inputs, outputs=outputs)
model = build_transformer_model()
model.compile(optimizer="adam",
loss="sparse_categorical_crossentropy",
metrics=["accuracy"])
# 数据预处理
train_images_pt = image_to_patches(train_images[..., tf.newaxis])
test_images_pt = image_to_patches(test_images[..., tf.newaxis])
history = model.fit(
train_images_pt, train_labels,
validation_data=(test_images_pt, test_labels),
epochs=10,
batch_size=128
)```Exception has occurred: InvalidArgumentError
input must be 4-dimensional[60000,28,28,1,1] [Op:ExtractImagePatches]
tensorflow.python.eager.core._NotOkStatusException: InvalidArgumentError: input must be 4-dimensional[60000,28,28,1,1] [Op:ExtractImagePatches]
During handling of the above exception, another exception occurred:
File "D:\source\test3\transform.py", line 32, in image_to_patches
patches = tf.image.extract_patches(
File "D:\source\test3\transform.py", line 118, in <module>
train_images_pt = image_to_patches(train_images[..., tf.newaxis]) # 输出形状(60000,49,16)
tensorflow.python.framework.errors_impl.InvalidArgumentError: input must be 4-dimensional[60000,28,28,1,1] [Op:ExtractImagePatches]
最新发布