How to save/restore a model after training? How to save/restore a model after training? python python

How to save/restore a model after training?

I am improving my answer to add more details for saving and restoring models.

In(and after) Tensorflow version 0.11:

Save the model:

import tensorflow as tf#Prepare to feed input, i.e. feed_dict and placeholdersw1 = tf.placeholder("float", name="w1")w2 = tf.placeholder("float", name="w2")b1= tf.Variable(2.0,name="bias")feed_dict ={w1:4,w2:8}#Define a test operation that we will restorew3 = tf.add(w1,w2)w4 = tf.multiply(w3,b1,name="op_to_restore")sess = tf.Session() a saver object which will save all the variablessaver = tf.train.Saver()#Run the operation by feeding inputprint,feed_dict)#Prints 24 which is sum of (w1+w2)*b1 #Now, save the, 'my_test_model',global_step=1000)

Restore the model:

import tensorflow as tfsess=tf.Session()    #First let's load meta graph and restore weightssaver = tf.train.import_meta_graph('my_test_model-1000.meta')saver.restore(sess,tf.train.latest_checkpoint('./'))# Access saved Variables directlyprint('bias:0'))# This will print 2, which is the value of bias that we saved# Now, let's access and create placeholders variables and# create feed-dict to feed new datagraph = tf.get_default_graph()w1 = graph.get_tensor_by_name("w1:0")w2 = graph.get_tensor_by_name("w2:0")feed_dict ={w1:13.0,w2:17.0}#Now, access the op that you want to run. op_to_restore = graph.get_tensor_by_name("op_to_restore:0")print,feed_dict)#This will print 60 which is calculated 

This and some more advanced use-cases have been explained very well here.

A quick complete tutorial to save and restore Tensorflow models

In (and after) TensorFlow version 0.11.0RC1, you can save and restore your model directly by calling tf.train.export_meta_graph and tf.train.import_meta_graph according to

Save the model

w1 = tf.Variable(tf.truncated_normal(shape=[10]), name='w1')w2 = tf.Variable(tf.truncated_normal(shape=[20]), name='w2')tf.add_to_collection('vars', w1)tf.add_to_collection('vars', w2)saver = tf.train.Saver()sess = tf.Session(), 'my-model')# `save` method will call `export_meta_graph` implicitly.# you will get saved graph files:my-model.meta

Restore the model

sess = tf.Session()new_saver = tf.train.import_meta_graph('my-model.meta')new_saver.restore(sess, tf.train.latest_checkpoint('./'))all_vars = tf.get_collection('vars')for v in all_vars:    v_ =    print(v_)

Tensorflow 2 Docs

Saving Checkpoints

Adapted from the docs

# -------------------------# -----  Toy Context  -----# -------------------------import tensorflow as tfclass Net(tf.keras.Model):    """A simple linear model."""    def __init__(self):        super(Net, self).__init__()        self.l1 = tf.keras.layers.Dense(5)    def call(self, x):        return self.l1(x)def toy_dataset():    inputs = tf.range(10.0)[:, None]    labels = inputs * 5.0 + tf.range(5.0)[None, :]    return (, y=labels)).repeat().batch(2)    )def train_step(net, example, optimizer):    """Trains `net` on `example` using `optimizer`."""    with tf.GradientTape() as tape:        output = net(example["x"])        loss = tf.reduce_mean(tf.abs(output - example["y"]))    variables = net.trainable_variables    gradients = tape.gradient(loss, variables)    optimizer.apply_gradients(zip(gradients, variables))    return loss# ----------------------------# -----  Create Objects  -----# ----------------------------net = Net()opt = tf.keras.optimizers.Adam(0.1)dataset = toy_dataset()iterator = iter(dataset)ckpt = tf.train.Checkpoint(    step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)manager = tf.train.CheckpointManager(ckpt, "./tf_ckpts", max_to_keep=3)# ----------------------------# -----  Train and Save  -----# ----------------------------ckpt.restore(manager.latest_checkpoint)if manager.latest_checkpoint:    print("Restored from {}".format(manager.latest_checkpoint))else:    print("Initializing from scratch.")for _ in range(50):    example = next(iterator)    loss = train_step(net, example, opt)    ckpt.step.assign_add(1)    if int(ckpt.step) % 10 == 0:        save_path =        print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))        print("loss {:1.2f}".format(loss.numpy()))# ---------------------# -----  Restore  -----# ---------------------# In another script, re-initialize objectsopt = tf.keras.optimizers.Adam(0.1)net = Net()dataset = toy_dataset()iterator = iter(dataset)ckpt = tf.train.Checkpoint(    step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)manager = tf.train.CheckpointManager(ckpt, "./tf_ckpts", max_to_keep=3)# Re-use the manager code above ^ckpt.restore(manager.latest_checkpoint)if manager.latest_checkpoint:    print("Restored from {}".format(manager.latest_checkpoint))else:    print("Initializing from scratch.")for _ in range(50):    example = next(iterator)    # Continue training or evaluate etc.

