tf.data.Dataset: how to get the dataset size (number of elements in a epoch)? tf.data.Dataset: how to get the dataset size (number of elements in a epoch)? python-3.x python-3.x

tf.data.Dataset: how to get the dataset size (number of elements in a epoch)?


len(list(dataset)) works in eager mode, although that's obviously not a good general solution.


Take a look here: https://github.com/tensorflow/tensorflow/issues/26966

It doesn't work for TFRecord datasets, but it works fine for other types.

TL;DR:

num_elements = tf.data.experimental.cardinality(dataset).numpy()


UPDATE:

Use tf.data.experimental.cardinality(dataset) - see here.


In case of tensorflow datasets you can use _, info = tfds.load(with_info=True). Then you may call info.splits['train'].num_examples. But even in this case it doesn't work properly if you define your own split.

So you may either count your files or iterate over the dataset (like described in other answers):

num_training_examples = 0num_validation_examples = 0for example in training_set:    num_training_examples += 1for example in validation_set:    num_validation_examples += 1