diff --git a/odc/loader/_builder.py b/odc/loader/_builder.py index 2c4d207..054753f 100644 --- a/odc/loader/_builder.py +++ b/odc/loader/_builder.py @@ -121,15 +121,15 @@ def __bool__(self) -> bool: def resolve_sources( self, srcs: Sequence[MultiBandRasterSource] - ) -> List[List[RasterSource]]: - out: List[List[RasterSource]] = [] + ) -> List[List[tuple[int, RasterSource]]]: + out: List[List[tuple[int, RasterSource]]] = [] for layer in self.srcs: _srcs: List[RasterSource] = [] for idx in layer: src = srcs[idx].get(self.band, None) if src is not None: - _srcs.append(src) + _srcs.append((idx, src)) out.append(_srcs) return out @@ -263,7 +263,7 @@ def _task_futures( dask_reader: DaskRasterReader, layer_name: str, dsk: dict[Key, Any], - rdr_cache: dict[int, DaskRasterReader] = {}, + rdr_cache: dict[str, DaskRasterReader], ) -> list[list[Key]]: # pylint: disable=too-many-locals srcs = task.resolve_sources(self.srcs) @@ -274,13 +274,14 @@ def _task_futures( for i_time, layer in enumerate(srcs, start=task.idx[0]): keys_out: list[Key] = [] - for i_src, src in enumerate(layer): - idx = (i_time, *task.idx[1:], i_src) + for i_src, src in layer: + idx = (i_src, i_time, *task.idx[1:]) - rdr = rdr_cache.get(i_src) + src_hash = tokenize(src) + rdr = rdr_cache.get(src_hash, None) if rdr is None: rdr = dask_reader.open(src, ctx, layer_name=layer_name, idx=i_src) - rdr_cache[i_src] = rdr + rdr_cache[src_hash] = rdr fut = rdr.read(cfg, dst_gbox, selection=task.selection, idx=idx) keys_out.append(fut.key) @@ -346,7 +347,7 @@ def __call__( cfg, resolve_src_nodata(cfg.fill_value, cfg), ) - rdr_cache: dict[int, DaskRasterReader] = {} + rdr_cache: dict[str, DaskRasterReader] = {} for task in self.load_tasks(name, shape[0]): task_key: Key = (band_layer, *task.idx) @@ -823,7 +824,7 @@ def _do_one(task: LoadChunkTask) -> Tuple[str, int, int, int]: with rdr.restore_env(env, load_state) as ctx: for t_idx, layer in enumerate(layers): - loaders = [rdr.open(src, ctx) for src in layer] + loaders = [rdr.open(src, ctx) for _, src in layer] _ = _fill_nd_slice( loaders, task.dst_gbox,