Skip to content

[TIRx] Bundle CUDA tile primitive and op dispatch updates#19896

Merged
spectrometerHBH merged 18 commits into
apache:mainfrom
spectrometerHBH:tirx-upstream-bundle
Jun 29, 2026
Merged

[TIRx] Bundle CUDA tile primitive and op dispatch updates#19896
spectrometerHBH merged 18 commits into
apache:mainfrom
spectrometerHBH:tirx-upstream-bundle

Conversation

@spectrometerHBH

@spectrometerHBH spectrometerHBH commented Jun 27, 2026

Copy link
Copy Markdown
Contributor

Summary

This bundles the 18 commits currently carried in spectrometerHBH/tvm on top of apache/tvm:main.

Major areas:

  • Extend CUDA TIRx tile primitives and op dispatch paths, including vector PTX ld/st, shared-memory copy paths, TMA/tcgen05 descriptor handling, dense FP8/TF32 gemm_async, and CUDA elementwise tile dispatch.
  • Add support utilities for benchmark timing, CUDA ptxas option plumbing, and TMA/TFLOAT32 descriptors.
  • Fix unsigned integer floormod/floordiv simplification rewrites without overflow and update the corresponding TIRx constant-folding tests.
  • Update TIRx dtype handling for upstream PrimType compatibility.
  • Add and update TIRx CUDA/operator tests for copy, elementwise, permute layout, and gemm_async behavior.

Validation

  • git diff --check apache/main..HEAD
  • python -m tirx_kernels.bench_suite --check-imports
  • python -m tirx_kernels.registry --cc 10 --strict
  • python -m pytest tests/python/tirx/ -n 16
    • 2033 passed, 39 skipped, 3 xpassed
  • python -m pytest tests/python/tirx-base/test_tir_imm_values.py -q
    • 44 passed, 6 warnings
  • pre-commit run --files tests/python/tirx-base/test_tir_imm_values.py
  • Focused TIRx regression tests after formatting:
    • test_cast_vec2_packed_dispatch
    • test_cast_warpgroup_src_layout_to_flat_uses_vec2_intrinsic
    • test_gemm_tcgen05_cta_group_1[task0]
  • Full bench_suite --impls all sweep: 256/256 workloads completed successfully.
  • Apache PR CI on 928a0605d0: all required GitHub Actions and Jenkins checks passed.

Drop benchmark agent skill and related AGENTS.md/.gitignore entries.
* feat(lower-tirx): add vector PTX ld/st and remove copy_xxb

Extend T.ptx.ld/st with dst/vec forms, retire copy_xxb intrinsics, and route
reg/gmem_smem tile copy dispatch through inline PTX ld/st with unit tests.

* test(lower-tirx): fix PTX ld/st tests to use behavior not source reads

Add GPU reg↔shared roundtrip checks in test_ptx_ld_st_ops, drop
test_reg_nested_copy_op source-path hacks, and assert PTX ld/st via compiled CUDA in test_reg.

* test(tirx): document cta_sync stand-in after copy_bytes removal

Namespace printer test still needs a T.cuda device_intrin after copy_xxb was deleted.

* test(tirx): drop stale copy_bytes comment in namespace test

* test(tirx): drop redundant warpgroup PTX ld/st compile grep test

test_reg_roundtrip already covers the same warpgroup reg<->shared kernel
with GPU execution and numerical checks.

* test(tirx): consolidate PTX ld/st width tests into test_ptx_ld_st_ops

Delete legacy test_cuda_copy.py and cover 128/64/32/16/8-byte shared
copies with one parametrized GPU test. Fix u8 ld return type to uint32
per PTX codegen constraints.

* test(tirx): dedupe PTX ld/st width kernels with case table

Drive shared-copy GPU tests from _SHARED_COPY_CASES and three dtype
templates instead of five copy-pasted num_bytes branches.

* test(tirx): use one shared-copy kernel with uint32 smem

Single _shared_scratch_copy_kernel entry: uint32 smem for 128/64/32b,
small 16b/8b branches only where PTX scratch/buffer types require it.

* test(tirx): single shared-copy kernel via closure dtype vars

TVMScript accepts smem_dtype/tmp_dtype from the builder closure; only
in-body branching on buffer declarations is unsupported.
#7)

Extend RewriteSimplifier with unsigned-only identities (x%x, x%1, x/x,
x/1) and floormod(x*c1, c2) -> 0 when c1 is divisible by c2. Signed rules
that assume no wraparound are intentionally not applied to unsigned dtypes.
Add references/TIRX_BENCH_IMPLS support, in-bench rounds aggregation, and store
CUDA event and proton timings in microseconds. Point AGENTS.md at tirx-kernels
bench_suite and simplify tir-test import setup.
Add an optional tma_dtype="tf32" to copy_async(tma) so an fp32 gmem buffer is
loaded via a CU_TENSOR_MAP_DATA_TYPE_TFLOAT32 descriptor. The TMA hardware then
round-to-nearest truncates fp32 -> tf32 ON LOAD, matching a tf32 MMA's operand
precision (and a torch allow_tf32 / DeepGEMM reference). Loading as FLOAT32 and
letting the tf32 MMA RZ-truncate the top 19 bits at read diverges by up to one
tf32 ULP (~5e-4 on a GEMM output, _calc_diff ~1.5e-7 > a 1e-8 gate).

