How to plot multiple Seaborn Jointplot in Subplot
Moving axes in matplotlib is not as easy as it used to be in previous versions. The below is working with the current version of matplotlib.
As has been pointed out at several places (this question, also this issue) several of the seaborn commands create their own figure automatically. This is hardcoded into the seaborn code, so there is currently no way to produce such plots in existing figures. Those are PairGrid
, FacetGrid
, JointGrid
, pairplot
, jointplot
and lmplot
.
There is a seaborn fork available which would allow to supply a subplot grid to the respective classes such that the plot is created in a preexisting figure. To use this, you would need to copy the axisgrid.py
from the fork to the seaborn folder. Note that this is currently restricted to be used with matplotlib 2.1 (possibly 2.0 as well).
An alternative could be to create a seaborn figure and copy the axes to another figure. The principle of this is shown in this answer and could be extended to Searborn plots. The implementation is a bit more complicated that I had initially expected. The following is a class SeabornFig2Grid
that can be called with a seaborn grid instance (the return of any of the above commands), a matplotlib figure and a subplot_spec
, which is a position of a gridspec
grid.
Note: This is a proof of concept, it may work for most easy cases, but I would not recommend using it in production code.
import matplotlib.pyplot as pltimport matplotlib.gridspec as gridspecimport seaborn as snsimport numpy as npclass SeabornFig2Grid(): def __init__(self, seaborngrid, fig, subplot_spec): self.fig = fig self.sg = seaborngrid self.subplot = subplot_spec if isinstance(self.sg, sns.axisgrid.FacetGrid) or \ isinstance(self.sg, sns.axisgrid.PairGrid): self._movegrid() elif isinstance(self.sg, sns.axisgrid.JointGrid): self._movejointgrid() self._finalize() def _movegrid(self): """ Move PairGrid or Facetgrid """ self._resize() n = self.sg.axes.shape[0] m = self.sg.axes.shape[1] self.subgrid = gridspec.GridSpecFromSubplotSpec(n,m, subplot_spec=self.subplot) for i in range(n): for j in range(m): self._moveaxes(self.sg.axes[i,j], self.subgrid[i,j]) def _movejointgrid(self): """ Move Jointgrid """ h= self.sg.ax_joint.get_position().height h2= self.sg.ax_marg_x.get_position().height r = int(np.round(h/h2)) self._resize() self.subgrid = gridspec.GridSpecFromSubplotSpec(r+1,r+1, subplot_spec=self.subplot) self._moveaxes(self.sg.ax_joint, self.subgrid[1:, :-1]) self._moveaxes(self.sg.ax_marg_x, self.subgrid[0, :-1]) self._moveaxes(self.sg.ax_marg_y, self.subgrid[1:, -1]) def _moveaxes(self, ax, gs): #https://stackoverflow.com/a/46906599/4124317 ax.remove() ax.figure=self.fig self.fig.axes.append(ax) self.fig.add_axes(ax) ax._subplotspec = gs ax.set_position(gs.get_position(self.fig)) ax.set_subplotspec(gs) def _finalize(self): plt.close(self.sg.fig) self.fig.canvas.mpl_connect("resize_event", self._resize) self.fig.canvas.draw() def _resize(self, evt=None): self.sg.fig.set_size_inches(self.fig.get_size_inches())
The usage of this class would look like this:
import matplotlib.pyplot as pltimport matplotlib.gridspec as gridspecimport seaborn as sns; sns.set()import SeabornFig2Grid as sfgiris = sns.load_dataset("iris")tips = sns.load_dataset("tips")# An lmplotg0 = sns.lmplot(x="total_bill", y="tip", hue="smoker", data=tips, palette=dict(Yes="g", No="m"))# A PairGridg1 = sns.PairGrid(iris, hue="species")g1.map(plt.scatter, s=5)# A FacetGridg2 = sns.FacetGrid(tips, col="time", hue="smoker")g2.map(plt.scatter, "total_bill", "tip", edgecolor="w")# A JointGridg3 = sns.jointplot("sepal_width", "petal_length", data=iris, kind="kde", space=0, color="g")fig = plt.figure(figsize=(13,8))gs = gridspec.GridSpec(2, 2)mg0 = sfg.SeabornFig2Grid(g0, fig, gs[0])mg1 = sfg.SeabornFig2Grid(g1, fig, gs[1])mg2 = sfg.SeabornFig2Grid(g2, fig, gs[3])mg3 = sfg.SeabornFig2Grid(g3, fig, gs[2])gs.tight_layout(fig)#gs.update(top=0.7)plt.show()
Note that there might be several drawbacks from copying axes and the above is not (yet) tested thoroughly.
It can not be easily done without hacking. jointplot
calls JointGrid
method, which in turn creates a new figure
object every time it is called.
Therefore, the hack is to make two jointplots (JG1
JG2
), then make a new figure, then migrate the axes objects from JG1
JG2
to the new figure created.
Finally, we adjust the sizes and the positions of subplots in the new figure we just created.
JG1 = sns.jointplot("C1", "C2", data=df, kind='reg')JG2 = sns.jointplot("C1", "C2", data=df, kind='kde')#subplots migrationf = plt.figure()for J in [JG1, JG2]: for A in J.fig.axes: f._axstack.add(f._make_key(A), A)#subplots size adjustmentf.axes[0].set_position([0.05, 0.05, 0.4, 0.4])f.axes[1].set_position([0.05, 0.45, 0.4, 0.05])f.axes[2].set_position([0.45, 0.05, 0.05, 0.4])f.axes[3].set_position([0.55, 0.05, 0.4, 0.4])f.axes[4].set_position([0.55, 0.45, 0.4, 0.05])f.axes[5].set_position([0.95, 0.05, 0.05, 0.4])
It is a hack because we are now using _axstack
and _add_key
private methods, which might and might not stay the same as they are now in matplotlib
future versions.
If you get into trouble despite the elegant solution of @ImportanceOfBeingErnest, you can still save seaborn plots to memory as images and use them to build your custom figure. Use other formats than '.png' if you seek a higher resolution.
Here is the example is shown above using this nasty (but working) approach:
import matplotlib.image as mpimgimport matplotlib.pyplot as pltimport seaborn as sns# datairis = sns.load_dataset("iris")tips = sns.load_dataset("tips")############### 1. CREATE PLOTS# An lmplotg0 = sns.lmplot(x="total_bill", y="tip", hue="smoker", data=tips, palette=dict(Yes="g", No="m"))# A PairGridg1 = sns.PairGrid(iris, hue="species")g1.map(plt.scatter, s=5)# A FacetGridg2 = sns.FacetGrid(tips, col="time", hue="smoker")g2.map(plt.scatter, "total_bill", "tip", edgecolor="w")# A JointGridg3 = sns.jointplot("sepal_width", "petal_length", data=iris, kind="kde", space=0, color="g")############### 2. SAVE PLOTS IN MEMORY TEMPORALLYg0.savefig('g0.png')plt.close(g0.fig)g1.savefig('g1.png')plt.close(g1.fig)g2.savefig('g2.png')plt.close(g2.fig)g3.savefig('g3.png')plt.close(g3.fig)############### 3. CREATE YOUR SUBPLOTS FROM TEMPORAL IMAGESf, axarr = plt.subplots(2, 2, figsize=(25, 16))axarr[0,0].imshow(mpimg.imread('g0.png'))axarr[0,1].imshow(mpimg.imread('g1.png'))axarr[1,0].imshow(mpimg.imread('g3.png'))axarr[1,1].imshow(mpimg.imread('g2.png'))# turn off x and y axis[ax.set_axis_off() for ax in axarr.ravel()]plt.tight_layout()plt.show()