diff --git a/backends/xnnpack/partition/config/generic_node_configs.py b/backends/xnnpack/partition/config/generic_node_configs.py index d36072d1991..0e588af66cb 100644 --- a/backends/xnnpack/partition/config/generic_node_configs.py +++ b/backends/xnnpack/partition/config/generic_node_configs.py @@ -430,6 +430,19 @@ class GeluConfig(GenericNodePartitionerConfig): def supported_precision_types(self) -> List[ConfigPrecisionType]: return [ConfigPrecisionType.FP32] + def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: + if not self.check_common_constraints(node, ep): + return False + + # XNNPACK does not support GELU for fp16 + node_val = node.meta.get("val", None) + if node_val is not None and isinstance(node_val, torch.Tensor): + if node_val.dtype == torch.float16: + why(node, reason="GELU does not support fp16") + return False + + return True + class HardswishConfig(GenericNodePartitionerConfig): target_name = "hardswish.default" diff --git a/backends/xnnpack/test/ops/test_gelu.py b/backends/xnnpack/test/ops/test_gelu.py index 5f2708bb306..fdb7e2e2848 100644 --- a/backends/xnnpack/test/ops/test_gelu.py +++ b/backends/xnnpack/test/ops/test_gelu.py @@ -63,8 +63,28 @@ def run_gelu_test(self, inputs): ) def test_fp16_gelu(self): + # Older versions of XNNPACK don't support fp16 GELU. + # TODO (gjcomer) Remove this when we update XNNPACK. (#16679) inputs = (torch.randn(20).to(torch.float16),) - self.run_gelu_test(inputs) + + with torch.no_grad(): + ref_output = torch.nn.functional.gelu(inputs[0].to(torch.float32)).to( + torch.float16 + ) + atol, rtol = calculate_fp16_gelu_tolerance(ref_output) + + ( + Tester(self.Gelu(), inputs) + .export() + .check_count({"torch.ops.aten.gelu.default": 1}) + .to_edge_transform_and_lower() + # Expect no delegation + .check(["executorch_exir_dialects_edge__ops_aten_gelu_default"]) + .check_not(["torch.ops.higher_order.executorch_call_delegate"]) + .to_executorch() + .serialize() + .run_method_and_compare_outputs(atol=atol, rtol=rtol) + ) def test_fp32_gelu(self): inputs = (torch.randn(20),)