From c466c8aa36488c8fa8d1fe6bb3d5bcde81100e4a Mon Sep 17 00:00:00 2001 From: Dmitrii Deriabin <44967953+DmitriiDeriabinQB@users.noreply.github.com> Date: Tue, 17 Nov 2020 18:26:27 +0000 Subject: [PATCH] [KED-2214] Enable loading config files with Jinja2 templating (#867) --- RELEASE.md | 1 + .../02_configuration.md | 45 ++++++ kedro/config/config.py | 152 +++++++++++------- kedro/config/templated_config.py | 20 +++ tests/config/test_templated_config.py | 62 ++++++- 5 files changed, 210 insertions(+), 70 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index 6a8522e238..fd2fda5746 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -7,6 +7,7 @@ * Deprecated `KedroContext.hooks`. Instead, hooks should be registered in `settings.py`. * Made `context_path` an optional key in `pyproject.toml`. `KedroContext` is used by default. * Removed `ProjectContext` from `src//run.py`. +* `TemplatedConfigLoader` now supports Jinja2 template syntax alongside its original one. ## Bug fixes and other changes * Bumped maximum required `fsspec` version to 0.9. diff --git a/docs/source/04_kedro_project_setup/02_configuration.md b/docs/source/04_kedro_project_setup/02_configuration.md index 5fbae6c034..48d9c2590f 100644 --- a/docs/source/04_kedro_project_setup/02_configuration.md +++ b/docs/source/04_kedro_project_setup/02_configuration.md @@ -136,6 +136,51 @@ raw_car_data: > Note: `TemplatedConfigLoader` uses `jmespath` package in the background to extract elements from global dictionary. For more information about JMESPath syntax please see: https://github.com/jmespath/jmespath.py. +### Jinja2 support + +From version 0.17.0 `TemplateConfigLoader` also supports [Jinja2](https://palletsprojects.com/p/jinja/) template engine alongside the original template syntax. Below is the example of a `catalog.yml` file, which uses both features: + +``` +{% for speed in ['fast', 'slow'] %} +{{ speed }}-trains: + type: MemoryDataSet + +{{ speed }}-cars: + type: pandas.CSVDataSet + filepath: s3://${bucket_name}/{{ speed }}-cars.csv + save_args: + index: true + +{% endfor %} +``` + +When parsing this configuration file, `TemplateConfigLoader` will: + +1. Read the `catalog.yml` and compile it using Jinja2 +2. Use YAML parser to parse the compiled config into a Python dictionary +3. Expand `${bucket_name}` in `filepath` using the `globals_*` arguments for the `TemplateConfigLoader` instance as in the previous examples + +The output Python dictionary will look as follows: + +```python +{ + "fast-trains": {"type": "MemoryDataSet"}, + "fast-cars": { + "type": "pandas.CSVDataSet", + "filepath": "s3://my_s3_bucket/fast-cars.csv", + "save_args": {"index": True}, + }, + "slow-trains": {"type": "MemoryDataSet"}, + "slow-cars": { + "type": "pandas.CSVDataSet", + "filepath": "s3://my_s3_bucket/slow-cars.csv", + "save_args": {"index": True}, + }, +} +``` + +> Note: Although Jinja2 is a very powerful and extremely flexible template engine, which comes with a wide range of features, we do _not_ recommend to use it to template your configuration unless absolutely necessary. The flexibility of dynamic configuration comes at a cost of significantly reduced readability and much higher maintenance overhead. We believe that, for the majority of analytics projects, dynamically compiled configuration does more harm than good. + ## Parameters diff --git a/kedro/config/config.py b/kedro/config/config.py index fbb3aae6fc..175f0a2200 100644 --- a/kedro/config/config.py +++ b/kedro/config/config.py @@ -123,6 +123,66 @@ def __init__(self, conf_paths: Union[str, Iterable[str]]): self.conf_paths = _remove_duplicates(conf_paths) self.logger = logging.getLogger(__name__) + @staticmethod + def _load_config_file(config_file: Path) -> Dict[str, Any]: + """Load an individual config file using `anyconfig` as a backend. + + Args: + config_file: Path to a config file to process. + + Returns: + Parsed configuration. + """ + # for performance reasons + import anyconfig # pylint: disable=import-outside-toplevel + + return { + k: v + for k, v in anyconfig.load(config_file).items() + if not k.startswith("_") + } + + def _load_configs(self, config_filepaths: List[Path]) -> Dict[str, Any]: + """Recursively load all configuration files, which satisfy + a given list of glob patterns from a specific path. + + Args: + config_filepaths: Configuration files sorted in the order of precedence. + + Raises: + ValueError: If 2 or more configuration files contain the same key(s). + + Returns: + Resulting configuration dictionary. + + """ + + aggregate_config = {} + seen_file_to_keys = {} # type: Dict[Path, AbstractSet[str]] + + for config_filepath in config_filepaths: + single_config = self._load_config_file(config_filepath) + _check_duplicate_keys(seen_file_to_keys, config_filepath, single_config) + seen_file_to_keys[config_filepath] = single_config.keys() + aggregate_config.update(single_config) + + return aggregate_config + + def _lookup_config_filepaths( + self, conf_path: Path, patterns: Iterable[str], processed_files: Set[Path] + ) -> List[Path]: + config_files = _path_lookup(conf_path, patterns) + + seen_files = config_files & processed_files + if seen_files: + self.logger.warning( + "Config file(s): %s already processed, skipping loading...", + ", ".join(str(seen) for seen in sorted(seen_files)), + ) + config_files -= seen_files + + return sorted(config_files) + def get(self, *patterns: str) -> Dict[str, Any]: """Recursively scan for configuration files, load and merge them, and return them in the form of a config dictionary. @@ -156,19 +216,14 @@ def get(self, *patterns: str) -> Dict[str, Any]: for conf_path in self.conf_paths: if not Path(conf_path).is_dir(): raise ValueError( - "Given configuration path either does not exist " - "or is not a valid directory: {0}".format(conf_path) + f"Given configuration path either does not exist " + f"or is not a valid directory: {conf_path}" ) - config_files = _path_lookup(Path(conf_path), patterns) - seen_files = set(config_files) & processed_files - if seen_files: - self.logger.warning( - "Config file(s): %s already processed, skipping loading...", - ",".join(str(seen) for seen in sorted(seen_files)), - ) - config_files = [cf for cf in config_files if cf not in seen_files] - new_conf = _load_config(config_files) + config_filepaths = self._lookup_config_filepaths( + Path(conf_path), patterns, processed_files + ) + new_conf = self._load_configs(config_filepaths) common_keys = config.keys() & new_conf.keys() if common_keys: @@ -180,65 +235,37 @@ def get(self, *patterns: str) -> Dict[str, Any]: self.logger.info(msg, conf_path, sorted_keys) config.update(new_conf) - processed_files |= set(config_files) + processed_files |= set(config_filepaths) if not processed_files: raise MissingConfigException( - "No files found in {} matching the glob " - "pattern(s): {}".format(str(self.conf_paths), str(list(patterns))) + f"No files found in {self.conf_paths} matching the glob " + f"pattern(s): {list(patterns)}" ) return config -def _load_config(config_files: List[Path]) -> Dict[str, Any]: - """Recursively load all configuration files, which satisfy - a given list of glob patterns from a specific path. - - Args: - config_files: Configuration files sorted in the order of precedence. - - Raises: - ValueError: If 2 or more configuration files contain the same key(s). +def _check_duplicate_keys( + processed_files: Dict[Path, AbstractSet[str]], filepath: Path, conf: Dict[str, Any] +) -> None: + duplicates = [] - Returns: - Resulting configuration dictionary. + for processed_file, keys in processed_files.items(): + overlapping_keys = conf.keys() & keys - """ - # for performance reasons - import anyconfig # pylint: disable=import-outside-toplevel - - config = {} - keys_by_filepath = {} # type: Dict[Path, AbstractSet[str]] - - def _check_dups(file1: Path, conf: Dict[str, Any]) -> None: - dups = set() - for file2, keys in keys_by_filepath.items(): - common = ", ".join(sorted(conf.keys() & keys)) - if common: - if len(common) > 100: - common = common[:100] + "..." - dups.add("{}: {}".format(str(file2), common)) - - if dups: - msg = "Duplicate keys found in {0} and:\n- {1}".format( - file1, "\n- ".join(dups) - ) - raise ValueError(msg) + if overlapping_keys: + sorted_keys = ", ".join(sorted(overlapping_keys)) + if len(sorted_keys) > 100: + sorted_keys = sorted_keys[:100] + "..." + duplicates.append(f"{processed_file}: {sorted_keys}") - for config_file in config_files: - cfg = { - k: v - for k, v in anyconfig.load(config_file).items() - if not k.startswith("_") - } - _check_dups(config_file, cfg) - keys_by_filepath[config_file] = cfg.keys() - config.update(cfg) - return config + if duplicates: + dup_str = "\n- ".join(duplicates) + raise ValueError(f"Duplicate keys found in {filepath} and:\n- {dup_str}") -def _path_lookup(conf_path: Path, patterns: Iterable[str]) -> List[Path]: - """Return a sorted list of all configuration files from ``conf_path`` or +def _path_lookup(conf_path: Path, patterns: Iterable[str]) -> Set[Path]: + """Return a set of all configuration files from ``conf_path`` or its subdirectories, which satisfy a given list of glob patterns. Args: @@ -246,7 +273,7 @@ def _path_lookup(conf_path: Path, patterns: Iterable[str]) -> List[Path]: patterns: List of glob patterns to match the filenames against. Returns: - Sorted list of ``Path`` objects representing configuration files. + A set of paths to configuration files. """ config_files = set() @@ -259,7 +286,8 @@ def _path_lookup(conf_path: Path, patterns: Iterable[str]) -> List[Path]: path = Path(each).resolve() if path.is_file() and path.suffix in SUPPORTED_EXTENSIONS: config_files.add(path) - return sorted(config_files) + + return config_files def _remove_duplicates(items: Iterable[str]): @@ -270,7 +298,7 @@ def _remove_duplicates(items: Iterable[str]): unique_items.append(item) else: warn( - "Duplicate environment detected! " - "Skipping re-loading from configuration path: {}".format(item) + f"Duplicate environment detected! " + f"Skipping re-loading from configuration path: {item}" ) return unique_items diff --git a/kedro/config/templated_config.py b/kedro/config/templated_config.py index bd2b4ba684..78ceb98a71 100644 --- a/kedro/config/templated_config.py +++ b/kedro/config/templated_config.py @@ -31,6 +31,7 @@ """ import re from copy import deepcopy +from pathlib import Path from typing import Any, Dict, Iterable, Optional, Union import jmespath @@ -145,6 +146,25 @@ def __init__( globals_dict = deepcopy(globals_dict) or {} self._arg_dict = {**self._arg_dict, **globals_dict} + @staticmethod + def _load_config_file(config_file: Path) -> Dict[str, Any]: + """Load an individual config file using `anyconfig` as a backend. + + Args: + config_file: Path to a config file to process. + + Returns: + Parsed configuration. + """ + # for performance reasons + import anyconfig # pylint: disable=import-outside-toplevel + + return { + k: v + for k, v in anyconfig.load(config_file, ac_template=True).items() + if not k.startswith("_") + } + def get(self, *patterns: str) -> Dict[str, Any]: """Tries to resolve the template variables in the config dictionary provided by the ``ConfigLoader`` (super class) ``get`` method using the diff --git a/tests/config/test_templated_config.py b/tests/config/test_templated_config.py index fc2ead65a1..a7529a1c10 100644 --- a/tests/config/test_templated_config.py +++ b/tests/config/test_templated_config.py @@ -77,6 +77,28 @@ def template_config(): } +@pytest.fixture +def catalog_with_jinja2_syntax(tmp_path): + filepath = tmp_path / "base" / "catalog.yml" + + catalog = """ +{% for speed in ['fast', 'slow'] %} +{{ speed }}-trains: + type: MemoryDataSet + +{{ speed }}-cars: + type: pandas.CSVDataSet + filepath: ${s3_bucket}/{{ speed }}-cars.csv + save_args: + index: true + +{% endfor %} +""" + + filepath.parent.mkdir(parents=True, exist_ok=True) + filepath.write_text(catalog) + + @pytest.fixture def proj_catalog_param(tmp_path, param_config): proj_catalog = tmp_path / "base" / "catalog.yml" @@ -205,7 +227,7 @@ def proj_catalog_param_w_vals_exceptional(tmp_path, param_config_exceptional): class TestTemplatedConfigLoader: @pytest.mark.usefixtures("proj_catalog_param") - def test_catlog_parameterized_w_dict(self, tmp_path, conf_paths, template_config): + def test_catalog_parameterized_w_dict(self, tmp_path, conf_paths, template_config): """Test parameterized config with input from dictionary with values""" (tmp_path / "local").mkdir(exist_ok=True) @@ -223,7 +245,7 @@ def test_catlog_parameterized_w_dict(self, tmp_path, conf_paths, template_config assert catalog["boats"]["users"] == ["fred", "ron"] @pytest.mark.usefixtures("proj_catalog_param", "proj_catalog_globals") - def test_catlog_parameterized_w_globals(self, tmp_path, conf_paths): + def test_catalog_parameterized_w_globals(self, tmp_path, conf_paths): """Test parameterized config with globals yaml file""" (tmp_path / "local").mkdir(exist_ok=True) @@ -241,7 +263,7 @@ def test_catlog_parameterized_w_globals(self, tmp_path, conf_paths): assert catalog["boats"]["users"] == ["fred", "ron"] @pytest.mark.usefixtures("proj_catalog_param") - def test_catlog_parameterized_no_params(self, tmp_path, conf_paths): + def test_catalog_parameterized_no_params(self, tmp_path, conf_paths): """Test parameterized config without input""" (tmp_path / "local").mkdir(exist_ok=True) @@ -258,7 +280,7 @@ def test_catlog_parameterized_no_params(self, tmp_path, conf_paths): assert catalog["boats"]["users"] == ["fred", "${write_only_user}"] @pytest.mark.usefixtures("proj_catalog_advanced") - def test_catlog_advanced(self, tmp_path, conf_paths, normal_config_advanced): + def test_catalog_advanced(self, tmp_path, conf_paths, normal_config_advanced): """Test whether it responds well to advanced yaml values (i.e. nested dicts, booleans, lists, etc.)""" (tmp_path / "local").mkdir(exist_ok=True) @@ -275,7 +297,7 @@ def test_catlog_advanced(self, tmp_path, conf_paths, normal_config_advanced): assert catalog["planes"]["secret_tables"] == ["models", "pilots", "engines"] @pytest.mark.usefixtures("proj_catalog_param_w_vals_advanced") - def test_catlog_parameterized_advanced( + def test_catalog_parameterized_advanced( self, tmp_path, conf_paths, template_config_advanced ): """Test advanced templating (i.e. nested dicts, booleans, lists, etc.)""" @@ -293,7 +315,9 @@ def test_catlog_parameterized_advanced( assert catalog["planes"]["secret_tables"] == ["models", "pilots", "engines"] @pytest.mark.usefixtures("proj_catalog_param_mixed", "proj_catalog_globals") - def test_catlog_parameterized_w_dict_mixed(self, tmp_path, conf_paths, get_environ): + def test_catalog_parameterized_w_dict_mixed( + self, tmp_path, conf_paths, get_environ + ): """Test parameterized config with input from dictionary with values and globals.yml""" (tmp_path / "local").mkdir(exist_ok=True) @@ -312,7 +336,7 @@ def test_catlog_parameterized_w_dict_mixed(self, tmp_path, conf_paths, get_envir assert catalog["boats"]["users"] == ["fred", "ron"] @pytest.mark.usefixtures("proj_catalog_param_namespaced") - def test_catlog_parameterized_w_dict_namespaced( + def test_catalog_parameterized_w_dict_namespaced( self, tmp_path, conf_paths, template_config, get_environ ): """Test parameterized config with namespacing in the template values""" @@ -332,7 +356,7 @@ def test_catlog_parameterized_w_dict_namespaced( assert catalog["boats"]["users"] == ["fred", "ron"] @pytest.mark.usefixtures("proj_catalog_param_w_vals_exceptional") - def test_catlog_parameterized_exceptional( + def test_catalog_parameterized_exceptional( self, tmp_path, conf_paths, template_config_exceptional ): """Test templating with mixed type replacement values going into one string""" @@ -344,6 +368,28 @@ def test_catlog_parameterized_exceptional( assert catalog["postcode"] == "NW10 2JK" + @pytest.mark.usefixtures("catalog_with_jinja2_syntax") + def test_catalog_with_jinja2_syntax(self, tmp_path, conf_paths, template_config): + (tmp_path / "local").mkdir(exist_ok=True) + catalog = TemplatedConfigLoader(conf_paths, globals_dict=template_config).get( + "catalog*.yml" + ) + expected_catalog = { + "fast-trains": {"type": "MemoryDataSet"}, + "fast-cars": { + "type": "pandas.CSVDataSet", + "filepath": "s3a://boat-and-car-bucket/fast-cars.csv", + "save_args": {"index": True}, + }, + "slow-trains": {"type": "MemoryDataSet"}, + "slow-cars": { + "type": "pandas.CSVDataSet", + "filepath": "s3a://boat-and-car-bucket/slow-cars.csv", + "save_args": {"index": True}, + }, + } + assert catalog == expected_catalog + class TestFormatObject: @pytest.mark.parametrize(