diff --git a/llvm/lib/Transforms/Scalar/LoopInterchange.cpp b/llvm/lib/Transforms/Scalar/LoopInterchange.cpp index 91286ebcea33aa14db71f69580f59ad276d1ff75..b4293f6c977cd7abbebed76a9185c9c712c801dd 100644 --- a/llvm/lib/Transforms/Scalar/LoopInterchange.cpp +++ b/llvm/lib/Transforms/Scalar/LoopInterchange.cpp @@ -15,6 +15,7 @@ #include "llvm/Transforms/Scalar/LoopInterchange.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringRef.h" #include "llvm/Analysis/DependenceAnalysis.h" @@ -112,8 +113,9 @@ static bool populateDependencyMatrix(CharMatrix &DepMatrix, unsigned Level, ValueVector::iterator I, IE, J, JE; - for (I = MemInstr.begin(), IE = MemInstr.end(); I != IE; ++I) { - for (J = I, JE = MemInstr.end(); J != JE; ++J) { + for (I = MemInstr.begin(), IE = MemInstr.end(); I < IE; ++I) { + for (J = I + 1, JE = MemInstr.end(); J < JE; ++J) { + std::vector Dep; Instruction *Src = cast(*I); Instruction *Dst = cast(*J); @@ -310,6 +312,8 @@ public: std::unique_ptr &CC); private: + int analyzeGEPGroupPattern(ArrayRef GEPGroup); + int analyzeSingleGEP(const GetElementPtrInst *GEP); int getInstrOrderCost(); std::optional isProfitablePerLoopCacheAnalysis( const DenseMap &CostMap, @@ -540,6 +544,15 @@ struct LoopInterchange { bool LoopInterchangeLegality::containsUnsafeInstructions(BasicBlock *BB) { return any_of(*BB, [](const Instruction &I) { + if (auto *LI = dyn_cast(&I)) { + if (LI->isSimple() && !LI->isVolatile()) + return false; + } + if (auto *SI = dyn_cast(&I)) { + if (SI->isSimple() && !SI->isVolatile()) + return false; + } + return I.mayHaveSideEffects() || I.mayReadFromMemory(); }); } @@ -561,15 +574,17 @@ bool LoopInterchangeLegality::tightlyNested(Loop *OuterLoop, Loop *InnerLoop) { for (BasicBlock *Succ : successors(OuterLoopHeaderBI)) if (Succ != InnerLoopPreHeader && Succ != InnerLoop->getHeader() && - Succ != OuterLoopLatch) - return false; + Succ != OuterLoopLatch){ + return false; + } LLVM_DEBUG(dbgs() << "Checking instructions in Loop header and Loop latch\n"); // We do not have any basic block in between now make sure the outer header // and outer loop latch doesn't contain any unsafe instructions. if (containsUnsafeInstructions(OuterLoopHeader) || - containsUnsafeInstructions(OuterLoopLatch)) + containsUnsafeInstructions(OuterLoopLatch)){ return false; + } // Also make sure the inner loop preheader does not contain any unsafe // instructions. Note that all instructions in the preheader will be moved to @@ -1041,55 +1056,120 @@ bool LoopInterchangeLegality::canInterchangeLoops(unsigned InnerLoopId, int LoopInterchangeProfitability::getInstrOrderCost() { unsigned GoodOrder, BadOrder; BadOrder = GoodOrder = 0; + + SmallPtrSet AnalyzedGEPs; + for (BasicBlock *BB : InnerLoop->blocks()) { for (Instruction &Ins : *BB) { if (const GetElementPtrInst *GEP = dyn_cast(&Ins)) { - unsigned NumOp = GEP->getNumOperands(); - bool FoundInnerInduction = false; - bool FoundOuterInduction = false; - for (unsigned i = 0; i < NumOp; ++i) { - // Skip operands that are not SCEV-able. - if (!SE->isSCEVable(GEP->getOperand(i)->getType())) - continue; - - const SCEV *OperandVal = SE->getSCEV(GEP->getOperand(i)); - const SCEVAddRecExpr *AR = dyn_cast(OperandVal); - if (!AR) - continue; - - // If we find the inner induction after an outer induction e.g. - // for(int i=0;igetLoop() == InnerLoop) { - // We found an InnerLoop induction after OuterLoop induction. It is - // a good order. - FoundInnerInduction = true; - if (FoundOuterInduction) { - GoodOrder++; - break; - } - } - // If we find the outer induction after an inner induction e.g. - // for(int i=0;igetLoop() == OuterLoop) { - // We found an OuterLoop induction after InnerLoop induction. It is - // a bad order. - FoundOuterInduction = true; - if (FoundInnerInduction) { - BadOrder++; - break; + if (AnalyzedGEPs.count(GEP)) + continue; + + bool isPtrPtrFirstLevel = false; + if (GEP->getNumOperands() < 3) { + isPtrPtrFirstLevel = true; + } + + if (isPtrPtrFirstLevel) { + SmallVector RelatedGEPs; + RelatedGEPs.push_back(GEP); + + for (const User *U : GEP->users()) { + if (const LoadInst *Load = dyn_cast(U)) { + for (const User *LoadUser : Load->users()) { + if (const GetElementPtrInst *SecondGEP = dyn_cast(LoadUser)) { + RelatedGEPs.push_back(SecondGEP); + AnalyzedGEPs.insert(SecondGEP); + } + } } } + + int patternCost = analyzeGEPGroupPattern(RelatedGEPs); + if (patternCost > 0) GoodOrder += 1; + else if (patternCost < 0) BadOrder += 1; + + AnalyzedGEPs.insert(GEP); + } else { + int gepCost = analyzeSingleGEP(GEP); + if (gepCost > 0) GoodOrder += 1; + else if (gepCost < 0) BadOrder += 1; + + AnalyzedGEPs.insert(GEP); } } } } - return GoodOrder - BadOrder; + + return GoodOrder - BadOrder; +} + +int LoopInterchangeProfitability::analyzeGEPGroupPattern(ArrayRef GEPGroup) { + SmallVector AllLoopIndices; + + for (const GetElementPtrInst *GEP : GEPGroup) { + for (unsigned i = 0; i < GEP->getNumOperands(); ++i) { + Value *Index = GEP->getOperand(i); + if (!SE->isSCEVable(Index->getType())) + continue; + + const SCEV *OperandVal = SE->getSCEV(Index); + if (const SCEVAddRecExpr *AR = dyn_cast(OperandVal)) { + AllLoopIndices.push_back(AR->getLoop()); + } + } + } + + bool FoundInnerInduction = false; + bool FoundOuterInduction = false; + + for (const Loop *L : AllLoopIndices) { + if (L == InnerLoop) { + FoundInnerInduction = true; + if (FoundOuterInduction) { + return 1; // Good pattern + } + } + + if (L == OuterLoop) { + FoundOuterInduction = true; + if (FoundInnerInduction) { + return -1; // Bad pattern + } + } + } + + return 0; // Unknown pattern +} + +int LoopInterchangeProfitability::analyzeSingleGEP(const GetElementPtrInst *GEP) { + bool FoundInnerInduction = false; + bool FoundOuterInduction = false; + + for (unsigned i = 0; i < GEP->getNumOperands(); ++i) { + if (!SE->isSCEVable(GEP->getOperand(i)->getType())) + continue; + const SCEV *OperandVal = SE->getSCEV(GEP->getOperand(i)); + const SCEVAddRecExpr *AR = dyn_cast(OperandVal); + if (!AR) + continue; + + if (AR->getLoop() == InnerLoop) { + FoundInnerInduction = true; + if (FoundOuterInduction) { + return 1; + } + } + + if (AR->getLoop() == OuterLoop) { + FoundOuterInduction = true; + if (FoundInnerInduction) { + return -1; + } + } + } + + return 0; } std::optional