reference, declarationdefinition
definition → references, declarations, derived classes, virtual overrides
reference to multiple definitions → definitions
unreferenced
    1
    2
    3
    4
    5
    6
    7
    8
    9
   10
   11
   12
   13
   14
   15
   16
   17
   18
   19
   20
   21
   22
   23
   24
   25
   26
   27
   28
   29
   30
   31
   32
   33
   34
   35
   36
   37
   38
   39
   40
   41
   42
   43
   44
   45
   46
   47
   48
   49
   50
   51
   52
   53
   54
   55
   56
   57
   58
   59
   60
   61
   62
   63
   64
   65
   66
   67
   68
   69
   70
   71
   72
   73
   74
   75
   76
   77
   78
   79
   80
   81
   82
   83
   84
   85
   86
   87
   88
   89
   90
   91
   92
   93
   94
   95
   96
   97
   98
   99
  100
  101
  102
  103
  104
  105
  106
  107
  108
  109
  110
  111
  112
  113
  114
  115
  116
  117
  118
  119
  120
  121
  122
  123
  124
  125
  126
  127
  128
  129
  130
  131
  132
  133
  134
  135
  136
  137
  138
  139
  140
  141
  142
  143
  144
  145
  146
  147
  148
  149
  150
  151
  152
  153
  154
  155
  156
  157
  158
  159
  160
  161
  162
  163
  164
  165
  166
  167
  168
  169
  170
  171
  172
  173
  174
  175
  176
  177
  178
  179
  180
  181
  182
  183
  184
  185
  186
  187
  188
  189
  190
  191
  192
  193
  194
  195
  196
  197
  198
  199
  200
  201
  202
  203
  204
  205
  206
  207
  208
  209
  210
  211
  212
  213
  214
  215
  216
  217
  218
  219
  220
  221
  222
  223
  224
  225
  226
  227
  228
  229
  230
  231
  232
  233
  234
  235
  236
  237
  238
  239
  240
  241
  242
  243
  244
  245
  246
  247
  248
  249
  250
  251
  252
  253
  254
  255
  256
  257
  258
  259
  260
  261
  262
  263
  264
  265
  266
  267
  268
  269
  270
  271
  272
  273
  274
  275
  276
  277
  278
  279
  280
  281
  282
  283
  284
  285
  286
  287
  288
  289
  290
  291
  292
  293
  294
  295
  296
  297
  298
  299
  300
  301
  302
  303
  304
  305
  306
  307
  308
  309
  310
  311
  312
  313
  314
  315
  316
  317
  318
  319
  320
  321
  322
  323
  324
  325
  326
  327
  328
  329
  330
  331
  332
  333
  334
  335
  336
  337
  338
  339
  340
  341
  342
  343
  344
  345
  346
  347
  348
  349
  350
  351
  352
  353
  354
  355
  356
  357
  358
  359
  360
  361
  362
  363
  364
  365
  366
  367
  368
  369
  370
  371
  372
  373
  374
  375
  376
  377
  378
  379
  380
//===- SyncDependenceAnalysis.cpp - Divergent Branch Dependence Calculation
//--===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements an algorithm that returns for a divergent branch
// the set of basic blocks whose phi nodes become divergent due to divergent
// control. These are the blocks that are reachable by two disjoint paths from
// the branch or loop exits that have a reaching path that is disjoint from a
// path to the loop latch.
//
// The SyncDependenceAnalysis is used in the DivergenceAnalysis to model
// control-induced divergence in phi nodes.
//
// -- Summary --
// The SyncDependenceAnalysis lazily computes sync dependences [3].
// The analysis evaluates the disjoint path criterion [2] by a reduction
// to SSA construction. The SSA construction algorithm is implemented as
// a simple data-flow analysis [1].
//
// [1] "A Simple, Fast Dominance Algorithm", SPI '01, Cooper, Harvey and Kennedy
// [2] "Efficiently Computing Static Single Assignment Form
//     and the Control Dependence Graph", TOPLAS '91,
//           Cytron, Ferrante, Rosen, Wegman and Zadeck
// [3] "Improving Performance of OpenCL on CPUs", CC '12, Karrenberg and Hack
// [4] "Divergence Analysis", TOPLAS '13, Sampaio, Souza, Collange and Pereira
//
// -- Sync dependence --
// Sync dependence [4] characterizes the control flow aspect of the
// propagation of branch divergence. For example,
//
//   %cond = icmp slt i32 %tid, 10
//   br i1 %cond, label %then, label %else
// then:
//   br label %merge
// else:
//   br label %merge
// merge:
//   %a = phi i32 [ 0, %then ], [ 1, %else ]
//
// Suppose %tid holds the thread ID. Although %a is not data dependent on %tid
// because %tid is not on its use-def chains, %a is sync dependent on %tid
// because the branch "br i1 %cond" depends on %tid and affects which value %a
// is assigned to.
//
// -- Reduction to SSA construction --
// There are two disjoint paths from A to X, if a certain variant of SSA
// construction places a phi node in X under the following set-up scheme [2].
//
// This variant of SSA construction ignores incoming undef values.
// That is paths from the entry without a definition do not result in
// phi nodes.
//
//       entry
//     /      \
//    A        \
//  /   \       Y
// B     C     /
//  \   /  \  /
//    D     E
//     \   /
//       F
// Assume that A contains a divergent branch. We are interested
// in the set of all blocks where each block is reachable from A
// via two disjoint paths. This would be the set {D, F} in this
// case.
// To generally reduce this query to SSA construction we introduce
// a virtual variable x and assign to x different values in each
// successor block of A.
//           entry
//         /      \
//        A        \
//      /   \       Y
// x = 0   x = 1   /
//      \  /   \  /
//        D     E
//         \   /
//           F
// Our flavor of SSA construction for x will construct the following
//            entry
//          /      \
//         A        \
//       /   \       Y
// x0 = 0   x1 = 1  /
//       \   /   \ /
//      x2=phi    E
//         \     /
//          x3=phi
// The blocks D and F contain phi nodes and are thus each reachable
// by two disjoins paths from A.
//
// -- Remarks --
// In case of loop exits we need to check the disjoint path criterion for loops
// [2]. To this end, we check whether the definition of x differs between the
// loop exit and the loop header (_after_ SSA construction).
//
//===----------------------------------------------------------------------===//
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Analysis/PostDominators.h"
#include "llvm/Analysis/SyncDependenceAnalysis.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/CFG.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/Function.h"

#include <stack>
#include <unordered_set>

#define DEBUG_TYPE "sync-dependence"

namespace llvm {

ConstBlockSet SyncDependenceAnalysis::EmptyBlockSet;

SyncDependenceAnalysis::SyncDependenceAnalysis(const DominatorTree &DT,
                                               const PostDominatorTree &PDT,
                                               const LoopInfo &LI)
    : FuncRPOT(DT.getRoot()->getParent()), DT(DT), PDT(PDT), LI(LI) {}

SyncDependenceAnalysis::~SyncDependenceAnalysis() {}

using FunctionRPOT = ReversePostOrderTraversal<const Function *>;

// divergence propagator for reducible CFGs
struct DivergencePropagator {
  const FunctionRPOT &FuncRPOT;
  const DominatorTree &DT;
  const PostDominatorTree &PDT;
  const LoopInfo &LI;

  // identified join points
  std::unique_ptr<ConstBlockSet> JoinBlocks;

  // reached loop exits (by a path disjoint to a path to the loop header)
  SmallPtrSet<const BasicBlock *, 4> ReachedLoopExits;

  // if DefMap[B] == C then C is the dominating definition at block B
  // if DefMap[B] ~ undef then we haven't seen B yet
  // if DefMap[B] == B then B is a join point of disjoint paths from X or B is
  // an immediate successor of X (initial value).
  using DefiningBlockMap = std::map<const BasicBlock *, const BasicBlock *>;
  DefiningBlockMap DefMap;

  // all blocks with pending visits
  std::unordered_set<const BasicBlock *> PendingUpdates;

  DivergencePropagator(const FunctionRPOT &FuncRPOT, const DominatorTree &DT,
                       const PostDominatorTree &PDT, const LoopInfo &LI)
      : FuncRPOT(FuncRPOT), DT(DT), PDT(PDT), LI(LI),
        JoinBlocks(new ConstBlockSet) {}

  // set the definition at @block and mark @block as pending for a visit
  void addPending(const BasicBlock &Block, const BasicBlock &DefBlock) {
    bool WasAdded = DefMap.emplace(&Block, &DefBlock).second;
    if (WasAdded)
      PendingUpdates.insert(&Block);
  }

  void printDefs(raw_ostream &Out) {
    Out << "Propagator::DefMap {\n";
    for (const auto *Block : FuncRPOT) {
      auto It = DefMap.find(Block);
      Out << Block->getName() << " : ";
      if (It == DefMap.end()) {
        Out << "\n";
      } else {
        const auto *DefBlock = It->second;
        Out << (DefBlock ? DefBlock->getName() : "<null>") << "\n";
      }
    }
    Out << "}\n";
  }

  // process @succBlock with reaching definition @defBlock
  // the original divergent branch was in @parentLoop (if any)
  void visitSuccessor(const BasicBlock &SuccBlock, const Loop *ParentLoop,
                      const BasicBlock &DefBlock) {

    // @succBlock is a loop exit
    if (ParentLoop && !ParentLoop->contains(&SuccBlock)) {
      DefMap.emplace(&SuccBlock, &DefBlock);
      ReachedLoopExits.insert(&SuccBlock);
      return;
    }

    // first reaching def?
    auto ItLastDef = DefMap.find(&SuccBlock);
    if (ItLastDef == DefMap.end()) {
      addPending(SuccBlock, DefBlock);
      return;
    }

    // a join of at least two definitions
    if (ItLastDef->second != &DefBlock) {
      // do we know this join already?
      if (!JoinBlocks->insert(&SuccBlock).second)
        return;

      // update the definition
      addPending(SuccBlock, SuccBlock);
    }
  }

