diff --git a/.gitmodules b/.gitmodules index 2ce36b88a2..e69de29bb2 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +0,0 @@ -[submodule "docs/notebooks"] - path = docs/notebooks - url = git@github.com:flexcompute/tidy3d-notebooks.git -[submodule "docs/faq"] - path = docs/faq - url = https://github.com/flexcompute/tidy3d-faq diff --git a/scripts/ensure_imports_from_common.py b/scripts/ensure_imports_from_common.py new file mode 100644 index 0000000000..f44e93e5ff --- /dev/null +++ b/scripts/ensure_imports_from_common.py @@ -0,0 +1,114 @@ +"""Ensure tidy3d._common modules avoid importing from tidy3d outside tidy3d._common.""" + +from __future__ import annotations + +import argparse +import ast +import sys +from collections.abc import Iterable +from dataclasses import dataclass +from pathlib import Path + + +@dataclass(frozen=True) +class ImportViolation: + file: str + line: int + statement: str + + +def parse_args(argv: Iterable[str]) -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Ensure tidy3d._common does not import from tidy3d modules outside tidy3d._common." + ) + ) + parser.add_argument( + "--root", + default="tidy3d/_common", + help="Root directory to scan (relative to repo root).", + ) + return parser.parse_args(argv) + + +def main(argv: Iterable[str]) -> None: + args = parse_args(argv) + repo_root = Path.cwd().resolve() + root = (repo_root / args.root).resolve() + if not root.exists(): + print(f"No directory found at {root}. Skipping check.") + return + + violations: list[ImportViolation] = [] + for path in sorted(root.rglob("*.py")): + violations.extend(_violations_in_file(path, repo_root)) + + if violations: + print("Invalid tidy3d imports found in tidy3d._common:") + for violation in violations: + print(f"{violation.file}:{violation.line}: {violation.statement}") + raise SystemExit(1) + + print("No invalid tidy3d imports found in tidy3d._common.") + + +def _violations_in_file(path: Path, repo_root: Path) -> list[ImportViolation]: + source = path.read_text(encoding="utf-8") + try: + tree = ast.parse(source) + except SyntaxError as exc: + raise SystemExit(f"Syntax error parsing {path}: {exc}") from exc + + rel_path = str(path.relative_to(repo_root)) + violations: list[ImportViolation] = [] + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + name = alias.name + if name == "tidy3d" or ( + name.startswith("tidy3d.") and not name.startswith("tidy3d._common") + ): + violations.append( + ImportViolation( + file=rel_path, + line=node.lineno, + statement=_statement(source, node), + ) + ) + elif isinstance(node, ast.ImportFrom): + if node.level: + continue + module = node.module + if not module: + continue + if module == "tidy3d": + for alias in node.names: + if alias.name != "_common": + violations.append( + ImportViolation( + file=rel_path, + line=node.lineno, + statement=_statement(source, node), + ) + ) + continue + if module.startswith("tidy3d.") and not module.startswith("tidy3d._common"): + violations.append( + ImportViolation( + file=rel_path, + line=node.lineno, + statement=_statement(source, node), + ) + ) + return violations + + +def _statement(source: str, node: ast.AST) -> str: + segment = ast.get_source_segment(source, node) + if segment: + return " ".join(segment.strip().splitlines()) + return node.__class__.__name__ + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/tests/config/test_legacy_env.py b/tests/config/test_legacy_env.py index cfeafe574f..2785376f08 100644 --- a/tests/config/test_legacy_env.py +++ b/tests/config/test_legacy_env.py @@ -93,11 +93,9 @@ def test_env_vars_follow_profile_switch( def test_web_core_environment_reexports(): """Legacy `tidy3d.web.core.environment` exports remain available via config shim.""" - - import tidy3d.web as web + from tidy3d._common.web.core import environment from tidy3d.config import Env as ConfigEnv - environment = web.core.environment assert environment.Env is ConfigEnv with warnings.catch_warnings(record=True) as caught: diff --git a/tests/config/test_loader.py b/tests/config/test_loader.py index 6ef42bf2fa..203685f5a3 100644 --- a/tests/config/test_loader.py +++ b/tests/config/test_loader.py @@ -6,8 +6,8 @@ from click.testing import CliRunner from pydantic import Field +from tidy3d._common.config import loader as config_loader # import from common as it is patched from tidy3d.config import get_manager, reload_config -from tidy3d.config import loader as config_loader from tidy3d.config import registry as config_registry from tidy3d.config.legacy import finalize_legacy_migration from tidy3d.config.loader import migrate_legacy_config diff --git a/tests/test_components/test_IO.py b/tests/test_components/test_IO.py index 495a2501b5..2d2e1b4924 100644 --- a/tests/test_components/test_IO.py +++ b/tests/test_components/test_IO.py @@ -15,7 +15,8 @@ import tidy3d as td from tidy3d import __version__ -from tidy3d.components.base import DATA_ARRAY_MAP, Tidy3dBaseModel +from tidy3d.components.base import Tidy3dBaseModel +from tidy3d.components.data.data_array import DATA_ARRAY_MAP from tidy3d.components.data.sim_data import DATA_TYPE_MAP from ..test_data.test_monitor_data import make_flux_data diff --git a/tests/test_components/test_viz.py b/tests/test_components/test_viz.py index d4d3a50b29..3be8f59cfb 100644 --- a/tests/test_components/test_viz.py +++ b/tests/test_components/test_viz.py @@ -224,7 +224,9 @@ def plot_with_multi_viz_spec(alphas, facecolors, edgecolors, rng, use_viz_spec=T def test_no_matlab_install(monkeypatch): """Test that the `VisualizationSpec` only throws a warning on validation if matplotlib is not installed.""" - monkeypatch.setattr("tidy3d.components.viz.visualization_spec.MATPLOTLIB_IMPORTED", False) + monkeypatch.setattr( + "tidy3d._common.components.viz.visualization_spec.MATPLOTLIB_IMPORTED", False + ) EXPECTED_WARNING_MSG_PIECE = ( "matplotlib was not successfully imported, but is required to validate colors" diff --git a/tests/test_web/test_local_cache.py b/tests/test_web/test_local_cache.py index 987862011c..9691ddb884 100644 --- a/tests/test_web/test_local_cache.py +++ b/tests/test_web/test_local_cache.py @@ -56,7 +56,7 @@ @pytest.fixture(autouse=True) def _isolate_local_cache(tmp_path, monkeypatch): """Keep cache operations in a temp dir and avoid moving/deleting real cache.""" - import tidy3d.web.cache as cache_mod + import tidy3d._common.web.cache as cache_mod from tidy3d.config import get_manager real_remove_cache_dir = cache_mod._remove_cache_dir @@ -611,13 +611,13 @@ def test_cache_eviction_by_entries(monkeypatch, tmp_path_factory, basic_simulati file1 = tmp_path_factory.mktemp("art1") / CACHE_ARTIFACT_NAME file1.write_text("a") - cache.store_result(_FakeStubData(basic_simulation), MOCK_TASK_ID, str(file1), "FDTD") + cache.store_result(MOCK_TASK_ID, str(file1), "FDTD", simulation=basic_simulation) assert len(cache) == 1 sim2 = basic_simulation.updated_copy(shutoff=1e-4) file2 = tmp_path_factory.mktemp("art2") / CACHE_ARTIFACT_NAME file2.write_text("b") - cache.store_result(_FakeStubData(sim2), MOCK_TASK_ID, str(file2), "FDTD") + cache.store_result(MOCK_TASK_ID, str(file2), "FDTD", simulation=sim2) entries = cache.list() assert len(entries) == 1 @@ -631,13 +631,13 @@ def test_cache_eviction_by_size(monkeypatch, tmp_path_factory, basic_simulation) file1 = tmp_path_factory.mktemp("art1") / CACHE_ARTIFACT_NAME file1.write_text("a" * 8_000) - cache.store_result(_FakeStubData(basic_simulation), MOCK_TASK_ID, str(file1), "FDTD") + cache.store_result(MOCK_TASK_ID, str(file1), "FDTD", simulation=basic_simulation) assert len(cache) == 1 sim2 = basic_simulation.updated_copy(shutoff=1e-4) file2 = tmp_path_factory.mktemp("art2") / CACHE_ARTIFACT_NAME file2.write_text("b" * 8_000) - cache.store_result(_FakeStubData(sim2), MOCK_TASK_ID, str(file2), "FDTD") + cache.store_result(MOCK_TASK_ID, str(file2), "FDTD", simulation=sim2) entries = cache.list() assert len(cache) == 1 @@ -653,7 +653,7 @@ def test_cache_stats_tracking(monkeypatch, tmp_path_factory, basic_simulation): payload = "stats-payload" artifact.write_text(payload) - cache.store_result(_FakeStubData(basic_simulation), MOCK_TASK_ID, str(artifact), "FDTD") + cache.store_result(MOCK_TASK_ID, str(artifact), "FDTD", simulation=basic_simulation) stats_path = cache.root / CACHE_STATS_NAME assert stats_path.exists() @@ -692,12 +692,12 @@ def test_cache_stats_sync(monkeypatch, tmp_path_factory, basic_simulation): artifact1 = tmp_path_factory.mktemp("artifact_sync1") / CACHE_ARTIFACT_NAME payload1 = "sync-one" artifact1.write_text(payload1) - cache.store_result(_FakeStubData(sim1), f"{MOCK_TASK_ID}-1", str(artifact1), "FDTD") + cache.store_result(f"{MOCK_TASK_ID}-1", str(artifact1), "FDTD", simulation=sim1) artifact2 = tmp_path_factory.mktemp("artifact_sync2") / CACHE_ARTIFACT_NAME payload2 = "sync-two" artifact2.write_text(payload2) - cache.store_result(_FakeStubData(sim2), f"{MOCK_TASK_ID}-2", str(artifact2), "FDTD") + cache.store_result(f"{MOCK_TASK_ID}-2", str(artifact2), "FDTD", simulation=sim2) stats_path = cache.root / CACHE_STATS_NAME assert stats_path.exists() @@ -734,7 +734,7 @@ def _counting_iter(): artifact = tmp_path / "iter_guard.hdf5" artifact.write_text("payload") - cache.store_result(_FakeStubData(basic_simulation), MOCK_TASK_ID, str(artifact), "FDTD") + cache.store_result(MOCK_TASK_ID, str(artifact), "FDTD", simulation=basic_simulation) assert iter_calls["count"] == 0 entry_dirs = [] @@ -809,9 +809,7 @@ def test_cache_cli_commands(monkeypatch, tmp_path_factory, basic_simulation, tmp artifact = artifact_dir / CACHE_ARTIFACT_NAME artifact.write_text("payload_cli") - cache.store_result( - _FakeStubData(basic_simulation), f"{MOCK_TASK_ID}-cli", str(artifact), "FDTD" - ) + cache.store_result(f"{MOCK_TASK_ID}-cli", str(artifact), "FDTD", simulation=basic_simulation) info_result = runner.invoke(tidy3d_cli, ["cache", "info"]) assert info_result.exit_code == 0 diff --git a/tests/test_web/test_s3utils.py b/tests/test_web/test_s3utils.py index a6552737a4..d1849bb32a 100644 --- a/tests/test_web/test_s3utils.py +++ b/tests/test_web/test_s3utils.py @@ -4,9 +4,11 @@ import pytest -import tidy3d +from tidy3d._common.web.core import s3utils as s3utils_common from tidy3d.web.core import s3utils +s3_utils_path = "tidy3d._common.web.core.s3utils" + @pytest.fixture def mock_S3STSToken(monkeypatch): @@ -16,9 +18,9 @@ def mock_S3STSToken(monkeypatch): mock_token.get_bucket = lambda: "" mock_token.get_s3_key = lambda: "" mock_token.is_expired = lambda: False - mock_token.get_client = lambda: tidy3d.web.core.s3utils.boto3.client() + mock_token.get_client = lambda: s3utils_common.boto3.client() monkeypatch.setattr( - target=tidy3d.web.core.s3utils, name="_S3STSToken", value=MagicMock(return_value=mock_token) + target=s3utils_common, name="_S3STSToken", value=MagicMock(return_value=mock_token) ) return mock_token @@ -26,10 +28,10 @@ def mock_S3STSToken(monkeypatch): @pytest.fixture def mock_get_s3_sts_token(monkeypatch): def _mock_get_s3_sts_token(resource_id, remote_filename): - return s3utils._S3STSToken(resource_id, remote_filename) + return s3utils_common._S3STSToken(resource_id, remote_filename) monkeypatch.setattr( - target=tidy3d.web.core.s3utils, name="get_s3_sts_token", value=_mock_get_s3_sts_token + target=s3utils_common, name="get_s3_sts_token", value=_mock_get_s3_sts_token ) return _mock_get_s3_sts_token @@ -44,7 +46,7 @@ def mock_s3_client(monkeypatch): # Patch the `client` as it is imported within `tidy3d.web.core.s3utils.boto3` so that # whenever it's invoked (for example with "s3"), it returns our `mock_client`. monkeypatch.setattr( - target=tidy3d.web.core.s3utils.boto3, + target=s3utils_common.boto3, name="client", value=MagicMock(return_value=mock_client), ) @@ -148,11 +150,11 @@ def test_s3_token_get_client_with_custom_endpoint(tmp_path, monkeypatch): # Mock boto3.client mock_boto_client = MagicMock() - monkeypatch.setattr("tidy3d.web.core.s3utils.boto3.client", mock_boto_client) + monkeypatch.setattr(f"{s3_utils_path}.boto3.client", mock_boto_client) # Test 1: Without custom endpoint - use fresh config test_config = ConfigManager(config_dir=tmp_path) - monkeypatch.setattr("tidy3d.web.core.s3utils.config", test_config) + monkeypatch.setattr(f"{s3_utils_path}.config", test_config) token.get_client() # Verify boto3.client was called without endpoint_url @@ -195,12 +197,12 @@ def test_s3_token_get_client_respects_ssl_verify(tmp_path, monkeypatch): token = _S3STSToken(**token_data) mock_boto_client = MagicMock() - monkeypatch.setattr("tidy3d.web.core.s3utils.boto3.client", mock_boto_client) + monkeypatch.setattr(f"{s3_utils_path}.boto3.client", mock_boto_client) # Use fresh config with ssl_verify=False test_config = ConfigManager(config_dir=tmp_path) test_config.update_section("web", ssl_verify=False) - monkeypatch.setattr("tidy3d.web.core.s3utils.config", test_config) + monkeypatch.setattr(f"{s3_utils_path}.config", test_config) token.get_client() diff --git a/tests/test_web/test_tidy3d_task.py b/tests/test_web/test_tidy3d_task.py index b69ed14c2e..843592d472 100644 --- a/tests/test_web/test_tidy3d_task.py +++ b/tests/test_web/test_tidy3d_task.py @@ -7,6 +7,7 @@ from responses import matchers import tidy3d as td +from tests.test_web.test_webapi import task_core_path from tidy3d import config from tidy3d.web.core import http_util from tidy3d.web.core.environment import Env @@ -31,7 +32,7 @@ def make_sim(): @pytest.fixture def set_api_key(monkeypatch): """Set the api key.""" - import tidy3d.web.core.http_util as httputil + import tidy3d._common.web.core.http_util as httputil monkeypatch.setattr(httputil, "api_key", lambda: "apikey") monkeypatch.setattr(httputil, "get_version", lambda: td.version.__version__) @@ -85,7 +86,7 @@ def mock_download(*args, **kwargs): to_file = kwargs["to_file"] sim.to_file(to_file) - monkeypatch.setattr("tidy3d.web.core.task_core.download_gz_file", mock_download) + monkeypatch.setattr(f"{task_core_path}.download_gz_file", mock_download) responses.add( responses.GET, @@ -121,7 +122,7 @@ def test_upload(monkeypatch, set_api_key): def mock_download(*args, **kwargs): pass - monkeypatch.setattr("tidy3d.web.core.task_core.upload_file", mock_download) + monkeypatch.setattr(f"{task_core_path}.upload_file", mock_download) task = SimulationTask.get("3eb06d16-208b-487b-864b-e9b1d3e010a7") with tempfile.NamedTemporaryFile() as temp: task.upload_file(temp.name, "temp.json") @@ -353,7 +354,7 @@ def mock(*args, **kwargs): with open(file_path, "w") as f: f.write("0.3,5.7") - monkeypatch.setattr("tidy3d.web.core.task_core.download_file", mock) + monkeypatch.setattr(f"{task_core_path}.download_file", mock) responses.add( responses.GET, f"{Env.current.web_api_endpoint}/tidy3d/tasks/3eb06d16-208b-487b-864b-e9b1d3e010a7/detail", diff --git a/tests/test_web/test_webapi.py b/tests/test_web/test_webapi.py index bf206c92b5..bf29a83cb5 100644 --- a/tests/test_web/test_webapi.py +++ b/tests/test_web/test_webapi.py @@ -64,7 +64,7 @@ common.CONNECTION_RETRY_TIME = 0.1 INVALID_TASK_ID = "INVALID_TASK_ID" -task_core_path = "tidy3d.web.core.task_core" +task_core_path = "tidy3d._common.web.core.task_core" api_path = "tidy3d.web.api.webapi" config.switch_profile("dev") @@ -204,7 +204,7 @@ def mock_upload(monkeypatch, set_api_key): def mock_upload_file(*args, **kwargs): pass - monkeypatch.setattr("tidy3d.web.core.task_core.upload_file", mock_upload_file) + monkeypatch.setattr(f"{task_core_path}.upload_file", mock_upload_file) @pytest.fixture diff --git a/tests/test_web/test_webapi_account.py b/tests/test_web/test_webapi_account.py index 3d3882cc16..4b3f4d9415 100644 --- a/tests/test_web/test_webapi_account.py +++ b/tests/test_web/test_webapi_account.py @@ -11,7 +11,6 @@ ) from tidy3d.web.core.environment import Env -task_core_path = "tidy3d.web.core.task_core" api_path = "tidy3d.web.api.webapi" config.switch_profile("dev") diff --git a/tests/test_web/test_webapi_eme.py b/tests/test_web/test_webapi_eme.py index 646f317077..b847788409 100644 --- a/tests/test_web/test_webapi_eme.py +++ b/tests/test_web/test_webapi_eme.py @@ -7,6 +7,7 @@ from responses import matchers import tidy3d as td +from tests.test_web.test_webapi import task_core_path from tidy3d import EMESimulation from tidy3d.exceptions import SetupError from tidy3d.web.api.asynchronous import run_async @@ -37,7 +38,6 @@ EST_FLEX_UNIT = 11.11 FILE_SIZE_GB = 4.0 -task_core_path = "tidy3d.web.core.task_core" api_path = "tidy3d.web.api.webapi" @@ -89,7 +89,7 @@ def mock_upload(monkeypatch, set_api_key): def mock_upload_file(*args, **kwargs): pass - monkeypatch.setattr("tidy3d.web.core.task_core.upload_file", mock_upload_file) + monkeypatch.setattr(f"{task_core_path}.upload_file", mock_upload_file) @pytest.fixture diff --git a/tests/test_web/test_webapi_extra.py b/tests/test_web/test_webapi_extra.py index 16cfd9422d..315a40e9c7 100644 --- a/tests/test_web/test_webapi_extra.py +++ b/tests/test_web/test_webapi_extra.py @@ -5,15 +5,14 @@ import pytest import responses +from tests.test_web.test_webapi import task_core_path from tidy3d.web.api.webapi import delete, get_info, get_tasks, real_cost, start @responses.activate def test_get_info_not_found(monkeypatch): """Tests that get_info raises a ValueError when the task is not found.""" - monkeypatch.setattr( - "tidy3d.web.core.task_core.SimulationTask.get", lambda *args, **kwargs: None - ) + monkeypatch.setattr(f"{task_core_path}.SimulationTask.get", lambda *args, **kwargs: None) with pytest.raises(ValueError, match="Task not found."): get_info("non_existent_task_id") @@ -21,9 +20,7 @@ def test_get_info_not_found(monkeypatch): @responses.activate def test_start_not_found(monkeypatch): """Tests that start raises a ValueError when the task is not found.""" - monkeypatch.setattr( - "tidy3d.web.core.task_core.SimulationTask.get", lambda *args, **kwargs: None - ) + monkeypatch.setattr(f"{task_core_path}.SimulationTask.get", lambda *args, **kwargs: None) with pytest.raises(ValueError, match="Task not found."): start("non_existent_task_id") @@ -42,9 +39,7 @@ class MockFolder: def list_tasks(self): return [] - monkeypatch.setattr( - "tidy3d.web.core.task_core.Folder.get", lambda *args, **kwargs: MockFolder() - ) + monkeypatch.setattr(f"{task_core_path}.Folder.get", lambda *args, **kwargs: MockFolder()) assert get_tasks() == [] @@ -71,9 +66,7 @@ def list_tasks(self): MockTask(datetime(2023, 1, 3), "3"), ] - monkeypatch.setattr( - "tidy3d.web.core.task_core.Folder.get", lambda *args, **kwargs: MockFolder() - ) + monkeypatch.setattr(f"{task_core_path}.Folder.get", lambda *args, **kwargs: MockFolder()) tasks = get_tasks(order="old") assert [t["task_id"] for t in tasks] == ["1", "2", "3"] diff --git a/tests/test_web/test_webapi_heat.py b/tests/test_web/test_webapi_heat.py index 6b33a33488..46135cdac6 100644 --- a/tests/test_web/test_webapi_heat.py +++ b/tests/test_web/test_webapi_heat.py @@ -7,6 +7,7 @@ from responses import matchers import tidy3d as td +from tests.test_web.test_webapi import task_core_path from tidy3d import HeatSimulation from tidy3d.web.api.asynchronous import run_async from tidy3d.web.api.container import Batch, Job @@ -34,7 +35,6 @@ EST_FLEX_UNIT = 11.11 FILE_SIZE_GB = 4.0 -task_core_path = "tidy3d.web.core.task_core" api_path = "tidy3d.web.api.webapi" @@ -86,7 +86,7 @@ def mock_upload(monkeypatch, set_api_key): def mock_upload_file(*args, **kwargs): pass - monkeypatch.setattr("tidy3d.web.core.task_core.upload_file", mock_upload_file) + monkeypatch.setattr(f"{task_core_path}.upload_file", mock_upload_file) @pytest.fixture diff --git a/tests/test_web/test_webapi_mode.py b/tests/test_web/test_webapi_mode.py index 09a07d5c81..ebace6a4ac 100644 --- a/tests/test_web/test_webapi_mode.py +++ b/tests/test_web/test_webapi_mode.py @@ -8,7 +8,8 @@ from responses import matchers import tidy3d as td -from tidy3d.components.data.dataset import ModeIndexDataArray +from tests.test_web.test_webapi import task_core_path +from tidy3d.components.data.data_array import ModeIndexDataArray from tidy3d.plugins.mode import ModeSolver from tidy3d.web.api.asynchronous import run_async from tidy3d.web.api.container import Batch, Job @@ -35,7 +36,6 @@ EST_FLEX_UNIT = 11.11 FILE_SIZE_GB = 4.0 -task_core_path = "tidy3d.web.core.task_core" api_path = "tidy3d.web.api.webapi" f, AX = plt.subplots() @@ -125,9 +125,9 @@ def mock_upload_file(*args, **kwargs): pass monkeypatch.setattr( - "tidy3d.web.core.task_core.SimulationTask.upload_simulation", mock_upload_simulation + f"{task_core_path}.SimulationTask.upload_simulation", mock_upload_simulation ) - monkeypatch.setattr("tidy3d.web.core.task_core.upload_file", mock_upload_file) + monkeypatch.setattr(f"{task_core_path}.upload_file", mock_upload_file) return uploaded_stub diff --git a/tests/test_web/test_webapi_mode_sim.py b/tests/test_web/test_webapi_mode_sim.py index 2b9d9e7e09..332e49d324 100644 --- a/tests/test_web/test_webapi_mode_sim.py +++ b/tests/test_web/test_webapi_mode_sim.py @@ -7,6 +7,7 @@ from responses import matchers import tidy3d as td +from tests.test_web.test_webapi import task_core_path from tidy3d.plugins.mode import ModeSolver from tidy3d.web.api.asynchronous import run_async from tidy3d.web.api.container import Batch, Job @@ -33,7 +34,6 @@ EST_FLEX_UNIT = 11.11 FILE_SIZE_GB = 4.0 -task_core_path = "tidy3d.web.core.task_core" api_path = "tidy3d.web.api.webapi" @@ -121,9 +121,9 @@ def mock_upload_file(*args, **kwargs): pass monkeypatch.setattr( - "tidy3d.web.core.task_core.SimulationTask.upload_simulation", mock_upload_simulation + f"{task_core_path}.SimulationTask.upload_simulation", mock_upload_simulation ) - monkeypatch.setattr("tidy3d.web.core.task_core.upload_file", mock_upload_file) + monkeypatch.setattr(f"{task_core_path}.upload_file", mock_upload_file) return uploaded_stub diff --git a/tidy3d/__init__.py b/tidy3d/__init__.py index dd70f96c19..5388fd17c0 100644 --- a/tidy3d/__init__.py +++ b/tidy3d/__init__.py @@ -2,6 +2,9 @@ from __future__ import annotations +# ruff: noqa: I001 - ensure config is imported first +from .config import config + from tidy3d.components.base import Tidy3dBaseModel from tidy3d.components.boundary import BroadbandModeABCFitterParam, BroadbandModeABCSpec from tidy3d.components.data.index import SimulationDataMap @@ -458,8 +461,6 @@ from .components.transformation import RotationAroundAxis from .components.viz import VisualizationSpec, restore_matplotlib_rcparams -# config -from .config import config # constants imported as `C_0 = td.C_0` or `td.constants.C_0` from .constants import C_0, EPSILON_0, ETA_0, HBAR, K_B, MU_0, Q_e, inf @@ -634,7 +635,6 @@ def set_logging_level(level: str) -> None: "EMEScalarModeFieldDataArray", "EMESimulation", "EMESimulationData", - "EMESweepSpec", "EMEUniformGrid", "FieldData", "FieldDataset", diff --git a/tidy3d/_common/__init__.py b/tidy3d/_common/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tidy3d/_common/_runtime.py b/tidy3d/_common/_runtime.py new file mode 100644 index 0000000000..6dbf61accd --- /dev/null +++ b/tidy3d/_common/_runtime.py @@ -0,0 +1,12 @@ +"""Runtime environment detection for tidy3d. + +This module must have ZERO dependencies on other tidy3d modules to avoid +circular imports. It is imported very early in the initialization chain. +""" + +from __future__ import annotations + +import sys + +# Detect WASM/Pyodide environment where web and filesystem features are unavailable +WASM_BUILD = "pyodide" in sys.modules or sys.platform == "emscripten" diff --git a/tidy3d/_common/compat.py b/tidy3d/_common/compat.py new file mode 100644 index 0000000000..a616a41895 --- /dev/null +++ b/tidy3d/_common/compat.py @@ -0,0 +1,31 @@ +"""Compatibility layer for handling differences between package versions.""" + +from __future__ import annotations + +import importlib +from functools import cache + +from packaging.version import parse + +try: + from xarray.structure import alignment +except ImportError: + from xarray.core import alignment + +try: + from numpy import trapezoid as np_trapezoid +except ImportError: # NumPy < 2.0 + from numpy import trapz as np_trapezoid + +try: + from typing import Self, TypeAlias # Python >= 3.11 +except ImportError: # Python <3.11 + from typing_extensions import Self, TypeAlias + + +@cache +def _package_is_older_than(package: str, version: str) -> bool: + return parse(importlib.metadata.version(package)) < parse(version) + + +__all__ = ["Self", "TypeAlias", "_package_is_older_than", "alignment", "np_trapezoid"] diff --git a/tidy3d/_common/components/__init__.py b/tidy3d/_common/components/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tidy3d/_common/components/autograd/__init__.py b/tidy3d/_common/components/autograd/__init__.py new file mode 100644 index 0000000000..cd81c18f3d --- /dev/null +++ b/tidy3d/_common/components/autograd/__init__.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from tidy3d._common.components.autograd.boxes import TidyArrayBox +from tidy3d._common.components.autograd.functions import interpn +from tidy3d._common.components.autograd.types import ( + AutogradFieldMap, + InterpolationType, + PathType, + TracedArrayFloat2D, + TracedArrayLike, + TracedComplex, + TracedCoordinate, + TracedFloat, + TracedPoleAndResidue, + TracedPolesAndResidues, + TracedPositiveFloat, + TracedSize, + TracedSize1D, +) + +from .utils import get_static, hasbox, is_tidy_box, split_list + +__all__ = [ + "AutogradFieldMap", + "InterpolationType", + "PathType", + "TidyArrayBox", + "TracedArrayFloat2D", + "TracedArrayLike", + "TracedComplex", + "TracedCoordinate", + "TracedFloat", + "TracedPoleAndResidue", + "TracedPolesAndResidues", + "TracedPositiveFloat", + "TracedSize", + "TracedSize1D", + "get_static", + "hasbox", + "interpn", + "is_tidy_box", + "split_list", +] diff --git a/tidy3d/_common/components/autograd/boxes.py b/tidy3d/_common/components/autograd/boxes.py new file mode 100644 index 0000000000..d51e948a85 --- /dev/null +++ b/tidy3d/_common/components/autograd/boxes.py @@ -0,0 +1,162 @@ +# Adds some functionality to the autograd arraybox and related autograd patches +# NOTE: we do not subclass ArrayBox since that would break autograd's internal checks +from __future__ import annotations + +import importlib +from typing import TYPE_CHECKING, Any + +import autograd.numpy as anp +from autograd.extend import VJPNode, defjvp, register_notrace +from autograd.numpy.numpy_boxes import ArrayBox +from autograd.numpy.numpy_wrapper import _astype + +if TYPE_CHECKING: + from typing import Callable + +TidyArrayBox = ArrayBox # NOT a subclass + +_autograd_module_cache = {} # cache for imported autograd modules + +register_notrace(VJPNode, anp.full_like) + +defjvp( + _astype, + lambda g, ans, A, dtype, order="K", casting="unsafe", subok=True, copy=True: _astype(g, dtype), +) + +anp.astype = _astype +anp.permute_dims = anp.transpose + + +@classmethod +def from_arraybox(cls: Any, box: ArrayBox) -> TidyArrayBox: + """Construct a TidyArrayBox from an ArrayBox.""" + return cls(box._value, box._trace, box._node) + + +def __array_function__( + self: Any, + func: Callable, + types: list[Any], + args: tuple[Any, ...], + kwargs: dict[str, Any], +) -> Any: + """ + Handle the dispatch of NumPy functions to autograd's numpy implementation. + + Parameters + ---------- + self : Any + The instance of the class. + func : Callable + The NumPy function being called. + types : list[Any] + The types of the arguments that implement __array_function__. + args : tuple[Any, ...] + The positional arguments to the function. + kwargs : dict[str, Any] + The keyword arguments to the function. + + Returns + ------- + Any + The result of the function call, or NotImplemented. + + Raises + ------ + NotImplementedError + If the function is not implemented in autograd.numpy. + + See Also + -------- + https://numpy.org/doc/stable/reference/arrays.classes.html#numpy.class.__array_function__ + """ + if not all(t in TidyArrayBox.type_mappings for t in types): + return NotImplemented + + module_name = func.__module__ + + if module_name.startswith("numpy"): + anp_module_name = "autograd." + module_name + else: + return NotImplemented + + # Use the cached module if available + anp_module = _autograd_module_cache.get(anp_module_name) + if anp_module is None: + try: + anp_module = importlib.import_module(anp_module_name) + _autograd_module_cache[anp_module_name] = anp_module + except ImportError: + return NotImplemented + + f = getattr(anp_module, func.__name__, None) + if f is None: + return NotImplemented + + if f.__name__ == "nanmean": # somehow xarray always dispatches to nanmean + f = anp.mean + kwargs.pop("dtype", None) # autograd mean vjp doesn't support dtype + + return f(*args, **kwargs) + + +def __array_ufunc__( + self: Any, + ufunc: Callable, + method: str, + *inputs: Any, + **kwargs: dict[str, Any], +) -> Any: + """ + Handle the dispatch of NumPy ufuncs to autograd's numpy implementation. + + Parameters + ---------- + self : Any + The instance of the class. + ufunc : Callable + The universal function being called. + method : str + The method of the ufunc being called. + inputs : Any + The input arguments to the ufunc. + kwargs : dict[str, Any] + The keyword arguments to the ufunc. + + Returns + ------- + Any + The result of the ufunc call, or NotImplemented. + + See Also + -------- + https://numpy.org/doc/stable/reference/arrays.classes.html#numpy.class.__array_ufunc__ + """ + if method != "__call__": + return NotImplemented + + ufunc_name = ufunc.__name__ + + anp_ufunc = getattr(anp, ufunc_name, None) + if anp_ufunc is not None: + return anp_ufunc(*inputs, **kwargs) + + return NotImplemented + + +def item(self: Any) -> Any: + if self.size != 1: + raise ValueError("Can only convert an array of size 1 to a scalar") + return anp.ravel(self)[0] + + +TidyArrayBox._tidy = True +TidyArrayBox.from_arraybox = from_arraybox +TidyArrayBox.__array_namespace__ = lambda self, *, api_version=None: anp +TidyArrayBox.__array_ufunc__ = __array_ufunc__ +TidyArrayBox.__array_function__ = __array_function__ +TidyArrayBox.real = property(anp.real) +TidyArrayBox.imag = property(anp.imag) +TidyArrayBox.conj = anp.conj +TidyArrayBox.item = item diff --git a/tidy3d/_common/components/autograd/derivative_utils.py b/tidy3d/_common/components/autograd/derivative_utils.py new file mode 100644 index 0000000000..f91a7d7b16 --- /dev/null +++ b/tidy3d/_common/components/autograd/derivative_utils.py @@ -0,0 +1,1105 @@ +"""Utilities for autograd derivative computation and field gradient evaluation.""" + +from __future__ import annotations + +from contextlib import contextmanager +from dataclasses import dataclass, field, replace +from functools import reduce +from typing import TYPE_CHECKING, Any, Optional + +import numpy as np +import xarray as xr +from numpy.typing import NDArray + +from tidy3d._common.components.data.data_array import FreqDataArray, ScalarFieldDataArray +from tidy3d._common.components.types.base import ArrayLike, Bound +from tidy3d._common.config import config +from tidy3d._common.constants import C_0, EPSILON_0, LARGE_NUMBER, MU_0 +from tidy3d._common.log import log + +from .types import PathType +from .utils import get_static + +if TYPE_CHECKING: + from collections.abc import Iterator + from typing import Callable, Union + + from tidy3d._common.compat import Self + from tidy3d._common.components.types.base import xyz + + +FieldDataDict = dict[str, ScalarFieldDataArray] +PermittivityData = dict[str, ScalarFieldDataArray] +EpsType = FreqDataArray +ArrayFloat = NDArray[np.floating] +ArrayComplex = NDArray[np.complexfloating] + + +class LazyInterpolator: + """Lazy wrapper for interpolators that creates them on first access.""" + + def __init__(self, creator_func: Callable[[], Callable[[ArrayFloat], ArrayComplex]]) -> None: + """Initialize with a function that creates the interpolator when called.""" + self.creator_func = creator_func + self._interpolator: Optional[Callable[[ArrayFloat], ArrayComplex]] = None + + def __call__(self, *args: Any, **kwargs: Any) -> ArrayComplex: + """Create interpolator on first call and delegate to it.""" + if self._interpolator is None: + self._interpolator = self.creator_func() + return self._interpolator(*args, **kwargs) + + +@dataclass +class DerivativeInfo: + """Stores derivative information passed to the ``._compute_derivatives`` methods. + + This dataclass contains all the field data and parameters needed for computing + gradients with respect to geometry perturbations. + """ + + # Required fields + paths: list[PathType] + """List of paths to the traced fields that need derivatives calculated.""" + + E_der_map: FieldDataDict + """Electric field gradient map. + Dataset where the field components ("Ex", "Ey", "Ez") store the multiplication + of the forward and adjoint electric fields. The tangential components of this + dataset are used when computing adjoint gradients for shifting boundaries. + All components are used when computing volume-based gradients.""" + + D_der_map: FieldDataDict + """Displacement field gradient map. + Dataset where the field components ("Ex", "Ey", "Ez") store the multiplication + of the forward and adjoint displacement fields. The normal component of this + dataset is used when computing adjoint gradients for shifting boundaries.""" + + E_fwd: FieldDataDict + """Forward electric fields. + Dataset where the field components ("Ex", "Ey", "Ez") represent the forward + electric fields used for computing gradients for a given structure.""" + + E_adj: FieldDataDict + """Adjoint electric fields. + Dataset where the field components ("Ex", "Ey", "Ez") represent the adjoint + electric fields used for computing gradients for a given structure.""" + + D_fwd: FieldDataDict + """Forward displacement fields. + Dataset where the field components ("Ex", "Ey", "Ez") represent the forward + displacement fields used for computing gradients for a given structure.""" + + D_adj: FieldDataDict + """Adjoint displacement fields. + Dataset where the field components ("Ex", "Ey", "Ez") represent the adjoint + displacement fields used for computing gradients for a given structure.""" + + eps_data: PermittivityData + """Permittivity dataset. + Dataset of relative permittivity values along all three dimensions. + Used for automatically computing permittivity inside or outside of a simple geometry.""" + + eps_in: EpsType | None + """Permittivity inside the Structure. + Computed only when structure.medium.is_custom is False. Contains the simulation + permittivity inside the structure when the simulation background medium is set to + the structure medium and all structures after the current structure are kept. Should + be used as the inside permittivity for shape derivative computations.""" + + eps_out: EpsType + """Permittivity outside the Structure. + Contains the simulation permittivity outside the structure when the current structure + is removed from the structure list. Should be used as the outside permittivity for + shape derivative computations.""" + + bounds: Bound + """Geometry bounds. + Bounds corresponding to the structure, used in Medium calculations.""" + + bounds_intersect: Bound + """Geometry and simulation intersection bounds. + Bounds corresponding to the minimum intersection between the structure + and the simulation it is contained in.""" + + simulation_bounds: Bound + """Simulation bounds. + Bounds corresponding to the simulation domain containing this structure. + Unlike bounds_intersect, this is independent of the structure's bounds and + is purely based on the simulation geometry.""" + + frequencies: ArrayLike + """Frequencies at which the adjoint gradient should be computed.""" + + # Optional fields with defaults + + H_der_map: Optional[FieldDataDict] = None + """Magnetic field gradient map. + Dataset where the field components ("Hx", "Hy", "Hz") store the multiplication + of the forward and adjoint magnetic fields. The tangential component of this + dataset is used when computing adjoint gradients for shifting boundaries of + structures composed of PEC mediums.""" + + H_fwd: Optional[FieldDataDict] = None + """Forward magnetic fields. + Dataset where the field components ("Hx", "Hy", "Hz") represent the forward + magnetic fields used for computing gradients for a given structure.""" + + H_adj: Optional[FieldDataDict] = None + """Adjoint magnetic fields. + Dataset where the field components ("Hx", "Hy", "Hz") represent the adjoint + magnetic fields used for computing gradients for a given structure.""" + + is_medium_pec: bool = False + """Indicates if structure material is PEC. + If True, the structure contains a PEC material which changes the gradient + formulation at the boundary compared to the dielectric case.""" + + background_medium_is_pec: bool = False + """Indicates if structure material is PEC. + If True, the structure is partially surrounded by a PEC material.""" + + interpolators: Optional[dict] = None + """Pre-computed interpolators. + Optional pre-computed interpolators for field components and permittivity data. + When provided, avoids redundant interpolator creation for multiple geometries + sharing the same field data. This significantly improves performance for + GeometryGroup processing.""" + + cached_min_spacing_from_permittivity: Optional[float] = None + """Cached `min_spacing_from_permittivity` to be used for objects like GeometryGroup + to avoid recomputing this value multiple times in `adaptive_vjp_spacing`.""" + + # private cache for interpolators + _interpolators_cache: dict = field(default_factory=dict, init=False, repr=False) + + def updated_copy(self, **kwargs: Any) -> Self: + """Create a copy with updated fields.""" + kwargs.pop("deep", None) + kwargs.pop("validate", None) + return replace(self, **kwargs) + + @staticmethod + def _nan_to_num_if_needed( + coords: Union[ArrayFloat, ArrayComplex], + ) -> Union[ArrayFloat, ArrayComplex]: + """Convert NaN and infinite values to finite numbers, optimized for finite inputs.""" + # skip check for small arrays + if coords.size < 1000: + return np.nan_to_num(coords, posinf=LARGE_NUMBER, neginf=-LARGE_NUMBER) + + if np.isfinite(coords).all(): + return coords + return np.nan_to_num(coords, posinf=LARGE_NUMBER, neginf=-LARGE_NUMBER) + + @staticmethod + def _evaluate_with_interpolators( + interpolators: dict[str, Callable[[ArrayFloat], ArrayComplex]], + coords: ArrayFloat, + ) -> dict[str, ArrayComplex]: + """Evaluate field components at coordinates using cached interpolators. + + Parameters + ---------- + interpolators : dict + Dictionary mapping field component names to ``RegularGridInterpolator`` objects. + coords : np.ndarray + Spatial coordinates (N, 3) where fields are evaluated. + + Returns + ------- + dict[str, np.ndarray] + Dictionary mapping component names to field values at coordinates. + """ + auto_cfg = config.adjoint + float_dtype = auto_cfg.gradient_dtype_float + complex_dtype = auto_cfg.gradient_dtype_complex + + coords = DerivativeInfo._nan_to_num_if_needed(coords) + if coords.dtype != float_dtype and coords.dtype != complex_dtype: + coords = coords.astype(float_dtype, copy=False) + return {name: interp(coords) for name, interp in interpolators.items()} + + def create_interpolators(self, dtype: Optional[np.dtype[Any]] = None) -> dict[str, Any]: + """Create interpolators for field components and permittivity data. + + Creates and caches ``RegularGridInterpolator`` objects for all field components + (E_fwd, E_adj, D_fwd, D_adj) and permittivity data (eps_in, eps_out, eps_data). + Contains (H_fwd, H_adj) field components when relevant for certain material types. + This caching strategy significantly improves performance by avoiding + repeated interpolator construction in gradient evaluation loops. + + Parameters + ---------- + dtype : np.dtype[Any], optional = None + Data type for interpolation coordinates and values. Defaults to the + current ``config.adjoint.gradient_dtype_float``. + + Returns + ------- + dict + Nested dictionary structure: + - Field data: {"E_fwd": {"Ex": interpolator, ...}, ...} + - Permittivity: {"eps_in": interpolator, "eps_out": interpolator, "eps_data": interpolator} + """ + from scipy.interpolate import RegularGridInterpolator + + auto_cfg = config.adjoint + if dtype is None: + dtype = auto_cfg.gradient_dtype_float + complex_dtype = auto_cfg.gradient_dtype_complex + + cache_key = str(dtype) + if cache_key in self._interpolators_cache: + return self._interpolators_cache[cache_key] + + interpolators = {} + coord_cache = {} + + def _make_lazy_interpolator_group( + field_data_dict: Optional[FieldDataDict], + group_key: Optional[str], + is_field_group: bool = True, + override_method: Optional[str] = None, + ) -> None: + """Helper to create a group of lazy interpolators.""" + if not field_data_dict: + return + if is_field_group: + interpolators[group_key] = {} + + for component_name, arr in field_data_dict.items(): + # use object ID for caching to handle shared grids + arr_id = id(arr.data) + if arr_id not in coord_cache: + points = tuple(c.data.astype(dtype, copy=False) for c in (arr.x, arr.y, arr.z)) + coord_cache[arr_id] = points + points = coord_cache[arr_id] + + def creator_func( + arr: ScalarFieldDataArray = arr, + points: tuple[np.ndarray, ...] = points, + ) -> Callable[[ArrayFloat], ArrayComplex]: + data = arr.data.astype( + complex_dtype if np.iscomplexobj(arr.data) else dtype, copy=False + ) + # create interpolator with frequency dimension + if "f" in arr.dims: + freq_coords = arr.coords["f"].data.astype(dtype, copy=False) + # ensure frequency dimension is last + if arr.dims != ("x", "y", "z", "f"): + freq_dim_idx = arr.dims.index("f") + axes = list(range(data.ndim)) + axes.append(axes.pop(freq_dim_idx)) + data = np.transpose(data, axes) + else: + # single frequency case - add singleton dimension + freq_coords = np.array([0.0], dtype=dtype) + data = data[..., np.newaxis] + + points_with_freq = (*points, freq_coords) + # If PEC, use nearest interpolation instead of linear to avoid interpolating + # with field values inside the PEC (which are 0). Instead, we make sure to + # choose interpolation points such that their nearest location is outside of + # the PEC surface. The same applies if the background_medium is marked as PEC + # since we will need to use the same interpolation strategy inside the structure + # border. + method = ( + "nearest" + if (self.is_medium_pec or self.background_medium_is_pec) + else "linear" + ) + if override_method is not None: + method = override_method + interpolator_obj = RegularGridInterpolator( + points_with_freq, data, method=method, bounds_error=False, fill_value=None + ) + + def interpolator(coords: ArrayFloat) -> ArrayComplex: + # coords: (N, 3) spatial points + n_points = coords.shape[0] + n_freqs = len(freq_coords) + + # build coordinates with frequency dimension + coords_with_freq = np.empty((n_points * n_freqs, 4), dtype=coords.dtype) + coords_with_freq[:, :3] = np.repeat(coords, n_freqs, axis=0) + coords_with_freq[:, 3] = np.tile(freq_coords, n_points) + + result = interpolator_obj(coords_with_freq) + return result.reshape(n_points, n_freqs) + + return interpolator + + if is_field_group: + interpolators[group_key][component_name] = LazyInterpolator(creator_func) + else: + interpolators[component_name] = LazyInterpolator(creator_func) + + # process field interpolators (nested dictionaries) + interpolator_groups = [ + ("E_fwd", self.E_fwd), + ("E_adj", self.E_adj), + ("D_fwd", self.D_fwd), + ("D_adj", self.D_adj), + ] + if self.is_medium_pec or self.background_medium_is_pec: + interpolator_groups += [("H_fwd", self.H_fwd), ("H_adj", self.H_adj)] # type: ignore[list-item] + for group_key, data_dict in interpolator_groups: + _make_lazy_interpolator_group( + data_dict, f"{group_key}_linear", is_field_group=True, override_method="linear" + ) + _make_lazy_interpolator_group( + data_dict, f"{group_key}_nearest", is_field_group=True, override_method="nearest" + ) + + if self.eps_data is not None: + _make_lazy_interpolator_group( + self.eps_data, "eps_data", is_field_group=True, override_method="nearest" + ) + + if self.eps_in is not None: + _make_lazy_interpolator_group( + {"eps_in": self.eps_in}, None, is_field_group=False, override_method="nearest" + ) + if self.eps_out is not None: + _make_lazy_interpolator_group( + {"eps_out": self.eps_out}, None, is_field_group=False, override_method="nearest" + ) + + self._interpolators_cache[cache_key] = interpolators + return interpolators + + def evaluate_gradient_at_points( + self, + spatial_coords: np.ndarray, + normals: np.ndarray, + perps1: np.ndarray, + perps2: np.ndarray, + interpolators: Optional[dict] = None, + ) -> np.ndarray: + """Compute adjoint gradients at surface points for shape optimization. + + Implements the surface integral formulation for computing gradients with respect + to geometry perturbations. + + Parameters + ---------- + spatial_coords : np.ndarray + (N, 3) array of surface evaluation points. + normals : np.ndarray + (N, 3) array of outward-pointing normal vectors at each surface point. + perps1 : np.ndarray + (N, 3) array of first tangent vectors perpendicular to normals. + perps2 : np.ndarray + (N, 3) array of second tangent vectors perpendicular to both normals and perps1. + interpolators : dict = None + Pre-computed field interpolators for efficiency. + + Returns + ------- + np.ndarray + (N,) array of gradient values at each surface point. Must be integrated + with appropriate quadrature weights to get total gradient. + """ + if interpolators is None: + raise NotImplementedError( + "Direct field evaluation without interpolators is not implemented. " + "Please create interpolators using 'create_interpolators()' first." + ) + + # In all paths below, we need to have computed the gradient integration for a + # dielectric-dielectric interface. + vjps_dielectric = self._evaluate_dielectric_gradient_at_points( + spatial_coords, + normals, + perps1, + perps2, + interpolators, + self.eps_in, + self.eps_out, + ) + + if self.is_medium_pec: + # The structure medium is PEC, but there may be a part of the interface that has + # dielectric placed on top of or around it where we want to use the dielectric + # gradient integration. We use the mask to choose between the PEC-dielectric and + # dielectric-dielectric parts of the border. + + # Detect PEC by looking just inside the boundary + mask_pec = self._detect_pec_gradient_points( + spatial_coords, + normals, + self.eps_in, + interpolators["eps_data"], + is_outside=False, + ) + + # Compute PEC gradients, pulling fields outside of the boundary + vjps_pec = self._evaluate_pec_gradient_at_points( + spatial_coords, + normals, + perps1, + perps2, + interpolators, + ("eps_out", self.eps_out), + is_outside=True, + ) + + vjps = mask_pec * vjps_pec + (1.0 - mask_pec) * vjps_dielectric + elif self.background_medium_is_pec: + # The structure medium is dielectric, but there may be a part of the interface that has + # PEC placed on top of or around it where we want to use the PEC gradient integration. + # We use the mask to choose between the dielectric-dielectric and PEC-dielectric parts + # of the border. + + # Detect PEC by looking just outside the boundary + mask_pec = self._detect_pec_gradient_points( + spatial_coords, + normals, + self.eps_out, + interpolators["eps_data"], + is_outside=True, + ) + + # Compute PEC gradients, pulling fields inside of the boundary and applying a negative + # sign compared to above because inside and outside definitions are switched + vjps_pec = -self._evaluate_pec_gradient_at_points( + spatial_coords, + normals, + perps1, + perps2, + interpolators, + ("eps_in", self.eps_in), + is_outside=False, + ) + + vjps = mask_pec * vjps_pec + (1.0 - mask_pec) * vjps_dielectric + else: + # The structure and its background are both assumed to be dielectric, so we use the + # dielectric-dielectric gradient integration. + vjps = vjps_dielectric + + # sum over frequency dimension + vjps = np.sum(vjps, axis=-1) + + return vjps + + def _evaluate_dielectric_gradient_at_points( + self, + spatial_coords: ArrayFloat, + normals: ArrayFloat, + perps1: ArrayFloat, + perps2: ArrayFloat, + interpolators: dict[str, dict[str, Callable[[ArrayFloat], ArrayComplex]]], + eps_in_data: ScalarFieldDataArray, + eps_out_data: ScalarFieldDataArray, + ) -> ArrayComplex: + eps_out_coords = self._snap_spatial_coords_boundary( + spatial_coords, + normals, + is_outside=True, + data_array=eps_out_data, + ) + eps_in_coords = self._snap_spatial_coords_boundary( + spatial_coords, + normals, + is_outside=False, + data_array=eps_in_data, + ) + + eps_out = interpolators["eps_out"](eps_out_coords) + eps_in = interpolators["eps_in"](eps_in_coords) + + # evaluate all field components at surface points + E_fwd_at_coords = { + name: interp(spatial_coords) for name, interp in interpolators["E_fwd_linear"].items() + } + E_adj_at_coords = { + name: interp(spatial_coords) for name, interp in interpolators["E_adj_linear"].items() + } + D_fwd_at_coords = { + name: interp(spatial_coords) for name, interp in interpolators["D_fwd_linear"].items() + } + D_adj_at_coords = { + name: interp(spatial_coords) for name, interp in interpolators["D_adj_linear"].items() + } + + delta_eps_inv = 1.0 / eps_in - 1.0 / eps_out + delta_eps = eps_in - eps_out + + # project fields onto local surface basis (normal + two tangents) + D_fwd_norm = self._project_in_basis(D_fwd_at_coords, basis_vector=normals) + D_adj_norm = self._project_in_basis(D_adj_at_coords, basis_vector=normals) + + E_fwd_perp1 = self._project_in_basis(E_fwd_at_coords, basis_vector=perps1) + E_adj_perp1 = self._project_in_basis(E_adj_at_coords, basis_vector=perps1) + + E_fwd_perp2 = self._project_in_basis(E_fwd_at_coords, basis_vector=perps2) + E_adj_perp2 = self._project_in_basis(E_adj_at_coords, basis_vector=perps2) + + D_der_norm = D_fwd_norm * D_adj_norm + E_der_perp1 = E_fwd_perp1 * E_adj_perp1 + E_der_perp2 = E_fwd_perp2 * E_adj_perp2 + + vjps = -delta_eps_inv * D_der_norm + E_der_perp1 * delta_eps + E_der_perp2 * delta_eps + + return vjps + + def _snap_spatial_coords_boundary( + self, + spatial_coords: ArrayFloat, + normals: ArrayFloat, + is_outside: bool, + data_array: ScalarFieldDataArray, + ) -> np.ndarray: + """Assuming a nearest interpolation, adjust the interpolation points given the grid + defined by `grid_centers` and using `spatial_coords` as a starting point such that we + select a point inside/outside the boundary depending on is_outside. + + *** (nearest point outside boundary) + ^ + | n (normal direction) + | + _.-~'`-._.-~'`-._ (boundary) + * (nearest point) + + Parameters + ---------- + spatial_coords : np.ndarray + (N, 3) array of surface evaluation points. + normals : np.ndarray + (N, 3) array of outward-pointing normal vectors at each surface point. + is_outside: bool + Indicator specifying if coordinates should be snapped inside or outside the boundary. + data_array: ScalarFieldDataArray + Data array to pull grid centers from when snapping coordinates. + + Returns + ------- + np.ndarray + (N, 3) array of coordinate centers at which to interpolate such that they line up + with a grid center and are inside/outside the boundary + """ + coords = data_array.coords + grid_centers = {key: np.array(coords[key].values) for key in coords} + + grid_ddim = np.zeros_like(normals) + for idx, dim in enumerate("xyz"): + expanded_coords = np.expand_dims(spatial_coords[:, idx], axis=1) + grid_centers_select = grid_centers[dim] + + diff = np.abs(expanded_coords - grid_centers_select) + + nearest_grid = np.argmin(diff, axis=-1) + nearest_grid = np.minimum(np.maximum(nearest_grid, 1), len(grid_centers_select) - 1) + + # compute the local grid spacing near the boundary + grid_ddim[:, idx] = ( + grid_centers_select[nearest_grid] - grid_centers_select[nearest_grid - 1] + ) + + # + # Assuming we move in the normal direction, finds which dimension we need to move the least + # in order to ensure we snap to a point outside the boundary in the worst case (i.e. - the + # nearest point is just inside the surface) + # + # Cover for 2D cases using filter below: + # 2D case 1: + # - in plane gradients where normal: [a, b, 0] and grid: [dx, dy, 0] + # - want to rely on in plane normals for boundary snapping (filter on normal component = 0) + # 2D case 2: + # - out of plane gradietns where normal: [0, 0, 1] and grid: [dx, dy, 0] + # - want to rely on out of plane normal (so do not want to filter on grid component = 0) + # - data may not be captured out of plane, so no snapping will occur even with coords_dn = 0 + # + small_number = np.finfo(normals.dtype).eps + coords_dn = np.min( + np.where( + (np.abs(normals) > small_number), + np.abs(grid_ddim) / (np.abs(normals) + small_number), + np.inf, + ), + axis=1, + keepdims=True, + ) + + # adjust coordinates by half a grid point outside boundary such that nearest interpolation + # point snaps to outside the boundary + normal_direction = 1.0 if is_outside else -1.0 + adjust_spatial_coords = ( + spatial_coords + + normal_direction * normals * config.adjoint.boundary_snapping_fraction * coords_dn + ) + + return adjust_spatial_coords + + def _compute_edge_distance( + self, + spatial_coords: np.ndarray, + grid_centers: dict[str, np.ndarray], + adjust_spatial_coords: np.ndarray, + ) -> np.ndarray: + """Assuming nearest neighbor interpolation, computes the edge distance after interpolation when using the + adjust_spatial_coords computed from _snap_spatial_coords_boundary. + + Parameters + ---------- + spatial_coords : np.ndarray + (N, 3) array of surface evaluation points. + normals : np.ndarray + (N, 3) array of outward-pointing normal vectors at each surface point. + grid_centers: dict[str, np.ndarray] + The grid points for a given field component indexed by dimension. These grid points + are used to find the nearest snapping point and adjust the interpolation coordinates + to ensure we fall inside/outside of a boundary. + + Returns + ------- + np.ndarray + (N,) array of distances from the nearest interpolation points to the desired surface + edge points specified by `spatial_coords` + """ + + edge_distance_squared_sum = np.zeros_like(adjust_spatial_coords[:, 0]) + for idx, dim in enumerate("xyz"): + expanded_adjusted_coords = np.expand_dims(adjust_spatial_coords[:, idx], axis=1) + grid_centers_select = grid_centers[dim] + + # find nearest grid point from the adjusted coordinates + diff = np.abs(expanded_adjusted_coords - grid_centers_select) + nearest_grid = np.argmin(diff, axis=-1) + + # compute edge distance from the nearest interpolated point to the boundary edge + edge_distance_squared_sum += ( + np.abs(spatial_coords[:, idx] - grid_centers_select[nearest_grid]) ** 2 + ) + + # this edge distance is useful when correcting for edge singularities like those from a PEC + # material and is used when the PEC PolySlab structure has zero thickness, for example + edge_distance = np.sqrt(edge_distance_squared_sum) + + return edge_distance + + def _detect_pec_gradient_points( + self, + spatial_coords: np.ndarray, + normals: np.ndarray, + eps_data: ScalarFieldDataArray, + interpolator: LazyInterpolator, + is_outside: bool, + ) -> np.ndarray: + def _detect_pec(eps_mask: np.ndarray) -> np.ndarray: + return 1.0 * (eps_mask < config.adjoint.pec_detection_threshold) + + adjusted_coords = self._snap_spatial_coords_boundary( + spatial_coords=spatial_coords, + normals=normals, + is_outside=is_outside, + data_array=eps_data, + ) + + eps_adjusted_all = [ + component_interpolator(adjusted_coords) + for _, component_interpolator in interpolator.items() + ] + eps_detect_pec = reduce(np.minimum, eps_adjusted_all) + + return _detect_pec(eps_detect_pec) + + def _evaluate_pec_gradient_at_points( + self, + spatial_coords: np.ndarray, + normals: np.ndarray, + perps1: np.ndarray, + perps2: np.ndarray, + interpolators: dict, + eps_dielectric: tuple[str, ScalarFieldDataArray], + is_outside: bool, + ) -> np.ndarray: + eps_dielectric_key, eps_dielectric_data = eps_dielectric + + def _snap_coordinate_outside( + field_components: FieldDataDict, + ) -> dict[str, dict[str, ArrayFloat]]: + """Helper function to perform coordinate adjustment and compute edge distance for each + component in `field_components`. + + Parameters + ---------- + field_components: FieldDataDict + The field components (i.e - Ex, Ey, Ez, Hx, Hy, Hz) that we would like to sample just + outside the PEC surface using nearest interpolation. + + Returns + ------- + dict[str, dict[str, np.ndarray]] + Dictionary mapping each field component name to a dictionary of adjusted coordinates + and edge distances for that component. + """ + adjustment = {} + for name in field_components: + field_component = field_components[name] + field_component_coords = field_component.coords + + grid_centers = { + key: np.array(field_component_coords[key].values) + for key in field_component_coords + } + + adjusted_coords = self._snap_spatial_coords_boundary( + spatial_coords, + normals, + is_outside=is_outside, + data_array=field_component, + ) + + edge_distance = self._compute_edge_distance( + spatial_coords=spatial_coords, + grid_centers=grid_centers, + adjust_spatial_coords=adjusted_coords, + ) + adjustment[name] = {"coords": adjusted_coords, "edge_distance": edge_distance} + + return adjustment + + def _interpolate_field_components( + interp_coords: dict[str, dict[str, ArrayFloat]], field_name: str + ) -> dict[str, ArrayComplex]: + return { + name: interp(interp_coords[name]["coords"]) + for name, interp in interpolators[field_name].items() + } + + # adjust coordinates for PEC to be outside structure bounds and get edge distance for singularity correction. + E_fwd_coords_adjusted = _snap_coordinate_outside(self.E_fwd) + E_adj_coords_adjusted = _snap_coordinate_outside(self.E_adj) + + H_fwd_coords_adjusted = _snap_coordinate_outside(self.H_fwd) + H_adj_coords_adjusted = _snap_coordinate_outside(self.H_adj) + + # using the adjusted coordinates, evaluate all field components at surface points + E_fwd_at_coords = _interpolate_field_components( + E_fwd_coords_adjusted, field_name="E_fwd_nearest" + ) + E_adj_at_coords = _interpolate_field_components( + E_adj_coords_adjusted, field_name="E_adj_nearest" + ) + H_fwd_at_coords = _interpolate_field_components( + H_fwd_coords_adjusted, field_name="H_fwd_nearest" + ) + H_adj_at_coords = _interpolate_field_components( + H_adj_coords_adjusted, field_name="H_adj_nearest" + ) + + eps_coords_adjusted = self._snap_spatial_coords_boundary( + spatial_coords, + normals, + is_outside=is_outside, + data_array=eps_dielectric_data, + ) + eps_dielectric = interpolators[eps_dielectric_key](eps_coords_adjusted) + + structure_sizes = np.array( + [self.bounds[1][idx] - self.bounds[0][idx] for idx in range(len(self.bounds[0]))] + ) + + is_flat_perp_dim1 = np.isclose(np.abs(np.sum(perps1[0] * structure_sizes)), 0.0) + is_flat_perp_dim2 = np.isclose(np.abs(np.sum(perps2[0] * structure_sizes)), 0.0) + flat_perp_dims = [is_flat_perp_dim1, is_flat_perp_dim2] + + # check if this integration is happening along an edge in which case we will eliminate + # on of the H field integration components and apply singularity correction + pec_line_integration = is_flat_perp_dim1 or is_flat_perp_dim2 + + def _compute_singularity_correction( + adjustment_: dict[str, dict[str, ArrayFloat]], + ) -> ArrayFloat: + """ + Given the `adjustment_` which contains the distance from the PEC edge each field + component is nearest interpolated at, computes the singularity correction when + working with 2D PEC using the average edge_distance for each component. In the case + of 3D PEC gradients, no singularity correction is applied so an array of ones is returned. + + Parameters + ---------- + adjustment_: dict[str, dict[str, np.ndarray]] + Dictionary that maps field component name to a dictionary containing the coordinate + adjustment and the distance to the PEC edge for those coordinates. The edge distance + is used for 2D PEC singularity correction. + + Returns + ------- + np.ndarray + Returns the singularity correction which has shape (N,) where there are N points in + `spatial_coords` + """ + return ( + ( + 0.5 + * np.pi + * np.mean([adjustment_[name]["edge_distance"] for name in adjustment_], axis=0) + ) + if pec_line_integration + else np.ones_like(spatial_coords, shape=spatial_coords.shape[0]) + ) + + E_norm_singularity_correction = np.expand_dims( + _compute_singularity_correction(E_fwd_coords_adjusted), axis=1 + ) + H_perp_singularity_correction = np.expand_dims( + _compute_singularity_correction(H_fwd_coords_adjusted), axis=1 + ) + + E_fwd_norm = self._project_in_basis(E_fwd_at_coords, basis_vector=normals) + E_adj_norm = self._project_in_basis(E_adj_at_coords, basis_vector=normals) + + # compute the normal E contribution to the gradient (the tangential E contribution + # is 0 in the case of PEC since this field component is continuous and thus 0 at + # the boundary) + contrib_E = E_norm_singularity_correction * eps_dielectric * E_fwd_norm * E_adj_norm + vjps = contrib_E + + # compute the tangential H contribution to the gradient (the normal H contribution + # is 0 for PEC) + H_fwd_perp1 = self._project_in_basis(H_fwd_at_coords, basis_vector=perps1) + H_adj_perp1 = self._project_in_basis(H_adj_at_coords, basis_vector=perps1) + + H_fwd_perp2 = self._project_in_basis(H_fwd_at_coords, basis_vector=perps2) + H_adj_perp2 = self._project_in_basis(H_adj_at_coords, basis_vector=perps2) + + H_der_perp1 = H_perp_singularity_correction * H_fwd_perp1 * H_adj_perp1 + H_der_perp2 = H_perp_singularity_correction * H_fwd_perp2 * H_adj_perp2 + + H_integration_components = (H_der_perp1, H_der_perp2) + if pec_line_integration: + # if we are integrating along the line, we choose the H component normal to + # the edge which corresponds to a surface current along the edge whereas the other + # tangential component corresponds to a surface current along the flat dimension. + H_integration_components = tuple( + H_comp for idx, H_comp in enumerate(H_integration_components) if flat_perp_dims[idx] + ) + + # for each of the tangential components we are integrating the H fields over, + # adjust weighting to account for pre-weighting of the source by `EPSILON_0` + # and multiply by appropriate `MU_0` factor + for H_perp in H_integration_components: + contrib_H = MU_0 * H_perp / EPSILON_0 + vjps += contrib_H + + return vjps + + @staticmethod + def _project_in_basis( + field_components: dict[str, np.ndarray], + basis_vector: np.ndarray, + ) -> np.ndarray: + """Project 3D field components onto a basis vector. + + Parameters + ---------- + field_components : dict[str, np.ndarray] + Dictionary with keys like "Ex", "Ey", "Ez" or "Dx", "Dy", "Dz" containing field values. + Values have shape (N, F) where F is the number of frequencies. + basis_vector : np.ndarray + (N, 3) array of basis vectors, one per evaluation point. + + Returns + ------- + np.ndarray + Projected field values with shape (N, F). + """ + prefix = next(iter(field_components.keys()))[0] + field_matrix = np.stack([field_components[f"{prefix}{dim}"] for dim in "xyz"], axis=0) + + # always expect (3, N, F) shape, transpose to (N, 3, F) + field_matrix = np.transpose(field_matrix, (1, 0, 2)) + return np.einsum("ij...,ij->i...", field_matrix, basis_vector) + + def project_der_map_to_axis( + self, axis: xyz, field_type: str = "E" + ) -> dict[str, ScalarFieldDataArray] | None: + """Return a copy of the selected derivative map with only one axis kept. + + Parameters + ---------- + axis: + Axis to keep (``"x"``, ``"y"``, ``"z"``, case-insensitive). + field_type: + Map selector: ``"E"`` (``self.E_der_map``) or ``"D"`` (``self.D_der_map``). + + Returns + ------- + dict[str, ScalarFieldDataArray] | None + Copied map where non-selected components are replaced by zeros, or ``None`` + if the requested map is unavailable. + """ + field_map = {"E": self.E_der_map, "D": self.D_der_map}.get(field_type) + if field_map is None: + raise ValueError("field type must be 'D' or 'E'.") + + axis = axis.lower() + projected = dict(field_map) + if not field_map: + return projected + for dim in "xyz": + key = f"E{dim}" + if key not in field_map: + continue + if dim != axis: + projected[key] = xr.zeros_like(field_map[key]) + else: + projected[key] = field_map[key] + return projected + + @property + def min_spacing_from_permittivity(self) -> float: + if self.cached_min_spacing_from_permittivity is not None: + return self.cached_min_spacing_from_permittivity + + def spacing_by_permittivity(eps_array: ScalarFieldDataArray) -> float: + eps_real = np.asarray(eps_array.values, dtype=np.complex128).real + + dx_candidates = [] + max_frequency = np.max(self.frequencies) + + # wavelength-based sampling for dielectrics + if np.any(eps_real > 0): + eps_max = eps_real[eps_real > 0].max() + lambda_min = self.wavelength_min / np.sqrt(eps_max) + dx_candidates.append(lambda_min) + + # skin depth sampling for metals + if np.any(eps_real <= 0): + omega = 2 * np.pi * max_frequency + eps_neg = eps_real[eps_real <= 0] + delta_min = C_0 / (omega * np.sqrt(np.abs(eps_neg).max())) + dx_candidates.append(delta_min) + + computed_spacing = min(dx_candidates) + + return computed_spacing + + eps_spacings = [ + spacing_by_permittivity(eps_array) for _, eps_array in self.eps_data.items() + ] + min_spacing = np.min(eps_spacings) + + return min_spacing + + @contextmanager + def cache_min_spacing_from_permittivity(self) -> Iterator[None]: + """ + Cache min_spacing_from_permittivity for the duration of the block. Cache + is always cleared on exit. + """ + + self.cached_min_spacing_from_permittivity = self.min_spacing_from_permittivity + try: + yield + finally: + self.cached_min_spacing_from_permittivity = None + + def adaptive_vjp_spacing( + self, + wl_fraction: Optional[float] = None, + min_allowed_spacing_fraction: Optional[float] = None, + ) -> float: + """Compute adaptive spacing for finite-difference gradient evaluation. + + Determines an appropriate spatial resolution based on the material + properties and electromagnetic wavelength/skin depth. + + Parameters + ---------- + wl_fraction : float, optional + Fraction of wavelength/skin depth to use as spacing. Defaults to the configured + ``autograd.default_wavelength_fraction`` when ``None``. + min_allowed_spacing_fraction : float, optional + Minimum allowed spacing fraction of free space wavelength used to + prevent numerical issues. Defaults to ``config.adjoint.minimum_spacing_fraction`` + when not specified. + + Returns + ------- + float + Adaptive spacing value for gradient evaluation. + """ + if wl_fraction is None or min_allowed_spacing_fraction is None: + from tidy3d._common.config import config + + if wl_fraction is None: + wl_fraction = config.adjoint.default_wavelength_fraction + if min_allowed_spacing_fraction is None: + min_allowed_spacing_fraction = config.adjoint.minimum_spacing_fraction + + computed_spacing = wl_fraction * self.min_spacing_from_permittivity + + min_allowed_spacing = self.wavelength_min * min_allowed_spacing_fraction + + if computed_spacing < min_allowed_spacing: + log.warning( + f"Based on the material, the adaptive spacing for integrating the polyslab surface " + f"would be {computed_spacing:.3e} μm. The spacing has been clipped to {min_allowed_spacing:.3e} μm " + f"to prevent a performance degradation.", + log_once=True, + ) + + return max(computed_spacing, min_allowed_spacing) + + @property + def wavelength_min(self) -> float: + return C_0 / np.max(self.frequencies) + + @property + def wavelength_max(self) -> float: + return C_0 / np.min(self.frequencies) + + +def integrate_within_bounds(arr: xr.DataArray, dims: list[str], bounds: Bound) -> xr.DataArray: + """Integrate a data array within specified spatial bounds. + + Clips the integration domain to the specified bounds and performs + numerical integration using the trapezoidal rule. + + Parameters + ---------- + arr : xr.DataArray + Data array to integrate. + dims : list[str] + Dimensions to integrate over (e.g., ['x', 'y', 'z']). + bounds : Bound + Integration bounds as [[xmin, ymin, zmin], [xmax, ymax, zmax]]. + + Returns + ------- + xr.DataArray + Result of integration with specified dimensions removed. + + Notes + ----- + - Coordinates outside bounds are clipped, effectively setting dL=0 + - Only integrates dimensions with more than one coordinate point + - Uses xarray's integrate method (trapezoidal rule) + """ + bounds = np.asarray(bounds).T + all_coords = {} + + for dim, (bmin, bmax) in zip(dims, bounds): + bmin = get_static(bmin) + bmax = get_static(bmax) + + # clip coordinates to bounds (sets dL=0 outside bounds) + coord_values = arr.coords[dim].data + all_coords[dim] = np.clip(coord_values, bmin, bmax) + + _arr = arr.assign_coords(**all_coords) + + # only integrate dimensions with multiple points + dims_integrate = [dim for dim in dims if len(_arr.coords[dim]) > 1] + return _arr.integrate(coord=dims_integrate) + + +__all__ = [ + "DerivativeInfo", + "integrate_within_bounds", +] diff --git a/tidy3d/_common/components/autograd/field_map.py b/tidy3d/_common/components/autograd/field_map.py new file mode 100644 index 0000000000..159a7d1527 --- /dev/null +++ b/tidy3d/_common/components/autograd/field_map.py @@ -0,0 +1,77 @@ +"""Typed containers for autograd traced field metadata.""" + +from __future__ import annotations + +import json +from typing import TYPE_CHECKING, Any, Union + +from pydantic import Field + +from tidy3d._common.components.autograd.types import TracedArrayLike, TracedComplex, TracedFloat +from tidy3d._common.components.base import Tidy3dBaseModel + +if TYPE_CHECKING: + from typing import Callable + + from tidy3d._common.components.autograd.types import AutogradFieldMap + + +class Tracer(Tidy3dBaseModel): + """Representation of a single traced element within a model.""" + + path: tuple[Any, ...] = Field( + title="Path to the traced object in the model dictionary.", + ) + data: Union[TracedFloat, TracedComplex, TracedArrayLike] = Field(title="Tracing data") + + +class FieldMap(Tidy3dBaseModel): + """Collection of traced elements.""" + + tracers: tuple[Tracer, ...] = Field( + title="Collection of Tracers.", + ) + + @property + def to_autograd_field_map(self) -> AutogradFieldMap: + """Convert to ``AutogradFieldMap`` autograd dictionary.""" + return {tracer.path: tracer.data for tracer in self.tracers} + + @classmethod + def from_autograd_field_map(cls, autograd_field_map: AutogradFieldMap) -> FieldMap: + """Initialize from an ``AutogradFieldMap`` autograd dictionary.""" + tracers = [] + for path, data in autograd_field_map.items(): + tracers.append(Tracer(path=path, data=data)) + return cls(tracers=tuple(tracers)) + + +def _encoded_path(path: tuple[Any, ...]) -> str: + """Return a stable JSON representation for a traced path.""" + return json.dumps(list(path), separators=(",", ":"), ensure_ascii=True) + + +class TracerKeys(Tidy3dBaseModel): + """Collection of traced field paths.""" + + keys: tuple[tuple[Any, ...], ...] = Field( + title="Collection of tracer keys.", + ) + + def encoded_keys(self) -> list[str]: + """Return the JSON-encoded representation of keys.""" + return [_encoded_path(path) for path in self.keys] + + @classmethod + def from_field_mapping( + cls, + field_mapping: AutogradFieldMap, + *, + sort_key: Callable[[tuple[Any, ...]], str] | None = None, + ) -> TracerKeys: + """Construct keys from an autograd field mapping.""" + if sort_key is None: + sort_key = _encoded_path + + sorted_paths = tuple(sorted(field_mapping.keys(), key=sort_key)) + return cls(keys=sorted_paths) diff --git a/tidy3d/_common/components/autograd/functions.py b/tidy3d/_common/components/autograd/functions.py new file mode 100644 index 0000000000..86beaec421 --- /dev/null +++ b/tidy3d/_common/components/autograd/functions.py @@ -0,0 +1,289 @@ +from __future__ import annotations + +import itertools +from typing import TYPE_CHECKING, Any + +import autograd.numpy as anp +import numpy as np +from autograd.extend import defjvp, defvjp, primitive +from autograd.numpy.numpy_jvps import broadcast +from autograd.numpy.numpy_vjps import unbroadcast_f + +if TYPE_CHECKING: + from numpy.typing import NDArray + + from tidy3d._common.components.autograd.types import InterpolationType + + +def _evaluate_nearest( + indices: NDArray[np.int64], norm_distances: NDArray[np.float64], values: NDArray[np.float64] +) -> NDArray[np.float64]: + """Perform nearest neighbor interpolation in an n-dimensional space. + + This function determines the nearest neighbor in a grid for a given point + and returns the corresponding value from the input array. + + Parameters + ---------- + indices : np.ndarray[np.int64] + Indices of the lower bounds of the grid cell containing the interpolation point. + norm_distances : np.ndarray[np.float64] + Normalized distances from the lower bounds of the grid cell to the + interpolation point, for each dimension. + values : np.ndarray[np.float64] + The n-dimensional array of values to interpolate from. + + Returns + ------- + np.ndarray[np.float64] + The value of the nearest neighbor to the interpolation point. + """ + idx_res = tuple(anp.where(yi <= 0.5, i, i + 1) for i, yi in zip(indices, norm_distances)) + return values[idx_res] + + +def _evaluate_linear( + indices: NDArray[np.int64], norm_distances: NDArray[np.float64], values: NDArray[np.float64] +) -> NDArray[np.float64]: + """Perform linear interpolation in an n-dimensional space. + + This function calculates the linearly interpolated value at a point in an + n-dimensional grid, given the indices of the surrounding grid points and + the normalized distances to these points. + The multi-linear interpolation is implemented by computing a weighted + average of the values at the vertices of the hypercube surrounding the + interpolation point. + + Parameters + ---------- + indices : np.ndarray[np.int64] + Indices of the lower bounds of the grid cell containing the interpolation point. + norm_distances : np.ndarray[np.float64] + Normalized distances from the lower bounds of the grid cell to the + interpolation point, for each dimension. + values : np.ndarray[np.float64] + The n-dimensional array of values to interpolate from. + + Returns + ------- + np.ndarray[np.float64] + The interpolated value at the desired point. + """ + # Create a slice object for broadcasting over trailing dimensions + _slice = (slice(None),) + (None,) * (values.ndim - len(indices)) + + # Prepare iterables for lower and upper bounds of the hypercube + ix = zip(indices, (1 - yi for yi in norm_distances)) + iy = zip((i + 1 for i in indices), norm_distances) + + # Initialize the result + value = anp.zeros(1) + + # Iterate over all vertices of the hypercube + for h in itertools.product(*zip(ix, iy)): + edge_indices, weights = zip(*h) + + # Compute the weight for this vertex + weight = anp.ones(1) + for w in weights: + weight = weight * w + + # Compute the contribution of this vertex and add it to the result + term = values[edge_indices] * weight[_slice] + value = value + term + + return value + + +def interpn( + points: tuple[NDArray[np.float64], ...], + values: NDArray[np.float64], + xi: tuple[NDArray[np.float64], ...], + *, + method: InterpolationType = "linear", + **kwargs: Any, +) -> NDArray[np.float64]: + """Interpolate over a rectilinear grid in arbitrary dimensions. + + This function mirrors the interface of `scipy.interpolate.interpn` but is differentiable with autograd. + + Parameters + ---------- + points : tuple[np.ndarray[np.float64], ...] + The points defining the rectilinear grid in n dimensions. + values : np.ndarray[np.float64] + The data values on the rectilinear grid. + xi : tuple[np.ndarray[np.float64], ...] + The coordinates to sample the gridded data at. + method : InterpolationType = "linear" + The method of interpolation to perform. Supported are "linear" and "nearest". + + Returns + ------- + np.ndarray[np.float64] + The interpolated values. + + Raises + ------ + ValueError + If the interpolation method is not supported. + + See Also + -------- + `scipy.interpolate.interpn `_ + """ + from scipy.interpolate import RegularGridInterpolator + + if method == "nearest": + interp_fn = _evaluate_nearest + elif method == "linear": + interp_fn = _evaluate_linear + else: + raise ValueError(f"Unsupported interpolation method: {method}") + + # Avoid SciPy coercing autograd ArrayBox values during _check_values. + dummy_values = np.zeros(np.shape(values), dtype=float) + if kwargs.get("fill_value") == "extrapolate": + itrp = RegularGridInterpolator( + points, dummy_values, method=method, fill_value=None, bounds_error=False + ) + else: + itrp = RegularGridInterpolator(points, dummy_values, method=method) + + # Prepare the grid for interpolation + # This step reshapes the grid, checks for NaNs and out-of-bounds values + # It returns: + # - reshaped grid + # - original shape + # - number of dimensions + # - boolean array indicating NaN positions + # - (discarded) boolean array for out-of-bounds values + xi, shape, ndim, nans, _ = itrp._prepare_xi(xi) + + # Find the indices of the grid cells containing the interpolation points + # and calculate the normalized distances (ranging from 0 at lower grid point to 1 + # at upper grid point) within these cells + indices, norm_distances = itrp._find_indices(xi.T) + + result = interp_fn(indices, norm_distances, values) + nans = anp.reshape(nans, (-1,) + (1,) * (result.ndim - 1)) + result = anp.where(nans, np.nan, result) + return anp.reshape(result, shape[:-1] + values.shape[ndim:]) + + +def trapz(y: NDArray, x: NDArray = None, dx: float = 1.0, axis: int = -1) -> float: + """ + Integrate along the given axis using the composite trapezoidal rule. + + Parameters + ---------- + y : np.ndarray + Input array to integrate. + x : np.ndarray = None + The sample points corresponding to the y values. If None, the sample points are assumed to be evenly spaced + with spacing `dx`. + dx : float = 1.0 + The spacing between sample points when `x` is None. Default is 1.0. + axis : int = -1 + The axis along which to integrate. Default is the last axis. + + Returns + ------- + float + Definite integral as approximated by the trapezoidal rule. + """ + if x is None: + d = dx + elif x.ndim == 1: + d = np.diff(x) + shape = [1] * y.ndim + shape[axis] = d.shape[0] + d = np.reshape(d, shape) + else: + d = np.diff(x, axis=axis) + + slice1 = [slice(None)] * y.ndim + slice2 = [slice(None)] * y.ndim + slice1[axis] = slice(1, None) + slice2[axis] = slice(None, -1) + + return anp.sum((y[tuple(slice1)] + y[tuple(slice2)]) * d / 2, axis=axis) + + +@primitive +def _add_at(x: NDArray, indices_x: tuple, y: NDArray) -> NDArray: + """ + Add values to specified indices of an array. + + Autograd requires that arguments to primitives are passed in positionally. + ``add_at`` is the public-facing wrapper for this function, + which allows keyword arguments in case users pass in kwargs. + """ + out = np.copy(x) # Copy to preserve 'x' for gradient computation + out[tuple(indices_x)] += y + return out + + +defvjp( + _add_at, + lambda ans, x, indices_x, y: unbroadcast_f(x, lambda g: g), + lambda ans, x, indices_x, y: lambda g: g[tuple(indices_x)], + argnums=(0, 2), +) + +defjvp( + _add_at, + lambda g, ans, x, indices_x, y: broadcast(g, ans), + lambda g, ans, x, indices_x, y: _add_at(anp.zeros_like(ans), indices_x, g), + argnums=(0, 2), +) + + +def add_at(x: NDArray, indices_x: tuple, y: NDArray) -> NDArray: + """ + Add values to specified indices of an array. + + This function creates a copy of the input array `x`, adds the values from `y` to the specified + indices `indices_x`, and returns the modified array. + + Parameters + ---------- + x : np.ndarray + Input array to which values will be added. + indices_x : tuple + Indices of `x` where values from `y` will be added. + y : np.ndarray + Values to add to the specified indices of `x`. + + Returns + ------- + np.ndarray + The modified array with values added at the specified indices. + """ + return _add_at(x, indices_x, y) + + +@primitive +def _straight_through_clip(x: NDArray, a_min: Any, a_max: Any) -> NDArray: + """Passthrough clip can be used to preserve gradients at the endpoints of the clip range where + there is a discontinuity in the derivative. This is useful when values are at the endpoints but may + have a gradient away from the boundary or in cases where numerical precision causes a function that is + typically bounded by the clip bounds to produce a value just outside the bounds. In the forward pass, + this runs the standard clip.""" + return anp.clip(x, a_min=a_min, a_max=a_max) + + +def _straight_through_clip_vjp(ans: Any, x: NDArray, a_min: Any, a_max: Any) -> NDArray: + """Preserve original gradient information in the backward pass up until a tolerance beyond the clip bounds.""" + tolerance = 1e-5 + mask = (x >= a_min - tolerance) & (x <= a_max + tolerance) + return lambda g: g * mask + + +defvjp(_straight_through_clip, _straight_through_clip_vjp) + +__all__ = [ + "add_at", + "interpn", + "trapz", +] diff --git a/tidy3d/_common/components/autograd/types.py b/tidy3d/_common/components/autograd/types.py new file mode 100644 index 0000000000..baea29e1fc --- /dev/null +++ b/tidy3d/_common/components/autograd/types.py @@ -0,0 +1,136 @@ +# type information for autograd + +from __future__ import annotations + +import copy +from typing import TYPE_CHECKING, Annotated, Any, Literal, Union, get_origin + +import autograd.numpy as anp +from autograd.builtins import dict as TracedDict +from autograd.extend import Box, defvjp, primitive +from autograd.numpy.numpy_boxes import ArrayBox +from pydantic import BeforeValidator, PlainSerializer, PositiveFloat, TypeAdapter + +from tidy3d._common.components.autograd.utils import get_static, hasbox +from tidy3d._common.components.types.base import ( + ArrayFloat2D, + ArrayLike, + Complex, + Size1D, + _auto_serializer, +) +from tidy3d._common.components.types.utils import _add_schema + +if TYPE_CHECKING: + from typing import Optional + + from pydantic import SerializationInfo + + from tidy3d._common.compat import TypeAlias + +# add schema to the Box +_add_schema(Box, title="AutogradBox", field_type_str="autograd.tracer.Box") +_add_schema(ArrayBox, title="AutogradArrayBox", field_type_str="autograd.numpy.ArrayBox") + +# make sure Boxes in tidy3d properly define VJPs for copy operations, for computational graph +_copy = primitive(copy.copy) +_deepcopy = primitive(copy.deepcopy) + +defvjp(_copy, lambda ans, x: lambda g: _copy(g)) +defvjp(_deepcopy, lambda ans, x, memo: lambda g: _deepcopy(g, memo)) + +Box.__copy__ = lambda v: _copy(v) +Box.__deepcopy__ = lambda v, memo: _deepcopy(v, memo) +Box.__str__ = lambda self: f"{self._value} <{type(self).__name__}>" +Box.__repr__ = Box.__str__ + + +def traced_alias(base_alias: Any, *, name: Optional[str] = None) -> TypeAlias: + base_adapter = TypeAdapter(base_alias, config={"arbitrary_types_allowed": True}) + + def _validate_box_or_container(v: Any) -> Any: + # case 1: v itself is a tracer + # in this case we just validate but leave the tracer untouched + if isinstance(v, Box): + base_adapter.validate_python(get_static(v)) + return v + + # case 2: v is a plain container that contains at least one tracer + # in this case we try to coerce into ArrayBox for one-shot validation, + # but always return the original v, and fall back to a structural walk if needed + if hasbox(v): + # decide whether we must return an array + origin = get_origin(base_alias) + is_array_field = base_alias in (ArrayLike, ArrayFloat2D) or origin is None + + if is_array_field: + dense = anp.array(v) + base_adapter.validate_python(get_static(dense)) + return dense + + # otherwise it's a Python container type + # try the fast-path array validation, but return the array so ops work + try: + dense = anp.array(v) + base_adapter.validate_python(get_static(dense)) + return dense + + except Exception: + # ragged/un-coercible -> rebuild container of Boxes + if isinstance(v, tuple): + return tuple(_validate_box_or_container(x) for x in v) + if isinstance(v, list): + return [_validate_box_or_container(x) for x in v] + if isinstance(v, dict): + return {k: _validate_box_or_container(x) for k, x in v.items()} + # fallback: can't handle this structure + raise + + return base_adapter.validate_python(v) + + def _serialize_traced(a: Any, info: SerializationInfo) -> Any: + return _auto_serializer(get_static(a), info) + + return Annotated[ + object, + BeforeValidator(_validate_box_or_container), + PlainSerializer(_serialize_traced, when_used="json"), + ] + + +# "primitive" types that can use traced_alias +TracedArrayLike = traced_alias(ArrayLike) +TracedArrayFloat2D = traced_alias(ArrayFloat2D) +TracedFloat = traced_alias(float) +TracedPositiveFloat = traced_alias(PositiveFloat) +TracedComplex = traced_alias(Complex) +TracedSize1D = traced_alias(Size1D) + +# derived traced types (these mirror the types in `components.types`) +TracedSize = tuple[TracedSize1D, TracedSize1D, TracedSize1D] +TracedCoordinate = tuple[TracedFloat, TracedFloat, TracedFloat] +TracedPoleAndResidue = tuple[TracedComplex, TracedComplex] +TracedPolesAndResidues = tuple[TracedPoleAndResidue, ...] + +# The data type that we pass in and out of the web.run() @autograd.primitive +PathType = tuple[Union[int, str], ...] +AutogradFieldMap = TracedDict[PathType, Box] + +InterpolationType = Literal["nearest", "linear"] + +__all__ = [ + "AutogradFieldMap", + "InterpolationType", + "PathType", + "TracedArrayFloat2D", + "TracedArrayLike", + "TracedComplex", + "TracedCoordinate", + "TracedDict", + "TracedFloat", + "TracedPoleAndResidue", + "TracedPolesAndResidues", + "TracedPositiveFloat", + "TracedSize", + "TracedSize1D", +] diff --git a/tidy3d/_common/components/autograd/utils.py b/tidy3d/_common/components/autograd/utils.py new file mode 100644 index 0000000000..76c13b583f --- /dev/null +++ b/tidy3d/_common/components/autograd/utils.py @@ -0,0 +1,84 @@ +# utilities for working with autograd +from __future__ import annotations + +from collections.abc import Iterable, Mapping, Sequence +from typing import TYPE_CHECKING, Any + +import autograd.numpy as anp +from autograd.tracer import getval, isbox + +if TYPE_CHECKING: + from typing import Union + + from autograd.numpy.numpy_boxes import ArrayBox + from numpy.typing import ArrayLike, NDArray + +__all__ = [ + "asarray1d", + "contains", + "get_static", + "hasbox", + "is_tidy_box", + "pack_complex_vec", + "split_list", +] + + +def get_static(item: Any) -> Any: + """ + Get the 'static' (untraced) version of some value by recursively calling getval + on Box instances within a nested structure. + """ + if isbox(item): + return getval(item) + elif isinstance(item, list): + return [get_static(x) for x in item] + elif isinstance(item, tuple): + return tuple(get_static(x) for x in item) + elif isinstance(item, dict): + return {k: get_static(v) for k, v in item.items()} + return item + + +def split_list(x: list[Any], index: int) -> tuple[list, list]: + """Split a list at a given index.""" + x = list(x) + return x[:index], x[index:] + + +def is_tidy_box(x: Any) -> bool: + """Check if a value is a tidy box.""" + return getattr(x, "_tidy", False) + + +def contains(target: Any, seq: Iterable[Any]) -> bool: + """Return ``True`` if target occurs anywhere within arbitrarily nested iterables.""" + for x in seq: + if x == target: + return True + if isinstance(x, Iterable) and not isinstance(x, (str, bytes)): + if contains(target, x): + return True + return False + + +def hasbox(obj: Any) -> bool: + """True if any element inside obj is an autograd Box.""" + if isbox(obj): + return True + if isinstance(obj, Mapping): + return any(hasbox(v) for v in obj.values()) + if isinstance(obj, Sequence) and not isinstance(obj, (str, bytes)): + return any(hasbox(i) for i in obj) + return False + + +def pack_complex_vec(z: Union[NDArray, ArrayBox]) -> Union[NDArray, ArrayBox]: + """Ravel [Re(z); Im(z)] into one real vector (autograd-safe).""" + return anp.concatenate([anp.ravel(anp.real(z)), anp.ravel(anp.imag(z))]) + + +def asarray1d(x: Union[ArrayLike, ArrayBox]) -> Union[NDArray, ArrayBox]: + """Autograd-friendly 1D flatten: returns ndarray of shape (-1,).""" + x = anp.array(x) + return x if x.ndim == 1 else anp.ravel(x) diff --git a/tidy3d/_common/components/base.py b/tidy3d/_common/components/base.py new file mode 100644 index 0000000000..a3f370749f --- /dev/null +++ b/tidy3d/_common/components/base.py @@ -0,0 +1,1895 @@ +"""global configuration / base class for pydantic models used to make simulation.""" + +from __future__ import annotations + +import hashlib +import io +import json +import math +import os +import tempfile +import typing as _t +from collections import defaultdict +from collections.abc import Mapping, Sequence +from functools import total_ordering, wraps +from math import ceil +from os import PathLike +from pathlib import Path +from types import UnionType +from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar, Union, get_args, get_origin + +import h5py +import numpy as np +import rich +import xarray as xr +import yaml +from autograd.builtins import dict as TracedDict +from autograd.numpy.numpy_boxes import ArrayBox +from autograd.tracer import isbox +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, field_validator, model_validator + +from tidy3d._common.components.autograd.utils import get_static +from tidy3d._common.components.data.data_array import DATA_ARRAY_MAP +from tidy3d._common.components.file_util import compress_file_to_gzip, extract_gzip_file +from tidy3d._common.components.types.base import TYPE_TAG_STR, Undefined +from tidy3d._common.exceptions import FileError +from tidy3d._common.log import log + +if TYPE_CHECKING: + from collections.abc import Iterator + from typing import Callable + + from pydantic.fields import FieldInfo + from pydantic.functional_validators import ModelWrapValidatorHandler + + from tidy3d._common.compat import Self + from tidy3d._common.components.autograd.types import AutogradFieldMap + + +INDENT_JSON_FILE = 4 # default indentation of json string in json files +INDENT = None # default indentation of json string used internally +JSON_TAG = "JSON_STRING" +# If json string is larger than ``MAX_STRING_LENGTH``, split the string when storing in hdf5 +MAX_STRING_LENGTH = 1_000_000_000 +FORBID_SPECIAL_CHARACTERS = ["/"] +TRACED_FIELD_KEYS_ATTR = "__tidy3d_traced_field_keys__" +TYPE_TO_CLASS_MAP: dict[str, type[Tidy3dBaseModel]] = {} + +_CacheReturn = TypeVar("_CacheReturn") + + +def cache(prop: Callable[[Any], _CacheReturn]) -> Callable[[Any], _CacheReturn]: + """Decorates a property to cache the first computed value and return it on subsequent calls.""" + + # note, we could also just use `prop` as dict key, but hashing property might be slow + prop_name = prop.__name__ + + @wraps(prop) + def cached_property_getter(self: Any) -> _CacheReturn: + """The new property method to be returned by decorator.""" + + stored_value = self._cached_properties.get(prop_name) + + if stored_value is not None: + return stored_value + + computed_value = prop(self) + self._cached_properties[prop_name] = computed_value + return computed_value + + return cached_property_getter + + +def cached_property(cached_property_getter: Callable[[Any], _CacheReturn]) -> property: + """Shortcut for property(cache()) of a getter.""" + + return property(cache(cached_property_getter)) + + +_GuardedReturn = TypeVar("_GuardedReturn") + + +def cached_property_guarded( + key_func: Callable[[Any], Any], +) -> Callable[[Callable[[Any], _GuardedReturn]], property]: + """Like cached_property, but invalidates when the key_func(self) changes.""" + + def _decorator(getter: Callable[[Any], _GuardedReturn]) -> property: + prop_name = getter.__name__ + + @wraps(getter) + def _guarded(self: Any) -> _GuardedReturn: + cache_store = self._cached_properties.get(prop_name) + current_key = key_func(self) + if cache_store is not None: + cached_key, cached_value = cache_store + if cached_key == current_key: + return cached_value + value = getter(self) + self._cached_properties[prop_name] = (current_key, value) + return value + + return property(_guarded) + + return _decorator + + +def make_json_compatible(json_string: str) -> str: + """Makes the string compatible with json standards, notably for infinity.""" + + tmp_string = "<>" + json_string = json_string.replace("-Infinity", tmp_string) + json_string = json_string.replace('""-Infinity""', tmp_string) + json_string = json_string.replace("Infinity", '"Infinity"') + json_string = json_string.replace('""Infinity""', '"Infinity"') + return json_string.replace(tmp_string, '"-Infinity"') + + +def _get_valid_extension(fname: PathLike) -> str: + """Return the file extension from fname, validated to accepted ones.""" + valid_extensions = [".json", ".yaml", ".hdf5", ".h5", ".hdf5.gz"] + path = Path(fname) + extensions = [s.lower() for s in path.suffixes[-2:]] + if len(extensions) == 0: + raise FileError(f"File '{path}' missing extension.") + single_extension = extensions[-1] + if single_extension in valid_extensions: + return single_extension + double_extension = "".join(extensions) + if double_extension in valid_extensions: + return double_extension + raise FileError( + f"File extension must be one of {', '.join(valid_extensions)}; file '{path}' does not " + "match any of those." + ) + + +def _fmt_ann_literal(ann: Any) -> str: + """Spell the annotation exactly as written.""" + if ann is None: + return "Any" + if isinstance(ann, _t._GenericAlias): + return str(ann).replace("typing.", "") + return ann.__name__ if hasattr(ann, "__name__") else str(ann) + + +T = TypeVar("T", bound="Tidy3dBaseModel") + + +def field_allows_scalar(field: FieldInfo) -> bool: + annotation = field.annotation + + def allows_scalar(a: Any) -> bool: + origin = get_origin(a) + if origin in (Union, UnionType): + args = (arg for arg in get_args(a) if arg is not type(None)) + return any(allows_scalar(arg) for arg in args) + if origin is not None: + return False + return isinstance(a, type) and issubclass(a, (float, int, np.generic)) + + return allows_scalar(annotation) + + +@total_ordering +class Tidy3dBaseModel(BaseModel): + """Base pydantic model that all Tidy3d components inherit from. + Defines configuration for handling data structures + as well as methods for importing, exporting, and hashing tidy3d objects. + For more details on pydantic base models, see: + `Pydantic models `_ + """ + + model_config = ConfigDict( + arbitrary_types_allowed=True, + defer_build=True, + validate_default=True, + populate_by_name=True, + ser_json_inf_nan="strings", + extra="forbid", + frozen=True, + ) + + attrs: dict = Field( + default_factory=dict, + title="Attributes", + description="Dictionary storing arbitrary metadata for a Tidy3D object. " + "This dictionary can be freely used by the user for storing data without affecting the " + "operation of Tidy3D as it is not used internally. " + "Note that, unlike regular Tidy3D fields, ``attrs`` are mutable. " + "For example, the following is allowed for setting an ``attr`` ``obj.attrs['foo'] = bar``. " + "Also note that Tidy3D will raise a ``TypeError`` if ``attrs`` contain objects " + "that can not be serialized. One can check if ``attrs`` are serializable " + "by calling ``obj.model_dump_json()``.", + ) + + _cached_properties: dict = PrivateAttr(default_factory=dict) + _has_tracers: Optional[bool] = PrivateAttr(default=None) + + @field_validator("name", check_fields=False) + @classmethod + def _validate_name_no_special_characters(cls: type[T], name: Optional[str]) -> Optional[str]: + if name is None: + return name + for character in FORBID_SPECIAL_CHARACTERS: + if character in name: + raise ValueError( + f"Special character '{character}' not allowed in component name {name}." + ) + return name + + def __init_subclass__(cls: type[T], **kwargs: Any) -> None: + """Injects a constant discriminator field before Pydantic builds the model. + + Adds + type: Literal[""] = "" + to every concrete subclass so it can participate in a + `Field(discriminator="type")` union without manual boilerplate. + + Must run *before* `super().__init_subclass__()`; that call lets Pydantic + see the injected field during its normal schema/validator generation. + See also: https://peps.python.org/pep-0487/ + """ + tag = cls.__name__ + cls.__annotations__[TYPE_TAG_STR] = Literal[tag] + setattr(cls, TYPE_TAG_STR, tag) + TYPE_TO_CLASS_MAP[tag] = cls + + if "__tidy3d_end_capture__" not in cls.__dict__: + + @model_validator(mode="after") + def __tidy3d_end_capture__(self: T) -> T: + if log._capture: + log.end_capture(self) + return self + + cls.__tidy3d_end_capture__ = __tidy3d_end_capture__ + + super().__init_subclass__(**kwargs) + + @classmethod + def __pydantic_init_subclass__(cls: type[T], **kwargs: Any) -> None: + super().__pydantic_init_subclass__(**kwargs) + + # add docstring once pydantic is done constructing the class + cls.__doc__ = cls.generate_docstring() + + @model_validator(mode="wrap") + @classmethod + def _capture_validation_warnings( + cls: type[T], + data: Any, + handler: ModelWrapValidatorHandler[T], + ) -> T: + if not log._capture: + return handler(data) + + log.begin_capture() + try: + return handler(data) + except Exception: + log.abort_capture() + raise + + def __hash__(self) -> int: + """Hash method.""" + return self._recursive_hash(self) + + @staticmethod + def _recursive_hash(value: Any) -> int: + # Handle Autograd ArrayBoxes + if isinstance(value, ArrayBox): + # Unwrap the underlying numpy array and recurse + return Tidy3dBaseModel._recursive_hash(value._value) + if isinstance(value, np.ndarray): + # numpy arrays are not hashable by default, use byte representation + v_hash = hashlib.md5(value.tobytes()).hexdigest() + return hash(v_hash) + if isinstance(value, (xr.DataArray, xr.Dataset)): + # we choose to not hash data arrays as this would require a lot of careful handling of units, metadata. + # technically this is incorrect, but should never lead to bugs in current implementation + return hash(str(value.__class__.__name__)) + if isinstance(value, str): + # this if-case is necessary because length-1 string would lead to infinite recursion in sequence case below + return hash(value) + if isinstance(value, Sequence): + # this assumes all objects in lists are hashable by default and do not require special handling + v_hash = tuple([Tidy3dBaseModel._recursive_hash(vi) for vi in value]) + return hash(v_hash) + if isinstance(value, dict): + to_hash_list = [] + for k, v in value.items(): + v_hash = Tidy3dBaseModel._recursive_hash(v) + to_hash_list.append((k, v_hash)) + return hash(tuple(to_hash_list)) + if isinstance(value, Tidy3dBaseModel): + # This function needs to take special care because of mutable attributes inside of frozen pydantic models + to_hash_list = [] + for k in type(value).model_fields: + if k == "attrs": + continue + v_hash = Tidy3dBaseModel._recursive_hash(getattr(value, k)) + to_hash_list.append((k, v_hash)) + extra = getattr(value, "__pydantic_extra__", None) + if extra: + for k, v in extra.items(): + v_hash = Tidy3dBaseModel._recursive_hash(v) + to_hash_list.append((k, v_hash)) + # attrs is mutable, use serialized output as safe hashing option + if value.attrs: + attrs_str = value._attrs_digest() + attrs_hash = hash(attrs_str) + to_hash_list.append(("attrs", attrs_hash)) + return hash(tuple(to_hash_list)) + return hash(value) + + def _hash_self(self) -> str: + """Hash this component with ``hashlib`` in a way that is the same every session.""" + bf = io.BytesIO() + self.to_hdf5(bf) + return hashlib.md5(bf.getvalue()).hexdigest() + + @model_validator(mode="before") + @classmethod + def coerce_numpy_scalars_for_model(cls, data: Any) -> Any: + """ + coerce numpy scalars / size-1 arrays to native Python + scalars, but only for fields whose annotations allow scalars. + """ + if not isinstance(data, dict): + return data + + for name, field in cls.model_fields.items(): + if name not in data or not field_allows_scalar(field): + continue + + v = data[name] + if isinstance(v, np.generic) or (isinstance(v, np.ndarray) and v.size == 1): + data[name] = v.item() + + return data + + @classmethod + def _get_type_value(cls, obj: dict[str, Any]) -> str: + """Return the type tag from a raw dictionary.""" + if not isinstance(obj, dict): + raise TypeError("Input must be a dict") + try: + type_value = obj[TYPE_TAG_STR] + except KeyError as exc: + raise ValueError(f'Missing "{TYPE_TAG_STR}" in data') from exc + if not isinstance(type_value, str) or not type_value: + raise ValueError(f'Invalid "{TYPE_TAG_STR}" value: {type_value!r}') + return type_value + + @classmethod + def _get_registered_class(cls, type_value: str) -> type[Tidy3dBaseModel]: + try: + return TYPE_TO_CLASS_MAP[type_value] + except KeyError as exc: + raise ValueError(f"Unknown type: {type_value}") from exc + + @classmethod + def _should_dispatch_to(cls, target_cls: type[Tidy3dBaseModel]) -> bool: + """Return True if ``cls`` allows auto-dispatch to ``target_cls``.""" + return issubclass(target_cls, cls) + + @classmethod + def _resolve_dispatch_target(cls, obj: dict[str, Any]) -> type[Tidy3dBaseModel]: + """Determine which subclass should receive ``obj``.""" + type_value = cls._get_type_value(obj) + target_cls = cls._get_registered_class(type_value) + if cls._should_dispatch_to(target_cls): + return target_cls + if target_cls is cls: + return cls + raise ValueError( + f'Cannot parse type "{type_value}" using {cls.__name__}; expected subclass of {cls.__name__}.' + ) + + @classmethod + def _target_cls_from_file( + cls, fname: PathLike, group_path: Optional[str] = None + ) -> type[Tidy3dBaseModel]: + """Peek the file metadata to determine the subclass to instantiate.""" + model_dict = cls.dict_from_file( + fname=fname, + group_path=group_path, + load_data_arrays=False, + ) + return cls._resolve_dispatch_target(model_dict) + + @classmethod + def _model_validate(cls, obj: dict[str, Any], **parse_obj_kwargs: Any) -> Tidy3dBaseModel: + """Dispatch ``obj`` to the correct subclass registered in the type map.""" + target_cls = cls._resolve_dispatch_target(obj) + if target_cls is cls: + return super().model_validate(obj, **parse_obj_kwargs) + return target_cls.model_validate(obj, **parse_obj_kwargs) + + @classmethod + def _validate_model_dict( + cls, model_dict: dict[str, Any], **parse_obj_kwargs: Any + ) -> Tidy3dBaseModel: + """Parse ``model_dict`` while optionally auto-dispatching when called on the base class.""" + if cls is Tidy3dBaseModel: + return cls._model_validate(model_dict, **parse_obj_kwargs) + return cls.model_validate(model_dict, **parse_obj_kwargs) + + def _preprocess_update_values(self, update: Mapping[str, Any]) -> dict[str, Any]: + """Preprocess update values to convert lists to tuples where appropriate. + + This helps avoid Pydantic v2 serialization warnings when using `model_copy()` + with list values for tuple fields. + """ + if not update: + return {} + + def get_tuple_element_type(annotation: Any) -> Optional[type]: + """Get the element type of a tuple annotation if it has one consistent type.""" + origin = get_origin(annotation) + if origin is tuple: + args = get_args(annotation) + if args: + # Check if it's a homogeneous tuple like tuple[bool, ...] or tuple[str, ...] + if len(args) == 2 and args[1] is ...: + return args[0] + # Check if all elements have the same type + if all(arg == args[0] for arg in args): + return args[0] + return None + + def should_convert_to_tuple(annotation: Any) -> tuple[bool, Optional[type[Any]]]: + """Check if the given annotation represents a tuple type and return element type if any.""" + origin = get_origin(annotation) + + if origin is tuple: + return True, get_tuple_element_type(annotation) + + # Union types containing tuple + if origin is Union: + args = get_args(annotation) + for arg in args: + if get_origin(arg) is tuple: + return True, get_tuple_element_type(arg) + + return False, None + + def convert_value(value: Any, field_info: FieldInfo) -> Any: + """Convert value based on field type information.""" + annotation = field_info.annotation + + # Handle list/tuple to tuple conversion with proper element types + is_tuple, element_type = should_convert_to_tuple(annotation) + + # Check if value is a numpy array and needs to be converted to tuple + try: + import numpy as np + + if isinstance(value, np.ndarray) and is_tuple: + # Convert numpy array to list first + value = value.tolist() + except ImportError: + pass + + # Handle autograd SequenceBox - convert to tuple + if ( + is_tuple + and hasattr(value, "__class__") + and value.__class__.__name__ == "SequenceBox" + ): + # SequenceBox is iterable, so convert it to tuple + return tuple(value) + + if isinstance(value, (list, tuple)) and is_tuple: + # Convert elements based on element type + if element_type is bool: + # Convert integers to booleans + value = [bool(item) if isinstance(item, int) else item for item in value] + elif element_type is str: + # Ensure all elements are strings + value = [str(item) if not isinstance(item, str) else item for item in value] + else: + # Check if it's a numpy array or contains numpy types + try: + import numpy as np + + if any(isinstance(item, np.generic) for item in value): + # Convert numpy types to Python types + value = [ + item.item() if isinstance(item, np.generic) else item + for item in value + ] + except ImportError: + pass + return tuple(value) + + # Handle int to bool conversion + if annotation is bool and isinstance(value, int): + return bool(value) + + # Handle dict to Tidy3dBaseModel conversion + if isinstance(value, dict): + # Check if the annotation is a Tidy3dBaseModel subclass + origin = get_origin(annotation) + if origin is None: + # Not a generic type, check if it's a direct subclass + try: + if isinstance(annotation, type) and issubclass(annotation, Tidy3dBaseModel): + return annotation(**value) + except (TypeError, AttributeError): + pass + elif origin is Union: + # For Union types, try to convert to the first matching Tidy3dBaseModel type + args = get_args(annotation) + for arg in args: + try: + if isinstance(arg, type) and issubclass(arg, Tidy3dBaseModel): + return arg(**value) + except (TypeError, AttributeError, ValueError): + continue + + return value + + processed = {} + for field_name, value in update.items(): + if field_name in type(self).model_fields: + field_info = type(self).model_fields[field_name] + processed[field_name] = convert_value(value, field_info) + else: + processed[field_name] = value + + return processed + + def copy( + self, + deep: bool = True, + *, + validate: bool = True, + update: Optional[Mapping[str, Any]] = None, + ) -> Self: + """Return a copy of the model. + + Parameters + ---------- + deep : bool = True + Whether to make a deep copy first (same as v1). + validate : bool = True + If ``True``, run full Pydantic validation on the copied data. + update : Optional[Mapping[str, Any]] = None + Optional mapping of fields to overwrite (passed straight + through to ``model_copy(update=...)``). + """ + if update and self.model_config.get("extra") == "forbid": + invalid = set(update) - set(type(self).model_fields) + if invalid: + raise KeyError(f"'{self.type}' received invalid fields on copy: {invalid}") + + # preprocess update values to convert lists to tuples where appropriate + if update: + update = self._preprocess_update_values(update) + + new_model = self.model_copy(deep=deep, update=update) + + if validate: + return self.__class__.model_validate(new_model.model_dump()) + else: + # make sure cache is always cleared + new_model._cached_properties = {} + + new_model._has_tracers = None + return new_model + + def updated_copy( + self, + path: Optional[str] = None, + *, + deep: bool = True, + validate: bool = True, + **kwargs: Any, + ) -> Self: + """Make copy of a component instance with ``**kwargs`` indicating updated field values. + + Note + ---- + If ``path`` is supplied, applies the updated copy with the update performed on the sub- + component corresponding to the path. For indexing into a tuple or list, use the integer + value. + + Example + ------- + >>> sim = simulation.updated_copy(size=new_size, path=f"structures/{i}/geometry") # doctest: +SKIP + """ + if not path: + return self.copy(deep=deep, validate=validate, update=kwargs) + + path_parts = path.split("/") + field_name, *rest = path_parts + + try: + sub_component = getattr(self, field_name) + except AttributeError as exc: + raise AttributeError( + f"Could not find field '{field_name}' in path '{path}'. " + f"Available top-level fields: {tuple(type(self).model_fields)}." + ) from exc + + if isinstance(sub_component, (list, tuple)): + try: + index = int(rest[0]) + except (IndexError, ValueError): + raise ValueError( + f"Expected integer index into '{field_name}' in path '{path}'." + ) from None + sub_component_list = list(sub_component) + sub_component_list[index] = sub_component_list[index].updated_copy( + path="/".join(rest[1:]), + deep=deep, + validate=validate, + **kwargs, + ) + new_value = type(sub_component)(sub_component_list) + else: + new_value = sub_component.updated_copy( + path="/".join(rest), + deep=deep, + validate=validate, + **kwargs, + ) + + return self.copy(deep=deep, validate=validate, update={field_name: new_value}) + + @staticmethod + def _core_model_traversal( + current_obj: Any, current_path_segments: tuple[str, ...] + ) -> Iterator[tuple[Self, tuple[str, ...]]]: + """ + Recursively traverses a model structure yielding Tidy3dBaseModel instances and their paths. + + This is an internal helper method used by :meth:`find_paths` and :meth:`find_submodels` + to navigate nested :class:`Tidy3dBaseModel` structures. + + Parameters + ---------- + current_obj : Any + The current object in the traversal, which can be a :class:`Tidy3dBaseModel`, + list, tuple, or other type. + current_path_segments : tuple[str, ...] + A tuple of strings representing the path segments from the initial model + to the ``current_obj``. + + Returns + ------- + Iterator[tuple[Self, tuple[str, ...]]] + An iterator yielding tuples, where the first element is a found :class:`Tidy3dBaseModel` instance + and the second is a tuple of strings representing the path to that instance + from the initial object. The path for the top-level model itself will be an empty tuple. + """ + if isinstance(current_obj, Tidy3dBaseModel): + yield current_obj, current_path_segments + + for field_name in type(current_obj).model_fields: + if ( + field_name == "type" + and getattr(current_obj, field_name, None) == current_obj.__class__.__name__ + ): + continue + + field_value = getattr(current_obj, field_name) + yield from Tidy3dBaseModel._core_model_traversal( + field_value, (*current_path_segments, field_name) + ) + elif isinstance(current_obj, (list, tuple)): + for index, item in enumerate(current_obj): + yield from Tidy3dBaseModel._core_model_traversal( + item, (*current_path_segments, str(index)) + ) + + def find_paths(self, target_field_name: str, target_field_value: Any = Undefined) -> list[str]: + """ + Finds paths to nested model instances that have a specific field, optionally matching a value. + + The paths are string representations like ``"structures/0/geometry"``, designed for direct + use with the :meth:`updated_copy` method to modify specific parts of this model. + An empty string ``""`` in the returned list indicates that this model instance + itself (the one ``find_paths`` is called on) matches the criteria. + + Parameters + ---------- + target_field_name : str + The name of the attribute (field) to search for within nested + :class:`Tidy3dBaseModel` instances. For example, ``"name"`` or ``"permittivity"``. + target_field_value : Any, optional + If provided, only paths to model instances where ``target_field_name`` also has this + specific value will be returned. If omitted, paths are returned if the + ``target_field_name`` exists, regardless of its value. + + Returns + ------- + list[str] + A sorted list of unique string paths. Each path points to a + :class:`Tidy3dBaseModel` instance that possesses the ``target_field_name`` + (and optionally matches ``target_field_value``). + + Example + ------- + >>> # Assume 'sim' is a Tidy3D simulation object + >>> # Find all geometries named "waveguide" + >>> paths = sim.find_paths(target_field_name="name", target_field_value="waveguide") # doctest: +SKIP + >>> # paths might be ['structures/0', 'structures/3'] + >>> # Update the size of the first found "waveguide" + >>> new_sim = sim.updated_copy(path=paths[0], size=(1.0, 0.5, 0.22)) # doctest: +SKIP + """ + found_paths_set = set() + + for sub_model_instance, path_segments_to_sub_model in Tidy3dBaseModel._core_model_traversal( + self, () + ): + if target_field_name in type(sub_model_instance).model_fields: + passes_value_filter = True + if target_field_value is not Undefined: + actual_value = getattr(sub_model_instance, target_field_name) + if actual_value != target_field_value: + passes_value_filter = False + + if passes_value_filter: + path_str = "/".join(path_segments_to_sub_model) + found_paths_set.add(path_str) + + return sorted(found_paths_set) + + def find_submodels(self, target_type: Self) -> list[Self]: + """ + Finds all unique nested instances of a specific Tidy3D model type within this model. + + This method traverses the model structure and collects all instances that are of + the ``target_type`` (e.g., :class:`~tidy3d.Structure`, :class:`~tidy3d.Medium`, + :class:`~tidy3d.Box`). + Uniqueness is determined by the model's content. The order of models + in the returned list corresponds to their first encounter during a depth-first traversal. + + Parameters + ---------- + target_type : Tidy3dBaseModel + The specific Tidy3D class (e.g., ``Structure``, ``Medium``, ``Box``) to search for. + This class must be a subclass of :class:`Tidy3dBaseModel`. + + Returns + ------- + list[Tidy3dBaseModel] + A list of unique instances found within this model that are of the + provided ``target_type``. + + Example + ------- + >>> # Assume 'sim' is a Tidy3D Simulation object + >>> # Find all Structure instances within the simulation + >>> all_structures = sim.find_submodels(td.Structure) # doctest: +SKIP + >>> for struct in all_structures: + ... print(f"Structure: {struct.name}, medium: {struct.medium}") # doctest: +SKIP + + >>> # Find all Box geometries within the simulation + >>> all_boxes = sim.find_submodels(td.Box) # doctest: +SKIP + >>> for box in all_boxes: + ... print(f"Found Box with size: {box.size}") # doctest: +SKIP + + >>> # Find all Medium instances (useful for checking materials) + >>> all_media = sim.find_submodels(td.Medium) # doctest: +SKIP + >>> # Note: This would find td.Medium instances, but not td.PECMedium or td.PoleResidue + >>> # unless they inherit directly from td.Medium and not just Tidy3dBaseModel or td.AbstractMedium. + >>> # To find all medium types, one might search for td.AbstractMedium if that's a common base. + """ + found_models_dict = {} + + for sub_model_candidate, _ in Tidy3dBaseModel._core_model_traversal(self, ()): + if isinstance(sub_model_candidate, target_type): + if sub_model_candidate not in found_models_dict: + found_models_dict[sub_model_candidate] = True + + return list(found_models_dict.keys()) + + def help(self, methods: bool = False) -> None: + """Prints message describing the fields and methods of a :class:`Tidy3dBaseModel`. + + Parameters + ---------- + methods : bool = False + Whether to also print out information about object's methods. + + Example + ------- + >>> simulation.help(methods=True) # doctest: +SKIP + """ + rich.inspect(type(self), methods=methods) + + @classmethod + def from_file( + cls, + fname: PathLike, + group_path: Optional[str] = None, + lazy: bool = False, + on_load: Optional[Callable[[Any], None]] = None, + **parse_obj_kwargs: Any, + ) -> Self: + """Loads a :class:`Tidy3dBaseModel` from .yaml, .json, .hdf5, or .hdf5.gz file. + + Parameters + ---------- + fname : PathLike + Full path to the file to load the :class:`Tidy3dBaseModel` from. + group_path : Optional[str] = None + Path to a group inside the file to use as the base level. Only for hdf5 files. + Starting `/` is optional. + lazy : bool = False + Whether to load the actual data (``lazy=False``) or return a proxy that loads + the data when accessed (``lazy=True``). + on_load : Optional[Callable[[Any], None]] = None + Callback function executed once the model is fully materialized. + Only used if ``lazy=True``. The callback is invoked with the loaded + instance as its sole argument, enabling post-processing such as + validation, logging, or warnings checks. + **model_validate_kwargs + Keyword arguments passed to pydantic's ``model_validate`` method when loading model. + + Returns + ------- + Self + An instance of the component class calling ``load``. + + Example + ------- + >>> simulation = Simulation.from_file(fname='folder/sim.json') # doctest: +SKIP + """ + if lazy: + target_cls = cls._target_cls_from_file(fname=fname, group_path=group_path) + Proxy = _make_lazy_proxy(target_cls, on_load=on_load) + return Proxy(fname, group_path, parse_obj_kwargs) + model_dict = cls.dict_from_file(fname=fname, group_path=group_path) + obj = cls._validate_model_dict(model_dict, **parse_obj_kwargs) + if not lazy and on_load is not None: + on_load(obj) + return obj + + @classmethod + def dict_from_file( + cls: type[T], + fname: PathLike, + group_path: Optional[str] = None, + *, + load_data_arrays: bool = True, + ) -> dict: + """Loads a dictionary containing the model from a .yaml, .json, .hdf5, or .hdf5.gz file. + + Parameters + ---------- + fname : PathLike + Full path to the file to load the :class:`Tidy3dBaseModel` from. + group_path : str, optional + Path to a group inside the file to use as the base level. + + Returns + ------- + dict + A dictionary containing the model. + + Example + ------- + >>> simulation = Simulation.from_file(fname='folder/sim.json') # doctest: +SKIP + """ + fname_path = Path(fname) + extension = _get_valid_extension(fname_path) + kwargs = {"fname": fname_path} + + if group_path is not None: + if extension in {".hdf5", ".hdf5.gz", ".h5"}: + kwargs["group_path"] = group_path + else: + log.warning("'group_path' provided, but this feature only works with hdf5 files.") + + if extension in {".hdf5", ".hdf5.gz", ".h5"}: + kwargs["load_data_arrays"] = load_data_arrays + + converter = { + ".json": cls.dict_from_json, + ".yaml": cls.dict_from_yaml, + ".hdf5": cls.dict_from_hdf5, + ".hdf5.gz": cls.dict_from_hdf5_gz, + ".h5": cls.dict_from_hdf5, + }[extension] + return converter(**kwargs) + + def to_file(self, fname: PathLike) -> None: + """Exports :class:`Tidy3dBaseModel` instance to .yaml, .json, or .hdf5 file + + Parameters + ---------- + fname : PathLike + Full path to the .yaml or .json file to save the :class:`Tidy3dBaseModel` to. + + Example + ------- + >>> simulation.to_file(fname='folder/sim.json') # doctest: +SKIP + """ + extension = _get_valid_extension(fname) + converter = { + ".json": self.to_json, + ".yaml": self.to_yaml, + ".hdf5": self.to_hdf5, + ".hdf5.gz": self.to_hdf5_gz, + }[extension] + return converter(fname=fname) + + @classmethod + def from_json(cls: type[T], fname: PathLike, **model_validate_kwargs: Any) -> Self: + """Load a :class:`Tidy3dBaseModel` from .json file. + + Parameters + ---------- + fname : PathLike + Full path to the .json file to load the :class:`Tidy3dBaseModel` from. + + Returns + ------- + Self + An instance of the component class calling `load`. + **model_validate_kwargs + Keyword arguments passed to pydantic's ``model_validate`` method. + + Example + ------- + >>> simulation = Simulation.from_json(fname='folder/sim.json') # doctest: +SKIP + """ + model_dict = cls.dict_from_json(fname=fname) + return cls._validate_model_dict(model_dict, **model_validate_kwargs) + + @classmethod + def dict_from_json(cls: type[T], fname: PathLike) -> dict: + """Load dictionary of the model from a .json file. + + Parameters + ---------- + fname : PathLike + Full path to the .json file to load the :class:`Tidy3dBaseModel` from. + + Returns + ------- + dict + A dictionary containing the model. + + Example + ------- + >>> sim_dict = Simulation.dict_from_json(fname='folder/sim.json') # doctest: +SKIP + """ + with open(fname, encoding="utf-8") as json_fhandle: + model_dict = json.load(json_fhandle) + return model_dict + + def to_json(self, fname: PathLike) -> None: + """Exports :class:`Tidy3dBaseModel` instance to .json file + + Parameters + ---------- + fname : PathLike + Full path to the .json file to save the :class:`Tidy3dBaseModel` to. + + Example + ------- + >>> simulation.to_json(fname='folder/sim.json') # doctest: +SKIP + """ + export_model = self.to_static() + json_string = export_model.model_dump_json(indent=INDENT_JSON_FILE) + self._warn_if_contains_data(json_string) + path = Path(fname) + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as file_handle: + file_handle.write(json_string) + + @classmethod + def from_yaml(cls: type[T], fname: PathLike, **model_validate_kwargs: Any) -> Self: + """Loads :class:`Tidy3dBaseModel` from .yaml file. + + Parameters + ---------- + fname : PathLike + Full path to the .yaml file to load the :class:`Tidy3dBaseModel` from. + **model_validate_kwargs + Keyword arguments passed to pydantic's ``model_validate`` method. + + Returns + ------- + Self + An instance of the component class calling `from_yaml`. + + Example + ------- + >>> simulation = Simulation.from_yaml(fname='folder/sim.yaml') # doctest: +SKIP + """ + model_dict = cls.dict_from_yaml(fname=fname) + return cls._validate_model_dict(model_dict, **model_validate_kwargs) + + @classmethod + def dict_from_yaml(cls: type[T], fname: PathLike) -> dict: + """Load dictionary of the model from a .yaml file. + + Parameters + ---------- + fname : PathLike + Full path to the .yaml file to load the :class:`Tidy3dBaseModel` from. + + Returns + ------- + dict + A dictionary containing the model. + + Example + ------- + >>> sim_dict = Simulation.dict_from_yaml(fname='folder/sim.yaml') # doctest: +SKIP + """ + with open(fname, encoding="utf-8") as yaml_in: + model_dict = yaml.safe_load(yaml_in) + return model_dict + + def to_yaml(self, fname: PathLike) -> None: + """Exports :class:`Tidy3dBaseModel` instance to .yaml file. + + Parameters + ---------- + fname : PathLike + Full path to the .yaml file to save the :class:`Tidy3dBaseModel` to. + + Example + ------- + >>> simulation.to_yaml(fname='folder/sim.yaml') # doctest: +SKIP + """ + export_model = self.to_static() + # We intentionally round-trip through JSON to preserve the exact JSON-mode serialization + # behavior in YAML output (notably `ser_json_inf_nan="strings"` for Infinity/-Infinity/NaN). + json_string = export_model.model_dump_json() + self._warn_if_contains_data(json_string) + model_dict = json.loads(json_string) + path = Path(fname) + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w+", encoding="utf-8") as file_handle: + yaml.dump(model_dict, file_handle, indent=INDENT_JSON_FILE) + + @staticmethod + def _warn_if_contains_data(json_str: str) -> None: + """Log a warning if the json string contains data, used in '.json' and '.yaml' file.""" + if any((key in json_str for key, _ in DATA_ARRAY_MAP.items())): + log.warning( + "Data contents found in the model to be written to file. " + "Note that this data will not be included in '.json' or '.yaml' formats. " + "As a result, it will not be possible to load the file back to the original model. " + "Instead, use '.hdf5' extension in filename passed to 'to_file()'." + ) + + @staticmethod + def _construct_group_path(group_path: str) -> str: + """Construct a group path with the leading forward slash if not supplied.""" + + # empty string or None + if not group_path: + return "/" + + # missing leading forward slash + if group_path[0] != "/": + return f"/{group_path}" + + return group_path + + @staticmethod + def get_tuple_group_name(index: int) -> str: + """Get the group name of a tuple element.""" + return str(int(index)) + + @staticmethod + def get_tuple_index(key_name: str) -> int: + """Get the index into the tuple based on its group name.""" + return int(str(key_name)) + + @classmethod + def tuple_to_dict(cls: type[T], tuple_values: tuple) -> dict: + """How we generate a dictionary mapping new keys to tuple values for hdf5.""" + return {cls.get_tuple_group_name(index=i): val for i, val in enumerate(tuple_values)} + + @classmethod + def get_sub_model( + cls: type[T], group_path: str, model_dict: Union[dict[str, Any], list[Any]] + ) -> dict: + """Get the sub model for a given group path.""" + + for key in group_path.split("/"): + if key: + if isinstance(model_dict, list): + tuple_index = cls.get_tuple_index(key_name=key) + model_dict = model_dict[tuple_index] + else: + model_dict = model_dict[key] + return model_dict + + @staticmethod + def _json_string_key(index: int) -> str: + """Get json string key for string chunk number ``index``.""" + if index: + return f"{JSON_TAG}_{index}" + return JSON_TAG + + @classmethod + def _json_string_from_hdf5(cls: type[T], fname: PathLike) -> str: + """Load the model json string from an hdf5 file.""" + with h5py.File(fname, "r") as f_handle: + num_string_parts = len([key for key in f_handle.keys() if JSON_TAG in key]) + json_string = b"" + for ind in range(num_string_parts): + json_string += f_handle[cls._json_string_key(ind)][()] + return json_string + + @classmethod + def dict_from_hdf5( + cls: type[T], + fname: PathLike, + group_path: str = "", + custom_decoders: Optional[list[Callable]] = None, + load_data_arrays: bool = True, + ) -> dict: + """Loads a dictionary containing the model contents from a .hdf5 file. + + Parameters + ---------- + fname : PathLike + Full path to the .hdf5 file to load the :class:`Tidy3dBaseModel` from. + group_path : str, optional + Path to a group inside the file to selectively load a sub-element of the model only. + custom_decoders : List[Callable] + List of functions accepting + (fname: str, group_path: str, model_dict: dict, key: str, value: Any) that store the + value in the model dict after a custom decoding. + + Returns + ------- + dict + Dictionary containing the model. + + Example + ------- + >>> sim_dict = Simulation.dict_from_hdf5(fname='folder/sim.hdf5') # doctest: +SKIP + """ + + def is_data_array(value: Any) -> bool: + """Whether a value is supposed to be a data array based on the contents.""" + return isinstance(value, str) and value in DATA_ARRAY_MAP + + fname_path = Path(fname) + + def load_data_from_file(model_dict: dict, group_path: str = "") -> None: + """For every DataArray item in dictionary, load path of hdf5 group as value.""" + + for key, value in model_dict.items(): + subpath = f"{group_path}/{key}" + + # apply custom validation to the key value pair and modify model_dict + if custom_decoders: + for custom_decoder in custom_decoders: + custom_decoder( + fname=str(fname_path), + group_path=subpath, + model_dict=model_dict, + key=key, + value=value, + ) + + # write the path to the element of the json dict where the data_array should be + if is_data_array(value): + data_array_type = DATA_ARRAY_MAP[value] + model_dict[key] = data_array_type.from_hdf5( + fname=fname_path, group_path=subpath + ) + continue + + # if a list, assign each element a unique key, recurse + if isinstance(value, (list, tuple)): + value_dict = cls.tuple_to_dict(tuple_values=value) + load_data_from_file(model_dict=value_dict, group_path=subpath) + + # handle case of nested list of DataArray elements + val_tuple = list(value_dict.values()) + for ind, (model_item, value_item) in enumerate(zip(model_dict[key], val_tuple)): + if is_data_array(model_item): + model_dict[key][ind] = value_item + + # if a dict, recurse + elif isinstance(value, dict): + load_data_from_file(model_dict=value, group_path=subpath) + + model_dict = json.loads(cls._json_string_from_hdf5(fname=fname_path)) + group_path = cls._construct_group_path(group_path) + model_dict = cls.get_sub_model(group_path=group_path, model_dict=model_dict) + if load_data_arrays: + load_data_from_file(model_dict=model_dict, group_path=group_path) + return model_dict + + @classmethod + def from_hdf5( + cls: type[T], + fname: PathLike, + group_path: str = "", + custom_decoders: Optional[list[Callable]] = None, + **model_validate_kwargs: Any, + ) -> Self: + """Loads :class:`Tidy3dBaseModel` instance to .hdf5 file. + + Parameters + ---------- + fname : PathLike + Full path to the .hdf5 file to load the :class:`Tidy3dBaseModel` from. + group_path : str, optional + Path to a group inside the file to selectively load a sub-element of the model only. + Starting `/` is optional. + custom_decoders : List[Callable] + List of functions accepting + (fname: str, group_path: str, model_dict: dict, key: str, value: Any) that store the + value in the model dict after a custom decoding. + **model_validate_kwargs + Keyword arguments passed to pydantic's ``model_validate`` method. + + Example + ------- + >>> simulation = Simulation.from_hdf5(fname='folder/sim.hdf5') # doctest: +SKIP + """ + + group_path = cls._construct_group_path(group_path) + model_dict = cls.dict_from_hdf5( + fname=fname, + group_path=group_path, + custom_decoders=custom_decoders, + ) + return cls._validate_model_dict(model_dict, **model_validate_kwargs) + + def to_hdf5( + self, + fname: Union[PathLike, io.BytesIO], + custom_encoders: Optional[list[Callable]] = None, + ) -> None: + """Exports :class:`Tidy3dBaseModel` instance to .hdf5 file. + + Parameters + ---------- + fname : Union[PathLike, BytesIO] + Full path to the .hdf5 file or buffer to save the :class:`Tidy3dBaseModel` to. + custom_encoders : List[Callable] + List of functions accepting (fname: str, group_path: str, value: Any) that take + the ``value`` supplied and write it to the hdf5 ``fname`` at ``group_path``. + + Example + ------- + >>> simulation.to_hdf5(fname='folder/sim.hdf5') # doctest: +SKIP + """ + + export_model = self.to_static() + traced_keys_payload = export_model.attrs.get(TRACED_FIELD_KEYS_ATTR) + + if traced_keys_payload is None: + traced_keys_payload = self.attrs.get(TRACED_FIELD_KEYS_ATTR) + if traced_keys_payload is None: + traced_keys_payload = self._serialized_traced_field_keys() + path = Path(fname) if isinstance(fname, PathLike) else fname + with h5py.File(path, "w") as f_handle: + json_str = export_model.model_dump_json() + for ind in range(ceil(len(json_str) / MAX_STRING_LENGTH)): + ind_start = int(ind * MAX_STRING_LENGTH) + ind_stop = min(int(ind + 1) * MAX_STRING_LENGTH, len(json_str)) + f_handle[self._json_string_key(ind)] = json_str[ind_start:ind_stop] + + def add_data_to_file(data_dict: dict, group_path: str = "") -> None: + """For every DataArray item in dictionary, write path of hdf5 group as value.""" + + for key, value in data_dict.items(): + # append the key to the path + subpath = f"{group_path}/{key}" + + if custom_encoders: + for custom_encoder in custom_encoders: + custom_encoder(fname=f_handle, group_path=subpath, value=value) + + # write the path to the element of the json dict where the data_array should be + if isinstance(value, xr.DataArray): + value.to_hdf5(fname=f_handle, group_path=subpath) + + # if a tuple, assign each element a unique key + if isinstance(value, (list, tuple)): + value_dict = export_model.tuple_to_dict(tuple_values=value) + add_data_to_file(data_dict=value_dict, group_path=subpath) + + # if a dict, recurse + elif isinstance(value, dict): + add_data_to_file(data_dict=value, group_path=subpath) + + add_data_to_file(data_dict=export_model.model_dump()) + if traced_keys_payload: + f_handle.attrs[TRACED_FIELD_KEYS_ATTR] = traced_keys_payload + + @classmethod + def dict_from_hdf5_gz( + cls: type[T], + fname: PathLike, + group_path: str = "", + custom_decoders: Optional[list[Callable]] = None, + load_data_arrays: bool = True, + ) -> dict: + """Loads a dictionary containing the model contents from a .hdf5.gz file. + + Parameters + ---------- + fname : PathLike + Full path to the .hdf5.gz file to load the :class:`Tidy3dBaseModel` from. + group_path : str, optional + Path to a group inside the file to selectively load a sub-element of the model only. + custom_decoders : List[Callable] + List of functions accepting + (fname: str, group_path: str, model_dict: dict, key: str, value: Any) that store the + value in the model dict after a custom decoding. + + Returns + ------- + dict + Dictionary containing the model. + + Example + ------- + >>> sim_dict = Simulation.dict_from_hdf5(fname='folder/sim.hdf5.gz') # doctest: +SKIP + """ + file_descriptor, extracted = tempfile.mkstemp(".hdf5") + os.close(file_descriptor) + extracted_path = Path(extracted) + try: + extract_gzip_file(fname, extracted_path) + result = cls.dict_from_hdf5( + extracted_path, + group_path=group_path, + custom_decoders=custom_decoders, + load_data_arrays=load_data_arrays, + ) + finally: + extracted_path.unlink(missing_ok=True) + + return result + + @classmethod + def from_hdf5_gz( + cls: type[T], + fname: PathLike, + group_path: str = "", + custom_decoders: Optional[list[Callable]] = None, + **model_validate_kwargs: Any, + ) -> Self: + """Loads :class:`Tidy3dBaseModel` instance to .hdf5.gz file. + + Parameters + ---------- + fname : PathLike + Full path to the .hdf5.gz file to load the :class:`Tidy3dBaseModel` from. + group_path : str, optional + Path to a group inside the file to selectively load a sub-element of the model only. + Starting `/` is optional. + custom_decoders : List[Callable] + List of functions accepting + (fname: str, group_path: str, model_dict: dict, key: str, value: Any) that store the + value in the model dict after a custom decoding. + **model_validate_kwargs + Keyword arguments passed to pydantic's ``model_validate`` method. + + Example + ------- + >>> simulation = Simulation.from_hdf5_gz(fname='folder/sim.hdf5.gz') # doctest: +SKIP + """ + + group_path = cls._construct_group_path(group_path) + model_dict = cls.dict_from_hdf5_gz( + fname=fname, + group_path=group_path, + custom_decoders=custom_decoders, + ) + return cls._validate_model_dict(model_dict, **model_validate_kwargs) + + def to_hdf5_gz( + self, + fname: Union[PathLike, io.BytesIO], + custom_encoders: Optional[list[Callable]] = None, + ) -> None: + """Exports :class:`Tidy3dBaseModel` instance to .hdf5.gz file. + + Parameters + ---------- + fname : Union[PathLike, BytesIO] + Full path to the .hdf5.gz file or buffer to save the :class:`Tidy3dBaseModel` to. + custom_encoders : List[Callable] + List of functions accepting (fname: str, group_path: str, value: Any) that take + the ``value`` supplied and write it to the hdf5 ``fname`` at ``group_path``. + + Example + ------- + >>> simulation.to_hdf5_gz(fname='folder/sim.hdf5.gz') # doctest: +SKIP + """ + file, decompressed = tempfile.mkstemp(".hdf5") + os.close(file) + try: + self.to_hdf5(decompressed, custom_encoders=custom_encoders) + compress_file_to_gzip(decompressed, fname) + finally: + os.unlink(decompressed) + + def __lt__(self, other: object) -> bool: + """define < for getting unique indices based on hash.""" + return hash(self) < hash(other) + + def __eq__(self, other: object) -> bool: + """Two models are equal when origins match and every public or extra field matches.""" + if not isinstance(other, BaseModel): + return NotImplemented + + self_origin = ( + getattr(self, "__pydantic_generic_metadata__", {}).get("origin") or self.__class__ + ) + other_origin = ( + getattr(other, "__pydantic_generic_metadata__", {}).get("origin") or other.__class__ + ) + if self_origin is not other_origin: + return False + + if getattr(self, "__pydantic_extra__", None) != getattr(other, "__pydantic_extra__", None): + return False + + def _fields_equal(a: Any, b: Any) -> bool: + a = get_static(a) + b = get_static(b) + + if a is b: + return True + if type(a) is not type(b): + if not (isinstance(a, (list, tuple)) and isinstance(b, (list, tuple))): + return False + if isinstance(a, np.ndarray): + return np.array_equal(a, b) + if isinstance(a, (xr.DataArray, xr.Dataset)): + return a.equals(b) + if isinstance(a, Mapping): + if a.keys() != b.keys(): + return False + return all(_fields_equal(a[k], b[k]) for k in a) + if isinstance(a, Sequence) and not isinstance(a, (str, bytes)): + if len(a) != len(b): + return False + return all(_fields_equal(x, y) for i, (x, y) in enumerate(zip(a, b))) + if isinstance(a, float) and isinstance(b, float) and np.isnan(a) and np.isnan(b): + return True + return a == b + + for name in type(self).model_fields: + if not _fields_equal(getattr(self, name), getattr(other, name)): + return False + + return True + + def _attrs_digest(self) -> str: + """Stable digest of `attrs` using the same JSON encoding rules as `model_dump_json()`.""" + # encoders = getattr(self.__config__, "json_encoders", {}) or {} + + # def _default(o): + # return custom_pydantic_encoder(encoders, o) + + json_str = json.dumps( + self.attrs, + # default=_default, + sort_keys=True, + separators=(",", ":"), + ensure_ascii=False, + ) + json_str = make_json_compatible(json_str) + + return hashlib.sha256(json_str.encode("utf-8")).hexdigest() + + @cached_property_guarded(lambda self: self._attrs_digest()) + def _json_string(self) -> str: + """Returns string representation of a :class:`Tidy3dBaseModel`. + + Returns + ------- + str + Json-formatted string holding :class:`Tidy3dBaseModel` data. + """ + return self.model_dump_json(indent=INDENT, exclude_unset=False) + + def _strip_traced_fields( + self, starting_path: tuple[str, ...] = (), include_untraced_data_arrays: bool = False + ) -> AutogradFieldMap: + """Extract a dictionary mapping paths in the model to the data traced by ``autograd``. + + Parameters + ---------- + starting_path : tuple[str, ...] = () + If provided, starts recursing in self.model_dump() from this path of field names + include_untraced_data_arrays : bool = False + Whether to include ``DataArray`` objects without tracers. + We need to include these when returning data, but are unnecessary for structures. + + Returns + ------- + dict + mapping of traced fields used by ``autograd`` + + """ + + path = tuple(starting_path) + if self._has_tracers is False and not include_untraced_data_arrays: + return TracedDict() + + field_mapping = {} + + def handle_value(x: Any, path: tuple[str, ...]) -> None: + """recursively update ``field_mapping`` with path to the autograd data.""" + + # this is a leaf node that we want to trace, add this path and data to the mapping + if isbox(x): + field_mapping[path] = x + + # for data arrays, need to be more careful as their tracers are stored in .data + elif isinstance(x, xr.DataArray): + data = x.data + if isbox(data) or any(isbox(el) for el in np.asarray(data).ravel()): + field_mapping[path] = x.data + elif include_untraced_data_arrays: + field_mapping[path] = x.data + + # for sequences, add (i,) to the path and handle each value individually + elif isinstance(x, (list, tuple)): + for i, val in enumerate(x): + handle_value(val, path=(*path, i)) + + # for dictionaries, add the (key,) to the path and handle each value individually + elif isinstance(x, dict): + for key, val in x.items(): + handle_value(val, path=(*path, key)) + + # recursively parse the dictionary of this object + self_dict = self.model_dump(round_trip=True) + + # if an include_only string was provided, only look at that subset of the dict + if path: + for key in path: + self_dict = self_dict[key] + + handle_value(self_dict, path=path) + + if field_mapping: + if not include_untraced_data_arrays: + self._has_tracers = True + return TracedDict(field_mapping) + + if not include_untraced_data_arrays and not path: + self._has_tracers = False + return TracedDict() + + def _insert_traced_fields(self, field_mapping: AutogradFieldMap) -> Self: + """Recursively insert a map of paths to autograd-traced fields into a copy of this obj.""" + self_dict = self.model_dump(round_trip=True) + + def insert_value(x: Any, path: tuple[str, ...], sub_dict: dict[str, Any]) -> None: + """Insert a value into the path into a dictionary.""" + current_dict = sub_dict + for key in path[:-1]: + if isinstance(current_dict[key], tuple): + current_dict[key] = list(current_dict[key]) + current_dict = current_dict[key] + + final_key = path[-1] + if isinstance(current_dict[final_key], tuple): + current_dict[final_key] = list(current_dict[final_key]) + + sub_element = current_dict[final_key] + if isinstance(sub_element, xr.DataArray): + current_dict[final_key] = sub_element.copy(deep=False, data=x) + + else: + current_dict[final_key] = x + + for path, value in field_mapping.items(): + insert_value(value, path=path, sub_dict=self_dict) + + return self.__class__.model_validate(self_dict) + + def _serialized_traced_field_keys( + self, field_mapping: Optional[AutogradFieldMap] = None + ) -> Optional[str]: + """Return a serialized, order-independent representation of traced field paths.""" + + if field_mapping is None: + field_mapping = self._strip_traced_fields() + if not field_mapping: + return None + + # TODO: remove this deferred import once TracerKeys is decoupled from Tidy3dBaseModel. + from tidy3d._common.components.autograd.field_map import TracerKeys + + tracer_keys = TracerKeys.from_field_mapping(field_mapping) + return tracer_keys.model_dump_json() + + def to_static(self) -> Self: + """Version of object with all autograd-traced fields removed.""" + + if self._has_tracers is False: + return self + + # get dictionary of all traced fields + field_mapping = self._strip_traced_fields() + + # shortcut to just return self if no tracers found, for performance + if not field_mapping: + self._has_tracers = False + return self + + # convert all fields to static values + field_mapping_static = {key: get_static(val) for key, val in field_mapping.items()} + + # insert the static values into a copy of self + static_self = self._insert_traced_fields(field_mapping_static) + static_self._has_tracers = False + return static_self + + @classmethod + def generate_docstring(cls) -> str: + """Generates a docstring for a Tidy3D model.""" + + doc = "" + + # keep any pre-existing class description + original_docstrings = [] + if cls.__doc__: + original_docstrings = cls.__doc__.split("\n\n") + doc += original_docstrings.pop(0) + original_docstrings = "\n\n".join(original_docstrings) + + # parameters + doc += "\n\n Parameters\n ----------\n" + for field_name, field in cls.model_fields.items(): # v2 + if field_name == TYPE_TAG_STR: + continue + + # type + ann = getattr(field, "annotation", None) + data_type = _fmt_ann_literal(ann) + + # default / default_factory + default_val = ( + f"{field.default_factory.__name__}()" + if field.default_factory is not None + else field.get_default(call_default_factory=False) + ) + + if isinstance(default_val, BaseModel) or ( + "=" in str(default_val) if default_val is not None else False + ): + default_val = ", ".join( + str(f"{default_val.__class__.__name__}({default_val})").split(" ") + ) + + default_str = "" if field.is_required() else f" = {default_val}" + doc += f" {field_name} : {data_type}{default_str}\n" + + parts = [] + + # units + units = None + extra = getattr(field, "json_schema_extra", None) + if isinstance(extra, dict): + units = extra.get("units") + if units is None and hasattr(field, "metadata"): + for meta in field.metadata: + if isinstance(meta, dict) and "units" in meta: + units = meta["units"] + break + if units is not None: + unitstr = ( + f"({', '.join(str(u) for u in units)})" + if isinstance(units, (list, tuple)) + else str(units) + ) + parts.append(f"[units = {unitstr}].") + + # description + desc = getattr(field, "description", None) + if desc: + parts.append(desc) + + if parts: + doc += " " + " ".join(parts) + "\n" + + if original_docstrings: + doc += "\n" + original_docstrings + doc += "\n" + + return doc + + def get_submodels_by_hash(self) -> dict[int, list[Union[str, tuple[str, int]]]]: + """ + Return a mapping ``{hash(submodel): [field_path, ...]}`` for every + nested ``Tidy3dBaseModel`` inside this model. + """ + out = defaultdict(list) + + for name in type(self).model_fields: + value = getattr(self, name) + + if isinstance(value, Tidy3dBaseModel): + out[hash(value)].append(name) + continue + + if isinstance(value, (list, tuple)): + for idx, item in enumerate(value): + if isinstance(item, Tidy3dBaseModel): + out[hash(item)].append((name, idx)) + + elif isinstance(value, np.ndarray): + for idx, item in enumerate(value.flat): + if isinstance(item, Tidy3dBaseModel): + out[hash(item)].append((name, idx)) + + elif isinstance(value, dict): + for k, item in value.items(): + if isinstance(item, Tidy3dBaseModel): + out[hash(item)].append((name, k)) + + return dict(out) + + @staticmethod + def _scientific_notation( + min_val: float, max_val: float, min_digits: int = 4 + ) -> tuple[str, str]: + """ + Convert numbers to scientific notation, displaying only digits up to the point of difference, + with a minimum number of significant digits specified by `min_digits`. + """ + + def to_sci(value: float, exponent: int, precision: int) -> str: + normalized_value = value / (10**exponent) + return f"{normalized_value:.{precision}f}e{exponent}" + + if min_val == 0 or max_val == 0: + return f"{min_val:.0e}", f"{max_val:.0e}" + + exponent_min = math.floor(math.log10(abs(min_val))) + exponent_max = math.floor(math.log10(abs(max_val))) + + common_exponent = min(exponent_min, exponent_max) + normalized_min = min_val / (10**common_exponent) + normalized_max = max_val / (10**common_exponent) + + if normalized_min == normalized_max: + precision = min_digits + else: + precision = 0 + while round(normalized_min, precision) == round(normalized_max, precision): + precision += 1 + + precision = max(precision, min_digits) + + sci_min = to_sci(min_val, common_exponent, precision) + sci_max = to_sci(max_val, common_exponent, precision) + + return sci_min, sci_max + + def __rich_repr__(self) -> rich.repr.Result: + """How to pretty-print instances of ``Tidy3dBaseModel``.""" + for name in type(self).model_fields: + value = getattr(self, name) + + # don't print the type field we add to the models + if name == "type": + continue + + # skip `attrs` if it's an empty dictionary + if name == "attrs" and isinstance(value, dict) and not value: + continue + + yield name, value + + def __str__(self) -> str: + """Return a pretty-printed string representation of the model.""" + from io import StringIO + + from rich.console import Console + + sio = StringIO() + console = Console(file=sio) + console.print(self) + output = sio.getvalue() + return output.rstrip("\n") + + +def _make_lazy_proxy( + target_cls: type[Tidy3dBaseModel], + on_load: Optional[Callable[[Any], None]] = None, +) -> type[Tidy3dBaseModel]: + """ + Return a lazy-loading proxy subclass of ``target_cls``. + + Parameters + ---------- + target_cls : type + Must implement ``dict_from_file`` and ``model_validate``. + on_load : Optional[Callable[[Any], None]] = None + A function to call with the fully loaded instance once loaded. + + Returns + ------- + type + A class named ``Proxy`` with init args: + ``(fname, group_path, parse_obj_kwargs)``. + """ + + proxy_name = f"{target_cls.__name__}Proxy" + + class _LazyProxy(target_cls): # type: ignore[misc] + def __init__( + self, + fname: PathLike, + group_path: Optional[str], + parse_obj_kwargs: Any, + ) -> None: + # store lazy context only in __dict__ + object.__setattr__(self, "_lazy_fname", Path(fname)) + object.__setattr__(self, "_lazy_group_path", group_path) + object.__setattr__(self, "_lazy_parse_obj_kwargs", dict(parse_obj_kwargs or {})) + + def copy(self, **kwargs: Any) -> Self: + """Return another lazy proxy instead of materializing.""" + return _LazyProxy( + object.__getattribute__(self, "_lazy_fname"), + object.__getattribute__(self, "_lazy_group_path"), + { + **object.__getattribute__(self, "_lazy_parse_obj_kwargs"), + **kwargs, + }, + ) + + def __getattribute__(self, name: str) -> Any: + # Attributes that must *not* trigger materialization + if name.startswith("_lazy_") or name in { + "__class__", + "__dict__", + "__weakref__", + "__post_root_validators__", + "__pydantic_decorators__", + "copy", # don't materialize just for .copy() + }: + return object.__getattribute__(self, name) + + d = object.__getattribute__(self, "__dict__") + + if "_lazy_fname" in d: + fname = d["_lazy_fname"] + group_path = d["_lazy_group_path"] + kwargs = d["_lazy_parse_obj_kwargs"] + + # Build the real instance + model_dict = target_cls.dict_from_file(fname=fname, group_path=group_path) + target = target_cls._validate_model_dict(model_dict, **kwargs) + + d.clear() + d.update(target.__dict__) + + object.__setattr__(self, "__class__", target.__class__) + fields_set = getattr(target, "__pydantic_fields_set__", None) + if fields_set is not None: + object.__setattr__(self, "__pydantic_fields_set__", set(fields_set)) + + pvt = getattr(target, "__pydantic_private__", None) + if pvt is not None: + object.__setattr__(self, "__pydantic_private__", pvt) + + if on_load is not None: + on_load(self) + + return object.__getattribute__(self, name) + + _LazyProxy.__name__ = proxy_name + return _LazyProxy diff --git a/tidy3d/_common/components/base_sim/__init__.py b/tidy3d/_common/components/base_sim/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tidy3d/_common/components/base_sim/source.py b/tidy3d/_common/components/base_sim/source.py new file mode 100644 index 0000000000..eb6f51deca --- /dev/null +++ b/tidy3d/_common/components/base_sim/source.py @@ -0,0 +1,30 @@ +"""Abstract base for classes that define simulation sources.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Optional + +from pydantic import Field + +from tidy3d._common.components.base import Tidy3dBaseModel +from tidy3d._common.components.validators import validate_name_str + +if TYPE_CHECKING: + from tidy3d._common.components.viz import PlotParams + + +class AbstractSource(Tidy3dBaseModel, ABC): + """Abstract base class for all sources.""" + + name: Optional[str] = Field( + None, + title="Name", + description="Optional name for the source.", + ) + + @abstractmethod + def plot_params(self) -> PlotParams: + """Default parameters for plotting a Source object.""" + + _name_validator = validate_name_str() diff --git a/tidy3d/_common/components/data/__init__.py b/tidy3d/_common/components/data/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tidy3d/_common/components/data/data_array.py b/tidy3d/_common/components/data/data_array.py new file mode 100644 index 0000000000..f2df15783b --- /dev/null +++ b/tidy3d/_common/components/data/data_array.py @@ -0,0 +1,871 @@ +"""Storing tidy3d data at it's most fundamental level as xr.DataArray objects""" + +from __future__ import annotations + +import pathlib +from abc import ABC +from typing import TYPE_CHECKING, Any + +import autograd.numpy as anp +import h5py +import numpy as np +import xarray as xr +from autograd.tracer import isbox +from pydantic_core import core_schema +from xarray.core import missing +from xarray.core.indexes import PandasIndex +from xarray.core.indexing import _outer_to_numpy_indexer +from xarray.core.utils import OrderedSet, either_dict_or_kwargs +from xarray.core.variable import as_variable + +from tidy3d._common.compat import alignment +from tidy3d._common.components.autograd import TidyArrayBox, get_static, interpn, is_tidy_box +from tidy3d._common.components.geometry.bound_ops import bounds_contains +from tidy3d._common.constants import ( + HERTZ, + MICROMETER, + RADIAN, + SECOND, +) +from tidy3d._common.exceptions import DataError, FileError + +if TYPE_CHECKING: + from collections.abc import Mapping + from os import PathLike + from typing import Optional, Union + + from numpy.typing import NDArray + from pydantic.annotated_handlers import GetCoreSchemaHandler + from pydantic.json_schema import GetJsonSchemaHandler, JsonSchemaValue + from xarray.core.types import InterpOptions, Self + + from tidy3d._common.components.autograd import InterpolationType + from tidy3d._common.components.types.base import Axis, Bound + + +# maps the dimension names to their attributes +DIM_ATTRS = { + "x": {"units": MICROMETER, "long_name": "x position"}, + "y": {"units": MICROMETER, "long_name": "y position"}, + "z": {"units": MICROMETER, "long_name": "z position"}, + "f": {"units": HERTZ, "long_name": "frequency"}, + "t": {"units": SECOND, "long_name": "time"}, + "direction": {"long_name": "propagation direction"}, + "mode_index": {"long_name": "mode index"}, + "eme_port_index": {"long_name": "EME port index"}, + "eme_cell_index": {"long_name": "EME cell index"}, + "mode_index_in": {"long_name": "mode index in"}, + "mode_index_out": {"long_name": "mode index out"}, + "sweep_index": {"long_name": "sweep index"}, + "theta": {"units": RADIAN, "long_name": "elevation angle"}, + "phi": {"units": RADIAN, "long_name": "azimuth angle"}, + "ux": {"long_name": "normalized kx"}, + "uy": {"long_name": "normalized ky"}, + "orders_x": {"long_name": "diffraction order"}, + "orders_y": {"long_name": "diffraction order"}, + "face_index": {"long_name": "face index"}, + "vertex_index": {"long_name": "vertex index"}, + "axis": {"long_name": "axis"}, +} + + +# name of the DataArray.values in the hdf5 file (xarray's default name too) +DATA_ARRAY_VALUE_NAME = "__xarray_dataarray_variable__" + + +# maps the dimension names to their attributes +DIM_ATTRS = { + "x": {"units": MICROMETER, "long_name": "x position"}, + "y": {"units": MICROMETER, "long_name": "y position"}, + "z": {"units": MICROMETER, "long_name": "z position"}, + "f": {"units": HERTZ, "long_name": "frequency"}, + "t": {"units": SECOND, "long_name": "time"}, + "direction": {"long_name": "propagation direction"}, + "mode_index": {"long_name": "mode index"}, + "eme_port_index": {"long_name": "EME port index"}, + "eme_cell_index": {"long_name": "EME cell index"}, + "mode_index_in": {"long_name": "mode index in"}, + "mode_index_out": {"long_name": "mode index out"}, + "sweep_index": {"long_name": "sweep index"}, + "theta": {"units": RADIAN, "long_name": "elevation angle"}, + "phi": {"units": RADIAN, "long_name": "azimuth angle"}, + "ux": {"long_name": "normalized kx"}, + "uy": {"long_name": "normalized ky"}, + "orders_x": {"long_name": "diffraction order"}, + "orders_y": {"long_name": "diffraction order"}, + "face_index": {"long_name": "face index"}, + "vertex_index": {"long_name": "vertex index"}, + "axis": {"long_name": "axis"}, +} + + +# name of the DataArray.values in the hdf5 file (xarray's default name too) +DATA_ARRAY_VALUE_NAME = "__xarray_dataarray_variable__" + + +class DataArray(xr.DataArray): + """Subclass of ``xr.DataArray`` that requires _dims to match the keys of the coords.""" + + # Always set __slots__ = () to avoid xarray warnings + __slots__ = () + # stores an ordered tuple of strings corresponding to the data dimensions + _dims = () + # stores a dictionary of attributes corresponding to the data values + _data_attrs: dict[str, str] = {} + + def __init__(self, data: Any, *args: Any, **kwargs: Any) -> None: + # if data is a vanilla autograd box, convert to our box + if isbox(data) and not is_tidy_box(data): + data = TidyArrayBox.from_arraybox(data) + # do the same for xr.Variable or xr.DataArray type + elif isinstance(data, (xr.Variable, xr.DataArray)): + if isbox(data.data) and not is_tidy_box(data.data): + data.data = TidyArrayBox.from_arraybox(data.data) + super().__init__(data, *args, **kwargs) + + @classmethod + def __get_pydantic_core_schema__( + cls, source_type: Any, handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + """Core schema definition for validation & serialization.""" + + def _initial_parser(value: Any) -> Self: + if isinstance(value, cls): + return value + + if isinstance(value, str) and value == cls.__name__: + raise DataError( + f"Trying to load '{cls.__name__}' from string placeholder '{value}' " + "but the actual data is missing. DataArrays are not typically stored " + "in JSON. Load from HDF5 or ensure the DataArray object is provided." + ) + + try: + instance = cls(value) + if not isinstance(instance, cls): + raise TypeError( + f"Constructor for {cls.__name__} returned unexpected type {type(instance)}" + ) + return instance + except Exception as e: + raise ValueError( + f"Could not construct '{cls.__name__}' from input of type '{type(value)}'. " + f"Ensure input is compatible with xarray.DataArray constructor. Original error: {e}" + ) from e + + validation_schema = core_schema.no_info_plain_validator_function(_initial_parser) + validation_schema = core_schema.no_info_after_validator_function( + cls._validate_dims, validation_schema + ) + validation_schema = core_schema.no_info_after_validator_function( + cls._assign_data_attrs, validation_schema + ) + validation_schema = core_schema.no_info_after_validator_function( + cls._assign_coord_attrs, validation_schema + ) + + def _serialize_to_name(instance: Self) -> str: + return type(instance).__name__ + + # serialization behavior: + # - for JSON ('json' mode), use the _serialize_to_name function. + # - for Python ('python' mode), use Pydantic's default for the object type + serialization_schema = core_schema.plain_serializer_function_ser_schema( + _serialize_to_name, + return_schema=core_schema.str_schema(), + when_used="json", + ) + + return core_schema.json_or_python_schema( + python_schema=validation_schema, + json_schema=validation_schema, # Use same validation rules for JSON input + serialization=serialization_schema, + ) + + @classmethod + def __get_pydantic_json_schema__( + cls, core_schema_obj: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ) -> JsonSchemaValue: + """JSON schema definition (defines how it LOOKS in a schema, not the data).""" + return { + "type": "string", + "title": cls.__name__, + "description": ( + f"Placeholder for a '{cls.__name__}' object. Actual data is typically " + "serialized separately (e.g., via HDF5) and not embedded in JSON." + ), + } + + @classmethod + def _validate_dims(cls, val: Self) -> Self: + """Make sure the dims are the same as ``_dims``, then put them in the correct order.""" + if set(val.dims) != set(cls._dims): + raise ValueError( + f"Wrong dims for {cls.__name__}, expected '{cls._dims}', got '{val.dims}'" + ) + if val.dims != cls._dims: + val = val.transpose(*cls._dims) + return val + + @classmethod + def _assign_data_attrs(cls, val: Self) -> Self: + """Assign the correct data attributes to the :class:`.DataArray`.""" + for attr_name, attr_val in cls._data_attrs.items(): + val.attrs[attr_name] = attr_val + return val + + @classmethod + def _assign_coord_attrs(cls, val: Self) -> Self: + """Assign the correct coordinate attributes to the :class:`.DataArray`.""" + target_dims = set(val.dims) & set(cls._dims) & set(val.coords) + for dim in target_dims: + template = DIM_ATTRS.get(dim) + if not template: + continue + + coord_attrs = val.coords[dim].attrs + missing = {k: v for k, v in template.items() if coord_attrs.get(k) != v} + coord_attrs.update(missing) + return val + + def _interp_validator(self, field_name: Optional[str] = None) -> None: + """Ensure the data can be interpolated or selected by checking for duplicate coordinates. + + NOTE + ---- + This does not check every 'DataArray' by default. Instead, when required, this check can be + called from a validator, as is the case with 'CustomMedium' and 'CustomFieldSource'. + """ + if field_name is None: + field_name = self.__class__.__name__ + + for dim, coord in self.coords.items(): + if coord.to_index().duplicated().any(): + raise DataError( + f"Field '{field_name}' contains duplicate coordinates in dimension '{dim}'. " + "Duplicates can be removed by running " + f"'{field_name}={field_name}.drop_duplicates(dim=\"{dim}\")'." + ) + + def __eq__(self, other: Any) -> bool: + """Whether two data array objects are equal.""" + + if not isinstance(other, xr.DataArray): + return False + + if not self.data.shape == other.data.shape or not np.all(self.data == other.data): + return False + for key, val in self.coords.items(): + if not np.all(np.array(val) == np.array(other.coords[key])): + return False + return True + + @property + def values(self) -> NDArray: + """ + The array's data converted to a numpy.ndarray. + + Returns + ------- + np.ndarray + The values of the DataArray. + """ + return self.data if isbox(self.data) else super().values + + @values.setter + def values(self, value: Any) -> None: + self.variable.values = value + + def to_numpy(self) -> np.ndarray: + """Return `.data` when traced to avoid `dtype=object` NumPy conversion.""" + return self.data if isbox(self.data) else super().to_numpy() + + @property + def abs(self) -> Self: + """Absolute value of data array.""" + return abs(self) + + @property + def angle(self) -> Self: + """Angle or phase value of data array.""" + values = np.angle(self.values) + return type(self)(values, coords=self.coords) + + @property + def is_uniform(self) -> bool: + """Whether each element is of equal value in the data array""" + raw_data = self.data.ravel() + return np.allclose(raw_data, raw_data[0]) + + def to_hdf5(self, fname: Union[PathLike, h5py.File], group_path: str) -> None: + """Save an ``xr.DataArray`` to the hdf5 file or file handle with a given path to the group.""" + if isinstance(fname, (str, pathlib.Path)): + path = pathlib.Path(fname) + path.parent.mkdir(parents=True, exist_ok=True) + with h5py.File(path, "w") as f_handle: + self.to_hdf5_handle(f_handle=f_handle, group_path=group_path) + else: + self.to_hdf5_handle(f_handle=fname, group_path=group_path) + + def to_hdf5_handle(self, f_handle: h5py.File, group_path: str) -> None: + """Save an ``xr.DataArray`` to the hdf5 file handle with a given path to the group.""" + sub_group = f_handle.create_group(group_path) + sub_group[DATA_ARRAY_VALUE_NAME] = get_static(self.data) + for key, val in self.coords.items(): + if val.dtype == " Self: + """Load a DataArray from an hdf5 file with a given path to the group.""" + path = pathlib.Path(fname) + with h5py.File(path, "r") as f: + sub_group = f[group_path] + values = np.array(sub_group[DATA_ARRAY_VALUE_NAME]) + coords = {dim: np.array(sub_group[dim]) for dim in cls._dims if dim in sub_group} + for key, val in coords.items(): + if val.dtype == "O": + coords[key] = [byte_string.decode() for byte_string in val.tolist()] + return cls(values, coords=coords, dims=cls._dims) + + @classmethod + def from_file(cls, fname: PathLike, group_path: str) -> Self: + """Load a DataArray from an hdf5 file with a given path to the group.""" + path = pathlib.Path(fname) + if not any(suffix.lower() == ".hdf5" for suffix in path.suffixes): + raise FileError( + f"'DataArray' objects must be written to '.hdf5' format. Given filename of {path}." + ) + return cls.from_hdf5(fname=path, group_path=group_path) + + def __hash__(self) -> int: + """Generate hash value for a :class:`.DataArray` instance, needed for custom components.""" + import dask + + token_str = dask.base.tokenize(self) + return hash(token_str) + + def multiply_at(self, value: complex, coord_name: str, indices: list[int]) -> Self: + """Multiply self by value at indices.""" + if isbox(self.data) or isbox(value): + return self._ag_multiply_at(value, coord_name, indices) + + self_mult = self.copy() + self_mult[{coord_name: indices}] *= value + return self_mult + + def _ag_multiply_at(self, value: complex, coord_name: str, indices: list[int]) -> Self: + """Autograd multiply_at override when tracing.""" + key = {coord_name: indices} + _, index_tuple, _ = self.variable._broadcast_indexes(key) + idx = _outer_to_numpy_indexer(index_tuple, self.data.shape) + mask = np.zeros(self.data.shape, dtype="?") + mask[idx] = True + return self.copy(deep=False, data=anp.where(mask, self.data * value, self.data)) + + def interp( + self, + coords: Mapping[Any, Any] | None = None, + method: InterpOptions = "linear", + assume_sorted: bool = False, + kwargs: Mapping[str, Any] | None = None, + **coords_kwargs: Any, + ) -> Self: + """Interpolate this DataArray to new coordinate values. + + Parameters + ---------- + coords : Union[Mapping[Any, Any], None] = None + A mapping from dimension names to new coordinate labels. + method : InterpOptions = "linear" + The interpolation method to use. + assume_sorted : bool = False + If True, skip sorting of coordinates. + kwargs : Union[Mapping[str, Any], None] = None + Additional keyword arguments to pass to the interpolation function. + **coords_kwargs : Any + The keyword arguments form of coords. + + Returns + ------- + DataArray + A new DataArray with interpolated values. + + Raises + ------ + KeyError + If any of the specified coordinates are not in the DataArray. + """ + if isbox(self.data): + return self._ag_interp(coords, method, assume_sorted, kwargs, **coords_kwargs) + + return super().interp(coords, method, assume_sorted, kwargs, **coords_kwargs) + + def _ag_interp( + self, + coords: Union[Mapping[Any, Any], None] = None, + method: InterpOptions = "linear", + assume_sorted: bool = False, + kwargs: Union[Mapping[str, Any], None] = None, + **coords_kwargs: Any, + ) -> Self: + """Autograd interp override when tracing over self.data. + + This implementation closely follows the interp implementation of xarray + to match its behavior as closely as possible while supporting autograd. + + See: + - https://docs.xarray.dev/en/latest/generated/xarray.DataArray.interp.html + - https://docs.xarray.dev/en/latest/generated/xarray.Dataset.interp.html + """ + if kwargs is None: + kwargs = {} + + ds = self._to_temp_dataset() + + coords = either_dict_or_kwargs(coords, coords_kwargs, "interp") + indexers = dict(ds._validate_interp_indexers(coords)) + + if coords: + # Find shared dimensions between the dataset and the indexers + sdims = ( + set(ds.dims) + .intersection(*[set(nx.dims) for nx in indexers.values()]) + .difference(coords.keys()) + ) + indexers.update({d: ds.variables[d] for d in sdims}) + + obj = ds if assume_sorted else ds.sortby(list(coords)) + + # workaround to get a variable for a dimension without a coordinate + validated_indexers = { + k: (obj._variables.get(k, as_variable((k, range(obj.sizes[k])))), v) + for k, v in indexers.items() + } + + for k, v in validated_indexers.items(): + obj, newidx = missing._localize(obj, {k: v}) + validated_indexers[k] = newidx[k] + + variables = {} + reindex = False + for name, var in obj._variables.items(): + if name in indexers: + continue + dtype_kind = var.dtype.kind + if dtype_kind in "uifc": + # Interpolation for numeric types + var_indexers = {k: v for k, v in validated_indexers.items() if k in var.dims} + variables[name] = self._ag_interp_func(var, var_indexers, method, **kwargs) + elif dtype_kind in "ObU" and (validated_indexers.keys() & var.dims): + # Stepwise interpolation for non-numeric types + reindex = True + elif all(d not in indexers for d in var.dims): + # Keep variables not dependent on interpolated coords + variables[name] = var + + if reindex: + # Reindex for non-numeric types + reindex_indexers = {k: v for k, (_, v) in validated_indexers.items() if v.dims == (k,)} + reindexed = alignment.reindex( + obj, + indexers=reindex_indexers, + method="nearest", + exclude_vars=variables.keys(), + ) + indexes = dict(reindexed._indexes) + variables.update(reindexed.variables) + else: + # Get the indexes that are not being interpolated along + indexes = {k: v for k, v in obj._indexes.items() if k not in indexers} + + # Get the coords that also exist in the variables + coord_names = obj._coord_names & variables.keys() + selected = ds._replace_with_new_dims(variables.copy(), coord_names, indexes=indexes) + + # Attach indexer as coordinate + for k, v in indexers.items(): + if v.dims == (k,): + index = PandasIndex(v, k, coord_dtype=v.dtype) + index_vars = index.create_variables({k: v}) + indexes[k] = index + variables.update(index_vars) + else: + variables[k] = v + + # Extract coordinates from indexers + coord_vars, new_indexes = selected._get_indexers_coords_and_indexes(coords) + variables.update(coord_vars) + indexes.update(new_indexes) + + coord_names = obj._coord_names & variables.keys() | coord_vars.keys() + ds = ds._replace_with_new_dims(variables, coord_names, indexes=indexes) + return self._from_temp_dataset(ds) + + @staticmethod + def _ag_interp_func( + var: xr.Variable, + indexes_coords: dict[str, tuple[xr.Variable, xr.Variable]], + method: InterpolationType, + **kwargs: Any, + ) -> xr.Variable: + """ + Interpolate the variable `var` along the coordinates specified in `indexes_coords` using the given `method`. + + The implementation follows xarray's interp implementation in xarray.core.missing, + but replaces some of the pre-processing as well as the actual interpolation + function with an autograd-compatible approach. + + + Parameters + ---------- + var : xr.Variable + The variable to be interpolated. + indexes_coords : dict + A dictionary mapping dimension names to coordinate values for interpolation. + method : Literal["nearest", "linear"] + The interpolation method to use. + **kwargs : dict + Additional keyword arguments to pass to the interpolation function. + + Returns + ------- + xr.Variable + The interpolated variable. + """ + if not indexes_coords: + return var.copy() + result = var + for indep_indexes_coords in missing.decompose_interp(indexes_coords): + var = result + + # target dimensions + dims = list(indep_indexes_coords) + x, new_x = zip(*[indep_indexes_coords[d] for d in dims]) + destination = missing.broadcast_variables(*new_x) + + broadcast_dims = [d for d in var.dims if d not in dims] + original_dims = broadcast_dims + dims + new_dims = broadcast_dims + list(destination[0].dims) + + x, new_x = missing._floatize_x(x, new_x) + + permutation = [var.dims.index(dim) for dim in original_dims] + combined_permutation = permutation[-len(x) :] + permutation[: -len(x)] + data = anp.transpose(var.data, combined_permutation) + xi = anp.stack([anp.ravel(new_xi.data) for new_xi in new_x], axis=-1) + + result = interpn([xn.data for xn in x], data, xi, method=method, **kwargs) + + result = anp.moveaxis(result, 0, -1) + result = anp.reshape(result, result.shape[:-1] + new_x[0].shape) + + result = xr.Variable(new_dims, result, attrs=var.attrs, fastpath=True) + + out_dims: OrderedSet = OrderedSet() + for d in var.dims: + if d in dims: + out_dims.update(indep_indexes_coords[d][1].dims) + else: + out_dims.add(d) + if len(out_dims) > 1: + result = result.transpose(*out_dims) + return result + + def _with_updated_data(self, data: np.ndarray, coords: dict[str, Any]) -> DataArray: + """Make copy of ``DataArray`` with ``data`` at specified ``coords``, autograd compatible + + Constraints / Edge cases: + - `coords` must map to a specific value eg {x: '1'}, does not broadcast to arrays + - `data` will be reshaped to try to match `self.shape` except where `coords` present + """ + + # make mask + mask = xr.zeros_like(self, dtype=bool) + mask.loc[coords] = True + + # reshape `data` to line up with `self.dims`, with shape of 1 along the selected axis + old_data = self.data + new_shape = list(old_data.shape) + for i, dim in enumerate(self.dims): + if dim in coords: + new_shape[i] = 1 + try: + new_data = data.reshape(new_shape) + except ValueError as e: + raise ValueError( + "Couldn't reshape the supplied 'data' to update 'DataArray'. The provided data was " + f"of shape {data.shape} and tried to reshape to {new_shape}. If you encounter this " + "error please raise an issue on the tidy3d github repository with the context." + ) from e + + # broadcast data to repeat data along the selected dimensions to match mask + new_data = new_data + np.zeros_like(old_data) + + new_data = np.where(mask, new_data, old_data) + + return self.copy(deep=True, data=new_data) + + +class FreqDataArray(DataArray): + """Frequency-domain array. + + Example + ------- + >>> f = [2e14, 3e14] + >>> fd = FreqDataArray((1+1j) * np.random.random((2,)), coords=dict(f=f)) + """ + + __slots__ = () + _dims = ("f",) + + +class AbstractSpatialDataArray(DataArray, ABC): + """Spatial distribution.""" + + __slots__ = () + _dims = ("x", "y", "z") + _data_attrs = {"long_name": "field value"} + + @property + def _spatially_sorted(self) -> Self: + """Check whether sorted and sort if not.""" + needs_sorting = [] + for axis in "xyz": + axis_coords = self.coords[axis].values + if len(axis_coords) > 1 and np.any(axis_coords[1:] < axis_coords[:-1]): + needs_sorting.append(axis) + + if len(needs_sorting) > 0: + return self.sortby(needs_sorting) + + return self + + def sel_inside(self, bounds: Bound) -> Self: + """Return a new SpatialDataArray that contains the minimal amount data necessary to cover + a spatial region defined by ``bounds``. Note that the returned data is sorted with respect + to spatial coordinates. + + + Parameters + ---------- + bounds : Tuple[float, float, float], Tuple[float, float float] + Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. + + Returns + ------- + SpatialDataArray + Extracted spatial data array. + """ + if any(bmin > bmax for bmin, bmax in zip(*bounds)): + raise DataError( + "Min and max bounds must be packaged as '(minx, miny, minz), (maxx, maxy, maxz)'." + ) + + # make sure data is sorted with respect to coordinates + sorted_self = self._spatially_sorted + + inds_list = [] + + coords = (sorted_self.x, sorted_self.y, sorted_self.z) + + for coord, smin, smax in zip(coords, bounds[0], bounds[1]): + length = len(coord) + + # one point along direction, assume invariance + if length == 1: + comp_inds = [0] + else: + # if data does not cover structure at all take the closest index + if smax < coord[0]: # structure is completely on the left side + # take 2 if possible, so that linear iterpolation is possible + comp_inds = np.arange(0, max(2, length)) + + elif smin > coord[-1]: # structure is completely on the right side + # take 2 if possible, so that linear iterpolation is possible + comp_inds = np.arange(min(0, length - 2), length) + + else: + if smin < coord[0]: + ind_min = 0 + else: + ind_min = max(0, (coord >= smin).data.argmax() - 1) + + if smax > coord[-1]: + ind_max = length - 1 + else: + ind_max = (coord >= smax).data.argmax() + + comp_inds = np.arange(ind_min, ind_max + 1) + + inds_list.append(comp_inds) + + return sorted_self.isel(x=inds_list[0], y=inds_list[1], z=inds_list[2]) + + def does_cover(self, bounds: Bound, rtol: float = 0.0, atol: float = 0.0) -> bool: + """Check whether data fully covers specified by ``bounds`` spatial region. If data contains + only one point along a given direction, then it is assumed the data is constant along that + direction and coverage is not checked. + + + Parameters + ---------- + bounds : Tuple[float, float, float], Tuple[float, float float] + Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. + rtol : float = 0.0 + Relative tolerance for comparing bounds + atol : float = 0.0 + Absolute tolerance for comparing bounds + + Returns + ------- + bool + Full cover check outcome. + """ + if any(bmin > bmax for bmin, bmax in zip(*bounds)): + raise DataError( + "Min and max bounds must be packaged as '(minx, miny, minz), (maxx, maxy, maxz)'." + ) + xyz = [self.x, self.y, self.z] + self_min = [0] * 3 + self_max = [0] * 3 + for dim in range(3): + coords = xyz[dim] + if len(coords) == 1: + self_min[dim] = bounds[0][dim] + self_max[dim] = bounds[1][dim] + else: + self_min[dim] = np.min(coords) + self_max[dim] = np.max(coords) + self_bounds = (tuple(self_min), tuple(self_max)) + return bounds_contains(self_bounds, bounds, rtol=rtol, atol=atol) + + +class ScalarFieldDataArray(AbstractSpatialDataArray): + """Spatial distribution in the frequency-domain. + + Example + ------- + >>> x = [1,2] + >>> y = [2,3,4] + >>> z = [3,4,5,6] + >>> f = [2e14, 3e14] + >>> coords = dict(x=x, y=y, z=z, f=f) + >>> fd = ScalarFieldDataArray((1+1j) * np.random.random((2,3,4,2)), coords=coords) + """ + + __slots__ = () + _dims = ("x", "y", "z", "f") + + +class TriangleMeshDataArray(DataArray): + """Data of the triangles of a surface mesh as in the STL file format.""" + + __slots__ = () + _dims = ("face_index", "vertex_index", "axis") + _data_attrs = {"long_name": "surface mesh triangles"} + + +class TimeDataArray(DataArray): + """Time-domain array. + + Example + ------- + >>> t = [0, 1e-12, 2e-12] + >>> td = TimeDataArray((1+1j) * np.random.random((3,)), coords=dict(t=t)) + """ + + __slots__ = () + _dims = ("t",) + + +class SpatialDataArray(AbstractSpatialDataArray): + """Spatial distribution. + + Example + ------- + >>> x = [1,2] + >>> y = [2,3,4] + >>> z = [3,4,5,6] + >>> coords = dict(x=x, y=y, z=z) + >>> fd = SpatialDataArray((1+1j) * np.random.random((2,3,4)), coords=coords) + """ + + __slots__ = () + + def reflect(self, axis: Axis, center: float, reflection_only: bool = False) -> Self: + """Reflect data across the plane define by parameters ``axis`` and ``center`` from right to + left. Note that the returned data is sorted with respect to spatial coordinates. + + Parameters + ---------- + axis : Literal[0, 1, 2] + Normal direction of the reflection plane. + center : float + Location of the reflection plane along its normal direction. + reflection_only : bool = False + Return only reflected data. + + Returns + ------- + SpatialDataArray + Data after reflection is performed. + """ + + sorted_self = self._spatially_sorted + + coords = [sorted_self.x.values, sorted_self.y.values, sorted_self.z.values] + data = np.array(sorted_self.data) + + data_left_bound = coords[axis][0] + + if np.isclose(center, data_left_bound): + num_duplicates = 1 + elif center > data_left_bound: + raise DataError("Reflection center must be outside and to the left of the data region.") + else: + num_duplicates = 0 + + if reflection_only: + coords[axis] = 2 * center - coords[axis] + coords_dict = dict(zip("xyz", coords)) + + tmp_arr = SpatialDataArray(sorted_self.data, coords=coords_dict) + + return tmp_arr.sortby("xyz"[axis]) + + shape = np.array(np.shape(data)) + old_len = shape[axis] + shape[axis] = 2 * old_len - num_duplicates + + ind_left = [slice(shape[0]), slice(shape[1]), slice(shape[2])] + ind_right = [slice(shape[0]), slice(shape[1]), slice(shape[2])] + + ind_left[axis] = slice(old_len - 1, None, -1) + ind_right[axis] = slice(old_len - num_duplicates, None) + + new_data = np.zeros(shape) + + new_data[ind_left[0], ind_left[1], ind_left[2]] = data + new_data[ind_right[0], ind_right[1], ind_right[2]] = data + + new_coords = np.zeros(shape[axis]) + new_coords[old_len - num_duplicates :] = coords[axis] + new_coords[old_len - 1 :: -1] = 2 * center - coords[axis] + + coords[axis] = new_coords + coords_dict = dict(zip("xyz", coords)) + + return SpatialDataArray(new_data, coords=coords_dict) + + +DATA_ARRAY_TYPES: list[type[DataArray]] = [ + FreqDataArray, + TriangleMeshDataArray, + TimeDataArray, + SpatialDataArray, + ScalarFieldDataArray, +] +DATA_ARRAY_MAP = {data_array.__name__: data_array for data_array in DATA_ARRAY_TYPES} diff --git a/tidy3d/_common/components/data/dataset.py b/tidy3d/_common/components/data/dataset.py new file mode 100644 index 0000000000..21ee22b6b8 --- /dev/null +++ b/tidy3d/_common/components/data/dataset.py @@ -0,0 +1,207 @@ +"""Collections of DataArrays.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any + +import numpy as np +import xarray as xr +from pydantic import Field + +from tidy3d._common.components.base import Tidy3dBaseModel +from tidy3d._common.components.data.data_array import ( + DataArray, + ScalarFieldDataArray, + TimeDataArray, + TriangleMeshDataArray, +) +from tidy3d._common.exceptions import DataError +from tidy3d._common.log import log + +if TYPE_CHECKING: + from typing import Callable + + from tidy3d._common.components.types.base import ArrayLike, Axis + +DEFAULT_MAX_SAMPLES_PER_STEP = 10_000 +DEFAULT_MAX_CELLS_PER_STEP = 10_000 +DEFAULT_TOLERANCE_CELL_FINDING = 1e-6 + + +class Dataset(Tidy3dBaseModel, ABC): + """Abstract base class for objects that store collections of `:class:`.DataArray`s.""" + + @property + def data_arrs(self) -> dict: + """Returns a dictionary of all `:class:`.DataArray`s in the dataset.""" + data_arrs = {} + for key in self.__class__.model_fields.keys(): + data = getattr(self, key) + if isinstance(data, DataArray): + data_arrs[key] = data + return data_arrs + + +class TriangleMeshDataset(Dataset): + """Dataset for storing triangular surface data.""" + + surface_mesh: TriangleMeshDataArray = Field( + title="Surface mesh data", + description="Dataset containing the surface triangles and corresponding face indices " + "for a surface mesh.", + ) + + +class AbstractFieldDataset(Dataset, ABC): + """Collection of scalar fields with some symmetry properties.""" + + @property + @abstractmethod + def field_components(self) -> dict[str, DataArray]: + """Maps the field components to their associated data.""" + + def apply_phase(self, phase: float) -> AbstractFieldDataset: + """Create a copy where all elements are phase-shifted by a value (in radians).""" + if phase == 0.0: + return self + phasor = np.exp(1j * phase) + field_components_shifted = {} + for fld_name, fld_cmp in self.field_components.items(): + fld_cmp_shifted = phasor * fld_cmp + field_components_shifted[fld_name] = fld_cmp_shifted + return self.updated_copy(**field_components_shifted) + + @property + @abstractmethod + def grid_locations(self) -> dict[str, str]: + """Maps field components to the string key of their grid locations on the yee lattice.""" + + @property + @abstractmethod + def symmetry_eigenvalues(self) -> dict[str, Callable[[Axis], float]]: + """Maps field components to their (positive) symmetry eigenvalues.""" + + def package_colocate_results(self, centered_fields: dict[str, ScalarFieldDataArray]) -> Any: + """How to package the dictionary of fields computed via self.colocate().""" + return xr.Dataset(centered_fields) + + def colocate(self, x: ArrayLike = None, y: ArrayLike = None, z: ArrayLike = None) -> xr.Dataset: + """Colocate all of the data at a set of x, y, z coordinates. + + Parameters + ---------- + x : Optional[array-like] = None + x coordinates of locations. + If not supplied, does not try to colocate on this dimension. + y : Optional[array-like] = None + y coordinates of locations. + If not supplied, does not try to colocate on this dimension. + z : Optional[array-like] = None + z coordinates of locations. + If not supplied, does not try to colocate on this dimension. + + Returns + ------- + xr.Dataset + Dataset containing all fields at the same spatial locations. + For more details refer to `xarray's Documentation `_. + + Note + ---- + For many operations (such as flux calculations and plotting), + it is important that the fields are colocated at the same spatial locations. + Be sure to apply this method to your field data in those cases. + """ + + if hasattr(self, "monitor") and self.monitor.colocate: + with log as consolidated_logger: + consolidated_logger.warning( + "Colocating data that has already been colocated during the solver " + "run. For most accurate results when colocating to custom coordinates set " + "'Monitor.colocate' to 'False' to use the raw data on the Yee grid " + "and avoid double interpolation. Note: the default value was changed to 'True' " + "in Tidy3D version 2.4.0." + ) + + # convert supplied coordinates to array and assign string mapping to them + supplied_coord_map = {k: np.array(v) for k, v in zip("xyz", (x, y, z)) if v is not None} + + # dict of data arrays to combine in dataset and return + centered_fields = {} + + # loop through field components + for field_name, field_data in self.field_components.items(): + # loop through x, y, z dimensions and raise an error if only one element along dim + for coord_name, coords_supplied in supplied_coord_map.items(): + coord_data = np.array(field_data.coords[coord_name]) + if coord_data.size == 1: + raise DataError( + f"colocate given {coord_name}={coords_supplied}, but " + f"data only has one coordinate at {coord_name}={coord_data[0]}. " + "Therefore, can't colocate along this dimension. " + f"supply {coord_name}=None to skip it." + ) + + centered_fields[field_name] = field_data.interp( + **supplied_coord_map, kwargs={"bounds_error": True} + ) + + # combine all centered fields in a dataset + return self.package_colocate_results(centered_fields) + + +class TimeDataset(Dataset): + """Dataset for storing a function of time.""" + + values: TimeDataArray = Field( + title="Values", + description="Values as a function of time.", + ) + + +class AbstractMediumPropertyDataset(AbstractFieldDataset, ABC): + """Dataset storing medium property.""" + + eps_xx: ScalarFieldDataArray = Field( + title="Epsilon xx", + description="Spatial distribution of the xx-component of the relative permittivity.", + ) + eps_yy: ScalarFieldDataArray = Field( + title="Epsilon yy", + description="Spatial distribution of the yy-component of the relative permittivity.", + ) + eps_zz: ScalarFieldDataArray = Field( + title="Epsilon zz", + description="Spatial distribution of the zz-component of the relative permittivity.", + ) + + +class PermittivityDataset(AbstractMediumPropertyDataset): + """Dataset storing the diagonal components of the permittivity tensor. + + Example + ------- + >>> x = [-1,1] + >>> y = [-2,0,2] + >>> z = [-3,-1,1,3] + >>> f = [2e14, 3e14] + >>> coords = dict(x=x, y=y, z=z, f=f) + >>> sclr_fld = ScalarFieldDataArray((1+1j) * np.random.random((2,3,4,2)), coords=coords) + >>> data = PermittivityDataset(eps_xx=sclr_fld, eps_yy=sclr_fld, eps_zz=sclr_fld) + """ + + @property + def field_components(self) -> dict[str, ScalarFieldDataArray]: + """Maps the field components to their associated data.""" + return {"eps_xx": self.eps_xx, "eps_yy": self.eps_yy, "eps_zz": self.eps_zz} + + @property + def grid_locations(self) -> dict[str, str]: + """Maps field components to the string key of their grid locations on the yee lattice.""" + return {"eps_xx": "Ex", "eps_yy": "Ey", "eps_zz": "Ez"} + + @property + def symmetry_eigenvalues(self) -> dict[str, None]: + """Maps field components to their (positive) symmetry eigenvalues.""" + return {"eps_xx": None, "eps_yy": None, "eps_zz": None} diff --git a/tidy3d/_common/components/data/validators.py b/tidy3d/_common/components/data/validators.py new file mode 100644 index 0000000000..17763e819d --- /dev/null +++ b/tidy3d/_common/components/data/validators.py @@ -0,0 +1,90 @@ +# special validators for Datasets +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import numpy as np +from pydantic import field_validator + +from tidy3d._common.components.data.data_array import DataArray, ScalarFieldDataArray +from tidy3d._common.components.data.dataset import AbstractFieldDataset +from tidy3d._common.exceptions import ValidationError + +if TYPE_CHECKING: + from typing import Callable, Optional + + from pydantic_core.core_schema import ValidationInfo + +if TYPE_CHECKING: + from typing import Callable, Optional + + from pydantic_core.core_schema import ValidationInfo + + +# this can't go in validators.py because that file imports dataset.py +def validate_no_nans(*field_names: str) -> Callable[[Any, ValidationInfo], Any]: + """Raise validation error if nans found in Dataset, or other data-containing item.""" + + @field_validator(*field_names) + def no_nans(val: Any, info: ValidationInfo) -> Any: + """Raise validation error if nans found in Dataset, or other data-containing item.""" + + if val is None: + return val + + def error_if_has_nans(value: Any, identifier: Optional[str] = None) -> None: + """Recursively check if value (or iterable) has nans and error if so.""" + + def has_nans(values: Any) -> bool: + """Base case: do these values contain NaN?""" + try: + return np.any(np.isnan(values)) + # if this fails for some reason (fails in adjoint, for example), don't check it. + except Exception: + return False + + if isinstance(value, (tuple, list)): + for i, _value in enumerate(value): + error_if_has_nans(_value, identifier=f"[{i}]") + + elif isinstance(value, AbstractFieldDataset): + for key, val in value.field_components.items(): + error_if_has_nans(val, identifier=f".{key}") + + elif isinstance(value, DataArray): + error_if_has_nans(value.values) + + else: + if has_nans(value): + # the identifier is used to make the message more clear by appending some more info + field_name_display = info.field_name + if identifier: + field_name_display += identifier + + raise ValidationError( + f"Found 'NaN' values in '{field_name_display}'. " + "If they were not intended, please double check your construction. " + "If intended, to replace these data points with a value 'x', " + "call 'values = np.nan_to_num(values, nan=x)'." + ) + + error_if_has_nans(val) + return val + + return no_nans + + +def validate_can_interpolate( + *field_names: str, +) -> Callable[[AbstractFieldDataset], AbstractFieldDataset]: + """Make sure the data in ``field_name`` can be interpolated.""" + + @field_validator(*field_names) + def check_fields_interpolate(val: AbstractFieldDataset) -> AbstractFieldDataset: + if isinstance(val, AbstractFieldDataset): + for name, data in val.field_components.items(): + if isinstance(data, ScalarFieldDataArray): + data._interp_validator(name) + return val + + return check_fields_interpolate diff --git a/tidy3d/_common/components/data/zbf.py b/tidy3d/_common/components/data/zbf.py new file mode 100644 index 0000000000..5fa0e5a1a1 --- /dev/null +++ b/tidy3d/_common/components/data/zbf.py @@ -0,0 +1,156 @@ +"""ZBF utilities""" + +from __future__ import annotations + +from struct import unpack + +import numpy as np +from pydantic import Field + +from tidy3d._common.components.base import Tidy3dBaseModel + + +class ZBFData(Tidy3dBaseModel): + """ + Contains data read in from a ``.zbf`` file + """ + + version: int = Field(title="Version", description="File format version number.") + nx: int = Field(title="Samples in X", description="Number of samples in the x direction.") + ny: int = Field(title="Samples in Y", description="Number of samples in the y direction.") + ispol: bool = Field( + title="Is Polarized", + description="``True`` if the beam is polarized, ``False`` otherwise.", + ) + unit: str = Field( + title="Spatial Units", description="Spatial units, either 'mm', 'cm', 'in', or 'm'." + ) + dx: float = Field(title="Grid Spacing, X", description="Grid spacing in x.") + dy: float = Field(title="Grid Spacing, Y", description="Grid spacing in y.") + zposition_x: float = Field( + title="Z Position, X Direction", + description="The pilot beam z position with respect to the pilot beam waist, x direction.", + ) + zposition_y: float = Field( + title="Z Position, Y Direction", + description="The pilot beam z position with respect to the pilot beam waist, y direction.", + ) + rayleigh_x: float = Field( + title="Rayleigh Distance, X Direction", + description="The pilot beam Rayleigh distance in the x direction.", + ) + rayleigh_y: float = Field( + title="Rayleigh Distance, Y Direction", + description="The pilot beam Rayleigh distance in the y direction.", + ) + waist_x: float = Field( + title="Beam Waist, X", description="The pilot beam waist in the x direction." + ) + waist_y: float = Field( + title="Beam Waist, Y", description="The pilot beam waist in the y direction." + ) + wavelength: float = Field(title="Wavelength", description="The wavelength of the beam.") + background_refractive_index: float = Field( + title="Background Refractive Index", + description="The index of refraction in the current medium.", + ) + receiver_eff: float = Field( + title="Receiver Efficiency", + description="The receiver efficiency. Zero if fiber coupling is not computed.", + ) + system_eff: float = Field( + title="System Efficiency", + description="The system efficiency. Zero if fiber coupling is not computed.", + ) + Ex: np.ndarray = Field( + title="Electric Field, X Component", + description="Complex-valued electric field, x component.", + ) + Ey: np.ndarray = Field( + title="Electric Field, Y Component", + description="Complex-valued electric field, y component.", + ) + + def read_zbf(filename: str) -> ZBFData: + """Reads a Zemax Beam File (``.zbf``) + + Parameters + ---------- + filename : str + The file name of the ``.zbf`` file to read. + + Returns + ------- + :class:`.ZBFData` + """ + + # Read the zbf file + with open(filename, "rb") as f: + # Load the header + version, nx, ny, ispol, units = unpack("<5I", f.read(20)) + f.read(16) # unused values + ( + dx, + dy, + zposition_x, + rayleigh_x, + waist_x, + zposition_y, + rayleigh_y, + waist_y, + wavelength, + background_refractive_index, + receiver_eff, + system_eff, + ) = unpack("<12d", f.read(96)) + f.read(64) # unused values + + # read E field + nsamps = 2 * nx * ny + rawx = list(unpack(f"<{nsamps}d", f.read(8 * nsamps))) + if ispol: + rawy = list(unpack(f"<{nsamps}d", f.read(8 * nsamps))) + + # convert unit key to unit string + map_units = {0: "mm", 1: "cm", 2: "in", 3: "m"} + try: + unit = map_units[units] + except KeyError: + raise KeyError( + f"Invalid units specified in the zbf file (expected '0', '1', '2', or '3', got '{units}')." + ) from None + + # load E field + Ex_real = np.asarray(rawx[0::2]).reshape(nx, ny, order="F") + Ex_imag = np.asarray(rawx[1::2]).reshape(nx, ny, order="F") + if ispol: + Ey_real = np.asarray(rawy[0::2]).reshape(nx, ny, order="F") + Ey_imag = np.asarray(rawy[1::2]).reshape(nx, ny, order="F") + else: + Ey_real = np.zeros((nx, ny)) + Ey_imag = np.zeros((nx, ny)) + + Ex = Ex_real + 1j * Ex_imag + Ey = Ey_real + 1j * Ey_imag + + return ZBFData( + version=version, + nx=nx, + ny=ny, + ispol=ispol, + unit=unit, + dx=dx, + dy=dy, + zposition_x=zposition_x, + zposition_y=zposition_y, + rayleigh_x=rayleigh_x, + rayleigh_y=rayleigh_y, + waist_x=waist_x, + waist_y=waist_y, + wavelength=wavelength, + background_refractive_index=background_refractive_index, + receiver_eff=receiver_eff, + system_eff=system_eff, + Ex=Ex, + Ey=Ey, + ) diff --git a/tidy3d/_common/components/file_util.py b/tidy3d/_common/components/file_util.py new file mode 100644 index 0000000000..51e13f586d --- /dev/null +++ b/tidy3d/_common/components/file_util.py @@ -0,0 +1,83 @@ +"""File compression utilities""" + +from __future__ import annotations + +import gzip +import pathlib +import shutil +from typing import TYPE_CHECKING, Any + +import numpy as np + +if TYPE_CHECKING: + from io import BytesIO + from os import PathLike + + +def compress_file_to_gzip(input_file: PathLike, output_gz_file: PathLike | BytesIO) -> None: + """ + Compress a file using gzip. + + Parameters + ---------- + input_file : PathLike + The path to the input file. + output_gz_file : PathLike | BytesIO + The path to the output gzip file or an in-memory buffer. + """ + input_file = pathlib.Path(input_file) + with input_file.open("rb") as file_in: + with gzip.open(output_gz_file, "wb") as file_out: + shutil.copyfileobj(file_in, file_out) + + +def extract_gzip_file(input_gz_file: PathLike, output_file: PathLike) -> None: + """ + Extract a gzip-compressed file. + + Parameters + ---------- + input_gz_file : PathLike + The path to the gzip-compressed input file. + output_file : PathLike + The path to the extracted output file. + """ + input_path = pathlib.Path(input_gz_file) + output_path = pathlib.Path(output_file) + with gzip.open(input_path, "rb") as file_in: + with output_path.open("wb") as file_out: + shutil.copyfileobj(file_in, file_out) + + +def replace_values(values: Any, search_value: Any, replace_value: Any) -> Any: + """ + Create a copy of ``values`` where any elements equal to ``search_value`` are replaced by ``replace_value``. + + Parameters + ---------- + values : Any + The input object to iterate through. + search_value : Any + An object to match for in ``values``. + replace_value : Any + A replacement object for the matched value in ``values``. + + Returns + ------- + Any + values type object with ``search_value`` terms replaced by ``replace_value``. + """ + # np.all allows for arrays to be evaluated + if np.all(values == search_value): + return replace_value + if isinstance(values, dict): + return { + key: replace_values(val, search_value, replace_value) for key, val in values.items() + } + if isinstance( + values, (tuple, list) + ): # Parts of the nested dict structure include tuples with more dicts + return type(values)(replace_values(val, search_value, replace_value) for val in values) + + # Used to maintain values that are not search_value or containers + return values diff --git a/tidy3d/_common/components/geometry/__init__.py b/tidy3d/_common/components/geometry/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tidy3d/_common/components/geometry/base.py b/tidy3d/_common/components/geometry/base.py new file mode 100644 index 0000000000..6871739c65 --- /dev/null +++ b/tidy3d/_common/components/geometry/base.py @@ -0,0 +1,3718 @@ +"""Abstract base classes for geometry.""" + +from __future__ import annotations + +import functools +import pathlib +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Optional + +import autograd.numpy as np +import shapely +from pydantic import Field, NonNegativeFloat, field_validator, model_validator + +from tidy3d._common.compat import _package_is_older_than +from tidy3d._common.components.autograd import TracedCoordinate, TracedFloat, TracedSize, get_static +from tidy3d._common.components.base import Tidy3dBaseModel, cached_property +from tidy3d._common.components.geometry.bound_ops import bounds_intersection, bounds_union +from tidy3d._common.components.geometry.float_utils import increment_float +from tidy3d._common.components.transformation import ReflectionFromPlane, RotationAroundAxis +from tidy3d._common.components.types.base import ( + Axis, + ClipOperationType, + MatrixReal4x4, + PlanePosition, + discriminated_union, +) +from tidy3d._common.components.viz import ( + ARROW_LENGTH, + PLOT_BUFFER, + add_ax_if_none, + arrow_style, + equal_aspect, + plot_params_geometry, + polygon_patch, + set_default_labels_and_title, +) +from tidy3d._common.constants import LARGE_NUMBER, MICROMETER, RADIAN, fp_eps, inf +from tidy3d._common.exceptions import ( + SetupError, + Tidy3dError, + Tidy3dImportError, + Tidy3dKeyError, + ValidationError, +) +from tidy3d._common.log import log +from tidy3d._common.packaging import verify_packages_import + +if TYPE_CHECKING: + from collections.abc import Iterable, Sequence + from os import PathLike + from typing import Callable, Union + + import pydantic + from gdstk import Cell + from matplotlib.backend_bases import Event + from matplotlib.patches import FancyArrowPatch + from numpy.typing import ArrayLike, NDArray + from pydantic import NonNegativeInt, PositiveFloat + from typing_extensions import Self + + from tidy3d._common.components.autograd import AutogradFieldMap + from tidy3d._common.components.autograd.derivative_utils import DerivativeInfo + from tidy3d._common.components.types.base import ( + ArrayFloat2D, + ArrayFloat3D, + Ax, + Bound, + Coordinate, + Coordinate2D, + LengthUnit, + Shapely, + Size, + ) + from tidy3d._common.components.viz import PlotParams, VisualizationSpec + +try: + from matplotlib import patches +except ImportError: + pass + +POLY_GRID_SIZE = 1e-12 +POLY_TOLERANCE_RATIO = 1e-12 +POLY_DISTANCE_TOLERANCE = 8e-12 + + +_shapely_operations = { + "union": shapely.union, + "intersection": shapely.intersection, + "difference": shapely.difference, + "symmetric_difference": shapely.symmetric_difference, +} + +_bit_operations = { + "union": lambda a, b: a | b, + "intersection": lambda a, b: a & b, + "difference": lambda a, b: a & ~b, + "symmetric_difference": lambda a, b: a != b, +} + + +class Geometry(Tidy3dBaseModel, ABC): + """Abstract base class, defines where something exists in space.""" + + @cached_property + def plot_params(self) -> PlotParams: + """Default parameters for plotting a Geometry object.""" + return plot_params_geometry + + def inside(self, x: NDArray[float], y: NDArray[float], z: NDArray[float]) -> NDArray[bool]: + """For input arrays ``x``, ``y``, ``z`` of arbitrary but identical shape, return an array + with the same shape which is ``True`` for every point in zip(x, y, z) that is inside the + volume of the :class:`Geometry`, and ``False`` otherwise. + + Parameters + ---------- + x : np.ndarray[float] + Array of point positions in x direction. + y : np.ndarray[float] + Array of point positions in y direction. + z : np.ndarray[float] + Array of point positions in z direction. + + Returns + ------- + np.ndarray[bool] + ``True`` for every point that is inside the geometry. + """ + + def point_inside(x: float, y: float, z: float) -> bool: + """Returns ``True`` if a single point ``(x, y, z)`` is inside.""" + shapes_intersect = self.intersections_plane(z=z) + loc = self.make_shapely_point(x, y) + return any(shape.contains(loc) for shape in shapes_intersect) + + arrays = tuple(map(np.array, (x, y, z))) + self._ensure_equal_shape(*arrays) + inside = np.zeros((arrays[0].size,), dtype=bool) + arrays_flat = map(np.ravel, arrays) + for ipt, args in enumerate(zip(*arrays_flat)): + inside[ipt] = point_inside(*args) + return inside.reshape(arrays[0].shape) + + @staticmethod + def _ensure_equal_shape(*arrays: Any) -> None: + """Ensure all input arrays have the same shape.""" + shapes = {np.array(arr).shape for arr in arrays} + if len(shapes) > 1: + raise ValueError("All coordinate inputs (x, y, z) must have the same shape.") + + @staticmethod + def make_shapely_box(minx: float, miny: float, maxx: float, maxy: float) -> shapely.box: + """Make a shapely box ensuring everything untraced.""" + + minx = get_static(minx) + miny = get_static(miny) + maxx = get_static(maxx) + maxy = get_static(maxy) + + return shapely.box(minx, miny, maxx, maxy) + + @staticmethod + def make_shapely_point(minx: float, miny: float) -> shapely.Point: + """Make a shapely Point ensuring everything untraced.""" + + minx = get_static(minx) + miny = get_static(miny) + + return shapely.Point(minx, miny) + + def _inds_inside_bounds( + self, x: NDArray[float], y: NDArray[float], z: NDArray[float] + ) -> tuple[slice, slice, slice]: + """Return slices into the sorted input arrays that are inside the geometry bounds. + + Parameters + ---------- + x : np.ndarray[float] + 1D array of point positions in x direction. + y : np.ndarray[float] + 1D array of point positions in y direction. + z : np.ndarray[float] + 1D array of point positions in z direction. + + Returns + ------- + tuple[slice, slice, slice] + Slices into each of the three arrays that are inside the geometry bounds. + """ + bounds = self.bounds + inds_in = [] + for dim, coords in enumerate([x, y, z]): + inds = np.nonzero((bounds[0][dim] <= coords) * (coords <= bounds[1][dim]))[0] + inds_in.append(slice(0, 0) if inds.size == 0 else slice(inds[0], inds[-1] + 1)) + + return tuple(inds_in) + + def inside_meshgrid( + self, x: NDArray[float], y: NDArray[float], z: NDArray[float] + ) -> NDArray[bool]: + """Perform ``self.inside`` on a set of sorted 1D coordinates. Applies meshgrid to the + supplied coordinates before checking inside. + + Parameters + ---------- + + x : np.ndarray[float] + 1D array of point positions in x direction. + y : np.ndarray[float] + 1D array of point positions in y direction. + z : np.ndarray[float] + 1D array of point positions in z direction. + + Returns + ------- + np.ndarray[bool] + Array with shape ``(x.size, y.size, z.size)``, which is ``True`` for every + point that is inside the geometry. + """ + + arrays = tuple(map(np.array, (x, y, z))) + if any(arr.ndim != 1 for arr in arrays): + raise ValueError("Each of the supplied coordinates (x, y, z) must be 1D.") + shape = tuple(arr.size for arr in arrays) + is_inside = np.zeros(shape, dtype=bool) + inds_inside = self._inds_inside_bounds(*arrays) + coords_inside = tuple(arr[ind] for ind, arr in zip(inds_inside, arrays)) + coords_3d = np.meshgrid(*coords_inside, indexing="ij") + is_inside[inds_inside] = self.inside(*coords_3d) + return is_inside + + @abstractmethod + def intersections_tilted_plane( + self, + normal: Coordinate, + origin: Coordinate, + to_2D: MatrixReal4x4, + cleanup: bool = True, + quad_segs: Optional[int] = None, + ) -> list[Shapely]: + """Return a list of shapely geometries at the plane specified by normal and origin. + + Parameters + ---------- + normal : Coordinate + Vector defining the normal direction to the plane. + origin : Coordinate + Vector defining the plane origin. + to_2D : MatrixReal4x4 + Transformation matrix to apply to resulting shapes. + cleanup : bool = True + If True, removes extremely small features from each polygon's boundary. + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. If ``None``, uses + high-quality visualization settings. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation `_. + """ + + def intersections_plane( + self, + x: Optional[float] = None, + y: Optional[float] = None, + z: Optional[float] = None, + cleanup: bool = True, + quad_segs: Optional[int] = None, + ) -> list[Shapely]: + """Returns list of shapely geometries at plane specified by one non-None value of x,y,z. + + Parameters + ---------- + x : float = None + Position of plane in x direction, only one of x,y,z can be specified to define plane. + y : float = None + Position of plane in y direction, only one of x,y,z can be specified to define plane. + z : float = None + Position of plane in z direction, only one of x,y,z can be specified to define plane. + cleanup : bool = True + If True, removes extremely small features from each polygon's boundary. + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. If ``None``, uses + high-quality visualization settings. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation `_. + """ + axis, position = self.parse_xyz_kwargs(x=x, y=y, z=z) + origin = self.unpop_axis(position, (0, 0), axis=axis) + normal = self.unpop_axis(1, (0, 0), axis=axis) + to_2D = np.eye(4) + if axis != 2: + last, indices = self.pop_axis((0, 1, 2), axis) + to_2D = to_2D[[*list(indices), last, 3]] + return self.intersections_tilted_plane( + normal, origin, to_2D, cleanup=cleanup, quad_segs=quad_segs + ) + + def intersections_2dbox(self, plane: Box) -> list[Shapely]: + """Returns list of shapely geometries representing the intersections of the geometry with + a 2D box. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. For more details refer to + `Shapely's Documentation `_. + """ + log.warning( + "'intersections_2dbox()' is deprecated and will be removed in the future. " + "Use 'plane.intersections_with(...)' for the same functionality." + ) + return plane.intersections_with(self) + + def intersects( + self, other: Geometry, strict_inequality: tuple[bool, bool, bool] = [False, False, False] + ) -> bool: + """Returns ``True`` if two :class:`Geometry` have intersecting `.bounds`. + + Parameters + ---------- + other : :class:`Geometry` + Geometry to check intersection with. + strict_inequality : tuple[bool, bool, bool] = [False, False, False] + For each dimension, defines whether to include equality in the boundaries comparison. + If ``False``, equality is included, and two geometries that only intersect at their + boundaries will evaluate as ``True``. If ``True``, such geometries will evaluate as + ``False``. + + Returns + ------- + bool + Whether the rectangular bounding boxes of the two geometries intersect. + """ + + self_bmin, self_bmax = self.bounds + other_bmin, other_bmax = other.bounds + + for smin, omin, smax, omax, strict in zip( + self_bmin, other_bmin, self_bmax, other_bmax, strict_inequality + ): + # are all of other's minimum coordinates less than self's maximum coordinate? + in_minus = omin < smax if strict else omin <= smax + # are all of other's maximum coordinates greater than self's minimum coordinate? + in_plus = omax > smin if strict else omax >= smin + + # if either failed, return False + if not all((in_minus, in_plus)): + return False + + return True + + def contains( + self, other: Geometry, strict_inequality: tuple[bool, bool, bool] = [False, False, False] + ) -> bool: + """Returns ``True`` if the `.bounds` of ``other`` are contained within the + `.bounds` of ``self``. + + Parameters + ---------- + other : :class:`Geometry` + Geometry to check containment with. + strict_inequality : tuple[bool, bool, bool] = [False, False, False] + For each dimension, defines whether to include equality in the boundaries comparison. + If ``False``, equality will be considered as contained. If ``True``, ``other``'s + bounds must be strictly within the bounds of ``self``. + + Returns + ------- + bool + Whether the rectangular bounding box of ``other`` is contained within the bounding + box of ``self``. + """ + + self_bmin, self_bmax = self.bounds + other_bmin, other_bmax = other.bounds + + for smin, omin, smax, omax, strict in zip( + self_bmin, other_bmin, self_bmax, other_bmax, strict_inequality + ): + # are all of other's minimum coordinates greater than self's minimim coordinate? + in_minus = omin > smin if strict else omin >= smin + # are all of other's maximum coordinates less than self's maximum coordinate? + in_plus = omax < smax if strict else omax <= smax + + # if either failed, return False + if not all((in_minus, in_plus)): + return False + + return True + + def intersects_plane( + self, x: Optional[float] = None, y: Optional[float] = None, z: Optional[float] = None + ) -> bool: + """Whether self intersects plane specified by one non-None value of x,y,z. + + Parameters + ---------- + x : float = None + Position of plane in x direction, only one of x,y,z can be specified to define plane. + y : float = None + Position of plane in y direction, only one of x,y,z can be specified to define plane. + z : float = None + Position of plane in z direction, only one of x,y,z can be specified to define plane. + + Returns + ------- + bool + Whether this geometry intersects the plane. + """ + + axis, position = self.parse_xyz_kwargs(x=x, y=y, z=z) + return self.intersects_axis_position(axis, position) + + def intersects_axis_position(self, axis: int, position: float) -> bool: + """Whether self intersects plane specified by a given position along a normal axis. + + Parameters + ---------- + axis : int = None + Axis normal to the plane. + position : float = None + Position of plane along the normal axis. + + Returns + ------- + bool + Whether this geometry intersects the plane. + """ + return self.bounds[0][axis] <= position <= self.bounds[1][axis] + + @cached_property + @abstractmethod + def bounds(self) -> Bound: + """Returns bounding box min and max coordinates. + + Returns + ------- + tuple[float, float, float], tuple[float, float float] + Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. + """ + + @staticmethod + def bounds_intersection(bounds1: Bound, bounds2: Bound) -> Bound: + """Return the bounds that are the intersection of two bounds.""" + return bounds_intersection(bounds1, bounds2) + + @staticmethod + def bounds_union(bounds1: Bound, bounds2: Bound) -> Bound: + """Return the bounds that are the union of two bounds.""" + return bounds_union(bounds1, bounds2) + + @cached_property + def bounding_box(self) -> Box: + """Returns :class:`Box` representation of the bounding box of a :class:`Geometry`. + + Returns + ------- + :class:`Box` + Geometric object representing bounding box. + """ + return Box.from_bounds(*self.bounds) + + @cached_property + def zero_dims(self) -> list[Axis]: + """A list of axes along which the :class:`Geometry` is zero-sized based on its bounds.""" + zero_dims = [] + for dim in range(3): + if self.bounds[1][dim] == self.bounds[0][dim]: + zero_dims.append(dim) + return zero_dims + + def _pop_bounds(self, axis: Axis) -> tuple[Coordinate2D, tuple[Coordinate2D, Coordinate2D]]: + """Returns min and max bounds in plane normal to and tangential to ``axis``. + + Parameters + ---------- + axis : int + Integer index into 'xyz' (0,1,2). + + Returns + ------- + tuple[float, float], tuple[tuple[float, float], tuple[float, float]] + Bounds along axis and a tuple of bounds in the ordered planar coordinates. + Packed as ``(zmin, zmax), ((xmin, ymin), (xmax, ymax))``. + """ + b_min, b_max = self.bounds + zmin, (xmin, ymin) = self.pop_axis(b_min, axis=axis) + zmax, (xmax, ymax) = self.pop_axis(b_max, axis=axis) + return (zmin, zmax), ((xmin, ymin), (xmax, ymax)) + + @staticmethod + def _get_center(pt_min: float, pt_max: float) -> float: + """Returns center point based on bounds along dimension.""" + if np.isneginf(pt_min) and np.isposinf(pt_max): + return 0.0 + if np.isneginf(pt_min) or np.isposinf(pt_max): + raise SetupError( + f"Bounds of ({pt_min}, {pt_max}) supplied along one dimension. " + "We currently don't support a single ``inf`` value in bounds for ``Box``. " + "To construct a semi-infinite ``Box``, " + "please supply a large enough number instead of ``inf``. " + "For example, a location extending outside of the " + "Simulation domain (including PML)." + ) + return (pt_min + pt_max) / 2.0 + + @cached_property + def _normal_2dmaterial(self) -> Axis: + """Get the normal to the given geometry, checking that it is a 2D geometry.""" + raise ValidationError("'Medium2D' is not compatible with this geometry class.") + + def _update_from_bounds(self, bounds: tuple[float, float], axis: Axis) -> Geometry: + """Returns an updated geometry which has been transformed to fit within ``bounds`` + along the ``axis`` direction.""" + raise NotImplementedError( + "'_update_from_bounds' is not compatible with this geometry class." + ) + + @equal_aspect + @add_ax_if_none + def plot( + self, + x: Optional[float] = None, + y: Optional[float] = None, + z: Optional[float] = None, + ax: Ax = None, + plot_length_units: LengthUnit = None, + viz_spec: VisualizationSpec = None, + **patch_kwargs: Any, + ) -> Ax: + """Plot geometry cross section at single (x,y,z) coordinate. + + Parameters + ---------- + x : float = None + Position of plane in x direction, only one of x,y,z can be specified to define plane. + y : float = None + Position of plane in y direction, only one of x,y,z can be specified to define plane. + z : float = None + Position of plane in z direction, only one of x,y,z can be specified to define plane. + ax : matplotlib.axes._subplots.Axes = None + Matplotlib axes to plot on, if not specified, one is created. + plot_length_units : LengthUnit = None + Specify units to use for axis labels, tick labels, and the title. + viz_spec : VisualizationSpec = None + Plotting parameters associated with a medium to use instead of defaults. + **patch_kwargs + Optional keyword arguments passed to the matplotlib patch plotting of structure. + For details on accepted values, refer to + `Matplotlib's documentation `_. + + Returns + ------- + matplotlib.axes._subplots.Axes + The supplied or created matplotlib axes. + """ + + # find shapes that intersect self at plane + axis, position = self.parse_xyz_kwargs(x=x, y=y, z=z) + shapes_intersect = self.intersections_plane(x=x, y=y, z=z) + + plot_params = self.plot_params + if viz_spec is not None: + plot_params = plot_params.override_with_viz_spec(viz_spec) + plot_params = plot_params.include_kwargs(**patch_kwargs) + + # for each intersection, plot the shape + for shape in shapes_intersect: + ax = self.plot_shape(shape, plot_params=plot_params, ax=ax) + + # clean up the axis display + ax = self.add_ax_lims(axis=axis, ax=ax) + ax.set_aspect("equal") + # Add the default axis labels, tick labels, and title + ax = Box.add_ax_labels_and_title(ax=ax, x=x, y=y, z=z, plot_length_units=plot_length_units) + return ax + + def plot_shape(self, shape: Shapely, plot_params: PlotParams, ax: Ax) -> Ax: + """Defines how a shape is plotted on a matplotlib axes.""" + if shape.geom_type in ( + "MultiPoint", + "MultiLineString", + "MultiPolygon", + "GeometryCollection", + ): + for sub_shape in shape.geoms: + ax = self.plot_shape(shape=sub_shape, plot_params=plot_params, ax=ax) + + return ax + + _shape = Geometry.evaluate_inf_shape(shape) + + if _shape.geom_type == "LineString": + xs, ys = zip(*_shape.coords) + ax.plot(xs, ys, color=plot_params.facecolor, linewidth=plot_params.linewidth) + elif _shape.geom_type == "Point": + ax.scatter(shape.x, shape.y, color=plot_params.facecolor) + else: + patch = polygon_patch(_shape, **plot_params.to_kwargs()) + ax.add_artist(patch) + return ax + + @staticmethod + def _do_not_intersect( + bounds_a: float, bounds_b: float, shape_a: Shapely, shape_b: Shapely + ) -> bool: + """Check whether two shapes intersect.""" + + # do a bounding box check to see if any intersection to do anything about + if ( + bounds_a[0] > bounds_b[2] + or bounds_b[0] > bounds_a[2] + or bounds_a[1] > bounds_b[3] + or bounds_b[1] > bounds_a[3] + ): + return True + + # look more closely to see if intersected. + if shape_b.is_empty or not shape_a.intersects(shape_b): + return True + + return False + + @staticmethod + def _get_plot_labels(axis: Axis) -> tuple[str, str]: + """Returns planar coordinate x and y axis labels for cross section plots. + + Parameters + ---------- + axis : int + Integer index into 'xyz' (0,1,2). + + Returns + ------- + str, str + Labels of plot, packaged as ``(xlabel, ylabel)``. + """ + _, (xlabel, ylabel) = Geometry.pop_axis("xyz", axis=axis) + return xlabel, ylabel + + def _get_plot_limits( + self, axis: Axis, buffer: float = PLOT_BUFFER + ) -> tuple[Coordinate2D, Coordinate2D]: + """Gets planar coordinate limits for cross section plots. + + Parameters + ---------- + axis : int + Integer index into 'xyz' (0,1,2). + buffer : float = 0.3 + Amount of space to add around the limits on the + and - sides. + + Returns + ------- + tuple[float, float], tuple[float, float] + The x and y plot limits, packed as ``(xmin, xmax), (ymin, ymax)``. + """ + _, ((xmin, ymin), (xmax, ymax)) = self._pop_bounds(axis=axis) + return (xmin - buffer, xmax + buffer), (ymin - buffer, ymax + buffer) + + def add_ax_lims(self, axis: Axis, ax: Ax, buffer: float = PLOT_BUFFER) -> Ax: + """Sets the x,y limits based on ``self.bounds``. + + Parameters + ---------- + axis : int + Integer index into 'xyz' (0,1,2). + ax : matplotlib.axes._subplots.Axes + Matplotlib axes to add labels and limits on. + buffer : float = 0.3 + Amount of space to place around the limits on the + and - sides. + + Returns + ------- + matplotlib.axes._subplots.Axes + The supplied or created matplotlib axes. + """ + (xmin, xmax), (ymin, ymax) = self._get_plot_limits(axis=axis, buffer=buffer) + + # note: axes limits dont like inf values, so we need to evaluate them first if present + xmin, xmax, ymin, ymax = self._evaluate_inf((xmin, xmax, ymin, ymax)) + + ax.set_xlim(xmin, xmax) + ax.set_ylim(ymin, ymax) + return ax + + @staticmethod + def add_ax_labels_and_title( + ax: Ax, + x: Optional[float] = None, + y: Optional[float] = None, + z: Optional[float] = None, + plot_length_units: LengthUnit = None, + ) -> Ax: + """Sets the axis labels, tick labels, and title based on ``axis`` + and an optional ``plot_length_units`` argument. + + Parameters + ---------- + ax : matplotlib.axes._subplots.Axes + Matplotlib axes to add labels and limits on. + x : float = None + Position of plane in x direction, only one of x,y,z can be specified to define plane. + y : float = None + Position of plane in y direction, only one of x,y,z can be specified to define plane. + z : float = None + Position of plane in z direction, only one of x,y,z can be specified to define plane. + plot_length_units : LengthUnit = None + When set to a supported ``LengthUnit``, plots will be produced with annotated axes + and title with the proper units. + + Returns + ------- + matplotlib.axes._subplots.Axes + The supplied matplotlib axes. + """ + axis, position = Box.parse_xyz_kwargs(x=x, y=y, z=z) + axis_labels = Box._get_plot_labels(axis) + ax = set_default_labels_and_title( + axis_labels=axis_labels, + axis=axis, + position=position, + ax=ax, + plot_length_units=plot_length_units, + ) + return ax + + @staticmethod + def _evaluate_inf(array: ArrayLike) -> NDArray[np.floating]: + """Processes values and evaluates any infs into large (signed) numbers.""" + array = get_static(np.array(array)) + return np.where(np.isinf(array), np.sign(array) * LARGE_NUMBER, array) + + @staticmethod + def evaluate_inf_shape(shape: Shapely) -> Shapely: + """Returns a copy of shape with inf vertices replaced by large numbers if polygon.""" + if not any(np.isinf(b) for b in shape.bounds): + return shape + + def _processed_coords(coords: Sequence[tuple[Any, ...]]) -> list[tuple[float, ...]]: + evaluated = Geometry._evaluate_inf(np.array(coords)) + return [tuple(point) for point in evaluated.tolist()] + + if shape.geom_type == "Polygon": + shell = _processed_coords(shape.exterior.coords) + holes = [_processed_coords(g.coords) for g in shape.interiors] + return shapely.Polygon(shell, holes) + if shape.geom_type in {"Point", "LineString", "LinearRing"}: + return shape.__class__(Geometry._evaluate_inf(np.array(shape.coords))) + if shape.geom_type in { + "MultiPoint", + "MultiLineString", + "MultiPolygon", + "GeometryCollection", + }: + return shape.__class__([Geometry.evaluate_inf_shape(g) for g in shape.geoms]) + return shape + + @staticmethod + def pop_axis(coord: tuple[Any, Any, Any], axis: int) -> tuple[Any, tuple[Any, Any]]: + """Separates coordinate at ``axis`` index from coordinates on the plane tangent to ``axis``. + + Parameters + ---------- + coord : tuple[Any, Any, Any] + Tuple of three values in original coordinate system. + axis : int + Integer index into 'xyz' (0,1,2). + + Returns + ------- + Any, tuple[Any, Any] + The input coordinates are separated into the one along the axis provided + and the two on the planar coordinates, + like ``axis_coord, (planar_coord1, planar_coord2)``. + """ + plane_vals = list(coord) + axis_val = plane_vals.pop(axis) + return axis_val, tuple(plane_vals) + + @staticmethod + def unpop_axis(ax_coord: Any, plane_coords: tuple[Any, Any], axis: int) -> tuple[Any, Any, Any]: + """Combine coordinate along axis with coordinates on the plane tangent to the axis. + + Parameters + ---------- + ax_coord : Any + Value along axis direction. + plane_coords : tuple[Any, Any] + Values along ordered planar directions. + axis : int + Integer index into 'xyz' (0,1,2). + + Returns + ------- + tuple[Any, Any, Any] + The three values in the xyz coordinate system. + """ + coords = list(plane_coords) + coords.insert(axis, ax_coord) + return tuple(coords) + + @staticmethod + def parse_xyz_kwargs(**xyz: Any) -> tuple[Axis, float]: + """Turns x,y,z kwargs into index of the normal axis and position along that axis. + + Parameters + ---------- + x : float = None + Position of plane in x direction, only one of x,y,z can be specified to define plane. + y : float = None + Position of plane in y direction, only one of x,y,z can be specified to define plane. + z : float = None + Position of plane in z direction, only one of x,y,z can be specified to define plane. + + Returns + ------- + int, float + Index into xyz axis (0,1,2) and position along that axis. + """ + xyz_filtered = {k: v for k, v in xyz.items() if v is not None} + if len(xyz_filtered) != 1: + raise ValueError("exactly one kwarg in [x,y,z] must be specified.") + axis_label, position = list(xyz_filtered.items())[0] + axis = "xyz".index(axis_label) + return axis, position + + @staticmethod + def parse_two_xyz_kwargs(**xyz: Any) -> list[tuple[Axis, float]]: + """Turns x,y,z kwargs into indices of axes and the position along each axis. + + Parameters + ---------- + x : float = None + Position in x direction, only two of x,y,z can be specified to define line. + y : float = None + Position in y direction, only two of x,y,z can be specified to define line. + z : float = None + Position in z direction, only two of x,y,z can be specified to define line. + + Returns + ------- + [(int, float), (int, float)] + Index into xyz axis (0,1,2) and position along that axis. + """ + xyz_filtered = {k: v for k, v in xyz.items() if v is not None} + assert len(xyz_filtered) == 2, "exactly two kwarg in [x,y,z] must be specified." + xyz_list = list(xyz_filtered.items()) + return [("xyz".index(axis_label), position) for axis_label, position in xyz_list] + + @staticmethod + def rotate_points(points: ArrayFloat3D, axis: Coordinate, angle: float) -> ArrayFloat3D: + """Rotate a set of points in 3D. + + Parameters + ---------- + points : ArrayLike[float] + Array of shape ``(3, ...)``. + axis : Coordinate + Axis of rotation + angle : float + Angle of rotation counter-clockwise around the axis (rad). + """ + rotation = RotationAroundAxis(axis=axis, angle=angle) + return rotation.rotate_vector(points) + + def reflect_points( + self, + points: ArrayFloat3D, + polar_axis: Axis, + angle_theta: float, + angle_phi: float, + ) -> ArrayFloat3D: + """Reflect a set of points in 3D at a plane passing through the coordinate origin defined + and normal to a given axis defined in polar coordinates (theta, phi) w.r.t. the + ``polar_axis`` which can be 0, 1, or 2. + + Parameters + ---------- + points : ArrayLike[float] + Array of shape ``(3, ...)``. + polar_axis : Axis + Cartesian axis w.r.t. which the normal axis angles are defined. + angle_theta : float + Polar angle w.r.t. the polar axis. + angle_phi : float + Azimuth angle around the polar axis. + """ + + # Rotate such that the plane normal is along the polar_axis + axis_theta, axis_phi = [0, 0, 0], [0, 0, 0] + axis_phi[polar_axis] = 1 + plane_axes = [0, 1, 2] + plane_axes.pop(polar_axis) + axis_theta[plane_axes[1]] = 1 + points_new = self.rotate_points(points, axis_phi, -angle_phi) + points_new = self.rotate_points(points_new, axis_theta, -angle_theta) + + # Flip the ``polar_axis`` coordinate of the points, which is now normal to the plane + points_new[polar_axis, :] *= -1 + + # Rotate back + points_new = self.rotate_points(points_new, axis_theta, angle_theta) + points_new = self.rotate_points(points_new, axis_phi, angle_phi) + + return points_new + + def volume(self, bounds: Bound = None) -> float: + """Returns object's volume with optional bounds. + + Parameters + ---------- + bounds : tuple[tuple[float, float, float], tuple[float, float, float]] = None + Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. + + Returns + ------- + float + Volume in um^3. + """ + + if not bounds: + bounds = self.bounds + + return self._volume(bounds) + + @abstractmethod + def _volume(self, bounds: Bound) -> float: + """Returns object's volume within given bounds.""" + + def surface_area(self, bounds: Bound = None) -> float: + """Returns object's surface area with optional bounds. + + Parameters + ---------- + bounds : tuple[tuple[float, float, float], tuple[float, float, float]] = None + Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. + + Returns + ------- + float + Surface area in um^2. + """ + + if not bounds: + bounds = self.bounds + + return self._surface_area(bounds) + + @abstractmethod + def _surface_area(self, bounds: Bound) -> float: + """Returns object's surface area within given bounds.""" + + def translated(self, x: float, y: float, z: float) -> Geometry: + """Return a translated copy of this geometry. + + Parameters + ---------- + x : float + Translation along x. + y : float + Translation along y. + z : float + Translation along z. + + Returns + ------- + :class:`Geometry` + Translated copy of this geometry. + """ + return Transformed(geometry=self, transform=Transformed.translation(x, y, z)) + + def scaled(self, x: float = 1.0, y: float = 1.0, z: float = 1.0) -> Geometry: + """Return a scaled copy of this geometry. + + Parameters + ---------- + x : float = 1.0 + Scaling factor along x. + y : float = 1.0 + Scaling factor along y. + z : float = 1.0 + Scaling factor along z. + + Returns + ------- + :class:`Geometry` + Scaled copy of this geometry. + """ + return Transformed(geometry=self, transform=Transformed.scaling(x, y, z)) + + def rotated(self, angle: float, axis: Union[Axis, Coordinate]) -> Geometry: + """Return a rotated copy of this geometry. + + Parameters + ---------- + angle : float + Rotation angle (in radians). + axis : Union[int, tuple[float, float, float]] + Axis of rotation: 0, 1, or 2 for x, y, and z, respectively, or a 3D vector. + + Returns + ------- + :class:`Geometry` + Rotated copy of this geometry. + """ + return Transformed(geometry=self, transform=Transformed.rotation(angle, axis)) + + def reflected(self, normal: Coordinate) -> Geometry: + """Return a reflected copy of this geometry. + + Parameters + ---------- + normal : tuple[float, float, float] + The 3D normal vector of the plane of reflection. The plane is assumed + to pass through the origin (0,0,0). + + Returns + ------- + :class:`Geometry` + Reflected copy of this geometry. + """ + return Transformed(geometry=self, transform=Transformed.reflection(normal)) + + """ Field and coordinate transformations """ + + @staticmethod + def car_2_sph(x: float, y: float, z: float) -> tuple[float, float, float]: + """Convert Cartesian to spherical coordinates. + + Parameters + ---------- + x : float + x coordinate relative to ``local_origin``. + y : float + y coordinate relative to ``local_origin``. + z : float + z coordinate relative to ``local_origin``. + + Returns + ------- + tuple[float, float, float] + r, theta, and phi coordinates relative to ``local_origin``. + """ + r = np.sqrt(x**2 + y**2 + z**2) + theta = np.arccos(z / r) + phi = np.arctan2(y, x) + return r, theta, phi + + @staticmethod + def sph_2_car(r: float, theta: float, phi: float) -> tuple[float, float, float]: + """Convert spherical to Cartesian coordinates. + + Parameters + ---------- + r : float + radius. + theta : float + polar angle (rad) downward from x=y=0 line. + phi : float + azimuthal (rad) angle from y=z=0 line. + + Returns + ------- + tuple[float, float, float] + x, y, and z coordinates relative to ``local_origin``. + """ + r_sin_theta = r * np.sin(theta) + x = r_sin_theta * np.cos(phi) + y = r_sin_theta * np.sin(phi) + z = r * np.cos(theta) + return x, y, z + + @staticmethod + def sph_2_car_field( + f_r: float, f_theta: float, f_phi: float, theta: float, phi: float + ) -> tuple[complex, complex, complex]: + """Convert vector field components in spherical coordinates to cartesian. + + Parameters + ---------- + f_r : float + radial component of the vector field. + f_theta : float + polar angle component of the vector fielf. + f_phi : float + azimuthal angle component of the vector field. + theta : float + polar angle (rad) of location of the vector field. + phi : float + azimuthal angle (rad) of location of the vector field. + + Returns + ------- + tuple[float, float, float] + x, y, and z components of the vector field in cartesian coordinates. + """ + sin_theta = np.sin(theta) + cos_theta = np.cos(theta) + sin_phi = np.sin(phi) + cos_phi = np.cos(phi) + f_x = f_r * sin_theta * cos_phi + f_theta * cos_theta * cos_phi - f_phi * sin_phi + f_y = f_r * sin_theta * sin_phi + f_theta * cos_theta * sin_phi + f_phi * cos_phi + f_z = f_r * cos_theta - f_theta * sin_theta + return f_x, f_y, f_z + + @staticmethod + def car_2_sph_field( + f_x: float, f_y: float, f_z: float, theta: float, phi: float + ) -> tuple[complex, complex, complex]: + """Convert vector field components in cartesian coordinates to spherical. + + Parameters + ---------- + f_x : float + x component of the vector field. + f_y : float + y component of the vector fielf. + f_z : float + z component of the vector field. + theta : float + polar angle (rad) of location of the vector field. + phi : float + azimuthal angle (rad) of location of the vector field. + + Returns + ------- + tuple[float, float, float] + radial (s), elevation (theta), and azimuthal (phi) components + of the vector field in spherical coordinates. + """ + sin_theta = np.sin(theta) + cos_theta = np.cos(theta) + sin_phi = np.sin(phi) + cos_phi = np.cos(phi) + f_r = f_x * sin_theta * cos_phi + f_y * sin_theta * sin_phi + f_z * cos_theta + f_theta = f_x * cos_theta * cos_phi + f_y * cos_theta * sin_phi - f_z * sin_theta + f_phi = -f_x * sin_phi + f_y * cos_phi + return f_r, f_theta, f_phi + + @staticmethod + def kspace_2_sph(ux: float, uy: float, axis: Axis) -> tuple[float, float]: + """Convert normalized k-space coordinates to angles. + + Parameters + ---------- + ux : float + normalized kx coordinate. + uy : float + normalized ky coordinate. + axis : int + axis along which the observation plane is oriented. + + Returns + ------- + tuple[float, float] + theta and phi coordinates relative to ``local_origin``. + """ + phi_local = np.arctan2(uy, ux) + with np.errstate(invalid="ignore"): + theta_local = np.arcsin(np.sqrt(ux**2 + uy**2)) + # Spherical coordinates rotation matrix reference: + # https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula#Matrix_notation + if axis == 2: + return theta_local, phi_local + + x = np.cos(theta_local) + y = np.sin(theta_local) * np.cos(phi_local) + z = np.sin(theta_local) * np.sin(phi_local) + + if axis == 1: + x, y, z = y, x, z + + theta = np.arccos(z) + phi = np.arctan2(y, x) + return theta, phi + + @staticmethod + @verify_packages_import(["gdstk"]) + def load_gds_vertices_gdstk( + gds_cell: Cell, + gds_layer: int, + gds_dtype: Optional[int] = None, + gds_scale: PositiveFloat = 1.0, + ) -> list[ArrayFloat2D]: + """Load polygon vertices from a ``gdstk.Cell``. + + Parameters + ---------- + gds_cell : gdstk.Cell + ``gdstk.Cell`` containing 2D geometric data. + gds_layer : int + Layer index in the ``gds_cell``. + gds_dtype : int = None + Data-type index in the ``gds_cell``. If ``None``, imports all data for this layer into + the returned list. + gds_scale : float = 1.0 + Length scale used in GDS file in units of micrometer. For example, if gds file uses + nanometers, set ``gds_scale=1e-3``. Must be positive. + + Returns + ------- + list[ArrayFloat2D] + List of polygon vertices + """ + + # apply desired scaling and load the polygon vertices + if gds_dtype is not None: + # if both layer and datatype are specified, let gdstk do the filtering for better + # performance on large layouts + all_vertices = [ + polygon.scale(gds_scale).points + for polygon in gds_cell.get_polygons(layer=gds_layer, datatype=gds_dtype) + ] + else: + all_vertices = [ + polygon.scale(gds_scale).points + for polygon in gds_cell.get_polygons() + if polygon.layer == gds_layer + ] + # make sure something got loaded, otherwise error + if not all_vertices: + raise Tidy3dKeyError( + f"Couldn't load gds_cell, no vertices found at gds_layer={gds_layer} " + f"with specified gds_dtype={gds_dtype}." + ) + + return all_vertices + + @staticmethod + @verify_packages_import(["gdstk"]) + def from_gds( + gds_cell: Cell, + axis: Axis, + slab_bounds: tuple[float, float], + gds_layer: int, + gds_dtype: Optional[int] = None, + gds_scale: PositiveFloat = 1.0, + dilation: float = 0.0, + sidewall_angle: float = 0, + reference_plane: PlanePosition = "middle", + ) -> Geometry: + """Import a ``gdstk.Cell`` and extrude it into a GeometryGroup. + + Parameters + ---------- + gds_cell : gdstk.Cell + ``gdstk.Cell`` containing 2D geometric data. + axis : int + Integer index defining the extrusion axis: 0 (x), 1 (y), or 2 (z). + slab_bounds: tuple[float, float] + Minimal and maximal positions of the extruded slab along ``axis``. + gds_layer : int + Layer index in the ``gds_cell``. + gds_dtype : int = None + Data-type index in the ``gds_cell``. If ``None``, imports all data for this layer into + the returned list. + gds_scale : float = 1.0 + Length scale used in GDS file in units of micrometer. For example, if gds file uses + nanometers, set ``gds_scale=1e-3``. Must be positive. + dilation : float = 0.0 + Dilation (positive) or erosion (negative) amount to be applied to the original polygons. + sidewall_angle : float = 0 + Angle of the extrusion sidewalls, away from the vertical direction, in radians. Positive + (negative) values result in slabs larger (smaller) at the base than at the top. + reference_plane : PlanePosition = "middle" + Reference position of the (dilated/eroded) polygons along the slab axis. One of + ``"middle"`` (polygons correspond to the center of the slab bounds), ``"bottom"`` + (minimal slab bound position), or ``"top"`` (maximal slab bound position). This value + has no effect if ``sidewall_angle == 0``. + + Returns + ------- + :class:`Geometry` + Geometries created from the 2D data. + """ + import gdstk + + if not isinstance(gds_cell, gdstk.Cell): + # Check if it might be a gdstk cell but gdstk is not found (should be caught by decorator) + # or if it's an entirely different type. + if "gdstk" in gds_cell.__class__.__name__.lower(): + raise Tidy3dImportError( + "Module 'gdstk' not found. It is required to import gdstk cells." + ) + raise Tidy3dImportError("Argument 'gds_cell' must be an instance of 'gdstk.Cell'.") + + gds_loader_fn = Geometry.load_gds_vertices_gdstk + geometries = [] + with log as consolidated_logger: + for vertices in gds_loader_fn(gds_cell, gds_layer, gds_dtype, gds_scale): + # buffer(0) is necessary to merge self-intersections + shape = shapely.set_precision(shapely.Polygon(vertices).buffer(0), POLY_GRID_SIZE) + try: + geometries.append( + from_shapely( + shape, axis, slab_bounds, dilation, sidewall_angle, reference_plane + ) + ) + except ValidationError as error: + consolidated_logger.warning(str(error)) + except Tidy3dError as error: + consolidated_logger.warning(str(error)) + return geometries[0] if len(geometries) == 1 else GeometryGroup(geometries=geometries) + + @staticmethod + def from_shapely( + shape: Shapely, + axis: Axis, + slab_bounds: tuple[float, float], + dilation: float = 0.0, + sidewall_angle: float = 0, + reference_plane: PlanePosition = "middle", + ) -> Geometry: + """Convert a shapely primitive into a geometry instance by extrusion. + + Parameters + ---------- + shape : shapely.geometry.base.BaseGeometry + Shapely primitive to be converted. It must be a linear ring, a polygon or a collection + of any of those. + axis : int + Integer index defining the extrusion axis: 0 (x), 1 (y), or 2 (z). + slab_bounds: tuple[float, float] + Minimal and maximal positions of the extruded slab along ``axis``. + dilation : float + Dilation of the polygon in the base by shifting each edge along its normal outwards + direction by a distance; a negative value corresponds to erosion. + sidewall_angle : float = 0 + Angle of the extrusion sidewalls, away from the vertical direction, in radians. Positive + (negative) values result in slabs larger (smaller) at the base than at the top. + reference_plane : PlanePosition = "middle" + Reference position of the (dilated/eroded) polygons along the slab axis. One of + ``"middle"`` (polygons correspond to the center of the slab bounds), ``"bottom"`` + (minimal slab bound position), or ``"top"`` (maximal slab bound position). This value + has no effect if ``sidewall_angle == 0``. + + Returns + ------- + :class:`Geometry` + Geometry extruded from the 2D data. + """ + return from_shapely(shape, axis, slab_bounds, dilation, sidewall_angle, reference_plane) + + @verify_packages_import(["gdstk"]) + def to_gdstk( + self, + x: Optional[float] = None, + y: Optional[float] = None, + z: Optional[float] = None, + gds_layer: NonNegativeInt = 0, + gds_dtype: NonNegativeInt = 0, + ) -> list: + """Convert a Geometry object's planar slice to a .gds type polygon. + + Parameters + ---------- + x : float = None + Position of plane in x direction, only one of x,y,z can be specified to define plane. + y : float = None + Position of plane in y direction, only one of x,y,z can be specified to define plane. + z : float = None + Position of plane in z direction, only one of x,y,z can be specified to define plane. + gds_layer : int = 0 + Layer index to use for the shapes stored in the .gds file. + gds_dtype : int = 0 + Data-type index to use for the shapes stored in the .gds file. + + Return + ------ + List + List of `gdstk.Polygon`. + """ + import gdstk + + shapes = self.intersections_plane(x=x, y=y, z=z) + polygons = [] + for shape in shapes: + for vertices in vertices_from_shapely(shape): + if len(vertices) == 1: + polygons.append(gdstk.Polygon(vertices[0], gds_layer, gds_dtype)) + else: + polygons.extend( + gdstk.boolean( + vertices[:1], + vertices[1:], + "not", + layer=gds_layer, + datatype=gds_dtype, + ) + ) + return polygons + + @verify_packages_import(["gdstk"]) + def to_gds( + self, + cell: Cell, + x: Optional[float] = None, + y: Optional[float] = None, + z: Optional[float] = None, + gds_layer: NonNegativeInt = 0, + gds_dtype: NonNegativeInt = 0, + ) -> None: + """Append a Geometry object's planar slice to a .gds cell. + + Parameters + ---------- + cell : ``gdstk.Cell`` + Cell object to which the generated polygons are added. + x : float = None + Position of plane in x direction, only one of x,y,z can be specified to define plane. + y : float = None + Position of plane in y direction, only one of x,y,z can be specified to define plane. + z : float = None + Position of plane in z direction, only one of x,y,z can be specified to define plane. + gds_layer : int = 0 + Layer index to use for the shapes stored in the .gds file. + gds_dtype : int = 0 + Data-type index to use for the shapes stored in the .gds file. + """ + import gdstk + + if not isinstance(cell, gdstk.Cell): + if "gdstk" in cell.__class__.__name__.lower(): + raise Tidy3dImportError( + "Module 'gdstk' not found. It is required to export shapes to gdstk cells." + ) + raise Tidy3dImportError("Argument 'cell' must be an instance of 'gdstk.Cell'.") + + polygons = self.to_gdstk(x=x, y=y, z=z, gds_layer=gds_layer, gds_dtype=gds_dtype) + if polygons: + cell.add(*polygons) + + @verify_packages_import(["gdstk"]) + def to_gds_file( + self, + fname: PathLike, + x: Optional[float] = None, + y: Optional[float] = None, + z: Optional[float] = None, + gds_layer: NonNegativeInt = 0, + gds_dtype: NonNegativeInt = 0, + gds_cell_name: str = "MAIN", + ) -> None: + """Export a Geometry object's planar slice to a .gds file. + + Parameters + ---------- + fname : PathLike + Full path to the .gds file to save the :class:`Geometry` slice to. + x : float = None + Position of plane in x direction, only one of x,y,z can be specified to define plane. + y : float = None + Position of plane in y direction, only one of x,y,z can be specified to define plane. + z : float = None + Position of plane in z direction, only one of x,y,z can be specified to define plane. + gds_layer : int = 0 + Layer index to use for the shapes stored in the .gds file. + gds_dtype : int = 0 + Data-type index to use for the shapes stored in the .gds file. + gds_cell_name : str = 'MAIN' + Name of the cell created in the .gds file to store the geometry. + """ + try: + import gdstk + except ImportError as e: + raise Tidy3dImportError( + "Python module 'gdstk' not found. To export geometries to .gds " + "files, please install it." + ) from e + + library = gdstk.Library() + cell = library.new_cell(gds_cell_name) + self.to_gds(cell, x=x, y=y, z=z, gds_layer=gds_layer, gds_dtype=gds_dtype) + fname = pathlib.Path(fname) + fname.parent.mkdir(parents=True, exist_ok=True) + library.write_gds(fname) + + def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: + """Compute the adjoint derivatives for this object.""" + raise NotImplementedError(f"Can't compute derivative for 'Geometry': '{type(self)}'.") + + def _as_union(self) -> list[Geometry]: + """Return a list of geometries that, united, make up the given geometry.""" + if isinstance(self, GeometryGroup): + return self.geometries + + if isinstance(self, ClipOperation) and self.operation == "union": + return (self.geometry_a, self.geometry_b) + return (self,) + + def __add__(self, other: Union[int, Geometry]) -> Union[Self, GeometryGroup]: + """Union of geometries""" + # This allows the user to write sum(geometries...) with the default start=0 + if isinstance(other, int): + return self + if not isinstance(other, Geometry): + return NotImplemented # type: ignore[return-value] + return GeometryGroup(geometries=self._as_union() + other._as_union()) + + def __radd__(self, other: Union[int, Geometry]) -> Union[Self, GeometryGroup]: + """Union of geometries""" + # This allows the user to write sum(geometries...) with the default start=0 + if isinstance(other, int): + return self + if not isinstance(other, Geometry): + return NotImplemented # type: ignore[return-value] + return GeometryGroup(geometries=other._as_union() + self._as_union()) + + def __or__(self, other: Geometry) -> GeometryGroup: + """Union of geometries""" + if not isinstance(other, Geometry): + return NotImplemented + return GeometryGroup(geometries=self._as_union() + other._as_union()) + + def __mul__(self, other: Geometry) -> ClipOperation: + """Intersection of geometries""" + if not isinstance(other, Geometry): + return NotImplemented + return ClipOperation(operation="intersection", geometry_a=self, geometry_b=other) + + def __and__(self, other: Geometry) -> ClipOperation: + """Intersection of geometries""" + if not isinstance(other, Geometry): + return NotImplemented + return ClipOperation(operation="intersection", geometry_a=self, geometry_b=other) + + def __sub__(self, other: Geometry) -> ClipOperation: + """Difference of geometries""" + if not isinstance(other, Geometry): + return NotImplemented # type: ignore[return-value] + return ClipOperation(operation="difference", geometry_a=self, geometry_b=other) + + def __xor__(self, other: Geometry) -> ClipOperation: + """Symmetric difference of geometries""" + if not isinstance(other, Geometry): + return NotImplemented + return ClipOperation(operation="symmetric_difference", geometry_a=self, geometry_b=other) + + def __pos__(self) -> Self: + """No op""" + return self + + def __neg__(self) -> ClipOperation: + """Opposite of a geometry""" + return ClipOperation( + operation="difference", geometry_a=Box(size=(inf, inf, inf)), geometry_b=self + ) + + def __invert__(self) -> ClipOperation: + """Opposite of a geometry""" + return ClipOperation( + operation="difference", geometry_a=Box(size=(inf, inf, inf)), geometry_b=self + ) + + +""" Abstract subclasses """ + + +class Centered(Geometry, ABC): + """Geometry with a well defined center.""" + + center: Optional[TracedCoordinate] = Field( + None, + title="Center", + description="Center of object in x, y, and z.", + json_schema_extra={"units": MICROMETER}, + ) + + @field_validator("center", mode="before") + @classmethod + def _center_default(cls, val: Any) -> Any: + """Make sure center is not infinitiy.""" + if val is None: + val = (0.0, 0.0, 0.0) + return val + + @field_validator("center") + @classmethod + def _center_not_inf(cls, val: tuple[float, float, float]) -> tuple[float, float, float]: + """Make sure center is not infinitiy.""" + if any(np.isinf(v) for v in val): + raise ValidationError("center can not contain td.inf terms.") + return val + + +class SimplePlaneIntersection(Geometry, ABC): + """A geometry where intersections with an axis aligned plane may be computed efficiently.""" + + def intersections_tilted_plane( + self, + normal: Coordinate, + origin: Coordinate, + to_2D: MatrixReal4x4, + cleanup: bool = True, + quad_segs: Optional[int] = None, + ) -> list[Shapely]: + """Return a list of shapely geometries at the plane specified by normal and origin. + Checks special cases before relying on the complete computation. + + Parameters + ---------- + normal : Coordinate + Vector defining the normal direction to the plane. + origin : Coordinate + Vector defining the plane origin. + to_2D : MatrixReal4x4 + Transformation matrix to apply to resulting shapes. + cleanup : bool = True + If True, removes extremely small features from each polygon's boundary. + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. If ``None``, uses + high-quality visualization settings. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation `_. + """ + + # Check if normal is a special case, where the normal is aligned with an axis. + if np.sum(np.isclose(normal, 0.0)) == 2: + axis = np.argmax(np.abs(normal)).item() + coord = "xyz"[axis] + kwargs = {coord: origin[axis]} + section = self.intersections_plane(cleanup=cleanup, quad_segs=quad_segs, **kwargs) + # Apply transformation in the plane by removing row and column + to_2D_in_plane = np.delete(np.delete(to_2D, 2, 0), axis, 1) + + def transform(p_array: NDArray) -> NDArray: + return np.dot( + np.hstack((p_array, np.ones((p_array.shape[0], 1)))), to_2D_in_plane.T + )[:, :2] + + transformed_section = shapely.transform(section, transformation=transform) + return transformed_section + # Otherwise compute the arbitrary intersection + return self._do_intersections_tilted_plane( + normal=normal, origin=origin, to_2D=to_2D, quad_segs=quad_segs + ) + + @abstractmethod + def _do_intersections_tilted_plane( + self, + normal: Coordinate, + origin: Coordinate, + to_2D: MatrixReal4x4, + quad_segs: Optional[int] = None, + ) -> list[Shapely]: + """Return a list of shapely geometries at the plane specified by normal and origin. + + Parameters + ---------- + normal : Coordinate + Vector defining the normal direction to the plane. + origin : Coordinate + Vector defining the plane origin. + to_2D : MatrixReal4x4 + Transformation matrix to apply to resulting shapes. + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation `_. + """ + + +class Planar(SimplePlaneIntersection, Geometry, ABC): + """Geometry with one ``axis`` that is slab-like with thickness ``height``.""" + + axis: Axis = Field( + 2, + title="Axis", + description="Specifies dimension of the planar axis (0,1,2) -> (x,y,z).", + ) + + sidewall_angle: TracedFloat = Field( + 0.0, + title="Sidewall angle", + description="Angle of the sidewall. " + "``sidewall_angle=0`` (default) specifies a vertical wall; " + "``0 float: + lower_bound = -np.pi / 2 + upper_bound = np.pi / 2 + if (val <= lower_bound) or (val >= upper_bound): + # u03C0 is unicode for pi + raise ValidationError(f"Sidewall angle ({val}) must be between -π/2 and π/2 rad.") + return val + + @property + @abstractmethod + def center_axis(self) -> float: + """Gets the position of the center of the geometry in the out of plane dimension.""" + + @property + @abstractmethod + def length_axis(self) -> float: + """Gets the length of the geometry along the out of plane dimension.""" + + @property + def finite_length_axis(self) -> float: + """Gets the length of the geometry along the out of plane dimension. + If the length is td.inf, return ``LARGE_NUMBER`` + """ + return min(self.length_axis, LARGE_NUMBER) + + @property + def reference_axis_pos(self) -> float: + """Coordinate along the slab axis at the reference plane. + + Returns the axis coordinate corresponding to the selected + reference_plane: + - "bottom": lower bound of slab_bounds + - "middle": center_axis + - "top": upper bound of slab_bounds + """ + if self.reference_plane == "bottom": + return self.slab_bounds[0] + if self.reference_plane == "top": + return self.slab_bounds[1] + # default to middle + return self.center_axis + + def intersections_plane( + self, + x: Optional[float] = None, + y: Optional[float] = None, + z: Optional[float] = None, + cleanup: bool = True, + quad_segs: Optional[int] = None, + ) -> list[Shapely]: + """Returns shapely geometry at plane specified by one non None value of x,y,z. + + Parameters + ---------- + x : float + Position of plane in x direction, only one of x,y,z can be specified to define plane. + y : float + Position of plane in y direction, only one of x,y,z can be specified to define plane. + z : float + Position of plane in z direction, only one of x,y,z can be specified to define plane. + cleanup : bool = True + If True, removes extremely small features from each polygon's boundary. + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. If ``None``, uses + high-quality visualization settings. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation ``. + """ + axis, position = self.parse_xyz_kwargs(x=x, y=y, z=z) + if not self.intersects_axis_position(axis, position): + return [] + if axis == self.axis: + return self._intersections_normal(position, quad_segs=quad_segs) + return self._intersections_side(position, axis) + + @abstractmethod + def _intersections_normal(self, z: float, quad_segs: Optional[int] = None) -> list: + """Find shapely geometries intersecting planar geometry with axis normal to slab. + + Parameters + ---------- + z : float + Position along the axis normal to slab + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. If ``None``, uses + high-quality visualization settings. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation `_. + """ + + @abstractmethod + def _intersections_side(self, position: float, axis: Axis) -> list[Shapely]: + """Find shapely geometries intersecting planar geometry with axis orthogonal to plane. + + Parameters + ---------- + position : float + Position along axis. + axis : int + Integer index into 'xyz' (0,1,2). + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation `_. + """ + + def _order_axis(self, axis: int) -> int: + """Order the axis as if self.axis is along z-direction. + + Parameters + ---------- + axis : int + Integer index into the structure's planar axis. + + Returns + ------- + int + New index of axis. + """ + axis_index = [0, 1] + axis_index.insert(self.axis, 2) + return axis_index[axis] + + def _order_by_axis(self, plane_val: Any, axis_val: Any, axis: int) -> tuple[Any, Any]: + """Orders a value in the plane and value along axis in correct (x,y) order for plotting. + Note: sometimes if axis=1 and we compute cross section values orthogonal to axis, + they can either be x or y in the plots. + This function allows one to figure out the ordering. + + Parameters + ---------- + plane_val : Any + The value in the planar coordinate. + axis_val : Any + The value in the ``axis`` coordinate. + axis : int + Integer index into the structure's planar axis. + + Returns + ------- + ``(Any, Any)`` + The two planar coordinates in this new coordinate system. + """ + vals = 3 * [plane_val] + vals[self.axis] = axis_val + _, (val_x, val_y) = self.pop_axis(vals, axis=axis) + return val_x, val_y + + @cached_property + def _tanq(self) -> float: + """Value of ``tan(sidewall_angle)``. + + The (possibliy infinite) geometry offset is given by ``_tanq * length_axis``. + """ + return np.tan(self.sidewall_angle) + + +class Circular(Geometry): + """Geometry with circular characteristics (specified by a radius).""" + + radius: NonNegativeFloat = Field( + title="Radius", + description="Radius of geometry.", + json_schema_extra={"units": MICROMETER}, + ) + + @field_validator("radius") + @classmethod + def _radius_not_inf(cls, val: float) -> float: + """Make sure center is not infinitiy.""" + if np.isinf(val): + raise ValidationError("radius can not be 'td.inf'.") + return val + + def _intersect_dist(self, position: float, z0: float) -> float: + """Distance between points on circle at z=position where center of circle at z=z0. + + Parameters + ---------- + position : float + position along z. + z0 : float + center of circle in z. + + Returns + ------- + float + Distance between points on the circle intersecting z=z, if no points, ``None``. + """ + dz = np.abs(z0 - position) + if dz > self.radius: + return None + return 2 * np.sqrt(self.radius**2 - dz**2) + + +"""Primitive classes""" + + +class Box(SimplePlaneIntersection, Centered): + """Rectangular prism. + Also base class for :class:`.Simulation`, :class:`Monitor`, and :class:`Source`. + + Example + ------- + >>> b = Box(center=(1,2,3), size=(2,2,2)) + """ + + size: TracedSize = Field( + title="Size", + description="Size in x, y, and z directions.", + json_schema_extra={"units": MICROMETER}, + ) + + @classmethod + def from_bounds(cls, rmin: Coordinate, rmax: Coordinate, **kwargs: Any) -> Self: + """Constructs a :class:`Box` from minimum and maximum coordinate bounds + + Parameters + ---------- + rmin : tuple[float, float, float] + (x, y, z) coordinate of the minimum values. + rmax : tuple[float, float, float] + (x, y, z) coordinate of the maximum values. + + Example + ------- + >>> b = Box.from_bounds(rmin=(-1, -2, -3), rmax=(3, 2, 1)) + """ + + center = tuple(cls._get_center(pt_min, pt_max) for pt_min, pt_max in zip(rmin, rmax)) + size = tuple((pt_max - pt_min) for pt_min, pt_max in zip(rmin, rmax)) + return cls(center=center, size=size, **kwargs) + + @cached_property + def _normal_axis(self) -> Axis: + """Axis normal to the Box. Errors if box is not planar.""" + if self.size.count(0.0) != 1: + raise ValidationError( + f"Tried to get 'normal_axis' of 'Box' that is not planar. Given 'size={self.size}.'" + ) + return self.size.index(0.0) + + @classmethod + def surfaces(cls, size: Size, center: Coordinate, **kwargs: Any) -> list[Self]: + """Returns a list of 6 :class:`Box` instances corresponding to each surface of a 3D volume. + The output surfaces are stored in the order [x-, x+, y-, y+, z-, z+], where x, y, and z + denote which axis is perpendicular to that surface, while "-" and "+" denote the direction + of the normal vector of that surface. If a name is provided, each output surface's name + will be that of the provided name appended with the above symbols. E.g., if the provided + name is "box", the x+ surfaces's name will be "box_x+". + + Parameters + ---------- + size : tuple[float, float, float] + Size of object in x, y, and z directions. + center : tuple[float, float, float] + Center of object in x, y, and z. + + Example + ------- + >>> b = Box.surfaces(size=(1, 2, 3), center=(3, 2, 1)) + """ + + if any(s == 0.0 for s in size): + raise SetupError( + "Can't generate surfaces for the given object because it has zero volume." + ) + + bounds = Box(center=center, size=size).bounds + + # Set up geometry data and names for each surface: + centers = [list(center) for _ in range(6)] + sizes = [list(size) for _ in range(6)] + + surface_index = 0 + for dim_index in range(3): + for min_max_index in range(2): + new_center = centers[surface_index] + new_size = sizes[surface_index] + + new_center[dim_index] = bounds[min_max_index][dim_index] + new_size[dim_index] = 0.0 + + centers[surface_index] = new_center + sizes[surface_index] = new_size + + surface_index += 1 + + name_base = kwargs.pop("name", "") + kwargs.pop("normal_dir", None) + + names = [] + normal_dirs = [] + + for coord in "xyz": + for direction in "-+": + surface_name = name_base + "_" + coord + direction + names.append(surface_name) + normal_dirs.append(direction) + + # ignore surfaces that are infinitely far away + del_idx = [] + for idx, _size in enumerate(size): + if _size == inf: + del_idx.append(idx) + del_idx = [[2 * i, 2 * i + 1] for i in del_idx] + del_idx = [item for sublist in del_idx for item in sublist] + + def del_items(items: Iterable, indices: int) -> list: + """Delete list items at indices.""" + return [i for j, i in enumerate(items) if j not in indices] + + centers = del_items(centers, del_idx) + sizes = del_items(sizes, del_idx) + names = del_items(names, del_idx) + normal_dirs = del_items(normal_dirs, del_idx) + + surfaces = [] + for _cent, _size, _name, _normal_dir in zip(centers, sizes, names, normal_dirs): + if "normal_dir" in cls.model_fields: + kwargs["normal_dir"] = _normal_dir + + if "name" in cls.model_fields: + kwargs["name"] = _name + + surface = cls(center=_cent, size=_size, **kwargs) + surfaces.append(surface) + + return surfaces + + @classmethod + def surfaces_with_exclusion(cls, size: Size, center: Coordinate, **kwargs: Any) -> list[Self]: + """Returns a list of 6 :class:`Box` instances corresponding to each surface of a 3D volume. + The output surfaces are stored in the order [x-, x+, y-, y+, z-, z+], where x, y, and z + denote which axis is perpendicular to that surface, while "-" and "+" denote the direction + of the normal vector of that surface. If a name is provided, each output surface's name + will be that of the provided name appended with the above symbols. E.g., if the provided + name is "box", the x+ surfaces's name will be "box_x+". If ``kwargs`` contains an + ``exclude_surfaces`` parameter, the returned list of surfaces will not include the excluded + surfaces. Otherwise, the behavior is identical to that of ``surfaces()``. + + Parameters + ---------- + size : tuple[float, float, float] + Size of object in x, y, and z directions. + center : tuple[float, float, float] + Center of object in x, y, and z. + + Example + ------- + >>> b = Box.surfaces_with_exclusion( + ... size=(1, 2, 3), center=(3, 2, 1), exclude_surfaces=["x-"] + ... ) + """ + exclude_surfaces = kwargs.pop("exclude_surfaces", None) + surfaces = cls.surfaces(size=size, center=center, **kwargs) + if "name" in cls.model_fields and exclude_surfaces: + surfaces = [surf for surf in surfaces if surf.name[-2:] not in exclude_surfaces] + return surfaces + + @verify_packages_import(["trimesh"]) + def _do_intersections_tilted_plane( + self, + normal: Coordinate, + origin: Coordinate, + to_2D: MatrixReal4x4, + quad_segs: Optional[int] = None, + ) -> list[Shapely]: + """Return a list of shapely geometries at the plane specified by normal and origin. + + Parameters + ---------- + normal : Coordinate + Vector defining the normal direction to the plane. + origin : Coordinate + Vector defining the plane origin. + to_2D : MatrixReal4x4 + Transformation matrix to apply to resulting shapes. + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. Not used for Box geometry. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation `_. + """ + import trimesh + + (x0, y0, z0), (x1, y1, z1) = self.bounds + vertices = [ + (x0, y0, z0), # 0 + (x0, y0, z1), # 1 + (x0, y1, z0), # 2 + (x0, y1, z1), # 3 + (x1, y0, z0), # 4 + (x1, y0, z1), # 5 + (x1, y1, z0), # 6 + (x1, y1, z1), # 7 + ] + faces = [ + (0, 1, 3, 2), # -x + (4, 6, 7, 5), # +x + (0, 4, 5, 1), # -y + (2, 3, 7, 6), # +y + (0, 2, 6, 4), # -z + (1, 5, 7, 3), # +z + ] + mesh = trimesh.Trimesh(vertices, faces) + + section = mesh.section(plane_origin=origin, plane_normal=normal) + if section is None: + return [] + path, _ = section.to_2D(to_2D=to_2D) + return path.polygons_full + + def intersections_plane( + self, + x: Optional[float] = None, + y: Optional[float] = None, + z: Optional[float] = None, + cleanup: bool = True, + quad_segs: Optional[int] = None, + ) -> list[Shapely]: + """Returns shapely geometry at plane specified by one non None value of x,y,z. + + Parameters + ---------- + x : float = None + Position of plane in x direction, only one of x,y,z can be specified to define plane. + y : float = None + Position of plane in y direction, only one of x,y,z can be specified to define plane. + z : float = None + Position of plane in z direction, only one of x,y,z can be specified to define plane. + cleanup : bool = True + If True, removes extremely small features from each polygon's boundary. + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. Not used for Box geometry. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation `_. + """ + axis, position = self.parse_xyz_kwargs(x=x, y=y, z=z) + if not self.intersects_axis_position(axis, position): + return [] + z0, (x0, y0) = self.pop_axis(self.center, axis=axis) + Lz, (Lx, Ly) = self.pop_axis(self.size, axis=axis) + dz = np.abs(z0 - position) + if dz > Lz / 2 + fp_eps: + return [] + + minx = x0 - Lx / 2 + miny = y0 - Ly / 2 + maxx = x0 + Lx / 2 + maxy = y0 + Ly / 2 + + # handle case where the box vertices are identical + if np.isclose(minx, maxx) and np.isclose(miny, maxy): + return [self.make_shapely_point(minx, miny)] + + return [self.make_shapely_box(minx, miny, maxx, maxy)] + + def inside(self, x: NDArray[float], y: NDArray[float], z: NDArray[float]) -> NDArray[bool]: + """For input arrays ``x``, ``y``, ``z`` of arbitrary but identical shape, return an array + with the same shape which is ``True`` for every point in zip(x, y, z) that is inside the + volume of the :class:`Geometry`, and ``False`` otherwise. + + Parameters + ---------- + x : np.ndarray[float] + Array of point positions in x direction. + y : np.ndarray[float] + Array of point positions in y direction. + z : np.ndarray[float] + Array of point positions in z direction. + + Returns + ------- + np.ndarray[bool] + ``True`` for every point that is inside the geometry. + """ + self._ensure_equal_shape(x, y, z) + x0, y0, z0 = self.center + Lx, Ly, Lz = self.size + dist_x = np.abs(x - x0) + dist_y = np.abs(y - y0) + dist_z = np.abs(z - z0) + return (dist_x <= Lx / 2) * (dist_y <= Ly / 2) * (dist_z <= Lz / 2) + + def intersections_with( + self, other: Shapely, cleanup: bool = True, quad_segs: Optional[int] = None + ) -> list[Shapely]: + """Returns list of shapely geometries representing the intersections of the geometry with + this 2D box. + + Parameters + ---------- + other : Shapely + Geometry to intersect with. + cleanup : bool = True + If True, removes extremely small features from each polygon's boundary. + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. If ``None``, uses + high-quality visualization settings. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect this 2D box. + For more details refer to + `Shapely's Documentation `_. + """ + + # Verify 2D + if self.size.count(0.0) != 1: + raise ValidationError( + "Intersections with other geometry are only calculated from a 2D box." + ) + + # dont bother if the geometry doesn't intersect the self at all + if not other.intersects(self): + return [] + + # get list of Shapely shapes that intersect at the self + normal_ind = self.size.index(0.0) + dim = "xyz"[normal_ind] + pos = self.center[normal_ind] + xyz_kwargs = {dim: pos} + shapes_plane = other.intersections_plane(cleanup=cleanup, quad_segs=quad_segs, **xyz_kwargs) + + # intersect all shapes with the input self + bs_min, bs_max = (self.pop_axis(bounds, axis=normal_ind)[1] for bounds in self.bounds) + + shapely_box = self.make_shapely_box(bs_min[0], bs_min[1], bs_max[0], bs_max[1]) + shapely_box = Geometry.evaluate_inf_shape(shapely_box) + return [Geometry.evaluate_inf_shape(shape) & shapely_box for shape in shapes_plane] + + def slightly_enlarged_copy(self) -> Box: + """Box size slightly enlarged around machine precision.""" + size = [increment_float(orig_length, 1) for orig_length in self.size] + return self.updated_copy(size=size) + + def padded_copy( + self, + x: Optional[tuple[pydantic.NonNegativeFloat, pydantic.NonNegativeFloat]] = None, + y: Optional[tuple[pydantic.NonNegativeFloat, pydantic.NonNegativeFloat]] = None, + z: Optional[tuple[pydantic.NonNegativeFloat, pydantic.NonNegativeFloat]] = None, + ) -> Box: + """Created a padded copy of a :class:`Box` instance. + + Parameters + ---------- + x : Optional[tuple[pydantic.NonNegativeFloat, pydantic.NonNegativeFloat]] = None + Padding sizes at the left and right boundaries of the box along x-axis. + y : Optional[tuple[pydantic.NonNegativeFloat, pydantic.NonNegativeFloat]] = None + Padding sizes at the left and right boundaries of the box along y-axis. + z : Optional[tuple[pydantic.NonNegativeFloat, pydantic.NonNegativeFloat]] = None + Padding sizes at the left and right boundaries of the box along z-axis. + + Returns + ------- + Box + Padded instance of :class:`Box`. + """ + + # Validate that padding values are non-negative + for axis_name, axis_padding in zip(("x", "y", "z"), (x, y, z)): + if axis_padding is not None: + if not isinstance(axis_padding, (tuple, list)) or len(axis_padding) != 2: + raise ValueError(f"Padding for {axis_name}-axis must be a tuple of two values.") + if any(p < 0 for p in axis_padding): + raise ValueError( + f"Padding values for {axis_name}-axis must be non-negative. Got {axis_padding}." + ) + + rmin, rmax = self.bounds + + def bound_array(arrs: ArrayLike, idx: int) -> NDArray: + return np.array([(a[idx] if a is not None else 0) for a in arrs]) + + # parse padding sizes for simulation + drmin = bound_array((x, y, z), 0) + drmax = bound_array((x, y, z), 1) + + rmin = np.array(rmin) - drmin + rmax = np.array(rmax) + drmax + + return Box.from_bounds(rmin=rmin, rmax=rmax) + + @cached_property + def bounds(self) -> Bound: + """Returns bounding box min and max coordinates. + + Returns + ------- + tuple[float, float, float], tuple[float, float float] + Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. + """ + size = self.size + center = self.center + coord_min = tuple(c - s / 2 for (s, c) in zip(size, center)) + coord_max = tuple(c + s / 2 for (s, c) in zip(size, center)) + return (coord_min, coord_max) + + @cached_property + def geometry(self) -> Box: + """:class:`Box` representation of self (used for subclasses of Box). + + Returns + ------- + :class:`Box` + Instance of :class:`Box` representing self's geometry. + """ + return Box(center=self.center, size=self.size) + + @cached_property + def zero_dims(self) -> list[Axis]: + """A list of axes along which the :class:`Box` is zero-sized.""" + return [dim for dim, size in enumerate(self.size) if size == 0] + + @cached_property + def _normal_2dmaterial(self) -> Axis: + """Get the normal to the given geometry, checking that it is a 2D geometry.""" + if np.count_nonzero(self.size) != 2: + raise ValidationError( + "'Medium2D' requires exactly one of the 'Box' dimensions to have size zero." + ) + return self.size.index(0) + + def _update_from_bounds(self, bounds: tuple[float, float], axis: Axis) -> Box: + """Returns an updated geometry which has been transformed to fit within ``bounds`` + along the ``axis`` direction.""" + new_center = list(self.center) + new_center[axis] = (bounds[0] + bounds[1]) / 2 + new_size = list(self.size) + new_size[axis] = bounds[1] - bounds[0] + return self.updated_copy(center=tuple(new_center), size=tuple(new_size)) + + def _plot_arrow( + self, + direction: tuple[float, float, float], + x: Optional[float] = None, + y: Optional[float] = None, + z: Optional[float] = None, + color: Optional[str] = None, + alpha: Optional[float] = None, + bend_radius: Optional[float] = None, + bend_axis: Axis = None, + both_dirs: bool = False, + ax: Ax = None, + arrow_base: Coordinate = None, + ) -> Ax: + """Adds an arrow to the axis if with options if certain conditions met. + + Parameters + ---------- + direction: tuple[float, float, float] + Normalized vector describing the arrow direction. + x : float = None + Position of plotting plane in x direction. + y : float = None + Position of plotting plane in y direction. + z : float = None + Position of plotting plane in z direction. + color : str = None + Color of the arrow. + alpha : float = None + Opacity of the arrow (0, 1) + bend_radius : float = None + Radius of curvature for this arrow. + bend_axis : Axis = None + Axis of curvature of ``bend_radius``. + both_dirs : bool = False + If True, plots an arrow pointing in direction and one in -direction. + arrow_base : :class:`.Coordinate` = None + Custom base of the arrow. Uses the geometry's center if not provided. + + Returns + ------- + matplotlib.axes._subplots.Axes + The matplotlib axes with the arrow added. + """ + + plot_axis, _ = self.parse_xyz_kwargs(x=x, y=y, z=z) + _, (dx, dy) = self.pop_axis(direction, axis=plot_axis) + + # conditions to check to determine whether to plot arrow, taking into account the + # possibility of a custom arrow base + arrow_intersecting_plane = len(self.intersections_plane(x=x, y=y, z=z)) > 0 + center = self.center + if arrow_base: + arrow_intersecting_plane = arrow_intersecting_plane and any( + a == b for a, b in zip(arrow_base, [x, y, z]) + ) + center = arrow_base + + _, (dx, dy) = self.pop_axis(direction, axis=plot_axis) + components_in_plane = any(not np.isclose(component, 0) for component in (dx, dy)) + + # plot if arrow in plotting plane and some non-zero component can be displayed. + if arrow_intersecting_plane and components_in_plane: + _, (x0, y0) = self.pop_axis(center, axis=plot_axis) + + # Reasonable value for temporary arrow size. The correct size and direction + # have to be calculated after all transforms have been set. That is why we + # use a callback to do these calculations only at the drawing phase. + xmin, xmax = ax.get_xlim() + ymin, ymax = ax.get_ylim() + v_x = (xmax - xmin) / 10 + v_y = (ymax - ymin) / 10 + + directions = (1.0, -1.0) if both_dirs else (1.0,) + for sign in directions: + arrow = patches.FancyArrowPatch( + (x0, y0), + (x0 + v_x, y0 + v_y), + arrowstyle=arrow_style, + color=color, + alpha=alpha, + zorder=np.inf, + ) + # Don't draw this arrow until it's been reshaped + arrow.set_visible(False) + + callback = self._arrow_shape_cb( + arrow, (x0, y0), (dx, dy), sign, bend_radius if bend_axis == plot_axis else None + ) + callback_id = ax.figure.canvas.mpl_connect("draw_event", callback) + + # Store a reference to the callback because mpl_connect does not. + arrow.set_shape_cb = (callback_id, callback) + + ax.add_patch(arrow) + + return ax + + @staticmethod + def _arrow_shape_cb( + arrow: FancyArrowPatch, + pos: tuple[float, float], + direction: ArrayLike, + sign: float, + bend_radius: float | None, + ) -> Callable[[Event], None]: + def _cb(event: Event) -> None: + # We only want to set the shape once, so we disconnect ourselves + event.canvas.mpl_disconnect(arrow.set_shape_cb[0]) + + transform = arrow.axes.transData.transform + scale_x = transform((1, 0))[0] - transform((0, 0))[0] + scale_y = transform((0, 1))[1] - transform((0, 0))[1] + scale = max(scale_x, scale_y) # <-- Hack: This is a somewhat arbitrary choice. + arrow_length = ARROW_LENGTH * event.canvas.figure.get_dpi() / scale + + if bend_radius: + v_norm = (direction[0] ** 2 + direction[1] ** 2) ** 0.5 + vx_norm = direction[0] / v_norm + vy_norm = direction[1] / v_norm + bend_angle = -sign * arrow_length / bend_radius + t_x = 1 - np.cos(bend_angle) + t_y = np.sin(bend_angle) + v_x = -bend_radius * (vx_norm * t_y - vy_norm * t_x) + v_y = -bend_radius * (vx_norm * t_x + vy_norm * t_y) + tangent_angle = np.arctan2(direction[1], direction[0]) + arrow.set_connectionstyle( + patches.ConnectionStyle.Angle3( + angleA=180 / np.pi * tangent_angle, + angleB=180 / np.pi * (tangent_angle + bend_angle), + ) + ) + + else: + v_x = sign * arrow_length * direction[0] + v_y = sign * arrow_length * direction[1] + + arrow.set_positions(pos, (pos[0] + v_x, pos[1] + v_y)) + arrow.set_visible(True) + arrow.draw(event.renderer) + + return _cb + + def _volume(self, bounds: Bound) -> float: + """Returns object's volume within given bounds.""" + + volume = 1 + + for axis in range(3): + min_bound = max(self.bounds[0][axis], bounds[0][axis]) + max_bound = min(self.bounds[1][axis], bounds[1][axis]) + + volume *= max_bound - min_bound + + return volume + + def _surface_area(self, bounds: Bound) -> float: + """Returns object's surface area within given bounds.""" + + min_bounds = list(self.bounds[0]) + max_bounds = list(self.bounds[1]) + + in_bounds_factor = [2, 2, 2] + length = [0, 0, 0] + + for axis in (0, 1, 2): + if min_bounds[axis] < bounds[0][axis]: + min_bounds[axis] = bounds[0][axis] + in_bounds_factor[axis] -= 1 + + if max_bounds[axis] > bounds[1][axis]: + max_bounds[axis] = bounds[1][axis] + in_bounds_factor[axis] -= 1 + + length[axis] = max_bounds[axis] - min_bounds[axis] + + return ( + length[0] * length[1] * in_bounds_factor[2] + + length[1] * length[2] * in_bounds_factor[0] + + length[2] * length[0] * in_bounds_factor[1] + ) + + """ Autograd code """ + + def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: + """Compute the adjoint derivatives for this object.""" + + # get gradients w.r.t. each of the 6 faces (in normal direction) + vjps_faces = self._derivative_faces(derivative_info=derivative_info) + + # post-process these values to give the gradients w.r.t. center and size + vjps_center_size = self._derivatives_center_size(vjps_faces=vjps_faces) + + # store only the gradients asked for in 'field_paths' + derivative_map = {} + for field_path in derivative_info.paths: + field_name, *index = field_path + + if field_name in vjps_center_size: + # if the vjp calls for a specific index into the tuple + if index and len(index) == 1: + index = int(index[0]) + if field_path not in derivative_map: + derivative_map[field_path] = vjps_center_size[field_name][index] + + # otherwise, just grab the whole array + else: + derivative_map[field_path] = vjps_center_size[field_name] + + return derivative_map + + @staticmethod + def _derivatives_center_size(vjps_faces: Bound) -> dict[str, Coordinate]: + """Derivatives with respect to the ``center`` and ``size`` fields in the ``Box``.""" + + vjps_faces_min, vjps_faces_max = np.array(vjps_faces) + + # post-process min and max face gradients into center and size + vjp_center = vjps_faces_max - vjps_faces_min + vjp_size = (vjps_faces_min + vjps_faces_max) / 2.0 + + return { + "center": tuple(vjp_center.tolist()), + "size": tuple(vjp_size.tolist()), + } + + def _derivative_faces(self, derivative_info: DerivativeInfo) -> Bound: + """Derivative with respect to normal position of 6 faces of ``Box``.""" + + axes_to_compute = (0, 1, 2) + if len(derivative_info.paths[0]) > 1: + axes_to_compute = tuple(info[1] for info in derivative_info.paths) + + # change in permittivity between inside and outside + vjp_faces = np.zeros((2, 3)) + + for min_max_index, _ in enumerate((0, -1)): + for axis in axes_to_compute: + vjp_face = self._derivative_face( + min_max_index=min_max_index, + axis_normal=axis, + derivative_info=derivative_info, + ) + + # record vjp for this face + vjp_faces[min_max_index, axis] = vjp_face + + return vjp_faces + + def _derivative_face( + self, + min_max_index: int, + axis_normal: Axis, + derivative_info: DerivativeInfo, + ) -> float: + """Compute the derivative w.r.t. shifting a face in the normal direction.""" + + interpolators = derivative_info.interpolators or derivative_info.create_interpolators() + _, axis_perp = self.pop_axis((0, 1, 2), axis=axis_normal) + + # First, check if the face is outside the simulation domain in which case set the + # face gradient to 0. + bounds_normal, bounds_perp = self.pop_axis( + np.array(derivative_info.bounds).T, axis=axis_normal + ) + coord_normal_face = bounds_normal[min_max_index] + + if min_max_index == 0: + if coord_normal_face < derivative_info.simulation_bounds[0][axis_normal]: + return 0.0 + else: + if coord_normal_face > derivative_info.simulation_bounds[1][axis_normal]: + return 0.0 + + intersect_min, intersect_max = map(np.asarray, derivative_info.bounds_intersect) + extents = intersect_max - intersect_min + _, intersect_min_perp = self.pop_axis(np.array(intersect_min), axis=axis_normal) + _, intersect_max_perp = self.pop_axis(np.array(intersect_max), axis=axis_normal) + + is_2d_map = [] + for axis_idx in range(3): + if axis_idx == axis_normal: + continue + is_2d_map.append(np.isclose(extents[axis_idx], 0.0)) + + if np.all(is_2d_map): + return 0.0 + + is_2d = np.any(is_2d_map) + + sim_bounds_normal, sim_bounds_perp = self.pop_axis( + np.array(derivative_info.simulation_bounds).T, axis=axis_normal + ) + + # Build point grid + adaptive_spacing = derivative_info.adaptive_vjp_spacing() + + def spacing_to_grid_points( + spacing: float, min_coord: float, max_coord: float + ) -> NDArray[float]: + N = np.maximum(3, 1 + int((max_coord - min_coord) / spacing)) + + points = np.linspace(min_coord, max_coord, N) + centers = 0.5 * (points[0:-1] + points[1:]) + + return centers + + def verify_integration_interval(bound: tuple[float, float]) -> bool: + # assume the bounds should not be equal or else this integration interval + # would be the flat dimension of a 2D geometry. + return bound[1] > bound[0] + + def compute_integration_weight(grid_points: NDArray[float]) -> float: + grid_spacing = grid_points[1] - grid_points[0] + if grid_spacing == 0.0: + integration_weight = 1.0 / len(grid_points) + else: + integration_weight = grid_points[1] - grid_points[0] + + return integration_weight + + if is_2d: + # build 1D grid for sampling points along the face, which is an edge in the 2D case + zero_dim = np.where(is_2d_map)[0][0] + # zero dim is one of the perpendicular directions, so the other perpendicular direction + # is the nonzero dimension + nonzero_dim = 1 - zero_dim + + # clip at simulation bounds for integration dimension + integration_bounds_perp = ( + intersect_min_perp[nonzero_dim], + intersect_max_perp[nonzero_dim], + ) + + if not verify_integration_interval(integration_bounds_perp): + return 0.0 + + grid_points_linear = spacing_to_grid_points( + adaptive_spacing, integration_bounds_perp[0], integration_bounds_perp[1] + ) + integration_weight = compute_integration_weight(grid_points_linear) + + grid_points = np.repeat(np.expand_dims(grid_points_linear.copy(), 1), 3, axis=1) + + # set up grid points to pass into evaluate_gradient_at_points + grid_points[:, axis_perp[nonzero_dim]] = grid_points_linear + grid_points[:, axis_perp[zero_dim]] = intersect_min_perp[zero_dim] + grid_points[:, axis_normal] = coord_normal_face + else: + # build 3D grid for sampling points along the face + + # clip at simulation bounds for each integration dimension + integration_bounds_perp = ( + (intersect_min_perp[0], intersect_max_perp[0]), + (intersect_min_perp[1], intersect_max_perp[1]), + ) + + if not np.all([verify_integration_interval(b) for b in integration_bounds_perp]): + return 0.0 + + grid_points_perp_1 = spacing_to_grid_points( + adaptive_spacing, integration_bounds_perp[0][0], integration_bounds_perp[0][1] + ) + grid_points_perp_2 = spacing_to_grid_points( + adaptive_spacing, integration_bounds_perp[1][0], integration_bounds_perp[1][1] + ) + integration_weight = compute_integration_weight( + grid_points_perp_1 + ) * compute_integration_weight(grid_points_perp_2) + + mesh_perp1, mesh_perp2 = np.meshgrid(grid_points_perp_1, grid_points_perp_2) + + zip_perp_coords = np.array(list(zip(mesh_perp1.flatten(), mesh_perp2.flatten()))) + + grid_points = np.pad(zip_perp_coords.copy(), ((0, 0), (1, 0)), mode="constant") + + # set up grid points to pass into evaluate_gradient_at_points + grid_points[:, axis_perp[0]] = zip_perp_coords[:, 0] + grid_points[:, axis_perp[1]] = zip_perp_coords[:, 1] + grid_points[:, axis_normal] = coord_normal_face + + normals = np.zeros_like(grid_points) + perps1 = np.zeros_like(grid_points) + perps2 = np.zeros_like(grid_points) + + normals[:, axis_normal] = -1 if (min_max_index == 0) else 1 + perps1[:, axis_perp[0]] = 1 + perps2[:, axis_perp[1]] = 1 + + gradient_at_points = derivative_info.evaluate_gradient_at_points( + spatial_coords=grid_points, + normals=normals, + perps1=perps1, + perps2=perps2, + interpolators=interpolators, + ) + + vjp_value = np.sum(integration_weight * np.real(gradient_at_points)) + return vjp_value + + +"""Compound subclasses""" + + +class Transformed(Geometry): + """Class representing a transformed geometry.""" + + geometry: discriminated_union(GeometryType) = Field( + title="Geometry", + description="Base geometry to be transformed.", + ) + + transform: MatrixReal4x4 = Field( + default_factory=lambda: np.eye(4).tolist(), + title="Transform", + description="Transform matrix applied to the base geometry.", + ) + + @field_validator("transform") + @classmethod + def _transform_is_invertible(cls, val: MatrixReal4x4) -> MatrixReal4x4: + # If the transform is not invertible, this will raise an error + _ = np.linalg.inv(val) + return val + + @field_validator("geometry") + @classmethod + def _geometry_is_finite(cls, val: GeometryType) -> GeometryType: + if not np.isfinite(val.bounds).all(): + raise ValidationError( + "Transformations are only supported on geometries with finite dimensions. " + "Try using a large value instead of 'inf' when creating geometries that undergo " + "transformations." + ) + return val + + @model_validator(mode="after") + def _apply_transforms(self: dict[str, Any]) -> dict[str, Any]: + while isinstance(self.geometry, Transformed): + inner = self.geometry + object.__setattr__(self, "geometry", inner.geometry) + object.__setattr__(self, "transform", np.dot(self.transform, inner.transform)) + return self + + @cached_property + def inverse(self) -> MatrixReal4x4: + """Inverse of this transform.""" + return np.linalg.inv(self.transform) + + @staticmethod + def _vertices_from_bounds(bounds: Bound) -> ArrayFloat2D: + """Return the 8 vertices derived from bounds. + + The vertices are returned as homogeneous coordinates (with 4 components). + + Parameters + ---------- + bounds : Bound + Bounds from which to derive the vertices. + + Returns + ------- + ArrayFloat2D + Array with shape (4, 8) with all vertices from ``bounds``. + """ + (x0, y0, z0), (x1, y1, z1) = bounds + return np.array( + ( + (x0, x0, x0, x0, x1, x1, x1, x1), + (y0, y0, y1, y1, y0, y0, y1, y1), + (z0, z1, z0, z1, z0, z1, z0, z1), + (1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0), + ) + ) + + @cached_property + def bounds(self) -> Bound: + """Returns bounding box min and max coordinates. + + Returns + ------- + tuple[float, float, float], tuple[float, float, float] + Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. + """ + # NOTE (Lucas): The bounds are overestimated because we don't want to calculate + # precise TriangleMesh representations for GeometryGroup or ClipOperation. + vertices = np.dot(self.transform, self._vertices_from_bounds(self.geometry.bounds))[:3] + return (tuple(vertices.min(axis=1)), tuple(vertices.max(axis=1))) + + def intersections_tilted_plane( + self, + normal: Coordinate, + origin: Coordinate, + to_2D: MatrixReal4x4, + cleanup: bool = True, + quad_segs: Optional[int] = None, + ) -> list[Shapely]: + """Return a list of shapely geometries at the plane specified by normal and origin. + + Parameters + ---------- + normal : Coordinate + Vector defining the normal direction to the plane. + origin : Coordinate + Vector defining the plane origin. + to_2D : MatrixReal4x4 + Transformation matrix to apply to resulting shapes. + cleanup : bool = True + If True, removes extremely small features from each polygon's boundary. + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. If ``None``, uses + high-quality visualization settings. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation `_. + """ + return self.geometry.intersections_tilted_plane( + tuple(np.dot((normal[0], normal[1], normal[2], 0.0), self.transform)[:3]), + tuple(np.dot(self.inverse, (origin[0], origin[1], origin[2], 1.0))[:3]), + np.dot(to_2D, self.transform), + cleanup=cleanup, + quad_segs=quad_segs, + ) + + def inside(self, x: NDArray[float], y: NDArray[float], z: NDArray[float]) -> NDArray[bool]: + """For input arrays ``x``, ``y``, ``z`` of arbitrary but identical shape, return an array + with the same shape which is ``True`` for every point in zip(x, y, z) that is inside the + volume of the :class:`Geometry`, and ``False`` otherwise. + + Parameters + ---------- + x : np.ndarray[float] + Array of point positions in x direction. + y : np.ndarray[float] + Array of point positions in y direction. + z : np.ndarray[float] + Array of point positions in z direction. + + Returns + ------- + np.ndarray[bool] + ``True`` for every point that is inside the geometry. + """ + x = np.array(x) + y = np.array(y) + z = np.array(z) + xyz = np.dot(self.inverse, np.vstack((x.flat, y.flat, z.flat, np.ones(x.size)))) + if xyz.shape[1] == 1: + # TODO: This "fix" is required because of a bug in PolySlab.inside (with non-zero sidewall angle) + return self.geometry.inside(xyz[0][0], xyz[1][0], xyz[2][0]).reshape(x.shape) + return self.geometry.inside(xyz[0], xyz[1], xyz[2]).reshape(x.shape) + + def _volume(self, bounds: Bound) -> float: + """Returns object's volume within given bounds.""" + # NOTE (Lucas): Bounds are overestimated. + vertices = np.dot(self.inverse, self._vertices_from_bounds(bounds))[:3] + inverse_bounds = (tuple(vertices.min(axis=1)), tuple(vertices.max(axis=1))) + return abs(np.linalg.det(self.transform)) * self.geometry.volume(inverse_bounds) + + def _surface_area(self, bounds: Bound) -> float: + """Returns object's surface area within given bounds.""" + log.warning("Surface area of transformed elements cannot be calculated.") + return None + + @staticmethod + def translation(x: float, y: float, z: float) -> MatrixReal4x4: + """Return a translation matrix. + + Parameters + ---------- + x : float + Translation along x. + y : float + Translation along y. + z : float + Translation along z. + + Returns + ------- + numpy.ndarray + Transform matrix with shape (4, 4). + """ + return np.array( + [ + (1.0, 0.0, 0.0, x), + (0.0, 1.0, 0.0, y), + (0.0, 0.0, 1.0, z), + (0.0, 0.0, 0.0, 1.0), + ], + dtype=float, + ) + + @staticmethod + def scaling(x: float = 1.0, y: float = 1.0, z: float = 1.0) -> MatrixReal4x4: + """Return a scaling matrix. + + Parameters + ---------- + x : float = 1.0 + Scaling factor along x. + y : float = 1.0 + Scaling factor along y. + z : float = 1.0 + Scaling factor along z. + + Returns + ------- + numpy.ndarray + Transform matrix with shape (4, 4). + """ + if np.isclose((x, y, z), 0.0).any(): + raise Tidy3dError("Scaling factors cannot be zero in any dimensions.") + return np.array( + [ + (x, 0.0, 0.0, 0.0), + (0.0, y, 0.0, 0.0), + (0.0, 0.0, z, 0.0), + (0.0, 0.0, 0.0, 1.0), + ], + dtype=float, + ) + + @staticmethod + def rotation(angle: float, axis: Union[Axis, Coordinate]) -> MatrixReal4x4: + """Return a rotation matrix. + + Parameters + ---------- + angle : float + Rotation angle (in radians). + axis : Union[int, tuple[float, float, float]] + Axis of rotation: 0, 1, or 2 for x, y, and z, respectively, or a 3D vector. + + Returns + ------- + numpy.ndarray + Transform matrix with shape (4, 4). + """ + transform = np.eye(4) + transform[:3, :3] = RotationAroundAxis(angle=angle, axis=axis).matrix + return transform + + @staticmethod + def reflection(normal: Coordinate) -> MatrixReal4x4: + """Return a reflection matrix. + + Parameters + ---------- + normal : tuple[float, float, float] + Normal of the plane of reflection. + + Returns + ------- + numpy.ndarray + Transform matrix with shape (4, 4). + """ + + transform = np.eye(4) + transform[:3, :3] = ReflectionFromPlane(normal=normal).matrix + return transform + + @staticmethod + def preserves_axis(transform: MatrixReal4x4, axis: Axis) -> bool: + """Indicate if the transform preserves the orientation of a given axis. + + Parameters: + transform: MatrixReal4x4 + Transform matrix to check. + axis : int + Axis to check. Values 0, 1, or 2, to check x, y, or z, respectively. + + Returns + ------- + bool + ``True`` if the transformation preserves the axis orientation, ``False`` otherwise. + """ + i = (axis + 1) % 3 + j = (axis + 2) % 3 + return np.isclose(transform[i, axis], 0) and np.isclose(transform[j, axis], 0) + + @cached_property + def _normal_2dmaterial(self) -> Axis: + """Get the normal to the given geometry, checking that it is a 2D geometry.""" + normal = self.geometry._normal_2dmaterial + preserves_axis = Transformed.preserves_axis(self.transform, normal) + + if not preserves_axis: + raise ValidationError( + "'Medium2D' requires geometries of type 'Transformed' to " + "perserve the axis normal to the 'Medium2D'." + ) + + return normal + + def _update_from_bounds(self, bounds: tuple[float, float], axis: Axis) -> Transformed: + """Returns an updated geometry which has been transformed to fit within ``bounds`` + along the ``axis`` direction.""" + min_bound = np.array([0, 0, 0, 1.0]) + min_bound[axis] = bounds[0] + max_bound = np.array([0, 0, 0, 1.0]) + max_bound[axis] = bounds[1] + new_bounds = [] + new_bounds.append(np.dot(self.inverse, min_bound)[axis]) + new_bounds.append(np.dot(self.inverse, max_bound)[axis]) + new_geometry = self.geometry._update_from_bounds(bounds=new_bounds, axis=axis) + return self.updated_copy(geometry=new_geometry) + + +class ClipOperation(Geometry): + """Class representing the result of a set operation between geometries.""" + + operation: ClipOperationType = Field( + title="Operation Type", + description="Operation to be performed between geometries.", + ) + + geometry_a: discriminated_union(GeometryType) = Field( + title="Geometry A", + description="First operand for the set operation. It can be any geometry type, including " + ":class:`GeometryGroup`.", + ) + + geometry_b: discriminated_union(GeometryType) = Field( + title="Geometry B", + description="Second operand for the set operation. It can also be any geometry type.", + ) + + @field_validator("geometry_a", "geometry_b") + @classmethod + def _geometries_untraced(cls, val: GeometryType) -> GeometryType: + """Make sure that ``ClipOperation`` geometries do not contain tracers.""" + traced = val._strip_traced_fields() + if traced: + raise ValidationError( + f"{val.type} contains traced fields {list(traced.keys())}. Note that " + "'ClipOperation' does not currently support automatic differentiation." + ) + return val + + @staticmethod + def to_polygon_list(base_geometry: Shapely, cleanup: bool = False) -> list[Shapely]: + """Return a list of valid polygons from a shapely geometry, discarding points, lines, and + empty polygons, and empty triangles within polygons. + + Parameters + ---------- + base_geometry : shapely.geometry.base.BaseGeometry + Base geometry for inspection. + cleanup: bool = False + If True, removes extremely small features from each polygon's boundary. + This is useful for removing artifacts from 2D plots displayed to the user. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + Valid polygons retrieved from ``base geometry``. + """ + unfiltered_geoms = [] + if base_geometry.geom_type == "GeometryCollection": + unfiltered_geoms = [ + p + for geom in base_geometry.geoms + for p in ClipOperation.to_polygon_list(geom, cleanup) + ] + if base_geometry.geom_type == "MultiPolygon": + unfiltered_geoms = [p for p in base_geometry.geoms if not p.is_empty] + if base_geometry.geom_type == "Polygon" and not base_geometry.is_empty: + unfiltered_geoms = [base_geometry] + geoms = [] + if cleanup: + # Optional: "clean" each of the polygons (by removing extremely small or thin features). + for geom in unfiltered_geoms: + geom_clean = cleanup_shapely_object(geom) + if geom_clean.geom_type == "Polygon": + geoms.append(geom_clean) + if geom_clean.geom_type == "MultiPolygon": + geoms += [p for p in geom_clean.geoms if not p.is_empty] + # Ignore other types of shapely objects (points and lines) + else: + geoms = unfiltered_geoms + return geoms + + @property + def _shapely_operation(self) -> Callable[[Shapely, Shapely], Shapely]: + """Return a Shapely function equivalent to this operation.""" + result = _shapely_operations.get(self.operation, None) + if not result: + raise ValueError( + "'operation' must be one of 'union', 'intersection', 'difference', or " + "'symmetric_difference'." + ) + return result + + @property + def _bit_operation(self) -> Callable[[Any, Any], Any]: + """Return a function equivalent to this operation using bit operators.""" + result = _bit_operations.get(self.operation, None) + if not result: + raise ValueError( + "'operation' must be one of 'union', 'intersection', 'difference', or " + "'symmetric_difference'." + ) + return result + + def intersections_tilted_plane( + self, + normal: Coordinate, + origin: Coordinate, + to_2D: MatrixReal4x4, + cleanup: bool = True, + quad_segs: Optional[int] = None, + ) -> list[Shapely]: + """Return a list of shapely geometries at the plane specified by normal and origin. + + Parameters + ---------- + normal : Coordinate + Vector defining the normal direction to the plane. + origin : Coordinate + Vector defining the plane origin. + to_2D : MatrixReal4x4 + Transformation matrix to apply to resulting shapes. + cleanup : bool = True + If True, removes extremely small features from each polygon's boundary. + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. If ``None``, uses + high-quality visualization settings. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation `_. + """ + a = self.geometry_a.intersections_tilted_plane( + normal, origin, to_2D, cleanup=cleanup, quad_segs=quad_segs + ) + b = self.geometry_b.intersections_tilted_plane( + normal, origin, to_2D, cleanup=cleanup, quad_segs=quad_segs + ) + geom_a = shapely.unary_union([Geometry.evaluate_inf_shape(g) for g in a]) + geom_b = shapely.unary_union([Geometry.evaluate_inf_shape(g) for g in b]) + return ClipOperation.to_polygon_list( + self._shapely_operation(geom_a, geom_b), + cleanup=cleanup, + ) + + def intersections_plane( + self, + x: Optional[float] = None, + y: Optional[float] = None, + z: Optional[float] = None, + cleanup: bool = True, + quad_segs: Optional[int] = None, + ) -> list[Shapely]: + """Returns list of shapely geometries at plane specified by one non-None value of x,y,z. + + Parameters + ---------- + x : float = None + Position of plane in x direction, only one of x,y,z can be specified to define plane. + y : float = None + Position of plane in y direction, only one of x,y,z can be specified to define plane. + z : float = None + Position of plane in z direction, only one of x,y,z can be specified to define plane. + cleanup : bool = True + If True, removes extremely small features from each polygon's boundary. + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. If ``None``, uses + high-quality visualization settings. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentaton `_. + """ + a = self.geometry_a.intersections_plane(x, y, z, cleanup=cleanup, quad_segs=quad_segs) + b = self.geometry_b.intersections_plane(x, y, z, cleanup=cleanup, quad_segs=quad_segs) + geom_a = shapely.unary_union([Geometry.evaluate_inf_shape(g) for g in a]) + geom_b = shapely.unary_union([Geometry.evaluate_inf_shape(g) for g in b]) + return ClipOperation.to_polygon_list( + self._shapely_operation(geom_a, geom_b), + cleanup=cleanup, + ) + + @cached_property + def bounds(self) -> Bound: + """Returns bounding box min and max coordinates. + + Returns + ------- + tuple[float, float, float], tuple[float, float float] + Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. + """ + # Overestimates + if self.operation == "difference": + result = self.geometry_a.bounds + elif self.operation == "intersection": + bounds = (self.geometry_a.bounds, self.geometry_b.bounds) + result = ( + tuple(max(b[i] for b, _ in bounds) for i in range(3)), + tuple(min(b[i] for _, b in bounds) for i in range(3)), + ) + if any(result[0][i] > result[1][i] for i in range(3)): + result = ((0, 0, 0), (0, 0, 0)) + else: + bounds = (self.geometry_a.bounds, self.geometry_b.bounds) + result = ( + tuple(min(b[i] for b, _ in bounds) for i in range(3)), + tuple(max(b[i] for _, b in bounds) for i in range(3)), + ) + return result + + def inside(self, x: NDArray[float], y: NDArray[float], z: NDArray[float]) -> NDArray[bool]: + """For input arrays ``x``, ``y``, ``z`` of arbitrary but identical shape, return an array + with the same shape which is ``True`` for every point in zip(x, y, z) that is inside the + volume of the :class:`Geometry`, and ``False`` otherwise. + + Parameters + ---------- + x : np.ndarray[float] + Array of point positions in x direction. + y : np.ndarray[float] + Array of point positions in y direction. + z : np.ndarray[float] + Array of point positions in z direction. + + Returns + ------- + np.ndarray[bool] + ``True`` for every point that is inside the geometry. + """ + inside_a = self.geometry_a.inside(x, y, z) + inside_b = self.geometry_b.inside(x, y, z) + return self._bit_operation(inside_a, inside_b) + + def inside_meshgrid( + self, x: NDArray[float], y: NDArray[float], z: NDArray[float] + ) -> NDArray[bool]: + """Faster way to check ``self.inside`` on a meshgrid. The input arrays are assumed sorted. + + Parameters + ---------- + x : np.ndarray[float] + 1D array of point positions in x direction. + y : np.ndarray[float] + 1D array of point positions in y direction. + z : np.ndarray[float] + 1D array of point positions in z direction. + + Returns + ------- + np.ndarray[bool] + Array with shape ``(x.size, y.size, z.size)``, which is ``True`` for every + point that is inside the geometry. + """ + inside_a = self.geometry_a.inside_meshgrid(x, y, z) + inside_b = self.geometry_b.inside_meshgrid(x, y, z) + return self._bit_operation(inside_a, inside_b) + + def _volume(self, bounds: Bound) -> float: + """Returns object's volume within given bounds.""" + # Overestimates + if self.operation == "intersection": + return min(self.geometry_a.volume(bounds), self.geometry_b.volume(bounds)) + if self.operation == "difference": + return self.geometry_a.volume(bounds) + return self.geometry_a.volume(bounds) + self.geometry_b.volume(bounds) + + def _surface_area(self, bounds: Bound) -> float: + """Returns object's surface area within given bounds.""" + # Overestimates + return self.geometry_a.surface_area(bounds) + self.geometry_b.surface_area(bounds) + + @cached_property + def _normal_2dmaterial(self) -> Axis: + """Get the normal to the given geometry, checking that it is a 2D geometry.""" + normal_a = self.geometry_a._normal_2dmaterial + normal_b = self.geometry_b._normal_2dmaterial + + if normal_a != normal_b: + raise ValidationError( + "'Medium2D' requires both geometries in the 'ClipOperation' to " + "have exactly one dimension with zero size in common." + ) + + plane_position_a = self.geometry_a.bounds[0][normal_a] + plane_position_b = self.geometry_b.bounds[0][normal_b] + + if plane_position_a != plane_position_b: + raise ValidationError( + "'Medium2D' requires both geometries in the 'ClipOperation' to be co-planar." + ) + return normal_a + + def _update_from_bounds(self, bounds: tuple[float, float], axis: Axis) -> ClipOperation: + """Returns an updated geometry which has been transformed to fit within ``bounds`` + along the ``axis`` direction.""" + new_geom_a = self.geometry_a._update_from_bounds(bounds=bounds, axis=axis) + new_geom_b = self.geometry_b._update_from_bounds(bounds=bounds, axis=axis) + return self.updated_copy(geometry_a=new_geom_a, geometry_b=new_geom_b) + + +class GeometryGroup(Geometry): + """A collection of Geometry objects that can be called as a single geometry object.""" + + geometries: tuple[discriminated_union(GeometryType), ...] = Field( + title="Geometries", + description="Tuple of geometries in a single grouping. " + "Can provide significant performance enhancement in ``Structure`` when all geometries are " + "assigned the same medium.", + ) + + @field_validator("geometries") + @classmethod + def _geometries_not_empty(cls, val: tuple[GeometryType, ...]) -> tuple[GeometryType, ...]: + """make sure geometries are not empty.""" + if not len(val) > 0: + raise ValidationError("GeometryGroup.geometries must not be empty.") + return val + + @cached_property + def bounds(self) -> Bound: + """Returns bounding box min and max coordinates. + + Returns + ------- + tuple[float, float, float], tuple[float, float, float] + Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. + """ + + bounds = tuple(geometry.bounds for geometry in self.geometries) + return ( + tuple(min(b[i] for b, _ in bounds) for i in range(3)), + tuple(max(b[i] for _, b in bounds) for i in range(3)), + ) + + def intersections_tilted_plane( + self, + normal: Coordinate, + origin: Coordinate, + to_2D: MatrixReal4x4, + cleanup: bool = True, + quad_segs: Optional[int] = None, + ) -> list[Shapely]: + """Return a list of shapely geometries at the plane specified by normal and origin. + + Parameters + ---------- + normal : Coordinate + Vector defining the normal direction to the plane. + origin : Coordinate + Vector defining the plane origin. + to_2D : MatrixReal4x4 + Transformation matrix to apply to resulting shapes. + cleanup : bool = True + If True, removes extremely small features from each polygon's boundary. + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. If ``None``, uses + high-quality visualization settings. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation `_. + """ + return [ + intersection + for geometry in self.geometries + for intersection in geometry.intersections_tilted_plane( + normal, origin, to_2D, cleanup=cleanup, quad_segs=quad_segs + ) + ] + + def intersections_plane( + self, + x: Optional[float] = None, + y: Optional[float] = None, + z: Optional[float] = None, + cleanup: bool = True, + quad_segs: Optional[int] = None, + ) -> list[Shapely]: + """Returns list of shapely geometries at plane specified by one non-None value of x,y,z. + + Parameters + ---------- + x : float = None + Position of plane in x direction, only one of x,y,z can be specified to define plane. + y : float = None + Position of plane in y direction, only one of x,y,z can be specified to define plane. + z : float = None + Position of plane in z direction, only one of x,y,z can be specified to define plane. + cleanup : bool = True + If True, removes extremely small features from each polygon's boundary. + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. If ``None``, uses + high-quality visualization settings. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation `_. + """ + if not self.intersects_plane(x, y, z): + return [] + return [ + intersection + for geometry in self.geometries + for intersection in geometry.intersections_plane( + x=x, y=y, z=z, cleanup=cleanup, quad_segs=quad_segs + ) + ] + + def intersects_axis_position(self, axis: float, position: float) -> bool: + """Whether self intersects plane specified by a given position along a normal axis. + + Parameters + ---------- + axis : int = None + Axis normal to the plane. + position : float = None + Position of plane along the normal axis. + + Returns + ------- + bool + Whether this geometry intersects the plane. + """ + return any(geom.intersects_axis_position(axis, position) for geom in self.geometries) + + def inside(self, x: NDArray[float], y: NDArray[float], z: NDArray[float]) -> NDArray[bool]: + """For input arrays ``x``, ``y``, ``z`` of arbitrary but identical shape, return an array + with the same shape which is ``True`` for every point in zip(x, y, z) that is inside the + volume of the :class:`Geometry`, and ``False`` otherwise. + + Parameters + ---------- + x : np.ndarray[float] + Array of point positions in x direction. + y : np.ndarray[float] + Array of point positions in y direction. + z : np.ndarray[float] + Array of point positions in z direction. + + Returns + ------- + np.ndarray[bool] + ``True`` for every point that is inside the geometry. + """ + individual_insides = (geometry.inside(x, y, z) for geometry in self.geometries) + return functools.reduce(lambda a, b: a | b, individual_insides) + + def inside_meshgrid( + self, x: NDArray[float], y: NDArray[float], z: NDArray[float] + ) -> NDArray[bool]: + """Faster way to check ``self.inside`` on a meshgrid. The input arrays are assumed sorted. + + Parameters + ---------- + x : np.ndarray[float] + 1D array of point positions in x direction. + y : np.ndarray[float] + 1D array of point positions in y direction. + z : np.ndarray[float] + 1D array of point positions in z direction. + + Returns + ------- + np.ndarray[bool] + Array with shape ``(x.size, y.size, z.size)``, which is ``True`` for every + point that is inside the geometry. + """ + individual_insides = (geom.inside_meshgrid(x, y, z) for geom in self.geometries) + return functools.reduce(lambda a, b: a | b, individual_insides) + + def _volume(self, bounds: Bound) -> float: + """Returns object's volume within given bounds.""" + return sum(geometry.volume(bounds) for geometry in self.geometries) + + def _surface_area(self, bounds: Bound) -> float: + """Returns object's surface area within given bounds.""" + return sum(geometry.surface_area(bounds) for geometry in self.geometries) + + @cached_property + def _normal_2dmaterial(self) -> Axis: + """Get the normal to the given geometry, checking that it is a 2D geometry.""" + + normals = {geom._normal_2dmaterial for geom in self.geometries} + + if len(normals) != 1: + raise ValidationError( + "'Medium2D' requires all geometries in the 'GeometryGroup' to " + "share exactly one dimension with zero size." + ) + normal = list(normals)[0] + positions = {geom.bounds[0][normal] for geom in self.geometries} + if len(positions) != 1: + raise ValidationError( + "'Medium2D' requires all geometries in the 'GeometryGroup' to be co-planar." + ) + return normal + + def _update_from_bounds(self, bounds: tuple[float, float], axis: Axis) -> GeometryGroup: + """Returns an updated geometry which has been transformed to fit within ``bounds`` + along the ``axis`` direction.""" + new_geometries = tuple( + geometry._update_from_bounds(bounds=bounds, axis=axis) for geometry in self.geometries + ) + return self.updated_copy(geometries=new_geometries) + + def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: + """Compute the adjoint derivatives for this object.""" + + grad_vjps = {} + + # create interpolators once for all geometries to avoid redundant field data conversions + interpolators = derivative_info.interpolators or derivative_info.create_interpolators() + + with derivative_info.cache_min_spacing_from_permittivity(): + for field_path in derivative_info.paths: + _, index, *geo_path = field_path + + geo = self.geometries[index] + # pass pre-computed interpolators if available + geo_info = derivative_info.updated_copy( + paths=[tuple(geo_path)], + bounds=geo.bounds, + bounds_intersect=self.bounds_intersection( + geo.bounds, derivative_info.simulation_bounds + ), + deep=False, + interpolators=interpolators, + ) + + vjp_dict_geo = geo._compute_derivatives(geo_info) + + if len(vjp_dict_geo) != 1: + raise AssertionError("Got multiple gradients for single geometry field.") + + grad_vjps[field_path] = vjp_dict_geo.popitem()[1] + + return grad_vjps + + +def cleanup_shapely_object(obj: Shapely, tolerance_ratio: float = POLY_TOLERANCE_RATIO) -> Shapely: + """Remove small geometric features from the boundaries of a shapely object including + inward and outward spikes, thin holes, and thin connections between larger regions. + + Parameters + ---------- + obj : shapely + a shapely object (typically a ``Polygon`` or a ``MultiPolygon``) + tolerance_ratio : float = ``POLY_TOLERANCE_RATIO`` + Features on the boundaries of polygons will be discarded if they are smaller + or narrower than ``tolerance_ratio`` multiplied by the size of the object. + + Returns + ------- + Shapely + A new shapely object whose small features (eg. thin spikes or holes) are removed. + + Notes + ----- + This function does not attempt to delete overlapping, nearby, or collinear vertices. + To solve that problem, use ``shapely.simplify()`` afterwards. + """ + if _package_is_older_than("shapely", "2.1"): + log.warning("Versions of shapely prior to v2.1 may cause plot errors.", log_once=True) + return obj + if obj.is_empty: + return obj + centroid = obj.centroid + object_size = min(obj.bounds[2] - obj.bounds[0], obj.bounds[3] - obj.bounds[1]) + if object_size == 0.0: + return shapely.Polygon([]) + + # To prevent numerical over- or underflow errors, subtract the centroid and rescale + normalized_obj = shapely.affinity.affine_transform( + obj, + matrix=[ + 1 / object_size, + 0.0, + 0.0, + 1 / object_size, + -centroid.x / object_size, + -centroid.y / object_size, + ], + ) + # Important: Remove any self intersections beforehand using `shapely.make_valid()`. + valid_obj = shapely.make_valid(normalized_obj, method="structure", keep_collapsed=False) + + # To get rid of small thin features, erode(shrink), dilate(expand), and erode again. + eroded_obj = shapely.buffer( + valid_obj, + distance=-tolerance_ratio, + cap_style="square", + quad_segs=3, + ) + dilated_obj = shapely.buffer( + eroded_obj, + distance=2 * tolerance_ratio, + cap_style="square", + quad_segs=3, + ) + cleaned_obj = dilated_obj + + # Optional: Now shrink the polygon back to the original size. + cleaned_obj = shapely.buffer( + cleaned_obj, + distance=-tolerance_ratio, + cap_style="square", + quad_segs=3, + ) + # Clean vertices of very close distances created during the erosion/dilation process. + # The distance value is heuristic. + cleaned_obj = cleaned_obj.simplify(POLY_DISTANCE_TOLERANCE, preserve_topology=True) + # Revert to the original scale and position. + rescaled_clean_obj = shapely.affinity.affine_transform( + cleaned_obj, + matrix=[ + object_size, + 0.0, + 0.0, + object_size, + centroid.x, + centroid.y, + ], + ) + return rescaled_clean_obj + + +from tidy3d._common.components.geometry.utils import ( # noqa: E402 + GeometryType, + from_shapely, + vertices_from_shapely, +) diff --git a/tidy3d/_common/components/geometry/bound_ops.py b/tidy3d/_common/components/geometry/bound_ops.py new file mode 100644 index 0000000000..a34e4fa0cf --- /dev/null +++ b/tidy3d/_common/components/geometry/bound_ops.py @@ -0,0 +1,74 @@ +"""Geometry operations for bounding box type with minimal imports.""" + +from __future__ import annotations + +from math import isclose +from typing import TYPE_CHECKING + +from tidy3d._common.constants import fp_eps + +if TYPE_CHECKING: + from tidy3d._common.components.types.base import Bound + +if TYPE_CHECKING: + from tidy3d._common.components.types import Bound + + +def bounds_intersection(bounds1: Bound, bounds2: Bound) -> Bound: + """Return the bounds that are the intersection of two bounds.""" + rmin1, rmax1 = bounds1 + rmin2, rmax2 = bounds2 + rmin = tuple(max(v1, v2) for v1, v2 in zip(rmin1, rmin2)) + rmax = tuple(min(v1, v2) for v1, v2 in zip(rmax1, rmax2)) + return (rmin, rmax) + + +def bounds_union(bounds1: Bound, bounds2: Bound) -> Bound: + """Return the bounds that are the union of two bounds.""" + rmin1, rmax1 = bounds1 + rmin2, rmax2 = bounds2 + rmin = tuple(min(v1, v2) for v1, v2 in zip(rmin1, rmin2)) + rmax = tuple(max(v1, v2) for v1, v2 in zip(rmax1, rmax2)) + return (rmin, rmax) + + +def bounds_contains( + outer_bounds: Bound, inner_bounds: Bound, rtol: float = fp_eps, atol: float = 0.0 +) -> bool: + """Checks whether ``inner_bounds`` is contained within ``outer_bounds`` within specified tolerances. + + Parameters + ---------- + outer_bounds : Bound + The outer bounds to check containment against + inner_bounds : Bound + The inner bounds to check if contained + rtol : float = fp_eps + Relative tolerance for comparing bounds + atol : float = 0.0 + Absolute tolerance for comparing bounds + + Returns + ------- + bool + True if ``inner_bounds`` is contained within ``outer_bounds`` within tolerances + """ + outer_min, outer_max = outer_bounds + inner_min, inner_max = inner_bounds + for dim in range(3): + outer_min_dim = outer_min[dim] + outer_max_dim = outer_max[dim] + inner_min_dim = inner_min[dim] + inner_max_dim = inner_max[dim] + within_min = ( + isclose(outer_min_dim, inner_min_dim, rel_tol=rtol, abs_tol=atol) + or outer_min_dim <= inner_min_dim + ) + within_max = ( + isclose(outer_max_dim, inner_max_dim, rel_tol=rtol, abs_tol=atol) + or outer_max_dim >= inner_max_dim + ) + + if not within_min or not within_max: + return False + return True diff --git a/tidy3d/_common/components/geometry/float_utils.py b/tidy3d/_common/components/geometry/float_utils.py new file mode 100644 index 0000000000..2b5848666d --- /dev/null +++ b/tidy3d/_common/components/geometry/float_utils.py @@ -0,0 +1,31 @@ +"""Utilities for float manipulation.""" + +from __future__ import annotations + +import numpy as np + +from tidy3d._common.constants import inf + + +def increment_float(val: float, sign: int) -> float: + """Applies a small positive or negative shift as though `val` is a 32bit float + using numpy.nextafter, but additionally handles some corner cases. + """ + # Infinity is left unchanged + if val == inf or val == -inf: + return val + + if sign >= 0: + sign = 1 + else: + sign = -1 + + # Avoid small increments within subnormal values + if np.abs(val) <= np.finfo(np.float32).tiny: + return val + sign * np.finfo(np.float32).tiny + + # Numpy seems to skip over the increment from -0.0 and +0.0 + # which is different from c++ + val_inc = np.nextafter(val, sign * inf, dtype=np.float32) + + return np.float32(val_inc) diff --git a/tidy3d/_common/components/geometry/mesh.py b/tidy3d/_common/components/geometry/mesh.py new file mode 100644 index 0000000000..416b9eaaf1 --- /dev/null +++ b/tidy3d/_common/components/geometry/mesh.py @@ -0,0 +1,1285 @@ +"""Mesh-defined geometry.""" + +from __future__ import annotations + +from abc import ABC +from typing import TYPE_CHECKING, Any, Optional + +import numpy as np +from autograd import numpy as anp +from numpy.typing import NDArray +from pydantic import Field, PrivateAttr, field_validator, model_validator + +from tidy3d._common.components.autograd import get_static +from tidy3d._common.components.base import cached_property +from tidy3d._common.components.data.data_array import DATA_ARRAY_MAP, TriangleMeshDataArray +from tidy3d._common.components.data.dataset import TriangleMeshDataset +from tidy3d._common.components.data.validators import validate_no_nans +from tidy3d._common.components.geometry import base +from tidy3d._common.components.viz import add_ax_if_none, equal_aspect +from tidy3d._common.config import config +from tidy3d._common.constants import fp_eps, inf +from tidy3d._common.exceptions import DataError, ValidationError +from tidy3d._common.log import log +from tidy3d._common.packaging import verify_packages_import + +if TYPE_CHECKING: + from os import PathLike + from typing import Callable, Literal, Union + + from trimesh import Trimesh + + from tidy3d._common.components.autograd import AutogradFieldMap + from tidy3d._common.components.autograd.derivative_utils import DerivativeInfo + from tidy3d._common.components.types.base import Ax, Bound, Coordinate, MatrixReal4x4, Shapely + +AREA_SIZE_THRESHOLD = 1e-36 + + +class TriangleMesh(base.Geometry, ABC): + """Custom surface geometry given by a triangle mesh, as in the STL file format. + + Example + ------- + >>> vertices = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]]) + >>> faces = np.array([[1, 2, 3], [0, 3, 2], [0, 1, 3], [0, 2, 1]]) + >>> stl_geom = TriangleMesh.from_vertices_faces(vertices, faces) + """ + + mesh_dataset: Optional[TriangleMeshDataset] = Field( + None, + title="Surface mesh data", + description="Surface mesh data.", + ) + + _no_nans_mesh = validate_no_nans("mesh_dataset") + _barycentric_samples: dict[int, NDArray] = PrivateAttr(default_factory=dict) + + @verify_packages_import(["trimesh"]) + @model_validator(mode="before") + @classmethod + def _validate_trimesh_library(cls, data: dict[str, Any]) -> dict[str, Any]: + """Check if the trimesh package is imported as a validator.""" + return data + + @field_validator("mesh_dataset", mode="before") + @classmethod + def _warn_if_none(cls, val: TriangleMeshDataset) -> TriangleMeshDataset: + """Warn if the Dataset fails to load.""" + if isinstance(val, dict): + if any((v in DATA_ARRAY_MAP for _, v in val.items() if isinstance(v, str))): + log.warning("Loading 'mesh_dataset' without data.") + return None + return val + + @field_validator("mesh_dataset") + @classmethod + def _check_mesh(cls, val: TriangleMeshDataset) -> TriangleMeshDataset: + """Check that the mesh is valid.""" + if val is None: + return None + + import trimesh + + surface_mesh = val.surface_mesh + triangles = get_static(surface_mesh.data) + mesh = cls._triangles_to_trimesh(triangles) + if not all(np.array(mesh.area_faces) > AREA_SIZE_THRESHOLD): + old_tol = trimesh.tol.merge + trimesh.tol.merge = np.sqrt(2 * AREA_SIZE_THRESHOLD) + new_mesh = mesh.process(validate=True) + trimesh.tol.merge = old_tol + val = TriangleMesh.from_trimesh(new_mesh).mesh_dataset + log.warning( + f"The provided mesh has triangles with near zero area < {AREA_SIZE_THRESHOLD}. " + "Triangles which have one edge of their 2D oriented bounding box shorter than " + f"'sqrt(2*{AREA_SIZE_THRESHOLD}) are being automatically removed.'" + ) + if not all(np.array(new_mesh.area_faces) > AREA_SIZE_THRESHOLD): + raise ValidationError( + f"The provided mesh has triangles with near zero area < {AREA_SIZE_THRESHOLD}. " + "The automatic removal of these triangles has failed. You can try " + "using numpy-stl's 'from_file' import with 'remove_empty_areas' set " + "to True and a suitable 'AREA_SIZE_THRESHOLD' to remove them." + ) + if not mesh.is_watertight: + log.warning( + "The provided mesh is not watertight. " + "This can lead to incorrect permittivity distributions, " + "and can also cause problems with plotting and mesh validation. " + "You can try 'TriangleMesh.fill_holes', which attempts to repair the mesh. " + "Otherwise, the mesh may require manual repair. You can use a " + "'PermittivityMonitor' to check if the permittivity distribution is correct. " + "You can see which faces are broken using 'trimesh.repair.broken_faces'." + ) + if not mesh.is_winding_consistent: + log.warning( + "The provided mesh does not have consistent winding (face orientations). " + "This can lead to incorrect permittivity distributions, " + "and can also cause problems with plotting and mesh validation. " + "You can try 'TriangleMesh.fix_winding', which attempts to repair the mesh. " + "Otherwise, the mesh may require manual repair. You can use a " + "'PermittivityMonitor' to check if the permittivity distribution is correct. " + ) + if not mesh.is_volume: + log.warning( + "The provided mesh does not represent a valid volume, possibly due to " + "incorrect normal vector orientation. " + "This can lead to incorrect permittivity distributions, " + "and can also cause problems with plotting and mesh validation. " + "You can try 'TriangleMesh.fix_normals', " + "which attempts to fix the normals to be consistent and outward-facing. " + "Otherwise, the mesh may require manual repair. You can use a " + "'PermittivityMonitor' to check if the permittivity distribution is correct." + ) + + return val + + @verify_packages_import(["trimesh"]) + def fix_winding(self) -> TriangleMesh: + """Try to fix winding in the mesh.""" + import trimesh + + mesh = TriangleMesh._triangles_to_trimesh(self.mesh_dataset.surface_mesh) + trimesh.repair.fix_winding(mesh) + return TriangleMesh.from_trimesh(mesh) + + @verify_packages_import(["trimesh"]) + def fill_holes(self) -> TriangleMesh: + """Try to fill holes in the mesh. Can be used to repair non-watertight meshes.""" + import trimesh + + mesh = TriangleMesh._triangles_to_trimesh(self.mesh_dataset.surface_mesh) + trimesh.repair.fill_holes(mesh) + return TriangleMesh.from_trimesh(mesh) + + @verify_packages_import(["trimesh"]) + def fix_normals(self) -> TriangleMesh: + """Try to fix normals to be consistent and outward-facing.""" + import trimesh + + mesh = TriangleMesh._triangles_to_trimesh(self.mesh_dataset.surface_mesh) + trimesh.repair.fix_normals(mesh) + return TriangleMesh.from_trimesh(mesh) + + @classmethod + @verify_packages_import(["trimesh"]) + def from_stl( + cls, + filename: str, + scale: float = 1.0, + origin: tuple[float, float, float] = (0, 0, 0), + solid_index: Optional[int] = None, + **kwargs: Any, + ) -> Union[TriangleMesh, base.GeometryGroup]: + """Load a :class:`.TriangleMesh` directly from an STL file. + The ``solid_index`` parameter can be used to select a single solid from the file. + Otherwise, if the file contains a single solid, it will be loaded as a + :class:`.TriangleMesh`; if the file contains multiple solids, + they will all be loaded as a :class:`.GeometryGroup`. + + Parameters + ---------- + filename : str + The name of the STL file containing the surface geometry mesh data. + scale : float = 1.0 + The length scale for the loaded geometry (um). + For example, a scale of 10.0 means that a vertex (1, 0, 0) will be placed at + x = 10 um. + origin : tuple[float, float, float] = (0, 0, 0) + The origin of the loaded geometry, in units of ``scale``. + Translates from (0, 0, 0) to this point after applying the scaling. + solid_index : int = None + If set, read a single solid with this index from the file. + + Returns + ------- + Union[:class:`.TriangleMesh`, :class:`.GeometryGroup`] + The geometry or geometry group from the file. + """ + import trimesh + + from tidy3d._common.components.types.third_party import TrimeshType + + def process_single(mesh: TrimeshType) -> TriangleMesh: + """Process a single 'trimesh.Trimesh' using scale and origin.""" + mesh.apply_scale(scale) + mesh.apply_translation(origin) + return cls.from_trimesh(mesh) + + scene = trimesh.load(filename, **kwargs) + meshes = [] + if isinstance(scene, trimesh.Trimesh): + meshes = [scene] + elif isinstance(scene, trimesh.Scene): + meshes = scene.dump() + else: + raise ValidationError( + "Invalid trimesh type in file. Supported types are 'trimesh.Trimesh' " + "and 'trimesh.Scene'." + ) + + if solid_index is None: + if isinstance(scene, trimesh.Trimesh): + return process_single(scene) + if isinstance(scene, trimesh.Scene): + geoms = [process_single(mesh) for mesh in meshes] + return base.GeometryGroup(geometries=geoms) + + if solid_index < len(meshes): + return process_single(meshes[solid_index]) + raise ValidationError("No solid found at 'solid_index' in the stl file.") + + @verify_packages_import(["trimesh"]) + def to_stl( + self, + filename: PathLike, + *, + binary: bool = True, + ) -> None: + """Export this TriangleMesh to an STL file. + + Parameters + ---------- + filename : str + Output STL filename. + binary : bool = True + Whether to write binary STL. Set False for ASCII STL. + """ + triangles = get_static(self.mesh_dataset.surface_mesh.data) + mesh = self._triangles_to_trimesh(triangles) + + file_type = "stl" if binary else "stl_ascii" + mesh.export(file_obj=filename, file_type=file_type) + + @classmethod + @verify_packages_import(["trimesh"]) + def from_trimesh(cls, mesh: Trimesh) -> TriangleMesh: + """Create a :class:`.TriangleMesh` from a ``trimesh.Trimesh`` object. + + Parameters + ---------- + trimesh : ``trimesh.Trimesh`` + The Trimesh object containing the surface geometry mesh data. + + Returns + ------- + :class:`.TriangleMesh` + The custom surface mesh geometry given by the ``trimesh.Trimesh`` provided. + """ + return cls.from_vertices_faces(mesh.vertices, mesh.faces) + + @classmethod + def from_triangles(cls, triangles: NDArray) -> TriangleMesh: + """Create a :class:`.TriangleMesh` from a numpy array + containing the triangles of a surface mesh. + + Parameters + ---------- + triangles : ``np.ndarray`` + A numpy array of shape (N, 3, 3) storing the triangles of the surface mesh. + The first index labels the triangle, the second index labels the vertex + within a given triangle, and the third index is the coordinate (x, y, or z). + + Returns + ------- + :class:`.TriangleMesh` + The custom surface mesh geometry given by the triangles provided. + + """ + triangles = anp.array(triangles) + if len(triangles.shape) != 3 or triangles.shape[1] != 3 or triangles.shape[2] != 3: + raise ValidationError( + f"Provided 'triangles' must be an N x 3 x 3 array, given {triangles.shape}." + ) + num_faces = len(triangles) + coords = { + "face_index": np.arange(num_faces), + "vertex_index": np.arange(3), + "axis": np.arange(3), + } + vertices = TriangleMeshDataArray(triangles, coords=coords) + mesh_dataset = TriangleMeshDataset(surface_mesh=vertices) + return TriangleMesh(mesh_dataset=mesh_dataset) + + @classmethod + @verify_packages_import(["trimesh"]) + def from_vertices_faces(cls, vertices: NDArray, faces: NDArray) -> TriangleMesh: + """Create a :class:`.TriangleMesh` from numpy arrays containing the data + of a surface mesh. The first array contains the vertices, and the second array contains + faces formed from triples of the vertices. + + Parameters + ---------- + vertices: ``np.ndarray`` + A numpy array of shape (N, 3) storing the vertices of the surface mesh. + The first index labels the vertex, and the second index is the coordinate + (x, y, or z). + faces : ``np.ndarray`` + A numpy array of shape (M, 3) storing the indices of the vertices of each face + in the surface mesh. The first index labels the face, and the second index + labels the vertex index within the ``vertices`` array. + + Returns + ------- + :class:`.TriangleMesh` + The custom surface mesh geometry given by the vertices and faces provided. + + """ + import trimesh + + vertices = np.array(vertices) + faces = np.array(faces) + if len(vertices.shape) != 2 or vertices.shape[1] != 3: + raise ValidationError( + f"Provided 'vertices' must be an N x 3 array, given {vertices.shape}." + ) + if len(faces.shape) != 2 or faces.shape[1] != 3: + raise ValidationError(f"Provided 'faces' must be an M x 3 array, given {faces.shape}.") + return cls.from_triangles(trimesh.Trimesh(vertices, faces).triangles) + + @classmethod + @verify_packages_import(["trimesh"]) + def _triangles_to_trimesh( + cls, triangles: NDArray + ) -> Trimesh: # -> We need to get this out of the classes and into functional methods operating on a class (maybe still referenced to the class) + """Convert an (N, 3, 3) numpy array of triangles to a ``trimesh.Trimesh``.""" + import trimesh + + # ``triangles`` may contain autograd ``ArrayBox`` entries when differentiating + # geometry parameters. ``trimesh`` expects plain ``float`` values, so strip any + # tracing information before constructing the mesh. + triangles = get_static(anp.array(triangles)) + return trimesh.Trimesh(**trimesh.triangles.to_kwargs(triangles)) + + @classmethod + def from_height_grid( + cls, + axis: Ax, + direction: Literal["-", "+"], + base: float, + grid: tuple[np.ndarray, np.ndarray], + height: NDArray, + ) -> TriangleMesh: + """Construct a TriangleMesh object from grid based height information. + + Parameters + ---------- + axis : Ax + Axis of extrusion. + direction : Literal["-", "+"] + Direction of extrusion. + base : float + Coordinate of the base surface along the geometry's axis. + grid : Tuple[np.ndarray, np.ndarray] + Tuple of two one-dimensional arrays representing the sampling grid (XY, YZ, or ZX + corresponding to values of axis) + height : np.ndarray + Height values sampled on the given grid. Can be 1D (raveled) or 2D (matching grid mesh). + + Returns + ------- + TriangleMesh + The resulting TriangleMesh geometry object. + """ + + x_coords = grid[0] + y_coords = grid[1] + + nx = len(x_coords) + ny = len(y_coords) + nt = nx * ny + + x_mesh, y_mesh = np.meshgrid(x_coords, y_coords, indexing="ij") + + sign = 1 + if direction == "-": + sign = -1 + + flat_height = np.ravel(height) + if flat_height.shape[0] != nt: + raise ValueError( + f"Shape of flattened height array {flat_height.shape} does not match " + f"the number of grid points {nt}." + ) + + if np.any(flat_height < 0): + raise ValueError("All height values must be non-negative.") + + max_h = np.max(flat_height) + min_h_clip = fp_eps * max_h + flat_height = np.clip(flat_height, min_h_clip, inf) + + vertices_raw_list = [ + [np.ravel(x_mesh), np.ravel(y_mesh), base + sign * flat_height], # Alpha surface + [np.ravel(x_mesh), np.ravel(y_mesh), base * np.ones(nt)], + ] + + if direction == "-": + vertices_raw_list = vertices_raw_list[::-1] + + vertices = np.hstack(vertices_raw_list).T + vertices = np.roll(vertices, shift=axis - 2, axis=1) + + q0 = (np.arange(nx - 1)[:, None] * ny + np.arange(ny - 1)[None, :]).ravel() + q1 = (np.arange(1, nx)[:, None] * ny + np.arange(ny - 1)[None, :]).ravel() + q2 = (np.arange(1, nx)[:, None] * ny + np.arange(1, ny)[None, :]).ravel() + q3 = (np.arange(nx - 1)[:, None] * ny + np.arange(1, ny)[None, :]).ravel() + + q0_b = nt + q0 + q1_b = nt + q1 + q2_b = nt + q2 + q3_b = nt + q3 + + top_quads = np.stack((q0, q1, q2, q3), axis=-1) + bottom_quads = np.stack((q0_b, q3_b, q2_b, q1_b), axis=-1) + + s1_q0 = (0 * ny + np.arange(ny - 1)).ravel() + s1_q1 = (0 * ny + np.arange(1, ny)).ravel() + s1_q2 = (nt + 0 * ny + np.arange(1, ny)).ravel() + s1_q3 = (nt + 0 * ny + np.arange(ny - 1)).ravel() + side1_quads = np.stack((s1_q0, s1_q1, s1_q2, s1_q3), axis=-1) + + s2_q0 = ((nx - 1) * ny + np.arange(ny - 1)).ravel() + s2_q1 = (nt + (nx - 1) * ny + np.arange(ny - 1)).ravel() + s2_q2 = (nt + (nx - 1) * ny + np.arange(1, ny)).ravel() + s2_q3 = ((nx - 1) * ny + np.arange(1, ny)).ravel() + side2_quads = np.stack((s2_q0, s2_q1, s2_q2, s2_q3), axis=-1) + + s3_q0 = (np.arange(nx - 1) * ny + 0).ravel() + s3_q1 = (nt + np.arange(nx - 1) * ny + 0).ravel() + s3_q2 = (nt + np.arange(1, nx) * ny + 0).ravel() + s3_q3 = (np.arange(1, nx) * ny + 0).ravel() + side3_quads = np.stack((s3_q0, s3_q1, s3_q2, s3_q3), axis=-1) + + s4_q0 = (np.arange(nx - 1) * ny + ny - 1).ravel() + s4_q1 = (np.arange(1, nx) * ny + ny - 1).ravel() + s4_q2 = (nt + np.arange(1, nx) * ny + ny - 1).ravel() + s4_q3 = (nt + np.arange(nx - 1) * ny + ny - 1).ravel() + side4_quads = np.stack((s4_q0, s4_q1, s4_q2, s4_q3), axis=-1) + + all_quads = np.vstack( + (top_quads, bottom_quads, side1_quads, side2_quads, side3_quads, side4_quads) + ) + + triangles_list = [ + np.stack((all_quads[:, 0], all_quads[:, 1], all_quads[:, 3]), axis=-1), + np.stack((all_quads[:, 3], all_quads[:, 1], all_quads[:, 2]), axis=-1), + ] + tri_faces = np.vstack(triangles_list) + + return cls.from_vertices_faces(vertices=vertices, faces=tri_faces) + + @classmethod + def from_height_function( + cls, + axis: Ax, + direction: Literal["-", "+"], + base: float, + center: tuple[float, float], + size: tuple[float, float], + grid_size: tuple[int, int], + height_func: Callable[[np.ndarray, np.ndarray], np.ndarray], + ) -> TriangleMesh: + """Construct a TriangleMesh object from analytical expression of height function. + The height function should be vectorized to accept 2D meshgrid arrays. + + Parameters + ---------- + axis : Ax + Axis of extrusion. + direction : Literal["-", "+"] + Direction of extrusion. + base : float + Coordinate of the base rectangle along the geometry's axis. + center : Tuple[float, float] + Center of the base rectangle in the plane perpendicular to the extrusion axis + (XY, YZ, or ZX corresponding to values of axis). + size : Tuple[float, float] + Size of the base rectangle in the plane perpendicular to the extrusion axis + (XY, YZ, or ZX corresponding to values of axis). + grid_size : Tuple[int, int] + Number of grid points for discretization of the base rectangle + (XY, YZ, or ZX corresponding to values of axis). + height_func : Callable[[np.ndarray, np.ndarray], np.ndarray] + Vectorized function to compute height values from 2D meshgrid coordinate arrays. + It should take two ndarrays (x_mesh, y_mesh) and return an ndarray of heights. + + Returns + ------- + TriangleMesh + The resulting TriangleMesh geometry object. + """ + x_lin = np.linspace(center[0] - 0.5 * size[0], center[0] + 0.5 * size[0], grid_size[0]) + y_lin = np.linspace(center[1] - 0.5 * size[1], center[1] + 0.5 * size[1], grid_size[1]) + + x_mesh, y_mesh = np.meshgrid(x_lin, y_lin, indexing="ij") + + height_values = height_func(x_mesh, y_mesh) + + if not (isinstance(height_values, np.ndarray) and height_values.shape == x_mesh.shape): + raise ValueError( + f"The 'height_func' must return a NumPy array with shape {x_mesh.shape}, " + f"but got shape {getattr(height_values, 'shape', type(height_values))}." + ) + + return cls.from_height_grid( + axis=axis, + direction=direction, + base=base, + grid=(x_lin, y_lin), + height=height_values, + ) + + @cached_property + @verify_packages_import(["trimesh"]) + def trimesh( + self, + ) -> Trimesh: # -> We need to get this out of the classes and into functional methods operating on a class (maybe still referenced to the class) + """A ``trimesh.Trimesh`` object representing the custom surface mesh geometry.""" + return self._triangles_to_trimesh(self.triangles) + + @cached_property + def triangles(self) -> np.ndarray: + """The triangles of the surface mesh as an ``np.ndarray``.""" + if self.mesh_dataset is None: + raise DataError("Can't get triangles as 'mesh_dataset' is None.") + return np.asarray(get_static(self.mesh_dataset.surface_mesh.data)) + + def _surface_area(self, bounds: Bound) -> float: + """Returns object's surface area within given bounds.""" + # currently ignores bounds + return self.trimesh.area + + def _volume(self, bounds: Bound) -> float: + """Returns object's volume within given bounds.""" + # currently ignores bounds + return self.trimesh.volume + + @cached_property + def bounds(self) -> Bound: + """Returns bounding box min and max coordinates. + + Returns + ------- + tuple[float, float, float], tuple[float, float float] + Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. + """ + if self.mesh_dataset is None: + return ((-inf, -inf, -inf), (inf, inf, inf)) + return self.trimesh.bounds + + def intersections_tilted_plane( + self, + normal: Coordinate, + origin: Coordinate, + to_2D: MatrixReal4x4, + cleanup: bool = True, + quad_segs: Optional[int] = None, + ) -> list[Shapely]: + """Return a list of shapely geometries at the plane specified by normal and origin. + + Parameters + ---------- + normal : Coordinate + Vector defining the normal direction to the plane. + origin : Coordinate + Vector defining the plane origin. + to_2D : MatrixReal4x4 + Transformation matrix to apply to resulting shapes. + cleanup : bool = True + If True, removes extremely small features from each polygon's boundary. + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. Not used for TriangleMesh. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation `_. + """ + section = self.trimesh.section(plane_origin=origin, plane_normal=normal) + if section is None: + return [] + path, _ = section.to_2D(to_2D=to_2D) + return path.polygons_full + + def intersections_plane( + self, + x: Optional[float] = None, + y: Optional[float] = None, + z: Optional[float] = None, + cleanup: bool = True, + quad_segs: Optional[int] = None, + ) -> list[Shapely]: + """Returns list of shapely geometries at plane specified by one non-None value of x,y,z. + + Parameters + ---------- + x : float = None + Position of plane in x direction, only one of x,y,z can be specified to define plane. + y : float = None + Position of plane in y direction, only one of x,y,z can be specified to define plane. + z : float = None + Position of plane in z direction, only one of x,y,z can be specified to define plane. + cleanup : bool = True + If True, removes extremely small features from each polygon's boundary. + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. Not used for TriangleMesh. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentaton `_. + """ + + if self.mesh_dataset is None: + return [] + + axis, position = self.parse_xyz_kwargs(x=x, y=y, z=z) + + origin = self.unpop_axis(position, (0, 0), axis=axis) + normal = self.unpop_axis(1, (0, 0), axis=axis) + + mesh = self.trimesh + + try: + section = mesh.section(plane_origin=origin, plane_normal=normal) + + if section is None: + return [] + + # homogeneous transformation matrix to map to xy plane + mapping = np.eye(4) + + # translate to origin + mapping[3, :3] = -np.array(origin) + + # permute so normal is aligned with z axis + # and (y, z), (x, z), resp. (x, y) are aligned with (x, y) + identity = np.eye(3) + permutation = self.unpop_axis(identity[2], identity[0:2], axis=axis) + mapping[:3, :3] = np.array(permutation).T + + section2d, _ = section.to_2D(to_2D=mapping) + return list(section2d.polygons_full) + + except ValueError as e: + if not mesh.is_watertight: + log.warning( + "Unable to compute 'TriangleMesh.intersections_plane' " + "because the mesh was not watertight. Using bounding box instead. " + "This may be overly strict; consider using 'TriangleMesh.fill_holes' " + "to repair the non-watertight mesh." + ) + else: + log.warning( + "Unable to compute 'TriangleMesh.intersections_plane'. " + "Using bounding box instead." + ) + log.warning(f"Error encountered: {e}") + return self.bounding_box.intersections_plane(x=x, y=y, z=z, cleanup=cleanup) + + def inside(self, x: NDArray, y: NDArray, z: NDArray) -> np.ndarray[bool]: + """For input arrays ``x``, ``y``, ``z`` of arbitrary but identical shape, return an array + with the same shape which is ``True`` for every point in zip(x, y, z) that is inside the + volume of the :class:`Geometry`, and ``False`` otherwise. + + Parameters + ---------- + x : np.ndarray[float] + Array of point positions in x direction. + y : np.ndarray[float] + Array of point positions in y direction. + z : np.ndarray[float] + Array of point positions in z direction. + + Returns + ------- + np.ndarray[bool] + ``True`` for every point that is inside the geometry. + """ + + arrays = tuple(map(np.array, (x, y, z))) + self._ensure_equal_shape(*arrays) + arrays_flat = map(np.ravel, arrays) + arrays_stacked = np.stack(tuple(arrays_flat), axis=-1) + inside = self.trimesh.contains(arrays_stacked) + return inside.reshape(arrays[0].shape) + + @equal_aspect + @add_ax_if_none + def plot( + self, + x: Optional[float] = None, + y: Optional[float] = None, + z: Optional[float] = None, + ax: Ax = None, + **patch_kwargs: Any, + ) -> Ax: + """Plot geometry cross section at single (x,y,z) coordinate. + + Parameters + ---------- + x : float = None + Position of plane in x direction, only one of x,y,z can be specified to define plane. + y : float = None + Position of plane in y direction, only one of x,y,z can be specified to define plane. + z : float = None + Position of plane in z direction, only one of x,y,z can be specified to define plane. + ax : matplotlib.axes._subplots.Axes = None + Matplotlib axes to plot on, if not specified, one is created. + **patch_kwargs + Optional keyword arguments passed to the matplotlib patch plotting of structure. + For details on accepted values, refer to + `Matplotlib's documentation `_. + + Returns + ------- + matplotlib.axes._subplots.Axes + The supplied or created matplotlib axes. + """ + + log.warning( + "Plotting a 'TriangleMesh' may give inconsistent results " + "if the mesh is not unionized. We recommend unionizing all meshes before import. " + "A 'PermittivityMonitor' can be used to check that the mesh is loaded correctly." + ) + + return base.Geometry.plot(self, x=x, y=y, z=z, ax=ax, **patch_kwargs) + + def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: + """Compute adjoint derivatives for a ``TriangleMesh`` geometry.""" + vjps: AutogradFieldMap = {} + + if not self.mesh_dataset: + raise DataError("Can't compute derivatives without mesh data.") + + valid_paths = {("mesh_dataset", "surface_mesh")} + for path in derivative_info.paths: + if path not in valid_paths: + raise ValueError(f"No derivative defined w.r.t. 'TriangleMesh' field '{path}'.") + + if ("mesh_dataset", "surface_mesh") not in derivative_info.paths: + return vjps + + triangles = np.asarray(self.triangles, dtype=config.adjoint.gradient_dtype_float) + + # early exit if geometry is completely outside simulation bounds + sim_min, sim_max = map(np.asarray, derivative_info.simulation_bounds) + mesh_min, mesh_max = map(np.asarray, self.bounds) + if np.any(mesh_max < sim_min) or np.any(mesh_min > sim_max): + log.warning( + "'TriangleMesh' lies completely outside the simulation domain.", + log_once=True, + ) + zeros = np.zeros_like(triangles) + vjps[("mesh_dataset", "surface_mesh")] = zeros + return vjps + + # gather surface samples within the simulation bounds + dx = derivative_info.adaptive_vjp_spacing() + samples = self._collect_surface_samples( + triangles=triangles, + spacing=dx, + sim_min=sim_min, + sim_max=sim_max, + ) + + if samples["points"].shape[0] == 0: + zeros = np.zeros_like(triangles) + vjps[("mesh_dataset", "surface_mesh")] = zeros + return vjps + + interpolators = derivative_info.interpolators + if interpolators is None: + interpolators = derivative_info.create_interpolators( + dtype=config.adjoint.gradient_dtype_float + ) + + g = derivative_info.evaluate_gradient_at_points( + samples["points"], + samples["normals"], + samples["perps1"], + samples["perps2"], + interpolators, + ) + + # accumulate per-vertex contributions using barycentric weights + weights = (samples["weights"] * g).real + normals = samples["normals"] + faces = samples["faces"] + bary = samples["barycentric"] + + contrib_vec = weights[:, None] * normals + + triangle_grads = np.zeros_like(triangles, dtype=config.adjoint.gradient_dtype_float) + for vertex_idx in range(3): + scaled = contrib_vec * bary[:, vertex_idx][:, None] + np.add.at(triangle_grads[:, vertex_idx, :], faces, scaled) + + vjps[("mesh_dataset", "surface_mesh")] = triangle_grads + return vjps + + def _collect_surface_samples( + self, + triangles: NDArray, + spacing: float, + sim_min: NDArray, + sim_max: NDArray, + ) -> dict[str, np.ndarray]: + """Deterministic per-triangle sampling used historically.""" + + dtype = config.adjoint.gradient_dtype_float + tol = config.adjoint.edge_clip_tolerance + + sim_min = np.asarray(sim_min, dtype=dtype) + sim_max = np.asarray(sim_max, dtype=dtype) + + points_list: list[np.ndarray] = [] + normals_list: list[np.ndarray] = [] + perps1_list: list[np.ndarray] = [] + perps2_list: list[np.ndarray] = [] + weights_list: list[np.ndarray] = [] + faces_list: list[np.ndarray] = [] + bary_list: list[np.ndarray] = [] + + spacing = max(float(spacing), np.finfo(float).eps) + triangles_arr = np.asarray(triangles, dtype=dtype) + + sim_extents = sim_max - sim_min + valid_axes = np.abs(sim_extents) > tol + collapsed_indices = np.flatnonzero(np.isclose(sim_extents, 0.0, atol=tol)) + collapsed_axis: Optional[int] = None + plane_value: Optional[float] = None + if collapsed_indices.size == 1: + collapsed_axis = int(collapsed_indices[0]) + plane_value = float(sim_min[collapsed_axis]) + + warned = False + warning_msg = "Some triangles from the mesh lie outside the simulation bounds - this may lead to inaccurate gradients." + for face_index, tri in enumerate(triangles_arr): + area, normal = self._triangle_area_and_normal(tri) + if area <= AREA_SIZE_THRESHOLD: + continue + + perps = self._triangle_tangent_basis(tri, normal) + if perps is None: + continue + perp1, perp2 = perps + + if collapsed_axis is not None and plane_value is not None: + samples, outside_bounds = self._collect_surface_samples_2d( + triangle=tri, + face_index=face_index, + normal=normal, + perp1=perp1, + perp2=perp2, + spacing=spacing, + collapsed_axis=collapsed_axis, + plane_value=plane_value, + sim_min=sim_min, + sim_max=sim_max, + valid_axes=valid_axes, + tol=tol, + dtype=dtype, + ) + else: + samples, outside_bounds = self._collect_surface_samples_3d( + triangle=tri, + face_index=face_index, + normal=normal, + perp1=perp1, + perp2=perp2, + area=area, + spacing=spacing, + sim_min=sim_min, + sim_max=sim_max, + valid_axes=valid_axes, + tol=tol, + dtype=dtype, + ) + + if outside_bounds and not warned: + log.warning(warning_msg) + warned = True + + if samples is None: + continue + + points_list.append(samples["points"]) + normals_list.append(samples["normals"]) + perps1_list.append(samples["perps1"]) + perps2_list.append(samples["perps2"]) + weights_list.append(samples["weights"]) + faces_list.append(samples["faces"]) + bary_list.append(samples["barycentric"]) + + if not points_list: + return { + "points": np.zeros((0, 3), dtype=dtype), + "normals": np.zeros((0, 3), dtype=dtype), + "perps1": np.zeros((0, 3), dtype=dtype), + "perps2": np.zeros((0, 3), dtype=dtype), + "weights": np.zeros((0,), dtype=dtype), + "faces": np.zeros((0,), dtype=int), + "barycentric": np.zeros((0, 3), dtype=dtype), + } + + return { + "points": np.concatenate(points_list, axis=0), + "normals": np.concatenate(normals_list, axis=0), + "perps1": np.concatenate(perps1_list, axis=0), + "perps2": np.concatenate(perps2_list, axis=0), + "weights": np.concatenate(weights_list, axis=0), + "faces": np.concatenate(faces_list, axis=0), + "barycentric": np.concatenate(bary_list, axis=0), + } + + def _collect_surface_samples_2d( + self, + triangle: NDArray, + face_index: int, + normal: np.ndarray, + perp1: np.ndarray, + perp2: np.ndarray, + spacing: float, + collapsed_axis: int, + plane_value: float, + sim_min: np.ndarray, + sim_max: np.ndarray, + valid_axes: np.ndarray, + tol: float, + dtype: np.dtype, + ) -> tuple[Optional[dict[str, np.ndarray]], bool]: + """Collect samples when the simulation bounds collapse onto a 2D plane.""" + + segments = self._triangle_plane_segments( + triangle=triangle, axis=collapsed_axis, plane_value=plane_value, tol=tol + ) + + points: list[np.ndarray] = [] + normals: list[np.ndarray] = [] + perps1_list: list[np.ndarray] = [] + perps2_list: list[np.ndarray] = [] + weights: list[np.ndarray] = [] + faces: list[np.ndarray] = [] + barycentric: list[np.ndarray] = [] + outside_bounds = False + + for start, end in segments: + vec = end - start + length = float(np.linalg.norm(vec)) + if length <= tol: + continue + + subdivisions = max(1, int(np.ceil(length / spacing))) + t_vals = (np.arange(subdivisions, dtype=dtype) + 0.5) / subdivisions + sample_points = start[None, :] + t_vals[:, None] * vec[None, :] + bary = self._barycentric_coordinates(triangle, sample_points, tol) + + inside_mask = np.ones(sample_points.shape[0], dtype=bool) + if np.any(valid_axes): + min_bound = (sim_min - tol)[valid_axes] + max_bound = (sim_max + tol)[valid_axes] + coords = sample_points[:, valid_axes] + inside_mask = np.all(coords >= min_bound, axis=1) & np.all( + coords <= max_bound, axis=1 + ) + + outside_bounds = outside_bounds or (not np.all(inside_mask)) + if not np.any(inside_mask): + continue + + sample_points = sample_points[inside_mask] + bary_inside = bary[inside_mask] + n_inside = sample_points.shape[0] + + normal_tile = np.repeat(normal[None, :], n_inside, axis=0) + perp1_tile = np.repeat(perp1[None, :], n_inside, axis=0) + perp2_tile = np.repeat(perp2[None, :], n_inside, axis=0) + weights_tile = np.full(n_inside, length / subdivisions, dtype=dtype) + faces_tile = np.full(n_inside, face_index, dtype=int) + + points.append(sample_points) + normals.append(normal_tile) + perps1_list.append(perp1_tile) + perps2_list.append(perp2_tile) + weights.append(weights_tile) + faces.append(faces_tile) + barycentric.append(bary_inside) + + if not points: + return None, outside_bounds + + samples = { + "points": np.concatenate(points, axis=0), + "normals": np.concatenate(normals, axis=0), + "perps1": np.concatenate(perps1_list, axis=0), + "perps2": np.concatenate(perps2_list, axis=0), + "weights": np.concatenate(weights, axis=0), + "faces": np.concatenate(faces, axis=0), + "barycentric": np.concatenate(barycentric, axis=0), + } + return samples, outside_bounds + + def _collect_surface_samples_3d( + self, + triangle: NDArray, + face_index: int, + normal: np.ndarray, + perp1: np.ndarray, + perp2: np.ndarray, + area: float, + spacing: float, + sim_min: np.ndarray, + sim_max: np.ndarray, + valid_axes: np.ndarray, + tol: float, + dtype: np.dtype, + ) -> tuple[Optional[dict[str, np.ndarray]], bool]: + """Collect samples when the simulation bounds represent a full 3D region.""" + + edge_lengths = ( + np.linalg.norm(triangle[1] - triangle[0]), + np.linalg.norm(triangle[2] - triangle[1]), + np.linalg.norm(triangle[0] - triangle[2]), + ) + subdivisions = self._subdivision_count(area, spacing, edge_lengths) + barycentric = self._get_barycentric_samples(subdivisions, dtype) + num_samples = barycentric.shape[0] + base_weight = area / num_samples + + sample_points = barycentric @ triangle + + inside_mask = np.all( + sample_points[:, valid_axes] >= (sim_min - tol)[valid_axes], axis=1 + ) & np.all(sample_points[:, valid_axes] <= (sim_max + tol)[valid_axes], axis=1) + outside_bounds = not np.all(inside_mask) + if not np.any(inside_mask): + return None, outside_bounds + + sample_points = sample_points[inside_mask] + bary_inside = barycentric[inside_mask] + n_samples_inside = sample_points.shape[0] + + normal_tile = np.repeat(normal[None, :], n_samples_inside, axis=0) + perp1_tile = np.repeat(perp1[None, :], n_samples_inside, axis=0) + perp2_tile = np.repeat(perp2[None, :], n_samples_inside, axis=0) + weights_tile = np.full(n_samples_inside, base_weight, dtype=dtype) + faces_tile = np.full(n_samples_inside, face_index, dtype=int) + + samples = { + "points": sample_points, + "normals": normal_tile, + "perps1": perp1_tile, + "perps2": perp2_tile, + "weights": weights_tile, + "faces": faces_tile, + "barycentric": bary_inside, + } + return samples, outside_bounds + + @staticmethod + def _triangle_area_and_normal(triangle: NDArray) -> tuple[float, np.ndarray]: + """Return area and outward normal of the provided triangle.""" + + edge01 = triangle[1] - triangle[0] + edge02 = triangle[2] - triangle[0] + cross = np.cross(edge01, edge02) + norm = np.linalg.norm(cross) + if norm <= 0.0: + return 0.0, np.zeros(3, dtype=triangle.dtype) + normal = (cross / norm).astype(triangle.dtype, copy=False) + area = 0.5 * norm + return area, normal + + @staticmethod + def _triangle_plane_segments( + triangle: NDArray, axis: int, plane_value: float, tol: float + ) -> list[tuple[np.ndarray, np.ndarray]]: + """Return intersection segments between a triangle and an axis-aligned plane.""" + + vertices = np.asarray(triangle) + distances = vertices[:, axis] - plane_value + edges = ((0, 1), (1, 2), (2, 0)) + + segments: list[tuple[np.ndarray, np.ndarray]] = [] + points: list[np.ndarray] = [] + + def add_point(pt: np.ndarray) -> None: + for existing in points: + if np.linalg.norm(existing - pt) <= tol: + return + points.append(pt.copy()) + + for i, j in edges: + di = distances[i] + dj = distances[j] + vi = vertices[i] + vj = vertices[j] + + if abs(di) <= tol and abs(dj) <= tol: + segments.append((vi.copy(), vj.copy())) + continue + + if di * dj > 0.0: + continue + + if abs(di) <= tol: + add_point(vi) + continue + + if abs(dj) <= tol: + add_point(vj) + continue + + denom = di - dj + if abs(denom) <= tol: + continue + t = di / denom + if t < 0.0 or t > 1.0: + continue + point = vi + t * (vj - vi) + add_point(point) + + if segments: + return segments + + if len(points) >= 2: + return [(points[0], points[1])] + + return [] + + @staticmethod + def _barycentric_coordinates(triangle: NDArray, points: np.ndarray, tol: float) -> np.ndarray: + """Compute barycentric coordinates of ``points`` with respect to ``triangle``.""" + + pts = np.asarray(points, dtype=triangle.dtype) + v0 = triangle[0] + v1 = triangle[1] + v2 = triangle[2] + v0v1 = v1 - v0 + v0v2 = v2 - v0 + + d00 = float(np.dot(v0v1, v0v1)) + d01 = float(np.dot(v0v1, v0v2)) + d11 = float(np.dot(v0v2, v0v2)) + denom = d00 * d11 - d01 * d01 + if abs(denom) <= tol: + return np.tile( + np.array([1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0], dtype=triangle.dtype), (pts.shape[0], 1) + ) + + v0p = pts - v0 + d20 = v0p @ v0v1 + d21 = v0p @ v0v2 + v = (d11 * d20 - d01 * d21) / denom + w = (d00 * d21 - d01 * d20) / denom + u = 1.0 - v - w + bary = np.stack((u, v, w), axis=1) + return bary.astype(triangle.dtype, copy=False) + + @classmethod + def _subdivision_count( + cls, + area: float, + spacing: float, + edge_lengths: Optional[tuple[float, float, float]] = None, + ) -> int: + """Determine the number of subdivisions needed for the given area and spacing.""" + + spacing = max(float(spacing), np.finfo(float).eps) + + target = np.sqrt(max(area, 0.0)) + area_based = np.ceil(np.sqrt(2.0) * target / spacing) + + edge_based = 0.0 + if edge_lengths: + max_edge = max(edge_lengths) + if max_edge > 0.0: + edge_based = np.ceil(max_edge / spacing) + + subdivisions = max(1, int(max(area_based, edge_based))) + return subdivisions + + def _get_barycentric_samples(self, subdivisions: int, dtype: np.dtype) -> np.ndarray: + """Return barycentric sample coordinates for a subdivision level.""" + + cache = self._barycentric_samples + if subdivisions not in cache: + cache[subdivisions] = self._build_barycentric_samples(subdivisions) + return cache[subdivisions].astype(dtype, copy=False) + + @staticmethod + def _build_barycentric_samples(subdivisions: int) -> np.ndarray: + """Construct barycentric sampling points for a given subdivision level.""" + + if subdivisions <= 1: + return np.array([[1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0]]) + + bary = [] + for i in range(subdivisions): + for j in range(subdivisions - i): + l1 = (i + 1.0 / 3.0) / subdivisions + l2 = (j + 1.0 / 3.0) / subdivisions + l0 = 1.0 - l1 - l2 + bary.append((l0, l1, l2)) + return np.asarray(bary, dtype=float) + + @staticmethod + def subdivide_faces(vertices: NDArray, faces: NDArray) -> tuple[np.ndarray, np.ndarray]: + """Uniformly subdivide each triangular face by inserting edge midpoints.""" + + midpoint_cache: dict[tuple[int, int], int] = {} + verts_list = [np.asarray(v, dtype=float) for v in vertices] + + def midpoint(i: int, j: int) -> int: + key = (i, j) if i < j else (j, i) + if key in midpoint_cache: + return midpoint_cache[key] + vm = 0.5 * (verts_list[i] + verts_list[j]) + verts_list.append(vm) + idx = len(verts_list) - 1 + midpoint_cache[key] = idx + return idx + + new_faces: list[tuple[int, int, int]] = [] + for tri in faces: + a = midpoint(tri[0], tri[1]) + b = midpoint(tri[1], tri[2]) + c = midpoint(tri[2], tri[0]) + new_faces.extend(((tri[0], a, c), (tri[1], b, a), (tri[2], c, b), (a, b, c))) + + verts_arr = np.asarray(verts_list, dtype=float) + return verts_arr, np.asarray(new_faces, dtype=int) + + @staticmethod + def _triangle_tangent_basis( + triangle: NDArray, normal: NDArray + ) -> Optional[tuple[np.ndarray, np.ndarray]]: + """Compute orthonormal tangential vectors for a triangle.""" + + tol = np.finfo(triangle.dtype).eps + edges = [triangle[1] - triangle[0], triangle[2] - triangle[0], triangle[2] - triangle[1]] + + edge = None + for candidate in edges: + length = np.linalg.norm(candidate) + if length > tol: + edge = (candidate / length).astype(triangle.dtype, copy=False) + break + + if edge is None: + return None + + perp1 = edge + perp2 = np.cross(normal, perp1) + perp2_norm = np.linalg.norm(perp2) + if perp2_norm <= tol: + return None + perp2 = (perp2 / perp2_norm).astype(triangle.dtype, copy=False) + return perp1, perp2 diff --git a/tidy3d/_common/components/geometry/polyslab.py b/tidy3d/_common/components/geometry/polyslab.py new file mode 100644 index 0000000000..4aea970634 --- /dev/null +++ b/tidy3d/_common/components/geometry/polyslab.py @@ -0,0 +1,2774 @@ +"""Geometry extruded from polygonal shapes.""" + +from __future__ import annotations + +import math +from copy import copy +from functools import lru_cache +from typing import TYPE_CHECKING, Any + +import autograd.numpy as np +import shapely +from autograd.tracer import getval +from numpy.polynomial.legendre import leggauss as _leggauss +from pydantic import Field, field_validator, model_validator + +from tidy3d._common.components.autograd import TracedArrayFloat2D, get_static +from tidy3d._common.components.autograd.types import TracedFloat +from tidy3d._common.components.autograd.utils import hasbox +from tidy3d._common.components.base import cached_property +from tidy3d._common.components.geometry import base, triangulation +from tidy3d._common.components.transformation import ReflectionFromPlane, RotationAroundAxis +from tidy3d._common.config import config +from tidy3d._common.constants import LARGE_NUMBER, MICROMETER, fp_eps +from tidy3d._common.exceptions import SetupError, Tidy3dImportError, ValidationError +from tidy3d._common.log import log +from tidy3d._common.packaging import verify_packages_import + +if TYPE_CHECKING: + from typing import Optional, Union + + from gdstk import Cell + from numpy.typing import NDArray + from pydantic import PositiveFloat + + from tidy3d._common.compat import Self + from tidy3d._common.components.autograd import AutogradFieldMap + from tidy3d._common.components.autograd.derivative_utils import DerivativeInfo + from tidy3d._common.components.types.base import ( + ArrayFloat1D, + ArrayFloat2D, + ArrayLike, + Axis, + Bound, + Coordinate, + MatrixReal4x4, + PlanePosition, + Shapely, + ) + +# sampling polygon along dilation for validating polygon to be +# non self-intersecting during the entire dilation process +_N_SAMPLE_POLYGON_INTERSECT = 5 + +_IS_CLOSE_RTOL = np.finfo(float).eps + +# Warn for too many divided polyslabs +_COMPLEX_POLYSLAB_DIVISIONS_WARN = 100 + +# Warn before triangulating large polyslabs due to inefficiency +_MAX_POLYSLAB_VERTICES_FOR_TRIANGULATION = 500 + +_MIN_POLYGON_AREA = fp_eps + + +@lru_cache(maxsize=128) +def leggauss(n: int) -> tuple[NDArray, NDArray]: + """Cached version of leggauss with dtype conversions.""" + g, w = _leggauss(n) + return g.astype(config.adjoint.gradient_dtype_float, copy=False), w.astype( + config.adjoint.gradient_dtype_float, copy=False + ) + + +class PolySlab(base.Planar): + """Polygon extruded with optional sidewall angle along axis direction. + + Example + ------- + >>> vertices = np.array([(0,0), (1,0), (1,1)]) + >>> p = PolySlab(vertices=vertices, axis=2, slab_bounds=(-1, 1)) + """ + + slab_bounds: tuple[TracedFloat, TracedFloat] = Field( + title="Slab Bounds", + description="Minimum and maximum positions of the slab along axis dimension.", + json_schema_extra={"units": MICROMETER}, + ) + + dilation: float = Field( + 0.0, + title="Dilation", + description="Dilation of the supplied polygon by shifting each edge along its " + "normal outwards direction by a distance; a negative value corresponds to erosion.", + json_schema_extra={"units": MICROMETER}, + ) + + vertices: TracedArrayFloat2D = Field( + title="Vertices", + description="List of (d1, d2) defining the 2 dimensional positions of the polygon " + "face vertices at the ``reference_plane``. " + "The index of dimension should be in the ascending order: e.g. if " + "the slab normal axis is ``axis=y``, the coordinate of the vertices will be in (x, z)", + json_schema_extra={"units": MICROMETER}, + ) + + @staticmethod + def make_shapely_polygon(vertices: ArrayLike) -> shapely.Polygon: + """Make a shapely polygon from some vertices, first ensures they are untraced.""" + vertices = get_static(vertices) + return shapely.Polygon(vertices) + + @field_validator("slab_bounds") + @classmethod + def slab_bounds_order(cls, val: tuple[float, float]) -> tuple[float, float]: + """Maximum position of the slab should be no smaller than its minimal position.""" + if val[1] < val[0]: + raise SetupError( + "Polyslab.slab_bounds must be specified in the order of " + "minimum and maximum positions of the slab along the axis. " + f"But now the maximum {val[1]} is smaller than the minimum {val[0]}." + ) + return val + + @field_validator("vertices") + @classmethod + def correct_shape(cls, val: ArrayFloat2D) -> ArrayFloat2D: + """Makes sure vertices size is correct. Make sure no intersecting edges.""" + # overall shape of vertices + if val.shape[1] != 2: + raise SetupError( + "PolySlab.vertices must be a 2 dimensional array shaped (N, 2). " + f"Given array with shape of {val.shape}." + ) + # make sure no polygon splitting, islands, 0 area + poly_heal = shapely.make_valid(cls.make_shapely_polygon(val)) + if poly_heal.area < _MIN_POLYGON_AREA: + raise SetupError("The polygon almost collapses to a 1D curve.") + + if not poly_heal.geom_type == "Polygon" or len(poly_heal.interiors) > 0: + raise SetupError( + "Polygon is self-intersecting, resulting in " + "polygon splitting or generation of holes/islands. " + "A general treatment to self-intersecting polygon will be available " + "in future releases." + ) + return val + + @model_validator(mode="after") + def no_complex_self_intersecting_polygon_at_reference_plane(self: Self) -> Self: + """At the reference plane, check if the polygon is self-intersecting. + + There are two types of self-intersection that can occur during dilation: + 1) the one that creates holes/islands, or splits polygons, or removes everything; + 2) the one that does not. + + For 1), we issue an error since it is yet to be supported; + For 2), we heal the polygon, and warn that the polygon has been cleaned up. + """ + val = self.vertices + # no need to validate anything here + if math.isclose(self.dilation, 0): + return self + + val_np = PolySlab._proper_vertices(val) + dist = self.dilation + + # 0) fully eroded + if dist < 0 and dist < -PolySlab._maximal_erosion(val_np): + raise SetupError("Erosion value is too large. The polygon is fully eroded.") + + # no edge events + if not PolySlab._edge_events_detection(val_np, dist, ignore_at_dist=False): + return self + + poly_offset = PolySlab._shift_vertices(val_np, dist)[0] + if PolySlab._area(poly_offset) < fp_eps**2: + raise SetupError("Erosion value is too large. The polygon is fully eroded.") + + # edge events + poly_offset = shapely.make_valid(self.make_shapely_polygon(poly_offset)) + # 1) polygon split or create holes/islands + if not poly_offset.geom_type == "Polygon" or len(poly_offset.interiors) > 0: + raise SetupError( + "Dilation/Erosion value is too large, resulting in " + "polygon splitting or generation of holes/islands. " + "A general treatment to self-intersecting polygon will be available " + "in future releases." + ) + + # case 2 + log.warning( + "The dilation/erosion value is too large. resulting in a " + "self-intersecting polygon. " + "The vertices have been modified to make a valid polygon." + ) + return self + + @model_validator(mode="after") + def no_self_intersecting_polygon_during_extrusion(self: Self) -> Self: + """In this simple polyslab, we don't support self-intersecting polygons yet, meaning that + any normal cross section of the PolySlab cannot be self-intersecting. This part checks + if any self-interction will occur during extrusion with non-zero sidewall angle. + + There are two types of self-intersection, known as edge events, + that can occur during dilation: + 1) neighboring vertex-vertex crossing. This type of edge event can be treated with + ``ComplexPolySlab`` which divides the polyslab into a list of simple polyslabs. + + 2) other types of edge events that can create holes/islands or split polygons. + To detect this, we sample _N_SAMPLE_POLYGON_INTERSECT cross sections to see if any creation + of polygons/holes, and changes in vertices number. + """ + val = self.vertices + + # no need to validate anything here + # sidewall_angle may be autograd-traced; use static value for this check only + if math.isclose(getval(self.sidewall_angle), 0): + return self + + # apply dilation + poly_ref = PolySlab._proper_vertices(val) + if not math.isclose(self.dilation, 0): + poly_ref = PolySlab._shift_vertices(poly_ref, self.dilation)[0] + poly_ref = PolySlab._heal_polygon(poly_ref) + + slab_bounds = get_static(self.slab_bounds) + slab_min, slab_max = slab_bounds + + # first, check vertex-vertex crossing at any point during extrusion + length = slab_bounds[1] - slab_bounds[0] + dist = [-length * np.tan(self.sidewall_angle)] + # reverse the dilation value if it's defined on the top + if self.reference_plane == "top": + dist = [-dist[0]] + # for middle, both direction needs to be examined + elif self.reference_plane == "middle": + dist = [dist[0] / 2, -dist[0] / 2] + + # capture vertex crossing events + max_thick = [] + for dist_val in dist: + max_dist = PolySlab._neighbor_vertices_crossing_detection(poly_ref, dist_val) + + if max_dist is not None: + max_thick.append(max_dist / abs(dist_val) * length) + + if len(max_thick) > 0: + max_thick = min(max_thick) + raise SetupError( + "Sidewall angle or structure thickness is so large that the polygon " + "is self-intersecting during extrusion. " + f"Please either reduce structure thickness to be < {max_thick:.3e}, " + "or use our plugin 'ComplexPolySlab' to divide the complex polyslab " + "into a list of simple polyslabs." + ) + + # vertex-edge crossing event. + for dist_val in dist: + if PolySlab._edge_events_detection(poly_ref, dist_val): + raise SetupError( + "Sidewall angle or structure thickness is too large, " + "resulting in polygon splitting or generation of holes/islands. " + "A general treatment to self-intersecting polygon will be available " + "in future releases." + ) + return self + + @classmethod + def from_gds( + cls, + gds_cell: Cell, + axis: Axis, + slab_bounds: tuple[float, float], + gds_layer: int, + gds_dtype: Optional[int] = None, + gds_scale: PositiveFloat = 1.0, + dilation: float = 0.0, + sidewall_angle: float = 0, + reference_plane: PlanePosition = "middle", + ) -> list[PolySlab]: + """Import :class:`PolySlab` from a ``gdstk.Cell``. + + Parameters + ---------- + gds_cell : gdstk.Cell + ``gdstk.Cell`` containing 2D geometric data. + axis : int + Integer index into the polygon's slab axis. (0,1,2) -> (x,y,z). + slab_bounds: tuple[float, float] + Minimum and maximum positions of the slab along ``axis``. + gds_layer : int + Layer index in the ``gds_cell``. + gds_dtype : int = None + Data-type index in the ``gds_cell``. + If ``None``, imports all data for this layer into the returned list. + gds_scale : float = 1.0 + Length scale used in GDS file in units of MICROMETER. + For example, if gds file uses nanometers, set ``gds_scale=1e-3``. + Must be positive. + dilation : float = 0.0 + Dilation of the polygon in the base by shifting each edge along its + normal outwards direction by a distance; + a negative value corresponds to erosion. + sidewall_angle : float = 0 + Angle of the sidewall. + ``sidewall_angle=0`` (default) specifies vertical wall, + while ``0 list[ArrayFloat2D]: + """Import :class:`PolySlab` from a ``gdstk.Cell``. + + Parameters + ---------- + gds_cell : gdstk.Cell + ``gdstk.Cell`` containing 2D geometric data. + gds_layer : int + Layer index in the ``gds_cell``. + gds_dtype : int = None + Data-type index in the ``gds_cell``. + If ``None``, imports all data for this layer into the returned list. + gds_scale : float = 1.0 + Length scale used in GDS file in units of MICROMETER. + For example, if gds file uses nanometers, set ``gds_scale=1e-3``. + Must be positive. + + Returns + ------- + list[ArrayFloat2D] + List of :class:`.ArrayFloat2D` + """ + import gdstk + + gds_cell_class_name = str(gds_cell.__class__) + if not isinstance(gds_cell, gdstk.Cell): + if ( + "gdstk" in gds_cell_class_name + ): # Check if it might be a gdstk cell but gdstk is not found + raise Tidy3dImportError( + "Module 'gdstk' not found. It is required to import gdstk cells." + ) + raise ValueError( + f"validate 'gds_cell' of type '{gds_cell_class_name}' " + "does not seem to be associated with 'gdstk' package " + "and therefore can't be loaded by Tidy3D." + ) + + all_vertices = base.Geometry.load_gds_vertices_gdstk( + gds_cell=gds_cell, + gds_layer=gds_layer, + gds_dtype=gds_dtype, + gds_scale=gds_scale, + ) + + # convert vertices into polyslabs + polygons = [PolySlab.make_shapely_polygon(vertices).buffer(0) for vertices in all_vertices] + polys_union = shapely.unary_union(polygons, grid_size=base.POLY_GRID_SIZE) + + if polys_union.geom_type == "Polygon": + all_vertices = [np.array(polys_union.exterior.coords)] + elif polys_union.geom_type == "MultiPolygon": + all_vertices = [np.array(polygon.exterior.coords) for polygon in polys_union.geoms] + return all_vertices + + @property + def center_axis(self) -> float: + """Gets the position of the center of the geometry in the out of plane dimension.""" + zmin, zmax = self.slab_bounds + if np.isneginf(zmin) and np.isposinf(zmax): + return 0.0 + zmin = max(zmin, -LARGE_NUMBER) + zmax = min(zmax, LARGE_NUMBER) + return (zmax + zmin) / 2.0 + + @property + def length_axis(self) -> float: + """Gets the length of the geometry along the out of plane dimension.""" + zmin, zmax = self.slab_bounds + return zmax - zmin + + @property + def finite_length_axis(self) -> float: + """Gets the length of the PolySlab along the out of plane dimension. + First clips the slab bounds to LARGE_NUMBER and then returns difference. + """ + zmin, zmax = self.slab_bounds + zmin = max(zmin, -LARGE_NUMBER) + zmax = min(zmax, LARGE_NUMBER) + return zmax - zmin + + @cached_property + def reference_polygon(self) -> NDArray: + """The polygon at the reference plane. + + Returns + ------- + ArrayLike[float, float] + The vertices of the polygon at the reference plane. + """ + vertices = self._proper_vertices(self.vertices) + if math.isclose(self.dilation, 0): + return vertices + offset_vertices = self._shift_vertices(vertices, self.dilation)[0] + return self._heal_polygon(offset_vertices) + + @cached_property + def middle_polygon(self) -> NDArray: + """The polygon at the middle. + + Returns + ------- + ArrayLike[float, float] + The vertices of the polygon at the middle. + """ + + dist = self._extrusion_length_to_offset_distance(self.finite_length_axis / 2) + if self.reference_plane == "bottom": + return self._shift_vertices(self.reference_polygon, dist)[0] + if self.reference_plane == "top": + return self._shift_vertices(self.reference_polygon, -dist)[0] + # middle case + return self.reference_polygon + + @cached_property + def base_polygon(self) -> NDArray: + """The polygon at the base, derived from the ``middle_polygon``. + + Returns + ------- + ArrayLike[float, float] + The vertices of the polygon at the base. + """ + if self.reference_plane == "bottom": + return self.reference_polygon + dist = self._extrusion_length_to_offset_distance(-self.finite_length_axis / 2) + return self._shift_vertices(self.middle_polygon, dist)[0] + + @cached_property + def top_polygon(self) -> NDArray: + """The polygon at the top, derived from the ``middle_polygon``. + + Returns + ------- + ArrayLike[float, float] + The vertices of the polygon at the top. + """ + if self.reference_plane == "top": + return self.reference_polygon + dist = self._extrusion_length_to_offset_distance(self.finite_length_axis / 2) + return self._shift_vertices(self.middle_polygon, dist)[0] + + @cached_property + def _normal_2dmaterial(self) -> Axis: + """Get the normal to the given geometry, checking that it is a 2D geometry.""" + if self.slab_bounds[0] != self.slab_bounds[1]: + raise ValidationError("'Medium2D' requires the 'PolySlab' bounds to be equal.") + return self.axis + + def _update_from_bounds(self, bounds: tuple[float, float], axis: Axis) -> PolySlab: + """Returns an updated geometry which has been transformed to fit within ``bounds`` + along the ``axis`` direction.""" + if axis != self.axis: + raise ValueError( + f"'_update_from_bounds' may only be applied along axis '{self.axis}', " + f"but was given axis '{axis}'." + ) + return self.updated_copy(slab_bounds=tuple(bounds)) + + @cached_property + def is_ccw(self) -> bool: + """Is this ``PolySlab`` CCW-oriented?""" + return PolySlab._area(self.vertices) > 0 + + def inside(self, x: NDArray[float], y: NDArray[float], z: NDArray[float]) -> NDArray[bool]: + """For input arrays ``x``, ``y``, ``z`` of arbitrary but identical shape, return an array + with the same shape which is ``True`` for every point in zip(x, y, z) that is inside the + volume of the :class:`Geometry`, and ``False`` otherwise. + + Note + ---- + For slanted sidewalls, this function only works if x, y, and z are arrays produced by a + ``meshgrid call``, i.e. 3D arrays and each is constant along one axis. + + Parameters + ---------- + x : np.ndarray[float] + Array of point positions in x direction. + y : np.ndarray[float] + Array of point positions in y direction. + z : np.ndarray[float] + Array of point positions in z direction. + + Returns + ------- + np.ndarray[bool] + ``True`` for every point that is inside the geometry. + """ + self._ensure_equal_shape(x, y, z) + + z, (x, y) = self.pop_axis((x, y, z), axis=self.axis) + + z0 = self.center_axis + dist_z = np.abs(z - z0) + inside_height = dist_z <= (self.finite_length_axis / 2) + + # avoid going into face checking if no points are inside slab bounds + if not np.any(inside_height): + return inside_height + + # check what points are inside polygon cross section (face) + z_local = z - z0 # distance to the middle + dist = -z_local * self._tanq + + if isinstance(x, np.ndarray): + inside_polygon = np.zeros_like(inside_height) + xs_slab = x[inside_height] + ys_slab = y[inside_height] + + # vertical sidewall + if math.isclose(self.sidewall_angle, 0): + face_polygon = shapely.Polygon(self.reference_polygon).buffer(fp_eps) + shapely.prepare(face_polygon) + inside_polygon_slab = shapely.contains_xy(face_polygon, x=xs_slab, y=ys_slab) + inside_polygon[inside_height] = inside_polygon_slab + # slanted sidewall, offsetting vertices at each z + else: + # a helper function for moving axis + def _move_axis(arr: NDArray) -> NDArray: + return np.moveaxis(arr, source=self.axis, destination=-1) + + def _move_axis_reverse(arr: NDArray) -> NDArray: + return np.moveaxis(arr, source=-1, destination=self.axis) + + inside_polygon_axis = _move_axis(inside_polygon) + x_axis = _move_axis(x) + y_axis = _move_axis(y) + + for z_i in range(z.shape[self.axis]): + if not _move_axis(inside_height)[0, 0, z_i]: + continue + vertices_z = self._shift_vertices( + self.middle_polygon, _move_axis(dist)[0, 0, z_i] + )[0] + face_polygon = shapely.Polygon(vertices_z).buffer(fp_eps) + shapely.prepare(face_polygon) + xs = x_axis[:, :, 0].flatten() + ys = y_axis[:, :, 0].flatten() + inside_polygon_slab = shapely.contains_xy(face_polygon, x=xs, y=ys) + inside_polygon_axis[:, :, z_i] = inside_polygon_slab.reshape(x_axis.shape[:2]) + inside_polygon = _move_axis_reverse(inside_polygon_axis) + else: + vertices_z = self._shift_vertices(self.middle_polygon, dist)[0] + face_polygon = self.make_shapely_polygon(vertices_z).buffer(fp_eps) + point = shapely.Point(x, y) + inside_polygon = face_polygon.covers(point) + return inside_height * inside_polygon + + @verify_packages_import(["trimesh"]) + def _do_intersections_tilted_plane( + self, + normal: Coordinate, + origin: Coordinate, + to_2D: MatrixReal4x4, + quad_segs: Optional[int] = None, + ) -> list[Shapely]: + """Return a list of shapely geometries at the plane specified by normal and origin. + + Parameters + ---------- + normal : Coordinate + Vector defining the normal direction to the plane. + origin : Coordinate + Vector defining the plane origin. + to_2D : MatrixReal4x4 + Transformation matrix to apply to resulting shapes. + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. Not used for PolySlab geometry. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation `_. + """ + import trimesh + + if len(self.base_polygon) > _MAX_POLYSLAB_VERTICES_FOR_TRIANGULATION: + log.warning( + f"Processing PolySlabs with over {_MAX_POLYSLAB_VERTICES_FOR_TRIANGULATION} vertices can be slow.", + log_once=True, + ) + base_triangles = triangulation.triangulate(self.base_polygon) + top_triangles = ( + base_triangles + if math.isclose(self.sidewall_angle, 0) + else triangulation.triangulate(self.top_polygon) + ) + + n = len(self.base_polygon) + faces = ( + [[a, b, c] for c, b, a in base_triangles] + + [[n + a, n + b, n + c] for a, b, c in top_triangles] + + [(i, (i + 1) % n, n + i) for i in range(n)] + + [((i + 1) % n, n + ((i + 1) % n), n + i) for i in range(n)] + ) + + x = np.hstack((self.base_polygon[:, 0], self.top_polygon[:, 0])) + y = np.hstack((self.base_polygon[:, 1], self.top_polygon[:, 1])) + z = np.hstack((np.full(n, self.slab_bounds[0]), np.full(n, self.slab_bounds[1]))) + vertices = np.vstack(self.unpop_axis(z, (x, y), self.axis)).T + mesh = trimesh.Trimesh(vertices, faces) + + section = mesh.section(plane_origin=origin, plane_normal=normal) + if section is None: + return [] + path, _ = section.to_2D(to_2D=to_2D) + return path.polygons_full + + def _intersections_normal(self, z: float, quad_segs: Optional[int] = None) -> list[Shapely]: + """Find shapely geometries intersecting planar geometry with axis normal to slab. + + Parameters + ---------- + z : float + Position along the axis normal to slab. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation `_. + """ + if math.isclose(self.sidewall_angle, 0): + return [self.make_shapely_polygon(self.reference_polygon)] + + z0 = self.center_axis + z_local = z - z0 # distance to the middle + dist = -z_local * self._tanq + vertices_z = self._shift_vertices(self.middle_polygon, dist)[0] + return [self.make_shapely_polygon(vertices_z)] + + def _intersections_side(self, position: float, axis: int) -> list[Shapely]: + """Find shapely geometries intersecting planar geometry with axis orthogonal to slab. + + For slanted polyslab, the procedure is as follows, + 1) Find out all z-coordinates where the plane will intersect directly with a vertex. + Denote the coordinates as (z_0, z_1, z_2, ... ) + 2) Find out all polygons that can be formed between z_i and z_{i+1}. There are two + types of polygons: + a) formed by the plane intersecting the edges + b) formed by the plane intersecting the vertices. + For either type, one needs to compute: + i) intersecting position + ii) angle between the plane and the intersecting edge + For a), both are straightforward to compute; while for b), one needs to compute + which edge the plane will slide into. + 3) Looping through z_i, and merge all polygons. The partition by z_i is because once + the plane intersects the vertex, it can intersect with other edges during + the extrusion. + + Parameters + ---------- + position : float + Position along ``axis``. + axis : int + Integer index into 'xyz' (0,1,2). + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation `_. + """ + + # find out all z_i where the plane will intersect the vertex + z0 = self.center_axis + z_base = z0 - self.finite_length_axis / 2 + + axis_ordered = self._order_axis(axis) + height_list = self._find_intersecting_height(position, axis_ordered) + polys = [] + + # looping through z_i to assemble the polygons + height_list = np.append(height_list, self.finite_length_axis) + h_base = 0.0 + for h_top in height_list: + # length within between top and bottom + h_length = h_top - h_base + + # coordinate of each subsection + z_min = z_base + h_base + z_max = np.inf if np.isposinf(h_top) else z_base + h_top + + # for vertical sidewall, no need for complications + if math.isclose(self.sidewall_angle, 0): + ints_y, ints_angle = self._find_intersecting_ys_angle_vertical( + self.reference_polygon, position, axis_ordered + ) + else: + # for slanted sidewall, move up by `fp_eps` in case vertices are degenerate at the base. + dist = -(h_base - self.finite_length_axis / 2 + fp_eps) * self._tanq + vertices = self._shift_vertices(self.middle_polygon, dist)[0] + ints_y, ints_angle = self._find_intersecting_ys_angle_slant( + vertices, position, axis_ordered + ) + + # make polygon with intersections and z axis information + for y_index in range(len(ints_y) // 2): + y_min = ints_y[2 * y_index] + y_max = ints_y[2 * y_index + 1] + minx, miny = self._order_by_axis(plane_val=y_min, axis_val=z_min, axis=axis) + maxx, maxy = self._order_by_axis(plane_val=y_max, axis_val=z_max, axis=axis) + + if math.isclose(self.sidewall_angle, 0): + polys.append(self.make_shapely_box(minx, miny, maxx, maxy)) + else: + angle_min = ints_angle[2 * y_index] + angle_max = ints_angle[2 * y_index + 1] + + angle_min = np.arctan(np.tan(self.sidewall_angle) / np.sin(angle_min)) + angle_max = np.arctan(np.tan(self.sidewall_angle) / np.sin(angle_max)) + + dy_min = h_length * np.tan(angle_min) + dy_max = h_length * np.tan(angle_max) + + x1, y1 = self._order_by_axis(plane_val=y_min, axis_val=z_min, axis=axis) + x2, y2 = self._order_by_axis(plane_val=y_max, axis_val=z_min, axis=axis) + x3, y3 = self._order_by_axis( + plane_val=y_max - dy_max, axis_val=z_max, axis=axis + ) + x4, y4 = self._order_by_axis( + plane_val=y_min + dy_min, axis_val=z_max, axis=axis + ) + vertices = ((x1, y1), (x2, y2), (x3, y3), (x4, y4)) + polys.append(self.make_shapely_polygon(vertices).buffer(0)) + # update the base coordinate for the next subsection + h_base = h_top + + # merge touching polygons + polys_union = shapely.unary_union(polys, grid_size=base.POLY_GRID_SIZE) + if polys_union.geom_type == "Polygon": + return [polys_union] + if polys_union.geom_type == "MultiPolygon": + return polys_union.geoms + # in other cases, just return the original unmerged polygons + return polys + + def _find_intersecting_height(self, position: float, axis: int) -> NDArray: + """Found a list of height where the plane will intersect with the vertices; + For vertical sidewall, just return np.array([]). + Assumes axis is handles so this function works on xy plane. + + Parameters + ---------- + position : float + position along axis. + axis : int + Integer index into 'xyz' (0,1,2). + + Returns + np.ndarray + Height (relative to the base) where the plane will intersect with vertices. + """ + if math.isclose(self.sidewall_angle, 0): + return np.array([]) + + # shift rate + dist = 1.0 + shift_x, shift_y = PolySlab._shift_vertices(self.middle_polygon, dist)[2] + shift_val = shift_x if axis == 0 else shift_y + shift_val[np.isclose(shift_val, 0, rtol=_IS_CLOSE_RTOL)] = np.inf # for static vertices + + # distance to the plane in the direction of vertex shifting + distance = self.middle_polygon[:, axis] - position + height = distance / self._tanq / shift_val + self.finite_length_axis / 2 + height = np.unique(height) + # further filter very close ones + is_not_too_close = np.insert((np.diff(height) > fp_eps), 0, True) + height = height[is_not_too_close] + + height = height[height > fp_eps] + height = height[height < self.finite_length_axis - fp_eps] + return height + + def _find_intersecting_ys_angle_vertical( + self, + vertices: NDArray, + position: float, + axis: int, + exclude_on_vertices: bool = False, + ) -> tuple[NDArray, NDArray, NDArray]: + """Finds pairs of forward and backwards vertices where polygon intersects position at axis, + Find intersection point (in y) assuming straight line,and intersecting angle between plane + and edges. (For unslanted polyslab). + Assumes axis is handles so this function works on xy plane. + + Parameters + ---------- + vertices : np.ndarray + Shape (N, 2) defining the polygon vertices in the xy-plane. + position : float + position along axis. + axis : int + Integer index into 'xyz' (0,1,2). + exclude_on_vertices : bool = False + Whether to exclude those intersecting directly with the vertices. + + Returns + ------- + Union[np.ndarray, np.ndarray] + List of intersection points along y direction. + List of angles between plane and edges. + """ + + vertices_axis = vertices + + # flip vertices x,y for axis = y + if axis == 1: + vertices_axis = np.roll(vertices_axis, shift=1, axis=1) + + # get the forward vertices + vertices_f = np.roll(vertices_axis, shift=-1, axis=0) + + # x coordinate of the two sets of vertices + x_vertices_f, _ = vertices_f.T + x_vertices_axis, _ = vertices_axis.T + + # Find which segments intersect: + # 1. Strictly crossing: one endpoint strictly left, one strictly right + # 2. Touching: exactly one endpoint on the plane (xor), which excludes + # edges lying entirely on the plane (both endpoints at position). + orig_on_plane = np.isclose(x_vertices_axis, position, rtol=_IS_CLOSE_RTOL) + f_on_plane = np.roll(orig_on_plane, shift=-1) + crosses_b = (x_vertices_axis > position) & (x_vertices_f < position) + crosses_f = (x_vertices_axis < position) & (x_vertices_f > position) + + if exclude_on_vertices: + # exclude vertices at the position + not_touching = np.logical_not(orig_on_plane | f_on_plane) + intersects_segment = (crosses_b | crosses_f) & not_touching + else: + single_touch = np.logical_xor(orig_on_plane, f_on_plane) + intersects_segment = crosses_b | crosses_f | single_touch + + iverts_b = vertices_axis[intersects_segment] + iverts_f = vertices_f[intersects_segment] + + # intersecting positions and angles + ints_y = [] + ints_angle = [] + for vertices_f_local, vertices_b_local in zip(iverts_b, iverts_f): + x1, y1 = vertices_f_local + x2, y2 = vertices_b_local + slope = (y2 - y1) / (x2 - x1) + y = y1 + slope * (position - x1) + ints_y.append(y) + ints_angle.append(np.pi / 2 - np.arctan(np.abs(slope))) + + ints_y = np.array(ints_y) + ints_angle = np.array(ints_angle) + + # Get rid of duplicate intersection points (vertices counted twice if directly on position) + ints_y_sort, sort_index = np.unique(ints_y, return_index=True) + ints_angle_sort = ints_angle[sort_index] + + # For tangent touches (vertex on plane, both neighbors on same side), + # add y-value back to form a degenerate pair + if not exclude_on_vertices: + n = len(vertices_axis) + for idx in np.where(orig_on_plane)[0]: + prev_on = orig_on_plane[(idx - 1) % n] + next_on = orig_on_plane[(idx + 1) % n] + if not prev_on and not next_on: + prev_side = x_vertices_axis[(idx - 1) % n] > position + next_side = x_vertices_axis[(idx + 1) % n] > position + if prev_side == next_side: + ints_y_sort = np.append(ints_y_sort, vertices_axis[idx, 1]) + ints_angle_sort = np.append(ints_angle_sort, 0) + + sort_index = np.argsort(ints_y_sort) + ints_y_sort = ints_y_sort[sort_index] + ints_angle_sort = ints_angle_sort[sort_index] + return ints_y_sort, ints_angle_sort + + def _find_intersecting_ys_angle_slant( + self, vertices: NDArray, position: float, axis: int + ) -> tuple[NDArray, NDArray, NDArray]: + """Finds pairs of forward and backwards vertices where polygon intersects position at axis, + Find intersection point (in y) assuming straight line,and intersecting angle between plane + and edges. (For slanted polyslab) + Assumes axis is handles so this function works on xy plane. + + Parameters + ---------- + vertices : np.ndarray + Shape (N, 2) defining the polygon vertices in the xy-plane. + position : float + position along axis. + axis : int + Integer index into 'xyz' (0,1,2). + + Returns + ------- + Union[np.ndarray, np.ndarray] + List of intersection points along y direction. + List of angles between plane and edges. + """ + + vertices_axis = vertices.copy() + # flip vertices x,y for axis = y + if axis == 1: + vertices_axis = np.roll(vertices_axis, shift=1, axis=1) + + # get the forward vertices + vertices_f = np.roll(vertices_axis, shift=-1, axis=0) + # get the backward vertices + vertices_b = np.roll(vertices_axis, shift=1, axis=0) + + ## First part, plane intersects with edges, same as vertical + ints_y, ints_angle = self._find_intersecting_ys_angle_vertical( + vertices, position, axis, exclude_on_vertices=True + ) + ints_y = ints_y.tolist() + ints_angle = ints_angle.tolist() + + ## Second part, plane intersects directly with vertices + # vertices on the intersection + intersects_on = np.isclose(vertices_axis[:, 0], position, rtol=_IS_CLOSE_RTOL) + iverts_on = vertices_axis[intersects_on] + # position of the neighbouring vertices + iverts_b = vertices_b[intersects_on] + iverts_f = vertices_f[intersects_on] + # shift rate + dist = -np.sign(self.sidewall_angle) + shift_x, shift_y = self._shift_vertices(self.middle_polygon, dist)[2] + shift_val = shift_x if axis == 0 else shift_y + shift_val = shift_val[intersects_on] + + for vertices_f_local, vertices_b_local, vertices_on_local, shift_local in zip( + iverts_f, iverts_b, iverts_on, shift_val + ): + x_on, y_on = vertices_on_local + x_f, y_f = vertices_f_local + x_b, y_b = vertices_b_local + + num_added = 0 # keep track the number of added vertices + slope = [] # list of slopes for added vertices + # case 1, shifting velocity is 0 + if np.isclose(shift_local, 0, rtol=_IS_CLOSE_RTOL): + ints_y.append(y_on) + # Slope w.r.t. forward and backward should equal, + # just pick one of them. + slope.append((y_on - y_b) / (x_on - x_b)) + ints_angle.append(np.pi / 2 - np.arctan(np.abs(slope[0]))) + continue + + # case 2, shifting towards backward direction + if (x_b - position) * shift_local < 0: + ints_y.append(y_on) + slope.append((y_on - y_b) / (x_on - x_b)) + num_added += 1 + + # case 3, shifting towards forward direction + if (x_f - position) * shift_local < 0: + ints_y.append(y_on) + slope.append((y_on - y_f) / (x_on - x_f)) + num_added += 1 + + # in case 2, and case 3, if just num_added = 1 + if num_added == 1: + ints_angle.append(np.pi / 2 - np.arctan(np.abs(slope[0]))) + # if num_added = 2, the order of the two new vertices needs to handled correctly; + # it should be sorted according to the -slope * moving direction + elif num_added == 2: + dressed_slope = [-s_i * shift_local for s_i in slope] + sort_index = np.argsort(np.array(dressed_slope)) + sorted_slope = np.array(slope)[sort_index] + + ints_angle.append(np.pi / 2 - np.arctan(np.abs(sorted_slope[0]))) + ints_angle.append(np.pi / 2 - np.arctan(np.abs(sorted_slope[1]))) + + ints_y = np.array(ints_y) + ints_angle = np.array(ints_angle) + + sort_index = np.argsort(ints_y) + ints_y_sort = ints_y[sort_index] + ints_angle_sort = ints_angle[sort_index] + + return ints_y_sort, ints_angle_sort + + @cached_property + def bounds(self) -> Bound: + """Returns bounding box min and max coordinates. The dilation and slant angle are not + taken into account exactly for speed. Instead, the polygon may be slightly smaller than + the returned bounds, but it should always be fully contained. + + Returns + ------- + tuple[float, float, float], tuple[float, float float] + Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. + """ + + # check for the maximum possible contribution from dilation/slant on each side + max_offset = self.dilation + # sidewall_angle may be autograd-traced; unbox for this check + if not math.isclose(getval(self.sidewall_angle), 0): + if self.reference_plane == "bottom": + max_offset += max(0, -self._tanq * self.finite_length_axis) + elif self.reference_plane == "top": + max_offset += max(0, self._tanq * self.finite_length_axis) + elif self.reference_plane == "middle": + max_offset += max(0, abs(self._tanq) * self.finite_length_axis / 2) + + # special care when dilated + if max_offset > 0: + dilated_vertices = self._shift_vertices( + self._proper_vertices(self.vertices), max_offset + )[0] + xmin, ymin = np.amin(dilated_vertices, axis=0) + xmax, ymax = np.amax(dilated_vertices, axis=0) + else: + # otherwise, bounds are directly based on the supplied vertices + xmin, ymin = np.amin(self.vertices, axis=0) + xmax, ymax = np.amax(self.vertices, axis=0) + + # get bounds in (local) z + zmin, zmax = self.slab_bounds + + # rearrange axes + coords_min = self.unpop_axis(zmin, (xmin, ymin), axis=self.axis) + coords_max = self.unpop_axis(zmax, (xmax, ymax), axis=self.axis) + return (tuple(coords_min), tuple(coords_max)) + + def _extrusion_length_to_offset_distance(self, extrusion: float) -> float: + """Convert extrusion length to offset distance.""" + if math.isclose(self.sidewall_angle, 0): + return 0 + return -extrusion * self._tanq + + @staticmethod + def _area(vertices: NDArray) -> float: + """Compute the signed polygon area (positive for CCW orientation). + + Parameters + ---------- + vertices : np.ndarray + Shape (N, 2) defining the polygon vertices in the xy-plane. + + Returns + ------- + float + Signed polygon area (positive for CCW orientation). + """ + vert_shift = np.roll(vertices, axis=0, shift=-1) + + xs, ys = vertices.T + xs_shift, ys_shift = vert_shift.T + + term1 = xs * ys_shift + term2 = ys * xs_shift + return np.sum(term1 - term2) * 0.5 + + @staticmethod + def _perimeter(vertices: NDArray) -> float: + """Compute the polygon perimeter. + + Parameters + ---------- + vertices : np.ndarray + Shape (N, 2) defining the polygon vertices in the xy-plane. + + Returns + ------- + float + Polygon perimeter. + """ + + vert_shift = np.roll(vertices, axis=0, shift=-1) + squared_diffs = (vertices - vert_shift) ** 2 + + # distance along each edge + dists = np.sqrt(squared_diffs.sum(axis=-1)) + + # total distance along all edges + return np.sum(dists) + + @staticmethod + def _orient(vertices: NDArray) -> NDArray: + """Return a CCW-oriented polygon. + + Parameters + ---------- + vertices : np.ndarray + Shape (N, 2) defining the polygon vertices in the xy-plane. + + Returns + ------- + np.ndarray + Vertices of a CCW-oriented polygon. + """ + return vertices if PolySlab._area(vertices) > 0 else vertices[::-1, :] + + @staticmethod + def _remove_duplicate_vertices(vertices: NDArray) -> NDArray: + """Remove redundant/identical nearest neighbour vertices. + + Parameters + ---------- + vertices : np.ndarray + Shape (N, 2) defining the polygon vertices in the xy-plane. + + Returns + ------- + np.ndarray + Vertices of polygon. + """ + + vertices_f = np.roll(vertices, shift=-1, axis=0) + vertices_diff = np.linalg.norm(vertices - vertices_f, axis=1) + return vertices[~np.isclose(vertices_diff, 0, rtol=_IS_CLOSE_RTOL)] + + @staticmethod + def _proper_vertices(vertices: ArrayFloat2D) -> NDArray: + """convert vertices to np.array format, + removing duplicate neighbouring vertices, + and oriented in CCW direction. + + Returns + ------- + ArrayLike[float, float] + The vertices of the polygon for internal use. + """ + vertices_np = np.array(vertices) + return PolySlab._orient(PolySlab._remove_duplicate_vertices(vertices_np)) + + @staticmethod + def _edge_events_detection( + proper_vertices: NDArray, dilation: float, ignore_at_dist: bool = True + ) -> bool: + """Detect any edge events within the offset distance ``dilation``. + If ``ignore_at_dist=True``, the edge event at ``dist`` is ignored. + """ + + # ignore the event that occurs right at the offset distance + if ignore_at_dist: + dilation -= fp_eps * dilation / abs(dilation) + # number of vertices before offsetting + num_vertices = proper_vertices.shape[0] + + # 0) fully eroded? + if dilation < 0 and dilation < -PolySlab._maximal_erosion(proper_vertices): + return True + + # sample at a few dilation values + dist_list = ( + dilation + * np.linspace( + 0, 1, 1 + _N_SAMPLE_POLYGON_INTERSECT, dtype=config.adjoint.gradient_dtype_float + )[1:] + ) + for dist in dist_list: + # offset: we offset the vertices first, and then use shapely to make it proper + # in principle, one can offset with shapely.buffer directly, but shapely somehow + # automatically removes some vertices even though no change of topology. + poly_offset = PolySlab._shift_vertices(proper_vertices, dist)[0] + # flipped winding number + if PolySlab._area(poly_offset) < fp_eps**2: + return True + + poly_offset = shapely.make_valid(PolySlab.make_shapely_polygon(poly_offset)) + # 1) polygon split or create holes/islands + if not poly_offset.geom_type == "Polygon" or len(poly_offset.interiors) > 0: + return True + + # 2) reduction in vertex number + offset_vertices = PolySlab._proper_vertices(poly_offset.exterior.coords) + if offset_vertices.shape[0] != num_vertices: + return True + + # 3) some split polygon might fully disappear after the offset, but they + # can be detected if we offset back. + poly_offset_back = shapely.make_valid( + PolySlab.make_shapely_polygon(PolySlab._shift_vertices(offset_vertices, -dist)[0]) + ) + if poly_offset_back.geom_type == "MultiPolygon" or len(poly_offset_back.interiors) > 0: + return True + offset_back_vertices = poly_offset_back.exterior.coords + if PolySlab._proper_vertices(offset_back_vertices).shape[0] != num_vertices: + return True + + return False + + @staticmethod + def _neighbor_vertices_crossing_detection( + vertices: NDArray, dist: float, ignore_at_dist: bool = True + ) -> float: + """Detect if neighboring vertices will cross after a dilation distance dist. + + Parameters + ---------- + vertices : np.ndarray + Shape (N, 2) defining the polygon vertices in the xy-plane. + dist : float + Distance to offset. + ignore_at_dist : bool, optional + whether to ignore the event right at ``dist`. + + Returns + ------- + float + the absolute value of the maximal allowed dilation + if there are any crossing, otherwise return ``None``. + """ + # ignore the event that occurs right at the offset distance + if ignore_at_dist: + dist -= fp_eps * dist / abs(dist) + + edge_length, edge_reduction = PolySlab._edge_length_and_reduction_rate(vertices) + length_remaining = edge_length - edge_reduction * dist + + if np.any(length_remaining < 0): + index_oversized = length_remaining < 0 + max_dist = np.min( + np.abs(edge_length[index_oversized] / edge_reduction[index_oversized]) + ) + return max_dist + return None + + @staticmethod + def array_to_vertices(arr_vertices: NDArray) -> ArrayFloat2D: + """Converts a numpy array of vertices to a list of tuples.""" + return list(arr_vertices) + + @staticmethod + def vertices_to_array(vertices_tuple: ArrayFloat2D) -> NDArray: + """Converts a list of tuples (vertices) to a numpy array.""" + return np.array(vertices_tuple) + + @cached_property + def interior_angle(self) -> ArrayFloat1D: + """Angle formed inside polygon by two adjacent edges.""" + + def normalize(v: NDArray) -> NDArray: + return v / np.linalg.norm(v, axis=0) + + vs_orig = self.reference_polygon.T + vs_next = np.roll(vs_orig, axis=-1, shift=-1) + vs_previous = np.roll(vs_orig, axis=-1, shift=+1) + + asp = normalize(vs_next - vs_orig) + asm = normalize(vs_previous - vs_orig) + + cos_angle = asp[0] * asm[0] + asp[1] * asm[1] + sin_angle = asp[0] * asm[1] - asp[1] * asm[0] + + angle = np.arccos(cos_angle) + # concave angles + angle[sin_angle < 0] = 2 * np.pi - angle[sin_angle < 0] + return angle + + @staticmethod + def _shift_vertices( + vertices: NDArray, dist: float + ) -> tuple[NDArray, NDArray, tuple[NDArray, NDArray]]: + """Shifts the vertices of a polygon outward uniformly by distances + `dists`. + + Parameters + ---------- + np.ndarray + Shape (N, 2) defining the polygon vertices in the xy-plane. + dist : float + Distance to offset. + + Returns + ------- + tuple[np.ndarray, np.narray,tuple[np.ndarray,np.ndarray]] + New polygon vertices; + and the shift of vertices in direction parallel to the edges. + Shift along x and y direction. + """ + + # 'dist' may be autograd-traced; unbox for the zero-check only + if math.isclose(getval(dist), 0): + return vertices, np.zeros(vertices.shape[0], dtype=float), None + + def rot90(v: tuple[NDArray, NDArray]) -> NDArray: + """90 degree rotation of 2d vector + vx -> vy + vy -> -vx + """ + vxs, vys = v + return np.stack((-vys, vxs), axis=0) + + def cross(u: NDArray, v: NDArray) -> Any: + return u[0] * v[1] - u[1] * v[0] + + def normalize(v: NDArray) -> NDArray: + return v / np.linalg.norm(v, axis=0) + + vs_orig = copy(vertices.T) + vs_next = np.roll(copy(vs_orig), axis=-1, shift=-1) + vs_previous = np.roll(copy(vs_orig), axis=-1, shift=+1) + + asp = normalize(vs_next - vs_orig) + asm = normalize(vs_orig - vs_previous) + + # the vertex shift is decomposed into parallel and perpendicular directions + perpendicular_shift = -dist + det = cross(asm, asp) + + tan_half_angle = np.where( + np.isclose(det, 0, rtol=_IS_CLOSE_RTOL), + 0.0, + cross(asm, rot90(asm - asp)) / (det + np.isclose(det, 0, rtol=_IS_CLOSE_RTOL)), + ) + parallel_shift = dist * tan_half_angle + + shift_total = perpendicular_shift * rot90(asm) + parallel_shift * asm + shift_x = shift_total[0, :] + shift_y = shift_total[1, :] + + return ( + np.swapaxes(vs_orig + shift_total, -2, -1), + parallel_shift, + (shift_x, shift_y), + ) + + @staticmethod + def _edge_length_and_reduction_rate( + vertices: NDArray, + ) -> tuple[NDArray, NDArray]: + """Edge length of reduction rate of each edge with unit offset length. + + Parameters + ---------- + vertices : np.ndarray + Shape (N, 2) defining the polygon vertices in the xy-plane. + + Returns + ------- + tuple[np.ndarray, np.narray] + edge length, and reduction rate + """ + + # edge length + vs_orig = copy(vertices.T) + vs_next = np.roll(copy(vs_orig), axis=-1, shift=-1) + edge_length = np.linalg.norm(vs_next - vs_orig, axis=0) + + # edge length remaining + dist = 1 + parallel_shift = PolySlab._shift_vertices(vertices, dist)[1] + parallel_shift_p = np.roll(copy(parallel_shift), shift=-1) + edge_reduction = -(parallel_shift + parallel_shift_p) + return edge_length, edge_reduction + + @staticmethod + def _maximal_erosion(vertices: NDArray) -> float: + """The erosion value that reduces the length of + all edges to be non-positive. + """ + edge_length, edge_reduction = PolySlab._edge_length_and_reduction_rate(vertices) + ind_nonzero = abs(edge_reduction) > fp_eps + return -np.min(edge_length[ind_nonzero] / edge_reduction[ind_nonzero]) + + @staticmethod + def _heal_polygon(vertices: NDArray) -> NDArray: + """heal a self-intersecting polygon.""" + shapely_poly = PolySlab.make_shapely_polygon(vertices) + if shapely_poly.is_valid: + return vertices + elif hasbox(vertices): + raise NotImplementedError( + "The dilation caused damage to the polygon. " + "Automatically healing this is currently not supported when " + "differentiating w.r.t. the vertices. Try increasing the spacing " + "between vertices or reduce the amount of dilation." + ) + # perform healing + poly_heal = shapely.make_valid(shapely_poly) + return PolySlab._proper_vertices(list(poly_heal.exterior.coords)) + + def _volume(self, bounds: Bound) -> float: + """Returns object's volume within given bounds.""" + + z_min, z_max = self.slab_bounds + + z_min = max(z_min, bounds[0][self.axis]) + z_max = min(z_max, bounds[1][self.axis]) + + length = z_max - z_min + + top_area = abs(self._area(self.top_polygon)) + base_area = abs(self._area(self.base_polygon)) + + # https://mathworld.wolfram.com/PyramidalFrustum.html + return 1.0 / 3.0 * length * (top_area + base_area + np.sqrt(top_area * base_area)) + + def _surface_area(self, bounds: Bound) -> float: + """Returns object's surface area within given bounds.""" + + area = 0 + + top_polygon = self.top_polygon + base_polygon = self.base_polygon + + top_area = abs(self._area(top_polygon)) + base_area = abs(self._area(base_polygon)) + + top_perim = self._perimeter(top_polygon) + base_perim = self._perimeter(base_polygon) + + z_min, z_max = self.slab_bounds + + if z_min < bounds[0][self.axis]: + z_min = bounds[0][self.axis] + else: + area += base_area + + if z_max > bounds[1][self.axis]: + z_max = bounds[1][self.axis] + else: + area += top_area + + length = z_max - z_min + + area += 0.5 * (top_perim + base_perim) * length + + return area + + """ Autograd code """ + + def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: + """ + Return VJPs while handling several edge-cases: + + - If the slab volume does not overlap the simulation, all grads are zero + (one warning is issued). + - Faces that lie completely outside the simulation give zero ``slab_bounds`` + gradients; this includes the +/- inf cases. + - A 2d simulation collapses the surface integral to a line integral + """ + vjps: AutogradFieldMap = {} + + intersect_min, intersect_max = map(np.asarray, derivative_info.bounds_intersect) + sim_min, sim_max = map(np.asarray, derivative_info.simulation_bounds) + + extents = intersect_max - intersect_min + is_2d = np.isclose(extents[self.axis], 0.0) + + # early return if polyslab is not in simulation domain + slab_min, slab_max = self.slab_bounds + if (slab_max < sim_min[self.axis]) or (slab_min > sim_max[self.axis]): + log.warning( + "'PolySlab' lies completely outside the simulation domain.", + log_once=True, + ) + for p in derivative_info.paths: + vjps[p] = np.zeros_like(self.vertices) if p == ("vertices",) else 0.0 + return vjps + + # create interpolators once for ALL derivative computations + # use provided interpolators if available to avoid redundant field data conversions + interpolators = derivative_info.interpolators or derivative_info.create_interpolators( + dtype=config.adjoint.gradient_dtype_float + ) + + for path in derivative_info.paths: + if path == ("vertices",): + vjps[path] = self._compute_derivative_vertices( + derivative_info, sim_min, sim_max, is_2d, interpolators + ) + + elif path == ("sidewall_angle",): + vjps[path] = self._compute_derivative_sidewall_angle( + derivative_info, sim_min, sim_max, is_2d, interpolators + ) + elif path[0] == "slab_bounds": + idx = path[1] + face_coord = self.slab_bounds[idx] + + # face entirely outside -> gradient 0 + if ( + np.isinf(face_coord) + or face_coord < sim_min[self.axis] + or face_coord > sim_max[self.axis] + or is_2d + ): + vjps[path] = 0.0 + continue + + v = self._compute_derivative_slab_bounds(derivative_info, idx, interpolators) + # outward-normal convention + if idx == 0: + v *= -1 + vjps[path] = v + else: + raise ValueError(f"No derivative defined w.r.t. 'PolySlab' field '{path}'.") + + return vjps + + # ---- Shared helpers for VJP surface integrations ---- + def _z_slices( + self, sim_min: NDArray, sim_max: NDArray, is_2d: bool, dx: float + ) -> tuple[NDArray, float, float, float]: + """Compute z-slice centers and spacing within bounds. + + Returns (z_centers, dz, z0, z1). For 2D, returns single center and dz=1. + """ + if is_2d: + midpoint_z = np.maximum( + np.minimum(self.center_axis, sim_max[self.axis]), + sim_min[self.axis], + ) + zc = np.array([midpoint_z], dtype=config.adjoint.gradient_dtype_float) + return zc, 1.0, self.center_axis, self.center_axis + + z0 = max(self.slab_bounds[0], sim_min[self.axis]) + z1 = min(self.slab_bounds[1], sim_max[self.axis]) + if z1 <= z0: + return np.array([], dtype=config.adjoint.gradient_dtype_float), 0.0, z0, z1 + + n_z = max(1, int(np.ceil((z1 - z0) / dx))) + dz = (z1 - z0) / n_z + z_centers = np.linspace( + z0 + dz / 2, z1 - dz / 2, n_z, dtype=config.adjoint.gradient_dtype_float + ) + return z_centers, dz, z0, z1 + + @staticmethod + def _clip_edges_to_bounds_batch( + segment_starts: NDArray, + segment_ends: NDArray, + sim_min: NDArray, + sim_max: NDArray, + *, + _edge_clip_tol: Optional[float] = None, + _dtype: Optional[type] = None, + ) -> tuple[NDArray, NDArray, NDArray]: + """ + Compute parametric bounds for multiple segments clipped to simulation bounds. + + Parameters + ---------- + segment_starts : NDArray + (N, 3) array of segment start coordinates. + segment_ends : NDArray + (N, 3) array of segment end coordinates. + sim_min : NDArray + (3,) array of simulation minimum bounds. + sim_max : NDArray + (3,) array of simulation maximum bounds. + + Returns + ------- + is_within_bounds : NDArray + (N,) boolean array indicating if the segment intersects the bounds. + t_starts : NDArray + (N,) array of parametric start values (0.0 to 1.0). + t_ends : NDArray + (N,) array of parametric end values (0.0 to 1.0). + """ + n = segment_starts.shape[0] + if _edge_clip_tol is None: + _edge_clip_tol = config.adjoint.edge_clip_tolerance + if _dtype is None: + _dtype = config.adjoint.gradient_dtype_float + + t_starts = np.zeros(n, dtype=_dtype) + t_ends = np.ones(n, dtype=_dtype) + is_within_bounds = np.ones(n, dtype=bool) + + for dim in range(3): + start_coords = segment_starts[:, dim] + end_coords = segment_ends[:, dim] + bound_min = sim_min[dim] + bound_max = sim_max[dim] + + # check for parallel edges (faster than isclose) + parallel = np.abs(start_coords - end_coords) < 1e-12 + + # parallel edges: check if outside bounds + outside = parallel & ( + (start_coords < (bound_min - _edge_clip_tol)) + | (start_coords > (bound_max + _edge_clip_tol)) + ) + is_within_bounds &= ~outside + + # non-parallel edges: compute t_min, t_max + not_parallel = ~parallel & is_within_bounds + if np.any(not_parallel): + denom = np.where(not_parallel, end_coords - start_coords, 1.0) # avoid div by zero + t_min = (bound_min - start_coords) / denom + t_max = (bound_max - start_coords) / denom + + # swap if needed + swap = t_min > t_max + t_min_new = np.where(swap, t_max, t_min) + t_max_new = np.where(swap, t_min, t_max) + + # update t_starts and t_ends for valid non-parallel edges + t_starts = np.where(not_parallel, np.maximum(t_starts, t_min_new), t_starts) + t_ends = np.where(not_parallel, np.minimum(t_ends, t_max_new), t_ends) + + # still valid? + is_within_bounds &= ~not_parallel | (t_starts < t_ends) + + is_within_bounds &= t_ends > t_starts + _edge_clip_tol + + return is_within_bounds, t_starts, t_ends + + @staticmethod + def _adaptive_edge_samples( + L: float, + dx: float, + t_start: float = 0.0, + t_end: float = 1.0, + *, + _sample_fraction: Optional[float] = None, + _gauss_order: Optional[int] = None, + _dtype: Optional[type] = None, + ) -> tuple[NDArray, NDArray]: + """ + Compute Gauss samples and weights along [t_start, t_end] with adaptive count. + + Parameters + ---------- + L : float + Physical length of the full edge. + dx : float + Target discretization step size. + t_start : float, optional + Start parameter, by default 0.0. + t_end : float, optional + End parameter, by default 1.0. + + Returns + ------- + tuple[NDArray, NDArray] + Tuple of (samples, weights) for the integration. + """ + if _sample_fraction is None: + _sample_fraction = config.adjoint.quadrature_sample_fraction + if _gauss_order is None: + _gauss_order = config.adjoint.gauss_quadrature_order + if _dtype is None: + _dtype = config.adjoint.gradient_dtype_float + + L_eff = L * max(0.0, t_end - t_start) + n_uniform = max(1, int(np.ceil(L_eff / dx))) + n_gauss = n_uniform if n_uniform <= 3 else max(2, int(n_uniform * _sample_fraction)) + if n_gauss <= _gauss_order: + g, w = leggauss(n_gauss) + half_range = 0.5 * (t_end - t_start) + s = (half_range * g + 0.5 * (t_end + t_start)).astype(_dtype, copy=False) + wt = (w * half_range).astype(_dtype, copy=False) + return s, wt + + # composite Gauss with fixed local order + g_loc, w_loc = leggauss(_gauss_order) + segs = n_uniform + edges_t = np.linspace(t_start, t_end, segs + 1, dtype=_dtype) + + # compute all segments at once + a = edges_t[:-1] # (segs,) + b = edges_t[1:] # (segs,) + half_width = 0.5 * (b - a) # (segs,) + mid = 0.5 * (b + a) # (segs,) + + # (segs, 1) * (order,) + (segs, 1) -> (segs, order) + S = (half_width[:, None] * g_loc + mid[:, None]).astype(_dtype, copy=False) + W = (half_width[:, None] * w_loc).astype(_dtype, copy=False) + return S.ravel(), W.ravel() + + def _collect_sidewall_patches( + self, + vertices: NDArray, + next_v: NDArray, + edges: NDArray, + basis: dict, + sim_min: NDArray, + sim_max: NDArray, + is_2d: bool, + dx: float, + ) -> dict: + """ + Collect sidewall patch geometry for batched VJP evaluation. + + Parameters + ---------- + vertices : NDArray + Array of polygon vertices. + next_v : NDArray + Array of next vertices (forming edges). + edges : NDArray + Edge vectors. + basis : dict + Basis vectors dictionary. + sim_min : NDArray + Simulation minimum bounds. + sim_max : NDArray + Simulation maximum bounds. + is_2d : bool + Whether the simulation is 2D. + dx : float + Discretization step. + + Returns + ------- + dict + Dictionary containing: + - centers: (N, 3) array of patch centers. + - normals: (N, 3) array of patch normals. + - perps1: (N, 3) array of first tangent vectors. + - perps2: (N, 3) array of second tangent vectors. + - Ls: (N,) array of edge lengths. + - s_vals: (N,) array of parametric coordinates along the edge. + - s_weights: (N,) array of quadrature weights. + - zc_vals: (N,) array of z-coordinates. + - dz: float, slice thickness. + - edge_indices: (N,) array of original edge indices. + """ + # cache config values to avoid repeated lookups (overhead not insignificant here) + _dtype = config.adjoint.gradient_dtype_float + _edge_clip_tol = config.adjoint.edge_clip_tolerance + _sample_fraction = config.adjoint.quadrature_sample_fraction + _gauss_order = config.adjoint.gauss_quadrature_order + + theta = get_static(self.sidewall_angle) + z_ref = self.reference_axis_pos + + cos_th = np.cos(theta) + cos_th = np.clip(cos_th, 1e-12, 1.0) + tan_th = np.tan(theta) + dprime = -tan_th # dd/dz + + # axis unit vector in 3D + axis_vec = np.zeros(3, dtype=_dtype) + axis_vec[self.axis] = 1.0 + + # densify along axis as |theta| grows, dz scales with cos(theta) + z_centers, dz, z0, z1 = self._z_slices(sim_min, sim_max, is_2d=is_2d, dx=dx * cos_th) + + # early exit: no slices + if (not is_2d) and len(z_centers) == 0: + return { + "centers": np.empty((0, 3), dtype=_dtype), + "normals": np.empty((0, 3), dtype=_dtype), + "perps1": np.empty((0, 3), dtype=_dtype), + "perps2": np.empty((0, 3), dtype=_dtype), + "Ls": np.empty((0,), dtype=_dtype), + "s_vals": np.empty((0,), dtype=_dtype), + "s_weights": np.empty((0,), dtype=_dtype), + "zc_vals": np.empty((0,), dtype=_dtype), + "dz": dz, + "edge_indices": np.empty((0,), dtype=int), + } + + # estimate patches for pre-allocation + n_edges = len(vertices) + estimated_patches = 0 + denom_edge = max(dx * cos_th, 1e-12) + for ei in range(n_edges): + v0, v1 = vertices[ei], next_v[ei] + L = np.linalg.norm(v1 - v0) + if not np.isclose(L, 0.0): + # prealloc guided by actual step; ds_phys scales with cos(theta) + n_samples = max(1, int(np.ceil(L / denom_edge) * 0.6)) + estimated_patches += n_samples * max(1, len(z_centers)) + estimated_patches = int(max(1, estimated_patches) * 1.2) + + # pre-allocate arrays + centers = np.empty((estimated_patches, 3), dtype=_dtype) + normals = np.empty((estimated_patches, 3), dtype=_dtype) + perps1 = np.empty((estimated_patches, 3), dtype=_dtype) + perps2 = np.empty((estimated_patches, 3), dtype=_dtype) + Ls = np.empty((estimated_patches,), dtype=_dtype) + s_vals = np.empty((estimated_patches,), dtype=_dtype) + s_weights = np.empty((estimated_patches,), dtype=_dtype) + zc_vals = np.empty((estimated_patches,), dtype=_dtype) + edge_indices = np.empty((estimated_patches,), dtype=int) + + patch_idx = 0 + + # if the simulation is effectively 2D (one tangential dimension collapsed), + # slightly expand degenerate bounds to enable finite-length clipping of edges. + sim_min_eff = np.array(sim_min, dtype=_dtype) + sim_max_eff = np.array(sim_max, dtype=_dtype) + for dim in range(3): + if dim == self.axis: + continue + if np.isclose(sim_max_eff[dim] - sim_min_eff[dim], 0.0): + sim_min_eff[dim] -= 0.5 * dx + sim_max_eff[dim] += 0.5 * dx + + # pre-compute values that are constant across z slices + n_z = len(z_centers) + z_centers_arr = np.asarray(z_centers, dtype=_dtype) + + # slanted local basis (constant across z for non-slanted case) + # for slanted: rz = axis_vec + dprime * n2d, but dprime is constant + for ei, (v0, v1) in enumerate(zip(vertices, next_v)): + edge_vec = v1 - v0 + L = np.sqrt(np.dot(edge_vec, edge_vec)) + if L < 1e-12: + continue + + # constant along edge: unit tangent in 3D (no axis component) + t_edge = basis["perp1"][ei] + + # outward in-plane normal from canonical basis normal + n2d = basis["norm"][ei].copy() + n2d[self.axis] = 0.0 + nrm = np.linalg.norm(n2d) + if not np.isclose(nrm, 0.0): + n2d = n2d / nrm + else: + # fallback to right-handed construction if degenerate + tmp = np.cross(axis_vec, t_edge) + n2d = tmp / (np.linalg.norm(tmp) + 1e-20) + + # compute basis vectors once per edge + rz = axis_vec + dprime * n2d + T1_vec = t_edge + N_vec = np.cross(T1_vec, rz) + N_norm = np.linalg.norm(N_vec) + if not np.isclose(N_norm, 0.0): + N_vec = N_vec / N_norm + + # align N with outward edge normal + if float(np.dot(N_vec, basis["norm"][ei])) < 0.0: + N_vec = -N_vec + + T2_vec = np.cross(N_vec, T1_vec) + T2_norm = np.linalg.norm(T2_vec) + if not np.isclose(T2_norm, 0.0): + T2_vec = T2_vec / T2_norm + + # batch compute offsets for all z slices at once + d_all = -(z_centers_arr - z_ref) * tan_th # (n_z,) + offsets_3d = d_all[:, None] * n2d # (n_z, 3) - faster than np.outer + + # batch compute segment starts and ends for all z slices + segment_starts = np.empty((n_z, 3), dtype=_dtype) + segment_ends = np.empty((n_z, 3), dtype=_dtype) + plane_axes = [i for i in range(3) if i != self.axis] + segment_starts[:, self.axis] = z_centers_arr + segment_starts[:, plane_axes] = v0 + segment_starts += offsets_3d + segment_ends[:, self.axis] = z_centers_arr + segment_ends[:, plane_axes] = v1 + segment_ends += offsets_3d + + # batch clip all z slices at once + is_within_bounds, t_starts, t_ends = self._clip_edges_to_bounds_batch( + segment_starts, + segment_ends, + sim_min_eff, + sim_max_eff, + _edge_clip_tol=_edge_clip_tol, + _dtype=_dtype, + ) + + # process only valid z slices (sampling has variable output sizes) + valid_indices = np.nonzero(is_within_bounds)[0] + if len(valid_indices) == 0: + continue + + # group z slices by unique (t0, t1) pairs to avoid redundant quadrature calculations. + # since most z-slices will have identical clipping bounds (0.0, 1.0), + # we can compute the Gauss samples once and reuse them for almost all slices. + # rounding ensures we get cache hits despite tiny floating point differences. + t0_valid = np.round(t_starts[valid_indices], 10) + t1_valid = np.round(t_ends[valid_indices], 10) + + # simple cache for sampling results: (t0, t1) -> (s_list, w_list) + sample_cache = {} + + # process each z slice + for zi, t0, t1 in zip(valid_indices, t0_valid, t1_valid): + if (t0, t1) not in sample_cache: + sample_cache[(t0, t1)] = self._adaptive_edge_samples( + L, + denom_edge, + t0, + t1, + _sample_fraction=_sample_fraction, + _gauss_order=_gauss_order, + _dtype=_dtype, + ) + + s_list, w_list = sample_cache[(t0, t1)] + if len(s_list) == 0: + continue + + zc = z_centers_arr[zi] + offset3d = offsets_3d[zi] + + pts2d = v0 + s_list[:, None] * edge_vec # faster than np.outer + + # inline unpop_axis_vect for xyz computation + n_pts = len(s_list) + xyz = np.empty((n_pts, 3), dtype=_dtype) + xyz[:, self.axis] = zc + xyz[:, plane_axes] = pts2d + xyz += offset3d + + n_patches = n_pts + new_size_needed = patch_idx + n_patches + if new_size_needed > centers.shape[0]: + # grow arrays by 1.5x to avoid frequent reallocations + new_size = int(new_size_needed * 1.5) + centers.resize((new_size, 3), refcheck=False) + normals.resize((new_size, 3), refcheck=False) + perps1.resize((new_size, 3), refcheck=False) + perps2.resize((new_size, 3), refcheck=False) + Ls.resize((new_size,), refcheck=False) + s_vals.resize((new_size,), refcheck=False) + s_weights.resize((new_size,), refcheck=False) + zc_vals.resize((new_size,), refcheck=False) + edge_indices.resize((new_size,), refcheck=False) + + sl = slice(patch_idx, patch_idx + n_patches) + centers[sl] = xyz + normals[sl] = N_vec + perps1[sl] = T1_vec + perps2[sl] = T2_vec + Ls[sl] = L + s_vals[sl] = s_list + s_weights[sl] = w_list + zc_vals[sl] = zc + edge_indices[sl] = ei + + patch_idx += n_patches + + # trim arrays to final size + centers = centers[:patch_idx] + normals = normals[:patch_idx] + perps1 = perps1[:patch_idx] + perps2 = perps2[:patch_idx] + Ls = Ls[:patch_idx] + s_vals = s_vals[:patch_idx] + s_weights = s_weights[:patch_idx] + zc_vals = zc_vals[:patch_idx] + edge_indices = edge_indices[:patch_idx] + + return { + "centers": centers, + "normals": normals, + "perps1": perps1, + "perps2": perps2, + "Ls": Ls, + "s_vals": s_vals, + "s_weights": s_weights, + "zc_vals": zc_vals, + "dz": dz, + "edge_indices": edge_indices, + } + + def _compute_derivative_sidewall_angle( + self, + derivative_info: DerivativeInfo, + sim_min: NDArray, + sim_max: NDArray, + is_2d: bool = False, + interpolators: Optional[dict] = None, + ) -> float: + """VJP for dJ/dtheta where theta = sidewall_angle. + + Use dJ/dtheta = integral_S g(x) * V_n(x; theta) * dA, with g(x) from + `evaluate_gradient_at_points`. For a ruled sidewall built by + offsetting the mid-plane polygon by d(z) = -(z - z_ref) * tan(theta), + the normal velocity is V_n = (dd/dtheta) * cos(theta) = -(z - z_ref)/cos(theta) + and the area element is dA = (dz/cos(theta)) * d_ell. + Therefore each patch weight is w = L * dz * (-(z - z_ref)) / cos(theta)^2. + """ + if interpolators is None: + interpolators = derivative_info.create_interpolators( + dtype=config.adjoint.gradient_dtype_float + ) + + # 2D sim => no dependence on theta (z_local=0) + if is_2d: + return 0.0 + + vertices, next_v, edges, basis = self._edge_geometry_arrays() + + dx = derivative_info.adaptive_vjp_spacing() + + # collect patches once + patch = self._collect_sidewall_patches( + vertices=vertices, + next_v=next_v, + edges=edges, + basis=basis, + sim_min=sim_min, + sim_max=sim_max, + is_2d=False, + dx=dx, + ) + if patch["centers"].shape[0] == 0: + return 0.0 + + # Shape-derivative factors: + # - Offset: d(z) = -(z - z_ref) * tan(theta) + # - Tangential rate: dd/dtheta = -(z - z_ref) * sec(theta)^2 + # - Normal velocity (project to surface normal): V_n = (dd/dtheta) * cos(theta) = -(z - z_ref)/cos(theta) + # - Area element of slanted strip: dA = (dz/cos(theta)) * d_ell + # => Patch weight scales as: V_n * dA = -(z - z_ref) * dz * d_ell / cos(theta)^2 + cos_theta = np.cos(get_static(self.sidewall_angle)) + inv_cos2 = 1.0 / (cos_theta * cos_theta) + z_ref = self.reference_axis_pos + + g = derivative_info.evaluate_gradient_at_points( + patch["centers"], patch["normals"], patch["perps1"], patch["perps2"], interpolators + ) + z_local = patch["zc_vals"] - z_ref + weights = patch["Ls"] * patch["s_weights"] * patch["dz"] * (-z_local) * inv_cos2 + return float(np.real(np.sum(g * weights))) + + def _compute_derivative_slab_bounds( + self, derivative_info: DerivativeInfo, min_max_index: int, interpolators: dict + ) -> TracedArrayFloat2D: + """VJP for one of the two horizontal faces of a ``PolySlab``. + + The face is discretized into a Cartesian grid of small planar patches. + The adjoint surface integral is evaluated on every retained patch; the + resulting derivative is split equally between the two vertices that bound + the edge segment. + """ + # rmin/rmax over the geometry and simulation box + if np.isclose(self.slab_bounds[1] - self.slab_bounds[0], 0.0): + log.warning( + "Computing slab face derivatives for flat structures is not fully supported and " + "may give zero for the derivative. Try using a structure with a small, but nonzero " + "thickness for slab bound derivatives." + ) + rmin, rmax = derivative_info.bounds_intersect + _, (r1_min, r2_min) = self.pop_axis(rmin, axis=self.axis) + _, (r1_max, r2_max) = self.pop_axis(rmax, axis=self.axis) + ax_val = self.slab_bounds[min_max_index] + + # planar grid resolution, clipped to polygon bounding box + face_verts = self.base_polygon if min_max_index == 0 else self.top_polygon + face_poly = shapely.Polygon(face_verts).buffer(fp_eps) + + # limit the patch grid to the face that lives inside the simulation box + poly_min_r1, poly_min_r2, poly_max_r1, poly_max_r2 = face_poly.bounds + r1_min = max(r1_min, poly_min_r1) + r1_max = min(r1_max, poly_max_r1) + r2_min = max(r2_min, poly_min_r2) + r2_max = min(r2_max, poly_max_r2) + + # intersect the polygon with the simulation bounds + face_poly = face_poly.intersection(shapely.box(r1_min, r2_min, r1_max, r2_max)) + + if (r1_max <= r1_min) and (r2_max <= r2_min): + # the polygon does not intersect the current simulation slice + return 0.0 + + # re-compute the extents after clipping to the polygon bounds + extents = np.array([r1_max - r1_min, r2_max - r2_min]) + + # choose surface or line integral + integral_fun = ( + self.compute_derivative_slab_bounds_line + if np.isclose(extents, 0).any() + else self.compute_derivative_slab_bounds_surface + ) + return integral_fun( + derivative_info, + extents, + r1_min, + r1_max, + r2_min, + r2_max, + ax_val, + face_poly, + min_max_index, + interpolators, + ) + + def compute_derivative_slab_bounds_line( + self, + derivative_info: DerivativeInfo, + extents: NDArray, + r1_min: float, + r1_max: float, + r2_min: float, + r2_max: float, + ax_val: float, + face_poly: shapely.Polygon, + min_max_index: int, + interpolators: dict, + ) -> float: + """Handle degenerate line cross-section case""" + line_dim = 1 if np.isclose(extents[0], 0) else 0 + + poly_min_r1, poly_min_r2, poly_max_r1, poly_max_r2 = face_poly.bounds + if line_dim == 0: # x varies, y is fixed + l_min = max(r1_min, poly_min_r1) + l_max = min(r1_max, poly_max_r1) + else: # y varies, x is fixed + l_min = max(r2_min, poly_min_r2) + l_max = min(r2_max, poly_max_r2) + + length = l_max - l_min + if np.isclose(length, 0): + return 0.0 + + dx = derivative_info.adaptive_vjp_spacing() + n_seg = max(1, int(np.ceil(length / dx))) + coords = np.linspace( + l_min, l_max, 2 * n_seg + 1, dtype=config.adjoint.gradient_dtype_float + )[1::2] + + # build XY coordinates and in-plane direction vectors + if line_dim == 0: + xy = np.column_stack((coords, np.full_like(coords, r2_min))) + dir_vec_plane = np.column_stack((np.ones_like(coords), np.zeros_like(coords))) + else: + xy = np.column_stack((np.full_like(coords, r1_min), coords)) + dir_vec_plane = np.column_stack((np.zeros_like(coords), np.ones_like(coords))) + + inside = shapely.contains_xy(face_poly, xy[:, 0], xy[:, 1]) + if not inside.any(): + return 0.0 + + xy = xy[inside] + dir_vec_plane = dir_vec_plane[inside] + n_pts = len(xy) + + centers_xyz = self.unpop_axis_vect(np.full(n_pts, ax_val), xy) + areas = np.full(n_pts, length / n_seg) # patch length + + normals_xyz = self.unpop_axis_vect( + np.full( + n_pts, -1 if min_max_index == 0 else 1, dtype=config.adjoint.gradient_dtype_float + ), + np.zeros_like(xy, dtype=config.adjoint.gradient_dtype_float), + ) + perps1_xyz = self.unpop_axis_vect(np.zeros(n_pts), dir_vec_plane) + perps2_xyz = self.unpop_axis_vect(np.zeros(n_pts), np.zeros_like(dir_vec_plane)) + + vjps = derivative_info.evaluate_gradient_at_points( + centers_xyz, normals_xyz, perps1_xyz, perps2_xyz, interpolators + ) + return np.real(np.sum(vjps * areas)).item() + + def compute_derivative_slab_bounds_surface( + self, + derivative_info: DerivativeInfo, + extents: NDArray, + r1_min: float, + r1_max: float, + r2_min: float, + r2_max: float, + ax_val: float, + face_poly: shapely.Polygon, + min_max_index: int, + interpolators: dict, + ) -> float: + """2d surface integral on a Gauss quadrature grid""" + dx = derivative_info.adaptive_vjp_spacing() + + # uniform grid would use n1 x n2 points + n1_uniform, n2_uniform = np.maximum(1, np.ceil(extents / dx).astype(int)) + + # use ~1/2 Gauss points in each direction for similar accuracy + n1 = max(2, n1_uniform // 2) + n2 = max(2, n2_uniform // 2) + + g1, w1 = leggauss(n1) + g2, w2 = leggauss(n2) + + coords1 = (0.5 * (r1_max - r1_min) * g1 + 0.5 * (r1_max + r1_min)).astype( + config.adjoint.gradient_dtype_float, copy=False + ) + coords2 = (0.5 * (r2_max - r2_min) * g2 + 0.5 * (r2_max + r2_min)).astype( + config.adjoint.gradient_dtype_float, copy=False + ) + + r1_grid, r2_grid = np.meshgrid(coords1, coords2, indexing="ij") + r1_flat = r1_grid.flatten() + r2_flat = r2_grid.flatten() + pts = np.column_stack((r1_flat, r2_flat)) + + in_face = shapely.contains_xy(face_poly, pts[:, 0], pts[:, 1]) + if not in_face.any(): + return 0.0 + + xyz = self.unpop_axis_vect( + np.full(in_face.sum(), ax_val, dtype=config.adjoint.gradient_dtype_float), pts[in_face] + ) + n_patches = xyz.shape[0] + + normals_xyz = self.unpop_axis_vect( + np.full( + n_patches, + -1 if min_max_index == 0 else 1, + dtype=config.adjoint.gradient_dtype_float, + ), + np.zeros((n_patches, 2), dtype=config.adjoint.gradient_dtype_float), + ) + perps1_xyz = self.unpop_axis_vect( + np.zeros(n_patches, dtype=config.adjoint.gradient_dtype_float), + np.column_stack( + ( + np.ones(n_patches, dtype=config.adjoint.gradient_dtype_float), + np.zeros(n_patches, dtype=config.adjoint.gradient_dtype_float), + ) + ), + ) + perps2_xyz = self.unpop_axis_vect( + np.zeros(n_patches, dtype=config.adjoint.gradient_dtype_float), + np.column_stack( + ( + np.zeros(n_patches, dtype=config.adjoint.gradient_dtype_float), + np.ones(n_patches, dtype=config.adjoint.gradient_dtype_float), + ) + ), + ) + + w1_grid, w2_grid = np.meshgrid(w1, w2, indexing="ij") + weights_flat = (w1_grid * w2_grid).flatten()[in_face] + jacobian = 0.25 * (r1_max - r1_min) * (r2_max - r2_min) + + # area-based correction for non-rectangular domains (e.g. concave polygon) + # for constant integrand, integral should equal polygon area + sum_weights = np.sum(weights_flat) + if sum_weights > 0: + area_correction = face_poly.area / (sum_weights * jacobian) + weights_flat = weights_flat * area_correction + + vjps = derivative_info.evaluate_gradient_at_points( + xyz, normals_xyz, perps1_xyz, perps2_xyz, interpolators + ) + return np.real(np.sum(vjps * weights_flat * jacobian)).item() + + def _compute_derivative_vertices( + self, + derivative_info: DerivativeInfo, + sim_min: NDArray, + sim_max: NDArray, + is_2d: bool = False, + interpolators: Optional[dict] = None, + ) -> NDArray: + """VJP for the vertices of a ``PolySlab``. + + Uses shared sidewall patch collection and batched field evaluation. + """ + vertices, next_v, edges, basis = self._edge_geometry_arrays() + dx = derivative_info.adaptive_vjp_spacing() + + # collect patches once + patch = self._collect_sidewall_patches( + vertices=vertices, + next_v=next_v, + edges=edges, + basis=basis, + sim_min=sim_min, + sim_max=sim_max, + is_2d=is_2d, + dx=dx, + ) + + # early return if no patches + if patch["centers"].shape[0] == 0: + return np.zeros_like(vertices) + + dz = patch["dz"] + dz_surf = 1.0 if is_2d else dz / np.cos(self.sidewall_angle) + + # use provided interpolators or create them if not provided + if interpolators is None: + interpolators = derivative_info.create_interpolators( + dtype=config.adjoint.gradient_dtype_float + ) + + # evaluate integrand + g = derivative_info.evaluate_gradient_at_points( + patch["centers"], patch["normals"], patch["perps1"], patch["perps2"], interpolators + ) + + # compute area-based weights and weighted vjps + areas = patch["Ls"] * patch["s_weights"] * dz_surf + patch_vjps = (g * areas).real + + # distribute to vertices using vectorized accumulation + normals_2d = np.delete(basis["norm"], self.axis, axis=1) + edge_idx = patch["edge_indices"] + s = patch["s_vals"] + w0 = (1.0 - s) * patch_vjps + w1 = s * patch_vjps + edge_norms = normals_2d[edge_idx] + + # Accumulate per-vertex contributions using bincount (O(N_patches)) + num_vertices = vertices.shape[0] + contrib0 = w0[:, None] * edge_norms # (n_patches, 2) + contrib1 = w1[:, None] * edge_norms # (n_patches, 2) + + idx0 = edge_idx + idx1 = (edge_idx + 1) % num_vertices + + v0x = np.bincount(idx0, weights=contrib0[:, 0], minlength=num_vertices) + v0y = np.bincount(idx0, weights=contrib0[:, 1], minlength=num_vertices) + v1x = np.bincount(idx1, weights=contrib1[:, 0], minlength=num_vertices) + v1y = np.bincount(idx1, weights=contrib1[:, 1], minlength=num_vertices) + + vjp_per_vertex = np.stack((v0x + v1x, v0y + v1y), axis=1) + return vjp_per_vertex + + def _edge_geometry_arrays( + self, dtype: np.dtype = config.adjoint.gradient_dtype_float + ) -> tuple[NDArray, NDArray, NDArray, dict[str, NDArray]]: + """Return (vertices, next_v, edges, basis) arrays for sidewall edge geometry.""" + vertices = np.asarray(self.vertices, dtype=dtype) + next_v = np.roll(vertices, -1, axis=0) + edges = next_v - vertices + basis = self.edge_basis_vectors(edges) + return vertices, next_v, edges, basis + + def edge_basis_vectors( + self, + edges: NDArray, # (N, 2) + ) -> dict[str, NDArray]: # (N, 3) + """Normalized basis vectors for ``normal`` direction, ``slab`` tangent direction and ``edge``.""" + + # ensure edges have consistent dtype + edges = edges.astype(config.adjoint.gradient_dtype_float, copy=False) + + num_vertices, _ = edges.shape + zeros = np.zeros(num_vertices, dtype=config.adjoint.gradient_dtype_float) + ones = np.ones(num_vertices, dtype=config.adjoint.gradient_dtype_float) + + # normalized vectors along edges + edges_norm_in_plane = self.normalize_vect(edges) + edges_norm_xyz = self.unpop_axis_vect(zeros, edges_norm_in_plane) + + # normalized vectors from base of edges to tops of edges + cos_angle = np.cos(self.sidewall_angle) + sin_angle = np.sin(self.sidewall_angle) + slabs_axis_components = cos_angle * ones + + # create axis_norm as array directly to avoid tuple->array conversion in np.cross + axis_norm = np.zeros(3, dtype=config.adjoint.gradient_dtype_float) + axis_norm[self.axis] = 1.0 + slab_normal_xyz = -sin_angle * np.cross(edges_norm_xyz, axis_norm) + _, slab_normal_in_plane = self.pop_axis_vect(slab_normal_xyz) + slabs_norm_xyz = self.unpop_axis_vect(slabs_axis_components, slab_normal_in_plane) + + # normalized vectors pointing in normal direction of edge + # cross yields inward normal when the extrusion axis is y, so negate once for axis==1 + sign = (-1 if self.axis == 1 else 1) * (-1 if not self.is_ccw else 1) + normals_norm_xyz = sign * np.cross(edges_norm_xyz, slabs_norm_xyz) + + return { + "norm": normals_norm_xyz, + "perp1": edges_norm_xyz, + "perp2": slabs_norm_xyz, + } + + def unpop_axis_vect(self, ax_coords: NDArray, plane_coords: NDArray) -> NDArray: + """Combine coordinate along axis with coordinates on the plane tangent to the axis. + + ax_coords.shape == [N] + plane_coords.shape == [N, 2] + return shape == [N, 3] + """ + n_pts = ax_coords.shape[0] + arr_xyz = np.zeros((n_pts, 3), dtype=ax_coords.dtype) + + plane_axes = [i for i in range(3) if i != self.axis] + + arr_xyz[:, self.axis] = ax_coords + arr_xyz[:, plane_axes] = plane_coords + + return arr_xyz + + def pop_axis_vect(self, coord: NDArray) -> tuple[NDArray, tuple[NDArray, NDArray]]: + """Combine coordinate along axis with coordinates on the plane tangent to the axis. + + coord.shape == [N, 3] + return shape == ([N], [N, 2] + """ + + arr_axis, arrs_plane = self.pop_axis(coord.T, axis=self.axis) + arrs_plane = np.array(arrs_plane).T + + return arr_axis, arrs_plane + + @staticmethod + def normalize_vect(arr: NDArray) -> NDArray: + """normalize an array shaped (N, d) along the `d` axis and return (N, 1).""" + norm = np.linalg.norm(arr, axis=-1, keepdims=True) + norm = np.where(norm == 0, 1, norm) + return arr / norm + + def translated(self, x: float, y: float, z: float) -> PolySlab: + """Return a translated copy of this geometry. + + Parameters + ---------- + x : float + Translation along x. + y : float + Translation along y. + z : float + Translation along z. + + Returns + ------- + :class:`PolySlab` + Translated copy of this ``PolySlab``. + """ + + t_normal, t_plane = self.pop_axis((x, y, z), axis=self.axis) + translated_vertices = np.array(self.vertices) + np.array(t_plane)[None, :] + translated_slab_bounds = (self.slab_bounds[0] + t_normal, self.slab_bounds[1] + t_normal) + return self.updated_copy(vertices=translated_vertices, slab_bounds=translated_slab_bounds) + + def scaled(self, x: float = 1.0, y: float = 1.0, z: float = 1.0) -> PolySlab: + """Return a scaled copy of this geometry. + + Parameters + ---------- + x : float = 1.0 + Scaling factor along x. + y : float = 1.0 + Scaling factor along y. + z : float = 1.0 + Scaling factor along z. + + Returns + ------- + :class:`Geometry` + Scaled copy of this geometry. + """ + scale_normal, scale_in_plane = self.pop_axis((x, y, z), axis=self.axis) + scaled_vertices = self.vertices * np.array(scale_in_plane) + scaled_slab_bounds = tuple(scale_normal * bound for bound in self.slab_bounds) + return self.updated_copy(vertices=scaled_vertices, slab_bounds=scaled_slab_bounds) + + def rotated(self, angle: float, axis: Union[Axis, Coordinate]) -> PolySlab: + """Return a rotated copy of this geometry. + + Parameters + ---------- + angle : float + Rotation angle (in radians). + axis : Union[int, tuple[float, float, float]] + Axis of rotation: 0, 1, or 2 for x, y, and z, respectively, or a 3D vector. + + Returns + ------- + :class:`PolySlab` + Rotated copy of this ``PolySlab``. + """ + _, plane_axs = self.pop_axis([0, 1, 2], self.axis) + if (isinstance(axis, int) and axis == self.axis) or ( + isinstance(axis, tuple) and all(axis[ax] == 0 for ax in plane_axs) + ): + verts_3d = np.zeros((3, self.vertices.shape[0])) + verts_3d[plane_axs[0], :] = self.vertices[:, 0] + verts_3d[plane_axs[1], :] = self.vertices[:, 1] + rotation = RotationAroundAxis(angle=angle, axis=axis) + rotated_vertices = rotation.rotate_vector(verts_3d) + rotated_vertices = rotated_vertices[plane_axs, :].T + return self.updated_copy(vertices=rotated_vertices) + + return super().rotated(angle=angle, axis=axis) + + def reflected(self, normal: Coordinate) -> PolySlab: + """Return a reflected copy of this geometry. + + Parameters + ---------- + normal : tuple[float, float, float] + The 3D normal vector of the plane of reflection. The plane is assumed + to pass through the origin (0,0,0). + + Returns + ------- + ------- + :class:`PolySlab` + Reflected copy of this ``PolySlab``. + """ + if math.isclose(normal[self.axis], 0): + _, plane_axs = self.pop_axis((0, 1, 2), self.axis) + verts_3d = np.zeros((3, self.vertices.shape[0])) + verts_3d[plane_axs[0], :] = self.vertices[:, 0] + verts_3d[plane_axs[1], :] = self.vertices[:, 1] + reflection = ReflectionFromPlane(normal=normal) + reflected_vertices = reflection.reflect_vector(verts_3d) + reflected_vertices = reflected_vertices[plane_axs, :].T + return self.updated_copy(vertices=reflected_vertices) + + return super().reflected(normal=normal) + + +class ComplexPolySlabBase(PolySlab): + """Interface for dividing a complex polyslab where self-intersecting polygon can + occur during extrusion. This class should not be used directly. Use instead + :class:`plugins.polyslab.ComplexPolySlab`.""" + + @model_validator(mode="after") + def no_self_intersecting_polygon_during_extrusion(self: Self) -> Self: + """Turn off the validation for this class.""" + return self + + @classmethod + def from_gds( + cls, + gds_cell: Cell, + axis: Axis, + slab_bounds: tuple[float, float], + gds_layer: int, + gds_dtype: Optional[int] = None, + gds_scale: PositiveFloat = 1.0, + dilation: float = 0.0, + sidewall_angle: float = 0, + reference_plane: PlanePosition = "middle", + ) -> list[PolySlab]: + """Import :class:`.PolySlab` from a ``gdstk.Cell``. + + Parameters + ---------- + gds_cell : gdstk.Cell + ``gdstk.Cell`` containing 2D geometric data. + axis : int + Integer index into the polygon's slab axis. (0,1,2) -> (x,y,z). + slab_bounds: tuple[float, float] + Minimum and maximum positions of the slab along ``axis``. + gds_layer : int + Layer index in the ``gds_cell``. + gds_dtype : int = None + Data-type index in the ``gds_cell``. + If ``None``, imports all data for this layer into the returned list. + gds_scale : float = 1.0 + Length scale used in GDS file in units of MICROMETER. + For example, if gds file uses nanometers, set ``gds_scale=1e-3``. + Must be positive. + dilation : float = 0.0 + Dilation of the polygon in the base by shifting each edge along its + normal outwards direction by a distance; + a negative value corresponds to erosion. + sidewall_angle : float = 0 + Angle of the sidewall. + ``sidewall_angle=0`` (default) specifies vertical wall, + while ``0 base.GeometryGroup: + """Divide a complex polyslab into a list of simple polyslabs, which + are assembled into a :class:`.GeometryGroup`. + + Returns + ------- + :class:`.GeometryGroup` + GeometryGroup for a list of simple polyslabs divided from the complex + polyslab. + """ + return base.GeometryGroup(geometries=self.sub_polyslabs) + + @property + def sub_polyslabs(self) -> list[PolySlab]: + """Divide a complex polyslab into a list of simple polyslabs. + Only neighboring vertex-vertex crossing events are treated in this + version. + + Returns + ------- + list[PolySlab] + A list of simple polyslabs. + """ + sub_polyslab_list = [] + num_division_count = 0 + # initialize sub-polyslab parameters + sub_polyslab_dict = self.model_dump(exclude={"type"}).copy() + if math.isclose(self.sidewall_angle, 0): + return [PolySlab.model_validate(sub_polyslab_dict)] + + sub_polyslab_dict.update({"dilation": 0}) # dilation accounted in setup + # initialize offset distance + offset_distance = 0 + + for dist_val in self._dilation_length: + dist_now = 0.0 + vertices_now = self.reference_polygon + + # constructing sub-polyslabs until reaching the base/top + while not math.isclose(dist_now, dist_val): + # bounds for sub-polyslabs assuming no self-intersection + slab_bounds = [ + self._dilation_value_at_reference_to_coord(dist_now), + self._dilation_value_at_reference_to_coord(dist_val), + ] + # 1) find out any vertices touching events between the current + # position to the base/top + max_dist = PolySlab._neighbor_vertices_crossing_detection( + vertices_now, dist_val - dist_now + ) + + # vertices touching events captured, update bounds for sub-polyslab + if max_dist is not None: + # max_dist doesn't have sign, so construct signed offset distance + offset_distance = max_dist * dist_val / abs(dist_val) + slab_bounds[1] = self._dilation_value_at_reference_to_coord( + dist_now + offset_distance + ) + + # 2) construct sub-polyslab + slab_bounds.sort() # for reference_plane=top/bottom, bounds need to be ordered + # direction of marching + reference_plane = "bottom" if dist_val / self._tanq < 0 else "top" + sub_polyslab_dict.update( + { + "slab_bounds": tuple(slab_bounds), + "vertices": vertices_now, + "reference_plane": reference_plane, + } + ) + sub_polyslab_list.append(PolySlab.model_validate(sub_polyslab_dict)) + + # Now Step 3 + if max_dist is None: + break + dist_now += offset_distance + # new polygon vertices where collapsing vertices are removed but keep one + vertices_now = PolySlab._shift_vertices(vertices_now, offset_distance)[0] + vertices_now = PolySlab._remove_duplicate_vertices(vertices_now) + # all vertices collapse + if len(vertices_now) < 3: + break + # polygon collapse into 1D + if self.make_shapely_polygon(vertices_now).buffer(0).area < fp_eps: + break + vertices_now = PolySlab._orient(vertices_now) + num_division_count += 1 + + if num_division_count > _COMPLEX_POLYSLAB_DIVISIONS_WARN: + log.warning( + f"Too many self-intersecting events: the polyslab has been divided into " + f"{num_division_count} polyslabs; more than {_COMPLEX_POLYSLAB_DIVISIONS_WARN} may " + f"slow down the simulation." + ) + + return sub_polyslab_list + + @property + def _dilation_length(self) -> list[float]: + """dilation length from reference plane to the top/bottom of the polyslab.""" + + # for "bottom", only needs to compute the offset length to the top + dist = [self._extrusion_length_to_offset_distance(self.finite_length_axis)] + # reverse the dilation value if the reference plane is on the top + if self.reference_plane == "top": + dist = [-dist[0]] + # for middle, both directions + elif self.reference_plane == "middle": + dist = [dist[0] / 2, -dist[0] / 2] + return dist + + def _dilation_value_at_reference_to_coord(self, dilation: float) -> float: + """Compute the coordinate based on the dilation value to the reference plane.""" + + z_coord = -dilation / self._tanq + self.slab_bounds[0] + if self.reference_plane == "middle": + return z_coord + self.finite_length_axis / 2 + if self.reference_plane == "top": + return z_coord + self.finite_length_axis + # bottom case + return z_coord + + def intersections_tilted_plane( + self, + normal: Coordinate, + origin: Coordinate, + to_2D: MatrixReal4x4, + cleanup: bool = True, + quad_segs: Optional[int] = None, + ) -> list[Shapely]: + """Return a list of shapely geometries at the plane specified by normal and origin. + + Parameters + ---------- + normal : Coordinate + Vector defining the normal direction to the plane. + origin : Coordinate + Vector defining the plane origin. + to_2D : MatrixReal4x4 + Transformation matrix to apply to resulting shapes. + cleanup : bool = True + If True, removes extremely small features from each polygon's boundary. + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. Not used for PolySlab. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation `_. + """ + return [ + shapely.unary_union( + [ + base.Geometry.evaluate_inf_shape(shape) + for polyslab in self.sub_polyslabs + for shape in polyslab.intersections_tilted_plane( + normal, origin, to_2D, cleanup=cleanup, quad_segs=quad_segs + ) + ] + ) + ] diff --git a/tidy3d/_common/components/geometry/primitives.py b/tidy3d/_common/components/geometry/primitives.py new file mode 100644 index 0000000000..0b96d7be5c --- /dev/null +++ b/tidy3d/_common/components/geometry/primitives.py @@ -0,0 +1,1294 @@ +"""Concrete primitive geometrical objects.""" + +from __future__ import annotations + +from math import isclose +from typing import TYPE_CHECKING, Any + +import autograd.numpy as anp +import numpy as np +import shapely +from pydantic import Field, PrivateAttr, model_validator + +from tidy3d._common.components.autograd import TracedSize1D, get_static +from tidy3d._common.components.base import cached_property +from tidy3d._common.components.geometry import base +from tidy3d._common.components.geometry.mesh import TriangleMesh +from tidy3d._common.components.geometry.polyslab import PolySlab +from tidy3d._common.config import config +from tidy3d._common.constants import LARGE_NUMBER, MICROMETER +from tidy3d._common.exceptions import SetupError, ValidationError +from tidy3d._common.log import log +from tidy3d._common.packaging import verify_packages_import + +if TYPE_CHECKING: + from typing import Optional + + from shapely.geometry.base import BaseGeometry + + from tidy3d._common.compat import Self + from tidy3d._common.components.autograd import AutogradFieldMap + from tidy3d._common.components.autograd.derivative_utils import DerivativeInfo + from tidy3d._common.components.types.base import Axis, Bound, Coordinate, MatrixReal4x4, Shapely + +# for sampling conical frustum in visualization +_N_SAMPLE_CURVE_SHAPELY = 40 + +# for shapely circular shapes discretization in visualization +_N_SHAPELY_QUAD_SEGS_VISUALIZATION = 200 + +# Default number of points to discretize polyslab in `Cylinder.to_polyslab()` +_N_PTS_CYLINDER_POLYSLAB = 51 +_MAX_ICOSPHERE_SUBDIVISIONS = 7 # this would have 164K vertices and 328K faces +_DEFAULT_EDGE_FRACTION = 0.25 + + +def _base_icosahedron() -> tuple[np.ndarray, np.ndarray]: + """Return vertices and faces of a unit icosahedron.""" + + phi = (1.0 + np.sqrt(5.0)) / 2.0 + vertices = np.array( + [ + (-1, phi, 0), + (1, phi, 0), + (-1, -phi, 0), + (1, -phi, 0), + (0, -1, phi), + (0, 1, phi), + (0, -1, -phi), + (0, 1, -phi), + (phi, 0, -1), + (phi, 0, 1), + (-phi, 0, -1), + (-phi, 0, 1), + ], + dtype=float, + ) + vertices /= np.linalg.norm(vertices, axis=1)[:, None] + faces = np.array( + [ + (0, 11, 5), + (0, 5, 1), + (0, 1, 7), + (0, 7, 10), + (0, 10, 11), + (1, 5, 9), + (5, 11, 4), + (11, 10, 2), + (10, 7, 6), + (7, 1, 8), + (3, 9, 4), + (3, 4, 2), + (3, 2, 6), + (3, 6, 8), + (3, 8, 9), + (4, 9, 5), + (2, 4, 11), + (6, 2, 10), + (8, 6, 7), + (9, 8, 1), + ], + dtype=int, + ) + return vertices, faces + + +_ICOSAHEDRON_VERTS, _ICOSAHEDRON_FACES = _base_icosahedron() + + +def discretization_wavelength(derivative_info: DerivativeInfo, geometry_label: str) -> float: + """Choose reference wavelength for surface discretization.""" + wvl0_min = derivative_info.wavelength_min + wvl_mat = wvl0_min / np.max([1.0, np.max(np.sqrt(abs(derivative_info.eps_in)))]) + + grid_cfg = config.adjoint + + min_wvl_mat = grid_cfg.min_wvl_fraction * wvl0_min + if wvl_mat < min_wvl_mat: + log.warning( + f"The minimum wavelength inside the {geometry_label} material is {wvl_mat:.3e} μm, which would " + f"create a large number of discretization points for computing the gradient. " + f"To prevent performance degradation, the discretization wavelength has " + f"been clipped to {min_wvl_mat:.3e} μm.", + log_once=True, + ) + return max(wvl_mat, min_wvl_mat) + + +class Sphere(base.Centered, base.Circular): + """Spherical geometry. + + Example + ------- + >>> b = Sphere(center=(1,2,3), radius=2) + """ + + radius: TracedSize1D = Field( + title="Radius", + description="Radius of geometry.", + json_schema_extra={"units": MICROMETER}, + ) + + _icosphere_cache: dict[int, tuple[np.ndarray, float]] = PrivateAttr(default_factory=dict) + + @verify_packages_import(["trimesh"]) + def to_triangle_mesh( + self, + *, + max_edge_length: Optional[float] = None, + subdivisions: Optional[int] = None, + ) -> TriangleMesh: + """Approximate the sphere surface with a ``TriangleMesh``. + + Parameters + ---------- + max_edge_length : float = None + Maximum edge length for triangulation in micrometers. + subdivisions : int = None + Number of subdivisions for icosphere generation. + + Returns + ------- + TriangleMesh + Triangle mesh approximation of the sphere surface. + """ + + triangles, _ = self._triangulated_surface( + max_edge_length=max_edge_length, subdivisions=subdivisions + ) + return TriangleMesh.from_triangles(triangles) + + def inside( + self, x: np.ndarray[float], y: np.ndarray[float], z: np.ndarray[float] + ) -> np.ndarray[bool]: + """For input arrays ``x``, ``y``, ``z`` of arbitrary but identical shape, return an array + with the same shape which is ``True`` for every point in zip(x, y, z) that is inside the + volume of the :class:`Geometry`, and ``False`` otherwise. + + Parameters + ---------- + x : np.ndarray[float] + Array of point positions in x direction. + y : np.ndarray[float] + Array of point positions in y direction. + z : np.ndarray[float] + Array of point positions in z direction. + + Returns + ------- + np.ndarray[bool] + ``True`` for every point that is inside the geometry. + """ + self._ensure_equal_shape(x, y, z) + x0, y0, z0 = self.center + dist_x = np.abs(x - x0) + dist_y = np.abs(y - y0) + dist_z = np.abs(z - z0) + return (dist_x**2 + dist_y**2 + dist_z**2) <= (self.radius**2) + + def intersections_tilted_plane( + self, + normal: Coordinate, + origin: Coordinate, + to_2D: MatrixReal4x4, + cleanup: bool = True, + quad_segs: Optional[int] = None, + ) -> list[Shapely]: + """Return a list of shapely geometries at the plane specified by normal and origin. + + Parameters + ---------- + normal : Coordinate + Vector defining the normal direction to the plane. + origin : Coordinate + Vector defining the plane origin. + to_2D : MatrixReal4x4 + Transformation matrix to apply to resulting shapes. + cleanup : bool = True + If True, removes extremely small features from each polygon's boundary. + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. If ``None``, uses + ``_N_SHAPELY_QUAD_SEGS_VISUALIZATION`` for high-quality visualization. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation `_. + """ + if quad_segs is None: + quad_segs = _N_SHAPELY_QUAD_SEGS_VISUALIZATION + + normal = np.array(normal) + unit_normal = normal / (np.sum(normal**2) ** 0.5) + projection = np.dot(np.array(origin) - np.array(self.center), unit_normal) + if abs(projection) >= self.radius: + return [] + + radius = (self.radius**2 - projection**2) ** 0.5 + center = np.array(self.center) + projection * unit_normal + + v = np.zeros(3) + v[np.argmin(np.abs(unit_normal))] = 1 + u = np.cross(unit_normal, v) + u /= np.sum(u**2) ** 0.5 + v = np.cross(unit_normal, u) + + angles = np.linspace(0, 2 * np.pi, quad_segs * 4 + 1)[:-1] + circ = center + np.outer(np.cos(angles), radius * u) + np.outer(np.sin(angles), radius * v) + vertices = np.dot(np.hstack((circ, np.ones((angles.size, 1)))), to_2D.T) + return [shapely.Polygon(vertices[:, :2])] + + def intersections_plane( + self, + x: Optional[float] = None, + y: Optional[float] = None, + z: Optional[float] = None, + cleanup: bool = True, + quad_segs: Optional[int] = None, + ) -> list[BaseGeometry]: + """Returns shapely geometry at plane specified by one non None value of x,y,z. + + Parameters + ---------- + x : float = None + Position of plane in x direction, only one of x,y,z can be specified to define plane. + y : float = None + Position of plane in x direction, only one of x,y,z can be specified to define plane. + z : float = None + Position of plane in x direction, only one of x,y,z can be specified to define plane. + cleanup : bool = True + If True, removes extremely small features from each polygon's boundary. + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. If ``None``, uses + ``_N_SHAPELY_QUAD_SEGS_VISUALIZATION`` for high-quality visualization. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation ``. + """ + if quad_segs is None: + quad_segs = _N_SHAPELY_QUAD_SEGS_VISUALIZATION + + axis, position = self.parse_xyz_kwargs(x=x, y=y, z=z) + if not self.intersects_axis_position(axis, position): + return [] + z0, (x0, y0) = self.pop_axis(self.center, axis=axis) + intersect_dist = self._intersect_dist(position, z0) + if not intersect_dist: + return [] + return [shapely.Point(x0, y0).buffer(0.5 * intersect_dist, quad_segs=quad_segs)] + + @cached_property + def bounds(self) -> Bound: + """Returns bounding box min and max coordinates. + + Returns + ------- + Tuple[float, float, float], Tuple[float, float, float] + Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. + """ + coord_min = tuple(c - self.radius for c in self.center) + coord_max = tuple(c + self.radius for c in self.center) + return (coord_min, coord_max) + + def _volume(self, bounds: Bound) -> float: + """Returns object's volume within given bounds.""" + + volume = 4.0 / 3.0 * np.pi * self.radius**3 + + # a very loose upper bound on how much of sphere is in bounds + for axis in range(3): + if self.center[axis] <= bounds[0][axis] or self.center[axis] >= bounds[1][axis]: + volume *= 0.5 + + return volume + + def _surface_area(self, bounds: Bound) -> float: + """Returns object's surface area within given bounds.""" + + area = 4.0 * np.pi * self.radius**2 + + # a very loose upper bound on how much of sphere is in bounds + for axis in range(3): + if self.center[axis] <= bounds[0][axis] or self.center[axis] >= bounds[1][axis]: + area *= 0.5 + + return area + + @classmethod + def unit_sphere_triangles( + cls, + *, + target_edge_length: Optional[float] = None, + subdivisions: Optional[int] = None, + ) -> np.ndarray: + """Return unit sphere triangles discretized via an icosphere.""" + + unit_tris = UNIT_SPHERE._unit_sphere_triangles( + target_edge_length=target_edge_length, + subdivisions=subdivisions, + copy_result=True, + ) + return unit_tris + + def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: + """Compute adjoint derivatives using smooth sphere surface samples.""" + valid_paths = {("radius",), *{("center", i) for i in range(3)}} + for path in derivative_info.paths: + if path not in valid_paths: + raise ValueError( + f"No derivative defined w.r.t. 'Sphere' field '{path}'. " + "Supported fields are 'radius' and 'center'." + ) + + if not derivative_info.paths: + return {} + + grid_cfg = config.adjoint + radius = float(get_static(self.radius)) + if radius == 0.0: + log.warning( + "Sphere gradients cannot be computed for zero radius; gradients are zero.", + log_once=True, + ) + return dict.fromkeys(derivative_info.paths, 0.0) + + wvl_mat = discretization_wavelength(derivative_info, "sphere") + target_edge = max(wvl_mat / grid_cfg.points_per_wavelength, np.finfo(float).eps) + triangles, _ = self._triangulated_surface(max_edge_length=target_edge) + triangles = triangles.astype(grid_cfg.gradient_dtype_float, copy=False) + + sim_min, sim_max = ( + np.asarray(arr, dtype=grid_cfg.gradient_dtype_float) + for arr in derivative_info.simulation_bounds + ) + tol = config.adjoint.edge_clip_tolerance + + sim_extents = sim_max - sim_min + collapsed_indices = np.flatnonzero(np.isclose(sim_extents, 0.0, atol=tol)) + if collapsed_indices.size: + if collapsed_indices.size > 1: + return dict.fromkeys(derivative_info.paths, 0.0) + axis_idx = int(collapsed_indices[0]) + plane_value = float(sim_min[axis_idx]) + return self._compute_derivatives_collapsed_axis( + derivative_info=derivative_info, + axis_idx=axis_idx, + plane_value=plane_value, + ) + + trimesh_obj = TriangleMesh._triangles_to_trimesh(triangles) + vertices = np.asarray(trimesh_obj.vertices, dtype=grid_cfg.gradient_dtype_float) + center = np.asarray(self.center, dtype=grid_cfg.gradient_dtype_float) + verts_centered = vertices - center + norms = np.linalg.norm(verts_centered, axis=1, keepdims=True) + norms = np.where(norms == 0, 1, norms) + normals = verts_centered / norms + + if vertices.size == 0: + return dict.fromkeys(derivative_info.paths, 0.0) + + # get vertex weights + faces = np.asarray(trimesh_obj.faces, dtype=int) + face_areas = np.asarray(trimesh_obj.area_faces, dtype=grid_cfg.gradient_dtype_float) + weights = np.zeros(len(vertices), dtype=grid_cfg.gradient_dtype_float) + np.add.at(weights, faces[:, 0], face_areas / 3.0) + np.add.at(weights, faces[:, 1], face_areas / 3.0) + np.add.at(weights, faces[:, 2], face_areas / 3.0) + + perp1, perp2 = self._tangent_basis_from_normals(normals) + + valid_axes = np.abs(sim_max - sim_min) > tol + inside_mask = np.all( + vertices[:, valid_axes] >= (sim_min - tol)[valid_axes], axis=1 + ) & np.all(vertices[:, valid_axes] <= (sim_max + tol)[valid_axes], axis=1) + + if not np.any(inside_mask): + return dict.fromkeys(derivative_info.paths, 0.0) + + points = vertices[inside_mask] + normals_sel = normals[inside_mask] + perp1_sel = perp1[inside_mask] + perp2_sel = perp2[inside_mask] + weights_sel = weights[inside_mask] + + interpolators = derivative_info.interpolators + if interpolators is None: + interpolators = derivative_info.create_interpolators( + dtype=grid_cfg.gradient_dtype_float + ) + + g = derivative_info.evaluate_gradient_at_points( + points, + normals_sel, + perp1_sel, + perp2_sel, + interpolators, + ) + + weighted = (weights_sel * g).real + grad_center = np.sum(weighted[:, None] * normals_sel, axis=0) + grad_radius = np.sum(weighted) + + vjps: AutogradFieldMap = {} + for path in derivative_info.paths: + if path == ("radius",): + vjps[path] = float(grad_radius) + else: + _, idx = path + vjps[path] = float(grad_center[idx]) + + return vjps + + def _compute_derivatives_collapsed_axis( + self, + derivative_info: DerivativeInfo, + axis_idx: int, + plane_value: float, + ) -> AutogradFieldMap: + """Delegate collapsed-axis gradients to a Cylinder cross section.""" + tol = config.adjoint.edge_clip_tolerance + radius = float(self.radius) + center = np.asarray(self.center, dtype=float) + delta = plane_value - center[axis_idx] + radius_sq = radius**2 - delta**2 + if radius_sq <= tol**2: + return dict.fromkeys(derivative_info.paths, 0.0) + + radius_plane = float(np.sqrt(max(radius_sq, 0.0))) + if radius_plane <= tol: + return dict.fromkeys(derivative_info.paths, 0.0) + + cyl_paths: set[tuple[str, int | None]] = set() + need_radius = False + for path in derivative_info.paths: + if path == ("radius",) or path == ("center", axis_idx): + cyl_paths.add(("radius",)) + need_radius = True + elif path[0] == "center" and path[1] != axis_idx: + cyl_paths.add(("center", path[1])) + + if not cyl_paths: + return dict.fromkeys(derivative_info.paths, 0.0) + + cyl_center = center.copy() + cyl_center[axis_idx] = plane_value + cylinder = Cylinder( + center=tuple(cyl_center), + radius=radius_plane, + length=discretization_wavelength(derivative_info, "sphere") * 2.0, + axis=axis_idx, + ) + + bounds_min = list(cyl_center) + bounds_max = list(cyl_center) + for dim in range(3): + if dim == axis_idx: + continue + bounds_min[dim] = center[dim] - radius_plane + bounds_max[dim] = center[dim] + radius_plane + + bounds = (tuple(bounds_min), tuple(bounds_max)) + sim_min_arr, sim_max_arr = ( + np.asarray(arr, dtype=float) for arr in derivative_info.simulation_bounds + ) + intersect_min = tuple(max(bounds[0][i], sim_min_arr[i]) for i in range(3)) + intersect_max = tuple(min(bounds[1][i], sim_max_arr[i]) for i in range(3)) + if any(lo > hi for lo, hi in zip(intersect_min, intersect_max)): + return dict.fromkeys(derivative_info.paths, 0.0) + + derivative_info_cyl = derivative_info.updated_copy( + paths=list(cyl_paths), + bounds=bounds, + bounds_intersect=(intersect_min, intersect_max), + ) + + vjps_cyl = cylinder._compute_derivatives(derivative_info_cyl) + result = dict.fromkeys(derivative_info.paths, 0.0) + vjp_radius = float(vjps_cyl.get(("radius",), 0.0)) if need_radius else 0.0 + + for path in derivative_info.paths: + if path == ("radius",): + result[path] = vjp_radius * (radius / radius_plane) + elif path == ("center", axis_idx): + result[path] = vjp_radius * (delta / radius_plane) + elif path[0] == "center" and path[1] != axis_idx: + result[path] = float(vjps_cyl.get(("center", path[1]), 0.0)) + + return result + + def _edge_length_on_unit_sphere( + self, max_edge_length: Optional[float] = _DEFAULT_EDGE_FRACTION + ) -> Optional[float]: + """Convert ``max_edge_length`` in μm to unit-sphere coordinates.""" + max_edge_length = _DEFAULT_EDGE_FRACTION if max_edge_length is None else max_edge_length + radius = float(self.radius) + if radius <= 0.0: + return None + return max_edge_length / radius + + def _triangulated_surface( + self, + *, + max_edge_length: Optional[float] = None, + subdivisions: Optional[int] = None, + ) -> tuple[np.ndarray, np.ndarray]: + """Return physical and unit triangles for the surface discretization. Pass either max_edge_length or subdivisions.""" + max_edge_length_unit = None + if subdivisions is None: + max_edge_length_unit = self._edge_length_on_unit_sphere(max_edge_length) + + unit_tris = self._unit_sphere_triangles( + target_edge_length=max_edge_length_unit, + subdivisions=subdivisions, + copy_result=False, + ) + + radius = float(get_static(self.radius)) + center = np.asarray(self.center, dtype=float) + dtype = config.adjoint.gradient_dtype_float + + physical = radius * unit_tris + center + return physical.astype(dtype, copy=False), unit_tris.astype(dtype, copy=False) + + def _unit_sphere_triangles( + self, + *, + target_edge_length: Optional[float] = None, + subdivisions: Optional[int] = None, + copy_result: bool = True, + ) -> np.ndarray: + """Return cached unit-sphere triangles with optional copying. Pass either target_edge_length or subdivisions.""" + if target_edge_length is not None and subdivisions is not None: + raise ValueError("Specify either target_edge_length OR subdivisions, not both.") + + if subdivisions is None: + subdivisions = self._subdivisions_for_edge(target_edge_length) + + triangles, _ = self._icosphere_data(subdivisions) + return np.array(triangles, copy=copy_result) + + def _subdivisions_for_edge(self, target_edge_length: Optional[float]) -> int: + if target_edge_length is None or target_edge_length <= 0.0: + return 0 + + for subdiv in range(_MAX_ICOSPHERE_SUBDIVISIONS + 1): + _, max_edge = self._icosphere_data(subdiv) + if max_edge <= target_edge_length: + return subdiv + + log.warning( + f"Requested sphere mesh edge length {target_edge_length:.3e} μm requires more than " + f"{_MAX_ICOSPHERE_SUBDIVISIONS} subdivisions. " + "Clipping to the finest available mesh.", + log_once=True, + ) + return _MAX_ICOSPHERE_SUBDIVISIONS + + @staticmethod + def _tangent_basis_from_normals(normals: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """Construct orthonormal tangential bases for each normal vector (vectorized).""" + + dtype = normals.dtype + tol = np.finfo(dtype).eps + + # Normalize normals (in case they are not perfectly unit length). + n_norm = np.linalg.norm(normals, axis=1) + n = normals / np.maximum(n_norm, tol)[:, None] + + # Pick a reference axis least aligned with each normal: argmin(|nx|,|ny|,|nz|). + ref_idx = np.argmin(np.abs(n), axis=1) + ref = np.zeros_like(n) + ref[np.arange(n.shape[0]), ref_idx] = 1.0 + + basis1 = np.cross(n, ref) + b1_norm = np.linalg.norm(basis1, axis=1) + basis1 = basis1 / np.maximum(b1_norm, tol)[:, None] + + basis2 = np.cross(n, basis1) + b2_norm = np.linalg.norm(basis2, axis=1) + basis2 = basis2 / np.maximum(b2_norm, tol)[:, None] + + return basis1, basis2 + + def _icosphere_data(self, subdivisions: int) -> tuple[np.ndarray, float]: + cache = self._icosphere_cache + if subdivisions in cache: + return cache[subdivisions] + + vertices = np.asarray(_ICOSAHEDRON_VERTS, dtype=float) + faces = np.asarray(_ICOSAHEDRON_FACES, dtype=int) + if subdivisions > 0: + vertices = vertices.copy() + faces = faces.copy() + for _ in range(subdivisions): + vertices, faces = TriangleMesh.subdivide_faces(vertices, faces) + + norms = np.linalg.norm(vertices, axis=1, keepdims=True) + norms = np.where(norms == 0.0, 1.0, norms) + vertices = vertices / norms + + triangles = vertices[faces] + max_edge = self._max_edge_length(triangles) + cache[subdivisions] = (triangles, max_edge) + return triangles, max_edge + + @staticmethod + def _max_edge_length(triangles: np.ndarray) -> float: + v = triangles + edges = np.stack( + [ + v[:, 1] - v[:, 0], + v[:, 2] - v[:, 1], + v[:, 0] - v[:, 2], + ], + axis=1, + ) + return float(np.linalg.norm(edges, axis=2).max()) + + +UNIT_SPHERE = Sphere(center=(0.0, 0.0, 0.0), radius=1.0) + + +class Cylinder(base.Centered, base.Circular, base.Planar): + """Cylindrical geometry with optional sidewall angle along axis + direction. When ``sidewall_angle`` is nonzero, the shape is a + conical frustum or a cone. + + Example + ------- + >>> c = Cylinder(center=(1,2,3), radius=2, length=5, axis=2) + + See Also + -------- + + **Notebooks** + + * `THz integrated demultiplexer/filter based on a ring resonator <../../../notebooks/THzDemultiplexerFilter.html>`_ + * `Photonic crystal waveguide polarization filter <../../../notebooks/PhotonicCrystalWaveguidePolarizationFilter.html>`_ + """ + + # Provide more explanations on where radius is defined + radius: TracedSize1D = Field( + title="Radius", + description="Radius of geometry at the ``reference_plane``.", + json_schema_extra={"units": MICROMETER}, + ) + + length: TracedSize1D = Field( + title="Length", + description="Defines thickness of cylinder along axis dimension.", + json_schema_extra={"units": MICROMETER}, + ) + + @model_validator(mode="after") + def _only_middle_for_infinite_length_slanted_cylinder(self: Self) -> Self: + """For a slanted cylinder of infinite length, ``reference_plane`` can only + be ``middle``; otherwise, the radius at ``center`` is either td.inf or 0. + """ + if isclose(self.sidewall_angle, 0) or not np.isinf(self.length): + return self + if self.reference_plane != "middle": + raise SetupError( + "For a slanted cylinder here is of infinite length, " + "defining the reference_plane other than 'middle' " + "leads to undefined cylinder behaviors near 'center'." + ) + return self + + def to_polyslab( + self, num_pts_circumference: int = _N_PTS_CYLINDER_POLYSLAB, **kwargs: Any + ) -> PolySlab: + """Convert instance of ``Cylinder`` into a discretized version using ``PolySlab``. + + Parameters + ---------- + num_pts_circumference : int = 51 + Number of points in the circumference of the discretized polyslab. + **kwargs: + Extra keyword arguments passed to ``PolySlab()``, such as ``dilation``. + + Returns + ------- + PolySlab + Extruded polygon representing a discretized version of the cylinder. + """ + + center_axis = self.center_axis + length_axis = self.length_axis + slab_bounds = (center_axis - length_axis / 2.0, center_axis + length_axis / 2.0) + + if num_pts_circumference < 3: + raise ValueError("'PolySlab' from 'Cylinder' must have 3 or more radius points.") + + _, (x0, y0) = self.pop_axis(self.center, axis=self.axis) + + xs_, ys_ = self._points_unit_circle(num_pts_circumference=num_pts_circumference) + + xs = x0 + self.radius * xs_ + ys = y0 + self.radius * ys_ + + vertices = anp.stack((xs, ys), axis=-1) + + return PolySlab( + vertices=vertices, + axis=self.axis, + slab_bounds=slab_bounds, + sidewall_angle=self.sidewall_angle, + reference_plane=self.reference_plane, + **kwargs, + ) + + def _points_unit_circle( + self, num_pts_circumference: int = _N_PTS_CYLINDER_POLYSLAB + ) -> np.ndarray: + """Set of x and y points for the unit circle when discretizing cylinder as a polyslab.""" + angles = np.linspace(0, 2 * np.pi, num_pts_circumference, endpoint=False) + xs = np.cos(angles) + ys = np.sin(angles) + return np.stack((xs, ys), axis=0) + + def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: + """Compute the adjoint derivatives for this object.""" + + # compute circumference discretization + wvl_mat = discretization_wavelength(derivative_info, "cylinder") + + circumference = 2 * np.pi * self.radius + wvls_in_circumference = circumference / wvl_mat + + grid_cfg = config.adjoint + num_pts_circumference = int(np.ceil(grid_cfg.points_per_wavelength * wvls_in_circumference)) + num_pts_circumference = max(3, num_pts_circumference) + + # construct equivalent polyslab and compute the derivatives + polyslab = self.to_polyslab(num_pts_circumference=num_pts_circumference) + + # build PolySlab derivative paths based on requested Cylinder paths + ps_paths = set() + for path in derivative_info.paths: + if path == ("length",): + ps_paths.update({("slab_bounds", 0), ("slab_bounds", 1)}) + elif path == ("radius",): + ps_paths.add(("vertices",)) + elif "center" in path: + _, center_index = path + _, (index_x, index_y) = self.pop_axis((0, 1, 2), axis=self.axis) + if center_index in (index_x, index_y): + ps_paths.add(("vertices",)) + else: + ps_paths.update({("slab_bounds", 0), ("slab_bounds", 1)}) + elif path == ("sidewall_angle",): + ps_paths.add(("sidewall_angle",)) + + # pass interpolators to PolySlab if available to avoid redundant conversions + update_kwargs = { + "paths": list(ps_paths), + "deep": False, + } + if derivative_info.interpolators is not None: + update_kwargs["interpolators"] = derivative_info.interpolators + + derivative_info_polyslab = derivative_info.updated_copy(**update_kwargs) + vjps_polyslab = polyslab._compute_derivatives(derivative_info_polyslab) + + vjps = {} + for path in derivative_info.paths: + if path == ("length",): + vjp_top = vjps_polyslab.get(("slab_bounds", 0), 0.0) + vjp_bot = vjps_polyslab.get(("slab_bounds", 1), 0.0) + vjps[path] = vjp_top - vjp_bot + + elif path == ("radius",): + # transform polyslab vertices derivatives into radius derivative + xs_, ys_ = self._points_unit_circle(num_pts_circumference=num_pts_circumference) + if ("vertices",) not in vjps_polyslab: + vjps[path] = 0.0 + else: + vjps_vertices_xs, vjps_vertices_ys = vjps_polyslab[("vertices",)].T + vjp_xs = np.sum(xs_ * vjps_vertices_xs) + vjp_ys = np.sum(ys_ * vjps_vertices_ys) + vjps[path] = vjp_xs + vjp_ys + + elif "center" in path: + _, center_index = path + _, (index_x, index_y) = self.pop_axis((0, 1, 2), axis=self.axis) + if center_index == index_x: + if ("vertices",) not in vjps_polyslab: + vjps[path] = 0.0 + else: + vjps_vertices_xs = vjps_polyslab[("vertices",)][:, 0] + vjps[path] = np.sum(vjps_vertices_xs) + elif center_index == index_y: + if ("vertices",) not in vjps_polyslab: + vjps[path] = 0.0 + else: + vjps_vertices_ys = vjps_polyslab[("vertices",)][:, 1] + vjps[path] = np.sum(vjps_vertices_ys) + else: + vjp_top = vjps_polyslab.get(("slab_bounds", 0), 0.0) + vjp_bot = vjps_polyslab.get(("slab_bounds", 1), 0.0) + vjps[path] = vjp_top + vjp_bot + + elif path == ("sidewall_angle",): + # direct mapping: cylinder angle equals polyslab angle + vjps[path] = vjps_polyslab.get(("sidewall_angle",), 0.0) + + else: + raise NotImplementedError( + f"Differentiation with respect to 'Cylinder' '{path}' field not supported. " + "If you would like this feature added, please feel free to raise " + "an issue on the tidy3d front end repository." + ) + + return vjps + + @property + def center_axis(self) -> Any: + """Gets the position of the center of the geometry in the out of plane dimension.""" + z0, _ = self.pop_axis(self.center, axis=self.axis) + return z0 + + @property + def length_axis(self) -> float: + """Gets the length of the geometry along the out of plane dimension.""" + return self.length + + @cached_property + def _normal_2dmaterial(self) -> Axis: + """Get the normal to the given geometry, checking that it is a 2D geometry.""" + if self.length != 0: + raise ValidationError("'Medium2D' requires the 'Cylinder' length to be zero.") + return self.axis + + def _update_from_bounds(self, bounds: tuple[float, float], axis: Axis) -> Cylinder: + """Returns an updated geometry which has been transformed to fit within ``bounds`` + along the ``axis`` direction.""" + if axis != self.axis: + raise ValueError( + f"'_update_from_bounds' may only be applied along axis '{self.axis}', " + f"but was given axis '{axis}'." + ) + new_center = list(self.center) + new_center[axis] = (bounds[0] + bounds[1]) / 2 + new_length = bounds[1] - bounds[0] + return self.updated_copy(center=tuple(new_center), length=new_length) + + @verify_packages_import(["trimesh"]) + def _do_intersections_tilted_plane( + self, + normal: Coordinate, + origin: Coordinate, + to_2D: MatrixReal4x4, + quad_segs: Optional[int] = None, + ) -> list[Shapely]: + """Return a list of shapely geometries at the plane specified by normal and origin. + + Parameters + ---------- + normal : Coordinate + Vector defining the normal direction to the plane. + origin : Coordinate + Vector defining the plane origin. + to_2D : MatrixReal4x4 + Transformation matrix to apply to resulting shapes. + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. If ``None``, uses + ``_N_SHAPELY_QUAD_SEGS_VISUALIZATION`` for high-quality visualization. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation `_. + """ + import trimesh + + if quad_segs is None: + quad_segs = _N_SHAPELY_QUAD_SEGS_VISUALIZATION + + z0, (x0, y0) = self.pop_axis(self.center, self.axis) + half_length = self.finite_length_axis / 2 + + z_top = z0 + half_length + z_bot = z0 - half_length + + if np.isclose(self.sidewall_angle, 0): + r_top = self.radius + r_bot = self.radius + else: + r_top = self.radius_top + r_bot = self.radius_bottom + if r_top < 0 or np.isclose(r_top, 0): + r_top = 0 + z_top = z0 + self._radius_z(z0) / self._tanq + elif r_bot < 0 or np.isclose(r_bot, 0): + r_bot = 0 + z_bot = z0 + self._radius_z(z0) / self._tanq + + angles = np.linspace(0, 2 * np.pi, quad_segs * 4 + 1) + + if r_bot > 0: + x_bot = x0 + r_bot * np.cos(angles) + y_bot = y0 + r_bot * np.sin(angles) + x_bot[-1] = x0 + y_bot[-1] = y0 + else: + x_bot = np.array([x0]) + y_bot = np.array([y0]) + + if r_top > 0: + x_top = x0 + r_top * np.cos(angles) + y_top = y0 + r_top * np.sin(angles) + x_top[-1] = x0 + y_top[-1] = y0 + else: + x_top = np.array([x0]) + y_top = np.array([y0]) + + x = np.hstack((x_bot, x_top)) + y = np.hstack((y_bot, y_top)) + z = np.hstack((np.full_like(x_bot, z_bot), np.full_like(x_top, z_top))) + vertices = np.vstack(self.unpop_axis(z, (x, y), self.axis)).T + + if x_bot.shape[0] == 1: + m = 1 + n = x_top.shape[0] - 1 + faces_top = [(m + n, m + i, m + (i + 1) % n) for i in range(n)] + faces_side = [(m + (i + 1) % n, m + i, 0) for i in range(n)] + faces = faces_top + faces_side + elif x_top.shape[0] == 1: + m = x_bot.shape[0] + n = m - 1 + faces_bot = [(n, (i + 1) % n, i) for i in range(n)] + faces_side = [(i, (i + 1) % n, m) for i in range(n)] + faces = faces_bot + faces_side + else: + m = x_bot.shape[0] + n = m - 1 + faces_bot = [(n, (i + 1) % n, i) for i in range(n)] + faces_top = [(m + n, m + i, m + (i + 1) % n) for i in range(n)] + faces_side_bot = [(i, (i + 1) % n, m + (i + 1) % n) for i in range(n)] + faces_side_top = [(m + (i + 1) % n, m + i, i) for i in range(n)] + faces = faces_bot + faces_top + faces_side_bot + faces_side_top + + mesh = trimesh.Trimesh(vertices, faces) + + section = mesh.section(plane_origin=origin, plane_normal=normal) + if section is None: + return [] + path, _ = section.to_2D(to_2D=to_2D) + return path.polygons_full + + def _intersections_normal( + self, z: float, quad_segs: Optional[int] = None + ) -> list[BaseGeometry]: + """Find shapely geometries intersecting cylindrical geometry with axis normal to slab. + + Parameters + ---------- + z : float + Position along the axis normal to slab + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. If ``None``, uses + ``_N_SHAPELY_QUAD_SEGS_VISUALIZATION`` for high-quality visualization. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation `_. + """ + if quad_segs is None: + quad_segs = _N_SHAPELY_QUAD_SEGS_VISUALIZATION + + static_self = self.to_static() + + # radius at z + radius_offset = static_self._radius_z(z) + + if radius_offset <= 0: + return [] + + _, (x0, y0) = self.pop_axis(static_self.center, axis=self.axis) + return [shapely.Point(x0, y0).buffer(radius_offset, quad_segs=quad_segs)] + + def _intersections_side(self, position: float, axis: int) -> list[BaseGeometry]: + """Find shapely geometries intersecting cylindrical geometry with axis orthogonal to length. + When ``sidewall_angle`` is nonzero, so that it's in fact a conical frustum or cone, the + cross section can contain hyperbolic curves. This is currently approximated by a polygon + of many vertices. + + Parameters + ---------- + position : float + Position along axis direction. + axis : int + Integer index into 'xyz' (0, 1, 2). + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation `_. + """ + # position in the local coordinate of the cylinder + position_local = position - self.center[axis] + + # no intersection + if abs(position_local) >= self.radius_max: + return [] + + # half of intersection length at the top and bottom + intersect_half_length_max = np.sqrt(self.radius_max**2 - position_local**2) + intersect_half_length_min = -LARGE_NUMBER + if abs(position_local) < self.radius_min: + intersect_half_length_min = np.sqrt(self.radius_min**2 - position_local**2) + + # the vertices on the max side of top/bottom + # The two vertices are present in all scenarios. + vertices_max = [ + self._local_to_global_side_cross_section([-intersect_half_length_max, 0], axis), + self._local_to_global_side_cross_section([intersect_half_length_max, 0], axis), + ] + + # Extending to a cone, the maximal height of the cone + h_cone = ( + LARGE_NUMBER if isclose(self.sidewall_angle, 0) else self.radius_max / abs(self._tanq) + ) + # The maximal height of the cross section + height_max = min( + (1 - abs(position_local) / self.radius_max) * h_cone, self.finite_length_axis + ) + + # more vertices to add for conical frustum shape + vertices_frustum_right = [] + vertices_frustum_left = [] + if not (isclose(position, self.center[axis]) or isclose(self.sidewall_angle, 0)): + # The y-coordinate for the additional vertices + y_list = height_max * np.linspace(0, 1, _N_SAMPLE_CURVE_SHAPELY) + # `abs()` to make sure np.sqrt(0-fp_eps) goes through + x_list = np.sqrt( + np.abs(self.radius_max**2 * (1 - y_list / h_cone) ** 2 - position_local**2) + ) + for i in range(_N_SAMPLE_CURVE_SHAPELY): + vertices_frustum_right.append( + self._local_to_global_side_cross_section([x_list[i], y_list[i]], axis) + ) + vertices_frustum_left.append( + self._local_to_global_side_cross_section( + [ + -x_list[_N_SAMPLE_CURVE_SHAPELY - i - 1], + y_list[_N_SAMPLE_CURVE_SHAPELY - i - 1], + ], + axis, + ) + ) + + # the vertices on the min side of top/bottom + vertices_min = [] + + ## termination at the top/bottom + if intersect_half_length_min > 0: + vertices_min.append( + self._local_to_global_side_cross_section( + [intersect_half_length_min, self.finite_length_axis], axis + ) + ) + vertices_min.append( + self._local_to_global_side_cross_section( + [-intersect_half_length_min, self.finite_length_axis], axis + ) + ) + ## early termination + else: + vertices_min.append(self._local_to_global_side_cross_section([0, height_max], axis)) + + return [ + shapely.Polygon( + vertices_max + vertices_frustum_right + vertices_min + vertices_frustum_left + ) + ] + + def inside( + self, x: np.ndarray[float], y: np.ndarray[float], z: np.ndarray[float] + ) -> np.ndarray[bool]: + """For input arrays ``x``, ``y``, ``z`` of arbitrary but identical shape, return an array + with the same shape which is ``True`` for every point in zip(x, y, z) that is inside the + volume of the :class:`Geometry`, and ``False`` otherwise. + + Parameters + ---------- + x : np.ndarray[float] + Array of point positions in x direction. + y : np.ndarray[float] + Array of point positions in y direction. + z : np.ndarray[float] + Array of point positions in z direction. + + Returns + ------- + np.ndarray[bool] + ``True`` for every point that is inside the geometry. + """ + # radius at z + self._ensure_equal_shape(x, y, z) + z0, (x0, y0) = self.pop_axis(self.center, axis=self.axis) + z, (x, y) = self.pop_axis((x, y, z), axis=self.axis) + radius_offset = self._radius_z(z) + positive_radius = radius_offset > 0 + + dist_x = np.abs(x - x0) + dist_y = np.abs(y - y0) + dist_z = np.abs(z - z0) + inside_radius = (dist_x**2 + dist_y**2) <= (radius_offset**2) + inside_height = dist_z <= (self.finite_length_axis / 2) + return positive_radius * inside_radius * inside_height + + @cached_property + def bounds(self) -> Bound: + """Returns bounding box min and max coordinates. + + Returns + ------- + Tuple[float, float, float], Tuple[float, float, float] + Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. + """ + coord_min = [c - self.radius_max for c in self.center] + coord_max = [c + self.radius_max for c in self.center] + coord_min[self.axis] = self.center[self.axis] - self.length_axis / 2.0 + coord_max[self.axis] = self.center[self.axis] + self.length_axis / 2.0 + return (tuple(coord_min), tuple(coord_max)) + + def _volume(self, bounds: Bound) -> float: + """Returns object's volume within given bounds.""" + + coord_min = max(self.bounds[0][self.axis], bounds[0][self.axis]) + coord_max = min(self.bounds[1][self.axis], bounds[1][self.axis]) + + length = coord_max - coord_min + + volume = np.pi * self.radius_max**2 * length + + # a very loose upper bound on how much of the cylinder is in bounds + for axis in range(3): + if axis != self.axis: + if self.center[axis] <= bounds[0][axis] or self.center[axis] >= bounds[1][axis]: + volume *= 0.5 + + return volume + + def _surface_area(self, bounds: Bound) -> float: + """Returns object's surface area within given bounds.""" + + area = 0 + + coord_min = self.bounds[0][self.axis] + coord_max = self.bounds[1][self.axis] + + if coord_min < bounds[0][self.axis]: + coord_min = bounds[0][self.axis] + else: + area += np.pi * self.radius_max**2 + + if coord_max > bounds[1][self.axis]: + coord_max = bounds[1][self.axis] + else: + area += np.pi * self.radius_max**2 + + length = coord_max - coord_min + + area += 2.0 * np.pi * self.radius_max * length + + # a very loose upper bound on how much of the cylinder is in bounds + for axis in range(3): + if axis != self.axis: + if self.center[axis] <= bounds[0][axis] or self.center[axis] >= bounds[1][axis]: + area *= 0.5 + + return area + + @cached_property + def radius_bottom(self) -> float: + """radius of bottom""" + return self._radius_z(self.center_axis - self.finite_length_axis / 2) + + @cached_property + def radius_top(self) -> float: + """radius of bottom""" + return self._radius_z(self.center_axis + self.finite_length_axis / 2) + + @cached_property + def radius_max(self) -> float: + """max(radius of top, radius of bottom)""" + return max(self.radius_bottom, self.radius_top) + + @cached_property + def radius_min(self) -> float: + """min(radius of top, radius of bottom). It can be negative for a large + sidewall angle. + """ + return min(self.radius_bottom, self.radius_top) + + def _radius_z(self, z: float) -> float: + """Compute the radius of the cross section at the position z. + + Parameters + ---------- + z : float + Position along the axis normal to slab + """ + if isclose(self.sidewall_angle, 0): + return self.radius + + radius_middle = self.radius + if self.reference_plane == "top": + radius_middle += self.finite_length_axis / 2 * self._tanq + elif self.reference_plane == "bottom": + radius_middle -= self.finite_length_axis / 2 * self._tanq + + return radius_middle - (z - self.center_axis) * self._tanq + + def _local_to_global_side_cross_section(self, coords: list[float], axis: int) -> list[float]: + """Map a point (x,y) from local to global coordinate system in the + side cross section. + + The definition of the local: y=0 lies at the base if ``sidewall_angle>=0``, + and at the top if ``sidewall_angle<0``; x=0 aligns with the corresponding + ``self.center``. In both cases, y-axis is pointing towards the narrowing + direction of cylinder. + + Parameters + ---------- + axis : int + Integer index into 'xyz' (0, 1, 2). + coords : list[float, float] + The value in the planar coordinate. + + Returns + ------- + Tuple[float, float] + The point in the global coordinate for plotting `_intersection_side`. + + """ + + # For negative sidewall angle, quantities along axis direction usually needs a flipped sign + axis_sign = 1 + if self.sidewall_angle < 0: + axis_sign = -1 + + lx_offset, ly_offset = self._order_by_axis( + plane_val=coords[0], + axis_val=axis_sign * (-self.finite_length_axis / 2 + coords[1]), + axis=axis, + ) + _, (x_center, y_center) = self.pop_axis(self.center, axis=axis) + return [x_center + lx_offset, y_center + ly_offset] diff --git a/tidy3d/_common/components/geometry/triangulation.py b/tidy3d/_common/components/geometry/triangulation.py new file mode 100644 index 0000000000..db80da30a7 --- /dev/null +++ b/tidy3d/_common/components/geometry/triangulation.py @@ -0,0 +1,188 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import numpy as np +import shapely + +from tidy3d._common.components.types.base import ArrayFloat1D +from tidy3d._common.exceptions import Tidy3dError + +if TYPE_CHECKING: + from tidy3d._common.components.types.base import ArrayFloat2D + +if TYPE_CHECKING: + from tidy3d._common.components.types import ArrayFloat2D + + +@dataclass +class Vertex: + """Simple data class to hold triangulation data structures. + + Parameters + ---------- + coordinate: ArrayFloat1D + Vertex coordinate. + index : int + Vertex index in the original polygon. + convexity : float = 0.0 + Value representing the convexity (> 0) or concavity (< 0) of the vertex in the polygon. + is_ear : bool = False + Flag indicating whether this is an ear of the polygon. + """ + + coordinate: ArrayFloat1D + + index: int + + convexity: float + + is_ear: bool + + +def update_convexity(vertices: list[Vertex], i: int) -> int: + """Update the convexity of a vertex in a polygon. + + Parameters + ---------- + vertices : list[Vertex] + Vertices of the polygon. + i : int + Index of the vertex to be updated. + + Returns + ------- + int + Value indicating vertex convexity change w.r.t. 0. See note below. + + Note + ---- + Besides updating the vertex, this function returns a value indicating whether the updated vertex + convexity changed to or from 0 (0 convexity means the vertex is collinear with its neighbors). + If the convexity changes from zero to non-zero, return -1. If it changes from non-zero to zero, + return +1. Return 0 in any other case. This allows the main triangulation loop to keep track of + the total number of collinear vertices in the polygon. + + """ + result = -1 if vertices[i].convexity == 0.0 else 0 + j = (i + 1) % len(vertices) + vertices[i].convexity = np.linalg.det( + [ + vertices[i].coordinate - vertices[i - 1].coordinate, + vertices[j].coordinate - vertices[i].coordinate, + ] + ) + if vertices[i].convexity == 0.0: + result += 1 + return result + + +def is_inside( + vertex: ArrayFloat1D, triangle: tuple[ArrayFloat1D, ArrayFloat1D, ArrayFloat1D] +) -> bool: + """Check if a vertex is inside a triangle. + + Parameters + ---------- + vertex : ArrayFloat1D + Vertex coordinates. + triangle : tuple[ArrayFloat1D, ArrayFloat1D, ArrayFloat1D] + Vertices of the triangle in CCW order. + + Returns + ------- + bool: + Flag indicating if the vertex is inside the triangle. + """ + return all( + np.linalg.det([triangle[i] - triangle[i - 1], vertex - triangle[i - 1]]) > 0 + for i in range(3) + ) + + +def update_ear_flag(vertices: list[Vertex], i: int) -> None: + """Update the ear flag of a vertex in a polygon. + + Parameters + ---------- + vertices : list[Vertex] + Vertices of the polygon. + i : int + Index of the vertex to be updated. + """ + h = (i - 1) % len(vertices) + j = (i + 1) % len(vertices) + triangle = (vertices[h].coordinate, vertices[i].coordinate, vertices[j].coordinate) + vertices[i].is_ear = vertices[i].convexity > 0 and not any( + is_inside(v.coordinate, triangle) + for k, v in enumerate(vertices) + if not (v.convexity > 0 or k == h or k == i or k == j) + ) + + +# TODO: This is an inefficient algorithm that runs in O(n^2). We should use something +# better, and probably as a compiled extension. +def triangulate(vertices: ArrayFloat2D) -> list[tuple[int, int, int]]: + """Triangulate a simple polygon. + + Parameters + ---------- + vertices : ArrayFloat2D + Vertices of the polygon. + + Returns + ------- + list[tuple[int, int, int]] + List of indices of the vertices of the triangles. + """ + is_ccw = shapely.LinearRing(vertices).is_ccw + + # Initialize vertices as non-collinear because we will update the actual value below and count + # the number of collinear vertices. + vertices = [Vertex(v, i, -1.0, False) for i, v in enumerate(vertices)] + if not is_ccw: + vertices.reverse() + + collinears = 0 + for i in range(len(vertices)): + collinears += update_convexity(vertices, i) + + for i in range(len(vertices)): + update_ear_flag(vertices, i) + + triangles = [] + + ear_found = True + while len(vertices) > 3: + if not ear_found: + raise Tidy3dError( + "Impossible to triangulate polygon. Verify that the polygon is valid." + ) + ear_found = False + i = 0 + while i < len(vertices): + if vertices[i].is_ear: + removed = vertices.pop(i) + h = (i - 1) % len(vertices) + j = i % len(vertices) + collinears += update_convexity(vertices, h) + collinears += update_convexity(vertices, j) + if collinears == len(vertices): + # Undo removal because only collinear vertices remain + vertices.insert(i, removed) + collinears += update_convexity(vertices, (i - 1) % len(vertices)) + collinears += update_convexity(vertices, (i + 1) % len(vertices)) + i += 1 + else: + ear_found = True + triangles.append((vertices[h].index, removed.index, vertices[j].index)) + update_ear_flag(vertices, h) + update_ear_flag(vertices, j) + if len(vertices) == 3: + break + else: + i += 1 + + triangles.append(tuple(v.index for v in vertices)) + return triangles diff --git a/tidy3d/_common/components/geometry/utils.py b/tidy3d/_common/components/geometry/utils.py new file mode 100644 index 0000000000..eedd543126 --- /dev/null +++ b/tidy3d/_common/components/geometry/utils.py @@ -0,0 +1,481 @@ +"""Utilities for geometry manipulation.""" + +from __future__ import annotations + +from collections import defaultdict +from enum import Enum +from math import isclose +from typing import TYPE_CHECKING, Any, Optional, Union + +import numpy as np +import shapely +from pydantic import Field, NonNegativeInt +from shapely.geometry import ( + Polygon, +) +from shapely.geometry.base import ( + BaseMultipartGeometry, +) + +from tidy3d._common.components.autograd.utils import get_static +from tidy3d._common.components.base import Tidy3dBaseModel +from tidy3d._common.components.geometry import base, mesh, polyslab, primitives +from tidy3d._common.components.types.base import Shapely +from tidy3d._common.exceptions import SetupError, Tidy3dError + +if TYPE_CHECKING: + from collections.abc import Iterable + + from numpy.typing import ArrayLike + + from tidy3d._common.components.geometry.base import Box + from tidy3d._common.components.types.base import ( + ArrayFloat2D, + Axis, + MatrixReal4x4, + PlanePosition, + ) + +GeometryType = Union[ + base.Box, + base.Transformed, + base.ClipOperation, + base.GeometryGroup, + primitives.Sphere, + primitives.Cylinder, + polyslab.PolySlab, + polyslab.ComplexPolySlabBase, + mesh.TriangleMesh, +] + + +def flatten_shapely_geometries( + geoms: Union[Shapely, Iterable[Shapely]], keep_types: tuple[type, ...] = (Polygon,) +) -> list[Shapely]: + """ + Flatten nested geometries into a flat list, while only keeping the specified types. + + Recursively extracts and returns non-empty geometries of the given types from input geometries, + expanding any GeometryCollections or Multi* types. + + Parameters + ---------- + geoms : Union[Shapely, Iterable[Shapely]] + Input geometries to flatten. + + keep_types : tuple[type, ...] + Geometry types to keep (e.g., (Polygon, LineString)). Default is + (Polygon). + + Returns + ------- + list[Shapely] + Flat list of non-empty geometries matching the specified types. + """ + # Handle single Shapely object by wrapping it in a list + if isinstance(geoms, Shapely): + geoms = [geoms] + + flat = [] + for geom in geoms: + if geom.is_empty: + continue + if isinstance(geom, keep_types): + flat.append(geom) + elif isinstance(geom, BaseMultipartGeometry): + flat.extend(flatten_shapely_geometries(geom.geoms, keep_types)) + return flat + + +def merging_geometries_on_plane( + geometries: list[GeometryType], + plane: Box, + property_list: list[Any], + interior_disjoint_geometries: bool = False, + cleanup: bool = True, + quad_segs: Optional[int] = None, +) -> list[tuple[Any, Shapely]]: + """Compute list of shapes on plane. Overlaps are removed or merged depending on + provided property_list. + + Parameters + ---------- + geometries : list[GeometryType] + List of structures to filter on the plane. + plane : Box + Plane specification. + property_list : List = None + Property value for each structure. + interior_disjoint_geometries: bool = False + If ``True``, geometries of different properties on the plane must not be overlapping. + cleanup : bool = True + If True, removes extremely small features from each polygon's boundary. + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. If ``None``, uses + high-quality visualization settings. + + Returns + ------- + list[tuple[Any, Shapely]] + List of shapes and their property value on the plane after merging. + """ + + if len(geometries) != len(property_list): + raise SetupError( + "Number of provided property values is not equal to the number of geometries." + ) + + shapes = [] + for geo, prop in zip(geometries, property_list): + # get list of Shapely shapes that intersect at the plane + shapes_plane = plane.intersections_with(geo, cleanup=cleanup, quad_segs=quad_segs) + + # Append each of them and their property information to the list of shapes + for shape in shapes_plane: + shapes.append((prop, shape, shape.bounds)) + + if interior_disjoint_geometries: + # No need to consider overlapping. We simply group shapes by property, and union_all + # shapes of the same property. + shapes_by_prop = defaultdict(list) + for prop, shape, _ in shapes: + shapes_by_prop[prop].append(shape) + # union shapes of same property + results = [] + for prop, shapes in shapes_by_prop.items(): + unionized = shapely.union_all(shapes).buffer(0).normalize() + if not unionized.is_empty: + results.append((prop, unionized)) + return results + + background_shapes = [] + for prop, shape, bounds in shapes: + minx, miny, maxx, maxy = bounds + + # loop through background_shapes (note: all background are non-intersecting or merged) + for index, (_prop, _shape, _bounds) in enumerate(background_shapes): + _minx, _miny, _maxx, _maxy = _bounds + + # do a bounding box check to see if any intersection to do anything about + if minx > _maxx or _minx > maxx or miny > _maxy or _miny > maxy: + continue + + # look more closely to see if intersected. + if shape.disjoint(_shape): + continue + + # different prop, remove intersection from background shape + if prop != _prop: + diff_shape = (_shape - shape).buffer(0).normalize() + # mark background shape for removal if nothing left + if diff_shape.is_empty or len(diff_shape.bounds) == 0: + background_shapes[index] = None + background_shapes[index] = (_prop, diff_shape, diff_shape.bounds) + # same prop, unionize shapes and mark background shape for removal + else: + shape = (shape | _shape).buffer(0).normalize() + background_shapes[index] = None + + # after doing this with all background shapes, add this shape to the background + background_shapes.append((prop, shape, shape.bounds)) + + # remove any existing background shapes that have been marked as 'None' + background_shapes = [b for b in background_shapes if b is not None] + + # filter out any remaining None or empty shapes (shapes with area completely removed) + return [(prop, shape) for (prop, shape, _) in background_shapes if shape] + + +def flatten_groups( + *geometries: GeometryType, + flatten_nonunion_type: bool = False, + flatten_transformed: bool = False, + transform: Optional[MatrixReal4x4] = None, +) -> GeometryType: + """Iterates over all geometries, flattening groups and unions. + + Parameters + ---------- + *geometries : GeometryType + Geometries to flatten. + flatten_nonunion_type : bool = False + If ``False``, only flatten geometry unions (and ``GeometryGroup``). If ``True``, flatten + all clip operations. + flatten_transformed : bool = False + If ``True``, ``Transformed`` groups are flattened into individual transformed geometries. + transform : Optional[MatrixReal4x4] + Accumulated transform from parents. Only used when ``flatten_transformed`` is ``True``. + + Yields + ------ + GeometryType + Geometries after flattening groups and unions. + """ + for geometry in geometries: + if isinstance(geometry, base.GeometryGroup): + yield from flatten_groups( + *geometry.geometries, + flatten_nonunion_type=flatten_nonunion_type, + flatten_transformed=flatten_transformed, + transform=transform, + ) + elif isinstance(geometry, base.ClipOperation) and ( + flatten_nonunion_type or geometry.operation == "union" + ): + yield from flatten_groups( + geometry.geometry_a, + geometry.geometry_b, + flatten_nonunion_type=flatten_nonunion_type, + flatten_transformed=flatten_transformed, + transform=transform, + ) + elif flatten_transformed and isinstance(geometry, base.Transformed): + new_transform = geometry.transform + if transform is not None: + new_transform = np.matmul(transform, new_transform) + yield from flatten_groups( + geometry.geometry, + flatten_nonunion_type=flatten_nonunion_type, + flatten_transformed=flatten_transformed, + transform=new_transform, + ) + elif flatten_transformed and transform is not None: + yield base.Transformed(geometry=geometry, transform=transform) + else: + yield geometry + + +def traverse_geometries(geometry: GeometryType) -> GeometryType: + """Iterator over all geometries within the given geometry. + + Iterates over groups and clip operations within the given geometry, yielding each one. + + Parameters + ---------- + geometry: GeometryType + Base geometry to start iteration. + + Returns + ------- + :class:`Geometry` + Geometries within the base geometry. + """ + if isinstance(geometry, base.GeometryGroup): + for g in geometry.geometries: + yield from traverse_geometries(g) + elif isinstance(geometry, base.ClipOperation): + yield from traverse_geometries(geometry.geometry_a) + yield from traverse_geometries(geometry.geometry_b) + yield geometry + + +def from_shapely( + shape: Shapely, + axis: Axis, + slab_bounds: tuple[float, float], + dilation: float = 0.0, + sidewall_angle: float = 0, + reference_plane: PlanePosition = "middle", +) -> base.Geometry: + """Convert a shapely primitive into a geometry instance by extrusion. + + Parameters + ---------- + shape : shapely.geometry.base.BaseGeometry + Shapely primitive to be converted. It must be a linear ring, a polygon or a collection + of any of those. + axis : int + Integer index defining the extrusion axis: 0 (x), 1 (y), or 2 (z). + slab_bounds: tuple[float, float] + Minimal and maximal positions of the extruded slab along ``axis``. + dilation : float + Dilation of the polygon in the base by shifting each edge along its normal outwards + direction by a distance; a negative value corresponds to erosion. + sidewall_angle : float = 0 + Angle of the extrusion sidewalls, away from the vertical direction, in radians. Positive + (negative) values result in slabs larger (smaller) at the base than at the top. + reference_plane : PlanePosition = "middle" + Reference position of the (dilated/eroded) polygons along the slab axis. One of + ``"middle"`` (polygons correspond to the center of the slab bounds), ``"bottom"`` + (minimal slab bound position), or ``"top"`` (maximal slab bound position). This value + has no effect if ``sidewall_angle == 0``. + + Returns + ------- + :class:`Geometry` + Geometry extruded from the 2D data. + """ + if shape.geom_type == "LinearRing": + if sidewall_angle == 0: + return polyslab.PolySlab( + vertices=shape.coords[:-1], + axis=axis, + slab_bounds=slab_bounds, + dilation=dilation, + reference_plane=reference_plane, + ) + group = polyslab.ComplexPolySlabBase( + vertices=shape.coords[:-1], + axis=axis, + slab_bounds=slab_bounds, + dilation=dilation, + sidewall_angle=sidewall_angle, + reference_plane=reference_plane, + ).geometry_group + return group.geometries[0] if len(group.geometries) == 1 else group + + if shape.geom_type == "Polygon": + exterior = from_shapely( + shape.exterior, axis, slab_bounds, dilation, sidewall_angle, reference_plane + ) + interior = [ + from_shapely(hole, axis, slab_bounds, -dilation, -sidewall_angle, reference_plane) + for hole in shape.interiors + ] + if len(interior) == 0: + return exterior + interior = interior[0] if len(interior) == 1 else base.GeometryGroup(geometries=interior) + return base.ClipOperation(operation="difference", geometry_a=exterior, geometry_b=interior) + + if shape.geom_type in {"MultiPolygon", "GeometryCollection"}: + return base.GeometryGroup( + geometries=[ + from_shapely(geo, axis, slab_bounds, dilation, sidewall_angle, reference_plane) + for geo in shape.geoms + ] + ) + + raise Tidy3dError(f"Shape {shape} cannot be converted to Geometry.") + + +def vertices_from_shapely(shape: Shapely) -> ArrayFloat2D: + """Iterate over the polygons of a shapely geometry returning the vertices. + + Parameters + ---------- + shape : shapely.geometry.base.BaseGeometry + Shapely primitive to have its vertices extracted. It must be a linear ring, a polygon or a + collection of any of those. + + Returns + ------- + list[tuple[ArrayFloat2D]] + List of tuples ``(exterior, *interiors)``. + """ + if shape.geom_type == "LinearRing": + return [(shape.coords[:-1],)] + if shape.geom_type == "Polygon": + return [(shape.exterior.coords[:-1], *tuple(hole.coords[:-1] for hole in shape.interiors))] + if shape.geom_type in {"MultiPolygon", "GeometryCollection"}: + return sum(vertices_from_shapely(geo) for geo in shape.geoms) + + raise Tidy3dError(f"Shape {shape} cannot be converted to Geometry.") + + +def validate_no_transformed_polyslabs( + geometry: GeometryType, transform: MatrixReal4x4 = None +) -> None: + """Prevents the creation of slanted polyslabs rotated out of plane.""" + if transform is None: + transform = np.eye(4) + if isinstance(geometry, polyslab.PolySlab): + # sidewall_angle may be autograd-traced; unbox for the check only + if not ( + isclose(get_static(geometry.sidewall_angle), 0) + or base.Transformed.preserves_axis(transform, geometry.axis) + ): + raise Tidy3dError( + "Slanted PolySlabs are not allowed to be rotated out of the slab plane." + ) + elif isinstance(geometry, base.Transformed): + transform = np.dot(transform, geometry.transform) + validate_no_transformed_polyslabs(geometry.geometry, transform) + elif isinstance(geometry, base.GeometryGroup): + for geo in geometry.geometries: + validate_no_transformed_polyslabs(geo, transform) + elif isinstance(geometry, base.ClipOperation): + validate_no_transformed_polyslabs(geometry.geometry_a, transform) + validate_no_transformed_polyslabs(geometry.geometry_b, transform) + + +class SnapLocation(Enum): + """Describes different methods for defining the snapping locations.""" + + Boundary = 1 + """ + Choose the boundaries of Yee cells. + """ + Center = 2 + """ + Choose the center of Yee cells. + """ + + +class SnapBehavior(Enum): + """Describes different methods for snapping intervals, which are defined by two endpoints.""" + + Closest = 1 + """ + Snaps the interval's endpoints to the closest grid point. + """ + Expand = 2 + """ + Snaps the interval's endpoints to the closest grid points, + while guaranteeing that the snapping location will never move endpoints inwards. + """ + Contract = 3 + """ + Snaps the interval's endpoints to the closest grid points, + while guaranteeing that the snapping location will never move endpoints outwards. + """ + StrictExpand = 4 + """ + Same as Expand, but will always move endpoints outwards, even if already coincident with grid. + """ + StrictContract = 5 + """ + Same as Contract, but will always move endpoints inwards, even if already coincident with grid. + """ + Off = 6 + """ + Do not use snapping. + """ + + +class SnappingSpec(Tidy3dBaseModel): + """Specifies how to apply grid snapping along each dimension.""" + + location: tuple[SnapLocation, SnapLocation, SnapLocation] = Field( + title="Location", + description="Describes which positions in the grid will be considered for snapping.", + ) + + behavior: tuple[SnapBehavior, SnapBehavior, SnapBehavior] = Field( + title="Behavior", + description="Describes how snapping positions will be chosen.", + ) + + margin: Optional[tuple[NonNegativeInt, NonNegativeInt, NonNegativeInt]] = Field( + (0, 0, 0), + title="Margin", + description="Number of additional grid points to consider when expanding or contracting " + "during snapping. Only applies when ``SnapBehavior`` is ``Expand`` or ``Contract``.", + ) + + +def get_closest_value(test: float, coords: ArrayLike, upper_bound_idx: int) -> float: + """Helper to choose the closest value in an array to a given test value, + using the index of the upper bound. The ``upper_bound_idx`` corresponds to the first value in + the ``coords`` array which is greater than or equal to the test value. + """ + # Handle corner cases first + if upper_bound_idx == 0: + return coords[upper_bound_idx] + if upper_bound_idx == len(coords): + return coords[upper_bound_idx - 1] + # General case + lower_bound = coords[upper_bound_idx - 1] + upper_bound = coords[upper_bound_idx] + dlower = abs(test - lower_bound) + dupper = abs(test - upper_bound) + return lower_bound if dlower < dupper else upper_bound diff --git a/tidy3d/_common/components/source/__init__.py b/tidy3d/_common/components/source/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tidy3d/_common/components/source/base.py b/tidy3d/_common/components/source/base.py new file mode 100644 index 0000000000..2afa5cc18c --- /dev/null +++ b/tidy3d/_common/components/source/base.py @@ -0,0 +1,135 @@ +"""Defines an abstract base for electromagnetic sources.""" + +from __future__ import annotations + +from abc import ABC +from typing import TYPE_CHECKING, Any + +from pydantic import Field, field_validator + +from tidy3d._common.components.base import cached_property +from tidy3d._common.components.base_sim.source import AbstractSource +from tidy3d._common.components.geometry.base import Box +from tidy3d._common.components.types import TYPE_TAG_STR +from tidy3d._common.components.validators import _assert_min_freq, _warn_unsupported_traced_argument +from tidy3d._common.components.viz import ( + ARROW_ALPHA, + ARROW_COLOR_POLARIZATION, + ARROW_COLOR_SOURCE, + plot_params_source, +) + +from .time import SourceTimeType + +if TYPE_CHECKING: + from typing import Optional + + from tidy3d._common.components.types import Ax + from tidy3d._common.components.viz import PlotParams + + +class Source(Box, AbstractSource, ABC): + """Abstract base class for all sources.""" + + source_time: SourceTimeType = Field( + title="Source Time", + description="Specification of the source time-dependence.", + discriminator=TYPE_TAG_STR, + ) + + @cached_property + def plot_params(self) -> PlotParams: + """Default parameters for plotting a Source object.""" + return plot_params_source + + @cached_property + def geometry(self) -> Box: + """:class:`Box` representation of source.""" + + return Box(center=self.center, size=self.size) + + @cached_property + def _injection_axis(self) -> None: + """Injection axis of the source.""" + return + + @cached_property + def _dir_vector(self) -> None: + """Returns a vector indicating the source direction for arrow plotting, if not None.""" + return None + + @cached_property + def _pol_vector(self) -> None: + """Returns a vector indicating the source polarization for arrow plotting, if not None.""" + return None + + _warn_traced_center = _warn_unsupported_traced_argument("center") + _warn_traced_size = _warn_unsupported_traced_argument("size") + + @field_validator("source_time") + @classmethod + def _freqs_lower_bound(cls, val: SourceTimeType) -> SourceTimeType: + """Raise validation error if central frequency is too low.""" + _assert_min_freq(val._freq0_sigma_centroid, msg_start="'source_time.freq0'") + return val + + def plot( + self, + x: Optional[float] = None, + y: Optional[float] = None, + z: Optional[float] = None, + ax: Ax = None, + **patch_kwargs: Any, + ) -> Ax: + """Plot this source.""" + + kwargs_arrow_base = patch_kwargs.pop("arrow_base", None) + + # call the `Source.plot()` function first. + ax = Box.plot(self, x=x, y=y, z=z, ax=ax, **patch_kwargs) + + kwargs_alpha = patch_kwargs.get("alpha") + arrow_alpha = ARROW_ALPHA if kwargs_alpha is None else kwargs_alpha + + # then add the arrow based on the propagation direction + if self._dir_vector is not None: + bend_radius = None + bend_axis = None + if hasattr(self, "mode_spec") and self.mode_spec.bend_radius is not None: + bend_radius = self.mode_spec.bend_radius + bend_axis = self._bend_axis + sign = 1 if self.direction == "+" else -1 + # Curvature has to be reversed because of ploting coordinates + if (self.size.index(0), bend_axis) in [(1, 2), (2, 0), (2, 1)]: + bend_radius *= -sign + else: + bend_radius *= sign + + ax = self._plot_arrow( + x=x, + y=y, + z=z, + ax=ax, + direction=self._dir_vector, + bend_radius=bend_radius, + bend_axis=bend_axis, + color=ARROW_COLOR_SOURCE, + alpha=arrow_alpha, + both_dirs=False, + arrow_base=kwargs_arrow_base, + ) + + if self._pol_vector is not None: + ax = self._plot_arrow( + x=x, + y=y, + z=z, + ax=ax, + direction=self._pol_vector, + color=ARROW_COLOR_POLARIZATION, + alpha=arrow_alpha, + both_dirs=False, + arrow_base=kwargs_arrow_base, + ) + + return ax diff --git a/tidy3d/_common/components/source/time.py b/tidy3d/_common/components/source/time.py new file mode 100644 index 0000000000..0862a8e222 --- /dev/null +++ b/tidy3d/_common/components/source/time.py @@ -0,0 +1,694 @@ +"""Defines time dependencies of injected electromagnetic sources.""" + +from __future__ import annotations + +import logging +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Optional, Union + +import numpy as np +from pydantic import Field, PositiveFloat, field_validator, model_validator +from pyroots import Brentq + +from tidy3d._common.components.base import cached_property +from tidy3d._common.components.data.data_array import TimeDataArray +from tidy3d._common.components.data.dataset import TimeDataset +from tidy3d._common.components.data.validators import validate_no_nans +from tidy3d._common.components.time import AbstractTimeDependence +from tidy3d._common.components.types.base import FreqBound +from tidy3d._common.components.validators import warn_if_dataset_none +from tidy3d._common.components.viz import add_ax_if_none +from tidy3d._common.constants import HERTZ +from tidy3d._common.exceptions import ValidationError +from tidy3d._common.log import log +from tidy3d._common.packaging import check_tidy3d_extras_licensed_feature, tidy3d_extras + +if TYPE_CHECKING: + from tidy3d._common.components.types.base import ArrayComplex1D, ArrayFloat1D, Ax, PlotVal + +if TYPE_CHECKING: + from tidy3d._common.components.types import ArrayComplex1D, ArrayFloat1D, Ax, PlotVal + +# how many units of ``twidth`` from the ``offset`` until a gaussian pulse is considered "off" +END_TIME_FACTOR_GAUSSIAN = 10 + +# warn if source amplitude is too small at the endpoints of frequency range +WARN_SOURCE_AMPLITUDE = 0.1 +# used in Brentq +_ROOTS_TOL = 1e-10 +# Default sigma value in frequency_range +DEFAULT_SIGMA = 4.0 +# Offset in fwidth in finding frequency_range_sigma[1] to ensure the interval brackets the root +OFFSET_FWIDTH_FMAX = 100 + + +class SourceTime(AbstractTimeDependence): + """Base class describing the time dependence of a source.""" + + @add_ax_if_none + def plot_spectrum( + self, + times: ArrayFloat1D, + num_freqs: int = 101, + val: PlotVal = "real", + ax: Ax = None, + ) -> Ax: + """Plot the complex-valued amplitude of the source time-dependence. + Note: Only the real part of the time signal is used. + + Parameters + ---------- + times : np.ndarray + Array of evenly-spaced times (seconds) to evaluate source time-dependence at. + The spectrum is computed from this value and the source time frequency content. + To see source spectrum for a specific :class:`.Simulation`, + pass ``simulation.tmesh``. + num_freqs : int = 101 + Number of frequencies to plot within the SourceTime.frequency_range. + ax : matplotlib.axes._subplots.Axes = None + Matplotlib axes to plot on, if not specified, one is created. + + Returns + ------- + matplotlib.axes._subplots.Axes + The supplied or created matplotlib axes. + """ + + fmin, fmax = self.frequency_range_sigma() + return self.plot_spectrum_in_frequency_range( + times, fmin, fmax, num_freqs=num_freqs, val=val, ax=ax + ) + + @abstractmethod + def frequency_range(self, num_fwidth: float = DEFAULT_SIGMA) -> FreqBound: + """Frequency range within plus/minus ``num_fwidth * fwidth`` of the central frequency.""" + + def frequency_range_sigma(self, sigma: float = DEFAULT_SIGMA) -> FreqBound: + """Frequency range where the source amplitude is within ``exp(-sigma**2/2)`` of the peak amplitude.""" + return self.frequency_range(num_fwidth=sigma) + + @cached_property + def _frequency_range_sigma_cached(self) -> FreqBound: + """Cached `frequency_range_sigma` for the default sigma value.""" + return self.frequency_range_sigma(sigma=DEFAULT_SIGMA) + + @abstractmethod + def end_time(self) -> Optional[float]: + """Time after which the source is effectively turned off / close to zero amplitude.""" + + @cached_property + def _freq0(self) -> float: + """Central frequency. If not present in input parameters, returns `_freq0_sigma_centroid`.""" + return self._freq0_sigma_centroid + + @cached_property + def _freq0_sigma_centroid(self) -> float: + """Central of frequency range at 1-sigma drop from the peak amplitude.""" + return np.mean(self.frequency_range_sigma(sigma=1)) + + +class Pulse(SourceTime, ABC): + """A source time that ramps up with some ``fwidth`` and oscillates at ``freq0``.""" + + freq0: PositiveFloat = Field( + title="Central Frequency", + description="Central frequency of the pulse.", + json_schema_extra={"units": HERTZ}, + ) + fwidth: PositiveFloat = Field( + title="", + description="Standard deviation of the frequency content of the pulse.", + json_schema_extra={"units": HERTZ}, + ) + + offset: float = Field( + 5.0, + title="Offset", + description="Time delay of the maximum value of the " + "pulse in units of 1 / (``2pi * fwidth``).", + ge=2.5, + ) + + @cached_property + def _freq0(self) -> float: + """Central frequency.""" + return self.freq0 + + @property + def offset_time(self) -> float: + """Offset time in seconds.""" + return self.offset * self.twidth + + @property + def twidth(self) -> float: + """Width of pulse in seconds.""" + return 1.0 / (2 * np.pi * self.fwidth) + + def frequency_range(self, num_fwidth: float = DEFAULT_SIGMA) -> FreqBound: + """Frequency range within 5 standard deviations of the central frequency. + + Parameters + ---------- + num_fwidth : float = 4. + Frequency range defined as plus/minus ``num_fwidth * self.fwdith``. + + Returns + ------- + Tuple[float, float] + Minimum and maximum frequencies of the :class:`GaussianPulse` or :class:`ContinuousWave` + power. + """ + + freq_width_range = num_fwidth * self.fwidth + freq_min = max(0, self.freq0 - freq_width_range) + freq_max = self.freq0 + freq_width_range + return (freq_min, freq_max) + + +class GaussianPulse(Pulse): + """Source time dependence that describes a Gaussian pulse. + + Example + ------- + >>> pulse = GaussianPulse(freq0=200e12, fwidth=20e12) + """ + + remove_dc_component: bool = Field( + True, + title="Remove DC Component", + description="Whether to remove the DC component in the Gaussian pulse spectrum. " + "If ``True``, the Gaussian pulse is modified at low frequencies to zero out the " + "DC component, which is usually desirable so that the fields will decay. However, " + "for broadband simulations, it may be better to have non-vanishing source power " + "near zero frequency. Setting this to ``False`` results in an unmodified Gaussian " + "pulse spectrum which can have a nonzero DC component.", + ) + + @property + def peak_time(self) -> float: + """Peak time in seconds, defined by ``offset``.""" + return self.offset * self.twidth + + @property + def _peak_time_shift(self) -> float: + """In the case of DC removal, correction to offset_time so that ``offset`` indeed defines time delay + of pulse peak. + """ + if self.remove_dc_component and self.fwidth > self.freq0: + return self.twidth * np.sqrt(1 - self.freq0**2 / self.fwidth**2) + return 0 + + @property + def offset_time(self) -> float: + """Offset time in seconds. Note that in the case of DC removal, the maximal value of pulse can be shifted.""" + return self.peak_time + self._peak_time_shift + + def amp_time(self, time: float) -> complex: + """Complex-valued source amplitude as a function of time.""" + + omega0 = 2 * np.pi * self.freq0 + time_shifted = time - self.offset_time + + offset = np.exp(1j * self.phase) + oscillation = np.exp(-1j * omega0 * time) + amp = np.exp(-(time_shifted**2) / 2 / self.twidth**2) * self.amplitude + + pulse_amp = offset * oscillation * amp + + # subtract out DC component + if self.remove_dc_component: + pulse_amp = pulse_amp * (1j * omega0 + time_shifted / self.twidth**2) + # normalize by peak frequency instead of omega0, as for small omega0, omega0 approaches 0 faster + pulse_amp /= 2 * np.pi * self.peak_frequency + else: + # 1j to make it agree in large omega0 limit + pulse_amp = pulse_amp * 1j + + return pulse_amp + + def end_time(self) -> Optional[float]: + """Time after which the source is effectively turned off / close to zero amplitude.""" + + # TODO: decide if we should continue to return an end_time if the DC component remains + # if not self.remove_dc_component: + # return None + + end_time = self.offset_time + END_TIME_FACTOR_GAUSSIAN * self.twidth + + # for derivative Gaussian that contains two peaks, add time interval between them + if self.remove_dc_component and self.fwidth > self.freq0: + end_time += 2 * self._peak_time_shift + return end_time + + def amp_freq(self, freq: float) -> complex: + """Complex-valued source spectrum in frequency domain.""" + phase = np.exp(1j * self.phase + 1j * 2 * np.pi * (freq - self.freq0) * self.offset_time) + envelope = np.exp(-((freq - self.freq0) ** 2) / 2 / self.fwidth**2) + amp = 1j * self.amplitude / self.fwidth * phase * envelope + if not self.remove_dc_component: + return amp + + # derivative of Gaussian when DC is removed + return freq * amp / (2 * np.pi * self.peak_frequency) + + def _rel_amp_freq(self, freq: float) -> complex: + """Complex-valued source spectrum in frequency domain normalized by peak amplitude.""" + return self.amp_freq(freq) / self._peak_freq_amp + + @property + def peak_frequency(self) -> float: + """Frequency at which the source time dependence has its peak amplitude in the frequency domain.""" + if not self.remove_dc_component: + return self.freq0 + return 0.5 * (self.freq0 + np.sqrt(self.freq0**2 + 4 * self.fwidth**2)) + + @property + def _peak_freq_amp(self) -> complex: + """Peak amplitude in frequency domain""" + return self.amp_freq(self.peak_frequency) + + @property + def _peak_time_amp(self) -> complex: + """Peak amplitude in time domain""" + return self.amp_time(self.peak_time) + + def frequency_range_sigma(self, sigma: float = DEFAULT_SIGMA) -> FreqBound: + """Frequency range where the source amplitude is within ``exp(-sigma**2/2)`` of the peak amplitude.""" + if not self.remove_dc_component: + return self.frequency_range(num_fwidth=sigma) + + # With dc removed, we'll need to solve for the transcendental equation to find the frequency range + def equation_for_sigma_frequency(freq: float) -> float: + """computes A / A_p - exp(-sigma)""" + return np.abs(self._rel_amp_freq(freq)) - np.exp(-(sigma**2) / 2) + + logger = logging.getLogger("pyroots") + logger.setLevel(logging.CRITICAL) + root_scalar = Brentq(raise_on_fail=False, epsilon=_ROOTS_TOL) + fmin_data = root_scalar(equation_for_sigma_frequency, xa=0, xb=self.peak_frequency) + fmax_data = root_scalar( + equation_for_sigma_frequency, + xa=self.peak_frequency, + xb=self.peak_frequency + + self.fwidth + * ( + OFFSET_FWIDTH_FMAX + 2 * sigma**2 + ), # offset slightly to make sure that it flips sign + ) + fmin, fmax = fmin_data.x0, fmax_data.x0 + + # if unconverged, fall back to `frequency_range` + if not (fmin_data.converged and fmax_data.converged and fmax > fmin): + return self.frequency_range(num_fwidth=sigma) + + # converged + return fmin.item(), fmax.item() + + @property + def amp_complex(self) -> complex: + """Grab the complex amplitude from a ``GaussianPulse``.""" + phase = np.exp(1j * self.phase) + return self.amplitude * phase + + @classmethod + def from_amp_complex(cls, amp: complex, **kwargs: Any) -> GaussianPulse: + """Set the complex amplitude of a ``GaussianPulse``. + + Parameters + ---------- + amp : complex + Complex-valued amplitude to set in the returned ``GaussianPulse``. + kwargs : dict + Keyword arguments passed to ``GaussianPulse()``, excluding ``amplitude`` & ``phase``. + """ + amplitude = abs(amp) + phase = np.angle(amp) + return cls(amplitude=amplitude, phase=phase, **kwargs) + + @staticmethod + def _minimum_source_bandwidth( + fmin: float, fmax: float, minimum_source_bandwidth: float + ) -> tuple[float, float]: + """Define a source bandwidth based on fmin and fmax, but enforce a minimum bandwidth.""" + if minimum_source_bandwidth <= 0: + raise ValidationError("'minimum_source_bandwidth' must be positive") + if minimum_source_bandwidth >= 1: + raise ValidationError("'minimum_source_bandwidth' must less than or equal to 1") + + f_difference = fmax - fmin + f_middle = 0.5 * (fmin + fmax) + + full_width = minimum_source_bandwidth * f_middle + if f_difference < full_width: + half_width = 0.5 * full_width + fmin = f_middle - half_width + fmax = f_middle + half_width + + return fmin, fmax + + @classmethod + def from_frequency_range( + cls, + fmin: PositiveFloat, + fmax: PositiveFloat, + minimum_source_bandwidth: Optional[PositiveFloat] = None, + **kwargs: Any, + ) -> GaussianPulse: + """Create a ``GaussianPulse`` that maximizes its amplitude in the frequency range [fmin, fmax]. + + Parameters + ---------- + fmin : float + Lower bound of frequency of interest. + fmax : float + Upper bound of frequency of interest. + kwargs : dict + Keyword arguments passed to ``GaussianPulse()``, excluding ``freq0`` & ``fwidth``. + + Returns + ------- + GaussianPulse + A ``GaussianPulse`` that maximizes its amplitude in the frequency range [fmin, fmax]. + """ + # validate that fmin and fmax must positive, and fmax > fmin + if fmin <= 0: + raise ValidationError("'fmin' must be positive.") + if fmax <= fmin: + raise ValidationError("'fmax' must be greater than 'fmin'.") + + if minimum_source_bandwidth is not None: + fmin, fmax = cls._minimum_source_bandwidth(fmin, fmax, minimum_source_bandwidth) + + # frequency range and center + freq_range = fmax - fmin + freq_center = (fmax + fmin) / 2.0 + + # If remove_dc_component=False, simply return the standard GaussianPulse parameters + if kwargs.get("remove_dc_component", True) is False: + return cls(freq0=freq_center, fwidth=freq_range / 2.0, **kwargs) + + # If remove_dc_component=True, the Gaussian pulse is distorted + kwargs.update({"remove_dc_component": True}) + log_ratio = np.log(fmax / fmin) + coeff = ((1 + log_ratio**2) ** 0.5 - 1) / 2.0 + freq0 = freq_center - coeff / log_ratio * freq_range + fwidth = freq_range / log_ratio * coeff**0.5 + pulse = cls(freq0=freq0, fwidth=fwidth, **kwargs) + if np.abs(pulse._rel_amp_freq(fmin)) < WARN_SOURCE_AMPLITUDE: + log.warning( + "Default source time profile is less accurate for the specified broadband frequency range. " + "For more accurate results, consider reducing the frequency range or using a 'BroadbandSource'.", + ) + return pulse + + +class ContinuousWave(Pulse): + """Source time dependence that ramps up to continuous oscillation + and holds until end of simulation. + + Note + ---- + Field decay will not occur, so the simulation will run for the full ``run_time``. + Also, source normalization of frequency-domain monitors is not meaningful. + + Example + ------- + >>> cw = ContinuousWave(freq0=200e12, fwidth=20e12) + """ + + def amp_time(self, time: float) -> complex: + """Complex-valued source amplitude as a function of time.""" + + twidth = 1.0 / (2 * np.pi * self.fwidth) + omega0 = 2 * np.pi * self.freq0 + time_shifted = time - self.offset_time + + const = 1.0 + offset = np.exp(1j * self.phase) + oscillation = np.exp(-1j * omega0 * time) + amp = 1 / (1 + np.exp(-time_shifted / twidth)) * self.amplitude + + return const * offset * oscillation * amp + + def end_time(self) -> Optional[float]: + """Time after which the source is effectively turned off / close to zero amplitude.""" + return None + + +class CustomSourceTime(Pulse): + """Custom source time dependence consisting of a real or complex envelope + modulated at a central frequency, as shown below. + + Note + ---- + .. math:: + + amp\\_time(t) = amplitude \\cdot \\ + e^{i \\cdot phase - 2 \\pi i \\cdot freq0 \\cdot t} \\cdot \\ + envelope(t - offset / (2 \\pi \\cdot fwidth)) + + Note + ---- + Depending on the envelope, field decay may not occur. + If field decay does not occur, then the simulation will run for the full ``run_time``. + Also, if field decay does not occur, then source normalization of frequency-domain + monitors is not meaningful. + + Note + ---- + The source time dependence is linearly interpolated to the simulation time steps. + The sampling rate should be sufficiently fast that this interpolation does not + introduce artifacts. The source time dependence should also start at zero and ramp up smoothly. + The first and last values of the envelope will be used for times that are out of range + of the provided data. + + Example + ------- + >>> cst = CustomSourceTime.from_values(freq0=1, fwidth=0.1, + ... values=np.linspace(0, 9, 10), dt=0.1) + + """ + + offset: float = Field( + 0.0, + title="Offset", + description="Time delay of the envelope in units of 1 / (``2pi * fwidth``).", + ) + + source_time_dataset: Optional[TimeDataset] = Field( + None, + title="Source time dataset", + description="Dataset for storing the envelope of the custom source time. " + "This envelope will be modulated by a complex exponential at frequency ``freq0``.", + ) + + _no_nans_dataset = validate_no_nans("source_time_dataset") + _source_time_dataset_none_warning = warn_if_dataset_none("source_time_dataset") + + @field_validator("source_time_dataset") + @classmethod + def _more_than_one_time(cls, val: Optional[TimeDataset]) -> Optional[TimeDataset]: + """Must have more than one time to interpolate.""" + if val is None: + return val + if val.values.size <= 1: + raise ValidationError("'CustomSourceTime' must have more than one time coordinate.") + return val + + @classmethod + def from_values( + cls, freq0: float, fwidth: float, values: ArrayComplex1D, dt: float + ) -> CustomSourceTime: + """Create a :class:`.CustomSourceTime` from a numpy array. + + Parameters + ---------- + freq0 : float + Central frequency of the source. The envelope provided will be modulated + by a complex exponential at this frequency. + fwidth : float + Estimated frequency width of the source. + values: ArrayComplex1D + Complex values of the source envelope. + dt: float + Time step for the ``values`` array. This value should be sufficiently small + that the interpolation to simulation time steps does not introduce artifacts. + + Returns + ------- + CustomSourceTime + :class:`.CustomSourceTime` with envelope given by ``values``, modulated by a complex + exponential at frequency ``freq0``. The time coordinates are evenly spaced + between ``0`` and ``dt * (N-1)`` with a step size of ``dt``, where ``N`` is the length of + the values array. + """ + + times = np.arange(len(values)) * dt + source_time_dataarray = TimeDataArray(values, coords={"t": times}) + source_time_dataset = TimeDataset(values=source_time_dataarray) + return CustomSourceTime( + freq0=freq0, + fwidth=fwidth, + source_time_dataset=source_time_dataset, + ) + + @property + def data_times(self) -> ArrayFloat1D: + """Times of envelope definition.""" + if self.source_time_dataset is None: + return [] + data_times = self.source_time_dataset.values.coords["t"].values.squeeze() + return data_times + + def _all_outside_range(self, run_time: float) -> bool: + """Whether all times are outside range of definition.""" + + # can't validate if data isn't loaded + if self.source_time_dataset is None: + return False + + # make time a numpy array for uniform handling + data_times = self.data_times + + # shift time + max_time_shifted = run_time - self.offset_time + min_time_shifted = -self.offset_time + + return (max_time_shifted < min(data_times)) | (min_time_shifted > max(data_times)) + + def amp_time(self, time: float) -> complex: + """Complex-valued source amplitude as a function of time. + + Parameters + ---------- + time : float + Time in seconds. + + Returns + ------- + complex + Complex-valued source amplitude at that time. + """ + + if self.source_time_dataset is None: + return None + + # make time a numpy array for uniform handling + times = np.array([time] if isinstance(time, (int, float)) else time) + data_times = self.data_times + + # shift time + twidth = 1.0 / (2 * np.pi * self.fwidth) + time_shifted = times - self.offset * twidth + + # mask times that are out of range + mask = (time_shifted < min(data_times)) | (time_shifted > max(data_times)) + + # get envelope + envelope = np.zeros(len(time_shifted), dtype=complex) + values = self.source_time_dataset.values + envelope[mask] = values.sel(t=time_shifted[mask], method="nearest").to_numpy() + if not all(mask): + envelope[~mask] = values.interp(t=time_shifted[~mask]).to_numpy() + + # modulation, phase, amplitude + omega0 = 2 * np.pi * self.freq0 + offset = np.exp(1j * self.phase) + oscillation = np.exp(-1j * omega0 * times) + amp = self.amplitude + + return offset * oscillation * amp * envelope + + def end_time(self) -> Optional[float]: + """Time after which the source is effectively turned off / close to zero amplitude.""" + + if self.source_time_dataset is None: + return None + + data_array = self.source_time_dataset.values + + t_coords = data_array.coords["t"] + source_is_non_zero = ~np.isclose(abs(data_array), 0) + t_non_zero = t_coords[source_is_non_zero] + + return np.max(t_non_zero) + + +class BroadbandPulse(SourceTime): + """A source time injecting significant energy in the entire custom frequency range.""" + + freq_range: FreqBound = Field( + title="Frequency Range", + description="Frequency range where the pulse should have significant energy.", + json_schema_extra={"units": HERTZ}, + ) + minimum_amplitude: float = Field( + 0.3, + title="Minimum Amplitude", + description="Minimum amplitude of the pulse relative to the peak amplitude in the frequency range.", + gt=0.05, + lt=0.5, + ) + offset: float = Field( + 0.0, + title="Offset", + description="An automatic time delay of the peak value of the pulse has been applied under the hood " + "to ensure smooth ramping up of the pulse at time = 0. This offfset is added on top of the automatic time delay " + "in units of 1 / [``2pi * (freq_range[1] - freq_range[0])``].", + ) + + @field_validator("freq_range") + @classmethod + def _validate_freq_range(cls, val: FreqBound) -> FreqBound: + """Validate that freq_range is positive and properly ordered.""" + if val[0] <= 0 or val[1] <= 0: + raise ValidationError("Both elements of 'freq_range' must be positive.") + if val[1] <= val[0]: + raise ValidationError( + f"'freq_range[1]' ({val[1]}) must be greater than 'freq_range[0]' ({val[0]})." + ) + return val + + @model_validator(mode="before") + @classmethod + def _check_broadband_pulse_available(cls, values: dict[str, Any]) -> dict[str, Any]: + """Check if BroadbandPulse is available.""" + check_tidy3d_extras_licensed_feature("BroadbandPulse") + return values + + @cached_property + def _source(self) -> Any: + """Implementation of broadband pulse.""" + return tidy3d_extras["mod"].extension.BroadbandPulse( + fmin=self.freq_range[0], + fmax=self.freq_range[1], + minRelAmp=self.minimum_amplitude, + amp=self.amplitude, + phase=self.phase, + offset=self.offset, + ) + + def end_time(self) -> float: + """Time after which the source is effectively turned off / close to zero amplitude.""" + return self._source.end_time(END_TIME_FACTOR_GAUSSIAN) + + def amp_time(self, time: float) -> complex: + """Complex-valued source amplitude as a function of time.""" + return self._source.amp_time(time) + + def amp_freq(self, freq: float) -> complex: + """Complex-valued source amplitude as a function of frequency.""" + return self._source.amp_freq(freq) + + def frequency_range_sigma(self, sigma: float = DEFAULT_SIGMA) -> FreqBound: + """Frequency range where the source amplitude is within ``exp(-sigma**2/2)`` of the peak amplitude.""" + return self._source.frequency_range(sigma) + + def frequency_range(self, num_fwidth: float = DEFAULT_SIGMA) -> FreqBound: + """Delegated to `frequency_range_sigma(sigma=num_fwidth)` for computing the frequency range where the source amplitude + is within ``exp(-num_fwidth**2/2)`` of the peak amplitude. + """ + return self.frequency_range_sigma(num_fwidth) + + +SourceTimeType = Union[GaussianPulse, ContinuousWave, CustomSourceTime, BroadbandPulse] diff --git a/tidy3d/_common/components/time.py b/tidy3d/_common/components/time.py new file mode 100644 index 0000000000..75bc367dd1 --- /dev/null +++ b/tidy3d/_common/components/time.py @@ -0,0 +1,210 @@ +"""Defines time dependence""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +import numpy as np +from pydantic import Field, NonNegativeFloat + +from tidy3d._common.components.base import Tidy3dBaseModel +from tidy3d._common.components.viz import add_ax_if_none +from tidy3d._common.constants import RADIAN +from tidy3d._common.exceptions import SetupError + +if TYPE_CHECKING: + from tidy3d._common.components.types import ArrayFloat1D, Ax, PlotVal + +if TYPE_CHECKING: + from .types import ArrayFloat1D, Ax, PlotVal + +# in spectrum computation, discard amplitudes with relative magnitude smaller than cutoff +DFT_CUTOFF = 1e-8 + + +class AbstractTimeDependence(ABC, Tidy3dBaseModel): + """Base class describing time dependence.""" + + amplitude: NonNegativeFloat = Field( + 1.0, title="Amplitude", description="Real-valued maximum amplitude of the time dependence." + ) + + phase: float = Field( + 0.0, + title="Phase", + description="Phase shift of the time dependence.", + json_schema_extra={"units": RADIAN}, + ) + + @abstractmethod + def amp_time(self, time: float) -> complex: + """Complex-valued amplitude as a function of time. + + Parameters + ---------- + time : float + Time in seconds. + + Returns + ------- + complex + Complex-valued amplitude at that time. + """ + + def spectrum( + self, + times: ArrayFloat1D, + freqs: ArrayFloat1D, + dt: float, + ) -> complex: + """Complex-valued spectrum as a function of frequency. + Note: Only the real part of the time signal is used. + + Parameters + ---------- + times : np.ndarray + Times to use to evaluate spectrum Fourier transform. + (Typically the simulation time mesh). + freqs : np.ndarray + Frequencies in Hz to evaluate spectrum at. + dt : float or np.ndarray + Time step to weight FT integral with. + If array, use to weigh each of the time intervals in ``times``. + + Returns + ------- + np.ndarray + Complex-valued array (of len(freqs)) containing spectrum at those frequencies. + """ + + times = np.array(times) + freqs = np.array(freqs) + time_amps = np.real(self.amp_time(times)) + + # if all time amplitudes are zero, just return (complex-valued) zeros for spectrum + if np.all(np.equal(time_amps, 0.0)): + return (0.0 + 0.0j) * np.zeros_like(freqs) + + # Cut to only relevant times + relevant_time_inds = np.where(np.abs(time_amps) / np.amax(np.abs(time_amps)) > DFT_CUTOFF) + # find first and last index where the filter is True + start_ind = relevant_time_inds[0][0] + stop_ind = relevant_time_inds[0][-1] + 1 + time_amps = time_amps[start_ind:stop_ind] + times_cut = times[start_ind:stop_ind] + if times_cut.size == 0: + return (0.0 + 0.0j) * np.zeros_like(freqs) + + # only need to compute DTFT kernel for distinct dts + # usually, there is only one dt, if times is simulation time mesh + dts = np.diff(times_cut) + dts_unique, kernel_indices = np.unique(dts, return_inverse=True) + + dft_kernels = [np.exp(2j * np.pi * freqs * curr_dt) for curr_dt in dts_unique] + running_kernel = np.exp(2j * np.pi * freqs * times_cut[0]) + dft = np.zeros(len(freqs), dtype=complex) + for amp, kernel_index in zip(time_amps, kernel_indices): + dft += running_kernel * amp + running_kernel *= dft_kernels[kernel_index] + + # kernel_indices was one index shorter than time_amps + dft += running_kernel * time_amps[-1] + + return dt * dft / np.sqrt(2 * np.pi) + + @add_ax_if_none + def plot_spectrum_in_frequency_range( + self, + times: ArrayFloat1D, + fmin: float, + fmax: float, + num_freqs: int = 101, + val: PlotVal = "real", + ax: Ax = None, + ) -> Ax: + """Plot the complex-valued amplitude of the time-dependence. + Note: Only the real part of the time signal is used. + + Parameters + ---------- + times : np.ndarray + Array of evenly-spaced times (seconds) to evaluate time-dependence at. + The spectrum is computed from this value and the time frequency content. + To see spectrum for a specific :class:`.Simulation`, + pass ``simulation.tmesh``. + fmin : float + Lower bound of frequency for the spectrum plot. + fmax : float + Upper bound of frequency for the spectrum plot. + num_freqs : int = 101 + Number of frequencies to plot within the [fmin, fmax]. + ax : matplotlib.axes._subplots.Axes = None + Matplotlib axes to plot on, if not specified, one is created. + + Returns + ------- + matplotlib.axes._subplots.Axes + The supplied or created matplotlib axes. + """ + times = np.array(times) + + dts = np.diff(times) + if not np.allclose(dts, dts[0] * np.ones_like(dts), atol=1e-17): + raise SetupError("Supplied times not evenly spaced.") + + dt = np.mean(dts) + freqs = np.linspace(fmin, fmax, num_freqs) + + spectrum = self.spectrum(times=times, dt=dt, freqs=freqs) + + if val == "real": + ax.plot(freqs, spectrum.real, color="blueviolet", label="real") + elif val == "imag": + ax.plot(freqs, spectrum.imag, color="crimson", label="imag") + elif val == "abs": + ax.plot(freqs, np.abs(spectrum), color="k", label="abs") + else: + raise ValueError(f"Plot 'val' option of '{val}' not recognized.") + ax.set_xlabel("frequency (Hz)") + ax.set_title("source spectrum") + ax.legend() + ax.set_aspect("auto") + return ax + + @add_ax_if_none + def plot(self, times: ArrayFloat1D, val: PlotVal = "real", ax: Ax = None) -> Ax: + """Plot the complex-valued amplitude of the time-dependence. + + Parameters + ---------- + times : np.ndarray + Array of times (seconds) to plot source at. + To see source time amplitude for a specific :class:`.Simulation`, + pass ``simulation.tmesh``. + val : Literal['real', 'imag', 'abs'] = 'real' + Which part of the spectrum to plot. + ax : matplotlib.axes._subplots.Axes = None + Matplotlib axes to plot on, if not specified, one is created. + + Returns + ------- + matplotlib.axes._subplots.Axes + The supplied or created matplotlib axes. + """ + times = np.array(times) + amp_complex = self.amp_time(times) + + if val == "real": + ax.plot(times, amp_complex.real, color="blueviolet", label="real") + elif val == "imag": + ax.plot(times, amp_complex.imag, color="crimson", label="imag") + elif val == "abs": + ax.plot(times, np.abs(amp_complex), color="k", label="abs") + else: + raise ValueError(f"Plot 'val' option of '{val}' not recognized.") + ax.set_xlabel("time (s)") + ax.set_title("source amplitude") + ax.legend() + ax.set_aspect("auto") + return ax diff --git a/tidy3d/_common/components/transformation.py b/tidy3d/_common/components/transformation.py new file mode 100644 index 0000000000..17038fb503 --- /dev/null +++ b/tidy3d/_common/components/transformation.py @@ -0,0 +1,210 @@ +"""Defines geometric transformation classes""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Union + +import numpy as np +from pydantic import Field, field_validator + +from tidy3d._common.components.autograd import TracedFloat +from tidy3d._common.components.base import Tidy3dBaseModel, cached_property +from tidy3d._common.components.types.base import Axis, Coordinate +from tidy3d._common.constants import RADIAN +from tidy3d._common.exceptions import ValidationError + +if TYPE_CHECKING: + from tidy3d._common.components.types.base import ArrayFloat2D, TensorReal + + +class AbstractRotation(ABC, Tidy3dBaseModel): + """Abstract rotation of vectors and tensors.""" + + @cached_property + @abstractmethod + def matrix(self) -> TensorReal: + """Rotation matrix.""" + + @cached_property + @abstractmethod + def isidentity(self) -> bool: + """Check whether rotation is identity.""" + + def rotate_vector(self, vector: ArrayFloat2D) -> ArrayFloat2D: + """Rotate a vector/point or a list of vectors/points. + + Parameters + ---------- + points : ArrayLike[float] + Array of shape ``(3, ...)``. + + Returns + ------- + Coordinate + Rotated vector. + """ + + if self.isidentity: + return vector + + if len(vector.shape) == 1: + return self.matrix @ vector + + return np.tensordot(self.matrix, vector, axes=1) + + def rotate_tensor(self, tensor: TensorReal) -> TensorReal: + """Rotate a tensor. + + Parameters + ---------- + tensor : ArrayLike[float] + Array of shape ``(3, 3)``. + + Returns + ------- + TensorReal + Rotated tensor. + """ + + if self.isidentity: + return tensor + + return np.matmul(self.matrix, np.matmul(tensor, self.matrix.T)) + + +class RotationAroundAxis(AbstractRotation): + """Rotation of vectors and tensors around a given vector.""" + + axis: Union[Axis, Coordinate] = Field( + 0, + title="Axis of Rotation", + description="A vector that specifies the axis of rotation, or a single int: 0, 1, or 2, " + "indicating x, y, or z.", + ) + + angle: TracedFloat = Field( + 0.0, + title="Angle of Rotation", + description="Angle of rotation in radians.", + json_schema_extra={"units": RADIAN}, + ) + + @field_validator("axis") + @classmethod + def _validate_axis_vector(cls, val: Union[Axis, Coordinate]) -> Coordinate: + if not isinstance(val, tuple): + axis = [0.0, 0.0, 0.0] + axis[val] = 1.0 + val = tuple(axis) + return val + + @field_validator("axis") + @classmethod + def _validate_axis_nonzero_norm(cls, val: Coordinate) -> Coordinate: + norm = np.linalg.norm(val) + if np.isclose(norm, 0): + raise ValidationError( + "The norm of vector 'axis' cannot be zero. Please provide a proper rotation axis." + ) + return val + + @cached_property + def isidentity(self) -> bool: + """Check whether rotation is identity.""" + + return np.isclose(self.angle % (2 * np.pi), 0) + + @cached_property + def matrix(self) -> TensorReal: + """Rotation matrix.""" + + if self.isidentity: + return np.eye(3) + + norm = np.linalg.norm(self.axis) + n = self.axis / norm + c = np.cos(self.angle) + s = np.sin(self.angle) + K = np.array([[0, -n[2], n[1]], [n[2], 0, -n[0]], [-n[1], n[0], 0]]) + R = np.eye(3) + s * K + (1 - c) * K @ K + + return R + + +class AbstractReflection(ABC, Tidy3dBaseModel): + """Abstract reflection of vectors and tensors.""" + + @cached_property + @abstractmethod + def matrix(self) -> TensorReal: + """Reflection matrix.""" + + def reflect_vector(self, vector: ArrayFloat2D) -> ArrayFloat2D: + """Reflect a vector/point or a list of vectors/points. + + Parameters + ---------- + vector : ArrayLike[float] + Array of shape ``(3, ...)``. + + Returns + ------- + Coordinate + Reflected vector. + """ + + if len(vector.shape) == 1: + return self.matrix @ vector + + return np.tensordot(self.matrix, vector, axes=1) + + def reflect_tensor(self, tensor: TensorReal) -> TensorReal: + """Reflect a tensor. + + Parameters + ---------- + tensor : ArrayLike[float] + Array of shape ``(3, 3)``. + + Returns + ------- + TensorReal + Reflected tensor. + """ + + return np.matmul(self.matrix, np.matmul(tensor, self.matrix.T)) + + +class ReflectionFromPlane(AbstractReflection): + """Reflection of vectors and tensors around a given vector.""" + + normal: Coordinate = Field( + (1, 0, 0), + title="Normal of the reflecting plane", + description="A vector that specifies the normal of the plane of reflection", + ) + + @field_validator("normal") + @classmethod + def _validate_normal_nonzero_norm(cls, val: Coordinate) -> Coordinate: + norm = np.linalg.norm(val) + if np.isclose(norm, 0): + raise ValidationError( + "The norm of vector 'normal' cannot be zero. Please provide a proper normal vector." + ) + return val + + @cached_property + def matrix(self) -> TensorReal: + """Reflection matrix.""" + + norm = np.linalg.norm(self.normal) + n = self.normal / norm + R = np.eye(3) - 2 * np.outer(n, n) + + return R + + +RotationType = Union[RotationAroundAxis] +ReflectionType = Union[ReflectionFromPlane] diff --git a/tidy3d/_common/components/types/__init__.py b/tidy3d/_common/components/types/__init__.py new file mode 100644 index 0000000000..51654144cb --- /dev/null +++ b/tidy3d/_common/components/types/__init__.py @@ -0,0 +1,131 @@ +"""Exports all data types from the .base module for easy access.""" + +from __future__ import annotations + +from tidy3d._common.components.types.base import ( + TYPE_TAG_STR, + ArrayComplex1D, + ArrayComplex2D, + ArrayComplex3D, + ArrayComplex4D, + ArrayFloat1D, + ArrayFloat2D, + ArrayFloat3D, + ArrayFloat4D, + ArrayInt1D, + ArrayLike, + AuxField, + Ax, + Axis, + Axis2D, + Bound, + BoxSurface, + ClipOperationType, + ColormapType, + Complex, + Coordinate, + Coordinate2D, + CoordinateOptional, + Direction, + EMField, + EpsSpecType, + FieldType, + FieldVal, + FreqArray, + FreqBound, + FreqBoundMax, + FreqBoundMin, + GridSize, + InterpMethod, + LengthUnit, + LumpDistType, + MatrixReal4x4, + ModeClassification, + ModeSolverType, + ObsGridArray, + PermittivityComponent, + PlanePosition, + PlotScale, + PlotVal, + Polarization, + PolarizationBasis, + PoleAndResidue, + PriorityMode, + RealFieldVal, + ScalarSymmetry, + Shapely, + Size, + Size1D, + Symmetry, + TensorReal, + TrackFreq, + Undefined, + UnitsZBF, + xyz, +) +from tidy3d._common.components.types.third_party import TrimeshType +from tidy3d._common.components.types.utils import _add_schema + +__all__ = [ + "TYPE_TAG_STR", + "ArrayComplex1D", + "ArrayComplex2D", + "ArrayComplex3D", + "ArrayComplex4D", + "ArrayFloat1D", + "ArrayFloat2D", + "ArrayFloat3D", + "ArrayFloat4D", + "ArrayInt1D", + "ArrayLike", + "AuxField", + "Ax", + "Axis", + "Axis2D", + "Bound", + "BoxSurface", + "ClipOperationType", + "ColormapType", + "Complex", + "Coordinate", + "Coordinate2D", + "CoordinateOptional", + "Direction", + "EMField", + "EpsSpecType", + "FieldType", + "FieldVal", + "FreqArray", + "FreqBound", + "FreqBoundMax", + "FreqBoundMin", + "GridSize", + "InterpMethod", + "LengthUnit", + "LumpDistType", + "MatrixReal4x4", + "ModeClassification", + "ModeSolverType", + "ObsGridArray", + "PermittivityComponent", + "PlanePosition", + "PlotScale", + "PlotVal", + "Polarization", + "PolarizationBasis", + "PoleAndResidue", + "PriorityMode", + "RealFieldVal", + "ScalarSymmetry", + "Shapely", + "Size", + "Size1D", + "Symmetry", + "TensorReal", + "TrackFreq", + "TrimeshType", + "Undefined", + "UnitsZBF", + "_add_schema", + "xyz", +] diff --git a/tidy3d/_common/components/types/base.py b/tidy3d/_common/components/types/base.py new file mode 100644 index 0000000000..ea408643fd --- /dev/null +++ b/tidy3d/_common/components/types/base.py @@ -0,0 +1,320 @@ +"""Defines 'types' that various fields can be""" + +from __future__ import annotations + +import numbers +from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, Union + +import numpy as np +from pydantic import ( + BaseModel, + BeforeValidator, + ConfigDict, + Field, + NonNegativeFloat, + PlainValidator, + PositiveFloat, +) +from pydantic.functional_serializers import PlainSerializer +from pydantic.json_schema import WithJsonSchema + +if TYPE_CHECKING: + from numpy.typing import NDArray + +try: + from matplotlib.axes import Axes +except ImportError: + Axes = None + +from shapely.geometry.base import BaseGeometry + +# type tag default name +TYPE_TAG_STR = "type" + + +def discriminated_union(union: type, discriminator: str = TYPE_TAG_STR) -> type: + return Annotated[union, Field(discriminator=discriminator)] + + +""" Numpy Arrays """ + + +def _dtype2python(value: Any) -> Any: + """Converts numpy scalar types to their python equivalents.""" + if isinstance(value, np.integer): + return int(value) + if isinstance(value, np.floating): + return float(value) + if isinstance(value, np.complexfloating): + return complex(value) + if isinstance(value, np.bool_): + return bool(value) + return value + + +def _from_complex_dict(v: Any) -> Any: + if isinstance(v, dict) and "real" in v and "imag" in v: + return np.asarray(v["real"]) + 1j * np.asarray(v["imag"]) + return v + + +def _auto_serializer(a: Any, _: Any) -> Any: + """Serializes numpy arrays and scalars for JSON.""" + if isinstance(a, complex) or ( + hasattr(np, "complexfloating") and isinstance(a, np.complexfloating) + ): + return {"real": float(a.real), "imag": float(a.imag)} + if isinstance(a, np.ndarray): + if np.iscomplexobj(a): + return {"real": a.real.tolist(), "imag": a.imag.tolist()} + else: + return a.tolist() + if isinstance(a, float) or (hasattr(np, "floating") and isinstance(a, np.floating)): + return float(a) # Ensure basic Python float + if isinstance(a, int) or (hasattr(np, "integer") and isinstance(a, np.integer)): + return int(a) # Ensure basic Python int + if hasattr(np, "number") and isinstance(a, np.number): + return a.item() + return a + + +DTypeLike = Annotated[np.dtype, PlainValidator(np.dtype), WithJsonSchema({"type": "np.dtype"})] + + +class ArrayConstraints(BaseModel): + """Container for array constraints.""" + + model_config = ConfigDict(frozen=True) + + dtype: Optional[DTypeLike] = None + ndim: Optional[int] = None + shape: Optional[tuple[int, ...]] = None + forbid_nan: bool = True + scalar_to_1d: bool = False + strict: bool = False + + +def _coerce(v: Any, *, constraints: ArrayConstraints) -> NDArray: + """Convert input to a NumPy array with constraints. + + Raises + ------ + ValueError + - If conversion to an array fails. + - If the array ends up with dtype=object (unsupported element type). + - If the number of dimensions or shape does not match the expectations. + - If ``forbid_nan`` is ``True`` and the array contains NaN values. + """ + if constraints.strict and np.isscalar(v): + raise ValueError( + f"strict mode: scalar value {type(v).__name__!r} cannot be coerced to a NumPy array. " + ) + + try: + # constraints.dtype is already an np.dtype object or None + arr = np.asarray(v) if constraints.dtype is None else np.asarray(v, dtype=constraints.dtype) + except Exception as e: + raise ValueError(f"cannot convert {type(v).__name__!r} to a NumPy array") from e + + if arr.dtype == np.dtype("object"): + raise ValueError(f"unsupported element type {type(v).__name__!r} for array coercion") + + if ( + arr.ndim == 0 + and (constraints.ndim == 1 or constraints.ndim is None) + and constraints.scalar_to_1d + ): + arr = arr.reshape(1) + if constraints.ndim is not None and arr.ndim != constraints.ndim: + raise ValueError(f"expected {constraints.ndim}-D, got {arr.ndim}-D") + if constraints.shape is not None and tuple(arr.shape) != constraints.shape: + raise ValueError(f"expected shape {constraints.shape}, got {tuple(arr.shape)}") + if constraints.forbid_nan and np.any(np.isnan(arr)): + raise ValueError("array contains NaN") + + # enforce immutability of our Pydantic models + arr.flags.writeable = False + + return arr + + +def array_alias( + *, + dtype: Optional[Any] = None, + ndim: Optional[int] = None, + shape: Optional[tuple[int, ...]] = None, + forbid_nan: bool = True, + scalar_to_1d: bool = False, + strict: bool = False, +) -> Any: + constraints = ArrayConstraints( + dtype=dtype, + ndim=ndim, + shape=shape, + forbid_nan=forbid_nan, + scalar_to_1d=scalar_to_1d, + strict=strict, + ) + serializer = PlainSerializer(_auto_serializer, when_used="json") + + base_schema = { + "type": "ArrayLike", + "x-array-dtype": getattr(constraints.dtype, "str", None), + "x-array-ndim": constraints.ndim, + "x-array-shape": constraints.shape, + "x-array-forbid_nan": constraints.forbid_nan, + "x-array-scalar_to_1d": constraints.scalar_to_1d, + "x-array-strict": constraints.strict, + } + + return Annotated[ + np.ndarray, + BeforeValidator(_from_complex_dict), + BeforeValidator(lambda v: _coerce(v, constraints=constraints)), + serializer, + WithJsonSchema(base_schema), + ] + + +ArrayLike = array_alias() +ArrayLikeStrict = array_alias(strict=True) + +ArrayInt1D = array_alias(dtype=int, ndim=1, scalar_to_1d=True) + +ArrayFloat = array_alias(dtype=float) +ArrayFloat1D = array_alias(dtype=float, ndim=1, scalar_to_1d=True) +ArrayFloat2D = array_alias(dtype=float, ndim=2) +ArrayFloat3D = array_alias(dtype=float, ndim=3) +ArrayFloat4D = array_alias(dtype=float, ndim=4) + +ArrayComplex = array_alias(dtype=complex) +ArrayComplex1D = array_alias(dtype=complex, ndim=1, scalar_to_1d=True) +ArrayComplex2D = array_alias(dtype=complex, ndim=2) +ArrayComplex3D = array_alias(dtype=complex, ndim=3) +ArrayComplex4D = array_alias(dtype=complex, ndim=4) + +TensorReal = array_alias(dtype=float, ndim=2, shape=(3, 3)) +MatrixReal4x4 = array_alias(dtype=float, ndim=2, shape=(4, 4)) + +""" Complex Values """ + + +def _parse_complex(v: Any) -> complex: + if isinstance(v, complex): + return v + + if isinstance(v, dict) and "real" in v and "imag" in v: + return complex(v["real"], v["imag"]) + + if isinstance(v, numbers.Number): + return complex(v) + + if hasattr(v, "__complex__"): + try: + return complex(v.__complex__()) + except Exception: + pass + + if isinstance(v, (list, tuple)) and len(v) == 2: + return complex(v[0], v[1]) + + return v + + +Complex = Annotated[ + complex, + BeforeValidator(_parse_complex), + PlainSerializer( + lambda z, _: {"real": z.real, "imag": z.imag}, + when_used="json", + return_type=dict, + ), +] + +""" symmetry """ + +Symmetry = Annotated[Literal[0, -1, 1], BeforeValidator(_dtype2python)] +ScalarSymmetry = Annotated[Literal[0, 1], BeforeValidator(_dtype2python)] + +""" geometric """ + +Size1D = NonNegativeFloat +Size = tuple[Size1D, Size1D, Size1D] +Coordinate = tuple[float, float, float] +CoordinateOptional = tuple[Optional[float], Optional[float], Optional[float]] +Coordinate2D = tuple[float, float] +Bound = tuple[Coordinate, Coordinate] +GridSize = Union[PositiveFloat, tuple[PositiveFloat, ...]] +Axis = Annotated[Literal[0, 1, 2], BeforeValidator(_dtype2python)] +Axis2D = Annotated[Literal[0, 1], BeforeValidator(_dtype2python)] +Shapely = BaseGeometry +PlanePosition = Literal["bottom", "middle", "top"] +ClipOperationType = Literal["union", "intersection", "difference", "symmetric_difference"] +BoxSurface = Literal["x-", "x+", "y-", "y+", "z-", "z+"] +LengthUnit = Literal["nm", "μm", "um", "mm", "cm", "m", "mil", "in"] +PriorityMode = Literal["equal", "conductor"] + +""" medium """ + +# custom medium +InterpMethod = Literal["nearest", "linear"] + +PoleAndResidue = tuple[Complex, Complex] +PolesAndResidues = tuple[PoleAndResidue, ...] +FreqBoundMax = float +FreqBoundMin = float +FreqBound = tuple[FreqBoundMin, FreqBoundMax] + +PermittivityComponent = Literal["xx", "xy", "xz", "yx", "yy", "yz", "zx", "zy", "zz"] + +""" sources """ + +Polarization = Literal["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"] +Direction = Literal["+", "-"] + +""" monitors """ + + +def _list_to_tuple(v: Any) -> Any: + if isinstance(v, list): + return tuple(v) + return v + + +EMField = Literal["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"] +FieldType = Literal["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"] +FreqArray = ArrayFloat1D +ObsGridArray = FreqArray +PolarizationBasis = Literal["linear", "circular"] +AuxField = Literal["Nfx", "Nfy", "Nfz"] + +""" plotting """ + +Ax = Axes +PlotVal = Literal["real", "imag", "abs"] +FieldVal = Literal["real", "imag", "abs", "abs^2", "phase"] +RealFieldVal = Literal["real", "abs", "abs^2"] +PlotScale = Literal["lin", "dB", "log", "symlog"] +ColormapType = Literal["divergent", "sequential", "cyclic"] + +""" mode solver """ + +ModeSolverType = Literal["tensorial", "diagonal"] +EpsSpecType = Literal["diagonal", "tensorial_real", "tensorial_complex"] +ModeClassification = Literal["TEM", "quasi-TEM", "TE", "TM", "Hybrid"] + +""" mode tracking """ + +TrackFreq = Literal["central", "lowest", "highest"] + +""" lumped elements""" + +LumpDistType = Literal["off", "laterally_only", "on"] + +""" dataset """ + +xyz = Literal["x", "y", "z"] +UnitsZBF = Literal["mm", "cm", "in", "m"] + +""" sentinel """ +Undefined = object() diff --git a/tidy3d/_common/components/types/third_party.py b/tidy3d/_common/components/types/third_party.py new file mode 100644 index 0000000000..1530d2f088 --- /dev/null +++ b/tidy3d/_common/components/types/third_party.py @@ -0,0 +1,14 @@ +from __future__ import annotations + +from typing import Any + +from tidy3d._common.packaging import check_import + +# TODO Complicated as trimesh should be a core package unless decoupled implementation types in functional location. +# We need to restructure. +if check_import("trimesh"): + import trimesh # Won't add much overhead if already imported + + TrimeshType = trimesh.Trimesh +else: + TrimeshType = Any diff --git a/tidy3d/_common/components/types/utils.py b/tidy3d/_common/components/types/utils.py new file mode 100644 index 0000000000..333cdb807e --- /dev/null +++ b/tidy3d/_common/components/types/utils.py @@ -0,0 +1,33 @@ +"""Utilities for type & schema creation.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from pydantic_core import core_schema + +if TYPE_CHECKING: + from pydantic import GetCoreSchemaHandler + + +def _add_schema(arbitrary_type: type, title: str, field_type_str: str) -> None: + """Adds a schema to the ``arbitrary_type`` class without subclassing.""" + + @classmethod + def __get_pydantic_core_schema__( + cls: type, _source_type: type, _handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + def _serialize(value: Any, info: core_schema.SerializationInfo) -> Any: + from tidy3d._common.components.autograd.utils import get_static + from tidy3d._common.components.types.base import _auto_serializer + + return _auto_serializer(get_static(value), info) + + return core_schema.any_schema( + metadata={"title": title, "type": field_type_str}, + serialization=core_schema.plain_serializer_function_ser_schema( + _serialize, info_arg=True + ), + ) + + arbitrary_type.__get_pydantic_core_schema__ = __get_pydantic_core_schema__ diff --git a/tidy3d/_common/components/validators.py b/tidy3d/_common/components/validators.py new file mode 100644 index 0000000000..ded4659f02 --- /dev/null +++ b/tidy3d/_common/components/validators.py @@ -0,0 +1,122 @@ +"""Defines various validation functions that get used to ensure inputs are legit""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, TypeVar, Union + +import numpy as np +from numpy.typing import NDArray +from pydantic import field_validator + +from tidy3d._common.components.autograd.utils import get_static, hasbox +from tidy3d._common.components.data.data_array import DATA_ARRAY_MAP +from tidy3d._common.exceptions import ValidationError +from tidy3d._common.log import log + +if TYPE_CHECKING: + from typing import Callable, Optional + + from pydantic import FieldValidationInfo + +T = TypeVar("T") + +""" Explanation of pydantic validators: + + Validators are class methods that are added to the models to validate their fields (kwargs). + The functions on this page return validators based on config arguments + and are generally in multiple components of tidy3d. + The inner functions (validators) are decorated with @pydantic.validator, which is configured. + First argument is the string of the field being validated in the model. + ``allow_reuse`` lets us use the validator in more than one model. + ``always`` makes sure if the model is changed, the validator gets called again. + + The function being decorated by @pydantic.validator generally takes + ``cls`` the class that the validator is added to. + ``val`` the value of the field being validated. + ``values`` a dictionary containing all of the other fields of the model. + It is important to note that the validator only has access to fields that are defined + before the field being validated. + Fields defined under the validated field will not be in ``values``. + + All validators generally should throw an exception if the validation fails + and return val if it passes. + Sometimes, we can use validators to change ``val`` or ``values``, + but this should be done with caution as it can be hard to reason about. + + To add a validator from this file to the pydantic model, + put it in the model's main body and assign it to a variable (class method). + For example ``_plane_validator = assert_plane()``. + Note, if the assigned name ``_plane_validator`` is used later on for another validator, say, + the original validator will be overwritten so be aware of this. + + For more details: `Pydantic Validators `_ +""" + +# Lowest frequency supported (Hz) +MIN_FREQUENCY = 1e5 + +FloatArray = Union[Sequence[float], NDArray] + + +def _assert_min_freq(freqs: FloatArray, msg_start: str) -> None: + """Check if all ``freqs`` are above the minimum frequency.""" + if np.min(freqs) < MIN_FREQUENCY: + raise ValidationError( + f"{msg_start} must be no lower than {MIN_FREQUENCY:.0e} Hz. " + "Note that the unit of frequency is 'Hz'." + ) + + +def _warn_unsupported_traced_argument( + *names: str, +) -> Callable[[type, Any, FieldValidationInfo], Any]: + @field_validator(*names) + @classmethod + def _warn_traced_arg(cls: type, val: Any, info: FieldValidationInfo) -> Any: + if hasbox(val): + log.warning( + f"Field '{info.field_name}' of '{cls.__name__}' received an autograd tracer " + f"(i.e., a value being tracked for automatic differentiation). " + f"Automatic differentiation through this field is unsupported, " + f"so the tracer has been converted to its static value. " + f"If you want to avoid this warning, you manually unbox the value " + f"using the 'autograd.tracer.getval' function before passing it to Tidy3D." + ) + return get_static(val) + return val + + return _warn_traced_arg + + +def warn_if_dataset_none( + field_name: str, +) -> Callable[[type, Optional[dict[str, Any]]], Optional[dict[str, Any]]]: + """Warn if a Dataset field has None in its dictionary.""" + + @field_validator(field_name, mode="before") + @classmethod + def _warn_if_none(cls: type, val: Optional[dict[str, Any]]) -> Optional[dict[str, Any]]: + """Warn if the DataArrays fail to load.""" + if isinstance(val, dict): + if any((v in DATA_ARRAY_MAP for _, v in val.items() if isinstance(v, str))): + log.warning(f"Loading {field_name} without data.", custom_loc=[field_name]) + return None + return val + + return _warn_if_none + + +# FIXME: this validator doesn't do anything +def validate_name_str() -> Callable[[type, Optional[str]], Optional[str]]: + """make sure the name does not include [, ] (used for default names)""" + + @field_validator("name") + @classmethod + def field_has_unique_names(cls: type, val: Optional[str]) -> Optional[str]: + """raise exception if '[' or ']' in name""" + # if val and ('[' in val or ']' in val): + # raise SetupError(f"'[' or ']' not allowed in name: {val} (used for defaults)") + return val + + return field_has_unique_names diff --git a/tidy3d/_common/components/viz/__init__.py b/tidy3d/_common/components/viz/__init__.py new file mode 100644 index 0000000000..307d61142c --- /dev/null +++ b/tidy3d/_common/components/viz/__init__.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +from tidy3d._common.components.viz.axes_utils import ( + add_ax_if_none, + equal_aspect, + make_ax, + set_default_labels_and_title, +) +from tidy3d._common.components.viz.descartes import Polygon, polygon_patch, polygon_path +from tidy3d._common.components.viz.flex_style import ( + apply_tidy3d_params, + restore_matplotlib_rcparams, +) +from tidy3d._common.components.viz.plot_params import ( + AbstractPlotParams, + PathPlotParams, + PlotParams, + plot_params_abc, + plot_params_absorber, + plot_params_bloch, + plot_params_fluid, + plot_params_geometry, + plot_params_grid, + plot_params_lumped_element, + plot_params_min_grid_size, + plot_params_monitor, + plot_params_override_structures, + plot_params_pec, + plot_params_pmc, + plot_params_pml, + plot_params_source, + plot_params_structure, + plot_params_symmetry, +) +from tidy3d._common.components.viz.plot_sim_3d import plot_scene_3d, plot_sim_3d +from tidy3d._common.components.viz.styles import ( + ARROW_ALPHA, + ARROW_COLOR_ABSORBER, + ARROW_COLOR_MONITOR, + ARROW_COLOR_POLARIZATION, + ARROW_COLOR_SOURCE, + ARROW_LENGTH, + FLEXCOMPUTE_COLORS, + MEDIUM_CMAP, + PLOT_BUFFER, + STRUCTURE_EPS_CMAP, + STRUCTURE_EPS_CMAP_R, + STRUCTURE_HEAT_COND_CMAP, + arrow_style, +) +from tidy3d._common.components.viz.visualization_spec import MATPLOTLIB_IMPORTED, VisualizationSpec + +apply_tidy3d_params() + +__all__ = [ + "ARROW_ALPHA", + "ARROW_COLOR_ABSORBER", + "ARROW_COLOR_MONITOR", + "ARROW_COLOR_POLARIZATION", + "ARROW_COLOR_SOURCE", + "ARROW_LENGTH", + "FLEXCOMPUTE_COLORS", + "MATPLOTLIB_IMPORTED", + "MEDIUM_CMAP", + "PLOT_BUFFER", + "STRUCTURE_EPS_CMAP", + "STRUCTURE_EPS_CMAP_R", + "STRUCTURE_HEAT_COND_CMAP", + "AbstractPlotParams", + "PathPlotParams", + "PlotParams", + "Polygon", + "VisualizationSpec", + "add_ax_if_none", + "arrow_style", + "equal_aspect", + "make_ax", + "plot_params_abc", + "plot_params_absorber", + "plot_params_bloch", + "plot_params_fluid", + "plot_params_geometry", + "plot_params_grid", + "plot_params_lumped_element", + "plot_params_min_grid_size", + "plot_params_monitor", + "plot_params_override_structures", + "plot_params_pec", + "plot_params_pmc", + "plot_params_pml", + "plot_params_source", + "plot_params_structure", + "plot_params_symmetry", + "plot_scene_3d", + "plot_sim_3d", + "polygon_patch", + "polygon_path", + "restore_matplotlib_rcparams", + "set_default_labels_and_title", +] diff --git a/tidy3d/_common/components/viz/axes_utils.py b/tidy3d/_common/components/viz/axes_utils.py new file mode 100644 index 0000000000..4a3e342a7b --- /dev/null +++ b/tidy3d/_common/components/viz/axes_utils.py @@ -0,0 +1,198 @@ +from __future__ import annotations + +from functools import wraps +from typing import TYPE_CHECKING + +from tidy3d._common.components.types.base import LengthUnit +from tidy3d._common.constants import UnitScaling +from tidy3d._common.exceptions import Tidy3dKeyError + +if TYPE_CHECKING: + from typing import Callable, ParamSpec, TypeVar + + import matplotlib.ticker as ticker + from matplotlib.axes import Axes + + P = ParamSpec("P") + T = TypeVar("T", bound=Callable[..., Axes]) + from typing import Optional + + from tidy3d._common.components.types.base import Ax, Axis + + +def _create_unit_aware_locator() -> ticker.Locator: + """Create UnitAwareLocator lazily due to matplotlib import restrictions.""" + import matplotlib.ticker as ticker + + class UnitAwareLocator(ticker.Locator): + """Custom tick locator that places ticks at nice positions in the target unit.""" + + def __init__(self, scale_factor: float) -> None: + """ + Parameters + ---------- + scale_factor : float + Factor to convert from micrometers to the target unit. + """ + super().__init__() + self.scale_factor = scale_factor + + def __call__(self) -> list[float]: + vmin, vmax = self.axis.get_view_interval() + return self.tick_values(vmin, vmax) + + def view_limits(self, vmin: float, vmax: float) -> tuple[float, float]: + """Override to prevent matplotlib from adjusting our limits.""" + return vmin, vmax + + def tick_values(self, vmin: float, vmax: float) -> list[float]: + # convert the view range to the target unit + vmin_unit = vmin * self.scale_factor + vmax_unit = vmax * self.scale_factor + + # tolerance for floating point comparisons in target unit + unit_range = vmax_unit - vmin_unit + unit_tol = unit_range * 1e-8 + + locator = ticker.MaxNLocator(nbins=11, prune=None, min_n_ticks=2) + + ticks_unit = locator.tick_values(vmin_unit, vmax_unit) + + # ensure we have ticks that cover the full range + if len(ticks_unit) > 0: + if ticks_unit[0] > vmin_unit + unit_tol or ticks_unit[-1] < vmax_unit - unit_tol: + # try with fewer bins to get better coverage + for n in [10, 9, 8, 7, 6, 5]: + locator = ticker.MaxNLocator(nbins=n, prune=None, min_n_ticks=2) + ticks_unit = locator.tick_values(vmin_unit, vmax_unit) + if ( + len(ticks_unit) >= 3 + and ticks_unit[0] <= vmin_unit + unit_tol + and ticks_unit[-1] >= vmax_unit - unit_tol + ): + break + + # if still no good coverage, manually ensure edge coverage + if len(ticks_unit) > 0: + if ( + ticks_unit[0] > vmin_unit + unit_tol + or ticks_unit[-1] < vmax_unit - unit_tol + ): + # find a reasonable step size from existing ticks + if len(ticks_unit) > 1: + step = ticks_unit[1] - ticks_unit[0] + else: + step = unit_range / 5 + + # extend the range to ensure coverage + extended_min = vmin_unit - step + extended_max = vmax_unit + step + + # try one more time with extended range + locator = ticker.MaxNLocator(nbins=8, prune=None, min_n_ticks=2) + ticks_unit = locator.tick_values(extended_min, extended_max) + + # filter to reasonable bounds around the original range + ticks_unit = [ + t + for t in ticks_unit + if t >= vmin_unit - step / 2 and t <= vmax_unit + step / 2 + ] + + # convert the nice ticks back to the original data unit (micrometers) + ticks_um = ticks_unit / self.scale_factor + + # filter to ensure ticks are within bounds (with small tolerance) + eps = (vmax - vmin) * 1e-8 + return [tick for tick in ticks_um if vmin - eps <= tick <= vmax + eps] + + return UnitAwareLocator + + +def make_ax() -> Ax: + """makes an empty ``ax``.""" + import matplotlib.pyplot as plt + + _, ax = plt.subplots(1, 1, tight_layout=True) + return ax + + +def add_ax_if_none(plot: T) -> T: + """Decorates ``plot(*args, **kwargs, ax=None)`` function. + if ax=None in the function call, creates an ax and feeds it to rest of function. + """ + + @wraps(plot) + def _plot(*args: P.args, **kwargs: P.kwargs) -> Axes: + """New plot function using a generated ax if None.""" + if kwargs.get("ax") is None: + ax = make_ax() + kwargs["ax"] = ax + return plot(*args, **kwargs) + + return _plot + + +def equal_aspect(plot: T) -> T: + """Decorates a plotting function returning a matplotlib axes. + Ensures the aspect ratio of the returned axes is set to equal. + Useful for 2D plots, like sim.plot() or sim_data.plot_fields() + """ + + @wraps(plot) + def _plot(*args: P.args, **kwargs: P.kwargs) -> Axes: + """New plot function with equal aspect ratio axes returned.""" + ax = plot(*args, **kwargs) + ax.set_aspect("equal") + return ax + + return _plot + + +def set_default_labels_and_title( + axis_labels: tuple[str, str], + axis: Axis, + position: float, + ax: Ax, + plot_length_units: Optional[LengthUnit] = None, +) -> Ax: + """Adds axis labels and title to plots involving spatial dimensions. + When the ``plot_length_units`` are specified, the plot axes are scaled, and + the title and axis labels include the desired units. + """ + + import matplotlib.ticker as ticker + + xlabel = axis_labels[0] + ylabel = axis_labels[1] + if plot_length_units is not None: + if plot_length_units not in UnitScaling: + raise Tidy3dKeyError( + f"Provided units '{plot_length_units}' are not supported. " + f"Please choose one of '{LengthUnit}'." + ) + ax.set_xlabel(f"{xlabel} ({plot_length_units})") + ax.set_ylabel(f"{ylabel} ({plot_length_units})") + + scale_factor = UnitScaling[plot_length_units] + + # for imperial units, use custom tick locator for nice tick positions + if plot_length_units in ["mil", "in"]: + UnitAwareLocator = _create_unit_aware_locator() + x_locator = UnitAwareLocator(scale_factor) + y_locator = UnitAwareLocator(scale_factor) + ax.xaxis.set_major_locator(x_locator) + ax.yaxis.set_major_locator(y_locator) + + formatter = ticker.FuncFormatter(lambda y, _: f"{y * scale_factor:.2f}") + + ax.xaxis.set_major_formatter(formatter) + ax.yaxis.set_major_formatter(formatter) + + position_scaled = position * scale_factor + ax.set_title(f"cross section at {'xyz'[axis]}={position_scaled:.2f} ({plot_length_units})") + else: + ax.set_xlabel(xlabel) + ax.set_ylabel(ylabel) + ax.set_title(f"cross section at {'xyz'[axis]}={position:.2f}") + return ax diff --git a/tidy3d/_common/components/viz/descartes.py b/tidy3d/_common/components/viz/descartes.py new file mode 100644 index 0000000000..572dfc44ba --- /dev/null +++ b/tidy3d/_common/components/viz/descartes.py @@ -0,0 +1,113 @@ +"""================================================================================================= +Descartes modified from https://pypi.org/project/descartes/ for Shapely >= 1.8.0 + +Copyright Flexcompute 2022 + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR +IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND +FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER +IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT +OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from numpy.typing import NDArray + from shapely.geometry.base import BaseGeometry + +try: + from matplotlib.patches import PathPatch + from matplotlib.path import Path +except ImportError: + pass +from numpy import array, concatenate, ones + + +class Polygon: + """Adapt Shapely polygons to a common interface""" + + def __init__(self, context: dict[str, Any]) -> None: + if isinstance(context, dict): + self.context = context["coordinates"] + else: + self.context = context + + @property + def exterior(self) -> Any: + """Get polygon exterior.""" + value = getattr(self.context, "exterior", None) + if value is None: + value = self.context[0] + return value + + @property + def interiors(self) -> Any: + """Get polygon interiors.""" + value = getattr(self.context, "interiors", None) + if value is None: + value = self.context[1:] + return value + + +def polygon_path(polygon: BaseGeometry) -> Path: + """Constructs a compound matplotlib path from a Shapely or GeoJSON-like + geometric object""" + + def coding(obj: Any) -> NDArray: + # The codes will be all "LINETO" commands, except for "MOVETO"s at the + # beginning of each subpath + crds = getattr(obj, "coords", None) + if crds is None: + crds = obj + n = len(crds) + vals = ones(n, dtype=Path.code_type) * Path.LINETO + if len(vals) > 0: + vals[0] = Path.MOVETO + return vals + + ptype = polygon.geom_type + if ptype == "Polygon": + polygon = [Polygon(polygon)] + elif ptype == "MultiPolygon": + polygon = [Polygon(p) for p in polygon.geoms] + + vertices = concatenate( + [ + concatenate( + [array(t.exterior.coords)[:, :2]] + [array(r.coords)[:, :2] for r in t.interiors] + ) + for t in polygon + ] + ) + codes = concatenate( + [concatenate([coding(t.exterior)] + [coding(r) for r in t.interiors]) for t in polygon] + ) + + return Path(vertices, codes) + + +def polygon_patch(polygon: BaseGeometry, **kwargs: Any) -> PathPatch: + """Constructs a matplotlib patch from a geometric object + + The ``polygon`` may be a Shapely or GeoJSON-like object with or without holes. + The ``kwargs`` are those supported by the matplotlib.patches.Polygon class + constructor. Returns an instance of matplotlib.patches.PathPatch. + + Example + ------- + >>> b = Point(0, 0).buffer(1.0) # doctest: +SKIP + >>> patch = PolygonPatch(b, fc='blue', ec='blue', alpha=0.5) # doctest: +SKIP + >>> axis.add_patch(patch) # doctest: +SKIP + + """ + return PathPatch(polygon_path(polygon), **kwargs) + + +"""End descartes modification +=================================================================================================""" diff --git a/tidy3d/_common/components/viz/flex_color_palettes.py b/tidy3d/_common/components/viz/flex_color_palettes.py new file mode 100644 index 0000000000..7fc1454a0b --- /dev/null +++ b/tidy3d/_common/components/viz/flex_color_palettes.py @@ -0,0 +1,3306 @@ +from __future__ import annotations + +SEQUENTIAL_PALETTES_HEX = { + "flex_turquoise_seq": [ + "#ffffff", + "#fefefe", + "#fdfdfd", + "#fcfcfc", + "#fbfbfb", + "#fafbfa", + "#f9fafa", + "#f8f9f9", + "#f7f8f8", + "#f6f7f7", + "#f5f6f6", + "#f3f5f5", + "#f2f4f4", + "#f1f3f3", + "#f0f3f2", + "#eff2f1", + "#eef1f1", + "#edf0f0", + "#ecefef", + "#ebeeee", + "#eaeded", + "#e9edec", + "#e8eceb", + "#e7ebeb", + "#e6eaea", + "#e5e9e9", + "#e4e8e8", + "#e3e7e7", + "#e2e7e6", + "#e1e6e5", + "#e0e5e5", + "#dfe4e4", + "#dee3e3", + "#dde2e2", + "#dce2e1", + "#dbe1e0", + "#dae0df", + "#d9dfdf", + "#d8dede", + "#d7dedd", + "#d6dddc", + "#d5dcdb", + "#d4dbdb", + "#d3dada", + "#d2dad9", + "#d1d9d8", + "#d1d8d7", + "#d0d7d6", + "#cfd6d6", + "#ced6d5", + "#cdd5d4", + "#ccd4d3", + "#cbd3d2", + "#cad2d2", + "#c9d2d1", + "#c8d1d0", + "#c7d0cf", + "#c6cfce", + "#c5cece", + "#c4cecd", + "#c3cdcc", + "#c2cccb", + "#c1cbca", + "#c0cbca", + "#bfcac9", + "#bec9c8", + "#bec8c7", + "#bdc8c7", + "#bcc7c6", + "#bbc6c5", + "#bac5c4", + "#b9c5c3", + "#b8c4c3", + "#b7c3c2", + "#b6c2c1", + "#b5c2c0", + "#b4c1c0", + "#b3c0bf", + "#b2bfbe", + "#b2bfbd", + "#b1bebd", + "#b0bdbc", + "#afbcbb", + "#aebcba", + "#adbbba", + "#acbab9", + "#abbab8", + "#aab9b7", + "#a9b8b7", + "#a9b7b6", + "#a8b7b5", + "#a7b6b4", + "#a6b5b4", + "#a5b4b3", + "#a4b4b2", + "#a3b3b2", + "#a2b2b1", + "#a1b2b0", + "#a1b1af", + "#a0b0af", + "#9fb0ae", + "#9eafad", + "#9daeac", + "#9cadac", + "#9badab", + "#9aacaa", + "#99abaa", + "#99aba9", + "#98aaa8", + "#97a9a7", + "#96a9a7", + "#95a8a6", + "#94a7a5", + "#93a6a5", + "#92a6a4", + "#92a5a3", + "#91a4a2", + "#90a4a2", + "#8fa3a1", + "#8ea2a0", + "#8da2a0", + "#8ca19f", + "#8ca09e", + "#8ba09e", + "#8a9f9d", + "#899e9c", + "#889e9c", + "#879d9b", + "#869c9a", + "#869c9a", + "#859b99", + "#849a98", + "#839a97", + "#829997", + "#819896", + "#809895", + "#809795", + "#7f9694", + "#7e9693", + "#7d9593", + "#7c9492", + "#7b9491", + "#7a9391", + "#7a9290", + "#79928f", + "#78918f", + "#77908e", + "#76908d", + "#758f8d", + "#758f8c", + "#748e8b", + "#738d8b", + "#728d8a", + "#718c89", + "#708b89", + "#708b88", + "#6f8a87", + "#6e8987", + "#6d8986", + "#6c8885", + "#6b8885", + "#6a8784", + "#6a8684", + "#698683", + "#688582", + "#678482", + "#668481", + "#658380", + "#658280", + "#64827f", + "#63817e", + "#62817e", + "#61807d", + "#607f7c", + "#607f7c", + "#5f7e7b", + "#5e7d7b", + "#5d7d7a", + "#5c7c79", + "#5b7c79", + "#5b7b78", + "#5a7a77", + "#597a77", + "#587976", + "#577975", + "#567875", + "#567774", + "#557774", + "#547673", + "#537572", + "#527572", + "#517471", + "#507470", + "#507370", + "#4f726f", + "#4e726f", + "#4d716e", + "#4c716d", + "#4b706d", + "#4b6f6c", + "#4a6f6b", + "#496e6b", + "#486e6a", + "#476d6a", + "#466c69", + "#456c68", + "#446b68", + "#446b67", + "#436a67", + "#426966", + "#416965", + "#406865", + "#3f6864", + "#3e6763", + "#3e6663", + "#3d6662", + "#3c6562", + "#3b6561", + "#3a6460", + "#396360", + "#38635f", + "#37625f", + "#36625e", + "#35615d", + "#35605d", + "#34605c", + "#335f5c", + "#325f5b", + "#315e5a", + "#305d5a", + "#2f5d59", + "#2e5c58", + "#2d5c58", + "#2c5b57", + "#2b5a57", + "#2a5a56", + "#295955", + "#285955", + "#275854", + "#265754", + "#255753", + "#245652", + "#235652", + "#225551", + "#215551", + "#205450", + "#1e534f", + "#1d534f", + "#1c524e", + "#1b524e", + "#1a514d", + "#18504c", + "#17504c", + "#164f4b", + "#144f4b", + "#134e4a", + ], + "flex_green_seq": [ + "#ffffff", + "#fefefe", + "#fdfdfd", + "#fcfcfc", + "#fbfbfb", + "#f9fafa", + "#f8f9f9", + "#f7f8f8", + "#f6f7f7", + "#f5f6f6", + "#f4f5f5", + "#f3f5f3", + "#f2f4f2", + "#f1f3f1", + "#f0f2f0", + "#eff1ef", + "#eef0ee", + "#ecefed", + "#ebeeec", + "#eaedeb", + "#e9ecea", + "#e8ebe9", + "#e7eae8", + "#e6eae7", + "#e5e9e6", + "#e4e8e5", + "#e3e7e4", + "#e2e6e3", + "#e1e5e2", + "#e0e4e1", + "#dfe3e0", + "#dee3df", + "#dde2de", + "#dce1dd", + "#dbe0dc", + "#dadfdb", + "#d8deda", + "#d7ddd9", + "#d6dcd8", + "#d5dcd7", + "#d4dbd6", + "#d3dad5", + "#d2d9d4", + "#d1d8d3", + "#d0d7d2", + "#cfd6d1", + "#ced6d0", + "#cdd5cf", + "#ccd4ce", + "#cbd3ce", + "#cad2cd", + "#c9d1cc", + "#c8d1cb", + "#c7d0ca", + "#c6cfc9", + "#c5cec8", + "#c4cdc7", + "#c3cdc6", + "#c2ccc5", + "#c1cbc4", + "#c0cac3", + "#bfc9c2", + "#bec9c1", + "#bdc8c0", + "#bcc7bf", + "#bbc6be", + "#bac5bd", + "#b9c5bd", + "#b9c4bc", + "#b8c3bb", + "#b7c2ba", + "#b6c1b9", + "#b5c1b8", + "#b4c0b7", + "#b3bfb6", + "#b2beb5", + "#b1bdb4", + "#b0bdb3", + "#afbcb2", + "#aebbb1", + "#adbab1", + "#acbab0", + "#abb9af", + "#aab8ae", + "#a9b7ad", + "#a8b7ac", + "#a7b6ab", + "#a6b5aa", + "#a5b4a9", + "#a5b4a8", + "#a4b3a8", + "#a3b2a7", + "#a2b1a6", + "#a1b1a5", + "#a0b0a4", + "#9fafa3", + "#9eaea2", + "#9daea1", + "#9cada0", + "#9baca0", + "#9aab9f", + "#99ab9e", + "#99aa9d", + "#98a99c", + "#97a89b", + "#96a89a", + "#95a799", + "#94a699", + "#93a598", + "#92a597", + "#91a496", + "#90a395", + "#90a394", + "#8fa293", + "#8ea193", + "#8da092", + "#8ca091", + "#8b9f90", + "#8a9e8f", + "#899e8e", + "#889d8d", + "#879c8d", + "#879b8c", + "#869b8b", + "#859a8a", + "#849989", + "#839988", + "#829888", + "#819787", + "#809786", + "#809685", + "#7f9584", + "#7e9483", + "#7d9483", + "#7c9382", + "#7b9281", + "#7a9280", + "#79917f", + "#79907e", + "#78907e", + "#778f7d", + "#768e7c", + "#758e7b", + "#748d7a", + "#738c79", + "#728c79", + "#728b78", + "#718a77", + "#708a76", + "#6f8975", + "#6e8875", + "#6d8774", + "#6c8773", + "#6c8672", + "#6b8571", + "#6a8571", + "#698470", + "#68836f", + "#67836e", + "#66826d", + "#66816d", + "#65816c", + "#64806b", + "#637f6a", + "#627f69", + "#617e69", + "#607d68", + "#607d67", + "#5f7c66", + "#5e7c65", + "#5d7b65", + "#5c7a64", + "#5b7a63", + "#5a7962", + "#5a7861", + "#597861", + "#587760", + "#57765f", + "#56765e", + "#55755d", + "#55745d", + "#54745c", + "#53735b", + "#52725a", + "#51725a", + "#507159", + "#4f7058", + "#4f7057", + "#4e6f56", + "#4d6e56", + "#4c6e55", + "#4b6d54", + "#4a6d53", + "#4a6c53", + "#496b52", + "#486b51", + "#476a50", + "#466950", + "#45694f", + "#44684e", + "#44674d", + "#43674c", + "#42664c", + "#41654b", + "#40654a", + "#3f6449", + "#3e6449", + "#3e6348", + "#3d6247", + "#3c6246", + "#3b6146", + "#3a6045", + "#396044", + "#385f43", + "#385e43", + "#375e42", + "#365d41", + "#355c40", + "#345c40", + "#335b3f", + "#325b3e", + "#315a3d", + "#30593d", + "#30593c", + "#2f583b", + "#2e573a", + "#2d573a", + "#2c5639", + "#2b5538", + "#2a5537", + "#295437", + "#285436", + "#275335", + "#265234", + "#265234", + "#255133", + "#245032", + "#235031", + "#224f31", + "#214e30", + "#204e2f", + "#1f4d2e", + "#1e4c2e", + "#1d4c2d", + "#1c4b2c", + "#1b4b2b", + "#1a4a2b", + "#18492a", + "#174929", + "#164828", + "#154728", + "#144727", + "#134626", + "#124525", + "#104525", + "#0f4424", + ], + "flex_blue_seq": [ + "#ffffff", + "#fefefe", + "#fdfdfd", + "#fbfcfc", + "#fafbfb", + "#f9fafa", + "#f8f9f9", + "#f7f7f8", + "#f6f6f8", + "#f4f5f7", + "#f3f4f6", + "#f2f3f5", + "#f1f2f4", + "#f0f1f3", + "#eff0f2", + "#eeeff1", + "#eceef0", + "#ebedf0", + "#eaecef", + "#e9ebee", + "#e8eaed", + "#e7e9ec", + "#e6e8eb", + "#e4e7ea", + "#e3e6ea", + "#e2e5e9", + "#e1e4e8", + "#e0e3e7", + "#dfe2e6", + "#dee1e5", + "#dde0e5", + "#dcdfe4", + "#dadee3", + "#d9dde2", + "#d8dce1", + "#d7dbe0", + "#d6dae0", + "#d5d9df", + "#d4d8de", + "#d3d7dd", + "#d2d6dc", + "#d1d5dc", + "#d0d4db", + "#cfd3da", + "#ced2d9", + "#ccd1d9", + "#cbd0d8", + "#cacfd7", + "#c9ced6", + "#c8cdd5", + "#c7ccd5", + "#c6ccd4", + "#c5cbd3", + "#c4cad2", + "#c3c9d2", + "#c2c8d1", + "#c1c7d0", + "#c0c6cf", + "#bfc5cf", + "#bec4ce", + "#bdc3cd", + "#bcc2cd", + "#bbc1cc", + "#bac0cb", + "#b9c0ca", + "#b8bfca", + "#b7bec9", + "#b6bdc8", + "#b5bcc7", + "#b4bbc7", + "#b3bac6", + "#b2b9c5", + "#b1b8c5", + "#b0b7c4", + "#afb7c3", + "#aeb6c3", + "#adb5c2", + "#acb4c1", + "#abb3c1", + "#aab2c0", + "#a9b1bf", + "#a8b0be", + "#a7b0be", + "#a6afbd", + "#a5aebc", + "#a4adbc", + "#a3acbb", + "#a2abba", + "#a1aaba", + "#a0aab9", + "#9fa9b8", + "#9ea8b8", + "#9da7b7", + "#9ca6b7", + "#9ba5b6", + "#9aa4b5", + "#99a4b5", + "#98a3b4", + "#97a2b3", + "#96a1b3", + "#95a0b2", + "#949fb1", + "#939fb1", + "#929eb0", + "#919db0", + "#909caf", + "#8f9bae", + "#8f9aae", + "#8e9aad", + "#8d99ac", + "#8c98ac", + "#8b97ab", + "#8a96ab", + "#8996aa", + "#8895a9", + "#8794a9", + "#8693a8", + "#8592a8", + "#8492a7", + "#8391a6", + "#8290a6", + "#818fa5", + "#818ea5", + "#808ea4", + "#7f8da3", + "#7e8ca3", + "#7d8ba2", + "#7c8aa2", + "#7b8aa1", + "#7a89a1", + "#7988a0", + "#78879f", + "#77869f", + "#76869e", + "#76859e", + "#75849d", + "#74839d", + "#73829c", + "#72829b", + "#71819b", + "#70809a", + "#6f7f9a", + "#6e7f99", + "#6d7e99", + "#6c7d98", + "#6c7c98", + "#6b7c97", + "#6a7b97", + "#697a96", + "#687995", + "#677895", + "#667894", + "#657794", + "#647693", + "#637593", + "#637592", + "#627492", + "#617391", + "#607291", + "#5f7290", + "#5e7190", + "#5d708f", + "#5c6f8f", + "#5b6f8e", + "#5b6e8e", + "#5a6d8d", + "#596c8c", + "#586b8c", + "#576b8b", + "#566a8b", + "#55698a", + "#54688a", + "#536889", + "#536789", + "#526688", + "#516588", + "#506587", + "#4f6487", + "#4e6386", + "#4d6286", + "#4c6285", + "#4b6185", + "#4a6084", + "#4a5f84", + "#495f83", + "#485e83", + "#475d83", + "#465d82", + "#455c82", + "#445b81", + "#435a81", + "#425a80", + "#425980", + "#41587f", + "#40577f", + "#3f577e", + "#3e567e", + "#3d557d", + "#3c547d", + "#3b547c", + "#3a537c", + "#39527b", + "#39517b", + "#38517b", + "#37507a", + "#364f7a", + "#354e79", + "#344e79", + "#334d78", + "#324c78", + "#314b77", + "#304b77", + "#2f4a76", + "#2e4976", + "#2d4876", + "#2c4875", + "#2c4775", + "#2b4674", + "#2a4574", + "#294473", + "#284473", + "#274373", + "#264272", + "#254172", + "#244171", + "#234071", + "#223f70", + "#213e70", + "#203e70", + "#1f3d6f", + "#1e3c6f", + "#1d3b6e", + "#1c3a6e", + "#1b3a6e", + "#1a396d", + "#19386d", + "#17376c", + "#16366c", + "#15366c", + "#14356b", + "#13346b", + "#12336b", + "#10326a", + "#0f326a", + "#0e316a", + "#0d3069", + "#0b2f69", + "#0a2e68", + "#082d68", + "#072c68", + "#062c68", + "#042b67", + "#032a67", + "#022967", + "#012866", + "#002766", + ], + "flex_orange_seq": [ + "#ffffff", + "#fefefe", + "#fefdfd", + "#fdfdfc", + "#fdfcfb", + "#fcfbfa", + "#fbfafa", + "#fbf9f9", + "#faf9f8", + "#faf8f7", + "#f9f7f6", + "#f8f6f5", + "#f8f6f4", + "#f7f5f3", + "#f7f4f2", + "#f6f3f1", + "#f5f2f1", + "#f5f2f0", + "#f4f1ef", + "#f3f0ee", + "#f3efed", + "#f2efec", + "#f2eeeb", + "#f1edea", + "#f1ece9", + "#f0ece8", + "#f0ebe7", + "#efeae6", + "#efe9e5", + "#eee9e4", + "#eee8e3", + "#ede7e2", + "#ede6e1", + "#ece5e0", + "#ece5df", + "#ebe4de", + "#ebe3dd", + "#eae2dc", + "#eae2db", + "#e9e1da", + "#e9e0d9", + "#e9dfd8", + "#e8dfd7", + "#e8ded6", + "#e7ddd5", + "#e7dcd4", + "#e6dbd3", + "#e6dbd2", + "#e6dad1", + "#e5d9d0", + "#e5d8cf", + "#e4d8ce", + "#e4d7cd", + "#e3d6cc", + "#e3d5cb", + "#e3d5ca", + "#e2d4c9", + "#e2d3c8", + "#e1d2c7", + "#e1d2c6", + "#e0d1c5", + "#e0d0c4", + "#e0cfc3", + "#dfcfc2", + "#dfcec1", + "#decdc0", + "#deccbf", + "#deccbe", + "#ddcbbd", + "#ddcabc", + "#dcc9bb", + "#dcc9ba", + "#dcc8b9", + "#dbc7b8", + "#dbc6b8", + "#dbc6b7", + "#dac5b6", + "#dac4b5", + "#d9c4b4", + "#d9c3b3", + "#d9c2b2", + "#d8c1b1", + "#d8c1b0", + "#d7c0af", + "#d7bfae", + "#d7bead", + "#d6beac", + "#d6bdab", + "#d6bcaa", + "#d5bba9", + "#d5bba8", + "#d4baa7", + "#d4b9a6", + "#d4b9a5", + "#d3b8a4", + "#d3b7a3", + "#d3b6a2", + "#d2b6a1", + "#d2b5a0", + "#d2b49f", + "#d1b49e", + "#d1b39d", + "#d0b29c", + "#d0b19b", + "#d0b19a", + "#cfb099", + "#cfaf99", + "#cfaf98", + "#ceae97", + "#cead96", + "#ceac95", + "#cdac94", + "#cdab93", + "#cdaa92", + "#ccaa91", + "#cca990", + "#cca88f", + "#cba78e", + "#cba78d", + "#caa68c", + "#caa58b", + "#caa58a", + "#c9a489", + "#c9a388", + "#c9a387", + "#c8a286", + "#c8a185", + "#c8a085", + "#c7a084", + "#c79f83", + "#c79e82", + "#c69e81", + "#c69d80", + "#c69c7f", + "#c59c7e", + "#c59b7d", + "#c59a7c", + "#c4997b", + "#c4997a", + "#c49879", + "#c39778", + "#c39777", + "#c29676", + "#c29575", + "#c29575", + "#c19474", + "#c19373", + "#c19372", + "#c09271", + "#c09170", + "#c0906f", + "#bf906e", + "#bf8f6d", + "#bf8e6c", + "#be8e6b", + "#be8d6a", + "#be8c69", + "#bd8c68", + "#bd8b67", + "#bd8a67", + "#bc8a66", + "#bc8965", + "#bc8864", + "#bb8863", + "#bb8762", + "#bb8661", + "#ba8660", + "#ba855f", + "#ba845e", + "#b9835d", + "#b9835c", + "#b8825b", + "#b8815b", + "#b8815a", + "#b78059", + "#b77f58", + "#b77f57", + "#b67e56", + "#b67d55", + "#b67d54", + "#b57c53", + "#b57b52", + "#b57b51", + "#b47a50", + "#b4794f", + "#b4794f", + "#b3784e", + "#b3774d", + "#b3774c", + "#b2764b", + "#b2754a", + "#b17549", + "#b17448", + "#b17347", + "#b07346", + "#b07245", + "#b07144", + "#af7144", + "#af7043", + "#af6f42", + "#ae6f41", + "#ae6e40", + "#ae6d3f", + "#ad6d3e", + "#ad6c3d", + "#ac6b3c", + "#ac6b3b", + "#ac6a3a", + "#ab6939", + "#ab6939", + "#ab6838", + "#aa6737", + "#aa6736", + "#aa6635", + "#a96534", + "#a96533", + "#a86432", + "#a86331", + "#a86330", + "#a7622f", + "#a7612e", + "#a7612d", + "#a6602c", + "#a65f2b", + "#a55f2a", + "#a55e2a", + "#a55d29", + "#a45d28", + "#a45c27", + "#a35b26", + "#a35b25", + "#a35a24", + "#a25923", + "#a25922", + "#a25821", + "#a15720", + "#a1571f", + "#a0561e", + "#a0551d", + "#a0551c", + "#9f541b", + "#9f531a", + "#9e5318", + "#9e5217", + "#9e5116", + "#9d5115", + "#9d5014", + "#9c4f13", + "#9c4f12", + "#9b4e10", + "#9b4d0f", + "#9b4d0e", + "#9a4c0c", + "#9a4b0b", + "#994b09", + "#994a08", + ], + "flex_red_seq": [ + "#ffffff", + "#fefefe", + "#fefdfd", + "#fdfcfc", + "#fcfbfb", + "#fcfafa", + "#fbf9f9", + "#faf8f8", + "#faf7f7", + "#f9f6f6", + "#f8f5f5", + "#f8f4f5", + "#f7f3f4", + "#f6f2f3", + "#f5f2f2", + "#f5f1f1", + "#f4f0f0", + "#f3efef", + "#f3eeee", + "#f2eded", + "#f1ecec", + "#f1ebec", + "#f0eaeb", + "#efe9ea", + "#efe8e9", + "#eee7e8", + "#eee6e7", + "#ede5e6", + "#ece4e6", + "#ece3e5", + "#ebe2e4", + "#ebe1e3", + "#eae0e2", + "#eae0e1", + "#e9dfe0", + "#e9dedf", + "#e8dddf", + "#e8dcde", + "#e7dbdd", + "#e7dadc", + "#e6d9db", + "#e6d8da", + "#e5d7d9", + "#e5d6d8", + "#e4d5d7", + "#e4d4d7", + "#e3d3d6", + "#e3d2d5", + "#e2d1d4", + "#e2d0d3", + "#e1d0d2", + "#e1cfd1", + "#e0ced1", + "#e0cdd0", + "#dfcccf", + "#dfcbce", + "#decacd", + "#dec9cc", + "#ddc8cb", + "#ddc7cb", + "#dcc6ca", + "#dcc5c9", + "#dbc4c8", + "#dbc4c7", + "#dbc3c6", + "#dac2c5", + "#dac1c5", + "#d9c0c4", + "#d9bfc3", + "#d8bec2", + "#d8bdc1", + "#d7bcc0", + "#d7bbc0", + "#d7babf", + "#d6babe", + "#d6b9bd", + "#d5b8bc", + "#d5b7bb", + "#d4b6bb", + "#d4b5ba", + "#d4b4b9", + "#d3b3b8", + "#d3b2b7", + "#d2b1b6", + "#d2b0b6", + "#d1b0b5", + "#d1afb4", + "#d1aeb3", + "#d0adb2", + "#d0acb1", + "#cfabb1", + "#cfaab0", + "#cfa9af", + "#cea8ae", + "#cea8ad", + "#cda7ad", + "#cda6ac", + "#cca5ab", + "#cca4aa", + "#cca3a9", + "#cba2a9", + "#cba1a8", + "#caa0a7", + "#caa0a6", + "#ca9fa5", + "#c99ea5", + "#c99da4", + "#c89ca3", + "#c89ba2", + "#c89aa1", + "#c799a1", + "#c799a0", + "#c6989f", + "#c6979e", + "#c6969d", + "#c5959d", + "#c5949c", + "#c4939b", + "#c4929a", + "#c49299", + "#c39199", + "#c39098", + "#c28f97", + "#c28e96", + "#c28d96", + "#c18c95", + "#c18c94", + "#c18b93", + "#c08a92", + "#c08992", + "#bf8891", + "#bf8790", + "#bf868f", + "#be858f", + "#be858e", + "#bd848d", + "#bd838c", + "#bd828c", + "#bc818b", + "#bc808a", + "#bb7f89", + "#bb7f88", + "#bb7e88", + "#ba7d87", + "#ba7c86", + "#b97b85", + "#b97a85", + "#b97984", + "#b87983", + "#b87882", + "#b87782", + "#b77681", + "#b77580", + "#b6747f", + "#b6737f", + "#b6737e", + "#b5727d", + "#b5717c", + "#b4707c", + "#b46f7b", + "#b46e7a", + "#b36d79", + "#b36d79", + "#b26c78", + "#b26b77", + "#b26a76", + "#b16976", + "#b16875", + "#b06774", + "#b06773", + "#b06673", + "#af6572", + "#af6471", + "#ae6371", + "#ae6270", + "#ae616f", + "#ad616e", + "#ad606e", + "#ac5f6d", + "#ac5e6c", + "#ab5d6b", + "#ab5c6b", + "#ab5b6a", + "#aa5b69", + "#aa5a69", + "#a95968", + "#a95867", + "#a95766", + "#a85666", + "#a85565", + "#a75464", + "#a75463", + "#a65363", + "#a65262", + "#a65161", + "#a55061", + "#a54f60", + "#a44e5f", + "#a44d5e", + "#a34d5e", + "#a34c5d", + "#a34b5c", + "#a24a5c", + "#a2495b", + "#a1485a", + "#a1475a", + "#a04659", + "#a04558", + "#9f4557", + "#9f4457", + "#9f4356", + "#9e4255", + "#9e4155", + "#9d4054", + "#9d3f53", + "#9c3e52", + "#9c3d52", + "#9b3c51", + "#9b3b50", + "#9a3b50", + "#9a3a4f", + "#99394e", + "#99384e", + "#98374d", + "#98364c", + "#98354b", + "#97344b", + "#97334a", + "#963249", + "#963149", + "#953048", + "#952f47", + "#942e47", + "#942d46", + "#932c45", + "#932b45", + "#922a44", + "#922943", + "#912843", + "#912742", + "#902641", + "#902540", + "#8f2440", + "#8e223f", + "#8e213e", + "#8d203e", + "#8d1f3d", + "#8c1e3c", + "#8c1d3c", + "#8b1b3b", + "#8b1a3a", + "#8a193a", + "#8a1739", + "#891638", + "#891438", + "#881337", + ], + "flex_purple_seq": [ + "#ffffff", + "#fefefe", + "#fdfdfd", + "#fcfcfd", + "#fbfbfc", + "#fafafb", + "#f9f9fa", + "#f8f8f9", + "#f7f7f9", + "#f6f6f8", + "#f5f5f7", + "#f4f4f6", + "#f3f3f6", + "#f2f2f5", + "#f1f1f4", + "#f0f0f3", + "#efeff3", + "#eeeef2", + "#ededf1", + "#ececf0", + "#ebebf0", + "#eaeaef", + "#e9e9ee", + "#e8e8ed", + "#e8e8ed", + "#e7e7ec", + "#e6e6eb", + "#e5e5eb", + "#e4e4ea", + "#e3e3e9", + "#e2e2e8", + "#e1e1e8", + "#e0e0e7", + "#dfdfe6", + "#dedee6", + "#dddde5", + "#dcdce4", + "#dbdbe4", + "#dadae3", + "#d9dae2", + "#d9d9e2", + "#d8d8e1", + "#d7d7e0", + "#d6d6e0", + "#d5d5df", + "#d4d4df", + "#d3d3de", + "#d2d2dd", + "#d1d1dd", + "#d0d1dc", + "#d0d0db", + "#cfcfdb", + "#ceceda", + "#cdcdd9", + "#ccccd9", + "#cbcbd8", + "#cacad8", + "#c9c9d7", + "#c8c9d6", + "#c8c8d6", + "#c7c7d5", + "#c6c6d5", + "#c5c5d4", + "#c4c4d3", + "#c3c3d3", + "#c2c2d2", + "#c2c2d2", + "#c1c1d1", + "#c0c0d1", + "#bfbfd0", + "#bebecf", + "#bdbdcf", + "#bcbcce", + "#bcbcce", + "#bbbbcd", + "#babacd", + "#b9b9cc", + "#b8b8cc", + "#b7b7cb", + "#b7b7ca", + "#b6b6ca", + "#b5b5c9", + "#b4b4c9", + "#b3b3c8", + "#b3b2c8", + "#b2b2c7", + "#b1b1c7", + "#b0b0c6", + "#afafc6", + "#aeaec5", + "#aeadc5", + "#adadc4", + "#acacc4", + "#ababc3", + "#aaaac3", + "#aaa9c2", + "#a9a8c2", + "#a8a8c1", + "#a7a7c1", + "#a6a6c0", + "#a6a5c0", + "#a5a4bf", + "#a4a4bf", + "#a3a3be", + "#a3a2be", + "#a2a1bd", + "#a1a0bd", + "#a0a0bc", + "#9f9fbc", + "#9f9ebb", + "#9e9dbb", + "#9d9cba", + "#9c9cba", + "#9c9bb9", + "#9b9ab9", + "#9a99b8", + "#9998b8", + "#9998b8", + "#9897b7", + "#9796b7", + "#9695b6", + "#9694b6", + "#9594b5", + "#9493b5", + "#9392b4", + "#9391b4", + "#9291b4", + "#9190b3", + "#908fb3", + "#908eb2", + "#8f8db2", + "#8e8db1", + "#8d8cb1", + "#8d8bb1", + "#8c8ab0", + "#8b8ab0", + "#8a89af", + "#8a88af", + "#8987ae", + "#8886ae", + "#8886ae", + "#8785ad", + "#8684ad", + "#8583ac", + "#8583ac", + "#8482ac", + "#8381ab", + "#8280ab", + "#8280ab", + "#817faa", + "#807eaa", + "#807da9", + "#7f7ca9", + "#7e7ca9", + "#7e7ba8", + "#7d7aa8", + "#7c79a8", + "#7b79a7", + "#7b78a7", + "#7a77a6", + "#7976a6", + "#7976a6", + "#7875a5", + "#7774a5", + "#7773a5", + "#7673a4", + "#7572a4", + "#7571a4", + "#7470a3", + "#736fa3", + "#736fa3", + "#726ea2", + "#716da2", + "#716ca2", + "#706ca1", + "#6f6ba1", + "#6f6aa1", + "#6e69a0", + "#6d69a0", + "#6d68a0", + "#6c679f", + "#6b669f", + "#6b669f", + "#6a659e", + "#69649e", + "#69639e", + "#68629d", + "#67629d", + "#67619d", + "#66609d", + "#655f9c", + "#655f9c", + "#645e9c", + "#635d9b", + "#635c9b", + "#625c9b", + "#625b9b", + "#615a9a", + "#60599a", + "#60589a", + "#5f589a", + "#5e5799", + "#5e5699", + "#5d5599", + "#5d5498", + "#5c5498", + "#5b5398", + "#5b5298", + "#5a5198", + "#595097", + "#595097", + "#584f97", + "#584e97", + "#574d96", + "#564c96", + "#564c96", + "#554b96", + "#554a95", + "#544995", + "#544895", + "#534795", + "#524795", + "#524694", + "#514594", + "#514494", + "#504394", + "#504294", + "#4f4294", + "#4e4193", + "#4e4093", + "#4d3f93", + "#4d3e93", + "#4c3d93", + "#4c3c93", + "#4b3b93", + "#4b3a92", + "#4a3992", + "#4a3892", + "#493892", + "#493792", + "#483692", + "#483592", + "#473492", + "#473391", + "#463291", + "#463191", + "#452f91", + "#452e91", + "#442d91", + "#442c91", + "#432b91", + "#432a91", + "#422991", + "#422891", + "#412691", + "#412591", + ], + "flex_grey_seq": [ + "#ffffff", + "#fefefe", + "#fdfdfd", + "#fcfcfc", + "#fbfbfc", + "#fafafb", + "#f9f9fa", + "#f8f9f9", + "#f8f8f8", + "#f7f7f7", + "#f6f6f6", + "#f5f5f6", + "#f4f4f5", + "#f3f3f4", + "#f2f2f3", + "#f1f1f2", + "#f0f0f1", + "#eff0f1", + "#eeeff0", + "#eeeeef", + "#ededee", + "#ececed", + "#ebebec", + "#eaeaec", + "#e9e9eb", + "#e8e9ea", + "#e7e8e9", + "#e6e7e8", + "#e6e6e8", + "#e5e5e7", + "#e4e4e6", + "#e3e3e5", + "#e2e3e4", + "#e1e2e4", + "#e0e1e3", + "#dfe0e2", + "#dfdfe1", + "#dedee0", + "#dddde0", + "#dcdddf", + "#dbdcde", + "#dadbdd", + "#d9dadd", + "#d9d9dc", + "#d8d8db", + "#d7d8da", + "#d6d7da", + "#d5d6d9", + "#d4d5d8", + "#d4d4d7", + "#d3d4d6", + "#d2d3d6", + "#d1d2d5", + "#d0d1d4", + "#cfd0d3", + "#cfcfd3", + "#cecfd2", + "#cdced1", + "#cccdd0", + "#cbccd0", + "#cacbcf", + "#cacbce", + "#c9cace", + "#c8c9cd", + "#c7c8cc", + "#c6c8cb", + "#c6c7cb", + "#c5c6ca", + "#c4c5c9", + "#c3c4c8", + "#c2c4c8", + "#c2c3c7", + "#c1c2c6", + "#c0c1c6", + "#bfc0c5", + "#bec0c4", + "#bebfc3", + "#bdbec3", + "#bcbdc2", + "#bbbdc1", + "#babcc1", + "#babbc0", + "#b9babf", + "#b8babe", + "#b7b9be", + "#b7b8bd", + "#b6b7bc", + "#b5b7bc", + "#b4b6bb", + "#b3b5ba", + "#b3b4ba", + "#b2b4b9", + "#b1b3b8", + "#b0b2b8", + "#b0b1b7", + "#afb1b6", + "#aeb0b6", + "#adafb5", + "#adaeb4", + "#acaeb3", + "#abadb3", + "#aaacb2", + "#aaabb1", + "#a9abb1", + "#a8aab0", + "#a7a9af", + "#a7a8af", + "#a6a8ae", + "#a5a7ad", + "#a4a6ad", + "#a4a6ac", + "#a3a5ab", + "#a2a4ab", + "#a1a3aa", + "#a1a3aa", + "#a0a2a9", + "#9fa1a8", + "#9ea1a8", + "#9ea0a7", + "#9d9fa6", + "#9c9ea6", + "#9c9ea5", + "#9b9da4", + "#9a9ca4", + "#999ca3", + "#999ba2", + "#989aa2", + "#979aa1", + "#9799a0", + "#9698a0", + "#95979f", + "#94979f", + "#94969e", + "#93959d", + "#92959d", + "#92949c", + "#91939b", + "#90939b", + "#8f929a", + "#8f919a", + "#8e9199", + "#8d9098", + "#8d8f98", + "#8c8e97", + "#8b8e96", + "#8b8d96", + "#8a8c95", + "#898c95", + "#888b94", + "#888a93", + "#878a93", + "#868992", + "#868891", + "#858891", + "#848790", + "#848690", + "#83868f", + "#82858e", + "#82848e", + "#81848d", + "#80838d", + "#80828c", + "#7f828b", + "#7e818b", + "#7e808a", + "#7d808a", + "#7c7f89", + "#7b7e88", + "#7b7e88", + "#7a7d87", + "#797c87", + "#797c86", + "#787b85", + "#777a85", + "#777a84", + "#767984", + "#757983", + "#757882", + "#747782", + "#737781", + "#737681", + "#727580", + "#71757f", + "#71747f", + "#70737e", + "#70737e", + "#6f727d", + "#6e717d", + "#6e717c", + "#6d707b", + "#6c707b", + "#6c6f7a", + "#6b6e7a", + "#6a6e79", + "#6a6d79", + "#696c78", + "#686c77", + "#686b77", + "#676a76", + "#666a76", + "#666975", + "#656974", + "#646874", + "#646773", + "#636773", + "#636672", + "#626572", + "#616571", + "#616470", + "#606370", + "#5f636f", + "#5f626f", + "#5e626e", + "#5d616e", + "#5d606d", + "#5c606c", + "#5b5f6c", + "#5b5e6b", + "#5a5e6b", + "#5a5d6a", + "#595d6a", + "#585c69", + "#585b69", + "#575b68", + "#565a67", + "#565967", + "#555966", + "#555866", + "#545865", + "#535765", + "#535664", + "#525663", + "#515563", + "#515562", + "#505462", + "#4f5361", + "#4f5361", + "#4e5260", + "#4e5160", + "#4d515f", + "#4c505e", + "#4c505e", + "#4b4f5d", + "#4a4e5d", + "#4a4e5c", + "#494d5c", + "#494d5b", + "#484c5a", + "#474b5a", + "#474b59", + "#464a59", + "#454958", + "#454958", + "#444857", + "#444857", + "#434756", + ], +} +CATEGORICAL_PALETTES_HEX = { + "flex_distinct": [ + "#176737", + "#FF7B0D", + "#979BAA", + "#F44E6A", + "#0062FF", + "#26AB5B", + "#6D3EF2", + "#F59E0B", + ] +} +DIVERGING_PALETTES_HEX = { + "flex_BuRd": [ + "#002766", + "#022967", + "#052b67", + "#072d68", + "#0a2e69", + "#0d3069", + "#10326a", + "#12346b", + "#15356c", + "#17376c", + "#1a396d", + "#1c3a6e", + "#1e3c6f", + "#203e70", + "#223f71", + "#244171", + "#264372", + "#284473", + "#2a4674", + "#2c4775", + "#2e4976", + "#304a77", + "#324c78", + "#344e79", + "#364f7a", + "#38517b", + "#3a527c", + "#3c547d", + "#3e557e", + "#3f577f", + "#415980", + "#435a80", + "#455c81", + "#475d83", + "#495f84", + "#4b6085", + "#4c6286", + "#4e6387", + "#506588", + "#526789", + "#54688a", + "#566a8b", + "#586b8c", + "#5a6d8d", + "#5b6e8e", + "#5d708f", + "#5f7290", + "#617391", + "#637592", + "#657694", + "#677895", + "#687a96", + "#6a7b97", + "#6c7d98", + "#6e7e99", + "#70809a", + "#72829b", + "#74839d", + "#76859e", + "#78879f", + "#7a88a0", + "#7b8aa1", + "#7d8ca3", + "#7f8da4", + "#818fa5", + "#8391a6", + "#8592a8", + "#8794a9", + "#8996aa", + "#8b97ab", + "#8d99ad", + "#8f9bae", + "#919daf", + "#939eb1", + "#95a0b2", + "#97a2b3", + "#99a4b4", + "#9ba5b6", + "#9da7b7", + "#9fa9b9", + "#a1abba", + "#a3acbb", + "#a5aebd", + "#a7b0be", + "#a9b2c0", + "#abb4c1", + "#adb6c2", + "#afb7c4", + "#b1b9c5", + "#b4bbc7", + "#b6bdc8", + "#b8bfca", + "#bac1cb", + "#bcc3cd", + "#bec5ce", + "#c0c6d0", + "#c3c8d1", + "#c5cad3", + "#c7ccd5", + "#c9ced6", + "#cbd0d8", + "#ced2d9", + "#d0d4db", + "#d2d6dd", + "#d4d8de", + "#d7dae0", + "#d9dce2", + "#dbdee3", + "#dee0e5", + "#e0e3e7", + "#e2e5e9", + "#e4e7ea", + "#e7e9ec", + "#e9ebee", + "#ecedf0", + "#eeeff2", + "#f0f2f3", + "#f3f4f5", + "#f5f6f7", + "#f8f8f9", + "#fafafb", + "#fdfdfd", + "#FFFFFF", + "#fefdfd", + "#fcfbfb", + "#fbf9f9", + "#f9f7f7", + "#f8f5f5", + "#f6f3f3", + "#f5f1f1", + "#f4efef", + "#f2edee", + "#f1ebec", + "#efe9ea", + "#eee7e8", + "#ede5e6", + "#ece3e4", + "#ebe1e3", + "#e9dfe1", + "#e8dddf", + "#e7dbdd", + "#e6d9db", + "#e5d7d9", + "#e4d5d8", + "#e3d3d6", + "#e2d1d4", + "#e1cfd2", + "#e0cdd0", + "#dfcccf", + "#decacd", + "#ddc8cb", + "#dcc6c9", + "#dbc4c7", + "#dac2c6", + "#d9c0c4", + "#d8bec2", + "#d7bcc0", + "#d6babf", + "#d6b8bd", + "#d5b6bb", + "#d4b5b9", + "#d3b3b8", + "#d2b1b6", + "#d1afb4", + "#d0adb2", + "#cfabb1", + "#cfa9af", + "#cea8ad", + "#cda6ac", + "#cca4aa", + "#cba2a8", + "#caa0a7", + "#c99ea5", + "#c99ca3", + "#c89ba2", + "#c799a0", + "#c6979e", + "#c5959d", + "#c4939b", + "#c49199", + "#c39098", + "#c28e96", + "#c18c94", + "#c08a93", + "#c08891", + "#bf8790", + "#be858e", + "#bd838c", + "#bc818b", + "#bb7f89", + "#bb7e88", + "#ba7c86", + "#b97a84", + "#b87883", + "#b77681", + "#b67580", + "#b6737e", + "#b5717d", + "#b46f7b", + "#b36d79", + "#b26c78", + "#b26a76", + "#b16875", + "#b06673", + "#af6572", + "#ae6370", + "#ad616f", + "#ac5f6d", + "#ac5d6c", + "#ab5c6a", + "#aa5a69", + "#a95867", + "#a85666", + "#a75464", + "#a65263", + "#a55161", + "#a44f60", + "#a44d5e", + "#a34b5d", + "#a2495b", + "#a1475a", + "#a04658", + "#9f4457", + "#9e4255", + "#9d4054", + "#9c3e52", + "#9b3c51", + "#9a3a4f", + "#99384e", + "#98364c", + "#97344b", + "#96324a", + "#953048", + "#942e47", + "#932c45", + "#922a44", + "#912842", + "#902541", + "#8f233f", + "#8e213e", + "#8d1e3d", + "#8b1c3b", + "#8a193a", + "#891638", + "#881337", + ], + "flex_RdBu": [ + "#881337", + "#891638", + "#8a193a", + "#8b1c3b", + "#8d1e3d", + "#8e213e", + "#8f233f", + "#902541", + "#912842", + "#922a44", + "#932c45", + "#942e47", + "#953048", + "#96324a", + "#97344b", + "#98364c", + "#99384e", + "#9a3a4f", + "#9b3c51", + "#9c3e52", + "#9d4054", + "#9e4255", + "#9f4457", + "#a04658", + "#a1475a", + "#a2495b", + "#a34b5d", + "#a44d5e", + "#a44f60", + "#a55161", + "#a65263", + "#a75464", + "#a85666", + "#a95867", + "#aa5a69", + "#ab5c6a", + "#ac5d6c", + "#ac5f6d", + "#ad616f", + "#ae6370", + "#af6572", + "#b06673", + "#b16875", + "#b26a76", + "#b26c78", + "#b36d79", + "#b46f7b", + "#b5717d", + "#b6737e", + "#b67580", + "#b77681", + "#b87883", + "#b97a84", + "#ba7c86", + "#bb7e88", + "#bb7f89", + "#bc818b", + "#bd838c", + "#be858e", + "#bf8790", + "#c08891", + "#c08a93", + "#c18c94", + "#c28e96", + "#c39098", + "#c49199", + "#c4939b", + "#c5959d", + "#c6979e", + "#c799a0", + "#c89ba2", + "#c99ca3", + "#c99ea5", + "#caa0a7", + "#cba2a8", + "#cca4aa", + "#cda6ac", + "#cea8ad", + "#cfa9af", + "#cfabb1", + "#d0adb2", + "#d1afb4", + "#d2b1b6", + "#d3b3b8", + "#d4b5b9", + "#d5b6bb", + "#d6b8bd", + "#d6babf", + "#d7bcc0", + "#d8bec2", + "#d9c0c4", + "#dac2c6", + "#dbc4c7", + "#dcc6c9", + "#ddc8cb", + "#decacd", + "#dfcccf", + "#e0cdd0", + "#e1cfd2", + "#e2d1d4", + "#e3d3d6", + "#e4d5d8", + "#e5d7d9", + "#e6d9db", + "#e7dbdd", + "#e8dddf", + "#e9dfe1", + "#ebe1e3", + "#ece3e4", + "#ede5e6", + "#eee7e8", + "#efe9ea", + "#f1ebec", + "#f2edee", + "#f4efef", + "#f5f1f1", + "#f6f3f3", + "#f8f5f5", + "#f9f7f7", + "#fbf9f9", + "#fcfbfb", + "#fefdfd", + "#FFFFFF", + "#fdfdfd", + "#fafafb", + "#f8f8f9", + "#f5f6f7", + "#f3f4f5", + "#f0f2f3", + "#eeeff2", + "#ecedf0", + "#e9ebee", + "#e7e9ec", + "#e4e7ea", + "#e2e5e9", + "#e0e3e7", + "#dee0e5", + "#dbdee3", + "#d9dce2", + "#d7dae0", + "#d4d8de", + "#d2d6dd", + "#d0d4db", + "#ced2d9", + "#cbd0d8", + "#c9ced6", + "#c7ccd5", + "#c5cad3", + "#c3c8d1", + "#c0c6d0", + "#bec5ce", + "#bcc3cd", + "#bac1cb", + "#b8bfca", + "#b6bdc8", + "#b4bbc7", + "#b1b9c5", + "#afb7c4", + "#adb6c2", + "#abb4c1", + "#a9b2c0", + "#a7b0be", + "#a5aebd", + "#a3acbb", + "#a1abba", + "#9fa9b9", + "#9da7b7", + "#9ba5b6", + "#99a4b4", + "#97a2b3", + "#95a0b2", + "#939eb1", + "#919daf", + "#8f9bae", + "#8d99ad", + "#8b97ab", + "#8996aa", + "#8794a9", + "#8592a8", + "#8391a6", + "#818fa5", + "#7f8da4", + "#7d8ca3", + "#7b8aa1", + "#7a88a0", + "#78879f", + "#76859e", + "#74839d", + "#72829b", + "#70809a", + "#6e7e99", + "#6c7d98", + "#6a7b97", + "#687a96", + "#677895", + "#657694", + "#637592", + "#617391", + "#5f7290", + "#5d708f", + "#5b6e8e", + "#5a6d8d", + "#586b8c", + "#566a8b", + "#54688a", + "#526789", + "#506588", + "#4e6387", + "#4c6286", + "#4b6085", + "#495f84", + "#475d83", + "#455c81", + "#435a80", + "#415980", + "#3f577f", + "#3e557e", + "#3c547d", + "#3a527c", + "#38517b", + "#364f7a", + "#344e79", + "#324c78", + "#304a77", + "#2e4976", + "#2c4775", + "#2a4674", + "#284473", + "#264372", + "#244171", + "#223f71", + "#203e70", + "#1e3c6f", + "#1c3a6e", + "#1a396d", + "#17376c", + "#15356c", + "#12346b", + "#10326a", + "#0d3069", + "#0a2e69", + "#072d68", + "#052b67", + "#022967", + "#002766", + ], + "flex_GrPu": [ + "#0f4424", + "#124526", + "#144727", + "#174829", + "#19492a", + "#1b4b2c", + "#1d4c2d", + "#1f4d2f", + "#214f30", + "#235032", + "#255234", + "#275335", + "#295437", + "#2b5638", + "#2d573a", + "#2f583b", + "#315a3d", + "#335b3e", + "#355c40", + "#365e42", + "#385f43", + "#3a6045", + "#3c6246", + "#3e6348", + "#3f644a", + "#41664b", + "#43674d", + "#45684e", + "#476a50", + "#486b52", + "#4a6c53", + "#4c6e55", + "#4e6f56", + "#4f7058", + "#51725a", + "#53735b", + "#55745d", + "#57765f", + "#587760", + "#5a7962", + "#5c7a63", + "#5e7b65", + "#5f7d67", + "#617e68", + "#637f6a", + "#65816c", + "#67826d", + "#68846f", + "#6a8571", + "#6c8672", + "#6e8874", + "#6f8976", + "#718b78", + "#738c79", + "#758e7b", + "#778f7d", + "#79907e", + "#7a9280", + "#7c9382", + "#7e9584", + "#809685", + "#829887", + "#849989", + "#859b8b", + "#879c8c", + "#899d8e", + "#8b9f90", + "#8da092", + "#8fa294", + "#91a395", + "#92a597", + "#94a699", + "#96a89b", + "#98aa9d", + "#9aab9e", + "#9cada0", + "#9eaea2", + "#a0b0a4", + "#a2b1a6", + "#a4b3a8", + "#a6b4aa", + "#a8b6ab", + "#aab8ad", + "#acb9af", + "#aebbb1", + "#b0bcb3", + "#b2beb5", + "#b4c0b7", + "#b6c1b9", + "#b8c3bb", + "#bac5bd", + "#bcc6bf", + "#bec8c1", + "#c0cac2", + "#c2cbc4", + "#c4cdc6", + "#c6cfc8", + "#c8d0ca", + "#cad2cc", + "#ccd4ce", + "#ced6d0", + "#d0d7d2", + "#d3d9d5", + "#d5dbd7", + "#d7ddd9", + "#d9dfdb", + "#dbe0dd", + "#dde2df", + "#dfe4e1", + "#e2e6e3", + "#e4e8e5", + "#e6eae7", + "#e8ebe9", + "#ebedeb", + "#edefee", + "#eff1f0", + "#f1f3f2", + "#f4f5f4", + "#f6f7f6", + "#f8f9f8", + "#fafbfb", + "#fdfdfd", + "#FFFFFF", + "#fdfdfd", + "#fbfbfc", + "#f9f9fa", + "#f7f7f8", + "#f5f5f7", + "#f3f3f5", + "#f1f0f4", + "#efeef2", + "#ececf0", + "#eaeaef", + "#e8e8ed", + "#e6e6ec", + "#e4e5ea", + "#e3e3e9", + "#e1e1e8", + "#dfdfe6", + "#dddde5", + "#dbdbe3", + "#d9d9e2", + "#d7d7e1", + "#d5d5df", + "#d3d3de", + "#d1d1dd", + "#cfd0db", + "#ceceda", + "#ccccd9", + "#cacad7", + "#c8c8d6", + "#c6c6d5", + "#c4c4d4", + "#c3c3d2", + "#c1c1d1", + "#bfbfd0", + "#bdbdcf", + "#bcbcce", + "#babacd", + "#b8b8cb", + "#b6b6ca", + "#b5b4c9", + "#b3b3c8", + "#b1b1c7", + "#afafc6", + "#aeaec5", + "#acacc4", + "#aaaac3", + "#a9a8c1", + "#a7a7c0", + "#a5a5bf", + "#a4a3be", + "#a2a2bd", + "#a1a0bc", + "#9f9ebb", + "#9d9dba", + "#9c9bb9", + "#9a99b8", + "#9898b8", + "#9796b7", + "#9594b6", + "#9493b5", + "#9291b4", + "#918fb3", + "#8f8eb2", + "#8e8cb1", + "#8c8ab0", + "#8b89af", + "#8987af", + "#8786ae", + "#8684ad", + "#8482ac", + "#8381ab", + "#817faa", + "#807eaa", + "#7f7ca9", + "#7d7aa8", + "#7c79a7", + "#7a77a6", + "#7976a6", + "#7774a5", + "#7672a4", + "#7471a4", + "#736fa3", + "#726ea2", + "#706ca1", + "#6f6aa1", + "#6d69a0", + "#6c679f", + "#6b669f", + "#69649e", + "#68629d", + "#67619d", + "#655f9c", + "#645e9c", + "#635c9b", + "#615a9a", + "#60599a", + "#5f5799", + "#5d5599", + "#5c5498", + "#5b5298", + "#595097", + "#584f97", + "#574d96", + "#564b96", + "#554a95", + "#534895", + "#524695", + "#514494", + "#504394", + "#4f4193", + "#4d3f93", + "#4c3d93", + "#4b3b92", + "#4a3992", + "#493792", + "#483592", + "#473392", + "#463191", + "#452f91", + "#442d91", + "#432a91", + "#422891", + "#412591", + ], + "flex_PuGr": [ + "#412591", + "#422891", + "#432a91", + "#442d91", + "#452f91", + "#463191", + "#473392", + "#483592", + "#493792", + "#4a3992", + "#4b3b92", + "#4c3d93", + "#4d3f93", + "#4f4193", + "#504394", + "#514494", + "#524695", + "#534895", + "#554a95", + "#564b96", + "#574d96", + "#584f97", + "#595097", + "#5b5298", + "#5c5498", + "#5d5599", + "#5f5799", + "#60599a", + "#615a9a", + "#635c9b", + "#645e9c", + "#655f9c", + "#67619d", + "#68629d", + "#69649e", + "#6b669f", + "#6c679f", + "#6d69a0", + "#6f6aa1", + "#706ca1", + "#726ea2", + "#736fa3", + "#7471a4", + "#7672a4", + "#7774a5", + "#7976a6", + "#7a77a6", + "#7c79a7", + "#7d7aa8", + "#7f7ca9", + "#807eaa", + "#817faa", + "#8381ab", + "#8482ac", + "#8684ad", + "#8786ae", + "#8987af", + "#8b89af", + "#8c8ab0", + "#8e8cb1", + "#8f8eb2", + "#918fb3", + "#9291b4", + "#9493b5", + "#9594b6", + "#9796b7", + "#9898b8", + "#9a99b8", + "#9c9bb9", + "#9d9dba", + "#9f9ebb", + "#a1a0bc", + "#a2a2bd", + "#a4a3be", + "#a5a5bf", + "#a7a7c0", + "#a9a8c1", + "#aaaac3", + "#acacc4", + "#aeaec5", + "#afafc6", + "#b1b1c7", + "#b3b3c8", + "#b5b4c9", + "#b6b6ca", + "#b8b8cb", + "#babacd", + "#bcbcce", + "#bdbdcf", + "#bfbfd0", + "#c1c1d1", + "#c3c3d2", + "#c4c4d4", + "#c6c6d5", + "#c8c8d6", + "#cacad7", + "#ccccd9", + "#ceceda", + "#cfd0db", + "#d1d1dd", + "#d3d3de", + "#d5d5df", + "#d7d7e1", + "#d9d9e2", + "#dbdbe3", + "#dddde5", + "#dfdfe6", + "#e1e1e8", + "#e3e3e9", + "#e4e5ea", + "#e6e6ec", + "#e8e8ed", + "#eaeaef", + "#ececf0", + "#efeef2", + "#f1f0f4", + "#f3f3f5", + "#f5f5f7", + "#f7f7f8", + "#f9f9fa", + "#fbfbfc", + "#fdfdfd", + "#FFFFFF", + "#fdfdfd", + "#fafbfb", + "#f8f9f8", + "#f6f7f6", + "#f4f5f4", + "#f1f3f2", + "#eff1f0", + "#edefee", + "#ebedeb", + "#e8ebe9", + "#e6eae7", + "#e4e8e5", + "#e2e6e3", + "#dfe4e1", + "#dde2df", + "#dbe0dd", + "#d9dfdb", + "#d7ddd9", + "#d5dbd7", + "#d3d9d5", + "#d0d7d2", + "#ced6d0", + "#ccd4ce", + "#cad2cc", + "#c8d0ca", + "#c6cfc8", + "#c4cdc6", + "#c2cbc4", + "#c0cac2", + "#bec8c1", + "#bcc6bf", + "#bac5bd", + "#b8c3bb", + "#b6c1b9", + "#b4c0b7", + "#b2beb5", + "#b0bcb3", + "#aebbb1", + "#acb9af", + "#aab8ad", + "#a8b6ab", + "#a6b4aa", + "#a4b3a8", + "#a2b1a6", + "#a0b0a4", + "#9eaea2", + "#9cada0", + "#9aab9e", + "#98aa9d", + "#96a89b", + "#94a699", + "#92a597", + "#91a395", + "#8fa294", + "#8da092", + "#8b9f90", + "#899d8e", + "#879c8c", + "#859b8b", + "#849989", + "#829887", + "#809685", + "#7e9584", + "#7c9382", + "#7a9280", + "#79907e", + "#778f7d", + "#758e7b", + "#738c79", + "#718b78", + "#6f8976", + "#6e8874", + "#6c8672", + "#6a8571", + "#68846f", + "#67826d", + "#65816c", + "#637f6a", + "#617e68", + "#5f7d67", + "#5e7b65", + "#5c7a63", + "#5a7962", + "#587760", + "#57765f", + "#55745d", + "#53735b", + "#51725a", + "#4f7058", + "#4e6f56", + "#4c6e55", + "#4a6c53", + "#486b52", + "#476a50", + "#45684e", + "#43674d", + "#41664b", + "#3f644a", + "#3e6348", + "#3c6246", + "#3a6045", + "#385f43", + "#365e42", + "#355c40", + "#335b3e", + "#315a3d", + "#2f583b", + "#2d573a", + "#2b5638", + "#295437", + "#275335", + "#255234", + "#235032", + "#214f30", + "#1f4d2f", + "#1d4c2d", + "#1b4b2c", + "#19492a", + "#174829", + "#144727", + "#124526", + "#0f4424", + ], + "flex_TuOr": [ + "#134e4a", + "#164f4b", + "#19504d", + "#1b524e", + "#1e534f", + "#205450", + "#225552", + "#255753", + "#275854", + "#295955", + "#2b5a57", + "#2d5c58", + "#2f5d59", + "#315e5a", + "#335f5c", + "#35615d", + "#37625e", + "#39635f", + "#3b6461", + "#3c6662", + "#3e6763", + "#406865", + "#426966", + "#446b67", + "#456c68", + "#476d6a", + "#496e6b", + "#4b706c", + "#4d716e", + "#4e726f", + "#507370", + "#527572", + "#547673", + "#557774", + "#577975", + "#597a77", + "#5b7b78", + "#5d7c79", + "#5e7e7b", + "#607f7c", + "#62807d", + "#64827f", + "#658380", + "#678482", + "#698683", + "#6b8784", + "#6c8886", + "#6e8a87", + "#708b88", + "#728c8a", + "#738e8b", + "#758f8c", + "#77908e", + "#79928f", + "#7a9391", + "#7c9492", + "#7e9693", + "#809795", + "#819896", + "#839a98", + "#859b99", + "#879d9b", + "#899e9c", + "#8a9f9d", + "#8ca19f", + "#8ea2a0", + "#90a4a2", + "#92a5a3", + "#93a7a5", + "#95a8a6", + "#97a9a8", + "#99aba9", + "#9bacab", + "#9daeac", + "#9eafae", + "#a0b1af", + "#a2b2b1", + "#a4b4b2", + "#a6b5b4", + "#a8b7b5", + "#aab8b7", + "#acbab8", + "#adbbba", + "#afbdbb", + "#b1bebd", + "#b3c0bf", + "#b5c1c0", + "#b7c3c2", + "#b9c5c3", + "#bbc6c5", + "#bdc8c7", + "#bfc9c8", + "#c1cbca", + "#c3cccc", + "#c5cecd", + "#c7d0cf", + "#c9d1d1", + "#cbd3d2", + "#cdd5d4", + "#cfd6d6", + "#d1d8d7", + "#d3dad9", + "#d5dbdb", + "#d7dddc", + "#d9dfde", + "#dbe0e0", + "#dde2e2", + "#dfe4e3", + "#e1e6e5", + "#e3e7e7", + "#e5e9e9", + "#e7ebeb", + "#e9edec", + "#ebeeee", + "#eef0f0", + "#f0f2f2", + "#f2f4f4", + "#f4f6f6", + "#f6f8f7", + "#f8f9f9", + "#fbfbfb", + "#fdfdfd", + "#FFFFFF", + "#fefdfd", + "#fcfcfb", + "#fbfaf9", + "#faf8f7", + "#f9f7f6", + "#f7f5f4", + "#f6f4f2", + "#f5f2f0", + "#f4f0ee", + "#f2efec", + "#f1edea", + "#f0ece8", + "#efeae6", + "#eee8e4", + "#ede7e2", + "#ece5df", + "#ebe4dd", + "#eae2db", + "#e9e0d9", + "#e8dfd7", + "#e7ddd5", + "#e6dcd3", + "#e5dad1", + "#e5d8cf", + "#e4d7cd", + "#e3d5cb", + "#e2d4c9", + "#e1d2c7", + "#e0d0c5", + "#dfcfc3", + "#dfcdc1", + "#deccbe", + "#ddcabc", + "#dcc9ba", + "#dbc7b8", + "#dac6b6", + "#dac4b4", + "#d9c2b2", + "#d8c1b0", + "#d7bfae", + "#d6beac", + "#d6bcaa", + "#d5bba8", + "#d4b9a6", + "#d3b8a4", + "#d3b6a2", + "#d2b5a0", + "#d1b39e", + "#d0b29c", + "#d0b09a", + "#cfaf98", + "#cead96", + "#cdac94", + "#cdaa92", + "#cca990", + "#cba78e", + "#caa68c", + "#caa48a", + "#c9a388", + "#c8a286", + "#c7a084", + "#c79f82", + "#c69d80", + "#c59c7e", + "#c59a7c", + "#c4997a", + "#c39778", + "#c29676", + "#c29474", + "#c19372", + "#c09270", + "#c0906e", + "#bf8f6c", + "#be8d6b", + "#bd8c69", + "#bd8a67", + "#bc8965", + "#bb8863", + "#bb8661", + "#ba855f", + "#b9835d", + "#b8825b", + "#b88059", + "#b77f57", + "#b67e55", + "#b57c53", + "#b57b51", + "#b47950", + "#b3784e", + "#b3774c", + "#b2754a", + "#b17448", + "#b07246", + "#b07144", + "#af7042", + "#ae6e40", + "#ad6d3e", + "#ad6b3c", + "#ac6a3a", + "#ab6938", + "#aa6737", + "#a96635", + "#a96433", + "#a86331", + "#a7622f", + "#a6602d", + "#a65f2b", + "#a55d29", + "#a45c27", + "#a35b25", + "#a25923", + "#a15821", + "#a1571f", + "#a0551c", + "#9f541a", + "#9e5218", + "#9d5116", + "#9c5013", + "#9c4e11", + "#9b4d0e", + "#9a4b0b", + "#994a08", + ], + "flex_OrTu": [ + "#994a08", + "#9a4b0b", + "#9b4d0e", + "#9c4e11", + "#9c5013", + "#9d5116", + "#9e5218", + "#9f541a", + "#a0551c", + "#a1571f", + "#a15821", + "#a25923", + "#a35b25", + "#a45c27", + "#a55d29", + "#a65f2b", + "#a6602d", + "#a7622f", + "#a86331", + "#a96433", + "#a96635", + "#aa6737", + "#ab6938", + "#ac6a3a", + "#ad6b3c", + "#ad6d3e", + "#ae6e40", + "#af7042", + "#b07144", + "#b07246", + "#b17448", + "#b2754a", + "#b3774c", + "#b3784e", + "#b47950", + "#b57b51", + "#b57c53", + "#b67e55", + "#b77f57", + "#b88059", + "#b8825b", + "#b9835d", + "#ba855f", + "#bb8661", + "#bb8863", + "#bc8965", + "#bd8a67", + "#bd8c69", + "#be8d6b", + "#bf8f6c", + "#c0906e", + "#c09270", + "#c19372", + "#c29474", + "#c29676", + "#c39778", + "#c4997a", + "#c59a7c", + "#c59c7e", + "#c69d80", + "#c79f82", + "#c7a084", + "#c8a286", + "#c9a388", + "#caa48a", + "#caa68c", + "#cba78e", + "#cca990", + "#cdaa92", + "#cdac94", + "#cead96", + "#cfaf98", + "#d0b09a", + "#d0b29c", + "#d1b39e", + "#d2b5a0", + "#d3b6a2", + "#d3b8a4", + "#d4b9a6", + "#d5bba8", + "#d6bcaa", + "#d6beac", + "#d7bfae", + "#d8c1b0", + "#d9c2b2", + "#dac4b4", + "#dac6b6", + "#dbc7b8", + "#dcc9ba", + "#ddcabc", + "#deccbe", + "#dfcdc1", + "#dfcfc3", + "#e0d0c5", + "#e1d2c7", + "#e2d4c9", + "#e3d5cb", + "#e4d7cd", + "#e5d8cf", + "#e5dad1", + "#e6dcd3", + "#e7ddd5", + "#e8dfd7", + "#e9e0d9", + "#eae2db", + "#ebe4dd", + "#ece5df", + "#ede7e2", + "#eee8e4", + "#efeae6", + "#f0ece8", + "#f1edea", + "#f2efec", + "#f4f0ee", + "#f5f2f0", + "#f6f4f2", + "#f7f5f4", + "#f9f7f6", + "#faf8f7", + "#fbfaf9", + "#fcfcfb", + "#fefdfd", + "#FFFFFF", + "#fdfdfd", + "#fbfbfb", + "#f8f9f9", + "#f6f8f7", + "#f4f6f6", + "#f2f4f4", + "#f0f2f2", + "#eef0f0", + "#ebeeee", + "#e9edec", + "#e7ebeb", + "#e5e9e9", + "#e3e7e7", + "#e1e6e5", + "#dfe4e3", + "#dde2e2", + "#dbe0e0", + "#d9dfde", + "#d7dddc", + "#d5dbdb", + "#d3dad9", + "#d1d8d7", + "#cfd6d6", + "#cdd5d4", + "#cbd3d2", + "#c9d1d1", + "#c7d0cf", + "#c5cecd", + "#c3cccc", + "#c1cbca", + "#bfc9c8", + "#bdc8c7", + "#bbc6c5", + "#b9c5c3", + "#b7c3c2", + "#b5c1c0", + "#b3c0bf", + "#b1bebd", + "#afbdbb", + "#adbbba", + "#acbab8", + "#aab8b7", + "#a8b7b5", + "#a6b5b4", + "#a4b4b2", + "#a2b2b1", + "#a0b1af", + "#9eafae", + "#9daeac", + "#9bacab", + "#99aba9", + "#97a9a8", + "#95a8a6", + "#93a7a5", + "#92a5a3", + "#90a4a2", + "#8ea2a0", + "#8ca19f", + "#8a9f9d", + "#899e9c", + "#879d9b", + "#859b99", + "#839a98", + "#819896", + "#809795", + "#7e9693", + "#7c9492", + "#7a9391", + "#79928f", + "#77908e", + "#758f8c", + "#738e8b", + "#728c8a", + "#708b88", + "#6e8a87", + "#6c8886", + "#6b8784", + "#698683", + "#678482", + "#658380", + "#64827f", + "#62807d", + "#607f7c", + "#5e7e7b", + "#5d7c79", + "#5b7b78", + "#597a77", + "#577975", + "#557774", + "#547673", + "#527572", + "#507370", + "#4e726f", + "#4d716e", + "#4b706c", + "#496e6b", + "#476d6a", + "#456c68", + "#446b67", + "#426966", + "#406865", + "#3e6763", + "#3c6662", + "#3b6461", + "#39635f", + "#37625e", + "#35615d", + "#335f5c", + "#315e5a", + "#2f5d59", + "#2d5c58", + "#2b5a57", + "#295955", + "#275854", + "#255753", + "#225552", + "#205450", + "#1e534f", + "#1b524e", + "#19504d", + "#164f4b", + "#134e4a", + ], +} diff --git a/tidy3d/_common/components/viz/flex_style.py b/tidy3d/_common/components/viz/flex_style.py new file mode 100644 index 0000000000..babcd3e9f1 --- /dev/null +++ b/tidy3d/_common/components/viz/flex_style.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from tidy3d._common.log import log + +_ORIGINAL_PARAMS = None + + +def apply_tidy3d_params() -> None: + """ + Applies a set of defaults to the matplotlib params that are following the tidy3d color palettes and design. + """ + global _ORIGINAL_PARAMS + try: + import matplotlib as mpl + import matplotlib.pyplot as plt + + _ORIGINAL_PARAMS = mpl.rcParams.copy() + + try: + plt.style.use("tidy3d.style") + except Exception as e: + log.error(f"Failed to apply Tidy3D plotting style on import. Error: {e}") + _ORIGINAL_PARAMS = {} + except ImportError: + pass + + +def restore_matplotlib_rcparams() -> None: + """ + Resets matplotlib rcParams to the values they had before the Tidy3D + style was automatically applied on import. + """ + global _ORIGINAL_PARAMS + try: + import matplotlib.pyplot as plt + from matplotlib import style + + if not _ORIGINAL_PARAMS: + style.use("default") + return + + plt.rcParams.update(_ORIGINAL_PARAMS) + except ImportError: + log.error("Matplotlib is not installed on your system. Failed to reset to default styles.") + except Exception as e: + log.error(f"Failed to reset previous Matplotlib style. Error: {e}") diff --git a/tidy3d/_common/components/viz/plot_params.py b/tidy3d/_common/components/viz/plot_params.py new file mode 100644 index 0000000000..4da16d5f4a --- /dev/null +++ b/tidy3d/_common/components/viz/plot_params.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Optional + +from numpy import inf +from pydantic import Field, NonNegativeFloat + +from tidy3d._common.components.base import Tidy3dBaseModel + +if TYPE_CHECKING: + from tidy3d._common.components.viz.visualization_spec import VisualizationSpec + +if TYPE_CHECKING: + from tidy3d._common.components.viz.visualization_spec import VisualizationSpec + + +class AbstractPlotParams(Tidy3dBaseModel): + """Abstract class for storing plotting parameters. + Corresponds with select properties of ``matplotlib.artist.Artist``. + """ + + alpha: Any = Field(1.0, title="Opacity") + zorder: Optional[float] = Field(None, title="Display Order") + + def include_kwargs(self, **kwargs: Any) -> AbstractPlotParams: + """Update the plot params with supplied kwargs.""" + update_dict = { + key: value + for key, value in kwargs.items() + if key not in ("type",) and value is not None and key in type(self).model_fields + } + return self.copy(update=update_dict) + + def override_with_viz_spec(self, viz_spec: VisualizationSpec) -> AbstractPlotParams: + """Override plot params with supplied VisualizationSpec.""" + return self.include_kwargs(**dict(viz_spec)) + + def to_kwargs(self) -> dict[str, Any]: + """Export the plot parameters as kwargs dict that can be supplied to plot function.""" + kwarg_dict = self.model_dump() + for ignore_key in ("type", "attrs"): + kwarg_dict.pop(ignore_key) + return kwarg_dict + + +class PathPlotParams(AbstractPlotParams): + """Stores plotting parameters / specifications for a path. + Corresponds with select properties of ``matplotlib.lines.Line2D``. + """ + + color: Optional[Any] = Field(None, title="Color", alias="c") + linewidth: NonNegativeFloat = Field(2, title="Line Width", alias="lw") + linestyle: str = Field("--", title="Line Style", alias="ls") + marker: Any = Field("o", title="Marker Style") + markeredgecolor: Optional[Any] = Field(None, title="Marker Edge Color", alias="mec") + markerfacecolor: Optional[Any] = Field(None, title="Marker Face Color", alias="mfc") + markersize: NonNegativeFloat = Field(10, title="Marker Size", alias="ms") + + +class PlotParams(AbstractPlotParams): + """Stores plotting parameters / specifications for a given model. + Corresponds with select properties of ``matplotlib.patches.Patch``. + """ + + edgecolor: Optional[Any] = Field(None, title="Edge Color", alias="ec") + facecolor: Optional[Any] = Field(None, title="Face Color", alias="fc") + fill: bool = Field(True, title="Is Filled") + hatch: Optional[str] = Field(None, title="Hatch Style") + linewidth: NonNegativeFloat = Field(1, title="Line Width", alias="lw") + + +# defaults for different tidy3d objects +plot_params_geometry = PlotParams() +plot_params_structure = PlotParams() +plot_params_source = PlotParams(alpha=0.4, facecolor="limegreen", edgecolor="limegreen", lw=3) +plot_params_absorber = PlotParams( + alpha=0.4, facecolor="lightskyblue", edgecolor="lightskyblue", lw=3 +) +plot_params_monitor = PlotParams(alpha=0.4, facecolor="orange", edgecolor="orange", lw=3) +plot_params_pml = PlotParams(alpha=0.7, facecolor="gray", edgecolor="gray", hatch="x", zorder=inf) +plot_params_pec = PlotParams(alpha=1.0, facecolor="gold", edgecolor="black", zorder=inf) +plot_params_pmc = PlotParams(alpha=1.0, facecolor="lightsteelblue", edgecolor="black", zorder=inf) +plot_params_bloch = PlotParams(alpha=1.0, facecolor="orchid", edgecolor="black", zorder=inf) +plot_params_abc = PlotParams(alpha=1.0, facecolor="lightskyblue", edgecolor="black", zorder=inf) +plot_params_symmetry = PlotParams(edgecolor="gray", facecolor="gray", alpha=0.6, zorder=inf) +plot_params_override_structures = PlotParams( + linewidth=0.4, edgecolor="black", fill=False, zorder=inf +) +plot_params_fluid = PlotParams(facecolor="white", edgecolor="lightsteelblue", lw=0.4, hatch="xx") +plot_params_grid = PlotParams(edgecolor="black", lw=0.2) +plot_params_lumped_element = PlotParams( + alpha=0.4, facecolor="mediumblue", edgecolor="mediumblue", lw=3 +) +plot_params_min_grid_size = PlotParams( + alpha=0.5, facecolor="gray", edgecolor="darkred", lw=0, fill=True, hatch=".", zorder=0 +) diff --git a/tidy3d/_common/components/viz/plot_sim_3d.py b/tidy3d/_common/components/viz/plot_sim_3d.py new file mode 100644 index 0000000000..6a2471e4ff --- /dev/null +++ b/tidy3d/_common/components/viz/plot_sim_3d.py @@ -0,0 +1,228 @@ +from __future__ import annotations + +from html import escape +from typing import ( + TYPE_CHECKING, + Protocol, + runtime_checkable, +) + +from tidy3d._common.exceptions import SetupError + +if TYPE_CHECKING: + import io + from collections.abc import Sequence + from os import PathLike + from typing import ( + Callable, + Optional, + Union, + runtime_checkable, + ) + + from IPython.core.display_functions import DisplayHandle + + +@runtime_checkable +class PlotSim3DProtocol(Protocol): + def to_hdf5_gz( + self, + fname: Union[PathLike[str], io.BytesIO], + custom_encoders: Optional[Sequence[Callable[..., object]]] = None, + ) -> None: ... + + +@runtime_checkable +class PlotScene3DProtocol(Protocol): + # Used by plot_scene_3d to patch JSON_STRING + size: Sequence[float] + center: Sequence[float] + + def to_hdf5( + self, + fname: Union[PathLike[str], io.BytesIO], + custom_encoders: Optional[Sequence[Callable[..., object]]] = None, + ) -> None: ... + + +def plot_scene_3d(scene: PlotScene3DProtocol, width: int = 800, height: int = 800) -> None: + import gzip + import json + from base64 import b64encode + from io import BytesIO + + import h5py + + # Serialize scene to HDF5 in-memory + buffer = BytesIO() + scene.to_hdf5(buffer) + buffer.seek(0) + + # Open source HDF5 for reading and prepare modified copy + with h5py.File(buffer, "r") as src: + buffer2 = BytesIO() + with h5py.File(buffer2, "w") as dst: + + def copy_item(name: str, obj: h5py.Group | h5py.Dataset) -> None: + if isinstance(obj, h5py.Group): + dst.create_group(name) + for k, v in obj.attrs.items(): + dst[name].attrs[k] = v + elif isinstance(obj, h5py.Dataset): + data = obj[()] + if name == "JSON_STRING": + # Parse and update JSON string + json_str = ( + data.decode("utf-8") if isinstance(data, (bytes, bytearray)) else data + ) + json_data = json.loads(json_str) + json_data["size"] = list(scene.size) + json_data["center"] = list(scene.center) + json_data["grid_spec"] = {} + new_str = json.dumps(json_data) + dst.create_dataset(name, data=new_str.encode("utf-8")) + else: + dst.create_dataset(name, data=data) + for k, v in obj.attrs.items(): + dst[name].attrs[k] = v + + src.visititems(copy_item) + buffer2.seek(0) + + # Gzip the modified HDF5 + gz_buffer = BytesIO() + with gzip.GzipFile(fileobj=gz_buffer, mode="wb") as gz: + gz.write(buffer2.read()) + gz_buffer.seek(0) + + # Base64 encode and display with gzipped flag + sim_base64 = b64encode(gz_buffer.read()).decode("utf-8") + plot_sim_3d(sim_base64, width=width, height=height, is_gz_base64=True) + + +def plot_sim_3d( + sim: Union[PlotSim3DProtocol, str], + width: int = 800, + height: int = 800, + is_gz_base64: bool = False, +) -> DisplayHandle: + """Make 3D display of simulation in ipython notebook.""" + + try: + from IPython.display import HTML, display + except ImportError as e: + raise SetupError( + "3D plotting requires ipython to be installed " + "and the code to be running on a jupyter notebook." + ) from e + + from base64 import b64encode + from io import BytesIO + + if not is_gz_base64: + buffer = BytesIO() + sim.to_hdf5_gz(buffer) + buffer.seek(0) + base64 = b64encode(buffer.read()).decode("utf-8") + else: + base64 = sim + + js_code = """ + /** + * Simulation Viewer Injector + * + * Monitors the document for elements being added in the form: + * + *
+ * + * This script will then inject an iframe to the viewer application, and pass it the simulation data + * via the postMessage API on request. The script may be safely included multiple times, with only the + * configuration of the first started script (e.g. viewer URL) applying. + * + */ + (function() { + const TARGET_CLASS = "simulation-viewer"; + const ACTIVE_CLASS = "simulation-viewer-active"; + const VIEWER_URL = "https://tidy3d.simulation.cloud/simulation-viewer"; + + class SimulationViewerInjector { + constructor() { + for (var node of document.getElementsByClassName(TARGET_CLASS)) { + this.injectViewer(node); + } + + // Monitor for newly added nodes to the DOM + this.observer = new MutationObserver(this.onMutations.bind(this)); + this.observer.observe(document.body, {childList: true, subtree: true}); + } + + onMutations(mutations) { + for (var mutation of mutations) { + if (mutation.type === 'childList') { + /** + * Have found that adding the element does not reliably trigger the mutation observer. + * It may be the case that setting content with innerHTML does not trigger. + * + * It seems to be sufficient to re-scan the document for un-activated viewers + * whenever an event occurs, as Jupyter triggers multiple events on cell evaluation. + */ + var viewers = document.getElementsByClassName(TARGET_CLASS); + for (var node of viewers) { + this.injectViewer(node); + } + } + } + } + + injectViewer(node) { + // (re-)check that this is a valid simulation container and has not already been injected + if (node.classList.contains(TARGET_CLASS) && !node.classList.contains(ACTIVE_CLASS)) { + // Mark node as injected, to prevent re-runs + node.classList.add(ACTIVE_CLASS); + + var uuid; + if (window.crypto && window.crypto.randomUUID) { + uuid = window.crypto.randomUUID(); + } else { + uuid = "" + Math.random(); + } + + var frame = document.createElement("iframe"); + frame.width = node.dataset.width || 800; + frame.height = node.dataset.height || 800; + frame.style.cssText = `width:${frame.width}px;height:${frame.height}px;max-width:none;border:0;display:block` + frame.src = VIEWER_URL + "?uuid=" + uuid; + + var postMessageToViewer; + postMessageToViewer = event => { + if(event.data.type === 'viewer' && event.data.uuid===uuid){ + frame.contentWindow.postMessage({ type: 'jupyter', uuid, value: node.dataset.simulation, fileType: 'hdf5'}, '*'); + + // Run once only + window.removeEventListener('message', postMessageToViewer); + } + }; + window.addEventListener( + 'message', + postMessageToViewer, + false + ); + + node.appendChild(frame); + } + } + } + + if (!window.simulationViewerInjector) { + window.simulationViewerInjector = new SimulationViewerInjector(); + } + })(); + """ + html_code = f""" +
+ + """ + + return display(HTML(html_code)) diff --git a/tidy3d/_common/components/viz/styles.py b/tidy3d/_common/components/viz/styles.py new file mode 100644 index 0000000000..067afa9327 --- /dev/null +++ b/tidy3d/_common/components/viz/styles.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +try: + from matplotlib.patches import ArrowStyle + + arrow_style = ArrowStyle.Simple(head_length=11, head_width=9, tail_width=4) +except ImportError: + arrow_style = None + +FLEXCOMPUTE_COLORS = { + "brand_green": "#00643C", + "brand_tan": "#B8A18B", + "brand_blue": "#6DB5DD", + "brand_purple": "#8851AD", + "brand_black": "#000000", + "brand_orange": "#FC7A4C", +} +ARROW_COLOR_SOURCE = FLEXCOMPUTE_COLORS["brand_green"] +ARROW_COLOR_POLARIZATION = FLEXCOMPUTE_COLORS["brand_tan"] +ARROW_COLOR_MONITOR = FLEXCOMPUTE_COLORS["brand_orange"] +ARROW_COLOR_ABSORBER = FLEXCOMPUTE_COLORS["brand_blue"] +PLOT_BUFFER = 0.3 +ARROW_ALPHA = 0.8 +ARROW_LENGTH = 0.3 + +# stores color of simulation.structures for given index in simulation.medium_map +MEDIUM_CMAP = [ + "#689DBC", + "#D0698E", + "#5E6EAD", + "#C6224E", + "#BDB3E2", + "#9EC3E0", + "#616161", + "#877EBC", +] + +# colormap for structure's permittivity in plot_eps +STRUCTURE_EPS_CMAP = "gist_yarg" +STRUCTURE_EPS_CMAP_R = "gist_yarg_r" +STRUCTURE_HEAT_COND_CMAP = "gist_yarg" diff --git a/tidy3d/_common/components/viz/visualization_spec.py b/tidy3d/_common/components/viz/visualization_spec.py new file mode 100644 index 0000000000..3d4c6d4afb --- /dev/null +++ b/tidy3d/_common/components/viz/visualization_spec.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from pydantic import Field, field_validator + +from tidy3d._common.components.base import Tidy3dBaseModel +from tidy3d._common.log import log + +if TYPE_CHECKING: + from pydantic import ValidationInfo + +if TYPE_CHECKING: + from pydantic import ValidationInfo + +MATPLOTLIB_IMPORTED = True +try: + from matplotlib.colors import is_color_like +except ImportError: + is_color_like = None + MATPLOTLIB_IMPORTED = False + + +def is_valid_color(value: str) -> str: + if not MATPLOTLIB_IMPORTED: + log.warning( + "matplotlib was not successfully imported, but is required " + "to validate colors in the VisualizationSpec. The specified colors " + "have not been validated." + ) + else: + if is_color_like is not None and not is_color_like(value): + raise ValueError(f"{value} is not a valid plotting color") + + return value + + +class VisualizationSpec(Tidy3dBaseModel): + """Defines specification for visualization when used with plotting functions.""" + + facecolor: str = Field( + "", + title="Face color", + description="Color applied to the faces in visualization.", + ) + + edgecolor: str = Field( + "", + title="Edge color", + description="Color applied to the edges in visualization.", + ) + + alpha: float = Field( + 1.0, + title="Opacity", + description="Opacity/alpha value in plotting between 0 and 1.", + ge=0, + le=1, + ) + + @field_validator("facecolor") + @classmethod + def _validate_facecolor(cls, value: str) -> str: + return is_valid_color(value) + + @field_validator("edgecolor") + @classmethod + def _ensure_edgecolor(cls, value: str, info: ValidationInfo) -> str: + # if no explicit edgecolor given, fall back to facecolor + if (value == "") and "facecolor" in info.data: + return is_valid_color(info.data["facecolor"]) + return is_valid_color(value) diff --git a/tidy3d/config/README.md b/tidy3d/_common/config/README.md similarity index 100% rename from tidy3d/config/README.md rename to tidy3d/_common/config/README.md diff --git a/tidy3d/_common/config/__init__.py b/tidy3d/_common/config/__init__.py new file mode 100644 index 0000000000..4be3a4ab1a --- /dev/null +++ b/tidy3d/_common/config/__init__.py @@ -0,0 +1,85 @@ +"""Tidy3D configuration system public API.""" + +from __future__ import annotations + +from typing import Any + +from .legacy import ( + LegacyConfigWrapper, + LegacyEnvironment, + LegacyEnvironmentConfig, +) +from .manager import ConfigManager +from .registry import ( + get_handlers, + get_sections, + register_handler, + register_plugin, + register_section, +) + +__all__ = [ + "ConfigManager", + "Env", + "Environment", + "EnvironmentConfig", + "config", + "get_handlers", + "get_sections", + "register_handler", + "register_plugin", + "register_section", +] + + +def _create_manager() -> ConfigManager: + return ConfigManager() + + +_base_manager = _create_manager() +# TODO(FXC-3827): Drop LegacyConfigWrapper once legacy accessors are removed in Tidy3D 2.12. +_config_wrapper = LegacyConfigWrapper(_base_manager) +config = _config_wrapper + +# TODO(FXC-3827): Remove legacy Env exports after deprecation window (planned 2.12). +Environment = LegacyEnvironment +EnvironmentConfig = LegacyEnvironmentConfig +Env: LegacyEnvironment | None = None + + +def initialize_env() -> None: + """Initialize legacy Env after sections register.""" + + global Env + if Env is None: + Env = LegacyEnvironment(_base_manager) + + +def reload_config(*, profile: str | None = None) -> LegacyConfigWrapper: + """Recreate the global configuration manager (primarily for tests).""" + + global _base_manager, Env + if _base_manager is not None: + try: + _base_manager.apply_web_env({}) + except AttributeError: + pass + _base_manager = ConfigManager(profile=profile) + _config_wrapper.reset_manager(_base_manager) + if Env is None: + initialize_env() + Env.reset_manager(_base_manager) + return _config_wrapper + + +def get_manager() -> ConfigManager: + """Return the underlying configuration manager instance.""" + + return _base_manager + + +def __getattr__(name: str) -> Any: + if name == "Env": + initialize_env() + return Env + return getattr(config, name) diff --git a/tidy3d/_common/config/legacy.py b/tidy3d/_common/config/legacy.py new file mode 100644 index 0000000000..015600349e --- /dev/null +++ b/tidy3d/_common/config/legacy.py @@ -0,0 +1,541 @@ +"""Legacy compatibility layer for tidy3d.config. + +This module holds (most) of the compatibility layer to the pre-2.10 tidy3d config +and is intended to be removed in a future release. +""" + +from __future__ import annotations + +import os +import warnings +from typing import TYPE_CHECKING, Any + +import toml + +from tidy3d._common._runtime import WASM_BUILD +from tidy3d._common.log import log + +# TODO(FXC-3827): Remove LegacyConfigWrapper/Environment shims and related helpers in Tidy3D 2.12. +from .manager import ConfigManager, normalize_profile_name +from .profiles import BUILTIN_PROFILES + +if TYPE_CHECKING: + from pathlib import Path + from typing import Optional + + from tidy3d._common.log import LogLevel + + +def _warn_env_deprecated() -> None: + message = "'tidy3d.config.Env' is deprecated; use 'config.switch_profile(...)' instead." + warnings.warn(message, DeprecationWarning, stacklevel=3) + log.warning(message, log_once=True) + + +# TODO(FXC-3827): Delete LegacyConfigWrapper once legacy attribute access is dropped. +class LegacyConfigWrapper: + """Provide attribute-level compatibility with the legacy config module.""" + + def __init__(self, manager: ConfigManager): + self._manager = manager + self._frozen = False # retained for backwards compatibility tests + + @property + def logging_level(self) -> LogLevel: + return self._manager.get_section("logging").level + + @logging_level.setter + def logging_level(self, value: LogLevel) -> None: + from warnings import warn + + warn( + "'config.logging_level' is deprecated; use 'config.logging.level' instead.", + DeprecationWarning, + stacklevel=2, + ) + self._manager.update_section("logging", level=value) + + @property + def log_suppression(self) -> bool: + return self._manager.get_section("logging").suppression + + @log_suppression.setter + def log_suppression(self, value: bool) -> None: + from warnings import warn + + warn( + "'config.log_suppression' is deprecated; use 'config.logging.suppression'.", + DeprecationWarning, + stacklevel=2, + ) + self._manager.update_section("logging", suppression=value) + + @property + def use_local_subpixel(self) -> Optional[bool]: + return self._manager.get_section("simulation").use_local_subpixel + + @use_local_subpixel.setter + def use_local_subpixel(self, value: Optional[bool]) -> None: + from warnings import warn + + warn( + "'config.use_local_subpixel' is deprecated; use 'config.simulation.use_local_subpixel'.", + DeprecationWarning, + stacklevel=2, + ) + self._manager.update_section("simulation", use_local_subpixel=value) + + @property + def suppress_rf_license_warning(self) -> bool: + return self._manager.get_section("microwave").suppress_rf_license_warning + + @suppress_rf_license_warning.setter + def suppress_rf_license_warning(self, value: bool) -> None: + from warnings import warn + + warn( + "'config.suppress_rf_license_warning' is deprecated; " + "use 'config.microwave.suppress_rf_license_warning'.", + DeprecationWarning, + stacklevel=2, + ) + self._manager.update_section("microwave", suppress_rf_license_warning=value) + + @property + def frozen(self) -> bool: + return self._frozen + + @frozen.setter + def frozen(self, value: bool) -> None: + self._frozen = bool(value) + + def save(self, include_defaults: bool = False) -> None: + self._manager.save(include_defaults=include_defaults) + + def reset_manager(self, manager: ConfigManager) -> None: + """Swap the underlying manager instance.""" + + self._manager = manager + + def switch_profile(self, profile: str) -> None: + """Switch active profile and synchronize the legacy environment proxy.""" + + normalized = normalize_profile_name(profile) + self._manager.switch_profile(normalized) + try: + from tidy3d._common.config import Env as _legacy_env + except Exception: + _legacy_env = None + if _legacy_env is not None: + _legacy_env._sync_to_manager(apply_env=True) + + def __getattr__(self, name: str) -> Any: + return getattr(self._manager, name) + + def __setattr__(self, name: str, value: Any) -> None: + if name.startswith("_"): + object.__setattr__(self, name, value) + elif name in { + "logging_level", + "log_suppression", + "use_local_subpixel", + "suppress_rf_license_warning", + "frozen", + }: + prop = getattr(type(self), name) + prop.fset(self, value) + else: + setattr(self._manager, name, value) + + def __str__(self) -> str: + return self._manager.format() + + +# TODO(FXC-3827): Delete LegacyEnvironmentConfig once profile-based Env shim is removed. +class LegacyEnvironmentConfig: + """Backward compatible environment config wrapper that proxies ConfigManager.""" + + def __init__( + self, + manager: Optional[ConfigManager] = None, + name: Optional[str] = None, + *, + web_api_endpoint: Optional[str] = None, + website_endpoint: Optional[str] = None, + s3_region: Optional[str] = None, + ssl_verify: Optional[bool] = None, + enable_caching: Optional[bool] = None, + ssl_version: Optional[str] = None, + env_vars: Optional[dict[str, str]] = None, + environment: Optional[LegacyEnvironment] = None, + ) -> None: + if name is None: + raise ValueError("Environment name is required") + self._manager = manager + self._name = normalize_profile_name(name) + self._environment = environment + self._pending: dict[str, Any] = {} + if web_api_endpoint is not None: + self._pending["api_endpoint"] = web_api_endpoint + if website_endpoint is not None: + self._pending["website_endpoint"] = website_endpoint + if s3_region is not None: + self._pending["s3_region"] = s3_region + if ssl_verify is not None: + self._pending["ssl_verify"] = ssl_verify + if enable_caching is not None: + self._pending["enable_caching"] = enable_caching + if ssl_version is not None: + self._pending["ssl_version"] = ssl_version + if env_vars is not None: + self._pending["env_vars"] = dict(env_vars) + + def reset_manager(self, manager: ConfigManager) -> None: + self._manager = manager + + @property + def manager(self) -> Optional[ConfigManager]: + if self._manager is not None: + return self._manager + if self._environment is not None: + return self._environment._manager + return None + + def active(self) -> None: + _warn_env_deprecated() + environment = self._environment + if environment is None: + from tidy3d._common.config import Env # local import to avoid circular + + environment = Env + + environment.set_current(self) + + @property + def web_api_endpoint(self) -> Optional[str]: + value = self._value("api_endpoint") + return _maybe_str(value) + + @property + def website_endpoint(self) -> Optional[str]: + value = self._value("website_endpoint") + return _maybe_str(value) + + @property + def s3_region(self) -> Optional[str]: + return self._value("s3_region") + + @property + def ssl_verify(self) -> bool: + value = self._value("ssl_verify") + if value is None: + return True + return bool(value) + + @property + def enable_caching(self) -> bool: + value = self._value("enable_caching") + if value is None: + return True + return bool(value) + + @enable_caching.setter + def enable_caching(self, value: Optional[bool]) -> None: + self._set_pending("enable_caching", value) + + @property + def ssl_version(self) -> Optional[str]: + return self._value("ssl_version") + + @ssl_version.setter + def ssl_version(self, value: Optional[str]) -> None: + self._set_pending("ssl_version", value) + + @property + def env_vars(self) -> dict[str, str]: + value = self._value("env_vars") + if value is None: + return {} + return dict(value) + + @env_vars.setter + def env_vars(self, value: dict[str, str]) -> None: + self._set_pending("env_vars", dict(value)) + + @property + def name(self) -> str: + return self._name + + @name.setter + def name(self, value: str) -> None: + self._name = normalize_profile_name(value) + + def copy_state_from(self, other: LegacyEnvironmentConfig) -> None: + if not isinstance(other, LegacyEnvironmentConfig): + raise TypeError("Expected LegacyEnvironmentConfig instance.") + for key, value in other._pending.items(): + if key == "env_vars" and value is not None: + self._pending[key] = dict(value) + else: + self._pending[key] = value + + def get_real_url(self, path: str) -> str: + manager = self.manager + if manager is not None and manager.profile == self._name: + web_section = manager.get_section("web") + if hasattr(web_section, "build_api_url"): + return web_section.build_api_url(path) + + endpoint = self.web_api_endpoint or "" + if not path: + return endpoint + return "/".join([endpoint.rstrip("/"), str(path).lstrip("/")]) + + def apply_pending_overrides(self) -> None: + manager = self.manager + if manager is None or manager.profile != self._name: + return + if not self._pending: + return + updates = dict(self._pending) + manager.update_section("web", **updates) + self._pending.clear() + + def _set_pending(self, key: str, value: Any) -> None: + if key == "env_vars" and value is not None: + self._pending[key] = dict(value) + else: + self._pending[key] = value + self.apply_pending_overrides() + + def _web_section(self) -> dict[str, Any]: + manager = self.manager + if manager is None or WASM_BUILD: + return {} + profile = normalize_profile_name(self._name) + if manager.profile == profile: + section = manager.get_section("web") + return section.model_dump(mode="python", exclude_unset=False) + preview = manager.preview_profile(profile) + source = preview.get("web", {}) + return dict(source) if isinstance(source, dict) else {} + + def _value(self, key: str) -> Any: + if key in self._pending: + return self._pending[key] + return self._web_section().get(key) + + +# TODO(FXC-3827): Delete LegacyEnvironment after deprecating `tidy3d.config.Env`. +class LegacyEnvironment: + """Legacy Env wrapper that maps to profiles.""" + + def __init__(self, manager: ConfigManager): + self._previous_env_vars: dict[str, Optional[str]] = {} + self.env_map: dict[str, LegacyEnvironmentConfig] = {} + self._current: Optional[LegacyEnvironmentConfig] = None + self._manager: Optional[ConfigManager] = None + self._applied_profile: Optional[str] = None + self.reset_manager(manager) + + def reset_manager(self, manager: ConfigManager) -> None: + self._manager = manager + self.env_map = {} + for name in BUILTIN_PROFILES: + key = normalize_profile_name(name) + self.env_map[key] = LegacyEnvironmentConfig(manager, key, environment=self) + self._applied_profile = None + self._current = None + self._sync_to_manager(apply_env=True) + + @property + def current(self) -> LegacyEnvironmentConfig: + self._sync_to_manager() + assert self._current is not None + return self._current + + def set_current(self, env_config: LegacyEnvironmentConfig) -> None: + _warn_env_deprecated() + key = normalize_profile_name(env_config.name) + stored = self._get_config(key) + stored.copy_state_from(env_config) + if self._manager and self._manager.profile != key: + self._manager.switch_profile(key) + self._sync_to_manager(apply_env=True) + + def enable_caching(self, enable_caching: Optional[bool] = True) -> None: + config = self.current + config.enable_caching = enable_caching + self._sync_to_manager() + + def set_ssl_version(self, ssl_version: Optional[str]) -> None: + config = self.current + config.ssl_version = ssl_version + self._sync_to_manager() + + def __getattr__(self, name: str) -> LegacyEnvironmentConfig: + return self._get_config(name) + + def _get_config(self, name: str) -> LegacyEnvironmentConfig: + key = normalize_profile_name(name) + config = self.env_map.get(key) + if config is None: + config = LegacyEnvironmentConfig(self._manager, key, environment=self) + self.env_map[key] = config + else: + manager = self._manager + if manager is not None: + config.reset_manager(manager) + config._environment = self + return config + + def _sync_to_manager(self, *, apply_env: bool = False) -> None: + if self._manager is None: + return + active = normalize_profile_name(self._manager.profile) + config = self._get_config(active) + config.apply_pending_overrides() + self._current = config + if apply_env or self._applied_profile != active: + self._apply_env_vars(config) + self._applied_profile = active + + def _apply_env_vars(self, config: LegacyEnvironmentConfig) -> None: + self._restore_env_vars() + env_vars = config.env_vars or {} + self._previous_env_vars = {} + for key, value in env_vars.items(): + self._previous_env_vars[key] = os.environ.get(key) + os.environ[key] = value + + def _restore_env_vars(self) -> None: + for key, previous in self._previous_env_vars.items(): + if previous is None: + os.environ.pop(key, None) + else: + os.environ[key] = previous + self._previous_env_vars = {} + + +def _maybe_str(value: Any) -> Optional[str]: + if value is None: + return None + return str(value) + + +def load_legacy_flat_config(config_dir: Path) -> dict[str, Any]: + """Load legacy flat configuration file (pre-migration format). + + This function now supports both the original flat config format and + Nexus custom deployment settings introduced in later versions. + + Legacy key mappings: + - apikey -> web.apikey + - web_api_endpoint -> web.api_endpoint + - website_endpoint -> web.website_endpoint + - s3_region -> web.s3_region + - s3_endpoint -> web.env_vars.AWS_ENDPOINT_URL_S3 + - ssl_verify -> web.ssl_verify + - enable_caching -> web.enable_caching + """ + + legacy_path = config_dir / "config" + if not legacy_path.exists(): + return {} + + try: + text = legacy_path.read_text(encoding="utf-8") + except Exception as exc: + log.warning(f"Failed to read legacy configuration file '{legacy_path}': {exc}") + return {} + + try: + parsed = toml.loads(text) + except Exception as exc: + log.warning(f"Failed to decode legacy configuration file '{legacy_path}': {exc}") + return {} + + legacy_data: dict[str, Any] = {} + + # Migrate API key (original functionality) + apikey = parsed.get("apikey") + if apikey is not None: + legacy_data.setdefault("web", {})["apikey"] = apikey + + # Migrate Nexus API endpoint + web_api = parsed.get("web_api_endpoint") + if web_api is not None: + legacy_data.setdefault("web", {})["api_endpoint"] = web_api + + # Migrate Nexus website endpoint + website = parsed.get("website_endpoint") + if website is not None: + legacy_data.setdefault("web", {})["website_endpoint"] = website + + # Migrate S3 region + s3_region = parsed.get("s3_region") + if s3_region is not None: + legacy_data.setdefault("web", {})["s3_region"] = s3_region + + # Migrate SSL verification setting + ssl_verify = parsed.get("ssl_verify") + if ssl_verify is not None: + legacy_data.setdefault("web", {})["ssl_verify"] = ssl_verify + + # Migrate caching setting + enable_caching = parsed.get("enable_caching") + if enable_caching is not None: + legacy_data.setdefault("web", {})["enable_caching"] = enable_caching + + # Migrate S3 endpoint to env_vars + s3_endpoint = parsed.get("s3_endpoint") + if s3_endpoint is not None: + env_vars = legacy_data.setdefault("web", {}).setdefault("env_vars", {}) + env_vars["AWS_ENDPOINT_URL_S3"] = s3_endpoint + + return legacy_data + + +__all__ = [ + "LegacyConfigWrapper", + "LegacyEnvironment", + "LegacyEnvironmentConfig", + "finalize_legacy_migration", + "load_legacy_flat_config", +] + + +def finalize_legacy_migration(config_dir: Path) -> None: + """Promote a copied legacy configuration tree into the structured format. + + Parameters + ---------- + config_dir : Path + Destination directory (typically the canonical config location). + """ + + legacy_data = load_legacy_flat_config(config_dir) + + from .manager import ConfigManager # local import to avoid circular dependency + + manager = ConfigManager(profile="default", config_dir=config_dir) + config_path = config_dir / "config.toml" + for section, values in legacy_data.items(): + if isinstance(values, dict): + manager.update_section(section, **values) + try: + manager.save(include_defaults=True) + except Exception: + if config_path.exists(): + try: + config_path.unlink() + except Exception: + pass + raise + + legacy_flat_path = config_dir / "config" + if legacy_flat_path.exists(): + try: + legacy_flat_path.unlink() + except Exception as exc: + log.warning(f"Failed to remove legacy configuration file '{legacy_flat_path}': {exc}") diff --git a/tidy3d/_common/config/loader.py b/tidy3d/_common/config/loader.py new file mode 100644 index 0000000000..6875fcb4c8 --- /dev/null +++ b/tidy3d/_common/config/loader.py @@ -0,0 +1,451 @@ +"""Filesystem helpers and persistence utilities for the configuration system.""" + +from __future__ import annotations + +import os +import shutil +import tempfile +from copy import deepcopy +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import toml +import tomlkit + +from tidy3d._common.log import log + +from .profiles import BUILTIN_PROFILES +from .serializer import build_document, collect_descriptions + +if TYPE_CHECKING: + from typing import Optional + + +class ConfigLoader: + """Handle reading and writing configuration files.""" + + def __init__(self, config_dir: Optional[Path] = None): + self.config_dir = config_dir or resolve_config_directory() + self.config_dir.mkdir(mode=0o700, parents=True, exist_ok=True) + self._docs: dict[Path, tomlkit.TOMLDocument] = {} + + def load_base(self) -> dict[str, Any]: + """Load base configuration from config.toml. + + If config.toml doesn't exist but the legacy flat config does, + automatically migrate to the new format. + """ + + config_path = self.config_dir / "config.toml" + data = self._read_toml(config_path) + if data: + return data + + # Check for legacy flat config + from .legacy import load_legacy_flat_config + + legacy_path = self.config_dir / "config" + legacy = load_legacy_flat_config(self.config_dir) + + # Auto-migrate if legacy config exists + if legacy and legacy_path.exists(): + log.info( + f"Detected legacy configuration at '{legacy_path}'. " + "Automatically migrating to new format..." + ) + + try: + # Save in new format + self.save_base(legacy) + + # Rename old config to preserve it + backup_path = legacy_path.with_suffix(".migrated") + legacy_path.rename(backup_path) + + log.info( + f"Migration complete. Configuration saved to '{config_path}'. " + f"Legacy config backed up as '{backup_path.name}'." + ) + + # Re-read the newly created config + return self._read_toml(config_path) + except Exception as exc: + log.warning( + f"Failed to auto-migrate legacy configuration: {exc}. " + "Using legacy data without migration." + ) + return legacy + + if legacy: + return legacy + return {} + + def load_user_profile(self, profile: str) -> dict[str, Any]: + """Load user profile overrides (if any).""" + + if profile in ("default", "prod"): + # default and prod share the same baseline; user overrides live in config.toml + return {} + + profile_path = self.profile_path(profile) + return self._read_toml(profile_path) + + def get_builtin_profile(self, profile: str) -> dict[str, Any]: + """Return builtin profile data if available.""" + + return BUILTIN_PROFILES.get(profile, {}) + + def save_base(self, data: dict[str, Any]) -> None: + """Persist base configuration.""" + + config_path = self.config_dir / "config.toml" + self._atomic_write(config_path, data) + + def save_profile(self, profile: str, data: dict[str, Any]) -> None: + """Persist profile overrides (remove file if empty).""" + + profile_path = self.profile_path(profile) + if not data: + if profile_path.exists(): + profile_path.unlink() + self._docs.pop(profile_path, None) + return + profile_path.parent.mkdir(mode=0o700, parents=True, exist_ok=True) + self._atomic_write(profile_path, data) + + def profile_path(self, profile: str) -> Path: + """Return on-disk path for a profile.""" + + return self.config_dir / "profiles" / f"{profile}.toml" + + def get_default_profile(self) -> Optional[str]: + """Read the default_profile from config.toml. + + Returns + ------- + Optional[str] + The default profile name if set, None otherwise. + """ + + config_path = self.config_dir / "config.toml" + if not config_path.exists(): + return None + + try: + text = config_path.read_text(encoding="utf-8") + data = toml.loads(text) + return data.get("default_profile") + except Exception as exc: + log.warning(f"Failed to read default_profile from '{config_path}': {exc}") + return None + + def set_default_profile(self, profile: Optional[str]) -> None: + """Set the default_profile in config.toml. + + Parameters + ---------- + profile : Optional[str] + The profile name to set as default, or None to remove the setting. + """ + + config_path = self.config_dir / "config.toml" + data = self._read_toml(config_path) + + if profile is None: + # Remove default_profile if it exists + if "default_profile" in data: + del data["default_profile"] + else: + # Set default_profile as a top-level key + data["default_profile"] = profile + + self._atomic_write(config_path, data) + + def _read_toml(self, path: Path) -> dict[str, Any]: + if not path.exists(): + self._docs.pop(path, None) + return {} + + try: + text = path.read_text(encoding="utf-8") + except Exception as exc: + log.warning(f"Failed to read configuration file '{path}': {exc}") + self._docs.pop(path, None) + return {} + + try: + document = tomlkit.parse(text) + except Exception as exc: + log.warning(f"Failed to parse configuration file '{path}': {exc}") + document = tomlkit.document() + self._docs[path] = document + + try: + return toml.loads(text) + except Exception as exc: + log.warning(f"Failed to decode configuration file '{path}': {exc}") + return {} + + def _atomic_write(self, path: Path, data: dict[str, Any]) -> None: + path.parent.mkdir(mode=0o700, parents=True, exist_ok=True) + tmp_dir = path.parent + + cleaned = _clean_data(deepcopy(data)) + descriptions = collect_descriptions() + + base_document = self._docs.get(path) + document = build_document(cleaned, base_document, descriptions) + toml_text = tomlkit.dumps(document) + + with tempfile.NamedTemporaryFile( + "w", dir=tmp_dir, delete=False, encoding="utf-8" + ) as handle: + tmp_path = Path(handle.name) + handle.write(toml_text) + handle.flush() + os.fsync(handle.fileno()) + + backup_path = path.with_suffix(path.suffix + ".bak") + try: + if path.exists(): + shutil.copy2(path, backup_path) + tmp_path.replace(path) + os.chmod(path, 0o600) + if backup_path.exists(): + backup_path.unlink() + except Exception: + if tmp_path.exists(): + tmp_path.unlink() + if backup_path.exists(): + try: + backup_path.replace(path) + except Exception: + log.warning("Failed to restore configuration backup") + raise + + self._docs[path] = tomlkit.parse(toml_text) + + +def load_environment_overrides() -> dict[str, Any]: + """Parse environment variables into a nested configuration dict.""" + + overrides: dict[str, Any] = {} + for key, value in os.environ.items(): + if key == "SIMCLOUD_APIKEY": + _assign_path(overrides, ("web", "apikey"), value) + continue + if not key.startswith("TIDY3D_"): + continue + rest = key[len("TIDY3D_") :] + if "__" not in rest: + continue + segments = tuple(segment.lower() for segment in rest.split("__") if segment) + if not segments: + continue + if segments[0] == "auth": + segments = ("web",) + segments[1:] + _assign_path(overrides, segments, value) + return overrides + + +def deep_merge(*sources: dict[str, Any]) -> dict[str, Any]: + """Deep merge multiple dictionaries into a new dict.""" + + result: dict[str, Any] = {} + for source in sources: + _merge_into(result, source) + return result + + +def _merge_into(target: dict[str, Any], source: dict[str, Any]) -> None: + for key, value in source.items(): + if isinstance(value, dict): + node = target.setdefault(key, {}) + if isinstance(node, dict): + _merge_into(node, value) + else: + target[key] = deepcopy(value) + else: + target[key] = value + + +def deep_diff(base: dict[str, Any], target: dict[str, Any]) -> dict[str, Any]: + """Return keys from target that differ from base.""" + + diff: dict[str, Any] = {} + keys = set(base.keys()) | set(target.keys()) + for key in keys: + base_value = base.get(key) + target_value = target.get(key) + if isinstance(base_value, dict) and isinstance(target_value, dict): + nested = deep_diff(base_value, target_value) + if nested: + diff[key] = nested + elif target_value != base_value: + if isinstance(target_value, dict): + diff[key] = deepcopy(target_value) + else: + diff[key] = target_value + return diff + + +def _assign_path(target: dict[str, Any], path: tuple[str, ...], value: Any) -> None: + node = target + for segment in path[:-1]: + node = node.setdefault(segment, {}) + node[path[-1]] = value + + +def _clean_data(data: Any) -> Any: + if isinstance(data, dict): + cleaned: dict[str, Any] = {} + for key, value in data.items(): + cleaned_value = _clean_data(value) + if cleaned_value is None: + continue + cleaned[key] = cleaned_value + return cleaned + if isinstance(data, list): + cleaned_list = [_clean_data(item) for item in data] + return [item for item in cleaned_list if item is not None] + if data is None: + return None + return data + + +def legacy_config_directory() -> Path: + """Return the legacy configuration directory (~/.tidy3d).""" + + return Path.home() / ".tidy3d" + + +def canonical_config_directory() -> Path: + """Return the platform-dependent canonical configuration directory.""" + + return _xdg_config_home() / "tidy3d" + + +def _warn_legacy_dir_ignored(*, canonical_dir: Path, legacy_dir: Path) -> None: + if legacy_dir.exists(): + log.warning( + f"Using canonical configuration directory at '{canonical_dir}'. " + "Found legacy directory at '~/.tidy3d', which will be ignored. " + "Remove it manually or run 'tidy3d config migrate --delete-legacy' to clean up.", + log_once=True, + ) + + +def resolve_config_directory() -> Path: + """Determine the directory used to store tidy3d configuration files.""" + + base_override = os.getenv("TIDY3D_BASE_DIR") + if base_override: + base_path = Path(base_override).expanduser().resolve() + path = base_path / "config" + if path.is_dir(): + return path + if _is_writable(path.parent): + return path + log.warning( + "'TIDY3D_BASE_DIR' is not writable; using temporary configuration directory instead." + ) + return _temporary_config_dir() + + canonical_dir = canonical_config_directory() + legacy_dir = legacy_config_directory() + if canonical_dir.is_dir(): + _warn_legacy_dir_ignored(canonical_dir=canonical_dir, legacy_dir=legacy_dir) + return canonical_dir + if _is_writable(canonical_dir.parent): + _warn_legacy_dir_ignored(canonical_dir=canonical_dir, legacy_dir=legacy_dir) + return canonical_dir + + if legacy_dir.exists(): + log.warning( + "Configuration found in legacy location '~/.tidy3d'. Consider running 'tidy3d config migrate'.", + log_once=True, + ) + return legacy_dir + + log.warning(f"Unable to write to '{canonical_dir}'; falling back to temporary directory.") + return _temporary_config_dir() + + +def _xdg_config_home() -> Path: + xdg_home = os.getenv("XDG_CONFIG_HOME") + if xdg_home: + return Path(xdg_home).expanduser() + return Path.home() / ".config" + + +def _temporary_config_dir() -> Path: + base = Path(tempfile.gettempdir()) / "tidy3d" + base.mkdir(mode=0o700, exist_ok=True) + return base / "config" + + +def _is_writable(path: Path) -> bool: + try: + path.mkdir(parents=True, exist_ok=True) + fd, test_path = tempfile.mkstemp(dir=path, prefix=".tidy3d_write_test_") + os.close(fd) + try: + Path(test_path).unlink() + except FileNotFoundError: + pass + return True + except Exception: + return False + + +def migrate_legacy_config(*, overwrite: bool = False, remove_legacy: bool = False) -> Path: + """Copy configuration files from the legacy ``~/.tidy3d`` directory to the canonical location. + + Parameters + ---------- + overwrite : bool + If ``True``, existing files in the canonical directory will be replaced. + remove_legacy : bool + If ``True``, the legacy directory is removed after a successful migration. + + Returns + ------- + Path + The path of the canonical configuration directory. + + Raises + ------ + FileNotFoundError + If the legacy directory does not exist. + FileExistsError + If the destination already exists and ``overwrite`` is ``False``. + RuntimeError + If the legacy and canonical directories resolve to the same location. + """ + + legacy_dir = legacy_config_directory() + if not legacy_dir.exists(): + raise FileNotFoundError("Legacy configuration directory '~/.tidy3d' was not found.") + + canonical_dir = canonical_config_directory() + if canonical_dir.resolve() == legacy_dir.resolve(): + raise RuntimeError( + "Legacy and canonical configuration directories are the same path; nothing to migrate." + ) + + if canonical_dir.exists() and not overwrite: + raise FileExistsError( + f"Destination '{canonical_dir}' already exists. Pass overwrite=True to replace existing files." + ) + + canonical_dir.parent.mkdir(parents=True, exist_ok=True) + shutil.copytree(legacy_dir, canonical_dir, dirs_exist_ok=overwrite) + + from .legacy import finalize_legacy_migration # local import to avoid circular dependency + + finalize_legacy_migration(canonical_dir) + + if remove_legacy: + shutil.rmtree(legacy_dir) + + return canonical_dir diff --git a/tidy3d/_common/config/manager.py b/tidy3d/_common/config/manager.py new file mode 100644 index 0000000000..ffd1277913 --- /dev/null +++ b/tidy3d/_common/config/manager.py @@ -0,0 +1,634 @@ +"""Central configuration manager implementation.""" + +from __future__ import annotations + +import os +import shutil +from collections import defaultdict +from copy import deepcopy +from enum import Enum +from io import StringIO +from pathlib import Path +from typing import TYPE_CHECKING, Any, get_args, get_origin + +from pydantic import BaseModel +from rich.console import Console +from rich.panel import Panel +from rich.pretty import Pretty +from rich.text import Text +from rich.tree import Tree + +from tidy3d._common.log import log + +from .loader import ConfigLoader, deep_diff, deep_merge, load_environment_overrides +from .profiles import BUILTIN_PROFILES +from .registry import attach_manager, get_handlers, get_sections + +if TYPE_CHECKING: + from collections.abc import Iterable, Mapping + from typing import Optional + + +def normalize_profile_name(name: str) -> str: + """Return a canonical profile name for builtin profiles.""" + + normalized = name.strip() + lowered = normalized.lower() + if lowered in BUILTIN_PROFILES: + return lowered + return normalized + + +class SectionAccessor: + """Attribute proxy that routes assignments back through the manager.""" + + def __init__(self, manager: ConfigManager, path: str): + self._manager = manager + self._path = path + + def __getattr__(self, name: str) -> Any: + model = self._manager._get_model(self._path) + if model is None: + raise AttributeError(f"Section '{self._path}' is not available") + return getattr(model, name) + + def __setattr__(self, name: str, value: Any) -> None: + if name.startswith("_"): + object.__setattr__(self, name, value) + return + self._manager.update_section(self._path, **{name: value}) + + def __repr__(self) -> str: + model = self._manager._get_model(self._path) + return f"SectionAccessor({self._path}={model!r})" + + def __rich__(self) -> Panel: + model = self._manager._get_model(self._path) + if model is None: + return Panel(Text(f"Section '{self._path}' is unavailable", style="red")) + data = _prepare_for_display(model.model_dump(exclude_unset=False)) + return _build_section_panel(self._path, data) + + def dict(self, *args: Any, **kwargs: Any) -> dict[str, Any]: + model = self._manager._get_model(self._path) + if model is None: + return {} + return model.model_dump(*args, **kwargs) + + def __str__(self) -> str: + return self._manager.format_section(self._path) + + +class PluginsAccessor: + """Provides access to registered plugin configurations.""" + + def __init__(self, manager: ConfigManager): + self._manager = manager + + def __getattr__(self, plugin: str) -> SectionAccessor: + if plugin not in self._manager._plugin_models: + raise AttributeError(f"Plugin '{plugin}' is not registered") + return SectionAccessor(self._manager, f"plugins.{plugin}") + + def list(self) -> Iterable[str]: + return sorted(self._manager._plugin_models.keys()) + + +class ProfilesAccessor: + """Read-only profile helper.""" + + def __init__(self, manager: ConfigManager): + self._manager = manager + + def list(self) -> dict[str, list[str]]: + return self._manager.list_profiles() + + def __getattr__(self, profile: str) -> dict[str, Any]: + return self._manager.preview_profile(profile) + + +class ConfigManager: + """High-level orchestrator for tidy3d configuration.""" + + def __init__( + self, + profile: Optional[str] = None, + config_dir: Optional[os.PathLike[str]] = None, + ): + loader_path = None if config_dir is None else Path(config_dir) + self._loader = ConfigLoader(loader_path) + self._runtime_overrides: dict[str, dict[str, Any]] = defaultdict(dict) + self._plugin_models: dict[str, BaseModel] = {} + self._section_models: dict[str, BaseModel] = {} + self._profile = self._resolve_initial_profile(profile) + self._builtin_data: dict[str, Any] = {} + self._base_data: dict[str, Any] = {} + self._profile_data: dict[str, Any] = {} + self._raw_tree: dict[str, Any] = {} + self._effective_tree: dict[str, Any] = {} + self._env_overrides: dict[str, Any] = load_environment_overrides() + self._web_env_previous: dict[str, Optional[str]] = {} + + attach_manager(self) + self._reload() + + # Notify users when using a non-default profile + if self._profile != "default": + log.info(f"Using configuration profile: '{self._profile}'", log_once=True) + + self._apply_handlers() + + @property + def profile(self) -> str: + return self._profile + + @property + def config_dir(self) -> Path: + return self._loader.config_dir + + @property + def plugins(self) -> PluginsAccessor: + return PluginsAccessor(self) + + @property + def profiles(self) -> ProfilesAccessor: + return ProfilesAccessor(self) + + def update_section(self, name: str, **updates: Any) -> None: + if not updates: + return + segments = name.split(".") + overrides = self._runtime_overrides[self._profile] + previous = deepcopy(overrides) + node = overrides + for segment in segments[:-1]: + node = node.setdefault(segment, {}) + section_key = segments[-1] + section_payload = node.setdefault(section_key, {}) + for key, value in updates.items(): + section_payload[key] = _serialize_value(value) + try: + self._reload() + except Exception: + self._runtime_overrides[self._profile] = previous + raise + self._apply_handlers(section=name) + + def switch_profile(self, profile: str) -> None: + if not profile: + raise ValueError("Profile name cannot be empty") + normalized = normalize_profile_name(profile) + if not normalized: + raise ValueError("Profile name cannot be empty") + self._profile = normalized + self._reload() + + # Notify users when switching to a non-default profile + if self._profile != "default": + log.info(f"Switched to configuration profile: '{self._profile}'") + + self._apply_handlers() + + def set_default_profile(self, profile: Optional[str]) -> None: + """Set the default profile to be used on startup. + + Parameters + ---------- + profile : Optional[str] + The profile name to use as default, or None to clear the default. + When set, this profile will be automatically loaded unless overridden + by environment variables (TIDY3D_CONFIG_PROFILE, TIDY3D_PROFILE, or TIDY3D_ENV). + + Notes + ----- + This setting is persisted to config.toml and survives across sessions. + Environment variables always take precedence over the default profile. + """ + + if profile is not None: + normalized = normalize_profile_name(profile) + if not normalized: + raise ValueError("Profile name cannot be empty") + self._loader.set_default_profile(normalized) + else: + self._loader.set_default_profile(None) + + def get_default_profile(self) -> Optional[str]: + """Get the currently configured default profile. + + Returns + ------- + Optional[str] + The default profile name if set, None otherwise. + """ + + return self._loader.get_default_profile() + + def save(self, include_defaults: bool = False) -> None: + if self._profile == "default": + # For base config: only save fields marked with persist=True + base_without_env = self._filter_persisted(self._compose_without_env()) + if include_defaults: + defaults = self._filter_persisted(self._default_tree()) + base_without_env = deep_merge(defaults, base_without_env) + self._loader.save_base(base_without_env) + else: + # For profile overrides: save any field that differs from baseline + # (don't filter by persist flag - profiles should save all customizations) + base_without_env = self._compose_without_env() + baseline = deep_merge(self._builtin_data, self._base_data) + diff = deep_diff(baseline, base_without_env) + self._loader.save_profile(self._profile, diff) + # refresh cached base/profile data after saving + self._base_data = self._loader.load_base() + self._profile_data = self._loader.load_user_profile(self._profile) + self._reload() + + def reset_to_defaults(self, *, include_profiles: bool = True) -> None: + """Reset configuration files to their default annotated state.""" + + self._runtime_overrides = defaultdict(dict) + defaults = self._filter_persisted(self._default_tree()) + self._loader.save_base(defaults) + + if include_profiles: + profiles_dir = self._loader.profile_path("_dummy").parent + if profiles_dir.exists(): + shutil.rmtree(profiles_dir) + loader_docs = getattr(self._loader, "_docs", {}) + for path in list(loader_docs.keys()): + try: + path.relative_to(profiles_dir) + except ValueError: + continue + loader_docs.pop(path, None) + self._profile = "default" + + self._reload() + self._apply_handlers() + + def apply_web_env(self, env_vars: Mapping[str, str]) -> None: + """Apply environment variable overrides for the web configuration section.""" + + self._restore_web_env() + for key, value in env_vars.items(): + self._web_env_previous[key] = os.environ.get(key) + os.environ[key] = value + + def _restore_web_env(self) -> None: + """Restore previously overridden environment variables.""" + + for key, previous in self._web_env_previous.items(): + if previous is None: + os.environ.pop(key, None) + else: + os.environ[key] = previous + self._web_env_previous.clear() + + def list_profiles(self) -> dict[str, list[str]]: + profiles_dir = self._loader.config_dir / "profiles" + user_profiles = [] + if profiles_dir.exists(): + for path in profiles_dir.glob("*.toml"): + user_profiles.append(path.stem) + built_in = sorted(name for name in BUILTIN_PROFILES.keys()) + return {"built_in": built_in, "user": sorted(user_profiles)} + + def preview_profile(self, profile: str) -> dict[str, Any]: + builtin = self._loader.get_builtin_profile(profile) + base = self._loader.load_base() + overrides = self._loader.load_user_profile(profile) + view = deep_merge(builtin, base, overrides) + return deepcopy(view) + + def get_section(self, name: str) -> BaseModel: + model = self._get_model(name) + if model is None: + raise AttributeError(f"Section '{name}' is not available") + return model + + def as_dict(self, include_env: bool = True) -> dict[str, Any]: + """Return the current configuration tree, including defaults for all sections.""" + + tree = self._compose_without_env() + if include_env: + tree = deep_merge(tree, self._env_overrides) + return deep_merge(self._default_tree(), tree) + + def __rich__(self) -> Panel: + """Return a rich renderable representation of the full configuration.""" + + return _build_config_panel( + title=f"Config (profile='{self._profile}')", + data=_prepare_for_display(self.as_dict(include_env=True)), + ) + + def format(self, *, include_env: bool = True) -> str: + """Return a human-friendly representation of the full configuration.""" + + panel = _build_config_panel( + title=f"Config (profile='{self._profile}')", + data=_prepare_for_display(self.as_dict(include_env=include_env)), + ) + return _render_panel(panel) + + def format_section(self, name: str) -> str: + """Return a string representation for an individual section.""" + + model = self._get_model(name) + if model is None: + raise AttributeError(f"Section '{name}' is not available") + data = _prepare_for_display(model.model_dump(exclude_unset=False)) + panel = _build_section_panel(name, data) + return _render_panel(panel) + + def on_section_registered(self, section: str) -> None: + self._reload() + self._apply_handlers(section=section) + + def on_handler_registered(self, section: str) -> None: + self._apply_handlers(section=section) + + def _resolve_initial_profile(self, profile: Optional[str]) -> str: + if profile: + return normalize_profile_name(str(profile)) + + # Check environment variables first (highest priority) + env_profile = ( + os.getenv("TIDY3D_CONFIG_PROFILE") + or os.getenv("TIDY3D_PROFILE") + or os.getenv("TIDY3D_ENV") + ) + if env_profile: + return normalize_profile_name(env_profile) + + # Check for default_profile in config file + config_default = self._loader.get_default_profile() + if config_default: + return normalize_profile_name(config_default) + + # Fall back to "default" profile + return "default" + + def _reload(self) -> None: + self._env_overrides = load_environment_overrides() + self._builtin_data = deepcopy(self._loader.get_builtin_profile(self._profile)) + self._base_data = deepcopy(self._loader.load_base()) + self._profile_data = deepcopy(self._loader.load_user_profile(self._profile)) + self._raw_tree = deep_merge(self._builtin_data, self._base_data, self._profile_data) + + runtime = deepcopy(self._runtime_overrides.get(self._profile, {})) + effective = deep_merge(self._raw_tree, self._env_overrides, runtime) + self._effective_tree = effective + self._build_models() + + def _build_models(self) -> None: + sections = get_sections() + new_sections: dict[str, BaseModel] = {} + new_plugins: dict[str, BaseModel] = {} + + errors: list[tuple[str, Exception]] = [] + for name, schema in sections.items(): + if name.startswith("plugins."): + plugin_name = name.split(".", 1)[1] + plugin_data = _deep_get(self._effective_tree, ("plugins", plugin_name)) or {} + try: + new_plugins[plugin_name] = schema(**plugin_data) + except Exception as exc: + log.error(f"Failed to load configuration for plugin '{plugin_name}': {exc}") + errors.append((name, exc)) + continue + if name == "plugins": + continue + section_data = self._effective_tree.get(name, {}) + try: + new_sections[name] = schema(**section_data) + except Exception as exc: + log.error(f"Failed to load configuration for section '{name}': {exc}") + errors.append((name, exc)) + + if errors: + # propagate the first error; others already logged + raise errors[0][1] + + self._section_models = new_sections + self._plugin_models = new_plugins + + def _get_model(self, name: str) -> Optional[BaseModel]: + if name.startswith("plugins."): + plugin = name.split(".", 1)[1] + return self._plugin_models.get(plugin) + return self._section_models.get(name) + + def _apply_handlers(self, section: Optional[str] = None) -> None: + handlers = get_handlers() + targets = [section] if section else handlers.keys() + for target in targets: + handler = handlers.get(target) + if handler is None: + continue + model = self._get_model(target) + if model is None: + continue + try: + handler(model) + except Exception as exc: + log.error(f"Failed to apply configuration handler for '{target}': {exc}") + + def _compose_without_env(self) -> dict[str, Any]: + runtime = self._runtime_overrides.get(self._profile, {}) + return deep_merge(self._raw_tree, runtime) + + def _default_tree(self) -> dict[str, Any]: + defaults: dict[str, Any] = {} + for name, schema in get_sections().items(): + if name.startswith("plugins."): + plugin = name.split(".", 1)[1] + defaults.setdefault("plugins", {})[plugin] = _model_dict(schema()) + elif name == "plugins": + defaults.setdefault("plugins", {}) + else: + defaults[name] = _model_dict(schema()) + return defaults + + def _filter_persisted(self, tree: dict[str, Any]) -> dict[str, Any]: + sections = get_sections() + filtered: dict[str, Any] = {} + plugins_source = tree.get("plugins", {}) + plugin_filtered: dict[str, Any] = {} + + for name, schema in sections.items(): + if name == "plugins": + continue + if name.startswith("plugins."): + plugin_name = name.split(".", 1)[1] + plugin_data = plugins_source.get(plugin_name, {}) + if not isinstance(plugin_data, dict): + continue + persisted_plugin = _extract_persisted(schema, plugin_data) + if persisted_plugin: + plugin_filtered[plugin_name] = persisted_plugin + continue + + section_data = tree.get(name, {}) + if not isinstance(section_data, dict): + continue + persisted_section = _extract_persisted(schema, section_data) + if persisted_section: + filtered[name] = persisted_section + + if plugin_filtered: + filtered["plugins"] = plugin_filtered + return filtered + + def __getattr__(self, name: str) -> Any: + if name in self._section_models: + return SectionAccessor(self, name) + if name == "plugins": + return self.plugins + raise AttributeError(f"Config has no section '{name}'") + + def __setattr__(self, name: str, value: Any) -> None: + if name.startswith("_"): + object.__setattr__(self, name, value) + return + if name in self._section_models: + if isinstance(value, BaseModel): + payload = value.model_dump(exclude_unset=False) + else: + payload = value + self.update_section(name, **payload) + return + object.__setattr__(self, name, value) + + def __str__(self) -> str: + return self.format() + + +def _deep_get(tree: dict[str, Any], path: Iterable[str]) -> Optional[dict[str, Any]]: + node: Any = tree + for segment in path: + if not isinstance(node, dict): + return None + node = node.get(segment) + if node is None: + return None + return node if isinstance(node, dict) else None + + +def _resolve_model_type(annotation: Any) -> Optional[type[BaseModel]]: + """Return the first BaseModel subclass found in an annotation (if any).""" + + if isinstance(annotation, type) and issubclass(annotation, BaseModel): + return annotation + + origin = get_origin(annotation) + if origin is None: + return None + + for arg in get_args(annotation): + nested = _resolve_model_type(arg) + if nested is not None: + return nested + return None + + +def _serialize_value(value: Any) -> Any: + if isinstance(value, BaseModel): + return value.model_dump(exclude_unset=False) + if hasattr(value, "get_secret_value"): + return value.get_secret_value() + return value + + +def _prepare_for_display(value: Any) -> Any: + if isinstance(value, BaseModel): + return { + k: _prepare_for_display(v) for k, v in value.model_dump(exclude_unset=False).items() + } + if isinstance(value, dict): + return {str(k): _prepare_for_display(v) for k, v in value.items()} + if isinstance(value, (list, tuple, set)): + return [_prepare_for_display(v) for v in value] + if isinstance(value, Path): + return str(value) + if isinstance(value, Enum): + return value.value + if hasattr(value, "get_secret_value"): + displayed = getattr(value, "display", None) + if callable(displayed): + return displayed() + return str(value) + return value + + +def _build_config_panel(title: str, data: dict[str, Any]) -> Panel: + tree = Tree(Text(title, style="bold cyan")) + if data: + for key in sorted(data.keys()): + branch = tree.add(Text(key, style="bold magenta")) + branch.add(Pretty(data[key], expand_all=True)) + else: + tree.add(Text("", style="dim")) + return Panel(tree, border_style="cyan", padding=(0, 1)) + + +def _build_section_panel(name: str, data: Any) -> Panel: + tree = Tree(Text(name, style="bold cyan")) + tree.add(Pretty(data, expand_all=True)) + return Panel(tree, border_style="cyan", padding=(0, 1)) + + +def _render_panel(renderable: Panel, *, width: int = 100) -> str: + buffer = StringIO() + console = Console(file=buffer, record=True, force_terminal=True, width=width, color_system=None) + console.print(renderable) + return buffer.getvalue().rstrip() + + +def _model_dict(model: BaseModel) -> dict[str, Any]: + data = model.model_dump(exclude_unset=False) + for key, value in list(data.items()): + if hasattr(value, "get_secret_value"): + data[key] = value.get_secret_value() + return data + + +def _extract_persisted(schema: type[BaseModel], data: dict[str, Any]) -> dict[str, Any]: + persisted: dict[str, Any] = {} + for field_name, field in schema.model_fields.items(): + schema_extra = field.json_schema_extra or {} + annotation = field.annotation + persist = bool(schema_extra.get("persist")) if isinstance(schema_extra, dict) else False + if not persist: + continue + if field_name not in data: + continue + value = data[field_name] + if value is None: + persisted[field_name] = None + continue + + nested_type = _resolve_model_type(annotation) + if nested_type is not None: + nested_source = value if isinstance(value, dict) else {} + nested_persisted = _extract_persisted(nested_type, nested_source) + if nested_persisted: + persisted[field_name] = nested_persisted + continue + + if hasattr(value, "get_secret_value"): + persisted[field_name] = value.get_secret_value() + else: + persisted[field_name] = deepcopy(value) + + return persisted + + +__all__ = [ + "ConfigManager", + "PluginsAccessor", + "ProfilesAccessor", + "SectionAccessor", + "normalize_profile_name", +] diff --git a/tidy3d/_common/config/profiles.py b/tidy3d/_common/config/profiles.py new file mode 100644 index 0000000000..29bbb43180 --- /dev/null +++ b/tidy3d/_common/config/profiles.py @@ -0,0 +1,64 @@ +"""Built-in configuration profiles for tidy3d.""" + +from __future__ import annotations + +from typing import Any + +BUILTIN_PROFILES: dict[str, dict[str, Any]] = { + "default": { + "web": { + "api_endpoint": "https://tidy3d-api.simulation.cloud", + "website_endpoint": "https://tidy3d.simulation.cloud", + "s3_region": "us-gov-west-1", + } + }, + "prod": { + "web": { + "api_endpoint": "https://tidy3d-api.simulation.cloud", + "website_endpoint": "https://tidy3d.simulation.cloud", + "s3_region": "us-gov-west-1", + } + }, + "dev": { + "web": { + "api_endpoint": "https://tidy3d-api.dev-simulation.cloud", + "website_endpoint": "https://tidy3d.dev-simulation.cloud", + "s3_region": "us-east-1", + } + }, + "uat": { + "web": { + "api_endpoint": "https://tidy3d-api.uat-simulation.cloud", + "website_endpoint": "https://tidy3d.uat-simulation.cloud", + "s3_region": "us-west-2", + } + }, + "pre": { + "web": { + "api_endpoint": "https://preprod-tidy3d-api.simulation.cloud", + "website_endpoint": "https://preprod-tidy3d.simulation.cloud", + "s3_region": "us-gov-west-1", + } + }, + "nexus": { + "web": { + "api_endpoint": "http://127.0.0.1:5000", + "website_endpoint": "http://127.0.0.1/tidy3d", + "ssl_verify": False, + "enable_caching": False, + "s3_region": "us-east-1", + "env_vars": { + "AWS_ENDPOINT_URL_S3": "http://127.0.0.1:9000", + }, + } + }, + "test": { + "web": { + "s3_region": "test", + "api_endpoint": "https://test", + "website_endpoint": "https://test", + } + }, +} + +__all__ = ["BUILTIN_PROFILES"] diff --git a/tidy3d/_common/config/registry.py b/tidy3d/_common/config/registry.py new file mode 100644 index 0000000000..7c1b16b7a1 --- /dev/null +++ b/tidy3d/_common/config/registry.py @@ -0,0 +1,83 @@ +"""Registry utilities for tidy3d configuration sections and handlers.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, TypeVar + +from pydantic import BaseModel + +if TYPE_CHECKING: + from typing import Callable, Optional + +T = TypeVar("T", bound=BaseModel) + +_SECTIONS: dict[str, type[BaseModel]] = {} +_HANDLERS: dict[str, Callable[[BaseModel], None]] = {} +_MANAGER: Optional[ConfigManagerProtocol] = None + + +class ConfigManagerProtocol: + """Protocol-like interface for manager notifications.""" + + def on_section_registered(self, section: str) -> None: + """Called when a new section schema is registered.""" + + def on_handler_registered(self, section: str) -> None: + """Called when a handler is registered.""" + + +def attach_manager(manager: ConfigManagerProtocol) -> None: + """Attach the active configuration manager for registry callbacks.""" + + global _MANAGER + _MANAGER = manager + + +def get_manager() -> Optional[ConfigManagerProtocol]: + """Return the currently attached configuration manager, if any.""" + + return _MANAGER + + +def register_section(name: str) -> Callable[[type[T]], type[T]]: + """Decorator to register a configuration section schema.""" + + def decorator(cls: type[T]) -> type[T]: + _SECTIONS[name] = cls + if _MANAGER is not None: + _MANAGER.on_section_registered(name) + return cls + + return decorator + + +def register_plugin(name: str) -> Callable[[type[T]], type[T]]: + """Decorator to register a plugin configuration schema.""" + + return register_section(f"plugins.{name}") + + +def register_handler( + name: str, +) -> Callable[[Callable[[BaseModel], None]], Callable[[BaseModel], None]]: + """Decorator to register a handler for a configuration section.""" + + def decorator(func: Callable[[BaseModel], None]) -> Callable[[BaseModel], None]: + _HANDLERS[name] = func + if _MANAGER is not None: + _MANAGER.on_handler_registered(name) + return func + + return decorator + + +def get_sections() -> dict[str, type[BaseModel]]: + """Return registered section schemas.""" + + return dict(_SECTIONS) + + +def get_handlers() -> dict[str, Callable[[BaseModel], None]]: + """Return registered configuration handlers.""" + + return dict(_HANDLERS) diff --git a/tidy3d/_common/config/serializer.py b/tidy3d/_common/config/serializer.py new file mode 100644 index 0000000000..5db5dc5d97 --- /dev/null +++ b/tidy3d/_common/config/serializer.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, get_args, get_origin + +import tomlkit +from pydantic import BaseModel +from tomlkit.items import Item, Table + +from .registry import get_sections + +if TYPE_CHECKING: + from collections.abc import Iterable + + from pydantic.fields import FieldInfo + +Path = tuple[str, ...] + + +def collect_descriptions() -> dict[Path, str]: + """Collect description strings for registered configuration fields.""" + + descriptions: dict[Path, str] = {} + for section_name, model in get_sections().items(): + base_path = tuple(segment for segment in section_name.split(".") if segment) + section_doc = (model.__doc__ or "").strip() + if section_doc and base_path: + descriptions[base_path] = descriptions.get( + base_path, section_doc.splitlines()[0].strip() + ) + for field_name, field in model.model_fields.items(): + descriptions.update(_describe_field(field, prefix=(*base_path, field_name))) + return descriptions + + +def _describe_field(field: FieldInfo, prefix: Path) -> dict[Path, str]: + descriptions: dict[Path, str] = {} + description = (field.description or "").strip() + if description: + descriptions[prefix] = description + + nested_models: Iterable[type[BaseModel]] = _iter_model_types(field.annotation) + for model in nested_models: + nested_doc = (model.__doc__ or "").strip() + if nested_doc: + descriptions[prefix] = descriptions.get(prefix, nested_doc.splitlines()[0].strip()) + for sub_name, sub_field in model.model_fields.items(): + descriptions.update(_describe_field(sub_field, prefix=(*prefix, sub_name))) + return descriptions + + +def _iter_model_types(annotation: Any) -> Iterable[type[BaseModel]]: + """Yield BaseModel subclasses referenced by a field annotation (if any).""" + + if annotation is None: + return + + stack = [annotation] + seen: set[type[BaseModel]] = set() + + while stack: + current = stack.pop() + if isinstance(current, type) and issubclass(current, BaseModel): + if current not in seen: + seen.add(current) + yield current + continue + + origin = get_origin(current) + if origin is None: + continue + + stack.extend(get_args(current)) + + +def build_document( + data: dict[str, Any], + existing: tomlkit.TOMLDocument | None, + descriptions: dict[Path, str] | None = None, +) -> tomlkit.TOMLDocument: + """Return a TOML document populated with data and annotated comments.""" + + descriptions = descriptions or collect_descriptions() + document = existing if existing is not None else tomlkit.document() + _prune_missing_keys(document, data.keys()) + for key, value in data.items(): + _apply_value( + container=document, + key=key, + value=value, + path=(key,), + descriptions=descriptions, + is_new=key not in document, + ) + return document + + +def _prune_missing_keys(container: Table | tomlkit.TOMLDocument, keys: Iterable[str]) -> None: + desired = set(keys) + for existing_key in list(container.keys()): + if existing_key not in desired: + del container[existing_key] + + +def _apply_value( + container: Table | tomlkit.TOMLDocument, + key: str, + value: Any, + path: Path, + descriptions: dict[Path, str], + is_new: bool, +) -> None: + description = descriptions.get(path) + if isinstance(value, dict): + existing = container.get(key) + table = existing if isinstance(existing, Table) else tomlkit.table() + _prune_missing_keys(table, value.keys()) + for sub_key, sub_value in value.items(): + _apply_value( + container=table, + key=sub_key, + value=sub_value, + path=(*path, sub_key), + descriptions=descriptions, + is_new=not isinstance(existing, Table) or sub_key not in table, + ) + if key in container: + container[key] = table + else: + if isinstance(container, tomlkit.TOMLDocument) and len(container) > 0: + container.add(tomlkit.nl()) + container.add(key, table) + return + + if value is None: + return + + existing_item = container.get(key) + new_item = tomlkit.item(value) + if isinstance(existing_item, Item): + new_item.trivia.comment = existing_item.trivia.comment + new_item.trivia.comment_ws = existing_item.trivia.comment_ws + elif description: + new_item.comment(description) + + if key in container: + container[key] = new_item + else: + container.add(key, new_item) diff --git a/tidy3d/_common/constants.py b/tidy3d/_common/constants.py new file mode 100644 index 0000000000..81b168cad5 --- /dev/null +++ b/tidy3d/_common/constants.py @@ -0,0 +1,313 @@ +"""Defines importable constants. + +Attributes: + inf (float): Tidy3d representation of infinity. + C_0 (float): Speed of light in vacuum [um/s] + EPSILON_0 (float): Vacuum permittivity [F/um] + MU_0 (float): Vacuum permeability [H/um] + ETA_0 (float): Vacuum impedance + HBAR (float): reduced Planck constant [eV*s] + Q_e (float): funamental charge [C] +""" + +from __future__ import annotations + +from types import MappingProxyType + +import numpy as np + +# fundamental constants (https://physics.nist.gov) +C_0 = 2.99792458e14 +""" +Speed of light in vacuum [um/s] +""" + +MU_0 = 1.25663706212e-12 +""" +Vacuum permeability [H/um] +""" + +EPSILON_0 = 1 / (MU_0 * C_0**2) +""" +Vacuum permittivity [F/um] +""" + +#: Free space impedance +ETA_0 = np.sqrt(MU_0 / EPSILON_0) +""" +Vacuum impedance in Ohms +""" + +Q_e = 1.602176634e-19 +""" +Fundamental charge [C] +""" + +HBAR = 6.582119569e-16 +""" +Reduced Planck constant [eV*s] +""" + +K_B = 8.617333262e-5 +""" +Boltzmann constant [eV/K] +""" + +GRAV_ACC = 9.80665 * 1e6 +""" +Gravitational acceleration (g) [um/s^2].", +""" + +M_E_C_SQUARE = 0.51099895069e6 +""" +Electron rest mass energy (m_e * c^2) [eV] +""" + +M_E_EV = M_E_C_SQUARE / C_0**2 +""" +Electron mass [eV*s^2/um^2] +""" + +# floating point precisions +dp_eps = np.finfo(np.float64).eps +""" +Double floating point precision. +""" + +fp_eps = np.float64(np.finfo(np.float32).eps) +""" +Floating point precision. +""" + +# values of PEC for mode solver +pec_val = -1e8 +""" +PEC values for mode solver +""" + +# unit labels +HERTZ = "Hz" +""" +One cycle per second. +""" + +TERAHERTZ = "THz" +""" +One trillion (10^12) cycles per second. +""" + +SECOND = "sec" +""" +SI unit of time. +""" + +PICOSECOND = "ps" +""" +One trillionth (10^-12) of a second. +""" + +METER = "m" +""" +SI unit of length. +""" + +PERMETER = "1/m" +""" +SI unit of inverse length. +""" + +MICROMETER = "um" +""" +One millionth (10^-6) of a meter. +""" + +NANOMETER = "nm" +""" +One billionth (10^-9) of a meter. +""" + +RADIAN = "rad" +""" +SI unit of angle. +""" + +CONDUCTIVITY = "S/um" +""" +Siemens per micrometer. +""" + +PERMITTIVITY = "None (relative permittivity)" +""" +Relative permittivity. +""" + +PML_SIGMA = "2*EPSILON_0/dt" +""" +2 times vacuum permittivity over time differential step. +""" + +RADPERSEC = "rad/sec" +""" +One radian per second. +""" + +RADPERMETER = "rad/m" +""" +One radian per meter. +""" + +NEPERPERMETER = "Np/m" +""" +SI unit for attenuation constant. +""" + + +ELECTRON_VOLT = "eV" +""" +Unit of energy. +""" + +KELVIN = "K" +""" +SI unit of temperature. +""" + +CMCUBE = "cm^3" +""" +Cubic centimeter unit of volume. +""" + +PERCMCUBE = "1/cm^3" +""" +Unit per centimeter cube. +""" + +WATT = "W" +""" +SI unit of power. +""" + +VOLT = "V" +""" +SI unit of electric potential. +""" + +PICOSECOND_PER_NANOMETER_PER_KILOMETER = "ps/(nm km)" +""" +Picosecond per (nanometer kilometer). +""" + +OHM = "ohm" +""" +SI unit of resistance. +""" + +FARAD = "farad" +""" +SI unit of capacitance. +""" + +HENRY = "henry" +""" +SI unit of inductance. +""" + +AMP = "A" +""" +SI unit of electric current. +""" + +THERMAL_CONDUCTIVITY = "W/(um*K)" +""" +Watts per (micrometer Kelvin). +""" + +SPECIFIC_HEAT_CAPACITY = "J/(kg*K)" +""" +Joules per (kilogram Kelvin). +""" + +DENSITY = "kg/um^3" +""" +Kilograms per cubic micrometer. +""" + +HEAT_FLUX = "W/um^2" +""" +Watts per square micrometer. +""" + +VOLUMETRIC_HEAT_RATE = "W/um^3" +""" +Watts per cube micrometer. +""" + +HEAT_TRANSFER_COEFF = "W/(um^2*K)" +""" +Watts per (square micrometer Kelvin). +""" + +CURRENT_DENSITY = "A/um^2" +""" +Amperes per square micrometer +""" + +DYNAMIC_VISCOSITY = "kg/(um*s)" +""" +Kilograms per (micrometer second) +""" + +SPECIFIC_HEAT = "um^2/(s^2*K)" +""" +Square micrometers per (square second Kelvin). +""" + +THERMAL_EXPANSIVITY = "1/K" +""" +Inverse Kelvin. +""" + +VELOCITY_SI = "m/s" +""" +SI unit of velocity +""" + +ACCELERATION = "um/s^2" +""" +Acceleration unit. +""" + +LARGE_NUMBER = 1e10 +""" +Large number used for comparing infinity. +""" + +LARGEST_FP_NUMBER = 1e38 +""" +Largest number used for single precision floating point number. +""" + +inf = np.inf +""" +Representation of infinity used within tidy3d. +""" + +# if |np.pi/2 - angle_theta| < GLANCING_CUTOFF in an angled source or in mode spec, raise warning +GLANCING_CUTOFF = 0.1 +""" +if |np.pi/2 - angle_theta| < GLANCING_CUTOFF in an angled source or in mode spec, raise warning. +""" + +UnitScaling = MappingProxyType( + { + "nm": 1e3, + "μm": 1e0, + "um": 1e0, + "mm": 1e-3, + "cm": 1e-4, + "m": 1e-6, + "mil": 1.0 / 25.4, + "in": 1.0 / 25400, + } +) +"""Immutable dictionary for converting microns to another spatial unit, eg. nm = um * UnitScaling["nm"].""" diff --git a/tidy3d/_common/exceptions.py b/tidy3d/_common/exceptions.py new file mode 100644 index 0000000000..24f53345c2 --- /dev/null +++ b/tidy3d/_common/exceptions.py @@ -0,0 +1,67 @@ +"""Custom Tidy3D exceptions""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from tidy3d._common.log import log + +if TYPE_CHECKING: + from typing import Optional + +if TYPE_CHECKING: + from typing import Optional + + +class Tidy3dError(ValueError): + """Any error in tidy3d""" + + def __init__(self, message: Optional[str] = None, log_error: bool = True) -> None: + """Log just the error message and then raise the Exception.""" + super().__init__(message) + if log_error: + log.error(message) + + +class ConfigError(Tidy3dError): + """Error when configuring Tidy3d.""" + + +class Tidy3dKeyError(Tidy3dError): + """Could not find a key in a Tidy3d dictionary.""" + + +class ValidationError(Tidy3dError): + """Error when constructing Tidy3d components.""" + + +class SetupError(Tidy3dError): + """Error regarding the setup of the components (outside of domains, etc).""" + + +class FileError(Tidy3dError): + """Error reading or writing to file.""" + + +class WebError(Tidy3dError): + """Error with the webAPI.""" + + +class AuthenticationError(Tidy3dError): + """Error authenticating a user through webapi webAPI.""" + + +class DataError(Tidy3dError): + """Error accessing data.""" + + +class Tidy3dImportError(Tidy3dError): + """Error importing a package needed for tidy3d.""" + + +class Tidy3dNotImplementedError(Tidy3dError): + """Error when a functionality is not (yet) supported.""" + + +class AdjointError(Tidy3dError): + """An error in setting up the adjoint solver.""" diff --git a/tidy3d/_common/log.py b/tidy3d/_common/log.py new file mode 100644 index 0000000000..290ae8b0e5 --- /dev/null +++ b/tidy3d/_common/log.py @@ -0,0 +1,520 @@ +"""Logging Configuration for Tidy3d.""" + +from __future__ import annotations + +import inspect +from contextlib import contextmanager +from datetime import datetime +from typing import TYPE_CHECKING, Any, Literal, Union + +from rich.console import Console +from rich.text import Text + +if TYPE_CHECKING: + from collections.abc import Iterator + from os import PathLike + from types import TracebackType + from typing import Callable, Optional + + from pydantic import BaseModel + from rich.progress import Progress as RichProgress + + from tidy3d._common.compat import Self +# Note: "SUPPORT" and "USER" levels are meant for backend runs only. +# Logging in frontend code should just use the standard debug/info/warning/error/critical. +LogLevel = Literal["DEBUG", "SUPPORT", "USER", "INFO", "WARNING", "ERROR", "CRITICAL"] +LogValue = Union[int, LogLevel] + +# Logging levels compatible with logging module +_level_value = { + "DEBUG": 10, + "SUPPORT": 12, + "USER": 15, + "INFO": 20, + "WARNING": 30, + "ERROR": 40, + "CRITICAL": 50, +} + +_level_name = {v: k for k, v in _level_value.items()} + +DEFAULT_LEVEL = "WARNING" + +DEFAULT_LOG_STYLES = { + "DEBUG": None, + "SUPPORT": None, + "USER": None, + "INFO": None, + "WARNING": "red", + "ERROR": "red bold", + "CRITICAL": "red bold", +} + +# Width of the console used for rich logging (in characters). +CONSOLE_WIDTH = 80 + + +def _default_log_level_format(level: str, message: str) -> tuple[str, str]: + """By default just return unformatted prefix and message.""" + return level, message + + +def _get_level_int(level: LogValue) -> int: + """Get the integer corresponding to the level string.""" + if isinstance(level, int): + return level + + if level not in _level_value: + # We don't want to import ConfigError to avoid a circular dependency + raise ValueError( + f"logging level {level} not supported, must be " + "'DEBUG', 'SUPPORT', 'USER', 'INFO', 'WARNING', 'ERROR', or 'CRITICAL'" + ) + return _level_value[level] + + +class LogHandler: + """Handle log messages depending on log level""" + + def __init__( + self, + console: Console, + level: LogValue, + log_level_format: Callable = _default_log_level_format, + prefix_every_line: bool = False, + ) -> None: + self.level = _get_level_int(level) + self.console = console + self.log_level_format = log_level_format + self.prefix_every_line = prefix_every_line + + def handle(self, level: int, level_name: str, message: str) -> None: + """Output log messages depending on log level""" + if level >= self.level: + stack = inspect.stack() + console = self.console + offset = 4 + if stack[offset - 1].filename.endswith("exceptions.py"): + # We want the calling site for exceptions.py + offset += 1 + prefix, msg = self.log_level_format(level_name, message) + if self.prefix_every_line: + wrapped_text = Text(msg, style="default") + msgs = wrapped_text.wrap(console=console, width=console.width - len(prefix) - 2) + else: + msgs = [msg] + for msg in msgs: + console.log( + prefix, + msg, + sep=": ", + style=DEFAULT_LOG_STYLES[level_name], + _stack_offset=offset, + ) + + +class Logger: + """Custom logger to avoid the complexities of the logging module. + + Notes + ----- + The logger can be used in a context manager to avoid the emission of multiple messages. In this + case, the first message in the context is emitted normally, but any others are discarded. When + the context is exited, the number of discarded messages of each level is displayed with the + highest level of the captures messages. + + Messages can also be captured for post-processing. That can be enabled through 'set_capture' to + record warnings emitted during model validation (and other explicit begin/end capture regions, + e.g. validation routines like ``validate_pre_upload``). A structured copy of captured warnings + can then be recovered through 'captured_warnings'. + """ + + _static_cache = set() + + def __init__(self) -> None: + self.handlers = {} + self.suppression = True + self.warn_once = False + self._counts = None + self._stack = None + self._capture = False + self._captured_warnings = [] + + def set_capture(self, capture: bool) -> None: + """Turn on/off tree-like capturing of log messages.""" + self._capture = capture + + def captured_warnings(self) -> list[dict[str, Any]]: + """Get the formatted list of captured log messages.""" + captured_warnings = self._captured_warnings + self._captured_warnings = [] + return captured_warnings + + def __enter__(self) -> Self: + """If suppression is enabled, enter a consolidation context (only a single message is + emitted).""" + if self.suppression and self._counts is None: + self._counts = {} + return self + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> Literal[False]: + """Exist a consolidation context (report the number of messages discarded).""" + if self._counts is not None: + total = sum(v for v in self._counts.values()) + if total > 0: + max_level = max(k for k, v in self._counts.items() if v > 0) + counts = [f"{v} {_level_name[k]}" for k, v in self._counts.items() if v > 0] + self._counts = None + if total > 0: + noun = " messages." if total > 1 else " message." + # Temporarily prevent capturing messages to emit consolidated summary + stack = self._stack + self._stack = None + self.log(max_level, "Suppressed " + ", ".join(counts) + noun) + self._stack = stack + return False + + def begin_capture(self) -> None: + """Start capturing log stack for consolidated validation log. + + This method should be called before a validation routine starts. It must be followed by a + corresponding 'end_capture'. + """ + if not self._capture: + return + + stack_item = {"messages": [], "children": {}} + if self._stack: + self._stack.append(stack_item) + else: + self._stack = [stack_item] + + def abort_capture(self) -> None: + """Undo the last ``begin_capture()`` call. + + This is used when validation fails before reaching the corresponding ``end_capture()``. + """ + if not self._stack: + return + + self._stack.pop() + if len(self._stack) == 0: + self._stack = None + + def end_capture(self, model: BaseModel) -> None: + """End capturing log stack for consolidated validation log. + + This method should be called after a validation routine ends. It must follow a + corresponding 'begin_capture'. + """ + if not self._stack: + return + + stack_item = self._stack.pop() + if len(self._stack) == 0: + self._stack = None + + # Check if this stack item contains any messages or children + if len(stack_item["messages"]) > 0 or len(stack_item["children"]) > 0: + stack_item["type"] = model.__class__.__name__ + + # Set the path for each children + model_fields = model.get_submodels_by_hash() + for child_hash, child_dict in stack_item["children"].items(): + child_dict["parent_fields"] = model_fields.get(child_hash, []) + + # Are we at the bottom of the stack? + if self._stack is None: + # Yes, we're root + self._parse_warning_capture(current_loc=[], stack_item=stack_item) + else: + # No, we're someone else's child + hash_ = hash(model) + self._stack[-1]["children"][hash_] = stack_item + + def _parse_warning_capture(self, current_loc: list[Any], stack_item: dict[str, Any]) -> None: + """Process capture tree to compile formatted captured warnings.""" + + if "parent_fields" in stack_item: + for field in stack_item["parent_fields"]: + if isinstance(field, tuple): + # array field + new_loc = current_loc + list(field) + else: + # single field + new_loc = [*current_loc, field] + + # process current level warnings + for level, msg, custom_loc in stack_item["messages"]: + if level == "WARNING": + self._captured_warnings.append({"loc": new_loc + custom_loc, "msg": msg}) + + # initialize processing at children level + for child_stack in stack_item["children"].values(): + self._parse_warning_capture(current_loc=new_loc, stack_item=child_stack) + + else: # for root object + # process current level warnings + for level, msg, custom_loc in stack_item["messages"]: + if level == "WARNING": + self._captured_warnings.append({"loc": current_loc + custom_loc, "msg": msg}) + + # initialize processing at children level + for child_stack in stack_item["children"].values(): + self._parse_warning_capture(current_loc=current_loc, stack_item=child_stack) + + def _log( + self, + level: int, + level_name: str, + message: str, + *args: Any, + log_once: bool = False, + custom_loc: Optional[list] = None, + capture: bool = True, + ) -> None: + """Distribute log messages to all handlers""" + + # Check global cache if requested or if warn_once is enabled for warnings + # (before composing/capturing to avoid duplicates) + should_check_cache = log_once or (self.warn_once and level_name == "WARNING") + if should_check_cache: + # Use the message body before composition as key + if message in self._static_cache: + return + self._static_cache.add(message) + + # Compose message + if len(args) > 0: + try: + composed_message = str(message) % args + + except Exception as e: + composed_message = f"{message} % {args}\n{e}" + else: + composed_message = str(message) + + # Capture all messages (even if suppressed later) + if self._stack and capture: + if custom_loc is None: + custom_loc = [] + self._stack[-1]["messages"].append((level_name, composed_message, custom_loc)) + + # Context-local logger emits a single message and consolidates the rest + if self._counts is not None: + if len(self._counts) > 0: + self._counts[level] = 1 + self._counts.get(level, 0) + return + self._counts[level] = 0 + + # Forward message to handlers + for handler in self.handlers.values(): + handler.handle(level, level_name, composed_message) + + def log(self, level: LogValue, message: str, *args: Any, log_once: bool = False) -> None: + """Log (message) % (args) with given level""" + if isinstance(level, str): + level_name = level + level = _get_level_int(level) + else: + level_name = _level_name.get(level, "unknown") + self._log(level, level_name, message, *args, log_once=log_once) + + def debug(self, message: str, *args: Any, log_once: bool = False) -> None: + """Log (message) % (args) at debug level""" + self._log(_level_value["DEBUG"], "DEBUG", message, *args, log_once=log_once) + + def support(self, message: str, *args: Any, log_once: bool = False) -> None: + """Log (message) % (args) at support level""" + self._log(_level_value["SUPPORT"], "SUPPORT", message, *args, log_once=log_once) + + def user(self, message: str, *args: Any, log_once: bool = False) -> None: + """Log (message) % (args) at user level""" + self._log(_level_value["USER"], "USER", message, *args, log_once=log_once) + + def info(self, message: str, *args: Any, log_once: bool = False) -> None: + """Log (message) % (args) at info level""" + self._log(_level_value["INFO"], "INFO", message, *args, log_once=log_once) + + def warning( + self, + message: str, + *args: Any, + log_once: bool = False, + custom_loc: Optional[list] = None, + capture: bool = True, + ) -> None: + """Log (message) % (args) at warning level""" + self._log( + _level_value["WARNING"], + "WARNING", + message, + *args, + log_once=log_once, + custom_loc=custom_loc, + capture=capture, + ) + + def error(self, message: str, *args: Any, log_once: bool = False) -> None: + """Log (message) % (args) at error level""" + self._log(_level_value["ERROR"], "ERROR", message, *args, log_once=log_once) + + def critical(self, message: str, *args: Any, log_once: bool = False) -> None: + """Log (message) % (args) at critical level""" + self._log(_level_value["CRITICAL"], "CRITICAL", message, *args, log_once=log_once) + + +def set_logging_level(level: LogValue = DEFAULT_LEVEL) -> None: + """Set tidy3d console logging level priority. + + Parameters + ---------- + level : str + The lowest priority level of logging messages to display. One of ``{'DEBUG', 'SUPPORT', + 'USER', INFO', 'WARNING', 'ERROR', 'CRITICAL'}`` (listed in increasing priority). + """ + if "console" in log.handlers: + log.handlers["console"].level = _get_level_int(level) + + +def set_log_suppression(value: bool) -> None: + """Control log suppression for repeated messages.""" + log.suppression = value + + +def set_warn_once(value: bool) -> None: + """Control whether warnings are only shown once per unique message. + + Parameters + ---------- + value : bool + When True, each unique warning message is only shown once per process. + """ + log.warn_once = value + + +def get_aware_datetime() -> datetime: + """Get an aware current local datetime(with local timezone info)""" + return datetime.now().astimezone() + + +def set_logging_console(stderr: bool = False) -> None: + """Set stdout or stderr as console output + + Parameters + ---------- + stderr : bool + If False, logs are directed to stdout, otherwise to stderr. + """ + if "console" in log.handlers: + previous_level = log.handlers["console"].level + else: + previous_level = DEFAULT_LEVEL + log.handlers["console"] = LogHandler( + Console( + stderr=stderr, + width=CONSOLE_WIDTH, + log_path=False, + get_datetime=get_aware_datetime, + log_time_format="%X %Z", + ), + previous_level, + ) + + +def set_logging_file( + fname: PathLike, + filemode: str = "w", + level: LogValue = DEFAULT_LEVEL, + log_path: bool = False, +) -> None: + """Set a file to write log to, independently from the stdout and stderr + output chosen using :meth:`set_logging_level`. + + Parameters + ---------- + fname : PathLike + Path to file to direct the output to. If empty string, a previously set logging file will + be closed, if any, but nothing else happens. + filemode : str + 'w' or 'a', defining if the file should be overwritten or appended. + level : str + One of ``{'DEBUG', 'SUPPORT', 'USER', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'}``. This is set + for the file independently of the console output level set by :meth:`set_logging_level`. + log_path : bool = False + Whether to log the path to the file that issued the message. + """ + if filemode not in "wa": + raise ValueError("filemode must be either 'w' or 'a'") + + # Close previous handler, if any + if "file" in log.handlers: + try: + log.handlers["file"].console.file.close() + except Exception: # TODO: catch specific exception + log.warning("Log file could not be closed") + finally: + del log.handlers["file"] + + if str(fname) == "": + # Empty string can be passed to just stop previously opened file handler + return + + try: + file = open(fname, filemode) + except Exception: # TODO: catch specific exception + log.error(f"File {fname} could not be opened") + return + + log.handlers["file"] = LogHandler( + Console(file=file, force_jupyter=False, log_path=log_path), level + ) + + +# Initialize Tidy3d's logger +log = Logger() + +# Set default logging output +set_logging_console() + + +def get_logging_console() -> Console: + """Get console from logging handlers.""" + if "console" not in log.handlers: + set_logging_console() + return log.handlers["console"].console + + +class NoOpProgress: + """Dummy progress manager that doesn't show any output.""" + + def __enter__(self) -> Self: + return self + + def __exit__(self, *args: Any, **kwargs: Any) -> None: + pass + + def add_task(self, *args: Any, **kwargs: Any) -> None: + pass + + def update(self, *args: Any, **kwargs: Any) -> None: + pass + + +@contextmanager +def Progress(console: Console, show_progress: bool) -> Iterator[Union[RichProgress, NoOpProgress]]: + """Progress manager that wraps ``rich.Progress`` if ``show_progress`` is ``True``, + and ``NoOpProgress`` otherwise.""" + if show_progress: + from rich.progress import Progress + + with Progress(console=console) as progress: + yield progress + else: + with NoOpProgress() as progress: + yield progress diff --git a/tidy3d/_common/packaging.py b/tidy3d/_common/packaging.py new file mode 100644 index 0000000000..4763349c99 --- /dev/null +++ b/tidy3d/_common/packaging.py @@ -0,0 +1,289 @@ +""" +This file contains a set of functions relating to packaging tidy3d for distribution. Sections of the codebase should depend on this file, but this file should not depend on any other part of the codebase. + +This section should only depend on the standard core installation in the pyproject.toml, and should not depend on any other part of the codebase optional imports. +""" + +from __future__ import annotations + +import functools +from importlib import import_module +from importlib.util import find_spec +from typing import TYPE_CHECKING, Any, Callable, TypeVar + +import numpy as np + +from tidy3d._common.exceptions import Tidy3dImportError + +if TYPE_CHECKING: + from typing import Literal + +__version__ = "2.11.0.dev0" # TODO change version handling + +F = TypeVar("F", bound=Callable[..., Any]) + +if TYPE_CHECKING: + from typing import Literal + +F = TypeVar("F", bound=Callable[..., Any]) + +vtk = { + "mod": None, + "id_type": np.int64, + "vtk_to_numpy": None, + "numpy_to_vtkIdTypeArray": None, + "numpy_to_vtk": None, +} + +tidy3d_extras = {"mod": None, "use_local_subpixel": None} + + +def check_import(module_name: str) -> bool: + """ + Check if a module or submodule section has been imported. This is a functional way of loading packages that will still load the corresponding module into the total space. + + Parameters + ---------- + module_name + + Returns + ------- + bool + True if the module has been imported, False otherwise. + + """ + try: + import_module(module_name) + return True + except ImportError: + return False + + +def verify_packages_import( + modules: list[str], required: Literal["any", "all"] = "all" +) -> Callable[[F], F]: + def decorator(func: F) -> F: + """ + When decorating a method, requires that the specified modules are available. It will raise an error if the + module is not available depending on the value of the 'required' parameter which represents the type of + import required. + + There are a few options to choose for the 'required' parameter: + - 'all': All the modules must be available for the operation to continue without raising an error + - 'any': At least one of the modules must be available for the operation to continue without raising an error + + Parameters + ---------- + func + The function to decorate. + + Returns + ------- + checks_modules_import + The decorated function. + + """ + + @functools.wraps(func) + def checks_modules_import(*args: Any, **kwargs: Any) -> Any: + """ + Checks if the modules are available. If they are not available, it will raise an error depending on the value. + """ + available_modules_status = [] + maximum_amount_modules = len(modules) + + module_id_i = 0 + for module in modules: + # Starts counting from one so that it can be compared to len(modules) + module_id_i += 1 + import_available = check_import(module) + available_modules_status.append( + import_available + ) # Stores the status of the module import + + if not import_available: + if required == "all": + raise Tidy3dImportError( + f"The package '{module}' is required for this operation, but it was not found. " + f"Please install the '{module}' dependencies using, for example, " + f"'pip install tidy3d[]" + ) + if required == "any": + # Means we need to verify that at least one of the modules is available + if ( + not any(available_modules_status) + ) and module_id_i == maximum_amount_modules: + # Means that we have reached the last module and none of them were available + raise Tidy3dImportError( + f"The package '{module}' is required for this operation, but it was not found. " + f"Please install the '{module}' dependencies using, for example, " + f"'pip install tidy3d[]" + ) + else: + raise ValueError( + f"The value '{required}' is not a valid value for the 'required' parameter. " + f"Please use any 'all' or 'any'." + ) + else: + # Means that the module is available, so we can just continue with the operation + pass + return func(*args, **kwargs) + + return checks_modules_import + + return decorator + + +def requires_vtk(fn: F) -> F: + """When decorating a method, requires that vtk is available.""" + + @functools.wraps(fn) + def _fn(*args: Any, **kwargs: Any) -> Any: + if vtk["mod"] is None: + try: + import vtk as vtk_mod + from vtk.util.numpy_support import ( + numpy_to_vtk, + numpy_to_vtkIdTypeArray, + vtk_to_numpy, + ) + from vtkmodules.vtkCommonCore import vtkLogger + + vtk["mod"] = vtk_mod + vtk["vtk_to_numpy"] = vtk_to_numpy + vtk["numpy_to_vtkIdTypeArray"] = numpy_to_vtkIdTypeArray + vtk["numpy_to_vtk"] = numpy_to_vtk + + vtkLogger.SetStderrVerbosity(vtkLogger.VERBOSITY_WARNING) + + if vtk["mod"].vtkIdTypeArray().GetDataTypeSize() == 4: + vtk["id_type"] = np.int32 + + except ImportError as exc: + raise Tidy3dImportError( + "The package 'vtk' is required for this operation, but it was not found. " + "Please install the 'vtk' dependencies using, for example, " + "'pip install .[vtk]'." + ) from exc + + return fn(*args, **kwargs) + + return _fn + + +def get_numpy_major_version(module: Any = np) -> int: + """ + Extracts the major version of the installed numpy accordingly. + + Parameters + ---------- + module : module + The module to extract the version from. Default is numpy. + + Returns + ------- + int + The major version of the module. + """ + # Get the version of the module + module_version = module.__version__ + + # Extract the major version number + major_version = int(module_version.split(".")[0]) + + return major_version + + +def _check_tidy3d_extras_available(quiet: bool = False) -> None: + """Helper function to check if 'tidy3d-extras' is available and version matched. + + Parameters + ---------- + quiet : bool + If True, suppress error logging when raising exceptions. + + Raises + ------ + Tidy3dImportError + If tidy3d-extras is not available or not properly initialized. + """ + if tidy3d_extras["mod"] is not None: + return + + module_exists = find_spec("tidy3d_extras") is not None + if not module_exists: + raise Tidy3dImportError( + "The package 'tidy3d-extras' is absent. " + "Please install the 'tidy3d-extras' package using, for " + r"example, 'pip install tidy3d\[extras]'.", + log_error=not quiet, + ) + + try: + import tidy3d_extras as tidy3d_extras_mod + + except ImportError as exc: + raise Tidy3dImportError( + "The package 'tidy3d-extras' did not initialize correctly.", + log_error=not quiet, + ) from exc + + if not hasattr(tidy3d_extras_mod, "__version__"): + raise Tidy3dImportError( + "The package 'tidy3d-extras' did not initialize correctly. " + "Please install the 'tidy3d-extras' package using, for " + r"example, 'pip install tidy3d\[extras]'.", + log_error=not quiet, + ) + + version = tidy3d_extras_mod.__version__ + + if version is None: + raise Tidy3dImportError( + "The package 'tidy3d-extras' did not initialize correctly, " + "likely due to an invalid API key.", + log_error=not quiet, + ) + + if version != __version__: + raise Tidy3dImportError( + f"The version of 'tidy3d-extras' is {version}, but the version of 'tidy3d' is {__version__}. " + "They must match. You can install the correct " + r"version using 'pip install tidy3d\[extras]'.", + log_error=not quiet, + ) + + tidy3d_extras["mod"] = tidy3d_extras_mod + + +def check_tidy3d_extras_licensed_feature(feature_name: str, quiet: bool = False) -> None: + """Helper function to check if a specific feature is licensed in 'tidy3d-extras'. + + Parameters + ---------- + feature_name : str + The name of the feature to check for. + quiet : bool + If True, suppress error logging when raising exceptions. + + Raises + ------ + Tidy3dImportError + If the feature is not available with your license. + """ + + try: + _check_tidy3d_extras_available(quiet=quiet) + except Tidy3dImportError as exc: + raise Tidy3dImportError( + f"The package 'tidy3d-extras' is required for this feature '{feature_name}'.", + log_error=not quiet, + ) from exc + + features = tidy3d_extras["mod"].extension._features() + if feature_name not in features: + raise Tidy3dImportError( + f"The feature '{feature_name}' is not available with your license. " + "Please contact Tidy3D support, or upgrade your license.", + log_error=not quiet, + ) diff --git a/tidy3d/_common/web/__init__.py b/tidy3d/_common/web/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tidy3d/_common/web/cache.py b/tidy3d/_common/web/cache.py new file mode 100644 index 0000000000..0c92475674 --- /dev/null +++ b/tidy3d/_common/web/cache.py @@ -0,0 +1,895 @@ +"""Local simulation cache manager.""" + +from __future__ import annotations + +import hashlib +import json +import os +import shutil +import tempfile +import threading +from contextlib import contextmanager +from dataclasses import dataclass +from datetime import datetime, timezone +from enum import Enum +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol + +from pydantic import BaseModel, ConfigDict, Field, NonNegativeInt + +from tidy3d._common import config +from tidy3d._common.log import log +from tidy3d._common.web.core.http_util import get_version as _get_protocol_version + +if TYPE_CHECKING: + from collections.abc import Iterator + + from tidy3d._common.web.core.constants import TaskId + + +class CacheableSimulation(Protocol): + """Protocol for simulation objects that can be cached.""" + + def _hash_self(self) -> str: + """Return a stable hash for cache key construction.""" + + +_GetWorkflowType = Callable[[CacheableSimulation], str] +_get_workflow_type_callback: Optional[_GetWorkflowType] = None + + +def register_get_workflow_type(callback: _GetWorkflowType) -> None: + """Register workflow type resolver for cache logging.""" + global _get_workflow_type_callback + _get_workflow_type_callback = callback + + +def _get_workflow_type(simulation: CacheableSimulation) -> str: + if _get_workflow_type_callback is None: + return type(simulation).__name__ + try: + return _get_workflow_type_callback(simulation) + except Exception: + return type(simulation).__name__ + + +if TYPE_CHECKING: + from collections.abc import Iterator + + from tidy3d._common.web.core.constants import TaskId + +CACHE_ARTIFACT_NAME = "simulation_data.hdf5" +CACHE_METADATA_NAME = "metadata.json" +CACHE_STATS_NAME = "stats.json" + +TMP_PREFIX = "tidy3d-cache-" +TMP_BATCH_PREFIX = "tmp_batch" + +_CACHE: Optional[LocalCache] = None + + +def _remove_cache_dir(path: os.PathLike, *, recreate: bool) -> None: + """Remove a cache directory and optionally recreate it.""" + cache_path = Path(path) + if cache_path.exists(): + try: + shutil.rmtree(cache_path) + except (FileNotFoundError, OSError): + return + if recreate: + cache_path.mkdir(parents=True, exist_ok=True) + + +def get_cache_entry_dir(root: os.PathLike, key: str) -> Path: + """ + Returns the cache directory for a given key. + A three-character prefix subdirectory is used to avoid hitting filesystem limits on the number of entries per folder. + """ + return Path(root) / key[:3] / key + + +class CacheStats(BaseModel): + """Lightweight summary of cache usage persisted in ``stats.json``.""" + + last_used: dict[str, str] = Field( + default_factory=dict, + description="Mapping from cache entry key to the most recent ISO-8601 access timestamp.", + ) + total_size: NonNegativeInt = Field( + default=0, + description="Aggregate size in bytes across cached artifacts captured in the stats file.", + ) + updated_at: Optional[datetime] = Field( + default=None, + description="UTC timestamp indicating when the statistics were last refreshed.", + ) + + model_config = ConfigDict(extra="allow", validate_assignment=True) + + @property + def total_entries(self) -> int: + return len(self.last_used) + + +class CacheEntryMetadata(BaseModel): + """Schema for cache entry metadata persisted on disk.""" + + cache_key: str + checksum: str + created_at: datetime + last_used: datetime + file_size: int = Field(ge=0) + simulation_hash: str + workflow_type: str + versions: Any + task_id: str + path: str + + model_config = ConfigDict(extra="allow", validate_assignment=True) + + def bump_last_used(self) -> None: + self.last_used = datetime.now(timezone.utc) + + def as_dict(self) -> dict[str, Any]: + return self.model_dump(mode="json") + + def get(self, key: str, default: Any = None) -> Any: + return self.as_dict().get(key, default) + + def __getitem__(self, key: str) -> Any: + data = self.as_dict() + if key not in data: + raise KeyError(key) + return data[key] + + +@dataclass +class CacheEntry: + """Internal representation of a cache entry.""" + + key: str + root: Path + metadata: CacheEntryMetadata + + @property + def path(self) -> Path: + return get_cache_entry_dir(self.root, self.key) + + @property + def artifact_path(self) -> Path: + return self.path / CACHE_ARTIFACT_NAME + + @property + def metadata_path(self) -> Path: + return self.path / CACHE_METADATA_NAME + + def exists(self) -> bool: + return self.path.exists() and self.artifact_path.exists() and self.metadata_path.exists() + + def verify(self) -> bool: + if not self.exists(): + return False + checksum = self.metadata.checksum + if not checksum: + return False + try: + actual_checksum, file_size = _copy_and_hash(self.artifact_path, None) + except FileNotFoundError: + return False + if checksum != actual_checksum: + log.warning( + "Simulation cache checksum mismatch for key '%s'. Removing stale entry.", self.key + ) + return False + if self.metadata.file_size != file_size: + self.metadata.file_size = file_size + _write_metadata(self.metadata_path, self.metadata) + return True + + def materialize(self, target: Path) -> Path: + """Copy cached artifact to ``target`` and return the resulting path.""" + target = Path(target) + target.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(self.artifact_path, target) + return target + + +class LocalCache: + """Manages storing and retrieving cached simulation artifacts.""" + + def __init__(self, directory: os.PathLike, max_size_gb: float, max_entries: int) -> None: + self.max_size_gb = max_size_gb + self.max_entries = max_entries + self._root = Path(directory) + self._lock = threading.RLock() + self._syncing_stats = False + self._sync_pending = False + + @property + def _stats_path(self) -> Path: + return self._root / CACHE_STATS_NAME + + def _schedule_sync(self) -> None: + self._sync_pending = True + + def _run_pending_sync(self) -> None: + if self._sync_pending and not self._syncing_stats: + self._sync_pending = False + self.sync_stats() + + @contextmanager + def _with_lock(self) -> Iterator[None]: + self._run_pending_sync() + with self._lock: + yield + self._run_pending_sync() + + def _write_stats(self, stats: CacheStats) -> CacheStats: + updated = stats.model_copy(update={"updated_at": datetime.now(timezone.utc)}) + payload = updated.model_dump(mode="json") + payload["total_entries"] = updated.total_entries + self._stats_path.parent.mkdir(parents=True, exist_ok=True) + _write_metadata(self._stats_path, payload) + self._sync_pending = False + return updated + + def _load_stats(self, *, rebuild: bool = False) -> CacheStats: + path = self._stats_path + if not path.exists(): + if not self._syncing_stats: + self._schedule_sync() + return CacheStats() + try: + data = json.loads(path.read_text(encoding="utf-8")) + if "last_used" not in data and "entries" in data: + data["last_used"] = data.pop("entries") + stats = CacheStats.model_validate(data) + except Exception: + if rebuild and not self._syncing_stats: + self._schedule_sync() + return CacheStats() + if stats.total_size < 0: + self._schedule_sync() + return CacheStats() + return stats + + def _record_store_stats( + self, + key: str, + *, + last_used: str, + file_size: int, + previous_size: int, + ) -> None: + stats = self._load_stats() + entries = dict(stats.last_used) + entries[key] = last_used + total_size = stats.total_size - previous_size + file_size + if total_size < 0: + total_size = 0 + self._schedule_sync() + updated = stats.model_copy(update={"last_used": entries, "total_size": total_size}) + self._write_stats(updated) + + def _record_touch_stats( + self, key: str, last_used: str, *, file_size: Optional[int] = None + ) -> None: + stats = self._load_stats() + entries = dict(stats.last_used) + existed = key in entries + total_size = stats.total_size + if not existed and file_size is not None: + total_size += file_size + if total_size < 0: + total_size = 0 + self._schedule_sync() + entries[key] = last_used + updated = stats.model_copy(update={"last_used": entries, "total_size": total_size}) + self._write_stats(updated) + + def _record_remove_stats(self, key: str, file_size: int) -> None: + stats = self._load_stats() + entries = dict(stats.last_used) + entries.pop(key, None) + total_size = stats.total_size - file_size + if total_size < 0: + total_size = 0 + self._schedule_sync() + updated = stats.model_copy(update={"last_used": entries, "total_size": total_size}) + self._write_stats(updated) + + def _enforce_limits_post_sync(self, entries: list[CacheEntry]) -> None: + if not entries: + return + + entries_map = {entry.key: entry.metadata.last_used.isoformat() for entry in entries} + + if self.max_entries > 0 and len(entries) > self.max_entries: + excess = len(entries) - self.max_entries + self._evict(entries_map, remove_count=excess, exclude_keys=set()) + + max_size_bytes = int(self.max_size_gb * (1024**3)) + if max_size_bytes > 0: + total_size = sum(entry.metadata.file_size for entry in entries) + if total_size > max_size_bytes: + bytes_to_free = total_size - max_size_bytes + self._evict_by_size(entries_map, bytes_to_free, exclude_keys=set()) + + def sync_stats(self) -> CacheStats: + with self._lock: + self._syncing_stats = True + log.debug("Syncing stats.json of local cache") + try: + entries: list[CacheEntry] = [] + last_used_map: dict[str, str] = {} + total_size = 0 + for entry in self._iter_entries(): + entries.append(entry) + total_size += entry.metadata.file_size + last_used_map[entry.key] = entry.metadata.last_used.isoformat() + stats = CacheStats(last_used=last_used_map, total_size=total_size) + written = self._write_stats(stats) + self._enforce_limits_post_sync(entries) + return written + finally: + self._syncing_stats = False + + @property + def root(self) -> Path: + return self._root + + def list(self) -> list[dict[str, Any]]: + """Return metadata for all cache entries.""" + with self._with_lock(): + entries = [entry.metadata.model_dump(mode="json") for entry in self._iter_entries()] + return entries + + def clear(self, hard: bool = False) -> None: + """Remove all cache contents. If set to hard, root directory is removed.""" + with self._with_lock(): + _remove_cache_dir(self._root, recreate=not hard) + if not hard: + self._write_stats(CacheStats()) + + def _fetch(self, key: str) -> Optional[CacheEntry]: + """Retrieve an entry by key, verifying checksum.""" + with self._with_lock(): + entry = self._load_entry(key) + if not entry or not entry.exists(): + return None + if not entry.verify(): + self._remove_entry(entry) + return None + self._touch(entry) + return entry + + def __len__(self) -> int: + """Return number of valid cache entries.""" + with self._with_lock(): + count = self._load_stats().total_entries + return count + + def _store( + self, key: str, source_path: Path, metadata: CacheEntryMetadata + ) -> Optional[CacheEntry]: + """Store a new cache entry from ``source_path``. + + Parameters + ---------- + key : str + Cache key computed from simulation hash and runtime context. + source_path : Path + Location of the artifact to cache. + metadata : CacheEntryMetadata + Metadata describing the cache entry to be persisted. + + Returns + ------- + CacheEntry + Representation of the stored cache entry. + """ + source_path = Path(source_path) + if not source_path.exists(): + raise FileNotFoundError(f"Cannot cache missing artifact: {source_path}") + os.makedirs(self._root, exist_ok=True) + tmp_dir = Path(tempfile.mkdtemp(prefix=TMP_PREFIX, dir=self._root)) + tmp_artifact = tmp_dir / CACHE_ARTIFACT_NAME + tmp_meta = tmp_dir / CACHE_METADATA_NAME + os.makedirs(tmp_dir, exist_ok=True) + + checksum, file_size = _copy_and_hash(source_path, tmp_artifact) + metadata.cache_key = key + metadata.created_at = datetime.now(timezone.utc) + metadata.last_used = metadata.created_at + metadata.checksum = checksum + metadata.file_size = file_size + + _write_metadata(tmp_meta, metadata) + entry: Optional[CacheEntry] = None + try: + with self._with_lock(): + self._root.mkdir(parents=True, exist_ok=True) + existing_entry = self._load_entry(key) + previous_size = ( + existing_entry.metadata.file_size if existing_entry is not None else 0 + ) + self._ensure_limits( + file_size, + incoming_key=key, + replacing_size=previous_size, + ) + final_dir = get_cache_entry_dir(self._root, key) + final_dir.parent.mkdir(parents=True, exist_ok=True) + if final_dir.exists(): + shutil.rmtree(final_dir) + os.replace(tmp_dir, final_dir) + entry = CacheEntry(key=key, root=self._root, metadata=metadata) + + self._record_store_stats( + key, + last_used=metadata.last_used.isoformat(), + file_size=file_size, + previous_size=previous_size, + ) + log.debug("Stored simulation cache entry '%s' (%d bytes).", key, file_size) + finally: + try: + if tmp_dir.exists(): + shutil.rmtree(tmp_dir, ignore_errors=True) + except FileNotFoundError: + pass + return entry + + def invalidate(self, key: str) -> None: + with self._with_lock(): + entry = self._load_entry(key) + if entry: + self._remove_entry(entry) + + def _ensure_limits( + self, + incoming_size: int, + *, + incoming_key: Optional[str] = None, + replacing_size: int = 0, + ) -> None: + max_entries = self.max_entries + max_size_bytes = int(self.max_size_gb * (1024**3)) + + try: + incoming_size_int = int(incoming_size) + except (TypeError, ValueError): + incoming_size_int = 0 + if incoming_size_int < 0: + incoming_size_int = 0 + + stats = self._load_stats() + entries_info = dict(stats.last_used) + existing_keys = set(entries_info) + projected_entries = stats.total_entries + if not incoming_key or incoming_key not in existing_keys: + projected_entries += 1 + + if projected_entries > max_entries > 0: + excess = projected_entries - max_entries + exclude = {incoming_key} if incoming_key else set() + self._evict(entries_info, remove_count=excess, exclude_keys=exclude) + stats = self._load_stats() + entries_info = dict(stats.last_used) + existing_keys = set(entries_info) + + if max_size_bytes == 0: # no limit + return + + existing_size = stats.total_size + try: + replacing_size_int = int(replacing_size) + except (TypeError, ValueError): + replacing_size_int = 0 + if incoming_key and incoming_key in existing_keys: + projected_size = existing_size - replacing_size_int + incoming_size_int + else: + projected_size = existing_size + incoming_size_int + + if max_size_bytes > 0 and projected_size > max_size_bytes: + bytes_to_free = projected_size - max_size_bytes + exclude = {incoming_key} if incoming_key else set() + self._evict_by_size(entries_info, bytes_to_free, exclude_keys=exclude) + + def _evict(self, entries: dict[str, str], *, remove_count: int, exclude_keys: set[str]) -> None: + if remove_count <= 0: + return + candidates = [(key, entries.get(key, "")) for key in entries if key not in exclude_keys] + if not candidates: + return + candidates.sort(key=lambda item: item[1] or "") + for key, _ in candidates[:remove_count]: + self._remove_entry_by_key(key) + + def _evict_by_size( + self, entries: dict[str, str], bytes_to_free: int, *, exclude_keys: set[str] + ) -> None: + if bytes_to_free <= 0: + return + candidates = [(key, entries.get(key, "")) for key in entries if key not in exclude_keys] + if not candidates: + return + candidates.sort(key=lambda item: item[1] or "") + reclaimed = 0 + for key, _ in candidates: + if reclaimed >= bytes_to_free: + break + entry = self._load_entry(key) + if entry is None: + log.debug("Could not find entry for eviction.") + self._schedule_sync() + break + size = entry.metadata.file_size + self._remove_entry(entry) + reclaimed += size + log.info(f"Simulation cache evicted entry '{key}' to reclaim {size} bytes.") + + def _iter_entries(self) -> Iterator[CacheEntry]: + """Iterate lazily over all cache entries, including those in prefix subdirectories.""" + if not self._root.exists(): + return + + for prefix_dir in self._root.iterdir(): + if not prefix_dir.is_dir() or prefix_dir.name.startswith( + (TMP_PREFIX, TMP_BATCH_PREFIX) + ): + continue + + # if cache is directly flat (no prefix directories), include that level too + subdirs = [prefix_dir] + if any((prefix_dir / name).is_dir() for name in prefix_dir.iterdir()): + subdirs = prefix_dir.iterdir() + + for child in subdirs: + if not child.is_dir(): + continue + if child.name.startswith((TMP_PREFIX, TMP_BATCH_PREFIX)): + continue + + meta_path = child / CACHE_METADATA_NAME + if not meta_path.exists(): + continue + + try: + metadata = _read_metadata(meta_path, child / CACHE_ARTIFACT_NAME) + except Exception: + log.debug( + "Failed to parse metadata for '%s'; scheduling stats sync.", child.name + ) + self._schedule_sync() + continue + + yield CacheEntry(key=child.name, root=self._root, metadata=metadata) + + def _load_entry(self, key: str) -> Optional[CacheEntry]: + entry = CacheEntry(key=key, root=self._root, metadata={}) + if not entry.metadata_path.exists() or not entry.artifact_path.exists(): + return None + try: + metadata = _read_metadata(entry.metadata_path, entry.artifact_path) + except Exception: + return None + return CacheEntry(key=key, root=self._root, metadata=metadata) + + def _touch(self, entry: CacheEntry) -> None: + entry.metadata.bump_last_used() + _write_metadata(entry.metadata_path, entry.metadata) + self._record_touch_stats( + entry.key, + entry.metadata.last_used.isoformat(), + file_size=entry.metadata.file_size, + ) + + def _remove_entry_by_key(self, key: str) -> None: + entry = self._load_entry(key) + if entry is None: + path = get_cache_entry_dir(self._root, key) + if path.exists(): + shutil.rmtree(path, ignore_errors=True) + else: + log.debug("Could not find entry for key '%s' to delete.", key) + self._record_remove_stats(key, 0) + return + self._remove_entry(entry) + + def _remove_entry(self, entry: CacheEntry) -> None: + file_size = entry.metadata.file_size + if entry.path.exists(): + shutil.rmtree(entry.path, ignore_errors=True) + self._record_remove_stats(entry.key, file_size) + + def try_fetch( + self, + simulation: CacheableSimulation, + verbose: bool = False, + ) -> Optional[CacheEntry]: + """ + Attempt to resolve and fetch a cached result entry for the given simulation context. + On miss or any cache error, returns None (the caller should proceed with upload/run). + """ + try: + simulation_hash = simulation._hash_self() + workflow_type = _get_workflow_type(simulation) + + versions = _get_protocol_version() + + cache_key = build_cache_key( + simulation_hash=simulation_hash, + version=versions, + ) + + entry = self._fetch(cache_key) + if not entry: + return None + + if verbose: + log.info( + f"Simulation cache hit for workflow '{workflow_type}'; using local results." + ) + + return entry + except Exception as e: + log.error("Failed to fetch cache results: " + str(e)) + return None + + def store_result( + self, + task_id: TaskId, + path: str, + workflow_type: str, + simulation: CacheableSimulation, + ) -> bool: + """ + Stores completed workflow results in the local cache using a canonical cache key. + + Parameters + ---------- + task_id : str + Unique identifier of the finished workflow task. + path : str + Path to the results file on disk. + workflow_type : str + Type of workflow associated with the results (e.g., ``"SIMULATION"`` or ``"MODE_SOLVER"``). + simulation : :class:`.CacheableSimulation` + Simulation object to use when computing the cache key. If not provided, + it will be inferred from ``stub_data.simulation`` when possible. + + Returns + ------- + bool + ``True`` if the result was successfully stored in the local cache, ``False`` otherwise. + + Notes + ----- + The cache entry is keyed by the simulation hash, workflow type, environment, and protocol version. + This enables automatic reuse of identical simulation results across future runs. + Legacy task ID mappings are recorded to support backward lookup compatibility. + """ + try: + simulation_hash = simulation._hash_self() + if not simulation_hash: + log.debug("Failed storing local cache entry: Could not hash simulation.") + return False + + version = _get_protocol_version() + + cache_key = build_cache_key( + simulation_hash=simulation_hash, + version=version, + ) + + metadata = build_entry_metadata( + simulation_hash=simulation_hash, + workflow_type=workflow_type, + task_id=task_id, + version=version, + path=Path(path), + ) + + self._store( + key=cache_key, + source_path=Path(path), + metadata=metadata, + ) + log.debug("Stored local cache entry for workflow type '%s'.", workflow_type) + except Exception as e: + log.error(f"Could not store cache entry: {e}") + return False + return True + + +def _copy_and_hash( + source: Path, dest: Optional[Path], existing_hash: Optional[str] = None +) -> tuple[str, int]: + """Copy ``source`` to ``dest`` while computing SHA256 checksum. + + Parameters + ---------- + source : Path + Source file path. + dest : Path or None + Destination file path. If ``None``, no copy is performed. + existing_hash : str, optional + If provided alongside ``dest`` and ``dest`` already exists, skip copying when hashes match. + + Returns + ------- + tuple[str, int] + The hexadecimal digest and file size in bytes. + """ + source = Path(source) + if dest is not None: + dest = Path(dest) + sha256 = _Hasher() + size = 0 + with source.open("rb") as src: + if dest is None: + while chunk := src.read(1024 * 1024): + sha256.update(chunk) + size += len(chunk) + else: + dest.parent.mkdir(parents=True, exist_ok=True) + with dest.open("wb") as dst: + while chunk := src.read(1024 * 1024): + dst.write(chunk) + sha256.update(chunk) + size += len(chunk) + return sha256.hexdigest(), size + + +def _write_metadata(path: Path, metadata: CacheEntryMetadata | dict[str, Any]) -> None: + tmp_path = path.with_suffix(".tmp") + payload: dict[str, Any] + if isinstance(metadata, CacheEntryMetadata): + payload = metadata.model_dump(mode="json") + else: + payload = metadata + with tmp_path.open("w", encoding="utf-8") as fh: + json.dump(payload, fh, indent=2, sort_keys=True) + os.replace(tmp_path, path) + + +def _now() -> str: + return datetime.now(timezone.utc).isoformat() + + +def _timestamp_suffix() -> str: + return datetime.now(timezone.utc).strftime("%Y%m%d%H%M%S%f") + + +def _read_metadata(meta_path: Path, artifact_path: Path) -> CacheEntryMetadata: + raw = json.loads(meta_path.read_text(encoding="utf-8")) + if "file_size" not in raw: + try: + raw["file_size"] = artifact_path.stat().st_size + except FileNotFoundError: + raw["file_size"] = 0 + raw.setdefault("created_at", _now()) + raw.setdefault("last_used", raw["created_at"]) + raw.setdefault("cache_key", meta_path.parent.name) + return CacheEntryMetadata.model_validate(raw) + + +class _Hasher: + def __init__(self) -> None: + self._hasher = hashlib.sha256() + + def update(self, data: bytes) -> None: + self._hasher.update(data) + + def hexdigest(self) -> str: + return self._hasher.hexdigest() + + +def clear() -> None: + """Remove all cache entries.""" + cache = resolve_local_cache(use_cache=True) + if cache is not None: + cache.clear() + + +def _canonicalize(value: Any) -> Any: + """Convert value into a JSON-serializable object for hashing/metadata.""" + + if isinstance(value, dict): + return { + str(k): _canonicalize(v) + for k, v in sorted(value.items(), key=lambda item: str(item[0])) + } + if isinstance(value, (list, tuple)): + return [_canonicalize(v) for v in value] + if isinstance(value, set): + return sorted(_canonicalize(v) for v in value) + if isinstance(value, Enum): + return value.value + if isinstance(value, Path): + return str(value) + if isinstance(value, datetime): + return value.isoformat() + if isinstance(value, bytes): + return value.decode("utf-8", errors="ignore") + return value + + +def build_cache_key( + *, + simulation_hash: str, + version: str, +) -> str: + """Construct a deterministic cache key.""" + + payload = { + "simulation_hash": simulation_hash, + "versions": _canonicalize(version), + } + encoded = json.dumps(payload, sort_keys=True, separators=(",", ":")).encode("utf-8") + return hashlib.sha256(encoded).hexdigest() + + +def build_entry_metadata( + *, + simulation_hash: str, + workflow_type: str, + task_id: str, + version: str, + path: Path, +) -> CacheEntryMetadata: + """Create metadata object for a cache entry.""" + + now = datetime.now(timezone.utc) + return CacheEntryMetadata( + cache_key="", + checksum="", + created_at=now, + last_used=now, + file_size=0, + simulation_hash=simulation_hash, + workflow_type=workflow_type, + versions=_canonicalize(version), + task_id=task_id, + path=str(path), + ) + + +def resolve_local_cache(use_cache: Optional[bool] = None) -> Optional[LocalCache]: + """ + Returns LocalCache instance if enabled. + Returns None if use_cached=False or config-fetched 'enabled' is False. + Deletes old cache directory if existing. + """ + global _CACHE + + if use_cache is False or (use_cache is not True and not config.local_cache.enabled): + return None + + if _CACHE is not None and _CACHE._root != Path(config.local_cache.directory): + old_root = _CACHE._root + new_root = Path(config.local_cache.directory) + log.debug(f"Moving cache directory from {old_root} → {new_root}") + try: + new_root.parent.mkdir(parents=True, exist_ok=True) + if old_root.exists(): + shutil.move(old_root, new_root) + except Exception as e: + log.warning(f"Failed to move cache directory: {e}. Delete old cache.") + _remove_cache_dir(old_root, recreate=False) + + _CACHE = LocalCache( + directory=config.local_cache.directory, + max_entries=config.local_cache.max_entries, + max_size_gb=config.local_cache.max_size_gb, + ) + + try: + return _CACHE + except Exception as err: + log.debug(f"Simulation cache unavailable: {err}") + return None + + +resolve_local_cache() diff --git a/tidy3d/_common/web/core/__init__.py b/tidy3d/_common/web/core/__init__.py new file mode 100644 index 0000000000..f1a0e1eaa8 --- /dev/null +++ b/tidy3d/_common/web/core/__init__.py @@ -0,0 +1,8 @@ +"""Tidy3d core package imports""" + +from __future__ import annotations + +# TODO(FXC-3827): Drop this import once the legacy shim is removed in Tidy3D 2.12. +from tidy3d._common.web.core import environment + +__all__ = ["environment"] diff --git a/tidy3d/_common/web/core/account.py b/tidy3d/_common/web/core/account.py new file mode 100644 index 0000000000..aefc41bd61 --- /dev/null +++ b/tidy3d/_common/web/core/account.py @@ -0,0 +1,66 @@ +"""Tidy3d user account.""" + +from __future__ import annotations + +from datetime import datetime +from typing import Optional + +from pydantic import Field + +from tidy3d._common.web.core.http_util import http +from tidy3d._common.web.core.types import Tidy3DResource + + +class Account(Tidy3DResource, extra="allow"): + """Tidy3D User Account.""" + + allowance_cycle_type: Optional[str] = Field( + None, + title="AllowanceCycleType", + description="Daily or Monthly", + alias="allowanceCycleType", + ) + credit: Optional[float] = Field( + 0, title="credit", description="Current FlexCredit balance", alias="credit" + ) + credit_expiration: Optional[datetime] = Field( + None, + title="creditExpiration", + description="Expiration date", + alias="creditExpiration", + ) + allowance_current_cycle_amount: Optional[float] = Field( + 0, + title="allowanceCurrentCycleAmount", + description="Daily/Monthly free simulation balance", + alias="allowanceCurrentCycleAmount", + ) + allowance_current_cycle_end_date: Optional[datetime] = Field( + None, + title="allowanceCurrentCycleEndDate", + description="Daily/Monthly free simulation balance expiration date", + alias="allowanceCurrentCycleEndDate", + ) + daily_free_simulation_counts: Optional[int] = Field( + 0, + title="dailyFreeSimulationCounts", + description="Daily free simulation counts", + alias="dailyFreeSimulationCounts", + ) + + @classmethod + def get(cls) -> Optional[Account]: + """Get user account information. + + Parameters + ---------- + + Returns + ------- + account : Account + """ + resp = http.get("tidy3d/py/account") + if resp: + account = Account(**resp) + return account + return None diff --git a/tidy3d/_common/web/core/cache.py b/tidy3d/_common/web/core/cache.py new file mode 100644 index 0000000000..d83421ca21 --- /dev/null +++ b/tidy3d/_common/web/core/cache.py @@ -0,0 +1,6 @@ +"""Local caches.""" + +from __future__ import annotations + +FOLDER_CACHE = {} +S3_STS_TOKENS = {} diff --git a/tidy3d/_common/web/core/constants.py b/tidy3d/_common/web/core/constants.py new file mode 100644 index 0000000000..623af2bba8 --- /dev/null +++ b/tidy3d/_common/web/core/constants.py @@ -0,0 +1,38 @@ +"""Defines constants for core.""" + +# HTTP Header key and value +from __future__ import annotations + +HEADER_APIKEY = "simcloud-api-key" +HEADER_VERSION = "tidy3d-python-version" +HEADER_SOURCE = "source" +HEADER_SOURCE_VALUE = "Python" +HEADER_USER_AGENT = "User-Agent" +HEADER_APPLICATION = "Application" +HEADER_APPLICATION_VALUE = "TIDY3D" + + +SIMCLOUD_APIKEY = "SIMCLOUD_APIKEY" +KEY_APIKEY = "apikey" +JSON_TAG = "JSON_STRING" +# type of the task_id +TaskId = str +# type of task_name +TaskName = str + + +SIMULATION_JSON = "simulation.json" +SIMULATION_DATA_HDF5 = "output/monitor_data.hdf5" +SIMULATION_DATA_HDF5_GZ = "output/simulation_data.hdf5.gz" +RUNNING_INFO = "output/solver_progress.csv" +SIM_LOG_FILE = "output/tidy3d.log" +SIM_FILE_HDF5 = "simulation.hdf5" +SIM_FILE_HDF5_GZ = "simulation.hdf5.gz" +MODE_FILE_HDF5_GZ = "mode_solver.hdf5.gz" +MODE_DATA_HDF5_GZ = "output/mode_solver_data.hdf5.gz" +SIM_ERROR_FILE = "output/tidy3d_error.json" +SIM_VALIDATION_FILE = "output/tidy3d_validation.json" + +# Component modeler specific artifacts +MODELER_FILE_HDF5_GZ = "modeler.hdf5.gz" +CM_DATA_HDF5_GZ = "output/cm_data.hdf5.gz" diff --git a/tidy3d/_common/web/core/core_config.py b/tidy3d/_common/web/core/core_config.py new file mode 100644 index 0000000000..3a9a517c71 --- /dev/null +++ b/tidy3d/_common/web/core/core_config.py @@ -0,0 +1,50 @@ +"""Tidy3d core log, need init config from Tidy3d api""" + +from __future__ import annotations + +import logging as log +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from rich.console import Console + + from tidy3d._common.log import Logger + +# default setting +config_setting = { + "logger": log, + "logger_console": None, + "version": "", +} + + +def set_config(logger: Logger, logger_console: Console, version: str) -> None: + """Init tidy3d core logger and logger console. + + Parameters + ---------- + logger : :class:`.Logger` + Tidy3d log Logger. + logger_console : :class:`.Console` + Get console from logging handlers. + version : str + tidy3d version + """ + config_setting["logger"] = logger + config_setting["logger_console"] = logger_console + config_setting["version"] = version + + +def get_logger() -> Logger: + """Get logging handlers.""" + return config_setting["logger"] + + +def get_logger_console() -> Console: + """Get console from logging handlers.""" + return config_setting["logger_console"] + + +def get_version() -> str: + """Get version from cache.""" + return config_setting["version"] diff --git a/tidy3d/_common/web/core/environment.py b/tidy3d/_common/web/core/environment.py new file mode 100644 index 0000000000..58bd8ceaef --- /dev/null +++ b/tidy3d/_common/web/core/environment.py @@ -0,0 +1,42 @@ +"""Legacy re-export of configuration environment helpers.""" + +from __future__ import annotations + +# TODO(FXC-3827): Remove this module-level legacy shim in Tidy3D 2.12. +import warnings +from typing import Any + +from tidy3d._common.config import Env, Environment, EnvironmentConfig + +__all__ = [ # noqa: F822 + "Env", + "Environment", + "EnvironmentConfig", + "dev", + "nexus", + "pre", + "prod", + "uat", +] + +_LEGACY_ENV_NAMES = {"dev", "uat", "pre", "prod", "nexus"} +_DEPRECATION_MESSAGE = ( + "'tidy3d.web.core.environment.{name}' is deprecated and will be removed in " + "Tidy3D 2.12. Transition to 'tidy3d.config.Env.{name}' or " + "'tidy3d.config.config.switch_profile(...)'." +) + + +def _get_legacy_env(name: str) -> Any: + warnings.warn(_DEPRECATION_MESSAGE.format(name=name), DeprecationWarning, stacklevel=2) + return getattr(Env, name) + + +def __getattr__(name: str) -> Any: + if name in _LEGACY_ENV_NAMES: + return _get_legacy_env(name) + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") + + +def __dir__() -> list[str]: + return sorted(set(__all__)) diff --git a/tidy3d/_common/web/core/exceptions.py b/tidy3d/_common/web/core/exceptions.py new file mode 100644 index 0000000000..e0d00d772d --- /dev/null +++ b/tidy3d/_common/web/core/exceptions.py @@ -0,0 +1,27 @@ +"""Custom Tidy3D exceptions""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from tidy3d._common.web.core.core_config import get_logger + +if TYPE_CHECKING: + from typing import Optional + +if TYPE_CHECKING: + from typing import Optional + + +class WebError(Exception): + """Any error in tidy3d""" + + def __init__(self, message: Optional[str] = None) -> None: + """Log just the error message and then raise the Exception.""" + log = get_logger() + super().__init__(message) + log.error(message) + + +class WebNotFoundError(WebError): + """A generic error indicating an HTTP 404 (resource not found).""" diff --git a/tidy3d/_common/web/core/file_util.py b/tidy3d/_common/web/core/file_util.py new file mode 100644 index 0000000000..41a9dac0a2 --- /dev/null +++ b/tidy3d/_common/web/core/file_util.py @@ -0,0 +1,87 @@ +"""File compression utilities""" + +from __future__ import annotations + +import gzip +import os +import shutil +import tempfile + +import h5py + +from tidy3d._common.web.core.constants import JSON_TAG + + +def compress_file_to_gzip(input_file: os.PathLike, output_gz_file: os.PathLike) -> None: + """Compresses a file using gzip. + + Parameters + ---------- + input_file : PathLike + The path of the input file. + output_gz_file : PathLike + The path of the output gzip file. + """ + with open(input_file, "rb") as file_in: + with gzip.open(output_gz_file, "wb") as file_out: + shutil.copyfileobj(file_in, file_out) + + +def extract_gzip_file(input_gz_file: os.PathLike, output_file: os.PathLike) -> None: + """Extract a gzip file. + + Parameters + ---------- + input_gz_file : PathLike + The path of the gzip input file. + output_file : PathLike + The path of the output file. + """ + with gzip.open(input_gz_file, "rb") as file_in: + with open(output_file, "wb") as file_out: + shutil.copyfileobj(file_in, file_out) + + +def read_simulation_from_hdf5_gz(file_name: os.PathLike) -> str: + """read simulation str from hdf5.gz""" + + hdf5_file, hdf5_file_path = tempfile.mkstemp(".hdf5") + os.close(hdf5_file) + try: + extract_gzip_file(file_name, hdf5_file_path) + # Pass the uncompressed temporary file path to the reader + json_str = read_simulation_from_hdf5(hdf5_file_path) + finally: + os.unlink(hdf5_file_path) + return json_str + + +"""TODO: _json_string_key and read_simulation_from_hdf5 are duplicated functions that also exist +as methods in Tidy3dBaseModel. For consistency it would be best if this duplication is avoided.""" + + +def _json_string_key(index: int) -> str: + """Get json string key for string chunk number ``index``.""" + if index: + return f"{JSON_TAG}_{index}" + return JSON_TAG + + +def read_simulation_from_hdf5(file_name: os.PathLike) -> bytes: + """read simulation str from hdf5""" + with h5py.File(file_name, "r") as f_handle: + num_string_parts = len([key for key in f_handle.keys() if JSON_TAG in key]) + json_string = b"" + for ind in range(num_string_parts): + json_string += f_handle[_json_string_key(ind)][()] + return json_string + + +"""End TODO""" + + +def read_simulation_from_json(file_name: os.PathLike) -> str: + """read simulation str from json""" + with open(file_name, encoding="utf-8") as json_file: + json_data = json_file.read() + return json_data diff --git a/tidy3d/_common/web/core/http_util.py b/tidy3d/_common/web/core/http_util.py new file mode 100644 index 0000000000..8512f3db59 --- /dev/null +++ b/tidy3d/_common/web/core/http_util.py @@ -0,0 +1,288 @@ +"""Http connection pool and authentication management.""" + +from __future__ import annotations + +import json +import os +import ssl +from enum import Enum +from functools import wraps +from typing import TYPE_CHECKING, Any + +import requests +from requests.adapters import HTTPAdapter +from urllib3.util.ssl_ import create_urllib3_context + +from tidy3d._common import log +from tidy3d._common.config import config +from tidy3d._common.web.core import core_config +from tidy3d._common.web.core.constants import ( + HEADER_APIKEY, + HEADER_APPLICATION, + HEADER_APPLICATION_VALUE, + HEADER_SOURCE, + HEADER_SOURCE_VALUE, + HEADER_USER_AGENT, + HEADER_VERSION, + SIMCLOUD_APIKEY, +) +from tidy3d._common.web.core.core_config import get_logger +from tidy3d._common.web.core.exceptions import WebError, WebNotFoundError + +if TYPE_CHECKING: + from typing import Callable, Optional, TypeAlias + +if TYPE_CHECKING: + from typing import Callable, Optional, TypeAlias + +JSONType: TypeAlias = dict[str, Any] | list[Any] | str | int + + +class ResponseCodes(Enum): + """HTTP response codes to handle individually.""" + + UNAUTHORIZED = 401 + OK = 200 + NOT_FOUND = 404 + + +def get_version() -> str: + """Get the version for the current environment.""" + return core_config.get_version() + # return "2.10.0rc2.1" + + +def get_user_agent() -> str: + """Get the user agent the current environment.""" + return os.environ.get("TIDY3D_AGENT", f"Python-Client/{get_version()}") + + +def api_key() -> Optional[str]: + """Get the api key for the current environment.""" + + if os.environ.get(SIMCLOUD_APIKEY): + return os.environ.get(SIMCLOUD_APIKEY) + + try: + apikey = config.web.apikey + except AttributeError: + return None + + if apikey is None: + return None + if hasattr(apikey, "get_secret_value"): + return apikey.get_secret_value() + return str(apikey) + + +def api_key_auth(request: requests.request) -> requests.request: + """Save the authentication info in a request. + + Parameters + ---------- + request : requests.request + The original request to set authentication for. + + Returns + ------- + requests.request + The request with authentication set. + """ + key = api_key() + version = get_version() + if not key: + raise ValueError( + "API key not found. To get your API key, sign into 'https://tidy3d.simulation.cloud' " + "and copy it from your 'Account' page. Then you can configure tidy3d through command " + "line 'tidy3d configure' and enter your API key when prompted. " + "Alternatively, especially if using windows, you can manually create the configuration " + "file by creating a file at their home directory '~/.tidy3d/config' (unix) or " + "'.tidy3d/config' (windows) containing the following line: " + "apikey = 'XXX'. Here XXX is your API key copied from your account page within quotes." + ) + if not version: + raise ValueError("version not found.") + + request.headers[HEADER_APIKEY] = key + request.headers[HEADER_VERSION] = version + request.headers[HEADER_SOURCE] = HEADER_SOURCE_VALUE + request.headers[HEADER_USER_AGENT] = get_user_agent() + return request + + +def get_headers() -> dict[str, Optional[str]]: + """get headers for http request. + + Returns + ------- + dict[str, str] + dictionary with "Authorization" and "Application" keys. + """ + return { + HEADER_APIKEY: api_key(), + HEADER_APPLICATION: HEADER_APPLICATION_VALUE, + HEADER_USER_AGENT: get_user_agent(), + } + + +def http_interceptor(func: Callable[..., Any]) -> Callable[..., JSONType]: + """Intercept the response and raise an exception if the status code is not 200.""" + + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> JSONType: + """The wrapper function.""" + suppress_404 = kwargs.pop("suppress_404", False) + + # Extend some capabilities of func + resp = func(*args, **kwargs) + + if resp.status_code != ResponseCodes.OK.value: + if resp.status_code == ResponseCodes.NOT_FOUND.value: + if suppress_404: + return None + raise WebNotFoundError("Resource not found (HTTP 404).") + try: + json_resp = resp.json() + except Exception: + json_resp = None + + # Build a helpful error message using available fields + err_msg = None + if isinstance(json_resp, dict): + parts = [] + for key in ("error", "message", "msg", "detail", "code", "httpStatus", "warning"): + val = json_resp.get(key) + if not val: + continue + if key == "error": + # Always include the raw 'error' payload for debugging. Also try to extract a nested message. + if isinstance(val, str): + try: + nested = json.loads(val) + if isinstance(nested, dict): + nested_msg = ( + nested.get("message") + or nested.get("error") + or nested.get("msg") + ) + if nested_msg: + parts.append(str(nested_msg)) + except Exception: + pass + parts.append(f"error={val}") + else: + parts.append(f"error={val!s}") + continue + parts.append(str(val)) + if parts: + err_msg = "; ".join(parts) + if not err_msg: + # Fallback to response text or status code + err_msg = resp.text or f"HTTP {resp.status_code}" + + # Append request context to aid debugging + try: + method = getattr(resp.request, "method", "") + url = getattr(resp.request, "url", "") + err_msg = f"{err_msg} [HTTP {resp.status_code} {method} {url}]" + except Exception: + pass + + raise WebError(err_msg) + + if not resp.text: + return None + result = resp.json() + + if isinstance(result, dict): + warning = result.get("warning") + if warning: + log = get_logger() + log.warning(warning) + + if "data" in result: + return result["data"] + + return result + + return wrapper + + +class TLSAdapter(HTTPAdapter): + def init_poolmanager(self, *args: Any, **kwargs: Any) -> None: + try: + ssl_version = ( + ssl.TLSVersion[config.web.ssl_version] + if config.web.ssl_version is not None + else None + ) + except KeyError: + log.warning(f"Invalid SSL/TLS version '{config.web.ssl_version}', using default") + ssl_version = None + context = create_urllib3_context(ssl_version=ssl_version) + kwargs["ssl_context"] = context + return super().init_poolmanager(*args, **kwargs) + + +class HttpSessionManager: + """Http util class.""" + + def __init__(self, session: requests.Session) -> None: + """Initialize the session.""" + self.session = session + self._mounted_ssl_version = None + self._ensure_tls_adapter(config.web.ssl_version) + self.session.verify = config.web.ssl_verify + + def reinit(self) -> None: + """Reinitialize the session.""" + ssl_version = config.web.ssl_version + self._ensure_tls_adapter(ssl_version) + self.session.verify = config.web.ssl_verify + + def _ensure_tls_adapter(self, ssl_version: str) -> None: + if not ssl_version: + self._mounted_ssl_version = None + return + if self._mounted_ssl_version != ssl_version: + self.session.mount("https://", TLSAdapter()) + self._mounted_ssl_version = ssl_version + + @http_interceptor + def get( + self, path: str, json: JSONType = None, params: Optional[dict[str, Any]] = None + ) -> requests.Response: + """Get the resource.""" + self.reinit() + return self.session.get( + url=config.web.build_api_url(path), auth=api_key_auth, json=json, params=params + ) + + @http_interceptor + def post(self, path: str, json: JSONType = None) -> requests.Response: + """Create the resource.""" + self.reinit() + return self.session.post(config.web.build_api_url(path), json=json, auth=api_key_auth) + + @http_interceptor + def put( + self, path: str, json: JSONType = None, files: Optional[dict[str, Any]] = None + ) -> requests.Response: + """Update the resource.""" + self.reinit() + return self.session.put( + config.web.build_api_url(path), json=json, auth=api_key_auth, files=files + ) + + @http_interceptor + def delete( + self, path: str, json: JSONType = None, params: Optional[dict[str, Any]] = None + ) -> requests.Response: + """Delete the resource.""" + self.reinit() + return self.session.delete( + config.web.build_api_url(path), auth=api_key_auth, json=json, params=params + ) + + +http = HttpSessionManager(requests.Session()) diff --git a/tidy3d/_common/web/core/s3utils.py b/tidy3d/_common/web/core/s3utils.py new file mode 100644 index 0000000000..4529f60aa9 --- /dev/null +++ b/tidy3d/_common/web/core/s3utils.py @@ -0,0 +1,478 @@ +"""handles filesystem, storage""" + +from __future__ import annotations + +import os +import tempfile +import urllib +from datetime import datetime +from enum import Enum +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import boto3 +from boto3.s3.transfer import TransferConfig +from pydantic import BaseModel, Field +from rich.progress import ( + BarColumn, + DownloadColumn, + Progress, + TextColumn, + TimeRemainingColumn, + TransferSpeedColumn, +) + +from tidy3d._common.config import config +from tidy3d._common.web.core.core_config import get_logger_console +from tidy3d._common.web.core.exceptions import WebError +from tidy3d._common.web.core.file_util import extract_gzip_file +from tidy3d._common.web.core.http_util import http + +if TYPE_CHECKING: + from collections.abc import Mapping + from os import PathLike + from typing import Callable, Optional + + import rich + +if TYPE_CHECKING: + from collections.abc import Mapping + from os import PathLike + from typing import Callable, Optional + + import rich + +IN_TRANSIT_SUFFIX = ".tmp" + + +class _UserCredential(BaseModel): + """Stores information about user credentials.""" + + access_key_id: str = Field(alias="accessKeyId") + expiration: datetime + secret_access_key: str = Field(alias="secretAccessKey") + session_token: str = Field(alias="sessionToken") + + +class _S3STSToken(BaseModel): + """Stores information about S3 token.""" + + cloud_path: str = Field(alias="cloudpath") + user_credential: _UserCredential = Field(alias="userCredentials") + + def get_bucket(self) -> str: + """Get the bucket name for this token.""" + + r = urllib.parse.urlparse(self.cloud_path) + return r.netloc + + def get_s3_key(self) -> str: + """Get the s3 key for this token.""" + + r = urllib.parse.urlparse(self.cloud_path) + return r.path[1:] + + def get_client(self) -> boto3.client: + """Get the boto client for this token. + + Automatically configures custom S3 endpoint if specified in web.env_vars. + """ + + client_kwargs = { + "service_name": "s3", + "region_name": config.web.s3_region, + "aws_access_key_id": self.user_credential.access_key_id, + "aws_secret_access_key": self.user_credential.secret_access_key, + "aws_session_token": self.user_credential.session_token, + "verify": config.web.ssl_verify, + } + + # Add custom S3 endpoint if configured (e.g., for Nexus deployments) + if config.web.env_vars and "AWS_ENDPOINT_URL_S3" in config.web.env_vars: + s3_endpoint = config.web.env_vars["AWS_ENDPOINT_URL_S3"] + client_kwargs["endpoint_url"] = s3_endpoint + + return boto3.client(**client_kwargs) + + def is_expired(self) -> bool: + """True if token is expired.""" + + return ( + self.user_credential.expiration + - datetime.now(tz=self.user_credential.expiration.tzinfo) + ).total_seconds() < 300 + + +class UploadProgress: + """Updates progressbar with the upload status. + + Attributes + ---------- + progress : rich.progress.Progress() + Progressbar instance from rich + ul_task : rich.progress.Task + Progressbar task instance. + """ + + def __init__(self, size_bytes: int, progress: rich.progress.Progress) -> None: + """initialize with the size of file and rich.progress.Progress() instance. + + Parameters + ---------- + size_bytes: int + Number of total bytes to upload. + progress : rich.progress.Progress() + Progressbar instance from rich + """ + self.progress = progress + self.ul_task = self.progress.add_task("[red]Uploading...", total=size_bytes) + + def report(self, bytes_in_chunk: Any) -> None: + """Update the progressbar with the most recent chunk. + + Parameters + ---------- + bytes_in_chunk : int + Description + """ + self.progress.update(self.ul_task, advance=bytes_in_chunk) + + +class DownloadProgress: + """Updates progressbar using the download status. + + Attributes + ---------- + progress : rich.progress.Progress() + Progressbar instance from rich + ul_task : rich.progress.Task + Progressbar task instance. + """ + + def __init__(self, size_bytes: int, progress: rich.progress.Progress) -> None: + """initialize with the size of file and rich.progress.Progress() instance + + Parameters + ---------- + size_bytes: float + Number of total bytes to download. + progress : rich.progress.Progress() + Progressbar instance from rich + """ + self.progress = progress + self.dl_task = self.progress.add_task("[red]Downloading...", total=size_bytes) + + def report(self, bytes_in_chunk: int) -> None: + """Update the progressbar with the most recent chunk. + + Parameters + ---------- + bytes_in_chunk : float + Description + """ + self.progress.update(self.dl_task, advance=bytes_in_chunk) + + +class _S3Action(Enum): + UPLOADING = "↑" + DOWNLOADING = "↓" + + +def _get_progress(action: _S3Action) -> Progress: + """Get the progress of an action.""" + + col = ( + TextColumn(f"[bold green]{_S3Action.DOWNLOADING.value}") + if action == _S3Action.DOWNLOADING + else TextColumn(f"[bold red]{_S3Action.UPLOADING.value}") + ) + return Progress( + col, + TextColumn("[bold blue]{task.fields[filename]}"), + BarColumn(), + "[progress.percentage]{task.percentage:>3.1f}%", + "•", + DownloadColumn(), + "•", + TransferSpeedColumn(), + "•", + TimeRemainingColumn(), + console=get_logger_console(), + ) + + +_s3_config = TransferConfig() + +_s3_sts_tokens: dict[str, _S3STSToken] = {} + + +def get_s3_sts_token( + resource_id: str, file_name: PathLike, extra_arguments: Optional[Mapping[str, str]] = None +) -> _S3STSToken: + """Get s3 sts token for the given resource id and file name. + + Parameters + ---------- + resource_id : str + The resource id, e.g. task id. + file_name : PathLike + The remote file name on S3. + extra_arguments : Mapping[str, str] + Additional arguments for the query url. + + Returns + ------- + _S3STSToken + The S3 STS token. + """ + file_name = str(Path(file_name).as_posix()) + cache_key = f"{resource_id}:{file_name}" + if cache_key not in _s3_sts_tokens or _s3_sts_tokens[cache_key].is_expired(): + method = f"tidy3d/py/tasks/{resource_id}/file?filename={file_name}" + if extra_arguments is not None: + method += "&" + "&".join(f"{k}={v}" for k, v in extra_arguments.items()) + resp = http.get(method) + token = _S3STSToken.model_validate(resp) + _s3_sts_tokens[cache_key] = token + return _s3_sts_tokens[cache_key] + + +def upload_file( + resource_id: str, + path: PathLike, + remote_filename: PathLike, + verbose: bool = True, + progress_callback: Optional[Callable[[float], None]] = None, + extra_arguments: Optional[Mapping[str, str]] = None, +) -> None: + """Upload a file to S3. + + Parameters + ---------- + resource_id : str + The resource id, e.g. task id. + path : PathLike + Path to the file to upload. + remote_filename : PathLike + The remote file name on S3 relative to the resource context root path. + verbose : bool = True + Whether to display a progressbar for the upload. + progress_callback : Callable[[float], None] = None + User-supplied callback function with ``bytes_in_chunk`` as argument. + extra_arguments : Mapping[str, str] + Additional arguments used to specify the upload bucket. + """ + + path = Path(path) + token = get_s3_sts_token(resource_id, remote_filename, extra_arguments) + + def _upload(_callback: Callable) -> None: + """Perform the upload with a callback function. + + Parameters + ---------- + _callback : Callable[[float], None] + Callback function for upload, accepts ``bytes_in_chunk`` + """ + + with path.open("rb") as data: + token.get_client().upload_fileobj( + data, + Bucket=token.get_bucket(), + Key=token.get_s3_key(), + Callback=_callback, + Config=_s3_config, + ExtraArgs={"ContentEncoding": "gzip"} + if token.get_s3_key().endswith(".gz") + else None, + ) + + if progress_callback is not None: + _upload(progress_callback) + else: + if verbose: + with _get_progress(_S3Action.UPLOADING) as progress: + total_size = path.stat().st_size + task_id = progress.add_task( + "upload", filename=str(remote_filename), total=total_size + ) + + def _callback(bytes_in_chunk: int) -> None: + progress.update(task_id, advance=bytes_in_chunk) + + _upload(_callback) + + progress.update(task_id, completed=total_size, refresh=True) + + else: + _upload(lambda bytes_in_chunk: None) + + +def download_file( + resource_id: str, + remote_filename: PathLike, + to_file: Optional[PathLike] = None, + verbose: bool = True, + progress_callback: Optional[Callable[[float], None]] = None, +) -> Path: + """Download file from S3. + + Parameters + ---------- + resource_id : str + The resource id, e.g. task id. + remote_filename : PathLike + Path to the remote file. + to_file : PathLike = None + Local filename to save to; if not specified, defaults to ``remote_filename`` in a + directory named after ``resource_id``. + verbose : bool = True + Whether to display a progressbar for the upload + progress_callback : Callable[[float], None] = None + User-supplied callback function with ``bytes_in_chunk`` as argument. + """ + + token = get_s3_sts_token(resource_id, remote_filename) + client = token.get_client() + meta_data = client.head_object(Bucket=token.get_bucket(), Key=token.get_s3_key()) + + # Get only last part of the remote file name + remote_basename = Path(remote_filename).name + + # set to_file if None + if to_file is None: + to_path = Path(resource_id) / remote_basename + else: + to_path = Path(to_file) + + # make the leading directories in the 'to_path', if any + to_path.parent.mkdir(parents=True, exist_ok=True) + + def _download(_callback: Callable) -> None: + """Perform the download with a callback function. + + Parameters + ---------- + _callback : Callable[[float], None] + Callback function for download, accepts ``bytes_in_chunk`` + """ + # Caller can assume the existence of the file means download succeeded. + # So make sure this file does not exist until that assumption is true. + to_path.unlink(missing_ok=True) + # Download to a temporary file. + try: + fd, tmp_file_path_str = tempfile.mkstemp(suffix=IN_TRANSIT_SUFFIX, dir=to_path.parent) + os.close(fd) # `tempfile.mkstemp()` creates and opens a randomly named file. close it. + to_path_tmp = Path(tmp_file_path_str) + client.download_file( + Bucket=token.get_bucket(), + Filename=str(to_path_tmp), + Key=token.get_s3_key(), + Callback=_callback, + Config=_s3_config, + ) + to_path_tmp.rename(to_path) + except Exception as e: + to_path_tmp.unlink(missing_ok=True) # Delete incompletely downloaded file. + raise e + + if progress_callback is not None: + _download(progress_callback) + else: + if verbose: + with _get_progress(_S3Action.DOWNLOADING) as progress: + total_size = meta_data.get("ContentLength", 0) + progress.start() + task_id = progress.add_task("download", filename=remote_basename, total=total_size) + + def _callback(bytes_in_chunk: int) -> None: + progress.update(task_id, advance=bytes_in_chunk) + + _download(_callback) + + progress.update(task_id, completed=total_size, refresh=True) + + else: + _download(lambda bytes_in_chunk: None) + + return to_path + + +def download_gz_file( + resource_id: str, + remote_filename: PathLike, + to_file: Optional[PathLike] = None, + verbose: bool = True, + progress_callback: Optional[Callable[[float], None]] = None, +) -> Path: + """Download a ``.gz`` file and unzip it into ``to_file``, unless ``to_file`` itself + ends in .gz + + Parameters + ---------- + resource_id : str + The resource id, e.g. task id. + remote_filename : PathLike + Path to the remote file. + to_file : Optional[PathLike] = None + Local filename to save to; if not specified, defaults to ``remote_filename`` with the + ``.gz`` suffix removed in a directory named after ``resource_id``. + verbose : bool = True + Whether to display a progressbar for the upload + progress_callback : Callable[[float], None] = None + User-supplied callback function with ``bytes_in_chunk`` as argument. + """ + + # If to_file is a gzip extension, just download + if to_file is None: + remote_basename = Path(remote_filename).name + if remote_basename.endswith(".gz"): + remote_basename = remote_basename[:-3] + to_path = Path(resource_id) / remote_basename + else: + to_path = Path(to_file) + + suffixes = "".join(to_path.suffixes).lower() + if suffixes.endswith(".gz"): + return download_file( + resource_id, + remote_filename, + to_file=to_path, + verbose=verbose, + progress_callback=progress_callback, + ) + + # Otherwise, download and unzip + # The tempfile is set as ``hdf5.gz`` so that the mock download in the webapi tests works + tmp_file, tmp_file_path_str = tempfile.mkstemp(".hdf5.gz") + os.close(tmp_file) + + # make the leading directories in the 'to_file', if any + to_path.parent.mkdir(parents=True, exist_ok=True) + try: + download_file( + resource_id, + remote_filename, + to_file=Path(tmp_file_path_str), + verbose=verbose, + progress_callback=progress_callback, + ) + if not Path(tmp_file_path_str).exists(): + raise WebError(f"Failed to download and extract '{remote_filename}'.") + + tmp_out_fd, tmp_out_path_str = tempfile.mkstemp( + suffix=IN_TRANSIT_SUFFIX, dir=to_path.parent + ) + os.close(tmp_out_fd) + tmp_out_path = Path(tmp_out_path_str) + try: + extract_gzip_file(Path(tmp_file_path_str), tmp_out_path) + tmp_out_path.replace(to_path) + except Exception as e: + tmp_out_path.unlink(missing_ok=True) + raise WebError( + f"Failed to extract '{remote_filename}' from '{tmp_file_path_str}' to '{to_path}'." + ) from e + finally: + Path(tmp_file_path_str).unlink(missing_ok=True) + return to_path diff --git a/tidy3d/_common/web/core/stub.py b/tidy3d/_common/web/core/stub.py new file mode 100644 index 0000000000..cebffd9ba0 --- /dev/null +++ b/tidy3d/_common/web/core/stub.py @@ -0,0 +1,84 @@ +"""Defines interface that can be subclassed to use with the tidy3d webapi""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from os import PathLike + + +class TaskStubData(ABC): + @abstractmethod + def from_file(self, file_path: PathLike) -> TaskStubData: + """Loads a :class:`TaskStubData` from .yaml, .json, or .hdf5 file. + + Parameters + ---------- + file_path : PathLike + Full path to the .yaml or .json or .hdf5 file to load the :class:`Stub` from. + + Returns + ------- + :class:`Stub` + An instance of the component class calling ``load``. + + """ + + @abstractmethod + def to_file(self, file_path: PathLike) -> None: + """Loads a :class:`Stub` from .yaml, .json, or .hdf5 file. + + Parameters + ---------- + file_path : PathLike + Full path to the .yaml or .json or .hdf5 file to load the :class:`Stub` from. + + Returns + ------- + :class:`Stub` + An instance of the component class calling ``load``. + """ + + +class TaskStub(ABC): + @abstractmethod + def from_file(self, file_path: PathLike) -> TaskStub: + """Loads a :class:`TaskStubData` from .yaml, .json, or .hdf5 file. + + Parameters + ---------- + file_path : PathLike + Full path to the .yaml or .json or .hdf5 file to load the :class:`Stub` from. + + Returns + ------- + :class:`TaskStubData` + An instance of the component class calling ``load``. + """ + + @abstractmethod + def to_file(self, file_path: PathLike) -> None: + """Loads a :class:`TaskStub` from .yaml, .json, .hdf5 or .hdf5.gz file. + + Parameters + ---------- + file_path : PathLike + Full path to the .yaml or .json or .hdf5 file to load the :class:`TaskStub` from. + + Returns + ------- + :class:`Stub` + An instance of the component class calling ``load``. + """ + + @abstractmethod + def to_hdf5_gz(self, fname: PathLike) -> None: + """Exports :class:`TaskStub` instance to .hdf5.gz file. + + Parameters + ---------- + fname : PathLike + Full path to the .hdf5.gz file to save the :class:`TaskStub` to. + """ diff --git a/tidy3d/_common/web/core/task_core.py b/tidy3d/_common/web/core/task_core.py new file mode 100644 index 0000000000..d613dec53d --- /dev/null +++ b/tidy3d/_common/web/core/task_core.py @@ -0,0 +1,1010 @@ +"""Tidy3d webapi types.""" + +from __future__ import annotations + +import os +import pathlib +import tempfile +from datetime import datetime +from typing import TYPE_CHECKING, Optional + +from botocore.exceptions import ClientError +from pydantic import Field, TypeAdapter + +from tidy3d._common.config import config +from tidy3d._common.exceptions import ValidationError +from tidy3d._common.log import log +from tidy3d._common.web.core import http_util +from tidy3d._common.web.core.cache import FOLDER_CACHE +from tidy3d._common.web.core.constants import ( + SIM_ERROR_FILE, + SIM_FILE_HDF5_GZ, + SIM_LOG_FILE, + SIM_VALIDATION_FILE, + SIMULATION_DATA_HDF5_GZ, +) +from tidy3d._common.web.core.core_config import get_logger_console +from tidy3d._common.web.core.exceptions import WebError, WebNotFoundError +from tidy3d._common.web.core.file_util import read_simulation_from_hdf5 +from tidy3d._common.web.core.http_util import get_version as _get_protocol_version +from tidy3d._common.web.core.http_util import http +from tidy3d._common.web.core.s3utils import download_file, download_gz_file, upload_file +from tidy3d._common.web.core.task_info import BatchDetail, TaskInfo +from tidy3d._common.web.core.types import ( + PayType, + Queryable, + ResourceLifecycle, + Submittable, + Tidy3DResource, +) + +if TYPE_CHECKING: + from os import PathLike + from typing import Callable, Union + + import requests + + from tidy3d._common.web.core.stub import TaskStub + + +class Folder(Tidy3DResource, Queryable, extra="allow"): + """Tidy3D Folder.""" + + folder_id: str = Field( + title="Folder id", + description="folder id", + alias="projectId", + ) + folder_name: str = Field( + title="Folder name", + description="folder name", + alias="projectName", + ) + + @classmethod + def list(cls, projects_endpoint: str = "tidy3d/projects") -> []: + """List all folders. + + Returns + ------- + folders : [Folder] + List of folders + """ + resp = http.get(projects_endpoint) + return TypeAdapter(list[Folder]).validate_python(resp) if resp else None + + @classmethod + def get( + cls, + folder_name: str, + create: bool = False, + projects_endpoint: str = "tidy3d/projects", + project_endpoint: str = "tidy3d/project", + ) -> Folder: + """Get folder by name. + + Parameters + ---------- + folder_name : str + Name of the folder. + create : str + If the folder doesn't exist, create it. + + Returns + ------- + folder : Folder + """ + folder = FOLDER_CACHE.get(folder_name) + if not folder: + resp = http.get(project_endpoint, params={"projectName": folder_name}) + if resp: + folder = Folder(**resp) + if create and not folder: + resp = http.post(projects_endpoint, {"projectName": folder_name}) + if resp: + folder = Folder(**resp) + FOLDER_CACHE[folder_name] = folder + return folder + + @classmethod + def create(cls, folder_name: str) -> Folder: + """Create a folder, return existing folder if there is one has the same name. + + Parameters + ---------- + folder_name : str + Name of the folder. + + Returns + ------- + folder : Folder + """ + return Folder.get(folder_name, True) + + def delete(self, projects_endpoint: str = "tidy3d/projects") -> None: + """Remove this folder.""" + + http.delete(f"{projects_endpoint}/{self.folder_id}") + + def delete_old(self, days_old: int) -> int: + """Remove folder contents older than ``days_old``.""" + + return http.delete( + f"tidy3d/tasks/{self.folder_id}/tasks", + params={"daysOld": days_old}, + ) + + def list_tasks(self, projects_endpoint: str = "tidy3d/projects") -> list[Tidy3DResource]: + """List all tasks in this folder. + + Returns + ------- + tasks : list[:class:`.SimulationTask`] + List of tasks in this folder + """ + resp = http.get(f"{projects_endpoint}/{self.folder_id}/tasks") + return TypeAdapter(list[SimulationTask]).validate_python(resp) if resp else None + + +class WebTask(ResourceLifecycle, Submittable, extra="allow"): + """Interface for managing the running a task on the server.""" + + task_id: Optional[str] = Field( + None, + title="task_id", + description="Task ID number, set when the task is uploaded, leave as None.", + alias="taskId", + ) + + @classmethod + def create( + cls, + task_type: str, + task_name: str, + folder_name: str = "default", + callback_url: Optional[str] = None, + simulation_type: str = "tidy3d", + parent_tasks: Optional[list[str]] = None, + file_type: str = "Gz", + projects_endpoint: str = "tidy3d/projects", + ) -> SimulationTask: + """Create a new task on the server. + + Parameters + ---------- + task_type: :class".TaskType" + The type of task. + task_name: str + The name of the task. + folder_name: str, + The name of the folder to store the task. Default is "default". + callback_url: str + Http PUT url to receive simulation finish event. The body content is a json file with + fields ``{'id', 'status', 'name', 'workUnit', 'solverVersion'}``. + simulation_type : str + Type of simulation being uploaded. + parent_tasks : list[str] + List of related task ids. + file_type: str + the simulation file type Json, Hdf5, Gz + + Returns + ------- + :class:`SimulationTask` + :class:`SimulationTask` object containing info about status, size, + credits of task and others. + """ + + # handle backwards compatibility, "tidy3d" is the default simulation_type + if simulation_type is None: + simulation_type = "tidy3d" + + folder = Folder.get(folder_name, create=True) + if task_type in ["RF", "TERMINAL_CM", "MODAL_CM"]: + payload = { + "groupName": task_name, + "folderId": folder.folder_id, + "fileType": file_type, + "taskType": task_type, + } + resp = http.post("rf/task", payload) + else: + payload = { + "taskName": task_name, + "taskType": task_type, + "callbackUrl": callback_url, # type: ignore[dict-item] + "simulationType": simulation_type, + "parentTasks": parent_tasks, # type: ignore[dict-item] + "fileType": file_type, + } + resp = http.post(f"{projects_endpoint}/{folder.folder_id}/tasks", payload) + return SimulationTask(**resp, taskType=task_type, folder_name=folder_name) + + def get_url(self) -> str: + base = str(config.web.website_endpoint or "") + if isinstance(self, BatchTask): + return "/".join([base.rstrip("/"), f"rf?taskId={self.task_id}"]) + return "/".join([base.rstrip("/"), f"workbench?taskId={self.task_id}"]) + + def get_folder_url(self) -> Optional[str]: + folder_id = getattr(self, "folder_id", None) + if not folder_id: + return None + base = str(config.web.website_endpoint or "") + return "/".join([base.rstrip("/"), f"folders/{folder_id}"]) + + def get_log( + self, + to_file: PathLike, + verbose: bool = True, + progress_callback: Optional[Callable[[float], None]] = None, + ) -> pathlib.Path: + """Get log file from Server. + + Parameters + ---------- + to_file: PathLike + Save file to path. + verbose: bool = True + Whether to display progress bars. + progress_callback : Callable[[float], None] = None + Optional callback function called while downloading the data. + + Returns + ------- + path: pathlib.Path + Path to saved file. + """ + + if not self.task_id: + raise WebError("Expected field 'task_id' is unset.") + + target_path = pathlib.Path(to_file) + + return download_file( + self.task_id, + SIM_LOG_FILE, + to_file=target_path, + verbose=verbose, + progress_callback=progress_callback, + ) + + def get_data_hdf5( + self, + to_file: PathLike, + remote_data_file_gz: PathLike = SIMULATION_DATA_HDF5_GZ, + verbose: bool = True, + progress_callback: Optional[Callable[[float], None]] = None, + ) -> pathlib.Path: + """Download data artifact (simulation or batch) with gz fallback handling. + + Parameters + ---------- + remote_data_file_gz : PathLike + Gzipped remote filename. + to_file : PathLike + Local target path. + verbose : bool + Whether to log progress. + progress_callback : Optional[Callable[[float], None]] + Progress callback. + + Returns + ------- + pathlib.Path + Saved local path. + """ + if not self.task_id: + raise WebError("Expected field 'task_id' is unset.") + target_path = pathlib.Path(to_file) + file = None + try: + file = download_gz_file( + resource_id=self.task_id, + remote_filename=remote_data_file_gz, + to_file=target_path, + verbose=verbose, + progress_callback=progress_callback, + ) + except ClientError: + if verbose: + console = get_logger_console() + console.log(f"Unable to download '{remote_data_file_gz}'.") + if not file: + try: + file = download_file( + resource_id=self.task_id, + remote_filename=str(remote_data_file_gz)[:-3], + to_file=target_path, + verbose=verbose, + progress_callback=progress_callback, + ) + except Exception as e: + raise WebError( + "Failed to download the data file from the server. " + "Please confirm that the task completed successfully." + ) from e + return file + + @staticmethod + def is_batch(resource_id: str) -> bool: + """Checks if a given resource ID corresponds to a valid batch task. + + This is a utility function to verify a batch task's existence before + instantiating the class. + + Parameters + ---------- + resource_id : str + The unique identifier for the resource. + + Returns + ------- + bool + ``True`` if the resource is a valid batch task, ``False`` otherwise. + """ + try: + # TODO PROPERLY FIXME + # Disable non critical logs due to check for resourceId, until we have a dedicated API for this + resp = http.get( + f"rf/task/{resource_id}/statistics", + suppress_404=True, + ) + status = bool(resp and isinstance(resp, dict) and "status" in resp) + return status + except Exception: + return False + + def delete(self, versions: bool = False) -> None: + """Delete current task from server. + + Parameters + ---------- + versions : bool = False + If ``True``, delete all versions of the task in the task group. Otherwise, delete only + the version associated with the current task ID. + """ + if not self.task_id: + raise ValueError("Task id not found.") + + task_details = self.detail().model_dump() + + if task_details and "groupId" in task_details: + group_id = task_details["groupId"] + if versions: + http.delete("tidy3d/group", json={"groupIds": [group_id]}) + return + elif "version" in task_details: + version = task_details["version"] + http.delete(f"tidy3d/group/{group_id}/versions", json={"versions": [version]}) + return + + # Fallback to old method if we can't get the groupId and version + http.delete(f"tidy3d/tasks/{self.task_id}") + + +class SimulationTask(WebTask): + """Interface for managing the running of solver tasks on the server.""" + + folder_id: Optional[str] = Field( + None, + title="folder_id", + description="Folder ID number, set when the task is uploaded, leave as None.", + alias="folderId", + ) + status: Optional[str] = Field(None, title="status", description="Simulation task status.") + + real_flex_unit: Optional[float] = Field( + None, title="real FlexCredits", description="Billed FlexCredits.", alias="realCost" + ) + + created_at: Optional[datetime] = Field( + None, + title="created_at", + description="Time at which this task was created.", + alias="createdAt", + ) + + task_type: Optional[str] = Field( + None, title="task_type", description="The type of task.", alias="taskType" + ) + + folder_name: Optional[str] = Field( + "default", + title="Folder Name", + description="Name of the folder associated with this task.", + alias="folderName", + ) + + callback_url: Optional[str] = Field( + None, + title="Callback URL", + description="Http PUT url to receive simulation finish event. " + "The body content is a json file with fields " + "``{'id', 'status', 'name', 'workUnit', 'solverVersion'}``.", + ) + + # simulation_type: str = Field( + # None, + # title="Simulation Type", + # description="Type of simulation, used internally only.", + # ) + + # parent_tasks: Tuple[TaskId, ...] = Field( + # None, + # title="Parent Tasks", + # description="List of parent task ids for the simulation, used internally only." + # ) + + @classmethod + def get(cls, task_id: str, verbose: bool = True) -> SimulationTask: + """Get task from the server by id. + + Parameters + ---------- + task_id: str + Unique identifier of task on server. + verbose: + If `True`, will print progressbars and status, otherwise, will run silently. + + Returns + ------- + :class:`.SimulationTask` + :class:`.SimulationTask` object containing info about status, + size, credits of task and others. + """ + try: + resp = http.get(f"tidy3d/tasks/{task_id}/detail") + except WebNotFoundError as e: + log.error(f"The requested task ID '{task_id}' does not exist.") + raise e + + task = SimulationTask(**resp) if resp else None + return task + + @classmethod + def get_running_tasks(cls) -> list[SimulationTask]: + """Get a list of running tasks from the server" + + Returns + ------- + List[:class:`.SimulationTask`] + :class:`.SimulationTask` object containing info about status, + size, credits of task and others. + """ + resp = http.get("tidy3d/py/tasks") + if not resp: + return [] + return TypeAdapter(list[SimulationTask]).validate_python(resp) + + def detail(self) -> TaskInfo: + """Fetches the detailed information and status of the task. + + Returns + ------- + TaskInfo + An object containing the task's latest data. + """ + resp = http.get(f"tidy3d/tasks/{self.task_id}/detail") + return TaskInfo(**{"taskId": self.task_id, "taskType": self.task_type, **resp}) # type: ignore[dict-item] + + def get_simulation_json(self, to_file: PathLike, verbose: bool = True) -> None: + """Get json file for a :class:`.Simulation` from server. + + Parameters + ---------- + to_file: PathLike + Save file to path. + verbose: bool = True + Whether to display progress bars. + """ + if not self.task_id: + raise WebError("Expected field 'task_id' is unset.") + + to_file = pathlib.Path(to_file) + + hdf5_file, hdf5_file_path = tempfile.mkstemp(".hdf5") + os.close(hdf5_file) + try: + self.get_simulation_hdf5(hdf5_file_path) + if os.path.exists(hdf5_file_path): + json_string = read_simulation_from_hdf5(hdf5_file_path) + to_file.parent.mkdir(parents=True, exist_ok=True) + with to_file.open("w", encoding="utf-8") as file: + # Write the string to the file + file.write(json_string.decode("utf-8")) + if verbose: + console = get_logger_console() + console.log(f"Generate {to_file} successfully.") + else: + raise WebError("Failed to download simulation.json.") + finally: + os.unlink(hdf5_file_path) + + def upload_simulation( + self, + stub: TaskStub, + verbose: bool = True, + progress_callback: Optional[Callable[[float], None]] = None, + remote_sim_file: PathLike = SIM_FILE_HDF5_GZ, + ) -> None: + """Upload :class:`.Simulation` object to Server. + + Parameters + ---------- + stub: :class:`TaskStub` + An instance of TaskStub. + verbose: bool = True + Whether to display progress bars. + progress_callback : Callable[[float], None] = None + Optional callback function called while uploading the data. + """ + if not self.task_id: + raise WebError("Expected field 'task_id' is unset.") + if not stub: + raise WebError("Expected field 'simulation' is unset.") + # Also upload hdf5.gz containing all data. + file, file_name = tempfile.mkstemp() + os.close(file) + try: + # upload simulation + # compress .hdf5 to .hdf5.gz + stub.to_hdf5_gz(file_name) + upload_file( + self.task_id, + file_name, + remote_sim_file, + verbose=verbose, + progress_callback=progress_callback, + ) + finally: + os.unlink(file_name) + + def upload_file( + self, + local_file: PathLike, + remote_filename: str, + verbose: bool = True, + progress_callback: Optional[Callable[[float], None]] = None, + ) -> None: + """ + Upload file to platform. Using this method when the json file is too large to parse + as :class".simulation". + Parameters + ---------- + local_file: PathLike + Local file path. + remote_filename: str + file name on the server + verbose: bool = True + Whether to display progress bars. + progress_callback : Callable[[float], None] = None + Optional callback function called while uploading the data. + """ + if not self.task_id: + raise WebError("Expected field 'task_id' is unset.") + + upload_file( + self.task_id, + local_file, + remote_filename, + verbose=verbose, + progress_callback=progress_callback, + ) + + def submit( + self, + solver_version: Optional[str] = None, + worker_group: Optional[str] = None, + pay_type: Union[PayType, str] = PayType.AUTO, + priority: Optional[int] = None, + ) -> None: + """Kick off this task. + + It will be uploaded to server before + starting the task. Otherwise, this method assumes that the Simulation has been uploaded by + the upload_file function, so the task will be kicked off directly. + + Parameters + ---------- + solver_version: str = None + target solver version. + worker_group: str = None + worker group + pay_type: Union[PayType, str] = PayType.AUTO + Which method to pay the simulation. + priority: int = None + Priority of the simulation in the Virtual GPU (vGPU) queue (1 = lowest, 10 = highest). + It affects only simulations from vGPU licenses and does not impact simulations using FlexCredits. + """ + pay_type = PayType(pay_type) if not isinstance(pay_type, PayType) else pay_type + + if solver_version: + protocol_version = None + else: + protocol_version = http_util.get_version() + + http.post( + f"tidy3d/tasks/{self.task_id}/submit", + { + "solverVersion": solver_version, + "workerGroup": worker_group, + "protocolVersion": protocol_version, + "enableCaching": config.web.enable_caching, + "payType": pay_type.value, + "priority": priority, + }, + ) + + def estimate_cost(self, solver_version: Optional[str] = None) -> float: + """Compute the maximum flex unit charge for a given task, assuming the simulation runs for + the full ``run_time``. If early shut-off is triggered, the cost is adjusted proportionately. + + Parameters + ---------- + solver_version: str + target solver version. + + Returns + ------- + flex_unit_cost: float + estimated cost in FlexCredits + """ + if not self.task_id: + raise WebError("Expected field 'task_id' is unset.") + + if solver_version: + protocol_version = None + else: + protocol_version = http_util.get_version() + + resp = http.post( + f"tidy3d/tasks/{self.task_id}/metadata", + { + "solverVersion": solver_version, + "protocolVersion": protocol_version, + }, + ) + return resp + + def get_simulation_hdf5( + self, + to_file: PathLike, + verbose: bool = True, + progress_callback: Optional[Callable[[float], None]] = None, + remote_sim_file: PathLike = SIM_FILE_HDF5_GZ, + ) -> pathlib.Path: + """Get simulation.hdf5 file from Server. + + Parameters + ---------- + to_file: PathLike + Save file to path. + verbose: bool = True + Whether to display progress bars. + progress_callback : Callable[[float], None] = None + Optional callback function called while downloading the data. + + Returns + ------- + path: pathlib.Path + Path to saved file. + """ + if not self.task_id: + raise WebError("Expected field 'task_id' is unset.") + + target_path = pathlib.Path(to_file) + + return download_gz_file( + resource_id=self.task_id, + remote_filename=remote_sim_file, + to_file=target_path, + verbose=verbose, + progress_callback=progress_callback, + ) + + def get_running_info(self) -> tuple[float, float]: + """Gets the % done and field_decay for a running task. + + Returns + ------- + perc_done : float + Percentage of run done (in terms of max number of time steps). + Is ``None`` if run info not available. + field_decay : float + Average field intensity normalized to max value (1.0). + Is ``None`` if run info not available. + """ + if not self.task_id: + raise WebError("Expected field 'task_id' is unset.") + + resp = http.get(f"tidy3d/tasks/{self.task_id}/progress") + perc_done = resp.get("perc_done") + field_decay = resp.get("field_decay") + return perc_done, field_decay + + def get_log( + self, + to_file: PathLike, + verbose: bool = True, + progress_callback: Optional[Callable[[float], None]] = None, + ) -> pathlib.Path: + """Get log file from Server. + + Parameters + ---------- + to_file: PathLike + Save file to path. + verbose: bool = True + Whether to display progress bars. + progress_callback : Callable[[float], None] = None + Optional callback function called while downloading the data. + + Returns + ------- + path: pathlib.Path + Path to saved file. + """ + + if not self.task_id: + raise WebError("Expected field 'task_id' is unset.") + + target_path = pathlib.Path(to_file) + + return download_file( + self.task_id, + SIM_LOG_FILE, + to_file=target_path, + verbose=verbose, + progress_callback=progress_callback, + ) + + def get_error_json( + self, to_file: PathLike, verbose: bool = True, validation: bool = False + ) -> pathlib.Path: + """Get error json file for a :class:`.Simulation` from server. + + Parameters + ---------- + to_file: PathLike + Save file to path. + verbose: bool = True + Whether to display progress bars. + validation: bool = False + Whether to get a validation error file or a solver error file. + + Returns + ------- + path: pathlib.Path + Path to saved file. + """ + if not self.task_id: + raise WebError("Expected field 'task_id' is unset.") + + target_path = pathlib.Path(to_file) + target_file = SIM_ERROR_FILE if not validation else SIM_VALIDATION_FILE + + return download_file( + self.task_id, + target_file, + to_file=target_path, + verbose=verbose, + ) + + def abort(self) -> requests.Response: + """Abort the current task on the server.""" + if not self.task_id: + raise ValueError("Task id not found.") + return http.put( + "tidy3d/tasks/abort", json={"taskType": self.task_type, "taskId": self.task_id} + ) + + def validate_post_upload(self, parent_tasks: Optional[list[str]] = None) -> None: + """Perform checks after task is uploaded and metadata is processed.""" + if self.task_type == "HEAT_CHARGE" and parent_tasks: + try: + if len(parent_tasks) > 1: + raise ValueError( + "A single parent 'task_id' corresponding to the task in which the meshing " + "was run must be provided." + ) + try: + # get mesh task info + mesh_task = SimulationTask.get(parent_tasks[0], verbose=False) + assert mesh_task.task_type == "VOLUME_MESH" + assert mesh_task.status == "success" + # get up-to-date task info + task = SimulationTask.get(self.task_id, verbose=False) + if task.fileMd5 != mesh_task.childFileMd5: + raise ValidationError( + "Simulation stored in parent task 'VolumeMesher' does not match the " + "current simulation." + ) + except Exception as e: + raise ValidationError( + "The parent task must be a 'VolumeMesher' task which has been successfully " + "run and is associated to the same 'HeatChargeSimulation' as provided here." + ) from e + + except Exception as e: + raise WebError(f"Provided 'parent_tasks' failed validation: {e!s}") from e + + +class BatchTask(WebTask): + """Interface for managing a batch task on the server.""" + + task_type: Optional[str] = Field( + None, title="task_type", description="The type of task.", alias="taskType" + ) + + @classmethod + def get(cls, task_id: str, verbose: bool = True) -> BatchTask: + """Get batch task by id. + + Parameters + ---------- + task_id: str + Unique identifier of batch on server. + verbose: + If `True`, will print progressbars and status, otherwise, will run silently. + + Returns + ------- + :class:`.BatchTask` | None + BatchTask object if found, otherwise None. + """ + try: + resp = http.get(f"rf/task/{task_id}/statistics") + except WebNotFoundError as e: + log.error(f"The requested batch ID '{task_id}' does not exist.") + raise e + # Extract taskType from response if available + if resp: + task_type = resp.get("taskType") if isinstance(resp, dict) else None + return BatchTask(taskId=task_id, taskType=task_type) + return None + + def detail(self) -> BatchDetail: + """Fetches the detailed information and status of the batch. + + Returns + ------- + BatchDetail + An object containing the batch's latest data. + """ + resp = http.get( + f"rf/task/{self.task_id}/statistics", + ) + # Some backends may return null for collection fields; coerce to sensible defaults + if isinstance(resp, dict): + if resp.get("tasks") is None: + resp["tasks"] = [] + return BatchDetail(**(resp or {})) + + def check( + self, + check_task_type: str, + solver_version: Optional[str] = None, + protocol_version: Optional[str] = None, + ) -> requests.Response: + """Submits a request to validate the batch configuration on the server. + + Parameters + ---------- + solver_version : Optional[str], default=None + The version of the solver to use for validation. + protocol_version : Optional[str], default=None + The data protocol version. Defaults to the current version. + + Returns + ------- + Any + The server's response to the check request. + """ + if protocol_version is None: + protocol_version = _get_protocol_version() + return http.post( + f"rf/task/{self.task_id}/check", + { + "solverVersion": solver_version, + "protocolVersion": protocol_version, + "taskType": check_task_type, + }, + ) + + def submit( + self, + solver_version: Optional[str] = None, + protocol_version: Optional[str] = None, + worker_group: Optional[str] = None, + pay_type: Union[PayType, str] = PayType.AUTO, + priority: Optional[int] = None, + ) -> requests.Response: + """Submits the batch for execution on the server. + + Parameters + ---------- + solver_version : Optional[str], default=None + The version of the solver to use for execution. + protocol_version : Optional[str], default=None + The data protocol version. Defaults to the current version. + worker_group : Optional[str], default=None + Optional identifier for a specific worker group to run on. + + Returns + ------- + Any + The server's response to the submit request. + """ + + # TODO: add support for pay_type and priority arguments + if pay_type != PayType.AUTO: + raise NotImplementedError( + "The 'pay_type' argument is not yet supported and will be ignored." + ) + if priority is not None: + raise NotImplementedError( + "The 'priority' argument is not yet supported and will be ignored." + ) + + if protocol_version is None: + protocol_version = _get_protocol_version() + return http.post( + f"rf/task/{self.task_id}/submit", + { + "solverVersion": solver_version, + "protocolVersion": protocol_version, + "workerGroup": worker_group, + }, + ) + + def abort(self) -> requests.Response: + """Abort the current task on the server.""" + if not self.task_id: + raise ValueError("Batch id not found.") + return http.put(f"rf/task/{self.task_id}/abort", {}) + + +class TaskFactory: + """Factory for obtaining the correct task subclass.""" + + _REGISTRY: dict[str, type[WebTask]] = {} + + @classmethod + def reset(cls) -> None: + """Clear the cached task kind registry (used in tests).""" + cls._REGISTRY.clear() + + @classmethod + def register(cls, task_id: str, kind: type[WebTask]) -> None: + cls._REGISTRY[task_id] = kind + + @classmethod + def get_kind(cls, task_id: str, verbose: bool = True) -> type[WebTask]: + """Return cached task class, fetching and caching if needed.""" + kind = cls._REGISTRY.get(task_id) + if kind: + return kind + if WebTask.is_batch(task_id): + cls.register(task_id, BatchTask) + return BatchTask + task = SimulationTask.get(task_id, verbose=verbose) + if task: + cls.register(task_id, SimulationTask) + return SimulationTask + + @classmethod + def get(cls, task_id: str, verbose: bool = True) -> WebTask: + kind = cls._REGISTRY.get(task_id) + if kind is BatchTask: + return BatchTask.get(task_id, verbose=verbose) + if kind is SimulationTask: + task = SimulationTask.get(task_id, verbose=verbose) + return task + if WebTask.is_batch(task_id): + cls.register(task_id, BatchTask) + return BatchTask.get(task_id, verbose=verbose) + task = SimulationTask.get(task_id, verbose=verbose) + if task: + cls.register(task_id, SimulationTask) + return task diff --git a/tidy3d/_common/web/core/task_info.py b/tidy3d/_common/web/core/task_info.py new file mode 100644 index 0000000000..c42ba0f220 --- /dev/null +++ b/tidy3d/_common/web/core/task_info.py @@ -0,0 +1,328 @@ +"""Defines information about a task""" + +from __future__ import annotations + +from abc import ABC +from datetime import datetime +from enum import Enum +from typing import Annotated, Optional + +from pydantic import BaseModel, ConfigDict, Field + + +class TaskBase(BaseModel, ABC): + """Base configuration for all task objects.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + +class ChargeType(str, Enum): + """The payment method of the task.""" + + FREE = "free" + """No payment required.""" + + PAID = "paid" + """Payment required.""" + + +class TaskBlockInfo(TaskBase): + """Information about the task's block status. + + Notes + ----- + This includes details about how the task can be blocked by various features + such as user limits and insufficient balance. + """ + + chargeType: Optional[ChargeType] = None + """The type of charge applicable to the task (free or paid).""" + + maxFreeCount: Optional[int] = None + """The maximum number of free tasks allowed.""" + + maxGridPoints: Optional[int] = None + """The maximum number of grid points permitted.""" + + maxTimeSteps: Optional[int] = None + """The maximum number of time steps allowed.""" + + +class TaskInfo(TaskBase): + """General information about a task.""" + + taskId: str + """Unique identifier for the task.""" + + taskName: Optional[str] = None + """Name of the task.""" + + nodeSize: Optional[int] = None + """Size of the node allocated for the task.""" + + completedAt: Optional[datetime] = None + """Timestamp when the task was completed.""" + + status: Optional[str] = None + """Current status of the task.""" + + realCost: Optional[float] = None + """Actual cost incurred by the task.""" + + timeSteps: Optional[int] = None + """Number of time steps involved in the task.""" + + solverVersion: Optional[str] = None + """Version of the solver used for the task.""" + + createAt: Optional[datetime] = None + """Timestamp when the task was created.""" + + estCostMin: Optional[float] = None + """Estimated minimum cost for the task.""" + + estCostMax: Optional[float] = None + """Estimated maximum cost for the task.""" + + realFlexUnit: Optional[float] = None + """Actual flexible units used by the task.""" + + oriRealFlexUnit: Optional[float] = None + """Original real flexible units.""" + + estFlexUnit: Optional[float] = None + """Estimated flexible units for the task.""" + + estFlexCreditTimeStepping: Optional[float] = None + """Estimated flexible credits for time stepping.""" + + estFlexCreditPostProcess: Optional[float] = None + """Estimated flexible credits for post-processing.""" + + estFlexCreditMode: Optional[float] = None + """Estimated flexible credits based on the mode.""" + + s3Storage: Optional[float] = None + """Amount of S3 storage used by the task.""" + + startSolverTime: Optional[datetime] = None + """Timestamp when the solver started.""" + + finishSolverTime: Optional[datetime] = None + """Timestamp when the solver finished.""" + + totalSolverTime: Optional[int] = None + """Total time taken by the solver.""" + + callbackUrl: Optional[str] = None + """Callback URL for task notifications.""" + + taskType: Optional[str] = None + """Type of the task.""" + + metadataStatus: Optional[str] = None + """Status of the metadata for the task.""" + + taskBlockInfo: Optional[TaskBlockInfo] = None + """Blocking information for the task.""" + + version: Optional[str] = None + """Version of the task.""" + + +class RunInfo(TaskBase): + """Information about the run of a task.""" + + perc_done: Annotated[float, Field(ge=0.0, le=100.0)] + """Percentage of the task that is completed (0 to 100).""" + + field_decay: Annotated[float, Field(ge=0.0, le=1.0)] + """Field decay from the maximum value (0 to 1).""" + + def display(self) -> None: + """Print some info about the task's progress.""" + print(f" - {self.perc_done:.2f} (%) done") + print(f" - {self.field_decay:.2e} field decay from max") + + +# ---------------------- Batch (Modeler) detail schema ---------------------- # + + +class BatchTaskBlockInfo(TaskBlockInfo): + """ + Extends `TaskBlockInfo` with specific details for batch task blocking. + + Attributes: + accountLimit: A usage or cost limit imposed by the user's account. + taskBlockMsg: A human-readable message describing the reason for the block. + taskBlockType: The specific type of block (e.g., 'balance', 'limit'). + blockStatus: The current blocking status for the batch. + taskStatus: The status of the task when it was blocked. + """ + + accountLimit: Optional[float] = None + taskBlockMsg: Optional[str] = None + taskBlockType: Optional[str] = None + blockStatus: Optional[str] = None + taskStatus: Optional[str] = None + + +class BatchMember(TaskBase): + """ + Represents a single task within a larger batch operation. + + Attributes: + refId: A reference identifier for the member task. + folderId: The identifier of the folder containing the task. + sweepId: The identifier for the parameter sweep, if applicable. + taskId: The unique identifier of the task. + linkedTaskId: The identifier of a task linked to this one. + groupId: The identifier of the group this task belongs to. + taskName: The name of the individual task. + status: The current status of this specific task. + sweepData: Data associated with a parameter sweep. + validateInfo: Information related to the task's validation. + replaceData: Data used for replacements or modifications. + protocolVersion: The version of the protocol used. + variable: The variable parameter for this task in a sweep. + createdAt: The timestamp when the member task was created. + updatedAt: The timestamp when the member task was last updated. + denormalizeStatus: The status of the data denormalization process. + summary: A dictionary containing summary information for the task. + """ + + refId: Optional[str] = None + folderId: Optional[str] = None + sweepId: Optional[str] = None + taskId: Optional[str] = None + linkedTaskId: Optional[str] = None + groupId: Optional[str] = None + taskName: Optional[str] = None + status: Optional[str] = None + sweepData: Optional[str] = None + validateInfo: Optional[str] = None + replaceData: Optional[str] = None + protocolVersion: Optional[str] = None + variable: Optional[str] = None + createdAt: Optional[datetime] = None + updatedAt: Optional[datetime] = None + denormalizeStatus: Optional[str] = None + summary: Optional[dict] = None + + +class BatchDetail(TaskBase): + """Provides a detailed, top-level view of a batch of tasks. + + Notes + ----- + This model serves as the main payload for retrieving comprehensive + information about a batch operation. + + Attributes + ---------- + refId + A reference identifier for the entire batch. + optimizationId + Identifier for the optimization process, if any. + groupId + Identifier for the group the batch belongs to. + name + The user-defined name of the batch. + status + The current status of the batch. + totalTask + The total number of tasks in the batch. + preprocessSuccess + The count of tasks that completed preprocessing. + postprocessStatus + The status of the batch's postprocessing stage. + validateSuccess + The count of tasks that passed validation. + runSuccess + The count of tasks that ran successfully. + postprocessSuccess + The count of tasks that completed postprocessing. + taskBlockInfo + Information on what might be blocking the batch. + estFlexUnit + The estimated total flexible compute units for the batch. + totalSeconds + The total time in seconds the batch has taken. + totalCheckMillis + Total time in milliseconds spent on checks. + message + A general message providing information about the batch status. + tasks + A list of `BatchMember` objects, one for each task in the batch. + taskType + The type of tasks contained in the batch. + """ + + refId: Optional[str] = None + optimizationId: Optional[str] = None + groupId: Optional[str] = None + name: Optional[str] = None + status: Optional[str] = None + totalTask: int = 0 + preprocessSuccess: int = 0 + postprocessStatus: Optional[str] = None + validateSuccess: int = 0 + runSuccess: int = 0 + postprocessSuccess: int = 0 + taskBlockInfo: Optional[BatchTaskBlockInfo] = None + estFlexUnit: Optional[float] = None + realFlexUnit: Optional[float] = None + totalSeconds: Optional[int] = None + totalCheckMillis: Optional[int] = None + message: Optional[str] = None + tasks: list[BatchMember] = [] + validateErrors: Optional[dict] = None + taskType: str = None + version: Optional[str] = None + + +class AsyncJobDetail(TaskBase): + """Provides a detailed view of an asynchronous job and its sub-tasks. + + Notes + ----- + This model represents a long-running operation. The 'result' attribute holds + the output of a completed job, which for orchestration jobs, is often a + JSON string mapping sub-task names to their unique IDs. + + Attributes + ---------- + asyncId + The unique identifier for the asynchronous job. + status + The current overall status of the job (e.g., 'RUNNING', 'COMPLETED'). + progress + The completion percentage of the job (from 0.0 to 100.0). + createdAt + The timestamp when the job was created. + completedAt + The timestamp when the job finished (successfully or not). + tasks + A dictionary mapping logical task keys to their unique task IDs. + This is often populated by parsing the 'result' of an orchestration task. + result + The raw string output of the completed job. If the job spawns other + tasks, this is expected to be a JSON string detailing those tasks. + taskBlockInfo + Information on any dependencies blocking the job from running. + message + A human-readable message about the job's status. + """ + + asyncId: str + status: str + progress: Optional[float] = None + createdAt: Optional[datetime] = None + completedAt: Optional[datetime] = None + tasks: Optional[dict[str, str]] = None + result: Optional[str] = None + taskBlockInfo: Optional[TaskBlockInfo] = None + message: Optional[str] = None + + +AsyncJobDetail.model_rebuild() diff --git a/tidy3d/_common/web/core/types.py b/tidy3d/_common/web/core/types.py new file mode 100644 index 0000000000..aaac18612a --- /dev/null +++ b/tidy3d/_common/web/core/types.py @@ -0,0 +1,73 @@ +"""Tidy3d abstraction types for the core.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from enum import Enum +from typing import Any + +from pydantic import BaseModel + + +class Tidy3DResource(BaseModel, ABC): + """Abstract base class / template for a webservice that implements resource query.""" + + @classmethod + @abstractmethod + def get(cls, *args: Any, **kwargs: Any) -> Tidy3DResource: + """Get a resource from the server.""" + + +class ResourceLifecycle(Tidy3DResource, ABC): + """Abstract base class for a webservice that implements resource life cycle management.""" + + @classmethod + @abstractmethod + def create(cls, *args: Any, **kwargs: Any) -> Tidy3DResource: + """Create a new resource and return it.""" + + @abstractmethod + def delete(self, *args: Any, **kwargs: Any) -> None: + """Delete the resource.""" + + +class Submittable(BaseModel, ABC): + """Abstract base class / template for a webservice that implements a submit method.""" + + @abstractmethod + def submit(self, *args: Any, **kwargs: Any) -> None: + """Submit the task to the webservice.""" + + +class Queryable(BaseModel, ABC): + """Abstract base class / template for a webservice that implements a query method.""" + + @classmethod + @abstractmethod + def list(cls, *args: Any, **kwargs: Any) -> list[Queryable]: + """List all resources of this type.""" + + +class TaskType(str, Enum): + FDTD = "FDTD" + MODE_SOLVER = "MODE_SOLVER" + HEAT = "HEAT" + HEAT_CHARGE = "HEAT_CHARGE" + EME = "EME" + MODE = "MODE" + VOLUME_MESH = "VOLUME_MESH" + MODAL_CM = "MODAL_CM" + TERMINAL_CM = "TERMINAL_CM" + + +class PayType(str, Enum): + CREDITS = "FLEX_CREDIT" + AUTO = "AUTO" + + @classmethod + def _missing_(cls, value: object) -> PayType: + if isinstance(value, str): + key = value.strip().replace(" ", "_").upper() + if key in cls.__members__: + return cls.__members__[key] + return super()._missing_(value) diff --git a/tidy3d/_runtime.py b/tidy3d/_runtime.py index 6dbf61accd..068dfd2d92 100644 --- a/tidy3d/_runtime.py +++ b/tidy3d/_runtime.py @@ -1,12 +1,10 @@ -"""Runtime environment detection for tidy3d. +"""Compatibility shim for :mod:`tidy3d._common._runtime`.""" -This module must have ZERO dependencies on other tidy3d modules to avoid -circular imports. It is imported very early in the initialization chain. -""" +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility +# marked as migrated to _common from __future__ import annotations -import sys - -# Detect WASM/Pyodide environment where web and filesystem features are unavailable -WASM_BUILD = "pyodide" in sys.modules or sys.platform == "emscripten" +from tidy3d._common._runtime import ( + WASM_BUILD, +) diff --git a/tidy3d/compat.py b/tidy3d/compat.py index a616a41895..dd9ff00a6b 100644 --- a/tidy3d/compat.py +++ b/tidy3d/compat.py @@ -1,31 +1,8 @@ -"""Compatibility layer for handling differences between package versions.""" +"""Compatibility shim for :mod:`tidy3d._common.compat`.""" -from __future__ import annotations - -import importlib -from functools import cache - -from packaging.version import parse - -try: - from xarray.structure import alignment -except ImportError: - from xarray.core import alignment - -try: - from numpy import trapezoid as np_trapezoid -except ImportError: # NumPy < 2.0 - from numpy import trapz as np_trapezoid - -try: - from typing import Self, TypeAlias # Python >= 3.11 -except ImportError: # Python <3.11 - from typing_extensions import Self, TypeAlias - - -@cache -def _package_is_older_than(package: str, version: str) -> bool: - return parse(importlib.metadata.version(package)) < parse(version) +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility +# marked as migrated to _common +from __future__ import annotations -__all__ = ["Self", "TypeAlias", "_package_is_older_than", "alignment", "np_trapezoid"] +from tidy3d._common.compat import Self, TypeAlias, _package_is_older_than, alignment, np_trapezoid diff --git a/tidy3d/components/autograd/__init__.py b/tidy3d/components/autograd/__init__.py index a415ea67f1..d83a2da2c6 100644 --- a/tidy3d/components/autograd/__init__.py +++ b/tidy3d/components/autograd/__init__.py @@ -1,11 +1,15 @@ +"""Compatibility shim for :mod:`tidy3d._common.components.autograd`.""" + +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility + +# marked as migrated to _common from __future__ import annotations -from .boxes import TidyArrayBox -from .functions import interpn -from .types import ( +from tidy3d._common.components.autograd import ( AutogradFieldMap, InterpolationType, PathType, + TidyArrayBox, TracedArrayFloat2D, TracedArrayLike, TracedComplex, @@ -16,27 +20,9 @@ TracedPositiveFloat, TracedSize, TracedSize1D, + get_static, + hasbox, + interpn, + is_tidy_box, + split_list, ) -from .utils import get_static, hasbox, is_tidy_box, split_list - -__all__ = [ - "AutogradFieldMap", - "InterpolationType", - "PathType", - "TidyArrayBox", - "TracedArrayFloat2D", - "TracedArrayLike", - "TracedComplex", - "TracedCoordinate", - "TracedFloat", - "TracedPoleAndResidue", - "TracedPolesAndResidues", - "TracedPositiveFloat", - "TracedSize", - "TracedSize1D", - "get_static", - "hasbox", - "interpn", - "is_tidy_box", - "split_list", -] diff --git a/tidy3d/components/autograd/boxes.py b/tidy3d/components/autograd/boxes.py index d51e948a85..d8765e865c 100644 --- a/tidy3d/components/autograd/boxes.py +++ b/tidy3d/components/autograd/boxes.py @@ -1,162 +1,13 @@ -# Adds some functionality to the autograd arraybox and related autograd patches -# NOTE: we do not subclass ArrayBox since that would break autograd's internal checks -from __future__ import annotations - -import importlib -from typing import TYPE_CHECKING, Any - -import autograd.numpy as anp -from autograd.extend import VJPNode, defjvp, register_notrace -from autograd.numpy.numpy_boxes import ArrayBox -from autograd.numpy.numpy_wrapper import _astype - -if TYPE_CHECKING: - from typing import Callable - -TidyArrayBox = ArrayBox # NOT a subclass +"""Compatibility shim for :mod:`tidy3d._common.components.autograd.boxes`.""" -_autograd_module_cache = {} # cache for imported autograd modules +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -register_notrace(VJPNode, anp.full_like) +# marked as migrated to _common +from __future__ import annotations -defjvp( - _astype, - lambda g, ans, A, dtype, order="K", casting="unsafe", subok=True, copy=True: _astype(g, dtype), +from tidy3d._common.components.autograd.boxes import ( + TidyArrayBox, + _autograd_module_cache, + from_arraybox, + item, ) - -anp.astype = _astype -anp.permute_dims = anp.transpose - - -@classmethod -def from_arraybox(cls: Any, box: ArrayBox) -> TidyArrayBox: - """Construct a TidyArrayBox from an ArrayBox.""" - return cls(box._value, box._trace, box._node) - - -def __array_function__( - self: Any, - func: Callable, - types: list[Any], - args: tuple[Any, ...], - kwargs: dict[str, Any], -) -> Any: - """ - Handle the dispatch of NumPy functions to autograd's numpy implementation. - - Parameters - ---------- - self : Any - The instance of the class. - func : Callable - The NumPy function being called. - types : list[Any] - The types of the arguments that implement __array_function__. - args : tuple[Any, ...] - The positional arguments to the function. - kwargs : dict[str, Any] - The keyword arguments to the function. - - Returns - ------- - Any - The result of the function call, or NotImplemented. - - Raises - ------ - NotImplementedError - If the function is not implemented in autograd.numpy. - - See Also - -------- - https://numpy.org/doc/stable/reference/arrays.classes.html#numpy.class.__array_function__ - """ - if not all(t in TidyArrayBox.type_mappings for t in types): - return NotImplemented - - module_name = func.__module__ - - if module_name.startswith("numpy"): - anp_module_name = "autograd." + module_name - else: - return NotImplemented - - # Use the cached module if available - anp_module = _autograd_module_cache.get(anp_module_name) - if anp_module is None: - try: - anp_module = importlib.import_module(anp_module_name) - _autograd_module_cache[anp_module_name] = anp_module - except ImportError: - return NotImplemented - - f = getattr(anp_module, func.__name__, None) - if f is None: - return NotImplemented - - if f.__name__ == "nanmean": # somehow xarray always dispatches to nanmean - f = anp.mean - kwargs.pop("dtype", None) # autograd mean vjp doesn't support dtype - - return f(*args, **kwargs) - - -def __array_ufunc__( - self: Any, - ufunc: Callable, - method: str, - *inputs: Any, - **kwargs: dict[str, Any], -) -> Any: - """ - Handle the dispatch of NumPy ufuncs to autograd's numpy implementation. - - Parameters - ---------- - self : Any - The instance of the class. - ufunc : Callable - The universal function being called. - method : str - The method of the ufunc being called. - inputs : Any - The input arguments to the ufunc. - kwargs : dict[str, Any] - The keyword arguments to the ufunc. - - Returns - ------- - Any - The result of the ufunc call, or NotImplemented. - - See Also - -------- - https://numpy.org/doc/stable/reference/arrays.classes.html#numpy.class.__array_ufunc__ - """ - if method != "__call__": - return NotImplemented - - ufunc_name = ufunc.__name__ - - anp_ufunc = getattr(anp, ufunc_name, None) - if anp_ufunc is not None: - return anp_ufunc(*inputs, **kwargs) - - return NotImplemented - - -def item(self: Any) -> Any: - if self.size != 1: - raise ValueError("Can only convert an array of size 1 to a scalar") - return anp.ravel(self)[0] - - -TidyArrayBox._tidy = True -TidyArrayBox.from_arraybox = from_arraybox -TidyArrayBox.__array_namespace__ = lambda self, *, api_version=None: anp -TidyArrayBox.__array_ufunc__ = __array_ufunc__ -TidyArrayBox.__array_function__ = __array_function__ -TidyArrayBox.real = property(anp.real) -TidyArrayBox.imag = property(anp.imag) -TidyArrayBox.conj = anp.conj -TidyArrayBox.item = item diff --git a/tidy3d/components/autograd/derivative_utils.py b/tidy3d/components/autograd/derivative_utils.py index 658bd67f05..28b5c8f52e 100644 --- a/tidy3d/components/autograd/derivative_utils.py +++ b/tidy3d/components/autograd/derivative_utils.py @@ -1,1106 +1,17 @@ -"""Utilities for autograd derivative computation and field gradient evaluation.""" +"""Compatibility shim for :mod:`tidy3d._common.components.autograd.derivative_utils`.""" -from __future__ import annotations - -from contextlib import contextmanager -from dataclasses import dataclass, field, replace -from functools import reduce -from typing import TYPE_CHECKING, Any, Optional - -import numpy as np -from numpy.typing import NDArray - -from tidy3d.components.data.data_array import FreqDataArray, ScalarFieldDataArray -from tidy3d.components.data.utils import _zeros_like -from tidy3d.components.types import ArrayLike, Bound -from tidy3d.config import config -from tidy3d.constants import C_0, EPSILON_0, LARGE_NUMBER, MU_0 -from tidy3d.log import log - -from .types import PathType -from .utils import get_static - -if TYPE_CHECKING: - from collections.abc import Iterator - from typing import Callable, Union - - import xarray as xr - - from tidy3d.compat import Self - from tidy3d.components.types import xyz - -FieldDataDict = dict[str, ScalarFieldDataArray] -PermittivityData = dict[str, ScalarFieldDataArray] -EpsType = FreqDataArray -ArrayFloat = NDArray[np.floating] -ArrayComplex = NDArray[np.complexfloating] - - -class LazyInterpolator: - """Lazy wrapper for interpolators that creates them on first access.""" - - def __init__(self, creator_func: Callable[[], Callable[[ArrayFloat], ArrayComplex]]) -> None: - """Initialize with a function that creates the interpolator when called.""" - self.creator_func = creator_func - self._interpolator: Optional[Callable[[ArrayFloat], ArrayComplex]] = None - - def __call__(self, *args: Any, **kwargs: Any) -> ArrayComplex: - """Create interpolator on first call and delegate to it.""" - if self._interpolator is None: - self._interpolator = self.creator_func() - return self._interpolator(*args, **kwargs) - - -@dataclass -class DerivativeInfo: - """Stores derivative information passed to the ``._compute_derivatives`` methods. - - This dataclass contains all the field data and parameters needed for computing - gradients with respect to geometry perturbations. - """ - - # Required fields - paths: list[PathType] - """List of paths to the traced fields that need derivatives calculated.""" - - E_der_map: FieldDataDict - """Electric field gradient map. - Dataset where the field components ("Ex", "Ey", "Ez") store the multiplication - of the forward and adjoint electric fields. The tangential components of this - dataset are used when computing adjoint gradients for shifting boundaries. - All components are used when computing volume-based gradients.""" - - D_der_map: FieldDataDict - """Displacement field gradient map. - Dataset where the field components ("Ex", "Ey", "Ez") store the multiplication - of the forward and adjoint displacement fields. The normal component of this - dataset is used when computing adjoint gradients for shifting boundaries.""" - - E_fwd: FieldDataDict - """Forward electric fields. - Dataset where the field components ("Ex", "Ey", "Ez") represent the forward - electric fields used for computing gradients for a given structure.""" - - E_adj: FieldDataDict - """Adjoint electric fields. - Dataset where the field components ("Ex", "Ey", "Ez") represent the adjoint - electric fields used for computing gradients for a given structure.""" - - D_fwd: FieldDataDict - """Forward displacement fields. - Dataset where the field components ("Ex", "Ey", "Ez") represent the forward - displacement fields used for computing gradients for a given structure.""" - - D_adj: FieldDataDict - """Adjoint displacement fields. - Dataset where the field components ("Ex", "Ey", "Ez") represent the adjoint - displacement fields used for computing gradients for a given structure.""" - - eps_data: PermittivityData - """Permittivity dataset. - Dataset of relative permittivity values along all three dimensions. - Used for automatically computing permittivity inside or outside of a simple geometry.""" - - eps_in: EpsType | None - """Permittivity inside the Structure. - Computed only when structure.medium.is_custom is False. Contains the simulation - permittivity inside the structure when the simulation background medium is set to - the structure medium and all structures after the current structure are kept. Should - be used as the inside permittivity for shape derivative computations.""" - - eps_out: EpsType - """Permittivity outside the Structure. - Contains the simulation permittivity outside the structure when the current structure - is removed from the structure list. Should be used as the outside permittivity for - shape derivative computations.""" - - bounds: Bound - """Geometry bounds. - Bounds corresponding to the structure, used in Medium calculations.""" - - bounds_intersect: Bound - """Geometry and simulation intersection bounds. - Bounds corresponding to the minimum intersection between the structure - and the simulation it is contained in.""" - - simulation_bounds: Bound - """Simulation bounds. - Bounds corresponding to the simulation domain containing this structure. - Unlike bounds_intersect, this is independent of the structure's bounds and - is purely based on the simulation geometry.""" - - frequencies: ArrayLike - """Frequencies at which the adjoint gradient should be computed.""" - - # Optional fields with defaults - - H_der_map: Optional[FieldDataDict] = None - """Magnetic field gradient map. - Dataset where the field components ("Hx", "Hy", "Hz") store the multiplication - of the forward and adjoint magnetic fields. The tangential component of this - dataset is used when computing adjoint gradients for shifting boundaries of - structures composed of PEC mediums.""" - - H_fwd: Optional[FieldDataDict] = None - """Forward magnetic fields. - Dataset where the field components ("Hx", "Hy", "Hz") represent the forward - magnetic fields used for computing gradients for a given structure.""" - - H_adj: Optional[FieldDataDict] = None - """Adjoint magnetic fields. - Dataset where the field components ("Hx", "Hy", "Hz") represent the adjoint - magnetic fields used for computing gradients for a given structure.""" - - is_medium_pec: bool = False - """Indicates if structure material is PEC. - If True, the structure contains a PEC material which changes the gradient - formulation at the boundary compared to the dielectric case.""" - - background_medium_is_pec: bool = False - """Indicates if structure material is PEC. - If True, the structure is partially surrounded by a PEC material.""" - - interpolators: Optional[dict] = None - """Pre-computed interpolators. - Optional pre-computed interpolators for field components and permittivity data. - When provided, avoids redundant interpolator creation for multiple geometries - sharing the same field data. This significantly improves performance for - GeometryGroup processing.""" - - cached_min_spacing_from_permittivity: Optional[float] = None - """Cached `min_spacing_from_permittivity` to be used for objects like GeometryGroup - to avoid recomputing this value multiple times in `adaptive_vjp_spacing`.""" - - # private cache for interpolators - _interpolators_cache: dict = field(default_factory=dict, init=False, repr=False) - - def updated_copy(self, **kwargs: Any) -> Self: - """Create a copy with updated fields.""" - kwargs.pop("deep", None) - kwargs.pop("validate", None) - return replace(self, **kwargs) - - @staticmethod - def _nan_to_num_if_needed( - coords: Union[ArrayFloat, ArrayComplex], - ) -> Union[ArrayFloat, ArrayComplex]: - """Convert NaN and infinite values to finite numbers, optimized for finite inputs.""" - # skip check for small arrays - if coords.size < 1000: - return np.nan_to_num(coords, posinf=LARGE_NUMBER, neginf=-LARGE_NUMBER) - - if np.isfinite(coords).all(): - return coords - return np.nan_to_num(coords, posinf=LARGE_NUMBER, neginf=-LARGE_NUMBER) - - @staticmethod - def _evaluate_with_interpolators( - interpolators: dict[str, Callable[[ArrayFloat], ArrayComplex]], - coords: ArrayFloat, - ) -> dict[str, ArrayComplex]: - """Evaluate field components at coordinates using cached interpolators. - - Parameters - ---------- - interpolators : dict - Dictionary mapping field component names to ``RegularGridInterpolator`` objects. - coords : np.ndarray - Spatial coordinates (N, 3) where fields are evaluated. - - Returns - ------- - dict[str, np.ndarray] - Dictionary mapping component names to field values at coordinates. - """ - auto_cfg = config.adjoint - float_dtype = auto_cfg.gradient_dtype_float - complex_dtype = auto_cfg.gradient_dtype_complex - - coords = DerivativeInfo._nan_to_num_if_needed(coords) - if coords.dtype != float_dtype and coords.dtype != complex_dtype: - coords = coords.astype(float_dtype, copy=False) - return {name: interp(coords) for name, interp in interpolators.items()} - - def create_interpolators(self, dtype: Optional[np.dtype[Any]] = None) -> dict[str, Any]: - """Create interpolators for field components and permittivity data. - - Creates and caches ``RegularGridInterpolator`` objects for all field components - (E_fwd, E_adj, D_fwd, D_adj) and permittivity data (eps_in, eps_out, eps_data). - Contains (H_fwd, H_adj) field components when relevant for certain material types. - This caching strategy significantly improves performance by avoiding - repeated interpolator construction in gradient evaluation loops. - - Parameters - ---------- - dtype : np.dtype[Any], optional = None - Data type for interpolation coordinates and values. Defaults to the - current ``config.adjoint.gradient_dtype_float``. - - Returns - ------- - dict - Nested dictionary structure: - - Field data: {"E_fwd": {"Ex": interpolator, ...}, ...} - - Permittivity: {"eps_in": interpolator, "eps_out": interpolator, "eps_data": interpolator} - """ - from scipy.interpolate import RegularGridInterpolator - - auto_cfg = config.adjoint - if dtype is None: - dtype = auto_cfg.gradient_dtype_float - complex_dtype = auto_cfg.gradient_dtype_complex - - cache_key = str(dtype) - if cache_key in self._interpolators_cache: - return self._interpolators_cache[cache_key] - - interpolators = {} - coord_cache = {} - - def _make_lazy_interpolator_group( - field_data_dict: Optional[FieldDataDict], - group_key: Optional[str], - is_field_group: bool = True, - override_method: Optional[str] = None, - ) -> None: - """Helper to create a group of lazy interpolators.""" - if not field_data_dict: - return - if is_field_group: - interpolators[group_key] = {} - - for component_name, arr in field_data_dict.items(): - # use object ID for caching to handle shared grids - arr_id = id(arr.data) - if arr_id not in coord_cache: - points = tuple(c.data.astype(dtype, copy=False) for c in (arr.x, arr.y, arr.z)) - coord_cache[arr_id] = points - points = coord_cache[arr_id] - - def creator_func( - arr: ScalarFieldDataArray = arr, - points: tuple[np.ndarray, ...] = points, - ) -> Callable[[ArrayFloat], ArrayComplex]: - data = arr.data.astype( - complex_dtype if np.iscomplexobj(arr.data) else dtype, copy=False - ) - # create interpolator with frequency dimension - if "f" in arr.dims: - freq_coords = arr.coords["f"].data.astype(dtype, copy=False) - # ensure frequency dimension is last - if arr.dims != ("x", "y", "z", "f"): - freq_dim_idx = arr.dims.index("f") - axes = list(range(data.ndim)) - axes.append(axes.pop(freq_dim_idx)) - data = np.transpose(data, axes) - else: - # single frequency case - add singleton dimension - freq_coords = np.array([0.0], dtype=dtype) - data = data[..., np.newaxis] - - points_with_freq = (*points, freq_coords) - # If PEC, use nearest interpolation instead of linear to avoid interpolating - # with field values inside the PEC (which are 0). Instead, we make sure to - # choose interpolation points such that their nearest location is outside of - # the PEC surface. The same applies if the background_medium is marked as PEC - # since we will need to use the same interpolation strategy inside the structure - # border. - method = ( - "nearest" - if (self.is_medium_pec or self.background_medium_is_pec) - else "linear" - ) - if override_method is not None: - method = override_method - interpolator_obj = RegularGridInterpolator( - points_with_freq, data, method=method, bounds_error=False, fill_value=None - ) - - def interpolator(coords: ArrayFloat) -> ArrayComplex: - # coords: (N, 3) spatial points - n_points = coords.shape[0] - n_freqs = len(freq_coords) - - # build coordinates with frequency dimension - coords_with_freq = np.empty((n_points * n_freqs, 4), dtype=coords.dtype) - coords_with_freq[:, :3] = np.repeat(coords, n_freqs, axis=0) - coords_with_freq[:, 3] = np.tile(freq_coords, n_points) - - result = interpolator_obj(coords_with_freq) - return result.reshape(n_points, n_freqs) - - return interpolator - - if is_field_group: - interpolators[group_key][component_name] = LazyInterpolator(creator_func) - else: - interpolators[component_name] = LazyInterpolator(creator_func) - - # process field interpolators (nested dictionaries) - interpolator_groups = [ - ("E_fwd", self.E_fwd), - ("E_adj", self.E_adj), - ("D_fwd", self.D_fwd), - ("D_adj", self.D_adj), - ] - if self.is_medium_pec or self.background_medium_is_pec: - interpolator_groups += [("H_fwd", self.H_fwd), ("H_adj", self.H_adj)] # type: ignore[list-item] - for group_key, data_dict in interpolator_groups: - _make_lazy_interpolator_group( - data_dict, f"{group_key}_linear", is_field_group=True, override_method="linear" - ) - _make_lazy_interpolator_group( - data_dict, f"{group_key}_nearest", is_field_group=True, override_method="nearest" - ) - - if self.eps_data is not None: - _make_lazy_interpolator_group( - self.eps_data, "eps_data", is_field_group=True, override_method="nearest" - ) - - if self.eps_in is not None: - _make_lazy_interpolator_group( - {"eps_in": self.eps_in}, None, is_field_group=False, override_method="nearest" - ) - if self.eps_out is not None: - _make_lazy_interpolator_group( - {"eps_out": self.eps_out}, None, is_field_group=False, override_method="nearest" - ) - - self._interpolators_cache[cache_key] = interpolators - return interpolators - - def evaluate_gradient_at_points( - self, - spatial_coords: np.ndarray, - normals: np.ndarray, - perps1: np.ndarray, - perps2: np.ndarray, - interpolators: Optional[dict] = None, - ) -> np.ndarray: - """Compute adjoint gradients at surface points for shape optimization. - - Implements the surface integral formulation for computing gradients with respect - to geometry perturbations. - - Parameters - ---------- - spatial_coords : np.ndarray - (N, 3) array of surface evaluation points. - normals : np.ndarray - (N, 3) array of outward-pointing normal vectors at each surface point. - perps1 : np.ndarray - (N, 3) array of first tangent vectors perpendicular to normals. - perps2 : np.ndarray - (N, 3) array of second tangent vectors perpendicular to both normals and perps1. - interpolators : dict = None - Pre-computed field interpolators for efficiency. - - Returns - ------- - np.ndarray - (N,) array of gradient values at each surface point. Must be integrated - with appropriate quadrature weights to get total gradient. - """ - if interpolators is None: - raise NotImplementedError( - "Direct field evaluation without interpolators is not implemented. " - "Please create interpolators using 'create_interpolators()' first." - ) - - # In all paths below, we need to have computed the gradient integration for a - # dielectric-dielectric interface. - vjps_dielectric = self._evaluate_dielectric_gradient_at_points( - spatial_coords, - normals, - perps1, - perps2, - interpolators, - self.eps_in, - self.eps_out, - ) - - if self.is_medium_pec: - # The structure medium is PEC, but there may be a part of the interface that has - # dielectric placed on top of or around it where we want to use the dielectric - # gradient integration. We use the mask to choose between the PEC-dielectric and - # dielectric-dielectric parts of the border. - - # Detect PEC by looking just inside the boundary - mask_pec = self._detect_pec_gradient_points( - spatial_coords, - normals, - self.eps_in, - interpolators["eps_data"], - is_outside=False, - ) - - # Compute PEC gradients, pulling fields outside of the boundary - vjps_pec = self._evaluate_pec_gradient_at_points( - spatial_coords, - normals, - perps1, - perps2, - interpolators, - ("eps_out", self.eps_out), - is_outside=True, - ) - - vjps = mask_pec * vjps_pec + (1.0 - mask_pec) * vjps_dielectric - elif self.background_medium_is_pec: - # The structure medium is dielectric, but there may be a part of the interface that has - # PEC placed on top of or around it where we want to use the PEC gradient integration. - # We use the mask to choose between the dielectric-dielectric and PEC-dielectric parts - # of the border. - - # Detect PEC by looking just outside the boundary - mask_pec = self._detect_pec_gradient_points( - spatial_coords, - normals, - self.eps_out, - interpolators["eps_data"], - is_outside=True, - ) - - # Compute PEC gradients, pulling fields inside of the boundary and applying a negative - # sign compared to above because inside and outside definitions are switched - vjps_pec = -self._evaluate_pec_gradient_at_points( - spatial_coords, - normals, - perps1, - perps2, - interpolators, - ("eps_in", self.eps_in), - is_outside=False, - ) - - vjps = mask_pec * vjps_pec + (1.0 - mask_pec) * vjps_dielectric - else: - # The structure and its background are both assumed to be dielectric, so we use the - # dielectric-dielectric gradient integration. - vjps = vjps_dielectric - - # sum over frequency dimension - vjps = np.sum(vjps, axis=-1) - - return vjps - - def _evaluate_dielectric_gradient_at_points( - self, - spatial_coords: ArrayFloat, - normals: ArrayFloat, - perps1: ArrayFloat, - perps2: ArrayFloat, - interpolators: dict[str, dict[str, Callable[[ArrayFloat], ArrayComplex]]], - eps_in_data: ScalarFieldDataArray, - eps_out_data: ScalarFieldDataArray, - ) -> ArrayComplex: - eps_out_coords = self._snap_spatial_coords_boundary( - spatial_coords, - normals, - is_outside=True, - data_array=eps_out_data, - ) - eps_in_coords = self._snap_spatial_coords_boundary( - spatial_coords, - normals, - is_outside=False, - data_array=eps_in_data, - ) - - eps_out = interpolators["eps_out"](eps_out_coords) - eps_in = interpolators["eps_in"](eps_in_coords) - - # evaluate all field components at surface points - E_fwd_at_coords = { - name: interp(spatial_coords) for name, interp in interpolators["E_fwd_linear"].items() - } - E_adj_at_coords = { - name: interp(spatial_coords) for name, interp in interpolators["E_adj_linear"].items() - } - D_fwd_at_coords = { - name: interp(spatial_coords) for name, interp in interpolators["D_fwd_linear"].items() - } - D_adj_at_coords = { - name: interp(spatial_coords) for name, interp in interpolators["D_adj_linear"].items() - } - - delta_eps_inv = 1.0 / eps_in - 1.0 / eps_out - delta_eps = eps_in - eps_out - - # project fields onto local surface basis (normal + two tangents) - D_fwd_norm = self._project_in_basis(D_fwd_at_coords, basis_vector=normals) - D_adj_norm = self._project_in_basis(D_adj_at_coords, basis_vector=normals) - - E_fwd_perp1 = self._project_in_basis(E_fwd_at_coords, basis_vector=perps1) - E_adj_perp1 = self._project_in_basis(E_adj_at_coords, basis_vector=perps1) - - E_fwd_perp2 = self._project_in_basis(E_fwd_at_coords, basis_vector=perps2) - E_adj_perp2 = self._project_in_basis(E_adj_at_coords, basis_vector=perps2) - - D_der_norm = D_fwd_norm * D_adj_norm - E_der_perp1 = E_fwd_perp1 * E_adj_perp1 - E_der_perp2 = E_fwd_perp2 * E_adj_perp2 - - vjps = -delta_eps_inv * D_der_norm + E_der_perp1 * delta_eps + E_der_perp2 * delta_eps - - return vjps - - def _snap_spatial_coords_boundary( - self, - spatial_coords: ArrayFloat, - normals: ArrayFloat, - is_outside: bool, - data_array: ScalarFieldDataArray, - ) -> np.ndarray: - """Assuming a nearest interpolation, adjust the interpolation points given the grid - defined by `grid_centers` and using `spatial_coords` as a starting point such that we - select a point inside/outside the boundary depending on is_outside. - - *** (nearest point outside boundary) - ^ - | n (normal direction) - | - _.-~'`-._.-~'`-._ (boundary) - * (nearest point) - - Parameters - ---------- - spatial_coords : np.ndarray - (N, 3) array of surface evaluation points. - normals : np.ndarray - (N, 3) array of outward-pointing normal vectors at each surface point. - is_outside: bool - Indicator specifying if coordinates should be snapped inside or outside the boundary. - data_array: ScalarFieldDataArray - Data array to pull grid centers from when snapping coordinates. - - Returns - ------- - np.ndarray - (N, 3) array of coordinate centers at which to interpolate such that they line up - with a grid center and are inside/outside the boundary - """ - coords = data_array.coords - grid_centers = {key: np.array(coords[key].values) for key in coords} - - grid_ddim = np.zeros_like(normals) - for idx, dim in enumerate("xyz"): - expanded_coords = np.expand_dims(spatial_coords[:, idx], axis=1) - grid_centers_select = grid_centers[dim] - - diff = np.abs(expanded_coords - grid_centers_select) - - nearest_grid = np.argmin(diff, axis=-1) - nearest_grid = np.minimum(np.maximum(nearest_grid, 1), len(grid_centers_select) - 1) - - # compute the local grid spacing near the boundary - grid_ddim[:, idx] = ( - grid_centers_select[nearest_grid] - grid_centers_select[nearest_grid - 1] - ) - - # - # Assuming we move in the normal direction, finds which dimension we need to move the least - # in order to ensure we snap to a point outside the boundary in the worst case (i.e. - the - # nearest point is just inside the surface) - # - # Cover for 2D cases using filter below: - # 2D case 1: - # - in plane gradients where normal: [a, b, 0] and grid: [dx, dy, 0] - # - want to rely on in plane normals for boundary snapping (filter on normal component = 0) - # 2D case 2: - # - out of plane gradietns where normal: [0, 0, 1] and grid: [dx, dy, 0] - # - want to rely on out of plane normal (so do not want to filter on grid component = 0) - # - data may not be captured out of plane, so no snapping will occur even with coords_dn = 0 - # - small_number = np.finfo(normals.dtype).eps - coords_dn = np.min( - np.where( - (np.abs(normals) > small_number), - np.abs(grid_ddim) / (np.abs(normals) + small_number), - np.inf, - ), - axis=1, - keepdims=True, - ) - - # adjust coordinates by half a grid point outside boundary such that nearest interpolation - # point snaps to outside the boundary - normal_direction = 1.0 if is_outside else -1.0 - adjust_spatial_coords = ( - spatial_coords - + normal_direction * normals * config.adjoint.boundary_snapping_fraction * coords_dn - ) - - return adjust_spatial_coords - - def _compute_edge_distance( - self, - spatial_coords: np.ndarray, - grid_centers: dict[str, np.ndarray], - adjust_spatial_coords: np.ndarray, - ) -> np.ndarray: - """Assuming nearest neighbor interpolation, computes the edge distance after interpolation when using the - adjust_spatial_coords computed from _snap_spatial_coords_boundary. - - Parameters - ---------- - spatial_coords : np.ndarray - (N, 3) array of surface evaluation points. - normals : np.ndarray - (N, 3) array of outward-pointing normal vectors at each surface point. - grid_centers: dict[str, np.ndarray] - The grid points for a given field component indexed by dimension. These grid points - are used to find the nearest snapping point and adjust the interpolation coordinates - to ensure we fall inside/outside of a boundary. - - Returns - ------- - np.ndarray - (N,) array of distances from the nearest interpolation points to the desired surface - edge points specified by `spatial_coords` - """ - - edge_distance_squared_sum = np.zeros_like(adjust_spatial_coords[:, 0]) - for idx, dim in enumerate("xyz"): - expanded_adjusted_coords = np.expand_dims(adjust_spatial_coords[:, idx], axis=1) - grid_centers_select = grid_centers[dim] - - # find nearest grid point from the adjusted coordinates - diff = np.abs(expanded_adjusted_coords - grid_centers_select) - nearest_grid = np.argmin(diff, axis=-1) - - # compute edge distance from the nearest interpolated point to the boundary edge - edge_distance_squared_sum += ( - np.abs(spatial_coords[:, idx] - grid_centers_select[nearest_grid]) ** 2 - ) - - # this edge distance is useful when correcting for edge singularities like those from a PEC - # material and is used when the PEC PolySlab structure has zero thickness, for example - edge_distance = np.sqrt(edge_distance_squared_sum) - - return edge_distance - - def _detect_pec_gradient_points( - self, - spatial_coords: np.ndarray, - normals: np.ndarray, - eps_data: ScalarFieldDataArray, - interpolator: LazyInterpolator, - is_outside: bool, - ) -> np.ndarray: - def _detect_pec(eps_mask: np.ndarray) -> np.ndarray: - return 1.0 * (eps_mask < config.adjoint.pec_detection_threshold) - - adjusted_coords = self._snap_spatial_coords_boundary( - spatial_coords=spatial_coords, - normals=normals, - is_outside=is_outside, - data_array=eps_data, - ) - - eps_adjusted_all = [ - component_interpolator(adjusted_coords) - for _, component_interpolator in interpolator.items() - ] - eps_detect_pec = reduce(np.minimum, eps_adjusted_all) - - return _detect_pec(eps_detect_pec) - - def _evaluate_pec_gradient_at_points( - self, - spatial_coords: np.ndarray, - normals: np.ndarray, - perps1: np.ndarray, - perps2: np.ndarray, - interpolators: dict, - eps_dielectric: tuple[str, ScalarFieldDataArray], - is_outside: bool, - ) -> np.ndarray: - eps_dielectric_key, eps_dielectric_data = eps_dielectric - - def _snap_coordinate_outside( - field_components: FieldDataDict, - ) -> dict[str, dict[str, ArrayFloat]]: - """Helper function to perform coordinate adjustment and compute edge distance for each - component in `field_components`. - - Parameters - ---------- - field_components: FieldDataDict - The field components (i.e - Ex, Ey, Ez, Hx, Hy, Hz) that we would like to sample just - outside the PEC surface using nearest interpolation. - - Returns - ------- - dict[str, dict[str, np.ndarray]] - Dictionary mapping each field component name to a dictionary of adjusted coordinates - and edge distances for that component. - """ - adjustment = {} - for name in field_components: - field_component = field_components[name] - field_component_coords = field_component.coords - - grid_centers = { - key: np.array(field_component_coords[key].values) - for key in field_component_coords - } - - adjusted_coords = self._snap_spatial_coords_boundary( - spatial_coords, - normals, - is_outside=is_outside, - data_array=field_component, - ) - - edge_distance = self._compute_edge_distance( - spatial_coords=spatial_coords, - grid_centers=grid_centers, - adjust_spatial_coords=adjusted_coords, - ) - adjustment[name] = {"coords": adjusted_coords, "edge_distance": edge_distance} - - return adjustment - - def _interpolate_field_components( - interp_coords: dict[str, dict[str, ArrayFloat]], field_name: str - ) -> dict[str, ArrayComplex]: - return { - name: interp(interp_coords[name]["coords"]) - for name, interp in interpolators[field_name].items() - } - - # adjust coordinates for PEC to be outside structure bounds and get edge distance for singularity correction. - E_fwd_coords_adjusted = _snap_coordinate_outside(self.E_fwd) - E_adj_coords_adjusted = _snap_coordinate_outside(self.E_adj) - - H_fwd_coords_adjusted = _snap_coordinate_outside(self.H_fwd) - H_adj_coords_adjusted = _snap_coordinate_outside(self.H_adj) - - # using the adjusted coordinates, evaluate all field components at surface points - E_fwd_at_coords = _interpolate_field_components( - E_fwd_coords_adjusted, field_name="E_fwd_nearest" - ) - E_adj_at_coords = _interpolate_field_components( - E_adj_coords_adjusted, field_name="E_adj_nearest" - ) - H_fwd_at_coords = _interpolate_field_components( - H_fwd_coords_adjusted, field_name="H_fwd_nearest" - ) - H_adj_at_coords = _interpolate_field_components( - H_adj_coords_adjusted, field_name="H_adj_nearest" - ) - - eps_coords_adjusted = self._snap_spatial_coords_boundary( - spatial_coords, - normals, - is_outside=is_outside, - data_array=eps_dielectric_data, - ) - eps_dielectric = interpolators[eps_dielectric_key](eps_coords_adjusted) - - structure_sizes = np.array( - [self.bounds[1][idx] - self.bounds[0][idx] for idx in range(len(self.bounds[0]))] - ) - - is_flat_perp_dim1 = np.isclose(np.abs(np.sum(perps1[0] * structure_sizes)), 0.0) - is_flat_perp_dim2 = np.isclose(np.abs(np.sum(perps2[0] * structure_sizes)), 0.0) - flat_perp_dims = [is_flat_perp_dim1, is_flat_perp_dim2] - - # check if this integration is happening along an edge in which case we will eliminate - # on of the H field integration components and apply singularity correction - pec_line_integration = is_flat_perp_dim1 or is_flat_perp_dim2 - - def _compute_singularity_correction( - adjustment_: dict[str, dict[str, ArrayFloat]], - ) -> ArrayFloat: - """ - Given the `adjustment_` which contains the distance from the PEC edge each field - component is nearest interpolated at, computes the singularity correction when - working with 2D PEC using the average edge_distance for each component. In the case - of 3D PEC gradients, no singularity correction is applied so an array of ones is returned. - - Parameters - ---------- - adjustment_: dict[str, dict[str, np.ndarray]] - Dictionary that maps field component name to a dictionary containing the coordinate - adjustment and the distance to the PEC edge for those coordinates. The edge distance - is used for 2D PEC singularity correction. - - Returns - ------- - np.ndarray - Returns the singularity correction which has shape (N,) where there are N points in - `spatial_coords` - """ - return ( - ( - 0.5 - * np.pi - * np.mean([adjustment_[name]["edge_distance"] for name in adjustment_], axis=0) - ) - if pec_line_integration - else np.ones_like(spatial_coords, shape=spatial_coords.shape[0]) - ) - - E_norm_singularity_correction = np.expand_dims( - _compute_singularity_correction(E_fwd_coords_adjusted), axis=1 - ) - H_perp_singularity_correction = np.expand_dims( - _compute_singularity_correction(H_fwd_coords_adjusted), axis=1 - ) - - E_fwd_norm = self._project_in_basis(E_fwd_at_coords, basis_vector=normals) - E_adj_norm = self._project_in_basis(E_adj_at_coords, basis_vector=normals) - - # compute the normal E contribution to the gradient (the tangential E contribution - # is 0 in the case of PEC since this field component is continuous and thus 0 at - # the boundary) - contrib_E = E_norm_singularity_correction * eps_dielectric * E_fwd_norm * E_adj_norm - vjps = contrib_E - - # compute the tangential H contribution to the gradient (the normal H contribution - # is 0 for PEC) - H_fwd_perp1 = self._project_in_basis(H_fwd_at_coords, basis_vector=perps1) - H_adj_perp1 = self._project_in_basis(H_adj_at_coords, basis_vector=perps1) - - H_fwd_perp2 = self._project_in_basis(H_fwd_at_coords, basis_vector=perps2) - H_adj_perp2 = self._project_in_basis(H_adj_at_coords, basis_vector=perps2) - - H_der_perp1 = H_perp_singularity_correction * H_fwd_perp1 * H_adj_perp1 - H_der_perp2 = H_perp_singularity_correction * H_fwd_perp2 * H_adj_perp2 - - H_integration_components = (H_der_perp1, H_der_perp2) - if pec_line_integration: - # if we are integrating along the line, we choose the H component normal to - # the edge which corresponds to a surface current along the edge whereas the other - # tangential component corresponds to a surface current along the flat dimension. - H_integration_components = tuple( - H_comp for idx, H_comp in enumerate(H_integration_components) if flat_perp_dims[idx] - ) - - # for each of the tangential components we are integrating the H fields over, - # adjust weighting to account for pre-weighting of the source by `EPSILON_0` - # and multiply by appropriate `MU_0` factor - for H_perp in H_integration_components: - contrib_H = MU_0 * H_perp / EPSILON_0 - vjps += contrib_H - - return vjps - - @staticmethod - def _project_in_basis( - field_components: dict[str, np.ndarray], - basis_vector: np.ndarray, - ) -> np.ndarray: - """Project 3D field components onto a basis vector. - - Parameters - ---------- - field_components : dict[str, np.ndarray] - Dictionary with keys like "Ex", "Ey", "Ez" or "Dx", "Dy", "Dz" containing field values. - Values have shape (N, F) where F is the number of frequencies. - basis_vector : np.ndarray - (N, 3) array of basis vectors, one per evaluation point. - - Returns - ------- - np.ndarray - Projected field values with shape (N, F). - """ - prefix = next(iter(field_components.keys()))[0] - field_matrix = np.stack([field_components[f"{prefix}{dim}"] for dim in "xyz"], axis=0) - - # always expect (3, N, F) shape, transpose to (N, 3, F) - field_matrix = np.transpose(field_matrix, (1, 0, 2)) - return np.einsum("ij...,ij->i...", field_matrix, basis_vector) - - def project_der_map_to_axis( - self, axis: xyz, field_type: str = "E" - ) -> dict[str, ScalarFieldDataArray] | None: - """Return a copy of the selected derivative map with only one axis kept. - - Parameters - ---------- - axis: - Axis to keep (``"x"``, ``"y"``, ``"z"``, case-insensitive). - field_type: - Map selector: ``"E"`` (``self.E_der_map``) or ``"D"`` (``self.D_der_map``). - - Returns - ------- - dict[str, ScalarFieldDataArray] | None - Copied map where non-selected components are replaced by zeros, or ``None`` - if the requested map is unavailable. - """ - field_map = {"E": self.E_der_map, "D": self.D_der_map}.get(field_type) - if field_map is None: - raise ValueError("field type must be 'D' or 'E'.") - - axis = axis.lower() - projected = dict(field_map) - if not field_map: - return projected - for dim in "xyz": - key = f"E{dim}" - if key not in field_map: - continue - if dim != axis: - projected[key] = _zeros_like(field_map[key]) - else: - projected[key] = field_map[key] - return projected - - @property - def min_spacing_from_permittivity(self) -> float: - if self.cached_min_spacing_from_permittivity is not None: - return self.cached_min_spacing_from_permittivity - - def spacing_by_permittivity(eps_array: ScalarFieldDataArray) -> float: - eps_real = np.asarray(eps_array.values, dtype=np.complex128).real - - dx_candidates = [] - max_frequency = np.max(self.frequencies) - - # wavelength-based sampling for dielectrics - if np.any(eps_real > 0): - eps_max = eps_real[eps_real > 0].max() - lambda_min = self.wavelength_min / np.sqrt(eps_max) - dx_candidates.append(lambda_min) - - # skin depth sampling for metals - if np.any(eps_real <= 0): - omega = 2 * np.pi * max_frequency - eps_neg = eps_real[eps_real <= 0] - delta_min = C_0 / (omega * np.sqrt(np.abs(eps_neg).max())) - dx_candidates.append(delta_min) - - computed_spacing = min(dx_candidates) - - return computed_spacing - - eps_spacings = [ - spacing_by_permittivity(eps_array) for _, eps_array in self.eps_data.items() - ] - min_spacing = np.min(eps_spacings) - - return min_spacing - - @contextmanager - def cache_min_spacing_from_permittivity(self) -> Iterator[None]: - """ - Cache min_spacing_from_permittivity for the duration of the block. Cache - is always cleared on exit. - """ - - self.cached_min_spacing_from_permittivity = self.min_spacing_from_permittivity - try: - yield - finally: - self.cached_min_spacing_from_permittivity = None - - def adaptive_vjp_spacing( - self, - wl_fraction: Optional[float] = None, - min_allowed_spacing_fraction: Optional[float] = None, - ) -> float: - """Compute adaptive spacing for finite-difference gradient evaluation. - - Determines an appropriate spatial resolution based on the material - properties and electromagnetic wavelength/skin depth. - - Parameters - ---------- - wl_fraction : float, optional - Fraction of wavelength/skin depth to use as spacing. Defaults to the configured - ``autograd.default_wavelength_fraction`` when ``None``. - min_allowed_spacing_fraction : float, optional - Minimum allowed spacing fraction of free space wavelength used to - prevent numerical issues. Defaults to ``config.adjoint.minimum_spacing_fraction`` - when not specified. - - Returns - ------- - float - Adaptive spacing value for gradient evaluation. - """ - if wl_fraction is None or min_allowed_spacing_fraction is None: - from tidy3d.config import config - - if wl_fraction is None: - wl_fraction = config.adjoint.default_wavelength_fraction - if min_allowed_spacing_fraction is None: - min_allowed_spacing_fraction = config.adjoint.minimum_spacing_fraction - - computed_spacing = wl_fraction * self.min_spacing_from_permittivity - - min_allowed_spacing = self.wavelength_min * min_allowed_spacing_fraction - - if computed_spacing < min_allowed_spacing: - log.warning( - f"Based on the material, the adaptive spacing for integrating the polyslab surface " - f"would be {computed_spacing:.3e} μm. The spacing has been clipped to {min_allowed_spacing:.3e} μm " - f"to prevent a performance degradation.", - log_once=True, - ) - - return max(computed_spacing, min_allowed_spacing) - - @property - def wavelength_min(self) -> float: - return C_0 / np.max(self.frequencies) - - @property - def wavelength_max(self) -> float: - return C_0 / np.min(self.frequencies) - - -def integrate_within_bounds(arr: xr.DataArray, dims: list[str], bounds: Bound) -> xr.DataArray: - """Integrate a data array within specified spatial bounds. - - Clips the integration domain to the specified bounds and performs - numerical integration using the trapezoidal rule. - - Parameters - ---------- - arr : xr.DataArray - Data array to integrate. - dims : list[str] - Dimensions to integrate over (e.g., ['x', 'y', 'z']). - bounds : Bound - Integration bounds as [[xmin, ymin, zmin], [xmax, ymax, zmax]]. - - Returns - ------- - xr.DataArray - Result of integration with specified dimensions removed. - - Notes - ----- - - Coordinates outside bounds are clipped, effectively setting dL=0 - - Only integrates dimensions with more than one coordinate point - - Uses xarray's integrate method (trapezoidal rule) - """ - bounds = np.asarray(bounds).T - all_coords = {} - - for dim, (bmin, bmax) in zip(dims, bounds): - bmin = get_static(bmin) - bmax = get_static(bmax) - - # clip coordinates to bounds (sets dL=0 outside bounds) - coord_values = arr.coords[dim].data - all_coords[dim] = np.clip(coord_values, bmin, bmax) - - _arr = arr.assign_coords(**all_coords) - - # only integrate dimensions with multiple points - dims_integrate = [dim for dim in dims if len(_arr.coords[dim]) > 1] - return _arr.integrate(coord=dims_integrate) +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility +# marked as migrated to _common +from __future__ import annotations -__all__ = [ - "DerivativeInfo", - "integrate_within_bounds", -] +from tidy3d._common.components.autograd.derivative_utils import ( + ArrayComplex, + ArrayFloat, + DerivativeInfo, + EpsType, + FieldDataDict, + LazyInterpolator, + PermittivityData, + integrate_within_bounds, +) diff --git a/tidy3d/components/autograd/field_map.py b/tidy3d/components/autograd/field_map.py index f5a352df02..6690a71649 100644 --- a/tidy3d/components/autograd/field_map.py +++ b/tidy3d/components/autograd/field_map.py @@ -1,77 +1,13 @@ -"""Typed containers for autograd traced field metadata.""" +"""Compatibility shim for :mod:`tidy3d._common.components.autograd.field_map`.""" -from __future__ import annotations - -import json -from typing import TYPE_CHECKING, Any, Union - -from pydantic import Field - -from tidy3d.components.autograd.types import TracedArrayLike, TracedComplex, TracedFloat -from tidy3d.components.base import Tidy3dBaseModel - -if TYPE_CHECKING: - from typing import Callable - - from tidy3d.components.autograd.types import AutogradFieldMap - - -class Tracer(Tidy3dBaseModel): - """Representation of a single traced element within a model.""" - - path: tuple[Any, ...] = Field( - title="Path to the traced object in the model dictionary.", - ) - data: Union[TracedFloat, TracedComplex, TracedArrayLike] = Field(title="Tracing data") - - -class FieldMap(Tidy3dBaseModel): - """Collection of traced elements.""" +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility - tracers: tuple[Tracer, ...] = Field( - title="Collection of Tracers.", - ) - - @property - def to_autograd_field_map(self) -> AutogradFieldMap: - """Convert to ``AutogradFieldMap`` autograd dictionary.""" - return {tracer.path: tracer.data for tracer in self.tracers} - - @classmethod - def from_autograd_field_map(cls, autograd_field_map: AutogradFieldMap) -> FieldMap: - """Initialize from an ``AutogradFieldMap`` autograd dictionary.""" - tracers = [] - for path, data in autograd_field_map.items(): - tracers.append(Tracer(path=path, data=data)) - return cls(tracers=tuple(tracers)) - - -def _encoded_path(path: tuple[Any, ...]) -> str: - """Return a stable JSON representation for a traced path.""" - return json.dumps(list(path), separators=(",", ":"), ensure_ascii=True) - - -class TracerKeys(Tidy3dBaseModel): - """Collection of traced field paths.""" - - keys: tuple[tuple[Any, ...], ...] = Field( - title="Collection of tracer keys.", - ) - - def encoded_keys(self) -> list[str]: - """Return the JSON-encoded representation of keys.""" - return [_encoded_path(path) for path in self.keys] - - @classmethod - def from_field_mapping( - cls, - field_mapping: AutogradFieldMap, - *, - sort_key: Callable[[tuple[Any, ...]], str] | None = None, - ) -> TracerKeys: - """Construct keys from an autograd field mapping.""" - if sort_key is None: - sort_key = _encoded_path +# marked as migrated to _common +from __future__ import annotations - sorted_paths = tuple(sorted(field_mapping.keys(), key=sort_key)) - return cls(keys=sorted_paths) +from tidy3d._common.components.autograd.field_map import ( + FieldMap, + Tracer, + TracerKeys, + _encoded_path, +) diff --git a/tidy3d/components/autograd/functions.py b/tidy3d/components/autograd/functions.py index 708ebc14a5..3e23c18503 100644 --- a/tidy3d/components/autograd/functions.py +++ b/tidy3d/components/autograd/functions.py @@ -1,289 +1,16 @@ -from __future__ import annotations - -import itertools -from typing import TYPE_CHECKING, Any - -import autograd.numpy as anp -import numpy as np -from autograd.extend import defjvp, defvjp, primitive -from autograd.numpy.numpy_jvps import broadcast -from autograd.numpy.numpy_vjps import unbroadcast_f - -if TYPE_CHECKING: - from numpy.typing import NDArray - - from .types import InterpolationType - - -def _evaluate_nearest( - indices: NDArray[np.int64], norm_distances: NDArray[np.float64], values: NDArray[np.float64] -) -> NDArray[np.float64]: - """Perform nearest neighbor interpolation in an n-dimensional space. - - This function determines the nearest neighbor in a grid for a given point - and returns the corresponding value from the input array. - - Parameters - ---------- - indices : np.ndarray[np.int64] - Indices of the lower bounds of the grid cell containing the interpolation point. - norm_distances : np.ndarray[np.float64] - Normalized distances from the lower bounds of the grid cell to the - interpolation point, for each dimension. - values : np.ndarray[np.float64] - The n-dimensional array of values to interpolate from. - - Returns - ------- - np.ndarray[np.float64] - The value of the nearest neighbor to the interpolation point. - """ - idx_res = tuple(anp.where(yi <= 0.5, i, i + 1) for i, yi in zip(indices, norm_distances)) - return values[idx_res] - - -def _evaluate_linear( - indices: NDArray[np.int64], norm_distances: NDArray[np.float64], values: NDArray[np.float64] -) -> NDArray[np.float64]: - """Perform linear interpolation in an n-dimensional space. - - This function calculates the linearly interpolated value at a point in an - n-dimensional grid, given the indices of the surrounding grid points and - the normalized distances to these points. - The multi-linear interpolation is implemented by computing a weighted - average of the values at the vertices of the hypercube surrounding the - interpolation point. - - Parameters - ---------- - indices : np.ndarray[np.int64] - Indices of the lower bounds of the grid cell containing the interpolation point. - norm_distances : np.ndarray[np.float64] - Normalized distances from the lower bounds of the grid cell to the - interpolation point, for each dimension. - values : np.ndarray[np.float64] - The n-dimensional array of values to interpolate from. - - Returns - ------- - np.ndarray[np.float64] - The interpolated value at the desired point. - """ - # Create a slice object for broadcasting over trailing dimensions - _slice = (slice(None),) + (None,) * (values.ndim - len(indices)) - - # Prepare iterables for lower and upper bounds of the hypercube - ix = zip(indices, (1 - yi for yi in norm_distances)) - iy = zip((i + 1 for i in indices), norm_distances) - - # Initialize the result - value = anp.zeros(1) - - # Iterate over all vertices of the hypercube - for h in itertools.product(*zip(ix, iy)): - edge_indices, weights = zip(*h) - - # Compute the weight for this vertex - weight = anp.ones(1) - for w in weights: - weight = weight * w - - # Compute the contribution of this vertex and add it to the result - term = values[edge_indices] * weight[_slice] - value = value + term - - return value - - -def interpn( - points: tuple[NDArray[np.float64], ...], - values: NDArray[np.float64], - xi: tuple[NDArray[np.float64], ...], - *, - method: InterpolationType = "linear", - **kwargs: Any, -) -> NDArray[np.float64]: - """Interpolate over a rectilinear grid in arbitrary dimensions. - - This function mirrors the interface of `scipy.interpolate.interpn` but is differentiable with autograd. - - Parameters - ---------- - points : tuple[np.ndarray[np.float64], ...] - The points defining the rectilinear grid in n dimensions. - values : np.ndarray[np.float64] - The data values on the rectilinear grid. - xi : tuple[np.ndarray[np.float64], ...] - The coordinates to sample the gridded data at. - method : InterpolationType = "linear" - The method of interpolation to perform. Supported are "linear" and "nearest". - - Returns - ------- - np.ndarray[np.float64] - The interpolated values. - - Raises - ------ - ValueError - If the interpolation method is not supported. - - See Also - -------- - `scipy.interpolate.interpn `_ - """ - from scipy.interpolate import RegularGridInterpolator - - if method == "nearest": - interp_fn = _evaluate_nearest - elif method == "linear": - interp_fn = _evaluate_linear - else: - raise ValueError(f"Unsupported interpolation method: {method}") +"""Compatibility shim for :mod:`tidy3d._common.components.autograd.functions`.""" - # Avoid SciPy coercing autograd ArrayBox values during _check_values. - dummy_values = np.zeros(np.shape(values), dtype=float) - if kwargs.get("fill_value") == "extrapolate": - itrp = RegularGridInterpolator( - points, dummy_values, method=method, fill_value=None, bounds_error=False - ) - else: - itrp = RegularGridInterpolator(points, dummy_values, method=method) +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility - # Prepare the grid for interpolation - # This step reshapes the grid, checks for NaNs and out-of-bounds values - # It returns: - # - reshaped grid - # - original shape - # - number of dimensions - # - boolean array indicating NaN positions - # - (discarded) boolean array for out-of-bounds values - xi, shape, ndim, nans, _ = itrp._prepare_xi(xi) - - # Find the indices of the grid cells containing the interpolation points - # and calculate the normalized distances (ranging from 0 at lower grid point to 1 - # at upper grid point) within these cells - indices, norm_distances = itrp._find_indices(xi.T) - - result = interp_fn(indices, norm_distances, values) - nans = anp.reshape(nans, (-1,) + (1,) * (result.ndim - 1)) - result = anp.where(nans, np.nan, result) - return anp.reshape(result, shape[:-1] + values.shape[ndim:]) - - -def trapz(y: NDArray, x: NDArray = None, dx: float = 1.0, axis: int = -1) -> float: - """ - Integrate along the given axis using the composite trapezoidal rule. - - Parameters - ---------- - y : np.ndarray - Input array to integrate. - x : np.ndarray = None - The sample points corresponding to the y values. If None, the sample points are assumed to be evenly spaced - with spacing `dx`. - dx : float = 1.0 - The spacing between sample points when `x` is None. Default is 1.0. - axis : int = -1 - The axis along which to integrate. Default is the last axis. - - Returns - ------- - float - Definite integral as approximated by the trapezoidal rule. - """ - if x is None: - d = dx - elif x.ndim == 1: - d = np.diff(x) - shape = [1] * y.ndim - shape[axis] = d.shape[0] - d = np.reshape(d, shape) - else: - d = np.diff(x, axis=axis) - - slice1 = [slice(None)] * y.ndim - slice2 = [slice(None)] * y.ndim - slice1[axis] = slice(1, None) - slice2[axis] = slice(None, -1) - - return anp.sum((y[tuple(slice1)] + y[tuple(slice2)]) * d / 2, axis=axis) - - -@primitive -def _add_at(x: NDArray, indices_x: tuple, y: NDArray) -> NDArray: - """ - Add values to specified indices of an array. - - Autograd requires that arguments to primitives are passed in positionally. - ``add_at`` is the public-facing wrapper for this function, - which allows keyword arguments in case users pass in kwargs. - """ - out = np.copy(x) # Copy to preserve 'x' for gradient computation - out[tuple(indices_x)] += y - return out - - -defvjp( - _add_at, - lambda ans, x, indices_x, y: unbroadcast_f(x, lambda g: g), - lambda ans, x, indices_x, y: lambda g: g[tuple(indices_x)], - argnums=(0, 2), -) +# marked as migrated to _common +from __future__ import annotations -defjvp( +from tidy3d._common.components.autograd.functions import ( _add_at, - lambda g, ans, x, indices_x, y: broadcast(g, ans), - lambda g, ans, x, indices_x, y: _add_at(anp.zeros_like(ans), indices_x, g), - argnums=(0, 2), + _evaluate_linear, + _evaluate_nearest, + _straight_through_clip, + add_at, + interpn, + trapz, ) - - -def add_at(x: NDArray, indices_x: tuple, y: NDArray) -> NDArray: - """ - Add values to specified indices of an array. - - This function creates a copy of the input array `x`, adds the values from `y` to the specified - indices `indices_x`, and returns the modified array. - - Parameters - ---------- - x : np.ndarray - Input array to which values will be added. - indices_x : tuple - Indices of `x` where values from `y` will be added. - y : np.ndarray - Values to add to the specified indices of `x`. - - Returns - ------- - np.ndarray - The modified array with values added at the specified indices. - """ - return _add_at(x, indices_x, y) - - -@primitive -def _straight_through_clip(x: NDArray, a_min: Any, a_max: Any) -> NDArray: - """Passthrough clip can be used to preserve gradients at the endpoints of the clip range where - there is a discontinuity in the derivative. This is useful when values are at the endpoints but may - have a gradient away from the boundary or in cases where numerical precision causes a function that is - typically bounded by the clip bounds to produce a value just outside the bounds. In the forward pass, - this runs the standard clip.""" - return anp.clip(x, a_min=a_min, a_max=a_max) - - -def _straight_through_clip_vjp(ans: Any, x: NDArray, a_min: Any, a_max: Any) -> NDArray: - """Preserve original gradient information in the backward pass up until a tolerance beyond the clip bounds.""" - tolerance = 1e-5 - mask = (x >= a_min - tolerance) & (x <= a_max + tolerance) - return lambda g: g * mask - - -defvjp(_straight_through_clip, _straight_through_clip_vjp) - -__all__ = [ - "add_at", - "interpn", - "trapz", -] diff --git a/tidy3d/components/autograd/types.py b/tidy3d/components/autograd/types.py index 93c445a5de..dcbfd4a274 100644 --- a/tidy3d/components/autograd/types.py +++ b/tidy3d/components/autograd/types.py @@ -1,132 +1,26 @@ -# type information for autograd +"""Compatibility shim for :mod:`tidy3d._common.components.autograd.types`.""" -from __future__ import annotations - -import copy -from typing import TYPE_CHECKING, Annotated, Any, Literal, Union, get_origin - -import autograd.numpy as anp -from autograd.builtins import dict as TracedDict -from autograd.extend import Box, defvjp, primitive -from autograd.numpy.numpy_boxes import ArrayBox -from pydantic import BeforeValidator, PlainSerializer, PositiveFloat, TypeAdapter - -from tidy3d.components.types import ArrayFloat2D, ArrayLike, Complex, Size1D -from tidy3d.components.types.base import _auto_serializer -from tidy3d.components.types.utils import _add_schema - -from .utils import get_static, hasbox - -if TYPE_CHECKING: - from typing import Optional - - from pydantic import SerializationInfo - - from tidy3d.compat import TypeAlias - -# add schema to the Box -_add_schema(Box, title="AutogradBox", field_type_str="autograd.tracer.Box") -_add_schema(ArrayBox, title="AutogradArrayBox", field_type_str="autograd.numpy.ArrayBox") - -# make sure Boxes in tidy3d properly define VJPs for copy operations, for computational graph -_copy = primitive(copy.copy) -_deepcopy = primitive(copy.deepcopy) - -defvjp(_copy, lambda ans, x: lambda g: _copy(g)) -defvjp(_deepcopy, lambda ans, x, memo: lambda g: _deepcopy(g, memo)) - -Box.__copy__ = lambda v: _copy(v) -Box.__deepcopy__ = lambda v, memo: _deepcopy(v, memo) -Box.__str__ = lambda self: f"{self._value} <{type(self).__name__}>" -Box.__repr__ = Box.__str__ - - -def traced_alias(base_alias: Any, *, name: Optional[str] = None) -> TypeAlias: - base_adapter = TypeAdapter(base_alias, config={"arbitrary_types_allowed": True}) +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility - def _validate_box_or_container(v: Any) -> Any: - # case 1: v itself is a tracer - # in this case we just validate but leave the tracer untouched - if isinstance(v, Box): - base_adapter.validate_python(get_static(v)) - return v - - # case 2: v is a plain container that contains at least one tracer - # in this case we try to coerce into ArrayBox for one-shot validation, - # but always return the original v, and fall back to a structural walk if needed - if hasbox(v): - # decide whether we must return an array - origin = get_origin(base_alias) - is_array_field = base_alias in (ArrayLike, ArrayFloat2D) or origin is None - - if is_array_field: - dense = anp.array(v) - base_adapter.validate_python(get_static(dense)) - return dense - - # otherwise it's a Python container type - # try the fast-path array validation, but return the array so ops work - try: - dense = anp.array(v) - base_adapter.validate_python(get_static(dense)) - return dense - - except Exception: - # ragged/un-coercible -> rebuild container of Boxes - if isinstance(v, tuple): - return tuple(_validate_box_or_container(x) for x in v) - if isinstance(v, list): - return [_validate_box_or_container(x) for x in v] - if isinstance(v, dict): - return {k: _validate_box_or_container(x) for k, x in v.items()} - # fallback: can't handle this structure - raise - - return base_adapter.validate_python(v) - - def _serialize_traced(a: Any, info: SerializationInfo) -> Any: - return _auto_serializer(get_static(a), info) - - return Annotated[ - object, - BeforeValidator(_validate_box_or_container), - PlainSerializer(_serialize_traced, when_used="json"), - ] - - -# "primitive" types that can use traced_alias -TracedArrayLike = traced_alias(ArrayLike) -TracedArrayFloat2D = traced_alias(ArrayFloat2D) -TracedFloat = traced_alias(float) -TracedPositiveFloat = traced_alias(PositiveFloat) -TracedComplex = traced_alias(Complex) -TracedSize1D = traced_alias(Size1D) - -# derived traced types (these mirror the types in `components.types`) -TracedSize = tuple[TracedSize1D, TracedSize1D, TracedSize1D] -TracedCoordinate = tuple[TracedFloat, TracedFloat, TracedFloat] -TracedPoleAndResidue = tuple[TracedComplex, TracedComplex] -TracedPolesAndResidues = tuple[TracedPoleAndResidue, ...] - -# The data type that we pass in and out of the web.run() @autograd.primitive -PathType = tuple[Union[int, str], ...] -AutogradFieldMap = TracedDict[PathType, Box] - -InterpolationType = Literal["nearest", "linear"] +# marked as migrated to _common +from __future__ import annotations -__all__ = [ - "AutogradFieldMap", - "InterpolationType", - "PathType", - "TracedArrayFloat2D", - "TracedArrayLike", - "TracedComplex", - "TracedCoordinate", - "TracedDict", - "TracedFloat", - "TracedPoleAndResidue", - "TracedPolesAndResidues", - "TracedPositiveFloat", - "TracedSize", - "TracedSize1D", -] +from tidy3d._common.components.autograd.types import ( + AutogradFieldMap, + InterpolationType, + PathType, + TracedArrayFloat2D, + TracedArrayLike, + TracedComplex, + TracedCoordinate, + TracedDict, + TracedFloat, + TracedPoleAndResidue, + TracedPolesAndResidues, + TracedPositiveFloat, + TracedSize, + TracedSize1D, + _copy, + _deepcopy, + traced_alias, +) diff --git a/tidy3d/components/autograd/utils.py b/tidy3d/components/autograd/utils.py index 76c13b583f..f8d9f4304e 100644 --- a/tidy3d/components/autograd/utils.py +++ b/tidy3d/components/autograd/utils.py @@ -1,84 +1,16 @@ -# utilities for working with autograd -from __future__ import annotations - -from collections.abc import Iterable, Mapping, Sequence -from typing import TYPE_CHECKING, Any - -import autograd.numpy as anp -from autograd.tracer import getval, isbox - -if TYPE_CHECKING: - from typing import Union - - from autograd.numpy.numpy_boxes import ArrayBox - from numpy.typing import ArrayLike, NDArray - -__all__ = [ - "asarray1d", - "contains", - "get_static", - "hasbox", - "is_tidy_box", - "pack_complex_vec", - "split_list", -] - - -def get_static(item: Any) -> Any: - """ - Get the 'static' (untraced) version of some value by recursively calling getval - on Box instances within a nested structure. - """ - if isbox(item): - return getval(item) - elif isinstance(item, list): - return [get_static(x) for x in item] - elif isinstance(item, tuple): - return tuple(get_static(x) for x in item) - elif isinstance(item, dict): - return {k: get_static(v) for k, v in item.items()} - return item - +"""Compatibility shim for :mod:`tidy3d._common.components.autograd.utils`.""" -def split_list(x: list[Any], index: int) -> tuple[list, list]: - """Split a list at a given index.""" - x = list(x) - return x[:index], x[index:] - - -def is_tidy_box(x: Any) -> bool: - """Check if a value is a tidy box.""" - return getattr(x, "_tidy", False) - - -def contains(target: Any, seq: Iterable[Any]) -> bool: - """Return ``True`` if target occurs anywhere within arbitrarily nested iterables.""" - for x in seq: - if x == target: - return True - if isinstance(x, Iterable) and not isinstance(x, (str, bytes)): - if contains(target, x): - return True - return False - - -def hasbox(obj: Any) -> bool: - """True if any element inside obj is an autograd Box.""" - if isbox(obj): - return True - if isinstance(obj, Mapping): - return any(hasbox(v) for v in obj.values()) - if isinstance(obj, Sequence) and not isinstance(obj, (str, bytes)): - return any(hasbox(i) for i in obj) - return False - - -def pack_complex_vec(z: Union[NDArray, ArrayBox]) -> Union[NDArray, ArrayBox]: - """Ravel [Re(z); Im(z)] into one real vector (autograd-safe).""" - return anp.concatenate([anp.ravel(anp.real(z)), anp.ravel(anp.imag(z))]) +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility +# marked as migrated to _common +from __future__ import annotations -def asarray1d(x: Union[ArrayLike, ArrayBox]) -> Union[NDArray, ArrayBox]: - """Autograd-friendly 1D flatten: returns ndarray of shape (-1,).""" - x = anp.array(x) - return x if x.ndim == 1 else anp.ravel(x) +from tidy3d._common.components.autograd.utils import ( + asarray1d, + contains, + get_static, + hasbox, + is_tidy_box, + pack_complex_vec, + split_list, +) diff --git a/tidy3d/components/base.py b/tidy3d/components/base.py index 265f02acb8..cfae36803c 100644 --- a/tidy3d/components/base.py +++ b/tidy3d/components/base.py @@ -1,1896 +1,27 @@ -"""global configuration / base class for pydantic models used to make simulation.""" +"""Compatibility shim for :mod:`tidy3d._common.components.base`.""" -from __future__ import annotations - -import hashlib -import io -import json -import math -import os -import tempfile -import typing as _t -from collections import defaultdict -from collections.abc import Mapping, Sequence -from functools import total_ordering, wraps -from math import ceil -from os import PathLike -from pathlib import Path -from types import UnionType -from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar, Union, get_args, get_origin - -import h5py -import numpy as np -import rich -import xarray as xr -import yaml -from autograd.numpy.numpy_boxes import ArrayBox -from autograd.tracer import isbox -from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, field_validator, model_validator - -from tidy3d.exceptions import FileError -from tidy3d.log import log - -from .autograd.types import TracedDict -from .autograd.utils import get_static -from .data.data_array import DATA_ARRAY_MAP -from .file_util import compress_file_to_gzip, extract_gzip_file -from .types import TYPE_TAG_STR, Undefined - -if TYPE_CHECKING: - from collections.abc import Iterator - from typing import Callable - - from pydantic.fields import FieldInfo - from pydantic.functional_validators import ModelWrapValidatorHandler - - from tidy3d.compat import Self - - from .autograd.types import AutogradFieldMap - -INDENT_JSON_FILE = 4 # default indentation of json string in json files -INDENT = None # default indentation of json string used internally -JSON_TAG = "JSON_STRING" -# If json string is larger than ``MAX_STRING_LENGTH``, split the string when storing in hdf5 -MAX_STRING_LENGTH = 1_000_000_000 -FORBID_SPECIAL_CHARACTERS = ["/"] -TRACED_FIELD_KEYS_ATTR = "__tidy3d_traced_field_keys__" -TYPE_TO_CLASS_MAP: dict[str, type[Tidy3dBaseModel]] = {} - -_CacheReturn = TypeVar("_CacheReturn") - - -def cache(prop: Callable[[Any], _CacheReturn]) -> Callable[[Any], _CacheReturn]: - """Decorates a property to cache the first computed value and return it on subsequent calls.""" - - # note, we could also just use `prop` as dict key, but hashing property might be slow - prop_name = prop.__name__ - - @wraps(prop) - def cached_property_getter(self: Any) -> _CacheReturn: - """The new property method to be returned by decorator.""" - - stored_value = self._cached_properties.get(prop_name) - - if stored_value is not None: - return stored_value - - computed_value = prop(self) - self._cached_properties[prop_name] = computed_value - return computed_value - - return cached_property_getter - - -def cached_property(cached_property_getter: Callable[[Any], _CacheReturn]) -> property: - """Shortcut for property(cache()) of a getter.""" - - return property(cache(cached_property_getter)) - - -_GuardedReturn = TypeVar("_GuardedReturn") - - -def cached_property_guarded( - key_func: Callable[[Any], Any], -) -> Callable[[Callable[[Any], _GuardedReturn]], property]: - """Like cached_property, but invalidates when the key_func(self) changes.""" - - def _decorator(getter: Callable[[Any], _GuardedReturn]) -> property: - prop_name = getter.__name__ - - @wraps(getter) - def _guarded(self: Any) -> _GuardedReturn: - cache_store = self._cached_properties.get(prop_name) - current_key = key_func(self) - if cache_store is not None: - cached_key, cached_value = cache_store - if cached_key == current_key: - return cached_value - value = getter(self) - self._cached_properties[prop_name] = (current_key, value) - return value - - return property(_guarded) - - return _decorator - - -def make_json_compatible(json_string: str) -> str: - """Makes the string compatible with json standards, notably for infinity.""" - - tmp_string = "<>" - json_string = json_string.replace("-Infinity", tmp_string) - json_string = json_string.replace('""-Infinity""', tmp_string) - json_string = json_string.replace("Infinity", '"Infinity"') - json_string = json_string.replace('""Infinity""', '"Infinity"') - return json_string.replace(tmp_string, '"-Infinity"') - - -def _get_valid_extension(fname: PathLike) -> str: - """Return the file extension from fname, validated to accepted ones.""" - valid_extensions = [".json", ".yaml", ".hdf5", ".h5", ".hdf5.gz"] - path = Path(fname) - extensions = [s.lower() for s in path.suffixes[-2:]] - if len(extensions) == 0: - raise FileError(f"File '{path}' missing extension.") - single_extension = extensions[-1] - if single_extension in valid_extensions: - return single_extension - double_extension = "".join(extensions) - if double_extension in valid_extensions: - return double_extension - raise FileError( - f"File extension must be one of {', '.join(valid_extensions)}; file '{path}' does not " - "match any of those." - ) - - -def _fmt_ann_literal(ann: Any) -> str: - """Spell the annotation exactly as written.""" - if ann is None: - return "Any" - if isinstance(ann, _t._GenericAlias): - return str(ann).replace("typing.", "") - return ann.__name__ if hasattr(ann, "__name__") else str(ann) - - -T = TypeVar("T", bound="Tidy3dBaseModel") - - -def field_allows_scalar(field: FieldInfo) -> bool: - annotation = field.annotation - - def allows_scalar(a: Any) -> bool: - origin = get_origin(a) - if origin in (Union, UnionType): - args = (arg for arg in get_args(a) if arg is not type(None)) - return any(allows_scalar(arg) for arg in args) - if origin is not None: - return False - return isinstance(a, type) and issubclass(a, (float, int, np.generic)) - - return allows_scalar(annotation) - - -@total_ordering -class Tidy3dBaseModel(BaseModel): - """Base pydantic model that all Tidy3d components inherit from. - Defines configuration for handling data structures - as well as methods for importing, exporting, and hashing tidy3d objects. - For more details on pydantic base models, see: - `Pydantic models `_ - """ - - model_config = ConfigDict( - arbitrary_types_allowed=True, - defer_build=True, - validate_default=True, - populate_by_name=True, - ser_json_inf_nan="strings", - extra="forbid", - frozen=True, - ) - - attrs: dict = Field( - default_factory=dict, - title="Attributes", - description="Dictionary storing arbitrary metadata for a Tidy3D object. " - "This dictionary can be freely used by the user for storing data without affecting the " - "operation of Tidy3D as it is not used internally. " - "Note that, unlike regular Tidy3D fields, ``attrs`` are mutable. " - "For example, the following is allowed for setting an ``attr`` ``obj.attrs['foo'] = bar``. " - "Also note that Tidy3D will raise a ``TypeError`` if ``attrs`` contain objects " - "that can not be serialized. One can check if ``attrs`` are serializable " - "by calling ``obj.model_dump_json()``.", - ) - - _cached_properties: dict = PrivateAttr(default_factory=dict) - _has_tracers: Optional[bool] = PrivateAttr(default=None) - - @field_validator("name", check_fields=False) - @classmethod - def _validate_name_no_special_characters(cls: type[T], name: Optional[str]) -> Optional[str]: - if name is None: - return name - for character in FORBID_SPECIAL_CHARACTERS: - if character in name: - raise ValueError( - f"Special character '{character}' not allowed in component name {name}." - ) - return name - - def __init_subclass__(cls: type[T], **kwargs: Any) -> None: - """Injects a constant discriminator field before Pydantic builds the model. - - Adds - type: Literal[""] = "" - to every concrete subclass so it can participate in a - `Field(discriminator="type")` union without manual boilerplate. - - Must run *before* `super().__init_subclass__()`; that call lets Pydantic - see the injected field during its normal schema/validator generation. - See also: https://peps.python.org/pep-0487/ - """ - tag = cls.__name__ - cls.__annotations__[TYPE_TAG_STR] = Literal[tag] - setattr(cls, TYPE_TAG_STR, tag) - TYPE_TO_CLASS_MAP[tag] = cls - - if "__tidy3d_end_capture__" not in cls.__dict__: - - @model_validator(mode="after") - def __tidy3d_end_capture__(self: T) -> T: - if log._capture: - log.end_capture(self) - return self - - cls.__tidy3d_end_capture__ = __tidy3d_end_capture__ - - super().__init_subclass__(**kwargs) - - @classmethod - def __pydantic_init_subclass__(cls: type[T], **kwargs: Any) -> None: - super().__pydantic_init_subclass__(**kwargs) - - # add docstring once pydantic is done constructing the class - cls.__doc__ = cls.generate_docstring() - - @model_validator(mode="wrap") - @classmethod - def _capture_validation_warnings( - cls: type[T], - data: Any, - handler: ModelWrapValidatorHandler[T], - ) -> T: - if not log._capture: - return handler(data) - - log.begin_capture() - try: - return handler(data) - except Exception: - log.abort_capture() - raise - - def __hash__(self) -> int: - """Hash method.""" - return self._recursive_hash(self) - - @staticmethod - def _recursive_hash(value: Any) -> int: - # Handle Autograd ArrayBoxes - if isinstance(value, ArrayBox): - # Unwrap the underlying numpy array and recurse - return Tidy3dBaseModel._recursive_hash(value._value) - if isinstance(value, np.ndarray): - # numpy arrays are not hashable by default, use byte representation - v_hash = hashlib.md5(value.tobytes()).hexdigest() - return hash(v_hash) - if isinstance(value, (xr.DataArray, xr.Dataset)): - # we choose to not hash data arrays as this would require a lot of careful handling of units, metadata. - # technically this is incorrect, but should never lead to bugs in current implementation - return hash(str(value.__class__.__name__)) - if isinstance(value, str): - # this if-case is necessary because length-1 string would lead to infinite recursion in sequence case below - return hash(value) - if isinstance(value, Sequence): - # this assumes all objects in lists are hashable by default and do not require special handling - v_hash = tuple([Tidy3dBaseModel._recursive_hash(vi) for vi in value]) - return hash(v_hash) - if isinstance(value, dict): - to_hash_list = [] - for k, v in value.items(): - v_hash = Tidy3dBaseModel._recursive_hash(v) - to_hash_list.append((k, v_hash)) - return hash(tuple(to_hash_list)) - if isinstance(value, Tidy3dBaseModel): - # This function needs to take special care because of mutable attributes inside of frozen pydantic models - to_hash_list = [] - for k in type(value).model_fields: - if k == "attrs": - continue - v_hash = Tidy3dBaseModel._recursive_hash(getattr(value, k)) - to_hash_list.append((k, v_hash)) - extra = getattr(value, "__pydantic_extra__", None) - if extra: - for k, v in extra.items(): - v_hash = Tidy3dBaseModel._recursive_hash(v) - to_hash_list.append((k, v_hash)) - # attrs is mutable, use serialized output as safe hashing option - if value.attrs: - attrs_str = value._attrs_digest() - attrs_hash = hash(attrs_str) - to_hash_list.append(("attrs", attrs_hash)) - return hash(tuple(to_hash_list)) - return hash(value) - - def _hash_self(self) -> str: - """Hash this component with ``hashlib`` in a way that is the same every session.""" - bf = io.BytesIO() - self.to_hdf5(bf) - return hashlib.md5(bf.getvalue()).hexdigest() - - @model_validator(mode="before") - @classmethod - def coerce_numpy_scalars_for_model(cls, data: Any) -> Any: - """ - coerce numpy scalars / size-1 arrays to native Python - scalars, but only for fields whose annotations allow scalars. - """ - if not isinstance(data, dict): - return data - - for name, field in cls.model_fields.items(): - if name not in data or not field_allows_scalar(field): - continue - - v = data[name] - if isinstance(v, np.generic) or (isinstance(v, np.ndarray) and v.size == 1): - data[name] = v.item() - - return data - - @classmethod - def _get_type_value(cls, obj: dict[str, Any]) -> str: - """Return the type tag from a raw dictionary.""" - if not isinstance(obj, dict): - raise TypeError("Input must be a dict") - try: - type_value = obj[TYPE_TAG_STR] - except KeyError as exc: - raise ValueError(f'Missing "{TYPE_TAG_STR}" in data') from exc - if not isinstance(type_value, str) or not type_value: - raise ValueError(f'Invalid "{TYPE_TAG_STR}" value: {type_value!r}') - return type_value - - @classmethod - def _get_registered_class(cls, type_value: str) -> type[Tidy3dBaseModel]: - try: - return TYPE_TO_CLASS_MAP[type_value] - except KeyError as exc: - raise ValueError(f"Unknown type: {type_value}") from exc - - @classmethod - def _should_dispatch_to(cls, target_cls: type[Tidy3dBaseModel]) -> bool: - """Return True if ``cls`` allows auto-dispatch to ``target_cls``.""" - return issubclass(target_cls, cls) - - @classmethod - def _resolve_dispatch_target(cls, obj: dict[str, Any]) -> type[Tidy3dBaseModel]: - """Determine which subclass should receive ``obj``.""" - type_value = cls._get_type_value(obj) - target_cls = cls._get_registered_class(type_value) - if cls._should_dispatch_to(target_cls): - return target_cls - if target_cls is cls: - return cls - raise ValueError( - f'Cannot parse type "{type_value}" using {cls.__name__}; expected subclass of {cls.__name__}.' - ) - - @classmethod - def _target_cls_from_file( - cls, fname: PathLike, group_path: Optional[str] = None - ) -> type[Tidy3dBaseModel]: - """Peek the file metadata to determine the subclass to instantiate.""" - model_dict = cls.dict_from_file( - fname=fname, - group_path=group_path, - load_data_arrays=False, - ) - return cls._resolve_dispatch_target(model_dict) - - @classmethod - def _model_validate(cls, obj: dict[str, Any], **parse_obj_kwargs: Any) -> Tidy3dBaseModel: - """Dispatch ``obj`` to the correct subclass registered in the type map.""" - target_cls = cls._resolve_dispatch_target(obj) - if target_cls is cls: - return super().model_validate(obj, **parse_obj_kwargs) - return target_cls.model_validate(obj, **parse_obj_kwargs) - - @classmethod - def _validate_model_dict( - cls, model_dict: dict[str, Any], **parse_obj_kwargs: Any - ) -> Tidy3dBaseModel: - """Parse ``model_dict`` while optionally auto-dispatching when called on the base class.""" - if cls is Tidy3dBaseModel: - return cls._model_validate(model_dict, **parse_obj_kwargs) - return cls.model_validate(model_dict, **parse_obj_kwargs) - - def _preprocess_update_values(self, update: Mapping[str, Any]) -> dict[str, Any]: - """Preprocess update values to convert lists to tuples where appropriate. - - This helps avoid Pydantic v2 serialization warnings when using `model_copy()` - with list values for tuple fields. - """ - if not update: - return {} - - def get_tuple_element_type(annotation: Any) -> Optional[type]: - """Get the element type of a tuple annotation if it has one consistent type.""" - origin = get_origin(annotation) - if origin is tuple: - args = get_args(annotation) - if args: - # Check if it's a homogeneous tuple like tuple[bool, ...] or tuple[str, ...] - if len(args) == 2 and args[1] is ...: - return args[0] - # Check if all elements have the same type - if all(arg == args[0] for arg in args): - return args[0] - return None - - def should_convert_to_tuple(annotation: Any) -> tuple[bool, Optional[type[Any]]]: - """Check if the given annotation represents a tuple type and return element type if any.""" - origin = get_origin(annotation) - - if origin is tuple: - return True, get_tuple_element_type(annotation) - - # Union types containing tuple - if origin is Union: - args = get_args(annotation) - for arg in args: - if get_origin(arg) is tuple: - return True, get_tuple_element_type(arg) - - return False, None - - def convert_value(value: Any, field_info: FieldInfo) -> Any: - """Convert value based on field type information.""" - annotation = field_info.annotation - - # Handle list/tuple to tuple conversion with proper element types - is_tuple, element_type = should_convert_to_tuple(annotation) - - # Check if value is a numpy array and needs to be converted to tuple - try: - import numpy as np - - if isinstance(value, np.ndarray) and is_tuple: - # Convert numpy array to list first - value = value.tolist() - except ImportError: - pass - - # Handle autograd SequenceBox - convert to tuple - if ( - is_tuple - and hasattr(value, "__class__") - and value.__class__.__name__ == "SequenceBox" - ): - # SequenceBox is iterable, so convert it to tuple - return tuple(value) - - if isinstance(value, (list, tuple)) and is_tuple: - # Convert elements based on element type - if element_type is bool: - # Convert integers to booleans - value = [bool(item) if isinstance(item, int) else item for item in value] - elif element_type is str: - # Ensure all elements are strings - value = [str(item) if not isinstance(item, str) else item for item in value] - else: - # Check if it's a numpy array or contains numpy types - try: - import numpy as np - - if any(isinstance(item, np.generic) for item in value): - # Convert numpy types to Python types - value = [ - item.item() if isinstance(item, np.generic) else item - for item in value - ] - except ImportError: - pass - return tuple(value) - - # Handle int to bool conversion - if annotation is bool and isinstance(value, int): - return bool(value) - - # Handle dict to Tidy3dBaseModel conversion - if isinstance(value, dict): - # Check if the annotation is a Tidy3dBaseModel subclass - origin = get_origin(annotation) - if origin is None: - # Not a generic type, check if it's a direct subclass - try: - if isinstance(annotation, type) and issubclass(annotation, Tidy3dBaseModel): - return annotation(**value) - except (TypeError, AttributeError): - pass - elif origin is Union: - # For Union types, try to convert to the first matching Tidy3dBaseModel type - args = get_args(annotation) - for arg in args: - try: - if isinstance(arg, type) and issubclass(arg, Tidy3dBaseModel): - return arg(**value) - except (TypeError, AttributeError, ValueError): - continue - - return value - - processed = {} - for field_name, value in update.items(): - if field_name in type(self).model_fields: - field_info = type(self).model_fields[field_name] - processed[field_name] = convert_value(value, field_info) - else: - processed[field_name] = value - - return processed - - def copy( - self, - deep: bool = True, - *, - validate: bool = True, - update: Optional[Mapping[str, Any]] = None, - ) -> Self: - """Return a copy of the model. - - Parameters - ---------- - deep : bool = True - Whether to make a deep copy first (same as v1). - validate : bool = True - If ``True``, run full Pydantic validation on the copied data. - update : Optional[Mapping[str, Any]] = None - Optional mapping of fields to overwrite (passed straight - through to ``model_copy(update=...)``). - """ - if update and self.model_config.get("extra") == "forbid": - invalid = set(update) - set(type(self).model_fields) - if invalid: - raise KeyError(f"'{self.type}' received invalid fields on copy: {invalid}") - - # preprocess update values to convert lists to tuples where appropriate - if update: - update = self._preprocess_update_values(update) - - new_model = self.model_copy(deep=deep, update=update) - - if validate: - return self.__class__.model_validate(new_model.model_dump()) - else: - # make sure cache is always cleared - new_model._cached_properties = {} - - new_model._has_tracers = None - return new_model - - def updated_copy( - self, - path: Optional[str] = None, - *, - deep: bool = True, - validate: bool = True, - **kwargs: Any, - ) -> Self: - """Make copy of a component instance with ``**kwargs`` indicating updated field values. - - Note - ---- - If ``path`` is supplied, applies the updated copy with the update performed on the sub- - component corresponding to the path. For indexing into a tuple or list, use the integer - value. - - Example - ------- - >>> sim = simulation.updated_copy(size=new_size, path=f"structures/{i}/geometry") # doctest: +SKIP - """ - if not path: - return self.copy(deep=deep, validate=validate, update=kwargs) - - path_parts = path.split("/") - field_name, *rest = path_parts - - try: - sub_component = getattr(self, field_name) - except AttributeError as exc: - raise AttributeError( - f"Could not find field '{field_name}' in path '{path}'. " - f"Available top-level fields: {tuple(type(self).model_fields)}." - ) from exc - - if isinstance(sub_component, (list, tuple)): - try: - index = int(rest[0]) - except (IndexError, ValueError): - raise ValueError( - f"Expected integer index into '{field_name}' in path '{path}'." - ) from None - sub_component_list = list(sub_component) - sub_component_list[index] = sub_component_list[index].updated_copy( - path="/".join(rest[1:]), - deep=deep, - validate=validate, - **kwargs, - ) - new_value = type(sub_component)(sub_component_list) - else: - new_value = sub_component.updated_copy( - path="/".join(rest), - deep=deep, - validate=validate, - **kwargs, - ) - - return self.copy(deep=deep, validate=validate, update={field_name: new_value}) - - @staticmethod - def _core_model_traversal( - current_obj: Any, current_path_segments: tuple[str, ...] - ) -> Iterator[tuple[Self, tuple[str, ...]]]: - """ - Recursively traverses a model structure yielding Tidy3dBaseModel instances and their paths. - - This is an internal helper method used by :meth:`find_paths` and :meth:`find_submodels` - to navigate nested :class:`Tidy3dBaseModel` structures. - - Parameters - ---------- - current_obj : Any - The current object in the traversal, which can be a :class:`Tidy3dBaseModel`, - list, tuple, or other type. - current_path_segments : tuple[str, ...] - A tuple of strings representing the path segments from the initial model - to the ``current_obj``. - - Returns - ------- - Iterator[tuple[Self, tuple[str, ...]]] - An iterator yielding tuples, where the first element is a found :class:`Tidy3dBaseModel` instance - and the second is a tuple of strings representing the path to that instance - from the initial object. The path for the top-level model itself will be an empty tuple. - """ - if isinstance(current_obj, Tidy3dBaseModel): - yield current_obj, current_path_segments - - for field_name in type(current_obj).model_fields: - if ( - field_name == "type" - and getattr(current_obj, field_name, None) == current_obj.__class__.__name__ - ): - continue - - field_value = getattr(current_obj, field_name) - yield from Tidy3dBaseModel._core_model_traversal( - field_value, (*current_path_segments, field_name) - ) - elif isinstance(current_obj, (list, tuple)): - for index, item in enumerate(current_obj): - yield from Tidy3dBaseModel._core_model_traversal( - item, (*current_path_segments, str(index)) - ) - - def find_paths(self, target_field_name: str, target_field_value: Any = Undefined) -> list[str]: - """ - Finds paths to nested model instances that have a specific field, optionally matching a value. - - The paths are string representations like ``"structures/0/geometry"``, designed for direct - use with the :meth:`updated_copy` method to modify specific parts of this model. - An empty string ``""`` in the returned list indicates that this model instance - itself (the one ``find_paths`` is called on) matches the criteria. - - Parameters - ---------- - target_field_name : str - The name of the attribute (field) to search for within nested - :class:`Tidy3dBaseModel` instances. For example, ``"name"`` or ``"permittivity"``. - target_field_value : Any, optional - If provided, only paths to model instances where ``target_field_name`` also has this - specific value will be returned. If omitted, paths are returned if the - ``target_field_name`` exists, regardless of its value. - - Returns - ------- - list[str] - A sorted list of unique string paths. Each path points to a - :class:`Tidy3dBaseModel` instance that possesses the ``target_field_name`` - (and optionally matches ``target_field_value``). - - Example - ------- - >>> # Assume 'sim' is a Tidy3D simulation object - >>> # Find all geometries named "waveguide" - >>> paths = sim.find_paths(target_field_name="name", target_field_value="waveguide") # doctest: +SKIP - >>> # paths might be ['structures/0', 'structures/3'] - >>> # Update the size of the first found "waveguide" - >>> new_sim = sim.updated_copy(path=paths[0], size=(1.0, 0.5, 0.22)) # doctest: +SKIP - """ - found_paths_set = set() - - for sub_model_instance, path_segments_to_sub_model in Tidy3dBaseModel._core_model_traversal( - self, () - ): - if target_field_name in type(sub_model_instance).model_fields: - passes_value_filter = True - if target_field_value is not Undefined: - actual_value = getattr(sub_model_instance, target_field_name) - if actual_value != target_field_value: - passes_value_filter = False - - if passes_value_filter: - path_str = "/".join(path_segments_to_sub_model) - found_paths_set.add(path_str) - - return sorted(found_paths_set) - - def find_submodels(self, target_type: Self) -> list[Self]: - """ - Finds all unique nested instances of a specific Tidy3D model type within this model. - - This method traverses the model structure and collects all instances that are of - the ``target_type`` (e.g., :class:`~tidy3d.Structure`, :class:`~tidy3d.Medium`, - :class:`~tidy3d.Box`). - Uniqueness is determined by the model's content. The order of models - in the returned list corresponds to their first encounter during a depth-first traversal. - - Parameters - ---------- - target_type : Tidy3dBaseModel - The specific Tidy3D class (e.g., ``Structure``, ``Medium``, ``Box``) to search for. - This class must be a subclass of :class:`Tidy3dBaseModel`. - - Returns - ------- - list[Tidy3dBaseModel] - A list of unique instances found within this model that are of the - provided ``target_type``. - - Example - ------- - >>> # Assume 'sim' is a Tidy3D Simulation object - >>> # Find all Structure instances within the simulation - >>> all_structures = sim.find_submodels(td.Structure) # doctest: +SKIP - >>> for struct in all_structures: - ... print(f"Structure: {struct.name}, medium: {struct.medium}") # doctest: +SKIP - - >>> # Find all Box geometries within the simulation - >>> all_boxes = sim.find_submodels(td.Box) # doctest: +SKIP - >>> for box in all_boxes: - ... print(f"Found Box with size: {box.size}") # doctest: +SKIP - - >>> # Find all Medium instances (useful for checking materials) - >>> all_media = sim.find_submodels(td.Medium) # doctest: +SKIP - >>> # Note: This would find td.Medium instances, but not td.PECMedium or td.PoleResidue - >>> # unless they inherit directly from td.Medium and not just Tidy3dBaseModel or td.AbstractMedium. - >>> # To find all medium types, one might search for td.AbstractMedium if that's a common base. - """ - found_models_dict = {} - - for sub_model_candidate, _ in Tidy3dBaseModel._core_model_traversal(self, ()): - if isinstance(sub_model_candidate, target_type): - if sub_model_candidate not in found_models_dict: - found_models_dict[sub_model_candidate] = True - - return list(found_models_dict.keys()) - - def help(self, methods: bool = False) -> None: - """Prints message describing the fields and methods of a :class:`Tidy3dBaseModel`. - - Parameters - ---------- - methods : bool = False - Whether to also print out information about object's methods. - - Example - ------- - >>> simulation.help(methods=True) # doctest: +SKIP - """ - rich.inspect(type(self), methods=methods) - - @classmethod - def from_file( - cls, - fname: PathLike, - group_path: Optional[str] = None, - lazy: bool = False, - on_load: Optional[Callable[[Any], None]] = None, - **parse_obj_kwargs: Any, - ) -> Self: - """Loads a :class:`Tidy3dBaseModel` from .yaml, .json, .hdf5, or .hdf5.gz file. - - Parameters - ---------- - fname : PathLike - Full path to the file to load the :class:`Tidy3dBaseModel` from. - group_path : Optional[str] = None - Path to a group inside the file to use as the base level. Only for hdf5 files. - Starting `/` is optional. - lazy : bool = False - Whether to load the actual data (``lazy=False``) or return a proxy that loads - the data when accessed (``lazy=True``). - on_load : Optional[Callable[[Any], None]] = None - Callback function executed once the model is fully materialized. - Only used if ``lazy=True``. The callback is invoked with the loaded - instance as its sole argument, enabling post-processing such as - validation, logging, or warnings checks. - **model_validate_kwargs - Keyword arguments passed to pydantic's ``model_validate`` method when loading model. - - Returns - ------- - Self - An instance of the component class calling ``load``. - - Example - ------- - >>> simulation = Simulation.from_file(fname='folder/sim.json') # doctest: +SKIP - """ - if lazy: - target_cls = cls._target_cls_from_file(fname=fname, group_path=group_path) - Proxy = _make_lazy_proxy(target_cls, on_load=on_load) - return Proxy(fname, group_path, parse_obj_kwargs) - model_dict = cls.dict_from_file(fname=fname, group_path=group_path) - obj = cls._validate_model_dict(model_dict, **parse_obj_kwargs) - if not lazy and on_load is not None: - on_load(obj) - return obj - - @classmethod - def dict_from_file( - cls: type[T], - fname: PathLike, - group_path: Optional[str] = None, - *, - load_data_arrays: bool = True, - ) -> dict: - """Loads a dictionary containing the model from a .yaml, .json, .hdf5, or .hdf5.gz file. - - Parameters - ---------- - fname : PathLike - Full path to the file to load the :class:`Tidy3dBaseModel` from. - group_path : str, optional - Path to a group inside the file to use as the base level. - - Returns - ------- - dict - A dictionary containing the model. - - Example - ------- - >>> simulation = Simulation.from_file(fname='folder/sim.json') # doctest: +SKIP - """ - fname_path = Path(fname) - extension = _get_valid_extension(fname_path) - kwargs = {"fname": fname_path} - - if group_path is not None: - if extension in {".hdf5", ".hdf5.gz", ".h5"}: - kwargs["group_path"] = group_path - else: - log.warning("'group_path' provided, but this feature only works with hdf5 files.") - - if extension in {".hdf5", ".hdf5.gz", ".h5"}: - kwargs["load_data_arrays"] = load_data_arrays - - converter = { - ".json": cls.dict_from_json, - ".yaml": cls.dict_from_yaml, - ".hdf5": cls.dict_from_hdf5, - ".hdf5.gz": cls.dict_from_hdf5_gz, - ".h5": cls.dict_from_hdf5, - }[extension] - return converter(**kwargs) - - def to_file(self, fname: PathLike) -> None: - """Exports :class:`Tidy3dBaseModel` instance to .yaml, .json, or .hdf5 file - - Parameters - ---------- - fname : PathLike - Full path to the .yaml or .json file to save the :class:`Tidy3dBaseModel` to. - - Example - ------- - >>> simulation.to_file(fname='folder/sim.json') # doctest: +SKIP - """ - extension = _get_valid_extension(fname) - converter = { - ".json": self.to_json, - ".yaml": self.to_yaml, - ".hdf5": self.to_hdf5, - ".hdf5.gz": self.to_hdf5_gz, - }[extension] - return converter(fname=fname) - - @classmethod - def from_json(cls: type[T], fname: PathLike, **model_validate_kwargs: Any) -> Self: - """Load a :class:`Tidy3dBaseModel` from .json file. - - Parameters - ---------- - fname : PathLike - Full path to the .json file to load the :class:`Tidy3dBaseModel` from. - - Returns - ------- - Self - An instance of the component class calling `load`. - **model_validate_kwargs - Keyword arguments passed to pydantic's ``model_validate`` method. - - Example - ------- - >>> simulation = Simulation.from_json(fname='folder/sim.json') # doctest: +SKIP - """ - model_dict = cls.dict_from_json(fname=fname) - return cls._validate_model_dict(model_dict, **model_validate_kwargs) - - @classmethod - def dict_from_json(cls: type[T], fname: PathLike) -> dict: - """Load dictionary of the model from a .json file. - - Parameters - ---------- - fname : PathLike - Full path to the .json file to load the :class:`Tidy3dBaseModel` from. - - Returns - ------- - dict - A dictionary containing the model. - - Example - ------- - >>> sim_dict = Simulation.dict_from_json(fname='folder/sim.json') # doctest: +SKIP - """ - with open(fname, encoding="utf-8") as json_fhandle: - model_dict = json.load(json_fhandle) - return model_dict - - def to_json(self, fname: PathLike) -> None: - """Exports :class:`Tidy3dBaseModel` instance to .json file - - Parameters - ---------- - fname : PathLike - Full path to the .json file to save the :class:`Tidy3dBaseModel` to. - - Example - ------- - >>> simulation.to_json(fname='folder/sim.json') # doctest: +SKIP - """ - export_model = self.to_static() - json_string = export_model.model_dump_json(indent=INDENT_JSON_FILE) - self._warn_if_contains_data(json_string) - path = Path(fname) - path.parent.mkdir(parents=True, exist_ok=True) - with path.open("w", encoding="utf-8") as file_handle: - file_handle.write(json_string) - - @classmethod - def from_yaml(cls: type[T], fname: PathLike, **model_validate_kwargs: Any) -> Self: - """Loads :class:`Tidy3dBaseModel` from .yaml file. - - Parameters - ---------- - fname : PathLike - Full path to the .yaml file to load the :class:`Tidy3dBaseModel` from. - **model_validate_kwargs - Keyword arguments passed to pydantic's ``model_validate`` method. - - Returns - ------- - Self - An instance of the component class calling `from_yaml`. - - Example - ------- - >>> simulation = Simulation.from_yaml(fname='folder/sim.yaml') # doctest: +SKIP - """ - model_dict = cls.dict_from_yaml(fname=fname) - return cls._validate_model_dict(model_dict, **model_validate_kwargs) +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility - @classmethod - def dict_from_yaml(cls: type[T], fname: PathLike) -> dict: - """Load dictionary of the model from a .yaml file. - - Parameters - ---------- - fname : PathLike - Full path to the .yaml file to load the :class:`Tidy3dBaseModel` from. - - Returns - ------- - dict - A dictionary containing the model. - - Example - ------- - >>> sim_dict = Simulation.dict_from_yaml(fname='folder/sim.yaml') # doctest: +SKIP - """ - with open(fname, encoding="utf-8") as yaml_in: - model_dict = yaml.safe_load(yaml_in) - return model_dict - - def to_yaml(self, fname: PathLike) -> None: - """Exports :class:`Tidy3dBaseModel` instance to .yaml file. - - Parameters - ---------- - fname : PathLike - Full path to the .yaml file to save the :class:`Tidy3dBaseModel` to. - - Example - ------- - >>> simulation.to_yaml(fname='folder/sim.yaml') # doctest: +SKIP - """ - export_model = self.to_static() - # We intentionally round-trip through JSON to preserve the exact JSON-mode serialization - # behavior in YAML output (notably `ser_json_inf_nan="strings"` for Infinity/-Infinity/NaN). - json_string = export_model.model_dump_json() - self._warn_if_contains_data(json_string) - model_dict = json.loads(json_string) - path = Path(fname) - path.parent.mkdir(parents=True, exist_ok=True) - with path.open("w+", encoding="utf-8") as file_handle: - yaml.dump(model_dict, file_handle, indent=INDENT_JSON_FILE) - - @staticmethod - def _warn_if_contains_data(json_str: str) -> None: - """Log a warning if the json string contains data, used in '.json' and '.yaml' file.""" - if any((key in json_str for key, _ in DATA_ARRAY_MAP.items())): - log.warning( - "Data contents found in the model to be written to file. " - "Note that this data will not be included in '.json' or '.yaml' formats. " - "As a result, it will not be possible to load the file back to the original model. " - "Instead, use '.hdf5' extension in filename passed to 'to_file()'." - ) - - @staticmethod - def _construct_group_path(group_path: str) -> str: - """Construct a group path with the leading forward slash if not supplied.""" - - # empty string or None - if not group_path: - return "/" - - # missing leading forward slash - if group_path[0] != "/": - return f"/{group_path}" - - return group_path - - @staticmethod - def get_tuple_group_name(index: int) -> str: - """Get the group name of a tuple element.""" - return str(int(index)) - - @staticmethod - def get_tuple_index(key_name: str) -> int: - """Get the index into the tuple based on its group name.""" - return int(str(key_name)) - - @classmethod - def tuple_to_dict(cls: type[T], tuple_values: tuple) -> dict: - """How we generate a dictionary mapping new keys to tuple values for hdf5.""" - return {cls.get_tuple_group_name(index=i): val for i, val in enumerate(tuple_values)} - - @classmethod - def get_sub_model( - cls: type[T], group_path: str, model_dict: Union[dict[str, Any], list[Any]] - ) -> dict: - """Get the sub model for a given group path.""" - - for key in group_path.split("/"): - if key: - if isinstance(model_dict, list): - tuple_index = cls.get_tuple_index(key_name=key) - model_dict = model_dict[tuple_index] - else: - model_dict = model_dict[key] - return model_dict - - @staticmethod - def _json_string_key(index: int) -> str: - """Get json string key for string chunk number ``index``.""" - if index: - return f"{JSON_TAG}_{index}" - return JSON_TAG - - @classmethod - def _json_string_from_hdf5(cls: type[T], fname: PathLike) -> str: - """Load the model json string from an hdf5 file.""" - with h5py.File(fname, "r") as f_handle: - num_string_parts = len([key for key in f_handle.keys() if JSON_TAG in key]) - json_string = b"" - for ind in range(num_string_parts): - json_string += f_handle[cls._json_string_key(ind)][()] - return json_string - - @classmethod - def dict_from_hdf5( - cls: type[T], - fname: PathLike, - group_path: str = "", - custom_decoders: Optional[list[Callable]] = None, - load_data_arrays: bool = True, - ) -> dict: - """Loads a dictionary containing the model contents from a .hdf5 file. - - Parameters - ---------- - fname : PathLike - Full path to the .hdf5 file to load the :class:`Tidy3dBaseModel` from. - group_path : str, optional - Path to a group inside the file to selectively load a sub-element of the model only. - custom_decoders : List[Callable] - List of functions accepting - (fname: str, group_path: str, model_dict: dict, key: str, value: Any) that store the - value in the model dict after a custom decoding. - - Returns - ------- - dict - Dictionary containing the model. - - Example - ------- - >>> sim_dict = Simulation.dict_from_hdf5(fname='folder/sim.hdf5') # doctest: +SKIP - """ - - def is_data_array(value: Any) -> bool: - """Whether a value is supposed to be a data array based on the contents.""" - return isinstance(value, str) and value in DATA_ARRAY_MAP - - fname_path = Path(fname) - - def load_data_from_file(model_dict: dict, group_path: str = "") -> None: - """For every DataArray item in dictionary, load path of hdf5 group as value.""" - - for key, value in model_dict.items(): - subpath = f"{group_path}/{key}" - - # apply custom validation to the key value pair and modify model_dict - if custom_decoders: - for custom_decoder in custom_decoders: - custom_decoder( - fname=str(fname_path), - group_path=subpath, - model_dict=model_dict, - key=key, - value=value, - ) - - # write the path to the element of the json dict where the data_array should be - if is_data_array(value): - data_array_type = DATA_ARRAY_MAP[value] - model_dict[key] = data_array_type.from_hdf5( - fname=fname_path, group_path=subpath - ) - continue - - # if a list, assign each element a unique key, recurse - if isinstance(value, (list, tuple)): - value_dict = cls.tuple_to_dict(tuple_values=value) - load_data_from_file(model_dict=value_dict, group_path=subpath) - - # handle case of nested list of DataArray elements - val_tuple = list(value_dict.values()) - for ind, (model_item, value_item) in enumerate(zip(model_dict[key], val_tuple)): - if is_data_array(model_item): - model_dict[key][ind] = value_item - - # if a dict, recurse - elif isinstance(value, dict): - load_data_from_file(model_dict=value, group_path=subpath) - - model_dict = json.loads(cls._json_string_from_hdf5(fname=fname_path)) - group_path = cls._construct_group_path(group_path) - model_dict = cls.get_sub_model(group_path=group_path, model_dict=model_dict) - if load_data_arrays: - load_data_from_file(model_dict=model_dict, group_path=group_path) - return model_dict - - @classmethod - def from_hdf5( - cls: type[T], - fname: PathLike, - group_path: str = "", - custom_decoders: Optional[list[Callable]] = None, - **model_validate_kwargs: Any, - ) -> Self: - """Loads :class:`Tidy3dBaseModel` instance to .hdf5 file. - - Parameters - ---------- - fname : PathLike - Full path to the .hdf5 file to load the :class:`Tidy3dBaseModel` from. - group_path : str, optional - Path to a group inside the file to selectively load a sub-element of the model only. - Starting `/` is optional. - custom_decoders : List[Callable] - List of functions accepting - (fname: str, group_path: str, model_dict: dict, key: str, value: Any) that store the - value in the model dict after a custom decoding. - **model_validate_kwargs - Keyword arguments passed to pydantic's ``model_validate`` method. - - Example - ------- - >>> simulation = Simulation.from_hdf5(fname='folder/sim.hdf5') # doctest: +SKIP - """ - - group_path = cls._construct_group_path(group_path) - model_dict = cls.dict_from_hdf5( - fname=fname, - group_path=group_path, - custom_decoders=custom_decoders, - ) - return cls._validate_model_dict(model_dict, **model_validate_kwargs) - - def to_hdf5( - self, - fname: Union[PathLike, io.BytesIO], - custom_encoders: Optional[list[Callable]] = None, - ) -> None: - """Exports :class:`Tidy3dBaseModel` instance to .hdf5 file. - - Parameters - ---------- - fname : Union[PathLike, BytesIO] - Full path to the .hdf5 file or buffer to save the :class:`Tidy3dBaseModel` to. - custom_encoders : List[Callable] - List of functions accepting (fname: str, group_path: str, value: Any) that take - the ``value`` supplied and write it to the hdf5 ``fname`` at ``group_path``. - - Example - ------- - >>> simulation.to_hdf5(fname='folder/sim.hdf5') # doctest: +SKIP - """ - - export_model = self.to_static() - traced_keys_payload = export_model.attrs.get(TRACED_FIELD_KEYS_ATTR) - - if traced_keys_payload is None: - traced_keys_payload = self.attrs.get(TRACED_FIELD_KEYS_ATTR) - if traced_keys_payload is None: - traced_keys_payload = self._serialized_traced_field_keys() - path = Path(fname) if isinstance(fname, PathLike) else fname - with h5py.File(path, "w") as f_handle: - json_str = export_model.model_dump_json() - for ind in range(ceil(len(json_str) / MAX_STRING_LENGTH)): - ind_start = int(ind * MAX_STRING_LENGTH) - ind_stop = min(int(ind + 1) * MAX_STRING_LENGTH, len(json_str)) - f_handle[self._json_string_key(ind)] = json_str[ind_start:ind_stop] - - def add_data_to_file(data_dict: dict, group_path: str = "") -> None: - """For every DataArray item in dictionary, write path of hdf5 group as value.""" - - for key, value in data_dict.items(): - # append the key to the path - subpath = f"{group_path}/{key}" - - if custom_encoders: - for custom_encoder in custom_encoders: - custom_encoder(fname=f_handle, group_path=subpath, value=value) - - # write the path to the element of the json dict where the data_array should be - if isinstance(value, xr.DataArray): - value.to_hdf5(fname=f_handle, group_path=subpath) - - # if a tuple, assign each element a unique key - if isinstance(value, (list, tuple)): - value_dict = export_model.tuple_to_dict(tuple_values=value) - add_data_to_file(data_dict=value_dict, group_path=subpath) - - # if a dict, recurse - elif isinstance(value, dict): - add_data_to_file(data_dict=value, group_path=subpath) - - add_data_to_file(data_dict=export_model.model_dump()) - if traced_keys_payload: - f_handle.attrs[TRACED_FIELD_KEYS_ATTR] = traced_keys_payload - - @classmethod - def dict_from_hdf5_gz( - cls: type[T], - fname: PathLike, - group_path: str = "", - custom_decoders: Optional[list[Callable]] = None, - load_data_arrays: bool = True, - ) -> dict: - """Loads a dictionary containing the model contents from a .hdf5.gz file. - - Parameters - ---------- - fname : PathLike - Full path to the .hdf5.gz file to load the :class:`Tidy3dBaseModel` from. - group_path : str, optional - Path to a group inside the file to selectively load a sub-element of the model only. - custom_decoders : List[Callable] - List of functions accepting - (fname: str, group_path: str, model_dict: dict, key: str, value: Any) that store the - value in the model dict after a custom decoding. - - Returns - ------- - dict - Dictionary containing the model. - - Example - ------- - >>> sim_dict = Simulation.dict_from_hdf5(fname='folder/sim.hdf5.gz') # doctest: +SKIP - """ - file_descriptor, extracted = tempfile.mkstemp(".hdf5") - os.close(file_descriptor) - extracted_path = Path(extracted) - try: - extract_gzip_file(fname, extracted_path) - result = cls.dict_from_hdf5( - extracted_path, - group_path=group_path, - custom_decoders=custom_decoders, - load_data_arrays=load_data_arrays, - ) - finally: - extracted_path.unlink(missing_ok=True) - - return result - - @classmethod - def from_hdf5_gz( - cls: type[T], - fname: PathLike, - group_path: str = "", - custom_decoders: Optional[list[Callable]] = None, - **model_validate_kwargs: Any, - ) -> Self: - """Loads :class:`Tidy3dBaseModel` instance to .hdf5.gz file. - - Parameters - ---------- - fname : PathLike - Full path to the .hdf5.gz file to load the :class:`Tidy3dBaseModel` from. - group_path : str, optional - Path to a group inside the file to selectively load a sub-element of the model only. - Starting `/` is optional. - custom_decoders : List[Callable] - List of functions accepting - (fname: str, group_path: str, model_dict: dict, key: str, value: Any) that store the - value in the model dict after a custom decoding. - **model_validate_kwargs - Keyword arguments passed to pydantic's ``model_validate`` method. - - Example - ------- - >>> simulation = Simulation.from_hdf5_gz(fname='folder/sim.hdf5.gz') # doctest: +SKIP - """ - - group_path = cls._construct_group_path(group_path) - model_dict = cls.dict_from_hdf5_gz( - fname=fname, - group_path=group_path, - custom_decoders=custom_decoders, - ) - return cls._validate_model_dict(model_dict, **model_validate_kwargs) - - def to_hdf5_gz( - self, - fname: Union[PathLike, io.BytesIO], - custom_encoders: Optional[list[Callable]] = None, - ) -> None: - """Exports :class:`Tidy3dBaseModel` instance to .hdf5.gz file. - - Parameters - ---------- - fname : Union[PathLike, BytesIO] - Full path to the .hdf5.gz file or buffer to save the :class:`Tidy3dBaseModel` to. - custom_encoders : List[Callable] - List of functions accepting (fname: str, group_path: str, value: Any) that take - the ``value`` supplied and write it to the hdf5 ``fname`` at ``group_path``. - - Example - ------- - >>> simulation.to_hdf5_gz(fname='folder/sim.hdf5.gz') # doctest: +SKIP - """ - file, decompressed = tempfile.mkstemp(".hdf5") - os.close(file) - try: - self.to_hdf5(decompressed, custom_encoders=custom_encoders) - compress_file_to_gzip(decompressed, fname) - finally: - os.unlink(decompressed) - - def __lt__(self, other: object) -> bool: - """define < for getting unique indices based on hash.""" - return hash(self) < hash(other) - - def __eq__(self, other: object) -> bool: - """Two models are equal when origins match and every public or extra field matches.""" - if not isinstance(other, BaseModel): - return NotImplemented - - self_origin = ( - getattr(self, "__pydantic_generic_metadata__", {}).get("origin") or self.__class__ - ) - other_origin = ( - getattr(other, "__pydantic_generic_metadata__", {}).get("origin") or other.__class__ - ) - if self_origin is not other_origin: - return False - - if getattr(self, "__pydantic_extra__", None) != getattr(other, "__pydantic_extra__", None): - return False - - def _fields_equal(a: Any, b: Any) -> bool: - a = get_static(a) - b = get_static(b) - - if a is b: - return True - if type(a) is not type(b): - if not (isinstance(a, (list, tuple)) and isinstance(b, (list, tuple))): - return False - if isinstance(a, np.ndarray): - return np.array_equal(a, b) - if isinstance(a, (xr.DataArray, xr.Dataset)): - return a.equals(b) - if isinstance(a, Mapping): - if a.keys() != b.keys(): - return False - return all(_fields_equal(a[k], b[k]) for k in a) - if isinstance(a, Sequence) and not isinstance(a, (str, bytes)): - if len(a) != len(b): - return False - return all(_fields_equal(x, y) for i, (x, y) in enumerate(zip(a, b))) - if isinstance(a, float) and isinstance(b, float) and np.isnan(a) and np.isnan(b): - return True - return a == b - - for name in type(self).model_fields: - if not _fields_equal(getattr(self, name), getattr(other, name)): - return False - - return True - - def _attrs_digest(self) -> str: - """Stable digest of `attrs` using the same JSON encoding rules as `model_dump_json()`.""" - # encoders = getattr(self.__config__, "json_encoders", {}) or {} - - # def _default(o): - # return custom_pydantic_encoder(encoders, o) - - json_str = json.dumps( - self.attrs, - # default=_default, - sort_keys=True, - separators=(",", ":"), - ensure_ascii=False, - ) - json_str = make_json_compatible(json_str) - - return hashlib.sha256(json_str.encode("utf-8")).hexdigest() - - @cached_property_guarded(lambda self: self._attrs_digest()) - def _json_string(self) -> str: - """Returns string representation of a :class:`Tidy3dBaseModel`. - - Returns - ------- - str - Json-formatted string holding :class:`Tidy3dBaseModel` data. - """ - return self.model_dump_json(indent=INDENT, exclude_unset=False) - - def _strip_traced_fields( - self, starting_path: tuple[str, ...] = (), include_untraced_data_arrays: bool = False - ) -> AutogradFieldMap: - """Extract a dictionary mapping paths in the model to the data traced by ``autograd``. - - Parameters - ---------- - starting_path : tuple[str, ...] = () - If provided, starts recursing in self.model_dump() from this path of field names - include_untraced_data_arrays : bool = False - Whether to include ``DataArray`` objects without tracers. - We need to include these when returning data, but are unnecessary for structures. - - Returns - ------- - dict - mapping of traced fields used by ``autograd`` - - """ - - path = tuple(starting_path) - if self._has_tracers is False and not include_untraced_data_arrays: - return TracedDict() - - field_mapping = {} - - def handle_value(x: Any, path: tuple[str, ...]) -> None: - """recursively update ``field_mapping`` with path to the autograd data.""" - - # this is a leaf node that we want to trace, add this path and data to the mapping - if isbox(x): - field_mapping[path] = x - - # for data arrays, need to be more careful as their tracers are stored in .data - elif isinstance(x, xr.DataArray): - data = x.data - if isbox(data) or any(isbox(el) for el in np.asarray(data).ravel()): - field_mapping[path] = x.data - elif include_untraced_data_arrays: - field_mapping[path] = x.data - - # for sequences, add (i,) to the path and handle each value individually - elif isinstance(x, (list, tuple)): - for i, val in enumerate(x): - handle_value(val, path=(*path, i)) - - # for dictionaries, add the (key,) to the path and handle each value individually - elif isinstance(x, dict): - for key, val in x.items(): - handle_value(val, path=(*path, key)) - - # recursively parse the dictionary of this object - self_dict = self.model_dump(round_trip=True) - - # if an include_only string was provided, only look at that subset of the dict - if path: - for key in path: - self_dict = self_dict[key] - - handle_value(self_dict, path=path) - - if field_mapping: - if not include_untraced_data_arrays: - self._has_tracers = True - return TracedDict(field_mapping) - - if not include_untraced_data_arrays and not path: - self._has_tracers = False - return TracedDict() - - def _insert_traced_fields(self, field_mapping: AutogradFieldMap) -> Self: - """Recursively insert a map of paths to autograd-traced fields into a copy of this obj.""" - self_dict = self.model_dump(round_trip=True) - - def insert_value(x: Any, path: tuple[str, ...], sub_dict: dict[str, Any]) -> None: - """Insert a value into the path into a dictionary.""" - current_dict = sub_dict - for key in path[:-1]: - if isinstance(current_dict[key], tuple): - current_dict[key] = list(current_dict[key]) - current_dict = current_dict[key] - - final_key = path[-1] - if isinstance(current_dict[final_key], tuple): - current_dict[final_key] = list(current_dict[final_key]) - - sub_element = current_dict[final_key] - if isinstance(sub_element, xr.DataArray): - current_dict[final_key] = sub_element.copy(deep=False, data=x) - - else: - current_dict[final_key] = x - - for path, value in field_mapping.items(): - insert_value(value, path=path, sub_dict=self_dict) - - return self.__class__.model_validate(self_dict) - - def _serialized_traced_field_keys( - self, field_mapping: Optional[AutogradFieldMap] = None - ) -> Optional[str]: - """Return a serialized, order-independent representation of traced field paths.""" - - if field_mapping is None: - field_mapping = self._strip_traced_fields() - if not field_mapping: - return None - - # TODO: remove this deferred import once TracerKeys is decoupled from Tidy3dBaseModel. - from tidy3d.components.autograd.field_map import TracerKeys - - tracer_keys = TracerKeys.from_field_mapping(field_mapping) - return tracer_keys.model_dump_json() - - def to_static(self) -> Self: - """Version of object with all autograd-traced fields removed.""" - - if self._has_tracers is False: - return self - - # get dictionary of all traced fields - field_mapping = self._strip_traced_fields() - - # shortcut to just return self if no tracers found, for performance - if not field_mapping: - self._has_tracers = False - return self - - # convert all fields to static values - field_mapping_static = {key: get_static(val) for key, val in field_mapping.items()} - - # insert the static values into a copy of self - static_self = self._insert_traced_fields(field_mapping_static) - static_self._has_tracers = False - return static_self - - @classmethod - def generate_docstring(cls) -> str: - """Generates a docstring for a Tidy3D model.""" - - doc = "" - - # keep any pre-existing class description - original_docstrings = [] - if cls.__doc__: - original_docstrings = cls.__doc__.split("\n\n") - doc += original_docstrings.pop(0) - original_docstrings = "\n\n".join(original_docstrings) - - # parameters - doc += "\n\n Parameters\n ----------\n" - for field_name, field in cls.model_fields.items(): # v2 - if field_name == TYPE_TAG_STR: - continue - - # type - ann = getattr(field, "annotation", None) - data_type = _fmt_ann_literal(ann) - - # default / default_factory - default_val = ( - f"{field.default_factory.__name__}()" - if field.default_factory is not None - else field.get_default(call_default_factory=False) - ) - - if isinstance(default_val, BaseModel) or ( - "=" in str(default_val) if default_val is not None else False - ): - default_val = ", ".join( - str(f"{default_val.__class__.__name__}({default_val})").split(" ") - ) - - default_str = "" if field.is_required() else f" = {default_val}" - doc += f" {field_name} : {data_type}{default_str}\n" - - parts = [] - - # units - units = None - extra = getattr(field, "json_schema_extra", None) - if isinstance(extra, dict): - units = extra.get("units") - if units is None and hasattr(field, "metadata"): - for meta in field.metadata: - if isinstance(meta, dict) and "units" in meta: - units = meta["units"] - break - if units is not None: - unitstr = ( - f"({', '.join(str(u) for u in units)})" - if isinstance(units, (list, tuple)) - else str(units) - ) - parts.append(f"[units = {unitstr}].") - - # description - desc = getattr(field, "description", None) - if desc: - parts.append(desc) - - if parts: - doc += " " + " ".join(parts) + "\n" - - if original_docstrings: - doc += "\n" + original_docstrings - doc += "\n" - - return doc - - def get_submodels_by_hash(self) -> dict[int, list[Union[str, tuple[str, int]]]]: - """ - Return a mapping ``{hash(submodel): [field_path, ...]}`` for every - nested ``Tidy3dBaseModel`` inside this model. - """ - out = defaultdict(list) - - for name in type(self).model_fields: - value = getattr(self, name) - - if isinstance(value, Tidy3dBaseModel): - out[hash(value)].append(name) - continue - - if isinstance(value, (list, tuple)): - for idx, item in enumerate(value): - if isinstance(item, Tidy3dBaseModel): - out[hash(item)].append((name, idx)) - - elif isinstance(value, np.ndarray): - for idx, item in enumerate(value.flat): - if isinstance(item, Tidy3dBaseModel): - out[hash(item)].append((name, idx)) - - elif isinstance(value, dict): - for k, item in value.items(): - if isinstance(item, Tidy3dBaseModel): - out[hash(item)].append((name, k)) - - return dict(out) - - @staticmethod - def _scientific_notation( - min_val: float, max_val: float, min_digits: int = 4 - ) -> tuple[str, str]: - """ - Convert numbers to scientific notation, displaying only digits up to the point of difference, - with a minimum number of significant digits specified by `min_digits`. - """ - - def to_sci(value: float, exponent: int, precision: int) -> str: - normalized_value = value / (10**exponent) - return f"{normalized_value:.{precision}f}e{exponent}" - - if min_val == 0 or max_val == 0: - return f"{min_val:.0e}", f"{max_val:.0e}" - - exponent_min = math.floor(math.log10(abs(min_val))) - exponent_max = math.floor(math.log10(abs(max_val))) - - common_exponent = min(exponent_min, exponent_max) - normalized_min = min_val / (10**common_exponent) - normalized_max = max_val / (10**common_exponent) - - if normalized_min == normalized_max: - precision = min_digits - else: - precision = 0 - while round(normalized_min, precision) == round(normalized_max, precision): - precision += 1 - - precision = max(precision, min_digits) - - sci_min = to_sci(min_val, common_exponent, precision) - sci_max = to_sci(max_val, common_exponent, precision) - - return sci_min, sci_max - - def __rich_repr__(self) -> rich.repr.Result: - """How to pretty-print instances of ``Tidy3dBaseModel``.""" - for name in type(self).model_fields: - value = getattr(self, name) - - # don't print the type field we add to the models - if name == "type": - continue - - # skip `attrs` if it's an empty dictionary - if name == "attrs" and isinstance(value, dict) and not value: - continue - - yield name, value - - def __str__(self) -> str: - """Return a pretty-printed string representation of the model.""" - from io import StringIO - - from rich.console import Console - - sio = StringIO() - console = Console(file=sio) - console.print(self) - output = sio.getvalue() - return output.rstrip("\n") - - -def _make_lazy_proxy( - target_cls: type[Tidy3dBaseModel], - on_load: Optional[Callable[[Any], None]] = None, -) -> type[Tidy3dBaseModel]: - """ - Return a lazy-loading proxy subclass of ``target_cls``. - - Parameters - ---------- - target_cls : type - Must implement ``dict_from_file`` and ``model_validate``. - on_load : Optional[Callable[[Any], None]] = None - A function to call with the fully loaded instance once loaded. - - Returns - ------- - type - A class named ``Proxy`` with init args: - ``(fname, group_path, parse_obj_kwargs)``. - """ - - proxy_name = f"{target_cls.__name__}Proxy" - - class _LazyProxy(target_cls): # type: ignore[misc] - def __init__( - self, - fname: PathLike, - group_path: Optional[str], - parse_obj_kwargs: Any, - ) -> None: - # store lazy context only in __dict__ - object.__setattr__(self, "_lazy_fname", Path(fname)) - object.__setattr__(self, "_lazy_group_path", group_path) - object.__setattr__(self, "_lazy_parse_obj_kwargs", dict(parse_obj_kwargs or {})) - - def copy(self, **kwargs: Any) -> Self: - """Return another lazy proxy instead of materializing.""" - return _LazyProxy( - object.__getattribute__(self, "_lazy_fname"), - object.__getattribute__(self, "_lazy_group_path"), - { - **object.__getattribute__(self, "_lazy_parse_obj_kwargs"), - **kwargs, - }, - ) - - def __getattribute__(self, name: str) -> Any: - # Attributes that must *not* trigger materialization - if name.startswith("_lazy_") or name in { - "__class__", - "__dict__", - "__weakref__", - "__post_root_validators__", - "__pydantic_decorators__", - "copy", # don't materialize just for .copy() - }: - return object.__getattribute__(self, name) - - d = object.__getattribute__(self, "__dict__") - - if "_lazy_fname" in d: - fname = d["_lazy_fname"] - group_path = d["_lazy_group_path"] - kwargs = d["_lazy_parse_obj_kwargs"] - - # Build the real instance - model_dict = target_cls.dict_from_file(fname=fname, group_path=group_path) - target = target_cls._validate_model_dict(model_dict, **kwargs) - - d.clear() - d.update(target.__dict__) - - object.__setattr__(self, "__class__", target.__class__) - fields_set = getattr(target, "__pydantic_fields_set__", None) - if fields_set is not None: - object.__setattr__(self, "__pydantic_fields_set__", set(fields_set)) - - pvt = getattr(target, "__pydantic_private__", None) - if pvt is not None: - object.__setattr__(self, "__pydantic_private__", pvt) - - if on_load is not None: - on_load(self) - - return object.__getattribute__(self, name) +# marked as migrated to _common +from __future__ import annotations - _LazyProxy.__name__ = proxy_name - return _LazyProxy +from tidy3d._common.components.base import ( + FORBID_SPECIAL_CHARACTERS, + INDENT, + INDENT_JSON_FILE, + JSON_TAG, + MAX_STRING_LENGTH, + TRACED_FIELD_KEYS_ATTR, + TYPE_TO_CLASS_MAP, + T, + Tidy3dBaseModel, + _CacheReturn, + _fmt_ann_literal, + _get_valid_extension, + _GuardedReturn, + _make_lazy_proxy, + cache, + cached_property, + cached_property_guarded, + make_json_compatible, +) diff --git a/tidy3d/components/base_sim/source.py b/tidy3d/components/base_sim/source.py index 2f7391e21d..277c4e6de5 100644 --- a/tidy3d/components/base_sim/source.py +++ b/tidy3d/components/base_sim/source.py @@ -1,30 +1,10 @@ -"""Abstract base for classes that define simulation sources.""" +"""Compatibility shim for :mod:`tidy3d._common.components.base_sim.source`.""" -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Optional - -from pydantic import Field - -from tidy3d.components.base import Tidy3dBaseModel -from tidy3d.components.validators import validate_name_str - -if TYPE_CHECKING: - from tidy3d.components.viz import PlotParams +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility +# marked as migrated to _common +from __future__ import annotations -class AbstractSource(Tidy3dBaseModel, ABC): - """Abstract base class for all sources.""" - - name: Optional[str] = Field( - None, - title="Name", - description="Optional name for the source.", - ) - - @abstractmethod - def plot_params(self) -> PlotParams: - """Default parameters for plotting a Source object.""" - - _name_validator = validate_name_str() +from tidy3d._common.components.base_sim.source import ( + AbstractSource, +) diff --git a/tidy3d/components/data/data_array.py b/tidy3d/components/data/data_array.py index 37eaf7efcc..86c5ed68fc 100644 --- a/tidy3d/components/data/data_array.py +++ b/tidy3d/components/data/data_array.py @@ -1,599 +1,38 @@ -"""Storing tidy3d data at it's most fundamental level as xr.DataArray objects""" +"""Compatibility shim for :mod:`tidy3d._common.components.data.data_array`.""" +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility + +# marked as partially migrated to _common from __future__ import annotations -import pathlib -from abc import ABC -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Union -import autograd.numpy as anp -import h5py import numpy as np -import xarray as xr -from autograd.tracer import isbox -from pydantic_core import core_schema -from xarray.core import missing -from xarray.core.indexes import PandasIndex -from xarray.core.indexing import _outer_to_numpy_indexer -from xarray.core.utils import OrderedSet, either_dict_or_kwargs -from xarray.core.variable import as_variable - -from tidy3d.compat import alignment -from tidy3d.components.autograd import TidyArrayBox, get_static, interpn, is_tidy_box -from tidy3d.components.geometry.bound_ops import bounds_contains -from tidy3d.constants import ( + +from tidy3d._common.components.data.data_array import ( + DATA_ARRAY_MAP, + DATA_ARRAY_TYPES, + AbstractSpatialDataArray, + DataArray, + FreqDataArray, + ScalarFieldDataArray, + SpatialDataArray, + TimeDataArray, + TriangleMeshDataArray, +) +from tidy3d._common.constants import ( AMP, - HERTZ, - MICROMETER, OHM, PICOSECOND_PER_NANOMETER_PER_KILOMETER, - RADIAN, - SECOND, VOLT, WATT, ) -from tidy3d.exceptions import DataError, FileError +from tidy3d._common.exceptions import DataError, FileError if TYPE_CHECKING: - from collections.abc import Mapping - from os import PathLike - from typing import Optional - - from numpy.typing import NDArray - from pydantic.annotated_handlers import GetCoreSchemaHandler - from pydantic.json_schema import GetJsonSchemaHandler, JsonSchemaValue - from xarray.core.types import InterpOptions, Self - - from tidy3d.components.autograd import InterpolationType - from tidy3d.components.types import Axis, Bound - -# maps the dimension names to their attributes -DIM_ATTRS = { - "x": {"units": MICROMETER, "long_name": "x position"}, - "y": {"units": MICROMETER, "long_name": "y position"}, - "z": {"units": MICROMETER, "long_name": "z position"}, - "f": {"units": HERTZ, "long_name": "frequency"}, - "t": {"units": SECOND, "long_name": "time"}, - "direction": {"long_name": "propagation direction"}, - "mode_index": {"long_name": "mode index"}, - "eme_port_index": {"long_name": "EME port index"}, - "eme_cell_index": {"long_name": "EME cell index"}, - "mode_index_in": {"long_name": "mode index in"}, - "mode_index_out": {"long_name": "mode index out"}, - "sweep_index": {"long_name": "sweep index"}, - "theta": {"units": RADIAN, "long_name": "elevation angle"}, - "phi": {"units": RADIAN, "long_name": "azimuth angle"}, - "ux": {"long_name": "normalized kx"}, - "uy": {"long_name": "normalized ky"}, - "orders_x": {"long_name": "diffraction order"}, - "orders_y": {"long_name": "diffraction order"}, - "face_index": {"long_name": "face index"}, - "vertex_index": {"long_name": "vertex index"}, - "axis": {"long_name": "axis"}, -} - - -# name of the DataArray.values in the hdf5 file (xarray's default name too) -DATA_ARRAY_VALUE_NAME = "__xarray_dataarray_variable__" - - -class DataArray(xr.DataArray): - """Subclass of ``xr.DataArray`` that requires _dims to match the keys of the coords.""" + from xarray.core.types import Self - # Always set __slots__ = () to avoid xarray warnings - __slots__ = () - # stores an ordered tuple of strings corresponding to the data dimensions - _dims = () - # stores a dictionary of attributes corresponding to the data values - _data_attrs: dict[str, str] = {} - - def __init__(self, data: Any, *args: Any, **kwargs: Any) -> None: - # if data is a vanilla autograd box, convert to our box - if isbox(data) and not is_tidy_box(data): - data = TidyArrayBox.from_arraybox(data) - # do the same for xr.Variable or xr.DataArray type - elif isinstance(data, (xr.Variable, xr.DataArray)): - if isbox(data.data) and not is_tidy_box(data.data): - data.data = TidyArrayBox.from_arraybox(data.data) - super().__init__(data, *args, **kwargs) - - @classmethod - def __get_pydantic_core_schema__( - cls, source_type: Any, handler: GetCoreSchemaHandler - ) -> core_schema.CoreSchema: - """Core schema definition for validation & serialization.""" - - def _initial_parser(value: Any) -> Self: - if isinstance(value, cls): - return value - - if isinstance(value, str) and value == cls.__name__: - raise DataError( - f"Trying to load '{cls.__name__}' from string placeholder '{value}' " - "but the actual data is missing. DataArrays are not typically stored " - "in JSON. Load from HDF5 or ensure the DataArray object is provided." - ) - - try: - instance = cls(value) - if not isinstance(instance, cls): - raise TypeError( - f"Constructor for {cls.__name__} returned unexpected type {type(instance)}" - ) - return instance - except Exception as e: - raise ValueError( - f"Could not construct '{cls.__name__}' from input of type '{type(value)}'. " - f"Ensure input is compatible with xarray.DataArray constructor. Original error: {e}" - ) from e - - validation_schema = core_schema.no_info_plain_validator_function(_initial_parser) - validation_schema = core_schema.no_info_after_validator_function( - cls._validate_dims, validation_schema - ) - validation_schema = core_schema.no_info_after_validator_function( - cls._assign_data_attrs, validation_schema - ) - validation_schema = core_schema.no_info_after_validator_function( - cls._assign_coord_attrs, validation_schema - ) - - def _serialize_to_name(instance: Self) -> str: - return type(instance).__name__ - - # serialization behavior: - # - for JSON ('json' mode), use the _serialize_to_name function. - # - for Python ('python' mode), use Pydantic's default for the object type - serialization_schema = core_schema.plain_serializer_function_ser_schema( - _serialize_to_name, - return_schema=core_schema.str_schema(), - when_used="json", - ) - - return core_schema.json_or_python_schema( - python_schema=validation_schema, - json_schema=validation_schema, # Use same validation rules for JSON input - serialization=serialization_schema, - ) - - @classmethod - def __get_pydantic_json_schema__( - cls, core_schema_obj: core_schema.CoreSchema, handler: GetJsonSchemaHandler - ) -> JsonSchemaValue: - """JSON schema definition (defines how it LOOKS in a schema, not the data).""" - return { - "type": "string", - "title": cls.__name__, - "description": ( - f"Placeholder for a '{cls.__name__}' object. Actual data is typically " - "serialized separately (e.g., via HDF5) and not embedded in JSON." - ), - } - - @classmethod - def _validate_dims(cls, val: Self) -> Self: - """Make sure the dims are the same as ``_dims``, then put them in the correct order.""" - if set(val.dims) != set(cls._dims): - raise ValueError( - f"Wrong dims for {cls.__name__}, expected '{cls._dims}', got '{val.dims}'" - ) - if val.dims != cls._dims: - val = val.transpose(*cls._dims) - return val - - @classmethod - def _assign_data_attrs(cls, val: Self) -> Self: - """Assign the correct data attributes to the :class:`.DataArray`.""" - for attr_name, attr_val in cls._data_attrs.items(): - val.attrs[attr_name] = attr_val - return val - - @classmethod - def _assign_coord_attrs(cls, val: Self) -> Self: - """Assign the correct coordinate attributes to the :class:`.DataArray`.""" - target_dims = set(val.dims) & set(cls._dims) & set(val.coords) - for dim in target_dims: - template = DIM_ATTRS.get(dim) - if not template: - continue - - coord_attrs = val.coords[dim].attrs - missing = {k: v for k, v in template.items() if coord_attrs.get(k) != v} - coord_attrs.update(missing) - return val - - def _interp_validator(self, field_name: Optional[str] = None) -> None: - """Ensure the data can be interpolated or selected by checking for duplicate coordinates. - - NOTE - ---- - This does not check every 'DataArray' by default. Instead, when required, this check can be - called from a validator, as is the case with 'CustomMedium' and 'CustomFieldSource'. - """ - if field_name is None: - field_name = self.__class__.__name__ - - for dim, coord in self.coords.items(): - if coord.to_index().duplicated().any(): - raise DataError( - f"Field '{field_name}' contains duplicate coordinates in dimension '{dim}'. " - "Duplicates can be removed by running " - f"'{field_name}={field_name}.drop_duplicates(dim=\"{dim}\")'." - ) - - def __eq__(self, other: Any) -> bool: - """Whether two data array objects are equal.""" - - if not isinstance(other, xr.DataArray): - return False - - if not self.data.shape == other.data.shape or not np.all(self.data == other.data): - return False - for key, val in self.coords.items(): - if not np.all(np.array(val) == np.array(other.coords[key])): - return False - return True - - @property - def values(self) -> NDArray: - """ - The array's data converted to a numpy.ndarray. - - Returns - ------- - np.ndarray - The values of the DataArray. - """ - return self.data if isbox(self.data) else super().values - - @values.setter - def values(self, value: Any) -> None: - self.variable.values = value - - def to_numpy(self) -> np.ndarray: - """Return `.data` when traced to avoid `dtype=object` NumPy conversion.""" - return self.data if isbox(self.data) else super().to_numpy() - - @property - def abs(self) -> Self: - """Absolute value of data array.""" - return abs(self) - - @property - def angle(self) -> Self: - """Angle or phase value of data array.""" - values = np.angle(self.values) - return type(self)(values, coords=self.coords) - - @property - def is_uniform(self) -> bool: - """Whether each element is of equal value in the data array""" - raw_data = self.data.ravel() - return np.allclose(raw_data, raw_data[0]) - - def to_hdf5(self, fname: Union[PathLike, h5py.File], group_path: str) -> None: - """Save an ``xr.DataArray`` to the hdf5 file or file handle with a given path to the group.""" - if isinstance(fname, (str, pathlib.Path)): - path = pathlib.Path(fname) - path.parent.mkdir(parents=True, exist_ok=True) - with h5py.File(path, "w") as f_handle: - self.to_hdf5_handle(f_handle=f_handle, group_path=group_path) - else: - self.to_hdf5_handle(f_handle=fname, group_path=group_path) - - def to_hdf5_handle(self, f_handle: h5py.File, group_path: str) -> None: - """Save an ``xr.DataArray`` to the hdf5 file handle with a given path to the group.""" - sub_group = f_handle.create_group(group_path) - sub_group[DATA_ARRAY_VALUE_NAME] = get_static(self.data) - for key, val in self.coords.items(): - if val.dtype == " Self: - """Load a DataArray from an hdf5 file with a given path to the group.""" - path = pathlib.Path(fname) - with h5py.File(path, "r") as f: - sub_group = f[group_path] - values = np.array(sub_group[DATA_ARRAY_VALUE_NAME]) - coords = {dim: np.array(sub_group[dim]) for dim in cls._dims if dim in sub_group} - for key, val in coords.items(): - if val.dtype == "O": - coords[key] = [byte_string.decode() for byte_string in val.tolist()] - return cls(values, coords=coords, dims=cls._dims) - - @classmethod - def from_file(cls, fname: PathLike, group_path: str) -> Self: - """Load a DataArray from an hdf5 file with a given path to the group.""" - path = pathlib.Path(fname) - if not any(suffix.lower() == ".hdf5" for suffix in path.suffixes): - raise FileError( - f"'DataArray' objects must be written to '.hdf5' format. Given filename of {path}." - ) - return cls.from_hdf5(fname=path, group_path=group_path) - - def __hash__(self) -> int: - """Generate hash value for a :class:`.DataArray` instance, needed for custom components.""" - import dask - - token_str = dask.base.tokenize(self) - return hash(token_str) - - def multiply_at(self, value: complex, coord_name: str, indices: list[int]) -> Self: - """Multiply self by value at indices.""" - if isbox(self.data) or isbox(value): - return self._ag_multiply_at(value, coord_name, indices) - - self_mult = self.copy() - self_mult[{coord_name: indices}] *= value - return self_mult - - def _ag_multiply_at(self, value: complex, coord_name: str, indices: list[int]) -> Self: - """Autograd multiply_at override when tracing.""" - key = {coord_name: indices} - _, index_tuple, _ = self.variable._broadcast_indexes(key) - idx = _outer_to_numpy_indexer(index_tuple, self.data.shape) - mask = np.zeros(self.data.shape, dtype="?") - mask[idx] = True - return self.copy(deep=False, data=anp.where(mask, self.data * value, self.data)) - - def interp( - self, - coords: Mapping[Any, Any] | None = None, - method: InterpOptions = "linear", - assume_sorted: bool = False, - kwargs: Mapping[str, Any] | None = None, - **coords_kwargs: Any, - ) -> Self: - """Interpolate this DataArray to new coordinate values. - - Parameters - ---------- - coords : Union[Mapping[Any, Any], None] = None - A mapping from dimension names to new coordinate labels. - method : InterpOptions = "linear" - The interpolation method to use. - assume_sorted : bool = False - If True, skip sorting of coordinates. - kwargs : Union[Mapping[str, Any], None] = None - Additional keyword arguments to pass to the interpolation function. - **coords_kwargs : Any - The keyword arguments form of coords. - - Returns - ------- - DataArray - A new DataArray with interpolated values. - - Raises - ------ - KeyError - If any of the specified coordinates are not in the DataArray. - """ - if isbox(self.data): - return self._ag_interp(coords, method, assume_sorted, kwargs, **coords_kwargs) - - return super().interp(coords, method, assume_sorted, kwargs, **coords_kwargs) - - def _ag_interp( - self, - coords: Union[Mapping[Any, Any], None] = None, - method: InterpOptions = "linear", - assume_sorted: bool = False, - kwargs: Union[Mapping[str, Any], None] = None, - **coords_kwargs: Any, - ) -> Self: - """Autograd interp override when tracing over self.data. - - This implementation closely follows the interp implementation of xarray - to match its behavior as closely as possible while supporting autograd. - - See: - - https://docs.xarray.dev/en/latest/generated/xarray.DataArray.interp.html - - https://docs.xarray.dev/en/latest/generated/xarray.Dataset.interp.html - """ - if kwargs is None: - kwargs = {} - - ds = self._to_temp_dataset() - - coords = either_dict_or_kwargs(coords, coords_kwargs, "interp") - indexers = dict(ds._validate_interp_indexers(coords)) - - if coords: - # Find shared dimensions between the dataset and the indexers - sdims = ( - set(ds.dims) - .intersection(*[set(nx.dims) for nx in indexers.values()]) - .difference(coords.keys()) - ) - indexers.update({d: ds.variables[d] for d in sdims}) - - obj = ds if assume_sorted else ds.sortby(list(coords)) - - # workaround to get a variable for a dimension without a coordinate - validated_indexers = { - k: (obj._variables.get(k, as_variable((k, range(obj.sizes[k])))), v) - for k, v in indexers.items() - } - - for k, v in validated_indexers.items(): - obj, newidx = missing._localize(obj, {k: v}) - validated_indexers[k] = newidx[k] - - variables = {} - reindex = False - for name, var in obj._variables.items(): - if name in indexers: - continue - dtype_kind = var.dtype.kind - if dtype_kind in "uifc": - # Interpolation for numeric types - var_indexers = {k: v for k, v in validated_indexers.items() if k in var.dims} - variables[name] = self._ag_interp_func(var, var_indexers, method, **kwargs) - elif dtype_kind in "ObU" and (validated_indexers.keys() & var.dims): - # Stepwise interpolation for non-numeric types - reindex = True - elif all(d not in indexers for d in var.dims): - # Keep variables not dependent on interpolated coords - variables[name] = var - - if reindex: - # Reindex for non-numeric types - reindex_indexers = {k: v for k, (_, v) in validated_indexers.items() if v.dims == (k,)} - reindexed = alignment.reindex( - obj, - indexers=reindex_indexers, - method="nearest", - exclude_vars=variables.keys(), - ) - indexes = dict(reindexed._indexes) - variables.update(reindexed.variables) - else: - # Get the indexes that are not being interpolated along - indexes = {k: v for k, v in obj._indexes.items() if k not in indexers} - - # Get the coords that also exist in the variables - coord_names = obj._coord_names & variables.keys() - selected = ds._replace_with_new_dims(variables.copy(), coord_names, indexes=indexes) - - # Attach indexer as coordinate - for k, v in indexers.items(): - if v.dims == (k,): - index = PandasIndex(v, k, coord_dtype=v.dtype) - index_vars = index.create_variables({k: v}) - indexes[k] = index - variables.update(index_vars) - else: - variables[k] = v - - # Extract coordinates from indexers - coord_vars, new_indexes = selected._get_indexers_coords_and_indexes(coords) - variables.update(coord_vars) - indexes.update(new_indexes) - - coord_names = obj._coord_names & variables.keys() | coord_vars.keys() - ds = ds._replace_with_new_dims(variables, coord_names, indexes=indexes) - return self._from_temp_dataset(ds) - - @staticmethod - def _ag_interp_func( - var: xr.Variable, - indexes_coords: dict[str, tuple[xr.Variable, xr.Variable]], - method: InterpolationType, - **kwargs: Any, - ) -> xr.Variable: - """ - Interpolate the variable `var` along the coordinates specified in `indexes_coords` using the given `method`. - - The implementation follows xarray's interp implementation in xarray.core.missing, - but replaces some of the pre-processing as well as the actual interpolation - function with an autograd-compatible approach. - - - Parameters - ---------- - var : xr.Variable - The variable to be interpolated. - indexes_coords : dict - A dictionary mapping dimension names to coordinate values for interpolation. - method : Literal["nearest", "linear"] - The interpolation method to use. - **kwargs : dict - Additional keyword arguments to pass to the interpolation function. - - Returns - ------- - xr.Variable - The interpolated variable. - """ - if not indexes_coords: - return var.copy() - result = var - for indep_indexes_coords in missing.decompose_interp(indexes_coords): - var = result - - # target dimensions - dims = list(indep_indexes_coords) - x, new_x = zip(*[indep_indexes_coords[d] for d in dims]) - destination = missing.broadcast_variables(*new_x) - - broadcast_dims = [d for d in var.dims if d not in dims] - original_dims = broadcast_dims + dims - new_dims = broadcast_dims + list(destination[0].dims) - - x, new_x = missing._floatize_x(x, new_x) - - permutation = [var.dims.index(dim) for dim in original_dims] - combined_permutation = permutation[-len(x) :] + permutation[: -len(x)] - data = anp.transpose(var.data, combined_permutation) - xi = anp.stack([anp.ravel(new_xi.data) for new_xi in new_x], axis=-1) - - result = interpn([xn.data for xn in x], data, xi, method=method, **kwargs) - - result = anp.moveaxis(result, 0, -1) - result = anp.reshape(result, result.shape[:-1] + new_x[0].shape) - - result = xr.Variable(new_dims, result, attrs=var.attrs, fastpath=True) - - out_dims: OrderedSet = OrderedSet() - for d in var.dims: - if d in dims: - out_dims.update(indep_indexes_coords[d][1].dims) - else: - out_dims.add(d) - if len(out_dims) > 1: - result = result.transpose(*out_dims) - return result - - def _with_updated_data(self, data: np.ndarray, coords: dict[str, Any]) -> DataArray: - """Make copy of ``DataArray`` with ``data`` at specified ``coords``, autograd compatible - - Constraints / Edge cases: - - `coords` must map to a specific value eg {x: '1'}, does not broadcast to arrays - - `data` will be reshaped to try to match `self.shape` except where `coords` present - """ - - # make mask - mask = xr.zeros_like(self, dtype=bool) - mask.loc[coords] = True - - # reshape `data` to line up with `self.dims`, with shape of 1 along the selected axis - old_data = self.data - new_shape = list(old_data.shape) - for i, dim in enumerate(self.dims): - if dim in coords: - new_shape[i] = 1 - try: - new_data = data.reshape(new_shape) - except ValueError as e: - raise ValueError( - "Couldn't reshape the supplied 'data' to update 'DataArray'. The provided data was " - f"of shape {data.shape} and tried to reshape to {new_shape}. If you encounter this " - "error please raise an issue on the tidy3d github repository with the context." - ) from e - - # broadcast data to repeat data along the selected dimensions to match mask - new_data = new_data + np.zeros_like(old_data) - - new_data = np.where(mask, new_data, old_data) - - return self.copy(deep=True, data=new_data) - - -class FreqDataArray(DataArray): - """Frequency-domain array. - - Example - ------- - >>> f = [2e14, 3e14] - >>> fd = FreqDataArray((1+1j) * np.random.random((2,)), coords=dict(f=f)) - """ - - __slots__ = () - _dims = ("f",) + from tidy3d._common.components.types.base import Axis, Bound class FreqVoltageDataArray(DataArray): @@ -629,19 +68,6 @@ class FreqModeDataArray(DataArray): _dims = ("f", "mode_index") -class TimeDataArray(DataArray): - """Time-domain array. - - Example - ------- - >>> t = [0, 1e-12, 2e-12] - >>> td = TimeDataArray((1+1j) * np.random.random((3,)), coords=dict(t=t)) - """ - - __slots__ = () - _dims = ("t",) - - class MixedModeDataArray(DataArray): """Scalar property associated with mode pairs @@ -658,224 +84,6 @@ class MixedModeDataArray(DataArray): _dims = ("f", "mode_index_0", "mode_index_1") -class AbstractSpatialDataArray(DataArray, ABC): - """Spatial distribution.""" - - __slots__ = () - _dims = ("x", "y", "z") - _data_attrs = {"long_name": "field value"} - - @property - def _spatially_sorted(self) -> Self: - """Check whether sorted and sort if not.""" - needs_sorting = [] - for axis in "xyz": - axis_coords = self.coords[axis].values - if len(axis_coords) > 1 and np.any(axis_coords[1:] < axis_coords[:-1]): - needs_sorting.append(axis) - - if len(needs_sorting) > 0: - return self.sortby(needs_sorting) - - return self - - def sel_inside(self, bounds: Bound) -> Self: - """Return a new SpatialDataArray that contains the minimal amount data necessary to cover - a spatial region defined by ``bounds``. Note that the returned data is sorted with respect - to spatial coordinates. - - - Parameters - ---------- - bounds : Tuple[float, float, float], Tuple[float, float float] - Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. - - Returns - ------- - SpatialDataArray - Extracted spatial data array. - """ - if any(bmin > bmax for bmin, bmax in zip(*bounds)): - raise DataError( - "Min and max bounds must be packaged as '(minx, miny, minz), (maxx, maxy, maxz)'." - ) - - # make sure data is sorted with respect to coordinates - sorted_self = self._spatially_sorted - - inds_list = [] - - coords = (sorted_self.x, sorted_self.y, sorted_self.z) - - for coord, smin, smax in zip(coords, bounds[0], bounds[1]): - length = len(coord) - - # one point along direction, assume invariance - if length == 1: - comp_inds = [0] - else: - # if data does not cover structure at all take the closest index - if smax < coord[0]: # structure is completely on the left side - # take 2 if possible, so that linear iterpolation is possible - comp_inds = np.arange(0, max(2, length)) - - elif smin > coord[-1]: # structure is completely on the right side - # take 2 if possible, so that linear iterpolation is possible - comp_inds = np.arange(min(0, length - 2), length) - - else: - if smin < coord[0]: - ind_min = 0 - else: - ind_min = max(0, (coord >= smin).data.argmax() - 1) - - if smax > coord[-1]: - ind_max = length - 1 - else: - ind_max = (coord >= smax).data.argmax() - - comp_inds = np.arange(ind_min, ind_max + 1) - - inds_list.append(comp_inds) - - return sorted_self.isel(x=inds_list[0], y=inds_list[1], z=inds_list[2]) - - def does_cover(self, bounds: Bound, rtol: float = 0.0, atol: float = 0.0) -> bool: - """Check whether data fully covers specified by ``bounds`` spatial region. If data contains - only one point along a given direction, then it is assumed the data is constant along that - direction and coverage is not checked. - - - Parameters - ---------- - bounds : Tuple[float, float, float], Tuple[float, float float] - Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. - rtol : float = 0.0 - Relative tolerance for comparing bounds - atol : float = 0.0 - Absolute tolerance for comparing bounds - - Returns - ------- - bool - Full cover check outcome. - """ - if any(bmin > bmax for bmin, bmax in zip(*bounds)): - raise DataError( - "Min and max bounds must be packaged as '(minx, miny, minz), (maxx, maxy, maxz)'." - ) - xyz = [self.x, self.y, self.z] - self_min = [0] * 3 - self_max = [0] * 3 - for dim in range(3): - coords = xyz[dim] - if len(coords) == 1: - self_min[dim] = bounds[0][dim] - self_max[dim] = bounds[1][dim] - else: - self_min[dim] = np.min(coords) - self_max[dim] = np.max(coords) - self_bounds = (tuple(self_min), tuple(self_max)) - return bounds_contains(self_bounds, bounds, rtol=rtol, atol=atol) - - -class SpatialDataArray(AbstractSpatialDataArray): - """Spatial distribution. - - Example - ------- - >>> x = [1,2] - >>> y = [2,3,4] - >>> z = [3,4,5,6] - >>> coords = dict(x=x, y=y, z=z) - >>> fd = SpatialDataArray((1+1j) * np.random.random((2,3,4)), coords=coords) - """ - - __slots__ = () - - def reflect(self, axis: Axis, center: float, reflection_only: bool = False) -> Self: - """Reflect data across the plane define by parameters ``axis`` and ``center`` from right to - left. Note that the returned data is sorted with respect to spatial coordinates. - - Parameters - ---------- - axis : Literal[0, 1, 2] - Normal direction of the reflection plane. - center : float - Location of the reflection plane along its normal direction. - reflection_only : bool = False - Return only reflected data. - - Returns - ------- - SpatialDataArray - Data after reflection is performed. - """ - - sorted_self = self._spatially_sorted - - coords = [sorted_self.x.values, sorted_self.y.values, sorted_self.z.values] - data = np.array(sorted_self.data) - - data_left_bound = coords[axis][0] - - if np.isclose(center, data_left_bound): - num_duplicates = 1 - elif center > data_left_bound: - raise DataError("Reflection center must be outside and to the left of the data region.") - else: - num_duplicates = 0 - - if reflection_only: - coords[axis] = 2 * center - coords[axis] - coords_dict = dict(zip("xyz", coords)) - - tmp_arr = SpatialDataArray(sorted_self.data, coords=coords_dict) - - return tmp_arr.sortby("xyz"[axis]) - - shape = np.array(np.shape(data)) - old_len = shape[axis] - shape[axis] = 2 * old_len - num_duplicates - - ind_left = [slice(shape[0]), slice(shape[1]), slice(shape[2])] - ind_right = [slice(shape[0]), slice(shape[1]), slice(shape[2])] - - ind_left[axis] = slice(old_len - 1, None, -1) - ind_right[axis] = slice(old_len - num_duplicates, None) - - new_data = np.zeros(shape) - - new_data[ind_left[0], ind_left[1], ind_left[2]] = data - new_data[ind_right[0], ind_right[1], ind_right[2]] = data - - new_coords = np.zeros(shape[axis]) - new_coords[old_len - num_duplicates :] = coords[axis] - new_coords[old_len - 1 :: -1] = 2 * center - coords[axis] - - coords[axis] = new_coords - coords_dict = dict(zip("xyz", coords)) - - return SpatialDataArray(new_data, coords=coords_dict) - - -class ScalarFieldDataArray(AbstractSpatialDataArray): - """Spatial distribution in the frequency-domain. - - Example - ------- - >>> x = [1,2] - >>> y = [2,3,4] - >>> z = [3,4,5,6] - >>> f = [2e14, 3e14] - >>> coords = dict(x=x, y=y, z=z, f=f) - >>> fd = ScalarFieldDataArray((1+1j) * np.random.random((2,3,4,2)), coords=coords) - """ - - __slots__ = () - _dims = ("x", "y", "z", "f") - - class ScalarFieldTimeDataArray(AbstractSpatialDataArray): """Spatial distribution in the time-domain. @@ -1103,14 +311,6 @@ class DiffractionDataArray(DataArray): _data_attrs = {"long_name": "diffraction amplitude"} -class TriangleMeshDataArray(DataArray): - """Data of the triangles of a surface mesh as in the STL file format.""" - - __slots__ = () - _dims = ("face_index", "vertex_index", "axis") - _data_attrs = {"long_name": "surface mesh triangles"} - - class HeatDataArray(DataArray): """Heat data array. @@ -1662,57 +862,54 @@ def _make_impedance_data_array(result: DataArray) -> ImpedanceResultType: return cls._assign_data_attrs(cls(data=result.data, coords=result.coords)) -DATA_ARRAY_TYPES = [ - SpatialDataArray, - ScalarFieldDataArray, - ScalarFieldTimeDataArray, - ScalarModeFieldDataArray, - FluxDataArray, - FluxTimeDataArray, - ModeAmpsDataArray, - ModeIndexDataArray, - GroupIndexDataArray, - ModeDispersionDataArray, - FieldProjectionAngleDataArray, - FieldProjectionCartesianDataArray, - FieldProjectionKSpaceDataArray, - DiffractionDataArray, - FreqModeDataArray, - FreqDataArray, - TimeDataArray, - FreqModeDataArray, - FreqVoltageDataArray, - TriangleMeshDataArray, - HeatDataArray, - EMEScalarFieldDataArray, - EMEScalarModeFieldDataArray, - EMESMatrixDataArray, - EMEInterfaceSMatrixDataArray, - EMECoefficientDataArray, - EMEModeIndexDataArray, - EMEFluxDataArray, - EMEFreqModeDataArray, - ChargeDataArray, - SteadyVoltageDataArray, - PointDataArray, - CellDataArray, - IndexedDataArray, - IndexedFieldVoltageDataArray, - IndexedVoltageDataArray, - SpatialVoltageDataArray, - PerturbationCoefficientDataArray, - IndexedTimeDataArray, - VoltageFreqDataArray, - VoltageTimeDataArray, - VoltageFreqModeDataArray, - CurrentFreqDataArray, - CurrentTimeDataArray, - CurrentFreqModeDataArray, - ImpedanceFreqDataArray, - ImpedanceTimeDataArray, - ImpedanceFreqModeDataArray, -] -DATA_ARRAY_MAP = {data_array.__name__: data_array for data_array in DATA_ARRAY_TYPES} +DATA_ARRAY_TYPES.extend( + [ + ScalarFieldTimeDataArray, + ScalarModeFieldDataArray, + FluxDataArray, + FluxTimeDataArray, + ModeAmpsDataArray, + ModeIndexDataArray, + GroupIndexDataArray, + ModeDispersionDataArray, + FieldProjectionAngleDataArray, + FieldProjectionCartesianDataArray, + FieldProjectionKSpaceDataArray, + DiffractionDataArray, + FreqModeDataArray, + FreqModeDataArray, + FreqVoltageDataArray, + HeatDataArray, + EMEScalarFieldDataArray, + EMEScalarModeFieldDataArray, + EMESMatrixDataArray, + EMEInterfaceSMatrixDataArray, + EMECoefficientDataArray, + EMEModeIndexDataArray, + EMEFluxDataArray, + EMEFreqModeDataArray, + ChargeDataArray, + SteadyVoltageDataArray, + PointDataArray, + CellDataArray, + IndexedDataArray, + IndexedFieldVoltageDataArray, + IndexedVoltageDataArray, + SpatialVoltageDataArray, + PerturbationCoefficientDataArray, + IndexedTimeDataArray, + VoltageFreqDataArray, + VoltageTimeDataArray, + VoltageFreqModeDataArray, + CurrentFreqDataArray, + CurrentTimeDataArray, + CurrentFreqModeDataArray, + ImpedanceFreqDataArray, + ImpedanceTimeDataArray, + ImpedanceFreqModeDataArray, + ] +) +DATA_ARRAY_MAP.update({data_array.__name__: data_array for data_array in DATA_ARRAY_TYPES}) IndexedDataArrayTypes = Union[ IndexedDataArray, diff --git a/tidy3d/components/data/dataset.py b/tidy3d/components/data/dataset.py index 324ce679a6..3f3decbd20 100644 --- a/tidy3d/components/data/dataset.py +++ b/tidy3d/components/data/dataset.py @@ -1,4 +1,8 @@ -"""Collections of DataArrays.""" +"""Compatibility shim for :mod:`tidy3d._common.components.data.dataset`.""" + +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility + +# marked as partially migrated to _common from __future__ import annotations @@ -9,14 +13,18 @@ import xarray as xr from pydantic import Field -from tidy3d.components.base import Tidy3dBaseModel -from tidy3d.components.types import xyz -from tidy3d.constants import C_0, PICOSECOND_PER_NANOMETER_PER_KILOMETER, UnitScaling -from tidy3d.exceptions import DataError -from tidy3d.log import log - -from .data_array import ( - DataArray, +from tidy3d._common.components.data.dataset import ( + DEFAULT_MAX_CELLS_PER_STEP, + DEFAULT_MAX_SAMPLES_PER_STEP, + DEFAULT_TOLERANCE_CELL_FINDING, + AbstractFieldDataset, + AbstractMediumPropertyDataset, + Dataset, + PermittivityDataset, + TimeDataset, + TriangleMeshDataset, +) +from tidy3d.components.data.data_array import ( EMEScalarFieldDataArray, EMEScalarModeFieldDataArray, GroupIndexDataArray, @@ -27,9 +35,12 @@ ScalarModeFieldCylindricalDataArray, ScalarModeFieldDataArray, TimeDataArray, - TriangleMeshDataArray, ) -from .zbf import ZBFData +from tidy3d.components.data.zbf import ZBFData +from tidy3d.components.types.base import xyz +from tidy3d.constants import C_0, PICOSECOND_PER_NANOMETER_PER_KILOMETER, UnitScaling +from tidy3d.exceptions import DataError +from tidy3d.log import log if TYPE_CHECKING: from typing import Callable, Literal @@ -37,25 +48,8 @@ from numpy.typing import ArrayLike from tidy3d.compat import Self - from tidy3d.components.types import Axis, FreqArray - -DEFAULT_MAX_SAMPLES_PER_STEP = 10_000 -DEFAULT_MAX_CELLS_PER_STEP = 10_000 -DEFAULT_TOLERANCE_CELL_FINDING = 1e-6 - - -class Dataset(Tidy3dBaseModel, ABC): - """Abstract base class for objects that store collections of `:class:`.DataArray`s.""" - - @property - def data_arrs(self) -> dict: - """Returns a dictionary of all `:class:`.DataArray`s in the dataset.""" - data_arrs = {} - for key in self.__class__.model_fields.keys(): - data = getattr(self, key) - if isinstance(data, DataArray): - data_arrs[key] = data - return data_arrs + from tidy3d.components.data.data_array import DataArray + from tidy3d.components.types.base import Axis, FreqArray class FreqDataset(Dataset, ABC): @@ -195,104 +189,6 @@ def _apply_mode_reorder(self, sort_inds_2d: np.ndarray) -> Self: return self.updated_copy(**modify_data) -class AbstractFieldDataset(Dataset, ABC): - """Collection of scalar fields with some symmetry properties.""" - - @property - @abstractmethod - def field_components(self) -> dict[str, DataArray]: - """Maps the field components to their associated data.""" - - def apply_phase(self, phase: float) -> AbstractFieldDataset: - """Create a copy where all elements are phase-shifted by a value (in radians).""" - if phase == 0.0: - return self - phasor = np.exp(1j * phase) - field_components_shifted = {} - for fld_name, fld_cmp in self.field_components.items(): - fld_cmp_shifted = phasor * fld_cmp - field_components_shifted[fld_name] = fld_cmp_shifted - return self.updated_copy(**field_components_shifted) - - @property - @abstractmethod - def grid_locations(self) -> dict[str, str]: - """Maps field components to the string key of their grid locations on the yee lattice.""" - - @property - @abstractmethod - def symmetry_eigenvalues(self) -> dict[str, Callable[[Axis], float]]: - """Maps field components to their (positive) symmetry eigenvalues.""" - - def package_colocate_results(self, centered_fields: dict[str, ScalarFieldDataArray]) -> Any: - """How to package the dictionary of fields computed via self.colocate().""" - return xr.Dataset(centered_fields) - - def colocate(self, x: ArrayLike = None, y: ArrayLike = None, z: ArrayLike = None) -> xr.Dataset: - """Colocate all of the data at a set of x, y, z coordinates. - - Parameters - ---------- - x : Optional[array-like] = None - x coordinates of locations. - If not supplied, does not try to colocate on this dimension. - y : Optional[array-like] = None - y coordinates of locations. - If not supplied, does not try to colocate on this dimension. - z : Optional[array-like] = None - z coordinates of locations. - If not supplied, does not try to colocate on this dimension. - - Returns - ------- - xr.Dataset - Dataset containing all fields at the same spatial locations. - For more details refer to `xarray's Documentation `_. - - Note - ---- - For many operations (such as flux calculations and plotting), - it is important that the fields are colocated at the same spatial locations. - Be sure to apply this method to your field data in those cases. - """ - - if hasattr(self, "monitor") and self.monitor.colocate: - with log as consolidated_logger: - consolidated_logger.warning( - "Colocating data that has already been colocated during the solver " - "run. For most accurate results when colocating to custom coordinates set " - "'Monitor.colocate' to 'False' to use the raw data on the Yee grid " - "and avoid double interpolation. Note: the default value was changed to 'True' " - "in Tidy3D version 2.4.0." - ) - - # convert supplied coordinates to array and assign string mapping to them - supplied_coord_map = {k: np.array(v) for k, v in zip("xyz", (x, y, z)) if v is not None} - - # dict of data arrays to combine in dataset and return - centered_fields = {} - - # loop through field components - for field_name, field_data in self.field_components.items(): - # loop through x, y, z dimensions and raise an error if only one element along dim - for coord_name, coords_supplied in supplied_coord_map.items(): - coord_data = np.array(field_data.coords[coord_name]) - if coord_data.size == 1: - raise DataError( - f"colocate given {coord_name}={coords_supplied}, but " - f"data only has one coordinate at {coord_name}={coord_data[0]}. " - "Therefore, can't colocate along this dimension. " - f"supply {coord_name}=None to skip it." - ) - - centered_fields[field_name] = field_data.interp( - **supplied_coord_map, kwargs={"bounds_error": True} - ) - - # combine all centered fields in a dataset - return self.package_colocate_results(centered_fields) - - EMScalarFieldType = Union[ ScalarFieldDataArray, ScalarFieldTimeDataArray, @@ -710,7 +606,7 @@ class ModeSolverDataset(ElectromagneticFieldDataset, ModeFreqDataset): None, title="Dispersion", description="Dispersion parameter for the mode.", - json_schema_extra={"units": PICOSECOND_PER_NANOMETER_PER_KILOMETER}, + units=PICOSECOND_PER_NANOMETER_PER_KILOMETER, ) @property @@ -772,53 +668,6 @@ def plot_field(self, *args: Any, **kwargs: Any) -> None: ) -class AbstractMediumPropertyDataset(AbstractFieldDataset, ABC): - """Dataset storing medium property.""" - - eps_xx: ScalarFieldDataArray = Field( - title="Epsilon xx", - description="Spatial distribution of the xx-component of the relative permittivity.", - ) - eps_yy: ScalarFieldDataArray = Field( - title="Epsilon yy", - description="Spatial distribution of the yy-component of the relative permittivity.", - ) - eps_zz: ScalarFieldDataArray = Field( - title="Epsilon zz", - description="Spatial distribution of the zz-component of the relative permittivity.", - ) - - -class PermittivityDataset(AbstractMediumPropertyDataset): - """Dataset storing the diagonal components of the permittivity tensor. - - Example - ------- - >>> x = [-1,1] - >>> y = [-2,0,2] - >>> z = [-3,-1,1,3] - >>> f = [2e14, 3e14] - >>> coords = dict(x=x, y=y, z=z, f=f) - >>> sclr_fld = ScalarFieldDataArray((1+1j) * np.random.random((2,3,4,2)), coords=coords) - >>> data = PermittivityDataset(eps_xx=sclr_fld, eps_yy=sclr_fld, eps_zz=sclr_fld) - """ - - @property - def field_components(self) -> dict[str, ScalarFieldDataArray]: - """Maps the field components to their associated data.""" - return {"eps_xx": self.eps_xx, "eps_yy": self.eps_yy, "eps_zz": self.eps_zz} - - @property - def grid_locations(self) -> dict[str, str]: - """Maps field components to the string key of their grid locations on the yee lattice.""" - return {"eps_xx": "Ex", "eps_yy": "Ey", "eps_zz": "Ez"} - - @property - def symmetry_eigenvalues(self) -> dict[str, None]: - """Maps field components to their (positive) symmetry eigenvalues.""" - return {"eps_xx": None, "eps_yy": None, "eps_zz": None} - - class MediumDataset(AbstractMediumPropertyDataset): """Dataset storing the diagonal components of the permittivity and permeability tensor. @@ -881,22 +730,3 @@ def symmetry_eigenvalues(self) -> dict[str, None]: "mu_yy": None, "mu_zz": None, } - - -class TriangleMeshDataset(Dataset): - """Dataset for storing triangular surface data.""" - - surface_mesh: TriangleMeshDataArray = Field( - title="Surface mesh data", - description="Dataset containing the surface triangles and corresponding face indices " - "for a surface mesh.", - ) - - -class TimeDataset(Dataset): - """Dataset for storing a function of time.""" - - values: TimeDataArray = Field( - title="Values", - description="Values as a function of time.", - ) diff --git a/tidy3d/components/data/validators.py b/tidy3d/components/data/validators.py index 5ce9e9e804..86f6d6043e 100644 --- a/tidy3d/components/data/validators.py +++ b/tidy3d/components/data/validators.py @@ -1,86 +1,11 @@ -# special validators for Datasets -from __future__ import annotations - -from typing import TYPE_CHECKING, Any - -import numpy as np -from pydantic import field_validator - -from tidy3d.exceptions import ValidationError - -from .data_array import DataArray -from .dataset import AbstractFieldDataset, ScalarFieldDataArray - -if TYPE_CHECKING: - from typing import Callable, Optional - - from pydantic_core.core_schema import ValidationInfo - - -# this can't go in validators.py because that file imports dataset.py -def validate_no_nans(*field_names: str) -> Callable[[Any, ValidationInfo], Any]: - """Raise validation error if nans found in Dataset, or other data-containing item.""" - - @field_validator(*field_names) - def no_nans(val: Any, info: ValidationInfo) -> Any: - """Raise validation error if nans found in Dataset, or other data-containing item.""" - - if val is None: - return val +"""Compatibility shim for :mod:`tidy3d._common.components.data.validators`.""" - def error_if_has_nans(value: Any, identifier: Optional[str] = None) -> None: - """Recursively check if value (or iterable) has nans and error if so.""" +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility - def has_nans(values: Any) -> bool: - """Base case: do these values contain NaN?""" - try: - return np.any(np.isnan(values)) - # if this fails for some reason (fails in adjoint, for example), don't check it. - except Exception: - return False - - if isinstance(value, (tuple, list)): - for i, _value in enumerate(value): - error_if_has_nans(_value, identifier=f"[{i}]") - - elif isinstance(value, AbstractFieldDataset): - for key, val in value.field_components.items(): - error_if_has_nans(val, identifier=f".{key}") - - elif isinstance(value, DataArray): - error_if_has_nans(value.values) - - else: - if has_nans(value): - # the identifier is used to make the message more clear by appending some more info - field_name_display = info.field_name - if identifier: - field_name_display += identifier - - raise ValidationError( - f"Found 'NaN' values in '{field_name_display}'. " - "If they were not intended, please double check your construction. " - "If intended, to replace these data points with a value 'x', " - "call 'values = np.nan_to_num(values, nan=x)'." - ) - - error_if_has_nans(val) - return val - - return no_nans - - -def validate_can_interpolate( - *field_names: str, -) -> Callable[[AbstractFieldDataset], AbstractFieldDataset]: - """Make sure the data in ``field_name`` can be interpolated.""" - - @field_validator(*field_names) - def check_fields_interpolate(val: AbstractFieldDataset) -> AbstractFieldDataset: - if isinstance(val, AbstractFieldDataset): - for name, data in val.field_components.items(): - if isinstance(data, ScalarFieldDataArray): - data._interp_validator(name) - return val +# marked as migrated to _common +from __future__ import annotations - return check_fields_interpolate +from tidy3d._common.components.data.validators import ( + validate_can_interpolate, + validate_no_nans, +) diff --git a/tidy3d/components/data/zbf.py b/tidy3d/components/data/zbf.py index dedd22660b..6827e6a4c6 100644 --- a/tidy3d/components/data/zbf.py +++ b/tidy3d/components/data/zbf.py @@ -1,156 +1,10 @@ -"""ZBF utilities""" +"""Compatibility shim for :mod:`tidy3d._common.components.data.zbf`.""" -from __future__ import annotations - -from struct import unpack - -import numpy as np -from pydantic import Field - -from tidy3d.components.base import Tidy3dBaseModel - - -class ZBFData(Tidy3dBaseModel): - """ - Contains data read in from a ``.zbf`` file - """ - - version: int = Field(title="Version", description="File format version number.") - nx: int = Field(title="Samples in X", description="Number of samples in the x direction.") - ny: int = Field(title="Samples in Y", description="Number of samples in the y direction.") - ispol: bool = Field( - title="Is Polarized", - description="``True`` if the beam is polarized, ``False`` otherwise.", - ) - unit: str = Field( - title="Spatial Units", description="Spatial units, either 'mm', 'cm', 'in', or 'm'." - ) - dx: float = Field(title="Grid Spacing, X", description="Grid spacing in x.") - dy: float = Field(title="Grid Spacing, Y", description="Grid spacing in y.") - zposition_x: float = Field( - title="Z Position, X Direction", - description="The pilot beam z position with respect to the pilot beam waist, x direction.", - ) - zposition_y: float = Field( - title="Z Position, Y Direction", - description="The pilot beam z position with respect to the pilot beam waist, y direction.", - ) - rayleigh_x: float = Field( - title="Rayleigh Distance, X Direction", - description="The pilot beam Rayleigh distance in the x direction.", - ) - rayleigh_y: float = Field( - title="Rayleigh Distance, Y Direction", - description="The pilot beam Rayleigh distance in the y direction.", - ) - waist_x: float = Field( - title="Beam Waist, X", description="The pilot beam waist in the x direction." - ) - waist_y: float = Field( - title="Beam Waist, Y", description="The pilot beam waist in the y direction." - ) - wavelength: float = Field(title="Wavelength", description="The wavelength of the beam.") - background_refractive_index: float = Field( - title="Background Refractive Index", - description="The index of refraction in the current medium.", - ) - receiver_eff: float = Field( - title="Receiver Efficiency", - description="The receiver efficiency. Zero if fiber coupling is not computed.", - ) - system_eff: float = Field( - title="System Efficiency", - description="The system efficiency. Zero if fiber coupling is not computed.", - ) - Ex: np.ndarray = Field( - title="Electric Field, X Component", - description="Complex-valued electric field, x component.", - ) - Ey: np.ndarray = Field( - title="Electric Field, Y Component", - description="Complex-valued electric field, y component.", - ) - - def read_zbf(filename: str) -> ZBFData: - """Reads a Zemax Beam File (``.zbf``) +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility - Parameters - ---------- - filename : str - The file name of the ``.zbf`` file to read. - - Returns - ------- - :class:`.ZBFData` - """ - - # Read the zbf file - with open(filename, "rb") as f: - # Load the header - version, nx, ny, ispol, units = unpack("<5I", f.read(20)) - f.read(16) # unused values - ( - dx, - dy, - zposition_x, - rayleigh_x, - waist_x, - zposition_y, - rayleigh_y, - waist_y, - wavelength, - background_refractive_index, - receiver_eff, - system_eff, - ) = unpack("<12d", f.read(96)) - f.read(64) # unused values - - # read E field - nsamps = 2 * nx * ny - rawx = list(unpack(f"<{nsamps}d", f.read(8 * nsamps))) - if ispol: - rawy = list(unpack(f"<{nsamps}d", f.read(8 * nsamps))) - - # convert unit key to unit string - map_units = {0: "mm", 1: "cm", 2: "in", 3: "m"} - try: - unit = map_units[units] - except KeyError: - raise KeyError( - f"Invalid units specified in the zbf file (expected '0', '1', '2', or '3', got '{units}')." - ) from None - - # load E field - Ex_real = np.asarray(rawx[0::2]).reshape(nx, ny, order="F") - Ex_imag = np.asarray(rawx[1::2]).reshape(nx, ny, order="F") - if ispol: - Ey_real = np.asarray(rawy[0::2]).reshape(nx, ny, order="F") - Ey_imag = np.asarray(rawy[1::2]).reshape(nx, ny, order="F") - else: - Ey_real = np.zeros((nx, ny)) - Ey_imag = np.zeros((nx, ny)) - - Ex = Ex_real + 1j * Ex_imag - Ey = Ey_real + 1j * Ey_imag +# marked as migrated to _common +from __future__ import annotations - return ZBFData( - version=version, - nx=nx, - ny=ny, - ispol=ispol, - unit=unit, - dx=dx, - dy=dy, - zposition_x=zposition_x, - zposition_y=zposition_y, - rayleigh_x=rayleigh_x, - rayleigh_y=rayleigh_y, - waist_x=waist_x, - waist_y=waist_y, - wavelength=wavelength, - background_refractive_index=background_refractive_index, - receiver_eff=receiver_eff, - system_eff=system_eff, - Ex=Ex, - Ey=Ey, - ) +from tidy3d._common.components.data.zbf import ( + ZBFData, +) diff --git a/tidy3d/components/file_util.py b/tidy3d/components/file_util.py index 51e13f586d..54b1ab4c7d 100644 --- a/tidy3d/components/file_util.py +++ b/tidy3d/components/file_util.py @@ -1,83 +1,12 @@ -"""File compression utilities""" +"""Compatibility shim for :mod:`tidy3d._common.components.file_util`.""" -from __future__ import annotations - -import gzip -import pathlib -import shutil -from typing import TYPE_CHECKING, Any - -import numpy as np - -if TYPE_CHECKING: - from io import BytesIO - from os import PathLike - - -def compress_file_to_gzip(input_file: PathLike, output_gz_file: PathLike | BytesIO) -> None: - """ - Compress a file using gzip. - - Parameters - ---------- - input_file : PathLike - The path to the input file. - output_gz_file : PathLike | BytesIO - The path to the output gzip file or an in-memory buffer. - """ - input_file = pathlib.Path(input_file) - with input_file.open("rb") as file_in: - with gzip.open(output_gz_file, "wb") as file_out: - shutil.copyfileobj(file_in, file_out) +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility +# marked as migrated to _common +from __future__ import annotations -def extract_gzip_file(input_gz_file: PathLike, output_file: PathLike) -> None: - """ - Extract a gzip-compressed file. - - Parameters - ---------- - input_gz_file : PathLike - The path to the gzip-compressed input file. - output_file : PathLike - The path to the extracted output file. - """ - input_path = pathlib.Path(input_gz_file) - output_path = pathlib.Path(output_file) - with gzip.open(input_path, "rb") as file_in: - with output_path.open("wb") as file_out: - shutil.copyfileobj(file_in, file_out) - - -def replace_values(values: Any, search_value: Any, replace_value: Any) -> Any: - """ - Create a copy of ``values`` where any elements equal to ``search_value`` are replaced by ``replace_value``. - - Parameters - ---------- - values : Any - The input object to iterate through. - search_value : Any - An object to match for in ``values``. - replace_value : Any - A replacement object for the matched value in ``values``. - - Returns - ------- - Any - values type object with ``search_value`` terms replaced by ``replace_value``. - """ - # np.all allows for arrays to be evaluated - if np.all(values == search_value): - return replace_value - if isinstance(values, dict): - return { - key: replace_values(val, search_value, replace_value) for key, val in values.items() - } - if isinstance( - values, (tuple, list) - ): # Parts of the nested dict structure include tuples with more dicts - return type(values)(replace_values(val, search_value, replace_value) for val in values) - - # Used to maintain values that are not search_value or containers - return values +from tidy3d._common.components.file_util import ( + compress_file_to_gzip, + extract_gzip_file, + replace_values, +) diff --git a/tidy3d/components/geometry/__init__.py b/tidy3d/components/geometry/__init__.py index e69de29bb2..a8ed42d8cc 100644 --- a/tidy3d/components/geometry/__init__.py +++ b/tidy3d/components/geometry/__init__.py @@ -0,0 +1,8 @@ +"""Compatibility shim for :mod:`tidy3d._common.components.geometry`.""" + +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility + +# marked as migrated to _common +from __future__ import annotations + +import tidy3d._common.components.geometry as _module diff --git a/tidy3d/components/geometry/base.py b/tidy3d/components/geometry/base.py index b570a4aeee..5bfe288e48 100644 --- a/tidy3d/components/geometry/base.py +++ b/tidy3d/components/geometry/base.py @@ -1,3709 +1,24 @@ -"""Abstract base classes for geometry.""" +"""Compatibility shim for :mod:`tidy3d._common.components.geometry.base`.""" -from __future__ import annotations - -import functools -import pathlib -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Optional +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -import autograd.numpy as np -import shapely -from pydantic import Field, NonNegativeFloat, field_validator, model_validator +# marked as migrated to _common +from __future__ import annotations -from tidy3d.compat import _package_is_older_than -from tidy3d.components.autograd import TracedCoordinate, TracedFloat, TracedSize, get_static -from tidy3d.components.base import Tidy3dBaseModel, cached_property -from tidy3d.components.geometry.bound_ops import bounds_intersection, bounds_union -from tidy3d.components.geometry.float_utils import increment_float -from tidy3d.components.transformation import ReflectionFromPlane, RotationAroundAxis -from tidy3d.components.types import Axis, ClipOperationType, MatrixReal4x4, PlanePosition -from tidy3d.components.types.base import discriminated_union -from tidy3d.components.viz import ( - ARROW_LENGTH, - PLOT_BUFFER, - add_ax_if_none, - arrow_style, - equal_aspect, - plot_params_geometry, - polygon_patch, - set_default_labels_and_title, -) -from tidy3d.constants import LARGE_NUMBER, MICROMETER, RADIAN, fp_eps, inf -from tidy3d.exceptions import ( - SetupError, - Tidy3dError, - Tidy3dImportError, - Tidy3dKeyError, - ValidationError, +from tidy3d._common.components.geometry.base import ( + POLY_DISTANCE_TOLERANCE, + POLY_GRID_SIZE, + POLY_TOLERANCE_RATIO, + Box, + Centered, + Circular, + ClipOperation, + Geometry, + GeometryGroup, + Planar, + SimplePlaneIntersection, + Transformed, + _bit_operations, + _shapely_operations, + cleanup_shapely_object, ) -from tidy3d.log import log -from tidy3d.packaging import verify_packages_import - -if TYPE_CHECKING: - from collections.abc import Iterable, Sequence - from os import PathLike - from typing import Callable, Union - - import pydantic - from gdstk import Cell - from matplotlib.backend_bases import Event - from matplotlib.patches import FancyArrowPatch - from numpy.typing import ArrayLike, NDArray - from pydantic import NonNegativeInt, PositiveFloat - from typing_extensions import Self - - from tidy3d.components.autograd import AutogradFieldMap - from tidy3d.components.autograd.derivative_utils import DerivativeInfo - from tidy3d.components.types import ( - ArrayFloat2D, - ArrayFloat3D, - Ax, - Bound, - Coordinate, - Coordinate2D, - LengthUnit, - Shapely, - Size, - ) - from tidy3d.components.viz import PlotParams, VisualizationSpec - -try: - from matplotlib import patches -except ImportError: - pass - -POLY_GRID_SIZE = 1e-12 -POLY_TOLERANCE_RATIO = 1e-12 -POLY_DISTANCE_TOLERANCE = 8e-12 - - -_shapely_operations = { - "union": shapely.union, - "intersection": shapely.intersection, - "difference": shapely.difference, - "symmetric_difference": shapely.symmetric_difference, -} - -_bit_operations = { - "union": lambda a, b: a | b, - "intersection": lambda a, b: a & b, - "difference": lambda a, b: a & ~b, - "symmetric_difference": lambda a, b: a != b, -} - - -class Geometry(Tidy3dBaseModel, ABC): - """Abstract base class, defines where something exists in space.""" - - @cached_property - def plot_params(self) -> PlotParams: - """Default parameters for plotting a Geometry object.""" - return plot_params_geometry - - def inside(self, x: NDArray[float], y: NDArray[float], z: NDArray[float]) -> NDArray[bool]: - """For input arrays ``x``, ``y``, ``z`` of arbitrary but identical shape, return an array - with the same shape which is ``True`` for every point in zip(x, y, z) that is inside the - volume of the :class:`Geometry`, and ``False`` otherwise. - - Parameters - ---------- - x : np.ndarray[float] - Array of point positions in x direction. - y : np.ndarray[float] - Array of point positions in y direction. - z : np.ndarray[float] - Array of point positions in z direction. - - Returns - ------- - np.ndarray[bool] - ``True`` for every point that is inside the geometry. - """ - - def point_inside(x: float, y: float, z: float) -> bool: - """Returns ``True`` if a single point ``(x, y, z)`` is inside.""" - shapes_intersect = self.intersections_plane(z=z) - loc = self.make_shapely_point(x, y) - return any(shape.contains(loc) for shape in shapes_intersect) - - arrays = tuple(map(np.array, (x, y, z))) - self._ensure_equal_shape(*arrays) - inside = np.zeros((arrays[0].size,), dtype=bool) - arrays_flat = map(np.ravel, arrays) - for ipt, args in enumerate(zip(*arrays_flat)): - inside[ipt] = point_inside(*args) - return inside.reshape(arrays[0].shape) - - @staticmethod - def _ensure_equal_shape(*arrays: Any) -> None: - """Ensure all input arrays have the same shape.""" - shapes = {np.array(arr).shape for arr in arrays} - if len(shapes) > 1: - raise ValueError("All coordinate inputs (x, y, z) must have the same shape.") - - @staticmethod - def make_shapely_box(minx: float, miny: float, maxx: float, maxy: float) -> shapely.box: - """Make a shapely box ensuring everything untraced.""" - - minx = get_static(minx) - miny = get_static(miny) - maxx = get_static(maxx) - maxy = get_static(maxy) - - return shapely.box(minx, miny, maxx, maxy) - - @staticmethod - def make_shapely_point(minx: float, miny: float) -> shapely.Point: - """Make a shapely Point ensuring everything untraced.""" - - minx = get_static(minx) - miny = get_static(miny) - - return shapely.Point(minx, miny) - - def _inds_inside_bounds( - self, x: NDArray[float], y: NDArray[float], z: NDArray[float] - ) -> tuple[slice, slice, slice]: - """Return slices into the sorted input arrays that are inside the geometry bounds. - - Parameters - ---------- - x : np.ndarray[float] - 1D array of point positions in x direction. - y : np.ndarray[float] - 1D array of point positions in y direction. - z : np.ndarray[float] - 1D array of point positions in z direction. - - Returns - ------- - tuple[slice, slice, slice] - Slices into each of the three arrays that are inside the geometry bounds. - """ - bounds = self.bounds - inds_in = [] - for dim, coords in enumerate([x, y, z]): - inds = np.nonzero((bounds[0][dim] <= coords) * (coords <= bounds[1][dim]))[0] - inds_in.append(slice(0, 0) if inds.size == 0 else slice(inds[0], inds[-1] + 1)) - - return tuple(inds_in) - - def inside_meshgrid( - self, x: NDArray[float], y: NDArray[float], z: NDArray[float] - ) -> NDArray[bool]: - """Perform ``self.inside`` on a set of sorted 1D coordinates. Applies meshgrid to the - supplied coordinates before checking inside. - - Parameters - ---------- - - x : np.ndarray[float] - 1D array of point positions in x direction. - y : np.ndarray[float] - 1D array of point positions in y direction. - z : np.ndarray[float] - 1D array of point positions in z direction. - - Returns - ------- - np.ndarray[bool] - Array with shape ``(x.size, y.size, z.size)``, which is ``True`` for every - point that is inside the geometry. - """ - - arrays = tuple(map(np.array, (x, y, z))) - if any(arr.ndim != 1 for arr in arrays): - raise ValueError("Each of the supplied coordinates (x, y, z) must be 1D.") - shape = tuple(arr.size for arr in arrays) - is_inside = np.zeros(shape, dtype=bool) - inds_inside = self._inds_inside_bounds(*arrays) - coords_inside = tuple(arr[ind] for ind, arr in zip(inds_inside, arrays)) - coords_3d = np.meshgrid(*coords_inside, indexing="ij") - is_inside[inds_inside] = self.inside(*coords_3d) - return is_inside - - @abstractmethod - def intersections_tilted_plane( - self, - normal: Coordinate, - origin: Coordinate, - to_2D: MatrixReal4x4, - cleanup: bool = True, - quad_segs: Optional[int] = None, - ) -> list[Shapely]: - """Return a list of shapely geometries at the plane specified by normal and origin. - - Parameters - ---------- - normal : Coordinate - Vector defining the normal direction to the plane. - origin : Coordinate - Vector defining the plane origin. - to_2D : MatrixReal4x4 - Transformation matrix to apply to resulting shapes. - cleanup : bool = True - If True, removes extremely small features from each polygon's boundary. - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. If ``None``, uses - high-quality visualization settings. - - Returns - ------- - list[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation `_. - """ - - def intersections_plane( - self, - x: Optional[float] = None, - y: Optional[float] = None, - z: Optional[float] = None, - cleanup: bool = True, - quad_segs: Optional[int] = None, - ) -> list[Shapely]: - """Returns list of shapely geometries at plane specified by one non-None value of x,y,z. - - Parameters - ---------- - x : float = None - Position of plane in x direction, only one of x,y,z can be specified to define plane. - y : float = None - Position of plane in y direction, only one of x,y,z can be specified to define plane. - z : float = None - Position of plane in z direction, only one of x,y,z can be specified to define plane. - cleanup : bool = True - If True, removes extremely small features from each polygon's boundary. - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. If ``None``, uses - high-quality visualization settings. - - Returns - ------- - list[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation `_. - """ - axis, position = self.parse_xyz_kwargs(x=x, y=y, z=z) - origin = self.unpop_axis(position, (0, 0), axis=axis) - normal = self.unpop_axis(1, (0, 0), axis=axis) - to_2D = np.eye(4) - if axis != 2: - last, indices = self.pop_axis((0, 1, 2), axis) - to_2D = to_2D[[*list(indices), last, 3]] - return self.intersections_tilted_plane( - normal, origin, to_2D, cleanup=cleanup, quad_segs=quad_segs - ) - - def intersections_2dbox(self, plane: Box) -> list[Shapely]: - """Returns list of shapely geometries representing the intersections of the geometry with - a 2D box. - - Returns - ------- - list[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. For more details refer to - `Shapely's Documentation `_. - """ - log.warning( - "'intersections_2dbox()' is deprecated and will be removed in the future. " - "Use 'plane.intersections_with(...)' for the same functionality." - ) - return plane.intersections_with(self) - - def intersects( - self, other: Geometry, strict_inequality: tuple[bool, bool, bool] = [False, False, False] - ) -> bool: - """Returns ``True`` if two :class:`Geometry` have intersecting `.bounds`. - - Parameters - ---------- - other : :class:`Geometry` - Geometry to check intersection with. - strict_inequality : tuple[bool, bool, bool] = [False, False, False] - For each dimension, defines whether to include equality in the boundaries comparison. - If ``False``, equality is included, and two geometries that only intersect at their - boundaries will evaluate as ``True``. If ``True``, such geometries will evaluate as - ``False``. - - Returns - ------- - bool - Whether the rectangular bounding boxes of the two geometries intersect. - """ - - self_bmin, self_bmax = self.bounds - other_bmin, other_bmax = other.bounds - - for smin, omin, smax, omax, strict in zip( - self_bmin, other_bmin, self_bmax, other_bmax, strict_inequality - ): - # are all of other's minimum coordinates less than self's maximum coordinate? - in_minus = omin < smax if strict else omin <= smax - # are all of other's maximum coordinates greater than self's minimum coordinate? - in_plus = omax > smin if strict else omax >= smin - - # if either failed, return False - if not all((in_minus, in_plus)): - return False - - return True - - def contains( - self, other: Geometry, strict_inequality: tuple[bool, bool, bool] = [False, False, False] - ) -> bool: - """Returns ``True`` if the `.bounds` of ``other`` are contained within the - `.bounds` of ``self``. - - Parameters - ---------- - other : :class:`Geometry` - Geometry to check containment with. - strict_inequality : tuple[bool, bool, bool] = [False, False, False] - For each dimension, defines whether to include equality in the boundaries comparison. - If ``False``, equality will be considered as contained. If ``True``, ``other``'s - bounds must be strictly within the bounds of ``self``. - - Returns - ------- - bool - Whether the rectangular bounding box of ``other`` is contained within the bounding - box of ``self``. - """ - - self_bmin, self_bmax = self.bounds - other_bmin, other_bmax = other.bounds - - for smin, omin, smax, omax, strict in zip( - self_bmin, other_bmin, self_bmax, other_bmax, strict_inequality - ): - # are all of other's minimum coordinates greater than self's minimim coordinate? - in_minus = omin > smin if strict else omin >= smin - # are all of other's maximum coordinates less than self's maximum coordinate? - in_plus = omax < smax if strict else omax <= smax - - # if either failed, return False - if not all((in_minus, in_plus)): - return False - - return True - - def intersects_plane( - self, x: Optional[float] = None, y: Optional[float] = None, z: Optional[float] = None - ) -> bool: - """Whether self intersects plane specified by one non-None value of x,y,z. - - Parameters - ---------- - x : float = None - Position of plane in x direction, only one of x,y,z can be specified to define plane. - y : float = None - Position of plane in y direction, only one of x,y,z can be specified to define plane. - z : float = None - Position of plane in z direction, only one of x,y,z can be specified to define plane. - - Returns - ------- - bool - Whether this geometry intersects the plane. - """ - - axis, position = self.parse_xyz_kwargs(x=x, y=y, z=z) - return self.intersects_axis_position(axis, position) - - def intersects_axis_position(self, axis: int, position: float) -> bool: - """Whether self intersects plane specified by a given position along a normal axis. - - Parameters - ---------- - axis : int = None - Axis normal to the plane. - position : float = None - Position of plane along the normal axis. - - Returns - ------- - bool - Whether this geometry intersects the plane. - """ - return self.bounds[0][axis] <= position <= self.bounds[1][axis] - - @cached_property - @abstractmethod - def bounds(self) -> Bound: - """Returns bounding box min and max coordinates. - - Returns - ------- - tuple[float, float, float], tuple[float, float float] - Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. - """ - - @staticmethod - def bounds_intersection(bounds1: Bound, bounds2: Bound) -> Bound: - """Return the bounds that are the intersection of two bounds.""" - return bounds_intersection(bounds1, bounds2) - - @staticmethod - def bounds_union(bounds1: Bound, bounds2: Bound) -> Bound: - """Return the bounds that are the union of two bounds.""" - return bounds_union(bounds1, bounds2) - - @cached_property - def bounding_box(self) -> Box: - """Returns :class:`Box` representation of the bounding box of a :class:`Geometry`. - - Returns - ------- - :class:`Box` - Geometric object representing bounding box. - """ - return Box.from_bounds(*self.bounds) - - @cached_property - def zero_dims(self) -> list[Axis]: - """A list of axes along which the :class:`Geometry` is zero-sized based on its bounds.""" - zero_dims = [] - for dim in range(3): - if self.bounds[1][dim] == self.bounds[0][dim]: - zero_dims.append(dim) - return zero_dims - - def _pop_bounds(self, axis: Axis) -> tuple[Coordinate2D, tuple[Coordinate2D, Coordinate2D]]: - """Returns min and max bounds in plane normal to and tangential to ``axis``. - - Parameters - ---------- - axis : int - Integer index into 'xyz' (0,1,2). - - Returns - ------- - tuple[float, float], tuple[tuple[float, float], tuple[float, float]] - Bounds along axis and a tuple of bounds in the ordered planar coordinates. - Packed as ``(zmin, zmax), ((xmin, ymin), (xmax, ymax))``. - """ - b_min, b_max = self.bounds - zmin, (xmin, ymin) = self.pop_axis(b_min, axis=axis) - zmax, (xmax, ymax) = self.pop_axis(b_max, axis=axis) - return (zmin, zmax), ((xmin, ymin), (xmax, ymax)) - - @staticmethod - def _get_center(pt_min: float, pt_max: float) -> float: - """Returns center point based on bounds along dimension.""" - if np.isneginf(pt_min) and np.isposinf(pt_max): - return 0.0 - if np.isneginf(pt_min) or np.isposinf(pt_max): - raise SetupError( - f"Bounds of ({pt_min}, {pt_max}) supplied along one dimension. " - "We currently don't support a single ``inf`` value in bounds for ``Box``. " - "To construct a semi-infinite ``Box``, " - "please supply a large enough number instead of ``inf``. " - "For example, a location extending outside of the " - "Simulation domain (including PML)." - ) - return (pt_min + pt_max) / 2.0 - - @cached_property - def _normal_2dmaterial(self) -> Axis: - """Get the normal to the given geometry, checking that it is a 2D geometry.""" - raise ValidationError("'Medium2D' is not compatible with this geometry class.") - - def _update_from_bounds(self, bounds: tuple[float, float], axis: Axis) -> Geometry: - """Returns an updated geometry which has been transformed to fit within ``bounds`` - along the ``axis`` direction.""" - raise NotImplementedError( - "'_update_from_bounds' is not compatible with this geometry class." - ) - - @equal_aspect - @add_ax_if_none - def plot( - self, - x: Optional[float] = None, - y: Optional[float] = None, - z: Optional[float] = None, - ax: Ax = None, - plot_length_units: LengthUnit = None, - viz_spec: VisualizationSpec = None, - **patch_kwargs: Any, - ) -> Ax: - """Plot geometry cross section at single (x,y,z) coordinate. - - Parameters - ---------- - x : float = None - Position of plane in x direction, only one of x,y,z can be specified to define plane. - y : float = None - Position of plane in y direction, only one of x,y,z can be specified to define plane. - z : float = None - Position of plane in z direction, only one of x,y,z can be specified to define plane. - ax : matplotlib.axes._subplots.Axes = None - Matplotlib axes to plot on, if not specified, one is created. - plot_length_units : LengthUnit = None - Specify units to use for axis labels, tick labels, and the title. - viz_spec : VisualizationSpec = None - Plotting parameters associated with a medium to use instead of defaults. - **patch_kwargs - Optional keyword arguments passed to the matplotlib patch plotting of structure. - For details on accepted values, refer to - `Matplotlib's documentation `_. - - Returns - ------- - matplotlib.axes._subplots.Axes - The supplied or created matplotlib axes. - """ - - # find shapes that intersect self at plane - axis, position = self.parse_xyz_kwargs(x=x, y=y, z=z) - shapes_intersect = self.intersections_plane(x=x, y=y, z=z) - - plot_params = self.plot_params - if viz_spec is not None: - plot_params = plot_params.override_with_viz_spec(viz_spec) - plot_params = plot_params.include_kwargs(**patch_kwargs) - - # for each intersection, plot the shape - for shape in shapes_intersect: - ax = self.plot_shape(shape, plot_params=plot_params, ax=ax) - - # clean up the axis display - ax = self.add_ax_lims(axis=axis, ax=ax) - ax.set_aspect("equal") - # Add the default axis labels, tick labels, and title - ax = Box.add_ax_labels_and_title(ax=ax, x=x, y=y, z=z, plot_length_units=plot_length_units) - return ax - - def plot_shape(self, shape: Shapely, plot_params: PlotParams, ax: Ax) -> Ax: - """Defines how a shape is plotted on a matplotlib axes.""" - if shape.geom_type in ( - "MultiPoint", - "MultiLineString", - "MultiPolygon", - "GeometryCollection", - ): - for sub_shape in shape.geoms: - ax = self.plot_shape(shape=sub_shape, plot_params=plot_params, ax=ax) - - return ax - - _shape = Geometry.evaluate_inf_shape(shape) - - if _shape.geom_type == "LineString": - xs, ys = zip(*_shape.coords) - ax.plot(xs, ys, color=plot_params.facecolor, linewidth=plot_params.linewidth) - elif _shape.geom_type == "Point": - ax.scatter(shape.x, shape.y, color=plot_params.facecolor) - else: - patch = polygon_patch(_shape, **plot_params.to_kwargs()) - ax.add_artist(patch) - return ax - - @staticmethod - def _do_not_intersect( - bounds_a: float, bounds_b: float, shape_a: Shapely, shape_b: Shapely - ) -> bool: - """Check whether two shapes intersect.""" - - # do a bounding box check to see if any intersection to do anything about - if ( - bounds_a[0] > bounds_b[2] - or bounds_b[0] > bounds_a[2] - or bounds_a[1] > bounds_b[3] - or bounds_b[1] > bounds_a[3] - ): - return True - - # look more closely to see if intersected. - if shape_b.is_empty or not shape_a.intersects(shape_b): - return True - - return False - - @staticmethod - def _get_plot_labels(axis: Axis) -> tuple[str, str]: - """Returns planar coordinate x and y axis labels for cross section plots. - - Parameters - ---------- - axis : int - Integer index into 'xyz' (0,1,2). - - Returns - ------- - str, str - Labels of plot, packaged as ``(xlabel, ylabel)``. - """ - _, (xlabel, ylabel) = Geometry.pop_axis("xyz", axis=axis) - return xlabel, ylabel - - def _get_plot_limits( - self, axis: Axis, buffer: float = PLOT_BUFFER - ) -> tuple[Coordinate2D, Coordinate2D]: - """Gets planar coordinate limits for cross section plots. - - Parameters - ---------- - axis : int - Integer index into 'xyz' (0,1,2). - buffer : float = 0.3 - Amount of space to add around the limits on the + and - sides. - - Returns - ------- - tuple[float, float], tuple[float, float] - The x and y plot limits, packed as ``(xmin, xmax), (ymin, ymax)``. - """ - _, ((xmin, ymin), (xmax, ymax)) = self._pop_bounds(axis=axis) - return (xmin - buffer, xmax + buffer), (ymin - buffer, ymax + buffer) - - def add_ax_lims(self, axis: Axis, ax: Ax, buffer: float = PLOT_BUFFER) -> Ax: - """Sets the x,y limits based on ``self.bounds``. - - Parameters - ---------- - axis : int - Integer index into 'xyz' (0,1,2). - ax : matplotlib.axes._subplots.Axes - Matplotlib axes to add labels and limits on. - buffer : float = 0.3 - Amount of space to place around the limits on the + and - sides. - - Returns - ------- - matplotlib.axes._subplots.Axes - The supplied or created matplotlib axes. - """ - (xmin, xmax), (ymin, ymax) = self._get_plot_limits(axis=axis, buffer=buffer) - - # note: axes limits dont like inf values, so we need to evaluate them first if present - xmin, xmax, ymin, ymax = self._evaluate_inf((xmin, xmax, ymin, ymax)) - - ax.set_xlim(xmin, xmax) - ax.set_ylim(ymin, ymax) - return ax - - @staticmethod - def add_ax_labels_and_title( - ax: Ax, - x: Optional[float] = None, - y: Optional[float] = None, - z: Optional[float] = None, - plot_length_units: LengthUnit = None, - ) -> Ax: - """Sets the axis labels, tick labels, and title based on ``axis`` - and an optional ``plot_length_units`` argument. - - Parameters - ---------- - ax : matplotlib.axes._subplots.Axes - Matplotlib axes to add labels and limits on. - x : float = None - Position of plane in x direction, only one of x,y,z can be specified to define plane. - y : float = None - Position of plane in y direction, only one of x,y,z can be specified to define plane. - z : float = None - Position of plane in z direction, only one of x,y,z can be specified to define plane. - plot_length_units : LengthUnit = None - When set to a supported ``LengthUnit``, plots will be produced with annotated axes - and title with the proper units. - - Returns - ------- - matplotlib.axes._subplots.Axes - The supplied matplotlib axes. - """ - axis, position = Box.parse_xyz_kwargs(x=x, y=y, z=z) - axis_labels = Box._get_plot_labels(axis) - ax = set_default_labels_and_title( - axis_labels=axis_labels, - axis=axis, - position=position, - ax=ax, - plot_length_units=plot_length_units, - ) - return ax - - @staticmethod - def _evaluate_inf(array: ArrayLike) -> NDArray[np.floating]: - """Processes values and evaluates any infs into large (signed) numbers.""" - array = get_static(np.array(array)) - return np.where(np.isinf(array), np.sign(array) * LARGE_NUMBER, array) - - @staticmethod - def evaluate_inf_shape(shape: Shapely) -> Shapely: - """Returns a copy of shape with inf vertices replaced by large numbers if polygon.""" - if not any(np.isinf(b) for b in shape.bounds): - return shape - - def _processed_coords(coords: Sequence[tuple[Any, ...]]) -> list[tuple[float, ...]]: - evaluated = Geometry._evaluate_inf(np.array(coords)) - return [tuple(point) for point in evaluated.tolist()] - - if shape.geom_type == "Polygon": - shell = _processed_coords(shape.exterior.coords) - holes = [_processed_coords(g.coords) for g in shape.interiors] - return shapely.Polygon(shell, holes) - if shape.geom_type in {"Point", "LineString", "LinearRing"}: - return shape.__class__(Geometry._evaluate_inf(np.array(shape.coords))) - if shape.geom_type in { - "MultiPoint", - "MultiLineString", - "MultiPolygon", - "GeometryCollection", - }: - return shape.__class__([Geometry.evaluate_inf_shape(g) for g in shape.geoms]) - return shape - - @staticmethod - def pop_axis(coord: tuple[Any, Any, Any], axis: int) -> tuple[Any, tuple[Any, Any]]: - """Separates coordinate at ``axis`` index from coordinates on the plane tangent to ``axis``. - - Parameters - ---------- - coord : tuple[Any, Any, Any] - Tuple of three values in original coordinate system. - axis : int - Integer index into 'xyz' (0,1,2). - - Returns - ------- - Any, tuple[Any, Any] - The input coordinates are separated into the one along the axis provided - and the two on the planar coordinates, - like ``axis_coord, (planar_coord1, planar_coord2)``. - """ - plane_vals = list(coord) - axis_val = plane_vals.pop(axis) - return axis_val, tuple(plane_vals) - - @staticmethod - def unpop_axis(ax_coord: Any, plane_coords: tuple[Any, Any], axis: int) -> tuple[Any, Any, Any]: - """Combine coordinate along axis with coordinates on the plane tangent to the axis. - - Parameters - ---------- - ax_coord : Any - Value along axis direction. - plane_coords : tuple[Any, Any] - Values along ordered planar directions. - axis : int - Integer index into 'xyz' (0,1,2). - - Returns - ------- - tuple[Any, Any, Any] - The three values in the xyz coordinate system. - """ - coords = list(plane_coords) - coords.insert(axis, ax_coord) - return tuple(coords) - - @staticmethod - def parse_xyz_kwargs(**xyz: Any) -> tuple[Axis, float]: - """Turns x,y,z kwargs into index of the normal axis and position along that axis. - - Parameters - ---------- - x : float = None - Position of plane in x direction, only one of x,y,z can be specified to define plane. - y : float = None - Position of plane in y direction, only one of x,y,z can be specified to define plane. - z : float = None - Position of plane in z direction, only one of x,y,z can be specified to define plane. - - Returns - ------- - int, float - Index into xyz axis (0,1,2) and position along that axis. - """ - xyz_filtered = {k: v for k, v in xyz.items() if v is not None} - if len(xyz_filtered) != 1: - raise ValueError("exactly one kwarg in [x,y,z] must be specified.") - axis_label, position = list(xyz_filtered.items())[0] - axis = "xyz".index(axis_label) - return axis, position - - @staticmethod - def parse_two_xyz_kwargs(**xyz: Any) -> list[tuple[Axis, float]]: - """Turns x,y,z kwargs into indices of axes and the position along each axis. - - Parameters - ---------- - x : float = None - Position in x direction, only two of x,y,z can be specified to define line. - y : float = None - Position in y direction, only two of x,y,z can be specified to define line. - z : float = None - Position in z direction, only two of x,y,z can be specified to define line. - - Returns - ------- - [(int, float), (int, float)] - Index into xyz axis (0,1,2) and position along that axis. - """ - xyz_filtered = {k: v for k, v in xyz.items() if v is not None} - assert len(xyz_filtered) == 2, "exactly two kwarg in [x,y,z] must be specified." - xyz_list = list(xyz_filtered.items()) - return [("xyz".index(axis_label), position) for axis_label, position in xyz_list] - - @staticmethod - def rotate_points(points: ArrayFloat3D, axis: Coordinate, angle: float) -> ArrayFloat3D: - """Rotate a set of points in 3D. - - Parameters - ---------- - points : ArrayLike[float] - Array of shape ``(3, ...)``. - axis : Coordinate - Axis of rotation - angle : float - Angle of rotation counter-clockwise around the axis (rad). - """ - rotation = RotationAroundAxis(axis=axis, angle=angle) - return rotation.rotate_vector(points) - - def reflect_points( - self, - points: ArrayFloat3D, - polar_axis: Axis, - angle_theta: float, - angle_phi: float, - ) -> ArrayFloat3D: - """Reflect a set of points in 3D at a plane passing through the coordinate origin defined - and normal to a given axis defined in polar coordinates (theta, phi) w.r.t. the - ``polar_axis`` which can be 0, 1, or 2. - - Parameters - ---------- - points : ArrayLike[float] - Array of shape ``(3, ...)``. - polar_axis : Axis - Cartesian axis w.r.t. which the normal axis angles are defined. - angle_theta : float - Polar angle w.r.t. the polar axis. - angle_phi : float - Azimuth angle around the polar axis. - """ - - # Rotate such that the plane normal is along the polar_axis - axis_theta, axis_phi = [0, 0, 0], [0, 0, 0] - axis_phi[polar_axis] = 1 - plane_axes = [0, 1, 2] - plane_axes.pop(polar_axis) - axis_theta[plane_axes[1]] = 1 - points_new = self.rotate_points(points, axis_phi, -angle_phi) - points_new = self.rotate_points(points_new, axis_theta, -angle_theta) - - # Flip the ``polar_axis`` coordinate of the points, which is now normal to the plane - points_new[polar_axis, :] *= -1 - - # Rotate back - points_new = self.rotate_points(points_new, axis_theta, angle_theta) - points_new = self.rotate_points(points_new, axis_phi, angle_phi) - - return points_new - - def volume(self, bounds: Bound = None) -> float: - """Returns object's volume with optional bounds. - - Parameters - ---------- - bounds : tuple[tuple[float, float, float], tuple[float, float, float]] = None - Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. - - Returns - ------- - float - Volume in um^3. - """ - - if not bounds: - bounds = self.bounds - - return self._volume(bounds) - - @abstractmethod - def _volume(self, bounds: Bound) -> float: - """Returns object's volume within given bounds.""" - - def surface_area(self, bounds: Bound = None) -> float: - """Returns object's surface area with optional bounds. - - Parameters - ---------- - bounds : tuple[tuple[float, float, float], tuple[float, float, float]] = None - Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. - - Returns - ------- - float - Surface area in um^2. - """ - - if not bounds: - bounds = self.bounds - - return self._surface_area(bounds) - - @abstractmethod - def _surface_area(self, bounds: Bound) -> float: - """Returns object's surface area within given bounds.""" - - def translated(self, x: float, y: float, z: float) -> Geometry: - """Return a translated copy of this geometry. - - Parameters - ---------- - x : float - Translation along x. - y : float - Translation along y. - z : float - Translation along z. - - Returns - ------- - :class:`Geometry` - Translated copy of this geometry. - """ - return Transformed(geometry=self, transform=Transformed.translation(x, y, z)) - - def scaled(self, x: float = 1.0, y: float = 1.0, z: float = 1.0) -> Geometry: - """Return a scaled copy of this geometry. - - Parameters - ---------- - x : float = 1.0 - Scaling factor along x. - y : float = 1.0 - Scaling factor along y. - z : float = 1.0 - Scaling factor along z. - - Returns - ------- - :class:`Geometry` - Scaled copy of this geometry. - """ - return Transformed(geometry=self, transform=Transformed.scaling(x, y, z)) - - def rotated(self, angle: float, axis: Union[Axis, Coordinate]) -> Geometry: - """Return a rotated copy of this geometry. - - Parameters - ---------- - angle : float - Rotation angle (in radians). - axis : Union[int, tuple[float, float, float]] - Axis of rotation: 0, 1, or 2 for x, y, and z, respectively, or a 3D vector. - - Returns - ------- - :class:`Geometry` - Rotated copy of this geometry. - """ - return Transformed(geometry=self, transform=Transformed.rotation(angle, axis)) - - def reflected(self, normal: Coordinate) -> Geometry: - """Return a reflected copy of this geometry. - - Parameters - ---------- - normal : tuple[float, float, float] - The 3D normal vector of the plane of reflection. The plane is assumed - to pass through the origin (0,0,0). - - Returns - ------- - :class:`Geometry` - Reflected copy of this geometry. - """ - return Transformed(geometry=self, transform=Transformed.reflection(normal)) - - """ Field and coordinate transformations """ - - @staticmethod - def car_2_sph(x: float, y: float, z: float) -> tuple[float, float, float]: - """Convert Cartesian to spherical coordinates. - - Parameters - ---------- - x : float - x coordinate relative to ``local_origin``. - y : float - y coordinate relative to ``local_origin``. - z : float - z coordinate relative to ``local_origin``. - - Returns - ------- - tuple[float, float, float] - r, theta, and phi coordinates relative to ``local_origin``. - """ - r = np.sqrt(x**2 + y**2 + z**2) - theta = np.arccos(z / r) - phi = np.arctan2(y, x) - return r, theta, phi - - @staticmethod - def sph_2_car(r: float, theta: float, phi: float) -> tuple[float, float, float]: - """Convert spherical to Cartesian coordinates. - - Parameters - ---------- - r : float - radius. - theta : float - polar angle (rad) downward from x=y=0 line. - phi : float - azimuthal (rad) angle from y=z=0 line. - - Returns - ------- - tuple[float, float, float] - x, y, and z coordinates relative to ``local_origin``. - """ - r_sin_theta = r * np.sin(theta) - x = r_sin_theta * np.cos(phi) - y = r_sin_theta * np.sin(phi) - z = r * np.cos(theta) - return x, y, z - - @staticmethod - def sph_2_car_field( - f_r: float, f_theta: float, f_phi: float, theta: float, phi: float - ) -> tuple[complex, complex, complex]: - """Convert vector field components in spherical coordinates to cartesian. - - Parameters - ---------- - f_r : float - radial component of the vector field. - f_theta : float - polar angle component of the vector fielf. - f_phi : float - azimuthal angle component of the vector field. - theta : float - polar angle (rad) of location of the vector field. - phi : float - azimuthal angle (rad) of location of the vector field. - - Returns - ------- - tuple[float, float, float] - x, y, and z components of the vector field in cartesian coordinates. - """ - sin_theta = np.sin(theta) - cos_theta = np.cos(theta) - sin_phi = np.sin(phi) - cos_phi = np.cos(phi) - f_x = f_r * sin_theta * cos_phi + f_theta * cos_theta * cos_phi - f_phi * sin_phi - f_y = f_r * sin_theta * sin_phi + f_theta * cos_theta * sin_phi + f_phi * cos_phi - f_z = f_r * cos_theta - f_theta * sin_theta - return f_x, f_y, f_z - - @staticmethod - def car_2_sph_field( - f_x: float, f_y: float, f_z: float, theta: float, phi: float - ) -> tuple[complex, complex, complex]: - """Convert vector field components in cartesian coordinates to spherical. - - Parameters - ---------- - f_x : float - x component of the vector field. - f_y : float - y component of the vector fielf. - f_z : float - z component of the vector field. - theta : float - polar angle (rad) of location of the vector field. - phi : float - azimuthal angle (rad) of location of the vector field. - - Returns - ------- - tuple[float, float, float] - radial (s), elevation (theta), and azimuthal (phi) components - of the vector field in spherical coordinates. - """ - sin_theta = np.sin(theta) - cos_theta = np.cos(theta) - sin_phi = np.sin(phi) - cos_phi = np.cos(phi) - f_r = f_x * sin_theta * cos_phi + f_y * sin_theta * sin_phi + f_z * cos_theta - f_theta = f_x * cos_theta * cos_phi + f_y * cos_theta * sin_phi - f_z * sin_theta - f_phi = -f_x * sin_phi + f_y * cos_phi - return f_r, f_theta, f_phi - - @staticmethod - def kspace_2_sph(ux: float, uy: float, axis: Axis) -> tuple[float, float]: - """Convert normalized k-space coordinates to angles. - - Parameters - ---------- - ux : float - normalized kx coordinate. - uy : float - normalized ky coordinate. - axis : int - axis along which the observation plane is oriented. - - Returns - ------- - tuple[float, float] - theta and phi coordinates relative to ``local_origin``. - """ - phi_local = np.arctan2(uy, ux) - with np.errstate(invalid="ignore"): - theta_local = np.arcsin(np.sqrt(ux**2 + uy**2)) - # Spherical coordinates rotation matrix reference: - # https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula#Matrix_notation - if axis == 2: - return theta_local, phi_local - - x = np.cos(theta_local) - y = np.sin(theta_local) * np.cos(phi_local) - z = np.sin(theta_local) * np.sin(phi_local) - - if axis == 1: - x, y, z = y, x, z - - theta = np.arccos(z) - phi = np.arctan2(y, x) - return theta, phi - - @staticmethod - @verify_packages_import(["gdstk"]) - def load_gds_vertices_gdstk( - gds_cell: Cell, - gds_layer: int, - gds_dtype: Optional[int] = None, - gds_scale: PositiveFloat = 1.0, - ) -> list[ArrayFloat2D]: - """Load polygon vertices from a ``gdstk.Cell``. - - Parameters - ---------- - gds_cell : gdstk.Cell - ``gdstk.Cell`` containing 2D geometric data. - gds_layer : int - Layer index in the ``gds_cell``. - gds_dtype : int = None - Data-type index in the ``gds_cell``. If ``None``, imports all data for this layer into - the returned list. - gds_scale : float = 1.0 - Length scale used in GDS file in units of micrometer. For example, if gds file uses - nanometers, set ``gds_scale=1e-3``. Must be positive. - - Returns - ------- - list[ArrayFloat2D] - List of polygon vertices - """ - - # apply desired scaling and load the polygon vertices - if gds_dtype is not None: - # if both layer and datatype are specified, let gdstk do the filtering for better - # performance on large layouts - all_vertices = [ - polygon.scale(gds_scale).points - for polygon in gds_cell.get_polygons(layer=gds_layer, datatype=gds_dtype) - ] - else: - all_vertices = [ - polygon.scale(gds_scale).points - for polygon in gds_cell.get_polygons() - if polygon.layer == gds_layer - ] - # make sure something got loaded, otherwise error - if not all_vertices: - raise Tidy3dKeyError( - f"Couldn't load gds_cell, no vertices found at gds_layer={gds_layer} " - f"with specified gds_dtype={gds_dtype}." - ) - - return all_vertices - - @staticmethod - @verify_packages_import(["gdstk"]) - def from_gds( - gds_cell: Cell, - axis: Axis, - slab_bounds: tuple[float, float], - gds_layer: int, - gds_dtype: Optional[int] = None, - gds_scale: PositiveFloat = 1.0, - dilation: float = 0.0, - sidewall_angle: float = 0, - reference_plane: PlanePosition = "middle", - ) -> Geometry: - """Import a ``gdstk.Cell`` and extrude it into a GeometryGroup. - - Parameters - ---------- - gds_cell : gdstk.Cell - ``gdstk.Cell`` containing 2D geometric data. - axis : int - Integer index defining the extrusion axis: 0 (x), 1 (y), or 2 (z). - slab_bounds: tuple[float, float] - Minimal and maximal positions of the extruded slab along ``axis``. - gds_layer : int - Layer index in the ``gds_cell``. - gds_dtype : int = None - Data-type index in the ``gds_cell``. If ``None``, imports all data for this layer into - the returned list. - gds_scale : float = 1.0 - Length scale used in GDS file in units of micrometer. For example, if gds file uses - nanometers, set ``gds_scale=1e-3``. Must be positive. - dilation : float = 0.0 - Dilation (positive) or erosion (negative) amount to be applied to the original polygons. - sidewall_angle : float = 0 - Angle of the extrusion sidewalls, away from the vertical direction, in radians. Positive - (negative) values result in slabs larger (smaller) at the base than at the top. - reference_plane : PlanePosition = "middle" - Reference position of the (dilated/eroded) polygons along the slab axis. One of - ``"middle"`` (polygons correspond to the center of the slab bounds), ``"bottom"`` - (minimal slab bound position), or ``"top"`` (maximal slab bound position). This value - has no effect if ``sidewall_angle == 0``. - - Returns - ------- - :class:`Geometry` - Geometries created from the 2D data. - """ - import gdstk - - if not isinstance(gds_cell, gdstk.Cell): - # Check if it might be a gdstk cell but gdstk is not found (should be caught by decorator) - # or if it's an entirely different type. - if "gdstk" in gds_cell.__class__.__name__.lower(): - raise Tidy3dImportError( - "Module 'gdstk' not found. It is required to import gdstk cells." - ) - raise Tidy3dImportError("Argument 'gds_cell' must be an instance of 'gdstk.Cell'.") - - gds_loader_fn = Geometry.load_gds_vertices_gdstk - geometries = [] - with log as consolidated_logger: - for vertices in gds_loader_fn(gds_cell, gds_layer, gds_dtype, gds_scale): - # buffer(0) is necessary to merge self-intersections - shape = shapely.set_precision(shapely.Polygon(vertices).buffer(0), POLY_GRID_SIZE) - try: - geometries.append( - from_shapely( - shape, axis, slab_bounds, dilation, sidewall_angle, reference_plane - ) - ) - except ValidationError as error: - consolidated_logger.warning(str(error)) - except Tidy3dError as error: - consolidated_logger.warning(str(error)) - return geometries[0] if len(geometries) == 1 else GeometryGroup(geometries=geometries) - - @staticmethod - def from_shapely( - shape: Shapely, - axis: Axis, - slab_bounds: tuple[float, float], - dilation: float = 0.0, - sidewall_angle: float = 0, - reference_plane: PlanePosition = "middle", - ) -> Geometry: - """Convert a shapely primitive into a geometry instance by extrusion. - - Parameters - ---------- - shape : shapely.geometry.base.BaseGeometry - Shapely primitive to be converted. It must be a linear ring, a polygon or a collection - of any of those. - axis : int - Integer index defining the extrusion axis: 0 (x), 1 (y), or 2 (z). - slab_bounds: tuple[float, float] - Minimal and maximal positions of the extruded slab along ``axis``. - dilation : float - Dilation of the polygon in the base by shifting each edge along its normal outwards - direction by a distance; a negative value corresponds to erosion. - sidewall_angle : float = 0 - Angle of the extrusion sidewalls, away from the vertical direction, in radians. Positive - (negative) values result in slabs larger (smaller) at the base than at the top. - reference_plane : PlanePosition = "middle" - Reference position of the (dilated/eroded) polygons along the slab axis. One of - ``"middle"`` (polygons correspond to the center of the slab bounds), ``"bottom"`` - (minimal slab bound position), or ``"top"`` (maximal slab bound position). This value - has no effect if ``sidewall_angle == 0``. - - Returns - ------- - :class:`Geometry` - Geometry extruded from the 2D data. - """ - return from_shapely(shape, axis, slab_bounds, dilation, sidewall_angle, reference_plane) - - @verify_packages_import(["gdstk"]) - def to_gdstk( - self, - x: Optional[float] = None, - y: Optional[float] = None, - z: Optional[float] = None, - gds_layer: NonNegativeInt = 0, - gds_dtype: NonNegativeInt = 0, - ) -> list: - """Convert a Geometry object's planar slice to a .gds type polygon. - - Parameters - ---------- - x : float = None - Position of plane in x direction, only one of x,y,z can be specified to define plane. - y : float = None - Position of plane in y direction, only one of x,y,z can be specified to define plane. - z : float = None - Position of plane in z direction, only one of x,y,z can be specified to define plane. - gds_layer : int = 0 - Layer index to use for the shapes stored in the .gds file. - gds_dtype : int = 0 - Data-type index to use for the shapes stored in the .gds file. - - Return - ------ - List - List of `gdstk.Polygon`. - """ - import gdstk - - shapes = self.intersections_plane(x=x, y=y, z=z) - polygons = [] - for shape in shapes: - for vertices in vertices_from_shapely(shape): - if len(vertices) == 1: - polygons.append(gdstk.Polygon(vertices[0], gds_layer, gds_dtype)) - else: - polygons.extend( - gdstk.boolean( - vertices[:1], - vertices[1:], - "not", - layer=gds_layer, - datatype=gds_dtype, - ) - ) - return polygons - - @verify_packages_import(["gdstk"]) - def to_gds( - self, - cell: Cell, - x: Optional[float] = None, - y: Optional[float] = None, - z: Optional[float] = None, - gds_layer: NonNegativeInt = 0, - gds_dtype: NonNegativeInt = 0, - ) -> None: - """Append a Geometry object's planar slice to a .gds cell. - - Parameters - ---------- - cell : ``gdstk.Cell`` - Cell object to which the generated polygons are added. - x : float = None - Position of plane in x direction, only one of x,y,z can be specified to define plane. - y : float = None - Position of plane in y direction, only one of x,y,z can be specified to define plane. - z : float = None - Position of plane in z direction, only one of x,y,z can be specified to define plane. - gds_layer : int = 0 - Layer index to use for the shapes stored in the .gds file. - gds_dtype : int = 0 - Data-type index to use for the shapes stored in the .gds file. - """ - import gdstk - - if not isinstance(cell, gdstk.Cell): - if "gdstk" in cell.__class__.__name__.lower(): - raise Tidy3dImportError( - "Module 'gdstk' not found. It is required to export shapes to gdstk cells." - ) - raise Tidy3dImportError("Argument 'cell' must be an instance of 'gdstk.Cell'.") - - polygons = self.to_gdstk(x=x, y=y, z=z, gds_layer=gds_layer, gds_dtype=gds_dtype) - if polygons: - cell.add(*polygons) - - @verify_packages_import(["gdstk"]) - def to_gds_file( - self, - fname: PathLike, - x: Optional[float] = None, - y: Optional[float] = None, - z: Optional[float] = None, - gds_layer: NonNegativeInt = 0, - gds_dtype: NonNegativeInt = 0, - gds_cell_name: str = "MAIN", - ) -> None: - """Export a Geometry object's planar slice to a .gds file. - - Parameters - ---------- - fname : PathLike - Full path to the .gds file to save the :class:`Geometry` slice to. - x : float = None - Position of plane in x direction, only one of x,y,z can be specified to define plane. - y : float = None - Position of plane in y direction, only one of x,y,z can be specified to define plane. - z : float = None - Position of plane in z direction, only one of x,y,z can be specified to define plane. - gds_layer : int = 0 - Layer index to use for the shapes stored in the .gds file. - gds_dtype : int = 0 - Data-type index to use for the shapes stored in the .gds file. - gds_cell_name : str = 'MAIN' - Name of the cell created in the .gds file to store the geometry. - """ - try: - import gdstk - except ImportError as e: - raise Tidy3dImportError( - "Python module 'gdstk' not found. To export geometries to .gds " - "files, please install it." - ) from e - - library = gdstk.Library() - cell = library.new_cell(gds_cell_name) - self.to_gds(cell, x=x, y=y, z=z, gds_layer=gds_layer, gds_dtype=gds_dtype) - fname = pathlib.Path(fname) - fname.parent.mkdir(parents=True, exist_ok=True) - library.write_gds(fname) - - def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: - """Compute the adjoint derivatives for this object.""" - raise NotImplementedError(f"Can't compute derivative for 'Geometry': '{type(self)}'.") - - def _as_union(self) -> list[Geometry]: - """Return a list of geometries that, united, make up the given geometry.""" - if isinstance(self, GeometryGroup): - return self.geometries - - if isinstance(self, ClipOperation) and self.operation == "union": - return (self.geometry_a, self.geometry_b) - return (self,) - - def __add__(self, other: Union[int, Geometry]) -> Union[Self, GeometryGroup]: - """Union of geometries""" - # This allows the user to write sum(geometries...) with the default start=0 - if isinstance(other, int): - return self - if not isinstance(other, Geometry): - return NotImplemented # type: ignore[return-value] - return GeometryGroup(geometries=self._as_union() + other._as_union()) - - def __radd__(self, other: Union[int, Geometry]) -> Union[Self, GeometryGroup]: - """Union of geometries""" - # This allows the user to write sum(geometries...) with the default start=0 - if isinstance(other, int): - return self - if not isinstance(other, Geometry): - return NotImplemented # type: ignore[return-value] - return GeometryGroup(geometries=other._as_union() + self._as_union()) - - def __or__(self, other: Geometry) -> GeometryGroup: - """Union of geometries""" - if not isinstance(other, Geometry): - return NotImplemented - return GeometryGroup(geometries=self._as_union() + other._as_union()) - - def __mul__(self, other: Geometry) -> ClipOperation: - """Intersection of geometries""" - if not isinstance(other, Geometry): - return NotImplemented - return ClipOperation(operation="intersection", geometry_a=self, geometry_b=other) - - def __and__(self, other: Geometry) -> ClipOperation: - """Intersection of geometries""" - if not isinstance(other, Geometry): - return NotImplemented - return ClipOperation(operation="intersection", geometry_a=self, geometry_b=other) - - def __sub__(self, other: Geometry) -> ClipOperation: - """Difference of geometries""" - if not isinstance(other, Geometry): - return NotImplemented # type: ignore[return-value] - return ClipOperation(operation="difference", geometry_a=self, geometry_b=other) - - def __xor__(self, other: Geometry) -> ClipOperation: - """Symmetric difference of geometries""" - if not isinstance(other, Geometry): - return NotImplemented - return ClipOperation(operation="symmetric_difference", geometry_a=self, geometry_b=other) - - def __pos__(self) -> Self: - """No op""" - return self - - def __neg__(self) -> ClipOperation: - """Opposite of a geometry""" - return ClipOperation( - operation="difference", geometry_a=Box(size=(inf, inf, inf)), geometry_b=self - ) - - def __invert__(self) -> ClipOperation: - """Opposite of a geometry""" - return ClipOperation( - operation="difference", geometry_a=Box(size=(inf, inf, inf)), geometry_b=self - ) - - -""" Abstract subclasses """ - - -class Centered(Geometry, ABC): - """Geometry with a well defined center.""" - - center: Optional[TracedCoordinate] = Field( - None, - title="Center", - description="Center of object in x, y, and z.", - json_schema_extra={"units": MICROMETER}, - ) - - @field_validator("center", mode="before") - @classmethod - def _center_default(cls, val: Any) -> Any: - """Make sure center is not infinitiy.""" - if val is None: - val = (0.0, 0.0, 0.0) - return val - - @field_validator("center") - @classmethod - def _center_not_inf(cls, val: tuple[float, float, float]) -> tuple[float, float, float]: - """Make sure center is not infinitiy.""" - if any(np.isinf(v) for v in val): - raise ValidationError("center can not contain td.inf terms.") - return val - - -class SimplePlaneIntersection(Geometry, ABC): - """A geometry where intersections with an axis aligned plane may be computed efficiently.""" - - def intersections_tilted_plane( - self, - normal: Coordinate, - origin: Coordinate, - to_2D: MatrixReal4x4, - cleanup: bool = True, - quad_segs: Optional[int] = None, - ) -> list[Shapely]: - """Return a list of shapely geometries at the plane specified by normal and origin. - Checks special cases before relying on the complete computation. - - Parameters - ---------- - normal : Coordinate - Vector defining the normal direction to the plane. - origin : Coordinate - Vector defining the plane origin. - to_2D : MatrixReal4x4 - Transformation matrix to apply to resulting shapes. - cleanup : bool = True - If True, removes extremely small features from each polygon's boundary. - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. If ``None``, uses - high-quality visualization settings. - - Returns - ------- - list[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation `_. - """ - - # Check if normal is a special case, where the normal is aligned with an axis. - if np.sum(np.isclose(normal, 0.0)) == 2: - axis = np.argmax(np.abs(normal)).item() - coord = "xyz"[axis] - kwargs = {coord: origin[axis]} - section = self.intersections_plane(cleanup=cleanup, quad_segs=quad_segs, **kwargs) - # Apply transformation in the plane by removing row and column - to_2D_in_plane = np.delete(np.delete(to_2D, 2, 0), axis, 1) - - def transform(p_array: NDArray) -> NDArray: - return np.dot( - np.hstack((p_array, np.ones((p_array.shape[0], 1)))), to_2D_in_plane.T - )[:, :2] - - transformed_section = shapely.transform(section, transformation=transform) - return transformed_section - # Otherwise compute the arbitrary intersection - return self._do_intersections_tilted_plane( - normal=normal, origin=origin, to_2D=to_2D, quad_segs=quad_segs - ) - - @abstractmethod - def _do_intersections_tilted_plane( - self, - normal: Coordinate, - origin: Coordinate, - to_2D: MatrixReal4x4, - quad_segs: Optional[int] = None, - ) -> list[Shapely]: - """Return a list of shapely geometries at the plane specified by normal and origin. - - Parameters - ---------- - normal : Coordinate - Vector defining the normal direction to the plane. - origin : Coordinate - Vector defining the plane origin. - to_2D : MatrixReal4x4 - Transformation matrix to apply to resulting shapes. - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. - - Returns - ------- - list[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation `_. - """ - - -class Planar(SimplePlaneIntersection, Geometry, ABC): - """Geometry with one ``axis`` that is slab-like with thickness ``height``.""" - - axis: Axis = Field( - 2, - title="Axis", - description="Specifies dimension of the planar axis (0,1,2) -> (x,y,z).", - ) - - sidewall_angle: TracedFloat = Field( - 0.0, - title="Sidewall angle", - description="Angle of the sidewall. " - "``sidewall_angle=0`` (default) specifies a vertical wall; " - "``0 float: - lower_bound = -np.pi / 2 - upper_bound = np.pi / 2 - if (val <= lower_bound) or (val >= upper_bound): - # u03C0 is unicode for pi - raise ValidationError(f"Sidewall angle ({val}) must be between -π/2 and π/2 rad.") - return val - - @property - @abstractmethod - def center_axis(self) -> float: - """Gets the position of the center of the geometry in the out of plane dimension.""" - - @property - @abstractmethod - def length_axis(self) -> float: - """Gets the length of the geometry along the out of plane dimension.""" - - @property - def finite_length_axis(self) -> float: - """Gets the length of the geometry along the out of plane dimension. - If the length is td.inf, return ``LARGE_NUMBER`` - """ - return min(self.length_axis, LARGE_NUMBER) - - @property - def reference_axis_pos(self) -> float: - """Coordinate along the slab axis at the reference plane. - - Returns the axis coordinate corresponding to the selected - reference_plane: - - "bottom": lower bound of slab_bounds - - "middle": center_axis - - "top": upper bound of slab_bounds - """ - if self.reference_plane == "bottom": - return self.slab_bounds[0] - if self.reference_plane == "top": - return self.slab_bounds[1] - # default to middle - return self.center_axis - - def intersections_plane( - self, - x: Optional[float] = None, - y: Optional[float] = None, - z: Optional[float] = None, - cleanup: bool = True, - quad_segs: Optional[int] = None, - ) -> list[Shapely]: - """Returns shapely geometry at plane specified by one non None value of x,y,z. - - Parameters - ---------- - x : float - Position of plane in x direction, only one of x,y,z can be specified to define plane. - y : float - Position of plane in y direction, only one of x,y,z can be specified to define plane. - z : float - Position of plane in z direction, only one of x,y,z can be specified to define plane. - cleanup : bool = True - If True, removes extremely small features from each polygon's boundary. - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. If ``None``, uses - high-quality visualization settings. - - Returns - ------- - list[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation ``. - """ - axis, position = self.parse_xyz_kwargs(x=x, y=y, z=z) - if not self.intersects_axis_position(axis, position): - return [] - if axis == self.axis: - return self._intersections_normal(position, quad_segs=quad_segs) - return self._intersections_side(position, axis) - - @abstractmethod - def _intersections_normal(self, z: float, quad_segs: Optional[int] = None) -> list: - """Find shapely geometries intersecting planar geometry with axis normal to slab. - - Parameters - ---------- - z : float - Position along the axis normal to slab - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. If ``None``, uses - high-quality visualization settings. - - Returns - ------- - list[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation `_. - """ - - @abstractmethod - def _intersections_side(self, position: float, axis: Axis) -> list[Shapely]: - """Find shapely geometries intersecting planar geometry with axis orthogonal to plane. - - Parameters - ---------- - position : float - Position along axis. - axis : int - Integer index into 'xyz' (0,1,2). - - Returns - ------- - list[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation `_. - """ - - def _order_axis(self, axis: int) -> int: - """Order the axis as if self.axis is along z-direction. - - Parameters - ---------- - axis : int - Integer index into the structure's planar axis. - - Returns - ------- - int - New index of axis. - """ - axis_index = [0, 1] - axis_index.insert(self.axis, 2) - return axis_index[axis] - - def _order_by_axis(self, plane_val: Any, axis_val: Any, axis: int) -> tuple[Any, Any]: - """Orders a value in the plane and value along axis in correct (x,y) order for plotting. - Note: sometimes if axis=1 and we compute cross section values orthogonal to axis, - they can either be x or y in the plots. - This function allows one to figure out the ordering. - - Parameters - ---------- - plane_val : Any - The value in the planar coordinate. - axis_val : Any - The value in the ``axis`` coordinate. - axis : int - Integer index into the structure's planar axis. - - Returns - ------- - ``(Any, Any)`` - The two planar coordinates in this new coordinate system. - """ - vals = 3 * [plane_val] - vals[self.axis] = axis_val - _, (val_x, val_y) = self.pop_axis(vals, axis=axis) - return val_x, val_y - - @cached_property - def _tanq(self) -> float: - """Value of ``tan(sidewall_angle)``. - - The (possibliy infinite) geometry offset is given by ``_tanq * length_axis``. - """ - return np.tan(self.sidewall_angle) - - -class Circular(Geometry): - """Geometry with circular characteristics (specified by a radius).""" - - radius: NonNegativeFloat = Field( - title="Radius", - description="Radius of geometry.", - json_schema_extra={"units": MICROMETER}, - ) - - @field_validator("radius") - @classmethod - def _radius_not_inf(cls, val: float) -> float: - """Make sure center is not infinitiy.""" - if np.isinf(val): - raise ValidationError("radius can not be 'td.inf'.") - return val - - def _intersect_dist(self, position: float, z0: float) -> float: - """Distance between points on circle at z=position where center of circle at z=z0. - - Parameters - ---------- - position : float - position along z. - z0 : float - center of circle in z. - - Returns - ------- - float - Distance between points on the circle intersecting z=z, if no points, ``None``. - """ - dz = np.abs(z0 - position) - if dz > self.radius: - return None - return 2 * np.sqrt(self.radius**2 - dz**2) - - -"""Primitive classes""" - - -class Box(SimplePlaneIntersection, Centered): - """Rectangular prism. - Also base class for :class:`.Simulation`, :class:`Monitor`, and :class:`Source`. - - Example - ------- - >>> b = Box(center=(1,2,3), size=(2,2,2)) - """ - - size: TracedSize = Field( - title="Size", - description="Size in x, y, and z directions.", - json_schema_extra={"units": MICROMETER}, - ) - - @classmethod - def from_bounds(cls, rmin: Coordinate, rmax: Coordinate, **kwargs: Any) -> Self: - """Constructs a :class:`Box` from minimum and maximum coordinate bounds - - Parameters - ---------- - rmin : tuple[float, float, float] - (x, y, z) coordinate of the minimum values. - rmax : tuple[float, float, float] - (x, y, z) coordinate of the maximum values. - - Example - ------- - >>> b = Box.from_bounds(rmin=(-1, -2, -3), rmax=(3, 2, 1)) - """ - - center = tuple(cls._get_center(pt_min, pt_max) for pt_min, pt_max in zip(rmin, rmax)) - size = tuple((pt_max - pt_min) for pt_min, pt_max in zip(rmin, rmax)) - return cls(center=center, size=size, **kwargs) - - @cached_property - def _normal_axis(self) -> Axis: - """Axis normal to the Box. Errors if box is not planar.""" - if self.size.count(0.0) != 1: - raise ValidationError( - f"Tried to get 'normal_axis' of 'Box' that is not planar. Given 'size={self.size}.'" - ) - return self.size.index(0.0) - - @classmethod - def surfaces(cls, size: Size, center: Coordinate, **kwargs: Any) -> list[Self]: - """Returns a list of 6 :class:`Box` instances corresponding to each surface of a 3D volume. - The output surfaces are stored in the order [x-, x+, y-, y+, z-, z+], where x, y, and z - denote which axis is perpendicular to that surface, while "-" and "+" denote the direction - of the normal vector of that surface. If a name is provided, each output surface's name - will be that of the provided name appended with the above symbols. E.g., if the provided - name is "box", the x+ surfaces's name will be "box_x+". - - Parameters - ---------- - size : tuple[float, float, float] - Size of object in x, y, and z directions. - center : tuple[float, float, float] - Center of object in x, y, and z. - - Example - ------- - >>> b = Box.surfaces(size=(1, 2, 3), center=(3, 2, 1)) - """ - - if any(s == 0.0 for s in size): - raise SetupError( - "Can't generate surfaces for the given object because it has zero volume." - ) - - bounds = Box(center=center, size=size).bounds - - # Set up geometry data and names for each surface: - centers = [list(center) for _ in range(6)] - sizes = [list(size) for _ in range(6)] - - surface_index = 0 - for dim_index in range(3): - for min_max_index in range(2): - new_center = centers[surface_index] - new_size = sizes[surface_index] - - new_center[dim_index] = bounds[min_max_index][dim_index] - new_size[dim_index] = 0.0 - - centers[surface_index] = new_center - sizes[surface_index] = new_size - - surface_index += 1 - - name_base = kwargs.pop("name", "") - kwargs.pop("normal_dir", None) - - names = [] - normal_dirs = [] - - for coord in "xyz": - for direction in "-+": - surface_name = name_base + "_" + coord + direction - names.append(surface_name) - normal_dirs.append(direction) - - # ignore surfaces that are infinitely far away - del_idx = [] - for idx, _size in enumerate(size): - if _size == inf: - del_idx.append(idx) - del_idx = [[2 * i, 2 * i + 1] for i in del_idx] - del_idx = [item for sublist in del_idx for item in sublist] - - def del_items(items: Iterable, indices: int) -> list: - """Delete list items at indices.""" - return [i for j, i in enumerate(items) if j not in indices] - - centers = del_items(centers, del_idx) - sizes = del_items(sizes, del_idx) - names = del_items(names, del_idx) - normal_dirs = del_items(normal_dirs, del_idx) - - surfaces = [] - for _cent, _size, _name, _normal_dir in zip(centers, sizes, names, normal_dirs): - if "normal_dir" in cls.model_fields: - kwargs["normal_dir"] = _normal_dir - - if "name" in cls.model_fields: - kwargs["name"] = _name - - surface = cls(center=_cent, size=_size, **kwargs) - surfaces.append(surface) - - return surfaces - - @classmethod - def surfaces_with_exclusion(cls, size: Size, center: Coordinate, **kwargs: Any) -> list[Self]: - """Returns a list of 6 :class:`Box` instances corresponding to each surface of a 3D volume. - The output surfaces are stored in the order [x-, x+, y-, y+, z-, z+], where x, y, and z - denote which axis is perpendicular to that surface, while "-" and "+" denote the direction - of the normal vector of that surface. If a name is provided, each output surface's name - will be that of the provided name appended with the above symbols. E.g., if the provided - name is "box", the x+ surfaces's name will be "box_x+". If ``kwargs`` contains an - ``exclude_surfaces`` parameter, the returned list of surfaces will not include the excluded - surfaces. Otherwise, the behavior is identical to that of ``surfaces()``. - - Parameters - ---------- - size : tuple[float, float, float] - Size of object in x, y, and z directions. - center : tuple[float, float, float] - Center of object in x, y, and z. - - Example - ------- - >>> b = Box.surfaces_with_exclusion( - ... size=(1, 2, 3), center=(3, 2, 1), exclude_surfaces=["x-"] - ... ) - """ - exclude_surfaces = kwargs.pop("exclude_surfaces", None) - surfaces = cls.surfaces(size=size, center=center, **kwargs) - if "name" in cls.model_fields and exclude_surfaces: - surfaces = [surf for surf in surfaces if surf.name[-2:] not in exclude_surfaces] - return surfaces - - @verify_packages_import(["trimesh"]) - def _do_intersections_tilted_plane( - self, - normal: Coordinate, - origin: Coordinate, - to_2D: MatrixReal4x4, - quad_segs: Optional[int] = None, - ) -> list[Shapely]: - """Return a list of shapely geometries at the plane specified by normal and origin. - - Parameters - ---------- - normal : Coordinate - Vector defining the normal direction to the plane. - origin : Coordinate - Vector defining the plane origin. - to_2D : MatrixReal4x4 - Transformation matrix to apply to resulting shapes. - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. Not used for Box geometry. - - Returns - ------- - list[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation `_. - """ - import trimesh - - (x0, y0, z0), (x1, y1, z1) = self.bounds - vertices = [ - (x0, y0, z0), # 0 - (x0, y0, z1), # 1 - (x0, y1, z0), # 2 - (x0, y1, z1), # 3 - (x1, y0, z0), # 4 - (x1, y0, z1), # 5 - (x1, y1, z0), # 6 - (x1, y1, z1), # 7 - ] - faces = [ - (0, 1, 3, 2), # -x - (4, 6, 7, 5), # +x - (0, 4, 5, 1), # -y - (2, 3, 7, 6), # +y - (0, 2, 6, 4), # -z - (1, 5, 7, 3), # +z - ] - mesh = trimesh.Trimesh(vertices, faces) - - section = mesh.section(plane_origin=origin, plane_normal=normal) - if section is None: - return [] - path, _ = section.to_2D(to_2D=to_2D) - return path.polygons_full - - def intersections_plane( - self, - x: Optional[float] = None, - y: Optional[float] = None, - z: Optional[float] = None, - cleanup: bool = True, - quad_segs: Optional[int] = None, - ) -> list[Shapely]: - """Returns shapely geometry at plane specified by one non None value of x,y,z. - - Parameters - ---------- - x : float = None - Position of plane in x direction, only one of x,y,z can be specified to define plane. - y : float = None - Position of plane in y direction, only one of x,y,z can be specified to define plane. - z : float = None - Position of plane in z direction, only one of x,y,z can be specified to define plane. - cleanup : bool = True - If True, removes extremely small features from each polygon's boundary. - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. Not used for Box geometry. - - Returns - ------- - list[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation `_. - """ - axis, position = self.parse_xyz_kwargs(x=x, y=y, z=z) - if not self.intersects_axis_position(axis, position): - return [] - z0, (x0, y0) = self.pop_axis(self.center, axis=axis) - Lz, (Lx, Ly) = self.pop_axis(self.size, axis=axis) - dz = np.abs(z0 - position) - if dz > Lz / 2 + fp_eps: - return [] - - minx = x0 - Lx / 2 - miny = y0 - Ly / 2 - maxx = x0 + Lx / 2 - maxy = y0 + Ly / 2 - - # handle case where the box vertices are identical - if np.isclose(minx, maxx) and np.isclose(miny, maxy): - return [self.make_shapely_point(minx, miny)] - - return [self.make_shapely_box(minx, miny, maxx, maxy)] - - def inside(self, x: NDArray[float], y: NDArray[float], z: NDArray[float]) -> NDArray[bool]: - """For input arrays ``x``, ``y``, ``z`` of arbitrary but identical shape, return an array - with the same shape which is ``True`` for every point in zip(x, y, z) that is inside the - volume of the :class:`Geometry`, and ``False`` otherwise. - - Parameters - ---------- - x : np.ndarray[float] - Array of point positions in x direction. - y : np.ndarray[float] - Array of point positions in y direction. - z : np.ndarray[float] - Array of point positions in z direction. - - Returns - ------- - np.ndarray[bool] - ``True`` for every point that is inside the geometry. - """ - self._ensure_equal_shape(x, y, z) - x0, y0, z0 = self.center - Lx, Ly, Lz = self.size - dist_x = np.abs(x - x0) - dist_y = np.abs(y - y0) - dist_z = np.abs(z - z0) - return (dist_x <= Lx / 2) * (dist_y <= Ly / 2) * (dist_z <= Lz / 2) - - def intersections_with( - self, other: Shapely, cleanup: bool = True, quad_segs: Optional[int] = None - ) -> list[Shapely]: - """Returns list of shapely geometries representing the intersections of the geometry with - this 2D box. - - Parameters - ---------- - other : Shapely - Geometry to intersect with. - cleanup : bool = True - If True, removes extremely small features from each polygon's boundary. - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. If ``None``, uses - high-quality visualization settings. - - Returns - ------- - list[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect this 2D box. - For more details refer to - `Shapely's Documentation `_. - """ - - # Verify 2D - if self.size.count(0.0) != 1: - raise ValidationError( - "Intersections with other geometry are only calculated from a 2D box." - ) - - # dont bother if the geometry doesn't intersect the self at all - if not other.intersects(self): - return [] - - # get list of Shapely shapes that intersect at the self - normal_ind = self.size.index(0.0) - dim = "xyz"[normal_ind] - pos = self.center[normal_ind] - xyz_kwargs = {dim: pos} - shapes_plane = other.intersections_plane(cleanup=cleanup, quad_segs=quad_segs, **xyz_kwargs) - - # intersect all shapes with the input self - bs_min, bs_max = (self.pop_axis(bounds, axis=normal_ind)[1] for bounds in self.bounds) - - shapely_box = self.make_shapely_box(bs_min[0], bs_min[1], bs_max[0], bs_max[1]) - shapely_box = Geometry.evaluate_inf_shape(shapely_box) - return [Geometry.evaluate_inf_shape(shape) & shapely_box for shape in shapes_plane] - - def slightly_enlarged_copy(self) -> Box: - """Box size slightly enlarged around machine precision.""" - size = [increment_float(orig_length, 1) for orig_length in self.size] - return self.updated_copy(size=size) - - def padded_copy( - self, - x: Optional[tuple[pydantic.NonNegativeFloat, pydantic.NonNegativeFloat]] = None, - y: Optional[tuple[pydantic.NonNegativeFloat, pydantic.NonNegativeFloat]] = None, - z: Optional[tuple[pydantic.NonNegativeFloat, pydantic.NonNegativeFloat]] = None, - ) -> Box: - """Created a padded copy of a :class:`Box` instance. - - Parameters - ---------- - x : Optional[tuple[pydantic.NonNegativeFloat, pydantic.NonNegativeFloat]] = None - Padding sizes at the left and right boundaries of the box along x-axis. - y : Optional[tuple[pydantic.NonNegativeFloat, pydantic.NonNegativeFloat]] = None - Padding sizes at the left and right boundaries of the box along y-axis. - z : Optional[tuple[pydantic.NonNegativeFloat, pydantic.NonNegativeFloat]] = None - Padding sizes at the left and right boundaries of the box along z-axis. - - Returns - ------- - Box - Padded instance of :class:`Box`. - """ - - # Validate that padding values are non-negative - for axis_name, axis_padding in zip(("x", "y", "z"), (x, y, z)): - if axis_padding is not None: - if not isinstance(axis_padding, (tuple, list)) or len(axis_padding) != 2: - raise ValueError(f"Padding for {axis_name}-axis must be a tuple of two values.") - if any(p < 0 for p in axis_padding): - raise ValueError( - f"Padding values for {axis_name}-axis must be non-negative. Got {axis_padding}." - ) - - rmin, rmax = self.bounds - - def bound_array(arrs: ArrayLike, idx: int) -> NDArray: - return np.array([(a[idx] if a is not None else 0) for a in arrs]) - - # parse padding sizes for simulation - drmin = bound_array((x, y, z), 0) - drmax = bound_array((x, y, z), 1) - - rmin = np.array(rmin) - drmin - rmax = np.array(rmax) + drmax - - return Box.from_bounds(rmin=rmin, rmax=rmax) - - @cached_property - def bounds(self) -> Bound: - """Returns bounding box min and max coordinates. - - Returns - ------- - tuple[float, float, float], tuple[float, float float] - Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. - """ - size = self.size - center = self.center - coord_min = tuple(c - s / 2 for (s, c) in zip(size, center)) - coord_max = tuple(c + s / 2 for (s, c) in zip(size, center)) - return (coord_min, coord_max) - - @cached_property - def geometry(self) -> Box: - """:class:`Box` representation of self (used for subclasses of Box). - - Returns - ------- - :class:`Box` - Instance of :class:`Box` representing self's geometry. - """ - return Box(center=self.center, size=self.size) - - @cached_property - def zero_dims(self) -> list[Axis]: - """A list of axes along which the :class:`Box` is zero-sized.""" - return [dim for dim, size in enumerate(self.size) if size == 0] - - @cached_property - def _normal_2dmaterial(self) -> Axis: - """Get the normal to the given geometry, checking that it is a 2D geometry.""" - if np.count_nonzero(self.size) != 2: - raise ValidationError( - "'Medium2D' requires exactly one of the 'Box' dimensions to have size zero." - ) - return self.size.index(0) - - def _update_from_bounds(self, bounds: tuple[float, float], axis: Axis) -> Box: - """Returns an updated geometry which has been transformed to fit within ``bounds`` - along the ``axis`` direction.""" - new_center = list(self.center) - new_center[axis] = (bounds[0] + bounds[1]) / 2 - new_size = list(self.size) - new_size[axis] = bounds[1] - bounds[0] - return self.updated_copy(center=tuple(new_center), size=tuple(new_size)) - - def _plot_arrow( - self, - direction: tuple[float, float, float], - x: Optional[float] = None, - y: Optional[float] = None, - z: Optional[float] = None, - color: Optional[str] = None, - alpha: Optional[float] = None, - bend_radius: Optional[float] = None, - bend_axis: Axis = None, - both_dirs: bool = False, - ax: Ax = None, - arrow_base: Coordinate = None, - ) -> Ax: - """Adds an arrow to the axis if with options if certain conditions met. - - Parameters - ---------- - direction: tuple[float, float, float] - Normalized vector describing the arrow direction. - x : float = None - Position of plotting plane in x direction. - y : float = None - Position of plotting plane in y direction. - z : float = None - Position of plotting plane in z direction. - color : str = None - Color of the arrow. - alpha : float = None - Opacity of the arrow (0, 1) - bend_radius : float = None - Radius of curvature for this arrow. - bend_axis : Axis = None - Axis of curvature of ``bend_radius``. - both_dirs : bool = False - If True, plots an arrow pointing in direction and one in -direction. - arrow_base : :class:`.Coordinate` = None - Custom base of the arrow. Uses the geometry's center if not provided. - - Returns - ------- - matplotlib.axes._subplots.Axes - The matplotlib axes with the arrow added. - """ - - plot_axis, _ = self.parse_xyz_kwargs(x=x, y=y, z=z) - _, (dx, dy) = self.pop_axis(direction, axis=plot_axis) - - # conditions to check to determine whether to plot arrow, taking into account the - # possibility of a custom arrow base - arrow_intersecting_plane = len(self.intersections_plane(x=x, y=y, z=z)) > 0 - center = self.center - if arrow_base: - arrow_intersecting_plane = arrow_intersecting_plane and any( - a == b for a, b in zip(arrow_base, [x, y, z]) - ) - center = arrow_base - - _, (dx, dy) = self.pop_axis(direction, axis=plot_axis) - components_in_plane = any(not np.isclose(component, 0) for component in (dx, dy)) - - # plot if arrow in plotting plane and some non-zero component can be displayed. - if arrow_intersecting_plane and components_in_plane: - _, (x0, y0) = self.pop_axis(center, axis=plot_axis) - - # Reasonable value for temporary arrow size. The correct size and direction - # have to be calculated after all transforms have been set. That is why we - # use a callback to do these calculations only at the drawing phase. - xmin, xmax = ax.get_xlim() - ymin, ymax = ax.get_ylim() - v_x = (xmax - xmin) / 10 - v_y = (ymax - ymin) / 10 - - directions = (1.0, -1.0) if both_dirs else (1.0,) - for sign in directions: - arrow = patches.FancyArrowPatch( - (x0, y0), - (x0 + v_x, y0 + v_y), - arrowstyle=arrow_style, - color=color, - alpha=alpha, - zorder=np.inf, - ) - # Don't draw this arrow until it's been reshaped - arrow.set_visible(False) - - callback = self._arrow_shape_cb( - arrow, (x0, y0), (dx, dy), sign, bend_radius if bend_axis == plot_axis else None - ) - callback_id = ax.figure.canvas.mpl_connect("draw_event", callback) - - # Store a reference to the callback because mpl_connect does not. - arrow.set_shape_cb = (callback_id, callback) - - ax.add_patch(arrow) - - return ax - - @staticmethod - def _arrow_shape_cb( - arrow: FancyArrowPatch, - pos: tuple[float, float], - direction: ArrayLike, - sign: float, - bend_radius: float | None, - ) -> Callable[[Event], None]: - def _cb(event: Event) -> None: - # We only want to set the shape once, so we disconnect ourselves - event.canvas.mpl_disconnect(arrow.set_shape_cb[0]) - - transform = arrow.axes.transData.transform - scale_x = transform((1, 0))[0] - transform((0, 0))[0] - scale_y = transform((0, 1))[1] - transform((0, 0))[1] - scale = max(scale_x, scale_y) # <-- Hack: This is a somewhat arbitrary choice. - arrow_length = ARROW_LENGTH * event.canvas.figure.get_dpi() / scale - - if bend_radius: - v_norm = (direction[0] ** 2 + direction[1] ** 2) ** 0.5 - vx_norm = direction[0] / v_norm - vy_norm = direction[1] / v_norm - bend_angle = -sign * arrow_length / bend_radius - t_x = 1 - np.cos(bend_angle) - t_y = np.sin(bend_angle) - v_x = -bend_radius * (vx_norm * t_y - vy_norm * t_x) - v_y = -bend_radius * (vx_norm * t_x + vy_norm * t_y) - tangent_angle = np.arctan2(direction[1], direction[0]) - arrow.set_connectionstyle( - patches.ConnectionStyle.Angle3( - angleA=180 / np.pi * tangent_angle, - angleB=180 / np.pi * (tangent_angle + bend_angle), - ) - ) - - else: - v_x = sign * arrow_length * direction[0] - v_y = sign * arrow_length * direction[1] - - arrow.set_positions(pos, (pos[0] + v_x, pos[1] + v_y)) - arrow.set_visible(True) - arrow.draw(event.renderer) - - return _cb - - def _volume(self, bounds: Bound) -> float: - """Returns object's volume within given bounds.""" - - volume = 1 - - for axis in range(3): - min_bound = max(self.bounds[0][axis], bounds[0][axis]) - max_bound = min(self.bounds[1][axis], bounds[1][axis]) - - volume *= max_bound - min_bound - - return volume - - def _surface_area(self, bounds: Bound) -> float: - """Returns object's surface area within given bounds.""" - - min_bounds = list(self.bounds[0]) - max_bounds = list(self.bounds[1]) - - in_bounds_factor = [2, 2, 2] - length = [0, 0, 0] - - for axis in (0, 1, 2): - if min_bounds[axis] < bounds[0][axis]: - min_bounds[axis] = bounds[0][axis] - in_bounds_factor[axis] -= 1 - - if max_bounds[axis] > bounds[1][axis]: - max_bounds[axis] = bounds[1][axis] - in_bounds_factor[axis] -= 1 - - length[axis] = max_bounds[axis] - min_bounds[axis] - - return ( - length[0] * length[1] * in_bounds_factor[2] - + length[1] * length[2] * in_bounds_factor[0] - + length[2] * length[0] * in_bounds_factor[1] - ) - - """ Autograd code """ - - def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: - """Compute the adjoint derivatives for this object.""" - - # get gradients w.r.t. each of the 6 faces (in normal direction) - vjps_faces = self._derivative_faces(derivative_info=derivative_info) - - # post-process these values to give the gradients w.r.t. center and size - vjps_center_size = self._derivatives_center_size(vjps_faces=vjps_faces) - - # store only the gradients asked for in 'field_paths' - derivative_map = {} - for field_path in derivative_info.paths: - field_name, *index = field_path - - if field_name in vjps_center_size: - # if the vjp calls for a specific index into the tuple - if index and len(index) == 1: - index = int(index[0]) - if field_path not in derivative_map: - derivative_map[field_path] = vjps_center_size[field_name][index] - - # otherwise, just grab the whole array - else: - derivative_map[field_path] = vjps_center_size[field_name] - - return derivative_map - - @staticmethod - def _derivatives_center_size(vjps_faces: Bound) -> dict[str, Coordinate]: - """Derivatives with respect to the ``center`` and ``size`` fields in the ``Box``.""" - - vjps_faces_min, vjps_faces_max = np.array(vjps_faces) - - # post-process min and max face gradients into center and size - vjp_center = vjps_faces_max - vjps_faces_min - vjp_size = (vjps_faces_min + vjps_faces_max) / 2.0 - - return { - "center": tuple(vjp_center.tolist()), - "size": tuple(vjp_size.tolist()), - } - - def _derivative_faces(self, derivative_info: DerivativeInfo) -> Bound: - """Derivative with respect to normal position of 6 faces of ``Box``.""" - - axes_to_compute = (0, 1, 2) - if len(derivative_info.paths[0]) > 1: - axes_to_compute = tuple(info[1] for info in derivative_info.paths) - - # change in permittivity between inside and outside - vjp_faces = np.zeros((2, 3)) - - for min_max_index, _ in enumerate((0, -1)): - for axis in axes_to_compute: - vjp_face = self._derivative_face( - min_max_index=min_max_index, - axis_normal=axis, - derivative_info=derivative_info, - ) - - # record vjp for this face - vjp_faces[min_max_index, axis] = vjp_face - - return vjp_faces - - def _derivative_face( - self, - min_max_index: int, - axis_normal: Axis, - derivative_info: DerivativeInfo, - ) -> float: - """Compute the derivative w.r.t. shifting a face in the normal direction.""" - - interpolators = derivative_info.interpolators or derivative_info.create_interpolators() - _, axis_perp = self.pop_axis((0, 1, 2), axis=axis_normal) - - # First, check if the face is outside the simulation domain in which case set the - # face gradient to 0. - bounds_normal, bounds_perp = self.pop_axis( - np.array(derivative_info.bounds).T, axis=axis_normal - ) - coord_normal_face = bounds_normal[min_max_index] - - if min_max_index == 0: - if coord_normal_face < derivative_info.simulation_bounds[0][axis_normal]: - return 0.0 - else: - if coord_normal_face > derivative_info.simulation_bounds[1][axis_normal]: - return 0.0 - - intersect_min, intersect_max = map(np.asarray, derivative_info.bounds_intersect) - extents = intersect_max - intersect_min - _, intersect_min_perp = self.pop_axis(np.array(intersect_min), axis=axis_normal) - _, intersect_max_perp = self.pop_axis(np.array(intersect_max), axis=axis_normal) - - is_2d_map = [] - for axis_idx in range(3): - if axis_idx == axis_normal: - continue - is_2d_map.append(np.isclose(extents[axis_idx], 0.0)) - - if np.all(is_2d_map): - return 0.0 - - is_2d = np.any(is_2d_map) - - sim_bounds_normal, sim_bounds_perp = self.pop_axis( - np.array(derivative_info.simulation_bounds).T, axis=axis_normal - ) - - # Build point grid - adaptive_spacing = derivative_info.adaptive_vjp_spacing() - - def spacing_to_grid_points( - spacing: float, min_coord: float, max_coord: float - ) -> NDArray[float]: - N = np.maximum(3, 1 + int((max_coord - min_coord) / spacing)) - - points = np.linspace(min_coord, max_coord, N) - centers = 0.5 * (points[0:-1] + points[1:]) - - return centers - - def verify_integration_interval(bound: tuple[float, float]) -> bool: - # assume the bounds should not be equal or else this integration interval - # would be the flat dimension of a 2D geometry. - return bound[1] > bound[0] - - def compute_integration_weight(grid_points: NDArray[float]) -> float: - grid_spacing = grid_points[1] - grid_points[0] - if grid_spacing == 0.0: - integration_weight = 1.0 / len(grid_points) - else: - integration_weight = grid_points[1] - grid_points[0] - - return integration_weight - - if is_2d: - # build 1D grid for sampling points along the face, which is an edge in the 2D case - zero_dim = np.where(is_2d_map)[0][0] - # zero dim is one of the perpendicular directions, so the other perpendicular direction - # is the nonzero dimension - nonzero_dim = 1 - zero_dim - - # clip at simulation bounds for integration dimension - integration_bounds_perp = ( - intersect_min_perp[nonzero_dim], - intersect_max_perp[nonzero_dim], - ) - - if not verify_integration_interval(integration_bounds_perp): - return 0.0 - - grid_points_linear = spacing_to_grid_points( - adaptive_spacing, integration_bounds_perp[0], integration_bounds_perp[1] - ) - integration_weight = compute_integration_weight(grid_points_linear) - - grid_points = np.repeat(np.expand_dims(grid_points_linear.copy(), 1), 3, axis=1) - - # set up grid points to pass into evaluate_gradient_at_points - grid_points[:, axis_perp[nonzero_dim]] = grid_points_linear - grid_points[:, axis_perp[zero_dim]] = intersect_min_perp[zero_dim] - grid_points[:, axis_normal] = coord_normal_face - else: - # build 3D grid for sampling points along the face - - # clip at simulation bounds for each integration dimension - integration_bounds_perp = ( - (intersect_min_perp[0], intersect_max_perp[0]), - (intersect_min_perp[1], intersect_max_perp[1]), - ) - - if not np.all([verify_integration_interval(b) for b in integration_bounds_perp]): - return 0.0 - - grid_points_perp_1 = spacing_to_grid_points( - adaptive_spacing, integration_bounds_perp[0][0], integration_bounds_perp[0][1] - ) - grid_points_perp_2 = spacing_to_grid_points( - adaptive_spacing, integration_bounds_perp[1][0], integration_bounds_perp[1][1] - ) - integration_weight = compute_integration_weight( - grid_points_perp_1 - ) * compute_integration_weight(grid_points_perp_2) - - mesh_perp1, mesh_perp2 = np.meshgrid(grid_points_perp_1, grid_points_perp_2) - - zip_perp_coords = np.array(list(zip(mesh_perp1.flatten(), mesh_perp2.flatten()))) - - grid_points = np.pad(zip_perp_coords.copy(), ((0, 0), (1, 0)), mode="constant") - - # set up grid points to pass into evaluate_gradient_at_points - grid_points[:, axis_perp[0]] = zip_perp_coords[:, 0] - grid_points[:, axis_perp[1]] = zip_perp_coords[:, 1] - grid_points[:, axis_normal] = coord_normal_face - - normals = np.zeros_like(grid_points) - perps1 = np.zeros_like(grid_points) - perps2 = np.zeros_like(grid_points) - - normals[:, axis_normal] = -1 if (min_max_index == 0) else 1 - perps1[:, axis_perp[0]] = 1 - perps2[:, axis_perp[1]] = 1 - - gradient_at_points = derivative_info.evaluate_gradient_at_points( - spatial_coords=grid_points, - normals=normals, - perps1=perps1, - perps2=perps2, - interpolators=interpolators, - ) - - vjp_value = np.sum(integration_weight * np.real(gradient_at_points)) - return vjp_value - - -"""Compound subclasses""" - - -class Transformed(Geometry): - """Class representing a transformed geometry.""" - - geometry: discriminated_union(GeometryType) = Field( - title="Geometry", - description="Base geometry to be transformed.", - ) - - transform: MatrixReal4x4 = Field( - default_factory=lambda: np.eye(4).tolist(), - title="Transform", - description="Transform matrix applied to the base geometry.", - ) - - @field_validator("transform") - @classmethod - def _transform_is_invertible(cls, val: MatrixReal4x4) -> MatrixReal4x4: - # If the transform is not invertible, this will raise an error - _ = np.linalg.inv(val) - return val - - @field_validator("geometry") - @classmethod - def _geometry_is_finite(cls, val: GeometryType) -> GeometryType: - if not np.isfinite(val.bounds).all(): - raise ValidationError( - "Transformations are only supported on geometries with finite dimensions. " - "Try using a large value instead of 'inf' when creating geometries that undergo " - "transformations." - ) - return val - - @model_validator(mode="after") - def _apply_transforms(self: dict[str, Any]) -> dict[str, Any]: - while isinstance(self.geometry, Transformed): - inner = self.geometry - object.__setattr__(self, "geometry", inner.geometry) - object.__setattr__(self, "transform", np.dot(self.transform, inner.transform)) - return self - - @cached_property - def inverse(self) -> MatrixReal4x4: - """Inverse of this transform.""" - return np.linalg.inv(self.transform) - - @staticmethod - def _vertices_from_bounds(bounds: Bound) -> ArrayFloat2D: - """Return the 8 vertices derived from bounds. - - The vertices are returned as homogeneous coordinates (with 4 components). - - Parameters - ---------- - bounds : Bound - Bounds from which to derive the vertices. - - Returns - ------- - ArrayFloat2D - Array with shape (4, 8) with all vertices from ``bounds``. - """ - (x0, y0, z0), (x1, y1, z1) = bounds - return np.array( - ( - (x0, x0, x0, x0, x1, x1, x1, x1), - (y0, y0, y1, y1, y0, y0, y1, y1), - (z0, z1, z0, z1, z0, z1, z0, z1), - (1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0), - ) - ) - - @cached_property - def bounds(self) -> Bound: - """Returns bounding box min and max coordinates. - - Returns - ------- - tuple[float, float, float], tuple[float, float, float] - Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. - """ - # NOTE (Lucas): The bounds are overestimated because we don't want to calculate - # precise TriangleMesh representations for GeometryGroup or ClipOperation. - vertices = np.dot(self.transform, self._vertices_from_bounds(self.geometry.bounds))[:3] - return (tuple(vertices.min(axis=1)), tuple(vertices.max(axis=1))) - - def intersections_tilted_plane( - self, - normal: Coordinate, - origin: Coordinate, - to_2D: MatrixReal4x4, - cleanup: bool = True, - quad_segs: Optional[int] = None, - ) -> list[Shapely]: - """Return a list of shapely geometries at the plane specified by normal and origin. - - Parameters - ---------- - normal : Coordinate - Vector defining the normal direction to the plane. - origin : Coordinate - Vector defining the plane origin. - to_2D : MatrixReal4x4 - Transformation matrix to apply to resulting shapes. - cleanup : bool = True - If True, removes extremely small features from each polygon's boundary. - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. If ``None``, uses - high-quality visualization settings. - - Returns - ------- - list[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation `_. - """ - return self.geometry.intersections_tilted_plane( - tuple(np.dot((normal[0], normal[1], normal[2], 0.0), self.transform)[:3]), - tuple(np.dot(self.inverse, (origin[0], origin[1], origin[2], 1.0))[:3]), - np.dot(to_2D, self.transform), - cleanup=cleanup, - quad_segs=quad_segs, - ) - - def inside(self, x: NDArray[float], y: NDArray[float], z: NDArray[float]) -> NDArray[bool]: - """For input arrays ``x``, ``y``, ``z`` of arbitrary but identical shape, return an array - with the same shape which is ``True`` for every point in zip(x, y, z) that is inside the - volume of the :class:`Geometry`, and ``False`` otherwise. - - Parameters - ---------- - x : np.ndarray[float] - Array of point positions in x direction. - y : np.ndarray[float] - Array of point positions in y direction. - z : np.ndarray[float] - Array of point positions in z direction. - - Returns - ------- - np.ndarray[bool] - ``True`` for every point that is inside the geometry. - """ - x = np.array(x) - y = np.array(y) - z = np.array(z) - xyz = np.dot(self.inverse, np.vstack((x.flat, y.flat, z.flat, np.ones(x.size)))) - if xyz.shape[1] == 1: - # TODO: This "fix" is required because of a bug in PolySlab.inside (with non-zero sidewall angle) - return self.geometry.inside(xyz[0][0], xyz[1][0], xyz[2][0]).reshape(x.shape) - return self.geometry.inside(xyz[0], xyz[1], xyz[2]).reshape(x.shape) - - def _volume(self, bounds: Bound) -> float: - """Returns object's volume within given bounds.""" - # NOTE (Lucas): Bounds are overestimated. - vertices = np.dot(self.inverse, self._vertices_from_bounds(bounds))[:3] - inverse_bounds = (tuple(vertices.min(axis=1)), tuple(vertices.max(axis=1))) - return abs(np.linalg.det(self.transform)) * self.geometry.volume(inverse_bounds) - - def _surface_area(self, bounds: Bound) -> float: - """Returns object's surface area within given bounds.""" - log.warning("Surface area of transformed elements cannot be calculated.") - return None - - @staticmethod - def translation(x: float, y: float, z: float) -> MatrixReal4x4: - """Return a translation matrix. - - Parameters - ---------- - x : float - Translation along x. - y : float - Translation along y. - z : float - Translation along z. - - Returns - ------- - numpy.ndarray - Transform matrix with shape (4, 4). - """ - return np.array( - [ - (1.0, 0.0, 0.0, x), - (0.0, 1.0, 0.0, y), - (0.0, 0.0, 1.0, z), - (0.0, 0.0, 0.0, 1.0), - ], - dtype=float, - ) - - @staticmethod - def scaling(x: float = 1.0, y: float = 1.0, z: float = 1.0) -> MatrixReal4x4: - """Return a scaling matrix. - - Parameters - ---------- - x : float = 1.0 - Scaling factor along x. - y : float = 1.0 - Scaling factor along y. - z : float = 1.0 - Scaling factor along z. - - Returns - ------- - numpy.ndarray - Transform matrix with shape (4, 4). - """ - if np.isclose((x, y, z), 0.0).any(): - raise Tidy3dError("Scaling factors cannot be zero in any dimensions.") - return np.array( - [ - (x, 0.0, 0.0, 0.0), - (0.0, y, 0.0, 0.0), - (0.0, 0.0, z, 0.0), - (0.0, 0.0, 0.0, 1.0), - ], - dtype=float, - ) - - @staticmethod - def rotation(angle: float, axis: Union[Axis, Coordinate]) -> MatrixReal4x4: - """Return a rotation matrix. - - Parameters - ---------- - angle : float - Rotation angle (in radians). - axis : Union[int, tuple[float, float, float]] - Axis of rotation: 0, 1, or 2 for x, y, and z, respectively, or a 3D vector. - - Returns - ------- - numpy.ndarray - Transform matrix with shape (4, 4). - """ - transform = np.eye(4) - transform[:3, :3] = RotationAroundAxis(angle=angle, axis=axis).matrix - return transform - - @staticmethod - def reflection(normal: Coordinate) -> MatrixReal4x4: - """Return a reflection matrix. - - Parameters - ---------- - normal : tuple[float, float, float] - Normal of the plane of reflection. - - Returns - ------- - numpy.ndarray - Transform matrix with shape (4, 4). - """ - - transform = np.eye(4) - transform[:3, :3] = ReflectionFromPlane(normal=normal).matrix - return transform - - @staticmethod - def preserves_axis(transform: MatrixReal4x4, axis: Axis) -> bool: - """Indicate if the transform preserves the orientation of a given axis. - - Parameters: - transform: MatrixReal4x4 - Transform matrix to check. - axis : int - Axis to check. Values 0, 1, or 2, to check x, y, or z, respectively. - - Returns - ------- - bool - ``True`` if the transformation preserves the axis orientation, ``False`` otherwise. - """ - i = (axis + 1) % 3 - j = (axis + 2) % 3 - return np.isclose(transform[i, axis], 0) and np.isclose(transform[j, axis], 0) - - @cached_property - def _normal_2dmaterial(self) -> Axis: - """Get the normal to the given geometry, checking that it is a 2D geometry.""" - normal = self.geometry._normal_2dmaterial - preserves_axis = Transformed.preserves_axis(self.transform, normal) - - if not preserves_axis: - raise ValidationError( - "'Medium2D' requires geometries of type 'Transformed' to " - "perserve the axis normal to the 'Medium2D'." - ) - - return normal - - def _update_from_bounds(self, bounds: tuple[float, float], axis: Axis) -> Transformed: - """Returns an updated geometry which has been transformed to fit within ``bounds`` - along the ``axis`` direction.""" - min_bound = np.array([0, 0, 0, 1.0]) - min_bound[axis] = bounds[0] - max_bound = np.array([0, 0, 0, 1.0]) - max_bound[axis] = bounds[1] - new_bounds = [] - new_bounds.append(np.dot(self.inverse, min_bound)[axis]) - new_bounds.append(np.dot(self.inverse, max_bound)[axis]) - new_geometry = self.geometry._update_from_bounds(bounds=new_bounds, axis=axis) - return self.updated_copy(geometry=new_geometry) - - -class ClipOperation(Geometry): - """Class representing the result of a set operation between geometries.""" - - operation: ClipOperationType = Field( - title="Operation Type", - description="Operation to be performed between geometries.", - ) - - geometry_a: discriminated_union(GeometryType) = Field( - title="Geometry A", - description="First operand for the set operation. It can be any geometry type, including " - ":class:`GeometryGroup`.", - ) - - geometry_b: discriminated_union(GeometryType) = Field( - title="Geometry B", - description="Second operand for the set operation. It can also be any geometry type.", - ) - - @field_validator("geometry_a", "geometry_b") - @classmethod - def _geometries_untraced(cls, val: GeometryType) -> GeometryType: - """Make sure that ``ClipOperation`` geometries do not contain tracers.""" - traced = val._strip_traced_fields() - if traced: - raise ValidationError( - f"{val.type} contains traced fields {list(traced.keys())}. Note that " - "'ClipOperation' does not currently support automatic differentiation." - ) - return val - - @staticmethod - def to_polygon_list(base_geometry: Shapely, cleanup: bool = False) -> list[Shapely]: - """Return a list of valid polygons from a shapely geometry, discarding points, lines, and - empty polygons, and empty triangles within polygons. - - Parameters - ---------- - base_geometry : shapely.geometry.base.BaseGeometry - Base geometry for inspection. - cleanup: bool = False - If True, removes extremely small features from each polygon's boundary. - This is useful for removing artifacts from 2D plots displayed to the user. - - Returns - ------- - list[shapely.geometry.base.BaseGeometry] - Valid polygons retrieved from ``base geometry``. - """ - unfiltered_geoms = [] - if base_geometry.geom_type == "GeometryCollection": - unfiltered_geoms = [ - p - for geom in base_geometry.geoms - for p in ClipOperation.to_polygon_list(geom, cleanup) - ] - if base_geometry.geom_type == "MultiPolygon": - unfiltered_geoms = [p for p in base_geometry.geoms if not p.is_empty] - if base_geometry.geom_type == "Polygon" and not base_geometry.is_empty: - unfiltered_geoms = [base_geometry] - geoms = [] - if cleanup: - # Optional: "clean" each of the polygons (by removing extremely small or thin features). - for geom in unfiltered_geoms: - geom_clean = cleanup_shapely_object(geom) - if geom_clean.geom_type == "Polygon": - geoms.append(geom_clean) - if geom_clean.geom_type == "MultiPolygon": - geoms += [p for p in geom_clean.geoms if not p.is_empty] - # Ignore other types of shapely objects (points and lines) - else: - geoms = unfiltered_geoms - return geoms - - @property - def _shapely_operation(self) -> Callable[[Shapely, Shapely], Shapely]: - """Return a Shapely function equivalent to this operation.""" - result = _shapely_operations.get(self.operation, None) - if not result: - raise ValueError( - "'operation' must be one of 'union', 'intersection', 'difference', or " - "'symmetric_difference'." - ) - return result - - @property - def _bit_operation(self) -> Callable[[Any, Any], Any]: - """Return a function equivalent to this operation using bit operators.""" - result = _bit_operations.get(self.operation, None) - if not result: - raise ValueError( - "'operation' must be one of 'union', 'intersection', 'difference', or " - "'symmetric_difference'." - ) - return result - - def intersections_tilted_plane( - self, - normal: Coordinate, - origin: Coordinate, - to_2D: MatrixReal4x4, - cleanup: bool = True, - quad_segs: Optional[int] = None, - ) -> list[Shapely]: - """Return a list of shapely geometries at the plane specified by normal and origin. - - Parameters - ---------- - normal : Coordinate - Vector defining the normal direction to the plane. - origin : Coordinate - Vector defining the plane origin. - to_2D : MatrixReal4x4 - Transformation matrix to apply to resulting shapes. - cleanup : bool = True - If True, removes extremely small features from each polygon's boundary. - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. If ``None``, uses - high-quality visualization settings. - - Returns - ------- - list[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation `_. - """ - a = self.geometry_a.intersections_tilted_plane( - normal, origin, to_2D, cleanup=cleanup, quad_segs=quad_segs - ) - b = self.geometry_b.intersections_tilted_plane( - normal, origin, to_2D, cleanup=cleanup, quad_segs=quad_segs - ) - geom_a = shapely.unary_union([Geometry.evaluate_inf_shape(g) for g in a]) - geom_b = shapely.unary_union([Geometry.evaluate_inf_shape(g) for g in b]) - return ClipOperation.to_polygon_list( - self._shapely_operation(geom_a, geom_b), - cleanup=cleanup, - ) - - def intersections_plane( - self, - x: Optional[float] = None, - y: Optional[float] = None, - z: Optional[float] = None, - cleanup: bool = True, - quad_segs: Optional[int] = None, - ) -> list[Shapely]: - """Returns list of shapely geometries at plane specified by one non-None value of x,y,z. - - Parameters - ---------- - x : float = None - Position of plane in x direction, only one of x,y,z can be specified to define plane. - y : float = None - Position of plane in y direction, only one of x,y,z can be specified to define plane. - z : float = None - Position of plane in z direction, only one of x,y,z can be specified to define plane. - cleanup : bool = True - If True, removes extremely small features from each polygon's boundary. - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. If ``None``, uses - high-quality visualization settings. - - Returns - ------- - list[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentaton `_. - """ - a = self.geometry_a.intersections_plane(x, y, z, cleanup=cleanup, quad_segs=quad_segs) - b = self.geometry_b.intersections_plane(x, y, z, cleanup=cleanup, quad_segs=quad_segs) - geom_a = shapely.unary_union([Geometry.evaluate_inf_shape(g) for g in a]) - geom_b = shapely.unary_union([Geometry.evaluate_inf_shape(g) for g in b]) - return ClipOperation.to_polygon_list( - self._shapely_operation(geom_a, geom_b), - cleanup=cleanup, - ) - - @cached_property - def bounds(self) -> Bound: - """Returns bounding box min and max coordinates. - - Returns - ------- - tuple[float, float, float], tuple[float, float float] - Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. - """ - # Overestimates - if self.operation == "difference": - result = self.geometry_a.bounds - elif self.operation == "intersection": - bounds = (self.geometry_a.bounds, self.geometry_b.bounds) - result = ( - tuple(max(b[i] for b, _ in bounds) for i in range(3)), - tuple(min(b[i] for _, b in bounds) for i in range(3)), - ) - if any(result[0][i] > result[1][i] for i in range(3)): - result = ((0, 0, 0), (0, 0, 0)) - else: - bounds = (self.geometry_a.bounds, self.geometry_b.bounds) - result = ( - tuple(min(b[i] for b, _ in bounds) for i in range(3)), - tuple(max(b[i] for _, b in bounds) for i in range(3)), - ) - return result - - def inside(self, x: NDArray[float], y: NDArray[float], z: NDArray[float]) -> NDArray[bool]: - """For input arrays ``x``, ``y``, ``z`` of arbitrary but identical shape, return an array - with the same shape which is ``True`` for every point in zip(x, y, z) that is inside the - volume of the :class:`Geometry`, and ``False`` otherwise. - - Parameters - ---------- - x : np.ndarray[float] - Array of point positions in x direction. - y : np.ndarray[float] - Array of point positions in y direction. - z : np.ndarray[float] - Array of point positions in z direction. - - Returns - ------- - np.ndarray[bool] - ``True`` for every point that is inside the geometry. - """ - inside_a = self.geometry_a.inside(x, y, z) - inside_b = self.geometry_b.inside(x, y, z) - return self._bit_operation(inside_a, inside_b) - - def inside_meshgrid( - self, x: NDArray[float], y: NDArray[float], z: NDArray[float] - ) -> NDArray[bool]: - """Faster way to check ``self.inside`` on a meshgrid. The input arrays are assumed sorted. - - Parameters - ---------- - x : np.ndarray[float] - 1D array of point positions in x direction. - y : np.ndarray[float] - 1D array of point positions in y direction. - z : np.ndarray[float] - 1D array of point positions in z direction. - - Returns - ------- - np.ndarray[bool] - Array with shape ``(x.size, y.size, z.size)``, which is ``True`` for every - point that is inside the geometry. - """ - inside_a = self.geometry_a.inside_meshgrid(x, y, z) - inside_b = self.geometry_b.inside_meshgrid(x, y, z) - return self._bit_operation(inside_a, inside_b) - - def _volume(self, bounds: Bound) -> float: - """Returns object's volume within given bounds.""" - # Overestimates - if self.operation == "intersection": - return min(self.geometry_a.volume(bounds), self.geometry_b.volume(bounds)) - if self.operation == "difference": - return self.geometry_a.volume(bounds) - return self.geometry_a.volume(bounds) + self.geometry_b.volume(bounds) - - def _surface_area(self, bounds: Bound) -> float: - """Returns object's surface area within given bounds.""" - # Overestimates - return self.geometry_a.surface_area(bounds) + self.geometry_b.surface_area(bounds) - - @cached_property - def _normal_2dmaterial(self) -> Axis: - """Get the normal to the given geometry, checking that it is a 2D geometry.""" - normal_a = self.geometry_a._normal_2dmaterial - normal_b = self.geometry_b._normal_2dmaterial - - if normal_a != normal_b: - raise ValidationError( - "'Medium2D' requires both geometries in the 'ClipOperation' to " - "have exactly one dimension with zero size in common." - ) - - plane_position_a = self.geometry_a.bounds[0][normal_a] - plane_position_b = self.geometry_b.bounds[0][normal_b] - - if plane_position_a != plane_position_b: - raise ValidationError( - "'Medium2D' requires both geometries in the 'ClipOperation' to be co-planar." - ) - return normal_a - - def _update_from_bounds(self, bounds: tuple[float, float], axis: Axis) -> ClipOperation: - """Returns an updated geometry which has been transformed to fit within ``bounds`` - along the ``axis`` direction.""" - new_geom_a = self.geometry_a._update_from_bounds(bounds=bounds, axis=axis) - new_geom_b = self.geometry_b._update_from_bounds(bounds=bounds, axis=axis) - return self.updated_copy(geometry_a=new_geom_a, geometry_b=new_geom_b) - - -class GeometryGroup(Geometry): - """A collection of Geometry objects that can be called as a single geometry object.""" - - geometries: tuple[discriminated_union(GeometryType), ...] = Field( - title="Geometries", - description="Tuple of geometries in a single grouping. " - "Can provide significant performance enhancement in ``Structure`` when all geometries are " - "assigned the same medium.", - ) - - @field_validator("geometries") - @classmethod - def _geometries_not_empty(cls, val: tuple[GeometryType, ...]) -> tuple[GeometryType, ...]: - """make sure geometries are not empty.""" - if not len(val) > 0: - raise ValidationError("GeometryGroup.geometries must not be empty.") - return val - - @cached_property - def bounds(self) -> Bound: - """Returns bounding box min and max coordinates. - - Returns - ------- - tuple[float, float, float], tuple[float, float, float] - Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. - """ - - bounds = tuple(geometry.bounds for geometry in self.geometries) - return ( - tuple(min(b[i] for b, _ in bounds) for i in range(3)), - tuple(max(b[i] for _, b in bounds) for i in range(3)), - ) - - def intersections_tilted_plane( - self, - normal: Coordinate, - origin: Coordinate, - to_2D: MatrixReal4x4, - cleanup: bool = True, - quad_segs: Optional[int] = None, - ) -> list[Shapely]: - """Return a list of shapely geometries at the plane specified by normal and origin. - - Parameters - ---------- - normal : Coordinate - Vector defining the normal direction to the plane. - origin : Coordinate - Vector defining the plane origin. - to_2D : MatrixReal4x4 - Transformation matrix to apply to resulting shapes. - cleanup : bool = True - If True, removes extremely small features from each polygon's boundary. - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. If ``None``, uses - high-quality visualization settings. - - Returns - ------- - list[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation `_. - """ - return [ - intersection - for geometry in self.geometries - for intersection in geometry.intersections_tilted_plane( - normal, origin, to_2D, cleanup=cleanup, quad_segs=quad_segs - ) - ] - - def intersections_plane( - self, - x: Optional[float] = None, - y: Optional[float] = None, - z: Optional[float] = None, - cleanup: bool = True, - quad_segs: Optional[int] = None, - ) -> list[Shapely]: - """Returns list of shapely geometries at plane specified by one non-None value of x,y,z. - - Parameters - ---------- - x : float = None - Position of plane in x direction, only one of x,y,z can be specified to define plane. - y : float = None - Position of plane in y direction, only one of x,y,z can be specified to define plane. - z : float = None - Position of plane in z direction, only one of x,y,z can be specified to define plane. - cleanup : bool = True - If True, removes extremely small features from each polygon's boundary. - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. If ``None``, uses - high-quality visualization settings. - - Returns - ------- - list[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation `_. - """ - if not self.intersects_plane(x, y, z): - return [] - return [ - intersection - for geometry in self.geometries - for intersection in geometry.intersections_plane( - x=x, y=y, z=z, cleanup=cleanup, quad_segs=quad_segs - ) - ] - - def intersects_axis_position(self, axis: float, position: float) -> bool: - """Whether self intersects plane specified by a given position along a normal axis. - - Parameters - ---------- - axis : int = None - Axis normal to the plane. - position : float = None - Position of plane along the normal axis. - - Returns - ------- - bool - Whether this geometry intersects the plane. - """ - return any(geom.intersects_axis_position(axis, position) for geom in self.geometries) - - def inside(self, x: NDArray[float], y: NDArray[float], z: NDArray[float]) -> NDArray[bool]: - """For input arrays ``x``, ``y``, ``z`` of arbitrary but identical shape, return an array - with the same shape which is ``True`` for every point in zip(x, y, z) that is inside the - volume of the :class:`Geometry`, and ``False`` otherwise. - - Parameters - ---------- - x : np.ndarray[float] - Array of point positions in x direction. - y : np.ndarray[float] - Array of point positions in y direction. - z : np.ndarray[float] - Array of point positions in z direction. - - Returns - ------- - np.ndarray[bool] - ``True`` for every point that is inside the geometry. - """ - individual_insides = (geometry.inside(x, y, z) for geometry in self.geometries) - return functools.reduce(lambda a, b: a | b, individual_insides) - - def inside_meshgrid( - self, x: NDArray[float], y: NDArray[float], z: NDArray[float] - ) -> NDArray[bool]: - """Faster way to check ``self.inside`` on a meshgrid. The input arrays are assumed sorted. - - Parameters - ---------- - x : np.ndarray[float] - 1D array of point positions in x direction. - y : np.ndarray[float] - 1D array of point positions in y direction. - z : np.ndarray[float] - 1D array of point positions in z direction. - - Returns - ------- - np.ndarray[bool] - Array with shape ``(x.size, y.size, z.size)``, which is ``True`` for every - point that is inside the geometry. - """ - individual_insides = (geom.inside_meshgrid(x, y, z) for geom in self.geometries) - return functools.reduce(lambda a, b: a | b, individual_insides) - - def _volume(self, bounds: Bound) -> float: - """Returns object's volume within given bounds.""" - return sum(geometry.volume(bounds) for geometry in self.geometries) - - def _surface_area(self, bounds: Bound) -> float: - """Returns object's surface area within given bounds.""" - return sum(geometry.surface_area(bounds) for geometry in self.geometries) - - @cached_property - def _normal_2dmaterial(self) -> Axis: - """Get the normal to the given geometry, checking that it is a 2D geometry.""" - - normals = {geom._normal_2dmaterial for geom in self.geometries} - - if len(normals) != 1: - raise ValidationError( - "'Medium2D' requires all geometries in the 'GeometryGroup' to " - "share exactly one dimension with zero size." - ) - normal = list(normals)[0] - positions = {geom.bounds[0][normal] for geom in self.geometries} - if len(positions) != 1: - raise ValidationError( - "'Medium2D' requires all geometries in the 'GeometryGroup' to be co-planar." - ) - return normal - - def _update_from_bounds(self, bounds: tuple[float, float], axis: Axis) -> GeometryGroup: - """Returns an updated geometry which has been transformed to fit within ``bounds`` - along the ``axis`` direction.""" - new_geometries = tuple( - geometry._update_from_bounds(bounds=bounds, axis=axis) for geometry in self.geometries - ) - return self.updated_copy(geometries=new_geometries) - - def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: - """Compute the adjoint derivatives for this object.""" - - grad_vjps = {} - - # create interpolators once for all geometries to avoid redundant field data conversions - interpolators = derivative_info.interpolators or derivative_info.create_interpolators() - - with derivative_info.cache_min_spacing_from_permittivity(): - for field_path in derivative_info.paths: - _, index, *geo_path = field_path - - geo = self.geometries[index] - # pass pre-computed interpolators if available - geo_info = derivative_info.updated_copy( - paths=[tuple(geo_path)], - bounds=geo.bounds, - bounds_intersect=self.bounds_intersection( - geo.bounds, derivative_info.simulation_bounds - ), - deep=False, - interpolators=interpolators, - ) - - vjp_dict_geo = geo._compute_derivatives(geo_info) - - if len(vjp_dict_geo) != 1: - raise AssertionError("Got multiple gradients for single geometry field.") - - grad_vjps[field_path] = vjp_dict_geo.popitem()[1] - - return grad_vjps - - -def cleanup_shapely_object(obj: Shapely, tolerance_ratio: float = POLY_TOLERANCE_RATIO) -> Shapely: - """Remove small geometric features from the boundaries of a shapely object including - inward and outward spikes, thin holes, and thin connections between larger regions. - - Parameters - ---------- - obj : shapely - a shapely object (typically a ``Polygon`` or a ``MultiPolygon``) - tolerance_ratio : float = ``POLY_TOLERANCE_RATIO`` - Features on the boundaries of polygons will be discarded if they are smaller - or narrower than ``tolerance_ratio`` multiplied by the size of the object. - - Returns - ------- - Shapely - A new shapely object whose small features (eg. thin spikes or holes) are removed. - - Notes - ----- - This function does not attempt to delete overlapping, nearby, or collinear vertices. - To solve that problem, use ``shapely.simplify()`` afterwards. - """ - if _package_is_older_than("shapely", "2.1"): - log.warning("Versions of shapely prior to v2.1 may cause plot errors.", log_once=True) - return obj - if obj.is_empty: - return obj - centroid = obj.centroid - object_size = min(obj.bounds[2] - obj.bounds[0], obj.bounds[3] - obj.bounds[1]) - if object_size == 0.0: - return shapely.Polygon([]) - - # To prevent numerical over- or underflow errors, subtract the centroid and rescale - normalized_obj = shapely.affinity.affine_transform( - obj, - matrix=[ - 1 / object_size, - 0.0, - 0.0, - 1 / object_size, - -centroid.x / object_size, - -centroid.y / object_size, - ], - ) - # Important: Remove any self intersections beforehand using `shapely.make_valid()`. - valid_obj = shapely.make_valid(normalized_obj, method="structure", keep_collapsed=False) - - # To get rid of small thin features, erode(shrink), dilate(expand), and erode again. - eroded_obj = shapely.buffer( - valid_obj, - distance=-tolerance_ratio, - cap_style="square", - quad_segs=3, - ) - dilated_obj = shapely.buffer( - eroded_obj, - distance=2 * tolerance_ratio, - cap_style="square", - quad_segs=3, - ) - cleaned_obj = dilated_obj - - # Optional: Now shrink the polygon back to the original size. - cleaned_obj = shapely.buffer( - cleaned_obj, - distance=-tolerance_ratio, - cap_style="square", - quad_segs=3, - ) - # Clean vertices of very close distances created during the erosion/dilation process. - # The distance value is heuristic. - cleaned_obj = cleaned_obj.simplify(POLY_DISTANCE_TOLERANCE, preserve_topology=True) - # Revert to the original scale and position. - rescaled_clean_obj = shapely.affinity.affine_transform( - cleaned_obj, - matrix=[ - object_size, - 0.0, - 0.0, - object_size, - centroid.x, - centroid.y, - ], - ) - return rescaled_clean_obj - - -from .utils import GeometryType, from_shapely, vertices_from_shapely # noqa: E402 diff --git a/tidy3d/components/geometry/bound_ops.py b/tidy3d/components/geometry/bound_ops.py index 4fd550a9e7..58aa09c5f8 100644 --- a/tidy3d/components/geometry/bound_ops.py +++ b/tidy3d/components/geometry/bound_ops.py @@ -1,71 +1,12 @@ -"""Geometry operations for bounding box type with minimal imports.""" +"""Compatibility shim for :mod:`tidy3d._common.components.geometry.bound_ops`.""" -from __future__ import annotations - -from math import isclose -from typing import TYPE_CHECKING - -from tidy3d.constants import fp_eps - -if TYPE_CHECKING: - from tidy3d.components.types import Bound - - -def bounds_intersection(bounds1: Bound, bounds2: Bound) -> Bound: - """Return the bounds that are the intersection of two bounds.""" - rmin1, rmax1 = bounds1 - rmin2, rmax2 = bounds2 - rmin = tuple(max(v1, v2) for v1, v2 in zip(rmin1, rmin2)) - rmax = tuple(min(v1, v2) for v1, v2 in zip(rmax1, rmax2)) - return (rmin, rmax) +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility +# marked as migrated to _common +from __future__ import annotations -def bounds_union(bounds1: Bound, bounds2: Bound) -> Bound: - """Return the bounds that are the union of two bounds.""" - rmin1, rmax1 = bounds1 - rmin2, rmax2 = bounds2 - rmin = tuple(min(v1, v2) for v1, v2 in zip(rmin1, rmin2)) - rmax = tuple(max(v1, v2) for v1, v2 in zip(rmax1, rmax2)) - return (rmin, rmax) - - -def bounds_contains( - outer_bounds: Bound, inner_bounds: Bound, rtol: float = fp_eps, atol: float = 0.0 -) -> bool: - """Checks whether ``inner_bounds`` is contained within ``outer_bounds`` within specified tolerances. - - Parameters - ---------- - outer_bounds : Bound - The outer bounds to check containment against - inner_bounds : Bound - The inner bounds to check if contained - rtol : float = fp_eps - Relative tolerance for comparing bounds - atol : float = 0.0 - Absolute tolerance for comparing bounds - - Returns - ------- - bool - True if ``inner_bounds`` is contained within ``outer_bounds`` within tolerances - """ - outer_min, outer_max = outer_bounds - inner_min, inner_max = inner_bounds - for dim in range(3): - outer_min_dim = outer_min[dim] - outer_max_dim = outer_max[dim] - inner_min_dim = inner_min[dim] - inner_max_dim = inner_max[dim] - within_min = ( - isclose(outer_min_dim, inner_min_dim, rel_tol=rtol, abs_tol=atol) - or outer_min_dim <= inner_min_dim - ) - within_max = ( - isclose(outer_max_dim, inner_max_dim, rel_tol=rtol, abs_tol=atol) - or outer_max_dim >= inner_max_dim - ) - - if not within_min or not within_max: - return False - return True +from tidy3d._common.components.geometry.bound_ops import ( + bounds_contains, + bounds_intersection, + bounds_union, +) diff --git a/tidy3d/components/geometry/float_utils.py b/tidy3d/components/geometry/float_utils.py index 5ab7b438be..a45674303e 100644 --- a/tidy3d/components/geometry/float_utils.py +++ b/tidy3d/components/geometry/float_utils.py @@ -1,31 +1,10 @@ -"""Utilities for float manipulation.""" - -from __future__ import annotations - -import numpy as np - -from tidy3d.constants import inf +"""Compatibility shim for :mod:`tidy3d._common.components.geometry.float_utils`.""" +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -def increment_float(val: float, sign: int) -> float: - """Applies a small positive or negative shift as though `val` is a 32bit float - using numpy.nextafter, but additionally handles some corner cases. - """ - # Infinity is left unchanged - if val == inf or val == -inf: - return val - - if sign >= 0: - sign = 1 - else: - sign = -1 - - # Avoid small increments within subnormal values - if np.abs(val) <= np.finfo(np.float32).tiny: - return val + sign * np.finfo(np.float32).tiny - - # Numpy seems to skip over the increment from -0.0 and +0.0 - # which is different from c++ - val_inc = np.nextafter(val, sign * inf, dtype=np.float32) +# marked as migrated to _common +from __future__ import annotations - return np.float32(val_inc) +from tidy3d._common.components.geometry.float_utils import ( + increment_float, +) diff --git a/tidy3d/components/geometry/mesh.py b/tidy3d/components/geometry/mesh.py index 3c7551ca39..c1a305e221 100644 --- a/tidy3d/components/geometry/mesh.py +++ b/tidy3d/components/geometry/mesh.py @@ -1,1286 +1,11 @@ -"""Mesh-defined geometry.""" +"""Compatibility shim for :mod:`tidy3d._common.components.geometry.mesh`.""" -from __future__ import annotations - -from abc import ABC -from typing import TYPE_CHECKING, Any, Optional - -import numpy as np -from autograd import numpy as anp -from numpy.typing import NDArray -from pydantic import Field, PrivateAttr, field_validator, model_validator - -from tidy3d.components.autograd import get_static -from tidy3d.components.base import cached_property -from tidy3d.components.data.data_array import DATA_ARRAY_MAP, TriangleMeshDataArray -from tidy3d.components.data.dataset import TriangleMeshDataset -from tidy3d.components.data.validators import validate_no_nans -from tidy3d.components.viz import add_ax_if_none, equal_aspect -from tidy3d.config import config -from tidy3d.constants import fp_eps, inf -from tidy3d.exceptions import DataError, ValidationError -from tidy3d.log import log -from tidy3d.packaging import verify_packages_import - -from . import base - -if TYPE_CHECKING: - from os import PathLike - from typing import Callable, Literal, Union - - from trimesh import Trimesh - - from tidy3d.components.autograd import AutogradFieldMap - from tidy3d.components.autograd.derivative_utils import DerivativeInfo - from tidy3d.components.types import Ax, Bound, Coordinate, MatrixReal4x4, Shapely - -AREA_SIZE_THRESHOLD = 1e-36 - - -class TriangleMesh(base.Geometry, ABC): - """Custom surface geometry given by a triangle mesh, as in the STL file format. - - Example - ------- - >>> vertices = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]]) - >>> faces = np.array([[1, 2, 3], [0, 3, 2], [0, 1, 3], [0, 2, 1]]) - >>> stl_geom = TriangleMesh.from_vertices_faces(vertices, faces) - """ - - mesh_dataset: Optional[TriangleMeshDataset] = Field( - None, - title="Surface mesh data", - description="Surface mesh data.", - ) - - _no_nans_mesh = validate_no_nans("mesh_dataset") - _barycentric_samples: dict[int, NDArray] = PrivateAttr(default_factory=dict) - - @verify_packages_import(["trimesh"]) - @model_validator(mode="before") - @classmethod - def _validate_trimesh_library(cls, data: dict[str, Any]) -> dict[str, Any]: - """Check if the trimesh package is imported as a validator.""" - return data - - @field_validator("mesh_dataset", mode="before") - @classmethod - def _warn_if_none(cls, val: TriangleMeshDataset) -> TriangleMeshDataset: - """Warn if the Dataset fails to load.""" - if isinstance(val, dict): - if any((v in DATA_ARRAY_MAP for _, v in val.items() if isinstance(v, str))): - log.warning("Loading 'mesh_dataset' without data.") - return None - return val - - @field_validator("mesh_dataset") - @classmethod - def _check_mesh(cls, val: TriangleMeshDataset) -> TriangleMeshDataset: - """Check that the mesh is valid.""" - if val is None: - return None - - import trimesh - - surface_mesh = val.surface_mesh - triangles = get_static(surface_mesh.data) - mesh = cls._triangles_to_trimesh(triangles) - if not all(np.array(mesh.area_faces) > AREA_SIZE_THRESHOLD): - old_tol = trimesh.tol.merge - trimesh.tol.merge = np.sqrt(2 * AREA_SIZE_THRESHOLD) - new_mesh = mesh.process(validate=True) - trimesh.tol.merge = old_tol - val = TriangleMesh.from_trimesh(new_mesh).mesh_dataset - log.warning( - f"The provided mesh has triangles with near zero area < {AREA_SIZE_THRESHOLD}. " - "Triangles which have one edge of their 2D oriented bounding box shorter than " - f"'sqrt(2*{AREA_SIZE_THRESHOLD}) are being automatically removed.'" - ) - if not all(np.array(new_mesh.area_faces) > AREA_SIZE_THRESHOLD): - raise ValidationError( - f"The provided mesh has triangles with near zero area < {AREA_SIZE_THRESHOLD}. " - "The automatic removal of these triangles has failed. You can try " - "using numpy-stl's 'from_file' import with 'remove_empty_areas' set " - "to True and a suitable 'AREA_SIZE_THRESHOLD' to remove them." - ) - if not mesh.is_watertight: - log.warning( - "The provided mesh is not watertight. " - "This can lead to incorrect permittivity distributions, " - "and can also cause problems with plotting and mesh validation. " - "You can try 'TriangleMesh.fill_holes', which attempts to repair the mesh. " - "Otherwise, the mesh may require manual repair. You can use a " - "'PermittivityMonitor' to check if the permittivity distribution is correct. " - "You can see which faces are broken using 'trimesh.repair.broken_faces'." - ) - if not mesh.is_winding_consistent: - log.warning( - "The provided mesh does not have consistent winding (face orientations). " - "This can lead to incorrect permittivity distributions, " - "and can also cause problems with plotting and mesh validation. " - "You can try 'TriangleMesh.fix_winding', which attempts to repair the mesh. " - "Otherwise, the mesh may require manual repair. You can use a " - "'PermittivityMonitor' to check if the permittivity distribution is correct. " - ) - if not mesh.is_volume: - log.warning( - "The provided mesh does not represent a valid volume, possibly due to " - "incorrect normal vector orientation. " - "This can lead to incorrect permittivity distributions, " - "and can also cause problems with plotting and mesh validation. " - "You can try 'TriangleMesh.fix_normals', " - "which attempts to fix the normals to be consistent and outward-facing. " - "Otherwise, the mesh may require manual repair. You can use a " - "'PermittivityMonitor' to check if the permittivity distribution is correct." - ) - - return val - - @verify_packages_import(["trimesh"]) - def fix_winding(self) -> TriangleMesh: - """Try to fix winding in the mesh.""" - import trimesh - - mesh = TriangleMesh._triangles_to_trimesh(self.mesh_dataset.surface_mesh) - trimesh.repair.fix_winding(mesh) - return TriangleMesh.from_trimesh(mesh) - - @verify_packages_import(["trimesh"]) - def fill_holes(self) -> TriangleMesh: - """Try to fill holes in the mesh. Can be used to repair non-watertight meshes.""" - import trimesh - - mesh = TriangleMesh._triangles_to_trimesh(self.mesh_dataset.surface_mesh) - trimesh.repair.fill_holes(mesh) - return TriangleMesh.from_trimesh(mesh) - - @verify_packages_import(["trimesh"]) - def fix_normals(self) -> TriangleMesh: - """Try to fix normals to be consistent and outward-facing.""" - import trimesh - - mesh = TriangleMesh._triangles_to_trimesh(self.mesh_dataset.surface_mesh) - trimesh.repair.fix_normals(mesh) - return TriangleMesh.from_trimesh(mesh) - - @classmethod - @verify_packages_import(["trimesh"]) - def from_stl( - cls, - filename: str, - scale: float = 1.0, - origin: tuple[float, float, float] = (0, 0, 0), - solid_index: Optional[int] = None, - **kwargs: Any, - ) -> Union[TriangleMesh, base.GeometryGroup]: - """Load a :class:`.TriangleMesh` directly from an STL file. - The ``solid_index`` parameter can be used to select a single solid from the file. - Otherwise, if the file contains a single solid, it will be loaded as a - :class:`.TriangleMesh`; if the file contains multiple solids, - they will all be loaded as a :class:`.GeometryGroup`. - - Parameters - ---------- - filename : str - The name of the STL file containing the surface geometry mesh data. - scale : float = 1.0 - The length scale for the loaded geometry (um). - For example, a scale of 10.0 means that a vertex (1, 0, 0) will be placed at - x = 10 um. - origin : tuple[float, float, float] = (0, 0, 0) - The origin of the loaded geometry, in units of ``scale``. - Translates from (0, 0, 0) to this point after applying the scaling. - solid_index : int = None - If set, read a single solid with this index from the file. - - Returns - ------- - Union[:class:`.TriangleMesh`, :class:`.GeometryGroup`] - The geometry or geometry group from the file. - """ - import trimesh - - from tidy3d.components.types.third_party import TrimeshType - - def process_single(mesh: TrimeshType) -> TriangleMesh: - """Process a single 'trimesh.Trimesh' using scale and origin.""" - mesh.apply_scale(scale) - mesh.apply_translation(origin) - return cls.from_trimesh(mesh) - - scene = trimesh.load(filename, **kwargs) - meshes = [] - if isinstance(scene, trimesh.Trimesh): - meshes = [scene] - elif isinstance(scene, trimesh.Scene): - meshes = scene.dump() - else: - raise ValidationError( - "Invalid trimesh type in file. Supported types are 'trimesh.Trimesh' " - "and 'trimesh.Scene'." - ) - - if solid_index is None: - if isinstance(scene, trimesh.Trimesh): - return process_single(scene) - if isinstance(scene, trimesh.Scene): - geoms = [process_single(mesh) for mesh in meshes] - return base.GeometryGroup(geometries=geoms) - - if solid_index < len(meshes): - return process_single(meshes[solid_index]) - raise ValidationError("No solid found at 'solid_index' in the stl file.") - - @verify_packages_import(["trimesh"]) - def to_stl( - self, - filename: PathLike, - *, - binary: bool = True, - ) -> None: - """Export this TriangleMesh to an STL file. - - Parameters - ---------- - filename : str - Output STL filename. - binary : bool = True - Whether to write binary STL. Set False for ASCII STL. - """ - triangles = get_static(self.mesh_dataset.surface_mesh.data) - mesh = self._triangles_to_trimesh(triangles) - - file_type = "stl" if binary else "stl_ascii" - mesh.export(file_obj=filename, file_type=file_type) - - @classmethod - @verify_packages_import(["trimesh"]) - def from_trimesh(cls, mesh: Trimesh) -> TriangleMesh: - """Create a :class:`.TriangleMesh` from a ``trimesh.Trimesh`` object. - - Parameters - ---------- - trimesh : ``trimesh.Trimesh`` - The Trimesh object containing the surface geometry mesh data. - - Returns - ------- - :class:`.TriangleMesh` - The custom surface mesh geometry given by the ``trimesh.Trimesh`` provided. - """ - return cls.from_vertices_faces(mesh.vertices, mesh.faces) - - @classmethod - def from_triangles(cls, triangles: NDArray) -> TriangleMesh: - """Create a :class:`.TriangleMesh` from a numpy array - containing the triangles of a surface mesh. - - Parameters - ---------- - triangles : ``np.ndarray`` - A numpy array of shape (N, 3, 3) storing the triangles of the surface mesh. - The first index labels the triangle, the second index labels the vertex - within a given triangle, and the third index is the coordinate (x, y, or z). - - Returns - ------- - :class:`.TriangleMesh` - The custom surface mesh geometry given by the triangles provided. - - """ - triangles = anp.array(triangles) - if len(triangles.shape) != 3 or triangles.shape[1] != 3 or triangles.shape[2] != 3: - raise ValidationError( - f"Provided 'triangles' must be an N x 3 x 3 array, given {triangles.shape}." - ) - num_faces = len(triangles) - coords = { - "face_index": np.arange(num_faces), - "vertex_index": np.arange(3), - "axis": np.arange(3), - } - vertices = TriangleMeshDataArray(triangles, coords=coords) - mesh_dataset = TriangleMeshDataset(surface_mesh=vertices) - return TriangleMesh(mesh_dataset=mesh_dataset) - - @classmethod - @verify_packages_import(["trimesh"]) - def from_vertices_faces(cls, vertices: NDArray, faces: NDArray) -> TriangleMesh: - """Create a :class:`.TriangleMesh` from numpy arrays containing the data - of a surface mesh. The first array contains the vertices, and the second array contains - faces formed from triples of the vertices. - - Parameters - ---------- - vertices: ``np.ndarray`` - A numpy array of shape (N, 3) storing the vertices of the surface mesh. - The first index labels the vertex, and the second index is the coordinate - (x, y, or z). - faces : ``np.ndarray`` - A numpy array of shape (M, 3) storing the indices of the vertices of each face - in the surface mesh. The first index labels the face, and the second index - labels the vertex index within the ``vertices`` array. - - Returns - ------- - :class:`.TriangleMesh` - The custom surface mesh geometry given by the vertices and faces provided. - - """ - import trimesh - - vertices = np.array(vertices) - faces = np.array(faces) - if len(vertices.shape) != 2 or vertices.shape[1] != 3: - raise ValidationError( - f"Provided 'vertices' must be an N x 3 array, given {vertices.shape}." - ) - if len(faces.shape) != 2 or faces.shape[1] != 3: - raise ValidationError(f"Provided 'faces' must be an M x 3 array, given {faces.shape}.") - return cls.from_triangles(trimesh.Trimesh(vertices, faces).triangles) - - @classmethod - @verify_packages_import(["trimesh"]) - def _triangles_to_trimesh( - cls, triangles: NDArray - ) -> Trimesh: # -> We need to get this out of the classes and into functional methods operating on a class (maybe still referenced to the class) - """Convert an (N, 3, 3) numpy array of triangles to a ``trimesh.Trimesh``.""" - import trimesh - - # ``triangles`` may contain autograd ``ArrayBox`` entries when differentiating - # geometry parameters. ``trimesh`` expects plain ``float`` values, so strip any - # tracing information before constructing the mesh. - triangles = get_static(anp.array(triangles)) - return trimesh.Trimesh(**trimesh.triangles.to_kwargs(triangles)) - - @classmethod - def from_height_grid( - cls, - axis: Ax, - direction: Literal["-", "+"], - base: float, - grid: tuple[np.ndarray, np.ndarray], - height: NDArray, - ) -> TriangleMesh: - """Construct a TriangleMesh object from grid based height information. - - Parameters - ---------- - axis : Ax - Axis of extrusion. - direction : Literal["-", "+"] - Direction of extrusion. - base : float - Coordinate of the base surface along the geometry's axis. - grid : Tuple[np.ndarray, np.ndarray] - Tuple of two one-dimensional arrays representing the sampling grid (XY, YZ, or ZX - corresponding to values of axis) - height : np.ndarray - Height values sampled on the given grid. Can be 1D (raveled) or 2D (matching grid mesh). - - Returns - ------- - TriangleMesh - The resulting TriangleMesh geometry object. - """ - - x_coords = grid[0] - y_coords = grid[1] - - nx = len(x_coords) - ny = len(y_coords) - nt = nx * ny - - x_mesh, y_mesh = np.meshgrid(x_coords, y_coords, indexing="ij") - - sign = 1 - if direction == "-": - sign = -1 - - flat_height = np.ravel(height) - if flat_height.shape[0] != nt: - raise ValueError( - f"Shape of flattened height array {flat_height.shape} does not match " - f"the number of grid points {nt}." - ) - - if np.any(flat_height < 0): - raise ValueError("All height values must be non-negative.") - - max_h = np.max(flat_height) - min_h_clip = fp_eps * max_h - flat_height = np.clip(flat_height, min_h_clip, inf) - - vertices_raw_list = [ - [np.ravel(x_mesh), np.ravel(y_mesh), base + sign * flat_height], # Alpha surface - [np.ravel(x_mesh), np.ravel(y_mesh), base * np.ones(nt)], - ] - - if direction == "-": - vertices_raw_list = vertices_raw_list[::-1] - - vertices = np.hstack(vertices_raw_list).T - vertices = np.roll(vertices, shift=axis - 2, axis=1) - - q0 = (np.arange(nx - 1)[:, None] * ny + np.arange(ny - 1)[None, :]).ravel() - q1 = (np.arange(1, nx)[:, None] * ny + np.arange(ny - 1)[None, :]).ravel() - q2 = (np.arange(1, nx)[:, None] * ny + np.arange(1, ny)[None, :]).ravel() - q3 = (np.arange(nx - 1)[:, None] * ny + np.arange(1, ny)[None, :]).ravel() - - q0_b = nt + q0 - q1_b = nt + q1 - q2_b = nt + q2 - q3_b = nt + q3 - - top_quads = np.stack((q0, q1, q2, q3), axis=-1) - bottom_quads = np.stack((q0_b, q3_b, q2_b, q1_b), axis=-1) - - s1_q0 = (0 * ny + np.arange(ny - 1)).ravel() - s1_q1 = (0 * ny + np.arange(1, ny)).ravel() - s1_q2 = (nt + 0 * ny + np.arange(1, ny)).ravel() - s1_q3 = (nt + 0 * ny + np.arange(ny - 1)).ravel() - side1_quads = np.stack((s1_q0, s1_q1, s1_q2, s1_q3), axis=-1) - - s2_q0 = ((nx - 1) * ny + np.arange(ny - 1)).ravel() - s2_q1 = (nt + (nx - 1) * ny + np.arange(ny - 1)).ravel() - s2_q2 = (nt + (nx - 1) * ny + np.arange(1, ny)).ravel() - s2_q3 = ((nx - 1) * ny + np.arange(1, ny)).ravel() - side2_quads = np.stack((s2_q0, s2_q1, s2_q2, s2_q3), axis=-1) - - s3_q0 = (np.arange(nx - 1) * ny + 0).ravel() - s3_q1 = (nt + np.arange(nx - 1) * ny + 0).ravel() - s3_q2 = (nt + np.arange(1, nx) * ny + 0).ravel() - s3_q3 = (np.arange(1, nx) * ny + 0).ravel() - side3_quads = np.stack((s3_q0, s3_q1, s3_q2, s3_q3), axis=-1) - - s4_q0 = (np.arange(nx - 1) * ny + ny - 1).ravel() - s4_q1 = (np.arange(1, nx) * ny + ny - 1).ravel() - s4_q2 = (nt + np.arange(1, nx) * ny + ny - 1).ravel() - s4_q3 = (nt + np.arange(nx - 1) * ny + ny - 1).ravel() - side4_quads = np.stack((s4_q0, s4_q1, s4_q2, s4_q3), axis=-1) - - all_quads = np.vstack( - (top_quads, bottom_quads, side1_quads, side2_quads, side3_quads, side4_quads) - ) - - triangles_list = [ - np.stack((all_quads[:, 0], all_quads[:, 1], all_quads[:, 3]), axis=-1), - np.stack((all_quads[:, 3], all_quads[:, 1], all_quads[:, 2]), axis=-1), - ] - tri_faces = np.vstack(triangles_list) - - return cls.from_vertices_faces(vertices=vertices, faces=tri_faces) - - @classmethod - def from_height_function( - cls, - axis: Ax, - direction: Literal["-", "+"], - base: float, - center: tuple[float, float], - size: tuple[float, float], - grid_size: tuple[int, int], - height_func: Callable[[np.ndarray, np.ndarray], np.ndarray], - ) -> TriangleMesh: - """Construct a TriangleMesh object from analytical expression of height function. - The height function should be vectorized to accept 2D meshgrid arrays. - - Parameters - ---------- - axis : Ax - Axis of extrusion. - direction : Literal["-", "+"] - Direction of extrusion. - base : float - Coordinate of the base rectangle along the geometry's axis. - center : Tuple[float, float] - Center of the base rectangle in the plane perpendicular to the extrusion axis - (XY, YZ, or ZX corresponding to values of axis). - size : Tuple[float, float] - Size of the base rectangle in the plane perpendicular to the extrusion axis - (XY, YZ, or ZX corresponding to values of axis). - grid_size : Tuple[int, int] - Number of grid points for discretization of the base rectangle - (XY, YZ, or ZX corresponding to values of axis). - height_func : Callable[[np.ndarray, np.ndarray], np.ndarray] - Vectorized function to compute height values from 2D meshgrid coordinate arrays. - It should take two ndarrays (x_mesh, y_mesh) and return an ndarray of heights. - - Returns - ------- - TriangleMesh - The resulting TriangleMesh geometry object. - """ - x_lin = np.linspace(center[0] - 0.5 * size[0], center[0] + 0.5 * size[0], grid_size[0]) - y_lin = np.linspace(center[1] - 0.5 * size[1], center[1] + 0.5 * size[1], grid_size[1]) - - x_mesh, y_mesh = np.meshgrid(x_lin, y_lin, indexing="ij") - - height_values = height_func(x_mesh, y_mesh) - - if not (isinstance(height_values, np.ndarray) and height_values.shape == x_mesh.shape): - raise ValueError( - f"The 'height_func' must return a NumPy array with shape {x_mesh.shape}, " - f"but got shape {getattr(height_values, 'shape', type(height_values))}." - ) - - return cls.from_height_grid( - axis=axis, - direction=direction, - base=base, - grid=(x_lin, y_lin), - height=height_values, - ) - - @cached_property - @verify_packages_import(["trimesh"]) - def trimesh( - self, - ) -> Trimesh: # -> We need to get this out of the classes and into functional methods operating on a class (maybe still referenced to the class) - """A ``trimesh.Trimesh`` object representing the custom surface mesh geometry.""" - return self._triangles_to_trimesh(self.triangles) - - @cached_property - def triangles(self) -> np.ndarray: - """The triangles of the surface mesh as an ``np.ndarray``.""" - if self.mesh_dataset is None: - raise DataError("Can't get triangles as 'mesh_dataset' is None.") - return np.asarray(get_static(self.mesh_dataset.surface_mesh.data)) - - def _surface_area(self, bounds: Bound) -> float: - """Returns object's surface area within given bounds.""" - # currently ignores bounds - return self.trimesh.area - - def _volume(self, bounds: Bound) -> float: - """Returns object's volume within given bounds.""" - # currently ignores bounds - return self.trimesh.volume - - @cached_property - def bounds(self) -> Bound: - """Returns bounding box min and max coordinates. - - Returns - ------- - tuple[float, float, float], tuple[float, float float] - Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. - """ - if self.mesh_dataset is None: - return ((-inf, -inf, -inf), (inf, inf, inf)) - return self.trimesh.bounds - - def intersections_tilted_plane( - self, - normal: Coordinate, - origin: Coordinate, - to_2D: MatrixReal4x4, - cleanup: bool = True, - quad_segs: Optional[int] = None, - ) -> list[Shapely]: - """Return a list of shapely geometries at the plane specified by normal and origin. - - Parameters - ---------- - normal : Coordinate - Vector defining the normal direction to the plane. - origin : Coordinate - Vector defining the plane origin. - to_2D : MatrixReal4x4 - Transformation matrix to apply to resulting shapes. - cleanup : bool = True - If True, removes extremely small features from each polygon's boundary. - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. Not used for TriangleMesh. - - Returns - ------- - list[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation `_. - """ - section = self.trimesh.section(plane_origin=origin, plane_normal=normal) - if section is None: - return [] - path, _ = section.to_2D(to_2D=to_2D) - return path.polygons_full - - def intersections_plane( - self, - x: Optional[float] = None, - y: Optional[float] = None, - z: Optional[float] = None, - cleanup: bool = True, - quad_segs: Optional[int] = None, - ) -> list[Shapely]: - """Returns list of shapely geometries at plane specified by one non-None value of x,y,z. - - Parameters - ---------- - x : float = None - Position of plane in x direction, only one of x,y,z can be specified to define plane. - y : float = None - Position of plane in y direction, only one of x,y,z can be specified to define plane. - z : float = None - Position of plane in z direction, only one of x,y,z can be specified to define plane. - cleanup : bool = True - If True, removes extremely small features from each polygon's boundary. - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. Not used for TriangleMesh. - - Returns - ------- - list[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentaton `_. - """ - - if self.mesh_dataset is None: - return [] - - axis, position = self.parse_xyz_kwargs(x=x, y=y, z=z) - - origin = self.unpop_axis(position, (0, 0), axis=axis) - normal = self.unpop_axis(1, (0, 0), axis=axis) - - mesh = self.trimesh - - try: - section = mesh.section(plane_origin=origin, plane_normal=normal) - - if section is None: - return [] +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility - # homogeneous transformation matrix to map to xy plane - mapping = np.eye(4) - - # translate to origin - mapping[3, :3] = -np.array(origin) - - # permute so normal is aligned with z axis - # and (y, z), (x, z), resp. (x, y) are aligned with (x, y) - identity = np.eye(3) - permutation = self.unpop_axis(identity[2], identity[0:2], axis=axis) - mapping[:3, :3] = np.array(permutation).T - - section2d, _ = section.to_2D(to_2D=mapping) - return list(section2d.polygons_full) - - except ValueError as e: - if not mesh.is_watertight: - log.warning( - "Unable to compute 'TriangleMesh.intersections_plane' " - "because the mesh was not watertight. Using bounding box instead. " - "This may be overly strict; consider using 'TriangleMesh.fill_holes' " - "to repair the non-watertight mesh." - ) - else: - log.warning( - "Unable to compute 'TriangleMesh.intersections_plane'. " - "Using bounding box instead." - ) - log.warning(f"Error encountered: {e}") - return self.bounding_box.intersections_plane(x=x, y=y, z=z, cleanup=cleanup) - - def inside(self, x: NDArray, y: NDArray, z: NDArray) -> np.ndarray[bool]: - """For input arrays ``x``, ``y``, ``z`` of arbitrary but identical shape, return an array - with the same shape which is ``True`` for every point in zip(x, y, z) that is inside the - volume of the :class:`Geometry`, and ``False`` otherwise. - - Parameters - ---------- - x : np.ndarray[float] - Array of point positions in x direction. - y : np.ndarray[float] - Array of point positions in y direction. - z : np.ndarray[float] - Array of point positions in z direction. - - Returns - ------- - np.ndarray[bool] - ``True`` for every point that is inside the geometry. - """ - - arrays = tuple(map(np.array, (x, y, z))) - self._ensure_equal_shape(*arrays) - arrays_flat = map(np.ravel, arrays) - arrays_stacked = np.stack(tuple(arrays_flat), axis=-1) - inside = self.trimesh.contains(arrays_stacked) - return inside.reshape(arrays[0].shape) - - @equal_aspect - @add_ax_if_none - def plot( - self, - x: Optional[float] = None, - y: Optional[float] = None, - z: Optional[float] = None, - ax: Ax = None, - **patch_kwargs: Any, - ) -> Ax: - """Plot geometry cross section at single (x,y,z) coordinate. - - Parameters - ---------- - x : float = None - Position of plane in x direction, only one of x,y,z can be specified to define plane. - y : float = None - Position of plane in y direction, only one of x,y,z can be specified to define plane. - z : float = None - Position of plane in z direction, only one of x,y,z can be specified to define plane. - ax : matplotlib.axes._subplots.Axes = None - Matplotlib axes to plot on, if not specified, one is created. - **patch_kwargs - Optional keyword arguments passed to the matplotlib patch plotting of structure. - For details on accepted values, refer to - `Matplotlib's documentation `_. - - Returns - ------- - matplotlib.axes._subplots.Axes - The supplied or created matplotlib axes. - """ - - log.warning( - "Plotting a 'TriangleMesh' may give inconsistent results " - "if the mesh is not unionized. We recommend unionizing all meshes before import. " - "A 'PermittivityMonitor' can be used to check that the mesh is loaded correctly." - ) - - return base.Geometry.plot(self, x=x, y=y, z=z, ax=ax, **patch_kwargs) - - def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: - """Compute adjoint derivatives for a ``TriangleMesh`` geometry.""" - vjps: AutogradFieldMap = {} - - if not self.mesh_dataset: - raise DataError("Can't compute derivatives without mesh data.") - - valid_paths = {("mesh_dataset", "surface_mesh")} - for path in derivative_info.paths: - if path not in valid_paths: - raise ValueError(f"No derivative defined w.r.t. 'TriangleMesh' field '{path}'.") - - if ("mesh_dataset", "surface_mesh") not in derivative_info.paths: - return vjps - - triangles = np.asarray(self.triangles, dtype=config.adjoint.gradient_dtype_float) - - # early exit if geometry is completely outside simulation bounds - sim_min, sim_max = map(np.asarray, derivative_info.simulation_bounds) - mesh_min, mesh_max = map(np.asarray, self.bounds) - if np.any(mesh_max < sim_min) or np.any(mesh_min > sim_max): - log.warning( - "'TriangleMesh' lies completely outside the simulation domain.", - log_once=True, - ) - zeros = np.zeros_like(triangles) - vjps[("mesh_dataset", "surface_mesh")] = zeros - return vjps - - # gather surface samples within the simulation bounds - dx = derivative_info.adaptive_vjp_spacing() - samples = self._collect_surface_samples( - triangles=triangles, - spacing=dx, - sim_min=sim_min, - sim_max=sim_max, - ) - - if samples["points"].shape[0] == 0: - zeros = np.zeros_like(triangles) - vjps[("mesh_dataset", "surface_mesh")] = zeros - return vjps - - interpolators = derivative_info.interpolators - if interpolators is None: - interpolators = derivative_info.create_interpolators( - dtype=config.adjoint.gradient_dtype_float - ) - - g = derivative_info.evaluate_gradient_at_points( - samples["points"], - samples["normals"], - samples["perps1"], - samples["perps2"], - interpolators, - ) - - # accumulate per-vertex contributions using barycentric weights - weights = (samples["weights"] * g).real - normals = samples["normals"] - faces = samples["faces"] - bary = samples["barycentric"] - - contrib_vec = weights[:, None] * normals - - triangle_grads = np.zeros_like(triangles, dtype=config.adjoint.gradient_dtype_float) - for vertex_idx in range(3): - scaled = contrib_vec * bary[:, vertex_idx][:, None] - np.add.at(triangle_grads[:, vertex_idx, :], faces, scaled) - - vjps[("mesh_dataset", "surface_mesh")] = triangle_grads - return vjps - - def _collect_surface_samples( - self, - triangles: NDArray, - spacing: float, - sim_min: NDArray, - sim_max: NDArray, - ) -> dict[str, np.ndarray]: - """Deterministic per-triangle sampling used historically.""" - - dtype = config.adjoint.gradient_dtype_float - tol = config.adjoint.edge_clip_tolerance - - sim_min = np.asarray(sim_min, dtype=dtype) - sim_max = np.asarray(sim_max, dtype=dtype) - - points_list: list[np.ndarray] = [] - normals_list: list[np.ndarray] = [] - perps1_list: list[np.ndarray] = [] - perps2_list: list[np.ndarray] = [] - weights_list: list[np.ndarray] = [] - faces_list: list[np.ndarray] = [] - bary_list: list[np.ndarray] = [] - - spacing = max(float(spacing), np.finfo(float).eps) - triangles_arr = np.asarray(triangles, dtype=dtype) - - sim_extents = sim_max - sim_min - valid_axes = np.abs(sim_extents) > tol - collapsed_indices = np.flatnonzero(np.isclose(sim_extents, 0.0, atol=tol)) - collapsed_axis: Optional[int] = None - plane_value: Optional[float] = None - if collapsed_indices.size == 1: - collapsed_axis = int(collapsed_indices[0]) - plane_value = float(sim_min[collapsed_axis]) - - warned = False - warning_msg = "Some triangles from the mesh lie outside the simulation bounds - this may lead to inaccurate gradients." - for face_index, tri in enumerate(triangles_arr): - area, normal = self._triangle_area_and_normal(tri) - if area <= AREA_SIZE_THRESHOLD: - continue - - perps = self._triangle_tangent_basis(tri, normal) - if perps is None: - continue - perp1, perp2 = perps - - if collapsed_axis is not None and plane_value is not None: - samples, outside_bounds = self._collect_surface_samples_2d( - triangle=tri, - face_index=face_index, - normal=normal, - perp1=perp1, - perp2=perp2, - spacing=spacing, - collapsed_axis=collapsed_axis, - plane_value=plane_value, - sim_min=sim_min, - sim_max=sim_max, - valid_axes=valid_axes, - tol=tol, - dtype=dtype, - ) - else: - samples, outside_bounds = self._collect_surface_samples_3d( - triangle=tri, - face_index=face_index, - normal=normal, - perp1=perp1, - perp2=perp2, - area=area, - spacing=spacing, - sim_min=sim_min, - sim_max=sim_max, - valid_axes=valid_axes, - tol=tol, - dtype=dtype, - ) - - if outside_bounds and not warned: - log.warning(warning_msg) - warned = True - - if samples is None: - continue - - points_list.append(samples["points"]) - normals_list.append(samples["normals"]) - perps1_list.append(samples["perps1"]) - perps2_list.append(samples["perps2"]) - weights_list.append(samples["weights"]) - faces_list.append(samples["faces"]) - bary_list.append(samples["barycentric"]) - - if not points_list: - return { - "points": np.zeros((0, 3), dtype=dtype), - "normals": np.zeros((0, 3), dtype=dtype), - "perps1": np.zeros((0, 3), dtype=dtype), - "perps2": np.zeros((0, 3), dtype=dtype), - "weights": np.zeros((0,), dtype=dtype), - "faces": np.zeros((0,), dtype=int), - "barycentric": np.zeros((0, 3), dtype=dtype), - } - - return { - "points": np.concatenate(points_list, axis=0), - "normals": np.concatenate(normals_list, axis=0), - "perps1": np.concatenate(perps1_list, axis=0), - "perps2": np.concatenate(perps2_list, axis=0), - "weights": np.concatenate(weights_list, axis=0), - "faces": np.concatenate(faces_list, axis=0), - "barycentric": np.concatenate(bary_list, axis=0), - } - - def _collect_surface_samples_2d( - self, - triangle: NDArray, - face_index: int, - normal: np.ndarray, - perp1: np.ndarray, - perp2: np.ndarray, - spacing: float, - collapsed_axis: int, - plane_value: float, - sim_min: np.ndarray, - sim_max: np.ndarray, - valid_axes: np.ndarray, - tol: float, - dtype: np.dtype, - ) -> tuple[Optional[dict[str, np.ndarray]], bool]: - """Collect samples when the simulation bounds collapse onto a 2D plane.""" - - segments = self._triangle_plane_segments( - triangle=triangle, axis=collapsed_axis, plane_value=plane_value, tol=tol - ) - - points: list[np.ndarray] = [] - normals: list[np.ndarray] = [] - perps1_list: list[np.ndarray] = [] - perps2_list: list[np.ndarray] = [] - weights: list[np.ndarray] = [] - faces: list[np.ndarray] = [] - barycentric: list[np.ndarray] = [] - outside_bounds = False - - for start, end in segments: - vec = end - start - length = float(np.linalg.norm(vec)) - if length <= tol: - continue - - subdivisions = max(1, int(np.ceil(length / spacing))) - t_vals = (np.arange(subdivisions, dtype=dtype) + 0.5) / subdivisions - sample_points = start[None, :] + t_vals[:, None] * vec[None, :] - bary = self._barycentric_coordinates(triangle, sample_points, tol) - - inside_mask = np.ones(sample_points.shape[0], dtype=bool) - if np.any(valid_axes): - min_bound = (sim_min - tol)[valid_axes] - max_bound = (sim_max + tol)[valid_axes] - coords = sample_points[:, valid_axes] - inside_mask = np.all(coords >= min_bound, axis=1) & np.all( - coords <= max_bound, axis=1 - ) - - outside_bounds = outside_bounds or (not np.all(inside_mask)) - if not np.any(inside_mask): - continue - - sample_points = sample_points[inside_mask] - bary_inside = bary[inside_mask] - n_inside = sample_points.shape[0] - - normal_tile = np.repeat(normal[None, :], n_inside, axis=0) - perp1_tile = np.repeat(perp1[None, :], n_inside, axis=0) - perp2_tile = np.repeat(perp2[None, :], n_inside, axis=0) - weights_tile = np.full(n_inside, length / subdivisions, dtype=dtype) - faces_tile = np.full(n_inside, face_index, dtype=int) - - points.append(sample_points) - normals.append(normal_tile) - perps1_list.append(perp1_tile) - perps2_list.append(perp2_tile) - weights.append(weights_tile) - faces.append(faces_tile) - barycentric.append(bary_inside) - - if not points: - return None, outside_bounds - - samples = { - "points": np.concatenate(points, axis=0), - "normals": np.concatenate(normals, axis=0), - "perps1": np.concatenate(perps1_list, axis=0), - "perps2": np.concatenate(perps2_list, axis=0), - "weights": np.concatenate(weights, axis=0), - "faces": np.concatenate(faces, axis=0), - "barycentric": np.concatenate(barycentric, axis=0), - } - return samples, outside_bounds - - def _collect_surface_samples_3d( - self, - triangle: NDArray, - face_index: int, - normal: np.ndarray, - perp1: np.ndarray, - perp2: np.ndarray, - area: float, - spacing: float, - sim_min: np.ndarray, - sim_max: np.ndarray, - valid_axes: np.ndarray, - tol: float, - dtype: np.dtype, - ) -> tuple[Optional[dict[str, np.ndarray]], bool]: - """Collect samples when the simulation bounds represent a full 3D region.""" - - edge_lengths = ( - np.linalg.norm(triangle[1] - triangle[0]), - np.linalg.norm(triangle[2] - triangle[1]), - np.linalg.norm(triangle[0] - triangle[2]), - ) - subdivisions = self._subdivision_count(area, spacing, edge_lengths) - barycentric = self._get_barycentric_samples(subdivisions, dtype) - num_samples = barycentric.shape[0] - base_weight = area / num_samples - - sample_points = barycentric @ triangle - - inside_mask = np.all( - sample_points[:, valid_axes] >= (sim_min - tol)[valid_axes], axis=1 - ) & np.all(sample_points[:, valid_axes] <= (sim_max + tol)[valid_axes], axis=1) - outside_bounds = not np.all(inside_mask) - if not np.any(inside_mask): - return None, outside_bounds - - sample_points = sample_points[inside_mask] - bary_inside = barycentric[inside_mask] - n_samples_inside = sample_points.shape[0] - - normal_tile = np.repeat(normal[None, :], n_samples_inside, axis=0) - perp1_tile = np.repeat(perp1[None, :], n_samples_inside, axis=0) - perp2_tile = np.repeat(perp2[None, :], n_samples_inside, axis=0) - weights_tile = np.full(n_samples_inside, base_weight, dtype=dtype) - faces_tile = np.full(n_samples_inside, face_index, dtype=int) - - samples = { - "points": sample_points, - "normals": normal_tile, - "perps1": perp1_tile, - "perps2": perp2_tile, - "weights": weights_tile, - "faces": faces_tile, - "barycentric": bary_inside, - } - return samples, outside_bounds - - @staticmethod - def _triangle_area_and_normal(triangle: NDArray) -> tuple[float, np.ndarray]: - """Return area and outward normal of the provided triangle.""" - - edge01 = triangle[1] - triangle[0] - edge02 = triangle[2] - triangle[0] - cross = np.cross(edge01, edge02) - norm = np.linalg.norm(cross) - if norm <= 0.0: - return 0.0, np.zeros(3, dtype=triangle.dtype) - normal = (cross / norm).astype(triangle.dtype, copy=False) - area = 0.5 * norm - return area, normal - - @staticmethod - def _triangle_plane_segments( - triangle: NDArray, axis: int, plane_value: float, tol: float - ) -> list[tuple[np.ndarray, np.ndarray]]: - """Return intersection segments between a triangle and an axis-aligned plane.""" - - vertices = np.asarray(triangle) - distances = vertices[:, axis] - plane_value - edges = ((0, 1), (1, 2), (2, 0)) - - segments: list[tuple[np.ndarray, np.ndarray]] = [] - points: list[np.ndarray] = [] - - def add_point(pt: np.ndarray) -> None: - for existing in points: - if np.linalg.norm(existing - pt) <= tol: - return - points.append(pt.copy()) - - for i, j in edges: - di = distances[i] - dj = distances[j] - vi = vertices[i] - vj = vertices[j] - - if abs(di) <= tol and abs(dj) <= tol: - segments.append((vi.copy(), vj.copy())) - continue - - if di * dj > 0.0: - continue - - if abs(di) <= tol: - add_point(vi) - continue - - if abs(dj) <= tol: - add_point(vj) - continue - - denom = di - dj - if abs(denom) <= tol: - continue - t = di / denom - if t < 0.0 or t > 1.0: - continue - point = vi + t * (vj - vi) - add_point(point) - - if segments: - return segments - - if len(points) >= 2: - return [(points[0], points[1])] - - return [] - - @staticmethod - def _barycentric_coordinates(triangle: NDArray, points: np.ndarray, tol: float) -> np.ndarray: - """Compute barycentric coordinates of ``points`` with respect to ``triangle``.""" - - pts = np.asarray(points, dtype=triangle.dtype) - v0 = triangle[0] - v1 = triangle[1] - v2 = triangle[2] - v0v1 = v1 - v0 - v0v2 = v2 - v0 - - d00 = float(np.dot(v0v1, v0v1)) - d01 = float(np.dot(v0v1, v0v2)) - d11 = float(np.dot(v0v2, v0v2)) - denom = d00 * d11 - d01 * d01 - if abs(denom) <= tol: - return np.tile( - np.array([1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0], dtype=triangle.dtype), (pts.shape[0], 1) - ) - - v0p = pts - v0 - d20 = v0p @ v0v1 - d21 = v0p @ v0v2 - v = (d11 * d20 - d01 * d21) / denom - w = (d00 * d21 - d01 * d20) / denom - u = 1.0 - v - w - bary = np.stack((u, v, w), axis=1) - return bary.astype(triangle.dtype, copy=False) - - @classmethod - def _subdivision_count( - cls, - area: float, - spacing: float, - edge_lengths: Optional[tuple[float, float, float]] = None, - ) -> int: - """Determine the number of subdivisions needed for the given area and spacing.""" - - spacing = max(float(spacing), np.finfo(float).eps) - - target = np.sqrt(max(area, 0.0)) - area_based = np.ceil(np.sqrt(2.0) * target / spacing) - - edge_based = 0.0 - if edge_lengths: - max_edge = max(edge_lengths) - if max_edge > 0.0: - edge_based = np.ceil(max_edge / spacing) - - subdivisions = max(1, int(max(area_based, edge_based))) - return subdivisions - - def _get_barycentric_samples(self, subdivisions: int, dtype: np.dtype) -> np.ndarray: - """Return barycentric sample coordinates for a subdivision level.""" - - cache = self._barycentric_samples - if subdivisions not in cache: - cache[subdivisions] = self._build_barycentric_samples(subdivisions) - return cache[subdivisions].astype(dtype, copy=False) - - @staticmethod - def _build_barycentric_samples(subdivisions: int) -> np.ndarray: - """Construct barycentric sampling points for a given subdivision level.""" - - if subdivisions <= 1: - return np.array([[1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0]]) - - bary = [] - for i in range(subdivisions): - for j in range(subdivisions - i): - l1 = (i + 1.0 / 3.0) / subdivisions - l2 = (j + 1.0 / 3.0) / subdivisions - l0 = 1.0 - l1 - l2 - bary.append((l0, l1, l2)) - return np.asarray(bary, dtype=float) - - @staticmethod - def subdivide_faces(vertices: NDArray, faces: NDArray) -> tuple[np.ndarray, np.ndarray]: - """Uniformly subdivide each triangular face by inserting edge midpoints.""" - - midpoint_cache: dict[tuple[int, int], int] = {} - verts_list = [np.asarray(v, dtype=float) for v in vertices] - - def midpoint(i: int, j: int) -> int: - key = (i, j) if i < j else (j, i) - if key in midpoint_cache: - return midpoint_cache[key] - vm = 0.5 * (verts_list[i] + verts_list[j]) - verts_list.append(vm) - idx = len(verts_list) - 1 - midpoint_cache[key] = idx - return idx - - new_faces: list[tuple[int, int, int]] = [] - for tri in faces: - a = midpoint(tri[0], tri[1]) - b = midpoint(tri[1], tri[2]) - c = midpoint(tri[2], tri[0]) - new_faces.extend(((tri[0], a, c), (tri[1], b, a), (tri[2], c, b), (a, b, c))) - - verts_arr = np.asarray(verts_list, dtype=float) - return verts_arr, np.asarray(new_faces, dtype=int) - - @staticmethod - def _triangle_tangent_basis( - triangle: NDArray, normal: NDArray - ) -> Optional[tuple[np.ndarray, np.ndarray]]: - """Compute orthonormal tangential vectors for a triangle.""" - - tol = np.finfo(triangle.dtype).eps - edges = [triangle[1] - triangle[0], triangle[2] - triangle[0], triangle[2] - triangle[1]] - - edge = None - for candidate in edges: - length = np.linalg.norm(candidate) - if length > tol: - edge = (candidate / length).astype(triangle.dtype, copy=False) - break - - if edge is None: - return None +# marked as migrated to _common +from __future__ import annotations - perp1 = edge - perp2 = np.cross(normal, perp1) - perp2_norm = np.linalg.norm(perp2) - if perp2_norm <= tol: - return None - perp2 = (perp2 / perp2_norm).astype(triangle.dtype, copy=False) - return perp1, perp2 +from tidy3d._common.components.geometry.mesh import ( + AREA_SIZE_THRESHOLD, + TriangleMesh, +) diff --git a/tidy3d/components/geometry/polyslab.py b/tidy3d/components/geometry/polyslab.py index 0f077889cb..f433226fd0 100644 --- a/tidy3d/components/geometry/polyslab.py +++ b/tidy3d/components/geometry/polyslab.py @@ -1,2775 +1,17 @@ -"""Geometry extruded from polygonal shapes.""" +"""Compatibility shim for :mod:`tidy3d._common.components.geometry.polyslab`.""" -from __future__ import annotations - -import math -from copy import copy -from functools import lru_cache -from typing import TYPE_CHECKING, Any - -import autograd.numpy as np -import shapely -from autograd.tracer import getval -from numpy.polynomial.legendre import leggauss as _leggauss -from pydantic import Field, field_validator, model_validator - -from tidy3d.components.autograd import TracedArrayFloat2D, get_static -from tidy3d.components.autograd.types import TracedFloat -from tidy3d.components.autograd.utils import hasbox -from tidy3d.components.base import cached_property -from tidy3d.components.transformation import ReflectionFromPlane, RotationAroundAxis -from tidy3d.config import config -from tidy3d.constants import LARGE_NUMBER, MICROMETER, fp_eps -from tidy3d.exceptions import SetupError, Tidy3dImportError, ValidationError -from tidy3d.log import log -from tidy3d.packaging import verify_packages_import - -from . import base, triangulation - -if TYPE_CHECKING: - from typing import Optional, Union - - from gdstk import Cell - from numpy.typing import NDArray - from pydantic import PositiveFloat - - from tidy3d.compat import Self - from tidy3d.components.autograd import AutogradFieldMap - from tidy3d.components.autograd.derivative_utils import DerivativeInfo - from tidy3d.components.types import ( - ArrayFloat1D, - ArrayFloat2D, - ArrayLike, - Axis, - Bound, - Coordinate, - MatrixReal4x4, - PlanePosition, - Shapely, - ) - -# sampling polygon along dilation for validating polygon to be -# non self-intersecting during the entire dilation process -_N_SAMPLE_POLYGON_INTERSECT = 5 - -_IS_CLOSE_RTOL = np.finfo(float).eps - -# Warn for too many divided polyslabs -_COMPLEX_POLYSLAB_DIVISIONS_WARN = 100 - -# Warn before triangulating large polyslabs due to inefficiency -_MAX_POLYSLAB_VERTICES_FOR_TRIANGULATION = 500 - -_MIN_POLYGON_AREA = fp_eps - - -@lru_cache(maxsize=128) -def leggauss(n: int) -> tuple[NDArray, NDArray]: - """Cached version of leggauss with dtype conversions.""" - g, w = _leggauss(n) - return g.astype(config.adjoint.gradient_dtype_float, copy=False), w.astype( - config.adjoint.gradient_dtype_float, copy=False - ) - - -class PolySlab(base.Planar): - """Polygon extruded with optional sidewall angle along axis direction. - - Example - ------- - >>> vertices = np.array([(0,0), (1,0), (1,1)]) - >>> p = PolySlab(vertices=vertices, axis=2, slab_bounds=(-1, 1)) - """ - - slab_bounds: tuple[TracedFloat, TracedFloat] = Field( - title="Slab Bounds", - description="Minimum and maximum positions of the slab along axis dimension.", - json_schema_extra={"units": MICROMETER}, - ) - - dilation: float = Field( - 0.0, - title="Dilation", - description="Dilation of the supplied polygon by shifting each edge along its " - "normal outwards direction by a distance; a negative value corresponds to erosion.", - json_schema_extra={"units": MICROMETER}, - ) - - vertices: TracedArrayFloat2D = Field( - title="Vertices", - description="List of (d1, d2) defining the 2 dimensional positions of the polygon " - "face vertices at the ``reference_plane``. " - "The index of dimension should be in the ascending order: e.g. if " - "the slab normal axis is ``axis=y``, the coordinate of the vertices will be in (x, z)", - json_schema_extra={"units": MICROMETER}, - ) - - @staticmethod - def make_shapely_polygon(vertices: ArrayLike) -> shapely.Polygon: - """Make a shapely polygon from some vertices, first ensures they are untraced.""" - vertices = get_static(vertices) - return shapely.Polygon(vertices) - - @field_validator("slab_bounds") - @classmethod - def slab_bounds_order(cls, val: tuple[float, float]) -> tuple[float, float]: - """Maximum position of the slab should be no smaller than its minimal position.""" - if val[1] < val[0]: - raise SetupError( - "Polyslab.slab_bounds must be specified in the order of " - "minimum and maximum positions of the slab along the axis. " - f"But now the maximum {val[1]} is smaller than the minimum {val[0]}." - ) - return val - - @field_validator("vertices") - @classmethod - def correct_shape(cls, val: ArrayFloat2D) -> ArrayFloat2D: - """Makes sure vertices size is correct. Make sure no intersecting edges.""" - # overall shape of vertices - if val.shape[1] != 2: - raise SetupError( - "PolySlab.vertices must be a 2 dimensional array shaped (N, 2). " - f"Given array with shape of {val.shape}." - ) - # make sure no polygon splitting, islands, 0 area - poly_heal = shapely.make_valid(cls.make_shapely_polygon(val)) - if poly_heal.area < _MIN_POLYGON_AREA: - raise SetupError("The polygon almost collapses to a 1D curve.") - - if not poly_heal.geom_type == "Polygon" or len(poly_heal.interiors) > 0: - raise SetupError( - "Polygon is self-intersecting, resulting in " - "polygon splitting or generation of holes/islands. " - "A general treatment to self-intersecting polygon will be available " - "in future releases." - ) - return val - - @model_validator(mode="after") - def no_complex_self_intersecting_polygon_at_reference_plane(self: Self) -> Self: - """At the reference plane, check if the polygon is self-intersecting. - - There are two types of self-intersection that can occur during dilation: - 1) the one that creates holes/islands, or splits polygons, or removes everything; - 2) the one that does not. - - For 1), we issue an error since it is yet to be supported; - For 2), we heal the polygon, and warn that the polygon has been cleaned up. - """ - val = self.vertices - # no need to validate anything here - if math.isclose(self.dilation, 0): - return self - - val_np = PolySlab._proper_vertices(val) - dist = self.dilation - - # 0) fully eroded - if dist < 0 and dist < -PolySlab._maximal_erosion(val_np): - raise SetupError("Erosion value is too large. The polygon is fully eroded.") - - # no edge events - if not PolySlab._edge_events_detection(val_np, dist, ignore_at_dist=False): - return self - - poly_offset = PolySlab._shift_vertices(val_np, dist)[0] - if PolySlab._area(poly_offset) < fp_eps**2: - raise SetupError("Erosion value is too large. The polygon is fully eroded.") - - # edge events - poly_offset = shapely.make_valid(self.make_shapely_polygon(poly_offset)) - # 1) polygon split or create holes/islands - if not poly_offset.geom_type == "Polygon" or len(poly_offset.interiors) > 0: - raise SetupError( - "Dilation/Erosion value is too large, resulting in " - "polygon splitting or generation of holes/islands. " - "A general treatment to self-intersecting polygon will be available " - "in future releases." - ) - - # case 2 - log.warning( - "The dilation/erosion value is too large. resulting in a " - "self-intersecting polygon. " - "The vertices have been modified to make a valid polygon." - ) - return self - - @model_validator(mode="after") - def no_self_intersecting_polygon_during_extrusion(self: Self) -> Self: - """In this simple polyslab, we don't support self-intersecting polygons yet, meaning that - any normal cross section of the PolySlab cannot be self-intersecting. This part checks - if any self-interction will occur during extrusion with non-zero sidewall angle. - - There are two types of self-intersection, known as edge events, - that can occur during dilation: - 1) neighboring vertex-vertex crossing. This type of edge event can be treated with - ``ComplexPolySlab`` which divides the polyslab into a list of simple polyslabs. - - 2) other types of edge events that can create holes/islands or split polygons. - To detect this, we sample _N_SAMPLE_POLYGON_INTERSECT cross sections to see if any creation - of polygons/holes, and changes in vertices number. - """ - val = self.vertices - - # no need to validate anything here - # sidewall_angle may be autograd-traced; use static value for this check only - if math.isclose(getval(self.sidewall_angle), 0): - return self - - # apply dilation - poly_ref = PolySlab._proper_vertices(val) - if not math.isclose(self.dilation, 0): - poly_ref = PolySlab._shift_vertices(poly_ref, self.dilation)[0] - poly_ref = PolySlab._heal_polygon(poly_ref) - - slab_bounds = get_static(self.slab_bounds) - slab_min, slab_max = slab_bounds - - # first, check vertex-vertex crossing at any point during extrusion - length = slab_bounds[1] - slab_bounds[0] - dist = [-length * np.tan(self.sidewall_angle)] - # reverse the dilation value if it's defined on the top - if self.reference_plane == "top": - dist = [-dist[0]] - # for middle, both direction needs to be examined - elif self.reference_plane == "middle": - dist = [dist[0] / 2, -dist[0] / 2] - - # capture vertex crossing events - max_thick = [] - for dist_val in dist: - max_dist = PolySlab._neighbor_vertices_crossing_detection(poly_ref, dist_val) - - if max_dist is not None: - max_thick.append(max_dist / abs(dist_val) * length) - - if len(max_thick) > 0: - max_thick = min(max_thick) - raise SetupError( - "Sidewall angle or structure thickness is so large that the polygon " - "is self-intersecting during extrusion. " - f"Please either reduce structure thickness to be < {max_thick:.3e}, " - "or use our plugin 'ComplexPolySlab' to divide the complex polyslab " - "into a list of simple polyslabs." - ) - - # vertex-edge crossing event. - for dist_val in dist: - if PolySlab._edge_events_detection(poly_ref, dist_val): - raise SetupError( - "Sidewall angle or structure thickness is too large, " - "resulting in polygon splitting or generation of holes/islands. " - "A general treatment to self-intersecting polygon will be available " - "in future releases." - ) - return self - - @classmethod - def from_gds( - cls, - gds_cell: Cell, - axis: Axis, - slab_bounds: tuple[float, float], - gds_layer: int, - gds_dtype: Optional[int] = None, - gds_scale: PositiveFloat = 1.0, - dilation: float = 0.0, - sidewall_angle: float = 0, - reference_plane: PlanePosition = "middle", - ) -> list[PolySlab]: - """Import :class:`PolySlab` from a ``gdstk.Cell``. - - Parameters - ---------- - gds_cell : gdstk.Cell - ``gdstk.Cell`` containing 2D geometric data. - axis : int - Integer index into the polygon's slab axis. (0,1,2) -> (x,y,z). - slab_bounds: tuple[float, float] - Minimum and maximum positions of the slab along ``axis``. - gds_layer : int - Layer index in the ``gds_cell``. - gds_dtype : int = None - Data-type index in the ``gds_cell``. - If ``None``, imports all data for this layer into the returned list. - gds_scale : float = 1.0 - Length scale used in GDS file in units of MICROMETER. - For example, if gds file uses nanometers, set ``gds_scale=1e-3``. - Must be positive. - dilation : float = 0.0 - Dilation of the polygon in the base by shifting each edge along its - normal outwards direction by a distance; - a negative value corresponds to erosion. - sidewall_angle : float = 0 - Angle of the sidewall. - ``sidewall_angle=0`` (default) specifies vertical wall, - while ``0 list[ArrayFloat2D]: - """Import :class:`PolySlab` from a ``gdstk.Cell``. - - Parameters - ---------- - gds_cell : gdstk.Cell - ``gdstk.Cell`` containing 2D geometric data. - gds_layer : int - Layer index in the ``gds_cell``. - gds_dtype : int = None - Data-type index in the ``gds_cell``. - If ``None``, imports all data for this layer into the returned list. - gds_scale : float = 1.0 - Length scale used in GDS file in units of MICROMETER. - For example, if gds file uses nanometers, set ``gds_scale=1e-3``. - Must be positive. - - Returns - ------- - list[ArrayFloat2D] - List of :class:`.ArrayFloat2D` - """ - import gdstk - - gds_cell_class_name = str(gds_cell.__class__) - if not isinstance(gds_cell, gdstk.Cell): - if ( - "gdstk" in gds_cell_class_name - ): # Check if it might be a gdstk cell but gdstk is not found - raise Tidy3dImportError( - "Module 'gdstk' not found. It is required to import gdstk cells." - ) - raise ValueError( - f"validate 'gds_cell' of type '{gds_cell_class_name}' " - "does not seem to be associated with 'gdstk' package " - "and therefore can't be loaded by Tidy3D." - ) - - all_vertices = base.Geometry.load_gds_vertices_gdstk( - gds_cell=gds_cell, - gds_layer=gds_layer, - gds_dtype=gds_dtype, - gds_scale=gds_scale, - ) - - # convert vertices into polyslabs - polygons = [PolySlab.make_shapely_polygon(vertices).buffer(0) for vertices in all_vertices] - polys_union = shapely.unary_union(polygons, grid_size=base.POLY_GRID_SIZE) - - if polys_union.geom_type == "Polygon": - all_vertices = [np.array(polys_union.exterior.coords)] - elif polys_union.geom_type == "MultiPolygon": - all_vertices = [np.array(polygon.exterior.coords) for polygon in polys_union.geoms] - return all_vertices - - @property - def center_axis(self) -> float: - """Gets the position of the center of the geometry in the out of plane dimension.""" - zmin, zmax = self.slab_bounds - if np.isneginf(zmin) and np.isposinf(zmax): - return 0.0 - zmin = max(zmin, -LARGE_NUMBER) - zmax = min(zmax, LARGE_NUMBER) - return (zmax + zmin) / 2.0 - - @property - def length_axis(self) -> float: - """Gets the length of the geometry along the out of plane dimension.""" - zmin, zmax = self.slab_bounds - return zmax - zmin - - @property - def finite_length_axis(self) -> float: - """Gets the length of the PolySlab along the out of plane dimension. - First clips the slab bounds to LARGE_NUMBER and then returns difference. - """ - zmin, zmax = self.slab_bounds - zmin = max(zmin, -LARGE_NUMBER) - zmax = min(zmax, LARGE_NUMBER) - return zmax - zmin - - @cached_property - def reference_polygon(self) -> NDArray: - """The polygon at the reference plane. - - Returns - ------- - ArrayLike[float, float] - The vertices of the polygon at the reference plane. - """ - vertices = self._proper_vertices(self.vertices) - if math.isclose(self.dilation, 0): - return vertices - offset_vertices = self._shift_vertices(vertices, self.dilation)[0] - return self._heal_polygon(offset_vertices) - - @cached_property - def middle_polygon(self) -> NDArray: - """The polygon at the middle. - - Returns - ------- - ArrayLike[float, float] - The vertices of the polygon at the middle. - """ - - dist = self._extrusion_length_to_offset_distance(self.finite_length_axis / 2) - if self.reference_plane == "bottom": - return self._shift_vertices(self.reference_polygon, dist)[0] - if self.reference_plane == "top": - return self._shift_vertices(self.reference_polygon, -dist)[0] - # middle case - return self.reference_polygon - - @cached_property - def base_polygon(self) -> NDArray: - """The polygon at the base, derived from the ``middle_polygon``. - - Returns - ------- - ArrayLike[float, float] - The vertices of the polygon at the base. - """ - if self.reference_plane == "bottom": - return self.reference_polygon - dist = self._extrusion_length_to_offset_distance(-self.finite_length_axis / 2) - return self._shift_vertices(self.middle_polygon, dist)[0] - - @cached_property - def top_polygon(self) -> NDArray: - """The polygon at the top, derived from the ``middle_polygon``. - - Returns - ------- - ArrayLike[float, float] - The vertices of the polygon at the top. - """ - if self.reference_plane == "top": - return self.reference_polygon - dist = self._extrusion_length_to_offset_distance(self.finite_length_axis / 2) - return self._shift_vertices(self.middle_polygon, dist)[0] - - @cached_property - def _normal_2dmaterial(self) -> Axis: - """Get the normal to the given geometry, checking that it is a 2D geometry.""" - if self.slab_bounds[0] != self.slab_bounds[1]: - raise ValidationError("'Medium2D' requires the 'PolySlab' bounds to be equal.") - return self.axis - - def _update_from_bounds(self, bounds: tuple[float, float], axis: Axis) -> PolySlab: - """Returns an updated geometry which has been transformed to fit within ``bounds`` - along the ``axis`` direction.""" - if axis != self.axis: - raise ValueError( - f"'_update_from_bounds' may only be applied along axis '{self.axis}', " - f"but was given axis '{axis}'." - ) - return self.updated_copy(slab_bounds=tuple(bounds)) - - @cached_property - def is_ccw(self) -> bool: - """Is this ``PolySlab`` CCW-oriented?""" - return PolySlab._area(self.vertices) > 0 - - def inside(self, x: NDArray[float], y: NDArray[float], z: NDArray[float]) -> NDArray[bool]: - """For input arrays ``x``, ``y``, ``z`` of arbitrary but identical shape, return an array - with the same shape which is ``True`` for every point in zip(x, y, z) that is inside the - volume of the :class:`Geometry`, and ``False`` otherwise. - - Note - ---- - For slanted sidewalls, this function only works if x, y, and z are arrays produced by a - ``meshgrid call``, i.e. 3D arrays and each is constant along one axis. - - Parameters - ---------- - x : np.ndarray[float] - Array of point positions in x direction. - y : np.ndarray[float] - Array of point positions in y direction. - z : np.ndarray[float] - Array of point positions in z direction. - - Returns - ------- - np.ndarray[bool] - ``True`` for every point that is inside the geometry. - """ - self._ensure_equal_shape(x, y, z) - - z, (x, y) = self.pop_axis((x, y, z), axis=self.axis) - - z0 = self.center_axis - dist_z = np.abs(z - z0) - inside_height = dist_z <= (self.finite_length_axis / 2) - - # avoid going into face checking if no points are inside slab bounds - if not np.any(inside_height): - return inside_height - - # check what points are inside polygon cross section (face) - z_local = z - z0 # distance to the middle - dist = -z_local * self._tanq - - if isinstance(x, np.ndarray): - inside_polygon = np.zeros_like(inside_height) - xs_slab = x[inside_height] - ys_slab = y[inside_height] - - # vertical sidewall - if math.isclose(self.sidewall_angle, 0): - face_polygon = shapely.Polygon(self.reference_polygon).buffer(fp_eps) - shapely.prepare(face_polygon) - inside_polygon_slab = shapely.contains_xy(face_polygon, x=xs_slab, y=ys_slab) - inside_polygon[inside_height] = inside_polygon_slab - # slanted sidewall, offsetting vertices at each z - else: - # a helper function for moving axis - def _move_axis(arr: NDArray) -> NDArray: - return np.moveaxis(arr, source=self.axis, destination=-1) - - def _move_axis_reverse(arr: NDArray) -> NDArray: - return np.moveaxis(arr, source=-1, destination=self.axis) - - inside_polygon_axis = _move_axis(inside_polygon) - x_axis = _move_axis(x) - y_axis = _move_axis(y) - - for z_i in range(z.shape[self.axis]): - if not _move_axis(inside_height)[0, 0, z_i]: - continue - vertices_z = self._shift_vertices( - self.middle_polygon, _move_axis(dist)[0, 0, z_i] - )[0] - face_polygon = shapely.Polygon(vertices_z).buffer(fp_eps) - shapely.prepare(face_polygon) - xs = x_axis[:, :, 0].flatten() - ys = y_axis[:, :, 0].flatten() - inside_polygon_slab = shapely.contains_xy(face_polygon, x=xs, y=ys) - inside_polygon_axis[:, :, z_i] = inside_polygon_slab.reshape(x_axis.shape[:2]) - inside_polygon = _move_axis_reverse(inside_polygon_axis) - else: - vertices_z = self._shift_vertices(self.middle_polygon, dist)[0] - face_polygon = self.make_shapely_polygon(vertices_z).buffer(fp_eps) - point = shapely.Point(x, y) - inside_polygon = face_polygon.covers(point) - return inside_height * inside_polygon - - @verify_packages_import(["trimesh"]) - def _do_intersections_tilted_plane( - self, - normal: Coordinate, - origin: Coordinate, - to_2D: MatrixReal4x4, - quad_segs: Optional[int] = None, - ) -> list[Shapely]: - """Return a list of shapely geometries at the plane specified by normal and origin. - - Parameters - ---------- - normal : Coordinate - Vector defining the normal direction to the plane. - origin : Coordinate - Vector defining the plane origin. - to_2D : MatrixReal4x4 - Transformation matrix to apply to resulting shapes. - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. Not used for PolySlab geometry. - - Returns - ------- - list[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation `_. - """ - import trimesh - - if len(self.base_polygon) > _MAX_POLYSLAB_VERTICES_FOR_TRIANGULATION: - log.warning( - f"Processing PolySlabs with over {_MAX_POLYSLAB_VERTICES_FOR_TRIANGULATION} vertices can be slow.", - log_once=True, - ) - base_triangles = triangulation.triangulate(self.base_polygon) - top_triangles = ( - base_triangles - if math.isclose(self.sidewall_angle, 0) - else triangulation.triangulate(self.top_polygon) - ) - - n = len(self.base_polygon) - faces = ( - [[a, b, c] for c, b, a in base_triangles] - + [[n + a, n + b, n + c] for a, b, c in top_triangles] - + [(i, (i + 1) % n, n + i) for i in range(n)] - + [((i + 1) % n, n + ((i + 1) % n), n + i) for i in range(n)] - ) - - x = np.hstack((self.base_polygon[:, 0], self.top_polygon[:, 0])) - y = np.hstack((self.base_polygon[:, 1], self.top_polygon[:, 1])) - z = np.hstack((np.full(n, self.slab_bounds[0]), np.full(n, self.slab_bounds[1]))) - vertices = np.vstack(self.unpop_axis(z, (x, y), self.axis)).T - mesh = trimesh.Trimesh(vertices, faces) - - section = mesh.section(plane_origin=origin, plane_normal=normal) - if section is None: - return [] - path, _ = section.to_2D(to_2D=to_2D) - return path.polygons_full - - def _intersections_normal(self, z: float, quad_segs: Optional[int] = None) -> list[Shapely]: - """Find shapely geometries intersecting planar geometry with axis normal to slab. - - Parameters - ---------- - z : float - Position along the axis normal to slab. - - Returns - ------- - list[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation `_. - """ - if math.isclose(self.sidewall_angle, 0): - return [self.make_shapely_polygon(self.reference_polygon)] - - z0 = self.center_axis - z_local = z - z0 # distance to the middle - dist = -z_local * self._tanq - vertices_z = self._shift_vertices(self.middle_polygon, dist)[0] - return [self.make_shapely_polygon(vertices_z)] - - def _intersections_side(self, position: float, axis: int) -> list[Shapely]: - """Find shapely geometries intersecting planar geometry with axis orthogonal to slab. - - For slanted polyslab, the procedure is as follows, - 1) Find out all z-coordinates where the plane will intersect directly with a vertex. - Denote the coordinates as (z_0, z_1, z_2, ... ) - 2) Find out all polygons that can be formed between z_i and z_{i+1}. There are two - types of polygons: - a) formed by the plane intersecting the edges - b) formed by the plane intersecting the vertices. - For either type, one needs to compute: - i) intersecting position - ii) angle between the plane and the intersecting edge - For a), both are straightforward to compute; while for b), one needs to compute - which edge the plane will slide into. - 3) Looping through z_i, and merge all polygons. The partition by z_i is because once - the plane intersects the vertex, it can intersect with other edges during - the extrusion. - - Parameters - ---------- - position : float - Position along ``axis``. - axis : int - Integer index into 'xyz' (0,1,2). - - Returns - ------- - list[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation `_. - """ - - # find out all z_i where the plane will intersect the vertex - z0 = self.center_axis - z_base = z0 - self.finite_length_axis / 2 - - axis_ordered = self._order_axis(axis) - height_list = self._find_intersecting_height(position, axis_ordered) - polys = [] - - # looping through z_i to assemble the polygons - height_list = np.append(height_list, self.finite_length_axis) - h_base = 0.0 - for h_top in height_list: - # length within between top and bottom - h_length = h_top - h_base - - # coordinate of each subsection - z_min = z_base + h_base - z_max = np.inf if np.isposinf(h_top) else z_base + h_top - - # for vertical sidewall, no need for complications - if math.isclose(self.sidewall_angle, 0): - ints_y, ints_angle = self._find_intersecting_ys_angle_vertical( - self.reference_polygon, position, axis_ordered - ) - else: - # for slanted sidewall, move up by `fp_eps` in case vertices are degenerate at the base. - dist = -(h_base - self.finite_length_axis / 2 + fp_eps) * self._tanq - vertices = self._shift_vertices(self.middle_polygon, dist)[0] - ints_y, ints_angle = self._find_intersecting_ys_angle_slant( - vertices, position, axis_ordered - ) - - # make polygon with intersections and z axis information - for y_index in range(len(ints_y) // 2): - y_min = ints_y[2 * y_index] - y_max = ints_y[2 * y_index + 1] - minx, miny = self._order_by_axis(plane_val=y_min, axis_val=z_min, axis=axis) - maxx, maxy = self._order_by_axis(plane_val=y_max, axis_val=z_max, axis=axis) - - if math.isclose(self.sidewall_angle, 0): - polys.append(self.make_shapely_box(minx, miny, maxx, maxy)) - else: - angle_min = ints_angle[2 * y_index] - angle_max = ints_angle[2 * y_index + 1] - - angle_min = np.arctan(np.tan(self.sidewall_angle) / np.sin(angle_min)) - angle_max = np.arctan(np.tan(self.sidewall_angle) / np.sin(angle_max)) - - dy_min = h_length * np.tan(angle_min) - dy_max = h_length * np.tan(angle_max) - - x1, y1 = self._order_by_axis(plane_val=y_min, axis_val=z_min, axis=axis) - x2, y2 = self._order_by_axis(plane_val=y_max, axis_val=z_min, axis=axis) - x3, y3 = self._order_by_axis( - plane_val=y_max - dy_max, axis_val=z_max, axis=axis - ) - x4, y4 = self._order_by_axis( - plane_val=y_min + dy_min, axis_val=z_max, axis=axis - ) - vertices = ((x1, y1), (x2, y2), (x3, y3), (x4, y4)) - polys.append(self.make_shapely_polygon(vertices).buffer(0)) - # update the base coordinate for the next subsection - h_base = h_top - - # merge touching polygons - polys_union = shapely.unary_union(polys, grid_size=base.POLY_GRID_SIZE) - if polys_union.geom_type == "Polygon": - return [polys_union] - if polys_union.geom_type == "MultiPolygon": - return polys_union.geoms - # in other cases, just return the original unmerged polygons - return polys - - def _find_intersecting_height(self, position: float, axis: int) -> NDArray: - """Found a list of height where the plane will intersect with the vertices; - For vertical sidewall, just return np.array([]). - Assumes axis is handles so this function works on xy plane. - - Parameters - ---------- - position : float - position along axis. - axis : int - Integer index into 'xyz' (0,1,2). - - Returns - np.ndarray - Height (relative to the base) where the plane will intersect with vertices. - """ - if math.isclose(self.sidewall_angle, 0): - return np.array([]) - - # shift rate - dist = 1.0 - shift_x, shift_y = PolySlab._shift_vertices(self.middle_polygon, dist)[2] - shift_val = shift_x if axis == 0 else shift_y - shift_val[np.isclose(shift_val, 0, rtol=_IS_CLOSE_RTOL)] = np.inf # for static vertices - - # distance to the plane in the direction of vertex shifting - distance = self.middle_polygon[:, axis] - position - height = distance / self._tanq / shift_val + self.finite_length_axis / 2 - height = np.unique(height) - # further filter very close ones - is_not_too_close = np.insert((np.diff(height) > fp_eps), 0, True) - height = height[is_not_too_close] - - height = height[height > fp_eps] - height = height[height < self.finite_length_axis - fp_eps] - return height - - def _find_intersecting_ys_angle_vertical( - self, - vertices: NDArray, - position: float, - axis: int, - exclude_on_vertices: bool = False, - ) -> tuple[NDArray, NDArray, NDArray]: - """Finds pairs of forward and backwards vertices where polygon intersects position at axis, - Find intersection point (in y) assuming straight line,and intersecting angle between plane - and edges. (For unslanted polyslab). - Assumes axis is handles so this function works on xy plane. - - Parameters - ---------- - vertices : np.ndarray - Shape (N, 2) defining the polygon vertices in the xy-plane. - position : float - position along axis. - axis : int - Integer index into 'xyz' (0,1,2). - exclude_on_vertices : bool = False - Whether to exclude those intersecting directly with the vertices. - - Returns - ------- - Union[np.ndarray, np.ndarray] - List of intersection points along y direction. - List of angles between plane and edges. - """ - - vertices_axis = vertices - - # flip vertices x,y for axis = y - if axis == 1: - vertices_axis = np.roll(vertices_axis, shift=1, axis=1) - - # get the forward vertices - vertices_f = np.roll(vertices_axis, shift=-1, axis=0) - - # x coordinate of the two sets of vertices - x_vertices_f, _ = vertices_f.T - x_vertices_axis, _ = vertices_axis.T - - # Find which segments intersect: - # 1. Strictly crossing: one endpoint strictly left, one strictly right - # 2. Touching: exactly one endpoint on the plane (xor), which excludes - # edges lying entirely on the plane (both endpoints at position). - orig_on_plane = np.isclose(x_vertices_axis, position, rtol=_IS_CLOSE_RTOL) - f_on_plane = np.roll(orig_on_plane, shift=-1) - crosses_b = (x_vertices_axis > position) & (x_vertices_f < position) - crosses_f = (x_vertices_axis < position) & (x_vertices_f > position) - - if exclude_on_vertices: - # exclude vertices at the position - not_touching = np.logical_not(orig_on_plane | f_on_plane) - intersects_segment = (crosses_b | crosses_f) & not_touching - else: - single_touch = np.logical_xor(orig_on_plane, f_on_plane) - intersects_segment = crosses_b | crosses_f | single_touch - - iverts_b = vertices_axis[intersects_segment] - iverts_f = vertices_f[intersects_segment] - - # intersecting positions and angles - ints_y = [] - ints_angle = [] - for vertices_f_local, vertices_b_local in zip(iverts_b, iverts_f): - x1, y1 = vertices_f_local - x2, y2 = vertices_b_local - slope = (y2 - y1) / (x2 - x1) - y = y1 + slope * (position - x1) - ints_y.append(y) - ints_angle.append(np.pi / 2 - np.arctan(np.abs(slope))) - - ints_y = np.array(ints_y) - ints_angle = np.array(ints_angle) - - # Get rid of duplicate intersection points (vertices counted twice if directly on position) - ints_y_sort, sort_index = np.unique(ints_y, return_index=True) - ints_angle_sort = ints_angle[sort_index] - - # For tangent touches (vertex on plane, both neighbors on same side), - # add y-value back to form a degenerate pair - if not exclude_on_vertices: - n = len(vertices_axis) - for idx in np.where(orig_on_plane)[0]: - prev_on = orig_on_plane[(idx - 1) % n] - next_on = orig_on_plane[(idx + 1) % n] - if not prev_on and not next_on: - prev_side = x_vertices_axis[(idx - 1) % n] > position - next_side = x_vertices_axis[(idx + 1) % n] > position - if prev_side == next_side: - ints_y_sort = np.append(ints_y_sort, vertices_axis[idx, 1]) - ints_angle_sort = np.append(ints_angle_sort, 0) - - sort_index = np.argsort(ints_y_sort) - ints_y_sort = ints_y_sort[sort_index] - ints_angle_sort = ints_angle_sort[sort_index] - return ints_y_sort, ints_angle_sort - - def _find_intersecting_ys_angle_slant( - self, vertices: NDArray, position: float, axis: int - ) -> tuple[NDArray, NDArray, NDArray]: - """Finds pairs of forward and backwards vertices where polygon intersects position at axis, - Find intersection point (in y) assuming straight line,and intersecting angle between plane - and edges. (For slanted polyslab) - Assumes axis is handles so this function works on xy plane. - - Parameters - ---------- - vertices : np.ndarray - Shape (N, 2) defining the polygon vertices in the xy-plane. - position : float - position along axis. - axis : int - Integer index into 'xyz' (0,1,2). - - Returns - ------- - Union[np.ndarray, np.ndarray] - List of intersection points along y direction. - List of angles between plane and edges. - """ - - vertices_axis = vertices.copy() - # flip vertices x,y for axis = y - if axis == 1: - vertices_axis = np.roll(vertices_axis, shift=1, axis=1) - - # get the forward vertices - vertices_f = np.roll(vertices_axis, shift=-1, axis=0) - # get the backward vertices - vertices_b = np.roll(vertices_axis, shift=1, axis=0) - - ## First part, plane intersects with edges, same as vertical - ints_y, ints_angle = self._find_intersecting_ys_angle_vertical( - vertices, position, axis, exclude_on_vertices=True - ) - ints_y = ints_y.tolist() - ints_angle = ints_angle.tolist() - - ## Second part, plane intersects directly with vertices - # vertices on the intersection - intersects_on = np.isclose(vertices_axis[:, 0], position, rtol=_IS_CLOSE_RTOL) - iverts_on = vertices_axis[intersects_on] - # position of the neighbouring vertices - iverts_b = vertices_b[intersects_on] - iverts_f = vertices_f[intersects_on] - # shift rate - dist = -np.sign(self.sidewall_angle) - shift_x, shift_y = self._shift_vertices(self.middle_polygon, dist)[2] - shift_val = shift_x if axis == 0 else shift_y - shift_val = shift_val[intersects_on] - - for vertices_f_local, vertices_b_local, vertices_on_local, shift_local in zip( - iverts_f, iverts_b, iverts_on, shift_val - ): - x_on, y_on = vertices_on_local - x_f, y_f = vertices_f_local - x_b, y_b = vertices_b_local - - num_added = 0 # keep track the number of added vertices - slope = [] # list of slopes for added vertices - # case 1, shifting velocity is 0 - if np.isclose(shift_local, 0, rtol=_IS_CLOSE_RTOL): - ints_y.append(y_on) - # Slope w.r.t. forward and backward should equal, - # just pick one of them. - slope.append((y_on - y_b) / (x_on - x_b)) - ints_angle.append(np.pi / 2 - np.arctan(np.abs(slope[0]))) - continue - - # case 2, shifting towards backward direction - if (x_b - position) * shift_local < 0: - ints_y.append(y_on) - slope.append((y_on - y_b) / (x_on - x_b)) - num_added += 1 - - # case 3, shifting towards forward direction - if (x_f - position) * shift_local < 0: - ints_y.append(y_on) - slope.append((y_on - y_f) / (x_on - x_f)) - num_added += 1 - - # in case 2, and case 3, if just num_added = 1 - if num_added == 1: - ints_angle.append(np.pi / 2 - np.arctan(np.abs(slope[0]))) - # if num_added = 2, the order of the two new vertices needs to handled correctly; - # it should be sorted according to the -slope * moving direction - elif num_added == 2: - dressed_slope = [-s_i * shift_local for s_i in slope] - sort_index = np.argsort(np.array(dressed_slope)) - sorted_slope = np.array(slope)[sort_index] - - ints_angle.append(np.pi / 2 - np.arctan(np.abs(sorted_slope[0]))) - ints_angle.append(np.pi / 2 - np.arctan(np.abs(sorted_slope[1]))) - - ints_y = np.array(ints_y) - ints_angle = np.array(ints_angle) - - sort_index = np.argsort(ints_y) - ints_y_sort = ints_y[sort_index] - ints_angle_sort = ints_angle[sort_index] - - return ints_y_sort, ints_angle_sort - - @cached_property - def bounds(self) -> Bound: - """Returns bounding box min and max coordinates. The dilation and slant angle are not - taken into account exactly for speed. Instead, the polygon may be slightly smaller than - the returned bounds, but it should always be fully contained. - - Returns - ------- - tuple[float, float, float], tuple[float, float float] - Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. - """ - - # check for the maximum possible contribution from dilation/slant on each side - max_offset = self.dilation - # sidewall_angle may be autograd-traced; unbox for this check - if not math.isclose(getval(self.sidewall_angle), 0): - if self.reference_plane == "bottom": - max_offset += max(0, -self._tanq * self.finite_length_axis) - elif self.reference_plane == "top": - max_offset += max(0, self._tanq * self.finite_length_axis) - elif self.reference_plane == "middle": - max_offset += max(0, abs(self._tanq) * self.finite_length_axis / 2) - - # special care when dilated - if max_offset > 0: - dilated_vertices = self._shift_vertices( - self._proper_vertices(self.vertices), max_offset - )[0] - xmin, ymin = np.amin(dilated_vertices, axis=0) - xmax, ymax = np.amax(dilated_vertices, axis=0) - else: - # otherwise, bounds are directly based on the supplied vertices - xmin, ymin = np.amin(self.vertices, axis=0) - xmax, ymax = np.amax(self.vertices, axis=0) - - # get bounds in (local) z - zmin, zmax = self.slab_bounds - - # rearrange axes - coords_min = self.unpop_axis(zmin, (xmin, ymin), axis=self.axis) - coords_max = self.unpop_axis(zmax, (xmax, ymax), axis=self.axis) - return (tuple(coords_min), tuple(coords_max)) - - def _extrusion_length_to_offset_distance(self, extrusion: float) -> float: - """Convert extrusion length to offset distance.""" - if math.isclose(self.sidewall_angle, 0): - return 0 - return -extrusion * self._tanq - - @staticmethod - def _area(vertices: NDArray) -> float: - """Compute the signed polygon area (positive for CCW orientation). - - Parameters - ---------- - vertices : np.ndarray - Shape (N, 2) defining the polygon vertices in the xy-plane. - - Returns - ------- - float - Signed polygon area (positive for CCW orientation). - """ - vert_shift = np.roll(vertices, axis=0, shift=-1) - - xs, ys = vertices.T - xs_shift, ys_shift = vert_shift.T - - term1 = xs * ys_shift - term2 = ys * xs_shift - return np.sum(term1 - term2) * 0.5 - - @staticmethod - def _perimeter(vertices: NDArray) -> float: - """Compute the polygon perimeter. - - Parameters - ---------- - vertices : np.ndarray - Shape (N, 2) defining the polygon vertices in the xy-plane. - - Returns - ------- - float - Polygon perimeter. - """ - - vert_shift = np.roll(vertices, axis=0, shift=-1) - squared_diffs = (vertices - vert_shift) ** 2 - - # distance along each edge - dists = np.sqrt(squared_diffs.sum(axis=-1)) - - # total distance along all edges - return np.sum(dists) - - @staticmethod - def _orient(vertices: NDArray) -> NDArray: - """Return a CCW-oriented polygon. - - Parameters - ---------- - vertices : np.ndarray - Shape (N, 2) defining the polygon vertices in the xy-plane. - - Returns - ------- - np.ndarray - Vertices of a CCW-oriented polygon. - """ - return vertices if PolySlab._area(vertices) > 0 else vertices[::-1, :] - - @staticmethod - def _remove_duplicate_vertices(vertices: NDArray) -> NDArray: - """Remove redundant/identical nearest neighbour vertices. - - Parameters - ---------- - vertices : np.ndarray - Shape (N, 2) defining the polygon vertices in the xy-plane. - - Returns - ------- - np.ndarray - Vertices of polygon. - """ - - vertices_f = np.roll(vertices, shift=-1, axis=0) - vertices_diff = np.linalg.norm(vertices - vertices_f, axis=1) - return vertices[~np.isclose(vertices_diff, 0, rtol=_IS_CLOSE_RTOL)] - - @staticmethod - def _proper_vertices(vertices: ArrayFloat2D) -> NDArray: - """convert vertices to np.array format, - removing duplicate neighbouring vertices, - and oriented in CCW direction. - - Returns - ------- - ArrayLike[float, float] - The vertices of the polygon for internal use. - """ - vertices_np = np.array(vertices) - return PolySlab._orient(PolySlab._remove_duplicate_vertices(vertices_np)) - - @staticmethod - def _edge_events_detection( - proper_vertices: NDArray, dilation: float, ignore_at_dist: bool = True - ) -> bool: - """Detect any edge events within the offset distance ``dilation``. - If ``ignore_at_dist=True``, the edge event at ``dist`` is ignored. - """ - - # ignore the event that occurs right at the offset distance - if ignore_at_dist: - dilation -= fp_eps * dilation / abs(dilation) - # number of vertices before offsetting - num_vertices = proper_vertices.shape[0] - - # 0) fully eroded? - if dilation < 0 and dilation < -PolySlab._maximal_erosion(proper_vertices): - return True - - # sample at a few dilation values - dist_list = ( - dilation - * np.linspace( - 0, 1, 1 + _N_SAMPLE_POLYGON_INTERSECT, dtype=config.adjoint.gradient_dtype_float - )[1:] - ) - for dist in dist_list: - # offset: we offset the vertices first, and then use shapely to make it proper - # in principle, one can offset with shapely.buffer directly, but shapely somehow - # automatically removes some vertices even though no change of topology. - poly_offset = PolySlab._shift_vertices(proper_vertices, dist)[0] - # flipped winding number - if PolySlab._area(poly_offset) < fp_eps**2: - return True - - poly_offset = shapely.make_valid(PolySlab.make_shapely_polygon(poly_offset)) - # 1) polygon split or create holes/islands - if not poly_offset.geom_type == "Polygon" or len(poly_offset.interiors) > 0: - return True - - # 2) reduction in vertex number - offset_vertices = PolySlab._proper_vertices(poly_offset.exterior.coords) - if offset_vertices.shape[0] != num_vertices: - return True - - # 3) some split polygon might fully disappear after the offset, but they - # can be detected if we offset back. - poly_offset_back = shapely.make_valid( - PolySlab.make_shapely_polygon(PolySlab._shift_vertices(offset_vertices, -dist)[0]) - ) - if poly_offset_back.geom_type == "MultiPolygon" or len(poly_offset_back.interiors) > 0: - return True - offset_back_vertices = poly_offset_back.exterior.coords - if PolySlab._proper_vertices(offset_back_vertices).shape[0] != num_vertices: - return True - - return False - - @staticmethod - def _neighbor_vertices_crossing_detection( - vertices: NDArray, dist: float, ignore_at_dist: bool = True - ) -> float: - """Detect if neighboring vertices will cross after a dilation distance dist. - - Parameters - ---------- - vertices : np.ndarray - Shape (N, 2) defining the polygon vertices in the xy-plane. - dist : float - Distance to offset. - ignore_at_dist : bool, optional - whether to ignore the event right at ``dist`. - - Returns - ------- - float - the absolute value of the maximal allowed dilation - if there are any crossing, otherwise return ``None``. - """ - # ignore the event that occurs right at the offset distance - if ignore_at_dist: - dist -= fp_eps * dist / abs(dist) - - edge_length, edge_reduction = PolySlab._edge_length_and_reduction_rate(vertices) - length_remaining = edge_length - edge_reduction * dist - - if np.any(length_remaining < 0): - index_oversized = length_remaining < 0 - max_dist = np.min( - np.abs(edge_length[index_oversized] / edge_reduction[index_oversized]) - ) - return max_dist - return None - - @staticmethod - def array_to_vertices(arr_vertices: NDArray) -> ArrayFloat2D: - """Converts a numpy array of vertices to a list of tuples.""" - return list(arr_vertices) - - @staticmethod - def vertices_to_array(vertices_tuple: ArrayFloat2D) -> NDArray: - """Converts a list of tuples (vertices) to a numpy array.""" - return np.array(vertices_tuple) - - @cached_property - def interior_angle(self) -> ArrayFloat1D: - """Angle formed inside polygon by two adjacent edges.""" - - def normalize(v: NDArray) -> NDArray: - return v / np.linalg.norm(v, axis=0) - - vs_orig = self.reference_polygon.T - vs_next = np.roll(vs_orig, axis=-1, shift=-1) - vs_previous = np.roll(vs_orig, axis=-1, shift=+1) - - asp = normalize(vs_next - vs_orig) - asm = normalize(vs_previous - vs_orig) - - cos_angle = asp[0] * asm[0] + asp[1] * asm[1] - sin_angle = asp[0] * asm[1] - asp[1] * asm[0] - - angle = np.arccos(cos_angle) - # concave angles - angle[sin_angle < 0] = 2 * np.pi - angle[sin_angle < 0] - return angle - - @staticmethod - def _shift_vertices( - vertices: NDArray, dist: float - ) -> tuple[NDArray, NDArray, tuple[NDArray, NDArray]]: - """Shifts the vertices of a polygon outward uniformly by distances - `dists`. - - Parameters - ---------- - np.ndarray - Shape (N, 2) defining the polygon vertices in the xy-plane. - dist : float - Distance to offset. - - Returns - ------- - tuple[np.ndarray, np.narray,tuple[np.ndarray,np.ndarray]] - New polygon vertices; - and the shift of vertices in direction parallel to the edges. - Shift along x and y direction. - """ - - # 'dist' may be autograd-traced; unbox for the zero-check only - if math.isclose(getval(dist), 0): - return vertices, np.zeros(vertices.shape[0], dtype=float), None - - def rot90(v: tuple[NDArray, NDArray]) -> NDArray: - """90 degree rotation of 2d vector - vx -> vy - vy -> -vx - """ - vxs, vys = v - return np.stack((-vys, vxs), axis=0) - - def cross(u: NDArray, v: NDArray) -> Any: - return u[0] * v[1] - u[1] * v[0] - - def normalize(v: NDArray) -> NDArray: - return v / np.linalg.norm(v, axis=0) - - vs_orig = copy(vertices.T) - vs_next = np.roll(copy(vs_orig), axis=-1, shift=-1) - vs_previous = np.roll(copy(vs_orig), axis=-1, shift=+1) - - asp = normalize(vs_next - vs_orig) - asm = normalize(vs_orig - vs_previous) - - # the vertex shift is decomposed into parallel and perpendicular directions - perpendicular_shift = -dist - det = cross(asm, asp) - - tan_half_angle = np.where( - np.isclose(det, 0, rtol=_IS_CLOSE_RTOL), - 0.0, - cross(asm, rot90(asm - asp)) / (det + np.isclose(det, 0, rtol=_IS_CLOSE_RTOL)), - ) - parallel_shift = dist * tan_half_angle +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility - shift_total = perpendicular_shift * rot90(asm) + parallel_shift * asm - shift_x = shift_total[0, :] - shift_y = shift_total[1, :] - - return ( - np.swapaxes(vs_orig + shift_total, -2, -1), - parallel_shift, - (shift_x, shift_y), - ) - - @staticmethod - def _edge_length_and_reduction_rate( - vertices: NDArray, - ) -> tuple[NDArray, NDArray]: - """Edge length of reduction rate of each edge with unit offset length. - - Parameters - ---------- - vertices : np.ndarray - Shape (N, 2) defining the polygon vertices in the xy-plane. - - Returns - ------- - tuple[np.ndarray, np.narray] - edge length, and reduction rate - """ - - # edge length - vs_orig = copy(vertices.T) - vs_next = np.roll(copy(vs_orig), axis=-1, shift=-1) - edge_length = np.linalg.norm(vs_next - vs_orig, axis=0) - - # edge length remaining - dist = 1 - parallel_shift = PolySlab._shift_vertices(vertices, dist)[1] - parallel_shift_p = np.roll(copy(parallel_shift), shift=-1) - edge_reduction = -(parallel_shift + parallel_shift_p) - return edge_length, edge_reduction - - @staticmethod - def _maximal_erosion(vertices: NDArray) -> float: - """The erosion value that reduces the length of - all edges to be non-positive. - """ - edge_length, edge_reduction = PolySlab._edge_length_and_reduction_rate(vertices) - ind_nonzero = abs(edge_reduction) > fp_eps - return -np.min(edge_length[ind_nonzero] / edge_reduction[ind_nonzero]) - - @staticmethod - def _heal_polygon(vertices: NDArray) -> NDArray: - """heal a self-intersecting polygon.""" - shapely_poly = PolySlab.make_shapely_polygon(vertices) - if shapely_poly.is_valid: - return vertices - elif hasbox(vertices): - raise NotImplementedError( - "The dilation caused damage to the polygon. " - "Automatically healing this is currently not supported when " - "differentiating w.r.t. the vertices. Try increasing the spacing " - "between vertices or reduce the amount of dilation." - ) - # perform healing - poly_heal = shapely.make_valid(shapely_poly) - return PolySlab._proper_vertices(list(poly_heal.exterior.coords)) - - def _volume(self, bounds: Bound) -> float: - """Returns object's volume within given bounds.""" - - z_min, z_max = self.slab_bounds - - z_min = max(z_min, bounds[0][self.axis]) - z_max = min(z_max, bounds[1][self.axis]) - - length = z_max - z_min - - top_area = abs(self._area(self.top_polygon)) - base_area = abs(self._area(self.base_polygon)) - - # https://mathworld.wolfram.com/PyramidalFrustum.html - return 1.0 / 3.0 * length * (top_area + base_area + np.sqrt(top_area * base_area)) - - def _surface_area(self, bounds: Bound) -> float: - """Returns object's surface area within given bounds.""" - - area = 0 - - top_polygon = self.top_polygon - base_polygon = self.base_polygon - - top_area = abs(self._area(top_polygon)) - base_area = abs(self._area(base_polygon)) - - top_perim = self._perimeter(top_polygon) - base_perim = self._perimeter(base_polygon) - - z_min, z_max = self.slab_bounds - - if z_min < bounds[0][self.axis]: - z_min = bounds[0][self.axis] - else: - area += base_area - - if z_max > bounds[1][self.axis]: - z_max = bounds[1][self.axis] - else: - area += top_area - - length = z_max - z_min - - area += 0.5 * (top_perim + base_perim) * length - - return area - - """ Autograd code """ - - def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: - """ - Return VJPs while handling several edge-cases: - - - If the slab volume does not overlap the simulation, all grads are zero - (one warning is issued). - - Faces that lie completely outside the simulation give zero ``slab_bounds`` - gradients; this includes the +/- inf cases. - - A 2d simulation collapses the surface integral to a line integral - """ - vjps: AutogradFieldMap = {} - - intersect_min, intersect_max = map(np.asarray, derivative_info.bounds_intersect) - sim_min, sim_max = map(np.asarray, derivative_info.simulation_bounds) - - extents = intersect_max - intersect_min - is_2d = np.isclose(extents[self.axis], 0.0) - - # early return if polyslab is not in simulation domain - slab_min, slab_max = self.slab_bounds - if (slab_max < sim_min[self.axis]) or (slab_min > sim_max[self.axis]): - log.warning( - "'PolySlab' lies completely outside the simulation domain.", - log_once=True, - ) - for p in derivative_info.paths: - vjps[p] = np.zeros_like(self.vertices) if p == ("vertices",) else 0.0 - return vjps - - # create interpolators once for ALL derivative computations - # use provided interpolators if available to avoid redundant field data conversions - interpolators = derivative_info.interpolators or derivative_info.create_interpolators( - dtype=config.adjoint.gradient_dtype_float - ) - - for path in derivative_info.paths: - if path == ("vertices",): - vjps[path] = self._compute_derivative_vertices( - derivative_info, sim_min, sim_max, is_2d, interpolators - ) - - elif path == ("sidewall_angle",): - vjps[path] = self._compute_derivative_sidewall_angle( - derivative_info, sim_min, sim_max, is_2d, interpolators - ) - elif path[0] == "slab_bounds": - idx = path[1] - face_coord = self.slab_bounds[idx] - - # face entirely outside -> gradient 0 - if ( - np.isinf(face_coord) - or face_coord < sim_min[self.axis] - or face_coord > sim_max[self.axis] - or is_2d - ): - vjps[path] = 0.0 - continue - - v = self._compute_derivative_slab_bounds(derivative_info, idx, interpolators) - # outward-normal convention - if idx == 0: - v *= -1 - vjps[path] = v - else: - raise ValueError(f"No derivative defined w.r.t. 'PolySlab' field '{path}'.") - - return vjps - - # ---- Shared helpers for VJP surface integrations ---- - def _z_slices( - self, sim_min: NDArray, sim_max: NDArray, is_2d: bool, dx: float - ) -> tuple[NDArray, float, float, float]: - """Compute z-slice centers and spacing within bounds. - - Returns (z_centers, dz, z0, z1). For 2D, returns single center and dz=1. - """ - if is_2d: - midpoint_z = np.maximum( - np.minimum(self.center_axis, sim_max[self.axis]), - sim_min[self.axis], - ) - zc = np.array([midpoint_z], dtype=config.adjoint.gradient_dtype_float) - return zc, 1.0, self.center_axis, self.center_axis - - z0 = max(self.slab_bounds[0], sim_min[self.axis]) - z1 = min(self.slab_bounds[1], sim_max[self.axis]) - if z1 <= z0: - return np.array([], dtype=config.adjoint.gradient_dtype_float), 0.0, z0, z1 - - n_z = max(1, int(np.ceil((z1 - z0) / dx))) - dz = (z1 - z0) / n_z - z_centers = np.linspace( - z0 + dz / 2, z1 - dz / 2, n_z, dtype=config.adjoint.gradient_dtype_float - ) - return z_centers, dz, z0, z1 - - @staticmethod - def _clip_edges_to_bounds_batch( - segment_starts: NDArray, - segment_ends: NDArray, - sim_min: NDArray, - sim_max: NDArray, - *, - _edge_clip_tol: Optional[float] = None, - _dtype: Optional[type] = None, - ) -> tuple[NDArray, NDArray, NDArray]: - """ - Compute parametric bounds for multiple segments clipped to simulation bounds. - - Parameters - ---------- - segment_starts : NDArray - (N, 3) array of segment start coordinates. - segment_ends : NDArray - (N, 3) array of segment end coordinates. - sim_min : NDArray - (3,) array of simulation minimum bounds. - sim_max : NDArray - (3,) array of simulation maximum bounds. - - Returns - ------- - is_within_bounds : NDArray - (N,) boolean array indicating if the segment intersects the bounds. - t_starts : NDArray - (N,) array of parametric start values (0.0 to 1.0). - t_ends : NDArray - (N,) array of parametric end values (0.0 to 1.0). - """ - n = segment_starts.shape[0] - if _edge_clip_tol is None: - _edge_clip_tol = config.adjoint.edge_clip_tolerance - if _dtype is None: - _dtype = config.adjoint.gradient_dtype_float - - t_starts = np.zeros(n, dtype=_dtype) - t_ends = np.ones(n, dtype=_dtype) - is_within_bounds = np.ones(n, dtype=bool) - - for dim in range(3): - start_coords = segment_starts[:, dim] - end_coords = segment_ends[:, dim] - bound_min = sim_min[dim] - bound_max = sim_max[dim] - - # check for parallel edges (faster than isclose) - parallel = np.abs(start_coords - end_coords) < 1e-12 - - # parallel edges: check if outside bounds - outside = parallel & ( - (start_coords < (bound_min - _edge_clip_tol)) - | (start_coords > (bound_max + _edge_clip_tol)) - ) - is_within_bounds &= ~outside - - # non-parallel edges: compute t_min, t_max - not_parallel = ~parallel & is_within_bounds - if np.any(not_parallel): - denom = np.where(not_parallel, end_coords - start_coords, 1.0) # avoid div by zero - t_min = (bound_min - start_coords) / denom - t_max = (bound_max - start_coords) / denom - - # swap if needed - swap = t_min > t_max - t_min_new = np.where(swap, t_max, t_min) - t_max_new = np.where(swap, t_min, t_max) - - # update t_starts and t_ends for valid non-parallel edges - t_starts = np.where(not_parallel, np.maximum(t_starts, t_min_new), t_starts) - t_ends = np.where(not_parallel, np.minimum(t_ends, t_max_new), t_ends) - - # still valid? - is_within_bounds &= ~not_parallel | (t_starts < t_ends) - - is_within_bounds &= t_ends > t_starts + _edge_clip_tol - - return is_within_bounds, t_starts, t_ends - - @staticmethod - def _adaptive_edge_samples( - L: float, - dx: float, - t_start: float = 0.0, - t_end: float = 1.0, - *, - _sample_fraction: Optional[float] = None, - _gauss_order: Optional[int] = None, - _dtype: Optional[type] = None, - ) -> tuple[NDArray, NDArray]: - """ - Compute Gauss samples and weights along [t_start, t_end] with adaptive count. - - Parameters - ---------- - L : float - Physical length of the full edge. - dx : float - Target discretization step size. - t_start : float, optional - Start parameter, by default 0.0. - t_end : float, optional - End parameter, by default 1.0. - - Returns - ------- - tuple[NDArray, NDArray] - Tuple of (samples, weights) for the integration. - """ - if _sample_fraction is None: - _sample_fraction = config.adjoint.quadrature_sample_fraction - if _gauss_order is None: - _gauss_order = config.adjoint.gauss_quadrature_order - if _dtype is None: - _dtype = config.adjoint.gradient_dtype_float - - L_eff = L * max(0.0, t_end - t_start) - n_uniform = max(1, int(np.ceil(L_eff / dx))) - n_gauss = n_uniform if n_uniform <= 3 else max(2, int(n_uniform * _sample_fraction)) - if n_gauss <= _gauss_order: - g, w = leggauss(n_gauss) - half_range = 0.5 * (t_end - t_start) - s = (half_range * g + 0.5 * (t_end + t_start)).astype(_dtype, copy=False) - wt = (w * half_range).astype(_dtype, copy=False) - return s, wt - - # composite Gauss with fixed local order - g_loc, w_loc = leggauss(_gauss_order) - segs = n_uniform - edges_t = np.linspace(t_start, t_end, segs + 1, dtype=_dtype) - - # compute all segments at once - a = edges_t[:-1] # (segs,) - b = edges_t[1:] # (segs,) - half_width = 0.5 * (b - a) # (segs,) - mid = 0.5 * (b + a) # (segs,) - - # (segs, 1) * (order,) + (segs, 1) -> (segs, order) - S = (half_width[:, None] * g_loc + mid[:, None]).astype(_dtype, copy=False) - W = (half_width[:, None] * w_loc).astype(_dtype, copy=False) - return S.ravel(), W.ravel() - - def _collect_sidewall_patches( - self, - vertices: NDArray, - next_v: NDArray, - edges: NDArray, - basis: dict, - sim_min: NDArray, - sim_max: NDArray, - is_2d: bool, - dx: float, - ) -> dict: - """ - Collect sidewall patch geometry for batched VJP evaluation. - - Parameters - ---------- - vertices : NDArray - Array of polygon vertices. - next_v : NDArray - Array of next vertices (forming edges). - edges : NDArray - Edge vectors. - basis : dict - Basis vectors dictionary. - sim_min : NDArray - Simulation minimum bounds. - sim_max : NDArray - Simulation maximum bounds. - is_2d : bool - Whether the simulation is 2D. - dx : float - Discretization step. - - Returns - ------- - dict - Dictionary containing: - - centers: (N, 3) array of patch centers. - - normals: (N, 3) array of patch normals. - - perps1: (N, 3) array of first tangent vectors. - - perps2: (N, 3) array of second tangent vectors. - - Ls: (N,) array of edge lengths. - - s_vals: (N,) array of parametric coordinates along the edge. - - s_weights: (N,) array of quadrature weights. - - zc_vals: (N,) array of z-coordinates. - - dz: float, slice thickness. - - edge_indices: (N,) array of original edge indices. - """ - # cache config values to avoid repeated lookups (overhead not insignificant here) - _dtype = config.adjoint.gradient_dtype_float - _edge_clip_tol = config.adjoint.edge_clip_tolerance - _sample_fraction = config.adjoint.quadrature_sample_fraction - _gauss_order = config.adjoint.gauss_quadrature_order - - theta = get_static(self.sidewall_angle) - z_ref = self.reference_axis_pos - - cos_th = np.cos(theta) - cos_th = np.clip(cos_th, 1e-12, 1.0) - tan_th = np.tan(theta) - dprime = -tan_th # dd/dz - - # axis unit vector in 3D - axis_vec = np.zeros(3, dtype=_dtype) - axis_vec[self.axis] = 1.0 - - # densify along axis as |theta| grows, dz scales with cos(theta) - z_centers, dz, z0, z1 = self._z_slices(sim_min, sim_max, is_2d=is_2d, dx=dx * cos_th) - - # early exit: no slices - if (not is_2d) and len(z_centers) == 0: - return { - "centers": np.empty((0, 3), dtype=_dtype), - "normals": np.empty((0, 3), dtype=_dtype), - "perps1": np.empty((0, 3), dtype=_dtype), - "perps2": np.empty((0, 3), dtype=_dtype), - "Ls": np.empty((0,), dtype=_dtype), - "s_vals": np.empty((0,), dtype=_dtype), - "s_weights": np.empty((0,), dtype=_dtype), - "zc_vals": np.empty((0,), dtype=_dtype), - "dz": dz, - "edge_indices": np.empty((0,), dtype=int), - } - - # estimate patches for pre-allocation - n_edges = len(vertices) - estimated_patches = 0 - denom_edge = max(dx * cos_th, 1e-12) - for ei in range(n_edges): - v0, v1 = vertices[ei], next_v[ei] - L = np.linalg.norm(v1 - v0) - if not np.isclose(L, 0.0): - # prealloc guided by actual step; ds_phys scales with cos(theta) - n_samples = max(1, int(np.ceil(L / denom_edge) * 0.6)) - estimated_patches += n_samples * max(1, len(z_centers)) - estimated_patches = int(max(1, estimated_patches) * 1.2) - - # pre-allocate arrays - centers = np.empty((estimated_patches, 3), dtype=_dtype) - normals = np.empty((estimated_patches, 3), dtype=_dtype) - perps1 = np.empty((estimated_patches, 3), dtype=_dtype) - perps2 = np.empty((estimated_patches, 3), dtype=_dtype) - Ls = np.empty((estimated_patches,), dtype=_dtype) - s_vals = np.empty((estimated_patches,), dtype=_dtype) - s_weights = np.empty((estimated_patches,), dtype=_dtype) - zc_vals = np.empty((estimated_patches,), dtype=_dtype) - edge_indices = np.empty((estimated_patches,), dtype=int) - - patch_idx = 0 - - # if the simulation is effectively 2D (one tangential dimension collapsed), - # slightly expand degenerate bounds to enable finite-length clipping of edges. - sim_min_eff = np.array(sim_min, dtype=_dtype) - sim_max_eff = np.array(sim_max, dtype=_dtype) - for dim in range(3): - if dim == self.axis: - continue - if np.isclose(sim_max_eff[dim] - sim_min_eff[dim], 0.0): - sim_min_eff[dim] -= 0.5 * dx - sim_max_eff[dim] += 0.5 * dx - - # pre-compute values that are constant across z slices - n_z = len(z_centers) - z_centers_arr = np.asarray(z_centers, dtype=_dtype) - - # slanted local basis (constant across z for non-slanted case) - # for slanted: rz = axis_vec + dprime * n2d, but dprime is constant - for ei, (v0, v1) in enumerate(zip(vertices, next_v)): - edge_vec = v1 - v0 - L = np.sqrt(np.dot(edge_vec, edge_vec)) - if L < 1e-12: - continue - - # constant along edge: unit tangent in 3D (no axis component) - t_edge = basis["perp1"][ei] - - # outward in-plane normal from canonical basis normal - n2d = basis["norm"][ei].copy() - n2d[self.axis] = 0.0 - nrm = np.linalg.norm(n2d) - if not np.isclose(nrm, 0.0): - n2d = n2d / nrm - else: - # fallback to right-handed construction if degenerate - tmp = np.cross(axis_vec, t_edge) - n2d = tmp / (np.linalg.norm(tmp) + 1e-20) - - # compute basis vectors once per edge - rz = axis_vec + dprime * n2d - T1_vec = t_edge - N_vec = np.cross(T1_vec, rz) - N_norm = np.linalg.norm(N_vec) - if not np.isclose(N_norm, 0.0): - N_vec = N_vec / N_norm - - # align N with outward edge normal - if float(np.dot(N_vec, basis["norm"][ei])) < 0.0: - N_vec = -N_vec - - T2_vec = np.cross(N_vec, T1_vec) - T2_norm = np.linalg.norm(T2_vec) - if not np.isclose(T2_norm, 0.0): - T2_vec = T2_vec / T2_norm - - # batch compute offsets for all z slices at once - d_all = -(z_centers_arr - z_ref) * tan_th # (n_z,) - offsets_3d = d_all[:, None] * n2d # (n_z, 3) - faster than np.outer - - # batch compute segment starts and ends for all z slices - segment_starts = np.empty((n_z, 3), dtype=_dtype) - segment_ends = np.empty((n_z, 3), dtype=_dtype) - plane_axes = [i for i in range(3) if i != self.axis] - segment_starts[:, self.axis] = z_centers_arr - segment_starts[:, plane_axes] = v0 - segment_starts += offsets_3d - segment_ends[:, self.axis] = z_centers_arr - segment_ends[:, plane_axes] = v1 - segment_ends += offsets_3d - - # batch clip all z slices at once - is_within_bounds, t_starts, t_ends = self._clip_edges_to_bounds_batch( - segment_starts, - segment_ends, - sim_min_eff, - sim_max_eff, - _edge_clip_tol=_edge_clip_tol, - _dtype=_dtype, - ) - - # process only valid z slices (sampling has variable output sizes) - valid_indices = np.nonzero(is_within_bounds)[0] - if len(valid_indices) == 0: - continue - - # group z slices by unique (t0, t1) pairs to avoid redundant quadrature calculations. - # since most z-slices will have identical clipping bounds (0.0, 1.0), - # we can compute the Gauss samples once and reuse them for almost all slices. - # rounding ensures we get cache hits despite tiny floating point differences. - t0_valid = np.round(t_starts[valid_indices], 10) - t1_valid = np.round(t_ends[valid_indices], 10) - - # simple cache for sampling results: (t0, t1) -> (s_list, w_list) - sample_cache = {} - - # process each z slice - for zi, t0, t1 in zip(valid_indices, t0_valid, t1_valid): - if (t0, t1) not in sample_cache: - sample_cache[(t0, t1)] = self._adaptive_edge_samples( - L, - denom_edge, - t0, - t1, - _sample_fraction=_sample_fraction, - _gauss_order=_gauss_order, - _dtype=_dtype, - ) - - s_list, w_list = sample_cache[(t0, t1)] - if len(s_list) == 0: - continue - - zc = z_centers_arr[zi] - offset3d = offsets_3d[zi] - - pts2d = v0 + s_list[:, None] * edge_vec # faster than np.outer - - # inline unpop_axis_vect for xyz computation - n_pts = len(s_list) - xyz = np.empty((n_pts, 3), dtype=_dtype) - xyz[:, self.axis] = zc - xyz[:, plane_axes] = pts2d - xyz += offset3d - - n_patches = n_pts - new_size_needed = patch_idx + n_patches - if new_size_needed > centers.shape[0]: - # grow arrays by 1.5x to avoid frequent reallocations - new_size = int(new_size_needed * 1.5) - centers.resize((new_size, 3), refcheck=False) - normals.resize((new_size, 3), refcheck=False) - perps1.resize((new_size, 3), refcheck=False) - perps2.resize((new_size, 3), refcheck=False) - Ls.resize((new_size,), refcheck=False) - s_vals.resize((new_size,), refcheck=False) - s_weights.resize((new_size,), refcheck=False) - zc_vals.resize((new_size,), refcheck=False) - edge_indices.resize((new_size,), refcheck=False) - - sl = slice(patch_idx, patch_idx + n_patches) - centers[sl] = xyz - normals[sl] = N_vec - perps1[sl] = T1_vec - perps2[sl] = T2_vec - Ls[sl] = L - s_vals[sl] = s_list - s_weights[sl] = w_list - zc_vals[sl] = zc - edge_indices[sl] = ei - - patch_idx += n_patches - - # trim arrays to final size - centers = centers[:patch_idx] - normals = normals[:patch_idx] - perps1 = perps1[:patch_idx] - perps2 = perps2[:patch_idx] - Ls = Ls[:patch_idx] - s_vals = s_vals[:patch_idx] - s_weights = s_weights[:patch_idx] - zc_vals = zc_vals[:patch_idx] - edge_indices = edge_indices[:patch_idx] - - return { - "centers": centers, - "normals": normals, - "perps1": perps1, - "perps2": perps2, - "Ls": Ls, - "s_vals": s_vals, - "s_weights": s_weights, - "zc_vals": zc_vals, - "dz": dz, - "edge_indices": edge_indices, - } - - def _compute_derivative_sidewall_angle( - self, - derivative_info: DerivativeInfo, - sim_min: NDArray, - sim_max: NDArray, - is_2d: bool = False, - interpolators: Optional[dict] = None, - ) -> float: - """VJP for dJ/dtheta where theta = sidewall_angle. - - Use dJ/dtheta = integral_S g(x) * V_n(x; theta) * dA, with g(x) from - `evaluate_gradient_at_points`. For a ruled sidewall built by - offsetting the mid-plane polygon by d(z) = -(z - z_ref) * tan(theta), - the normal velocity is V_n = (dd/dtheta) * cos(theta) = -(z - z_ref)/cos(theta) - and the area element is dA = (dz/cos(theta)) * d_ell. - Therefore each patch weight is w = L * dz * (-(z - z_ref)) / cos(theta)^2. - """ - if interpolators is None: - interpolators = derivative_info.create_interpolators( - dtype=config.adjoint.gradient_dtype_float - ) - - # 2D sim => no dependence on theta (z_local=0) - if is_2d: - return 0.0 - - vertices, next_v, edges, basis = self._edge_geometry_arrays() - - dx = derivative_info.adaptive_vjp_spacing() - - # collect patches once - patch = self._collect_sidewall_patches( - vertices=vertices, - next_v=next_v, - edges=edges, - basis=basis, - sim_min=sim_min, - sim_max=sim_max, - is_2d=False, - dx=dx, - ) - if patch["centers"].shape[0] == 0: - return 0.0 - - # Shape-derivative factors: - # - Offset: d(z) = -(z - z_ref) * tan(theta) - # - Tangential rate: dd/dtheta = -(z - z_ref) * sec(theta)^2 - # - Normal velocity (project to surface normal): V_n = (dd/dtheta) * cos(theta) = -(z - z_ref)/cos(theta) - # - Area element of slanted strip: dA = (dz/cos(theta)) * d_ell - # => Patch weight scales as: V_n * dA = -(z - z_ref) * dz * d_ell / cos(theta)^2 - cos_theta = np.cos(get_static(self.sidewall_angle)) - inv_cos2 = 1.0 / (cos_theta * cos_theta) - z_ref = self.reference_axis_pos - - g = derivative_info.evaluate_gradient_at_points( - patch["centers"], patch["normals"], patch["perps1"], patch["perps2"], interpolators - ) - z_local = patch["zc_vals"] - z_ref - weights = patch["Ls"] * patch["s_weights"] * patch["dz"] * (-z_local) * inv_cos2 - return float(np.real(np.sum(g * weights))) - - def _compute_derivative_slab_bounds( - self, derivative_info: DerivativeInfo, min_max_index: int, interpolators: dict - ) -> TracedArrayFloat2D: - """VJP for one of the two horizontal faces of a ``PolySlab``. - - The face is discretized into a Cartesian grid of small planar patches. - The adjoint surface integral is evaluated on every retained patch; the - resulting derivative is split equally between the two vertices that bound - the edge segment. - """ - # rmin/rmax over the geometry and simulation box - if np.isclose(self.slab_bounds[1] - self.slab_bounds[0], 0.0): - log.warning( - "Computing slab face derivatives for flat structures is not fully supported and " - "may give zero for the derivative. Try using a structure with a small, but nonzero " - "thickness for slab bound derivatives." - ) - rmin, rmax = derivative_info.bounds_intersect - _, (r1_min, r2_min) = self.pop_axis(rmin, axis=self.axis) - _, (r1_max, r2_max) = self.pop_axis(rmax, axis=self.axis) - ax_val = self.slab_bounds[min_max_index] - - # planar grid resolution, clipped to polygon bounding box - face_verts = self.base_polygon if min_max_index == 0 else self.top_polygon - face_poly = shapely.Polygon(face_verts).buffer(fp_eps) - - # limit the patch grid to the face that lives inside the simulation box - poly_min_r1, poly_min_r2, poly_max_r1, poly_max_r2 = face_poly.bounds - r1_min = max(r1_min, poly_min_r1) - r1_max = min(r1_max, poly_max_r1) - r2_min = max(r2_min, poly_min_r2) - r2_max = min(r2_max, poly_max_r2) - - # intersect the polygon with the simulation bounds - face_poly = face_poly.intersection(shapely.box(r1_min, r2_min, r1_max, r2_max)) - - if (r1_max <= r1_min) and (r2_max <= r2_min): - # the polygon does not intersect the current simulation slice - return 0.0 - - # re-compute the extents after clipping to the polygon bounds - extents = np.array([r1_max - r1_min, r2_max - r2_min]) - - # choose surface or line integral - integral_fun = ( - self.compute_derivative_slab_bounds_line - if np.isclose(extents, 0).any() - else self.compute_derivative_slab_bounds_surface - ) - return integral_fun( - derivative_info, - extents, - r1_min, - r1_max, - r2_min, - r2_max, - ax_val, - face_poly, - min_max_index, - interpolators, - ) - - def compute_derivative_slab_bounds_line( - self, - derivative_info: DerivativeInfo, - extents: NDArray, - r1_min: float, - r1_max: float, - r2_min: float, - r2_max: float, - ax_val: float, - face_poly: shapely.Polygon, - min_max_index: int, - interpolators: dict, - ) -> float: - """Handle degenerate line cross-section case""" - line_dim = 1 if np.isclose(extents[0], 0) else 0 - - poly_min_r1, poly_min_r2, poly_max_r1, poly_max_r2 = face_poly.bounds - if line_dim == 0: # x varies, y is fixed - l_min = max(r1_min, poly_min_r1) - l_max = min(r1_max, poly_max_r1) - else: # y varies, x is fixed - l_min = max(r2_min, poly_min_r2) - l_max = min(r2_max, poly_max_r2) - - length = l_max - l_min - if np.isclose(length, 0): - return 0.0 - - dx = derivative_info.adaptive_vjp_spacing() - n_seg = max(1, int(np.ceil(length / dx))) - coords = np.linspace( - l_min, l_max, 2 * n_seg + 1, dtype=config.adjoint.gradient_dtype_float - )[1::2] - - # build XY coordinates and in-plane direction vectors - if line_dim == 0: - xy = np.column_stack((coords, np.full_like(coords, r2_min))) - dir_vec_plane = np.column_stack((np.ones_like(coords), np.zeros_like(coords))) - else: - xy = np.column_stack((np.full_like(coords, r1_min), coords)) - dir_vec_plane = np.column_stack((np.zeros_like(coords), np.ones_like(coords))) - - inside = shapely.contains_xy(face_poly, xy[:, 0], xy[:, 1]) - if not inside.any(): - return 0.0 - - xy = xy[inside] - dir_vec_plane = dir_vec_plane[inside] - n_pts = len(xy) - - centers_xyz = self.unpop_axis_vect(np.full(n_pts, ax_val), xy) - areas = np.full(n_pts, length / n_seg) # patch length - - normals_xyz = self.unpop_axis_vect( - np.full( - n_pts, -1 if min_max_index == 0 else 1, dtype=config.adjoint.gradient_dtype_float - ), - np.zeros_like(xy, dtype=config.adjoint.gradient_dtype_float), - ) - perps1_xyz = self.unpop_axis_vect(np.zeros(n_pts), dir_vec_plane) - perps2_xyz = self.unpop_axis_vect(np.zeros(n_pts), np.zeros_like(dir_vec_plane)) - - vjps = derivative_info.evaluate_gradient_at_points( - centers_xyz, normals_xyz, perps1_xyz, perps2_xyz, interpolators - ) - return np.real(np.sum(vjps * areas)).item() - - def compute_derivative_slab_bounds_surface( - self, - derivative_info: DerivativeInfo, - extents: NDArray, - r1_min: float, - r1_max: float, - r2_min: float, - r2_max: float, - ax_val: float, - face_poly: shapely.Polygon, - min_max_index: int, - interpolators: dict, - ) -> float: - """2d surface integral on a Gauss quadrature grid""" - dx = derivative_info.adaptive_vjp_spacing() - - # uniform grid would use n1 x n2 points - n1_uniform, n2_uniform = np.maximum(1, np.ceil(extents / dx).astype(int)) - - # use ~1/2 Gauss points in each direction for similar accuracy - n1 = max(2, n1_uniform // 2) - n2 = max(2, n2_uniform // 2) - - g1, w1 = leggauss(n1) - g2, w2 = leggauss(n2) - - coords1 = (0.5 * (r1_max - r1_min) * g1 + 0.5 * (r1_max + r1_min)).astype( - config.adjoint.gradient_dtype_float, copy=False - ) - coords2 = (0.5 * (r2_max - r2_min) * g2 + 0.5 * (r2_max + r2_min)).astype( - config.adjoint.gradient_dtype_float, copy=False - ) - - r1_grid, r2_grid = np.meshgrid(coords1, coords2, indexing="ij") - r1_flat = r1_grid.flatten() - r2_flat = r2_grid.flatten() - pts = np.column_stack((r1_flat, r2_flat)) - - in_face = shapely.contains_xy(face_poly, pts[:, 0], pts[:, 1]) - if not in_face.any(): - return 0.0 - - xyz = self.unpop_axis_vect( - np.full(in_face.sum(), ax_val, dtype=config.adjoint.gradient_dtype_float), pts[in_face] - ) - n_patches = xyz.shape[0] - - normals_xyz = self.unpop_axis_vect( - np.full( - n_patches, - -1 if min_max_index == 0 else 1, - dtype=config.adjoint.gradient_dtype_float, - ), - np.zeros((n_patches, 2), dtype=config.adjoint.gradient_dtype_float), - ) - perps1_xyz = self.unpop_axis_vect( - np.zeros(n_patches, dtype=config.adjoint.gradient_dtype_float), - np.column_stack( - ( - np.ones(n_patches, dtype=config.adjoint.gradient_dtype_float), - np.zeros(n_patches, dtype=config.adjoint.gradient_dtype_float), - ) - ), - ) - perps2_xyz = self.unpop_axis_vect( - np.zeros(n_patches, dtype=config.adjoint.gradient_dtype_float), - np.column_stack( - ( - np.zeros(n_patches, dtype=config.adjoint.gradient_dtype_float), - np.ones(n_patches, dtype=config.adjoint.gradient_dtype_float), - ) - ), - ) - - w1_grid, w2_grid = np.meshgrid(w1, w2, indexing="ij") - weights_flat = (w1_grid * w2_grid).flatten()[in_face] - jacobian = 0.25 * (r1_max - r1_min) * (r2_max - r2_min) - - # area-based correction for non-rectangular domains (e.g. concave polygon) - # for constant integrand, integral should equal polygon area - sum_weights = np.sum(weights_flat) - if sum_weights > 0: - area_correction = face_poly.area / (sum_weights * jacobian) - weights_flat = weights_flat * area_correction - - vjps = derivative_info.evaluate_gradient_at_points( - xyz, normals_xyz, perps1_xyz, perps2_xyz, interpolators - ) - return np.real(np.sum(vjps * weights_flat * jacobian)).item() - - def _compute_derivative_vertices( - self, - derivative_info: DerivativeInfo, - sim_min: NDArray, - sim_max: NDArray, - is_2d: bool = False, - interpolators: Optional[dict] = None, - ) -> NDArray: - """VJP for the vertices of a ``PolySlab``. - - Uses shared sidewall patch collection and batched field evaluation. - """ - vertices, next_v, edges, basis = self._edge_geometry_arrays() - dx = derivative_info.adaptive_vjp_spacing() - - # collect patches once - patch = self._collect_sidewall_patches( - vertices=vertices, - next_v=next_v, - edges=edges, - basis=basis, - sim_min=sim_min, - sim_max=sim_max, - is_2d=is_2d, - dx=dx, - ) - - # early return if no patches - if patch["centers"].shape[0] == 0: - return np.zeros_like(vertices) - - dz = patch["dz"] - dz_surf = 1.0 if is_2d else dz / np.cos(self.sidewall_angle) - - # use provided interpolators or create them if not provided - if interpolators is None: - interpolators = derivative_info.create_interpolators( - dtype=config.adjoint.gradient_dtype_float - ) - - # evaluate integrand - g = derivative_info.evaluate_gradient_at_points( - patch["centers"], patch["normals"], patch["perps1"], patch["perps2"], interpolators - ) - - # compute area-based weights and weighted vjps - areas = patch["Ls"] * patch["s_weights"] * dz_surf - patch_vjps = (g * areas).real - - # distribute to vertices using vectorized accumulation - normals_2d = np.delete(basis["norm"], self.axis, axis=1) - edge_idx = patch["edge_indices"] - s = patch["s_vals"] - w0 = (1.0 - s) * patch_vjps - w1 = s * patch_vjps - edge_norms = normals_2d[edge_idx] - - # Accumulate per-vertex contributions using bincount (O(N_patches)) - num_vertices = vertices.shape[0] - contrib0 = w0[:, None] * edge_norms # (n_patches, 2) - contrib1 = w1[:, None] * edge_norms # (n_patches, 2) - - idx0 = edge_idx - idx1 = (edge_idx + 1) % num_vertices - - v0x = np.bincount(idx0, weights=contrib0[:, 0], minlength=num_vertices) - v0y = np.bincount(idx0, weights=contrib0[:, 1], minlength=num_vertices) - v1x = np.bincount(idx1, weights=contrib1[:, 0], minlength=num_vertices) - v1y = np.bincount(idx1, weights=contrib1[:, 1], minlength=num_vertices) - - vjp_per_vertex = np.stack((v0x + v1x, v0y + v1y), axis=1) - return vjp_per_vertex - - def _edge_geometry_arrays( - self, dtype: np.dtype = config.adjoint.gradient_dtype_float - ) -> tuple[NDArray, NDArray, NDArray, dict[str, NDArray]]: - """Return (vertices, next_v, edges, basis) arrays for sidewall edge geometry.""" - vertices = np.asarray(self.vertices, dtype=dtype) - next_v = np.roll(vertices, -1, axis=0) - edges = next_v - vertices - basis = self.edge_basis_vectors(edges) - return vertices, next_v, edges, basis - - def edge_basis_vectors( - self, - edges: NDArray, # (N, 2) - ) -> dict[str, NDArray]: # (N, 3) - """Normalized basis vectors for ``normal`` direction, ``slab`` tangent direction and ``edge``.""" - - # ensure edges have consistent dtype - edges = edges.astype(config.adjoint.gradient_dtype_float, copy=False) - - num_vertices, _ = edges.shape - zeros = np.zeros(num_vertices, dtype=config.adjoint.gradient_dtype_float) - ones = np.ones(num_vertices, dtype=config.adjoint.gradient_dtype_float) - - # normalized vectors along edges - edges_norm_in_plane = self.normalize_vect(edges) - edges_norm_xyz = self.unpop_axis_vect(zeros, edges_norm_in_plane) - - # normalized vectors from base of edges to tops of edges - cos_angle = np.cos(self.sidewall_angle) - sin_angle = np.sin(self.sidewall_angle) - slabs_axis_components = cos_angle * ones - - # create axis_norm as array directly to avoid tuple->array conversion in np.cross - axis_norm = np.zeros(3, dtype=config.adjoint.gradient_dtype_float) - axis_norm[self.axis] = 1.0 - slab_normal_xyz = -sin_angle * np.cross(edges_norm_xyz, axis_norm) - _, slab_normal_in_plane = self.pop_axis_vect(slab_normal_xyz) - slabs_norm_xyz = self.unpop_axis_vect(slabs_axis_components, slab_normal_in_plane) - - # normalized vectors pointing in normal direction of edge - # cross yields inward normal when the extrusion axis is y, so negate once for axis==1 - sign = (-1 if self.axis == 1 else 1) * (-1 if not self.is_ccw else 1) - normals_norm_xyz = sign * np.cross(edges_norm_xyz, slabs_norm_xyz) - - return { - "norm": normals_norm_xyz, - "perp1": edges_norm_xyz, - "perp2": slabs_norm_xyz, - } - - def unpop_axis_vect(self, ax_coords: NDArray, plane_coords: NDArray) -> NDArray: - """Combine coordinate along axis with coordinates on the plane tangent to the axis. - - ax_coords.shape == [N] - plane_coords.shape == [N, 2] - return shape == [N, 3] - """ - n_pts = ax_coords.shape[0] - arr_xyz = np.zeros((n_pts, 3), dtype=ax_coords.dtype) - - plane_axes = [i for i in range(3) if i != self.axis] - - arr_xyz[:, self.axis] = ax_coords - arr_xyz[:, plane_axes] = plane_coords - - return arr_xyz - - def pop_axis_vect(self, coord: NDArray) -> tuple[NDArray, tuple[NDArray, NDArray]]: - """Combine coordinate along axis with coordinates on the plane tangent to the axis. - - coord.shape == [N, 3] - return shape == ([N], [N, 2] - """ - - arr_axis, arrs_plane = self.pop_axis(coord.T, axis=self.axis) - arrs_plane = np.array(arrs_plane).T - - return arr_axis, arrs_plane - - @staticmethod - def normalize_vect(arr: NDArray) -> NDArray: - """normalize an array shaped (N, d) along the `d` axis and return (N, 1).""" - norm = np.linalg.norm(arr, axis=-1, keepdims=True) - norm = np.where(norm == 0, 1, norm) - return arr / norm - - def translated(self, x: float, y: float, z: float) -> PolySlab: - """Return a translated copy of this geometry. - - Parameters - ---------- - x : float - Translation along x. - y : float - Translation along y. - z : float - Translation along z. - - Returns - ------- - :class:`PolySlab` - Translated copy of this ``PolySlab``. - """ - - t_normal, t_plane = self.pop_axis((x, y, z), axis=self.axis) - translated_vertices = np.array(self.vertices) + np.array(t_plane)[None, :] - translated_slab_bounds = (self.slab_bounds[0] + t_normal, self.slab_bounds[1] + t_normal) - return self.updated_copy(vertices=translated_vertices, slab_bounds=translated_slab_bounds) - - def scaled(self, x: float = 1.0, y: float = 1.0, z: float = 1.0) -> PolySlab: - """Return a scaled copy of this geometry. - - Parameters - ---------- - x : float = 1.0 - Scaling factor along x. - y : float = 1.0 - Scaling factor along y. - z : float = 1.0 - Scaling factor along z. - - Returns - ------- - :class:`Geometry` - Scaled copy of this geometry. - """ - scale_normal, scale_in_plane = self.pop_axis((x, y, z), axis=self.axis) - scaled_vertices = self.vertices * np.array(scale_in_plane) - scaled_slab_bounds = tuple(scale_normal * bound for bound in self.slab_bounds) - return self.updated_copy(vertices=scaled_vertices, slab_bounds=scaled_slab_bounds) - - def rotated(self, angle: float, axis: Union[Axis, Coordinate]) -> PolySlab: - """Return a rotated copy of this geometry. - - Parameters - ---------- - angle : float - Rotation angle (in radians). - axis : Union[int, tuple[float, float, float]] - Axis of rotation: 0, 1, or 2 for x, y, and z, respectively, or a 3D vector. - - Returns - ------- - :class:`PolySlab` - Rotated copy of this ``PolySlab``. - """ - _, plane_axs = self.pop_axis([0, 1, 2], self.axis) - if (isinstance(axis, int) and axis == self.axis) or ( - isinstance(axis, tuple) and all(axis[ax] == 0 for ax in plane_axs) - ): - verts_3d = np.zeros((3, self.vertices.shape[0])) - verts_3d[plane_axs[0], :] = self.vertices[:, 0] - verts_3d[plane_axs[1], :] = self.vertices[:, 1] - rotation = RotationAroundAxis(angle=angle, axis=axis) - rotated_vertices = rotation.rotate_vector(verts_3d) - rotated_vertices = rotated_vertices[plane_axs, :].T - return self.updated_copy(vertices=rotated_vertices) - - return super().rotated(angle=angle, axis=axis) - - def reflected(self, normal: Coordinate) -> PolySlab: - """Return a reflected copy of this geometry. - - Parameters - ---------- - normal : tuple[float, float, float] - The 3D normal vector of the plane of reflection. The plane is assumed - to pass through the origin (0,0,0). - - Returns - ------- - ------- - :class:`PolySlab` - Reflected copy of this ``PolySlab``. - """ - if math.isclose(normal[self.axis], 0): - _, plane_axs = self.pop_axis((0, 1, 2), self.axis) - verts_3d = np.zeros((3, self.vertices.shape[0])) - verts_3d[plane_axs[0], :] = self.vertices[:, 0] - verts_3d[plane_axs[1], :] = self.vertices[:, 1] - reflection = ReflectionFromPlane(normal=normal) - reflected_vertices = reflection.reflect_vector(verts_3d) - reflected_vertices = reflected_vertices[plane_axs, :].T - return self.updated_copy(vertices=reflected_vertices) - - return super().reflected(normal=normal) - - -class ComplexPolySlabBase(PolySlab): - """Interface for dividing a complex polyslab where self-intersecting polygon can - occur during extrusion. This class should not be used directly. Use instead - :class:`plugins.polyslab.ComplexPolySlab`.""" - - @model_validator(mode="after") - def no_self_intersecting_polygon_during_extrusion(self: Self) -> Self: - """Turn off the validation for this class.""" - return self - - @classmethod - def from_gds( - cls, - gds_cell: Cell, - axis: Axis, - slab_bounds: tuple[float, float], - gds_layer: int, - gds_dtype: Optional[int] = None, - gds_scale: PositiveFloat = 1.0, - dilation: float = 0.0, - sidewall_angle: float = 0, - reference_plane: PlanePosition = "middle", - ) -> list[PolySlab]: - """Import :class:`.PolySlab` from a ``gdstk.Cell``. - - Parameters - ---------- - gds_cell : gdstk.Cell - ``gdstk.Cell`` containing 2D geometric data. - axis : int - Integer index into the polygon's slab axis. (0,1,2) -> (x,y,z). - slab_bounds: tuple[float, float] - Minimum and maximum positions of the slab along ``axis``. - gds_layer : int - Layer index in the ``gds_cell``. - gds_dtype : int = None - Data-type index in the ``gds_cell``. - If ``None``, imports all data for this layer into the returned list. - gds_scale : float = 1.0 - Length scale used in GDS file in units of MICROMETER. - For example, if gds file uses nanometers, set ``gds_scale=1e-3``. - Must be positive. - dilation : float = 0.0 - Dilation of the polygon in the base by shifting each edge along its - normal outwards direction by a distance; - a negative value corresponds to erosion. - sidewall_angle : float = 0 - Angle of the sidewall. - ``sidewall_angle=0`` (default) specifies vertical wall, - while ``0 base.GeometryGroup: - """Divide a complex polyslab into a list of simple polyslabs, which - are assembled into a :class:`.GeometryGroup`. - - Returns - ------- - :class:`.GeometryGroup` - GeometryGroup for a list of simple polyslabs divided from the complex - polyslab. - """ - return base.GeometryGroup(geometries=self.sub_polyslabs) - - @property - def sub_polyslabs(self) -> list[PolySlab]: - """Divide a complex polyslab into a list of simple polyslabs. - Only neighboring vertex-vertex crossing events are treated in this - version. - - Returns - ------- - list[PolySlab] - A list of simple polyslabs. - """ - sub_polyslab_list = [] - num_division_count = 0 - # initialize sub-polyslab parameters - sub_polyslab_dict = self.model_dump(exclude={"type"}).copy() - if math.isclose(self.sidewall_angle, 0): - return [PolySlab.model_validate(sub_polyslab_dict)] - - sub_polyslab_dict.update({"dilation": 0}) # dilation accounted in setup - # initialize offset distance - offset_distance = 0 - - for dist_val in self._dilation_length: - dist_now = 0.0 - vertices_now = self.reference_polygon - - # constructing sub-polyslabs until reaching the base/top - while not math.isclose(dist_now, dist_val): - # bounds for sub-polyslabs assuming no self-intersection - slab_bounds = [ - self._dilation_value_at_reference_to_coord(dist_now), - self._dilation_value_at_reference_to_coord(dist_val), - ] - # 1) find out any vertices touching events between the current - # position to the base/top - max_dist = PolySlab._neighbor_vertices_crossing_detection( - vertices_now, dist_val - dist_now - ) - - # vertices touching events captured, update bounds for sub-polyslab - if max_dist is not None: - # max_dist doesn't have sign, so construct signed offset distance - offset_distance = max_dist * dist_val / abs(dist_val) - slab_bounds[1] = self._dilation_value_at_reference_to_coord( - dist_now + offset_distance - ) - - # 2) construct sub-polyslab - slab_bounds.sort() # for reference_plane=top/bottom, bounds need to be ordered - # direction of marching - reference_plane = "bottom" if dist_val / self._tanq < 0 else "top" - sub_polyslab_dict.update( - { - "slab_bounds": tuple(slab_bounds), - "vertices": vertices_now, - "reference_plane": reference_plane, - } - ) - sub_polyslab_list.append(PolySlab.model_validate(sub_polyslab_dict)) - - # Now Step 3 - if max_dist is None: - break - dist_now += offset_distance - # new polygon vertices where collapsing vertices are removed but keep one - vertices_now = PolySlab._shift_vertices(vertices_now, offset_distance)[0] - vertices_now = PolySlab._remove_duplicate_vertices(vertices_now) - # all vertices collapse - if len(vertices_now) < 3: - break - # polygon collapse into 1D - if self.make_shapely_polygon(vertices_now).buffer(0).area < fp_eps: - break - vertices_now = PolySlab._orient(vertices_now) - num_division_count += 1 - - if num_division_count > _COMPLEX_POLYSLAB_DIVISIONS_WARN: - log.warning( - f"Too many self-intersecting events: the polyslab has been divided into " - f"{num_division_count} polyslabs; more than {_COMPLEX_POLYSLAB_DIVISIONS_WARN} may " - f"slow down the simulation." - ) - - return sub_polyslab_list - - @property - def _dilation_length(self) -> list[float]: - """dilation length from reference plane to the top/bottom of the polyslab.""" - - # for "bottom", only needs to compute the offset length to the top - dist = [self._extrusion_length_to_offset_distance(self.finite_length_axis)] - # reverse the dilation value if the reference plane is on the top - if self.reference_plane == "top": - dist = [-dist[0]] - # for middle, both directions - elif self.reference_plane == "middle": - dist = [dist[0] / 2, -dist[0] / 2] - return dist - - def _dilation_value_at_reference_to_coord(self, dilation: float) -> float: - """Compute the coordinate based on the dilation value to the reference plane.""" - - z_coord = -dilation / self._tanq + self.slab_bounds[0] - if self.reference_plane == "middle": - return z_coord + self.finite_length_axis / 2 - if self.reference_plane == "top": - return z_coord + self.finite_length_axis - # bottom case - return z_coord - - def intersections_tilted_plane( - self, - normal: Coordinate, - origin: Coordinate, - to_2D: MatrixReal4x4, - cleanup: bool = True, - quad_segs: Optional[int] = None, - ) -> list[Shapely]: - """Return a list of shapely geometries at the plane specified by normal and origin. - - Parameters - ---------- - normal : Coordinate - Vector defining the normal direction to the plane. - origin : Coordinate - Vector defining the plane origin. - to_2D : MatrixReal4x4 - Transformation matrix to apply to resulting shapes. - cleanup : bool = True - If True, removes extremely small features from each polygon's boundary. - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. Not used for PolySlab. +# marked as migrated to _common +from __future__ import annotations - Returns - ------- - list[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation `_. - """ - return [ - shapely.unary_union( - [ - base.Geometry.evaluate_inf_shape(shape) - for polyslab in self.sub_polyslabs - for shape in polyslab.intersections_tilted_plane( - normal, origin, to_2D, cleanup=cleanup, quad_segs=quad_segs - ) - ] - ) - ] +from tidy3d._common.components.geometry.polyslab import ( + _COMPLEX_POLYSLAB_DIVISIONS_WARN, + _IS_CLOSE_RTOL, + _MAX_POLYSLAB_VERTICES_FOR_TRIANGULATION, + _MIN_POLYGON_AREA, + _N_SAMPLE_POLYGON_INTERSECT, + ComplexPolySlabBase, + PolySlab, + leggauss, +) diff --git a/tidy3d/components/geometry/primitives.py b/tidy3d/components/geometry/primitives.py index ea656a9800..f0921d132a 100644 --- a/tidy3d/components/geometry/primitives.py +++ b/tidy3d/components/geometry/primitives.py @@ -1,1294 +1,19 @@ -"""Concrete primitive geometrical objects.""" - -from __future__ import annotations - -from math import isclose -from typing import TYPE_CHECKING, Any - -import autograd.numpy as anp -import numpy as np -import shapely -from pydantic import Field, PrivateAttr, model_validator - -from tidy3d.components.autograd import TracedSize1D, get_static -from tidy3d.components.base import cached_property -from tidy3d.components.geometry import base -from tidy3d.components.geometry.mesh import TriangleMesh -from tidy3d.components.geometry.polyslab import PolySlab -from tidy3d.config import config -from tidy3d.constants import LARGE_NUMBER, MICROMETER -from tidy3d.exceptions import SetupError, ValidationError -from tidy3d.log import log -from tidy3d.packaging import verify_packages_import - -if TYPE_CHECKING: - from typing import Optional - - from shapely.geometry.base import BaseGeometry - - from tidy3d.compat import Self - from tidy3d.components.autograd import AutogradFieldMap - from tidy3d.components.autograd.derivative_utils import DerivativeInfo - from tidy3d.components.types import Axis, Bound, Coordinate, MatrixReal4x4, Shapely - -# for sampling conical frustum in visualization -_N_SAMPLE_CURVE_SHAPELY = 40 - -# for shapely circular shapes discretization in visualization -_N_SHAPELY_QUAD_SEGS_VISUALIZATION = 200 - -# Default number of points to discretize polyslab in `Cylinder.to_polyslab()` -_N_PTS_CYLINDER_POLYSLAB = 51 -_MAX_ICOSPHERE_SUBDIVISIONS = 7 # this would have 164K vertices and 328K faces -_DEFAULT_EDGE_FRACTION = 0.25 - - -def _base_icosahedron() -> tuple[np.ndarray, np.ndarray]: - """Return vertices and faces of a unit icosahedron.""" - - phi = (1.0 + np.sqrt(5.0)) / 2.0 - vertices = np.array( - [ - (-1, phi, 0), - (1, phi, 0), - (-1, -phi, 0), - (1, -phi, 0), - (0, -1, phi), - (0, 1, phi), - (0, -1, -phi), - (0, 1, -phi), - (phi, 0, -1), - (phi, 0, 1), - (-phi, 0, -1), - (-phi, 0, 1), - ], - dtype=float, - ) - vertices /= np.linalg.norm(vertices, axis=1)[:, None] - faces = np.array( - [ - (0, 11, 5), - (0, 5, 1), - (0, 1, 7), - (0, 7, 10), - (0, 10, 11), - (1, 5, 9), - (5, 11, 4), - (11, 10, 2), - (10, 7, 6), - (7, 1, 8), - (3, 9, 4), - (3, 4, 2), - (3, 2, 6), - (3, 6, 8), - (3, 8, 9), - (4, 9, 5), - (2, 4, 11), - (6, 2, 10), - (8, 6, 7), - (9, 8, 1), - ], - dtype=int, - ) - return vertices, faces - - -_ICOSAHEDRON_VERTS, _ICOSAHEDRON_FACES = _base_icosahedron() - - -def discretization_wavelength(derivative_info: DerivativeInfo, geometry_label: str) -> float: - """Choose reference wavelength for surface discretization.""" - wvl0_min = derivative_info.wavelength_min - wvl_mat = wvl0_min / np.max([1.0, np.max(np.sqrt(abs(derivative_info.eps_in)))]) - - grid_cfg = config.adjoint - - min_wvl_mat = grid_cfg.min_wvl_fraction * wvl0_min - if wvl_mat < min_wvl_mat: - log.warning( - f"The minimum wavelength inside the {geometry_label} material is {wvl_mat:.3e} μm, which would " - f"create a large number of discretization points for computing the gradient. " - f"To prevent performance degradation, the discretization wavelength has " - f"been clipped to {min_wvl_mat:.3e} μm.", - log_once=True, - ) - return max(wvl_mat, min_wvl_mat) - - -class Sphere(base.Centered, base.Circular): - """Spherical geometry. - - Example - ------- - >>> b = Sphere(center=(1,2,3), radius=2) - """ - - radius: TracedSize1D = Field( - title="Radius", - description="Radius of geometry.", - json_schema_extra={"units": MICROMETER}, - ) - - _icosphere_cache: dict[int, tuple[np.ndarray, float]] = PrivateAttr(default_factory=dict) - - @verify_packages_import(["trimesh"]) - def to_triangle_mesh( - self, - *, - max_edge_length: Optional[float] = None, - subdivisions: Optional[int] = None, - ) -> TriangleMesh: - """Approximate the sphere surface with a ``TriangleMesh``. - - Parameters - ---------- - max_edge_length : float = None - Maximum edge length for triangulation in micrometers. - subdivisions : int = None - Number of subdivisions for icosphere generation. - - Returns - ------- - TriangleMesh - Triangle mesh approximation of the sphere surface. - """ - - triangles, _ = self._triangulated_surface( - max_edge_length=max_edge_length, subdivisions=subdivisions - ) - return TriangleMesh.from_triangles(triangles) - - def inside( - self, x: np.ndarray[float], y: np.ndarray[float], z: np.ndarray[float] - ) -> np.ndarray[bool]: - """For input arrays ``x``, ``y``, ``z`` of arbitrary but identical shape, return an array - with the same shape which is ``True`` for every point in zip(x, y, z) that is inside the - volume of the :class:`Geometry`, and ``False`` otherwise. - - Parameters - ---------- - x : np.ndarray[float] - Array of point positions in x direction. - y : np.ndarray[float] - Array of point positions in y direction. - z : np.ndarray[float] - Array of point positions in z direction. - - Returns - ------- - np.ndarray[bool] - ``True`` for every point that is inside the geometry. - """ - self._ensure_equal_shape(x, y, z) - x0, y0, z0 = self.center - dist_x = np.abs(x - x0) - dist_y = np.abs(y - y0) - dist_z = np.abs(z - z0) - return (dist_x**2 + dist_y**2 + dist_z**2) <= (self.radius**2) - - def intersections_tilted_plane( - self, - normal: Coordinate, - origin: Coordinate, - to_2D: MatrixReal4x4, - cleanup: bool = True, - quad_segs: Optional[int] = None, - ) -> list[Shapely]: - """Return a list of shapely geometries at the plane specified by normal and origin. - - Parameters - ---------- - normal : Coordinate - Vector defining the normal direction to the plane. - origin : Coordinate - Vector defining the plane origin. - to_2D : MatrixReal4x4 - Transformation matrix to apply to resulting shapes. - cleanup : bool = True - If True, removes extremely small features from each polygon's boundary. - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. If ``None``, uses - ``_N_SHAPELY_QUAD_SEGS_VISUALIZATION`` for high-quality visualization. - - Returns - ------- - list[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation `_. - """ - if quad_segs is None: - quad_segs = _N_SHAPELY_QUAD_SEGS_VISUALIZATION - - normal = np.array(normal) - unit_normal = normal / (np.sum(normal**2) ** 0.5) - projection = np.dot(np.array(origin) - np.array(self.center), unit_normal) - if abs(projection) >= self.radius: - return [] - - radius = (self.radius**2 - projection**2) ** 0.5 - center = np.array(self.center) + projection * unit_normal - - v = np.zeros(3) - v[np.argmin(np.abs(unit_normal))] = 1 - u = np.cross(unit_normal, v) - u /= np.sum(u**2) ** 0.5 - v = np.cross(unit_normal, u) - - angles = np.linspace(0, 2 * np.pi, quad_segs * 4 + 1)[:-1] - circ = center + np.outer(np.cos(angles), radius * u) + np.outer(np.sin(angles), radius * v) - vertices = np.dot(np.hstack((circ, np.ones((angles.size, 1)))), to_2D.T) - return [shapely.Polygon(vertices[:, :2])] - - def intersections_plane( - self, - x: Optional[float] = None, - y: Optional[float] = None, - z: Optional[float] = None, - cleanup: bool = True, - quad_segs: Optional[int] = None, - ) -> list[BaseGeometry]: - """Returns shapely geometry at plane specified by one non None value of x,y,z. - - Parameters - ---------- - x : float = None - Position of plane in x direction, only one of x,y,z can be specified to define plane. - y : float = None - Position of plane in x direction, only one of x,y,z can be specified to define plane. - z : float = None - Position of plane in x direction, only one of x,y,z can be specified to define plane. - cleanup : bool = True - If True, removes extremely small features from each polygon's boundary. - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. If ``None``, uses - ``_N_SHAPELY_QUAD_SEGS_VISUALIZATION`` for high-quality visualization. - - Returns - ------- - list[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation ``. - """ - if quad_segs is None: - quad_segs = _N_SHAPELY_QUAD_SEGS_VISUALIZATION - - axis, position = self.parse_xyz_kwargs(x=x, y=y, z=z) - if not self.intersects_axis_position(axis, position): - return [] - z0, (x0, y0) = self.pop_axis(self.center, axis=axis) - intersect_dist = self._intersect_dist(position, z0) - if not intersect_dist: - return [] - return [shapely.Point(x0, y0).buffer(0.5 * intersect_dist, quad_segs=quad_segs)] - - @cached_property - def bounds(self) -> Bound: - """Returns bounding box min and max coordinates. - - Returns - ------- - Tuple[float, float, float], Tuple[float, float, float] - Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. - """ - coord_min = tuple(c - self.radius for c in self.center) - coord_max = tuple(c + self.radius for c in self.center) - return (coord_min, coord_max) - - def _volume(self, bounds: Bound) -> float: - """Returns object's volume within given bounds.""" - - volume = 4.0 / 3.0 * np.pi * self.radius**3 - - # a very loose upper bound on how much of sphere is in bounds - for axis in range(3): - if self.center[axis] <= bounds[0][axis] or self.center[axis] >= bounds[1][axis]: - volume *= 0.5 - - return volume - - def _surface_area(self, bounds: Bound) -> float: - """Returns object's surface area within given bounds.""" - - area = 4.0 * np.pi * self.radius**2 - - # a very loose upper bound on how much of sphere is in bounds - for axis in range(3): - if self.center[axis] <= bounds[0][axis] or self.center[axis] >= bounds[1][axis]: - area *= 0.5 - - return area - - @classmethod - def unit_sphere_triangles( - cls, - *, - target_edge_length: Optional[float] = None, - subdivisions: Optional[int] = None, - ) -> np.ndarray: - """Return unit sphere triangles discretized via an icosphere.""" - - unit_tris = UNIT_SPHERE._unit_sphere_triangles( - target_edge_length=target_edge_length, - subdivisions=subdivisions, - copy_result=True, - ) - return unit_tris - - def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: - """Compute adjoint derivatives using smooth sphere surface samples.""" - valid_paths = {("radius",), *{("center", i) for i in range(3)}} - for path in derivative_info.paths: - if path not in valid_paths: - raise ValueError( - f"No derivative defined w.r.t. 'Sphere' field '{path}'. " - "Supported fields are 'radius' and 'center'." - ) - - if not derivative_info.paths: - return {} - - grid_cfg = config.adjoint - radius = float(get_static(self.radius)) - if radius == 0.0: - log.warning( - "Sphere gradients cannot be computed for zero radius; gradients are zero.", - log_once=True, - ) - return dict.fromkeys(derivative_info.paths, 0.0) - - wvl_mat = discretization_wavelength(derivative_info, "sphere") - target_edge = max(wvl_mat / grid_cfg.points_per_wavelength, np.finfo(float).eps) - triangles, _ = self._triangulated_surface(max_edge_length=target_edge) - triangles = triangles.astype(grid_cfg.gradient_dtype_float, copy=False) - - sim_min, sim_max = ( - np.asarray(arr, dtype=grid_cfg.gradient_dtype_float) - for arr in derivative_info.simulation_bounds - ) - tol = config.adjoint.edge_clip_tolerance - - sim_extents = sim_max - sim_min - collapsed_indices = np.flatnonzero(np.isclose(sim_extents, 0.0, atol=tol)) - if collapsed_indices.size: - if collapsed_indices.size > 1: - return dict.fromkeys(derivative_info.paths, 0.0) - axis_idx = int(collapsed_indices[0]) - plane_value = float(sim_min[axis_idx]) - return self._compute_derivatives_collapsed_axis( - derivative_info=derivative_info, - axis_idx=axis_idx, - plane_value=plane_value, - ) - - trimesh_obj = TriangleMesh._triangles_to_trimesh(triangles) - vertices = np.asarray(trimesh_obj.vertices, dtype=grid_cfg.gradient_dtype_float) - center = np.asarray(self.center, dtype=grid_cfg.gradient_dtype_float) - verts_centered = vertices - center - norms = np.linalg.norm(verts_centered, axis=1, keepdims=True) - norms = np.where(norms == 0, 1, norms) - normals = verts_centered / norms - - if vertices.size == 0: - return dict.fromkeys(derivative_info.paths, 0.0) - - # get vertex weights - faces = np.asarray(trimesh_obj.faces, dtype=int) - face_areas = np.asarray(trimesh_obj.area_faces, dtype=grid_cfg.gradient_dtype_float) - weights = np.zeros(len(vertices), dtype=grid_cfg.gradient_dtype_float) - np.add.at(weights, faces[:, 0], face_areas / 3.0) - np.add.at(weights, faces[:, 1], face_areas / 3.0) - np.add.at(weights, faces[:, 2], face_areas / 3.0) - - perp1, perp2 = self._tangent_basis_from_normals(normals) - - valid_axes = np.abs(sim_max - sim_min) > tol - inside_mask = np.all( - vertices[:, valid_axes] >= (sim_min - tol)[valid_axes], axis=1 - ) & np.all(vertices[:, valid_axes] <= (sim_max + tol)[valid_axes], axis=1) - - if not np.any(inside_mask): - return dict.fromkeys(derivative_info.paths, 0.0) - - points = vertices[inside_mask] - normals_sel = normals[inside_mask] - perp1_sel = perp1[inside_mask] - perp2_sel = perp2[inside_mask] - weights_sel = weights[inside_mask] - - interpolators = derivative_info.interpolators - if interpolators is None: - interpolators = derivative_info.create_interpolators( - dtype=grid_cfg.gradient_dtype_float - ) - - g = derivative_info.evaluate_gradient_at_points( - points, - normals_sel, - perp1_sel, - perp2_sel, - interpolators, - ) - - weighted = (weights_sel * g).real - grad_center = np.sum(weighted[:, None] * normals_sel, axis=0) - grad_radius = np.sum(weighted) - - vjps: AutogradFieldMap = {} - for path in derivative_info.paths: - if path == ("radius",): - vjps[path] = float(grad_radius) - else: - _, idx = path - vjps[path] = float(grad_center[idx]) - - return vjps - - def _compute_derivatives_collapsed_axis( - self, - derivative_info: DerivativeInfo, - axis_idx: int, - plane_value: float, - ) -> AutogradFieldMap: - """Delegate collapsed-axis gradients to a Cylinder cross section.""" - tol = config.adjoint.edge_clip_tolerance - radius = float(self.radius) - center = np.asarray(self.center, dtype=float) - delta = plane_value - center[axis_idx] - radius_sq = radius**2 - delta**2 - if radius_sq <= tol**2: - return dict.fromkeys(derivative_info.paths, 0.0) - - radius_plane = float(np.sqrt(max(radius_sq, 0.0))) - if radius_plane <= tol: - return dict.fromkeys(derivative_info.paths, 0.0) - - cyl_paths: set[tuple[str, int | None]] = set() - need_radius = False - for path in derivative_info.paths: - if path == ("radius",) or path == ("center", axis_idx): - cyl_paths.add(("radius",)) - need_radius = True - elif path[0] == "center" and path[1] != axis_idx: - cyl_paths.add(("center", path[1])) - - if not cyl_paths: - return dict.fromkeys(derivative_info.paths, 0.0) - - cyl_center = center.copy() - cyl_center[axis_idx] = plane_value - cylinder = Cylinder( - center=tuple(cyl_center), - radius=radius_plane, - length=discretization_wavelength(derivative_info, "sphere") * 2.0, - axis=axis_idx, - ) - - bounds_min = list(cyl_center) - bounds_max = list(cyl_center) - for dim in range(3): - if dim == axis_idx: - continue - bounds_min[dim] = center[dim] - radius_plane - bounds_max[dim] = center[dim] + radius_plane - - bounds = (tuple(bounds_min), tuple(bounds_max)) - sim_min_arr, sim_max_arr = ( - np.asarray(arr, dtype=float) for arr in derivative_info.simulation_bounds - ) - intersect_min = tuple(max(bounds[0][i], sim_min_arr[i]) for i in range(3)) - intersect_max = tuple(min(bounds[1][i], sim_max_arr[i]) for i in range(3)) - if any(lo > hi for lo, hi in zip(intersect_min, intersect_max)): - return dict.fromkeys(derivative_info.paths, 0.0) - - derivative_info_cyl = derivative_info.updated_copy( - paths=list(cyl_paths), - bounds=bounds, - bounds_intersect=(intersect_min, intersect_max), - ) - - vjps_cyl = cylinder._compute_derivatives(derivative_info_cyl) - result = dict.fromkeys(derivative_info.paths, 0.0) - vjp_radius = float(vjps_cyl.get(("radius",), 0.0)) if need_radius else 0.0 - - for path in derivative_info.paths: - if path == ("radius",): - result[path] = vjp_radius * (radius / radius_plane) - elif path == ("center", axis_idx): - result[path] = vjp_radius * (delta / radius_plane) - elif path[0] == "center" and path[1] != axis_idx: - result[path] = float(vjps_cyl.get(("center", path[1]), 0.0)) - - return result - - def _edge_length_on_unit_sphere( - self, max_edge_length: Optional[float] = _DEFAULT_EDGE_FRACTION - ) -> Optional[float]: - """Convert ``max_edge_length`` in μm to unit-sphere coordinates.""" - max_edge_length = _DEFAULT_EDGE_FRACTION if max_edge_length is None else max_edge_length - radius = float(self.radius) - if radius <= 0.0: - return None - return max_edge_length / radius - - def _triangulated_surface( - self, - *, - max_edge_length: Optional[float] = None, - subdivisions: Optional[int] = None, - ) -> tuple[np.ndarray, np.ndarray]: - """Return physical and unit triangles for the surface discretization. Pass either max_edge_length or subdivisions.""" - max_edge_length_unit = None - if subdivisions is None: - max_edge_length_unit = self._edge_length_on_unit_sphere(max_edge_length) - - unit_tris = self._unit_sphere_triangles( - target_edge_length=max_edge_length_unit, - subdivisions=subdivisions, - copy_result=False, - ) - - radius = float(get_static(self.radius)) - center = np.asarray(self.center, dtype=float) - dtype = config.adjoint.gradient_dtype_float - - physical = radius * unit_tris + center - return physical.astype(dtype, copy=False), unit_tris.astype(dtype, copy=False) - - def _unit_sphere_triangles( - self, - *, - target_edge_length: Optional[float] = None, - subdivisions: Optional[int] = None, - copy_result: bool = True, - ) -> np.ndarray: - """Return cached unit-sphere triangles with optional copying. Pass either target_edge_length or subdivisions.""" - if target_edge_length is not None and subdivisions is not None: - raise ValueError("Specify either target_edge_length OR subdivisions, not both.") - - if subdivisions is None: - subdivisions = self._subdivisions_for_edge(target_edge_length) - - triangles, _ = self._icosphere_data(subdivisions) - return np.array(triangles, copy=copy_result) - - def _subdivisions_for_edge(self, target_edge_length: Optional[float]) -> int: - if target_edge_length is None or target_edge_length <= 0.0: - return 0 - - for subdiv in range(_MAX_ICOSPHERE_SUBDIVISIONS + 1): - _, max_edge = self._icosphere_data(subdiv) - if max_edge <= target_edge_length: - return subdiv - - log.warning( - f"Requested sphere mesh edge length {target_edge_length:.3e} μm requires more than " - f"{_MAX_ICOSPHERE_SUBDIVISIONS} subdivisions. " - "Clipping to the finest available mesh.", - log_once=True, - ) - return _MAX_ICOSPHERE_SUBDIVISIONS - - @staticmethod - def _tangent_basis_from_normals(normals: np.ndarray) -> tuple[np.ndarray, np.ndarray]: - """Construct orthonormal tangential bases for each normal vector (vectorized).""" - - dtype = normals.dtype - tol = np.finfo(dtype).eps - - # Normalize normals (in case they are not perfectly unit length). - n_norm = np.linalg.norm(normals, axis=1) - n = normals / np.maximum(n_norm, tol)[:, None] - - # Pick a reference axis least aligned with each normal: argmin(|nx|,|ny|,|nz|). - ref_idx = np.argmin(np.abs(n), axis=1) - ref = np.zeros_like(n) - ref[np.arange(n.shape[0]), ref_idx] = 1.0 - - basis1 = np.cross(n, ref) - b1_norm = np.linalg.norm(basis1, axis=1) - basis1 = basis1 / np.maximum(b1_norm, tol)[:, None] - - basis2 = np.cross(n, basis1) - b2_norm = np.linalg.norm(basis2, axis=1) - basis2 = basis2 / np.maximum(b2_norm, tol)[:, None] - - return basis1, basis2 - - def _icosphere_data(self, subdivisions: int) -> tuple[np.ndarray, float]: - cache = self._icosphere_cache - if subdivisions in cache: - return cache[subdivisions] - - vertices = np.asarray(_ICOSAHEDRON_VERTS, dtype=float) - faces = np.asarray(_ICOSAHEDRON_FACES, dtype=int) - if subdivisions > 0: - vertices = vertices.copy() - faces = faces.copy() - for _ in range(subdivisions): - vertices, faces = TriangleMesh.subdivide_faces(vertices, faces) - - norms = np.linalg.norm(vertices, axis=1, keepdims=True) - norms = np.where(norms == 0.0, 1.0, norms) - vertices = vertices / norms - - triangles = vertices[faces] - max_edge = self._max_edge_length(triangles) - cache[subdivisions] = (triangles, max_edge) - return triangles, max_edge - - @staticmethod - def _max_edge_length(triangles: np.ndarray) -> float: - v = triangles - edges = np.stack( - [ - v[:, 1] - v[:, 0], - v[:, 2] - v[:, 1], - v[:, 0] - v[:, 2], - ], - axis=1, - ) - return float(np.linalg.norm(edges, axis=2).max()) +"""Compatibility shim for :mod:`tidy3d._common.components.geometry.primitives`.""" +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -UNIT_SPHERE = Sphere(center=(0.0, 0.0, 0.0), radius=1.0) - - -class Cylinder(base.Centered, base.Circular, base.Planar): - """Cylindrical geometry with optional sidewall angle along axis - direction. When ``sidewall_angle`` is nonzero, the shape is a - conical frustum or a cone. - - Example - ------- - >>> c = Cylinder(center=(1,2,3), radius=2, length=5, axis=2) - - See Also - -------- - - **Notebooks** - - * `THz integrated demultiplexer/filter based on a ring resonator <../../../notebooks/THzDemultiplexerFilter.html>`_ - * `Photonic crystal waveguide polarization filter <../../../notebooks/PhotonicCrystalWaveguidePolarizationFilter.html>`_ - """ - - # Provide more explanations on where radius is defined - radius: TracedSize1D = Field( - title="Radius", - description="Radius of geometry at the ``reference_plane``.", - json_schema_extra={"units": MICROMETER}, - ) - - length: TracedSize1D = Field( - title="Length", - description="Defines thickness of cylinder along axis dimension.", - json_schema_extra={"units": MICROMETER}, - ) - - @model_validator(mode="after") - def _only_middle_for_infinite_length_slanted_cylinder(self: Self) -> Self: - """For a slanted cylinder of infinite length, ``reference_plane`` can only - be ``middle``; otherwise, the radius at ``center`` is either td.inf or 0. - """ - if isclose(self.sidewall_angle, 0) or not np.isinf(self.length): - return self - if self.reference_plane != "middle": - raise SetupError( - "For a slanted cylinder here is of infinite length, " - "defining the reference_plane other than 'middle' " - "leads to undefined cylinder behaviors near 'center'." - ) - return self - - def to_polyslab( - self, num_pts_circumference: int = _N_PTS_CYLINDER_POLYSLAB, **kwargs: Any - ) -> PolySlab: - """Convert instance of ``Cylinder`` into a discretized version using ``PolySlab``. - - Parameters - ---------- - num_pts_circumference : int = 51 - Number of points in the circumference of the discretized polyslab. - **kwargs: - Extra keyword arguments passed to ``PolySlab()``, such as ``dilation``. - - Returns - ------- - PolySlab - Extruded polygon representing a discretized version of the cylinder. - """ - - center_axis = self.center_axis - length_axis = self.length_axis - slab_bounds = (center_axis - length_axis / 2.0, center_axis + length_axis / 2.0) - - if num_pts_circumference < 3: - raise ValueError("'PolySlab' from 'Cylinder' must have 3 or more radius points.") - - _, (x0, y0) = self.pop_axis(self.center, axis=self.axis) - - xs_, ys_ = self._points_unit_circle(num_pts_circumference=num_pts_circumference) - - xs = x0 + self.radius * xs_ - ys = y0 + self.radius * ys_ - - vertices = anp.stack((xs, ys), axis=-1) - - return PolySlab( - vertices=vertices, - axis=self.axis, - slab_bounds=slab_bounds, - sidewall_angle=self.sidewall_angle, - reference_plane=self.reference_plane, - **kwargs, - ) - - def _points_unit_circle( - self, num_pts_circumference: int = _N_PTS_CYLINDER_POLYSLAB - ) -> np.ndarray: - """Set of x and y points for the unit circle when discretizing cylinder as a polyslab.""" - angles = np.linspace(0, 2 * np.pi, num_pts_circumference, endpoint=False) - xs = np.cos(angles) - ys = np.sin(angles) - return np.stack((xs, ys), axis=0) - - def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: - """Compute the adjoint derivatives for this object.""" - - # compute circumference discretization - wvl_mat = discretization_wavelength(derivative_info, "cylinder") - - circumference = 2 * np.pi * self.radius - wvls_in_circumference = circumference / wvl_mat - - grid_cfg = config.adjoint - num_pts_circumference = int(np.ceil(grid_cfg.points_per_wavelength * wvls_in_circumference)) - num_pts_circumference = max(3, num_pts_circumference) - - # construct equivalent polyslab and compute the derivatives - polyslab = self.to_polyslab(num_pts_circumference=num_pts_circumference) - - # build PolySlab derivative paths based on requested Cylinder paths - ps_paths = set() - for path in derivative_info.paths: - if path == ("length",): - ps_paths.update({("slab_bounds", 0), ("slab_bounds", 1)}) - elif path == ("radius",): - ps_paths.add(("vertices",)) - elif "center" in path: - _, center_index = path - _, (index_x, index_y) = self.pop_axis((0, 1, 2), axis=self.axis) - if center_index in (index_x, index_y): - ps_paths.add(("vertices",)) - else: - ps_paths.update({("slab_bounds", 0), ("slab_bounds", 1)}) - elif path == ("sidewall_angle",): - ps_paths.add(("sidewall_angle",)) - - # pass interpolators to PolySlab if available to avoid redundant conversions - update_kwargs = { - "paths": list(ps_paths), - "deep": False, - } - if derivative_info.interpolators is not None: - update_kwargs["interpolators"] = derivative_info.interpolators - - derivative_info_polyslab = derivative_info.updated_copy(**update_kwargs) - vjps_polyslab = polyslab._compute_derivatives(derivative_info_polyslab) - - vjps = {} - for path in derivative_info.paths: - if path == ("length",): - vjp_top = vjps_polyslab.get(("slab_bounds", 0), 0.0) - vjp_bot = vjps_polyslab.get(("slab_bounds", 1), 0.0) - vjps[path] = vjp_top - vjp_bot - - elif path == ("radius",): - # transform polyslab vertices derivatives into radius derivative - xs_, ys_ = self._points_unit_circle(num_pts_circumference=num_pts_circumference) - if ("vertices",) not in vjps_polyslab: - vjps[path] = 0.0 - else: - vjps_vertices_xs, vjps_vertices_ys = vjps_polyslab[("vertices",)].T - vjp_xs = np.sum(xs_ * vjps_vertices_xs) - vjp_ys = np.sum(ys_ * vjps_vertices_ys) - vjps[path] = vjp_xs + vjp_ys - - elif "center" in path: - _, center_index = path - _, (index_x, index_y) = self.pop_axis((0, 1, 2), axis=self.axis) - if center_index == index_x: - if ("vertices",) not in vjps_polyslab: - vjps[path] = 0.0 - else: - vjps_vertices_xs = vjps_polyslab[("vertices",)][:, 0] - vjps[path] = np.sum(vjps_vertices_xs) - elif center_index == index_y: - if ("vertices",) not in vjps_polyslab: - vjps[path] = 0.0 - else: - vjps_vertices_ys = vjps_polyslab[("vertices",)][:, 1] - vjps[path] = np.sum(vjps_vertices_ys) - else: - vjp_top = vjps_polyslab.get(("slab_bounds", 0), 0.0) - vjp_bot = vjps_polyslab.get(("slab_bounds", 1), 0.0) - vjps[path] = vjp_top + vjp_bot - - elif path == ("sidewall_angle",): - # direct mapping: cylinder angle equals polyslab angle - vjps[path] = vjps_polyslab.get(("sidewall_angle",), 0.0) - - else: - raise NotImplementedError( - f"Differentiation with respect to 'Cylinder' '{path}' field not supported. " - "If you would like this feature added, please feel free to raise " - "an issue on the tidy3d front end repository." - ) - - return vjps - - @property - def center_axis(self) -> Any: - """Gets the position of the center of the geometry in the out of plane dimension.""" - z0, _ = self.pop_axis(self.center, axis=self.axis) - return z0 - - @property - def length_axis(self) -> float: - """Gets the length of the geometry along the out of plane dimension.""" - return self.length - - @cached_property - def _normal_2dmaterial(self) -> Axis: - """Get the normal to the given geometry, checking that it is a 2D geometry.""" - if self.length != 0: - raise ValidationError("'Medium2D' requires the 'Cylinder' length to be zero.") - return self.axis - - def _update_from_bounds(self, bounds: tuple[float, float], axis: Axis) -> Cylinder: - """Returns an updated geometry which has been transformed to fit within ``bounds`` - along the ``axis`` direction.""" - if axis != self.axis: - raise ValueError( - f"'_update_from_bounds' may only be applied along axis '{self.axis}', " - f"but was given axis '{axis}'." - ) - new_center = list(self.center) - new_center[axis] = (bounds[0] + bounds[1]) / 2 - new_length = bounds[1] - bounds[0] - return self.updated_copy(center=tuple(new_center), length=new_length) - - @verify_packages_import(["trimesh"]) - def _do_intersections_tilted_plane( - self, - normal: Coordinate, - origin: Coordinate, - to_2D: MatrixReal4x4, - quad_segs: Optional[int] = None, - ) -> list[Shapely]: - """Return a list of shapely geometries at the plane specified by normal and origin. - - Parameters - ---------- - normal : Coordinate - Vector defining the normal direction to the plane. - origin : Coordinate - Vector defining the plane origin. - to_2D : MatrixReal4x4 - Transformation matrix to apply to resulting shapes. - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. If ``None``, uses - ``_N_SHAPELY_QUAD_SEGS_VISUALIZATION`` for high-quality visualization. - - Returns - ------- - list[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation `_. - """ - import trimesh - - if quad_segs is None: - quad_segs = _N_SHAPELY_QUAD_SEGS_VISUALIZATION - - z0, (x0, y0) = self.pop_axis(self.center, self.axis) - half_length = self.finite_length_axis / 2 - - z_top = z0 + half_length - z_bot = z0 - half_length - - if np.isclose(self.sidewall_angle, 0): - r_top = self.radius - r_bot = self.radius - else: - r_top = self.radius_top - r_bot = self.radius_bottom - if r_top < 0 or np.isclose(r_top, 0): - r_top = 0 - z_top = z0 + self._radius_z(z0) / self._tanq - elif r_bot < 0 or np.isclose(r_bot, 0): - r_bot = 0 - z_bot = z0 + self._radius_z(z0) / self._tanq - - angles = np.linspace(0, 2 * np.pi, quad_segs * 4 + 1) - - if r_bot > 0: - x_bot = x0 + r_bot * np.cos(angles) - y_bot = y0 + r_bot * np.sin(angles) - x_bot[-1] = x0 - y_bot[-1] = y0 - else: - x_bot = np.array([x0]) - y_bot = np.array([y0]) - - if r_top > 0: - x_top = x0 + r_top * np.cos(angles) - y_top = y0 + r_top * np.sin(angles) - x_top[-1] = x0 - y_top[-1] = y0 - else: - x_top = np.array([x0]) - y_top = np.array([y0]) - - x = np.hstack((x_bot, x_top)) - y = np.hstack((y_bot, y_top)) - z = np.hstack((np.full_like(x_bot, z_bot), np.full_like(x_top, z_top))) - vertices = np.vstack(self.unpop_axis(z, (x, y), self.axis)).T - - if x_bot.shape[0] == 1: - m = 1 - n = x_top.shape[0] - 1 - faces_top = [(m + n, m + i, m + (i + 1) % n) for i in range(n)] - faces_side = [(m + (i + 1) % n, m + i, 0) for i in range(n)] - faces = faces_top + faces_side - elif x_top.shape[0] == 1: - m = x_bot.shape[0] - n = m - 1 - faces_bot = [(n, (i + 1) % n, i) for i in range(n)] - faces_side = [(i, (i + 1) % n, m) for i in range(n)] - faces = faces_bot + faces_side - else: - m = x_bot.shape[0] - n = m - 1 - faces_bot = [(n, (i + 1) % n, i) for i in range(n)] - faces_top = [(m + n, m + i, m + (i + 1) % n) for i in range(n)] - faces_side_bot = [(i, (i + 1) % n, m + (i + 1) % n) for i in range(n)] - faces_side_top = [(m + (i + 1) % n, m + i, i) for i in range(n)] - faces = faces_bot + faces_top + faces_side_bot + faces_side_top - - mesh = trimesh.Trimesh(vertices, faces) - - section = mesh.section(plane_origin=origin, plane_normal=normal) - if section is None: - return [] - path, _ = section.to_2D(to_2D=to_2D) - return path.polygons_full - - def _intersections_normal( - self, z: float, quad_segs: Optional[int] = None - ) -> list[BaseGeometry]: - """Find shapely geometries intersecting cylindrical geometry with axis normal to slab. - - Parameters - ---------- - z : float - Position along the axis normal to slab - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. If ``None``, uses - ``_N_SHAPELY_QUAD_SEGS_VISUALIZATION`` for high-quality visualization. - - Returns - ------- - list[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation `_. - """ - if quad_segs is None: - quad_segs = _N_SHAPELY_QUAD_SEGS_VISUALIZATION - - static_self = self.to_static() - - # radius at z - radius_offset = static_self._radius_z(z) - - if radius_offset <= 0: - return [] - - _, (x0, y0) = self.pop_axis(static_self.center, axis=self.axis) - return [shapely.Point(x0, y0).buffer(radius_offset, quad_segs=quad_segs)] - - def _intersections_side(self, position: float, axis: int) -> list[BaseGeometry]: - """Find shapely geometries intersecting cylindrical geometry with axis orthogonal to length. - When ``sidewall_angle`` is nonzero, so that it's in fact a conical frustum or cone, the - cross section can contain hyperbolic curves. This is currently approximated by a polygon - of many vertices. - - Parameters - ---------- - position : float - Position along axis direction. - axis : int - Integer index into 'xyz' (0, 1, 2). - - Returns - ------- - list[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation `_. - """ - # position in the local coordinate of the cylinder - position_local = position - self.center[axis] - - # no intersection - if abs(position_local) >= self.radius_max: - return [] - - # half of intersection length at the top and bottom - intersect_half_length_max = np.sqrt(self.radius_max**2 - position_local**2) - intersect_half_length_min = -LARGE_NUMBER - if abs(position_local) < self.radius_min: - intersect_half_length_min = np.sqrt(self.radius_min**2 - position_local**2) - - # the vertices on the max side of top/bottom - # The two vertices are present in all scenarios. - vertices_max = [ - self._local_to_global_side_cross_section([-intersect_half_length_max, 0], axis), - self._local_to_global_side_cross_section([intersect_half_length_max, 0], axis), - ] - - # Extending to a cone, the maximal height of the cone - h_cone = ( - LARGE_NUMBER if isclose(self.sidewall_angle, 0) else self.radius_max / abs(self._tanq) - ) - # The maximal height of the cross section - height_max = min( - (1 - abs(position_local) / self.radius_max) * h_cone, self.finite_length_axis - ) - - # more vertices to add for conical frustum shape - vertices_frustum_right = [] - vertices_frustum_left = [] - if not (isclose(position, self.center[axis]) or isclose(self.sidewall_angle, 0)): - # The y-coordinate for the additional vertices - y_list = height_max * np.linspace(0, 1, _N_SAMPLE_CURVE_SHAPELY) - # `abs()` to make sure np.sqrt(0-fp_eps) goes through - x_list = np.sqrt( - np.abs(self.radius_max**2 * (1 - y_list / h_cone) ** 2 - position_local**2) - ) - for i in range(_N_SAMPLE_CURVE_SHAPELY): - vertices_frustum_right.append( - self._local_to_global_side_cross_section([x_list[i], y_list[i]], axis) - ) - vertices_frustum_left.append( - self._local_to_global_side_cross_section( - [ - -x_list[_N_SAMPLE_CURVE_SHAPELY - i - 1], - y_list[_N_SAMPLE_CURVE_SHAPELY - i - 1], - ], - axis, - ) - ) - - # the vertices on the min side of top/bottom - vertices_min = [] - - ## termination at the top/bottom - if intersect_half_length_min > 0: - vertices_min.append( - self._local_to_global_side_cross_section( - [intersect_half_length_min, self.finite_length_axis], axis - ) - ) - vertices_min.append( - self._local_to_global_side_cross_section( - [-intersect_half_length_min, self.finite_length_axis], axis - ) - ) - ## early termination - else: - vertices_min.append(self._local_to_global_side_cross_section([0, height_max], axis)) - - return [ - shapely.Polygon( - vertices_max + vertices_frustum_right + vertices_min + vertices_frustum_left - ) - ] - - def inside( - self, x: np.ndarray[float], y: np.ndarray[float], z: np.ndarray[float] - ) -> np.ndarray[bool]: - """For input arrays ``x``, ``y``, ``z`` of arbitrary but identical shape, return an array - with the same shape which is ``True`` for every point in zip(x, y, z) that is inside the - volume of the :class:`Geometry`, and ``False`` otherwise. - - Parameters - ---------- - x : np.ndarray[float] - Array of point positions in x direction. - y : np.ndarray[float] - Array of point positions in y direction. - z : np.ndarray[float] - Array of point positions in z direction. - - Returns - ------- - np.ndarray[bool] - ``True`` for every point that is inside the geometry. - """ - # radius at z - self._ensure_equal_shape(x, y, z) - z0, (x0, y0) = self.pop_axis(self.center, axis=self.axis) - z, (x, y) = self.pop_axis((x, y, z), axis=self.axis) - radius_offset = self._radius_z(z) - positive_radius = radius_offset > 0 - - dist_x = np.abs(x - x0) - dist_y = np.abs(y - y0) - dist_z = np.abs(z - z0) - inside_radius = (dist_x**2 + dist_y**2) <= (radius_offset**2) - inside_height = dist_z <= (self.finite_length_axis / 2) - return positive_radius * inside_radius * inside_height - - @cached_property - def bounds(self) -> Bound: - """Returns bounding box min and max coordinates. - - Returns - ------- - Tuple[float, float, float], Tuple[float, float, float] - Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. - """ - coord_min = [c - self.radius_max for c in self.center] - coord_max = [c + self.radius_max for c in self.center] - coord_min[self.axis] = self.center[self.axis] - self.length_axis / 2.0 - coord_max[self.axis] = self.center[self.axis] + self.length_axis / 2.0 - return (tuple(coord_min), tuple(coord_max)) - - def _volume(self, bounds: Bound) -> float: - """Returns object's volume within given bounds.""" - - coord_min = max(self.bounds[0][self.axis], bounds[0][self.axis]) - coord_max = min(self.bounds[1][self.axis], bounds[1][self.axis]) - - length = coord_max - coord_min - - volume = np.pi * self.radius_max**2 * length - - # a very loose upper bound on how much of the cylinder is in bounds - for axis in range(3): - if axis != self.axis: - if self.center[axis] <= bounds[0][axis] or self.center[axis] >= bounds[1][axis]: - volume *= 0.5 - - return volume - - def _surface_area(self, bounds: Bound) -> float: - """Returns object's surface area within given bounds.""" - - area = 0 - - coord_min = self.bounds[0][self.axis] - coord_max = self.bounds[1][self.axis] - - if coord_min < bounds[0][self.axis]: - coord_min = bounds[0][self.axis] - else: - area += np.pi * self.radius_max**2 - - if coord_max > bounds[1][self.axis]: - coord_max = bounds[1][self.axis] - else: - area += np.pi * self.radius_max**2 - - length = coord_max - coord_min - - area += 2.0 * np.pi * self.radius_max * length - - # a very loose upper bound on how much of the cylinder is in bounds - for axis in range(3): - if axis != self.axis: - if self.center[axis] <= bounds[0][axis] or self.center[axis] >= bounds[1][axis]: - area *= 0.5 - - return area - - @cached_property - def radius_bottom(self) -> float: - """radius of bottom""" - return self._radius_z(self.center_axis - self.finite_length_axis / 2) - - @cached_property - def radius_top(self) -> float: - """radius of bottom""" - return self._radius_z(self.center_axis + self.finite_length_axis / 2) - - @cached_property - def radius_max(self) -> float: - """max(radius of top, radius of bottom)""" - return max(self.radius_bottom, self.radius_top) - - @cached_property - def radius_min(self) -> float: - """min(radius of top, radius of bottom). It can be negative for a large - sidewall angle. - """ - return min(self.radius_bottom, self.radius_top) - - def _radius_z(self, z: float) -> float: - """Compute the radius of the cross section at the position z. - - Parameters - ---------- - z : float - Position along the axis normal to slab - """ - if isclose(self.sidewall_angle, 0): - return self.radius - - radius_middle = self.radius - if self.reference_plane == "top": - radius_middle += self.finite_length_axis / 2 * self._tanq - elif self.reference_plane == "bottom": - radius_middle -= self.finite_length_axis / 2 * self._tanq - - return radius_middle - (z - self.center_axis) * self._tanq - - def _local_to_global_side_cross_section(self, coords: list[float], axis: int) -> list[float]: - """Map a point (x,y) from local to global coordinate system in the - side cross section. - - The definition of the local: y=0 lies at the base if ``sidewall_angle>=0``, - and at the top if ``sidewall_angle<0``; x=0 aligns with the corresponding - ``self.center``. In both cases, y-axis is pointing towards the narrowing - direction of cylinder. - - Parameters - ---------- - axis : int - Integer index into 'xyz' (0, 1, 2). - coords : list[float, float] - The value in the planar coordinate. - - Returns - ------- - Tuple[float, float] - The point in the global coordinate for plotting `_intersection_side`. - - """ - - # For negative sidewall angle, quantities along axis direction usually needs a flipped sign - axis_sign = 1 - if self.sidewall_angle < 0: - axis_sign = -1 +# marked as migrated to _common +from __future__ import annotations - lx_offset, ly_offset = self._order_by_axis( - plane_val=coords[0], - axis_val=axis_sign * (-self.finite_length_axis / 2 + coords[1]), - axis=axis, - ) - _, (x_center, y_center) = self.pop_axis(self.center, axis=axis) - return [x_center + lx_offset, y_center + ly_offset] +from tidy3d._common.components.geometry.primitives import ( + _DEFAULT_EDGE_FRACTION, + _MAX_ICOSPHERE_SUBDIVISIONS, + _N_PTS_CYLINDER_POLYSLAB, + _N_SAMPLE_CURVE_SHAPELY, + _N_SHAPELY_QUAD_SEGS_VISUALIZATION, + UNIT_SPHERE, + Cylinder, + Sphere, + _base_icosahedron, + discretization_wavelength, +) diff --git a/tidy3d/components/geometry/triangulation.py b/tidy3d/components/geometry/triangulation.py index 34624c86ea..96debe35cf 100644 --- a/tidy3d/components/geometry/triangulation.py +++ b/tidy3d/components/geometry/triangulation.py @@ -1,185 +1,14 @@ -from __future__ import annotations - -from dataclasses import dataclass -from typing import TYPE_CHECKING - -import numpy as np -import shapely - -from tidy3d.components.types import ArrayFloat1D -from tidy3d.exceptions import Tidy3dError - -if TYPE_CHECKING: - from tidy3d.components.types import ArrayFloat2D - - -@dataclass -class Vertex: - """Simple data class to hold triangulation data structures. - - Parameters - ---------- - coordinate: ArrayFloat1D - Vertex coordinate. - index : int - Vertex index in the original polygon. - convexity : float = 0.0 - Value representing the convexity (> 0) or concavity (< 0) of the vertex in the polygon. - is_ear : bool = False - Flag indicating whether this is an ear of the polygon. - """ - - coordinate: ArrayFloat1D - - index: int - - convexity: float - - is_ear: bool - - -def update_convexity(vertices: list[Vertex], i: int) -> int: - """Update the convexity of a vertex in a polygon. - - Parameters - ---------- - vertices : list[Vertex] - Vertices of the polygon. - i : int - Index of the vertex to be updated. - - Returns - ------- - int - Value indicating vertex convexity change w.r.t. 0. See note below. - - Note - ---- - Besides updating the vertex, this function returns a value indicating whether the updated vertex - convexity changed to or from 0 (0 convexity means the vertex is collinear with its neighbors). - If the convexity changes from zero to non-zero, return -1. If it changes from non-zero to zero, - return +1. Return 0 in any other case. This allows the main triangulation loop to keep track of - the total number of collinear vertices in the polygon. +"""Compatibility shim for :mod:`tidy3d._common.components.geometry.triangulation`.""" - """ - result = -1 if vertices[i].convexity == 0.0 else 0 - j = (i + 1) % len(vertices) - vertices[i].convexity = np.linalg.det( - [ - vertices[i].coordinate - vertices[i - 1].coordinate, - vertices[j].coordinate - vertices[i].coordinate, - ] - ) - if vertices[i].convexity == 0.0: - result += 1 - return result +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility +# marked as migrated to _common +from __future__ import annotations -def is_inside( - vertex: ArrayFloat1D, triangle: tuple[ArrayFloat1D, ArrayFloat1D, ArrayFloat1D] -) -> bool: - """Check if a vertex is inside a triangle. - - Parameters - ---------- - vertex : ArrayFloat1D - Vertex coordinates. - triangle : tuple[ArrayFloat1D, ArrayFloat1D, ArrayFloat1D] - Vertices of the triangle in CCW order. - - Returns - ------- - bool: - Flag indicating if the vertex is inside the triangle. - """ - return all( - np.linalg.det([triangle[i] - triangle[i - 1], vertex - triangle[i - 1]]) > 0 - for i in range(3) - ) - - -def update_ear_flag(vertices: list[Vertex], i: int) -> None: - """Update the ear flag of a vertex in a polygon. - - Parameters - ---------- - vertices : list[Vertex] - Vertices of the polygon. - i : int - Index of the vertex to be updated. - """ - h = (i - 1) % len(vertices) - j = (i + 1) % len(vertices) - triangle = (vertices[h].coordinate, vertices[i].coordinate, vertices[j].coordinate) - vertices[i].is_ear = vertices[i].convexity > 0 and not any( - is_inside(v.coordinate, triangle) - for k, v in enumerate(vertices) - if not (v.convexity > 0 or k == h or k == i or k == j) - ) - - -# TODO: This is an inefficient algorithm that runs in O(n^2). We should use something -# better, and probably as a compiled extension. -def triangulate(vertices: ArrayFloat2D) -> list[tuple[int, int, int]]: - """Triangulate a simple polygon. - - Parameters - ---------- - vertices : ArrayFloat2D - Vertices of the polygon. - - Returns - ------- - list[tuple[int, int, int]] - List of indices of the vertices of the triangles. - """ - is_ccw = shapely.LinearRing(vertices).is_ccw - - # Initialize vertices as non-collinear because we will update the actual value below and count - # the number of collinear vertices. - vertices = [Vertex(v, i, -1.0, False) for i, v in enumerate(vertices)] - if not is_ccw: - vertices.reverse() - - collinears = 0 - for i in range(len(vertices)): - collinears += update_convexity(vertices, i) - - for i in range(len(vertices)): - update_ear_flag(vertices, i) - - triangles = [] - - ear_found = True - while len(vertices) > 3: - if not ear_found: - raise Tidy3dError( - "Impossible to triangulate polygon. Verify that the polygon is valid." - ) - ear_found = False - i = 0 - while i < len(vertices): - if vertices[i].is_ear: - removed = vertices.pop(i) - h = (i - 1) % len(vertices) - j = i % len(vertices) - collinears += update_convexity(vertices, h) - collinears += update_convexity(vertices, j) - if collinears == len(vertices): - # Undo removal because only collinear vertices remain - vertices.insert(i, removed) - collinears += update_convexity(vertices, (i - 1) % len(vertices)) - collinears += update_convexity(vertices, (i + 1) % len(vertices)) - i += 1 - else: - ear_found = True - triangles.append((vertices[h].index, removed.index, vertices[j].index)) - update_ear_flag(vertices, h) - update_ear_flag(vertices, j) - if len(vertices) == 3: - break - else: - i += 1 - - triangles.append(tuple(v.index for v in vertices)) - return triangles +from tidy3d._common.components.geometry.triangulation import ( + Vertex, + is_inside, + triangulate, + update_convexity, + update_ear_flag, +) diff --git a/tidy3d/components/geometry/utils.py b/tidy3d/components/geometry/utils.py index 244585d55e..ef97e26dca 100644 --- a/tidy3d/components/geometry/utils.py +++ b/tidy3d/components/geometry/utils.py @@ -1,491 +1,47 @@ -"""Utilities for geometry manipulation.""" +"""Compatibility shim for :mod:`tidy3d._common.components.geometry.utils`.""" + +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility + +# marked as partially migrated to _common from __future__ import annotations -from collections import defaultdict -from enum import Enum from math import isclose -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING import numpy as np -import shapely -from pydantic import Field, NonNegativeInt -from shapely.geometry import ( - Polygon, -) -from shapely.geometry.base import ( - BaseMultipartGeometry, -) -from tidy3d.components.autograd.utils import get_static -from tidy3d.components.base import Tidy3dBaseModel +from tidy3d._common.components.geometry.utils import ( + GeometryType, + SnapBehavior, + SnapLocation, + SnappingSpec, # noqa: TC + flatten_groups, + flatten_shapely_geometries, + from_shapely, + get_closest_value, + merging_geometries_on_plane, + traverse_geometries, + validate_no_transformed_polyslabs, + vertices_from_shapely, +) from tidy3d.components.geometry.base import Box -from tidy3d.components.types import Shapely from tidy3d.constants import fp_eps -from tidy3d.exceptions import SetupError, Tidy3dError - -from . import base, mesh, polyslab, primitives +from tidy3d.exceptions import SetupError if TYPE_CHECKING: - from collections.abc import Iterable + from typing import Optional from numpy.typing import ArrayLike + from pydantic import NonNegativeInt from tidy3d.components.grid.grid import Grid - from tidy3d.components.types import ( - ArrayFloat2D, - Axis, + from tidy3d.components.types.base import ( Bound, Coordinate, Direction, - MatrixReal4x4, - PlanePosition, - ) - -GeometryType = Union[ - base.Box, - base.Transformed, - base.ClipOperation, - base.GeometryGroup, - primitives.Sphere, - primitives.Cylinder, - polyslab.PolySlab, - polyslab.ComplexPolySlabBase, - mesh.TriangleMesh, -] - - -def flatten_shapely_geometries( - geoms: Union[Shapely, Iterable[Shapely]], keep_types: tuple[type, ...] = (Polygon,) -) -> list[Shapely]: - """ - Flatten nested geometries into a flat list, while only keeping the specified types. - - Recursively extracts and returns non-empty geometries of the given types from input geometries, - expanding any GeometryCollections or Multi* types. - - Parameters - ---------- - geoms : Union[Shapely, Iterable[Shapely]] - Input geometries to flatten. - - keep_types : tuple[type, ...] - Geometry types to keep (e.g., (Polygon, LineString)). Default is - (Polygon). - - Returns - ------- - list[Shapely] - Flat list of non-empty geometries matching the specified types. - """ - # Handle single Shapely object by wrapping it in a list - if isinstance(geoms, Shapely): - geoms = [geoms] - - flat = [] - for geom in geoms: - if geom.is_empty: - continue - if isinstance(geom, keep_types): - flat.append(geom) - elif isinstance(geom, BaseMultipartGeometry): - flat.extend(flatten_shapely_geometries(geom.geoms, keep_types)) - return flat - - -def merging_geometries_on_plane( - geometries: list[GeometryType], - plane: Box, - property_list: list[Any], - interior_disjoint_geometries: bool = False, - cleanup: bool = True, - quad_segs: Optional[int] = None, -) -> list[tuple[Any, Shapely]]: - """Compute list of shapes on plane. Overlaps are removed or merged depending on - provided property_list. - - Parameters - ---------- - geometries : list[GeometryType] - List of structures to filter on the plane. - plane : Box - Plane specification. - property_list : List = None - Property value for each structure. - interior_disjoint_geometries: bool = False - If ``True``, geometries of different properties on the plane must not be overlapping. - cleanup : bool = True - If True, removes extremely small features from each polygon's boundary. - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. If ``None``, uses - high-quality visualization settings. - - Returns - ------- - list[tuple[Any, Shapely]] - List of shapes and their property value on the plane after merging. - """ - - if len(geometries) != len(property_list): - raise SetupError( - "Number of provided property values is not equal to the number of geometries." - ) - - shapes = [] - for geo, prop in zip(geometries, property_list): - # get list of Shapely shapes that intersect at the plane - shapes_plane = plane.intersections_with(geo, cleanup=cleanup, quad_segs=quad_segs) - - # Append each of them and their property information to the list of shapes - for shape in shapes_plane: - shapes.append((prop, shape, shape.bounds)) - - if interior_disjoint_geometries: - # No need to consider overlapping. We simply group shapes by property, and union_all - # shapes of the same property. - shapes_by_prop = defaultdict(list) - for prop, shape, _ in shapes: - shapes_by_prop[prop].append(shape) - # union shapes of same property - results = [] - for prop, shapes in shapes_by_prop.items(): - unionized = shapely.union_all(shapes).buffer(0).normalize() - if not unionized.is_empty: - results.append((prop, unionized)) - return results - - background_shapes = [] - for prop, shape, bounds in shapes: - minx, miny, maxx, maxy = bounds - - # loop through background_shapes (note: all background are non-intersecting or merged) - for index, (_prop, _shape, _bounds) in enumerate(background_shapes): - _minx, _miny, _maxx, _maxy = _bounds - - # do a bounding box check to see if any intersection to do anything about - if minx > _maxx or _minx > maxx or miny > _maxy or _miny > maxy: - continue - - # look more closely to see if intersected. - if shape.disjoint(_shape): - continue - - # different prop, remove intersection from background shape - if prop != _prop: - diff_shape = (_shape - shape).buffer(0).normalize() - # mark background shape for removal if nothing left - if diff_shape.is_empty or len(diff_shape.bounds) == 0: - background_shapes[index] = None - background_shapes[index] = (_prop, diff_shape, diff_shape.bounds) - # same prop, unionize shapes and mark background shape for removal - else: - shape = (shape | _shape).buffer(0).normalize() - background_shapes[index] = None - - # after doing this with all background shapes, add this shape to the background - background_shapes.append((prop, shape, shape.bounds)) - - # remove any existing background shapes that have been marked as 'None' - background_shapes = [b for b in background_shapes if b is not None] - - # filter out any remaining None or empty shapes (shapes with area completely removed) - return [(prop, shape) for (prop, shape, _) in background_shapes if shape] - - -def flatten_groups( - *geometries: GeometryType, - flatten_nonunion_type: bool = False, - flatten_transformed: bool = False, - transform: Optional[MatrixReal4x4] = None, -) -> GeometryType: - """Iterates over all geometries, flattening groups and unions. - - Parameters - ---------- - *geometries : GeometryType - Geometries to flatten. - flatten_nonunion_type : bool = False - If ``False``, only flatten geometry unions (and ``GeometryGroup``). If ``True``, flatten - all clip operations. - flatten_transformed : bool = False - If ``True``, ``Transformed`` groups are flattened into individual transformed geometries. - transform : Optional[MatrixReal4x4] - Accumulated transform from parents. Only used when ``flatten_transformed`` is ``True``. - - Yields - ------ - GeometryType - Geometries after flattening groups and unions. - """ - for geometry in geometries: - if isinstance(geometry, base.GeometryGroup): - yield from flatten_groups( - *geometry.geometries, - flatten_nonunion_type=flatten_nonunion_type, - flatten_transformed=flatten_transformed, - transform=transform, - ) - elif isinstance(geometry, base.ClipOperation) and ( - flatten_nonunion_type or geometry.operation == "union" - ): - yield from flatten_groups( - geometry.geometry_a, - geometry.geometry_b, - flatten_nonunion_type=flatten_nonunion_type, - flatten_transformed=flatten_transformed, - transform=transform, - ) - elif flatten_transformed and isinstance(geometry, base.Transformed): - new_transform = geometry.transform - if transform is not None: - new_transform = np.matmul(transform, new_transform) - yield from flatten_groups( - geometry.geometry, - flatten_nonunion_type=flatten_nonunion_type, - flatten_transformed=flatten_transformed, - transform=new_transform, - ) - elif flatten_transformed and transform is not None: - yield base.Transformed(geometry=geometry, transform=transform) - else: - yield geometry - - -def traverse_geometries(geometry: GeometryType) -> GeometryType: - """Iterator over all geometries within the given geometry. - - Iterates over groups and clip operations within the given geometry, yielding each one. - - Parameters - ---------- - geometry: GeometryType - Base geometry to start iteration. - - Returns - ------- - :class:`Geometry` - Geometries within the base geometry. - """ - if isinstance(geometry, base.GeometryGroup): - for g in geometry.geometries: - yield from traverse_geometries(g) - elif isinstance(geometry, base.ClipOperation): - yield from traverse_geometries(geometry.geometry_a) - yield from traverse_geometries(geometry.geometry_b) - yield geometry - - -def from_shapely( - shape: Shapely, - axis: Axis, - slab_bounds: tuple[float, float], - dilation: float = 0.0, - sidewall_angle: float = 0, - reference_plane: PlanePosition = "middle", -) -> base.Geometry: - """Convert a shapely primitive into a geometry instance by extrusion. - - Parameters - ---------- - shape : shapely.geometry.base.BaseGeometry - Shapely primitive to be converted. It must be a linear ring, a polygon or a collection - of any of those. - axis : int - Integer index defining the extrusion axis: 0 (x), 1 (y), or 2 (z). - slab_bounds: tuple[float, float] - Minimal and maximal positions of the extruded slab along ``axis``. - dilation : float - Dilation of the polygon in the base by shifting each edge along its normal outwards - direction by a distance; a negative value corresponds to erosion. - sidewall_angle : float = 0 - Angle of the extrusion sidewalls, away from the vertical direction, in radians. Positive - (negative) values result in slabs larger (smaller) at the base than at the top. - reference_plane : PlanePosition = "middle" - Reference position of the (dilated/eroded) polygons along the slab axis. One of - ``"middle"`` (polygons correspond to the center of the slab bounds), ``"bottom"`` - (minimal slab bound position), or ``"top"`` (maximal slab bound position). This value - has no effect if ``sidewall_angle == 0``. - - Returns - ------- - :class:`Geometry` - Geometry extruded from the 2D data. - """ - if shape.geom_type == "LinearRing": - if sidewall_angle == 0: - return polyslab.PolySlab( - vertices=shape.coords[:-1], - axis=axis, - slab_bounds=slab_bounds, - dilation=dilation, - reference_plane=reference_plane, - ) - group = polyslab.ComplexPolySlabBase( - vertices=shape.coords[:-1], - axis=axis, - slab_bounds=slab_bounds, - dilation=dilation, - sidewall_angle=sidewall_angle, - reference_plane=reference_plane, - ).geometry_group - return group.geometries[0] if len(group.geometries) == 1 else group - - if shape.geom_type == "Polygon": - exterior = from_shapely( - shape.exterior, axis, slab_bounds, dilation, sidewall_angle, reference_plane - ) - interior = [ - from_shapely(hole, axis, slab_bounds, -dilation, -sidewall_angle, reference_plane) - for hole in shape.interiors - ] - if len(interior) == 0: - return exterior - interior = interior[0] if len(interior) == 1 else base.GeometryGroup(geometries=interior) - return base.ClipOperation(operation="difference", geometry_a=exterior, geometry_b=interior) - - if shape.geom_type in {"MultiPolygon", "GeometryCollection"}: - return base.GeometryGroup( - geometries=[ - from_shapely(geo, axis, slab_bounds, dilation, sidewall_angle, reference_plane) - for geo in shape.geoms - ] - ) - - raise Tidy3dError(f"Shape {shape} cannot be converted to Geometry.") - - -def vertices_from_shapely(shape: Shapely) -> ArrayFloat2D: - """Iterate over the polygons of a shapely geometry returning the vertices. - - Parameters - ---------- - shape : shapely.geometry.base.BaseGeometry - Shapely primitive to have its vertices extracted. It must be a linear ring, a polygon or a - collection of any of those. - - Returns - ------- - list[tuple[ArrayFloat2D]] - List of tuples ``(exterior, *interiors)``. - """ - if shape.geom_type == "LinearRing": - return [(shape.coords[:-1],)] - if shape.geom_type == "Polygon": - return [(shape.exterior.coords[:-1], *tuple(hole.coords[:-1] for hole in shape.interiors))] - if shape.geom_type in {"MultiPolygon", "GeometryCollection"}: - return sum(vertices_from_shapely(geo) for geo in shape.geoms) - - raise Tidy3dError(f"Shape {shape} cannot be converted to Geometry.") - - -def validate_no_transformed_polyslabs( - geometry: GeometryType, transform: MatrixReal4x4 = None -) -> None: - """Prevents the creation of slanted polyslabs rotated out of plane.""" - if transform is None: - transform = np.eye(4) - if isinstance(geometry, polyslab.PolySlab): - # sidewall_angle may be autograd-traced; unbox for the check only - if not ( - isclose(get_static(geometry.sidewall_angle), 0) - or base.Transformed.preserves_axis(transform, geometry.axis) - ): - raise Tidy3dError( - "Slanted PolySlabs are not allowed to be rotated out of the slab plane." - ) - elif isinstance(geometry, base.Transformed): - transform = np.dot(transform, geometry.transform) - validate_no_transformed_polyslabs(geometry.geometry, transform) - elif isinstance(geometry, base.GeometryGroup): - for geo in geometry.geometries: - validate_no_transformed_polyslabs(geo, transform) - elif isinstance(geometry, base.ClipOperation): - validate_no_transformed_polyslabs(geometry.geometry_a, transform) - validate_no_transformed_polyslabs(geometry.geometry_b, transform) - - -class SnapLocation(Enum): - """Describes different methods for defining the snapping locations.""" - - Boundary = 1 - """ - Choose the boundaries of Yee cells. - """ - Center = 2 - """ - Choose the center of Yee cells. - """ - - -class SnapBehavior(Enum): - """Describes different methods for snapping intervals, which are defined by two endpoints.""" - - Closest = 1 - """ - Snaps the interval's endpoints to the closest grid point. - """ - Expand = 2 - """ - Snaps the interval's endpoints to the closest grid points, - while guaranteeing that the snapping location will never move endpoints inwards. - """ - Contract = 3 - """ - Snaps the interval's endpoints to the closest grid points, - while guaranteeing that the snapping location will never move endpoints outwards. - """ - StrictExpand = 4 - """ - Same as Expand, but will always move endpoints outwards, even if already coincident with grid. - """ - StrictContract = 5 - """ - Same as Contract, but will always move endpoints inwards, even if already coincident with grid. - """ - Off = 6 - """ - Do not use snapping. - """ - - -class SnappingSpec(Tidy3dBaseModel): - """Specifies how to apply grid snapping along each dimension.""" - - location: tuple[SnapLocation, SnapLocation, SnapLocation] = Field( - title="Location", - description="Describes which positions in the grid will be considered for snapping.", ) - behavior: tuple[SnapBehavior, SnapBehavior, SnapBehavior] = Field( - title="Behavior", - description="Describes how snapping positions will be chosen.", - ) - - margin: Optional[tuple[NonNegativeInt, NonNegativeInt, NonNegativeInt]] = Field( - (0, 0, 0), - title="Margin", - description="Number of additional grid points to consider when expanding or contracting " - "during snapping. Only applies when ``SnapBehavior`` is ``Expand`` or ``Contract``.", - ) - - -def get_closest_value(test: float, coords: ArrayLike, upper_bound_idx: int) -> float: - """Helper to choose the closest value in an array to a given test value, - using the index of the upper bound. The ``upper_bound_idx`` corresponds to the first value in - the ``coords`` array which is greater than or equal to the test value. - """ - # Handle corner cases first - if upper_bound_idx == 0: - return coords[upper_bound_idx] - if upper_bound_idx == len(coords): - return coords[upper_bound_idx - 1] - # General case - lower_bound = coords[upper_bound_idx - 1] - upper_bound = coords[upper_bound_idx] - dlower = abs(test - lower_bound) - dupper = abs(test - upper_bound) - return lower_bound if dlower < dupper else upper_bound - def snap_box_to_grid(grid: Grid, box: Box, snap_spec: SnappingSpec, rtol: float = fp_eps) -> Box: """Snaps a :class:`.Box` to the grid, so that the boundaries of the box are aligned with grid centers or boundaries. diff --git a/tidy3d/components/lumped_element.py b/tidy3d/components/lumped_element.py index 44fea6aee1..53cd053e10 100644 --- a/tidy3d/components/lumped_element.py +++ b/tidy3d/components/lumped_element.py @@ -96,7 +96,8 @@ def to_snapping_points(self) -> list[CoordinateOptional]: @abstractmethod def to_geometry(self) -> Geometry: - """Converts the :class:`.LumpedElement` object to a :class:`.Geometry`.""" + """Converts the :class:`.LumpedElement` object to a + :class:`~tidy3d.Geometry`.""" @abstractmethod def to_structure(self, grid: Optional[Grid] = None) -> Structure: @@ -441,7 +442,8 @@ def to_structure(self, grid: Optional[Grid] = None) -> Structure: ) def to_geometry(self, grid: Optional[Grid] = None) -> ClipOperation: - """Converts the :class:`CoaxialLumpedResistor` object to a :class:`Geometry`.""" + """Converts the :class:`CoaxialLumpedResistor` object to a + :class:`~tidy3d.Geometry`.""" rout = self.outer_diameter / 2 rin = self.inner_diameter / 2 disk_out = Cylinder(axis=self.normal_axis, radius=rout, length=0, center=self.center) diff --git a/tidy3d/components/monitor.py b/tidy3d/components/monitor.py index c946b7ee90..136514ed12 100644 --- a/tidy3d/components/monitor.py +++ b/tidy3d/components/monitor.py @@ -97,7 +97,7 @@ def _storage_size_solver(self, num_cells: int, tmesh: ArrayFloat1D) -> int: class FreqMonitor(Monitor, ABC): - """:class:`Monitor` that records data in the frequency-domain.""" + """:class:`~tidy3d.Monitor` that records data in the frequency-domain.""" freqs: FreqArray = Field( title="Frequencies", @@ -147,7 +147,7 @@ def frequency_range(self) -> FreqBound: class TimeMonitor(Monitor, ABC): - """:class:`Monitor` that records data in the time-domain.""" + """:class:`~tidy3d.Monitor` that records data in the time-domain.""" start: NonNegativeFloat = Field( 0.0, @@ -246,7 +246,7 @@ def num_steps(self, tmesh: ArrayFloat1D) -> int: class AbstractFieldMonitor(Monitor, ABC): - """:class:`Monitor` that records electromagnetic field data as a function of x,y,z.""" + """:class:`~tidy3d.Monitor` that records electromagnetic field data as a function of x,y,z.""" fields: tuple[EMField, ...] = Field( ["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"], @@ -334,7 +334,7 @@ def _storage_size_solver(self, num_cells: int, tmesh: ArrayFloat1D) -> int: class PlanarMonitor(Monitor, ABC): - """:class:`Monitor` that has a planar geometry.""" + """:class:`~tidy3d.Monitor` that has a planar geometry.""" _plane_validator = assert_plane() @@ -345,7 +345,7 @@ def normal_axis(self) -> Axis: class AbstractOverlapMonitor(PlanarMonitor, FreqMonitor): - """:class:`Monitor` that projects fields onto a specified basis and stores overlap amplitudes. + """:class:`~tidy3d.Monitor` that projects fields onto a specified basis and stores overlap amplitudes. This base is shared by ModeMonitor and Gaussian-overlap monitors. """ @@ -423,7 +423,7 @@ def _storage_size_solver(self, num_cells: int, tmesh: ArrayFloat1D) -> int: class AbstractModeMonitor(AbstractOverlapMonitor): - """:class:`Monitor` that records mode-related data.""" + """:class:`~tidy3d.Monitor` that records mode-related data.""" _draw_overlap_arrows: bool = False # AbstractModeMonitor.plot() draws its own arrows @@ -534,7 +534,7 @@ def _storage_size_solver(self, num_cells: int, tmesh: ArrayFloat1D) -> int: class AbstractGaussianOverlapMonitor(AbstractOverlapMonitor): - """:class:`Monitor` that records amplitudes from decomposition onto a Gaussian-like beam. + """:class:`~tidy3d.Monitor` that records amplitudes from decomposition onto a Gaussian-like beam. Common fields and behavior shared by GaussianOverlapMonitor and AstigmaticGaussianOverlapMonitor. @@ -580,7 +580,7 @@ def storage_size(self, num_cells: int, tmesh: ArrayFloat1D) -> int: class GaussianOverlapMonitor(AbstractGaussianOverlapMonitor): - """:class:`Monitor` that records amplitudes from decomposition onto a Gaussian beam. + """:class:`~tidy3d.Monitor` that records amplitudes from decomposition onto a Gaussian beam. Example ------- @@ -617,7 +617,7 @@ class GaussianOverlapMonitor(AbstractGaussianOverlapMonitor): class AstigmaticGaussianOverlapMonitor(AbstractGaussianOverlapMonitor): - """:class:`Monitor` that records amplitudes from decomposition onto an astigmatic Gaussian beam. + """:class:`~tidy3d.Monitor` that records amplitudes from decomposition onto an astigmatic Gaussian beam. The simple astigmatic Gaussian distribution allows both an elliptical intensity profile and different waist locations for the two principal axes @@ -664,7 +664,7 @@ class AstigmaticGaussianOverlapMonitor(AbstractGaussianOverlapMonitor): class FieldMonitor(AbstractFieldMonitor, FreqMonitor): - """:class:`Monitor` that records electromagnetic fields in the frequency domain. + """:class:`~tidy3d.Monitor` that records electromagnetic fields in the frequency domain. Notes ----- @@ -704,7 +704,7 @@ def storage_size(self, num_cells: int, tmesh: ArrayFloat1D) -> int: class FieldTimeMonitor(AbstractFieldMonitor, TimeMonitor): - """:class:`Monitor` that records electromagnetic fields in the time domain. + """:class:`~tidy3d.Monitor` that records electromagnetic fields in the time domain. Notes ----- @@ -777,7 +777,7 @@ def storage_size(self, num_cells: int, tmesh: ArrayFloat1D) -> int: class AbstractMediumPropertyMonitor(FreqMonitor, ABC): - """:class:`Monitor` that records material properties in the frequency domain.""" + """:class:`~tidy3d.Monitor` that records material properties in the frequency domain.""" colocate: Literal[False] = Field( False, @@ -802,7 +802,7 @@ class AbstractMediumPropertyMonitor(FreqMonitor, ABC): class MediumMonitor(AbstractMediumPropertyMonitor): - """:class:`Monitor` that records the diagonal components of the complex-valued relative + """:class:`~tidy3d.Monitor` that records the diagonal components of the complex-valued relative permittivity and permeability tensor in the frequency domain. The recorded data has the same shape as a :class:`.FieldMonitor` of the same geometry: the permittivity and permeability values are saved at the Yee grid locations, and can be interpolated to any point inside the monitor. @@ -831,7 +831,7 @@ def storage_size(self, num_cells: int, tmesh: ArrayFloat1D) -> int: class PermittivityMonitor(AbstractMediumPropertyMonitor): - """:class:`Monitor` that records the diagonal components of the complex-valued relative + """:class:`~tidy3d.Monitor` that records the diagonal components of the complex-valued relative permittivity tensor in the frequency domain. The recorded data has the same shape as a :class:`.FieldMonitor` of the same geometry: the permittivity values are saved at the Yee grid locations, and can be interpolated to any point inside the monitor. @@ -923,11 +923,11 @@ def _storage_size_solver(self, num_cells: int, tmesh: ArrayFloat1D) -> int: class AbstractFluxMonitor(SurfaceIntegrationMonitor, ABC): - """:class:`Monitor` that records flux during the solver run.""" + """:class:`~tidy3d.Monitor` that records flux during the solver run.""" class FluxMonitor(AbstractFluxMonitor, FreqMonitor): - """:class:`Monitor` that records power flux in the frequency domain. + """:class:`~tidy3d.Monitor` that records power flux in the frequency domain. Notes ----- @@ -960,7 +960,7 @@ def storage_size(self, num_cells: int, tmesh: ArrayFloat1D) -> int: class FluxTimeMonitor(AbstractFluxMonitor, TimeMonitor): - """:class:`Monitor` that records power flux in the time domain. + """:class:`~tidy3d.Monitor` that records power flux in the time domain. Notes ----- @@ -989,7 +989,7 @@ def storage_size(self, num_cells: int, tmesh: ArrayFloat1D) -> int: class ModeMonitor(AbstractModeMonitor): - """:class:`Monitor` that records amplitudes from modal decomposition of fields on plane. + """:class:`~tidy3d.Monitor` that records amplitudes from modal decomposition of fields on plane. Notes ------ @@ -1052,7 +1052,7 @@ def storage_size(self, num_cells: int, tmesh: int) -> int: class ModeSolverMonitor(AbstractModeMonitor): - """:class:`Monitor` that stores the mode field profiles returned by the mode solver in the + """:class:`~tidy3d.Monitor` that stores the mode field profiles returned by the mode solver in the monitor plane. Example @@ -1151,7 +1151,7 @@ def is_plane(cls, val: FieldMonitor) -> FieldMonitor: class AbstractFieldProjectionMonitor(SurfaceIntegrationMonitor, FreqMonitor): - """:class:`Monitor` that samples electromagnetic near fields in the frequency domain + """:class:`~tidy3d.Monitor` that samples electromagnetic near fields in the frequency domain and projects them to a given set of observation points. """ @@ -1326,7 +1326,7 @@ def window_function( class FieldProjectionAngleMonitor(AbstractFieldProjectionMonitor): - """:class:`Monitor` that samples electromagnetic near fields in the frequency domain + """:class:`~tidy3d.Monitor` that samples electromagnetic near fields in the frequency domain and projects them at given observation angles. Notes @@ -1460,7 +1460,7 @@ def storage_size(self, num_cells: int, tmesh: ArrayFloat1D) -> int: class DirectivityMonitor(MicrowaveBaseModel, FieldProjectionAngleMonitor, FluxMonitor): """ - :class:`Monitor` that records the radiation characteristics of antennas in the frequency domain + :class:`~tidy3d.Monitor` that records the radiation characteristics of antennas in the frequency domain at specified observation angles. Note @@ -1508,7 +1508,7 @@ def storage_size(self, num_cells: int, tmesh: ArrayFloat1D) -> int: class FieldProjectionCartesianMonitor(AbstractFieldProjectionMonitor): - """:class:`Monitor` that samples electromagnetic near fields in the frequency domain + """:class:`~tidy3d.Monitor` that samples electromagnetic near fields in the frequency domain and projects them on a Cartesian observation plane. Notes @@ -1650,7 +1650,7 @@ def storage_size(self, num_cells: int, tmesh: ArrayFloat1D) -> int: class FieldProjectionKSpaceMonitor(AbstractFieldProjectionMonitor): - """:class:`Monitor` that samples electromagnetic near fields in the frequency domain + """:class:`~tidy3d.Monitor` that samples electromagnetic near fields in the frequency domain and projects them on an observation plane defined in k-space. Notes @@ -1765,7 +1765,7 @@ def storage_size(self, num_cells: int, tmesh: ArrayFloat1D) -> int: class DiffractionMonitor(PlanarMonitor, FreqMonitor): - """:class:`Monitor` that uses a 2D Fourier transform to compute the + """:class:`~tidy3d.Monitor` that uses a 2D Fourier transform to compute the diffraction amplitudes and efficiency for allowed diffraction orders. Note diff --git a/tidy3d/components/simulation.py b/tidy3d/components/simulation.py index c46597fade..9061669834 100644 --- a/tidy3d/components/simulation.py +++ b/tidy3d/components/simulation.py @@ -21,6 +21,7 @@ from tidy3d.components.microwave.mode_spec import MicrowaveModeSpec from tidy3d.components.types.base import discriminated_union +from tidy3d.components.viz import plot_params_min_grid_size from tidy3d.constants import C_0, SECOND, fp_eps, inf from tidy3d.exceptions import ( AdjointError, @@ -127,7 +128,6 @@ equal_aspect, plot_params_abc, plot_params_bloch, - plot_params_min_grid_size, plot_params_override_structures, plot_params_pec, plot_params_pmc, diff --git a/tidy3d/components/source/base.py b/tidy3d/components/source/base.py index 6ffe49bd57..f080f49ad7 100644 --- a/tidy3d/components/source/base.py +++ b/tidy3d/components/source/base.py @@ -1,135 +1,10 @@ -"""Defines an abstract base for electromagnetic sources.""" +"""Compatibility shim for :mod:`tidy3d._common.components.source.base`.""" -from __future__ import annotations - -from abc import ABC -from typing import TYPE_CHECKING, Any +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -from pydantic import Field, field_validator +# marked as migrated to _common +from __future__ import annotations -from tidy3d.components.base import cached_property -from tidy3d.components.base_sim.source import AbstractSource -from tidy3d.components.geometry.base import Box -from tidy3d.components.types import TYPE_TAG_STR -from tidy3d.components.validators import _assert_min_freq, _warn_unsupported_traced_argument -from tidy3d.components.viz import ( - ARROW_ALPHA, - ARROW_COLOR_POLARIZATION, - ARROW_COLOR_SOURCE, - plot_params_source, +from tidy3d._common.components.source.base import ( + Source, ) - -from .time import SourceTimeType - -if TYPE_CHECKING: - from typing import Optional - - from tidy3d.components.types import Ax - from tidy3d.components.viz import PlotParams - - -class Source(Box, AbstractSource, ABC): - """Abstract base class for all sources.""" - - source_time: SourceTimeType = Field( - title="Source Time", - description="Specification of the source time-dependence.", - discriminator=TYPE_TAG_STR, - ) - - @cached_property - def plot_params(self) -> PlotParams: - """Default parameters for plotting a Source object.""" - return plot_params_source - - @cached_property - def geometry(self) -> Box: - """:class:`Box` representation of source.""" - - return Box(center=self.center, size=self.size) - - @cached_property - def _injection_axis(self) -> None: - """Injection axis of the source.""" - return - - @cached_property - def _dir_vector(self) -> None: - """Returns a vector indicating the source direction for arrow plotting, if not None.""" - return None - - @cached_property - def _pol_vector(self) -> None: - """Returns a vector indicating the source polarization for arrow plotting, if not None.""" - return None - - _warn_traced_center = _warn_unsupported_traced_argument("center") - _warn_traced_size = _warn_unsupported_traced_argument("size") - - @field_validator("source_time") - @classmethod - def _freqs_lower_bound(cls, val: SourceTimeType) -> SourceTimeType: - """Raise validation error if central frequency is too low.""" - _assert_min_freq(val._freq0_sigma_centroid, msg_start="'source_time.freq0'") - return val - - def plot( - self, - x: Optional[float] = None, - y: Optional[float] = None, - z: Optional[float] = None, - ax: Ax = None, - **patch_kwargs: Any, - ) -> Ax: - """Plot this source.""" - - kwargs_arrow_base = patch_kwargs.pop("arrow_base", None) - - # call the `Source.plot()` function first. - ax = Box.plot(self, x=x, y=y, z=z, ax=ax, **patch_kwargs) - - kwargs_alpha = patch_kwargs.get("alpha") - arrow_alpha = ARROW_ALPHA if kwargs_alpha is None else kwargs_alpha - - # then add the arrow based on the propagation direction - if self._dir_vector is not None: - bend_radius = None - bend_axis = None - if hasattr(self, "mode_spec") and self.mode_spec.bend_radius is not None: - bend_radius = self.mode_spec.bend_radius - bend_axis = self._bend_axis - sign = 1 if self.direction == "+" else -1 - # Curvature has to be reversed because of ploting coordinates - if (self.size.index(0), bend_axis) in [(1, 2), (2, 0), (2, 1)]: - bend_radius *= -sign - else: - bend_radius *= sign - - ax = self._plot_arrow( - x=x, - y=y, - z=z, - ax=ax, - direction=self._dir_vector, - bend_radius=bend_radius, - bend_axis=bend_axis, - color=ARROW_COLOR_SOURCE, - alpha=arrow_alpha, - both_dirs=False, - arrow_base=kwargs_arrow_base, - ) - - if self._pol_vector is not None: - ax = self._plot_arrow( - x=x, - y=y, - z=z, - ax=ax, - direction=self._pol_vector, - color=ARROW_COLOR_POLARIZATION, - alpha=arrow_alpha, - both_dirs=False, - arrow_base=kwargs_arrow_base, - ) - - return ax diff --git a/tidy3d/components/source/time.py b/tidy3d/components/source/time.py index 5ff4ec868d..628cd35b8b 100644 --- a/tidy3d/components/source/time.py +++ b/tidy3d/components/source/time.py @@ -1,691 +1,21 @@ -"""Defines time dependencies of injected electromagnetic sources.""" +"""Compatibility shim for :mod:`tidy3d._common.components.source.time`.""" -from __future__ import annotations - -import logging -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Optional, Union - -import numpy as np -from pydantic import Field, PositiveFloat, field_validator, model_validator -from pyroots import Brentq - -from tidy3d.components.base import cached_property -from tidy3d.components.data.data_array import TimeDataArray -from tidy3d.components.data.dataset import TimeDataset -from tidy3d.components.data.validators import validate_no_nans -from tidy3d.components.time import AbstractTimeDependence -from tidy3d.components.types import FreqBound -from tidy3d.components.validators import warn_if_dataset_none -from tidy3d.components.viz import add_ax_if_none -from tidy3d.constants import HERTZ -from tidy3d.exceptions import ValidationError -from tidy3d.log import log -from tidy3d.packaging import check_tidy3d_extras_licensed_feature, tidy3d_extras - -if TYPE_CHECKING: - from tidy3d.components.types import ArrayComplex1D, ArrayFloat1D, Ax, PlotVal - -# how many units of ``twidth`` from the ``offset`` until a gaussian pulse is considered "off" -END_TIME_FACTOR_GAUSSIAN = 10 - -# warn if source amplitude is too small at the endpoints of frequency range -WARN_SOURCE_AMPLITUDE = 0.1 -# used in Brentq -_ROOTS_TOL = 1e-10 -# Default sigma value in frequency_range -DEFAULT_SIGMA = 4.0 -# Offset in fwidth in finding frequency_range_sigma[1] to ensure the interval brackets the root -OFFSET_FWIDTH_FMAX = 100 - - -class SourceTime(AbstractTimeDependence): - """Base class describing the time dependence of a source.""" - - @add_ax_if_none - def plot_spectrum( - self, - times: ArrayFloat1D, - num_freqs: int = 101, - val: PlotVal = "real", - ax: Ax = None, - ) -> Ax: - """Plot the complex-valued amplitude of the source time-dependence. - Note: Only the real part of the time signal is used. - - Parameters - ---------- - times : np.ndarray - Array of evenly-spaced times (seconds) to evaluate source time-dependence at. - The spectrum is computed from this value and the source time frequency content. - To see source spectrum for a specific :class:`.Simulation`, - pass ``simulation.tmesh``. - num_freqs : int = 101 - Number of frequencies to plot within the SourceTime.frequency_range. - ax : matplotlib.axes._subplots.Axes = None - Matplotlib axes to plot on, if not specified, one is created. - - Returns - ------- - matplotlib.axes._subplots.Axes - The supplied or created matplotlib axes. - """ - - fmin, fmax = self.frequency_range_sigma() - return self.plot_spectrum_in_frequency_range( - times, fmin, fmax, num_freqs=num_freqs, val=val, ax=ax - ) - - @abstractmethod - def frequency_range(self, num_fwidth: float = DEFAULT_SIGMA) -> FreqBound: - """Frequency range within plus/minus ``num_fwidth * fwidth`` of the central frequency.""" - - def frequency_range_sigma(self, sigma: float = DEFAULT_SIGMA) -> FreqBound: - """Frequency range where the source amplitude is within ``exp(-sigma**2/2)`` of the peak amplitude.""" - return self.frequency_range(num_fwidth=sigma) - - @cached_property - def _frequency_range_sigma_cached(self) -> FreqBound: - """Cached `frequency_range_sigma` for the default sigma value.""" - return self.frequency_range_sigma(sigma=DEFAULT_SIGMA) - - @abstractmethod - def end_time(self) -> Optional[float]: - """Time after which the source is effectively turned off / close to zero amplitude.""" - - @cached_property - def _freq0(self) -> float: - """Central frequency. If not present in input parameters, returns `_freq0_sigma_centroid`.""" - return self._freq0_sigma_centroid - - @cached_property - def _freq0_sigma_centroid(self) -> float: - """Central of frequency range at 1-sigma drop from the peak amplitude.""" - return np.mean(self.frequency_range_sigma(sigma=1)) - - -class Pulse(SourceTime, ABC): - """A source time that ramps up with some ``fwidth`` and oscillates at ``freq0``.""" - - freq0: PositiveFloat = Field( - title="Central Frequency", - description="Central frequency of the pulse.", - json_schema_extra={"units": HERTZ}, - ) - fwidth: PositiveFloat = Field( - title="", - description="Standard deviation of the frequency content of the pulse.", - json_schema_extra={"units": HERTZ}, - ) - - offset: float = Field( - 5.0, - title="Offset", - description="Time delay of the maximum value of the " - "pulse in units of 1 / (``2pi * fwidth``).", - ge=2.5, - ) - - @cached_property - def _freq0(self) -> float: - """Central frequency.""" - return self.freq0 - - @property - def offset_time(self) -> float: - """Offset time in seconds.""" - return self.offset * self.twidth - - @property - def twidth(self) -> float: - """Width of pulse in seconds.""" - return 1.0 / (2 * np.pi * self.fwidth) - - def frequency_range(self, num_fwidth: float = DEFAULT_SIGMA) -> FreqBound: - """Frequency range within 5 standard deviations of the central frequency. - - Parameters - ---------- - num_fwidth : float = 4. - Frequency range defined as plus/minus ``num_fwidth * self.fwdith``. - - Returns - ------- - Tuple[float, float] - Minimum and maximum frequencies of the :class:`GaussianPulse` or :class:`ContinuousWave` - power. - """ - - freq_width_range = num_fwidth * self.fwidth - freq_min = max(0, self.freq0 - freq_width_range) - freq_max = self.freq0 + freq_width_range - return (freq_min, freq_max) - - -class GaussianPulse(Pulse): - """Source time dependence that describes a Gaussian pulse. - - Example - ------- - >>> pulse = GaussianPulse(freq0=200e12, fwidth=20e12) - """ - - remove_dc_component: bool = Field( - True, - title="Remove DC Component", - description="Whether to remove the DC component in the Gaussian pulse spectrum. " - "If ``True``, the Gaussian pulse is modified at low frequencies to zero out the " - "DC component, which is usually desirable so that the fields will decay. However, " - "for broadband simulations, it may be better to have non-vanishing source power " - "near zero frequency. Setting this to ``False`` results in an unmodified Gaussian " - "pulse spectrum which can have a nonzero DC component.", - ) - - @property - def peak_time(self) -> float: - """Peak time in seconds, defined by ``offset``.""" - return self.offset * self.twidth - - @property - def _peak_time_shift(self) -> float: - """In the case of DC removal, correction to offset_time so that ``offset`` indeed defines time delay - of pulse peak. - """ - if self.remove_dc_component and self.fwidth > self.freq0: - return self.twidth * np.sqrt(1 - self.freq0**2 / self.fwidth**2) - return 0 - - @property - def offset_time(self) -> float: - """Offset time in seconds. Note that in the case of DC removal, the maximal value of pulse can be shifted.""" - return self.peak_time + self._peak_time_shift - - def amp_time(self, time: float) -> complex: - """Complex-valued source amplitude as a function of time.""" - - omega0 = 2 * np.pi * self.freq0 - time_shifted = time - self.offset_time - - offset = np.exp(1j * self.phase) - oscillation = np.exp(-1j * omega0 * time) - amp = np.exp(-(time_shifted**2) / 2 / self.twidth**2) * self.amplitude - - pulse_amp = offset * oscillation * amp - - # subtract out DC component - if self.remove_dc_component: - pulse_amp = pulse_amp * (1j * omega0 + time_shifted / self.twidth**2) - # normalize by peak frequency instead of omega0, as for small omega0, omega0 approaches 0 faster - pulse_amp /= 2 * np.pi * self.peak_frequency - else: - # 1j to make it agree in large omega0 limit - pulse_amp = pulse_amp * 1j - - return pulse_amp - - def end_time(self) -> Optional[float]: - """Time after which the source is effectively turned off / close to zero amplitude.""" - - # TODO: decide if we should continue to return an end_time if the DC component remains - # if not self.remove_dc_component: - # return None - - end_time = self.offset_time + END_TIME_FACTOR_GAUSSIAN * self.twidth - - # for derivative Gaussian that contains two peaks, add time interval between them - if self.remove_dc_component and self.fwidth > self.freq0: - end_time += 2 * self._peak_time_shift - return end_time - - def amp_freq(self, freq: float) -> complex: - """Complex-valued source spectrum in frequency domain.""" - phase = np.exp(1j * self.phase + 1j * 2 * np.pi * (freq - self.freq0) * self.offset_time) - envelope = np.exp(-((freq - self.freq0) ** 2) / 2 / self.fwidth**2) - amp = 1j * self.amplitude / self.fwidth * phase * envelope - if not self.remove_dc_component: - return amp - - # derivative of Gaussian when DC is removed - return freq * amp / (2 * np.pi * self.peak_frequency) - - def _rel_amp_freq(self, freq: float) -> complex: - """Complex-valued source spectrum in frequency domain normalized by peak amplitude.""" - return self.amp_freq(freq) / self._peak_freq_amp - - @property - def peak_frequency(self) -> float: - """Frequency at which the source time dependence has its peak amplitude in the frequency domain.""" - if not self.remove_dc_component: - return self.freq0 - return 0.5 * (self.freq0 + np.sqrt(self.freq0**2 + 4 * self.fwidth**2)) - - @property - def _peak_freq_amp(self) -> complex: - """Peak amplitude in frequency domain""" - return self.amp_freq(self.peak_frequency) - - @property - def _peak_time_amp(self) -> complex: - """Peak amplitude in time domain""" - return self.amp_time(self.peak_time) - - def frequency_range_sigma(self, sigma: float = DEFAULT_SIGMA) -> FreqBound: - """Frequency range where the source amplitude is within ``exp(-sigma**2/2)`` of the peak amplitude.""" - if not self.remove_dc_component: - return self.frequency_range(num_fwidth=sigma) - - # With dc removed, we'll need to solve for the transcendental equation to find the frequency range - def equation_for_sigma_frequency(freq: float) -> float: - """computes A / A_p - exp(-sigma)""" - return np.abs(self._rel_amp_freq(freq)) - np.exp(-(sigma**2) / 2) - - logger = logging.getLogger("pyroots") - logger.setLevel(logging.CRITICAL) - root_scalar = Brentq(raise_on_fail=False, epsilon=_ROOTS_TOL) - fmin_data = root_scalar(equation_for_sigma_frequency, xa=0, xb=self.peak_frequency) - fmax_data = root_scalar( - equation_for_sigma_frequency, - xa=self.peak_frequency, - xb=self.peak_frequency - + self.fwidth - * ( - OFFSET_FWIDTH_FMAX + 2 * sigma**2 - ), # offset slightly to make sure that it flips sign - ) - fmin, fmax = fmin_data.x0, fmax_data.x0 - - # if unconverged, fall back to `frequency_range` - if not (fmin_data.converged and fmax_data.converged and fmax > fmin): - return self.frequency_range(num_fwidth=sigma) - - # converged - return fmin.item(), fmax.item() - - @property - def amp_complex(self) -> complex: - """Grab the complex amplitude from a ``GaussianPulse``.""" - phase = np.exp(1j * self.phase) - return self.amplitude * phase - - @classmethod - def from_amp_complex(cls, amp: complex, **kwargs: Any) -> GaussianPulse: - """Set the complex amplitude of a ``GaussianPulse``. - - Parameters - ---------- - amp : complex - Complex-valued amplitude to set in the returned ``GaussianPulse``. - kwargs : dict - Keyword arguments passed to ``GaussianPulse()``, excluding ``amplitude`` & ``phase``. - """ - amplitude = abs(amp) - phase = np.angle(amp) - return cls(amplitude=amplitude, phase=phase, **kwargs) - - @staticmethod - def _minimum_source_bandwidth( - fmin: float, fmax: float, minimum_source_bandwidth: float - ) -> tuple[float, float]: - """Define a source bandwidth based on fmin and fmax, but enforce a minimum bandwidth.""" - if minimum_source_bandwidth <= 0: - raise ValidationError("'minimum_source_bandwidth' must be positive") - if minimum_source_bandwidth >= 1: - raise ValidationError("'minimum_source_bandwidth' must less than or equal to 1") - - f_difference = fmax - fmin - f_middle = 0.5 * (fmin + fmax) - - full_width = minimum_source_bandwidth * f_middle - if f_difference < full_width: - half_width = 0.5 * full_width - fmin = f_middle - half_width - fmax = f_middle + half_width - - return fmin, fmax - - @classmethod - def from_frequency_range( - cls, - fmin: PositiveFloat, - fmax: PositiveFloat, - minimum_source_bandwidth: Optional[PositiveFloat] = None, - **kwargs: Any, - ) -> GaussianPulse: - """Create a ``GaussianPulse`` that maximizes its amplitude in the frequency range [fmin, fmax]. - - Parameters - ---------- - fmin : float - Lower bound of frequency of interest. - fmax : float - Upper bound of frequency of interest. - kwargs : dict - Keyword arguments passed to ``GaussianPulse()``, excluding ``freq0`` & ``fwidth``. - - Returns - ------- - GaussianPulse - A ``GaussianPulse`` that maximizes its amplitude in the frequency range [fmin, fmax]. - """ - # validate that fmin and fmax must positive, and fmax > fmin - if fmin <= 0: - raise ValidationError("'fmin' must be positive.") - if fmax <= fmin: - raise ValidationError("'fmax' must be greater than 'fmin'.") - - if minimum_source_bandwidth is not None: - fmin, fmax = cls._minimum_source_bandwidth(fmin, fmax, minimum_source_bandwidth) - - # frequency range and center - freq_range = fmax - fmin - freq_center = (fmax + fmin) / 2.0 - - # If remove_dc_component=False, simply return the standard GaussianPulse parameters - if kwargs.get("remove_dc_component", True) is False: - return cls(freq0=freq_center, fwidth=freq_range / 2.0, **kwargs) - - # If remove_dc_component=True, the Gaussian pulse is distorted - kwargs.update({"remove_dc_component": True}) - log_ratio = np.log(fmax / fmin) - coeff = ((1 + log_ratio**2) ** 0.5 - 1) / 2.0 - freq0 = freq_center - coeff / log_ratio * freq_range - fwidth = freq_range / log_ratio * coeff**0.5 - pulse = cls(freq0=freq0, fwidth=fwidth, **kwargs) - if np.abs(pulse._rel_amp_freq(fmin)) < WARN_SOURCE_AMPLITUDE: - log.warning( - "Default source time profile is less accurate for the specified broadband frequency range. " - "For more accurate results, consider reducing the frequency range or using a 'BroadbandSource'.", - ) - return pulse - - -class ContinuousWave(Pulse): - """Source time dependence that ramps up to continuous oscillation - and holds until end of simulation. - - Note - ---- - Field decay will not occur, so the simulation will run for the full ``run_time``. - Also, source normalization of frequency-domain monitors is not meaningful. - - Example - ------- - >>> cw = ContinuousWave(freq0=200e12, fwidth=20e12) - """ - - def amp_time(self, time: float) -> complex: - """Complex-valued source amplitude as a function of time.""" - - twidth = 1.0 / (2 * np.pi * self.fwidth) - omega0 = 2 * np.pi * self.freq0 - time_shifted = time - self.offset_time - - const = 1.0 - offset = np.exp(1j * self.phase) - oscillation = np.exp(-1j * omega0 * time) - amp = 1 / (1 + np.exp(-time_shifted / twidth)) * self.amplitude - - return const * offset * oscillation * amp - - def end_time(self) -> Optional[float]: - """Time after which the source is effectively turned off / close to zero amplitude.""" - return None - - -class CustomSourceTime(Pulse): - """Custom source time dependence consisting of a real or complex envelope - modulated at a central frequency, as shown below. - - Note - ---- - .. math:: - - amp\\_time(t) = amplitude \\cdot \\ - e^{i \\cdot phase - 2 \\pi i \\cdot freq0 \\cdot t} \\cdot \\ - envelope(t - offset / (2 \\pi \\cdot fwidth)) - - Note - ---- - Depending on the envelope, field decay may not occur. - If field decay does not occur, then the simulation will run for the full ``run_time``. - Also, if field decay does not occur, then source normalization of frequency-domain - monitors is not meaningful. - - Note - ---- - The source time dependence is linearly interpolated to the simulation time steps. - The sampling rate should be sufficiently fast that this interpolation does not - introduce artifacts. The source time dependence should also start at zero and ramp up smoothly. - The first and last values of the envelope will be used for times that are out of range - of the provided data. - - Example - ------- - >>> cst = CustomSourceTime.from_values(freq0=1, fwidth=0.1, - ... values=np.linspace(0, 9, 10), dt=0.1) - - """ - - offset: float = Field( - 0.0, - title="Offset", - description="Time delay of the envelope in units of 1 / (``2pi * fwidth``).", - ) - - source_time_dataset: Optional[TimeDataset] = Field( - None, - title="Source time dataset", - description="Dataset for storing the envelope of the custom source time. " - "This envelope will be modulated by a complex exponential at frequency ``freq0``.", - ) - - _no_nans_dataset = validate_no_nans("source_time_dataset") - _source_time_dataset_none_warning = warn_if_dataset_none("source_time_dataset") - - @field_validator("source_time_dataset") - @classmethod - def _more_than_one_time(cls, val: Optional[TimeDataset]) -> Optional[TimeDataset]: - """Must have more than one time to interpolate.""" - if val is None: - return val - if val.values.size <= 1: - raise ValidationError("'CustomSourceTime' must have more than one time coordinate.") - return val - - @classmethod - def from_values( - cls, freq0: float, fwidth: float, values: ArrayComplex1D, dt: float - ) -> CustomSourceTime: - """Create a :class:`.CustomSourceTime` from a numpy array. - - Parameters - ---------- - freq0 : float - Central frequency of the source. The envelope provided will be modulated - by a complex exponential at this frequency. - fwidth : float - Estimated frequency width of the source. - values: ArrayComplex1D - Complex values of the source envelope. - dt: float - Time step for the ``values`` array. This value should be sufficiently small - that the interpolation to simulation time steps does not introduce artifacts. - - Returns - ------- - CustomSourceTime - :class:`.CustomSourceTime` with envelope given by ``values``, modulated by a complex - exponential at frequency ``freq0``. The time coordinates are evenly spaced - between ``0`` and ``dt * (N-1)`` with a step size of ``dt``, where ``N`` is the length of - the values array. - """ - - times = np.arange(len(values)) * dt - source_time_dataarray = TimeDataArray(values, coords={"t": times}) - source_time_dataset = TimeDataset(values=source_time_dataarray) - return CustomSourceTime( - freq0=freq0, - fwidth=fwidth, - source_time_dataset=source_time_dataset, - ) - - @property - def data_times(self) -> ArrayFloat1D: - """Times of envelope definition.""" - if self.source_time_dataset is None: - return [] - data_times = self.source_time_dataset.values.coords["t"].values.squeeze() - return data_times - - def _all_outside_range(self, run_time: float) -> bool: - """Whether all times are outside range of definition.""" - - # can't validate if data isn't loaded - if self.source_time_dataset is None: - return False - - # make time a numpy array for uniform handling - data_times = self.data_times - - # shift time - max_time_shifted = run_time - self.offset_time - min_time_shifted = -self.offset_time - - return (max_time_shifted < min(data_times)) | (min_time_shifted > max(data_times)) - - def amp_time(self, time: float) -> complex: - """Complex-valued source amplitude as a function of time. - - Parameters - ---------- - time : float - Time in seconds. - - Returns - ------- - complex - Complex-valued source amplitude at that time. - """ - - if self.source_time_dataset is None: - return None - - # make time a numpy array for uniform handling - times = np.array([time] if isinstance(time, (int, float)) else time) - data_times = self.data_times - - # shift time - twidth = 1.0 / (2 * np.pi * self.fwidth) - time_shifted = times - self.offset * twidth - - # mask times that are out of range - mask = (time_shifted < min(data_times)) | (time_shifted > max(data_times)) - - # get envelope - envelope = np.zeros(len(time_shifted), dtype=complex) - values = self.source_time_dataset.values - envelope[mask] = values.sel(t=time_shifted[mask], method="nearest").to_numpy() - if not all(mask): - envelope[~mask] = values.interp(t=time_shifted[~mask]).to_numpy() - - # modulation, phase, amplitude - omega0 = 2 * np.pi * self.freq0 - offset = np.exp(1j * self.phase) - oscillation = np.exp(-1j * omega0 * times) - amp = self.amplitude - - return offset * oscillation * amp * envelope - - def end_time(self) -> Optional[float]: - """Time after which the source is effectively turned off / close to zero amplitude.""" - - if self.source_time_dataset is None: - return None - - data_array = self.source_time_dataset.values - - t_coords = data_array.coords["t"] - source_is_non_zero = ~np.isclose(abs(data_array), 0) - t_non_zero = t_coords[source_is_non_zero] - - return np.max(t_non_zero) - - -class BroadbandPulse(SourceTime): - """A source time injecting significant energy in the entire custom frequency range.""" - - freq_range: FreqBound = Field( - title="Frequency Range", - description="Frequency range where the pulse should have significant energy.", - json_schema_extra={"units": HERTZ}, - ) - minimum_amplitude: float = Field( - 0.3, - title="Minimum Amplitude", - description="Minimum amplitude of the pulse relative to the peak amplitude in the frequency range.", - gt=0.05, - lt=0.5, - ) - offset: float = Field( - 0.0, - title="Offset", - description="An automatic time delay of the peak value of the pulse has been applied under the hood " - "to ensure smooth ramping up of the pulse at time = 0. This offfset is added on top of the automatic time delay " - "in units of 1 / [``2pi * (freq_range[1] - freq_range[0])``].", - ) - - @field_validator("freq_range") - @classmethod - def _validate_freq_range(cls, val: FreqBound) -> FreqBound: - """Validate that freq_range is positive and properly ordered.""" - if val[0] <= 0 or val[1] <= 0: - raise ValidationError("Both elements of 'freq_range' must be positive.") - if val[1] <= val[0]: - raise ValidationError( - f"'freq_range[1]' ({val[1]}) must be greater than 'freq_range[0]' ({val[0]})." - ) - return val - - @model_validator(mode="before") - @classmethod - def _check_broadband_pulse_available(cls, values: dict[str, Any]) -> dict[str, Any]: - """Check if BroadbandPulse is available.""" - check_tidy3d_extras_licensed_feature("BroadbandPulse") - return values - - @cached_property - def _source(self) -> Any: - """Implementation of broadband pulse.""" - return tidy3d_extras["mod"].extension.BroadbandPulse( - fmin=self.freq_range[0], - fmax=self.freq_range[1], - minRelAmp=self.minimum_amplitude, - amp=self.amplitude, - phase=self.phase, - offset=self.offset, - ) - - def end_time(self) -> float: - """Time after which the source is effectively turned off / close to zero amplitude.""" - return self._source.end_time(END_TIME_FACTOR_GAUSSIAN) - - def amp_time(self, time: float) -> complex: - """Complex-valued source amplitude as a function of time.""" - return self._source.amp_time(time) - - def amp_freq(self, freq: float) -> complex: - """Complex-valued source amplitude as a function of frequency.""" - return self._source.amp_freq(freq) - - def frequency_range_sigma(self, sigma: float = DEFAULT_SIGMA) -> FreqBound: - """Frequency range where the source amplitude is within ``exp(-sigma**2/2)`` of the peak amplitude.""" - return self._source.frequency_range(sigma) - - def frequency_range(self, num_fwidth: float = DEFAULT_SIGMA) -> FreqBound: - """Delegated to `frequency_range_sigma(sigma=num_fwidth)` for computing the frequency range where the source amplitude - is within ``exp(-num_fwidth**2/2)`` of the peak amplitude. - """ - return self.frequency_range_sigma(num_fwidth) +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility +# marked as migrated to _common +from __future__ import annotations -SourceTimeType = Union[GaussianPulse, ContinuousWave, CustomSourceTime, BroadbandPulse] +from tidy3d._common.components.source.time import ( + _ROOTS_TOL, + DEFAULT_SIGMA, + END_TIME_FACTOR_GAUSSIAN, + OFFSET_FWIDTH_FMAX, + WARN_SOURCE_AMPLITUDE, + BroadbandPulse, + ContinuousWave, + CustomSourceTime, + GaussianPulse, + Pulse, + SourceTime, + SourceTimeType, +) diff --git a/tidy3d/components/time.py b/tidy3d/components/time.py index a5c5be0933..c14f051d8b 100644 --- a/tidy3d/components/time.py +++ b/tidy3d/components/time.py @@ -1,208 +1,11 @@ -"""Defines time dependence""" +"""Compatibility shim for :mod:`tidy3d._common.components.time`.""" -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING - -import numpy as np -from pydantic import Field, NonNegativeFloat - -from tidy3d.constants import RADIAN -from tidy3d.exceptions import SetupError - -from .base import Tidy3dBaseModel -from .viz import add_ax_if_none - -if TYPE_CHECKING: - from .types import ArrayFloat1D, Ax, PlotVal - -# in spectrum computation, discard amplitudes with relative magnitude smaller than cutoff -DFT_CUTOFF = 1e-8 - - -class AbstractTimeDependence(ABC, Tidy3dBaseModel): - """Base class describing time dependence.""" - - amplitude: NonNegativeFloat = Field( - 1.0, title="Amplitude", description="Real-valued maximum amplitude of the time dependence." - ) - - phase: float = Field( - 0.0, - title="Phase", - description="Phase shift of the time dependence.", - json_schema_extra={"units": RADIAN}, - ) - - @abstractmethod - def amp_time(self, time: float) -> complex: - """Complex-valued amplitude as a function of time. - - Parameters - ---------- - time : float - Time in seconds. - - Returns - ------- - complex - Complex-valued amplitude at that time. - """ - - def spectrum( - self, - times: ArrayFloat1D, - freqs: ArrayFloat1D, - dt: float, - ) -> complex: - """Complex-valued spectrum as a function of frequency. - Note: Only the real part of the time signal is used. - - Parameters - ---------- - times : np.ndarray - Times to use to evaluate spectrum Fourier transform. - (Typically the simulation time mesh). - freqs : np.ndarray - Frequencies in Hz to evaluate spectrum at. - dt : float or np.ndarray - Time step to weight FT integral with. - If array, use to weigh each of the time intervals in ``times``. - - Returns - ------- - np.ndarray - Complex-valued array (of len(freqs)) containing spectrum at those frequencies. - """ +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility - times = np.array(times) - freqs = np.array(freqs) - time_amps = np.real(self.amp_time(times)) - - # if all time amplitudes are zero, just return (complex-valued) zeros for spectrum - if np.all(np.equal(time_amps, 0.0)): - return (0.0 + 0.0j) * np.zeros_like(freqs) - - # Cut to only relevant times - relevant_time_inds = np.where(np.abs(time_amps) / np.amax(np.abs(time_amps)) > DFT_CUTOFF) - # find first and last index where the filter is True - start_ind = relevant_time_inds[0][0] - stop_ind = relevant_time_inds[0][-1] + 1 - time_amps = time_amps[start_ind:stop_ind] - times_cut = times[start_ind:stop_ind] - if times_cut.size == 0: - return (0.0 + 0.0j) * np.zeros_like(freqs) - - # only need to compute DTFT kernel for distinct dts - # usually, there is only one dt, if times is simulation time mesh - dts = np.diff(times_cut) - dts_unique, kernel_indices = np.unique(dts, return_inverse=True) - - dft_kernels = [np.exp(2j * np.pi * freqs * curr_dt) for curr_dt in dts_unique] - running_kernel = np.exp(2j * np.pi * freqs * times_cut[0]) - dft = np.zeros(len(freqs), dtype=complex) - for amp, kernel_index in zip(time_amps, kernel_indices): - dft += running_kernel * amp - running_kernel *= dft_kernels[kernel_index] - - # kernel_indices was one index shorter than time_amps - dft += running_kernel * time_amps[-1] - - return dt * dft / np.sqrt(2 * np.pi) - - @add_ax_if_none - def plot_spectrum_in_frequency_range( - self, - times: ArrayFloat1D, - fmin: float, - fmax: float, - num_freqs: int = 101, - val: PlotVal = "real", - ax: Ax = None, - ) -> Ax: - """Plot the complex-valued amplitude of the time-dependence. - Note: Only the real part of the time signal is used. - - Parameters - ---------- - times : np.ndarray - Array of evenly-spaced times (seconds) to evaluate time-dependence at. - The spectrum is computed from this value and the time frequency content. - To see spectrum for a specific :class:`.Simulation`, - pass ``simulation.tmesh``. - fmin : float - Lower bound of frequency for the spectrum plot. - fmax : float - Upper bound of frequency for the spectrum plot. - num_freqs : int = 101 - Number of frequencies to plot within the [fmin, fmax]. - ax : matplotlib.axes._subplots.Axes = None - Matplotlib axes to plot on, if not specified, one is created. - - Returns - ------- - matplotlib.axes._subplots.Axes - The supplied or created matplotlib axes. - """ - times = np.array(times) - - dts = np.diff(times) - if not np.allclose(dts, dts[0] * np.ones_like(dts), atol=1e-17): - raise SetupError("Supplied times not evenly spaced.") - - dt = np.mean(dts) - freqs = np.linspace(fmin, fmax, num_freqs) - - spectrum = self.spectrum(times=times, dt=dt, freqs=freqs) - - if val == "real": - ax.plot(freqs, spectrum.real, color="blueviolet", label="real") - elif val == "imag": - ax.plot(freqs, spectrum.imag, color="crimson", label="imag") - elif val == "abs": - ax.plot(freqs, np.abs(spectrum), color="k", label="abs") - else: - raise ValueError(f"Plot 'val' option of '{val}' not recognized.") - ax.set_xlabel("frequency (Hz)") - ax.set_title("source spectrum") - ax.legend() - ax.set_aspect("auto") - return ax - - @add_ax_if_none - def plot(self, times: ArrayFloat1D, val: PlotVal = "real", ax: Ax = None) -> Ax: - """Plot the complex-valued amplitude of the time-dependence. - - Parameters - ---------- - times : np.ndarray - Array of times (seconds) to plot source at. - To see source time amplitude for a specific :class:`.Simulation`, - pass ``simulation.tmesh``. - val : Literal['real', 'imag', 'abs'] = 'real' - Which part of the spectrum to plot. - ax : matplotlib.axes._subplots.Axes = None - Matplotlib axes to plot on, if not specified, one is created. - - Returns - ------- - matplotlib.axes._subplots.Axes - The supplied or created matplotlib axes. - """ - times = np.array(times) - amp_complex = self.amp_time(times) +# marked as migrated to _common +from __future__ import annotations - if val == "real": - ax.plot(times, amp_complex.real, color="blueviolet", label="real") - elif val == "imag": - ax.plot(times, amp_complex.imag, color="crimson", label="imag") - elif val == "abs": - ax.plot(times, np.abs(amp_complex), color="k", label="abs") - else: - raise ValueError(f"Plot 'val' option of '{val}' not recognized.") - ax.set_xlabel("time (s)") - ax.set_title("source amplitude") - ax.legend() - ax.set_aspect("auto") - return ax +from tidy3d._common.components.time import ( + DFT_CUTOFF, + AbstractTimeDependence, +) diff --git a/tidy3d/components/transformation.py b/tidy3d/components/transformation.py index 2178a7d9ef..3add2c4b41 100644 --- a/tidy3d/components/transformation.py +++ b/tidy3d/components/transformation.py @@ -1,211 +1,15 @@ -"""Defines geometric transformation classes""" +"""Compatibility shim for :mod:`tidy3d._common.components.transformation`.""" -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Union - -import numpy as np -from pydantic import Field, field_validator - -from tidy3d.constants import RADIAN -from tidy3d.exceptions import ValidationError - -from .autograd import TracedFloat -from .base import Tidy3dBaseModel, cached_property -from .types import Axis, Coordinate - -if TYPE_CHECKING: - from .types import ArrayFloat2D, TensorReal - - -class AbstractRotation(ABC, Tidy3dBaseModel): - """Abstract rotation of vectors and tensors.""" - - @cached_property - @abstractmethod - def matrix(self) -> TensorReal: - """Rotation matrix.""" - - @cached_property - @abstractmethod - def isidentity(self) -> bool: - """Check whether rotation is identity.""" - - def rotate_vector(self, vector: ArrayFloat2D) -> ArrayFloat2D: - """Rotate a vector/point or a list of vectors/points. - - Parameters - ---------- - points : ArrayLike[float] - Array of shape ``(3, ...)``. - - Returns - ------- - Coordinate - Rotated vector. - """ - - if self.isidentity: - return vector - - if len(vector.shape) == 1: - return self.matrix @ vector - - return np.tensordot(self.matrix, vector, axes=1) - - def rotate_tensor(self, tensor: TensorReal) -> TensorReal: - """Rotate a tensor. - - Parameters - ---------- - tensor : ArrayLike[float] - Array of shape ``(3, 3)``. - - Returns - ------- - TensorReal - Rotated tensor. - """ - - if self.isidentity: - return tensor - - return np.matmul(self.matrix, np.matmul(tensor, self.matrix.T)) - - -class RotationAroundAxis(AbstractRotation): - """Rotation of vectors and tensors around a given vector.""" - - axis: Union[Axis, Coordinate] = Field( - 0, - title="Axis of Rotation", - description="A vector that specifies the axis of rotation, or a single int: 0, 1, or 2, " - "indicating x, y, or z.", - ) - - angle: TracedFloat = Field( - 0.0, - title="Angle of Rotation", - description="Angle of rotation in radians.", - json_schema_extra={"units": RADIAN}, - ) - - @field_validator("axis") - @classmethod - def _validate_axis_vector(cls, val: Union[Axis, Coordinate]) -> Coordinate: - if not isinstance(val, tuple): - axis = [0.0, 0.0, 0.0] - axis[val] = 1.0 - val = tuple(axis) - return val - - @field_validator("axis") - @classmethod - def _validate_axis_nonzero_norm(cls, val: Coordinate) -> Coordinate: - norm = np.linalg.norm(val) - if np.isclose(norm, 0): - raise ValidationError( - "The norm of vector 'axis' cannot be zero. Please provide a proper rotation axis." - ) - return val - - @cached_property - def isidentity(self) -> bool: - """Check whether rotation is identity.""" - - return np.isclose(self.angle % (2 * np.pi), 0) - - @cached_property - def matrix(self) -> TensorReal: - """Rotation matrix.""" - - if self.isidentity: - return np.eye(3) - - norm = np.linalg.norm(self.axis) - n = self.axis / norm - c = np.cos(self.angle) - s = np.sin(self.angle) - K = np.array([[0, -n[2], n[1]], [n[2], 0, -n[0]], [-n[1], n[0], 0]]) - R = np.eye(3) + s * K + (1 - c) * K @ K - - return R - - -class AbstractReflection(ABC, Tidy3dBaseModel): - """Abstract reflection of vectors and tensors.""" - - @cached_property - @abstractmethod - def matrix(self) -> TensorReal: - """Reflection matrix.""" - - def reflect_vector(self, vector: ArrayFloat2D) -> ArrayFloat2D: - """Reflect a vector/point or a list of vectors/points. - - Parameters - ---------- - vector : ArrayLike[float] - Array of shape ``(3, ...)``. - - Returns - ------- - Coordinate - Reflected vector. - """ - - if len(vector.shape) == 1: - return self.matrix @ vector - - return np.tensordot(self.matrix, vector, axes=1) - - def reflect_tensor(self, tensor: TensorReal) -> TensorReal: - """Reflect a tensor. - - Parameters - ---------- - tensor : ArrayLike[float] - Array of shape ``(3, 3)``. - - Returns - ------- - TensorReal - Reflected tensor. - """ - - return np.matmul(self.matrix, np.matmul(tensor, self.matrix.T)) - - -class ReflectionFromPlane(AbstractReflection): - """Reflection of vectors and tensors around a given vector.""" - - normal: Coordinate = Field( - (1, 0, 0), - title="Normal of the reflecting plane", - description="A vector that specifies the normal of the plane of reflection", - ) - - @field_validator("normal") - @classmethod - def _validate_normal_nonzero_norm(cls, val: Coordinate) -> Coordinate: - norm = np.linalg.norm(val) - if np.isclose(norm, 0): - raise ValidationError( - "The norm of vector 'normal' cannot be zero. Please provide a proper normal vector." - ) - return val - - @cached_property - def matrix(self) -> TensorReal: - """Reflection matrix.""" - - norm = np.linalg.norm(self.normal) - n = self.normal / norm - R = np.eye(3) - 2 * np.outer(n, n) - - return R +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility +# marked as migrated to _common +from __future__ import annotations -RotationType = Union[RotationAroundAxis] -ReflectionType = Union[ReflectionFromPlane] +from tidy3d._common.components.transformation import ( + AbstractReflection, + AbstractRotation, + ReflectionFromPlane, + ReflectionType, + RotationAroundAxis, + RotationType, +) diff --git a/tidy3d/components/types/base.py b/tidy3d/components/types/base.py index ea408643fd..39695d994e 100644 --- a/tidy3d/components/types/base.py +++ b/tidy3d/components/types/base.py @@ -1,320 +1,82 @@ -"""Defines 'types' that various fields can be""" +"""Compatibility shim for :mod:`tidy3d._common.components.types.base`.""" -from __future__ import annotations +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -import numbers -from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, Union +# marked as migrated to _common +from __future__ import annotations -import numpy as np -from pydantic import ( - BaseModel, - BeforeValidator, - ConfigDict, - Field, - NonNegativeFloat, - PlainValidator, - PositiveFloat, +from tidy3d._common.components.types.base import ( + TYPE_TAG_STR, + ArrayComplex, + ArrayComplex1D, + ArrayComplex2D, + ArrayComplex3D, + ArrayComplex4D, + ArrayConstraints, + ArrayFloat, + ArrayFloat1D, + ArrayFloat2D, + ArrayFloat3D, + ArrayFloat4D, + ArrayInt1D, + ArrayLike, + ArrayLikeStrict, + AuxField, + Ax, + Axis, + Axis2D, + Bound, + BoxSurface, + ClipOperationType, + ColormapType, + Complex, + Coordinate, + Coordinate2D, + CoordinateOptional, + Direction, + DTypeLike, + EMField, + EpsSpecType, + FieldType, + FieldVal, + FreqArray, + FreqBound, + FreqBoundMax, + FreqBoundMin, + GridSize, + InterpMethod, + LengthUnit, + LumpDistType, + MatrixReal4x4, + ModeClassification, + ModeSolverType, + ObsGridArray, + PermittivityComponent, + PlanePosition, + PlotScale, + PlotVal, + Polarization, + PolarizationBasis, + PoleAndResidue, + PolesAndResidues, + PriorityMode, + RealFieldVal, + ScalarSymmetry, + Shapely, + Size, + Size1D, + Symmetry, + TensorReal, + TrackFreq, + Undefined, + UnitsZBF, + _auto_serializer, + _coerce, + _dtype2python, + _from_complex_dict, + _list_to_tuple, + _parse_complex, + array_alias, + discriminated_union, + xyz, ) -from pydantic.functional_serializers import PlainSerializer -from pydantic.json_schema import WithJsonSchema - -if TYPE_CHECKING: - from numpy.typing import NDArray - -try: - from matplotlib.axes import Axes -except ImportError: - Axes = None - -from shapely.geometry.base import BaseGeometry - -# type tag default name -TYPE_TAG_STR = "type" - - -def discriminated_union(union: type, discriminator: str = TYPE_TAG_STR) -> type: - return Annotated[union, Field(discriminator=discriminator)] - - -""" Numpy Arrays """ - - -def _dtype2python(value: Any) -> Any: - """Converts numpy scalar types to their python equivalents.""" - if isinstance(value, np.integer): - return int(value) - if isinstance(value, np.floating): - return float(value) - if isinstance(value, np.complexfloating): - return complex(value) - if isinstance(value, np.bool_): - return bool(value) - return value - - -def _from_complex_dict(v: Any) -> Any: - if isinstance(v, dict) and "real" in v and "imag" in v: - return np.asarray(v["real"]) + 1j * np.asarray(v["imag"]) - return v - - -def _auto_serializer(a: Any, _: Any) -> Any: - """Serializes numpy arrays and scalars for JSON.""" - if isinstance(a, complex) or ( - hasattr(np, "complexfloating") and isinstance(a, np.complexfloating) - ): - return {"real": float(a.real), "imag": float(a.imag)} - if isinstance(a, np.ndarray): - if np.iscomplexobj(a): - return {"real": a.real.tolist(), "imag": a.imag.tolist()} - else: - return a.tolist() - if isinstance(a, float) or (hasattr(np, "floating") and isinstance(a, np.floating)): - return float(a) # Ensure basic Python float - if isinstance(a, int) or (hasattr(np, "integer") and isinstance(a, np.integer)): - return int(a) # Ensure basic Python int - if hasattr(np, "number") and isinstance(a, np.number): - return a.item() - return a - - -DTypeLike = Annotated[np.dtype, PlainValidator(np.dtype), WithJsonSchema({"type": "np.dtype"})] - - -class ArrayConstraints(BaseModel): - """Container for array constraints.""" - - model_config = ConfigDict(frozen=True) - - dtype: Optional[DTypeLike] = None - ndim: Optional[int] = None - shape: Optional[tuple[int, ...]] = None - forbid_nan: bool = True - scalar_to_1d: bool = False - strict: bool = False - - -def _coerce(v: Any, *, constraints: ArrayConstraints) -> NDArray: - """Convert input to a NumPy array with constraints. - - Raises - ------ - ValueError - - If conversion to an array fails. - - If the array ends up with dtype=object (unsupported element type). - - If the number of dimensions or shape does not match the expectations. - - If ``forbid_nan`` is ``True`` and the array contains NaN values. - """ - if constraints.strict and np.isscalar(v): - raise ValueError( - f"strict mode: scalar value {type(v).__name__!r} cannot be coerced to a NumPy array. " - ) - - try: - # constraints.dtype is already an np.dtype object or None - arr = np.asarray(v) if constraints.dtype is None else np.asarray(v, dtype=constraints.dtype) - except Exception as e: - raise ValueError(f"cannot convert {type(v).__name__!r} to a NumPy array") from e - - if arr.dtype == np.dtype("object"): - raise ValueError(f"unsupported element type {type(v).__name__!r} for array coercion") - - if ( - arr.ndim == 0 - and (constraints.ndim == 1 or constraints.ndim is None) - and constraints.scalar_to_1d - ): - arr = arr.reshape(1) - if constraints.ndim is not None and arr.ndim != constraints.ndim: - raise ValueError(f"expected {constraints.ndim}-D, got {arr.ndim}-D") - if constraints.shape is not None and tuple(arr.shape) != constraints.shape: - raise ValueError(f"expected shape {constraints.shape}, got {tuple(arr.shape)}") - if constraints.forbid_nan and np.any(np.isnan(arr)): - raise ValueError("array contains NaN") - - # enforce immutability of our Pydantic models - arr.flags.writeable = False - - return arr - - -def array_alias( - *, - dtype: Optional[Any] = None, - ndim: Optional[int] = None, - shape: Optional[tuple[int, ...]] = None, - forbid_nan: bool = True, - scalar_to_1d: bool = False, - strict: bool = False, -) -> Any: - constraints = ArrayConstraints( - dtype=dtype, - ndim=ndim, - shape=shape, - forbid_nan=forbid_nan, - scalar_to_1d=scalar_to_1d, - strict=strict, - ) - serializer = PlainSerializer(_auto_serializer, when_used="json") - - base_schema = { - "type": "ArrayLike", - "x-array-dtype": getattr(constraints.dtype, "str", None), - "x-array-ndim": constraints.ndim, - "x-array-shape": constraints.shape, - "x-array-forbid_nan": constraints.forbid_nan, - "x-array-scalar_to_1d": constraints.scalar_to_1d, - "x-array-strict": constraints.strict, - } - - return Annotated[ - np.ndarray, - BeforeValidator(_from_complex_dict), - BeforeValidator(lambda v: _coerce(v, constraints=constraints)), - serializer, - WithJsonSchema(base_schema), - ] - - -ArrayLike = array_alias() -ArrayLikeStrict = array_alias(strict=True) - -ArrayInt1D = array_alias(dtype=int, ndim=1, scalar_to_1d=True) - -ArrayFloat = array_alias(dtype=float) -ArrayFloat1D = array_alias(dtype=float, ndim=1, scalar_to_1d=True) -ArrayFloat2D = array_alias(dtype=float, ndim=2) -ArrayFloat3D = array_alias(dtype=float, ndim=3) -ArrayFloat4D = array_alias(dtype=float, ndim=4) - -ArrayComplex = array_alias(dtype=complex) -ArrayComplex1D = array_alias(dtype=complex, ndim=1, scalar_to_1d=True) -ArrayComplex2D = array_alias(dtype=complex, ndim=2) -ArrayComplex3D = array_alias(dtype=complex, ndim=3) -ArrayComplex4D = array_alias(dtype=complex, ndim=4) - -TensorReal = array_alias(dtype=float, ndim=2, shape=(3, 3)) -MatrixReal4x4 = array_alias(dtype=float, ndim=2, shape=(4, 4)) - -""" Complex Values """ - - -def _parse_complex(v: Any) -> complex: - if isinstance(v, complex): - return v - - if isinstance(v, dict) and "real" in v and "imag" in v: - return complex(v["real"], v["imag"]) - - if isinstance(v, numbers.Number): - return complex(v) - - if hasattr(v, "__complex__"): - try: - return complex(v.__complex__()) - except Exception: - pass - - if isinstance(v, (list, tuple)) and len(v) == 2: - return complex(v[0], v[1]) - - return v - - -Complex = Annotated[ - complex, - BeforeValidator(_parse_complex), - PlainSerializer( - lambda z, _: {"real": z.real, "imag": z.imag}, - when_used="json", - return_type=dict, - ), -] - -""" symmetry """ - -Symmetry = Annotated[Literal[0, -1, 1], BeforeValidator(_dtype2python)] -ScalarSymmetry = Annotated[Literal[0, 1], BeforeValidator(_dtype2python)] - -""" geometric """ - -Size1D = NonNegativeFloat -Size = tuple[Size1D, Size1D, Size1D] -Coordinate = tuple[float, float, float] -CoordinateOptional = tuple[Optional[float], Optional[float], Optional[float]] -Coordinate2D = tuple[float, float] -Bound = tuple[Coordinate, Coordinate] -GridSize = Union[PositiveFloat, tuple[PositiveFloat, ...]] -Axis = Annotated[Literal[0, 1, 2], BeforeValidator(_dtype2python)] -Axis2D = Annotated[Literal[0, 1], BeforeValidator(_dtype2python)] -Shapely = BaseGeometry -PlanePosition = Literal["bottom", "middle", "top"] -ClipOperationType = Literal["union", "intersection", "difference", "symmetric_difference"] -BoxSurface = Literal["x-", "x+", "y-", "y+", "z-", "z+"] -LengthUnit = Literal["nm", "μm", "um", "mm", "cm", "m", "mil", "in"] -PriorityMode = Literal["equal", "conductor"] - -""" medium """ - -# custom medium -InterpMethod = Literal["nearest", "linear"] - -PoleAndResidue = tuple[Complex, Complex] -PolesAndResidues = tuple[PoleAndResidue, ...] -FreqBoundMax = float -FreqBoundMin = float -FreqBound = tuple[FreqBoundMin, FreqBoundMax] - -PermittivityComponent = Literal["xx", "xy", "xz", "yx", "yy", "yz", "zx", "zy", "zz"] - -""" sources """ - -Polarization = Literal["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"] -Direction = Literal["+", "-"] - -""" monitors """ - - -def _list_to_tuple(v: Any) -> Any: - if isinstance(v, list): - return tuple(v) - return v - - -EMField = Literal["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"] -FieldType = Literal["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"] -FreqArray = ArrayFloat1D -ObsGridArray = FreqArray -PolarizationBasis = Literal["linear", "circular"] -AuxField = Literal["Nfx", "Nfy", "Nfz"] - -""" plotting """ - -Ax = Axes -PlotVal = Literal["real", "imag", "abs"] -FieldVal = Literal["real", "imag", "abs", "abs^2", "phase"] -RealFieldVal = Literal["real", "abs", "abs^2"] -PlotScale = Literal["lin", "dB", "log", "symlog"] -ColormapType = Literal["divergent", "sequential", "cyclic"] - -""" mode solver """ - -ModeSolverType = Literal["tensorial", "diagonal"] -EpsSpecType = Literal["diagonal", "tensorial_real", "tensorial_complex"] -ModeClassification = Literal["TEM", "quasi-TEM", "TE", "TM", "Hybrid"] - -""" mode tracking """ - -TrackFreq = Literal["central", "lowest", "highest"] - -""" lumped elements""" - -LumpDistType = Literal["off", "laterally_only", "on"] - -""" dataset """ - -xyz = Literal["x", "y", "z"] -UnitsZBF = Literal["mm", "cm", "in", "m"] - -""" sentinel """ -Undefined = object() diff --git a/tidy3d/components/types/third_party.py b/tidy3d/components/types/third_party.py index 4fe305ce68..7e2eda6240 100644 --- a/tidy3d/components/types/third_party.py +++ b/tidy3d/components/types/third_party.py @@ -1,14 +1,8 @@ -from __future__ import annotations - -from typing import Any +"""Compatibility shim for :mod:`tidy3d._common.components.types.third_party`.""" -from tidy3d.packaging import check_import +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -# TODO Complicated as trimesh should be a core package unless decoupled implementation types in functional location. -# We need to restructure. -if check_import("trimesh"): - import trimesh # Won't add much overhead if already imported +# marked as migrated to _common +from __future__ import annotations - TrimeshType = trimesh.Trimesh -else: - TrimeshType = Any +from tidy3d._common.components.types.third_party import TrimeshType diff --git a/tidy3d/components/types/utils.py b/tidy3d/components/types/utils.py index f05a68b146..39cb1b0f5c 100644 --- a/tidy3d/components/types/utils.py +++ b/tidy3d/components/types/utils.py @@ -1,33 +1,10 @@ -"""Utilities for type & schema creation.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING, Any - -from pydantic_core import core_schema - -if TYPE_CHECKING: - from pydantic import GetCoreSchemaHandler +"""Compatibility shim for :mod:`tidy3d._common.components.types.utils`.""" +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -def _add_schema(arbitrary_type: type, title: str, field_type_str: str) -> None: - """Adds a schema to the ``arbitrary_type`` class without subclassing.""" - - @classmethod - def __get_pydantic_core_schema__( - cls: type, _source_type: type, _handler: GetCoreSchemaHandler - ) -> core_schema.CoreSchema: - def _serialize(value: Any, info: core_schema.SerializationInfo) -> Any: - from tidy3d.components.autograd.utils import get_static - from tidy3d.components.types.base import _auto_serializer - - return _auto_serializer(get_static(value), info) - - return core_schema.any_schema( - metadata={"title": title, "type": field_type_str}, - serialization=core_schema.plain_serializer_function_ser_schema( - _serialize, info_arg=True - ), - ) +# marked as migrated to _common +from __future__ import annotations - arbitrary_type.__get_pydantic_core_schema__ = __get_pydantic_core_schema__ +from tidy3d._common.components.types.utils import ( + _add_schema, +) diff --git a/tidy3d/components/validators.py b/tidy3d/components/validators.py index e604f97f21..d3be4fa0d9 100644 --- a/tidy3d/components/validators.py +++ b/tidy3d/components/validators.py @@ -1,59 +1,44 @@ -"""Defines various validation functions that get used to ensure inputs are legit""" +"""Compatibility shim for :mod:`tidy3d._common.components.validators`.""" + +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility + +# marked as migrated to _common from __future__ import annotations -from collections.abc import Sequence from typing import TYPE_CHECKING, Any, TypeVar, Union import numpy as np from numpy.typing import NDArray from pydantic import field_validator, model_validator +from tidy3d._common.components.validators import ( + MIN_FREQUENCY, + FloatArray, + _assert_min_freq, + _warn_unsupported_traced_argument, + validate_name_str, + warn_if_dataset_none, +) +from tidy3d.components.data.data_array import DATA_ARRAY_MAP +from tidy3d.components.geometry.base import Box from tidy3d.exceptions import SetupError, ValidationError from tidy3d.log import log -from .autograd.utils import get_static, hasbox -from .base import DATA_ARRAY_MAP -from .geometry.base import Box - if TYPE_CHECKING: + from collections.abc import Sequence from typing import Callable, Optional from pydantic import FieldValidationInfo from tidy3d import Simulation + from tidy3d._common.components.validators import T from tidy3d.components.base_sim.simulation import AbstractSimulation from tidy3d.components.data.monitor_data import AbstractFieldData from tidy3d.components.types import FreqArray from tidy3d.plugins.smatrix import AbstractComponentModeler -T = TypeVar("T") - -"""Explanation of Pydantic validators (v2). - - Validators are class methods that validate and coerce model inputs. This module defines - reusable validator factories that are shared across tidy3d components. - - In Pydantic v2 we use: - - ``@field_validator("field_name")`` for field-local checks/coercions. It can access - already-validated fields via ``info.data``, but ``info.data`` only contains fields - validated earlier, so avoid order-dependent cross-field logic. - - ``@model_validator(mode="after")`` for cross-field constraints that need the full model. - - To attach a validator from this file to a Pydantic model, assign the factory result in the - class body, e.g. ``_plane_validator = assert_plane()``. Avoid reusing the same attribute - name for multiple validators, or earlier validators may be overwritten. - - For more details: `Pydantic validators `_ -""" - -# Lowest frequency supported (Hz) -MIN_FREQUENCY = 1e5 - -FloatArray = Union[Sequence[float], NDArray] - - def named_obj_descr(obj: Any, field_name: str, position_index: int) -> str: """Generate a string describing a named object which can be used in error messages.""" descr = f"simulation.{field_name}[{position_index}] (no `name` was specified)" @@ -124,21 +109,6 @@ def is_volumetric(cls: type, val: tuple[float, ...]) -> tuple[float, ...]: return is_volumetric -# FIXME: this validator doesn't do anything -def validate_name_str() -> Callable[[type, Optional[str]], Optional[str]]: - """make sure the name does not include [, ] (used for default names)""" - - @field_validator("name") - @classmethod - def field_has_unique_names(cls: type, val: Optional[str]) -> Optional[str]: - """raise exception if '[' or ']' in name""" - # if val and ('[' in val or ']' in val): - # raise SetupError(f"'[' or ']' not allowed in name: {val} (used for defaults)") - return val - - return field_has_unique_names - - def validate_unique( *field_names: str, ) -> Callable[[type, Sequence[Any], FieldValidationInfo], Sequence[Any]]: @@ -309,30 +279,12 @@ def _make_required(self: T) -> T: return _make_required -def warn_if_dataset_none( - field_name: str, -) -> Callable[[type, Optional[dict[str, Any]]], Optional[dict[str, Any]]]: - """Warn if a Dataset field has None in its dictionary.""" - - @field_validator(field_name, mode="before") - @classmethod - def _warn_if_none(cls: type, val: Optional[dict[str, Any]]) -> Optional[dict[str, Any]]: - """Warn if the DataArrays fail to load.""" - if isinstance(val, dict): - if any((v in DATA_ARRAY_MAP for _, v in val.items() if isinstance(v, str))): - log.warning(f"Loading {field_name} without data.", custom_loc=[field_name]) - return None - return val - - return _warn_if_none - - def warn_backward_waist_distance(field_name: str) -> Callable[[T], T]: - """Warn about changed waist distance behavior for backward-propagating beams.""" + """Warn if a backward-propagating beam uses a non-zero waist distance.""" @model_validator(mode="after") def _warn_backward_nonzero(self: T) -> T: - """Emit warning about changed waist distance interpretation.""" + """Emit deprecation warning for backward propagation with non-zero waist.""" direction = self.direction if direction != "-": return self @@ -340,15 +292,12 @@ def _warn_backward_nonzero(self: T) -> T: waist_array = np.atleast_1d(waist_value) if not np.all(np.isclose(waist_array, 0.0)): log.warning( - f"Starting in version 2.11, the behavior of {self.__class__.__name__} with direction '-' " - f"and non-zero '{field_name}' has changed. The waist position is now defined " - "consistently for both forward- and backward-propagating beams: a positive " - f"'{field_name}' always places the beam waist behind the source/monitor plane " - "(toward the negative normal axis). This ensures reciprocity between Gaussian " - "sources and overlap monitors used for port-based S-matrix calculations. " - "If your simulation relied on the previous behavior (where the waist position " - "flipped with direction), you may need to adjust your waist distance values.", - log_once=True, + f"Behavior of {self.__class__.__name__} with direction '-' and non-zero '{field_name}' will " + "change in version 2.11 to be consistent with upcoming beam overlap monitors and " + "ports. Currently, the waist distance is interpreted w.r.t. the directed " + "propagation axis, so switching 'direction' also switches the position of the " + "waist in the global reference frame. In the future, the waist position will be " + "defined such that it is the same for backward- and forward-propagating beams.", ) return self @@ -424,15 +373,6 @@ def _check_perturbed_val(cls: type, val: Any, info: FieldValidationInfo) -> Any: return _check_perturbed_val -def _assert_min_freq(freqs: FloatArray, msg_start: str) -> None: - """Check if all ``freqs`` are above the minimum frequency.""" - if np.min(freqs) < MIN_FREQUENCY: - raise ValidationError( - f"{msg_start} must be no lower than {MIN_FREQUENCY:.0e} Hz. " - "Note that the unit of frequency is 'Hz'." - ) - - def validate_freqs_min() -> Callable[[type, FreqArray], FreqArray]: """Validate lower bound for monitor, and mode solver frequencies.""" @@ -472,24 +412,3 @@ def freqs_unique(cls: AbstractComponentModeler, val: FreqArray) -> FreqArray: return val return freqs_unique - - -def _warn_unsupported_traced_argument( - *names: str, -) -> Callable[[type, Any, FieldValidationInfo], Any]: - @field_validator(*names) - @classmethod - def _warn_traced_arg(cls: type, val: Any, info: FieldValidationInfo) -> Any: - if hasbox(val): - log.warning( - f"Field '{info.field_name}' of '{cls.__name__}' received an autograd tracer " - f"(i.e., a value being tracked for automatic differentiation). " - f"Automatic differentiation through this field is unsupported, " - f"so the tracer has been converted to its static value. " - f"If you want to avoid this warning, you manually unbox the value " - f"using the 'autograd.tracer.getval' function before passing it to Tidy3D." - ) - return get_static(val) - return val - - return _warn_traced_arg diff --git a/tidy3d/components/viz/__init__.py b/tidy3d/components/viz/__init__.py index 169b63fa2f..e94c153f6a 100644 --- a/tidy3d/components/viz/__init__.py +++ b/tidy3d/components/viz/__init__.py @@ -1,12 +1,33 @@ +"""Compatibility shim for :mod:`tidy3d._common.components.viz`.""" + +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility + +# marked as migrated to _common from __future__ import annotations -from .axes_utils import add_ax_if_none, equal_aspect, make_ax, set_default_labels_and_title -from .descartes import Polygon, polygon_patch, polygon_path -from .flex_style import apply_tidy3d_params, restore_matplotlib_rcparams -from .plot_params import ( +from tidy3d._common.components.viz import ( + ARROW_ALPHA, + ARROW_COLOR_ABSORBER, + ARROW_COLOR_MONITOR, + ARROW_COLOR_POLARIZATION, + ARROW_COLOR_SOURCE, + ARROW_LENGTH, + FLEXCOMPUTE_COLORS, + MATPLOTLIB_IMPORTED, + MEDIUM_CMAP, + PLOT_BUFFER, + STRUCTURE_EPS_CMAP, + STRUCTURE_EPS_CMAP_R, + STRUCTURE_HEAT_COND_CMAP, AbstractPlotParams, PathPlotParams, PlotParams, + Polygon, + VisualizationSpec, + add_ax_if_none, + arrow_style, + equal_aspect, + make_ax, plot_params_abc, plot_params_absorber, plot_params_bloch, @@ -23,70 +44,10 @@ plot_params_source, plot_params_structure, plot_params_symmetry, + plot_scene_3d, + plot_sim_3d, + polygon_patch, + polygon_path, + restore_matplotlib_rcparams, + set_default_labels_and_title, ) -from .plot_sim_3d import plot_scene_3d, plot_sim_3d -from .styles import ( - ARROW_ALPHA, - ARROW_COLOR_ABSORBER, - ARROW_COLOR_MONITOR, - ARROW_COLOR_POLARIZATION, - ARROW_COLOR_SOURCE, - ARROW_LENGTH, - FLEXCOMPUTE_COLORS, - MEDIUM_CMAP, - PLOT_BUFFER, - STRUCTURE_EPS_CMAP, - STRUCTURE_EPS_CMAP_R, - STRUCTURE_HEAT_COND_CMAP, - arrow_style, -) -from .visualization_spec import MATPLOTLIB_IMPORTED, VisualizationSpec - -apply_tidy3d_params() - -__all__ = [ - "ARROW_ALPHA", - "ARROW_COLOR_ABSORBER", - "ARROW_COLOR_MONITOR", - "ARROW_COLOR_POLARIZATION", - "ARROW_COLOR_SOURCE", - "ARROW_LENGTH", - "FLEXCOMPUTE_COLORS", - "MATPLOTLIB_IMPORTED", - "MEDIUM_CMAP", - "PLOT_BUFFER", - "STRUCTURE_EPS_CMAP", - "STRUCTURE_EPS_CMAP_R", - "STRUCTURE_HEAT_COND_CMAP", - "AbstractPlotParams", - "PathPlotParams", - "PlotParams", - "Polygon", - "VisualizationSpec", - "add_ax_if_none", - "arrow_style", - "equal_aspect", - "make_ax", - "plot_params_abc", - "plot_params_absorber", - "plot_params_bloch", - "plot_params_fluid", - "plot_params_geometry", - "plot_params_grid", - "plot_params_lumped_element", - "plot_params_min_grid_size", - "plot_params_monitor", - "plot_params_override_structures", - "plot_params_pec", - "plot_params_pmc", - "plot_params_pml", - "plot_params_source", - "plot_params_structure", - "plot_params_symmetry", - "plot_scene_3d", - "plot_sim_3d", - "polygon_patch", - "polygon_path", - "restore_matplotlib_rcparams", - "set_default_labels_and_title", -] diff --git a/tidy3d/components/viz/axes_utils.py b/tidy3d/components/viz/axes_utils.py index e1579abed3..cdc5454a95 100644 --- a/tidy3d/components/viz/axes_utils.py +++ b/tidy3d/components/viz/axes_utils.py @@ -1,198 +1,14 @@ -from __future__ import annotations - -from functools import wraps -from typing import TYPE_CHECKING - -from tidy3d.components.types import LengthUnit -from tidy3d.constants import UnitScaling -from tidy3d.exceptions import Tidy3dKeyError - -if TYPE_CHECKING: - from typing import Callable, ParamSpec, TypeVar - - import matplotlib.ticker as ticker - from matplotlib.axes import Axes - - P = ParamSpec("P") - T = TypeVar("T", bound=Callable[..., Axes]) - from typing import Optional - - from tidy3d.components.types import Ax, Axis - - -def _create_unit_aware_locator() -> ticker.Locator: - """Create UnitAwareLocator lazily due to matplotlib import restrictions.""" - import matplotlib.ticker as ticker - - class UnitAwareLocator(ticker.Locator): - """Custom tick locator that places ticks at nice positions in the target unit.""" - - def __init__(self, scale_factor: float) -> None: - """ - Parameters - ---------- - scale_factor : float - Factor to convert from micrometers to the target unit. - """ - super().__init__() - self.scale_factor = scale_factor - - def __call__(self) -> list[float]: - vmin, vmax = self.axis.get_view_interval() - return self.tick_values(vmin, vmax) - - def view_limits(self, vmin: float, vmax: float) -> tuple[float, float]: - """Override to prevent matplotlib from adjusting our limits.""" - return vmin, vmax - - def tick_values(self, vmin: float, vmax: float) -> list[float]: - # convert the view range to the target unit - vmin_unit = vmin * self.scale_factor - vmax_unit = vmax * self.scale_factor - - # tolerance for floating point comparisons in target unit - unit_range = vmax_unit - vmin_unit - unit_tol = unit_range * 1e-8 - - locator = ticker.MaxNLocator(nbins=11, prune=None, min_n_ticks=2) - - ticks_unit = locator.tick_values(vmin_unit, vmax_unit) - - # ensure we have ticks that cover the full range - if len(ticks_unit) > 0: - if ticks_unit[0] > vmin_unit + unit_tol or ticks_unit[-1] < vmax_unit - unit_tol: - # try with fewer bins to get better coverage - for n in [10, 9, 8, 7, 6, 5]: - locator = ticker.MaxNLocator(nbins=n, prune=None, min_n_ticks=2) - ticks_unit = locator.tick_values(vmin_unit, vmax_unit) - if ( - len(ticks_unit) >= 3 - and ticks_unit[0] <= vmin_unit + unit_tol - and ticks_unit[-1] >= vmax_unit - unit_tol - ): - break - - # if still no good coverage, manually ensure edge coverage - if len(ticks_unit) > 0: - if ( - ticks_unit[0] > vmin_unit + unit_tol - or ticks_unit[-1] < vmax_unit - unit_tol - ): - # find a reasonable step size from existing ticks - if len(ticks_unit) > 1: - step = ticks_unit[1] - ticks_unit[0] - else: - step = unit_range / 5 - - # extend the range to ensure coverage - extended_min = vmin_unit - step - extended_max = vmax_unit + step - - # try one more time with extended range - locator = ticker.MaxNLocator(nbins=8, prune=None, min_n_ticks=2) - ticks_unit = locator.tick_values(extended_min, extended_max) - - # filter to reasonable bounds around the original range - ticks_unit = [ - t - for t in ticks_unit - if t >= vmin_unit - step / 2 and t <= vmax_unit + step / 2 - ] +"""Compatibility shim for :mod:`tidy3d._common.components.viz.axes_utils`.""" - # convert the nice ticks back to the original data unit (micrometers) - ticks_um = ticks_unit / self.scale_factor +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility - # filter to ensure ticks are within bounds (with small tolerance) - eps = (vmax - vmin) * 1e-8 - return [tick for tick in ticks_um if vmin - eps <= tick <= vmax + eps] - - return UnitAwareLocator - - -def make_ax() -> Ax: - """makes an empty ``ax``.""" - import matplotlib.pyplot as plt - - _, ax = plt.subplots(1, 1, tight_layout=True) - return ax - - -def add_ax_if_none(plot: T) -> T: - """Decorates ``plot(*args, **kwargs, ax=None)`` function. - if ax=None in the function call, creates an ax and feeds it to rest of function. - """ - - @wraps(plot) - def _plot(*args: P.args, **kwargs: P.kwargs) -> Axes: - """New plot function using a generated ax if None.""" - if kwargs.get("ax") is None: - ax = make_ax() - kwargs["ax"] = ax - return plot(*args, **kwargs) - - return _plot - - -def equal_aspect(plot: T) -> T: - """Decorates a plotting function returning a matplotlib axes. - Ensures the aspect ratio of the returned axes is set to equal. - Useful for 2D plots, like sim.plot() or sim_data.plot_fields() - """ - - @wraps(plot) - def _plot(*args: P.args, **kwargs: P.kwargs) -> Axes: - """New plot function with equal aspect ratio axes returned.""" - ax = plot(*args, **kwargs) - ax.set_aspect("equal") - return ax - - return _plot - - -def set_default_labels_and_title( - axis_labels: tuple[str, str], - axis: Axis, - position: float, - ax: Ax, - plot_length_units: Optional[LengthUnit] = None, -) -> Ax: - """Adds axis labels and title to plots involving spatial dimensions. - When the ``plot_length_units`` are specified, the plot axes are scaled, and - the title and axis labels include the desired units. - """ - - import matplotlib.ticker as ticker - - xlabel = axis_labels[0] - ylabel = axis_labels[1] - if plot_length_units is not None: - if plot_length_units not in UnitScaling: - raise Tidy3dKeyError( - f"Provided units '{plot_length_units}' are not supported. " - f"Please choose one of '{LengthUnit}'." - ) - ax.set_xlabel(f"{xlabel} ({plot_length_units})") - ax.set_ylabel(f"{ylabel} ({plot_length_units})") - - scale_factor = UnitScaling[plot_length_units] - - # for imperial units, use custom tick locator for nice tick positions - if plot_length_units in ["mil", "in"]: - UnitAwareLocator = _create_unit_aware_locator() - x_locator = UnitAwareLocator(scale_factor) - y_locator = UnitAwareLocator(scale_factor) - ax.xaxis.set_major_locator(x_locator) - ax.yaxis.set_major_locator(y_locator) - - formatter = ticker.FuncFormatter(lambda y, _: f"{y * scale_factor:.2f}") - - ax.xaxis.set_major_formatter(formatter) - ax.yaxis.set_major_formatter(formatter) +# marked as migrated to _common +from __future__ import annotations - position_scaled = position * scale_factor - ax.set_title(f"cross section at {'xyz'[axis]}={position_scaled:.2f} ({plot_length_units})") - else: - ax.set_xlabel(xlabel) - ax.set_ylabel(ylabel) - ax.set_title(f"cross section at {'xyz'[axis]}={position:.2f}") - return ax +from tidy3d._common.components.viz.axes_utils import ( + _create_unit_aware_locator, + add_ax_if_none, + equal_aspect, + make_ax, + set_default_labels_and_title, +) diff --git a/tidy3d/components/viz/descartes.py b/tidy3d/components/viz/descartes.py index 572dfc44ba..a1b2f54fc2 100644 --- a/tidy3d/components/viz/descartes.py +++ b/tidy3d/components/viz/descartes.py @@ -1,113 +1,12 @@ -"""================================================================================================= -Descartes modified from https://pypi.org/project/descartes/ for Shapely >= 1.8.0 +"""Compatibility shim for :mod:`tidy3d._common.components.viz.descartes`.""" -Copyright Flexcompute 2022 - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR -IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND -FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR -CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER -IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT -OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -""" +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility +# marked as migrated to _common from __future__ import annotations -from typing import TYPE_CHECKING, Any - -if TYPE_CHECKING: - from numpy.typing import NDArray - from shapely.geometry.base import BaseGeometry - -try: - from matplotlib.patches import PathPatch - from matplotlib.path import Path -except ImportError: - pass -from numpy import array, concatenate, ones - - -class Polygon: - """Adapt Shapely polygons to a common interface""" - - def __init__(self, context: dict[str, Any]) -> None: - if isinstance(context, dict): - self.context = context["coordinates"] - else: - self.context = context - - @property - def exterior(self) -> Any: - """Get polygon exterior.""" - value = getattr(self.context, "exterior", None) - if value is None: - value = self.context[0] - return value - - @property - def interiors(self) -> Any: - """Get polygon interiors.""" - value = getattr(self.context, "interiors", None) - if value is None: - value = self.context[1:] - return value - - -def polygon_path(polygon: BaseGeometry) -> Path: - """Constructs a compound matplotlib path from a Shapely or GeoJSON-like - geometric object""" - - def coding(obj: Any) -> NDArray: - # The codes will be all "LINETO" commands, except for "MOVETO"s at the - # beginning of each subpath - crds = getattr(obj, "coords", None) - if crds is None: - crds = obj - n = len(crds) - vals = ones(n, dtype=Path.code_type) * Path.LINETO - if len(vals) > 0: - vals[0] = Path.MOVETO - return vals - - ptype = polygon.geom_type - if ptype == "Polygon": - polygon = [Polygon(polygon)] - elif ptype == "MultiPolygon": - polygon = [Polygon(p) for p in polygon.geoms] - - vertices = concatenate( - [ - concatenate( - [array(t.exterior.coords)[:, :2]] + [array(r.coords)[:, :2] for r in t.interiors] - ) - for t in polygon - ] - ) - codes = concatenate( - [concatenate([coding(t.exterior)] + [coding(r) for r in t.interiors]) for t in polygon] - ) - - return Path(vertices, codes) - - -def polygon_patch(polygon: BaseGeometry, **kwargs: Any) -> PathPatch: - """Constructs a matplotlib patch from a geometric object - - The ``polygon`` may be a Shapely or GeoJSON-like object with or without holes. - The ``kwargs`` are those supported by the matplotlib.patches.Polygon class - constructor. Returns an instance of matplotlib.patches.PathPatch. - - Example - ------- - >>> b = Point(0, 0).buffer(1.0) # doctest: +SKIP - >>> patch = PolygonPatch(b, fc='blue', ec='blue', alpha=0.5) # doctest: +SKIP - >>> axis.add_patch(patch) # doctest: +SKIP - - """ - return PathPatch(polygon_path(polygon), **kwargs) - - -"""End descartes modification -=================================================================================================""" +from tidy3d._common.components.viz.descartes import ( + Polygon, + polygon_patch, + polygon_path, +) diff --git a/tidy3d/components/viz/flex_color_palettes.py b/tidy3d/components/viz/flex_color_palettes.py index 7fc1454a0b..0b80b28bef 100644 --- a/tidy3d/components/viz/flex_color_palettes.py +++ b/tidy3d/components/viz/flex_color_palettes.py @@ -1,3306 +1,12 @@ +"""Compatibility shim for :mod:`tidy3d._common.components.viz.flex_color_palettes`.""" + +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility + +# marked as migrated to _common from __future__ import annotations -SEQUENTIAL_PALETTES_HEX = { - "flex_turquoise_seq": [ - "#ffffff", - "#fefefe", - "#fdfdfd", - "#fcfcfc", - "#fbfbfb", - "#fafbfa", - "#f9fafa", - "#f8f9f9", - "#f7f8f8", - "#f6f7f7", - "#f5f6f6", - "#f3f5f5", - "#f2f4f4", - "#f1f3f3", - "#f0f3f2", - "#eff2f1", - "#eef1f1", - "#edf0f0", - "#ecefef", - "#ebeeee", - "#eaeded", - "#e9edec", - "#e8eceb", - "#e7ebeb", - "#e6eaea", - "#e5e9e9", - "#e4e8e8", - "#e3e7e7", - "#e2e7e6", - "#e1e6e5", - "#e0e5e5", - "#dfe4e4", - "#dee3e3", - "#dde2e2", - "#dce2e1", - "#dbe1e0", - "#dae0df", - "#d9dfdf", - "#d8dede", - "#d7dedd", - "#d6dddc", - "#d5dcdb", - "#d4dbdb", - "#d3dada", - "#d2dad9", - "#d1d9d8", - "#d1d8d7", - "#d0d7d6", - "#cfd6d6", - "#ced6d5", - "#cdd5d4", - "#ccd4d3", - "#cbd3d2", - "#cad2d2", - "#c9d2d1", - "#c8d1d0", - "#c7d0cf", - "#c6cfce", - "#c5cece", - "#c4cecd", - "#c3cdcc", - "#c2cccb", - "#c1cbca", - "#c0cbca", - "#bfcac9", - "#bec9c8", - "#bec8c7", - "#bdc8c7", - "#bcc7c6", - "#bbc6c5", - "#bac5c4", - "#b9c5c3", - "#b8c4c3", - "#b7c3c2", - "#b6c2c1", - "#b5c2c0", - "#b4c1c0", - "#b3c0bf", - "#b2bfbe", - "#b2bfbd", - "#b1bebd", - "#b0bdbc", - "#afbcbb", - "#aebcba", - "#adbbba", - "#acbab9", - "#abbab8", - "#aab9b7", - "#a9b8b7", - "#a9b7b6", - "#a8b7b5", - "#a7b6b4", - "#a6b5b4", - "#a5b4b3", - "#a4b4b2", - "#a3b3b2", - "#a2b2b1", - "#a1b2b0", - "#a1b1af", - "#a0b0af", - "#9fb0ae", - "#9eafad", - "#9daeac", - "#9cadac", - "#9badab", - "#9aacaa", - "#99abaa", - "#99aba9", - "#98aaa8", - "#97a9a7", - "#96a9a7", - "#95a8a6", - "#94a7a5", - "#93a6a5", - "#92a6a4", - "#92a5a3", - "#91a4a2", - "#90a4a2", - "#8fa3a1", - "#8ea2a0", - "#8da2a0", - "#8ca19f", - "#8ca09e", - "#8ba09e", - "#8a9f9d", - "#899e9c", - "#889e9c", - "#879d9b", - "#869c9a", - "#869c9a", - "#859b99", - "#849a98", - "#839a97", - "#829997", - "#819896", - "#809895", - "#809795", - "#7f9694", - "#7e9693", - "#7d9593", - "#7c9492", - "#7b9491", - "#7a9391", - "#7a9290", - "#79928f", - "#78918f", - "#77908e", - "#76908d", - "#758f8d", - "#758f8c", - "#748e8b", - "#738d8b", - "#728d8a", - "#718c89", - "#708b89", - "#708b88", - "#6f8a87", - "#6e8987", - "#6d8986", - "#6c8885", - "#6b8885", - "#6a8784", - "#6a8684", - "#698683", - "#688582", - "#678482", - "#668481", - "#658380", - "#658280", - "#64827f", - "#63817e", - "#62817e", - "#61807d", - "#607f7c", - "#607f7c", - "#5f7e7b", - "#5e7d7b", - "#5d7d7a", - "#5c7c79", - "#5b7c79", - "#5b7b78", - "#5a7a77", - "#597a77", - "#587976", - "#577975", - "#567875", - "#567774", - "#557774", - "#547673", - "#537572", - "#527572", - "#517471", - "#507470", - "#507370", - "#4f726f", - "#4e726f", - "#4d716e", - "#4c716d", - "#4b706d", - "#4b6f6c", - "#4a6f6b", - "#496e6b", - "#486e6a", - "#476d6a", - "#466c69", - "#456c68", - "#446b68", - "#446b67", - "#436a67", - "#426966", - "#416965", - "#406865", - "#3f6864", - "#3e6763", - "#3e6663", - "#3d6662", - "#3c6562", - "#3b6561", - "#3a6460", - "#396360", - "#38635f", - "#37625f", - "#36625e", - "#35615d", - "#35605d", - "#34605c", - "#335f5c", - "#325f5b", - "#315e5a", - "#305d5a", - "#2f5d59", - "#2e5c58", - "#2d5c58", - "#2c5b57", - "#2b5a57", - "#2a5a56", - "#295955", - "#285955", - "#275854", - "#265754", - "#255753", - "#245652", - "#235652", - "#225551", - "#215551", - "#205450", - "#1e534f", - "#1d534f", - "#1c524e", - "#1b524e", - "#1a514d", - "#18504c", - "#17504c", - "#164f4b", - "#144f4b", - "#134e4a", - ], - "flex_green_seq": [ - "#ffffff", - "#fefefe", - "#fdfdfd", - "#fcfcfc", - "#fbfbfb", - "#f9fafa", - "#f8f9f9", - "#f7f8f8", - "#f6f7f7", - "#f5f6f6", - "#f4f5f5", - "#f3f5f3", - "#f2f4f2", - "#f1f3f1", - "#f0f2f0", - "#eff1ef", - "#eef0ee", - "#ecefed", - "#ebeeec", - "#eaedeb", - "#e9ecea", - "#e8ebe9", - "#e7eae8", - "#e6eae7", - "#e5e9e6", - "#e4e8e5", - "#e3e7e4", - "#e2e6e3", - "#e1e5e2", - "#e0e4e1", - "#dfe3e0", - "#dee3df", - "#dde2de", - "#dce1dd", - "#dbe0dc", - "#dadfdb", - "#d8deda", - "#d7ddd9", - "#d6dcd8", - "#d5dcd7", - "#d4dbd6", - "#d3dad5", - "#d2d9d4", - "#d1d8d3", - "#d0d7d2", - "#cfd6d1", - "#ced6d0", - "#cdd5cf", - "#ccd4ce", - "#cbd3ce", - "#cad2cd", - "#c9d1cc", - "#c8d1cb", - "#c7d0ca", - "#c6cfc9", - "#c5cec8", - "#c4cdc7", - "#c3cdc6", - "#c2ccc5", - "#c1cbc4", - "#c0cac3", - "#bfc9c2", - "#bec9c1", - "#bdc8c0", - "#bcc7bf", - "#bbc6be", - "#bac5bd", - "#b9c5bd", - "#b9c4bc", - "#b8c3bb", - "#b7c2ba", - "#b6c1b9", - "#b5c1b8", - "#b4c0b7", - "#b3bfb6", - "#b2beb5", - "#b1bdb4", - "#b0bdb3", - "#afbcb2", - "#aebbb1", - "#adbab1", - "#acbab0", - "#abb9af", - "#aab8ae", - "#a9b7ad", - "#a8b7ac", - "#a7b6ab", - "#a6b5aa", - "#a5b4a9", - "#a5b4a8", - "#a4b3a8", - "#a3b2a7", - "#a2b1a6", - "#a1b1a5", - "#a0b0a4", - "#9fafa3", - "#9eaea2", - "#9daea1", - "#9cada0", - "#9baca0", - "#9aab9f", - "#99ab9e", - "#99aa9d", - "#98a99c", - "#97a89b", - "#96a89a", - "#95a799", - "#94a699", - "#93a598", - "#92a597", - "#91a496", - "#90a395", - "#90a394", - "#8fa293", - "#8ea193", - "#8da092", - "#8ca091", - "#8b9f90", - "#8a9e8f", - "#899e8e", - "#889d8d", - "#879c8d", - "#879b8c", - "#869b8b", - "#859a8a", - "#849989", - "#839988", - "#829888", - "#819787", - "#809786", - "#809685", - "#7f9584", - "#7e9483", - "#7d9483", - "#7c9382", - "#7b9281", - "#7a9280", - "#79917f", - "#79907e", - "#78907e", - "#778f7d", - "#768e7c", - "#758e7b", - "#748d7a", - "#738c79", - "#728c79", - "#728b78", - "#718a77", - "#708a76", - "#6f8975", - "#6e8875", - "#6d8774", - "#6c8773", - "#6c8672", - "#6b8571", - "#6a8571", - "#698470", - "#68836f", - "#67836e", - "#66826d", - "#66816d", - "#65816c", - "#64806b", - "#637f6a", - "#627f69", - "#617e69", - "#607d68", - "#607d67", - "#5f7c66", - "#5e7c65", - "#5d7b65", - "#5c7a64", - "#5b7a63", - "#5a7962", - "#5a7861", - "#597861", - "#587760", - "#57765f", - "#56765e", - "#55755d", - "#55745d", - "#54745c", - "#53735b", - "#52725a", - "#51725a", - "#507159", - "#4f7058", - "#4f7057", - "#4e6f56", - "#4d6e56", - "#4c6e55", - "#4b6d54", - "#4a6d53", - "#4a6c53", - "#496b52", - "#486b51", - "#476a50", - "#466950", - "#45694f", - "#44684e", - "#44674d", - "#43674c", - "#42664c", - "#41654b", - "#40654a", - "#3f6449", - "#3e6449", - "#3e6348", - "#3d6247", - "#3c6246", - "#3b6146", - "#3a6045", - "#396044", - "#385f43", - "#385e43", - "#375e42", - "#365d41", - "#355c40", - "#345c40", - "#335b3f", - "#325b3e", - "#315a3d", - "#30593d", - "#30593c", - "#2f583b", - "#2e573a", - "#2d573a", - "#2c5639", - "#2b5538", - "#2a5537", - "#295437", - "#285436", - "#275335", - "#265234", - "#265234", - "#255133", - "#245032", - "#235031", - "#224f31", - "#214e30", - "#204e2f", - "#1f4d2e", - "#1e4c2e", - "#1d4c2d", - "#1c4b2c", - "#1b4b2b", - "#1a4a2b", - "#18492a", - "#174929", - "#164828", - "#154728", - "#144727", - "#134626", - "#124525", - "#104525", - "#0f4424", - ], - "flex_blue_seq": [ - "#ffffff", - "#fefefe", - "#fdfdfd", - "#fbfcfc", - "#fafbfb", - "#f9fafa", - "#f8f9f9", - "#f7f7f8", - "#f6f6f8", - "#f4f5f7", - "#f3f4f6", - "#f2f3f5", - "#f1f2f4", - "#f0f1f3", - "#eff0f2", - "#eeeff1", - "#eceef0", - "#ebedf0", - "#eaecef", - "#e9ebee", - "#e8eaed", - "#e7e9ec", - "#e6e8eb", - "#e4e7ea", - "#e3e6ea", - "#e2e5e9", - "#e1e4e8", - "#e0e3e7", - "#dfe2e6", - "#dee1e5", - "#dde0e5", - "#dcdfe4", - "#dadee3", - "#d9dde2", - "#d8dce1", - "#d7dbe0", - "#d6dae0", - "#d5d9df", - "#d4d8de", - "#d3d7dd", - "#d2d6dc", - "#d1d5dc", - "#d0d4db", - "#cfd3da", - "#ced2d9", - "#ccd1d9", - "#cbd0d8", - "#cacfd7", - "#c9ced6", - "#c8cdd5", - "#c7ccd5", - "#c6ccd4", - "#c5cbd3", - "#c4cad2", - "#c3c9d2", - "#c2c8d1", - "#c1c7d0", - "#c0c6cf", - "#bfc5cf", - "#bec4ce", - "#bdc3cd", - "#bcc2cd", - "#bbc1cc", - "#bac0cb", - "#b9c0ca", - "#b8bfca", - "#b7bec9", - "#b6bdc8", - "#b5bcc7", - "#b4bbc7", - "#b3bac6", - "#b2b9c5", - "#b1b8c5", - "#b0b7c4", - "#afb7c3", - "#aeb6c3", - "#adb5c2", - "#acb4c1", - "#abb3c1", - "#aab2c0", - "#a9b1bf", - "#a8b0be", - "#a7b0be", - "#a6afbd", - "#a5aebc", - "#a4adbc", - "#a3acbb", - "#a2abba", - "#a1aaba", - "#a0aab9", - "#9fa9b8", - "#9ea8b8", - "#9da7b7", - "#9ca6b7", - "#9ba5b6", - "#9aa4b5", - "#99a4b5", - "#98a3b4", - "#97a2b3", - "#96a1b3", - "#95a0b2", - "#949fb1", - "#939fb1", - "#929eb0", - "#919db0", - "#909caf", - "#8f9bae", - "#8f9aae", - "#8e9aad", - "#8d99ac", - "#8c98ac", - "#8b97ab", - "#8a96ab", - "#8996aa", - "#8895a9", - "#8794a9", - "#8693a8", - "#8592a8", - "#8492a7", - "#8391a6", - "#8290a6", - "#818fa5", - "#818ea5", - "#808ea4", - "#7f8da3", - "#7e8ca3", - "#7d8ba2", - "#7c8aa2", - "#7b8aa1", - "#7a89a1", - "#7988a0", - "#78879f", - "#77869f", - "#76869e", - "#76859e", - "#75849d", - "#74839d", - "#73829c", - "#72829b", - "#71819b", - "#70809a", - "#6f7f9a", - "#6e7f99", - "#6d7e99", - "#6c7d98", - "#6c7c98", - "#6b7c97", - "#6a7b97", - "#697a96", - "#687995", - "#677895", - "#667894", - "#657794", - "#647693", - "#637593", - "#637592", - "#627492", - "#617391", - "#607291", - "#5f7290", - "#5e7190", - "#5d708f", - "#5c6f8f", - "#5b6f8e", - "#5b6e8e", - "#5a6d8d", - "#596c8c", - "#586b8c", - "#576b8b", - "#566a8b", - "#55698a", - "#54688a", - "#536889", - "#536789", - "#526688", - "#516588", - "#506587", - "#4f6487", - "#4e6386", - "#4d6286", - "#4c6285", - "#4b6185", - "#4a6084", - "#4a5f84", - "#495f83", - "#485e83", - "#475d83", - "#465d82", - "#455c82", - "#445b81", - "#435a81", - "#425a80", - "#425980", - "#41587f", - "#40577f", - "#3f577e", - "#3e567e", - "#3d557d", - "#3c547d", - "#3b547c", - "#3a537c", - "#39527b", - "#39517b", - "#38517b", - "#37507a", - "#364f7a", - "#354e79", - "#344e79", - "#334d78", - "#324c78", - "#314b77", - "#304b77", - "#2f4a76", - "#2e4976", - "#2d4876", - "#2c4875", - "#2c4775", - "#2b4674", - "#2a4574", - "#294473", - "#284473", - "#274373", - "#264272", - "#254172", - "#244171", - "#234071", - "#223f70", - "#213e70", - "#203e70", - "#1f3d6f", - "#1e3c6f", - "#1d3b6e", - "#1c3a6e", - "#1b3a6e", - "#1a396d", - "#19386d", - "#17376c", - "#16366c", - "#15366c", - "#14356b", - "#13346b", - "#12336b", - "#10326a", - "#0f326a", - "#0e316a", - "#0d3069", - "#0b2f69", - "#0a2e68", - "#082d68", - "#072c68", - "#062c68", - "#042b67", - "#032a67", - "#022967", - "#012866", - "#002766", - ], - "flex_orange_seq": [ - "#ffffff", - "#fefefe", - "#fefdfd", - "#fdfdfc", - "#fdfcfb", - "#fcfbfa", - "#fbfafa", - "#fbf9f9", - "#faf9f8", - "#faf8f7", - "#f9f7f6", - "#f8f6f5", - "#f8f6f4", - "#f7f5f3", - "#f7f4f2", - "#f6f3f1", - "#f5f2f1", - "#f5f2f0", - "#f4f1ef", - "#f3f0ee", - "#f3efed", - "#f2efec", - "#f2eeeb", - "#f1edea", - "#f1ece9", - "#f0ece8", - "#f0ebe7", - "#efeae6", - "#efe9e5", - "#eee9e4", - "#eee8e3", - "#ede7e2", - "#ede6e1", - "#ece5e0", - "#ece5df", - "#ebe4de", - "#ebe3dd", - "#eae2dc", - "#eae2db", - "#e9e1da", - "#e9e0d9", - "#e9dfd8", - "#e8dfd7", - "#e8ded6", - "#e7ddd5", - "#e7dcd4", - "#e6dbd3", - "#e6dbd2", - "#e6dad1", - "#e5d9d0", - "#e5d8cf", - "#e4d8ce", - "#e4d7cd", - "#e3d6cc", - "#e3d5cb", - "#e3d5ca", - "#e2d4c9", - "#e2d3c8", - "#e1d2c7", - "#e1d2c6", - "#e0d1c5", - "#e0d0c4", - "#e0cfc3", - "#dfcfc2", - "#dfcec1", - "#decdc0", - "#deccbf", - "#deccbe", - "#ddcbbd", - "#ddcabc", - "#dcc9bb", - "#dcc9ba", - "#dcc8b9", - "#dbc7b8", - "#dbc6b8", - "#dbc6b7", - "#dac5b6", - "#dac4b5", - "#d9c4b4", - "#d9c3b3", - "#d9c2b2", - "#d8c1b1", - "#d8c1b0", - "#d7c0af", - "#d7bfae", - "#d7bead", - "#d6beac", - "#d6bdab", - "#d6bcaa", - "#d5bba9", - "#d5bba8", - "#d4baa7", - "#d4b9a6", - "#d4b9a5", - "#d3b8a4", - "#d3b7a3", - "#d3b6a2", - "#d2b6a1", - "#d2b5a0", - "#d2b49f", - "#d1b49e", - "#d1b39d", - "#d0b29c", - "#d0b19b", - "#d0b19a", - "#cfb099", - "#cfaf99", - "#cfaf98", - "#ceae97", - "#cead96", - "#ceac95", - "#cdac94", - "#cdab93", - "#cdaa92", - "#ccaa91", - "#cca990", - "#cca88f", - "#cba78e", - "#cba78d", - "#caa68c", - "#caa58b", - "#caa58a", - "#c9a489", - "#c9a388", - "#c9a387", - "#c8a286", - "#c8a185", - "#c8a085", - "#c7a084", - "#c79f83", - "#c79e82", - "#c69e81", - "#c69d80", - "#c69c7f", - "#c59c7e", - "#c59b7d", - "#c59a7c", - "#c4997b", - "#c4997a", - "#c49879", - "#c39778", - "#c39777", - "#c29676", - "#c29575", - "#c29575", - "#c19474", - "#c19373", - "#c19372", - "#c09271", - "#c09170", - "#c0906f", - "#bf906e", - "#bf8f6d", - "#bf8e6c", - "#be8e6b", - "#be8d6a", - "#be8c69", - "#bd8c68", - "#bd8b67", - "#bd8a67", - "#bc8a66", - "#bc8965", - "#bc8864", - "#bb8863", - "#bb8762", - "#bb8661", - "#ba8660", - "#ba855f", - "#ba845e", - "#b9835d", - "#b9835c", - "#b8825b", - "#b8815b", - "#b8815a", - "#b78059", - "#b77f58", - "#b77f57", - "#b67e56", - "#b67d55", - "#b67d54", - "#b57c53", - "#b57b52", - "#b57b51", - "#b47a50", - "#b4794f", - "#b4794f", - "#b3784e", - "#b3774d", - "#b3774c", - "#b2764b", - "#b2754a", - "#b17549", - "#b17448", - "#b17347", - "#b07346", - "#b07245", - "#b07144", - "#af7144", - "#af7043", - "#af6f42", - "#ae6f41", - "#ae6e40", - "#ae6d3f", - "#ad6d3e", - "#ad6c3d", - "#ac6b3c", - "#ac6b3b", - "#ac6a3a", - "#ab6939", - "#ab6939", - "#ab6838", - "#aa6737", - "#aa6736", - "#aa6635", - "#a96534", - "#a96533", - "#a86432", - "#a86331", - "#a86330", - "#a7622f", - "#a7612e", - "#a7612d", - "#a6602c", - "#a65f2b", - "#a55f2a", - "#a55e2a", - "#a55d29", - "#a45d28", - "#a45c27", - "#a35b26", - "#a35b25", - "#a35a24", - "#a25923", - "#a25922", - "#a25821", - "#a15720", - "#a1571f", - "#a0561e", - "#a0551d", - "#a0551c", - "#9f541b", - "#9f531a", - "#9e5318", - "#9e5217", - "#9e5116", - "#9d5115", - "#9d5014", - "#9c4f13", - "#9c4f12", - "#9b4e10", - "#9b4d0f", - "#9b4d0e", - "#9a4c0c", - "#9a4b0b", - "#994b09", - "#994a08", - ], - "flex_red_seq": [ - "#ffffff", - "#fefefe", - "#fefdfd", - "#fdfcfc", - "#fcfbfb", - "#fcfafa", - "#fbf9f9", - "#faf8f8", - "#faf7f7", - "#f9f6f6", - "#f8f5f5", - "#f8f4f5", - "#f7f3f4", - "#f6f2f3", - "#f5f2f2", - "#f5f1f1", - "#f4f0f0", - "#f3efef", - "#f3eeee", - "#f2eded", - "#f1ecec", - "#f1ebec", - "#f0eaeb", - "#efe9ea", - "#efe8e9", - "#eee7e8", - "#eee6e7", - "#ede5e6", - "#ece4e6", - "#ece3e5", - "#ebe2e4", - "#ebe1e3", - "#eae0e2", - "#eae0e1", - "#e9dfe0", - "#e9dedf", - "#e8dddf", - "#e8dcde", - "#e7dbdd", - "#e7dadc", - "#e6d9db", - "#e6d8da", - "#e5d7d9", - "#e5d6d8", - "#e4d5d7", - "#e4d4d7", - "#e3d3d6", - "#e3d2d5", - "#e2d1d4", - "#e2d0d3", - "#e1d0d2", - "#e1cfd1", - "#e0ced1", - "#e0cdd0", - "#dfcccf", - "#dfcbce", - "#decacd", - "#dec9cc", - "#ddc8cb", - "#ddc7cb", - "#dcc6ca", - "#dcc5c9", - "#dbc4c8", - "#dbc4c7", - "#dbc3c6", - "#dac2c5", - "#dac1c5", - "#d9c0c4", - "#d9bfc3", - "#d8bec2", - "#d8bdc1", - "#d7bcc0", - "#d7bbc0", - "#d7babf", - "#d6babe", - "#d6b9bd", - "#d5b8bc", - "#d5b7bb", - "#d4b6bb", - "#d4b5ba", - "#d4b4b9", - "#d3b3b8", - "#d3b2b7", - "#d2b1b6", - "#d2b0b6", - "#d1b0b5", - "#d1afb4", - "#d1aeb3", - "#d0adb2", - "#d0acb1", - "#cfabb1", - "#cfaab0", - "#cfa9af", - "#cea8ae", - "#cea8ad", - "#cda7ad", - "#cda6ac", - "#cca5ab", - "#cca4aa", - "#cca3a9", - "#cba2a9", - "#cba1a8", - "#caa0a7", - "#caa0a6", - "#ca9fa5", - "#c99ea5", - "#c99da4", - "#c89ca3", - "#c89ba2", - "#c89aa1", - "#c799a1", - "#c799a0", - "#c6989f", - "#c6979e", - "#c6969d", - "#c5959d", - "#c5949c", - "#c4939b", - "#c4929a", - "#c49299", - "#c39199", - "#c39098", - "#c28f97", - "#c28e96", - "#c28d96", - "#c18c95", - "#c18c94", - "#c18b93", - "#c08a92", - "#c08992", - "#bf8891", - "#bf8790", - "#bf868f", - "#be858f", - "#be858e", - "#bd848d", - "#bd838c", - "#bd828c", - "#bc818b", - "#bc808a", - "#bb7f89", - "#bb7f88", - "#bb7e88", - "#ba7d87", - "#ba7c86", - "#b97b85", - "#b97a85", - "#b97984", - "#b87983", - "#b87882", - "#b87782", - "#b77681", - "#b77580", - "#b6747f", - "#b6737f", - "#b6737e", - "#b5727d", - "#b5717c", - "#b4707c", - "#b46f7b", - "#b46e7a", - "#b36d79", - "#b36d79", - "#b26c78", - "#b26b77", - "#b26a76", - "#b16976", - "#b16875", - "#b06774", - "#b06773", - "#b06673", - "#af6572", - "#af6471", - "#ae6371", - "#ae6270", - "#ae616f", - "#ad616e", - "#ad606e", - "#ac5f6d", - "#ac5e6c", - "#ab5d6b", - "#ab5c6b", - "#ab5b6a", - "#aa5b69", - "#aa5a69", - "#a95968", - "#a95867", - "#a95766", - "#a85666", - "#a85565", - "#a75464", - "#a75463", - "#a65363", - "#a65262", - "#a65161", - "#a55061", - "#a54f60", - "#a44e5f", - "#a44d5e", - "#a34d5e", - "#a34c5d", - "#a34b5c", - "#a24a5c", - "#a2495b", - "#a1485a", - "#a1475a", - "#a04659", - "#a04558", - "#9f4557", - "#9f4457", - "#9f4356", - "#9e4255", - "#9e4155", - "#9d4054", - "#9d3f53", - "#9c3e52", - "#9c3d52", - "#9b3c51", - "#9b3b50", - "#9a3b50", - "#9a3a4f", - "#99394e", - "#99384e", - "#98374d", - "#98364c", - "#98354b", - "#97344b", - "#97334a", - "#963249", - "#963149", - "#953048", - "#952f47", - "#942e47", - "#942d46", - "#932c45", - "#932b45", - "#922a44", - "#922943", - "#912843", - "#912742", - "#902641", - "#902540", - "#8f2440", - "#8e223f", - "#8e213e", - "#8d203e", - "#8d1f3d", - "#8c1e3c", - "#8c1d3c", - "#8b1b3b", - "#8b1a3a", - "#8a193a", - "#8a1739", - "#891638", - "#891438", - "#881337", - ], - "flex_purple_seq": [ - "#ffffff", - "#fefefe", - "#fdfdfd", - "#fcfcfd", - "#fbfbfc", - "#fafafb", - "#f9f9fa", - "#f8f8f9", - "#f7f7f9", - "#f6f6f8", - "#f5f5f7", - "#f4f4f6", - "#f3f3f6", - "#f2f2f5", - "#f1f1f4", - "#f0f0f3", - "#efeff3", - "#eeeef2", - "#ededf1", - "#ececf0", - "#ebebf0", - "#eaeaef", - "#e9e9ee", - "#e8e8ed", - "#e8e8ed", - "#e7e7ec", - "#e6e6eb", - "#e5e5eb", - "#e4e4ea", - "#e3e3e9", - "#e2e2e8", - "#e1e1e8", - "#e0e0e7", - "#dfdfe6", - "#dedee6", - "#dddde5", - "#dcdce4", - "#dbdbe4", - "#dadae3", - "#d9dae2", - "#d9d9e2", - "#d8d8e1", - "#d7d7e0", - "#d6d6e0", - "#d5d5df", - "#d4d4df", - "#d3d3de", - "#d2d2dd", - "#d1d1dd", - "#d0d1dc", - "#d0d0db", - "#cfcfdb", - "#ceceda", - "#cdcdd9", - "#ccccd9", - "#cbcbd8", - "#cacad8", - "#c9c9d7", - "#c8c9d6", - "#c8c8d6", - "#c7c7d5", - "#c6c6d5", - "#c5c5d4", - "#c4c4d3", - "#c3c3d3", - "#c2c2d2", - "#c2c2d2", - "#c1c1d1", - "#c0c0d1", - "#bfbfd0", - "#bebecf", - "#bdbdcf", - "#bcbcce", - "#bcbcce", - "#bbbbcd", - "#babacd", - "#b9b9cc", - "#b8b8cc", - "#b7b7cb", - "#b7b7ca", - "#b6b6ca", - "#b5b5c9", - "#b4b4c9", - "#b3b3c8", - "#b3b2c8", - "#b2b2c7", - "#b1b1c7", - "#b0b0c6", - "#afafc6", - "#aeaec5", - "#aeadc5", - "#adadc4", - "#acacc4", - "#ababc3", - "#aaaac3", - "#aaa9c2", - "#a9a8c2", - "#a8a8c1", - "#a7a7c1", - "#a6a6c0", - "#a6a5c0", - "#a5a4bf", - "#a4a4bf", - "#a3a3be", - "#a3a2be", - "#a2a1bd", - "#a1a0bd", - "#a0a0bc", - "#9f9fbc", - "#9f9ebb", - "#9e9dbb", - "#9d9cba", - "#9c9cba", - "#9c9bb9", - "#9b9ab9", - "#9a99b8", - "#9998b8", - "#9998b8", - "#9897b7", - "#9796b7", - "#9695b6", - "#9694b6", - "#9594b5", - "#9493b5", - "#9392b4", - "#9391b4", - "#9291b4", - "#9190b3", - "#908fb3", - "#908eb2", - "#8f8db2", - "#8e8db1", - "#8d8cb1", - "#8d8bb1", - "#8c8ab0", - "#8b8ab0", - "#8a89af", - "#8a88af", - "#8987ae", - "#8886ae", - "#8886ae", - "#8785ad", - "#8684ad", - "#8583ac", - "#8583ac", - "#8482ac", - "#8381ab", - "#8280ab", - "#8280ab", - "#817faa", - "#807eaa", - "#807da9", - "#7f7ca9", - "#7e7ca9", - "#7e7ba8", - "#7d7aa8", - "#7c79a8", - "#7b79a7", - "#7b78a7", - "#7a77a6", - "#7976a6", - "#7976a6", - "#7875a5", - "#7774a5", - "#7773a5", - "#7673a4", - "#7572a4", - "#7571a4", - "#7470a3", - "#736fa3", - "#736fa3", - "#726ea2", - "#716da2", - "#716ca2", - "#706ca1", - "#6f6ba1", - "#6f6aa1", - "#6e69a0", - "#6d69a0", - "#6d68a0", - "#6c679f", - "#6b669f", - "#6b669f", - "#6a659e", - "#69649e", - "#69639e", - "#68629d", - "#67629d", - "#67619d", - "#66609d", - "#655f9c", - "#655f9c", - "#645e9c", - "#635d9b", - "#635c9b", - "#625c9b", - "#625b9b", - "#615a9a", - "#60599a", - "#60589a", - "#5f589a", - "#5e5799", - "#5e5699", - "#5d5599", - "#5d5498", - "#5c5498", - "#5b5398", - "#5b5298", - "#5a5198", - "#595097", - "#595097", - "#584f97", - "#584e97", - "#574d96", - "#564c96", - "#564c96", - "#554b96", - "#554a95", - "#544995", - "#544895", - "#534795", - "#524795", - "#524694", - "#514594", - "#514494", - "#504394", - "#504294", - "#4f4294", - "#4e4193", - "#4e4093", - "#4d3f93", - "#4d3e93", - "#4c3d93", - "#4c3c93", - "#4b3b93", - "#4b3a92", - "#4a3992", - "#4a3892", - "#493892", - "#493792", - "#483692", - "#483592", - "#473492", - "#473391", - "#463291", - "#463191", - "#452f91", - "#452e91", - "#442d91", - "#442c91", - "#432b91", - "#432a91", - "#422991", - "#422891", - "#412691", - "#412591", - ], - "flex_grey_seq": [ - "#ffffff", - "#fefefe", - "#fdfdfd", - "#fcfcfc", - "#fbfbfc", - "#fafafb", - "#f9f9fa", - "#f8f9f9", - "#f8f8f8", - "#f7f7f7", - "#f6f6f6", - "#f5f5f6", - "#f4f4f5", - "#f3f3f4", - "#f2f2f3", - "#f1f1f2", - "#f0f0f1", - "#eff0f1", - "#eeeff0", - "#eeeeef", - "#ededee", - "#ececed", - "#ebebec", - "#eaeaec", - "#e9e9eb", - "#e8e9ea", - "#e7e8e9", - "#e6e7e8", - "#e6e6e8", - "#e5e5e7", - "#e4e4e6", - "#e3e3e5", - "#e2e3e4", - "#e1e2e4", - "#e0e1e3", - "#dfe0e2", - "#dfdfe1", - "#dedee0", - "#dddde0", - "#dcdddf", - "#dbdcde", - "#dadbdd", - "#d9dadd", - "#d9d9dc", - "#d8d8db", - "#d7d8da", - "#d6d7da", - "#d5d6d9", - "#d4d5d8", - "#d4d4d7", - "#d3d4d6", - "#d2d3d6", - "#d1d2d5", - "#d0d1d4", - "#cfd0d3", - "#cfcfd3", - "#cecfd2", - "#cdced1", - "#cccdd0", - "#cbccd0", - "#cacbcf", - "#cacbce", - "#c9cace", - "#c8c9cd", - "#c7c8cc", - "#c6c8cb", - "#c6c7cb", - "#c5c6ca", - "#c4c5c9", - "#c3c4c8", - "#c2c4c8", - "#c2c3c7", - "#c1c2c6", - "#c0c1c6", - "#bfc0c5", - "#bec0c4", - "#bebfc3", - "#bdbec3", - "#bcbdc2", - "#bbbdc1", - "#babcc1", - "#babbc0", - "#b9babf", - "#b8babe", - "#b7b9be", - "#b7b8bd", - "#b6b7bc", - "#b5b7bc", - "#b4b6bb", - "#b3b5ba", - "#b3b4ba", - "#b2b4b9", - "#b1b3b8", - "#b0b2b8", - "#b0b1b7", - "#afb1b6", - "#aeb0b6", - "#adafb5", - "#adaeb4", - "#acaeb3", - "#abadb3", - "#aaacb2", - "#aaabb1", - "#a9abb1", - "#a8aab0", - "#a7a9af", - "#a7a8af", - "#a6a8ae", - "#a5a7ad", - "#a4a6ad", - "#a4a6ac", - "#a3a5ab", - "#a2a4ab", - "#a1a3aa", - "#a1a3aa", - "#a0a2a9", - "#9fa1a8", - "#9ea1a8", - "#9ea0a7", - "#9d9fa6", - "#9c9ea6", - "#9c9ea5", - "#9b9da4", - "#9a9ca4", - "#999ca3", - "#999ba2", - "#989aa2", - "#979aa1", - "#9799a0", - "#9698a0", - "#95979f", - "#94979f", - "#94969e", - "#93959d", - "#92959d", - "#92949c", - "#91939b", - "#90939b", - "#8f929a", - "#8f919a", - "#8e9199", - "#8d9098", - "#8d8f98", - "#8c8e97", - "#8b8e96", - "#8b8d96", - "#8a8c95", - "#898c95", - "#888b94", - "#888a93", - "#878a93", - "#868992", - "#868891", - "#858891", - "#848790", - "#848690", - "#83868f", - "#82858e", - "#82848e", - "#81848d", - "#80838d", - "#80828c", - "#7f828b", - "#7e818b", - "#7e808a", - "#7d808a", - "#7c7f89", - "#7b7e88", - "#7b7e88", - "#7a7d87", - "#797c87", - "#797c86", - "#787b85", - "#777a85", - "#777a84", - "#767984", - "#757983", - "#757882", - "#747782", - "#737781", - "#737681", - "#727580", - "#71757f", - "#71747f", - "#70737e", - "#70737e", - "#6f727d", - "#6e717d", - "#6e717c", - "#6d707b", - "#6c707b", - "#6c6f7a", - "#6b6e7a", - "#6a6e79", - "#6a6d79", - "#696c78", - "#686c77", - "#686b77", - "#676a76", - "#666a76", - "#666975", - "#656974", - "#646874", - "#646773", - "#636773", - "#636672", - "#626572", - "#616571", - "#616470", - "#606370", - "#5f636f", - "#5f626f", - "#5e626e", - "#5d616e", - "#5d606d", - "#5c606c", - "#5b5f6c", - "#5b5e6b", - "#5a5e6b", - "#5a5d6a", - "#595d6a", - "#585c69", - "#585b69", - "#575b68", - "#565a67", - "#565967", - "#555966", - "#555866", - "#545865", - "#535765", - "#535664", - "#525663", - "#515563", - "#515562", - "#505462", - "#4f5361", - "#4f5361", - "#4e5260", - "#4e5160", - "#4d515f", - "#4c505e", - "#4c505e", - "#4b4f5d", - "#4a4e5d", - "#4a4e5c", - "#494d5c", - "#494d5b", - "#484c5a", - "#474b5a", - "#474b59", - "#464a59", - "#454958", - "#454958", - "#444857", - "#444857", - "#434756", - ], -} -CATEGORICAL_PALETTES_HEX = { - "flex_distinct": [ - "#176737", - "#FF7B0D", - "#979BAA", - "#F44E6A", - "#0062FF", - "#26AB5B", - "#6D3EF2", - "#F59E0B", - ] -} -DIVERGING_PALETTES_HEX = { - "flex_BuRd": [ - "#002766", - "#022967", - "#052b67", - "#072d68", - "#0a2e69", - "#0d3069", - "#10326a", - "#12346b", - "#15356c", - "#17376c", - "#1a396d", - "#1c3a6e", - "#1e3c6f", - "#203e70", - "#223f71", - "#244171", - "#264372", - "#284473", - "#2a4674", - "#2c4775", - "#2e4976", - "#304a77", - "#324c78", - "#344e79", - "#364f7a", - "#38517b", - "#3a527c", - "#3c547d", - "#3e557e", - "#3f577f", - "#415980", - "#435a80", - "#455c81", - "#475d83", - "#495f84", - "#4b6085", - "#4c6286", - "#4e6387", - "#506588", - "#526789", - "#54688a", - "#566a8b", - "#586b8c", - "#5a6d8d", - "#5b6e8e", - "#5d708f", - "#5f7290", - "#617391", - "#637592", - "#657694", - "#677895", - "#687a96", - "#6a7b97", - "#6c7d98", - "#6e7e99", - "#70809a", - "#72829b", - "#74839d", - "#76859e", - "#78879f", - "#7a88a0", - "#7b8aa1", - "#7d8ca3", - "#7f8da4", - "#818fa5", - "#8391a6", - "#8592a8", - "#8794a9", - "#8996aa", - "#8b97ab", - "#8d99ad", - "#8f9bae", - "#919daf", - "#939eb1", - "#95a0b2", - "#97a2b3", - "#99a4b4", - "#9ba5b6", - "#9da7b7", - "#9fa9b9", - "#a1abba", - "#a3acbb", - "#a5aebd", - "#a7b0be", - "#a9b2c0", - "#abb4c1", - "#adb6c2", - "#afb7c4", - "#b1b9c5", - "#b4bbc7", - "#b6bdc8", - "#b8bfca", - "#bac1cb", - "#bcc3cd", - "#bec5ce", - "#c0c6d0", - "#c3c8d1", - "#c5cad3", - "#c7ccd5", - "#c9ced6", - "#cbd0d8", - "#ced2d9", - "#d0d4db", - "#d2d6dd", - "#d4d8de", - "#d7dae0", - "#d9dce2", - "#dbdee3", - "#dee0e5", - "#e0e3e7", - "#e2e5e9", - "#e4e7ea", - "#e7e9ec", - "#e9ebee", - "#ecedf0", - "#eeeff2", - "#f0f2f3", - "#f3f4f5", - "#f5f6f7", - "#f8f8f9", - "#fafafb", - "#fdfdfd", - "#FFFFFF", - "#fefdfd", - "#fcfbfb", - "#fbf9f9", - "#f9f7f7", - "#f8f5f5", - "#f6f3f3", - "#f5f1f1", - "#f4efef", - "#f2edee", - "#f1ebec", - "#efe9ea", - "#eee7e8", - "#ede5e6", - "#ece3e4", - "#ebe1e3", - "#e9dfe1", - "#e8dddf", - "#e7dbdd", - "#e6d9db", - "#e5d7d9", - "#e4d5d8", - "#e3d3d6", - "#e2d1d4", - "#e1cfd2", - "#e0cdd0", - "#dfcccf", - "#decacd", - "#ddc8cb", - "#dcc6c9", - "#dbc4c7", - "#dac2c6", - "#d9c0c4", - "#d8bec2", - "#d7bcc0", - "#d6babf", - "#d6b8bd", - "#d5b6bb", - "#d4b5b9", - "#d3b3b8", - "#d2b1b6", - "#d1afb4", - "#d0adb2", - "#cfabb1", - "#cfa9af", - "#cea8ad", - "#cda6ac", - "#cca4aa", - "#cba2a8", - "#caa0a7", - "#c99ea5", - "#c99ca3", - "#c89ba2", - "#c799a0", - "#c6979e", - "#c5959d", - "#c4939b", - "#c49199", - "#c39098", - "#c28e96", - "#c18c94", - "#c08a93", - "#c08891", - "#bf8790", - "#be858e", - "#bd838c", - "#bc818b", - "#bb7f89", - "#bb7e88", - "#ba7c86", - "#b97a84", - "#b87883", - "#b77681", - "#b67580", - "#b6737e", - "#b5717d", - "#b46f7b", - "#b36d79", - "#b26c78", - "#b26a76", - "#b16875", - "#b06673", - "#af6572", - "#ae6370", - "#ad616f", - "#ac5f6d", - "#ac5d6c", - "#ab5c6a", - "#aa5a69", - "#a95867", - "#a85666", - "#a75464", - "#a65263", - "#a55161", - "#a44f60", - "#a44d5e", - "#a34b5d", - "#a2495b", - "#a1475a", - "#a04658", - "#9f4457", - "#9e4255", - "#9d4054", - "#9c3e52", - "#9b3c51", - "#9a3a4f", - "#99384e", - "#98364c", - "#97344b", - "#96324a", - "#953048", - "#942e47", - "#932c45", - "#922a44", - "#912842", - "#902541", - "#8f233f", - "#8e213e", - "#8d1e3d", - "#8b1c3b", - "#8a193a", - "#891638", - "#881337", - ], - "flex_RdBu": [ - "#881337", - "#891638", - "#8a193a", - "#8b1c3b", - "#8d1e3d", - "#8e213e", - "#8f233f", - "#902541", - "#912842", - "#922a44", - "#932c45", - "#942e47", - "#953048", - "#96324a", - "#97344b", - "#98364c", - "#99384e", - "#9a3a4f", - "#9b3c51", - "#9c3e52", - "#9d4054", - "#9e4255", - "#9f4457", - "#a04658", - "#a1475a", - "#a2495b", - "#a34b5d", - "#a44d5e", - "#a44f60", - "#a55161", - "#a65263", - "#a75464", - "#a85666", - "#a95867", - "#aa5a69", - "#ab5c6a", - "#ac5d6c", - "#ac5f6d", - "#ad616f", - "#ae6370", - "#af6572", - "#b06673", - "#b16875", - "#b26a76", - "#b26c78", - "#b36d79", - "#b46f7b", - "#b5717d", - "#b6737e", - "#b67580", - "#b77681", - "#b87883", - "#b97a84", - "#ba7c86", - "#bb7e88", - "#bb7f89", - "#bc818b", - "#bd838c", - "#be858e", - "#bf8790", - "#c08891", - "#c08a93", - "#c18c94", - "#c28e96", - "#c39098", - "#c49199", - "#c4939b", - "#c5959d", - "#c6979e", - "#c799a0", - "#c89ba2", - "#c99ca3", - "#c99ea5", - "#caa0a7", - "#cba2a8", - "#cca4aa", - "#cda6ac", - "#cea8ad", - "#cfa9af", - "#cfabb1", - "#d0adb2", - "#d1afb4", - "#d2b1b6", - "#d3b3b8", - "#d4b5b9", - "#d5b6bb", - "#d6b8bd", - "#d6babf", - "#d7bcc0", - "#d8bec2", - "#d9c0c4", - "#dac2c6", - "#dbc4c7", - "#dcc6c9", - "#ddc8cb", - "#decacd", - "#dfcccf", - "#e0cdd0", - "#e1cfd2", - "#e2d1d4", - "#e3d3d6", - "#e4d5d8", - "#e5d7d9", - "#e6d9db", - "#e7dbdd", - "#e8dddf", - "#e9dfe1", - "#ebe1e3", - "#ece3e4", - "#ede5e6", - "#eee7e8", - "#efe9ea", - "#f1ebec", - "#f2edee", - "#f4efef", - "#f5f1f1", - "#f6f3f3", - "#f8f5f5", - "#f9f7f7", - "#fbf9f9", - "#fcfbfb", - "#fefdfd", - "#FFFFFF", - "#fdfdfd", - "#fafafb", - "#f8f8f9", - "#f5f6f7", - "#f3f4f5", - "#f0f2f3", - "#eeeff2", - "#ecedf0", - "#e9ebee", - "#e7e9ec", - "#e4e7ea", - "#e2e5e9", - "#e0e3e7", - "#dee0e5", - "#dbdee3", - "#d9dce2", - "#d7dae0", - "#d4d8de", - "#d2d6dd", - "#d0d4db", - "#ced2d9", - "#cbd0d8", - "#c9ced6", - "#c7ccd5", - "#c5cad3", - "#c3c8d1", - "#c0c6d0", - "#bec5ce", - "#bcc3cd", - "#bac1cb", - "#b8bfca", - "#b6bdc8", - "#b4bbc7", - "#b1b9c5", - "#afb7c4", - "#adb6c2", - "#abb4c1", - "#a9b2c0", - "#a7b0be", - "#a5aebd", - "#a3acbb", - "#a1abba", - "#9fa9b9", - "#9da7b7", - "#9ba5b6", - "#99a4b4", - "#97a2b3", - "#95a0b2", - "#939eb1", - "#919daf", - "#8f9bae", - "#8d99ad", - "#8b97ab", - "#8996aa", - "#8794a9", - "#8592a8", - "#8391a6", - "#818fa5", - "#7f8da4", - "#7d8ca3", - "#7b8aa1", - "#7a88a0", - "#78879f", - "#76859e", - "#74839d", - "#72829b", - "#70809a", - "#6e7e99", - "#6c7d98", - "#6a7b97", - "#687a96", - "#677895", - "#657694", - "#637592", - "#617391", - "#5f7290", - "#5d708f", - "#5b6e8e", - "#5a6d8d", - "#586b8c", - "#566a8b", - "#54688a", - "#526789", - "#506588", - "#4e6387", - "#4c6286", - "#4b6085", - "#495f84", - "#475d83", - "#455c81", - "#435a80", - "#415980", - "#3f577f", - "#3e557e", - "#3c547d", - "#3a527c", - "#38517b", - "#364f7a", - "#344e79", - "#324c78", - "#304a77", - "#2e4976", - "#2c4775", - "#2a4674", - "#284473", - "#264372", - "#244171", - "#223f71", - "#203e70", - "#1e3c6f", - "#1c3a6e", - "#1a396d", - "#17376c", - "#15356c", - "#12346b", - "#10326a", - "#0d3069", - "#0a2e69", - "#072d68", - "#052b67", - "#022967", - "#002766", - ], - "flex_GrPu": [ - "#0f4424", - "#124526", - "#144727", - "#174829", - "#19492a", - "#1b4b2c", - "#1d4c2d", - "#1f4d2f", - "#214f30", - "#235032", - "#255234", - "#275335", - "#295437", - "#2b5638", - "#2d573a", - "#2f583b", - "#315a3d", - "#335b3e", - "#355c40", - "#365e42", - "#385f43", - "#3a6045", - "#3c6246", - "#3e6348", - "#3f644a", - "#41664b", - "#43674d", - "#45684e", - "#476a50", - "#486b52", - "#4a6c53", - "#4c6e55", - "#4e6f56", - "#4f7058", - "#51725a", - "#53735b", - "#55745d", - "#57765f", - "#587760", - "#5a7962", - "#5c7a63", - "#5e7b65", - "#5f7d67", - "#617e68", - "#637f6a", - "#65816c", - "#67826d", - "#68846f", - "#6a8571", - "#6c8672", - "#6e8874", - "#6f8976", - "#718b78", - "#738c79", - "#758e7b", - "#778f7d", - "#79907e", - "#7a9280", - "#7c9382", - "#7e9584", - "#809685", - "#829887", - "#849989", - "#859b8b", - "#879c8c", - "#899d8e", - "#8b9f90", - "#8da092", - "#8fa294", - "#91a395", - "#92a597", - "#94a699", - "#96a89b", - "#98aa9d", - "#9aab9e", - "#9cada0", - "#9eaea2", - "#a0b0a4", - "#a2b1a6", - "#a4b3a8", - "#a6b4aa", - "#a8b6ab", - "#aab8ad", - "#acb9af", - "#aebbb1", - "#b0bcb3", - "#b2beb5", - "#b4c0b7", - "#b6c1b9", - "#b8c3bb", - "#bac5bd", - "#bcc6bf", - "#bec8c1", - "#c0cac2", - "#c2cbc4", - "#c4cdc6", - "#c6cfc8", - "#c8d0ca", - "#cad2cc", - "#ccd4ce", - "#ced6d0", - "#d0d7d2", - "#d3d9d5", - "#d5dbd7", - "#d7ddd9", - "#d9dfdb", - "#dbe0dd", - "#dde2df", - "#dfe4e1", - "#e2e6e3", - "#e4e8e5", - "#e6eae7", - "#e8ebe9", - "#ebedeb", - "#edefee", - "#eff1f0", - "#f1f3f2", - "#f4f5f4", - "#f6f7f6", - "#f8f9f8", - "#fafbfb", - "#fdfdfd", - "#FFFFFF", - "#fdfdfd", - "#fbfbfc", - "#f9f9fa", - "#f7f7f8", - "#f5f5f7", - "#f3f3f5", - "#f1f0f4", - "#efeef2", - "#ececf0", - "#eaeaef", - "#e8e8ed", - "#e6e6ec", - "#e4e5ea", - "#e3e3e9", - "#e1e1e8", - "#dfdfe6", - "#dddde5", - "#dbdbe3", - "#d9d9e2", - "#d7d7e1", - "#d5d5df", - "#d3d3de", - "#d1d1dd", - "#cfd0db", - "#ceceda", - "#ccccd9", - "#cacad7", - "#c8c8d6", - "#c6c6d5", - "#c4c4d4", - "#c3c3d2", - "#c1c1d1", - "#bfbfd0", - "#bdbdcf", - "#bcbcce", - "#babacd", - "#b8b8cb", - "#b6b6ca", - "#b5b4c9", - "#b3b3c8", - "#b1b1c7", - "#afafc6", - "#aeaec5", - "#acacc4", - "#aaaac3", - "#a9a8c1", - "#a7a7c0", - "#a5a5bf", - "#a4a3be", - "#a2a2bd", - "#a1a0bc", - "#9f9ebb", - "#9d9dba", - "#9c9bb9", - "#9a99b8", - "#9898b8", - "#9796b7", - "#9594b6", - "#9493b5", - "#9291b4", - "#918fb3", - "#8f8eb2", - "#8e8cb1", - "#8c8ab0", - "#8b89af", - "#8987af", - "#8786ae", - "#8684ad", - "#8482ac", - "#8381ab", - "#817faa", - "#807eaa", - "#7f7ca9", - "#7d7aa8", - "#7c79a7", - "#7a77a6", - "#7976a6", - "#7774a5", - "#7672a4", - "#7471a4", - "#736fa3", - "#726ea2", - "#706ca1", - "#6f6aa1", - "#6d69a0", - "#6c679f", - "#6b669f", - "#69649e", - "#68629d", - "#67619d", - "#655f9c", - "#645e9c", - "#635c9b", - "#615a9a", - "#60599a", - "#5f5799", - "#5d5599", - "#5c5498", - "#5b5298", - "#595097", - "#584f97", - "#574d96", - "#564b96", - "#554a95", - "#534895", - "#524695", - "#514494", - "#504394", - "#4f4193", - "#4d3f93", - "#4c3d93", - "#4b3b92", - "#4a3992", - "#493792", - "#483592", - "#473392", - "#463191", - "#452f91", - "#442d91", - "#432a91", - "#422891", - "#412591", - ], - "flex_PuGr": [ - "#412591", - "#422891", - "#432a91", - "#442d91", - "#452f91", - "#463191", - "#473392", - "#483592", - "#493792", - "#4a3992", - "#4b3b92", - "#4c3d93", - "#4d3f93", - "#4f4193", - "#504394", - "#514494", - "#524695", - "#534895", - "#554a95", - "#564b96", - "#574d96", - "#584f97", - "#595097", - "#5b5298", - "#5c5498", - "#5d5599", - "#5f5799", - "#60599a", - "#615a9a", - "#635c9b", - "#645e9c", - "#655f9c", - "#67619d", - "#68629d", - "#69649e", - "#6b669f", - "#6c679f", - "#6d69a0", - "#6f6aa1", - "#706ca1", - "#726ea2", - "#736fa3", - "#7471a4", - "#7672a4", - "#7774a5", - "#7976a6", - "#7a77a6", - "#7c79a7", - "#7d7aa8", - "#7f7ca9", - "#807eaa", - "#817faa", - "#8381ab", - "#8482ac", - "#8684ad", - "#8786ae", - "#8987af", - "#8b89af", - "#8c8ab0", - "#8e8cb1", - "#8f8eb2", - "#918fb3", - "#9291b4", - "#9493b5", - "#9594b6", - "#9796b7", - "#9898b8", - "#9a99b8", - "#9c9bb9", - "#9d9dba", - "#9f9ebb", - "#a1a0bc", - "#a2a2bd", - "#a4a3be", - "#a5a5bf", - "#a7a7c0", - "#a9a8c1", - "#aaaac3", - "#acacc4", - "#aeaec5", - "#afafc6", - "#b1b1c7", - "#b3b3c8", - "#b5b4c9", - "#b6b6ca", - "#b8b8cb", - "#babacd", - "#bcbcce", - "#bdbdcf", - "#bfbfd0", - "#c1c1d1", - "#c3c3d2", - "#c4c4d4", - "#c6c6d5", - "#c8c8d6", - "#cacad7", - "#ccccd9", - "#ceceda", - "#cfd0db", - "#d1d1dd", - "#d3d3de", - "#d5d5df", - "#d7d7e1", - "#d9d9e2", - "#dbdbe3", - "#dddde5", - "#dfdfe6", - "#e1e1e8", - "#e3e3e9", - "#e4e5ea", - "#e6e6ec", - "#e8e8ed", - "#eaeaef", - "#ececf0", - "#efeef2", - "#f1f0f4", - "#f3f3f5", - "#f5f5f7", - "#f7f7f8", - "#f9f9fa", - "#fbfbfc", - "#fdfdfd", - "#FFFFFF", - "#fdfdfd", - "#fafbfb", - "#f8f9f8", - "#f6f7f6", - "#f4f5f4", - "#f1f3f2", - "#eff1f0", - "#edefee", - "#ebedeb", - "#e8ebe9", - "#e6eae7", - "#e4e8e5", - "#e2e6e3", - "#dfe4e1", - "#dde2df", - "#dbe0dd", - "#d9dfdb", - "#d7ddd9", - "#d5dbd7", - "#d3d9d5", - "#d0d7d2", - "#ced6d0", - "#ccd4ce", - "#cad2cc", - "#c8d0ca", - "#c6cfc8", - "#c4cdc6", - "#c2cbc4", - "#c0cac2", - "#bec8c1", - "#bcc6bf", - "#bac5bd", - "#b8c3bb", - "#b6c1b9", - "#b4c0b7", - "#b2beb5", - "#b0bcb3", - "#aebbb1", - "#acb9af", - "#aab8ad", - "#a8b6ab", - "#a6b4aa", - "#a4b3a8", - "#a2b1a6", - "#a0b0a4", - "#9eaea2", - "#9cada0", - "#9aab9e", - "#98aa9d", - "#96a89b", - "#94a699", - "#92a597", - "#91a395", - "#8fa294", - "#8da092", - "#8b9f90", - "#899d8e", - "#879c8c", - "#859b8b", - "#849989", - "#829887", - "#809685", - "#7e9584", - "#7c9382", - "#7a9280", - "#79907e", - "#778f7d", - "#758e7b", - "#738c79", - "#718b78", - "#6f8976", - "#6e8874", - "#6c8672", - "#6a8571", - "#68846f", - "#67826d", - "#65816c", - "#637f6a", - "#617e68", - "#5f7d67", - "#5e7b65", - "#5c7a63", - "#5a7962", - "#587760", - "#57765f", - "#55745d", - "#53735b", - "#51725a", - "#4f7058", - "#4e6f56", - "#4c6e55", - "#4a6c53", - "#486b52", - "#476a50", - "#45684e", - "#43674d", - "#41664b", - "#3f644a", - "#3e6348", - "#3c6246", - "#3a6045", - "#385f43", - "#365e42", - "#355c40", - "#335b3e", - "#315a3d", - "#2f583b", - "#2d573a", - "#2b5638", - "#295437", - "#275335", - "#255234", - "#235032", - "#214f30", - "#1f4d2f", - "#1d4c2d", - "#1b4b2c", - "#19492a", - "#174829", - "#144727", - "#124526", - "#0f4424", - ], - "flex_TuOr": [ - "#134e4a", - "#164f4b", - "#19504d", - "#1b524e", - "#1e534f", - "#205450", - "#225552", - "#255753", - "#275854", - "#295955", - "#2b5a57", - "#2d5c58", - "#2f5d59", - "#315e5a", - "#335f5c", - "#35615d", - "#37625e", - "#39635f", - "#3b6461", - "#3c6662", - "#3e6763", - "#406865", - "#426966", - "#446b67", - "#456c68", - "#476d6a", - "#496e6b", - "#4b706c", - "#4d716e", - "#4e726f", - "#507370", - "#527572", - "#547673", - "#557774", - "#577975", - "#597a77", - "#5b7b78", - "#5d7c79", - "#5e7e7b", - "#607f7c", - "#62807d", - "#64827f", - "#658380", - "#678482", - "#698683", - "#6b8784", - "#6c8886", - "#6e8a87", - "#708b88", - "#728c8a", - "#738e8b", - "#758f8c", - "#77908e", - "#79928f", - "#7a9391", - "#7c9492", - "#7e9693", - "#809795", - "#819896", - "#839a98", - "#859b99", - "#879d9b", - "#899e9c", - "#8a9f9d", - "#8ca19f", - "#8ea2a0", - "#90a4a2", - "#92a5a3", - "#93a7a5", - "#95a8a6", - "#97a9a8", - "#99aba9", - "#9bacab", - "#9daeac", - "#9eafae", - "#a0b1af", - "#a2b2b1", - "#a4b4b2", - "#a6b5b4", - "#a8b7b5", - "#aab8b7", - "#acbab8", - "#adbbba", - "#afbdbb", - "#b1bebd", - "#b3c0bf", - "#b5c1c0", - "#b7c3c2", - "#b9c5c3", - "#bbc6c5", - "#bdc8c7", - "#bfc9c8", - "#c1cbca", - "#c3cccc", - "#c5cecd", - "#c7d0cf", - "#c9d1d1", - "#cbd3d2", - "#cdd5d4", - "#cfd6d6", - "#d1d8d7", - "#d3dad9", - "#d5dbdb", - "#d7dddc", - "#d9dfde", - "#dbe0e0", - "#dde2e2", - "#dfe4e3", - "#e1e6e5", - "#e3e7e7", - "#e5e9e9", - "#e7ebeb", - "#e9edec", - "#ebeeee", - "#eef0f0", - "#f0f2f2", - "#f2f4f4", - "#f4f6f6", - "#f6f8f7", - "#f8f9f9", - "#fbfbfb", - "#fdfdfd", - "#FFFFFF", - "#fefdfd", - "#fcfcfb", - "#fbfaf9", - "#faf8f7", - "#f9f7f6", - "#f7f5f4", - "#f6f4f2", - "#f5f2f0", - "#f4f0ee", - "#f2efec", - "#f1edea", - "#f0ece8", - "#efeae6", - "#eee8e4", - "#ede7e2", - "#ece5df", - "#ebe4dd", - "#eae2db", - "#e9e0d9", - "#e8dfd7", - "#e7ddd5", - "#e6dcd3", - "#e5dad1", - "#e5d8cf", - "#e4d7cd", - "#e3d5cb", - "#e2d4c9", - "#e1d2c7", - "#e0d0c5", - "#dfcfc3", - "#dfcdc1", - "#deccbe", - "#ddcabc", - "#dcc9ba", - "#dbc7b8", - "#dac6b6", - "#dac4b4", - "#d9c2b2", - "#d8c1b0", - "#d7bfae", - "#d6beac", - "#d6bcaa", - "#d5bba8", - "#d4b9a6", - "#d3b8a4", - "#d3b6a2", - "#d2b5a0", - "#d1b39e", - "#d0b29c", - "#d0b09a", - "#cfaf98", - "#cead96", - "#cdac94", - "#cdaa92", - "#cca990", - "#cba78e", - "#caa68c", - "#caa48a", - "#c9a388", - "#c8a286", - "#c7a084", - "#c79f82", - "#c69d80", - "#c59c7e", - "#c59a7c", - "#c4997a", - "#c39778", - "#c29676", - "#c29474", - "#c19372", - "#c09270", - "#c0906e", - "#bf8f6c", - "#be8d6b", - "#bd8c69", - "#bd8a67", - "#bc8965", - "#bb8863", - "#bb8661", - "#ba855f", - "#b9835d", - "#b8825b", - "#b88059", - "#b77f57", - "#b67e55", - "#b57c53", - "#b57b51", - "#b47950", - "#b3784e", - "#b3774c", - "#b2754a", - "#b17448", - "#b07246", - "#b07144", - "#af7042", - "#ae6e40", - "#ad6d3e", - "#ad6b3c", - "#ac6a3a", - "#ab6938", - "#aa6737", - "#a96635", - "#a96433", - "#a86331", - "#a7622f", - "#a6602d", - "#a65f2b", - "#a55d29", - "#a45c27", - "#a35b25", - "#a25923", - "#a15821", - "#a1571f", - "#a0551c", - "#9f541a", - "#9e5218", - "#9d5116", - "#9c5013", - "#9c4e11", - "#9b4d0e", - "#9a4b0b", - "#994a08", - ], - "flex_OrTu": [ - "#994a08", - "#9a4b0b", - "#9b4d0e", - "#9c4e11", - "#9c5013", - "#9d5116", - "#9e5218", - "#9f541a", - "#a0551c", - "#a1571f", - "#a15821", - "#a25923", - "#a35b25", - "#a45c27", - "#a55d29", - "#a65f2b", - "#a6602d", - "#a7622f", - "#a86331", - "#a96433", - "#a96635", - "#aa6737", - "#ab6938", - "#ac6a3a", - "#ad6b3c", - "#ad6d3e", - "#ae6e40", - "#af7042", - "#b07144", - "#b07246", - "#b17448", - "#b2754a", - "#b3774c", - "#b3784e", - "#b47950", - "#b57b51", - "#b57c53", - "#b67e55", - "#b77f57", - "#b88059", - "#b8825b", - "#b9835d", - "#ba855f", - "#bb8661", - "#bb8863", - "#bc8965", - "#bd8a67", - "#bd8c69", - "#be8d6b", - "#bf8f6c", - "#c0906e", - "#c09270", - "#c19372", - "#c29474", - "#c29676", - "#c39778", - "#c4997a", - "#c59a7c", - "#c59c7e", - "#c69d80", - "#c79f82", - "#c7a084", - "#c8a286", - "#c9a388", - "#caa48a", - "#caa68c", - "#cba78e", - "#cca990", - "#cdaa92", - "#cdac94", - "#cead96", - "#cfaf98", - "#d0b09a", - "#d0b29c", - "#d1b39e", - "#d2b5a0", - "#d3b6a2", - "#d3b8a4", - "#d4b9a6", - "#d5bba8", - "#d6bcaa", - "#d6beac", - "#d7bfae", - "#d8c1b0", - "#d9c2b2", - "#dac4b4", - "#dac6b6", - "#dbc7b8", - "#dcc9ba", - "#ddcabc", - "#deccbe", - "#dfcdc1", - "#dfcfc3", - "#e0d0c5", - "#e1d2c7", - "#e2d4c9", - "#e3d5cb", - "#e4d7cd", - "#e5d8cf", - "#e5dad1", - "#e6dcd3", - "#e7ddd5", - "#e8dfd7", - "#e9e0d9", - "#eae2db", - "#ebe4dd", - "#ece5df", - "#ede7e2", - "#eee8e4", - "#efeae6", - "#f0ece8", - "#f1edea", - "#f2efec", - "#f4f0ee", - "#f5f2f0", - "#f6f4f2", - "#f7f5f4", - "#f9f7f6", - "#faf8f7", - "#fbfaf9", - "#fcfcfb", - "#fefdfd", - "#FFFFFF", - "#fdfdfd", - "#fbfbfb", - "#f8f9f9", - "#f6f8f7", - "#f4f6f6", - "#f2f4f4", - "#f0f2f2", - "#eef0f0", - "#ebeeee", - "#e9edec", - "#e7ebeb", - "#e5e9e9", - "#e3e7e7", - "#e1e6e5", - "#dfe4e3", - "#dde2e2", - "#dbe0e0", - "#d9dfde", - "#d7dddc", - "#d5dbdb", - "#d3dad9", - "#d1d8d7", - "#cfd6d6", - "#cdd5d4", - "#cbd3d2", - "#c9d1d1", - "#c7d0cf", - "#c5cecd", - "#c3cccc", - "#c1cbca", - "#bfc9c8", - "#bdc8c7", - "#bbc6c5", - "#b9c5c3", - "#b7c3c2", - "#b5c1c0", - "#b3c0bf", - "#b1bebd", - "#afbdbb", - "#adbbba", - "#acbab8", - "#aab8b7", - "#a8b7b5", - "#a6b5b4", - "#a4b4b2", - "#a2b2b1", - "#a0b1af", - "#9eafae", - "#9daeac", - "#9bacab", - "#99aba9", - "#97a9a8", - "#95a8a6", - "#93a7a5", - "#92a5a3", - "#90a4a2", - "#8ea2a0", - "#8ca19f", - "#8a9f9d", - "#899e9c", - "#879d9b", - "#859b99", - "#839a98", - "#819896", - "#809795", - "#7e9693", - "#7c9492", - "#7a9391", - "#79928f", - "#77908e", - "#758f8c", - "#738e8b", - "#728c8a", - "#708b88", - "#6e8a87", - "#6c8886", - "#6b8784", - "#698683", - "#678482", - "#658380", - "#64827f", - "#62807d", - "#607f7c", - "#5e7e7b", - "#5d7c79", - "#5b7b78", - "#597a77", - "#577975", - "#557774", - "#547673", - "#527572", - "#507370", - "#4e726f", - "#4d716e", - "#4b706c", - "#496e6b", - "#476d6a", - "#456c68", - "#446b67", - "#426966", - "#406865", - "#3e6763", - "#3c6662", - "#3b6461", - "#39635f", - "#37625e", - "#35615d", - "#335f5c", - "#315e5a", - "#2f5d59", - "#2d5c58", - "#2b5a57", - "#295955", - "#275854", - "#255753", - "#225552", - "#205450", - "#1e534f", - "#1b524e", - "#19504d", - "#164f4b", - "#134e4a", - ], -} +from tidy3d._common.components.viz.flex_color_palettes import ( + CATEGORICAL_PALETTES_HEX, + DIVERGING_PALETTES_HEX, + SEQUENTIAL_PALETTES_HEX, +) diff --git a/tidy3d/components/viz/flex_style.py b/tidy3d/components/viz/flex_style.py index 0706826fca..c26686d494 100644 --- a/tidy3d/components/viz/flex_style.py +++ b/tidy3d/components/viz/flex_style.py @@ -1,46 +1,12 @@ -from __future__ import annotations - -from tidy3d.log import log - -_ORIGINAL_PARAMS = None - - -def apply_tidy3d_params() -> None: - """ - Applies a set of defaults to the matplotlib params that are following the tidy3d color palettes and design. - """ - global _ORIGINAL_PARAMS - try: - import matplotlib as mpl - import matplotlib.pyplot as plt +"""Compatibility shim for :mod:`tidy3d._common.components.viz.flex_style`.""" - _ORIGINAL_PARAMS = mpl.rcParams.copy() +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility - try: - plt.style.use("tidy3d.style") - except Exception as e: - log.error(f"Failed to apply Tidy3D plotting style on import. Error: {e}") - _ORIGINAL_PARAMS = {} - except ImportError: - pass - - -def restore_matplotlib_rcparams() -> None: - """ - Resets matplotlib rcParams to the values they had before the Tidy3D - style was automatically applied on import. - """ - global _ORIGINAL_PARAMS - try: - import matplotlib.pyplot as plt - from matplotlib import style - - if not _ORIGINAL_PARAMS: - style.use("default") - return +# marked as migrated to _common +from __future__ import annotations - plt.rcParams.update(_ORIGINAL_PARAMS) - except ImportError: - log.error("Matplotlib is not installed on your system. Failed to reset to default styles.") - except Exception as e: - log.error(f"Failed to reset previous Matplotlib style. Error: {e}") +from tidy3d._common.components.viz.flex_style import ( + _ORIGINAL_PARAMS, + apply_tidy3d_params, + restore_matplotlib_rcparams, +) diff --git a/tidy3d/components/viz/plot_params.py b/tidy3d/components/viz/plot_params.py index aa46630c9e..e6bd14d668 100644 --- a/tidy3d/components/viz/plot_params.py +++ b/tidy3d/components/viz/plot_params.py @@ -1,93 +1,27 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Optional - -from numpy import inf -from pydantic import Field, NonNegativeFloat - -from tidy3d.components.base import Tidy3dBaseModel - -if TYPE_CHECKING: - from tidy3d.components.viz.visualization_spec import VisualizationSpec - - -class AbstractPlotParams(Tidy3dBaseModel): - """Abstract class for storing plotting parameters. - Corresponds with select properties of ``matplotlib.artist.Artist``. - """ - - alpha: Any = Field(1.0, title="Opacity") - zorder: Optional[float] = Field(None, title="Display Order") - - def include_kwargs(self, **kwargs: Any) -> AbstractPlotParams: - """Update the plot params with supplied kwargs.""" - update_dict = { - key: value - for key, value in kwargs.items() - if key not in ("type",) and value is not None and key in type(self).model_fields - } - return self.copy(update=update_dict) - - def override_with_viz_spec(self, viz_spec: VisualizationSpec) -> AbstractPlotParams: - """Override plot params with supplied VisualizationSpec.""" - return self.include_kwargs(**dict(viz_spec)) +"""Compatibility shim for :mod:`tidy3d._common.components.viz.plot_params`.""" - def to_kwargs(self) -> dict[str, Any]: - """Export the plot parameters as kwargs dict that can be supplied to plot function.""" - kwarg_dict = self.model_dump() - for ignore_key in ("type", "attrs"): - kwarg_dict.pop(ignore_key) - return kwarg_dict - - -class PathPlotParams(AbstractPlotParams): - """Stores plotting parameters / specifications for a path. - Corresponds with select properties of ``matplotlib.lines.Line2D``. - """ - - color: Optional[Any] = Field(None, title="Color", alias="c") - linewidth: NonNegativeFloat = Field(2, title="Line Width", alias="lw") - linestyle: str = Field("--", title="Line Style", alias="ls") - marker: Any = Field("o", title="Marker Style") - markeredgecolor: Optional[Any] = Field(None, title="Marker Edge Color", alias="mec") - markerfacecolor: Optional[Any] = Field(None, title="Marker Face Color", alias="mfc") - markersize: NonNegativeFloat = Field(10, title="Marker Size", alias="ms") - - -class PlotParams(AbstractPlotParams): - """Stores plotting parameters / specifications for a given model. - Corresponds with select properties of ``matplotlib.patches.Patch``. - """ - - edgecolor: Optional[Any] = Field(None, title="Edge Color", alias="ec") - facecolor: Optional[Any] = Field(None, title="Face Color", alias="fc") - fill: bool = Field(True, title="Is Filled") - hatch: Optional[str] = Field(None, title="Hatch Style") - linewidth: NonNegativeFloat = Field(1, title="Line Width", alias="lw") +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility +# marked as migrated to _common +from __future__ import annotations -# defaults for different tidy3d objects -plot_params_geometry = PlotParams() -plot_params_structure = PlotParams() -plot_params_source = PlotParams(alpha=0.4, facecolor="limegreen", edgecolor="limegreen", lw=3) -plot_params_absorber = PlotParams( - alpha=0.4, facecolor="lightskyblue", edgecolor="lightskyblue", lw=3 -) -plot_params_monitor = PlotParams(alpha=0.4, facecolor="orange", edgecolor="orange", lw=3) -plot_params_pml = PlotParams(alpha=0.7, facecolor="gray", edgecolor="gray", hatch="x", zorder=inf) -plot_params_pec = PlotParams(alpha=1.0, facecolor="gold", edgecolor="black", zorder=inf) -plot_params_pmc = PlotParams(alpha=1.0, facecolor="lightsteelblue", edgecolor="black", zorder=inf) -plot_params_bloch = PlotParams(alpha=1.0, facecolor="orchid", edgecolor="black", zorder=inf) -plot_params_abc = PlotParams(alpha=1.0, facecolor="lightskyblue", edgecolor="black", zorder=inf) -plot_params_symmetry = PlotParams(edgecolor="gray", facecolor="gray", alpha=0.6, zorder=inf) -plot_params_override_structures = PlotParams( - linewidth=0.4, edgecolor="black", fill=False, zorder=inf -) -plot_params_fluid = PlotParams(facecolor="white", edgecolor="lightsteelblue", lw=0.4, hatch="xx") -plot_params_grid = PlotParams(edgecolor="black", lw=0.2) -plot_params_lumped_element = PlotParams( - alpha=0.4, facecolor="mediumblue", edgecolor="mediumblue", lw=3 -) -plot_params_min_grid_size = PlotParams( - alpha=0.5, facecolor="gray", edgecolor="darkred", lw=0, fill=True, hatch=".", zorder=0 +from tidy3d._common.components.viz.plot_params import ( + AbstractPlotParams, + PathPlotParams, + PlotParams, + plot_params_abc, + plot_params_absorber, + plot_params_bloch, + plot_params_fluid, + plot_params_geometry, + plot_params_grid, + plot_params_lumped_element, + plot_params_monitor, + plot_params_override_structures, + plot_params_pec, + plot_params_pmc, + plot_params_pml, + plot_params_source, + plot_params_structure, + plot_params_symmetry, ) diff --git a/tidy3d/components/viz/plot_sim_3d.py b/tidy3d/components/viz/plot_sim_3d.py index 3b17db7828..e6de969fbd 100644 --- a/tidy3d/components/viz/plot_sim_3d.py +++ b/tidy3d/components/viz/plot_sim_3d.py @@ -1,193 +1,11 @@ -from __future__ import annotations - -from html import escape -from typing import TYPE_CHECKING - -from tidy3d.exceptions import SetupError - -if TYPE_CHECKING: - from typing import Union - - from IPython.core.display_functions import DisplayHandle - - from tidy3d import Scene, Simulation - - -def plot_scene_3d(scene: Scene, width: int = 800, height: int = 800) -> None: - import gzip - import json - from base64 import b64encode - from io import BytesIO - - import h5py - - # Serialize scene to HDF5 in-memory - buffer = BytesIO() - scene.to_hdf5(buffer) - buffer.seek(0) - - # Open source HDF5 for reading and prepare modified copy - with h5py.File(buffer, "r") as src: - buffer2 = BytesIO() - with h5py.File(buffer2, "w") as dst: - - def copy_item(name: str, obj: h5py.Group | h5py.Dataset) -> None: - if isinstance(obj, h5py.Group): - dst.create_group(name) - for k, v in obj.attrs.items(): - dst[name].attrs[k] = v - elif isinstance(obj, h5py.Dataset): - data = obj[()] - if name == "JSON_STRING": - # Parse and update JSON string - json_str = ( - data.decode("utf-8") if isinstance(data, (bytes, bytearray)) else data - ) - json_data = json.loads(json_str) - json_data["size"] = list(scene.size) - json_data["center"] = list(scene.center) - json_data["grid_spec"] = {} - new_str = json.dumps(json_data) - dst.create_dataset(name, data=new_str.encode("utf-8")) - else: - dst.create_dataset(name, data=data) - for k, v in obj.attrs.items(): - dst[name].attrs[k] = v - - src.visititems(copy_item) - buffer2.seek(0) - - # Gzip the modified HDF5 - gz_buffer = BytesIO() - with gzip.GzipFile(fileobj=gz_buffer, mode="wb") as gz: - gz.write(buffer2.read()) - gz_buffer.seek(0) - - # Base64 encode and display with gzipped flag - sim_base64 = b64encode(gz_buffer.read()).decode("utf-8") - plot_sim_3d(sim_base64, width=width, height=height, is_gz_base64=True) +"""Compatibility shim for :mod:`tidy3d._common.components.viz.plot_sim_3d`.""" +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -def plot_sim_3d( - sim: Union[Simulation, str], width: int = 800, height: int = 800, is_gz_base64: bool = False -) -> DisplayHandle: - """Make 3D display of simulation in ipython notebook.""" - - try: - from IPython.display import HTML, display - except ImportError as e: - raise SetupError( - "3D plotting requires ipython to be installed " - "and the code to be running on a jupyter notebook." - ) from e - - from base64 import b64encode - from io import BytesIO - - if not is_gz_base64: - buffer = BytesIO() - sim.to_hdf5_gz(buffer) - buffer.seek(0) - base64 = b64encode(buffer.read()).decode("utf-8") - else: - base64 = sim - - js_code = """ - /** - * Simulation Viewer Injector - * - * Monitors the document for elements being added in the form: - * - *
- * - * This script will then inject an iframe to the viewer application, and pass it the simulation data - * via the postMessage API on request. The script may be safely included multiple times, with only the - * configuration of the first started script (e.g. viewer URL) applying. - * - */ - (function() { - const TARGET_CLASS = "simulation-viewer"; - const ACTIVE_CLASS = "simulation-viewer-active"; - const VIEWER_URL = "https://tidy3d.simulation.cloud/simulation-viewer"; - - class SimulationViewerInjector { - constructor() { - for (var node of document.getElementsByClassName(TARGET_CLASS)) { - this.injectViewer(node); - } - - // Monitor for newly added nodes to the DOM - this.observer = new MutationObserver(this.onMutations.bind(this)); - this.observer.observe(document.body, {childList: true, subtree: true}); - } - - onMutations(mutations) { - for (var mutation of mutations) { - if (mutation.type === 'childList') { - /** - * Have found that adding the element does not reliably trigger the mutation observer. - * It may be the case that setting content with innerHTML does not trigger. - * - * It seems to be sufficient to re-scan the document for un-activated viewers - * whenever an event occurs, as Jupyter triggers multiple events on cell evaluation. - */ - var viewers = document.getElementsByClassName(TARGET_CLASS); - for (var node of viewers) { - this.injectViewer(node); - } - } - } - } - - injectViewer(node) { - // (re-)check that this is a valid simulation container and has not already been injected - if (node.classList.contains(TARGET_CLASS) && !node.classList.contains(ACTIVE_CLASS)) { - // Mark node as injected, to prevent re-runs - node.classList.add(ACTIVE_CLASS); - - var uuid; - if (window.crypto && window.crypto.randomUUID) { - uuid = window.crypto.randomUUID(); - } else { - uuid = "" + Math.random(); - } - - var frame = document.createElement("iframe"); - frame.width = node.dataset.width || 800; - frame.height = node.dataset.height || 800; - frame.style.cssText = `width:${frame.width}px;height:${frame.height}px;max-width:none;border:0;display:block` - frame.src = VIEWER_URL + "?uuid=" + uuid; - - var postMessageToViewer; - postMessageToViewer = event => { - if(event.data.type === 'viewer' && event.data.uuid===uuid){ - frame.contentWindow.postMessage({ type: 'jupyter', uuid, value: node.dataset.simulation, fileType: 'hdf5'}, '*'); - - // Run once only - window.removeEventListener('message', postMessageToViewer); - } - }; - window.addEventListener( - 'message', - postMessageToViewer, - false - ); - - node.appendChild(frame); - } - } - } - - if (!window.simulationViewerInjector) { - window.simulationViewerInjector = new SimulationViewerInjector(); - } - })(); - """ - html_code = f""" -
- - """ +# marked as migrated to _common +from __future__ import annotations - return display(HTML(html_code)) +from tidy3d._common.components.viz.plot_sim_3d import ( + plot_scene_3d, + plot_sim_3d, +) diff --git a/tidy3d/components/viz/styles.py b/tidy3d/components/viz/styles.py index 067afa9327..77f0a87390 100644 --- a/tidy3d/components/viz/styles.py +++ b/tidy3d/components/viz/styles.py @@ -1,41 +1,21 @@ -from __future__ import annotations - -try: - from matplotlib.patches import ArrowStyle +"""Compatibility shim for :mod:`tidy3d._common.components.viz.styles`.""" - arrow_style = ArrowStyle.Simple(head_length=11, head_width=9, tail_width=4) -except ImportError: - arrow_style = None +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -FLEXCOMPUTE_COLORS = { - "brand_green": "#00643C", - "brand_tan": "#B8A18B", - "brand_blue": "#6DB5DD", - "brand_purple": "#8851AD", - "brand_black": "#000000", - "brand_orange": "#FC7A4C", -} -ARROW_COLOR_SOURCE = FLEXCOMPUTE_COLORS["brand_green"] -ARROW_COLOR_POLARIZATION = FLEXCOMPUTE_COLORS["brand_tan"] -ARROW_COLOR_MONITOR = FLEXCOMPUTE_COLORS["brand_orange"] -ARROW_COLOR_ABSORBER = FLEXCOMPUTE_COLORS["brand_blue"] -PLOT_BUFFER = 0.3 -ARROW_ALPHA = 0.8 -ARROW_LENGTH = 0.3 - -# stores color of simulation.structures for given index in simulation.medium_map -MEDIUM_CMAP = [ - "#689DBC", - "#D0698E", - "#5E6EAD", - "#C6224E", - "#BDB3E2", - "#9EC3E0", - "#616161", - "#877EBC", -] +# marked as migrated to _common +from __future__ import annotations -# colormap for structure's permittivity in plot_eps -STRUCTURE_EPS_CMAP = "gist_yarg" -STRUCTURE_EPS_CMAP_R = "gist_yarg_r" -STRUCTURE_HEAT_COND_CMAP = "gist_yarg" +from tidy3d._common.components.viz.styles import ( + ARROW_ALPHA, + ARROW_COLOR_ABSORBER, + ARROW_COLOR_MONITOR, + ARROW_COLOR_POLARIZATION, + ARROW_COLOR_SOURCE, + ARROW_LENGTH, + FLEXCOMPUTE_COLORS, + MEDIUM_CMAP, + PLOT_BUFFER, + STRUCTURE_EPS_CMAP, + STRUCTURE_EPS_CMAP_R, + STRUCTURE_HEAT_COND_CMAP, +) diff --git a/tidy3d/components/viz/visualization_spec.py b/tidy3d/components/viz/visualization_spec.py index f62dfc85a5..58070983c7 100644 --- a/tidy3d/components/viz/visualization_spec.py +++ b/tidy3d/components/viz/visualization_spec.py @@ -1,69 +1,12 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from pydantic import Field, field_validator - -from tidy3d.components.base import Tidy3dBaseModel -from tidy3d.log import log - -if TYPE_CHECKING: - from pydantic import ValidationInfo - -MATPLOTLIB_IMPORTED = True -try: - from matplotlib.colors import is_color_like -except ImportError: - is_color_like = None - MATPLOTLIB_IMPORTED = False - +"""Compatibility shim for :mod:`tidy3d._common.components.viz.visualization_spec`.""" -def is_valid_color(value: str) -> str: - if not MATPLOTLIB_IMPORTED: - log.warning( - "matplotlib was not successfully imported, but is required " - "to validate colors in the VisualizationSpec. The specified colors " - "have not been validated." - ) - else: - if is_color_like is not None and not is_color_like(value): - raise ValueError(f"{value} is not a valid plotting color") +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility - return value - - -class VisualizationSpec(Tidy3dBaseModel): - """Defines specification for visualization when used with plotting functions.""" - - facecolor: str = Field( - "", - title="Face color", - description="Color applied to the faces in visualization.", - ) - - edgecolor: str = Field( - "", - title="Edge color", - description="Color applied to the edges in visualization.", - ) - - alpha: float = Field( - 1.0, - title="Opacity", - description="Opacity/alpha value in plotting between 0 and 1.", - ge=0, - le=1, - ) - - @field_validator("facecolor") - @classmethod - def _validate_facecolor(cls, value: str) -> str: - return is_valid_color(value) +# marked as migrated to _common +from __future__ import annotations - @field_validator("edgecolor") - @classmethod - def _ensure_edgecolor(cls, value: str, info: ValidationInfo) -> str: - # if no explicit edgecolor given, fall back to facecolor - if (value == "") and "facecolor" in info.data: - return is_valid_color(info.data["facecolor"]) - return is_valid_color(value) +from tidy3d._common.components.viz.visualization_spec import ( + MATPLOTLIB_IMPORTED, + VisualizationSpec, + is_valid_color, +) diff --git a/tidy3d/config/__init__.py b/tidy3d/config/__init__.py index 8865c1ec95..90a9963354 100644 --- a/tidy3d/config/__init__.py +++ b/tidy3d/config/__init__.py @@ -1,69 +1,32 @@ -"""Tidy3D configuration system public API.""" +"""Compatibility shim for :mod:`tidy3d._common.config`.""" -from __future__ import annotations +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -from typing import Any +# marked as migrated to _common +from __future__ import annotations -from . import sections # noqa: F401 - ensure builtin sections register -from .legacy import LegacyConfigWrapper, LegacyEnvironment, LegacyEnvironmentConfig -from .manager import ConfigManager -from .registry import ( +import tidy3d._common.config as _common_config +from tidy3d.config import sections + +_common_config.initialize_env() + +from tidy3d._common.config import ( # noqa: E402 - import after Env setup + ConfigManager, + Env, + Environment, + EnvironmentConfig, + LegacyConfigWrapper, + LegacyEnvironment, + LegacyEnvironmentConfig, + _base_manager, + _config_wrapper, + _create_manager, + config, get_handlers, + get_manager, get_sections, register_handler, register_plugin, register_section, + reload_config, ) - -__all__ = [ - "ConfigManager", - "Env", - "Environment", - "EnvironmentConfig", - "config", - "get_handlers", - "get_sections", - "register_handler", - "register_plugin", - "register_section", -] - - -def _create_manager() -> ConfigManager: - return ConfigManager() - - -_base_manager = _create_manager() -# TODO(FXC-3827): Drop LegacyConfigWrapper once legacy accessors are removed in Tidy3D 2.12. -_config_wrapper = LegacyConfigWrapper(_base_manager) -config = _config_wrapper - -# TODO(FXC-3827): Remove legacy Env exports after deprecation window (planned 2.12). -Environment = LegacyEnvironment -EnvironmentConfig = LegacyEnvironmentConfig -Env = LegacyEnvironment(_base_manager) - - -def reload_config(*, profile: str | None = None) -> LegacyConfigWrapper: - """Recreate the global configuration manager (primarily for tests).""" - - global _base_manager, Env - if _base_manager is not None: - try: - _base_manager.apply_web_env({}) - except AttributeError: - pass - _base_manager = ConfigManager(profile=profile) - _config_wrapper.reset_manager(_base_manager) - Env.reset_manager(_base_manager) - return _config_wrapper - - -def get_manager() -> ConfigManager: - """Return the underlying configuration manager instance.""" - - return _base_manager - - -def __getattr__(name: str) -> Any: - return getattr(config, name) diff --git a/tidy3d/config/legacy.py b/tidy3d/config/legacy.py index 069d8bd246..1356bafe49 100644 --- a/tidy3d/config/legacy.py +++ b/tidy3d/config/legacy.py @@ -1,541 +1,16 @@ -"""Legacy compatibility layer for tidy3d.config. +"""Compatibility shim for :mod:`tidy3d._common.config.legacy`.""" -This module holds (most) of the compatibility layer to the pre-2.10 tidy3d config -and is intended to be removed in a future release. -""" +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility +# marked as migrated to _common from __future__ import annotations -import os -import warnings -from typing import TYPE_CHECKING, Any - -import toml - -from tidy3d._runtime import WASM_BUILD -from tidy3d.log import log - -# TODO(FXC-3827): Remove LegacyConfigWrapper/Environment shims and related helpers in Tidy3D 2.12. -from .manager import ConfigManager, normalize_profile_name -from .profiles import BUILTIN_PROFILES - -if TYPE_CHECKING: - from pathlib import Path - from typing import Optional - - from tidy3d.log import LogLevel - - -def _warn_env_deprecated() -> None: - message = "'tidy3d.config.Env' is deprecated; use 'config.switch_profile(...)' instead." - warnings.warn(message, DeprecationWarning, stacklevel=3) - log.warning(message, log_once=True) - - -# TODO(FXC-3827): Delete LegacyConfigWrapper once legacy attribute access is dropped. -class LegacyConfigWrapper: - """Provide attribute-level compatibility with the legacy config module.""" - - def __init__(self, manager: ConfigManager): - self._manager = manager - self._frozen = False # retained for backwards compatibility tests - - @property - def logging_level(self) -> LogLevel: - return self._manager.get_section("logging").level - - @logging_level.setter - def logging_level(self, value: LogLevel) -> None: - from warnings import warn - - warn( - "'config.logging_level' is deprecated; use 'config.logging.level' instead.", - DeprecationWarning, - stacklevel=2, - ) - self._manager.update_section("logging", level=value) - - @property - def log_suppression(self) -> bool: - return self._manager.get_section("logging").suppression - - @log_suppression.setter - def log_suppression(self, value: bool) -> None: - from warnings import warn - - warn( - "'config.log_suppression' is deprecated; use 'config.logging.suppression'.", - DeprecationWarning, - stacklevel=2, - ) - self._manager.update_section("logging", suppression=value) - - @property - def use_local_subpixel(self) -> Optional[bool]: - return self._manager.get_section("simulation").use_local_subpixel - - @use_local_subpixel.setter - def use_local_subpixel(self, value: Optional[bool]) -> None: - from warnings import warn - - warn( - "'config.use_local_subpixel' is deprecated; use 'config.simulation.use_local_subpixel'.", - DeprecationWarning, - stacklevel=2, - ) - self._manager.update_section("simulation", use_local_subpixel=value) - - @property - def suppress_rf_license_warning(self) -> bool: - return self._manager.get_section("microwave").suppress_rf_license_warning - - @suppress_rf_license_warning.setter - def suppress_rf_license_warning(self, value: bool) -> None: - from warnings import warn - - warn( - "'config.suppress_rf_license_warning' is deprecated; " - "use 'config.microwave.suppress_rf_license_warning'.", - DeprecationWarning, - stacklevel=2, - ) - self._manager.update_section("microwave", suppress_rf_license_warning=value) - - @property - def frozen(self) -> bool: - return self._frozen - - @frozen.setter - def frozen(self, value: bool) -> None: - self._frozen = bool(value) - - def save(self, include_defaults: bool = False) -> None: - self._manager.save(include_defaults=include_defaults) - - def reset_manager(self, manager: ConfigManager) -> None: - """Swap the underlying manager instance.""" - - self._manager = manager - - def switch_profile(self, profile: str) -> None: - """Switch active profile and synchronize the legacy environment proxy.""" - - normalized = normalize_profile_name(profile) - self._manager.switch_profile(normalized) - try: - from tidy3d.config import Env as _legacy_env - except Exception: - _legacy_env = None - if _legacy_env is not None: - _legacy_env._sync_to_manager(apply_env=True) - - def __getattr__(self, name: str) -> Any: - return getattr(self._manager, name) - - def __setattr__(self, name: str, value: Any) -> None: - if name.startswith("_"): - object.__setattr__(self, name, value) - elif name in { - "logging_level", - "log_suppression", - "use_local_subpixel", - "suppress_rf_license_warning", - "frozen", - }: - prop = getattr(type(self), name) - prop.fset(self, value) - else: - setattr(self._manager, name, value) - - def __str__(self) -> str: - return self._manager.format() - - -# TODO(FXC-3827): Delete LegacyEnvironmentConfig once profile-based Env shim is removed. -class LegacyEnvironmentConfig: - """Backward compatible environment config wrapper that proxies ConfigManager.""" - - def __init__( - self, - manager: Optional[ConfigManager] = None, - name: Optional[str] = None, - *, - web_api_endpoint: Optional[str] = None, - website_endpoint: Optional[str] = None, - s3_region: Optional[str] = None, - ssl_verify: Optional[bool] = None, - enable_caching: Optional[bool] = None, - ssl_version: Optional[str] = None, - env_vars: Optional[dict[str, str]] = None, - environment: Optional[LegacyEnvironment] = None, - ) -> None: - if name is None: - raise ValueError("Environment name is required") - self._manager = manager - self._name = normalize_profile_name(name) - self._environment = environment - self._pending: dict[str, Any] = {} - if web_api_endpoint is not None: - self._pending["api_endpoint"] = web_api_endpoint - if website_endpoint is not None: - self._pending["website_endpoint"] = website_endpoint - if s3_region is not None: - self._pending["s3_region"] = s3_region - if ssl_verify is not None: - self._pending["ssl_verify"] = ssl_verify - if enable_caching is not None: - self._pending["enable_caching"] = enable_caching - if ssl_version is not None: - self._pending["ssl_version"] = ssl_version - if env_vars is not None: - self._pending["env_vars"] = dict(env_vars) - - def reset_manager(self, manager: ConfigManager) -> None: - self._manager = manager - - @property - def manager(self) -> Optional[ConfigManager]: - if self._manager is not None: - return self._manager - if self._environment is not None: - return self._environment._manager - return None - - def active(self) -> None: - _warn_env_deprecated() - environment = self._environment - if environment is None: - from tidy3d.config import Env # local import to avoid circular - - environment = Env - - environment.set_current(self) - - @property - def web_api_endpoint(self) -> Optional[str]: - value = self._value("api_endpoint") - return _maybe_str(value) - - @property - def website_endpoint(self) -> Optional[str]: - value = self._value("website_endpoint") - return _maybe_str(value) - - @property - def s3_region(self) -> Optional[str]: - return self._value("s3_region") - - @property - def ssl_verify(self) -> bool: - value = self._value("ssl_verify") - if value is None: - return True - return bool(value) - - @property - def enable_caching(self) -> bool: - value = self._value("enable_caching") - if value is None: - return True - return bool(value) - - @enable_caching.setter - def enable_caching(self, value: Optional[bool]) -> None: - self._set_pending("enable_caching", value) - - @property - def ssl_version(self) -> Optional[str]: - return self._value("ssl_version") - - @ssl_version.setter - def ssl_version(self, value: Optional[str]) -> None: - self._set_pending("ssl_version", value) - - @property - def env_vars(self) -> dict[str, str]: - value = self._value("env_vars") - if value is None: - return {} - return dict(value) - - @env_vars.setter - def env_vars(self, value: dict[str, str]) -> None: - self._set_pending("env_vars", dict(value)) - - @property - def name(self) -> str: - return self._name - - @name.setter - def name(self, value: str) -> None: - self._name = normalize_profile_name(value) - - def copy_state_from(self, other: LegacyEnvironmentConfig) -> None: - if not isinstance(other, LegacyEnvironmentConfig): - raise TypeError("Expected LegacyEnvironmentConfig instance.") - for key, value in other._pending.items(): - if key == "env_vars" and value is not None: - self._pending[key] = dict(value) - else: - self._pending[key] = value - - def get_real_url(self, path: str) -> str: - manager = self.manager - if manager is not None and manager.profile == self._name: - web_section = manager.get_section("web") - if hasattr(web_section, "build_api_url"): - return web_section.build_api_url(path) - - endpoint = self.web_api_endpoint or "" - if not path: - return endpoint - return "/".join([endpoint.rstrip("/"), str(path).lstrip("/")]) - - def apply_pending_overrides(self) -> None: - manager = self.manager - if manager is None or manager.profile != self._name: - return - if not self._pending: - return - updates = dict(self._pending) - manager.update_section("web", **updates) - self._pending.clear() - - def _set_pending(self, key: str, value: Any) -> None: - if key == "env_vars" and value is not None: - self._pending[key] = dict(value) - else: - self._pending[key] = value - self.apply_pending_overrides() - - def _web_section(self) -> dict[str, Any]: - manager = self.manager - if manager is None or WASM_BUILD: - return {} - profile = normalize_profile_name(self._name) - if manager.profile == profile: - section = manager.get_section("web") - return section.model_dump(mode="python", exclude_unset=False) - preview = manager.preview_profile(profile) - source = preview.get("web", {}) - return dict(source) if isinstance(source, dict) else {} - - def _value(self, key: str) -> Any: - if key in self._pending: - return self._pending[key] - return self._web_section().get(key) - - -# TODO(FXC-3827): Delete LegacyEnvironment after deprecating `tidy3d.config.Env`. -class LegacyEnvironment: - """Legacy Env wrapper that maps to profiles.""" - - def __init__(self, manager: ConfigManager): - self._previous_env_vars: dict[str, Optional[str]] = {} - self.env_map: dict[str, LegacyEnvironmentConfig] = {} - self._current: Optional[LegacyEnvironmentConfig] = None - self._manager: Optional[ConfigManager] = None - self._applied_profile: Optional[str] = None - self.reset_manager(manager) - - def reset_manager(self, manager: ConfigManager) -> None: - self._manager = manager - self.env_map = {} - for name in BUILTIN_PROFILES: - key = normalize_profile_name(name) - self.env_map[key] = LegacyEnvironmentConfig(manager, key, environment=self) - self._applied_profile = None - self._current = None - self._sync_to_manager(apply_env=True) - - @property - def current(self) -> LegacyEnvironmentConfig: - self._sync_to_manager() - assert self._current is not None - return self._current - - def set_current(self, env_config: LegacyEnvironmentConfig) -> None: - _warn_env_deprecated() - key = normalize_profile_name(env_config.name) - stored = self._get_config(key) - stored.copy_state_from(env_config) - if self._manager and self._manager.profile != key: - self._manager.switch_profile(key) - self._sync_to_manager(apply_env=True) - - def enable_caching(self, enable_caching: Optional[bool] = True) -> None: - config = self.current - config.enable_caching = enable_caching - self._sync_to_manager() - - def set_ssl_version(self, ssl_version: Optional[str]) -> None: - config = self.current - config.ssl_version = ssl_version - self._sync_to_manager() - - def __getattr__(self, name: str) -> LegacyEnvironmentConfig: - return self._get_config(name) - - def _get_config(self, name: str) -> LegacyEnvironmentConfig: - key = normalize_profile_name(name) - config = self.env_map.get(key) - if config is None: - config = LegacyEnvironmentConfig(self._manager, key, environment=self) - self.env_map[key] = config - else: - manager = self._manager - if manager is not None: - config.reset_manager(manager) - config._environment = self - return config - - def _sync_to_manager(self, *, apply_env: bool = False) -> None: - if self._manager is None: - return - active = normalize_profile_name(self._manager.profile) - config = self._get_config(active) - config.apply_pending_overrides() - self._current = config - if apply_env or self._applied_profile != active: - self._apply_env_vars(config) - self._applied_profile = active - - def _apply_env_vars(self, config: LegacyEnvironmentConfig) -> None: - self._restore_env_vars() - env_vars = config.env_vars or {} - self._previous_env_vars = {} - for key, value in env_vars.items(): - self._previous_env_vars[key] = os.environ.get(key) - os.environ[key] = value - - def _restore_env_vars(self) -> None: - for key, previous in self._previous_env_vars.items(): - if previous is None: - os.environ.pop(key, None) - else: - os.environ[key] = previous - self._previous_env_vars = {} - - -def _maybe_str(value: Any) -> Optional[str]: - if value is None: - return None - return str(value) - - -def load_legacy_flat_config(config_dir: Path) -> dict[str, Any]: - """Load legacy flat configuration file (pre-migration format). - - This function now supports both the original flat config format and - Nexus custom deployment settings introduced in later versions. - - Legacy key mappings: - - apikey -> web.apikey - - web_api_endpoint -> web.api_endpoint - - website_endpoint -> web.website_endpoint - - s3_region -> web.s3_region - - s3_endpoint -> web.env_vars.AWS_ENDPOINT_URL_S3 - - ssl_verify -> web.ssl_verify - - enable_caching -> web.enable_caching - """ - - legacy_path = config_dir / "config" - if not legacy_path.exists(): - return {} - - try: - text = legacy_path.read_text(encoding="utf-8") - except Exception as exc: - log.warning(f"Failed to read legacy configuration file '{legacy_path}': {exc}") - return {} - - try: - parsed = toml.loads(text) - except Exception as exc: - log.warning(f"Failed to decode legacy configuration file '{legacy_path}': {exc}") - return {} - - legacy_data: dict[str, Any] = {} - - # Migrate API key (original functionality) - apikey = parsed.get("apikey") - if apikey is not None: - legacy_data.setdefault("web", {})["apikey"] = apikey - - # Migrate Nexus API endpoint - web_api = parsed.get("web_api_endpoint") - if web_api is not None: - legacy_data.setdefault("web", {})["api_endpoint"] = web_api - - # Migrate Nexus website endpoint - website = parsed.get("website_endpoint") - if website is not None: - legacy_data.setdefault("web", {})["website_endpoint"] = website - - # Migrate S3 region - s3_region = parsed.get("s3_region") - if s3_region is not None: - legacy_data.setdefault("web", {})["s3_region"] = s3_region - - # Migrate SSL verification setting - ssl_verify = parsed.get("ssl_verify") - if ssl_verify is not None: - legacy_data.setdefault("web", {})["ssl_verify"] = ssl_verify - - # Migrate caching setting - enable_caching = parsed.get("enable_caching") - if enable_caching is not None: - legacy_data.setdefault("web", {})["enable_caching"] = enable_caching - - # Migrate S3 endpoint to env_vars - s3_endpoint = parsed.get("s3_endpoint") - if s3_endpoint is not None: - env_vars = legacy_data.setdefault("web", {}).setdefault("env_vars", {}) - env_vars["AWS_ENDPOINT_URL_S3"] = s3_endpoint - - return legacy_data - - -__all__ = [ - "LegacyConfigWrapper", - "LegacyEnvironment", - "LegacyEnvironmentConfig", - "finalize_legacy_migration", - "load_legacy_flat_config", -] - - -def finalize_legacy_migration(config_dir: Path) -> None: - """Promote a copied legacy configuration tree into the structured format. - - Parameters - ---------- - config_dir : Path - Destination directory (typically the canonical config location). - """ - - legacy_data = load_legacy_flat_config(config_dir) - - from .manager import ConfigManager # local import to avoid circular dependency - - manager = ConfigManager(profile="default", config_dir=config_dir) - config_path = config_dir / "config.toml" - for section, values in legacy_data.items(): - if isinstance(values, dict): - manager.update_section(section, **values) - try: - manager.save(include_defaults=True) - except Exception: - if config_path.exists(): - try: - config_path.unlink() - except Exception: - pass - raise - - legacy_flat_path = config_dir / "config" - if legacy_flat_path.exists(): - try: - legacy_flat_path.unlink() - except Exception as exc: - log.warning(f"Failed to remove legacy configuration file '{legacy_flat_path}': {exc}") +from tidy3d._common.config.legacy import ( + LegacyConfigWrapper, + LegacyEnvironment, + LegacyEnvironmentConfig, + _maybe_str, + _warn_env_deprecated, + finalize_legacy_migration, + load_legacy_flat_config, +) diff --git a/tidy3d/config/loader.py b/tidy3d/config/loader.py index 952f440322..b614d63401 100644 --- a/tidy3d/config/loader.py +++ b/tidy3d/config/loader.py @@ -1,451 +1,23 @@ -"""Filesystem helpers and persistence utilities for the configuration system.""" +"""Compatibility shim for :mod:`tidy3d._common.config.loader`.""" -from __future__ import annotations - -import os -import shutil -import tempfile -from copy import deepcopy -from pathlib import Path -from typing import TYPE_CHECKING, Any - -import toml -import tomlkit - -from tidy3d.log import log - -from .profiles import BUILTIN_PROFILES -from .serializer import build_document, collect_descriptions - -if TYPE_CHECKING: - from typing import Optional - - -class ConfigLoader: - """Handle reading and writing configuration files.""" - - def __init__(self, config_dir: Optional[Path] = None): - self.config_dir = config_dir or resolve_config_directory() - self.config_dir.mkdir(mode=0o700, parents=True, exist_ok=True) - self._docs: dict[Path, tomlkit.TOMLDocument] = {} - - def load_base(self) -> dict[str, Any]: - """Load base configuration from config.toml. - - If config.toml doesn't exist but the legacy flat config does, - automatically migrate to the new format. - """ - - config_path = self.config_dir / "config.toml" - data = self._read_toml(config_path) - if data: - return data - - # Check for legacy flat config - from .legacy import load_legacy_flat_config - - legacy_path = self.config_dir / "config" - legacy = load_legacy_flat_config(self.config_dir) - - # Auto-migrate if legacy config exists - if legacy and legacy_path.exists(): - log.info( - f"Detected legacy configuration at '{legacy_path}'. " - "Automatically migrating to new format..." - ) - - try: - # Save in new format - self.save_base(legacy) - - # Rename old config to preserve it - backup_path = legacy_path.with_suffix(".migrated") - legacy_path.rename(backup_path) - - log.info( - f"Migration complete. Configuration saved to '{config_path}'. " - f"Legacy config backed up as '{backup_path.name}'." - ) - - # Re-read the newly created config - return self._read_toml(config_path) - except Exception as exc: - log.warning( - f"Failed to auto-migrate legacy configuration: {exc}. " - "Using legacy data without migration." - ) - return legacy - - if legacy: - return legacy - return {} - - def load_user_profile(self, profile: str) -> dict[str, Any]: - """Load user profile overrides (if any).""" - - if profile in ("default", "prod"): - # default and prod share the same baseline; user overrides live in config.toml - return {} - - profile_path = self.profile_path(profile) - return self._read_toml(profile_path) - - def get_builtin_profile(self, profile: str) -> dict[str, Any]: - """Return builtin profile data if available.""" - - return BUILTIN_PROFILES.get(profile, {}) - - def save_base(self, data: dict[str, Any]) -> None: - """Persist base configuration.""" - - config_path = self.config_dir / "config.toml" - self._atomic_write(config_path, data) - - def save_profile(self, profile: str, data: dict[str, Any]) -> None: - """Persist profile overrides (remove file if empty).""" - - profile_path = self.profile_path(profile) - if not data: - if profile_path.exists(): - profile_path.unlink() - self._docs.pop(profile_path, None) - return - profile_path.parent.mkdir(mode=0o700, parents=True, exist_ok=True) - self._atomic_write(profile_path, data) - - def profile_path(self, profile: str) -> Path: - """Return on-disk path for a profile.""" - - return self.config_dir / "profiles" / f"{profile}.toml" - - def get_default_profile(self) -> Optional[str]: - """Read the default_profile from config.toml. - - Returns - ------- - Optional[str] - The default profile name if set, None otherwise. - """ - - config_path = self.config_dir / "config.toml" - if not config_path.exists(): - return None - - try: - text = config_path.read_text(encoding="utf-8") - data = toml.loads(text) - return data.get("default_profile") - except Exception as exc: - log.warning(f"Failed to read default_profile from '{config_path}': {exc}") - return None - - def set_default_profile(self, profile: Optional[str]) -> None: - """Set the default_profile in config.toml. - - Parameters - ---------- - profile : Optional[str] - The profile name to set as default, or None to remove the setting. - """ - - config_path = self.config_dir / "config.toml" - data = self._read_toml(config_path) - - if profile is None: - # Remove default_profile if it exists - if "default_profile" in data: - del data["default_profile"] - else: - # Set default_profile as a top-level key - data["default_profile"] = profile - - self._atomic_write(config_path, data) - - def _read_toml(self, path: Path) -> dict[str, Any]: - if not path.exists(): - self._docs.pop(path, None) - return {} - - try: - text = path.read_text(encoding="utf-8") - except Exception as exc: - log.warning(f"Failed to read configuration file '{path}': {exc}") - self._docs.pop(path, None) - return {} - - try: - document = tomlkit.parse(text) - except Exception as exc: - log.warning(f"Failed to parse configuration file '{path}': {exc}") - document = tomlkit.document() - self._docs[path] = document - - try: - return toml.loads(text) - except Exception as exc: - log.warning(f"Failed to decode configuration file '{path}': {exc}") - return {} - - def _atomic_write(self, path: Path, data: dict[str, Any]) -> None: - path.parent.mkdir(mode=0o700, parents=True, exist_ok=True) - tmp_dir = path.parent - - cleaned = _clean_data(deepcopy(data)) - descriptions = collect_descriptions() - - base_document = self._docs.get(path) - document = build_document(cleaned, base_document, descriptions) - toml_text = tomlkit.dumps(document) - - with tempfile.NamedTemporaryFile( - "w", dir=tmp_dir, delete=False, encoding="utf-8" - ) as handle: - tmp_path = Path(handle.name) - handle.write(toml_text) - handle.flush() - os.fsync(handle.fileno()) - - backup_path = path.with_suffix(path.suffix + ".bak") - try: - if path.exists(): - shutil.copy2(path, backup_path) - tmp_path.replace(path) - os.chmod(path, 0o600) - if backup_path.exists(): - backup_path.unlink() - except Exception: - if tmp_path.exists(): - tmp_path.unlink() - if backup_path.exists(): - try: - backup_path.replace(path) - except Exception: - log.warning("Failed to restore configuration backup") - raise +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility - self._docs[path] = tomlkit.parse(toml_text) - - -def load_environment_overrides() -> dict[str, Any]: - """Parse environment variables into a nested configuration dict.""" - - overrides: dict[str, Any] = {} - for key, value in os.environ.items(): - if key == "SIMCLOUD_APIKEY": - _assign_path(overrides, ("web", "apikey"), value) - continue - if not key.startswith("TIDY3D_"): - continue - rest = key[len("TIDY3D_") :] - if "__" not in rest: - continue - segments = tuple(segment.lower() for segment in rest.split("__") if segment) - if not segments: - continue - if segments[0] == "auth": - segments = ("web",) + segments[1:] - _assign_path(overrides, segments, value) - return overrides - - -def deep_merge(*sources: dict[str, Any]) -> dict[str, Any]: - """Deep merge multiple dictionaries into a new dict.""" - - result: dict[str, Any] = {} - for source in sources: - _merge_into(result, source) - return result - - -def _merge_into(target: dict[str, Any], source: dict[str, Any]) -> None: - for key, value in source.items(): - if isinstance(value, dict): - node = target.setdefault(key, {}) - if isinstance(node, dict): - _merge_into(node, value) - else: - target[key] = deepcopy(value) - else: - target[key] = value - - -def deep_diff(base: dict[str, Any], target: dict[str, Any]) -> dict[str, Any]: - """Return keys from target that differ from base.""" - - diff: dict[str, Any] = {} - keys = set(base.keys()) | set(target.keys()) - for key in keys: - base_value = base.get(key) - target_value = target.get(key) - if isinstance(base_value, dict) and isinstance(target_value, dict): - nested = deep_diff(base_value, target_value) - if nested: - diff[key] = nested - elif target_value != base_value: - if isinstance(target_value, dict): - diff[key] = deepcopy(target_value) - else: - diff[key] = target_value - return diff - - -def _assign_path(target: dict[str, Any], path: tuple[str, ...], value: Any) -> None: - node = target - for segment in path[:-1]: - node = node.setdefault(segment, {}) - node[path[-1]] = value - - -def _clean_data(data: Any) -> Any: - if isinstance(data, dict): - cleaned: dict[str, Any] = {} - for key, value in data.items(): - cleaned_value = _clean_data(value) - if cleaned_value is None: - continue - cleaned[key] = cleaned_value - return cleaned - if isinstance(data, list): - cleaned_list = [_clean_data(item) for item in data] - return [item for item in cleaned_list if item is not None] - if data is None: - return None - return data - - -def legacy_config_directory() -> Path: - """Return the legacy configuration directory (~/.tidy3d).""" - - return Path.home() / ".tidy3d" - - -def canonical_config_directory() -> Path: - """Return the platform-dependent canonical configuration directory.""" - - return _xdg_config_home() / "tidy3d" - - -def _warn_legacy_dir_ignored(*, canonical_dir: Path, legacy_dir: Path) -> None: - if legacy_dir.exists(): - log.warning( - f"Using canonical configuration directory at '{canonical_dir}'. " - "Found legacy directory at '~/.tidy3d', which will be ignored. " - "Remove it manually or run 'tidy3d config migrate --delete-legacy' to clean up.", - log_once=True, - ) - - -def resolve_config_directory() -> Path: - """Determine the directory used to store tidy3d configuration files.""" - - base_override = os.getenv("TIDY3D_BASE_DIR") - if base_override: - base_path = Path(base_override).expanduser().resolve() - path = base_path / "config" - if path.is_dir(): - return path - if _is_writable(path.parent): - return path - log.warning( - "'TIDY3D_BASE_DIR' is not writable; using temporary configuration directory instead." - ) - return _temporary_config_dir() - - canonical_dir = canonical_config_directory() - legacy_dir = legacy_config_directory() - if canonical_dir.is_dir(): - _warn_legacy_dir_ignored(canonical_dir=canonical_dir, legacy_dir=legacy_dir) - return canonical_dir - if _is_writable(canonical_dir.parent): - _warn_legacy_dir_ignored(canonical_dir=canonical_dir, legacy_dir=legacy_dir) - return canonical_dir - - if legacy_dir.exists(): - log.warning( - "Configuration found in legacy location '~/.tidy3d'. Consider running 'tidy3d config migrate'.", - log_once=True, - ) - return legacy_dir - - log.warning(f"Unable to write to '{canonical_dir}'; falling back to temporary directory.") - return _temporary_config_dir() - - -def _xdg_config_home() -> Path: - xdg_home = os.getenv("XDG_CONFIG_HOME") - if xdg_home: - return Path(xdg_home).expanduser() - return Path.home() / ".config" - - -def _temporary_config_dir() -> Path: - base = Path(tempfile.gettempdir()) / "tidy3d" - base.mkdir(mode=0o700, exist_ok=True) - return base / "config" - - -def _is_writable(path: Path) -> bool: - try: - path.mkdir(parents=True, exist_ok=True) - fd, test_path = tempfile.mkstemp(dir=path, prefix=".tidy3d_write_test_") - os.close(fd) - try: - Path(test_path).unlink() - except FileNotFoundError: - pass - return True - except Exception: - return False - - -def migrate_legacy_config(*, overwrite: bool = False, remove_legacy: bool = False) -> Path: - """Copy configuration files from the legacy ``~/.tidy3d`` directory to the canonical location. - - Parameters - ---------- - overwrite : bool - If ``True``, existing files in the canonical directory will be replaced. - remove_legacy : bool - If ``True``, the legacy directory is removed after a successful migration. - - Returns - ------- - Path - The path of the canonical configuration directory. - - Raises - ------ - FileNotFoundError - If the legacy directory does not exist. - FileExistsError - If the destination already exists and ``overwrite`` is ``False``. - RuntimeError - If the legacy and canonical directories resolve to the same location. - """ - - legacy_dir = legacy_config_directory() - if not legacy_dir.exists(): - raise FileNotFoundError("Legacy configuration directory '~/.tidy3d' was not found.") - - canonical_dir = canonical_config_directory() - if canonical_dir.resolve() == legacy_dir.resolve(): - raise RuntimeError( - "Legacy and canonical configuration directories are the same path; nothing to migrate." - ) - - if canonical_dir.exists() and not overwrite: - raise FileExistsError( - f"Destination '{canonical_dir}' already exists. Pass overwrite=True to replace existing files." - ) - - canonical_dir.parent.mkdir(parents=True, exist_ok=True) - shutil.copytree(legacy_dir, canonical_dir, dirs_exist_ok=overwrite) - - from .legacy import finalize_legacy_migration # local import to avoid circular dependency - - finalize_legacy_migration(canonical_dir) - - if remove_legacy: - shutil.rmtree(legacy_dir) +# marked as migrated to _common +from __future__ import annotations - return canonical_dir +from tidy3d._common.config.loader import ( + ConfigLoader, + _assign_path, + _clean_data, + _is_writable, + _merge_into, + _temporary_config_dir, + _xdg_config_home, + canonical_config_directory, + deep_diff, + deep_merge, + legacy_config_directory, + load_environment_overrides, + migrate_legacy_config, + resolve_config_directory, +) diff --git a/tidy3d/config/manager.py b/tidy3d/config/manager.py index 0171372180..19dd6975b4 100644 --- a/tidy3d/config/manager.py +++ b/tidy3d/config/manager.py @@ -1,634 +1,24 @@ -"""Central configuration manager implementation.""" +"""Compatibility shim for :mod:`tidy3d._common.config.manager`.""" -from __future__ import annotations - -import os -import shutil -from collections import defaultdict -from copy import deepcopy -from enum import Enum -from io import StringIO -from pathlib import Path -from typing import TYPE_CHECKING, Any, get_args, get_origin - -from pydantic import BaseModel -from rich.console import Console -from rich.panel import Panel -from rich.pretty import Pretty -from rich.text import Text -from rich.tree import Tree - -from tidy3d.log import log - -from .loader import ConfigLoader, deep_diff, deep_merge, load_environment_overrides -from .profiles import BUILTIN_PROFILES -from .registry import attach_manager, get_handlers, get_sections - -if TYPE_CHECKING: - from collections.abc import Iterable, Mapping - from typing import Optional - - -def normalize_profile_name(name: str) -> str: - """Return a canonical profile name for builtin profiles.""" - - normalized = name.strip() - lowered = normalized.lower() - if lowered in BUILTIN_PROFILES: - return lowered - return normalized - - -class SectionAccessor: - """Attribute proxy that routes assignments back through the manager.""" - - def __init__(self, manager: ConfigManager, path: str): - self._manager = manager - self._path = path - - def __getattr__(self, name: str) -> Any: - model = self._manager._get_model(self._path) - if model is None: - raise AttributeError(f"Section '{self._path}' is not available") - return getattr(model, name) - - def __setattr__(self, name: str, value: Any) -> None: - if name.startswith("_"): - object.__setattr__(self, name, value) - return - self._manager.update_section(self._path, **{name: value}) - - def __repr__(self) -> str: - model = self._manager._get_model(self._path) - return f"SectionAccessor({self._path}={model!r})" - - def __rich__(self) -> Panel: - model = self._manager._get_model(self._path) - if model is None: - return Panel(Text(f"Section '{self._path}' is unavailable", style="red")) - data = _prepare_for_display(model.model_dump(exclude_unset=False)) - return _build_section_panel(self._path, data) - - def dict(self, *args: Any, **kwargs: Any) -> dict[str, Any]: - model = self._manager._get_model(self._path) - if model is None: - return {} - return model.model_dump(*args, **kwargs) - - def __str__(self) -> str: - return self._manager.format_section(self._path) - - -class PluginsAccessor: - """Provides access to registered plugin configurations.""" - - def __init__(self, manager: ConfigManager): - self._manager = manager - - def __getattr__(self, plugin: str) -> SectionAccessor: - if plugin not in self._manager._plugin_models: - raise AttributeError(f"Plugin '{plugin}' is not registered") - return SectionAccessor(self._manager, f"plugins.{plugin}") - - def list(self) -> Iterable[str]: - return sorted(self._manager._plugin_models.keys()) - - -class ProfilesAccessor: - """Read-only profile helper.""" - - def __init__(self, manager: ConfigManager): - self._manager = manager - - def list(self) -> dict[str, list[str]]: - return self._manager.list_profiles() - - def __getattr__(self, profile: str) -> dict[str, Any]: - return self._manager.preview_profile(profile) - - -class ConfigManager: - """High-level orchestrator for tidy3d configuration.""" - - def __init__( - self, - profile: Optional[str] = None, - config_dir: Optional[os.PathLike[str]] = None, - ): - loader_path = None if config_dir is None else Path(config_dir) - self._loader = ConfigLoader(loader_path) - self._runtime_overrides: dict[str, dict[str, Any]] = defaultdict(dict) - self._plugin_models: dict[str, BaseModel] = {} - self._section_models: dict[str, BaseModel] = {} - self._profile = self._resolve_initial_profile(profile) - self._builtin_data: dict[str, Any] = {} - self._base_data: dict[str, Any] = {} - self._profile_data: dict[str, Any] = {} - self._raw_tree: dict[str, Any] = {} - self._effective_tree: dict[str, Any] = {} - self._env_overrides: dict[str, Any] = load_environment_overrides() - self._web_env_previous: dict[str, Optional[str]] = {} - - attach_manager(self) - self._reload() - - # Notify users when using a non-default profile - if self._profile != "default": - log.info(f"Using configuration profile: '{self._profile}'", log_once=True) - - self._apply_handlers() - - @property - def profile(self) -> str: - return self._profile - - @property - def config_dir(self) -> Path: - return self._loader.config_dir - - @property - def plugins(self) -> PluginsAccessor: - return PluginsAccessor(self) - - @property - def profiles(self) -> ProfilesAccessor: - return ProfilesAccessor(self) - - def update_section(self, name: str, **updates: Any) -> None: - if not updates: - return - segments = name.split(".") - overrides = self._runtime_overrides[self._profile] - previous = deepcopy(overrides) - node = overrides - for segment in segments[:-1]: - node = node.setdefault(segment, {}) - section_key = segments[-1] - section_payload = node.setdefault(section_key, {}) - for key, value in updates.items(): - section_payload[key] = _serialize_value(value) - try: - self._reload() - except Exception: - self._runtime_overrides[self._profile] = previous - raise - self._apply_handlers(section=name) - - def switch_profile(self, profile: str) -> None: - if not profile: - raise ValueError("Profile name cannot be empty") - normalized = normalize_profile_name(profile) - if not normalized: - raise ValueError("Profile name cannot be empty") - self._profile = normalized - self._reload() - - # Notify users when switching to a non-default profile - if self._profile != "default": - log.info(f"Switched to configuration profile: '{self._profile}'") - - self._apply_handlers() - - def set_default_profile(self, profile: Optional[str]) -> None: - """Set the default profile to be used on startup. - - Parameters - ---------- - profile : Optional[str] - The profile name to use as default, or None to clear the default. - When set, this profile will be automatically loaded unless overridden - by environment variables (TIDY3D_CONFIG_PROFILE, TIDY3D_PROFILE, or TIDY3D_ENV). - - Notes - ----- - This setting is persisted to config.toml and survives across sessions. - Environment variables always take precedence over the default profile. - """ - - if profile is not None: - normalized = normalize_profile_name(profile) - if not normalized: - raise ValueError("Profile name cannot be empty") - self._loader.set_default_profile(normalized) - else: - self._loader.set_default_profile(None) - - def get_default_profile(self) -> Optional[str]: - """Get the currently configured default profile. - - Returns - ------- - Optional[str] - The default profile name if set, None otherwise. - """ - - return self._loader.get_default_profile() - - def save(self, include_defaults: bool = False) -> None: - if self._profile == "default": - # For base config: only save fields marked with persist=True - base_without_env = self._filter_persisted(self._compose_without_env()) - if include_defaults: - defaults = self._filter_persisted(self._default_tree()) - base_without_env = deep_merge(defaults, base_without_env) - self._loader.save_base(base_without_env) - else: - # For profile overrides: save any field that differs from baseline - # (don't filter by persist flag - profiles should save all customizations) - base_without_env = self._compose_without_env() - baseline = deep_merge(self._builtin_data, self._base_data) - diff = deep_diff(baseline, base_without_env) - self._loader.save_profile(self._profile, diff) - # refresh cached base/profile data after saving - self._base_data = self._loader.load_base() - self._profile_data = self._loader.load_user_profile(self._profile) - self._reload() - - def reset_to_defaults(self, *, include_profiles: bool = True) -> None: - """Reset configuration files to their default annotated state.""" - - self._runtime_overrides = defaultdict(dict) - defaults = self._filter_persisted(self._default_tree()) - self._loader.save_base(defaults) - - if include_profiles: - profiles_dir = self._loader.profile_path("_dummy").parent - if profiles_dir.exists(): - shutil.rmtree(profiles_dir) - loader_docs = getattr(self._loader, "_docs", {}) - for path in list(loader_docs.keys()): - try: - path.relative_to(profiles_dir) - except ValueError: - continue - loader_docs.pop(path, None) - self._profile = "default" - - self._reload() - self._apply_handlers() - - def apply_web_env(self, env_vars: Mapping[str, str]) -> None: - """Apply environment variable overrides for the web configuration section.""" - - self._restore_web_env() - for key, value in env_vars.items(): - self._web_env_previous[key] = os.environ.get(key) - os.environ[key] = value - - def _restore_web_env(self) -> None: - """Restore previously overridden environment variables.""" - - for key, previous in self._web_env_previous.items(): - if previous is None: - os.environ.pop(key, None) - else: - os.environ[key] = previous - self._web_env_previous.clear() - - def list_profiles(self) -> dict[str, list[str]]: - profiles_dir = self._loader.config_dir / "profiles" - user_profiles = [] - if profiles_dir.exists(): - for path in profiles_dir.glob("*.toml"): - user_profiles.append(path.stem) - built_in = sorted(name for name in BUILTIN_PROFILES.keys()) - return {"built_in": built_in, "user": sorted(user_profiles)} - - def preview_profile(self, profile: str) -> dict[str, Any]: - builtin = self._loader.get_builtin_profile(profile) - base = self._loader.load_base() - overrides = self._loader.load_user_profile(profile) - view = deep_merge(builtin, base, overrides) - return deepcopy(view) - - def get_section(self, name: str) -> BaseModel: - model = self._get_model(name) - if model is None: - raise AttributeError(f"Section '{name}' is not available") - return model - - def as_dict(self, include_env: bool = True) -> dict[str, Any]: - """Return the current configuration tree, including defaults for all sections.""" - - tree = self._compose_without_env() - if include_env: - tree = deep_merge(tree, self._env_overrides) - return deep_merge(self._default_tree(), tree) - - def __rich__(self) -> Panel: - """Return a rich renderable representation of the full configuration.""" - - return _build_config_panel( - title=f"Config (profile='{self._profile}')", - data=_prepare_for_display(self.as_dict(include_env=True)), - ) - - def format(self, *, include_env: bool = True) -> str: - """Return a human-friendly representation of the full configuration.""" - - panel = _build_config_panel( - title=f"Config (profile='{self._profile}')", - data=_prepare_for_display(self.as_dict(include_env=include_env)), - ) - return _render_panel(panel) - - def format_section(self, name: str) -> str: - """Return a string representation for an individual section.""" - - model = self._get_model(name) - if model is None: - raise AttributeError(f"Section '{name}' is not available") - data = _prepare_for_display(model.model_dump(exclude_unset=False)) - panel = _build_section_panel(name, data) - return _render_panel(panel) - - def on_section_registered(self, section: str) -> None: - self._reload() - self._apply_handlers(section=section) - - def on_handler_registered(self, section: str) -> None: - self._apply_handlers(section=section) - - def _resolve_initial_profile(self, profile: Optional[str]) -> str: - if profile: - return normalize_profile_name(str(profile)) - - # Check environment variables first (highest priority) - env_profile = ( - os.getenv("TIDY3D_CONFIG_PROFILE") - or os.getenv("TIDY3D_PROFILE") - or os.getenv("TIDY3D_ENV") - ) - if env_profile: - return normalize_profile_name(env_profile) - - # Check for default_profile in config file - config_default = self._loader.get_default_profile() - if config_default: - return normalize_profile_name(config_default) - - # Fall back to "default" profile - return "default" - - def _reload(self) -> None: - self._env_overrides = load_environment_overrides() - self._builtin_data = deepcopy(self._loader.get_builtin_profile(self._profile)) - self._base_data = deepcopy(self._loader.load_base()) - self._profile_data = deepcopy(self._loader.load_user_profile(self._profile)) - self._raw_tree = deep_merge(self._builtin_data, self._base_data, self._profile_data) - - runtime = deepcopy(self._runtime_overrides.get(self._profile, {})) - effective = deep_merge(self._raw_tree, self._env_overrides, runtime) - self._effective_tree = effective - self._build_models() - - def _build_models(self) -> None: - sections = get_sections() - new_sections: dict[str, BaseModel] = {} - new_plugins: dict[str, BaseModel] = {} - - errors: list[tuple[str, Exception]] = [] - for name, schema in sections.items(): - if name.startswith("plugins."): - plugin_name = name.split(".", 1)[1] - plugin_data = _deep_get(self._effective_tree, ("plugins", plugin_name)) or {} - try: - new_plugins[plugin_name] = schema(**plugin_data) - except Exception as exc: - log.error(f"Failed to load configuration for plugin '{plugin_name}': {exc}") - errors.append((name, exc)) - continue - if name == "plugins": - continue - section_data = self._effective_tree.get(name, {}) - try: - new_sections[name] = schema(**section_data) - except Exception as exc: - log.error(f"Failed to load configuration for section '{name}': {exc}") - errors.append((name, exc)) - - if errors: - # propagate the first error; others already logged - raise errors[0][1] - - self._section_models = new_sections - self._plugin_models = new_plugins - - def _get_model(self, name: str) -> Optional[BaseModel]: - if name.startswith("plugins."): - plugin = name.split(".", 1)[1] - return self._plugin_models.get(plugin) - return self._section_models.get(name) - - def _apply_handlers(self, section: Optional[str] = None) -> None: - handlers = get_handlers() - targets = [section] if section else handlers.keys() - for target in targets: - handler = handlers.get(target) - if handler is None: - continue - model = self._get_model(target) - if model is None: - continue - try: - handler(model) - except Exception as exc: - log.error(f"Failed to apply configuration handler for '{target}': {exc}") - - def _compose_without_env(self) -> dict[str, Any]: - runtime = self._runtime_overrides.get(self._profile, {}) - return deep_merge(self._raw_tree, runtime) - - def _default_tree(self) -> dict[str, Any]: - defaults: dict[str, Any] = {} - for name, schema in get_sections().items(): - if name.startswith("plugins."): - plugin = name.split(".", 1)[1] - defaults.setdefault("plugins", {})[plugin] = _model_dict(schema()) - elif name == "plugins": - defaults.setdefault("plugins", {}) - else: - defaults[name] = _model_dict(schema()) - return defaults - - def _filter_persisted(self, tree: dict[str, Any]) -> dict[str, Any]: - sections = get_sections() - filtered: dict[str, Any] = {} - plugins_source = tree.get("plugins", {}) - plugin_filtered: dict[str, Any] = {} - - for name, schema in sections.items(): - if name == "plugins": - continue - if name.startswith("plugins."): - plugin_name = name.split(".", 1)[1] - plugin_data = plugins_source.get(plugin_name, {}) - if not isinstance(plugin_data, dict): - continue - persisted_plugin = _extract_persisted(schema, plugin_data) - if persisted_plugin: - plugin_filtered[plugin_name] = persisted_plugin - continue - - section_data = tree.get(name, {}) - if not isinstance(section_data, dict): - continue - persisted_section = _extract_persisted(schema, section_data) - if persisted_section: - filtered[name] = persisted_section - - if plugin_filtered: - filtered["plugins"] = plugin_filtered - return filtered - - def __getattr__(self, name: str) -> Any: - if name in self._section_models: - return SectionAccessor(self, name) - if name == "plugins": - return self.plugins - raise AttributeError(f"Config has no section '{name}'") - - def __setattr__(self, name: str, value: Any) -> None: - if name.startswith("_"): - object.__setattr__(self, name, value) - return - if name in self._section_models: - if isinstance(value, BaseModel): - payload = value.model_dump(exclude_unset=False) - else: - payload = value - self.update_section(name, **payload) - return - object.__setattr__(self, name, value) - - def __str__(self) -> str: - return self.format() - - -def _deep_get(tree: dict[str, Any], path: Iterable[str]) -> Optional[dict[str, Any]]: - node: Any = tree - for segment in path: - if not isinstance(node, dict): - return None - node = node.get(segment) - if node is None: - return None - return node if isinstance(node, dict) else None - - -def _resolve_model_type(annotation: Any) -> Optional[type[BaseModel]]: - """Return the first BaseModel subclass found in an annotation (if any).""" - - if isinstance(annotation, type) and issubclass(annotation, BaseModel): - return annotation - - origin = get_origin(annotation) - if origin is None: - return None - - for arg in get_args(annotation): - nested = _resolve_model_type(arg) - if nested is not None: - return nested - return None - - -def _serialize_value(value: Any) -> Any: - if isinstance(value, BaseModel): - return value.model_dump(exclude_unset=False) - if hasattr(value, "get_secret_value"): - return value.get_secret_value() - return value - - -def _prepare_for_display(value: Any) -> Any: - if isinstance(value, BaseModel): - return { - k: _prepare_for_display(v) for k, v in value.model_dump(exclude_unset=False).items() - } - if isinstance(value, dict): - return {str(k): _prepare_for_display(v) for k, v in value.items()} - if isinstance(value, (list, tuple, set)): - return [_prepare_for_display(v) for v in value] - if isinstance(value, Path): - return str(value) - if isinstance(value, Enum): - return value.value - if hasattr(value, "get_secret_value"): - displayed = getattr(value, "display", None) - if callable(displayed): - return displayed() - return str(value) - return value - - -def _build_config_panel(title: str, data: dict[str, Any]) -> Panel: - tree = Tree(Text(title, style="bold cyan")) - if data: - for key in sorted(data.keys()): - branch = tree.add(Text(key, style="bold magenta")) - branch.add(Pretty(data[key], expand_all=True)) - else: - tree.add(Text("", style="dim")) - return Panel(tree, border_style="cyan", padding=(0, 1)) - - -def _build_section_panel(name: str, data: Any) -> Panel: - tree = Tree(Text(name, style="bold cyan")) - tree.add(Pretty(data, expand_all=True)) - return Panel(tree, border_style="cyan", padding=(0, 1)) - - -def _render_panel(renderable: Panel, *, width: int = 100) -> str: - buffer = StringIO() - console = Console(file=buffer, record=True, force_terminal=True, width=width, color_system=None) - console.print(renderable) - return buffer.getvalue().rstrip() - - -def _model_dict(model: BaseModel) -> dict[str, Any]: - data = model.model_dump(exclude_unset=False) - for key, value in list(data.items()): - if hasattr(value, "get_secret_value"): - data[key] = value.get_secret_value() - return data - - -def _extract_persisted(schema: type[BaseModel], data: dict[str, Any]) -> dict[str, Any]: - persisted: dict[str, Any] = {} - for field_name, field in schema.model_fields.items(): - schema_extra = field.json_schema_extra or {} - annotation = field.annotation - persist = bool(schema_extra.get("persist")) if isinstance(schema_extra, dict) else False - if not persist: - continue - if field_name not in data: - continue - value = data[field_name] - if value is None: - persisted[field_name] = None - continue - - nested_type = _resolve_model_type(annotation) - if nested_type is not None: - nested_source = value if isinstance(value, dict) else {} - nested_persisted = _extract_persisted(nested_type, nested_source) - if nested_persisted: - persisted[field_name] = nested_persisted - continue - - if hasattr(value, "get_secret_value"): - persisted[field_name] = value.get_secret_value() - else: - persisted[field_name] = deepcopy(value) - - return persisted +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility +# marked as migrated to _common +from __future__ import annotations -__all__ = [ - "ConfigManager", - "PluginsAccessor", - "ProfilesAccessor", - "SectionAccessor", - "normalize_profile_name", -] +from tidy3d._common.config.manager import ( + BUILTIN_PROFILES, + ConfigManager, + PluginsAccessor, + ProfilesAccessor, + SectionAccessor, + _build_config_panel, + _build_section_panel, + _deep_get, + _extract_persisted, + _model_dict, + _prepare_for_display, + _render_panel, + _resolve_model_type, + _serialize_value, + normalize_profile_name, +) diff --git a/tidy3d/config/profiles.py b/tidy3d/config/profiles.py index 29bbb43180..a7870a6f1b 100644 --- a/tidy3d/config/profiles.py +++ b/tidy3d/config/profiles.py @@ -1,64 +1,10 @@ -"""Built-in configuration profiles for tidy3d.""" +"""Compatibility shim for :mod:`tidy3d._common.config.profiles`.""" -from __future__ import annotations - -from typing import Any +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -BUILTIN_PROFILES: dict[str, dict[str, Any]] = { - "default": { - "web": { - "api_endpoint": "https://tidy3d-api.simulation.cloud", - "website_endpoint": "https://tidy3d.simulation.cloud", - "s3_region": "us-gov-west-1", - } - }, - "prod": { - "web": { - "api_endpoint": "https://tidy3d-api.simulation.cloud", - "website_endpoint": "https://tidy3d.simulation.cloud", - "s3_region": "us-gov-west-1", - } - }, - "dev": { - "web": { - "api_endpoint": "https://tidy3d-api.dev-simulation.cloud", - "website_endpoint": "https://tidy3d.dev-simulation.cloud", - "s3_region": "us-east-1", - } - }, - "uat": { - "web": { - "api_endpoint": "https://tidy3d-api.uat-simulation.cloud", - "website_endpoint": "https://tidy3d.uat-simulation.cloud", - "s3_region": "us-west-2", - } - }, - "pre": { - "web": { - "api_endpoint": "https://preprod-tidy3d-api.simulation.cloud", - "website_endpoint": "https://preprod-tidy3d.simulation.cloud", - "s3_region": "us-gov-west-1", - } - }, - "nexus": { - "web": { - "api_endpoint": "http://127.0.0.1:5000", - "website_endpoint": "http://127.0.0.1/tidy3d", - "ssl_verify": False, - "enable_caching": False, - "s3_region": "us-east-1", - "env_vars": { - "AWS_ENDPOINT_URL_S3": "http://127.0.0.1:9000", - }, - } - }, - "test": { - "web": { - "s3_region": "test", - "api_endpoint": "https://test", - "website_endpoint": "https://test", - } - }, -} +# marked as migrated to _common +from __future__ import annotations -__all__ = ["BUILTIN_PROFILES"] +from tidy3d._common.config.profiles import ( + BUILTIN_PROFILES, +) diff --git a/tidy3d/config/registry.py b/tidy3d/config/registry.py index 7c1b16b7a1..ad4a9bddee 100644 --- a/tidy3d/config/registry.py +++ b/tidy3d/config/registry.py @@ -1,83 +1,21 @@ -"""Registry utilities for tidy3d configuration sections and handlers.""" +"""Compatibility shim for :mod:`tidy3d._common.config.registry`.""" -from __future__ import annotations - -from typing import TYPE_CHECKING, TypeVar - -from pydantic import BaseModel - -if TYPE_CHECKING: - from typing import Callable, Optional - -T = TypeVar("T", bound=BaseModel) - -_SECTIONS: dict[str, type[BaseModel]] = {} -_HANDLERS: dict[str, Callable[[BaseModel], None]] = {} -_MANAGER: Optional[ConfigManagerProtocol] = None - - -class ConfigManagerProtocol: - """Protocol-like interface for manager notifications.""" - - def on_section_registered(self, section: str) -> None: - """Called when a new section schema is registered.""" - - def on_handler_registered(self, section: str) -> None: - """Called when a handler is registered.""" - - -def attach_manager(manager: ConfigManagerProtocol) -> None: - """Attach the active configuration manager for registry callbacks.""" - - global _MANAGER - _MANAGER = manager - - -def get_manager() -> Optional[ConfigManagerProtocol]: - """Return the currently attached configuration manager, if any.""" - - return _MANAGER +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility +# marked as migrated to _common +from __future__ import annotations -def register_section(name: str) -> Callable[[type[T]], type[T]]: - """Decorator to register a configuration section schema.""" - - def decorator(cls: type[T]) -> type[T]: - _SECTIONS[name] = cls - if _MANAGER is not None: - _MANAGER.on_section_registered(name) - return cls - - return decorator - - -def register_plugin(name: str) -> Callable[[type[T]], type[T]]: - """Decorator to register a plugin configuration schema.""" - - return register_section(f"plugins.{name}") - - -def register_handler( - name: str, -) -> Callable[[Callable[[BaseModel], None]], Callable[[BaseModel], None]]: - """Decorator to register a handler for a configuration section.""" - - def decorator(func: Callable[[BaseModel], None]) -> Callable[[BaseModel], None]: - _HANDLERS[name] = func - if _MANAGER is not None: - _MANAGER.on_handler_registered(name) - return func - - return decorator - - -def get_sections() -> dict[str, type[BaseModel]]: - """Return registered section schemas.""" - - return dict(_SECTIONS) - - -def get_handlers() -> dict[str, Callable[[BaseModel], None]]: - """Return registered configuration handlers.""" - - return dict(_HANDLERS) +from tidy3d._common.config.registry import ( + _HANDLERS, + _MANAGER, + _SECTIONS, + ConfigManagerProtocol, + T, + attach_manager, + get_handlers, + get_manager, + get_sections, + register_handler, + register_plugin, + register_section, +) diff --git a/tidy3d/config/serializer.py b/tidy3d/config/serializer.py index 5db5dc5d97..e664881565 100644 --- a/tidy3d/config/serializer.py +++ b/tidy3d/config/serializer.py @@ -1,148 +1,16 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, get_args, get_origin - -import tomlkit -from pydantic import BaseModel -from tomlkit.items import Item, Table - -from .registry import get_sections - -if TYPE_CHECKING: - from collections.abc import Iterable - - from pydantic.fields import FieldInfo - -Path = tuple[str, ...] - - -def collect_descriptions() -> dict[Path, str]: - """Collect description strings for registered configuration fields.""" - - descriptions: dict[Path, str] = {} - for section_name, model in get_sections().items(): - base_path = tuple(segment for segment in section_name.split(".") if segment) - section_doc = (model.__doc__ or "").strip() - if section_doc and base_path: - descriptions[base_path] = descriptions.get( - base_path, section_doc.splitlines()[0].strip() - ) - for field_name, field in model.model_fields.items(): - descriptions.update(_describe_field(field, prefix=(*base_path, field_name))) - return descriptions - - -def _describe_field(field: FieldInfo, prefix: Path) -> dict[Path, str]: - descriptions: dict[Path, str] = {} - description = (field.description or "").strip() - if description: - descriptions[prefix] = description - - nested_models: Iterable[type[BaseModel]] = _iter_model_types(field.annotation) - for model in nested_models: - nested_doc = (model.__doc__ or "").strip() - if nested_doc: - descriptions[prefix] = descriptions.get(prefix, nested_doc.splitlines()[0].strip()) - for sub_name, sub_field in model.model_fields.items(): - descriptions.update(_describe_field(sub_field, prefix=(*prefix, sub_name))) - return descriptions - +"""Compatibility shim for :mod:`tidy3d._common.config.serializer`.""" -def _iter_model_types(annotation: Any) -> Iterable[type[BaseModel]]: - """Yield BaseModel subclasses referenced by a field annotation (if any).""" +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility - if annotation is None: - return - - stack = [annotation] - seen: set[type[BaseModel]] = set() - - while stack: - current = stack.pop() - if isinstance(current, type) and issubclass(current, BaseModel): - if current not in seen: - seen.add(current) - yield current - continue - - origin = get_origin(current) - if origin is None: - continue - - stack.extend(get_args(current)) - - -def build_document( - data: dict[str, Any], - existing: tomlkit.TOMLDocument | None, - descriptions: dict[Path, str] | None = None, -) -> tomlkit.TOMLDocument: - """Return a TOML document populated with data and annotated comments.""" - - descriptions = descriptions or collect_descriptions() - document = existing if existing is not None else tomlkit.document() - _prune_missing_keys(document, data.keys()) - for key, value in data.items(): - _apply_value( - container=document, - key=key, - value=value, - path=(key,), - descriptions=descriptions, - is_new=key not in document, - ) - return document - - -def _prune_missing_keys(container: Table | tomlkit.TOMLDocument, keys: Iterable[str]) -> None: - desired = set(keys) - for existing_key in list(container.keys()): - if existing_key not in desired: - del container[existing_key] - - -def _apply_value( - container: Table | tomlkit.TOMLDocument, - key: str, - value: Any, - path: Path, - descriptions: dict[Path, str], - is_new: bool, -) -> None: - description = descriptions.get(path) - if isinstance(value, dict): - existing = container.get(key) - table = existing if isinstance(existing, Table) else tomlkit.table() - _prune_missing_keys(table, value.keys()) - for sub_key, sub_value in value.items(): - _apply_value( - container=table, - key=sub_key, - value=sub_value, - path=(*path, sub_key), - descriptions=descriptions, - is_new=not isinstance(existing, Table) or sub_key not in table, - ) - if key in container: - container[key] = table - else: - if isinstance(container, tomlkit.TOMLDocument) and len(container) > 0: - container.add(tomlkit.nl()) - container.add(key, table) - return - - if value is None: - return - - existing_item = container.get(key) - new_item = tomlkit.item(value) - if isinstance(existing_item, Item): - new_item.trivia.comment = existing_item.trivia.comment - new_item.trivia.comment_ws = existing_item.trivia.comment_ws - elif description: - new_item.comment(description) +# marked as migrated to _common +from __future__ import annotations - if key in container: - container[key] = new_item - else: - container.add(key, new_item) +from tidy3d._common.config.serializer import ( + Path, + _apply_value, + _describe_field, + _iter_model_types, + _prune_missing_keys, + build_document, + collect_descriptions, +) diff --git a/tidy3d/constants.py b/tidy3d/constants.py index 81b168cad5..15810fcca5 100644 --- a/tidy3d/constants.py +++ b/tidy3d/constants.py @@ -1,313 +1,65 @@ -"""Defines importable constants. +"""Compatibility shim for :mod:`tidy3d._common.constants`.""" -Attributes: - inf (float): Tidy3d representation of infinity. - C_0 (float): Speed of light in vacuum [um/s] - EPSILON_0 (float): Vacuum permittivity [F/um] - MU_0 (float): Vacuum permeability [H/um] - ETA_0 (float): Vacuum impedance - HBAR (float): reduced Planck constant [eV*s] - Q_e (float): funamental charge [C] -""" +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility +# marked as migrated to _common from __future__ import annotations -from types import MappingProxyType - -import numpy as np - -# fundamental constants (https://physics.nist.gov) -C_0 = 2.99792458e14 -""" -Speed of light in vacuum [um/s] -""" - -MU_0 = 1.25663706212e-12 -""" -Vacuum permeability [H/um] -""" - -EPSILON_0 = 1 / (MU_0 * C_0**2) -""" -Vacuum permittivity [F/um] -""" - -#: Free space impedance -ETA_0 = np.sqrt(MU_0 / EPSILON_0) -""" -Vacuum impedance in Ohms -""" - -Q_e = 1.602176634e-19 -""" -Fundamental charge [C] -""" - -HBAR = 6.582119569e-16 -""" -Reduced Planck constant [eV*s] -""" - -K_B = 8.617333262e-5 -""" -Boltzmann constant [eV/K] -""" - -GRAV_ACC = 9.80665 * 1e6 -""" -Gravitational acceleration (g) [um/s^2].", -""" - -M_E_C_SQUARE = 0.51099895069e6 -""" -Electron rest mass energy (m_e * c^2) [eV] -""" - -M_E_EV = M_E_C_SQUARE / C_0**2 -""" -Electron mass [eV*s^2/um^2] -""" - -# floating point precisions -dp_eps = np.finfo(np.float64).eps -""" -Double floating point precision. -""" - -fp_eps = np.float64(np.finfo(np.float32).eps) -""" -Floating point precision. -""" - -# values of PEC for mode solver -pec_val = -1e8 -""" -PEC values for mode solver -""" - -# unit labels -HERTZ = "Hz" -""" -One cycle per second. -""" - -TERAHERTZ = "THz" -""" -One trillion (10^12) cycles per second. -""" - -SECOND = "sec" -""" -SI unit of time. -""" - -PICOSECOND = "ps" -""" -One trillionth (10^-12) of a second. -""" - -METER = "m" -""" -SI unit of length. -""" - -PERMETER = "1/m" -""" -SI unit of inverse length. -""" - -MICROMETER = "um" -""" -One millionth (10^-6) of a meter. -""" - -NANOMETER = "nm" -""" -One billionth (10^-9) of a meter. -""" - -RADIAN = "rad" -""" -SI unit of angle. -""" - -CONDUCTIVITY = "S/um" -""" -Siemens per micrometer. -""" - -PERMITTIVITY = "None (relative permittivity)" -""" -Relative permittivity. -""" - -PML_SIGMA = "2*EPSILON_0/dt" -""" -2 times vacuum permittivity over time differential step. -""" - -RADPERSEC = "rad/sec" -""" -One radian per second. -""" - -RADPERMETER = "rad/m" -""" -One radian per meter. -""" - -NEPERPERMETER = "Np/m" -""" -SI unit for attenuation constant. -""" - - -ELECTRON_VOLT = "eV" -""" -Unit of energy. -""" - -KELVIN = "K" -""" -SI unit of temperature. -""" - -CMCUBE = "cm^3" -""" -Cubic centimeter unit of volume. -""" - -PERCMCUBE = "1/cm^3" -""" -Unit per centimeter cube. -""" - -WATT = "W" -""" -SI unit of power. -""" - -VOLT = "V" -""" -SI unit of electric potential. -""" - -PICOSECOND_PER_NANOMETER_PER_KILOMETER = "ps/(nm km)" -""" -Picosecond per (nanometer kilometer). -""" - -OHM = "ohm" -""" -SI unit of resistance. -""" - -FARAD = "farad" -""" -SI unit of capacitance. -""" - -HENRY = "henry" -""" -SI unit of inductance. -""" - -AMP = "A" -""" -SI unit of electric current. -""" - -THERMAL_CONDUCTIVITY = "W/(um*K)" -""" -Watts per (micrometer Kelvin). -""" - -SPECIFIC_HEAT_CAPACITY = "J/(kg*K)" -""" -Joules per (kilogram Kelvin). -""" - -DENSITY = "kg/um^3" -""" -Kilograms per cubic micrometer. -""" - -HEAT_FLUX = "W/um^2" -""" -Watts per square micrometer. -""" - -VOLUMETRIC_HEAT_RATE = "W/um^3" -""" -Watts per cube micrometer. -""" - -HEAT_TRANSFER_COEFF = "W/(um^2*K)" -""" -Watts per (square micrometer Kelvin). -""" - -CURRENT_DENSITY = "A/um^2" -""" -Amperes per square micrometer -""" - -DYNAMIC_VISCOSITY = "kg/(um*s)" -""" -Kilograms per (micrometer second) -""" - -SPECIFIC_HEAT = "um^2/(s^2*K)" -""" -Square micrometers per (square second Kelvin). -""" - -THERMAL_EXPANSIVITY = "1/K" -""" -Inverse Kelvin. -""" - -VELOCITY_SI = "m/s" -""" -SI unit of velocity -""" - -ACCELERATION = "um/s^2" -""" -Acceleration unit. -""" - -LARGE_NUMBER = 1e10 -""" -Large number used for comparing infinity. -""" - -LARGEST_FP_NUMBER = 1e38 -""" -Largest number used for single precision floating point number. -""" - -inf = np.inf -""" -Representation of infinity used within tidy3d. -""" - -# if |np.pi/2 - angle_theta| < GLANCING_CUTOFF in an angled source or in mode spec, raise warning -GLANCING_CUTOFF = 0.1 -""" -if |np.pi/2 - angle_theta| < GLANCING_CUTOFF in an angled source or in mode spec, raise warning. -""" - -UnitScaling = MappingProxyType( - { - "nm": 1e3, - "μm": 1e0, - "um": 1e0, - "mm": 1e-3, - "cm": 1e-4, - "m": 1e-6, - "mil": 1.0 / 25.4, - "in": 1.0 / 25400, - } +from tidy3d._common.constants import ( + ACCELERATION, + AMP, + C_0, + CMCUBE, + CONDUCTIVITY, + CURRENT_DENSITY, + DENSITY, + DYNAMIC_VISCOSITY, + ELECTRON_VOLT, + EPSILON_0, + ETA_0, + FARAD, + GLANCING_CUTOFF, + GRAV_ACC, + HBAR, + HEAT_FLUX, + HEAT_TRANSFER_COEFF, + HENRY, + HERTZ, + K_B, + KELVIN, + LARGE_NUMBER, + LARGEST_FP_NUMBER, + M_E_C_SQUARE, + M_E_EV, + METER, + MICROMETER, + MU_0, + NANOMETER, + NEPERPERMETER, + OHM, + PERCMCUBE, + PERMETER, + PERMITTIVITY, + PICOSECOND, + PICOSECOND_PER_NANOMETER_PER_KILOMETER, + PML_SIGMA, + RADIAN, + RADPERMETER, + RADPERSEC, + SECOND, + SPECIFIC_HEAT, + SPECIFIC_HEAT_CAPACITY, + TERAHERTZ, + THERMAL_CONDUCTIVITY, + THERMAL_EXPANSIVITY, + VELOCITY_SI, + VOLT, + VOLUMETRIC_HEAT_RATE, + WATT, + Q_e, + UnitScaling, + dp_eps, + fp_eps, + inf, + pec_val, ) -"""Immutable dictionary for converting microns to another spatial unit, eg. nm = um * UnitScaling["nm"].""" diff --git a/tidy3d/exceptions.py b/tidy3d/exceptions.py index 040396d685..4a6ff4822c 100644 --- a/tidy3d/exceptions.py +++ b/tidy3d/exceptions.py @@ -1,64 +1,21 @@ -"""Custom Tidy3D exceptions""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -from .log import log - -if TYPE_CHECKING: - from typing import Optional - - -class Tidy3dError(ValueError): - """Any error in tidy3d""" - - def __init__(self, message: Optional[str] = None, log_error: bool = True) -> None: - """Log just the error message and then raise the Exception.""" - super().__init__(message) - if log_error: - log.error(message) - - -class ConfigError(Tidy3dError): - """Error when configuring Tidy3d.""" - - -class Tidy3dKeyError(Tidy3dError): - """Could not find a key in a Tidy3d dictionary.""" - - -class ValidationError(Tidy3dError): - """Error when constructing Tidy3d components.""" +"""Compatibility shim for :mod:`tidy3d._common.exceptions`.""" +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -class SetupError(Tidy3dError): - """Error regarding the setup of the components (outside of domains, etc).""" - - -class FileError(Tidy3dError): - """Error reading or writing to file.""" - - -class WebError(Tidy3dError): - """Error with the webAPI.""" - - -class AuthenticationError(Tidy3dError): - """Error authenticating a user through webapi webAPI.""" - - -class DataError(Tidy3dError): - """Error accessing data.""" - - -class Tidy3dImportError(Tidy3dError): - """Error importing a package needed for tidy3d.""" - - -class Tidy3dNotImplementedError(Tidy3dError): - """Error when a functionality is not (yet) supported.""" - +# marked as migrated to _common +from __future__ import annotations -class AdjointError(Tidy3dError): - """An error in setting up the adjoint solver.""" +from tidy3d._common.exceptions import ( + AdjointError, + AuthenticationError, + ConfigError, + DataError, + FileError, + SetupError, + Tidy3dError, + Tidy3dImportError, + Tidy3dKeyError, + Tidy3dNotImplementedError, + ValidationError, + WebError, +) diff --git a/tidy3d/log.py b/tidy3d/log.py index 0b738d2865..ca4a776d3c 100644 --- a/tidy3d/log.py +++ b/tidy3d/log.py @@ -1,520 +1,30 @@ -"""Logging Configuration for Tidy3d.""" +"""Compatibility shim for :mod:`tidy3d._common.log`.""" -from __future__ import annotations - -import inspect -from contextlib import contextmanager -from datetime import datetime -from typing import TYPE_CHECKING, Any, Literal, Union - -from rich.console import Console -from rich.text import Text - -if TYPE_CHECKING: - from collections.abc import Iterator - from os import PathLike - from types import TracebackType - from typing import Callable, Optional - - from pydantic import BaseModel - from rich.progress import Progress as RichProgress - - from tidy3d.compat import Self -# Note: "SUPPORT" and "USER" levels are meant for backend runs only. -# Logging in frontend code should just use the standard debug/info/warning/error/critical. -LogLevel = Literal["DEBUG", "SUPPORT", "USER", "INFO", "WARNING", "ERROR", "CRITICAL"] -LogValue = Union[int, LogLevel] - -# Logging levels compatible with logging module -_level_value = { - "DEBUG": 10, - "SUPPORT": 12, - "USER": 15, - "INFO": 20, - "WARNING": 30, - "ERROR": 40, - "CRITICAL": 50, -} - -_level_name = {v: k for k, v in _level_value.items()} - -DEFAULT_LEVEL = "WARNING" - -DEFAULT_LOG_STYLES = { - "DEBUG": None, - "SUPPORT": None, - "USER": None, - "INFO": None, - "WARNING": "red", - "ERROR": "red bold", - "CRITICAL": "red bold", -} - -# Width of the console used for rich logging (in characters). -CONSOLE_WIDTH = 80 - - -def _default_log_level_format(level: str, message: str) -> tuple[str, str]: - """By default just return unformatted prefix and message.""" - return level, message - - -def _get_level_int(level: LogValue) -> int: - """Get the integer corresponding to the level string.""" - if isinstance(level, int): - return level - - if level not in _level_value: - # We don't want to import ConfigError to avoid a circular dependency - raise ValueError( - f"logging level {level} not supported, must be " - "'DEBUG', 'SUPPORT', 'USER', 'INFO', 'WARNING', 'ERROR', or 'CRITICAL'" - ) - return _level_value[level] - - -class LogHandler: - """Handle log messages depending on log level""" - - def __init__( - self, - console: Console, - level: LogValue, - log_level_format: Callable = _default_log_level_format, - prefix_every_line: bool = False, - ) -> None: - self.level = _get_level_int(level) - self.console = console - self.log_level_format = log_level_format - self.prefix_every_line = prefix_every_line - - def handle(self, level: int, level_name: str, message: str) -> None: - """Output log messages depending on log level""" - if level >= self.level: - stack = inspect.stack() - console = self.console - offset = 4 - if stack[offset - 1].filename.endswith("exceptions.py"): - # We want the calling site for exceptions.py - offset += 1 - prefix, msg = self.log_level_format(level_name, message) - if self.prefix_every_line: - wrapped_text = Text(msg, style="default") - msgs = wrapped_text.wrap(console=console, width=console.width - len(prefix) - 2) - else: - msgs = [msg] - for msg in msgs: - console.log( - prefix, - msg, - sep=": ", - style=DEFAULT_LOG_STYLES[level_name], - _stack_offset=offset, - ) - - -class Logger: - """Custom logger to avoid the complexities of the logging module. - - Notes - ----- - The logger can be used in a context manager to avoid the emission of multiple messages. In this - case, the first message in the context is emitted normally, but any others are discarded. When - the context is exited, the number of discarded messages of each level is displayed with the - highest level of the captures messages. - - Messages can also be captured for post-processing. That can be enabled through 'set_capture' to - record warnings emitted during model validation (and other explicit begin/end capture regions, - e.g. validation routines like ``validate_pre_upload``). A structured copy of captured warnings - can then be recovered through 'captured_warnings'. - """ - - _static_cache = set() - - def __init__(self) -> None: - self.handlers = {} - self.suppression = True - self.warn_once = False - self._counts = None - self._stack = None - self._capture = False - self._captured_warnings = [] - - def set_capture(self, capture: bool) -> None: - """Turn on/off tree-like capturing of log messages.""" - self._capture = capture - - def captured_warnings(self) -> list[dict[str, Any]]: - """Get the formatted list of captured log messages.""" - captured_warnings = self._captured_warnings - self._captured_warnings = [] - return captured_warnings - - def __enter__(self) -> Self: - """If suppression is enabled, enter a consolidation context (only a single message is - emitted).""" - if self.suppression and self._counts is None: - self._counts = {} - return self - - def __exit__( - self, - exc_type: Optional[type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType], - ) -> Literal[False]: - """Exist a consolidation context (report the number of messages discarded).""" - if self._counts is not None: - total = sum(v for v in self._counts.values()) - if total > 0: - max_level = max(k for k, v in self._counts.items() if v > 0) - counts = [f"{v} {_level_name[k]}" for k, v in self._counts.items() if v > 0] - self._counts = None - if total > 0: - noun = " messages." if total > 1 else " message." - # Temporarily prevent capturing messages to emit consolidated summary - stack = self._stack - self._stack = None - self.log(max_level, "Suppressed " + ", ".join(counts) + noun) - self._stack = stack - return False - - def begin_capture(self) -> None: - """Start capturing log stack for consolidated validation log. - - This method should be called before a validation routine starts. It must be followed by a - corresponding 'end_capture'. - """ - if not self._capture: - return - - stack_item = {"messages": [], "children": {}} - if self._stack: - self._stack.append(stack_item) - else: - self._stack = [stack_item] - - def abort_capture(self) -> None: - """Undo the last ``begin_capture()`` call. - - This is used when validation fails before reaching the corresponding ``end_capture()``. - """ - if not self._stack: - return - - self._stack.pop() - if len(self._stack) == 0: - self._stack = None - - def end_capture(self, model: BaseModel) -> None: - """End capturing log stack for consolidated validation log. - - This method should be called after a validation routine ends. It must follow a - corresponding 'begin_capture'. - """ - if not self._stack: - return - - stack_item = self._stack.pop() - if len(self._stack) == 0: - self._stack = None - - # Check if this stack item contains any messages or children - if len(stack_item["messages"]) > 0 or len(stack_item["children"]) > 0: - stack_item["type"] = model.__class__.__name__ - - # Set the path for each children - model_fields = model.get_submodels_by_hash() - for child_hash, child_dict in stack_item["children"].items(): - child_dict["parent_fields"] = model_fields.get(child_hash, []) - - # Are we at the bottom of the stack? - if self._stack is None: - # Yes, we're root - self._parse_warning_capture(current_loc=[], stack_item=stack_item) - else: - # No, we're someone else's child - hash_ = hash(model) - self._stack[-1]["children"][hash_] = stack_item - - def _parse_warning_capture(self, current_loc: list[Any], stack_item: dict[str, Any]) -> None: - """Process capture tree to compile formatted captured warnings.""" - - if "parent_fields" in stack_item: - for field in stack_item["parent_fields"]: - if isinstance(field, tuple): - # array field - new_loc = current_loc + list(field) - else: - # single field - new_loc = [*current_loc, field] - - # process current level warnings - for level, msg, custom_loc in stack_item["messages"]: - if level == "WARNING": - self._captured_warnings.append({"loc": new_loc + custom_loc, "msg": msg}) - - # initialize processing at children level - for child_stack in stack_item["children"].values(): - self._parse_warning_capture(current_loc=new_loc, stack_item=child_stack) - - else: # for root object - # process current level warnings - for level, msg, custom_loc in stack_item["messages"]: - if level == "WARNING": - self._captured_warnings.append({"loc": current_loc + custom_loc, "msg": msg}) - - # initialize processing at children level - for child_stack in stack_item["children"].values(): - self._parse_warning_capture(current_loc=current_loc, stack_item=child_stack) +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility - def _log( - self, - level: int, - level_name: str, - message: str, - *args: Any, - log_once: bool = False, - custom_loc: Optional[list] = None, - capture: bool = True, - ) -> None: - """Distribute log messages to all handlers""" - - # Check global cache if requested or if warn_once is enabled for warnings - # (before composing/capturing to avoid duplicates) - should_check_cache = log_once or (self.warn_once and level_name == "WARNING") - if should_check_cache: - # Use the message body before composition as key - if message in self._static_cache: - return - self._static_cache.add(message) - - # Compose message - if len(args) > 0: - try: - composed_message = str(message) % args - - except Exception as e: - composed_message = f"{message} % {args}\n{e}" - else: - composed_message = str(message) - - # Capture all messages (even if suppressed later) - if self._stack and capture: - if custom_loc is None: - custom_loc = [] - self._stack[-1]["messages"].append((level_name, composed_message, custom_loc)) - - # Context-local logger emits a single message and consolidates the rest - if self._counts is not None: - if len(self._counts) > 0: - self._counts[level] = 1 + self._counts.get(level, 0) - return - self._counts[level] = 0 - - # Forward message to handlers - for handler in self.handlers.values(): - handler.handle(level, level_name, composed_message) - - def log(self, level: LogValue, message: str, *args: Any, log_once: bool = False) -> None: - """Log (message) % (args) with given level""" - if isinstance(level, str): - level_name = level - level = _get_level_int(level) - else: - level_name = _level_name.get(level, "unknown") - self._log(level, level_name, message, *args, log_once=log_once) - - def debug(self, message: str, *args: Any, log_once: bool = False) -> None: - """Log (message) % (args) at debug level""" - self._log(_level_value["DEBUG"], "DEBUG", message, *args, log_once=log_once) - - def support(self, message: str, *args: Any, log_once: bool = False) -> None: - """Log (message) % (args) at support level""" - self._log(_level_value["SUPPORT"], "SUPPORT", message, *args, log_once=log_once) - - def user(self, message: str, *args: Any, log_once: bool = False) -> None: - """Log (message) % (args) at user level""" - self._log(_level_value["USER"], "USER", message, *args, log_once=log_once) - - def info(self, message: str, *args: Any, log_once: bool = False) -> None: - """Log (message) % (args) at info level""" - self._log(_level_value["INFO"], "INFO", message, *args, log_once=log_once) - - def warning( - self, - message: str, - *args: Any, - log_once: bool = False, - custom_loc: Optional[list] = None, - capture: bool = True, - ) -> None: - """Log (message) % (args) at warning level""" - self._log( - _level_value["WARNING"], - "WARNING", - message, - *args, - log_once=log_once, - custom_loc=custom_loc, - capture=capture, - ) - - def error(self, message: str, *args: Any, log_once: bool = False) -> None: - """Log (message) % (args) at error level""" - self._log(_level_value["ERROR"], "ERROR", message, *args, log_once=log_once) - - def critical(self, message: str, *args: Any, log_once: bool = False) -> None: - """Log (message) % (args) at critical level""" - self._log(_level_value["CRITICAL"], "CRITICAL", message, *args, log_once=log_once) - - -def set_logging_level(level: LogValue = DEFAULT_LEVEL) -> None: - """Set tidy3d console logging level priority. - - Parameters - ---------- - level : str - The lowest priority level of logging messages to display. One of ``{'DEBUG', 'SUPPORT', - 'USER', INFO', 'WARNING', 'ERROR', 'CRITICAL'}`` (listed in increasing priority). - """ - if "console" in log.handlers: - log.handlers["console"].level = _get_level_int(level) - - -def set_log_suppression(value: bool) -> None: - """Control log suppression for repeated messages.""" - log.suppression = value - - -def set_warn_once(value: bool) -> None: - """Control whether warnings are only shown once per unique message. - - Parameters - ---------- - value : bool - When True, each unique warning message is only shown once per process. - """ - log.warn_once = value - - -def get_aware_datetime() -> datetime: - """Get an aware current local datetime(with local timezone info)""" - return datetime.now().astimezone() - - -def set_logging_console(stderr: bool = False) -> None: - """Set stdout or stderr as console output - - Parameters - ---------- - stderr : bool - If False, logs are directed to stdout, otherwise to stderr. - """ - if "console" in log.handlers: - previous_level = log.handlers["console"].level - else: - previous_level = DEFAULT_LEVEL - log.handlers["console"] = LogHandler( - Console( - stderr=stderr, - width=CONSOLE_WIDTH, - log_path=False, - get_datetime=get_aware_datetime, - log_time_format="%X %Z", - ), - previous_level, - ) - - -def set_logging_file( - fname: PathLike, - filemode: str = "w", - level: LogValue = DEFAULT_LEVEL, - log_path: bool = False, -) -> None: - """Set a file to write log to, independently from the stdout and stderr - output chosen using :meth:`set_logging_level`. - - Parameters - ---------- - fname : PathLike - Path to file to direct the output to. If empty string, a previously set logging file will - be closed, if any, but nothing else happens. - filemode : str - 'w' or 'a', defining if the file should be overwritten or appended. - level : str - One of ``{'DEBUG', 'SUPPORT', 'USER', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'}``. This is set - for the file independently of the console output level set by :meth:`set_logging_level`. - log_path : bool = False - Whether to log the path to the file that issued the message. - """ - if filemode not in "wa": - raise ValueError("filemode must be either 'w' or 'a'") - - # Close previous handler, if any - if "file" in log.handlers: - try: - log.handlers["file"].console.file.close() - except Exception: # TODO: catch specific exception - log.warning("Log file could not be closed") - finally: - del log.handlers["file"] - - if str(fname) == "": - # Empty string can be passed to just stop previously opened file handler - return - - try: - file = open(fname, filemode) - except Exception: # TODO: catch specific exception - log.error(f"File {fname} could not be opened") - return - - log.handlers["file"] = LogHandler( - Console(file=file, force_jupyter=False, log_path=log_path), level - ) - - -# Initialize Tidy3d's logger -log = Logger() - -# Set default logging output -set_logging_console() - - -def get_logging_console() -> Console: - """Get console from logging handlers.""" - if "console" not in log.handlers: - set_logging_console() - return log.handlers["console"].console - - -class NoOpProgress: - """Dummy progress manager that doesn't show any output.""" - - def __enter__(self) -> Self: - return self - - def __exit__(self, *args: Any, **kwargs: Any) -> None: - pass - - def add_task(self, *args: Any, **kwargs: Any) -> None: - pass - - def update(self, *args: Any, **kwargs: Any) -> None: - pass - - -@contextmanager -def Progress(console: Console, show_progress: bool) -> Iterator[Union[RichProgress, NoOpProgress]]: - """Progress manager that wraps ``rich.Progress`` if ``show_progress`` is ``True``, - and ``NoOpProgress`` otherwise.""" - if show_progress: - from rich.progress import Progress +# marked as migrated to _common +from __future__ import annotations - with Progress(console=console) as progress: - yield progress - else: - with NoOpProgress() as progress: - yield progress +from tidy3d._common.log import ( + CONSOLE_WIDTH, + DEFAULT_LEVEL, + DEFAULT_LOG_STYLES, + Logger, + LogHandler, + LogLevel, + LogValue, + NoOpProgress, + Progress, + _default_log_level_format, + _get_level_int, + _level_name, + _level_value, + get_aware_datetime, + get_logging_console, + log, + set_log_suppression, + set_logging_console, + set_logging_file, + set_logging_level, + set_warn_once, +) diff --git a/tidy3d/material_library/material_library.py b/tidy3d/material_library/material_library.py index 78a83e728c..73481e811a 100644 --- a/tidy3d/material_library/material_library.py +++ b/tidy3d/material_library/material_library.py @@ -10,7 +10,12 @@ from tidy3d.components.base import Tidy3dBaseModel from tidy3d.components.material.multi_physics import MultiPhysicsMedium from tidy3d.components.material.tcad.charge import SemiconductorMedium -from tidy3d.components.medium import AnisotropicMedium, Medium2D, PoleResidue, Sellmeier +from tidy3d.components.medium import ( + AnisotropicMedium, + Medium2D, + PoleResidue, + Sellmeier, +) from tidy3d.components.tcad.bandgap_energy import ConstantEnergyBandGap from tidy3d.components.tcad.types import ( AugerRecombination, diff --git a/tidy3d/packaging.py b/tidy3d/packaging.py index 5a1f0fd0b9..5ef16c3c1e 100644 --- a/tidy3d/packaging.py +++ b/tidy3d/packaging.py @@ -1,288 +1,28 @@ -""" -This file contains a set of functions relating to packaging tidy3d for distribution. Sections of the codebase should depend on this file, but this file should not depend on any other part of the codebase. +"""Compatibility shim for :mod:`tidy3d._common.packaging`.""" -This section should only depend on the standard core installation in the pyproject.toml, and should not depend on any other part of the codebase optional imports. -""" +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility +# marked as partially migrated to _common from __future__ import annotations import functools -from importlib import import_module -from importlib.util import find_spec -from typing import TYPE_CHECKING, Any, Callable, TypeVar - -import numpy as np - -from tidy3d.config import config - -from .exceptions import Tidy3dImportError -from .version import __version__ +from typing import TYPE_CHECKING, Any + +from tidy3d._common.config import config +from tidy3d._common.exceptions import Tidy3dImportError +from tidy3d._common.packaging import ( + _check_tidy3d_extras_available, + check_import, + check_tidy3d_extras_licensed_feature, + get_numpy_major_version, + requires_vtk, + tidy3d_extras, + verify_packages_import, + vtk, +) if TYPE_CHECKING: - from typing import Literal - -F = TypeVar("F", bound=Callable[..., Any]) - -vtk = { - "mod": None, - "id_type": np.int64, - "vtk_to_numpy": None, - "numpy_to_vtkIdTypeArray": None, - "numpy_to_vtk": None, -} - -tidy3d_extras = {"mod": None, "use_local_subpixel": None} - - -def check_import(module_name: str) -> bool: - """ - Check if a module or submodule section has been imported. This is a functional way of loading packages that will still load the corresponding module into the total space. - - Parameters - ---------- - module_name - - Returns - ------- - bool - True if the module has been imported, False otherwise. - - """ - try: - import_module(module_name) - return True - except ImportError: - return False - - -def verify_packages_import( - modules: list[str], required: Literal["any", "all"] = "all" -) -> Callable[[F], F]: - def decorator(func: F) -> F: - """ - When decorating a method, requires that the specified modules are available. It will raise an error if the - module is not available depending on the value of the 'required' parameter which represents the type of - import required. - - There are a few options to choose for the 'required' parameter: - - 'all': All the modules must be available for the operation to continue without raising an error - - 'any': At least one of the modules must be available for the operation to continue without raising an error - - Parameters - ---------- - func - The function to decorate. - - Returns - ------- - checks_modules_import - The decorated function. - - """ - - @functools.wraps(func) - def checks_modules_import(*args: Any, **kwargs: Any) -> Any: - """ - Checks if the modules are available. If they are not available, it will raise an error depending on the value. - """ - available_modules_status = [] - maximum_amount_modules = len(modules) - - module_id_i = 0 - for module in modules: - # Starts counting from one so that it can be compared to len(modules) - module_id_i += 1 - import_available = check_import(module) - available_modules_status.append( - import_available - ) # Stores the status of the module import - - if not import_available: - if required == "all": - raise Tidy3dImportError( - f"The package '{module}' is required for this operation, but it was not found. " - f"Please install the '{module}' dependencies using, for example, " - f"'pip install tidy3d[]" - ) - if required == "any": - # Means we need to verify that at least one of the modules is available - if ( - not any(available_modules_status) - ) and module_id_i == maximum_amount_modules: - # Means that we have reached the last module and none of them were available - raise Tidy3dImportError( - f"The package '{module}' is required for this operation, but it was not found. " - f"Please install the '{module}' dependencies using, for example, " - f"'pip install tidy3d[]" - ) - else: - raise ValueError( - f"The value '{required}' is not a valid value for the 'required' parameter. " - f"Please use any 'all' or 'any'." - ) - else: - # Means that the module is available, so we can just continue with the operation - pass - return func(*args, **kwargs) - - return checks_modules_import - - return decorator - - -def requires_vtk(fn: F) -> F: - """When decorating a method, requires that vtk is available.""" - - @functools.wraps(fn) - def _fn(*args: Any, **kwargs: Any) -> Any: - if vtk["mod"] is None: - try: - import vtk as vtk_mod - from vtk.util.numpy_support import ( - numpy_to_vtk, - numpy_to_vtkIdTypeArray, - vtk_to_numpy, - ) - from vtkmodules.vtkCommonCore import vtkLogger - - vtk["mod"] = vtk_mod - vtk["vtk_to_numpy"] = vtk_to_numpy - vtk["numpy_to_vtkIdTypeArray"] = numpy_to_vtkIdTypeArray - vtk["numpy_to_vtk"] = numpy_to_vtk - - vtkLogger.SetStderrVerbosity(vtkLogger.VERBOSITY_WARNING) - - if vtk["mod"].vtkIdTypeArray().GetDataTypeSize() == 4: - vtk["id_type"] = np.int32 - - except ImportError as exc: - raise Tidy3dImportError( - "The package 'vtk' is required for this operation, but it was not found. " - "Please install the 'vtk' dependencies using, for example, " - "'pip install .[vtk]'." - ) from exc - - return fn(*args, **kwargs) - - return _fn - - -def get_numpy_major_version(module: Any = np) -> int: - """ - Extracts the major version of the installed numpy accordingly. - - Parameters - ---------- - module : module - The module to extract the version from. Default is numpy. - - Returns - ------- - int - The major version of the module. - """ - # Get the version of the module - module_version = module.__version__ - - # Extract the major version number - major_version = int(module_version.split(".")[0]) - - return major_version - - -def _check_tidy3d_extras_available(quiet: bool = False) -> None: - """Helper function to check if 'tidy3d-extras' is available and version matched. - - Parameters - ---------- - quiet : bool - If True, suppress error logging when raising exceptions. - - Raises - ------ - Tidy3dImportError - If tidy3d-extras is not available or not properly initialized. - """ - if tidy3d_extras["mod"] is not None: - return - - module_exists = find_spec("tidy3d_extras") is not None - if not module_exists: - raise Tidy3dImportError( - "The package 'tidy3d-extras' is absent. " - "Please install the 'tidy3d-extras' package using, for " - r"example, 'pip install tidy3d\[extras]'.", - log_error=not quiet, - ) - - try: - import tidy3d_extras as tidy3d_extras_mod - - except ImportError as exc: - raise Tidy3dImportError( - "The package 'tidy3d-extras' did not initialize correctly.", - log_error=not quiet, - ) from exc - - if not hasattr(tidy3d_extras_mod, "__version__"): - raise Tidy3dImportError( - "The package 'tidy3d-extras' did not initialize correctly. " - "Please install the 'tidy3d-extras' package using, for " - r"example, 'pip install tidy3d\[extras]'.", - log_error=not quiet, - ) - - version = tidy3d_extras_mod.__version__ - - if version is None: - raise Tidy3dImportError( - "The package 'tidy3d-extras' did not initialize correctly, " - "likely due to an invalid API key.", - log_error=not quiet, - ) - - if version != __version__: - raise Tidy3dImportError( - f"The version of 'tidy3d-extras' is {version}, but the version of 'tidy3d' is {__version__}. " - "They must match. You can install the correct " - r"version using 'pip install tidy3d\[extras]'.", - log_error=not quiet, - ) - - tidy3d_extras["mod"] = tidy3d_extras_mod - - -def check_tidy3d_extras_licensed_feature(feature_name: str, quiet: bool = False) -> None: - """Helper function to check if a specific feature is licensed in 'tidy3d-extras'. - - Parameters - ---------- - feature_name : str - The name of the feature to check for. - quiet : bool - If True, suppress error logging when raising exceptions. - - Raises - ------ - Tidy3dImportError - If the feature is not available with your license. - """ - - try: - _check_tidy3d_extras_available(quiet=quiet) - except Tidy3dImportError as exc: - raise Tidy3dImportError( - f"The package 'tidy3d-extras' is required for this feature '{feature_name}'.", - log_error=not quiet, - ) from exc - - features = tidy3d_extras["mod"].extension._features() - if feature_name not in features: - raise Tidy3dImportError( - f"The feature '{feature_name}' is not available with your license. " - "Please contact Tidy3D support, or upgrade your license.", - log_error=not quiet, - ) + from tidy3d._common.packaging import F def supports_local_subpixel(fn: F) -> F: diff --git a/tidy3d/plugins/invdes/design.py b/tidy3d/plugins/invdes/design.py index 0b403b51a7..7f1dc5dc70 100644 --- a/tidy3d/plugins/invdes/design.py +++ b/tidy3d/plugins/invdes/design.py @@ -22,6 +22,7 @@ import autograd.numpy as anp from tidy3d.compat import Self + from tidy3d.web import BatchData PostProcessFnType = Callable[[td.SimulationData], float] @@ -109,7 +110,7 @@ def run(self, simulation: td.Simulation, **kwargs: Any) -> td.SimulationData: kwargs.setdefault("task_name", self.task_name) return run(simulation, **kwargs) - def run_async(self, simulations: dict[str, td.Simulation], **kwargs: Any) -> td.web.BatchData: + def run_async(self, simulations: dict[str, td.Simulation], **kwargs: Any) -> BatchData: # type- """Run a batch of tidy3d simulations.""" from tidy3d.web import run_async @@ -333,7 +334,7 @@ def to_simulation(self, params: anp.ndarray) -> dict[str, td.Simulation]: simulation_list = [design.to_simulation(params) for design in self.designs] return dict(zip(self.task_names, simulation_list)) - def to_simulation_data(self, params: anp.ndarray, **kwargs: Any) -> td.web.BatchData: + def to_simulation_data(self, params: anp.ndarray, **kwargs: Any) -> BatchData: """Convert the ``InverseDesignMulti`` to a set of ``td.Simulation``s and run async.""" simulations = self.to_simulation(params) return self.run_async(simulations, **kwargs) diff --git a/tidy3d/web/api/autograd/io_utils.py b/tidy3d/web/api/autograd/io_utils.py index a92f8228f9..549995dfe5 100644 --- a/tidy3d/web/api/autograd/io_utils.py +++ b/tidy3d/web/api/autograd/io_utils.py @@ -53,7 +53,6 @@ def get_vjp_traced_fields(task_id_adj: str, verbose: bool) -> AutogradFieldMap: with tempfile.NamedTemporaryFile(suffix=".hdf5") as tmp_file: simulation = load_simulation(task_id_adj, path=tmp_file.name, verbose=False) simulation_cache.store_result( - stub_data=field_map, task_id=task_id_adj, path=fname, workflow_type=workflow_type, diff --git a/tidy3d/web/api/container.py b/tidy3d/web/api/container.py index e117e0acec..da174e9755 100644 --- a/tidy3d/web/api/container.py +++ b/tidy3d/web/api/container.py @@ -495,7 +495,6 @@ def load(self, path: PathLike = DEFAULT_DATA_PATH) -> WorkflowDataType: _store_mode_solver_in_cache( self.task_id, self.simulation, - data, path, ) self.simulation._patch_data(data=data) @@ -1443,7 +1442,7 @@ def load( job_data = data[task_name] if not loaded_from_cache[task_name]: _store_mode_solver_in_cache( - task_ids[task_name], job.simulation, job_data, task_paths[task_name] + task_ids[task_name], job.simulation, task_paths[task_name] ) job.simulation._patch_data(data=job_data) diff --git a/tidy3d/web/api/webapi.py b/tidy3d/web/api/webapi.py index 868f7a686c..42ce352339 100644 --- a/tidy3d/web/api/webapi.py +++ b/tidy3d/web/api/webapi.py @@ -419,7 +419,7 @@ def run( if isinstance(simulation, ModeSolver): if task_id is not None: - _store_mode_solver_in_cache(task_id, simulation, data, path) + _store_mode_solver_in_cache(task_id, simulation, path) simulation._patch_data(data=data) return data @@ -1180,8 +1180,9 @@ def load( except Exception as e: log.info(f"Failed to load simulation for storing results: {e}.") return stub_data + else: + simulation = stub_data.simulation simulation_cache.store_result( - stub_data=stub_data, task_id=task_id, path=path, workflow_type=workflow_type, diff --git a/tidy3d/web/cache.py b/tidy3d/web/cache.py index 92a889411c..1e0047934b 100644 --- a/tidy3d/web/cache.py +++ b/tidy3d/web/cache.py @@ -1,884 +1,56 @@ -"""Local simulation cache manager.""" +"""Compatibility shim for :mod:`tidy3d._common.web.cache`.""" -from __future__ import annotations - -import hashlib -import json -import os -import shutil -import tempfile -import threading -from contextlib import contextmanager -from dataclasses import dataclass -from datetime import datetime, timezone -from enum import Enum -from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -from pydantic import BaseModel, ConfigDict, Field, NonNegativeInt +# marked as migrated to _common +from __future__ import annotations -from tidy3d import config -from tidy3d.log import log +from typing import TYPE_CHECKING + +from tidy3d._common.web.cache import ( + _CACHE, + CACHE_ARTIFACT_NAME, + CACHE_METADATA_NAME, + CACHE_STATS_NAME, + TMP_BATCH_PREFIX, + TMP_PREFIX, + CacheEntry, + CacheEntryMetadata, + CacheStats, + LocalCache, + _canonicalize, + _copy_and_hash, + _Hasher, + _now, + _read_metadata, + _timestamp_suffix, + _write_metadata, + build_cache_key, + build_entry_metadata, + clear, + get_cache_entry_dir, + register_get_workflow_type, + resolve_local_cache, +) +from tidy3d._common.web.core.types import TaskType from tidy3d.web.api.tidy3d_stub import Tidy3dStub -from tidy3d.web.core.http_util import get_version as _get_protocol_version -from tidy3d.web.core.types import TaskType if TYPE_CHECKING: - from collections.abc import Iterator + import os from tidy3d.components.mode.mode_solver import ModeSolver - from tidy3d.components.types.workflow import WorkflowDataType, WorkflowType - from tidy3d.web.core.constants import TaskId - -CACHE_ARTIFACT_NAME = "simulation_data.hdf5" -CACHE_METADATA_NAME = "metadata.json" -CACHE_STATS_NAME = "stats.json" - -TMP_PREFIX = "tidy3d-cache-" -TMP_BATCH_PREFIX = "tmp_batch" - -_CACHE: Optional[LocalCache] = None - - -def _remove_cache_dir(path: os.PathLike, *, recreate: bool) -> None: - """Remove a cache directory and optionally recreate it.""" - cache_path = Path(path) - if cache_path.exists(): - try: - shutil.rmtree(cache_path) - except (FileNotFoundError, OSError): - return - if recreate: - cache_path.mkdir(parents=True, exist_ok=True) - - -def get_cache_entry_dir(root: os.PathLike, key: str) -> Path: - """ - Returns the cache directory for a given key. - A three-character prefix subdirectory is used to avoid hitting filesystem limits on the number of entries per folder. - """ - return Path(root) / key[:3] / key - - -class CacheStats(BaseModel): - """Lightweight summary of cache usage persisted in ``stats.json``.""" - - last_used: dict[str, str] = Field( - default_factory=dict, - description="Mapping from cache entry key to the most recent ISO-8601 access timestamp.", - ) - total_size: NonNegativeInt = Field( - default=0, - description="Aggregate size in bytes across cached artifacts captured in the stats file.", - ) - updated_at: Optional[datetime] = Field( - default=None, - description="UTC timestamp indicating when the statistics were last refreshed.", - ) - - model_config = ConfigDict(extra="allow", validate_assignment=True) - - @property - def total_entries(self) -> int: - return len(self.last_used) - - -class CacheEntryMetadata(BaseModel): - """Schema for cache entry metadata persisted on disk.""" - - cache_key: str - checksum: str - created_at: datetime - last_used: datetime - file_size: int = Field(ge=0) - simulation_hash: str - workflow_type: str - versions: Any - task_id: str - path: str - - model_config = ConfigDict(extra="allow", validate_assignment=True) - - def bump_last_used(self) -> None: - self.last_used = datetime.now(timezone.utc) - - def as_dict(self) -> dict[str, Any]: - return self.model_dump(mode="json") - - def get(self, key: str, default: Any = None) -> Any: - return self.as_dict().get(key, default) - - def __getitem__(self, key: str) -> Any: - data = self.as_dict() - if key not in data: - raise KeyError(key) - return data[key] - - -@dataclass -class CacheEntry: - """Internal representation of a cache entry.""" - - key: str - root: Path - metadata: CacheEntryMetadata - - @property - def path(self) -> Path: - return get_cache_entry_dir(self.root, self.key) - - @property - def artifact_path(self) -> Path: - return self.path / CACHE_ARTIFACT_NAME - - @property - def metadata_path(self) -> Path: - return self.path / CACHE_METADATA_NAME - - def exists(self) -> bool: - return self.path.exists() and self.artifact_path.exists() and self.metadata_path.exists() - - def verify(self) -> bool: - if not self.exists(): - return False - checksum = self.metadata.checksum - if not checksum: - return False - try: - actual_checksum, file_size = _copy_and_hash(self.artifact_path, None) - except FileNotFoundError: - return False - if checksum != actual_checksum: - log.warning( - "Simulation cache checksum mismatch for key '%s'. Removing stale entry.", self.key - ) - return False - if self.metadata.file_size != file_size: - self.metadata.file_size = file_size - _write_metadata(self.metadata_path, self.metadata) - return True - - def materialize(self, target: Path) -> Path: - """Copy cached artifact to ``target`` and return the resulting path.""" - target = Path(target) - target.parent.mkdir(parents=True, exist_ok=True) - shutil.copy2(self.artifact_path, target) - return target - - -class LocalCache: - """Manages storing and retrieving cached simulation artifacts.""" - - def __init__(self, directory: os.PathLike, max_size_gb: float, max_entries: int) -> None: - self.max_size_gb = max_size_gb - self.max_entries = max_entries - self._root = Path(directory) - self._lock = threading.RLock() - self._syncing_stats = False - self._sync_pending = False - - @property - def _stats_path(self) -> Path: - return self._root / CACHE_STATS_NAME - - def _schedule_sync(self) -> None: - self._sync_pending = True - - def _run_pending_sync(self) -> None: - if self._sync_pending and not self._syncing_stats: - self._sync_pending = False - self.sync_stats() - - @contextmanager - def _with_lock(self) -> Iterator[None]: - self._run_pending_sync() - with self._lock: - yield - self._run_pending_sync() - - def _write_stats(self, stats: CacheStats) -> CacheStats: - updated = stats.model_copy(update={"updated_at": datetime.now(timezone.utc)}) - payload = updated.model_dump(mode="json") - payload["total_entries"] = updated.total_entries - self._stats_path.parent.mkdir(parents=True, exist_ok=True) - _write_metadata(self._stats_path, payload) - self._sync_pending = False - return updated - - def _load_stats(self, *, rebuild: bool = False) -> CacheStats: - path = self._stats_path - if not path.exists(): - if not self._syncing_stats: - self._schedule_sync() - return CacheStats() - try: - data = json.loads(path.read_text(encoding="utf-8")) - if "last_used" not in data and "entries" in data: - data["last_used"] = data.pop("entries") - stats = CacheStats.model_validate(data) - except Exception: - if rebuild and not self._syncing_stats: - self._schedule_sync() - return CacheStats() - if stats.total_size < 0: - self._schedule_sync() - return CacheStats() - return stats - - def _record_store_stats( - self, - key: str, - *, - last_used: str, - file_size: int, - previous_size: int, - ) -> None: - stats = self._load_stats() - entries = dict(stats.last_used) - entries[key] = last_used - total_size = stats.total_size - previous_size + file_size - if total_size < 0: - total_size = 0 - self._schedule_sync() - updated = stats.model_copy(update={"last_used": entries, "total_size": total_size}) - self._write_stats(updated) - - def _record_touch_stats( - self, key: str, last_used: str, *, file_size: Optional[int] = None - ) -> None: - stats = self._load_stats() - entries = dict(stats.last_used) - existed = key in entries - total_size = stats.total_size - if not existed and file_size is not None: - total_size += file_size - if total_size < 0: - total_size = 0 - self._schedule_sync() - entries[key] = last_used - updated = stats.model_copy(update={"last_used": entries, "total_size": total_size}) - self._write_stats(updated) - - def _record_remove_stats(self, key: str, file_size: int) -> None: - stats = self._load_stats() - entries = dict(stats.last_used) - entries.pop(key, None) - total_size = stats.total_size - file_size - if total_size < 0: - total_size = 0 - self._schedule_sync() - updated = stats.model_copy(update={"last_used": entries, "total_size": total_size}) - self._write_stats(updated) - - def _enforce_limits_post_sync(self, entries: list[CacheEntry]) -> None: - if not entries: - return - - entries_map = {entry.key: entry.metadata.last_used.isoformat() for entry in entries} - - if self.max_entries > 0 and len(entries) > self.max_entries: - excess = len(entries) - self.max_entries - self._evict(entries_map, remove_count=excess, exclude_keys=set()) - - max_size_bytes = int(self.max_size_gb * (1024**3)) - if max_size_bytes > 0: - total_size = sum(entry.metadata.file_size for entry in entries) - if total_size > max_size_bytes: - bytes_to_free = total_size - max_size_bytes - self._evict_by_size(entries_map, bytes_to_free, exclude_keys=set()) - - def sync_stats(self) -> CacheStats: - with self._lock: - self._syncing_stats = True - log.debug("Syncing stats.json of local cache") - try: - entries: list[CacheEntry] = [] - last_used_map: dict[str, str] = {} - total_size = 0 - for entry in self._iter_entries(): - entries.append(entry) - total_size += entry.metadata.file_size - last_used_map[entry.key] = entry.metadata.last_used.isoformat() - stats = CacheStats(last_used=last_used_map, total_size=total_size) - written = self._write_stats(stats) - self._enforce_limits_post_sync(entries) - return written - finally: - self._syncing_stats = False - - @property - def root(self) -> Path: - return self._root - - def list(self) -> list[dict[str, Any]]: - """Return metadata for all cache entries.""" - with self._with_lock(): - entries = [entry.metadata.model_dump(mode="json") for entry in self._iter_entries()] - return entries - - def clear(self, hard: bool = False) -> None: - """Remove all cache contents. If set to hard, root directory is removed.""" - with self._with_lock(): - _remove_cache_dir(self._root, recreate=not hard) - if not hard: - self._write_stats(CacheStats()) - - def _fetch(self, key: str) -> Optional[CacheEntry]: - """Retrieve an entry by key, verifying checksum.""" - with self._with_lock(): - entry = self._load_entry(key) - if not entry or not entry.exists(): - return None - if not entry.verify(): - self._remove_entry(entry) - return None - self._touch(entry) - return entry - - def __len__(self) -> int: - """Return number of valid cache entries.""" - with self._with_lock(): - count = self._load_stats().total_entries - return count - - def _store( - self, key: str, source_path: Path, metadata: CacheEntryMetadata - ) -> Optional[CacheEntry]: - """Store a new cache entry from ``source_path``. - - Parameters - ---------- - key : str - Cache key computed from simulation hash and runtime context. - source_path : Path - Location of the artifact to cache. - metadata : CacheEntryMetadata - Metadata describing the cache entry to be persisted. - - Returns - ------- - CacheEntry - Representation of the stored cache entry. - """ - source_path = Path(source_path) - if not source_path.exists(): - raise FileNotFoundError(f"Cannot cache missing artifact: {source_path}") - os.makedirs(self._root, exist_ok=True) - tmp_dir = Path(tempfile.mkdtemp(prefix=TMP_PREFIX, dir=self._root)) - tmp_artifact = tmp_dir / CACHE_ARTIFACT_NAME - tmp_meta = tmp_dir / CACHE_METADATA_NAME - os.makedirs(tmp_dir, exist_ok=True) - - checksum, file_size = _copy_and_hash(source_path, tmp_artifact) - metadata.cache_key = key - metadata.created_at = datetime.now(timezone.utc) - metadata.last_used = metadata.created_at - metadata.checksum = checksum - metadata.file_size = file_size - - _write_metadata(tmp_meta, metadata) - entry: Optional[CacheEntry] = None - try: - with self._with_lock(): - self._root.mkdir(parents=True, exist_ok=True) - existing_entry = self._load_entry(key) - previous_size = ( - existing_entry.metadata.file_size if existing_entry is not None else 0 - ) - self._ensure_limits( - file_size, - incoming_key=key, - replacing_size=previous_size, - ) - final_dir = get_cache_entry_dir(self._root, key) - final_dir.parent.mkdir(parents=True, exist_ok=True) - if final_dir.exists(): - shutil.rmtree(final_dir) - os.replace(tmp_dir, final_dir) - entry = CacheEntry(key=key, root=self._root, metadata=metadata) - - self._record_store_stats( - key, - last_used=metadata.last_used.isoformat(), - file_size=file_size, - previous_size=previous_size, - ) - log.debug("Stored simulation cache entry '%s' (%d bytes).", key, file_size) - finally: - try: - if tmp_dir.exists(): - shutil.rmtree(tmp_dir, ignore_errors=True) - except FileNotFoundError: - pass - return entry - - def invalidate(self, key: str) -> None: - with self._with_lock(): - entry = self._load_entry(key) - if entry: - self._remove_entry(entry) - - def _ensure_limits( - self, - incoming_size: int, - *, - incoming_key: Optional[str] = None, - replacing_size: int = 0, - ) -> None: - max_entries = self.max_entries - max_size_bytes = int(self.max_size_gb * (1024**3)) - - try: - incoming_size_int = int(incoming_size) - except (TypeError, ValueError): - incoming_size_int = 0 - if incoming_size_int < 0: - incoming_size_int = 0 - - stats = self._load_stats() - entries_info = dict(stats.last_used) - existing_keys = set(entries_info) - projected_entries = stats.total_entries - if not incoming_key or incoming_key not in existing_keys: - projected_entries += 1 - - if projected_entries > max_entries > 0: - excess = projected_entries - max_entries - exclude = {incoming_key} if incoming_key else set() - self._evict(entries_info, remove_count=excess, exclude_keys=exclude) - stats = self._load_stats() - entries_info = dict(stats.last_used) - existing_keys = set(entries_info) - - if max_size_bytes == 0: # no limit - return + from tidy3d.components.types.workflow import WorkflowType - existing_size = stats.total_size - try: - replacing_size_int = int(replacing_size) - except (TypeError, ValueError): - replacing_size_int = 0 - if incoming_key and incoming_key in existing_keys: - projected_size = existing_size - replacing_size_int + incoming_size_int - else: - projected_size = existing_size + incoming_size_int - if max_size_bytes > 0 and projected_size > max_size_bytes: - bytes_to_free = projected_size - max_size_bytes - exclude = {incoming_key} if incoming_key else set() - self._evict_by_size(entries_info, bytes_to_free, exclude_keys=exclude) +def get_workflow_type(simulation: WorkflowType) -> str: + """Resolve workflow type name for cache logging.""" + return Tidy3dStub(simulation=simulation).get_type() - def _evict(self, entries: dict[str, str], *, remove_count: int, exclude_keys: set[str]) -> None: - if remove_count <= 0: - return - candidates = [(key, entries.get(key, "")) for key in entries if key not in exclude_keys] - if not candidates: - return - candidates.sort(key=lambda item: item[1] or "") - for key, _ in candidates[:remove_count]: - self._remove_entry_by_key(key) - def _evict_by_size( - self, entries: dict[str, str], bytes_to_free: int, *, exclude_keys: set[str] - ) -> None: - if bytes_to_free <= 0: - return - candidates = [(key, entries.get(key, "")) for key in entries if key not in exclude_keys] - if not candidates: - return - candidates.sort(key=lambda item: item[1] or "") - reclaimed = 0 - for key, _ in candidates: - if reclaimed >= bytes_to_free: - break - entry = self._load_entry(key) - if entry is None: - log.debug("Could not find entry for eviction.") - self._schedule_sync() - break - size = entry.metadata.file_size - self._remove_entry(entry) - reclaimed += size - log.info(f"Simulation cache evicted entry '{key}' to reclaim {size} bytes.") +register_get_workflow_type(get_workflow_type) - def _iter_entries(self) -> Iterator[CacheEntry]: - """Iterate lazily over all cache entries, including those in prefix subdirectories.""" - if not self._root.exists(): - return - for prefix_dir in self._root.iterdir(): - if not prefix_dir.is_dir() or prefix_dir.name.startswith( - (TMP_PREFIX, TMP_BATCH_PREFIX) - ): - continue - - # if cache is directly flat (no prefix directories), include that level too - subdirs = [prefix_dir] - if any((prefix_dir / name).is_dir() for name in prefix_dir.iterdir()): - subdirs = prefix_dir.iterdir() - - for child in subdirs: - if not child.is_dir(): - continue - if child.name.startswith((TMP_PREFIX, TMP_BATCH_PREFIX)): - continue - - meta_path = child / CACHE_METADATA_NAME - if not meta_path.exists(): - continue - - try: - metadata = _read_metadata(meta_path, child / CACHE_ARTIFACT_NAME) - except Exception: - log.debug( - "Failed to parse metadata for '%s'; scheduling stats sync.", child.name - ) - self._schedule_sync() - continue - - yield CacheEntry(key=child.name, root=self._root, metadata=metadata) - - def _load_entry(self, key: str) -> Optional[CacheEntry]: - entry = CacheEntry(key=key, root=self._root, metadata={}) - if not entry.metadata_path.exists() or not entry.artifact_path.exists(): - return None - try: - metadata = _read_metadata(entry.metadata_path, entry.artifact_path) - except Exception: - return None - return CacheEntry(key=key, root=self._root, metadata=metadata) - - def _touch(self, entry: CacheEntry) -> None: - entry.metadata.bump_last_used() - _write_metadata(entry.metadata_path, entry.metadata) - self._record_touch_stats( - entry.key, - entry.metadata.last_used.isoformat(), - file_size=entry.metadata.file_size, - ) - - def _remove_entry_by_key(self, key: str) -> None: - entry = self._load_entry(key) - if entry is None: - path = get_cache_entry_dir(self._root, key) - if path.exists(): - shutil.rmtree(path, ignore_errors=True) - else: - log.debug("Could not find entry for key '%s' to delete.", key) - self._record_remove_stats(key, 0) - return - self._remove_entry(entry) - - def _remove_entry(self, entry: CacheEntry) -> None: - file_size = entry.metadata.file_size - if entry.path.exists(): - shutil.rmtree(entry.path, ignore_errors=True) - self._record_remove_stats(entry.key, file_size) - - def try_fetch( - self, - simulation: WorkflowType, - verbose: bool = False, - ) -> Optional[CacheEntry]: - """ - Attempt to resolve and fetch a cached result entry for the given simulation context. - On miss or any cache error, returns None (the caller should proceed with upload/run). - """ - try: - simulation_hash = simulation._hash_self() - workflow_type = Tidy3dStub(simulation=simulation).get_type() - - versions = _get_protocol_version() - - cache_key = build_cache_key( - simulation_hash=simulation_hash, - version=versions, - ) - - entry = self._fetch(cache_key) - if not entry: - return None - - if verbose: - log.info( - f"Simulation cache hit for workflow '{workflow_type}'; using local results." - ) - - return entry - except Exception as e: - log.error("Failed to fetch cache results: " + str(e)) - return None - - def store_result( - self, - stub_data: WorkflowDataType, - task_id: TaskId, - path: str, - workflow_type: str, - simulation: Optional[WorkflowType] = None, - ) -> bool: - """ - Stores completed workflow results in the local cache using a canonical cache key. - - Parameters - ---------- - stub_data : :class:`.WorkflowDataType` - Object containing the workflow results, including references to the originating simulation. - task_id : str - Unique identifier of the finished workflow task. - path : str - Path to the results file on disk. - workflow_type : str - Type of workflow associated with the results (e.g., ``"SIMULATION"`` or ``"MODE_SOLVER"``). - simulation : Optional[:class:`.WorkflowDataType`] - Simulation object to use when computing the cache key. If not provided, - it will be inferred from ``stub_data.simulation`` when possible. - - Returns - ------- - bool - ``True`` if the result was successfully stored in the local cache, ``False`` otherwise. - - Notes - ----- - The cache entry is keyed by the simulation hash, workflow type, environment, and protocol version. - This enables automatic reuse of identical simulation results across future runs. - Legacy task ID mappings are recorded to support backward lookup compatibility. - """ - try: - if simulation is not None: - simulation_obj = simulation - else: - simulation_obj = getattr(stub_data, "simulation", None) - if simulation_obj is None: - log.debug( - "Failed storing local cache entry: Could not find simulation data in stub_data." - ) - return False - simulation_hash = simulation_obj._hash_self() if simulation_obj is not None else None - if not simulation_hash: - log.debug("Failed storing local cache entry: Could not hash simulation.") - return False - - version = _get_protocol_version() - - cache_key = build_cache_key( - simulation_hash=simulation_hash, - version=version, - ) - - metadata = build_entry_metadata( - simulation_hash=simulation_hash, - workflow_type=workflow_type, - task_id=task_id, - version=version, - path=Path(path), - ) - - self._store( - key=cache_key, - source_path=Path(path), - metadata=metadata, - ) - log.debug("Stored local cache entry for workflow type '%s'.", workflow_type) - except Exception as e: - log.error(f"Could not store cache entry: {e}") - return False - return True - - -def _copy_and_hash( - source: Path, dest: Optional[Path], existing_hash: Optional[str] = None -) -> tuple[str, int]: - """Copy ``source`` to ``dest`` while computing SHA256 checksum. - - Parameters - ---------- - source : Path - Source file path. - dest : Path or None - Destination file path. If ``None``, no copy is performed. - existing_hash : str, optional - If provided alongside ``dest`` and ``dest`` already exists, skip copying when hashes match. - - Returns - ------- - tuple[str, int] - The hexadecimal digest and file size in bytes. - """ - source = Path(source) - if dest is not None: - dest = Path(dest) - sha256 = _Hasher() - size = 0 - with source.open("rb") as src: - if dest is None: - while chunk := src.read(1024 * 1024): - sha256.update(chunk) - size += len(chunk) - else: - dest.parent.mkdir(parents=True, exist_ok=True) - with dest.open("wb") as dst: - while chunk := src.read(1024 * 1024): - dst.write(chunk) - sha256.update(chunk) - size += len(chunk) - return sha256.hexdigest(), size - - -def _write_metadata(path: Path, metadata: CacheEntryMetadata | dict[str, Any]) -> None: - tmp_path = path.with_suffix(".tmp") - payload: dict[str, Any] - if isinstance(metadata, CacheEntryMetadata): - payload = metadata.model_dump(mode="json") - else: - payload = metadata - with tmp_path.open("w", encoding="utf-8") as fh: - json.dump(payload, fh, indent=2, sort_keys=True) - os.replace(tmp_path, path) - - -def _now() -> str: - return datetime.now(timezone.utc).isoformat() - - -def _timestamp_suffix() -> str: - return datetime.now(timezone.utc).strftime("%Y%m%d%H%M%S%f") - - -def _read_metadata(meta_path: Path, artifact_path: Path) -> CacheEntryMetadata: - raw = json.loads(meta_path.read_text(encoding="utf-8")) - if "file_size" not in raw: - try: - raw["file_size"] = artifact_path.stat().st_size - except FileNotFoundError: - raw["file_size"] = 0 - raw.setdefault("created_at", _now()) - raw.setdefault("last_used", raw["created_at"]) - raw.setdefault("cache_key", meta_path.parent.name) - return CacheEntryMetadata.model_validate(raw) - - -class _Hasher: - def __init__(self) -> None: - self._hasher = hashlib.sha256() - - def update(self, data: bytes) -> None: - self._hasher.update(data) - - def hexdigest(self) -> str: - return self._hasher.hexdigest() - - -def clear() -> None: - """Remove all cache entries.""" - cache = resolve_local_cache(use_cache=True) - if cache is not None: - cache.clear() - - -def _canonicalize(value: Any) -> Any: - """Convert value into a JSON-serializable object for hashing/metadata.""" - - if isinstance(value, dict): - return { - str(k): _canonicalize(v) - for k, v in sorted(value.items(), key=lambda item: str(item[0])) - } - if isinstance(value, (list, tuple)): - return [_canonicalize(v) for v in value] - if isinstance(value, set): - return sorted(_canonicalize(v) for v in value) - if isinstance(value, Enum): - return value.value - if isinstance(value, Path): - return str(value) - if isinstance(value, datetime): - return value.isoformat() - if isinstance(value, bytes): - return value.decode("utf-8", errors="ignore") - return value - - -def build_cache_key( - *, - simulation_hash: str, - version: str, -) -> str: - """Construct a deterministic cache key.""" - - payload = { - "simulation_hash": simulation_hash, - "versions": _canonicalize(version), - } - encoded = json.dumps(payload, sort_keys=True, separators=(",", ":")).encode("utf-8") - return hashlib.sha256(encoded).hexdigest() - - -def build_entry_metadata( - *, - simulation_hash: str, - workflow_type: str, - task_id: str, - version: str, - path: Path, -) -> CacheEntryMetadata: - """Create metadata object for a cache entry.""" - - now = datetime.now(timezone.utc) - return CacheEntryMetadata( - cache_key="", - checksum="", - created_at=now, - last_used=now, - file_size=0, - simulation_hash=simulation_hash, - workflow_type=workflow_type, - versions=_canonicalize(version), - task_id=task_id, - path=str(path), - ) - - -def resolve_local_cache(use_cache: Optional[bool] = None) -> Optional[LocalCache]: - """ - Returns LocalCache instance if enabled. - Returns None if use_cached=False or config-fetched 'enabled' is False. - Deletes old cache directory if existing. - """ - global _CACHE - - if use_cache is False or (use_cache is not True and not config.local_cache.enabled): - return None - - if _CACHE is not None and _CACHE._root != Path(config.local_cache.directory): - old_root = _CACHE._root - new_root = Path(config.local_cache.directory) - log.debug(f"Moving cache directory from {old_root} → {new_root}") - try: - new_root.parent.mkdir(parents=True, exist_ok=True) - if old_root.exists(): - shutil.move(old_root, new_root) - except Exception as e: - log.warning(f"Failed to move cache directory: {e}. Delete old cache.") - _remove_cache_dir(old_root, recreate=False) - - _CACHE = LocalCache( - directory=config.local_cache.directory, - max_entries=config.local_cache.max_entries, - max_size_gb=config.local_cache.max_size_gb, - ) - - try: - return _CACHE - except Exception as err: - log.debug(f"Simulation cache unavailable: {err}") - return None - - -def _store_mode_solver_in_cache( - task_id: TaskId, simulation: ModeSolver, data: WorkflowDataType, path: os.PathLike -) -> bool: +def _store_mode_solver_in_cache(task_id: str, simulation: ModeSolver, path: os.PathLike) -> bool: """ Stores the results of a :class:`.ModeSolver` run in the local cache, if available. @@ -888,8 +60,6 @@ def _store_mode_solver_in_cache( Unique identifier of the mode solver task. simulation : :class:`.ModeSolver` Mode solver simulation object whose results should be cached. - data : :class:`.WorkflowDataType` - Data object containing the computed results to store. path : PathLike Path to the result file on disk. @@ -906,7 +76,6 @@ def _store_mode_solver_in_cache( simulation_cache = resolve_local_cache() if simulation_cache is not None: stored = simulation_cache.store_result( - stub_data=data, task_id=task_id, path=path, workflow_type=TaskType.MODE_SOLVER.name, @@ -914,6 +83,3 @@ def _store_mode_solver_in_cache( ) return stored return False - - -resolve_local_cache() diff --git a/tidy3d/web/core/__init__.py b/tidy3d/web/core/__init__.py index 281ccca03a..f353d09be6 100644 --- a/tidy3d/web/core/__init__.py +++ b/tidy3d/web/core/__init__.py @@ -1,8 +1,10 @@ -"""Tidy3d core package imports""" +"""Compatibility shim for :mod:`tidy3d._common.web.core`.""" -from __future__ import annotations +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -# TODO(FXC-3827): Drop this import once the legacy shim is removed in Tidy3D 2.12. -from . import environment +# marked as migrated to _common +from __future__ import annotations -__all__ = ["environment"] +from tidy3d._common.web.core import ( + environment, +) diff --git a/tidy3d/web/core/account.py b/tidy3d/web/core/account.py index 30cab0b392..fbc4c856da 100644 --- a/tidy3d/web/core/account.py +++ b/tidy3d/web/core/account.py @@ -1,66 +1,10 @@ -"""Tidy3d user account.""" - -from __future__ import annotations - -from datetime import datetime -from typing import Optional - -from pydantic import Field - -from .http_util import http -from .types import Tidy3DResource +"""Compatibility shim for :mod:`tidy3d._common.web.core.account`.""" +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -class Account(Tidy3DResource, extra="allow"): - """Tidy3D User Account.""" - - allowance_cycle_type: Optional[str] = Field( - None, - title="AllowanceCycleType", - description="Daily or Monthly", - alias="allowanceCycleType", - ) - credit: Optional[float] = Field( - 0, title="credit", description="Current FlexCredit balance", alias="credit" - ) - credit_expiration: Optional[datetime] = Field( - None, - title="creditExpiration", - description="Expiration date", - alias="creditExpiration", - ) - allowance_current_cycle_amount: Optional[float] = Field( - 0, - title="allowanceCurrentCycleAmount", - description="Daily/Monthly free simulation balance", - alias="allowanceCurrentCycleAmount", - ) - allowance_current_cycle_end_date: Optional[datetime] = Field( - None, - title="allowanceCurrentCycleEndDate", - description="Daily/Monthly free simulation balance expiration date", - alias="allowanceCurrentCycleEndDate", - ) - daily_free_simulation_counts: Optional[int] = Field( - 0, - title="dailyFreeSimulationCounts", - description="Daily free simulation counts", - alias="dailyFreeSimulationCounts", - ) - - @classmethod - def get(cls) -> Optional[Account]: - """Get user account information. - - Parameters - ---------- +# marked as migrated to _common +from __future__ import annotations - Returns - ------- - account : Account - """ - resp = http.get("tidy3d/py/account") - if resp: - account = Account(**resp) - return account - return None +from tidy3d._common.web.core.account import ( + Account, +) diff --git a/tidy3d/web/core/cache.py b/tidy3d/web/core/cache.py index d83421ca21..115080c123 100644 --- a/tidy3d/web/core/cache.py +++ b/tidy3d/web/core/cache.py @@ -1,6 +1,11 @@ -"""Local caches.""" +"""Compatibility shim for :mod:`tidy3d._common.web.core.cache`.""" +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility + +# marked as migrated to _common from __future__ import annotations -FOLDER_CACHE = {} -S3_STS_TOKENS = {} +from tidy3d._common.web.core.cache import ( + FOLDER_CACHE, + S3_STS_TOKENS, +) diff --git a/tidy3d/web/core/constants.py b/tidy3d/web/core/constants.py index 623af2bba8..bb03702f91 100644 --- a/tidy3d/web/core/constants.py +++ b/tidy3d/web/core/constants.py @@ -1,38 +1,34 @@ -"""Defines constants for core.""" - -# HTTP Header key and value -from __future__ import annotations - -HEADER_APIKEY = "simcloud-api-key" -HEADER_VERSION = "tidy3d-python-version" -HEADER_SOURCE = "source" -HEADER_SOURCE_VALUE = "Python" -HEADER_USER_AGENT = "User-Agent" -HEADER_APPLICATION = "Application" -HEADER_APPLICATION_VALUE = "TIDY3D" +"""Compatibility shim for :mod:`tidy3d._common.web.core.constants`.""" +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -SIMCLOUD_APIKEY = "SIMCLOUD_APIKEY" -KEY_APIKEY = "apikey" -JSON_TAG = "JSON_STRING" -# type of the task_id -TaskId = str -# type of task_name -TaskName = str - - -SIMULATION_JSON = "simulation.json" -SIMULATION_DATA_HDF5 = "output/monitor_data.hdf5" -SIMULATION_DATA_HDF5_GZ = "output/simulation_data.hdf5.gz" -RUNNING_INFO = "output/solver_progress.csv" -SIM_LOG_FILE = "output/tidy3d.log" -SIM_FILE_HDF5 = "simulation.hdf5" -SIM_FILE_HDF5_GZ = "simulation.hdf5.gz" -MODE_FILE_HDF5_GZ = "mode_solver.hdf5.gz" -MODE_DATA_HDF5_GZ = "output/mode_solver_data.hdf5.gz" -SIM_ERROR_FILE = "output/tidy3d_error.json" -SIM_VALIDATION_FILE = "output/tidy3d_validation.json" +# marked as migrated to _common +from __future__ import annotations -# Component modeler specific artifacts -MODELER_FILE_HDF5_GZ = "modeler.hdf5.gz" -CM_DATA_HDF5_GZ = "output/cm_data.hdf5.gz" +from tidy3d._common.web.core.constants import ( + CM_DATA_HDF5_GZ, + HEADER_APIKEY, + HEADER_APPLICATION, + HEADER_APPLICATION_VALUE, + HEADER_SOURCE, + HEADER_SOURCE_VALUE, + HEADER_USER_AGENT, + HEADER_VERSION, + JSON_TAG, + KEY_APIKEY, + MODE_DATA_HDF5_GZ, + MODE_FILE_HDF5_GZ, + MODELER_FILE_HDF5_GZ, + RUNNING_INFO, + SIM_ERROR_FILE, + SIM_FILE_HDF5, + SIM_FILE_HDF5_GZ, + SIM_LOG_FILE, + SIM_VALIDATION_FILE, + SIMCLOUD_APIKEY, + SIMULATION_DATA_HDF5, + SIMULATION_DATA_HDF5_GZ, + SIMULATION_JSON, + TaskId, + TaskName, +) diff --git a/tidy3d/web/core/core_config.py b/tidy3d/web/core/core_config.py index 313f1aff62..b3640c8df2 100644 --- a/tidy3d/web/core/core_config.py +++ b/tidy3d/web/core/core_config.py @@ -1,50 +1,14 @@ -"""Tidy3d core log, need init config from Tidy3d api""" +"""Compatibility shim for :mod:`tidy3d._common.web.core.core_config`.""" -from __future__ import annotations - -import logging as log -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from rich.console import Console - - from tidy3d.log import Logger - -# default setting -config_setting = { - "logger": log, - "logger_console": None, - "version": "", -} - - -def set_config(logger: Logger, logger_console: Console, version: str) -> None: - """Init tidy3d core logger and logger console. - - Parameters - ---------- - logger : :class:`.Logger` - Tidy3d log Logger. - logger_console : :class:`.Console` - Get console from logging handlers. - version : str - tidy3d version - """ - config_setting["logger"] = logger - config_setting["logger_console"] = logger_console - config_setting["version"] = version - - -def get_logger() -> Logger: - """Get logging handlers.""" - return config_setting["logger"] - - -def get_logger_console() -> Console: - """Get console from logging handlers.""" - return config_setting["logger_console"] +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility +# marked as migrated to _common +from __future__ import annotations -def get_version() -> str: - """Get version from cache.""" - return config_setting["version"] +from tidy3d._common.web.core.core_config import ( + config_setting, + get_logger, + get_logger_console, + get_version, + set_config, +) diff --git a/tidy3d/web/core/environment.py b/tidy3d/web/core/environment.py index ffe86d89d4..7873b9b4d2 100644 --- a/tidy3d/web/core/environment.py +++ b/tidy3d/web/core/environment.py @@ -1,42 +1,20 @@ -"""Legacy re-export of configuration environment helpers.""" +"""Compatibility shim for :mod:`tidy3d._common.web.core.environment`.""" -from __future__ import annotations - -# TODO(FXC-3827): Remove this module-level legacy shim in Tidy3D 2.12. -import warnings -from typing import Any - -from tidy3d.config import Env, Environment, EnvironmentConfig +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -__all__ = [ # noqa: F822 - "Env", - "Environment", - "EnvironmentConfig", - "dev", - "nexus", - "pre", - "prod", - "uat", -] +# marked as migrated to _common +from __future__ import annotations -_LEGACY_ENV_NAMES = {"dev", "uat", "pre", "prod", "nexus"} -_DEPRECATION_MESSAGE = ( - "'tidy3d.web.core.environment.{name}' is deprecated and will be removed in " - "Tidy3D 2.12. Transition to 'tidy3d.config.Env.{name}' or " - "'tidy3d.config.config.switch_profile(...)'." +from tidy3d._common.web.core.environment import ( + _DEPRECATION_MESSAGE, + _LEGACY_ENV_NAMES, + Env, + Environment, + EnvironmentConfig, + _get_legacy_env, + dev, + nexus, + pre, + prod, + uat, ) - - -def _get_legacy_env(name: str) -> Any: - warnings.warn(_DEPRECATION_MESSAGE.format(name=name), DeprecationWarning, stacklevel=2) - return getattr(Env, name) - - -def __getattr__(name: str) -> Any: - if name in _LEGACY_ENV_NAMES: - return _get_legacy_env(name) - raise AttributeError(f"module '{__name__}' has no attribute '{name}'") - - -def __dir__() -> list[str]: - return sorted(set(__all__)) diff --git a/tidy3d/web/core/exceptions.py b/tidy3d/web/core/exceptions.py index 4733c4fe08..1900371e9b 100644 --- a/tidy3d/web/core/exceptions.py +++ b/tidy3d/web/core/exceptions.py @@ -1,24 +1,11 @@ -"""Custom Tidy3D exceptions""" +"""Compatibility shim for :mod:`tidy3d._common.web.core.exceptions`.""" -from __future__ import annotations - -from typing import TYPE_CHECKING - -from .core_config import get_logger - -if TYPE_CHECKING: - from typing import Optional - - -class WebError(Exception): - """Any error in tidy3d""" - - def __init__(self, message: Optional[str] = None) -> None: - """Log just the error message and then raise the Exception.""" - log = get_logger() - super().__init__(message) - log.error(message) +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility +# marked as migrated to _common +from __future__ import annotations -class WebNotFoundError(WebError): - """A generic error indicating an HTTP 404 (resource not found).""" +from tidy3d._common.web.core.exceptions import ( + WebError, + WebNotFoundError, +) diff --git a/tidy3d/web/core/file_util.py b/tidy3d/web/core/file_util.py index a2160ec66b..b908de53a2 100644 --- a/tidy3d/web/core/file_util.py +++ b/tidy3d/web/core/file_util.py @@ -1,87 +1,15 @@ -"""File compression utilities""" - -from __future__ import annotations - -import gzip -import os -import shutil -import tempfile - -import h5py - -from tidy3d.web.core.constants import JSON_TAG - - -def compress_file_to_gzip(input_file: os.PathLike, output_gz_file: os.PathLike) -> None: - """Compresses a file using gzip. - - Parameters - ---------- - input_file : PathLike - The path of the input file. - output_gz_file : PathLike - The path of the output gzip file. - """ - with open(input_file, "rb") as file_in: - with gzip.open(output_gz_file, "wb") as file_out: - shutil.copyfileobj(file_in, file_out) - - -def extract_gzip_file(input_gz_file: os.PathLike, output_file: os.PathLike) -> None: - """Extract a gzip file. - - Parameters - ---------- - input_gz_file : PathLike - The path of the gzip input file. - output_file : PathLike - The path of the output file. - """ - with gzip.open(input_gz_file, "rb") as file_in: - with open(output_file, "wb") as file_out: - shutil.copyfileobj(file_in, file_out) +"""Compatibility shim for :mod:`tidy3d._common.web.core.file_util`.""" +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -def read_simulation_from_hdf5_gz(file_name: os.PathLike) -> str: - """read simulation str from hdf5.gz""" - - hdf5_file, hdf5_file_path = tempfile.mkstemp(".hdf5") - os.close(hdf5_file) - try: - extract_gzip_file(file_name, hdf5_file_path) - # Pass the uncompressed temporary file path to the reader - json_str = read_simulation_from_hdf5(hdf5_file_path) - finally: - os.unlink(hdf5_file_path) - return json_str - - -"""TODO: _json_string_key and read_simulation_from_hdf5 are duplicated functions that also exist -as methods in Tidy3dBaseModel. For consistency it would be best if this duplication is avoided.""" - - -def _json_string_key(index: int) -> str: - """Get json string key for string chunk number ``index``.""" - if index: - return f"{JSON_TAG}_{index}" - return JSON_TAG - - -def read_simulation_from_hdf5(file_name: os.PathLike) -> bytes: - """read simulation str from hdf5""" - with h5py.File(file_name, "r") as f_handle: - num_string_parts = len([key for key in f_handle.keys() if JSON_TAG in key]) - json_string = b"" - for ind in range(num_string_parts): - json_string += f_handle[_json_string_key(ind)][()] - return json_string - - -"""End TODO""" - +# marked as migrated to _common +from __future__ import annotations -def read_simulation_from_json(file_name: os.PathLike) -> str: - """read simulation str from json""" - with open(file_name, encoding="utf-8") as json_file: - json_data = json_file.read() - return json_data +from tidy3d._common.web.core.file_util import ( + _json_string_key, + compress_file_to_gzip, + extract_gzip_file, + read_simulation_from_hdf5, + read_simulation_from_hdf5_gz, + read_simulation_from_json, +) diff --git a/tidy3d/web/core/http_util.py b/tidy3d/web/core/http_util.py index dbf657d9b7..31170a69f5 100644 --- a/tidy3d/web/core/http_util.py +++ b/tidy3d/web/core/http_util.py @@ -1,286 +1,20 @@ -"""Http connection pool and authentication management.""" +"""Compatibility shim for :mod:`tidy3d._common.web.core.http_util`.""" -from __future__ import annotations - -import json -import os -import ssl -from enum import Enum -from functools import wraps -from typing import TYPE_CHECKING, Any +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -import requests -from requests.adapters import HTTPAdapter -from urllib3.util.ssl_ import create_urllib3_context - -from tidy3d import log -from tidy3d.config import config +# marked as migrated to _common +from __future__ import annotations -from . import core_config -from .constants import ( - HEADER_APIKEY, - HEADER_APPLICATION, - HEADER_APPLICATION_VALUE, - HEADER_SOURCE, - HEADER_SOURCE_VALUE, - HEADER_USER_AGENT, - HEADER_VERSION, - SIMCLOUD_APIKEY, +from tidy3d._common.web.core.http_util import ( + HttpSessionManager, + JSONType, + ResponseCodes, + TLSAdapter, + api_key, + api_key_auth, + get_headers, + get_user_agent, + get_version, + http, + http_interceptor, ) -from .core_config import get_logger -from .exceptions import WebError, WebNotFoundError - -if TYPE_CHECKING: - from typing import Callable, Optional, TypeAlias - -JSONType: TypeAlias = dict[str, Any] | list[Any] | str | int - - -class ResponseCodes(Enum): - """HTTP response codes to handle individually.""" - - UNAUTHORIZED = 401 - OK = 200 - NOT_FOUND = 404 - - -def get_version() -> str: - """Get the version for the current environment.""" - return core_config.get_version() - # return "2.10.0rc2.1" - - -def get_user_agent() -> str: - """Get the user agent the current environment.""" - return os.environ.get("TIDY3D_AGENT", f"Python-Client/{get_version()}") - - -def api_key() -> Optional[str]: - """Get the api key for the current environment.""" - - if os.environ.get(SIMCLOUD_APIKEY): - return os.environ.get(SIMCLOUD_APIKEY) - - try: - apikey = config.web.apikey - except AttributeError: - return None - - if apikey is None: - return None - if hasattr(apikey, "get_secret_value"): - return apikey.get_secret_value() - return str(apikey) - - -def api_key_auth(request: requests.request) -> requests.request: - """Save the authentication info in a request. - - Parameters - ---------- - request : requests.request - The original request to set authentication for. - - Returns - ------- - requests.request - The request with authentication set. - """ - key = api_key() - version = get_version() - if not key: - raise ValueError( - "API key not found. To get your API key, sign into 'https://tidy3d.simulation.cloud' " - "and copy it from your 'Account' page. Then you can configure tidy3d through command " - "line 'tidy3d configure' and enter your API key when prompted. " - "Alternatively, especially if using windows, you can manually create the configuration " - "file by creating a file at their home directory '~/.tidy3d/config' (unix) or " - "'.tidy3d/config' (windows) containing the following line: " - "apikey = 'XXX'. Here XXX is your API key copied from your account page within quotes." - ) - if not version: - raise ValueError("version not found.") - - request.headers[HEADER_APIKEY] = key - request.headers[HEADER_VERSION] = version - request.headers[HEADER_SOURCE] = HEADER_SOURCE_VALUE - request.headers[HEADER_USER_AGENT] = get_user_agent() - return request - - -def get_headers() -> dict[str, Optional[str]]: - """get headers for http request. - - Returns - ------- - dict[str, str] - dictionary with "Authorization" and "Application" keys. - """ - return { - HEADER_APIKEY: api_key(), - HEADER_APPLICATION: HEADER_APPLICATION_VALUE, - HEADER_USER_AGENT: get_user_agent(), - } - - -def http_interceptor(func: Callable[..., Any]) -> Callable[..., JSONType]: - """Intercept the response and raise an exception if the status code is not 200.""" - - @wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> JSONType: - """The wrapper function.""" - suppress_404 = kwargs.pop("suppress_404", False) - - # Extend some capabilities of func - resp = func(*args, **kwargs) - - if resp.status_code != ResponseCodes.OK.value: - if resp.status_code == ResponseCodes.NOT_FOUND.value: - if suppress_404: - return None - raise WebNotFoundError("Resource not found (HTTP 404).") - try: - json_resp = resp.json() - except Exception: - json_resp = None - - # Build a helpful error message using available fields - err_msg = None - if isinstance(json_resp, dict): - parts = [] - for key in ("error", "message", "msg", "detail", "code", "httpStatus", "warning"): - val = json_resp.get(key) - if not val: - continue - if key == "error": - # Always include the raw 'error' payload for debugging. Also try to extract a nested message. - if isinstance(val, str): - try: - nested = json.loads(val) - if isinstance(nested, dict): - nested_msg = ( - nested.get("message") - or nested.get("error") - or nested.get("msg") - ) - if nested_msg: - parts.append(str(nested_msg)) - except Exception: - pass - parts.append(f"error={val}") - else: - parts.append(f"error={val!s}") - continue - parts.append(str(val)) - if parts: - err_msg = "; ".join(parts) - if not err_msg: - # Fallback to response text or status code - err_msg = resp.text or f"HTTP {resp.status_code}" - - # Append request context to aid debugging - try: - method = getattr(resp.request, "method", "") - url = getattr(resp.request, "url", "") - err_msg = f"{err_msg} [HTTP {resp.status_code} {method} {url}]" - except Exception: - pass - - raise WebError(err_msg) - - if not resp.text: - return None - result = resp.json() - - if isinstance(result, dict): - warning = result.get("warning") - if warning: - log = get_logger() - log.warning(warning) - - if "data" in result: - return result["data"] - - return result - - return wrapper - - -class TLSAdapter(HTTPAdapter): - def init_poolmanager(self, *args: Any, **kwargs: Any) -> None: - try: - ssl_version = ( - ssl.TLSVersion[config.web.ssl_version] - if config.web.ssl_version is not None - else None - ) - except KeyError: - log.warning(f"Invalid SSL/TLS version '{config.web.ssl_version}', using default") - ssl_version = None - context = create_urllib3_context(ssl_version=ssl_version) - kwargs["ssl_context"] = context - return super().init_poolmanager(*args, **kwargs) - - -class HttpSessionManager: - """Http util class.""" - - def __init__(self, session: requests.Session) -> None: - """Initialize the session.""" - self.session = session - self._mounted_ssl_version = None - self._ensure_tls_adapter(config.web.ssl_version) - self.session.verify = config.web.ssl_verify - - def reinit(self) -> None: - """Reinitialize the session.""" - ssl_version = config.web.ssl_version - self._ensure_tls_adapter(ssl_version) - self.session.verify = config.web.ssl_verify - - def _ensure_tls_adapter(self, ssl_version: str) -> None: - if not ssl_version: - self._mounted_ssl_version = None - return - if self._mounted_ssl_version != ssl_version: - self.session.mount("https://", TLSAdapter()) - self._mounted_ssl_version = ssl_version - - @http_interceptor - def get( - self, path: str, json: JSONType = None, params: Optional[dict[str, Any]] = None - ) -> requests.Response: - """Get the resource.""" - self.reinit() - return self.session.get( - url=config.web.build_api_url(path), auth=api_key_auth, json=json, params=params - ) - - @http_interceptor - def post(self, path: str, json: JSONType = None) -> requests.Response: - """Create the resource.""" - self.reinit() - return self.session.post(config.web.build_api_url(path), json=json, auth=api_key_auth) - - @http_interceptor - def put( - self, path: str, json: JSONType = None, files: Optional[dict[str, Any]] = None - ) -> requests.Response: - """Update the resource.""" - self.reinit() - return self.session.put( - config.web.build_api_url(path), json=json, auth=api_key_auth, files=files - ) - - @http_interceptor - def delete( - self, path: str, json: JSONType = None, params: Optional[dict[str, Any]] = None - ) -> requests.Response: - """Delete the resource.""" - self.reinit() - return self.session.delete( - config.web.build_api_url(path), auth=api_key_auth, json=json, params=params - ) - - -http = HttpSessionManager(requests.Session()) diff --git a/tidy3d/web/core/s3utils.py b/tidy3d/web/core/s3utils.py index ebdf048733..401347b271 100644 --- a/tidy3d/web/core/s3utils.py +++ b/tidy3d/web/core/s3utils.py @@ -1,472 +1,22 @@ -"""handles filesystem, storage""" +"""Compatibility shim for :mod:`tidy3d._common.web.core.s3utils`.""" -from __future__ import annotations +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -import os -import tempfile -import urllib -from datetime import datetime -from enum import Enum -from pathlib import Path -from typing import TYPE_CHECKING, Any +# marked as migrated to _common +from __future__ import annotations -import boto3 -from boto3.s3.transfer import TransferConfig -from pydantic import BaseModel, Field -from rich.progress import ( - BarColumn, - DownloadColumn, - Progress, - TextColumn, - TimeRemainingColumn, - TransferSpeedColumn, +from tidy3d._common.web.core.s3utils import ( + IN_TRANSIT_SUFFIX, + DownloadProgress, + UploadProgress, + _get_progress, + _s3_config, + _s3_sts_tokens, + _S3Action, + _S3STSToken, + _UserCredential, + download_file, + download_gz_file, + get_s3_sts_token, + upload_file, ) - -from tidy3d.config import config - -from .core_config import get_logger_console -from .exceptions import WebError -from .file_util import extract_gzip_file -from .http_util import http - -if TYPE_CHECKING: - from collections.abc import Mapping - from os import PathLike - from typing import Callable, Optional - - import rich - -IN_TRANSIT_SUFFIX = ".tmp" - - -class _UserCredential(BaseModel): - """Stores information about user credentials.""" - - access_key_id: str = Field(alias="accessKeyId") - expiration: datetime - secret_access_key: str = Field(alias="secretAccessKey") - session_token: str = Field(alias="sessionToken") - - -class _S3STSToken(BaseModel): - """Stores information about S3 token.""" - - cloud_path: str = Field(alias="cloudpath") - user_credential: _UserCredential = Field(alias="userCredentials") - - def get_bucket(self) -> str: - """Get the bucket name for this token.""" - - r = urllib.parse.urlparse(self.cloud_path) - return r.netloc - - def get_s3_key(self) -> str: - """Get the s3 key for this token.""" - - r = urllib.parse.urlparse(self.cloud_path) - return r.path[1:] - - def get_client(self) -> boto3.client: - """Get the boto client for this token. - - Automatically configures custom S3 endpoint if specified in web.env_vars. - """ - - client_kwargs = { - "service_name": "s3", - "region_name": config.web.s3_region, - "aws_access_key_id": self.user_credential.access_key_id, - "aws_secret_access_key": self.user_credential.secret_access_key, - "aws_session_token": self.user_credential.session_token, - "verify": config.web.ssl_verify, - } - - # Add custom S3 endpoint if configured (e.g., for Nexus deployments) - if config.web.env_vars and "AWS_ENDPOINT_URL_S3" in config.web.env_vars: - s3_endpoint = config.web.env_vars["AWS_ENDPOINT_URL_S3"] - client_kwargs["endpoint_url"] = s3_endpoint - - return boto3.client(**client_kwargs) - - def is_expired(self) -> bool: - """True if token is expired.""" - - return ( - self.user_credential.expiration - - datetime.now(tz=self.user_credential.expiration.tzinfo) - ).total_seconds() < 300 - - -class UploadProgress: - """Updates progressbar with the upload status. - - Attributes - ---------- - progress : rich.progress.Progress() - Progressbar instance from rich - ul_task : rich.progress.Task - Progressbar task instance. - """ - - def __init__(self, size_bytes: int, progress: rich.progress.Progress) -> None: - """initialize with the size of file and rich.progress.Progress() instance. - - Parameters - ---------- - size_bytes: int - Number of total bytes to upload. - progress : rich.progress.Progress() - Progressbar instance from rich - """ - self.progress = progress - self.ul_task = self.progress.add_task("[red]Uploading...", total=size_bytes) - - def report(self, bytes_in_chunk: Any) -> None: - """Update the progressbar with the most recent chunk. - - Parameters - ---------- - bytes_in_chunk : int - Description - """ - self.progress.update(self.ul_task, advance=bytes_in_chunk) - - -class DownloadProgress: - """Updates progressbar using the download status. - - Attributes - ---------- - progress : rich.progress.Progress() - Progressbar instance from rich - ul_task : rich.progress.Task - Progressbar task instance. - """ - - def __init__(self, size_bytes: int, progress: rich.progress.Progress) -> None: - """initialize with the size of file and rich.progress.Progress() instance - - Parameters - ---------- - size_bytes: float - Number of total bytes to download. - progress : rich.progress.Progress() - Progressbar instance from rich - """ - self.progress = progress - self.dl_task = self.progress.add_task("[red]Downloading...", total=size_bytes) - - def report(self, bytes_in_chunk: int) -> None: - """Update the progressbar with the most recent chunk. - - Parameters - ---------- - bytes_in_chunk : float - Description - """ - self.progress.update(self.dl_task, advance=bytes_in_chunk) - - -class _S3Action(Enum): - UPLOADING = "↑" - DOWNLOADING = "↓" - - -def _get_progress(action: _S3Action) -> Progress: - """Get the progress of an action.""" - - col = ( - TextColumn(f"[bold green]{_S3Action.DOWNLOADING.value}") - if action == _S3Action.DOWNLOADING - else TextColumn(f"[bold red]{_S3Action.UPLOADING.value}") - ) - return Progress( - col, - TextColumn("[bold blue]{task.fields[filename]}"), - BarColumn(), - "[progress.percentage]{task.percentage:>3.1f}%", - "•", - DownloadColumn(), - "•", - TransferSpeedColumn(), - "•", - TimeRemainingColumn(), - console=get_logger_console(), - ) - - -_s3_config = TransferConfig() - -_s3_sts_tokens: dict[str, _S3STSToken] = {} - - -def get_s3_sts_token( - resource_id: str, file_name: PathLike, extra_arguments: Optional[Mapping[str, str]] = None -) -> _S3STSToken: - """Get s3 sts token for the given resource id and file name. - - Parameters - ---------- - resource_id : str - The resource id, e.g. task id. - file_name : PathLike - The remote file name on S3. - extra_arguments : Mapping[str, str] - Additional arguments for the query url. - - Returns - ------- - _S3STSToken - The S3 STS token. - """ - file_name = str(Path(file_name).as_posix()) - cache_key = f"{resource_id}:{file_name}" - if cache_key not in _s3_sts_tokens or _s3_sts_tokens[cache_key].is_expired(): - method = f"tidy3d/py/tasks/{resource_id}/file?filename={file_name}" - if extra_arguments is not None: - method += "&" + "&".join(f"{k}={v}" for k, v in extra_arguments.items()) - resp = http.get(method) - token = _S3STSToken.model_validate(resp) - _s3_sts_tokens[cache_key] = token - return _s3_sts_tokens[cache_key] - - -def upload_file( - resource_id: str, - path: PathLike, - remote_filename: PathLike, - verbose: bool = True, - progress_callback: Optional[Callable[[float], None]] = None, - extra_arguments: Optional[Mapping[str, str]] = None, -) -> None: - """Upload a file to S3. - - Parameters - ---------- - resource_id : str - The resource id, e.g. task id. - path : PathLike - Path to the file to upload. - remote_filename : PathLike - The remote file name on S3 relative to the resource context root path. - verbose : bool = True - Whether to display a progressbar for the upload. - progress_callback : Callable[[float], None] = None - User-supplied callback function with ``bytes_in_chunk`` as argument. - extra_arguments : Mapping[str, str] - Additional arguments used to specify the upload bucket. - """ - - path = Path(path) - token = get_s3_sts_token(resource_id, remote_filename, extra_arguments) - - def _upload(_callback: Callable) -> None: - """Perform the upload with a callback function. - - Parameters - ---------- - _callback : Callable[[float], None] - Callback function for upload, accepts ``bytes_in_chunk`` - """ - - with path.open("rb") as data: - token.get_client().upload_fileobj( - data, - Bucket=token.get_bucket(), - Key=token.get_s3_key(), - Callback=_callback, - Config=_s3_config, - ExtraArgs={"ContentEncoding": "gzip"} - if token.get_s3_key().endswith(".gz") - else None, - ) - - if progress_callback is not None: - _upload(progress_callback) - else: - if verbose: - with _get_progress(_S3Action.UPLOADING) as progress: - total_size = path.stat().st_size - task_id = progress.add_task( - "upload", filename=str(remote_filename), total=total_size - ) - - def _callback(bytes_in_chunk: int) -> None: - progress.update(task_id, advance=bytes_in_chunk) - - _upload(_callback) - - progress.update(task_id, completed=total_size, refresh=True) - - else: - _upload(lambda bytes_in_chunk: None) - - -def download_file( - resource_id: str, - remote_filename: PathLike, - to_file: Optional[PathLike] = None, - verbose: bool = True, - progress_callback: Optional[Callable[[float], None]] = None, -) -> Path: - """Download file from S3. - - Parameters - ---------- - resource_id : str - The resource id, e.g. task id. - remote_filename : PathLike - Path to the remote file. - to_file : PathLike = None - Local filename to save to; if not specified, defaults to ``remote_filename`` in a - directory named after ``resource_id``. - verbose : bool = True - Whether to display a progressbar for the upload - progress_callback : Callable[[float], None] = None - User-supplied callback function with ``bytes_in_chunk`` as argument. - """ - - token = get_s3_sts_token(resource_id, remote_filename) - client = token.get_client() - meta_data = client.head_object(Bucket=token.get_bucket(), Key=token.get_s3_key()) - - # Get only last part of the remote file name - remote_basename = Path(remote_filename).name - - # set to_file if None - if to_file is None: - to_path = Path(resource_id) / remote_basename - else: - to_path = Path(to_file) - - # make the leading directories in the 'to_path', if any - to_path.parent.mkdir(parents=True, exist_ok=True) - - def _download(_callback: Callable) -> None: - """Perform the download with a callback function. - - Parameters - ---------- - _callback : Callable[[float], None] - Callback function for download, accepts ``bytes_in_chunk`` - """ - # Caller can assume the existence of the file means download succeeded. - # So make sure this file does not exist until that assumption is true. - to_path.unlink(missing_ok=True) - # Download to a temporary file. - try: - fd, tmp_file_path_str = tempfile.mkstemp(suffix=IN_TRANSIT_SUFFIX, dir=to_path.parent) - os.close(fd) # `tempfile.mkstemp()` creates and opens a randomly named file. close it. - to_path_tmp = Path(tmp_file_path_str) - client.download_file( - Bucket=token.get_bucket(), - Filename=str(to_path_tmp), - Key=token.get_s3_key(), - Callback=_callback, - Config=_s3_config, - ) - to_path_tmp.rename(to_path) - except Exception as e: - to_path_tmp.unlink(missing_ok=True) # Delete incompletely downloaded file. - raise e - - if progress_callback is not None: - _download(progress_callback) - else: - if verbose: - with _get_progress(_S3Action.DOWNLOADING) as progress: - total_size = meta_data.get("ContentLength", 0) - progress.start() - task_id = progress.add_task("download", filename=remote_basename, total=total_size) - - def _callback(bytes_in_chunk: int) -> None: - progress.update(task_id, advance=bytes_in_chunk) - - _download(_callback) - - progress.update(task_id, completed=total_size, refresh=True) - - else: - _download(lambda bytes_in_chunk: None) - - return to_path - - -def download_gz_file( - resource_id: str, - remote_filename: PathLike, - to_file: Optional[PathLike] = None, - verbose: bool = True, - progress_callback: Optional[Callable[[float], None]] = None, -) -> Path: - """Download a ``.gz`` file and unzip it into ``to_file``, unless ``to_file`` itself - ends in .gz - - Parameters - ---------- - resource_id : str - The resource id, e.g. task id. - remote_filename : PathLike - Path to the remote file. - to_file : Optional[PathLike] = None - Local filename to save to; if not specified, defaults to ``remote_filename`` with the - ``.gz`` suffix removed in a directory named after ``resource_id``. - verbose : bool = True - Whether to display a progressbar for the upload - progress_callback : Callable[[float], None] = None - User-supplied callback function with ``bytes_in_chunk`` as argument. - """ - - # If to_file is a gzip extension, just download - if to_file is None: - remote_basename = Path(remote_filename).name - if remote_basename.endswith(".gz"): - remote_basename = remote_basename[:-3] - to_path = Path(resource_id) / remote_basename - else: - to_path = Path(to_file) - - suffixes = "".join(to_path.suffixes).lower() - if suffixes.endswith(".gz"): - return download_file( - resource_id, - remote_filename, - to_file=to_path, - verbose=verbose, - progress_callback=progress_callback, - ) - - # Otherwise, download and unzip - # The tempfile is set as ``hdf5.gz`` so that the mock download in the webapi tests works - tmp_file, tmp_file_path_str = tempfile.mkstemp(".hdf5.gz") - os.close(tmp_file) - - # make the leading directories in the 'to_file', if any - to_path.parent.mkdir(parents=True, exist_ok=True) - try: - download_file( - resource_id, - remote_filename, - to_file=Path(tmp_file_path_str), - verbose=verbose, - progress_callback=progress_callback, - ) - if not Path(tmp_file_path_str).exists(): - raise WebError(f"Failed to download and extract '{remote_filename}'.") - - tmp_out_fd, tmp_out_path_str = tempfile.mkstemp( - suffix=IN_TRANSIT_SUFFIX, dir=to_path.parent - ) - os.close(tmp_out_fd) - tmp_out_path = Path(tmp_out_path_str) - try: - extract_gzip_file(Path(tmp_file_path_str), tmp_out_path) - tmp_out_path.replace(to_path) - except Exception as e: - tmp_out_path.unlink(missing_ok=True) - raise WebError( - f"Failed to extract '{remote_filename}' from '{tmp_file_path_str}' to '{to_path}'." - ) from e - finally: - Path(tmp_file_path_str).unlink(missing_ok=True) - return to_path diff --git a/tidy3d/web/core/stub.py b/tidy3d/web/core/stub.py index cebffd9ba0..91fc96ac90 100644 --- a/tidy3d/web/core/stub.py +++ b/tidy3d/web/core/stub.py @@ -1,84 +1,11 @@ -"""Defines interface that can be subclassed to use with the tidy3d webapi""" +"""Compatibility shim for :mod:`tidy3d._common.web.core.stub`.""" -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from os import PathLike - - -class TaskStubData(ABC): - @abstractmethod - def from_file(self, file_path: PathLike) -> TaskStubData: - """Loads a :class:`TaskStubData` from .yaml, .json, or .hdf5 file. - - Parameters - ---------- - file_path : PathLike - Full path to the .yaml or .json or .hdf5 file to load the :class:`Stub` from. - - Returns - ------- - :class:`Stub` - An instance of the component class calling ``load``. - - """ - - @abstractmethod - def to_file(self, file_path: PathLike) -> None: - """Loads a :class:`Stub` from .yaml, .json, or .hdf5 file. - - Parameters - ---------- - file_path : PathLike - Full path to the .yaml or .json or .hdf5 file to load the :class:`Stub` from. +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility - Returns - ------- - :class:`Stub` - An instance of the component class calling ``load``. - """ - - -class TaskStub(ABC): - @abstractmethod - def from_file(self, file_path: PathLike) -> TaskStub: - """Loads a :class:`TaskStubData` from .yaml, .json, or .hdf5 file. - - Parameters - ---------- - file_path : PathLike - Full path to the .yaml or .json or .hdf5 file to load the :class:`Stub` from. - - Returns - ------- - :class:`TaskStubData` - An instance of the component class calling ``load``. - """ - - @abstractmethod - def to_file(self, file_path: PathLike) -> None: - """Loads a :class:`TaskStub` from .yaml, .json, .hdf5 or .hdf5.gz file. - - Parameters - ---------- - file_path : PathLike - Full path to the .yaml or .json or .hdf5 file to load the :class:`TaskStub` from. - - Returns - ------- - :class:`Stub` - An instance of the component class calling ``load``. - """ - - @abstractmethod - def to_hdf5_gz(self, fname: PathLike) -> None: - """Exports :class:`TaskStub` instance to .hdf5.gz file. +# marked as migrated to _common +from __future__ import annotations - Parameters - ---------- - fname : PathLike - Full path to the .hdf5.gz file to save the :class:`TaskStub` to. - """ +from tidy3d._common.web.core.stub import ( + TaskStub, + TaskStubData, +) diff --git a/tidy3d/web/core/task_core.py b/tidy3d/web/core/task_core.py index 194e4235c7..704961bae8 100644 --- a/tidy3d/web/core/task_core.py +++ b/tidy3d/web/core/task_core.py @@ -1,1005 +1,14 @@ -"""Tidy3d webapi types.""" +"""Compatibility shim for :mod:`tidy3d._common.web.core.task_core`.""" -from __future__ import annotations - -import os -import pathlib -import tempfile -from datetime import datetime -from typing import TYPE_CHECKING, Optional - -from botocore.exceptions import ClientError -from pydantic import Field, TypeAdapter +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -import tidy3d as td -from tidy3d.config import config -from tidy3d.exceptions import ValidationError +# marked as migrated to _common +from __future__ import annotations -from . import http_util -from .cache import FOLDER_CACHE -from .constants import ( - SIM_ERROR_FILE, - SIM_FILE_HDF5_GZ, - SIM_LOG_FILE, - SIM_VALIDATION_FILE, - SIMULATION_DATA_HDF5_GZ, +from tidy3d._common.web.core.task_core import ( + BatchTask, + Folder, + SimulationTask, + TaskFactory, + WebTask, ) -from .core_config import get_logger_console -from .exceptions import WebError, WebNotFoundError -from .file_util import read_simulation_from_hdf5 -from .http_util import get_version as _get_protocol_version -from .http_util import http -from .s3utils import download_file, download_gz_file, upload_file -from .task_info import BatchDetail, TaskInfo -from .types import PayType, Queryable, ResourceLifecycle, Submittable, Tidy3DResource - -if TYPE_CHECKING: - from os import PathLike - from typing import Callable, Union - - import requests - - from .stub import TaskStub - - -class Folder(Tidy3DResource, Queryable, extra="allow"): - """Tidy3D Folder.""" - - folder_id: str = Field( - title="Folder id", - description="folder id", - alias="projectId", - ) - folder_name: str = Field( - title="Folder name", - description="folder name", - alias="projectName", - ) - - @classmethod - def list(cls, projects_endpoint: str = "tidy3d/projects") -> []: - """List all folders. - - Returns - ------- - folders : [Folder] - List of folders - """ - resp = http.get(projects_endpoint) - return TypeAdapter(list[Folder]).validate_python(resp) if resp else None - - @classmethod - def get( - cls, - folder_name: str, - create: bool = False, - projects_endpoint: str = "tidy3d/projects", - project_endpoint: str = "tidy3d/project", - ) -> Folder: - """Get folder by name. - - Parameters - ---------- - folder_name : str - Name of the folder. - create : str - If the folder doesn't exist, create it. - - Returns - ------- - folder : Folder - """ - folder = FOLDER_CACHE.get(folder_name) - if not folder: - resp = http.get(project_endpoint, params={"projectName": folder_name}) - if resp: - folder = Folder(**resp) - if create and not folder: - resp = http.post(projects_endpoint, {"projectName": folder_name}) - if resp: - folder = Folder(**resp) - FOLDER_CACHE[folder_name] = folder - return folder - - @classmethod - def create(cls, folder_name: str) -> Folder: - """Create a folder, return existing folder if there is one has the same name. - - Parameters - ---------- - folder_name : str - Name of the folder. - - Returns - ------- - folder : Folder - """ - return Folder.get(folder_name, True) - - def delete(self, projects_endpoint: str = "tidy3d/projects") -> None: - """Remove this folder.""" - - http.delete(f"{projects_endpoint}/{self.folder_id}") - - def delete_old(self, days_old: int) -> int: - """Remove folder contents older than ``days_old``.""" - - return http.delete( - f"tidy3d/tasks/{self.folder_id}/tasks", - params={"daysOld": days_old}, - ) - - def list_tasks(self, projects_endpoint: str = "tidy3d/projects") -> list[Tidy3DResource]: - """List all tasks in this folder. - - Returns - ------- - tasks : list[:class:`.SimulationTask`] - List of tasks in this folder - """ - resp = http.get(f"{projects_endpoint}/{self.folder_id}/tasks") - return TypeAdapter(list[SimulationTask]).validate_python(resp) if resp else None - - -class WebTask(ResourceLifecycle, Submittable, extra="allow"): - """Interface for managing the running a task on the server.""" - - task_id: Optional[str] = Field( - None, - title="task_id", - description="Task ID number, set when the task is uploaded, leave as None.", - alias="taskId", - ) - - @classmethod - def create( - cls, - task_type: str, - task_name: str, - folder_name: str = "default", - callback_url: Optional[str] = None, - simulation_type: str = "tidy3d", - parent_tasks: Optional[list[str]] = None, - file_type: str = "Gz", - projects_endpoint: str = "tidy3d/projects", - ) -> SimulationTask: - """Create a new task on the server. - - Parameters - ---------- - task_type: :class".TaskType" - The type of task. - task_name: str - The name of the task. - folder_name: str, - The name of the folder to store the task. Default is "default". - callback_url: str - Http PUT url to receive simulation finish event. The body content is a json file with - fields ``{'id', 'status', 'name', 'workUnit', 'solverVersion'}``. - simulation_type : str - Type of simulation being uploaded. - parent_tasks : list[str] - List of related task ids. - file_type: str - the simulation file type Json, Hdf5, Gz - - Returns - ------- - :class:`SimulationTask` - :class:`SimulationTask` object containing info about status, size, - credits of task and others. - """ - - # handle backwards compatibility, "tidy3d" is the default simulation_type - if simulation_type is None: - simulation_type = "tidy3d" - - folder = Folder.get(folder_name, create=True) - if task_type in ["RF", "TERMINAL_CM", "MODAL_CM"]: - payload = { - "groupName": task_name, - "folderId": folder.folder_id, - "fileType": file_type, - "taskType": task_type, - } - resp = http.post("rf/task", payload) - else: - payload = { - "taskName": task_name, - "taskType": task_type, - "callbackUrl": callback_url, # type: ignore[dict-item] - "simulationType": simulation_type, - "parentTasks": parent_tasks, # type: ignore[dict-item] - "fileType": file_type, - } - resp = http.post(f"{projects_endpoint}/{folder.folder_id}/tasks", payload) - return SimulationTask(**resp, taskType=task_type, folder_name=folder_name) - - def get_url(self) -> str: - base = str(config.web.website_endpoint or "") - if isinstance(self, BatchTask): - return "/".join([base.rstrip("/"), f"rf?taskId={self.task_id}"]) - return "/".join([base.rstrip("/"), f"workbench?taskId={self.task_id}"]) - - def get_folder_url(self) -> Optional[str]: - folder_id = getattr(self, "folder_id", None) - if not folder_id: - return None - base = str(config.web.website_endpoint or "") - return "/".join([base.rstrip("/"), f"folders/{folder_id}"]) - - def get_log( - self, - to_file: PathLike, - verbose: bool = True, - progress_callback: Optional[Callable[[float], None]] = None, - ) -> pathlib.Path: - """Get log file from Server. - - Parameters - ---------- - to_file: PathLike - Save file to path. - verbose: bool = True - Whether to display progress bars. - progress_callback : Callable[[float], None] = None - Optional callback function called while downloading the data. - - Returns - ------- - path: pathlib.Path - Path to saved file. - """ - - if not self.task_id: - raise WebError("Expected field 'task_id' is unset.") - - target_path = pathlib.Path(to_file) - - return download_file( - self.task_id, - SIM_LOG_FILE, - to_file=target_path, - verbose=verbose, - progress_callback=progress_callback, - ) - - def get_data_hdf5( - self, - to_file: PathLike, - remote_data_file_gz: PathLike = SIMULATION_DATA_HDF5_GZ, - verbose: bool = True, - progress_callback: Optional[Callable[[float], None]] = None, - ) -> pathlib.Path: - """Download data artifact (simulation or batch) with gz fallback handling. - - Parameters - ---------- - remote_data_file_gz : PathLike - Gzipped remote filename. - to_file : PathLike - Local target path. - verbose : bool - Whether to log progress. - progress_callback : Optional[Callable[[float], None]] - Progress callback. - - Returns - ------- - pathlib.Path - Saved local path. - """ - if not self.task_id: - raise WebError("Expected field 'task_id' is unset.") - target_path = pathlib.Path(to_file) - file = None - try: - file = download_gz_file( - resource_id=self.task_id, - remote_filename=remote_data_file_gz, - to_file=target_path, - verbose=verbose, - progress_callback=progress_callback, - ) - except ClientError: - if verbose: - console = get_logger_console() - console.log(f"Unable to download '{remote_data_file_gz}'.") - if not file: - try: - file = download_file( - resource_id=self.task_id, - remote_filename=str(remote_data_file_gz)[:-3], - to_file=target_path, - verbose=verbose, - progress_callback=progress_callback, - ) - except Exception as e: - raise WebError( - "Failed to download the data file from the server. " - "Please confirm that the task completed successfully." - ) from e - return file - - @staticmethod - def is_batch(resource_id: str) -> bool: - """Checks if a given resource ID corresponds to a valid batch task. - - This is a utility function to verify a batch task's existence before - instantiating the class. - - Parameters - ---------- - resource_id : str - The unique identifier for the resource. - - Returns - ------- - bool - ``True`` if the resource is a valid batch task, ``False`` otherwise. - """ - try: - # TODO PROPERLY FIXME - # Disable non critical logs due to check for resourceId, until we have a dedicated API for this - resp = http.get( - f"rf/task/{resource_id}/statistics", - suppress_404=True, - ) - status = bool(resp and isinstance(resp, dict) and "status" in resp) - return status - except Exception: - return False - - def delete(self, versions: bool = False) -> None: - """Delete current task from server. - - Parameters - ---------- - versions : bool = False - If ``True``, delete all versions of the task in the task group. Otherwise, delete only - the version associated with the current task ID. - """ - if not self.task_id: - raise ValueError("Task id not found.") - - task_details = self.detail().model_dump() - - if task_details and "groupId" in task_details: - group_id = task_details["groupId"] - if versions: - http.delete("tidy3d/group", json={"groupIds": [group_id]}) - return - elif "version" in task_details: - version = task_details["version"] - http.delete(f"tidy3d/group/{group_id}/versions", json={"versions": [version]}) - return - - # Fallback to old method if we can't get the groupId and version - http.delete(f"tidy3d/tasks/{self.task_id}") - - -class SimulationTask(WebTask): - """Interface for managing the running of solver tasks on the server.""" - - folder_id: Optional[str] = Field( - None, - title="folder_id", - description="Folder ID number, set when the task is uploaded, leave as None.", - alias="folderId", - ) - status: Optional[str] = Field(None, title="status", description="Simulation task status.") - - real_flex_unit: Optional[float] = Field( - None, title="real FlexCredits", description="Billed FlexCredits.", alias="realCost" - ) - - created_at: Optional[datetime] = Field( - None, - title="created_at", - description="Time at which this task was created.", - alias="createdAt", - ) - - task_type: Optional[str] = Field( - None, title="task_type", description="The type of task.", alias="taskType" - ) - - folder_name: Optional[str] = Field( - "default", - title="Folder Name", - description="Name of the folder associated with this task.", - alias="folderName", - ) - - callback_url: Optional[str] = Field( - None, - title="Callback URL", - description="Http PUT url to receive simulation finish event. " - "The body content is a json file with fields " - "``{'id', 'status', 'name', 'workUnit', 'solverVersion'}``.", - ) - - # simulation_type: str = Field( - # None, - # title="Simulation Type", - # description="Type of simulation, used internally only.", - # ) - - # parent_tasks: Tuple[TaskId, ...] = Field( - # None, - # title="Parent Tasks", - # description="List of parent task ids for the simulation, used internally only." - # ) - - @classmethod - def get(cls, task_id: str, verbose: bool = True) -> SimulationTask: - """Get task from the server by id. - - Parameters - ---------- - task_id: str - Unique identifier of task on server. - verbose: - If `True`, will print progressbars and status, otherwise, will run silently. - - Returns - ------- - :class:`.SimulationTask` - :class:`.SimulationTask` object containing info about status, - size, credits of task and others. - """ - try: - resp = http.get(f"tidy3d/tasks/{task_id}/detail") - except WebNotFoundError as e: - td.log.error(f"The requested task ID '{task_id}' does not exist.") - raise e - - task = SimulationTask(**resp) if resp else None - return task - - @classmethod - def get_running_tasks(cls) -> list[SimulationTask]: - """Get a list of running tasks from the server" - - Returns - ------- - List[:class:`.SimulationTask`] - :class:`.SimulationTask` object containing info about status, - size, credits of task and others. - """ - resp = http.get("tidy3d/py/tasks") - if not resp: - return [] - return TypeAdapter(list[SimulationTask]).validate_python(resp) - - def detail(self) -> TaskInfo: - """Fetches the detailed information and status of the task. - - Returns - ------- - TaskInfo - An object containing the task's latest data. - """ - resp = http.get(f"tidy3d/tasks/{self.task_id}/detail") - return TaskInfo(**{"taskId": self.task_id, "taskType": self.task_type, **resp}) # type: ignore[dict-item] - - def get_simulation_json(self, to_file: PathLike, verbose: bool = True) -> None: - """Get json file for a :class:`.Simulation` from server. - - Parameters - ---------- - to_file: PathLike - Save file to path. - verbose: bool = True - Whether to display progress bars. - """ - if not self.task_id: - raise WebError("Expected field 'task_id' is unset.") - - to_file = pathlib.Path(to_file) - - hdf5_file, hdf5_file_path = tempfile.mkstemp(".hdf5") - os.close(hdf5_file) - try: - self.get_simulation_hdf5(hdf5_file_path) - if os.path.exists(hdf5_file_path): - json_string = read_simulation_from_hdf5(hdf5_file_path) - to_file.parent.mkdir(parents=True, exist_ok=True) - with to_file.open("w", encoding="utf-8") as file: - # Write the string to the file - file.write(json_string.decode("utf-8")) - if verbose: - console = get_logger_console() - console.log(f"Generate {to_file} successfully.") - else: - raise WebError("Failed to download simulation.json.") - finally: - os.unlink(hdf5_file_path) - - def upload_simulation( - self, - stub: TaskStub, - verbose: bool = True, - progress_callback: Optional[Callable[[float], None]] = None, - remote_sim_file: PathLike = SIM_FILE_HDF5_GZ, - ) -> None: - """Upload :class:`.Simulation` object to Server. - - Parameters - ---------- - stub: :class:`TaskStub` - An instance of TaskStub. - verbose: bool = True - Whether to display progress bars. - progress_callback : Callable[[float], None] = None - Optional callback function called while uploading the data. - """ - if not self.task_id: - raise WebError("Expected field 'task_id' is unset.") - if not stub: - raise WebError("Expected field 'simulation' is unset.") - # Also upload hdf5.gz containing all data. - file, file_name = tempfile.mkstemp() - os.close(file) - try: - # upload simulation - # compress .hdf5 to .hdf5.gz - stub.to_hdf5_gz(file_name) - upload_file( - self.task_id, - file_name, - remote_sim_file, - verbose=verbose, - progress_callback=progress_callback, - ) - finally: - os.unlink(file_name) - - def upload_file( - self, - local_file: PathLike, - remote_filename: str, - verbose: bool = True, - progress_callback: Optional[Callable[[float], None]] = None, - ) -> None: - """ - Upload file to platform. Using this method when the json file is too large to parse - as :class".simulation". - Parameters - ---------- - local_file: PathLike - Local file path. - remote_filename: str - file name on the server - verbose: bool = True - Whether to display progress bars. - progress_callback : Callable[[float], None] = None - Optional callback function called while uploading the data. - """ - if not self.task_id: - raise WebError("Expected field 'task_id' is unset.") - - upload_file( - self.task_id, - local_file, - remote_filename, - verbose=verbose, - progress_callback=progress_callback, - ) - - def submit( - self, - solver_version: Optional[str] = None, - worker_group: Optional[str] = None, - pay_type: Union[PayType, str] = PayType.AUTO, - priority: Optional[int] = None, - ) -> None: - """Kick off this task. - - It will be uploaded to server before - starting the task. Otherwise, this method assumes that the Simulation has been uploaded by - the upload_file function, so the task will be kicked off directly. - - Parameters - ---------- - solver_version: str = None - target solver version. - worker_group: str = None - worker group - pay_type: Union[PayType, str] = PayType.AUTO - Which method to pay the simulation. - priority: int = None - Priority of the simulation in the Virtual GPU (vGPU) queue (1 = lowest, 10 = highest). - It affects only simulations from vGPU licenses and does not impact simulations using FlexCredits. - """ - pay_type = PayType(pay_type) if not isinstance(pay_type, PayType) else pay_type - - if solver_version: - protocol_version = None - else: - protocol_version = http_util.get_version() - - http.post( - f"tidy3d/tasks/{self.task_id}/submit", - { - "solverVersion": solver_version, - "workerGroup": worker_group, - "protocolVersion": protocol_version, - "enableCaching": config.web.enable_caching, - "payType": pay_type.value, - "priority": priority, - }, - ) - - def estimate_cost(self, solver_version: Optional[str] = None) -> float: - """Compute the maximum flex unit charge for a given task, assuming the simulation runs for - the full ``run_time``. If early shut-off is triggered, the cost is adjusted proportionately. - - Parameters - ---------- - solver_version: str - target solver version. - - Returns - ------- - flex_unit_cost: float - estimated cost in FlexCredits - """ - if not self.task_id: - raise WebError("Expected field 'task_id' is unset.") - - if solver_version: - protocol_version = None - else: - protocol_version = http_util.get_version() - - resp = http.post( - f"tidy3d/tasks/{self.task_id}/metadata", - { - "solverVersion": solver_version, - "protocolVersion": protocol_version, - }, - ) - return resp - - def get_simulation_hdf5( - self, - to_file: PathLike, - verbose: bool = True, - progress_callback: Optional[Callable[[float], None]] = None, - remote_sim_file: PathLike = SIM_FILE_HDF5_GZ, - ) -> pathlib.Path: - """Get simulation.hdf5 file from Server. - - Parameters - ---------- - to_file: PathLike - Save file to path. - verbose: bool = True - Whether to display progress bars. - progress_callback : Callable[[float], None] = None - Optional callback function called while downloading the data. - - Returns - ------- - path: pathlib.Path - Path to saved file. - """ - if not self.task_id: - raise WebError("Expected field 'task_id' is unset.") - - target_path = pathlib.Path(to_file) - - return download_gz_file( - resource_id=self.task_id, - remote_filename=remote_sim_file, - to_file=target_path, - verbose=verbose, - progress_callback=progress_callback, - ) - - def get_running_info(self) -> tuple[float, float]: - """Gets the % done and field_decay for a running task. - - Returns - ------- - perc_done : float - Percentage of run done (in terms of max number of time steps). - Is ``None`` if run info not available. - field_decay : float - Average field intensity normalized to max value (1.0). - Is ``None`` if run info not available. - """ - if not self.task_id: - raise WebError("Expected field 'task_id' is unset.") - - resp = http.get(f"tidy3d/tasks/{self.task_id}/progress") - perc_done = resp.get("perc_done") - field_decay = resp.get("field_decay") - return perc_done, field_decay - - def get_log( - self, - to_file: PathLike, - verbose: bool = True, - progress_callback: Optional[Callable[[float], None]] = None, - ) -> pathlib.Path: - """Get log file from Server. - - Parameters - ---------- - to_file: PathLike - Save file to path. - verbose: bool = True - Whether to display progress bars. - progress_callback : Callable[[float], None] = None - Optional callback function called while downloading the data. - - Returns - ------- - path: pathlib.Path - Path to saved file. - """ - - if not self.task_id: - raise WebError("Expected field 'task_id' is unset.") - - target_path = pathlib.Path(to_file) - - return download_file( - self.task_id, - SIM_LOG_FILE, - to_file=target_path, - verbose=verbose, - progress_callback=progress_callback, - ) - - def get_error_json( - self, to_file: PathLike, verbose: bool = True, validation: bool = False - ) -> pathlib.Path: - """Get error json file for a :class:`.Simulation` from server. - - Parameters - ---------- - to_file: PathLike - Save file to path. - verbose: bool = True - Whether to display progress bars. - validation: bool = False - Whether to get a validation error file or a solver error file. - - Returns - ------- - path: pathlib.Path - Path to saved file. - """ - if not self.task_id: - raise WebError("Expected field 'task_id' is unset.") - - target_path = pathlib.Path(to_file) - target_file = SIM_ERROR_FILE if not validation else SIM_VALIDATION_FILE - - return download_file( - self.task_id, - target_file, - to_file=target_path, - verbose=verbose, - ) - - def abort(self) -> requests.Response: - """Abort the current task on the server.""" - if not self.task_id: - raise ValueError("Task id not found.") - return http.put( - "tidy3d/tasks/abort", json={"taskType": self.task_type, "taskId": self.task_id} - ) - - def validate_post_upload(self, parent_tasks: Optional[list[str]] = None) -> None: - """Perform checks after task is uploaded and metadata is processed.""" - if self.task_type == "HEAT_CHARGE" and parent_tasks: - try: - if len(parent_tasks) > 1: - raise ValueError( - "A single parent 'task_id' corresponding to the task in which the meshing " - "was run must be provided." - ) - try: - # get mesh task info - mesh_task = SimulationTask.get(parent_tasks[0], verbose=False) - assert mesh_task.task_type == "VOLUME_MESH" - assert mesh_task.status == "success" - # get up-to-date task info - task = SimulationTask.get(self.task_id, verbose=False) - if task.fileMd5 != mesh_task.childFileMd5: - raise ValidationError( - "Simulation stored in parent task 'VolumeMesher' does not match the " - "current simulation." - ) - except Exception as e: - raise ValidationError( - "The parent task must be a 'VolumeMesher' task which has been successfully " - "run and is associated to the same 'HeatChargeSimulation' as provided here." - ) from e - - except Exception as e: - raise WebError(f"Provided 'parent_tasks' failed validation: {e!s}") from e - - -class BatchTask(WebTask): - """Interface for managing a batch task on the server.""" - - task_type: Optional[str] = Field( - None, title="task_type", description="The type of task.", alias="taskType" - ) - - @classmethod - def get(cls, task_id: str, verbose: bool = True) -> BatchTask: - """Get batch task by id. - - Parameters - ---------- - task_id: str - Unique identifier of batch on server. - verbose: - If `True`, will print progressbars and status, otherwise, will run silently. - - Returns - ------- - :class:`.BatchTask` | None - BatchTask object if found, otherwise None. - """ - try: - resp = http.get(f"rf/task/{task_id}/statistics") - except WebNotFoundError as e: - td.log.error(f"The requested batch ID '{task_id}' does not exist.") - raise e - # Extract taskType from response if available - if resp: - task_type = resp.get("taskType") if isinstance(resp, dict) else None - return BatchTask(taskId=task_id, taskType=task_type) - return None - - def detail(self) -> BatchDetail: - """Fetches the detailed information and status of the batch. - - Returns - ------- - BatchDetail - An object containing the batch's latest data. - """ - resp = http.get( - f"rf/task/{self.task_id}/statistics", - ) - # Some backends may return null for collection fields; coerce to sensible defaults - if isinstance(resp, dict): - if resp.get("tasks") is None: - resp["tasks"] = [] - return BatchDetail(**(resp or {})) - - def check( - self, - check_task_type: str, - solver_version: Optional[str] = None, - protocol_version: Optional[str] = None, - ) -> requests.Response: - """Submits a request to validate the batch configuration on the server. - - Parameters - ---------- - solver_version : Optional[str], default=None - The version of the solver to use for validation. - protocol_version : Optional[str], default=None - The data protocol version. Defaults to the current version. - - Returns - ------- - Any - The server's response to the check request. - """ - if protocol_version is None: - protocol_version = _get_protocol_version() - return http.post( - f"rf/task/{self.task_id}/check", - { - "solverVersion": solver_version, - "protocolVersion": protocol_version, - "taskType": check_task_type, - }, - ) - - def submit( - self, - solver_version: Optional[str] = None, - protocol_version: Optional[str] = None, - worker_group: Optional[str] = None, - pay_type: Union[PayType, str] = PayType.AUTO, - priority: Optional[int] = None, - ) -> requests.Response: - """Submits the batch for execution on the server. - - Parameters - ---------- - solver_version : Optional[str], default=None - The version of the solver to use for execution. - protocol_version : Optional[str], default=None - The data protocol version. Defaults to the current version. - worker_group : Optional[str], default=None - Optional identifier for a specific worker group to run on. - - Returns - ------- - Any - The server's response to the submit request. - """ - - # TODO: add support for pay_type and priority arguments - if pay_type != PayType.AUTO: - raise NotImplementedError( - "The 'pay_type' argument is not yet supported and will be ignored." - ) - if priority is not None: - raise NotImplementedError( - "The 'priority' argument is not yet supported and will be ignored." - ) - - if protocol_version is None: - protocol_version = _get_protocol_version() - return http.post( - f"rf/task/{self.task_id}/submit", - { - "solverVersion": solver_version, - "protocolVersion": protocol_version, - "workerGroup": worker_group, - }, - ) - - def abort(self) -> requests.Response: - """Abort the current task on the server.""" - if not self.task_id: - raise ValueError("Batch id not found.") - return http.put(f"rf/task/{self.task_id}/abort", {}) - - -class TaskFactory: - """Factory for obtaining the correct task subclass.""" - - _REGISTRY: dict[str, type[WebTask]] = {} - - @classmethod - def reset(cls) -> None: - """Clear the cached task kind registry (used in tests).""" - cls._REGISTRY.clear() - - @classmethod - def register(cls, task_id: str, kind: type[WebTask]) -> None: - cls._REGISTRY[task_id] = kind - - @classmethod - def get_kind(cls, task_id: str, verbose: bool = True) -> type[WebTask]: - """Return cached task class, fetching and caching if needed.""" - kind = cls._REGISTRY.get(task_id) - if kind: - return kind - if WebTask.is_batch(task_id): - cls.register(task_id, BatchTask) - return BatchTask - task = SimulationTask.get(task_id, verbose=verbose) - if task: - cls.register(task_id, SimulationTask) - return SimulationTask - - @classmethod - def get(cls, task_id: str, verbose: bool = True) -> WebTask: - kind = cls._REGISTRY.get(task_id) - if kind is BatchTask: - return BatchTask.get(task_id, verbose=verbose) - if kind is SimulationTask: - task = SimulationTask.get(task_id, verbose=verbose) - return task - if WebTask.is_batch(task_id): - cls.register(task_id, BatchTask) - return BatchTask.get(task_id, verbose=verbose) - task = SimulationTask.get(task_id, verbose=verbose) - if task: - cls.register(task_id, SimulationTask) - return task diff --git a/tidy3d/web/core/task_info.py b/tidy3d/web/core/task_info.py index c42ba0f220..f902fc1bcd 100644 --- a/tidy3d/web/core/task_info.py +++ b/tidy3d/web/core/task_info.py @@ -1,328 +1,18 @@ -"""Defines information about a task""" +"""Compatibility shim for :mod:`tidy3d._common.web.core.task_info`.""" -from __future__ import annotations - -from abc import ABC -from datetime import datetime -from enum import Enum -from typing import Annotated, Optional - -from pydantic import BaseModel, ConfigDict, Field - - -class TaskBase(BaseModel, ABC): - """Base configuration for all task objects.""" - - model_config = ConfigDict(arbitrary_types_allowed=True) - - -class ChargeType(str, Enum): - """The payment method of the task.""" - - FREE = "free" - """No payment required.""" - - PAID = "paid" - """Payment required.""" - - -class TaskBlockInfo(TaskBase): - """Information about the task's block status. - - Notes - ----- - This includes details about how the task can be blocked by various features - such as user limits and insufficient balance. - """ - - chargeType: Optional[ChargeType] = None - """The type of charge applicable to the task (free or paid).""" - - maxFreeCount: Optional[int] = None - """The maximum number of free tasks allowed.""" - - maxGridPoints: Optional[int] = None - """The maximum number of grid points permitted.""" - - maxTimeSteps: Optional[int] = None - """The maximum number of time steps allowed.""" - - -class TaskInfo(TaskBase): - """General information about a task.""" - - taskId: str - """Unique identifier for the task.""" - - taskName: Optional[str] = None - """Name of the task.""" - - nodeSize: Optional[int] = None - """Size of the node allocated for the task.""" - - completedAt: Optional[datetime] = None - """Timestamp when the task was completed.""" - - status: Optional[str] = None - """Current status of the task.""" - - realCost: Optional[float] = None - """Actual cost incurred by the task.""" - - timeSteps: Optional[int] = None - """Number of time steps involved in the task.""" - - solverVersion: Optional[str] = None - """Version of the solver used for the task.""" - - createAt: Optional[datetime] = None - """Timestamp when the task was created.""" - - estCostMin: Optional[float] = None - """Estimated minimum cost for the task.""" - - estCostMax: Optional[float] = None - """Estimated maximum cost for the task.""" - - realFlexUnit: Optional[float] = None - """Actual flexible units used by the task.""" - - oriRealFlexUnit: Optional[float] = None - """Original real flexible units.""" - - estFlexUnit: Optional[float] = None - """Estimated flexible units for the task.""" - - estFlexCreditTimeStepping: Optional[float] = None - """Estimated flexible credits for time stepping.""" - - estFlexCreditPostProcess: Optional[float] = None - """Estimated flexible credits for post-processing.""" - - estFlexCreditMode: Optional[float] = None - """Estimated flexible credits based on the mode.""" - - s3Storage: Optional[float] = None - """Amount of S3 storage used by the task.""" - - startSolverTime: Optional[datetime] = None - """Timestamp when the solver started.""" - - finishSolverTime: Optional[datetime] = None - """Timestamp when the solver finished.""" - - totalSolverTime: Optional[int] = None - """Total time taken by the solver.""" - - callbackUrl: Optional[str] = None - """Callback URL for task notifications.""" - - taskType: Optional[str] = None - """Type of the task.""" - - metadataStatus: Optional[str] = None - """Status of the metadata for the task.""" - - taskBlockInfo: Optional[TaskBlockInfo] = None - """Blocking information for the task.""" - - version: Optional[str] = None - """Version of the task.""" - - -class RunInfo(TaskBase): - """Information about the run of a task.""" - - perc_done: Annotated[float, Field(ge=0.0, le=100.0)] - """Percentage of the task that is completed (0 to 100).""" - - field_decay: Annotated[float, Field(ge=0.0, le=1.0)] - """Field decay from the maximum value (0 to 1).""" - - def display(self) -> None: - """Print some info about the task's progress.""" - print(f" - {self.perc_done:.2f} (%) done") - print(f" - {self.field_decay:.2e} field decay from max") - - -# ---------------------- Batch (Modeler) detail schema ---------------------- # - - -class BatchTaskBlockInfo(TaskBlockInfo): - """ - Extends `TaskBlockInfo` with specific details for batch task blocking. - - Attributes: - accountLimit: A usage or cost limit imposed by the user's account. - taskBlockMsg: A human-readable message describing the reason for the block. - taskBlockType: The specific type of block (e.g., 'balance', 'limit'). - blockStatus: The current blocking status for the batch. - taskStatus: The status of the task when it was blocked. - """ - - accountLimit: Optional[float] = None - taskBlockMsg: Optional[str] = None - taskBlockType: Optional[str] = None - blockStatus: Optional[str] = None - taskStatus: Optional[str] = None - - -class BatchMember(TaskBase): - """ - Represents a single task within a larger batch operation. - - Attributes: - refId: A reference identifier for the member task. - folderId: The identifier of the folder containing the task. - sweepId: The identifier for the parameter sweep, if applicable. - taskId: The unique identifier of the task. - linkedTaskId: The identifier of a task linked to this one. - groupId: The identifier of the group this task belongs to. - taskName: The name of the individual task. - status: The current status of this specific task. - sweepData: Data associated with a parameter sweep. - validateInfo: Information related to the task's validation. - replaceData: Data used for replacements or modifications. - protocolVersion: The version of the protocol used. - variable: The variable parameter for this task in a sweep. - createdAt: The timestamp when the member task was created. - updatedAt: The timestamp when the member task was last updated. - denormalizeStatus: The status of the data denormalization process. - summary: A dictionary containing summary information for the task. - """ - - refId: Optional[str] = None - folderId: Optional[str] = None - sweepId: Optional[str] = None - taskId: Optional[str] = None - linkedTaskId: Optional[str] = None - groupId: Optional[str] = None - taskName: Optional[str] = None - status: Optional[str] = None - sweepData: Optional[str] = None - validateInfo: Optional[str] = None - replaceData: Optional[str] = None - protocolVersion: Optional[str] = None - variable: Optional[str] = None - createdAt: Optional[datetime] = None - updatedAt: Optional[datetime] = None - denormalizeStatus: Optional[str] = None - summary: Optional[dict] = None - - -class BatchDetail(TaskBase): - """Provides a detailed, top-level view of a batch of tasks. - - Notes - ----- - This model serves as the main payload for retrieving comprehensive - information about a batch operation. - - Attributes - ---------- - refId - A reference identifier for the entire batch. - optimizationId - Identifier for the optimization process, if any. - groupId - Identifier for the group the batch belongs to. - name - The user-defined name of the batch. - status - The current status of the batch. - totalTask - The total number of tasks in the batch. - preprocessSuccess - The count of tasks that completed preprocessing. - postprocessStatus - The status of the batch's postprocessing stage. - validateSuccess - The count of tasks that passed validation. - runSuccess - The count of tasks that ran successfully. - postprocessSuccess - The count of tasks that completed postprocessing. - taskBlockInfo - Information on what might be blocking the batch. - estFlexUnit - The estimated total flexible compute units for the batch. - totalSeconds - The total time in seconds the batch has taken. - totalCheckMillis - Total time in milliseconds spent on checks. - message - A general message providing information about the batch status. - tasks - A list of `BatchMember` objects, one for each task in the batch. - taskType - The type of tasks contained in the batch. - """ - - refId: Optional[str] = None - optimizationId: Optional[str] = None - groupId: Optional[str] = None - name: Optional[str] = None - status: Optional[str] = None - totalTask: int = 0 - preprocessSuccess: int = 0 - postprocessStatus: Optional[str] = None - validateSuccess: int = 0 - runSuccess: int = 0 - postprocessSuccess: int = 0 - taskBlockInfo: Optional[BatchTaskBlockInfo] = None - estFlexUnit: Optional[float] = None - realFlexUnit: Optional[float] = None - totalSeconds: Optional[int] = None - totalCheckMillis: Optional[int] = None - message: Optional[str] = None - tasks: list[BatchMember] = [] - validateErrors: Optional[dict] = None - taskType: str = None - version: Optional[str] = None - - -class AsyncJobDetail(TaskBase): - """Provides a detailed view of an asynchronous job and its sub-tasks. - - Notes - ----- - This model represents a long-running operation. The 'result' attribute holds - the output of a completed job, which for orchestration jobs, is often a - JSON string mapping sub-task names to their unique IDs. - - Attributes - ---------- - asyncId - The unique identifier for the asynchronous job. - status - The current overall status of the job (e.g., 'RUNNING', 'COMPLETED'). - progress - The completion percentage of the job (from 0.0 to 100.0). - createdAt - The timestamp when the job was created. - completedAt - The timestamp when the job finished (successfully or not). - tasks - A dictionary mapping logical task keys to their unique task IDs. - This is often populated by parsing the 'result' of an orchestration task. - result - The raw string output of the completed job. If the job spawns other - tasks, this is expected to be a JSON string detailing those tasks. - taskBlockInfo - Information on any dependencies blocking the job from running. - message - A human-readable message about the job's status. - """ - - asyncId: str - status: str - progress: Optional[float] = None - createdAt: Optional[datetime] = None - completedAt: Optional[datetime] = None - tasks: Optional[dict[str, str]] = None - result: Optional[str] = None - taskBlockInfo: Optional[TaskBlockInfo] = None - message: Optional[str] = None +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility +# marked as migrated to _common +from __future__ import annotations -AsyncJobDetail.model_rebuild() +from tidy3d._common.web.core.task_info import ( + AsyncJobDetail, + BatchDetail, + BatchMember, + BatchTaskBlockInfo, + ChargeType, + RunInfo, + TaskBase, + TaskBlockInfo, + TaskInfo, +) diff --git a/tidy3d/web/core/types.py b/tidy3d/web/core/types.py index aaac18612a..51437dbbdf 100644 --- a/tidy3d/web/core/types.py +++ b/tidy3d/web/core/types.py @@ -1,73 +1,15 @@ -"""Tidy3d abstraction types for the core.""" +"""Compatibility shim for :mod:`tidy3d._common.web.core.types`.""" -from __future__ import annotations - -from abc import ABC, abstractmethod -from enum import Enum -from typing import Any - -from pydantic import BaseModel - - -class Tidy3DResource(BaseModel, ABC): - """Abstract base class / template for a webservice that implements resource query.""" - - @classmethod - @abstractmethod - def get(cls, *args: Any, **kwargs: Any) -> Tidy3DResource: - """Get a resource from the server.""" - - -class ResourceLifecycle(Tidy3DResource, ABC): - """Abstract base class for a webservice that implements resource life cycle management.""" - - @classmethod - @abstractmethod - def create(cls, *args: Any, **kwargs: Any) -> Tidy3DResource: - """Create a new resource and return it.""" - - @abstractmethod - def delete(self, *args: Any, **kwargs: Any) -> None: - """Delete the resource.""" +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility +# marked as migrated to _common +from __future__ import annotations -class Submittable(BaseModel, ABC): - """Abstract base class / template for a webservice that implements a submit method.""" - - @abstractmethod - def submit(self, *args: Any, **kwargs: Any) -> None: - """Submit the task to the webservice.""" - - -class Queryable(BaseModel, ABC): - """Abstract base class / template for a webservice that implements a query method.""" - - @classmethod - @abstractmethod - def list(cls, *args: Any, **kwargs: Any) -> list[Queryable]: - """List all resources of this type.""" - - -class TaskType(str, Enum): - FDTD = "FDTD" - MODE_SOLVER = "MODE_SOLVER" - HEAT = "HEAT" - HEAT_CHARGE = "HEAT_CHARGE" - EME = "EME" - MODE = "MODE" - VOLUME_MESH = "VOLUME_MESH" - MODAL_CM = "MODAL_CM" - TERMINAL_CM = "TERMINAL_CM" - - -class PayType(str, Enum): - CREDITS = "FLEX_CREDIT" - AUTO = "AUTO" - - @classmethod - def _missing_(cls, value: object) -> PayType: - if isinstance(value, str): - key = value.strip().replace(" ", "_").upper() - if key in cls.__members__: - return cls.__members__[key] - return super()._missing_(value) +from tidy3d._common.web.core.types import ( + PayType, + Queryable, + ResourceLifecycle, + Submittable, + TaskType, + Tidy3DResource, +)