pytorch 관련 compile 기능 정리....
on
1. 정적 그래프의 시대: TorchScript (torch.jit.script
, torch.jit.trace
)
PyTorch는 처음부터 “동적 그래프(Dynamic Graph)”를 강조했습니다. 코드가 실행될 때마다 연산 그래프가 실시간으로 생성되어 디버깅이 쉽고 유연하다는 장점이 있었죠. 하지만 프로덕션 환경에서 모델을 배포하거나, 파이썬 인터프리터의 오버헤드를 줄여 최고 성능을 내기 위해서는 정적인 그래프(Static Graph)가 필요했습니다.
이러한 필요성에서 PyTorch 1.0과 함께 TorchScript가 등장했습니다.
torch.jit.script
:- 탄생 배경: 파이썬 인터프리터의 오버헤드를 줄이고, 파이썬 환경 없이도 모델을 배포(예: C++ 애플리케이션)하기 위해 개발되었습니다.
if
,for
같은 파이썬의 제어 흐름까지 그래프에 포함할 수 있도록 코드를 명시적으로 정적 그래프로 스크립팅하는 방식이었습니다. - 장점: 제어 흐름을 포함한 모델 전체를 컴파일 가능하며, 파이썬 종속성 없이 배포 가능합니다.
- 단점: TorchScript가 이해할 수 있는 “TorchScript-friendly”한 코드로 작성해야 했기에, 기존 파이썬 코드를 수정해야 하는 경우가 많았고, 일부 파이썬 기능은 지원되지 않아 개발자에게 제약이 있었습니다.
- 탄생 배경: 파이썬 인터프리터의 오버헤드를 줄이고, 파이썬 환경 없이도 모델을 배포(예: C++ 애플리케이션)하기 위해 개발되었습니다.
torch.jit.trace
:- 탄생 배경:
script
의 코드 수정 부담을 줄이기 위해 등장했습니다. - 장점: 모델을 실제 입력으로 한 번 실행시켜서 그 실행 경로를 “추적(trace)”하여 그래프를 생성합니다. 기존 파이썬 코드를 거의 변경하지 않고도 적용하기 쉽습니다.
- 단점: 실행 시점의 특정 입력에 대한 경로만 추적하기 때문에, 입력 데이터에 따라 실행 경로가 달라지는 동적인 제어 흐름(예: 배치 크기에 따라 분기하는
if
문)이 있다면 추적된 그래프는 하나의 경로만 고정되어 있어 다른 입력에 대해 올바르게 작동하지 않을 수 있습니다.
- 탄생 배경:
2. 동적 그래프 컴파일의 혁신: torch.compile
시대
TorchScript는 중요한 역할을 했지만, 파이썬의 동적인 특성과 타협해야 하는 부분이 많아 개발자들이 아쉬움을 느끼는 지점이 있었습니다. 이에 PyTorch는 2.0에서 torch.compile
을 필두로 한 새로운 컴파일러 스택을 도입하며 혁신을 꾀했습니다. 목표는 하나였습니다: “파이썬의 유연성을 유지하면서도 컴파일의 성능 이점을 모두 누리게 하자!”
torch.compile
:- 탄생 배경:
torch.jit
의 한계(코드 수정 필요, 동적 제어 흐름 처리의 어려움)를 극복하고, PyTorch 모델에 더 쉽고 강력한 최적화를 적용하기 위해 개발되었습니다. - 장점:
@torch.compile
데코레이터나torch.compile(model)
한 줄만으로 적용 가능하며, 대부분의 경우 기존 파이썬 코드를 수정할 필요가 없습니다. 내부적으로torch.dynamo
기술을 사용하여 파이썬 코드를 분석, PyTorch 연산 그래프를 추출하고, 컴파일러가 처리할 수 없는 일반 파이썬 연산은 동적 실행(폴백)으로 넘겨 유연성을 극대화합니다. - 백엔드: 추출된 그래프는 다양한 컴파일러 백엔드(기본
inductor
백엔드, Triton, C++/CUDA 등을 활용)로 전달되어 최적화된 실행 가능한 코드를 생성합니다. 이는 성능 향상에 크게 기여합니다.
- 탄생 배경:
3. 중간 다리 역할: torch.fx
와 Torch-MLIR
torch.compile
이 파이썬 코드를 컴파일하는 상위 레벨의 인터페이스라면, 그 밑단에서는 모델의 중간 표현을 다루고 다양한 백엔드로 변환하는 중요한 도구들이 작동합니다.
torch.fx
(Functional eXchange):- 탄생 배경:
torch.jit
그래프는 C++ 기반이라 파이썬에서 직접 조작하기 어려웠습니다. PyTorch 모델의 연산 그래프를 파이썬에서 직접 다루고 변형할 수 있는 표준화된 방법을 제공하고자 했습니다. - 용도: PyTorch 모델의 연산 그래프를 파이썬에서 조작 가능한 심볼릭 그래프(FX Graph)로 표현합니다.
torch.dynamo
가 파이썬 코드를 캡처하면, 이 결과물이 바로torch.fx
그래프 형태로 생성됩니다.torch.fx
는 컴파일러 최적화를 위한 중요한 중간 표현(IR)이자, 사용자 정의 그래프 변형을 위한 강력한 도구입니다.
- 탄생 배경:
- Torch-MLIR:
- 탄생 배경: PyTorch 모델을 LLVM, TVM, IREE 등 다양한 하드웨어 백엔드에 배포하고 최적화하기 위해 MLIR(Multi-Level Intermediate Representation)의 확장 가능한 컴파일러 인프라를 활용하고자 했습니다. PyTorch의 복잡한 연산을 MLIR로 매핑하는 표준화된 방식이 필요했습니다.
- 용도:
torch.fx
그래프(또는torch.jit
그래프)와 같은 PyTorch 중간 표현을 MLIR로 변환하는 데 특화된 오픈소스 프로젝트입니다. MLIR 내에torch
다이얼렉트를 정의하고, 이를linalg
,tensor
,scf
등 MLIR의 하위 다이얼렉트로 점진적으로 로어링(lowering)하는 패스들을 제공합니다. - 위치:
torch.compile
의 플러그인 가능한 백엔드 중 하나로 사용될 수 있습니다. 즉,torch.compile
이torch.dynamo
를 통해torch.fx
그래프를 생성하면,Torch-MLIR
백엔드는 이를 받아서 MLIR로 변환하고 MLIR의 강력한 최적화 파이프라인을 거쳐 최종 코드를 생성합니다.
마치며
PyTorch의 컴파일러 스택은 계속 진화하고 있으며, torch.compile
은 그 중심에서 파이썬의 유연성과 컴파일된 코드의 성능을 동시에 제공하고자 합니다. torch.fx
는 이 과정에서 중요한 중간 표현 역할을 하고, Torch-MLIR
은 MLIR이라는 강력한 인프라를 활용하여 PyTorch 모델을 더 넓은 범위의 하드웨어에서 효율적으로 실행할 수 있도록 돕는 다리 역할을 합니다.
이러한 도구들을 이해하면 PyTorch 모델을 더 깊이 있게 최적화하고 원하는 배포 환경에 맞춰 활용할 수 있는 인사이트를 얻으실 수 있을 거예요.