#ifndef __BACKTRACKABLE_UNION_FIND_H__
#define __BACKTRACKABLE_UNION_FIND_H__

#include <vector>
#include <map>
#include <cassert>
using namespace std;

template<class T, class Comparator>
class BacktrackableUnionFind {
 public:
    BacktrackableUnionFind(const Comparator& comparator)
	: _comparator(comparator) {
	_levels.push_back(0);
    }

    void Union(T x, T y) {
	T rx = Find(x);
	T ry = Find(y);

	if (_comparator.compare(rx, ry) < 0) {
	    T tmp = x;
	    x = y;
	    y = tmp;
	    tmp = rx;
	    rx = ry;
	    ry = tmp;
	}

//	coutput << "Union: " << x << " " << y << endl;

	if (rx == ry) {
	    return;
	}
	_levels.back()++;
	
	makeRoot_(x);
	_unionHistory.push_back(pair<T, T>(x, y));
	_oldRoots.push_back(rx);
	_unionFind[x].push_back(y);
	_size[ry].push_back(_size[rx].back() + _size[ry].back());
    }

    const T& Find(const T& x) {
	const T* current = &x;
	const T* previous;
	do {
	    previous = current;
	    current = &getParrent(*current);
	} while (*previous != *current);
	return *current;
    }

    void Backtrack() {
	size_t n = _levels.back();
	for (size_t i = 0; i < n; i++) {
	    T x = _unionHistory.back().first;
	    T y = _unionHistory.back().second;
	    T o = _oldRoots.back();

	    _unionHistory.pop_back();
	    _oldRoots.pop_back();
	
	    _unionFind[x].back() = x;
	    makeRoot(o);
	    _unionFind[x].pop_back();

	    T ry = Find(y);
	    _size[ry].pop_back();
	}
	_levels.pop_back();
    }

    void makeRoot(T x) {
	makeRoot_(x);
	_unionFind[x].back() = x;
    }

    void Explain(const T& x, const T& y, std::vector< std::pair<T, T> >& explanation) {
	assert(Find(x) == Find(y));
	vector<T> xpath, ypath;
	getPath(x, xpath);
	getPath(y, ypath);

	bool found = false;
	T commonAncesstor = T();
	for (size_t i = 0; i < xpath.size() && !found; i++)
	    for (size_t j = 0; j < ypath.size() && !found; j++)
		if (xpath[i] == ypath[j]) {
		    found = true;
		    commonAncesstor = xpath[i];
		}
	
	for (int i = 0; xpath[i] != commonAncesstor; i++) {
	    explanation.push_back(std::pair<T, T>(xpath[i], xpath[i+1]));
	}
	for (int i = 0; ypath[i] != commonAncesstor; i++) {
	    explanation.push_back(std::pair<T, T>(ypath[i], ypath[i+1]));
	}    
    }

    void print() {
	typename map< T, vector<T> >::const_iterator it;
	for (it = _unionFind.begin(); it != _unionFind.end(); it++) {
	    cout << it->first << "\t " << it->second.back() << "\t" << Find(it->first);
	    if (it->first == it->second.back())
		cout << "   : " << _size[it->first].back();
	    cout << endl;
	}
	cout << endl;
    }

    void newLevel() {
	_levels.push_back(0);
    }

    void addElement(const T& x) {
	_unionFind.resize(x + 1);
	_unionFind[x] = vector<T>();
	_unionFind[x].push_back(x);
	_size[x] = vector<int>();
	_size[x].push_back(1);
    }

 private:
    void makeRoot_(const T& x) {
	const T& rx = getParrent(x);
	if (rx == x)
	    return;

	makeRoot_(rx);
	_unionFind[rx].back() = x;
    }

    const T& getParrent(const T& x) {
	return _unionFind[x].back();
    }

    void getPath(T x, vector<T>& path) {
	T previous;
	T current = x;
	while(1) {
	    path.push_back(current);
	    previous = current;
	    current = getParrent(current);
	    if (previous == current)
		break;
	}

    }

    vector< vector<T> > _unionFind;

    map< T, vector<int> > _size;
    vector< pair<T, T> > _unionHistory;
    vector<T> _oldRoots;
    const Comparator& _comparator;

    vector<unsigned> _levels;
};

template <class T, class Comparator>

#endif