- runtime.cuTensorMapEncodeTiled (cuda_device_api.cc): optional trailing int
  force_cu_dtype (>= 0 overrides the dtype-derived CUtensorMapDataType; 11 ==
  TFLOAT32). Backward-compatible -- older callers omit the arg.
- tma.py: _TMA_DTYPE_TO_CU maps "tf32"/"tfloat32" -> 11; reads
  op_call.config["tma_dtype"], validates a float32 buffer, folds it into the
  tensormap cache key, and appends force_cu_dtype to the encode call ONLY when
  set so the default path stays byte-identical (existing codegen tests unchanged).

tir-test (tests/python/tirx): 2040 passed, 0 failed.
)

Factor ptxas flags into _ptxas_option_flags() and forward
TVM_CUDA_PTXAS_REG_LEVEL / TVM_CUDA_PTXAS_EXTRA_OPTS for both nvcc
(comma-joined) and nvrtc (one --ptxas-options entry per token).
Extend tcgen05 dense MMA dispatch with semantic A/B dtypes (fp8, tf32 via
is_AB_tf32) and add end-to-end gemm_async tests for fp8 and tf32 TMA loads.
* fix(infra): reg-copy R2S/S2R for tcgen05 split-laneid atom layouts

Preserve r_perm through align_layouts_raw so split-laneid register
layouts pair correctly with swizzled SMEM. Add layout and compile tests.

* refactor(infra): shorten align_layouts_raw docstring

* test(infra): document tcgen05 D epilogue reg-copy regression case

Name the production copy (shapes, layouts, slice), assert layout pairing
3 vs 1 groups, and check generated outer loop f<16 with st.shared.v2.u32.

* test(infra): add GPU roundtrip for tcgen05 D epilogue reg copy

Exercise R→S deposit and S→R reload on Layout-F reg vs 128B swizzled SMEM;
assert bit-exact recovery of the host-filled logical tile.

* refactor(op-dispatch): remove slice@llvm layout workarounds

FA4's global pre-canon in TileLayoutNode::Slice makes llvm-only slicing
unnecessary; slice and canonicalize under sctx.target in reg, ldstmatrix,
and elementwise dispatch.
Wire maximum into BINARY_OPS with scalar FMNMX lowering (no f32x2 vec).
Add GPU reg roundtrip test for Tx.maximum on local operands.
* fix(lower-tirx): slice elementwise reg layouts under sctx.target (B00011)

Run layout slice and canonicalize under sctx.target in
_align_layouts_no_post_canon and _check_layout_operands_agree so split-laneid
tcgen05 atom layouts do not hit conflicting thread scopes during dispatch.

* test(op-dispatch): add B00011 tcgen05 cast regression tests

Cover warpgroup cast on split-laneid tcgen05 atoms and drop redundant
hasattr guards before TileLayout.canonicalize().

* test(op-dispatch): add GPU regression for tcgen05 warpgroup cast

Roundtrip scatter/cast/gather on the split-laneid .16x256b atom layout
mirrors the tf32 A-cast path and checks bf16→fp32 values on device.

* test(op-dispatch): drop redundant B00011 predicate unit test

Keep compile and GPU roundtrip regressions for tcgen05 warpgroup cast only.

* refactor(op-dispatch): drop unrelated elementwise changes from B00011 PR

Revert noop _common.py slice split and unrelated hasattr cleanup; keep only
_check_layout_operands_agree slice+canon under sctx.target.

* refactor(op-dispatch): drop hasattr guard before layout canonicalize

