Skip to content

Conversation

@hauhaut
Copy link

@hauhaut hauhaut commented Dec 16, 2025

cuda kernel for delta-net linear attention layers in qwen3next.

adds GGML_OP_DELTA_NET + recurrent kernel for decode, blackwell path (sm12.0+) for prefill with 64k shmem. also improved solve_tri for the chunked prefill path.

getting ~45-55 t/s on q4/mxfp4 and ~40 t/s bf16 on 80B-A3B (blackwell). pre-blackwell cards get ~38-40 t/s from solve_tri improvements (baseline was the original ~20 t/s).

LLM assistance has been used throughout the process to help alleviate workload. Every single aspect of the process though has been manually tested and validated before being posted here. My key concern is the feasibility of the implementation onto the main branch as it's ultimately an entirely new addition to it which comes with its own risks. I defer to the active maintainers of the project on this.

Edit: omitted some small bits.

@hauhaut hauhaut closed this Dec 16, 2025
@hauhaut hauhaut reopened this Dec 16, 2025
Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am in principle willing to review the CUDA code but from a broader ggml perspective the PR cannot be merged like this. Preferably implement your kernel as a fused operation that is entirely contained within the CUDA backend. If this is not possible, support in the CPU backend is mandatory both as a fallback for other backends and to assert that the new ggml op is working correctly in test-backend-ops. (For a fused op new tests in test-backend-ops should also be added.)

cc @am17an (who has recently worked on fusion within the CUDA backend)

@hauhaut
Copy link
Author

hauhaut commented Dec 16, 2025

I am in principle willing to review the CUDA code but from a broader ggml perspective the PR cannot be merged like this. Preferably implement your kernel as a fused operation that is entirely contained within the CUDA backend. If this is not possible, support in the CPU backend is mandatory both as a fallback for other backends and to assert that the new ggml op is working correctly in test-backend-ops. (For a fused op new tests in test-backend-ops should also be added.)

cc @am17an (who has recently worked on fusion within the CUDA backend)

thanks for the feedback. looked into the fused-op-only approach but delta-net has recurrent state that persists across calls - similar to mamba's ssm_scan or rwkv's wkv ops. the state update semantics are subtle enough that pattern-based fusion would be fragile. will add cpu fallback + test-backend-ops tests. should have that up soon.

@hauhaut
Copy link
Author

hauhaut commented Dec 16, 2025

Added CPU fallback and test-backend-ops coverage. CUDA passes against CPU reference.

$ ./test-backend-ops -o DELTA_NET
Backend 1/4: CUDA0 (RTX 6000 Ada)
DELTA_NET(type=f32,n_heads=8,head_dim=64,n_tokens=1,n_seqs=1): OK
DELTA_NET(type=f32,n_heads=8,head_dim=64,n_tokens=32,n_seqs=1): OK
DELTA_NET(type=f32,n_heads=8,head_dim=64,n_tokens=32,n_seqs=2): OK
DELTA_NET(type=f32,n_heads=8,head_dim=64,n_tokens=128,n_seqs=2): OK
4/4 tests passed
Backend 2/4: CUDA1 (RTX PRO 6000 Blackwell)
4/4 tests passed
Backend 3/4: CUDA2 (RTX PRO 6000 Blackwell)
4/4 tests passed
4/4 backends passed

@pwilkin
Copy link
Collaborator

pwilkin commented Dec 16, 2025

Would be good to ask @ggerganov for his opinion because when I was implementing Qwen3Next he said he didn't want to add custom per-model kernels.

@hauhaut
Copy link
Author

hauhaut commented Dec 16, 2025

GGML_OP_DELTA_NET isn't a per-model kernel. It's a general linear attention op, same category as GLA, WKV6/7, and SSM_SCAN. These exist in ggml because they're architectural primitives that could be used by any model implementing that attention mechanism. Happy to hear ggerganov's take though (and to drop the PR of course if not feasible for llama.cpp)

@github-actions github-actions bot added model Model specific testing Everything test related Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Dec 16, 2025
@jeffbolznv
Copy link
Collaborator

I'm not an ML expert, but this does seem to be an operation that appears in multiple models and is "too much" to reconstruct in fusion. I can't comment on the correctness of this definition, but if it's self-contained like this and can replace this whole block in qwen3next it seems appealing to have as an op.

