Skip to content

Commit

Permalink
[AMD] Fix shared layout order for batch dimension in pipeline passes (#…
Browse files Browse the repository at this point in the history
…4796)

Batch dimension should be slowest one, other cases are not supported by
MFMA/WMMA/MMA pipeline.
  • Loading branch information
binarman authored Oct 3, 2024
1 parent 5f77e8c commit 33c0c1c
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 4 deletions.
35 changes: 35 additions & 0 deletions test/TritonGPU/loop-pipeline-hip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -198,3 +198,38 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
tt.return
}
} // end module

// -----

// CHECK-NOT: #triton_gpu.shared<{{.*}} order = [2, 0, 1]
// CHECK: #triton_gpu.shared<{{.*}} order = [2, 1, 0]
// CHECK-NOT: #triton_gpu.shared<{{.*}} order = [2, 0, 1]

// CHECK-LABEL: tt.func public @slowest_dim_is_batch
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [4, 1, 16], warpsPerCTA = [4, 1, 1], order = [2, 1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1, 8], threadsPerWarp = [16, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 64], warpsPerCTA = [1, 4], order = [1, 0]}>
#blocked5 = #triton_gpu.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [16, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx90a", "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func public @slowest_dim_is_batch(%arg0: tensor<1x512x!tt.ptr<f32>, #blocked2>, %arg1: tensor<64x8x32x!tt.ptr<f32>, #blocked1>, %arg2: tensor<64x1x32x!tt.ptr<f32>, #blocked>) attributes {noinline = false} {
%cst = arith.constant dense<0.000000e+00> : tensor<64x1x32xf32, #blocked>
%cst_0 = arith.constant dense<512> : tensor<1x512xi32, #blocked2>
%cst_1 = arith.constant dense<128> : tensor<64x8x32xi32, #blocked1>
%c1_i32 = arith.constant 1 : i32
%c5_i32 = arith.constant 2 : i32
%c0_i32 = arith.constant 0 : i32
%33:3 = scf.for %arg7 = %c0_i32 to %c5_i32 step %c1_i32 iter_args(%arg8 = %cst, %arg9 = %arg0, %arg10 = %arg1) -> (tensor<64x1x32xf32, #blocked>, tensor<1x512x!tt.ptr<f32>, #blocked2>, tensor<64x8x32x!tt.ptr<f32>, #blocked1>) : i32 {
%39 = tt.load %arg9 : tensor<1x512x!tt.ptr<f32>, #blocked2>
%40 = tt.load %arg10 : tensor<64x8x32x!tt.ptr<f32>, #blocked1>
%41 = tt.reshape %39 {allow_reorder = true} : tensor<1x512xf32, #blocked2> -> tensor<64x1x8xf32, #blocked5>
%43 = triton_gpu.convert_layout %41 : tensor<64x1x8xf32, #blocked5> -> tensor<64x1x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>
%44 = triton_gpu.convert_layout %40 : tensor<64x8x32xf32, #blocked1> -> tensor<64x8x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>
%45 = tt.dot %43, %44, %arg8 : tensor<64x1x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x8x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x1x32xf32, #blocked>
%46 = tt.addptr %arg9, %cst_0 : tensor<1x512x!tt.ptr<f32>, #blocked2>, tensor<1x512xi32, #blocked2>
%47 = tt.addptr %arg10, %cst_1 : tensor<64x8x32x!tt.ptr<f32>, #blocked1>, tensor<64x8x32xi32, #blocked1>
scf.yield %45, %46, %47 : tensor<64x1x32xf32, #blocked>, tensor<1x512x!tt.ptr<f32>, #blocked2>, tensor<64x8x32x!tt.ptr<f32>, #blocked1>
}
tt.store %arg2, %33#0 : tensor<64x1x32x!tt.ptr<f32>, #blocked>
tt.return
}
}
21 changes: 18 additions & 3 deletions third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -403,9 +403,24 @@ void LoopPipeliner::createBufferTypes() {
// unsigned bitWidth = dotOpEnc.getMMAv2kWidth()
// ? 32 / dotOpEnc.getMMAv2kWidth()
// : ty.getElementType().getIntOrFloatBitWidth();
auto sharedEnc = ttg::SharedEncodingAttr::get(
ty.getContext(), dotOpEnc, ty.getShape(),
ttg::getOrder(ty.getEncoding()), CTALayout, eType);
auto srcOrder = ttg::getOrder(ty.getEncoding());
SmallVector<unsigned> sharedOrder;
int rank = srcOrder.size();
// TODO rework this when shared -> dotOp conversions support arbitrary
// shared memory ordering
if (rank == 3) {
// Move the batch dimension (dim #0) to be the last so that it will be the
// slowest varying dimension.
for (unsigned i = 0; i < rank; ++i)
if (srcOrder[i] != 0)
sharedOrder.emplace_back(srcOrder[i]);
sharedOrder.emplace_back(0);
} else {
sharedOrder = srcOrder;
}
auto sharedEnc =
ttg::SharedEncodingAttr::get(ty.getContext(), dotOpEnc, ty.getShape(),
sharedOrder, CTALayout, eType);
loadsBufferType[loadOp] = triton::MemDescType::get(
bufferShape, eType, sharedEnc,
triton::gpu::SharedMemorySpaceAttr::get(ty.getContext()),
Expand Down
16 changes: 15 additions & 1 deletion third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,22 @@ getSharedEncIfAllUsersAreDotEnc(Value val) {
auto CTALayout = ttg::getCTALayout(srcTy.getEncoding());
auto order = ttg::getOrder(srcTy.getEncoding());
unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth();
SmallVector<unsigned> sharedOrder;
int rank = order.size();
// TODO rework this when shared -> dotOp conversions support arbitrary
// shared memory ordering
if (rank == 3) {
// Move the batch dimension (dim #0) to be the last so that it will be
// the slowest varying dimension.
for (unsigned i = 0; i < rank; ++i)
if (order[i] != 0)
sharedOrder.emplace_back(order[i]);
sharedOrder.emplace_back(0);
} else {
sharedOrder = order;
}
tempAttr = ttg::SharedEncodingAttr::get(
val.getContext(), dotOpEnc, srcTy.getShape(), order, CTALayout,
val.getContext(), dotOpEnc, srcTy.getShape(), sharedOrder, CTALayout,
bitWidth, /*needTrans=*/false);
}
// Check that the shared encodings needed by the users are compatible.
Expand Down

0 comments on commit 33c0c1c

Please sign in to comment.