lynx   »   [go: up one dir, main page]

Transformers documentation

Trainer

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Trainer

Trainer๋Š” Transformers ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์— ๊ตฌํ˜„๋œ PyTorch ๋ชจ๋ธ์„ ๋ฐ˜๋ณตํ•˜์—ฌ ํ›ˆ๋ จ ๋ฐ ํ‰๊ฐ€ ๊ณผ์ •์ž…๋‹ˆ๋‹ค. ํ›ˆ๋ จ์— ํ•„์š”ํ•œ ์š”์†Œ(๋ชจ๋ธ, ํ† ํฌ๋‚˜์ด์ €, ๋ฐ์ดํ„ฐ์…‹, ํ‰๊ฐ€ ํ•จ์ˆ˜, ํ›ˆ๋ จ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ ๋“ฑ)๋งŒ ์ œ๊ณตํ•˜๋ฉด Trainer๊ฐ€ ํ•„์š”ํ•œ ๋‚˜๋จธ์ง€ ์ž‘์—…์„ ์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค. ์ด๋ฅผ ํ†ตํ•ด ์ง์ ‘ ํ›ˆ๋ จ ๋ฃจํ”„๋ฅผ ์ž‘์„ฑํ•˜์ง€ ์•Š๊ณ ๋„ ๋น ๋ฅด๊ฒŒ ํ›ˆ๋ จ์„ ์‹œ์ž‘ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋˜ํ•œ Trainer๋Š” ๊ฐ•๋ ฅํ•œ ๋งž์ถค ์„ค์ •๊ณผ ๋‹ค์–‘ํ•œ ํ›ˆ๋ จ ์˜ต์…˜์„ ์ œ๊ณตํ•˜์—ฌ ์‚ฌ์šฉ์ž ๋งž์ถค ํ›ˆ๋ จ์ด ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค.

Transformers๋Š” Trainer ํด๋ž˜์Šค ์™ธ์—๋„ ๋ฒˆ์—ญ์ด๋‚˜ ์š”์•ฝ๊ณผ ๊ฐ™์€ ์‹œํ€€์Šค-ํˆฌ-์‹œํ€€์Šค ์ž‘์—…์„ ์œ„ํ•œ Seq2SeqTrainer ํด๋ž˜์Šค๋„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค. ๋˜ํ•œ TRL ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์—๋Š” Trainer ํด๋ž˜์Šค๋ฅผ ๊ฐ์‹ธ๊ณ  Llama-2 ๋ฐ Mistral๊ณผ ๊ฐ™์€ ์–ธ์–ด ๋ชจ๋ธ์„ ์ž๋™ ํšŒ๊ท€ ๊ธฐ๋ฒ•์œผ๋กœ ํ›ˆ๋ จํ•˜๋Š” ๋ฐ ์ตœ์ ํ™”๋œ SFTTrainer ํด๋ž˜์Šค ์ž…๋‹ˆ๋‹ค. SFTTrainer๋Š” ์‹œํ€€์Šค ํŒจํ‚น, LoRA, ์–‘์žํ™” ๋ฐ DeepSpeed์™€ ๊ฐ™์€ ๊ธฐ๋Šฅ์„ ์ง€์›ํ•˜์—ฌ ํฌ๊ธฐ ์ƒ๊ด€์—†์ด ๋ชจ๋ธ ํšจ์œจ์ ์œผ๋กœ ํ™•์žฅํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.


์ด๋“ค ๋‹ค๋ฅธ Trainer ์œ ํ˜• ํด๋ž˜์Šค์— ๋Œ€ํ•ด ๋” ์•Œ๊ณ  ์‹ถ๋‹ค๋ฉด API ์ฐธ์กฐ๋ฅผ ํ™•์ธํ•˜์—ฌ ์–ธ์ œ ์–ด๋–ค ํด๋ž˜์Šค๊ฐ€ ์ ํ•ฉํ• ์ง€ ์–ผ๋งˆ๋“ ์ง€ ํ™•์ธํ•˜์„ธ์š”. ์ผ๋ฐ˜์ ์œผ๋กœ Trainer๋Š” ๊ฐ€์žฅ ๋‹ค์žฌ๋‹ค๋Šฅํ•œ ์˜ต์…˜์œผ๋กœ, ๋‹ค์–‘ํ•œ ์ž‘์—…์— ์ ํ•ฉํ•ฉ๋‹ˆ๋‹ค. Seq2SeqTrainer๋Š” ์‹œํ€€์Šค-ํˆฌ-์‹œํ€€์Šค ์ž‘์—…์„ ์œ„ํ•ด ์„ค๊ณ„๋˜์—ˆ๊ณ , SFTTrainer๋Š” ์–ธ์–ด ๋ชจ๋ธ ํ›ˆ๋ จ์„ ์œ„ํ•ด ์„ค๊ณ„๋˜์—ˆ์Šต๋‹ˆ๋‹ค.

์‹œ์ž‘ํ•˜๊ธฐ ์ „์—, ๋ถ„์‚ฐ ํ™˜๊ฒฝ์—์„œ PyTorch ํ›ˆ๋ จ๊ณผ ์‹คํ–‰์„ ํ•  ์ˆ˜ ์žˆ๊ฒŒ Accelerate ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๊ฐ€ ์„ค์น˜๋˜์—ˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”.

pip install accelerate

# ์—…๊ทธ๋ ˆ์ด๋“œ
pip install accelerate --upgrade

์ด ๊ฐ€์ด๋“œ๋Š” Trainer ํด๋ž˜์Šค์— ๋Œ€ํ•œ ๊ฐœ์š”๋ฅผ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.

๊ธฐ๋ณธ ์‚ฌ์šฉ๋ฒ•