Did you use AI to write the code? It seems like too many variants stamped out that don't need to be (can just be templated) in a way that AI might do. Also the weird deleting of comments...

@hauhaut
Copy link
Author

hauhaut commented Dec 17, 2025

I'm not an ML expert, but this does seem to be an operation that appears in multiple models and is "too much" to reconstruct in fusion. I can't comment on the correctness of this definition, but if it's self-contained like this and can replace this whole block in qwen3next it seems appealing to have as an op.

Did you use AI to write the code? It seems like too many variants stamped out that don't need to be (can just be templated) in a way that AI might do. Also the weird deleting of comments...

The kernel variants aren't duplicates. They target different memory hierarchies (global vs 64KB shared), data types (FP32 vs FP16 with half2), and parallelization strategies (single block vs column-parallel). Templating them together would generate the same code with more complex dispatch. The deleted comments were noise (// Apply sigmoid before ggml_sigmoid()). Happy to discuss specific consolidations if you see any.

Yes, I use AI for scaffolding and iteration. The kernels went through extensive validation and debugging before landing here.

@am17an
Copy link
Collaborator

am17an commented Dec 17, 2025

The kernel variants aren't duplicates. They target different memory hierarchies (global vs 64KB shared), data types (FP32 vs FP16 with half2), and parallelization strategies (single block vs column-parallel). Templating them together would generate the same code with more complex dispatch. The deleted comments were noise (// Apply sigmoid before ggml_sigmoid()). Happy to discuss specific consolidations if you see any.

This also seems to be generated by AI

Yes, I use AI for scaffolding and iteration. The kernels went through extensive validation and debugging before landing here.

I see that the code assumes Blackwell has 64kb per block and has 2 separate kernels for it, which it doesn't according to the spec (still returns 48kb). So I am suspicious of this claim.

In general I believe the PR could be useful but I have low confidence in its claims, and I'm not going to review a 1400 line cuda kernel. I think it could be consolidated to a single kernel with some templates for data-types, even then the AI responses to questions are off-putting

@IIIIIllllIIIIIlllll
Copy link

IIIIIllllIIIIIlllll commented Dec 17, 2025

Fantastic work!
I managed to get this branch running on ROCm with some straightforward modifications. The performance improvement on my AI MAX+ 395 is very noticeable. I really hope to see this merged into the main branch once it's been refined a bit more.

(Of course, I am not familiar with the relevant knowledge; I am just sharing some test results.)

before (master):

ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 ROCm devices:
  Device 0: AMD Radeon Graphics, gfx1151 (0x1151), VMM: no, Wave Size: 32
| model                          |       size |     params | backend    | ngl | n_ubatch | fa | mmap |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: |
| qwen3next 80B.A3B Q8_0         |  79.57 GiB |    79.67 B | ROCm       |  99 |     2048 |  1 |    0 |          pp2048 |        587.16 ± 0.89 |
| qwen3next 80B.A3B Q8_0         |  79.57 GiB |    79.67 B | ROCm       |  99 |     2048 |  1 |    0 |            tg32 |         27.54 ± 0.04 |

build: unknown (0)

after (this PR):

ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 ROCm devices:
  Device 0: AMD Radeon Graphics, gfx1151 (0x1151), VMM: no, Wave Size: 32
| model                          |       size |     params | backend    | ngl | n_ubatch | fa | mmap |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: |
| qwen3next 80B.A3B Q8_0         |  79.57 GiB |    79.67 B | ROCm       |  99 |     2048 |  1 |    0 |          pp2048 |        516.30 ± 0.72 |
| qwen3next 80B.A3B Q8_0         |  79.57 GiB |    79.67 B | ROCm       |  99 |     2048 |  1 |    0 |            tg32 |         30.00 ± 0.04 |

build: unknown (0)

@hauhaut
Copy link
Author

hauhaut commented Dec 17, 2025

Fantastic work! I managed to get this branch running on ROCm with some straightforward modifications. The performance improvement on my AI MAX+ 395 is very noticeable. I really hope to see this merged into the main branch once it's been refined a bit more.

commad:

before (master):

**ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 ROCm devices:
  Device 0: AMD Radeon Graphics, gfx1151 (0x1151), VMM: no, Wave Size: 32
