From 640b53b0fb63f53a9a0e57789d3d9f0eb1f84ab8 Mon Sep 17 00:00:00 2001 From: blepping Date: Sat, 6 Sep 2025 21:34:37 -0600 Subject: [PATCH 1/8] Experimental Triton support for Q8_0 and Q4_K --- dequant.py | 12 ++- dequant_triton.py | 269 ++++++++++++++++++++++++++++++++++++++++++++++ nodes.py | 29 +++++ 3 files changed, 308 insertions(+), 2 deletions(-) create mode 100644 dequant_triton.py diff --git a/dequant.py b/dequant.py index 9e545b7..2cf699b 100644 --- a/dequant.py +++ b/dequant.py @@ -3,6 +3,14 @@ import torch from tqdm import tqdm +ALLOW_TRITON = True +try: + from . import dequant_triton + triton_dequantize_functions=dequant_triton.dequantize_functions +except Exception as exc: + print(f"\nGGUF: Failed to enable Triton: {exc}") + triton_dequantize_functions={} + TORCH_COMPATIBLE_QTYPES = (None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16) @@ -18,7 +26,7 @@ def dequantize_tensor(tensor, dtype=None, dequant_dtype=None): if qtype in TORCH_COMPATIBLE_QTYPES: return tensor.to(dtype) - elif qtype in dequantize_functions: + if qtype in dequantize_functions: dequant_dtype = dtype if dequant_dtype == "target" else dequant_dtype return dequantize(tensor.data, qtype, oshape, dtype=dequant_dtype).to(dtype) else: @@ -32,7 +40,7 @@ def dequantize(data, qtype, oshape, dtype=None): Dequantize tensor back to usable shape/dtype """ block_size, type_size = gguf.GGML_QUANT_SIZES[qtype] - dequantize_blocks = dequantize_functions[qtype] + dequantize_blocks = (ALLOW_TRITON and triton_dequantize_functions.get(qtype)) or dequantize_functions[qtype] rows = data.reshape( (-1, data.shape[-1]) diff --git a/dequant_triton.py b/dequant_triton.py new file mode 100644 index 0000000..3548682 --- /dev/null +++ b/dequant_triton.py @@ -0,0 +1,269 @@ +import torch + +import triton +import triton.language as tl + +import gguf + +TORCH_TO_TRITON_DTYPE_MAP = { + torch.float32: tl.float32, + torch.float16: tl.float16, + torch.bfloat16: tl.bfloat16, +} + +# K Quants # +QK_K = 256 +K_SCALE_SIZE = 12 + + +@triton.autotune( + configs=[ + # Test different numbers of GGUF blocks per program instance + triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=4), + triton.Config({"N_BLOCKS_PER_PROG": 2}, num_warps=4), + triton.Config({"N_BLOCKS_PER_PROG": 4}, num_warps=4), + triton.Config({"N_BLOCKS_PER_PROG": 8}, num_warps=4), + triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=8), + triton.Config({"N_BLOCKS_PER_PROG": 2}, num_warps=8), + triton.Config({"N_BLOCKS_PER_PROG": 4}, num_warps=8), + triton.Config({"N_BLOCKS_PER_PROG": 8}, num_warps=8), + ], + key=["n_total_blocks"], # Tune based on the total number of blocks +) +@triton.jit +def dequantize_q8_0_kernel( + q_tensor_ptr, + out_tensor_ptr, + n_total_blocks, + GGUF_BLOCK_SIZE: tl.constexpr, + GGUF_TYPE_SIZE: tl.constexpr, + N_BLOCKS_PER_PROG: tl.constexpr, # How many blocks each program handles + OUT_DTYPE: tl.constexpr, +): + # Each program is responsible for a chunk of N_BLOCKS_PER_PROG blocks + pid = tl.program_id(axis=0) + + # Starting GGUF block index for this program + start_block_idx = pid * N_BLOCKS_PER_PROG + + # Create offsets for the weights within a GGUF block (0, 1, ..., 31) + weight_indices = tl.arange(0, GGUF_BLOCK_SIZE) + + # Loop over the N blocks assigned to this program + for i in range(N_BLOCKS_PER_PROG): + current_block_idx = start_block_idx + i + + # Boundary check to avoid processing padding blocks + if current_block_idx < n_total_blocks: + # Pointer to the start of the current GGUF block in the input tensor + block_start_ptr = q_tensor_ptr + current_block_idx * GGUF_TYPE_SIZE + + # Load scale (d) + uint16_ptr = block_start_ptr.to(tl.pointer_type(tl.uint16)) + uint16_val = tl.load(uint16_ptr) + scale_fp16 = tl.cast(uint16_val, tl.float16, bitcast=True) + scale = scale_fp16.to(tl.float32) + + # Load weights (x) + q_weights_ptr = block_start_ptr + 2 + uint8_weights = tl.load(q_weights_ptr + weight_indices) + q_weights = uint8_weights.to(tl.int8) + + # Dequantize + dequantized_weights = q_weights.to(OUT_DTYPE) * scale + # dequantized_weights = q_weights.to(tl.float32) * scale + + # Store the result + output_start_ptr = out_tensor_ptr + current_block_idx * GGUF_BLOCK_SIZE + tl.store(output_start_ptr + weight_indices, dequantized_weights) + + +def dequantize_blocks_Q8_0_triton( + blocks: torch.Tensor, + block_size: int, + type_size: int, + dtype=None, +) -> torch.Tensor: + GGUF_BLOCK_SIZE = 32 + GGUF_TYPE_SIZE = 34 + + assert blocks.dtype == torch.uint8 and blocks.is_cuda and blocks.is_contiguous() + + n_elements = blocks.numel() + assert n_elements % GGUF_TYPE_SIZE == 0 + n_total_blocks = n_elements // GGUF_TYPE_SIZE + + dtype = dtype or torch.float32 + triton_dtype = TORCH_TO_TRITON_DTYPE_MAP.get(dtype) + if triton_dtype is None: + raise TypeError(f"Unsupported output dtype {dtype}") + + out_tensor = torch.empty( + (n_total_blocks * GGUF_BLOCK_SIZE,), + dtype=dtype, + device=blocks.device, + ) + + def grid(meta): + return (triton.cdiv(n_total_blocks, meta["N_BLOCKS_PER_PROG"]),) + + dequantize_q8_0_kernel[grid]( + blocks, + out_tensor, + n_total_blocks, + GGUF_BLOCK_SIZE=GGUF_BLOCK_SIZE, + GGUF_TYPE_SIZE=GGUF_TYPE_SIZE, + OUT_DTYPE=triton_dtype, + ) + + return out_tensor.reshape(n_total_blocks, -1) + + +@triton.autotune( + configs=[ + triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=4), + triton.Config({"N_BLOCKS_PER_PROG": 2}, num_warps=4), + triton.Config({"N_BLOCKS_PER_PROG": 4}, num_warps=4), + triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=8), + triton.Config({"N_BLOCKS_PER_PROG": 2}, num_warps=8), + triton.Config({"N_BLOCKS_PER_PROG": 4}, num_warps=8), + ], + key=["n_total_blocks"], +) +@triton.jit +def dequantize_q4_k_kernel( + q_tensor_ptr, + out_tensor_ptr, + n_total_blocks, + QK_K: tl.constexpr, + Q4_K_TYPE_SIZE: tl.constexpr, + K_SCALE_SIZE: tl.constexpr, + N_BLOCKS_PER_PROG: tl.constexpr, + OUT_DTYPE: tl.constexpr, +): + pid = tl.program_id(axis=0) + start_block_idx = pid * N_BLOCKS_PER_PROG + + qs_chunk_offsets = tl.arange(0, 32) + store_offsets = tl.arange(0, 32) + + for i in range(N_BLOCKS_PER_PROG): + current_block_idx = start_block_idx + i + if current_block_idx < n_total_blocks: + block_start_ptr = q_tensor_ptr + current_block_idx * Q4_K_TYPE_SIZE + output_start_ptr = out_tensor_ptr + current_block_idx * QK_K + + d = tl.load(block_start_ptr.to(tl.pointer_type(tl.float16))).to(OUT_DTYPE) + dmin = tl.load((block_start_ptr + 2).to(tl.pointer_type(tl.float16))).to( + OUT_DTYPE + ) + + scales_ptr_u32 = (block_start_ptr + 4).to(tl.pointer_type(tl.uint32)) + d_sc_word = tl.load(scales_ptr_u32 + 0) + m_word = tl.load(scales_ptr_u32 + 1) + m_sc_word = tl.load(scales_ptr_u32 + 2) + + qs_start_ptr = block_start_ptr + 4 + K_SCALE_SIZE + + # Process in 4 chunks of 64 values + for k_chunk in range(4): + # Scale indices for low (a) and high (b) nibbles + k_idx_a = 2 * k_chunk + k_idx_b = 2 * k_chunk + 1 + + # --- Calculate Scale A (for low nibbles) --- + if k_idx_a < 4: + d_sc_byte_a = (d_sc_word >> (k_idx_a * 8)) & 0xFF + m_byte_a = (m_word >> (k_idx_a * 8)) & 0xFF + sc_a = d_sc_byte_a & 0x3F + m_a = m_byte_a & 0x3F + else: + k_prime_a = k_idx_a - 4 + d_sc_byte_a = (d_sc_word >> (k_prime_a * 8)) & 0xFF + m_byte_a = (m_word >> (k_prime_a * 8)) & 0xFF + m_sc_byte_a = (m_sc_word >> (k_prime_a * 8)) & 0xFF + sc_a = (m_sc_byte_a & 0x0F) | ((d_sc_byte_a >> 2) & 0x30) + m_a = (m_sc_byte_a >> 4) | ((m_byte_a >> 2) & 0x30) + + # --- Calculate Scale B (for high nibbles) --- + if k_idx_b < 4: + d_sc_byte_b = (d_sc_word >> (k_idx_b * 8)) & 0xFF + m_byte_b = (m_word >> (k_idx_b * 8)) & 0xFF + sc_b = d_sc_byte_b & 0x3F + m_b = m_byte_b & 0x3F + else: + k_prime_b = k_idx_b - 4 + d_sc_byte_b = (d_sc_word >> (k_prime_b * 8)) & 0xFF + m_byte_b = (m_word >> (k_prime_b * 8)) & 0xFF + m_sc_byte_b = (m_sc_word >> (k_prime_b * 8)) & 0xFF + sc_b = (m_sc_byte_b & 0x0F) | ((d_sc_byte_b >> 2) & 0x30) + m_b = (m_sc_byte_b >> 4) | ((m_byte_b >> 2) & 0x30) + + current_d_a = d * sc_a.to(OUT_DTYPE) + current_dm_a = dmin * m_a.to(OUT_DTYPE) + current_d_b = d * sc_b.to(OUT_DTYPE) + current_dm_b = dmin * m_b.to(OUT_DTYPE) + + # Load 32 bytes of quantized data + chunk_qs_ptr = qs_start_ptr + k_chunk * 32 + qs_bytes_chunk = tl.load(chunk_qs_ptr + qs_chunk_offsets) + + qs_low = (qs_bytes_chunk & 0x0F).to(OUT_DTYPE) + qs_high = (qs_bytes_chunk >> 4).to(OUT_DTYPE) + + dequant_low = current_d_a * qs_low - current_dm_a + dequant_high = current_d_b * qs_high - current_dm_b + + # Store results contiguously + output_chunk_ptr = output_start_ptr + k_chunk * 64 + tl.store(output_chunk_ptr + store_offsets, dequant_low.to(OUT_DTYPE)) + tl.store( + output_chunk_ptr + 32 + store_offsets, dequant_high.to(OUT_DTYPE) + ) + + +def dequantize_blocks_Q4_K_triton( + blocks: torch.Tensor, + block_size: int, + type_size: int, + dtype=None, +) -> torch.Tensor: + Q4_K_TYPE_SIZE = 144 + + assert blocks.dtype == torch.uint8 and blocks.is_cuda and blocks.is_contiguous() + + n_elements = blocks.numel() + assert n_elements % Q4_K_TYPE_SIZE == 0 + n_total_blocks = n_elements // Q4_K_TYPE_SIZE + + dtype = dtype or torch.float32 + triton_dtype = TORCH_TO_TRITON_DTYPE_MAP.get(dtype) + if triton_dtype is None: + raise TypeError(f"Unsupported output dtype {dtype}") + + out_tensor = torch.empty( + (n_total_blocks * QK_K,), dtype=dtype, device=blocks.device + ) + + def grid(meta): + return (triton.cdiv(n_total_blocks, meta["N_BLOCKS_PER_PROG"]),) + + dequantize_q4_k_kernel[grid]( + blocks, + out_tensor, + n_total_blocks, + QK_K=QK_K, + Q4_K_TYPE_SIZE=Q4_K_TYPE_SIZE, + K_SCALE_SIZE=K_SCALE_SIZE, + OUT_DTYPE=triton_dtype, + ) + + return out_tensor.reshape(n_total_blocks, -1) + + +dequantize_functions = { + gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0_triton, + gguf.GGMLQuantizationType.Q4_K: dequantize_blocks_Q4_K_triton, +} + +__all__ = ("dequantize_functions",) diff --git a/nodes.py b/nodes.py index 4159142..dc962c6 100644 --- a/nodes.py +++ b/nodes.py @@ -16,6 +16,8 @@ from .loader import gguf_sd_loader, gguf_clip_loader from .dequant import is_quantized, is_torch_compatible +from . import dequant + def update_folder_names_and_paths(key, targets=[]): # check for existing key base = folder_paths.folder_names_and_paths.get(key, ([], {})) @@ -295,6 +297,32 @@ def load_clip(self, clip_name1, clip_name2, clip_name3, clip_name4, type="stable clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) return (self.load_patcher(clip_paths, clip_type, self.load_data(clip_paths)),) +class GGUFTritonToggle: + @classmethod + def INPUT_TYPES(cls) -> dict: + return { + "required": { + "passthrough_model": ("MODEL",), + "enabled": ( + "BOOLEAN", + {"default": bool(dequant.triton_dequantize_functions)}, + ) + } + } + + TITLE = "Triton toggle (GGUF)" + RETURN_TYPES = ("MODEL",) + FUNCTION = "go" + CATEGORY = "hacks" + + @classmethod + def go(cls, *, enabled: bool, passthrough_model: object) -> tuple[object]: + dequant.ALLOW_TRITON = dequant.triton_dequantize_functions and enabled + if enabled: + print(f"\nGGUF: Enabling Triton, supported quants: {tuple(dequant.triton_dequantize_functions)}") + return (passthrough_model.clone(),) + + NODE_CLASS_MAPPINGS = { "UnetLoaderGGUF": UnetLoaderGGUF, "CLIPLoaderGGUF": CLIPLoaderGGUF, @@ -302,4 +330,5 @@ def load_clip(self, clip_name1, clip_name2, clip_name3, clip_name4, type="stable "TripleCLIPLoaderGGUF": TripleCLIPLoaderGGUF, "QuadrupleCLIPLoaderGGUF": QuadrupleCLIPLoaderGGUF, "UnetLoaderGGUFAdvanced": UnetLoaderGGUFAdvanced, + "GGUFTritonToggle": GGUFTritonToggle, } From ce4c9b258c15e99affceab85e6987da4973262a7 Mon Sep 17 00:00:00 2001 From: blepping Date: Mon, 8 Sep 2025 00:41:17 -0600 Subject: [PATCH 2/8] Add Q6_K Triton kernel --- dequant_triton.py | 133 ++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 124 insertions(+), 9 deletions(-) diff --git a/dequant_triton.py b/dequant_triton.py index 3548682..d781a6a 100644 --- a/dequant_triton.py +++ b/dequant_triton.py @@ -71,7 +71,6 @@ def dequantize_q8_0_kernel( # Dequantize dequantized_weights = q_weights.to(OUT_DTYPE) * scale - # dequantized_weights = q_weights.to(tl.float32) * scale # Store the result output_start_ptr = out_tensor_ptr + current_block_idx * GGUF_BLOCK_SIZE @@ -136,7 +135,7 @@ def dequantize_q4_k_kernel( out_tensor_ptr, n_total_blocks, QK_K: tl.constexpr, - Q4_K_TYPE_SIZE: tl.constexpr, + TYPE_SIZE: tl.constexpr, K_SCALE_SIZE: tl.constexpr, N_BLOCKS_PER_PROG: tl.constexpr, OUT_DTYPE: tl.constexpr, @@ -150,7 +149,7 @@ def dequantize_q4_k_kernel( for i in range(N_BLOCKS_PER_PROG): current_block_idx = start_block_idx + i if current_block_idx < n_total_blocks: - block_start_ptr = q_tensor_ptr + current_block_idx * Q4_K_TYPE_SIZE + block_start_ptr = q_tensor_ptr + current_block_idx * TYPE_SIZE output_start_ptr = out_tensor_ptr + current_block_idx * QK_K d = tl.load(block_start_ptr.to(tl.pointer_type(tl.float16))).to(OUT_DTYPE) @@ -228,13 +227,11 @@ def dequantize_blocks_Q4_K_triton( type_size: int, dtype=None, ) -> torch.Tensor: - Q4_K_TYPE_SIZE = 144 - assert blocks.dtype == torch.uint8 and blocks.is_cuda and blocks.is_contiguous() n_elements = blocks.numel() - assert n_elements % Q4_K_TYPE_SIZE == 0 - n_total_blocks = n_elements // Q4_K_TYPE_SIZE + assert n_elements % type_size == 0 + n_total_blocks = n_elements // type_size dtype = dtype or torch.float32 triton_dtype = TORCH_TO_TRITON_DTYPE_MAP.get(dtype) @@ -253,7 +250,7 @@ def grid(meta): out_tensor, n_total_blocks, QK_K=QK_K, - Q4_K_TYPE_SIZE=Q4_K_TYPE_SIZE, + TYPE_SIZE=type_size, K_SCALE_SIZE=K_SCALE_SIZE, OUT_DTYPE=triton_dtype, ) @@ -261,9 +258,127 @@ def grid(meta): return out_tensor.reshape(n_total_blocks, -1) +@triton.autotune( + configs=[ + triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=4), + triton.Config({"N_BLOCKS_PER_PROG": 2}, num_warps=4), + triton.Config({"N_BLOCKS_PER_PROG": 4}, num_warps=4), + triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=8), + triton.Config({"N_BLOCKS_PER_PROG": 2}, num_warps=8), + triton.Config({"N_BLOCKS_PER_PROG": 4}, num_warps=8), + ], + key=["n_total_blocks"], +) +@triton.jit +def dequantize_q6_k_kernel( + q_tensor_ptr, + out_tensor_ptr, + n_total_blocks, + QK_K: tl.constexpr, + TYPE_SIZE: tl.constexpr, + OUT_DTYPE: tl.constexpr, + N_BLOCKS_PER_PROG: tl.constexpr, +): + pid = tl.program_id(axis=0) + start_block_idx = pid * N_BLOCKS_PER_PROG + offsets_32 = tl.arange(0, 32) + mask_16 = offsets_32 < 16 + + for i in range(N_BLOCKS_PER_PROG): + current_block_idx = start_block_idx + i + if current_block_idx < n_total_blocks: + block_start_ptr = q_tensor_ptr + current_block_idx * TYPE_SIZE + output_start_ptr = out_tensor_ptr + current_block_idx * QK_K + + d_ptr = block_start_ptr + 208 + scales_ptr = block_start_ptr + 192 + d_super_scale = tl.load(d_ptr.to(tl.pointer_type(tl.float16))).to( + tl.float32 + ) + + # Process block in 8 chunks of 32 values + for chunk_idx in range(8): + # 1. Calculate ql source data and unpack + ql_byte_offset = (chunk_idx % 2) * 32 + (chunk_idx // 4) * 64 + ql_ptr = block_start_ptr + ql_byte_offset + ql_32_bytes = tl.load(ql_ptr + offsets_32) + + use_low_nibbles = (chunk_idx // 2) % 2 == 0 + if use_low_nibbles: + ql_vec_32 = (ql_32_bytes & 0x0F).to(tl.int8) + else: + ql_vec_32 = (ql_32_bytes >> 4).to(tl.int8) + + # 2. Calculate qh source data and unpack + qh_byte_offset = (chunk_idx // 4) * 32 + qh_ptr = block_start_ptr + 128 + qh_byte_offset + qh_32_bytes = tl.load(qh_ptr + offsets_32) + + bit_shift = (chunk_idx % 4) * 2 + qh_vec_32 = ((qh_32_bytes >> bit_shift) & 0x03).to(tl.int8) + + # 3. Combine and dequantize + q_vec_32 = (ql_vec_32 | (qh_vec_32 << 4)) - 32 + + # 4. Load and apply correct scales + scale_0_ptr = scales_ptr + chunk_idx * 2 + scale_1_ptr = scales_ptr + chunk_idx * 2 + 1 + scale_0 = d_super_scale * tl.load(scale_0_ptr).to(tl.int8).to( + tl.float32 + ) + scale_1 = d_super_scale * tl.load(scale_1_ptr).to(tl.int8).to( + tl.float32 + ) + + scales_32 = tl.where(mask_16, scale_0, scale_1) + dequant_32 = q_vec_32.to(OUT_DTYPE) * scales_32 + + # 5. Store result + output_ptr = output_start_ptr + chunk_idx * 32 + tl.store(output_ptr + offsets_32, dequant_32) + + +def dequantize_blocks_Q6_K_triton( + blocks: torch.Tensor, + block_size: int, + type_size: int, + dtype=None, +) -> torch.Tensor: + assert blocks.dtype == torch.uint8 and blocks.is_cuda and blocks.is_contiguous() + + n_elements = blocks.numel() + assert n_elements % type_size == 0 + n_total_blocks = n_elements // type_size + + dtype = dtype or torch.float32 + triton_dtype = TORCH_TO_TRITON_DTYPE_MAP.get(dtype) + if triton_dtype is None: + raise TypeError(f"Unsupported output dtype {dtype}") + + out_tensor = torch.empty( + (n_total_blocks * QK_K,), dtype=dtype, device=blocks.device + ) + + def grid(meta): + return (triton.cdiv(n_total_blocks, meta["N_BLOCKS_PER_PROG"]),) + + dequantize_q6_k_kernel[grid]( + blocks, + out_tensor, + n_total_blocks, + QK_K=QK_K, + TYPE_SIZE=type_size, + OUT_DTYPE=triton_dtype, + ) + + return out_tensor.reshape(n_total_blocks, -1) + + dequantize_functions = { - gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0_triton, + # Q8_0 simply seems than the PyTorch implementation. + # gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0_triton, gguf.GGMLQuantizationType.Q4_K: dequantize_blocks_Q4_K_triton, + gguf.GGMLQuantizationType.Q6_K: dequantize_blocks_Q6_K_triton, } __all__ = ("dequantize_functions",) From 2ac36d4c6285ac9402bf656f89c090cc7c9204c0 Mon Sep 17 00:00:00 2001 From: blepping Date: Mon, 8 Sep 2025 03:01:36 -0600 Subject: [PATCH 3/8] Add Q5_K Triton kernel --- dequant_triton.py | 120 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 120 insertions(+) diff --git a/dequant_triton.py b/dequant_triton.py index d781a6a..b51d793 100644 --- a/dequant_triton.py +++ b/dequant_triton.py @@ -258,6 +258,125 @@ def grid(meta): return out_tensor.reshape(n_total_blocks, -1) +@triton.autotune( + configs=[ + triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=4), + triton.Config({"N_BLOCKS_PER_PROG": 2}, num_warps=4), + triton.Config({"N_BLOCKS_PER_PROG": 4}, num_warps=4), + triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=8), + triton.Config({"N_BLOCKS_PER_PROG": 2}, num_warps=8), + triton.Config({"N_BLOCKS_PER_PROG": 4}, num_warps=8), + ], + key=["n_total_blocks"], +) +@triton.jit +def dequantize_q5_k_kernel( + q_tensor_ptr, + out_tensor_ptr, + n_total_blocks, + QK_K: tl.constexpr, + TYPE_SIZE: tl.constexpr, + K_SCALE_SIZE: tl.constexpr, + OUT_DTYPE: tl.constexpr, + N_BLOCKS_PER_PROG: tl.constexpr, +): + pid = tl.program_id(axis=0) + start_block_idx = pid * N_BLOCKS_PER_PROG + + offsets_32 = tl.arange(0, 32) + + for i in range(N_BLOCKS_PER_PROG): + current_block_idx = start_block_idx + i + if current_block_idx < n_total_blocks: + # Pointers and initial loads + block_start_ptr = q_tensor_ptr + current_block_idx * TYPE_SIZE + output_start_ptr = out_tensor_ptr + current_block_idx * QK_K + d = tl.load(block_start_ptr.to(tl.pointer_type(tl.float16))).to(tl.float32) + dmin = tl.load((block_start_ptr + 2).to(tl.pointer_type(tl.float16))).to( + tl.float32 + ) + + scales_ptr_u32 = (block_start_ptr + 4).to(tl.pointer_type(tl.uint32)) + d_sc_word = tl.load(scales_ptr_u32 + 0) + m_word = tl.load(scales_ptr_u32 + 1) + m_sc_word = tl.load(scales_ptr_u32 + 2) + + qh_start_ptr = block_start_ptr + 4 + K_SCALE_SIZE + qs_start_ptr = qh_start_ptr + QK_K // 8 + + qh_bytes_all = tl.load(qh_start_ptr + offsets_32) + + # Process in 8 chunks of 32 values + for chunk_idx in range(8): + # 1. Unpack scale and min for this chunk + if chunk_idx < 4: + sc = ((d_sc_word >> (chunk_idx * 8)) & 0xFF) & 0x3F + m = ((m_word >> (chunk_idx * 8)) & 0xFF) & 0x3F + else: + k_prime = chunk_idx - 4 + d_sc_byte = (d_sc_word >> (k_prime * 8)) & 0xFF + m_byte = (m_word >> (k_prime * 8)) & 0xFF + m_sc_byte = (m_sc_word >> (k_prime * 8)) & 0xFF + sc = (m_sc_byte & 0x0F) | ((d_sc_byte >> 2) & 0x30) + m = (m_sc_byte >> 4) | ((m_byte >> 2) & 0x30) + + final_d = d * sc.to(tl.float32) + final_dm = dmin * m.to(tl.float32) + + # 2. Unpack ql (lower 4 bits) for this chunk + qs_byte_offset = (chunk_idx // 2) * 32 + qs_bytes = tl.load(qs_start_ptr + qs_byte_offset + offsets_32) + use_low_nibbles = chunk_idx % 2 == 0 + ql = tl.where(use_low_nibbles, qs_bytes & 0x0F, qs_bytes >> 4) + + # 3. Unpack qh (higher 1 bit) for this chunk + qh = (qh_bytes_all >> chunk_idx) & 0x01 + + # 4. Combine, dequantize, and store + q = ql.to(tl.uint8) | (qh.to(tl.uint8) << 4) + dequant_32 = final_d * q.to(tl.float32) - final_dm + + output_ptr = output_start_ptr + chunk_idx * 32 + tl.store(output_ptr + offsets_32, dequant_32) + + +def dequantize_blocks_Q5_K_triton( + blocks: torch.Tensor, + block_size: int, + type_size: int, + dtype=None, +) -> torch.Tensor: + assert blocks.dtype == torch.uint8 and blocks.is_cuda and blocks.is_contiguous() + + n_elements = blocks.numel() + assert n_elements % type_size == 0 + n_total_blocks = n_elements // type_size + + dtype = dtype or torch.float32 + triton_dtype = TORCH_TO_TRITON_DTYPE_MAP.get(dtype) + if triton_dtype is None: + raise TypeError(f"Unsupported output dtype {dtype}") + + out_tensor = torch.empty( + (n_total_blocks * QK_K,), dtype=dtype, device=blocks.device + ) + + def grid(meta): + return (triton.cdiv(n_total_blocks, meta["N_BLOCKS_PER_PROG"]),) + + dequantize_q5_k_kernel[grid]( + blocks, + out_tensor, + n_total_blocks, + QK_K=QK_K, + TYPE_SIZE=type_size, + K_SCALE_SIZE=K_SCALE_SIZE, + OUT_DTYPE=triton_dtype, + ) + + return out_tensor.reshape(n_total_blocks, -1) + + @triton.autotune( configs=[ triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=4), @@ -378,6 +497,7 @@ def grid(meta): # Q8_0 simply seems than the PyTorch implementation. # gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0_triton, gguf.GGMLQuantizationType.Q4_K: dequantize_blocks_Q4_K_triton, + # gguf.GGMLQuantizationType.Q5_K: dequantize_blocks_Q5_K_triton, gguf.GGMLQuantizationType.Q6_K: dequantize_blocks_Q6_K_triton, } From a5d5f6d89cd35f444587bb421084abc71e8dd145 Mon Sep 17 00:00:00 2001 From: blepping Date: Mon, 8 Sep 2025 03:03:08 -0600 Subject: [PATCH 4/8] Actually enable the Q5_K kernel --- dequant_triton.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dequant_triton.py b/dequant_triton.py index b51d793..f356671 100644 --- a/dequant_triton.py +++ b/dequant_triton.py @@ -497,7 +497,7 @@ def grid(meta): # Q8_0 simply seems than the PyTorch implementation. # gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0_triton, gguf.GGMLQuantizationType.Q4_K: dequantize_blocks_Q4_K_triton, - # gguf.GGMLQuantizationType.Q5_K: dequantize_blocks_Q5_K_triton, + gguf.GGMLQuantizationType.Q5_K: dequantize_blocks_Q5_K_triton, gguf.GGMLQuantizationType.Q6_K: dequantize_blocks_Q6_K_triton, } From 6704e34505553f40a84dae2a28c8ad62f13f3d2a Mon Sep 17 00:00:00 2001 From: blepping Date: Mon, 8 Sep 2025 03:54:01 -0600 Subject: [PATCH 5/8] Triton dequant code cleanups --- dequant_triton.py | 68 +++++++++++++++++++++-------------------------- 1 file changed, 31 insertions(+), 37 deletions(-) diff --git a/dequant_triton.py b/dequant_triton.py index f356671..e4fc2bd 100644 --- a/dequant_triton.py +++ b/dequant_triton.py @@ -40,6 +40,8 @@ def dequantize_q8_0_kernel( N_BLOCKS_PER_PROG: tl.constexpr, # How many blocks each program handles OUT_DTYPE: tl.constexpr, ): + out_dtype = OUT_DTYPE.value + # Each program is responsible for a chunk of N_BLOCKS_PER_PROG blocks pid = tl.program_id(axis=0) @@ -62,7 +64,7 @@ def dequantize_q8_0_kernel( uint16_ptr = block_start_ptr.to(tl.pointer_type(tl.uint16)) uint16_val = tl.load(uint16_ptr) scale_fp16 = tl.cast(uint16_val, tl.float16, bitcast=True) - scale = scale_fp16.to(tl.float32) + scale = scale_fp16.to(out_dtype) # Load weights (x) q_weights_ptr = block_start_ptr + 2 @@ -70,7 +72,7 @@ def dequantize_q8_0_kernel( q_weights = uint8_weights.to(tl.int8) # Dequantize - dequantized_weights = q_weights.to(OUT_DTYPE) * scale + dequantized_weights = q_weights.to(out_dtype) * scale # Store the result output_start_ptr = out_tensor_ptr + current_block_idx * GGUF_BLOCK_SIZE @@ -83,14 +85,11 @@ def dequantize_blocks_Q8_0_triton( type_size: int, dtype=None, ) -> torch.Tensor: - GGUF_BLOCK_SIZE = 32 - GGUF_TYPE_SIZE = 34 - assert blocks.dtype == torch.uint8 and blocks.is_cuda and blocks.is_contiguous() n_elements = blocks.numel() - assert n_elements % GGUF_TYPE_SIZE == 0 - n_total_blocks = n_elements // GGUF_TYPE_SIZE + assert n_elements % type_size == 0 + n_total_blocks = n_elements // type_size dtype = dtype or torch.float32 triton_dtype = TORCH_TO_TRITON_DTYPE_MAP.get(dtype) @@ -98,7 +97,7 @@ def dequantize_blocks_Q8_0_triton( raise TypeError(f"Unsupported output dtype {dtype}") out_tensor = torch.empty( - (n_total_blocks * GGUF_BLOCK_SIZE,), + (n_total_blocks * block_size,), dtype=dtype, device=blocks.device, ) @@ -110,8 +109,8 @@ def grid(meta): blocks, out_tensor, n_total_blocks, - GGUF_BLOCK_SIZE=GGUF_BLOCK_SIZE, - GGUF_TYPE_SIZE=GGUF_TYPE_SIZE, + GGUF_BLOCK_SIZE=block_size, + GGUF_TYPE_SIZE=type_size, OUT_DTYPE=triton_dtype, ) @@ -140,6 +139,7 @@ def dequantize_q4_k_kernel( N_BLOCKS_PER_PROG: tl.constexpr, OUT_DTYPE: tl.constexpr, ): + out_dtype = OUT_DTYPE.value pid = tl.program_id(axis=0) start_block_idx = pid * N_BLOCKS_PER_PROG @@ -152,9 +152,9 @@ def dequantize_q4_k_kernel( block_start_ptr = q_tensor_ptr + current_block_idx * TYPE_SIZE output_start_ptr = out_tensor_ptr + current_block_idx * QK_K - d = tl.load(block_start_ptr.to(tl.pointer_type(tl.float16))).to(OUT_DTYPE) + d = tl.load(block_start_ptr.to(tl.pointer_type(tl.float16))).to(out_dtype) dmin = tl.load((block_start_ptr + 2).to(tl.pointer_type(tl.float16))).to( - OUT_DTYPE + out_dtype ) scales_ptr_u32 = (block_start_ptr + 4).to(tl.pointer_type(tl.uint32)) @@ -198,27 +198,25 @@ def dequantize_q4_k_kernel( sc_b = (m_sc_byte_b & 0x0F) | ((d_sc_byte_b >> 2) & 0x30) m_b = (m_sc_byte_b >> 4) | ((m_byte_b >> 2) & 0x30) - current_d_a = d * sc_a.to(OUT_DTYPE) - current_dm_a = dmin * m_a.to(OUT_DTYPE) - current_d_b = d * sc_b.to(OUT_DTYPE) - current_dm_b = dmin * m_b.to(OUT_DTYPE) + current_d_a = d * sc_a.to(out_dtype) + current_dm_a = dmin * m_a.to(out_dtype) + current_d_b = d * sc_b.to(out_dtype) + current_dm_b = dmin * m_b.to(out_dtype) # Load 32 bytes of quantized data chunk_qs_ptr = qs_start_ptr + k_chunk * 32 qs_bytes_chunk = tl.load(chunk_qs_ptr + qs_chunk_offsets) - qs_low = (qs_bytes_chunk & 0x0F).to(OUT_DTYPE) - qs_high = (qs_bytes_chunk >> 4).to(OUT_DTYPE) + qs_low = (qs_bytes_chunk & 0x0F).to(out_dtype) + qs_high = (qs_bytes_chunk >> 4).to(out_dtype) dequant_low = current_d_a * qs_low - current_dm_a dequant_high = current_d_b * qs_high - current_dm_b # Store results contiguously output_chunk_ptr = output_start_ptr + k_chunk * 64 - tl.store(output_chunk_ptr + store_offsets, dequant_low.to(OUT_DTYPE)) - tl.store( - output_chunk_ptr + 32 + store_offsets, dequant_high.to(OUT_DTYPE) - ) + tl.store(output_chunk_ptr + store_offsets, dequant_low) + tl.store(output_chunk_ptr + 32 + store_offsets, dequant_high) def dequantize_blocks_Q4_K_triton( @@ -280,6 +278,7 @@ def dequantize_q5_k_kernel( OUT_DTYPE: tl.constexpr, N_BLOCKS_PER_PROG: tl.constexpr, ): + out_dtype = OUT_DTYPE.value pid = tl.program_id(axis=0) start_block_idx = pid * N_BLOCKS_PER_PROG @@ -291,9 +290,9 @@ def dequantize_q5_k_kernel( # Pointers and initial loads block_start_ptr = q_tensor_ptr + current_block_idx * TYPE_SIZE output_start_ptr = out_tensor_ptr + current_block_idx * QK_K - d = tl.load(block_start_ptr.to(tl.pointer_type(tl.float16))).to(tl.float32) + d = tl.load(block_start_ptr.to(tl.pointer_type(tl.float16))).to(out_dtype) dmin = tl.load((block_start_ptr + 2).to(tl.pointer_type(tl.float16))).to( - tl.float32 + out_dtype ) scales_ptr_u32 = (block_start_ptr + 4).to(tl.pointer_type(tl.uint32)) @@ -320,8 +319,8 @@ def dequantize_q5_k_kernel( sc = (m_sc_byte & 0x0F) | ((d_sc_byte >> 2) & 0x30) m = (m_sc_byte >> 4) | ((m_byte >> 2) & 0x30) - final_d = d * sc.to(tl.float32) - final_dm = dmin * m.to(tl.float32) + final_d = d * sc.to(out_dtype) + final_dm = dmin * m.to(out_dtype) # 2. Unpack ql (lower 4 bits) for this chunk qs_byte_offset = (chunk_idx // 2) * 32 @@ -334,7 +333,7 @@ def dequantize_q5_k_kernel( # 4. Combine, dequantize, and store q = ql.to(tl.uint8) | (qh.to(tl.uint8) << 4) - dequant_32 = final_d * q.to(tl.float32) - final_dm + dequant_32 = final_d * q.to(out_dtype) - final_dm output_ptr = output_start_ptr + chunk_idx * 32 tl.store(output_ptr + offsets_32, dequant_32) @@ -398,6 +397,7 @@ def dequantize_q6_k_kernel( OUT_DTYPE: tl.constexpr, N_BLOCKS_PER_PROG: tl.constexpr, ): + out_dtype = OUT_DTYPE.value pid = tl.program_id(axis=0) start_block_idx = pid * N_BLOCKS_PER_PROG offsets_32 = tl.arange(0, 32) @@ -411,9 +411,7 @@ def dequantize_q6_k_kernel( d_ptr = block_start_ptr + 208 scales_ptr = block_start_ptr + 192 - d_super_scale = tl.load(d_ptr.to(tl.pointer_type(tl.float16))).to( - tl.float32 - ) + d_super_scale = tl.load(d_ptr.to(tl.pointer_type(tl.float16))).to(out_dtype) # Process block in 8 chunks of 32 values for chunk_idx in range(8): @@ -442,15 +440,11 @@ def dequantize_q6_k_kernel( # 4. Load and apply correct scales scale_0_ptr = scales_ptr + chunk_idx * 2 scale_1_ptr = scales_ptr + chunk_idx * 2 + 1 - scale_0 = d_super_scale * tl.load(scale_0_ptr).to(tl.int8).to( - tl.float32 - ) - scale_1 = d_super_scale * tl.load(scale_1_ptr).to(tl.int8).to( - tl.float32 - ) + scale_0 = d_super_scale * tl.load(scale_0_ptr).to(tl.int8).to(out_dtype) + scale_1 = d_super_scale * tl.load(scale_1_ptr).to(tl.int8).to(out_dtype) scales_32 = tl.where(mask_16, scale_0, scale_1) - dequant_32 = q_vec_32.to(OUT_DTYPE) * scales_32 + dequant_32 = q_vec_32.to(out_dtype) * scales_32 # 5. Store result output_ptr = output_start_ptr + chunk_idx * 32 From a7dd75f24d7035f67ea905d8736356125ddece9d Mon Sep 17 00:00:00 2001 From: blepping Date: Mon, 8 Sep 2025 22:29:35 -0600 Subject: [PATCH 6/8] Use static_range in Triton kernels --- dequant_triton.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/dequant_triton.py b/dequant_triton.py index e4fc2bd..c63ad3d 100644 --- a/dequant_triton.py +++ b/dequant_triton.py @@ -52,7 +52,7 @@ def dequantize_q8_0_kernel( weight_indices = tl.arange(0, GGUF_BLOCK_SIZE) # Loop over the N blocks assigned to this program - for i in range(N_BLOCKS_PER_PROG): + for i in tl.static_range(N_BLOCKS_PER_PROG): current_block_idx = start_block_idx + i # Boundary check to avoid processing padding blocks @@ -146,7 +146,7 @@ def dequantize_q4_k_kernel( qs_chunk_offsets = tl.arange(0, 32) store_offsets = tl.arange(0, 32) - for i in range(N_BLOCKS_PER_PROG): + for i in tl.static_range(N_BLOCKS_PER_PROG): current_block_idx = start_block_idx + i if current_block_idx < n_total_blocks: block_start_ptr = q_tensor_ptr + current_block_idx * TYPE_SIZE @@ -165,7 +165,7 @@ def dequantize_q4_k_kernel( qs_start_ptr = block_start_ptr + 4 + K_SCALE_SIZE # Process in 4 chunks of 64 values - for k_chunk in range(4): + for k_chunk in tl.static_range(4): # Scale indices for low (a) and high (b) nibbles k_idx_a = 2 * k_chunk k_idx_b = 2 * k_chunk + 1 @@ -284,7 +284,7 @@ def dequantize_q5_k_kernel( offsets_32 = tl.arange(0, 32) - for i in range(N_BLOCKS_PER_PROG): + for i in tl.static_range(N_BLOCKS_PER_PROG): current_block_idx = start_block_idx + i if current_block_idx < n_total_blocks: # Pointers and initial loads @@ -306,7 +306,7 @@ def dequantize_q5_k_kernel( qh_bytes_all = tl.load(qh_start_ptr + offsets_32) # Process in 8 chunks of 32 values - for chunk_idx in range(8): + for chunk_idx in tl.static_range(8): # 1. Unpack scale and min for this chunk if chunk_idx < 4: sc = ((d_sc_word >> (chunk_idx * 8)) & 0xFF) & 0x3F @@ -403,7 +403,7 @@ def dequantize_q6_k_kernel( offsets_32 = tl.arange(0, 32) mask_16 = offsets_32 < 16 - for i in range(N_BLOCKS_PER_PROG): + for i in tl.static_range(N_BLOCKS_PER_PROG): current_block_idx = start_block_idx + i if current_block_idx < n_total_blocks: block_start_ptr = q_tensor_ptr + current_block_idx * TYPE_SIZE @@ -414,7 +414,7 @@ def dequantize_q6_k_kernel( d_super_scale = tl.load(d_ptr.to(tl.pointer_type(tl.float16))).to(out_dtype) # Process block in 8 chunks of 32 values - for chunk_idx in range(8): + for chunk_idx in tl.static_range(8): # 1. Calculate ql source data and unpack ql_byte_offset = (chunk_idx % 2) * 32 + (chunk_idx // 4) * 64 ql_ptr = block_start_ptr + ql_byte_offset From f11b617d5236c340e84c2ef83755b178a4ea7360 Mon Sep 17 00:00:00 2001 From: blepping Date: Tue, 9 Sep 2025 14:15:47 -0600 Subject: [PATCH 7/8] Refactor Q4_K Triton kernel a bit to reduce code duplication --- dequant_triton.py | 72 +++++++++++++++++++++++------------------------ 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/dequant_triton.py b/dequant_triton.py index c63ad3d..0b4da7a 100644 --- a/dequant_triton.py +++ b/dequant_triton.py @@ -117,6 +117,28 @@ def grid(meta): return out_tensor.reshape(n_total_blocks, -1) +@triton.jit +def dequantize_q4_k_get_scales_min( + k_idx: int, + d_sc_word: tl.tensor, + m_word: tl.tensor, + m_sc_word: tl.tensor, +) -> tuple[tl.tensor, tl.tensor]: + if k_idx < 4: + k_idx_x8 = k_idx * 8 + d_sc_byte = d_sc_word >> k_idx_x8 + m_byte = m_word >> k_idx_x8 + return d_sc_byte & 0x3F, m_byte & 0x3F + + k_prime_x8 = (k_idx - 4) * 8 + d_sc_byte = d_sc_word >> k_prime_x8 + m_byte = m_word >> k_prime_x8 + m_sc_byte = (m_sc_word >> k_prime_x8) & 0xFF + sc = (m_sc_byte & 0x0F) | ((d_sc_byte >> 2) & 0x30) + m = (m_sc_byte >> 4) | ((m_byte >> 2) & 0x30) + return sc, m + + @triton.autotune( configs=[ triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=4), @@ -133,11 +155,11 @@ def dequantize_q4_k_kernel( q_tensor_ptr, out_tensor_ptr, n_total_blocks, - QK_K: tl.constexpr, + OUT_DTYPE: tl.constexpr, TYPE_SIZE: tl.constexpr, - K_SCALE_SIZE: tl.constexpr, N_BLOCKS_PER_PROG: tl.constexpr, - OUT_DTYPE: tl.constexpr, + QK_K: tl.constexpr = QK_K, + K_SCALE_SIZE: tl.constexpr = K_SCALE_SIZE, ): out_dtype = OUT_DTYPE.value pid = tl.program_id(axis=0) @@ -166,37 +188,17 @@ def dequantize_q4_k_kernel( # Process in 4 chunks of 64 values for k_chunk in tl.static_range(4): - # Scale indices for low (a) and high (b) nibbles - k_idx_a = 2 * k_chunk - k_idx_b = 2 * k_chunk + 1 - - # --- Calculate Scale A (for low nibbles) --- - if k_idx_a < 4: - d_sc_byte_a = (d_sc_word >> (k_idx_a * 8)) & 0xFF - m_byte_a = (m_word >> (k_idx_a * 8)) & 0xFF - sc_a = d_sc_byte_a & 0x3F - m_a = m_byte_a & 0x3F - else: - k_prime_a = k_idx_a - 4 - d_sc_byte_a = (d_sc_word >> (k_prime_a * 8)) & 0xFF - m_byte_a = (m_word >> (k_prime_a * 8)) & 0xFF - m_sc_byte_a = (m_sc_word >> (k_prime_a * 8)) & 0xFF - sc_a = (m_sc_byte_a & 0x0F) | ((d_sc_byte_a >> 2) & 0x30) - m_a = (m_sc_byte_a >> 4) | ((m_byte_a >> 2) & 0x30) - - # --- Calculate Scale B (for high nibbles) --- - if k_idx_b < 4: - d_sc_byte_b = (d_sc_word >> (k_idx_b * 8)) & 0xFF - m_byte_b = (m_word >> (k_idx_b * 8)) & 0xFF - sc_b = d_sc_byte_b & 0x3F - m_b = m_byte_b & 0x3F - else: - k_prime_b = k_idx_b - 4 - d_sc_byte_b = (d_sc_word >> (k_prime_b * 8)) & 0xFF - m_byte_b = (m_word >> (k_prime_b * 8)) & 0xFF - m_sc_byte_b = (m_sc_word >> (k_prime_b * 8)) & 0xFF - sc_b = (m_sc_byte_b & 0x0F) | ((d_sc_byte_b >> 2) & 0x30) - m_b = (m_sc_byte_b >> 4) | ((m_byte_b >> 2) & 0x30) + k_idx = 2 * k_chunk + + # --- Get scale A (for low nibbles) --- + sc_a, m_a = dequantize_q4_k_get_scales_min( + k_idx, d_sc_word, m_word, m_sc_word + ) + + # --- Get scale B (for high nibbles) --- + sc_b, m_b = dequantize_q4_k_get_scales_min( + k_idx + 1, d_sc_word, m_word, m_sc_word + ) current_d_a = d * sc_a.to(out_dtype) current_dm_a = dmin * m_a.to(out_dtype) @@ -247,9 +249,7 @@ def grid(meta): blocks, out_tensor, n_total_blocks, - QK_K=QK_K, TYPE_SIZE=type_size, - K_SCALE_SIZE=K_SCALE_SIZE, OUT_DTYPE=triton_dtype, ) From 26ffedbafeb7e519c3557de815987870292075a4 Mon Sep 17 00:00:00 2001 From: blepping Date: Wed, 10 Sep 2025 18:06:32 -0600 Subject: [PATCH 8/8] Refactor/cleanup Triton support Remove Q8_0 Triton kernel --- dequant_triton.py | 395 ++++++++++++++-------------------------------- 1 file changed, 117 insertions(+), 278 deletions(-) diff --git a/dequant_triton.py b/dequant_triton.py index 0b4da7a..26511e5 100644 --- a/dequant_triton.py +++ b/dequant_triton.py @@ -3,161 +3,67 @@ import triton import triton.language as tl -import gguf +from gguf import GGML_QUANT_SIZES, QK_K, GGMLQuantizationType -TORCH_TO_TRITON_DTYPE_MAP = { +K_SCALE_SIZE = 12 + +TORCH_TO_TRITON_DTYPE_MAP: dict[torch.dtype, tl.dtype] = { torch.float32: tl.float32, torch.float16: tl.float16, torch.bfloat16: tl.bfloat16, } -# K Quants # -QK_K = 256 -K_SCALE_SIZE = 12 - - -@triton.autotune( - configs=[ - # Test different numbers of GGUF blocks per program instance - triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=4), - triton.Config({"N_BLOCKS_PER_PROG": 2}, num_warps=4), - triton.Config({"N_BLOCKS_PER_PROG": 4}, num_warps=4), - triton.Config({"N_BLOCKS_PER_PROG": 8}, num_warps=4), - triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=8), - triton.Config({"N_BLOCKS_PER_PROG": 2}, num_warps=8), - triton.Config({"N_BLOCKS_PER_PROG": 4}, num_warps=8), - triton.Config({"N_BLOCKS_PER_PROG": 8}, num_warps=8), - ], - key=["n_total_blocks"], # Tune based on the total number of blocks -) -@triton.jit -def dequantize_q8_0_kernel( - q_tensor_ptr, - out_tensor_ptr, - n_total_blocks, - GGUF_BLOCK_SIZE: tl.constexpr, - GGUF_TYPE_SIZE: tl.constexpr, - N_BLOCKS_PER_PROG: tl.constexpr, # How many blocks each program handles - OUT_DTYPE: tl.constexpr, -): - out_dtype = OUT_DTYPE.value +_DEFAULT_AUTOTUNE_CONFIGS: list[triton.Config] = [ + triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=4), + triton.Config({"N_BLOCKS_PER_PROG": 2}, num_warps=4), + triton.Config({"N_BLOCKS_PER_PROG": 4}, num_warps=4), + triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=8), + triton.Config({"N_BLOCKS_PER_PROG": 2}, num_warps=8), + triton.Config({"N_BLOCKS_PER_PROG": 4}, num_warps=8), +] - # Each program is responsible for a chunk of N_BLOCKS_PER_PROG blocks - pid = tl.program_id(axis=0) - - # Starting GGUF block index for this program - start_block_idx = pid * N_BLOCKS_PER_PROG - - # Create offsets for the weights within a GGUF block (0, 1, ..., 31) - weight_indices = tl.arange(0, GGUF_BLOCK_SIZE) - - # Loop over the N blocks assigned to this program - for i in tl.static_range(N_BLOCKS_PER_PROG): - current_block_idx = start_block_idx + i - - # Boundary check to avoid processing padding blocks - if current_block_idx < n_total_blocks: - # Pointer to the start of the current GGUF block in the input tensor - block_start_ptr = q_tensor_ptr + current_block_idx * GGUF_TYPE_SIZE - - # Load scale (d) - uint16_ptr = block_start_ptr.to(tl.pointer_type(tl.uint16)) - uint16_val = tl.load(uint16_ptr) - scale_fp16 = tl.cast(uint16_val, tl.float16, bitcast=True) - scale = scale_fp16.to(out_dtype) - - # Load weights (x) - q_weights_ptr = block_start_ptr + 2 - uint8_weights = tl.load(q_weights_ptr + weight_indices) - q_weights = uint8_weights.to(tl.int8) - - # Dequantize - dequantized_weights = q_weights.to(out_dtype) * scale - - # Store the result - output_start_ptr = out_tensor_ptr + current_block_idx * GGUF_BLOCK_SIZE - tl.store(output_start_ptr + weight_indices, dequantized_weights) - - -def dequantize_blocks_Q8_0_triton( - blocks: torch.Tensor, - block_size: int, - type_size: int, - dtype=None, -) -> torch.Tensor: - assert blocks.dtype == torch.uint8 and blocks.is_cuda and blocks.is_contiguous() - - n_elements = blocks.numel() - assert n_elements % type_size == 0 - n_total_blocks = n_elements // type_size - - dtype = dtype or torch.float32 - triton_dtype = TORCH_TO_TRITON_DTYPE_MAP.get(dtype) - if triton_dtype is None: - raise TypeError(f"Unsupported output dtype {dtype}") - - out_tensor = torch.empty( - (n_total_blocks * block_size,), - dtype=dtype, - device=blocks.device, - ) - - def grid(meta): - return (triton.cdiv(n_total_blocks, meta["N_BLOCKS_PER_PROG"]),) - - dequantize_q8_0_kernel[grid]( - blocks, - out_tensor, - n_total_blocks, - GGUF_BLOCK_SIZE=block_size, - GGUF_TYPE_SIZE=type_size, - OUT_DTYPE=triton_dtype, - ) - - return out_tensor.reshape(n_total_blocks, -1) +_AUTOTUNE_CONFIGS: dict[str, list[triton.Config]] = {} @triton.jit -def dequantize_q4_k_get_scales_min( +def dequantize_Q4_K_get_scales_min( k_idx: int, d_sc_word: tl.tensor, m_word: tl.tensor, m_sc_word: tl.tensor, -) -> tuple[tl.tensor, tl.tensor]: +) -> tl.tuple: if k_idx < 4: k_idx_x8 = k_idx * 8 d_sc_byte = d_sc_word >> k_idx_x8 m_byte = m_word >> k_idx_x8 - return d_sc_byte & 0x3F, m_byte & 0x3F + sc = d_sc_byte & 0x3F + m = m_byte & 0x3F + else: + k_prime_x8 = (k_idx - 4) * 8 + d_sc_byte = d_sc_word >> k_prime_x8 + m_byte = m_word >> k_prime_x8 + m_sc_byte = m_sc_word >> k_prime_x8 + sc = (m_sc_byte & 0x0F) | ((d_sc_byte >> 2) & 0x30) + m = ((m_sc_byte & 0xFF) >> 4) | ((m_byte >> 2) & 0x30) + return tl.tuple((sc, m)) + - k_prime_x8 = (k_idx - 4) * 8 - d_sc_byte = d_sc_word >> k_prime_x8 - m_byte = m_word >> k_prime_x8 - m_sc_byte = (m_sc_word >> k_prime_x8) & 0xFF - sc = (m_sc_byte & 0x0F) | ((d_sc_byte >> 2) & 0x30) - m = (m_sc_byte >> 4) | ((m_byte >> 2) & 0x30) - return sc, m +# Same as Q4_K +dequantize_Q5_K_get_scales_min = dequantize_Q4_K_get_scales_min @triton.autotune( - configs=[ - triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=4), - triton.Config({"N_BLOCKS_PER_PROG": 2}, num_warps=4), - triton.Config({"N_BLOCKS_PER_PROG": 4}, num_warps=4), - triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=8), - triton.Config({"N_BLOCKS_PER_PROG": 2}, num_warps=8), - triton.Config({"N_BLOCKS_PER_PROG": 4}, num_warps=8), - ], + configs=_AUTOTUNE_CONFIGS.get("q4_k", _DEFAULT_AUTOTUNE_CONFIGS), key=["n_total_blocks"], ) @triton.jit -def dequantize_q4_k_kernel( +def dequantize_Q4_K_kernel( q_tensor_ptr, out_tensor_ptr, n_total_blocks, OUT_DTYPE: tl.constexpr, - TYPE_SIZE: tl.constexpr, N_BLOCKS_PER_PROG: tl.constexpr, + TYPE_SIZE: tl.constexpr, QK_K: tl.constexpr = QK_K, K_SCALE_SIZE: tl.constexpr = K_SCALE_SIZE, ): @@ -165,14 +71,14 @@ def dequantize_q4_k_kernel( pid = tl.program_id(axis=0) start_block_idx = pid * N_BLOCKS_PER_PROG - qs_chunk_offsets = tl.arange(0, 32) - store_offsets = tl.arange(0, 32) + offsets_32 = tl.arange(0, 32) + offsets_scale = offsets_32 + 4 + K_SCALE_SIZE for i in tl.static_range(N_BLOCKS_PER_PROG): current_block_idx = start_block_idx + i if current_block_idx < n_total_blocks: block_start_ptr = q_tensor_ptr + current_block_idx * TYPE_SIZE - output_start_ptr = out_tensor_ptr + current_block_idx * QK_K + output_start_ptr = out_tensor_ptr + current_block_idx * QK_K + offsets_32 d = tl.load(block_start_ptr.to(tl.pointer_type(tl.float16))).to(out_dtype) dmin = tl.load((block_start_ptr + 2).to(tl.pointer_type(tl.float16))).to( @@ -184,19 +90,19 @@ def dequantize_q4_k_kernel( m_word = tl.load(scales_ptr_u32 + 1) m_sc_word = tl.load(scales_ptr_u32 + 2) - qs_start_ptr = block_start_ptr + 4 + K_SCALE_SIZE + qs_start_ptr = block_start_ptr + offsets_scale # Process in 4 chunks of 64 values for k_chunk in tl.static_range(4): k_idx = 2 * k_chunk # --- Get scale A (for low nibbles) --- - sc_a, m_a = dequantize_q4_k_get_scales_min( + sc_a, m_a = dequantize_Q4_K_get_scales_min( k_idx, d_sc_word, m_word, m_sc_word ) # --- Get scale B (for high nibbles) --- - sc_b, m_b = dequantize_q4_k_get_scales_min( + sc_b, m_b = dequantize_Q4_K_get_scales_min( k_idx + 1, d_sc_word, m_word, m_sc_word ) @@ -207,7 +113,7 @@ def dequantize_q4_k_kernel( # Load 32 bytes of quantized data chunk_qs_ptr = qs_start_ptr + k_chunk * 32 - qs_bytes_chunk = tl.load(chunk_qs_ptr + qs_chunk_offsets) + qs_bytes_chunk = tl.load(chunk_qs_ptr) qs_low = (qs_bytes_chunk & 0x0F).to(out_dtype) qs_high = (qs_bytes_chunk >> 4).to(out_dtype) @@ -217,79 +123,38 @@ def dequantize_q4_k_kernel( # Store results contiguously output_chunk_ptr = output_start_ptr + k_chunk * 64 - tl.store(output_chunk_ptr + store_offsets, dequant_low) - tl.store(output_chunk_ptr + 32 + store_offsets, dequant_high) - - -def dequantize_blocks_Q4_K_triton( - blocks: torch.Tensor, - block_size: int, - type_size: int, - dtype=None, -) -> torch.Tensor: - assert blocks.dtype == torch.uint8 and blocks.is_cuda and blocks.is_contiguous() - - n_elements = blocks.numel() - assert n_elements % type_size == 0 - n_total_blocks = n_elements // type_size - - dtype = dtype or torch.float32 - triton_dtype = TORCH_TO_TRITON_DTYPE_MAP.get(dtype) - if triton_dtype is None: - raise TypeError(f"Unsupported output dtype {dtype}") - - out_tensor = torch.empty( - (n_total_blocks * QK_K,), dtype=dtype, device=blocks.device - ) - - def grid(meta): - return (triton.cdiv(n_total_blocks, meta["N_BLOCKS_PER_PROG"]),) - - dequantize_q4_k_kernel[grid]( - blocks, - out_tensor, - n_total_blocks, - TYPE_SIZE=type_size, - OUT_DTYPE=triton_dtype, - ) - - return out_tensor.reshape(n_total_blocks, -1) + output_chunk_ptr.store(dequant_low) + (output_chunk_ptr + 32).store(dequant_high) @triton.autotune( - configs=[ - triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=4), - triton.Config({"N_BLOCKS_PER_PROG": 2}, num_warps=4), - triton.Config({"N_BLOCKS_PER_PROG": 4}, num_warps=4), - triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=8), - triton.Config({"N_BLOCKS_PER_PROG": 2}, num_warps=8), - triton.Config({"N_BLOCKS_PER_PROG": 4}, num_warps=8), - ], + configs=_AUTOTUNE_CONFIGS.get("q5_k", _DEFAULT_AUTOTUNE_CONFIGS), key=["n_total_blocks"], ) @triton.jit -def dequantize_q5_k_kernel( +def dequantize_Q5_K_kernel( q_tensor_ptr, out_tensor_ptr, n_total_blocks, - QK_K: tl.constexpr, - TYPE_SIZE: tl.constexpr, - K_SCALE_SIZE: tl.constexpr, OUT_DTYPE: tl.constexpr, N_BLOCKS_PER_PROG: tl.constexpr, + TYPE_SIZE: tl.constexpr, + QK_K: tl.constexpr = QK_K, + K_SCALE_SIZE: tl.constexpr = K_SCALE_SIZE, ): out_dtype = OUT_DTYPE.value pid = tl.program_id(axis=0) start_block_idx = pid * N_BLOCKS_PER_PROG offsets_32 = tl.arange(0, 32) + offsets_scale = offsets_32 + 4 + K_SCALE_SIZE for i in tl.static_range(N_BLOCKS_PER_PROG): current_block_idx = start_block_idx + i if current_block_idx < n_total_blocks: # Pointers and initial loads block_start_ptr = q_tensor_ptr + current_block_idx * TYPE_SIZE - output_start_ptr = out_tensor_ptr + current_block_idx * QK_K + output_start_ptr = out_tensor_ptr + current_block_idx * QK_K + offsets_32 d = tl.load(block_start_ptr.to(tl.pointer_type(tl.float16))).to(out_dtype) dmin = tl.load((block_start_ptr + 2).to(tl.pointer_type(tl.float16))).to( out_dtype @@ -300,31 +165,24 @@ def dequantize_q5_k_kernel( m_word = tl.load(scales_ptr_u32 + 1) m_sc_word = tl.load(scales_ptr_u32 + 2) - qh_start_ptr = block_start_ptr + 4 + K_SCALE_SIZE + qh_start_ptr = block_start_ptr + offsets_scale qs_start_ptr = qh_start_ptr + QK_K // 8 - qh_bytes_all = tl.load(qh_start_ptr + offsets_32) + qh_bytes_all = tl.load(qh_start_ptr) # Process in 8 chunks of 32 values for chunk_idx in tl.static_range(8): - # 1. Unpack scale and min for this chunk - if chunk_idx < 4: - sc = ((d_sc_word >> (chunk_idx * 8)) & 0xFF) & 0x3F - m = ((m_word >> (chunk_idx * 8)) & 0xFF) & 0x3F - else: - k_prime = chunk_idx - 4 - d_sc_byte = (d_sc_word >> (k_prime * 8)) & 0xFF - m_byte = (m_word >> (k_prime * 8)) & 0xFF - m_sc_byte = (m_sc_word >> (k_prime * 8)) & 0xFF - sc = (m_sc_byte & 0x0F) | ((d_sc_byte >> 2) & 0x30) - m = (m_sc_byte >> 4) | ((m_byte >> 2) & 0x30) + # # 1. Unpack scale and min for this chunk + sc, m = dequantize_Q5_K_get_scales_min( + chunk_idx, d_sc_word, m_word, m_sc_word + ) final_d = d * sc.to(out_dtype) final_dm = dmin * m.to(out_dtype) # 2. Unpack ql (lower 4 bits) for this chunk qs_byte_offset = (chunk_idx // 2) * 32 - qs_bytes = tl.load(qs_start_ptr + qs_byte_offset + offsets_32) + qs_bytes = tl.load(qs_start_ptr + qs_byte_offset) use_low_nibbles = chunk_idx % 2 == 0 ql = tl.where(use_low_nibbles, qs_bytes & 0x0F, qs_bytes >> 4) @@ -336,66 +194,22 @@ def dequantize_q5_k_kernel( dequant_32 = final_d * q.to(out_dtype) - final_dm output_ptr = output_start_ptr + chunk_idx * 32 - tl.store(output_ptr + offsets_32, dequant_32) - - -def dequantize_blocks_Q5_K_triton( - blocks: torch.Tensor, - block_size: int, - type_size: int, - dtype=None, -) -> torch.Tensor: - assert blocks.dtype == torch.uint8 and blocks.is_cuda and blocks.is_contiguous() - - n_elements = blocks.numel() - assert n_elements % type_size == 0 - n_total_blocks = n_elements // type_size - - dtype = dtype or torch.float32 - triton_dtype = TORCH_TO_TRITON_DTYPE_MAP.get(dtype) - if triton_dtype is None: - raise TypeError(f"Unsupported output dtype {dtype}") - - out_tensor = torch.empty( - (n_total_blocks * QK_K,), dtype=dtype, device=blocks.device - ) - - def grid(meta): - return (triton.cdiv(n_total_blocks, meta["N_BLOCKS_PER_PROG"]),) - - dequantize_q5_k_kernel[grid]( - blocks, - out_tensor, - n_total_blocks, - QK_K=QK_K, - TYPE_SIZE=type_size, - K_SCALE_SIZE=K_SCALE_SIZE, - OUT_DTYPE=triton_dtype, - ) - - return out_tensor.reshape(n_total_blocks, -1) + output_ptr.store(dequant_32) @triton.autotune( - configs=[ - triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=4), - triton.Config({"N_BLOCKS_PER_PROG": 2}, num_warps=4), - triton.Config({"N_BLOCKS_PER_PROG": 4}, num_warps=4), - triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=8), - triton.Config({"N_BLOCKS_PER_PROG": 2}, num_warps=8), - triton.Config({"N_BLOCKS_PER_PROG": 4}, num_warps=8), - ], + configs=_AUTOTUNE_CONFIGS.get("q6_k", _DEFAULT_AUTOTUNE_CONFIGS), key=["n_total_blocks"], ) @triton.jit -def dequantize_q6_k_kernel( +def dequantize_Q6_K_kernel( q_tensor_ptr, out_tensor_ptr, n_total_blocks, - QK_K: tl.constexpr, - TYPE_SIZE: tl.constexpr, OUT_DTYPE: tl.constexpr, N_BLOCKS_PER_PROG: tl.constexpr, + TYPE_SIZE: tl.constexpr, + QK_K: tl.constexpr = QK_K, ): out_dtype = OUT_DTYPE.value pid = tl.program_id(axis=0) @@ -451,48 +265,73 @@ def dequantize_q6_k_kernel( tl.store(output_ptr + offsets_32, dequant_32) -def dequantize_blocks_Q6_K_triton( - blocks: torch.Tensor, - block_size: int, - type_size: int, - dtype=None, -) -> torch.Tensor: - assert blocks.dtype == torch.uint8 and blocks.is_cuda and blocks.is_contiguous() +def dequantize_blocks_triton_wrapper_factory( + qtype: GGMLQuantizationType, + kernel, +): + ggml_type_size = GGML_QUANT_SIZES[qtype][1] + + def dequantize_blocks_triton( + blocks: torch.Tensor, + block_size: int, + type_size: int, + dtype=None, + ) -> torch.Tensor: + if blocks.dtype != torch.uint8: + raise ValueError( + f"GGUF Triton {qtype.name}: Blocks tensor dtype must be uint8 but got {blocks.dtype}" + ) + if not blocks.is_cuda: + raise ValueError(f"GGUF Triton {qtype.name}: Blocks tensor must be CUDA") + if not blocks.is_contiguous(): + raise ValueError( + f"GGUF Triton {qtype.name}: Blocks tensor must be contiguous" + ) - n_elements = blocks.numel() - assert n_elements % type_size == 0 - n_total_blocks = n_elements // type_size + n_elements = blocks.numel() + if n_elements % ggml_type_size != 0: + raise ValueError( + f"GGUF Triton {qtype.name}: Blocks tensor must have a number of elements ({n_elements}) divisible by the type size {ggml_type_size}" + ) + n_total_blocks = n_elements // ggml_type_size + + dtype = dtype or torch.float32 + triton_dtype = TORCH_TO_TRITON_DTYPE_MAP.get(dtype) + if triton_dtype is None: + raise TypeError( + f"GGUF Triton {qtype.name}: Unsupported output dtype {dtype}" + ) - dtype = dtype or torch.float32 - triton_dtype = TORCH_TO_TRITON_DTYPE_MAP.get(dtype) - if triton_dtype is None: - raise TypeError(f"Unsupported output dtype {dtype}") + out_tensor = torch.empty( + (n_total_blocks * QK_K,), dtype=dtype, device=blocks.device + ) - out_tensor = torch.empty( - (n_total_blocks * QK_K,), dtype=dtype, device=blocks.device - ) + def grid(meta: dict): + return (triton.cdiv(n_total_blocks, meta["N_BLOCKS_PER_PROG"]),) - def grid(meta): - return (triton.cdiv(n_total_blocks, meta["N_BLOCKS_PER_PROG"]),) + kernel[grid]( + blocks, + out_tensor, + n_total_blocks, + TYPE_SIZE=ggml_type_size, + OUT_DTYPE=triton_dtype, + ) - dequantize_q6_k_kernel[grid]( - blocks, - out_tensor, - n_total_blocks, - QK_K=QK_K, - TYPE_SIZE=type_size, - OUT_DTYPE=triton_dtype, - ) + return out_tensor.reshape(n_total_blocks, -1) - return out_tensor.reshape(n_total_blocks, -1) + return dequantize_blocks_triton dequantize_functions = { - # Q8_0 simply seems than the PyTorch implementation. - # gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0_triton, - gguf.GGMLQuantizationType.Q4_K: dequantize_blocks_Q4_K_triton, - gguf.GGMLQuantizationType.Q5_K: dequantize_blocks_Q5_K_triton, - gguf.GGMLQuantizationType.Q6_K: dequantize_blocks_Q6_K_triton, + GGMLQuantizationType.Q4_K: dequantize_blocks_triton_wrapper_factory( + GGMLQuantizationType.Q4_K, dequantize_Q4_K_kernel + ), + GGMLQuantizationType.Q5_K: dequantize_blocks_triton_wrapper_factory( + GGMLQuantizationType.Q5_K, dequantize_Q5_K_kernel + ), + GGMLQuantizationType.Q6_K: dequantize_blocks_triton_wrapper_factory( + GGMLQuantizationType.Q6_K, dequantize_Q6_K_kernel + ), } __all__ = ("dequantize_functions",)