#ifndef __EUF_H__
#define __EUF_H__

#include <vector>
#include "expressions/Expression.h"
#include "theory/Theory.h"
#include "auxiliary/BacktrackableUnionFind.h"
#include "auxiliary/BacktrackableStack_.h"

class EUF : public Theory {
 public:
    EUF() 
	: _unsat(false),
	_currentLevel(0),
	_expressionComparator(_expressions), 
	_unionFind(_expressionComparator),
	_assertedDisequalities(_currentLevel), 
	_assertedEqualities(_currentLevel) {
	registerTheory(this);
    }

    virtual void init() {
    }

    virtual void addExpression(Expression e);

    virtual void newDecision() {
    }

    virtual void backtrackDecision() {
    }
    

    virtual void assertExpression(Expression e);
    virtual void backtrack();
    virtual void commitBacktrack();


    virtual void explain(std::vector<Expression>& explanation) {
	explanation = _explanation;
    }

    void explainEquality(unsigned x, unsigned y, std::vector<Expression>& explanation);

    virtual bool theoryPropagate(std::vector<Expression>& explanation);

    virtual void forceCheck(bool strong);

    virtual bool isUnsat() {
	return _unsat;
    }

    virtual bool generateModel() {
	return true;
    }

    virtual bool checkAgainstModel(Expression e) {
	coutput << "Checking: " << e << endl;
	return true;
    }

    virtual Expression cannonize(Expression e) {
	Expression l = e[0];
	Expression r = e[1];
	if (e.IsEquality()) {
	    if (_expressionComparator.compare(getIndex(l), getIndex(r)) < 0)
		return Expression::Equality(l, r);
	    else
		return e;
	} else if (e.IsDisequality()) {
	    if (_expressionComparator.compare(getIndex(l), getIndex(r)) < 0)
		return Expression::Disequality(l, r);
	    else
		return e;
	}
	return e;
    }

    virtual Expression negate(Expression e) {
	if (e.IsDisequality())
	    return Expression::Equality(e[0], e[1]);
	else if (e.IsEquality())
	    return Expression::Disequality(e[0], e[1]);
	else
	    return Expression::NOT(e);
    }

    virtual bool isTrue(Expression e) {
	return false;
    }

    virtual bool isFalse(Expression e) { 
	return false;
    }

    void newLevel() {
	_unionFind.newLevel();
	_assertedDisequalities.newLevel();
	_assertedEqualities.newLevel();
	_currentLevel++;
    }

    void print() const;
 private:
    bool _unsat;
    std::vector<Expression> _explanation;

    bool checkDisequality(unsigned xi, unsigned yi);
    bool checkEquality(unsigned xi, unsigned yi);

    std::vector<Expression> _equalities;
    std::vector<Expression> _disequalities;

    struct ExpressionComparator {
	ExpressionComparator(const std::vector<Expression>& expressions)
	    : _expressions(expressions) {
	}

	int compare(unsigned i1, unsigned i2) const {
	    const Expression& e1 = _expressions[i1];
	    const Expression& e2 = _expressions[i2];
	    if (e1.IsNumeral() && !e2.IsNumeral())
		return -1;
	    if (!e1.IsNumeral() && e2.IsNumeral())
		return 1;
	    return e1.GetName().compare(e2.GetName());
	}

	const std::vector<Expression>& _expressions;
    };

    unsigned getIndex(Expression& expr) {
	if (expr.getIndex() == (unsigned)(-1)) {
	    _expressions.push_back(expr);
	    expr.setIndex(_expressions.size()-1);
	    _unionFind.addElement(expr.getIndex());
	}
	return expr.getIndex();
    }
    
    unsigned getIndex(const Expression& expr) const {
	return expr.getIndex();
    }

    unsigned _currentLevel;
    std::vector<Expression> _expressions;
    ExpressionComparator _expressionComparator;
    BacktrackableUnionFind<unsigned, ExpressionComparator> _unionFind;
    BacktrackableStack<Expression> _assertedDisequalities;
    BacktrackableStack<Expression> _assertedEqualities;

};

#endif
