/***************************************************************************
  ArgoSat
  Copyright (C) 2007 Filip Maric, Predrag Janicic

  This program is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License version 2
  as published by the Free Software Foundation.

  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with this program; if not, write to the Free Software
  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
-----------------------------------------------------------------------------
  This program is inspired by MiniSat solver (C) Een, Sorensson 2003-2006.
  It uses Bliss (C) Tommi Junttila
*****************************************************************************/

#include "Clause.h"
#include "Solver.h"
#include "SolverStatistics.h"
#include <ctime>
#include <cstdlib>
#include <fstream>
#include <iostream>
using std::cerr;
using std::endl;

#include "expressions/Expression.h"
#include "theory/Theory.h"

// extern std::ostream& logger;

void printExplanationLemma(const std::vector<Literal>& literals) {
    coutput << "lemma \"(" << endl;
    std::vector<Literal>::const_iterator it;
    for (it = literals.begin(); it != literals.end(); it++) {
	coutput << Literals::getExpression(Literals::getOpposite(*it));
	if (it != literals.end() - 1)
	    coutput << " &";
	coutput << endl;
    }
    coutput << ") --> False\"" << endl;
    cout << "apply auto" << endl << "done" << endl << endl;
}


void Solver::printClauses() {
    for (size_t i = 0; i < _initialClauses.size(); i++)
	cout << _initialClauses[i]->toString() << endl;
}

bool random_boolean() {
    return rand() > RAND_MAX / 2;
}

Solver::Solver() 
    : 
    _satisfiable(UNDEF),
    _conflictClause(0),
    _backjumpClause(this),

    _variableActivity(1.0, 1/0.95), 
    _literalSelectionPolarity(LITERAL_SELECTION_POLARITY_FALSE),
    _variableSelectionStrategy(VARIABLE_SELECTION_MAX_ACTIVE),

    _forgetIncreaseFactor(1.1f),
    _clauseActivity(1.0, 1/0.999),

    _numConflictsForFirstRestart(100),
    _restartIncreaseFactor(1.5f) {
}

Solver::~Solver() {
    std::vector<Clause*>::const_iterator it;
    for (it = _learntClauses.begin(); it != _learntClauses.end(); it++)
	removeClause(*it);
    
    for (it = _initialClauses.begin(); it != _initialClauses.end(); it++)
	removeClause(*it);
}

/****************************************************
 *      SOLVE
 *****************************************************/

void Solver::solve() {
    Theory* theory = Theory::getInstance();
    theory->init();

    assert(getCurrentDecisionLevel() == 0);
//    logger << "Starting solver: " << endl
//	 << "\t" << getNumberOfInitialClauses() << " initial clauses, " << endl
//	 << "\t" << getNumberOfVariables() << " variables" << endl;

//    SolverStatistics::printHeader();
//    SolverStatistics::printStatistics();

    if (!_assertionTrail.empty())
	forgetSatisfiedClauses();

    _numConflictsSinceLastRestart = 0;
    _numConflictsForRestart = _numConflictsForFirstRestart;

    _numClausesForForget = (size_t)(3 * getNumberOfInitialClauses());

    while(_satisfiable == UNDEF) {
	while (applicableUnitPropagate() && !isConflicting()) {
            SolverStatistics::newPropagation();
	    applyUnitPropagate();
	}

	if (isConflicting()) {
	    Clause* conflictClause = getConflictClause();

//	    cout << _assertionTrail.toString() << endl;
//	    cout << "Conflict: " << conflictClause->toString() << endl;

	    bumpClauseActivity(conflictClause);
	    SolverStatistics::newConflict();
	    _numConflictsSinceLastRestart++;

	    // Makes solver online
	    while(!conflictClause->containsLiteral(Literals::getOpposite(getLastAssertedLiteral())))
		backtrackLiteral();

//	    cout << "ONLINE: " << _assertionTrail.toString() << endl;

	    if (getCurrentDecisionLevel() == 0) {
//		cout << "Formula found UNSAT" << endl;
		_satisfiable = FALSE;
	    } else {
		applyBackjump();
		applyLearn(_backjumpClause.getClause());
	    }
	    
	    _clauseActivity.decay();
	    _variableActivity.decay();
	} else {
	    if (_numConflictsSinceLastRestart >= _numConflictsForRestart) {
		applyRestart();
		forgetSatisfiedClauses();
		SolverStatistics::newRestart();
//		SolverStatistics::printStatistics();

		_numConflictsSinceLastRestart = 0;
		_numConflictsForRestart = (size_t)(_numConflictsForRestart * _restartIncreaseFactor);

	    }

	    if (getTotalNumberOfClauses() >= _numClausesForForget) {
		applyForget();
		_numClausesForForget = (size_t)(_numClausesForForget * _forgetIncreaseFactor);
	    }
	    
	    if (!allVariablesAssigned()) {
		if (!applyTheoryPropagate()) {
		    SolverStatistics::newDecision();

		    theory->forceCheck(false);
		    if (theory->isUnsat()) {
			addTheoryLemma();
			continue;
		    }

		    applyDecide();
		}
	    } else {
		theory->forceCheck(true);
		if (theory->isUnsat()) {
		    addTheoryLemma();
		    continue;
		}


//		logger << "Formula found SAT" << endl;
		_satisfiable = TRUE;
		generateModel();
		verifyModel();
		backtrackToLevel(0);
	    }
	}
    }
//    SolverStatistics::printStatistics();
//    SolverStatistics::printFooter();
}

