Skip to content

Commit

Permalink
fix(ParticleData): support partlocs as ndarray or list of lists (#1752)
Browse files Browse the repository at this point in the history
  • Loading branch information
wpbonelli committed Mar 23, 2023
1 parent c307420 commit 8d52ece
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 4 deletions.
51 changes: 51 additions & 0 deletions autotest/test_particledata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import numpy as np

from flopy.modpath import ParticleData

structured_plocs = [(1, 1, 1), (1, 1, 2)]
structured_dtype = np.dtype(
[
("k", "<i4"),
("i", "<i4"),
("j", "<i4"),
("localx", "<f4"),
("localy", "<f4"),
("localz", "<f4"),
("timeoffset", "<f4"),
("drape", "<i4"),
]
)
structured_array = np.core.records.fromrecords(
[
(1, 1, 1, 0.5, 0.5, 0.5, 0.0, 0),
(1, 1, 2, 0.5, 0.5, 0.5, 0.0, 0),
],
dtype=structured_dtype,
)


def test_particledata_structured_partlocs_as_list_of_tuples():
locs = structured_plocs
data = ParticleData(partlocs=locs, structured=True)

assert data.particlecount == 2
assert data.dtype == structured_dtype
assert np.array_equal(data.particledata, structured_array)


def test_particledata_structured_partlocs_as_ndarray():
locs = np.array(structured_plocs)
data = ParticleData(partlocs=locs, structured=True)

assert data.particlecount == 2
assert data.dtype == structured_dtype
assert np.array_equal(data.particledata, structured_array)


def test_particledata_structured_partlocs_as_list_of_lists():
locs = [list(p) for p in structured_plocs]
data = ParticleData(partlocs=locs, structured=True)

assert data.particlecount == 2
assert data.dtype == structured_dtype
assert np.array_equal(data.particledata, structured_array)
11 changes: 7 additions & 4 deletions flopy/modpath/mp7particledata.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""

import numpy as np
from numpy.lib.recfunctions import unstructured_to_structured

from ..utils.recarray_utils import create_empty_recarray

Expand Down Expand Up @@ -125,7 +126,7 @@ def __init__(
alllen3 = all(len(el) == 3 for el in partlocs)
if not alllen3:
raise ValueError(
"{}: all partlocs entries must have 3 items for "
"{}: all partlocs entries must have 3 items for "
"structured particle data".format(self.name)
)
else:
Expand Down Expand Up @@ -164,14 +165,16 @@ def __init__(

# convert partlocs composed of a lists/tuples of lists/tuples
# to a numpy array
partlocs = np.array(partlocs, dtype=dtype)
partlocs = unstructured_to_structured(
np.array(partlocs), dtype=dtype
)
elif isinstance(partlocs, np.ndarray):
dtypein = partlocs.dtype
if dtypein != dtype:
partlocs = np.array(partlocs, dtype=dtype)
partlocs = unstructured_to_structured(partlocs, dtype=dtype)
else:
raise ValueError(
f"{self.name}: partlocs must be a list or tuple with lists or tuples"
f"{self.name}: partlocs must be a list or tuple with lists or tuples, or an ndarray"
)

# localx
Expand Down

0 comments on commit 8d52ece

Please sign in to comment.