lynx   »   [go: up one dir, main page]

Transformers documentation

GPU

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v4.56.2).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

GPU

GPU๋Š” ๋†’์€ ๋ฉ”๋ชจ๋ฆฌ ๋Œ€์—ญํญ๊ณผ ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ ๋Šฅ๋ ฅ ๋•๋ถ„์— ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ ํ•™์Šต์— ๋„๋ฆฌ ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค. GPU ์‚ฌ์–‘๊ณผ ๋ชจ๋ธ ํฌ๊ธฐ์— ๋”ฐ๋ผ ์ˆ˜์‹ญ์–ต ๊ฐœ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ๊ฐ€์ง„ ๋ชจ๋ธ๋„ ํ•™์Šตํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ํ•ต์‹ฌ์€ GPU ๋ฉ”๋ชจ๋ฆฌ ํ™œ์šฉ๋„(๋ฐ์ดํ„ฐ ์ฒ˜๋ฆฌ๋Ÿ‰/ํ•™์Šต ์‹œ๊ฐ„)์™€ ํ•™์Šต ์†๋„ ์‚ฌ์ด์—์„œ ์ตœ์ ์˜ ๊ท ํ˜•์„ ์ฐพ๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

์ด ๊ฐ€์ด๋“œ๋Š” Transformers์™€ PyTorch์—์„œ GPU๋ฅผ ํ™œ์šฉํ•ด ๋ชจ๋ธ์„ ํšจ์œจ์ ์œผ๋กœ ํ•™์Šตํ•˜๊ธฐ ์œ„ํ•ด ์ œ๊ณตํ•˜๋Š” ๊ธฐ๋Šฅ์„ ์†Œ๊ฐœํ•ฉ๋‹ˆ๋‹ค. ๋Œ€๋ถ€๋ถ„์˜ ๊ฒฝ์šฐ, ์ด ๊ธฐ๋Šฅ๋“ค์„ ์กฐํ•ฉํ•ด์„œ ํ•™์Šต์„ ์ตœ์ ํ™”ํ•˜๋Š” ๊ฒƒ์ด ์ข‹์Šต๋‹ˆ๋‹ค.

์•„๋ž˜ ํ‘œ๋ฅผ ์ฐธ๊ณ ํ•˜๋ฉด ์ž์‹ ์˜ ํ•™์Šต ์‹œ๋‚˜๋ฆฌ์˜ค์— ์ ํ•ฉํ•œ ๊ธฐ๋Šฅ์„ ๋น ๋ฅด๊ฒŒ ํŒŒ์•…ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๊ธฐ๋Šฅ ํ•™์Šต ์†๋„ ๊ฐ€์† ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ์ ˆ์•ฝ
๋ฐฐ์น˜ ํฌ๊ธฐ ์˜ˆ ์˜ˆ
๊ทธ๋ ˆ์ด๋””์–ธํŠธ ๋ˆ„์  ์•„๋‹ˆ์š” ์˜ˆ
๊ทธ๋ ˆ์ด๋””์–ธํŠธ ์ฒดํฌํฌ์ธํŒ… ์•„๋‹ˆ์š” ์˜ˆ
ํ˜ผํ•ฉ ์ •๋ฐ€๋„ ์˜ˆ ์กฐ๊ฑด๋ถ€
์˜ตํ‹ฐ๋งˆ์ด์ € ์˜ˆ ์˜ˆ
๋ฐ์ดํ„ฐ ์‚ฌ์ „ ์ ์žฌ ์˜ˆ ์•„๋‹ˆ์š”
torch_empty_cache_steps ์•„๋‹ˆ์š” ์˜ˆ
torch.compile ์˜ˆ ์•„๋‹ˆ์š”
์Šค์ผ€์ผ๋œ ๋‚ด์  ์–ดํ…์…˜ (SDPA) ์˜ˆ ์˜ˆ

Trainer

Trainer๋Š” TrainingArguments๋กœ ์„ค์ •ํ•  ์ˆ˜ ์žˆ๋Š” ๋‹ค์–‘ํ•œ ํ•™์Šต ๊ธฐ๋Šฅ์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค. ์ด๋ฒˆ ์„น์…˜์—์„œ๋Š” ํ•™์Šต ์ตœ์ ํ™”์— ํŠนํžˆ ์œ ์šฉํ•œ ์ฃผ์š” ๊ธฐ๋Šฅ ๋ช‡ ๊ฐ€์ง€๋ฅผ ์‚ดํŽด๋ด…๋‹ˆ๋‹ค.

๋ฐฐ์น˜ ํฌ๊ธฐ

