diff --git a/docs/adding_components.md b/docs/adding_components.md index bfcb6f6a..6571cb2f 100644 --- a/docs/adding_components.md +++ b/docs/adding_components.md @@ -147,4 +147,3 @@ pytest tests/my_component_test.py -v - [ ] Create `docs/my_component.md` - [ ] Add to `docs/_toc.yml` - [ ] Update reference tables in `hybrid_plant.md` and `component_types.md` - diff --git a/hercules/resource/resource_utilities.py b/hercules/resource/resource_utilities.py new file mode 100644 index 00000000..bcdf09ed --- /dev/null +++ b/hercules/resource/resource_utilities.py @@ -0,0 +1,570 @@ +"""Shared utilities for resource data downloading and visualization. + +This module provides common functions used by the NSRDB, WTK, and Open-Meteo +resource downloaders, including time parameter validation, data I/O, +elapsed time formatting, and plotting. +""" + +import math +import os +import time +from typing import List, Optional + +import cartopy.crs as ccrs +import cartopy.feature as cfeature +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from hercules.utilities import hercules_float_type +from rex import ResourceX +from scipy.interpolate import griddata + + +def validate_time_params( + year: Optional[int], + start_date: Optional[str], + end_date: Optional[str], +) -> dict: + """Validate time parameters and compute derived time information. + + Ensures that either ``year`` or both ``start_date`` and ``end_date`` are + provided (but not both). Returns file_years, time_suffix, + time_description, and resolved start_date/end_date values. + + Args: + year (int, optional): Year of data to download. + start_date (str, optional): Start date in 'YYYY-MM-DD' format. + end_date (str, optional): End date in 'YYYY-MM-DD' format. + + Returns: + dict: Dictionary with keys: + - file_years (list[int]): Years spanned by the time range. + - time_suffix (str): Filename-safe suffix for the time range. + - time_description (str): Human-readable time range description. + - start_date (str): Resolved start date string. + - end_date (str): Resolved end date string. + + Raises: + ValueError: If the parameter combination is invalid or if + start_date > end_date. + """ + if year is not None and (start_date is not None or end_date is not None): + raise ValueError( + "Please provide either 'year' OR both 'start_date' and 'end_date', not both approaches." + ) + + if year is None and (start_date is None or end_date is None): + raise ValueError("Please provide either 'year' OR both 'start_date' and 'end_date'.") + + if year is not None: + return { + "file_years": [year], + "time_suffix": str(year), + "time_description": f"year {year}", + "start_date": f"{year}-01-01", + "end_date": f"{year}-12-31", + } + + start_dt = pd.to_datetime(start_date) + end_dt = pd.to_datetime(end_date) + + if start_dt > end_dt: + raise ValueError("start_date must be before end_date") + + return { + "file_years": list(range(start_dt.year, end_dt.year + 1)), + "time_suffix": f"{start_date}_to_{end_date}".replace("-", ""), + "time_description": f"period {start_date} to {end_date}", + "start_date": start_date, + "end_date": end_date, + } + + +def create_bounding_box( + target_lat: float, + target_lon: float, + coord_delta: float, +) -> tuple: + """Create a bounding box from a center point and coordinate delta. + + Args: + target_lat (float): Center latitude coordinate. + target_lon (float): Center longitude coordinate. + coord_delta (float): Half-width of the bounding box in degrees. + + Returns: + tuple: (llcrn_lat, llcrn_lon, urcrn_lat, urcrn_lon) lower-left and + upper-right corners of the bounding box. + """ + return ( + target_lat - coord_delta, + target_lon - coord_delta, + target_lat + coord_delta, + target_lon + coord_delta, + ) + + +def download_nrel_rex_data( + dataset_path: str, + dataset_filename_prefix: str, + source_name: str, + target_lat: float, + target_lon: float, + variables: List[str], + bounding_box: tuple, + file_years: List[int], + start_date: Optional[str], + end_date: Optional[str], + output_dir: str, + filename_prefix: str, + time_suffix: str, + time_description: str, + os_error_hint: str = "This could be caused by an invalid API key or date range.", +) -> dict: + """Download data from an NLR rex-based dataset (NSRDB or WTK). + + Handles the complete download workflow: fetching data via ResourceX for + each year, concatenating across years, converting to the hercules float + type, and saving to feather format. + + Args: + dataset_path (str): Base path of the dataset on the NLR HSDS server. + dataset_filename_prefix (str): Filename prefix for the HDF5 files + in the format ``{dataset_filename_prefix}_{year}.h5``. + source_name (str): Human-readable data source name (e.g., "NSRDB", + "WTK") used in log messages. + target_lat (float): Target latitude coordinate. + target_lon (float): Target longitude coordinate. + variables (list[str]): List of variables to download. + bounding_box (tuple): (llcrn_lat, llcrn_lon, urcrn_lat, urcrn_lon) + corners of the spatial bounding box. + file_years (list[int]): List of years to download. + start_date (str, optional): Start date for filtering. If None, no + date filtering is applied. + end_date (str, optional): End date for filtering. If None, no date + filtering is applied. + output_dir (str): Directory to save output feather files. + filename_prefix (str): Prefix for output filenames. + time_suffix (str): Suffix for output filenames encoding the time + range. + time_description (str): Human-readable time range for log messages. + os_error_hint (str, optional): Additional context for OSError + messages. Defaults to "This could be caused by an invalid API + key or date range." + + Returns: + dict: Dictionary containing DataFrames for each variable and a + "coordinates" key with lat/lon data. + + Raises: + OSError: If there is an error accessing the NLR HSDS server. + """ + llcrn_lat, llcrn_lon, urcrn_lat, urcrn_lon = bounding_box + + print(f"Downloading {source_name} data for {time_description}") + print(f"Target coordinates: ({target_lat}, {target_lon})") + print(f"Bounding box: ({llcrn_lat}, {llcrn_lon}) to ({urcrn_lat}, {urcrn_lon})") + print(f"Variables: {variables}") + print(f"Years to process: {file_years}") + + t0 = time.time() + + data_dict = {} + all_dataframes = {var: [] for var in variables} + + try: + for file_year in file_years: + print(f"\nProcessing year {file_year}...") + fp = f"{dataset_path}/{dataset_filename_prefix}_{file_year}.h5" + + with ResourceX(fp) as res: + for var in variables: + print(f" Downloading {var} for {file_year}...") + df_year = res.get_box_df( + var, + lat_lon_1=[llcrn_lat, llcrn_lon], + lat_lon_2=[urcrn_lat, urcrn_lon], + ) + + if start_date is not None and end_date is not None: + df_year = df_year.loc[start_date:end_date] + + all_dataframes[var].append(df_year) + + if "coordinates" not in data_dict: + gids = df_year.columns.values + coordinates = res.lat_lon[gids] + df_coords = pd.DataFrame(coordinates, index=gids, columns=["lat", "lon"]) + data_dict["coordinates"] = df_coords + + for var in variables: + if all_dataframes[var]: + print(f"Concatenating {var} data across {len(all_dataframes[var])} years...") + data_dict[var] = pd.concat(all_dataframes[var], axis=0).sort_index() + + for col in data_dict[var].columns: + if pd.api.types.is_numeric_dtype(data_dict[var][col]): + data_dict[var][col] = data_dict[var][col].astype(hercules_float_type) + + all_dataframes[var].clear() + + save_variable_to_feather( + data_dict[var], + output_dir, + filename_prefix, + var, + time_suffix, + ) + + save_coords_to_feather( + data_dict["coordinates"], + output_dir, + filename_prefix, + time_suffix, + ) + + except OSError as e: + print(f"Error downloading {source_name} data: {e}") + print(os_error_hint) + raise + except Exception as e: + print(f"Error downloading {source_name} data: {e}") + raise + + print_elapsed_time(t0, source_name) + + return data_dict + + +def save_variable_to_feather( + df: pd.DataFrame, + output_dir: str, + filename_prefix: str, + var_name: str, + time_suffix: str, +) -> str: + """Save a variable DataFrame to feather format. + + Args: + df (pd.DataFrame): DataFrame to save. + output_dir (str): Directory to save the file in. + filename_prefix (str): Prefix for the filename. + var_name (str): Variable name included in the filename. + time_suffix (str): Time range suffix included in the filename. + + Returns: + str: Path to the saved feather file. + """ + output_file = os.path.join(output_dir, f"{filename_prefix}_{var_name}_{time_suffix}.feather") + df.reset_index().to_feather(output_file) + print(f"Saved {var_name} data to {output_file}") + return output_file + + +def save_coords_to_feather( + df_coords: pd.DataFrame, + output_dir: str, + filename_prefix: str, + time_suffix: str, +) -> str: + """Save a coordinates DataFrame to feather format. + + Args: + df_coords (pd.DataFrame): Coordinates DataFrame with 'lat' and + 'lon' columns. + output_dir (str): Directory to save the file in. + filename_prefix (str): Prefix for the filename. + time_suffix (str): Time range suffix included in the filename. + + Returns: + str: Path to the saved feather file. + """ + coords_file = os.path.join(output_dir, f"{filename_prefix}_coords_{time_suffix}.feather") + df_coords.reset_index().to_feather(coords_file) + print(f"Saved coordinates to {coords_file}") + return coords_file + + +def print_elapsed_time(t0: float, source_name: str) -> None: + """Print elapsed time since t0 in minutes:seconds format. + + Args: + t0 (float): Start time from ``time.time()``. + source_name (str): Name of the data source for the log message. + """ + total_time = (time.time() - t0) / 60 + decimal_part = math.modf(total_time)[0] * 60 + print( + f"{source_name} download completed in " + f"{int(np.floor(total_time))}:{int(np.round(decimal_part, 0)):02d}" + " minutes" + ) + + +def dispatch_plots( + data_dict: dict, + variables: List[str], + plot_data: bool, + plot_type: str, + title: str, +) -> None: + """Dispatch plotting based on the plot_data flag and plot_type. + + Args: + data_dict (dict): Dictionary containing DataFrames for each variable + and a "coordinates" key. + variables (list[str]): List of variable names to plot. + plot_data (bool): Whether to create plots. + plot_type (str): Type of plot: 'timeseries' or 'map'. + title (str): Title for the plots. + """ + if plot_data and data_dict and "coordinates" in data_dict: + coordinates_array = data_dict["coordinates"][["lat", "lon"]].values + if plot_type == "timeseries": + plot_timeseries(data_dict, variables, coordinates_array, title) + elif plot_type == "map": + plot_spatial_map(data_dict, variables, coordinates_array, title) + + +# --------------------------------------------------------------------------- +# Plotting functions +# --------------------------------------------------------------------------- + + +def plot_timeseries( + data_dict: dict, + variables: List[str], + coordinates: np.ndarray, + title: str, +): + """Create time-series plots for the downloaded data. + + Args: + data_dict (dict): Dictionary containing DataFrames for each variable. + variables (list[str]): List of variables to plot. + coordinates (np.ndarray): Array of coordinates for the data points. + title (str): Title for the plots. + """ + n_vars = len(variables) + if n_vars == 0: + return + + fig, axes = plt.subplots(n_vars, 1, figsize=(12, 4 * n_vars), sharex=True) + if n_vars == 1: + axes = [axes] + + for i, var in enumerate(variables): + if var in data_dict: + df = data_dict[var] + + for col in df.columns: + axes[i].plot(df.index, df[col], alpha=0.7, linewidth=0.8) + + axes[i].set_ylabel(get_variable_label(var)) + axes[i].set_title(f"{var.replace('_', ' ').title()}") + axes[i].grid(True, alpha=0.3) + + axes[-1].set_xlabel("Time") + plt.suptitle(f"{title} - Time Series", fontsize=14, fontweight="bold") + plt.tight_layout() + + +def plot_spatial_map( + data_dict: dict, + variables: List[str], + coordinates: np.ndarray, + title: str, +): + """Create spatial maps showing the mean values across the region. + + Args: + data_dict (dict): Dictionary containing DataFrames for each variable. + variables (list[str]): List of variables to plot. + coordinates (np.ndarray): Array of coordinates for the data points. + title (str): Title for the plots. + """ + n_vars = len(variables) + if n_vars == 0: + return + + n_cols = min(2, n_vars) + n_rows = math.ceil(n_vars / n_cols) + + plt.figure(figsize=(8 * n_cols, 6 * n_rows)) + + for i, var in enumerate(variables): + if var in data_dict: + df = data_dict[var] + + lats = coordinates[:, 0] + lons = coordinates[:, 1] + + mean_values = df.mean(axis=0).values + + ax = plt.subplot(n_rows, n_cols, i + 1, projection=ccrs.PlateCarree()) + + ax.add_feature(cfeature.COASTLINE, alpha=0.5) + ax.add_feature(cfeature.BORDERS, linestyle=":", alpha=0.5) + ax.add_feature( + cfeature.LAND, + edgecolor="black", + facecolor="lightgray", + alpha=0.3, + ) + ax.add_feature(cfeature.OCEAN, facecolor="lightblue", alpha=0.3) + + if len(lats) > 4: + grid_lon = np.linspace(min(lons), max(lons), 50) + grid_lat = np.linspace(min(lats), max(lats), 50) + grid_lon, grid_lat = np.meshgrid(grid_lon, grid_lat) + + try: + grid_values = griddata( + (lons, lats), + mean_values, + (grid_lon, grid_lat), + method="cubic", + ) + contour = ax.contourf( + grid_lon, + grid_lat, + grid_values, + levels=15, + cmap=get_variable_colormap(var), + transform=ccrs.PlateCarree(), + ) + plt.colorbar( + contour, + ax=ax, + orientation="vertical", + label=get_variable_label(var), + shrink=0.8, + ) + except Exception: + sc = ax.scatter( + lons, + lats, + c=mean_values, + s=100, + cmap=get_variable_colormap(var), + transform=ccrs.PlateCarree(), + ) + plt.colorbar( + sc, + ax=ax, + orientation="vertical", + label=get_variable_label(var), + shrink=0.8, + ) + else: + sc = ax.scatter( + lons, + lats, + c=mean_values, + s=100, + cmap=get_variable_colormap(var), + transform=ccrs.PlateCarree(), + ) + plt.colorbar( + sc, + ax=ax, + orientation="vertical", + label=get_variable_label(var), + shrink=0.8, + ) + + ax.scatter( + lons, + lats, + c="black", + s=20, + transform=ccrs.PlateCarree(), + alpha=0.8, + ) + + ax.set_title(f"{var.replace('_', ' ').title()}") + + ax.set_xticks(np.linspace(min(lons), max(lons), 5)) + ax.set_yticks(np.linspace(min(lats), max(lats), 5)) + ax.set_xticklabels( + [f"{lon:.2f}°" for lon in np.linspace(min(lons), max(lons), 5)], + fontsize=8, + ) + ax.set_yticklabels( + [f"{lat:.2f}°" for lat in np.linspace(min(lats), max(lats), 5)], + fontsize=8, + ) + ax.set_xlabel("Longitude") + ax.set_ylabel("Latitude") + + plt.suptitle( + f"{title} - Spatial Distribution (Time-Averaged)", + fontsize=14, + fontweight="bold", + ) + plt.tight_layout() + + +# --------------------------------------------------------------------------- +# Variable metadata helpers +# --------------------------------------------------------------------------- + + +def get_variable_label(variable: str) -> str: + """Get appropriate axis label with units for a variable. + + Args: + variable (str): Variable name. + + Returns: + str: Label with units for the variable. + """ + labels = { + "ghi": "GHI (W/m²)", + "dni": "DNI (W/m²)", + "dhi": "DHI (W/m²)", + "windspeed_100m": "Wind Speed at 100m (m/s)", + "winddirection_100m": "Wind Direction at 100m (°)", + "turbulent_kinetic_energy_100m": "TKE at 100m (m²/s²)", + "temperature_100m": "Temperature at 100m (°C)", + "pressure_100m": "Pressure at 100m (Pa)", + "wind_speed_80m": "Wind Speed at 80m (m/s)", + "windspeed_80m": "Wind Speed at 80m (m/s)", + "wind_direction_80m": "Wind Direction at 80m (°)", + "winddirection_80m": "Wind Direction at 80m (°)", + "temperature_2m": "Temperature at 2m (°C)", + "shortwave_radiation_instant": "Shortwave Radiation (W/m²)", + "diffuse_radiation_instant": "Diffuse Radiation (W/m²)", + "direct_normal_irradiance_instant": "Direct Normal Irradiance (W/m²)", + } + return labels.get(variable, variable.replace("_", " ").title()) + + +def get_variable_colormap(variable: str) -> str: + """Get appropriate matplotlib colormap name for a variable. + + Args: + variable (str): Variable name. + + Returns: + str: Matplotlib colormap name for the variable. + """ + colormaps = { + "ghi": "plasma", + "dni": "plasma", + "dhi": "plasma", + "windspeed_100m": "viridis", + "winddirection_100m": "hsv", + "turbulent_kinetic_energy_100m": "cividis", + "temperature_100m": "RdYlBu_r", + "pressure_100m": "coolwarm", + "wind_speed_80m": "viridis", + "windspeed_80m": "viridis", + "wind_direction_80m": "hsv", + "winddirection_80m": "hsv", + "temperature_2m": "RdYlBu_r", + "shortwave_radiation_instant": "plasma", + "diffuse_radiation_instant": "plasma", + "direct_normal_irradiance_instant": "plasma", + } + return colormaps.get(variable, "viridis") diff --git a/hercules/resource/wind_solar_resource_downloader.py b/hercules/resource/wind_solar_resource_downloader.py index 51870963..610015cc 100644 --- a/hercules/resource/wind_solar_resource_downloader.py +++ b/hercules/resource/wind_solar_resource_downloader.py @@ -1,7 +1,6 @@ -""" -WTK, NSRDB, and Open-Meteo Data Downloader +"""WTK, NSRDB, and Open-Meteo Data Downloader -This script provides functions to download weather data from multiple sources: +This module provides functions to download weather data from multiple sources: - NLR's Wind Toolkit (WTK) for high-resolution wind data - NLR's National Solar Radiation Database (NSRDB) for solar irradiance data - Open-Meteo API for historical weather data with global coverage @@ -14,23 +13,40 @@ Updated: September 2025 (Added Open-Meteo support) """ -import math import os import time import warnings from typing import List, Optional -import cartopy.crs as ccrs -import cartopy.feature as cfeature -import matplotlib.pyplot as plt -import numpy as np import openmeteo_requests import pandas as pd import requests_cache +from hercules.resource.resource_utilities import ( + create_bounding_box, + dispatch_plots, + download_nrel_rex_data, + get_variable_colormap, + get_variable_label, + plot_spatial_map, + plot_timeseries, + print_elapsed_time, + save_coords_to_feather, + save_variable_to_feather, + validate_time_params, +) from hercules.utilities import hercules_float_type from retry_requests import retry -from rex import ResourceX -from scipy.interpolate import griddata + +# Re-export plotting utilities so existing callers can still import them here +__all__ = [ + "download_nsrdb_data", + "download_wtk_data", + "download_openmeteo_data", + "plot_timeseries", + "plot_spatial_map", + "get_variable_label", + "get_variable_colormap", +] def download_nsrdb_data( @@ -90,136 +106,32 @@ def download_nsrdb_data( allows for more flexible time periods than full year. Plots are not automatically shown. If plot_data is True, call matplotlib.pyplot.show() to display the figure. """ - - # Create output directory if it doesn't exist os.makedirs(output_dir, exist_ok=True) - # Validate input parameters - if year is not None and (start_date is not None or end_date is not None): - raise ValueError( - "Please provide either 'year' OR both 'start_date' and 'end_date', not both approaches." - ) - - if year is None and (start_date is None or end_date is None): - raise ValueError("Please provide either 'year' OR both 'start_date' and 'end_date'.") - - # Determine the approach and set up file paths and time info - if year is not None: - # Full year approach - file_years = [year] - time_suffix = str(year) - time_description = f"year {year}" - else: - # Date range approach - - start_dt = pd.to_datetime(start_date) - end_dt = pd.to_datetime(end_date) - - if start_dt > end_dt: - raise ValueError("start_date must be before end_date") - - # Get all years in the date range - file_years = list(range(start_dt.year, end_dt.year + 1)) - time_suffix = f"{start_date}_to_{end_date}".replace("-", "") - time_description = f"period {start_date} to {end_date}" - - # Create the bounding box - llcrn_lat = target_lat - coord_delta - llcrn_lon = target_lon - coord_delta - urcrn_lat = target_lat + coord_delta - urcrn_lon = target_lon + coord_delta - - print(f"Downloading NSRDB data for {time_description}") - print(f"Target coordinates: ({target_lat}, {target_lon})") - print(f"Bounding box: ({llcrn_lat}, {llcrn_lon}) to ({urcrn_lat}, {urcrn_lon})") - print(f"Variables: {variables}") - print(f"Years to process: {file_years}") - - t0 = time.time() - - data_dict = {} - all_dataframes = {var: [] for var in variables} - - try: - # Process each year in the range - for file_year in file_years: - print(f"\nProcessing year {file_year}...") - fp = f"{nsrdb_dataset_path}/{nsrdb_filename_prefix}_{file_year}.h5" - - with ResourceX(fp) as res: - # Download each variable for this year - for var in variables: - print(f" Downloading {var} for {file_year}...") - df_year = res.get_box_df( - var, lat_lon_1=[llcrn_lat, llcrn_lon], lat_lon_2=[urcrn_lat, urcrn_lon] - ) - - # Filter by date range if using date range approach - if start_date is not None and end_date is not None: - # Filter the DataFrame to the specified date range - df_year = df_year.loc[start_date:end_date] - - all_dataframes[var].append(df_year) - - # Get coordinates (only need to do this once) - if "coordinates" not in data_dict: - gids = df_year.columns.values - coordinates = res.lat_lon[gids] - df_coords = pd.DataFrame(coordinates, index=gids, columns=["lat", "lon"]) - data_dict["coordinates"] = df_coords - - # Concatenate all years for each variable - for var in variables: - if all_dataframes[var]: - print(f"Concatenating {var} data across {len(all_dataframes[var])} years...") - data_dict[var] = pd.concat(all_dataframes[var], axis=0).sort_index() - - # Convert numeric columns to float32 for memory efficiency - for col in data_dict[var].columns: - if pd.api.types.is_numeric_dtype(data_dict[var][col]): - data_dict[var][col] = data_dict[var][col].astype(hercules_float_type) - - # Clear intermediate DataFrames to free memory - all_dataframes[var].clear() - - # Save to feather format - output_file = os.path.join( - output_dir, f"{filename_prefix}_{var}_{time_suffix}.feather" - ) - data_dict[var].reset_index().to_feather(output_file) - print(f"Saved {var} data to {output_file}") - - # Save coordinates - coords_file = os.path.join(output_dir, f"{filename_prefix}_coords_{time_suffix}.feather") - data_dict["coordinates"].reset_index().to_feather(coords_file) - print(f"Saved coordinates to {coords_file}") - - except OSError as e: - print(f"Error downloading NSRDB data: {e}") - print("This could be caused by an invalid API key, NSRDB dataset path, or date range.") - raise - except Exception as e: - print(f"Error downloading NSRDB data: {e}") - raise - - total_time = (time.time() - t0) / 60 - decimal_part = math.modf(total_time)[0] * 60 - print( - "NSRDB download completed in " - f"{int(np.floor(total_time))}:{int(np.round(decimal_part, 0)):02d} minutes" + time_params = validate_time_params(year, start_date, end_date) + bounding_box = create_bounding_box(target_lat, target_lon, coord_delta) + + data_dict = download_nrel_rex_data( + dataset_path=nsrdb_dataset_path, + dataset_filename_prefix=nsrdb_filename_prefix, + source_name="NSRDB", + target_lat=target_lat, + target_lon=target_lon, + variables=variables, + bounding_box=bounding_box, + file_years=time_params["file_years"], + start_date=start_date, + end_date=end_date, + output_dir=output_dir, + filename_prefix=filename_prefix, + time_suffix=time_params["time_suffix"], + time_description=time_params["time_description"], + os_error_hint=( + "This could be caused by an invalid API key, NSRDB dataset path, or date range." + ), ) - # Create plots if requested - if plot_data and data_dict and "coordinates" in data_dict: - coordinates_array = data_dict["coordinates"][["lat", "lon"]].values - if plot_type == "timeseries": - plot_timeseries( - data_dict, variables, coordinates_array, f"{filename_prefix} NSRDB Data" - ) - elif plot_type == "map": - plot_spatial_map( - data_dict, variables, coordinates_array, f"{filename_prefix} NSRDB Data" - ) + dispatch_plots(data_dict, variables, plot_data, plot_type, f"{filename_prefix} NSRDB Data") return data_dict @@ -273,134 +185,53 @@ def download_wtk_data( allows for more flexible time periods than full year. Plots are not automatically shown. If plot_data is True, call matplotlib.pyplot.show() to display the figure. """ - - # Create output directory if it doesn't exist os.makedirs(output_dir, exist_ok=True) - # Validate input parameters - if year is not None and (start_date is not None or end_date is not None): - raise ValueError( - "Please provide either 'year' OR both 'start_date' and 'end_date', not both approaches." - ) - - if year is None and (start_date is None or end_date is None): - raise ValueError("Please provide either 'year' OR both 'start_date' and 'end_date'.") - - # Determine the approach and set up file paths and time info - if year is not None: - # Full year approach - file_years = [year] - time_suffix = str(year) - time_description = f"year {year}" - else: - # Date range approach - - start_dt = pd.to_datetime(start_date) - end_dt = pd.to_datetime(end_date) - - if start_dt > end_dt: - raise ValueError("start_date must be before end_date") - - # Get all years in the date range - file_years = list(range(start_dt.year, end_dt.year + 1)) - time_suffix = f"{start_date}_to_{end_date}".replace("-", "") - time_description = f"period {start_date} to {end_date}" - - # Create the bounding box - llcrn_lat = target_lat - coord_delta - llcrn_lon = target_lon - coord_delta - urcrn_lat = target_lat + coord_delta - urcrn_lon = target_lon + coord_delta - - print(f"Downloading WTK data for {time_description}") - print(f"Target coordinates: ({target_lat}, {target_lon})") - print(f"Bounding box: ({llcrn_lat}, {llcrn_lon}) to ({urcrn_lat}, {urcrn_lon})") - print(f"Variables: {variables}") - print(f"Years to process: {file_years}") - - t0 = time.time() + time_params = validate_time_params(year, start_date, end_date) + bounding_box = create_bounding_box(target_lat, target_lon, coord_delta) + + data_dict = download_nrel_rex_data( + dataset_path="/nrel/wtk/wtk-led/conus/v1.0.0/5min", + dataset_filename_prefix="wtk_conus", + source_name="WTK", + target_lat=target_lat, + target_lon=target_lon, + variables=variables, + bounding_box=bounding_box, + file_years=time_params["file_years"], + start_date=start_date, + end_date=end_date, + output_dir=output_dir, + filename_prefix=filename_prefix, + time_suffix=time_params["time_suffix"], + time_description=time_params["time_description"], + os_error_hint="This could be caused by an invalid API key or date range.", + ) - data_dict = {} - all_dataframes = {var: [] for var in variables} + dispatch_plots(data_dict, variables, plot_data, plot_type, f"{filename_prefix} WTK Data") - try: - # Process each year in the range - for file_year in file_years: - print(f"\nProcessing year {file_year}...") - fp = f"/nrel/wtk/wtk-led/conus/v1.0.0/5min/wtk_conus_{file_year}.h5" - - with ResourceX(fp) as res: - # Download each variable for this year - for var in variables: - print(f" Downloading {var} for {file_year}...") - df_year = res.get_box_df( - var, lat_lon_1=[llcrn_lat, llcrn_lon], lat_lon_2=[urcrn_lat, urcrn_lon] - ) - - # Filter by date range if using date range approach - if start_date is not None and end_date is not None: - # Filter the DataFrame to the specified date range - df_year = df_year.loc[start_date:end_date] - - all_dataframes[var].append(df_year) - - # Get coordinates (only need to do this once) - if "coordinates" not in data_dict: - gids = df_year.columns.values - coordinates = res.lat_lon[gids] - df_coords = pd.DataFrame(coordinates, index=gids, columns=["lat", "lon"]) - data_dict["coordinates"] = df_coords - - # Concatenate all years for each variable - for var in variables: - if all_dataframes[var]: - print(f"Concatenating {var} data across {len(all_dataframes[var])} years...") - data_dict[var] = pd.concat(all_dataframes[var], axis=0).sort_index() - - # Convert numeric columns to float32 for memory efficiency - for col in data_dict[var].columns: - if pd.api.types.is_numeric_dtype(data_dict[var][col]): - data_dict[var][col] = data_dict[var][col].astype(hercules_float_type) - - # Clear intermediate DataFrames to free memory - all_dataframes[var].clear() - - # Save to feather format - output_file = os.path.join( - output_dir, f"{filename_prefix}_{var}_{time_suffix}.feather" - ) - data_dict[var].reset_index().to_feather(output_file) - print(f"Saved {var} data to {output_file}") - - # Save coordinates - coords_file = os.path.join(output_dir, f"{filename_prefix}_coords_{time_suffix}.feather") - data_dict["coordinates"].reset_index().to_feather(coords_file) - print(f"Saved coordinates to {coords_file}") - - except OSError as e: - print(f"Error downloading WTK data: {e}") - print("This could be caused by an invalid API key or date range.") - raise - except Exception as e: - print(f"Error downloading WTK data: {e}") - raise + return data_dict - total_time = (time.time() - t0) / 60 - decimal_part = math.modf(total_time)[0] * 60 - print( - "WTK download completed in " - f"{int(np.floor(total_time))}:{int(np.round(decimal_part, 0)):02d} minutes" - ) - # Create plots if requested - if plot_data and data_dict and "coordinates" in data_dict: - coordinates_array = data_dict["coordinates"][["lat", "lon"]].values - if plot_type == "timeseries": - plot_timeseries(data_dict, variables, coordinates_array, f"{filename_prefix} WTK Data") - elif plot_type == "map": - plot_spatial_map(data_dict, variables, coordinates_array, f"{filename_prefix} WTK Data") +# --------------------------------------------------------------------------- +# Open-Meteo variable mapping +# --------------------------------------------------------------------------- - return data_dict +OPENMETEO_VARIABLE_MAPPING = { + "wind_speed_80m": "wind_speed_80m", + "wind_direction_80m": "wind_direction_80m", + "temperature_2m": "temperature_2m", + "shortwave_radiation_instant": "shortwave_radiation_instant", + "diffuse_radiation_instant": "diffuse_radiation_instant", + "direct_normal_irradiance_instant": "direct_normal_irradiance_instant", + "ghi": "shortwave_radiation_instant", + "dni": "direct_normal_irradiance_instant", + "dhi": "diffuse_radiation_instant", + "windspeed_80m": "wind_speed_80m", + "winddirection_80m": "wind_direction_80m", +} +"""Mapping from user-facing variable names (including aliases) to Open-Meteo API parameter +names.""" def download_openmeteo_data( @@ -462,421 +293,219 @@ def download_openmeteo_data( spans from 1940 to present. Plots are not automatically shown. If plot_data is True, call matplotlib.pyplot.show() to display the figure. """ - - # Create output directory if it doesn't exist os.makedirs(output_dir, exist_ok=True) - # Validate input parameters - if year is not None and (start_date is not None or end_date is not None): - raise ValueError( - "Please provide either 'year' OR both 'start_date' and 'end_date', not both approaches." - ) - - if year is None and (start_date is None or end_date is None): - raise ValueError("Please provide either 'year' OR both 'start_date' and 'end_date'.") - - # Determine the approach and set up time info - if year is not None: - start_date = f"{year}-01-01" - end_date = f"{year}-12-31" - time_suffix = str(year) - time_description = f"year {year}" - else: - start_dt = pd.to_datetime(start_date) - end_dt = pd.to_datetime(end_date) - - if start_dt > end_dt: - raise ValueError("start_date must be before end_date") - - time_suffix = f"{start_date}_to_{end_date}".replace("-", "") - time_description = f"period {start_date} to {end_date}" + time_params = validate_time_params(year, start_date, end_date) + time_suffix = time_params["time_suffix"] + time_description = time_params["time_description"] + api_start_date = time_params["start_date"] + api_end_date = time_params["end_date"] print(f"Downloading Open-Meteo data for {time_description}") print(f"Target coordinates: ({target_lat}, {target_lon})") print(f"Variables: {variables}") print("Note: Open-Meteo provides point data (coord_delta ignored)") - # Map variable names to Open-Meteo API parameters - variable_mapping = { - "wind_speed_80m": "wind_speed_80m", - "wind_direction_80m": "wind_direction_80m", - "temperature_2m": "temperature_2m", - "shortwave_radiation_instant": "shortwave_radiation_instant", - "diffuse_radiation_instant": "diffuse_radiation_instant", - "direct_normal_irradiance_instant": "direct_normal_irradiance_instant", - "ghi": "shortwave_radiation_instant", # Alias for solar users - "dni": "direct_normal_irradiance_instant", # Alias for solar users - "dhi": "diffuse_radiation_instant", # Alias for solar users - "windspeed_80m": "wind_speed_80m", # Alias for wind users - "winddirection_80m": "wind_direction_80m", # Alias for wind users - } - - # Validate variables and map them - mapped_variables = [] - for var in variables: - if var in variable_mapping: - mapped_variables.append(variable_mapping[var]) - else: - print(f"Warning: Variable '{var}' not available in Open-Meteo. Skipping.") - - if not mapped_variables: - raise ValueError("No valid variables found for Open-Meteo download.") + mapped_variables = _map_openmeteo_variables(variables) t0 = time.time() try: - # Setup the Open-Meteo API client with cache and retry on error - cache_session = requests_cache.CachedSession(".cache", expire_after=3600) - retry_session = retry(cache_session, retries=5, backoff_factor=0.2) - openmeteo = openmeteo_requests.Client(session=retry_session) - - # Setup API parameters - url = "https://historical-forecast-api.open-meteo.com/v1/forecast" - params = { - "latitude": target_lat, - "longitude": target_lon, - "start_date": start_date, - "end_date": end_date, - "minutely_15": mapped_variables, - "wind_speed_unit": "ms", - } - - # Try to make the API request with SSL verification first, then fallback to no verification - try: - responses = openmeteo.weather_api(url, params=params) - print("API request successful with SSL verification.") - except Exception as e: - print(f"SSL verification failed: {str(e)[:100]}...") - print("Trying with SSL verification disabled...") - - # Suppress SSL warnings since we're intentionally disabling verification - warnings.filterwarnings("ignore", message="Unverified HTTPS request") - - # Create a new session with SSL verification disabled - cache_session_no_ssl = requests_cache.CachedSession(".cache", expire_after=3600) - cache_session_no_ssl.verify = False - retry_session_no_ssl = retry(cache_session_no_ssl, retries=5, backoff_factor=0.2) - openmeteo_no_ssl = openmeteo_requests.Client(session=retry_session_no_ssl) - - responses = openmeteo_no_ssl.weather_api(url, params=params) - print("API request successful with SSL verification disabled.") - - # Create data dictionary in the same format as WTK/NSRDB and initialize dataframes - data_dict = {} - data_dict["coordinates"] = pd.DataFrame() - - # Initialize for each variable - original_var_names = [] - for var in mapped_variables: - # Use original variable name (not mapped name) for consistency - original_var_name = None - for orig, mapped in variable_mapping.items(): - if mapped == var and orig in variables: - original_var_name = orig - break - - var_name = original_var_name if original_var_name else var - data_dict[var_name] = pd.DataFrame() - - original_var_names.append(var_name) - - # Process the responses for each lat/lon - for gid, response in enumerate(responses): - print(f"Coordinates retrieved: {response.Latitude()}°N {response.Longitude()}°E") - print(f"Elevation: {response.Elevation()} m asl") - - # Process minutely_15 data - minutely_15 = response.Minutely15() - - # Create the date range - date_range = pd.date_range( - start=pd.to_datetime(minutely_15.Time(), unit="s", utc=True), - end=pd.to_datetime(minutely_15.TimeEnd(), unit="s", utc=True), - freq=pd.Timedelta(seconds=minutely_15.Interval()), - inclusive="left", - ) - - # Create coordinates DataFrame (single point, but match the format) - # Use a synthetic GID (grid ID) to match WTK/NSRDB format - df_coords = pd.DataFrame( - [[response.Latitude(), response.Longitude()]], index=[gid], columns=["lat", "lon"] - ) - data_dict["coordinates"] = pd.concat([data_dict["coordinates"], df_coords], axis=0) - - # Process each requested variable - for i, var_name in enumerate(original_var_names): - var_data = minutely_15.Variables(i).ValuesAsNumpy() - - # Create DataFrame with same structure as WTK/NSRDB (datetime index, gid columns) - # Convert to float32 for memory efficiency - df_var = pd.DataFrame( - var_data.astype(hercules_float_type), index=date_range, columns=[gid] - ) - df_var.index.name = "time_index" - - data_dict[var_name] = pd.concat([data_dict[var_name], df_var], axis=1) - - # Check for duplicates, remove if any exist, and rename locations indices consecutively - if remove_duplicate_coords & (len(data_dict["coordinates"]) > 1): - duplicate_mask = data_dict["coordinates"].duplicated( - subset=["lat", "lon"], keep="first" - ) - data_dict["coordinates"] = data_dict["coordinates"][~duplicate_mask] + responses = _fetch_openmeteo_responses( + target_lat, target_lon, api_start_date, api_end_date, mapped_variables + ) - for var_name in original_var_names: - data_dict[var_name] = data_dict[var_name][ - [c for c in data_dict["coordinates"].index] - ] - data_dict[var_name].columns = range(len(data_dict["coordinates"])) + data_dict, original_var_names = _process_openmeteo_responses( + responses, mapped_variables, variables + ) - data_dict["coordinates"] = data_dict["coordinates"].reset_index(drop=True) + if remove_duplicate_coords and len(data_dict["coordinates"]) > 1: + _remove_duplicate_coordinates(data_dict, original_var_names) - # Save variables to feather format for var_name in original_var_names: - output_file = os.path.join( - output_dir, f"{filename_prefix}_{var_name}_{time_suffix}.feather" + save_variable_to_feather( + data_dict[var_name], output_dir, filename_prefix, var_name, time_suffix ) - data_dict[var_name].reset_index().to_feather(output_file) - print(f"Saved {var_name} data to {output_file}") - # Save coordinates - coords_file = os.path.join(output_dir, f"{filename_prefix}_coords_{time_suffix}.feather") - data_dict["coordinates"].reset_index().to_feather(coords_file) - print(f"Saved coordinates to {coords_file}") + save_coords_to_feather(data_dict["coordinates"], output_dir, filename_prefix, time_suffix) except Exception as e: print(f"Error downloading Open-Meteo data: {e}") raise - total_time = (time.time() - t0) / 60 - decimal_part = math.modf(total_time)[0] * 60 - print( - "Open-Meteo download completed in " - f"{int(np.floor(total_time))}:{int(np.round(decimal_part, 0)):02d} minutes" - ) + print_elapsed_time(t0, "Open-Meteo") - # Create plots if requested - if plot_data and data_dict and "coordinates" in data_dict: - coordinates_array = data_dict["coordinates"][["lat", "lon"]].values - if plot_type == "timeseries": - plot_timeseries( - data_dict, variables, coordinates_array, f"{filename_prefix} Open-Meteo Data" - ) - elif plot_type == "map": - plot_spatial_map( - data_dict, variables, coordinates_array, f"{filename_prefix} Open-Meteo Data" - ) + dispatch_plots(data_dict, variables, plot_data, plot_type, f"{filename_prefix} Open-Meteo Data") return data_dict -def plot_timeseries(data_dict: dict, variables: List[str], coordinates: np.ndarray, title: str): - """Create time-series plots for the downloaded data. +# --------------------------------------------------------------------------- +# Open-Meteo internal helpers +# --------------------------------------------------------------------------- - Args: - data_dict (dict): Dictionary containing DataFrames for each variable. - variables (List[str]): List of variables to plot. - coordinates (np.ndarray): Array of coordinates for the data points. - title (str): Title for the plots. - """ - n_vars = len(variables) - if n_vars == 0: - return +def _map_openmeteo_variables(variables: List[str]) -> list: + """Map user-facing variable names to Open-Meteo API parameter names. + + Args: + variables (list[str]): List of user-facing variable names. - # Create subplots based on number of variables - fig, axes = plt.subplots(n_vars, 1, figsize=(12, 4 * n_vars), sharex=True) - if n_vars == 1: - axes = [axes] + Returns: + list: List of mapped Open-Meteo API parameter names. - for i, var in enumerate(variables): - if var in data_dict: - df = data_dict[var] + Raises: + ValueError: If no valid variables are found after mapping. + """ + mapped_variables = [] + for var in variables: + if var in OPENMETEO_VARIABLE_MAPPING: + mapped_variables.append(OPENMETEO_VARIABLE_MAPPING[var]) + else: + print(f"Warning: Variable '{var}' not available in Open-Meteo. Skipping.") - # Plot all time series (one for each spatial point) - for col in df.columns: - axes[i].plot(df.index, df[col], alpha=0.7, linewidth=0.8) + if not mapped_variables: + raise ValueError("No valid variables found for Open-Meteo download.") - axes[i].set_ylabel(get_variable_label(var)) - axes[i].set_title(f"{var.replace('_', ' ').title()}") - axes[i].grid(True, alpha=0.3) + return mapped_variables - axes[-1].set_xlabel("Time") - plt.suptitle(f"{title} - Time Series", fontsize=14, fontweight="bold") - plt.tight_layout() +def _fetch_openmeteo_responses( + target_lat: float | List[float], + target_lon: float | List[float], + start_date: str, + end_date: str, + mapped_variables: list, +) -> list: + """Fetch data from the Open-Meteo API with SSL fallback. -def plot_spatial_map(data_dict: dict, variables: List[str], coordinates: np.ndarray, title: str): - """Create spatial maps showing the mean values across the region. + Attempts the request with SSL verification first. If that fails, retries + with SSL verification disabled. Args: - data_dict (dict): Dictionary containing DataFrames for each variable. - variables (List[str]): List of variables to plot. - coordinates (np.ndarray): Array of coordinates for the data points. - title (str): Title for the plots. + target_lat (float | list[float]): Target latitude(s). + target_lon (float | list[float]): Target longitude(s). + start_date (str): Start date in 'YYYY-MM-DD' format. + end_date (str): End date in 'YYYY-MM-DD' format. + mapped_variables (list): List of Open-Meteo API parameter names. + + Returns: + list: List of Open-Meteo API response objects. """ + cache_session = requests_cache.CachedSession(".cache", expire_after=3600) + retry_session = retry(cache_session, retries=5, backoff_factor=0.2) + openmeteo = openmeteo_requests.Client(session=retry_session) + + url = "https://historical-forecast-api.open-meteo.com/v1/forecast" + params = { + "latitude": target_lat, + "longitude": target_lon, + "start_date": start_date, + "end_date": end_date, + "minutely_15": mapped_variables, + "wind_speed_unit": "ms", + } - n_vars = len(variables) - if n_vars == 0: - return - - # Calculate subplot layout - n_cols = min(2, n_vars) - n_rows = math.ceil(n_vars / n_cols) - - plt.figure(figsize=(8 * n_cols, 6 * n_rows)) - - for i, var in enumerate(variables): - if var in data_dict: - df = data_dict[var] - - # Extract coordinates - lats = coordinates[:, 0] - lons = coordinates[:, 1] - - # Calculate mean values across time - mean_values = df.mean(axis=0).values - - # Create subplot with map projection - ax = plt.subplot(n_rows, n_cols, i + 1, projection=ccrs.PlateCarree()) - - # Add geographic features - ax.add_feature(cfeature.COASTLINE, alpha=0.5) - ax.add_feature(cfeature.BORDERS, linestyle=":", alpha=0.5) - ax.add_feature(cfeature.LAND, edgecolor="black", facecolor="lightgray", alpha=0.3) - ax.add_feature(cfeature.OCEAN, facecolor="lightblue", alpha=0.3) - - # Create interpolated grid for smoother visualization - if len(lats) > 4: # Only interpolate if we have enough points - grid_lon = np.linspace(min(lons), max(lons), 50) - grid_lat = np.linspace(min(lats), max(lats), 50) - grid_lon, grid_lat = np.meshgrid(grid_lon, grid_lat) - - try: - grid_values = griddata( - (lons, lats), mean_values, (grid_lon, grid_lat), method="cubic" - ) - contour = ax.contourf( - grid_lon, - grid_lat, - grid_values, - levels=15, - cmap=get_variable_colormap(var), - transform=ccrs.PlateCarree(), - ) - plt.colorbar( - contour, - ax=ax, - orientation="vertical", - label=get_variable_label(var), - shrink=0.8, - ) - except Exception: - # Fall back to scatter plot if interpolation fails - sc = ax.scatter( - lons, - lats, - c=mean_values, - s=100, - cmap=get_variable_colormap(var), - transform=ccrs.PlateCarree(), - ) - plt.colorbar( - sc, ax=ax, orientation="vertical", label=get_variable_label(var), shrink=0.8 - ) - else: - # Use scatter plot for few points - sc = ax.scatter( - lons, - lats, - c=mean_values, - s=100, - cmap=get_variable_colormap(var), - transform=ccrs.PlateCarree(), - ) - plt.colorbar( - sc, ax=ax, orientation="vertical", label=get_variable_label(var), shrink=0.8 - ) - - # Add points on top - ax.scatter(lons, lats, c="black", s=20, transform=ccrs.PlateCarree(), alpha=0.8) - - # Set title - ax.set_title(f"{var.replace('_', ' ').title()}") - - # Set coordinate labels - ax.set_xticks(np.linspace(min(lons), max(lons), 5)) - ax.set_yticks(np.linspace(min(lats), max(lats), 5)) - ax.set_xticklabels( - [f"{lon:.2f}°" for lon in np.linspace(min(lons), max(lons), 5)], fontsize=8 - ) - ax.set_yticklabels( - [f"{lat:.2f}°" for lat in np.linspace(min(lats), max(lats), 5)], fontsize=8 - ) - ax.set_xlabel("Longitude") - ax.set_ylabel("Latitude") + try: + responses = openmeteo.weather_api(url, params=params) + print("API request successful with SSL verification.") + except Exception as e: + print(f"SSL verification failed: {str(e)[:100]}...") + print("Trying with SSL verification disabled...") + + warnings.filterwarnings("ignore", message="Unverified HTTPS request") - plt.suptitle(f"{title} - Spatial Distribution (Time-Averaged)", fontsize=14, fontweight="bold") - plt.tight_layout() + cache_session_no_ssl = requests_cache.CachedSession(".cache", expire_after=3600) + cache_session_no_ssl.verify = False + retry_session_no_ssl = retry(cache_session_no_ssl, retries=5, backoff_factor=0.2) + openmeteo_no_ssl = openmeteo_requests.Client(session=retry_session_no_ssl) + responses = openmeteo_no_ssl.weather_api(url, params=params) + print("API request successful with SSL verification disabled.") -def get_variable_label(variable: str) -> str: - """Get appropriate label and units for a variable. + return responses + + +def _process_openmeteo_responses( + responses: list, + mapped_variables: list, + original_variables: List[str], +) -> tuple: + """Process Open-Meteo API responses into a data dictionary. Args: - variable (str): Variable name. + responses (list): List of Open-Meteo API response objects. + mapped_variables (list): List of mapped Open-Meteo API parameter names. + original_variables (list[str]): Original user-facing variable names. Returns: - str: Label with units for the variable. + tuple: (data_dict, original_var_names) where data_dict contains DataFrames for each + variable and coordinates, and original_var_names is the list of variable names used + as keys in data_dict. """ - labels = { - "ghi": "GHI (W/m²)", - "dni": "DNI (W/m²)", - "dhi": "DHI (W/m²)", - "windspeed_100m": "Wind Speed at 100m (m/s)", - "winddirection_100m": "Wind Direction at 100m (°)", - "turbulent_kinetic_energy_100m": "TKE at 100m (m²/s²)", - "temperature_100m": "Temperature at 100m (°C)", - "pressure_100m": "Pressure at 100m (Pa)", - # Open-Meteo variables - "wind_speed_80m": "Wind Speed at 80m (m/s)", - "windspeed_80m": "Wind Speed at 80m (m/s)", - "wind_direction_80m": "Wind Direction at 80m (m/s)", - "winddirection_80m": "Wind Direction at 80m (m/s)", - "temperature_2m": "Temperature at 2m (°C)", - "shortwave_radiation_instant": "Shortwave Radiation (W/m²)", - "diffuse_radiation_instant": "Diffuse Radiation (W/m²)", - "direct_normal_irradiance_instant": "Direct Normal Irradiance (W/m²)", - } - return labels.get(variable, variable.replace("_", " ").title()) + data_dict = {"coordinates": pd.DataFrame()} + + original_var_names = [] + for var in mapped_variables: + original_var_name = None + for orig, mapped in OPENMETEO_VARIABLE_MAPPING.items(): + if mapped == var and orig in original_variables: + original_var_name = orig + break + + var_name = original_var_name if original_var_name else var + data_dict[var_name] = pd.DataFrame() + original_var_names.append(var_name) + + for gid, response in enumerate(responses): + print(f"Coordinates retrieved: {response.Latitude()}°N {response.Longitude()}°E") + print(f"Elevation: {response.Elevation()} m asl") + + minutely_15 = response.Minutely15() + + date_range = pd.date_range( + start=pd.to_datetime(minutely_15.Time(), unit="s", utc=True), + end=pd.to_datetime(minutely_15.TimeEnd(), unit="s", utc=True), + freq=pd.Timedelta(seconds=minutely_15.Interval()), + inclusive="left", + ) + df_coords = pd.DataFrame( + [[response.Latitude(), response.Longitude()]], index=[gid], columns=["lat", "lon"] + ) + data_dict["coordinates"] = pd.concat([data_dict["coordinates"], df_coords], axis=0) -def get_variable_colormap(variable: str) -> str: - """Get appropriate colormap for a variable. + for i, var_name in enumerate(original_var_names): + var_data = minutely_15.Variables(i).ValuesAsNumpy() - Args: - variable (str): Variable name. + df_var = pd.DataFrame( + var_data.astype(hercules_float_type), index=date_range, columns=[gid] + ) + df_var.index.name = "time_index" - Returns: - str: Matplotlib colormap name for the variable. + data_dict[var_name] = pd.concat([data_dict[var_name], df_var], axis=1) + + return data_dict, original_var_names + + +def _remove_duplicate_coordinates( + data_dict: dict, + original_var_names: list, +) -> None: + """Remove duplicate coordinates from the data dictionary in-place. + + When multiple requested coordinates map to the same weather grid cell, this function keeps + only the first occurrence and re-indexes the columns consecutively. + + Args: + data_dict (dict): Data dictionary to modify. Must contain a "coordinates" key. + original_var_names (list): List of variable names to filter. """ - colormaps = { - "ghi": "plasma", - "dni": "plasma", - "dhi": "plasma", - "windspeed_100m": "viridis", - "winddirection_100m": "hsv", - "turbulent_kinetic_energy_100m": "cividis", - "temperature_100m": "RdYlBu_r", - "pressure_100m": "coolwarm", - # Open-Meteo variables - "wind_speed_80m": "viridis", - "windspeed_80m": "viridis", - "wind_direction_80m": "hsv", - "winddirection_80m": "hsv", - "temperature_2m": "RdYlBu_r", - "shortwave_radiation_instant": "plasma", - "diffuse_radiation_instant": "plasma", - "direct_normal_irradiance_instant": "plasma", - } - return colormaps.get(variable, "viridis") + duplicate_mask = data_dict["coordinates"].duplicated(subset=["lat", "lon"], keep="first") + data_dict["coordinates"] = data_dict["coordinates"][~duplicate_mask] + + for var_name in original_var_names: + data_dict[var_name] = data_dict[var_name][[c for c in data_dict["coordinates"].index]] + data_dict[var_name].columns = range(len(data_dict["coordinates"])) + + data_dict["coordinates"] = data_dict["coordinates"].reset_index(drop=True)