Trainer๋Š” ๊ธฐ๋ณธ์ ์ธ ํ›ˆ๋ จ ๋ฃจํ”„์— ํ•„์š”ํ•œ ๋ชจ๋“  ์ฝ”๋“œ๋ฅผ ํฌํ•จํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.

  1. ์†์‹ค์„ ๊ณ„์‚ฐํ•˜๋Š” ํ›ˆ๋ จ ๋‹จ๊ณ„๋ฅผ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.
  2. backward ๋ฉ”์†Œ๋“œ๋กœ ๊ทธ๋ ˆ์ด๋””์–ธํŠธ๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค.
  3. ๊ทธ๋ ˆ์ด๋””์–ธํŠธ๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ๊ฐ€์ค‘์น˜๋ฅผ ์—…๋ฐ์ดํŠธํ•ฉ๋‹ˆ๋‹ค.
  4. ์ •ํ•ด์ง„ ์—ํญ ์ˆ˜์— ๋„๋‹ฌํ•  ๋•Œ๊นŒ์ง€ ์ด ๊ณผ์ •์„ ๋ฐ˜๋ณตํ•ฉ๋‹ˆ๋‹ค.

Trainer ํด๋ž˜์Šค๋Š” PyTorch์™€ ํ›ˆ๋ จ ๊ณผ์ •์— ์ต์ˆ™ํ•˜์ง€ ์•Š๊ฑฐ๋‚˜ ๋ง‰ ์‹œ์ž‘ํ•œ ๊ฒฝ์šฐ์—๋„ ํ›ˆ๋ จ์ด ๊ฐ€๋Šฅํ•˜๋„๋ก ํ•„์š”ํ•œ ๋ชจ๋“  ์ฝ”๋“œ๋ฅผ ์ถ”์ƒํ™”ํ•˜์˜€์Šต๋‹ˆ๋‹ค. ๋˜ํ•œ ๋งค๋ฒˆ ํ›ˆ๋ จ ๋ฃจํ”„๋ฅผ ์†์ˆ˜ ์ž‘์„ฑํ•˜์ง€ ์•Š์•„๋„ ๋˜๋ฉฐ, ํ›ˆ๋ จ์— ํ•„์š”ํ•œ ๋ชจ๋ธ๊ณผ ๋ฐ์ดํ„ฐ์…‹ ๊ฐ™์€ ํ•„์ˆ˜ ๊ตฌ์„ฑ ์š”์†Œ๋งŒ ์ œ๊ณตํ•˜๋ฉด, [Trainer] ํด๋ž˜์Šค๊ฐ€ ๋‚˜๋จธ์ง€๋ฅผ ์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค.

ํ›ˆ๋ จ ์˜ต์…˜์ด๋‚˜ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์ง€์ •ํ•˜๋ ค๋ฉด, TrainingArguments ํด๋ž˜์Šค์—์„œ ํ™•์ธ ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด, ๋ชจ๋ธ์„ ์ €์žฅํ•  ๋””๋ ‰ํ† ๋ฆฌ๋ฅผ output_dir์— ์ •์˜ํ•˜๊ณ , ํ›ˆ๋ จ ํ›„์— Hub๋กœ ๋ชจ๋ธ์„ ํ‘ธ์‹œํ•˜๋ ค๋ฉด push_to_hub=True๋กœ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.

from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="your-model",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=2,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    push_to_hub=True,
)

training_args๋ฅผ Trainer์— ๋ชจ๋ธ, ๋ฐ์ดํ„ฐ์…‹, ๋ฐ์ดํ„ฐ์…‹ ์ „์ฒ˜๋ฆฌ ๋„๊ตฌ(๋ฐ์ดํ„ฐ ์œ ํ˜•์— ๋”ฐ๋ผ ํ† ํฌ๋‚˜์ด์ €, ํŠน์ง• ์ถ”์ถœ๊ธฐ ๋˜๋Š” ์ด๋ฏธ์ง€ ํ”„๋กœ์„ธ์„œ์ผ ์ˆ˜ ์žˆ์Œ), ๋ฐ์ดํ„ฐ ์ˆ˜์ง‘๊ธฐ ๋ฐ ํ›ˆ๋ จ ์ค‘ ํ™•์ธํ•  ์ง€ํ‘œ๋ฅผ ๊ณ„์‚ฐํ•  ํ•จ์ˆ˜๋ฅผ ํ•จ๊ป˜ ์ „๋‹ฌํ•˜์„ธ์š”.

๋งˆ์ง€๋ง‰์œผ๋กœ, train()๋ฅผ ํ˜ธ์ถœํ•˜์—ฌ ํ›ˆ๋ จ์„ ์‹œ์ž‘ํ•˜์„ธ์š”!

from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

์ฒดํฌํฌ์ธํŠธ

Trainer ํด๋ž˜์Šค๋Š” TrainingArguments์˜ output_dir ๋งค๊ฐœ๋ณ€์ˆ˜์— ์ง€์ •๋œ ๋””๋ ‰ํ† ๋ฆฌ์— ๋ชจ๋ธ ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค. ์ฒดํฌํฌ์ธํŠธ๋Š” checkpoint-000 ํ•˜์œ„ ํด๋”์— ์ €์žฅ๋˜๋ฉฐ, ์—ฌ๊ธฐ์„œ ๋์˜ ์ˆซ์ž๋Š” ํ›ˆ๋ จ ๋‹จ๊ณ„์— ํ•ด๋‹นํ•ฉ๋‹ˆ๋‹ค. ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์ €์žฅํ•˜๋ฉด ๋‚˜์ค‘์— ํ›ˆ๋ จ์„ ์žฌ๊ฐœํ•  ๋•Œ ์œ ์šฉํ•ฉ๋‹ˆ๋‹ค.

