-
Notifications
You must be signed in to change notification settings - Fork 222
Triton dequantization/config framework #336
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Remove Q8_0 Triton kernel
Fix module config handling Allow compiling GGUF dequant functions
|
Hmm I think if you use the comfyui builtin or Kijai's torch compile nodes, you should get the same, or maybe even better perf than the triton kernels. What torch.compile does under the hood is generating optimized triton kernels automatically. |
This pull actually already includes an option to compile the dequantization functions. You can try it if you want, it didn't affect the performance in my testing and the reports I've gotten. It's harder to compare performance when compiling the whole model rather than just the dequant functions, but compiling the whole model also isn't always desirable (or possible in some cases). |
with wan this doesn't make a difference , tried with q4km BUT as you can see from my tests it certainly works. BTW , I am also using the standard torch.compile and sage-attention on top of this so this is just about gguf itself and it works. |
Refactoring/cleanups
|
q3 kernel also shows same amount of speed improvement , around 12% at 2048x2048 with hunyuan-image. |
|
Q8_0 isn't worth it, but as of now there are now Triton kernels for all the K-quants and legacy quants. The following tests were done using this utility I made: https://gist.github.com/blepping/963459b244a4140b081cebdec24c56b2 Using it requires running from the repo directory with this branch checkout and a venv with the necessary dependencies (PyTorch, Triton, etc) activated. I used this command (Linux/Unix, presumably would work with WSL) to run it over all the available quants and dtypes: PyTorch vs TritonComparing performance of the Triton kernels vs the existing PyTorch dequant functions. 2.0 in a column would mean the Triton version was two times faster. These results are from benchmarking the dequant functions in isolation so you won't see the same speedup running an actual model. For reference, Q4_K is ~3.5x here, for moderate image sizes with models like Flux, Qwen the real world performance benefit is more like 1.2x. The Q8_0 kernel which wasn't worth using was around 1.4x here I will have to do some real testing with the quants that seem a bit borderline to find out if having them enabled is actually worth it (Q4_0, Q2_K at non-32bit, etc).
PyTorch vs PyTorch float32Comparing the PyTorch dequant functions with
PyTorch Q4_0 vs PyTorch other quantsThis also between PyTorch implementations, no Triton. Comparing the dequantization speed vs Q4_0 (Q8_0 may be faster but I don't have it implemented). So we say Q4_0 is 1.0x speed and compare the other quants to this. Comparison done with
These tests were done before my changes to eliminate the slow |
|
Quick'n'dirty testing with the legacy quants and Chroma Radiance (float16 output): Looking pretty good, especially the 5bit quants. I also tried testing a full-sized generation with Q5_1 (bottom results). I ran out of memory twice in a row testing the non-Triton version, VRAM was around 96% used before it failed. With the Triton kernel it stayed around 80% so at least for Q5_1 it seems like using Triton uses less memory as well as being faster. |
|
@city96 This needs some cleanup, but if I get it presentable would you be interested in merging these changes? |
Show Triton status in tooltip for optimize param in advanced loader
Tweak autotune configs a bit The return of the Q8_0 Triton kernel, matches or slightly exceeds PT performance now
city96
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@blepping I'll have to test this properly sometime, just haven't really had time lately (or access to usable hardware most of the time lol I'm writing this review on my laptop so I can't even run the code).
Most of the comments are just random nitpicks or stuff I noticed looking over the code, again, I can't run it and haven't tested the control flow so half of them could be wrong.
Also, for the final file layout, I'd probably make a new folder named dequant, and rename dequant.py to dequant/kernel_torch.py and the triton one to dequant/kernel_triton.py or something similar. Then we can have the logic for picking either in the dequant folder's init, so it picks the correct one when you import from it. Thought that might look weird on the git diff so idk.
| import torch | ||
| from tqdm import tqdm | ||
|
|
||
| HAVE_BFLOAT16=hasattr(torch, "bfloat16") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't we need to check if the current device actually supports it? i.e. we'd want to test this on RTX 20XX, Volta and Pascal. I can test on volta+pascal sometime.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll make the bfloat16 changes a separate PR.
| @@ -1,38 +1,72 @@ | |||
| # (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) | |||
| from typing import Callable, Literal, NamedTuple, Optional, Union | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The repo generally doesn't have typing anywhere, so I'd say just remove it unless we plan to add in everywhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, typing stuff is generally considered good and if you'd be interested in adding it everywhere I can certainly do that. Also having type annotations in some places means it would be less work if you wanted it (everywhere) later on. I can just remove all the type annotations if that's what you really want though.
|
|
||
| def dequantize(data, qtype, oshape, dtype=None): | ||
| if qtype == gguf.GGMLQuantizationType.BF16 and HAVE_BFLOAT16: | ||
| return tensor.view(dtype=torch.bfloat16).reshape(oshape).to(dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this actually work...? ggml bf16 has a completely different data layout compared to pytorch bf16
(also, I don't super like the if/elif/else being flattened, the diff is harder to read with all the small changes)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will remove these changes and make it a separate PR. I did test though and it seemed to work fine with a model that had a bunch of BF16 tensors so I am pretty sure it couldn't be a different layout.
also, I don't super like the
if/elif/elsebeing flattened
You mean you don't like:
def blah():
if condition:
return 1
elif other_condition:
return 2
else:
return 3Compared to:
def blah():
if condition:
return 1
if other_condition:
return 2
return 3Linters will complain about the former version because the elif and else are redundant and it's usually considered a "code smell" but I can do it that way if you want. (This particular part isn't going to be relevant for the reviewable pull but I can make sure I follow your style preference in other places.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this actually work...? ggml bf16 has a completely different data layout compared to pytorch bf16
>>> import numpy as np
>>> import torch
>>> def quantize_blocks(blocks: np.ndarray) -> np.ndarray: # From gguf-py
... n = blocks.view(np.uint32)
... # force nan to quiet
... n = np.where((n & 0x7fffffff) > 0x7f800000, (n & np.uint32(0xffff0000)) | np.uint32(64 << 16), n)
... # round to nearest even
... n = (np.uint64(n) + (0x7fff + ((n >> 16) & 1))) >> 16
... return n.astype(np.uint16).view(np.uint8)
...
>>> torch.manual_seed(0)
<torch._C.Generator object at 0x7f841bbf9650>
>>> x = torch.randn(1000, dtype=torch.bfloat16)
>>> xnp = x.to(dtype=torch.float32).numpy()
>>> xqnp = quantize_blocks(xnp)
>>> xq = torch.tensor(xqnp)
>>> xq.dtype
torch.uint8
>>> xdq_manual = (xq.view(torch.int16).to(torch.int32) << 16).view(torch.float32)
>>> xdq_view = xq.view(dtype=torch.bfloat16).to(torch.float32)
>>> torch.equal(xdq_manual, xdq_view)
True
TL;DR: It's the same layout and just viewing is safe as long as Torch has bf16 support.
I'm not sure the GPU even needs to have BF16 support here since it's just a storage type and there's no math involved. Of course if the user has the compute dtype set to bf16 and their GPU doesn't support it then they're going to run into issues.
If you have access to a GPU without bf16 support, a simple test would just be to temporarily manifest a bf16 tensor and see if it causes problems:
>>> import torch
>>> torch.arange(100, dtype=torch.uint8).view(torch.bfloat16).view(torch.uint8)
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89,
90, 91, 92, 93, 94, 95, 96, 97, 98, 99], dtype=torch.uint8)
| @@ -0,0 +1,723 @@ | |||
| from __future__ import annotations | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this required here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It makes the blah: some_type | None pattern usable with all the Python versions ComfyUI supports. Otherwise that has to be blah: Optional[type]. So it would be irrelevant if you don't want type annotations at all (or if you prefer the Optional pattern for some reason).
|
|
||
| _MODULE_NAMES = ("Linear", "Conv2d", "Embedding", "LayerNorm", "GroupNorm") | ||
|
|
||
| def __init__(self, *args, ggufconfig: Optional[GGUFConfig]=None, **kwargs): | ||
| super().__init__(*args, **kwargs) | ||
| linear_config = ggufconfig or DEFAULT_CONFIG | ||
| # Ignore patch_dtype and dequant_dtype for non-Linear layers. | ||
| other_config = linear_config._replace(patch_dtype=None, dequant_dtype=None) | ||
| self.ggufconfig = linear_config | ||
| for module_name in self._MODULE_NAMES: | ||
| module = getattr(self.__class__, module_name) | ||
| curr_config = linear_config if module_name == "Linear" else other_config | ||
| setattr(self, module_name, type(module_name, (module,), {"ggufconfig": curr_config})) | ||
|
|
||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like there has to be a better way to pass the config to the modules but I just woke up so I can't think of anything...
(also, ggufconfig -> gguf_config if we're being consistent)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will change the name to gguf_config.
This was the best way I could think of to make sure it is present even if something else is trying to clone the class or whatever, since doing it this way make sure it's there as a default attribute. It could be done as a parameter to __init__ for each module, but that would require more changes and I'm not positive there isn't code somewhere that might try to recreate/rebuild them and wouldn't know to pass parameters like that.
| HAVE_TRITON=True | ||
| except Exception as exc: | ||
| HAVE_TRITON=False | ||
| print(f"\nGGUF: Failed to enable Triton: {exc}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logging.warning instead of print, that's what comfy uses. I should probably add my linter config to the repo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These debug print statements will be removed before I mark this PR as ready. If you actually want some Triton status information logged to the console, let me know and I can add it.
| triton_dequantize_functions={} | ||
|
|
||
|
|
||
| TORCH_COMPATIBLE_QTYPES = frozenset((None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
frozenset? we can probably leave it as a normal set
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can change it if you want. Usually it's better to use non-mutable types for stuff that isn't supposed to be mutable. It's a little easier to reason about (since you know stuff won't randomly be changing) and it's usually a little more efficient also.
dequant_triton.py
Outdated
| def dequantize_kernel( | ||
| q_tensor_ptr, | ||
| out_tensor_ptr, | ||
| n_total_blocks, | ||
| DTYPE: tl.constexpr, | ||
| N_BLOCKS_PER_PROG: tl.constexpr, | ||
| CTX: tl.constexpr, | ||
| ) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I usually put all the args on one line, though that's mostly just a nitpick.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will change it to be on one line before I mark this as ready for review. (Let's leave this open so I don't forget.)
| ml = dmin * (scale_byte >> 4).to(DTYPE) | ||
|
|
||
| # --- Map the 16 output elements to their source data --- | ||
| # This logic correctly models the Python reshape from a flat 256-element array. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't like the numbered llm slop comments lol, they don't really do anything other than making the file longer to scroll through.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some of the comments are about the structure of what it's dequantizing so might be useful (if they're accurate, I haven't read through/verified them). There's also stuff like "I really did it right this time!" which isn't serving a purpose. I can try to curate them or I can just remove them all entirely if you prefer.
|
Sorry, I might have wasted your time since I wasn't clear. I wasn't expecting you to go through the changes in-depth at this point. Was mostly just asking if you'd be interested in the general idea, like merging in Triton stuff at all before I put more time into cleaning it up. There's quite a bit of stuff in there that's just for testing like the print statements, etc. There are also some other changes (like the bfloat16 stuff) which are just there because this is the tree I'm testing my changes in. Again, I apologize for not being clear enough about that. The
The current system allows you to override individual kernels/dequant functions (or even the main one) via configuration. So you could do stuff like Triton only for Q6_K and the existing PT functions for everything else and that kind of thing. We definitely could put the actual dequantization handling functions in separate files in a I'll respond to all your reviews but it's going to be a lot of stuff to deal with and I know you're busy (also some of them aren't really that relevant) so no pressure to reply and I'll try to mark the ones that can just be closed. |
Cleanups/style fixes
|
Sorry for the lack of updates, some not so great stuff going on in my personal life at the moment. It's not abandoned but I haven't been able to work on any personal projects for a while now and I'm not sure when I will be able to complete this. I intend to, but it will probably be at least several weeks before I can even start on it again. Still no doubts about the value, I wouldn't use GGUF without these changes anymore since it's such a significant performance/memory improvement. You can close this if having a stale open pull is bugging you, I don't think I'll be able to reopen it though due to it having force pushed changes. |
|
No worries, I'm also not super active these days, part of that due to similar life reasons. I could try and take this on but probably wouldn't get too far (though if I do I'll definitely add you as a co-author and tag you). I don't mind the PR being open, we can leave it as a draft like this for now, though I do still think it'd be cool to have alternate kernels like this as an option, assuming there's no edge cases we run into. On that, sorry for never replying to your comments in the review. I'm somewhat easier to reach on discord since I can check that on my phone so if you do need more immediate feedback for something, do feel free to add me (name is the same as everywhere). I did end up testing it on pascal/volta and at least based on this it does work (I didn't have comfy or any models so kinda had to improvise lol) import torch
from dequant_triton import dequantize_functions
qtype = gguf.GGMLQuantizationType.Q4_0
fn = dequantize_functions[qtype]
x = torch.ones(1024, 1024*18, device="cuda:0", dtype=torch.uint8)
_ = fn(x, *gguf.GGML_QUANT_SIZES[qtype]) |
|
Sorry a bit late, I might be missing something, but my original thoughts were, if you Line 177 in 02dac86
The fact that you couldn't replicate the perf in real world models sounds strange, would you mind elaborating what you tried? |



Continuation of #331. Sorry for making a new pull, I couldn't reopen the existing one since it had been force pushed.
Wrapping the existing dequant functions with
torch.compilein my benchmark tool brought the results to parity with Triton: https://gist.github.com/blepping/963459b244a4140b081cebdec24c56b2However, I was not able to replicate that in the real world so the Triton kernels are back!
This includes a generic approach to pass configuration parameters around, including overriding the dequant functions (potentially with compiled or Triton versions). There is an
optimizeparameter in the advanced loader that lets you choose betweennone,compileandtriton. Enabling Triton for me is noticeably faster, usingtorch.compileis actually slower than nothing.This also fixes an issue with the existing approach to setting configuration like
dequant_dtype:This is not kosher because while
opsis an instance ofGGMLOps, attributes likeops.Linearare not instance specific, so setting it here will change that attribute in the global class. This means if there are multiple loaders with different settings, you will get the configuration of whatever loader overwrote that value last.