Draw dividing line across subplots in matplotlib Draw dividing line across subplots in matplotlib numpy numpy

Draw dividing line across subplots in matplotlib


Borrowing from Pablo's helpful answer, it seems using fig.transFigure can access coordinates in each subplot, and you can draw lines between all of these coordinates. This is probably the best method as it makes the starting and ending points straightforward to determine. Since your x-coordinates are conveniently from 1-12, you can also plot each subplot in two parts to leave a gap between points for the annotation line to go through.

import numpy as npimport pandas as pdimport matplotlibimport matplotlib.pyplot as pltfrom matplotlib.patches import ConnectionPatchy = np.array([0,1,2,3,4])## recreate your datadf = pd.DataFrame({    'A':[0, 1, 1, 1, 2, 2, 3, 2, 3] + [float("nan")]*3,    'N':[1, 0, 0, 2, 1, 1, 2, 3, 3, 3, 3, 3],    'P':[0, 1, 1, 1, 2, 1, 1, 1, 2, 3, 3, 3],    },      index=range(1,13))fig, axs = plt.subplots(3, sharex=True, sharey=True)fig.suptitle("I1 - Reakce na změnu prvku")## create a gap in the lineaxs[0].plot(df.index[0:3],df['A'][0:3], color='lightblue', label="A", marker='.')axs[0].plot(df.index[3:12],df['A'][3:12], color='lightblue', label="A", marker='.')## create a gap in the lineaxs[1].plot(df.index[0:8],df['N'][0:8], color='darkblue', label="N", marker='.')axs[1].plot(df.index[8:12],df['N'][8:12], color='darkblue', label="N", marker='.')## create a gap in the lineaxs[2].plot(df.index[0:10],df['P'][0:10], color='blue', label="P", marker='.')axs[2].plot(df.index[10:12],df['P'][10:12], color='blue', label="P", marker='.')plt.yticks(np.arange(y.min(), y.max(), 1))transFigure = fig.transFigure.inverted()## Since your subplots have a ymax value of 3, setting the end y-coordinate## of each line to just above that value should help it display outside of the figurecoord1 = transFigure.transform(axs[0].transData.transform([3.5,3]))coord2 = transFigure.transform(axs[1].transData.transform([3.5,3.5]))coord3 = transFigure.transform(axs[1].transData.transform([8.5,3.5]))coord4 = transFigure.transform(axs[2].transData.transform([8.5,3.5]))coord5 = transFigure.transform(axs[2].transData.transform([10.5,3.5]))coord6 = transFigure.transform(axs[2].transData.transform([10.5,0]))## add a vertical dashed lineline1 = matplotlib.lines.Line2D((coord1[0],coord2[0]),(coord1[1],coord2[1]),                               transform=fig.transFigure,                               ls='--',                               color='grey')## add a horizontal dashed lineline2 = matplotlib.lines.Line2D((coord2[0],coord3[0]),(coord2[1],coord3[1]),                               transform=fig.transFigure,                               ls='--',                               color='grey')## add a vertical dashed lineline3 = matplotlib.lines.Line2D((coord3[0],coord4[0]),(coord3[1],coord4[1]),                               transform=fig.transFigure,                               ls='--',                               color='grey')## add a horizontal dashed lineline4 = matplotlib.lines.Line2D((coord4[0],coord5[0]),(coord4[1],coord5[1]),                               transform=fig.transFigure,                               ls='--',                               color='grey')## add a vertical dashed lineline5 = matplotlib.lines.Line2D((coord5[0],coord6[0]),(coord5[1],coord6[1]),                               transform=fig.transFigure,                               ls='--',                               color='grey')fig.lines.extend([line1, line2, line3, line4, line5])plt.show()

enter image description here


My instinct, for this kind of problem is to draw a line in figure coordinates. The one issue I had was finding the position of the center region between consecutive axes. My code is ugly, but it works, and is independent of the relative size of each axes, or the spacing between axes, as demonstrated below:

from matplotlib.lines import Line2Ddef grouper(iterable, n, fillvalue=None):    "Collect data into fixed-length chunks or blocks"    # grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx"    from itertools import zip_longest    args = [iter(iterable)] * n    return zip_longest(*args, fillvalue=fillvalue)xconn = [0.15, 0.35, 0.7]  # position of the vertical lines in each subplot, in data coordinatesfig, axs = plt.subplots(3,1, gridspec_kw=dict(hspace=0.6, height_ratios=[2,0.5,1]))## Draw the separation line, should be done at the very end when the limits of the axes have been set etc.## convert the value of xconn in each axis to figure coordinatesxconn = [fig.transFigure.inverted().transform(ax.transData.transform([x,0]))[0] for x,ax in zip(xconn,axs)]yconn = []  # y-values of the connecting lines, in figure coordinatesfor ax in axs:    bbox = ax.get_position()    yconn.extend([bbox.y1, bbox.y0])# replace each pairs of values corresponding to the bottom and top of each pairs of axes by the averageyconn[1:-1] = np.ravel([[np.mean(ys)]*2 for ys in grouper(yconn[1:-1], 2)]).tolist()l = Line2D(np.repeat(xconn,2), yconn, transform=fig.transFigure, ls='--', lw=1, c='k')fig.add_artist(l)

enter image description here