11
11
import torch
12
12
from botorch import settings
13
13
from botorch .exceptions .errors import BotorchTensorDimensionError
14
+ from botorch .exceptions .warnings import UserInputWarning
14
15
from botorch .models .transforms .input import (
15
16
AffineInputTransform ,
16
17
AppendFeatures ,
@@ -155,16 +156,29 @@ def test_normalize(self):
155
156
self .assertEqual (nlz ._d , 2 )
156
157
self .assertEqual (nlz .mins .shape , torch .Size ([3 , 1 , 2 ]))
157
158
self .assertEqual (nlz .ranges .shape , torch .Size ([3 , 1 , 2 ]))
159
+ self .assertTrue (nlz .equals (Normalize (** nlz .get_init_args ())))
158
160
159
- # basic init, fixed bounds
161
+ # learn_bounds=False with no bounds.
162
+ with self .assertWarnsRegex (UserInputWarning , "learn_bounds" ):
163
+ Normalize (d = 2 , learn_bounds = False )
164
+
165
+ # learn_bounds=True with bounds provided.
160
166
bounds = torch .zeros (2 , 2 , device = self .device , dtype = dtype )
167
+ nlz = Normalize (d = 2 , bounds = bounds , learn_bounds = True )
168
+ self .assertTrue (nlz .learn_bounds )
169
+ self .assertTrue (torch .equal (nlz .mins , bounds [..., 0 :1 , :]))
170
+ self .assertTrue (
171
+ torch .equal (nlz .ranges , bounds [..., 1 :2 , :] - bounds [..., 0 :1 , :])
172
+ )
173
+
174
+ # basic init, fixed bounds
161
175
nlz = Normalize (d = 2 , bounds = bounds )
162
176
self .assertFalse (nlz .learn_bounds )
163
177
self .assertTrue (nlz .training )
164
178
self .assertEqual (nlz ._d , 2 )
165
179
self .assertTrue (torch .equal (nlz .mins , bounds [..., 0 :1 , :]))
166
180
self .assertTrue (
167
- torch .equal (nlz .mins , bounds [..., 1 :2 , :] - bounds [..., 0 :1 , :])
181
+ torch .equal (nlz .ranges , bounds [..., 1 :2 , :] - bounds [..., 0 :1 , :])
168
182
)
169
183
# with grad
170
184
bounds .requires_grad = True
@@ -180,6 +194,7 @@ def test_normalize(self):
180
194
nlz .eval ()
181
195
self .assertIsNone (nlz .coefficient .grad_fn )
182
196
self .assertIsNone (nlz .offset .grad_fn )
197
+ self .assertTrue (nlz .equals (Normalize (** nlz .get_init_args ())))
183
198
184
199
# basic init, provided indices
185
200
with self .assertRaises (ValueError ):
@@ -204,6 +219,7 @@ def test_normalize(self):
204
219
== torch .tensor ([0 ], dtype = torch .long , device = self .device )
205
220
).all ()
206
221
)
222
+ self .assertTrue (nlz .equals (Normalize (** nlz .get_init_args ())))
207
223
208
224
# test .to
209
225
other_dtype = torch .float if dtype == torch .double else torch .double
@@ -594,13 +610,15 @@ def test_round_transform(self):
594
610
self .assertTrue (round_tf .training )
595
611
self .assertFalse (round_tf .approximate )
596
612
self .assertEqual (round_tf .tau , 1e-3 )
613
+ self .assertTrue (round_tf .equals (Round (** round_tf .get_init_args ())))
597
614
598
615
# With tensor indices.
599
616
round_tf = Round (
600
617
integer_indices = torch .tensor (int_idcs , dtype = dtype , device = self .device ),
601
618
categorical_features = categorical_feats ,
602
619
)
603
620
self .assertEqual (round_tf .integer_indices .tolist (), int_idcs )
621
+ self .assertTrue (round_tf .equals (Round (** round_tf .get_init_args ())))
604
622
605
623
# basic usage
606
624
for batch_shape , approx , categorical_features in itertools .product (
0 commit comments