Skip to content

Commit 8a5b92c

Browse files
authored
Fix masked_multihead_self_attention meta registration (#3584)
* fix masked_multihead_self_attention * update code
1 parent 0d20258 commit 8a5b92c

File tree

3 files changed

+18
-14
lines changed

3 files changed

+18
-14
lines changed

intel_extension_for_pytorch/_meta_registrations.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import List, Optional
33

44
import torch
5+
import torch._custom_ops
56
import torch.library
67
from torch._prims_common import IntLike
78
from .utils.channels_last_1d import to_channels_last_1d
@@ -623,7 +624,7 @@ def meta_tpp_linear_mul(
623624
return input.new_empty((*input.shape[:-1], out_features))
624625

625626

626-
@register_meta("masked_multihead_self_attention")
627+
@torch.library.register_fake("torch_ipex::masked_multihead_self_attention")
627628
def meta_masked_multihead_self_attention(
628629
query,
629630
key,
@@ -641,24 +642,25 @@ def meta_masked_multihead_self_attention(
641642
attn_output = query.new_empty(
642643
(query.shape[0], query.shape[2], query.shape[1], query.shape[3])
643644
)
644-
if query.dtype == torch.bfloat16:
645-
attn_output.as_strided_(
646-
attn_output.shape,
647-
(
648-
query.shape[1] * query.shape[2] * query.shape[3],
649-
query.shape[3],
650-
query.shape[2] * query.shape[3],
651-
1,
652-
),
653-
)
645+
attn_output.as_strided_(
646+
attn_output.shape,
647+
(
648+
query.shape[1] * query.shape[2] * query.shape[3],
649+
query.shape[3],
650+
query.shape[2] * query.shape[3],
651+
1,
652+
),
653+
)
654654
attn_weights = None
655655
key_cache_out = query.new_empty(
656656
(key_cache.shape[0], key_cache.shape[1], key.shape[2], key.shape[3])
657657
)
658658
value_cache_out = query.new_empty(
659659
(value_cache.shape[0], value_cache.shape[1], value.shape[2], value.shape[3])
660660
)
661-
beam_idx_out = query.new_empty(beam_idx.shape)
661+
ctx = torch._custom_ops.get_ctx()
662+
num_to_keep = ctx.new_dynamic_size()
663+
beam_idx_out = query.new_empty((num_to_keep, beam_idx.shape[1]))
662664
return (attn_output, attn_weights, key_cache_out, value_cache_out, beam_idx_out)
663665

664666

tests/cpu/test_masked_mha.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,9 @@ def _test_mha(self, torchcompile=False):
152152

153153
if torchcompile:
154154
torch._dynamo.reset()
155+
torch._dynamo.config.capture_dynamic_output_shape_ops = True
155156
ipex._set_compiler_backend("inductor")
156-
mha = torch.compile(mha, backend="ipex")
157+
mha = torch.compile(mha, backend="ipex", dynamic=True)
157158

158159
# first token decode
159160
input_t = torch.randn(

tests/cpu/test_masked_mha_fp8.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,9 @@ def _test_mha(self, torchcompile=False):
169169

170170
if torchcompile:
171171
torch._dynamo.reset()
172+
torch._dynamo.config.capture_dynamic_output_shape_ops = True
172173
ipex._set_compiler_backend("inductor")
173-
mha = torch.compile(mha, backend="ipex")
174+
mha = torch.compile(mha, backend="ipex", dynamic=True)
174175

175176
# first token decode
176177
input_t = torch.randn(

0 commit comments

Comments
 (0)