Source code for nbodykit.algorithms.cgm

import numpy
import logging
import kdcount
import mpsort
import pandas as pd
from six import string_types
import warnings

from nbodykit import CurrentMPIComm
from nbodykit.source.catalog import ArrayCatalog

[docs]class CylindricalGroups(object): """ Compute groups of objects using a cylindrical grouping method. We identify all satellites within a given cylindrical volume around a central object. Results are computed when the object is inititalized, and the result is stored in the :attr:`groups` attribute; see the documenation of :func:`~CylindricalGroups.run`. Input parameters are stored in the :attr:`attrs` attribute dictionary. Parameters ---------- source : subclass of :class:`~nbodykit.base.catalog.CatalogSource` the input source of particles providing the 'Position' column; the grouping algorithm is run on this catalog rperp : float the radius of the cylinder in the sky plane (i.e., perpendicular to the line-of-sight) rpar : float the radius along the line-of-sight direction; this is 1/2 the height of the cylinder rankby : str, list, ``None`` a single or list of column names to rank order the input source by before computing the cylindrical groups, such that objects ranked first are marked as CGM centrals; if ``None`` is supplied, no sorting will be done flat_sky_los : bool, optional a unit vector of length 3 providing the line-of-sight direction, assuming a fixed line-of-sight across the box, e.g., [0,0,1] to use the z-axis. If ``None``, the observer at (0,0,0) is used to compute the line-of-sight for each pair periodic : bool, optional whether to use periodic boundary conditions BoxSize : float, 3-vector, optional the size of the box of the input data; must be provided as a keyword or in ``source.attrs`` if ``periodic=True`` References ---------- Okumura, Teppei, et al. "Reconstruction of halo power spectrum from redshift-space galaxy distribution: cylinder-grouping method and halo exclusion effect", arXiv:1611.04165, 2016. """ logger = logging.getLogger('CylindricalGroups') def __init__(self, source, rankby, rperp, rpar, flat_sky_los=None, periodic=False, BoxSize=None): if 'Position' not in source: raise ValueError("the 'Position' column must be defined in the input source") if rankby is None: rankby = [] if isinstance(rankby, string_types): rankby = [rankby] for col in rankby: if col not in source: raise ValueError("cannot rank by column '%s'; no such column" %col) self.source = source self.comm = source.comm self.attrs = {} # need BoxSize self.attrs['BoxSize'] = numpy.empty(3) BoxSize = source.attrs.get('BoxSize', BoxSize) if periodic and BoxSize is None: raise ValueError("please specify a BoxSize if using periodic boundary conditions") self.attrs['BoxSize'][:] = BoxSize # LOS must be unit vector if flat_sky_los is not None: if numpy.isscalar(flat_sky_los) or len(flat_sky_los) != 3: raise ValueError("line-of-sight ``flat_sky_los`` should be vector with length 3") if not numpy.allclose(numpy.einsum('i,i', flat_sky_los, flat_sky_los), 1.0, rtol=1e-5): raise ValueError("line-of-sight ``flat_sky_los`` must be a unit vector") # warn if periodic and LOS is None if flat_sky_los is None and periodic: warnings.warn(("CylindricalGroups using periodic boundary conditions " "with line-of-sight computed from origin (0,0,0); maybe specify a line-of-sight?")) # save meta-data self.attrs['rpar'] = rpar self.attrs['rperp'] = rperp self.attrs['periodic'] = periodic self.attrs['rankby'] = rankby self.attrs['flat_sky_los'] = flat_sky_los # log some info if self.comm.rank == 0: args = (str(rperp), str(rpar)) self.logger.info("finding groups with rperp=%s and rpar=%s " %args) if flat_sky_los is None: self.logger.info(" using line-of-sight computed using observer at origin (0,0,0)") else: self.logger.info(" using line-of-sight vector %s" %str(flat_sky_los)) msg = "periodic" if periodic else "non-periodic" msg = " using %s boundary conditions" %msg if self.attrs['BoxSize'] is not None: msg += " (BoxSize = %s)" %str(self.attrs['BoxSize']) self.logger.info(msg) self.run()
[docs] def run(self): """ Compute the cylindrical groups, saving the results to the :attr:`groups` attribute Attributes ---------- groups : :class:`~nbodykit.source.catalog.array.ArrayCatalog` a catalog holding the result of the grouping. The length of the catalog is equal to the length of the input size, i.e., the length is equal to the :attr:`size` attribute. The relevant fields are: #. cgm_type : a flag specifying the type for each object, with 0 specifying CGM central and 1 denoting CGM satellite #. cgm_haloid : The index of the CGM object this object belongs to; an integer between 0 and the total number of CGM halos #. num_cgm_sats : The number of satellites in the CGM halo """ from pmesh.domain import GridND from nbodykit.algorithms.fof import split_size_3d comm = self.comm rperp, rpar = self.attrs['rperp'], self.attrs['rpar'] rankby = self.attrs['rankby'] if self.attrs['periodic']: boxsize = self.attrs['BoxSize'] else: boxsize = None np = split_size_3d(self.comm.size) if self.comm.rank == 0: self.logger.info("using cpu grid decomposition: %s" %str(np)) # add a column for original index self.source['origind'] = self.source.Index # sort the data data = self.source.sort(self.attrs['rankby'], usecols=['Position', 'origind']) # add a column to track sorted index data['sortindex'] = data.Index # global min/max across all ranks pos = data.compute(data['Position']) posmin = numpy.asarray(comm.allgather(pos.min(axis=0))).min(axis=0) posmax = numpy.asarray(comm.allgather(pos.max(axis=0))).max(axis=0) # domain decomposition grid = [ numpy.linspace(posmin[0], posmax[0], np[0] + 1, endpoint=True), numpy.linspace(posmin[0], posmax[1], np[1] + 1, endpoint=True), numpy.linspace(posmin[0], posmax[2], np[2] + 1, endpoint=True), ] domain = GridND(grid, comm=comm) # run the CGM algorithm groups = cgm(comm, data, domain, rperp, rpar, self.attrs['flat_sky_los'], boxsize) # make the final structured array self.groups = ArrayCatalog(groups, comm=self.comm, **self.attrs) # log some info N_cen = (groups['cgm_type']==0).sum() isolated_N_cen = ((groups['cgm_type']==0)&(groups['num_cgm_sats']==0)).sum() N_cen = self.comm.allreduce(N_cen) isolated_N_cen = self.comm.allreduce(isolated_N_cen) if self.comm.rank == 0: self.logger.info("found %d CGM centrals total" %N_cen) self.logger.info("%d/%d are isolated centrals (no satellites)" % (isolated_N_cen,N_cen)) # delete the column we added to source del self.source['origind']
[docs]def cgm(comm, data, domain, rperp, rpar, los, boxsize): """ Perform the cylindrical grouping method This outputs a structured array with the same length as the input data with the following fields for each object in the original data: #. cgm_type : a flag specifying the type for each object, with 0 specifying CGM central and 1 denoting CGM satellite #. cgm_haloid : The index of the CGM object this object belongs to; an integer between 0 and the total number of CGM halos #. num_cgm_sats : The number of satellites in the CGM halo Parameters ---------- comm : the MPI communicator data : CatalogSource catalog with sorted input data, including Position domain : the domain decomposition rperp, rpar : float the maximum distances to group objects together in the directions perpendicular and parallel to the line-of-sight; the cylinder has radius ``rperp`` and height ``2 * rpar`` los : the line-of-sight vector boxsize : the boxsize, or ``None`` if not using periodic boundary conditions """ # whether we do periodic boundary conditions periodic = boxsize is not None flat_sky = los is not None # the maximum distance still inside the cylinder set by rperp,rpar rperp2 = rperp**2; rpar2 = rpar**2 rmax = (rperp2 + rpar2)**0.5 pos0, origind0, sortindex0 = data.compute(data['Position'], data['origind'], data['sortindex']) layout1 = domain.decompose(pos0, smoothing=0) pos1 = layout1.exchange(pos0) origind1 = layout1.exchange(origind0) sortindex1 = layout1.exchange(sortindex0) # exchange particles across ranks, accounting for smoothing radius layout2 = domain.decompose(pos1, smoothing=rmax) pos2 = layout2.exchange(pos1) origind2 = layout2.exchange(origind1) sortindex2 = layout2.exchange(sortindex1) startrank = layout2.exchange(numpy.ones(len(pos1), dtype='i4')*comm.rank) # make the KD-tree tree1 = kdcount.KDTree(pos1, boxsize=boxsize).root tree2 = kdcount.KDTree(pos2, boxsize=boxsize).root dataframe = [] j_gt_i = numpy.zeros(len(pos1), dtype='f4') wrong_rank = numpy.zeros(len(pos1), dtype='f4') def callback(r, i, j): r1 = pos1[i] r2 = pos2[j] dr = r1 - r2 # enforce periodicity in dpos if periodic: for axis, col in enumerate(dr.T): col[col > boxsize[axis]*0.5] -= boxsize[axis] col[col <= -boxsize[axis]*0.5] += boxsize[axis] # los distance if flat_sky: rlos2 = numpy.einsum("ij,j->i", dr, los)**2 else: center = 0.5 * (r1 + r2) dot2 = numpy.einsum('ij, ij->i', dr, center)**2 center2 = numpy.einsum('ij, ij->i', center, center) rlos2 = dot2 / center2 # sky dr2 = numpy.einsum('ij, ij->i', dr, dr) rsky2 = numpy.abs(dr2 - rlos2) # save the valid pairs # To Be Valid: pairs must be within cylinder (compare rperp and rpar) valid = (rsky2 <= rperp2)&(rlos2 <= rpar2) i = i[valid]; j = j[valid]; # the correctly sorted indices of particles sort_i = sortindex1[i] sort_j = sortindex2[j] # the rank where the j object lives rank_j = startrank[j] # track pairs where sorted j > sorted i weights = numpy.where(sort_i < sort_j, 1, 0) j_gt_i[:] += numpy.bincount(i, weights=weights, minlength=len(pos1)) # track pairs where j rank is wrong weights *= numpy.where(rank_j != comm.rank, 1, 0) wrong_rank[:] += numpy.bincount(i, weights=weights, minlength=len(pos1)) # save the valid pairs for final calculations res = numpy.vstack([i, j, sort_i, sort_j]).T dataframe.append(res) # add all the valid pairs to a dataframe tree1.enum(tree2, rmax, process=callback) # sorted indices of objects that are centrals # (objects with no pairs with j > i) centrals = set(sortindex1[(j_gt_i==0)]) # sorted indices of objects that might be centrals # (pairs with j>i that live on other ranks) maybes = set(sortindex1[(wrong_rank>0)]) # store the pairs in a pandas dataframe for fast groupby dataframe = numpy.concatenate(dataframe, axis=0) df = pd.DataFrame(dataframe, columns=['i', 'j', 'sort_i', 'sort_j']) # we sort by the correct sorted index in descending order which puts # highest priority objects first df.sort_values("sort_i", ascending=False, inplace=True) # index by the correct sorted order df.set_index('sort_i', inplace=True) # to find centrals, considers objects that could be satellites of another # (pairs with sort_j > sort_i) possible_cens = df[(df['sort_j']>df.index.values)] possible_cens = possible_cens.drop(centrals, errors='ignore') _remove_objects_paired_with(possible_cens, centrals) # remove objs paired with cens # sorted indices of objects that have pairs on other ranks # these objects are already "maybe" centrals on_other_ranks = sortindex1[(wrong_rank>0)] # find the centrals and associated halo labels for each central all_centrals, labels = _find_centrals(comm, possible_cens, on_other_ranks, centrals, maybes) # reset the index and return df.reset_index(inplace=True) # add the halo labels for each pair in the dataframe labels = pd.Series(labels, name='label_i', index=pd.Index(all_centrals, name='sort_i')) df = df.join(labels, on='sort_i') labels.name = 'label_j'; labels.index.name = 'sort_j' df = df.join(labels, on='sort_j') # iniitalize the output arrays labels = numpy.zeros(len(pos1), dtype='i8') - 1 # indexed by i types = numpy.zeros(len(pos1), dtype='u4') # indexed by i counts = numpy.zeros(len(pos2), dtype='i8') # indexed by j # assign labels of the centrals cens = df.dropna(subset=['label_j']).drop_duplicates('i') labels[cens['i'].values] = cens['label_i'].values # objects on this rank that are satellites # (no label for the 1st object in pair but a label for the 2nd object) sats = (df['label_i'].isnull())&(~df['label_j'].isnull()) df = df[sats] # find the corresponding central for each satellite df = df.sort_values('sort_j', ascending=False) df.set_index('sort_i', inplace=True) sats_grouped = df.groupby('sort_i', sort=False, as_index=False) centrals = sats_grouped.first() # these are the centrals for each satellite # update the satellite info with its pair with the highest priority cens_i = centrals['i'].values; cens_j = centrals['j'].values counts += numpy.bincount(cens_j, minlength=len(pos2)) types[cens_i] = 1 labels[cens_i] = centrals['label_j'].values # sum counts across ranks (take the sum of any repeated objects) counts = layout2.gather(counts, mode='sum') # output fields dtype = numpy.dtype([('cgm_haloid', 'i8'), ('num_cgm_sats', 'i8'), ('cgm_type', 'u4'), ('origind', 'u4')]) out = numpy.empty(len(data), dtype=dtype) # gather the data back onto the original ranks # no ghosts for this domain layout so choose any particle out['cgm_haloid'] = layout1.gather(labels, mode='any') out['origind'] = layout1.gather(origind1, mode='any') out['num_cgm_sats'] = layout1.gather(counts, mode='any') out['cgm_type'] = layout1.gather(types, mode='any') # restore the original order mpsort.sort(out, orderby='origind', comm=comm) fields = ['cgm_type', 'cgm_haloid', 'num_cgm_sats'] return out[fields]
def _remove_objects_paired_with(df, bad_pairs): """ Remove any objects that are paired with an object in ``bad_pairs`` This is done in place """ assert df.index.name == 'sort_i' df.reset_index(inplace=True) df.set_index('sort_j', inplace=True) # exception could be raised if no pairs need to be dropped # so just reset index and return try: bad_pair_index = df.index.intersection(bad_pairs).unique() to_drop = df.loc[bad_pair_index]['sort_i'].values df.reset_index(inplace=True) df.set_index('sort_i', inplace=True) df.drop(to_drop,inplace=True) except: df.set_index('sort_i', inplace=True) def _find_centrals(comm, df, on_other_ranks, centrals, maybes): """ Find the sorted index values of all of the centrals Each rank determines local centrals (which have no pairs on other ranks), and then root searches the objects spread out on multiple ranks Returns ------- all_centrals : list the sorted index values of all centrals labels : list corresponding labels; ranging from ``0`` to ``len(all_centrals)`` """ def find_local_centrals(grp): cenid = grp.index.values[0] # group number of the centrals that could host this object maybe_grp_nums = grp['sort_j'].values # this object is a satellite if any(num in centrals for num in maybe_grp_nums): return # this object could be a satellite if any(num in maybes for num in maybe_grp_nums): maybes.add(cenid) return # if we get here, this object is definitely a central centrals.add(cenid) # only need to examine objects that have all higher priority pairs # on the same rank --> if they have pairs on other ranks, then they are already # marked as "maybes" centrals same_rank_df = df.drop(on_other_ranks, errors='ignore') # group by centrals and find the local centrals # these are objects with no higher priority pairs cens_grouped = same_rank_df.groupby('sort_i', sort=False) cens_grouped.apply(find_local_centrals) # the pairs associated with objects that might be satellites maybe_index = df.index.intersection(list(maybes)).unique() maybe_cen_groups = df.loc[maybe_index] # gather data on maybes maybes_data = comm.gather(maybe_cen_groups) all_centrals = numpy.concatenate(comm.allgather(list(centrals)), axis=0) # root identifies the remaining centrals if comm.rank == 0: # keep track of new centrals new_centrals = set() all_maybes = pd.concat(maybes_data) # remove objects paired with a central _remove_objects_paired_with(all_maybes, all_centrals) # consider objects in sorted order all_maybes.sort_index(ascending=False, inplace=True) def finalize(grp): grp_nums = grp['sort_j'].values if not any(num in new_centrals for num in grp_nums): new_centrals.add(grp.index.values[0]) # find out which of the maybes are actually centrals maybes_grouped = all_maybes.groupby('sort_i', sort=False) maybes_grouped.apply(finalize) else: new_centrals = None # get the list of all centrals on all ranks new_centrals = comm.bcast(new_centrals) all_centrals = numpy.append(all_centrals, list(new_centrals)) # sort and create unique halo labels all_centrals[::-1].sort() labels = numpy.arange(0, len(all_centrals), dtype='i4') return all_centrals, labels
[docs]def data_to_sort_key(data): """ Convert floating type data to unique integers for sorting """ return numpy.fromstring(data.tobytes(), dtype='u8')