๋ฐฐ์น˜ ํฌ๊ธฐ๋Š” GPU ํ•™์Šต ํšจ์œจ์„ ์ขŒ์šฐํ•˜๋Š” ๊ฐ€์žฅ ์ค‘์š”ํ•œ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ ์ค‘ ํ•˜๋‚˜๋กœ, ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰๊ณผ ํ•™์Šต ์†๋„์— ์ง์ ‘์ ์ธ ์˜ํ–ฅ์„ ์ค๋‹ˆ๋‹ค. ๋ฐฐ์น˜ ํฌ๊ธฐ๋ฅผ ํฌ๊ฒŒ ํ•˜๋ฉด GPU์˜ ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ ๋Šฅ๋ ฅ์„ ๊ทน๋Œ€ํ™”ํ•˜์—ฌ ํ•™์Šต ์†๋„๋ฅผ ๋†’์ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ผ๋ฐ˜์ ์œผ๋กœ 8, 64, 128, 256, 512์ฒ˜๋Ÿผ 2์˜ ๊ฑฐ๋“ญ์ œ๊ณฑ ๊ฐ’์„ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ด ์ข‹์Šต๋‹ˆ๋‹ค. ์ ์ ˆํ•œ ๋ฐฐ์น˜ ํฌ๊ธฐ๋Š” GPU ์‚ฌ์–‘๊ณผ ๋ชจ๋ธ์˜ ๋ฐ์ดํ„ฐ ํƒ€์ž…์— ๋”ฐ๋ผ ๋‹ฌ๋ผ์ง‘๋‹ˆ๋‹ค.

๋ฐฐ์น˜ ํฌ๊ธฐ๋Š” TrainingArguments์˜ per_device_train_batch_size() ์˜ต์…˜์œผ๋กœ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.

from transformers import TrainingArguments

args = TrainingArguments(
    per_device_train_batch_size=256,
    per_device_eval_batch_size=256,
)

์„ฑ๋Šฅ, ์ž…๋ ฅ ํ”ผ์ฒ˜ ์ˆ˜์™€ ์ถœ๋ ฅ ๋‰ด๋Ÿฐ ์ˆ˜, ๋ฐฐ์น˜ ํฌ๊ธฐ๊ฐ€ ์„ฑ๋Šฅ์— ๋ฏธ์น˜๋Š” ์˜ํ–ฅ์— ๋Œ€ํ•ด์„œ๋Š” NVIDIA Performance ๊ฐ€์ด๋“œ๋ฅผ ์ฐธ๊ณ ํ•˜์„ธ์š”. ์ด ๋งค๊ฐœ๋ณ€์ˆ˜๋“ค์€ GPU์—์„œ ์‹คํ–‰๋˜๋Š” General Matrix Multiplications(GEMMs)์— ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค. ๋งค๊ฐœ๋ณ€์ˆ˜๊ฐ€ ํด์ˆ˜๋ก ๋ณ‘๋ ฌํ™”์™€ ํšจ์œจ์„ฑ์ด ํ–ฅ์ƒ๋ฉ๋‹ˆ๋‹ค.

๋ฐ์ดํ„ฐ ํƒ€์ž…๊ณผ GPU์— ๋”ฐ๋ฅธ ์ตœ์ ์˜ ๋ฐฐ์น˜ ํฌ๊ธฐ๋ฅผ ์„ ํƒํ•ด ํ…์„œ ๊ณฑ์…ˆ ์†๋„๋ฅผ ๊ทน๋Œ€ํ™”ํ•˜๋ ค๋ฉด, Tensor Core Requirements ์„น์…˜์„ ์ฐธ๊ณ ํ•˜๋Š” ๊ฒƒ์ด ์œ ์šฉํ•ฉ๋‹ˆ๋‹ค. ๊ทธ ์˜ˆ์‹œ๋กœ, fp16์—์„œ๋Š” 8์˜ ๋ฐฐ์ˆ˜๊ฐ€ ๊ถŒ์žฅ๋˜์ง€๋งŒ, A100 GPU์—์„œ๋Š” 64์˜ ๋ฐฐ์ˆ˜๊ฐ€ ๋” ์ ํ•ฉํ•˜๋‹ค๋Š” ์‚ฌ์‹ค์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๋งˆ์ง€๋ง‰์œผ๋กœ, ์ž‘์€ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์‚ฌ์šฉํ•  ๋•Œ๋Š” Dimension Quantization Effects๋ฅผ ๊ณ ๋ คํ•˜์„ธ์š”. ํ–‰๋ ฌ ์ฐจ์›์ด GPU ์Šค๋ ˆ๋“œ ๋ธ”๋ก์˜ ํƒ€์ผ ํฌ๊ธฐ๋กœ ๋‚˜๋ˆ„์–ด์ง€์ง€ ์•Š์œผ๋ฉด ํƒ€์ผ ์–‘์žํ™”๊ฐ€ ๋ฐœ์ƒํ•˜์—ฌ GPU ์ž์›์„ ์ถฉ๋ถ„ํžˆ ํ™œ์šฉํ•˜์ง€ ๋ชปํ•ฉ๋‹ˆ๋‹ค. ํ–‰๋ ฌ์ด ํƒ€์ผ ํฌ๊ธฐ๋กœ ์ •ํ™•ํžˆ ๋‚˜๋‰˜๋„๋ก ์˜ฌ๋ฐ”๋ฅธ ๋ฐฐ์น˜ ํฌ๊ธฐ ๋ฐฐ์ˆ˜๋ฅผ ์„ ํƒํ•˜๋ฉฐ ํ•™์Šต ์†๋„๊ฐ€ ํฌ๊ฒŒ ํ–ฅ์ƒ๋ฉ๋‹ˆ๋‹ค.

๊ทธ๋ ˆ์ด๋””์–ธํŠธ ๋ˆ„์ 

๊ทธ๋ ˆ์ด๋””์–ธํŠธ ๋ˆ„์ ์€ ๋ฉ”๋ชจ๋ฆฌ ์ œ์•ฝ์„ ๊ทน๋ณตํ•˜๋Š” ๋ฐฉ๋ฒ•์œผ๋กœ, ๋‹จ์ผ GPU์— ๋งž์ง€ ์•Š๋Š” ๋งค์šฐ ํฐ ๋ชจ๋ธ์„ ํ•™์Šตํ•  ๋•Œ ์œ ์šฉํ•ฉ๋‹ˆ๋‹ค. ์ด๋Š” ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์—…๋ฐ์ดํŠธํ•˜๊ธฐ ์ „์— ์—ฌ๋Ÿฌ ๋ฏธ๋‹ˆ ๋ฐฐ์น˜์— ๊ฑธ์ณ ๊ทธ๋ ˆ์ด๋””์–ธํŠธ๋ฅผ ๋ˆ„์ ํ•˜๋Š” ๋ฐฉ์‹์ž…๋‹ˆ๋‹ค. ๊ทธ ๊ฒฐ๊ณผ, ์ €์žฅํ•ด์•ผ ํ•˜๋Š” ๊ทธ๋ ˆ์ด๋””์–ธํŠธ ์ˆ˜๊ฐ€ ์ค„์–ด ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰์ด ์ค„์–ด๋“ค๊ณ , ์ผ๋ฐ˜์ ์œผ๋กœ ํ•˜๋‚˜์˜ ๋ฐฐ์น˜์—์„œ๋งŒ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ๊ฐฑ์‹ ํ•˜๋Š” ๋ฐฉ์‹๋ณด๋‹ค ๋” ํฐ ์œ ํšจ ๋ฐฐ์น˜ ํฌ๊ธฐ๋กœ ํ•™์Šตํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋‹ค๋งŒ, ์ถ”๊ฐ€์ ์ธ ์ˆœ์ „ํŒŒ์™€ ์—ญ์ „ํŒŒ๊ฐ€ ํ•„์š”ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ํ•™์Šต ์†๋„๊ฐ€ ๋А๋ ค์งˆ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๊ทธ๋ ˆ์ด๋””์–ธํŠธ ๋ˆ„์ ์„ ํ™œ์„ฑํ™”ํ•˜๋ ค๋ฉด TrainingArguments์—์„œ TrainingArguments.per_device_train_batch_size() ์˜ต์…˜์„ ์„ค์ •ํ•˜์„ธ์š”.

from transformers import TrainingArguments

# ํšจ์œจ์ ์ธ ๋ฐฐ์น˜ ํฌ๊ธฐ 64
args = TrainingArguments(
    per_device_train_batch_size=4,
    gradient_accumulation_steps=16,
)

ํ•™์Šต ์†๋„๊ฐ€ ๋А๋ ค์งˆ ์ˆ˜ ์žˆ๊ธฐ ๋•Œ๋ฌธ์— ๊ทธ๋ ˆ์ด๋””์–ธํŠธ ๋ˆ„์  ๋‹จ๊ณ„๋ฅผ ๋„ˆ๋ฌด ํฌ๊ฒŒ ์„ค์ •ํ•˜์ง€ ์•Š๋Š” ๊ฒƒ์ด ์ข‹์Šต๋‹ˆ๋‹ค. ์•„๋ž˜ ์˜ˆ์‹œ๋ฅผ ์ฐธ๊ณ ํ•˜์„ธ์š”, GPU์— ๋‹ด์„ ์ˆ˜ ์žˆ๋Š” ์ตœ๋Œ€ ๋ฐฐ์น˜ ํฌ๊ธฐ๊ฐ€ 4๋ผ๋ฉด GPU์˜ ํšจ์œจ์ ์ธ ์‚ฌ์šฉ์„ ์œ„ํ•ด ๋ฐฐ์น˜ ํฌ๊ธฐ๋ฅผ 4๋กœ ์œ ์ง€ํ•˜๋Š” ๊ฒƒ์ด ์ข‹์Šต๋‹ˆ๋‹ค.

