JAX学习笔记(random)

本文介绍了JAX中PRNGKey的概念,它用于生成随机数种子。重点讲解了random.RandomState,它是numpy的一个伪随机数生成器,能够产生[0,1]区间内的随机数序列。在机器学习训练时,可以结合RandomState和permutation来随机打乱数据顺序,用于batch的选择。" 103729528,7521098,Ubuntu用户连续多次输错密码自动锁定机制,"['Linux系统', '安全配置', 'Ubuntu', '服务器管理']

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

random.PRNGKey

from jax import random
key = random.PRNGKey(1)
print(key)

PRNGKey会生成一个(2,)shape array来作为seed的值

output: [0 1]

在未来需要生成随机数的时候,可以直接使用key值来作为seed,方便操作。

x = random.normal(key)
w = random.normal(key+1)
b = random.normal(key+2)

`

random.RandomState

RandomState其实是numpy中的语法,是一个伪随机数生成器,会生成[0,1]均匀分布的随机数序列,与seed()用法类似,但稍有不同。

具体代码用法:

rng = np.random.RandomState(1)
print(rng.rand(5))
output:[4.17022005e-01 7.20324493e-01 1.14374817e-04 3.02332573e-01
 1.46755891e-01]
rng = np.random.RandomState(1)
print(rng.rand(5))
output:[4.17022005e-01 7.20324493e-01 1.14374817e-04 3.02332573e-01

但要注意的是,RandomState必须在相同变量下使用,也就是上面的rng,否则就得不到相同的随机数组了。

在Machine Learning的训练过程中,可以结合permutation和RandomState来随机生成batch序列。

  rng = npr.RandomState(0)
  while True:
    perm = rng.permutation(num_train)
    #num_train 

以上code会打乱train input的序列顺序,方便batch选择。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值