pytorch how to set .requires_grad False
requires_grad=False
If you want to freeze part of your model and train the rest, you can set requires_grad
of the parameters you want to freeze to False
.
For example, if you only want to keep the convolutional part of VGG16 fixed:
model = torchvision.models.vgg16(pretrained=True)for param in model.features.parameters(): param.requires_grad = False
By switching the requires_grad
flags to False
, no intermediate buffers will be saved, until the computation gets to some point where one of the inputs of the operation requires the gradient.
torch.no_grad()
Using the context manager torch.no_grad
is a different way to achieve that goal: in the no_grad
context, all the results of the computations will have requires_grad=False
, even if the inputs have requires_grad=True
. Notice that you won't be able to backpropagate the gradient to layers before the no_grad
. For example:
x = torch.randn(2, 2)x.requires_grad = Truelin0 = nn.Linear(2, 2)lin1 = nn.Linear(2, 2)lin2 = nn.Linear(2, 2)x1 = lin0(x)with torch.no_grad(): x2 = lin1(x1)x3 = lin2(x2)x3.sum().backward()print(lin0.weight.grad, lin1.weight.grad, lin2.weight.grad)
outputs:
(None, None, tensor([[-1.4481, -1.1789], [-1.4481, -1.1789]]))
Here lin1.weight.requires_grad
was True, but the gradient wasn't computed because the oepration was done in the no_grad
context.
model.eval()
If your goal is not to finetune, but to set your model in inference mode, the most convenient way is to use the torch.no_grad
context manager. In this case you also have to set your model to evaluation mode, this is achieved by calling eval()
on the nn.Module
, for example:
model = torchvision.models.vgg16(pretrained=True)model.eval()
This operation sets the attribute self.training
of the layers to False
, in practice this will change the behavior of operations like Dropout
or BatchNorm
that must behave differently at training and test time.
Here is the way;
linear = nn.Linear(1,1)for param in linear.parameters(): param.requires_grad = Falsewith torch.no_grad(): linear.eval() print(linear.weight.requires_grad)
OUTPUT: False
To complete @Salih_Karagoz's answer, you also have the torch.set_grad_enabled()
context (further documentation here), which can be used to easily switch between train/eval modes:
linear = nn.Linear(1,1)is_train = Falsefor param in linear.parameters(): param.requires_grad = is_trainwith torch.set_grad_enabled(is_train): linear.eval() print(linear.weight.requires_grad)