深入理解Google DeepMind的Haiku框架基础

深入理解Google DeepMind的Haiku框架基础

dm-haiku JAX-based neural network library dm-haiku 项目地址: https://gitcode.com/gh_mirrors/dm/dm-haiku

Haiku框架概述

Haiku是Google DeepMind开发的一个基于JAX的神经网络库,它巧妙地将面向对象的编程模型与JAX的函数式编程范式相结合。Haiku的设计理念是让常见的神经网络操作(如参数管理和状态维护)变得简单直观,同时保留JAX纯函数变换的全部能力。

核心概念:模块与转换

基本模块构建

在Haiku中,我们通过继承hk.Module类来构建自定义模块。下面是一个线性层的实现示例:

class MyLinear1(hk.Module):
    def __init__(self, output_size, name=None):
        super().__init__(name=name)
        self.output_size = output_size
    
    def __call__(self, x):
        j, k = x.shape[-1], self.output_size
        w_init = hk.initializers.TruncatedNormal(1. / np.sqrt(j))
        w = hk.get_parameter("w", shape=[j, k], dtype=x.dtype, init=w_init)
        b = hk.get_parameter("b", shape=[k], dtype=x.dtype, init=jnp.ones)
        return jnp.dot(x, w) + b

这个模块定义了权重矩阵和偏置向量的初始化方式,并在调用时执行线性变换。

函数转换

由于JAX需要纯函数操作,我们需要使用hk.transform将模块转换为纯函数:

def _forward_fn_linear1(x):
    module = MyLinear1(output_size=2)
    return module(x)

forward_linear1 = hk.transform(_forward_fn_linear1)

转换后的函数对象包含两个关键方法:

  • init: 初始化网络参数
  • apply: 执行前向传播

参数初始化与推理

参数初始化

dummy_x = jnp.array([[1., 2., 3.]])
rng_key = jax.random.PRNGKey(42)
params = forward_linear1.init(rng=rng_key, x=dummy_x)

初始化过程需要提供一个随机数种子和样本输入(用于确定网络各层的形状)。

执行推理

sample_x = jnp.array([[1., 2., 3.]])
output = forward_linear1.apply(params=params, x=sample_x, rng=rng_key)

对于确定性操作,可以使用hk.without_apply_rng简化调用:

forward_without_rng = hk.without_apply_rng(hk.transform(_forward_fn_linear1))
output = forward_without_rng.apply(x=sample_x, params=params)

状态管理

Haiku支持在模块中维护状态变量,这在实现如BatchNorm等需要统计运行信息的层时非常有用:

def stateful_f(x):
    counter = hk.get_state("counter", shape=[], dtype=jnp.int32, init=jnp.ones)
    multiplier = hk.get_parameter('multiplier', shape=[1,], dtype=x.dtype, init=jnp.ones)
    hk.set_state("counter", counter + 1)
    return x + multiplier * counter

stateful_forward = hk.without_apply_rng(hk.transform_with_state(stateful_f))
params, state = stateful_forward.init(x=sample_x, rng=rng_key)
output, state = stateful_forward.apply(params, state, x=sample_x)

内置模块与模块组合

Haiku提供了许多常用层的实现(如MLP、卷积层等),可以方便地组合使用:

class MyModuleCustom(hk.Module):
    def __init__(self, output_size=2, name='custom_linear'):
        super().__init__(name=name)
        self._internal_linear_1 = hk.nets.MLP(output_sizes=[2, 3])
        self._internal_linear_2 = MyLinear1(output_size=output_size)
    
    def __call__(self, x):
        return self._internal_linear_2(self._internal_linear_1(x))

随机数生成

对于需要随机性的操作(如Dropout),可以使用hk.next_rng_key()获取确定性随机数:

def stochastic_forward(x):
    key = hk.next_rng_key()
    return jax.random.normal(key, x.shape)

总结

Haiku框架通过以下特性简化了JAX上的神经网络开发:

  1. 面向对象的模块化设计
  2. 透明的参数管理
  3. 灵活的状态维护机制
  4. 与JAX生态系统的无缝集成
  5. 丰富的内置网络组件

这些特性使得Haiku成为在JAX上构建复杂神经网络模型的理想选择,既保持了JAX的函数式编程优势,又提供了直观的面向对象开发体验。

dm-haiku JAX-based neural network library dm-haiku 项目地址: https://gitcode.com/gh_mirrors/dm/dm-haiku

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

霍忻念

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值