From 856f237bd2bbebb44a3454729a011fe8b40731e2 Mon Sep 17 00:00:00 2001 From: rahul-flex Date: Fri, 4 Apr 2025 14:52:57 -0400 Subject: [PATCH] feat: autograd support for rotated Box --- CHANGELOG.md | 1 + tests/test_components/test_autograd.py | 120 ++++++++ .../test_box_compute_derivatives.py | 246 ++++++++++++++++ tidy3d/components/geometry/base.py | 269 ++++++++++++------ tidy3d/web/api/autograd/autograd.py | 46 ++- 5 files changed, 566 insertions(+), 116 deletions(-) create mode 100644 tests/test_components/test_box_compute_derivatives.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 86f7ec540..ac6755e2d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - A property `interior_angle` in `PolySlab` that stores angles formed inside polygon by two adjacent edges. - `eps_component` argument in `td.Simulation.plot_eps()` to optionally select a specific permittivity component to plot (eg. `"xx"`). - Monitor `AuxFieldTimeMonitor` for aux fields like the free carrier density in `TwoPhotonAbsorption`. +- Gradient computation for rotated boxes in Transformed. ### Fixed - Compatibility with `xarray>=2025.03`. diff --git a/tests/test_components/test_autograd.py b/tests/test_components/test_autograd.py index d7268e8aa..54de293fb 100644 --- a/tests/test_components/test_autograd.py +++ b/tests/test_components/test_autograd.py @@ -2150,3 +2150,123 @@ def objective(args): # model is called without a frequency with AssertLogLevel("INFO"): grad = ag.grad(objective)(params0) + + +def make_sim_rotation(center: tuple, size: tuple, angle: float, axis: int): + wavelength = 1.5 + L = 10 * wavelength + freq0 = td.C_0 / wavelength + buffer = 1.0 * wavelength + + # Source + src = td.PointDipole( + center=(-L / 2 + buffer, 0, 0), + source_time=td.GaussianPulse(freq0=freq0, fwidth=freq0 / 10.0), + polarization="Ez", + ) + # Monitor + mnt = td.FieldMonitor( + center=( + +L / 2 - buffer, + 0.5 * buffer, + 0.5 * buffer, + ), + size=(0.0, 0.0, 0.0), + freqs=[freq0], + name="point", + ) + # The box geometry + base_box = td.Box(center=center, size=size) + if angle is not None: + base_box = base_box.rotated(angle, axis) + + scatterer = td.Structure( + geometry=base_box, + medium=td.Medium(permittivity=2.0), + ) + + sim = td.Simulation( + size=(L, L, L), + grid_spec=td.GridSpec.auto(min_steps_per_wvl=50), + structures=[scatterer], + sources=[src], + monitors=[mnt], + run_time=120 / freq0, + ) + return sim + + +def objective_fn(center, size, angle, axis): + sim = make_sim_rotation(center, size, angle, axis) + sim_data = web.run(sim, task_name="emulated_rot_test", local_gradient=True, verbose=False) + return anp.sum(sim_data.get_intensity("point").values) + + +def get_grad(center, size, angle, axis): + def wrapped(c, s): + return objective_fn(c, s, angle, axis) + + val, (grad_c, grad_s) = ag.value_and_grad(wrapped, argnum=(0, 1))(center, size) + return val, grad_c, grad_s + + +@pytest.mark.numerical +@pytest.mark.parametrize( + "angle_deg, axis", + [ + (0.0, 1), + (180.0, 1), + (90.0, 1), + (270.0, 1), + ], +) +def test_box_rotation_gradients(use_emulated_run, angle_deg, axis): + center0 = (0.0, 0.0, 0.0) + size0 = (2.0, 2.0, 2.0) + + angle_rad = np.deg2rad(angle_deg) + val, grad_c, grad_s = get_grad(center0, size0, angle=None, axis=None) + npx, npy, npz = grad_c + sSx, sSy, sSz = grad_s + + assert not np.allclose(grad_c, 0.0), "center gradient is all zero." + assert not np.allclose(grad_s, 0.0), "size gradient is all zero." + + if angle_deg == 180.0: + # rotating 180° about y => (x,z) become negated, y stays same + _, grad_c_ref, grad_s_ref = get_grad(center0, size0, angle_rad, axis) + rSx, rSy, rSz = grad_s_ref + rx, ry, rz = grad_c_ref + + assert np.allclose(grad_c[0], -grad_c_ref[0], atol=1e-6), "center_x sign mismatch" + assert np.allclose(grad_c[1], grad_c_ref[1], atol=1e-6), "center_y mismatch" + assert np.allclose(grad_c[2], -grad_c_ref[2], atol=1e-6), "center_z sign mismatch" + assert np.allclose(grad_s, grad_s_ref, atol=1e-6), "size grads changed unexpectedly" + + elif angle_deg == 90.0: + # rotating 90° about y => new x= old z, new z=- old x, y stays same + _, grad_c_ref, grad_s_ref = get_grad(center0, size0, angle_rad, axis) + rSx, rSy, rSz = grad_s_ref + rx, ry, rz = grad_c_ref + + assert np.allclose(npx, rz, atol=1e-6), "center_x != old center_z" + assert np.allclose(npy, ry, atol=1e-6), "center_y changed unexpectedly" + assert np.allclose(npz, -rx, atol=1e-6), "center_z != - old center_x" + + assert np.allclose(sSx, rSz, atol=1e-6), "size_x != old size_z" + assert np.allclose(sSy, rSy, atol=1e-6), "size_y changed unexpectedly" + assert np.allclose(sSz, rSx, atol=1e-6), "size_z != old size_x" + + elif angle_deg == 270.0: + # rotating 270° about y => new x= - old z, new z= old x, y stays same + _, grad_c_ref, grad_s_ref = get_grad(center0, size0, angle_rad, axis) + rSx, rSy, rSz = grad_s_ref + rx, ry, rz = grad_c_ref + + assert np.allclose(npx, -rz, atol=1e-6), "center_x != - old center_z" + assert np.allclose(npy, ry, atol=1e-6), "center_y changed unexpectedly" + assert np.allclose(npz, rx, atol=1e-6), "center_z != old center_x" + + assert np.allclose(sSx, rSz, atol=1e-6), "size_x != old size_z" + assert np.allclose(sSy, rSy, atol=1e-6), "size_y changed unexpectedly" + assert np.allclose(sSz, rSx, atol=1e-6), "size_z != old size_x" diff --git a/tests/test_components/test_box_compute_derivatives.py b/tests/test_components/test_box_compute_derivatives.py new file mode 100644 index 000000000..4060d212f --- /dev/null +++ b/tests/test_components/test_box_compute_derivatives.py @@ -0,0 +1,246 @@ +import atexit +import os +from collections import defaultdict + +import autograd +import autograd.numpy as anp +import matplotlib.pyplot as plt +import numpy as np +import pytest +import tidy3d as td +import tidy3d.web as web + +SAVE_RESULTS = False +PLOT_RESULTS = False +RESULTS_DIR = "./fd_ad_results" +results_collector = defaultdict(list) + +wavelength = 1.5 +freq0 = td.C_0 / wavelength +L = 10 * wavelength +buffer = 1.0 * wavelength +run_time = 120 / freq0 + + +SCENARIOS = [ + { + "name": "(1) normal", + "has_background": False, + "background_eps": 3.0, + "box_eps": 2.0, + "rotation_deg": None, + "rotation_axis": None, + }, + { + "name": "(2) perm=1.5", + "has_background": True, + "background_eps": 1.5, + "box_eps": 2.0, + "rotation_deg": None, + "rotation_axis": None, + }, + { + "name": "(3) rotation=0 deg about z", + "has_background": False, + "background_eps": 1.5, + "box_eps": 2.0, + "rotation_deg": 0.0, + "rotation_axis": 2, + }, + { + "name": "(4) rotation=90 deg about z", + "has_background": False, + "background_eps": 1.5, + "box_eps": 2.0, + "rotation_deg": 90.0, + "rotation_axis": 2, + }, + { + "name": "(5) rotation=45 deg about y", + "has_background": False, + "background_eps": 1.5, + "box_eps": 2.0, + "rotation_deg": 45.0, + "rotation_axis": 1, + }, + { + "name": "(6) rotation=45 deg about x", + "has_background": False, + "background_eps": 1.5, + "box_eps": 2.0, + "rotation_deg": 45.0, + "rotation_axis": 0, + }, + { + "name": "(7) rotation=45 deg about z", + "has_background": False, + "background_eps": 1.5, + "box_eps": 2.0, + "rotation_deg": 45.0, + "rotation_axis": 2, + }, +] + +PARAM_LABELS = ["center_x", "center_x", "center_y", "center_z", "size_x", "size_y", "size_z"] + + +def make_sim(center: tuple, size: tuple, scenario: dict): + source = td.PointDipole( + center=(-L / 2 + buffer, 0.0, 0.0), + source_time=td.GaussianPulse(freq0=freq0, fwidth=freq0 / 10.0), + polarization="Ez", + ) + + monitor = td.FieldMonitor( + center=(+L / 2 - buffer, 0.5 * buffer, 0.5 * buffer), + size=(0, 0, 0), + freqs=[freq0], + name="point_out", + ) + + structures = [] + if scenario["has_background"]: + back_box = td.Box(center=(0.0, 0.0, 0.0), size=(4.0, 1.6, 1.6)) + background_box = td.Structure( + geometry=back_box, + medium=td.Medium(permittivity=scenario["background_eps"]), + ) + structures.append(background_box) + + scatter_box = td.Box(center=center, size=size) + + if scenario["rotation_deg"] is not None: + angle_rad = np.deg2rad(scenario["rotation_deg"]) + rotated_geom = scatter_box.rotated(angle_rad, scenario["rotation_axis"]) + else: + rotated_geom = scatter_box + + scatter_struct = td.Structure( + geometry=rotated_geom, + medium=td.Medium(permittivity=scenario["box_eps"]), + ) + structures.append(scatter_struct) + + sim = td.Simulation( + size=(L, L, L), + run_time=run_time, + grid_spec=td.GridSpec.auto(min_steps_per_wvl=50), + sources=[source], + monitors=[monitor], + structures=structures, + ) + return sim + + +def objective_fn(center, size, scenario): + sim = make_sim(center, size, scenario) + sim_data = web.run(sim, task_name="autograd_vs_fd_scenario", local_gradient=True, verbose=False) + return anp.sum(sim_data.get_intensity("point_out").values) + + +def fd_vs_ad_param(center, size, scenario, param_label, delta=1e-3): + val_and_grad_fn = autograd.value_and_grad( + lambda c, s: objective_fn(c, s, scenario), argnum=(0, 1) + ) + _, (grad_center, grad_size) = val_and_grad_fn(center, size) + + param_map = { + "center_x": (0, "center"), + "center_y": (1, "center"), + "center_z": (2, "center"), + "size_x": (0, "size"), + "size_y": (1, "size"), + "size_z": (2, "size"), + } + idx, which = param_map[param_label] + if which == "center": + ad_val = grad_center[idx] + else: + ad_val = grad_size[idx] + + center_arr = np.array(center, dtype=float) + size_arr = np.array(size, dtype=float) + + if which == "center": + cplus = center_arr.copy() + cminus = center_arr.copy() + cplus[idx] += delta + cminus[idx] -= delta + p_plus = objective_fn(tuple(cplus), tuple(size_arr), scenario) + p_minus = objective_fn(tuple(cminus), tuple(size_arr), scenario) + else: + splus = size_arr.copy() + sminus = size_arr.copy() + splus[idx] += delta + sminus[idx] -= delta + p_plus = objective_fn(tuple(center_arr), tuple(splus), scenario) + p_minus = objective_fn(tuple(center_arr), tuple(sminus), scenario) + + fd_val = (p_plus - p_minus) / (2.0 * delta) + return fd_val, ad_val, p_plus, p_minus + + +@pytest.mark.numerical +@pytest.mark.parametrize("scenario", SCENARIOS, ids=[s["name"] for s in SCENARIOS]) +@pytest.mark.parametrize( + "param_label", ["center_x", "center_y", "center_z", "size_x", "size_y", "size_z"] +) +def test_autograd_vs_fd_scenarios(scenario, param_label): + center0 = (0.0, 0.0, 0.0) + size0 = (2.0, 2.0, 2.0) + delta = 0.03 + + fd_val, ad_val, p_plus, p_minus = fd_vs_ad_param(center0, size0, scenario, param_label, delta) + + assert np.isfinite(fd_val), f"FD derivative is not finite for param={param_label}" + assert np.isfinite(ad_val), f"AD derivative is not finite for param={param_label}" + + denom = max(abs(fd_val), 1e-12) + rel_diff = abs(fd_val - ad_val) / denom + assert rel_diff < 0.3, f"Autograd vs FD mismatch: param={param_label}, diff={rel_diff:.1%}" + + results_collector[param_label].append((scenario["name"], rel_diff)) + + if SAVE_RESULTS: + os.makedirs(RESULTS_DIR, exist_ok=True) + results_data = { + "scenario_name": scenario["name"], + "param_label": param_label, + "delta": float(delta), + "fd_val": float(fd_val), + "ad_val": float(ad_val), + "p_plus": float(p_plus), + "p_minus": float(p_minus), + "rel_diff": float(rel_diff), + } + filename_npy = f"fd_ad_{scenario['name'].replace(' ', '_')}_{param_label}.npy" + np.save(os.path.join(RESULTS_DIR, filename_npy), results_data) + + +def finalize_plotting(): + if not PLOT_RESULTS: + return + + os.makedirs(RESULTS_DIR, exist_ok=True) + + for param_label in PARAM_LABELS: + scenario_data = results_collector[param_label] + if not scenario_data: + continue + scenario_names = [sd[0] for sd in scenario_data] + rel_diffs = [sd[1] for sd in scenario_data] + + plt.figure(figsize=(6, 4)) + plt.bar(scenario_names, rel_diffs, color="blue") + plt.xticks(rotation=45, ha="right") + plt.title(f"Relative Error for param='{param_label}'\n(FD vs AD)") + plt.ylabel("Relative Error") + plt.tight_layout() + + filename_png = f"rel_error_{param_label.replace('_', '-')}.png" + plt.savefig(os.path.join(RESULTS_DIR, filename_png)) + plt.close() + print(f"Saved bar chart => {filename_png}") + + +atexit.register(finalize_plotting) diff --git a/tidy3d/components/geometry/base.py b/tidy3d/components/geometry/base.py index dbf95692e..f0dc9501c 100644 --- a/tidy3d/components/geometry/base.py +++ b/tidy3d/components/geometry/base.py @@ -10,7 +10,6 @@ import autograd.numpy as np import pydantic.v1 as pydantic import shapely -import xarray as xr try: from matplotlib import patches @@ -28,7 +27,10 @@ from ...log import log from ...packaging import check_import, verify_packages_import from ..autograd import AutogradFieldMap, TracedCoordinate, TracedSize, get_static -from ..autograd.derivative_utils import DerivativeInfo, integrate_within_bounds +from ..autograd.derivative_utils import ( + DerivativeInfo, + DerivativeSurfaceMesh, +) from ..base import Tidy3dBaseModel, cached_property from ..transformation import RotationAroundAxis from ..types import ( @@ -61,6 +63,7 @@ ) POLY_GRID_SIZE = 1e-12 +_NUM_PTS_DIM_BOX_FACE = 200 _shapely_operations = { @@ -1977,8 +1980,7 @@ def _normal_axis(self) -> Axis: """Axis normal to the Box. Errors if box is not planar.""" if self.size.count(0.0) != 1: raise ValidationError( - "Tried to get 'normal_axis' of 'Box' that is not planar. " - f"Given 'size={self.size}.'" + f"Tried to get 'normal_axis' of 'Box' that is not planar. Given 'size={self.size}.'" ) return self.size.index(0.0) @@ -2480,11 +2482,15 @@ def _surface_area(self, bounds: Bound) -> float: """ Autograd code """ - def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: - """Compute the adjoint derivatives for this object.""" + def compute_derivatives( + self, derivative_info: DerivativeInfo, rotation_matrix: np.ndarray = None + ) -> AutogradFieldMap: + """Compute the adjoint derivatives for this object with optional rotation matrix.""" - # get gradients w.r.t. each of the 6 faces (in normal direction) - vjps_faces = self.derivative_faces(derivative_info=derivative_info) + # Compute gradients w.r.t. each of the 6 faces using the rotation matrix (if any) + vjps_faces = self.derivative_faces( + derivative_info=derivative_info, rotation_matrix=rotation_matrix + ) # post-process these values to give the gradients w.r.t. center and size vjps_center_size = self.derivatives_center_size(vjps_faces=vjps_faces) @@ -2522,18 +2528,26 @@ def derivatives_center_size(vjps_faces: Bound) -> dict[str, Coordinate]: size=tuple(vjp_size.tolist()), ) - def derivative_faces(self, derivative_info: DerivativeInfo) -> Bound: - """Derivative with respect to normal position of 6 faces of ``Box``.""" + def derivative_faces( + self, derivative_info: DerivativeInfo, rotation_matrix: np.ndarray = None + ) -> Bound: + """Compute derivatives with respect to the normal position of 6 faces of `Box`, using rotation matrix if provided.""" # change in permittivity between inside and outside vjp_faces = np.zeros((2, 3)) for min_max_index, _ in enumerate((0, -1)): for axis in range(3): + if rotation_matrix is not None: + rotation_matrix = rotation_matrix + else: + rotation_matrix = None + vjp_face = self.derivative_face( min_max_index=min_max_index, axis_normal=axis, derivative_info=derivative_info, + rotation_matrix=rotation_matrix, ) # record vjp for this face @@ -2544,100 +2558,126 @@ def derivative_faces(self, derivative_info: DerivativeInfo) -> Bound: def derivative_face( self, min_max_index: int, - axis_normal: Axis, + axis_normal: int, derivative_info: DerivativeInfo, + rotation_matrix: np.ndarray = None, ) -> float: - """Compute the derivative w.r.t. shifting a face in the normal direction.""" - - # normal and tangential dims - dim_normal, dims_perp = self.pop_axis("xyz", axis=axis_normal) - fld_normal, flds_perp = self.pop_axis(("Ex", "Ey", "Ez"), axis=axis_normal) - - # normal and tangential fields - D_normal = derivative_info.D_der_map[fld_normal] - Es_perp = tuple(derivative_info.E_der_map[key] for key in flds_perp) - - # normal and tangential bounds - bounds_T = np.array(derivative_info.bounds).T # put (xyz) first dimension - bounds_normal, bounds_perp = self.pop_axis(bounds_T, axis=axis_normal) - - # define the integration plane - coord_normal_face = bounds_normal[min_max_index] - bounds_perp = np.array(bounds_perp).T # put (min / max) first dimension for integrator - - # normal field data coordinates - fld_coords_normal = D_normal.coords[dim_normal] - - # condition: a face is entirely outside of the domain, skip! - sign = (-1, 1)[min_max_index] - normal_coord_positive = sign * coord_normal_face - fld_coords_positive = sign * fld_coords_normal - if all(fld_coords_positive < normal_coord_positive): - log.info( - f"skipping VJP for 'Box' face '{dim_normal}{'-+'[min_max_index]}' " - "as it is entirely outside of the simulation domain." - ) - return 0.0 - - # grab permittivity data inside and outside edge in normal direction - eps_xyz = [derivative_info.eps_data[f"eps_{dim}{dim}"] for dim in "xyz"] - - # number of cells from the edge of data to register "inside" (index = num_cells_in - 1) - num_cells_in = 4 - - # if not enough data, just use best guess using eps in medium and simulation - needs_eps_approx = any(len(eps.coords[dim_normal]) <= num_cells_in for eps in eps_xyz) - - if derivative_info.eps_approx or needs_eps_approx: - eps_xyz_inside = 3 * [derivative_info.eps_in] - eps_xyz_outside = 3 * [derivative_info.eps_out] - # TODO: not tested... + """ + Compute the derivative (VJP) with respect to shifting a face of a rotated box, + using full integration over that face. This version uses bilinear interpolation + of the four corners to sample interior points. + """ - # otherwise, try to grab the data at the edges + if axis_normal == 0: + canonical_normal = np.array([1.0, 0.0, 0.0]) + elif axis_normal == 1: + canonical_normal = np.array([0.0, 1.0, 0.0]) + elif axis_normal == 2: + canonical_normal = np.array([0.0, 0.0, 1.0]) else: - if min_max_index == 0: - index_out, index_in = (0, num_cells_in - 1) - else: - index_out, index_in = (-1, -num_cells_in) - eps_xyz_inside = [eps.isel(**{dim_normal: index_in}) for eps in eps_xyz] - eps_xyz_outside = [eps.isel(**{dim_normal: index_out}) for eps in eps_xyz] + raise ValueError("Invalid axis_normal") - # put in normal / tangential basis - eps_in_normal, eps_in_perps = self.pop_axis(eps_xyz_inside, axis=axis_normal) - eps_out_normal, eps_out_perps = self.pop_axis(eps_xyz_outside, axis=axis_normal) + if min_max_index == 0: + canonical_normal *= -1.0 - # compute integration pre-factors - delta_eps_perps = [eps_in - eps_out for eps_in, eps_out in zip(eps_in_perps, eps_out_perps)] - delta_eps_inv_normal = 1.0 / eps_in_normal - 1.0 / eps_out_normal + if rotation_matrix is None: + rotation_matrix = np.eye(3) - def integrate_face(arr: xr.DataArray) -> complex: - """Interpolate and integrate a scalar field data over the face using bounds.""" + n_local = rotation_matrix @ canonical_normal + n_local = n_local / np.linalg.norm(n_local) - arr_at_face = arr.interp(**{dim_normal: float(coord_normal_face)}, assume_sorted=True) + def compute_tangential_vectors( + normal: np.ndarray, eps: float = 1e-8 + ) -> tuple[np.ndarray, np.ndarray]: + """Compute any two perpendicular tangential vectors t1, t2, given a normal.""" + if abs(normal[0]) > abs(normal[2]): + t1 = np.array([-normal[1], normal[0], 0.0]) + else: + t1 = np.array([0.0, -normal[2], normal[1]]) + t1_norm = np.linalg.norm(t1) + if t1_norm < eps: + raise ValueError("Degenerate normal vector.") + t1 = t1 / t1_norm + t2 = np.cross(normal, t1) + t2 /= np.linalg.norm(t2) + return t1, t2 + + t1_local, t2_local = compute_tangential_vectors(n_local) + + min_bound = np.array(self.center) - np.array(self.size) / 2.0 + max_bound = np.array(self.center) + np.array(self.size) / 2.0 + bounds_old = np.column_stack((min_bound, max_bound)) + + corners = np.array( + [ + [bounds_old[0, i], bounds_old[1, j], bounds_old[2, k]] + for i in (0, 1) + for j in (0, 1) + for k in (0, 1) + ] + ) - integral_result = integrate_within_bounds( - arr=arr_at_face, - dims=dims_perp, - bounds=bounds_perp, - ) + connectivity = { + 0: { # Faces perpendicular to x-axis + 0: [0, 1, 3, 2], + 1: [4, 5, 7, 6], + }, + 1: { # Faces perpendicular to y-axis + 0: [0, 1, 5, 4], + 1: [2, 3, 7, 6], + }, + 2: { # Faces perpendicular to z-axis + 0: [0, 4, 6, 2], + 1: [1, 5, 7, 3], + }, + } + + face_indices = connectivity[axis_normal][min_max_index] + face_corners = corners[face_indices, :] + + rotated_corners = (rotation_matrix @ face_corners.T).T + p1, p2, p3, p4 = rotated_corners + + num_s = _NUM_PTS_DIM_BOX_FACE + num_t = _NUM_PTS_DIM_BOX_FACE + s_vals = np.linspace(0, 1, 2 * num_s + 1)[1::2] + t_vals = np.linspace(0, 1, 2 * num_t + 1)[1::2] + S, T = np.meshgrid(s_vals, t_vals, indexing="ij") + + X = (1 - S) * (1 - T) * p1[0] + S * (1 - T) * p2[0] + S * T * p3[0] + (1 - S) * T * p4[0] + Y = (1 - S) * (1 - T) * p1[1] + S * (1 - T) * p2[1] + S * T * p3[1] + (1 - S) * T * p4[1] + Z = (1 - S) * (1 - T) * p1[2] + S * (1 - T) * p2[2] + S * T * p3[2] + (1 - S) * T * p4[2] + + centers = np.column_stack([X.ravel(), Y.ravel(), Z.ravel()]) + + tri1_area = 0.5 * np.linalg.norm(np.cross((p2 - p1), (p3 - p1))) + tri2_area = 0.5 * np.linalg.norm(np.cross((p4 - p1), (p3 - p1))) + face_area = tri1_area + tri2_area + + num_cells = (num_s) * (num_t) + if num_cells > 0: + cell_area = face_area / num_cells + else: + cell_area = face_area - return complex(integral_result.sum(dim="f")) + areas = cell_area * np.ones(centers.shape[0]) - # put together VJP using D_normal and E_perp integration - vjp_value = 0.0 + normals = np.tile(n_local, (centers.shape[0], 1)) + perps1 = np.tile(t1_local, (centers.shape[0], 1)) + perps2 = np.tile(t2_local, (centers.shape[0], 1)) - # perform D-normal integral - integrand_D = -delta_eps_inv_normal * D_normal - integral_D = integrate_face(integrand_D) - vjp_value += integral_D + surface_mesh = DerivativeSurfaceMesh( + centers=centers, + areas=areas, + normals=normals, + perps1=perps1, + perps2=perps2, + ) - # perform E-perpendicular integrals - for E_perp, delta_eps_perp in zip(Es_perp, delta_eps_perps): - integrand_E = E_perp * delta_eps_perp - integral_E = integrate_face(integrand_E) - vjp_value += integral_E + grads = derivative_info.grad_surfaces(surface_mesh=surface_mesh) + vjp_value = np.real(np.sum(grads).item()) - return np.real(vjp_value) + return vjp_value """Compound subclasses""" @@ -2664,7 +2704,15 @@ def _transform_is_invertible(cls, val): @pydantic.validator("geometry") def _geometry_is_finite(cls, val): - if not np.isfinite(val.bounds).all(): + def preprocess(value): + return value._value if isinstance(value, np.numpy_boxes.ArrayBox) else value + + processed_bounds = tuple( + tuple(preprocess(coord) for coord in bound) for bound in val.bounds + ) + + # Ensure all values are finite + if not np.isfinite(processed_bounds).all(): raise ValidationError( "Transformations are only supported on geometries with finite dimensions. " "Try using a large value instead of 'inf' when creating geometries that undergo " @@ -2918,6 +2966,45 @@ def _update_from_bounds(self, bounds: Tuple[float, float], axis: Axis) -> Transf new_geometry = self.geometry._update_from_bounds(bounds=new_bounds, axis=axis) return self.updated_copy(geometry=new_geometry) + def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: + """ + Compute the adjoint derivatives for the transformed geometry by + transforming the base geometry. + + """ + derivative_map = {} + + transform_paths = [path for path in derivative_info.paths if path[0] == "transform"] + + if derivative_info.paths == [("transform",)]: + derivative_info = derivative_info.updated_copy( + paths=[("geometry", "center"), ("geometry", "size"), ("transform",)], deep=False + ) + geometry_paths = [path for path in derivative_info.paths if path[0] == "geometry"] + + if "transform" in [p[0] for p in transform_paths]: + transform_paths = [("transform", i, j) for i in range(4) for j in range(4)] + + if geometry_paths: + T = self.transform + R = T[:3, :3] # Rotation matrix + + geo_info = derivative_info.updated_copy( + paths=[path[1:] for path in geometry_paths], deep=False + ) + transformed_geometry_derivatives = self.geometry.compute_derivatives(geo_info, R) + + transformed_center_gradient = np.array( + transformed_geometry_derivatives.get(("center",), (0.0, 0.0, 0.0)) + ) + transformed_size_gradient = np.array( + transformed_geometry_derivatives.get(("size",), (0.0, 0.0, 0.0)) + ) + derivative_map[("geometry", "center")] = transformed_center_gradient + derivative_map[("geometry", "size")] = transformed_size_gradient + derivative_map[("transform",)] = np.zeros((4, 4)) + return derivative_map + class ClipOperation(Geometry): """Class representing the result of a set operation between geometries.""" diff --git a/tidy3d/web/api/autograd/autograd.py b/tidy3d/web/api/autograd/autograd.py index 0d84553d5..4f34193de 100644 --- a/tidy3d/web/api/autograd/autograd.py +++ b/tidy3d/web/api/autograd/autograd.py @@ -1001,32 +1001,28 @@ def postprocess_adj( eps_background = None # manually override simulation medium as the background structure - if not isinstance(structure.geometry, td.Box): - # auto permittivity detection - sim_orig = sim_data_orig.simulation - plane_eps = eps_fwd.monitor.geometry - - # get permittivity without this structure - structs_no_struct = list(sim_orig.structures) - structs_no_struct.pop(structure_index) - sim_no_structure = sim_orig.updated_copy(structures=structs_no_struct) - eps_no_structure = sim_no_structure.epsilon( - box=plane_eps, coord_key="centers", freq=freq_adj - ) - - # get permittivity with structures on top of an infinite version of this structure - structs_inf_struct = list(sim_orig.structures)[structure_index + 1 :] - sim_inf_structure = sim_orig.updated_copy( - structures=structs_inf_struct, - medium=structure.medium, - monitors=[], - ) - eps_inf_structure = sim_inf_structure.epsilon( - box=plane_eps, coord_key="centers", freq=freq_adj - ) + # auto permittivity detection + sim_orig = sim_data_orig.simulation + plane_eps = eps_fwd.monitor.geometry + + # get permittivity without this structure + structs_no_struct = list(sim_orig.structures) + structs_no_struct.pop(structure_index) + sim_no_structure = sim_orig.updated_copy(structures=structs_no_struct) + eps_no_structure = sim_no_structure.epsilon( + box=plane_eps, coord_key="centers", freq=freq_adj + ) - else: - eps_no_structure = eps_inf_structure = None + # get permittivity with structures on top of an infinite version of this structure + structs_inf_struct = list(sim_orig.structures)[structure_index + 1 :] + sim_inf_structure = sim_orig.updated_copy( + structures=structs_inf_struct, + medium=structure.medium, + monitors=[], + ) + eps_inf_structure = sim_inf_structure.epsilon( + box=plane_eps, coord_key="centers", freq=freq_adj + ) # get minimum intersection of bounds with structure and sim struct_bounds = rmin_struct, rmax_struct = structure.geometry.bounds