"""Open Government Data API helpers."""
# Standard library
import dataclasses as dc
import datetime as dt
import enum
import hashlib
import logging
import os
import re
import typing
from functools import lru_cache
from pathlib import Path
from urllib.parse import urlparse
from uuid import UUID
# Third-party
import earthkit.data as ekd # type: ignore
import pydantic
import pydantic.dataclasses as pdc
import xarray as xr
# Local
from . import data_source, grib_decoder, icon_grid, util
API_URL = "https://data.geo.admin.ch/api/stac/v1"
logger = logging.getLogger(__name__)
session = util.init_session(logger)
[docs]
class Collection(str, enum.Enum):
#: Collection of icon-ch1-eps model outputs
ICON_CH1 = "ogd-forecasting-icon-ch1"
#: Collection of icon-ch2-eps model outputs
ICON_CH2 = "ogd-forecasting-icon-ch2"
def _forecast_prefix(field_name):
if field_name in {
"variable",
"reference_datetime",
"perturbed",
"horizon",
}:
return f"forecast:{field_name}"
return field_name
def _parse_datetime(value: str) -> dt.datetime:
return dt.datetime.strptime(value, "%Y-%m-%dT%H:%M:%S%z")
[docs]
@pdc.dataclass(
frozen=True,
config=pydantic.ConfigDict(
use_enum_values=True,
ser_json_timedelta="iso8601",
alias_generator=pydantic.AliasGenerator(serialization_alias=_forecast_prefix),
),
)
class Request:
"""Define filters for the STAC Search API.
Parameters
----------
collection : Collection
Name of the STAC collection.
variable : str
Name of the variable following the DWD convention.
reference_datetime : str
Forecast reference datetime in ISO 8601 format.
Alias: ref_time
perturbed : bool
If true, retrieve ensemble forecast members.
Otherwise, retrieve deterministic (control) forecast.
horizon : datetime.timedelta or list[datetime.timedelta]
Lead time of the requested data.
Can be supplied as string in ISO 8601 format.
Alias: lead_time
"""
collection: Collection = dc.field(metadata=dict(exclude=True))
variable: str
reference_datetime: str = dc.field(
metadata=dict(
validation_alias=pydantic.AliasChoices("reference_datetime", "ref_time")
)
)
perturbed: bool
horizon: dt.timedelta | list[dt.timedelta] = dc.field(
metadata=dict(validation_alias=pydantic.AliasChoices("horizon", "lead_time"))
)
if typing.TYPE_CHECKING:
# https://github.com/pydantic/pydantic/issues/10266
def __init__(self, *args: typing.Any, **kwargs: typing.Any): ...
@pydantic.computed_field # type: ignore[misc]
@property
def collections(self) -> list[str]:
return ["ch.meteoschweiz." + str(self.collection)]
@pydantic.field_validator("reference_datetime", mode="wrap")
@classmethod
def valid_reference_datetime(
cls, input_value: typing.Any, handler: pydantic.ValidatorFunctionWrapHandler
) -> str:
if isinstance(input_value, dt.datetime):
if input_value.tzinfo is None:
logger.warn("Converting naive datetime from local time to UTC")
fmt = "%Y-%m-%dT%H:%M:%SZ" # Zulu isoformat
return input_value.astimezone(dt.timezone.utc).strftime(fmt)
value = handler(input_value)
if value == "latest":
return value
parts = value.split("/")
match parts:
case [v, ".."] | ["..", v] | [v]:
# open ended or single value
_parse_datetime(v)
case [v1, v2]:
# range
d1 = _parse_datetime(v1)
d2 = _parse_datetime(v2)
if d2 < d1:
raise ValueError("reference_datetime bounds inverted")
case _:
raise ValueError(f"Unable to parse reference_datetime: {value}")
return value
@pydantic.field_serializer("reference_datetime")
def serialize_reference_datetime(self, value: str):
if value == "latest":
cutoff = dt.datetime.now(tz=dt.timezone.utc) - dt.timedelta(hours=48)
fmt = "%Y-%m-%dT%H:%M:%SZ" # Zulu isoformat
return f"{cutoff.strftime(fmt)}/.."
return value
def dump(self):
exclude_fields = {}
if isinstance(self.horizon, list):
exclude_fields["horizon"] = True
root = pydantic.RootModel(self)
return root.model_dump(mode="json", by_alias=True, exclude=exclude_fields)
def _search(url: str, body: dict | None = None):
response = session.post(url, json=body)
response.raise_for_status()
obj = response.json()
result = []
for item in obj["features"]:
for asset in item["assets"].values():
result.append(asset["href"])
for link in obj["links"]:
if link["rel"] == "next":
if link["method"] != "POST" or not link["merge"]:
raise RuntimeError(f"Bad link: {link}")
result.extend(_search(link["href"], body | link["body"]))
return result
[docs]
def get_asset_urls(request: Request) -> list[str]:
"""Get asset URLs from OGD.
The request attributes define filters for the STAC search API according
to the forecast extension. Forecasts reference datetimes for which not all
requested lead times are present are excluded from the result.
Parameters
----------
request : Request
Asset search filters
Raises
------
ValueError
when no datetime can be found in the asset URL for 'latest' requests.
Returns
-------
list[str]
URLs of the selected assets.
"""
result = _search(f"{API_URL}/search", request.dump())
if len(result) == 1:
return result
lead_times = (
request.horizon if isinstance(request.horizon, list) else [request.horizon]
)
pattern = re.compile(r"-(?P<ref_time>\d{12})-(?P<lead_time>\d+)-")
def extract_key(url: str) -> tuple[dt.datetime, dt.timedelta]:
path = urlparse(url).path
match = pattern.search(path)
if not match:
raise ValueError(f"No valid datetime found in URL path: {url}")
val = match.group("ref_time")
fmt = "%Y%m%d%H%M"
utc = dt.timezone.utc
ref_time = dt.datetime.strptime(val, fmt).replace(tzinfo=utc)
lead_time = dt.timedelta(hours=float(match.group("lead_time")))
return ref_time, lead_time
asset_map = {extract_key(url): url for url in result}
# gather reference times for which all requested lead times are present
tmp: dict[dt.datetime, list[dt.timedelta]] = {}
for ref_time, lead_time in asset_map:
tmp.setdefault(ref_time, []).append(lead_time)
required = set(lead_times)
complete = [ref_time for ref_time in tmp if set(tmp[ref_time]) >= required]
if request.reference_datetime == "latest":
ref_time = max(complete)
return [asset_map[(ref_time, lead_time)] for lead_time in lead_times]
return [
asset_map[(ref_time, lead_time)]
for lead_time in lead_times
for ref_time in complete
]
@lru_cache
def _get_collection_assets(collection_id: str):
url = f"{API_URL}/collections/{collection_id}/assets"
response = session.get(url)
response.raise_for_status()
return {asset["id"]: asset for asset in response.json().get("assets", [])}
[docs]
def get_collection_asset_url(collection_id: str, asset_id: str) -> str:
"""Get collection asset URL from OGD.
Query the STAC collection assets and return the URL for the given asset ID.
Parameters
----------
collection_id : str
Full STAC collection ID
asset_id : str
The ID of the static asset to retrieve.
Returns
-------
str
The pre-signed URL of the requested static asset.
Raises
------
KeyError
If the asset is not found in the collection.
"""
assets = _get_collection_assets(collection_id)
asset_info = assets.get(asset_id)
if not asset_info or "href" not in asset_info:
raise KeyError(f"Asset '{asset_id}' not found in collection '{collection_id}'.")
return asset_info["href"]
def _get_geo_coord_url(uuid: UUID) -> str:
if (var := os.environ.get("MDL_GEO_COORD_URL")) is not None:
return var
model_name = icon_grid.GRID_UUID_TO_MODEL.get(uuid)
if model_name is None:
raise KeyError("Grid UUID not found")
base_model = model_name.removesuffix("-eps")
collection_id = f"ch.meteoschweiz.ogd-forecasting-{base_model}"
asset_id = f"horizontal_constants_{model_name}.grib2"
return get_collection_asset_url(collection_id, asset_id)
def _no_coords(uuid: UUID) -> dict[str, xr.DataArray]:
return {}
def _geo_coords(uuid: UUID) -> dict[str, xr.DataArray]:
url = _get_geo_coord_url(uuid)
source = data_source.URLDataSource(urls=[url])
ds = grib_decoder.load(source, {"param": ["CLON", "CLAT"]}, geo_coords=_no_coords)
return {"lon": ds["CLON"].squeeze(), "lat": ds["CLAT"].squeeze()}
[docs]
def get_from_ogd(request: Request) -> xr.DataArray:
"""Get data from OGD.
The request attributes define filters for the STAC search API according
to the forecast extension. It is recommended to enable caching through
earthkit-data. A warning message is emitted if the cache is disabled.
Parameters
----------
request : Request
Asset search filters, must select a single asset.
Raises
------
ValueError
when the request does not select exactly one asset.
Returns
-------
xarray.DataArray
A data array of the selected asset including GRIB metadata and coordinates.
"""
if ekd.settings.get("cache-policy") == "off":
doc = "https://earthkit-data.readthedocs.io/en/latest/examples/cache.html"
logger.warning("Earthkit-data caching is recommended. See: %s", doc)
asset_urls = get_asset_urls(request)
source = data_source.URLDataSource(urls=asset_urls)
return grib_decoder.load(
source,
{"param": request.variable},
geo_coords=_geo_coords,
)[request.variable]
[docs]
def download_from_ogd(request: Request, target: Path) -> None:
"""Download forecast asset and its static coordinate files from OGD.
The request attributes define filters for the STAC search API according
to the forecast extension.
In addition to the main asset, this function downloads static files
with horizontal and vertical coordinates, as the forecast item
does not include the horizontal or vertical coordinates.
Parameters
----------
request : Request
Asset search filters, must select a single asset.
target : Path
Target path where to save the asset, must be a directory.
Raises
------
ValueError
when the request does not select exactly one asset.
RuntimeError
if the checksum verification fails.
"""
if target.exists() and not target.is_dir():
raise ValueError(f"target: {target} must be a directory")
if not target.exists():
target.mkdir(parents=True)
# Download main forecast asset
asset_urls = get_asset_urls(request)
for asset_url in asset_urls:
_download_with_checksum(asset_url, target)
model_suffix = request.collection.removeprefix("ogd-forecasting-")
collection_id = f"ch.meteoschweiz.{request.collection}"
# Download coordinate files
for prefix in ["horizontal", "vertical"]:
asset_id = f"{prefix}_constants_{model_suffix}-eps.grib2"
url = get_collection_asset_url(collection_id, asset_id)
_download_with_checksum(url, target)
def _file_hash(path: Path):
hash = hashlib.sha256()
with path.open("rb") as f:
while chunk := f.read(16 * 1024):
hash.update(chunk)
return hash.hexdigest()
def _download_with_checksum(url: str, target: Path) -> None:
filename = Path(urlparse(url).path).name
path = target / filename if target.is_dir() else target
hash_path = path.with_suffix(".sha256")
if path.exists():
if hash_path.exists() and hash_path.read_text() == _file_hash(path):
logger.info(f"File already exists, skipping download: {path}")
return
response = session.get(url, stream=True)
response.raise_for_status()
hash = response.headers.get("X-Amz-Meta-Sha256")
if hash is not None:
hash_path.write_text(hash)
hasher = hashlib.sha256()
with path.open("wb") as f:
for chunk in response.iter_content(16 * 1024):
f.write(chunk)
hasher.update(chunk)
if hash is not None and hash != hasher.hexdigest():
raise RuntimeError(f"Checksum verification failed for {filename}")