#include "Matrix.h"

void Matrix::pivot(unsigned i, unsigned x) {
    _nonBasicVariables[i]->insert(_basicVariables[i]);
    setBasicVariable(i, x);
    _nonBasicVariables[i]->erase(x);
    const RATIONAL& aix = _matrix[i][x];
    normalize(i, aix);

    for (unsigned k = 0; k < rows(); k++) {
	if (k == i)
	    continue;
	const RATIONAL& mb = _matrix[k][x];

	if (mb == 0)
	    continue;

	_nonBasicVariables[k]->erase(x);
	
	std::set<unsigned>::const_iterator it, 
	    beg = _nonBasicVariables[i]->begin(), 
	    en = _nonBasicVariables[i]->end();
	for (it = beg; it != en; it++) {
	    if (_matrix[k][*it] == 0) {
		if (mb == -1) {
		    _matrix[k][*it] = _matrix[i][*it];
		} else if (mb == 1) {
		    _matrix[k][*it] = -_matrix[i][*it];
		} else {
		    _matrix[k][*it] = -_matrix[i][*it];
		    _matrix[k][*it] *= mb;
		}
		_nonBasicVariables[k]->insert(*it);
	    }
	    else {
		if (mb == -1) {
		    _matrix[k][*it] += _matrix[i][*it];
		} else if (mb == 1) {
		    _matrix[k][*it] -= _matrix[i][*it];
		} else {
		    _matrix[k][*it] -= mb*_matrix[i][*it];
		}
		if (_matrix[k][*it] == 0)
		    _nonBasicVariables[k]->erase(*it);
	    }
	}
	_matrix[k][x] = 0;
    }
}

void Matrix::normalize(unsigned i, RATIONAL a) {
    if (a == 1)
	return;
    for (unsigned j = 0; j < cols(); j++) {
	if (_matrix[i][j] != 0) {
	    _matrix[i][j] /= a;
	    _matrix[i][j].canonicalize();
	}
    }
}

void Matrix::print() {
    cout << "ROWS: " << rows() << endl;
    cout << "ORDER: " << _orderedBasicVariables.size() << endl;
    std::set<unsigned, BasicVariableOrder>::const_iterator i;
    for (i = begin(); i != end(); i++) {
	std::set<unsigned, NonBasicVariableOrder>::const_iterator it, 
	    beg = _nonBasicVariables[*i]->begin(), 
	    en = _nonBasicVariables[*i]->end();
	for (it = beg; it != en; it++)
	    cout << _variables[*it] << " ";
	cout << endl;


	cout << _variables[_basicVariables[*i]] << " : ";
	for (unsigned j = 0; j < cols(); j++) {
	    if (_matrix[*i][j] != 0)
		cout << _matrix[*i][j] << " " << _variables[j] << "\t";
	}
	cout << endl;
    }
    std::cin.get();
}

void Matrix::setBasicVariable(unsigned eq, unsigned v) {
    if (eq < _basicVariables.size()) {
	_basicVariableEquations.erase(_basicVariables[eq]);
	std::set<unsigned, BasicVariableOrder>::iterator it = _orderedBasicVariables.find(eq);
	if (it != _orderedBasicVariables.end()) {
	    _orderedBasicVariables.erase(eq);
	}
	_basicVariables[eq] = v;
    } else {
	_basicVariables.push_back(v);
    }

    _basicVariableEquations[v] = eq;
    _orderedBasicVariables.insert(eq);
}

void Matrix::sort() {
//    coutput << "Sorting" << endl;
    std::set<unsigned, BasicVariableOrder> sorted_order(_basicVariableOrder);
    const_iterator it;
    for (it = begin(); it != end(); it++) {
	sorted_order.insert(*it);
    }
    _orderedBasicVariables = sorted_order;
    for (it = begin(); it != end(); it++) {
	std::set<unsigned, NonBasicVariableOrder>* sorted_nonBasics = 
	    new std::set<unsigned, NonBasicVariableOrder>(_nonBasicVariableOrder);
	std::set<unsigned, NonBasicVariableOrder>::const_iterator jt;
	const std::set<unsigned, NonBasicVariableOrder>& nonBasics = getNonBasicVariables(*it);
	for (jt = nonBasics.begin(); jt != nonBasics.end(); jt++) {
	    sorted_nonBasics->insert(*jt);
	}
	delete _nonBasicVariables[*it];
	_nonBasicVariables[*it] = sorted_nonBasics;
    }
}
