#include "sat/SatAbstraction.h"

Expression SatAbstraction::abstractExpression(const Expression& e, Solver& solver) {
    std::vector<Expression> operands;
    if (e.IsConnective()) {
	Expression::operands_iterator i;
	for (i = e.begin(); i != e.end(); i++) {
	    operands.push_back(abstractExpression(*i, solver));
	}
    }
    if (e.IsNOT())
	return Expression::NOT(operands[0]);
    if (e.IsAND())
	return Expression::AND(operands);
    if (e.IsOR())
	return Expression::OR(operands);
    if (e.IsIMPL())
	return Expression::IMPL(operands[0], operands[1]);
    if (e.IsIFF())
	return Expression::IFF(operands[0], operands[1]);
    if (e.IsTOP())
	return e;
    if (e.IsBOT())
	return e;
    if (e.IsFormulaVariable())
	return e;


    Theory* theory = Theory::getInstance();	
    Expression positive = theory->cannonize(e);

    // Check if the expression is trivially sat/unsat
    if (theory->isTrue(positive)) {
//	coutput << "TRUE: " << positive << std::endl;
	return Expression::TOP();
    } else if (theory->isFalse(positive)) {
//	coutput << "FALSE: " << positive << std::endl;
	return Expression::BOT();
    }

    Expression negative = theory->cannonize(theory->negate(positive));

    Literal i;
    // Check if this expression is already defined
    i = getLiteral(positive);
    if (i != (Literal)(-1)) {
//	assert(negative.IsDisequality() || getLiteral(negative) != (Literal)(-1));
	return getLiteralAsExpression(i);
    }

    i = getLiteral(negative);
    if (i != (Literal)(-1)) {
//	assert(positive.IsDisequality() && getLiteral(positive) == (Literal)(-1));
	return getLiteralAsExpression(Literals::getOpposite(i));
    }
	

    // Expression is undefined get fresh literal for it
    Variable v = solver.newVariable();
    Literal positive_literal = Literals::getLiteral(v, true);
    Literal negative_literal = Literals::getLiteral(v, false);


//    if (positive.IsDisequality()) {
//	addDisequality(negative, positive, negative_literal, positive_literal, solver);
//    }

    Literals::setExpression(positive_literal, positive);
    theory->addExpression(positive);


//    if (negative.IsDisequality()) {
//	addDisequality(positive, negative, positive_literal, negative_literal, solver);
//    }

    Literals::setExpression(negative_literal, negative);
    theory->addExpression(negative);

    return getLiteralAsExpression(positive_literal);
}

void SatAbstraction::addDisequality(const Expression& equality, const Expression& disequality, Literal eq, Literal deq,
				    Solver& solver) {
    Theory* theory = Theory::getInstance();	

    Expression ltExpr = Expression::Predicate("<", equality[0], equality[1]);
    Literal lt = getLiteral(ltExpr);
    if (lt == (Literal)(-1)) {
	Expression geqExpr = Expression::Predicate(">=", equality[0], equality[1]);
	assert(getLiteral(geqExpr) == (Literal)(-1));
	Variable ltvar = solver.newVariable();
	lt = Literals::getLiteral(ltvar, true);
	Literal geq = Literals::getLiteral(ltvar, false);
	Literals::setExpression(lt, ltExpr);
	Literals::setExpression(geq, geqExpr);
	theory->addExpression(ltExpr);
	theory->addExpression(geqExpr);
    }

    Expression gtExpr = Expression::Predicate(">", equality[0], equality[1]);
    Literal gt = getLiteral(gtExpr);
    if (gt == (Literal)(-1)) {
	Expression leqExpr = Expression::Predicate("<=", equality[0], equality[1]);
	assert(getLiteral(leqExpr) == (Literal)(-1));
	Variable gtvar = solver.newVariable();
	gt = Literals::getLiteral(gtvar, true);
	Literal leq = Literals::getLiteral(gtvar, false);
	Literals::setExpression(gt, gtExpr);
	Literals::setExpression(leq, leqExpr);
	theory->addExpression(gtExpr);
	theory->addExpression(leqExpr);
    }

    // equality:      x != y       deq
    // disequality:   x = y        eq
    //         :   x < y        lt
    //         :   x > y        gt
    // x != y <-> x < y \/ x > y
    // eq \/ lt \/ gt
    // -lt \/ deq
    // -gt \/ deq
    // -lt \/ -gt
    std::vector<Literal> literals;
    literals.push_back(eq);
    literals.push_back(lt);
    literals.push_back(gt);
    solver.addInitialClause(new Clause(literals, true));
	    
    literals.clear();
    literals.push_back(Literals::getOpposite(lt));
    literals.push_back(deq);
    solver.addInitialClause(new Clause(literals, true));
	    
    literals.clear();
    literals.push_back(Literals::getOpposite(gt));
    literals.push_back(deq);
    solver.addInitialClause(new Clause(literals, true));


    literals.clear();
    literals.push_back(Literals::getOpposite(lt));
    literals.push_back(Literals::getOpposite(gt));
    solver.addInitialClause(new Clause(literals, true));

}
