from keras.utils import multi_gpu_model
# Replicates `model` on 8 GPUs.# This assumes that your machine has 8 available GPUs.
parallel_model = multi_gpu_model(model, gpus=8)
parallel_model.compile(loss='categorical_crossentropy',
optimizer='rmsprop')# This `fit` call will be distributed on 8 GPUs.# Since the batch size is 256, each GPU will process 32 samples.
parallel_model.fit(x, y, epochs=20, batch_size=256)
设备并行代码
# Model where a shared LSTM is used to encode two different sequences in parallel
input_a = keras.Input(shape=(140, 256))
input_b = keras.Input(shape=(140, 256))
shared_lstm = keras.layers.LSTM(64)# Process the first sequence on one GPU
with tf.device_scope('/gpu:0'):
encoded_a = shared_lstm(tweet_a)# Process the next sequence on another GPU
with tf.device_scope('/gpu:1'):
encoded_b = shared_lstm(tweet_b)# Concatenate results on CPU
with tf.device_scope('/cpu:0'):
merged_vector = keras.layers.concatenate([encoded_a, encoded_b],
axis=-1)