/****************************************************
 *       VARIABLE SET
 ****************************************************/
Variable Solver::newVariable(bool useAsDecisionVariable) {
    Variable var = Variables::newVariable(useAsDecisionVariable);
    _reason.push_back(0);
    _reason.push_back(0);
    _watchClauses.push_back(std::vector<Clause*>());
    _watchClauses.push_back(std::vector<Clause*>());
    if (useAsDecisionVariable)
	_variableActivityHeap.push_heap(var);
    else
	_variableActivityHeap.push_inactive(var);

    return var;
}

/****************************************************
 *       MODEL
 ****************************************************/
void Solver::generateModel() {
    assert(isSatisfiable());
    _model.clear();

//    for (Variable variable = 0; variable < Variables::size(); variable++) {
//	_model.push_back(Variables::getValue(variable) == TRUE ?
//			 Literals::getLiteral(variable, true) :
//			 Literals::getLiteral(variable, false));
//    }
    BacktrackableStack<Literal>::const_iterator it;
    for (it = _assertionTrail.begin(); it != _assertionTrail.end(); it++) {
	_model.push_back(*it);
//	if (Literals::getExpression(*it) != Expression())
//	    coutput << Literals::getExpression(*it) << endl;
    }

    Theory* theory = Theory::getInstance();
    if (!theory->generateModel()) {
	cout << "Couldn't genrate theory model" << endl;
	exit(EXIT_FAILURE);
    }
}

bool Solver::verifyModel() const {
//    cout << _assertionTrail.toString() << endl;
    const std::vector<Literal>& model = getModel();
    std::vector<Clause*>::const_iterator it, 
	beg = _initialClauses.begin(), en = _initialClauses.end();
    for (it = beg; it != en; it++)
	if (!(*it)->isTrue(model)) {
	    cerr << "Clause: " << (*it)->toString() << " is not true!" << endl;
	    return false;
	}
//    cout << "Verified: " << getNumberOfInitialClauses() << " initial clauses" << endl;

    size_t verifiedAtomsNum = 0;
    Theory* theory = Theory::getInstance();
    std::vector<Literal>::const_iterator mit;
    for (mit = model.begin(); mit != model.end(); mit++) {
	Expression expr = Literals::getExpression(*mit);
	if (expr != Expression()) {
	    verifiedAtomsNum++;
	    if (!theory->checkAgainstModel(expr)) {
		coutput << "Error: " << expr << " is false in the model" << endl;
		exit(EXIT_FAILURE);
	    }
	}
    }
//    cout << "Verified: " << verifiedAtomsNum << " theory atoms" << endl;

    return true;
}


