Source code for meteodatalab.operators.regrid

"""Regridding operator."""

# Standard library
import dataclasses as dc
import typing
import warnings

# Third-party
import numpy as np
import xarray as xr
from rasterio import transform, warp
from rasterio.crs import CRS

# Local
from .. import icon_grid, metadata
from ..grib_decoder import set_code_flag

Resampling: typing.TypeAlias = warp.Resampling

# For more information: check https://epsg.io/<id>
CRS_ALIASES = {
    "geolatlon": "epsg:4326",  # WGS84
    "swiss": "epsg:21781",  # Swiss CH1903 / LV03
    "swiss03": "epsg:21781",  # Swiss CH1903 / LV03
    "swiss95": "epsg:2056",  # Swiss CH1903+ / LV95
    "boaga-west": "epsg:3003",  # Monte Mario / Italy zone 1
    "boaga-east": "epsg:3004",  # Monte Mario / Italy zone 2
}


def _get_crs(geo):
    if geo["gridType"] != "rotated_ll":
        raise NotImplementedError("Unsupported grid type")

    lon = geo["longitudeOfSouthernPoleInDegrees"]
    lat = -1 * geo["latitudeOfSouthernPoleInDegrees"]

    return CRS.from_string(
        f"+proj=ob_tran +o_proj=longlat +o_lat_p={lat} +lon_0={lon} +datum=WGS84"
    )


def _normalise(angle: float) -> float:
    return np.fmod(angle + 180, 360) - 180


