【从0开始】使用Flax NNX API 构建简单神经网络并训练

与 Linen API 不同,NNX 使用起来对初学者更加简单,跟 PyTorch 的体验更加接近。

任务

使用MLP拟合简单函数:
y=2x2+1 y=2x^2+1 y=2x2+1

代码

import jax.numpy as jnp
import jax.random as jrm
import optax as ox
from jax import Array
from flax import nnx
from typing import Generator


class Network(nnx.Module):
    """def a simple MLP"""

    def __init__(self, in_dim: int, out_dim: int, rng: nnx.Rngs, hidden_dim: int):
        super().__init__()
        self.linear1 = nnx.Linear(in_dim, hidden_dim, rngs=rng)
        self.linear2 = nnx.Linear(hidden_dim, hidden_dim, rngs=rng)
        self.linear3 = nnx.Linear(hidden_dim, out_dim, rngs=rng)

    def __call__(self, x) -> Array:
        x = self.linear1(x)
        x = nnx.relu(x)
        x = self.linear2(x)
        x = nnx.relu(x)
        x = self.linear3(x)
        return x


def make_dataset(
    X: Array, Y: Array, batch: int, seed: int = 0
) -> Generator[tuple[jnp.ndarray, jnp.ndarray], None, None]:
    "dataset sample function"
    combined = jnp.stack((X, Y), axis=1)[..., None]
    key = jrm.key(seed)
    while True:
        selected = jrm.choice(key, combined, shape=(batch,))
        yield selected[:, 0], selected[:, 1]


def loss_fn(model: Network, batch):
    x, y = batch
    predicted = model(x)
    return ox.l2_loss(predicted, y).mean()


# hyper parameter
seed = 0
batch = 16

# make dataset
X = jnp.arange(0, 10, 0.005)
Y = 2 * X**2 + 1.0

# build model & optimizer
model = Network(1, 1, hidden_dim=20, rng=nnx.Rngs(seed))
optimizer = nnx.Optimizer(model, ox.adamw(0.001, 0.90))

# train
for i, (x, y) in enumerate(make_dataset(X, Y, batch)):
    loss, grads = nnx.value_and_grad(loss_fn)(model, (x, y))
    optimizer.update(grads)
    print(i, loss)
    if i >= 6000:
        break

依赖如下

absl-py==2.1.0
chex==0.1.88
etils==1.11.0
flax==0.10.2
fsspec==2025.2.0
humanize==4.11.0
importlib-resources==6.5.2
jax==0.5.0
jaxlib==0.5.0
markdown-it-py==3.0.0
mdurl==0.1.2
ml-dtypes==0.5.1
msgpack==1.1.0
nest-asyncio==1.6.0
numpy==2.2.2
opt-einsum==3.4.0
optax==0.2.4
orbax-checkpoint==0.11.2
protobuf==3.20.3
pygments==2.19.1
pyyaml==6.0.2
rich==13.9.4
scipy==1.15.1
simplejson==3.19.3
tensorstore==0.1.71
toolz==1.0.0
typing-extensions==4.12.2
zipp==3.21.0
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值