Keras 模型系统解剖:Sequential vs Functional vs Subclassing
一、Keras 模型的三种构建方式
构建方式 | 优点 | 缺点 | 适用场景 |
---|---|---|---|
Sequential | 简洁直观,适合线性模型 | 不支持多输入/输出、分支结构 | 入门教程,快速实验 |
Functional API | 支持 DAG 结构,可可视化、导出完整模型结构 | 写法稍复杂 | 多输入输出、ResNet/U-Net |
子类化(Subclassing) | 极致灵活,控制流程、条件、循环等 | 不可直接序列化图结构 | 自定义层、RNN、GAN、自定义损失等场景 |
我们用实际代码对比:
1. Sequential(线性模型)
from tensorflow.keras import Sequential, layers
model = Sequential([
layers.Dense(64, activation='relu', input_shape=(784,)),
layers.Dense(10, activation='softmax')
])
背后逻辑就是:构造一个 Model
,顺序叠加 Layer
,__call__
自动传值。
2. Functional API(有向图结构)
from tensorflow.keras import Model, Input, layers
inputs = Input(shape=(784,))
x = layers.Dense(64, activation='relu')(inputs)
outputs = layers.Dense(10, activation='softmax')(x)
model = Model(inputs, outputs)
背后逻辑是:每个 Tensor
记录自己来自哪个 Layer,从而自动构建 DAG 结构,最后由 Model(inputs, outputs)
编译成一张完整的图。
3. Subclassing 自定义模型
from tensorflow.keras import Model, layers
class MyModel(Model):
def __init__(self):
super().__init__()
self.d1 = layers.Dense(64, activation='relu')
self.d2 = layers.Dense(10, activation='softmax')
def call(self, x, training=False):
x = self.d1(x)
return self.d2(x)
model = MyModel()
适合需要控制流(如 if/else、循环)、多个路径、额外 loss 输出的模型。
二、Layer / Model 的底层机制拆解
TensorFlow 中的 Layer
和 Model
都继承自 tf.keras.layers.Layer
,核心流程如下:
每个 Layer 内部执行流程:
def __call__(self, inputs):
# 自动追踪变量
self._maybe_build(inputs)
# 自动调用 call()
outputs = self.call(inputs)
# 添加到计算图
return outputs
build()
与 call()
的区别:
方法 | 用途 |
---|---|
build(input_shape) | 创建权重变量(weights、biases) |
call(inputs) | 实际执行前向传播逻辑 |
✅ 注意:若你在
__init__()
中没有提供 input_shape,Keras 会在第一次调用时自动触发 build()
示例:自定义 Layer 的写法
class MyLayer(tf.keras.layers.Layer):
def __init__(self, units=32):
super().__init__()
self.units = units
def build(self, input_shape):
self.w = self.add_weight(shape=(input_shape[-1], self.units),
initializer='random_normal',
trainable=True)
def call(self, inputs):
return tf.matmul(inputs, self.w)
Keras 会自动追踪 self.w
并纳入训练变量中。
三、Functional API 的“图构建”逻辑解析
每个 Tensor 都有 .keras_history
,记录了它是由哪个 Layer 操作哪个输入产生的。
inputs = Input(shape=(32,))
x = Dense(64)(inputs)
此时 x
会记录:
x._keras_history = (layer_instance, node_index, tensor_index)
最终调用 Model(inputs, outputs)
时,Keras 会自动根据图关系逆向展开整张计算图。
这让你可以做复杂模型:
# 多输入模型
image = Input(shape=(224, 224, 3))
text = Input(shape=(100,))
x1 = CNN_Block(image)
x2 = Embedding_Block(text)
merged = layers.concatenate([x1, x2])
outputs = layers.Dense(1, activation='sigmoid')(merged)
model = Model([image, text], outputs)
四、Subclass 模型的细节陷阱
-
call()
不等于可导出图结构子类模型没有“图结构”,所以:
- 不支持
.to_json()
导出 - 不能通过
.summary()
查看结构,除非先构建一次:
model.build(input_shape=(None, 784)) model.summary()
- 不支持
-
变量追踪必须在
__init__()
中定义否则
model.trainable_variables
无法追踪权重。 -
使用
@tf.function
会强制图化,但仍不可导出计算图结构@tf.function def inference(x): return model(x)
五、残差块实战:对比三种模型写法
Sequential ❌ 不支持跳连接
# 无法表达 x + F(x) 的残差结构
Functional ✅ 推荐方式
inputs = tf.keras.Input(shape=(64,))
x = layers.Dense(64, activation='relu')(inputs)
res = layers.Dense(64)(x)
outputs = layers.add([x, res]) # 残差连接
model = tf.keras.Model(inputs, outputs)
Subclass ✅ 结构自由
class ResidualBlock(tf.keras.Model):
def __init__(self):
super().__init__()
self.fc1 = layers.Dense(64, activation='relu')
self.fc2 = layers.Dense(64)
def call(self, x):
h = self.fc1(x)
return x + self.fc2(h)
model = ResidualBlock()
六、模型保存与导出差异
模型类型 | 支持 .save() | 支持 .to_json() | 可视化结构 |
---|---|---|---|
Sequential ✅ | ✅ | ✅ | |
Functional ✅ | ✅ | ✅ | |
Subclass ✅(SavedModel) | ❌(不支持 to_json) | ❌(需要手动 build 后 summary) |
✅ 所有类型都支持导出 SavedModel 用于部署,但只有 Functional/Sequential 支持
.model.to_json()
/.load_model()
。
七、本章建议与小结
场景 | 推荐方式 |
---|---|
快速线性网络 | Sequential |
复杂图结构、多输入/输出 | Functional API |
自定义层、控制流、组合模块 | Subclass Model |
使用 Functional API
是实际工作中的首选,既可追踪图,又可拓展模块;而子类模型则是高级开发者的主战场,适合实现 Transformer、GAN、分布式自定义训练逻辑。