TensorFlow: Remember LSTM state for next batch (stateful LSTM) TensorFlow: Remember LSTM state for next batch (stateful LSTM) python python

TensorFlow: Remember LSTM state for next batch (stateful LSTM)


I found out it was easiest to save the whole state for all layers in a placeholder.

init_state = np.zeros((num_layers, 2, batch_size, state_size))...state_placeholder = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size])

Then unpack it and create a tuple of LSTMStateTuples before using the native tensorflow RNN Api.

l = tf.unpack(state_placeholder, axis=0)rnn_tuple_state = tuple([tf.nn.rnn_cell.LSTMStateTuple(l[idx][0], l[idx][1]) for idx in range(num_layers)])

RNN passes in the API:

cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)cell = tf.nn.rnn_cell.MultiRNNCell([cell]*num_layers, state_is_tuple=True)outputs, state = tf.nn.dynamic_rnn(cell, x_input_batch, initial_state=rnn_tuple_state)

The state - variable will then be feeded to the next batch as a placeholder.


Tensorflow, best way to save state in RNNs? was actually my original question. The code bellow is how I use the state tuples.

with tf.variable_scope('decoder') as scope:    rnn_cell = tf.nn.rnn_cell.MultiRNNCell \    ([        tf.nn.rnn_cell.LSTMCell(512, num_proj = 256, state_is_tuple = True),        tf.nn.rnn_cell.LSTMCell(512, num_proj = WORD_VEC_SIZE, state_is_tuple = True)    ], state_is_tuple = True)    state = [[tf.zeros((BATCH_SIZE, sz)) for sz in sz_outer] for sz_outer in rnn_cell.state_size]    for t in range(TIME_STEPS):        if t:            last = y_[t - 1] if TRAINING else y[t - 1]        else:            last = tf.zeros((BATCH_SIZE, WORD_VEC_SIZE))        y[t] = tf.concat(1, (y[t], last))        y[t], state = rnn_cell(y[t], state)        scope.reuse_variables()

Rather than using tf.nn.rnn_cell.LSTMStateTuple I just create a lists of lists which works fine. In this example I am not saving the state. However you could easily have made state out of variables and just used assign to save the values.