Transformers documentation
GPU
GPU
GPU๋ ๋์ ๋ฉ๋ชจ๋ฆฌ ๋์ญํญ๊ณผ ๋ณ๋ ฌ ์ฒ๋ฆฌ ๋ฅ๋ ฅ ๋๋ถ์ ๋ฅ๋ฌ๋ ๋ชจ๋ธ ํ์ต์ ๋๋ฆฌ ์ฌ์ฉ๋ฉ๋๋ค. GPU ์ฌ์๊ณผ ๋ชจ๋ธ ํฌ๊ธฐ์ ๋ฐ๋ผ ์์ญ์ต ๊ฐ ๋งค๊ฐ๋ณ์๋ฅผ ๊ฐ์ง ๋ชจ๋ธ๋ ํ์ตํ ์ ์์ต๋๋ค. ํต์ฌ์ GPU ๋ฉ๋ชจ๋ฆฌ ํ์ฉ๋(๋ฐ์ดํฐ ์ฒ๋ฆฌ๋/ํ์ต ์๊ฐ)์ ํ์ต ์๋ ์ฌ์ด์์ ์ต์ ์ ๊ท ํ์ ์ฐพ๋ ๊ฒ์ ๋๋ค.
์ด ๊ฐ์ด๋๋ Transformers์ PyTorch์์ GPU๋ฅผ ํ์ฉํด ๋ชจ๋ธ์ ํจ์จ์ ์ผ๋ก ํ์ตํ๊ธฐ ์ํด ์ ๊ณตํ๋ ๊ธฐ๋ฅ์ ์๊ฐํฉ๋๋ค. ๋๋ถ๋ถ์ ๊ฒฝ์ฐ, ์ด ๊ธฐ๋ฅ๋ค์ ์กฐํฉํด์ ํ์ต์ ์ต์ ํํ๋ ๊ฒ์ด ์ข์ต๋๋ค.
์๋ ํ๋ฅผ ์ฐธ๊ณ ํ๋ฉด ์์ ์ ํ์ต ์๋๋ฆฌ์ค์ ์ ํฉํ ๊ธฐ๋ฅ์ ๋น ๋ฅด๊ฒ ํ์ ํ ์ ์์ต๋๋ค.
๊ธฐ๋ฅ | ํ์ต ์๋ ๊ฐ์ | ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋ ์ ์ฝ |
---|---|---|
๋ฐฐ์น ํฌ๊ธฐ | ์ | ์ |
๊ทธ๋ ์ด๋์ธํธ ๋์ | ์๋์ | ์ |
๊ทธ๋ ์ด๋์ธํธ ์ฒดํฌํฌ์ธํ | ์๋์ | ์ |
ํผํฉ ์ ๋ฐ๋ | ์ | ์กฐ๊ฑด๋ถ |
์ตํฐ๋ง์ด์ | ์ | ์ |
๋ฐ์ดํฐ ์ฌ์ ์ ์ฌ | ์ | ์๋์ |
torch_empty_cache_steps | ์๋์ | ์ |
torch.compile | ์ | ์๋์ |
์ค์ผ์ผ๋ ๋ด์ ์ดํ ์ (SDPA) | ์ | ์ |
Trainer
Trainer๋ TrainingArguments๋ก ์ค์ ํ ์ ์๋ ๋ค์ํ ํ์ต ๊ธฐ๋ฅ์ ์ ๊ณตํฉ๋๋ค. ์ด๋ฒ ์น์ ์์๋ ํ์ต ์ต์ ํ์ ํนํ ์ ์ฉํ ์ฃผ์ ๊ธฐ๋ฅ ๋ช ๊ฐ์ง๋ฅผ ์ดํด๋ด ๋๋ค.
๋ฐฐ์น ํฌ๊ธฐ
๋ฐฐ์น ํฌ๊ธฐ๋ GPU ํ์ต ํจ์จ์ ์ข์ฐํ๋ ๊ฐ์ฅ ์ค์ํ ํ์ดํผํ๋ผ๋ฏธํฐ ์ค ํ๋๋ก, ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋๊ณผ ํ์ต ์๋์ ์ง์ ์ ์ธ ์ํฅ์ ์ค๋๋ค. ๋ฐฐ์น ํฌ๊ธฐ๋ฅผ ํฌ๊ฒ ํ๋ฉด GPU์ ๋ณ๋ ฌ ์ฒ๋ฆฌ ๋ฅ๋ ฅ์ ๊ทน๋ํํ์ฌ ํ์ต ์๋๋ฅผ ๋์ผ ์ ์์ต๋๋ค. ์ผ๋ฐ์ ์ผ๋ก 8, 64, 128, 256, 512์ฒ๋ผ 2์ ๊ฑฐ๋ญ์ ๊ณฑ ๊ฐ์ ์ฌ์ฉํ๋ ๊ฒ์ด ์ข์ต๋๋ค. ์ ์ ํ ๋ฐฐ์น ํฌ๊ธฐ๋ GPU ์ฌ์๊ณผ ๋ชจ๋ธ์ ๋ฐ์ดํฐ ํ์ ์ ๋ฐ๋ผ ๋ฌ๋ผ์ง๋๋ค.
๋ฐฐ์น ํฌ๊ธฐ๋ TrainingArguments์ per_device_train_batch_size()
์ต์
์ผ๋ก ์ค์ ํฉ๋๋ค.
from transformers import TrainingArguments
args = TrainingArguments(
per_device_train_batch_size=256,
per_device_eval_batch_size=256,
)
์ฑ๋ฅ, ์ ๋ ฅ ํผ์ฒ ์์ ์ถ๋ ฅ ๋ด๋ฐ ์, ๋ฐฐ์น ํฌ๊ธฐ๊ฐ ์ฑ๋ฅ์ ๋ฏธ์น๋ ์ํฅ์ ๋ํด์๋ NVIDIA Performance ๊ฐ์ด๋๋ฅผ ์ฐธ๊ณ ํ์ธ์. ์ด ๋งค๊ฐ๋ณ์๋ค์ GPU์์ ์คํ๋๋ General Matrix Multiplications(GEMMs)์ ์ฌ์ฉ๋ฉ๋๋ค. ๋งค๊ฐ๋ณ์๊ฐ ํด์๋ก ๋ณ๋ ฌํ์ ํจ์จ์ฑ์ด ํฅ์๋ฉ๋๋ค.
๋ฐ์ดํฐ ํ์ ๊ณผ GPU์ ๋ฐ๋ฅธ ์ต์ ์ ๋ฐฐ์น ํฌ๊ธฐ๋ฅผ ์ ํํด ํ ์ ๊ณฑ์ ์๋๋ฅผ ๊ทน๋ํํ๋ ค๋ฉด, Tensor Core Requirements ์น์ ์ ์ฐธ๊ณ ํ๋ ๊ฒ์ด ์ ์ฉํฉ๋๋ค. ๊ทธ ์์๋ก, fp16์์๋ 8์ ๋ฐฐ์๊ฐ ๊ถ์ฅ๋์ง๋ง, A100 GPU์์๋ 64์ ๋ฐฐ์๊ฐ ๋ ์ ํฉํ๋ค๋ ์ฌ์ค์ ํ์ธํ ์ ์์ต๋๋ค.
๋ง์ง๋ง์ผ๋ก, ์์ ๋งค๊ฐ๋ณ์๋ฅผ ์ฌ์ฉํ ๋๋ Dimension Quantization Effects๋ฅผ ๊ณ ๋ คํ์ธ์. ํ๋ ฌ ์ฐจ์์ด GPU ์ค๋ ๋ ๋ธ๋ก์ ํ์ผ ํฌ๊ธฐ๋ก ๋๋์ด์ง์ง ์์ผ๋ฉด ํ์ผ ์์ํ๊ฐ ๋ฐ์ํ์ฌ GPU ์์์ ์ถฉ๋ถํ ํ์ฉํ์ง ๋ชปํฉ๋๋ค. ํ๋ ฌ์ด ํ์ผ ํฌ๊ธฐ๋ก ์ ํํ ๋๋๋๋ก ์ฌ๋ฐ๋ฅธ ๋ฐฐ์น ํฌ๊ธฐ ๋ฐฐ์๋ฅผ ์ ํํ๋ฉฐ ํ์ต ์๋๊ฐ ํฌ๊ฒ ํฅ์๋ฉ๋๋ค.
๊ทธ๋ ์ด๋์ธํธ ๋์
๊ทธ๋ ์ด๋์ธํธ ๋์ ์ ๋ฉ๋ชจ๋ฆฌ ์ ์ฝ์ ๊ทน๋ณตํ๋ ๋ฐฉ๋ฒ์ผ๋ก, ๋จ์ผ GPU์ ๋ง์ง ์๋ ๋งค์ฐ ํฐ ๋ชจ๋ธ์ ํ์ตํ ๋ ์ ์ฉํฉ๋๋ค. ์ด๋ ๋งค๊ฐ๋ณ์๋ฅผ ์ ๋ฐ์ดํธํ๊ธฐ ์ ์ ์ฌ๋ฌ ๋ฏธ๋ ๋ฐฐ์น์ ๊ฑธ์ณ ๊ทธ๋ ์ด๋์ธํธ๋ฅผ ๋์ ํ๋ ๋ฐฉ์์ ๋๋ค. ๊ทธ ๊ฒฐ๊ณผ, ์ ์ฅํด์ผ ํ๋ ๊ทธ๋ ์ด๋์ธํธ ์๊ฐ ์ค์ด ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ด ์ค์ด๋ค๊ณ , ์ผ๋ฐ์ ์ผ๋ก ํ๋์ ๋ฐฐ์น์์๋ง ๋งค๊ฐ๋ณ์๋ฅผ ๊ฐฑ์ ํ๋ ๋ฐฉ์๋ณด๋ค ๋ ํฐ ์ ํจ ๋ฐฐ์น ํฌ๊ธฐ๋ก ํ์ตํ ์ ์์ต๋๋ค. ๋ค๋ง, ์ถ๊ฐ์ ์ธ ์์ ํ์ ์ญ์ ํ๊ฐ ํ์ํ๊ธฐ ๋๋ฌธ์ ํ์ต ์๋๊ฐ ๋๋ ค์ง ์ ์์ต๋๋ค.
๊ทธ๋ ์ด๋์ธํธ ๋์ ์ ํ์ฑํํ๋ ค๋ฉด TrainingArguments์์ TrainingArguments.per_device_train_batch_size()
์ต์
์ ์ค์ ํ์ธ์.
from transformers import TrainingArguments
# ํจ์จ์ ์ธ ๋ฐฐ์น ํฌ๊ธฐ 64
args = TrainingArguments(
per_device_train_batch_size=4,
gradient_accumulation_steps=16,
)
ํ์ต ์๋๊ฐ ๋๋ ค์ง ์ ์๊ธฐ ๋๋ฌธ์ ๊ทธ๋ ์ด๋์ธํธ ๋์ ๋จ๊ณ๋ฅผ ๋๋ฌด ํฌ๊ฒ ์ค์ ํ์ง ์๋ ๊ฒ์ด ์ข์ต๋๋ค. ์๋ ์์๋ฅผ ์ฐธ๊ณ ํ์ธ์, GPU์ ๋ด์ ์ ์๋ ์ต๋ ๋ฐฐ์น ํฌ๊ธฐ๊ฐ 4๋ผ๋ฉด GPU์ ํจ์จ์ ์ธ ์ฌ์ฉ์ ์ํด ๋ฐฐ์น ํฌ๊ธฐ๋ฅผ 4๋ก ์ ์งํ๋ ๊ฒ์ด ์ข์ต๋๋ค.
๋ฐฐ์น ํฌ๊ธฐ | ๊ทธ๋ ์ด๋์ธํธ ๋์ ๋จ๊ณ | ํจ์จ์ ์ธ ๋ฐฐ์น ํฌ๊ธฐ | |
---|---|---|---|
1 | 64 | 64 | ๐ |
4 | 16 | 64 | ๐ |
๊ทธ๋ ์ด๋์ธํธ ์ฒดํฌํฌ์ธํ
๊ทธ๋ ์ด๋์ธํธ ์ฒดํฌํฌ์ธํ ์ ์ญ์ ํ ๊ณผ์ ์์ ์ผ๋ถ ์ค๊ฐ ํ์ฑํ ๊ฐ๋ง ์ ์ฅํ๊ณ ๋๋จธ์ง๋ ๋ค์ ๊ณ์ฐํด ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ ์ค์ ๋๋ค. ์ด๋ฅผ ํตํด ์์ ํ ๊ณผ์ ์์ ๋ชจ๋ ์ค๊ฐ ํ์ฑํ ๊ฐ์ ์ ์ฅํ์ง ์์๋ ๋์ด ๋ฉ๋ชจ๋ฆฌ ์ค๋ฒํค๋๋ฅผ ํฌ๊ฒ ์ค์ผ ์ ์์ต๋๋ค. ๋ค๋ง, ํ์ต ์๋๊ฐ ์ฝ 20% ๋๋ ค์ง๋ ํ๊ณ๊ฐ ์์ต๋๋ค.
๊ทธ๋ ์ด๋์ธํธ ๋์ ์ ํ์ฑํํ๋ ค๋ฉด TrainingArguments์์ gradient_checkpointing()
์ต์
์ ์ค์ ํ์ธ์.
from transformers import TrainingArguments
args = TrainingArguments(
per_device_train_batch_size=4,
gradient_accumulation_steps=16,
gradient_checkpointing=True,
)
ํผํฉ ์ ๋ฐ๋
ํผํฉ ์ ๋ฐ๋๋ ์ผ๋ถ ๊ณ์ฐ์ ๋ฐ์ ๋ฐ๋(fp16)๋ก, ๋๋จธ์ง๋ฅผ ์ ์ ๋ฐ๋(fp32)๋ก ์ํํด ํ์ต ์๋๋ฅผ ๋์ด๋ ๊ธฐ๋ฒ์ ๋๋ค. ๋ฐ์ ๋ฐ๋ ๊ณ์ฐ์ ์ ์ ๋ฐ๋๋ณด๋ค ๊ณ์ฐ๋์ด ์ ์ด ๋ ๋น ๋ฅด๊ฒ ์ํ๋ฉ๋๋ค. ํํธ, ์ ์ ๋ฐ๋๋ก ์ผ๋ถ ๊ณ์ฐ์ ์ํํ๋ฉด ์ ํ๋๋ฅผ ์ ์งํ ์ ์์ต๋๋ค.
ํผํฉ ์ ๋ฐ๋ ํ์ต์ ์ํด ์ฌ๋ฌ ์๋ฃํ์ ์ฌ์ฉํ ์ ์์ต๋๋ค.
ํผํฉ ์ ๋ฐ๋ ํ์ต์ ์ฃผ์ ์ฅ์ ์ ํ์ฑํ ๊ฐ์ fp16์ผ๋ก ์ ์ฅํ ์ ์๋ค๋ ๊ฒ์ ๋๋ค.
fp16 ์๋ฃํ์ผ๋ก ํผํฉ ์ ๋ฐ๋ ํ์ต์ ํ์ฑํํ๋ ค๋ฉด TrainingArguments์์ fp16()
์ต์
์ ์ค์ ํ์ธ์.
from transformers import TrainingArguments
args = TrainingArguments(
per_device_train_batch_size=4,
gradient_accumulation_steps=16,
gradient_checkpointing=True,
fp16=True.
)
fp16์ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ์ ์ต์ ํ๋ ๋ฐฉ์์ด ์๋๋๋ค. ์ด๋ fp16์ผ๋ก ๊ณ์ฐ๋ ๊ทธ๋ ์ด๋์ธํธ๊ฐ ์ต์ ํ ๋จ๊ณ์์ fp32๋ก ๋ค์ ๋ณํ๋๊ธฐ ๋๋ฌธ์ ๋๋ค. ํนํ ๋ฐฐ์น ํฌ๊ธฐ๊ฐ ์์ ๋๋, GPU์ ๋ ๊ฐ์ง ์๋ฃํ(fp16, fp32)์ด ์ ์ฌ๋์ด ์๊ธฐ ๋๋ฌธ์ ๋ ๋ง์ GPU ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ฌ์ฉํ๊ฒ ๋ฉ๋๋ค.
์ตํฐ๋ง์ด์
Transformers๋ ๊ธฐ๋ณธ์ ์ผ๋ก PyTorch์ AdamW (adamw_torch) ์ตํฐ๋ง์ด์ ๋ฅผ ์ฌ์ฉํฉ๋๋ค. ํ์ง๋ง, ์ด ์ตํฐ๋ง์ด์ ๋ ๊ณผ๊ฑฐ ๊ทธ๋ ์ด๋์ธํธ์ ๊ฐ์ค ํ๊ท ์ ์ ์ฅํ๊ธฐ ๋๋ฌธ์, ๊ทธ๋ ์ด๋์ธํธ๋ฅผ ์ ์ฅํ๊ธฐ ์ํด ๋ชจ๋ธ ๋งค๊ฐ๋ณ์ ์์ ๋น๋กํ ์ถ๊ฐ ๋ฉ๋ชจ๋ฆฌ๊ฐ ํ์ํฉ๋๋ค. ์ด๋ ๋งค์ฐ ํฐ ๋ชจ๋ธ์ ํ์ตํ ๋ ๋ฌธ์ ๊ฐ ๋ ์ ์์ผ๋ฉฐ, ์ด๋ฌ๋ฉด ๋ค๋ฅธ ์ตํฐ๋ง์ด์ ๋ฅผ ์ ํํ๋ ๊ฒ์ ๊ณ ๋ คํด์ผ ํฉ๋๋ค. ์๋ฅผ ๋ค์ด, NVIDIA ๋๋ AMD์ Apex๊ฐ ์ค์น๋์ด ์๋ค๋ฉด, ๋ชจ๋ AdamW ์ตํฐ๋ง์ด์ ์ค adamw_apex_fused
์ตํฐ๋ง์ด์ ๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ด ๊ฐ์ฅ ๋น ๋ฅธ ํ์ต ์๋๋ฅผ ์ป์ ์ ์์ต๋๋ค.
์ตํฐ๋ง์ด์ ๋ฅผ ์ ํํ๊ธฐ ์ํด์๋ TrainingArguments์์ optim()
์ต์
์ ์ค์ ํ์ธ์.
from transformers import TrainingArguments
args = TrainingArguments(
per_device_train_batch_size=4,
gradient_accumulation_steps=16,
gradient_checkpointing=True,
bf16=True,
optim="adamw_bnb_8bit"
)
ํ์ต ์๋๋ฆฌ์ค์ ๋ง๊ฒ ์ ํํ ์ ์๋ ๋ค์ํ ์ตํฐ๋ง์ด์ ๊ฐ ์์ต๋๋ค. (์ ์ฒด ์ง์ ๋ชฉ๋ก์ OptimizerNames๋ฅผ ์ฐธ๊ณ ํ์ธ์) ์๋ฅผ ๋ค์ด Adafactor๋ ํ๋ ฌ์ ๊ฐ ์์ ๋์ ํ ๋๋ ์ด ๋จ์์ ๊ฐ์ค ํ๊ท ๋ง ์ ์ฅํด ๋ฉ๋ชจ๋ฆฌ ์๊ตฌ๋์ ํฌ๊ฒ ์ค์ผ ์ ์์ง๋ง, ์๋ ด ์๋๋ ๋๋ ค์ง ์ ์์ต๋๋ค. ๋ ๋ค๋ฅธ ์๋ก, bitandbytes์ 8-bit AdamW optimizer๋ฅผ ์ฌ์ฉํ๋ฉด ์ตํฐ๋ง์ด์ ์ ์ํ๋ฅผ 8๋นํธ๋ก ์์ํํ ์ ์์ต๋๋ค. ์ตํฐ๋ง์ด์ ์ํ๋ ๋ฎ์ ์ ๋ฐ๋๋ก ์ ์ฅ๋์๋ค๊ฐ ์ตํฐ๋ง์ด์ ๋จ๊ณ์์ ์ฌ์ฉ๋๊ธฐ ์ ์ ์ญ ์์ํ๋ฉ๋๋ค.
ํนํ๋ ์ตํฐ๋ง์ด์ ์ ๋ํด ๋ ์๊ณ ์ถ๋ค๋ฉด optimizer ๊ฐ์ด๋๋ฅผ ์ฐธ๊ณ ํ์ธ์.
๋ฐ์ดํฐ ์ฌ์ ์ ์ฌ
๋ฐ์ดํฐ ์ฌ์ ์ ์ฌ(Data preloading)๋ GPU๊ฐ ์ง์์ ์ผ๋ก ์์ ํ ์ ์๋๋ก CPU์์ ๋ฏธ๋ฆฌ ๋ฐฐ์น ๋จ์์ ๋ฐ์ดํฐ๋ฅผ ์ ์ฌํ๊ณ ์ค๋นํ๋ ๊ธฐ๋ฅ์ ๋๋ค. ์ด๋ฅผ ํตํด GPU ์ ํด ์๊ฐ์ ์ค์ด๊ณ ํ์ฉ๋๋ฅผ ๋์ผ ์ ์์ต๋๋ค. GPU๊ฐ ํญ์ ์์ ์ ๊ณ์ํ๋๋ก ํ๋ ค๋ฉด ๋ค์ ๋ฐ์ดํฐ ์ฌ์ ์ ์ฌ๋ฅผ ์ํ ๋ ๊ฐ์ง ๋ฐฉ๋ฒ์ ์ฌ์ฉํ ์ ์์ต๋๋ค.
- ๋ฐ์ดํฐ๋ฅผ ์ ์ฅํ ๊ณ ์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ CPU์ ํ ๋นํ ๋ค, ์ด๋ฅผ GPU๋ก ์ง์ ์ ์กํฉ๋๋ค.
- CPU ์ค๋ ๋ ๋ฐ ์์ปค ์๋ฅผ ๋๋ ค ๋ฐ์ดํฐ๋ฅผ ๋ ๋น ๋ฅด๊ฒ ์ฌ์ ์ ์ฌํฉ๋๋ค.
๊ณ ์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ํ ๋นํ๊ณ ์์ปค ์๋ฅผ ๋๋ฆฌ๊ธฐ ์ํด์๋ TrainingArguments์์ dataloader_pin_memory()
์ dataloader_num_workers()
์ต์
์ ์ค์ ํ์ธ์.
from transformers import TrainingArguments
args = TrainingArguments(
per_device_train_batch_size=4,
gradient_accumulation_steps=16,
gradient_checkpointing=True,
bf16=True,
optim="adamw_bnb_8bit",
dataloader_pin_memory=True,
dataloader_num_workers=4,
)
PyTorch
PyTorch๋ ๋ฉ๋ชจ๋ฆฌ ์๊ตฌ์ฌํญ์ ์ค์ด๊ณ ํ์ต ์๋๋ฅผ ๋์ด๊ธฐ ์ํ ์ฌ๋ฌ ๊ธฐ๋ฅ์ ์ ๊ณตํฉ๋๋ค. ์ด๋ฌํ ๊ธฐ๋ฅ๋ค์ Transformers์์ ๋ช ์ค์ ์ฝ๋๋ง ์ถ๊ฐํ์ฌ ํ์ฑํํ ์ ์์ต๋๋ค.
torch.empty_cache_steps
torch.cuda.empty_cache ํจ์๋ ์ฌ์ฉํ์ง ์๋ ์บ์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ํด์ ํ์ฌ ๋ฉ๋ชจ๋ฆฌ ๋ถ์กฑ(OOM) ์ค๋ฅ๋ฅผ ๋ฐฉ์งํ ์ ์์ง๋ง, ํ์ต ์๋๊ฐ ์ฝ 10% ๋๋ ค์ง ์ ์์ต๋๋ค.
ํน์ ํ์ต ๋จ๊ณ ์ดํ์ ์ด ๊ธฐ๋ฅ์ ํ์ฑํํ๊ณ ์ถ๋ค๋ฉด, TrainingArguments์์ torch_empty_cache_steps()๋ฅผ ์ค์ ํ์ธ์.
from transformers import TrainingArguments
args = TrainingArguments(
per_device_train_batch_size=4,
gradient_accumulation_steps=16,
gradient_checkpointing=True,
bf16=True,
optim="adamw_bnb_8bit",
dataloader_pin_memory=True,
dataloader_num_workers=4,
torch_empty_cache_steps=4,
)
torch.compile
torch.compile์ PyTorch ์ฝ๋๋ฅผ ์ต์ ํ๋ ์ปค๋๋ก ์ปดํ์ผํด ํ์ต ์๋๋ฅผ ํฌ๊ฒ ๋์ฌ์ค๋๋ค. ์ด ๊ธฐ๋ฅ์ TorchDynamo๋ฅผ ์ฌ์ฉํด ํ๋ ์ ํ๊ฐ API๋ก๋ถํฐ PyTorch ๊ทธ๋ํ๋ฅผ ์บก์ฒํ๋ฉฐ, ์ด๋ ๊ฒ ์บก์ฒํ ๊ทธ๋ํ๋ ๋ค์ํ ๋ฐฑ์๋์ ์ถ๊ฐ๋ก ์ต์ ํ๋ ์ปค๋๋ก ์ปดํ์ผ๋ ์ ์์ต๋๋ค.
์ด๋ฅผ ํ์ฑํํ๋ ค๋ฉด TrainingArguments์์ torch_compile()
๋ฅผ ์ค์ ํ์ธ์. ๋ฐฑ์๋๋ torch_compile_backend()๋ฅผ ํตํด ์ ํํ ์ ์์ต๋๋ค.
from transformers import TrainingArguments
args = TrainingArguments(
per_device_train_batch_size=4,
gradient_accumulation_steps=16,
gradient_checkpointing=True,
bf16=True,
optim="adamw_bnb_8bit",
dataloader_pin_memory=True,
dataloader_num_workers=4,
torch_empty_cache_steps=4,
torch_compile=True,
torch_compile_backend="inductor"
)
์๋ ํ๋ฅผ ์ฐธ๊ณ ํ์ฌ ํ์ต ์๋๋ฆฌ์ค์ ์ ํฉํ ๋ฐฑ์๋๋ฅผ ์ ํํ์ธ์.
๋ฐฑ์๋ | ์ค๋ช | ๋ชฉํ |
---|---|---|
eager | PyTorch๋ฅผ ์ฌ์ฉํด ์ถ์ถ๋ GraphModule์ ์คํํฉ๋๋ค | ๋๋ฒ๊น |
aot_eager | AOTAutograd๋ก ์ถ์ถ๋ ์์ ํ ๋ฐ ์ญ์ ํ ๊ทธ๋ํ๋ฅผ Pytorch eager ๋ชจ๋๋ก ์คํํฉ๋๋ค | ๋๋ฒ๊น |
inductor | Triton ์ปค๋์ ํ์ฉํ๋ TorchInductor์ AOTAutograd, CUDA Graphs๋ฅผ ์ฌ์ฉํฉ๋๋ค | ํ์ต ๋ฐ ์ถ๋ก |
nvfuser | TorchScript์ ํจ๊ป nvFuser๋ฅผ ์ฌ์ฉํฉ๋๋ค | ํ์ต ๋ฐ ์ถ๋ก |
aot_nvfuser | AOTAutograd์ ํจ๊ป nvFuser๋ฅผ ์ฌ์ฉํฉ๋๋ค | ํ์ต ๋ฐ ์ถ๋ก |
aot_cudagraphs | AOTAutograd์ ํจ๊ป CUDA Graphs๋ฅผ ์ฌ์ฉํฉ๋๋ค | ํ์ต ๋ฐ ์ถ๋ก |
ofi | TorchScripts์ optimize_for_inference๋ฅผ ์ฌ์ฉํฉ๋๋ค | ์ถ๋ก |
fx2trt | Torch-TensorRT๋ฅผ ์ฌ์ฉํฉ๋๋ค | ์ถ๋ก |
onnxrt | CPU ๋ฐ GPU ์ถ๋ก ์ ์ํด ONNX-RT๋ฅผ ์ฌ์ฉํฉ๋๋ค | ์ถ๋ก |
ipex | CPU ์ถ๋ก ์ ์ํด IPEX๋ฅผ ์ฌ์ฉํฉ๋๋ค | ์ถ๋ก |
์ค์ผ์ผ๋ ๋ด์ ์ดํ ์
torch.nn.functional.scaled_dot_product_attention (SDPA)๋ ์ค์ผ์ผ๋ ๋ด์ ์ดํ ์ ๋ฉ์ปค๋์ฆ์ PyTorch์ ๋ด์ฅํด ๊ตฌํํ ํจ์์ ๋๋ค. SDPA๋ ํธ๋์คํฌ๋จธ ๋ชจ๋ธ์ ๊ธฐ์กด ์ดํ ์ ๋ฉ์ปค๋์ฆ๋ณด๋ค ๋ ํจ์จ์ ์ด๊ณ ์ต์ ํ๋์ด ์์ต๋๋ค. ์ธ ๊ฐ์ง ์ ํ์ ์ค์ผ์ผ๋ ๋ด์ ์ดํ ์ ์ ์ง์ํฉ๋๋ค.
- FlashAttention2๋ fp16 ๋๋ bf16 torch ํ์ ๋ชจ๋ธ์์ ์๋์ผ๋ก ํ์ฑํ๋ฉ๋๋ค. ๋จผ์ ๋ชจ๋ธ์ ์ ์ ํ ํ์ ์ผ๋ก ์บ์คํ ํ๋์ง ํ์ธํ์ธ์.
- xFormers ๋๋ Memory-Efficient Attention์ fp32 torch ํ์ ๋ชจ๋ธ์ ์ง์ํฉ๋๋ค.
- C++๋ก ๊ตฌํ๋ ์ค์ผ์ผ๋ ๋ด์ ์ดํ ์ ์ ๋๋ค.
SDPA๋ PyTorch 2.1.1 ๋ฒ์ ์ด์์์ ๊ธฐ๋ณธ์ ์ผ๋ก ํ์ฑํ๋์ด ์์ง๋ง, from_pretrained()์์ attn_implementation="sdpa"
๋ฅผ ์ค์ ํด ๋ช
์์ ์ผ๋ก ํ์ฑํํ ์ ์์ต๋๋ค.
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B", device_map="auto", attn_implementation="sdpa")