๋”ฅ๋Ÿฌ๋‹

Pytorch Autograd Case Study (create_graph for gradient and retrain_graph for multiple backwards)

Rudi 2022. 11. 7. 16:24

๐Ÿ”ฅ Torch์—๋Š” ๐Ÿ”–์ž๋™๋ฏธ๋ถ„์ด๋ผ๋Š” ๊ธฐ๋Šฅ์ด ์žˆ์–ด์„œ, ๋ฏธ๋ถ„์„ ์šฐ๋ฆฌ๊ฐ€ ๊ตณ์ด ํ•˜์ง€ ์•Š์•„๋„ ๋ฉ๋‹ˆ๋‹ค. 

๊ทธ๋Ÿฐ๋ฐ ๋ฌธ์ œ๋Š” ์–ด๋–ป๊ฒŒ ๋™์ž‘ํ•˜๋Š”์ง€ ๋ชจ๋ฅด๋‹ˆ, ๊ตฌํ˜„์„ ์ œ๋Œ€๋กœ ํ•˜๊ธฐ ์‰ฝ์ง€ ์•Š์Šต๋‹ˆ๋‹ค. 

๊ทธ๋ž˜์„œ torch ์—์„œ  ๐Ÿ”–autograd ๊ฐ€ ์‚ฌ์šฉ๋˜๋Š” ๋ช‡๊ฐ€์ง€ use case ๋“ค์„ ์ •๋ฆฌํ•˜์˜€์Šต๋‹ˆ๋‹ค. 

1.  ๋‹จ์ˆœํžˆ High-order derivative ๊ตฌํ•˜๊ธฐ

 

2.  Neural Network ์˜ input ํ˜น์€ weight์— ๋Œ€ํ•ด์„œ Gradient ๊ตฌํ•˜๊ธฐ 

 

3.  Model-Agnostic Meta Learning (MAML) : ๋ฉ”ํ„ฐ๋Ÿฌ๋‹

 

4.  Hessian of a neural network (Hessian Vector Product)

 

5.  Implicit Neural Representation

 

6.  Multiple Backwards

 

 

---- ์ด์™ธ์— ๊ดœ์ฐฎ์€ UseCase๋ฅผ ๋Œ“๊ธ€๋กœ ์†Œ๊ฐœํ•ด์ฃผ์‹œ๋ฉด ์ถ”๊ฐ€ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค ๐Ÿ™‚

---- ๋ณธ๋ฌธ์—์„œ ์‚ฌ์šฉ๋œ ์ฝ”๋“œ : https://github.com/fxnnxc/torch_autograd/blob/main/autograd_case_study.ipynb

---- Autograd ์†Œ๊ฐœ (Pytorch) : https://tutorials.pytorch.kr/beginner/blitz/autograd_tutorial.html

---- ์ž๋™ ๋ฏธ๋ถ„ ์†Œ๊ฐœ (youtube) : https://www.youtube.com/watch?v=MswxJw-8PvE (๊ฐ•์ถ”)

 


1.  ๐Ÿ”– ๋‹จ์ˆœํžˆ High-order derivative ๊ตฌํ•˜๊ธฐ

๊ฐ€๋ น ์ธํ’‹ $x$์— ๋Œ€ํ•˜์—ฌ ํ•จ์ˆ˜๊ฐ’$f$ ์˜ ๋ฏธ๋ถ„๊ณ„์ˆ˜๋ฅผ ๊ตฌํ•œ๋‹ค๋ฉด, $\frac{d}{dx}$ ๋ฅผ ์—ฌ๋Ÿฌ๋ฒˆ ํ•˜์—ฌ ๊ณ„์‚ฐํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

 

$$\frac{df}{dx} ~~~~ \frac{d}{dx}\Big(\frac{df}{dx}\Big) ~~~~ \frac{d}{dx}\Big(\frac{d^2f}{dx^2}\Big)$$

 

