Skip to content

Commit

Permalink
check was_unfrozen; separate errors module; import resolve
Browse files Browse the repository at this point in the history
  • Loading branch information
Vladimir Mikhaylov committed Sep 18, 2020
1 parent e74a36e commit cb885d4
Show file tree
Hide file tree
Showing 15 changed files with 254 additions and 129 deletions.
1 change: 1 addition & 0 deletions ntc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .config import *
from .errors import *

__version__ = "0.1.0"
105 changes: 50 additions & 55 deletions ntc/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,62 +5,44 @@

import yaml

from ntc.utils import import_module, merge_module


class ConfigError(Exception):
pass


class TypeMismatch(ConfigError):
pass


class NodeReassignment(ConfigError):
pass


class ModuleError(ConfigError):
pass


class SchemaError(ConfigError):
pass


class ValidationError(ConfigError):
pass


class SpecError(ConfigError):
pass


class NodeFrozenError(ConfigError):
pass
from ntc.errors import (
NodeFrozenError,
NodeReassignment,
SaveError,
SchemaError,
SchemaFrozenError,
SpecError,
TypeMismatch,
ValidationError,
)
from ntc.utils import add_yaml_representer, import_module, merge_cfg_module


class CfgNode:
_SCHEMA_FROZEN = "_schema_frozen"
_FROZEN = "_frozen"
_WAS_UNFROZEN = "_was_unfrozen"
_LEAF_SPEC = "_leaf_spec"
_MODULE = "_module"

def __init__(self, base: CfgNode = None, leaf_spec: Union[CfgLeaf, _CfgLeafSpec] = None,
schema_frozen: bool = False, frozen: bool = False):
def __init__(
self,
base: CfgNode = None,
leaf_spec: Union[CfgLeaf, _CfgLeafSpec] = None,
schema_frozen: bool = False,
frozen: bool = False,
):
# TODO: make access to attributes prettier
super().__setattr__(CfgNode._SCHEMA_FROZEN, schema_frozen)
super().__setattr__(CfgNode._FROZEN, frozen)
if leaf_spec and isinstance(leaf_spec, CfgLeaf):
leaf_spec = _CfgLeafSpec.from_leaf(leaf_spec)
super().__setattr__(CfgNode._LEAF_SPEC, leaf_spec)
super().__setattr__(CfgNode._MODULE, None)
super().__setattr__(CfgNode._WAS_UNFROZEN, False)

if base is not None:
super().__setattr__(CfgNode._LEAF_SPEC, base.leaf_spec())
self._set_attrs(base.attrs())
self.freeze_schema()
super().__setattr__(CfgNode._FROZEN, base.is_frozen())
self._init_with_base(base)

def __setattr__(self, key: str, value: Any) -> None:
if self.is_frozen():
Expand All @@ -77,28 +59,36 @@ def __getattribute__(self, item: str) -> Any:
return attr

def __eq__(self, other: CfgNode) -> bool:
print(self.to_dict())
print(other.to_dict())
return self.to_dict() == other.to_dict()

# def __str__(self) -> str:
# # TODO: handle custom class objects dump
# attrs = self.to_dict()
# return yaml.dump(attrs)
def __str__(self) -> str:
add_yaml_representer()
attrs = self.to_dict()
return yaml.dump(attrs)

def __len__(self):
return len(self.attrs())

def load(self, cfg_path: Path) -> CfgNode:
@staticmethod
def load(cfg_path: Union[Path, str]) -> CfgNode:
module = import_module(cfg_path)
cfg = module.cfg
cfg.validate()
cfg.set_module(module)
cfg.freeze()

return cfg

@staticmethod
def merge_module(module_path: Path, output_path: Path):
merge_cfg_module(module_path, output_path)

def save(self, path: Path) -> None:
# TODO: implement
merge_module(self.get_module(), path)
if self.was_unfrozen():
raise SaveError("Trying to save config which was unfrozen.")
if not self.get_module():
raise SaveError("Config was not loaded.")
self.merge_module(self.get_module(), path)

