Power method의 원리와 코드 구현
Pytorch로 구현한 Power method 는 다음과 같습니다.
import torch
import numpy as np
def power_method(weight, power_iteration=1):
# --- variables
dim_in, dim_out = weight.size(1), weight.size(0)
u = torch.rand(dim_out)
v = torch.rand(dim_in)
# --- iteration
for i in range(power_iteration):
u = torch.matmul(weight, v) / torch.matmul(weight, v).norm()
v = torch.matmul(weight.T, u) / torch.matmul(weight.T, u).norm()
return torch.matmul(torch.matmul(weight, v).T, u)
# --- variable
주어진 matrix $W$의 사이즈는 $\text{dim_out} \times \text{dim_in}$ 입니다. 알고리즘은 랜덤하게 두 개의 변수를 선언합니다.
- $u$ : 사이즈 $\text{dim_out}$
- $v$ 사이즈 $\text{dim_in}$
$u$ 와 $v$는 $W$의 앞 뒤에 곱해지는 벡터로, 계산하면 $W v$ 와 $u^\text{T} W = (W^\text{T} u)^{\text{T}}$ 을 얻게 됩니다.
# --- iteration
iteration 부분에서는 $u$ 와 $v$의 값을 업데이트 합니다. $u \leftarrow Wv$ 로 업데이트 하며, $u$의 크기를 1로 고정하기 위해서 $u \leftarrow Wv /||Wv||$로 업데이트 합니다. 마찬가지로 $v \leftarrow W^\text{T} u/ ||W^\text{T}u||$로 $v$의 값을 업데이트 합니다. 업데이트는 총 power iteration 만큼 진행하게 되는데, 보통 1이나 2처럼 한 두번 만 해도 충분합니다.
# --- return 값
$v$ 값은 $W$의 뒤에 곱하고, $u$ 값은 $W$의 앞에 곱한다고 했는데, $u^\text{T} (W v)$ scalar 값을 반환합니다. 이 값은 $W$ 메트릭스의 가장 큰 singular value 값과 동일합니다. (iteartion을 많이 돌수록 비슷해집니다.)
알고리즘 설명
Largest singular value를 구할 matrix $w$를 SVD를 한 형태로 적습니다.
그리고 $u^1, v^1$을 적당한 값으로 초기화합니다.
- $W = \sigma_1 u_1 v_1^\text{T} + \sigma_2 u_2 v_2^\text{T} + \cdots + \sigma_r u_r v_r^\text{T} $
- $W\in \mathbb{R}^{M\times N}$
- $u_i \in \mathbb{R}^M$
- $v_i \in \mathbb{R}^N$
- $u^1, v^1 : $ initialize randonly
- $u^k, v^k :$ 알고리즘에서 k 번째 업데이트로 얻은 벡터들
Iteration을 돌리면 무슨 일이 일어나는가요?
$u^2$값을 $v^1$ 값을 이용해서 구하는 방법은 $u^2 = Wv^1$으로 계산하는 것 입니다. 이 결과로 Singular Value Decomposition을 했을 때 생기는 $v_1, \cdots, v_r$ 과 dot-product가 이루어집니다. $v_i^\text{T} v^1$ 값과 $\sigma_i$는 곱해서 scalar 값이 나오며 $S_i = \sigma_i v_i v^1$ 라는 최종 Scalar 값을 얻게 됩니다.
가장 마지막 식에서, $u^2$ 값이 $u_1, u_2, \cdots, u_r$에 대한 linear-combination으로 표현되는 것을 확일할 수 있습니다. 또한 SVD의 결과로 $\sigma_1 \ge \sigma_2 \ge \cdots \ge \sigma_r$가 성립하게 되므로, $S_1 \ge S_2 \ge S_3 \ge \cdots \ge S_r$값이 대략적으로 성립하는 것을 알 수 있습니다. 그림으로 표현하면 다음과 같습니다.
따라서 한 번의 iteration결과로 $u^1$ 값은, $u_1$에 가까운 벡터로 근사하게 됩니다. 마찬가지로 $v^2$ 도 $u^1$ 을 $W$ 앞에 곱함으로써, scalar 값을을 얻고 $v^2$는 $v_1, \cdots, v_r$ 에 대한 linear combination으로 바뀌게 됩니다. 업데이트 하는 과정에서 SVD의 맨 앞에 있는 $\sigma_1 u_1 v_1^{\text{T}}$에 가까운 $u^k, v^k$로 업데이트가 진행되고, $u^k W v^k$ 를 곱하면, 알고리즘에서 return 값이었던 $\sigma_1$값이 나오게 됩니다.
Assert을 포함한 전체 코드
테스트를 위한 전체 코드는 다음과 같습니다. 중간에 `assert` 부분은 계산할 때, 결과의 차원이 예상대로 나오는지 확인하는 부분입니다. 굳이 필요하진 않지만 디버깅 할 때 편리해서 넣었습니다 :)
import torch
import numpy as np
def power_method(weight, power_iteration=1):
assert weight.ndim == 2
dim_in, dim_out = weight.size(1), weight.size(0)
u = torch.rand(dim_out)
v = torch.rand(dim_in)
for i in range(power_iteration):
u = torch.matmul(weight, v) / torch.matmul(weight, v).norm()
assert u.size(0) == dim_out
v = torch.matmul(weight.T, u) / torch.matmul(weight.T, u).norm()
assert v.size(0) == dim_in
return torch.matmul(torch.matmul(weight, v).T, u)
if __name__ == "__main__":
power_iteration = 2
dim_in, dim_out = 2, 10
weight = torch.rand(dim_out, dim_in)
largest_singular_value = power_method(weight, power_iteration=power_iteration)
numpy_result = np.linalg.svd(weight.numpy())[1][0]
print_string = f"| power_method: {largest_singular_value:.6f} " +\
f"| numpy svd: {numpy_result:.6f} " +\
f"| iteration:{power_iteration}| "
print(print_string)
댓글과 조언은 자유롭게 남기셔도 됩니다 :)
bumjin@kaist.ac.kr