Roadmap (Operators) - microsoft/tensorflow-directml GitHub Wiki

With TensorFlow-DirectML, our goal is to have largely equivalent operator coverage to the existing GPU device. We prioritize ops primarily on how frequently they are used in models that we've encountered, but there are other considerations as well. The following aspects are considered when lowering the priority of an operator:

  • Usage: we don't have enough evidence that an op is frequently used.
  • Cost: the implementation cost (e.g. a new DirectML API) is relatively high.
  • Performance: falling back to CPU for the op isn't expected to result in a significant performance penalty, or the penalty is amortized over a longer training session.
  • Domain: the op belongs to a set of ops that target a domain we don't intend to support at the moment.
  • Generality: the op is unique to TensorFlow, not suitable for a DirectML API, but not easily implementable with a composition of existing APIs.

This prioritization may change, but we hope it's useful in highlighting the specific functionality we intend to light up over the coming months. Not all ops are equally important to accelerate, and it's definitely not required to reach 100% parity with the CPU; even the GPU device only implements a little over half (~600) of the total number of ops (~1100). As of today, DirectML implements ~480 ops.

Operators

The table below comprises all of the TF operators that have at least one GPU kernel.

  • The first column is the name of the op. You can find more info for most of these by looking up tf.raw_ops.<name> on TensorFlow's documentation (e.g. tf.raw_ops.Bitcast).
  • The second column indicates the operator type, which is a rough grouping of ops based on functionality. See Operator Types for details.
  • The third column indicates DirectML support:
    • ✅ = the op has at least one DirectML kernel registered
    • 🚧 = the op will be implemented in the next milestone (~3 months)
    • ❌ = the op may be implemented in the future
