본문 바로가기

딥러닝/머신러닝(ML)

Overfitting을 해결하는 방법 3가지

 

머신러닝은 학습 데이터로부터 배우고 테스트 데이터에 대해서 예측하는 것을 목표로 합니다. Overfitting은 학습 데이터의 성능은 높지만, 테스트 데이터에 대해서 성능이 나오지 않는 경우를 말합니다.

 

Overfitting을 해결하는 방법은 대표적으로 3가지가 있습니다.

 

 


  1. 데이터의 수를 늘린다.
  2. 모델의 Complexity를 줄인다.
  3. Regularization을 사용한다.

편의상 모델의 파라미터에 대한 복잡도를 M, 데이터에 대한 복잡도를 D라고 하겠습니다.

 

OverfittingM>D 인 경우에, 학습이 경과됨에 따라서 학습데이터를 모두 학습한 상태를 말합니다. (모델의 복잡도가 충분하므로 데이터를 모두 학습할 수 있습니다.)

 

문제는 완료된 학습이 데이터에 대한 패턴을 배우는 것이 아니라 데이터를 기억하는 수준이라는 것 입니다. 따라서 새로운 데이터가 등장한다면, 이 데이터의 패턴을 통해서 예측하는 방식이 아니고 기억 속에서 찾아야 하므로 잘 예측하지 못합니다. (보지 못했으니까요)

 

결국 데이터와 모델의 복잡도는 서로 연관되어 있습니다. Overfitting을 해결하는 방법은 이 두가지를 비슷한 크기로 만들어주는 것 입니다.

 

1. 데이터의 수를 늘린다.

첫 번째 해결 방법은 D 를 증가시키는 방법입니다. 동일한 모델에 대해서 다양한 데이터를 줌으로써 데이터의 복잡도를 높입니다. 이는 단순히 데이터의 양 관점에서 내린 결론이 아니며 모델 사이즈와 연관되어 있습니다.

 

2. 모델의 Complexity를 줄인다.

두 번째 해결 방법은 M을 줄이는 방법입니다. 동일한 데이터에 대해서 모델의 복잡도를 줄인다면 데이터로부터 일반적인 패턴을 학습하여 일반화가 됩니다. (이를 Generalization이 된다고 말합니다.)

 

랜덤으로 생성된 15개의 데이터에 대해서, Polynomial Curve Fitting을 해보겠습니다. Polynomial의 차수 [1, 3, 5, 9]에 대해서 그림과 같은 Polynomial Curve가 만들어집니다.

 


# 1, 3, 5, 9차 Polynomial의 계수
--------------
0.2292 x + 0.4642
--------------
-0.004657 x^3 + 0.1235 x^2 - 0.4795 x + 0.6676
--------------
0.0001364 x^5 - 0.007481 x^4 + 0.1426 x^3 - 1.122 x^2 + 3.715 x - 3.163
--------------
9.76e-08 x^9 - 1.029e-05 x^8 + 0.0004635 x^7 - 0.01156 x^6 + 0.1735 x^5
 - 1.594 x^4 + 8.746 x^3 - 26.74 x^2 + 39.75 x - 20.34


1차 다항식은 계수의 절대값이 작고 3차 Polynomial의 경우 계수의 절대값이 조금 커진 것을 확인할 수 있습니다. 9차 다항식은 1차의 계수가 39정도로 더 커진 것을 확인할 수 있습니다. 따라서 모델이 복잡하면 데이터에 Fitting하기 위해서 모델의 계수가 증가하게 됩니다. 이러한 현상은 다항식 뿐만아니라 Neural Network에서도 발생합니다.

 

모델이 복잡하면 데이터에 Fitting하기 위해서 모델의 계수가 증가하게 됩니다.

 

3. Regularization을 사용한다.

위 실험에서 알 수 있는 점은, 모델의 구조가 복잡할수록 계수가 점점 커진다는 것 입니다. 다항식의 계수가 크기 때문에 값에 민감하게 반응하게 되고, 바로 이 때문에 테스트 데이터에 대해서 모델은 높은 에러를 가지게 됩니다.

 

따라서 계수의 크기를 줄여준다면, 모델의 계수에 의한 영향력을 줄일 수 있습니다. Regularization은 모델의 계수를 0에 가깝게 만들어 계수의 절대값을 줄이는 방법입니다. 대표적으로 Ridge나 Lasso가 있는데, 목적이 계수의 절대값을 줄이는 것이라고 이해하면 충분할 것 같습니다.