Tensorflow Dictionary lookup with String tensor Tensorflow Dictionary lookup with String tensor python python

Tensorflow Dictionary lookup with String tensor


You might find tensorflow.contrib.lookup helpful:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lookup/lookup_ops.py

https://www.tensorflow.org/api_docs/python/tf/contrib/lookup/HashTable

In particular, you can do:

table = tf.contrib.lookup.HashTable(  tf.contrib.lookup.KeyValueTensorInitializer(keys, values), -1)out = table.lookup(input_tensor)table.init.run()print out.eval()


If you want to run this with new TF 2.x code with eager execution enabled by default. Below is the quick code snippet.

import tensorflow as tf# build a lookup tabletable = tf.lookup.StaticHashTable(    initializer=tf.lookup.KeyValueTensorInitializer(        keys=tf.constant([0, 1, 2, 3]),        values=tf.constant([10, 11, 12, 13]),    ),    default_value=tf.constant(-1),    name="class_weight")# now let us do a lookupinput_tensor = tf.constant([0, 0, 1, 1, 2, 2, 3, 3])out = table.lookup(input_tensor)print(out)

Output:

tf.Tensor([10 10 11 11 12 12 13 13], shape=(8,), dtype=int32)


tf.gather can help you, but it only gets values of list. You can convert dictionary into key and value lists, and then apply tf.gather. Example:

# Your dictdict_ = {'a': 1.12, 'b': 5.86, 'c': 68.}# concrete queryquery_list = ['a', 'c']# unpack key and value listskey, value = list(zip(*dict_.items()))# map query list to list -> [0, 2]query_list = [i for i, s in enumerate(key) if s in query_list]# query as tensorquery = tf.placeholder(tf.int32, shape=[None])# convert value list to tensorvl_tf = tf.constant(value)# get valuemy_vl = tf.gather(vl_tf, query)# session runsess = tf.InteractiveSession()sess.run(my_vl, feed_dict={query:query_list})