Tristen.

Bivariate Data Exploration with Matplotlib & Seaborn

January 13, 2024 (11mo ago)

Bivariate Data Exploration with Matplotlib & Seaborn

Bivariate plots investigate relationships between pairs of variables in your data. This typically builds off of findings in univariate exploration, where you explore the distributions of different variables.

There are three combinations of variables for bivariate plots:

  • Qualitative vs. Qualitative (cat|cat)
  • Qualitative vs. Quantitative (cat|#)
  • Quantitative vs. Quantitative (#|#)

The main plots I’m going to go over are:

  • Scatterplots (#|#)
  • Heatmaps (#|#)
  • Violin plots (cat|#)
  • Box plots (cat|#)
  • Clustered bar charts (cat|cat)
  • Line plots (#)

All of the example plots I use throughout this blog post are derived from the Prosper loan data provided to me by Udacity while taking their data analyst course. You can view the full project on GitHub.

Import Packages

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Import cleaned dataset
df = pd.read_csv('../data/prosperLoanData_clean_v1.csv')
df['LoanOriginationDate'] = pd.to_datetime(df['LoanOriginationDate'])

# Create primary color for use in plots
color = sns.color_palette("Blues")[4]

Scatterplots

Scatterplots are the most common choice for visualizing the relationship between two quantitative variables. Based on the density and direction of plotted (x,y) points, we can gauge how these two variables correlate with each other.

Scatterplots with Matplotlib

We create a scatterplot by calling plt.scatter and providing arguments for ‘x’ and ‘y’.

To deal with overplotting, we can manipulate transparency and size. Transparency is controlled by the alpha= argument and takes a value from 0 to 1. Size is given by s= which is interpreted as points**2 (where typographic points are 1/72 in.)

# Take sample of data to prevent overplotting
df_sample = df.sample(5000, random_state=5)

# Plot scatterplot
plt.scatter(data=df_sample, x='DebtToIncomeRatio', y='LoanOriginalAmount',
            alpha=0.2, s=25)
plt.xlim(xmin=0, xmax = 1)

# Labels & Title
plt.ylabel('Loan Amount ($)')
plt.xlabel('DebtToIncomeRatio')
plt.title('Debt to Income Ratio vs. Loan Amount')

# Remove border box
sns.despine(trim=True, left=True);

scatterplots with matplotlib

Scatterplots with Seaborn

We can plot with Seaborn by using the same code as we did in Matplotlib and simply replacing plt.scatter with sns.scatterplot. We achieve a very similar graph, styled slightly differently due to Seaborn’s defaults.

scatterplots with seaborn

Regression plot with Seaborn

Seaborn has a second scatter function called sns.regplot(), which combines scatterplot creation with regression function fitting. Its syntax is once again very similar to the previous two, but we have to use scatter_kws={'alpha':0.2} to specify transparency. We use line_kws={} to specify various formatting for the regression line.

# Take sample of data to prevent overplotting
df_sample = df.sample(1000, random_state=5)

# Plot scatterplot
sns.regplot(data=df_sample, x='DebtToIncomeRatio', y='LoanOriginalAmount',
                scatter_kws={'alpha':0.2}, line_kws={"color": "green"})
plt.xlim(xmin=0, xmax = 1)

# Labels & Title
plt.ylabel('Loan Amount ($)')
plt.xlabel('DebtToIncomeRatio')
plt.title('Debt to Income Ratio vs. Loan Amount')

# Remove border box
sns.despine(trim=True, left=True);
plt.savefig('../plots/sns_scatter.png');

Regression plot with seaborn

In the scatter plot above, by default, the regression function is linear and includes a shaded confidence region for the regression estimate. In this example, there is no correlation as the line is nearly horizontal.

Heat Maps

A heatmap is a 2D version of a histogram and is often used as an alternative to a scatterplot.

Default heat plot using Matplotlib

A simple heatmap just takes ‘x’ and ‘y’ arguments to plt.hist2d(). I’ve also set and assigned bins here to more neatly group the data.

# Set bins
bins_x = np.arange(0, 1, 0.1)
bins_y = np.arange(0, 35000, 5000)

# Plot heatmap
plt.hist2d(data=df, x='DebtToIncomeRatio', y='LoanOriginalAmount',
            bins=[bins_x, bins_y]);

Default heat plot using Matplotlib

Polish Your Histogram

To change the color palette in your plot, adjust the cmap= parameter in the plt.hist2d() function. This is easily done by setting cmap= to a string that refers to one of Matplotlib's built-in palettes, such as "BuPu", as demonstrated in our example. You can find other palettes here.

It’s also possible to differentiate between cells with zero and non-zero counts using the cmin parameter. By setting cmin=0.5 in plt.hist2d(), cells are colored only if they contain at least one data point. For instance, in our large dataset, to visually omit cells with fewer than 100 data points, we use cmin=100.

To further enhance the plot, we incorporate a colorbar using the plt.colorbar() function. This color bar illustrates the mapping between counts and their respective colors in the plot.

# Set bins
bins_x = np.arange(0, 1, 0.1)
bins_y = np.arange(0, 35000, 5000)

# Plot heatmap
plt.hist2d(data=df, x='DebtToIncomeRatio', y='LoanOriginalAmount', bins=[bins_x, bins_y],
            cmin=100, cmap='BuPu')

# Add colorbar
plt.colorbar()

# Labels & Title
plt.ylabel('Loan Amount ($)')
plt.xlabel('DebtToIncomeRatio')
plt.title('Debt to Income Ratio vs. Total Amount Borrowed');

Polish Your Histogram

Add Annotations

When dealing with large datasets, adding annotations to each cell in a plot can help indicate the number of data points in each area. While using plt.hist2d(), this involves placing text elements individually. The necessary counts for these annotations are conveniently provided by plt.hist2d(), which returns not just the plot object but also an array of counts and bin edge vectors.

Keep in mind, if your heatmap is densely populated with cells, adding annotations can make it cluttered and difficult to interpret. In such cases, it’s best to leave off the annotations. Instead, allow the visual representation of the data and the accompanying colorbar to convey the necessary information.

# Loop through the cell counts and add text annotations for each
counts = g[0]
for i in range(counts.shape[0]):
    for j in range(counts.shape[1]):
        c = counts[i,j]
        if c >= 3500: # increase visibility on darker cells
            plt.text(bins_x[i]+.05, bins_y[j]+2500, int(c),
                    ha = 'center', va = 'center', color = 'white', size='xx-small')
        elif c > 0:
            plt.text(bins_x[i]+.05, bins_y[j]+2500, int(c),
                    ha = 'center', va = 'center', color = 'black', size='xx-small')

Add Annotations to heatmap

Box Plots

Box plots display descriptive statistics such as mean and IQR to compare qualitative data with quantitative data. You may also see them called “box-and-whisker” plots due to their cat-like features 😸.

The Matplotlib diagram below highlights what the lines in the plot represent.

Q1-1.5IQR   Q1   median  Q3   Q3+1.5IQR
                  |-----:-----|
  o      |--------|     :     |--------|    o  o
                  |-----:-----|
outliers          <----------->            outliers
                       IQR

Plotting Box Plots with Seaborn

# Plot boxplot
sns.boxplot(data=df, y='LoanStatus', x='LoanOriginalAmount', color=color);

Plotting Box Plots with Seaborn

Customize Your Boxplots

Changing the width of the bars can make for a more aesthetic graph, which can be specified as an argument in sns.boxplot() with width=.

One way to improve the readability of our boxplot is by adding gridlines. This can be done with .axis.grid().

Axis spines are the lines confining the plot area. We can clean up our plot in some instances by

removing them with sns.despine().

You can typically rotate the orientation of your plot by simply swapping the ‘x’ and ‘y’ variables.

# Set figure size based on Medium's aspect ratio
f, ax = plt.subplots(figsize=(12.5, 6.563))

# Plot boxplot with specific width
g = sns.boxplot(data=df, y='LoanStatus', x='LoanOriginalAmount',
                color=color, width=.45)

# Add gridlines
g.xaxis.grid(True)

# Set labels and title
g.set(ylabel="", xlabel="Amount ($)", title="Loan Amount by Loan Status")

# Remove spines
sns.despine(trim=True, left=True);

Customize Your Boxplots

Superimpose A Strip Plot

A strip plot displays the marginal distribution of your variable. It can be used as a supporting plot by superimposing it onto another, such as our boxplot.

# Get smaller subset of dataset
df_sample = df.sample(1500)

# Set figure size based on Medium's aspect ratio
f, ax = plt.subplots(figsize=(12.5, 6.563))

# Plot boxplot with specific width
g = sns.boxplot(data=df_sample, y='LoanStatus', x='LoanOriginalAmount',
                  color=color, width=.45)

# Add gridlines
g.xaxis.grid(True)

# Set labels and title
g.set(ylabel="", xlabel="Amount ($)", title="Loan Amount by Loan Status")

# Add points for observations
sns.stripplot(df_sample, y="LoanStatus", x="LoanOriginalAmount",
                size=4, color=".8")

# Remove border box
sns.despine(trim=True, left=True);

Superimpose A Strip Plot

Violin Plots

Violin plots are similar to box plots in that they compare a qualitative variable with a quantitative one. However, instead of visualizing summary statistics, they display a kernel density estimation (KDE) for each category.

KDE plots measure data density, where the area under the curve sums to 1. To determine the probability of an outcome within a specific range, you compute the area under the curve within that range (use of integration).

Plotting Violin Plots with Seaborn

The basic syntax is very easy, you just supply ‘x’ and ‘y’ variables to the sns.violinplot() function.

sns.violinplot(data=df_sample, x='LoanStatus', y='LoanOriginalAmount',
                color=color);

Violin Plots with Seaborn

Cleaning It Up

We can polish our plot in much the same way we did with the box plot. Most of the changes are the same, such as setting figsize, customizing the axis labels and title, adding gridlines, and removing the plot spines.

Violin plots, by default, contain a miniature box plot inside each curve. To simplify the look of our final plot, we can also remove this by providing inner=None when we call the sns.violinplot() function.

# Set figure size based on Medium's aspect ratio
f, ax = plt.subplots(figsize=(12.5, 6.563))

# Plot violin plot with specific width
g = sns.violinplot(data=df, x='LoanStatus', y='LoanOriginalAmount',
                color=color, inner=None)

# Add gridlines
g.yaxis.grid(True)

# Set labels and title
g.set(xlabel="", ylabel="Amount ($)", title="Loan Amount by Loan Status")

# Remove border box
sns.despine(trim=True, left=True);

polishing violin plots seaborn

Clustered Bar Charts

Clustered bar charts are used to show the relationship between two qualitative variables.

Clustered Bar Charts with Seaborn

Like a typical bar chart, we’re depicting the count values for each group within a categorical variable. With clustered bar charts, we’re comparing across a second categorical variable. They’re best used when you want to look at how the second category variable changes within each level of the first, or when you want to look at how the first category variable changes across levels of the second. The second case is what we’ll be looking at below.

Often you’ll have to prep your categories first. It’s a good idea to limit the number of groups in the category being clustered to improve readability. Too many groups can make it harder to interpret. 2–4 groups is a good range. You can also convert the category you are clustering by to an ordered dtype so that columns will be in a standard order across plots.

# Convert credit types to relevant time period
df.CreditType.replace({'CreditGrade':'Before 2009-07','ProsperRating':'After 2009-07'}, inplace=True)
df.rename({'CreditType':'TimePeriod'}, axis=1, inplace=True)

# Convert LoanStatus dtype to ordered category
LS_cat = pd.CategoricalDtype(categories=['Completed', 'Current', 'Past Due', 'Defaulted',
                                        'Chargedoff'], ordered=True)
df.LoanStatus = df.LoanStatus.astype(LS_cat)

We then plot the same way we would a univariate bar plot, with sns.countplot(). We add the category being clustered with the hue="" argument.

sns.countplot(data=df, x='LoanStatus', hue='TimePeriod');

Clustered Bar Charts with Seaborn

Once again, grid lines and custom labels can help to clean up our final visual.

g = sns.countplot(data=df, x='LoanStatus', hue='TimePeriod');

# Add gridlines
g.yaxis.grid(True)

# Set labels and title
g.set(xlabel="", ylabel="Amount ($)", title="Loan Amount by Loan Status")

# Remove border box
sns.despine(trim=True, left=True);

grid lines and custom labels

Line Plots

Plotting Lines with Matplotlib

Matplotlib’s .errorbar() essentially plots a scatter plot and connects each of the points — you can imagine how messy that might be. We can resolve this chaos by aggregating our data first.

In the example below, we create bins to group our independent variable into distinct intervals. We then use these bins to calculate the mean and standard deviation of our independent variable within each bin. For the purpose of plotting the line, we technically only need the mean. However, .errorbar() gives us the option to attach error bars, which we use standard deviation for.

## Set bin edges, and compute center of each bin
bin_edges = np.arange(0, 1+0.1, 0.1)
bin_centers = bin_edges[:-1] + 0.05

## Use the bins to group DebtToIncomeRatio into discrete intervals. Returns a Series object.
dir_bins = pd.cut(df['DebtToIncomeRatio'], bin_edges, include_lowest=True)
dir_bins

# Calculate mean and standard dev. for loan amount in each bin
amount_mean = df['LoanOriginalAmount'].groupby(dir_bins).mean()
amount_std = df['LoanOriginalAmount'].groupby(dir_bins).std()

## Plot the summarized data
plt.errorbar(x=bin_centers, y=amount_mean, yerr=amount_std)
plt.ylabel('Loan Amount ($)')
plt.xlabel('Debt to Income Ratio');

Plotting lines with Matplotlib

Plotting Time-Series with Seaborn

A time-series plot is a type of line plot in which the x-axis represents the time, and the y-axis represents the variable being measured. Timestamps tend to be pretty granular — you might have multiple records per second or minute. As you can see below, that granularity creates a lot of noise that can make it more difficult to see trends (but can be useful in some instances, like looking for outliers.)

# Set figure size based on Medium's aspect ratio
f, ax = plt.subplots(figsize=(12.5, 6.563))

# Group mean loan amount by date
amount_by_date = df.groupby('LoanOriginationDate')['LoanOriginalAmount'].mean()
amount_by_date = amount_by_date.reset_index(drop=False)

# Plot lineplot
sns.lineplot(data=amount_by_date);

Plotting Time-Series with Seaborn

We’ll fix two problems in the following graph. The x-axis is pulling from the index, not our dates. Also, we’ll smooth out the graph by grouping by year.

# Group mean loan amount by year
amount_by_date = df.groupby(df['LoanOriginationDate'].dt.to_period('y'))\
                  ['LoanOriginalAmount'].mean()
amount_by_date = amount_by_date.reset_index(drop=False)
amount_by_date.rename(columns={'LoanOriginationDate':'Year',
                      'LoanOriginalAmount':'LoanAmount'}, inplace=True)

# sns.lineplot can not handle period Dtype, convert to timestamp
amount_by_date.Year = amount_by_date.Year.dt.to_timestamp()

# Plot lineplot
g = sns.lineplot(x=amount_by_date.Year, y=amount_by_date.LoanAmount);

# Rotate xticks
plt.xticks(rotation='vertical');

smooth out the line graph

This is great for looking at long-term trends, but what if we need to zoom in on month? For instance, what’s going on in 2009, let’s take a look at a monthly scale. We can also clean this visualization up a bit like we’ve done in previous examples.

#Set figure size based on Medium aspect ratio
f, ax = plt.subplots(figsize=(12.5, 6.563))

# Group mean loan amount by Month
amount_by_date = df.groupby(df['LoanOriginationDate'].dt.to_period('m'))\
                            ['LoanOriginalAmount'].mean()
amount_by_date = amount_by_date.reset_index(drop=False)
amount_by_date.rename(columns={'LoanOriginationDate':'Month',
                      'LoanOriginalAmount':'LoanAmount'}, inplace=True)

# sns.lineplot can not handle period Dtype, convert to timestamp
amount_by_date.Month = amount_by_date.Month.dt.to_timestamp()

# Plot line plot
g = sns.lineplot(x=amount_by_date.Month, y=amount_by_date.LoanAmount,
                  linewidth = 2.5)

# Rotate x ticks
plt.xticks(rotation=90)

# Add gridlines
g.yaxis.grid(True)

# Set labels and title
g.set(xlabel="", ylabel="Amount ($)", title="Loan Amount by Month")

#Remove border box
sns.despine(left=True);

Conclusion

In summary, bivariate plots are vital for understanding relationships between dataset features. Using Matplotlib and Seaborn, we explored various plot types — scatterplots, box plots, clustered bar charts, and more — to visualize and interpret data interactions. Our exploration, using the Prosper loan dataset, highlights how these tools can uncover patterns and trends, aiding in informed data-driven decision-making.