  // find all blocks reachable by two disjoint paths from @rootTerm.
  // This method works for both divergent terminators and loops with
  // divergent exits.
  // @rootBlock is either the block containing the branch or the header of the
  // divergent loop.
  // @nodeSuccessors is the set of successors of the node (Loop or Terminator)
  // headed by @rootBlock.
  // @parentLoop is the parent loop of the Loop or the loop that contains the
  // Terminator.
  template <typename SuccessorIterable>
  std::unique_ptr<ConstBlockSet>
  computeJoinPoints(const BasicBlock &RootBlock,
                    SuccessorIterable NodeSuccessors, const Loop *ParentLoop) {
    assert(JoinBlocks);

    LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints. Parent loop: " << (ParentLoop ? ParentLoop->getName() : "<null>") << "\n" );

    // bootstrap with branch targets
    for (const auto *SuccBlock : NodeSuccessors) {
      DefMap.emplace(SuccBlock, SuccBlock);

      if (ParentLoop && !ParentLoop->contains(SuccBlock)) {
        // immediate loop exit from node.
        ReachedLoopExits.insert(SuccBlock);
      } else {
        // regular successor
        PendingUpdates.insert(SuccBlock);
      }
    }

    LLVM_DEBUG(
      dbgs() << "SDA: rpo order:\n";
      for (const auto * RpoBlock : FuncRPOT) {
        dbgs() << "- " << RpoBlock->getName() << "\n";
      }
    );

    auto ItBeginRPO = FuncRPOT.begin();

    // skip until term (TODO RPOT won't let us start at @term directly)
    for (; *ItBeginRPO != &RootBlock; ++ItBeginRPO) {}

    auto ItEndRPO = FuncRPOT.end();
    assert(ItBeginRPO != ItEndRPO);

    // propagate definitions at the immediate successors of the node in RPO
    auto ItBlockRPO = ItBeginRPO;
    while ((++ItBlockRPO != ItEndRPO) &&
           !PendingUpdates.empty()) {
      const auto *Block = *ItBlockRPO;
      LLVM_DEBUG(dbgs() << "SDA::joins. visiting " << Block->getName() << "\n");

      // skip Block if not pending update
      auto ItPending = PendingUpdates.find(Block);
      if (ItPending == PendingUpdates.end())
        continue;
      PendingUpdates.erase(ItPending);

      // propagate definition at Block to its successors
      auto ItDef = DefMap.find(Block);
      const auto *DefBlock = ItDef->second;
      assert(DefBlock);

      auto *BlockLoop = LI.getLoopFor(Block);
      if (ParentLoop &&
          (ParentLoop != BlockLoop && ParentLoop->contains(BlockLoop))) {
        // if the successor is the header of a nested loop pretend its a
        // single node with the loop's exits as successors
        SmallVector<BasicBlock *, 4> BlockLoopExits;
        BlockLoop->getExitBlocks(BlockLoopExits);
        for (const auto *BlockLoopExit : BlockLoopExits) {
          visitSuccessor(*BlockLoopExit, ParentLoop, *DefBlock);
        }

      } else {
        // the successors are either on the same loop level or loop exits
        for (const auto *SuccBlock : successors(Block)) {
          visitSuccessor(*SuccBlock, ParentLoop, *DefBlock);
        }
      }
    }

    LLVM_DEBUG(dbgs() << "SDA::joins. After propagation:\n"; printDefs(dbgs()));

    // We need to know the definition at the parent loop header to decide
    // whether the definition at the header is different from the definition at
    // the loop exits, which would indicate a divergent loop exits.
    //
    // A // loop header
    // |
    // B // nested loop header
    // |
    // C -> X (exit from B loop) -..-> (A latch)
    // |
    // D -> back to B (B latch)
    // |
    // proper exit from both loops
    //
    // analyze reached loop exits
    if (!ReachedLoopExits.empty()) {
      const BasicBlock *ParentLoopHeader =
          ParentLoop ? ParentLoop->getHeader() : nullptr;

      assert(ParentLoop);
      auto ItHeaderDef = DefMap.find(ParentLoopHeader);
      const auto *HeaderDefBlock = (ItHeaderDef == DefMap.end()) ? nullptr : ItHeaderDef->second;

      LLVM_DEBUG(printDefs(dbgs()));
      assert(HeaderDefBlock && "no definition at header of carrying loop");

      for (const auto *ExitBlock : ReachedLoopExits) {
        auto ItExitDef = DefMap.find(ExitBlock);
        assert((ItExitDef != DefMap.end()) &&
               "no reaching def at reachable loop exit");
        if (ItExitDef->second != HeaderDefBlock) {
          JoinBlocks->insert(ExitBlock);
        }
      }
    }

    return std::move(JoinBlocks);
  }
};

const ConstBlockSet &SyncDependenceAnalysis::join_blocks(const Loop &Loop) {
  using LoopExitVec = SmallVector<BasicBlock *, 4>;
  LoopExitVec LoopExits;
  Loop.getExitBlocks(LoopExits);
  if (LoopExits.size() < 1) {
    return EmptyBlockSet;
  }

  // already available in cache?
  auto ItCached = CachedLoopExitJoins.find(&Loop);
  if (ItCached != CachedLoopExitJoins.end()) {
    return *ItCached->second;
  }

  // compute all join points
  DivergencePropagator Propagator{FuncRPOT, DT, PDT, LI};
  auto JoinBlocks = Propagator.computeJoinPoints<const LoopExitVec &>(
      *Loop.getHeader(), LoopExits, Loop.getParentLoop());

  auto ItInserted = CachedLoopExitJoins.emplace(&Loop, std::move(JoinBlocks));
  assert(ItInserted.second);
  return *ItInserted.first->second;
}

const ConstBlockSet &
SyncDependenceAnalysis::join_blocks(const Instruction &Term) {
  // trivial case
  if (Term.getNumSuccessors() < 1) {
    return EmptyBlockSet;
  }

  // already available in cache?
  auto ItCached = CachedBranchJoins.find(&Term);
  if (ItCached != CachedBranchJoins.end())
    return *ItCached->second;

  // compute all join points
  DivergencePropagator Propagator{FuncRPOT, DT, PDT, LI};
  const auto &TermBlock = *Term.getParent();
  auto JoinBlocks = Propagator.computeJoinPoints<succ_const_range>(
      TermBlock, successors(Term.getParent()), LI.getLoopFor(&TermBlock));

  auto ItInserted = CachedBranchJoins.emplace(&Term, std::move(JoinBlocks));
  assert(ItInserted.second);
  return *ItInserted.first->second;
}

} // namespace llvm