[RFC] RPC Based Distributed Model Parallel - mrshenli/pytorch GitHub Wiki
PyTorch currently provides simple APIs for single machine data parallel, distributed data parallel, and single machine model parallel. However, when it comes to distributed model parallel, applications have to build their own scaffold to stitch together local autograd graphs into one global graph. This proposal aims to fill in that gap by providing an RPC-Based distributed model parallel API. In short, applications may run RPC to execute code remotely in the forward pass, and autograd will automatically travel across RPC boundaries in the backward pass.
API
Core Concepts
RRef[T] - (abbreviation ref) A reference to a value of some type T
(e.g. Tensor) on a remote worker. This handle keeps the referenced remote tensor value alive on the owner, but there is no implication that the value will be transferred to the local worker in the future. It is valid to have a reference to local value as well, and values of type T
can be implicitly converted to RRef[T]
. This implicit conversion will be critical later to allow the expression of different types of RPC. Think of it like the implicit conversion from std::string
to const std::string &
. See System Design section for more details about RRef
.
ref.owner() # what is the worker this value lives on
v = ref.local_value() # if ref.owner() is local worker, then
# this returns the the underlying value, otherwise error.
# you can create a ref to a local tensor
t = torch.rand(3, 4)
ref2 = torch.RRef(t)
# in TorchScript, T can be automatically converted to RRef[T]
ref3 : RRef[Tensor] = t
Future[T] - (abbreviation fut) a guarantee that at some future point in time the value of type T
will be available locally. The action to create T
locally is assumed to be scheduled and in-progress. Future is already supported in TorchScript and we are extending this to remote calls.
v = fut.wait() # block the current thread until v is ready
# local cpu task creation returns a future to the computed tensors
fut = torch.fork(lambda x, y: x + y, torch.rand(3, 4), torch.rand(3, 4))
Core Functions
# synchronous
result : T = torch.rpc(on : Worker, remote_callable : Callable, *args)
# asynchronous
result : Future[T] = torch.async_rpc(on : Worker, remote_callable : Callable, *args)
# remote reference
result : RRef[T] = torch.remote(on : Worker, remote_callable : Callable, *args)
Each function above invokes remote_callable
on a remote worker. Value types in the args
list are copied by value to the remote worker. RRef[T]
types in the args
list are copied by reference to the remote worker (again see the analogy between std::string
and const std::string&
).
The synchronous variant copies the result value back, blocking the calling thread until the response occurs. The asynchronous variant returns immediately with a future. The remote knows that the call will expect to receive the value so it will send a message back at some point with the result without further prompting.
The remote reference variant returns immediately with an RRef
of the return value. The remote knows that the caller does not expect to receive the result value.
Below shows how these functions are used:
# make some local tensors
a : Tensor = torch.rand(3, 4)
b : Tensor = torch.rand(3, 4)
# define a remote function, visible to all machines.
# type annotations define expected input/output types.
def remote_function(a : Tensor, b : RRef[Tensor]) -> Tensor:
# 'b' in the type signature is a remote reference, so we must copy it here
# to use it locally.
# to_here() is defined later in the syntax sugar section, it synchronously
# copies the tensor to this worker.
b_l : Tensor = b.to_here()
return a + b_l
# run remote_function on a different device.
# a is copied by value since it is a Tensor
# b is copied by reference remote machine due to the RRef[Tensor]
# type annotation in the signature, which causes an implicit conversion to a
# reference type.
# torch.remote always creates an RRef of the result type.
# It does not wait for the remote's response.
# There is no implied copy of the tensor data yet.
c : RRef[Tensor] = torch.remote("worker1", remote_function, a, b)
# we can explicitly request the data to be copied back here:
c_l : Tensor = c.to_here()
# another example:
def remote_function2(a : Tensor, b : Tensor) -> Tensor:
return a + b
# Here we call torch.rpc which returns the value directly without
# creating a remote reference.
# we synchronously wait for remote_function2 to return.
c : Tensor = torch.rpc("worker2", remote_function2, a, b)
# When the RPC call is returning a non-reference type, we need to wait for
# a response from the remote host. To avoid synchronously waiting, use the
# async flag to get a future instead.
c_f : Future[Tensor] = torch.async_rpc("worker2", remote_function2, a, b)
# even before calling wait, the remote knows that the data should be sent back
# to the caller as soon as it is ready.
# force the local thread to wait for the remote's response
c = c_f.wait()
# if you omit type annotations in the remote function, the assumption is that
# arguments are passed without any implicit conversions
def remote_function3(a, b):
# no annotations mean that a, b will be Tensor since there is no conversion
return a + b
c: Tensor = torch.rpc("worker2", remote_function3, a, b)
RRef Forks
Implicit Conversions for RRef Arguments
We allow implicit conversion between T
and RRef[T]
for arguments of RPC functions. Both the actual and formal parameter can either be a T
or an RRef[T]
, leading to four cases that might occur:
T → T (passing a T to an rpc that accepts a T): the value T is copied by value, and send over the wire as part of the message invoking the RPC
T → RRef[T] (passing a T to an rpc that accepts RRef[T]): The caller constructs a remote reference to the argument, and sends the reference over the wire to the callee. The data is not sent. The callee can then use the reference as a handle to either request the data later or to make further remote calls.
RRef[T] → T (passing an RRef[T] to an rpc that accepts T): The callee expects to get an actual value, so the callee needs to turn the reference into a value. The network behavior depends on where the RRef[T]
lives.
- If the
RRef[T]
lives on the caller, then the implementation looks up the actual value ofT
locally and pass it by value along the wire similar to the T → T case. - If the
RRef[T]
lives on the callee, then the implementation just sends the reference and the callee does the lookup locally. - If the
RRef[T]
lives on some third machine, then the caller sends 2 messages. One to the third machine telling it to send the data in the remote reference directly to the callee, and one to the callee telling it to start the RPC and expect this input to be coming from the third machine. This effectively forward value of theRRef[T]
to the callee without the caller having to load it or the callee having to request it later
Examples:
def remote_function1() -> Tensor:
return torch.ones(2)
def remote_function2(a : Tensor) → Tensor:
b = a * 2
return b
aref : RRef[Tensor] = remote("worker1", remote_function1)
# this local worker will make two RPC calls: one to tell worker1 to send the
# tensor to worker2, and another one to tell worker2 to expect this Tensor input
# from worker1. remote_function2 will run on worker2 only after it received the
# tensor from worker1.
bref : RRef[Tensor] = remote("worker2", remote_function2, aref)
RRef[T] → RRef[T] (**passing an RRef[T] to an RPC that accepts RRef[T]): **The callee expects an RRef[T]
, but we must make sure we correctly keep track of references to the value on a remote. So the actual behavior depends on where the RRef[T]
lives.
- If
RRef[T]
lives on the caller, then we simply pass it to the remote and record that this remote now has a live reference to the value. - If the
RRef[T]
lives on the callee, then we pass it to the remote, and it becomes a local reference on the remote. - If
RRef[T]
lives on some third machine, then we must forward the reference. To do this the caller sends two messages. One to the third machine telling it to create a remote reference and send it to the callee, and one to the callee telling from where to expect the remote. The callee code is not invoked until the remote is transferred to ensure sane reference counting.
Examples:
def remote_function1() -> Tensor:
return torch.ones(2)
def remote_function2(a : RRef[Tensor]) -> Tensor:
int delta = 10
return a.to_here() + delta
aref : RRef[Tensor] = remote("worker1", remote_function1)
# this local worker will make two RPC calls: one to tell worker1 to create a
# remote reference and send it to worker2, and another one to tell worker2 to
# expect this remote reference input from worker1. remote_function2 code will
# not run on worker2 until it receives the remote reference from worker1 to
# ensure proper reference counting.
bref : RRef[Tensor] = remote("worker2", remote_function2, aref)
When an RRef[T]
goes dead on machine A, a message is sent to the owner of T
telling it that the reference from machine A is dead.
Explicit RRef type for return values
The above implicit RRef
argument conversion does not apply to return values. If remote_function
returns RRef[T]
, calling it remotely using torch.remote
would return RRef[RRef[T]]
instead of RRef[T]
. This is because when the return value RRef
of torch.remote
is first created on the caller who does not know the owner of the real data T
. T could be stored on the callee of torch.remote
, but it could also be on a different worker as callee may also make another remote call within remote_function
and return an RRef[T]
owned by a different worker. Moreover, the caller is allowed to share the returned RRef
with other workers immediately after torch.remote
returns. However, as by then, the caller does not know the real owner of T
yet, sharing the RRef
would break the reference count algorithm.
Examples:
def remote_function3() -> RRef[Tensor]:
return torch.remote("Worker2", torch.ones, 2, 2)
cref : RRef[RRef[Tensor]] = remote("worker1", remote_function3)
Initialization API
Users may choose communication backend for RPC, and users are responsible for setting up the backend properly before calling the init_rpc
method.
# backend: specifies the underlying communication implementation
# init_method: contains the information to initialize/connect a name store to
# resolve names
# name: is a unique identifier for the current worker
torch.distributed.init_rpc(backend="pg", init_method="file:///...", name="worker1")
The init_rpc
method will create an RpcAgent
under the hood and will make the current worker ready to send and receive RPC calls. If you call init_rpc
and use the ProcessGroup
(pg
) backend, it acts as a global barrier, where all the node names as collectively synchronized before continuing. This is not the case if you use a peer to peer backend (e.g. tensor pipes), where calling init_rpc
will register the node name in the specified store and start serving.
Applications don’t need to explicitly register functions for remote execution, but we do assume same functions are defined on both caller and callee. This is often true as all workers can import the same set of libraries or even share the same Python script.
Syntax Sugar
Other operations are now implementable using syntax sugar.
Retrieving Value From RRef
# helper private RPC functions
def _identity(v : Tensor) -> Tensor:
# copy the tensor by value to this remote,
return v
def _to_here(v : RRef[T]) -> T:
# take a reference, send it to the device that owns it
# and have that device return the actual tensor by value
return v.local_value()
class RRef[T]:
...
# copy a remote tensor to the local worker, sync version
def to_here(self) -> T:
return torch.rpc(_to_here, self, on=self.owner())
Builtin Operators
# proxy methods for all builtin functions exist on references for
# existing TorchScript types like Tensors. They always follow a fixed pattern:
def _mm(a : RRef[Tensor], b : RRef[Tensor]) -> RRef[Tensor]:
return a.local_value() + b.local_value()
class RRef[Tensor]:
def mm(self : RRef[Tensor], other : RRef[Tensor]) -> RRef[Tensor]:
on = same_worker(self.owner(), other.owner())
return torch.remote(on, _mm, self, other)
c : Tensor = a.mm(b).to_here()
Callable and RRef
If RRef[T]
holds a callable object T
, the application may directly call the RRef
which will be translated into torch.remote
call to the owner of the callable.
# if T is callable for RRef[T], rref(x) will be translated to calling T(x)
# on the owner of the RRef
def _call_rref(v : RRef[T], *args):
return v.local_value()(*args)
class RRef[T]:
def __call__(self, *args):
return torch.remote(self.on(), _call_rref, self, *args)
net = torch.remote("Worker1", Net)
net(inputs)
Optimizer and RRef
As models might have remote sub-modules (i.e., RRef[nn.Module]
), we should provide an optimizer sugar to handle it. The optimizer sugar (torch.optim.remote
) takes a local optimizer constructor, a distributed model parallel model, and an argument list for the local optimizer constructor. The torch.optim.remote
recursively creates a local optimizer on every remote sub-module owner, and exposes the same step API as a local optimizer which recursively calls every local optimizer.
class Net1(nn.Module):
...
class Net2(nn.Module):
...
class DMP(nn.Module):
def __init__(self):
self.net1 = torch.remote(Net1, "worker1")
self.net2 = torch.remote(Net2, "worker2")
dmp = DMP()
# torch.optim.remote creates an optimizer on every RRef destination
optimizer = torch.distributed.optimizer(torch.optim.SGD, dmp, lr=0.1)
loss = model(inputs)
# we need the autograd_context_id to distinguish distributed concurrent
# backward calls.
autograd_context_id = torch.distributed.autograd.backward(loss)
optimizer.step(autograd_context_id)
Model Parallel Training Examples
Multi-Machine Model Training
# 1. load data
inputs_rref = torch.remote("worker1", load_inputs, path_to_inputs)
labels_rref = torch.remote("worker2", load_labels, path_to_inputs)
# 2. define model
class Net1(nn.Module):
...
class Net2(nn.Module):
...
class DMP(nn.Module):
def __init__(self):
self.net1 = torch.remote(Net1, "worker1")
self.net2 = torch.remote(Net2, "worker2")
def forward(self, inputs_rref):
# RRef[T].__call__(args) is a sugar that translates to
# dist.remote(T, RRef.on(), args)
outputs1_rref = self.net1(inputs_rref)
outputs2_rref = self.net2(outputs1_rref)
return outputs2_rref
# 3. training, run it where you want to call autograd
def train(inputs_rref, labels_rref):
dmp = DMP()
# torch.optim.remote creates an optimizer on every RRef destination
optimizer = torch.distributed.optimizer(torch.optim.SGD, dmp, lr=0.1)
outputs_rref = dmp(inputs_rref)
loss = loss_func(outputs_rref.to_here(), labels_rref.to_here())
autograd_ctx_id = torch.distributed.autograd.backward(loss)
optimizer.step(autograd_ctx_id)
dist.rpc(train, dev2, inputs_rref, labels_rref)
Parameter Server Training
class ParameterServer:
def __init__(self):
self.params = torch.zeros(100, 100).to(0)
def get_params(self) -> Tensor:
return self.params
def add_grads(self, grad: Tensor):
return self.params += grad.to(0)
def train(ps)
for _ in range(10):
params = torch.rpc(ParameterServer.get_params, ps)
# run forward and backward
torch.rpc(ParameterServer.add_grads, ps, params.grad)
torch.distributed.barrier(group=TRAINER_GROUP)
ps = torch.remote("worker1",ParameterServer, 2)
torch.remote("worker2", train, ps)
torch.remote("worker3", train, ps)
System Design
Distributed Autograd
Basic Idea
In the first version, dist.autograd.backward
does not support RRef
arguments, but RRef
can still help build the autograd graph. The overall idea is as follows.
- When calling
torch.rpc
orRRef.to_here()
,send
andrecv
autograd functions will be inserted to connect local autograd graphs on multiple workers into one distributed autograd graph. - Every distributed backward pass is assigned a globally unique id (
autograd_context_id
), and every participating worker will keep a dedicate context for it. - When the backward computation reaches a
recv
function, it packs the gradient and theautograd_context_id
in the message, and pass it to itssend
counterpart. - Upon receiving a message for a
send
function in the backward pass, it uses theautograd_context_id
in the message to identify which backward pass it belongs to, and uses the gradient in the message to continue autograd computation locally.
Send and Recv Autograd Functions
Let’s start with a simple example where there is just one synchronized RPC call and there is only one tensor passed across worker boundaries. Code is on the left and the autograd graph is on the right where AccumulateGrad
autograd functions for leaf nodes are omitted for simplicity.
# the add function should be
# defined on both workers
def add() -> Tensor:
a = torch.rand(2, 2)
b = torch.rand(2, 2)
c = a + b
return c
# make RPC call from worker0
# to execute add on worker1
c1 = torch.rpc(add, on="worker1")
d = torch.ones_like(c1)
e = c1 * d
e.sum().backward()
The send
and recv
autograd functions are inserted during the forward pass, which connect two local graphs into one distributed graph. In the backward pass, the gradient will be passed to the recv
autograd function on worker0
, and the recv
autograd function will then transmit the gradient tensor to worker1
’s send
autograd function. Then, worker1
can kick off the local autograd engine to resume the backward pass. There are a few more details need to be clarified in this simple example:
- On
worker1
, how do we keep the autograd graph alive after the RPC call returns?- In short, the distributed autograd engine on
worker1
will keep a reference to thesend
function which can keep the graph alive. - Reasoning: The graph can be kept alive by keeping a reference to either tensor
C
or thesend
autograd function, as both of them hold a reference to theadd
autograd function. We choose to keep a reference to thesend
function instead of tensorC
, becauseC
as a non-leaf node produced byadd
is not needed in the backward pass. It should be freed as soon as possible. It is not memory efficient to hold C alive just because we want to have an entrance point to the autograd graph.
- In short, the distributed autograd engine on
- In the backward pass, how does
recv
onworker0
find the correctsend
onworker1
to talk to?- This can be done by assigning a globally unique ID (worker***_id + local send/recv id***) for each
send
/recv
function pair.
- This can be done by assigning a globally unique ID (worker***_id + local send/recv id***) for each
- When can
worker1
delete its local autograd graph?send
should have the same lifetime as its correspondingrecv
function. This can be done by sending a message fromworker0
toworker1
whenrecv
is destructed onworker0
. Therecv
function is kept alive by theloss
tensor. So, conceptually, the global autograd graph will be deleted when the final loss tensor is gone.
Hidden Autograd Path and Circular Dependency
Things can become complicated when an autograd graph contains multiple send/recv pairs. Consider the following example.
# all functions shoud be defined on all workers
def worker0_func(c2: Tensor) -> Tensor:
g = torch.rand(2, 2)
h = g + c2
return h
def worker1_func_top() -> Tensor:
a = torch.rand(2, 2)
b = torch.rand(2, 2)
c = a + b
return c
def worker1_func_bottom(c: Tensor, e1: Tensor) -> Tensor:
f = c + e1
return f
def worker2_func(c1: Tensor) -> Tensor:
d = torch.rand(2, 2)
e = c1 + d
return e
# on Worker3
c_ref = torch.remote(worker1_func_top, on="Worker1")
h1 = torch.rpc(worker0_func, c_ref, on="Worker0")
e_ref = torch.remote(worker2_func, c_ref, on="Worker2")
f1 = torch.rpc(worker1_funct_bottom, c_ref, e_ref, on="Worker1")
i = h1 + f1
i.sum().backward()
This example highlights two problems that we need to address:
- Hidden Autograd Path: Existing local autograd engine starts from loss (or all outputs), and do a discovery/marking phase to identify all participating functions before executing the real autograd computation. So that all paths in the autograd graph are known upfront. However, we don’t have this luxury in distributed autograd because some parts of the autograd graph reside on remote workers. For example, when grad arrives at
send5
, worker1 cannot tell whethersend3
will be in the backward pass if it only looks at local information. More specifically,i.sum().backward()
will be the same asf1.sum().backward()
from worker1’s perspective, but the former involvessend3
and the latter does not.- To address this problem, we propose to record all globally upstream (upstream in the forward pass, downstream in the autograd graph)
send
/recv
pairs in the forward pass, so that we know exactly whichsend
/recv
to wait for in the backward pass.
- To address this problem, we propose to record all globally upstream (upstream in the forward pass, downstream in the autograd graph)
- Circular Dependency: there are circular dependencies between worker1 and worker2, i.e., it is impossible to finish autograd computation on one worker before kicking off on another worker. One option is to start autograd computation on
worker1
first, and having an autograd thread blocking there waiting for grads forsend1
, but this is less ideal.- To address this problem, we propose to only create the
send
autograd function and put it in the ready queue when the grad is received. Note that, when computing dependency count foradd1
, the autograd engine still takessend1
into account, so that the engine will only start computing grads for add1 after bothadd2
andsend1
finish.
- To address this problem, we propose to only create the
Note that we need to record information in the forward pass and do the discovery in the backward pass because we don’t know which send
function will be participating in the autograd computation. However, if the application can guarantee that all send
functions will receive grad in the backward pass, we can skip all these complexity and have a more efficient version. Both scenarios are useful, so we propose to have two modes:
- Smart Mode supports running backward on a subgraph of the global autograd graph, but there will be extra overhead in both forward and backward pass.
- Fast Mode skips dependency recording in the forward pass and graph discovery in the backward pass, but the application needs to guarantee that all send autograd function will receive grad in the backward pass.
The two sections below describe the two algorithms in more details.
Distributed Autograd Algorithm Smart mode
Forward pass:
For every send
x:
- Find
send
functions in x’s lineage, by:- Finds all locally reachable
recv
functions fromsend
x in the autograd graph. In the example above,send2
findsrecv1
,send4
findsrecv3
, andsend5
findsrecv2
. - Use those found
recv
functions to find globally reachablerecv
functions insend
x’s lineage. Note that this can be done, because in step 2 we send enough information fromsend
torecv
. In the example abovesend4
knowssend3
, andsend5
knowssend1
andsend2
.
- Finds all locally reachable
- Then,
send
x includes ids of its lineagesend
functions in the message. Intuitively, it means that if there is a grad received forsend
x, the backward pass must reach allsend
functions in its lineage as well. It helps a node to determine whether it should wait for asend
grad.
# pseudo code to demonstrate how send works in forward
def find_global_lineage(tensor):
# find local lineage
recvs = find_recvs(tensor.grad_fn)
dep_ids = {recv.id for recv in recvs}
# find global lineage
dep_ids.update({dep_id for recv in recvs for dep_id in recv.dep_ids})
return dep_ids
def send(func, tensors, on):
msg = Message(func)
for tensor in tensors:
lineage = find_global_lineage(tensor)
# connect send to autograd graph
send = SendFunc()
send.next = tensor.grad_fn
# remember the send by its id
RpcAgent.send_map[send.id] = send
# coalesce data
msg.data.append((tensor, send.id, lineage))
send_msg(msg, on)
def recv(func, data, from):
tensors = []
for tensor, send_id, lineage in data:
# use send_id as recv_id, and remember global lineage
recv = RecvFunc(send_id, lineage)
tensor.grad_fn = recv
tensors.append(tensor)
return func(tensors)
Backward pass:
On the node that calls torch.distributed.backward
:
- Find all
send
functions in the lineage of the loss tensor. In the above example, it will be all 5send
functions. These ids will be propagated to therecv
functions and will be passed to the counterpartsend
functions accordingly.- Optimizations can be added, e.g., drop unnecessary ids in backward pass to reduce message size.
On every node:
- Upon receiving the first message (be it a dedicated discovery message or grad of a send), record its
autograd_context_id
, and retrieve all participatingsend
ids from the message. Compute dependency count from thosesend
functions (and also from lossgrad_fn
if loss is on this node). Set dependency count forsend
functions as 1. If there is any autograd function has dependency count 0, put them into the ready queue. - Upon receiving a
send
grad, decrement the dependency count of thatsend
by 1, and add it to the ready queue. Note this is done on anRpcAgent
thread, and some autograd engine thread will pick up the autograd function for execution.
# pseudo code to demonstrate backward
graph_tasks = {}
def backward(loss):
global graph_tasks
autograd_context_id = gen_autograd_id()
lineage = find_global_lineage(loss)
# these send will participate in the autograd pass
roots = local_sends.intersection(lineage)
# propagate the autograd_id and deps info to all
# participating workers. This is non-blocking and can
# run concurrently with the real backward computation.
# This step is not absolutely necessary, but can help other
# workers to kick off autograd earlier.
disseminate(autograd_context_id, lineage)
# below is a handwaving impl to show how it works with local autograd engine
graph_task = GraphTask()
graph_tasks[autograd_context_id] = graph_task
roots.append(loss.grad_fn)
# setup dependency count properly
compute_dependencies(GraphRoot(roots), graph_task)
# insert the task to local engine ready queue. Only the FunctionTask
# for loss is inserted now, send FunctionTasks will be inserted later
# when their grad becomes available.
ready_queue.push_back(FunctionTask(graph_task, loss.grad_fn, ...))
return autograd_context_id
def on_grad_send(send_id, grad, autograd_id):
global graph_tasks
graph_task = graph_tasks[autograd_id]
send_func = RpcAgent.send_map[send_id]
ready_queue.push_back(FunctionTask(graph_task, send_func, grad))
Distributed Autograd Algorithm Fast mode
The problem with the above approach is that including ids in send
/ recv
messages incurs overhead, especially when there are a lot of tensors communicated across multiple workers. And this discovery phase is only necessary when running autograd on subgraph. For example, f1.sum().loss()
requires the discovery phase to avoid waiting for send3
, but it is easier for i.sum().loss()
as all send
are involved in the backward. So, we propose to have one additional mode for distributed autograd to bypass send
/ recv
dependency discovery in both forward and backward** if all send for non-leaf or requires_grad
tensors will receive grad in the backward pass**. The mode can be toggled when initializing RPC agents:
# all_requires_grad (bool): If True, the application guarantees that all
# send functions on non-leaf or requires_grad tensors will receive grad
# in the backward pass. Hence, we can skip the distributed dependency
# discovery algorithm (fast mode). If False, run smart mode, where
# messages beween send/recv will contain dependency ids in both forward
# and backward pass. (default False)
torch.distributed.init_rpc(name, backend="pg", all_requires_grad=False)
Internally, RpcAgent
will create a thread-local driver ID, where a driver is the worker that pieces together the autograd graph. In the above example, Worker3
is the driver. In the forward pass, every send
function originated from this driver will be tagged with its thread-local driver ID, and this applies to all downstream (upstream in the autograd graph) send
functions as well. This can be done by either propagating this driver ID to RPC calls recursively, or do an active driver ID discovery by walking the autograd graph before sending a tensor. If this information is ambiguous, e.g., one send
function traces back to two upstream (downstream in the autograd graph) recv
functions from two different drivers, it will throw an error. In the backward pass, the thread-local driver id of the loss will be included in the entire autograd execution to identify participating send
functions. Note that, in this mode, the application cannot keep two disjoint autograd graphs alive at the same time, as that would break the assumption that all send (originated from the driver) will receive grad in the backward pass.
Concurrent distributed Backward passes
A = torch.rand(2, 2)
B = torch.rand(2, 2)
# on all workers
def add() -> Tensor:
global A, B
return A + B
# on worker0
C = torch.remote(add, on="worker2").to_here()
C.sum().backward()
# on worker1
C = torch.remote(add, on="worker2").to_here()
C.sum().backward()
In the above example, there are two concurrent backward passes triggered by worker0
and worker1
respectively, and both will reach worker2
. To avoid race, the distributed autograd engine will use the globally unique autograd_context_id
to create a dedicated context on every participating worker. Later, pass this autograd_context_id
to optimizer to apply gradients. More concretely, this would work as follows:
- Compute all the leaf nodes in the autograd graph.
- As part of running distributed backwards, use the outputs parameter of the autograd engine to avoid executing
AccumulateGrad
for the leaf nodes we have and instead return the appropriateoutput_edges
to execute for accumulating gradients. - Store the
output_edges
with theautograd_context_id
. This would ensure multiple backward passes won't accumulate gradients in the same context. - This completes the backward pass and gradients are accumulated in the autograd engine per
autograd_context_id.
- Now we run the optimizer on each of the worker nodes and pass the
autograd_context_id
to the optimizer. - The optimizer applies all the gradients to the leaf nodes that we computed originally.
- The context and enclosing gradients should be destroyed when the
autograd_context_id
is destructed on the caller ofbackward()
.
Some pseudo-code to illustrate this:
optimizer = dist.optimizer(model)
loss = model(inputs)
bw_ctx_id = dist.autograd.backward(loss, timeout=60) # timeout of 60s
optimizer.step(bw_ctx_id)
RRef
RRef
is an important concept for building a distributed autograd graph. Each RRef
is owned by a single worker (i.e., owner) and can be used by multiple users. The owner stores the real data referenced by its RRef
s, and keeps track of the global reference counts for its RRef
s. Every RRef
can be uniquely identified by a global id ref_id
, which is assigned at the time it is first created either on a user or on the owner.
The owner only keeps one RRef
instance for each data object, while users can fork as many RRef
instances as necessary. All usage on the owner should retrieve the RRef
instance using the globally unique ref_id
. A fork of RRef
will be created when it is used as an argument or return value in a RPC call, but users don't need to worry about forking/forwarding and reference counting (RC) RRef
s. These will be handled transparently, and every fork will also have its own fork_id
, which is guaranteed to be unique across all RRef
instances for the same data object.
RRef
needs to support fast and scalable RPC. Hence, in the RC design, we avoid using any global master to keep RRef
states. Besides, when worker X invokes RPC on worker Y, Y should be able to start immediately after receiving the RPC request, without waiting for any third-party owner Z (unless Y needs to pull real data from Z), even if neither X nor Y owns the RRef
. We propose the following algorithm:
- If the owner is the RPC caller, the owner will update RC for the
RRef
accordingly. - If the owner is the RPC callee, the owner will drop the new fork, and use the unique
RRef
id in the fork to access its singleton localRRef
instance. - If the RPC is between two users:
- The caller sends an RPC message to the callee, and also notifies the owner on the new fork.
- The owner, upon receiving the notification, updates its local RC and then tells the callee the new fork is now known by the owner.
- The callee can starts executing the RPC as soon as it receives the RPC message from the caller, and does not need to wait for the message from the owner. However, it cannot delete its local
RRef
fork until owner's message arrives.
Reference Count
The right time to delete an RRef
on owner is when there are no living forks on any user and Python GC also agrees to delete the RRef
instance on the owner. The tricky part is to determine if there are any living forks.
A user can get a fork in three situations:
- Receiving a fork from the owner.
- Receiving a fork from another user.
- Creating a new
RRef
fork owned by another worker.
#1 is the simplest case where the owner initiates the fork, and hence it can easily increase local RC. The only requirement is that any fork must notify the owner before destruction. Hence, we need the first guarantee:
- G1. The owner will be notified when any fork is deleted.*
Note that the notification might come delayed or out-of-order.
With #2 and #3, it is possible that the owner only partially knows the RRef
fork graph or not even knowing it at all. For example, the RRef
could be constructed on a user, and before the owner receives the RPC call, the creator user might have already shared the RRef
with other users, and those users could further share the RRef
. One invariant is that the fork graph of any RRef
is a tree rooted at the owner, because forking an RRef
always creates a new RRef
instance, and hence every RRef
has a parent. One nasty detail is that when an RRef
is created on a user, technically the owner is not its parent but we still consider it that way and it does not break the argument below.
The owner's view on any node (fork) in the tree has three stages 1) unknown → 2) known → 3) deleted, and the owner's view on the entire tree keeps changing. The owner deletes its RRef
instance when it thinks there is no living forks, i.e., all the forks could be either indeed deleted or unknown. Therefore, the dangerous case is when some forks are unknown and others are deleted. We only need a simple guarantee to prevent this situation:
*G2. No fork x can be deleted on a user before the owner knows x’s parent fork.
*
This works because owner's view on x can only change from known to deleted when x's parent is known or deleted. If the parent is known, owner will not delete local RRef
. If the parent is deleted, this rule recursively applies to the parent's parent, until it reaches the root (owner). To implement the guarantee, we only need to make the caller include its own fork_id
when notifying the owner on a new fork.
G1 and G2 guarantee correct RC, but does not prevent a user deleting before finishes its own prior RPC calls using that RRef
fork. This should be OK, because when the caller deserializes the RPC message, it would hold a reference () to that RRef
, preventing it from been deleted.