From 791f1831e7d48a8455e581f7b641f469ff9885c7 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 19 Dec 2025 14:11:38 +0000 Subject: [PATCH 01/12] code drop Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/cpu_offload.py | 61 +++++++++++++++---- transformer_engine/pytorch/csrc/extensions.h | 3 +- .../pytorch/csrc/extensions/cast.cpp | 37 +++++------ .../pytorch/csrc/extensions/pybind.cpp | 3 +- .../quantization_current_scaling.py | 4 ++ .../pytorch/module/grouped_linear.py | 13 +++- transformer_engine/pytorch/module/linear.py | 3 +- .../pytorch/quantized_tensor.py | 3 + .../pytorch/tensor/float8_blockwise_tensor.py | 52 +++++++++++----- .../pytorch/tensor/mxfp8_tensor.py | 6 +- .../pytorch/tensor/nvfp4_tensor.py | 56 ++++++++++++----- 11 files changed, 175 insertions(+), 66 deletions(-) diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index 58ed063066..d6bff166b9 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -19,6 +19,7 @@ from .quantized_tensor import ( restore_from_saved, prepare_for_saving, + QuantizedTensor, ) @@ -255,6 +256,8 @@ def start_offload(self): Start offloading of tensors. Puts copy from GPU to CPU tasks on offload stream. Before each copy event, the offload stream waits for the event signalling that the tensor is ready to be offloaded. This event is recorded in the start_offload or push_tensor call. + + Note: tensor_list only contains regular tensors (QuantizedTensors are decomposed in push_tensor). """ self._validate_state(func_name="start_offload", allowed_states=["not_offloaded"]) self.state = "offload_started" @@ -275,19 +278,18 @@ def start_offload(self): with torch.cuda.stream(self.offload_stream): if allocate_cpu_buffers: - # empty_like is defined also for QuantizedTensors offloaded_tensor = torch.empty_like( tensor, device=torch.device("cpu"), pin_memory=True ) self.cpu_tensor_group.tensor_list.append(offloaded_tensor) else: - assert self.cpu_tensor_group.tensor_list[tensor_id].shape == tensor.shape, ( + offloaded_tensor = self.cpu_tensor_group.tensor_list[tensor_id] + assert offloaded_tensor.shape == tensor.shape, ( "CPU buffer shape does not match the offloaded tensor shape:" - f" {self.cpu_tensor_group.tensor_list[tensor_id].shape} != {tensor.shape} " - " Make sure that tensor shaped do not change between" + f" {offloaded_tensor.shape} != {tensor.shape} " + "Make sure that tensor shapes do not change between" " iterations if retain_pinned_cpu_buffers is True." ) - offloaded_tensor = self.cpu_tensor_group.tensor_list[tensor_id] offloaded_tensor.copy_(tensor, non_blocking=True) # aux is a dictionary that contains auxiliary data like information which tensors were deduplicated, @@ -318,6 +320,9 @@ def start_reload(self): """ Start reloading of tensors. It allocates new tensors on GPU and puts copy from CPU tasks on offload stream. + + Note: tensor_list only contains regular tensors (QuantizedTensors are decomposed in push_tensor + and reconstructed in pop_tensor). """ self._validate_state(func_name="start_reload", allowed_states=["offload_finished"]) self.state = "reload_started" @@ -330,7 +335,6 @@ def start_reload(self): # cannot move tensors from pool of one stream to another without # calling cudaFree and cudaMalloc again. - # empty_like is defined also for QuantizedTensors. reloaded_tensor = torch.empty_like(tensor, device=torch.device("cuda")) self.offload_stream.wait_stream(torch.cuda.current_stream()) @@ -347,15 +351,30 @@ def start_reload(self): self.bwd_gpu_tensor_group ) - def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor: + def push_tensor( + self, tensor: torch.Tensor + ) -> int | torch.Tensor | tuple[list, list]: """ It is called when a tensor is saved for backward pass. If tensor is offloaded, returns int representing the index of the tensor in the offloaded tensor group. If tensor is not offloaded, returns the tensor itself. + For QuantizedTensor, returns (list of push results for each component, tensor_objs) tuple. """ self._validate_state(func_name="push_tensor", allowed_states=["not_offloaded"]) + # For QuantizedTensor: decompose into component tensors, push each one recursively + if isinstance(tensor, QuantizedTensor): + # Make a copy because prepare_for_saving modifies the object (sets fields to None) + tensor_copy = tensor.detach() + # Inline prepare_for_saving logic - QuantizedTensor is a torch.Tensor subclass, + # so the generic prepare_for_saving would not call tensor.prepare_for_saving() + saved_tensors, tensor_obj = tensor_copy.prepare_for_saving() + push_results = [ + self.push_tensor(t) if t is not None else None for t in saved_tensors + ] + return (push_results, [tensor_obj]) + if self._check_if_offload(tensor): self.fwd_gpu_tensor_group.tensor_list.append(tensor) # The group is processed and offloaded at the end of the forward pass of current layer. @@ -370,23 +389,39 @@ def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor: return len(self.fwd_gpu_tensor_group.tensor_list) - 1 return tensor - def pop_tensor(self, tensor_or_tensor_id: torch.Tensor | int) -> torch.Tensor: + def pop_tensor( + self, tensor_or_tensor_id: torch.Tensor | int | tuple[list, list] + ) -> torch.Tensor: """ It is called when a tensor is used in backward pass. Returns the tensor. If tensor was offloaded/reloaded, wait for the reload of a tensor to finish. + For QuantizedTensor (tuple input), reconstructs from component tensors. """ self._validate_state( func_name="pop_tensor", allowed_states=["not_offloaded", "reload_started"] ) - # 1. tensor not offloaded + # 1. tensor not offloaded (regular tensor returned as-is from push) if isinstance(tensor_or_tensor_id, torch.Tensor): return tensor_or_tensor_id - # 2. the layer was not offloaded at all + + # 2. QuantizedTensor case: tuple of (push_results, tensor_objs) + if isinstance(tensor_or_tensor_id, tuple): + push_results, tensor_objs = tensor_or_tensor_id + # Recursively pop each component + reloaded_tensors = [ + self.pop_tensor(pr) if pr is not None else None for pr in push_results + ] + # Inline restore_from_saved - tensor_objs[0] is the QuantizedTensor copy + tensor_obj = tensor_objs[0] + tensor_obj.restore_from_saved(reloaded_tensors) + return tensor_obj + + # 3. Regular tensor index case if self.state == "not_offloaded": return self.fwd_gpu_tensor_group.tensor_list[tensor_or_tensor_id] - # 3. the layer was offloaded + # 4. the layer was offloaded assert self.state == "reload_started" # wait for the tensor to be reloaded torch.cuda.current_stream().wait_event( @@ -419,6 +454,10 @@ def _check_if_offload(self, t: torch.Tensor) -> bool: ) return False + # Only offload tensors with at least 256k elements (~1MB for float32) + if t.numel() < 256 * 1024: + return False + return True return False diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 80479dccf4..484d04c436 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -254,7 +254,8 @@ std::vector multi_tensor_quantize(const std::vector &ten std::vector split_quantize(const at::Tensor &tensor, const std::vector &split_sections, - std::vector quantizer_list); + std::vector quantizer_list, + bool disable_bulk_allocation = false); /*************************************************************************************************** * Bias gradient fusions diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index ac541435c7..2c08087343 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -944,7 +944,8 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input, std::vector split_quantize(const at::Tensor &tensor, const std::vector &split_sections, - std::vector quantizer_list) { + std::vector quantizer_list, + bool disable_bulk_allocation) { init_extension(); // Check number of tensors @@ -996,22 +997,24 @@ std::vector split_quantize(const at::Tensor &tensor, enum class QuantizationMethod { UNFUSED, FUSED_NVFP4 }; AllocationMethod allocation_method = AllocationMethod::UNFUSED; QuantizationMethod quantization_method = QuantizationMethod::UNFUSED; - if (std::all_of(quantizer_list.begin(), quantizer_list.end(), - [](const py::handle &quantizer) -> bool { - return detail::IsFloat8BlockwiseQuantizers(quantizer.ptr()); - })) { - allocation_method = AllocationMethod::BULK_FP8_BLOCKWISE; - } else if (std::all_of(quantizer_list.begin(), quantizer_list.end(), - [](const py::handle &quantizer) -> bool { - return detail::IsMXFP8Quantizers(quantizer.ptr()); - })) { - allocation_method = AllocationMethod::BULK_MXFP8; - } else if (std::all_of(quantizer_list.begin(), quantizer_list.end(), - [](const py::handle &quantizer) -> bool { - return detail::IsNVFP4Quantizers(quantizer.ptr()); - })) { - allocation_method = AllocationMethod::BULK_NVFP4; - quantization_method = QuantizationMethod::FUSED_NVFP4; + if (!disable_bulk_allocation) { + if (std::all_of(quantizer_list.begin(), quantizer_list.end(), + [](const py::handle &quantizer) -> bool { + return detail::IsFloat8BlockwiseQuantizers(quantizer.ptr()); + })) { + allocation_method = AllocationMethod::BULK_FP8_BLOCKWISE; + } else if (std::all_of(quantizer_list.begin(), quantizer_list.end(), + [](const py::handle &quantizer) -> bool { + return detail::IsMXFP8Quantizers(quantizer.ptr()); + })) { + allocation_method = AllocationMethod::BULK_MXFP8; + } else if (std::all_of(quantizer_list.begin(), quantizer_list.end(), + [](const py::handle &quantizer) -> bool { + return detail::IsNVFP4Quantizers(quantizer.ptr()); + })) { + allocation_method = AllocationMethod::BULK_NVFP4; + quantization_method = QuantizationMethod::FUSED_NVFP4; + } } // Allocate output tensors diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index d0f450bc71..0cbb457173 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -248,7 +248,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Multi-tensor quantize", py::arg("tensor_list"), py::arg("quantizer_list")); m.def("split_quantize", &transformer_engine::pytorch::split_quantize, "Split and multi-tensor quantize", py::arg("tensor"), py::arg("split_sections"), - py::arg("quantizer_list")); + py::arg("quantizer_list"), + py::arg("disable_bulk_allocation") = false); m.def("te_general_grouped_gemm", &transformer_engine::pytorch::te_general_grouped_gemm, "Grouped GEMM"); m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O", diff --git a/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py b/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py index 96cbca772c..49e2086238 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py @@ -487,8 +487,12 @@ def make_empty( dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, requires_grad: bool = False, # pylint: disable=unused-argument + share_scales: bool = False, + like: Optional[QuantizedTensor] = None, # pylint: disable=unused-argument ) -> CurrentScalingTensorRef: assert len(shape) == 2, "shape is not 2d" + if share_scales: + raise ValueError("share_scales is not supported for CurrentScalingTensorRef") # Canonicalize tensor attributes if device is None: diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index c4d35a9c2c..94d6f149ac 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -143,7 +143,9 @@ def forward( inp_view = inp.reshape(-1, in_features) inputmats: list if fp8 and not debug: - inputmats = tex.split_quantize(inp_view, m_splits, input_quantizers) + inputmats = tex.split_quantize( + inp_view, m_splits, input_quantizers, disable_bulk_allocation=cpu_offloading + ) elif debug: inputmats = DebugQuantizer.multi_tensor_quantize( inp_view, input_quantizers, m_splits, activation_dtype @@ -347,6 +349,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], grad_output_view, ctx.m_splits, ctx.grad_output_quantizers, + disable_bulk_allocation=ctx.cpu_offloading, ) else: # Multi-tensor quantize @@ -354,6 +357,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], grad_output_view, ctx.m_splits, ctx.grad_output_quantizers, + disable_bulk_allocation=ctx.cpu_offloading, ) elif ctx.debug: grad_output_mats = torch.split(grad_output_view, ctx.m_splits) @@ -438,7 +442,12 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], input_quantizer.set_usage(rowwise=False, columnwise=True) inputmats: list if ctx.fp8 and not ctx.debug: - inputmats = tex.split_quantize(inp_view, ctx.m_splits, ctx.input_quantizers) + inputmats = tex.split_quantize( + inp_view, + ctx.m_splits, + ctx.input_quantizers, + disable_bulk_allocation=ctx.cpu_offloading, + ) elif ctx.debug: inputmats = DebugQuantizer.multi_tensor_quantize( inp_view, ctx.input_quantizers, ctx.m_splits, ctx.activation_dtype diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index b65f7005eb..6cbfcd8859 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -428,7 +428,8 @@ def forward( # weights if weights are externally touched outside this module ctx.weight_object = weight - mark_not_offload(weight, weightmat, bias) + if cpu_offloading: + mark_not_offload(weight, weightmat, bias) # TODO(ksivamani): Check memory usage tensors_to_save, tensor_objects = prepare_for_saving( saved_inputmat, diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index c9a4467a82..3988e58a2f 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -267,6 +267,8 @@ def make_empty( *, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, + requires_grad: bool = False, + pin_memory: bool = False, ) -> QuantizedTensor: """Construct quantized tensor with uninitialized data""" raise NotImplementedError( @@ -467,6 +469,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): # Empty like op if func == torch.ops.aten.empty_like.default: tensor = args[0] + kwargs = kwargs or {} device = kwargs.get("device", tensor.device) requires_grad = kwargs.get("requires_grad", tensor.requires_grad) pin_memory = kwargs.get("pin_memory", False) diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 01e03e5355..d9c7401065 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -231,6 +231,8 @@ def make_empty( device: Optional[torch.device] = None, requires_grad: bool = False, pin_memory: bool = False, + share_scales: bool = False, + like: Optional[QuantizedTensor] = None, ) -> Float8BlockwiseQTensor: """Construct quantized tensor with uninitialized data""" if device is None: @@ -244,16 +246,25 @@ def make_empty( # Allocate FP8 data data = None - scale_inv = None + rowwise_scale_inv = None if self.rowwise_usage: data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory) - scale_shape = self.get_scale_shape(shape, columnwise=False) - scale_inv = torch.empty( - scale_shape, - dtype=torch.float32, - device=device, - pin_memory=pin_memory, - ) + if share_scales: + if ( + like is None + or not hasattr(like, "_rowwise_scale_inv") + or like._rowwise_scale_inv is None + ): + raise ValueError("share_scales requested but no rowwise scale tensor provided") + rowwise_scale_inv = like._rowwise_scale_inv + else: + scale_shape = self.get_scale_shape(shape, columnwise=False) + rowwise_scale_inv = torch.empty( + scale_shape, + dtype=torch.float32, + device=device, + pin_memory=pin_memory, + ) # Allocate FP8 data transpose if needed columnwise_data = None @@ -265,13 +276,22 @@ def make_empty( device=device, pin_memory=pin_memory, ) - columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True) - columnwise_scale_inv = torch.empty( - columnwise_scale_shape, - dtype=torch.float32, - device=device, - pin_memory=pin_memory, - ) + if share_scales: + if ( + like is None + or not hasattr(like, "_columnwise_scale_inv") + or like._columnwise_scale_inv is None + ): + raise ValueError("share_scales requested but no columnwise scale tensor provided") + columnwise_scale_inv = like._columnwise_scale_inv + else: + columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True) + columnwise_scale_inv = torch.empty( + columnwise_scale_shape, + dtype=torch.float32, + device=device, + pin_memory=pin_memory, + ) # Construct FP8 tensor return Float8BlockwiseQTensor( @@ -279,7 +299,7 @@ def make_empty( dtype=dtype, fp8_dtype=self.dtype, rowwise_data=data, - rowwise_scale_inv=scale_inv, + rowwise_scale_inv=rowwise_scale_inv, columnwise_data=columnwise_data, columnwise_scale_inv=columnwise_scale_inv, quantizer=self, diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 6dcf9ae79a..893c2e8c65 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -138,7 +138,9 @@ def make_empty( shape, dtype=torch.uint8, device=device, pin_memory=pin_memory ) columnwise_scale_inv = torch.empty( - round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4), + round_up_to_nearest_multiple( + math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4 + ), round_up_to_nearest_multiple(shape[-1], 128), dtype=torch.uint8, device=device, @@ -151,7 +153,7 @@ def make_empty( dtype=dtype, fp8_dtype=self.dtype, rowwise_data=data, - rowwise_scale_inv=scale_inv, + rowwise_scale_inv=rowwise_scale_inv, columnwise_data=columnwise_data, columnwise_scale_inv=columnwise_scale_inv, quantizer=self, diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 0c244628d6..eca7ca031e 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -289,6 +289,8 @@ def make_empty( device: Optional[torch.device] = None, pin_memory: bool = False, requires_grad: bool = False, + share_scales: bool = False, + like: Optional[QuantizedTensor] = None, ) -> NVFP4Tensor: # Canonicalize tensor attributes @@ -308,7 +310,7 @@ def make_empty( # Allocate FP4 data data = None - scale_inv = None + rowwise_scale_inv = None amax_rowwise = None if self.rowwise_usage: data = torch.empty( @@ -317,12 +319,24 @@ def make_empty( device=device, pin_memory=pin_memory, ) - scale_shape = self.get_scale_shape(shape, columnwise=False) - scale_inv = torch.empty( - scale_shape, dtype=torch.uint8, device=device, pin_memory=pin_memory - ) - # Allocate per tensor scale inverse. FP32 format. - amax_rowwise = torch.zeros(1, dtype=torch.float32, device=device, pin_memory=pin_memory) + if share_scales: + if ( + like is None + or not hasattr(like, "_rowwise_scale_inv") + or like._rowwise_scale_inv is None + ): + raise ValueError("share_scales requested but no rowwise scale tensor provided") + rowwise_scale_inv = like._rowwise_scale_inv + amax_rowwise = getattr(like, "_amax_rowwise", None) + else: + scale_shape = self.get_scale_shape(shape, columnwise=False) + rowwise_scale_inv = torch.empty( + scale_shape, dtype=torch.uint8, device=device, pin_memory=pin_memory + ) + # Allocate per tensor scale inverse. FP32 format. + amax_rowwise = torch.zeros( + 1, dtype=torch.float32, device=device, pin_memory=pin_memory + ) # Allocate FP8 data transpose if needed columnwise_data = None @@ -338,20 +352,32 @@ def make_empty( device=device, pin_memory=pin_memory, ) - columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True) - columnwise_scale_inv = torch.empty( - columnwise_scale_shape, dtype=torch.uint8, device=device, pin_memory=pin_memory - ) - amax_columnwise = torch.zeros( - 1, dtype=torch.float32, device=device, pin_memory=pin_memory - ) + if share_scales: + if ( + like is None + or not hasattr(like, "_columnwise_scale_inv") + or like._columnwise_scale_inv is None + ): + raise ValueError( + "share_scales requested but no columnwise scale tensor provided" + ) + columnwise_scale_inv = like._columnwise_scale_inv + amax_columnwise = getattr(like, "_amax_columnwise", None) + else: + columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True) + columnwise_scale_inv = torch.empty( + columnwise_scale_shape, dtype=torch.uint8, device=device, pin_memory=pin_memory + ) + amax_columnwise = torch.zeros( + 1, dtype=torch.float32, device=device, pin_memory=pin_memory + ) # Construct FP8 tensor return NVFP4Tensor( shape=shape, dtype=dtype, rowwise_data=data, - rowwise_scale_inv=scale_inv, + rowwise_scale_inv=rowwise_scale_inv, columnwise_data=columnwise_data, columnwise_scale_inv=columnwise_scale_inv, amax_rowwise=amax_rowwise, From 6d2f43b44d6f10d452ea7c1b3af4da7c6eebfb52 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 19 Dec 2025 14:21:14 +0000 Subject: [PATCH 02/12] fix Signed-off-by: Pawel Gadzinski --- .../pytorch/tensor/float8_blockwise_tensor.py | 52 ++++++----------- .../pytorch/tensor/nvfp4_tensor.py | 56 +++++-------------- 2 files changed, 31 insertions(+), 77 deletions(-) diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index d9c7401065..01e03e5355 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -231,8 +231,6 @@ def make_empty( device: Optional[torch.device] = None, requires_grad: bool = False, pin_memory: bool = False, - share_scales: bool = False, - like: Optional[QuantizedTensor] = None, ) -> Float8BlockwiseQTensor: """Construct quantized tensor with uninitialized data""" if device is None: @@ -246,25 +244,16 @@ def make_empty( # Allocate FP8 data data = None - rowwise_scale_inv = None + scale_inv = None if self.rowwise_usage: data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory) - if share_scales: - if ( - like is None - or not hasattr(like, "_rowwise_scale_inv") - or like._rowwise_scale_inv is None - ): - raise ValueError("share_scales requested but no rowwise scale tensor provided") - rowwise_scale_inv = like._rowwise_scale_inv - else: - scale_shape = self.get_scale_shape(shape, columnwise=False) - rowwise_scale_inv = torch.empty( - scale_shape, - dtype=torch.float32, - device=device, - pin_memory=pin_memory, - ) + scale_shape = self.get_scale_shape(shape, columnwise=False) + scale_inv = torch.empty( + scale_shape, + dtype=torch.float32, + device=device, + pin_memory=pin_memory, + ) # Allocate FP8 data transpose if needed columnwise_data = None @@ -276,22 +265,13 @@ def make_empty( device=device, pin_memory=pin_memory, ) - if share_scales: - if ( - like is None - or not hasattr(like, "_columnwise_scale_inv") - or like._columnwise_scale_inv is None - ): - raise ValueError("share_scales requested but no columnwise scale tensor provided") - columnwise_scale_inv = like._columnwise_scale_inv - else: - columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True) - columnwise_scale_inv = torch.empty( - columnwise_scale_shape, - dtype=torch.float32, - device=device, - pin_memory=pin_memory, - ) + columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True) + columnwise_scale_inv = torch.empty( + columnwise_scale_shape, + dtype=torch.float32, + device=device, + pin_memory=pin_memory, + ) # Construct FP8 tensor return Float8BlockwiseQTensor( @@ -299,7 +279,7 @@ def make_empty( dtype=dtype, fp8_dtype=self.dtype, rowwise_data=data, - rowwise_scale_inv=rowwise_scale_inv, + rowwise_scale_inv=scale_inv, columnwise_data=columnwise_data, columnwise_scale_inv=columnwise_scale_inv, quantizer=self, diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index eca7ca031e..0c244628d6 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -289,8 +289,6 @@ def make_empty( device: Optional[torch.device] = None, pin_memory: bool = False, requires_grad: bool = False, - share_scales: bool = False, - like: Optional[QuantizedTensor] = None, ) -> NVFP4Tensor: # Canonicalize tensor attributes @@ -310,7 +308,7 @@ def make_empty( # Allocate FP4 data data = None - rowwise_scale_inv = None + scale_inv = None amax_rowwise = None if self.rowwise_usage: data = torch.empty( @@ -319,24 +317,12 @@ def make_empty( device=device, pin_memory=pin_memory, ) - if share_scales: - if ( - like is None - or not hasattr(like, "_rowwise_scale_inv") - or like._rowwise_scale_inv is None - ): - raise ValueError("share_scales requested but no rowwise scale tensor provided") - rowwise_scale_inv = like._rowwise_scale_inv - amax_rowwise = getattr(like, "_amax_rowwise", None) - else: - scale_shape = self.get_scale_shape(shape, columnwise=False) - rowwise_scale_inv = torch.empty( - scale_shape, dtype=torch.uint8, device=device, pin_memory=pin_memory - ) - # Allocate per tensor scale inverse. FP32 format. - amax_rowwise = torch.zeros( - 1, dtype=torch.float32, device=device, pin_memory=pin_memory - ) + scale_shape = self.get_scale_shape(shape, columnwise=False) + scale_inv = torch.empty( + scale_shape, dtype=torch.uint8, device=device, pin_memory=pin_memory + ) + # Allocate per tensor scale inverse. FP32 format. + amax_rowwise = torch.zeros(1, dtype=torch.float32, device=device, pin_memory=pin_memory) # Allocate FP8 data transpose if needed columnwise_data = None @@ -352,32 +338,20 @@ def make_empty( device=device, pin_memory=pin_memory, ) - if share_scales: - if ( - like is None - or not hasattr(like, "_columnwise_scale_inv") - or like._columnwise_scale_inv is None - ): - raise ValueError( - "share_scales requested but no columnwise scale tensor provided" - ) - columnwise_scale_inv = like._columnwise_scale_inv - amax_columnwise = getattr(like, "_amax_columnwise", None) - else: - columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True) - columnwise_scale_inv = torch.empty( - columnwise_scale_shape, dtype=torch.uint8, device=device, pin_memory=pin_memory - ) - amax_columnwise = torch.zeros( - 1, dtype=torch.float32, device=device, pin_memory=pin_memory - ) + columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True) + columnwise_scale_inv = torch.empty( + columnwise_scale_shape, dtype=torch.uint8, device=device, pin_memory=pin_memory + ) + amax_columnwise = torch.zeros( + 1, dtype=torch.float32, device=device, pin_memory=pin_memory + ) # Construct FP8 tensor return NVFP4Tensor( shape=shape, dtype=dtype, rowwise_data=data, - rowwise_scale_inv=rowwise_scale_inv, + rowwise_scale_inv=scale_inv, columnwise_data=columnwise_data, columnwise_scale_inv=columnwise_scale_inv, amax_rowwise=amax_rowwise, From e28c581669b419267636682fd7f5e196ea7295c8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 19 Dec 2025 14:22:42 +0000 Subject: [PATCH 03/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/cpu_offload.py | 8 ++------ transformer_engine/pytorch/csrc/extensions/pybind.cpp | 3 +-- transformer_engine/pytorch/tensor/mxfp8_tensor.py | 4 +--- 3 files changed, 4 insertions(+), 11 deletions(-) diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index d6bff166b9..73dda74190 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -351,9 +351,7 @@ def start_reload(self): self.bwd_gpu_tensor_group ) - def push_tensor( - self, tensor: torch.Tensor - ) -> int | torch.Tensor | tuple[list, list]: + def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor | tuple[list, list]: """ It is called when a tensor is saved for backward pass. @@ -370,9 +368,7 @@ def push_tensor( # Inline prepare_for_saving logic - QuantizedTensor is a torch.Tensor subclass, # so the generic prepare_for_saving would not call tensor.prepare_for_saving() saved_tensors, tensor_obj = tensor_copy.prepare_for_saving() - push_results = [ - self.push_tensor(t) if t is not None else None for t in saved_tensors - ] + push_results = [self.push_tensor(t) if t is not None else None for t in saved_tensors] return (push_results, [tensor_obj]) if self._check_if_offload(tensor): diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 0cbb457173..62337ee7e7 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -248,8 +248,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Multi-tensor quantize", py::arg("tensor_list"), py::arg("quantizer_list")); m.def("split_quantize", &transformer_engine::pytorch::split_quantize, "Split and multi-tensor quantize", py::arg("tensor"), py::arg("split_sections"), - py::arg("quantizer_list"), - py::arg("disable_bulk_allocation") = false); + py::arg("quantizer_list"), py::arg("disable_bulk_allocation") = false); m.def("te_general_grouped_gemm", &transformer_engine::pytorch::te_general_grouped_gemm, "Grouped GEMM"); m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O", diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 893c2e8c65..15a95e4124 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -138,9 +138,7 @@ def make_empty( shape, dtype=torch.uint8, device=device, pin_memory=pin_memory ) columnwise_scale_inv = torch.empty( - round_up_to_nearest_multiple( - math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4 - ), + round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4), round_up_to_nearest_multiple(shape[-1], 128), dtype=torch.uint8, device=device, From 056f06e247b34192fc1820f310a108bc0387d061 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 19 Dec 2025 14:25:04 +0000 Subject: [PATCH 04/12] fix Signed-off-by: Pawel Gadzinski --- .../pytorch/custom_recipes/quantization_current_scaling.py | 4 ---- transformer_engine/pytorch/tensor/mxfp8_tensor.py | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py b/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py index 49e2086238..96cbca772c 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py @@ -487,12 +487,8 @@ def make_empty( dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, requires_grad: bool = False, # pylint: disable=unused-argument - share_scales: bool = False, - like: Optional[QuantizedTensor] = None, # pylint: disable=unused-argument ) -> CurrentScalingTensorRef: assert len(shape) == 2, "shape is not 2d" - if share_scales: - raise ValueError("share_scales is not supported for CurrentScalingTensorRef") # Canonicalize tensor attributes if device is None: diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 15a95e4124..6dcf9ae79a 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -151,7 +151,7 @@ def make_empty( dtype=dtype, fp8_dtype=self.dtype, rowwise_data=data, - rowwise_scale_inv=rowwise_scale_inv, + rowwise_scale_inv=scale_inv, columnwise_data=columnwise_data, columnwise_scale_inv=columnwise_scale_inv, quantizer=self, From c8e09840bcb62a4962fd47d2a3a1c1491f3cd8a0 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 19 Dec 2025 14:57:55 +0000 Subject: [PATCH 05/12] fix Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/module/grouped_linear.py | 5 +---- transformer_engine/pytorch/quantized_tensor.py | 3 --- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 94d6f149ac..9c8efaabb0 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -349,7 +349,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], grad_output_view, ctx.m_splits, ctx.grad_output_quantizers, - disable_bulk_allocation=ctx.cpu_offloading, ) else: # Multi-tensor quantize @@ -357,7 +356,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], grad_output_view, ctx.m_splits, ctx.grad_output_quantizers, - disable_bulk_allocation=ctx.cpu_offloading, ) elif ctx.debug: grad_output_mats = torch.split(grad_output_view, ctx.m_splits) @@ -445,8 +443,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmats = tex.split_quantize( inp_view, ctx.m_splits, - ctx.input_quantizers, - disable_bulk_allocation=ctx.cpu_offloading, + ctx.input_quantizers ) elif ctx.debug: inputmats = DebugQuantizer.multi_tensor_quantize( diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 3988e58a2f..c9a4467a82 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -267,8 +267,6 @@ def make_empty( *, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, - requires_grad: bool = False, - pin_memory: bool = False, ) -> QuantizedTensor: """Construct quantized tensor with uninitialized data""" raise NotImplementedError( @@ -469,7 +467,6 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): # Empty like op if func == torch.ops.aten.empty_like.default: tensor = args[0] - kwargs = kwargs or {} device = kwargs.get("device", tensor.device) requires_grad = kwargs.get("requires_grad", tensor.requires_grad) pin_memory = kwargs.get("pin_memory", False) From 3f45dcdc15b534e151bbf2c332551d4d5bf2a83c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 19 Dec 2025 15:01:34 +0000 Subject: [PATCH 06/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/grouped_linear.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 9c8efaabb0..3b5bc73eb9 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -440,11 +440,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], input_quantizer.set_usage(rowwise=False, columnwise=True) inputmats: list if ctx.fp8 and not ctx.debug: - inputmats = tex.split_quantize( - inp_view, - ctx.m_splits, - ctx.input_quantizers - ) + inputmats = tex.split_quantize(inp_view, ctx.m_splits, ctx.input_quantizers) elif ctx.debug: inputmats = DebugQuantizer.multi_tensor_quantize( inp_view, ctx.input_quantizers, ctx.m_splits, ctx.activation_dtype From 04912bf024514826f31cc02485d396974af68516 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 7 Jan 2026 14:33:56 +0000 Subject: [PATCH 07/12] fix Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/optimizers/fused_adam.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index b5c87b4815..330e595ebd 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -14,6 +14,7 @@ from torch.distributed._tensor import DTensor import transformer_engine_torch as tex from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer +from transformer_engine.pytorch.quantized_tensor import QuantizedTensor from .multi_tensor_apply import multi_tensor_applier @@ -372,10 +373,12 @@ def _initialize_state( store_param_remainders (bool): Store only trailing remainder bits. """ dtype = self.name_to_dtype_map[state_name] + # Handle QuantizedTensor by dequantizing first + param_for_empty = param.dequantize() if isinstance(param, QuantizedTensor) else param if store_param_remainders: - data = torch.zeros(param.shape, dtype=torch.int16, device=param.device) + data = torch.zeros_like(param_for_empty, dtype=torch.int16) else: - data = torch.empty(param.shape, dtype=dtype, device=param.device) + data = torch.empty_like(param_for_empty, dtype=dtype) if zero_buffer: data.zero_() From 04b104d11a1531e26b51ad4561f66941e949d4b5 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 8 Jan 2026 13:50:47 +0100 Subject: [PATCH 08/12] fix Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/cpu_offload.py | 39 +++++++++++++---------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index 73dda74190..2bc9c77ede 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -361,17 +361,17 @@ def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor | tuple[list, """ self._validate_state(func_name="push_tensor", allowed_states=["not_offloaded"]) - # For QuantizedTensor: decompose into component tensors, push each one recursively - if isinstance(tensor, QuantizedTensor): - # Make a copy because prepare_for_saving modifies the object (sets fields to None) - tensor_copy = tensor.detach() - # Inline prepare_for_saving logic - QuantizedTensor is a torch.Tensor subclass, - # so the generic prepare_for_saving would not call tensor.prepare_for_saving() - saved_tensors, tensor_obj = tensor_copy.prepare_for_saving() - push_results = [self.push_tensor(t) if t is not None else None for t in saved_tensors] - return (push_results, [tensor_obj]) - if self._check_if_offload(tensor): + # For QuantizedTensor: decompose into component tensors, push each one recursively + if isinstance(tensor, QuantizedTensor): + # Make a copy because prepare_for_saving modifies the object (sets fields to None) + tensor_copy = tensor.detach() + # Inline prepare_for_saving logic - QuantizedTensor is a torch.Tensor subclass, + # so the generic prepare_for_saving would not call tensor.prepare_for_saving() + saved_tensors, tensor_obj = tensor_copy.prepare_for_saving() + push_results = [self.push_tensor(t) if t is not None else None for t in saved_tensors] + return (push_results, [tensor_obj]) + self.fwd_gpu_tensor_group.tensor_list.append(tensor) # The group is processed and offloaded at the end of the forward pass of current layer. # To enable offloading of tensors faster we use self.offload_stream and record @@ -436,7 +436,11 @@ def release_all_memory(self): def _check_if_offload(self, t: torch.Tensor) -> bool: """ Check if tensor needs to be offloaded. - """ + """ + # Only offload tensors with at least 256k elements (~1MB for float32) + if t.numel() < 256 * 1024: + return False + if ( not isinstance(t, torch.nn.Parameter) and not getattr(t, "_TE_do_not_offload", False) @@ -449,11 +453,6 @@ def _check_if_offload(self, t: torch.Tensor) -> bool: " this tensor will be skipped." ) return False - - # Only offload tensors with at least 256k elements (~1MB for float32) - if t.numel() < 256 * 1024: - return False - return True return False @@ -627,6 +626,12 @@ def bwd_step(self, layer_num: int): for layer in self.start_reload_map[layer_num]: self.layer_states[layer].start_reload() + def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor: + """Push tensor - skip processing if layer won't be offloaded to reduce CPU overhead.""" + if not self.offload_layer_map.get(self.num_of_fwds, False): + return tensor + return self.layer_states[self.num_of_fwds].push_tensor(tensor) + class ManualOffloadSynchronizer(OffloadSynchronizer): """ @@ -672,7 +677,7 @@ def get_cpu_offload_context( offload_weights: bool = False, double_buffering: bool = False, # pylint: disable=unused-argument manual_synchronization: bool = False, - retain_pinned_cpu_buffers: bool = False, + retain_pinned_cpu_buffers: bool = True, offload_stream: Optional[torch.cuda.Stream] = None, ): """ From 3ddab74b53112d877f9f237c2557cee8928c2388 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 8 Jan 2026 12:51:35 +0000 Subject: [PATCH 09/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/cpu_offload.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index 2bc9c77ede..9311f44e32 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -369,7 +369,9 @@ def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor | tuple[list, # Inline prepare_for_saving logic - QuantizedTensor is a torch.Tensor subclass, # so the generic prepare_for_saving would not call tensor.prepare_for_saving() saved_tensors, tensor_obj = tensor_copy.prepare_for_saving() - push_results = [self.push_tensor(t) if t is not None else None for t in saved_tensors] + push_results = [ + self.push_tensor(t) if t is not None else None for t in saved_tensors + ] return (push_results, [tensor_obj]) self.fwd_gpu_tensor_group.tensor_list.append(tensor) @@ -436,11 +438,11 @@ def release_all_memory(self): def _check_if_offload(self, t: torch.Tensor) -> bool: """ Check if tensor needs to be offloaded. - """ + """ # Only offload tensors with at least 256k elements (~1MB for float32) if t.numel() < 256 * 1024: return False - + if ( not isinstance(t, torch.nn.Parameter) and not getattr(t, "_TE_do_not_offload", False) From 14fbfaeb3f52f3d37173e8870694fe008de13b4b Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 8 Jan 2026 18:21:50 +0100 Subject: [PATCH 10/12] fix Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/quantized_tensor.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index c9a4467a82..aea9f08ee9 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -20,10 +20,6 @@ _stride_from_shape, ) -_quantized_tensor_cpu_supported_ops = ( - torch.ops.aten.empty_like.default, - torch.ops.aten.copy_.default, -) class QuantizedTensorStorage: @@ -539,15 +535,6 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} - def check_if_cpu(arg): - if isinstance(cls, QuantizedTensor) and arg.device.type == "cpu": - assert ( - func in _quantized_tensor_cpu_supported_ops - ), f"QuantizedTensor on CPU does not support this operation: {func}" - return arg - - args = tree_map(check_if_cpu, args) - # Do not force the QuantizedTensor type on the returned tensor return torch._C._disabled_torch_function_impl(func, types, args, kwargs) From 5f7675c97cc21713c1e10ee67b1a705f3146a510 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 8 Jan 2026 17:22:59 +0000 Subject: [PATCH 11/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/quantized_tensor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index aea9f08ee9..3e9ab288dc 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -21,7 +21,6 @@ ) - class QuantizedTensorStorage: r"""Base class for all TensorStorage classes. From ccf54b920abaf5b3c6dcb7011d45b9c0e7fc6d27 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 9 Jan 2026 16:53:37 +0100 Subject: [PATCH 12/12] test fix Signed-off-by: Pawel Gadzinski --- tests/pytorch/test_cpu_offloading.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index 385998a8c5..0b4e3014f7 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -54,9 +54,11 @@ class Utils: + # Tensor big engough that both data and scaling factor tensor are bigger than 256 * 1024 elements, + # so that they are offloaded to GPU. tensor1 = torch.randn((1024, 1024), device="cuda", dtype=torch.bfloat16) - _B = 64 - _S = 256 + _B = 128 + _S = 512 _H = 4 _D = 256