/****************************************************
 *      INITIAL CLAUSE SET
 ****************************************************/
void Solver::addInitialClause(Clause* clause) {
    assert(getCurrentDecisionLevel() == 0);
    assert(clause->isInitial());

    SolverStatistics::newInitialClause();
//    cout << "Adding clause: " << clause->toString() << endl;

    if (clause->isFalse()) {
//	logger << "Top Level FALSE clause: " << clause->toString() << endl;
	_satisfiable = FALSE;
    } else if (clause->isTrue()) {
//	logger << "Top Level TRUE clause, skipping: " << clause->toString() << endl;
    } else if (clause->size() == 1) {
//	logger << "Single literal clause: " << clause->toString() << endl;
	assertLiteral((*clause)[0]);
    } else {
	_satisfiable = UNDEF;
	_initialClauses.push_back(clause);
	initializeWatches(clause);
    }
}

size_t Solver::getNumberOfInitialClauses() const {
    return _initialClauses.size();
}

bool Solver::isInitialClause(Clause* clause) const {
    assert(!clause->isInitial() || 
	   std::find(_initialClauses.begin(), _initialClauses.end(), clause) != 
	   _initialClauses.end());
    return clause->isInitial();
}


/****************************************************
 *       ASSERTION TRAIL  -  M
 ****************************************************/
void Solver::assertLiteral(Literal l) {
//    cout << "Assert: " << Literals::toString(l) << endl;
    _assertionTrail.push(l);
//    cout << _assertionTrail.toString() << endl;


    int currentDecisionLevel = getCurrentDecisionLevel();
    Literals::setTrue(l, currentDecisionLevel);

    Theory* theory = Theory::getInstance();
    Expression expr(Literals::getExpression(l));
    if (expr != Expression()) {
	theory->assertExpression(expr);
    }

    if (theory->isUnsat()) {
	addTheoryLemma();
    } else 
	notifyWatchClauses(Literals::getOpposite(l));
}

Literal Solver::getLastAssertedLiteral() {
    return _assertionTrail.back();
}

size_t Solver::getCurrentDecisionLevel() {
    return _assertionTrail.getLevel();
}

size_t Solver::getLiteralsDecisionLevel(Literal literal) {
    return Literals::getDecisionLevel(literal);
}

Literal Solver::backtrackLiteral() {
    Literal literal = _assertionTrail.pop();
    if (_assertionTrail.lastLevelEmpty()) {
	_assertionTrail.retractLevel();
	Theory::getInstance()->backtrackDecision();
    }
    
    Literals::setUnassigned(literal);
    Expression expr = Literals::getExpression(literal);
    if(expr != Expression()) {
	Theory::getInstance()->backtrack();
    }

    Variable variable = Literals::getVariable(literal);
    if (!_variableActivityHeap.contains(variable) && 
	Variables::useAsDecisionVariable(variable))
	_variableActivityHeap.push_heap(variable);

    return literal;
}

Literal Solver::backtrackLevel() {
    size_t currentDecisionLevel = getCurrentDecisionLevel();
    Literal literal;
    do {
	literal = backtrackLiteral();
    } while (getCurrentDecisionLevel() == currentDecisionLevel);
    return literal;
}

void Solver::backtrackToLevel(size_t level) {
    while(getCurrentDecisionLevel() > level) {
	backtrackLevel();
    }
}

/***************************************************
 *     CONFLICT CLAUSE DETECTION
 ***************************************************/
void Solver::setConflictClause(Clause* clause) {
    assert(clause == 0 || clause->isFalse());
    _conflictClause = clause;
}


/*****************************************************
 *         UNIT PROPAGATE
 *****************************************************/
bool Solver::applicableUnitPropagate() {
    return getUnitClause() != 0;
}

void Solver::applyUnitPropagate() {
    Clause* unitClause = getUnitClause();
    Literal unitLiteral = getUnitLiteral(unitClause);
//  cout << unitClause->toString() << " -> " << Literals::toString(unitLiteral) << endl;
    if (!Literals::isTrue(unitLiteral)) {
	setReason(unitLiteral, unitClause);
	assertLiteral(unitLiteral);
    }

    findNextUnitClause();
}

