lynx   »   [go: up one dir, main page]

Diffusers documentation

Wuerstchen

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v0.35.1).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Wuerstchen

Wuerstchen 模型通过将潜在空间压缩 42 倍,在不影响图像质量的情况下大幅降低计算成本并加速推理。在训练过程中,Wuerstchen 使用两个模型(VQGAN + 自动编码器)来压缩潜在表示,然后第三个模型(文本条件潜在扩散模型)在这个高度压缩的空间上进行条件化以生成图像。

为了将先验模型放入 GPU 内存并加速训练,尝试分别启用 gradient_accumulation_stepsgradient_checkpointingmixed_precision

本指南探讨 train_text_to_image_prior.py 脚本,帮助您更熟悉它,以及如何根据您的用例进行适配。

在运行脚本之前,请确保从源代码安装库:

git clone https://github.com/huggingface/diffusers
cd diffusers
pip install .

然后导航到包含训练脚本的示例文件夹,并安装脚本所需的依赖项:

cd examples/wuerstchen/text_to_image
pip install -r requirements.txt

🤗 Accelerate 是一个帮助您在多个 GPU/TPU 上或使用混合精度进行训练的库。它会根据您的硬件和环境自动配置训练设置。查看 🤗 Accelerate 快速入门 以了解更多信息。

初始化一个 🤗 Accelerate 环境:

accelerate config

要设置一个默认的 🤗 Accelerate 环境而不选择任何配置:

accelerate config default

或者,如果您的环境不支持交互式 shell,例如笔记本,您可以使用:

from accelerate.utils import write_basic_config

write_basic_config()

最后,如果您想在自己的数据集上训练模型,请查看 创建训练数据集 指南,了解如何创建与训练脚本兼容的数据集。

以下部分重点介绍了训练脚本中对于理解如何修改它很重要的部分,但并未涵盖 脚本 的详细信息。如果您有兴趣了解更多,请随时阅读脚本,并告诉我们您是否有任何问题或疑虑。

脚本参数

训练脚本提供了许多参数来帮助您自定义训练运行。所有参数及其描述都可以在 parse_args() 函数中找到。它为每个参数提供了默认值,例如训练批次大小和学习率,但如果您愿意,也可以在训练命令中设置自己的值。

例如,要使用 fp16 格式的混合精度加速训练,请在训练命令中添加 --mixed_precision 参数:

accelerate launch train_text_to_image_prior.py \
  --mixed_precision="fp16"

大多数参数与 文本到图像 训练指南中的参数相同,因此让我们直接深入 Wuerstchen 训练脚本!

训练脚本

训练脚本也与 文本到图像 训练指南类似,但已修改以支持 Wuerstchen。本指南重点介绍 Wuerstchen 训练脚本中独特的代码。

main() 函数首先初始化图像编码器 - 一个 EfficientNet - 以及通常的调度器和分词器。

with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
    pretrained_checkpoint_file = hf_hub_download("dome272/wuerstchen", filename="model_v2_stage_b.pt")
    state_dict = torch.load(pretrained_checkpoint_file, map_location="cpu")
    image_encoder = EfficientNetEncoder()
    image_encoder.load_state_dict(state_dict["effnet_state_dict"])
    image_encoder.eval()

您还将加载 WuerstchenPrior 模型以进行优化。

prior = WuerstchenPrior.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior")

optimizer = optimizer_cls(
    prior.parameters(),
    lr=args.learning_rate,
    betas=(args.adam_beta1, args.adam_beta2),
    weight_decay=args.adam_weight_decay,
    eps=args.adam_epsilon,
)

接下来,您将对图像应用一些 transforms 并对标题进行 tokenize

def preprocess_train(examples):
    images = [image.conver
t("RGB") for image in examples[image_column]]
    examples["effnet_pixel_values"] = [effnet_transforms(image) for image in images]
    examples["text_input_ids"], examples["text_mask"] = tokenize_captions(examples)
    return examples

最后,训练循环处理使用EfficientNetEncoder将图像压缩到潜在空间,向潜在表示添加噪声,并使用WuerstchenPrior模型预测噪声残差。

pred_noise = prior(noisy_latents, timesteps, prompt_embeds)

如果您想了解更多关于训练循环的工作原理,请查看理解管道、模型和调度器教程,该教程分解了去噪过程的基本模式。

启动脚本

一旦您完成了所有更改或对默认配置满意,就可以启动训练脚本了!🚀

设置DATASET_NAME环境变量为Hub中的数据集名称。本指南使用Naruto BLIP captions数据集,但您也可以创建和训练自己的数据集(参见创建用于训练的数据集指南)。

要使用Weights & Biases监控训练进度,请在训练命令中添加--report_to=wandb参数。您还需要在训练命令中添加--validation_prompt以跟踪结果。这对于调试模型和查看中间结果非常有用。

export DATASET_NAME="lambdalabs/naruto-blip-captions"

accelerate launch  train_text_to_image_prior.py \
  --mixed_precision="fp16" \
  --dataset_name=$DATASET_NAME \
  --resolution=768 \
  --train_batch_size=4 \
  --gradient_accumulation_steps=4 \
  --gradient_checkpointing \
  --dataloader_num_workers=4 \
  --max_train_steps=15000 \
  --learning_rate=1e-05 \
  --max_grad_norm=1 \
  --checkpoints_total_limit=3 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --validation_prompts="A robot naruto, 4k photo" \
  --report_to="wandb" \
  --push_to_hub \
  --output_dir="wuerstchen-prior-naruto-model"

训练完成后,您可以使用新训练的模型进行推理!

import torch
from diffusers import AutoPipelineForText2Image
from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS

pipeline = AutoPipelineForText2Image.from_pretrained("path/to/saved/model", torch_dtype=torch.float16).to("cuda")

caption = "A cute bird naruto holding a shield"
images = pipeline(
    caption,
    width=1024,
    height=1536,
    prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
    prior_guidance_scale=4.0,
    num_images_per_prompt=2,
).images

下一步

恭喜您训练了一个Wuerstchen模型!要了解更多关于如何使用您的新模型的信息,请参 以下内容可能有所帮助:

  • 查看 Wuerstchen API 文档,了解更多关于如何使用该管道进行文本到图像生成及其限制的信息。
< > Update on GitHub

Лучший частный хостинг