[docs] @dc.dataclass class RegularGrid: """Class defining a regular grid. Attributes ---------- crs : CRS Coordinate reference system. nx : int Number of grid points in the x direction. ny : int Number of grid points in the y direction. xmin : float Coordinate of the first grid point in the x direction. xmax : float Coordinate of the last grid point in the x direction. ymin : float Coordinate of the first grid point in the y direction. ymax : float Coordinate of the last grid point in the y direction. """ crs: CRS nx: int ny: int xmin: float xmax: float ymin: float ymax: float
[docs] @classmethod def from_field(cls, field: xr.DataArray): """Extract grid parameters from grib metadata. Parameters ---------- field : xarray.DataArray field containing the relevant metadata. """ geo = field.geography obj = cls( crs=_get_crs(geo), nx=geo["Ni"], ny=geo["Nj"], xmin=_normalise(geo["longitudeOfFirstGridPointInDegrees"]), xmax=_normalise(geo["longitudeOfLastGridPointInDegrees"]), ymin=geo["latitudeOfFirstGridPointInDegrees"], ymax=geo["latitudeOfLastGridPointInDegrees"], ) if abs(obj.dx - geo["iDirectionIncrementInDegrees"]) > 1e-5: raise ValueError("Inconsistent grid parameters") if abs(obj.dy - geo["jDirectionIncrementInDegrees"]) > 1e-5: raise ValueError("Inconsistent grid parameters") return obj
[docs] @classmethod def parse_regrid_operator(cls, op: str): """Parse fieldextra out_regrid_target string. Parameters ---------- op : str fieldextra out_regrid_target definition i.e. crs,xmin,ymin,xmay,ymax,dx,dy. """ crs_str, *grid_params = op.split(",") crs = CRS.from_string(CRS_ALIASES[crs_str]) xmin, ymin, xmax, ymax, dx, dy = map(float, grid_params) if abs(dx) < 1e-10 or abs(dy) < 1e-10: raise ValueError("Inconsistent regrid parameters") nx = (xmax - xmin) / dx + 1 ny = (ymax - ymin) / dy + 1 if nx != int(nx) or ny != int(ny): raise ValueError("Inconsistent regrid parameters") return cls(crs, int(nx), int(ny), xmin, xmax, ymin, ymax)
[docs] def to_crs(self, crs: str, **kwargs): """Return a new grid in the given coordinate reference system. Parameters ---------- crs : str The coordinate reference system in which the output grid should be defined. Returns ------- RegularGrid or subtype Output grid in the given coordinate reference system. """ dst_crs = CRS.from_string(crs) tx, width, height = typing.cast( tuple[transform.Affine, int, int], warp.calculate_default_transform( src_crs=self.crs, dst_crs=dst_crs, width=self.nx, height=self.ny, left=self.xmin, bottom=self.ymin, right=self.xmax, top=self.ymax, **kwargs, ), ) a = transform.AffineTransformer(tx) xmin, ymax = a.xy(-0.5, -0.5) xmax, ymin = a.xy(height + 0.5, width + 0.5) # row, col return type(self)(dst_crs, width, height, xmin, xmax, ymin, ymax)
@property def x(self) -> np.ndarray: return np.arange(self.nx) * self.dx + self.xmin @property def y(self) -> np.ndarray: return np.arange(self.ny) * self.dy + self.ymin @property def dx(self) -> float: return (self.xmax - self.xmin) / (self.nx - 1) @property def dy(self) -> float: return (self.ymax - self.ymin) / (self.ny - 1) @property def transform(self) -> transform.Affine: return transform.from_origin( west=self.xmin - self.dx / 2, north=self.ymax + self.dy / 2, xsize=self.dx, ysize=self.dy, )
def _udeg(value): return int(round(value * 1e6)) def _get_metadata(grid: RegularGrid): if grid.crs.to_epsg() == 4326: # geolatlon # https://codes.ecmwf.int/grib/format/grib2/ctables/3/4/ scanning_mode = set_code_flag([2]) # positive y # i, j direction increments given # bit 5 left unset since the target grid is geolatlon # both values are equivalent in this case resolution_components_flags = set_code_flag([3, 4]) return { "numberOfDataPoints": grid.nx * grid.ny, "sourceOfGridDefinition": 0, # defined by template number "numberOfOctectsForNumberOfPoints": 0, "interpretationOfNumberOfPoints": 0, "gridDefinitionTemplateNumber": 0, # latlon "shapeOfTheEarth": 5, # WGS 84 "Ni": grid.nx, "Nj": grid.ny, "latitudeOfFirstGridPoint": _udeg(grid.ymin), "longitudeOfFirstGridPoint": _udeg(grid.xmin), "resolutionAndComponentFlags": resolution_components_flags, "latitudeOfLastGridPoint": _udeg(grid.ymax), "longitudeOfLastGridPoint": _udeg(grid.xmax), "iDirectionIncrement": _udeg(grid.dx), "jDirectionIncrement": _udeg(grid.dy), "scanningMode": scanning_mode, } elif grid.crs.get("proj") == "ob_tran": # rotlatlon # https://codes.ecmwf.int/grib/format/grib2/ctables/3/4/ scanning_mode = set_code_flag([2]) # positive y # i, j direction increments given resolution_components_flags = set_code_flag([3, 4]) return { "numberOfDataPoints": grid.nx * grid.ny, "sourceOfGridDefinition": 0, # defined by template number "numberOfOctectsForNumberOfPoints": 0, "interpretationOfNumberOfPoints": 0, "gridDefinitionTemplateNumber": 1, # rotlatlon "shapeOfTheEarth": 5, # WGS 84 "Ni": grid.nx, "Nj": grid.ny, "latitudeOfFirstGridPoint": _udeg(grid.ymin), "longitudeOfFirstGridPoint": _udeg(grid.xmin), "resolutionAndComponentFlags": resolution_components_flags, "latitudeOfLastGridPoint": _udeg(grid.ymax), "longitudeOfLastGridPoint": _udeg(grid.xmax), "iDirectionIncrement": _udeg(grid.dx), "jDirectionIncrement": _udeg(grid.dy), "scanningMode": scanning_mode, "latitudeOfSouthernPole": _udeg(-1 * grid.crs.get("o_lat_p")), "longitudeOfSouthernPole": _udeg(grid.crs.get("lon_0")), "angleOfRotation": 0.0, }
[docs] def regrid( field: xr.DataArray, dst: RegularGrid, resampling: Resampling, src: RegularGrid | None = None, ): """Regrid a field. Parameters ---------- field : xarray.DataArray Input field defined on a regular grid in rotated latlon coordinates. dst : RegularGrid Destination grid onto which to project the field. resampling : Resampling Resampling method, alias of rasterio.warp.Resampling. src : RegularGrid, optional Definition of the input field grid Raises ------ ValueError If the input field is not defined on a regular grid in rotated latlon or if the input field geography metadata does not have consistent grid parameters. Returns ------- xarray.DataArray Field regridded in the destination grid. """ if src is None: src = RegularGrid.from_field(field) def reproject_layer(field): output = np.zeros((dst.ny, dst.nx)) warp.reproject( source=field[::-1], destination=output, src_crs=src.crs, src_transform=src.transform, dst_crs=dst.crs, dst_transform=dst.transform, resampling=resampling, ) return output[::-1] # output dims renamed to workaround limitation that overlapping dims in the input # must not change in size data = xr.apply_ufunc( reproject_layer, field, input_core_dims=[["y", "x"]], output_core_dims=[["y1", "x1"]], vectorize=True, ).rename({"x1": "x", "y1": "y"}) attrs = field.attrs if md := _get_metadata(dst): attrs |= metadata.override(field.message, **md) return xr.DataArray(data, attrs=attrs)
def _icon2regular( field: xr.DataArray, dst: RegularGrid, indices: np.ndarray, weights: np.ndarray ) -> xr.DataArray: mask = np.all(indices != 0, axis=-1) def reproject_layer(field): out_shape = field.shape[:-1] + (dst.ny, dst.nx) values = np.take(field, indices, axis=-1) if np.any(np.isnan(values)): warnings.warn("Interpolation of missing values is not supported.") vmin = np.min(values, axis=-1) vmax = np.max(values, axis=-1) result = np.einsum("...ij,ij->...i", values, weights) masked = np.where(mask, result, np.nan) return np.clip(masked, vmin, vmax).reshape(out_shape) data = xr.apply_ufunc( reproject_layer, field, input_core_dims=[["cell"]], output_core_dims=[["y", "x"]], ) attrs = field.attrs if md := _get_metadata(dst): attrs |= metadata.override(field.message, **md) return xr.DataArray(data, attrs=attrs)
[docs] def icon2geolatlon(field: xr.DataArray) -> xr.DataArray: """Remap ICON native grid data to the geolatlon grid. Parameters ---------- field : xarray.DataArray A field with data in the ICON native grid. Returns ------- xarray.DataArray Field with data remapped to the geolatlon grid. """ gid = metadata.extract_keys(field.message, "uuidOfHGrid") coeffs = icon_grid.get_remap_coeffs(gid, "geolatlon") indices = coeffs["rbf_B_glbidx"].values weights = coeffs["rbf_B_wgt"].values dst = RegularGrid( crs=CRS.from_string("epsg:4326"), nx=coeffs.nx, ny=coeffs.ny, xmin=coeffs.xmin, ymin=coeffs.ymin, xmax=coeffs.xmax, ymax=coeffs.ymax, ) return _icon2regular(field, dst, indices, weights)
[docs] def icon2rotlatlon(field: xr.DataArray) -> xr.DataArray: """Remap ICON native grid data to the rotated latlon grid. Parameters ---------- field : xarray.DataArray A field with data in the ICON native grid. Returns ------- xarray.DataArray Field with data remapped to the rotated latlon grid. """ gid = metadata.extract_keys(field.message, "uuidOfHGrid") coeffs = icon_grid.get_remap_coeffs(gid, "rotlatlon") indices = coeffs["rbf_B_glbidx"].values weights = coeffs["rbf_B_wgt"].values geo = { "gridType": "rotated_ll", "longitudeOfSouthernPoleInDegrees": coeffs.north_pole_lon - 180, "latitudeOfSouthernPoleInDegrees": -1 * coeffs.north_pole_lat, } dst = RegularGrid( crs=_get_crs(geo), nx=coeffs.nx, ny=coeffs.ny, xmin=coeffs.xmin, ymin=coeffs.ymin, xmax=coeffs.xmax, ymax=coeffs.ymax, ) return _icon2regular(field, dst, indices, weights)