#!/usr/bin/python

import sys
import argparse
import numpy as np
from scipy.io import mmread
from pylab import *

parser = argparse.ArgumentParser(sys.argv[0])
parser.add_argument('-x,--solution', dest='x', default='x.mtx',         help='Solution vector')
parser.add_argument('-p,--part',     dest='p', default='partition.mtx', help='Partition vector')
args = parser.parse_args(sys.argv[1:])

x = mmread(args.x)
p = mmread(args.p)

n = x.shape[0]
x.shape = (n,)
p.shape = (n,)

i = np.arange(n)
j = np.array(sorted(i, key=lambda k: p[k]))

y = np.zeros(x.shape)
y[j] = x

m = int(np.sqrt(n));
y.shape = (m,m)

figure(num=1, figsize=(10,7))
subplot(1,2,1)
imshow(y, origin='lower')
colorbar()
title('solution')

subplot(1,2,2)
imshow(p.reshape((m,m)), origin='lower')
colorbar()
title('partition')

show()
