[RFC] RPC Based Distributed Model Parallel - mrshenli/pytorch GitHub Wiki

PyTorch currently provides simple APIs for single machine data parallel, distributed data parallel, and single machine model parallel. However, when it comes to distributed model parallel, applications have to build their own scaffold to stitch together local autograd graphs into one global graph. This proposal aims to fill in that gap by providing an RPC-Based distributed model parallel API. In short, applications may run RPC to execute code remotely in the forward pass, and autograd will automatically travel across RPC boundaries in the backward pass.

API

Core Concepts

RRef[T] - (abbreviation ref) A reference to a value of some type T (e.g. Tensor) on a remote worker. This handle keeps the referenced remote tensor value alive on the owner, but there is no implication that the value will be transferred to the local worker in the future. It is valid to have a reference to local value as well, and values of type T can be implicitly converted to RRef[T]. This implicit conversion will be critical later to allow the expression of different types of RPC. Think of it like the implicit conversion from std::string to const std::string &. See System Design section for more details about RRef.

ref.owner() # what is the worker this value lives on

v = ref.local_value() # if ref.owner() is local worker, then
                      # this returns the the underlying value, otherwise error.
# you can create a ref to a local tensor
t = torch.rand(3, 4)         
ref2 = torch.RRef(t)

# in TorchScript, T can be automatically converted to RRef[T]
ref3 : RRef[Tensor] = t

Future[T] - (abbreviation fut) a guarantee that at some future point in time the value of type T will be available locally. The action to create T locally is assumed to be scheduled and in-progress. Future is already supported in TorchScript and we are extending this to remote calls.

v = fut.wait() # block the current thread until v is ready

# local cpu task creation returns a future to the computed tensors
fut = torch.fork(lambda x, y: x + y, torch.rand(3, 4), torch.rand(3, 4))

Core Functions

# synchronous
result : T = torch.rpc(on : Worker, remote_callable : Callable, *args)
# asynchronous
result : Future[T] = torch.async_rpc(on : Worker, remote_callable : Callable, *args)
# remote reference
result : RRef[T] = torch.remote(on : Worker, remote_callable : Callable, *args)

Each function above invokes remote_callable on a remote worker. Value types in the args list are copied by value to the remote worker. RRef[T] types in the args list are copied by reference to the remote worker (again see the analogy between std::string and const std::string&).

The synchronous variant copies the result value back, blocking the calling thread until the response occurs. The asynchronous variant returns immediately with a future. The remote knows that the call will expect to receive the value so it will send a message back at some point with the result without further prompting.

The remote reference variant returns immediately with an RRef of the return value. The remote knows that the caller does not expect to receive the result value.

Below shows how these functions are used:

# make some local tensors
a : Tensor = torch.rand(3, 4)
b : Tensor = torch.rand(3, 4)

# define a remote function, visible to all machines.
# type annotations define expected input/output types.
def remote_function(a : Tensor, b : RRef[Tensor]) -> Tensor:
   # 'b' in the type signature is a remote reference, so we must copy it here
   # to use it locally.
   # to_here() is defined later in the syntax sugar section, it synchronously
   # copies the tensor to this worker.
   b_l  : Tensor = b.to_here()
   return a + b_l

# run remote_function on a different device. 
# a is copied by value since it is a Tensor
# b is copied by reference remote machine due to the RRef[Tensor] 
# type annotation in the signature, which causes an implicit conversion to a
# reference type.

# torch.remote always creates an RRef of the result type. 
# It does not wait for the remote's response. 
# There is no implied copy of the tensor data yet.
c : RRef[Tensor] = torch.remote("worker1", remote_function, a, b)

# we can explicitly request the data to be copied back here:
c_l : Tensor = c.to_here()

# another example:
def remote_function2(a : Tensor, b : Tensor) -> Tensor:
   return a + b

# Here we call torch.rpc which returns the value directly without
# creating a remote reference.
# we synchronously wait for remote_function2 to return.
c : Tensor = torch.rpc("worker2", remote_function2, a, b)

