Skip to content

Commit

Permalink
sqme: fix rdr_cache impl
Browse files Browse the repository at this point in the history
- use actual src_idx, not index for that tile
- use tokenize(RasterSource) as cache key
  • Loading branch information
Kirill888 committed Jun 28, 2024
1 parent ba32007 commit 1860eb3
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions odc/loader/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,15 +121,15 @@ def __bool__(self) -> bool:

def resolve_sources(
self, srcs: Sequence[MultiBandRasterSource]
) -> List[List[RasterSource]]:
out: List[List[RasterSource]] = []
) -> List[List[tuple[int, RasterSource]]]:
out: List[List[tuple[int, RasterSource]]] = []

for layer in self.srcs:
_srcs: List[RasterSource] = []
for idx in layer:
src = srcs[idx].get(self.band, None)
if src is not None:
_srcs.append(src)
_srcs.append((idx, src))
out.append(_srcs)
return out

Expand Down Expand Up @@ -263,7 +263,7 @@ def _task_futures(
dask_reader: DaskRasterReader,
layer_name: str,
dsk: dict[Key, Any],
rdr_cache: dict[int, DaskRasterReader] = {},
rdr_cache: dict[str, DaskRasterReader],
) -> list[list[Key]]:
# pylint: disable=too-many-locals
srcs = task.resolve_sources(self.srcs)
Expand All @@ -274,13 +274,14 @@ def _task_futures(

for i_time, layer in enumerate(srcs, start=task.idx[0]):
keys_out: list[Key] = []
for i_src, src in enumerate(layer):
idx = (i_time, *task.idx[1:], i_src)
for i_src, src in layer:
idx = (i_src, i_time, *task.idx[1:])

rdr = rdr_cache.get(i_src)
src_hash = tokenize(src)
rdr = rdr_cache.get(src_hash, None)
if rdr is None:
rdr = dask_reader.open(src, ctx, layer_name=layer_name, idx=i_src)
rdr_cache[i_src] = rdr
rdr_cache[src_hash] = rdr

fut = rdr.read(cfg, dst_gbox, selection=task.selection, idx=idx)
keys_out.append(fut.key)
Expand Down Expand Up @@ -346,7 +347,7 @@ def __call__(
cfg,
resolve_src_nodata(cfg.fill_value, cfg),
)
rdr_cache: dict[int, DaskRasterReader] = {}
rdr_cache: dict[str, DaskRasterReader] = {}

for task in self.load_tasks(name, shape[0]):
task_key: Key = (band_layer, *task.idx)
Expand Down Expand Up @@ -823,7 +824,7 @@ def _do_one(task: LoadChunkTask) -> Tuple[str, int, int, int]:

with rdr.restore_env(env, load_state) as ctx:
for t_idx, layer in enumerate(layers):
loaders = [rdr.open(src, ctx) for src in layer]
loaders = [rdr.open(src, ctx) for _, src in layer]
_ = _fill_nd_slice(
loaders,
task.dst_gbox,
Expand Down

0 comments on commit 1860eb3

Please sign in to comment.