Literal Solver::getUnitLiteral(Clause* unitClause) {
    assert(unitClause->isUnit() || unitClause->isTrue() || unitClause->isFalse());
    return getWatch(0, unitClause);
}

Clause* Solver::getUnitClause() {
    return _unitClauseQueue.empty() ? 0 : _unitClauseQueue.back();
}

void Solver::findNextUnitClause() {
    if (!_unitClauseQueue.empty())
	_unitClauseQueue.pop_back();
}


void Solver::enqueueUnitClause(Clause* unitClause) {
    assert(unitClause->isUnit());
    _unitClauseQueue.push_front(unitClause);
}

void Solver::setWatch(char watchNumber, Clause* clause, size_t position) {
    assert(watchNumber == 0 || watchNumber == 1);
    clause->swapLiterals(watchNumber, position);
    addWatchClause((*clause)[watchNumber], clause);
}

Literal Solver::getWatch(char watchNumber, Clause* clause) {
    assert(watchNumber == 0 || watchNumber == 1);
    return (*clause)[watchNumber];
}

void Solver::removeWatch(char watchNumber, Clause* clause) {
    assert(watchNumber == 0 || watchNumber == 1);
    removeWatchClause((*clause)[watchNumber], clause);
}

void Solver::swapWatches(Clause* clause) {
    clause->swapLiterals(0, 1);
}

void Solver::initializeWatches(Clause* clause) {
    // Assert: Clause is not empty
    assert(clause->size() > 0);

    // Single literal clause
    if (clause->size() == 1) {
	Literal literal = (*clause)[0];
	if (Literals::isFalse(literal)) {
		setConflictClause(clause);
		_unitClauseQueue.clear();
	} else if (!Literals::isTrue(literal)) {
	        // Clause is unit and should be propagated
	        enqueueUnitClause(clause);
	}
        return;
    }

    // Assert: Multi literal clause
    assert(clause->size() >= 2);

    // find and set watch literals
    findAndSetWatch(0, clause);
    findAndSetWatch(1, clause);
    // Assert: UF UF X X X X X X X X   or UF F F F F F F F F F

    Literal watch0 = getWatch(0, clause);
    Literal watch1 = getWatch(1, clause);

    if (Literals::isFalse(watch1) && !Literals::isTrue(watch0) && !Literals::isFalse(watch0)) {
	// Assert: U F F F F F F F F F
        // Clause is unit and should be propagated
        enqueueUnitClause(clause);
    }
}

void Solver::findAndSetWatch(char watchNumber, Clause* clause) {
    assert(watchNumber == 0 || watchNumber == 1);

    int watchLiteralPosition = findUnfalsifiedLiteralPosition(clause, watchNumber);
    if (watchLiteralPosition == -1) {
        // There are no more unfalsified literals
        // Put watch on last falsified literal
        watchLiteralPosition = findLatestFalsifiedLiteralPosition(clause, watchNumber);
    }
    assert(watchLiteralPosition != -1);

    setWatch(watchNumber, clause, watchLiteralPosition);
}

int Solver::findUnfalsifiedLiteralPosition(Clause* clause, char watchNumber) {
//    logger << clause->valuationString() << endl;
    Clause::const_iterator i, b = clause->begin(), e = clause->end();
    for (i = b + watchNumber; i != e; i++) {
	if (!Literals::isFalse(*i))
	    return i - b;
    }
    return -1;
}

int Solver::findLatestFalsifiedLiteralPosition(Clause* clause, char watchNumber) {
    Clause::const_iterator i, b = clause->begin(), e = clause->end();
    Clause::const_iterator maxPosition = b + watchNumber- 1;
    size_t maxLevel = 0;
    for (i = maxPosition + 1; i != e; i++) {
        if (getLiteralsDecisionLevel(*i) >= maxLevel) {
            maxPosition = i;
            maxLevel = getLiteralsDecisionLevel(*maxPosition);
        }
    }
    return maxPosition - b;
}


