2525 AttnProcessor ,
2626)
2727
28- from ...testing_utils import is_attention , is_context_parallel , require_torch_multi_accelerator , torch_device
28+ from ...testing_utils import (
29+ assert_tensors_close ,
30+ is_attention ,
31+ is_context_parallel ,
32+ require_torch_multi_accelerator ,
33+ torch_device ,
34+ )
2935
3036
3137@is_attention
@@ -89,8 +95,12 @@ def test_fuse_unfuse_qkv_projections(self):
8995 output_after_fusion = output_after_fusion .to_tuple ()[0 ]
9096
9197 # Verify outputs match
92- assert torch .allclose (output_before_fusion , output_after_fusion , atol = self .base_precision ), (
93- "Output should not change after fusing projections"
98+ assert_tensors_close (
99+ output_before_fusion ,
100+ output_after_fusion ,
101+ atol = self .base_precision ,
102+ rtol = 0 ,
103+ msg = "Output should not change after fusing projections" ,
94104 )
95105
96106 # Unfuse projections
@@ -110,8 +120,12 @@ def test_fuse_unfuse_qkv_projections(self):
110120 output_after_unfusion = output_after_unfusion .to_tuple ()[0 ]
111121
112122 # Verify outputs still match
113- assert torch .allclose (output_before_fusion , output_after_unfusion , atol = self .base_precision ), (
114- "Output should match original after unfusing projections"
123+ assert_tensors_close (
124+ output_before_fusion ,
125+ output_after_unfusion ,
126+ atol = self .base_precision ,
127+ rtol = 0 ,
128+ msg = "Output should match original after unfusing projections" ,
115129 )
116130
117131 def test_get_set_processor (self ):
@@ -238,9 +252,6 @@ def test_context_parallel_inference(self, cp_type):
238252 if not torch .distributed .is_available ():
239253 pytest .skip ("torch.distributed is not available." )
240254
241- if not torch .cuda .is_available () or torch .cuda .device_count () < 2 :
242- pytest .skip ("Context parallel requires at least 2 CUDA devices." )
243-
244255 if not hasattr (self .model_class , "_cp_plan" ) or self .model_class ._cp_plan is None :
245256 pytest .skip ("Model does not have a _cp_plan defined for context parallel inference." )
246257
0 commit comments