# When the RPC call is returning a non-reference type, we need to wait for 
# a response from the remote host. To avoid synchronously waiting, use the
# async flag to get a future instead.
c_f : Future[Tensor] = torch.async_rpc("worker2", remote_function2, a, b)
# even before calling wait, the remote knows that the data should be sent back
# to the caller as soon as it is ready.

# force the local thread to wait for the remote's response
c = c_f.wait()


# if you omit type annotations in the remote function, the assumption is that 
# arguments are passed without any implicit conversions
def remote_function3(a, b):
   # no annotations mean that a, b will be Tensor since there is no conversion
   return a + b

c: Tensor = torch.rpc("worker2", remote_function3, a, b)

RRef Forks

Implicit Conversions for RRef Arguments

We allow implicit conversion between T and RRef[T] for arguments of RPC functions. Both the actual and formal parameter can either be a T or an RRef[T], leading to four cases that might occur:

T → T (passing a T to an rpc that accepts a T): the value T is copied by value, and send over the wire as part of the message invoking the RPC

T → RRef[T] (passing a T to an rpc that accepts RRef[T]): The caller constructs a remote reference to the argument, and sends the reference over the wire to the callee. The data is not sent. The callee can then use the reference as a handle to either request the data later or to make further remote calls.

RRef[T] → T (passing an RRef[T] to an rpc that accepts T): The callee expects to get an actual value, so the callee needs to turn the reference into a value. The network behavior depends on where the RRef[T] lives.

  • If the RRef[T] lives on the caller, then the implementation looks up the actual value of T locally and pass it by value along the wire similar to the T → T case.
  • If the RRef[T] lives on the callee, then the implementation just sends the reference and the callee does the lookup locally.
  • If the RRef[T] lives on some third machine, then the caller sends 2 messages. One to the third machine telling it to send the data in the remote reference directly to the callee, and one to the callee telling it to start the RPC and expect this input to be coming from the third machine. This effectively forward value of the RRef[T] to the callee without the caller having to load it or the callee having to request it later

Examples:

def remote_function1() -> Tensor:
    return torch.ones(2)
    
def remote_function2(a : Tensor) → Tensor:
    b = a * 2
    return b

aref : RRef[Tensor] = remote("worker1", remote_function1)

# this local worker will make two RPC calls: one to tell worker1 to send the 
# tensor to worker2, and another one to tell worker2 to expect this Tensor input
# from worker1. remote_function2 will run on worker2 only after it received the 
# tensor from worker1.  
bref : RRef[Tensor] = remote("worker2", remote_function2, aref) 

RRef[T] → RRef[T] (**passing an RRef[T] to an RPC that accepts RRef[T]): **The callee expects an RRef[T], but we must make sure we correctly keep track of references to the value on a remote. So the actual behavior depends on where the RRef[T] lives.

  • If RRef[T] lives on the caller, then we simply pass it to the remote and record that this remote now has a live reference to the value.
  • If the RRef[T] lives on the callee, then we pass it to the remote, and it becomes a local reference on the remote.
  • If RRef[T] lives on some third machine, then we must forward the reference. To do this the caller sends two messages. One to the third machine telling it to create a remote reference and send it to the callee, and one to the callee telling from where to expect the remote. The callee code is not invoked until the remote is transferred to ensure sane reference counting.

Examples:

def remote_function1() -> Tensor:
    return torch.ones(2)
    
def remote_function2(a : RRef[Tensor]) -> Tensor:
    int delta = 10
    return a.to_here() + delta

aref : RRef[Tensor] = remote("worker1", remote_function1)

# this local worker will make two RPC calls: one to tell worker1 to create a 
# remote reference and send it to worker2, and another one to tell worker2 to 
# expect this remote reference input from worker1. remote_function2 code will 
# not run on worker2 until it receives the remote reference from worker1 to 
# ensure proper reference counting.   
bref : RRef[Tensor] = remote("worker2", remote_function2, aref) 