๋ฐฐ์น˜ ํฌ๊ธฐ ๊ทธ๋ ˆ์ด๋””์–ธํŠธ ๋ˆ„์  ๋‹จ๊ณ„ ํšจ์œจ์ ์ธ ๋ฐฐ์น˜ ํฌ๊ธฐ
1 64 64 ๐Ÿ‘Ž
4 16 64 ๐Ÿ‘

๊ทธ๋ ˆ์ด๋””์–ธํŠธ ์ฒดํฌํฌ์ธํŒ…

๊ทธ๋ ˆ์ด๋””์–ธํŠธ ์ฒดํฌํฌ์ธํŒ…์€ ์—ญ์ „ํŒŒ ๊ณผ์ •์—์„œ ์ผ๋ถ€ ์ค‘๊ฐ„ ํ™œ์„ฑํ™” ๊ฐ’๋งŒ ์ €์žฅํ•˜๊ณ  ๋‚˜๋จธ์ง€๋Š” ๋‹ค์‹œ ๊ณ„์‚ฐํ•ด ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰์„ ์ค„์ž…๋‹ˆ๋‹ค. ์ด๋ฅผ ํ†ตํ•ด ์ˆœ์ „ํŒŒ ๊ณผ์ •์—์„œ ๋ชจ๋“  ์ค‘๊ฐ„ ํ™œ์„ฑํ™” ๊ฐ’์„ ์ €์žฅํ•˜์ง€ ์•Š์•„๋„ ๋˜์–ด ๋ฉ”๋ชจ๋ฆฌ ์˜ค๋ฒ„ํ—ค๋“œ๋ฅผ ํฌ๊ฒŒ ์ค„์ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋‹ค๋งŒ, ํ•™์Šต ์†๋„๊ฐ€ ์•ฝ 20% ๋А๋ ค์ง€๋Š” ํ•œ๊ณ„๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค.

๊ทธ๋ ˆ์ด๋””์–ธํŠธ ๋ˆ„์ ์„ ํ™œ์„ฑํ™”ํ•˜๋ ค๋ฉด TrainingArguments์—์„œ gradient_checkpointing() ์˜ต์…˜์„ ์„ค์ •ํ•˜์„ธ์š”.

from transformers import TrainingArguments

args = TrainingArguments(
    per_device_train_batch_size=4,
    gradient_accumulation_steps=16,
    gradient_checkpointing=True,
)

ํ˜ผํ•ฉ ์ •๋ฐ€๋„

ํ˜ผํ•ฉ ์ •๋ฐ€๋„๋Š” ์ผ๋ถ€ ๊ณ„์‚ฐ์„ ๋ฐ˜์ •๋ฐ€๋„(fp16)๋กœ, ๋‚˜๋จธ์ง€๋ฅผ ์ „์ •๋ฐ€๋„(fp32)๋กœ ์ˆ˜ํ–‰ํ•ด ํ•™์Šต ์†๋„๋ฅผ ๋†’์ด๋Š” ๊ธฐ๋ฒ•์ž…๋‹ˆ๋‹ค. ๋ฐ˜์ •๋ฐ€๋„ ๊ณ„์‚ฐ์€ ์ „์ •๋ฐ€๋„๋ณด๋‹ค ๊ณ„์‚ฐ๋Ÿ‰์ด ์ ์–ด ๋” ๋น ๋ฅด๊ฒŒ ์ˆ˜ํ–‰๋ฉ๋‹ˆ๋‹ค. ํ•œํŽธ, ์ „์ •๋ฐ€๋„๋กœ ์ผ๋ถ€ ๊ณ„์‚ฐ์„ ์ˆ˜ํ–‰ํ•˜๋ฉด ์ •ํ™•๋„๋ฅผ ์œ ์ง€ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

ํ˜ผํ•ฉ ์ •๋ฐ€๋„ ํ•™์Šต์„ ์œ„ํ•ด ์—ฌ๋Ÿฌ ์ž๋ฃŒํ˜•์„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

fp16
bf16
tf32

ํ˜ผํ•ฉ ์ •๋ฐ€๋„ ํ•™์Šต์˜ ์ฃผ์š” ์žฅ์ ์€ ํ™œ์„ฑํ™” ๊ฐ’์„ fp16์œผ๋กœ ์ €์žฅํ•  ์ˆ˜ ์žˆ๋‹ค๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

