#include "clang/Driver/Options.h"
#include "clang/AST/AST.h"
#include "clang/AST/ASTContext.h"
#include "clang/AST/ASTConsumer.h"
#include "clang/AST/RecursiveASTVisitor.h"
#include "clang/Frontend/ASTConsumers.h"
#include "clang/Frontend/FrontendActions.h"
#include "clang/Frontend/CompilerInstance.h"
#include "clang/Tooling/CommonOptionsParser.h"
#include "clang/Tooling/Tooling.h"
#include "clang/Rewrite/Core/Rewriter.h"
#include "clang/Lex/Lexer.h"
#include "clang/AST/ParentMap.h"
#include <ctime>

using namespace std;
using namespace clang;
using namespace clang::driver;
using namespace clang::tooling;
using namespace llvm;

Rewriter rewriter;

typedef std::set< std::pair<std::string, std::string> > VarTypeMap;

VarTypeMap set_difference(VarTypeMap v1, VarTypeMap v2) {
  VarTypeMap res;
  std::set_difference(v1.begin(), v1.end(), v2.begin(), v2.end(), std::inserter(res, res.begin()));
  return res;
}

VarTypeMap set_intersection(VarTypeMap v1, VarTypeMap v2) {
  VarTypeMap res;
  std::set_intersection(v1.begin(), v1.end(), v2.begin(), v2.end(), std::inserter(res, res.begin()));
  return res;
}

VarTypeMap set_union(VarTypeMap v1, VarTypeMap v2) {
  VarTypeMap res;
  std::set_union(v1.begin(), v1.end(), v2.begin(), v2.end(), std::inserter(res, res.begin()));
  return res;
}

std::string _preSource;
std::string _postSource;

std::string getSourceText(Stmt* stmt, ASTContext* astContext) {
  clang::SourceRange range = stmt->getSourceRange();
  return Lexer::getSourceText(CharSourceRange::getTokenRange(range), astContext->getSourceManager(), astContext->getLangOpts(), 0);
}

clang::SourceLocation getLocForEndOfToken(clang::SourceLocation curLoc, ASTContext* astContext) {
  return clang::Lexer::getLocForEndOfToken(curLoc, 0, astContext->getSourceManager(), astContext->getLangOpts());
}

clang::SourceLocation locAfterStmt(Stmt* stmt, ASTContext* astContext) {
  return getLocForEndOfToken(stmt->getLocEnd(), astContext);
}

clang::SourceLocation findLocAfterToken(clang::SourceLocation curLoc, clang::tok::TokenKind tok, ASTContext* astContext) {
  return clang::Lexer::findLocationAfterToken(curLoc, tok, astContext->getSourceManager(), astContext->getLangOpts(), true);
}

clang::SourceLocation findLocAfterSemi(clang::SourceLocation curLoc, ASTContext* astContext) {
  return findLocAfterToken(curLoc, clang::tok::semi, astContext);
}

std::string stripLeading(const std::string& s, char c) {
  int i = 0;
  while (i < s.size() && isspace(s[i]))
    i++;
  if (i < s.size() && s[i] == c)
    return s.substr(i+1);
  else
    return s;
}

std::string stripTrailing(const std::string& s, char c) {
  int i = s.size() - 1;
  while (i >= 0 && isspace(s[i]))
    i--;
  if (i >= 0 && s[i] == c)
    return s.substr(0, i);
  else
    return s;
}

// trim from left
inline std::string& ltrim(std::string& s, const char* t = " \t\n\r\f\v") {
    s.erase(0, s.find_first_not_of(t));
    return s;
}

// trim from right
inline std::string& rtrim(std::string& s, const char* t = " \t\n\r\f\v") {
    s.erase(s.find_last_not_of(t) + 1);
    return s;
}

// trim from left & right
inline std::string& trim(std::string& s, const char* t = " \t\n\r\f\v") {
    return ltrim(rtrim(s, t), t);
}

inline std::string stripBraces(std::string& s) {
  s = trim(s);
  if (s.size() > 0 && s[0] == '{')
    s = stripTrailing(stripLeading(s, '{'), '}');
  s = trim(s);
  return s;
}


