1. 程式人生 > >Python calculate and plot correlation between multiple variables

Python calculate and plot correlation between multiple variables

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.collections import EllipseCollection

data = pandas.DataFrame([["A", 4, 0, 1, 27], 
                         ["B", 7, 1, 1, 29], 
                         ["C", 6, 1, 0, 23], 
                         ["D", 2, 0, 0, 20], 
                         ["etc.", 3, 0, 1, 21]], 
                         columns=["ID", "score", "male", "age20", "BMI"])
print (data.corr())

def plot_corr_ellipses(data, ax=None, **kwargs):

    M = np.array(data)
    if not M.ndim == 2:
        raise ValueError('data must be a 2D array')
    if ax is None:
        fig, ax = plt.subplots(1, 1, subplot_kw={'aspect':'equal'})
        ax.set_xlim(-0.5, M.shape[1] - 0.5)
        ax.set_ylim(-0.5, M.shape[0] - 0.5)

    # xy locations of each ellipse center
    xy = np.indices(M.shape)[::-1].reshape(2, -1).T

    # set the relative sizes of the major/minor axes according to the strength of
    # the positive/negative correlation
    w = np.ones_like(M).ravel()
    h = 1 - np.abs(M).ravel()
    a = 45 * np.sign(M).ravel()

    ec = EllipseCollection(widths=w, heights=h, angles=a, units='x', offsets=xy,
                           transOffset=ax.transData, array=M.ravel(), **kwargs)
    ax.add_collection(ec)

    # if data is a DataFrame, use the row/column names as tick labels
    if isinstance(data, pd.DataFrame):
        ax.set_xticks(np.arange(M.shape[1]))
        ax.set_xticklabels(data.columns, rotation=90)
        ax.set_yticks(np.arange(M.shape[0]))
        ax.set_yticklabels(data.index)

    return ec

fig, ax = plt.subplots(1, 1)
m = plot_corr_ellipses(data.corr(), ax=ax, cmap='Greens')
cb = fig.colorbar(m)
cb.set_label('Correlation coefficient')
ax.margins(0.1)
current_fig = plt.gcf()  
current_fig.savefig('my_0.pdf', bbox_inches='tight')  

# install seaborn by running conda install -c anaconda seaborn=0.7.1
import seaborn as sns
sns.clustermap(data=data.corr(), annot=True, cmap='Greens').savefig('my_1.pdf', bbox_inches='tight')