โ๐ป EXP Vision Transformer๋ก CIFAR 10 ํ์ตํ๊ธฐ [Korean]
ViT ๊ฒฐ๋ก (TL;DR)
๐ MNIST ๋ ํ์ต์ด ์์ฃผ ์ฝ๋ค.
๐ CIFAR 10 ์ CrossEntropy๋ก Scratch ํ์ต์ ์ด๋ ต๋ค.
๐ Pretrain ๋ ๋ชจ๋ธ์ ์ฌ์ฉํ๋ฉด 1 epoch ๋ง์ ๋์ ์ฑ๋ฅ์ ๋ณด์ธ๋ค.
์ด ์คํ์ ์งํํ๊ธฐ ์ ์ ๋ชจ๋ธ์ ๋ํด์ ํ ๊ฐ์ง ๋ฏฟ์์ด ์์๋ค.
ํ์ต๋ฐ์ดํฐ์ ๋ํด์ Loss๋ฅผ ์ค์ด๋ ๊ฒ์ Validation Loss๋ฅผ ์ด๋์ ๋ ์ค์ธ๋ค.
" Decreasing training loss ensures large portion of validation"
๊ทธ๋ฌ๋ ๊ทธ๋ ์ง ์์ ๋ชจ๋ธ์ด ์์์ ์๊ฒ ๋์๋ค.
โ๐ปPost Structure
1. ViT ์ค๋ช
2. MNIST ํ์ต
3. CIFAR 10 ํ์ต
4. Pretrained -> CIFAR 10 ํ์ต
๋จผ์ ์ด์ผ๊ธฐ๋ ViT๋ฅผ ํ์ต์ํค๋๋ฐ์ ์์ํ๋ค. ViT๋ Transformer ์ธ์ฝ๋ ๋ธ๋ก์ ์ฌ๋ฌ ๊ฐ ์์์ฌ๋ฆฐ ๊ตฌ์กฐ๋ก, ๊ฐ ๋ธ๋ก์ Multi-head Attention๊ณผ MLP ๊ทธ๋ฆฌ๊ณ Layer Normalization์ผ๋ก ์ด๋ฃจ์ด์ ธ์๋ค. ๋ชจ๋ธ์ input sequence์ ๋ํด์ ์ดํ ์ ๊ธฐ๋ฐ์ผ๋ก ํ์ต์ ์งํํ๋ค. [๊ทธ๋ฆผ์ฐธ์กฐ]
์ด๋ฌํ ๊ตฌ์กฐ๋ก๋ถํฐ ์ค๋ ViT์ ๋ ๊ฐ์ง ํน์ง์ ๋ค์๊ณผ ๊ฐ๋ค.
- ๋ ์ด์ Translation Equivariance ํ์ง ์๋๋คโ. ์๋ก ๋ค๋ฅธ ํจ์น๋ค์ Positional Encoding์ผ๋ก ์์น ์ ๋ณด๊ฐ ์ถ๊ฐ๋์ด ์๋ค. (CNN์์๋ Weight๊ฐ ๊ณฑํด์ง ๋ ์์น์ ๋ํ ์ ๋ณด๊ฐ ์๋ค.)
- CNN์ Localํ Neighbor ์ ๋ณด๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ํ์ตํ๋ Local Receptive Field ์ธ ๋ฐ๋ฉด, ViT๋ Attention ์ ์ฌ์ฉํ๋ Global Receptive Field ๋ฅผ ๊ฐ์ง๋ค.
๋ ผ๋ฌธ ์ฐธ์กฐ : Transformers in Vision: A Survey (https://arxiv.org/abs/2101.01169)
๊ทธ๋ฆผ ์ถ์ฒ : ๋ด๋ธ๋ก๊ทธ https://fxnnxc.github.io/blog/2022/exp_20/
CNN์ inductive bias๊ฐ ์ฌํด์ ํ์ต์ด ์ฝ์ง๋ง ์์ ๋๊ฐ ๋ฎ๋ค. ๋ฐ๋ฉด์ ViT๋ Attention์ ๋ชจ๋ ํจ์น์ ๋ํด์ ์ฌ์ฉํ๊ธฐ ๋๋ฌธ์ ํ์ต์ ์์ ๋๊ฐ ๋๋ค. ์์ ๋๊ฐ ๋์ ๋งํผ, Classification์์๋ CNN์ ๋นํด์ ๋ ๋ง์ ์ํ์ด ํ์ํ๋ค. ์ด๋ฌํ ํน์ง์ ๊ธฐ๋ฐ์ผ๋ก ๋ ๊ฐ ๋ชจ๋ธ์ ๊ฒฐํฉํ ๋ชจ๋ธ์ ์ ์ํ๊ธฐ๋ ํ์๋๋ฐ, ์ฌ๊ธฐ์๋ ๋ ผ์ธ๋ก ํ๊ณ Pureํ ViT ๊ตฌ์กฐ๋ฅผ ์ฌ์ฉํ์๋ค.
๋ ผ๋ฌธ ์ฐธ์กฐ : Convolution Vision Transformer (https://arxiv.org/abs/2103.15808)
๐ Story 1: MNIST Training ํ์ต
์ผ๋จ ๊ธฐ๋ณธ์ ์ผ๋ก ViT์ ์ดํดํ๊ธฐ ์ํด์ MNIST ๋ก ํ์ตํด๋ดค๋ค. ๋ชจ๋ธ ๊ตฌ์กฐ๋ Pytorch ์์ ๊ตฌํํ vit_16 ๋ชจ๋ธ ๊ตฌ์กฐ๋ฅผ ๋ฐ๋๋ค. ๊ทธ๋ฅ ํ์ตํ๋ฉด ์ฌ์ฌํ๋, Layer ๊ฐ์๋ฅผ ๋ค์ํ๊ฒ ์ค์ ํด๋ดค๋ค.
์คํ ๊ฒฐ๊ณผ, Training Loss๋ ์ ์ค์ด๋ค์๊ณ , Validation Loss๋ ๋๋ถ๋ถ์ 3๊ฐ ์ด์์ ๋ ์ด์ด์์๋ 99% ์ด์์ ์ ์๐ฆธ๐ปโ๏ธ๋ฅผ ์ป์๋ค. ๊ทธ๋ฌ๋ ํ๋์ ๋ ์ด์ด๋ฅผ ์ฌ์ฉํ๋ ๊ฒฝ์ฐ, ์ ๋๋ก ํ์ตํ์ง ๋ชปํ์๋ค. ํ์ต์ ์ฌ์ฉ๋ hyperparameter๋ค์ ๋ค์๊ณผ ๊ฐ๋ค.
image size : 56, patch_size: 4 , batch_size:32 learning rate : 1e-4
trainining for 200 epochs. StepLR with gamma 0.5 with every 100 epochs
hidden_dim : 128, dropout=0.5
(์คํ ๊ฒฐ๋ก ) ์ด ์คํ์ผ๋ก ViT๋ฅผ ํ์ตํ๋๋ฐ ์์ด์, ๋ ์ด์ด ์๋ฅผ ๋ง์ด ํ ํ์๊ฐ ์์ผ๋ฉฐ, 3๊ฐ ์ ๋ ์ด์์ ๋ชจ๋ธ์ด ๋ฐ์ดํฐ์ ๋ํด์ ์ ๋๋ก ํ์ตํ๋ค๋ ๊ฒ์ ์๊ฒ๋์๋ค. ๊ทธ๋ฆฌ๊ณ ์ด์ผ๊ธฐ๋ CIFAR 10 ์ ํ์ตํ๋๋ฐ๋ก ๋์ด๊ฐ๋ค.
์ฐธ๊ณ ๋ก ViT ๊ณ์ด์ 12๊ฐ ๋ ์ด์ด๋ฅผ ๊ฐ์ง๋ ๊ฒ์ด ๋ณดํต์ด๋ค.
๋ ํฐ ๋ชจ๋ธ์ ๋ธ๋ก๋ ๋ง๊ณ ์ฐจ์๋ ์ปค์ ํ๋ผ๋ฏธํฐ๊ฐ ํจ์ฌ ๋ง๋ค.
๋ ผ๋ฌธ ์ฐธ์กฐ : An Empirical Study of Training Self-Supervised Vision Transformers
๐ Story 2 : Cifar 10 ํ์ต ์คํจ ์คํ ๋ฆฌ
MNIST์ ๋์ผํ ๋ชจ๋ธ๊ตฌ์กฐ๋ก ์ด๋ฏธ์ง๋ฅผ ํ์ต์์ผฐ๋ค. ๋ง์ฐฌ๊ฐ์ง๋ก ๋ค๋ฅธ ๋ ์ด์ด ์์ ๋ํด์ ๊ฒ์ฆํ์๊ณ ๊ฒฐ๊ณผ๋ ๋ณ๋ก ์ข์ง ๋ชปํ๋ค.
๋จผ์ Training Loss๋ฅผ ๋ณด์. ํค๋ ๊ฐ์๋ฅผ ๋๋ฆด์๋ก Loss๋ ๋น ๋ฅด๊ฒ ์ค์ด๋ ๋ค. ์ด๋ ๋ ์ด์ด๊ฐ ๋ง์์ง์ผ๋ก ์ธํด์ ํํ๊ณต๊ฐ์ด ๋ ๋ง์ ์ ๋ณด๋ฅผ ๋ด๊ธฐ ๋๋ฌธ์ผ๋ก ๋ณด์ธ๋ค. Loss ๊ฐ 0์ ์ ์ ๊ฐ๊น์์ง๋ ๊ฒ์ ํ์ธํ ์ ์๋ค. ๊ทธ๋ฌ๋ ํ์ต๋ ๋ชจ๋ธ๋ก Validation์ ์งํํ๋ ๊ฒฝ์ฐ, ์์ฃผ ์๋ง์ธ ๊ฒ์ ํ์ธํ ์ ์๋ค. ์ด์ ๋ํด์ ํ ๊ฐ์ง ๊ฐ์ ์ Attention์ ๋ํด์ ๋ฐฐ์ฐ๊ธฐ ์ ์, Training ๋ฐ์ดํฐ์ ๋ํด์ Cheating ํ๋ ๋ฐฉ๋ฒ์ ๋ฐฐ์ ์ง๋ง, ๊ทธ๊ฒ Validation์๋ ํตํ์ง ์์๊ธฐ ๋๋ฌธ์ผ๋ก ๋ณด์ธ๋ค. Attention์ ๊ธฐ๋ณธ์ ์ผ๋ก ์ด๋ฏธ์ง์ ์ค์ํ ๋ถ๋ถ๋ค์ ๋ํ ๊ฒฐ๋จ์ ๋ด๋ ค์ผ ํ๋๋ฐ, ํด๋น ํํํ์ต์ ์งํํ๊ธฐ ์ ์ ์ข๋ ์ฌ์ด ๋ฐฉ์์ ๋ฐฐ์ฐ๋ ๊ฒ์ด๋ค. ์ ํ๋๊ฐ 0.7 ์ ๋์ง๋ชปํ๊ณ , ๋ ์ด์ ์ฆ๊ฐํ์ง ์๋๋ค. ๋ฌผ๋ก Regularization์ด๋ Drop-out, Self-supervised learning ๊ณผ ๊ฐ์ ์ฌ๋ฌ๊ฐ์ง ํธ๋ฆญ์ ์ ์ฉํ ์ ์์ง๋ง, ์ฌ๊ธฐ์๋ ๋จ์ํ Supervised loss ๋ง ๊ณ ๋ คํ์๋ค.
์ด์ ๋ฐ์ดํฐ์ ๊ตฌ์กฐ๋ฅผ ๋์ฑ ์ ํ์ ํ๊ณ ์๋ ๋ชจ๋ธ ํ๋ผ๋ฏธํฐ : ๐ฆธ๐ปโ๏ธPRETRAIN ๋ชจ๋ธ๐ฆธ๐ปโ๏ธ๋ก ํ์ต์์ผ๋ณด์.
๋ํ ์ด์ ๋ ์ด์ด ๊ฐ์๋ฅผ Base model์ธ 12 โ ซ๊ฐ๋ก ๋ง์ถ์.
๐ Story 3 : Pretrained Model --> CIFAR 10
Pretrained Model ์ ๊ตฌํ ์ ์๋ ๊ณณ์ torchvision์ด๋ค.
๋ฌผ๋ก HuggingFace์์๋ ์ ๊ณตํ๋๋ฐ, torch๋ฅผ ์ฌ์ฉํ๋ ์ ์ ๋ก์ ๋ ์ฌ์ฉํ๊ธฐ ๊ฐ๋จํ ๊ฑด torchvision ์ด์๋ค.
- ๐ง TorchVision : https://pytorch.org/vision/stable/models/generated/torchvision.models.vit_b_16.html
- ๐ค HuggingFace : https://huggingface.co/aaraki/vit-base-patch16-224-in21k-finetuned-cifar10
torchvision models ์๋ ๊ธฐ๋ณธ์ ์ผ๋ก 3 ์ข ๋ฅ์ Vision Transformer Weights ๊ฐ ์กด์ฌํ๋ค.
- SWAG๋ก ImageNet 1K - Finetuning๐๏ธโ๏ธ
- SWAG๋ก ImageNet1K (Frozen) + ImageNet 1K Finetuning๐๏ธโ๏ธ
- ImageNet 1K Scratch ํ์ต ๐๏ธโ๏ธ
๋ ผ๋ฌธ ์ฐธ์กฐ : SWAG(Revisiting Weakly Supervised Pre-Training of Visual Perception Models)
https://arxiv.org/abs/2201.08371
์ธ ๊ฐ์ง ๋ชจ๋ธ ๋ชจ๋ ๋ฐ์ดํฐ ImageNet์ ๋ํด์ ๋์ ์ฑ๋ฅ์ ๋ณด์ธ๋ค. ์ด ๋ชจ๋ธ๋ค์ ๊ฐ์ ธ์์ CIFAR 10 ์ ํ์ต์์ผฐ๋ค. ์ฐธ๊ณ ๋ก ์ธ ๋ชจ๋ธ์ ๊ฐ๊ฐ Resize, CropSize, Patch Size๊ฐ ๋ค๋ฅด๋ค. ๋ฐ๋ผ์ ๋ชจ๋ ๋ค๋ฅธ ๊ฐ์ ์ฌ์ฉํด์ค์ผ ํ๋ค.
Fine Tuning ์งํํ๊ธฐ ๐ฏ
๊ฒฐ๊ณผ๋ฅผ ๋ณด๋ฉด, ์คํฌ๋์น์ ๋นํด์ ์ด๋ฐ์ accuracy๊ฐ ํ ๋์์ง๋ ๊ฒ์ ๊ด์ฐฐํ ์ ์๋ค. ๋ํ Frozen ๋์๋ Linear ๋ชจ๋ธ์ ์ฌ์ค์ ImageNet-1K์ ๋ํด์ ํ๋์ํจ๊ฒ ์๋๋ฏ๋ก Classification์ ๋ค์ ํ์ตํ๋๋ฐ ์๊ฐ์ด ์์๋์๋ค. ํ์คํ ์ ์ Pretrain์ด ์คํฌ๋์น๋ณด๋ค ๋์ผ๋ฉฐ, Classification์ Bias ๋ ๋ชจ๋ธ์ ๋ ์ข๋ค๋ ๊ฒ์ด๋ค.
์์ ์คํ๊ณผ ๋ง์ฐฌ๊ฐ์ง๋ก training loss๋ฅผ ์ค์ด๋ ๊ฒ์ ์ฌ์ฐ๋ Validation ์ ์ฑ๋ฅ์ ๋์ด๋ ๊ฒ์ ์ฝ์ง ์์๋ค. ์ถ๊ฐ์ ์ผ๋ก ํ๋์ ํ๋ฉด ํ์ต์ ๋์ฑ ์ ์ํฌ ์ ์์ผ๋, Validation์ overfitting๋ ๊ฑฐ ๊ฐ์์, ๊ฐ์ฅ ๊ฐ๋จํ Adam Optimizer + learning rate 0.0001 ์ ์ฌ์ฉํ์๋ค. ๋จ์ํ CrossEntropy Loss๋ก ์ฑ๋ฅ์ ๋์ฑ ๋์ด๋๋ฐ๋ ํ๊ณ๊ฐ ์๋ ๊ฒ ๊ฐ๋ค.
FINAL ๊ฒฐ๋ก ๐
์คํ์ ๊ธฐํํ๋ฉด์ ์ ์ผ ๊ถ๊ธํ๋ ๊ฒ์ "Transformer ๋ฅผ Supervised Learning์ผ๋ก ํ์ตํ๋ฉด ์ด๋ป๊ฒ ๋๋๊ฐ" ์๋ค. ์คํ์ ์ผ๋ก ๋ณด์ธ ๊ฒ์ ์คํฌ๋์น๋ก ํ์ตํ๋ ๊ฒ์ด ์์ข๋ค๋ ๊ฒ์ด๋ค. ๋ชจ๋ธ์ ๋ฐ์ดํฐ์ ๋ํด์ ์๋ฏธ์๋ ์ ๋ณด๋ฅผ ์ถ์ถํ๋ ๊ฒ์ด ์๋๋ผ, ๋จ์ํ ํ์ต๋ฐ์ดํฐ์ Overfitting ํ๋ ๋ฐฉ์์ผ๋ก ํ์ตํ๋ ๊ฒ์ด์๊ณ , ์ด๋ Self-Supervised Learning์ผ๋ก๋ถํฐ ๋ฐฐ์ด ๋ชจ๋ธ์ด ๊ฐ์ง๋ Global Receptive๋ฅผ ๊ฐ์ง ๋ชปํ๋ค๋ ๊ฒ์ ๋ํ๋ธ๋ค.
๋ชจ๋ธ์ ํ์ตํ ๋ ๋ณดํต Trainining Loss ๊ฐ ์ค์ด๋ค๋ฉด Validation๋ ์ค์ด๋ค ๊ฒ์ด๋ผ๊ณ ์์ํ์ง๋ง ViT ๋ชจ๋ธ์ ๊ทธ๋ฌํ ๊ธฐ๋๋ฅผ ๊ณผ๊ฐํ๊ฒ ๋ถ์๊ณ ๋ฐ์ดํฐ์ ๋ํด์ Cheating ํ๋ ๊ฒ์ ๋ณด์ฌ์คฌ๋ค. ๊ฒฐ๊ตญ Pretraining ์์ฒด๊ฐ ๊ต์ฅํ ์ค์ํ๋ฉฐ, ์ด๋ ํ๋ผ๋ฏธํฐ Space์ ๋ํด์ ์๋ฏธ์๋ ๊ณต๊ฐ์ด ๋ฐ๋ก ์กด์ฌํ๋ ๊ฒ์ ๋ํ๋ธ๋ค. ์ฆ Initialization / Supervised / Self-Supervised ์ ๋ํ ํ๋ผ๋ฏธํฐ ๊ณต๊ฐ์ด ๋ค๋ฅด๊ธฐ ๋๋ฌธ์, ๊ฐ ๊ณต๊ฐ์์ ์์ํด์ Downstream Task์ ๋ํ ํ์ต์ ์งํํ๋ค๋ฉด ๋ชจ๋ธ์ ๊ทธ ๊ณต๊ฐ์ ๊ทผ์ฒ์์ Loss๋ฅผ ์ต์ํํ๋ ๊ฒ์ ์ฐพ๋๋ค. ์ด๋ฌํ ๊ด์ฐฐ์ ๋ชจ๋ธ์ ํ๋ผ๋ฏธํฐ๋ค์ด ๋ฐ์ดํฐ๋ฅผ ์ต๋ํ ์๊ณ ์๋์ง ๊ฒ์ฆํ๋ ๊ณผ์ ๊ณผ ์ ์ฐจ๋ฅผ ํ์๋ก ํ๋ค๋ ๊ฒ์ ์ ์ ์๋ค.
์์ฝ : ๐จ ์คํฌ๋์น < ๐จ Self-Supervised < ๐จ Classification Pretraining