class StringTransformer {
public:
  virtual std::string transform(const std::string&) const = 0;
};

class PrefixStringTransformer : public StringTransformer {
private:
  std::string _prefix;
  
public:
  PrefixStringTransformer(std::string prefix)
    : _prefix(prefix) {
  }
    
  virtual std::string transform(const std::string& s) const {
    return _prefix + s;
  }
};

class SuffixStringTransformer : public StringTransformer {
private:
  std::string _suffix;
    
public:
  SuffixStringTransformer(std::string suffix)
    : _suffix(suffix) {
  }
    
  virtual std::string transform(const std::string& s) const {
    return s + _suffix;
  }
};
  
class EmptyStringTransformer: public StringTransformer {
public:
  virtual std::string transform(const std::string& s) const {
    return "";
  }
};

class IdentityStringTransformer: public StringTransformer {
public:
  virtual std::string transform(const std::string& s) const {
    return s;
  }
};
  
std::string makeArgList(const VarTypeMap& vars, const vector<string>& functionParams, const StringTransformer& varTransformer, const StringTransformer& typeTransformer) {
  std::vector< std::tuple<string, bool, string> > vars_vec;
  for (auto it = vars.begin(); it != vars.end(); ++it) {
    bool isParam = find(functionParams.begin(), functionParams.end(), it->first) != functionParams.end();
    vars_vec.push_back(make_tuple(it->second, isParam, it->first));
  }

  sort(vars_vec.begin(), vars_vec.end(),
       [](std::tuple<string, bool, string> a, std::tuple<string, bool, string> b) {
	 if (get<0>(a) < get<0>(b))
	   return true;
	 if (get<0>(a) > get<0>(b))
	   return false;
	 if (get<1>(a) && !get<1>(b))
	   return true;
	 if (!get<1>(a) && get<1>(b))
	   return false;
	 return get<2>(a) < get<2>(b);
       });
  
  std::string result;
  for (auto it = vars_vec.begin(); it != vars_vec.end(); ++it) {
    if (it != vars_vec.begin())
      result += ", ";
    result += typeTransformer.transform(get<0>(*it));
    result += " ";
    result += varTransformer.transform(get<2>(*it));
  }
  return result;
}

// All variables in an expression
class AllVarsVisitor: public RecursiveASTVisitor<AllVarsVisitor> {
private:
  ASTContext *astContext; // used for getting additional AST info
  VarTypeMap _vars;
public:
  explicit AllVarsVisitor(ASTContext* ac, Stmt* stmt) 
    : astContext(ac) {
  }

  virtual bool VisitStmt(Stmt* s) {
    if (isa<DeclRefExpr>(s)) {
      DeclRefExpr* dre = dyn_cast<DeclRefExpr>(s);
      assert(dre != 0);
      VarDecl* v = dyn_cast<VarDecl>(dre->getDecl());
      if (v != 0) {
	_vars.insert(std::make_pair(v->getNameAsString(), v->getType().getAsString()));
      }
    }
    return true;
  }

  const VarTypeMap& getVars() {
    return _vars;
  }
};
  
// Input = Read - Declared - Initialized
// Output = Changed - Declared

class VarVisitor: public RecursiveASTVisitor<VarVisitor> {
private:
  ASTContext *astContext; // used for getting additional AST info
  VarTypeMap _changedVars;     // var = ...
  VarTypeMap _readVars;        // ... = var
  VarTypeMap _initializedVars; // changed before every read
  VarTypeMap _declaredVars;    // type var;
  ParentMap* _map;

public:
  explicit VarVisitor(ASTContext* ac, Stmt* stmt) 
    : astContext(ac) {
    _map = new ParentMap(stmt);
  }

  ~VarVisitor() {
    if (_map)
      delete _map;
  }

public:
  virtual bool VisitDecl(Decl* d) {
    VarDecl* v = dyn_cast<VarDecl>(d);
    if (v != 0) {
      _declaredVars.insert(std::make_pair(v->getNameAsString(), v->getType().getAsString()));
    }
    return true;
  }

