From 1ffdc66d7b7ef074cec048441330fcdd721eeff1 Mon Sep 17 00:00:00 2001 From: Gregory Comer Date: Fri, 16 Jan 2026 15:43:24 -0800 Subject: [PATCH] Don't partition fp16 GeLU on XNNPACK for now Summary: Older versions of the XNNPACK library don't properly support fp16 GeLU. Disable partitioning until we pull in a newer version of XNNPACK (tracking in #16679). Differential Revision: D90899836 --- .../partition/config/generic_node_configs.py | 13 +++++++++++ backends/xnnpack/test/ops/test_gelu.py | 22 ++++++++++++++++++- 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/backends/xnnpack/partition/config/generic_node_configs.py b/backends/xnnpack/partition/config/generic_node_configs.py index d36072d1991..0e588af66cb 100644 --- a/backends/xnnpack/partition/config/generic_node_configs.py +++ b/backends/xnnpack/partition/config/generic_node_configs.py @@ -430,6 +430,19 @@ class GeluConfig(GenericNodePartitionerConfig): def supported_precision_types(self) -> List[ConfigPrecisionType]: return [ConfigPrecisionType.FP32] + def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: + if not self.check_common_constraints(node, ep): + return False + + # XNNPACK does not support GELU for fp16 + node_val = node.meta.get("val", None) + if node_val is not None and isinstance(node_val, torch.Tensor): + if node_val.dtype == torch.float16: + why(node, reason="GELU does not support fp16") + return False + + return True + class HardswishConfig(GenericNodePartitionerConfig): target_name = "hardswish.default" diff --git a/backends/xnnpack/test/ops/test_gelu.py b/backends/xnnpack/test/ops/test_gelu.py index 5f2708bb306..fdb7e2e2848 100644 --- a/backends/xnnpack/test/ops/test_gelu.py +++ b/backends/xnnpack/test/ops/test_gelu.py @@ -63,8 +63,28 @@ def run_gelu_test(self, inputs): ) def test_fp16_gelu(self): + # Older versions of XNNPACK don't support fp16 GELU. + # TODO (gjcomer) Remove this when we update XNNPACK. (#16679) inputs = (torch.randn(20).to(torch.float16),) - self.run_gelu_test(inputs) + + with torch.no_grad(): + ref_output = torch.nn.functional.gelu(inputs[0].to(torch.float32)).to( + torch.float16 + ) + atol, rtol = calculate_fp16_gelu_tolerance(ref_output) + + ( + Tester(self.Gelu(), inputs) + .export() + .check_count({"torch.ops.aten.gelu.default": 1}) + .to_edge_transform_and_lower() + # Expect no delegation + .check(["executorch_exir_dialects_edge__ops_aten_gelu_default"]) + .check_not(["torch.ops.higher_order.executorch_call_delegate"]) + .to_executorch() + .serialize() + .run_method_and_compare_outputs(atol=atol, rtol=rtol) + ) def test_fp32_gelu(self): inputs = (torch.randn(20),)