Mini Sequence Labeler PyTorch vs Dynet - gchanan/pytorch GitHub Wiki
Overview
Link to original PyTorch script: https://gist.github.com/hal3/8c170c4400576eb8d0a8bd94ab231232
Link to original DyNet script: https://gist.github.com/neubig/7d874d327e7d700bb609479f6cf359bb
Initial results:
Run | Time (s) |
---|---|
DyNet | 1.206 |
PyTorch | 13.047 |
PyTorch (OMP_NUM_THREADS=1) | 8.36 |
Investigation
Ok, so why are we so slow? Let's break down the time spent in each operation, starting with the forward.
The following table lists the cumulative running time if we continue
after each labeled state (these are different runs, so don't match up exactly to the above).
- Everything but
total
only includes the forward - The DyNet runs include an
npvalue
call at the last stage because of DyNet's lazy evaluation - The PyTorch runs are with OMP_NUM_THREADS=1
Run | preprocess | concat | affine_t | rectify | affine_transform | arg_min | loss | forward | total |
---|---|---|---|---|---|---|---|---|---|
Dynet | 0.52 | 0.73 | 0.78 | 0.82 | 0.82 | 0.87 | 0.95 | 1.00 | 1.29 |
PyTorch | 0.64 | 0.78 | 1.27 | 1.44 | 1.76 | 1.93 | 2.37 | 2.91 | 8.71 |
Some observations
- PyTorch really starts to fall behind when we get to affine_transform
- Most of the overhead is in backward + optimize, where PyTorch takes ~5.71 seconds compared to ~.3 seconds for DyNet.
Backward overhead
Let's get cumulative times for the various backward methods. I basically just recorded the time each call took in the backwards dispatch of each framework. Note that DyNet doesn't include an AccumulateGrad equivalent because that's done as a separate loop which I didn't annotate.
DyNet function | Cumulative Time (s) |
---|---|
AffineTransform | 0.213455 |
Concatenate | 0.0314116 |
Tanh | 0.0314116 |
SquaredEuclideanDistance | 0.0122211 |
CwiseMultiply | 0.00968469 |
Sum | 0.00484226 |
LogisticSigmoid | 0.00293803 |
Rectify | 0.00263496 |
PickBatchElements | 0.00236525 |
ConstantMinus | 0.00158334 |
PyTorch function | Cumulative Time (s) |
---|---|
AddmmBackward | 2.41561 |
ConcatBackward | 0.527592 |
EmbeddingBackward | 0.398528 |
MSELossBackward | 0.385419 |
ThresholdBackward | 0.332549 |
AddBackward | 0.308445 |
IndexBackward | 0.305842 |
TransposeBackward | 0.231972 |
MulBackward | 0.0893593 |
ChunkBackward | 0.0450809 |
SigmoidBackward | 0.0433337 |
SubBackward | 0.0280001 |
TanhBackward | 0.0200974 |
ViewBackward | 0.0142328 |
AccumulateGrad | 0.00393635 |
AddConstantBackward | 0.000338829 |
Graph Root | 0.000133762 |
As expected, PyTorch is slower here, and most of the time is spent in Addmm backward, which is what Linear
ends up calling. So let's look at that next.
Linear overhead
For investigating Linear, https://github.com/pytorch/pytorch/issues/2560 includes a simple benchmark -- let's look at that. I did some binary search over commits:
Commit | Time |
---|---|
master | 5.74 |
v0.12 | 1.94 |
added twice differentiation for a bunch of ops | 3.94 |
added twice differentiation for a bunch of ops~1 | 1.62 |
Use torch.matmul in nn.Linear | 5.60 |
Use torch.matmul in nn.Linear~1 | 4.10 |
Those two commits basically represent the entire slowdown and suggest the issue is python overhead: twice differentiation causes the backward to invoke python Variable
functions (which parse arguments in python) rather than Tensor
functions. Using torch.matmul in nn.Linear also increases the python overhead (it ends up calling the autograd Addmm Function, which has similar issues.
Larger model
So, if (python) overhead is an issue, PyTorch should do relatively better on a larger model. Let's take the example from here (https://twitter.com/haldaume3/status/900723511274196992) and set n_labels to 50000.
Results:
Run | Time (s) |
---|---|
DyNet | 105.84 |
PyTorch | 46.60 |
PyTorch (OMP_NUM_THREADS=1) | 78.80 |
Note this is inconsistent with the numbers in the twitter post, so we should investigate further; but it does suggest overhead is the main issue.
Reducing the overhead
The most obvious source of overhead with the above commits is python overhead with Variables. To give a rough estimate of how much we can reduce this, I wrote a (forward only) C++ autograd function and benchmark here. (Ideally we'd implement Addmm/Linear, but that would take significantly more work to get the estimate.) Here are the results:
Run | Time (s) for 100000 iterations | Variable overhead |
---|---|---|
addmm tensors | 0.150 | 0 |
addmm variables | 1.086 | 0.936 |
add tensors | 0.082 | 0 |
add variables (python) | 0.515 | 0.433 |
add variables (C++) | 0.178 | 0.096 |
So, the C++ code is ~3x faster than the python code and the Variable overhead is reduced by ~4.5x; given the variable overhead is higher for Addmm/Linear than Add and I spent no time optimizing this, it seems likely we can get more than 3x/4.5x out of moving the C++ functions to autograd. Even with pessimistic assumptions, reducing the Variable overhead of the Linear benchmark by 4.5x would leave us at:
Version | Time (s) |
---|---|
master (1/3/18) | 2.60 |
master (11/28/17) | 3.22 |
master (C++) estimated | 2.78 |
and if those gains are consistent for the backward (we'd have to measure the non-Variable forward to estimate forward overhead), we could reduce the backward from ~5.8 seconds currently to ~1.9 seconds. Still slower than Dynet is currently, but with conservative assumptions and not estimating gains for the forward.
Plan
The current short term plan is to make it easier to write autograd functions in C++, utilizing ATen. I just added a https://github.com/zdevito/ATen/pull/49 to add broadcasting, so moving autograd functions from python to C++ will be as seamless as possible.