tff.learning.from_keras_model

from_keras_model函数用于从tf.keras.Model创建tff.learning.Model对象,它处理前向传播和自动微分。输入参数包括未编译的keras_model、自定义损失函数、输入规范(input_spec)以及可选的损失权重和指标。input_spec必须是描述输入数据类型的结构。该函数不支持子类化的tf.keras.Models,且当Model包含BatchNormalization层时会发出警告。

tff.learning.from_keras_model(

    keras_model: tf.keras.Model,

    loss: Loss,

    input_spec,

    loss_weights: Optional[list[float]] = None,

    metrics: Optional[Union[list[tf.keras.metrics.Metric], list[Callable[[], tf.keras.

        metrics.Metric]]]] = None

) -> tff.learning.Model

功能:该函数返回的 "tff.learning.Model "使用 "keras_model "进行前向传递和自动分化。

参数:

keras_model: tf.keras.Model:一个未被编译的`tf.keras.Model`对象。

loss:损失函数

input_spec。一个`tf.TensorSpec`或`tff.Type`的结构,指定模型期望的参数类型。

如果`input_spec`是一个`tff.Type`,其叶子节点必须是`TensorType'。     

请注意,“input_spec”必须是两个元素的复合结构,指定输入的数据:生成预测的模型 (x)以及和标签(y)预期的类型。

如果作为列表提供,则必须按 [x, y] 顺序排列;如果作为字典提供,则键必须显式命名为“{}”和 `'{}'`.

loss_weights: (可选)一个Python浮点数的列表,用于加权每个模型输出的损失率。

metrics: (可选)一个`tf.keras.metrics.Metric`对象的列表,或者一个无参数的可调用对象的列表。

如果发现名称为 "num_examples "或 "num_batches "的指标,它们将被原样使用。

如果在`metrics`参数列表中没有找到它们,它们将被默认添加到`metrics`中。

默认使用`tff.learning.metrics.NumExamplesCounter`和`tff.learning.metrics.NumBatchesCounter'。

返回值:A tff.learning.Model` object.

注1:这个函数目前不接受子类化的`tf.keras.Models`

  因为它对某些属性的存在做了假设,而这些属性是通过函数式或序列式保证存在的。

  这些属性通过函数式或序列式API保证存在,但对于子类化的模型来说却不一定存在。

注2:如果`tf.keras.Model`包含BatchNormalization层,这个函数会引发一个UserWarning。

 因为批次的平均值和方差将被视为训练过程中不会被更新,可以考虑使用组归一化来代替。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值