Skip to content

Commit f91d773

Browse files
author
Sosnowsky
committed
Switch to enums for distributions instead of strings
1 parent 6683351 commit f91d773

8 files changed

+127
-86
lines changed

blobmodel/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
from .stochasticality import BlobFactory, DefaultBlobFactory
55
from .geometry import Geometry
66
from .blob_shape import AbstractBlobShape, BlobShapeImpl
7+
from .distributions import Distribution

blobmodel/distributions.py

+77
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from enum import Enum
2+
from abc import ABC, abstractmethod
3+
import numpy as np
4+
5+
6+
class Distribution(Enum):
7+
"""Enum class used to identify distribution functions."""
8+
9+
deg = 1
10+
zeros = 2
11+
exp = 3
12+
gamma = 4
13+
normal = 5
14+
uniform = 6
15+
rayleigh = 7
16+
17+
18+
class AbstractDistribution(ABC):
19+
"""Abstract class used to represent and implement a distribution function."""
20+
21+
@abstractmethod
22+
def sample(
23+
self,
24+
num_blobs: int,
25+
**kwargs,
26+
) -> np.ndarray:
27+
raise NotImplementedError
28+
29+
30+
def _sample_deg(num_blobs, **kwargs):
31+
free_param = kwargs["free_param"]
32+
return free_param * np.ones(num_blobs).astype(np.float64)
33+
34+
35+
def _sample_zeros(num_blobs, **kwargs):
36+
return np.zeros(num_blobs).astype(np.float64)
37+
38+
39+
def _sample_exp(num_blobs, **kwargs):
40+
free_param = kwargs["free_param"]
41+
return np.random.exponential(scale=free_param, size=num_blobs).astype(np.float64)
42+
43+
44+
def _sample_gamma(num_blobs, **kwargs):
45+
free_param = kwargs["free_param"]
46+
return np.random.gamma(
47+
shape=free_param, scale=1 / free_param, size=num_blobs
48+
).astype(np.float64)
49+
50+
51+
def _sample_normal(num_blobs, **kwargs):
52+
free_param = kwargs["free_param"]
53+
return np.random.normal(loc=0, scale=free_param, size=num_blobs).astype(np.float64)
54+
55+
56+
def _sample_uniform(num_blobs, **kwargs):
57+
free_param = kwargs["free_param"]
58+
return np.random.uniform(
59+
low=1 - free_param / 2, high=1 + free_param / 2, size=num_blobs
60+
).astype(np.float64)
61+
62+
63+
def _sample_rayleigh(num_blobs, **kwargs):
64+
return np.random.rayleigh(scale=np.sqrt(2.0 / np.pi), size=num_blobs).astype(
65+
np.float64
66+
)
67+
68+
69+
DISTRIBUTIONS = {
70+
Distribution.deg: _sample_deg,
71+
Distribution.zeros: _sample_zeros,
72+
Distribution.exp: _sample_exp,
73+
Distribution.gamma: _sample_gamma,
74+
Distribution.normal: _sample_normal,
75+
Distribution.uniform: _sample_uniform,
76+
Distribution.rayleigh: _sample_rayleigh,
77+
}

blobmodel/stochasticality.py

+28-71
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Any, List, Union, Dict
77
from .blobs import Blob
88
from .blob_shape import AbstractBlobShape
9+
from .distributions import *
910

1011

1112
class BlobFactory(ABC):
@@ -40,13 +41,13 @@ class DefaultBlobFactory(BlobFactory):
4041

4142
def __init__(
4243
self,
43-
A_dist: str = "exp",
44-
wx_dist: str = "deg",
45-
wy_dist: str = "deg",
46-
vx_dist: str = "deg",
47-
vy_dist: str = "deg",
48-
spx_dist: str = "deg",
49-
spy_dist: str = "deg",
44+
A_dist: Distribution = Distribution.exp,
45+
wx_dist: Distribution = Distribution.deg,
46+
wy_dist: Distribution = Distribution.deg,
47+
vx_dist: Distribution = Distribution.deg,
48+
vy_dist: Distribution = Distribution.deg,
49+
spx_dist: Distribution = Distribution.deg,
50+
spy_dist: Distribution = Distribution.deg,
5051
A_parameter: float = 1.0,
5152
wx_parameter: float = 1.0,
5253
wy_parameter: float = 1.0,
@@ -64,20 +65,20 @@ def __init__(
6465
6566
Parameters
6667
----------
67-
A_dist : str, optional
68-
Distribution type for amplitude, by default "exp"
69-
wx_dist : str, optional
70-
Distribution type for width in the x-direction, by default "deg"
71-
wy_dist : str, optional
72-
Distribution type for width in the y-direction, by default "deg"
73-
vx_dist : str, optional
74-
Distribution type for velocity in the x-direction, by default "deg"
75-
vy_dist : str, optional
76-
Distribution type for velocity in the y-direction, by default "deg"
77-
spx_dist : str, optional
78-
Distribution type for shape parameter in the x-direction, by default "deg"
79-
spy_dist : str, optional
80-
Distribution type for shape parameter in the y-direction, by default "deg"
68+
A_dist : Distribution, optional
69+
Distribution type for amplitude, by default "Distribution.exp"
70+
wx_dist : Distribution, optional
71+
Distribution type for width in the x-direction, by default "Distribution.deg"
72+
wy_dist : Distribution, optional
73+
Distribution type for width in the y-direction, by default "Distribution.deg"
74+
vx_dist : Distribution, optional
75+
Distribution type for velocity in the x-direction, by default "Distribution.deg"
76+
vy_dist : Distribution, optional
77+
Distribution type for velocity in the y-direction, by default "Distribution.deg"
78+
spx_dist : Distribution, optional
79+
Distribution type for shape parameter in the x-direction, by default "Distribution.deg"
80+
spy_dist : Distribution, optional
81+
Distribution type for shape parameter in the y-direction, by default "Distribution.deg"
8182
A_parameter : float, optional
8283
Free parameter for the amplitude distribution, by default 1.0
8384
wx_parameter : float, optional
@@ -126,54 +127,10 @@ def __init__(
126127
self.theta_setter = lambda: 0
127128

128129
def _draw_random_variables(
129-
self,
130-
dist_type: str,
131-
free_parameter: float,
132-
num_blobs: int,
130+
self, dist: Distribution, free_parameter: float, num_blobs: int
133131
) -> np.ndarray:
134-
"""
135-
Draws random variables from a specified distribution.
136-
137-
Parameters
138-
----------
139-
dist_type : str
140-
Type of distribution.
141-
free_parameter : float
142-
Free parameter for the distribution.
143-
num_blobs : int
144-
Number of random variables to draw.
145-
146-
Returns
147-
-------
148-
NDArray[Any, Float[64]]
149-
Array of random variables drawn from the specified distribution.
150-
"""
151-
if dist_type == "exp":
152-
return np.random.exponential(scale=1, size=num_blobs).astype(np.float64)
153-
elif dist_type == "gamma":
154-
return np.random.gamma(
155-
shape=free_parameter, scale=1 / free_parameter, size=num_blobs
156-
).astype(np.float64)
157-
elif dist_type == "normal":
158-
return np.random.normal(loc=0, scale=free_parameter, size=num_blobs).astype(
159-
np.float64
160-
)
161-
elif dist_type == "uniform":
162-
return np.random.uniform(
163-
low=1 - free_parameter / 2, high=1 + free_parameter / 2, size=num_blobs
164-
).astype(np.float64)
165-
elif dist_type == "ray":
166-
return np.random.rayleigh(
167-
scale=np.sqrt(2.0 / np.pi), size=num_blobs
168-
).astype(np.float64)
169-
elif dist_type == "deg":
170-
return free_parameter * np.ones(num_blobs).astype(np.float64)
171-
elif dist_type == "zeros":
172-
return np.zeros(num_blobs).astype(np.float64)
173-
else:
174-
raise NotImplementedError(
175-
self.__class__.__name__ + ".distribution function not implemented"
176-
)
132+
"""Draws random variables from a specified distribution."""
133+
return DISTRIBUTIONS[dist](num_blobs, free_param=free_parameter)
177134

178135
def sample_blobs(
179136
self,
@@ -205,9 +162,9 @@ def sample_blobs(
205162
List of Blob objects generated for the Model.
206163
"""
207164
amps = self._draw_random_variables(
208-
dist_type=self.amplitude_dist,
209-
free_parameter=self.amplitude_parameter,
210-
num_blobs=num_blobs,
165+
self.amplitude_dist,
166+
self.amplitude_parameter,
167+
num_blobs,
211168
)
212169
wxs = self._draw_random_variables(
213170
self.width_x_dist, self.width_x_parameter, num_blobs

tests/test_analytical.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
from blobmodel import Model, DefaultBlobFactory
1+
from blobmodel import Model, DefaultBlobFactory, Distribution
22
import xarray as xr
33
import numpy as np
44

55

66
# use DefaultBlobFactory to define distribution functions fo random variables
7-
bf = DefaultBlobFactory(A_dist="deg", wx_dist="deg", vx_dist="deg", vy_dist="zeros")
7+
bf = DefaultBlobFactory(A_dist=Distribution.deg, vy_dist=Distribution.zeros)
88

99
tmp = Model(
1010
Nx=100,

tests/test_blob.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from blobmodel import Model, DefaultBlobFactory, Blob, BlobShapeImpl
1+
from blobmodel import Model, DefaultBlobFactory, Blob, BlobShapeImpl, Distribution
22
import numpy as np
33
from unittest.mock import MagicMock
44

@@ -204,7 +204,7 @@ def test_kwargs():
204204

205205

206206
def test_get_blobs():
207-
bf = DefaultBlobFactory(A_dist="deg", wx_dist="deg", vx_dist="deg", vy_dist="deg")
207+
bf = DefaultBlobFactory(A_dist=Distribution.deg)
208208
one_blob = Model(
209209
Nx=100,
210210
Ny=100,
@@ -218,6 +218,6 @@ def test_get_blobs():
218218
num_blobs=3,
219219
blob_factory=bf,
220220
)
221-
ds = one_blob.make_realization()
221+
one_blob.make_realization()
222222
blob_list = one_blob.get_blobs()
223223
assert len(blob_list) == 3

tests/test_changing_t_drain.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
from blobmodel import Model, DefaultBlobFactory
1+
from blobmodel import Model, DefaultBlobFactory, Distribution
22
import xarray as xr
33
import numpy as np
44

55

66
# use DefaultBlobFactory to define distribution functions fo random variables
7-
bf = DefaultBlobFactory(A_dist="deg", wx_dist="deg", vx_dist="deg", vy_dist="zeros")
7+
bf = DefaultBlobFactory(A_dist=Distribution.deg, vy_dist=Distribution.zeros)
88

99
t_drain = np.linspace(2, 1, 10)
1010

tests/test_one_dim.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import pytest
2-
from blobmodel import Model, DefaultBlobFactory
2+
from blobmodel import Model, DefaultBlobFactory, Distribution
33
import numpy as np
44

55

66
# use DefaultBlobFactory to define distribution functions fo random variables
7-
bf = DefaultBlobFactory(A_dist="deg", wx_dist="deg", vx_dist="deg", vy_dist="zeros")
7+
bf = DefaultBlobFactory(A_dist=Distribution.deg, vy_dist=Distribution.zeros)
88

99
one_dim_model = Model(
1010
Nx=100,

tests/test_stochasticality.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,37 @@
11
import pytest
2-
from blobmodel import DefaultBlobFactory, BlobShapeImpl, BlobFactory
2+
from blobmodel import DefaultBlobFactory, BlobShapeImpl, BlobFactory, Distribution
33

44

55
def test_mean_of_distribution():
66
bf = DefaultBlobFactory()
7-
distributions_mean_1 = ["exp", "gamma", "uniform", "ray", "deg"]
8-
distributions_mean_0 = ["normal", "zeros"]
7+
distributions_mean_1 = [
8+
Distribution.exp,
9+
Distribution.gamma,
10+
Distribution.uniform,
11+
Distribution.rayleigh,
12+
Distribution.deg,
13+
]
14+
distributions_mean_0 = [Distribution.normal, Distribution.zeros]
915

1016
for dist in distributions_mean_1:
1117
tmp = bf._draw_random_variables(
12-
dist_type=dist,
18+
dist=dist,
1319
free_parameter=1,
1420
num_blobs=10000,
1521
)
1622
assert 0.95 <= tmp.mean() <= 1.05
1723

1824
for dist in distributions_mean_0:
1925
tmp = bf._draw_random_variables(
20-
dist_type=dist,
26+
dist=dist,
2127
free_parameter=1,
2228
num_blobs=10000,
2329
)
2430
assert -0.05 <= tmp.mean() <= 0.05
2531

2632

2733
def test_not_implemented_distribution():
28-
with pytest.raises(NotImplementedError):
34+
with pytest.raises(KeyError):
2935
bf = DefaultBlobFactory(A_dist="something_different")
3036
bf.sample_blobs(1, 1, 1, BlobShapeImpl("gauss"), 1)
3137

0 commit comments

Comments
 (0)