Skip to content

Conversation

@blepping
Copy link
Contributor

Continuation of #331. Sorry for making a new pull, I couldn't reopen the existing one since it had been force pushed.

Wrapping the existing dequant functions with torch.compile in my benchmark tool brought the results to parity with Triton: https://gist.github.com/blepping/963459b244a4140b081cebdec24c56b2

However, I was not able to replicate that in the real world so the Triton kernels are back!

This includes a generic approach to pass configuration parameters around, including overriding the dequant functions (potentially with compiled or Triton versions). There is an optimize parameter in the advanced loader that lets you choose between none, compile and triton. Enabling Triton for me is noticeably faster, using torch.compile is actually slower than nothing.

This also fixes an issue with the existing approach to setting configuration like dequant_dtype:

        ops = GGMLOps()

        if dequant_dtype in ("default", None):
            ops.Linear.dequant_dtype = None
        elif dequant_dtype in ["target"]:
            ops.Linear.dequant_dtype = dequant_dtype
        else:
            ops.Linear.dequant_dtype = getattr(torch, dequant_dtype)

This is not kosher because while ops is an instance of GGMLOps, attributes like ops.Linear are not instance specific, so setting it here will change that attribute in the global class. This means if there are multiple loaders with different settings, you will get the configuration of whatever loader overwrote that value last.

@blepping
Copy link
Contributor Author

blepping commented Sep 14, 2025

image

edit:
image

image

Some helpful test feedback I got on Discord. The Triton performance benefit is about around what I've observed in my own testing. Interesting to see that it's pretty constant even across Nvidia and AMD.

@StrongerXi
Copy link
Contributor

Hmm I think if you use the comfyui builtin or Kijai's torch compile nodes, you should get the same, or maybe even better perf than the triton kernels. What torch.compile does under the hood is generating optimized triton kernels automatically.

@blepping
Copy link
Contributor Author

Hmm I think if you use the comfyui builtin or Kijai's torch compile nodes, you should get the same, or maybe even better perf than the triton kernels.

This pull actually already includes an option to compile the dequantization functions. You can try it if you want, it didn't affect the performance in my testing and the reports I've gotten.

It's harder to compare performance when compiling the whole model rather than just the dequant functions, but compiling the whole model also isn't always desirable (or possible in some cases).

@patientx
Copy link

Hmm I think if you use the comfyui builtin or Kijai's torch compile nodes, you should get the same, or maybe even better perf than the triton kernels. What torch.compile does under the hood is generating optimized triton kernels automatically.

with wan this doesn't make a difference , tried with q4km BUT as you can see from my tests it certainly works. BTW , I am also using the standard torch.compile and sage-attention on top of this so this is just about gguf itself and it works.

Refactoring/cleanups
@patientx
Copy link

q3 kernel also shows same amount of speed improvement , around 12% at 2048x2048 with hunyuan-image.

@blepping
Copy link
Contributor Author

Q8_0 isn't worth it, but as of now there are now Triton kernels for all the K-quants and legacy quants.

The following tests were done using this utility I made: https://gist.github.com/blepping/963459b244a4140b081cebdec24c56b2

Using it requires running from the repo directory with this branch checkout and a venv with the necessary dependencies (PyTorch, Triton, etc) activated.

I used this command (Linux/Unix, presumably would work with WSL) to run it over all the available quants and dtypes: for dt in float32 float16 bfloat16; do for qt in q4_0 q4_1 q5_0 q5_1 q2_k q3_k q4_k q5_k q6_k; do python -O test_triton_dequant.py --no-validate --qtype=$qt --dtype=$dt; done; done

PyTorch vs Triton

Comparing performance of the Triton kernels vs the existing PyTorch dequant functions. 2.0 in a column would mean the Triton version was two times faster. These results are from benchmarking the dequant functions in isolation so you won't see the same speedup running an actual model.

For reference, Q4_K is ~3.5x here, for moderate image sizes with models like Flux, Qwen the real world performance benefit is more like 1.2x. The Q8_0 kernel which wasn't worth using was around 1.4x here I will have to do some real testing with the quants that seem a bit borderline to find out if having them enabled is actually worth it (Q4_0, Q2_K at non-32bit, etc).

qtype float32 float16 bfloat16
Q4_0 2.39 2.41 2.37
Q4_1 3.07 2.42 2.39
Q5_0 5.55 5.75 5.67
Q5_1 6.14 5.72 5.45
Q2_K 3.61 2.52 2.57
Q3_K 3.47 3.29 3.17
Q4_K 3.54 3.91 3.75
Q5_K 4.64 4.61 4.67
Q6_K 3.82 4.13 4.29

PyTorch vs PyTorch float32

Comparing the PyTorch dequant functions with float16 and bfloat16 vs float32. For example 1.27 would mean dequantizing for that dtype was 1.27x the performance of dequantizing at float32. Note: No Triton here.

qtype float16 bfloat16
Q4_0 1.27 1.09
Q4_1 1.44 1.38
Q5_0 1.06 1.01
Q5_1 1.15 1.17
Q2_K 1.48 1.49
Q3_K 1.07 1.07
Q4_K 1.40 1.40
Q5_K 1.22 1.23
Q6_K 1.12 1.07

PyTorch Q4_0 vs PyTorch other quants

This also between PyTorch implementations, no Triton. Comparing the dequantization speed vs Q4_0 (Q8_0 may be faster but I don't have it implemented). So we say Q4_0 is 1.0x speed and compare the other quants to this. Comparison done with float32.

qtype Q4_0 factor
Q4_0 1.0000
Q4_1 1.3485
Q5_0 2.8110
Q5_1 3.1835
Q2_K 1.3208
Q3_K 1.8160
Q4_K 1.3653
Q5_K 1.9745
Q6_K 1.7487

These tests were done before my changes to eliminate the slow to_uint32 function Q5_0 and Q5_1 use. This improved Q5_0 and Q5_1 with the PyTorch implementations by about 8-9% for float32.

@blepping
Copy link
Contributor Author

Quick'n'dirty testing with the legacy quants and Chroma Radiance (float16 output):

16x16 generations:

q4_0
normal: 17/18 [00:21<00:01,  1.28s/it
triton: 17/18 [00:11<00:00,  1.49it/s

q4_1
normal: 17/18 [00:22<00:01,  1.33s/it
triton: 17/18 [00:11<00:00,  1.47it/s

q5_0
normal: 17/18 [00:53<00:03,  3.15s/it
triton: 17/18 [00:12<00:00,  1.31it/s

q5_1
normal: 17/18 [00:54<00:03,  3.23s/it
triton: 17/18 [00:12<00:00,  1.33it/s

1152x1344 generation:

q5_1
normal: 15/18 [03:01<00:36, 12.11s/it (OOM)
triton: 17/18 [01:35<00:05,  5.61s/it

Looking pretty good, especially the 5bit quants. I also tried testing a full-sized generation with Q5_1 (bottom results). I ran out of memory twice in a row testing the non-Triton version, VRAM was around 96% used before it failed. With the Triton kernel it stayed around 80% so at least for Q5_1 it seems like using Triton uses less memory as well as being faster.

@blepping
Copy link
Contributor Author

@city96 This needs some cleanup, but if I get it presentable would you be interested in merging these changes?

Show Triton status in tooltip for optimize param in advanced loader
Tweak autotune configs a bit
The return of the Q8_0 Triton kernel, matches or slightly exceeds PT performance now
Copy link
Owner

@city96 city96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@blepping I'll have to test this properly sometime, just haven't really had time lately (or access to usable hardware most of the time lol I'm writing this review on my laptop so I can't even run the code).

Most of the comments are just random nitpicks or stuff I noticed looking over the code, again, I can't run it and haven't tested the control flow so half of them could be wrong.

Also, for the final file layout, I'd probably make a new folder named dequant, and rename dequant.py to dequant/kernel_torch.py and the triton one to dequant/kernel_triton.py or something similar. Then we can have the logic for picking either in the dequant folder's init, so it picks the correct one when you import from it. Thought that might look weird on the git diff so idk.

import torch
from tqdm import tqdm

HAVE_BFLOAT16=hasattr(torch, "bfloat16")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't we need to check if the current device actually supports it? i.e. we'd want to test this on RTX 20XX, Volta and Pascal. I can test on volta+pascal sometime.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll make the bfloat16 changes a separate PR.

@@ -1,38 +1,72 @@
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
from typing import Callable, Literal, NamedTuple, Optional, Union
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The repo generally doesn't have typing anywhere, so I'd say just remove it unless we plan to add in everywhere.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, typing stuff is generally considered good and if you'd be interested in adding it everywhere I can certainly do that. Also having type annotations in some places means it would be less work if you wanted it (everywhere) later on. I can just remove all the type annotations if that's what you really want though.


def dequantize(data, qtype, oshape, dtype=None):
if qtype == gguf.GGMLQuantizationType.BF16 and HAVE_BFLOAT16:
return tensor.view(dtype=torch.bfloat16).reshape(oshape).to(dtype)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this actually work...? ggml bf16 has a completely different data layout compared to pytorch bf16

(also, I don't super like the if/elif/else being flattened, the diff is harder to read with all the small changes)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will remove these changes and make it a separate PR. I did test though and it seemed to work fine with a model that had a bunch of BF16 tensors so I am pretty sure it couldn't be a different layout.

also, I don't super like the if/elif/else being flattened

You mean you don't like:

def blah():
  if condition:
    return 1
  elif other_condition:
    return 2
  else:
    return 3

Compared to:

def blah():
  if condition:
    return 1
  if other_condition:
    return 2
  return 3

Linters will complain about the former version because the elif and else are redundant and it's usually considered a "code smell" but I can do it that way if you want. (This particular part isn't going to be relevant for the reviewable pull but I can make sure I follow your style preference in other places.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this actually work...? ggml bf16 has a completely different data layout compared to pytorch bf16

>>> import numpy  as np
>>> import torch
>>> def quantize_blocks(blocks: np.ndarray) -> np.ndarray: # From gguf-py
...     n = blocks.view(np.uint32)
...     # force nan to quiet
...     n = np.where((n & 0x7fffffff) > 0x7f800000, (n & np.uint32(0xffff0000)) | np.uint32(64 << 16), n)
...     # round to nearest even
...     n = (np.uint64(n) + (0x7fff + ((n >> 16) & 1))) >> 16
...     return n.astype(np.uint16).view(np.uint8)
...
>>> torch.manual_seed(0)
<torch._C.Generator object at 0x7f841bbf9650>
>>> x = torch.randn(1000, dtype=torch.bfloat16)
>>> xnp = x.to(dtype=torch.float32).numpy()
>>> xqnp = quantize_blocks(xnp)
>>> xq = torch.tensor(xqnp)
>>> xq.dtype
torch.uint8
>>> xdq_manual = (xq.view(torch.int16).to(torch.int32) << 16).view(torch.float32)
>>> xdq_view = xq.view(dtype=torch.bfloat16).to(torch.float32)
>>> torch.equal(xdq_manual, xdq_view)
True

TL;DR: It's the same layout and just viewing is safe as long as Torch has bf16 support.

I'm not sure the GPU even needs to have BF16 support here since it's just a storage type and there's no math involved. Of course if the user has the compute dtype set to bf16 and their GPU doesn't support it then they're going to run into issues.

If you have access to a GPU without bf16 support, a simple test would just be to temporarily manifest a bf16 tensor and see if it causes problems:

>>> import torch
>>> torch.arange(100, dtype=torch.uint8).view(torch.bfloat16).view(torch.uint8)
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
        54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
        72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89,
        90, 91, 92, 93, 94, 95, 96, 97, 98, 99], dtype=torch.uint8)

@@ -0,0 +1,723 @@
from __future__ import annotations
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this required here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes the blah: some_type | None pattern usable with all the Python versions ComfyUI supports. Otherwise that has to be blah: Optional[type]. So it would be irrelevant if you don't want type annotations at all (or if you prefer the Optional pattern for some reason).

Comment on lines +237 to +251

_MODULE_NAMES = ("Linear", "Conv2d", "Embedding", "LayerNorm", "GroupNorm")

def __init__(self, *args, ggufconfig: Optional[GGUFConfig]=None, **kwargs):
super().__init__(*args, **kwargs)
linear_config = ggufconfig or DEFAULT_CONFIG
# Ignore patch_dtype and dequant_dtype for non-Linear layers.
other_config = linear_config._replace(patch_dtype=None, dequant_dtype=None)
self.ggufconfig = linear_config
for module_name in self._MODULE_NAMES:
module = getattr(self.__class__, module_name)
curr_config = linear_config if module_name == "Linear" else other_config
setattr(self, module_name, type(module_name, (module,), {"ggufconfig": curr_config}))


Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like there has to be a better way to pass the config to the modules but I just woke up so I can't think of anything...

(also, ggufconfig -> gguf_config if we're being consistent)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will change the name to gguf_config.

This was the best way I could think of to make sure it is present even if something else is trying to clone the class or whatever, since doing it this way make sure it's there as a default attribute. It could be done as a parameter to __init__ for each module, but that would require more changes and I'm not positive there isn't code somewhere that might try to recreate/rebuild them and wouldn't know to pass parameters like that.

HAVE_TRITON=True
except Exception as exc:
HAVE_TRITON=False
print(f"\nGGUF: Failed to enable Triton: {exc}")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logging.warning instead of print, that's what comfy uses. I should probably add my linter config to the repo.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These debug print statements will be removed before I mark this PR as ready. If you actually want some Triton status information logged to the console, let me know and I can add it.

triton_dequantize_functions={}


TORCH_COMPATIBLE_QTYPES = frozenset((None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

frozenset? we can probably leave it as a normal set

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can change it if you want. Usually it's better to use non-mutable types for stuff that isn't supposed to be mutable. It's a little easier to reason about (since you know stuff won't randomly be changing) and it's usually a little more efficient also.

Comment on lines 73 to 80
def dequantize_kernel(
q_tensor_ptr,
out_tensor_ptr,
n_total_blocks,
DTYPE: tl.constexpr,
N_BLOCKS_PER_PROG: tl.constexpr,
CTX: tl.constexpr,
) -> None:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I usually put all the args on one line, though that's mostly just a nitpick.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will change it to be on one line before I mark this as ready for review. (Let's leave this open so I don't forget.)

ml = dmin * (scale_byte >> 4).to(DTYPE)

# --- Map the 16 output elements to their source data ---
# This logic correctly models the Python reshape from a flat 256-element array.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like the numbered llm slop comments lol, they don't really do anything other than making the file longer to scroll through.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some of the comments are about the structure of what it's dequantizing so might be useful (if they're accurate, I haven't read through/verified them). There's also stuff like "I really did it right this time!" which isn't serving a purpose. I can try to curate them or I can just remove them all entirely if you prefer.

@blepping
Copy link
Contributor Author

@city96

Sorry, I might have wasted your time since I wasn't clear. I wasn't expecting you to go through the changes in-depth at this point. Was mostly just asking if you'd be interested in the general idea, like merging in Triton stuff at all before I put more time into cleaning it up. There's quite a bit of stuff in there that's just for testing like the print statements, etc. There are also some other changes (like the bfloat16 stuff) which are just there because this is the tree I'm testing my changes in. Again, I apologize for not being clear enough about that.

The ggufconfig (or soon to be gguf_config) config passing stuff and the Triton kernels kind of have to be in the same pull because there isn't really a good pass to pass configuration like that through the current system (and also the current approach has the issue I mentioned in the top comment). Other changes like (possibly) improving bfloat16 handling and the changes to the q5_0/q5_1 PT quants can be in a different pull.

Also, for the final file layout, I'd probably make a new folder named dequant, and rename dequant.py to dequant/kernel_torch.py and the triton one to dequant/kernel_triton.py or something similar. Then we can have the logic for picking either in the dequant folder's init, so it picks the correct one when you import from it. Thought that might look weird on the git diff so idk.

The current system allows you to override individual kernels/dequant functions (or even the main one) via configuration. So you could do stuff like Triton only for Q6_K and the existing PT functions for everything else and that kind of thing. We definitely could put the actual dequantization handling functions in separate files in a dequant sub-module though and have the __init__.py or whatever export stuff like the dicts of dequant functions and common types/functions. Does that idea sound okay?


I'll respond to all your reviews but it's going to be a lot of stuff to deal with and I know you're busy (also some of them aren't really that relevant) so no pressure to reply and I'll try to mark the ones that can just be closed.

@blepping
Copy link
Contributor Author

blepping commented Nov 5, 2025

Sorry for the lack of updates, some not so great stuff going on in my personal life at the moment. It's not abandoned but I haven't been able to work on any personal projects for a while now and I'm not sure when I will be able to complete this. I intend to, but it will probably be at least several weeks before I can even start on it again. Still no doubts about the value, I wouldn't use GGUF without these changes anymore since it's such a significant performance/memory improvement.

You can close this if having a stale open pull is bugging you, I don't think I'll be able to reopen it though due to it having force pushed changes.

@city96
Copy link
Owner

city96 commented Nov 8, 2025

No worries, I'm also not super active these days, part of that due to similar life reasons. I could try and take this on but probably wouldn't get too far (though if I do I'll definitely add you as a co-author and tag you).

I don't mind the PR being open, we can leave it as a draft like this for now, though I do still think it'd be cool to have alternate kernels like this as an option, assuming there's no edge cases we run into.

On that, sorry for never replying to your comments in the review. I'm somewhat easier to reach on discord since I can check that on my phone so if you do need more immediate feedback for something, do feel free to add me (name is the same as everywhere).

I did end up testing it on pascal/volta and at least based on this it does work (I didn't have comfy or any models so kinda had to improvise lol)

import torch
from dequant_triton import dequantize_functions

qtype = gguf.GGMLQuantizationType.Q4_0
fn = dequantize_functions[qtype]
x = torch.ones(1024, 1024*18, device="cuda:0", dtype=torch.uint8)
_ = fn(x, *gguf.GGML_QUANT_SIZES[qtype])

@StrongerXi
Copy link
Contributor

Sorry a bit late, I might be missing something, but my original thoughts were, if you torch.compile the dequantize_tensor function here, you'd essentially get the same or better perf as the manually written triton kernels, so why not just do that and avoid code debt?

weight = dequantize_tensor(tensor, dtype, self.dequant_dtype)

Wrapping the existing dequant functions with torch.compile in my benchmark tool brought the results to parity with Triton: https://gist.github.com/blepping/963459b244a4140b081cebdec24c56b2

However, I was not able to replicate that in the real world so the Triton kernels are back!

The fact that you couldn't replicate the perf in real world models sounds strange, would you mind elaborating what you tried?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants