Matplotlib: avoiding overlapping datapoints in a "scatter/dot/beeswarm" plot
Extending the answer by @user2467675, here’s how I did it:
def rand_jitter(arr): stdev = .01 * (max(arr) - min(arr)) return arr + np.random.randn(len(arr)) * stdevdef jitter(x, y, s=20, c='b', marker='o', cmap=None, norm=None, vmin=None, vmax=None, alpha=None, linewidths=None, verts=None, hold=None, **kwargs): return scatter(rand_jitter(x), rand_jitter(y), s=s, c=c, marker=marker, cmap=cmap, norm=norm, vmin=vmin, vmax=vmax, alpha=alpha, linewidths=linewidths, **kwargs)
The stdev
variable makes sure that the jitter is enough to be seen on different scales, but it assumes that the limits of the axes are zero and the max value.
You can then call jitter
instead of scatter
.
Seaborn provides histogram-like categorical dot-plots through sns.swarmplot()
and jittered categorical dot-plots via sns.stripplot()
:
import seaborn as snssns.set(style='ticks', context='talk')iris = sns.load_dataset('iris')sns.swarmplot('species', 'sepal_length', data=iris)sns.despine()
sns.stripplot('species', 'sepal_length', data=iris, jitter=0.2)sns.despine()
I used numpy.random to "scatter/beeswarm" the data along X-axis but around a fixed point for each category, and then basically do pyplot.scatter() for each category:
import matplotlib.pyplot as pltimport numpy as np#random data for category A, B, with B "taller"yA, yB = np.random.randn(100), 5.0+np.random.randn(1000)xA, xB = np.random.normal(1, 0.1, len(yA)), np.random.normal(3, 0.1, len(yB))plt.scatter(xA, yA)plt.scatter(xB, yB)plt.show()