Working with LLMs#
The ray.data.llm module integrates with key large language model (LLM) inference engines and deployed models to enable LLM batch inference.
This guide shows you how to use ray.data.llm to:
Perform batch inference with LLMs#
At a high level, the ray.data.llm module provides a Processor
object which encapsulates
logic for performing batch inference with LLMs on a Ray Data dataset.
You can use the build_llm_processor
API to construct a processor.
The following example uses the vLLMEngineProcessorConfig
to construct a processor for the unsloth/Llama-3.1-8B-Instruct
model.
To start, install Ray Data + LLMs. This also installs vLLM, which is a popular and optimized LLM inference engine.
pip install -U "ray[data, llm]>=2.49.1"
The vLLMEngineProcessorConfig
is a configuration object for the vLLM engine.
It contains the model name, the number of GPUs to use, and the number of shards to use, along with other vLLM engine configurations.
Upon execution, the Processor object instantiates replicas of the vLLM engine (using map_batches
under the hood).
Here’s a simple configuration example:
# Basic vLLM configuration
config = vLLMEngineProcessorConfig(
model_source="unsloth/Llama-3.1-8B-Instruct",
engine_kwargs={
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4096, # Reduce if CUDA OOM occurs
"max_model_len": 16384,
},
concurrency=1,
batch_size=64,
)
The configuration includes detailed comments explaining:
`concurrency`: Number of vLLM engine replicas (typically 1 per node)
`batch_size`: Number of samples processed per batch (reduce if GPU memory is limited)
`max_num_batched_tokens`: Maximum tokens processed simultaneously (reduce if CUDA OOM occurs)
`accelerator_type`: Specify GPU type for optimal resource allocation
Each processor requires specific input columns based on the model and configuration. The vLLM processor expects input in OpenAI chat format with a ‘messages’ column.
This basic configuration pattern is used throughout this guide and includes helpful comments explaining key parameters.
This configuration creates a processor that expects:
Input: Dataset with ‘messages’ column (OpenAI chat format)
Output: Dataset with ‘generated_text’ column containing model responses
Some models may require a Hugging Face token to be specified. You can specify the token in the runtime_env
argument.
# Configuration with Hugging Face token
config_with_token = vLLMEngineProcessorConfig(
model_source="unsloth/Llama-3.1-8B-Instruct",
runtime_env={"env_vars": {"HF_TOKEN": "your_huggingface_token"}},
concurrency=1,
batch_size=64,
)
Configure vLLM for LLM inference#
Use the vLLMEngineProcessorConfig
to configure the vLLM engine.
For handling larger models, specify model parallelism:
# Model parallelism configuration for larger models
# tensor_parallel_size=2: Split model across 2 GPUs for tensor parallelism
# pipeline_parallel_size=2: Use 2 pipeline stages (total 4 GPUs needed)
# Total GPUs required = tensor_parallel_size * pipeline_parallel_size = 4
config = vLLMEngineProcessorConfig(
model_source="unsloth/Llama-3.1-8B-Instruct",
engine_kwargs={
"max_model_len": 16384,
"tensor_parallel_size": 2,
"pipeline_parallel_size": 2,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 2048,
},
concurrency=1,
batch_size=32,
accelerator_type="L4",
)
The underlying Processor
object instantiates replicas of the vLLM engine and automatically
configure parallel workers to handle model parallelism (for tensor parallelism and pipeline parallelism,
if specified).
To optimize model loading, you can configure the load_format
to runai_streamer
or tensorizer
.
Note
In this case, install vLLM with runai dependencies: pip install -U "vllm[runai]>=0.10.1"
# RunAI streamer configuration for optimized model loading
# Note: Install vLLM with runai dependencies: pip install -U "vllm[runai]>=0.10.1"
config = vLLMEngineProcessorConfig(
model_source="unsloth/Llama-3.1-8B-Instruct",
engine_kwargs={
"load_format": "runai_streamer",
"max_model_len": 16384,
},
concurrency=1,
batch_size=64,
)
If your model is hosted on AWS S3, you can specify the S3 path in the model_source
argument, and specify load_format="runai_streamer"
in the engine_kwargs
argument.
# S3 hosted model configuration
s3_config = vLLMEngineProcessorConfig(
model_source="s3://your-bucket/your-model-path/",
engine_kwargs={
"load_format": "runai_streamer",
"max_model_len": 16384,
},
concurrency=1,
batch_size=64,
)
To do multi-LoRA batch inference, you need to set LoRA related parameters in engine_kwargs
. See the vLLM with LoRA example for details.
# Multi-LoRA configuration
config = vLLMEngineProcessorConfig(
model_source="unsloth/Llama-3.1-8B-Instruct",
engine_kwargs={
"enable_lora": True,
"max_lora_rank": 32,
"max_loras": 1,
"max_model_len": 16384,
},
concurrency=1,
batch_size=32,
)
Batch inference with vision-language-model (VLM)#
Ray Data LLM also supports running batch inference with vision language models. This example shows how to prepare a dataset with images and run batch inference with a vision language model.
This example applies 2 adjustments on top of the previous example:
set
has_image=True
invLLMEngineProcessorConfig
prepare image input inside preprocessor
First, install the required dependencies:
# Install required dependencies for vision-language models
pip install datasets>=4.0.0
First, load a vision dataset:
"""
Load vision dataset from Hugging Face.
This function loads the LMMs-Eval-Lite dataset which contains:
- Images with associated questions
- Multiple choice answers
- Various visual reasoning tasks
"""
try:
import datasets
# Load "LMMs-Eval-Lite" dataset from Hugging Face
vision_dataset_llms_lite = datasets.load_dataset(
"lmms-lab/LMMs-Eval-Lite", "coco2017_cap_val"
)
vision_dataset = ray.data.from_huggingface(vision_dataset_llms_lite["lite"])
return vision_dataset
except ImportError:
print(
"datasets package not available. Install with: pip install datasets>=4.0.0"
)
return None
except Exception as e:
print(f"Error loading dataset: {e}")
return None
Next, configure the VLM processor with the essential settings:
vision_processor_config = vLLMEngineProcessorConfig(
model_source="Qwen/Qwen2.5-VL-3B-Instruct",
engine_kwargs=dict(
tensor_parallel_size=1,
pipeline_parallel_size=1,
max_model_len=4096,
enable_chunked_prefill=True,
max_num_batched_tokens=2048,
),
# Override Ray's runtime env to include the Hugging Face token. Ray Data uses Ray under the hood to orchestrate the inference pipeline.
runtime_env=dict(
env_vars=dict(
# HF_TOKEN=HF_TOKEN, # Token not needed for public models
VLLM_USE_V1="1",
),
),
batch_size=16,
accelerator_type="L4",
concurrency=1,
has_image=True,
)
For a more comprehensive VLM configuration with advanced options:
"""Create VLM configuration."""
return vLLMEngineProcessorConfig(
model_source="Qwen/Qwen2.5-VL-3B-Instruct",
engine_kwargs=dict(
tensor_parallel_size=1,
pipeline_parallel_size=1,
max_model_len=4096,
trust_remote_code=True,
limit_mm_per_prompt={"image": 1},
),
runtime_env={
# "env_vars": {"HF_TOKEN": "your-hf-token-here"} # Token not needed for public models
},
batch_size=1,
accelerator_type="L4",
concurrency=1,
has_image=True,
)
Finally, run the VLM inference:
"""Run the complete VLM example workflow."""
config = create_vlm_config()
vision_dataset = load_vision_dataset()
if vision_dataset:
# Build processor with preprocessing
processor = build_llm_processor(config, preprocess=vision_preprocess)
print("VLM processor configured successfully")
print(f"Model: {config.model_source}")
print(f"Has image support: {config.has_image}")
result = processor(vision_dataset).take_all()
return config, processor, result
return None, None, None
Batch inference with embedding models#
Ray Data LLM supports batch inference with embedding models using vLLM:
import ray
from ray.data.llm import vLLMEngineProcessorConfig, build_llm_processor
embedding_config = vLLMEngineProcessorConfig(
model_source="sentence-transformers/all-MiniLM-L6-v2",
task_type="embed",
engine_kwargs=dict(
enable_prefix_caching=False,
enable_chunked_prefill=False,
max_model_len=256,
enforce_eager=True,
),
batch_size=32,
concurrency=1,
apply_chat_template=False,
detokenize=False,
)
embedding_processor = build_llm_processor(
embedding_config,
preprocess=lambda row: dict(prompt=row["text"]),
postprocess=lambda row: {
"text": row["prompt"],
"embedding": row["embeddings"],
},
)
texts = [
"Hello world",
"This is a test sentence",
"Embedding models convert text to vectors",
]
ds = ray.data.from_items([{"text": text} for text in texts])
embedded_ds = embedding_processor(ds)
embedded_ds.show(limit=1)
{'text': 'Hello world', 'embedding': [0.1, -0.2, 0.3, ...]}
Key differences for embedding models:
Set
task_type="embed"
Set
apply_chat_template=False
anddetokenize=False
Use direct
prompt
input instead ofmessages
Access embeddings through``row[“embeddings”]``
For a complete embedding configuration example, see:
# Embedding model configuration
embedding_config = vLLMEngineProcessorConfig(
model_source="sentence-transformers/all-MiniLM-L6-v2",
task_type="embed",
engine_kwargs=dict(
enable_prefix_caching=False,
enable_chunked_prefill=False,
max_model_len=256,
enforce_eager=True,
),
batch_size=32,
concurrency=1,
apply_chat_template=False,
detokenize=False,
)
# Example usage for embeddings
def create_embedding_processor():
return build_llm_processor(
embedding_config,
preprocess=lambda row: dict(prompt=row["text"]),
postprocess=lambda row: {
"text": row["prompt"],
"embedding": row["embeddings"],
},
)
Batch inference with an OpenAI-compatible endpoint#
You can also make calls to deployed models that have an OpenAI compatible API endpoint.
import ray
OPENAI_KEY = os.environ["OPENAI_API_KEY"]
ds = ray.data.from_items(["Hand me a haiku."])
config = HttpRequestProcessorConfig(
url="https://api.openai.com/v1/chat/completions",
headers={"Authorization": f"Bearer {OPENAI_KEY}"},
qps=1,
)
processor = build_llm_processor(
config,
preprocess=lambda row: dict(
payload=dict(
model="gpt-4o-mini",
messages=[
{
"role": "system",
"content": "You are a bot that responds with haikus.",
},
{"role": "user", "content": row["item"]},
],
temperature=0.0,
max_tokens=150,
),
),
postprocess=lambda row: dict(
response=row["http_response"]["choices"][0]["message"]["content"]
),
)
ds = processor(ds)
print(ds.take_all())
Usage Data Collection#
Data for the following features and attributes is collected to improve Ray Data LLM:
config name used for building the llm processor
number of concurrent users for data parallelism
batch size of requests
model architecture used for building vLLMEngineProcessor
task type used for building vLLMEngineProcessor
engine arguments used for building vLLMEngineProcessor
tensor parallel size and pipeline parallel size used
GPU type used and number of GPUs used
If you would like to opt-out from usage data collection, you can follow Ray usage stats to turn it off.
Frequently Asked Questions (FAQs)#
How to configure LLM stage to parallelize across multiple nodes?#
At the moment, Ray Data LLM doesn’t support cross-node parallelism (either tensor parallelism or pipeline parallelism).
The processing pipeline is designed to run on a single node. The number of
GPUs is calculated as the product of the tensor parallel size and the pipeline
parallel size, and apply
[STRICT_PACK
strategy](https://docs.ray.io/en/latest/ray-core/scheduling/placement-group.html#pgroup-strategy)
to ensure that each replica of the LLM stage is executed on a single node.
Nevertheless, you can still horizontally scale the LLM stage to multiple nodes
as long as each replica (TP * PP) fits into a single node. The number of
replicas is configured by the concurrency
argument in
vLLMEngineProcessorConfig
.
GPU Memory Management and CUDA OOM Prevention#
If you encounter CUDA out of memory errors, Ray Data LLM provides several configuration options to optimize GPU memory usage:
# GPU memory management configuration
# If you encounter CUDA out of memory errors, try these optimizations:
config_memory_optimized = vLLMEngineProcessorConfig(
model_source="unsloth/Llama-3.1-8B-Instruct",
engine_kwargs={
"max_model_len": 8192,
"max_num_batched_tokens": 2048,
"enable_chunked_prefill": True,
"gpu_memory_utilization": 0.85,
"block_size": 16,
},
concurrency=1,
batch_size=16,
)
# For very large models or limited GPU memory:
config_minimal_memory = vLLMEngineProcessorConfig(
model_source="unsloth/Llama-3.1-8B-Instruct",
engine_kwargs={
"max_model_len": 4096,
"max_num_batched_tokens": 1024,
"enable_chunked_prefill": True,
"gpu_memory_utilization": 0.75,
},
concurrency=1,
batch_size=8,
)
Key strategies for handling GPU memory issues:
Reduce batch size: Start with smaller batches (8-16) and increase gradually
Lower `max_num_batched_tokens`: Reduce from 4096 to 2048 or 1024
Decrease `max_model_len`: Use shorter context lengths when possible
Set `gpu_memory_utilization`: Use 0.75-0.85 instead of default 0.90
Use smaller models: Consider using smaller model variants for resource-constrained environments
If you run into CUDA out of memory, your batch size is likely too large. Set an explicit small batch size or use a smaller model, or a larger GPU.
How to cache model weight to remote object storage#
While deploying Ray Data LLM to large scale clusters, model loading may be rate limited by HuggingFace. In this case, you can cache the model to remote object storage (AWS S3 or Google Cloud Storage) for more stable model loading.
Ray Data LLM provides the following utility to help uploading models to remote object storage.
# Download model from HuggingFace, and upload to GCS
python -m ray.llm.utils.upload_model \
--model-source facebook/opt-350m \
--bucket-uri gs://my-bucket/path/to/facebook-opt-350m
# Or upload a local custom model to S3
python -m ray.llm.utils.upload_model \
--model-source local/path/to/model \
--bucket-uri s3://my-bucket/path/to/model_name
And later you can use remote object store URI as model_source
in the config.
# S3 hosted model configuration
s3_config = vLLMEngineProcessorConfig(
model_source="s3://your-bucket/your-model-path/",
engine_kwargs={
"load_format": "runai_streamer",
"max_model_len": 16384,
},
concurrency=1,
batch_size=64,
)