์ด์™€ ๋น„์Šทํ•˜๊ฒŒ torch์—์„œ๋„  torch.autograd.grad(f, x, create_graph=True)  ๋ฅผ ํ†ตํ•˜์—ฌ ๊ตฌํ•œ Gradient ์— ๋Œ€ํ•ด์„œ ๋‹ค์‹œ Gradient๋ฅผ ๊ตฌํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๐Ÿ”–create_graph ์˜ ์—ญํ• ์€ gradient ์— ๋Œ€ํ•ด์„œ๋„ computational graph๋ฅผ ๋งŒ๋“ค์–ด์ค€ ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์ด๋กœ๋ถ€ํ„ฐ ๊ตฌํ•œ gradient๋Š” gradient ๋ฅผ ์ €์žฅํ•  ์ˆ˜ ์žˆ๋Š” ๐Ÿ”–Leaf (๋’ค๊ฐ€ ์—†๋Š” ํ…์„œ) ๊ฐ€ ๋˜๊ณ , ์ด ๊ฐ’์— ๋Œ€ํ•ด์„œ ์—ฐ์‚ฐ๋“ค (์ œ๊ณฑ, ๋ง์…ˆ, ์‹ฌ์ง€์–ด Gradient) ๊นŒ์ง€ ๊ณ„์‚ฐํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. 

 

import torch 
v = -2.0
x = torch.tensor(v, requires_grad=True)

# function 
f = 2 * x**3
print(f"x={v},         f = {f}")

# first order
f.backward(retain_graph=True)  # ๋’ค์—์„œ๋„ ๋‹ค์‹œ backward ๋ฅผ ํ•ด์„œ ์‚ด๋ ค๋‘ก์‹œ๋‹ค.
print(f"x={v},     df/dx = {x.grad}")

# second order
x.grad.zero_()
df = torch.autograd.grad(f, x, create_graph=True)[0]
df.backward(gradient=torch.tensor(1.0))
print(f"x={v}, d^2f/dx^2 = {x.grad}")

# thrid order
x.grad.zero_()
df = torch.autograd.grad(f, x, create_graph=True)[0]
ddf = torch.autograd.grad(df, x, create_graph=True)[0]
ddf.backward(gradient=torch.tensor(1.0))
print(f"x={v}, d^3f/dx^3 = {x.grad}")

 

ํ•ด๋‹น ์ฝ”๋“œ๋ฅผ ๋Œ๋ ค๋ณด๋ฉด, ์•„๋ž˜์™€ ๊ฐ™์ด autograd๋กœ ๊ตฌํ•œ ๊ฐ’๊ณผ ์ˆ˜์‹์œผ๋กœ ๊ตฌํ•œ ๊ฐ’์ด ์ผ์น˜ํ•˜๋Š” ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. 

 

 

 


2.  ๐Ÿ”–  Neural Network ์˜ input ํ˜น์€ weight์— ๋Œ€ํ•ด์„œ Gradient ๊ตฌํ•˜๊ธฐ 

$f(x)$ ๋ฅผ  ํŒŒ๋ผ๋ฏธํ„ฐ $\theta$ ๋กœ ์ด๋ฃจ์–ด์ง„ Deep Neural Network ๋ผ๊ณ  ํ•œ๋‹ค๋ฉด,  Gradient ๋Š” ๋‘ ๊ฐ€์ง€์— ๋Œ€ํ•ด์„œ ๊ตฌํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. 

 

1. input $x$ ์— ๋Œ€ํ•ด์„œ $df/dx$

 

2. parameter $\theta$ ์— ๋Œ€ํ•ด์„œ $df/d\theta$

 

ํ˜น์€ ์‹ฌ์ง€์–ด ํŒŒ๋ผ๋ฏธํ„ฐ์— ๋Œ€ํ•œ gradient ๋ฅผ ํ‰๊ท ๋‚ธ ๊ฐ’์— ๋Œ€ํ•ด์„œ๋„ ๊ตฌํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. 

$$ G = \sum_\theta \Big(\frac{df}{d\theta} \Big)^2 $$

 

๊ตฌํ˜„์—์„œ ํ•œ ๊ฐ€์ง€ ํŠธ๋ฆญ์€ ๐Ÿ”–ํŒŒ๋ผ๋ฏธํ„ฐ๋“ค์„ ์ผ์ž๋กœ ๋งŒ๋“œ๋Š” ๊ฒƒ ์ž…๋‹ˆ๋‹ค.

๊ตณ์ด ([A,A,A], [B,B], [C], [D]) ๋ ˆ์ด์–ด๋กœ ๋‚˜๋ˆ ์ ธ ์žˆ์„ ํ•„์š”๊ฐ€ ์—†์ด, [A,A,A,B,B,C,D] ํ˜•ํƒœ๋กœ ๋งŒ๋“ญ๋‹ˆ๋‹ค. 

