Skip to content

feat: autograd support for rotated Box #2362

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- A property `interior_angle` in `PolySlab` that stores angles formed inside polygon by two adjacent edges.
- `eps_component` argument in `td.Simulation.plot_eps()` to optionally select a specific permittivity component to plot (eg. `"xx"`).
- Monitor `AuxFieldTimeMonitor` for aux fields like the free carrier density in `TwoPhotonAbsorption`.
- Gradient computation for rotated boxes in Transformed.

### Fixed
- Compatibility with `xarray>=2025.03`.
Expand Down
120 changes: 120 additions & 0 deletions tests/test_components/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2150,3 +2150,123 @@ def objective(args):
# model is called without a frequency
with AssertLogLevel("INFO"):
grad = ag.grad(objective)(params0)


def make_sim_rotation(center: tuple, size: tuple, angle: float, axis: int):
wavelength = 1.5
L = 10 * wavelength
freq0 = td.C_0 / wavelength
buffer = 1.0 * wavelength

# Source
src = td.PointDipole(
center=(-L / 2 + buffer, 0, 0),
source_time=td.GaussianPulse(freq0=freq0, fwidth=freq0 / 10.0),
polarization="Ez",
)
# Monitor
mnt = td.FieldMonitor(
center=(
+L / 2 - buffer,
0.5 * buffer,
0.5 * buffer,
),
size=(0.0, 0.0, 0.0),
freqs=[freq0],
name="point",
)
# The box geometry
base_box = td.Box(center=center, size=size)
if angle is not None:
base_box = base_box.rotated(angle, axis)

scatterer = td.Structure(
geometry=base_box,
medium=td.Medium(permittivity=2.0),
)

sim = td.Simulation(
size=(L, L, L),
grid_spec=td.GridSpec.auto(min_steps_per_wvl=50),
structures=[scatterer],
sources=[src],
monitors=[mnt],
run_time=120 / freq0,
)
return sim


def objective_fn(center, size, angle, axis):
sim = make_sim_rotation(center, size, angle, axis)
sim_data = web.run(sim, task_name="emulated_rot_test", local_gradient=True, verbose=False)
return anp.sum(sim_data.get_intensity("point").values)


def get_grad(center, size, angle, axis):
def wrapped(c, s):
return objective_fn(c, s, angle, axis)

val, (grad_c, grad_s) = ag.value_and_grad(wrapped, argnum=(0, 1))(center, size)
return val, grad_c, grad_s


@pytest.mark.numerical
@pytest.mark.parametrize(
"angle_deg, axis",
[
(0.0, 1),
(180.0, 1),
(90.0, 1),
(270.0, 1),
],
)
def test_box_rotation_gradients(use_emulated_run, angle_deg, axis):
center0 = (0.0, 0.0, 0.0)
size0 = (2.0, 2.0, 2.0)

angle_rad = np.deg2rad(angle_deg)
val, grad_c, grad_s = get_grad(center0, size0, angle=None, axis=None)
npx, npy, npz = grad_c
sSx, sSy, sSz = grad_s

assert not np.allclose(grad_c, 0.0), "center gradient is all zero."
assert not np.allclose(grad_s, 0.0), "size gradient is all zero."

if angle_deg == 180.0:
# rotating 180° about y => (x,z) become negated, y stays same
_, grad_c_ref, grad_s_ref = get_grad(center0, size0, angle_rad, axis)
rSx, rSy, rSz = grad_s_ref
rx, ry, rz = grad_c_ref

assert np.allclose(grad_c[0], -grad_c_ref[0], atol=1e-6), "center_x sign mismatch"
assert np.allclose(grad_c[1], grad_c_ref[1], atol=1e-6), "center_y mismatch"
assert np.allclose(grad_c[2], -grad_c_ref[2], atol=1e-6), "center_z sign mismatch"
assert np.allclose(grad_s, grad_s_ref, atol=1e-6), "size grads changed unexpectedly"

elif angle_deg == 90.0:
# rotating 90° about y => new x= old z, new z=- old x, y stays same
_, grad_c_ref, grad_s_ref = get_grad(center0, size0, angle_rad, axis)
rSx, rSy, rSz = grad_s_ref
rx, ry, rz = grad_c_ref

assert np.allclose(npx, rz, atol=1e-6), "center_x != old center_z"
assert np.allclose(npy, ry, atol=1e-6), "center_y changed unexpectedly"
assert np.allclose(npz, -rx, atol=1e-6), "center_z != - old center_x"