# ์ตœ์‹  ์ฒดํฌํฌ์ธํŠธ์—์„œ ์žฌ๊ฐœ
trainer.train(resume_from_checkpoint=True)

# ์ถœ๋ ฅ ๋””๋ ‰ํ† ๋ฆฌ์— ์ €์žฅ๋œ ํŠน์ • ์ฒดํฌํฌ์ธํŠธ์—์„œ ์žฌ๊ฐœ
trainer.train(resume_from_checkpoint="your-model/checkpoint-1000")

์ฒดํฌํฌ์ธํŠธ๋ฅผ Hub์— ํ‘ธ์‹œํ•˜๋ ค๋ฉด TrainingArguments์—์„œ push_to_hub=True๋กœ ์„ค์ •ํ•˜์—ฌ ์ปค๋ฐ‹ํ•˜๊ณ  ํ‘ธ์‹œํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ฒดํฌํฌ์ธํŠธ ์ €์žฅ ๋ฐฉ๋ฒ•์„ ๊ฒฐ์ •ํ•˜๋Š” ๋‹ค๋ฅธ ์˜ต์…˜์€ hub_strategy ๋งค๊ฐœ๋ณ€์ˆ˜์—์„œ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค:

  • hub_strategy="checkpoint"๋Š” ์ตœ์‹  ์ฒดํฌํฌ์ธํŠธ๋ฅผ โ€œlast-checkpointโ€๋ผ๋Š” ํ•˜์œ„ ํด๋”์— ํ‘ธ์‹œํ•˜์—ฌ ํ›ˆ๋ จ์„ ์žฌ๊ฐœํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • hub_strategy="all_checkpoints"๋Š” ๋ชจ๋“  ์ฒดํฌํฌ์ธํŠธ๋ฅผ output_dir์— ์ •์˜๋œ ๋””๋ ‰ํ† ๋ฆฌ์— ํ‘ธ์‹œํ•ฉ๋‹ˆ๋‹ค(๋ชจ๋ธ ๋ฆฌํฌ์ง€ํ† ๋ฆฌ์—์„œ ํด๋”๋‹น ํ•˜๋‚˜์˜ ์ฒดํฌํฌ์ธํŠธ๋ฅผ ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค).

์ฒดํฌํฌ์ธํŠธ์—์„œ ํ›ˆ๋ จ์„ ์žฌ๊ฐœํ•  ๋•Œ, Trainer๋Š” ์ฒดํฌํฌ์ธํŠธ๊ฐ€ ์ €์žฅ๋  ๋•Œ์™€ ๋™์ผํ•œ Python, NumPy ๋ฐ PyTorch RNG ์ƒํƒœ๋ฅผ ์œ ์ง€ํ•˜๋ ค๊ณ  ํ•ฉ๋‹ˆ๋‹ค. ํ•˜์ง€๋งŒ PyTorch๋Š” ๊ธฐ๋ณธ ์„ค์ •์œผ๋กœ โ€˜์ผ๊ด€๋œ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์žฅํ•˜์ง€ ์•Š์Œโ€™์œผ๋กœ ๋งŽ์ด ๋˜์–ด์žˆ๊ธฐ ๋•Œ๋ฌธ์—, RNG ์ƒํƒœ๊ฐ€ ๋™์ผํ•  ๊ฒƒ์ด๋ผ๊ณ  ๋ณด์žฅํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ, ์ผ๊ด€๋œ ๊ฒฐ๊ณผ๊ฐ€ ๋ณด์žฅ๋˜๋„๋ก ํ™œ์„ฑํ™” ํ•˜๋ ค๋ฉด, ๋žœ๋ค์„ฑ ์ œ์–ด ๊ฐ€์ด๋“œ๋ฅผ ์ฐธ๊ณ ํ•˜์—ฌ ํ›ˆ๋ จ์„ ์™„์ „ํžˆ ์ผ๊ด€๋œ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์žฅ ๋ฐ›๋„๋ก ๋งŒ๋“ค๊ธฐ ์œ„ํ•ด ํ™œ์„ฑํ™”ํ•  ์ˆ˜ ์žˆ๋Š” ํ•ญ๋ชฉ์„ ํ™•์ธํ•˜์„ธ์š”. ๋‹ค๋งŒ, ํŠน์ • ์„ค์ •์„ ๊ฒฐ์ •์ ์œผ๋กœ ๋งŒ๋“ค๋ฉด ํ›ˆ๋ จ์ด ๋А๋ ค์งˆ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

Trainer ๋งž์ถค ์„ค์ •

