pytorch how to set .requires_grad False pytorch how to set .requires_grad False python python

pytorch how to set .requires_grad False


If you want to freeze part of your model and train the rest, you can set requires_grad of the parameters you want to freeze to False.

For example, if you only want to keep the convolutional part of VGG16 fixed:

model = torchvision.models.vgg16(pretrained=True)for param in model.features.parameters():    param.requires_grad = False

By switching the requires_grad flags to False, no intermediate buffers will be saved, until the computation gets to some point where one of the inputs of the operation requires the gradient.


Using the context manager torch.no_grad is a different way to achieve that goal: in the no_grad context, all the results of the computations will have requires_grad=False, even if the inputs have requires_grad=True. Notice that you won't be able to backpropagate the gradient to layers before the no_grad. For example:

x = torch.randn(2, 2)x.requires_grad = Truelin0 = nn.Linear(2, 2)lin1 = nn.Linear(2, 2)lin2 = nn.Linear(2, 2)x1 = lin0(x)with torch.no_grad():        x2 = lin1(x1)x3 = lin2(x2)x3.sum().backward()print(lin0.weight.grad, lin1.weight.grad, lin2.weight.grad)


(None, None, tensor([[-1.4481, -1.1789],         [-1.4481, -1.1789]]))

Here lin1.weight.requires_grad was True, but the gradient wasn't computed because the oepration was done in the no_grad context.


If your goal is not to finetune, but to set your model in inference mode, the most convenient way is to use the torch.no_grad context manager. In this case you also have to set your model to evaluation mode, this is achieved by calling eval() on the nn.Module, for example:

model = torchvision.models.vgg16(pretrained=True)model.eval()

This operation sets the attribute of the layers to False, in practice this will change the behavior of operations like Dropout or BatchNorm that must behave differently at training and test time.

Here is the way;

linear = nn.Linear(1,1)for param in linear.parameters():    param.requires_grad = Falsewith torch.no_grad():    linear.eval()    print(linear.weight.requires_grad)


To complete @Salih_Karagoz's answer, you also have the torch.set_grad_enabled() context (further documentation here), which can be used to easily switch between train/eval modes:

linear = nn.Linear(1,1)is_train = Falsefor param in linear.parameters():    param.requires_grad = is_trainwith torch.set_grad_enabled(is_train):    linear.eval()    print(linear.weight.requires_grad)