  void collectChangedVars(Expr* s) {
    AllVarsVisitor vv(astContext, s);
    vv.TraverseStmt(s);
    for (auto it : vv.getVars()) {
      _changedVars.insert(it);
    }
  }

  void collectReadVars(Expr* s) {
    AllVarsVisitor vv(astContext, s);
    vv.TraverseStmt(s);
    for (auto it : vv.getVars()) {
      _readVars.insert(it);
    }
  }

  bool conditional(Stmt* s) {
    while (s != nullptr) {
      if (isa<IfStmt>(s) || isa<CaseStmt>(s))
	return true;
      s = _map->getParent(s);
    }
    return false;
  }
  
  virtual bool VisitStmt(Stmt* s) {
    if (isa<BinaryOperator>(s)) {
      BinaryOperator* b = dyn_cast<BinaryOperator>(s);
      assert(b != 0);
      if (b->isAssignmentOp()) {
	Expr* lhs = b->getLHS();
	Expr* rhs = b->getRHS();

	collectReadVars(rhs);
	if (b->isCompoundAssignmentOp()) {
	  collectReadVars(lhs);
	}
	collectChangedVars(lhs);

	// is this an initialization of the lhs var?
	if (!b->isCompoundAssignmentOp() && isa<DeclRefExpr>(lhs)) {
	  DeclRefExpr* dre = dyn_cast<DeclRefExpr>(lhs);
	  assert(dre != 0);
	  VarDecl* v = dyn_cast<VarDecl>(dre->getDecl());
	  if (v != 0) {
	    auto p = std::make_pair(v->getNameAsString(), v->getType().getAsString());
	    // yes, if the var vas not previously read and the current statement is not within a conditional
	    if (_readVars.find(p) == _readVars.end() && !conditional(s)) {
	      _initializedVars.insert(p);
	    }
	  }
	}
      } else {
	Expr* lhs = b->getLHS();
	Expr* rhs = b->getRHS();
	collectReadVars(lhs);
	collectReadVars(rhs);
      }
    }

    if (isa<UnaryOperator>(s)) {
      UnaryOperator* u = dyn_cast<UnaryOperator>(s);
      assert(u != 0);
      if (u->isIncrementDecrementOp()) {
	collectReadVars(u->getSubExpr());
	collectChangedVars(u->getSubExpr());
      }
    }

    if (isa<CallExpr>(s)) {
      CallExpr* ce = dyn_cast<CallExpr>(s);
      assert(ce != 0);
      collectReadVars(ce);
    }

    if (isa<ReturnStmt>(s)) {
      ReturnStmt* rs = dyn_cast<ReturnStmt>(s);
      assert(rs != 0);
      collectReadVars(rs->getRetValue());
    }
    
    return true;
  }

  VarTypeMap getInitializedVars() {
    return _initializedVars;
  }
  
  VarTypeMap getOutputVars() {
    return set_difference(_changedVars, _declaredVars);
  }

  VarTypeMap getInputVars() {
    return set_difference(set_difference(_readVars, _declaredVars), _initializedVars);
  }
};

class UsedVarsAfterStmtVisitor: public RecursiveASTVisitor<UsedVarsAfterStmtVisitor> {
private:
  ASTContext *_astContext; // used for getting additional AST info
  Stmt* _stmt;
  bool _after;
  ParentMap* _stmt_map;
  ParentMap _body_map;
  VarTypeMap _vars;
  VarTypeMap _initialized_vars;
  
public:
  explicit UsedVarsAfterStmtVisitor(ASTContext* a, Stmt* s, Stmt* body)
    : _astContext(a), _stmt(s), _after(false), _body_map(body), _stmt_map(0) {
  }

  ~UsedVarsAfterStmtVisitor() {
    if (_stmt_map)
      delete _stmt_map;
  }

  virtual bool VisitStmt(Stmt* s) {
    if (s == _stmt) {
      _stmt_map = new ParentMap(_stmt);
    } else if (_stmt_map && !_stmt_map->getParent(s)) {
      delete _stmt_map;
      _stmt_map = 0;
      _after = true;
    }
    
    if (_after) {
      if (_body_map.getParent(s)) {
	VarVisitor vv(_astContext, s);
	vv.TraverseStmt(s);
	_initialized_vars = set_union(_initialized_vars, vv.getInitializedVars());
	_vars = set_difference(set_union(_vars, vv.getInputVars()), _initialized_vars);
      }
    }    
    return true;
  }

