How to map a function with additional parameter using the new Dataset api in TF1.3? How to map a function with additional parameter using the new Dataset api in TF1.3? python python

How to map a function with additional parameter using the new Dataset api in TF1.3?


Here is an example using a lambda expression to wrap the function to which we want to pass an argument:

import tensorflow as tfdef fun(x, arg):    return x * argmy_arg = tf.constant(2, dtype=tf.int64)ds = tf.data.Dataset.range(5)ds = ds.map(lambda x: fun(x, my_arg))

In the above, the signature of the function provided to map must match the contents of our dataset. So we have to write our lambda expression to match that. Here it is simple, as there is only one element contained in the dataset, the x that contains elements in the range from 0 to 4.

If necessary, you can pass in an arbitrary number of external arguments from outside the dataset: ds = ds.map(lambda x: my_other_fun(x, arg1, arg2, arg3), and so on.

To verify that the above works, we can observe that the mapping indeed multiplies each dataset element by two:

iterator = ds.make_initializable_iterator()next_x = iterator.get_next()with tf.Session() as sess:    sess.run(iterator.initializer)    while True:      try:        print(sess.run(next_x))      except tf.errors.OutOfRangeError:        break

The output:

02468


You can also use a Partial function instead to wrap your parameter :

def _parse_function(arg1, example_proto):  features = {"image": tf.FixedLenFeature((), tf.string, default_value=""),              "label": tf.FixedLenFeature((), tf.int32, default_value=0)}  parsed_features = tf.parse_single_example(example_proto, features)  return parsed_features["image"], parsed_features["label"]

The parameters order of your function is changed in order to fit the partiality, then you can wrap your function with a parameter value like following :

from functools import partialarg1 = ...dataset = dataset.map(partial(_parse_function, arg1))


Another solution is to use a class wrapper. In the following code, I passed the parameter shape to the parse function.

class MyDataSets:    def __init__(self, shape):        self.shape = shape    def parse_sample(self.sample):        features = { ... }        f = tf.parse_example([example], features=features)        image_raw = tf.decode_raw(f['image_raw'], tf.uint8)        image = image.reshape(image_raw, self.shape)        label = tf.cast(f['label'], tf.int32)        return image, label    def init(self):        ds = tf.data.TFRecordDataSets(...)        ds = ds.map(self.parse_sample)        ...        return ds.make_initializable_iterator()