When an RRef[T] goes dead on machine A, a message is sent to the owner of T telling it that the reference from machine A is dead.

Explicit RRef type for return values

The above implicit RRef argument conversion does not apply to return values. If remote_function returns RRef[T], calling it remotely using torch.remote would return RRef[RRef[T]] instead of RRef[T]. This is because when the return value RRef of torch.remote is first created on the caller who does not know the owner of the real data T. T could be stored on the callee of torch.remote, but it could also be on a different worker as callee may also make another remote call within remote_function and return an RRef[T] owned by a different worker. Moreover, the caller is allowed to share the returned RRef with other workers immediately after torch.remote returns. However, as by then, the caller does not know the real owner of T yet, sharing the RRef would break the reference count algorithm.

Examples:

def remote_function3() -> RRef[Tensor]:
    return torch.remote("Worker2", torch.ones, 2, 2)
 
cref : RRef[RRef[Tensor]] = remote("worker1", remote_function3) 

Initialization API

Users may choose communication backend for RPC, and users are responsible for setting up the backend properly before calling the init_rpc method.

# backend: specifies the underlying communication implementation
# init_method: contains the information to initialize/connect a name store to 
#              resolve names
# name: is a unique identifier for the current worker
torch.distributed.init_rpc(backend="pg", init_method="file:///...", name="worker1")

The init_rpc method will create an RpcAgent under the hood and will make the current worker ready to send and receive RPC calls. If you call init_rpc and use the ProcessGroup (pg) backend, it acts as a global barrier, where all the node names as collectively synchronized before continuing. This is not the case if you use a peer to peer backend (e.g. tensor pipes), where calling init_rpc will register the node name in the specified store and start serving.

Applications don’t need to explicitly register functions for remote execution, but we do assume same functions are defined on both caller and callee. This is often true as all workers can import the same set of libraries or even share the same Python script.

Syntax Sugar

Other operations are now implementable using syntax sugar.

Retrieving Value From RRef

# helper private RPC functions
def _identity(v : Tensor) -> Tensor:
    # copy the tensor by value to this remote,
    return v

def _to_here(v : RRef[T]) -> T:
    # take a reference, send it to the device that owns it
    # and have that device return the actual tensor by value
    return v.local_value()

class RRef[T]:
    ...
    # copy a remote tensor to the local worker, sync version
    def to_here(self) -> T:
        return torch.rpc(_to_here, self, on=self.owner())

Builtin Operators

# proxy methods for all builtin functions exist on references for 
# existing TorchScript types like Tensors. They always follow a fixed pattern:
def _mm(a : RRef[Tensor], b : RRef[Tensor]) -> RRef[Tensor]:
    return a.local_value() + b.local_value()
     
class RRef[Tensor]:
    def mm(self : RRef[Tensor], other : RRef[Tensor]) -> RRef[Tensor]:
        on = same_worker(self.owner(), other.owner())
        return torch.remote(on, _mm, self, other)


c : Tensor = a.mm(b).to_here()

Callable and RRef

If RRef[T] holds a callable object T, the application may directly call the RRef which will be translated into torch.remote call to the owner of the callable.

# if T is callable for RRef[T], rref(x) will be translated to calling T(x) 
# on the owner of the RRef
def _call_rref(v : RRef[T], *args):
    return v.local_value()(*args)

class RRef[T]:
    def __call__(self, *args):
        return torch.remote(self.on(), _call_rref, self, *args)

net = torch.remote("Worker1", Net)
net(inputs)

Optimizer and RRef

As models might have remote sub-modules (i.e., RRef[nn.Module]), we should provide an optimizer sugar to handle it. The optimizer sugar (torch.optim.remote) takes a local optimizer constructor, a distributed model parallel model, and an argument list for the local optimizer constructor. The torch.optim.remote recursively creates a local optimizer on every remote sub-module owner, and exposes the same step API as a local optimizer which recursively calls every local optimizer.

class Net1(nn.Module):
    ...