void Solver::addWatchClause(Literal literal, Clause* clause) {
    _watchClauses[literal].push_back(clause);
}

void Solver::removeWatchClause(Literal literal, Clause* clause) {
    std::vector<Clause*>& clauseList = _watchClauses[literal];
    std::vector<Clause*>::iterator it = 
	std::find(clauseList.begin(), clauseList.end(), clause);
    if (it != clauseList.end())
	clauseList.erase(it);
}

void Solver::notifyWatchClauses(Literal literal) {
    std::vector<Clause*>& clauseList = _watchClauses[literal];

    std::vector<Clause*>::const_iterator currentClause,
	beg = clauseList.begin(), en = clauseList.end();
    std::vector<Clause*>::iterator lastKeptClause = clauseList.begin();
    for (currentClause = beg; currentClause != en; currentClause++) {
	Clause* clause = (*currentClause);

	// Assert: Watch literal is falsified
	// X F X X X X X X X    or   F X X X X X X X X X
        assert(Literals::isFalse(getWatch(0, clause)) ||
               Literals::isFalse(getWatch(1, clause)));

        // Assure that the false literal is watch 1
        if (getWatch(0, clause) == literal) {
            swapWatches(clause);
        }

        // Assert: X F X X X X X X X X
        assert(Literals::isFalse(getWatch(1, clause)));

	Literal watch0 = getWatch(0, clause);

        // Clause is true, there is no need to change watch literals
	// Assert: T F X X X X X X X X
        if (Literals::isTrue(watch0)) {
	    *lastKeptClause++ = clause;
	    continue;
        }

        int firstUnfalsified = -1;
        if ( clause->size() != 2 && 
	     (firstUnfalsified = findUnfalsifiedLiteralPosition(clause, 2)) != -1) {
	    // There are more unassigned literals so update watch
	    // Assert: UT F F F F UF X X X X
            assert(!Literals::isFalse((*clause)[firstUnfalsified]));
            setWatch(1, clause, firstUnfalsified);
            // Assert: UT UF F F F F X X X X
	    continue;
        }

        // No unfalsified literals except eventually first watch
        assert(clause->size() == 2 || firstUnfalsified == -1);

	// Assert: UT F F F F F F F F F
        // Clause is conflict or unit
	if (Literals::isFalse(watch0)) {
	    setConflictClause(clause);
	} else {
	    enqueueUnitClause(clause);
	}

	*lastKeptClause++ = clause;
    }
    clauseList.erase(lastKeptClause, clauseList.end());

    if (isConflicting())
	_unitClauseQueue.clear();
}

// Propagation graph

Clause* Solver::getReason(Literal literal) {
    return _reason[literal];
}

void Solver::setReason(Literal literal, Clause* clause) {
    _reason[literal] = clause;
}

/*****************************************************
 *        BACKTRACK
 *****************************************************/
/*
void Solver::applyBacktrack() {
    Literal literal = backtrackLevel();
    setConflictClause(0);
    assertLiteral(Literals::getOpposite(literal));
}
*/

/*****************************************************
 *        BACKJUMP
 *****************************************************/
void Solver::applyBackjump() {
    findBackjumpClause();
    setConflictClause(0);

    Clause* backjumpClause = _backjumpClause.getClause();
    size_t  backjumpLevel = _backjumpClause.getMaxPreviousLevel();
    Literal backjumpLiteral = _backjumpClause.getResolutionLiteral();

//    cout << "BackjumpClause: " << backjumpClause->toString() << endl;
//    cout << "BackjumpLiteral: " << Literals::toString(backjumpLiteral) << endl;
//    cout << "BackjumpLevel: " << backjumpLevel << endl;

    backtrackToLevel(backjumpLevel);
    setReason(backjumpLiteral, backjumpClause);
//    cout << "Backtrack: " << _assertionTrail.toString() << endl;

    Theory::getInstance()->commitBacktrack();
    assertLiteral(backjumpLiteral);
}