assert np.allclose(sSx, rSz, atol=1e-6), "size_x != old size_z"
assert np.allclose(sSy, rSy, atol=1e-6), "size_y changed unexpectedly"
assert np.allclose(sSz, rSx, atol=1e-6), "size_z != old size_x"

elif angle_deg == 270.0:
# rotating 270° about y => new x= - old z, new z= old x, y stays same
_, grad_c_ref, grad_s_ref = get_grad(center0, size0, angle_rad, axis)
rSx, rSy, rSz = grad_s_ref
rx, ry, rz = grad_c_ref

assert np.allclose(npx, -rz, atol=1e-6), "center_x != - old center_z"
assert np.allclose(npy, ry, atol=1e-6), "center_y changed unexpectedly"
assert np.allclose(npz, rx, atol=1e-6), "center_z != old center_x"

assert np.allclose(sSx, rSz, atol=1e-6), "size_x != old size_z"
assert np.allclose(sSy, rSy, atol=1e-6), "size_y changed unexpectedly"
assert np.allclose(sSz, rSx, atol=1e-6), "size_z != old size_x"
246 changes: 246 additions & 0 deletions tests/test_components/test_box_compute_derivatives.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
import atexit
import os
from collections import defaultdict

import autograd
import autograd.numpy as anp
import matplotlib.pyplot as plt
import numpy as np
import pytest
import tidy3d as td
import tidy3d.web as web

SAVE_RESULTS = False
PLOT_RESULTS = False
RESULTS_DIR = "./fd_ad_results"
results_collector = defaultdict(list)

wavelength = 1.5
freq0 = td.C_0 / wavelength
L = 10 * wavelength
buffer = 1.0 * wavelength
run_time = 120 / freq0


SCENARIOS = [
{
"name": "(1) normal",
"has_background": False,
"background_eps": 3.0,
"box_eps": 2.0,
"rotation_deg": None,
"rotation_axis": None,
},
{
"name": "(2) perm=1.5",
"has_background": True,
"background_eps": 1.5,
"box_eps": 2.0,
"rotation_deg": None,
"rotation_axis": None,
},
{
"name": "(3) rotation=0 deg about z",
"has_background": False,
"background_eps": 1.5,
"box_eps": 2.0,
"rotation_deg": 0.0,
"rotation_axis": 2,
},
{
"name": "(4) rotation=90 deg about z",
"has_background": False,
"background_eps": 1.5,
"box_eps": 2.0,
"rotation_deg": 90.0,
"rotation_axis": 2,
},
{
"name": "(5) rotation=45 deg about y",
"has_background": False,
"background_eps": 1.5,
"box_eps": 2.0,
"rotation_deg": 45.0,
"rotation_axis": 1,
},
{
"name": "(6) rotation=45 deg about x",
"has_background": False,
"background_eps": 1.5,
"box_eps": 2.0,
"rotation_deg": 45.0,
"rotation_axis": 0,
},
{
"name": "(7) rotation=45 deg about z",
"has_background": False,
"background_eps": 1.5,
"box_eps": 2.0,
"rotation_deg": 45.0,
"rotation_axis": 2,
},
]

PARAM_LABELS = ["center_x", "center_x", "center_y", "center_z", "size_x", "size_y", "size_z"]


def make_sim(center: tuple, size: tuple, scenario: dict):
source = td.PointDipole(
center=(-L / 2 + buffer, 0.0, 0.0),
source_time=td.GaussianPulse(freq0=freq0, fwidth=freq0 / 10.0),
polarization="Ez",
)

monitor = td.FieldMonitor(
center=(+L / 2 - buffer, 0.5 * buffer, 0.5 * buffer),
size=(0, 0, 0),
freqs=[freq0],
name="point_out",
)

structures = []
if scenario["has_background"]:
back_box = td.Box(center=(0.0, 0.0, 0.0), size=(4.0, 1.6, 1.6))
background_box = td.Structure(
geometry=back_box,
medium=td.Medium(permittivity=scenario["background_eps"]),
)
structures.append(background_box)

scatter_box = td.Box(center=center, size=size)

if scenario["rotation_deg"] is not None:
angle_rad = np.deg2rad(scenario["rotation_deg"])
rotated_geom = scatter_box.rotated(angle_rad, scenario["rotation_axis"])
else:
rotated_geom = scatter_box

scatter_struct = td.Structure(
geometry=rotated_geom,
medium=td.Medium(permittivity=scenario["box_eps"]),
)
structures.append(scatter_struct)

sim = td.Simulation(
size=(L, L, L),
run_time=run_time,
grid_spec=td.GridSpec.auto(min_steps_per_wvl=50),
sources=[source],
monitors=[monitor],
structures=structures,
)
return sim


def objective_fn(center, size, scenario):
sim = make_sim(center, size, scenario)
sim_data = web.run(sim, task_name="autograd_vs_fd_scenario", local_gradient=True, verbose=False)
return anp.sum(sim_data.get_intensity("point_out").values)


def fd_vs_ad_param(center, size, scenario, param_label, delta=1e-3):
val_and_grad_fn = autograd.value_and_grad(
lambda c, s: objective_fn(c, s, scenario), argnum=(0, 1)
)
_, (grad_center, grad_size) = val_and_grad_fn(center, size)

param_map = {
"center_x": (0, "center"),
"center_y": (1, "center"),
"center_z": (2, "center"),
"size_x": (0, "size"),
"size_y": (1, "size"),
"size_z": (2, "size"),
}
idx, which = param_map[param_label]
if which == "center":
ad_val = grad_center[idx]
else:
ad_val = grad_size[idx]

center_arr = np.array(center, dtype=float)
size_arr = np.array(size, dtype=float)

if which == "center":
cplus = center_arr.copy()
cminus = center_arr.copy()
cplus[idx] += delta
cminus[idx] -= delta
p_plus = objective_fn(tuple(cplus), tuple(size_arr), scenario)
p_minus = objective_fn(tuple(cminus), tuple(size_arr), scenario)
else:
splus = size_arr.copy()
sminus = size_arr.copy()
splus[idx] += delta
sminus[idx] -= delta
p_plus = objective_fn(tuple(center_arr), tuple(splus), scenario)
p_minus = objective_fn(tuple(center_arr), tuple(sminus), scenario)

fd_val = (p_plus - p_minus) / (2.0 * delta)
return fd_val, ad_val, p_plus, p_minus


@pytest.mark.numerical
@pytest.mark.parametrize("scenario", SCENARIOS, ids=[s["name"] for s in SCENARIOS])
@pytest.mark.parametrize(
"param_label", ["center_x", "center_y", "center_z", "size_x", "size_y", "size_z"]
)
def test_autograd_vs_fd_scenarios(scenario, param_label):
center0 = (0.0, 0.0, 0.0)
size0 = (2.0, 2.0, 2.0)
delta = 0.03

fd_val, ad_val, p_plus, p_minus = fd_vs_ad_param(center0, size0, scenario, param_label, delta)

assert np.isfinite(fd_val), f"FD derivative is not finite for param={param_label}"
assert np.isfinite(ad_val), f"AD derivative is not finite for param={param_label}"

denom = max(abs(fd_val), 1e-12)
rel_diff = abs(fd_val - ad_val) / denom
assert rel_diff < 0.3, f"Autograd vs FD mismatch: param={param_label}, diff={rel_diff:.1%}"

results_collector[param_label].append((scenario["name"], rel_diff))

if SAVE_RESULTS:
os.makedirs(RESULTS_DIR, exist_ok=True)
results_data = {
"scenario_name": scenario["name"],
"param_label": param_label,
"delta": float(delta),
"fd_val": float(fd_val),
"ad_val": float(ad_val),
"p_plus": float(p_plus),
"p_minus": float(p_minus),
"rel_diff": float(rel_diff),
}
filename_npy = f"fd_ad_{scenario['name'].replace(' ', '_')}_{param_label}.npy"
np.save(os.path.join(RESULTS_DIR, filename_npy), results_data)


def finalize_plotting():
if not PLOT_RESULTS:
return

os.makedirs(RESULTS_DIR, exist_ok=True)

for param_label in PARAM_LABELS:
scenario_data = results_collector[param_label]
if not scenario_data:
continue
scenario_names = [sd[0] for sd in scenario_data]
rel_diffs = [sd[1] for sd in scenario_data]

plt.figure(figsize=(6, 4))
plt.bar(scenario_names, rel_diffs, color="blue")
plt.xticks(rotation=45, ha="right")
plt.title(f"Relative Error for param='{param_label}'\n(FD vs AD)")
plt.ylabel("Relative Error")
plt.tight_layout()

filename_png = f"rel_error_{param_label.replace('_', '-')}.png"
plt.savefig(os.path.join(RESULTS_DIR, filename_png))
plt.close()
print(f"Saved bar chart => {filename_png}")


atexit.register(finalize_plotting)
Loading