"""Open Government Data API helpers."""
# Standard library
import dataclasses as dc
import datetime as dt
import enum
import hashlib
import logging
import os
import re
import typing
from functools import lru_cache
from pathlib import Path
from urllib.parse import urlparse
from uuid import UUID
# Third-party
import earthkit.data as ekd  # type: ignore
import pydantic
import pydantic.dataclasses as pdc
import xarray as xr
# Local
from . import data_source, grib_decoder, icon_grid, util
API_URL = "https://data.geo.admin.ch/api/stac/v1"
logger = logging.getLogger(__name__)
session = util.init_session(logger)
[docs]
class Collection(str, enum.Enum):
    #: Collection of icon-ch1-eps model outputs
    ICON_CH1 = "ogd-forecasting-icon-ch1"
    #: Collection of icon-ch2-eps model outputs
    ICON_CH2 = "ogd-forecasting-icon-ch2" 
def _forecast_prefix(field_name):
    if field_name in {
        "variable",
        "reference_datetime",
        "perturbed",
        "horizon",
    }:
        return f"forecast:{field_name}"
    return field_name
def _parse_datetime(value: str) -> dt.datetime:
    return dt.datetime.strptime(value, "%Y-%m-%dT%H:%M:%S%z")
[docs]
@pdc.dataclass(
    frozen=True,
    config=pydantic.ConfigDict(
        use_enum_values=True,
        ser_json_timedelta="iso8601",
        alias_generator=pydantic.AliasGenerator(serialization_alias=_forecast_prefix),
    ),
)
class Request:
    """Define filters for the STAC Search API.
    Parameters
    ----------
    collection : Collection
        Name of the STAC collection.
    variable : str
        Name of the variable following the DWD convention.
    reference_datetime : str
        Forecast reference datetime in ISO 8601 format.
        Alias: ref_time
    perturbed : bool
        If true, retrieve ensemble forecast members.
        Otherwise, retrieve deterministic (control) forecast.
    horizon : datetime.timedelta or list[datetime.timedelta]
        Lead time of the requested data.
        Can be supplied as string in ISO 8601 format.
        Alias: lead_time
    """
    collection: Collection = dc.field(metadata=dict(exclude=True))
    variable: str
    reference_datetime: str = dc.field(
        metadata=dict(
            validation_alias=pydantic.AliasChoices("reference_datetime", "ref_time")
        )
    )
    perturbed: bool
    horizon: dt.timedelta | list[dt.timedelta] = dc.field(
        metadata=dict(validation_alias=pydantic.AliasChoices("horizon", "lead_time"))
    )
    if typing.TYPE_CHECKING:
        # https://github.com/pydantic/pydantic/issues/10266
        def __init__(self, *args: typing.Any, **kwargs: typing.Any): ...
    @pydantic.computed_field  # type: ignore[misc]
    @property
    def collections(self) -> list[str]:
        return ["ch.meteoschweiz." + str(self.collection)]
    @pydantic.field_validator("reference_datetime", mode="wrap")
    @classmethod
    def valid_reference_datetime(
        cls, input_value: typing.Any, handler: pydantic.ValidatorFunctionWrapHandler
    ) -> str:
        if isinstance(input_value, dt.datetime):
            if input_value.tzinfo is None:
                logger.warn("Converting naive datetime from local time to UTC")
            fmt = "%Y-%m-%dT%H:%M:%SZ"  # Zulu isoformat
            return input_value.astimezone(dt.timezone.utc).strftime(fmt)
        value = handler(input_value)
        if value == "latest":
            return value
        parts = value.split("/")
        match parts:
            case [v, ".."] | ["..", v] | [v]:
                # open ended or single value
                _parse_datetime(v)
            case [v1, v2]:
                # range
                d1 = _parse_datetime(v1)
                d2 = _parse_datetime(v2)
                if d2 < d1:
                    raise ValueError("reference_datetime bounds inverted")
            case _:
                raise ValueError(f"Unable to parse reference_datetime: {value}")
        return value
    @pydantic.field_serializer("reference_datetime")
    def serialize_reference_datetime(self, value: str):
        if value == "latest":
            cutoff = dt.datetime.now(tz=dt.timezone.utc) - dt.timedelta(hours=48)
            fmt = "%Y-%m-%dT%H:%M:%SZ"  # Zulu isoformat
            return f"{cutoff.strftime(fmt)}/.."
        return value
    def dump(self):
        exclude_fields = {}
        if isinstance(self.horizon, list):
            exclude_fields["horizon"] = True
        root = pydantic.RootModel(self)
        return root.model_dump(mode="json", by_alias=True, exclude=exclude_fields) 
