Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions .flake8

This file was deleted.

53 changes: 14 additions & 39 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
python-version: ["3.10", "3.11", "3.12", "3.13"]

steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
with:
submodules: 'recursive'

Expand All @@ -30,46 +30,21 @@ jobs:

- run: python --version

# Cache the installation of Poetry itself, e.g. the next step. This prevents the workflow
# from installing Poetry every time, which can be slow. Note the use of the Poetry version
# number in the cache key, and the "-0" suffix: this allows you to invalidate the cache
# manually if/when you want to upgrade Poetry, or if something goes wrong.
- name: Cache poetry
uses: actions/cache@v4.2.0
- name: Install uv
uses: astral-sh/setup-uv@v6.1.0
with:
path: ~/.local
key: poetry-2.0.1-0
version: 0.7.8
checksum: "285981409c746508c1fd125f66a1ea654e487bf1e4d9f45371a062338f788adb"
enable-cache: true

- name: Install poetry
uses: snok/install-poetry@v1
with:
version: 2.0.1
virtualenvs-create: true
virtualenvs-in-project: true

- run: poetry --version
- run: poetry run python --version

# Cache your dependencies (i.e. all the stuff in your `pyproject.toml`). Note the cache
# key: if you're using multiple Python versions, or multiple OSes, you'd need to include
# them in the cache key. I'm not, so it can be simple and just depend on the poetry.lock.
- name: Cache dependencies
id: cache-deps
uses: actions/cache@v4.2.0
with:
path: .venv
key: pydeps-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }}

- name: Install dependencies
# The `if` statement ensures this only runs on a cache miss.
run: poetry install --no-interaction
if: steps.cache-deps.outputs.cache-hit != 'true'
- run: uv --version
- run: uv run python --version

- name: Check code formatting with black
run: poetry run black . --check
- name: Check code linting
run: uv run --extra dev ruff check --no-fix

- name: Lint with flake8
run: poetry run flake8 . --count --show-source
- name: Check code formatting
run: uv run --extra dev ruff format --check

- name: Test with pytest
run: poetry run python${{ matrix.python-version }} -m pytest
- name: Run unit tests
run: uv run --extra dev python${{ matrix.python-version }} -m pytest
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__pycache__/
.illustris_api_key.txt
.venv/
*.lock
*.egg-info/
/data/
/output/
2 changes: 0 additions & 2 deletions devel/show_parquet.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,6 @@
}
],
"source": [
"import pandas as pd\n",
"\n",
"df = dataset.to_table().to_pandas()\n",
"df"
]
Expand Down
2 changes: 1 addition & 1 deletion devel/test_chunks.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@
}
],
"source": [
"pq.read_table('part-0.parquet').to_pandas()"
"pq.read_table(\"part-0.parquet\").to_pandas()"
]
}
],
Expand Down
11 changes: 8 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,21 @@ dependencies = [
"requests>=2.32",
"scikit-image>=0.25",
"scipy>=1.15",
"pytest>=8.3",
]

[project.optional-dependencies]
dev = [
"black>=25.1",
"flake8>=7.0",
"ipykernel>=6.29",
"pytest>=8.3",
"ruff>=0.11.11",
]

[build-system]
requires = ["setuptools >= 61.0"]
build-backend = "setuptools.build_meta"

[tool.ruff]
line-length = 120

[tool.ruff.lint.mccabe]
max-complexity = 10
8 changes: 2 additions & 6 deletions src/pest/fits_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,7 @@ def convert_all(
"data": [data.flatten()],
"simulation": splits[-5],
"snapshot": np.int32(splits[-3].split("_")[1].lstrip("0")),
"subhalo_id": np.int32(
splits[-1].split("_")[1].lstrip("0")
),
"subhalo_id": np.int32(splits[-1].split("_")[1].lstrip("0")),
}
)

Expand All @@ -102,9 +100,7 @@ def convert_all(
table = pa.Table.from_pandas(df, schema=schema)

# Add shape metadata to the schema
table = table.replace_schema_metadata(
metadata={"data_shape": str(data.shape)}
)
table = table.replace_schema_metadata(metadata={"data_shape": str(data.shape)})

if writer is None:
writer = pq.ParquetWriter(
Expand Down
16 changes: 4 additions & 12 deletions src/pest/gaia_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,33 +67,25 @@ def convert(
lambda x: np.fromstring(x[1:-1], dtype=np.float32, sep=",")
)

calibrated_data, _ = calibrate(
continuous_data, sampling=self.sampling, save_file=False
)
calibrated_data, _ = calibrate(continuous_data, sampling=self.sampling, save_file=False)

if self.with_flux_error:
# Convert 'flux' column to float32
calibrated_data["flux_error"] = calibrated_data["flux_error"].apply(
lambda x: np.array(x, dtype=np.float32)
)
calibrated_data["flux_error"] = calibrated_data["flux_error"].apply(lambda x: np.array(x, dtype=np.float32))
else:
# Remove the 'flux_error' column from the calibrated data
if "flux_error" in calibrated_data.columns:
calibrated_data.drop(columns=["flux_error"], inplace=True)

# Convert 'flux' column to float32
calibrated_data["flux"] = calibrated_data["flux"].apply(
lambda x: np.array(x, dtype=np.float32)
)
calibrated_data["flux"] = calibrated_data["flux"].apply(lambda x: np.array(x, dtype=np.float32))

# Use pyarrow to write the data to a parquet file
table = pa.Table.from_pandas(calibrated_data)

# Add shape metadata to the schema
data_shape = f"(1, {len(calibrated_data['flux'][0])})"
table = table.replace_schema_metadata(
metadata={"flux_shape": data_shape, "flux_error_shape": data_shape}
)
table = table.replace_schema_metadata(metadata={"flux_shape": data_shape, "flux_error_shape": data_shape})

parquet.write_table(
table,
Expand Down
5 changes: 1 addition & 4 deletions src/pest/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,7 @@ def __call__(self, images) -> np.ndarray:

mean = np.mean(resulting_image, axis=0)
resulting_image = (
resulting_image
* np.asinh(self.stretch * self.range * (mean - self.lower_limit))
/ self.range
/ mean
resulting_image * np.asinh(self.stretch * self.range * (mean - self.lower_limit)) / self.range / mean
)

resulting_image = np.nan_to_num(resulting_image, nan=0, posinf=0, neginf=0)
Expand Down
Loading