Skip to content

Commit

Permalink
[BACKEND] Linear Layout with stmatrix part 2: support stmatrix for `l…
Browse files Browse the repository at this point in the history
…ocal_alloc` ops (#4763)

This PR enables the use of `stmatrix` for `local_alloc` ops through
linear layout and removes the legacy code from the `TargetInfo` class.
  • Loading branch information
Jokeren authored Oct 1, 2024
1 parent 80a5cfb commit 49266aa
Show file tree
Hide file tree
Showing 10 changed files with 287 additions and 215 deletions.
9 changes: 0 additions & 9 deletions include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,6 @@ class TargetInfoBase {
unsigned numLaneToReduce,
unsigned interleave) const = 0;

// TODO (Keren): Remove this function once layout conversion using stmatrix is
// handled by Linear Layout.
virtual bool processReplicaUsingStMatrix(
RewriterBase &rewriter, Location loc, Value smemBase,
SmallVector<Value> &vals, RankedTensorType srcTy, Type elemTy,
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> origRepShape,
ArrayRef<unsigned> outOrd, unsigned accumNumReplicates,
int swizzleByteWidth = 0) const = 0;

virtual std::string getMulhiFuncName(Type resultElementTy) const = 0;
// Emits LLVM code with |rewriter| to print a message following the given
// format from the device. |formatStrStart| is the pointer to the start of
Expand Down
128 changes: 125 additions & 3 deletions include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,134 @@ LinearLayout chooseShemLayoutForRegToRegConversion(
// row0 reg[0-1] reg[4-5]
// row8 reg[2-3] reg[6-7]
//
// When `swizzleByteSize` is non-zero, the layout is constructed
// differently due to leading dimension offset and swizzling.
// There are two key concepts to understand:
//
// 1. Chunks: The leading dimension (i.e., the column dimension) is divided
// into chunks, where each chunk's size is determined by `swizzleByteSize`.
// 2. Swizzling within tiles: Each tile applies a swizzling pattern to its
// rows to optimize memory access.
//
// - Concept 1: Chunks
//
// In the swizzled layout, the leading dimension is strided by
// `swizzleByteSize`. This introduces the concept of a "chunk", where each chunk
// spans a certain number of columns.
//
// For a tile size of `stmatrix.x4` (16x16 elements), with each element being 16
// bits (2 bytes), each tile occupies 16 rows and 32 bytes per row (since 16
// elements * 2 bytes per element = 32 bytes per row).
//
// Given a `swizzleByteSize` of 128 bytes, the number of tiles per chunk can be
// calculated as:
//
// Number of tiles per chunk = swizzleByteSize / (bytes per row) = 128 bytes /
// 32 bytes = 4 tiles
//
// Therefore, each chunk contains 4 tiles horizontally, spanning 64 columns
// (since each tile is 16 columns):
//
// col0-15 col16-31 col32-47 col48-63
// row0-15 tile0 tile1 tile2 tile3
//
// For a tensor of size 128x128 elements (#rows x #columns), and each element
// being 16 bits, the tensor can be divided into multiple chunks both
// horizontally and vertically. Chunks are stored in memory in a "column-major"
// order based on chunks, meaning chunk1's address follows chunk0's.
//
// Assuming we have 8 warps, and we assign each warp to process a chunk of 16
// rows (rows per tile) and 128 columns (the width of two chunks). This results
// in each warp handling one horizontal slice of the tensor.
//
// The overall layout can be visualized as:
//
// |<- 128 * 128 bytes ->|<- 128 * 128 bytes ->|
// columns 0-63 columns 64-127
// warp0 | rows 0-15 chunk0 chunk8
// warp1 | rows 16-31 chunk1 chunk9
// warp2 | rows 32-47 chunk2 chunk10
// warp3 | rows 48-63 chunk3 chunk11
// warp4 | rows 64-79 chunk4 chunk12
// warp5 | rows 80-95 chunk5 chunk13
// warp6 | rows 96-111 chunk6 chunk14
// warp7 | rows 112-127 chunk7 chunk15
//
// - Concept 2: Swizzling within tiles
//
// Within each 16x16 tile, rows are swizzled to optimize memory access patterns.
// This swizzling is similar to what's defined in `TritonGPUAttrDefs.td`. at the
// level of each 16x16 tile rather than the entire tensor.
//
// Key parameters for swizzling:
//
// - `perPhase`: The number of rows over which to apply a XOR operation at
// each phase.
// - `maxPhase`: The total number of phases.
// - `vectorWidth`: The number of elements per vector, which is 8 in this case
// because `stmatrix` stores 8 contiguous elements per thread.
//
// The offset of each element within a tile is calculated using the formula:
//
// offset = row * swizzleByteSize + (vectorWidth * ((row / perPhase) %
// maxPhase)) * elementSize
//
// where `elementSize` is the size of each element in bytes (2 bytes for 16-bit
// elements).
//
// For example, consider the element at index `(row=1, col=0)` in chunk0:
//
// Without swizzling:
//
// offset = row * swizzleByteSize + col * elementSize
// = 1 * 128 bytes + 0 * 2 bytes
// = 128 bytes
//
// With swizzling (assuming `perPhase=1`, `maxPhase=8`, `vectorWidth=8`):
//
// offset = row * swizzleByteSize + (vectorWidth * ((row / perPhase) %
// maxPhase)) * elementSize
// = 1 * 128 bytes + (8 * ((1 / 1) % 8)) * 2 bytes
// = 128 bytes + (8 * (1 % 8)) * 2 bytes
// = 128 bytes + 8 * 2 bytes
// = 128 bytes + 16 bytes
// = 144 bytes
//
// This swizzling ensures that elements are stored in a way that optimizes for
// memory bandwidth and reduces bank conflicts.
//
// - Verification through Linear Layout
//
// We can verify the offsets with the following outputs of the corresponding
// linear layout, where each element is 16 bits (2 bytes):
//
// - register=1 -> offset=1
// register=2 -> offset=2
// register=4 -> offset=4
// register=8 -> offset=16
// register=16 -> offset=32
// register=32 -> offset=8192
// - lane=1 -> offset=72
// lane=2 -> offset=144
// lane=4 -> offset=288
// lane=8 -> offset=512
// lane=16 -> offset=8
// - warp=1 -> offset=1024
// warp=2 -> offset=2048
// warp=4 -> offset=4096
//
// For index `(row=1, col=0)`, which corresponds to `reg=0` and `lane=1` in
// `warp=0`, the offset is calculated as 72 * 2 bytes = 144 bytes. The result
// matches our earlier calculation.
//
// TODO(Keren): We should replace tensorTy with a LinearLayout and the element
// bit width of the tensor in the future to support more flexible tensor
// encodings
std::optional<LinearLayout> chooseStMatrixLayoutForRegToRegConversion(
MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> order);
std::optional<LinearLayout>
chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape,
ArrayRef<unsigned> order, int swizzleByteSize);
} // namespace mlir::triton::gpu

