Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new plugin type for custom schema validators #1328

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ The ASDF Standard is at v1.6.0
- Update lower pins on ``numpy`` (per release policy), ``packaging``, and ``pyyaml`` to
ones that we can successfully build and test against. [#1360]
- Provide more informative filename when failing to open a file [#1357]
- Add new plugin type for custom schema validators. [#1328]

2.14.3 (2022-12-15)
-------------------
Expand Down
Empty file added asdf/core/__init__.py
Empty file.
23 changes: 23 additions & 0 deletions asdf/core/_extensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from asdf.extension import ManifestExtension
braingram marked this conversation as resolved.
Show resolved Hide resolved

from ._validators import ndarray

VALIDATORS = [
ndarray.NdimValidator(),
ndarray.MaxNdimValidator(),
ndarray.DatatypeValidator(),
]


MANIFEST_URIS = [
"asdf://asdf-format.org/core/manifests/core-1.0.0",
"asdf://asdf-format.org/core/manifests/core-1.1.0",
"asdf://asdf-format.org/core/manifests/core-1.2.0",
"asdf://asdf-format.org/core/manifests/core-1.3.0",
"asdf://asdf-format.org/core/manifests/core-1.4.0",
"asdf://asdf-format.org/core/manifests/core-1.5.0",
"asdf://asdf-format.org/core/manifests/core-1.6.0",
]


EXTENSIONS = [ManifestExtension.from_uri(u, validators=VALIDATORS) for u in MANIFEST_URIS]
21 changes: 21 additions & 0 deletions asdf/core/_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from asdf.resource import JsonschemaResourceMapping
braingram marked this conversation as resolved.
Show resolved Hide resolved


def get_extensions():
"""
Get the extension instances for the core extensions. This method is registered with the
asdf.extensions entry point.

Returns
-------
list of asdf.extension.Extension
"""
from . import _extensions

return _extensions.EXTENSIONS


def get_json_schema_resource_mappings():
return [
JsonschemaResourceMapping(),
]
Empty file.
28 changes: 28 additions & 0 deletions asdf/core/_validators/ndarray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from asdf.extension import Validator
from asdf.tags.core.ndarray import validate_datatype, validate_max_ndim, validate_ndim
braingram marked this conversation as resolved.
Show resolved Hide resolved


class NdimValidator(Validator):
schema_property = "ndim"
# The validators in this module should really only be applied
# to ndarray-* tags, but that will have to be a 3.0 change.
tags = ["**"]

def validate(self, expected_ndim, node, schema):
yield from validate_ndim(None, expected_ndim, node, schema)


class MaxNdimValidator(Validator):
schema_property = "max_ndim"
tags = ["**"]

def validate(self, max_ndim, node, schema):
yield from validate_max_ndim(None, max_ndim, node, schema)


class DatatypeValidator(Validator):
schema_property = "datatype"
tags = ["**"]

def validate(self, expected_datatype, node, schema):
yield from validate_datatype(None, expected_datatype, node, schema)
Empty file added asdf/core/tests/__init__.py
Empty file.
37 changes: 37 additions & 0 deletions asdf/core/tests/test_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pytest
import yaml

import asdf
from asdf.core._integration import get_extensions, get_json_schema_resource_mappings


@pytest.mark.parametrize(
"uri",
[
"http://json-schema.org/draft-04/schema",
],
)
def test_get_resource_mappings(uri):
mappings = get_json_schema_resource_mappings()

mapping = next(m for m in mappings if uri in m)
assert mapping is not None

assert uri.encode("utf-8") in mapping[uri]


def test_get_extensions():
extensions = get_extensions()
extension_uris = {e.extension_uri for e in extensions}

# No duplicates
assert len(extension_uris) == len(extensions)

resource_extension_uris = set()
resource_manager = asdf.get_config().resource_manager
for resource_uri in resource_manager:
if resource_uri.startswith("asdf://asdf-format.org/core/manifests/core-"):
resource_extension_uris.add(yaml.safe_load(resource_manager[resource_uri])["extension_uri"])

# Make sure every core manifest has a corresponding extension
assert resource_extension_uris == extension_uris
12 changes: 12 additions & 0 deletions asdf/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
from jsonschema import ValidationError

__all__ = [
"AsdfConversionWarning",
"AsdfDeprecationWarning",
"AsdfProvisionalAPIWarning",
"AsdfWarning",
"DelimiterNotFoundError",
"ValidationError",
]


class AsdfWarning(Warning):
"""
The base warning class from which all ASDF warnings should inherit.
Expand Down
2 changes: 2 additions & 0 deletions asdf/extension/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ._manager import ExtensionManager, get_cached_extension_manager
from ._manifest import ManifestExtension
from ._tag import TagDefinition
from ._validator import Validator

__all__ = [
# New API
Expand All @@ -28,6 +29,7 @@
"Converter",
"ConverterProxy",
"Compressor",
"Validator",
# Legacy API
"AsdfExtension",
"AsdfExtensionList",
Expand Down
33 changes: 33 additions & 0 deletions asdf/extension/_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ._converter import ConverterProxy
from ._legacy import AsdfExtension
from ._tag import TagDefinition
from ._validator import Validator


class Extension(abc.ABC):
Expand Down Expand Up @@ -117,6 +118,18 @@ def yaml_tag_handles(self):
"""
return {}

@property
def validators(self):
"""
Get the `asdf.extension.Validator` instances for additional
schema properties supported by this extension.

Returns
-------
iterable of asdf.extension.Validator
"""
return []


class ExtensionProxy(Extension, AsdfExtension):
"""
Expand Down Expand Up @@ -193,6 +206,14 @@ def __init__(self, delegate, package_name=None, package_version=None):
raise TypeError(msg)
self._compressors.append(compressor)

self._validators = []
if hasattr(self._delegate, "validators"):
for validator in self._delegate.validators:
if not isinstance(validator, Validator):
msg = "Extension property 'validators' must contain instances of asdf.extension.Validator"
raise TypeError(msg)
self._validators.append(validator)

@property
def extension_uri(self):
"""
Expand Down Expand Up @@ -373,6 +394,18 @@ def yaml_tag_handles(self):
"""
return self._yaml_tag_handles

@property
def validators(self):
"""
Get the `asdf.extension.Validator` instances for additional
schema properties supported by this extension.

Returns
-------
list of asdf.extension.Validator
"""
return self._validators

def __eq__(self, other):
if isinstance(other, ExtensionProxy):
return other.delegate is self.delegate
Expand Down
96 changes: 95 additions & 1 deletion asdf/extension/_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import lru_cache

from asdf.util import get_class_name
from asdf.tagged import Tagged
from asdf.util import get_class_name, uri_match
braingram marked this conversation as resolved.
Show resolved Hide resolved

from ._extension import ExtensionProxy

Expand All @@ -25,6 +26,8 @@ def __init__(self, extensions):
# This dict has both str and type keys:
self._converters_by_type = {}

validators = set()

for extension in self._extensions:
for tag_def in extension.tags:
if tag_def.tag_uri not in self._tag_defs_by_tag:
Expand All @@ -47,6 +50,10 @@ def __init__(self, extensions):
self._converters_by_type[typ] = converter
self._converters_by_type[type_class_name] = converter

validators.update(extension.validators)

self._validator_manager = _get_cached_validator_manager(tuple(validators))

@property
def extensions(self):
"""
Expand Down Expand Up @@ -182,6 +189,10 @@ def get_converter_for_type(self, typ):
)
raise KeyError(msg) from None

@property
def validator_manager(self):
return self._validator_manager


def get_cached_extension_manager(extensions):
"""
Expand Down Expand Up @@ -214,3 +225,86 @@ def get_cached_extension_manager(extensions):
@lru_cache
def _get_cached_extension_manager(extensions):
return ExtensionManager(extensions)


class ValidatorManager:
"""
Wraps a list of custom validators and indexes them by schema property.

Parameters
----------
validators : iterable of asdf.extension.Validator
List of validators to manage.
"""

def __init__(self, validators):
self._validators = list(validators)

self._validators_by_schema_property = {}
for validator in self._validators:
if validator.schema_property not in self._validators_by_schema_property:
self._validators_by_schema_property[validator.schema_property] = set()
self._validators_by_schema_property[validator.schema_property].add(validator)

self._jsonschema_validators_by_schema_property = {}
for schema_property in self._validators_by_schema_property:
self._jsonschema_validators_by_schema_property[schema_property] = self._get_jsonschema_validator(
schema_property,
)

def validate(self, schema_property, schema_property_value, node, schema):
"""
Validate an ASDF tree node against custom validators for a schema property.

Parameters
----------
schema_property : str
Name of the schema property (identifies the validator(s) to use).
schema_property_value : object
Value of the schema property.
node : asdf.tagged.Tagged
The ASDF node to validate.
schema : dict
The schema object that contains the property that triggered
the validation.

Yields
------
asdf.exceptions.ValidationError
"""
if schema_property in self._validators_by_schema_property:
for validator in self._validators_by_schema_property[schema_property]:
if _validator_matches(validator, node):
yield from validator.validate(schema_property_value, node, schema)

def get_jsonschema_validators(self):
"""
Get a dictionary of validator methods suitable for use
with the jsonschema library.

Returns
-------
dict of str: callable
"""
return dict(self._jsonschema_validators_by_schema_property)

def _get_jsonschema_validator(self, schema_property):
def _validator(_, schema_property_value, node, schema):
return self.validate(schema_property, schema_property_value, node, schema)
braingram marked this conversation as resolved.
Show resolved Hide resolved

return _validator


def _validator_matches(validator, node):
if any(t == "**" for t in validator.tags):
return True

if not isinstance(node, Tagged):
return False
eslavich marked this conversation as resolved.
Show resolved Hide resolved

return any(uri_match(t, node._tag) for t in validator.tags)


@lru_cache
def _get_cached_validator_manager(validators):
return ValidatorManager(validators)
14 changes: 13 additions & 1 deletion asdf/extension/_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ class ManifestExtension(Extension):
compressors : iterable of asdf.extension.Compressor, optional
Compressor instances to support additional binary
block compression options.
validators : iterable of asdf.extension.Validator, optional
Validator instances to support validation of custom
schema properties.
legacy_class_names : iterable of str, optional
Fully-qualified class names used by older versions
of this extension.
Expand All @@ -43,7 +46,7 @@ def from_uri(cls, manifest_uri, **kwargs):
manifest = yaml.safe_load(get_config().resource_manager[manifest_uri])
return cls(manifest, **kwargs)

def __init__(self, manifest, *, legacy_class_names=None, converters=None, compressors=None):
def __init__(self, manifest, *, legacy_class_names=None, converters=None, compressors=None, validators=None):
self._manifest = manifest

if legacy_class_names is None:
Expand All @@ -61,6 +64,11 @@ def __init__(self, manifest, *, legacy_class_names=None, converters=None, compre
else:
self._compressors = compressors

if validators is None:
self._validators = []
else:
self._validators = validators

@property
def extension_uri(self):
return self._manifest["extension_uri"]
Expand Down Expand Up @@ -93,6 +101,10 @@ def converters(self):
def compressors(self):
return self._compressors

@property
def validators(self):
return self._validators

@property
def tags(self):
result = []
Expand Down
Loading