Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions backends/xnnpack/partition/config/generic_node_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,19 @@ class GeluConfig(GenericNodePartitionerConfig):
def supported_precision_types(self) -> List[ConfigPrecisionType]:
return [ConfigPrecisionType.FP32]

def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
if not self.check_common_constraints(node, ep):
return False

# XNNPACK does not support GELU for fp16
node_val = node.meta.get("val", None)
if node_val is not None and isinstance(node_val, torch.Tensor):
if node_val.dtype == torch.float16:
why(node, reason="GELU does not support fp16")
return False

return True


class HardswishConfig(GenericNodePartitionerConfig):
target_name = "hardswish.default"
Expand Down
22 changes: 21 additions & 1 deletion backends/xnnpack/test/ops/test_gelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,28 @@ def run_gelu_test(self, inputs):
)

def test_fp16_gelu(self):
# Older versions of XNNPACK don't support fp16 GELU.
# TODO (gjcomer) Remove this when we update XNNPACK. (#16679)
inputs = (torch.randn(20).to(torch.float16),)
self.run_gelu_test(inputs)

with torch.no_grad():
ref_output = torch.nn.functional.gelu(inputs[0].to(torch.float32)).to(
torch.float16
)
atol, rtol = calculate_fp16_gelu_tolerance(ref_output)

(
Tester(self.Gelu(), inputs)
.export()
.check_count({"torch.ops.aten.gelu.default": 1})
.to_edge_transform_and_lower()
# Expect no delegation
.check(["executorch_exir_dialects_edge__ops_aten_gelu_default"])
.check_not(["torch.ops.higher_order.executorch_call_delegate"])
.to_executorch()
.serialize()
.run_method_and_compare_outputs(atol=atol, rtol=rtol)
)

def test_fp32_gelu(self):
inputs = (torch.randn(20),)
Expand Down
Loading