fp16 ์ž๋ฃŒํ˜•์œผ๋กœ ํ˜ผํ•ฉ ์ •๋ฐ€๋„ ํ•™์Šต์„ ํ™œ์„ฑํ™”ํ•˜๋ ค๋ฉด TrainingArguments์—์„œ fp16() ์˜ต์…˜์„ ์„ค์ •ํ•˜์„ธ์š”.

from transformers import TrainingArguments

args = TrainingArguments(
    per_device_train_batch_size=4,
    gradient_accumulation_steps=16,
    gradient_checkpointing=True,
    fp16=True.
)

fp16์€ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ์— ์ตœ์ ํ™”๋œ ๋ฐฉ์‹์ด ์•„๋‹™๋‹ˆ๋‹ค. ์ด๋Š” fp16์œผ๋กœ ๊ณ„์‚ฐ๋œ ๊ทธ๋ ˆ์ด๋””์–ธํŠธ๊ฐ€ ์ตœ์ ํ™” ๋‹จ๊ณ„์—์„œ fp32๋กœ ๋‹ค์‹œ ๋ณ€ํ™˜๋˜๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค. ํŠนํžˆ ๋ฐฐ์น˜ ํฌ๊ธฐ๊ฐ€ ์ž‘์„ ๋•Œ๋Š”, GPU์— ๋‘ ๊ฐ€์ง€ ์ž๋ฃŒํ˜•(fp16, fp32)์ด ์ ์žฌ๋˜์–ด ์žˆ๊ธฐ ๋•Œ๋ฌธ์— ๋” ๋งŽ์€ GPU ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ์‚ฌ์šฉํ•˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.

์˜ตํ‹ฐ๋งˆ์ด์ €

Transformers๋Š” ๊ธฐ๋ณธ์ ์œผ๋กœ PyTorch์˜ AdamW (adamw_torch) ์˜ตํ‹ฐ๋งˆ์ด์ €๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. ํ•˜์ง€๋งŒ, ์ด ์˜ตํ‹ฐ๋งˆ์ด์ €๋Š” ๊ณผ๊ฑฐ ๊ทธ๋ ˆ์ด๋””์–ธํŠธ์˜ ๊ฐ€์ค‘ ํ‰๊ท ์„ ์ €์žฅํ•˜๊ธฐ ๋•Œ๋ฌธ์—, ๊ทธ๋ ˆ์ด๋””์–ธํŠธ๋ฅผ ์ €์žฅํ•˜๊ธฐ ์œ„ํ•ด ๋ชจ๋ธ ๋งค๊ฐœ๋ณ€์ˆ˜ ์ˆ˜์— ๋น„๋ก€ํ•œ ์ถ”๊ฐ€ ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. ์ด๋Š” ๋งค์šฐ ํฐ ๋ชจ๋ธ์„ ํ•™์Šตํ•  ๋•Œ ๋ฌธ์ œ๊ฐ€ ๋  ์ˆ˜ ์žˆ์œผ๋ฉฐ, ์ด๋Ÿฌ๋ฉด ๋‹ค๋ฅธ ์˜ตํ‹ฐ๋งˆ์ด์ €๋ฅผ ์„ ํƒํ•˜๋Š” ๊ฒƒ์„ ๊ณ ๋ คํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด, NVIDIA ๋˜๋Š” AMD์— Apex๊ฐ€ ์„ค์น˜๋˜์–ด ์žˆ๋‹ค๋ฉด, ๋ชจ๋“  AdamW ์˜ตํ‹ฐ๋งˆ์ด์ € ์ค‘ adamw_apex_fused ์˜ตํ‹ฐ๋งˆ์ด์ €๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ด ๊ฐ€์žฅ ๋น ๋ฅธ ํ•™์Šต ์†๋„๋ฅผ ์–ป์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์˜ตํ‹ฐ๋งˆ์ด์ €๋ฅผ ์„ ํƒํ•˜๊ธฐ ์œ„ํ•ด์„œ๋Š” TrainingArguments์—์„œ optim() ์˜ต์…˜์„ ์„ค์ •ํ•˜์„ธ์š”.

from transformers import TrainingArguments

args = TrainingArguments(
    per_device_train_batch_size=4,
    gradient_accumulation_steps=16,
    gradient_checkpointing=True,
    bf16=True,
    optim="adamw_bnb_8bit"
)

ํ•™์Šต ์‹œ๋‚˜๋ฆฌ์˜ค์— ๋งž๊ฒŒ ์„ ํƒํ•  ์ˆ˜ ์žˆ๋Š” ๋‹ค์–‘ํ•œ ์˜ตํ‹ฐ๋งˆ์ด์ €๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค. (์ „์ฒด ์ง€์› ๋ชฉ๋ก์€ OptimizerNames๋ฅผ ์ฐธ๊ณ ํ•˜์„ธ์š”) ์˜ˆ๋ฅผ ๋“ค์–ด Adafactor๋Š” ํ–‰๋ ฌ์˜ ๊ฐ ์š”์†Œ ๋Œ€์‹  ํ–‰ ๋˜๋Š” ์—ด ๋‹จ์œ„์˜ ๊ฐ€์ค‘ ํ‰๊ท ๋งŒ ์ €์žฅํ•ด ๋ฉ”๋ชจ๋ฆฌ ์š”๊ตฌ๋Ÿ‰์„ ํฌ๊ฒŒ ์ค„์ผ ์ˆ˜ ์žˆ์ง€๋งŒ, ์ˆ˜๋ ด ์†๋„๋Š” ๋А๋ ค์งˆ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋˜ ๋‹ค๋ฅธ ์˜ˆ๋กœ, bitandbytes์˜ 8-bit AdamW optimizer๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด ์˜ตํ‹ฐ๋งˆ์ด์ €์˜ ์ƒํƒœ๋ฅผ 8๋น„ํŠธ๋กœ ์–‘์žํ™”ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์˜ตํ‹ฐ๋งˆ์ด์ € ์ƒํƒœ๋Š” ๋‚ฎ์€ ์ •๋ฐ€๋„๋กœ ์ €์žฅ๋˜์—ˆ๋‹ค๊ฐ€ ์˜ตํ‹ฐ๋งˆ์ด์ € ๋‹จ๊ณ„์—์„œ ์‚ฌ์šฉ๋˜๊ธฐ ์ „์— ์—ญ ์–‘์žํ™”๋ฉ๋‹ˆ๋‹ค.

ํŠนํ™”๋œ ์˜ตํ‹ฐ๋งˆ์ด์ €์— ๋Œ€ํ•ด ๋” ์•Œ๊ณ  ์‹ถ๋‹ค๋ฉด optimizer ๊ฐ€์ด๋“œ๋ฅผ ์ฐธ๊ณ ํ•˜์„ธ์š”.

๋ฐ์ดํ„ฐ ์‚ฌ์ „ ์ ์žฌ

๋ฐ์ดํ„ฐ ์‚ฌ์ „ ์ ์žฌ(Data preloading)๋Š” GPU๊ฐ€ ์ง€์†์ ์œผ๋กœ ์ž‘์—…ํ•  ์ˆ˜ ์žˆ๋„๋ก CPU์—์„œ ๋ฏธ๋ฆฌ ๋ฐฐ์น˜ ๋‹จ์œ„์˜ ๋ฐ์ดํ„ฐ๋ฅผ ์ ์žฌํ•˜๊ณ  ์ค€๋น„ํ•˜๋Š” ๊ธฐ๋Šฅ์ž…๋‹ˆ๋‹ค. ์ด๋ฅผ ํ†ตํ•ด GPU ์œ ํœด ์‹œ๊ฐ„์„ ์ค„์ด๊ณ  ํ™œ์šฉ๋„๋ฅผ ๋†’์ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. GPU๊ฐ€ ํ•ญ์ƒ ์ž‘์—…์„ ๊ณ„์†ํ•˜๋„๋ก ํ•˜๋ ค๋ฉด ๋‹ค์Œ ๋ฐ์ดํ„ฐ ์‚ฌ์ „ ์ ์žฌ๋ฅผ ์œ„ํ•œ ๋‘ ๊ฐ€์ง€ ๋ฐฉ๋ฒ•์„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

  1. ๋ฐ์ดํ„ฐ๋ฅผ ์ €์žฅํ•  ๊ณ ์ • ๋ฉ”๋ชจ๋ฆฌ๋ฅผ CPU์— ํ• ๋‹นํ•œ ๋’ค, ์ด๋ฅผ GPU๋กœ ์ง์ ‘ ์ „์†กํ•ฉ๋‹ˆ๋‹ค.
  2. CPU ์Šค๋ ˆ๋“œ ๋ฐ ์›Œ์ปค ์ˆ˜๋ฅผ ๋Š˜๋ ค ๋ฐ์ดํ„ฐ๋ฅผ ๋” ๋น ๋ฅด๊ฒŒ ์‚ฌ์ „ ์ ์žฌํ•ฉ๋‹ˆ๋‹ค.

๊ณ ์ • ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ํ• ๋‹นํ•˜๊ณ  ์›Œ์ปค ์ˆ˜๋ฅผ ๋Š˜๋ฆฌ๊ธฐ ์œ„ํ•ด์„œ๋Š” TrainingArguments์—์„œ dataloader_pin_memory()์™€ dataloader_num_workers() ์˜ต์…˜์„ ์„ค์ •ํ•˜์„ธ์š”.

