lynx   »   [go: up one dir, main page]

Diffusers documentation

InstructPix2Pix

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v0.35.1).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

InstructPix2Pix

InstructPix2Pix๋Š” text-conditioned diffusion ๋ชจ๋ธ์ด ํ•œ ์ด๋ฏธ์ง€์— ํŽธ์ง‘์„ ๋”ฐ๋ฅผ ์ˆ˜ ์žˆ๋„๋ก ํŒŒ์ธํŠœ๋‹ํ•˜๋Š” ๋ฐฉ๋ฒ•์ž…๋‹ˆ๋‹ค. ์ด ๋ฐฉ๋ฒ•์„ ์‚ฌ์šฉํ•˜์—ฌ ํŒŒ์ธํŠœ๋‹๋œ ๋ชจ๋ธ์€ ๋‹ค์Œ์„ ์ž…๋ ฅ์œผ๋กœ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค:

instructpix2pix-inputs

์ถœ๋ ฅ์€ ์ž…๋ ฅ ์ด๋ฏธ์ง€์— ํŽธ์ง‘ ์ง€์‹œ๊ฐ€ ๋ฐ˜์˜๋œ โ€œ์ˆ˜์ •๋œโ€ ์ด๋ฏธ์ง€์ž…๋‹ˆ๋‹ค:

instructpix2pix-output

train_instruct_pix2pix.py ์Šคํฌ๋ฆฝํŠธ(์—ฌ๊ธฐ์—์„œ ์ฐพ์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.)๋Š” ํ•™์Šต ์ ˆ์ฐจ๋ฅผ ์„ค๋ช…ํ•˜๊ณ  Stable Diffusion์— ์ ์šฉํ•  ์ˆ˜ ์žˆ๋Š” ๋ฐฉ๋ฒ•์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค.

*** train_instruct_pix2pix.py๋Š” ์›๋ž˜ ๊ตฌํ˜„์— ์ถฉ์‹คํ•˜๋ฉด์„œ InstructPix2Pix ํ•™์Šต ์ ˆ์ฐจ๋ฅผ ๊ตฌํ˜„ํ•˜๊ณ  ์žˆ์ง€๋งŒ, ์†Œ๊ทœ๋ชจ ๋ฐ์ดํ„ฐ์…‹์—์„œ๋งŒ ํ…Œ์ŠคํŠธ๋ฅผ ํ–ˆ์Šต๋‹ˆ๋‹ค. ์ด๋Š” ์ตœ์ข… ๊ฒฐ๊ณผ์— ์˜ํ–ฅ์„ ๋ผ์น  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋” ๋‚˜์€ ๊ฒฐ๊ณผ๋ฅผ ์œ„ํ•ด, ๋” ํฐ ๋ฐ์ดํ„ฐ์…‹์—์„œ ๋” ๊ธธ๊ฒŒ ํ•™์Šตํ•˜๋Š” ๊ฒƒ์„ ๊ถŒ์žฅํ•ฉ๋‹ˆ๋‹ค. ์—ฌ๊ธฐ์—์„œ InstructPix2Pix ํ•™์Šต์„ ์œ„ํ•ด ํฐ ๋ฐ์ดํ„ฐ์…‹์„ ์ฐพ์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.


PyTorch๋กœ ๋กœ์ปฌ์—์„œ ์‹คํ–‰ํ•˜๊ธฐ

์ข…์†์„ฑ(dependencies) ์„ค์น˜ํ•˜๊ธฐ

์ด ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์‹คํ–‰ํ•˜๊ธฐ ์ „์—, ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์˜ ํ•™์Šต ์ข…์†์„ฑ์„ ์„ค์น˜ํ•˜์„ธ์š”:

์ค‘์š”

์ตœ์‹  ๋ฒ„์ „์˜ ์˜ˆ์ œ ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์„ฑ๊ณต์ ์œผ๋กœ ์‹คํ–‰ํ•˜๊ธฐ ์œ„ํ•ด, ์›๋ณธ์œผ๋กœ๋ถ€ํ„ฐ ์„ค์น˜ํ•˜๋Š” ๊ฒƒ๊ณผ ์˜ˆ์ œ ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์ž์ฃผ ์—…๋ฐ์ดํŠธํ•˜๊ณ  ์˜ˆ์ œ๋ณ„ ์š”๊ตฌ์‚ฌํ•ญ์„ ์„ค์น˜ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ์ตœ์‹  ์ƒํƒœ๋กœ ์œ ์ง€ํ•˜๋Š” ๊ฒƒ์„ ๊ถŒ์žฅํ•ฉ๋‹ˆ๋‹ค. ์ด๋ฅผ ์œ„ํ•ด, ์ƒˆ๋กœ์šด ๊ฐ€์ƒ ํ™˜๊ฒฝ์—์„œ ๋‹ค์Œ ์Šคํ…์„ ์‹คํ–‰ํ•˜์„ธ์š”:

