Skip to content

Commit

Permalink
Merge pull request #6 from braingram/braingram-no-patch-iter-errors
Browse files Browse the repository at this point in the history
add cache to get_schema, combine instance checks with validate
  • Loading branch information
eslavich authored Mar 28, 2023
2 parents c24fa51 + 6e319b3 commit b1b9ec9
Showing 1 changed file with 26 additions and 10 deletions.
36 changes: 26 additions & 10 deletions asdf/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,7 @@ def _make_jsonschema_resolver_or_registry(url_mapping):

if _USE_REFERENCING:

@lru_cache
def retrieve_schema(url):
schema = schema_loader(url)[0]
return referencing.Resource(schema, specification=referencing.jsonschema.DRAFT4)
Expand Down Expand Up @@ -615,6 +616,7 @@ def __init__(
visit_repeat_nodes,
resolver=None,
registry=None,
instance_checks=None,
):
self._ctx = ctx
self._serialization_context = serialization_context
Expand All @@ -623,6 +625,7 @@ def __init__(
self._visit_repeat_nodes = visit_repeat_nodes
self._resolver = resolver
self._registry = registry
self._instance_checks = instance_checks or []

def _create_validator(self, schema):
if _USE_REFERENCING:
Expand All @@ -631,13 +634,22 @@ def _create_validator(self, schema):
return self._validator_class(schema, resolver=self._resolver)

def _iter_errors(self, instance, _schema=None):
# if we have a schema for this instance, validate the instance
if _schema is not None:
yield from self._create_validator(_schema).iter_errors(instance)
elif self._schema is not None:
yield from self._create_validator(self._schema).iter_errors(instance)

# run _instance_checks on this node
[check(instance) for check in self._instance_checks]

# next, look for tagged child nodes
for node in treeutil.iter_tree(instance, _visit_repeat_nodes=self._visit_repeat_nodes):
# run _instance_checks on the child node
[check(node) for check in self._instance_checks]

tag = getattr(node, "_tag", None)
# if this node is tagged, check it against the corresponding schema
if tag is not None:
if self._serialization_context.extension_manager.handles_tag_definition(tag):
tag_def = self._serialization_context.extension_manager.get_tag_definition(tag)
Expand Down Expand Up @@ -679,6 +691,7 @@ def get_validator(
url_mapping=None,
_visit_repeat_nodes=False,
_serialization_context=None,
_instance_checks=None,
):
"""
Get a validator object for the given schema. This method is not
Expand Down Expand Up @@ -736,6 +749,7 @@ def get_validator(
validator_class=validator_class,
schema=schema,
visit_repeat_nodes=_visit_repeat_nodes,
instance_checks=_instance_checks,
**resolver_kwargs,
)

Expand Down Expand Up @@ -825,18 +839,20 @@ def validate(instance, ctx=None, schema=None, validators=None, reading=False, *a

ctx = AsdfFile()

validator = get_validator({} if schema is None else schema, ctx, validators, ctx._resolver, *args, **kwargs)
validator.validate(instance)

additional_validators = [_validate_large_literals]
instance_checks = [lambda i: _validate_large_literals(i, reading)]
if ctx.version >= versioning.RESTRICTED_KEYS_MIN_VERSION:
additional_validators.append(_validate_mapping_keys)
instance_checks.append(lambda i: _validate_mapping_keys(i, reading))

def _callback(instance):
for validator in additional_validators:
validator(instance, reading)

treeutil.walk(instance, _callback)
kwargs["_instance_checks"] = instance_checks
validator = get_validator(
{} if schema is None else schema,
ctx,
validators,
ctx._resolver,
*args,
**kwargs,
)
validator.validate(instance)


def fill_defaults(instance, ctx, reading=False):
Expand Down

0 comments on commit b1b9ec9

Please # to comment.