class Net2(nn.Module):
    ...

class DMP(nn.Module):
    def __init__(self):
        self.net1 = torch.remote(Net1, "worker1")
        self.net2 = torch.remote(Net2, "worker2")
        
dmp = DMP()
# torch.optim.remote creates an optimizer on every RRef destination
optimizer = torch.distributed.optimizer(torch.optim.SGD, dmp, lr=0.1)
loss = model(inputs)
# we need the autograd_context_id to distinguish distributed concurrent
# backward calls.
autograd_context_id = torch.distributed.autograd.backward(loss)
optimizer.step(autograd_context_id)

Model Parallel Training Examples

Multi-Machine Model Training

# 1. load data
inputs_rref = torch.remote("worker1", load_inputs, path_to_inputs) 
labels_rref = torch.remote("worker2", load_labels, path_to_inputs)

# 2. define model
class Net1(nn.Module):
    ...

class Net2(nn.Module):
    ...

class DMP(nn.Module):
    def __init__(self):
        self.net1 = torch.remote(Net1, "worker1")
        self.net2 = torch.remote(Net2, "worker2")
        
    def forward(self, inputs_rref):
        # RRef[T].__call__(args) is a sugar that translates to 
        # dist.remote(T, RRef.on(), args)
        outputs1_rref = self.net1(inputs_rref)
        outputs2_rref = self.net2(outputs1_rref)
        return outputs2_rref
        
# 3. training, run it where you want to call autograd
def train(inputs_rref, labels_rref):
    dmp = DMP()
    # torch.optim.remote creates an optimizer on every RRef destination
    optimizer = torch.distributed.optimizer(torch.optim.SGD, dmp, lr=0.1)
    outputs_rref = dmp(inputs_rref)
    loss = loss_func(outputs_rref.to_here(), labels_rref.to_here())
    autograd_ctx_id = torch.distributed.autograd.backward(loss)
    optimizer.step(autograd_ctx_id)
    
dist.rpc(train, dev2, inputs_rref, labels_rref)

Parameter Server Training

class ParameterServer:
    def __init__(self):
        self.params = torch.zeros(100, 100).to(0)
        
    def get_params(self) -> Tensor:
        return self.params
        
    def add_grads(self, grad: Tensor):
        return self.params += grad.to(0)
        
def train(ps)
    for _ in range(10):
        params = torch.rpc(ParameterServer.get_params, ps)
        # run forward and backward
        torch.rpc(ParameterServer.add_grads, ps, params.grad)
        torch.distributed.barrier(group=TRAINER_GROUP)
        
ps = torch.remote("worker1",ParameterServer, 2)
torch.remote("worker2", train, ps)
torch.remote("worker3", train, ps)
            

System Design

Distributed Autograd

Basic Idea

In the first version, dist.autograd.backward does not support RRef arguments, but RRef can still help build the autograd graph. The overall idea is as follows.

  • When calling torch.rpc or RRef.to_here(), send and recv autograd functions will be inserted to connect local autograd graphs on multiple workers into one distributed autograd graph.
  • Every distributed backward pass is assigned a globally unique id (autograd_context_id), and every participating worker will keep a dedicate context for it.
  • When the backward computation reaches a recv function, it packs the gradient and the autograd_context_id in the message, and pass it to its send counterpart.
  • Upon receiving a message for a send function in the backward pass, it uses the autograd_context_id in the message to identify which backward pass it belongs to, and uses the gradient in the message to continue autograd computation locally.

Send and Recv Autograd Functions

Let’s start with a simple example where there is just one synchronized RPC call and there is only one tensor passed across worker boundaries. Code is on the left and the autograd graph is on the right where AccumulateGrad autograd functions for leaf nodes are omitted for simplicity.

# the add function should be 
# defined on both workers
def add() -> Tensor:
    a = torch.rand(2, 2)
    b = torch.rand(2, 2)
    c = a + b
    return c
    
