Quant Primitives - yiliu30/yi GitHub Wiki

OneDNN INT8 OPs

static at::Tensor linear_int8_with_onednn_weight
    /* Supported cases for binary post op:
      +-------------------+--------------+---------------+
      | Extra input dtype | Output dtype | Post op       |
      +-------------------+--------------+---------------+
      | Fp32/bf16         | fp32/bf16    | sum           |
      +-------------------+--------------+---------------+
      | Fp32/bf16         | int8         | add           |
      +-------------------+--------------+---------------+
      | int8              | fp32/bf16    | not supported |
      +-------------------+--------------+---------------+
      | int8              | int8         | sum           |
      +-------------------+--------------+---------------+
    */
// https://github.com/pytorch/pytorch/blob/5c78c6b05abc1c19dabdb2b41755e8dbf9321e61/aten/src/ATen/native/quantized/cpu/qlinear.cpp#L948-L961
// https://github.com/pytorch/pytorch/pull/123144

class __lambda__(torch.nn.Module):
    def forward(self, arg12_1: "f32[128, 128]"):
        # No stacktrace found for following nodes
        _frozen_param0: "f32[256]" = self._frozen_param0
        _frozen_param1: "f32[384]" = self._frozen_param1
        _frozen_param2: "f32[512]" = self._frozen_param2
        _frozen_param3: "f32[256]" = self._frozen_param3
        _frozen_param4: "i64[256]" = self._frozen_param4
        _frozen_param6: "f32[384]" = self._frozen_param6
        _frozen_param7: "i64[384]" = self._frozen_param7
        _frozen_param9: "f32[512]" = self._frozen_param9
        _frozen_param10: "i64[512]" = self._frozen_param10
        _frozen_param12: "i8[128, 256]" = self._frozen_param12
        _frozen_param13: "i8[256, 384]" = self._frozen_param13
        _frozen_param14: "i8[384, 512]" = self._frozen_param14

        # File: <eval_with_key>.12:7 in forward, code: quantize_per_tensor_default = torch.ops.quantized_decomposed.quantize_per_tensor.default(arg0_1, 0.0337982214987278, 122, 0, 255, torch.uint8);  arg0_1 = None
        quantize_per_tensor: "u8[128, 128]" = torch.ops.quantized_decomposed.quantize_per_tensor.default(
            arg12_1, 0.0337982214987278, 122, 0, 255, torch.uint8
        )
        arg12_1 = None

        # File: <eval_with_key>.12:16 in forward, code: quantize_per_tensor_default_1 = torch.ops.quantized_decomposed.quantize_per_tensor.default(relu, 0.00833055842667818, 0, 0, 255, torch.uint8);  relu = None
        qlinear_1: "u8[128, 256]" = torch__inductor_fx_passes_quantization_qlinear(
            x=quantize_per_tensor,
            x_scale=0.0337982214987278,
            x_zp=122,
            packed_weight=_frozen_param12,
            w_scale=_frozen_param3,
            w_zp=_frozen_param4,
            b=_frozen_param0,
            output_scale=1.0,
            output_zero_point=0,
            output_dtype=torch.float32,
            postop_name="none",
            postop_args=[],
            postop_algorithm="",
            o_inv_scale=0.00833055842667818,
            o_zp=0,
            o_qmin=0,
            o_qmax=255,
            o_dtype=torch.uint8,
        )
        quantize_per_tensor = _frozen_param12 = _frozen_param3 = _frozen_param4 = _frozen_param0 = None

        # File: <eval_with_key>.12:24 in forward, code: quantize_per_tensor_default_2 = torch.ops.quantized_decomposed.quantize_per_tensor.default(linear_1, 0.007723582908511162, 132, 0, 255, torch.uint8);  linear_1 = None
        qlinear: "u8[128, 384]" = torch__inductor_fx_passes_quantization_qlinear_1(
            x=qlinear_1,
            x_scale=0.00833055842667818,
            x_zp=0,
            packed_weight=_frozen_param13,
            w_scale=_frozen_param6,
            w_zp=_frozen_param7,
            b=_frozen_param1,
            output_scale=1.0,
            output_zero_point=0,
            output_dtype=torch.float32,
            postop_name="none",
            postop_args=[],
            postop_algorithm="",
            o_inv_scale=0.007723582908511162,
            o_zp=132,
            o_qmin=0,
            o_qmax=255,
            o_dtype=torch.uint8,
        )
        qlinear_1 = _frozen_param13 = _frozen_param6 = _frozen_param7 = _frozen_param1 = None

        """
        onednn::qlinear_pointwise(
            Tensor qx,
            float x_scale,
            int x_zero_point,
            Tensor qw, Tensor w_scale,
            Tensor w_zero_point, Tensor? bias,
            float output_scale,
            int output_zero_point,
            ScalarType? output_dtype,
            str post_op_name,
            Scalar?[] post_op_args,
            str post_op_algorithm) -> Tensor"
        """
        # File: /home/st_liu/workspace/projects/torch2/test/pt2e_demo.py:27 in forward, code: x = self.lin3(x)
        qlinear_pointwise_default: "f32[128, 512]" = torch.ops.onednn.qlinear_pointwise.default(
            qlinear,
            0.007723582908511162,
            132,
            _frozen_param14,
            _frozen_param9,
            _frozen_param10,
            _frozen_param2,
            1.0,
            0,
            torch.float32,
            "none",
            [],
            "",
        )
        qlinear = _frozen_param14 = _frozen_param9 = _frozen_param10 = _frozen_param2 = None
        return (qlinear_pointwise_default,)