Skip to content

vllm.model_executor.kernels.linear.scaled_mm.cpu

CPUFp8BlockScaledMMKernel

Bases: Fp8BlockScaledMMLinearKernel

FP8 W8A16 block-quantized GEMM via AMX BRGEMM on CPU.

Source code in vllm/model_executor/kernels/linear/scaled_mm/cpu.py
class CPUFp8BlockScaledMMKernel(Fp8BlockScaledMMLinearKernel):
    """FP8 W8A16 block-quantized GEMM via AMX BRGEMM on CPU."""

    # Input stays BF16 — no FP8 activation quantization.
    apply_input_quant = False

    @classmethod
    def is_supported(
        cls, compute_capability: int | None = None
    ) -> tuple[bool, str | None]:
        if not current_platform.is_cpu():
            return False, "requires CPU platform."
        if not torch.cpu._is_amx_tile_supported():
            return False, "requires AMX tile support (Sapphire Rapids or newer)."
        if not ops._supports_cpu_fp8_w8a16:
            return False, "fp8_scaled_mm_cpu op not available."
        return True, None

    @classmethod
    def can_implement(
        cls, config: FP8ScaledMMLinearLayerConfig
    ) -> tuple[bool, str | None]:
        # Validate weight block shape
        weight_gs = config.weight_quant_key.scale.group_shape
        if weight_gs.col <= 0 or weight_gs.col != 128:
            return False, (
                "CPU FP8 kernel requires K-dimension block size of 128, "
                f"got {weight_gs.col}."
            )
        if weight_gs.row <= 0 or weight_gs.row % 32 != 0:
            return False, (
                "CPU FP8 kernel requires N-dimension block size to be "
                f"a positive multiple of 32, got {weight_gs.row}."
            )
        if config.out_dtype not in (torch.bfloat16, torch.float32):
            return False, "Only bfloat16/float32 output dtype supported."
        return True, None

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        # Skip the base class process (FP8 padding / fnuz normalization)
        # which is GPU-oriented.  Instead, VNNI-prepack weights for AMX.
        params = self._get_layer_params(layer)
        packed_weight = torch.ops._C.convert_weight_packed(params.weight)
        replace_parameter(
            layer,
            params.WEIGHT,
            torch.nn.Parameter(packed_weight, requires_grad=False),
        )

        # Re-wrap scale as a plain Parameter so the kernel can read it
        # without weight-loader metadata interfering.
        scale_attr = (
            params.WEIGHT_SCALE_INV
            if params.weight_scale_inv is not None
            else params.WEIGHT_SCALE
        )
        weight_scale = (
            params.weight_scale_inv
            if params.weight_scale_inv is not None
            else params.weight_scale
        )
        assert weight_scale is not None
        replace_parameter(
            layer,
            scale_attr,
            torch.nn.Parameter(weight_scale.data, requires_grad=False),
        )

    def apply_weights(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
        **kwargs,
    ) -> torch.Tensor:
        params = self._get_layer_params(layer)
        weight_scale = (
            params.weight_scale_inv
            if params.weight_scale_inv is not None
            else params.weight_scale
        )

        x_2d = x.reshape(-1, x.shape[-1]) if x.dim() > 2 else x
        out = torch.ops._C.fp8_scaled_mm_cpu(
            x_2d,
            params.weight,
            weight_scale,
            list(self.weight_group_shape),
            bias,
            x.dtype,
            True,  # is_vnni (weight already prepacked)
        )
        return out.reshape(x.shape[:-1] + (out.size(-1),)) if x.dim() > 2 else out

    def apply_block_scaled_mm(
        self,
        A: torch.Tensor,
        B: torch.Tensor,
        As: torch.Tensor,
        Bs: torch.Tensor,
    ) -> torch.Tensor:
        raise NotImplementedError(
            "CPUFp8BlockScaledMMKernel overrides apply_weights directly."
        )