From 4cf483fdc0ac590496ca3ae65a3cadab41226ba8 Mon Sep 17 00:00:00 2001 From: ssjia Date: Tue, 30 Dec 2025 11:33:33 -0800 Subject: [PATCH] [ET-VK][ez] Restrict batch norm operator to 4D input tensors only Batch normalization is typically used with 4D tensors (batch, channels, height, width) in convolutional neural networks. This change adds input validation to ensure batch norm is only lowered to the Vulkan backend when the input tensor is 4-dimensional. For other input shapes, the operator will fall back to other backends. The implementation follows the same pattern as the convolution operator, using an `are_node_inputs_supported_fn` callback to validate input shapes during graph partitioning. This prevents potential issues with batch norm on unsupported tensor shapes and makes the operator requirements explicit. Differential Revision: [D89935219](https://our.internmc.facebook.com/intern/diff/D89935219/) [ghstack-poisoned] --- backends/vulkan/op_registry.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index feba4f6f072..cbabadcfc80 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -793,17 +793,30 @@ def register_ported_op_all_packed_dims(): # Ported ops that support their own prepacking. -@update_features( - [ - exir_ops.edge.aten.embedding.default, - exir_ops.edge.aten._native_batch_norm_legit_no_training.default, - ] -) -def register_ported_ops_with_prepacking(): +@update_features(exir_ops.edge.aten.embedding.default) +def register_embedding_op(): + return OpFeatures( + inputs_storage=utils.CHANNELS_PACKED_TEXTURE, + supports_prepacking=True, + supports_resize=True, + ) + + +@update_features(exir_ops.edge.aten._native_batch_norm_legit_no_training.default) +def register_batch_norm_op(): + def check_batch_norm_node(node: torch.fx.Node) -> bool: + x = node.args[0] + if not isinstance(x, torch.fx.Node): + return False + x_shape = x.meta["val"].size() + # Only support 4-D input tensors + return len(x_shape) == 4 + return OpFeatures( inputs_storage=utils.CHANNELS_PACKED_TEXTURE, supports_prepacking=True, supports_resize=True, + are_node_inputs_supported_fn=check_batch_norm_node, )