How does one initialize a variable with tf.get_variable and a numpy value in TensorFlow?
The following works, if you convert the constant NumPy array into a constant Tensor
:
init = tf.constant(np.random.rand(1, 2))tf.get_variable('var_name', initializer=init)
The documentation for get_variable
is a little lacking indeed. Just for your reference, the initializer
argument has to be either a TensorFlow Tensor
object (which can be constructed by calling tf.constant
on a numpy
value in your case), or a 'callable' that takes two arguments, shape
and dtype
, the shape and data type of the value that it's supposed to return. Again, in your case, you can write the following in case you wanted to use the 'callable' mechanism:
init = lambda shape, dtype: np.random.rand(*shape)tf.get_variable('var_name', initializer=init, shape=[1, 2])
@keveman Answered well, and for supplement, there is the usage of tf.get_variable('var_name', initializer=init), the tensorflow document did give a comprehensive example.
import numpy as npimport tensorflow as tfvalue = [0, 1, 2, 3, 4, 5, 6, 7]# value = np.array(value)# value = value.reshape([2, 4])init = tf.constant_initializer(value)print('fitting shape:')tf.reset_default_graph()with tf.Session() : x = tf.get_variable('x', shape = [2, 4], initializer = init) x.initializer.run() print(x.eval()) fitting shape :[[0. 1. 2. 3.][4. 5. 6. 7.]]print('larger shape:')tf.reset_default_graph()with tf.Session() : x = tf.get_variable('x', shape = [3, 4], initializer = init) x.initializer.run() print(x.eval()) larger shape :[[0. 1. 2. 3.][4. 5. 6. 7.][7. 7. 7. 7.]]print('smaller shape:')tf.reset_default_graph()with tf.Session() : x = tf.get_variable('x', shape = [2, 3], initializer = init) * <b>`ValueError`< / b > : Too many elements provided.Needed at most 6, but received 8
https://www.tensorflow.org/api_docs/python/tf/constant_initializer
If the variable was already created (ie from some complex function), just use load
.
https://www.tensorflow.org/api_docs/python/tf/Variable#load
x_var = tf.Variable(tf.zeros((1, 2), tf.float32))x_val = np.random.rand(1,2).astype(np.float32)sess = tf.Session()x_var.load(x_val, session=sess)# testassert np.all(sess.run(x_var) == x_val)