很多时候,在用tensorflow的时候,需要用到正态分布来随机生成一些数字。但这个时候往往有两个选择,一个是tf.random_normal,另一个是tf.truncated_normal。
a = tf.Variable(tf.random_normal(shape=[10,10],mean=0.0, stddev=1.0))
b = tf.Variable(tf.truncated_normal(shape=[10,10],mean=0.0, stddev=1.0))
两个函数的相同点:
两个函数都是用来产生服从正态分布的随机数字,参数shape表示生成张量的维度,mean是均值,stddev是标准差。
两个函数的不同点:
truncated英文意思为:切去顶端的,缩短了的,被删节的。从字面意思可以理解,用tf.truncated_normal生成的随机数字和用tf.random_normal生成的随机数字相比应该被截断了一些数值,而被删减的数值是均值加减两个标准差以外的数值,即只保留(μ-2σ,μ+2σ)区间内的值,学过概率论的朋友都知道,正态分布的横坐标位于该区间对应的面积是95.449974%,也就是说被某个数字被截断的概率为4.55%。
举个例子:
a = tf.Variable(tf.random_normal(shape=[10,10],mean=0.0, stddev=1.0))
b = tf.Variable(tf.truncated_normal(shape=[10,10],mean=0.0, stddev=1.0))
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
print(sess.run(a))
print(sess.run(b))
输出:
[[ 2.03344369 0.95904255 0.51551449 -0.64910495 -1.54670858 -0.36222935
-1.35218477 0.90432709 1.54785538 0.07634345]
[ 0.98553091 1.6959343 -0.23737453 -1.40345669 0.23556276 0.20473556
0.92165202 0.53387034 0.96706229 -0.60246515]
[-1.70365489 1.4985795 0.9192369 -1.93338513 1.97124398 0.0449861
-0.40005141 1.70940208 1.34348738 -0.41837004]
[ 0.9452104 -0.39224333 -0.5954296 -0.03549456 1.41368639 0.84007186
1.36516595 -0.0983744 0.22925216 -1.02553356]
[ 0.13267736 0.34456405 -0.18935715 0.62262297 0.8849178 0.27924481
0.98577827 -0.58259553 -0.2697261 0.29480031]
[-0.12653549 0.47688225 -0.99692768 0.67438906 -0.01882477 -0.73439389
-2.33074999 -0.03368849 -0.84927607 0.26850846]
[-0.64319742 0.64246458 -1.54543269 0.20438036 0.3918426 -0.34693053
-2.31869936 -0.92044121 0.25214902 1.07554746]
[ 1.18516409 0.9817906 0.81630504 -0.29319233 -0.03353639 1.10301745
0.9214775 -0.48573691 1.81510174 0.98491228]
[ 0.56139094 1.28018165 -0.1233808 0.17519344 0.1390176 -0.64548278
0.17056739 0.86900598 -0.39172679 -1.47325706]
[ 0.29449868 -1.48477566 2.23907995 -1.60770559 -0.70809013 0.01188888
-1.47039604 -2.38012648 -0.46762401 -0.12005028]]
[[ 0.08508326 -0.52920961 -1.3884685 1.35747743 -1.33630145 -0.18566781
1.52860987 0.90255803 -0.32859045 0.14518318]
[ 0.70132804 -0.39033541 1.26529586 0.21072282 0.04094095 -0.56563634
-0.23343883 1.36468518 -1.06854677 -0.93521482]
[-0.52809483 0.36721036 -1.3816942 -1.33670199 0.90239221 -1.23525608
0.13001908 -0.69113421 1.60243666 -1.76012647]
[ 0.17061962 0.99714231 -0.01754658 -0.00242918 0.03781764 -0.09301429
1.3679347 0.73169291 0.7708528 -0.42838347]
[ 0.53765416 0.8646881 -1.04237461 0.41709244 -0.65324241 -1.5069294
-1.15489745 -0.17940596 -0.01608775 -0.42601049]
[ 0.12074001 -1.06748223 -1.11893451 0.10148379 -0.25133476 0.13374318
-1.87533033 0.59769326 0.21703653 0.42207873]
[-0.16803861 0.80412644 -0.7267732 0.18619744 -0.15367813 1.57330763
-0.10576735 -0.14867106 0.02034684 -0.83152997]
[ 0.82828605 1.64573359 -1.28018653 0.57960075 -0.517268 -1.73364758
0.92954189 -0.50420767 -0.35316199 -1.57559478]
[ 1.25463605 1.08494401 -0.91807097 -0.27777147 0.99380928 0.76619095
-0.21474986 -1.63276267 -0.42645523 -1.36141109]
[-0.44800332 1.99873173 0.42046729 1.299456 0.0731993 1.16181815
-0.732741 0.48478526 0.0165507 -0.73229623]]
容易看到,用tf.truncated_normal()输出的值严格限制在(-2,2)之间(此时标准差为1),而tf.random_normal()输出的值就是一个标准的正态分布,没有被截断。