#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H
18 changes: 6 additions & 12 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,15 +215,9 @@ struct ConvertLayoutOpConversion
if (repId != 0) {
barrier();
}
auto successful = targetInfo.processReplicaUsingStMatrix(
rewriter, loc, smemBase, vals, srcTy,
getTypeConverter()->convertType(srcTy.getElementType()),
paddedRepShape, origRepShape, outOrd, accumNumReplicates);
if (!successful) {
processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep,
multiDimRepId, inVec, paddedRepShape, origRepShape,
outOrd, vals, smemBase);
}
processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep,
multiDimRepId, inVec, paddedRepShape, origRepShape, outOrd,
vals, smemBase);
barrier();
processReplica(loc, rewriter, /*stNotRd*/ false, dstTy, outNumCTAsEachRep,
multiDimRepId, outVec, paddedRepShape, origRepShape,
Expand Down Expand Up @@ -483,9 +477,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
// Input dims: [reg, lane, warp]
// Output dims: [offset, iteration]
std::optional<LinearLayout> shmemStoreLayout =
chooseStMatrixLayoutForRegToRegConversion(
ctx, op.getSrc().getType(), scratchConfig.repShape,
scratchConfig.paddedRepShape, scratchConfig.order);
chooseStMatrixLayout(ctx, op.getSrc().getType(), scratchConfig.repShape,
scratchConfig.paddedRepShape, scratchConfig.order,
/*swizzleByteSize=*/0);
bool isStMatrix = shmemStoreLayout.has_value();
if (!isStMatrix) {
shmemStoreLayout = srcLayout.invertAndCompose(sharedLayout);
Expand Down
1 change: 0 additions & 1 deletion lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
RankedTensorType dstTy = op.getType();
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();
// TODO: do we need to check if src is shared ?
if (isa<SharedEncodingAttr>(srcLayout) &&
isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr>(
dstLayout)) {
Expand Down
103 changes: 96 additions & 7 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -820,8 +820,8 @@ namespace {
// stmatrix. These restrictions are retained from legacy code, and we could
// relax some of them in the future.
bool canUseStMatrix(RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape,
ArrayRef<unsigned> order) {
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> order,
int swizzleByteSize) {
auto mmaLayout =
mlir::dyn_cast<NvidiaMmaEncodingAttr>(tensorTy.getEncoding());
if (!mmaLayout || !mmaLayout.isHopper())
Expand All @@ -840,17 +840,87 @@ bool canUseStMatrix(RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
return false;
if (paddedRepShape[1] % 8 != 0)
return false;
if (swizzleByteSize != 0 && swizzleByteSize != 32 && swizzleByteSize != 64 &&
swizzleByteSize != 128)
return false;
return true;
}

} // anonymous namespace
std::optional<LinearLayout> chooseStMatrixLayoutLeadingOffset(
MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> order,
int swizzleByteSize) {
StringAttr kReg = S("register");
StringAttr kLane = S("lane");
StringAttr kWarp = S("warp");
StringAttr kCol = S("dim1");
StringAttr kRow = S("dim0");
StringAttr kOffset = S("offset");

int perPhase;
int maxPhase;
if (swizzleByteSize == 32) {
perPhase = 4;
maxPhase = 2;
} else if (swizzleByteSize == 64) {
perPhase = 2;
maxPhase = 4;
} else if (swizzleByteSize == 128) {
perPhase = 1;
maxPhase = 8;
} else {
llvm::errs() << "Illegal swizzleByteSize: " << swizzleByteSize << "\n";
llvm::report_fatal_error("Illegal swizzleByteSize");
}

// stmatrix only supports 16-bit elements, and each vector has 8 elements
int elemBitWidth = 16;
int vecSize = 8;
int numRows = 16;
int numCols = 8 * swizzleByteSize / elemBitWidth;

// Construct a single stmatrix.x4 (16x16) tile
std::vector<std::vector<int>> basesReg = {{1, 0}, {2, 0}, {4, 0}};
std::vector<std::vector<int>> basesLane;
for (int logRow = 0; logRow < llvm::Log2_32(numRows); logRow++) {
int row = 1 << logRow;
basesLane.push_back({vecSize * ((row / perPhase) % maxPhase), row});
}
basesLane.push_back({8, 0});

// Expand the tile's register dimension to fit swizzleByteSize, which is a
// "chunk"
for (int logChunk = 0; logChunk < llvm::Log2_32(numCols / 16); logChunk++) {
int chunk = 1 << logChunk;
basesReg.push_back({16 * chunk, 0});
}

// Construct the layout for a single chunk
LinearLayout layout =
LinearLayout({{kReg, basesReg}, {kLane, basesLane}}, {kCol, kRow});

std::optional<LinearLayout> chooseStMatrixLayoutForRegToRegConversion(
// Expand the `warp` dimension according to warpsPerCTA.
auto mma = cast<NvidiaMmaEncodingAttr>(tensorTy.getEncoding());
layout *=
identityND(kWarp, mma.getWarpsPerCTA(), /*order=*/{0, 1}, {kRow, kCol})
.transposeOuts(llvm::to_vector(layout.getOutDimNames()));

// Expand the `register` dimension so the size of columns matches `n`.
int n = mma.getInstrShape()[1];
int numWarpRows = layout.getOutDimSize(kRow);
layout = (layout.reshapeOuts({{kOffset, layout.getTotalOutDimSize()}}) *
LinearLayout::identity1D(n / numCols, kReg, kOffset))
.reshapeOuts({{kCol, n}, {kRow, numWarpRows}});

auto ret =
combineCtaCgaWithShape(layout, mma.getCTALayout(), tensorTy.getShape());
return ret.transposeOuts(llvm::to_vector(layout.getOutDimNames()))
.reshapeOuts({{kOffset, ret.getTotalOutDimSize()}, {S("iteration"), 1}});
}

std::optional<LinearLayout> chooseStMatrixLayoutNoLeadingOffset(
MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> order) {
if (!canUseStMatrix(tensorTy, repShape, paddedRepShape, order))
return std::nullopt;

StringAttr kReg = S("register");
StringAttr kLane = S("lane");
StringAttr kWarp = S("warp");
Expand Down Expand Up @@ -880,4 +950,23 @@ std::optional<LinearLayout> chooseStMatrixLayoutForRegToRegConversion(
{{S("offset"), ret.getTotalOutDimSize()}, {S("iteration"), 1}});
}

} // anonymous namespace

std::optional<LinearLayout>
chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape,
ArrayRef<unsigned> order, int swizzleByteSize) {
if (!canUseStMatrix(tensorTy, repShape, paddedRepShape, order,
swizzleByteSize))
return std::nullopt;

if (swizzleByteSize == 0)
return chooseStMatrixLayoutNoLeadingOffset(ctx, tensorTy, repShape,
paddedRepShape, order);
else
return chooseStMatrixLayoutLeadingOffset(
ctx, tensorTy, repShape, paddedRepShape, order, swizzleByteSize);
}

} // namespace mlir::triton::gpu
9 changes: 0 additions & 9 deletions third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,6 @@ bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc,
return false;
}

bool TargetInfo::processReplicaUsingStMatrix(
RewriterBase &rewriter, Location loc, Value smemBase,
SmallVector<Value> &vals, RankedTensorType srcTy, Type elemTy,
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> origRepShape,
ArrayRef<unsigned> outOrd, unsigned accumNumReplicates,
int swizzleByteWidth) const {
return false;
}

void TargetInfo::printfImpl(Value formatStrStart, int formatStrByteCount,
ValueRange args, RewriterBase &rewriter,
bool useStdErr) const {
Expand Down
9 changes: 0 additions & 9 deletions third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,6 @@ class TargetInfo : public mlir::triton::TargetInfoBase {
triton::ReduceOp op, unsigned numLaneToReduce,
unsigned interleave) const override;

bool processReplicaUsingStMatrix(RewriterBase &rewriter, Location loc,
Value smemBase, SmallVector<Value> &vals,
RankedTensorType srcTy, Type elemTy,
ArrayRef<unsigned> paddedRepShape,
ArrayRef<unsigned> origRepShape,
ArrayRef<unsigned> outOrd,
unsigned accumNumReplicates,
int swizzleByteWidth) const override;

std::string getMulhiFuncName(Type resultElementTy) const override;

void printf(RewriterBase &rewriter, Value formatStrStart,
Expand Down
Loading

0 comments on commit 49266aa

Please sign in to comment.