Skip to content

Conversation

@blepping
Copy link
Contributor

@blepping blepping commented Sep 8, 2025

I'm opening this for discussion, this pull isn't really intended to be merged in its current state. If there's interest, I'll probably refactor stuff quite a bit.

Right now, this adds kernels for dequantizing Q4_K and Q6_K. Note: This is vibe coded with Gemini, so... Magical LLM code. I did tests comparing dequantization results with the official implementations and it seems to match. There's a Q8_0 kernel but it seems bad in the real world so that's disabled.

Result of some tests with Chroma Radiance (based on Chroma which is based on Flux Schnell, slightly smaller in terms of parameters):

Normal 16x16 Q4_K:
49/50 [00:34<00:00,  1.43it/s] Prompt executed in 34.38 seconds
49/50 [00:34<00:00,  1.44it/s] Prompt executed in 34.22 seconds

Triton 16x16 Q4_K:
49/50 [00:30<00:00,  1.60it/s] Prompt executed in 30.74 seconds
49/50 [00:30<00:00,  1.59it/s] Prompt executed in 30.83 seconds

Normal 16x16 Q6_K:
49/50 [00:48<00:00,  1.01it/s] Prompt executed in 48.78 seconds
49/50 [00:48<00:00,  1.01it/s] Prompt executed in 48.67 seconds

Triton 16x16 Q6_K:
49/50 [00:30<00:00,  1.58it/s] Prompt executed in 31.00 seconds
49/50 [00:30<00:00,  1.59it/s] Prompt executed in 30.93 seconds

Normal 1024x1024 Q6_K:
24/25 [04:13<00:10, 10.54s/it] Prompt executed in 254.80 seconds (DCT scaling)
24/25 [04:10<00:10, 10.43s/it] Prompt executed in 250.87 seconds (DCT scaling)

Triton 1024x1024 Q6_K:
24/25 [02:47<00:06,  6.97s/it] Prompt executed in 170.63 seconds (DCT scaling)
24/25 [02:46<00:06,  6.95s/it] Prompt executed in 167.27 seconds (DCT scaling)

Normal 16x16 Q8_0:
49/50 [00:13<00:00,  3.52it/s] Prompt executed in 14.89 seconds
49/50 [00:13<00:00,  3.53it/s] Prompt executed in 14.00 seconds

Triton 16x16 Q8_0:
49/50 [00:31<00:00,  1.57it/s] Prompt executed in 31.20 seconds

There's a small but noticeable improvement for Q4_K (the model I was testing had some Q5_K tensors which isn't implemented yet). A very noticeable improvement for Q6_K, seems around 1.5x performance.

Most of those tests are at 16x16 resolution, so basically the ideal case for maximizing quantization overhead and exaggerating dequantization performance changes. Interestingly, real generations with Q6_K retained the performance benefit.

Gemini needs a lot of tries and hand holding for these but it does seem like it can get there in the end.

Forgot to mention, there's a simple Triton toggle (GGUF) node that takes a model input. This is not a real model patch, it manipulates a global variable in dequant.py so the node has to run to apply its settings, deleting/muting/bypassing to turn off Triton won't work. Triton defaults to enabled here if there's support.

@blepping
Copy link
Contributor Author

blepping commented Sep 8, 2025

I don't have a pure Q5_K model to test with at the moment but my Q4_K model has some Q5_K tensors. Adding a Q5_K kernel resulted in a noticeable improvement:

Normal:
49/50 [00:34<00:00,  1.43it/s] Prompt executed in 34.38 seconds
49/50 [00:34<00:00,  1.44it/s] Prompt executed in 34.22 seconds

Triton with only Q4_K kernel:
49/50 [00:30<00:00,  1.60it/s] Prompt executed in 30.62 seconds
49/50 [00:30<00:00,  1.59it/s] Prompt executed in 30.85 seconds

Triton with Q5_K kernel
49/50 [00:28<00:00,  1.74it/s] Prompt executed in 28.21 seconds
49/50 [00:28<00:00,  1.74it/s] Prompt executed in 28.19 seconds

edit: Did some testing with Qwen Image Q4_K:

