Skip to content

Commit

Permalink
Multivariate distribution support and input scaling option (#117)
Browse files Browse the repository at this point in the history
Multivariate support, refactored domain space , bug fixes, added param normalization option
  • Loading branch information
tihom authored Aug 24, 2024
1 parent b130420 commit 5edd533
Show file tree
Hide file tree
Showing 15 changed files with 1,662 additions and 1,406 deletions.
19 changes: 16 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,12 @@ Mango search space is compatible with scikit-learn's parameter space definitions
[RandomizedSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html)
or [GridSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html).
The search space is defined as a dictionary with keys being the parameter names (string) and values being
list of discreet choices, range of integers or the distributions. Example of some common search spaces are:
list of discreet choices, range of integers or the distributions.

> [!NOTE]
> Mango does not scale or normalize the search space parameters by default. Users should use their judgement on whether input space needs to be normalized.
Example of some common search spaces are:

### Integer
Following space defines `x` as an integer parameters with values in `range(-10, 11)` (11 is not included):
Expand All @@ -151,7 +156,7 @@ param_space = dict(max_features=['auto', 0.2, 0.3])
Lists are uniformly sampled and are assumed to be unordered. They are one-hot encoded internally.

### Distributions
All the distributions supported by [`scipy.stats`](https://docs.scipy.org/doc/scipy/reference/stats.html) are supported.
All the distributions, including multivariate, supported by [`scipy.stats`](https://docs.scipy.org/doc/scipy/reference/stats.html) are supported.
In general, distributions must provide a `rvs` method for sampling.

#### Uniform distribution
Expand Down Expand Up @@ -383,10 +388,18 @@ It can be either:
This allows the user to customize the initial evaluation points and therefore guide the optimization process.
It also enables starting the optimizer from the results of a previous tuner run (see [this](<examples/warmup.ipynb>) notebook for a working example).
Note that if `initial_custom` option is given then `initial_random` is ignored.
- *scale_params*: `True` or `False` (default: `False`). Scales the search space parameter space using `MinMaxScaler`. Can be useful when the range of parameters is not comparable like below:
```python
{
'x': uniform(-1, 2), # -1 to 1
'y': uniform(-1000, 2000) # -1000 to 1000
}
```
However, use this option with caution as it could have unintended consequences.


The default configuration parameters can be modified, as shown below. Only the parameters whose values need to adjusted can be passed as the dictionary.

The configuration options can be modified, as shown below:
```python
conf_dict = dict(num_iteration=40, domain_size=10000, initial_random=3)

Expand Down
9 changes: 9 additions & 0 deletions dev_notes.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# 2024 Aug 16: Scaling parameters in GP space
Added an option `scale_params` to scale the parameters in GP space using min max scaler. Using 10 repeated `test_six_hump` test with param `y` scaled by a factor of 1000 we get the following results:

Base case (y param not scaled in inputs or GP space): 0 out of 10 tests failed
Y param not scaled, GP space scaled: 1 out of 10 tests failed
Without scaling GP space: 0 out of 10 tests failed
With scaling GP space: 2 out of 10 tests failed

Scaling is not clearly useful so not exposing it as an option.
115 changes: 108 additions & 7 deletions examples/Failure_Handling.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,22 @@
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluation failed\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c1c23a4d25654b0c98935b4b48dad76d",
"model_id": "8824fb1e3ebe4fcb9f6941c6a2713e06",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=20.0), HTML(value='')))"
" 0%| | 0/20 [00:00<?, ?it/s]"
]
},
"metadata": {},
Expand All @@ -41,9 +48,9 @@
"Evaluation failed\n",
"Evaluation failed\n",
"Evaluation failed\n",
"\n",
"best parameters: {'x': -0.016781443194481938, 'y': -0.022917881560944764}\n",
"best objective: -0.0008068461309311162\n"
"Evaluation failed\n",
"best parameters: {'x': np.float64(0.0276446020826544), 'y': np.float64(0.28614807644106754)}\n",
"best objective: -0.08264494567523134\n"
]
}
],
Expand Down Expand Up @@ -94,6 +101,100 @@
"print('best objective:',results['best_objective'])"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"x = results[\"params_tried\"]\n",
"y = results[\"objective_values\"]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"17"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(x)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"17"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(y)"
]
},
{
"cell_type": "raw",
"metadata": {},
"source": []
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-20.74736909502119 -20.74736909502119\n",
"-42.31845686613488 -42.31845686613488\n",
"-28.24377819521724 -28.24377819521724\n",
"-0.10681853281376912 -0.10681853281376912\n",
"-4.671679570731408 -4.671679570731408\n",
"-20.018842110693964 -20.018842110693964\n",
"-13.151638672697791 -13.151638672697791\n",
"-14.253843472516818 -14.253843472516818\n",
"-24.466383649475585 -24.466383649475585\n",
"-1.9945915709452409 -1.9945915709452409\n",
"-16.554145481793704 -16.554145481793704\n",
"-1.3652187718750919 -1.3652187718750919\n",
"-0.5475161352600229 -0.5475161352600229\n",
"-0.1409257338336861 -0.1409257338336861\n",
"-0.9355543392875246 -0.9355543392875246\n",
"-0.08264494567523134 -0.08264494567523134\n",
"-2.3269622867369497 -2.3269622867369497\n"
]
}
],
"source": [
"for p, o in zip(x, y):\n",
" xx = p['x']\n",
" yy = p['y']\n",
" oo = -(xx**2 + yy**2)\n",
" print(o, oo)\n",
" \n",
" assert oo == o "
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -104,7 +205,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -118,7 +219,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.10"
"version": "3.11.9"
}
},
"nbformat": 4,
Expand Down
92 changes: 46 additions & 46 deletions examples/warmup.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 1,
"id": "64520ba9",
"metadata": {},
"outputs": [],
Expand All @@ -28,14 +28,14 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 2,
"id": "8906f595",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7842b7e7cf8d455b99ed74d4fc0d7c4b",
"model_id": "886f322be5a44eac9b42079e3c8c4f44",
"version_major": 2,
"version_minor": 0
},
Expand Down Expand Up @@ -70,38 +70,38 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 3,
"id": "b74f0d1e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[({'b': 99, 'a': 71}, 170),\n",
" ({'b': 92, 'a': 35}, 127),\n",
" ({'a': 40, 'b': 14}, 54),\n",
" ({'a': 69, 'b': 97}, 166),\n",
" ({'a': 66, 'b': 98}, 164),\n",
" ({'a': 52, 'b': 21}, 73),\n",
" ({'a': 38, 'b': 77}, 115),\n",
" ({'a': 67, 'b': 4}, 71),\n",
" ({'a': 99, 'b': 99}, 198),\n",
" ({'a': 27, 'b': 7}, 34),\n",
" ({'a': 99, 'b': 98}, 197),\n",
" ({'a': 98, 'b': 99}, 197),\n",
" ({'a': 15, 'b': 3}, 18),\n",
" ({'a': 54, 'b': 24}, 78),\n",
" ({'a': 51, 'b': 91}, 142),\n",
" ({'a': 99, 'b': 97}, 196),\n",
" ({'a': 97, 'b': 99}, 196),\n",
" ({'a': 47, 'b': 87}, 134),\n",
" ({'a': 85, 'b': 18}, 103),\n",
" ({'a': 76, 'b': 17}, 93),\n",
" ({'a': 98, 'b': 98}, 196),\n",
" ({'a': 99, 'b': 96}, 195)]"
"[({'b': 8, 'a': 85}, np.int64(93)),\n",
" ({'b': 15, 'a': 93}, np.int64(108)),\n",
" ({'a': 71, 'b': 19}, np.int64(90)),\n",
" ({'a': 41, 'b': 95}, np.int64(136)),\n",
" ({'a': 80, 'b': 22}, np.int64(102)),\n",
" ({'a': 61, 'b': 51}, np.int64(112)),\n",
" ({'a': 41, 'b': 65}, np.int64(106)),\n",
" ({'a': 72, 'b': 20}, np.int64(92)),\n",
" ({'a': 55, 'b': 99}, np.int64(154)),\n",
" ({'a': 11, 'b': 29}, np.int64(40)),\n",
" ({'a': 98, 'b': 66}, np.int64(164)),\n",
" ({'a': 56, 'b': 32}, np.int64(88)),\n",
" ({'a': 40, 'b': 3}, np.int64(43)),\n",
" ({'a': 99, 'b': 99}, np.int64(198)),\n",
" ({'a': 53, 'b': 14}, np.int64(67)),\n",
" ({'a': 98, 'b': 99}, np.int64(197)),\n",
" ({'a': 99, 'b': 98}, np.int64(197)),\n",
" ({'a': 97, 'b': 99}, np.int64(196)),\n",
" ({'a': 99, 'b': 97}, np.int64(196)),\n",
" ({'a': 91, 'b': 16}, np.int64(107)),\n",
" ({'a': 89, 'b': 15}, np.int64(104)),\n",
" ({'a': 98, 'b': 98}, np.int64(196))]"
]
},
"execution_count": 5,
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -123,14 +123,14 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 4,
"id": "65029dad",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "28039081a31f44ddbab7a84499bf5875",
"model_id": "5698243cdab745359c6bd57559ceb018",
"version_major": 2,
"version_minor": 0
},
Expand Down Expand Up @@ -159,31 +159,31 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 5,
"id": "8115b7a3",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([{'b': 99, 'a': 71}, {'b': 92, 'a': 35}, {'a': 40, 'b': 14},\n",
" {'a': 69, 'b': 97}, {'a': 66, 'b': 98}, {'a': 52, 'b': 21},\n",
" {'a': 38, 'b': 77}, {'a': 67, 'b': 4}, {'a': 99, 'b': 99},\n",
" {'a': 27, 'b': 7}, {'a': 99, 'b': 98}, {'a': 98, 'b': 99},\n",
" {'a': 15, 'b': 3}, {'a': 54, 'b': 24}, {'a': 51, 'b': 91},\n",
" {'a': 99, 'b': 97}, {'a': 97, 'b': 99}, {'a': 47, 'b': 87},\n",
" {'a': 85, 'b': 18}, {'a': 76, 'b': 17}, {'a': 98, 'b': 98},\n",
" {'a': 99, 'b': 96}, {'a': 46, 'b': 44}, {'a': 79, 'b': 68},\n",
" {'a': 96, 'b': 99}, {'a': 75, 'b': 30}, {'a': 80, 'b': 54},\n",
" {'a': 98, 'b': 97}, {'a': 10, 'b': 27}, {'a': 97, 'b': 98},\n",
" {'a': 99, 'b': 95}, {'a': 6, 'b': 91}, {'a': 9, 'b': 99},\n",
" {'a': 95, 'b': 99}, {'a': 97, 'b': 97}, {'a': 91, 'b': 34},\n",
" {'a': 19, 'b': 55}, {'a': 98, 'b': 96}, {'a': 96, 'b': 98},\n",
" {'a': 38, 'b': 97}, {'a': 22, 'b': 59}, {'a': 99, 'b': 94}],\n",
"array([{'b': 8, 'a': 85}, {'b': 15, 'a': 93}, {'a': 71, 'b': 19},\n",
" {'a': 41, 'b': 95}, {'a': 80, 'b': 22}, {'a': 61, 'b': 51},\n",
" {'a': 41, 'b': 65}, {'a': 72, 'b': 20}, {'a': 55, 'b': 99},\n",
" {'a': 11, 'b': 29}, {'a': 98, 'b': 66}, {'a': 56, 'b': 32},\n",
" {'a': 40, 'b': 3}, {'a': 99, 'b': 99}, {'a': 53, 'b': 14},\n",
" {'a': 98, 'b': 99}, {'a': 99, 'b': 98}, {'a': 97, 'b': 99},\n",
" {'a': 99, 'b': 97}, {'a': 91, 'b': 16}, {'a': 89, 'b': 15},\n",
" {'a': 98, 'b': 98}, {'a': 50, 'b': 76}, {'a': 80, 'b': 88},\n",
" {'a': 96, 'b': 99}, {'a': 17, 'b': 12}, {'a': 99, 'b': 96},\n",
" {'a': 39, 'b': 48}, {'a': 98, 'b': 97}, {'a': 97, 'b': 98},\n",
" {'a': 38, 'b': 1}, {'a': 95, 'b': 99}, {'a': 45, 'b': 77},\n",
" {'a': 99, 'b': 95}, {'a': 85, 'b': 9}, {'a': 22, 'b': 93},\n",
" {'a': 97, 'b': 97}, {'a': 96, 'b': 98}, {'a': 76, 'b': 57},\n",
" {'a': 98, 'b': 96}, {'a': 98, 'b': 13}, {'a': 94, 'b': 99}],\n",
" dtype=object)"
]
},
"execution_count": 9,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -217,7 +217,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
"version": "3.11.9"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit 5edd533

Please # to comment.