Pytorch Autograd Case Study (create_graph for gradient and retrain_graph for multiple backwards)
๐ฅ Torch์๋ ๐์๋๋ฏธ๋ถ์ด๋ผ๋ ๊ธฐ๋ฅ์ด ์์ด์, ๋ฏธ๋ถ์ ์ฐ๋ฆฌ๊ฐ ๊ตณ์ด ํ์ง ์์๋ ๋ฉ๋๋ค.
๊ทธ๋ฐ๋ฐ ๋ฌธ์ ๋ ์ด๋ป๊ฒ ๋์ํ๋์ง ๋ชจ๋ฅด๋, ๊ตฌํ์ ์ ๋๋ก ํ๊ธฐ ์ฝ์ง ์์ต๋๋ค.
๊ทธ๋์ torch ์์ ๐autograd ๊ฐ ์ฌ์ฉ๋๋ ๋ช๊ฐ์ง use case ๋ค์ ์ ๋ฆฌํ์์ต๋๋ค.
1. ๋จ์ํ High-order derivative ๊ตฌํ๊ธฐ
2. Neural Network ์ input ํน์ weight์ ๋ํด์ Gradient ๊ตฌํ๊ธฐ
3. Model-Agnostic Meta Learning (MAML) : ๋ฉํฐ๋ฌ๋
4. Hessian of a neural network (Hessian Vector Product)
5. Implicit Neural Representation
6. Multiple Backwards
---- ์ด์ธ์ ๊ด์ฐฎ์ UseCase๋ฅผ ๋๊ธ๋ก ์๊ฐํด์ฃผ์๋ฉด ์ถ๊ฐํ๊ฒ ์ต๋๋ค ๐
---- ๋ณธ๋ฌธ์์ ์ฌ์ฉ๋ ์ฝ๋ : https://github.com/fxnnxc/torch_autograd/blob/main/autograd_case_study.ipynb
---- Autograd ์๊ฐ (Pytorch) : https://tutorials.pytorch.kr/beginner/blitz/autograd_tutorial.html
---- ์๋ ๋ฏธ๋ถ ์๊ฐ (youtube) : https://www.youtube.com/watch?v=MswxJw-8PvE (๊ฐ์ถ)
1. ๐ ๋จ์ํ High-order derivative ๊ตฌํ๊ธฐ
๊ฐ๋ น ์ธํ $x$์ ๋ํ์ฌ ํจ์๊ฐ$f$ ์ ๋ฏธ๋ถ๊ณ์๋ฅผ ๊ตฌํ๋ค๋ฉด, $\frac{d}{dx}$ ๋ฅผ ์ฌ๋ฌ๋ฒ ํ์ฌ ๊ณ์ฐํ ์ ์์ต๋๋ค.
$$\frac{df}{dx} ~~~~ \frac{d}{dx}\Big(\frac{df}{dx}\Big) ~~~~ \frac{d}{dx}\Big(\frac{d^2f}{dx^2}\Big)$$
์ด์ ๋น์ทํ๊ฒ torch์์๋ torch.autograd.grad(f, x, create_graph=True) ๋ฅผ ํตํ์ฌ ๊ตฌํ Gradient ์ ๋ํด์ ๋ค์ Gradient๋ฅผ ๊ตฌํ ์ ์์ต๋๋ค. ๐create_graph ์ ์ญํ ์ gradient ์ ๋ํด์๋ computational graph๋ฅผ ๋ง๋ค์ด์ค ๊ฒ์ ๋๋ค. ์ด๋ก๋ถํฐ ๊ตฌํ gradient๋ gradient ๋ฅผ ์ ์ฅํ ์ ์๋ ๐Leaf (๋ค๊ฐ ์๋ ํ ์) ๊ฐ ๋๊ณ , ์ด ๊ฐ์ ๋ํด์ ์ฐ์ฐ๋ค (์ ๊ณฑ, ๋ง์ , ์ฌ์ง์ด Gradient) ๊น์ง ๊ณ์ฐํ ์ ์์ต๋๋ค.
import torch
v = -2.0
x = torch.tensor(v, requires_grad=True)
# function
f = 2 * x**3
print(f"x={v}, f = {f}")
# first order
f.backward(retain_graph=True) # ๋ค์์๋ ๋ค์ backward ๋ฅผ ํด์ ์ด๋ ค๋ก์๋ค.
print(f"x={v}, df/dx = {x.grad}")
# second order
x.grad.zero_()
df = torch.autograd.grad(f, x, create_graph=True)[0]
df.backward(gradient=torch.tensor(1.0))
print(f"x={v}, d^2f/dx^2 = {x.grad}")
# thrid order
x.grad.zero_()
df = torch.autograd.grad(f, x, create_graph=True)[0]
ddf = torch.autograd.grad(df, x, create_graph=True)[0]
ddf.backward(gradient=torch.tensor(1.0))
print(f"x={v}, d^3f/dx^3 = {x.grad}")
ํด๋น ์ฝ๋๋ฅผ ๋๋ ค๋ณด๋ฉด, ์๋์ ๊ฐ์ด autograd๋ก ๊ตฌํ ๊ฐ๊ณผ ์์์ผ๋ก ๊ตฌํ ๊ฐ์ด ์ผ์นํ๋ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค.
2. ๐ Neural Network ์ input ํน์ weight์ ๋ํด์ Gradient ๊ตฌํ๊ธฐ
$f(x)$ ๋ฅผ ํ๋ผ๋ฏธํฐ $\theta$ ๋ก ์ด๋ฃจ์ด์ง Deep Neural Network ๋ผ๊ณ ํ๋ค๋ฉด, Gradient ๋ ๋ ๊ฐ์ง์ ๋ํด์ ๊ตฌํ ์ ์์ต๋๋ค.
1. input $x$ ์ ๋ํด์ $df/dx$
2. parameter $\theta$ ์ ๋ํด์ $df/d\theta$
ํน์ ์ฌ์ง์ด ํ๋ผ๋ฏธํฐ์ ๋ํ gradient ๋ฅผ ํ๊ท ๋ธ ๊ฐ์ ๋ํด์๋ ๊ตฌํ ์ ์์ต๋๋ค.
$$ G = \sum_\theta \Big(\frac{df}{d\theta} \Big)^2 $$
๊ตฌํ์์ ํ ๊ฐ์ง ํธ๋ฆญ์ ๐ํ๋ผ๋ฏธํฐ๋ค์ ์ผ์๋ก ๋ง๋๋ ๊ฒ ์ ๋๋ค.
๊ตณ์ด ([A,A,A], [B,B], [C], [D]) ๋ ์ด์ด๋ก ๋๋ ์ ธ ์์ ํ์๊ฐ ์์ด, [A,A,A,B,B,C,D] ํํ๋ก ๋ง๋ญ๋๋ค.
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super().__init__()
self.mlp = nn.Linear(2,5)
self.out = nn.Linear(5, 1)
def forward(self, x):
x = self.mlp(x**2)
x = nn.functional.relu(x)
x = self.out(x)
return x
net = Net()
input= torch.tensor([1.0]*2, requires_grad=True)
params = [p for p in net.parameters() if len(p.size()) > 1]
# gradient with respect to the input
print("------------ Gradient with respect to the input")
f = net(input)
f.backward()
print(input.grad)
# gradient with respect to the weights
print("------------ Gradient with respect to the weights")
input.grad.zero_()
f = net(input)
gx = torch.autograd.grad(f, params, create_graph=True)[0]
for i in range(2):
gx[i].sum().backward(retain_graph=True)
print(input.grad)
# gradient of weights to minimize the sum of squared of the gradients
print("------------ Gradients to the sum of squqared gradients")
net.zero_grad()
f = net(input)
gx = torch.autograd.grad(f, params, create_graph=True)[0]
(gx**2).sum().backward()
print([p.grad for p in params])
3. Model-Agnostic Meta Learning (MAML) : ๋ฉํฐ๋ฌ๋
๋ค์ ์์๋ MAML์์ ์ฌ์ฉํ๋ Gradient Update ๋ฐฉ์์ ๋๋ค.
MAML์ (1) meta-train ๋ฐ์ดํฐ์ ๋ํด์ weight๋ฅผ ๊ณ์ฐํ๊ณ , (2) meta-test ๋ฐ์ดํฐ์ ๋ํด์ update ๋ weight๋ฅผ ์ต์ํํ๋ ๋ฐฉ์์ผ๋ก ํ์ตํฉ๋๋ค. ์ด ๋, (1) ์์ ๊ตฌํ weight๋ $\theta' = \theta - \alpha \nabla_\theta L_{task} (f(\theta))$ ์ผ๋ก ๊ตฌํด์ง๋๋ฐ, gradient๊ฐ $\theta$ ์ ๋ํด์ ํ๋ผ๋ฏธํฐ๋ก ๊ณ์ฐ๋ฉ๋๋ค. ์ด๋ก๋ถํฐ (2) ์์ ๋ค์ gradient ๋ฅผ ๊ณ์ฐํด์ค๋๋ค.
$$ \theta \leftarrow \theta - \nabla_\theta L_{{meta}} (f(\theta')) $$
๋ ผ๋ฌธ : Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super().__init__()
self.mlp = nn.Linear(2,5)
self.out = nn.Linear(5, 1)
def forward(self, x, fast_weights=None):
if fast_weights is None:
x = self.mlp(x)
x = nn.functional.relu(x)
x = self.out(x)
else: # same logit with given weight
w, b = fast_weights[0], fast_weights[1]
x = F.linear(x, w, b)
x = nn.functional.relu(x)
w, b = fast_weights[2], fast_weights[3]
x = F.linear(x, w, b)
return x
net = Net()
lr=1e-3
meta_optim = torch.optim.Adam(net.parameters(), lr=lr)
meta_loss = 0
num_tasks = 3
for task in range(num_tasks):
# Step 1 : compute one_step weight
input= torch.tensor([(task+1.0)]*2)
f = net(input)
grad = torch.autograd.grad(f, net.parameters(), create_graph=True) # Step 1 Core
fast_weights = list(map(lambda p: p[1] - lr * p[0], zip(grad, net.parameters())))
# Step 2 : compute the loss with one step weight
meta_input= torch.tensor([(task+2.0)]*2)
test_loss = net(input, fast_weights) # Step 2 Core
meta_loss += test_loss
meta_loss /= num_tasks
# optimize theta parameters
meta_optim.zero_grad()
meta_loss.backward()
meta_optim.step()
4. Hessian of a neural network (Hessian Vector Product)
Hessian Matrix $H(\theta)_{N\times N} = \frac{d^2}{d\theta^2} \Big( L(\theta) \Big)$์ Vector $r_{N\times 1}$ ์ ๋ํ Product ์ฐ์ฐ์
$$ \begin{aligned} H(\theta) \cdot r &= \frac{d^2}{d\theta^2} \Big( L(\theta) \Big) \cdot r \\ &= \frac{d}{d\theta} \Big( \frac{d}{d\theta} L(\theta) \cdot r \Big) \end{aligned} $$
์ฌ๊ธฐ์, $r$ ์ $\theta$ ์ ๊ด๋ จ์๋ ๋ฒกํฐ์ด๋ฏ๋ก Second Derivative ๋ฅผ ๊ณ์ฐํ๋๋ฐ ์ํฅ์ ๋ฏธ์น์ง ์์ต๋๋ค.
โ Step 1. Compute the gradient $\frac{d}{d\theta} L(\theta)$
โ Step 2. Compute the gradient times vector $\Big(\frac{d}{d\theta} L(\theta) \Big) \cdot r$
net = Net()
params = [p for p in net.parameters() if len(p.size()) > 1]
N = sum(p.numel() for p in params)
print(f"Number of parameters : {N}")
print("Params", params)
# Compute the gardients
input= torch.tensor([(task+1.0)]*2)
f = net(input)
grad = torch.autograd.grad(f, inputs=params, create_graph=True) # Step 1 Core
# Hessian Vector Product
prod = torch.autograd.Variable(torch.zeros(1)).type(type(grad[0].data))
vec = torch.rand_like(prod)
for (g,v) in zip(grad, vec):
prod = prod + (g * v).cpu().sum() # Step 2 Core
prod.backward() # Now the params.grad stores the Hessian
print("----------------------")
print("Hessian Vector Product")
for p in params:
print(p.grad)
5. Implicit Neural Representation
์ด๋ฏธ์ง๋ ํจ์๋ก ์ดํดํ ์ ์์ต๋๋ค. ๊ฐ ํฝ์ ์์น์ ๋ํด์ ์๊น์ ์์ธกํ๋ ํจ์๋ก ๋ง์ด์ฃ . ๊ทธ๋ ๊ธฐ ๋๋ฌธ์, ์ด๋ฏธ์ง์ ๋ํด์๋ Gradient ๊ฐ ์์ผ๋ฉฐ, ์ ๊ทธ๋ฆผ์ฒ๋ผ x ๋ฐฉํฅ, y๋ฐฉํฅ์ ๋ํด์ ๋นจ๊ฐ์ ์ฒด๋์ Gradient๋ฅผ ๊ตฌํ ์ ์์ต๋๋ค. ์์ธํ ๋ด์ฉ์ ๐Sobel Filter ๋ฅผ ์ฐธ์กฐํด์ฃผ์ธ์. ์ด๋ฏธ์ง์ Gradient๋ฅผ ์ ๊ฒฝ๋ง์ผ๋ก Fitting ํ ์ ์์ผ๋ฉฐ, ์ด ๋ถ์ผ๋ฅผ Implicit Neural Representation ์ด๋ผ๊ณ ๋ถ๋ฆ ๋๋ค.
1. Original Function : $ f: p(x,y) \rightarrow (r,g,b) $
2. Neural Network Fitting : ์ ํจ์๋ฅผ ๊ทผ์ฌํ๋ ์ ๊ฒฝ๋ง $\phi \approx f$
ํจ์๋ฅผ ํผํ ํ๋ Objective ๋ ๋ ๊ฐ์ง (ํฝ์ ๊ฐ, ๊ทธ๋ ๋์ธํธ ์ ์ถ) ๋ก ๋๋ฉ๋๋ค.
1. Pixel Fitting : compute the color at pixel position $p'$ $$ \mathcal{L} = ||f(p') - \phi(p')||^2 $$
2. Gradient Fitting : compute the gradient of color at pixel position $p'$
$$ \mathcal{L} = ||\frac{df}{dp}\Big|_{p'} - \frac{d\phi(p)}{dp}\Big|_{p'}||^2 $$
class ImplicitNet(nn.Module):
def __init__(self):
super().__init__()
self.layer = nn.Sequential(
nn.Linear(2, 64), # Input the x,y position
nn.ReLU(),
nn.Linear(64, 3) # RGB prediction
)
def forward(self, x):
x = x.clone().detach().requires_grad_(True)
return self.layer(x), x
net = ImplicitNet()
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)
# given: position(x,y) --> output: ColorPrediction, position
coordinates = torch.tensor([[0.5, 0.5]])
all_color_prediction, coordinates = net(coordinates)
# Matching the gradient
# assume the ground truth gradient
gt = torch.rand(1, 2)
loss = 0
for channel in range(3): # each R G B
color_predict = all_color_prediction[:, channel]
grad_predict = torch.autograd.grad(color_predict,
coordinates,
create_graph=True)[0] # create_graph makes the optimization possible.
loss+= F.mse_loss(gt, grad_predict)
optimizer.zero_grad()
loss.backward()
optimizer.step()
6. Multiple Backwards
์๋ ๊ทธ๋ฆผ๊ณผ ๊ฐ์ด ๋ ๊ฐ์ Bracnh ์์ Gradient Descent ๋ฅผ ์งํํ๋ ๊ฒฝ์ฐ,
retrain_graph=True ๋ฅผ ์ด์ฉํ์ฌ backward ํ gradient ๋ฅผ ๋ณด์กดํ ์ ์์ต๋๋ค.
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super().__init__()
self.common = nn.Linear(2, 3)
self.net1 = nn.Sequential(
nn.Linear(3, 5),
nn.ReLU(),
nn.Linear(5,1)
)
self.net2 = nn.Sequential(
nn.Linear(3, 5),
nn.ReLU(),
nn.Linear(5,1)
)
def forward(self, x):
common = self.common(x)
path1 = self.net1(common)
path2 = self.net2(common)
return path1, path2
net = Net()
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)
input = torch.rand(1, 2)
out_1, out_2 = net(input)
# optimize net1 and common
optimizer.zero_grad()
out_1.backward(retain_graph=True) # Core Code for retaining computational graph
optimizer.step()
# optimize net2 and common (again)
optimizer.zero_grad()
out_2.backward()
optimizer.step()
๊ด์ฐฎ์ Use Case๊ฐ ์๋ค๋ฉด ๊ณต์ ํด์ฃผ์ธ์ ๐ค
Reference๋ฅผ ๋ฌ๊ณ ์ ๋ก๋ ํ๊ฒ ์ต๋๋ค~!
์๋ชป๋ ๊ตฌํ๋ ํผ๋๋ฐฑ ํด์ฃผ์ธ์ :)
๋ณธ๋ฌธ์์ ์ฌ์ฉ๋ ์ฝ๋ : https://github.com/fxnnxc/torch_autograd/blob/main/autograd_case_study.ipynb
๐ ์ข์์ /๋ ๐ํฌ์คํ ์ ๐ช๐ผํ์ด ๋ฉ๋๋ค