import torch.nn as nn 

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = nn.Linear(2,5)
        self.out = nn.Linear(5, 1)
    def forward(self, x):
        x = self.mlp(x**2)
        x = nn.functional.relu(x)
        x = self.out(x)
        return x 
        
net = Net()
input= torch.tensor([1.0]*2, requires_grad=True)
params = [p for p in net.parameters() if len(p.size()) > 1]

# gradient with respect to the input
print("------------ Gradient with respect to the input")
f = net(input)
f.backward()
print(input.grad)


# gradient with respect to the weights 
print("------------ Gradient with respect to the weights")
input.grad.zero_() 
f = net(input)
gx = torch.autograd.grad(f, params, create_graph=True)[0]
for i in range(2):
    gx[i].sum().backward(retain_graph=True)
    print(input.grad)
    
# gradient of weights to minimize the sum of squared of the gradients
print("------------ Gradients to the sum of squqared gradients")
net.zero_grad()
f = net(input)
gx = torch.autograd.grad(f, params, create_graph=True)[0]
(gx**2).sum().backward()

print([p.grad for p in params])

 

 


3.  Model-Agnostic Meta Learning (MAML) : ๋ฉ”ํ„ฐ๋Ÿฌ๋‹

๋‹ค์Œ ์˜ˆ์‹œ๋Š” MAML์—์„œ ์‚ฌ์šฉํ•˜๋Š” Gradient Update ๋ฐฉ์‹์ž…๋‹ˆ๋‹ค. 

MAML์€ (1) meta-train ๋ฐ์ดํ„ฐ์— ๋Œ€ํ•ด์„œ weight๋ฅผ ๊ณ„์‚ฐํ•˜๊ณ , (2) meta-test ๋ฐ์ดํ„ฐ์— ๋Œ€ํ•ด์„œ update ๋œ weight๋ฅผ ์ตœ์†Œํ™”ํ•˜๋Š” ๋ฐฉ์‹์œผ๋กœ ํ•™์Šตํ•ฉ๋‹ˆ๋‹ค. ์ด ๋•Œ, (1) ์—์„œ ๊ตฌํ•œ weight๋Š” $\theta' = \theta - \alpha \nabla_\theta L_{task} (f(\theta))$ ์œผ๋กœ ๊ตฌํ•ด์ง€๋Š”๋ฐ, gradient๊ฐ€ $\theta$ ์— ๋Œ€ํ•ด์„œ ํŒŒ๋ผ๋ฏธํ„ฐ๋กœ ๊ณ„์‚ฐ๋ฉ๋‹ˆ๋‹ค. ์ด๋กœ๋ถ€ํ„ฐ (2) ์—์„œ ๋‹ค์‹œ gradient ๋ฅผ ๊ณ„์‚ฐํ•ด์ค๋‹ˆ๋‹ค. 

