利用tensorflow训练数据时,要自己手动给网络“喂”数据,在使用自己的数据集时,有时要手动选取batch数据,可利用如下代码选取随机batch
def random_batch(X_train, y_train, batch_size):
rnd_indices = np.random.randint(0, len(X_train), batch_size)
X_batch = X_train[rnd_indices]
y_batch = y_train[rnd_indices]
return X_batch, y_batch
注意最后“喂”的数据是numpy array类型,不用转换为tensor。这样选取的数据有的数据可能从来未被选取过,有的数据则被多次选取,不过这在训练过程中影响不大。