Skip to content

Commit

Permalink
imp(node.py): Add pathlib.PosixPath to safe save types
Browse files Browse the repository at this point in the history
  • Loading branch information
Artem Vasenin committed Aug 10, 2021
1 parent d636ee7 commit 96b1c84
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 3 deletions.
5 changes: 4 additions & 1 deletion ntc/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import inspect
import logging
from collections import UserDict
from pathlib import Path
from pathlib import Path, PosixPath
from typing import Any, Callable, Dict, List, Tuple, Union

import yaml
Expand Down Expand Up @@ -394,6 +394,9 @@ def _update_module(self, full_key: str, value) -> None:
lines.append(f"{full_key} = {value.__name__}\n")
elif type(value) in [int, float, str]:
lines.append(f"{full_key} = {value!r}\n")
elif type(value) == PosixPath:
lines.append("from pathlib import PosixPath\n")
lines.append(f"{full_key} = {value!r}\n")
elif isinstance(value, CfgSavable):
import_str, cls_name, args, kwargs = value.save_strs()
lines.append(f"{import_str}\n")
Expand Down
12 changes: 12 additions & 0 deletions tests/data/good/save_safe_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from pathlib import Path

from ntc import CN
from tests.data.base_cfg import cfg
from tests.data.base_class import BaseClass

cfg = CN(cfg)
cfg.NEW.int = 1
cfg.NEW.str = "foo"
cfg.NEW.float = 3.14
cfg.NEW.path = Path("example")
cfg.NEW.type = BaseClass
14 changes: 12 additions & 2 deletions tests/test_good.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,7 @@ def test_new_allowed():
def test_save_modify(tmp_path):
cfg = CN.load(DATA_DIR / "good.py")
cfg.DICT.INT = 2
cfg.DICT.FOO = "bar"
cfg.CLASS = SavableClass(dt.date(2021, 1, 1))
cfg.SUBCLASS = SubClass

save_path = tmp_path / "good.py"
cfg.save(save_path)
Expand Down Expand Up @@ -133,3 +131,15 @@ def test_save_clone(tmp_path):
cfg.save(save_path)
cfg2 = CN.load(save_path)
assert cfg == cfg2


@pytest.mark.parametrize(
"name, value", {"int": 2, "str": "bar", "float": 2.71, "path": Path("another"), "type": SubClass}.items()
)
def test_save_safe_types(tmp_path, name, value):
cfg = CN.load(DATA_DIR / "save_safe_types.py")
setattr(cfg.NEW, name, value)
save_path = tmp_path / "save_safe_types.py"
cfg.save(save_path)
cfg2 = CN.load(save_path)
assert cfg == cfg2

0 comments on commit 96b1c84

Please sign in to comment.