-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add results to xugrid #1369
Add results to xugrid #1369
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -264,10 +264,10 @@ | |
# TODO | ||
# self.validate_model() | ||
filepath = Path(filepath) | ||
self.filepath = filepath | ||
if not filepath.suffix == ".toml": | ||
raise ValueError(f"Filepath '{filepath}' is not a .toml file.") | ||
context_file_loading.set({}) | ||
filepath = Path(filepath) | ||
directory = filepath.parent | ||
directory.mkdir(parents=True, exist_ok=True) | ||
self._save(directory, self.input_dir) | ||
|
@@ -280,7 +280,7 @@ | |
def _load(cls, filepath: Path | None) -> dict[str, Any]: | ||
context_file_loading.set({}) | ||
|
||
if filepath is not None: | ||
if filepath is not None and filepath.is_file(): | ||
with open(filepath, "rb") as f: | ||
config = tomli.load(f) | ||
|
||
|
@@ -395,9 +395,10 @@ | |
|
||
return ax | ||
|
||
def to_xugrid(self): | ||
def to_xugrid(self, add_results: bool = True): | ||
""" | ||
Convert the network to a `xugrid.UgridDataset`. | ||
Convert the network and results to a `xugrid.UgridDataset`. | ||
To get the network only, set `add_results=False`. | ||
This method will throw `ImportError`, | ||
if the optional dependency `xugrid` isn't installed. | ||
""" | ||
|
@@ -449,4 +450,63 @@ | |
uds = uds.assign_coords(from_node_id=(edge_dim, from_node_id)) | ||
uds = uds.assign_coords(to_node_id=(edge_dim, to_node_id)) | ||
|
||
if add_results: | ||
uds = self._add_results(uds) | ||
|
||
return uds | ||
|
||
def _add_results(self, uds): | ||
toml_path = self.filepath | ||
if toml_path is None: | ||
raise FileNotFoundError("Model must be written to disk to add results.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does The error happens when There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't |
||
|
||
results_path = toml_path.parent / self.results_dir | ||
basin_path = results_path / "basin.arrow" | ||
flow_path = results_path / "flow.arrow" | ||
|
||
if not basin_path.is_file() or not flow_path.is_file(): | ||
raise FileNotFoundError( | ||
f"Cannot find results in '{results_path}', " | ||
"perhaps the model needs to be run first." | ||
) | ||
|
||
basin_df = pd.read_feather(basin_path) | ||
flow_df = pd.read_feather(flow_path) | ||
|
||
edge_dim = uds.grid.edge_dimension | ||
node_dim = uds.grid.node_dimension | ||
|
||
# from node_id to the node_dim index | ||
node_lookup = pd.Series( | ||
index=uds["node_id"], | ||
data=uds[edge_dim], | ||
name="node_index", | ||
) | ||
# from edge_id to the edge_dim index | ||
edge_lookup = pd.Series( | ||
index=uds["edge_id"], | ||
data=uds[edge_dim], | ||
name="edge_index", | ||
) | ||
|
||
basin_df = pd.read_feather(basin_path) | ||
flow_df = pd.read_feather(flow_path) | ||
|
||
# datetime64[ms] gives trouble; https://github.com/pydata/xarray/issues/6318 | ||
flow_df["time"] = flow_df["time"].astype("datetime64[ns]") | ||
basin_df["time"] = basin_df["time"].astype("datetime64[ns]") | ||
|
||
# add flow results to the UgridDataset | ||
flow_df[edge_dim] = edge_lookup[flow_df["edge_id"]].to_numpy() | ||
flow_da = flow_df.set_index(["time", edge_dim])["flow_rate"].to_xarray() | ||
uds[flow_da.name] = flow_da | ||
|
||
# add basin results to the UgridDataset | ||
basin_df[node_dim] = node_lookup[basin_df["node_id"]].to_numpy() | ||
basin_df.drop(columns=["node_id"], inplace=True) | ||
basin_ds = basin_df.set_index(["time", node_dim]).to_xarray() | ||
|
||
for var_name, da in basin_ds.data_vars.items(): | ||
uds[var_name] = da | ||
|
||
return uds |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at this, it feels a bit weird that we store
self.filepath
, but then not use it here.Should the
filepath
be maybe optional forwrite
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could make
filepath
optional forwrite
, though then we are encouraging overwriting models, and it would perhaps not be obvious to code readers where their model ended up.This feels ok to me since once an in-memory model is written to disk, it has a
filepath
. Otherwise that would only be set when reading a model withModel.read("path/to/ribasim.toml")
.