# make RPC call from worker0
# to execute add on worker1
c1 = torch.rpc(add, on="worker1")
d = torch.ones_like(c1)
e = c1 * d
e.sum().backward()

The send and recv autograd functions are inserted during the forward pass, which connect two local graphs into one distributed graph. In the backward pass, the gradient will be passed to the recv autograd function on worker0, and the recv autograd function will then transmit the gradient tensor to worker1’s send autograd function. Then, worker1 can kick off the local autograd engine to resume the backward pass. There are a few more details need to be clarified in this simple example:

  • On worker1, how do we keep the autograd graph alive after the RPC call returns?
    • In short, the distributed autograd engine on worker1 will keep a reference to the send function which can keep the graph alive.
    • Reasoning: The graph can be kept alive by keeping a reference to either tensor C or the send autograd function, as both of them hold a reference to the add autograd function. We choose to keep a reference to the send function instead of tensor C, because C as a non-leaf node produced by add is not needed in the backward pass. It should be freed as soon as possible. It is not memory efficient to hold C alive just because we want to have an entrance point to the autograd graph.
  • In the backward pass, how does recv on worker0 find the correct send on worker1 to talk to?
    • This can be done by assigning a globally unique ID (worker***_id + local send/recv id***) for each send / recv function pair.
  • When can worker1 delete its local autograd graph?
    • send should have the same lifetime as its corresponding recv function. This can be done by sending a message from worker0 to worker1 when recv is destructed on worker0. The recv function is kept alive by the loss tensor. So, conceptually, the global autograd graph will be deleted when the final loss tensor is gone.

Hidden Autograd Path and Circular Dependency

Things can become complicated when an autograd graph contains multiple send/recv pairs. Consider the following example.

# all functions shoud be defined on all workers
def worker0_func(c2: Tensor) -> Tensor:
    g = torch.rand(2, 2)
    h = g + c2
    return h

def worker1_func_top() -> Tensor:
    a = torch.rand(2, 2)
    b = torch.rand(2, 2)
    c = a + b
    return c

def worker1_func_bottom(c: Tensor, e1: Tensor) -> Tensor:
    f = c + e1
    return f

def worker2_func(c1: Tensor) -> Tensor:
    d = torch.rand(2, 2)
    e = c1 + d
    return e

# on Worker3
c_ref = torch.remote(worker1_func_top, on="Worker1")
h1 = torch.rpc(worker0_func, c_ref, on="Worker0")
e_ref = torch.remote(worker2_func, c_ref, on="Worker2")
f1 = torch.rpc(worker1_funct_bottom, c_ref, e_ref, on="Worker1")
i = h1 + f1
i.sum().backward()

This example highlights two problems that we need to address:

  • Hidden Autograd Path: Existing local autograd engine starts from loss (or all outputs), and do a discovery/marking phase to identify all participating functions before executing the real autograd computation. So that all paths in the autograd graph are known upfront. However, we don’t have this luxury in distributed autograd because some parts of the autograd graph reside on remote workers. For example, when grad arrives at send5, worker1 cannot tell whether send3 will be in the backward pass if it only looks at local information. More specifically, i.sum().backward() will be the same as f1.sum().backward() from worker1’s perspective, but the former involves send3 and the latter does not.
    • To address this problem, we propose to record all globally upstream (upstream in the forward pass, downstream in the autograd graph) send / recv pairs in the forward pass, so that we know exactly which send / recv to wait for in the backward pass.
  • Circular Dependency: there are circular dependencies between worker1 and worker2, i.e., it is impossible to finish autograd computation on one worker before kicking off on another worker. One option is to start autograd computation on worker1 first, and having an autograd thread blocking there waiting for grads for send1, but this is less ideal.
    • To address this problem, we propose to only create the send autograd function and put it in the ready queue when the grad is received. Note that, when computing dependency count for add1, the autograd engine still takes send1 into account, so that the engine will only start computing grads for add1 after both add2 and send1 finish.

