From 3a9ddeaf075766696fe4910898877b55a001d98a Mon Sep 17 00:00:00 2001 From: peterbell10 Date: Fri, 4 Oct 2024 19:53:54 +0100 Subject: [PATCH] [Backend] Use symbol table to lookup smem base (#4853) Unlikely to make a huge perf difference, but doing a symbol table lookup here is a bit cleaner. --- .../triton/Conversion/TritonGPUToLLVM/Utility.h | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index d9ebe7ccc1e8..e47a025fbbf1 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -6,6 +6,7 @@ #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "triton/Analysis/Utility.h" #include "triton/Conversion/MLIRTypes.h" #include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" @@ -364,17 +365,14 @@ inline bool isKernel(FunctionOpInterface funcOp) { inline Value getStackPointer(RewriterBase &rewriter, FunctionOpInterface funcOp) { + if (!isKernel(funcOp)) { + return funcOp.getArgument(funcOp.getNumArguments() - 1); + } + auto mod = funcOp->getParentOfType(); - LLVM::GlobalOp globalBase = nullptr; - mod.walk([&](LLVM::GlobalOp op) { - if (op.getSymName() == "global_smem") - globalBase = op; - }); + auto globalBase = dyn_cast(mod.lookupSymbol("global_smem")); assert(globalBase); - if (isKernel(funcOp)) - return rewriter.create(funcOp.getLoc(), globalBase); - else - return funcOp.getArgument(funcOp.getNumArguments() - 1); + return rewriter.create(funcOp.getLoc(), globalBase); } inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter,