#include "clang/Driver/Options.h"
#include "clang/AST/AST.h"
#include "clang/AST/ASTContext.h"
#include "clang/AST/ASTConsumer.h"
#include "clang/AST/RecursiveASTVisitor.h"
#include "clang/Frontend/ASTConsumers.h"
#include "clang/Frontend/FrontendActions.h"
#include "clang/Frontend/CompilerInstance.h"
#include "clang/Tooling/CommonOptionsParser.h"
#include "clang/Tooling/Tooling.h"
#include "clang/Rewrite/Core/Rewriter.h"
#include "clang/Lex/Lexer.h"
#include <ctime>
#include <iostream>

using namespace std;
using namespace clang;
using namespace clang::driver;
using namespace clang::tooling;
using namespace llvm;

Rewriter rewriter;

typedef std::set< std::pair<std::string, std::string> > VarTypeMap;

const std::string RETURN_ARRAY = "ret_array";

std::string _preSource;
std::string _postSource = "\n#include \"../check.h\"\n";

bool isNumber(std::string s) {
  char* p;
  strtol(s.c_str(), &p, 10);
  return *p == 0;
}

// count number of occurences of the needle in a haystack
int count(const string& needle, const string& haystack) {
  int res = 0;
  std::size_t pos = haystack.find(needle);
  while (pos != std::string::npos) {
    res++;
    pos = haystack.find(needle, pos + 1);
  }
  return res;
}

int containsNum(const string& str, string& num) {
  std::size_t n = str.find_first_of("0123456789");
  if (n != std::string::npos)
  {
    if (n > 0 && str[n-1] == '-')  n--;
    std::size_t m = str.find_first_not_of("0123456789", n+1);
    if (m != std::string::npos && m < str.length() && str[m] == '.')
      m = str.find_first_not_of("0123456789", m+1);
    num = str.substr(n, m != std::string::npos ? m-n : m);
    return n;
  }
  return -1;
}

std::string getArg(CallExpr* e, int i) {
  clang::LangOptions LangOpts;
  LangOpts.CPlusPlus = true;
  clang::PrintingPolicy policy(LangOpts);
  std::string buff;
  llvm::raw_string_ostream arg_s(buff);
  e->getArg(i)->printPretty(arg_s, 0, policy);
  return arg_s.str();
}

typedef std::vector< std::pair<std::string, VarDecl*> > VarVector;
class Variables {
public:
  
  void insert(const std::string& s, VarDecl* v) {
    _vars.push_back(std::make_pair(s, v));
  }

  bool contains(const std::string& s) {
    for (auto it = _vars.begin(); it != _vars.end(); it++)
      if (it->first == s)
	return true;
    return false;
  }

  VarVector::const_iterator begin() {
    return _vars.begin();
  }

  VarVector::const_iterator end() {
    return _vars.end();
  }

private:
  VarVector _vars;
};

class ScanfVarVisitor : public  RecursiveASTVisitor<ScanfVarVisitor> {
private:
  ASTContext *astContext; // used for getting additional AST info
  std::map<std::string, VarDecl*> _vars;
  Variables _scanf_vars;
public:
  explicit ScanfVarVisitor(CompilerInstance *CI) 
    : astContext(&(CI->getASTContext()))
  {
    rewriter.setSourceMgr(astContext->getSourceManager(), astContext->getLangOpts());
  }

  virtual bool VisitDecl(Decl* d) {
    if (VarDecl* v = dyn_cast<VarDecl>(d)) {
      _vars[v->getNameAsString()] = v;
    }
    return true;
  }

  
  virtual bool VisitCallExpr(CallExpr* e) {
    // Get scanf-ed variables
    if (e->getDirectCallee()->getNameInfo().getAsString() == "scanf") {
      for (int i = 1; i < e->getNumArgs(); i++) {
	std::string arg = getArg(e, i);
	if (arg[0] == '&')
	  _scanf_vars.insert(arg.substr(1), _vars[arg.substr(1)]);
      }
      return true;
    }
  }

  const Variables& getVars() {
    return _scanf_vars;
  }
};

class CheckMainVisitor : public RecursiveASTVisitor<CheckMainVisitor> {
private:
  ASTContext *astContext; // used for getting additional AST info
  Variables _vars;

  clang::SourceLocation getLocForEndOfToken(clang::SourceLocation curLoc) {
    return clang::Lexer::getLocForEndOfToken(curLoc, 0, astContext->getSourceManager(), astContext->getLangOpts());
  }
  
public:
  explicit CheckMainVisitor(CompilerInstance *CI) 
    : astContext(&(CI->getASTContext()))
  {
    rewriter.setSourceMgr(astContext->getSourceManager(), astContext->getLangOpts());
  }

  void setVars(const Variables& vars) {
    _vars = vars;
  }


  bool isidchar(char c) {
    return isalpha(c) || isdigit(c) || c == '_';
  }

  bool followsChar(std::string s, int pos, char c) {
    do {
      pos--;
    } while (pos >= 0 && isspace(s[pos]));
    return pos >= 0 && s[pos] == c;
  }
  
  std::string removeVar(std::string s, std::string var) {
    std::size_t pos = s.find(var);
    do {
      if (pos == std::string::npos)
	return s;
      bool startWord = pos == 0 || (pos > 0 && !isidchar(s[pos-1]));

      // avoid variables in initialization
      if (followsChar(s, pos, '='))
	startWord = false;
      
      std::size_t afterPos = pos + var.size();
      bool endWord = afterPos == s.size() || (afterPos < s.size() && !isidchar(s[afterPos]));
      
      if (startWord && endWord) break;
      pos = s.find(var, pos+1);
    } while (true);
	     
    std::size_t commaBefore = s.rfind(",", pos);
    std::size_t commaAfter = s.find(",", pos);
    if (commaBefore != std::string::npos) {
      s.erase(commaBefore, pos + var.size() - commaBefore);
      return s;
    } else if (commaAfter != std::string::npos) {
      s.erase(pos, commaAfter - pos + 1);
      return s;
    } else
      return "";
  }
  
  virtual bool VisitStmt(Stmt* stmt) {
    // Remove _var declaration
    if (const CompoundStmt *cstmt = dyn_cast<CompoundStmt>(stmt)) {
      for(auto substmt = cstmt->body_begin(); substmt != cstmt->body_end(); ++substmt) {
	if(const DeclStmt *dstmt = dyn_cast<DeclStmt>(*substmt)) {
	  const DeclGroupRef DGR = dstmt->getDeclGroup();
	  SourceRange range;
	  range.setBegin((*DGR.begin())->getLocStart());
	  range.setEnd(getLocForEndOfToken((*(DGR.end() - 1))->getLocEnd()));
	  std::string decl_text = clang::Lexer::getSourceText(CharSourceRange::getTokenRange(range), astContext->getSourceManager(), astContext->getLangOpts(), 0);
	  for (auto it = _vars.begin(); it != _vars.end(); ++it) {
	    decl_text = removeVar(decl_text, it->first);
	  }

	  rewriter.ReplaceText(range.getBegin(), rewriter.getRangeSize(CharSourceRange(range, true)), decl_text);
	}
      }
    }
    return true;
  }

  virtual bool VisitDecl(Decl* d) {
    FunctionDecl* f = dyn_cast<FunctionDecl>(d);
    if (f != 0 && f->isMain()) {
      SourceRange range = f->getSourceRange();
      range.setEnd(f->getBody()->getLocStart().getLocWithOffset(-1));

      string proto = "int check_main(";
      int i = 0;
      for (auto it = _vars.begin(); it != _vars.end(); ++it, ++i) {
	if (i > 0)
	  proto += ", ";
	proto += it->second->getType().getAsString() + " " + it->first;
      }
      string return_array = "double " + RETURN_ARRAY + "[]";
      proto += ", " + return_array;
      proto += ")";
      rewriter.ReplaceText(range.getBegin(), rewriter.getRangeSize(CharSourceRange(range, true)), proto);
    }
    
    return true;
  }

  void addToReturnArray(Expr* e, string arg) {
    static int i = 0;
    rewriter.InsertText(e->getLocStart(), "{");
    string stmt = RETURN_ARRAY + "[" + to_string(i++) + "]" + "=" + arg + ";";
    rewriter.InsertText(clang::Lexer::getLocForEndOfToken(e->getLocEnd(), 0, astContext->getSourceManager(), astContext->getLangOpts()).getLocWithOffset(1), "\n" + stmt + "}");
    
  }

  virtual bool VisitCallExpr(CallExpr* e) {
    // Remove scanf
    if (e->getDirectCallee()->getNameInfo().getAsString() == "scanf") {
      rewriter.InsertText(e->getLocStart(), "// ");
    }

    // Add return after printf number
    if (e->getDirectCallee()->getNameInfo().getAsString() == "printf" &&
	e->getArg(0)->getType().getAsString() == "const char *") {

      // Get first argument
      string arg = getArg(e, 0);
      // remove quotes "..."
      arg = arg.substr(1, arg.size() - 2);
      // check and remove trailing "\n"
      if (arg.size() >= 2 && arg.substr(arg.size() - 2, arg.size() - 1) == "\\n")
	arg = arg.substr(0, arg.size()-2);

      // if it is a format, generate return statement
      int num_args = count("%", arg);
      for (int k = 1; k <= num_args; k++)
	addToReturnArray(e, getArg(e, k));
    }
    return true;
  }
};

class CheckMainASTConsumer : public ASTConsumer {
private:
  CheckMainVisitor *check_main_visitor; // doesn't have to be private
  ScanfVarVisitor *scanf_var_visitor; // doesn't have to be private

public:
    // override the constructor in order to pass CI
  explicit CheckMainASTConsumer(CompilerInstance *CI)
    : check_main_visitor(new CheckMainVisitor(CI)),
      scanf_var_visitor(new ScanfVarVisitor(CI))
  { }

  // override this to call our RecurserVisitor on each top-level Decl
  virtual bool HandleTopLevelDecl(DeclGroupRef DG) {
    // a DeclGroupRef may have multiple Decls, so we iterate through each one
    for (DeclGroupRef::iterator i = DG.begin(), e = DG.end(); i != e; i++) {
      Decl *D = *i;
      scanf_var_visitor->TraverseDecl(D);
      check_main_visitor->setVars(scanf_var_visitor->getVars());
      check_main_visitor->TraverseDecl(D); // recursively visit each AST node in Decl "D"
    }
    return true;
  }  
};



class CheckMainFrontendAction : public ASTFrontendAction {
public:
  virtual unique_ptr<ASTConsumer> CreateASTConsumer(CompilerInstance &CI, StringRef file) {
    return unique_ptr<ASTConsumer>(new CheckMainASTConsumer(&CI)); // pass CI pointer to ASTConsumer
  }
};


static cl::OptionCategory CheckMainCategory("CheckMainTool options");

int main(int argc, const char **argv) {
  srand(time(NULL));
  
  // parse the command-line args passed to your code
  CommonOptionsParser op(argc, argv, CheckMainCategory);
  // create a new Clang Tool instance (a LibTooling environment)
  ClangTool Tool(op.getCompilations(), op.getSourcePathList());

  // run the Clang Tool, creating a new FrontendAction (explained below)
  int result = Tool.run(newFrontendActionFactory<CheckMainFrontendAction>().get());

  // print out the rewritten source code ("rewriter" is a global var.)
  outs() << _preSource << "\n";
  rewriter.getEditBuffer(rewriter.getSourceMgr().getMainFileID()).write(outs());
  outs() << _postSource << "\n";
  return result;
}