git clone https://github.com/huggingface/diffusers
cd diffusers
pip install -e .

cd ๋ช…๋ น์–ด๋กœ ์˜ˆ์ œ ํด๋”๋กœ ์ด๋™ํ•˜์„ธ์š”.

cd examples/instruct_pix2pix

์ด์ œ ์‹คํ–‰ํ•˜์„ธ์š”.

pip install -r requirements.txt

๊ทธ๋ฆฌ๊ณ  ๐Ÿค—Accelerate ํ™˜๊ฒฝ์—์„œ ์ดˆ๊ธฐํ™”ํ•˜์„ธ์š”:

accelerate config

ํ˜น์€ ํ™˜๊ฒฝ์— ๋Œ€ํ•œ ์งˆ๋ฌธ ์—†์ด ๊ธฐ๋ณธ์ ์ธ accelerate ๊ตฌ์„ฑ์„ ์‚ฌ์šฉํ•˜๋ ค๋ฉด ๋‹ค์Œ์„ ์‹คํ–‰ํ•˜์„ธ์š”.

accelerate config default

ํ˜น์€ ์‚ฌ์šฉ ์ค‘์ธ ํ™˜๊ฒฝ์ด notebook๊ณผ ๊ฐ™์€ ๋Œ€ํ™”ํ˜• ์‰˜์€ ์ง€์›ํ•˜์ง€ ์•Š๋Š” ๊ฒฝ์šฐ๋Š” ๋‹ค์Œ ์ ˆ์ฐจ๋ฅผ ๋”ฐ๋ผ์ฃผ์„ธ์š”.

from accelerate.utils import write_basic_config

write_basic_config()

์˜ˆ์‹œ

์ด์ „์— ์–ธ๊ธ‰ํ–ˆ๋“ฏ์ด, ํ•™์Šต์„ ์œ„ํ•ด ์ž‘์€ ๋ฐ์ดํ„ฐ์…‹์„ ์‚ฌ์šฉํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค. ๊ทธ ๋ฐ์ดํ„ฐ์…‹์€ InstructPix2Pix ๋…ผ๋ฌธ์—์„œ ์‚ฌ์šฉ๋œ ์›๋ž˜์˜ ๋ฐ์ดํ„ฐ์…‹๋ณด๋‹ค ์ž‘์€ ๋ฒ„์ „์ž…๋‹ˆ๋‹ค. ์ž์‹ ์˜ ๋ฐ์ดํ„ฐ์…‹์„ ์‚ฌ์šฉํ•˜๊ธฐ ์œ„ํ•ด, ํ•™์Šต์„ ์œ„ํ•œ ๋ฐ์ดํ„ฐ์…‹ ๋งŒ๋“ค๊ธฐ ๊ฐ€์ด๋“œ๋ฅผ ์ฐธ๊ณ ํ•˜์„ธ์š”.

MODEL_NAME ํ™˜๊ฒฝ ๋ณ€์ˆ˜(ํ—ˆ๋ธŒ ๋ชจ๋ธ ๋ ˆํฌ์ง€ํ† ๋ฆฌ ๋˜๋Š” ๋ชจ๋ธ ๊ฐ€์ค‘์น˜๊ฐ€ ํฌํ•จ๋œ ํด๋” ๊ฒฝ๋กœ)๋ฅผ ์ง€์ •ํ•˜๊ณ  pretrained_model_name_or_path ์ธ์ˆ˜์— ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค. DATASET_ID์— ๋ฐ์ดํ„ฐ์…‹ ์ด๋ฆ„์„ ์ง€์ •ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค:

export MODEL_NAME="stable-diffusion-v1-5/stable-diffusion-v1-5"
export DATASET_ID="fusing/instructpix2pix-1000-samples"

์ง€๊ธˆ, ํ•™์Šต์„ ์‹คํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์Šคํฌ๋ฆฝํŠธ๋Š” ๋ ˆํฌ์ง€ํ† ๋ฆฌ์˜ ํ•˜์œ„ ํด๋”์˜ ๋ชจ๋“  ๊ตฌ์„ฑ์š”์†Œ(feature_extractor, scheduler, text_encoder, unet ๋“ฑ)๋ฅผ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค.

accelerate launch --mixed_precision="fp16" train_instruct_pix2pix.py \
    --pretrained_model_name_or_path=$MODEL_NAME \
    --dataset_name=$DATASET_ID \
    --enable_xformers_memory_efficient_attention \
    --resolution=256 --random_flip \
    --train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \
    --max_train_steps=15000 \
    --checkpointing_steps=5000 --checkpoints_total_limit=1 \
    --learning_rate=5e-05 --max_grad_norm=1 --lr_warmup_steps=0 \
    --conditioning_dropout_prob=0.05 \
    --mixed_precision=fp16 \
    --seed=42 \
    --push_to_hub

์ถ”๊ฐ€์ ์œผ๋กœ, ๊ฐ€์ค‘์น˜์™€ ๋ฐ”์ด์–ด์Šค๋ฅผ ํ•™์Šต ๊ณผ์ •์— ๋ชจ๋‹ˆํ„ฐ๋งํ•˜์—ฌ ๊ฒ€์ฆ ์ถ”๋ก ์„ ์ˆ˜ํ–‰ํ•˜๋Š” ๊ฒƒ์„ ์ง€์›ํ•ฉ๋‹ˆ๋‹ค. report_to="wandb"์™€ ์ด ๊ธฐ๋Šฅ์„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

accelerate launch --mixed_precision="fp16" train_instruct_pix2pix.py \
    --pretrained_model_name_or_path=$MODEL_NAME \
    --dataset_name=$DATASET_ID \
    --enable_xformers_memory_efficient_attention \
    --resolution=256 --random_flip \
    --train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \
    --max_train_steps=15000 \
    --checkpointing_steps=5000 --checkpoints_total_limit=1 \
    --learning_rate=5e-05 --max_grad_norm=1 --lr_warmup_steps=0 \
    --conditioning_dropout_prob=0.05 \
    --mixed_precision=fp16 \
    --val_image_url="https://hf.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" \
    --validation_prompt="make the mountains snowy" \
    --seed=42 \
    --report_to=wandb \
    --push_to_hub

๋ชจ๋ธ ๋””๋ฒ„๊น…์— ์œ ์šฉํ•œ ์ด ํ‰๊ฐ€ ๋ฐฉ๋ฒ• ๊ถŒ์žฅํ•ฉ๋‹ˆ๋‹ค. ์ด๋ฅผ ์‚ฌ์šฉํ•˜๊ธฐ ์œ„ํ•ด wandb๋ฅผ ์„ค์น˜ํ•˜๋Š” ๊ฒƒ์„ ์ฃผ๋ชฉํ•ด์ฃผ์„ธ์š”. pip install wandb๋กœ ์‹คํ–‰ํ•ด wandb๋ฅผ ์„ค์น˜ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์—ฌ๊ธฐ, ๋ช‡ ๊ฐ€์ง€ ํ‰๊ฐ€ ๋ฐฉ๋ฒ•๊ณผ ํ•™์Šต ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ํฌํ•จํ•˜๋Š” ์˜ˆ์‹œ๋ฅผ ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์ฐธ๊ณ : ์›๋ณธ ๋…ผ๋ฌธ์—์„œ, ์ €์ž๋“ค์€ 256x256 ์ด๋ฏธ์ง€ ํ•ด์ƒ๋„๋กœ ํ•™์Šตํ•œ ๋ชจ๋ธ๋กœ 512x512์™€ ๊ฐ™์€ ๋” ํฐ ํ•ด์ƒ๋„๋กœ ์ž˜ ์ผ๋ฐ˜ํ™”๋˜๋Š” ๊ฒƒ์„ ๋ณผ ์ˆ˜ ์žˆ์—ˆ์Šต๋‹ˆ๋‹ค. ์ด๋Š” ํ•™์Šต์— ์‚ฌ์šฉํ•œ ํฐ ๋ฐ์ดํ„ฐ์…‹์„ ์‚ฌ์šฉํ–ˆ๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค.

๋‹ค์ˆ˜์˜ GPU๋กœ ํ•™์Šตํ•˜๊ธฐ

accelerate๋Š” ์›ํ™œํ•œ ๋‹ค์ˆ˜์˜ GPU๋กœ ํ•™์Šต์„ ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ•ฉ๋‹ˆ๋‹ค. accelerate๋กœ ๋ถ„์‚ฐ ํ•™์Šต์„ ์‹คํ–‰ํ•˜๋Š” ์—ฌ๊ธฐ ์„ค๋ช…์„ ๋”ฐ๋ผ ํ•ด ์ฃผ์‹œ๊ธฐ ๋ฐ”๋ž๋‹ˆ๋‹ค. ์˜ˆ์‹œ์˜ ๋ช…๋ น์–ด ์ž…๋‹ˆ๋‹ค:

accelerate launch --mixed_precision="fp16" --multi_gpu train_instruct_pix2pix.py \
 --pretrained_model_name_or_path=stable-diffusion-v1-5/stable-diffusion-v1-5 \
 --dataset_name=sayakpaul/instructpix2pix-1000-samples \
 --use_ema \
 --enable_xformers_memory_efficient_attention \
 --resolution=512 --random_flip \
 --train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \
 --max_train_steps=15000 \
 --checkpointing_steps=5000 --checkpoints_total_limit=1 \
 --learning_rate=5e-05 --lr_warmup_steps=0 \
 --conditioning_dropout_prob=0.05 \
 --mixed_precision=fp16 \
 --seed=42 \
 --push_to_hub

์ถ”๋ก ํ•˜๊ธฐ

์ผ๋‹จ ํ•™์Šต์ด ์™„๋ฃŒ๋˜๋ฉด, ์ถ”๋ก  ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

import PIL
import requests
import torch
from diffusers import StableDiffusionInstructPix2PixPipeline

model_id = "your_model_id"  # <- ์ด๋ฅผ ์ˆ˜์ •ํ•˜์„ธ์š”.
pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
generator = torch.Generator("cuda").manual_seed(0)

url = "https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/test_pix2pix_4.png"


def download_image(url):
   image = PIL.Image.open(requests.get(url, stream=True).raw)
   image = PIL.ImageOps.exif_transpose(image)
   image = image.convert("RGB")
   return image


image = download_image(url)
prompt = "wipe out the lake"
num_inference_steps = 20
image_guidance_scale = 1.5
guidance_scale = 10

edited_image = pipe(
   prompt,
   image=image,
   num_inference_steps=num_inference_steps,
   image_guidance_scale=image_guidance_scale,
   guidance_scale=guidance_scale,
   generator=generator,
).images[0]
edited_image.save("edited_image.png")

ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์‚ฌ์šฉํ•ด ์–ป์€ ์˜ˆ์‹œ์˜ ๋ชจ๋ธ ๋ ˆํฌ์ง€ํ† ๋ฆฌ๋Š” ์—ฌ๊ธฐ sayakpaul/instruct-pix2pix์—์„œ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์„ฑ๋Šฅ์„ ์œ„ํ•œ ์†๋„์™€ ํ’ˆ์งˆ์„ ์ œ์–ดํ•˜๊ธฐ ์œ„ํ•ด ์„ธ ๊ฐ€์ง€ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ด ์ข‹์Šต๋‹ˆ๋‹ค:

  • num_inference_steps
  • image_guidance_scale
  • guidance_scale

ํŠนํžˆ, image_guidance_scale์™€ guidance_scale๋Š” ์ƒ์„ฑ๋œ(โ€œ์ˆ˜์ •๋œโ€) ์ด๋ฏธ์ง€์—์„œ ํฐ ์˜ํ–ฅ์„ ๋ฏธ์น  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.(์—ฌ๊ธฐ์˜ˆ์‹œ๋ฅผ ์ฐธ๊ณ ํ•ด์ฃผ์„ธ์š”.)

๋งŒ์•ฝ InstructPix2Pix ํ•™์Šต ๋ฐฉ๋ฒ•์„ ์‚ฌ์šฉํ•ด ๋ช‡ ๊ฐ€์ง€ ํฅ๋ฏธ๋กœ์šด ๋ฐฉ๋ฒ•์„ ์ฐพ๊ณ  ์žˆ๋‹ค๋ฉด, ์ด ๋ธ”๋กœ๊ทธ ๊ฒŒ์‹œ๋ฌผInstruction-tuning Stable Diffusion with InstructPix2Pix์„ ํ™•์ธํ•ด์ฃผ์„ธ์š”.

< > Update on GitHub

ะ›ัƒั‡ัˆะธะน ั‡ะฐัั‚ะฝั‹ะน ั…ะพัั‚ะธะฝะณ