Note that we need to record information in the forward pass and do the discovery in the backward pass because we don’t know which send function will be participating in the autograd computation. However, if the application can guarantee that all send functions will receive grad in the backward pass, we can skip all these complexity and have a more efficient version. Both scenarios are useful, so we propose to have two modes:

  • Smart Mode supports running backward on a subgraph of the global autograd graph, but there will be extra overhead in both forward and backward pass.
  • Fast Mode skips dependency recording in the forward pass and graph discovery in the backward pass, but the application needs to guarantee that all send autograd function will receive grad in the backward pass.

The two sections below describe the two algorithms in more details.

Distributed Autograd Algorithm Smart mode

Forward pass:

For every send x:

  1. Find send functions in x’s lineage, by:
    1. Finds all locally reachable recv functions from send x in the autograd graph. In the example above, send2 finds recv1, send4 finds recv3, and send5 finds recv2.
    2. Use those found recv functions to find globally reachable recv functions in send x’s lineage. Note that this can be done, because in step 2 we send enough information from send to recv. In the example above send4 knows send3, and send5 knows send1 and send2.
  2. Then, send x includes ids of its lineage send functions in the message. Intuitively, it means that if there is a grad received for send x, the backward pass must reach all send functions in its lineage as well. It helps a node to determine whether it should wait for a send grad.
# pseudo code to demonstrate how send works in forward
def find_global_lineage(tensor):
    # find local lineage
    recvs = find_recvs(tensor.grad_fn)
    dep_ids = {recv.id for recv in recvs}
    # find global lineage
    dep_ids.update({dep_id for recv in recvs for dep_id in recv.dep_ids})
    return dep_ids

def send(func, tensors, on):
    msg = Message(func)
    for tensor in tensors:
        lineage = find_global_lineage(tensor)
        # connect send to autograd graph
        send = SendFunc()
        send.next = tensor.grad_fn
        # remember the send by its id
        RpcAgent.send_map[send.id] = send
        # coalesce data
        msg.data.append((tensor, send.id, lineage))
    send_msg(msg, on)
    
def recv(func, data, from):
    tensors = []
    for tensor, send_id, lineage in data:
        # use send_id as recv_id, and remember global lineage
        recv = RecvFunc(send_id, lineage)
        tensor.grad_fn = recv
        tensors.append(tensor)
        
    return func(tensors)

Backward pass:

On the node that calls torch.distributed.backward:

  1. Find all send functions in the lineage of the loss tensor. In the above example, it will be all 5 send functions. These ids will be propagated to the recv functions and will be passed to the counterpart send functions accordingly.
    1. Optimizations can be added, e.g., drop unnecessary ids in backward pass to reduce message size.

On every node:

  1. Upon receiving the first message (be it a dedicated discovery message or grad of a send), record its autograd_context_id, and retrieve all participating send ids from the message. Compute dependency count from those send functions (and also from loss grad_fn if loss is on this node). Set dependency count for send functions as 1. If there is any autograd function has dependency count 0, put them into the ready queue.
  2. Upon receiving a send grad, decrement the dependency count of that send by 1, and add it to the ready queue. Note this is done on an RpcAgent thread, and some autograd engine thread will pick up the autograd function for execution.
# pseudo code to demonstrate backward
graph_tasks = {}
def backward(loss):
    global graph_tasks
    
    autograd_context_id = gen_autograd_id()
    lineage = find_global_lineage(loss)
    # these send will participate in the autograd pass
    roots = local_sends.intersection(lineage)
        
    # propagate the autograd_id and deps info to all
    # participating workers. This is non-blocking and can
    # run concurrently with the real backward computation. 
    # This step is not absolutely necessary, but can help other
    # workers to kick off autograd earlier.
    disseminate(autograd_context_id, lineage)

    # below is a handwaving impl to show how it works with local autograd engine
    graph_task = GraphTask()
    graph_tasks[autograd_context_id] = graph_task
    roots.append(loss.grad_fn)
    # setup dependency count properly
    compute_dependencies(GraphRoot(roots), graph_task)
    # insert the task to local engine ready queue. Only the FunctionTask
    # for loss is inserted now, send FunctionTasks will be inserted later
    # when their grad becomes available.
    ready_queue.push_back(FunctionTask(graph_task, loss.grad_fn, ...))
    return autograd_context_id
    
    
