Source code for nbodykit
from .version import __version__
from mpi4py import MPI
# prevents too many threads exception when using MPI and dask
import dask
dask.set_options(get=dask.get)
[docs]class CurrentMPIComm(object):
"""
A class to faciliate getting and setting the current MPI communicator.
"""
_instance = None
[docs] @staticmethod
def enable(func):
"""
Decorator to attach the current MPI communicator to the input
keyword arguments of ``func``, via the ``comm`` keyword.
"""
import functools
@functools.wraps(func)
def wrapped(*args, **kwargs):
kwargs.setdefault('comm', None)
if kwargs['comm'] is None:
kwargs['comm'] = CurrentMPIComm.get()
return func(*args, **kwargs)
return wrapped
[docs] @classmethod
def get(cls):
"""
Get the current MPI communicator, returning ``MPI.COMM_WORLD``
if it has not be explicitly set yet.
"""
# initialize MPI and set the comm if we need to
if not cls._instance:
comm = MPI.COMM_WORLD
cls._instance = comm
return cls._instance
[docs] @classmethod
def set(cls, comm):
"""
Set the current MPI communicator to the input value.
"""
cls._instance = comm
_logging_handler = None
[docs]def setup_logging(log_level="info"):
"""
Turn on logging, with the specified level.
Parameters
----------
log_level : 'info', 'debug', 'warning'
the logging level to set; logging below this level is ignored
"""
# This gives:
#
# [ 000000.43 ] 0: 06-28 14:49 measurestats INFO Nproc = [2, 1, 1]
# [ 000000.43 ] 0: 06-28 14:49 measurestats INFO Rmax = 120
import logging
levels = {
"info" : logging.INFO,
"debug" : logging.DEBUG,
"warning" : logging.WARNING,
}
import time
logger = logging.getLogger();
t0 = time.time()
rank = MPI.COMM_WORLD.rank
class Formatter(logging.Formatter):
def format(self, record):
s1 = ('[ %09.2f ] % 3d: ' % (time.time() - t0, rank))
return s1 + logging.Formatter.format(self, record)
fmt = Formatter(fmt='%(asctime)s %(name)-15s %(levelname)-8s %(message)s',
datefmt='%m-%d %H:%M ')
global _logging_handler
if _logging_handler is None:
_logging_handler = logging.StreamHandler()
logger.addHandler(_logging_handler)
_logging_handler.setFormatter(fmt)
logger.setLevel(levels[log_level])