Functions for plotting and mapping#
This is a markdown rendering of the plotting_utils module used in the notebooks. It is provided here for user reference, and may not reflect any changes to the code after 01/2026. The code can be viewed and downloaded from the github repository.
""" plotting_utils.py
Helper functions for generating maps and plots
"""
import xarray as xr
import numpy as np
import numpy.ma as ma
import pandas as pd
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from textwrap import wrap
import hvplot.xarray
import holoviews as hv
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from cartopy.mpl.geoaxes import GeoAxes
GeoAxes._pcolormesh_patched = Axes.pcolormesh # Helps avoid some weird issues with the polar projection
def get_winter_data(da, year_start=None, start_month="Sep", end_month="Apr", force_complete_season=False):
""" Select data for winter seasons corresponding to the input time range
Args:
da (xr.Dataset or xr.DataArray): data to restrict by time; must contain "time" as a coordinate
year_start (str, optional): year to start time range; if you want Sep 2019 - Apr 2020, set year="2019" (default to the first year in the dataset)
start_month (str, optional): first month in winter (default to September)
end_month (str, optional): second month in winter; this is the following calender year after start_month (default to April)
force_complete_season (bool, optional): require that winter season returns data if and only if all months have data? i.e. if Sep and Oct have no data, return nothing even if Nov-Apr have data? (default to False)
Returns:
da_winter (xr.Dataset or xr.DataArray): da restricted to winter seasons
"""
if year_start is None:
print("No start year specified. Getting winter data for first year in the dataset")
year_start = str(pd.to_datetime(da.time.values[0]).year)
start_timestep = start_month+" "+str(year_start) # mon year
end_timestep = end_month+" "+str(int(year_start)+1) # mon year
winter = pd.date_range(start=start_timestep, end=end_timestep, freq="MS") # pandas date range defining winter season
months_in_da = [mon for mon in winter if mon in da.time.values] # Just grab months if they correspond to a time coordinate in da
if len(months_in_da) > 0:
if (force_complete_season == True) and (all([mon in da.time.values for mon in winter])==False):
da_winter = None
else:
da_winter = da.sel(time=months_in_da)
else:
da_winter = None
return da_winter
def compute_gridcell_winter_means(da, years=None, start_month="Nov", end_month="Apr", force_complete_season=False):
""" Compute winter means over the time dimension. Useful for plotting as the grid is maintained.
Args:
da (xr.Dataset or xr.DataArray): data to restrict by time; must contain "time" as a coordinate
years (list of str): years over which to compute mean (default to unique years in the dataset)
year_start (str, optional): year to start time range; if you want Nov 2019 - Apr 2020, set year="2019" (default to the first year in the dataset)
start_month (str, optional): first month in winter (default to November)
end_month (str, optional): second month in winter; this is the following calender year after start_month (default to April)
force_complete_season (bool, optional): require that winter season returns data if and only if all months have data? i.e. if Sep and Oct have no data, return nothing even if Nov-Apr have data? (default to False)
Returns:
merged (xr.DataArray): DataArray with winter means as a time coordinate
"""
if years is None:
years = np.unique(pd.to_datetime(da.time.values).strftime("%Y")) # Unique years in the dataset
winter_means = []
for year in years: # Loop through each year and grab the winter months, compute winter mean, and append to list
da_winter_i = get_winter_data(da, year_start=year, start_month=start_month, end_month=end_month, force_complete_season=force_complete_season)
if da_winter_i is None:
continue
da_mean_i = da_winter_i.mean(dim="time", keep_attrs=True) # Compute mean over time dimension
# Assign time coordinate
time_arr = pd.to_datetime(da_winter_i.time.values)
da_mean_i = da_mean_i.assign_coords({"time":time_arr[0].strftime("%b %Y")+" - "+time_arr[-1].strftime("%b %Y")})
da_mean_i = da_mean_i.expand_dims("time")
winter_means.append(da_mean_i)
merged = xr.merge(winter_means) # Combine each winter mean Dataset into a single Dataset, with the time period maintained as a coordinate
merged = merged[list(merged.data_vars)[0]] # Convert to DataArray
merged.time.attrs["description"] = "Time period over which mean was computed" # Add descriptive attribute
return merged
def staticArcticMaps(da, title=None, dates=[], out_str="out", cmap="viridis", col=None, col_wrap=3, vmin=None, vmax=None, set_cbarlabel = '', min_lat=50, savefig=True):
""" Show data on a basemap of the Arctic. Can be one month or multiple months of data.
Creates an xarray facet grid. For more info, see: http://xarray.pydata.org/en/stable/user-guide/plotting.html
Args:
da (xr DataArray): data to plot
title (str, optional): title string for plot
dates (str list, option): dates to assign to subtitles, else defaults to whatever cartopy thinks they are
out_str (str, optional): output string when saving
cmap (str, optional): colormap to use (default to viridis)
col (str, optional): coordinate to use for creating facet plot (default to "time")
col_wrap (int, optional): number of columns of plots to display (default to 3, or None if time dimension has only one value)
vmin (float, optional): minimum on colorbar (default to 1st percentile)
vmax (float, optional): maximum on colorbar (default to 99th percentile)
min_lat (float, optional): minimum latitude to set extent of plot (default to 50 deg lat)
set_cbarlabel (str, optional): set colorbar label
savefig (bool): output figure
Returns:
Figure displayed in notebook
"""
# Compute min and max for plotting
def compute_vmin_vmax(da):
vmin = np.nanpercentile(da.values, 1)
vmax = np.nanpercentile(da.values, 99)
return vmin, vmax
vmin_data, vmax_data = compute_vmin_vmax(da)
vmin = vmin if vmin is not None else vmin_data # Set to smallest value of the two
vmax = vmax if vmax is not None else vmax_data # Set to largest value of the two
# All of this col and col_wrap maddness is to try and make this function as generalizable as possible
# This allows the function to work for DataArrays with multiple coordinates, different coordinates besides time, etc!
if col is None:
col = "time"
try: # Assign time coordinate if it doesn't exist
da["time"]
except AttributeError:
da = da.assign_coords({col:"unknown"})
col = col if sum(da[col].shape) > 1 else None
if col is not None:
if sum(da[col].shape)<=1:
col_wrap = None
# Plot
if len(set_cbarlabel)==0:
set_cbarlabel=da.attrs["long_name"]+' ['+da.attrs["units"]+']'
im = da.plot(x="longitude", y="latitude", col_wrap=col_wrap, col=col, transform=ccrs.PlateCarree(), cmap=cmap, zorder=8,
cbar_kwargs={'pad':0.02,'shrink': 0.8,'extend':'both', 'label':set_cbarlabel, 'location':'left'},
vmin=vmin, vmax=vmax,
subplot_kws={'projection':ccrs.NorthPolarStereo(central_longitude=-45)})
# Iterate through axes and add features
ax_iter = im.axes
if type(ax_iter) != np.array: # If the data is just a single month, ax.iter returns an axis object. We need to iterate through a list or array
ax_iter = np.array(ax_iter)
i=0
for ax in ax_iter.flatten():
ax.coastlines(linewidth=0.15, color = 'black', zorder = 10) # Coastlines
ax.add_feature(cfeature.LAND, color ='0.95', zorder = 5) # Land
ax.add_feature(cfeature.LAKES, color = 'grey', zorder = 5) # Lakes
ax.gridlines(draw_labels=False, linewidth=0.25, color='gray', alpha=0.7, linestyle='--', zorder=6) # Gridlines
ax.set_extent([-179, 179, min_lat, 90], crs=ccrs.PlateCarree()) # Set extent to zoom in on Arctic
if len(dates)>0:
try:
ax.set_title(dates[i], fontsize=10, horizontalalignment="center",verticalalignment="bottom", x=0.5, y=1.01, fontweight='medium')
except:
print('no date')
i+=1
# Get figure
fig = plt.gcf()
# Set title
if (sum(ax_iter.shape) == 0) and (title is not None):
ax.set_title(title, fontsize=10, horizontalalignment="center", x=0.5, y=1.06, fontweight='medium')
elif title is not None:
fig.suptitle(title, fontsize=10, horizontalalignment="center", x=0.5, y=1.06, fontweight='medium')
# save figure
if savefig:
plt.savefig('./figs/maps_'+out_str+'.png', dpi=300, facecolor="white", bbox_inches='tight')
plt.close() # Close so it doesnt automatically display in notebook
return fig
def staticArcticMaps_2025(da, title=None, dates=[], out_str="out", cmap="viridis", col=None, vmin=None, vmax=None, set_cbarlabel = '', min_lat=50, savefig=True):
""" Show data on a basemap of the Arctic with special 2025 layout for 7 winters.
Creates a custom layout where the first 6 winters are in regular panels and the 7th winter
is larger (spans 2 rows and 2 columns) and positioned on the right side.
Args:
da (xr DataArray): data to plot (should have 7 time periods)
title (str, optional): title string for plot
dates (str list, option): dates to assign to subtitles, else defaults to whatever cartopy thinks they are
out_str (str, optional): output string when saving
cmap (str, optional): colormap to use (default to viridis)
col (str, optional): coordinate to use for creating facet plot (default to "time")
vmin (float, optional): minimum on colorbar (default to 1st percentile)
vmax (float, optional): maximum on colorbar (default to 99th percentile)
min_lat (float, optional): minimum latitude to set extent of plot (default to 50 deg lat)
set_cbarlabel (str, optional): set colorbar label
savefig (bool): output figure
Returns:
Figure displayed in notebook
"""
# Compute min and max for plotting
def compute_vmin_vmax(da):
vmin = np.nanpercentile(da.values, 1)
vmax = np.nanpercentile(da.values, 99)
return vmin, vmax
vmin_data, vmax_data = compute_vmin_vmax(da)
vmin = vmin if vmin is not None else vmin_data # Set to smallest value of the two
vmax = vmax if vmax is not None else vmax_data # Set to largest value of the two
# All of this col maddness is to try and make this function as generalizable as possible
if col is None:
col = "time"
try: # Assign time coordinate if it doesn't exist
da["time"]
except AttributeError:
da = da.assign_coords({col:"unknown"})
col = col if sum(da[col].shape) > 1 else None
# Plot
if len(set_cbarlabel)==0:
set_cbarlabel=da.attrs["long_name"]+' ['+da.attrs["units"]+']'
# Create custom subplot layout: 2 rows, 5 columns (optimized for 7 winters)
# First 3 columns for regular panels (6 panels), last 2 columns for the large panel (1 panel)
fig = plt.figure(figsize=(15, 6)) # 2 rows × 6 height units
# Create GridSpec for custom layout
gs = fig.add_gridspec(2, 5, width_ratios=[1, 1, 1, 1, 1], height_ratios=[1, 1])
# Plot regular panels (first 6 winters)
axes = []
for i in range(6): # First 6 winters
row = i // 3
col_idx = i % 3
ax = fig.add_subplot(gs[row, col_idx], projection=ccrs.NorthPolarStereo(central_longitude=0))
axes.append(ax)
# Plot data
if col is not None:
data_to_plot = da.isel({col: i})
else:
data_to_plot = da
im = data_to_plot.plot(ax=ax, x="longitude", y="latitude", transform=ccrs.PlateCarree(),
cmap=cmap, zorder=8, vmin=vmin, vmax=vmax, add_colorbar=False)
# Add map features
ax.coastlines(linewidth=0.15, color='black', zorder=10)
ax.add_feature(cfeature.LAND, color='0.95', zorder=5)
ax.add_feature(cfeature.LAKES, color='grey', zorder=5)
ax.gridlines(draw_labels=False, linewidth=0.25, color='gray', alpha=0.7, linestyle='--', zorder=6)
ax.set_extent([-179, 179, 54, 90], crs=ccrs.PlateCarree())
# Set title
if len(dates) > i:
ax.set_title(dates[i], fontsize=10, horizontalalignment="center", verticalalignment="bottom",
x=0.5, y=0.97, fontweight='medium')
# Plot the large panel (7th winter, spans 2 rows and 2 columns)
ax_large = fig.add_subplot(gs[:, 3:], projection=ccrs.NorthPolarStereo(central_longitude=0))
axes.append(ax_large)
# Plot data for the 7th winter
if col is not None:
data_to_plot = da.isel({col: 6}) # 7th winter (index 6)
else:
data_to_plot = da
im_large = data_to_plot.plot(ax=ax_large, x="longitude", y="latitude", transform=ccrs.PlateCarree(),
cmap=cmap, zorder=8, vmin=vmin, vmax=vmax, add_colorbar=False)
# Add map features
ax_large.coastlines(linewidth=0.15, color='black', zorder=10)
ax_large.add_feature(cfeature.LAND, color='0.95', zorder=5)
ax_large.add_feature(cfeature.LAKES, color='grey', zorder=5)
ax_large.gridlines(draw_labels=False, linewidth=0.25, color='gray', alpha=0.7, linestyle='--', zorder=6)
ax_large.set_extent([-179, 179, 54, 90], crs=ccrs.PlateCarree())
# Set title
if len(dates) > 6:
ax_large.set_title(dates[6], fontsize=10, horizontalalignment="center", verticalalignment="bottom",
x=0.5, y=0.99, fontweight='medium')
# Add colorbar inside the large panel (bottom left) without affecting panel position
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
cbar_ax = inset_axes(ax_large, width="30%", height="3%", loc='lower left')
cbar = fig.colorbar(im_large, cax=cbar_ax, orientation='horizontal', extend='both')
cbar.set_label(set_cbarlabel, fontsize=10, labelpad=10)
cbar.ax.xaxis.set_ticks_position('top')
cbar.ax.xaxis.set_label_position('top')
cbar.set_ticks(np.linspace(vmin, vmax, 6)) # Set ticks from 0 to 5 in steps of 1
# Set overall title
if title is not None:
fig.suptitle(title, fontsize=12, horizontalalignment="center", x=0.5, y=0.95, fontweight='medium')
# Adjust layout with reduced spacing
plt.subplots_adjust(left=0.05, right=0.95, top=0.95, bottom=0.02, wspace=0.03, hspace=0.05)
# Save figure
if savefig:
plt.savefig('./figs/maps_'+out_str+'.png', dpi=300, facecolor="white", bbox_inches='tight')
plt.close() # Close so it doesnt automatically display in notebook
return fig
def staticArcticMaps_overlayDrifts(da, drifts_x, drifts_y, alpha=1, vector_val=0.1, scale_vec=0.5, res=6, units_vec=r'm s$^{-1}$', title=None, out_str="out", dates=[], cmap="viridis", col=None, col_wrap=3, vmin=None, vmax=None, set_cbarlabel = '', min_lat=50, savefig=True, figsize=(6,6)):
""" Show data on a basemap of the Arctic. Can be one month or multiple months of data. Overlay drift vectors on top
Creates an xarray facet grid. For more info, see: http://xarray.pydata.org/en/stable/user-guide/plotting.html
Args:
da (xr DataArray): data to plot
drifts_x (xr.DataArray): sea ice drifts along-x component of the ice motion
drifts_y (xr.DataArray): sea ice drifts along-y component of the ice motion
alpha (float 0-1, optional): Set this variable if you want da to have a reduced opacity (default to 1)
res (int, optional): resolution of vectors (default to 6; plot 1 out of every 6 vectors)
title (str, optional): title string for plot
out_str (str, optional): output string when saving
cmap (str, optional): colormap to use (default to viridis)
col (str, optional): coordinate to use for creating facet plot (default to "time")
col_wrap (int, optional): number of columns of plots to display (default to 3, or None if time dimension has only one value)
vmin (float, optional): minimum on colorbar (default to 1st percentile)
vmax (float, optional): maximum on colorbar (default to 99th percentile)
min_lat (float, optional): minimum latitude to set extent of plot (default to 50 deg lat)
set_cbarlabel (str, optional): set colorbar label
savefig (bool): output figure
Returns:
Figure displayed in notebook
"""
# Make sure alpha is between 0 and 1
if alpha > 1:
print("Argument alpha must be between 0 and 1. You inputted " +str(alpha)+ ". Setting alpha to 1.")
alpha = 1
elif alpha < 0:
print("Argument alpha must be between 0 and 1. You inputted " +str(alpha)+ ". Setting alpha to 0.5.")
alpha = 0.5
elif alpha == 0:
print("You set alpha=0. This indicates full transparency of the input data. No data will be displayed on the map.")
# Check that drifts and da have the same time coordinates
for drift in [drifts_x,drifts_y]:
equality = (da.time.values == drift.time.values)
if type(equality) == np.ndarray:
if not all(equality):
raise ValueError("Drifts vectors and input DataArray must have the same time coordinates")
elif (equality==False):
raise ValueError("Drifts vectors and input DataArray must have the same time coordinates")
# Compute min and max for plotting
def compute_vmin_vmax(da):
vmin = np.nanpercentile(da.values, 1)
vmax = np.nanpercentile(da.values, 99)
return vmin, vmax
vmin_data, vmax_data = compute_vmin_vmax(da)
vmin = vmin if vmin is not None else vmin_data # Set to smallest value of the two
vmax = vmax if vmax is not None else vmax_data # Set to largest value of the two
# All of this col and col_wrap maddness is to try and make this function as generalizable as possible
# This allows the function to work for DataArrays with multiple coordinates, different coordinates besides time, etc!
if col is None:
col = "time"
try: # Assign time coordinate if it doesn't exist
da["time"]
except AttributeError:
da = da.assign_coords({col:"unknown"})
col = col if sum(da[col].shape) > 1 else None
if col is not None:
if sum(da[col].shape)<=1:
col_wrap = None
# Plot
if len(set_cbarlabel)==0:
set_cbarlabel=da.attrs["long_name"]+' ['+da.attrs["units"]+']'
im = da.plot(x="longitude", y="latitude", col_wrap=col_wrap, col=col, transform=ccrs.PlateCarree(), cmap=cmap,
cbar_kwargs={'pad':0.02,'shrink': 0.8,'extend':'both', 'label':set_cbarlabel},
vmin=vmin, vmax=vmax, zorder=2, alpha=alpha,
subplot_kws={'projection':ccrs.NorthPolarStereo(central_longitude=-45)})
# Iterate through axes and add features
ax_iter = im.axes
if type(ax_iter) != np.array: # If the data is just a single month, ax.iter returns an axis object. We need to iterate through a list or array
ax_iter = np.array(ax_iter)
i = 0
try:
num_timesteps = len(da.time.values)
except:
num_timesteps = 1
for ax, i in zip(ax_iter.flatten(), range(num_timesteps)):
# Add drifts
if num_timesteps == 1:
drifts_xi = drifts_x.copy()
drifts_yi = drifts_y.copy()
else:
drifts_xi = drifts_x.isel(time=i).copy()
drifts_yi = drifts_y.isel(time=i).copy()
Q = ax.quiver(drifts_x.xgrid[::res, ::res], drifts_y.ygrid[::res, ::res],
ma.masked_where(np.isnan(drifts_xi[::res, ::res]), drifts_xi[::res, ::res]),
ma.masked_where(np.isnan(drifts_yi[::res, ::res]), drifts_yi[::res, ::res]) , units='inches', scale=scale_vec, zorder=10)
ax.quiverkey(Q, 0.85, 0.88, vector_val, str(vector_val)+' '+units_vec, coordinates='axes', zorder=11)
ax.coastlines(linewidth=0.15, color = 'black', zorder = 8) # Coastlines
ax.add_feature(cfeature.LAND, color ='0.95', zorder = 5) # Land
ax.add_feature(cfeature.LAKES, color = 'grey', zorder = 5) # Lakes
ax.gridlines(draw_labels=False, linewidth=0.25, color='gray', alpha=0.7, linestyle='--', zorder=6) # Gridlines
ax.set_extent([-179, 179, min_lat, 90], crs=ccrs.PlateCarree()) # Set extent to zoom in on Arctic
if len(dates)>0:
ax.set_title(dates[i], fontsize=10, horizontalalignment="center",verticalalignment="bottom", x=0.5, y=1.01, fontweight='medium')
# Get figure
fig = plt.gcf()
# Set title
if (sum(ax_iter.shape) == 0) and (title is not None):
ax.set_title(title, fontsize=10, horizontalalignment="center", x=0.45, y=1.06, fontweight='medium')
elif title is not None:
fig.suptitle(title, fontsize=10, horizontalalignment="center", x=0.45, y=1.06, fontweight='medium')
# save figure
if savefig:
plt.savefig('./figs/maps_'+out_str+'.png', dpi=400, facecolor="white", bbox_inches='tight')
plt.close() # Close so it doesnt automatically display in notebook
return fig
def interactiveArcticMaps(da, clabel=None, cmap="viridis", colorbar=True, vmin=None, vmax=None, title="", ylim=(60,90), frame_width=500, slider=True, cols=3):
""" Generative one or more interactive maps
Using the argument "slide", the user can set whether each map should be displayed together, or displayed in the form of a slider
To show each map together (no slider), set slider=False
Args:
da (xr.Dataset or xr.DataArray): data
clabel (str, optional): colorbar label (default to "long_name" and "units" if given in attributes of da)
cmap (str, optional): matplotlib colormap to use (default to "viridis")
colorbar (bool, optional): show colorbar? (default to True)
vmin (float, optional): minimum on colorbar (default to 1st percentile)
vmax (float, optional): maximum on colorbar (default to 99th percentile)
title (str, optional): main title to give plot (default to no title)
ylim (tuple, optional): limits of yaxis in the form min latitude, max latitude (default to (60,90))
frame_width (int, optional): width of frame. sets figure size of each map (default to 250)
slider (bool, optional): if da has more than one time coordinate, display maps with a slider? (default to True)
cols (int, optional): how many columns to show before wrapping, if da has more than one time coordinate (default to 3)
Returns:
pl (Holoviews map)
"""
# Compute min and max for plotting
def compute_vmin_vmax(da):
vmin = np.nanpercentile(da.values, 1)
vmax = np.nanpercentile(da.values, 99)
return vmin, vmax
vmin_data, vmax_data = compute_vmin_vmax(da)
vmin = vmin if vmin is not None else vmin_data # Set to smallest value of the two
vmax = vmax if vmax is not None else vmax_data # Set to largest value of the two
#https://hvplot.holoviz.org/user_guide/Subplots.html
subplots=False
shared_axes=False
show_title=False
if ("time" in da.coords):
if (sum(da["time"].shape) > 1):
subplots=True
shared_axes=True
if slider==True and title=="":
show_title=False # We don't want to remove the title for the slider plots since it removes the time from the title
if clabel is None and ("long_name" in da.attrs): # Add a logical colorbar label
clabel=da.attrs["long_name"]
if "units" in da.attrs:
clabel+=" ("+da.attrs["units"]+")"
pl = da.hvplot.quadmesh(y="latitude", x="longitude",
projection=ccrs.NorthPolarStereo(central_longitude=-45),
features=["coastline"], # Add coastlines
colorbar=colorbar, clim=(vmin,vmax), cmap=cmap, clabel=clabel, # Colorbar settings
project=True, ylim=ylim, frame_width=frame_width,
subplots=subplots, shared_axes=shared_axes,
dynamic=False, rasterize=True)
if slider==False: # Set number of columns
pl = pl.layout().cols(cols)
if show_title==True:
pl.opts(title=title) # Add title
hv.output(widget_location="bottom")
return pl
def interactive_winter_mean_maps(da, years=None, end_year=None, start_month="Sep", end_month="Apr", force_complete_season=False, clabel=None, cmap="viridis", colorbar=True, vmin=0, vmax=4, title="", ylim=(60,90), frame_width=250, slider=True, cols=3):
""" Generate interactive maps of winter mean data
Note: this function builds off the functions get_winter_data and interactiveArcticMaps.
Args:
da (xr.Dataset or xr.DataArray): data; must contain "time" coordinate
years (list of str): years over which to compute mean (default to unique years in the dataset)
start_month (str, optional): first month in winter (default to September)
end_month (str, optional): second month in winter; this is the following calender year after start_month (default to April)
force_complete_season (bool, optional): require that winter season returns data if and only if all months have data? i.e. if Sep and Oct have no data, return nothing even if Nov-Apr have data? (default to False)
clabel (str, optional): colorbar label (default to "long_name" and "units" if given in attributes of da)
cmap (str, optional): matplotlib colormap to use (default to "viridis")
colorbar (bool, optional): show colorbar? (default to True)
vmin (float, optional): minimum on colorbar (default to 0)
vmax (float, optional): maximum on colorbar (default to 4)
title (str, optional): main title to give plot (default to no title)
ylim (tuple, optional): limits of yaxis in the form min latitude, max latitude (default to (60,90))
frame_width (int, optional): width of frame. sets figure size of each map (default to 250)
slider (bool, optional): if da has more than one time coordinate, display maps with a slider? (default to True)
cols (int, optional): how many columns to show before wrapping, if da has more than one time coordinate (default to 3)
Returns:
pl_means (Holoviews map)
"""
winter_means_da = compute_gridcell_winter_means(da, years=years, start_month=start_month, end_month=end_month, force_complete_season=force_complete_season)
pl_means = interactiveArcticMaps(winter_means_da,
clabel=clabel, cmap=cmap, colorbar=colorbar,
vmin=vmin, vmax=vmax, title=title,
ylim=ylim, frame_width=frame_width, slider=slider, cols=cols)
hv.output(widget_location="bottom")
return pl_means
def static_winter_comparison_lineplot(da, da_unc=None, years=None, figsize=(5,3), start_month="Sep",
end_month="Apr", title="", set_ylabel = '', set_units = '', legend=True, savefig=True, save_label='',
annotation = '', force_complete_season=False, loc_pos=0, fmts = ['mo-.','cs-.','yv-.','k*-','r.-','gD--','b-.'],
reanalysis_option=None):
""" Make a lineplot with markers comparing monthly mean data across winter seasons
Args:
da (xr.DataArray): data to plot and compute mean for; must contain "time" as a coordinate
da_unc (xr.DataArray, optional): uncertainty data to plot as error bars
years (list of str): list of years for which to plot data. 2020 would correspond to the winter season defined by start month 2020 - end month 2021 (default to all unique years in da)
title (str, optional): title to give plot (default to no title)
set_ylabel (str, optional): prescribed y label string
set_units (str, optional): prescribed y label unit string
legend (bool): print legend
savefig (bool): output figure
save_label (str, optional): additional string for output
figsize (tuple, optional): figure size to display in notebook (default to (5,3))
start_month (str, optional): first month in winter (default to September)
end_month (str, optional): second month in winter; this is the following calender year after start_month (default to April)
force_complete_season (bool, optional): require that winter season returns data if and only if all months have data? i.e. if Sep and Oct have no data, return nothing even if Nov-Apr have data? (default to False)
loc_pos (int, optional): if greater than one use that, if not default to "best"
fmts (list, optional): list of format strings for different years
reanalysis_option (str, optional): specify which reanalysis to use for snow depth ('m2' or 'e5'). If None, uses the default snow_depth variable.
Returns:
Figure displayed in notebook
"""
if years is None:
years = np.unique(pd.to_datetime(da.time.values).strftime("%Y")) # Unique years in the dataset
print("No years specified. Using "+", ".join(years))
# Handle reanalysis option for snow depth
if reanalysis_option is not None:
if reanalysis_option.lower() == 'm2':
# Use M2 reanalysis snow depth
if hasattr(da, 'snow_depth_sm_m2'):
da = da.snow_depth_sm_m2
elif hasattr(da, 'snow_depth_sm_m2_int'):
da = da.snow_depth_sm_m2_int
else:
print(f"Warning: M2 snow depth variable not found in dataset. Using default snow depth.")
elif reanalysis_option.lower() == 'e5':
# Use E5 reanalysis snow depth
if hasattr(da, 'snow_depth_sm_e5'):
da = da.snow_depth_sm_e5
elif hasattr(da, 'snow_depth_sm_e5_int'):
da = da.snow_depth_sm_e5_int
else:
print(f"Warning: E5 snow depth variable not found in dataset. Using default snow depth.")
else:
print(f"Warning: Invalid reanalysis option '{reanalysis_option}'. Using default snow depth.")
# Set up x-axis
# This avoids having a set x-axis of winter months between Sep-Apr, even if there's no data for Sep, Oct etc
yr = 2000
if end_month not in ["Oct","Nov","Dec"]:
yr_end = yr+1
else:
yr_end = yr
xaxis_months = pd.date_range(start_month+"-"+str(yr), end_month+"-"+str(yr_end), freq="M").strftime("%b")
# Set up plot
fig, ax = plt.subplots(figsize=figsize)
ax.plot(xaxis_months, np.empty((len(xaxis_months),1))*np.nan, color=None, label=None) # Set x axis using winter months
try:
gridlines = plt.grid(b = True, linestyle = '-', alpha = 0.2) # Add gridlines
except:
try:
gridlines = plt.grid(visible = True, linestyle = '-', alpha = 0.2) # Add gridlines
except:
print("No gridlines")
for year, fmt in zip(years, fmts*100):
winter_da = get_winter_data(da, year_start=year, start_month=start_month, end_month=end_month, force_complete_season=force_complete_season) # Get data from that winter
if winter_da is None: # In case the user inputs a year that doesn't have data, skip this loop iteration to avoid appending None
continue
y = winter_da.mean(dim=["x","y"], keep_attrs=True)
x = pd.to_datetime(y.time.values)
# Add reanalysis info to legend if specified
if reanalysis_option is not None:
label = f"{x.year[0]}-{str(x.year[-1])[2:]} ({reanalysis_option.upper()})"
else:
label = f"{x.year[0]}-{str(x.year[-1])[2:]}"
ax.plot(x.strftime("%b"), y, fmt, label=label, markersize=4)
if da_unc is not None:
# Get uncertaintiy data from that winter
winter_da_unc = get_winter_data(da_unc, year_start=year, start_month=start_month, end_month=end_month, force_complete_season=force_complete_season)
if winter_da_unc is None: # In case the user inputs a year that doesn't have data, skip this loop iteration to avoid appending None
continue
yu = winter_da_unc.mean(dim=["x","y"], keep_attrs=True)
ax.fill_between(x.strftime("%b"), y - yu, y + yu, facecolor = fmt[0], alpha = 0.1, edgecolor = 'none')
# Add legend, title, and axis labels, and display plot in notebook
if legend:
if loc_pos>0:
plt.legend(fontsize=8, frameon=False,loc=loc_pos)
else:
plt.legend(fontsize=8, frameon=False, loc="best")
# Add annotation if provided
ax.annotate(annotation, xy=(0.02, 0.98),xycoords='axes fraction', horizontalalignment='left', verticalalignment='top', fontsize=8, zorder=2)
plt.title(title, fontsize=9)
if len(set_ylabel)>0:
ylabel=set_ylabel
elif "long_name" in da.attrs:
ylabel = da.attrs["long_name"]
if "units" in da.attrs:
ylabel+=" ("+da.attrs["units"]+")"
ylabel="\n".join(wrap(ylabel, 35))
else:
ylabel=None
plt.ylabel(ylabel, fontsize=8)
ax.tick_params(axis='both', which='major', labelsize=8)
# reduce white space
plt.tight_layout()
# save figure
if savefig:
# Include reanalysis option in filename if specified
if reanalysis_option is not None:
filename = f'./figs/{da.attrs.get("long_name", "data")}{start_month}{end_month}{years[0]}-{years[-1]+1}{save_label}_{reanalysis_option}.pdf'
else:
filename = f'./figs/{da.attrs.get("long_name", "data")}{start_month}{end_month}{years[0]}-{years[-1]+1}{save_label}.pdf'
plt.savefig(filename, dpi=300, facecolor="white", bbox_inches='tight')
plt.show()
def static_winter_comparison_lineplot_with_reanalysis(da, reanalysis_option='m2', da_unc=None, years=None, figsize=(5,3), start_month="Sep",
end_month="Apr", title="", set_ylabel = '', set_units = '', legend=True, savefig=True, save_label='',
annotation = '', force_complete_season=False, loc_pos=0, fmts = ['mo-.','cs-.','yv-.','k*-','r.-','gD--','b-.']):
""" Make a lineplot with markers comparing monthly mean data across winter seasons with reanalysis option for snow depth
This is a convenience function that calls static_winter_comparison_lineplot with the reanalysis_option parameter.
Args:
da (xr.DataArray): data to plot and compute mean for; must contain "time" as a coordinate
reanalysis_option (str): specify which reanalysis to use for snow depth ('m2' or 'e5')
da_unc (xr.DataArray, optional): uncertainty data to plot as error bars
years (list of str): list of years for which to plot data. 2020 would correspond to the winter season defined by start month 2020 - end month 2021 (default to all unique years in da)
title (str, optional): title to give plot (default to no title)
set_ylabel (str, optional): prescribed y label string
set_units (str, optional): prescribed y label unit string
legend (bool): print legend
savefig (bool): output figure
save_label (str, optional): additional string for output
figsize (tuple, optional): figure size to display in notebook (default to (5,3))
start_month (str, optional): first month in winter (default to September)
end_month (str, optional): second month in winter; this is the following calender year after start_month (default to April)
force_complete_season (bool, optional): require that winter season returns data if and only if all months have data? i.e. if Sep and Oct have no data, return nothing even if Nov-Apr have data? (default to False)
loc_pos (int, optional): if greater than one use that, if not default to "best"
fmts (list, optional): list of format strings for different years
Returns:
Figure displayed in notebook
"""
return static_winter_comparison_lineplot(da, da_unc=da_unc, years=years, figsize=figsize, start_month=start_month,
end_month=end_month, title=title, set_ylabel=set_ylabel, set_units=set_units,
legend=legend, savefig=savefig, save_label=save_label, annotation=annotation,
force_complete_season=force_complete_season, loc_pos=loc_pos, fmts=fmts,
reanalysis_option=reanalysis_option)
def interactive_winter_comparison_lineplot(da, years=None, title="Winter comparison", frame_width=600, frame_height=350, start_month="Sep", end_month="Apr", force_complete_season=False):
""" Make an interactive lineplot with markers comparing monthly mean data across winter seasons
Args:
da (xr.DataArray): data to plot and compute mean for; must contain "time" as a coordinate
years (list of str): list of years for which to plot data. 2020 would correspond to the winter season defined by start month 2020 - end month 2021 (default to all unique years in da)
title (str, optional): title to give plot (default to "Winter comparison")
frame_width (int, optional): width of plot frame (default to 600)
frame_height (int, optional): height of plot frame (default to 350)
start_month (str, optional): first month in winter (default to September)
end_month (str, optional): second month in winter; this is the following calender year after start_month (default to April)
force_complete_season (bool, optional): require that winter season returns data if and only if all months have data? i.e. if Sep and Oct have no data, return nothing even if Nov-Apr have data? (default to False)
Returns:
Interactive plot displayed in notebook
"""
if years is None:
years = np.unique(pd.to_datetime(da.time.values).strftime("%Y")) # Unique years in the dataset
print("No years specified. Using "+", ".join(years))
# Get winter data for each year
winter_data = []
for year in years:
winter_da = get_winter_data(da, year_start=year, start_month=start_month, end_month=end_month, force_complete_season=force_complete_season)
if winter_da is not None:
winter_mean = winter_da.mean(dim=["x","y"], keep_attrs=True)
winter_data.append(winter_mean)
if len(winter_data) == 0:
print("No winter data found for the specified years")
return None
# Combine all winter data
combined_data = xr.concat(winter_data, dim='year')
combined_data['year'] = years[:len(winter_data)]
# Create interactive plot
p = combined_data.hvplot.line(x='time', by='year', title=title, frame_width=frame_width, frame_height=frame_height)
return p