$$ \theta \leftarrow \theta - \nabla_\theta L_{{meta}} (f(\theta')) $$

๋…ผ๋ฌธ : Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks

import torch 
import torch.nn as nn 
import torch.nn.functional as F 

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = nn.Linear(2,5)
        self.out = nn.Linear(5, 1)
    def forward(self, x, fast_weights=None):
        if fast_weights is None:
            x = self.mlp(x)
            x = nn.functional.relu(x)
            x = self.out(x)
        else:  # same logit with given weight
            w, b = fast_weights[0], fast_weights[1]
            x = F.linear(x, w, b)
            x = nn.functional.relu(x)
            w, b = fast_weights[2], fast_weights[3]
            x = F.linear(x, w, b)
        return x 

net = Net()
lr=1e-3
meta_optim = torch.optim.Adam(net.parameters(), lr=lr)


meta_loss = 0
num_tasks = 3
for task in range(num_tasks):
    # Step 1 : compute one_step weight
    input= torch.tensor([(task+1.0)]*2)
    f = net(input)
    grad = torch.autograd.grad(f, net.parameters(), create_graph=True)    # Step 1 Core
    fast_weights = list(map(lambda p: p[1] - lr * p[0], zip(grad, net.parameters())))

    # Step 2 : compute the loss with one step weight
    meta_input= torch.tensor([(task+2.0)]*2)
    test_loss = net(input, fast_weights)      # Step 2 Core
    meta_loss += test_loss 

meta_loss /= num_tasks

# optimize theta parameters
meta_optim.zero_grad()
meta_loss.backward()
meta_optim.step()

 

 


4.  Hessian of a neural network (Hessian Vector Product)

Hessian Matrix $H(\theta)_{N\times N} =  \frac{d^2}{d\theta^2} \Big(  L(\theta) \Big)$์™€ Vector $r_{N\times 1}$ ์— ๋Œ€ํ•œ Product ์—ฐ์‚ฐ์€ 

 

$$ \begin{aligned} H(\theta) \cdot r &= \frac{d^2}{d\theta^2} \Big(  L(\theta) \Big) \cdot r \\ &= \frac{d}{d\theta} \Big( \frac{d}{d\theta} L(\theta) \cdot r \Big) \end{aligned} $$

 

์—ฌ๊ธฐ์„œ, $r$ ์€ $\theta$ ์™€ ๊ด€๋ จ์—†๋Š” ๋ฒกํ„ฐ์ด๋ฏ€๋กœ Second Derivative ๋ฅผ ๊ณ„์‚ฐํ•˜๋Š”๋ฐ ์˜ํ–ฅ์„ ๋ฏธ์น˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. 

 

 

โœ… Step 1. Compute the gradient   $\frac{d}{d\theta} L(\theta)$

 

โœ… Step 2. Compute the gradient times vector $\Big(\frac{d}{d\theta} L(\theta) \Big) \cdot r$

 

net = Net()
params = [p for p in net.parameters() if len(p.size()) > 1]
N = sum(p.numel() for p in params)
print(f"Number of parameters : {N}")
print("Params", params)

# Compute the gardients
input= torch.tensor([(task+1.0)]*2)
f = net(input)
grad = torch.autograd.grad(f, inputs=params, create_graph=True)    # Step 1 Core

# Hessian Vector Product
prod = torch.autograd.Variable(torch.zeros(1)).type(type(grad[0].data))
vec = torch.rand_like(prod)
for (g,v) in zip(grad, vec):
    prod = prod + (g * v).cpu().sum()  # Step 2 Core
    
prod.backward()  # Now the params.grad stores the Hessian 

print("----------------------")
print("Hessian Vector Product")
for p in params:
    print(p.grad)

 

 


5.  Implicit Neural Representation

 

 

์ด๋ฏธ์ง€๋„ ํ•จ์ˆ˜๋กœ ์ดํ•ดํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.  ๊ฐ ํ”ฝ์…€ ์œ„์น˜์— ๋Œ€ํ•ด์„œ ์ƒ‰๊น”์„ ์˜ˆ์ธกํ•˜๋Š” ํ•จ์ˆ˜๋กœ ๋ง์ด์ฃ . ๊ทธ๋ ‡๊ธฐ ๋•Œ๋ฌธ์—, ์ด๋ฏธ์ง€์— ๋Œ€ํ•ด์„œ๋„ Gradient ๊ฐ€ ์žˆ์œผ๋ฉฐ, ์œ„ ๊ทธ๋ฆผ์ฒ˜๋Ÿผ x ๋ฐฉํ–ฅ, y๋ฐฉํ–ฅ์— ๋Œ€ํ•ด์„œ ๋นจ๊ฐ„์ƒ‰ ์ฒด๋„์˜ Gradient๋ฅผ ๊ตฌํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ž์„ธํ•œ ๋‚ด์šฉ์€ ๐Ÿ”–Sobel Filter ๋ฅผ ์ฐธ์กฐํ•ด์ฃผ์„ธ์š”.  ์ด๋ฏธ์ง€์˜ Gradient๋ฅผ ์‹ ๊ฒฝ๋ง์œผ๋กœ Fitting ํ•  ์ˆ˜ ์žˆ์œผ๋ฉฐ, ์ด ๋ถ„์•ผ๋ฅผ Implicit Neural Representation ์ด๋ผ๊ณ  ๋ถ€๋ฆ…๋‹ˆ๋‹ค. 

 

1. Original Function : $ f: p(x,y) \rightarrow  (r,g,b) $

2. Neural Network Fitting : ์œ„ ํ•จ์ˆ˜๋ฅผ ๊ทผ์‚ฌํ•˜๋Š” ์‹ ๊ฒฝ๋ง $\phi \approx f$ 

 

ํ•จ์ˆ˜๋ฅผ ํ”ผํŒ…ํ•˜๋Š” Objective ๋Š” ๋‘ ๊ฐ€์ง€ (ํ”ฝ์…€๊ฐ’, ๊ทธ๋ ˆ๋””์–ธํŠธ ์œ ์ถ”) ๋กœ ๋‚˜๋‰ฉ๋‹ˆ๋‹ค. 

 

1. Pixel Fitting : compute the color at pixel position $p'$ $$ \mathcal{L} = ||f(p') - \phi(p')||^2 $$ 

2. Gradient Fitting : compute the gradient of color at pixel position $p'$ 

$$ \mathcal{L} = ||\frac{df}{dp}\Big|_{p'} - \frac{d\phi(p)}{dp}\Big|_{p'}||^2 $$

 

class ImplicitNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = nn.Sequential(
            nn.Linear(2, 64), # Input the x,y position
            nn.ReLU(),
            nn.Linear(64, 3)  # RGB prediction
        )
    def forward(self, x):
        x = x.clone().detach().requires_grad_(True)
        return self.layer(x), x
        
net = ImplicitNet()
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)

# given: position(x,y) --> output: ColorPrediction, position
coordinates = torch.tensor([[0.5, 0.5]]) 
all_color_prediction, coordinates = net(coordinates)

# Matching the gradient 
# assume the ground truth gradient
gt = torch.rand(1, 2) 

loss = 0 
for channel in range(3):  # each R G B
    color_predict = all_color_prediction[:, channel]
    grad_predict = torch.autograd.grad(color_predict, 
                                       coordinates,    
                                       create_graph=True)[0]  # create_graph makes the optimization possible. 
    loss+= F.mse_loss(gt, grad_predict)
    
optimizer.zero_grad()
loss.backward()
optimizer.step()

 


6.  Multiple Backwards

์•„๋ž˜ ๊ทธ๋ฆผ๊ณผ ๊ฐ™์ด ๋‘ ๊ฐœ์˜ Bracnh ์—์„œ Gradient Descent ๋ฅผ ์ง„ํ–‰ํ•˜๋Š” ๊ฒฝ์šฐ,

retrain_graph=True ๋ฅผ ์ด์šฉํ•˜์—ฌ backward ํ•œ gradient ๋ฅผ ๋ณด์กดํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. 

import torch 
import torch.nn as nn 

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.common = nn.Linear(2, 3)
        
        self.net1 = nn.Sequential(
            nn.Linear(3, 5),
            nn.ReLU(),
            nn.Linear(5,1)
        )
        self.net2 = nn.Sequential(
            nn.Linear(3, 5),
            nn.ReLU(),
            nn.Linear(5,1)
        )   
        
    def forward(self, x):
        common = self.common(x)
        path1 = self.net1(common)
        path2 = self.net2(common)
        return path1, path2 
    

net = Net()
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)
input = torch.rand(1, 2)

out_1, out_2 = net(input)

# optimize net1 and common
optimizer.zero_grad()
out_1.backward(retain_graph=True)  # Core Code for retaining computational graph
optimizer.step()

# optimize net2 and common (again)
optimizer.zero_grad()
out_2.backward()
optimizer.step()

 

๊ดœ์ฐฎ์€ Use Case๊ฐ€ ์žˆ๋‹ค๋ฉด ๊ณต์œ ํ•ด์ฃผ์„ธ์š” ๐Ÿค— 

Reference๋ฅผ ๋‹ฌ๊ณ  ์—…๋กœ๋“œ ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค~!

์ž˜๋ชป๋œ ๊ตฌํ˜„๋„ ํ”ผ๋“œ๋ฐฑ ํ•ด์ฃผ์„ธ์š” :) 

๋ณธ๋ฌธ์—์„œ ์‚ฌ์šฉ๋œ ์ฝ”๋“œ : https://github.com/fxnnxc/torch_autograd/blob/main/autograd_case_study.ipynb

๐Ÿ’› ์ข‹์•„์š” /๋Š” ๐Ÿ“ํฌ์ŠคํŒ…์— ๐Ÿ’ช๐Ÿผํž˜์ด ๋ฉ๋‹ˆ๋‹ค