Trainer ํด๋ž˜์Šค๋Š” ์ ‘๊ทผ์„ฑ๊ณผ ์šฉ์ด์„ฑ์„ ์—ผ๋‘์— ๋‘๊ณ  ์„ค๊ณ„๋˜์—ˆ์ง€๋งŒ, ๋” ๋‹ค์–‘ํ•œ ๊ธฐ๋Šฅ์„ ์›ํ•˜๋Š” ์‚ฌ์šฉ์ž๋“ค์„ ์œ„ํ•ด ๋‹ค์–‘ํ•œ ๋งž์ถค ์„ค์ • ์˜ต์…˜์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค. Trainer์˜ ๋งŽ์€ ๋ฉ”์†Œ๋“œ๋Š” ์„œ๋ธŒํด๋ž˜์Šคํ™” ๋ฐ ์˜ค๋ฒ„๋ผ์ด๋“œํ•˜์—ฌ ์›ํ•˜๋Š” ๊ธฐ๋Šฅ์„ ์ œ๊ณตํ•  ์ˆ˜ ์žˆ์œผ๋ฉฐ, ์ด๋ฅผ ํ†ตํ•ด ์ „์ฒด ํ›ˆ๋ จ ๋ฃจํ”„๋ฅผ ๋‹ค์‹œ ์ž‘์„ฑํ•  ํ•„์š” ์—†์ด ์›ํ•˜๋Š” ๊ธฐ๋Šฅ์„ ์ถ”๊ฐ€ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ๋ฉ”์†Œ๋“œ์—๋Š” ๋‹ค์Œ์ด ํฌํ•จ๋ฉ๋‹ˆ๋‹ค:

  • get_train_dataloader()๋Š” ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ๋กœ๋”๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
  • get_eval_dataloader()๋Š” ํ‰๊ฐ€ ๋ฐ์ดํ„ฐ๋กœ๋”๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
  • get_test_dataloader()๋Š” ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ๋กœ๋”๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
  • log()๋Š” ํ›ˆ๋ จ์„ ๋ชจ๋‹ˆํ„ฐ๋งํ•˜๋Š” ๋‹ค์–‘ํ•œ ๊ฐ์ฒด์— ๋Œ€ํ•œ ์ •๋ณด๋ฅผ ๋กœ๊ทธ๋กœ ๋‚จ๊น๋‹ˆ๋‹ค.
  • create_optimizer_and_scheduler()๋Š” __init__์—์„œ ์ „๋‹ฌ๋˜์ง€ ์•Š์€ ๊ฒฝ์šฐ ์˜ตํ‹ฐ๋งˆ์ด์ €์™€ ํ•™์Šต๋ฅ  ์Šค์ผ€์ค„๋Ÿฌ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค. ์ด๋“ค์€ ๊ฐ๊ฐ create_optimizer() ๋ฐ create_scheduler()๋กœ ๋ณ„๋„๋กœ ๋งž์ถค ์„ค์ • ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • compute_loss()๋Š” ํ›ˆ๋ จ ์ž…๋ ฅ ๋ฐฐ์น˜์— ๋Œ€ํ•œ ์†์‹ค์„ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค.
  • training_step()๋Š” ํ›ˆ๋ จ ๋‹จ๊ณ„๋ฅผ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.
  • prediction_step()๋Š” ์˜ˆ์ธก ๋ฐ ํ…Œ์ŠคํŠธ ๋‹จ๊ณ„๋ฅผ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.
  • evaluate()๋Š” ๋ชจ๋ธ์„ ํ‰๊ฐ€ํ•˜๊ณ  ํ‰๊ฐ€ ์ง€ํ‘œ์„ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
  • predict()๋Š” ํ…Œ์ŠคํŠธ ์„ธํŠธ์— ๋Œ€ํ•œ ์˜ˆ์ธก(๋ ˆ์ด๋ธ”์ด ์žˆ๋Š” ๊ฒฝ์šฐ ์ง€ํ‘œ ํฌํ•จ)์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.

์˜ˆ๋ฅผ ๋“ค์–ด, compute_loss() ๋ฉ”์†Œ๋“œ๋ฅผ ๋งž์ถค ์„ค์ •ํ•˜์—ฌ ๊ฐ€์ค‘ ์†์‹ค์„ ์‚ฌ์šฉํ•˜๋ ค๋Š” ๊ฒฝ์šฐ:

from torch import nn
from transformers import Trainer

class CustomTrainer(Trainer):
    def compute_loss(self,

 model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        # ์ˆœ๋ฐฉํ–ฅ ์ „ํŒŒ
        outputs = model(**inputs)
        logits = outputs.get("logits")
        # ์„œ๋กœ ๋‹ค๋ฅธ ๊ฐ€์ค‘์น˜๋กœ 3๊ฐœ์˜ ๋ ˆ์ด๋ธ”์— ๋Œ€ํ•œ ์‚ฌ์šฉ์ž ์ •์˜ ์†์‹ค์„ ๊ณ„์‚ฐ
        loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0], device=model.device))
        loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

์ฝœ๋ฐฑ

Trainer๋ฅผ ๋งž์ถค ์„ค์ •ํ•˜๋Š” ๋˜ ๋‹ค๋ฅธ ๋ฐฉ๋ฒ•์€ ์ฝœ๋ฐฑ์„ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์ฝœ๋ฐฑ์€ ํ›ˆ๋ จ ๋ฃจํ”„์—์„œ ๋ณ€ํ™”๋ฅผ ์ฃผ์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ํ›ˆ๋ จ ๋ฃจํ”„์˜ ์ƒํƒœ๋ฅผ ๊ฒ€์‚ฌํ•œ ํ›„ ์ƒํƒœ์— ๋”ฐ๋ผ ์ผ๋ถ€ ์ž‘์—…(์กฐ๊ธฐ ์ข…๋ฃŒ, ๊ฒฐ๊ณผ ๋กœ๊ทธ ๋“ฑ)์„ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค. ์ฆ‰, ์ฝœ๋ฐฑ์€ ์‚ฌ์šฉ์ž ์ •์˜ ์†์‹ค ํ•จ์ˆ˜์™€ ๊ฐ™์€ ๊ฒƒ์„ ๊ตฌํ˜„ํ•˜๋Š” ๋ฐ ์‚ฌ์šฉํ•  ์ˆ˜ ์—†์œผ๋ฉฐ, ์ด๋ฅผ ์œ„ํ•ด์„œ๋Š” compute_loss() ๋ฉ”์†Œ๋“œ๋ฅผ ์„œ๋ธŒํด๋ž˜์Šคํ™”ํ•˜๊ณ  ์˜ค๋ฒ„๋ผ์ด๋“œํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

์˜ˆ๋ฅผ ๋“ค์–ด, ํ›ˆ๋ จ ๋ฃจํ”„์— 10๋‹จ๊ณ„ ํ›„ ์กฐ๊ธฐ ์ข…๋ฃŒ ์ฝœ๋ฐฑ์„ ์ถ”๊ฐ€ํ•˜๋ ค๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์ด ํ•ฉ๋‹ˆ๋‹ค.

from transformers import TrainerCallback

class EarlyStoppingCallback(TrainerCallback):
    def __init__(self, num_steps=10):
        self.num_steps = num_steps
    
    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step >= self.num_steps:
            return {"should_training_stop": True}
        else:
            return {}

๊ทธ๋Ÿฐ ๋‹ค์Œ, ์ด๋ฅผ Trainer์˜ callback ๋งค๊ฐœ๋ณ€์ˆ˜์— ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค.

from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback()],
)

๋กœ๊น…

๋กœ๊น… API์— ๋Œ€ํ•œ ์ž์„ธํ•œ ๋‚ด์šฉ์€ ๋กœ๊น… API ๋ ˆํผ๋Ÿฐ์Šค๋ฅผ ํ™•์ธํ•˜์„ธ์š”.

Trainer๋Š” ๊ธฐ๋ณธ์ ์œผ๋กœ logging.INFO๋กœ ์„ค์ •๋˜์–ด ์žˆ์–ด ์˜ค๋ฅ˜, ๊ฒฝ๊ณ  ๋ฐ ๊ธฐํƒ€ ๊ธฐ๋ณธ ์ •๋ณด๋ฅผ ๋ณด๊ณ ํ•ฉ๋‹ˆ๋‹ค. ๋ถ„์‚ฐ ํ™˜๊ฒฝ์—์„œ๋Š” Trainer ๋ณต์ œ๋ณธ์ด logging.WARNING์œผ๋กœ ์„ค์ •๋˜์–ด ์˜ค๋ฅ˜์™€ ๊ฒฝ๊ณ ๋งŒ ๋ณด๊ณ ํ•ฉ๋‹ˆ๋‹ค. TrainingArguments์˜ log_level ๋ฐ log_level_replica ๋งค๊ฐœ๋ณ€์ˆ˜๋กœ ๋กœ๊ทธ ๋ ˆ๋ฒจ์„ ๋ณ€๊ฒฝํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๊ฐ ๋…ธ๋“œ์˜ ๋กœ๊ทธ ๋ ˆ๋ฒจ ์„ค์ •์„ ๊ตฌ์„ฑํ•˜๋ ค๋ฉด log_on_each_node ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๊ฐ ๋…ธ๋“œ์—์„œ ๋กœ๊ทธ ๋ ˆ๋ฒจ์„ ์‚ฌ์šฉํ• ์ง€ ์•„๋‹ˆ๋ฉด ์ฃผ ๋…ธ๋“œ์—์„œ๋งŒ ์‚ฌ์šฉํ• ์ง€ ๊ฒฐ์ •ํ•˜์„ธ์š”.

Trainer๋Š” Trainer.__init__() ๋ฉ”์†Œ๋“œ์—์„œ ๊ฐ ๋…ธ๋“œ์— ๋Œ€ํ•ด ๋กœ๊ทธ ๋ ˆ๋ฒจ์„ ๋ณ„๋„๋กœ ์„ค์ •ํ•˜๋ฏ€๋กœ, ๋‹ค๋ฅธ Transformers ๊ธฐ๋Šฅ์„ ์‚ฌ์šฉํ•  ๊ฒฝ์šฐ Trainer ๊ฐ์ฒด๋ฅผ ์ƒ์„ฑํ•˜๊ธฐ ์ „์— ์ด๋ฅผ ๋ฏธ๋ฆฌ ์„ค์ •ํ•˜๋Š” ๊ฒƒ์ด ์ข‹์Šต๋‹ˆ๋‹ค.

์˜ˆ๋ฅผ ๋“ค์–ด, ๋ฉ”์ธ ์ฝ”๋“œ์™€ ๋ชจ๋“ˆ์„ ๊ฐ ๋…ธ๋“œ์— ๋”ฐ๋ผ ๋™์ผํ•œ ๋กœ๊ทธ ๋ ˆ๋ฒจ์„ ์‚ฌ์šฉํ•˜๋„๋ก ์„ค์ •ํ•˜๋ ค๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์ด ํ•ฉ๋‹ˆ๋‹ค.

logger = logging.getLogger(__name__)

logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    handlers=[logging.StreamHandler(sys.stdout)],
)

log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)

trainer = Trainer(...)

๊ฐ ๋…ธ๋“œ์—์„œ ๊ธฐ๋ก๋  ๋‚ด์šฉ์„ ๊ตฌ์„ฑํ•˜๊ธฐ ์œ„ํ•ด log_level๊ณผ log_level_replica๋ฅผ ๋‹ค์–‘ํ•œ ์กฐํ•ฉ์œผ๋กœ ์‚ฌ์šฉํ•ด๋ณด์„ธ์š”.

single node
multi-node
my_app.py ... --log_level warning --log_level_replica error

NEFTune

NEFTune์€ ํ›ˆ๋ จ ์ค‘ ์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ์— ๋…ธ์ด์ฆˆ๋ฅผ ์ถ”๊ฐ€ํ•˜์—ฌ ์„ฑ๋Šฅ์„ ํ–ฅ์ƒ์‹œํ‚ฌ ์ˆ˜ ์žˆ๋Š” ๊ธฐ์ˆ ์ž…๋‹ˆ๋‹ค. Trainer์—์„œ ์ด๋ฅผ ํ™œ์„ฑํ™”ํ•˜๋ ค๋ฉด TrainingArguments์˜ neftune_noise_alpha ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์„ค์ •ํ•˜์—ฌ ๋…ธ์ด์ฆˆ์˜ ์–‘์„ ์กฐ์ ˆํ•ฉ๋‹ˆ๋‹ค.

from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(..., neftune_noise_alpha=0.1)
trainer = Trainer(..., args=training_args)

NEFTune์€ ์˜ˆ์ƒ์น˜ ๋ชปํ•œ ๋™์ž‘์„ ํ”ผํ•  ๋ชฉ์ ์œผ๋กœ ์ฒ˜์Œ ์ž„๋ฒ ๋”ฉ ๋ ˆ์ด์–ด๋กœ ๋ณต์›ํ•˜๊ธฐ ์œ„ํ•ด ํ›ˆ๋ จ ํ›„ ๋น„ํ™œ์„ฑํ™” ๋ฉ๋‹ˆ๋‹ค.

GaLore

Gradient Low-Rank Projection (GaLore)์€ ์ „์ฒด ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ํ•™์Šตํ•˜๋ฉด์„œ๋„ LoRA์™€ ๊ฐ™์€ ์ผ๋ฐ˜์ ์ธ ์ €๊ณ„์ˆ˜ ์ ์‘ ๋ฐฉ๋ฒ•๋ณด๋‹ค ๋” ๋ฉ”๋ชจ๋ฆฌ ํšจ์œจ์ ์ธ ์ €๊ณ„์ˆ˜ ํ•™์Šต ์ „๋žต์ž…๋‹ˆ๋‹ค.

๋จผ์ € GaLore ๊ณต์‹ ๋ฆฌํฌ์ง€ํ† ๋ฆฌ๋ฅผ ์„ค์น˜ํ•ฉ๋‹ˆ๋‹ค:

pip install galore-torch

๊ทธ๋Ÿฐ ๋‹ค์Œ optim์— ["galore_adamw", "galore_adafactor", "galore_adamw_8bit"] ์ค‘ ํ•˜๋‚˜์™€ ํ•จ๊ป˜ optim_target_modules๋ฅผ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค. ์ด๋Š” ์ ์šฉํ•˜๋ ค๋Š” ๋Œ€์ƒ ๋ชจ๋“ˆ ์ด๋ฆ„์— ํ•ด๋‹นํ•˜๋Š” ๋ฌธ์ž์—ด, ์ •๊ทœ ํ‘œํ˜„์‹ ๋˜๋Š” ์ „์ฒด ๊ฒฝ๋กœ์˜ ๋ชฉ๋ก์ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์•„๋ž˜๋Š” end-to-end ์˜ˆ์ œ ์Šคํฌ๋ฆฝํŠธ์ž…๋‹ˆ๋‹ค(ํ•„์š”ํ•œ ๊ฒฝ์šฐ pip install trl datasets๋ฅผ ์‹คํ–‰):

import datasets
from trl import SFTConfig, SFTTrainer

train_dataset = datasets.load_dataset('imdb', split='train')
args = SFTConfig(
    output_dir="./test-galore",
    max_steps=100,
    optim="galore_adamw",
    optim_target_modules=[r".*.attn.*", r".*.mlp.*"],
    gradient_checkpointing=True,
)
trainer = SFTTrainer(
    model="google/gemma-2b",
    args=args,
    train_dataset=train_dataset,
)
trainer.train()

GaLore๊ฐ€ ์ง€์›ํ•˜๋Š” ์ถ”๊ฐ€ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์ „๋‹ฌํ•˜๋ ค๋ฉด optim_args๋ฅผ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด:

import datasets
from trl import SFTConfig, SFTTrainer

train_dataset = datasets.load_dataset('imdb', split='train')
args = SFTConfig(
    output_dir="./test-galore",
    max_steps=100,
    optim="galore_adamw",
    optim_target_modules=[r".*.attn.*", r".*.mlp.*"],
    optim_args="rank=64, update_proj_gap=100, scale=0.10",
    gradient_checkpointing=True,
)
trainer = SFTTrainer(
    model="google/gemma-2b",
    args=args,
    train_dataset=train_dataset,
)
trainer.train()

ํ•ด๋‹น ๋ฐฉ๋ฒ•์— ๋Œ€ํ•œ ์ž์„ธํ•œ ๋‚ด์šฉ์€ ์›๋ณธ ๋ฆฌํฌ์ง€ํ† ๋ฆฌ ๋˜๋Š” ๋…ผ๋ฌธ์„ ์ฐธ๊ณ ํ•˜์„ธ์š”.

ํ˜„์žฌ GaLore ๋ ˆ์ด์–ด๋กœ ๊ฐ„์ฃผ๋˜๋Š” Linear ๋ ˆ์ด์–ด๋งŒ ํ›ˆ๋ จ ํ• ์ˆ˜ ์žˆ์œผ๋ฉฐ, ์ €๊ณ„์ˆ˜ ๋ถ„ํ•ด๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํ›ˆ๋ จ๋˜๊ณ  ๋‚˜๋จธ์ง€ ๋ ˆ์ด์–ด๋Š” ๊ธฐ์กด ๋ฐฉ์‹์œผ๋กœ ์ตœ์ ํ™”๋ฉ๋‹ˆ๋‹ค.

ํ›ˆ๋ จ ์‹œ์ž‘ ์ „์— ์‹œ๊ฐ„์ด ์•ฝ๊ฐ„ ๊ฑธ๋ฆด ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค(NVIDIA A100์—์„œ 2B ๋ชจ๋ธ์˜ ๊ฒฝ์šฐ ์•ฝ 3๋ถ„), ํ•˜์ง€๋งŒ ์ดํ›„ ํ›ˆ๋ จ์€ ์›ํ™œํ•˜๊ฒŒ ์ง„ํ–‰๋ฉ๋‹ˆ๋‹ค.

๋‹ค์Œ๊ณผ ๊ฐ™์ด ์˜ตํ‹ฐ๋งˆ์ด์ € ์ด๋ฆ„์— layerwise๋ฅผ ์ถ”๊ฐ€ํ•˜์—ฌ ๋ ˆ์ด์–ด๋ณ„ ์ตœ์ ํ™”๋ฅผ ์ˆ˜ํ–‰ํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค:

import datasets
from trl import SFTConfig, SFTTrainer

train_dataset = datasets.load_dataset('imdb', split='train')
args = SFTConfig(
    output_dir="./test-galore",
    max_steps=100,
    optim="galore_adamw_layerwise",
    optim_target_modules=[r".*.attn.*", r".*.mlp.*"],
    gradient_checkpointing=True,
)
trainer = SFTTrainer(
    model="google/gemma-2b",
    args=args,
    train_dataset=train_dataset,
)
trainer.train()

๋ ˆ์ด์–ด๋ณ„ ์ตœ์ ํ™”๋Š” ๋‹ค์†Œ ์‹คํ—˜์ ์ด๋ฉฐ DDP(๋ถ„์‚ฐ ๋ฐ์ดํ„ฐ ๋ณ‘๋ ฌ)๋ฅผ ์ง€์›ํ•˜์ง€ ์•Š์œผ๋ฏ€๋กœ, ๋‹จ์ผ GPU์—์„œ๋งŒ ํ›ˆ๋ จ ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์‹คํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ž์„ธํ•œ ๋‚ด์šฉ์€ ์ด ๋ฌธ์„œ๋ฅผ์„ ์ฐธ์กฐํ•˜์„ธ์š”. gradient clipping, DeepSpeed ๋“ฑ ๋‹ค๋ฅธ ๊ธฐ๋Šฅ์€ ๊ธฐ๋ณธ์ ์œผ๋กœ ์ง€์›๋˜์ง€ ์•Š์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ๋ฌธ์ œ๊ฐ€ ๋ฐœ์ƒํ•˜๋ฉด GitHub์— ์ด์Šˆ๋ฅผ ์˜ฌ๋ ค์ฃผ์„ธ์š”.

LOMO ์˜ตํ‹ฐ๋งˆ์ด์ €

LOMO ์˜ตํ‹ฐ๋งˆ์ด์ €๋Š” ์ œํ•œ๋œ ์ž์›์œผ๋กœ ๋Œ€ํ˜• ์–ธ์–ด ๋ชจ๋ธ์˜ ์ „์ฒด ๋งค๊ฐœ๋ณ€์ˆ˜ ๋ฏธ์„ธ ์กฐ์ •๊ณผ ์ ์‘ํ˜• ํ•™์Šต๋ฅ ์„ ํ†ตํ•œ ์ €๋ฉ”๋ชจ๋ฆฌ ์ตœ์ ํ™”(AdaLomo)์—์„œ ๋„์ž…๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ์ด๋“ค์€ ๋ชจ๋‘ ํšจ์œจ์ ์ธ ์ „์ฒด ๋งค๊ฐœ๋ณ€์ˆ˜ ๋ฏธ์„ธ ์กฐ์ • ๋ฐฉ๋ฒ•์œผ๋กœ ๊ตฌ์„ฑ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ์˜ตํ‹ฐ๋งˆ์ด์ €๋“ค์€ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰์„ ์ค„์ด๊ธฐ ์œ„ํ•ด ๊ทธ๋ ˆ์ด๋””์–ธํŠธ ๊ณ„์‚ฐ๊ณผ ๋งค๊ฐœ๋ณ€์ˆ˜ ์—…๋ฐ์ดํŠธ๋ฅผ ํ•˜๋‚˜์˜ ๋‹จ๊ณ„๋กœ ์œตํ•ฉํ•ฉ๋‹ˆ๋‹ค. LOMO์—์„œ ์ง€์›๋˜๋Š” ์˜ตํ‹ฐ๋งˆ์ด์ €๋Š” "lomo"์™€ "adalomo"์ž…๋‹ˆ๋‹ค. ๋จผ์ € pypi์—์„œ pip install lomo-optim๋ฅผ ํ†ตํ•ด lomo๋ฅผ ์„ค์น˜ํ•˜๊ฑฐ๋‚˜, GitHub ์†Œ์Šค์—์„œ pip install git+https://github.com/OpenLMLab/LOMO.git๋กœ ์„ค์น˜ํ•˜์„ธ์š”.

์ €์ž์— ๋”ฐ๋ฅด๋ฉด, grad_norm ์—†์ด AdaLomo๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ด ๋” ๋‚˜์€ ์„ฑ๋Šฅ๊ณผ ๋†’์€ ์ฒ˜๋ฆฌ๋Ÿ‰์„ ์ œ๊ณตํ•œ๋‹ค๊ณ  ํ•ฉ๋‹ˆ๋‹ค.

๋‹ค์Œ์€ IMDB ๋ฐ์ดํ„ฐ์…‹์—์„œ google/gemma-2b๋ฅผ ์ตœ๋Œ€ ์ •๋ฐ€๋„๋กœ ๋ฏธ์„ธ ์กฐ์ •ํ•˜๋Š” ๊ฐ„๋‹จํ•œ ์Šคํฌ๋ฆฝํŠธ์ž…๋‹ˆ๋‹ค:

import datasets
from trl import SFTConfig, SFTTrainer

train_dataset = datasets.load_dataset('imdb', split='train')
args = SFTConfig(
    output_dir="./test-lomo",
    max_steps=100,
    optim="adalomo",
    gradient_checkpointing=True,
)
trainer = SFTTrainer(
    model="google/gemma-2b",
    args=args,
    train_dataset=train_dataset,
)
trainer.train()

Accelerate์™€ Trainer

Trainer ํด๋ž˜์Šค๋Š” Accelerate๋กœ ๊ตฌ๋™๋˜๋ฉฐ, ์ด๋Š” FullyShardedDataParallel (FSDP) ๋ฐ DeepSpeed์™€ ๊ฐ™์€ ํ†ตํ•ฉ์„ ์ง€์›ํ•˜๋Š” ๋ถ„์‚ฐ ํ™˜๊ฒฝ์—์„œ PyTorch ๋ชจ๋ธ์„ ์‰ฝ๊ฒŒ ํ›ˆ๋ จํ•  ์ˆ˜ ์žˆ๋Š” ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์ž…๋‹ˆ๋‹ค.

FSDP ์ƒค๋”ฉ ์ „๋žต, CPU ์˜คํ”„๋กœ๋“œ ๋ฐ Trainer์™€ ํ•จ๊ป˜ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋Š” ๋” ๋งŽ์€ ๊ธฐ๋Šฅ์„ ์•Œ์•„๋ณด๋ ค๋ฉด Fully Sharded Data Parallel ๊ฐ€์ด๋“œ๋ฅผ ํ™•์ธํ•˜์„ธ์š”.

Trainer์™€ Accelerate๋ฅผ ์‚ฌ์šฉํ•˜๋ ค๋ฉด accelerate.config ๋ช…๋ น์„ ์‹คํ–‰ํ•˜์—ฌ ํ›ˆ๋ จ ํ™˜๊ฒฝ์„ ์„ค์ •ํ•˜์„ธ์š”. ์ด ๋ช…๋ น์€ ํ›ˆ๋ จ ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์‹คํ–‰ํ•  ๋•Œ ์‚ฌ์šฉํ•  config_file.yaml์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด, ๋‹ค์Œ ์˜ˆ์‹œ๋Š” ์„ค์ •ํ•  ์ˆ˜ ์žˆ๋Š” ์ผ๋ถ€ ๊ตฌ์„ฑ ์˜ˆ์ž…๋‹ˆ๋‹ค.

DistributedDataParallel
FSDP
DeepSpeed
DeepSpeed with Accelerate plugin
compute_environment: LOCAL_MACHINE                                                                                             
distributed_type: MULTI_GPU                                                                                                    
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0 # ๋…ธ๋“œ์— ๋”ฐ๋ผ ์ˆœ์œ„๋ฅผ ๋ณ€๊ฒฝํ•˜์„ธ์š”
main_process_ip: 192.168.20.1
main_process_port: 9898
main_training_function: main
mixed_precision: fp16
num_machines: 2
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

accelerate_launch ๋ช…๋ น์€ Accelerate์™€ Trainer๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ถ„์‚ฐ ์‹œ์Šคํ…œ์—์„œ ํ›ˆ๋ จ ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์‹คํ–‰ํ•˜๋Š” ๊ถŒ์žฅ ๋ฐฉ๋ฒ•์ด๋ฉฐ, config_file.yaml์— ์ง€์ •๋œ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. ์ด ํŒŒ์ผ์€ Accelerate ์บ์‹œ ํด๋”์— ์ €์žฅ๋˜๋ฉฐ accelerate_launch๋ฅผ ์‹คํ–‰ํ•  ๋•Œ ์ž๋™์œผ๋กœ ๋กœ๋“œ๋ฉ๋‹ˆ๋‹ค.

์˜ˆ๋ฅผ ๋“ค์–ด, FSDP ๊ตฌ์„ฑ์„ ์‚ฌ์šฉํ•˜์—ฌ run_glue.py ํ›ˆ๋ จ ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์‹คํ–‰ํ•˜๋ ค๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์ด ํ•ฉ๋‹ˆ๋‹ค:

accelerate launch \
    ./examples/pytorch/text-classification/run_glue.py \
    --model_name_or_path google-bert/bert-base-cased \
    --task_name $TASK_NAME \
    --do_train \
    --do_eval \
    --max_seq_length 128 \
    --per_device_train_batch_size 16 \
    --learning_rate 5e-5 \
    --num_train_epochs 3 \
    --output_dir /tmp/$TASK_NAME/ \
    --overwrite_output_dir

config_file.yaml ํŒŒ์ผ์˜ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์ง์ ‘ ์ง€์ •ํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค:

accelerate launch --num_processes=2 \
    --use_fsdp \
    --mixed_precision=bf16 \
    --fsdp_auto_wrap_policy=TRANSFORMER_BASED_WRAP  \
    --fsdp_transformer_layer_cls_to_wrap="BertLayer" \
    --fsdp_sharding_strategy=1 \
    --fsdp_state_dict_type=FULL_STATE_DICT \
    ./examples/pytorch/text-classification/run_glue.py \
    --model_name_or_path google-bert/bert-base-cased \
    --task_name $TASK_NAME \
    --do_train \
    --do_eval \
    --max_seq_length 128 \
    --per_device_train_batch_size 16 \
    --learning_rate 5e-5 \
    --num_train_epochs 3 \
    --output_dir /tmp/$TASK_NAME/ \
    --overwrite_output_dir

accelerate_launch์™€ ์‚ฌ์šฉ์ž ์ •์˜ ๊ตฌ์„ฑ์— ๋Œ€ํ•ด ๋” ์•Œ์•„๋ณด๋ ค๋ฉด Accelerate ์Šคํฌ๋ฆฝํŠธ ์‹คํ–‰ ํŠœํ† ๋ฆฌ์–ผ์„ ํ™•์ธํ•˜์„ธ์š”.

< > Update on GitHub

ะ›ัƒั‡ัˆะธะน ั‡ะฐัั‚ะฝั‹ะน ั…ะพัั‚ะธะฝะณ