PyTorch 中的“with torch no_grad”有什么作用?
“with”torch.no_grad()的使用就像一个循环,其中循环内的每个张量都将requires_grad设置为False。这意味着当前与当前计算图相连的任何具有梯度的张量现在都与当前图分离。我们不再能够计算关于这个张量的梯度。
张量从当前图中分离,直到它在循环内。一旦它离开循环,如果张量是用梯度定义的,它就会再次附加到当前图。
让我们举几个例子来更好地理解它是如何工作的。
示例1
在这个例子中,我们用requires_grad=true创建了一个张量x。接下来,我们定义这个张量x的函数y并将该函数放入with循环中。现在x在循环内,所以它的requires_grad被设置为False。torch.no_grad()
在循环中,无法计算y相对于x的梯度。所以,y.requires_grad返回False。
# import torch library import torch # define a torch tensor x = torch.tensor(2., requires_grad = True) print("x:", x) # define a function y with torch.no_grad(): y = x ** 2 print("y:", y) # check gradient for Y print("y.requires_grad:", y.requires_grad)输出结果
x: tensor(2., requires_grad=True) y: tensor(4.) y.requires_grad: False
示例2
在这个例子中,我们在循环之外定义了函数z。所以,z.requires_grad返回True。
# import torch library import torch # define three tensors x = torch.tensor(2., requires_grad = False) w = torch.tensor(3., requires_grad = True) b = torch.tensor(1., requires_grad = True) print("x:", x) print("w:", w) print("b:", b) # define a function y y = w * x + b print("y:", y) # define a function z with torch.no_grad(): z = w * x + b print("z:", z) # check if requires grad is true or not print("y.requires_grad:", y.requires_grad) print("z.requires_grad:", z.requires_grad)输出结果
x: tensor(2.) w: tensor(3., requires_grad=True) b: tensor(1., requires_grad=True) y: tensor(7., grad_fn=<AddBackward0>) z: tensor(7.) y.requires_grad: True z.requires_grad: False