def clone(self) -> CfgNode:
return CfgNode(base=self)
Expand Down Expand Up @@ -149,6 +139,7 @@ def freeze(self):

def unfreeze(self):
super().__setattr__(CfgNode._FROZEN, False)
super().__setattr__(CfgNode._WAS_UNFROZEN, True)
for key, attr in self.attrs():
if isinstance(attr, CfgNode):
attr.unfreeze()
Expand All @@ -165,6 +156,10 @@ def leaf_spec(self):
# TODO: try to make it property
return super().__getattribute__(CfgNode._LEAF_SPEC)

def was_unfrozen(self):
# TODO: try to make it property
return super().__getattribute__(CfgNode._WAS_UNFROZEN)

def set_module(self, module):
super().__setattr__(CfgNode._MODULE, module)

Expand All @@ -181,7 +176,7 @@ def _set_new_attr(self, key: str, value: Any) -> None:
if leaf_spec:
raise SchemaError(f"Key {key} cannot contain nested nodes as leaf spec is defined for it.")
if self.is_schema_frozen():
raise SchemaError(f"Trying to add node {key}, but schema is frozen.")
raise SchemaFrozenError(f"Trying to add node {key}, but schema is frozen.")
value_to_set = value
elif isinstance(value, CfgLeaf):
value_to_set = value
Expand All @@ -195,7 +190,7 @@ def _set_new_attr(self, key: str, value: Any) -> None:
raise SchemaError(f"Leaf at key {key} mismatches config node's leaf spec.")
else:
if self.is_schema_frozen():
raise SchemaError(f"Trying to add leaf {key} to frozen node without leaf spec.")
raise SchemaFrozenError(f"Trying to add leaf {key} to frozen node without leaf spec.")
super().__setattr__(key, value_to_set)

def _set_existing_attr(self, key: str, value: Any) -> None:
Expand All @@ -212,6 +207,11 @@ def _set_existing_attr(self, key: str, value: Any) -> None:
)
super().__setattr__(key, value)

def _init_with_base(self, base: CfgNode):
super().__setattr__(CfgNode._LEAF_SPEC, base.leaf_spec())
self._set_attrs(base.attrs())
self.freeze_schema()


class CfgLeaf:
def __init__(self, value: Any, type_: Type, required: bool = False):
Expand Down Expand Up @@ -254,9 +254,4 @@ def from_leaf(cfg_leaf: CfgLeaf) -> _CfgLeafSpec:
"CL",
"CfgNode",
"CfgLeaf",
"TypeMismatch",
"NodeReassignment",
"ModuleError",
"ValidationError",
"SchemaError",
]
38 changes: 38 additions & 0 deletions ntc/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
class ConfigError(Exception):
pass


class TypeMismatch(ConfigError):
pass


class NodeReassignment(ConfigError):
pass


class ModuleError(ConfigError):
pass


class SchemaError(ConfigError):
pass


class SchemaFrozenError(SchemaError):
pass


class ValidationError(ConfigError):
pass


class SpecError(ConfigError):
pass


class NodeFrozenError(ConfigError):
pass


class SaveError(ConfigError):
pass
91 changes: 55 additions & 36 deletions ntc/utils.py
Original file line number Diff line number Diff line change
@@ -1,72 +1,91 @@
import ast
import astor
import inspect
import importlib.util
import logging
import sys
from pathlib import Path
from types import ModuleType
from typing import Set, Union

import yaml

def import_module(module_path: Path) -> ModuleType:
# TODO: resolve it generally, not for tests
spec = importlib.util.spec_from_file_location("tests" + "." + "data", module_path.parent / "__init__.py")
module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module)
logger = logging.getLogger(__name__)

spec = importlib.util.spec_from_file_location("tests" + "." + "data" + "." + module_path.stem, module_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)

