Skip to content

Commit

Permalink
[KED-2214] Enable loading config files with Jinja2 templating (#867)
Browse files Browse the repository at this point in the history
  • Loading branch information
DmitriiDeriabinQB committed Nov 17, 2020
1 parent 80e2202 commit c466c8a
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 70 deletions.
1 change: 1 addition & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
* Deprecated `KedroContext.hooks`. Instead, hooks should be registered in `settings.py`.
* Made `context_path` an optional key in `pyproject.toml`. `KedroContext` is used by default.
* Removed `ProjectContext` from `src/<package_name>/run.py`.
* `TemplatedConfigLoader` now supports Jinja2 template syntax alongside its original one.

## Bug fixes and other changes
* Bumped maximum required `fsspec` version to 0.9.
Expand Down
45 changes: 45 additions & 0 deletions docs/source/04_kedro_project_setup/02_configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,51 @@ raw_car_data:

> Note: `TemplatedConfigLoader` uses `jmespath` package in the background to extract elements from global dictionary. For more information about JMESPath syntax please see: https://github.com/jmespath/jmespath.py.

### Jinja2 support

From version 0.17.0 `TemplateConfigLoader` also supports [Jinja2](https://palletsprojects.com/p/jinja/) template engine alongside the original template syntax. Below is the example of a `catalog.yml` file, which uses both features:

```
{% for speed in ['fast', 'slow'] %}
{{ speed }}-trains:
type: MemoryDataSet
{{ speed }}-cars:
type: pandas.CSVDataSet
filepath: s3://${bucket_name}/{{ speed }}-cars.csv
save_args:
index: true
{% endfor %}
```

When parsing this configuration file, `TemplateConfigLoader` will:

1. Read the `catalog.yml` and compile it using Jinja2
2. Use YAML parser to parse the compiled config into a Python dictionary
3. Expand `${bucket_name}` in `filepath` using the `globals_*` arguments for the `TemplateConfigLoader` instance as in the previous examples

The output Python dictionary will look as follows:

```python
{
"fast-trains": {"type": "MemoryDataSet"},
"fast-cars": {
"type": "pandas.CSVDataSet",
"filepath": "s3://my_s3_bucket/fast-cars.csv",
"save_args": {"index": True},
},
"slow-trains": {"type": "MemoryDataSet"},
"slow-cars": {
"type": "pandas.CSVDataSet",
"filepath": "s3://my_s3_bucket/slow-cars.csv",
"save_args": {"index": True},
},
}
```

> Note: Although Jinja2 is a very powerful and extremely flexible template engine, which comes with a wide range of features, we do _not_ recommend to use it to template your configuration unless absolutely necessary. The flexibility of dynamic configuration comes at a cost of significantly reduced readability and much higher maintenance overhead. We believe that, for the majority of analytics projects, dynamically compiled configuration does more harm than good.


## Parameters

Expand Down
152 changes: 90 additions & 62 deletions kedro/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,66 @@ def __init__(self, conf_paths: Union[str, Iterable[str]]):
self.conf_paths = _remove_duplicates(conf_paths)
self.logger = logging.getLogger(__name__)

@staticmethod
def _load_config_file(config_file: Path) -> Dict[str, Any]:
"""Load an individual config file using `anyconfig` as a backend.
Args:
config_file: Path to a config file to process.
Returns:
Parsed configuration.
"""
# for performance reasons
import anyconfig # pylint: disable=import-outside-toplevel

return {
k: v
for k, v in anyconfig.load(config_file).items()
if not k.startswith("_")
}

def _load_configs(self, config_filepaths: List[Path]) -> Dict[str, Any]:
"""Recursively load all configuration files, which satisfy
a given list of glob patterns from a specific path.
Args:
config_filepaths: Configuration files sorted in the order of precedence.
Raises:
ValueError: If 2 or more configuration files contain the same key(s).
Returns:
Resulting configuration dictionary.
"""

aggregate_config = {}
seen_file_to_keys = {} # type: Dict[Path, AbstractSet[str]]

for config_filepath in config_filepaths:
single_config = self._load_config_file(config_filepath)
_check_duplicate_keys(seen_file_to_keys, config_filepath, single_config)
seen_file_to_keys[config_filepath] = single_config.keys()
aggregate_config.update(single_config)

return aggregate_config

def _lookup_config_filepaths(
self, conf_path: Path, patterns: Iterable[str], processed_files: Set[Path]
) -> List[Path]:
config_files = _path_lookup(conf_path, patterns)

seen_files = config_files & processed_files
if seen_files:
self.logger.warning(
"Config file(s): %s already processed, skipping loading...",
", ".join(str(seen) for seen in sorted(seen_files)),
)
config_files -= seen_files

return sorted(config_files)

def get(self, *patterns: str) -> Dict[str, Any]:
"""Recursively scan for configuration files, load and merge them, and
return them in the form of a config dictionary.
Expand Down Expand Up @@ -156,19 +216,14 @@ def get(self, *patterns: str) -> Dict[str, Any]:
for conf_path in self.conf_paths:
if not Path(conf_path).is_dir():
raise ValueError(
"Given configuration path either does not exist "
"or is not a valid directory: {0}".format(conf_path)
f"Given configuration path either does not exist "
f"or is not a valid directory: {conf_path}"
)

config_files = _path_lookup(Path(conf_path), patterns)
seen_files = set(config_files) & processed_files
if seen_files:
self.logger.warning(
"Config file(s): %s already processed, skipping loading...",
",".join(str(seen) for seen in sorted(seen_files)),
)
config_files = [cf for cf in config_files if cf not in seen_files]
new_conf = _load_config(config_files)
config_filepaths = self._lookup_config_filepaths(
Path(conf_path), patterns, processed_files
)
new_conf = self._load_configs(config_filepaths)

common_keys = config.keys() & new_conf.keys()
if common_keys:
Expand All @@ -180,73 +235,45 @@ def get(self, *patterns: str) -> Dict[str, Any]:
self.logger.info(msg, conf_path, sorted_keys)

config.update(new_conf)
processed_files |= set(config_files)
processed_files |= set(config_filepaths)

if not processed_files:
raise MissingConfigException(
"No files found in {} matching the glob "
"pattern(s): {}".format(str(self.conf_paths), str(list(patterns)))
f"No files found in {self.conf_paths} matching the glob "
f"pattern(s): {list(patterns)}"
)
return config


def _load_config(config_files: List[Path]) -> Dict[str, Any]:
"""Recursively load all configuration files, which satisfy
a given list of glob patterns from a specific path.
Args:
config_files: Configuration files sorted in the order of precedence.
Raises:
ValueError: If 2 or more configuration files contain the same key(s).
def _check_duplicate_keys(
processed_files: Dict[Path, AbstractSet[str]], filepath: Path, conf: Dict[str, Any]
) -> None:
duplicates = []

Returns:
Resulting configuration dictionary.
for processed_file, keys in processed_files.items():
overlapping_keys = conf.keys() & keys

"""
# for performance reasons
import anyconfig # pylint: disable=import-outside-toplevel

config = {}
keys_by_filepath = {} # type: Dict[Path, AbstractSet[str]]

def _check_dups(file1: Path, conf: Dict[str, Any]) -> None:
dups = set()
for file2, keys in keys_by_filepath.items():
common = ", ".join(sorted(conf.keys() & keys))
if common:
if len(common) > 100:
common = common[:100] + "..."
dups.add("{}: {}".format(str(file2), common))

if dups:
msg = "Duplicate keys found in {0} and:\n- {1}".format(
file1, "\n- ".join(dups)
)
raise ValueError(msg)
if overlapping_keys:
sorted_keys = ", ".join(sorted(overlapping_keys))
if len(sorted_keys) > 100:
sorted_keys = sorted_keys[:100] + "..."
duplicates.append(f"{processed_file}: {sorted_keys}")

for config_file in config_files:
cfg = {
k: v
for k, v in anyconfig.load(config_file).items()
if not k.startswith("_")
}
_check_dups(config_file, cfg)
keys_by_filepath[config_file] = cfg.keys()
config.update(cfg)
return config
if duplicates:
dup_str = "\n- ".join(duplicates)
raise ValueError(f"Duplicate keys found in {filepath} and:\n- {dup_str}")


def _path_lookup(conf_path: Path, patterns: Iterable[str]) -> List[Path]:
"""Return a sorted list of all configuration files from ``conf_path`` or
def _path_lookup(conf_path: Path, patterns: Iterable[str]) -> Set[Path]:
"""Return a set of all configuration files from ``conf_path`` or
its subdirectories, which satisfy a given list of glob patterns.
Args:
conf_path: Path to configuration directory.
patterns: List of glob patterns to match the filenames against.
Returns:
Sorted list of ``Path`` objects representing configuration files.
A set of paths to configuration files.
"""
config_files = set()
Expand All @@ -259,7 +286,8 @@ def _path_lookup(conf_path: Path, patterns: Iterable[str]) -> List[Path]:
path = Path(each).resolve()
if path.is_file() and path.suffix in SUPPORTED_EXTENSIONS:
config_files.add(path)
return sorted(config_files)

return config_files


def _remove_duplicates(items: Iterable[str]):
Expand All @@ -270,7 +298,7 @@ def _remove_duplicates(items: Iterable[str]):
unique_items.append(item)
else:
warn(
"Duplicate environment detected! "
"Skipping re-loading from configuration path: {}".format(item)
f"Duplicate environment detected! "
f"Skipping re-loading from configuration path: {item}"
)
return unique_items
20 changes: 20 additions & 0 deletions kedro/config/templated_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"""
import re
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, Iterable, Optional, Union

import jmespath
Expand Down Expand Up @@ -145,6 +146,25 @@ def __init__(
globals_dict = deepcopy(globals_dict) or {}
self._arg_dict = {**self._arg_dict, **globals_dict}

@staticmethod
def _load_config_file(config_file: Path) -> Dict[str, Any]:
"""Load an individual config file using `anyconfig` as a backend.
Args:
config_file: Path to a config file to process.
Returns:
Parsed configuration.
"""
# for performance reasons
import anyconfig # pylint: disable=import-outside-toplevel

return {
k: v
for k, v in anyconfig.load(config_file, ac_template=True).items()
if not k.startswith("_")
}

def get(self, *patterns: str) -> Dict[str, Any]:
"""Tries to resolve the template variables in the config dictionary
provided by the ``ConfigLoader`` (super class) ``get`` method using the
Expand Down
Loading

0 comments on commit c466c8a

Please sign in to comment.