Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ $else:
#include "broadcasting_utils.h"
#include "indexing_utils.h"

$if MASK_PADDING:
#define MASK_PADDING

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")}
Expand Down Expand Up @@ -140,11 +143,26 @@ void main() {
other_texel = other_texel.xxxx;
}

write_texel_lpos(
t_out,
lpos,
VEC4_OUT_T(op(in_texel, other_texel, alpha)),
out_axis_map);
VEC4_OUT_T out_texel = VEC4_OUT_T(op(in_texel, other_texel, alpha));

#ifdef MASK_PADDING
// Handle padding elements in the last texel to prevent NaN propagation.
// When the packed dimension size is not a multiple of 4, the last texel
// will have padding elements. For division operations, padding elements
// (which are 0/0) can produce NaN values that propagate through reductions.
const int nspill = mod4(out_sizes[packed_dim]);
const int texels_per_batch = divup4(out_sizes[packed_dim]);
const bool is_last_texel = (lpos[packed_dim] % texels_per_batch) == (texels_per_batch - 1);

if (is_last_texel && nspill > 0) {
// Explicitly set padding elements to 0 to avoid NaN
[[unroll]] for (int i = nspill; i < 4; i++) {
out_texel[i] = 0;
}
}
#endif

write_texel_lpos(t_out, lpos, out_texel, out_axis_map);
}

#endif
3 changes: 3 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ binary_op:
NDIM: 3
DTYPE: float
PACKING: C_packed
MASK_PADDING: 0
generate_variant_forall:
STORAGE:
- VALUE: texture3d
Expand All @@ -26,10 +27,12 @@ binary_op:
OPERATOR: X * Y
- NAME: binary_div
OPERATOR: X / Y
MASK_PADDING: 1
- NAME: binary_pow
OPERATOR: pow(X, Y)
- NAME: binary_floor_divide
OPERATOR: floor(X / Y)
MASK_PADDING: 1
- NAME: binary_minimum
OPERATOR: min(X, Y)
- NAME: binary_eq_int32
Expand Down
Loading