return module
def import_module(module_path: Union[Path, str]) -> ModuleType:
if isinstance(module_path, str):
module_path = Path(module_path)
package = _load_package(module_path.parent)
module_name = module_path.stem
if package:
module_name = ".".join((package, module_name))

return _load_module(module_name, module_path)

def merge_module(module: ModuleType, output_path: Path, clean: bool = True) -> None:

def merge_cfg_module(
module: Union[ModuleType, Path, str], output_path: Path, clean: bool = True, imported_modules: Set[str] = None
) -> None:
if isinstance(module, (Path, str)):
module = import_module(module)
if isinstance(output_path, str):
output_path = Path(output_path)
if imported_modules is None:
imported_modules = set()
module_name = module.__spec__.name
if module_name in imported_modules:
return
if clean:
output_path.unlink(missing_ok=True)

module_path = Path(module.__file__)
# raise RuntimeError(_merge_source(module))

with module_path.open() as module_file:
_append_to_file(output_path, f"# START --- {module_path} ---\n")
for line in module_file:
if line.startswith("from "):
import_members = line.strip().split(" ")
imported_module = importlib.import_module(import_members[1], package=module.__package__)
if imported_module.__spec__.name in imported_modules:
continue
if import_members[3] == "cfg":
imported_module = importlib.import_module(import_members[1], package=module.__package__)
# if imported_module.__spec__.loader.is_package(imported_module.__spec__.name):
# pass
# else:
# if not _is_builtin_module(imported_module):
merge_module(imported_module, output_path, clean=False)
merge_cfg_module(imported_module, output_path, clean=False, imported_modules=imported_modules)
elif import_members[1].startswith("."):
logger.debug(f"Skipping non cfg relative import {line}")
else:
_append_to_file(output_path, line)
# if "import " not in line:
else:
_append_to_file(output_path, line)
_append_to_file(output_path, f"# END --- {module_path} ---\n")

imported_modules.add(module_name)

def _merge_source(module: ModuleType) -> ast.AST:
module_path = Path(module.__file__)

with module_path.open() as module_file:
source = module_file.read()
ast_node = ast.parse(source)
ast.fix_missing_locations(ast_node)
raise RuntimeError(astor.to_source(ast_node))
def add_yaml_representer():
def obj_representer(dumper, data):
return dumper.represent_scalar("tag:yaml.org,2002:str", str(data))

yaml.add_multi_representer(object, obj_representer)


def _load_module(module_name: str, module_path: Path) -> ModuleType:
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module)

return module

def _is_module_package(module: ModuleType) -> bool:
return module.__spec__.loader.is_package(module.__spec__.name)

def _load_package(package_path: Path) -> str:
init_path = package_path / "__init__.py"
if not init_path.exists():
return ""
package_name = package_path.stem
parent_package_name = _load_package(package_path.parent)
if parent_package_name:
package_name = ".".join((parent_package_name, package_name))
_load_module(package_name, init_path)

def _is_builtin_module(module: ModuleType) -> bool:
for path_dir in sys.path[1:]:
if module.__file__.startswith(path_dir):
return True
return False
return package_name


def _append_to_file(output_path: Path, data: str) -> None:
Expand Down
3 changes: 2 additions & 1 deletion tests/data/bad.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .base_cfg import cfg
from ntc import CN

from .base_cfg import cfg

cfg = CN(cfg)
cfg.NAME = 1
3 changes: 2 additions & 1 deletion tests/data/bad_attr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .base_cfg import cfg
from ntc import CN

from .base_cfg import cfg

cfg = CN(cfg)

cfg.NAME = "bad_attr"
Expand Down
3 changes: 2 additions & 1 deletion tests/data/bad_class.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .good import cfg
from ntc import CN

from .good import cfg


class BadClass:
pass
Expand Down
3 changes: 2 additions & 1 deletion tests/data/bad_node.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .base_cfg import cfg
from ntc import CN

from .base_cfg import cfg


class BadClass:
pass
Expand Down
Loading

0 comments on commit cb885d4

Please sign in to comment.