def on_grad_send(send_id, grad, autograd_id):
    global graph_tasks
    graph_task = graph_tasks[autograd_id]
    send_func = RpcAgent.send_map[send_id]
    ready_queue.push_back(FunctionTask(graph_task, send_func, grad))

Distributed Autograd Algorithm Fast mode

The problem with the above approach is that including ids in send / recv messages incurs overhead, especially when there are a lot of tensors communicated across multiple workers. And this discovery phase is only necessary when running autograd on subgraph. For example, f1.sum().loss() requires the discovery phase to avoid waiting for send3, but it is easier for i.sum().loss() as all send are involved in the backward. So, we propose to have one additional mode for distributed autograd to bypass send / recv dependency discovery in both forward and backward** if all send for non-leaf or requires_grad tensors will receive grad in the backward pass**. The mode can be toggled when initializing RPC agents:

# all_requires_grad (bool): If True, the application guarantees that all 
# send functions on non-leaf or requires_grad tensors will receive grad 
# in the backward pass. Hence, we can skip the distributed dependency 
# discovery algorithm (fast mode). If False, run smart mode, where 
# messages beween send/recv will contain dependency ids in both forward
# and backward pass. (default False)
torch.distributed.init_rpc(name, backend="pg", all_requires_grad=False)

Internally, RpcAgent will create a thread-local driver ID, where a driver is the worker that pieces together the autograd graph. In the above example, Worker3 is the driver. In the forward pass, every send function originated from this driver will be tagged with its thread-local driver ID, and this applies to all downstream (upstream in the autograd graph) send functions as well. This can be done by either propagating this driver ID to RPC calls recursively, or do an active driver ID discovery by walking the autograd graph before sending a tensor. If this information is ambiguous, e.g., one send function traces back to two upstream (downstream in the autograd graph) recv functions from two different drivers, it will throw an error. In the backward pass, the thread-local driver id of the loss will be included in the entire autograd execution to identify participating send functions. Note that, in this mode, the application cannot keep two disjoint autograd graphs alive at the same time, as that would break the assumption that all send (originated from the driver) will receive grad in the backward pass.

Concurrent distributed Backward passes

A = torch.rand(2, 2)
B = torch.rand(2, 2)
    
# on all workers
def add() -> Tensor:
    global A, B
    return A + B
    
# on worker0
C = torch.remote(add, on="worker2").to_here()
C.sum().backward()

# on worker1
C = torch.remote(add, on="worker2").to_here()
C.sum().backward()

In the above example, there are two concurrent backward passes triggered by worker0 and worker1 respectively, and both will reach worker2. To avoid race, the distributed autograd engine will use the globally unique autograd_context_id to create a dedicated context on every participating worker. Later, pass this autograd_context_id to optimizer to apply gradients. More concretely, this would work as follows:

  1. Compute all the leaf nodes in the autograd graph.
  2. As part of running distributed backwards, use the outputs parameter of the autograd engine to avoid executing AccumulateGrad for the leaf nodes we have and instead return the appropriate output_edges to execute for accumulating gradients.
  3. Store the output_edges with the autograd_context_id. This would ensure multiple backward passes won't accumulate gradients in the same context.
  4. This completes the backward pass and gradients are accumulated in the autograd engine per autograd_context_id.
  5. Now we run the optimizer on each of the worker nodes and pass the autograd_context_id to the optimizer.
  6. The optimizer applies all the gradients to the leaf nodes that we computed originally.
  7. The context and enclosing gradients should be destroyed when the autograd_context_id is destructed on the caller of backward().

Some pseudo-code to illustrate this:

optimizer = dist.optimizer(model)
loss = model(inputs)
bw_ctx_id = dist.autograd.backward(loss, timeout=60) # timeout of 60s
optimizer.step(bw_ctx_id)

