Google ColaboratoryのTPUでKerasを動かす


※公式ページを参考にしました.

準備

%tensorflow_version 2.x
import tensorflow as tf
print("Tensorflow version " + tf.__version__)

try:
  tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
  print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])
except ValueError:
  raise BaseException('ERROR: Not connected to a TPU runtime; please see the previous cell in this notebook for instructions!')

tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
tpu_strategy = tf.distribute.TPUStrategy(tpu)
def create_model():
  #model定義
  return model
with tpu_strategy.scope():
  model = create_model()

学習部分

通常通り,学習を行う.

history = model.fit( x_train, y_train, epochs = 50, batch_size = 2048, validation_data=(x_test, y_test))