diff --git a/docs/myst.yml b/docs/myst.yml index 6e8cf59..47adaef 100644 --- a/docs/myst.yml +++ b/docs/myst.yml @@ -30,6 +30,9 @@ project: - file: notebooks/reading_params.ipynb - file: notebooks/crossovers.ipynb - file: notebooks/bed_picks.ipynb + - file: notebooks/qc_demo.ipynb + - file: notebooks/nonstandard_data_products.ipynb + - file: notebooks/repicking.ipynb site: template: book-theme diff --git a/docs/notebooks/nonstandard_data_products.ipynb b/docs/notebooks/nonstandard_data_products.ipynb new file mode 100644 index 0000000..87172e2 --- /dev/null +++ b/docs/notebooks/nonstandard_data_products.ipynb @@ -0,0 +1,201 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "1772a2d7", + "metadata": {}, + "source": [ + "---\n", + "title: Loading non-standard data products\n", + "description: How to use xOPR to load unlisted (non-standard) data products\n", + "date: 2026-04-23\n", + "---\n", + "\n", + ":::{warning}\n", + "This is an advanced feature and most users can probably ignore it. If you want to load non-standard data products, such as either a custom processing type or individual \"images\" (often corresponding to high and low gain channels), read on.\n", + ":::\n", + "\n", + "By default, xOPR uses a pre-generated STAC catalog of radar data segments and loads certain pre-defined data product types into it, assuming they are available. These include OPR data products such as `CSARP_qlook` (unfocused SAR) and `CSARP_standard` (focused SAR). If you want to work with another data product that is publicly available from the OPR servers but hasn't been indexed by the xOPR STAC catalog (either because you just created it or because it's a non-standard data product that we don't index), you can do this by setting the `allow_unlisted_products=True` flag when loading frames.\n", + "\n", + "This is a bit of a your-mileage-may-vary feature. xOPR makes a best effort and will probably load most non-standard data products just fine, but a non-standard data product with a completely different format will certainly break things.\n", + "\n", + "Feel free to reach out if you have a use case that we don't support (or put in a pull request!).\n", + "\n", + "Keep in mind that any data product must be available on the public side of the OPR data website. If you can't find your desired data product at https://data.cresis.ku.edu/data/rds/, then xOPR can't load it either." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1d5a6e4b", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "import numpy as np\n", + "import xarray as xr\n", + "import geoviews as gv\n", + "import geoviews.feature as gf\n", + "import cartopy.crs as ccrs\n", + "import matplotlib.pyplot as plt\n", + "\n", + "import xopr\n", + "\n", + "import holoviews as hv\n", + "import hvplot.xarray\n", + "import hvplot.pandas\n", + "hvplot.extension('bokeh')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9a567aae", + "metadata": {}, + "outputs": [], + "source": [ + "# Establish an OPR session\n", + "opr = xopr.OPRConnection(cache_dir=\"radar_cache\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "089dd1ac", + "metadata": {}, + "outputs": [], + "source": [ + "collection, segment = \"2019_Antarctica_GV\", \"20191105_01\"\n", + "\n", + "stac_items = opr.query_frames(collections=collection, segment_paths=[segment], max_items=5)\n", + "stac_items" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2df228d2", + "metadata": {}, + "outputs": [], + "source": [ + "background_map = gf.ocean.opts(projection=ccrs.SouthPolarStereo(), scale='50m') * gf.coastline.opts(projection=ccrs.SouthPolarStereo(), scale='50m')\n", + "background_map * stac_items.to_crs('EPSG:3031').hvplot(aspect='equal')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "affbb5d2", + "metadata": {}, + "outputs": [], + "source": [ + "frames = opr.load_frames(stac_items, data_product=\"CSARP_mvdr\", allow_unlisted_products=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4d075dd2", + "metadata": {}, + "outputs": [], + "source": [ + "# Inspect an individual frame\n", + "frames[0].xopr" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b65d9bf4", + "metadata": {}, + "outputs": [], + "source": [ + "flight_line = xopr.merge_frames(frames)\n", + "\n", + "stacked = flight_line.resample(slow_time='2s').mean()\n", + "stacked.xopr\n", + "\n", + "fig, ax = plt.subplots(figsize=(15, 4))\n", + "radargram = 10*np.log10(np.abs(stacked['Data']))\n", + "radargram.plot.imshow(x='slow_time', cmap='gray', ax=ax)\n", + "ax.invert_yaxis()\n", + "\n", + "ax.set_title(f\"{stacked.attrs['collection']} - {stacked.attrs['segment_path']} - {stacked.attrs['data_product']}\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "03ad795e", + "metadata": {}, + "source": [ + "## Loading individual images\n", + "\n", + "The `allow_unlisted_products` flag can also be used to enable loading individual \"images\".\n", + "\n", + ":::{note}\n", + "\"Images\" are views of the same segment made from different channels and/or waveforms. They are generally used to capture the full dynamic range between the surface and the bed across different parameters. When you load a default data product, you're actually loading the \"combined\" image, which has been stitched together to build a (hopefully) radiometrically consistent data product.\n", + "\n", + "Depending on the system, images usually correspond to either a high and low gain channel and/or a set of different transmit waveforms optimized for shallow or deep sounding.\n", + ":::" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aef81f32", + "metadata": {}, + "outputs": [], + "source": [ + "ds_images = {}\n", + "for img_idx in range(5):\n", + " try:\n", + " if img_idx == 0:\n", + " img_idx = None\n", + " img_key = \"combined\"\n", + " else:\n", + " img_key = f\"img_{img_idx:02d}\"\n", + " \n", + " ds_images[img_key] = opr.load_frame(stac_items.iloc[0], data_product=\"CSARP_qlook\", image=img_idx, allow_unlisted_products=True)\n", + " except FileNotFoundError:\n", + " break\n", + "\n", + "slow_time_idx = 100\n", + "\n", + "fig, ax = plt.subplots(figsize=(12, 4))\n", + "\n", + "for img_key, ds_img in ds_images.items():\n", + " img = 10*np.log10(np.abs(ds_img['Data'].isel(slow_time=slow_time_idx)))\n", + " img.plot(x='twtt', ax=ax, label=img_key, alpha=0.6)\n", + "\n", + "ax.legend()\n", + "ax.grid()\n", + "ax.set_ylabel(\"return power [dB]\")\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "xopr", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/notebooks/qc_demo.ipynb b/docs/notebooks/qc_demo.ipynb new file mode 100644 index 0000000..abb4b5d --- /dev/null +++ b/docs/notebooks/qc_demo.ipynb @@ -0,0 +1,330 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "5pmfqbfwoat", + "metadata": {}, + "source": [ + "---\n", + "title: Quality Control Checks\n", + "date: 2026-04-23\n", + "---\n", + "\n", + "This notebook demonstrates the `xopr.qc` module, which provides standardized quality control checks for radar datasets. Each check produces a per-trace boolean mask that can be used to filter or visualize data quality.\n", + "\n", + "The built-in checks are:\n", + "\n", + "- **ice_thickness_threshold** — flags traces where ice thickness is below a minimum\n", + "- **snr_bed_pick** — flags traces where the bed return is weak relative to the noise floor\n", + "- **heading_change** — flags traces with rapid aircraft heading changes\n", + "- **minimum_agl** — flags traces where the platform is too close to the surface\n", + "\n", + "We'll load a flight line, run all checks, and visualize which checks pass or fail along the radargram." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e6vcofhmokm", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import xarray as xr\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib.colors as mcolors\n", + "\n", + "import xopr\n", + "from xopr.qc import run_qc" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "vogg93lpnwa", + "metadata": {}, + "outputs": [], + "source": [ + "opr = xopr.OPRConnection(cache_dir=\"radar_cache\")" + ] + }, + { + "cell_type": "markdown", + "id": "3oeyowkytm5", + "metadata": {}, + "source": [ + "## Load an example flight\n", + "\n", + "We'll start by querying the first few frames of an example flight." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8371c82c", + "metadata": {}, + "outputs": [], + "source": [ + "collection, segment = \"2019_Antarctica_GV\", \"20191105_01\"\n", + "\n", + "stac_items = opr.query_frames(collections=collection, segment_paths=[segment], max_items=5)\n", + "stac_items" + ] + }, + { + "cell_type": "markdown", + "id": "fjis3q38mjr", + "metadata": {}, + "source": [ + "We'll load all frames from the results, merge them into one dataset, and resample to uniform 2-second spacing." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2ie0s9nxvi1", + "metadata": {}, + "outputs": [], + "source": [ + "frames = opr.load_frames(stac_items)\n", + "flight_line = xopr.merge_frames(frames)\n", + "flight_line = flight_line.resample(slow_time='2s').mean()\n", + "flight_line.xopr" + ] + }, + { + "cell_type": "markdown", + "id": "lzxhcbmqkqk", + "metadata": {}, + "source": [ + "## Run QC checks\n", + "\n", + "`run_qc()` automatically loads layer picks (`standard:surface` and `standard:bottom`) when they are not already in the dataset — just pass the `opr` connection.\n", + "\n", + "Passing `checks=None` (the default) runs all registered checks with their default parameters. You can override parameters for specific checks by passing a dict:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "y0dl5my224s", + "metadata": {}, + "outputs": [], + "source": [ + "# Run all default checks (with a custom ice thickness threshold)\n", + "qc_ds = run_qc(\n", + " flight_line,\n", + " opr=opr,\n", + " checks={\n", + " \"ice_thickness_threshold\": {\"min_thickness_m\": 500},\n", + " \"snr_bed_pick\": {\"min_snr_db\": 5.0},\n", + " \"heading_change\": {\"max_deg_per_km\": 2.0},\n", + " \"minimum_agl\": {\"min_agl_m\": 100.0},\n", + " },\n", + ")\n", + "\n", + "# Summary\n", + "qc_vars = [v for v in qc_ds.data_vars if v.startswith(\"qc_\")]\n", + "for v in qc_vars:\n", + " n = int(qc_ds[v].sum())\n", + " print(f\" {v}: {n}/{qc_ds.sizes['slow_time']} passed\")\n", + "\n", + "n_pass = int(qc_ds[\"qc\"].sum())\n", + "n_total = qc_ds.sizes[\"slow_time\"]\n", + "print(f\"\\nCombined: {n_pass}/{n_total} traces passed all checks ({100*n_pass/n_total:.1f}%)\")" + ] + }, + { + "cell_type": "markdown", + "id": "lsqa76pxm8", + "metadata": {}, + "source": [ + "## Visualize the radargram with per-check QC flags\n", + "\n", + "The top panel shows the radargram with layer picks overlaid and red shading on traces that fail the combined QC. The bottom panel shows a stacked plot of each individual check. Each check's trace is high when the check passes and low when it fails." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e4621b84", + "metadata": {}, + "outputs": [], + "source": [ + "fig, (ax_rg, ax_qc) = plt.subplots(\n", + " 2, 1, figsize=(15, 7),\n", + " gridspec_kw={\"height_ratios\": [2, 1], \"hspace\": 0.05},\n", + " sharex=True,\n", + ")\n", + "\n", + "# --- Top panel: radargram with red shading ---\n", + "radargram = 10 * np.log10(np.abs(qc_ds[\"Data\"]))\n", + "radargram.plot.imshow(x=\"slow_time\", cmap=\"gray\", ax=ax_rg, add_colorbar=False)\n", + "ax_rg.invert_yaxis()\n", + "\n", + "if \"standard:surface\" in qc_ds:\n", + " qc_ds[\"standard:surface\"].plot(ax=ax_rg, x=\"slow_time\", color=\"cyan\", linewidth=0.5, label=\"Surface\")\n", + "if \"standard:bottom\" in qc_ds:\n", + " qc_ds[\"standard:bottom\"].plot(ax=ax_rg, x=\"slow_time\", color=\"yellow\", linewidth=0.5, label=\"Bottom\")\n", + "\n", + "ax_rg.set_ylabel(\"TWTT [s]\")\n", + "\n", + "# Shade failing traces\n", + "fail_mask = ~qc_ds[\"qc\"].values\n", + "if fail_mask.any():\n", + " slow_times = qc_ds.slow_time.values\n", + " half_dt = (slow_times[1] - slow_times[0]) / 2 if len(slow_times) > 1 else np.timedelta64(1, \"s\")\n", + " diff = np.diff(fail_mask.astype(int))\n", + " starts = np.where(np.concatenate(([fail_mask[0]], diff == 1)))[0]\n", + " ends = np.where(np.concatenate((diff == -1, [fail_mask[-1]])))[0]\n", + " for s, e in zip(starts, ends):\n", + " ax_rg.axvspan(slow_times[s] - half_dt, slow_times[e] + half_dt, color=\"tab:red\", alpha=0.25)\n", + "\n", + "ax_rg.legend(loc=\"lower right\")\n", + "ax_rg.set_title(f\"{flight_line.attrs['collection']} — {flight_line.attrs['segment_path']}\")\n", + "ax_rg.set_xlabel(\"\")\n", + "\n", + "# --- Bottom panel: per-check pass/fail ---\n", + "qc_vars = sorted(v for v in qc_ds.data_vars if v.startswith(\"qc_\")) + [\"qc\"]\n", + "for idx, var in enumerate(qc_vars):\n", + " (qc_ds[var]+(1.5*idx)).plot(ax=ax_qc, x=\"slow_time\", linewidth=0.5, label=var)\n", + "\n", + "ax_qc.set_yticks(np.arange(len(qc_vars))*1.5)\n", + "ax_qc.set_yticklabels(qc_vars, fontsize=8)\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "w2ino3e8thi", + "metadata": {}, + "source": [ + "## Flight track map\n", + "\n", + "The flight track projected onto EPSG:3031 with coastline context. Traces that pass all QC checks are shown in green; traces filtered out by any check are shown in red." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6hm0gq6c9ka", + "metadata": {}, + "outputs": [], + "source": [ + "import cartopy.crs as ccrs\n", + "import cartopy.feature as cfeature\n", + "\n", + "proj = xopr.geometry.project_dataset(qc_ds, target_crs=\"EPSG:3031\")\n", + "passed = proj[\"qc\"].values\n", + "\n", + "fig, ax = plt.subplots(figsize=(8, 8), subplot_kw={\"projection\": ccrs.SouthPolarStereo()})\n", + "ax.add_feature(cfeature.COASTLINE, linewidth=0.5)\n", + "ax.add_feature(cfeature.OCEAN)\n", + "\n", + "ax.scatter(proj[\"x\"].values[~passed], proj[\"y\"].values[~passed],\n", + " c=\"tab:red\", s=4, label=\"Failed QC\", transform=ccrs.epsg(3031))\n", + "ax.scatter(proj[\"x\"].values[passed], proj[\"y\"].values[passed],\n", + " c=\"tab:green\", s=4, label=\"Passed QC\", transform=ccrs.epsg(3031))\n", + "\n", + "ax.legend(loc=\"upper right\")\n", + "ax.set_title(f\"{collection} — {segment}\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "psn49srwffs", + "metadata": {}, + "source": [ + "## Running a subset of checks\n", + "\n", + "You can run only specific checks by passing a subset of the registry keys. You can also mix in custom check functions as callable keys:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "z9dq18ocq", + "metadata": {}, + "outputs": [], + "source": [ + "# Run only ice thickness and SNR checks with stricter thresholds\n", + "strict_ds = run_qc(\n", + " flight_line,\n", + " opr=opr,\n", + " checks={\n", + " \"ice_thickness_threshold\": {\"min_thickness_m\": 1000},\n", + " \"snr_bed_pick\": {\"min_snr_db\": 10.0},\n", + " },\n", + ")\n", + "\n", + "n_pass_strict = int(strict_ds[\"qc\"].sum())\n", + "print(f\"Strict thresholds: {n_pass_strict}/{n_total} traces passed ({100*n_pass_strict/n_total:.1f}%)\")\n", + "print(f\"Default thresholds: {n_pass}/{n_total} traces passed ({100*n_pass/n_total:.1f}%)\")" + ] + }, + { + "cell_type": "markdown", + "id": "foe47n0cera", + "metadata": {}, + "source": [ + "## Custom check functions\n", + "\n", + "You can also pass a callable as a check key. The function must accept a dataset as its first argument and return a dataset with QC mask variables added (via `_apply_qc_mask`).\n", + "\n", + "Here we define a simple roll-angle check that flags traces where the platform roll exceeds 5 degrees:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dmxpiagh9f", + "metadata": {}, + "outputs": [], + "source": [ + "from xopr.qc.checks import _apply_qc_mask\n", + "\n", + "def roll_check(ds, max_roll_deg=5.0):\n", + " \"\"\"Flag traces where abs(Roll) exceeds a threshold.\"\"\"\n", + " roll_deg = np.degrees(ds[\"Roll\"].values)\n", + " mask = xr.DataArray(np.abs(roll_deg) <= max_roll_deg, dims=\"slow_time\")\n", + " return _apply_qc_mask(ds, mask, \"roll\")\n", + "\n", + "custom_ds = run_qc(\n", + " flight_line,\n", + " opr=opr,\n", + " checks={\n", + " \"ice_thickness_threshold\": {\"min_thickness_m\": 500},\n", + " roll_check: {\"max_roll_deg\": 5.0},\n", + " },\n", + ")\n", + "\n", + "for v in sorted(v for v in custom_ds.data_vars if v.startswith(\"qc\")):\n", + " n = int(custom_ds[v].sum())\n", + " print(f\" {v}: {n}/{custom_ds.sizes['slow_time']} passed\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "xopr", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/notebooks/repicking.ipynb b/docs/notebooks/repicking.ipynb new file mode 100644 index 0000000..cb1416a --- /dev/null +++ b/docs/notebooks/repicking.ipynb @@ -0,0 +1,460 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "9hq12dwq12r", + "metadata": {}, + "source": [ + "---\n", + "title: Repicking layers\n", + "date: 2026-03-17\n", + "---\n", + "\n", + "Layer picks in OPR data (such as the surface and bed layers) should be pretty good, but they're not always perfect. Sometimes the pick is offset from the true reflection peak by a few samples, and occasionally it's just wrong. If your work relies on having exactly the right surface or bed point, you may want to consider locally repicking.\n", + "\n", + "\"Repicking\" refers to refining these layer picks by examining the radar data in a local window around each existing pick and selecting a better value. This notebook demonstrates two approaches:\n", + "\n", + "1. **Local maximum**: Simply take the strongest return within a window around the existing pick.\n", + "2. **`scipy.signal.find_peaks`**: Use peak detection with custom parameters for more control over the selection criteria." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "q5bg7bavw1", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import xarray as xr\n", + "import matplotlib.pyplot as plt\n", + "from scipy.signal import find_peaks\n", + "\n", + "import xopr\n", + "\n", + "import holoviews as hv\n", + "import hvplot.xarray\n", + "hvplot.extension('bokeh')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2822a420", + "metadata": {}, + "outputs": [], + "source": [ + "opr = xopr.OPRConnection(cache_dir=\"radar_cache\")" + ] + }, + { + "cell_type": "markdown", + "id": "yda60b2hrdd", + "metadata": {}, + "source": [ + "## Load a flight line and its layer picks\n", + "\n", + "We'll start by loading a flight segment and its existing layer picks." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0fbf47b4", + "metadata": {}, + "outputs": [], + "source": [ + "collection = \"2008_Antarctica_BaslerJKB\"\n", + "segment = \"20090126_02\"\n", + "\n", + "stac_items = opr.query_frames(collections=collection, segment_paths=[segment])\n", + "frames = opr.load_frames(stac_items)\n", + "flight_line = xopr.merge_frames(frames)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b1c63964", + "metadata": {}, + "outputs": [], + "source": [ + "layers = opr.get_layers(flight_line)\n", + "print(\"Available layers:\", list(layers.keys()))\n", + "\n", + "# Prepare data for repicking: linear power with dims (twtt, slow_time)\n", + "linear_data = np.abs(flight_line['Data']).transpose('twtt', 'slow_time')\n", + "\n", + "# Prepare dB version for display only\n", + "radargram_db = 10 * np.log10(linear_data)\n", + "radargram_db.name = 'Power (dB)'" + ] + }, + { + "cell_type": "markdown", + "id": "c9zddyl8bcj", + "metadata": {}, + "source": [ + "## Defining repick functions\n", + "\n", + "The idea behind repicking is simple: for each trace, look at the radar data in a small window around the existing pick and choose a better `twtt` value. The choice of *how* to pick within that window is where the two approaches differ.\n", + "\n", + "We define two small functions that operate on a single trace. Both have the same signature — they take a 1D trace array, a scalar pick value, and the `twtt` coordinate array, and return a scalar `twtt` value. This means either one can be plugged into the same `repick()` driver via [`xr.apply_ufunc`](https://docs.xarray.dev/en/stable/generated/xarray.apply_ufunc.html)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "l2mihepombp", + "metadata": {}, + "outputs": [], + "source": [ + "def local_max(trace, pick_twtt_val, twtt_vals, window_size):\n", + " \"\"\"Take the maximum amplitude within ±window_size samples of the pick.\"\"\"\n", + " if np.isnan(pick_twtt_val):\n", + " return np.nan\n", + " dt = twtt_vals[1] - twtt_vals[0]\n", + " idx = int(np.round((pick_twtt_val - twtt_vals[0]) / dt))\n", + " if idx < 0 or idx >= len(twtt_vals):\n", + " return np.nan\n", + " lo, hi = max(0, idx - window_size), min(len(twtt_vals), idx + window_size + 1)\n", + " window = trace[lo:hi]\n", + " if np.all(np.isnan(window)):\n", + " return pick_twtt_val\n", + " return twtt_vals[lo + np.nanargmax(window)]\n", + "\n", + "\n", + "def find_peaks_local(trace, pick_twtt_val, twtt_vals, window_size, rel_prominence=0.1, **kwargs):\n", + " \"\"\"Run find_peaks in a window around the pick; return the closest prominent peak.\"\"\"\n", + " if np.isnan(pick_twtt_val):\n", + " return np.nan\n", + " dt = twtt_vals[1] - twtt_vals[0]\n", + " idx = int(np.round((pick_twtt_val - twtt_vals[0]) / dt))\n", + " if idx < 0 or idx >= len(twtt_vals):\n", + " return np.nan\n", + " lo, hi = max(0, idx - window_size), min(len(twtt_vals), idx + window_size + 1)\n", + " window = np.nan_to_num(trace[lo:hi], nan=0.0)\n", + " peaks, _ = find_peaks(window, prominence=rel_prominence * np.max(window), **kwargs)\n", + " if len(peaks) == 0:\n", + " return pick_twtt_val\n", + " closest = peaks[np.argmin(np.abs(peaks - (idx - lo)))]\n", + " return twtt_vals[lo + closest]" + ] + }, + { + "cell_type": "markdown", + "id": "6eyv6dvtjv", + "metadata": {}, + "source": [ + "The driver function `repick()` uses `xr.apply_ufunc` to apply either of the above functions across all traces. The `input_core_dims=[['twtt'], []]` tells xarray to pass each trace as a 1D array and each pick as a scalar. You swap the repicking strategy simply by passing a different `func`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ngpnv0g7px", + "metadata": {}, + "outputs": [], + "source": [ + "def repick(data, pick_twtt, func, window_size=50, **kwargs):\n", + " \"\"\"\n", + " Repick a layer by applying func to each trace.\n", + "\n", + " Parameters\n", + " ----------\n", + " data : xr.DataArray\n", + " Radar data with a 'twtt' dimension (linear power, not dB).\n", + " pick_twtt : xr.DataArray\n", + " Existing layer picks as twtt values, indexed by slow_time.\n", + " func : callable\n", + " Per-trace repick function with signature (trace, pick_twtt_val, twtt_vals, window_size, **kwargs) -> float.\n", + " Use `local_max` or `find_peaks_local`.\n", + " window_size : int\n", + " Half-width of the search window in number of samples.\n", + " **kwargs\n", + " Extra keyword arguments forwarded to func (e.g. rel_prominence, distance).\n", + " \"\"\"\n", + " pick_aligned = pick_twtt.interp(slow_time=data['slow_time'], method='nearest')\n", + " return xr.apply_ufunc(\n", + " func, data, pick_aligned,\n", + " input_core_dims=[['twtt'], []],\n", + " output_core_dims=[[]],\n", + " vectorize=True,\n", + " kwargs=dict(twtt_vals=data['twtt'].values, window_size=window_size, **kwargs),\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "4vwkzbhv8p", + "metadata": {}, + "source": [ + "## Running the repick\n", + "\n", + "Both functions work on the linear-power data (not the dB version). We call `repick()` with each function — the only thing that changes is which `func` we pass and any extra keyword arguments it needs.\n", + "\n", + "`find_peaks_local` takes a `rel_prominence` parameter that sets the prominence threshold as a fraction of each window's maximum. This is important because the surface return is orders of magnitude stronger than the bed — an absolute prominence threshold would either accept every noise peak near the surface or reject every real peak near the bed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "rnmda8fqu8", + "metadata": {}, + "outputs": [], + "source": [ + "# Method 1: local maximum — just takes the strongest return in the window\n", + "bed_repicked_max = repick(linear_data, layers['standard:bottom']['twtt'], local_max, window_size=50)\n", + "surface_repicked_max = repick(linear_data, layers['standard:surface']['twtt'], local_max, window_size=30)\n", + "\n", + "# Method 2: find_peaks — uses prominence to pick the best real peak\n", + "bed_repicked_peaks = repick(linear_data, layers['standard:bottom']['twtt'], find_peaks_local,\n", + " window_size=50, rel_prominence=0.1, distance=5)\n", + "surface_repicked_peaks = repick(linear_data, layers['standard:surface']['twtt'], find_peaks_local,\n", + " window_size=30, rel_prominence=0.1)" + ] + }, + { + "cell_type": "markdown", + "id": "4rtdswic2xl", + "metadata": {}, + "source": [ + "## Single-trace diagnostic\n", + "\n", + "Before looking at the full flight line, it's helpful to examine a single trace to understand what the repicking is doing. Here we plot the radar amplitude in a window around the bed pick and mark the peaks that `find_peaks` identifies:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2q527wu3pb3", + "metadata": {}, + "outputs": [], + "source": [ + "# Look at a single trace near the middle of the flight\n", + "trace_idx = len(linear_data['slow_time']) // 2\n", + "trace = linear_data.isel(slow_time=trace_idx).values\n", + "\n", + "bed_pick_aligned = layers['standard:bottom']['twtt'].interp(slow_time=linear_data['slow_time'], method='nearest')\n", + "bed_twtt = float(bed_pick_aligned.isel(slow_time=trace_idx).values)\n", + "twtt_vals = linear_data['twtt'].values\n", + "dt = twtt_vals[1] - twtt_vals[0]\n", + "\n", + "bed_idx = int(np.round((bed_twtt - twtt_vals[0]) / dt))\n", + "window_size = 50\n", + "lo, hi = max(0, bed_idx - window_size), min(len(twtt_vals), bed_idx + window_size + 1)\n", + "\n", + "window = np.nan_to_num(trace[lo:hi], nan=0.0)\n", + "peaks_all, _ = find_peaks(window)\n", + "peaks_prominent, _ = find_peaks(window, prominence=0.1 * np.max(window))\n", + "\n", + "fig, ax = plt.subplots(figsize=(10, 4))\n", + "ax.plot(twtt_vals[lo:hi] * 1e6, trace[lo:hi], 'k-', linewidth=0.8)\n", + "ax.axvline(bed_twtt * 1e6, color='tab:blue', linestyle=':', label='Original pick')\n", + "ax.plot(twtt_vals[lo + peaks_all] * 1e6, trace[lo + peaks_all], 'v', color='tab:gray', markersize=6, label=f'All peaks ({len(peaks_all)})')\n", + "ax.plot(twtt_vals[lo + peaks_prominent] * 1e6, trace[lo + peaks_prominent], '^', color='tab:red', markersize=8, label=f'Prominent peaks ({len(peaks_prominent)})')\n", + "ax.set_xlabel('TWTT (μs)')\n", + "ax.set_ylabel('Amplitude')\n", + "ax.set_title('Single trace: peaks around the bed pick')\n", + "ax.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "zgahqwgsexf", + "metadata": {}, + "source": [ + "## Comparing the results\n", + "\n", + "Let's compare the original bed picks with the two repicked versions. First a full flight line view, then a zoom:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "z2k8h41fc1q", + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(figsize=(15, 4))\n", + "radargram_db.plot.pcolormesh(x='slow_time', cmap='gray', ax=ax)\n", + "ax.invert_yaxis()\n", + "\n", + "layers['standard:bottom']['twtt'].plot(ax=ax, x='slow_time', linewidth=0.5, linestyle=':', color='tab:blue', label='Original')\n", + "bed_repicked_max.plot(ax=ax, linewidth=0.5, color='tab:cyan', label='Local max')\n", + "bed_repicked_peaks.plot(ax=ax, linewidth=0.5, color='tab:red', label='find_peaks')\n", + "ax.legend(loc='upper right')\n", + "ax.set_title(\"Bed picks: original vs repicked\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "gbzff86p6d", + "metadata": {}, + "outputs": [], + "source": [ + "# Zoom in on ~200 traces near the middle, tight twtt range around the bed\n", + "n_traces = len(radargram_db['slow_time'])\n", + "zmid = n_traces // 2\n", + "zlo, zhi = zmid - 100, zmid + 100\n", + "\n", + "zoom_st = radargram_db['slow_time'].values[zlo:zhi]\n", + "st_lo, st_hi = zoom_st[0], zoom_st[-1]\n", + "\n", + "bed_aligned = layers['standard:bottom']['twtt'].interp(slow_time=radargram_db['slow_time'], method='nearest')\n", + "bed_slice = bed_aligned.values[zlo:zhi]\n", + "bed_valid = bed_slice[np.isfinite(bed_slice)]\n", + "twtt_pad = 3e-6\n", + "twtt_lo, twtt_hi = bed_valid.min() - twtt_pad, bed_valid.max() + twtt_pad\n", + "\n", + "fig, axes = plt.subplots(1, 3, figsize=(16, 5), sharey=True)\n", + "\n", + "for i, (ax, title, repicked) in enumerate(zip(\n", + " axes,\n", + " ['Original only', 'Local max', 'find_peaks'],\n", + " [None, bed_repicked_max, bed_repicked_peaks],\n", + ")):\n", + " im = radargram_db.plot.pcolormesh(x='slow_time', cmap='gray', ax=ax, add_colorbar=False)\n", + " ax.invert_yaxis()\n", + " layers['standard:bottom']['twtt'].plot(\n", + " ax=ax, x='slow_time', linewidth=1.5, linestyle='--', color='tab:blue', label='Original'\n", + " )\n", + " if repicked is not None:\n", + " repicked.plot(ax=ax, linewidth=1.5, color='tab:red', marker='.', markersize=3, label=title)\n", + " ax.set_xlim(st_lo, st_hi)\n", + " ax.set_ylim(twtt_hi, twtt_lo)\n", + " ax.set_title(title)\n", + " ax.legend(loc='lower right', fontsize=8)\n", + " if ax != axes[0]:\n", + " ax.set_ylabel('')\n", + "\n", + "fig.colorbar(im, ax=axes, label='Power (dB)', shrink=0.8)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "mcgmnkmaxtb", + "metadata": {}, + "source": [ + "## Quantifying the difference\n", + "\n", + "It can be helpful to see how much the repicked values differ from the originals. Large differences might indicate either a successful correction or a bad pick that needs manual review." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9y00w7vjj3f", + "metadata": {}, + "outputs": [], + "source": [ + "# Align original picks to the same slow_time grid as the repicked values\n", + "bed_original = layers['standard:bottom']['twtt'].interp(slow_time=linear_data['slow_time'], method='nearest')\n", + "\n", + "bed_diff_max = (bed_repicked_max - bed_original) * 1e6\n", + "bed_diff_peaks = (bed_repicked_peaks - bed_original) * 1e6\n", + "\n", + "fig, axes = plt.subplots(1, 2, figsize=(14, 3), sharey=True)\n", + "\n", + "axes[0].hist(bed_diff_max.values[np.isfinite(bed_diff_max.values)], bins=100, color='tab:cyan')\n", + "axes[0].set_xlabel('ΔTWTT (μs)')\n", + "axes[0].set_ylabel('Count')\n", + "axes[0].set_title('Bed: local max − original')\n", + "\n", + "axes[1].hist(bed_diff_peaks.values[np.isfinite(bed_diff_peaks.values)], bins=100, color='tab:red')\n", + "axes[1].set_xlabel('ΔTWTT (μs)')\n", + "axes[1].set_title('Bed: find_peaks − original')\n", + "\n", + "plt.tight_layout()\n", + "plt.show()\n", + "\n", + "for name, diff in [(\"Local max\", bed_diff_max), (\"find_peaks\", bed_diff_peaks)]:\n", + " valid = diff.values[np.isfinite(diff.values)]\n", + " if len(valid) == 0:\n", + " print(f\"{name}: no valid comparisons\")\n", + " continue\n", + " print(f\"{name}: median shift = {np.median(valid):.4f} μs, \"\n", + " f\"mean |shift| = {np.mean(np.abs(valid)):.4f} μs, \"\n", + " f\"max |shift| = {np.max(np.abs(valid)):.4f} μs\")" + ] + }, + { + "cell_type": "markdown", + "id": "cc9ayr7x7aj", + "metadata": {}, + "source": [ + "## Interactive radargram\n", + "\n", + "The plot below uses HoloViews/Bokeh — use the scroll wheel to zoom and drag to pan.\n", + "\n", + "Original picks are shown as dashed lines. Repicked versions are solid: cyan for local max, red for find_peaks." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "icpmvalpo1f", + "metadata": {}, + "outputs": [], + "source": [ + "# Subsample for interactive plotting (full resolution is too large without datashader)\n", + "step = max(1, len(radargram_db['slow_time']) // 500)\n", + "radargram_sub = radargram_db.isel(slow_time=slice(None, None, step))\n", + "\n", + "radargram_hv = radargram_sub.hvplot.quadmesh(\n", + " x='slow_time', y='twtt', cmap='gray',\n", + " width=900, height=500, colorbar=True, clabel='Power (dB)',\n", + ").opts(invert_yaxis=True)\n", + "\n", + "def _curve(da, label, color, dash='solid'):\n", + " \"\"\"Helper to build an hv.Curve from an xr.DataArray or layer dataset.\"\"\"\n", + " return hv.Curve(\n", + " (da['slow_time'].values, da.values), 'slow_time', 'twtt', label=label\n", + " ).opts(color=color, line_dash=dash, line_width=1.5)\n", + "\n", + "# Original picks (dashed)\n", + "surf_orig = _curve(layers['standard:surface']['twtt'], 'Surface (original)', 'blue', 'dashed')\n", + "bed_orig = _curve(layers['standard:bottom']['twtt'], 'Bed (original)', 'blue', 'dashed')\n", + "\n", + "# Local max repicks (cyan solid)\n", + "surf_max = _curve(surface_repicked_max, 'Surface (local max)', 'cyan')\n", + "bed_max = _curve(bed_repicked_max, 'Bed (local max)', 'cyan')\n", + "\n", + "# find_peaks repicks (red solid)\n", + "surf_fp = _curve(surface_repicked_peaks, 'Surface (find_peaks)', 'red')\n", + "bed_fp = _curve(bed_repicked_peaks, 'Bed (find_peaks)', 'red')\n", + "\n", + "(radargram_hv * surf_orig * bed_orig * surf_max * bed_max * surf_fp * bed_fp).opts(\n", + " title='Interactive comparison — zoom in to see differences',\n", + " active_tools=['wheel_zoom'],\n", + " legend_position='top_left',\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "xopr", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/xopr/__init__.py b/src/xopr/__init__.py index 8bbd39e..adc56a8 100644 --- a/src/xopr/__init__.py +++ b/src/xopr/__init__.py @@ -34,6 +34,7 @@ __version__ = "unknown" from . import geometry as geometry +from . import qc as qc from .opr_access import OPRConnection as OPRConnection from .opr_tools import find_intersections as find_intersections from .opr_tools import merge_frames as merge_frames diff --git a/src/xopr/opr_access.py b/src/xopr/opr_access.py index 386c3cb..dbd18b2 100644 --- a/src/xopr/opr_access.py +++ b/src/xopr/opr_access.py @@ -311,8 +311,10 @@ def query_frames(self, collections: list[str] = None, segment_paths: list[str] = def load_frames(self, stac_items: gpd.GeoDataFrame, data_product: str = "CSARP_standard", + image: Union[int, None] = None, merge_flights: bool = False, skip_errors: bool = False, + allow_unlisted_products: bool = False ) -> Union[list[xr.Dataset], xr.Dataset]: """ Load multiple radar frames from STAC items. @@ -323,10 +325,16 @@ def load_frames(self, stac_items: gpd.GeoDataFrame, STAC items returned from query_frames. data_product : str, optional Data product to load (e.g., "CSARP_standard", "CSARP_qlook"). + image : int or None, optional + The image number to load for each frame. If None (default), loads the + combined image. If specified, `allow_unlisted_products` must be True. merge_flights : bool, optional If True, merge frames from the same segment into single Datasets. skip_errors : bool, optional If True, skip failed frames and continue loading others. + allow_unlisted_products : bool, optional + If True, attempt to load the specified data product even if it's not + listed in the item's assets. See `load_frame` for details. Returns ------- @@ -338,7 +346,7 @@ def load_frames(self, stac_items: gpd.GeoDataFrame, for idx, item in stac_items.iterrows(): try: - frame = self.load_frame(item, data_product) + frame = self.load_frame(item, data_product, image=image, allow_unlisted_products=allow_unlisted_products) frames.append(frame) except Exception as e: print(f"Error loading frame for item {item.get('id', 'unknown')}: {e}") @@ -352,7 +360,9 @@ def load_frames(self, stac_items: gpd.GeoDataFrame, else: return frames - def load_frame(self, stac_item, data_product: str = "CSARP_standard") -> xr.Dataset: + def load_frame(self, stac_item, data_product: str = "CSARP_standard", + image: Union[int, None] = None, + allow_unlisted_products: bool = False) -> xr.Dataset: """ Load a single radar frame from a STAC item. @@ -362,7 +372,16 @@ def load_frame(self, stac_item, data_product: str = "CSARP_standard") -> xr.Data STAC item containing asset URLs. data_product : str, optional Data product to load (e.g., "CSARP_standard", "CSARP_qlook"). - + image : int or None, optional + The image number to load for this frame. If None (default), loads the + combined image. If specified, `allow_unlisted_products` must be True. + allow_unlisted_products : bool, optional + If True, attempt to load the specified data product even if it's not + listed in the item's assets. This can be useful for loading non-standard + products. If set to True and the data product is not found in the STAC + item assets, the method will attempt to construct the URL based on any + available CSARP_* asset. (If the frame is entirely unlisted, you can use + `load_frame_url` instead.) Returns ------- xr.Dataset @@ -375,13 +394,36 @@ def load_frame(self, stac_item, data_product: str = "CSARP_standard") -> xr.Data # Get the data asset data_asset = assets.get(data_product) if not data_asset: - available_assets = list(assets.keys()) - raise ValueError(f"No {data_product} asset found. Available assets: {available_assets}") + if not allow_unlisted_products: + available_assets = list(assets.keys()) + raise ValueError(f"{data_product} asset not found in STAC item. Available assets: {available_assets}") + else: + # Find any available CSARP_* asset to use as a template for constructing the URL + template_asset_name = None + template_asset_url = None + for asset_name, asset_info in assets.items(): + if asset_name.startswith('CSARP_'): + template_asset_name = asset_name + template_asset_url = asset_info.get('href') + break + + if not template_asset_url: + available_assets = list(assets.keys()) + raise ValueError(f"No CSARP_* asset found in STAC item to use as a template for constructing URL for {data_product}. Available assets: {available_assets}") + + url = template_asset_url.replace(template_asset_name, data_product) + else: + # The asset does exist in the STAC item, so just get the URL from the asset + url = data_asset.get('href') + + if not url: + raise ValueError(f"No href found in {data_product} asset") - # Get the URL from the asset - url = data_asset.get('href') - if not url: - raise ValueError(f"No href found in {data_product} asset") + # If a specific image is requested, modify the URL to point to that image + if image is not None: + if not allow_unlisted_products: + raise ValueError("Specifying an image number requires allow_unlisted_products=True to construct the URL") + url = url.replace("Data_", f"Data_img_{image:02d}_") # Load the frame using the existing method return self.load_frame_url(url) diff --git a/src/xopr/qc/__init__.py b/src/xopr/qc/__init__.py new file mode 100644 index 0000000..ac6ad3a --- /dev/null +++ b/src/xopr/qc/__init__.py @@ -0,0 +1,17 @@ +"""Quality control module for polar radar datasets.""" + +from .checks import ensure_picks as ensure_picks +from .checks import heading_change as heading_change +from .checks import ice_thickness_threshold as ice_thickness_threshold +from .checks import minimum_agl as minimum_agl +from .checks import snr_bed_pick as snr_bed_pick +from .runner import run_qc as run_qc + +__all__ = [ + "ensure_picks", + "heading_change", + "ice_thickness_threshold", + "run_qc", + "snr_bed_pick", + "minimum_agl", +] diff --git a/src/xopr/qc/checks.py b/src/xopr/qc/checks.py new file mode 100644 index 0000000..6d425d2 --- /dev/null +++ b/src/xopr/qc/checks.py @@ -0,0 +1,316 @@ +""" +Quality control checks for polar radar datasets. + +Each check function takes an xarray Dataset and returns a modified copy +with a per-trace boolean mask added as a new variable. +""" + +import numpy as np +import xarray as xr +from pyproj import Transformer +from scipy.constants import c as speed_of_light + +_REQUIRED_LAYERS = { + "standard:surface": [":surface"], + "standard:bottom": [":bottom"], +} + + +def _resolve_layer(name, aliases, available_keys): + """Return the first key in *available_keys* that matches *name* or an alias.""" + if name in available_keys: + return name + for alias in aliases: + if alias in available_keys: + return alias + return None + + +def ensure_picks(ds, opr=None): + """ + Ensure ``standard:surface`` and ``standard:bottom`` variables exist. + + If either variable is missing, layer picks are loaded via + ``opr.get_layers()`` and their twtt values are assigned to the + dataset. The layer names ``":surface"`` and ``":bottom"`` are + accepted as aliases and are stored under their canonical + ``standard:`` names. + + Parameters + ---------- + ds : xarray.Dataset + Radar dataset, potentially missing pick variables. + opr : xopr.OPRConnection, optional + An OPR connection used to fetch layers. Required only when + pick variables are missing from *ds*. + + Returns + ------- + xarray.Dataset + Copy of *ds* with ``standard:surface`` and ``standard:bottom`` + variables present. + + Raises + ------ + ValueError + If picks are missing and *opr* is ``None``, or if the required + layers cannot be loaded. + """ + if all(v in ds for v in _REQUIRED_LAYERS): + return ds + + if opr is None: + missing = [v for v in _REQUIRED_LAYERS if v not in ds] + raise ValueError( + f"Dataset is missing {missing} and no OPRConnection was " + "provided to load layer picks. Pass an opr= argument or " + "add the variables to the dataset manually." + ) + + ds = ds.copy() + layers = opr.get_layers(ds) + if layers is None: + raise ValueError("No layer data found for this dataset.") + + for canonical, aliases in _REQUIRED_LAYERS.items(): + if canonical in ds: + continue + matched = _resolve_layer(canonical, aliases, layers) + if matched is None: + raise ValueError( + f"Layer '{canonical}' (or aliases {aliases}) not found. " + f"Available layers: {list(layers.keys())}" + ) + twtt = layers[matched]["twtt"] + twtt = twtt.reindex( + slow_time=ds.slow_time, + method="nearest", + tolerance=np.timedelta64(5, "s"), + fill_value=np.nan, + ) + ds[canonical] = twtt + + return ds + + +def _apply_qc_mask(ds, mask, name): + """ + Add a QC mask to a dataset and update the combined ``qc`` variable. + + Parameters + ---------- + ds : xarray.Dataset + Input dataset. + mask : xarray.DataArray + Boolean DataArray with dimension ``slow_time``. True means the + trace passed the check. + name : str + Check name; stored as ``qc_{name}`` in the returned dataset. + + Returns + ------- + xarray.Dataset + Copy of *ds* with ``qc_{name}`` added. The combined ``qc`` + variable is the element-wise AND of all individual masks. + """ + ds = ds.copy() + var_name = f"qc_{name}" + ds[var_name] = mask + if "qc" in ds: + ds["qc"] = ds["qc"] & mask + else: + ds["qc"] = mask.copy() + return ds + + +# ---- Individual checks ----------------------------------------------- + + +def ice_thickness_threshold(ds, min_thickness_m=500.0, epsilon_ice=3.15): + """ + Flag traces where ice thickness is below a minimum threshold. + + Ice thickness is computed from the ``standard:surface`` and + ``standard:bottom`` two-way travel time picks using the radar wave + speed in ice. + + Parameters + ---------- + ds : xarray.Dataset + Must contain ``standard:surface`` and ``standard:bottom`` + variables (units: seconds). + min_thickness_m : float, optional + Minimum ice thickness in metres. Default 500. + epsilon_ice : float, optional + Relative permittivity of ice. Default 3.15. + + Returns + ------- + xarray.Dataset + Copy with ``qc_ice_thickness_threshold`` and ``qc`` variables. + + Raises + ------ + ValueError + If ``standard:surface`` or ``standard:bottom`` is missing. + """ + for var in _REQUIRED_LAYERS: + if var not in ds: + raise ValueError(f"Dataset is missing required variable '{var}'") + + v_ice = speed_of_light / np.sqrt(epsilon_ice) + thickness = (ds["standard:bottom"] - ds["standard:surface"]) * v_ice / 2.0 + mask = xr.DataArray( + thickness.values >= min_thickness_m, + dims="slow_time", + ) + # NaN picks produce NaN thickness → comparison is False + return _apply_qc_mask(ds, mask, "ice_thickness_threshold") + + +def snr_bed_pick(ds, min_snr_db=5.0, noise_region_samples=50): + """ + Flag traces where the bed pick signal-to-noise ratio is too low. + + SNR is estimated as the ratio of bed-pick power to the noise floor. + The noise floor is taken as the mean power over the last + *noise_region_samples* fast-time samples (assumed to be below the + bed return). + + Parameters + ---------- + ds : xarray.Dataset + Must contain ``Data`` (dims ``slow_time``, ``twtt``) and + ``standard:bottom`` (units: seconds). + min_snr_db : float, optional + Minimum acceptable SNR in dB. Default 5. + noise_region_samples : int, optional + Number of fast-time samples at the end of each trace to use + for the noise-floor estimate. Default 50. + + Returns + ------- + xarray.Dataset + Copy with ``qc_snr_bed_pick`` and ``qc`` variables. + + Raises + ------ + ValueError + If ``Data`` or ``standard:bottom`` is missing. + """ + for var in ("Data", "standard:bottom"): + if var not in ds: + raise ValueError(f"Dataset is missing required variable '{var}'") + + data = np.abs(ds["Data"].values) # (slow_time, twtt) or (twtt, slow_time) + # Ensure shape is (slow_time, twtt) + if ds["Data"].dims[0] == "twtt": + data = data.T + + twtt = ds.twtt.values + bottom_twtt = ds["standard:bottom"].values + + n_traces = data.shape[0] + idx = np.clip(np.searchsorted(twtt, bottom_twtt), 0, data.shape[1] - 1) + bed_power = np.where(np.isnan(bottom_twtt), np.nan, data[np.arange(n_traces), idx]) + + noise_floor = np.nanmean(data[:, -noise_region_samples:], axis=1) + + with np.errstate(divide="ignore", invalid="ignore"): + snr_db = 10.0 * np.log10(bed_power / noise_floor) + + mask = xr.DataArray(snr_db >= min_snr_db, dims="slow_time") + return _apply_qc_mask(ds, mask, "snr_bed_pick") + + +def heading_change(ds, max_deg_per_km=2.0): + """ + Flag traces with rapid aircraft heading changes. + + The heading rate of change is estimated from consecutive traces and + normalised by along-track distance. + + Parameters + ---------- + ds : xarray.Dataset + Must contain ``Heading`` (radians) and ``Latitude`` / + ``Longitude`` (degrees). + max_deg_per_km : float, optional + Maximum acceptable heading change in degrees per kilometre. + Default 2. + + Returns + ------- + xarray.Dataset + Copy with ``qc_heading_change`` and ``qc`` variables. + + Raises + ------ + ValueError + If ``Heading``, ``Latitude``, or ``Longitude`` is missing. + """ + for var in ("Heading", "Latitude", "Longitude"): + if var not in ds: + raise ValueError(f"Dataset is missing required variable '{var}'") + + heading_rad = ds["Heading"].values + + # Compute along-track distances in metres via projected coords + lat = ds["Latitude"].values + lon = ds["Longitude"].values + mean_lat = np.nanmean(lat) + epsg = "EPSG:3031" if mean_lat < 0 else "EPSG:3413" + transformer = Transformer.from_crs("EPSG:4326", epsg, always_xy=True) + x, y = transformer.transform(lon, lat) + + dx = np.diff(x) + dy = np.diff(y) + dist_m = np.sqrt(dx**2 + dy**2) + + # Heading change per step (handle wraparound at ±π) + dh = np.diff(heading_rad) + dh = (dh + np.pi) % (2 * np.pi) - np.pi # wrap to [-π, π] + dh_deg = np.abs(np.degrees(dh)) + + with np.errstate(divide="ignore", invalid="ignore"): + deg_per_km = np.where(dist_m > 0, dh_deg / (dist_m / 1000.0), 0.0) + + # First trace has no predecessor → passes by default + rate = np.empty(len(heading_rad)) + rate[0] = 0.0 + rate[1:] = deg_per_km + + mask = xr.DataArray(rate <= max_deg_per_km, dims="slow_time") + return _apply_qc_mask(ds, mask, "heading_change") + + +def minimum_agl(ds, min_agl_m=100.0): + """ + Flag traces where the above-ground level is too low. + + AGL is estimated from the ``standard:surface`` two-way travel time + as the one-way range from the platform to the surface. + + Parameters + ---------- + ds : xarray.Dataset + Must contain ``standard:surface`` (units: seconds). + min_agl_m : float, optional + Minimum above-ground level in metres. Default 100. + + Returns + ------- + xarray.Dataset + Copy with ``qc_minimum_agl`` and ``qc`` variables. + + Raises + ------ + ValueError + If ``standard:surface`` is missing. + """ + if "standard:surface" not in ds: + raise ValueError("Dataset is missing required variable 'standard:surface'") + + agl = ds["standard:surface"].values * speed_of_light / 2.0 + mask = xr.DataArray(agl >= min_agl_m, dims="slow_time") + return _apply_qc_mask(ds, mask, "minimum_agl") diff --git a/src/xopr/qc/runner.py b/src/xopr/qc/runner.py new file mode 100644 index 0000000..f28e982 --- /dev/null +++ b/src/xopr/qc/runner.py @@ -0,0 +1,93 @@ +""" +QC runner — orchestrates multiple quality control checks on a dataset. +""" + +from .checks import ( + ensure_picks, + heading_change, + ice_thickness_threshold, + minimum_agl, + snr_bed_pick, +) + +_CHECKS = { + "ice_thickness_threshold": ice_thickness_threshold, + "snr_bed_pick": snr_bed_pick, + "heading_change": heading_change, + "minimum_agl": minimum_agl, +} + + +def run_qc(ds, checks=None, opr=None): + """ + Run one or more QC checks on a dataset. + + If ``standard:surface`` or ``standard:bottom`` variables are missing + from the dataset, they are automatically loaded from layer picks via + *opr*. + + Parameters + ---------- + ds : xarray.Dataset + Input radar dataset. + checks : dict, optional + Mapping of checks to run. Keys are either a registered check + name (string matching a key in ``_CHECKS``) or a callable. Values + are dicts of keyword arguments passed to the check function + (use ``{}`` for defaults). + + Examples:: + + # Registered check with default params + {"ice_thickness_threshold": {}} + + # Registered check with custom params + {"ice_thickness_threshold": {"min_thickness_m": 300}} + + # Custom function + {my_custom_check: {"threshold": 0.5}} + + # Mix of both + {"ice_thickness_threshold": {}, my_custom_check: {}} + + ``None`` (default) runs all registered checks with default + parameters. + opr : xopr.OPRConnection, optional + Connection used to load layer picks when ``standard:surface`` or + ``standard:bottom`` is not already in *ds*. + + Returns + ------- + xarray.Dataset + Dataset with QC mask variables added. + + Raises + ------ + ValueError + If a string key does not match any registered check, or if picks + are missing and *opr* is ``None``. + """ + if checks is None: + checks = {name: {} for name in _CHECKS} + + # Resolve string keys to callables; validate up front + resolved = [] + for key, kwargs in checks.items(): + if callable(key): + resolved.append((key, kwargs)) + elif isinstance(key, str): + if key not in _CHECKS: + raise ValueError( + f"Unknown QC check '{key}'. " + f"Registered checks: {list(_CHECKS.keys())}" + ) + resolved.append((_CHECKS[key], kwargs)) + else: + raise TypeError(f"Check keys must be strings or callables, got {type(key)}") + + ds = ensure_picks(ds, opr=opr) + + for fn, kwargs in resolved: + ds = fn(ds, **kwargs) + + return ds diff --git a/src/xopr/qc/tests/__init__.py b/src/xopr/qc/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/xopr/qc/tests/test_checks.py b/src/xopr/qc/tests/test_checks.py new file mode 100644 index 0000000..4b62fdb --- /dev/null +++ b/src/xopr/qc/tests/test_checks.py @@ -0,0 +1,402 @@ +"""Tests for QC check functions.""" + +from unittest.mock import MagicMock + +import numpy as np +import pytest +import xarray as xr +from scipy.constants import c as speed_of_light + +from xopr.qc.checks import ( + ensure_picks, + heading_change, + ice_thickness_threshold, + minimum_agl, + snr_bed_pick, +) + + +@pytest.fixture +def synthetic_ds(): + """Synthetic radar dataset with 100 traces and 200 fast-time samples.""" + n_traces = 100 + n_samples = 200 + twtt = np.linspace(0, 50e-6, n_samples) + slow_time = np.arange(n_traces, dtype=float) + + # Build a Data array where the bed-pick sample is bright + data = np.random.rand(n_traces, n_samples) * 0.01 + bottom_twtt_val = 40e-6 + bed_idx = np.searchsorted(twtt, bottom_twtt_val) + data[:, bed_idx] = 10.0 # strong bed return + + ds = xr.Dataset( + { + "Data": (["slow_time", "twtt"], data), + "standard:surface": ("slow_time", np.full(n_traces, 10e-6)), + "standard:bottom": ("slow_time", np.full(n_traces, bottom_twtt_val)), + "Latitude": ("slow_time", np.linspace(-75, -74, n_traces)), + "Longitude": ("slow_time", np.linspace(100, 101, n_traces)), + "Heading": ("slow_time", np.zeros(n_traces)), + }, + coords={"slow_time": slow_time, "twtt": twtt}, + ) + return ds + + +# ---- ice_thickness_threshold ----------------------------------------- + + +def _synthetic_thickness(surface_twtt=10e-6, bottom_twtt=40e-6, epsilon=3.15): + v_ice = speed_of_light / np.sqrt(epsilon) + return (bottom_twtt - surface_twtt) * v_ice / 2.0 + + +def test_ice_thickness_all_pass(synthetic_ds): + thickness = _synthetic_thickness() + result = ice_thickness_threshold(synthetic_ds, min_thickness_m=thickness - 1) + assert result["qc_ice_thickness_threshold"].all() + assert result["qc"].all() + + +def test_ice_thickness_all_fail(synthetic_ds): + thickness = _synthetic_thickness() + result = ice_thickness_threshold(synthetic_ds, min_thickness_m=thickness + 1) + assert not result["qc_ice_thickness_threshold"].any() + + +def test_ice_thickness_nan_bottom(synthetic_ds): + synthetic_ds["standard:bottom"].values[50:] = np.nan + thickness = _synthetic_thickness() + result = ice_thickness_threshold(synthetic_ds, min_thickness_m=thickness - 1) + assert result["qc_ice_thickness_threshold"][:50].all() + assert not result["qc_ice_thickness_threshold"][50:].any() + + +def test_qc_variable_created(synthetic_ds): + result = ice_thickness_threshold(synthetic_ds, min_thickness_m=0) + assert "qc" in result + np.testing.assert_array_equal( + result["qc"].values, result["qc_ice_thickness_threshold"].values + ) + + +def test_qc_variable_accumulation(synthetic_ds): + """Running a check twice with different data should AND the masks.""" + result = ice_thickness_threshold(synthetic_ds, min_thickness_m=0) + result["standard:bottom"].values[50:] = result["standard:surface"].values[50:] + thickness = _synthetic_thickness() + result = ice_thickness_threshold(result, min_thickness_m=thickness - 1) + assert result["qc"][:50].all() + assert not result["qc"][50:].any() + + +def test_copy_semantics(synthetic_ds): + result = ice_thickness_threshold(synthetic_ds, min_thickness_m=0) + assert "qc" not in synthetic_ds + assert "qc_ice_thickness_threshold" not in synthetic_ds + assert "qc" in result + + +def test_missing_variable(): + ds = xr.Dataset({"standard:surface": ("slow_time", [1.0])}) + with pytest.raises(ValueError, match="standard:bottom"): + ice_thickness_threshold(ds) + + +# ---- snr_bed_pick ---------------------------------------------------- + + +def test_snr_bed_pick_all_pass(synthetic_ds): + """Strong bed return should pass even a moderate SNR threshold.""" + result = snr_bed_pick(synthetic_ds, min_snr_db=5.0) + assert result["qc_snr_bed_pick"].all() + + +def test_snr_bed_pick_all_fail(synthetic_ds): + """Extremely high threshold should fail everything.""" + result = snr_bed_pick(synthetic_ds, min_snr_db=100.0) + assert not result["qc_snr_bed_pick"].any() + + +def test_snr_bed_pick_nan_bottom(synthetic_ds): + """NaN bottom picks should be flagged as failing.""" + synthetic_ds["standard:bottom"].values[50:] = np.nan + result = snr_bed_pick(synthetic_ds, min_snr_db=5.0) + assert not result["qc_snr_bed_pick"][50:].any() + + +def test_snr_bed_pick_missing_data(): + ds = xr.Dataset({"standard:bottom": ("slow_time", [1.0])}) + with pytest.raises(ValueError, match="Data"): + snr_bed_pick(ds) + + +# ---- heading_change -------------------------------------------------- + + +def test_heading_change_straight_pass(synthetic_ds): + """Constant heading should always pass.""" + result = heading_change(synthetic_ds, max_deg_per_km=2.0) + assert result["qc_heading_change"].all() + + +def test_heading_change_sharp_turn(synthetic_ds): + """A sudden 90-degree turn should fail.""" + synthetic_ds["Heading"].values[50] = np.pi / 2 + result = heading_change(synthetic_ds, max_deg_per_km=0.01) + # Trace 50 (and 51 due to backward diff) should fail + assert not result["qc_heading_change"].values[50] + + +def test_heading_change_wraparound(synthetic_ds): + """Heading wrapping from +π to −π should not produce a large change.""" + # Set all headings to just below +π then jump to just above -π + synthetic_ds["Heading"].values[:] = np.pi - 0.001 + synthetic_ds["Heading"].values[50] = -np.pi + 0.001 + result = heading_change(synthetic_ds, max_deg_per_km=2.0) + # The actual change is ~0.002 rad ≈ 0.1°, should pass + assert result["qc_heading_change"].values[50] + + +def test_heading_change_missing_var(): + ds = xr.Dataset({"Latitude": ("slow_time", [1.0])}) + with pytest.raises(ValueError, match="Heading"): + heading_change(ds) + + +# ---- minimum_agl ----------------------------------------------------- + + +def test_minimum_agl_all_pass(synthetic_ds): + """Surface TWTT of 10μs gives AGL ~1500m, should pass 100m threshold.""" + result = minimum_agl(synthetic_ds, min_agl_m=100.0) + assert result["qc_minimum_agl"].all() + + +def test_minimum_agl_all_fail(synthetic_ds): + """Setting threshold above the AGL should fail everything.""" + agl = 10e-6 * speed_of_light / 2.0 # ~1500m + result = minimum_agl(synthetic_ds, min_agl_m=agl + 1) + assert not result["qc_minimum_agl"].any() + + +def test_minimum_agl_partial(synthetic_ds): + """Low surface TWTT traces should fail.""" + low_twtt = 2 * 50.0 / speed_of_light # 50m AGL + synthetic_ds["standard:surface"].values[50:] = low_twtt + result = minimum_agl(synthetic_ds, min_agl_m=100.0) + assert result["qc_minimum_agl"][:50].all() + assert not result["qc_minimum_agl"][50:].any() + + +def test_minimum_agl_nan_surface(synthetic_ds): + """NaN surface picks should fail.""" + synthetic_ds["standard:surface"].values[50:] = np.nan + result = minimum_agl(synthetic_ds, min_agl_m=100.0) + assert result["qc_minimum_agl"][:50].all() + assert not result["qc_minimum_agl"][50:].any() + + +def test_minimum_agl_missing_var(): + ds = xr.Dataset({"Data": ("slow_time", [1.0])}) + with pytest.raises(ValueError, match="standard:surface"): + minimum_agl(ds) + + +# ---- ensure_picks ---------------------------------------------------- + + +def test_ensure_picks_already_present(synthetic_ds): + result = ensure_picks(synthetic_ds) + assert "standard:surface" in result + assert "standard:bottom" in result + + +def test_ensure_picks_no_opr_raises(): + ds = xr.Dataset({"Data": ("slow_time", [1.0])}, coords={"slow_time": [0.0]}) + with pytest.raises(ValueError, match="OPRConnection"): + ensure_picks(ds) + + +def test_ensure_picks_loads_bottom(): + n = 100 + slow_time = np.arange(n, dtype=float) + ds = xr.Dataset( + {"standard:surface": ("slow_time", np.full(n, 10e-6))}, + coords={"slow_time": slow_time}, + ) + bottom_vals = np.full(n, 40e-6) + mock_opr = MagicMock() + mock_opr.get_layers.return_value = { + "standard:surface": xr.Dataset( + {"twtt": ("slow_time", np.full(n, 10e-6))}, + coords={"slow_time": slow_time}, + ), + "standard:bottom": xr.Dataset( + {"twtt": ("slow_time", bottom_vals)}, + coords={"slow_time": slow_time}, + ), + } + result = ensure_picks(ds, opr=mock_opr) + assert "standard:bottom" in result + np.testing.assert_array_equal(result["standard:bottom"].values, bottom_vals) + mock_opr.get_layers.assert_called_once() + + +def test_ensure_picks_loads_both(): + n = 50 + slow_time = np.arange(n, dtype=float) + ds = xr.Dataset( + {"Data": ("slow_time", np.ones(n))}, + coords={"slow_time": slow_time}, + ) + mock_opr = MagicMock() + mock_opr.get_layers.return_value = { + "standard:surface": xr.Dataset( + {"twtt": ("slow_time", np.full(n, 10e-6))}, + coords={"slow_time": slow_time}, + ), + "standard:bottom": xr.Dataset( + {"twtt": ("slow_time", np.full(n, 40e-6))}, + coords={"slow_time": slow_time}, + ), + } + result = ensure_picks(ds, opr=mock_opr) + assert "standard:surface" in result + assert "standard:bottom" in result + + +def test_ensure_picks_no_layers_raises(): + ds = xr.Dataset({"Data": ("slow_time", [1.0])}, coords={"slow_time": [0.0]}) + mock_opr = MagicMock() + mock_opr.get_layers.return_value = None + with pytest.raises(ValueError, match="No layer data"): + ensure_picks(ds, opr=mock_opr) + + +def test_ensure_picks_copy_semantics(): + n = 100 + slow_time = np.arange(n, dtype=float) + ds = xr.Dataset( + {"standard:surface": ("slow_time", np.full(n, 10e-6))}, + coords={"slow_time": slow_time}, + ) + mock_opr = MagicMock() + mock_opr.get_layers.return_value = { + "standard:surface": xr.Dataset( + {"twtt": ("slow_time", np.full(n, 10e-6))}, + coords={"slow_time": slow_time}, + ), + "standard:bottom": xr.Dataset( + {"twtt": ("slow_time", np.full(n, 40e-6))}, + coords={"slow_time": slow_time}, + ), + } + ensure_picks(ds, opr=mock_opr) + assert "standard:bottom" not in ds + + +def test_ensure_picks_alias_bottom(): + """Layers with ':bottom' alias are stored as 'standard:bottom'.""" + n = 50 + slow_time = np.arange(n, dtype=float) + ds = xr.Dataset( + {"standard:surface": ("slow_time", np.full(n, 10e-6))}, + coords={"slow_time": slow_time}, + ) + bottom_vals = np.full(n, 40e-6) + mock_opr = MagicMock() + mock_opr.get_layers.return_value = { + "standard:surface": xr.Dataset( + {"twtt": ("slow_time", np.full(n, 10e-6))}, + coords={"slow_time": slow_time}, + ), + ":bottom": xr.Dataset( + {"twtt": ("slow_time", bottom_vals)}, + coords={"slow_time": slow_time}, + ), + } + result = ensure_picks(ds, opr=mock_opr) + assert "standard:bottom" in result + assert ":bottom" not in result + np.testing.assert_array_equal(result["standard:bottom"].values, bottom_vals) + + +def test_ensure_picks_alias_surface(): + """Layers with ':surface' alias are stored as 'standard:surface'.""" + n = 50 + slow_time = np.arange(n, dtype=float) + ds = xr.Dataset( + {"standard:bottom": ("slow_time", np.full(n, 40e-6))}, + coords={"slow_time": slow_time}, + ) + surface_vals = np.full(n, 10e-6) + mock_opr = MagicMock() + mock_opr.get_layers.return_value = { + ":surface": xr.Dataset( + {"twtt": ("slow_time", surface_vals)}, + coords={"slow_time": slow_time}, + ), + "standard:bottom": xr.Dataset( + {"twtt": ("slow_time", np.full(n, 40e-6))}, + coords={"slow_time": slow_time}, + ), + } + result = ensure_picks(ds, opr=mock_opr) + assert "standard:surface" in result + assert ":surface" not in result + np.testing.assert_array_equal(result["standard:surface"].values, surface_vals) + + +def test_ensure_picks_alias_both(): + """Both layers via aliases are stored under canonical names.""" + n = 30 + slow_time = np.arange(n, dtype=float) + ds = xr.Dataset( + {"Data": ("slow_time", np.ones(n))}, + coords={"slow_time": slow_time}, + ) + mock_opr = MagicMock() + mock_opr.get_layers.return_value = { + ":surface": xr.Dataset( + {"twtt": ("slow_time", np.full(n, 10e-6))}, + coords={"slow_time": slow_time}, + ), + ":bottom": xr.Dataset( + {"twtt": ("slow_time", np.full(n, 40e-6))}, + coords={"slow_time": slow_time}, + ), + } + result = ensure_picks(ds, opr=mock_opr) + assert "standard:surface" in result + assert "standard:bottom" in result + + +def test_ensure_picks_prefers_canonical(): + """Canonical name is preferred over alias when both exist.""" + n = 30 + slow_time = np.arange(n, dtype=float) + ds = xr.Dataset( + {"Data": ("slow_time", np.ones(n))}, + coords={"slow_time": slow_time}, + ) + canonical_vals = np.full(n, 10e-6) + alias_vals = np.full(n, 99e-6) + mock_opr = MagicMock() + mock_opr.get_layers.return_value = { + "standard:surface": xr.Dataset( + {"twtt": ("slow_time", canonical_vals)}, + coords={"slow_time": slow_time}, + ), + ":surface": xr.Dataset( + {"twtt": ("slow_time", alias_vals)}, + coords={"slow_time": slow_time}, + ), + "standard:bottom": xr.Dataset( + {"twtt": ("slow_time", np.full(n, 40e-6))}, + coords={"slow_time": slow_time}, + ), + } + result = ensure_picks(ds, opr=mock_opr) + np.testing.assert_array_equal(result["standard:surface"].values, canonical_vals) diff --git a/src/xopr/qc/tests/test_runner.py b/src/xopr/qc/tests/test_runner.py new file mode 100644 index 0000000..cbffc3d --- /dev/null +++ b/src/xopr/qc/tests/test_runner.py @@ -0,0 +1,128 @@ +"""Tests for QC runner.""" + +from unittest.mock import MagicMock + +import numpy as np +import pytest +import xarray as xr + +from xopr.qc.checks import _apply_qc_mask +from xopr.qc.runner import run_qc + + +@pytest.fixture +def synthetic_ds(): + n_traces = 100 + n_samples = 200 + twtt = np.linspace(0, 50e-6, n_samples) + slow_time = np.arange(n_traces, dtype=float) + + return xr.Dataset( + { + "Data": (["slow_time", "twtt"], np.random.rand(n_traces, n_samples)), + "standard:surface": ("slow_time", np.full(n_traces, 10e-6)), + "standard:bottom": ("slow_time", np.full(n_traces, 40e-6)), + "Latitude": ("slow_time", np.linspace(-75, -74, n_traces)), + "Longitude": ("slow_time", np.linspace(100, 101, n_traces)), + "Heading": ("slow_time", np.zeros(n_traces)), + }, + coords={"slow_time": slow_time, "twtt": twtt}, + ) + + +def test_run_qc_default(synthetic_ds): + result = run_qc(synthetic_ds) + assert "qc" in result + assert "qc_ice_thickness_threshold" in result + assert "qc_snr_bed_pick" in result + assert "qc_heading_change" in result + assert "qc_minimum_agl" in result + + +def test_run_qc_with_params(synthetic_ds): + result = run_qc( + synthetic_ds, + checks={"ice_thickness_threshold": {"min_thickness_m": 0}}, + ) + assert result["qc"].all() + + +def test_run_qc_invalid_check(synthetic_ds): + with pytest.raises(ValueError, match="Unknown QC check"): + run_qc(synthetic_ds, checks={"nonexistent_check": {}}) + + +def test_run_qc_callable_key(synthetic_ds): + """A callable can be used as a check key.""" + + def always_pass(ds): + mask = xr.DataArray( + np.ones(ds.sizes["slow_time"], dtype=bool), dims="slow_time" + ) + return _apply_qc_mask(ds, mask, "always_pass") + + result = run_qc(synthetic_ds, checks={always_pass: {}}) + assert "qc_always_pass" in result + assert result["qc"].all() + + +def test_run_qc_mixed_keys(synthetic_ds): + """String and callable keys can be mixed.""" + + def flag_none(ds): + mask = xr.DataArray( + np.zeros(ds.sizes["slow_time"], dtype=bool), dims="slow_time" + ) + return _apply_qc_mask(ds, mask, "flag_none") + + result = run_qc( + synthetic_ds, + checks={ + "ice_thickness_threshold": {"min_thickness_m": 0}, + flag_none: {}, + }, + ) + # ice_thickness passes all, flag_none fails all → AND is all False + assert not result["qc"].any() + + +def test_run_qc_bad_key_type(synthetic_ds): + with pytest.raises(TypeError, match="strings or callables"): + run_qc(synthetic_ds, checks={42: {}}) + + +def test_run_qc_auto_loads_picks(): + """run_qc loads picks via opr when they are missing.""" + n_traces = 50 + n_samples = 100 + slow_time = np.arange(n_traces, dtype=float) + ds = xr.Dataset( + { + "Data": (["slow_time", "twtt"], np.random.rand(n_traces, n_samples)), + "Latitude": ("slow_time", np.linspace(-75, -74, n_traces)), + "Longitude": ("slow_time", np.linspace(100, 101, n_traces)), + "Heading": ("slow_time", np.zeros(n_traces)), + }, + coords={ + "slow_time": slow_time, + "twtt": np.linspace(0, 50e-6, n_samples), + }, + ) + + mock_opr = MagicMock() + mock_opr.get_layers.return_value = { + "standard:surface": xr.Dataset( + {"twtt": ("slow_time", np.full(n_traces, 10e-6))}, + coords={"slow_time": slow_time}, + ), + "standard:bottom": xr.Dataset( + {"twtt": ("slow_time", np.full(n_traces, 40e-6))}, + coords={"slow_time": slow_time}, + ), + } + + result = run_qc(ds, opr=mock_opr) + assert "qc" in result + assert "standard:bottom" in result + assert "standard:surface" in result + mock_opr.get_layers.assert_called_once()