void Solver::findBackjumpClause() {
    findFirstUIPClause();
    performSubsumptionResolution();
}

void Solver::findFirstUIPClause() {
    _backjumpClause = ResolutionClause(this);
    assert(isConflicting());

    Clause* conflictClause = getConflictClause();

    _backjumpClause.setConflictClause(conflictClause);
    bumpClauseVariablesActivity(conflictClause);

    Literal resolutionLiteral = Literals::getOpposite(getLastAssertedLiteral());
    _backjumpClause.setResolutionLiteral(resolutionLiteral);

//    cout << _backjumpClause.toString() << endl;

    while(_backjumpClause.getNumberOfCurrentLevelLiterals() > 1) {
	Clause* reasonClause = getReason(getLastAssertedLiteral());
//	cout << "Resolve with: " << reasonClause->toString() << endl;
	_backjumpClause.resolve(reasonClause);
	bumpClauseVariablesActivity(reasonClause);

	// backtrack Resolution literal
	backtrackLiteral();
	// backtrack Non Resolving Literals
	while(!_backjumpClause.containsVariable(Literals::getVariable(getLastAssertedLiteral()))) {
	    backtrackLiteral();
	}

	Literal resolutionLiteral = Literals::getOpposite(getLastAssertedLiteral());
	_backjumpClause.setResolutionLiteral(resolutionLiteral);

//	cout << _backjumpClause.toString() << endl;
    }
}

void Solver::performSubsumptionResolution() {
    std::vector<Literal> literals;
    _backjumpClause.getPreviousLevelLiterals(literals);
    std::vector<Literal>::const_iterator i;
    for (i = literals.begin(); i != literals.end(); i++) {
	Clause* predecessor = getReason(Literals::getOpposite(*i));
	if (predecessor != 0) {
	    if (_backjumpClause.subsumes(predecessor)) {
//		cout << "Resolve with: " << predecessor->toString() << endl;
		_backjumpClause.removeLiteral(*i);
	    }
	}
    }
}

/*****************************************************
 *        DECIDE
 *****************************************************/
void Solver::applyDecide() {
//    cout << "---------------------------------------------------------------------------------" << endl;
//    cout << _assertionTrail.toString() << endl;
    _assertionTrail.newLevel();
    Theory::getInstance()->newDecision();

    Literal decisionLiteral = selectDecisionLiteral();
//    cout << "Decide: " << Literals::toString(decisionLiteral) << endl;
    assertLiteral(decisionLiteral);
    setReason(decisionLiteral, 0);
}

bool Solver::allVariablesAssigned() {
    size_t nvars = Variables::size();
    for (Variable var = 0; var < nvars; var++)
	if (Variables::useAsDecisionVariable(var) && Variables::getValue(var) == UNDEF)
	    return false;
    return true;
}

Literal Solver::selectDecisionLiteral() {
    Variable variable;
    switch(_variableSelectionStrategy) {
	case VARIABLE_SELECTION_FIRST_UNDEFINED:
	    variable = getFirstUndefinedVariable();
	    break;
	case VARIABLE_SELECTION_MAX_ACTIVE:
	    variable = getMaxActiveUndefinedVariable();
	    break;
	default:
	    throw "Unknown variable selection algorithm";
    }

    bool polarity;
    switch(_literalSelectionPolarity) {
	case LITERAL_SELECTION_POLARITY_TRUE:
	    polarity = true;
	    break;
	case LITERAL_SELECTION_POLARITY_FALSE:
	    polarity = false;
	    break;
	case LITERAL_SELECTION_POLARITY_RANDOM:
	    polarity = random_boolean();
	    break;
	default:
	    throw "Unknown literal polarity strategy";
    }
    return Literals::getLiteral(variable, polarity);
}

Variable Solver::getFirstUndefinedVariable() {
    size_t nvars = Variables::size();
    for (Variable var = 0; var < nvars; var++)
	if (Variables::getValue(var) == UNDEF)
	    return var;
    return 0;
}

