Why do we need to call zero_grad() in PyTorch? Why do we need to call zero_grad() in PyTorch? python python

Why do we need to call zero_grad() in PyTorch?


In PyTorch, for every mini-batch during the training phase, we need to explicitly set the gradients to zero before starting to do backpropragation (i.e., updation of Weights and biases) because PyTorch accumulates the gradients on subsequent backward passes. This is convenient while training RNNs. So, the default action has been set to accumulate (i.e. sum) the gradients on every loss.backward() call.

Because of this, when you start your training loop, ideally you should zero out the gradients so that you do the parameter update correctly. Else the gradient would point in some other direction than the intended direction towards the minimum (or maximum, in case of maximization objectives).

Here is a simple example:

import torchfrom torch.autograd import Variableimport torch.optim as optimdef linear_model(x, W, b):    return torch.matmul(x, W) + bdata, targets = ...W = Variable(torch.randn(4, 3), requires_grad=True)b = Variable(torch.randn(3), requires_grad=True)optimizer = optim.Adam([W, b])for sample, target in zip(data, targets):    # clear out the gradients of all Variables     # in this optimizer (i.e. W, b)    optimizer.zero_grad()    output = linear_model(sample, W, b)    loss = (output - target) ** 2    loss.backward()    optimizer.step()

Alternatively, if you're doing a vanilla gradient descent, then:

W = Variable(torch.randn(4, 3), requires_grad=True)b = Variable(torch.randn(3), requires_grad=True)for sample, target in zip(data, targets):    # clear out the gradients of Variables     # (i.e. W, b)    W.grad.data.zero_()    b.grad.data.zero_()    output = linear_model(sample, W, b)    loss = (output - target) ** 2    loss.backward()    W -= learning_rate * W.grad.data    b -= learning_rate * b.grad.data

Note:

  • The accumulation (i.e., sum) of gradients happen when .backward() is called on the loss tensor.
  • As of v1.7.0, there's an option of resetting the gradients with None optimizer.zero_grad(set_to_none=True) instead of filling it with a tensor of zeroes. The docs claim that this setting will result in lower memory requirements and a slight improvement in performance but it might be error-prone, if not handled carefully.


Although the idea can be derived from the chosen answer, but I feel like I want to write that explicitly.

Being able to decide when to call optimizer.zero_grad() and optimizer.step() provides more freedom on how gradient is accumulated and applied by the optimizer in the training loop. This is crucial when the model or input data is big and one actual training batch do not fit in to the gpu card.

Here in this example from google-research, there are two arguments, named train_batch_size and gradient_accumulation_steps.

  • train_batch_size is the batch size for the forward pass, following the loss.backward(). This is limited by the gpu memory.

  • gradient_accumulation_steps is the actual training batch size, where loss from multiple forward pass is accumulated. This is NOT limited by the gpu memory.

From this example, you can see how optimizer.zero_grad() may followed by optimizer.step() but NOT loss.backward(). loss.backward() is invoked in every single iteration (line 216) but optimizer.zero_grad() and optimizer.step() is only invoked when the number of accumulated train batch equals the gradient_accumulation_steps (line 227 inside the if block in line 219)

https://github.com/google-research/xtreme/blob/master/third_party/run_classify.py

Also someone is asking about equivalent method in TensorFlow. I guess tf.GradientTape serve the same purpose.

(I am still new to AI library, please correct me if anything I said is wrong)


zero_grad() restarts looping without losses from the last step if you use the gradient method for decreasing the error (or losses).

If you do not use zero_grad() the loss will increase not decrease as required.

For example:

If you use zero_grad() you will get the following output:

model training loss is 1.5model training loss is 1.4model training loss is 1.3model training loss is 1.2

If you do not use zero_grad() you will get the following output:

model training loss is 1.4model training loss is 1.9model training loss is 2model training loss is 2.8model training loss is 3.5