Operands are always TileLayouts in the reg elementwise predicate path.
…matrix (#18)

Step 2 canonicalized the register/shared layout *before* slicing, to fuse a
split-laneid tcgen05 atom (separate laneid + wid_in_wg axes) into a single
tid_in_wg axis on the full layout — otherwise slicing left an ill-formed
sub-layout that GetScope rejected, silently dropping ldmatrix for a scalar
reg path.

`TileLayout.Slice` now globally pre-canonicalizes internally before grouping,
so the Python pre-slice canonicalize is redundant: `slice().canonicalize()`
produces the same layout. Drop it.

Add a compile-only regression test: a warpgroup `.16x256b` atom loaded from a
128B-swizzled SMEM tile must emit `ldmatrix.x4` (the tf32-prenorm cast-warp A
load).
…r cp (#19)

The smem->tmem cp dispatch encoded a fresh SMEM matrix descriptor per buffer
(cache key included hash(s_buf)) and set LDO=16 (a placeholder cp ignores for
data). Instead encode ONE descriptor template at SMEM base 0 — so the cache
key drops the per-buffer hash and identical (ldo, sdo, swizzle) templates are
shared — and patch its 14-bit address field per cp via
`cvta(addr) >> 4 & 0x3FFF` (`_desc_set_addr`, mirroring the hand-rolled
`replace_smem_desc_addr`). LDO is set to 0 since cp ignores it for data and a
non-zero LDO only bloats the address-patch codegen.

Add a compile-only regression test: a 4-tile copy emits one
`encode_matrix_descriptor` reused across four `tcgen05.cp.32x128b.warpx4`
issues, each with the per-cp address-field patch.
* feat(op-dispatch): support uint32 shape extents in TMA copy

A deepgemm TMA source buffer wants a uint32 runtime shape (e.g. shape_m)
with no int32 cast. Two grouping proofs blocked it:

- gmem: `_canonicalize_gmem` fused contiguous dims into one prod-extent
  iter (n*64), so regrouping by the buffer shape needed `(n*64) % n == 0`
  — unprovable for unsigned under wraparound. Drop the canonicalize: a
  plain gmem buffer already has one iter per dim, so grouping only needs
  the overflow-free `dim % dim` proof. The grouped result is identical for
  every signed case (group re-splits to the buffer dims), so this is a
  pure proof simplification (full golden suite unchanged) that also
  unlocks uint32 shapes.
- smem: an unsigned slice base leaks its dtype into the copy extent;
  `_regroup_smem_by_extgt1_shape` now views unsigned extents as signed for
  the structural proof (value-preserving; emitted base stays unsigned).

* refactor(op-dispatch): canonicalize TMA gmem groups after grouping

* fix(layout): fold unsigned floormod constants
#20)

A shared->shared warp_xor_swizzle permute indexed both operands through
``buf[...]``, so the swizzled layout lowered to a per-element IMAD flatten on
the hot SF-transpose path. For a 4/8-byte dtype with both operands in shared
memory, compute one base ptr via ``ptr_to(stride_offset)`` and add a
compile-time ``off * dtype_bytes`` per register slot, then issue
``T.ptx.ld/st(..., space="shared")`` directly.

The permute only shuffles bits, so move them through an unsigned container of
the matching width: ``ld.bN`` rejects a float return dtype, so this also lets
a float32/float64 shared tile use the direct path (the predicate already
admitted 4/8-byte floats — previously they hit a codegen error).

Add a compile-only test asserting the direct ld.shared/st.shared path fires
for both uint32 (SF case) and float32.
…mm_async (#15)

Rework the tcgen05 gemm_async SMEM matrix-descriptor handling:

- Build the descriptor per MMA from the buffer base address, selected by a
  new ``smem_desc`` config:
    * ``hoist`` (default): allocate + encode one warp-uniform descriptor per
      operand and add the per-MMA 16B offset (``smem_desc_add_16B_offset``).
    * ``recompute``: build the full descriptor inline per MMA (``_uniform_desc``)
      with no allocated/encoded cell — one fewer live register on the fa4 hot
      path at the cost of a few ALU ops.
  The descriptor base is always the buffer origin (stage 0); the per-MMA
  operand offset is applied on top, so both modes are address-correct.
- ``_atom_off``: a leading atom dim of extent 1 contributes no LBO/SBO, so its
  (meaningless) stride no longer leaks into the descriptor offset.
- Inline the A/B operands and SF addresses into the MMA call via ``T.meta_var``
  to avoid per-iteration LMEM temporaries, and fold ``needs_sf_id`` /
  descriptor-rotation handling so the block-scaled and dense paths share one
  ``main_impl``.

Add a compile-only test asserting the hoist vs recompute descriptor
fingerprints (and that both emit the MMA).

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request replaces the legacy cuda_copy_bytes intrinsics with generic PTX ld and st operations across various copy and register dispatch paths. It introduces several new PTX load/store intrinsics (including relaxed, release, and MMIO variants), adds support for unsigned FloorDiv and FloorMod simplifications, and updates the benchmarking infrastructure to support microsecond precision and independent rounds. Additionally, it refactors tcgen05 MMA descriptor handling and adds an optional TMA descriptor dtype override. The reviewer suggested extending the shared-to-shared copy optimization in warp_xor_swizzle.py to support 16-bit types, as ld.shared.b16 is a valid PTX instruction.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.


# Shared 32/64b: base ptr + stride offset avoids buf[] flatten IMAD path.
direct = (
dtype_bytes in (4, 8)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This is a great optimization to use direct ld.shared/st.shared for shared memory copies. It seems this could be extended to support 16-bit types as well, since ld.shared.b16 is a valid PTX instruction.

Suggested change
dtype_bytes in (4, 8)
dtype_bytes in (2, 4, 8)

@spectrometerHBH spectrometerHBH marked this pull request as ready for review June 28, 2026 02:08
@spectrometerHBH spectrometerHBH merged commit 4224d51 into apache:main Jun 29, 2026
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants