tensorflow datasetのshuffle()、repeat()、batch()の使い方
610 ワード
import numpy as np
import tensorflow as tf
np.random.seed(0)
x = np.random.sample((11,2))
# make a dataset from a numpy array
print(x)
dataset = tf.data.Dataset.from_tensor_slices(x)
dataset = dataset.shuffle(2)# , ,
dataset = dataset.batch(4)# 4 , batch
dataset = dataset.repeat()#
# repeat() batch , ,
# , repeat()
# create the iterator
iter = dataset.make_one_shot_iterator()
el = iter.get_next()
with tf.Session() as sess:
for i in range(6):
value = sess.run(el)
print(value)