def _search(url: str, body: dict | None = None):
    response = session.post(url, json=body)
    response.raise_for_status()
    obj = response.json()
    result = []
    for item in obj["features"]:
        for asset in item["assets"].values():
            result.append(asset["href"])
    for link in obj["links"]:
        if link["rel"] == "next":
            if link["method"] != "POST" or not link["merge"]:
                raise RuntimeError(f"Bad link: {link}")
            result.extend(_search(link["href"], body | link["body"]))
    return result
[docs]
def get_asset_urls(request: Request) -> list[str]:
    """Get asset URLs from OGD.
    The request attributes define filters for the STAC search API according
    to the forecast extension. Forecasts reference datetimes for which not all
    requested lead times are present are excluded from the result.
    Parameters
    ----------
    request : Request
        Asset search filters
    Raises
    ------
    ValueError
        when no datetime can be found in the asset URL for 'latest' requests.
    Returns
    -------
    list[str]
        URLs of the selected assets.
    """
    result = _search(f"{API_URL}/search", request.dump())
    if len(result) == 1:
        return result
    lead_times = (
        request.horizon if isinstance(request.horizon, list) else [request.horizon]
    )
    pattern = re.compile(r"-(?P<ref_time>\d{12})-(?P<lead_time>\d+)-")
    def extract_key(url: str) -> tuple[dt.datetime, dt.timedelta]:
        path = urlparse(url).path
        match = pattern.search(path)
        if not match:
            raise ValueError(f"No valid datetime found in URL path: {url}")
        val = match.group("ref_time")
        fmt = "%Y%m%d%H%M"
        utc = dt.timezone.utc
        ref_time = dt.datetime.strptime(val, fmt).replace(tzinfo=utc)
        lead_time = dt.timedelta(hours=float(match.group("lead_time")))
        return ref_time, lead_time
    asset_map = {extract_key(url): url for url in result}
    # gather reference times for which all requested lead times are present
    tmp: dict[dt.datetime, list[dt.timedelta]] = {}
    for ref_time, lead_time in asset_map:
        tmp.setdefault(ref_time, []).append(lead_time)
    required = set(lead_times)
    complete = [ref_time for ref_time in tmp if set(tmp[ref_time]) >= required]
    if request.reference_datetime == "latest":
        ref_time = max(complete)
        return [asset_map[(ref_time, lead_time)] for lead_time in lead_times]
    return [
        asset_map[(ref_time, lead_time)]
        for lead_time in lead_times
        for ref_time in complete
    ] 
@lru_cache
def _get_collection_assets(collection_id: str):
    url = f"{API_URL}/collections/{collection_id}/assets"
    response = session.get(url)
    response.raise_for_status()
    return {asset["id"]: asset for asset in response.json().get("assets", [])}
[docs]
def get_collection_asset_url(collection_id: str, asset_id: str) -> str:
    """Get collection asset URL from OGD.
    Query the STAC collection assets and return the URL for the given asset ID.
    Parameters
    ----------
    collection_id : str
        Full STAC collection ID
    asset_id : str
        The ID of the static asset to retrieve.
    Returns
    -------
    str
        The pre-signed URL of the requested static asset.
    Raises
    ------
    KeyError
        If the asset is not found in the collection.
    """
    assets = _get_collection_assets(collection_id)
    asset_info = assets.get(asset_id)
    if not asset_info or "href" not in asset_info:
        raise KeyError(f"Asset '{asset_id}' not found in collection '{collection_id}'.")
    return asset_info["href"] 
