vmap onboarding lab - pytorch/pytorch GitHub Wiki
The goal of this lab is to give first hand experience with writing vmap rules (aka batching rules) and adding them to PyTorch. An alternative to this lab is taking on real tasks to add batching rules to PyTorch.
The deliverable is a stack of PRs (that won't be merged into master) containing all of the code for the different sections below. The task will be considered finished when the PR is accepted by the reviewers and you can just close it at that time.
For this lab, we'll be writing a batching rule for a new operator, simple_mul
that is similar to torch.mul
, but has more restrictions on its input types. Concretely, simple_mul
looks like the following:
def simple_mul(x: Tensor, y: Tensor):
return torch.mul(x, y)
For the sake of learning, please do not read the existing batching rule for torch.mul. When the lab is complete, the following test cases should work:
import torch
from functorch import vmap
# The dimension being vmapped over
B = 2
B1 = 3
for op in [torch.mul, torch.simple_mul]:
# A) Simple case
x = torch.randn(B)
y = torch.randn(B)
vmap(op)(x, y)
# B) vmap over some Tensors
x = torch.randn(3)
y = torch.randn(B)
vmap(op, in_dims=(None, 0))(x, y)
# C) More complicated case
x = torch.randn(4, 3)
y = torch.randn(3, B)
vmap(op, in_dims=(None, 1))(x, y)
# D) Nested vmap
x = torch.randn(B)
y = torch.randn(B1)
vmap(vmap(op, (0, None)), (None, 0))(x, y)
For example, in case A, vmap(torch.mul)(x, y)
can be rewritten as torch.mul(x, y)
.
Write a function in Python with the following signature:
def mul_batched(in_dims: Tuple[Optional[int], Optional[int]], x: Tensor, y: Tensor) -> Tensor:
pass
mul_batched(in_dims, x, y)
should return the same thing as vmap(simple_mul, in_dims)(x, y)
.
Do not use vmap in the definition of mul_batched
.
Test it out with the above cases.
Write a brand new CompositeImplicitAutograd function in native_functions.yaml called simple_mul
.
Add an OpInfo entry
for it to test it. Don't use BinaryUfuncInfo
(that's beyond the scope of this lab) as the constructor; just use OpInfo
and provide the above test cases (and more, if you wish) as the sample inputs.
When this step is done, you'll have a torch.simple_mul
operator that is callable from Python as well as an OpInfo test for it.
Run
pytest test/functorch/test_vmap.py -v -k "simple_mul"
to verify that the OpInfo is correctly hooked up to our vmap testing.
Mark the operator as CompositeExplicitAutograd and add vmap support for torch.simple_mul
in BatchRulesBinaryOps.cpp. The goal is to get the following
tests to pass:
- TestVmapOperators.test_vmap_exhaustive_simple_mul*
- TestVmapOperators.test_op_has_batch_rule_simple_mul*
Please do this step by adding the following C++ function, which should be a transcription of your Python
mul_batched
into C++.
std::tuple<Tensor,optional<int64_t>> simple_mul_batch_rule(
const Tensor& x, optional<int64_t> x_bdim,
const Tensor& y, optional<int64_t> y_bdim) {
// your code here.
}
Avoid using the BINARY_POINTWISE macro (that would solve the problem trivially).
Unit 8: function transforms/Training Loops (Optional) - composable function transforms (aka torch.func, functorch)