Source code for parcels.plotting

from datetime import datetime
from datetime import timedelta as delta

import numpy as np
import copy

from parcels.field import Field
from parcels.field import VectorField
from parcels.grid import CurvilinearGrid
from parcels.grid import GridCode
from parcels.tools.statuscodes import TimeExtrapolationError
from parcels.tools.loggers import logger


[docs]def plotparticles(particles, with_particles=True, show_time=None, field=None, domain=None, projection=None, land=True, vmin=None, vmax=None, savefile=None, animation=False, **kwargs): """Function to plot a Parcels ParticleSet :param show_time: Time at which to show the ParticleSet :param with_particles: Boolean whether particles are also plotted on Field :param field: Field to plot under particles (either None, a Field object, or 'vector') :param domain: dictionary (with keys 'N', 'S', 'E', 'W') defining domain to show :param projection: type of cartopy projection to use (default PlateCarree) :param land: Boolean whether to show land. This is ignored for flat meshes :param vmin: minimum colour scale (only in single-plot mode) :param vmax: maximum colour scale (only in single-plot mode) :param savefile: Name of a file to save the plot to :param animation: Boolean whether result is a single plot, or an animation """ show_time = particles[0].time if show_time is None else show_time if isinstance(show_time, datetime): show_time = np.datetime64(show_time) if isinstance(show_time, np.datetime64): if not particles.time_origin: raise NotImplementedError( 'If fieldset.time_origin is not a date, showtime cannot be a date in particleset.show()') show_time = particles.time_origin.reltime(show_time) if isinstance(show_time, delta): show_time = show_time.total_seconds() if np.isnan(show_time): show_time, _ = particles.fieldset.gridset.dimrange('time_full') if field is None: spherical = True if particles.fieldset.U.grid.mesh == 'spherical' else False plt, fig, ax, cartopy = create_parcelsfig_axis(spherical, land, projection, cartopy_features=kwargs.pop('cartopy_features', [])) if plt is None: return # creating axes was not possible ax.set_title('Particles' + parsetimestr(particles.fieldset.U.grid.time_origin, show_time)) latN, latS, lonE, lonW = parsedomain(domain, particles.fieldset.U) if cartopy is None or projection is None: if domain is not None: if isinstance(particles.fieldset.U.grid, CurvilinearGrid): ax.set_xlim(particles.fieldset.U.grid.lon[latS, lonW], particles.fieldset.U.grid.lon[latN, lonE]) ax.set_ylim(particles.fieldset.U.grid.lat[latS, lonW], particles.fieldset.U.grid.lat[latN, lonE]) else: ax.set_xlim(particles.fieldset.U.grid.lon[lonW], particles.fieldset.U.grid.lon[lonE]) ax.set_ylim(particles.fieldset.U.grid.lat[latS], particles.fieldset.U.grid.lat[latN]) else: ax.set_xlim(np.nanmin(particles.fieldset.U.grid.lon), np.nanmax(particles.fieldset.U.grid.lon)) ax.set_ylim(np.nanmin(particles.fieldset.U.grid.lat), np.nanmax(particles.fieldset.U.grid.lat)) elif domain is not None: if isinstance(particles.fieldset.U.grid, CurvilinearGrid): ax.set_extent([particles.fieldset.U.grid.lon[latS, lonW], particles.fieldset.U.grid.lon[latN, lonE], particles.fieldset.U.grid.lat[latS, lonW], particles.fieldset.U.grid.lat[latN, lonE]]) else: ax.set_extent([particles.fieldset.U.grid.lon[lonW], particles.fieldset.U.grid.lon[lonE], particles.fieldset.U.grid.lat[latS], particles.fieldset.U.grid.lat[latN]]) else: if field == 'vector': field = particles.fieldset.UV elif not isinstance(field, Field): field = getattr(particles.fieldset, field) depth_level = kwargs.pop('depth_level', 0) plt, fig, ax, cartopy = plotfield(field=field, animation=animation, show_time=show_time, domain=domain, projection=projection, land=land, vmin=vmin, vmax=vmax, savefile=None, titlestr='Particles and ', depth_level=depth_level, **kwargs) if plt is None: return # creating axes was not possible if with_particles: plon = np.array([p.lon for p in particles]) plat = np.array([p.lat for p in particles]) if cartopy: ax.scatter(plon, plat, s=20, color='black', zorder=20, transform=cartopy.crs.PlateCarree()) else: ax.scatter(plon, plat, s=20, color='black', zorder=20) if animation: plt.draw() plt.pause(0.0001) elif savefile is None: plt.show() else: plt.savefig(savefile) logger.info('Plot saved to ' + savefile + '.png') plt.close()
[docs]def plotfield(field, show_time=None, domain=None, depth_level=0, projection=None, land=True, vmin=None, vmax=None, savefile=None, **kwargs): """Function to plot a Parcels Field :param show_time: Time at which to show the Field :param domain: dictionary (with keys 'N', 'S', 'E', 'W') defining domain to show :param depth_level: depth level to be plotted (default 0) :param projection: type of cartopy projection to use (default PlateCarree) :param land: Boolean whether to show land. This is ignored for flat meshes :param vmin: minimum colour scale (only in single-plot mode) :param vmax: maximum colour scale (only in single-plot mode) :param savefile: Name of a file to save the plot to :param animation: Boolean whether result is a single plot, or an animation """ if type(field) is VectorField: spherical = True if field.U.grid.mesh == 'spherical' else False field = [field.U, field.V] plottype = 'vector' elif type(field) is Field: spherical = True if field.grid.mesh == 'spherical' else False field = [field] plottype = 'scalar' else: raise RuntimeError('field needs to be a Field or VectorField object') if field[0].grid.gtype in [GridCode.CurvilinearZGrid, GridCode.CurvilinearSGrid]: logger.warning('Field.show() does not always correctly determine the domain for curvilinear grids. ' 'Use plotting with caution and perhaps use domain argument as in the NEMO 3D tutorial') plt, fig, ax, cartopy = create_parcelsfig_axis(spherical, land, projection=projection, cartopy_features=kwargs.pop('cartopy_features', [])) if plt is None: return None, None, None, None # creating axes was not possible data = {} plotlon = {} plotlat = {} for i, fld in enumerate(field): show_time = fld.grid.time[0] if show_time is None else show_time if fld.grid.defer_load: fld.fieldset.computeTimeChunk(show_time, 1) (idx, periods) = fld.time_index(show_time) show_time -= periods * (fld.grid.time_full[-1] - fld.grid.time_full[0]) if show_time > fld.grid.time[-1] or show_time < fld.grid.time[0]: raise TimeExtrapolationError(show_time, field=fld, msg='show_time') latN, latS, lonE, lonW = parsedomain(domain, fld) if isinstance(fld.grid, CurvilinearGrid): plotlon[i] = fld.grid.lon[latS:latN, lonW:lonE] plotlat[i] = fld.grid.lat[latS:latN, lonW:lonE] else: plotlon[i] = fld.grid.lon[lonW:lonE] plotlat[i] = fld.grid.lat[latS:latN] if i > 0 and not np.allclose(plotlon[i], plotlon[0]): raise RuntimeError('VectorField needs to be on an A-grid for plotting') if fld.grid.time.size > 1: if fld.grid.zdim > 1: data[i] = np.squeeze(fld.temporal_interpolate_fullfield(idx, show_time))[depth_level, latS:latN, lonW:lonE] else: data[i] = np.squeeze(fld.temporal_interpolate_fullfield(idx, show_time))[latS:latN, lonW:lonE] else: if fld.grid.zdim > 1: data[i] = np.squeeze(fld.data)[depth_level, latS:latN, lonW:lonE] else: data[i] = np.squeeze(fld.data)[latS:latN, lonW:lonE] if plottype == 'vector': if field[0].interp_method == 'cgrid_velocity': logger.warning_once('Plotting a C-grid velocity field is achieved via an A-grid projection, reducing the plot accuracy') d = np.empty_like(data[0]) d[:-1, :] = (data[0][:-1, :] + data[0][1:, :]) / 2. d[-1, :] = data[0][-1, :] data[0] = d d = np.empty_like(data[0]) d[:, :-1] = (data[0][:, :-1] + data[0][:, 1:]) / 2. d[:, -1] = data[0][:, -1] data[1] = d spd = data[0] ** 2 + data[1] ** 2 speed = np.where(spd > 0, np.sqrt(spd), 0) vmin = speed.min() if vmin is None else vmin vmax = speed.max() if vmax is None else vmax ncar_cmap = copy.copy(plt.cm.gist_ncar) ncar_cmap.set_over('k') ncar_cmap.set_under('w') if isinstance(field[0].grid, CurvilinearGrid): x, y = plotlon[0], plotlat[0] else: x, y = np.meshgrid(plotlon[0], plotlat[0]) u = np.where(speed > 0., data[0]/speed, 0) v = np.where(speed > 0., data[1]/speed, 0) if cartopy: cs = ax.quiver(np.asarray(x), np.asarray(y), np.asarray(u), np.asarray(v), speed, cmap=ncar_cmap, clim=[vmin, vmax], scale=50, transform=cartopy.crs.PlateCarree()) else: cs = ax.quiver(x, y, u, v, speed, cmap=ncar_cmap, clim=[vmin, vmax], scale=50) else: vmin = data[0].min() if vmin is None else vmin vmax = data[0].max() if vmax is None else vmax pc_cmap = copy.copy(plt.cm.get_cmap('viridis')) pc_cmap.set_over('k') pc_cmap.set_under('w') assert len(data[0].shape) == 2 if field[0].interp_method == 'cgrid_tracer': d = data[0][1:, 1:] elif field[0].interp_method == 'cgrid_velocity': if field[0].fieldtype == 'U': d = np.empty_like(data[0]) d[:-1, :-1] = (data[0][1:, :-1] + data[0][1:, 1:]) / 2. elif field[0].fieldtype == 'V': d = np.empty_like(data[0]) d[:-1, :-1] = (data[0][:-1, 1:] + data[0][1:, 1:]) / 2. else: # W d = data[0][1:, 1:] else: # if A-grid d = (data[0][:-1, :-1] + data[0][1:, :-1] + data[0][:-1, 1:] + data[0][1:, 1:])/4. d = np.where(data[0][:-1, :-1] == 0, 0, d) d = np.where(data[0][1:, :-1] == 0, 0, d) d = np.where(data[0][1:, 1:] == 0, 0, d) d = np.where(data[0][:-1, 1:] == 0, 0, d) if cartopy: cs = ax.pcolormesh(plotlon[0], plotlat[0], d, cmap=pc_cmap, transform=cartopy.crs.PlateCarree()) else: cs = ax.pcolormesh(plotlon[0], plotlat[0], d, cmap=pc_cmap) if cartopy is None: ax.set_xlim(np.nanmin(plotlon[0]), np.nanmax(plotlon[0])) ax.set_ylim(np.nanmin(plotlat[0]), np.nanmax(plotlat[0])) elif domain is not None: ax.set_extent([np.nanmin(plotlon[0]), np.nanmax(plotlon[0]), np.nanmin(plotlat[0]), np.nanmax(plotlat[0])], crs=cartopy.crs.PlateCarree()) cs.set_clim(vmin, vmax) cartopy_colorbar(cs, plt, fig, ax) timestr = parsetimestr(field[0].grid.time_origin, show_time) titlestr = kwargs.pop('titlestr', '') if field[0].grid.zdim > 1: if field[0].grid.gtype in [GridCode.CurvilinearZGrid, GridCode.RectilinearZGrid]: gphrase = 'depth' depth_or_level = field[0].grid.depth[depth_level] else: gphrase = 'level' depth_or_level = depth_level depthstr = ' at %s %g ' % (gphrase, depth_or_level) else: depthstr = '' if plottype == 'vector': ax.set_title(titlestr + 'Velocity field' + depthstr + timestr) else: ax.set_title(titlestr + field[0].name + depthstr + timestr) if not spherical: ax.set_xlabel('Zonal distance [m]') ax.set_ylabel('Meridional distance [m]') plt.draw() if savefile: plt.savefig(savefile) logger.info('Plot saved to ' + savefile + '.png') plt.close() return plt, fig, ax, cartopy
[docs]def create_parcelsfig_axis(spherical, land=True, projection=None, central_longitude=0, cartopy_features=[]): try: import matplotlib.pyplot as plt except: logger.info("Visualisation is not possible. Matplotlib not found.") return None, None, None, None # creating axes was not possible if projection is not None and not spherical: raise RuntimeError('projection not accepted when Field doesn''t have geographic coordinates') if spherical: try: import cartopy except: logger.info("Visualisation of field with geographic coordinates is not possible. Cartopy not found.") return None, None, None, None # creating axes was not possible projection = cartopy.crs.PlateCarree(central_longitude) if projection is None else projection fig, ax = plt.subplots(1, 1, subplot_kw={'projection': projection}) try: # gridlines not supported for all projections if isinstance(projection, cartopy.crs.PlateCarree) and central_longitude != 0: gl = ax.gridlines(crs=cartopy.crs.PlateCarree(), draw_labels=True) # central_lon=0 necessary for correct xlabels else: gl = ax.gridlines(crs=projection, draw_labels=True) gl.top_labels, gl.right_labels = (False, False) gl.xformatter = cartopy.mpl.gridliner.LONGITUDE_FORMATTER gl.yformatter = cartopy.mpl.gridliner.LATITUDE_FORMATTER except: pass for feature in cartopy_features: ax.add_feature(feature) if isinstance(land, str): ax.coastlines(land) elif land: ax.coastlines() else: cartopy = None fig, ax = plt.subplots(1, 1) ax.grid() return plt, fig, ax, cartopy
[docs]def parsedomain(domain, field): field.grid.check_zonal_periodic() if domain is not None: if not isinstance(domain, dict) and len(domain) == 4: # for backward compatibility with <v2.0.0 domain = {'N': domain[0], 'S': domain[1], 'E': domain[2], 'W': domain[3]} _, _, _, lonW, latS, _ = field.search_indices(domain['W'], domain['S'], 0, 0, 0, search2D=True) _, _, _, lonE, latN, _ = field.search_indices(domain['E'], domain['N'], 0, 0, 0, search2D=True) return latN+1, latS, lonE+1, lonW else: if field.grid.gtype in [GridCode.RectilinearSGrid, GridCode.CurvilinearSGrid]: return field.grid.lon.shape[0], 0, field.grid.lon.shape[1], 0 else: return len(field.grid.lat), 0, len(field.grid.lon), 0
[docs]def parsetimestr(time_origin, show_time): if time_origin.calendar is None: return ' after ' + str(delta(seconds=show_time)) + ' hours' else: date_str = str(time_origin.fulltime(show_time)) return ' on ' + date_str[:10] + ' ' + date_str[11:19]
[docs]def cartopy_colorbar(cs, plt, fig, ax): cbar_ax = fig.add_axes([0, 0, 0.1, 0.1]) fig.subplots_adjust(hspace=0, wspace=0, top=0.925, left=0.1) plt.colorbar(cs, cax=cbar_ax) def resize_colorbar(event): plt.draw() posn = ax.get_position() cbar_ax.set_position([posn.x0 + posn.width + 0.01, posn.y0, 0.04, posn.height]) fig.canvas.mpl_connect('resize_event', resize_colorbar) resize_colorbar(None)