Skip to content

Commit 4879c84

Browse files
committed
fix import
1 parent 6a0aca4 commit 4879c84

File tree

3 files changed

+4
-4
lines changed

3 files changed

+4
-4
lines changed

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ pre-commit = ">=2.20.0"
5151
pytest = ">=7.1.2"
5252
pytest-cov = ">=3.0.0"
5353
sphinx = ">=5.0.2"
54+
tensorflow = ">=2.16.1"
5455

5556
[tool.ruff]
5657
select = [

scikeras/_saving_utils.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@
33

44
import keras as keras
55
import keras.saving
6-
import keras.saving.object_registration
76
import numpy as np
8-
from keras.saving.saving_lib import load_model, save_model
7+
from keras.src.saving.saving_lib import load_model, save_model
98

109

1110
def unpack_keras_model(
@@ -25,7 +24,7 @@ def pack_keras_model(
2524
"""Support for Pythons's Pickle protocol."""
2625
tp = type(model)
2726
out = BytesIO()
28-
if tp not in keras.saving.object_registration.GLOBAL_CUSTOM_OBJECTS:
27+
if tp not in keras.saving.get_custom_objects():
2928
module = ".".join(tp.__qualname__.split(".")[:-1])
3029
name = tp.__qualname__.split(".")[-1]
3130
keras.saving.register_keras_serializable(module, name)(tp)

tests/test_compile_kwargs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
from keras import losses as losses_module
66
from keras import metrics as metrics_module
77
from keras import optimizers as optimizers_module
8-
from keras.backend.common.variables import KerasVariable
98
from keras.layers import Dense, Input
109
from keras.models import Model
10+
from keras.src.backend.common.variables import KerasVariable
1111
from sklearn.datasets import make_classification
1212

1313
from scikeras.wrappers import KerasClassifier

0 commit comments

Comments
 (0)