void Solver::bumpVariableActivity(Variable variable) {
    _variableActivity.bump(Variables::getActivity(variable));
    if (Variables::getActivity(variable) > MAX_ACTIVITY) {
	size_t nvars = Variables::size();
	for (Variable var = 0; var < nvars; var++)
	    _variableActivity.rescale(Variables::getActivity(var));

	_variableActivity.rescaleBumpAmount();
    }

    if (_variableActivityHeap.contains(variable))
	_variableActivityHeap.increase(variable);
}

Variable Solver::getMaxActiveUndefinedVariable() {
//    size_t nvars = Variables::size();
//    double maxActivity = 0.0;
//    Variable maxVar = 0;
//    for (Variable var = 0; var < nvars; var++) {
//	double activity = Variables::getActivity(var);
//	if (activity >= maxActivity && Variables::getValue(var) == UNDEF) {
//	    maxVar = var;
//	    maxActivity = activity;
//	}
//    }
    Variable maxVar;
    do {
	maxVar = _variableActivityHeap.pop_heap();
    } while (Variables::getValue(maxVar) != UNDEF);
    return maxVar;
}

/*****************************************************
 *        LEARN
 *****************************************************/
void Solver::applyLearn(Clause* clause) {
    addLearntClause(clause);
}

void Solver::addLearntClause(Clause* clause) {
    SolverStatistics::newLearntClause();
    _learntClauses.push_back(clause);    
    initializeWatches(clause);
}

bool Solver::isLearntClause(Clause* clause) const {
    assert(clause->isInitial() || 
	   std::find(_learntClauses.begin(), _learntClauses.end(), clause) != 
	   _learntClauses.end());
    return !clause->isInitial();
}

size_t Solver::getNumberOfLearntClauses() const {
    return _learntClauses.size();
}

size_t Solver::getTotalNumberOfClauses() const {
    return getNumberOfInitialClauses() + getNumberOfLearntClauses();
}

/*****************************************************
 *        FORGET
 *****************************************************/
void Solver::applyForget() {
//    logger << "Forgetting: " << _numClausesForForget << " clauses reached" << endl;
    std::sort(_learntClauses.begin(), _learntClauses.end(), ClauseComparator(this));

    size_t n = 0;
    std::vector<Clause*>::iterator i;
    for (i = _learntClauses.begin(); 
	 i != _learntClauses.end() && (isLockedByCurrentTrail(*i) || 2*n < _learntClauses.size()); 
	 i++, n++)
	;

    std::vector<Clause*>::iterator j;
    for (j = i; j < _learntClauses.end(); j++) {
	assert(!isLockedByCurrentTrail(*j));
	SolverStatistics::removeLearntClause();
	removeClause(*j);
    }
    _learntClauses.erase(i, _learntClauses.end());
//    logger << "\t" << getTotalNumberOfClauses() << " clauses after" << endl;
}

void Solver::removeClause(Clause* clause) {
    removeWatch(0, clause);
    if (clause->size() > 1)
	    removeWatch(1, clause);
    
    delete clause;
}

void Solver::bumpClauseActivity(Clause* clause) {
    if (!isLearntClause(clause))
	return;

    _clauseActivity.bump(clause->getActivity());
    if (clause->getActivity() > MAX_ACTIVITY) {
	std::vector<Clause*>::iterator it, 
	    beg = _learntClauses.begin(), en = _learntClauses.end();
	for (it = beg; it != en; it++) 
	    _clauseActivity.rescale((*it)->getActivity());

	_clauseActivity.rescaleBumpAmount();
    }
}

void Solver::bumpClauseVariablesActivity(Clause* clause) {
    Clause::const_iterator it,
	beg = clause->begin(), en = clause->end();
    for (it = beg; it != en; it++)
	bumpVariableActivity(Literals::getVariable(*it));
}

bool Solver::isLockedByCurrentTrail(Clause* clause) {
    return clause->size() == 2 || getReason(getWatch(0, clause)) == clause;
}