from transformers import TrainingArguments

args = TrainingArguments(
    per_device_train_batch_size=4,
    gradient_accumulation_steps=16,
    gradient_checkpointing=True,
    bf16=True,
    optim="adamw_bnb_8bit",
    dataloader_pin_memory=True,
    dataloader_num_workers=4,
)

PyTorch

PyTorch๋Š” ๋ฉ”๋ชจ๋ฆฌ ์š”๊ตฌ์‚ฌํ•ญ์„ ์ค„์ด๊ณ  ํ•™์Šต ์†๋„๋ฅผ ๋†’์ด๊ธฐ ์œ„ํ•œ ์—ฌ๋Ÿฌ ๊ธฐ๋Šฅ์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ๊ธฐ๋Šฅ๋“ค์€ Transformers์—์„œ ๋ช‡ ์ค„์˜ ์ฝ”๋“œ๋งŒ ์ถ”๊ฐ€ํ•˜์—ฌ ํ™œ์„ฑํ™”ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

torch.empty_cache_steps

torch.cuda.empty_cache ํ•จ์ˆ˜๋Š” ์‚ฌ์šฉํ•˜์ง€ ์•Š๋Š” ์บ์‹œ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ํ•ด์ œํ•˜์—ฌ ๋ฉ”๋ชจ๋ฆฌ ๋ถ€์กฑ(OOM) ์˜ค๋ฅ˜๋ฅผ ๋ฐฉ์ง€ํ•  ์ˆ˜ ์žˆ์ง€๋งŒ, ํ•™์Šต ์†๋„๊ฐ€ ์•ฝ 10% ๋А๋ ค์งˆ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

ํŠน์ • ํ•™์Šต ๋‹จ๊ณ„ ์ดํ›„์— ์ด ๊ธฐ๋Šฅ์„ ํ™œ์„ฑํ™”ํ•˜๊ณ  ์‹ถ๋‹ค๋ฉด, TrainingArguments์—์„œ torch_empty_cache_steps()๋ฅผ ์„ค์ •ํ•˜์„ธ์š”.

from transformers import TrainingArguments

args = TrainingArguments(
    per_device_train_batch_size=4,
    gradient_accumulation_steps=16,
    gradient_checkpointing=True,
    bf16=True,
    optim="adamw_bnb_8bit",
    dataloader_pin_memory=True,
    dataloader_num_workers=4,
    torch_empty_cache_steps=4,
)

torch.compile

torch.compile์€ PyTorch ์ฝ”๋“œ๋ฅผ ์ตœ์ ํ™”๋œ ์ปค๋„๋กœ ์ปดํŒŒ์ผํ•ด ํ•™์Šต ์†๋„๋ฅผ ํฌ๊ฒŒ ๋†’์—ฌ์ค๋‹ˆ๋‹ค. ์ด ๊ธฐ๋Šฅ์€ TorchDynamo๋ฅผ ์‚ฌ์šฉํ•ด ํ”„๋ ˆ์ž„ ํ‰๊ฐ€ API๋กœ๋ถ€ํ„ฐ PyTorch ๊ทธ๋ž˜ํ”„๋ฅผ ์บก์ฒ˜ํ•˜๋ฉฐ, ์ด๋ ‡๊ฒŒ ์บก์ฒ˜ํ•œ ๊ทธ๋ž˜ํ”„๋Š” ๋‹ค์–‘ํ•œ ๋ฐฑ์—”๋“œ์— ์ถ”๊ฐ€๋กœ ์ตœ์ ํ™”๋œ ์ปค๋„๋กœ ์ปดํŒŒ์ผ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์ด๋ฅผ ํ™œ์„ฑํ™”ํ•˜๋ ค๋ฉด TrainingArguments์—์„œ torch_compile()๋ฅผ ์„ค์ •ํ•˜์„ธ์š”. ๋ฐฑ์—”๋“œ๋Š” torch_compile_backend()๋ฅผ ํ†ตํ•ด ์„ ํƒํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

from transformers import TrainingArguments

args = TrainingArguments(
    per_device_train_batch_size=4,
    gradient_accumulation_steps=16,
    gradient_checkpointing=True,
    bf16=True,
    optim="adamw_bnb_8bit",
    dataloader_pin_memory=True,
    dataloader_num_workers=4,
    torch_empty_cache_steps=4,
    torch_compile=True,
    torch_compile_backend="inductor"
)

์•„๋ž˜ ํ‘œ๋ฅผ ์ฐธ๊ณ ํ•˜์—ฌ ํ•™์Šต ์‹œ๋‚˜๋ฆฌ์˜ค์— ์ ํ•ฉํ•œ ๋ฐฑ์—”๋“œ๋ฅผ ์„ ํƒํ•˜์„ธ์š”.

