22from typing import List , Optional
33
44import torch
5+ import torch ._custom_ops
56import torch .library
67from torch ._prims_common import IntLike
78from .utils .channels_last_1d import to_channels_last_1d
@@ -623,7 +624,7 @@ def meta_tpp_linear_mul(
623624 return input .new_empty ((* input .shape [:- 1 ], out_features ))
624625
625626
626- @register_meta ( " masked_multihead_self_attention" )
627+ @torch . library . register_fake ( "torch_ipex:: masked_multihead_self_attention" )
627628def meta_masked_multihead_self_attention (
628629 query ,
629630 key ,
@@ -641,24 +642,25 @@ def meta_masked_multihead_self_attention(
641642 attn_output = query .new_empty (
642643 (query .shape [0 ], query .shape [2 ], query .shape [1 ], query .shape [3 ])
643644 )
644- if query .dtype == torch .bfloat16 :
645- attn_output .as_strided_ (
646- attn_output .shape ,
647- (
648- query .shape [1 ] * query .shape [2 ] * query .shape [3 ],
649- query .shape [3 ],
650- query .shape [2 ] * query .shape [3 ],
651- 1 ,
652- ),
653- )
645+ attn_output .as_strided_ (
646+ attn_output .shape ,
647+ (
648+ query .shape [1 ] * query .shape [2 ] * query .shape [3 ],
649+ query .shape [3 ],
650+ query .shape [2 ] * query .shape [3 ],
651+ 1 ,
652+ ),
653+ )
654654 attn_weights = None
655655 key_cache_out = query .new_empty (
656656 (key_cache .shape [0 ], key_cache .shape [1 ], key .shape [2 ], key .shape [3 ])
657657 )
658658 value_cache_out = query .new_empty (
659659 (value_cache .shape [0 ], value_cache .shape [1 ], value .shape [2 ], value .shape [3 ])
660660 )
661- beam_idx_out = query .new_empty (beam_idx .shape )
661+ ctx = torch ._custom_ops .get_ctx ()
662+ num_to_keep = ctx .new_dynamic_size ()
663+ beam_idx_out = query .new_empty ((num_to_keep , beam_idx .shape [1 ]))
662664 return (attn_output , attn_weights , key_cache_out , value_cache_out , beam_idx_out )
663665
664666
0 commit comments