Skip to content

Commit

Permalink
[Backend] Use symbol table to lookup smem base (#4853)
Browse files Browse the repository at this point in the history
Unlikely to make a huge perf difference, but doing a symbol table lookup
here 
is a bit cleaner.
  • Loading branch information
peterbell10 authored Oct 4, 2024
1 parent 41006e9 commit 3a9ddea
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "triton/Analysis/Utility.h"
#include "triton/Conversion/MLIRTypes.h"
#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h"
Expand Down Expand Up @@ -364,17 +365,14 @@ inline bool isKernel(FunctionOpInterface funcOp) {

inline Value getStackPointer(RewriterBase &rewriter,
FunctionOpInterface funcOp) {
if (!isKernel(funcOp)) {
return funcOp.getArgument(funcOp.getNumArguments() - 1);
}

auto mod = funcOp->getParentOfType<ModuleOp>();
LLVM::GlobalOp globalBase = nullptr;
mod.walk([&](LLVM::GlobalOp op) {
if (op.getSymName() == "global_smem")
globalBase = op;
});
auto globalBase = dyn_cast<LLVM::GlobalOp>(mod.lookupSymbol("global_smem"));
assert(globalBase);
if (isKernel(funcOp))
return rewriter.create<LLVM::AddressOfOp>(funcOp.getLoc(), globalBase);
else
return funcOp.getArgument(funcOp.getNumArguments() - 1);
return rewriter.create<LLVM::AddressOfOp>(funcOp.getLoc(), globalBase);
}

inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter,
Expand Down

0 comments on commit 3a9ddea

Please sign in to comment.