Ray Direct Transport (RDT)#
Ray objects are normally stored in Ray’s CPU-based object store and copied and deserialized when accessed by a Ray task or actor.
For GPU data specifically, this can lead to unnecessary and expensive data transfers.
For example, passing a CUDA torch.Tensor
from one Ray task to another would require a copy from GPU to CPU memory, then back again to GPU memory.
Ray Direct Transport (RDT) is a new feature that allows Ray to store and pass objects directly between Ray actors.
This feature augments the familiar Ray ObjectRef
API by:
Keeping GPU data in GPU memory until a transfer is needed
Avoiding expensive serialization and copies to and from the Ray object store
Using efficient data transports like collective communication libraries (Gloo or NCCL) or point-to-point RDMA (via NVIDIA’s NIXL) to transfer data directly between devices, including both CPU and GPUs
Note
RDT is currently in alpha. Not all Ray Core APIs are supported yet. Future releases may introduce breaking API changes. See the limitations section for more details.
Getting started#
Tip
RDT currently supports torch.Tensor
objects created by Ray actor tasks. Other datatypes and Ray non-actor tasks may be supported in future releases.
This walkthrough will show how to create and use RDT with different tensor transports, i.e. the mechanism used to transfer the tensor between actors. Currently, RDT supports the following tensor transports:
Gloo: A collective communication library for PyTorch and CPUs.
NVIDIA NCCL: A collective communication library for NVIDIA GPUs.
NVIDIA NIXL (backed by UCX): A library for accelerating point-to-point transfers via RDMA, especially between various types of memory and NVIDIA GPUs.
For ease of following along, we’ll start with the Gloo transport, which can be used without any physical GPUs.
Usage with Gloo (CPUs only)#
Installation#
Note
Under construction.
Walkthrough#
To get started, define an actor class and a task that returns a torch.Tensor
:
import torch
import ray
@ray.remote
class MyActor:
def random_tensor(self):
return torch.randn(1000, 1000)
As written, when the torch.Tensor
is returned, it will be copied into Ray’s CPU-based object store.
For CPU-based tensors, this can require an expensive step to copy and serialize the object, while GPU-based tensors additionally require a copy to and from CPU memory.
To enable RDT, use the tensor_transport
option in the @ray.method
decorator.
@ray.remote
class MyActor:
@ray.method(tensor_transport="gloo")
def random_tensor(self):
return torch.randn(1000, 1000)
This decorator can be added to any actor tasks that return a torch.Tensor
, or that return torch.Tensors
nested inside other Python objects.
Adding this decorator will change Ray’s behavior in the following ways:
When returning the tensor, Ray will store a reference to the tensor instead of copying it to CPU memory.
When the
ray.ObjectRef
is passed to another task, Ray will use Gloo to transfer the tensor to the destination task.
Note that for (2) to work, the @ray.method(tensor_transport)
decorator only needs to be added to the actor task that returns the tensor. It should not be added to actor tasks that consume the tensor (unless those tasks also return tensors).
Also, for (2) to work, we must first create a collective group of actors.
Creating a collective group#
To create a collective group for use with RDT:
Create multiple Ray actors.
Create a collective group on the actors using the
ray.experimental.collective.create_collective_group
function. Thebackend
specified must match thetensor_transport
used in the@ray.method
decorator.
Here is an example:
import torch
import ray
from ray.experimental.collective import create_collective_group
@ray.remote
class MyActor:
@ray.method(tensor_transport="gloo")
def random_tensor(self):
return torch.randn(1000, 1000)
def sum(self, tensor: torch.Tensor):
return torch.sum(tensor)
sender, receiver = MyActor.remote(), MyActor.remote()
# The tensor_transport specified here must match the one used in the @ray.method
# decorator.
group = create_collective_group([sender, receiver], backend="torch_gloo")
The actors can now communicate directly via gloo.
The group can also be destroyed using the ray.experimental.collective.destroy_collective_group
function.
After calling this function, a new collective group can be created on the same actors.
Passing objects to other actors#
Now that we have a collective group, we can create and pass RDT objects between the actors. Here is a full example:
import torch
import ray
from ray.experimental.collective import create_collective_group
@ray.remote
class MyActor:
@ray.method(tensor_transport="gloo")
def random_tensor(self):
return torch.randn(1000, 1000)
def sum(self, tensor: torch.Tensor):
return torch.sum(tensor)
sender, receiver = MyActor.remote(), MyActor.remote()
group = create_collective_group([sender, receiver], backend="torch_gloo")
# The tensor will be stored by the `sender` actor instead of in Ray's object
# store.
tensor = sender.random_tensor.remote()
result = receiver.sum.remote(tensor)
print(ray.get(result))
When the ray.ObjectRef
is passed to another task, Ray will use Gloo to transfer the tensor directly from the source actor to the destination actor instead of the default object store.
Note that the @ray.method(tensor_transport)
decorator is only added to the actor task that returns the tensor; once this hint has been added, the receiving actor task receiver.sum
will automatically use Gloo to receive the tensor.
In this example, because MyActor.sum
does not have the @ray.method(tensor_transport)
decorator, it will use the default Ray object store transport to return torch.sum(tensor)
.
RDT also supports passing tensors nested inside Python data structures, as well as actor tasks that return multiple tensors, like in this example:
import torch
import ray
from ray.experimental.collective import create_collective_group
@ray.remote
class MyActor:
@ray.method(tensor_transport="gloo")
def random_tensor_dict(self):
return {"tensor1": torch.randn(1000, 1000), "tensor2": torch.randn(1000, 1000)}
def sum(self, tensor_dict: dict):
return torch.sum(tensor_dict["tensor1"]) + torch.sum(tensor_dict["tensor2"])
sender, receiver = MyActor.remote(), MyActor.remote()
group = create_collective_group([sender, receiver], backend="torch_gloo")
# Both tensor values in the dictionary will be stored by the `sender` actor
# instead of in Ray's object store.
tensor_dict = sender.random_tensor_dict.remote()
result = receiver.sum.remote(tensor_dict)
print(ray.get(result))
Passing RDT objects to the actor that produced them#
RDT ray.ObjectRefs
can also be passed to the actor that produced them.
This avoids any copies and just provides a reference to the same torch.Tensor
that was previously created.
For example:
import torch
import ray
import pytest
from ray.experimental.collective import create_collective_group
@ray.remote
class MyActor:
@ray.method(tensor_transport="gloo")
def random_tensor(self):
return torch.randn(1000, 1000)
def sum(self, tensor: torch.Tensor):
return torch.sum(tensor)
sender, receiver = MyActor.remote(), MyActor.remote()
group = create_collective_group([sender, receiver], backend="torch_gloo")
tensor = sender.random_tensor.remote()
# Pass the ObjectRef back to the actor that produced it. The tensor will be
# passed back to the same actor without copying.
sum1 = sender.sum.remote(tensor)
sum2 = receiver.sum.remote(tensor)
assert torch.allclose(*ray.get([sum1, sum2]))
Note
Ray only keeps a reference to the tensor created by the user, so the tensor objects are mutable.
If sender.sum
were to modify the tensor in the above example, the changes would also be seen by receiver.sum
.
This differs from the normal Ray Core API, which always makes an immutable copy of data returned by actors.
ray.get
#
The ray.get
function can also be used as usual to retrieve the result of an RDT object. However, ray.get
will by default use the same tensor transport as the one specified in the @ray.method
decorator. For collective-based transports, this will not work if the caller is not part of the collective group.
Therefore, users need to specify the Ray object store as the tensor transport explicitly by setting _tensor_transport
in ray.get
.
# Wrong example of ray.get(). Since the tensor transport in the @ray.method decorator is Gloo,
# ray.get() will try to use Gloo to fetch the tensor, which is not supported
# because the caller is not part of the collective group.
with pytest.raises(ValueError) as e:
ray.get(tensor)
assert (
"Currently ray.get() only supports OBJECT_STORE and NIXL tensor transport, got TensorTransportEnum.GLOO, please specify the correct tensor transport in ray.get()"
in str(e.value)
)
# Correct example of ray.get(), explicitly setting the tensor transport to use the Ray object store.
print(ray.get(tensor, _tensor_transport="object_store"))
# torch.Tensor(...)
Object mutability#
Unlike objects in the Ray object store, RDT objects are mutable, meaning that Ray only holds a reference to the tensor and will not copy it until a transfer is requested. This means that if the actor that returns a tensor also keeps a reference to the tensor, and the actor later modifies it in place while Ray is still storing the tensor reference, it’s possible that some or all of the changes may be seen by receiving actors.
Here is an example of what can go wrong:
import torch
import ray
from ray.experimental.collective import create_collective_group
@ray.remote
class MyActor:
@ray.method(tensor_transport="gloo")
def random_tensor(self):
self.tensor = torch.randn(1000, 1000)
# After this function returns, Ray and this actor will both hold a
# reference to the same tensor.
return self.tensor
def increment_and_sum_stored_tensor(self):
# NOTE: In-place update, while Ray still holds a reference to the same tensor.
self.tensor += 1
return torch.sum(self.tensor)
def increment_and_sum(self, tensor: torch.Tensor):
return torch.sum(tensor + 1)
sender, receiver = MyActor.remote(), MyActor.remote()
group = create_collective_group([sender, receiver], backend="torch_gloo")
tensor = sender.random_tensor.remote()
tensor1 = sender.increment_and_sum_stored_tensor.remote()
# Wait for sender.increment_and_sum_stored_tensor task to finish.
tensor1 = ray.get(tensor1)
# Receiver will now receive the updated value instead of the original.
tensor2 = receiver.increment_and_sum.remote(tensor)
try:
# This assertion will fail because sender.increment_and_sum_stored_tensor
# modified the tensor in place before sending it to
# receiver.increment_and_sum.
assert torch.allclose(tensor1, ray.get(tensor2))
except AssertionError:
print("AssertionError: sender and receiver returned different sums.")
In this example, the sender actor returns a tensor to Ray, but it also keeps a reference to the tensor in its local state.
Then, in sender.increment_and_sum_stored_tensor
, the sender actor modifies the tensor in place while Ray is still holding the tensor reference.
Then, the receiver.increment_and_sum
task receives the modified tensor instead of the original, so the assertion fails.
To fix this kind of error, use the ray.experimental.wait_tensor_freed
function to wait for Ray to release all references to the tensor, so that the actor can safely write to the tensor again.
wait_tensor_freed
will unblock once all tasks that depend on the tensor have finished executing and all corresponding ObjectRefs
have gone out of scope.
Ray tracks tasks that depend on the tensor by keeping track of which tasks take the ObjectRef
corresponding to the tensor as an argument.
Here’s a fixed version of the earlier example.
import torch
import ray
from ray.experimental.collective import create_collective_group
@ray.remote
class MyActor:
@ray.method(tensor_transport="gloo")
def random_tensor(self):
self.tensor = torch.randn(1000, 1000)
return self.tensor
def increment_and_sum_stored_tensor(self):
# 1. Sender actor waits for Ray to release all references to the tensor
# before modifying the tensor in place.
ray.experimental.wait_tensor_freed(self.tensor)
# NOTE: In-place update, but Ray guarantees that it has already released
# its references to this tensor.
self.tensor += 1
return torch.sum(self.tensor)
def increment_and_sum(self, tensor: torch.Tensor):
# Receiver task remains the same.
return torch.sum(tensor + 1)
sender, receiver = MyActor.remote(), MyActor.remote()
group = create_collective_group([sender, receiver], backend="torch_gloo")
tensor = sender.random_tensor.remote()
tensor1 = sender.increment_and_sum_stored_tensor.remote()
# 2. Skip `ray.get`` because `wait_tensor_freed`` will block until all
# references to `tensor` are freed, so calling `ray.get` here would cause a
# deadlock.
# tensor1 = ray.get(tensor1)
tensor2 = receiver.increment_and_sum.remote(tensor)
# 3. Delete all references to `tensor`, to unblock wait_tensor_freed.
del tensor
# This assertion will now pass.
assert torch.allclose(ray.get(tensor1), ray.get(tensor2))
The main changes are:
1. sender
calls wait_tensor_freed
before modifying the tensor in place.
2. The driver skips ray.get
because wait_tensor_freed
blocks until all ObjectRefs
pointing to the tensor are freed, so calling ray.get
here would cause a deadlock.
3. The driver calls del tensor
to release its reference to the tensor. Again, this is necessary because wait_tensor_freed
blocks until all ObjectRefs
pointing to the tensor are freed.
When an RDT ObjectRef
is passed back to the same actor that produced it, Ray passes back a reference to the tensor instead of a copy. Therefore, the same kind of bug can occur.
To help catch such cases, Ray will print a warning if an RDT object is passed to the actor that produced it and a different actor, like so:
import torch
import ray
from ray.experimental.collective import create_collective_group
@ray.remote
class MyActor:
@ray.method(tensor_transport="gloo")
def random_tensor(self):
return torch.randn(1000, 1000)
def increment_and_sum(self, tensor: torch.Tensor):
# In-place update.
tensor += 1
return torch.sum(tensor)
sender, receiver = MyActor.remote(), MyActor.remote()
group = create_collective_group([sender, receiver], backend="torch_gloo")
tensor = sender.random_tensor.remote()
tensor1 = sender.increment_and_sum.remote(tensor)
tensor2 = receiver.increment_and_sum.remote(tensor)
# A warning will be printed:
# UserWarning: GPU ObjectRef(...) is being passed back to the actor that created it Actor(MyActor, ...). Note that GPU objects are mutable. If the tensor is modified, Ray's internal copy will also be updated, and subsequent passes to other actors will receive the updated version instead of the original.
try:
# This assertion may fail because the tensor returned by sender.random_tensor
# is modified in-place by sender.increment_and_sum while being sent to
# receiver.increment_and_sum.
assert torch.allclose(ray.get(tensor1), ray.get(tensor2))
except AssertionError:
print("AssertionError: sender and receiver returned different sums.")
Usage with NCCL (NVIDIA GPUs only)#
RDT requires just a few lines of code change to switch tensor transports. Here is the Gloo example, modified to use NVIDIA GPUs and the NCCL library for collective GPU communication.
import torch
import ray
from ray.experimental.collective import create_collective_group
@ray.remote(num_gpus=1)
class MyActor:
@ray.method(tensor_transport="nccl")
def random_tensor(self):
return torch.randn(1000, 1000).cuda()
def sum(self, tensor: torch.Tensor):
return torch.sum(tensor)
sender, receiver = MyActor.remote(), MyActor.remote()
group = create_collective_group([sender, receiver], backend="nccl")
# The tensor will be stored by the `sender` actor instead of in Ray's object
# store.
tensor = sender.random_tensor.remote()
result = receiver.sum.remote(tensor)
ray.get(result)
The main code differences are:
The
@ray.method
usestensor_transport="nccl"
instead oftensor_transport="gloo"
.The
ray.experimental.collective.create_collective_group
function is used to create a collective group.The tensor is created on the GPU using the
.cuda()
method.
Usage with NIXL (CPUs or NVIDIA GPUs)#
Installation#
For maximum performance, run the install_gdrcopy.sh script (e.g., install_gdrcopy.sh "${GDRCOPY_OS_VERSION}" "12.8" "x64"
). You can find available OS versions here. If gdrcopy
is not installed, things will still work with a plain pip install nixl
, just with lower performance. nixl
and ucx
are installed as dependencies via pip.
Walkthrough#
NIXL can transfer data between different devices, including CPUs and NVIDIA GPUs, but doesn’t require a collective group to be created ahead of time. This means that any actor that has NIXL installed in its environment can be used to create and pass an RDT object.
Otherwise, the usage is the same as in the Gloo example.
Here is an example showing how to use NIXL to transfer an RDT object between two actors:
import torch
import ray
@ray.remote(num_gpus=1)
class MyActor:
@ray.method(tensor_transport="nixl")
def random_tensor(self):
return torch.randn(1000, 1000).cuda()
def sum(self, tensor: torch.Tensor):
return torch.sum(tensor)
def produce(self, tensors):
refs = []
for t in tensors:
refs.append(ray.put(t, _tensor_transport="nixl"))
return refs
def consume_with_nixl(self, refs):
# ray.get will also use NIXL to retrieve the
# result.
tensors = [ray.get(ref) for ref in refs]
sum = 0
for t in tensors:
assert t.device.type == "cuda"
sum += t.sum().item()
return sum
# No collective group is needed. The two actors just need to have NIXL
# installed.
sender, receiver = MyActor.remote(), MyActor.remote()
# The tensor will be stored by the `sender` actor instead of in Ray's object
# store.
tensor = sender.random_tensor.remote()
result = receiver.sum.remote(tensor)
ray.get(result)
Compared to the Gloo example, the main code differences are:
The
@ray.method
usestensor_transport="nixl"
instead oftensor_transport="gloo"
.No collective group is needed.
ray.put and ray.get with NIXL#
Unlike the collective-based tensor transports (Gloo and NCCL), the ray.get
function can use NIXL to retrieve a copy of the result.
By default, the tensor transport for ray.get
will be the one specified in the @ray.method
decorator.
# ray.get will also use NIXL to retrieve the
# result.
print(ray.get(tensor))
# torch.Tensor(...)
You can also use NIXL to retrieve the result from references created by ray.put
.
tensor1 = torch.randn(1000, 1000).cuda()
tensor2 = torch.randn(1000, 1000).cuda()
refs = sender.produce.remote([tensor1, tensor2])
ref1 = receiver.consume_with_nixl.remote(refs)
print(ray.get(ref1))
Summary#
RDT allows Ray to store and pass objects directly between Ray actors, using accelerated transports like GLOO, NCCL, and NIXL. Here are the main points to keep in mind:
If using a collective-based tensor transport (Gloo or NCCL), a collective group must be created ahead of time. NIXL just requires all involved actors to have NIXL installed.
Unlike objects in the Ray object store, RDT objects are mutable, meaning that Ray only holds a reference, not a copy, to the stored tensor(s).
Otherwise, actors can be used as normal.
For a full list of limitations, see the limitations section.
Microbenchmarks#
Note
Under construction.
Limitations#
RDT is currently in alpha and currently has the following limitations, which may be addressed in future releases:
Support for
torch.Tensor
objects only.Support for Ray actors only, not Ray tasks.
Not yet compatible with asyncio. Follow the tracking issue for updates.
Support for the following transports: Gloo, NCCL, and NIXL.
Support for CPUs and NVIDIA GPUs only.
RDT objects are mutable. This means that Ray only holds a reference to the tensor, and will not copy it until a transfer is requested. Thus, if the application code also keeps a reference to a tensor before returning it, and modifies the tensor in place, then some or all of the changes may be seen by the receiving actor.
For collective-based tensor transports (Gloo and NCCL):
Only the process that created the collective group can submit actor tasks that return and pass RDT objects. If the creating process passes the actor handles to other processes, those processes can submit actor tasks as usual, but will not be able to use RDT objects.
Similarly, the process that created the collective group cannot serialize and pass RDT
ray.ObjectRefs
to other Ray tasks or actors. Instead, theray.ObjectRef
s can only be passed as direct arguments to other actor tasks, and those actors must be in the same collective group.Each actor can only be in one collective group per tensor transport at a time.
No support for
ray.put
.If a system-level error occurs during a collective operation, the collective group will be destroyed and the actors will no longer be able to communicate via the collective group. Note that application-level errors, i.e. exceptions raised by user code, will not destroy the collective group and will instead be propagated to any dependent task(s), as for non-RDT Ray objects. System-level errors include:
Errors internal to the third-party transport, e.g., NCCL network errors
Actor and node failure
Tensors returned by the user that are located on an unsupported device, e.g., a CPU tensor when using NCCL
Any unexpected system bugs
Advanced: RDT Internals#
Note
Under construction.