"""Decoder for grib data."""
# Standard library
import dataclasses as dc
import datetime as dt
import io
import logging
import typing
from collections import UserDict
from collections.abc import Mapping, Sequence
from enum import Enum
from itertools import product
from pathlib import Path
from warnings import warn
# Third-party
import earthkit.data as ekd # type: ignore
import numpy as np
import pandas as pd
import xarray as xr
from numpy.typing import DTypeLike
# Local
from . import data_source, icon_grid, mars, metadata, tasking
logger = logging.getLogger(__name__)
DIM_MAP = {
"eps": "perturbationNumber",
"ref_time": "ref_time",
"lead_time": "step",
"z": "level",
}
NAME_KEY = "shortName"
Request = str | tuple | dict | mars.Request
[docs]
class ChainMap(UserDict):
def __init__(self, *maps):
self._maps = maps
def __getitem__(self, key):
for mapping in self._maps:
try:
return mapping[key]
except KeyError:
pass
raise KeyError(f"{key} not found")
[docs]
class GribField(typing.Protocol):
def metadata(self, *args, **kwargs) -> typing.Any: ...
def message(self) -> bytes: ...
def to_numpy(self, dtype: DTypeLike) -> np.ndarray: ...
def to_latlon(self) -> dict[str, np.ndarray]: ...
[docs]
class MissingData(RuntimeError):
pass
[docs]
class UnitOfTime(Enum):
MINUTE = 0
HOUR = 1
DAY = 2
SECOND = 13
MISSING = 255
@property
def unit(self):
if self.name == "MISSING":
return None
return self.name.lower()
def _is_ensemble(field) -> bool:
try:
return field.metadata("typeOfEnsembleForecast") == 192
except KeyError:
return False
def _get_hcoords(field):
if field.metadata("gridType") == "unstructured_grid":
grid_uuid = field.metadata("uuidOfHGrid")
return icon_grid.get_icon_grid(grid_uuid)
return {
dim: xr.DataArray(dims=("y", "x"), data=values)
for dim, values in field.to_latlon().items()
}
def _parse_datetime(date, time) -> dt.datetime:
return dt.datetime.strptime(f"{date}{time:04d}", "%Y%m%d%H%M")
def _to_timedelta(value, unit) -> np.timedelta64:
return pd.to_timedelta(value, unit).to_numpy()
def _get_key(field, dims):
md = field.metadata()
step = md["step"]
unit = "h" if isinstance(step, int) else None
extra = {
"ref_time": _parse_datetime(md["dataDate"], md["dataTime"]),
"step": _to_timedelta(step, unit),
}
dim_keys = (DIM_MAP[dim] for dim in dims)
mapping = ChainMap(extra, md)
return tuple(mapping[key] for key in dim_keys)
@dc.dataclass
class _FieldBuffer:
is_ensemble: dc.InitVar[bool]
dims: tuple[str, ...] = tuple(DIM_MAP)
hcoords: dict[str, xr.DataArray] = dc.field(default_factory=dict)
metadata: dict[str, typing.Any] = dc.field(default_factory=dict)
values: dict[tuple[int, ...], np.ndarray] = dc.field(default_factory=dict)
def __post_init__(self, is_ensemble: bool):
if not is_ensemble:
self.dims = self.dims[1:]
def load(self, field: GribField) -> None:
key = _get_key(field, self.dims)
name = field.metadata(NAME_KEY)
logger.debug("Received field for param: %s, key: %s", name, key)
if key in self.values:
logger.warn("Key collision for param: %s, key: %s", name, key)
self.values[key] = field.to_numpy(dtype=np.float32)
if not self.metadata:
self.metadata = {
"message": field.message(), # try field.metadata.override()
**metadata.extract(field.metadata()),
}
if not self.hcoords:
self.hcoords = _get_hcoords(field)
def _gather_coords(self):
coord_values = zip(*self.values)
unique = (sorted(set(values)) for values in coord_values)
coords = {dim: c for dim, c in zip(self.dims, unique)}
if missing := [
combination
for combination in product(*coords.values())
if combination not in self.values
]:
msg = f"Missing combinations: {missing}"
logger.exception(msg)
raise RuntimeError(msg)
field_shape = next(iter(self.values.values())).shape
shape = tuple(len(v) for v in coords.values()) + field_shape
return coords, shape
def to_xarray(self) -> xr.DataArray:
if not self.values:
raise MissingData("No values.")
coords, shape = self._gather_coords()
ref_time = xr.DataArray(coords["ref_time"], dims="ref_time")
lead_time = xr.DataArray(coords["lead_time"], dims="lead_time")
tcoords = {"valid_time": ref_time + lead_time}
hdims = self.hcoords["lon"].dims
array = xr.DataArray(
data=np.array(
[self.values.pop(key) for key in sorted(self.values)]
).reshape(shape),
coords=coords | self.hcoords | tcoords,
dims=self.dims + hdims,
attrs=self.metadata,
)
if array.vcoord_type != "surface":
return array
return array.squeeze("z", drop=True)
def _load_buffer_map(
source: data_source.DataSource,
request: Request,
) -> dict[str, _FieldBuffer]:
logger.info("Retrieving request: %s", request)
fs = source.retrieve(request)
buffer_map: dict[str, _FieldBuffer] = {}
for field in fs:
name = field.metadata(NAME_KEY)
if name in buffer_map:
buffer = buffer_map[name]
else:
buffer = buffer_map[name] = _FieldBuffer(_is_ensemble(field))
buffer.load(field)
return buffer_map
[docs]
def load_single_param(
source: data_source.DataSource,
request: Request,
) -> xr.DataArray:
"""Request data from a data source for a single parameter.
Parameters
----------
source : data_source.DataSource
Source to request the data from.
request : str | tuple[str, str] | dict[str, Any] | meteodatalab.mars.Request
Request for data from the source in the mars language.
Raises
------
ValueError
if more than one param is present in the request.
RuntimeError
when all of the requested data is not returned from the data source.
Returns
-------
xarray.DataArray
A data array of the requested field.
"""
if (
isinstance(request, dict)
and isinstance(request["param"], Sequence)
and len(request["param"]) > 1
):
raise ValueError("Only one param is supported.")
buffer_map = _load_buffer_map(source, request)
[buffer] = buffer_map.values()
return buffer.to_xarray()
[docs]
def load(
source: data_source.DataSource,
request: Request,
) -> dict[str, xr.DataArray]:
"""Request data from a data source.
Parameters
----------
source : data_source.DataSource
Source to request the data from.
request : str | tuple[str, str] | dict[str, Any] | meteodatalab.mars.Request
Request for data from the source in the mars language.
Raises
------
RuntimeError
when all of the requested data is not returned from the data source.
Returns
-------
dict[str, xarray.DataArray]
A mapping of shortName to data arrays of the requested fields.
"""
buffer_map = _load_buffer_map(source, request)
result = {}
for name, buffer in buffer_map.items():
try:
result[name] = buffer.to_xarray()
except MissingData as e:
raise RuntimeError(f"Missing data for param: {name}") from e
return result
class GribReader:
def __init__(
self,
source: data_source.DataSource,
ref_param: Request | None = None,
):
"""Initialize a grib reader from a data source.
Parameters
----------
source : data_source.DataSource
Data source from which to retrieve the grib fields
ref_param : str
name of parameter used to construct a reference grid
Raises
------
ValueError
if the grid can not be constructed from the ref_param
"""
self.data_source = source
if ref_param is not None:
warn("GribReader: ref_param is deprecated.")
@classmethod
def from_files(cls, datafiles: list[Path], ref_param: Request | None = None):
"""Initialize a grib reader from a list of grib files.
Parameters
----------
datafiles : list[Path]
List of grib input filenames
ref_param : str
name of parameter used to construct a reference grid
Raises
------
ValueError
if the grid can not be constructed from the ref_param
"""
return cls(
data_source.FileDataSource(datafiles=[str(p) for p in datafiles]), ref_param
)
def load(
self,
requests: Mapping[str, Request],
extract_pv: str | None = None,
) -> dict[str, xr.DataArray]:
"""Load a dataset with the requested parameters.
Parameters
----------
requests : Mapping[str, Request]
Mapping of label to request for a given field from the data source.
extract_pv: str | None
Optionally extract hybrid level coefficients from the field referenced by
the given label.
Raises
------
RuntimeError
if not all fields are found in the data source.
Returns
-------
dict[str, xr.DataArray]
Mapping of fields by label
"""
result = {
name: tasking.delayed(load_single_param)(self.data_source, req)
for name, req in requests.items()
}
if extract_pv is not None:
if extract_pv not in requests:
msg = f"{extract_pv=} was not a key of the given requests."
raise RuntimeError(msg)
return result | metadata.extract_pv(result[extract_pv].message)
return result
def load_fieldnames(
self,
params: list[str],
extract_pv: str | None = None,
) -> dict[str, xr.DataArray]:
reqs = {param: param for param in params}
return self.load(reqs, extract_pv)
[docs]
def save(
field: xr.DataArray,
file_handle: io.BufferedWriter | io.BytesIO,
bits_per_value: int = 16,
):
"""Write field to file in GRIB format.
Parameters
----------
field : xarray.DataArray
Field to write into the output file.
file_handle : io.BufferedWriter
File handle for the output file.
bits_per_value : int, optional
Bits per value encoded in the output file. (Default: 16)
Raises
------
ValueError
If the field does not have a message attribute.
"""
if not hasattr(field, "message"):
msg = "The message attribute is required to write to the GRIB format."
raise ValueError(msg)
stream = io.BytesIO(field.message)
[md] = (f.metadata() for f in ekd.from_source("stream", stream))
idx = {
dim: field.coords[key]
for key in field.dims
if (dim := str(key)) not in {"x", "y"}
}
step_unit = UnitOfTime.MINUTE
time_range_unit = UnitOfTime(md.get("indicatorOfUnitForTimeRange", 255)).unit
time_range = _to_timedelta(md.get("lengthOfTimeRange", 0), unit=time_range_unit)
if md.get("numberOfTimeRange", 1) != 1:
raise NotImplementedError("Unsupported value for numberOfTimeRange")
def to_grib(loc: dict[str, xr.DataArray]):
grib_loc = {
DIM_MAP[key]: value.item()
for key, value in loc.items()
if key not in {"ref_time", "lead_time"}
}
step_end = np.timedelta64(loc["lead_time"].item(), "ns")
step_begin = step_end - time_range
return grib_loc | {
"indicatorOfUnitOfTimeRange": step_unit.value,
"forecastTime": step_begin / _to_timedelta(1, step_unit.unit),
"dataDate": loc["ref_time"].dt.strftime("%Y%m%d").item(),
"dataTime": loc["ref_time"].dt.strftime("%H%M").item(),
}
for idx_slice in product(*idx.values()):
loc = {dim: value for dim, value in zip(idx.keys(), idx_slice)}
array = field.sel(loc).values
metadata = md.override(to_grib(loc))
fs = ekd.FieldList.from_numpy(array, metadata)
fs.write(file_handle, bits_per_value=bits_per_value)
[docs]
def get_code_flag(value: int, indices: Sequence[int]) -> list[bool]:
"""Get the code flag value at the given indices.
Parameters
----------
value : int
The code flag as an integer in the [0, 255] range.
indices : Sequence[int]
Indices at which to get the flag values. Left to right, 1-based.
Returns
-------
list[bool]
The code flag values at the given indices.
"""
if not 0 <= value <= 255:
raise ValueError("value must be a single byte integer")
result = []
for index in indices:
if not 1 <= index <= 8:
raise ValueError("index must in range [1,8]")
shift = 8 - index
result.append(bool(value >> shift & 1))
return result
[docs]
def set_code_flag(indices: Sequence[int]) -> int:
"""Create code flag by setting bits at the given indices.
Parameters
----------
indices : Sequence[int]
Indices at which to set the flag values. Left to right, 1-based.
Returns
-------
int
Code flag with bits set at the given indices.
"""
value = 0
for index in indices:
if not 1 <= index <= 8:
raise ValueError("index must in range [1,8]")
shift = 8 - index
value |= 1 << shift
return value