#include "EUF.h"
void EUF::addExpression(Expression e) {
//    coutput << "EUF add: " << e << endl;
    Expression l = e[0], r = e[1];
    getIndex(l); getIndex(r);
    if (e.IsEquality()) 
	_equalities.push_back(e);
    else if (e.IsDisequality())
	_disequalities.push_back(e);
}

void EUF::assertExpression(Expression e) {
//	coutput << "    ======================================" << endl;
//	coutput << "    EUF assert: " << e << endl;
    Expression l = e[0], r = e[1];
    unsigned li = getIndex(l), ri = getIndex(r);

    e.setAssigned(true);
    negate(e).setAssigned(true);

    if (e.IsEquality()) {
	if (!checkEquality(li, ri)) {
	    _unsat = true;
	} else {
	    _unionFind.Union(li, ri);
	    _assertedEqualities.push(e);
	}
    } else if (e.IsDisequality()) {
	if (!checkDisequality(li, ri)) {
	    _unsat = true;
	} else {
	    _assertedDisequalities.push(e);
	}
    }
}

void EUF::explainEquality(unsigned x, unsigned y, std::vector<Expression>& explanation) {
    if (x == y)
	return;
    std::vector< std::pair<unsigned, unsigned> > unionFindExplanation;
    _unionFind.Explain(x, y, unionFindExplanation);
    std::vector< std::pair<unsigned, unsigned> >::const_iterator eit;
    for (eit = unionFindExplanation.begin(); eit != unionFindExplanation.end(); eit++) {
	explanation.push_back(cannonize(Expression::Equality(_expressions[eit->first], _expressions[eit->second])));
    }
}

bool EUF::checkEquality(unsigned xi, unsigned yi) {
//    coutput << "Check: " << x << " = " << y << endl;
    unsigned rxi = _unionFind.Find(xi);
    unsigned ryi = _unionFind.Find(yi);
    
    const Expression& rx = _expressions[rxi];
    const Expression& ry = _expressions[ryi];
    if (rx.IsNumeral() && ry.IsNumeral() && 
	rx.GetValueRational() != ry.GetValueRational()) {
	_explanation.clear();
	_explanation.push_back(cannonize(Expression::Equality(_expressions[xi], _expressions[yi])));
	explainEquality(xi, rxi, _explanation);
	explainEquality(yi, ryi, _explanation);
	return false;
    }
    return true;
}

bool EUF::checkDisequality(unsigned xi, unsigned yi) {
//    coutput << "Check: " << x << " = " << y << endl;
    unsigned rxi = _unionFind.Find(xi);
    unsigned ryi = _unionFind.Find(yi);
    
    const Expression& rx = _expressions[rxi];
    const Expression& ry = _expressions[ryi];

//    coutput << "Check: " << x << " != " << y << endl;
    if (rx == ry) {
	_explanation.clear();
	_explanation.push_back(cannonize(Expression::Disequality(_expressions[xi], _expressions[yi])));
	explainEquality(xi, yi, _explanation);
	return false;
    }
    return true;
}

void EUF::forceCheck(bool strong) {
//    cout << "FORCE CHECK" << endl;
    BacktrackableStack<Expression>::const_iterator it;
    for (it = _assertedDisequalities.begin(); it != _assertedDisequalities.end(); ++it) {
//	coutput << "Checking: " << *it << endl;
	if (!checkDisequality(getIndex((*it)[0]), getIndex((*it)[1]))) {
	    _unsat = true;
	    return;
	}
    }
}

bool EUF::theoryPropagate(std::vector<Expression>& explanation) {
    std::vector<Expression>::const_iterator it, 
	beg = _equalities.begin(), en = _equalities.end();
    for (it = beg; it != en; ++it) {
	if (it->assigned())
	    continue;

	unsigned li = getIndex((*it)[0]), ri = getIndex((*it)[1]);
	unsigned rli = _unionFind.Find(li), rri = _unionFind.Find(ri);
	const Expression& rx = _expressions[rli]; 
	const Expression& ry = _expressions[rri];
	if (rx == ry) {
	    explanation.clear();
	    explanation.push_back(*it);
	    explainEquality(li, ri, explanation);
//	    coutput << "EUF: --> " << (*it) << endl;
//	    print();
	    return true;
	}
    }
    beg = _disequalities.begin(); en = _disequalities.end();
    for (it = beg; it != en; it++) {
	if (it->assigned())
	    continue;
	unsigned li = getIndex((*it)[0]), ri = getIndex((*it)[1]);
	unsigned rli = _unionFind.Find(li), rri = _unionFind.Find(ri);
	const Expression& rx = _expressions[rli]; 
	const Expression& ry = _expressions[rri];
	if (rx.IsNumeral() && ry.IsNumeral() && 
	    rx.GetValueRational() != ry.GetValueRational()) {
	    explanation.clear();
	    explanation.push_back(*it);
	    explainEquality(li, rli, explanation);
	    explainEquality(ri, rri, explanation);
//	    coutput << "EUF: --> " << (*it) << endl;
//	    print();
	    return true;
	}

/*
	newLevel();
	Expression eq = Expression::Equality(l, r);
	assertExpression(eq);
	forceCheck();
	if (_unsat) {
	    coutput << "EUF: --> " << (*it) << endl;
	    explanation.clear();
	    explanation.push_back(*it);
	    std::vector<Expression>::const_iterator jt;
	    for (jt = _explanation.begin(); jt != _explanation.end(); jt++) {
		if (*jt != eq)
		    explanation.push_back(*jt);
	    }
	    backtrack();
	    return true;
	}
	backtrack();
	_print = true;
*/
    }
    return false;
}

extern std::string toString_(const Expression& e);

void EUF::backtrack() {
    _unionFind.Backtrack();
    _currentLevel--;
    _unsat = false;
}

void EUF::commitBacktrack() {
    BacktrackableStack<Expression>::const_iterator it;
    for (it = _assertedEqualities.beginLastLevel(); it != _assertedEqualities.endLastLevel(); ++it) {
	Expression e = *it;
	e.setAssigned(false);
	negate(e).setAssigned(false);
    }
    for (it = _assertedDisequalities.beginLastLevel(); it != _assertedDisequalities.endLastLevel(); ++it) {
	Expression e = *it;
	e.setAssigned(false);
	negate(e).setAssigned(false);
    }

    _assertedEqualities.backtrack();
    _assertedDisequalities.backtrack();
}

void EUF::print() const {
    BacktrackableStack<Expression>::const_iterator it;
    for (it = _assertedEqualities.begin(); it != _assertedEqualities.end(); ++it)
	coutput << *it << endl;
    for (it = _assertedDisequalities.begin(); it != _assertedDisequalities.end(); ++it)
	coutput << *it << endl;
}
