与 Linen API 不同,NNX 使用起来对初学者更加简单,跟 PyTorch 的体验更加接近。
任务
使用MLP拟合简单函数:
y=2x2+1
y=2x^2+1
y=2x2+1
代码
import jax.numpy as jnp
import jax.random as jrm
import optax as ox
from jax import Array
from flax import nnx
from typing import Generator
class Network(nnx.Module):
"""def a simple MLP"""
def __init__(self, in_dim: int, out_dim: int, rng: nnx.Rngs, hidden_dim: int):
super().__init__()
self.linear1 = nnx.Linear(in_dim, hidden_dim, rngs=rng)
self.linear2 = nnx.Linear(hidden_dim, hidden_dim, rngs=rng)
self.linear3 = nnx.Linear(hidden_dim, out_dim, rngs=rng)
def __call__(self, x) -> Array:
x = self.linear1(x)
x = nnx.relu(x)
x = self.linear2(x)
x = nnx.relu(x)
x = self.linear3(x)
return x
def make_dataset(
X: Array, Y: Array, batch: int, seed: int = 0
) -> Generator[tuple[jnp.ndarray, jnp.ndarray], None, None]:
"dataset sample function"
combined = jnp.stack((X, Y), axis=1)[..., None]
key = jrm.key(seed)
while True:
selected = jrm.choice(key, combined, shape=(batch,))
yield selected[:, 0], selected[:, 1]
def loss_fn(model: Network, batch):
x, y = batch
predicted = model(x)
return ox.l2_loss(predicted, y).mean()
# hyper parameter
seed = 0
batch = 16
# make dataset
X = jnp.arange(0, 10, 0.005)
Y = 2 * X**2 + 1.0
# build model & optimizer
model = Network(1, 1, hidden_dim=20, rng=nnx.Rngs(seed))
optimizer = nnx.Optimizer(model, ox.adamw(0.001, 0.90))
# train
for i, (x, y) in enumerate(make_dataset(X, Y, batch)):
loss, grads = nnx.value_and_grad(loss_fn)(model, (x, y))
optimizer.update(grads)
print(i, loss)
if i >= 6000:
break
依赖如下
absl-py==2.1.0
chex==0.1.88
etils==1.11.0
flax==0.10.2
fsspec==2025.2.0
humanize==4.11.0
importlib-resources==6.5.2
jax==0.5.0
jaxlib==0.5.0
markdown-it-py==3.0.0
mdurl==0.1.2
ml-dtypes==0.5.1
msgpack==1.1.0
nest-asyncio==1.6.0
numpy==2.2.2
opt-einsum==3.4.0
optax==0.2.4
orbax-checkpoint==0.11.2
protobuf==3.20.3
pygments==2.19.1
pyyaml==6.0.2
rich==13.9.4
scipy==1.15.1
simplejson==3.19.3
tensorstore==0.1.71
toolz==1.0.0
typing-extensions==4.12.2
zipp==3.21.0