6
6
from typing import Any , List , Union , Dict
7
7
from .blobs import Blob
8
8
from .blob_shape import AbstractBlobShape
9
+ from .distributions import *
9
10
10
11
11
12
class BlobFactory (ABC ):
@@ -40,13 +41,13 @@ class DefaultBlobFactory(BlobFactory):
40
41
41
42
def __init__ (
42
43
self ,
43
- A_dist : str = " exp" ,
44
- wx_dist : str = " deg" ,
45
- wy_dist : str = " deg" ,
46
- vx_dist : str = " deg" ,
47
- vy_dist : str = " deg" ,
48
- spx_dist : str = " deg" ,
49
- spy_dist : str = " deg" ,
44
+ A_dist : Distribution = Distribution . exp ,
45
+ wx_dist : Distribution = Distribution . deg ,
46
+ wy_dist : Distribution = Distribution . deg ,
47
+ vx_dist : Distribution = Distribution . deg ,
48
+ vy_dist : Distribution = Distribution . deg ,
49
+ spx_dist : Distribution = Distribution . deg ,
50
+ spy_dist : Distribution = Distribution . deg ,
50
51
A_parameter : float = 1.0 ,
51
52
wx_parameter : float = 1.0 ,
52
53
wy_parameter : float = 1.0 ,
@@ -64,20 +65,20 @@ def __init__(
64
65
65
66
Parameters
66
67
----------
67
- A_dist : str , optional
68
- Distribution type for amplitude, by default "exp"
69
- wx_dist : str , optional
70
- Distribution type for width in the x-direction, by default "deg"
71
- wy_dist : str , optional
72
- Distribution type for width in the y-direction, by default "deg"
73
- vx_dist : str , optional
74
- Distribution type for velocity in the x-direction, by default "deg"
75
- vy_dist : str , optional
76
- Distribution type for velocity in the y-direction, by default "deg"
77
- spx_dist : str , optional
78
- Distribution type for shape parameter in the x-direction, by default "deg"
79
- spy_dist : str , optional
80
- Distribution type for shape parameter in the y-direction, by default "deg"
68
+ A_dist : Distribution , optional
69
+ Distribution type for amplitude, by default "Distribution. exp"
70
+ wx_dist : Distribution , optional
71
+ Distribution type for width in the x-direction, by default "Distribution. deg"
72
+ wy_dist : Distribution , optional
73
+ Distribution type for width in the y-direction, by default "Distribution. deg"
74
+ vx_dist : Distribution , optional
75
+ Distribution type for velocity in the x-direction, by default "Distribution. deg"
76
+ vy_dist : Distribution , optional
77
+ Distribution type for velocity in the y-direction, by default "Distribution. deg"
78
+ spx_dist : Distribution , optional
79
+ Distribution type for shape parameter in the x-direction, by default "Distribution. deg"
80
+ spy_dist : Distribution , optional
81
+ Distribution type for shape parameter in the y-direction, by default "Distribution. deg"
81
82
A_parameter : float, optional
82
83
Free parameter for the amplitude distribution, by default 1.0
83
84
wx_parameter : float, optional
@@ -126,54 +127,10 @@ def __init__(
126
127
self .theta_setter = lambda : 0
127
128
128
129
def _draw_random_variables (
129
- self ,
130
- dist_type : str ,
131
- free_parameter : float ,
132
- num_blobs : int ,
130
+ self , dist : Distribution , free_parameter : float , num_blobs : int
133
131
) -> np .ndarray :
134
- """
135
- Draws random variables from a specified distribution.
136
-
137
- Parameters
138
- ----------
139
- dist_type : str
140
- Type of distribution.
141
- free_parameter : float
142
- Free parameter for the distribution.
143
- num_blobs : int
144
- Number of random variables to draw.
145
-
146
- Returns
147
- -------
148
- NDArray[Any, Float[64]]
149
- Array of random variables drawn from the specified distribution.
150
- """
151
- if dist_type == "exp" :
152
- return np .random .exponential (scale = 1 , size = num_blobs ).astype (np .float64 )
153
- elif dist_type == "gamma" :
154
- return np .random .gamma (
155
- shape = free_parameter , scale = 1 / free_parameter , size = num_blobs
156
- ).astype (np .float64 )
157
- elif dist_type == "normal" :
158
- return np .random .normal (loc = 0 , scale = free_parameter , size = num_blobs ).astype (
159
- np .float64
160
- )
161
- elif dist_type == "uniform" :
162
- return np .random .uniform (
163
- low = 1 - free_parameter / 2 , high = 1 + free_parameter / 2 , size = num_blobs
164
- ).astype (np .float64 )
165
- elif dist_type == "ray" :
166
- return np .random .rayleigh (
167
- scale = np .sqrt (2.0 / np .pi ), size = num_blobs
168
- ).astype (np .float64 )
169
- elif dist_type == "deg" :
170
- return free_parameter * np .ones (num_blobs ).astype (np .float64 )
171
- elif dist_type == "zeros" :
172
- return np .zeros (num_blobs ).astype (np .float64 )
173
- else :
174
- raise NotImplementedError (
175
- self .__class__ .__name__ + ".distribution function not implemented"
176
- )
132
+ """Draws random variables from a specified distribution."""
133
+ return DISTRIBUTIONS [dist ](num_blobs , free_param = free_parameter )
177
134
178
135
def sample_blobs (
179
136
self ,
@@ -205,9 +162,9 @@ def sample_blobs(
205
162
List of Blob objects generated for the Model.
206
163
"""
207
164
amps = self ._draw_random_variables (
208
- dist_type = self .amplitude_dist ,
209
- free_parameter = self .amplitude_parameter ,
210
- num_blobs = num_blobs ,
165
+ self .amplitude_dist ,
166
+ self .amplitude_parameter ,
167
+ num_blobs ,
211
168
)
212
169
wxs = self ._draw_random_variables (
213
170
self .width_x_dist , self .width_x_parameter , num_blobs
0 commit comments