| model                          |       size |     params | backend    | ngl | n_ubatch | fa | mmap |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: |
| qwen3next 80B.A3B Q8_0         |  79.57 GiB |    79.67 B | ROCm       |  99 |     2048 |  1 |    0 |          pp2048 |        587.16 ± 0.89 |
| qwen3next 80B.A3B Q8_0         |  79.57 GiB |    79.67 B | ROCm       |  99 |     2048 |  1 |    0 |            tg32 |         27.54 ± 0.04 |

build: unknown (0)**

after (this PR):

ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 ROCm devices:
  Device 0: AMD Radeon Graphics, gfx1151 (0x1151), VMM: no, Wave Size: 32
| model                          |       size |     params | backend    | ngl | n_ubatch | fa | mmap |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: |
| qwen3next 80B.A3B Q8_0         |  79.57 GiB |    79.67 B | ROCm       |  99 |     2048 |  1 |    0 |          pp2048 |        516.30 ± 0.72 |
| qwen3next 80B.A3B Q8_0         |  79.57 GiB |    79.67 B | ROCm       |  99 |     2048 |  1 |    0 |            tg32 |         30.00 ± 0.04 |

build: unknown (0)

Thank you, we'll see what the others think. Starting to regret a little getting convinced to share.

@jeffbolznv
Copy link
Collaborator

The kernel variants aren't duplicates. They target different memory hierarchies (global vs 64KB shared), data types (FP32 vs FP16 with half2), and parallelization strategies (single block vs column-parallel).

I may have missed some subtleties in the delta net kernels, it's a lot of code and I only skimmed it. But I can say with some confidence that >800 lines of solve_tri is overkill ;-)

I see that this change takes qwen3next from 11185 down to 6001 nodes, at first glance that seems really good. I get mixed performance results - +10% for tg128, -20% for pp512 (cuda backend, rtx 5090). But the cuda backend is currently significantly slower than the vulkan backend so there's some kind of perf bug, and it makes it hard to make performance judgments on this change until that issue gets resolved.

@JohannesGaessler
Copy link
Collaborator

@hauhaut you are expected to read the contributing guidelines before opening a PR where it clearly states:

Using AI to generate PRs is permitted. However, you must (1) explicitly disclose how AI was used and (2) conduct a thorough manual review before publishing the PR. Note that trivial tab autocompletions do not require disclosure.

My previous comment regarding willingness to review was without me having looked at the code at all. In its current state I agree with Aman's comments: this PR imposes too much of a maintenance burden vs. the potential benefits and I will not review it.

@ggerganov
Copy link
Member

Would be good to ask @ggerganov for his opinion because when I was implementing Qwen3Next he said he didn't want to add custom per-model kernels.

I'm not an ML expert, but this does seem to be an operation that appears in multiple models and is "too much" to reconstruct in fusion. I can't comment on the correctness of this definition, but if it's self-contained like this and can replace this whole block in qwen3next it seems appealing to have as an op.

So generally, it's not super clear to me where to draw the line between fundamental ops and dedicated high-level ops. We do have ggml_flash_attn which is very similar to this, but IMO FA is something that has been very well established by this point and is not going away in the future.

Delta Net on the other hand is a brand new technique which we don't know what is the future of. I only know about Qwen3 Next using it and after that Kimi Linear came out and as far as I remember it already introduced some changes to the mechanism.

So at some point we will probably have to add a dedicated operator, I just cannot say that this is the right moment to commit, because I feel its definition is still shaky.

Happy to be corrected and to consider other povs.

Regarding the other similar ops:

  • ggml_rwkv_wkv6()
  • ggml_rwkv_wkv7()
  • ggml_gated_linear_attn()

In retrospect, I think it was too premature to introduce those. I don't think we see the benefits of having them and it would have been better to wait for the RWKV architecture to settle before adding those.

@pwilkin
Copy link
Collaborator

pwilkin commented Dec 17, 2025

@ggerganov Yeah, I agree it would probably be better to first have a few models in Delta Net to see if we can establish a common kernel. The Kimi kernel is similar but has subtle differences, I'd need to implement the full Kimi Linear model to see (I've been waiting for the PR that's pending, but I don't know if it's going anywhere).

@hauhaut For SOLVE_TRI, can you show the performance data from test-backend-ops perf -b CUDA0 -o SOLVE_TRI before and after this patch?

@hauhaut
Copy link
Author

hauhaut commented Dec 17, 2025

Quite a lot of messages to come back to!

Firstly, thank you to those taking time off their busy schedules to read the PR and take the time to provide constructive feedback. I'm currently looking to onboard as much of it as possible - nothing goes to waste as I see it as a learning opportunity.

Second, regarding clearly mentioning in OP about AI contribution, that's a direct oversight on my end for which (despite, when asked, I openly admit) is something I should have paid more attention to. I apologize for it and have now updated the OP to clearly state it. I ask of you to keep an open mind still, as some of you have shown already, on the implementation as I strongly feel this is the way moving forward (if wrong, I'm happy to buy the dev teaching me why a coffee).

Third, I will repeat (mostly for my own peace of mind), all code has been extensively manually tested and I have yet to come across any issues. All tests passed and all runtime manual testing seems to be issues free. I've tested across 3 different machines and 4 models of GPU's (RTX 6000 ADA, 6000 Blackwell Pro, RTX 5090 and RTX 5090M) - all showing qualitative clear and stable gains, especially in inference. It was a nice bonus to see the delta_net_recurrent_f32 kernel (keeps state in global memory) being put to good use even in ROCm.

Fourth, for SOLVE_TRI code reduction - I'm happy to look into it more and see what further optimizations I can make. There's always room for (clear) improvement. Also, worth noting, here are the test results on my workstation. @jeffbolznv I can also conduct further tests on my 5090 and 5090M if needed (both for SOLVE_TRI and for inference/pp/tg):
@pwilkin
Edit for full context:

Latest Llama.cpp build:

Backend 1/4: CUDA0
Device description: NVIDIA RTX 6000 Ada Generation
Device memory: 49139 MB (47217 MB free)

SOLVE_TRI(type=f32,ne_lhs=[64,64,4,4],ne_rhs=[32,64,4,4]): 65520 runs - 15.50 us/run - 2.13 MFLOP/run - 137.38 GFLOPS
SOLVE_TRI(type=f32,ne_lhs=[128,128,4,2],ne_rhs=[32,128,4,2]): 16380 runs - 72.51 us/run - 4.23 MFLOP/run - 58.29 GFLOPS
SOLVE_TRI(type=f32,ne_lhs=[64,64,8,32],ne_rhs=[64,64,8,32]): 24956 runs - 40.87 us/run - 68.16 MFLOP/run - 1.67 TFLOPS
SOLVE_TRI(type=f32,ne_lhs=[128,128,4,32],ne_rhs=[128,128,4,32]): 11100 runs - 92.49 us/run - 270.53 MFLOP/run - 2.93 TFLOPS
SOLVE_TRI(type=f32,ne_lhs=[256,256,4,2],ne_rhs=[128,256,4,2]): 8910 runs - 122.11 us/run - 67.37 MFLOP/run - 551.71 GFLOPS
Backend CUDA0: OK

Backend 2/4: CUDA1
Device description: NVIDIA RTX PRO 6000 Blackwell Workstation Edition
Device memory: 97886 MB (95276 MB free)

SOLVE_TRI(type=f32,ne_lhs=[64,64,4,4],ne_rhs=[32,64,4,4]): 65520 runs - 15.75 us/run - 2.13 MFLOP/run - 135.25 GFLOPS
SOLVE_TRI(type=f32,ne_lhs=[128,128,4,2],ne_rhs=[32,128,4,2]): 16380 runs - 78.73 us/run - 4.23 MFLOP/run - 53.69 GFLOPS
SOLVE_TRI(type=f32,ne_lhs=[64,64,8,32],ne_rhs=[64,64,8,32]): 30828 runs - 33.68 us/run - 68.16 MFLOP/run - 2.02 TFLOPS
SOLVE_TRI(type=f32,ne_lhs=[128,128,4,32],ne_rhs=[128,128,4,32]): 14060 runs - 72.20 us/run - 270.53 MFLOP/run - 3.75 TFLOPS
SOLVE_TRI(type=f32,ne_lhs=[256,256,4,2],ne_rhs=[128,256,4,2]): 8910 runs - 112.95 us/run - 67.37 MFLOP/run - 596.47 GFLOPS
Backend CUDA1: OK

