Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for distributed cholla datasets. #4702

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 116 additions & 14 deletions yt/frontends/cholla/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,50 @@

from yt.data_objects.index_subobjects.grid_patch import AMRGridPatch
from yt.data_objects.static_output import Dataset
from yt.funcs import setdefaultattr
from yt.funcs import get_pbar, setdefaultattr
from yt.geometry.api import Geometry
from yt.geometry.grid_geometry_handler import GridIndex
from yt.utilities.on_demand_imports import _h5py as h5py

from .fields import ChollaFieldInfo


def _split_fname_proc_suffix(filename: str):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you put a short note about how this is different from os.path.splitext? Just to avoid future confusion.

"""Splits ``filename`` at the '.' separating the beginning part of the
string from the process-id suffix, and returns both parts in a 2-tuple.

When cholla is compiled with MPI and it directly writes data-files, each
process appends a suffix to each filename that denotes the process-id. For
example, the MPI-compiled version might write '0.h5.0'. If that function is
passed such a string, then it returns ``('0.h5', '0')``.

In cases where there is no suffix, the output is ``(filename, '')``. This
might come up if the user concatenated the output files, which is common
practice.
"""

# at this time, we expect the suffix to be the minimum number of characters
# that are necessary to represent the process id. For flexibility, we will
# allow extra zero-padding

sep_i = filename.rfind(".")
suf_len = len(filename) - (sep_i + 1)
if (sep_i == -1) or (suf_len == 0) or not filename[sep_i + 1 :].isdecimal():
return (filename, "")
elif (sep_i == 0) or ((sep_i - 1) == filename.rfind("/")):
raise ValueError(
f"can't split a process-suffix off of {filename!r} "
"since the remaining filename would be empty"
)
else:
return (filename[:sep_i], filename[sep_i + 1 :])


class ChollaGrid(AMRGridPatch):
_id_offset = 0

def __init__(self, id, index, level, dims):
super().__init__(id, filename=index.index_filename, index=index)
def __init__(self, id, index, level, dims, filename):
super().__init__(id, filename=filename, index=index)
self.Parent = None
self.Children = []
self.Level = level
Expand All @@ -42,23 +73,92 @@ def _detect_output_fields(self):
self.field_list = [("cholla", k) for k in h5f.keys()]

def _count_grids(self):
self.num_grids = 1
# the number of grids is equal to the number of processes, unless the
# dataset has been concatenated. But, when the dataset is concatenated
# (a common post-processing step), the "nprocs" hdf5 attribute is
# usually dropped.

with h5py.File(self.index_filename, mode="r") as h5f:
nprocs = h5f.attrs.get("nprocs", np.array([1, 1, 1]))[:].astype("=i8")
self.num_grids = np.prod(nprocs)

if self.num_grids > 1:
# When there's more than 1 grid, we expect the user to
# - have not changed the names of the output files
# - have passed the file written by process 0 to ``yt.load``
# Let's perform a sanity-check that self.index_filename has the
# expected suffix for a file written by mpi-process 0
if int(_split_fname_proc_suffix(self.index_filename)[1]) != 0:
raise ValueError(
"the primary file associated with a "
"distributed cholla dataset must end in '.0'"
)

def _parse_index(self):
self.grid_left_edge[0][:] = self.ds.domain_left_edge[:]
self.grid_right_edge[0][:] = self.ds.domain_right_edge[:]
self.grid_dimensions[0][:] = self.ds.domain_dimensions[:]
self.grid_particle_count[0][0] = 0
self.grid_levels[0][0] = 0
self.grids = np.empty(self.num_grids, dtype="object")

# construct an iterable over the pairs of grid-index and corresponding
# filename
if self.num_grids == 1:
ind_fname_pairs = [(0, self.index_filename)]
else:
# index_fname should has the form f'{self.directory}/<prefix>.0'
# strip off the '.0' and determine the contents of <prefix>
pref, suf = _split_fname_proc_suffix(self.index_filename)
assert int(suf) == 0 # sanity check!

ind_fname_pairs = ((i, f"{pref}.{i}") for i in range(self.num_grids))

dims_global = self.ds.domain_dimensions[:]
pbar = get_pbar("Parsing Hierarchy", self.num_grids)

# It would be nice if we could avoid reading in every hdf5 file during
# this step... (to do this, Cholla could probably encode how the blocks
# are sorted in an hdf5 attribute)

for i, fname in ind_fname_pairs:
if self.num_grids == 1:
# if the file was concatenated, we might be missing attributes
# that are accessed in the other branch. To avoid issues, we use
# hardcoded values
left_frac, right_frac, dims_local = 0.0, 1.0, dims_global
else:
with h5py.File(fname, "r") as f:
offset = f.attrs["offset"][:].astype("=i8")
dims_local = f.attrs["dims_local"][:].astype("=i8")
left_frac = offset / dims_global
right_frac = (offset + dims_local) / dims_global

level = 0

self.grids[i] = self.grid(
i,
index=self,
level=level,
dims=dims_local,
filename=fname,
)

self.grid_left_edge[i] = left_frac
self.grid_right_edge[i] = right_frac
self.grid_dimensions[i] = dims_local
Comment on lines +142 to +144
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.grid_left_edge[i] = left_frac
self.grid_right_edge[i] = right_frac
self.grid_dimensions[i] = dims_local
self.grid_left_edge[i,:] = left_frac
self.grid_right_edge[i,:] = right_frac
self.grid_dimensions[i,:] = dims_local

Just for clarity, could we make it obvious that it's setting a slice to the values?

self.grid_levels[i, 0] = level
self.grid_particle_count[i, 0] = 0

pbar.update(i + 1)
pbar.finish()

slope = self.ds.domain_width / self.ds.arr(np.ones(3), "code_length")
self.grid_left_edge = self.grid_left_edge * slope + self.ds.domain_left_edge
self.grid_right_edge = self.grid_right_edge * slope + self.ds.domain_left_edge

self.max_level = 0

def _populate_grid_objects(self):
self.grids = np.empty(self.num_grids, dtype="object")
for i in range(self.num_grids):
g = self.grid(i, self, self.grid_levels.flat[i], self.grid_dimensions[i])
g = self.grids[i]
g._prepare_grid()
g._setup_dx()
self.grids[i] = g


class ChollaDataset(Dataset):
Expand Down Expand Up @@ -103,9 +203,11 @@ def _parse_parameter_file(self):
attrs = h5f.attrs
self.parameters = dict(attrs.items())
self.domain_left_edge = attrs["bounds"][:].astype("=f8")
self.domain_right_edge = attrs["domain"][:].astype("=f8")
self.domain_right_edge = self.domain_left_edge + attrs["domain"][:].astype(
"=f8"
)
self.dimensionality = len(attrs["dims"][:])
self.domain_dimensions = attrs["dims"][:].astype("=f8")
self.domain_dimensions = attrs["dims"][:].astype("=i8")
self.current_time = attrs["t"][:]
self._periodicity = tuple(attrs.get("periodicity", (False, False, False)))
self.gamma = attrs.get("gamma", 5.0 / 3.0)
Expand Down
35 changes: 18 additions & 17 deletions yt/frontends/cholla/io.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import numpy as np

from yt.utilities.io_handler import BaseIOHandler
from yt.utilities.on_demand_imports import _h5py as h5py
Expand All @@ -14,22 +13,24 @@ def _read_particle_coords(self, chunks, ptf):
def _read_particle_fields(self, chunks, ptf, selector):
raise NotImplementedError

def _read_fluid_selection(self, chunks, selector, fields, size):
data = {}
for field in fields:
data[field] = np.empty(size, dtype="float64")

with h5py.File(self.ds.parameter_filename, "r") as fh:
ind = 0
for chunk in chunks:
for grid in chunk.objs:
nd = 0
for field in fields:
ftype, fname = field
values = fh[fname][:].astype("=f8")
nd = grid.select(selector, values, data[field], ind)
ind += nd
return data
def io_iter(self, chunks, fields):
# this is loosely inspired by the implementation used for Enzo/Enzo-E
# - those other options use the lower-level hdf5 interface. Unclear
# whether that affords any advantages...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. I think in the past it did because we avoided having to re-allocate temporary scratch space, but I am not sure that would hold up to current inquiries. I think the big advantage those have is tracking the groups within the iteration.

fh, filename = None, None
for chunk in chunks:
for obj in chunk.objs:
if obj.filename is None: # unclear when this case arises...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

likely it will not here, unless you manually construct virtual grids

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, what is a virtual grid?

I realize this may be an involved answer - so if you could just point me to a frontend (or other area of the code) using virtual grids, I can probably investigate that on my own.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

likely it will not here, unless you manually construct virtual grids

continue
elif obj.filename != filename:
if fh is not None:
fh.close()
fh, filename = h5py.File(obj.filename, "r"), obj.filename
for field in fields:
ftype, fname = field
yield field, obj, fh[fname][:].astype("=f8")
if fh is not None:
fh.close()

def _read_chunk_data(self, chunk, fields):
raise NotImplementedError