More links

Checkpoints capture the exact value of all parameters (tf.Variable objects) used by a model. Checkpoints do not contain any description of the computation defined by the model and thus are typically only useful when source code that will use the saved parameter values is available.

The SavedModel format on the other hand includes a serialized description of the computation defined by the model in addition to the parameter values (checkpoint). Models in this format are independent of the source code that created the model. They are thus suitable for deployment via TensorFlow Serving, TensorFlow Lite, TensorFlow.js, or programs in other programming languages (the C, C++, Java, Go, Rust, C# etc. TensorFlow APIs).

(Highlights are my own)

Tensorflow < 2

From the docs:


# Create some variables.v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)inc_v1 = v1.assign(v1+1)dec_v2 = v2.assign(v2-1)# Add an op to initialize the variables.init_op = tf.global_variables_initializer()# Add ops to save and restore all the variables.saver = tf.train.Saver()# Later, launch the model, initialize the variables, do some work, and save the# variables to disk.with tf.Session() as sess:  # Do some work with the model.  # Save the variables to disk.  save_path =, "/tmp/model.ckpt")  print("Model saved in path: %s" % save_path)


tf.reset_default_graph()# Create some variables.v1 = tf.get_variable("v1", shape=[3])v2 = tf.get_variable("v2", shape=[5])# Add ops to save and restore all the variables.saver = tf.train.Saver()# Later, launch the model, use the saver to restore variables from disk, and# do some work with the model.with tf.Session() as sess:  # Restore variables from disk.  saver.restore(sess, "/tmp/model.ckpt")  print("Model restored.")  # Check the values of the variables  print("v1 : %s" % v1.eval())  print("v2 : %s" % v2.eval())


Many good answer, for completeness I'll add my 2 cents: simple_save. Also a standalone code example using the API.

Python 3 ; Tensorflow 1.14

import tensorflow as tffrom tensorflow.saved_model import tag_constantswith tf.Graph().as_default():    with tf.Session() as sess:        ...        # Saving        inputs = {            "batch_size_placeholder": batch_size_placeholder,            "features_placeholder": features_placeholder,            "labels_placeholder": labels_placeholder,        }        outputs = {"prediction": model_output}        tf.saved_model.simple_save(            sess, 'path/to/your/location/', inputs, outputs        )


graph = tf.Graph()with restored_graph.as_default():    with tf.Session() as sess:        tf.saved_model.loader.load(            sess,            [tag_constants.SERVING],            'path/to/your/location/',        )        batch_size_placeholder = graph.get_tensor_by_name('batch_size_placeholder:0')        features_placeholder = graph.get_tensor_by_name('features_placeholder:0')        labels_placeholder = graph.get_tensor_by_name('labels_placeholder:0')        prediction = restored_graph.get_tensor_by_name('dense/BiasAdd:0'), feed_dict={            batch_size_placeholder: some_value,            features_placeholder: some_other_value,            labels_placeholder: another_value        })

Standalone example

Original blog post

The following code generates random data for the sake of the demonstration.

  1. We start by creating the placeholders. They will hold the data at runtime. From them, we create the Dataset and then its Iterator. We get the iterator's generated tensor, called input_tensor which will serve as input to our model.
  2. The model itself is built from input_tensor: a GRU-based bidirectional RNN followed by a dense classifier. Because why not.
  3. The loss is a softmax_cross_entropy_with_logits, optimized with Adam. After 2 epochs (of 2 batches each), we save the "trained" model with tf.saved_model.simple_save. If you run the code as is, then the model will be saved in a folder called simple/ in your current working directory.
  4. In a new graph, we then restore the saved model with tf.saved_model.loader.load. We grab the placeholders and logits with graph.get_tensor_by_name and the Iterator initializing operation with graph.get_operation_by_name.
  5. Lastly we run an inference for both batches in the dataset, and check that the saved and restored model both yield the same values. They do!


import osimport shutilimport numpy as npimport tensorflow as tffrom tensorflow.python.saved_model import tag_constantsdef model(graph, input_tensor):    """Create the model which consists of    a bidirectional rnn (GRU(10)) followed by a dense classifier    Args:        graph (tf.Graph): Tensors' graph        input_tensor (tf.Tensor): Tensor fed as input to the model    Returns:        tf.Tensor: the model's output layer Tensor    """    cell = tf.nn.rnn_cell.GRUCell(10)    with graph.as_default():        ((fw_outputs, bw_outputs), (fw_state, bw_state)) = tf.nn.bidirectional_dynamic_rnn(            cell_fw=cell,            cell_bw=cell,            inputs=input_tensor,            sequence_length=[10] * 32,            dtype=tf.float32,            swap_memory=True,            scope=None)        outputs = tf.concat((fw_outputs, bw_outputs), 2)        mean = tf.reduce_mean(outputs, axis=1)        dense = tf.layers.dense(mean, 5, activation=None)        return densedef get_opt_op(graph, logits, labels_tensor):    """Create optimization operation from model's logits and labels    Args:        graph (tf.Graph): Tensors' graph        logits (tf.Tensor): The model's output without activation        labels_tensor (tf.Tensor): Target labels    Returns:        tf.Operation: the operation performing a stem of Adam optimizer    """    with graph.as_default():        with tf.variable_scope('loss'):            loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(                    logits=logits, labels=labels_tensor, name='xent'),                    name="mean-xent"                    )        with tf.variable_scope('optimizer'):            opt_op = tf.train.AdamOptimizer(1e-2).minimize(loss)        return opt_opif __name__ == '__main__':    # Set random seed for reproducibility    # and create synthetic data    np.random.seed(0)    features = np.random.randn(64, 10, 30)    labels = np.eye(5)[np.random.randint(0, 5, (64,))]    graph1 = tf.Graph()    with graph1.as_default():        # Random seed for reproducibility        tf.set_random_seed(0)        # Placeholders        batch_size_ph = tf.placeholder(tf.int64, name='batch_size_ph')        features_data_ph = tf.placeholder(tf.float32, [None, None, 30], 'features_data_ph')        labels_data_ph = tf.placeholder(tf.int32, [None, 5], 'labels_data_ph')        # Dataset        dataset =, labels_data_ph))        dataset = dataset.batch(batch_size_ph)        iterator =, dataset.output_shapes)        dataset_init_op = iterator.make_initializer(dataset, name='dataset_init')        input_tensor, labels_tensor = iterator.get_next()        # Model        logits = model(graph1, input_tensor)        # Optimization        opt_op = get_opt_op(graph1, logits, labels_tensor)        with tf.Session(graph=graph1) as sess:            # Initialize variables            tf.global_variables_initializer().run(session=sess)            for epoch in range(3):                batch = 0                # Initialize dataset (could feed epochs in Dataset.repeat(epochs))                          dataset_init_op,                    feed_dict={                        features_data_ph: features,                        labels_data_ph: labels,                        batch_size_ph: 32                    })                values = []                while True:                    try:                        if epoch < 2:                            # Training                            _, value =[opt_op, logits])                            print('Epoch {}, batch {} | Sample value: {}'.format(epoch, batch, value[0]))                            batch += 1                        else:                            # Final inference                            values.append(                            print('Epoch {}, batch {} | Final inference | Sample value: {}'.format(epoch, batch, values[-1][0]))                            batch += 1                    except tf.errors.OutOfRangeError:                        break            # Save model state            print('\nSaving...')            cwd = os.getcwd()            path = os.path.join(cwd, 'simple')            shutil.rmtree(path, ignore_errors=True)            inputs_dict = {                "batch_size_ph": batch_size_ph,                "features_data_ph": features_data_ph,                "labels_data_ph": labels_data_ph            }            outputs_dict = {                "logits": logits            }            tf.saved_model.simple_save(                sess, path, inputs_dict, outputs_dict            )            print('Ok')    # Restoring    graph2 = tf.Graph()    with graph2.as_default():        with tf.Session(graph=graph2) as sess:            # Restore saved values            print('\nRestoring...')            tf.saved_model.loader.load(                sess,                [tag_constants.SERVING],                path            )            print('Ok')            # Get restored placeholders            labels_data_ph = graph2.get_tensor_by_name('labels_data_ph:0')            features_data_ph = graph2.get_tensor_by_name('features_data_ph:0')            batch_size_ph = graph2.get_tensor_by_name('batch_size_ph:0')            # Get restored model output            restored_logits = graph2.get_tensor_by_name('dense/BiasAdd:0')            # Get dataset initializing operation            dataset_init_op = graph2.get_operation_by_name('dataset_init')            # Initialize restored dataset                  dataset_init_op,                feed_dict={                    features_data_ph: features,                    labels_data_ph: labels,                    batch_size_ph: 32                }            )            # Compute inference for both batches in dataset            restored_values = []            for i in range(2):                restored_values.append(                print('Restored values: ', restored_values[i][0])    # Check if original inference and restored inference are equal    valid = all((v == rv).all() for v, rv in zip(values, restored_values))    print('\nInferences match: ', valid)

This will print:

$ python3 save_and_restore.pyEpoch 0, batch 0 | Sample value: [-0.13851789 -0.3087595   0.12804556  0.20013677 -0.08229901]Epoch 0, batch 1 | Sample value: [-0.00555491 -0.04339041 -0.05111827 -0.2480045  -0.00107776]Epoch 1, batch 0 | Sample value: [-0.19321944 -0.2104792  -0.00602257  0.07465433  0.11674127]Epoch 1, batch 1 | Sample value: [-0.05275984  0.05981954 -0.15913513 -0.3244143   0.10673307]Epoch 2, batch 0 | Final inference | Sample value: [-0.26331693 -0.13013336 -0.12553    -0.04276478  0.2933622 ]Epoch 2, batch 1 | Final inference | Sample value: [-0.07730117  0.11119192 -0.20817074 -0.35660955  0.16990358]Saving...INFO:tensorflow:Assets added to graph.INFO:tensorflow:No assets to write.INFO:tensorflow:SavedModel written to: b'/some/path/simple/saved_model.pb'OkRestoring...INFO:tensorflow:Restoring parameters from b'/some/path/simple/variables/variables'OkRestored values:  [-0.26331693 -0.13013336 -0.12553    -0.04276478  0.2933622 ]Restored values:  [-0.07730117  0.11119192 -0.20817074 -0.35660955  0.16990358]Inferences match:  True