๋ฐฑ์—”๋“œ ์„ค๋ช… ๋ชฉํ‘œ
eager PyTorch๋ฅผ ์‚ฌ์šฉํ•ด ์ถ”์ถœ๋œ GraphModule์„ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค ๋””๋ฒ„๊น…
aot_eager AOTAutograd๋กœ ์ถ”์ถœ๋œ ์ˆœ์ „ํŒŒ ๋ฐ ์—ญ์ „ํŒŒ ๊ทธ๋ž˜ํ”„๋ฅผ Pytorch eager ๋ชจ๋“œ๋กœ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค ๋””๋ฒ„๊น…
inductor Triton ์ปค๋„์„ ํ™œ์šฉํ•˜๋Š” TorchInductor์™€ AOTAutograd, CUDA Graphs๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค ํ•™์Šต ๋ฐ ์ถ”๋ก 
nvfuser TorchScript์™€ ํ•จ๊ป˜ nvFuser๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค ํ•™์Šต ๋ฐ ์ถ”๋ก 
aot_nvfuser AOTAutograd์™€ ํ•จ๊ป˜ nvFuser๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค ํ•™์Šต ๋ฐ ์ถ”๋ก 
aot_cudagraphs AOTAutograd์™€ ํ•จ๊ป˜ CUDA Graphs๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค ํ•™์Šต ๋ฐ ์ถ”๋ก 
ofi TorchScripts์˜ optimize_for_inference๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค ์ถ”๋ก 
fx2trt Torch-TensorRT๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค ์ถ”๋ก 
onnxrt CPU ๋ฐ GPU ์ถ”๋ก ์„ ์œ„ํ•ด ONNX-RT๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค ์ถ”๋ก 
ipex CPU ์ถ”๋ก ์„ ์œ„ํ•ด IPEX๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค ์ถ”๋ก 

์Šค์ผ€์ผ๋œ ๋‚ด์  ์–ดํ…์…˜

torch.nn.functional.scaled_dot_product_attention (SDPA)๋Š” ์Šค์ผ€์ผ๋œ ๋‚ด์  ์–ดํ…์…˜ ๋ฉ”์ปค๋‹ˆ์ฆ˜์„ PyTorch์— ๋‚ด์žฅํ•ด ๊ตฌํ˜„ํ•œ ํ•จ์ˆ˜์ž…๋‹ˆ๋‹ค. SDPA๋Š” ํŠธ๋žœ์Šคํฌ๋จธ ๋ชจ๋ธ์˜ ๊ธฐ์กด ์–ดํ…์…˜ ๋ฉ”์ปค๋‹ˆ์ฆ˜๋ณด๋‹ค ๋” ํšจ์œจ์ ์ด๊ณ  ์ตœ์ ํ™”๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค. ์„ธ ๊ฐ€์ง€ ์œ ํ˜•์˜ ์Šค์ผ€์ผ๋œ ๋‚ด์  ์–ดํ…์…˜์„ ์ง€์›ํ•ฉ๋‹ˆ๋‹ค.

  • FlashAttention2๋Š” fp16 ๋˜๋Š” bf16 torch ํƒ€์ž… ๋ชจ๋ธ์—์„œ ์ž๋™์œผ๋กœ ํ™œ์„ฑํ™”๋ฉ๋‹ˆ๋‹ค. ๋จผ์ € ๋ชจ๋ธ์„ ์ ์ ˆํ•œ ํƒ€์ž…์œผ๋กœ ์บ์ŠคํŒ…ํ–ˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”.
  • xFormers ๋˜๋Š” Memory-Efficient Attention์€ fp32 torch ํƒ€์ž… ๋ชจ๋ธ์„ ์ง€์›ํ•ฉ๋‹ˆ๋‹ค.
  • C++๋กœ ๊ตฌํ˜„๋œ ์Šค์ผ€์ผ๋œ ๋‚ด์  ์–ดํ…์…˜์ž…๋‹ˆ๋‹ค.

SDPA๋Š” PyTorch 2.1.1 ๋ฒ„์ „ ์ด์ƒ์—์„œ ๊ธฐ๋ณธ์ ์œผ๋กœ ํ™œ์„ฑํ™”๋˜์–ด ์žˆ์ง€๋งŒ, from_pretrained()์—์„œ attn_implementation="sdpa"๋ฅผ ์„ค์ •ํ•ด ๋ช…์‹œ์ ์œผ๋กœ ํ™œ์„ฑํ™”ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B", device_map="auto", attn_implementation="sdpa")
< > Update on GitHub

ะ›ัƒั‡ัˆะธะน ั‡ะฐัั‚ะฝั‹ะน ั…ะพัั‚ะธะฝะณ