Normal 768x688:
15/15 [01:27<00:00,  5.82s/it

Triton 768x688:
15/15 [00:57<00:00,  3.83s/it

Normal 1328x1328:
15/15 [02:46<00:00, 11.10s/it

Triton 1328x1328:
15/15 [02:13<00:00,  8.87s/it

@Ph0rk0z
Copy link

Ph0rk0z commented Sep 8, 2025

And here I am with only Q8 chromas. Until you add lora, Q8 was almost as fast as FP8. Possible to speed up the TE as well?

@blepping
Copy link
Contributor Author

blepping commented Sep 8, 2025

And here I am with only Q8 chromas.

There's a Q8_0 kernel in there but the performance was pretty bad in my testing. If you want to try it you can uncomment the line at the bottom of dequant_triton.py to enable it. The format is so simple that it doesn't really seem like the overhead of stuff like launching Triton kernels is worth it in that case.

Until you add lora, Q8 was almost as fast as FP8.

Try using the advanced node and setting the types to the same type as what you're running model in. (You'd want to set it to "target" but there's a bug that prevents that from working, unless you want to apply by other pull.) Setting patch on device might hurt performance too.

Possible to speed up the TE as well?

Dequantization doesn't care about the type of model, so as long as it's something ComfyUI-GGUF supports and you have the model in GGUF format you can try it. Quantized models are typically slower though, the benefit is that more memory may be available so ComfyUI doesn't have to shuffle layers back and forth between RAM and the GPU. In other words, if you already can fit the whole model in memory and run it then the non-quantized version is going to be faster. Of course, there are other considerations like disk space. You also only do a pass through the model each time you encode the prompt compared to sampling where there are repeated passes, so generally it's going to be a less noticeable benefit.

@Ph0rk0z
Copy link

Ph0rk0z commented Sep 8, 2025

Now that I look at the code, since the toggle takes a model input, I don't think it can be included. I will definitely see what happens to lora after playing with the advanced node.

With LLMs, the quantized models are faster than the bare ones. Not sure how much of that is custom kernels and how much is the lower amount of memory. On CPU at least, the bigger the model, the higher the memory speed requirements. It's mirrored for me on wan distributed. If I use a BF16/FP16 model over an FP8 despite the activations probably being cast the same and having plenty of memory for both.

@blepping
Copy link
Contributor Author

blepping commented Sep 8, 2025

Now that I look at the code, since the toggle takes a model input, I don't think it can be included.

As I mentioned in the initial comment, this is not a normal model patch, it just sets a global flag. The Triton kernels are also enabled by default (globally) if you have Triton support.

The node takes an input just so there is something that requires its output (otherwise it wouldn't run). Note that this is a temporary testing thing, if this actually turns into a real pull and gets merged then there will be a less clunky way to configure whether those kernels get used or not.

@city96
Copy link
Owner

city96 commented Sep 9, 2025

This definitely looks interesting, almost feels like we could try and include multiple different dequant kernels and just run a small throughput benchmark with fake data to pick the fastest one for each quantization level, with the default being the pytorch ones.

You could use add_object_patch or whatever the function was called to patch the value, meaning it should be possible to undo that as long as we pass it along properly and make it depend on the model instead of a global variable (kinda like dequant dtype or patch dtype I guess).

I don't have much experience with triton, so I can't comment on the actual dequant kernels. I still think an interesting thing to test would be Isotr0py/ggml-libtorch, which if I understood it right has the vLLM kernels for both dequantizing and direct matmul as well.

@Ph0rk0z
Copy link

Ph0rk0z commented Sep 9, 2025

custom matmul kernels instead of whatever pytorch does would probably crank. only nunchaku and some dead projects have gone this route.

@blepping
Copy link
Contributor Author

@city96

This definitely looks interesting, almost feels like we could try and include multiple different dequant kernels and just run a small throughput benchmark with fake data to pick the fastest one for each quantization level, with the default being the pytorch ones.

It's not impossible to do this, but I think it might be tricky. I'm not really convinced it's necessary, since at least for my GPU (4060Ti) everything except Q8_0 is a significant performance benefit. It also makes sense why Q8_0 wouldn't be worthwhile since it's super simple compared to K-quants. There isn't really anything to optimize, while something like Q6_K requires a ton of complicated bit twiddling.

You could use add_object_patch or whatever the function was called to patch the value, meaning it should be possible to undo that as long as we pass it along properly and make it depend on the model instead of a global variable (kinda like dequant dtype or patch dtype I guess).

Hmm, doesn't seem like anything complicated like that would be necessary since you already have a way to pass options like the dequant type to the dequantization code. How do something like this sound as an approach to configuring it:

Add a string parameter to the advanced loader that takes a comma (or space separated, if you prefer) list of quants to enable Triton kernels for. If empty, then none are enabled. We can make it a little easier to use by letting * mean "add everything" and prefixing an item with - (minus) to remove it. So you could do something like *, -q4_k to enable everything but Q4_K. The default could be blank, later on if you wanted to enable it by default when Triton was available you could just change that to *.

At the point the node is created we'd also know if Triton is available so the tooltip for the widget could have information about what quants are available and whether Triton is available. Would something like that work?

I don't have much experience with triton, so I can't comment on the actual dequant kernels.

Me either, I just wanted to see if this was possible initially. As a Python programmer, I'm sure you can look at that stuff like I can and understand some of it, since it's basically a restricted version of Python. For example, you can't do stuff like break from loops or use global variables in a Triton kernel (that's why we're passing stuff like QK_K as an argument).

I still think an interesting thing to test would be Isotr0py/ggml-libtorch, which if I understood it right has the vLLM kernels for both dequantizing and direct matmul as well.

That would be harder to integrate since those are C++ CUDA kernels, not Triton. That means someone who wanted to use them would need to have a development toolchain set up to compile CUDA kernels. I'm not sure if this would also potentially cause any problems with ComfyUI's normal memory management, maybe it would be okay since I think there are ways to integrate PyTorch with C++ CUDA kernels. Doing it that way would probably have to come from someone other than me though.

So I guess the question is: Would you be interested in merging something like this with Triton kernels?


@Ph0rk0z

custom matmul kernels instead of whatever pytorch does would probably crank.

Maybe, it would be fairly complicated to implement though. The way GGML does it uses a row based approach, I believe. So it dequantizes rows and performs the matmul. You can't just directly matmul with the quantized data since GGML quants are block-based, so there are probably two main benefits: slightly lower memory usage (you'd save the size of one dequantized tensor, roughly) and maybe better cache usage.

Just doing a normal performant matmul certainly isn't simple: https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html

Also think that's only matrix to matrix, the vector to matrix case is also pretty common. Anyway, not really thinking too much about that right now, the first step is just normal accelerated dequantization and then maybe eventually we can look at stuff like fused dequantization with operations like matmul.

I did a little experimentation with refactoring the kernels into modules like dequantize a block, dequantize a chunk, etc. So refactoring it so there's a common dequantize function one-shot dequantization or fused versions could use is probably possible.

Remove Q8_0 Triton kernel
@blepping
Copy link
Contributor Author

blepping commented Sep 11, 2025

Simple Triton benchmarking/testinging utility: https://gist.github.com/blepping/963459b244a4140b081cebdec24c56b2

You can run this with something like python -O test_triton_dequant.py --qtype=q4_k --dtype=float32 to get output like:

GGUF: Failed to enable Triton: attempted relative import with no known parent package


*** Testing quantization type Q4_K
Original tensor shape: (1, 128, 1024, 1024), will use dtype torch.float32
Quantized tensor shape: torch.Size([1, 128, 1024, 576]), dtype: torch.uint8

--- Running validation ---
Verification 1: User's PyTorch reference matches official GGUF library: True
Verification 2: Triton implementation matches official GGUF library: True
Verification 3: Triton implementation matches PyTorch reference: True
------------------------------


--- Benchmarking ---
PyTorch  : 11.2213 ms
Triton   : 3.1613 ms

Speedup  : 3.55x

The "failed to enable Triton" stuff can be ignored. The default shape requires about 5GB VRAM to run. Run the utility with --help for available options and default values.

Validation fails when the dtype isn't float32, but the existing PyTorch version also fails in that case. I think the differences are small and due to running some of the internal dequant math in the output dtype, the models still seem to work.

These are the results I get running this for the permutations of supported qtypes and dtypes with a 4060Ti 16GB. The value is the speed factor compared to the PyTorch implementation.

qtype float32 float16 bfloat16
Q4_K 3.57 3.78 3.82
Q5_K 4.73 4.60 4.36
Q6_K 3.85 4.36 4.21

Command used: for qt in q4_k q5_k q6_k; do for dt in float32 float16 bfloat16; do python -O test_triton_dequant.py --no-validate --qtype=$qt --dtype=$dt; done; done

@blepping
Copy link
Contributor Author

Well, it pains me to say this but I tried just wrapping the existing dequantization functions with torch.compile and the performance is equivalent. :( I guess I'll close this, doesn't really make sense to add this much complexity when you can literally just torch.compile(quant_function) to get the same effect. Silly that I didn't think to try that before I put so much work into it!

It would probably make sense to add an option to the loader to use compiled GGUF dequantization functions. It's not always desirable/practical to compile the whole model, but compiling just these individual functions is quite fast so should work for anyone with torch.compile support. It would be pretty simple to implement as an option in the loader.

@blepping blepping closed this Sep 12, 2025
@Ph0rk0z
Copy link

Ph0rk0z commented Sep 12, 2025

Did you try the torch compile on multiple card generations? It may be faster on a 4xxx card but not 2xxx or 3xxx. Plus I'm not sure that we can compile the TE very easily with a node.

@blepping
Copy link
Contributor Author

@Ph0rk0z

Did you try the torch compile on multiple card generations? It may be faster on a 4xxx card but not 2xxx or 3xxx.

No, I've only tested on my own GPU. It would be surprising to find that the Triton kernels outperform torch.compile on some GPUs though. Also, you're going to have a hard time getting Triton working on legacy GPUs compared to torch.compile (if it's supported at all).

I updated the benchmark tool to take a --compile flag which will enable compiling the PyTorch dequant functions: https://gist.github.com/blepping/963459b244a4140b081cebdec24c56b2

If you can get someone to test it with that and they find that the Triton kernels are a noticeable performance benefit compared to compiling then let me know and I can revisit this. It also doesn't necessarily have to be in the ComfyUI-GGUF project, so I could potentially make it a separate addon or something.

Plus I'm not sure that we can compile the TE very easily with a node.

I'm only talking about compiling the GGUF dequantization functions, not the whole model. It's definitely possible to do this, you'd just need a parameter in the node to enable using compiling dequantization functions.

@blepping
Copy link
Contributor Author

It seems I closed this prematurely. Wrapping the dequant functions in torch.compile actually did not seem to help when running actual models. Continued in #336 since I can't reopen this pull.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants