reference, declarationdefinition
definition → references, declarations, derived classes, virtual overrides
reference to multiple definitions → definitions
unreferenced
    1
    2
    3
    4
    5
    6
    7
    8
    9
   10
   11
   12
   13
   14
   15
   16
   17
   18
   19
   20
   21
   22
   23
   24
   25
   26
   27
   28
   29
   30
   31
   32
   33
   34
   35
   36
   37
   38
   39
   40
   41
   42
   43
   44
   45
   46
   47
   48
   49
   50
   51
   52
   53
   54
   55
   56
   57
   58
   59
   60
   61
   62
   63
   64
   65
   66
   67
   68
   69
   70
   71
   72
   73
   74
   75
   76
   77
   78
   79
   80
   81
   82
   83
   84
   85
   86
   87
   88
   89
   90
   91
   92
//===- unittests/AST/ASTPrint.h ------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Helpers to simplify testing of printing of AST constructs provided in the/
// form of the source code.
//
//===----------------------------------------------------------------------===//

#include "clang/AST/ASTContext.h"
#include "clang/ASTMatchers/ASTMatchFinder.h"
#include "clang/Tooling/Tooling.h"
#include "llvm/ADT/SmallString.h"
#include "gtest/gtest.h"

namespace clang {

using PolicyAdjusterType =
    Optional<llvm::function_ref<void(PrintingPolicy &Policy)>>;

static void PrintStmt(raw_ostream &Out, const ASTContext *Context,
                      const Stmt *S, PolicyAdjusterType PolicyAdjuster) {
  assert(S != nullptr && "Expected non-null Stmt");
  PrintingPolicy Policy = Context->getPrintingPolicy();
  if (PolicyAdjuster)
    (*PolicyAdjuster)(Policy);
  S->printPretty(Out, /*Helper*/ nullptr, Policy);
}

class PrintMatch : public ast_matchers::MatchFinder::MatchCallback {
  SmallString<1024> Printed;
  unsigned NumFoundStmts;
  PolicyAdjusterType PolicyAdjuster;

public:
  PrintMatch(PolicyAdjusterType PolicyAdjuster)
      : NumFoundStmts(0), PolicyAdjuster(PolicyAdjuster) {}

  void run(const ast_matchers::MatchFinder::MatchResult &Result) override {
    const Stmt *S = Result.Nodes.getNodeAs<Stmt>("id");
    if (!S)
      return;
    NumFoundStmts++;
    if (NumFoundStmts > 1)
      return;

    llvm::raw_svector_ostream Out(Printed);
    PrintStmt(Out, Result.Context, S, PolicyAdjuster);
  }

  StringRef getPrinted() const { return Printed; }

  unsigned getNumFoundStmts() const { return NumFoundStmts; }
};

template <typename T>
::testing::AssertionResult
PrintedStmtMatches(StringRef Code, const std::vector<std::string> &Args,
                   const T &NodeMatch, StringRef ExpectedPrinted,
                   PolicyAdjusterType PolicyAdjuster = None) {

  PrintMatch Printer(PolicyAdjuster);
  ast_matchers::MatchFinder Finder;
  Finder.addMatcher(NodeMatch, &Printer);
  std::unique_ptr<tooling::FrontendActionFactory> Factory(
      tooling::newFrontendActionFactory(&Finder));

  if (!tooling::runToolOnCodeWithArgs(Factory->create(), Code, Args))
    return testing::AssertionFailure()
           << "Parsing error in \"" << Code.str() << "\"";

  if (Printer.getNumFoundStmts() == 0)
    return testing::AssertionFailure() << "Matcher didn't find any statements";

  if (Printer.getNumFoundStmts() > 1)
    return testing::AssertionFailure()
           << "Matcher should match only one statement (found "
           << Printer.getNumFoundStmts() << ")";

  if (Printer.getPrinted() != ExpectedPrinted)
    return ::testing::AssertionFailure()
           << "Expected \"" << ExpectedPrinted.str() << "\", got \""
           << Printer.getPrinted().str() << "\"";

  return ::testing::AssertionSuccess();
}

} // namespace clang