diff --git a/src/validators/function.rs b/src/validators/function.rs index e65ff79c3..ddb096054 100644 --- a/src/validators/function.rs +++ b/src/validators/function.rs @@ -380,7 +380,9 @@ impl Validator for FunctionWrapValidator { let handler = Bound::new(py, handler)?; #[allow(clippy::used_underscore_items)] let result = self._validate(handler.as_any(), py, input, state); - state.exactness = handler.borrow_mut().validator.exactness; + let handler = handler.borrow(); + state.exactness = handler.validator.exactness; + state.fields_set_count = handler.validator.fields_set_count; result } diff --git a/src/validators/generator.rs b/src/validators/generator.rs index 12d76b3be..b4be4c68f 100644 --- a/src/validators/generator.rs +++ b/src/validators/generator.rs @@ -225,6 +225,7 @@ pub struct InternalValidator { self_instance: Option, recursion_guard: RecursionState, pub(crate) exactness: Option, + pub(crate) fields_set_count: Option, validation_mode: InputType, hide_input_in_errors: bool, validation_error_cause: bool, @@ -256,6 +257,7 @@ impl InternalValidator { self_instance: extra.self_instance.map(|d| d.clone().unbind()), recursion_guard: state.recursion_guard.clone(), exactness: state.exactness, + fields_set_count: state.fields_set_count, validation_mode: extra.input_type, hide_input_in_errors, validation_error_cause, @@ -323,6 +325,7 @@ impl InternalValidator { }; let mut state = ValidationState::new(extra, &mut self.recursion_guard, false.into()); state.exactness = self.exactness; + state.fields_set_count = self.fields_set_count; let result = self.validator.validate(py, input, &mut state).map_err(|e| { ValidationError::from_val_error( py, @@ -335,6 +338,7 @@ impl InternalValidator { ) }); self.exactness = state.exactness; + self.fields_set_count = state.fields_set_count; result } } diff --git a/tests/validators/test_union.py b/tests/validators/test_union.py index f35ebfec0..7e0daacec 100644 --- a/tests/validators/test_union.py +++ b/tests/validators/test_union.py @@ -1358,3 +1358,83 @@ class Model: assert isinstance(validator.validate_python({'x': {'foo': 'foo'}}).x, Foo) assert isinstance(validator.validate_python({'x': {'bar': 'bar'}}).x, Bar) + + +def test_smart_union_wrap_validator_should_not_change_nested_model_field_counts() -> None: + """Adding a wrap validator on a union member should not affect smart union behavior""" + + class SubModel: + x: str = 'x' + + class ModelA: + type: str = 'A' + sub: SubModel + + class ModelB: + type: str = 'B' + sub: SubModel + + submodel_schema = core_schema.model_schema( + SubModel, + core_schema.model_fields_schema(fields={'x': core_schema.model_field(core_schema.str_schema())}), + ) + + wrapped_submodel_schema = core_schema.no_info_wrap_validator_function( + lambda v, handler: handler(v), submodel_schema + ) + + model_a_schema = core_schema.model_schema( + ModelA, + core_schema.model_fields_schema( + fields={ + 'type': core_schema.model_field( + core_schema.with_default_schema(core_schema.literal_schema(['A']), default='A'), + ), + 'sub': core_schema.model_field(wrapped_submodel_schema), + }, + ), + ) + + model_b_schema = core_schema.model_schema( + ModelB, + core_schema.model_fields_schema( + fields={ + 'type': core_schema.model_field( + core_schema.with_default_schema(core_schema.literal_schema(['B']), default='B'), + ), + 'sub': core_schema.model_field(submodel_schema), + }, + ), + ) + + for choices in permute_choices([model_a_schema, model_b_schema]): + schema = core_schema.union_schema(choices) + validator = SchemaValidator(schema) + + assert isinstance(validator.validate_python({'type': 'A', 'sub': {'x': 'x'}}), ModelA) + assert isinstance(validator.validate_python({'type': 'B', 'sub': {'x': 'x'}}), ModelB) + + # defaults to leftmost choice if there's a tie + assert isinstance(validator.validate_python({'sub': {'x': 'x'}}), choices[0]['cls']) + + # test validate_assignment + class RootModel: + ab: Union[ModelA, ModelB] + + root_model = core_schema.model_schema( + RootModel, + core_schema.model_fields_schema( + fields={'ab': core_schema.model_field(core_schema.union_schema([model_a_schema, model_b_schema]))} + ), + ) + + validator = SchemaValidator(root_model) + m = validator.validate_python({'ab': {'type': 'B', 'sub': {'x': 'x'}}}) + assert isinstance(m, RootModel) + assert isinstance(m.ab, ModelB) + assert m.ab.sub.x == 'x' + + m = validator.validate_assignment(m, 'ab', {'sub': {'x': 'y'}}) + assert isinstance(m, RootModel) + assert isinstance(m.ab, ModelA) + assert m.ab.sub.x == 'y'