RRef

RRef is an important concept for building a distributed autograd graph. Each RRef is owned by a single worker (i.e., owner) and can be used by multiple users. The owner stores the real data referenced by its RRefs, and keeps track of the global reference counts for its RRefs. Every RRef can be uniquely identified by a global id ref_id, which is assigned at the time it is first created either on a user or on the owner.

The owner only keeps one RRef instance for each data object, while users can fork as many RRef instances as necessary. All usage on the owner should retrieve the RRef instance using the globally unique ref_id. A fork of RRef will be created when it is used as an argument or return value in a RPC call, but users don't need to worry about forking/forwarding and reference counting (RC) RRefs. These will be handled transparently, and every fork will also have its own fork_id, which is guaranteed to be unique across all RRef instances for the same data object.

RRef needs to support fast and scalable RPC. Hence, in the RC design, we avoid using any global master to keep RRef states. Besides, when worker X invokes RPC on worker Y, Y should be able to start immediately after receiving the RPC request, without waiting for any third-party owner Z (unless Y needs to pull real data from Z), even if neither X nor Y owns the RRef. We propose the following algorithm:

  1. If the owner is the RPC caller, the owner will update RC for the RRef accordingly.
  2. If the owner is the RPC callee, the owner will drop the new fork, and use the unique RRef id in the fork to access its singleton local RRef instance.
  3. If the RPC is between two users:
    1. The caller sends an RPC message to the callee, and also notifies the owner on the new fork.
    2. The owner, upon receiving the notification, updates its local RC and then tells the callee the new fork is now known by the owner.
    3. The callee can starts executing the RPC as soon as it receives the RPC message from the caller, and does not need to wait for the message from the owner. However, it cannot delete its local RRef fork until owner's message arrives.

Reference Count

The right time to delete an RRef on owner is when there are no living forks on any user and Python GC also agrees to delete the RRef instance on the owner. The tricky part is to determine if there are any living forks.

A user can get a fork in three situations:

  1. Receiving a fork from the owner.
  2. Receiving a fork from another user.
  3. Creating a new RRef fork owned by another worker.

#1 is the simplest case where the owner initiates the fork, and hence it can easily increase local RC. The only requirement is that any fork must notify the owner before destruction. Hence, we need the first guarantee:

  • G1. The owner will be notified when any fork is deleted.*

Note that the notification might come delayed or out-of-order.

With #2 and #3, it is possible that the owner only partially knows the RRef fork graph or not even knowing it at all. For example, the RRef could be constructed on a user, and before the owner receives the RPC call, the creator user might have already shared the RRef with other users, and those users could further share the RRef. One invariant is that the fork graph of any RRef is a tree rooted at the owner, because forking an RRef always creates a new RRef instance, and hence every RRef has a parent. One nasty detail is that when an RRef is created on a user, technically the owner is not its parent but we still consider it that way and it does not break the argument below.

The owner's view on any node (fork) in the tree has three stages 1) unknown → 2) known → 3) deleted, and the owner's view on the entire tree keeps changing. The owner deletes its RRef instance when it thinks there is no living forks, i.e., all the forks could be either indeed deleted or unknown. Therefore, the dangerous case is when some forks are unknown and others are deleted. We only need a simple guarantee to prevent this situation:

*G2. No fork x can be deleted on a user before the owner knows x’s parent fork. * This works because owner's view on x can only change from known to deleted when x's parent is known or deleted. If the parent is known, owner will not delete local RRef. If the parent is deleted, this rule recursively applies to the parent's parent, until it reaches the root (owner). To implement the guarantee, we only need to make the caller include its own fork_id when notifying the owner on a new fork.

G1 and G2 guarantee correct RC, but does not prevent a user deleting before finishes its own prior RPC calls using that RRef fork. This should be OK, because when the caller deserializes the RPC message, it would hold a reference () to that RRef, preventing it from been deleted.