Skip to content

Commit

Permalink
Add new arguments schema
Browse files Browse the repository at this point in the history
  • Loading branch information
Viicos committed Feb 27, 2025
1 parent 741961c commit 8148d18
Show file tree
Hide file tree
Showing 7 changed files with 732 additions and 0 deletions.
106 changes: 106 additions & 0 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3498,6 +3498,112 @@ def arguments_schema(
)


class ArgumentsV3Parameter(TypedDict, total=False):
name: Required[str]
schema: Required[CoreSchema]
mode: Literal[
'positional_only',
'positional_or_keyword',
'keyword_only',
'var_args',
'var_kwargs_uniform',
'var_kwargs_unpacked_typed_dict',
] # default positional_or_keyword
alias: Union[str, list[Union[str, int]], list[list[Union[str, int]]]]


def arguments_v3_parameter(
name: str,
schema: CoreSchema,
*,
mode: Literal[
'positional_only',
'positional_or_keyword',
'keyword_only',
'var_args',
'var_kwargs_uniform',
'var_kwargs_unpacked_typed_dict',
]
| None = None,
alias: str | list[str | int] | list[list[str | int]] | None = None,
) -> ArgumentsV3Parameter:
"""
Returns a schema that matches an argument parameter, e.g.:
```py
from pydantic_core import SchemaValidator, core_schema
param = core_schema.arguments_v3_parameter(
name='a', schema=core_schema.str_schema(), mode='positional_only'
)
schema = core_schema.arguments_v3_schema([param])
v = SchemaValidator(schema)
assert v.validate_python({'a': 'hello'}) == (('hello',), {})
```
Args:
name: The name to use for the argument parameter
schema: The schema to use for the argument parameter
mode: The mode to use for the argument parameter
alias: The alias to use for the argument parameter
"""
return _dict_not_none(name=name, schema=schema, mode=mode, alias=alias)


class ArgumentsV3Schema(TypedDict, total=False):
type: Required[Literal['arguments-v3']]
arguments_schema: Required[list[ArgumentsV3Parameter]]
populate_by_name: bool
var_args_schema: CoreSchema
var_kwargs_mode: VarKwargsMode
var_kwargs_schema: CoreSchema
ref: str
metadata: dict[str, Any]
serialization: SerSchema


def arguments_v3_schema(
arguments: list[ArgumentsV3Parameter],
*,
populate_by_name: bool | None = None,
ref: str | None = None,
metadata: dict[str, Any] | None = None,
serialization: SerSchema | None = None,
) -> ArgumentsV3Schema:
"""
Returns a schema that matches an arguments schema, e.g.:
```py
from pydantic_core import SchemaValidator, core_schema
param_a = core_schema.arguments_v3_parameter(
name='a', schema=core_schema.str_schema(), mode='positional_only'
)
param_b = core_schema.arguments_v3_parameter(
name='kwargs', schema=core_schema.bool_schema(), mode='var_kwargs_uniform'
)
schema = core_schema.arguments_v3_schema([param_a, param_b])
v = SchemaValidator(schema)
assert v.validate_python({'a': 'hello', 'kwargs': {'extra': True}}) == (('hello',), {'extra': True})
```
Args:
arguments: The arguments to use for the arguments schema
populate_by_name: Whether to populate by name
ref: optional unique identifier of the schema, used to reference the schema in other places
metadata: Any other information you want to include with the schema, not used by pydantic-core
serialization: Custom serialization schema
"""
return _dict_not_none(
type='arguments-v2',
arguments_schema=arguments,
populate_by_name=populate_by_name,
ref=ref,
metadata=metadata,
serialization=serialization,
)


class CallSchema(TypedDict, total=False):
type: Required[Literal['call']]
arguments_schema: Required[CoreSchema]
Expand Down
2 changes: 2 additions & 0 deletions src/input/input_abstract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ pub trait Input<'py>: fmt::Debug {

fn validate_args(&self) -> ValResult<Self::Arguments<'_>>;

fn validate_args_v3(&self) -> ValResult<Self::Arguments<'_>>;

fn validate_dataclass_args<'a>(&'a self, dataclass_name: &str) -> ValResult<Self::Arguments<'a>>;

fn validate_str(&self, strict: bool, coerce_numbers_to_str: bool) -> ValMatch<EitherString<'_>>;
Expand Down
10 changes: 10 additions & 0 deletions src/input/input_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ impl<'py, 'data> Input<'py> for JsonValue<'data> {
}
}

#[cfg_attr(has_coverage_attribute, coverage(off))]
fn validate_args_v3(&self) -> ValResult<Self::Arguments<'_>> {
Err(ValError::new(ErrorTypeDefaults::ArgumentsType, self))
}

fn validate_dataclass_args<'a>(&'a self, class_name: &str) -> ValResult<JsonArgs<'a, 'data>> {
match self {
JsonValue::Object(object) => Ok(JsonArgs::new(None, Some(object))),
Expand Down Expand Up @@ -375,6 +380,11 @@ impl<'py> Input<'py> for str {
Err(ValError::new(ErrorTypeDefaults::ArgumentsType, self))
}

#[cfg_attr(has_coverage_attribute, coverage(off))]
fn validate_args_v3(&self) -> ValResult<Never> {
Err(ValError::new(ErrorTypeDefaults::ArgumentsType, self))
}

#[cfg_attr(has_coverage_attribute, coverage(off))]
fn validate_dataclass_args(&self, class_name: &str) -> ValResult<Never> {
let class_name = class_name.to_string();
Expand Down
10 changes: 10 additions & 0 deletions src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,16 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
}
}

fn validate_args_v3(&self) -> ValResult<PyArgs<'py>> {
if let Ok(args_kwargs) = self.extract::<ArgsKwargs>() {
let args = args_kwargs.args.into_bound(self.py());
let kwargs = args_kwargs.kwargs.map(|d| d.into_bound(self.py()));
Ok(PyArgs::new(Some(args), kwargs))
} else {
Err(ValError::new(ErrorTypeDefaults::ArgumentsType, self))
}
}

fn validate_dataclass_args<'a>(&'a self, class_name: &str) -> ValResult<PyArgs<'py>> {
if let Ok(dict) = self.downcast::<PyDict>() {
Ok(PyArgs::new(None, Some(dict.clone())))
Expand Down
5 changes: 5 additions & 0 deletions src/input/input_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ impl<'py> Input<'py> for StringMapping<'py> {
Err(ValError::new(ErrorTypeDefaults::ArgumentsType, self))
}

fn validate_args_v3(&self) -> ValResult<Self::Arguments<'_>> {
// do we want to support this?
Err(ValError::new(ErrorTypeDefaults::ArgumentsType, self))
}

fn validate_dataclass_args<'a>(&'a self, _dataclass_name: &str) -> ValResult<StringMappingDict<'py>> {
match self {
StringMapping::String(_) => Err(ValError::new(ErrorTypeDefaults::ArgumentsType, self)),
Expand Down
Loading

0 comments on commit 8148d18

Please # to comment.