Source code for nbodykit.source.catalog.array
from nbodykit.base.catalog import CatalogSource
from nbodykit.utils import is_structured_array
from nbodykit import CurrentMPIComm
from astropy.table import Table
import numpy
[docs]class ArrayCatalog(CatalogSource):
"""
A CatalogSource initialized from an in-memory :obj:`dict`,
structured :class:`numpy.ndarray`, or :class:`astropy.table.Table`.
See :ref:`the documentation <array-data>` for examples.
Parameters
----------
data : obj:`dict`, :class:`numpy.ndarray`, :class:`astropy.table.Table`
a dictionary, structured ndarray, or astropy Table; items are
interpreted as the columns of the catalog; the length of any item is
used as the size of the catalog.
comm : MPI Communicator, optional
the MPI communicator instance; default (``None``) sets to the
current communicator
**kwargs :
additional keywords to store as meta-data in :attr:`attrs`
"""
@CurrentMPIComm.enable
def __init__(self, data, comm=None, **kwargs):
# convert astropy Tables to structured numpy arrays
if isinstance(data, Table):
data = data.as_array()
# check for structured data
if not isinstance(data, dict):
if not is_structured_array(data):
raise ValueError(("input data to ArrayCatalog must have a "
"structured data type with fields"))
self.comm = comm
self._source = data
# compute the data type
if hasattr(data, 'dtype'):
keys = sorted(data.dtype.names)
else:
keys = sorted(data.keys())
dtype = numpy.dtype([(key, (data[key].dtype, data[key].shape[1:])) for key in keys])
self._dtype = dtype
# verify data types are the same
dtypes = self.comm.gather(dtype, root=0)
if self.comm.rank == 0:
if any(dt != dtypes[0] for dt in dtypes):
raise ValueError("mismatch between dtypes across ranks in Array")
# the local size
self._size = len(self._source[keys[0]])
for key in keys:
if len(self._source[key]) != self._size:
raise ValueError("column `%s` and column `%s` has different size" % (keys[0], key))
# update the meta-data
self.attrs.update(kwargs)
CatalogSource.__init__(self, comm=comm)
@property
def hardcolumns(self):
"""
The union of the columns in the file and any transformed columns.
"""
defaults = CatalogSource.hardcolumns.fget(self)
return list(self._dtype.names) + defaults
[docs] def get_hardcolumn(self, col):
"""
Return a column from the underlying data array/dict.
Columns are returned as dask arrays.
"""
if col in self._dtype.names:
return self.make_column(self._source[col])
else:
return CatalogSource.get_hardcolumn(self, col)