#include "expressions/Expression.h"
#include "auxiliary/Names.h"

class ITEElimination {
public:
	Expression eliminateITE(const Expression& e) {
		if (e.IsITE()) {
			std::map<Expression, Expression>::iterator i = ites.find(e);
			if (i != ites.end())
				return i->second;
			Expression v = getFreshVariable();
			ites[e] = v;
			Expression e0clean = eliminateITE(e[0]);
			additionals.push_back(
				Expression::AND(
					Expression::IMPL(e0clean, Expression::Equality(v, eliminateITE(e[1]))),
					Expression::IMPL(Expression::NOT(e0clean), Expression::Equality(v, eliminateITE(e[2])))
				)
			);
			return v;
		}

		if (!e.hasOperands())
			return e;

		std::vector<Expression> operands;
		Expression::operands_iterator i;
		for (i = e.begin(); i != e.end(); i++) {
			operands.push_back(eliminateITE(*i));
		}

		if (e.IsNOT())
			return Expression::NOT(operands[0]);
		if (e.IsAND())
			return Expression::AND(operands);
		if (e.IsOR())
			return Expression::OR(operands);
		if (e.IsIMPL())
			return Expression::IMPL(operands[0], operands[1]);
		if (e.IsIFF())
			return Expression::IFF(operands[0], operands[1]);
		if (e.IsXOR())
			return Expression::XOR(operands[0], operands[1]);
		if (e.IsPredicate())
			return Expression::Predicate(e.GetName(), operands);
		if (e.IsEquality())
			return Expression::Equality(operands[0], operands[1]);
		if (e.IsDisequality())
			return Expression::Disequality(operands[0], operands[1]);
		if (e.IsFunction())
			return Expression::Function(e.GetName(), operands);

		cout << e.GetType() << endl;			
		throw "Unknown expression type";
		return e;
	}


	Expression getAdditionals() {
		return Expression::AND(additionals);
	}
		
private:

	Expression getFreshVariable() {
		static int num = 0;
		return Expression::Variable("_ite_" + Names::itoa(num++));
	}
		
	std::map<Expression, Expression> ites;
	std::vector<Expression> additionals;
};

