Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Add new arguments-v3 schema #1641

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3489,6 +3489,116 @@ def arguments_schema(
)


class ArgumentsV3Parameter(TypedDict, total=False):
name: Required[str]
schema: Required[CoreSchema]
mode: Literal[
'positional_only',
'positional_or_keyword',
'keyword_only',
'var_args',
'var_kwargs_uniform',
'var_kwargs_unpacked_typed_dict',
] # default positional_or_keyword
alias: Union[str, list[Union[str, int]], list[list[Union[str, int]]]]


def arguments_v3_parameter(
name: str,
schema: CoreSchema,
*,
mode: Literal[
'positional_only',
'positional_or_keyword',
'keyword_only',
'var_args',
'var_kwargs_uniform',
'var_kwargs_unpacked_typed_dict',
]
| None = None,
alias: str | list[str | int] | list[list[str | int]] | None = None,
) -> ArgumentsV3Parameter:
"""
Returns a schema that matches an argument parameter, e.g.:

```py
from pydantic_core import SchemaValidator, core_schema

param = core_schema.arguments_v3_parameter(
name='a', schema=core_schema.str_schema(), mode='positional_only'
)
schema = core_schema.arguments_v3_schema([param])
v = SchemaValidator(schema)
assert v.validate_python({'a': 'hello'}) == (('hello',), {})
```

Args:
name: The name to use for the argument parameter
schema: The schema to use for the argument parameter
mode: The mode to use for the argument parameter
alias: The alias to use for the argument parameter
"""
return _dict_not_none(name=name, schema=schema, mode=mode, alias=alias)


class ArgumentsV3Schema(TypedDict, total=False):
type: Required[Literal['arguments-v3']]
arguments_schema: Required[list[ArgumentsV3Parameter]]
validate_by_name: bool
validate_by_alias: bool
var_args_schema: CoreSchema
var_kwargs_mode: VarKwargsMode
var_kwargs_schema: CoreSchema
ref: str
metadata: dict[str, Any]
serialization: SerSchema


def arguments_v3_schema(
arguments: list[ArgumentsV3Parameter],
*,
validate_by_name: bool | None = None,
validate_by_alias: bool | None = None,
ref: str | None = None,
metadata: dict[str, Any] | None = None,
serialization: SerSchema | None = None,
) -> ArgumentsV3Schema:
"""
Returns a schema that matches an arguments schema, e.g.:

```py
from pydantic_core import SchemaValidator, core_schema

param_a = core_schema.arguments_v3_parameter(
name='a', schema=core_schema.str_schema(), mode='positional_only'
)
param_b = core_schema.arguments_v3_parameter(
name='kwargs', schema=core_schema.bool_schema(), mode='var_kwargs_uniform'
)
schema = core_schema.arguments_v3_schema([param_a, param_b])
v = SchemaValidator(schema)
assert v.validate_python({'a': 'hello', 'kwargs': {'extra': True}}) == (('hello',), {'extra': True})
```

Args:
arguments: The arguments to use for the arguments schema.
validate_by_name: Whether to populate by the parameter names, defaults to `False`.
validate_by_alias: Whether to populate by the parameter aliases, defaults to `True`.
ref: optional unique identifier of the schema, used to reference the schema in other places.
metadata: Any other information you want to include with the schema, not used by pydantic-core.
serialization: Custom serialization schema.
"""
return _dict_not_none(
type='arguments-v3',
arguments_schema=arguments,
validate_by_name=validate_by_name,
validate_by_alias=validate_by_alias,
ref=ref,
metadata=metadata,
serialization=serialization,
)


class CallSchema(TypedDict, total=False):
type: Required[Literal['call']]
arguments_schema: Required[CoreSchema]
Expand Down Expand Up @@ -3916,6 +4026,7 @@ def definition_reference_schema(
DataclassArgsSchema,
DataclassSchema,
ArgumentsSchema,
ArgumentsV3Schema,
CallSchema,
CustomErrorSchema,
JsonSchema,
Expand Down
6 changes: 6 additions & 0 deletions src/input/input_abstract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ pub trait Input<'py>: fmt::Debug {

fn validate_args(&self) -> ValResult<Self::Arguments<'_>>;

fn validate_args_v3(&self) -> ValResult<Self::Arguments<'_>>;

fn validate_dataclass_args<'a>(&'a self, dataclass_name: &str) -> ValResult<Self::Arguments<'a>>;

fn validate_str(&self, strict: bool, coerce_numbers_to_str: bool) -> ValMatch<EitherString<'_>>;
Expand Down Expand Up @@ -265,6 +267,7 @@ pub trait ValidatedList<'py> {
pub trait ValidatedTuple<'py> {
type Item: BorrowInput<'py>;
fn len(&self) -> Option<usize>;
fn try_for_each(self, f: impl FnMut(PyResult<Self::Item>) -> ValResult<()>) -> ValResult<()>;
fn iterate<R>(self, consumer: impl ConsumeIterator<PyResult<Self::Item>, Output = R>) -> ValResult<R>;
}

Expand Down Expand Up @@ -313,6 +316,9 @@ impl<'py> ValidatedTuple<'py> for Never {
fn len(&self) -> Option<usize> {
unreachable!()
}
fn try_for_each(self, _f: impl FnMut(PyResult<Self::Item>) -> ValResult<()>) -> ValResult<()> {
unreachable!()
}
fn iterate<R>(self, _consumer: impl ConsumeIterator<PyResult<Self::Item>, Output = R>) -> ValResult<R> {
unreachable!()
}
Expand Down
16 changes: 16 additions & 0 deletions src/input/input_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ impl<'py, 'data> Input<'py> for JsonValue<'data> {
}
}

#[cfg_attr(has_coverage_attribute, coverage(off))]
fn validate_args_v3(&self) -> ValResult<Self::Arguments<'_>> {
Err(ValError::new(ErrorTypeDefaults::ArgumentsType, self))
}

fn validate_dataclass_args<'a>(&'a self, class_name: &str) -> ValResult<JsonArgs<'a, 'data>> {
match self {
JsonValue::Object(object) => Ok(JsonArgs::new(None, Some(object))),
Expand Down Expand Up @@ -375,6 +380,11 @@ impl<'py> Input<'py> for str {
Err(ValError::new(ErrorTypeDefaults::ArgumentsType, self))
}

#[cfg_attr(has_coverage_attribute, coverage(off))]
fn validate_args_v3(&self) -> ValResult<Never> {
Err(ValError::new(ErrorTypeDefaults::ArgumentsType, self))
}

#[cfg_attr(has_coverage_attribute, coverage(off))]
fn validate_dataclass_args(&self, class_name: &str) -> ValResult<Never> {
let class_name = class_name.to_string();
Expand Down Expand Up @@ -571,6 +581,12 @@ impl<'a, 'data> ValidatedTuple<'_> for &'a JsonArray<'data> {
fn len(&self) -> Option<usize> {
Some(SmallVec::len(self))
}
fn try_for_each(self, mut f: impl FnMut(PyResult<Self::Item>) -> ValResult<()>) -> ValResult<()> {
for item in self.iter() {
f(Ok(item))?;
}
Ok(())
}
fn iterate<R>(self, consumer: impl ConsumeIterator<PyResult<Self::Item>, Output = R>) -> ValResult<R> {
Ok(consumer.consume_iterator(self.iter().map(Ok)))
}
Expand Down
23 changes: 22 additions & 1 deletion src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,16 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
}
}

fn validate_args_v3(&self) -> ValResult<PyArgs<'py>> {
if let Ok(args_kwargs) = self.extract::<ArgsKwargs>() {
let args = args_kwargs.args.into_bound(self.py());
let kwargs = args_kwargs.kwargs.map(|d| d.into_bound(self.py()));
Ok(PyArgs::new(Some(args), kwargs))
} else {
Err(ValError::new(ErrorTypeDefaults::ArgumentsType, self))
}
}

fn validate_dataclass_args<'a>(&'a self, class_name: &str) -> ValResult<PyArgs<'py>> {
if let Ok(dict) = self.downcast::<PyDict>() {
Ok(PyArgs::new(None, Some(dict.clone())))
Expand Down Expand Up @@ -915,7 +925,15 @@ impl<'py> PySequenceIterable<'_, 'py> {
PySequenceIterable::Iterator(iter) => iter.len().ok(),
}
}

fn generic_try_for_each(self, f: impl FnMut(PyResult<Bound<'py, PyAny>>) -> ValResult<()>) -> ValResult<()> {
match self {
PySequenceIterable::List(iter) => iter.iter().map(Ok).try_for_each(f),
PySequenceIterable::Tuple(iter) => iter.iter().map(Ok).try_for_each(f),
PySequenceIterable::Set(iter) => iter.iter().map(Ok).try_for_each(f),
PySequenceIterable::FrozenSet(iter) => iter.iter().map(Ok).try_for_each(f),
PySequenceIterable::Iterator(mut iter) => iter.try_for_each(f),
}
}
fn generic_iterate<R>(
self,
consumer: impl ConsumeIterator<PyResult<Bound<'py, PyAny>>, Output = R>,
Expand Down Expand Up @@ -951,6 +969,9 @@ impl<'py> ValidatedTuple<'py> for PySequenceIterable<'_, 'py> {
fn len(&self) -> Option<usize> {
self.generic_len()
}
fn try_for_each(self, f: impl FnMut(PyResult<Self::Item>) -> ValResult<()>) -> ValResult<()> {
self.generic_try_for_each(f)
}
fn iterate<R>(self, consumer: impl ConsumeIterator<PyResult<Self::Item>, Output = R>) -> ValResult<R> {
self.generic_iterate(consumer)
}
Expand Down
5 changes: 5 additions & 0 deletions src/input/input_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ impl<'py> Input<'py> for StringMapping<'py> {
Err(ValError::new(ErrorTypeDefaults::ArgumentsType, self))
}

fn validate_args_v3(&self) -> ValResult<Self::Arguments<'_>> {
// do we want to support this?
Err(ValError::new(ErrorTypeDefaults::ArgumentsType, self))
}

fn validate_dataclass_args<'a>(&'a self, _dataclass_name: &str) -> ValResult<StringMappingDict<'py>> {
match self {
StringMapping::String(_) => Err(ValError::new(ErrorTypeDefaults::ArgumentsType, self)),
Expand Down
Loading
Loading