class CPUExpertsFp8(mk.FusedMoEExpertsMonolithic):
"""CPU FP8 W8A16 block-quantized monolithic MoE experts."""
def __init__(
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
):
super().__init__(
moe_config,
quant_config,
)
@property
def expects_unquantized_inputs(self) -> bool:
return True
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
@staticmethod
def _supports_current_device() -> bool:
return current_platform.is_cpu()
@staticmethod
def _supports_no_act_and_mul() -> bool:
return False
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
return activation == MoEActivation.SILU
@staticmethod
def _supports_parallel_config(
moe_parallel_config: FusedMoEParallelConfig,
) -> bool:
return True
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
SUPPORTED_W_A = [
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
]
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_routing_method(
routing_method: RoutingMethodType,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
return routing_method in [
RoutingMethodType.Default,
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
]
@staticmethod
def _supports_router_logits_dtype(
router_logits_dtype: torch.dtype | None,
routing_method: RoutingMethodType,
) -> bool:
return True
def supports_expert_map(self) -> bool:
return False
def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
apply_router_weight_on_input: bool,
# grouped topk + fused topk bias parameters
num_expert_group: int | None = None,
e_score_correction_bias: torch.Tensor | None = None,
routed_scaling_factor: float | None = None,
topk_group: int | None = None,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.cpu_fused_moe import (
select_experts,
)
topk_weights, topk_ids = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
use_grouped_topk=num_expert_group is not None,
top_k=self.moe_config.experts_per_token,
renormalize=self.moe_config.routing_method
in (
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
),
topk_group=topk_group,
num_expert_group=num_expert_group,
scoring_func="softmax",
routed_scaling_factor=(
routed_scaling_factor if routed_scaling_factor is not None else 1.0
),
e_score_correction_bias=e_score_correction_bias,
)
block_shape = (
list(self.quant_config.block_shape)
if self.quant_config.block_shape
else (
[self.quant_config._w1.shape.row, self.quant_config._w1.shape.col]
if self.quant_config._w1.shape is not None
else None
)
)
return torch.ops._C.fused_experts_cpu(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
False, # inplace
False, # use_int8_w8a8
True, # use_fp8_w8a16
self.w1_scale, # w1_scale
self.w2_scale, # w2_scale
block_shape, # block_size
None, # a1_scale (W8A16: no activation quantization)
None, # a2_scale
True, # is_vnni
)