  VarTypeMap& getVars() {
    return _vars;
  }
};

// A loop contains return (not within an inner loop)
class LoopContainsReturnVisitor : public RecursiveASTVisitor<LoopContainsReturnVisitor> {
private:
  bool _containsReturn;
  Stmt* _top;
  ParentMap* _pm;
  std::vector<ReturnStmt*> _returns;
  std::vector<Stmt*> _returnParents; // top level parents of returns

public:
  explicit LoopContainsReturnVisitor(Stmt* top)
    : _top(top), _containsReturn(false) {
    _pm = new ParentMap(_top);
  }

  ~LoopContainsReturnVisitor() {
    // FIXME: gives segfault
    // delete pm;
  }

  virtual bool VisitStmt(Stmt* s) {
    if (isa<ReturnStmt>(s)) {
      _containsReturn = true;
      ReturnStmt* _return = dyn_cast<ReturnStmt>(s);
      Stmt* parent = s;
      Stmt* pp;
      while(1) {
	if (!isa<CompoundStmt>(parent))
	  pp = parent;
	parent = _pm->getParent(parent);
	if (parent == _top) {
	  _returnParents.push_back(pp);
	  break;
	}
	if (isa<WhileStmt>(parent)) {
	  _containsReturn = false; // that was an inner-loop return :(
	  _return = 0;
	}
      };
      if (_return)
	_returns.push_back(_return);
    }
    return true;
  }

  bool containsReturn() {
    return _containsReturn;
  }
  
  const std::vector<ReturnStmt*>& returnStmts() {
    return _returns;
  }

  const std::vector<Stmt*>& returnParents() {
    std::sort( _returnParents.begin(), _returnParents.end() );
    _returnParents.erase( std::unique( _returnParents.begin(), _returnParents.end() ), _returnParents.end() );
    return _returnParents;
  }
};

class ReturnInLoopElimVisitor : public RecursiveASTVisitor<ReturnInLoopElimVisitor> {
private:
  ASTContext *astContext; // used for getting additional AST info
  bool _rewritten;
public:
  explicit ReturnInLoopElimVisitor(CompilerInstance *CI)
    : astContext(&(CI->getASTContext())),
      _rewritten(false) // initialize private members
  {
  }

  virtual bool VisitStmt(Stmt* s) {
    if (_rewritten)
      return true;
    // TODO: do-while
    if (isa<WhileStmt>(s))  {
      WhileStmt* ws = dyn_cast<WhileStmt>(s);
      LoopContainsReturnVisitor v(ws);
      v.TraverseStmt(ws);
      if (v.containsReturn()) {
	_rewritten = true;

	const std::vector<ReturnStmt*> returnStmts = v.returnStmts();
	// if first return statement has a value, then we assume that all must have a value and all values must have the same type
	bool hasRetValue = returnStmts[0]->getRetValue() != 0;
	string retValType = hasRetValue ? returnStmts[0]->getRetValue()->getType().getAsString() : "int";

	string varDecl;
	// string doneVar = "__done" + std::to_string(rand());
	// varDecl += "  int " + doneVar + " = 0" + ";\n";
	string retVar = "__ret" + std::to_string(rand());
	//	if (hasRetValue) {
	string retUndef = "__RET_UNDEF";
	varDecl += "  " + retValType + " " + retVar + " = " + retUndef + ";\n  ";
	//	}
	rewriter.InsertText(ws->getLocStart(), varDecl);

	Stmt* cond = ws->getCond();
	string condText = getSourceText(cond, astContext);
	// string stopCond = "!" + doneVar;
	string stopCond = retVar + " == " + retUndef;
	rewriter.ReplaceText(cond->getLocStart(), rewriter.getRangeSize(cond->getSourceRange()), "(" + condText + ") && (" + stopCond + ")");

	const std::vector<Stmt*>& returnParents = v.returnParents();
	for (int k = 0; k < returnParents.size(); k++) {
	  rewriter.InsertTextAfterToken(locAfterStmt(returnParents[k], astContext), "\n  if (" + stopCond + ") {\n");
	  rewriter.InsertText(locAfterStmt(ws, astContext), "}");
	}
	
	for (int k = 0; k < returnStmts.size(); k++) {
	  string newRetBody;
	  //	  newRetBody += "{ ";
	  // newRetBody += doneVar + " = 1" + "; ";
	  // if (hasRetValue) {
	  std::string retVal = hasRetValue ? getSourceText(returnStmts[k], astContext).substr(7) : "";
	  newRetBody += retVar + " = " + (hasRetValue ? retVal : "__RET_DEF") + ";";
	  // }
	  //	  newRetBody += " }";
	  clang::SourceRange range = returnStmts[k]->getSourceRange();
	  range.setEnd(getLocForEndOfToken(range.getEnd(), astContext));
	  rewriter.ReplaceText(returnStmts[k]->getLocStart(), rewriter.getRangeSize(range), newRetBody);
	}

	string ifDoneRet = "\n  if (!(" + stopCond + "))" + " return";
	if (hasRetValue)
	  ifDoneRet += " " + retVar;
	ifDoneRet += ";\n";
	rewriter.InsertText(locAfterStmt(ws, astContext), ifDoneRet);
      }
    }
    return true;
  }

