plot different color for different categorical levels using matplotlib plot different color for different categorical levels using matplotlib pandas pandas

plot different color for different categorical levels using matplotlib


Imports and Sample DataFrame

import matplotlib.pyplot as pltimport pandas as pdimport seaborn as sns  # for sample datafrom matplotlib.lines import Line2D  # for legend handle# DataFrame used for all optionsdf = sns.load_dataset('diamonds')   carat      cut color clarity  depth  table  price     x     y     z0   0.23    Ideal     E     SI2   61.5   55.0    326  3.95  3.98  2.431   0.21  Premium     E     SI1   59.8   61.0    326  3.89  3.84  2.312   0.23     Good     E     VS1   56.9   65.0    327  4.05  4.07  2.31

With matplotlib

You can pass plt.scatter a c argument, which allows you to select the colors. The following code defines a colors dictionary to map the diamond colors to the plotting colors.

fig, ax = plt.subplots(figsize=(6, 6))colors = {'D':'tab:blue', 'E':'tab:orange', 'F':'tab:green', 'G':'tab:red', 'H':'tab:purple', 'I':'tab:brown', 'J':'tab:pink'}ax.scatter(df['carat'], df['price'], c=df['color'].map(colors))# add a legendhandles = [Line2D([0], [0], marker='o', color='w', markerfacecolor=v, label=k, markersize=8) for k, v in colors.items()]ax.legend(title='color', handles=handles, bbox_to_anchor=(1.05, 1), loc='upper left')plt.show()

df['color'].map(colors) effectively maps the colors from "diamond" to "plotting".

(Forgive me for not putting another example image up, I think 2 is enough :P)

With seaborn

You can use seaborn which is a wrapper around matplotlib that makes it look prettier by default (rather opinion-based, I know :P) but also adds some plotting functions.

For this you could use seaborn.lmplot with fit_reg=False (which prevents it from automatically doing some regression).

  • sns.scatterplot(x='carat', y='price', data=df, hue='color', ec=None) also does the same thing.

Selecting hue='color' tells seaborn to split and plot the data based on the unique values in the 'color' column.

sns.lmplot(x='carat', y='price', data=df, hue='color', fit_reg=False)

enter image description here

With pandas.DataFrame.groupby & pandas.DataFrame.plot

If you don't want to use seaborn, use pandas.groupby to get the colors alone, and then plot them using just matplotlib, but you'll have to manually assign colors as you go, I've added an example below:

fig, ax = plt.subplots(figsize=(6, 6))grouped = df.groupby('color')for key, group in grouped:    group.plot(ax=ax, kind='scatter', x='carat', y='price', label=key, color=colors[key])plt.show()

This code assumes the same DataFrame as above, and then groups it based on color. It then iterates over these groups, plotting for each one. To select a color, I've created a colors dictionary, which can map the diamond color (for instance D) to a real color (for instance tab:blue).

enter image description here


Here's a succinct and generic solution to use a seaborn color palette.

First find a color palette you like and optionally visualize it:

sns.palplot(sns.color_palette("Set2", 8))

Then you can use it with matplotlib doing this:

# Unique category labels: 'D', 'F', 'G', ...color_labels = df['color'].unique()# List of RGB tripletsrgb_values = sns.color_palette("Set2", 8)# Map label to RGBcolor_map = dict(zip(color_labels, rgb_values))# Finally use the mapped valuesplt.scatter(df['carat'], df['price'], c=df['color'].map(color_map))


I had the same question, and have spent all day trying out different packages.

I had originally used matlibplot: and was not happy with either mapping categories to predefined colors; or grouping/aggregating then iterating through the groups (and still having to map colors). I just felt it was poor package implementation.

Seaborn wouldn't work on my case, and Altair ONLY works inside of a Jupyter Notebook.

The best solution for me was PlotNine, which "is an implementation of a grammar of graphics in Python, and based on ggplot2".

Below is the plotnine code to replicate your R example in Python:

from plotnine import *from plotnine.data import diamondsg = ggplot(diamonds, aes(x='carat', y='price', color='color')) + geom_point(stat='summary')print(g)

plotnine diamonds example

So clean and simple :)