How to use windows created by the Dataset.window() method in TensorFlow 2.0? How to use windows created by the Dataset.window() method in TensorFlow 2.0? python python

How to use windows created by the Dataset.window() method in TensorFlow 2.0?


The solution is to call flat_map() like this:

dataset = dataset.flat_map(lambda window: window.batch(5))

Now each item in the dataset is a window, so you can split it like this:

dataset = dataset.map(lambda window: (window[:-1], window[-1:]))

So the full code is:

import tensorflow as tfdataset = tf.data.Dataset.from_tensor_slices(tf.range(10))dataset = dataset.window(5, shift=1, drop_remainder=True)dataset = dataset.flat_map(lambda window: window.batch(5))dataset = dataset.map(lambda window: (window[:-1], window[-1:]))for X, y in dataset:    print("Input:", X.numpy(), "Target:", y.numpy())

Which outputs:

Input: [0 1 2 3] Target: [4]Input: [1 2 3 4] Target: [5]Input: [2 3 4 5] Target: [6]Input: [3 4 5 6] Target: [7]Input: [4 5 6 7] Target: [8]Input: [5 6 7 8] Target: [9]