From e8f2e12ca561284282fe846bafa974bfe4f66cd9 Mon Sep 17 00:00:00 2001 From: ssjia Date: Tue, 30 Dec 2025 11:33:29 -0800 Subject: [PATCH] [ET-VK][ez] Fix NaN propagation in binary div operations for padded texels When the packed dimension size is not a multiple of 4, texture-backed tensors have padding elements in the last texel. For division operations, these padding regions contain 0/0 = NaN, which propagates through subsequent reduction operations and corrupts results. This fix adds conditional padding masking logic to binary_op shaders: - Introduced MASK_PADDING codegen variable to binary_op.yaml - Enabled MASK_PADDING=1 for binary_div and binary_floor_divide ops - Added GLSL preprocessor macro definition in binary_op.glsl - Implemented padding masking logic using modulo arithmetic to correctly identify last texels in batch concatenation scenarios - Padding elements are explicitly zeroed out to prevent NaN propagation The implementation follows GLSL best practices by using Python preprocessing only for macro definition, keeping core shader logic as pure GLSL with standard #ifdef directives. Differential Revision: [D89935220](https://our.internmc.facebook.com/intern/diff/D89935220/) [ghstack-poisoned] --- .../runtime/graph/ops/glsl/binary_op.glsl | 28 +++++++++++++++---- .../runtime/graph/ops/glsl/binary_op.yaml | 3 ++ 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl b/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl index 6e638a3275c..b801ddfc183 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl @@ -64,6 +64,9 @@ $else: #include "broadcasting_utils.h" #include "indexing_utils.h" +$if MASK_PADDING: + #define MASK_PADDING + layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; ${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} @@ -140,11 +143,26 @@ void main() { other_texel = other_texel.xxxx; } - write_texel_lpos( - t_out, - lpos, - VEC4_OUT_T(op(in_texel, other_texel, alpha)), - out_axis_map); + VEC4_OUT_T out_texel = VEC4_OUT_T(op(in_texel, other_texel, alpha)); + +#ifdef MASK_PADDING + // Handle padding elements in the last texel to prevent NaN propagation. + // When the packed dimension size is not a multiple of 4, the last texel + // will have padding elements. For division operations, padding elements + // (which are 0/0) can produce NaN values that propagate through reductions. + const int nspill = mod4(out_sizes[packed_dim]); + const int texels_per_batch = divup4(out_sizes[packed_dim]); + const bool is_last_texel = (lpos[packed_dim] % texels_per_batch) == (texels_per_batch - 1); + + if (is_last_texel && nspill > 0) { + // Explicitly set padding elements to 0 to avoid NaN + [[unroll]] for (int i = nspill; i < 4; i++) { + out_texel[i] = 0; + } + } +#endif + + write_texel_lpos(t_out, lpos, out_texel, out_axis_map); } #endif diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml index 70793628d80..ee96b5c05b4 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml @@ -10,6 +10,7 @@ binary_op: NDIM: 3 DTYPE: float PACKING: C_packed + MASK_PADDING: 0 generate_variant_forall: STORAGE: - VALUE: texture3d @@ -26,10 +27,12 @@ binary_op: OPERATOR: X * Y - NAME: binary_div OPERATOR: X / Y + MASK_PADDING: 1 - NAME: binary_pow OPERATOR: pow(X, Y) - NAME: binary_floor_divide OPERATOR: floor(X / Y) + MASK_PADDING: 1 - NAME: binary_minimum OPERATOR: min(X, Y) - NAME: binary_eq_int32