深入理解Google DeepMind的Haiku框架基础
dm-haiku JAX-based neural network library 项目地址: https://gitcode.com/gh_mirrors/dm/dm-haiku
Haiku框架概述
Haiku是Google DeepMind开发的一个基于JAX的神经网络库,它巧妙地将面向对象的编程模型与JAX的函数式编程范式相结合。Haiku的设计理念是让常见的神经网络操作(如参数管理和状态维护)变得简单直观,同时保留JAX纯函数变换的全部能力。
核心概念:模块与转换
基本模块构建
在Haiku中,我们通过继承hk.Module
类来构建自定义模块。下面是一个线性层的实现示例:
class MyLinear1(hk.Module):
def __init__(self, output_size, name=None):
super().__init__(name=name)
self.output_size = output_size
def __call__(self, x):
j, k = x.shape[-1], self.output_size
w_init = hk.initializers.TruncatedNormal(1. / np.sqrt(j))
w = hk.get_parameter("w", shape=[j, k], dtype=x.dtype, init=w_init)
b = hk.get_parameter("b", shape=[k], dtype=x.dtype, init=jnp.ones)
return jnp.dot(x, w) + b
这个模块定义了权重矩阵和偏置向量的初始化方式,并在调用时执行线性变换。
函数转换
由于JAX需要纯函数操作,我们需要使用hk.transform
将模块转换为纯函数:
def _forward_fn_linear1(x):
module = MyLinear1(output_size=2)
return module(x)
forward_linear1 = hk.transform(_forward_fn_linear1)
转换后的函数对象包含两个关键方法:
init
: 初始化网络参数apply
: 执行前向传播
参数初始化与推理
参数初始化
dummy_x = jnp.array([[1., 2., 3.]])
rng_key = jax.random.PRNGKey(42)
params = forward_linear1.init(rng=rng_key, x=dummy_x)
初始化过程需要提供一个随机数种子和样本输入(用于确定网络各层的形状)。
执行推理
sample_x = jnp.array([[1., 2., 3.]])
output = forward_linear1.apply(params=params, x=sample_x, rng=rng_key)
对于确定性操作,可以使用hk.without_apply_rng
简化调用:
forward_without_rng = hk.without_apply_rng(hk.transform(_forward_fn_linear1))
output = forward_without_rng.apply(x=sample_x, params=params)
状态管理
Haiku支持在模块中维护状态变量,这在实现如BatchNorm等需要统计运行信息的层时非常有用:
def stateful_f(x):
counter = hk.get_state("counter", shape=[], dtype=jnp.int32, init=jnp.ones)
multiplier = hk.get_parameter('multiplier', shape=[1,], dtype=x.dtype, init=jnp.ones)
hk.set_state("counter", counter + 1)
return x + multiplier * counter
stateful_forward = hk.without_apply_rng(hk.transform_with_state(stateful_f))
params, state = stateful_forward.init(x=sample_x, rng=rng_key)
output, state = stateful_forward.apply(params, state, x=sample_x)
内置模块与模块组合
Haiku提供了许多常用层的实现(如MLP、卷积层等),可以方便地组合使用:
class MyModuleCustom(hk.Module):
def __init__(self, output_size=2, name='custom_linear'):
super().__init__(name=name)
self._internal_linear_1 = hk.nets.MLP(output_sizes=[2, 3])
self._internal_linear_2 = MyLinear1(output_size=output_size)
def __call__(self, x):
return self._internal_linear_2(self._internal_linear_1(x))
随机数生成
对于需要随机性的操作(如Dropout),可以使用hk.next_rng_key()
获取确定性随机数:
def stochastic_forward(x):
key = hk.next_rng_key()
return jax.random.normal(key, x.shape)
总结
Haiku框架通过以下特性简化了JAX上的神经网络开发:
- 面向对象的模块化设计
- 透明的参数管理
- 灵活的状态维护机制
- 与JAX生态系统的无缝集成
- 丰富的内置网络组件
这些特性使得Haiku成为在JAX上构建复杂神经网络模型的理想选择,既保持了JAX的函数式编程优势,又提供了直观的面向对象开发体验。
dm-haiku JAX-based neural network library 项目地址: https://gitcode.com/gh_mirrors/dm/dm-haiku
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考