diff --git a/tests/pytorch/nvfp4/test_nvfp4_post_rht_amax_estimation_sanity.py b/tests/pytorch/nvfp4/test_nvfp4_post_rht_amax_estimation_sanity.py new file mode 100644 index 00000000000..16aa8bf6f41 --- /dev/null +++ b/tests/pytorch/nvfp4/test_nvfp4_post_rht_amax_estimation_sanity.py @@ -0,0 +1,45 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest +import torch + +import transformer_engine.pytorch as te +from transformer_engine.pytorch import NVFP4Quantizer + + +recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required for NVFP4 quantization") +def test_nvfp4_post_rht_amax_estimation_sanity() -> None: + """Sanity: when using post-RHT amax estimation, columnwise amax is scaled pre-RHT amax.""" + + torch.manual_seed(0) + torch.cuda.manual_seed(0) + + # Shape must satisfy NVFP4 constraints and RHT kernel constraints. + # rows % 64 == 0 and cols % 128 == 0 triggers the fast RHT-cast fusion path. + M, N = 128, 128 + x = torch.randn((M, N), device="cuda", dtype=torch.bfloat16) + + scale = 2.0 + q = NVFP4Quantizer( + rowwise=True, + columnwise=True, + with_rht=True, + # Estimation path requires post-RHT amax kernel disabled. + with_post_rht_amax=False, + amax_estimation_scale=scale, + stochastic_rounding=False, + ) + + y = q(x) + assert y._amax_rowwise is not None + assert y._amax_columnwise is not None + + amax_pre = torch.max(torch.abs(x)).to(torch.float32).view(1) + torch.testing.assert_close(y._amax_rowwise, amax_pre, atol=0.0, rtol=0.0) + torch.testing.assert_close(y._amax_columnwise, amax_pre * scale, atol=0.0, rtol=0.0) diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h index b1773a8db3a..78711b97278 100644 --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -99,6 +99,23 @@ void nvte_compute_amax(const NVTETensor input, NVTETensor output, cudaStream_t s void nvte_compute_amax_with_config(const NVTETensor input, NVTETensor output, const NVTEQuantizationConfig config, cudaStream_t stream); +/*! \brief Scale a tensor's amax by a scalar. + * + * This is a lightweight utility intended for cases where the amax is + * derived/estimated from another amax value (e.g., post-transform amax + * estimated from pre-transform amax via a linear scale factor). + * + * If `columnwise` is true, scales `tensor.columnwise_amax` if present. + * Otherwise, scales `tensor.amax` if present. If the selected amax pointer + * is null, this function is a no-op. + * + * \param[in,out] tensor Tensor that owns the amax buffer(s). + * \param[in] columnwise Whether to scale columnwise amax (true) or rowwise amax (false). + * \param[in] scale Scalar multiplier applied to the amax value. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_scale_amax(NVTETensor tensor, bool columnwise, float scale, cudaStream_t stream); + /*! \brief Update an FP8 tensor's scale based on its amax. * * This is only supported for FP8 tensors with per-tensor scaling. diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 98e2a29df85..b541b7c696a 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -65,6 +65,9 @@ class QParams: amax_epsilon: optional minimum value of abs max random_hadamard_transform: whether to use random hadamard transform stochastic_rounding: whether to use stocastic rounding + amax_estimation_scale: scale factor for estimating post-RHT amax from pre-RHT amax. + When None, true post-RHT amax is computed (default behavior). + When set to a float, post-RHT amax is estimated as: pre_rht_amax * amax_estimation_scale """ power_2_scale: bool = False @@ -72,6 +75,7 @@ class QParams: random_hadamard_transform: bool = False stochastic_rounding: bool = False fp4_2d_quantization: bool = False + amax_estimation_scale: Optional[float] = None def __repr__(self) -> str: return ( @@ -79,7 +83,8 @@ def __repr__(self) -> str: f"amax_epsilon={self.amax_epsilon},\n" f"random_hadamard_transform={self.random_hadamard_transform},\n" f"stochastic_rounding={self.stochastic_rounding},\n" - f"fp4_2d_quantization={self.fp4_2d_quantization}\n)" + f"fp4_2d_quantization={self.fp4_2d_quantization},\n" + f"amax_estimation_scale={self.amax_estimation_scale}\n)" ) @@ -428,6 +433,16 @@ class NVFP4BlockScaling(Recipe): If set to `True`, stochastic rounding is disabled during quantization for all tensors. disable_2d_quantization : bool, default = False If set to `True`, 1D block scaling with block size 16 is used for all tensors. + use_post_rht_amax_estimation : bool, default = False + **EXPERIMENTAL**: If set to `True`, post-RHT amax is estimated from pre-RHT amax + instead of being computed by a separate RHT+amax kernel. This can reduce the + number of kernel launches but may affect numerical accuracy. + post_rht_amax_estimation_scale_fwd_inp : float, default = 2.0 + Scale factor for estimating post-RHT amax for forward input activations. + Only used when `use_post_rht_amax_estimation=True`. + post_rht_amax_estimation_scale_bwd_grad : float, default = 1.0 + Scale factor for estimating post-RHT amax for backward gradients. + Only used when `use_post_rht_amax_estimation=True`. """ # Configuration envvars @@ -444,10 +459,33 @@ class NVFP4BlockScaling(Recipe): fp8_dpa: bool = False fp8_mha: bool = False + # Experimental: Post-RHT amax estimation + use_post_rht_amax_estimation: bool = ( + os.getenv("NVTE_NVFP4_POST_RHT_AMAX_ESTIMATION", "0") == "1" + ) + post_rht_amax_estimation_scale_fwd_inp = float( + os.getenv("NVTE_NVFP4_POST_RHT_AMAX_ESTIMATION_X_SCALE", "2.0") + ) + post_rht_amax_estimation_scale_bwd_grad = float( + os.getenv("NVTE_NVFP4_POST_RHT_AMAX_ESTIMATION_G_SCALE", "1.0") + ) + def __post_init__(self) -> None: assert self.fp4_format == Format.E2M1, "Only E2M1 is supported for NVFP4 scaling" assert self.fp8_format == Format.E4M3, "Only E4M3 is supported for NVFP4 scaling" + # Determine amax estimation scales (None = use true post-RHT amax) + amax_scale_fwd_inp = ( + self.post_rht_amax_estimation_scale_fwd_inp + if self.use_post_rht_amax_estimation + else None + ) + amax_scale_bwd_grad = ( + self.post_rht_amax_estimation_scale_bwd_grad + if self.use_post_rht_amax_estimation + else None + ) + # Quantization params # Note: RHT is currently only applied to column-wise usage so that # it can be used for wgrad GEMM. @@ -455,6 +493,7 @@ def __post_init__(self) -> None: random_hadamard_transform=not self.disable_rht, stochastic_rounding=False, fp4_2d_quantization=False, + amax_estimation_scale=amax_scale_fwd_inp, ) self.fp4_quant_fwd_weight = QParams( random_hadamard_transform=False, @@ -465,6 +504,7 @@ def __post_init__(self) -> None: random_hadamard_transform=not self.disable_rht, stochastic_rounding=not self.disable_stochastic_rounding, fp4_2d_quantization=False, + amax_estimation_scale=amax_scale_bwd_grad, ) def __repr__(self) -> str: @@ -477,6 +517,7 @@ def __repr__(self) -> str: f"fp4_quant_fwd_inp={self.fp4_quant_fwd_inp}, " f"fp4_quant_fwd_weight={self.fp4_quant_fwd_weight}, " f"fp4_quant_bwd_grad={self.fp4_quant_bwd_grad}, " + f"use_post_rht_amax_estimation={self.use_post_rht_amax_estimation}, " ) diff --git a/transformer_engine/common/recipe/current_scaling.cu b/transformer_engine/common/recipe/current_scaling.cu index ee2c845159b..f0d871dd82c 100644 --- a/transformer_engine/common/recipe/current_scaling.cu +++ b/transformer_engine/common/recipe/current_scaling.cu @@ -188,6 +188,46 @@ void nvte_compute_amax_with_config(const NVTETensor input_, const NVTETensor out compute_amax_impl(input_, output_, stream, config_); } +namespace { + +__global__ void scale_amax_kernel(float *amax_ptr, float scale) { amax_ptr[0] *= scale; } + +} // namespace + +void nvte_scale_amax(NVTETensor tensor_, bool columnwise, float scale, cudaStream_t stream) { + NVTE_API_CALL(nvte_scale_amax); + NVTE_CHECK(tensor_ != nullptr, "Invalid tensor (got NULL)"); + auto &tensor = *transformer_engine::convertNVTETensorCheck(tensor_); + + // Pick amax pointer + void *amax_dptr = nullptr; + if (columnwise) { + amax_dptr = tensor.columnwise_amax.dptr; + } else { + amax_dptr = tensor.amax.dptr; + } + if (amax_dptr == nullptr) { + return; + } + NVTE_CHECK((!columnwise && tensor.amax.numel() == 1) || + (columnwise && tensor.columnwise_amax.numel() == 1), + "Invalid amax buffer (expected 1 element)"); + NVTE_CHECK( + (!columnwise && tensor.amax.dtype == transformer_engine::DType::kFloat32) || + (columnwise && tensor.columnwise_amax.dtype == transformer_engine::DType::kFloat32), + "Invalid amax dtype (expected FP32)"); + + // No-op for scale==1 to save a launch + if (scale == 1.0f) { + return; + } + // Scale should be positive for amax estimation use-cases + NVTE_CHECK(scale > 0.0f, "nvte_scale_amax requires scale > 0 (got ", scale, ")"); + + scale_amax_kernel<<<1, 1, 0, stream>>>(reinterpret_cast(amax_dptr), scale); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + namespace transformer_engine { namespace { diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 978bee52dc1..50b560ea7f0 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -293,6 +293,9 @@ class NVFP4Quantizer : public Quantizer { // random hadamard transform bool with_rht; bool with_post_rht_amax; + // Optional: estimate post-RHT amax from pre-RHT amax using a linear scale + bool with_amax_estimation; + float amax_estimation_scale; // 2D block scaling bool with_2d_quantization; bool stochastic_rounding; diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 14cc084c0c7..c988aa9f678 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -42,8 +42,9 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { - // Post-RHT amax is handled within NVFP4 quantizer + if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax && + !nvfp4_quantizer_cpp->with_amax_estimation) { + // True post-RHT amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; } else { impl = Impl::FUSED_ACTIVATION_AMAX_NVFP4; @@ -154,8 +155,9 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { - // Post-RHT amax is handled within NVFP4 quantizer + if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax && + !nvfp4_quantizer_cpp->with_amax_estimation) { + // True post-RHT amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; } else { impl = Impl::FUSED_ACTIVATION_AMAX_NVFP4; diff --git a/transformer_engine/pytorch/csrc/extensions/bias.cpp b/transformer_engine/pytorch/csrc/extensions/bias.cpp index b0435d27230..e3aebb377f2 100644 --- a/transformer_engine/pytorch/csrc/extensions/bias.cpp +++ b/transformer_engine/pytorch/csrc/extensions/bias.cpp @@ -152,8 +152,9 @@ std::vector dact_dbias( } else if (detail::IsNVFP4Quantizers(quantizer_py.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { - // Post-RHT amax is handled within NVFP4 quantizer + if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax && + !nvfp4_quantizer_cpp->with_amax_estimation) { + // True post-RHT amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; } else { impl = Impl::FUSED_DACT_AMAX_NVFP4; diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index ac541435c7e..923dd904b9e 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -17,6 +17,7 @@ #include "../extensions.h" #include "common.h" #include "pybind.h" +#include "transformer_engine/recipe.h" #include "transformer_engine/transformer_engine.h" namespace transformer_engine { @@ -827,8 +828,39 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input, input.data(), reinterpret_cast(nvte_tensor_output_list.data()), split_sections.data(), num_tensors, 0, quantizer.rht_matrix_random_sign_mask_t, stream); }); + } else if (quantizer.with_amax_estimation) { + // Consume/compute pre-RHT amax, and later estimate post-RHT amax from it + NVTE_SCOPED_GIL_RELEASE({ + for (size_t i = 0; i < num_tensors; ++i) { + if (input_list[i].numel() == 0) { + continue; + } + nvte_compute_amax_with_config(input_list[i].data(), output_list[i].data(), + quant_config_list[i], stream); + + auto rowwise_amax_ptr = output_list[i].get_amax().data_ptr; + auto columnwise_amax_ptr = output_list[i].get_columnwise_amax().data_ptr; + void *amax_ptr = rowwise_amax_ptr != nullptr ? rowwise_amax_ptr : columnwise_amax_ptr; + if (amax_ptr != nullptr) { + if (rowwise_amax_ptr != amax_ptr && rowwise_amax_ptr != nullptr) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(rowwise_amax_ptr, amax_ptr, sizeof(float), + cudaMemcpyDeviceToDevice, stream)); + } + if (columnwise_amax_ptr != amax_ptr && columnwise_amax_ptr != nullptr) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(columnwise_amax_ptr, amax_ptr, sizeof(float), + cudaMemcpyDeviceToDevice, stream)); + } + } + + // Estimate post-RHT amax for columnwise path via scaling. + if (quantizer.with_amax_estimation && quantizer.columnwise_usage) { + nvte_scale_amax(output_list[i].data(), /*columnwise=*/true, + quantizer.amax_estimation_scale, stream); + } + } + }); } else { - // RHT is enabled, but amax is pre-RHT amax + // with_rht but not with_post_rht_amax and not using estimation NVTE_ERROR("NVFP4 split-quantize does not yet support pre-RHT amax"); } diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index 3c5c17fc6f2..2d1332272d0 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -126,8 +126,9 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { - // Post-RHT amax is handled within NVFP4 quantizer + if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax && + !nvfp4_quantizer_cpp->with_amax_estimation) { + // True post-RHT amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; } else if (!transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { // TE kernel supports amax output @@ -355,8 +356,9 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { - // Post-RHT amax is handled within NVFP4 quantizer + if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax && + !nvfp4_quantizer_cpp->with_amax_estimation) { + // True post-RHT amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; } else if (!transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { // TE kernel supports amax output diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index c73c09b317a..bb7d778c9c7 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -9,6 +9,7 @@ #include "common.h" #include "pybind.h" #include "torch/torch.h" +#include "transformer_engine/recipe.h" namespace transformer_engine::pytorch { @@ -1137,6 +1138,16 @@ NVFP4Quantizer::NVFP4Quantizer(const py::handle& quantizer) : Quantizer(quantize this->dtype = quantizer.attr("dtype").cast(); this->with_rht = quantizer.attr("with_rht").cast(); this->with_post_rht_amax = quantizer.attr("with_post_rht_amax").cast(); + // Optional amax estimation scale (None -> disabled) + this->with_amax_estimation = false; + this->amax_estimation_scale = 1.0f; + if (py::hasattr(quantizer, "amax_estimation_scale")) { + auto aes = quantizer.attr("amax_estimation_scale"); + if (!aes.is_none()) { + this->with_amax_estimation = true; + this->amax_estimation_scale = aes.cast(); + } + } this->with_2d_quantization = quantizer.attr("with_2d_quantization").cast(); this->stochastic_rounding = quantizer.attr("stochastic_rounding").cast(); @@ -1510,16 +1521,45 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou if (input.dtype() != DType::kBFloat16) { NVTE_CHECK(false, "RHT is only supported for bfloat16 input"); } + NVTE_CHECK(!(this->with_post_rht_amax && this->with_amax_estimation), + "Invalid NVFP4 config: cannot use post-RHT amax kernel and amax estimation"); + if (this->with_post_rht_amax) { - // We need: + // True post-RHT amax path: // 1. Rowwise amax = amax for input // 2. Columnwise amax = amax for RHT(input.t) NVTE_SCOPED_GIL_RELEASE({ nvte_hadamard_transform_amax(input.data(), out.data(), 0, this->rht_matrix_random_sign_mask_t, stream); }); + } else if (this->with_amax_estimation) { + // Consume/compute pre-RHT amax, and later estimate post-RHT amax from it + if (compute_amax) { + // Compute amax of input tensor (pre-RHT) into an available amax pointer + auto rowwise_amax_ptr = out.get_amax().data_ptr; + auto columnwise_amax_ptr = out.get_columnwise_amax().data_ptr; + void* amax_ptr = rowwise_amax_ptr != nullptr ? rowwise_amax_ptr : columnwise_amax_ptr; + NVTE_CHECK(amax_ptr != nullptr, "Could not find amax pointer"); + + // Compute amax of input tensor + out.set_amax(amax_ptr, DType::kFloat32, std::vector{1}); + NVTE_SCOPED_GIL_RELEASE( + { nvte_compute_amax_with_config(input.data(), out.data(), quant_config, stream); }); + out.set_amax(rowwise_amax_ptr, DType::kFloat32, std::vector{1}); + + // Mirror pre-RHT amax to both row-wise and column-wise amax buffers + if (rowwise_amax_ptr != amax_ptr && rowwise_amax_ptr != nullptr) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(rowwise_amax_ptr, amax_ptr, sizeof(float), + cudaMemcpyDeviceToDevice, stream)); + } + if (columnwise_amax_ptr != amax_ptr && columnwise_amax_ptr != nullptr) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(columnwise_amax_ptr, amax_ptr, sizeof(float), + cudaMemcpyDeviceToDevice, stream)); + } + } + // else: pre-RHT amax is assumed to already be populated (e.g., via fused op + quantize_with_amax) } else { - // raise error since it's not supported yet + // with_rht but not with_post_rht_amax and not using estimation NVTE_CHECK(false, "Pre-RHT amax is not supported yet"); } } else { // Without RHT @@ -1570,6 +1610,14 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou { this->amax_reduction_group->allreduce_coalesced(amax_tensors, opts)->wait(); }); } + // If enabled, estimate post-RHT amax for columnwise path by scaling the (pre-RHT) amax. + // This is intentionally done as a small standalone kernel to avoid touching compute kernels. + if (this->with_rht && !this->with_post_rht_amax && this->with_amax_estimation && + this->columnwise_usage) { + NVTE_SCOPED_GIL_RELEASE( + { nvte_scale_amax(out.data(), /*columnwise=*/true, this->amax_estimation_scale, stream); }); + } + if (this->with_rht) { if (rowwise_usage) { // For rowwise usage, we need to quantize the input directly, but we need to avoid quantizing columnwise diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index fbe2ee6d1cf..2a9e1ca86e6 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -1329,7 +1329,10 @@ def _make_quantizer(idx: int) -> NVFP4Quantizer: rowwise=True, columnwise=True, with_rht=qparams.random_hadamard_transform, - with_post_rht_amax=qparams.random_hadamard_transform, + with_post_rht_amax=( + qparams.random_hadamard_transform and qparams.amax_estimation_scale is None + ), + amax_estimation_scale=qparams.amax_estimation_scale, with_2d_quantization=qparams.fp4_2d_quantization, stochastic_rounding=qparams.stochastic_rounding, ) @@ -1343,7 +1346,11 @@ def _make_quantizer(idx: int) -> NVFP4Quantizer: rowwise=True, columnwise=True, with_rht=self.recipe.fp4_quant_bwd_grad.random_hadamard_transform, - with_post_rht_amax=self.recipe.fp4_quant_bwd_grad.random_hadamard_transform, + with_post_rht_amax=( + self.recipe.fp4_quant_bwd_grad.random_hadamard_transform + and self.recipe.fp4_quant_bwd_grad.amax_estimation_scale is None + ), + amax_estimation_scale=self.recipe.fp4_quant_bwd_grad.amax_estimation_scale, with_2d_quantization=self.recipe.fp4_quant_bwd_grad.fp4_2d_quantization, stochastic_rounding=self.recipe.fp4_quant_bwd_grad.stochastic_rounding, ) diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 0c244628d65..effb3f0c41f 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -117,6 +117,7 @@ class NVFP4Quantizer(Quantizer): """Random Hadamard Transform""" with_rht: bool with_post_rht_amax: bool + amax_estimation_scale: Optional[float] """amax reduction options""" with_amax_reduction: bool amax_reduction_group: Optional[dist_group_type] @@ -140,6 +141,7 @@ def __init__( amax_reduction_group: Optional[dist_group_type] = None, with_rht: bool = False, with_post_rht_amax: bool = False, + amax_estimation_scale: Optional[float] = None, with_2d_quantization: bool = False, stochastic_rounding: bool = False, with_random_sign_mask: bool = True, @@ -148,6 +150,7 @@ def __init__( self.dtype = fp4_dtype self.with_rht = with_rht self.with_post_rht_amax = with_post_rht_amax + self.amax_estimation_scale = amax_estimation_scale self.with_amax_reduction = with_amax_reduction self.amax_reduction_group = amax_reduction_group self.with_2d_quantization = with_2d_quantization @@ -189,6 +192,7 @@ def copy(self) -> NVFP4Quantizer: amax_reduction_group=self.amax_reduction_group, with_rht=self.with_rht, with_post_rht_amax=self.with_post_rht_amax, + amax_estimation_scale=self.amax_estimation_scale, with_2d_quantization=self.with_2d_quantization, stochastic_rounding=self.stochastic_rounding, )