  bool rewritten() {
    return _rewritten;
  }  
};

class ForElimVisitor : public RecursiveASTVisitor<ForElimVisitor> {
private:
  ASTContext *astContext; // used for getting additional AST info
  bool _rewritten;
public:
  explicit ForElimVisitor(CompilerInstance *CI)
    : astContext(&(CI->getASTContext())),
      _rewritten(false) // initialize private members
  {
  }

  virtual bool VisitStmt(Stmt* s) {
    if (_rewritten)
      return true;
    // Transform for to while
    if (isa<ForStmt>(s))  {
      if (_rewritten) return true;
      
      ForStmt* fs = dyn_cast<ForStmt>(s);
      assert(fs != 0);
      
      std::string init_text = getSourceText(fs->getInit(), astContext);
      std::string cond_text = getSourceText(fs->getCond(), astContext);
      std::string inc_text = getSourceText(fs->getInc(), astContext);
      std::string body_text = getSourceText(fs->getBody(), astContext);
      body_text = stripBraces(body_text);

      std::string while_text;
      while_text =  init_text + ";\n";
      while_text += "  while(" + cond_text + ") {\n";
      while_text += "  " + body_text;
      if (body_text.back() != '}')
	while_text += ";\n";
      while_text += inc_text + ";\n";
      while_text += "  }\n";
      
      rewriter.ReplaceText(fs->getLocStart(), rewriter.getRangeSize(fs->getSourceRange()), while_text);
      _rewritten = true;
    }
    return true;
  }

  bool rewritten() {
    return _rewritten;
  }
};

class RecurserVisitor : public RecursiveASTVisitor<RecurserVisitor> {
private:
  ASTContext *astContext; // used for getting additional AST info
  bool _rewritten;
  FunctionDecl* _current_func;
  std::vector<std::string> _functions;

public:
  explicit RecurserVisitor(CompilerInstance *CI) 
    : astContext(&(CI->getASTContext())),
      _rewritten(false)
  {
  }

  std::set<std::string> varsSet(VarTypeMap vars) {
    std::set<std::string> res;
    for (auto it = vars.begin(); it != vars.end(); ++it)
      res.insert(it->first);
    return res;
  }

  std::string rewriteVars(const std::string& s, const std::set<std::string>& vars) {
    std::string res = "";
    int i = 0;
    while (i < s.length()) {
      if (isalpha(s[i]) || s[i] == '_') {
	int j = i;
	while(isalpha(s[i]) || s[i] == '_' || isdigit(s[i]))
	  i++;
	std::string var = s.substr(j, i-j);
	if (vars.find(var) != vars.end())
	  res += "(*p_" + var + ")";
	else
	  res += var;
      } else
	res += s[i++];
    }
    return res;
  }

