Transformers documentation
bitsandbytes
bitsandbytes
bitsandbytes๋ ๋ชจ๋ธ์ 8๋นํธ ๋ฐ 4๋นํธ๋ก ์์ํํ๋ ๊ฐ์ฅ ์ฌ์ด ๋ฐฉ๋ฒ์ ๋๋ค. 8๋นํธ ์์ํ๋ fp16์ ์ด์์น์ int8์ ๋น์ด์์น๋ฅผ ๊ณฑํ ํ, ๋น์ด์์น ๊ฐ์ fp16์ผ๋ก ๋ค์ ๋ณํํ๊ณ , ์ด๋ค์ ํฉ์ฐํ์ฌ fp16์ผ๋ก ๊ฐ์ค์น๋ฅผ ๋ฐํํฉ๋๋ค. ์ด๋ ๊ฒ ํ๋ฉด ์ด์์น ๊ฐ์ด ๋ชจ๋ธ ์ฑ๋ฅ์ ๋ฏธ์น๋ ์ ํ ํจ๊ณผ๋ฅผ ์ค์ผ ์ ์์ต๋๋ค. 4๋นํธ ์์ํ๋ ๋ชจ๋ธ์ ๋์ฑ ์์ถํ๋ฉฐ, QLoRA์ ํจ๊ป ์ฌ์ฉํ์ฌ ์์ํ๋ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ๋ฏธ์ธ ์กฐ์ ํ๋ ๋ฐ ํํ ์ฌ์ฉ๋ฉ๋๋ค.
bitsandbytes๋ฅผ ์ฌ์ฉํ๋ ค๋ฉด ๋ค์ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๊ฐ ์ค์น๋์ด ์์ด์ผ ํฉ๋๋ค:
pip install transformers accelerate bitsandbytes>0.37.0
์ด์ BitsAndBytesConfig
๋ฅผ from_pretrained() ๋ฉ์๋์ ์ ๋ฌํ์ฌ ๋ชจ๋ธ์ ์์ํํ ์ ์์ต๋๋ค. ์ด๋ Accelerate ๊ฐ์ ธ์ค๊ธฐ๋ฅผ ์ง์ํ๊ณ torch.nn.Linear
๋ ์ด์ด๊ฐ ํฌํจ๋ ๋ชจ๋ ๋ชจ๋ธ์์ ์๋ํฉ๋๋ค.
๋ชจ๋ธ์ 8๋นํธ๋ก ์์ํํ๋ฉด ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ด ์ ๋ฐ์ผ๋ก ์ค์ด๋ค๋ฉฐ, ๋๊ท๋ชจ ๋ชจ๋ธ์ ๊ฒฝ์ฐ ์ฌ์ฉ ๊ฐ๋ฅํ GPU๋ฅผ ํจ์จ์ ์ผ๋ก ํ์ฉํ๋ ค๋ฉด device_map="auto"
๋ฅผ ์ค์ ํ์ธ์.
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
model_8bit = AutoModelForCausalLM.from_pretrained(
"bigscience/bloom-1b7",
quantization_config=quantization_config
)
๊ธฐ๋ณธ์ ์ผ๋ก torch.nn.LayerNorm
๊ณผ ๊ฐ์ ๋ค๋ฅธ ๋ชจ๋์ torch.float16
์ผ๋ก ๋ณํ๋ฉ๋๋ค. ์ํ๋ค๋ฉด dtype
๋งค๊ฐ๋ณ์๋ก ์ด๋ค ๋ชจ๋์ ๋ฐ์ดํฐ ์ ํ์ ๋ณ๊ฒฝํ ์ ์์ต๋๋ค:
import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
model_8bit = AutoModelForCausalLM.from_pretrained(
"facebook/opt-350m",
quantization_config=quantization_config,
dtype=torch.float32
)
model_8bit.model.decoder.layers[-1].final_layer_norm.weight.dtype
๋ชจ๋ธ์ด 8๋นํธ๋ก ์์ํ๋๋ฉด ์ต์ ๋ฒ์ ์ Transformers์ bitsandbytes๋ฅผ ์ฌ์ฉํ์ง ์๋ ํ ์์ํ๋ ๊ฐ์ค์น๋ฅผ Hub์ ํธ์ํ ์ ์์ต๋๋ค. ์ต์ ๋ฒ์ ์ ์ฌ์ฉํ๋ ๊ฒฝ์ฐ, push_to_hub() ๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ 8๋นํธ ๋ชจ๋ธ์ Hub์ ํธ์ํ ์ ์์ต๋๋ค. ์์ํ config.json ํ์ผ์ด ๋จผ์ ํธ์๋๊ณ , ๊ทธ ๋ค์ ์์ํ๋ ๋ชจ๋ธ ๊ฐ์ค์น๊ฐ ํธ์๋ฉ๋๋ค.
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
model = AutoModelForCausalLM.from_pretrained(
"bigscience/bloom-560m",
quantization_config=quantization_config
)
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
model.push_to_hub("bloom-560m-8bit")
8๋นํธ ๋ฐ 4๋นํธ ๊ฐ์ค์น๋ก ํ๋ จํ๋ ๊ฒ์ ์ถ๊ฐ ๋งค๊ฐ๋ณ์์ ๋ํด์๋ง ์ง์๋ฉ๋๋ค.
๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ ํ์ธํ๋ ค๋ฉด get_memory_footprint
๋ฅผ ์ฌ์ฉํ์ธ์:
print(model.get_memory_footprint())
์์ํ๋ ๋ชจ๋ธ์ from_pretrained() ๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ load_in_8bit
๋๋ load_in_4bit
๋งค๊ฐ๋ณ์๋ฅผ ์ง์ ํ์ง ์๊ณ ๋ ๊ฐ์ ธ์ฌ ์ ์์ต๋๋ค:
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("{your_username}/bloom-560m-8bit", device_map="auto")
8๋นํธ (LLM.int8() ์๊ณ ๋ฆฌ์ฆ)
8๋นํธ ์์ํ์ ๋ํ ์์ธํ ๋ด์ฉ์ ์๊ณ ์ถ๋ค๋ฉด ์ด ๋ธ๋ก๊ทธ ํฌ์คํธ๋ฅผ ์ฐธ์กฐํ์ธ์!
์ด ์น์ ์์๋ ์คํ๋ก๋ฉ, ์ด์์น ์๊ณ๊ฐ, ๋ชจ๋ ๋ณํ ๊ฑด๋๋ฐ๊ธฐ ๋ฐ ๋ฏธ์ธ ์กฐ์ ๊ณผ ๊ฐ์ 8๋นํธ ๋ชจ๋ธ์ ํน์ ๊ธฐ๋ฅ์ ์ดํด๋ด ๋๋ค.
์คํ๋ก๋ฉ
8๋นํธ ๋ชจ๋ธ์ CPU์ GPU ๊ฐ์ ๊ฐ์ค์น๋ฅผ ์คํ๋ก๋ํ์ฌ ๋งค์ฐ ํฐ ๋ชจ๋ธ์ ๋ฉ๋ชจ๋ฆฌ์ ์ฅ์ฐฉํ ์ ์์ต๋๋ค. CPU๋ก ์ ์ก๋ ๊ฐ์ค์น๋ ์ค์ ๋ก float32๋ก ์ ์ฅ๋๋ฉฐ 8๋นํธ๋ก ๋ณํ๋์ง ์์ต๋๋ค. ์๋ฅผ ๋ค์ด, bigscience/bloom-1b7 ๋ชจ๋ธ์ ์คํ๋ก๋๋ฅผ ํ์ฑํํ๋ ค๋ฉด BitsAndBytesConfig๋ฅผ ์์ฑํ๋ ๊ฒ๋ถํฐ ์์ํ์ธ์:
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(llm_int8_enable_fp32_cpu_offload=True)
CPU์ ์ ๋ฌํ lm_head
๋ฅผ ์ ์ธํ ๋ชจ๋ ๊ฒ์ GPU์ ์ ์ฌํ ์ ์๋๋ก ์ฌ์ฉ์ ์ ์ ๋๋ฐ์ด์ค ๋งต์ ์ค๊ณํฉ๋๋ค:
device_map = {
"transformer.word_embeddings": 0,
"transformer.word_embeddings_layernorm": 0,
"lm_head": "cpu",
"transformer.h": 0,
"transformer.ln_f": 0,
}
์ด์ ์ฌ์ฉ์ ์ ์ device_map
๊ณผ quantization_config
์ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ ๊ฐ์ ธ์ต๋๋ค:
model_8bit = AutoModelForCausalLM.from_pretrained(
"bigscience/bloom-1b7",
device_map=device_map,
quantization_config=quantization_config,
)
์ด์์น ์๊ณ๊ฐ
โ์ด์์นโ๋ ํน์ ์๊ณ๊ฐ์ ์ด๊ณผํ๋ ์๋ ์ํ ๊ฐ์ ์๋ฏธํ๋ฉฐ, ์ด๋ฌํ ๊ฐ์ fp16์ผ๋ก ๊ณ์ฐ๋ฉ๋๋ค. ๊ฐ์ ์ผ๋ฐ์ ์ผ๋ก ์ ๊ท ๋ถํฌ ([-3.5, 3.5])๋ฅผ ๋ฐ๋ฅด์ง๋ง, ๋๊ท๋ชจ ๋ชจ๋ธ์ ๊ฒฝ์ฐ ์ด ๋ถํฌ๋ ๋งค์ฐ ๋ค๋ฅผ ์ ์์ต๋๋ค ([-60, 6] ๋๋ [6, 60]). 8๋นํธ ์์ํ๋ ~5 ์ ๋์ ๊ฐ์์ ์ ์๋ํ์ง๋ง, ๊ทธ ์ด์์์๋ ์๋นํ ์ฑ๋ฅ ์ ํ๊ฐ ๋ฐ์ํฉ๋๋ค. ์ข์ ๊ธฐ๋ณธ ์๊ณ๊ฐ ๊ฐ์ 6์ด์ง๋ง, ๋ ๋ถ์์ ํ ๋ชจ๋ธ (์ํ ๋ชจ๋ธ ๋๋ ๋ฏธ์ธ ์กฐ์ )์๋ ๋ ๋ฎ์ ์๊ณ๊ฐ์ด ํ์ํ ์ ์์ต๋๋ค.
๋ชจ๋ธ์ ๊ฐ์ฅ ์ ํฉํ ์๊ณ๊ฐ์ ์ฐพ์ผ๋ ค๋ฉด BitsAndBytesConfig์์ llm_int8_threshold
๋งค๊ฐ๋ณ์๋ฅผ ์คํํด๋ณด๋ ๊ฒ์ด ์ข์ต๋๋ค:
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
model_id = "bigscience/bloom-1b7"
quantization_config = BitsAndBytesConfig(
llm_int8_threshold=10,
)
model_8bit = AutoModelForCausalLM.from_pretrained(
model_id,
device_map=device_map,
quantization_config=quantization_config,
)
๋ชจ๋ ๋ณํ ๊ฑด๋๋ฐ๊ธฐ
Jukebox์ ๊ฐ์ ์ผ๋ถ ๋ชจ๋ธ์ ๋ชจ๋ ๋ชจ๋์ 8๋นํธ๋ก ์์ํํ ํ์๊ฐ ์์ผ๋ฉฐ, ์ด๋ ์ค์ ๋ก ๋ถ์์ ์ฑ์ ์ ๋ฐํ ์ ์์ต๋๋ค. Jukebox์ ๊ฒฝ์ฐ, BitsAndBytesConfig์ llm_int8_skip_modules
๋งค๊ฐ๋ณ์๋ฅผ ์ฌ์ฉํ์ฌ ์ฌ๋ฌ lm_head
๋ชจ๋์ ๊ฑด๋๋ฐ์ด์ผ ํฉ๋๋ค:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
model_id = "bigscience/bloom-1b7"
quantization_config = BitsAndBytesConfig(
llm_int8_skip_modules=["lm_head"],
)
model_8bit = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
quantization_config=quantization_config,
)
๋ฏธ์ธ ์กฐ์
PEFT ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ฌ์ฉํ๋ฉด flan-t5-large ๋ฐ facebook/opt-6.7b์ ๊ฐ์ ๋๊ท๋ชจ ๋ชจ๋ธ์ 8๋นํธ ์์ํ๋ก ๋ฏธ์ธ ์กฐ์ ํ ์ ์์ต๋๋ค. ํ๋ จ ์ device_map
๋งค๊ฐ๋ณ์๋ฅผ ์ ๋ฌํ ํ์๊ฐ ์์ผ๋ฉฐ, ๋ชจ๋ธ์ ์๋์ผ๋ก GPU์ ๊ฐ์ ธ์ต๋๋ค. ๊ทธ๋ฌ๋ ์ํ๋ ๊ฒฝ์ฐ device_map
๋งค๊ฐ๋ณ์๋ก ์ฅ์น ๋งต์ ์ฌ์ฉ์ ์ ์ํ ์ ์์ต๋๋ค (device_map="auto"
๋ ์ถ๋ก ์๋ง ์ฌ์ฉํด์ผ ํฉ๋๋ค).
4๋นํธ (QLoRA ์๊ณ ๋ฆฌ์ฆ)
์ด ๋ ธํธ๋ถ์์ 4๋นํธ ์์ํ๋ฅผ ์๋ํด๋ณด๊ณ ์์ธํ ๋ด์ฉ์ ์ด ๋ธ๋ก๊ทธ ๊ฒ์๋ฌผ์์ ํ์ธํ์ธ์.
์ด ์น์ ์์๋ ๊ณ์ฐ ๋ฐ์ดํฐ ์ ํ ๋ณ๊ฒฝ, Normal Float 4 (NF4) ๋ฐ์ดํฐ ์ ํ ์ฌ์ฉ, ์ค์ฒฉ ์์ํ ์ฌ์ฉ๊ณผ ๊ฐ์ 4๋นํธ ๋ชจ๋ธ์ ํน์ ๊ธฐ๋ฅ ์ผ๋ถ๋ฅผ ํ๊ตฌํฉ๋๋ค.
๋ฐ์ดํฐ ์ ํ ๊ณ์ฐ
๊ณ์ฐ ์๋๋ฅผ ๋์ด๊ธฐ ์ํด BitsAndBytesConfig์์ bnb_4bit_compute_dtype
๋งค๊ฐ๋ณ์๋ฅผ ์ฌ์ฉํ์ฌ ๋ฐ์ดํฐ ์ ํ์ float32(๊ธฐ๋ณธ๊ฐ)์์ bf16์ผ๋ก ๋ณ๊ฒฝํ ์ ์์ต๋๋ค:
import torch
from transformers import BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)
Normal Float 4 (NF4)
NF4๋ QLoRA ๋
ผ๋ฌธ์์ ์๊ฐ๋ 4๋นํธ ๋ฐ์ดํฐ ์ ํ์ผ๋ก, ์ ๊ท ๋ถํฌ์์ ์ด๊ธฐํ๋ ๊ฐ์ค์น์ ์ ํฉํฉ๋๋ค. 4๋นํธ ๊ธฐ๋ฐ ๋ชจ๋ธ์ ํ๋ จํ ๋ NF4๋ฅผ ์ฌ์ฉํด์ผ ํฉ๋๋ค. ์ด๋ BitsAndBytesConfig์์ bnb_4bit_quant_type
๋งค๊ฐ๋ณ์๋ก ์ค์ ํ ์ ์์ต๋๋ค:
from transformers import BitsAndBytesConfig
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
)
model_nf4 = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=nf4_config)
์ถ๋ก ์ ๊ฒฝ์ฐ, bnb_4bit_quant_type
์ ์ฑ๋ฅ์ ํฐ ์ํฅ์ ๋ฏธ์น์ง ์์ต๋๋ค. ๊ทธ๋ฌ๋ ๋ชจ๋ธ ๊ฐ์ค์น์ ์ผ๊ด์ฑ์ ์ ์งํ๊ธฐ ์ํด bnb_4bit_compute_dtype
๋ฐ dtype
๊ฐ์ ์ฌ์ฉํด์ผ ํฉ๋๋ค.
์ค์ฒฉ ์์ํ
์ค์ฒฉ ์์ํ๋ ์ถ๊ฐ์ ์ธ ์ฑ๋ฅ ์์ค ์์ด ์ถ๊ฐ์ ์ธ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ ์ฝํ ์ ์๋ ๊ธฐ์ ์ ๋๋ค. ์ด ๊ธฐ๋ฅ์ ์ด๋ฏธ ์์ํ๋ ๊ฐ์ค์น์ 2์ฐจ ์์ํ๋ฅผ ์ํํ์ฌ ๋งค๊ฐ๋ณ์๋น ์ถ๊ฐ๋ก 0.4๋นํธ๋ฅผ ์ ์ฝํฉ๋๋ค. ์๋ฅผ ๋ค์ด, ์ค์ฒฉ ์์ํ๋ฅผ ํตํด 16GB NVIDIA T4 GPU์์ ์ํ์ค ๊ธธ์ด 1024, ๋ฐฐ์น ํฌ๊ธฐ 1, ๊ทธ๋ ์ด๋์ธํธ ๋์ 4๋จ๊ณ๋ฅผ ์ฌ์ฉํ์ฌ Llama-13b ๋ชจ๋ธ์ ๋ฏธ์ธ ์กฐ์ ํ ์ ์์ต๋๋ค.
from transformers import BitsAndBytesConfig
double_quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
)
model_double_quant = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-13b", quantization_config=double_quant_config)
bitsandbytes ๋ชจ๋ธ์ ๋น์์ํ
์์ํ๋ ํ์๋ ๋ชจ๋ธ์ ์๋์ ์ ๋ฐ๋๋ก ๋น์์ํํ ์ ์์ง๋ง, ์ด๋ ๋ชจ๋ธ์ ํ์ง์ด ์ฝ๊ฐ ์ ํ๋ ์ ์์ต๋๋ค. ๋น์์ํ๋ ๋ชจ๋ธ์ ๋ง์ถ ์ ์๋ ์ถฉ๋ถํ GPU RAM์ด ์๋์ง ํ์ธํ์ธ์.
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer
model_id = "facebook/opt-125m"
model = AutoModelForCausalLM.from_pretrained(model_id, BitsAndBytesConfig(load_in_4bit=True))
tokenizer = AutoTokenizer.from_pretrained(model_id)
model.dequantize()
text = tokenizer("Hello my name is", return_tensors="pt").to(0)
out = model.generate(**text)
print(tokenizer.decode(out[0]))