bool Solver::ClauseComparator::operator() (Clause* c1, Clause* c2) {
    bool l1 = solver->isLockedByCurrentTrail(c1);
    bool l2 = solver->isLockedByCurrentTrail(c2);
		
    if (l1 && !l2)
	return true;
    if (!l1 && l2)
	return false;
                
    return c1->getActivity() > c2->getActivity();
}

void Solver::forgetSatisfiedClauses() {
    assert(getCurrentDecisionLevel() == 0);

    size_t initialClausesBefore = getNumberOfInitialClauses();
    size_t learntClausesBefore  = getNumberOfLearntClauses();

    std::vector<Clause*> newClauses;
    std::vector<Clause*>::const_iterator it, 
	beg = _learntClauses.begin(), en = _learntClauses.end();
    for (it = beg; it != en; it++)
	if ((*it)->isTrue() && !isLockedByCurrentTrail(*it)) {
	    SolverStatistics::removeLearntClause();
	    removeClause(*it);
	} else {
	    newClauses.push_back(*it);
	}
    _learntClauses = newClauses;

    newClauses.clear();
    beg = _initialClauses.begin(); en = _initialClauses.end();

    for (it = beg; it != en; it++) {
	if ((*it)->isTrue() && !isLockedByCurrentTrail(*it)) {
	    SolverStatistics::removeInitialClause();
	    removeClause(*it);
	} else {
	    newClauses.push_back(*it);
	}
    }
    _initialClauses = newClauses;

    if (initialClausesBefore < getNumberOfInitialClauses() ||
	learntClausesBefore < getNumberOfLearntClauses()) {
//	logger << "Removed satisfied clauses: " << endl;
//	logger << "\t" << initialClausesBefore << "+" << learntClausesBefore << " clauses before" << endl;
//	logger << "\t" << getNumberOfInitialClauses() << "+" << getNumberOfLearntClauses() << " clauses after" << endl;
    }
}

/*****************************************************
 *        RESTART
 *****************************************************/
void Solver::applyRestart() {
//    logger << "Restarting: " << _numConflictsSinceLastRestart << " conflicts reached" << endl;
    backtrackToLevel(0);
    Theory::getInstance()->commitBacktrack();
}

/*****************************************************
 *        THEORY PROPAGATE
*****************************************************/
void Solver::addTheoryLemma() {
    Theory* theory = Theory::getInstance();
    std::vector<Expression> explanation;
    theory->explain(explanation);
    std::vector<Literal> literals;
    std::vector<Expression>::const_iterator it;
    for (it = explanation.begin(); it != explanation.end(); it++) {
	Literal l = Literals::getOpposite(Literals::getLiteral(*it));
	if (find(literals.begin(), literals.end(), l) == literals.end())
	    literals.push_back(l);
    }
    Clause* conflictClause = new Clause(literals);
//    cout << "Conflict Clause: " << conflictClause->toString() << endl;
//    printExplanationLemma(literals);
    addLearntClause(conflictClause);
    setConflictClause(conflictClause);
    _unitClauseQueue.clear();
//	std::cin.get();

}
bool Solver::applyTheoryPropagate() {
//    return false;
    Theory* theory = Theory::getInstance();
    std::vector<Expression> explanation;
    if (!theory->theoryPropagate(explanation))
	return false;
    std::vector<Expression>::const_iterator it;
    std::vector<Literal> literals;
    literals.push_back(Literals::getLiteral(explanation[0]));
    for (it = explanation.begin()+1; it != explanation.end(); it++) {
	Literal l = Literals::getOpposite(Literals::getLiteral(*it));
	if (find(literals.begin(), literals.end(), l) == literals.end())
		 literals.push_back(l);
    }

    Clause* unitClause = new Clause(literals);
//    cout << "Unit Clause: " << unitClause->toString() << endl;
//    printExplanationLemma(literals);

    if (!unitClause->isUnit()) {
	cout << "ERROR: not unit" << endl;
	cout << _assertionTrail.toString() << endl;
	exit(EXIT_FAILURE);
    }
    addLearntClause(unitClause);
    enqueueUnitClause(unitClause);
    return true;
}
