TensorFlow: Remember LSTM state for next batch (stateful LSTM)
I found out it was easiest to save the whole state for all layers in a placeholder.
init_state = np.zeros((num_layers, 2, batch_size, state_size))...state_placeholder = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size])
Then unpack it and create a tuple of LSTMStateTuples before using the native tensorflow RNN Api.
l = tf.unpack(state_placeholder, axis=0)rnn_tuple_state = tuple([tf.nn.rnn_cell.LSTMStateTuple(l[idx][0], l[idx][1]) for idx in range(num_layers)])
RNN passes in the API:
cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)cell = tf.nn.rnn_cell.MultiRNNCell([cell]*num_layers, state_is_tuple=True)outputs, state = tf.nn.dynamic_rnn(cell, x_input_batch, initial_state=rnn_tuple_state)
The state
- variable will then be feeded to the next batch as a placeholder.
Tensorflow, best way to save state in RNNs? was actually my original question. The code bellow is how I use the state tuples.
with tf.variable_scope('decoder') as scope: rnn_cell = tf.nn.rnn_cell.MultiRNNCell \ ([ tf.nn.rnn_cell.LSTMCell(512, num_proj = 256, state_is_tuple = True), tf.nn.rnn_cell.LSTMCell(512, num_proj = WORD_VEC_SIZE, state_is_tuple = True) ], state_is_tuple = True) state = [[tf.zeros((BATCH_SIZE, sz)) for sz in sz_outer] for sz_outer in rnn_cell.state_size] for t in range(TIME_STEPS): if t: last = y_[t - 1] if TRAINING else y[t - 1] else: last = tf.zeros((BATCH_SIZE, WORD_VEC_SIZE)) y[t] = tf.concat(1, (y[t], last)) y[t], state = rnn_cell(y[t], state) scope.reuse_variables()
Rather than using tf.nn.rnn_cell.LSTMStateTuple
I just create a lists of lists which works fine. In this example I am not saving the state. However you could easily have made state out of variables and just used assign to save the values.