OneDNN INT8 OPs
static at::Tensor linear_int8_with_onednn_weight
/* Supported cases for binary post op:
+-------------------+--------------+---------------+
| Extra input dtype | Output dtype | Post op |
+-------------------+--------------+---------------+
| Fp32/bf16 | fp32/bf16 | sum |
+-------------------+--------------+---------------+
| Fp32/bf16 | int8 | add |
+-------------------+--------------+---------------+
| int8 | fp32/bf16 | not supported |
+-------------------+--------------+---------------+
| int8 | int8 | sum |
+-------------------+--------------+---------------+
*/
// https://github.com/pytorch/pytorch/blob/5c78c6b05abc1c19dabdb2b41755e8dbf9321e61/aten/src/ATen/native/quantized/cpu/qlinear.cpp#L948-L961
// https://github.com/pytorch/pytorch/pull/123144
class __lambda__(torch.nn.Module):
def forward(self, arg12_1: "f32[128, 128]"):
# No stacktrace found for following nodes
_frozen_param0: "f32[256]" = self._frozen_param0
_frozen_param1: "f32[384]" = self._frozen_param1
_frozen_param2: "f32[512]" = self._frozen_param2
_frozen_param3: "f32[256]" = self._frozen_param3
_frozen_param4: "i64[256]" = self._frozen_param4
_frozen_param6: "f32[384]" = self._frozen_param6
_frozen_param7: "i64[384]" = self._frozen_param7
_frozen_param9: "f32[512]" = self._frozen_param9
_frozen_param10: "i64[512]" = self._frozen_param10
_frozen_param12: "i8[128, 256]" = self._frozen_param12
_frozen_param13: "i8[256, 384]" = self._frozen_param13
_frozen_param14: "i8[384, 512]" = self._frozen_param14
# File: <eval_with_key>.12:7 in forward, code: quantize_per_tensor_default = torch.ops.quantized_decomposed.quantize_per_tensor.default(arg0_1, 0.0337982214987278, 122, 0, 255, torch.uint8); arg0_1 = None
quantize_per_tensor: "u8[128, 128]" = torch.ops.quantized_decomposed.quantize_per_tensor.default(
arg12_1, 0.0337982214987278, 122, 0, 255, torch.uint8
)
arg12_1 = None
# File: <eval_with_key>.12:16 in forward, code: quantize_per_tensor_default_1 = torch.ops.quantized_decomposed.quantize_per_tensor.default(relu, 0.00833055842667818, 0, 0, 255, torch.uint8); relu = None
qlinear_1: "u8[128, 256]" = torch__inductor_fx_passes_quantization_qlinear(
x=quantize_per_tensor,
x_scale=0.0337982214987278,
x_zp=122,
packed_weight=_frozen_param12,
w_scale=_frozen_param3,
w_zp=_frozen_param4,
b=_frozen_param0,
output_scale=1.0,
output_zero_point=0,
output_dtype=torch.float32,
postop_name="none",
postop_args=[],
postop_algorithm="",
o_inv_scale=0.00833055842667818,
o_zp=0,
o_qmin=0,
o_qmax=255,
o_dtype=torch.uint8,
)
quantize_per_tensor = _frozen_param12 = _frozen_param3 = _frozen_param4 = _frozen_param0 = None
# File: <eval_with_key>.12:24 in forward, code: quantize_per_tensor_default_2 = torch.ops.quantized_decomposed.quantize_per_tensor.default(linear_1, 0.007723582908511162, 132, 0, 255, torch.uint8); linear_1 = None
qlinear: "u8[128, 384]" = torch__inductor_fx_passes_quantization_qlinear_1(
x=qlinear_1,
x_scale=0.00833055842667818,
x_zp=0,
packed_weight=_frozen_param13,
w_scale=_frozen_param6,
w_zp=_frozen_param7,
b=_frozen_param1,
output_scale=1.0,
output_zero_point=0,
output_dtype=torch.float32,
postop_name="none",
postop_args=[],
postop_algorithm="",
o_inv_scale=0.007723582908511162,
o_zp=132,
o_qmin=0,
o_qmax=255,
o_dtype=torch.uint8,
)
qlinear_1 = _frozen_param13 = _frozen_param6 = _frozen_param7 = _frozen_param1 = None
"""
onednn::qlinear_pointwise(
Tensor qx,
float x_scale,
int x_zero_point,
Tensor qw, Tensor w_scale,
Tensor w_zero_point, Tensor? bias,
float output_scale,
int output_zero_point,
ScalarType? output_dtype,
str post_op_name,
Scalar?[] post_op_args,
str post_op_algorithm) -> Tensor"
"""
# File: /home/st_liu/workspace/projects/torch2/test/pt2e_demo.py:27 in forward, code: x = self.lin3(x)
qlinear_pointwise_default: "f32[128, 512]" = torch.ops.onednn.qlinear_pointwise.default(
qlinear,
0.007723582908511162,
132,
_frozen_param14,
_frozen_param9,
_frozen_param10,
_frozen_param2,
1.0,
0,
torch.float32,
"none",
[],
"",
)
qlinear = _frozen_param14 = _frozen_param9 = _frozen_param10 = _frozen_param2 = None
return (qlinear_pointwise_default,)