코딩일기

Gradient Accumulation과 Gradient CheckPointing을 왜 쓸까?(feat. LLM) 본문

Code/딥러닝(NL)

Gradient Accumulation과 Gradient CheckPointing을 왜 쓸까?(feat. LLM)

daje 2023. 10. 16. 18:19
728x90
반응형

Gradient Accumulation과 GradientCheckPointing이 무엇을까요? LLM을 fine-tuning 할 때, 메모리 부족 문제는 많은 연구자와 개발자들이 직면하는 고질적인 문제 중 하나입니다. 한정된 리소스 자원으로 대규모 모델을 훈련시키는 동안 메모리 부족으로 학습이 실패하는 상황은 흔히 발생하는 문제입니다. 그러나 걱정하지 마세요! 오늘은 여러분이 겪고 있는 이러한 어려움을 극복하는 두 가지 중요한 기술을 소개하려고 합니다. Gradient Accumulation과 Gradient CheckPointing은 메모리 부족 문제를 해결하는 데 필수적인 도구입니다. 이러한 기술을 익히고 적용함으로써, LLM 학습 중 메모리 문제로 인한 실패를 더 이상 겪지 않을 수 있습니다. 이 글에서는 이러한 기술들을 자세히 살펴보고, LLM fine-tuning을 원활하게 진행하는 방법에 대해 알아보겠습니다.

Gradient Accumulation

Batch size를 작게 하여 LLM을 돌리게 되면, 모델의 성능이 좋지 않게 되는데, 이를 극복하기 위해 나온 기술로 1step을 돌고 바로 가중치를 업데이트 하는 것이 아니라 4, 8, 16, 32 등 사용자가 설정한 step만큼 기다렸다가 일괄적으로 가중치를 업데이트 하는 방식입니다.

Gradient Accumulation(그래디언트 누적)은 딥 러닝 모델의 학습 과정 중에 그래디언트(gradient)를 여러 미니배치(mini-batch)에 대해 누적하고, 이 누적된 그래디언트를 사용하여 가중치 업데이트를 수행하는 기술입니다. 이 기술은 주로 큰 모델을 학습하거나 메모리가 제한된 환경에서 모델 학습을 수행할 때 유용합니다. Gradient Accumulation을 사용하는 주요 이유는 다음과 같습니다:

  1. 메모리 관리: 대규모 딥 러닝 모델을 학습할 때, 하나의 미니배치에서 그래디언트를 계산할 때 발생하는 메모리 부담이 큽니다. Gradient Accumulation을 사용하면 미니배치 크기를 작게 유지하면서도 여러 미니배치의 그래디언트를 누적하므로 메모리 사용량을 줄일 수 있습니다
  2. 더 정확한 그래디언트 추정: Gradient Accumulation은 그래디언트를 여러 미니배치에 걸쳐 누적하기 때문에, 단일 미니배치에 대한 그래디언트 추정보다 더 정확한 추정을 제공할 수 있습니다. 이는 모델 학습의 안정성을 향상시킬 수 있습니다.
  3. Gradient Accumulation을 구현하려면 다음 단계를 따를 수 있습니다:
  • 미니배치 크기 선택: 미니배치 크기를 선택하고, 각 미니배치에서 그래디언트를 계산합니다.
  • 그래디언트 누적: 각 미니배치에서 계산한 그래디언트를 누적합니다. 이를 위해 미니배치마다 그래디언트를 기존 그래디언트에 더해주거나, 누적 변수에 더해주는 방식을 사용합니다.
  • 업데이트: 그래디언트를 누적한 후, 가중치 업데이트를 수행합니다. 보통은 역전파(backpropagation) 및 경사 하강법(gradient descent)을 사용하여 가중치를 조정합니다. Gradient Accumulation을 통해 메모리 효율성을 높일 수 있지만, 학습 시간은 조금 더 오래 걸릴 수 있습니다. 따라서 미니배치 크기와 누적 스텝 수를 적절하게 조정하여 모델 학습의 메모리와 계산 비용을 조절해야 합니다. 이를 통해 모델 학습을 원활하게 진행할 수 있습니다.

이렇게 Gradient Accumulation을 사용한다고 하더라도, 메모리가 터지는 현상이 발생될 수 있습니다.

이럴때 사용할 수 있는 것이 gradient checkpointing입니다.

Gradient Checkpointing

Gradient Accumulation을 통해서 step을 모아서 가중치를 업데이트 한다고 해도, 역전파 과정에서 메모리가 터져버리는 현상이 발생될 수 있습니다. 이런 상황을 막고자 역전파 도중 가중치를 저장하여 터지는 현상을 방지하는 기술입니다.

"Gradient Checkpointing"은 딥 러닝 모델에서 역전파(backpropagation) 과정 중에 발생하는 메모리 부담을 줄이기 위한 기술입니다. 딥 러닝 모델의 역전파는 모델의 파라미터에 대한 그래디언트(gradient)를 계산하고, 이를 사용하여 가중치를 업데이트하는 과정입니다. 대규모 모델이나 시퀀스 길이가 긴 모델의 경우, 그래디언트 계산에 많은 메모리가 필요할 수 있습니다.

Gradient Checkpointing은 이러한 문제를 해결하기 위해 중간 계산 결과를 저장하고 나중에 필요할 때 계산을 다시 수행하는 방식으로 동작합니다. 이를 통해 메모리 사용량을 크게 줄일 수 있으며, 대규모 모델의 학습을 가능하게 합니다.

Gradient Checkpointing은 PyTorch에서 "torch.utils.checkpoint.checkpoint" 함수를 사용하여 구현할 수 있으며, 이 함수는 역전파 과정에서 중간 계산 결과를 저장하고 관리합니다.

예를 들어, 다음과 같이 사용할 수 있습니다:

pythonCopy code
import torch
from torch.utils.checkpoint import checkpoint

# 모델 정의
model = YourDeepLearningModel()

# 입력 데이터
input_data = torch.randn(1, 3, 64, 64)

# Gradient Checkpointing을 사용하여 역전파 계산
output = checkpoint(model, input_data)

이렇게 하면 역전파 과정에서 중간 계산 결과를 저장하고 메모리를 효율적으로 관리할 수 있습니다. Gradient Checkpointing은 메모리 사용량을 줄이면서 모델 학습을 가능하게 하므로 대규모 모델을 학습할 때 유용합니다.

그러나, 실제로 코드를 나중에 사용해보시면 아시겠지만, GradientCheckPointing은 Deepspeed:zero3에 더 많이 사용하시는 것을 보실 수 있습니다. 분산컴퓨팅할 때 반드시 필요하거든요!

이 외 읽을 만한 글

728x90
반응형
Comments