How to use windows created by the Dataset.window() method in TensorFlow 2.0?
The solution is to call flat_map()
like this:
dataset = dataset.flat_map(lambda window: window.batch(5))
Now each item in the dataset is a window, so you can split it like this:
dataset = dataset.map(lambda window: (window[:-1], window[-1:]))
So the full code is:
import tensorflow as tfdataset = tf.data.Dataset.from_tensor_slices(tf.range(10))dataset = dataset.window(5, shift=1, drop_remainder=True)dataset = dataset.flat_map(lambda window: window.batch(5))dataset = dataset.map(lambda window: (window[:-1], window[-1:]))for X, y in dataset: print("Input:", X.numpy(), "Target:", y.numpy())
Which outputs:
Input: [0 1 2 3] Target: [4]Input: [1 2 3 4] Target: [5]Input: [2 3 4 5] Target: [6]Input: [3 4 5 6] Target: [7]Input: [4 5 6 7] Target: [8]Input: [5 6 7 8] Target: [9]