def _get_geo_coord_url(uuid: UUID) -> str:
    if (var := os.environ.get("MDL_GEO_COORD_URL")) is not None:
        return var
    model_name = icon_grid.GRID_UUID_TO_MODEL.get(uuid)
    if model_name is None:
        raise KeyError("Grid UUID not found")
    base_model = model_name.removesuffix("-eps")
    collection_id = f"ch.meteoschweiz.ogd-forecasting-{base_model}"
    asset_id = f"horizontal_constants_{model_name}.grib2"
    return get_collection_asset_url(collection_id, asset_id)
def _no_coords(uuid: UUID) -> dict[str, xr.DataArray]:
    return {}
def _geo_coords(uuid: UUID) -> dict[str, xr.DataArray]:
    url = _get_geo_coord_url(uuid)
    source = data_source.URLDataSource(urls=[url])
    ds = grib_decoder.load(source, {"param": ["CLON", "CLAT"]}, geo_coords=_no_coords)
    return {"lon": ds["CLON"].squeeze(), "lat": ds["CLAT"].squeeze()}
[docs]
def get_from_ogd(request: Request) -> xr.DataArray:
    """Get data from OGD.
    The request attributes define filters for the STAC search API according
    to the forecast extension. It is recommended to enable caching through
    earthkit-data. A warning message is emitted if the cache is disabled.
    Parameters
    ----------
    request : Request
        Asset search filters, must select a single asset.
    Raises
    ------
    ValueError
        when the request does not select exactly one asset.
    Returns
    -------
    xarray.DataArray
        A data array of the selected asset including GRIB metadata and coordinates.
    """
    if ekd.settings.get("cache-policy") == "off":
        doc = "https://earthkit-data.readthedocs.io/en/latest/examples/cache.html"
        logger.warning("Earthkit-data caching is recommended. See: %s", doc)
    asset_urls = get_asset_urls(request)
    source = data_source.URLDataSource(urls=asset_urls)
    return grib_decoder.load(
        source,
        {"param": request.variable},
        geo_coords=_geo_coords,
    )[request.variable] 
[docs]
def download_from_ogd(request: Request, target: Path) -> None:
    """Download forecast asset and its static coordinate files from OGD.
    The request attributes define filters for the STAC search API according
    to the forecast extension.
    In addition to the main asset, this function downloads static files
    with horizontal and vertical coordinates, as the forecast item
    does not include the horizontal or vertical coordinates.
    Parameters
    ----------
    request : Request
        Asset search filters, must select a single asset.
    target : Path
        Target path where to save the asset, must be a directory.
    Raises
    ------
    ValueError
        when the request does not select exactly one asset.
    RuntimeError
        if the checksum verification fails.
    """
    if target.exists() and not target.is_dir():
        raise ValueError(f"target: {target} must be a directory")
    if not target.exists():
        target.mkdir(parents=True)
    # Download main forecast asset
    asset_urls = get_asset_urls(request)
    for asset_url in asset_urls:
        _download_with_checksum(asset_url, target)
    model_suffix = request.collection.removeprefix("ogd-forecasting-")
    collection_id = f"ch.meteoschweiz.{request.collection}"
    # Download coordinate files
    for prefix in ["horizontal", "vertical"]:
        asset_id = f"{prefix}_constants_{model_suffix}-eps.grib2"
        url = get_collection_asset_url(collection_id, asset_id)
        _download_with_checksum(url, target) 
def _file_hash(path: Path):
    hash = hashlib.sha256()
    with path.open("rb") as f:
        while chunk := f.read(16 * 1024):
            hash.update(chunk)
    return hash.hexdigest()
def _download_with_checksum(url: str, target: Path) -> None:
    filename = Path(urlparse(url).path).name
    path = target / filename if target.is_dir() else target
    hash_path = path.with_suffix(".sha256")
    if path.exists():
        if hash_path.exists() and hash_path.read_text() == _file_hash(path):
            logger.info(f"File already exists, skipping download: {path}")
            return
    response = session.get(url, stream=True)
    response.raise_for_status()
    hash = response.headers.get("X-Amz-Meta-Sha256")
    if hash is not None:
        hash_path.write_text(hash)
    hasher = hashlib.sha256()
    with path.open("wb") as f:
        for chunk in response.iter_content(16 * 1024):
            f.write(chunk)
            hasher.update(chunk)
    if hash is not None and hash != hasher.hexdigest():
        raise RuntimeError(f"Checksum verification failed for {filename}")