AI Engineering Topic/AI 모델 배포

파이토치 2.0 torch.compile() 이 얼마나 빠른지 알아보자

Young_Metal 2023. 4. 25. 13:53

출처 : https://discuss.pytorch.kr/t/accelerating-large-language-models-with-accelerated-transformers/1417

 

파이토치 2를 사용한 가속화된 생성 확산 모델(Accelerated Generative Diffusion Models with PyTorch 2) 🎉

PyTorch 공식 블로그에 게시된 Accelerated Generative Diffusion Models with PyTorch 2 글을 퍼왔습니다. 🙂 아래는 원문과 함께 DeepL이 번역한 내용입니다 - Translated with DeepL Accelerated Generative Diffusion Models with PyT

discuss.pytorch.kr

 

Introduction

생성모델의 경우 생성 루프 안에서 돌아가는 코드의 최적화가 생성속도를 높이는 주요 요인이 된다. 

 

pytorch 2.0에서 compile과 빠른 attention 구현으로 생성속도를 높였다. 원래 xFormers를 실행하는 1.0버전의 토치와 2.0 버전의 토치를 비교해서 

 

Optimized Attention

 

Diffusion 모델에서 Attention을 U-Net의 여러 트랜스포머 블록으로 사용하고 이 U-Net은 모든 샘플링 단계에서 실행되므로 속도를 줄이는 주요한 포인트다. 구현이 된 nn.MultiHeadAttention 은 처음에는 아래와 같이 쓰이지만 후에는 

class CrossAttention(nn.Module):
    def __init__(self, ...):
        # Create matrices: Q, K, V, out_proj
        # 행렬을 생성합니다: Q, K, V, out_proj
        ...
    def forward(self, x, context=None, mask=None):
       # Compute out = SoftMax(Q*K/sqrt(d))V
       # Return out_proj(out)
       # out_proj(out) 반환
       …

아래와 같이 대체된다. 

class CrossAttention(nn.Module):
    def __init__(self, ...):
        self.mha = nn.MultiheadAttention(...)
    def forward(self, x, context):
	return self.mha(x, context, context)

model = torch.compile(model)

동적 컴파일러

기본 동작으로 PyTorch는 내부적으로 TorchDynamo를 사용하여 코드를 컴파일하고 TorchInductor 3를 사용하여 코드를 더욱 최적화합니다

파이토치가 컴파일 할 수 없는 부분, graph break를 피해야 한다. 예전과 달리 graph break에 컴파일이 중단되지 않는다. 성능은 저하될 수 있다. 컴파일러가 지원하지 않는 라이브러리에서 함수를 지운다.