  vector<string> getCurrentFunctionParams() {
    vector<string> params;
    for (int i = 0, n = _current_func->getNumParams(); i < n; i++)
      params.push_back(_current_func->parameters()[i]->getNameAsString());
    return params;
  }

  virtual bool VisitDecl(Decl* decl) {
    if (decl->isFunctionOrFunctionTemplate()) {
      _current_func = decl->getAsFunction();
      _functions.push_back(_current_func->getNameInfo().getAsString());
    }
    return true;
  }
  
  virtual bool VisitStmt(Stmt* s) {
    if (_rewritten)
      return true;

    // Transform while to recursion
    if (isa<WhileStmt>(s) || isa<DoStmt>(s))  {
      
      WhileStmt* ws = dyn_cast<WhileStmt>(s);
      DoStmt* ds = dyn_cast<DoStmt>(s);

      Stmt* body;
      if (ws != 0) body = ws->getBody();
      if (ds != 0) body = ds->getBody();

      Stmt* cond;
      if (ws != 0) cond = ws->getCond();
      if (ds != 0) cond = ds->getCond();
      
      VarVisitor vv(astContext, s);
      vv.TraverseStmt(cond);
      vv.TraverseStmt(body);
      VarTypeMap inputVars = vv.getInputVars();
      VarTypeMap outputVars = vv.getOutputVars();
      VarTypeMap initializedVars = vv.getInitializedVars();

      Stmt* func_body = _current_func->getBody();
      UsedVarsAfterStmtVisitor usedAfterVisitor(astContext, s, func_body);
      usedAfterVisitor.TraverseStmt(func_body);
      VarTypeMap usedAfterVars = usedAfterVisitor.getVars();      
      VarTypeMap usedOutputVars = set_intersection(outputVars, usedAfterVars);
      bool error = false;
      if (usedOutputVars.size() != 1) {
	errs() << "Error: none or more than two changed variables: \n";
	error = true;
      }

      // Add output var to input if it is not initialized
      auto it = *usedOutputVars.begin();
      if (inputVars.find(it) == inputVars.end() &&
	  initializedVars.find(it) == initializedVars.end())
	inputVars.insert(it);
      
      errs() << "Input vars: "
	     << makeArgList(inputVars, getCurrentFunctionParams(), IdentityStringTransformer(), EmptyStringTransformer()) << "\n";
      errs() << "Initialized vars: "
	     << makeArgList(initializedVars, getCurrentFunctionParams(), IdentityStringTransformer(), EmptyStringTransformer()) << "\n";
      errs() << "Output vars: "
	     << makeArgList(outputVars, getCurrentFunctionParams(), IdentityStringTransformer(), EmptyStringTransformer()) << "\n";
     
      errs() << "Used after vars: "
	     << makeArgList(usedAfterVars, getCurrentFunctionParams(), IdentityStringTransformer(), EmptyStringTransformer()) << "\n";
      
      if (error)
	return false;

      int count_uf_functions = 0;
      std::string uf_prefix = "__UF_";
      for (int i = 0; i < _functions.size(); i++)
	if (_functions[i].substr(0, uf_prefix.size()) == uf_prefix)
	  count_uf_functions++;
      
      string recFunName = uf_prefix;
      recFunName = recFunName + "_" + std::to_string(inputVars.size()) + "_" + std::to_string(count_uf_functions);


      std::string retType = makeArgList(usedOutputVars, getCurrentFunctionParams(), EmptyStringTransformer(), IdentityStringTransformer());
      std::string recFunProto = retType + " " + recFunName + "(";
      recFunProto += makeArgList(inputVars, getCurrentFunctionParams(), EmptyStringTransformer(), IdentityStringTransformer());
      recFunProto += ")";
      _preSource += recFunProto + ";";
	     
      
      string while_rec_call = recFunName;
      while_rec_call += "(";
      while_rec_call += makeArgList(inputVars, getCurrentFunctionParams(), IdentityStringTransformer(), EmptyStringTransformer());
      while_rec_call += ");";

      std::string var = usedOutputVars.begin()->first;
      while_rec_call = var + " = " + while_rec_call;

      std::string cond_text = getSourceText(cond, astContext); 
      std::string body_text = getSourceText(body, astContext);
      body_text = stripBraces(body_text);

      std::string exitCond = "if (" + cond_text + ")";

      std::string code = exitCond + "{" + "\n" + 
	body_text + "\n" + 
	while_rec_call + "\n" + 
	"}";
      
      if (ws != 0)
	rewriter.ReplaceText(ws->getLocStart(), rewriter.getRangeSize(ws->getSourceRange()), code);
      if (ds != 0)
	rewriter.ReplaceText(ds->getLocStart(), rewriter.getRangeSize(ds->getSourceRange()), code);      

      _rewritten = true;
    }
    return true;
  }

