Skip to content

Enable NVFP4 RHT amax for grouped SReLU MLP#3133

Open
sraman-rgb wants to merge 8 commits into
NVIDIA:mainfrom
sraman-rgb:te-nvfp4-srelu-rht-hadamard
Open

Enable NVFP4 RHT amax for grouped SReLU MLP#3133
sraman-rgb wants to merge 8 commits into
NVIDIA:mainfrom
sraman-rgb:te-nvfp4-srelu-rht-hadamard

Conversation

@sraman-rgb

Copy link
Copy Markdown
Contributor

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@sraman-rgb sraman-rgb requested a review from timmoon10 as a code owner June 16, 2026 18:42
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Jun 16, 2026
Signed-off-by: Siddhartha Raman <sraman@nvidia.com>
@sraman-rgb sraman-rgb force-pushed the te-nvfp4-srelu-rht-hadamard branch from fa32e3b to 79def34 Compare June 16, 2026 18:45
@greptile-apps

greptile-apps Bot commented Jun 16, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR extends the NVFP4 RHT (Randomized Hadamard Transform) amax computation to the GroupedMLP_CuTeGEMMUnary (SReLU) path, matching existing support previously only available for the SwiGLU path. It adds a cuDNN frontend version gate (≥ 1.26.0), renames the GLU-specific hadamard kernel accessor to the shared grouped_gemm_act_hadamard_kernel, and adds a dedicated test for the nvfp4_rht + scaled_srelu combination.

  • grouped_mlp.py: Refactors use_fc1_glu_hadamarduse_fc1_act_hadamard / use_fc1_act_hadamard_srelu, adds a version guard, and implements grouped_gemm_act_hadamard_kernel() on GroupedMLP_CuTeGEMMUnary reusing grouped_gemm_glu_hadamard_wrapper_sm100 with act_func=\"srelu\".
  • test_fusible_ops.py: Generalises test_grouped_mlp to accept an activation parameter, updates tolerance selection to use quantization_tols for quantized runs, and adds test_grouped_mlp_nvfp4_rht_srelu as a targeted non-parametrised test.

Confidence Score: 5/5

Safe to merge; the new SReLU hadamard path is gated behind both a cuDNN frontend version check and a runtime capability probe.

The production path change is minimal — a renamed method accessor, a new version gate, and a small SReLU branch — all mirroring the existing SwiGLU hadamard implementation. The test extension correctly guards the quantization_tols call and the new dedicated test exercises the exact added path.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/fused/grouped_mlp.py Renames hadamard kernel accessor, adds SReLU-specific hadamard path with cuDNN FE >= 1.26.0 version gate; logic is sound and consistent with the existing SwiGLU path.
tests/pytorch/test_fusible_ops.py Generalises test_grouped_mlp to support scaled_srelu activation, fixes tolerance selection via quantization_tols guard, and adds targeted test_grouped_mlp_nvfp4_rht_srelu.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[fuser_forward] --> B{use_nvfp4_rht_amax?}
    B -- No --> Z[grouped_gemm_activation_kernel]
    B -- Yes --> C{activation_supports_hadamard?}
    C -- No --> Z
    C -- Yes --> D[kernel_getter = grouped_gemm_act_hadamard_kernel]
    D --> E{kernel available?}
    E -- No --> Z
    E -- Yes --> F{activation_is_srelu?}
    F -- Yes --> G[act_func = srelu]
    F -- No --> H[act_func = _cudnn_act_func]
    G --> I[grouped_gemm_act_hadamard_kernel]
    H --> I
    I --> J[_group_quantize_with_amax_for_grouped_mlp]
    Z --> K[norm_const_tensor path]
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
    A[fuser_forward] --> B{use_nvfp4_rht_amax?}
    B -- No --> Z[grouped_gemm_activation_kernel]
    B -- Yes --> C{activation_supports_hadamard?}
    C -- No --> Z
    C -- Yes --> D[kernel_getter = grouped_gemm_act_hadamard_kernel]
    D --> E{kernel available?}
    E -- No --> Z
    E -- Yes --> F{activation_is_srelu?}
    F -- Yes --> G[act_func = srelu]
    F -- No --> H[act_func = _cudnn_act_func]
    G --> I[grouped_gemm_act_hadamard_kernel]
    H --> I
    I --> J[_group_quantize_with_amax_for_grouped_mlp]
    Z --> K[norm_const_tensor path]
Loading

Reviews (6): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment thread tests/pytorch/test_fusible_ops.py

@vthumbe1503 vthumbe1503 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM mostly except CUDNN guard update that I think is needed.

Comment thread tests/pytorch/test_fusible_ops.py Outdated
"""Fused grouped GEMM activation kernel that also emits NVFP4 RHT amaxes."""
try:
from cudnn import (
grouped_gemm_glu_hadamard_wrapper_sm100,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need new cudnn version for supporting srelu in this kernel? If so, we should update it.

@vthumbe1503

Copy link
Copy Markdown
Collaborator

/te-ci pytorch

Co-authored-by: vthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: Siddhartha Raman Sundara Raman <sraman@nvidia.com>
Comment thread tests/pytorch/test_fusible_ops.py Outdated
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Comment thread tests/pytorch/test_fusible_ops.py Outdated
Signed-off-by: Siddhartha Raman Sundara Raman <sraman@nvidia.com>
@sraman-rgb sraman-rgb force-pushed the te-nvfp4-srelu-rht-hadamard branch from a076f41 to 6ce5259 Compare June 24, 2026 03:14
@sraman-rgb

Copy link
Copy Markdown
Contributor Author

/te-ci pytorch

Set default tolerance values for quantization checks.

Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
@vthumbe1503

Copy link
Copy Markdown
Collaborator

/te-ci pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants