TensorFlow 2.0: do you need a @tf.function decorator on top of each function?
While the decorator @tf.function applies to the function block immediately following it, any functions called by it will be executed in graph mode as well. See the Effective TF2 guide where it states:
In TensorFlow 2.0, users should refactor their code into smaller functions which are called as needed. In general, it's not necessary to decorate each of these smaller functions with tf.function; only use tf.function to decorate high-level computations - for example, one step of training, or the forward pass of your model.
@tf.function
converts a Python function to its graph representation.
The pattern to follow is to define the training step function, that's the most computationally intensive function, and decorate it with @tf.function
.
Usually, the code looks like:
#model,loss, and optimizer defined previously@tf.functiondef train_step(features, labels): with tf.GradientTape() as tape: predictions = model(features) loss_value = loss(labels, predictions) gradients = tape.gradient(loss, model.trainable_variables) optimizer.apply_gradients(zip(gradients, model.trainable_variables)) return loss_valuefor features, labels in dataset: lv = train_step(features, label) print("loss: ", lv)