Skip to content

Commit

Permalink
Cleaning. Resolving TODOs.
Browse files Browse the repository at this point in the history
  • Loading branch information
yonghakim committed Aug 8, 2024
1 parent b20325e commit b60e942
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 829 deletions.
3 changes: 1 addition & 2 deletions meent/on_numpy/emsolver/transfer_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,11 @@ def transfer_1d_4(pol, F, G, T, kz_top, kz_bot, theta, n_top, n_bot, type_comple
delta_i0[ff_xy // 2] = 1

if pol == 0: # TE
# TODO: check sign of H
inc_term = 1j * n_top * np.cos(theta) * delta_i0
T1 = np.linalg.inv(G + 1j * Kz_top @ F) @ (1j * Kz_top @ delta_i0 + inc_term)

elif pol == 1: # TM
inc_term = 1j * delta_i0 * np.cos(theta) / n_top # tODO: inc term?
inc_term = 1j * delta_i0 * np.cos(theta) / n_top
T1 = np.linalg.inv(G + 1j * Kz_top / (n_top ** 2) @ F) @ (1j * Kz_top / (n_top ** 2) @ delta_i0 + inc_term)

# T1 = np.linalg.inv(G + 1j * YZ_I @ F) @ (1j * YZ_I @ delta_i0 + inc_term)
Expand Down
191 changes: 13 additions & 178 deletions meent/on_torch/emsolver/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def __init__(self, n_top=1., n_bot=1., theta=0., phi=0., psi=None, pol=0., fto=(
self.layer_info_list = []
self.T1 = None

self.rayleigh_r = None # TODO
self.rayleigh_t = None
self.rayleigh_R = None # TODO: other bds
self.rayleigh_T = None

@property
def device(self):
Expand All @@ -72,16 +72,6 @@ def device(self, device):
else:
raise ValueError

# TODO: need this?
# try:
# self._theta = self._theta.to(self.device)
# self._phi = self._phi.to(self.device)
# self._psi = self._psi.to(self.device)
# self.thickness = self._thickness.to(self.device)
# except AssertionError as e:
# print(f'{e}. Get back to CPU')
# self._device = torch.device('cpu')

@property
def type_complex(self):
return self._type_complex
Expand All @@ -95,16 +85,6 @@ def type_complex(self, type_complex):
else:
raise ValueError('type_complex')

# TODO: need this?
# self._type_float = torch.float64 if self.type_complex is not torch.complex64 else torch.float32
# self._type_int = torch.int64 if self.type_complex is not torch.complex64 else torch.int32
# self._theta = self._theta.to(self.type_float)
# self._phi = self._phi.to(self.type_float)
# self._psi = self._psi.to(self.type_float)

# self.fto = self._fto
# self.thickness = self._thickness

@property
def type_float(self):
return self._type_float
Expand Down Expand Up @@ -229,22 +209,6 @@ def thickness(self, thickness):
else:
raise ValueError

# def get_kx_vector(self, wavelength):
#
# k0 = 2 * torch.pi / wavelength
# fourier_indices_x = torch.arange(-self.fto[0], self.fto[0] + 1, device=self.device,
# dtype=self.type_float)
# if self.grating_type == 0:
# kx = k0 * (self.n_top * torch.sin(self.theta) + fourier_indices_x * (wavelength / self.period[0])
# ).type(self.type_complex)
# else:
# kx = k0 * (self.n_top * torch.sin(self.theta) * torch.cos(self.phi) + fourier_indices_x * (
# wavelength / self.period[0])).type(self.type_complex)
#
# # kx = torch.where(kx == 0, self.perturbation, kx)
#
# return kx

def get_kx_ky_vector(self, wavelength):

fto_x_range = torch.arange(-self.fto[0], self.fto[0] + 1, device=self.device,
Expand All @@ -264,7 +228,7 @@ def solve_1d(self, wavelength, epx_conv_all, epy_conv_all, epz_conv_i_all):

self.layer_info_list = []
self.T1 = None
self.rayleigh_r, self.rayleigh_t = [], [] # tODO
self.rayleigh_R, self.rayleigh_T = [], []

ff_x = self.fto[0] * 2 + 1

Expand All @@ -275,10 +239,6 @@ def solve_1d(self, wavelength, epx_conv_all, epy_conv_all, epz_conv_i_all):
kz_top, kz_bot, F, G, T \
= transfer_1d_1(self.pol, ff_x, kx, self.n_top, self.n_bot, device=self.device, type_complex=self.type_complex)

# kx, Kx, k_I_z, k_II_z, f, YZ_I, g, inc_term, T \
# = transfer_1d_1(ff_x, self.pol, k0, self.n_top, self.n_bot, self.kx,
# self.theta, delta_i0, self.fto,
# device=self.device, type_complex=self.type_complex)
elif self.connecting_algo == 'SMM':
Kx, Wg, Vg, Kzg, Wr, Vr, Kzr, Wt, Vt, Kzt, Ar, Br, Sg \
= scattering_1d_1(k0, self.n_top, self.n_bot, self.theta, self.phi, fourier_indices, self.period,
Expand All @@ -295,29 +255,6 @@ def solve_1d(self, wavelength, epx_conv_all, epy_conv_all, epz_conv_i_all):

d = self.thickness[layer_index]

# if self.pol == 0:
# E_conv_i = None
# A = Kx ** 2 - E_conv
# Eig.perturbation = self.perturbation
# eigenvalues, W = Eig.apply(A)
# q = eigenvalues ** 0.5
# Q = torch.diag(q)
# V = W @ Q
#
# elif self.pol == 1:
# E_conv_i = torch.linalg.inv(E_conv)
# B = Kx @ E_conv_i @ Kx - torch.eye(E_conv.shape[0], device=self.device, dtype=self.type_complex)
# # o_E_conv_i = torch.linalg.inv(o_E_conv)
#
# Eig.perturbation = self.perturbation
# eigenvalues, W = Eig.apply(E_conv @ B)
# q = eigenvalues ** 0.5
# Q = torch.diag(q)
# # V = o_E_conv @ W @ Q
# V = E_conv_i @ W @ Q
#
# else:
# raise ValueError
if self.connecting_algo == 'TMM':
W, V, q = transfer_1d_2(self.pol, kx, epx_conv, epy_conv, epz_conv_i, device=self.device, type_complex=self.type_complex)

Expand All @@ -331,116 +268,26 @@ def solve_1d(self, wavelength, epx_conv_all, epy_conv_all, epz_conv_i_all):
else:
raise ValueError

# if self.algo == 'TMM':
# X, f, g, T, a_i, b = transfer_1d_2(k0, q, d, W, V, f, g, self.fto, T,
# device=self.device, type_complex=self.type_complex)
#
# layer_info = [E_conv_i, q, W, X, a_i, b, d]
# self.layer_info_list.append(layer_info)
#
# elif self.algo == 'SMM':
# A, B, S_dict, Sg = scattering_1d_2(W, Wg, V, Vg, d, k0, Q, Sg)
# else:
# raise ValueError

if self.connecting_algo == 'TMM':
de_ri, de_ti, T1 = transfer_1d_4(self.pol, F, G, T, kz_top, kz_bot, self.theta, self.n_top, self.n_bot,
de_ri, de_ti, T1, [R], [T] = transfer_1d_4(self.pol, F, G, T, kz_top, kz_bot, self.theta, self.n_top, self.n_bot,
device=self.device, type_complex=self.type_complex)
self.T1 = T1
self.rayleigh_R = [R]
self.rayleigh_T = [T]

elif self.connecting_algo == 'SMM':
de_ri, de_ti = scattering_1d_3(Wt, Wg, Vt, Vg, Sg, self.ff, Wr, self.fto, Kzr, Kzt,
self.n_top, self.n_bot, self.theta, self.pol)
else:
raise ValueError

return de_ri, de_ti, self.rayleigh_r, self.rayleigh_t, self.layer_info_list, self.T1

def solve_1d_conical(self, wavelength, E_conv_all, o_E_conv_all):
"""
Deprecated.
Args:
wavelength:
E_conv_all:
o_E_conv_all:
Returns:
"""
self.layer_info_list = []
self.T1 = None
self.rayleigh_r, self.rayleigh_t = [], []

# fourier_indices = torch.arange(-self.fto, self.fto + 1, device=self.device)
ff = self.fto[0] * 2 + 1

delta_i0 = torch.zeros(ff, device=self.device, dtype=self.type_complex)
delta_i0[self.fto[0]] = 1

k0 = 2 * torch.pi / wavelength

if self.algo == 'TMM':
Kx, ky, k_I_z, k_II_z, varphi, Y_I, Y_II, Z_I, Z_II, big_F, big_G, big_T \
= transfer_1d_conical_1(ff, k0, self.n_top, self.n_bot, self.kx_vector, self.theta, self.phi,
device=self.device, type_complex=self.type_complex)
elif self.algo == 'SMM':
print('SMM for 1D conical is not implemented')
return torch.nan, torch.nan
else:
raise ValueError

count = min(len(E_conv_all), len(o_E_conv_all), len(self.thickness))

# From the last layer
for layer_index in range(count)[::-1]:

E_conv = E_conv_all[layer_index]
# o_E_conv = o_E_conv_all[layer_index]
o_E_conv = None

d = self.thickness[layer_index]

E_conv_i = torch.linalg.inv(E_conv)
# o_E_conv_i = torch.linalg.inv(o_E_conv)
o_E_conv_i = None

if self.algo == 'TMM':
big_X, big_F, big_G, big_T, big_A_i, big_B, W_1, W_2, V_11, V_12, V_21, V_22, q_1, q_2 \
= transfer_1d_conical_2(k0, Kx, ky, E_conv, E_conv_i, o_E_conv_i, ff, d,
varphi, big_F, big_G, big_T,
device=self.device, type_complex=self.type_complex)

layer_info = [E_conv_i, q_1, q_2, W_1, W_2, V_11, V_12, V_21, V_22, big_X, big_A_i, big_B, d]
self.layer_info_list.append(layer_info)

elif self.algo == 'SMM':
raise ValueError
else:
raise ValueError

if self.algo == 'TMM':
de_ri, de_ti, big_T1, self.rayleigh_r, self.rayleigh_t = transfer_1d_conical_3(big_F, big_G, big_T, Z_I,
Y_I, self.psi, self.theta,
ff,
delta_i0, k_I_z, k0,
self.n_top, self.n_bot,
k_II_z,
device=self.device,
type_complex=self.type_complex)
self.T1 = big_T1

elif self.algo == 'SMM':
raise ValueError
else:
raise ValueError

return de_ri, de_ti, self.rayleigh_r, self.rayleigh_t, self.layer_info_list, self.T1
return de_ri, de_ti, self.rayleigh_R, self.rayleigh_T, self.layer_info_list, self.T1

def solve_2d(self, wavelength, epx_conv_all, epy_conv_all, epz_conv_i_all):

self.layer_info_list = []
self.T1 = None
self.rayleigh_r, self.rayleigh_t = [], []
self.rayleigh_R, self.rayleigh_T = [], []

ff_x = self.fto[0] * 2 + 1
ff_y = self.fto[1] * 2 + 1
Expand Down Expand Up @@ -472,20 +319,10 @@ def solve_2d(self, wavelength, epx_conv_all, epy_conv_all, epz_conv_i_all):
d = self.thickness[layer_index]

if self.connecting_algo == 'TMM':
# W, V, q = transfer_2d_wv(ff_xy, Kx, E_conv_i, Ky, o_E_conv_i, E_conv,
# device=self.device, type_complex=self.type_complex)

W, V, q = transfer_2d_2(kx, ky, epx_conv, epy_conv, epz_conv_i, device=self.device,
type_complex=self.type_complex)

# big_X, big_F, big_G, big_T, big_A_i, big_B, \
# W_11, W_12, W_21, W_22, V_11, V_12, V_21, V_22 \
# = transfer_2d_2(k0, d, W, V, center, q, varphi, I, O, big_F, big_G, big_T, device=self.device,
# type_complex=self.type_complex)
#
# layer_info = [E_conv_i, q, W_11, W_12, W_21, W_22, V_11, V_12, V_21, V_22, big_X, big_A_i, big_B, d]
# self.layer_info_list.append(layer_info)

big_X, big_F, big_G, big_T, big_A_i, big_B, \
= transfer_2d_3(k0, W, V, q, d, varphi, big_F, big_G, big_T, device=self.device,
type_complex=self.type_complex)
Expand All @@ -499,15 +336,13 @@ def solve_2d(self, wavelength, epx_conv_all, epy_conv_all, epz_conv_i_all):
else:
raise ValueError

if self.connecting_algo == 'TMM': # TODO: cleaning
# de_ri, de_ti, big_T1, self.rayleigh_r, self.rayleigh_t = transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, self.psi, self.theta, ff_xy,
# delta_i0, k_I_z, k0, self.n_top, self.n_bot, k_II_z, device=self.device,
# type_complex=self.type_complex)
# TODO: AA and BB
de_ri, de_ti, big_T1, AA, BB = transfer_2d_4(big_F, big_G, big_T, kz_top, kz_bot, self.psi, self.theta,
if self.connecting_algo == 'TMM':
de_ri, de_ti, big_T1, [R_s, R_p], [T_s, T_p], = transfer_2d_4(big_F, big_G, big_T, kz_top, kz_bot, self.psi, self.theta,
self.n_top, self.n_bot, device=self.device,
type_complex=self.type_complex)
self.T1 = big_T1
self.rayleigh_R = [R_s, R_p]
self.rayleigh_T = [T_s, T_p]

elif self.connecting_algo == 'SMM':
de_ri, de_ti = scattering_2d_3(Wt, Wg, Vt, Vg, Sg, Wr, Kx, Ky, Kzr, Kzt, kz_inc, self.n_top,
Expand All @@ -518,4 +353,4 @@ def solve_2d(self, wavelength, epx_conv_all, epy_conv_all, epz_conv_i_all):
de_ri = de_ri.reshape((ff_y, ff_x)).T
de_ti = de_ti.reshape((ff_y, ff_x)).T

return de_ri, de_ti, self.rayleigh_r, self.rayleigh_t, self.layer_info_list, self.T1
return de_ri, de_ti, self.rayleigh_R, self.rayleigh_T, self.layer_info_list, self.T1
Loading

0 comments on commit b60e942

Please # to comment.