random.PRNGKey
from jax import random
key = random.PRNGKey(1)
print(key)
PRNGKey会生成一个(2,)shape array来作为seed的值
output: [0 1]
在未来需要生成随机数的时候,可以直接使用key值来作为seed,方便操作。
x = random.normal(key)
w = random.normal(key+1)
b = random.normal(key+2)
`
random.RandomState
RandomState其实是numpy中的语法,是一个伪随机数生成器,会生成[0,1]均匀分布的随机数序列,与seed()用法类似,但稍有不同。
具体代码用法:
rng = np.random.RandomState(1)
print(rng.rand(5))
output:[4.17022005e-01 7.20324493e-01 1.14374817e-04 3.02332573e-01
1.46755891e-01]
rng = np.random.RandomState(1)
print(rng.rand(5))
output:[4.17022005e-01 7.20324493e-01 1.14374817e-04 3.02332573e-01
但要注意的是,RandomState必须在相同变量下使用,也就是上面的rng,否则就得不到相同的随机数组了。
在Machine Learning的训练过程中,可以结合permutation和RandomState来随机生成batch序列。
rng = npr.RandomState(0)
while True:
perm = rng.permutation(num_train)
#num_train
以上code会打乱train input的序列顺序,方便batch选择。