Op Name Type DirectML Support
_ParallelConcatStart array
_ParallelConcatUpdate array
BatchMatrixBandPart array
BatchMatrixDiag array
BatchMatrixDiagPart array
BatchMatrixSetDiag array
BatchToSpace array
BatchToSpaceND array
Bitcast array
BroadcastArgs array
BroadcastGradientArgs array
BroadcastTo array
CheckNumerics array
Concat array
ConcatOffset array
ConcatV2 array
Const array
DeepCopy array
DepthToSpace array
Diag array
DiagPart array
Empty array
EnsureShape array
ExpandDims array
ExtractImagePatches array
ExtractVolumePatches array
Fill array
Gather array
GatherNd array
GatherV2 array
Identity array
InplaceAdd array ❌ (Future: Usage, Cost)
InplaceSub array ❌ (Future: Usage, Cost)
InplaceUpdate array ❌ (Future: Usage, Cost)
InvertPermutation array
LowerBound array ❌ (Future: Usage, Cost)
MatrixBandPart array
MatrixDiag array
MatrixDiagPart array
MatrixDiagPartV2 array
MatrixDiagV2 array
MatrixSetDiag array
MatrixSetDiagV2 array
MirrorPad array
MirrorPadGrad array
OneHot array
OnesLike array
Pack array
Pad array
PadV2 array
ParallelConcat array
PlaceholderWithDefault array
PreventGradient array
Rank array
RefIdentity array
Reshape array
ResourceStridedSliceAssign array
Reverse array
ReverseSequence array
ReverseV2 array
Roll array
Shape array
ShapeN array
Size array
Slice array
Snapshot array
SpaceToBatch array
SpaceToBatchND array
SpaceToDepth array
Split array
SplitV array
Squeeze array
StopGradient array
StridedSlice array
StridedSliceAssign array
StridedSliceGrad array
TensorScatterAdd array
TensorScatterSub array
TensorScatterUpdate array
TensorStridedSliceUpdate array
Tile array
Transpose array
Unique array
Unpack array
UpperBound array ❌ (Future: Usage, Cost)
Where array
ZerosLike array
BitwiseAnd bitwise
BitwiseOr bitwise
BitwiseXor bitwise
Invert bitwise
LeftShift bitwise
PopulationCount bitwise
RightShift bitwise
Angle complex ❌ (Future: Domain (DirectML does not support complex data type))
Complex complex ❌ (Future: Domain (DirectML does not support complex data type))
ComplexAbs complex ❌ (Future: Domain (DirectML does not support complex data type))
Conj complex ❌ (Future: Domain (DirectML does not support complex data type))
ConjugateTranspose complex ❌ (Future: Domain (DirectML does not support complex data type))
Imag complex ❌ (Future: Domain (DirectML does not support complex data type))
Real complex ❌ (Future: Domain (DirectML does not support complex data type))
CudnnRNN cudnn ❌ (Future: Generality (CUDA-only op))
CudnnRNNBackprop cudnn ❌ (Future: Generality (CUDA-only op))
CudnnRNNBackpropV2 cudnn ❌ (Future: Generality (CUDA-only op))
CudnnRNNBackpropV3 cudnn ❌ (Future: Generality (CUDA-only op))
CudnnRNNCanonicalToParams cudnn ❌ (Future: Generality (CUDA-only op))
CudnnRNNCanonicalToParamsV2 cudnn ❌ (Future: Generality (CUDA-only op))
CudnnRNNParamsSize cudnn ❌ (Future: Generality (CUDA-only op))
CudnnRNNParamsToCanonical cudnn ❌ (Future: Generality (CUDA-only op))
CudnnRNNParamsToCanonicalV2 cudnn ❌ (Future: Generality (CUDA-only op))
CudnnRNNV2 cudnn ❌ (Future: Generality (CUDA-only op))
CudnnRNNV3 cudnn ❌ (Future: Generality (CUDA-only op))
AnonymousIterator dataset
AnonymousIteratorV2 dataset
DeleteIterator dataset
ExperimentalMapDataset dataset
ExperimentalSleepDataset dataset
GeneratorDataset dataset
IteratorFromStringHandleV2 dataset
IteratorGetNext dataset
IteratorGetNextAsOptional dataset
IteratorGetNextSync dataset
IteratorToStringHandle dataset
IteratorV2 dataset
MakeIterator dataset
OptionalFromValue dataset
OptionalGetValue dataset
OptionalHasValue dataset
OptionalNone dataset
PrefetchDataset dataset
SleepDataset dataset
UnwrapDatasetVariant dataset
WrapDatasetVariant dataset
Copy debug
CopyHost debug
DebugIdentity debug
DebugNanCount debug
DebugNumericSummary debug
CollectiveBcastRecv distributed ❌ (Future: Domain (distributed execution), Cost)
CollectiveBcastSend distributed ❌ (Future: Domain (distributed execution), Cost)
CollectiveGather distributed ❌ (Future: Domain (distributed execution), Cost)
CollectiveReduce distributed ❌ (Future: Domain (distributed execution), Cost)
ApplyKerasMomentum directml
DmlNonzeroCoordinates directml
_Arg functional
_ArrayToList functional
_DeviceArg functional
_DeviceRetval functional
_ListToArray functional
_Retval functional
_SwitchN functional
_While functional
Case functional
DeleteSessionTensor functional
DynamicPartition functional ❌ (Future: Cost, Generality)
DynamicStitch functional
EagerPyFunc functional
EmptyTensorList functional
Enter functional
Exit functional
FakeParam functional
For functional
GetSessionHandle functional
GetSessionHandleV2 functional
GetSessionTensor functional
If functional
MapClear functional
MapIncompleteSize functional
MapPeek functional
MapSize functional
MapStage functional
MapUnstage functional
MapUnstageNoKey functional
Merge functional
NextIteration functional
OrderedMapClear functional
OrderedMapIncompleteSize functional
OrderedMapPeek functional
OrderedMapSize functional
OrderedMapStage functional
OrderedMapUnstage functional
OrderedMapUnstageNoKey functional
ParallelDynamicStitch functional
PartitionedCall functional
RefEnter functional
RefExit functional
RefMerge functional
RefNextIteration functional
RefSwitch functional
RemoteCall functional
Stage functional
StageClear functional
StagePeek functional
StageSize functional
StatefulPartitionedCall functional
StatelessIf functional
StatelessWhile functional
Switch functional
SymbolicGradient functional
TensorArray functional
TensorArrayClose functional
TensorArrayCloseV2 functional
TensorArrayCloseV3 functional
TensorArrayConcat functional
TensorArrayConcatV2 functional
TensorArrayConcatV3 functional
TensorArrayGather functional
TensorArrayGatherV2 functional
TensorArrayGatherV3 functional
TensorArrayGrad functional
TensorArrayGradV2 functional
TensorArrayGradV3 functional
TensorArrayGradWithShape functional
TensorArrayPack functional
TensorArrayRead functional
TensorArrayReadV2 functional
TensorArrayReadV3 functional
TensorArrayScatter functional
TensorArrayScatterV2 functional
TensorArrayScatterV3 functional
TensorArraySize functional
TensorArraySizeV2 functional
TensorArraySizeV3 functional
TensorArraySplit functional
TensorArraySplitV2 functional
TensorArraySplitV3 functional
TensorArrayUnpack functional
TensorArrayV2 functional
TensorArrayV3 functional
TensorArrayWrite functional
TensorArrayWriteV2 functional
TensorArrayWriteV3 functional
TensorListConcat functional
TensorListConcatLists functional
TensorListConcatV2 functional
TensorListElementShape functional
TensorListFromTensor functional
TensorListGather functional
TensorListGetItem functional
TensorListLength functional
TensorListPopBack functional
TensorListPushBack functional
TensorListPushBackBatch functional
TensorListReserve functional
TensorListResize functional
TensorListScatter functional
TensorListScatterIntoExistingList functional
TensorListScatterV2 functional
TensorListSetItem functional
TensorListSplit functional
TensorListStack functional
Unstage functional
While functional
AdjustContrast image
AdjustContrastv2 image
AdjustHue image
AdjustSaturation image
CropAndResize image
CropAndResizeGradBoxes image
CropAndResizeGradImage image
HSVToRGB image
NonMaxSuppressionV2 image ❌ (Future: Usage, Cost)
NonMaxSuppressionV3 image ❌ (Future: Usage, Cost)
ResizeBilinear image
ResizeBilinearGrad image
ResizeNearestNeighbor image
ResizeNearestNeighborGrad image
RGBToHSV image
_CopyFromGpuToHost internal
_CopyFromHostToGpu internal
_If internal
_ScopedAllocator internal ❌ (Future: Usage)
_ScopedAllocatorConcat internal ❌ (Future: Usage)
_ScopedAllocatorSplit internal ❌ (Future: Usage)
BatchMatrixTriangularSolve linalg ❌ (Future: Domain (linear algebra), Cost, Usage)
BatchSvd linalg ❌ (Future: Domain (linear algebra), Cost)
Cholesky linalg ❌ (Future: Domain (linear algebra), Cost)
Einsum linalg ❌ (Future: Domain (linear algebra), Cost)
LogMatrixDeterminant linalg ❌ (Future: Domain (linear algebra), Cost, Usage)
Lu linalg ❌ (Future: Domain (linear algebra), Cost, Usage)
MatrixDeterminant linalg ❌ (Future: Domain (linear algebra), Cost)
MatrixInverse linalg ❌ (Future: Domain (linear algebra), Cost)
MatrixSolve linalg ❌ (Future: Domain (linear algebra), Cost)
MatrixTriangularSolve linalg ❌ (Future: Domain (linear algebra), Cost)
Qr linalg ❌ (Future: Domain (linear algebra), Cost, Usage)
SelfAdjointEigV2 linalg ❌ (Future: Domain (linear algebra), Cost)
Svd linalg ❌ (Future: Domain (linear algebra), Cost)
TridiagonalMatMul linalg ❌ (Future: Domain (linear algebra), Cost, Usage)
TridiagonalSolve linalg ❌ (Future: Domain (linear algebra), Cost, Usage)
Abs math
Acos math
Acosh math
Add math
AddN math
AddV2 math
All math
Any math
ApproximateEqual math
ArgMax math
ArgMin math
Asin math
Asinh math
Atan math
Atan2 math
Atanh math
BatchMatMul math
BatchMatMulV2 math
Bincount math ❌ (Future: Usage, Cost)
Bucketize math ❌ (Future: Usage, Cost)
Cast math
Ceil math
ClipByValue math
CompareAndBitpack math
Cos math
Cosh math
Cross math
Cumprod math
Cumsum math
CumulativeLogsumexp math ❌ (Future: Usage, Cost)
Div math
DivNoNan math
Equal math
Erf math
Erfc math
EuclideanNorm math
Exp math
Expm1 math
Floor math
FloorDiv math
FloorMod math
Greater math
GreaterEqual math
HistogramFixedWidth math ❌ (Future: Usage, Cost)
Inv math
InvGrad math
IsFinite math
IsInf math
IsNan math
Less math
LessEqual math
LinSpace math
Log math
Log1p math
LogicalAnd math
LogicalNot math
LogicalOr math
MatMul math
Max math
Maximum math
Mean math
Min math
Minimum math
Mod math
Mul math
MulNoNan math
Neg math
NextAfter math ❌ (Future: Usage, Cost)
NotEqual math
Pow math
Prod math
Range math
RealDiv math
Reciprocal math
ReciprocalGrad math
Rint math
Round math
Rsqrt math
RsqrtGrad math
SegmentSum math ❌ (Future: Generality, Cost)
Select math
SelectV2 math
Sigmoid math
SigmoidGrad math
Sign math
Sin math
Sinh math
Sqrt math
SqrtGrad math
Square math
SquaredDifference math
Sub math
Sum math
Tan math
Tanh math
TanhGrad math
TruncateDiv math
TruncateMod math
UnsortedSegmentMax math ❌ (Future: Generality, Cost)
UnsortedSegmentMin math ❌ (Future: Generality, Cost, Usage)
UnsortedSegmentProd math ❌ (Future: Generality, Cost, Usage)
UnsortedSegmentSum math ❌ (Future: Generality, Cost)
Xdivy math
Xlogy math
_FusedBatchNormEx nn
_FusedConv2D nn
AvgPool nn
AvgPool3D nn
AvgPool3DGrad nn
AvgPoolGrad nn
BatchNormWithGlobalNormalization nn
BatchNormWithGlobalNormalizationGrad nn
BiasAdd nn
BiasAddGrad nn
BiasAddV1 nn
Conv2D nn
Conv2DBackpropFilter nn
Conv2DBackpropInput nn
Conv3D nn
Conv3DBackpropFilter nn
Conv3DBackpropFilterV2 nn
Conv3DBackpropInput nn
Conv3DBackpropInputV2 nn
DataFormatDimMap nn
DataFormatVecPermute nn
DepthwiseConv2dNative nn
DepthwiseConv2dNativeBackpropFilter nn
DepthwiseConv2dNativeBackpropInput nn
Dilation2D nn ❌ (Future: Usage, Cost)
Dilation2DBackpropFilter nn ❌ (Future: Usage, Cost)
Dilation2DBackpropInput nn ❌ (Future: Usage, Cost)
Elu nn
EluGrad nn
FusedBatchNorm nn
FusedBatchNormGrad nn
FusedBatchNormGradV2 nn
FusedBatchNormGradV3 nn
FusedBatchNormV2 nn
FusedBatchNormV3 nn
InTopKV2 nn ❌ (Future: Usage, Cost)
L2Loss nn
LeakyRelu nn
LeakyReluGrad nn
LogSoftmax nn
LRN nn
LRNGrad nn
MaxPool nn
MaxPool3D nn
MaxPool3DGrad nn
MaxPool3DGradGrad nn ❌ (Future: Usage, Cost)
MaxPoolGrad nn
MaxPoolGradGrad nn ❌ (Future: Usage, Cost)
MaxPoolGradGradV2 nn ❌ (Future: Usage, Cost)
MaxPoolGradGradWithArgmax nn ❌ (Future: Usage, Cost)
MaxPoolGradV2 nn
MaxPoolGradWithArgmax nn ❌ (Future: Usage, Cost)
MaxPoolV2 nn
MaxPoolWithArgmax nn ❌ (Future: Usage, Cost)
Relu nn
Relu6 nn
Relu6Grad nn
ReluGrad nn
Selu nn
SeluGrad nn
Softmax nn
SoftmaxCrossEntropyWithLogits nn
Softplus nn
SoftplusGrad nn
Softsign nn
SoftsignGrad nn
SparseSoftmaxCrossEntropyWithLogits nn
TopK nn
TopKV2 nn
FakeQuantWithMinMaxArgs quantization ❌ (Future: Domain (quantization), Cost)
FakeQuantWithMinMaxArgsGradient quantization ❌ (Future: Domain (quantization), Cost)
FakeQuantWithMinMaxVars quantization ❌ (Future: Domain (quantization), Cost)
FakeQuantWithMinMaxVarsGradient quantization ❌ (Future: Domain (quantization), Cost)
FakeQuantWithMinMaxVarsPerChannel quantization ❌ (Future: Domain (quantization), Cost)
FakeQuantWithMinMaxVarsPerChannelGradient quantization ❌ (Future: Domain (quantization), Cost)
QuantizeAndDequantize quantization ❌ (Future: Domain (quantization), Cost)
QuantizeAndDequantizeV2 quantization ❌ (Future: Domain (quantization), Cost)
QuantizeAndDequantizeV3 quantization ❌ (Future: Domain (quantization), Cost)
Multinomial random ❌ (Future: Cost, Performance (amortizable)
ParameterizedTruncatedNormal random ❌ (Future: Cost, Performance (amortizable)
RandomGammaGrad random ❌ (Future: Cost, Performance (amortizable)
RandomStandardNormal random ❌ (Future: Cost, Performance (amortizable)
RandomUniform random
RandomUniformInt random
RngSkip random ❌ (Future: Cost, Performance (amortizable))
StatefulStandardNormalV2 random ❌ (Future: Cost, Performance (amortizable))
StatefulTruncatedNormal random ❌ (Future: Cost, Performance (amortizable))
StatefulUniform random ❌ (Future: Cost, Performance (amortizable))
StatefulUniformFullInt random ❌ (Future: Cost, Performance (amortizable))
StatefulUniformInt random ❌ (Future: Cost, Performance (amortizable))
StatelessMultinomial random ❌ (Future: Cost, Performance (amortizable))
StatelessRandomNormal random ❌ (Future: Cost, Performance (amortizable))
StatelessRandomUniform random
StatelessRandomUniformInt random
StatelessTruncatedNormal random ❌ (Future: Cost, Performance (amortizable))
TruncatedNormal random ❌ (Future: Cost, Performance (amortizable))
BlockLSTM rnn
BlockLSTMGrad rnn
BlockLSTMGradV2 rnn
BlockLSTMV2 rnn
GRUBlockCell rnn
GRUBlockCellGrad rnn
LSTMBlockCell rnn
LSTMBlockCellGrad rnn
BatchFFT signal ❌ (Future: Domain (signal processing), Cost)
BatchFFT2D signal ❌ (Future: Domain (signal processing), Cost)
BatchFFT3D signal ❌ (Future: Domain (signal processing), Cost)
BatchIFFT signal ❌ (Future: Domain (signal processing), Cost)
BatchIFFT2D signal ❌ (Future: Domain (signal processing), Cost)
BatchIFFT3D signal ❌ (Future: Domain (signal processing), Cost)
FFT signal ❌ (Future: Domain (signal processing), Cost)
FFT2D signal ❌ (Future: Domain (signal processing), Cost)
FFT3D signal ❌ (Future: Domain (signal processing), Cost)
IFFT signal ❌ (Future: Domain (signal processing), Cost)
IFFT2D signal ❌ (Future: Domain (signal processing), Cost)
IFFT3D signal ❌ (Future: Domain (signal processing), Cost)
IRFFT signal ❌ (Future: Domain (signal processing), Cost)
IRFFT2D signal ❌ (Future: Domain (signal processing), Cost)
IRFFT3D signal ❌ (Future: Domain (signal processing), Cost)
RFFT signal ❌ (Future: Domain (signal processing), Cost)
RFFT2D signal ❌ (Future: Domain (signal processing), Cost)
RFFT3D signal ❌ (Future: Domain (signal processing), Cost)
SparseTensorDenseMatMul sparse ❌ (Future: Domain (sparse tensors), Usage)
BesselI0e special ❌ (Future: Domain (special functions), Cost, Usage)
BesselI1e special ❌ (Future: Domain (special functions), Cost, Usage)
Betainc special ❌ (Future: Domain (special functions), Cost, Usage)
Digamma special ❌ (Future: Domain (special functions), Cost, Usage)
Igamma special ❌ (Future: Domain (special functions), Cost, Usage)
Igammac special ❌ (Future: Domain (special functions), Cost, Usage)
IgammaGradA special ❌ (Future: Domain (special functions), Cost, Usage)
Lgamma special ❌ (Future: Domain (special functions), Cost)
Polygamma special ❌ (Future: Domain (special functions), Cost, Usage)
Zeta special ❌ (Future: Domain (special functions), Cost, Usage)
_ReadVariablesOp state
_VarHandlesOp state
Assign state
AssignAdd state
AssignAddVariableOp state
AssignSub state
AssignSubVariableOp state
AssignVariableOp state
ConsumeMutexLock state
DestroyResourceOp state
DestroyTemporaryVariable state
IsVariableInitialized state
MutexLock state
MutexV2 state
ReadVariableOp state
ResourceGather state
ResourceGatherNd state
ResourceScatterAdd state
ResourceScatterDiv state
ResourceScatterMax state
ResourceScatterMin state
ResourceScatterMul state
ResourceScatterNdAdd state
ResourceScatterNdSub state
ResourceScatterNdUpdate state
ResourceScatterSub state
ResourceScatterUpdate state
ScatterAdd state
ScatterDiv state
ScatterMax state
ScatterMin state
ScatterMul state
ScatterNd state
ScatterNdAdd state
ScatterNdNonAliasingAdd state
ScatterNdSub state
ScatterNdUpdate state
ScatterSub state
ScatterUpdate state
TemporaryVariable state
VarHandleOp state
Variable state
VariableShape state
VariableV2 state
VarIsInitializedOp state
ApplyAdadelta training
ApplyAdagrad training
ApplyAdagradV2 training
ApplyAdam training
ApplyAdaMax training
ApplyAddSign training
ApplyCenteredRMSProp training
ApplyGradientDescent training
ApplyMomentum training
ApplyPowerSign training
ApplyRMSProp training
ResourceApplyAdadelta training
ResourceApplyAdagrad training
ResourceApplyAdagradV2 training
ResourceApplyAdam training
ResourceApplyAdaMax training
ResourceApplyAdamWithAmsgrad training
ResourceApplyAddSign training
ResourceApplyCenteredRMSProp training
ResourceApplyGradientDescent training
ResourceApplyKerasMomentum training
ResourceApplyMomentum training
ResourceApplyPowerSign training
ResourceApplyRMSProp training

Operator Types

The operator type in the table above is a rough grouping of ops based on functionality. These labels are not official and only for organization purposes. Many of the names are derived from the kernel source file names.

Op Type Purpose
array Tensor transforms such as tile, reshape, slice, join, split, broadcast, generate, etc.
bitwise Classic unary and binary bitwise ops: and, xor, not, bit count, etc.
complex Op to support tensors that store complex numbers.
cudnn Ops specific to CuDNN (no CPU implementation).
dataset Training/inference data pipeline primitives for efficient storage, retrieval, and batching.
debug Helpers for working with tfdbg (TensorFlow debugger).
distributed Distributed/remote execution, including multi-device and multi-node operation.
directml Ops specific to DirectML (no CPU implementation).
functional Constructing dynamic graphs with higher-level primitives (tf.function, while_loop, AutoGraph).
image Encoding, decoding, and working with image formats. Includes color space transforms, cropping, etc.
internal Internal-only memory transfer optimization ops.
linalg Classic linear algebra routines: determinants, decomposition, solving systems of equations, etc.
math General-purpose compute ops: unary, binary, trig functions, reductions, norms, etc.
nn Neural-network ops: pooling, conv, normalization, activation, etc.
quantization Ops to support quantization.
random Random number generation and distributions: uniform, normal, truncated normal, binomial, etc.
rnn Recurrent neural-network ops, including GRU and LSTM.
signal Signal processing ops for converting between time and frequency domains (FFTs).
special Special functions.
state Graph execution state: variables, resources, mutexes, etc.
training Optimizers (e.g. SGD, Adam) and training-specific support ops.
⚠️ **GitHub.com Fallback** ⚠️