Skip to content

Commit

Permalink
chore[test_files.py]: Add test for transform and validate
Browse files Browse the repository at this point in the history
  • Loading branch information
Artem Vasenin committed Sep 23, 2020
1 parent aa1ce34 commit 9fb912a
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 6 deletions.
32 changes: 26 additions & 6 deletions ntc/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class CfgNode(UserDict):
_NEW_ALLOWED = "_new_allowed"

_BUILT_IN_ATTRS = [_SCHEMA_FROZEN, _FROZEN, _WAS_UNFROZEN, _LEAF_SPEC, _MODULE, _NEW_ALLOWED]
RESERVED_KEYS = [*_BUILT_IN_ATTRS, "data"]
RESERVED_KEYS = [*_BUILT_IN_ATTRS, "data", "_post_load", "transform", "_validate"]

def __init__(
self,
Expand Down Expand Up @@ -95,8 +95,8 @@ def __len__(self) -> int:
@staticmethod
def load(cfg_path: Union[Path, str]) -> CfgNode:
module = import_module(cfg_path)
cfg = module.cfg
cfg.validate()
cfg: CfgNode = module.cfg
cfg._post_load(cfg_path)
cfg.set_module(module)
cfg.freeze()

Expand All @@ -116,13 +116,33 @@ def save(self, path: Path) -> None:
def clone(self) -> CfgNode:
return CfgNode(base=self)

def _post_load(self, cfg_path: Union[Path, str]) -> None:
"""
Any actions to be done after loading
:param cfg_path: File from which config was loaded
"""
self._transform()
self.validate()

def _validate(self) -> None:
"""
Specify additional rules to check
"""

def _transform(self) -> None:
"""
Specify additional changes to be made after manual changes
"""

def validate(self) -> None:
self._validate()
for key, attr in self.attrs:
if isinstance(attr, CfgNode):
attr.validate()
else:
if attr.required and attr.value is None:
raise MissingRequired(f"Key {key} is required, but was not provided.")
continue
if attr.required and attr.value is None:
raise MissingRequired(f"Key {key} is required, but was not provided.")

def to_dict(self) -> Dict[str, Any]:
attrs = {}
Expand Down
14 changes: 14 additions & 0 deletions tests/data/transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from ntc import CN

from .base_cfg import cfg as bc

cfg = CN(bc)
cfg.NAME = "Name"


def transform(cfg: CN):
if cfg.DICT.FOO == "foo":
cfg.DICT.FOO = "bar"


cfg.transform = transform
13 changes: 13 additions & 0 deletions tests/data/validate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from ntc import CN

from .base_cfg import cfg as bc

cfg = CN(bc)
cfg.NAME = "Name"


def validate(cfg: CN):
assert cfg.NAME != "Name"


cfg.validate = validate
10 changes: 10 additions & 0 deletions tests/test_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,13 @@ def test_bad_node_subclass():
def test_bad_node_instance():
with pytest.raises(SchemaError):
CN.load(DATA_DIR / "bad_node_instance.py")


def test_transform():
transform = CN.load(DATA_DIR / "transform.py")
assert transform.DICT.FOO == "bar"


def test_validate():
with pytest.raises(AssertionError):
CN.load(DATA_DIR / "validate.py")

0 comments on commit 9fb912a

Please sign in to comment.