  bool rewritten() {
    return _rewritten;
  }
};

bool rewritten = false;
class RecurserASTConsumer : public ASTConsumer {
private:
  ForElimVisitor *forElimVisitor;
  ReturnInLoopElimVisitor* returnElimVisitor;
  RecurserVisitor *recurserVisitor;

public:
    // override the constructor in order to pass CI
  explicit RecurserASTConsumer(CompilerInstance *CI)
    : recurserVisitor(new RecurserVisitor(CI)),
      forElimVisitor(new ForElimVisitor(CI)),
      returnElimVisitor(new ReturnInLoopElimVisitor(CI))
  {
    rewriter.setSourceMgr(CI->getASTContext().getSourceManager(), CI->getASTContext().getLangOpts());
  }

  // override this to call our RecurserVisitor on each top-level Decl
  virtual bool HandleTopLevelDecl(DeclGroupRef DG) {
    // a DeclGroupRef may have multiple Decls, so we iterate through each one
    for (DeclGroupRef::iterator i = DG.begin(), e = DG.end(); i != e; i++) {
      Decl *D = *i;
      forElimVisitor->TraverseDecl(D); // recursively visit each AST node in Decl "D"
      if (forElimVisitor->rewritten()) {
	errs() << "For elimination rewritten\n";
	rewritten = true;
      }
    }

    if (!rewritten) {
      // a DeclGroupRef may have multiple Decls, so we iterate through each one
      for (DeclGroupRef::iterator i = DG.begin(), e = DG.end(); i != e; i++) {
	Decl *D = *i;    
	returnElimVisitor->TraverseDecl(D); // recursively visit each AST node in Decl "D"
      }      
      if (returnElimVisitor->rewritten()) {
	errs() << "Return eliminiation rewritten\n";
	rewritten = true;
      }
    }
    
    if (!rewritten) {
      // a DeclGroupRef may have multiple Decls, so we iterate through each one
      for (DeclGroupRef::iterator i = DG.begin(), e = DG.end(); i != e; i++) {
	Decl *D = *i;
	recurserVisitor->TraverseDecl(D); // recursively visit each AST node in Decl "D"
      }
      if (recurserVisitor->rewritten()) {
	errs() << "Recurser rewritten\n";
	rewritten = true;
      }
    }
    return true;
  }
};



class RecurserFrontendAction : public ASTFrontendAction {
public:
  virtual unique_ptr<ASTConsumer> CreateASTConsumer(CompilerInstance &CI, StringRef file) {
    return unique_ptr<ASTConsumer>(new RecurserASTConsumer(&CI)); // pass CI pointer to ASTConsumer
  }
};


static cl::OptionCategory RecurserCategory("RecurserTool options");

int main(int argc, const char **argv) {
  // parse the command-line args passed to your code
  CommonOptionsParser op(argc, argv, RecurserCategory);
  // create a new Clang Tool instance (a LibTooling environment)
  ClangTool Tool(op.getCompilations(), op.getSourcePathList());

  srand(time(NULL));
  // run the Clang Tool, creating a new FrontendAction (explained below)
  int result = Tool.run(newFrontendActionFactory<RecurserFrontendAction>().get());

  // print out the rewritten source code ("rewriter" is a global var.)
  outs() << _preSource << "\n";
  rewriter.getEditBuffer(rewriter.getSourceMgr().getMainFileID()).write(outs());
  outs() << _postSource << "\n";

  return rewritten;
}
