Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions openapi_core/casting/schemas/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ def __init__(

def create(
self,
spec: SchemaPath,
schema: SchemaPath,
format_validators: Optional[FormatValidatorsDict] = None,
extra_format_validators: Optional[FormatValidatorsDict] = None,
) -> SchemaCaster:
schema_validator = self.schema_validators_factory.create(
spec,
schema,
format_validators=format_validators,
extra_format_validators=extra_format_validators,
Expand Down
5 changes: 4 additions & 1 deletion openapi_core/deserializing/media_types/deserializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def get_deserializer_callable(
class MediaTypeDeserializer:
def __init__(
self,
spec: SchemaPath,
style_deserializers_factory: StyleDeserializersFactory,
media_types_deserializer: MediaTypesDeserializer,
mimetype: str,
Expand All @@ -75,6 +76,7 @@ def __init__(
encoding: Optional[SchemaPath] = None,
**parameters: str,
):
self.spec = spec
self.style_deserializers_factory = style_deserializers_factory
self.media_types_deserializer = media_types_deserializer
self.mimetype = mimetype
Expand Down Expand Up @@ -117,6 +119,7 @@ def evolve(
schema_caster = self.schema_caster.evolve(schema)

return cls(
self.spec,
self.style_deserializers_factory,
self.media_types_deserializer,
mimetype=mimetype or self.mimetype,
Expand Down Expand Up @@ -221,7 +224,7 @@ def decode_property_style(
prep_encoding, default_location="query"
)
prop_deserializer = self.style_deserializers_factory.create(
prop_style, prop_explode, prop_schema, name=prop_name
self.spec, prop_schema, prop_style, prop_explode, name=prop_name
)
return prop_deserializer.deserialize(location)

Expand Down
4 changes: 3 additions & 1 deletion openapi_core/deserializing/media_types/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def from_schema_casters_factory(

def create(
self,
spec: SchemaPath,
mimetype: str,
schema: Optional[SchemaPath] = None,
schema_validator: Optional[SchemaValidator] = None,
Expand Down Expand Up @@ -89,11 +90,12 @@ def create(
):
schema_caster = (
self.style_deserializers_factory.schema_casters_factory.create(
schema
spec, schema
)
)

return MediaTypeDeserializer(
spec,
self.style_deserializers_factory,
media_types_deserializer,
mimetype,
Expand Down
7 changes: 2 additions & 5 deletions openapi_core/deserializing/styles/deserializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from typing import Mapping
from typing import Optional

from jsonschema_path import SchemaPath

from openapi_core.casting.schemas.casters import SchemaCaster
from openapi_core.casting.schemas.exceptions import CastError
from openapi_core.deserializing.exceptions import DeserializeError
Expand All @@ -17,15 +15,14 @@ def __init__(
style: str,
explode: bool,
name: str,
schema: SchemaPath,
schema_type: str,
caster: SchemaCaster,
deserializer_callable: Optional[DeserializerCallable] = None,
):
self.style = style
self.explode = explode
self.name = name
self.schema = schema
self.schema_type = (schema / "type").read_str("")
self.schema_type = schema_type
self.caster = caster
self.deserializer_callable = deserializer_callable

Expand Down
8 changes: 5 additions & 3 deletions openapi_core/deserializing/styles/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ def __init__(

def create(
self,
spec: SchemaPath,
schema: SchemaPath,
style: str,
explode: bool,
schema: SchemaPath,
name: str,
) -> StyleDeserializer:
deserialize_callable = self.style_deserializers.get(style)
caster = self.schema_casters_factory.create(schema)
caster = self.schema_casters_factory.create(spec, schema)
schema_type = (schema / "type").read_str("")
return StyleDeserializer(
style, explode, name, schema, caster, deserialize_callable
style, explode, name, schema_type, caster, deserialize_callable
)
2 changes: 2 additions & 0 deletions openapi_core/unmarshalling/schemas/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(

def create(
self,
spec: SchemaPath,
schema: SchemaPath,
format_validators: Optional[FormatValidatorsDict] = None,
format_unmarshallers: Optional[FormatUnmarshallersDict] = None,
Expand All @@ -51,6 +52,7 @@ def create(
if extra_format_validators is None:
extra_format_validators = {}
schema_validator = self.schema_validators_factory.create(
spec,
schema,
format_validators=format_validators,
extra_format_validators=extra_format_validators,
Expand Down
1 change: 1 addition & 0 deletions openapi_core/unmarshalling/unmarshallers.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def __init__(

def _unmarshal_schema(self, schema: SchemaPath, value: Any) -> Any:
unmarshaller = self.schema_unmarshallers_factory.create(
self.spec,
schema,
format_validators=self.format_validators,
extra_format_validators=self.extra_format_validators,
Expand Down
39 changes: 8 additions & 31 deletions openapi_core/validation/schemas/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from functools import partial

from lazy_object_proxy import Proxy
from openapi_schema_validator import OAS31_BASE_DIALECT_ID
from openapi_schema_validator import OAS32_BASE_DIALECT_ID
from openapi_schema_validator import OAS30ReadValidator
from openapi_schema_validator import OAS30WriteValidator
from openapi_schema_validator import OAS31Validator
from openapi_schema_validator import OAS32Validator

from openapi_core.validation.schemas._validators import (
build_forbid_unspecified_additional_properties_validator,
from openapi_core.validation.schemas.factories import (
DialectSchemaValidatorsFactory,
)
from openapi_core.validation.schemas.factories import SchemaValidatorsFactory

Expand All @@ -20,44 +19,22 @@

oas30_write_schema_validators_factory = SchemaValidatorsFactory(
OAS30WriteValidator,
Proxy(
partial(
build_forbid_unspecified_additional_properties_validator,
OAS30WriteValidator,
)
),
)

oas30_read_schema_validators_factory = SchemaValidatorsFactory(
OAS30ReadValidator,
Proxy(
partial(
build_forbid_unspecified_additional_properties_validator,
OAS30ReadValidator,
)
),
)

oas31_schema_validators_factory = SchemaValidatorsFactory(
oas31_schema_validators_factory = DialectSchemaValidatorsFactory(
OAS31Validator,
Proxy(
partial(
build_forbid_unspecified_additional_properties_validator,
OAS31Validator,
)
),
OAS31_BASE_DIALECT_ID,
# NOTE: Intentionally use OAS 3.0 format checker for OAS 3.1 to preserve
# backward compatibility for `byte`/`binary` formats.
# See https://github.com/python-openapi/openapi-core/issues/506
format_checker=OAS30ReadValidator.FORMAT_CHECKER,
)

oas32_schema_validators_factory = SchemaValidatorsFactory(
oas32_schema_validators_factory = DialectSchemaValidatorsFactory(
OAS32Validator,
Proxy(
partial(
build_forbid_unspecified_additional_properties_validator,
OAS32Validator,
)
),
OAS32_BASE_DIALECT_ID,
)
89 changes: 73 additions & 16 deletions openapi_core/validation/schemas/factories.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,41 @@
from copy import deepcopy
from functools import lru_cache
from typing import Any
from typing import Optional
from typing import cast

from jsonschema._format import FormatChecker
from jsonschema.protocols import Validator
from jsonschema.validators import validator_for
from jsonschema_path import SchemaPath

from openapi_core.validation.schemas._validators import (
build_enforce_properties_required_validator,
)
from openapi_core.validation.schemas._validators import (
build_forbid_unspecified_additional_properties_validator,
)
from openapi_core.validation.schemas.datatypes import FormatValidatorsDict
from openapi_core.validation.schemas.validators import SchemaValidator


class SchemaValidatorsFactory:
def __init__(
self,
schema_validator_class: type[Validator],
strict_schema_validator_class: Optional[type[Validator]] = None,
schema_validator_cls: type[Validator],
format_checker: Optional[FormatChecker] = None,
):
self.schema_validator_class = schema_validator_class
self.strict_schema_validator_class = strict_schema_validator_class
self.schema_validator_cls = schema_validator_cls
if format_checker is None:
format_checker = self.schema_validator_class.FORMAT_CHECKER
format_checker = self.schema_validator_cls.FORMAT_CHECKER
assert format_checker is not None
self.format_checker = format_checker

def get_validator_cls(
self, spec: SchemaPath, schema: SchemaPath
) -> type[Validator]:
return self.schema_validator_cls

def get_format_checker(
self,
format_validators: Optional[FormatValidatorsDict] = None,
Expand Down Expand Up @@ -57,34 +66,82 @@ def _add_validators(

def create(
self,
spec: SchemaPath,
schema: SchemaPath,
format_validators: Optional[FormatValidatorsDict] = None,
extra_format_validators: Optional[FormatValidatorsDict] = None,
forbid_unspecified_additional_properties: bool = False,
enforce_properties_required: bool = False,
) -> SchemaValidator:
validator_class: type[Validator] = self.schema_validator_class
validator_cls: type[Validator] = self.get_validator_cls(spec, schema)
if enforce_properties_required:
validator_cls = build_enforce_properties_required_validator(
validator_cls
)
if forbid_unspecified_additional_properties:
if self.strict_schema_validator_class is None:
raise ValueError(
"Strict additional properties validation is not supported "
"by this factory."
validator_cls = (
build_forbid_unspecified_additional_properties_validator(
validator_cls
)
validator_class = self.strict_schema_validator_class

if enforce_properties_required:
validator_class = build_enforce_properties_required_validator(
validator_class
)

format_checker = self.get_format_checker(
format_validators, extra_format_validators
)
with schema.resolve() as resolved:
jsonschema_validator = validator_class(
jsonschema_validator = validator_cls(
resolved.contents,
_resolver=resolved.resolver,
format_checker=format_checker,
)

return SchemaValidator(schema, jsonschema_validator)


class DialectSchemaValidatorsFactory(SchemaValidatorsFactory):
def __init__(
self,
schema_validator_cls: type[Validator],
default_jsonschema_dialect_id: str,
format_checker: Optional[FormatChecker] = None,
):
super().__init__(schema_validator_cls, format_checker)
self.default_jsonschema_dialect_id = default_jsonschema_dialect_id

def get_validator_cls(
self, spec: SchemaPath, schema: SchemaPath
) -> type[Validator]:
dialect_id = self._get_dialect_id(spec, schema)

validator_cls = self._get_validator_class_for_dialect(dialect_id)
if validator_cls is None:
raise ValueError(f"Unknown JSON Schema dialect: {dialect_id!r}")

return validator_cls

def _get_dialect_id(
self,
spec: SchemaPath,
schema: SchemaPath,
) -> str:
try:
return (schema / "$schema").read_str()
except KeyError:
return self._get_default_jsonschema_dialect_id(spec)

def _get_default_jsonschema_dialect_id(self, spec: SchemaPath) -> str:
return (spec / "jsonSchemaDialect").read_str(
default=self.default_jsonschema_dialect_id
)

@lru_cache
def _get_validator_class_for_dialect(
self, dialect_id: str
) -> type[Validator] | None:
return cast(
type[Validator] | None,
validator_for(
{"$schema": dialect_id},
default=cast(Any, None),
),
)
5 changes: 4 additions & 1 deletion openapi_core/validation/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,15 @@ def _deserialise_media_type(
schema_validator = None
if schema is not None:
schema_validator = self.schema_validators_factory.create(
self.spec,
schema,
format_validators=self.format_validators,
extra_format_validators=self.extra_format_validators,
forbid_unspecified_additional_properties=self.forbid_unspecified_additional_properties,
enforce_properties_required=self.enforce_properties_required,
)
deserializer = self.media_type_deserializers_factory.create(
self.spec,
mimetype,
schema=schema,
schema_validator=schema_validator,
Expand All @@ -169,12 +171,13 @@ def _deserialise_style(
style, explode = get_style_and_explode(param_or_header)
schema = param_or_header / "schema"
deserializer = self.style_deserializers_factory.create(
style, explode, schema, name=name
self.spec, schema, style, explode, name=name
)
return deserializer.deserialize(location)

def _validate_schema(self, schema: SchemaPath, value: Any) -> None:
validator = self.schema_validators_factory.create(
self.spec,
schema,
format_validators=self.format_validators,
extra_format_validators=self.extra_format_validators,
Expand Down
Loading
Loading