-
Notifications
You must be signed in to change notification settings - Fork 587
[PyT] Plumbing correct bias dims from TE to cudnn #2537
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[PyT] Plumbing correct bias dims from TE to cudnn #2537
Conversation
|
/te-ci pytorch L0 L1 |
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
200fd98 to
8da3252
Compare
for more information, see https://pre-commit.ci
Greptile SummaryThis PR correctly plumbs bias tensor dimensions from Transformer Engine to cuDNN by extracting actual bias shape dimensions instead of assuming they match query/key-value sequence lengths. Key Changes:
Impact: Confidence Score: 5/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller
participant FusedAttnFwd as fused_attn_arbitrary_seqlen_fwd
participant FusedAttnBwd as fused_attn_arbitrary_seqlen_bwd
participant FwdImpl as fused_attn_arbitrary_seqlen_fwd_impl
participant BwdImpl as fused_attn_arbitrary_seqlen_bwd_impl
participant CuDNN as cudnn_frontend
Note over Caller,CuDNN: Forward Pass
Caller->>FusedAttnFwd: input_Bias tensor
FusedAttnFwd->>FusedAttnFwd: Extract bias_b, bias_h from shape[0:2]
FusedAttnFwd->>FusedAttnFwd: Extract bias_sq, bias_skv from shape[2:4]
FusedAttnFwd->>FwdImpl: Pass bias_b, bias_h, bias_sq, bias_skv
FwdImpl->>FwdImpl: Create FADescriptor_v1 with all bias dims
FwdImpl->>CuDNN: Set bias tensor dims to {bias_b, bias_h, bias_sq, bias_skv}
CuDNN-->>FwdImpl: Execute attention with correct bias shape
FwdImpl->>FusedAttnFwd: Set output_bias shape to {bias_b, bias_h, bias_sq, bias_skv}
FusedAttnFwd-->>Caller: Return results
Note over Caller,CuDNN: Backward Pass
Caller->>FusedAttnBwd: input_Bias, output_dBias tensors
FusedAttnBwd->>FusedAttnBwd: Extract bias_b, bias_h from output_dBias shape[0:2]
FusedAttnBwd->>FusedAttnBwd: Extract bias_sq, bias_skv from input_Bias shape[2:4]
FusedAttnBwd->>BwdImpl: Pass bias_b, bias_h, bias_sq, bias_skv
BwdImpl->>BwdImpl: Create FADescriptor_v1 with all bias dims
BwdImpl->>CuDNN: Set bias/dBias dims to {bias_b, bias_h, bias_sq, bias_skv}
CuDNN-->>BwdImpl: Execute backward attention with correct shapes
BwdImpl-->>FusedAttnBwd: Return gradients
FusedAttnBwd-->>Caller: Return results
|
Greptile's behavior is changing!From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section. This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR". |
|
Looks good - please pick the 111s test from my branch as well. Thanks! |
Description
TE common was not plumbing attention vector bias dimensions correctly to cuDNN.
Instead of using shape from Bias, i.e.
[bias_sq, bias_skv]it was using[sq, skv]thereby passing larger than required dims. Using the reproducer : https://github.com/cyanguwa/TransformerEngine/tree/test_111s for bias [1,1,1,s] it can be seen in the cuDNN FE logs that prior to this PR the bias dims passed onto cuDNN from TE were{"data_type":null,"dim":[1,1,128,128],"is_pass_by_value":false,"is_virtual":false,"name":"bias","pass_by_value":null,"reordering_type":"NONE","stride":[16384,16384,128,1],"uid":0,"uid_assigned":false},and after this PR they are:
"bias":{"data_type":null,"dim":[1,1,1,128],"is_pass_by_value":false,"is_virtual":false,"name":"bias","pass_by_value":null,"reordering_type":"NONE","stride":[128,128,128,1],"uid":0,"uid_assigned":false},Type of change
Changes
bias_sqandbias_skvtofused_attn_arbitrary_seqlen_fwd_impl()andfused_attn_arbitrary_seqlen_bwd_impl()bias_sqandbias_skvinFADescriptor_v1bias_sqandbias_skvinstead ofs_qands_kvChecklist: