Transformers documentation
Trainer
Trainer
Trainer๋ Transformers ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ๊ตฌํ๋ PyTorch ๋ชจ๋ธ์ ๋ฐ๋ณตํ์ฌ ํ๋ จ ๋ฐ ํ๊ฐ ๊ณผ์ ์ ๋๋ค. ํ๋ จ์ ํ์ํ ์์(๋ชจ๋ธ, ํ ํฌ๋์ด์ , ๋ฐ์ดํฐ์ , ํ๊ฐ ํจ์, ํ๋ จ ํ์ดํผํ๋ผ๋ฏธํฐ ๋ฑ)๋ง ์ ๊ณตํ๋ฉด Trainer๊ฐ ํ์ํ ๋๋จธ์ง ์์ ์ ์ฒ๋ฆฌํฉ๋๋ค. ์ด๋ฅผ ํตํด ์ง์ ํ๋ จ ๋ฃจํ๋ฅผ ์์ฑํ์ง ์๊ณ ๋ ๋น ๋ฅด๊ฒ ํ๋ จ์ ์์ํ ์ ์์ต๋๋ค. ๋ํ Trainer๋ ๊ฐ๋ ฅํ ๋ง์ถค ์ค์ ๊ณผ ๋ค์ํ ํ๋ จ ์ต์ ์ ์ ๊ณตํ์ฌ ์ฌ์ฉ์ ๋ง์ถค ํ๋ จ์ด ๊ฐ๋ฅํฉ๋๋ค.
Transformers๋ Trainer ํด๋์ค ์ธ์๋ ๋ฒ์ญ์ด๋ ์์ฝ๊ณผ ๊ฐ์ ์ํ์ค-ํฌ-์ํ์ค ์์
์ ์ํ Seq2SeqTrainer ํด๋์ค๋ ์ ๊ณตํฉ๋๋ค. ๋ํ TRL ๋ผ์ด๋ธ๋ฌ๋ฆฌ์๋ Trainer ํด๋์ค๋ฅผ ๊ฐ์ธ๊ณ Llama-2 ๋ฐ Mistral๊ณผ ๊ฐ์ ์ธ์ด ๋ชจ๋ธ์ ์๋ ํ๊ท ๊ธฐ๋ฒ์ผ๋ก ํ๋ จํ๋ ๋ฐ ์ต์ ํ๋ SFTTrainer
ํด๋์ค ์
๋๋ค. SFTTrainer
๋ ์ํ์ค ํจํน, LoRA, ์์ํ ๋ฐ DeepSpeed์ ๊ฐ์ ๊ธฐ๋ฅ์ ์ง์ํ์ฌ ํฌ๊ธฐ ์๊ด์์ด ๋ชจ๋ธ ํจ์จ์ ์ผ๋ก ํ์ฅํ ์ ์์ต๋๋ค.
์ด๋ค ๋ค๋ฅธ Trainer ์ ํ ํด๋์ค์ ๋ํด ๋ ์๊ณ ์ถ๋ค๋ฉด API ์ฐธ์กฐ๋ฅผ ํ์ธํ์ฌ ์ธ์ ์ด๋ค ํด๋์ค๊ฐ ์ ํฉํ ์ง ์ผ๋ง๋ ์ง ํ์ธํ์ธ์. ์ผ๋ฐ์ ์ผ๋ก Trainer๋ ๊ฐ์ฅ ๋ค์ฌ๋ค๋ฅํ ์ต์
์ผ๋ก, ๋ค์ํ ์์
์ ์ ํฉํฉ๋๋ค. Seq2SeqTrainer๋ ์ํ์ค-ํฌ-์ํ์ค ์์
์ ์ํด ์ค๊ณ๋์๊ณ , SFTTrainer
๋ ์ธ์ด ๋ชจ๋ธ ํ๋ จ์ ์ํด ์ค๊ณ๋์์ต๋๋ค.
์์ํ๊ธฐ ์ ์, ๋ถ์ฐ ํ๊ฒฝ์์ PyTorch ํ๋ จ๊ณผ ์คํ์ ํ ์ ์๊ฒ Accelerate ๋ผ์ด๋ธ๋ฌ๋ฆฌ๊ฐ ์ค์น๋์๋์ง ํ์ธํ์ธ์.
pip install accelerate
# ์
๊ทธ๋ ์ด๋
pip install accelerate --upgrade
์ด ๊ฐ์ด๋๋ Trainer ํด๋์ค์ ๋ํ ๊ฐ์๋ฅผ ์ ๊ณตํฉ๋๋ค.
๊ธฐ๋ณธ ์ฌ์ฉ๋ฒ
Trainer๋ ๊ธฐ๋ณธ์ ์ธ ํ๋ จ ๋ฃจํ์ ํ์ํ ๋ชจ๋ ์ฝ๋๋ฅผ ํฌํจํ๊ณ ์์ต๋๋ค.
- ์์ค์ ๊ณ์ฐํ๋ ํ๋ จ ๋จ๊ณ๋ฅผ ์ํํฉ๋๋ค.
backward
๋ฉ์๋๋ก ๊ทธ๋ ์ด๋์ธํธ๋ฅผ ๊ณ์ฐํฉ๋๋ค.- ๊ทธ๋ ์ด๋์ธํธ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ๊ฐ์ค์น๋ฅผ ์ ๋ฐ์ดํธํฉ๋๋ค.
- ์ ํด์ง ์ํญ ์์ ๋๋ฌํ ๋๊น์ง ์ด ๊ณผ์ ์ ๋ฐ๋ณตํฉ๋๋ค.
Trainer ํด๋์ค๋ PyTorch์ ํ๋ จ ๊ณผ์ ์ ์ต์ํ์ง ์๊ฑฐ๋ ๋ง ์์ํ ๊ฒฝ์ฐ์๋ ํ๋ จ์ด ๊ฐ๋ฅํ๋๋ก ํ์ํ ๋ชจ๋ ์ฝ๋๋ฅผ ์ถ์ํํ์์ต๋๋ค. ๋ํ ๋งค๋ฒ ํ๋ จ ๋ฃจํ๋ฅผ ์์ ์์ฑํ์ง ์์๋ ๋๋ฉฐ, ํ๋ จ์ ํ์ํ ๋ชจ๋ธ๊ณผ ๋ฐ์ดํฐ์ ๊ฐ์ ํ์ ๊ตฌ์ฑ ์์๋ง ์ ๊ณตํ๋ฉด, [Trainer] ํด๋์ค๊ฐ ๋๋จธ์ง๋ฅผ ์ฒ๋ฆฌํฉ๋๋ค.
ํ๋ จ ์ต์
์ด๋ ํ์ดํผํ๋ผ๋ฏธํฐ๋ฅผ ์ง์ ํ๋ ค๋ฉด, TrainingArguments ํด๋์ค์์ ํ์ธ ํ ์ ์์ต๋๋ค. ์๋ฅผ ๋ค์ด, ๋ชจ๋ธ์ ์ ์ฅํ ๋๋ ํ ๋ฆฌ๋ฅผ output_dir
์ ์ ์ํ๊ณ , ํ๋ จ ํ์ Hub๋ก ๋ชจ๋ธ์ ํธ์ํ๋ ค๋ฉด push_to_hub=True
๋ก ์ค์ ํฉ๋๋ค.
from transformers import TrainingArguments
training_args = TrainingArguments(
output_dir="your-model",
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=2,
weight_decay=0.01,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
push_to_hub=True,
)
training_args
๋ฅผ Trainer์ ๋ชจ๋ธ, ๋ฐ์ดํฐ์
, ๋ฐ์ดํฐ์
์ ์ฒ๋ฆฌ ๋๊ตฌ(๋ฐ์ดํฐ ์ ํ์ ๋ฐ๋ผ ํ ํฌ๋์ด์ , ํน์ง ์ถ์ถ๊ธฐ ๋๋ ์ด๋ฏธ์ง ํ๋ก์ธ์์ผ ์ ์์), ๋ฐ์ดํฐ ์์ง๊ธฐ ๋ฐ ํ๋ จ ์ค ํ์ธํ ์งํ๋ฅผ ๊ณ์ฐํ ํจ์๋ฅผ ํจ๊ป ์ ๋ฌํ์ธ์.
๋ง์ง๋ง์ผ๋ก, train()๋ฅผ ํธ์ถํ์ฌ ํ๋ จ์ ์์ํ์ธ์!
from transformers import Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
)
trainer.train()
์ฒดํฌํฌ์ธํธ
Trainer ํด๋์ค๋ TrainingArguments์ output_dir
๋งค๊ฐ๋ณ์์ ์ง์ ๋ ๋๋ ํ ๋ฆฌ์ ๋ชจ๋ธ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ ์ฅํฉ๋๋ค. ์ฒดํฌํฌ์ธํธ๋ checkpoint-000
ํ์ ํด๋์ ์ ์ฅ๋๋ฉฐ, ์ฌ๊ธฐ์ ๋์ ์ซ์๋ ํ๋ จ ๋จ๊ณ์ ํด๋นํฉ๋๋ค. ์ฒดํฌํฌ์ธํธ๋ฅผ ์ ์ฅํ๋ฉด ๋์ค์ ํ๋ จ์ ์ฌ๊ฐํ ๋ ์ ์ฉํฉ๋๋ค.
# ์ต์ ์ฒดํฌํฌ์ธํธ์์ ์ฌ๊ฐ
trainer.train(resume_from_checkpoint=True)
# ์ถ๋ ฅ ๋๋ ํ ๋ฆฌ์ ์ ์ฅ๋ ํน์ ์ฒดํฌํฌ์ธํธ์์ ์ฌ๊ฐ
trainer.train(resume_from_checkpoint="your-model/checkpoint-1000")
์ฒดํฌํฌ์ธํธ๋ฅผ Hub์ ํธ์ํ๋ ค๋ฉด TrainingArguments์์ push_to_hub=True
๋ก ์ค์ ํ์ฌ ์ปค๋ฐํ๊ณ ํธ์ํ ์ ์์ต๋๋ค. ์ฒดํฌํฌ์ธํธ ์ ์ฅ ๋ฐฉ๋ฒ์ ๊ฒฐ์ ํ๋ ๋ค๋ฅธ ์ต์
์ hub_strategy
๋งค๊ฐ๋ณ์์์ ์ค์ ํฉ๋๋ค:
hub_strategy="checkpoint"
๋ ์ต์ ์ฒดํฌํฌ์ธํธ๋ฅผ โlast-checkpointโ๋ผ๋ ํ์ ํด๋์ ํธ์ํ์ฌ ํ๋ จ์ ์ฌ๊ฐํ ์ ์์ต๋๋ค.hub_strategy="all_checkpoints"
๋ ๋ชจ๋ ์ฒดํฌํฌ์ธํธ๋ฅผoutput_dir
์ ์ ์๋ ๋๋ ํ ๋ฆฌ์ ํธ์ํฉ๋๋ค(๋ชจ๋ธ ๋ฆฌํฌ์งํ ๋ฆฌ์์ ํด๋๋น ํ๋์ ์ฒดํฌํฌ์ธํธ๋ฅผ ๋ณผ ์ ์์ต๋๋ค).
์ฒดํฌํฌ์ธํธ์์ ํ๋ จ์ ์ฌ๊ฐํ ๋, Trainer๋ ์ฒดํฌํฌ์ธํธ๊ฐ ์ ์ฅ๋ ๋์ ๋์ผํ Python, NumPy ๋ฐ PyTorch RNG ์ํ๋ฅผ ์ ์งํ๋ ค๊ณ ํฉ๋๋ค. ํ์ง๋ง PyTorch๋ ๊ธฐ๋ณธ ์ค์ ์ผ๋ก โ์ผ๊ด๋ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์ฅํ์ง ์์โ์ผ๋ก ๋ง์ด ๋์ด์๊ธฐ ๋๋ฌธ์, RNG ์ํ๊ฐ ๋์ผํ ๊ฒ์ด๋ผ๊ณ ๋ณด์ฅํ ์ ์์ต๋๋ค. ๋ฐ๋ผ์, ์ผ๊ด๋ ๊ฒฐ๊ณผ๊ฐ ๋ณด์ฅ๋๋๋ก ํ์ฑํ ํ๋ ค๋ฉด, ๋๋ค์ฑ ์ ์ด ๊ฐ์ด๋๋ฅผ ์ฐธ๊ณ ํ์ฌ ํ๋ จ์ ์์ ํ ์ผ๊ด๋ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์ฅ ๋ฐ๋๋ก ๋ง๋ค๊ธฐ ์ํด ํ์ฑํํ ์ ์๋ ํญ๋ชฉ์ ํ์ธํ์ธ์. ๋ค๋ง, ํน์ ์ค์ ์ ๊ฒฐ์ ์ ์ผ๋ก ๋ง๋ค๋ฉด ํ๋ จ์ด ๋๋ ค์ง ์ ์์ต๋๋ค.
Trainer ๋ง์ถค ์ค์
Trainer ํด๋์ค๋ ์ ๊ทผ์ฑ๊ณผ ์ฉ์ด์ฑ์ ์ผ๋์ ๋๊ณ ์ค๊ณ๋์์ง๋ง, ๋ ๋ค์ํ ๊ธฐ๋ฅ์ ์ํ๋ ์ฌ์ฉ์๋ค์ ์ํด ๋ค์ํ ๋ง์ถค ์ค์ ์ต์ ์ ์ ๊ณตํฉ๋๋ค. Trainer์ ๋ง์ ๋ฉ์๋๋ ์๋ธํด๋์คํ ๋ฐ ์ค๋ฒ๋ผ์ด๋ํ์ฌ ์ํ๋ ๊ธฐ๋ฅ์ ์ ๊ณตํ ์ ์์ผ๋ฉฐ, ์ด๋ฅผ ํตํด ์ ์ฒด ํ๋ จ ๋ฃจํ๋ฅผ ๋ค์ ์์ฑํ ํ์ ์์ด ์ํ๋ ๊ธฐ๋ฅ์ ์ถ๊ฐํ ์ ์์ต๋๋ค. ์ด๋ฌํ ๋ฉ์๋์๋ ๋ค์์ด ํฌํจ๋ฉ๋๋ค:
- get_train_dataloader()๋ ํ๋ จ ๋ฐ์ดํฐ๋ก๋๋ฅผ ์์ฑํฉ๋๋ค.
- get_eval_dataloader()๋ ํ๊ฐ ๋ฐ์ดํฐ๋ก๋๋ฅผ ์์ฑํฉ๋๋ค.
- get_test_dataloader()๋ ํ ์คํธ ๋ฐ์ดํฐ๋ก๋๋ฅผ ์์ฑํฉ๋๋ค.
- log()๋ ํ๋ จ์ ๋ชจ๋ํฐ๋งํ๋ ๋ค์ํ ๊ฐ์ฒด์ ๋ํ ์ ๋ณด๋ฅผ ๋ก๊ทธ๋ก ๋จ๊น๋๋ค.
- create_optimizer_and_scheduler()๋
__init__
์์ ์ ๋ฌ๋์ง ์์ ๊ฒฝ์ฐ ์ตํฐ๋ง์ด์ ์ ํ์ต๋ฅ ์ค์ผ์ค๋ฌ๋ฅผ ์์ฑํฉ๋๋ค. ์ด๋ค์ ๊ฐ๊ฐ create_optimizer() ๋ฐ create_scheduler()๋ก ๋ณ๋๋ก ๋ง์ถค ์ค์ ํ ์ ์์ต๋๋ค. - compute_loss()๋ ํ๋ จ ์ ๋ ฅ ๋ฐฐ์น์ ๋ํ ์์ค์ ๊ณ์ฐํฉ๋๋ค.
- training_step()๋ ํ๋ จ ๋จ๊ณ๋ฅผ ์ํํฉ๋๋ค.
- prediction_step()๋ ์์ธก ๋ฐ ํ ์คํธ ๋จ๊ณ๋ฅผ ์ํํฉ๋๋ค.
- evaluate()๋ ๋ชจ๋ธ์ ํ๊ฐํ๊ณ ํ๊ฐ ์งํ์ ๋ฐํํฉ๋๋ค.
- predict()๋ ํ ์คํธ ์ธํธ์ ๋ํ ์์ธก(๋ ์ด๋ธ์ด ์๋ ๊ฒฝ์ฐ ์งํ ํฌํจ)์ ์ํํฉ๋๋ค.
์๋ฅผ ๋ค์ด, compute_loss() ๋ฉ์๋๋ฅผ ๋ง์ถค ์ค์ ํ์ฌ ๊ฐ์ค ์์ค์ ์ฌ์ฉํ๋ ค๋ ๊ฒฝ์ฐ:
from torch import nn
from transformers import Trainer
class CustomTrainer(Trainer):
def compute_loss(self,
model, inputs, return_outputs=False):
labels = inputs.pop("labels")
# ์๋ฐฉํฅ ์ ํ
outputs = model(**inputs)
logits = outputs.get("logits")
# ์๋ก ๋ค๋ฅธ ๊ฐ์ค์น๋ก 3๊ฐ์ ๋ ์ด๋ธ์ ๋ํ ์ฌ์ฉ์ ์ ์ ์์ค์ ๊ณ์ฐ
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0], device=model.device))
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
return (loss, outputs) if return_outputs else loss
์ฝ๋ฐฑ
Trainer๋ฅผ ๋ง์ถค ์ค์ ํ๋ ๋ ๋ค๋ฅธ ๋ฐฉ๋ฒ์ ์ฝ๋ฐฑ์ ์ฌ์ฉํ๋ ๊ฒ์ ๋๋ค. ์ฝ๋ฐฑ์ ํ๋ จ ๋ฃจํ์์ ๋ณํ๋ฅผ ์ฃผ์ง ์์ต๋๋ค. ํ๋ จ ๋ฃจํ์ ์ํ๋ฅผ ๊ฒ์ฌํ ํ ์ํ์ ๋ฐ๋ผ ์ผ๋ถ ์์ (์กฐ๊ธฐ ์ข ๋ฃ, ๊ฒฐ๊ณผ ๋ก๊ทธ ๋ฑ)์ ์คํํฉ๋๋ค. ์ฆ, ์ฝ๋ฐฑ์ ์ฌ์ฉ์ ์ ์ ์์ค ํจ์์ ๊ฐ์ ๊ฒ์ ๊ตฌํํ๋ ๋ฐ ์ฌ์ฉํ ์ ์์ผ๋ฉฐ, ์ด๋ฅผ ์ํด์๋ compute_loss() ๋ฉ์๋๋ฅผ ์๋ธํด๋์คํํ๊ณ ์ค๋ฒ๋ผ์ด๋ํด์ผ ํฉ๋๋ค.
์๋ฅผ ๋ค์ด, ํ๋ จ ๋ฃจํ์ 10๋จ๊ณ ํ ์กฐ๊ธฐ ์ข ๋ฃ ์ฝ๋ฐฑ์ ์ถ๊ฐํ๋ ค๋ฉด ๋ค์๊ณผ ๊ฐ์ด ํฉ๋๋ค.
from transformers import TrainerCallback
class EarlyStoppingCallback(TrainerCallback):
def __init__(self, num_steps=10):
self.num_steps = num_steps
def on_step_end(self, args, state, control, **kwargs):
if state.global_step >= self.num_steps:
return {"should_training_stop": True}
else:
return {}
๊ทธ๋ฐ ๋ค์, ์ด๋ฅผ Trainer์ callback
๋งค๊ฐ๋ณ์์ ์ ๋ฌํฉ๋๋ค.
from transformers import Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
callbacks=[EarlyStoppingCallback()],
)
๋ก๊น
๋ก๊น API์ ๋ํ ์์ธํ ๋ด์ฉ์ ๋ก๊น API ๋ ํผ๋ฐ์ค๋ฅผ ํ์ธํ์ธ์.
Trainer๋ ๊ธฐ๋ณธ์ ์ผ๋ก logging.INFO
๋ก ์ค์ ๋์ด ์์ด ์ค๋ฅ, ๊ฒฝ๊ณ ๋ฐ ๊ธฐํ ๊ธฐ๋ณธ ์ ๋ณด๋ฅผ ๋ณด๊ณ ํฉ๋๋ค. ๋ถ์ฐ ํ๊ฒฝ์์๋ Trainer ๋ณต์ ๋ณธ์ด logging.WARNING
์ผ๋ก ์ค์ ๋์ด ์ค๋ฅ์ ๊ฒฝ๊ณ ๋ง ๋ณด๊ณ ํฉ๋๋ค. TrainingArguments์ log_level
๋ฐ log_level_replica
๋งค๊ฐ๋ณ์๋ก ๋ก๊ทธ ๋ ๋ฒจ์ ๋ณ๊ฒฝํ ์ ์์ต๋๋ค.
๊ฐ ๋
ธ๋์ ๋ก๊ทธ ๋ ๋ฒจ ์ค์ ์ ๊ตฌ์ฑํ๋ ค๋ฉด log_on_each_node
๋งค๊ฐ๋ณ์๋ฅผ ์ฌ์ฉํ์ฌ ๊ฐ ๋
ธ๋์์ ๋ก๊ทธ ๋ ๋ฒจ์ ์ฌ์ฉํ ์ง ์๋๋ฉด ์ฃผ ๋
ธ๋์์๋ง ์ฌ์ฉํ ์ง ๊ฒฐ์ ํ์ธ์.
Trainer๋ Trainer.__init__()
๋ฉ์๋์์ ๊ฐ ๋
ธ๋์ ๋ํด ๋ก๊ทธ ๋ ๋ฒจ์ ๋ณ๋๋ก ์ค์ ํ๋ฏ๋ก, ๋ค๋ฅธ Transformers ๊ธฐ๋ฅ์ ์ฌ์ฉํ ๊ฒฝ์ฐ Trainer ๊ฐ์ฒด๋ฅผ ์์ฑํ๊ธฐ ์ ์ ์ด๋ฅผ ๋ฏธ๋ฆฌ ์ค์ ํ๋ ๊ฒ์ด ์ข์ต๋๋ค.
์๋ฅผ ๋ค์ด, ๋ฉ์ธ ์ฝ๋์ ๋ชจ๋์ ๊ฐ ๋ ธ๋์ ๋ฐ๋ผ ๋์ผํ ๋ก๊ทธ ๋ ๋ฒจ์ ์ฌ์ฉํ๋๋ก ์ค์ ํ๋ ค๋ฉด ๋ค์๊ณผ ๊ฐ์ด ํฉ๋๋ค.
logger = logging.getLogger(__name__)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
trainer = Trainer(...)
๊ฐ ๋
ธ๋์์ ๊ธฐ๋ก๋ ๋ด์ฉ์ ๊ตฌ์ฑํ๊ธฐ ์ํด log_level
๊ณผ log_level_replica
๋ฅผ ๋ค์ํ ์กฐํฉ์ผ๋ก ์ฌ์ฉํด๋ณด์ธ์.
my_app.py ... --log_level warning --log_level_replica error
NEFTune
NEFTune์ ํ๋ จ ์ค ์๋ฒ ๋ฉ ๋ฒกํฐ์ ๋
ธ์ด์ฆ๋ฅผ ์ถ๊ฐํ์ฌ ์ฑ๋ฅ์ ํฅ์์ํฌ ์ ์๋ ๊ธฐ์ ์
๋๋ค. Trainer์์ ์ด๋ฅผ ํ์ฑํํ๋ ค๋ฉด TrainingArguments์ neftune_noise_alpha
๋งค๊ฐ๋ณ์๋ฅผ ์ค์ ํ์ฌ ๋
ธ์ด์ฆ์ ์์ ์กฐ์ ํฉ๋๋ค.
from transformers import TrainingArguments, Trainer
training_args = TrainingArguments(..., neftune_noise_alpha=0.1)
trainer = Trainer(..., args=training_args)
NEFTune์ ์์์น ๋ชปํ ๋์์ ํผํ ๋ชฉ์ ์ผ๋ก ์ฒ์ ์๋ฒ ๋ฉ ๋ ์ด์ด๋ก ๋ณต์ํ๊ธฐ ์ํด ํ๋ จ ํ ๋นํ์ฑํ ๋ฉ๋๋ค.
GaLore
Gradient Low-Rank Projection (GaLore)์ ์ ์ฒด ๋งค๊ฐ๋ณ์๋ฅผ ํ์ตํ๋ฉด์๋ LoRA์ ๊ฐ์ ์ผ๋ฐ์ ์ธ ์ ๊ณ์ ์ ์ ๋ฐฉ๋ฒ๋ณด๋ค ๋ ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ ์ธ ์ ๊ณ์ ํ์ต ์ ๋ต์ ๋๋ค.
๋จผ์ GaLore ๊ณต์ ๋ฆฌํฌ์งํ ๋ฆฌ๋ฅผ ์ค์นํฉ๋๋ค:
pip install galore-torch
๊ทธ๋ฐ ๋ค์ optim
์ ["galore_adamw", "galore_adafactor", "galore_adamw_8bit"]
์ค ํ๋์ ํจ๊ป optim_target_modules
๋ฅผ ์ถ๊ฐํฉ๋๋ค. ์ด๋ ์ ์ฉํ๋ ค๋ ๋์ ๋ชจ๋ ์ด๋ฆ์ ํด๋นํ๋ ๋ฌธ์์ด, ์ ๊ท ํํ์ ๋๋ ์ ์ฒด ๊ฒฝ๋ก์ ๋ชฉ๋ก์ผ ์ ์์ต๋๋ค. ์๋๋ end-to-end ์์ ์คํฌ๋ฆฝํธ์
๋๋ค(ํ์ํ ๊ฒฝ์ฐ pip install trl datasets
๋ฅผ ์คํ):
import datasets
from trl import SFTConfig, SFTTrainer
train_dataset = datasets.load_dataset('imdb', split='train')
args = SFTConfig(
output_dir="./test-galore",
max_steps=100,
optim="galore_adamw",
optim_target_modules=[r".*.attn.*", r".*.mlp.*"],
gradient_checkpointing=True,
)
trainer = SFTTrainer(
model="google/gemma-2b",
args=args,
train_dataset=train_dataset,
)
trainer.train()
GaLore๊ฐ ์ง์ํ๋ ์ถ๊ฐ ๋งค๊ฐ๋ณ์๋ฅผ ์ ๋ฌํ๋ ค๋ฉด optim_args
๋ฅผ ์ค์ ํฉ๋๋ค. ์๋ฅผ ๋ค์ด:
import datasets
from trl import SFTConfig, SFTTrainer
train_dataset = datasets.load_dataset('imdb', split='train')
args = SFTConfig(
output_dir="./test-galore",
max_steps=100,
optim="galore_adamw",
optim_target_modules=[r".*.attn.*", r".*.mlp.*"],
optim_args="rank=64, update_proj_gap=100, scale=0.10",
gradient_checkpointing=True,
)
trainer = SFTTrainer(
model="google/gemma-2b",
args=args,
train_dataset=train_dataset,
)
trainer.train()
ํด๋น ๋ฐฉ๋ฒ์ ๋ํ ์์ธํ ๋ด์ฉ์ ์๋ณธ ๋ฆฌํฌ์งํ ๋ฆฌ ๋๋ ๋ ผ๋ฌธ์ ์ฐธ๊ณ ํ์ธ์.
ํ์ฌ GaLore ๋ ์ด์ด๋ก ๊ฐ์ฃผ๋๋ Linear ๋ ์ด์ด๋ง ํ๋ จ ํ ์ ์์ผ๋ฉฐ, ์ ๊ณ์ ๋ถํด๋ฅผ ์ฌ์ฉํ์ฌ ํ๋ จ๋๊ณ ๋๋จธ์ง ๋ ์ด์ด๋ ๊ธฐ์กด ๋ฐฉ์์ผ๋ก ์ต์ ํ๋ฉ๋๋ค.
ํ๋ จ ์์ ์ ์ ์๊ฐ์ด ์ฝ๊ฐ ๊ฑธ๋ฆด ์ ์์ต๋๋ค(NVIDIA A100์์ 2B ๋ชจ๋ธ์ ๊ฒฝ์ฐ ์ฝ 3๋ถ), ํ์ง๋ง ์ดํ ํ๋ จ์ ์ํํ๊ฒ ์งํ๋ฉ๋๋ค.
๋ค์๊ณผ ๊ฐ์ด ์ตํฐ๋ง์ด์ ์ด๋ฆ์ layerwise
๋ฅผ ์ถ๊ฐํ์ฌ ๋ ์ด์ด๋ณ ์ต์ ํ๋ฅผ ์ํํ ์๋ ์์ต๋๋ค:
import datasets
from trl import SFTConfig, SFTTrainer
train_dataset = datasets.load_dataset('imdb', split='train')
args = SFTConfig(
output_dir="./test-galore",
max_steps=100,
optim="galore_adamw_layerwise",
optim_target_modules=[r".*.attn.*", r".*.mlp.*"],
gradient_checkpointing=True,
)
trainer = SFTTrainer(
model="google/gemma-2b",
args=args,
train_dataset=train_dataset,
)
trainer.train()
๋ ์ด์ด๋ณ ์ต์ ํ๋ ๋ค์ ์คํ์ ์ด๋ฉฐ DDP(๋ถ์ฐ ๋ฐ์ดํฐ ๋ณ๋ ฌ)๋ฅผ ์ง์ํ์ง ์์ผ๋ฏ๋ก, ๋จ์ผ GPU์์๋ง ํ๋ จ ์คํฌ๋ฆฝํธ๋ฅผ ์คํํ ์ ์์ต๋๋ค. ์์ธํ ๋ด์ฉ์ ์ด ๋ฌธ์๋ฅผ์ ์ฐธ์กฐํ์ธ์. gradient clipping, DeepSpeed ๋ฑ ๋ค๋ฅธ ๊ธฐ๋ฅ์ ๊ธฐ๋ณธ์ ์ผ๋ก ์ง์๋์ง ์์ ์ ์์ต๋๋ค. ์ด๋ฌํ ๋ฌธ์ ๊ฐ ๋ฐ์ํ๋ฉด GitHub์ ์ด์๋ฅผ ์ฌ๋ ค์ฃผ์ธ์.
LOMO ์ตํฐ๋ง์ด์
LOMO ์ตํฐ๋ง์ด์ ๋ ์ ํ๋ ์์์ผ๋ก ๋ํ ์ธ์ด ๋ชจ๋ธ์ ์ ์ฒด ๋งค๊ฐ๋ณ์ ๋ฏธ์ธ ์กฐ์ ๊ณผ ์ ์ํ ํ์ต๋ฅ ์ ํตํ ์ ๋ฉ๋ชจ๋ฆฌ ์ต์ ํ(AdaLomo)์์ ๋์
๋์์ต๋๋ค.
์ด๋ค์ ๋ชจ๋ ํจ์จ์ ์ธ ์ ์ฒด ๋งค๊ฐ๋ณ์ ๋ฏธ์ธ ์กฐ์ ๋ฐฉ๋ฒ์ผ๋ก ๊ตฌ์ฑ๋์ด ์์ต๋๋ค. ์ด๋ฌํ ์ตํฐ๋ง์ด์ ๋ค์ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ ์ค์ด๊ธฐ ์ํด ๊ทธ๋ ์ด๋์ธํธ ๊ณ์ฐ๊ณผ ๋งค๊ฐ๋ณ์ ์
๋ฐ์ดํธ๋ฅผ ํ๋์ ๋จ๊ณ๋ก ์ตํฉํฉ๋๋ค. LOMO์์ ์ง์๋๋ ์ตํฐ๋ง์ด์ ๋ "lomo"
์ "adalomo"
์
๋๋ค. ๋จผ์ pypi์์ pip install lomo-optim
๋ฅผ ํตํด lomo
๋ฅผ ์ค์นํ๊ฑฐ๋, GitHub ์์ค์์ pip install git+https://github.com/OpenLMLab/LOMO.git
๋ก ์ค์นํ์ธ์.
์ ์์ ๋ฐ๋ฅด๋ฉด, grad_norm
์์ด AdaLomo
๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ด ๋ ๋์ ์ฑ๋ฅ๊ณผ ๋์ ์ฒ๋ฆฌ๋์ ์ ๊ณตํ๋ค๊ณ ํฉ๋๋ค.
๋ค์์ IMDB ๋ฐ์ดํฐ์ ์์ google/gemma-2b๋ฅผ ์ต๋ ์ ๋ฐ๋๋ก ๋ฏธ์ธ ์กฐ์ ํ๋ ๊ฐ๋จํ ์คํฌ๋ฆฝํธ์ ๋๋ค:
import datasets
from trl import SFTConfig, SFTTrainer
train_dataset = datasets.load_dataset('imdb', split='train')
args = SFTConfig(
output_dir="./test-lomo",
max_steps=100,
optim="adalomo",
gradient_checkpointing=True,
)
trainer = SFTTrainer(
model="google/gemma-2b",
args=args,
train_dataset=train_dataset,
)
trainer.train()
Accelerate์ Trainer
Trainer ํด๋์ค๋ Accelerate๋ก ๊ตฌ๋๋๋ฉฐ, ์ด๋ FullyShardedDataParallel (FSDP) ๋ฐ DeepSpeed์ ๊ฐ์ ํตํฉ์ ์ง์ํ๋ ๋ถ์ฐ ํ๊ฒฝ์์ PyTorch ๋ชจ๋ธ์ ์ฝ๊ฒ ํ๋ จํ ์ ์๋ ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ๋๋ค.
FSDP ์ค๋ฉ ์ ๋ต, CPU ์คํ๋ก๋ ๋ฐ Trainer์ ํจ๊ป ์ฌ์ฉํ ์ ์๋ ๋ ๋ง์ ๊ธฐ๋ฅ์ ์์๋ณด๋ ค๋ฉด Fully Sharded Data Parallel ๊ฐ์ด๋๋ฅผ ํ์ธํ์ธ์.
Trainer์ Accelerate๋ฅผ ์ฌ์ฉํ๋ ค๋ฉด accelerate.config
๋ช
๋ น์ ์คํํ์ฌ ํ๋ จ ํ๊ฒฝ์ ์ค์ ํ์ธ์. ์ด ๋ช
๋ น์ ํ๋ จ ์คํฌ๋ฆฝํธ๋ฅผ ์คํํ ๋ ์ฌ์ฉํ config_file.yaml
์ ์์ฑํฉ๋๋ค. ์๋ฅผ ๋ค์ด, ๋ค์ ์์๋ ์ค์ ํ ์ ์๋ ์ผ๋ถ ๊ตฌ์ฑ ์์
๋๋ค.
compute_environment: LOCAL_MACHINE
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0 # ๋
ธ๋์ ๋ฐ๋ผ ์์๋ฅผ ๋ณ๊ฒฝํ์ธ์
main_process_ip: 192.168.20.1
main_process_port: 9898
main_training_function: main
mixed_precision: fp16
num_machines: 2
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
accelerate_launch
๋ช
๋ น์ Accelerate์ Trainer๋ฅผ ์ฌ์ฉํ์ฌ ๋ถ์ฐ ์์คํ
์์ ํ๋ จ ์คํฌ๋ฆฝํธ๋ฅผ ์คํํ๋ ๊ถ์ฅ ๋ฐฉ๋ฒ์ด๋ฉฐ, config_file.yaml
์ ์ง์ ๋ ๋งค๊ฐ๋ณ์๋ฅผ ์ฌ์ฉํฉ๋๋ค. ์ด ํ์ผ์ Accelerate ์บ์ ํด๋์ ์ ์ฅ๋๋ฉฐ accelerate_launch
๋ฅผ ์คํํ ๋ ์๋์ผ๋ก ๋ก๋๋ฉ๋๋ค.
์๋ฅผ ๋ค์ด, FSDP ๊ตฌ์ฑ์ ์ฌ์ฉํ์ฌ run_glue.py ํ๋ จ ์คํฌ๋ฆฝํธ๋ฅผ ์คํํ๋ ค๋ฉด ๋ค์๊ณผ ๊ฐ์ด ํฉ๋๋ค:
accelerate launch \
./examples/pytorch/text-classification/run_glue.py \
--model_name_or_path google-bert/bert-base-cased \
--task_name $TASK_NAME \
--do_train \
--do_eval \
--max_seq_length 128 \
--per_device_train_batch_size 16 \
--learning_rate 5e-5 \
--num_train_epochs 3 \
--output_dir /tmp/$TASK_NAME/ \
--overwrite_output_dir
config_file.yaml
ํ์ผ์ ๋งค๊ฐ๋ณ์๋ฅผ ์ง์ ์ง์ ํ ์๋ ์์ต๋๋ค:
accelerate launch --num_processes=2 \
--use_fsdp \
--mixed_precision=bf16 \
--fsdp_auto_wrap_policy=TRANSFORMER_BASED_WRAP \
--fsdp_transformer_layer_cls_to_wrap="BertLayer" \
--fsdp_sharding_strategy=1 \
--fsdp_state_dict_type=FULL_STATE_DICT \
./examples/pytorch/text-classification/run_glue.py \
--model_name_or_path google-bert/bert-base-cased \
--task_name $TASK_NAME \
--do_train \
--do_eval \
--max_seq_length 128 \
--per_device_train_batch_size 16 \
--learning_rate 5e-5 \
--num_train_epochs 3 \
--output_dir /tmp/$TASK_NAME/ \
--overwrite_output_dir
accelerate_launch
์ ์ฌ์ฉ์ ์ ์ ๊ตฌ์ฑ์ ๋ํด ๋ ์์๋ณด๋ ค๋ฉด Accelerate ์คํฌ๋ฆฝํธ ์คํ ํํ ๋ฆฌ์ผ์ ํ์ธํ์ธ์.