-
Notifications
You must be signed in to change notification settings - Fork 326
/
Copy pathint_range_to_choice.py
85 lines (76 loc) · 3.24 KB
/
int_range_to_choice.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from numbers import Real
from typing import cast, Optional, TYPE_CHECKING
from ax.core.observation import Observation
from ax.core.parameter import ChoiceParameter, Parameter, ParameterType, RangeParameter
from ax.core.search_space import SearchSpace
from ax.modelbridge.transforms.base import Transform
from ax.modelbridge.transforms.utils import construct_new_search_space
from ax.models.types import TConfig
if TYPE_CHECKING:
# import as module to make sphinx-autodoc-typehints happy
from ax import modelbridge as modelbridge_module # noqa F401
class IntRangeToChoice(Transform):
"""Convert a RangeParameter of type int to a ordered ChoiceParameter.
Transform is done in-place.
"""
def __init__(
self,
search_space: SearchSpace | None = None,
observations: list[Observation] | None = None,
modelbridge: Optional["modelbridge_module.base.Adapter"] = None,
config: TConfig | None = None,
) -> None:
assert search_space is not None, "IntRangeToChoice requires search space"
config = config or {}
self.max_choices: float = float(
cast(Real, (config.get("max_choices", float("inf"))))
)
# Identify parameters that should be transformed
self.transform_parameters: set[str] = {
p_name
for p_name, p in search_space.parameters.items()
if isinstance(p, RangeParameter)
and p.parameter_type == ParameterType.INT
and p.cardinality() <= self.max_choices
}
def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
transformed_parameters: dict[str, Parameter] = {}
for p_name, p in search_space.parameters.items():
if (
p_name in self.transform_parameters
and isinstance(p, RangeParameter)
and p.parameter_type == ParameterType.INT
and p.cardinality() <= self.max_choices
):
values = list(range(int(p.lower), int(p.upper) + 1))
target_value = (
None
if p.target_value is None
else next(i for i, v in enumerate(values) if v == p.target_value)
)
transformed_parameters[p_name] = ChoiceParameter(
name=p_name,
parameter_type=p.parameter_type,
values=values, # pyre-fixme[6]
is_ordered=True,
is_fidelity=p.is_fidelity,
target_value=target_value,
)
else:
transformed_parameters[p.name] = p
return construct_new_search_space(
search_space=search_space,
parameters=list(transformed_parameters.values()),
parameter_constraints=[
pc.clone_with_transformed_parameters(
transformed_parameters=transformed_parameters
)
for pc in search_space.parameter_constraints
],
)