defshuffle_aligned_list(data):"""Shuffle arrays in a list by shuffling each array identically."""
num = data[0].shape[0]
p = np.random.permutation(num)return[d[p]for d in data]defbatch_generator(data, batch_size, shuffle=True):"""Generate batches of data.
Given a list of array-like objects, generate batches of a given
size by yielding a list of array-like objects corresponding to the
same slice of each input.
"""if shuffle:
data = shuffle_aligned_list(data)
batch_count =0whileTrue:if batch_count * batch_size + batch_size >len(data[0]):
batch_count =0if shuffle:
data = shuffle_aligned_list(data)
start = batch_count * batch_size
end = start + batch_size
batch_count +=1yield[d[start:end]for d in data]
batch_size =128
data_gen = batch_generator([x_train,y_train], batch_size)
x_batch, y_batch =next(data_gen)print(x_batch.shape, y_batch.shape)