#!/usr/bin/env python

import sys
import numpy as np

class DummyCommunicator:
    def __init__(self, rank, size, ranks=None):
        self.rank = rank
        self.size = size
        self.ranks = ranks
        self.txt = ''

    def new_communicator(self, ranks):
        # The process with rank i in new group is the
        # process with rank ranks[i] in the old group
        my_subcomm_rank = np.argwhere(ranks==self.rank).ravel()
        return DummyCommunicator(rank=my_subcomm_rank.item(),
                                 size=len(ranks), ranks=ranks)

    def simulate_kpt_comm(self, nspins, nibzkpts):
        # From gpaw/wavefunctions.py lines 76-83 rev. 3893
        mynks = nspins*nibzkpts // self.size
        self.txt += ', mynks=%d' % mynks
        ks0 = self.rank * mynks
        kpt_u = []
        for ks in range(ks0, ks0 + mynks):
            s, k = divmod(ks, nibzkpts)
            kpt_u.append('%d%s' % (k, s==0 and '^' or 'v'))
        self.txt += ', kpt_u=[' + ','.join(kpt_u) + ']'
        self.output('kpt_comm')

    def simulate_band_comm(self, bd):
        self.txt += ', mynbands=%d' % bd.mynbands
        self.txt += ', mybands=%s' % bd.get_band_indices()
        self.output('band_comm')

    def output(self, name='world'):
        if name is not 'world':
            name = '    ' + name.ljust(12)
        print '%s: rank=%d, ranks=%s' % (name, self.rank, self.ranks) + self.txt

# -------------------------------------------------------------------

def simulate(world_size, parsize_c, parsize_bands, parstride_bands, nspins,
             nibzkpts, nbands, show_kpt=True, show_band=True, show_domain=True):

    # We haven't imported GPAW until now because it messes with sys.argv
    from gpaw import mpi

    print ''
    print 'Simulating: world.size = %d' % world_size
    print '    parsize_c =', parsize_c
    print '    parsize_bands =', parsize_bands
    print '    parstride_bands =', parstride_bands
    print '    nspins =', nspins
    print '    nibzkpts =', nibzkpts
    print '    nbands =', nbands
    print ''

    for rank in range(world_size):
        world = DummyCommunicator(rank, world_size)
        world.output()

        (domain_comm, kpt_comm, band_comm) = mpi.distribute_cpus(parsize_c, \
            parsize_bands, nspins, nibzkpts, world)

        if show_kpt:
            kpt_comm.simulate_kpt_comm(nspins, nibzkpts)
        if show_band:
            from gpaw.band_descriptor import BandDescriptor
            bd = BandDescriptor(nbands, band_comm, parstride_bands)
            band_comm.simulate_band_comm(bd)
        if show_domain:
            domain_comm.output('domain_comm') #TODO N_c parts!


# -------------------------------------------------------------------

from optparse import Option, OptionValueError

def check_int3(option, opt, value):
    try:
        x,y,z = value.split(',')
        return (int(x), int(y), int(z),)
    except ValueError:
        raise OptionValueError('option %s: invalid int3 value: %r' \
                               % (opt, value))

class CustomOption(Option):
    TYPES = Option.TYPES + ('int3',)
    TYPE_CHECKER = Option.TYPE_CHECKER.copy()
    TYPE_CHECKER['int3'] = check_int3


if __name__ in ['__main__', '__builtin__']:

    from optparse import OptionParser, OptionGroup

    usage = '%prog [options]'
    version = '%prog 0.1'
    description = 'Simulate MPI parallelization of a GPAW calculation.'

    parser = OptionParser(usage=usage, version=version,
                          description=description, option_class=CustomOption)

    output_options = OptionGroup(parser, 'Output-specific options')
    output_options.add_option('-v', '--verbose', action='store_true', \
        default=None, dest='show_all', help='Simulate all communicators')
    output_options.add_option('-0', '--kpoint-communicators', 
                               action='store_true', \
        default=False, dest='show_kpoint', help='Simulate kpoint communicators')
    output_options.add_option('-1', '--band-communicators', action='store_true', \
        default=False, dest='show_band', help='Simulate band communicators')
    output_options.add_option('-2', '--domain-communicators', 
                              action='store_true', \
        default=False, dest='show_domain', help='Simulate domain communicators')
    parser.add_option_group(output_options)

    mpi_options = OptionGroup(parser, 'MPI-specific options')
    mpi_options.add_option('-w', '--dry-run', default=0, type='int',
        dest='world_size', metavar='<N>', \
        help='Number of processes in parallelization [default: %default]')
    mpi_options.add_option('-c', '--domain-decomposition', default=None, \
        type='int3', dest='parsize_c', metavar='<X,Y,Z>', \
        help='Domain decomposition along three axes [default: %default]')
    mpi_options.add_option('-b', '--state-parallelization', default=1, 
                           type='int', \
        dest='parsize_bands', metavar='<B>', \
        help='Divide bands into this many blocks [default: %default]')
    mpi_options.add_option('-x', '--strided-bandgroups', action='store_true', \
        default=None, dest='parstride_bands',\
        help='Simulate strided band grouping (instead of blocked)')
    parser.add_option_group(mpi_options)

    gpaw_options = OptionGroup(parser, 'GPAW-specific options')
    gpaw_options.add_option('-s', '--spins', default=1, type='int', 
                            dest='nspins', \
        metavar='<nspins>', help='Number of spins [default: %default]')
    gpaw_options.add_option('-k', '--kpoints', default=1, type='int', \
        dest='nibzkpts', metavar='<nibzkpts>', \
        help='Number of irreducible k-points [default: %default]')
    gpaw_options.add_option('-n', '--bands', default=1, type='int', 
                            dest='nbands', \
        metavar='<nbands>', help='Number of bands [default: %default]')
    parser.add_option_group(gpaw_options)
    #parser.add_option('-g', '--grid-size', default=None, type='int3', \
    #    dest='N_c', metavar='<X,Y,Z>', \
    #    help='Size of the grid along three axes [default: %default]')

    opts, args = parser.parse_args()

    if opts.show_all is not None:
        opts.show_kpoint = True
        opts.show_band = True
        opts.show_domain = True

    if not opts.world_size > 0:
        sys.stderr.write('ERROR: MPI world size is unspecified or invalid.\n')
        parser.print_help()
        raise SystemExit(-1)

    if opts.parsize_c is not None and (len(opts.parsize_c) != 3 or
       (np.array(opts.parsize_c)<1).any()):
        sys.stderr.write('ERROR: Domain decomposition is invalid.\n')
        parser.print_help()
        raise SystemExit(-1)

    if not opts.parsize_bands > 0:
        sys.stderr.write('ERROR: State-parallelization size is invalid.\n')
        parser.print_help()
        raise SystemExit(-1)

    if not opts.nspins in [1,2]:
        sys.stderr.write('ERROR: Invalid number of spins.\n')
        parser.print_help()
        raise SystemExit(-1)

    if not opts.nibzkpts > 0:
        sys.stderr.write('ERROR: Invalid number of irreducible k-points.\n')
        parser.print_help()
        raise SystemExit(-1)

    if not opts.nbands > 0:
        sys.stderr.write('ERROR: Invalid number of bands.\n')
        parser.print_help()
        raise SystemExit(-1)

    simulate(opts.world_size, opts.parsize_c, opts.parsize_bands,
             opts.parstride_bands, opts.nspins, opts.nibzkpts, opts.nbands,
             opts.show_kpoint, opts.show_band, opts.show_domain)

