Skip to content

Commit d08e0bb

Browse files
committed
update
1 parent c366b5a commit d08e0bb

File tree

8 files changed

+794
-287
lines changed

8 files changed

+794
-287
lines changed

tests/models/testing_utils/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from .attention import AttentionTesterMixin, ContextParallelTesterMixin
2-
from .common import ModelTesterMixin
2+
from .common import BaseModelTesterConfig, ModelTesterMixin
33
from .compile import TorchCompileTesterMixin
44
from .ip_adapter import IPAdapterTesterMixin
5-
from .lora import LoraTesterMixin
5+
from .lora import LoraHotSwappingForModelTesterMixin, LoraTesterMixin
66
from .memory import CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin, MemoryTesterMixin
77
from .quantization import (
88
BitsAndBytesTesterMixin,
@@ -17,14 +17,16 @@
1717

1818

1919
__all__ = [
20-
"ContextParallelTesterMixin",
2120
"AttentionTesterMixin",
21+
"BaseModelTesterConfig",
2222
"BitsAndBytesTesterMixin",
23+
"ContextParallelTesterMixin",
2324
"CPUOffloadTesterMixin",
2425
"GGUFTesterMixin",
2526
"GroupOffloadTesterMixin",
2627
"IPAdapterTesterMixin",
2728
"LayerwiseCastingTesterMixin",
29+
"LoraHotSwappingForModelTesterMixin",
2830
"LoraTesterMixin",
2931
"MemoryTesterMixin",
3032
"ModelOptTesterMixin",

tests/models/testing_utils/attention.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,13 @@
2525
AttnProcessor,
2626
)
2727

28-
from ...testing_utils import is_attention, is_context_parallel, require_torch_multi_accelerator, torch_device
28+
from ...testing_utils import (
29+
assert_tensors_close,
30+
is_attention,
31+
is_context_parallel,
32+
require_torch_multi_accelerator,
33+
torch_device,
34+
)
2935

3036

3137
@is_attention
@@ -89,8 +95,12 @@ def test_fuse_unfuse_qkv_projections(self):
8995
output_after_fusion = output_after_fusion.to_tuple()[0]
9096

9197
# Verify outputs match
92-
assert torch.allclose(output_before_fusion, output_after_fusion, atol=self.base_precision), (
93-
"Output should not change after fusing projections"
98+
assert_tensors_close(
99+
output_before_fusion,
100+
output_after_fusion,
101+
atol=self.base_precision,
102+
rtol=0,
103+
msg="Output should not change after fusing projections",
94104
)
95105

96106
# Unfuse projections
@@ -110,8 +120,12 @@ def test_fuse_unfuse_qkv_projections(self):
110120
output_after_unfusion = output_after_unfusion.to_tuple()[0]
111121

112122
# Verify outputs still match
113-
assert torch.allclose(output_before_fusion, output_after_unfusion, atol=self.base_precision), (
114-
"Output should match original after unfusing projections"
123+
assert_tensors_close(
124+
output_before_fusion,
125+
output_after_unfusion,
126+
atol=self.base_precision,
127+
rtol=0,
128+
msg="Output should match original after unfusing projections",
115129
)
116130

117131
def test_get_set_processor(self):
@@ -238,9 +252,6 @@ def test_context_parallel_inference(self, cp_type):
238252
if not torch.distributed.is_available():
239253
pytest.skip("torch.distributed is not available.")
240254

241-
if not torch.cuda.is_available() or torch.cuda.device_count() < 2:
242-
pytest.skip("Context parallel requires at least 2 CUDA devices.")
243-
244255
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
245256
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
246257

0 commit comments

Comments
 (0)