# -*- coding: utf-8 -*-
"""
Functions (accessors) provided by snowtools adaptation of xarray
----------------------------------------------------------------
The module xarray_snowtools_accessor aims at wrapping and extending the xarray module for snowtools-specific usage.
The wrapping of existing methods is designed to reduce dependency to native xarray method changes (in order to
centralise required adaptations).
Following the xarray project's recomandations, it is based on the use of accessor :
https://tutorial.xarray.dev/advanced/accessors/01_accessor_examples.html
This accessor is automatically made available when you import ``snowtools.utils.xarray_snowtools``.
Usage examples
^^^^^^^^^^^^^^
.. code-block:: python
from snowtools.utils import xarray_snowtools
import xarray as xr
ds = xr.open_dataset('INPUT.nc', decode_times=False)
ds = xarray_snowtools.preprocess(ds)
1. Select subset of points from a S2M file in the massif geometry, based on the massif number,
elevation, slope and aspect
.. code-block:: python
ds.semidistributed.sel_points(massif_num=3, ZS=[900, 1800, 2700, 3600], slope=40)
2. Project gridded data from Lambert-93 to lat/lon :
.. code-block:: python
ds.distributed.proj(crs_in="EPSG:2154", crs_out="EPSG:4326")
3. Interpret time-like variable or dimension as datetime :
.. code-block:: python
ds.surfex.decode_time_variable(varname)
4. Compute 24-hour precipitation accumulations from 6h J to 6h J+1 from hourly precipitation dataset:
.. code-block:: python
ds.Precipitation.meteo.daily_accumulation()
5. Example of groupby with the first-test PRO file:
.. code-block:: python
from snowtools.utils import xarray_snowtools
import xarray as xr
import matplotlib.pyplot as plt
ds = xr.open_dataset('PRO_2010080106_2011080106.nc', decode_times=False)
ds = xarray_snowtools.preprocess(ds)
dszs = ds.semidistributed.sel_points(ZS=2400)
meanmonthgroup = dszs.groupby("time.month").mean() # mean the variables on a monthly base
meanmonthgroup.TG1.plot() # choose one variable to plot
plt.show()
6. Example of resample with the first-test PRO file:
.. code-block:: python
from snowtools.utils import xarray_snowtools
import xarray as xr
ds = xr.open_dataset('PRO_2010080106_2011080106.nc', decode_times=False)
ds = xarray_snowtools.preprocess(ds)
dszs = ds.semidistributed.sel_points(ZS=2400)
dszs.resample(time='12h').mean() # time resampling to 12h timestep
7. Use of custom "daily_accumulation" method
Resample Rainf variable from hourly values to daily accumulations, starting at 03:00 :
.. code-block:: python
from snowtools.utils import xarray_snowtools
import xarray as xr
ds_hourly = xr.open_dataset('FORCING_test_2d.nc', decode_times=False)
ds = xarray_snowtools.preprocess(ds)
ds_daily = ds_hourly.Rainf.snowtools.daily_accumulation(start_hour=3)
New features integration rules:
Any native xarray function/method NOT part of xarray’s public API can be overwritten in these accessors in order to
centralise required adaptations in case of any change of behavior of the native method.
Informations on the list of xarray function/method considered public API can be found in the xarray documentation :
- https://docs.xarray.dev/en/v2023.09.0/getting-started-guide/faq.html (section "What parts of xarray are considered
public API?")
- https://docs.xarray.dev/en/v2023.09.0/api.html#api
"""
from typing import Union
import xarray as xr
[docs]
@xr.register_dataset_accessor("meteo")
@xr.register_dataarray_accessor("meteo")
class MeteoAccessor(SnowtoolsAccessor):
"""
Accessor designed to deal with meteorological files.
Usage example:
.. code-block:: python
from snowtools.utils import xarray_snowtools
import xarray as xr
ds = xr.open_dataset('FORCING.nc', decode_times=False)
ds = xarray_snowtools.preprocess(ds)
ds.meteo.[...]
"""
[docs]
@xr.register_dataset_accessor("surfex")
@xr.register_dataarray_accessor("surfex")
class SurfexAccessor(SnowtoolsAccessor):
"""
Accessor designed to deal with SURFEX output files.
Usage example:
.. code-block:: python
from snowtools.utils import xarray_snowtools
import xarray as xr
ds = xr.open_dataset('PRO.nc', decode_times=False)
ds = xarray_snowtools.preprocess(ds)
ds.surfex.decode_time_variable('time')
"""
[docs]
def decode_time_variable(self, varname):
"""
Manually decode any time-like variable from a SURFEX output
:param varname: Name of the variable to decode
:type varname: str
"""
timevar = xr.Dataset({varname: self.ds[varname]})
timevar = xr.decode_cf(timevar)
self.ds[varname] = timevar[varname]
[docs]
def drop_tile_dimension(self, tile=0):
"""
Select a single value for "tile" dimension (or equivalent) and squeeze the dataset to drop the dimension.
:param tile: Value of the "tile" dimension to select
:type tile: int
"""
for drop_dim in ['Number_of_patches', 'tile', 'Number_of_Tile']:
if drop_dim in self.ds.dims:
self.ds = self.ds.sel(drop_dim=tile).squeeze()
return self.ds
[docs]
@xr.register_dataset_accessor("semidistributed")
@xr.register_dataarray_accessor("semidistributed")
class SemiDistributedAccessor(SnowtoolsAccessor):
"""
Additionnal methods in semi-distributed geometry (ex: S2M simulaitions)
Usage example:
.. code-block:: python
from snowtools.utils import xarray_snowtools
import xarray as xr
ds = xr.open_dataset('INPUT.nc', decode_times=False)
ds = xarray_snowtools.preprocess(ds)
ds.semidistributed.sel_points(massif_num=3, ZS=[900, 1800, 2700, 3600], slope=40)
"""
[docs]
def sel_points(self, massif_num=None, ZS=None, slope=None, aspect=None):
"""
Method used to select a user-defined list of points in semi-distributed geometry (SAFRAN massifs geometry)
from their elevation (ZS), massif number (massif_num), slope and aspect.
**NB :**
More advanced indexing (for example to select all elevations above 1800m or use a slice as argument), use the
native xarray "where" method directly.
:param massif_num: Massif number(s) of points to select
:param massif_num: list, range or int
:param ZS: Elevation(s) of points to select
:param ZS: list, range or int
:param slope: Slope(s) of points to select
:param slope: list, range or int
:param aspect: Aspects(s) of points to select
:param aspect: list, range or int
"""
if isinstance(self.ds, xr.DataArray):
raise TypeError("This method only applies to Dataset objects")
indexer = None
for var in ['massif_num', 'ZS', 'slope', 'aspect']:
if eval(var) is not None:
if var not in list(self.ds.keys()):
raise ValueError(f'Variable "{var}" does not exist')
else:
if isinstance(eval(var), list):
tmp = self.ds[var].isin(eval(var))
elif isinstance(eval(var), range):
tmp = self.ds[var].isin([x for x in eval(var)])
elif isinstance(eval(var), int):
tmp = self.ds[var] == eval(var)
else:
raise TypeError(f"{var} should be a list, range or int")
if indexer is None:
indexer = tmp
else:
indexer = indexer & tmp
if indexer is not None:
indexer = indexer.compute()
# When all elements of the indexer are "False", calling "where" raises the following error:
# IndexError: The indexing operation you are attempting to perform is not valid on netCDF4.Variable object.
# Try loading your data into memory first by calling .load().
if any(indexer):
out = self.ds.where(indexer, drop=True)
else:
print("WARNING : No entry found with the given arguments, returning an empty Dataset")
return xr.Dataset()
else:
print("WARNING : arguments where empty or could not be interpreted, nothing changed.")
out = self.ds
return out
[docs]
@xr.register_dataset_accessor("distributed")
@xr.register_dataarray_accessor("distributed")
class DistributedAccessor(SnowtoolsAccessor):
"""
Additionnal methods in distributed geometry (ex: EDELWEISS)
Usage example:
.. code-block:: python
from snowtools.utils import xarray_snowtools
import xarray as xr
ds = xr.open_dataset('INPUT.nc', decode_times=False)
ds = xarray_snowtools.preprocess(ds)
ds.distributed.proj("EPSG:4326", "EPSG:2154")
"""
[docs]
def proj(self, crs_in="EPSG:4326", crs_out="EPSG:2154"):
"""
Projection of an xarray dataset or dataarray into a new CRS.
This method implies a dependency to rioxarray.
:param ds: xarray object to preprocess
:type ds: xarray Dataset or Dataarray
:param crs_in: CRS of the input object
:type crs_in: str
:param crs_out: CRS of the output object
:type crs_out: str
"""
import rioxarray # noqa
# TODO extract from rioxarray documentation :
# "If you use one of xarray’s open methods such as xarray.open_dataset to load netCDF files with the default
# engine, it is recommended to use decode_coords="all". This will load the grid mapping variable into
# coordinates for compatibility with rioxarray."
# TODO : check crs_in ?
self.ds.rio.set_spatial_dims(x_dim='xx', y_dim='yy', inplace=True)
self.ds.rio.write_crs(crs_in, inplace=True)
out = self.ds.rio.reproject(crs_out).rename(x='xx', y='yy')
return out
[docs]
def plot_ensemble(self, variable=None, vmin=None, vmax=None, cmap=None, dem=None, isolevels=None,
members: Union[str, int] = 'all', projection=None):
"""
Plot field(s) from an ensemble. To control the members of the ensemble to be plotted, use the "members"
argument.
The dataset must have a 'member' dimension, and the spatial dimensions ('xx', 'yy'). This implies that
the selection of the time step must be done before calling the method.
:param variable: Variable name to plot (Dataset only)
:type variable: str
:param vmin: Min colorbar value.
:type vmin: float
:param vmax: Max colorbar value.
:type vmax: float
:param cmap: Matplotlib colormap
:type cmap: str
:param dem: Digital Elevation Model covering the dataset area.
:type dem: DataArray
:param isolevels: List of iso-levels to plot
:type isolevels: list
:param members: How to plot the ensemble data:
'all': Plot all ensemble members on the same figure;
'mean': Plot the mean ensemble field;
int: Plot the given ensemble member only.
:type members: str or int
"""
if isinstance(self.ds, xr.Dataset):
if variable is None:
raise ValueError("A variable name should be provided")
elif variable not in list(self.ds.keys()):
raise ValueError(f"Variable {variable} not in Dataset")
else:
ensemble = self.ds[variable]
else:
ensemble = self.ds
if 'member' not in self.ds.dims:
raise AttributeError("The 'member' dimension is missing")
if not set(list(self.ds.dims)) == set(['xx', 'yy', 'member']):
raise AttributeError("The dimensions of the dataset must be exactly ('xx', 'yy', 'member')")
else:
import matplotlib.pyplot as plt
from snowtools.plots.maps import plot2D
if members == 'all':
ensemble.load()
if vmin is None:
vmin = ensemble.min().data
if vmax is None:
vmax = ensemble.max().data
# Assume ensemble size = 16
# TODO : Add check on ensemble size
fig, ax = plt.subplots(nrows=4, ncols=4, figsize=(22, 15))
i = 0
j = 0
for mb in ensemble.member.data[1:]:
tmp = ensemble.sel({'member': mb})
im = plot2D.plot_field(tmp, ax=ax[i, j], vmin=vmin, vmax=vmax, cmap=cmap, dem=dem,
isolevels=isolevels, add_colorbar=False)
j = j + 1
if j == 4:
j = 0
i = i + 1
for axis in ax.flatten():
axis.margins(0.02)
axis.set_title('')
axis.set_xticks([])
axis.set_yticks([])
axis.set_xlabel('')
axis.set_ylabel('')
fig.subplots_adjust(left=0.01, top=0.99, bottom=0.01, right=0.85, wspace=0.02, hspace=0.02)
cax = fig.add_axes([0.86, 0.02, 0.05, 0.96])
cb = fig.colorbar(im, cax=cax)
# cb.ax.tick_params(labelsize=20)
cb.set_label(ensemble.name, size=24)
return fig
elif members == 'mean':
field = ensemble.mean(dim='member')
elif isinstance(members, int):
field = ensemble.sel(member=members)
else:
raise ValueError(f'Could not interpret argument members ({members}). '
'Should be an interger, "mean" or "all".')
field.load()
if vmin is None:
vmin = field.min().data
if vmax is None:
vmax = field.max().data
gridlines = projection is not None
im = plot2D.plot_field(field, vmin=vmin, vmax=vmax, cmap=cmap, dem=dem,
isolevels=isolevels, gridlines=gridlines, projection=projection)
return im