PR Implementation:

Device description: NVIDIA RTX 6000 Ada Generation
Device memory: 49139 MB (47217 MB free)

SOLVE_TRI(type=f32,ne_lhs=[64,64,4,4],ne_rhs=[32,64,4,4]): 49140 runs - 22.83 us/run - 2.13 MFLOP/run - 93.30 GFLOPS
SOLVE_TRI(type=f32,ne_lhs=[128,128,4,2],ne_rhs=[32,128,4,2]): 16380 runs - 70.57 us/run - 4.23 MFLOP/run - 59.90 GFLOPS
SOLVE_TRI(type=f32,ne_lhs=[64,64,8,32],ne_rhs=[64,64,8,32]): 24956 runs - 40.56 us/run - 68.16 MFLOP/run - 1.68 TFLOPS
SOLVE_TRI(type=f32,ne_lhs=[128,128,4,32],ne_rhs=[128,128,4,32]): 11100 runs - 92.03 us/run - 270.53 MFLOP/run - 2.94 TFLOPS
SOLVE_TRI(type=f32,ne_lhs=[256,256,4,2],ne_rhs=[128,256,4,2]): 8910 runs - 119.72 us/run - 67.37 MFLOP/run - 562.75 GFLOPS
Backend CUDA0: OK

Device description: NVIDIA RTX PRO 6000 Blackwell Workstation Edition
Device memory: 97886 MB (95276 MB free)

SOLVE_TRI(type=f32,ne_lhs=[64,64,4,4],ne_rhs=[32,64,4,4]): 49140 runs - 21.75 us/run - 2.13 MFLOP/run - 97.93 GFLOPS
SOLVE_TRI(type=f32,ne_lhs=[128,128,4,2],ne_rhs=[32,128,4,2]): 16380 runs - 68.10 us/run - 4.23 MFLOP/run - 62.08 GFLOPS
SOLVE_TRI(type=f32,ne_lhs=[64,64,8,32],ne_rhs=[64,64,8,32]): 33764 runs - 30.07 us/run - 68.16 MFLOP/run - 2.27 TFLOPS
SOLVE_TRI(type=f32,ne_lhs=[128,128,4,32],ne_rhs=[128,128,4,32]): 15170 runs - 66.88 us/run - 270.53 MFLOP/run - 4.04 TFLOPS
SOLVE_TRI(type=f32,ne_lhs=[256,256,4,2],ne_rhs=[128,256,4,2]): 10395 runs - 104.22 us/run - 67.37 MFLOP/run - 646.45 GFLOPS
Backend CUDA1: OK

Fifth, on whether delta-net is mature enough ( @ggerganov ): I'm currently working on extending support to KDA (Kimi-Linear's diagonal decay) on a private branch - same op, but with shape-driven detection. Early results look good and to me and getting ahead of this now makes sense rather than catching up later. Also, for my own due dilligence, I was looking at the landscape and Delta-Net keeps showing up in newer linear attention models. I couldn't find another approach that would let us unlock these kinds of optimizations down the line.

However, having said all of that, from the very first post I mentioned one clear concern (which i will reiterate now) - I don't know how easily implementable it all is in the main branch currently. This is a big project and although this feels like the way forward, it comes with the usual risks for which I defer to you all whether or not it's worth proceeding with currently (or in my iteration of it). It's easy for me to implement things for myself and my own use / POC, it's another to be responsible for the main branch.

I hope I didn't miss anything, but if I did, let me know.

@JohannesGaessler
Copy link
Collaborator

JohannesGaessler commented Dec 17, 2025

all code has been extensively manually tested and I have yet to come across any issues.

You are missing the point. This is not about whether the code works correctly right now, this is about the maintenance burden down the line. It is almost guaranteed that over the course of the project's lifetime it will be necessary to make modifications to the code in this PR. For that the code needs to be analyzed and understood. So unless the code has been written by a maintainer that is available long-term and can modify the code with little upfront investment, keeping it compact and deduplicated is a major priority. Let me be frank: if I were to merge the code as it is right now it would be a net negative for the project due to the opportunity cost for maintainers.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning model Model specific Nvidia GPU Issues specific to Nvidia GPUs testing Everything test related

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants