Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,6 +0,0 @@
[submodule "docs/notebooks"]
path = docs/notebooks
url = git@github.com:flexcompute/tidy3d-notebooks.git
[submodule "docs/faq"]
path = docs/faq
url = https://github.com/flexcompute/tidy3d-faq
114 changes: 114 additions & 0 deletions scripts/ensure_imports_from_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""Ensure tidy3d._common modules avoid importing from tidy3d outside tidy3d._common."""

from __future__ import annotations

import argparse
import ast
import sys
from collections.abc import Iterable
from dataclasses import dataclass
from pathlib import Path


@dataclass(frozen=True)
class ImportViolation:
file: str
line: int
statement: str


def parse_args(argv: Iterable[str]) -> argparse.Namespace:
parser = argparse.ArgumentParser(
description=(
"Ensure tidy3d._common does not import from tidy3d modules outside tidy3d._common."
)
)
parser.add_argument(
"--root",
default="tidy3d/_common",
help="Root directory to scan (relative to repo root).",
)
return parser.parse_args(argv)


def main(argv: Iterable[str]) -> None:
args = parse_args(argv)
repo_root = Path.cwd().resolve()
root = (repo_root / args.root).resolve()
if not root.exists():
print(f"No directory found at {root}. Skipping check.")
return

violations: list[ImportViolation] = []
for path in sorted(root.rglob("*.py")):
violations.extend(_violations_in_file(path, repo_root))

if violations:
print("Invalid tidy3d imports found in tidy3d._common:")
for violation in violations:
print(f"{violation.file}:{violation.line}: {violation.statement}")
raise SystemExit(1)

print("No invalid tidy3d imports found in tidy3d._common.")


def _violations_in_file(path: Path, repo_root: Path) -> list[ImportViolation]:
source = path.read_text(encoding="utf-8")
try:
tree = ast.parse(source)
except SyntaxError as exc:
raise SystemExit(f"Syntax error parsing {path}: {exc}") from exc

rel_path = str(path.relative_to(repo_root))
violations: list[ImportViolation] = []
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
name = alias.name
if name == "tidy3d" or (
name.startswith("tidy3d.") and not name.startswith("tidy3d._common")
):
violations.append(
ImportViolation(
file=rel_path,
line=node.lineno,
statement=_statement(source, node),
)
)
elif isinstance(node, ast.ImportFrom):
if node.level:
continue
module = node.module
if not module:
continue
if module == "tidy3d":
for alias in node.names:
if alias.name != "_common":
violations.append(
ImportViolation(
file=rel_path,
line=node.lineno,
statement=_statement(source, node),
)
)
continue
if module.startswith("tidy3d.") and not module.startswith("tidy3d._common"):
violations.append(
ImportViolation(
file=rel_path,
line=node.lineno,
statement=_statement(source, node),
)
)
return violations


def _statement(source: str, node: ast.AST) -> str:
segment = ast.get_source_segment(source, node)
if segment:
return " ".join(segment.strip().splitlines())
return node.__class__.__name__


if __name__ == "__main__":
main(sys.argv[1:])
4 changes: 1 addition & 3 deletions tests/config/test_legacy_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,9 @@ def test_env_vars_follow_profile_switch(

def test_web_core_environment_reexports():
"""Legacy `tidy3d.web.core.environment` exports remain available via config shim."""

import tidy3d.web as web
from tidy3d._common.web.core import environment
from tidy3d.config import Env as ConfigEnv

environment = web.core.environment
assert environment.Env is ConfigEnv

with warnings.catch_warnings(record=True) as caught:
Expand Down
2 changes: 1 addition & 1 deletion tests/config/test_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from click.testing import CliRunner
from pydantic import Field

from tidy3d._common.config import loader as config_loader # import from common as it is patched
from tidy3d.config import get_manager, reload_config
from tidy3d.config import loader as config_loader
from tidy3d.config import registry as config_registry
from tidy3d.config.legacy import finalize_legacy_migration
from tidy3d.config.loader import migrate_legacy_config
Expand Down
3 changes: 2 additions & 1 deletion tests/test_components/test_IO.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

import tidy3d as td
from tidy3d import __version__
from tidy3d.components.base import DATA_ARRAY_MAP, Tidy3dBaseModel
from tidy3d.components.base import Tidy3dBaseModel
from tidy3d.components.data.data_array import DATA_ARRAY_MAP
from tidy3d.components.data.sim_data import DATA_TYPE_MAP

from ..test_data.test_monitor_data import make_flux_data
Expand Down
4 changes: 3 additions & 1 deletion tests/test_components/test_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,9 @@ def plot_with_multi_viz_spec(alphas, facecolors, edgecolors, rng, use_viz_spec=T

def test_no_matlab_install(monkeypatch):
"""Test that the `VisualizationSpec` only throws a warning on validation if matplotlib is not installed."""
monkeypatch.setattr("tidy3d.components.viz.visualization_spec.MATPLOTLIB_IMPORTED", False)
monkeypatch.setattr(
"tidy3d._common.components.viz.visualization_spec.MATPLOTLIB_IMPORTED", False
)

EXPECTED_WARNING_MSG_PIECE = (
"matplotlib was not successfully imported, but is required to validate colors"
Expand Down
22 changes: 10 additions & 12 deletions tests/test_web/test_local_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
@pytest.fixture(autouse=True)
def _isolate_local_cache(tmp_path, monkeypatch):
"""Keep cache operations in a temp dir and avoid moving/deleting real cache."""
import tidy3d.web.cache as cache_mod
import tidy3d._common.web.cache as cache_mod
from tidy3d.config import get_manager

real_remove_cache_dir = cache_mod._remove_cache_dir
Expand Down Expand Up @@ -611,13 +611,13 @@ def test_cache_eviction_by_entries(monkeypatch, tmp_path_factory, basic_simulati

file1 = tmp_path_factory.mktemp("art1") / CACHE_ARTIFACT_NAME
file1.write_text("a")
cache.store_result(_FakeStubData(basic_simulation), MOCK_TASK_ID, str(file1), "FDTD")
cache.store_result(MOCK_TASK_ID, str(file1), "FDTD", simulation=basic_simulation)
assert len(cache) == 1

sim2 = basic_simulation.updated_copy(shutoff=1e-4)
file2 = tmp_path_factory.mktemp("art2") / CACHE_ARTIFACT_NAME
file2.write_text("b")
cache.store_result(_FakeStubData(sim2), MOCK_TASK_ID, str(file2), "FDTD")
cache.store_result(MOCK_TASK_ID, str(file2), "FDTD", simulation=sim2)

entries = cache.list()
assert len(entries) == 1
Expand All @@ -631,13 +631,13 @@ def test_cache_eviction_by_size(monkeypatch, tmp_path_factory, basic_simulation)

file1 = tmp_path_factory.mktemp("art1") / CACHE_ARTIFACT_NAME
file1.write_text("a" * 8_000)
cache.store_result(_FakeStubData(basic_simulation), MOCK_TASK_ID, str(file1), "FDTD")
cache.store_result(MOCK_TASK_ID, str(file1), "FDTD", simulation=basic_simulation)
assert len(cache) == 1

sim2 = basic_simulation.updated_copy(shutoff=1e-4)
file2 = tmp_path_factory.mktemp("art2") / CACHE_ARTIFACT_NAME
file2.write_text("b" * 8_000)
cache.store_result(_FakeStubData(sim2), MOCK_TASK_ID, str(file2), "FDTD")
cache.store_result(MOCK_TASK_ID, str(file2), "FDTD", simulation=sim2)

entries = cache.list()
assert len(cache) == 1
Expand All @@ -653,7 +653,7 @@ def test_cache_stats_tracking(monkeypatch, tmp_path_factory, basic_simulation):
payload = "stats-payload"
artifact.write_text(payload)

cache.store_result(_FakeStubData(basic_simulation), MOCK_TASK_ID, str(artifact), "FDTD")
cache.store_result(MOCK_TASK_ID, str(artifact), "FDTD", simulation=basic_simulation)

stats_path = cache.root / CACHE_STATS_NAME
assert stats_path.exists()
Expand Down Expand Up @@ -692,12 +692,12 @@ def test_cache_stats_sync(monkeypatch, tmp_path_factory, basic_simulation):
artifact1 = tmp_path_factory.mktemp("artifact_sync1") / CACHE_ARTIFACT_NAME
payload1 = "sync-one"
artifact1.write_text(payload1)
cache.store_result(_FakeStubData(sim1), f"{MOCK_TASK_ID}-1", str(artifact1), "FDTD")
cache.store_result(f"{MOCK_TASK_ID}-1", str(artifact1), "FDTD", simulation=sim1)

artifact2 = tmp_path_factory.mktemp("artifact_sync2") / CACHE_ARTIFACT_NAME
payload2 = "sync-two"
artifact2.write_text(payload2)
cache.store_result(_FakeStubData(sim2), f"{MOCK_TASK_ID}-2", str(artifact2), "FDTD")
cache.store_result(f"{MOCK_TASK_ID}-2", str(artifact2), "FDTD", simulation=sim2)

stats_path = cache.root / CACHE_STATS_NAME
assert stats_path.exists()
Expand Down Expand Up @@ -734,7 +734,7 @@ def _counting_iter():
artifact = tmp_path / "iter_guard.hdf5"
artifact.write_text("payload")

cache.store_result(_FakeStubData(basic_simulation), MOCK_TASK_ID, str(artifact), "FDTD")
cache.store_result(MOCK_TASK_ID, str(artifact), "FDTD", simulation=basic_simulation)
assert iter_calls["count"] == 0

entry_dirs = []
Expand Down Expand Up @@ -809,9 +809,7 @@ def test_cache_cli_commands(monkeypatch, tmp_path_factory, basic_simulation, tmp

artifact = artifact_dir / CACHE_ARTIFACT_NAME
artifact.write_text("payload_cli")
cache.store_result(
_FakeStubData(basic_simulation), f"{MOCK_TASK_ID}-cli", str(artifact), "FDTD"
)
cache.store_result(f"{MOCK_TASK_ID}-cli", str(artifact), "FDTD", simulation=basic_simulation)

info_result = runner.invoke(tidy3d_cli, ["cache", "info"])
assert info_result.exit_code == 0
Expand Down
22 changes: 12 additions & 10 deletions tests/test_web/test_s3utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

import pytest

import tidy3d
from tidy3d._common.web.core import s3utils as s3utils_common
from tidy3d.web.core import s3utils

s3_utils_path = "tidy3d._common.web.core.s3utils"


@pytest.fixture
def mock_S3STSToken(monkeypatch):
Expand All @@ -16,20 +18,20 @@ def mock_S3STSToken(monkeypatch):
mock_token.get_bucket = lambda: ""
mock_token.get_s3_key = lambda: ""
mock_token.is_expired = lambda: False
mock_token.get_client = lambda: tidy3d.web.core.s3utils.boto3.client()
mock_token.get_client = lambda: s3utils_common.boto3.client()
monkeypatch.setattr(
target=tidy3d.web.core.s3utils, name="_S3STSToken", value=MagicMock(return_value=mock_token)
target=s3utils_common, name="_S3STSToken", value=MagicMock(return_value=mock_token)
)
return mock_token


@pytest.fixture
def mock_get_s3_sts_token(monkeypatch):
def _mock_get_s3_sts_token(resource_id, remote_filename):
return s3utils._S3STSToken(resource_id, remote_filename)
return s3utils_common._S3STSToken(resource_id, remote_filename)

monkeypatch.setattr(
target=tidy3d.web.core.s3utils, name="get_s3_sts_token", value=_mock_get_s3_sts_token
target=s3utils_common, name="get_s3_sts_token", value=_mock_get_s3_sts_token
)
return _mock_get_s3_sts_token

Expand All @@ -44,7 +46,7 @@ def mock_s3_client(monkeypatch):
# Patch the `client` as it is imported within `tidy3d.web.core.s3utils.boto3` so that
# whenever it's invoked (for example with "s3"), it returns our `mock_client`.
monkeypatch.setattr(
target=tidy3d.web.core.s3utils.boto3,
target=s3utils_common.boto3,
name="client",
value=MagicMock(return_value=mock_client),
)
Expand Down Expand Up @@ -148,11 +150,11 @@ def test_s3_token_get_client_with_custom_endpoint(tmp_path, monkeypatch):

# Mock boto3.client
mock_boto_client = MagicMock()
monkeypatch.setattr("tidy3d.web.core.s3utils.boto3.client", mock_boto_client)
monkeypatch.setattr(f"{s3_utils_path}.boto3.client", mock_boto_client)

# Test 1: Without custom endpoint - use fresh config
test_config = ConfigManager(config_dir=tmp_path)
monkeypatch.setattr("tidy3d.web.core.s3utils.config", test_config)
monkeypatch.setattr(f"{s3_utils_path}.config", test_config)
token.get_client()

# Verify boto3.client was called without endpoint_url
Expand Down Expand Up @@ -195,12 +197,12 @@ def test_s3_token_get_client_respects_ssl_verify(tmp_path, monkeypatch):
token = _S3STSToken(**token_data)

mock_boto_client = MagicMock()
monkeypatch.setattr("tidy3d.web.core.s3utils.boto3.client", mock_boto_client)
monkeypatch.setattr(f"{s3_utils_path}.boto3.client", mock_boto_client)

# Use fresh config with ssl_verify=False
test_config = ConfigManager(config_dir=tmp_path)
test_config.update_section("web", ssl_verify=False)
monkeypatch.setattr("tidy3d.web.core.s3utils.config", test_config)
monkeypatch.setattr(f"{s3_utils_path}.config", test_config)

token.get_client()

Expand Down
9 changes: 5 additions & 4 deletions tests/test_web/test_tidy3d_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from responses import matchers

import tidy3d as td
from tests.test_web.test_webapi import task_core_path
from tidy3d import config
from tidy3d.web.core import http_util
from tidy3d.web.core.environment import Env
Expand All @@ -31,7 +32,7 @@ def make_sim():
@pytest.fixture
def set_api_key(monkeypatch):
"""Set the api key."""
import tidy3d.web.core.http_util as httputil
import tidy3d._common.web.core.http_util as httputil

monkeypatch.setattr(httputil, "api_key", lambda: "apikey")
monkeypatch.setattr(httputil, "get_version", lambda: td.version.__version__)
Expand Down Expand Up @@ -85,7 +86,7 @@ def mock_download(*args, **kwargs):
to_file = kwargs["to_file"]
sim.to_file(to_file)

monkeypatch.setattr("tidy3d.web.core.task_core.download_gz_file", mock_download)
monkeypatch.setattr(f"{task_core_path}.download_gz_file", mock_download)

responses.add(
responses.GET,
Expand Down Expand Up @@ -121,7 +122,7 @@ def test_upload(monkeypatch, set_api_key):
def mock_download(*args, **kwargs):
pass

monkeypatch.setattr("tidy3d.web.core.task_core.upload_file", mock_download)
monkeypatch.setattr(f"{task_core_path}.upload_file", mock_download)
task = SimulationTask.get("3eb06d16-208b-487b-864b-e9b1d3e010a7")
with tempfile.NamedTemporaryFile() as temp:
task.upload_file(temp.name, "temp.json")
Expand Down Expand Up @@ -353,7 +354,7 @@ def mock(*args, **kwargs):
with open(file_path, "w") as f:
f.write("0.3,5.7")

monkeypatch.setattr("tidy3d.web.core.task_core.download_file", mock)
monkeypatch.setattr(f"{task_core_path}.download_file", mock)
responses.add(
responses.GET,
f"{Env.current.web_api_endpoint}/tidy3d/tasks/3eb06d16-208b-487b-864b-e9b1d3e010a7/detail",
Expand Down
4 changes: 2 additions & 2 deletions tests/test_web/test_webapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
common.CONNECTION_RETRY_TIME = 0.1
INVALID_TASK_ID = "INVALID_TASK_ID"

task_core_path = "tidy3d.web.core.task_core"
task_core_path = "tidy3d._common.web.core.task_core"
api_path = "tidy3d.web.api.webapi"

config.switch_profile("dev")
Expand Down Expand Up @@ -204,7 +204,7 @@ def mock_upload(monkeypatch, set_api_key):
def mock_upload_file(*args, **kwargs):
pass

monkeypatch.setattr("tidy3d.web.core.task_core.upload_file", mock_upload_file)
monkeypatch.setattr(f"{task_core_path}.upload_file", mock_upload_file)


@pytest.fixture
Expand Down
1 change: 0 additions & 1 deletion tests/test_web/test_webapi_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
)
from tidy3d.web.core.environment import Env

task_core_path = "tidy3d.web.core.task_core"
api_path = "tidy3d.web.api.webapi"

config.switch_profile("dev")
Expand Down
Loading
Loading