Source code for plotlp.modules.StyledAxes_LP.StyledAxes

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Date          : 2025-12-11
# Author        : Lancelot PINCET
# GitHub        : https://github.com/LancelotPincet
# Library       : plotLP
# Module        : StyledAxes

"""
A class using stored styles inside.
"""



# %% Libraries
from matplotlib.axes import Axes
from matplotlib import pyplot as plt
import matplotlib.projections as projections
import inspect
from corelp import prop
import numpy as np
from matplotlib.patches import Rectangle
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from plotlp import scalebar as scbar



# %% Class
[docs] class StyledAxes(Axes) : f''' A class using stored styles inside. Attributes ---------- style : dict associated figure style. polish_imscale : bool True to polish automatically imscale polish_grids : bool True to polish automatically grids polish_noborder : bool True to polish automatically no boders polish_equiscale : bool True to polish automatically equiscale Examples -------- >>> from plotlp import subplots ... >>> fig, axes = subplots(nrows=2, ncols=2) # StyledFigure, StyledAxes >>> axis = fig.axis # StyledAxes, current axis of fig >>> axis.polish() # applies all the polish methods for all the "polish_attr" attributes >>> axis.grids() # applies the polish grids, adds major and minor grids >>> axis.imscale() # applies the polish imscale, sets x and y limits to images borders >>> axis.noborder() # applies the polish no border, removes axis borders >>> axis.equiscale() # applies the polish equiscale, x and y unit have same size on plot >>> axis.implot(image, x, y, w, h) # plots an image to the coordinates, deformes to fit the box defined ''' name = "styled" @property def style(self) : return self.figure.style # Legend
[docs] def legend(self, *args, **kwargs) : with plt.style.context(self.style): return super().legend(*args, **kwargs)
# Imshow
[docs] def imshow(self, X, *args, barname=None, coordinates=None, **kwargs) : with plt.style.context(self.style): if coordinates is not None : x, y = coordinates dx, dy = (x[-1]-x[0]) / (len(x)-1) / 2, (y[-1]-y[0]) / (len(y)-1) / 2 extent = [x[0]-dx, x[-1]+dx, y[0]-dy, y[-1]+dy] kwargs.update(dict(extent=extent, aspect='auto', origin='lower')) im = super().imshow(X, *args, **kwargs) if coordinates is not None : self.invert_yaxis() Ny, Nx = X.shape self.set_box_aspect(Ny / Nx) self.polish_axis = False if barname is not None : self.figure.colorbar(im, barname=barname) return im
# Pcolormesh
[docs] def pcolormesh(self, *args, cmap=None, **kwargs): with plt.style.context(self.style): if cmap is None: cmap = plt.get_cmap(plt.rcParams['image.cmap']) return super().pcolormesh(*args, cmap=cmap, **kwargs)
# Bar
[docs] def bar(self, *args, **kwargs) : with plt.style.context(self.style): self.polish_grids = True return super().bar(*args, **kwargs)
# Implot
[docs] def implot(self, img, x, y, w, h, zorder=3, **kwargs) : newaxe = inset_axes(self, [x, y, w, h], transform=self.transData, zorder=zorder, axes_class=StyledAxes) newaxe.set_axis_off() kw = {'aspect':'auto','extent':[x, x+w, y, y+h],'origin':'lower'} kw.update(kwargs) im = newaxe.imshow(img, **kw) clip_rect = Rectangle((0, 0), 1, 1, transform=self.transAxes, facecolor="none") im.set_clip_path(clip_rect) return im
# Scalebar
[docs] def scalebar(self, *args, **kwargs) : with plt.style.context(self.style): scbar(self, *args, **kwargs)
# Facecolor
[docs] def set_facecolor(self, *args, **kwargs) : with plt.style.context(self.style): super().set_facecolor(*args, **kwargs)
### --- Polish functions --- polish_axis = True
[docs] def polish(self) : if not self.polish_axis : return if self.polish_grids : self.grids() if self.polish_imscale : self.imscale() if self.polish_noborders : self.noborders() if self.polish_equiscale : self.equiscale()
# grids @prop() def polish_grids(self) : return len(self.lines) > 0 or len(self.collections) > 0 # no grid with patches here grid_major = {'linestyle':'-', 'alpha':1} grid_minor = {'linestyle':'--', 'alpha':0.5}
[docs] def grids(self) : with plt.style.context(self.style) : if self.grid_major is not None and len(self.grid_major) > 0 : self.grid(which='major',**self.grid_major) if self.grid_minor is not None and len(self.grid_minor) > 0 : self.minorticks_on() # force enabling minor ticks self.grid(which='minor',**self.grid_minor)
# imscale @prop() def polish_imscale(self) : return len(self.get_images()) > 0
[docs] def imscale(self) : with plt.style.context(self.style) : xmax, ymax = 0, 0 for image in self.get_images() : # Get maximum image coordinates y, x = np.shape(image.get_array())[0:2] ymax, xmax = max(y, ymax), max(x, xmax) if self.get_autoscalex_on() : self.set_xlim(-0.5, xmax - 0.5) if self.get_autoscaley_on() : self.set_ylim(ymax - 0.5, -0.5)
# noborders @prop() def polish_noborders(self) : return not self.polish_grids
[docs] def noborders(self) : with plt.style.context(self.style) : self.set_axis_off()
# equiscale @prop() def polish_equiscale(self) : return False
[docs] def equiscale(self) : with plt.style.context(self.style) : self.set_aspect(aspect='equal', adjustable='box')
### --- Regenerate parent class methods in the given style --- def is_plottable(method_name, method_obj): """ Returns True if this method should be wrapped automatically. """ if hasattr(StyledAxes, method_name): # Already overriden return False if method_name.startswith("_"): # private return False # handled separately if not callable(method_obj): return False if not inspect.ismethoddescriptor(method_obj) and not inspect.isfunction(method_obj): return False # Heuristic: methods returning artists / sequences of artists # Usually are plotting methods; we accept them all here return True def wrap_method(method_name): """ Returns a wrapper method that applies the style context then calls the parent Axes method. """ def wrapper(self, *args, **kwargs): with plt.style.context(self.style): method = getattr(super(StyledAxes, self), method_name) return method(*args, **kwargs) wrapper.__name__ = method_name return wrapper # Dynamically inject each wrapper method into StyledAxes for name, obj in Axes.__dict__.items(): if is_plottable(name, obj): setattr(StyledAxes, name, wrap_method(name)) # Finally register projection projections.register_projection(StyledAxes) # %% Test function run if __name__ == "__main__": from corelp import test test(__file__)