Source code for nbodykit

from .version import __version__

from mpi4py import MPI

import dask

import warnings


try:
    # prevents too many threads exception when using MPI and dask
    # by disabling threading in dask.
    dask.config.set(scheduler='synchronous') 
except:
    # deprecated since 0.18.1
    dask.set_options(get=dask.get)

_global_options = {}
_global_options['global_cache_size'] = 1e8 # 100 MB
_global_options['dask_chunk_size'] = 100000
_global_options['paint_chunk_size'] = 1024 * 1024 * 4

from contextlib import contextmanager
import logging

def _unpickle(name):
    return getattr(MPI, name)

def _comm_pickle(obj):
    if obj == MPI.COMM_NULL:
        return _unpickle, ('COMM_NULL',)
    if obj == MPI.COMM_SELF:
        return _unpickle, ('COMM_SELF',)
    if obj == MPI.COMM_WORLD:
        return _unpickle, ('COMM_WORLD',)
    raise TypeError("cannot pickle object")

def _setup_for_distributed():
    CurrentMPIComm._stack[-1] = MPI.COMM_SELF

    try:
        import copyreg
    except ImportError:  # Python 2
        import copy_reg as copyreg

    copyreg.pickle(MPI.Comm, _comm_pickle, _unpickle)
    copyreg.pickle(MPI.Intracomm, _comm_pickle, _unpickle)

    set_options(dask_chunk_size=1024 * 1024 * 2)

[docs]def use_distributed(c=None): """ Setup nbodykit to work with dask.distributed. This will change the default MPI communicator to MPI.COMM_SELF, such that each nbodykit object only reside on a single MPI rank. This function shall only be used before any nbodykit object is created. Parameters ---------- c : Client the distributed client. If not given, the default client is used. Notice that if you switch a new client then this function must be called again. """ dask.config.set(scheduler="distributed") import distributed key = 'nbodykit_setup_for_distributed' if c is None: c = distributed.get_client() _setup_for_distributed() # use an lock to minimize chances of seeing KeyError from publish_dataset # the error is annoyingly printed to stderr even if we caught it. lock = distributed.Lock(key) locked = lock.acquire(timeout=3) if key not in c.list_datasets(): try: c.publish_dataset(**{key : True}) c.register_worker_callbacks(setup=_setup_for_distributed) except KeyError: # already published, someone else is registering the callback. pass if locked: lock.release()
[docs]def use_mpi(comm=None): """ Setup nbodykit to work with MPI. This will change the default MPI communicator to MPI.COMM_WORLD, such that each nbodykit object is partitioned to many MPI ranks. This function shall only be used before any nbodykit object is created. """ dask.config.set(scheduler='synchronous') if comm is None: comm = MPI.COMM_WORLD CurrentMPIComm._stack[-1] = comm set_options(dask_chunk_size=1024 * 100)
[docs]class CurrentMPIComm(object): """ A class to faciliate getting and setting the current MPI communicator. """ _stack = [MPI.COMM_WORLD] logger = logging.getLogger("CurrentMPIComm")
[docs] @staticmethod def enable(func): """ Decorator to attach the current MPI communicator to the input keyword arguments of ``func``, via the ``comm`` keyword. """ import functools @functools.wraps(func) def wrapped(*args, **kwargs): kwargs.setdefault('comm', None) if kwargs['comm'] is None: kwargs['comm'] = CurrentMPIComm.get() return func(*args, **kwargs) return wrapped
[docs] @classmethod @contextmanager def enter(cls, comm): """ Enters a context where the current default MPI communicator is modified to the argument `comm`. After leaving the context manager the communicator is restored. Example: .. code:: python with CurrentMPIComm.enter(comm): cat = UniformCatalog(...) is identical to .. code:: python cat = UniformCatalog(..., comm=comm) """ cls.push(comm) yield cls.pop()
[docs] @classmethod def push(cls, comm): """ Switch to a new current default MPI communicator """ cls._stack.append(comm) if comm.rank == 0: cls.logger.info("Entering a current communicator of size %d" % comm.size) cls._stack[-1].barrier()
[docs] @classmethod def pop(cls): """ Restore to the previous current default MPI communicator """ comm = cls._stack[-1] if comm.rank == 0: cls.logger.info("Leaving current communicator of size %d" % comm.size) cls._stack[-1].barrier() cls._stack.pop() comm = cls._stack[-1] if comm.rank == 0: cls.logger.info("Restored current communicator to size %d" % comm.size)
[docs] @classmethod def get(cls): """ Get the default current MPI communicator. The initial value is ``MPI.COMM_WORLD``. """ return cls._stack[-1]
[docs] @classmethod def set(cls, comm): """ Set the current MPI communicator to the input value. """ warnings.warn("CurrentMPIComm.set is deprecated. Use `with CurrentMPIComm.enter(comm):` instead") cls._stack[-1].barrier() cls._stack[-1] = comm cls._stack[-1].barrier()
import dask.cache
[docs]class GlobalCache(dask.cache.Cache): """ A Cache object. """
[docs] @classmethod def get(cls): """ Return the global cache object. The default size is controlled by the ``global_cache_size`` global option; see :class:`set_options`. Returns ------- cache : :class:`dask.cache.Cache` the cache object, as provided by dask """ # if not created, use default cache size return _global_cache
_global_cache = GlobalCache(_global_options['global_cache_size']) _global_cache.register()
[docs]class set_options(object): """ Set global configuration options. Parameters ---------- dask_chunk_size : int the number of elements for the default chunk size for dask arrays; chunks should usually hold between 10 MB and 100 MB global_cache_size : float the size of the internal dask cache in bytes; default is 1e9 paint_chunk_size : int the number of objects to paint at the same time. This is independent from dask chunksize. """ def __init__(self, **kwargs): self.old = _global_options.copy() for key in sorted(kwargs): if key not in _global_options: raise KeyError("Option `%s` is not supported" % key) _global_options.update(kwargs) # resize the global Cache! # FIXME: after https://github.com/dask/cachey/pull/12 if 'global_cache_size' in kwargs: cache = GlobalCache.get().cache cache.available_bytes = _global_options['global_cache_size'] cache.shrink() def __enter__(self): return def __exit__(self, type, value, traceback): _global_options.clear() _global_options.update(self.old) # resize Cache to original size # FIXME: after https://github.com/dask/cachey/pull/12 cache = GlobalCache.get().cache cache.available_bytes = _global_options['global_cache_size'] cache.shrink()
_logging_handler = None
[docs]def setup_logging(log_level="info"): """ Turn on logging, with the specified level. Parameters ---------- log_level : 'info', 'debug', 'warning' the logging level to set; logging below this level is ignored """ # This gives: # # [ 000000.43 ] 0: 06-28 14:49 measurestats INFO Nproc = [2, 1, 1] # [ 000000.43 ] 0: 06-28 14:49 measurestats INFO Rmax = 120 import logging levels = { "info" : logging.INFO, "debug" : logging.DEBUG, "warning" : logging.WARNING, } import time logger = logging.getLogger(); t0 = time.time() rank = MPI.COMM_WORLD.rank class Formatter(logging.Formatter): def format(self, record): s1 = ('[ %09.2f ] % 3d: ' % (time.time() - t0, rank)) return s1 + logging.Formatter.format(self, record) fmt = Formatter(fmt='%(asctime)s %(name)-15s %(levelname)-8s %(message)s', datefmt='%m-%d %H:%M ') global _logging_handler if _logging_handler is None: _logging_handler = logging.StreamHandler() logger.addHandler(_logging_handler) _logging_handler.setFormatter(fmt) logger.setLevel(levels[log_level])