From 6e319b3cbd198ff4edb09d13271eb7137dff58bc Mon Sep 17 00:00:00 2001 From: Brett Date: Thu, 23 Mar 2023 11:03:06 -0400 Subject: [PATCH] add cache to get_schema, combine instance checks with validate --- asdf/schema.py | 36 ++++++++++++++++++++++++++---------- 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/asdf/schema.py b/asdf/schema.py index 835d7b9aa..533cef56b 100644 --- a/asdf/schema.py +++ b/asdf/schema.py @@ -425,6 +425,7 @@ def _make_jsonschema_resolver_or_registry(url_mapping): if _USE_REFERENCING: + @lru_cache def retrieve_schema(url): schema = schema_loader(url)[0] return referencing.Resource(schema, specification=referencing.jsonschema.DRAFT4) @@ -615,6 +616,7 @@ def __init__( visit_repeat_nodes, resolver=None, registry=None, + instance_checks=None, ): self._ctx = ctx self._serialization_context = serialization_context @@ -623,6 +625,7 @@ def __init__( self._visit_repeat_nodes = visit_repeat_nodes self._resolver = resolver self._registry = registry + self._instance_checks = instance_checks or [] def _create_validator(self, schema): if _USE_REFERENCING: @@ -631,13 +634,22 @@ def _create_validator(self, schema): return self._validator_class(schema, resolver=self._resolver) def _iter_errors(self, instance, _schema=None): + # if we have a schema for this instance, validate the instance if _schema is not None: yield from self._create_validator(_schema).iter_errors(instance) elif self._schema is not None: yield from self._create_validator(self._schema).iter_errors(instance) + # run _instance_checks on this node + [check(instance) for check in self._instance_checks] + + # next, look for tagged child nodes for node in treeutil.iter_tree(instance, _visit_repeat_nodes=self._visit_repeat_nodes): + # run _instance_checks on the child node + [check(node) for check in self._instance_checks] + tag = getattr(node, "_tag", None) + # if this node is tagged, check it against the corresponding schema if tag is not None: if self._serialization_context.extension_manager.handles_tag_definition(tag): tag_def = self._serialization_context.extension_manager.get_tag_definition(tag) @@ -679,6 +691,7 @@ def get_validator( url_mapping=None, _visit_repeat_nodes=False, _serialization_context=None, + _instance_checks=None, ): """ Get a validator object for the given schema. This method is not @@ -736,6 +749,7 @@ def get_validator( validator_class=validator_class, schema=schema, visit_repeat_nodes=_visit_repeat_nodes, + instance_checks=_instance_checks, **resolver_kwargs, ) @@ -825,18 +839,20 @@ def validate(instance, ctx=None, schema=None, validators=None, reading=False, *a ctx = AsdfFile() - validator = get_validator({} if schema is None else schema, ctx, validators, ctx._resolver, *args, **kwargs) - validator.validate(instance) - - additional_validators = [_validate_large_literals] + instance_checks = [lambda i: _validate_large_literals(i, reading)] if ctx.version >= versioning.RESTRICTED_KEYS_MIN_VERSION: - additional_validators.append(_validate_mapping_keys) + instance_checks.append(lambda i: _validate_mapping_keys(i, reading)) - def _callback(instance): - for validator in additional_validators: - validator(instance, reading) - - treeutil.walk(instance, _callback) + kwargs["_instance_checks"] = instance_checks + validator = get_validator( + {} if schema is None else schema, + ctx, + validators, + ctx._resolver, + *args, + **kwargs, + ) + validator.validate(instance) def fill_defaults(instance, ctx, reading=False):