Keras 模型系统解剖:Sequential vs Functional vs Subclassing

Keras 模型系统解剖:Sequential vs Functional vs Subclassing


一、Keras 模型的三种构建方式

构建方式优点缺点适用场景
Sequential简洁直观,适合线性模型不支持多输入/输出、分支结构入门教程,快速实验
Functional API支持 DAG 结构,可可视化、导出完整模型结构写法稍复杂多输入输出、ResNet/U-Net
子类化(Subclassing)极致灵活,控制流程、条件、循环等不可直接序列化图结构自定义层、RNN、GAN、自定义损失等场景

我们用实际代码对比:


1. Sequential(线性模型)

from tensorflow.keras import Sequential, layers

model = Sequential([
    layers.Dense(64, activation='relu', input_shape=(784,)),
    layers.Dense(10, activation='softmax')
])

背后逻辑就是:构造一个 Model,顺序叠加 Layer__call__ 自动传值。


2. Functional API(有向图结构)

from tensorflow.keras import Model, Input, layers

inputs = Input(shape=(784,))
x = layers.Dense(64, activation='relu')(inputs)
outputs = layers.Dense(10, activation='softmax')(x)

model = Model(inputs, outputs)

背后逻辑是:每个 Tensor 记录自己来自哪个 Layer,从而自动构建 DAG 结构,最后由 Model(inputs, outputs) 编译成一张完整的图。


3. Subclassing 自定义模型

from tensorflow.keras import Model, layers

class MyModel(Model):
    def __init__(self):
        super().__init__()
        self.d1 = layers.Dense(64, activation='relu')
        self.d2 = layers.Dense(10, activation='softmax')

    def call(self, x, training=False):
        x = self.d1(x)
        return self.d2(x)

model = MyModel()

适合需要控制流(如 if/else、循环)、多个路径、额外 loss 输出的模型。


二、Layer / Model 的底层机制拆解

TensorFlow 中的 LayerModel 都继承自 tf.keras.layers.Layer,核心流程如下:

每个 Layer 内部执行流程:

def __call__(self, inputs):
    # 自动追踪变量
    self._maybe_build(inputs)
    # 自动调用 call()
    outputs = self.call(inputs)
    # 添加到计算图
    return outputs

build()call() 的区别:

方法用途
build(input_shape)创建权重变量(weights、biases)
call(inputs)实际执行前向传播逻辑

注意:若你在 __init__() 中没有提供 input_shape,Keras 会在第一次调用时自动触发 build()


示例:自定义 Layer 的写法

class MyLayer(tf.keras.layers.Layer):
    def __init__(self, units=32):
        super().__init__()
        self.units = units

    def build(self, input_shape):
        self.w = self.add_weight(shape=(input_shape[-1], self.units),
                                 initializer='random_normal',
                                 trainable=True)

    def call(self, inputs):
        return tf.matmul(inputs, self.w)

Keras 会自动追踪 self.w 并纳入训练变量中。


三、Functional API 的“图构建”逻辑解析

每个 Tensor 都有 .keras_history,记录了它是由哪个 Layer 操作哪个输入产生的。

inputs = Input(shape=(32,))
x = Dense(64)(inputs)

此时 x 会记录:

x._keras_history = (layer_instance, node_index, tensor_index)

最终调用 Model(inputs, outputs) 时,Keras 会自动根据图关系逆向展开整张计算图。

这让你可以做复杂模型:

# 多输入模型
image = Input(shape=(224, 224, 3))
text = Input(shape=(100,))
x1 = CNN_Block(image)
x2 = Embedding_Block(text)
merged = layers.concatenate([x1, x2])
outputs = layers.Dense(1, activation='sigmoid')(merged)
model = Model([image, text], outputs)

四、Subclass 模型的细节陷阱

  1. call() 不等于可导出图结构

    子类模型没有“图结构”,所以:

    • 不支持 .to_json() 导出
    • 不能通过 .summary() 查看结构,除非先构建一次:
    model.build(input_shape=(None, 784))
    model.summary()
    
  2. 变量追踪必须在 __init__() 中定义

    否则 model.trainable_variables 无法追踪权重。

  3. 使用 @tf.function 会强制图化,但仍不可导出计算图结构

    @tf.function
    def inference(x):
        return model(x)
    

五、残差块实战:对比三种模型写法

Sequential ❌ 不支持跳连接

# 无法表达 x + F(x) 的残差结构

Functional ✅ 推荐方式

inputs = tf.keras.Input(shape=(64,))
x = layers.Dense(64, activation='relu')(inputs)
res = layers.Dense(64)(x)
outputs = layers.add([x, res])  # 残差连接
model = tf.keras.Model(inputs, outputs)

Subclass ✅ 结构自由

class ResidualBlock(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.fc1 = layers.Dense(64, activation='relu')
        self.fc2 = layers.Dense(64)

    def call(self, x):
        h = self.fc1(x)
        return x + self.fc2(h)

model = ResidualBlock()

六、模型保存与导出差异

模型类型支持 .save()支持 .to_json()可视化结构
Sequential ✅
Functional ✅
Subclass ✅(SavedModel)❌(不支持 to_json)❌(需要手动 build 后 summary)

✅ 所有类型都支持导出 SavedModel 用于部署,但只有 Functional/Sequential 支持 .model.to_json() / .load_model()


七、本章建议与小结

场景推荐方式
快速线性网络Sequential
复杂图结构、多输入/输出Functional API
自定义层、控制流、组合模块Subclass Model

使用 Functional API 是实际工作中的首选,既可追踪图,又可拓展模块;而子类模型则是高级开发者的主战场,适合实现 Transformer、GAN、分布式自定义训练逻辑。


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

观熵

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值