在学习深度学习的过程中,发现很多例子都采用批量训练的方法,比如每次迭代从数据集中选择50或者100条数据进行训练计算。这么做的好处是可以加快运算速度,但是,经过测试发现,过大批训练数据,容易引起训练结果不准确,拟合度低的缺点,下面我举一个例子说明:
import numpy as np
from sklearn.utils import shuffle
import matplotlib.pyplot as plt
import tensorflow as tf
import datetime
starttime = datetime.datetime.now()
def generae(sample_size,mean,cov,diff,regression):
num_class=2
sample_per_class=int(sample_size/2)
X0=np.random.multivariate_normal(mean,cov,sample_per_class)
Y0=np.zeros(sample_per_class)
X,Y=[],[]
for ci,d in enumerate(diff):
X1=np.random.multivariate_normal(mean+d,cov,sample_per_class)
Y1=(ci+1)*np.ones(sample_per_class)
X0=np.concatenate((X0,X1))
Y0=np.concatenate((Y0,Y1))
if regression==False:
class_ind=[Y==class_number for class_number in range(num_class)]
Y=np.asarray(np.hstack(class_ind),dtype=np.float32)
X,Y=shuffle(X0,Y0)
return X,Y
np.random.seed(10)
num_classes=2
mean=np.random.randn(num_classes)
cov=np.eye(num_classes)
X,Y=generae(1000, mean, cov, [3.0], False)
lab_dim=1
input_dim=2
input_features=tf.placeholder(tf.float32,[None,input_dim])
input_lables=tf.placeholder(tf.float32,[None,lab_dim])
W=tf.Variable(tf.random_normal([input_dim,lab_dim]),name='weight')
b=tf.Variable(tf.zeros([lab_dim]),name='bias')
output=tf.nn.sigmoid(tf.matmul(input_features,W)+b)
loss=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=input_lables,logits=tf.matmul(input_features,W)+b))
optimizer=tf.train.AdamOptimizer(0.01)
train=optimizer.minimize(loss)
maxEpochs=50
minitatchSize=5
x1=[]
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for epoch in range(maxEpochs):
sumerr=0
for i in range(np.int32(len(Y)/minitatchSize)):
x1=X[i*minitatchSize:minitatchSize*(i+1),:]
y1=np.reshape(Y[i*minitatchSize:minitatchSize*(i+1)],[-1,1])
_,lossval,outputval=sess.run([train,loss,output],feed_dict={input_features:x1,input_lables:y1})
print(epoch,"----",lossval)
print("Finished")
endtime = datetime.datetime.now()
print("Eplased Time:",(endtime - starttime).seconds)
这个例子产生了1000条数据,并且迭代50个世代。当批训练数据大小设置为25条的时候,程序输出如下:
0 ---- 0.8833265
1 ---- 0.7246625
2 ---- 0.56082326
3 ---- 0.4263128
4 ---- 0.3264776
5 ---- 0.25968412
6 ---- 0.21754791
7 ---- 0.19074523
8 ---- 0.17274845
9 ---- 0.15978481
10 ---- 0.14979471
11 ---- 0.14165106
12 ---- 0.13472195
13 ---- 0.12864223
14 ---- 0.12319314
15 ---- 0.11823835
16 ---- 0.11368922
17 ---- 0.109484985
18 ---- 0.10558146
19 ---- 0.101945356
20 ---- 0.0985498
21 ---- 0.0953724
22 ---- 0.092394
23 ---- 0.08959783
24 ---- 0.08696909
25 ---- 0.08449426
26 ---- 0.08216138
27 ---- 0.07995938
28 ---- 0.0778783
29 ---- 0.07590899
30 ---- 0.07404316
31 ---- 0.07227333
32 ---- 0.070592485
33 ---- 0.06899432
34 ---- 0.06747306
35 ---- 0.06602346
36 ---- 0.06464061
37 ---- 0.06331999
38 ---- 0.062057685
39 ---- 0.060849763
40 ---- 0.059692908
41 ---- 0.058583792
42 ---- 0.057519644
43 ---- 0.05649774
44 ---- 0.05551556
45 ---- 0.05457068
46 ---- 0.05366121
47 ---- 0.052785005
48 ---- 0.05194026
49 ---- 0.05112519
Finished
Eplased Time: 3
当批训练数据大小设置为5条的时候,程序输出如下:
0 ---- 2.4122727
1 ---- 1.2697283
2 ---- 0.56022817
3 ---- 0.26786795
4 ---- 0.15360412
5 ---- 0.10067397
6 ---- 0.071917914
7 ---- 0.054414533
8 ---- 0.04286511
9 ---- 0.034778215
10 ---- 0.028854271
11 ---- 0.02435914
12 ---- 0.020851564
13 ---- 0.018052543
14 ---- 0.015778149
15 ---- 0.013902565
16 ---- 0.01233704
17 ---- 0.0110171195
18 ---- 0.009894882
19 ---- 0.008933785
20 ---- 0.008105488
21 ---- 0.007387633
22 ---- 0.0067623234
23 ---- 0.006215036
24 ---- 0.0057339044
25 ---- 0.005309134
26 ---- 0.0049326196
27 ---- 0.0045975544
28 ---- 0.004298257
29 ---- 0.004029934
30 ---- 0.0037885192
31 ---- 0.0035705771
32 ---- 0.0033731621
33 ---- 0.0031937952
34 ---- 0.0030303255
35 ---- 0.0028808962
36 ---- 0.002743928
37 ---- 0.0026180497
38 ---- 0.0025020558
39 ---- 0.0023949211
40 ---- 0.0022957386
41 ---- 0.0022037134
42 ---- 0.0021181535
43 ---- 0.0020384365
44 ---- 0.0019640164
45 ---- 0.0018944215
46 ---- 0.0018292269
47 ---- 0.0017680468
48 ---- 0.0017105412
49 ---- 0.0016564125
Finished
Eplased Time: 14
从结果对比可以看出,批大小缩小到原先的1/5,训练的精度提高了将近30倍(当然这么说可能不太严谨),训练的时间开销增加了将近5倍。如何取舍,取决于训数据的大小和对训练结果的要求。