Source code for hydrobox.stats.regression

"""Regression Module

Common tools for calculating linear regression and produce scatter plots.

"""
from itertools import product

from scipy.stats import linregress
import matplotlib.pyplot as plt
import numpy as np

from hydrobox.toolbox import merge


[docs]def linear_regression(*x, df=None, plot=False, ax=None, notext=False): """Linear Regression tool This tool can be used for a number of regression related tasks. It can calculate a linear regression between two observables and also return a scatter plot including the regression parameters and function. In case more than two ``Series`` or ``arrays`` are passed, they will be merged into a ``DataFrame`` and a linear regression between all combinations will be calculated and potted if desired. Parameters ---------- x : pandas.Series, numpy.ndarray If df is None, at least two Series or arrays have to be passed. If more are passed, a multi output will be produced. df : pandas.DataFrame If df is set, all x occurrences will be ignored. DataFrame of the input to be used for calculating the linear regression, This attribute can be useful, whenever a multi input to x does not get merged correctly. Note that linear_regression will only use the DataFrame.data array and ignore all other structural elements. plot : bool If True, the function will output a matplotlib Figure or plot into an existing instance. If False (default) the data used for the plots will be returned. ax : matplotlib.Axes.Axessubplot Has to be a single matplotlib Axes instance if two data sets are passed or a list of Axes if more than two data sets are passed. notext : bool If True, the output of the fitting parameters as a text into the plot will be suppressed. This setting is ignored, is plot is set to False. Returns ------- matplotlib.Figure numpy.ndarray Notes ----- If plot is True and ax is not None, the number of passed Axes has to match the total combinations between the data sets. This is .. math:: N^2 where N is the length of x, or the length of df.columns. .. warning:: This function does just calculate a linear regression. It handles a multi input recursively and has some data wrangling overhead. If you are seeking a fast linear regression tool, use the scipy.stats.linregress function directly. """ # combine all inputs if df is None: df = merge(*x, dropna=True) if len(df.columns) == 2: slope, intercept, rvalue, pvalue, stderr = linregress(df.values[:, 0], df.values[:, 1]) # calculate regression reg = dict(slope=slope, intercept=intercept, rvalue=rvalue, pvalue=pvalue, stderr=stderr) # return data if not plot: return reg # return plot else: # pragma: no cover # build the ax if necessary if ax is None: fig, ax = plt.subplots(1, 1, figsize=(6, 6)) else: fig = ax.get_figure() ax.scatter(df.values[:, 0], df.values[:, 1], 25, marker='.', color='blue') if not notext: ax.text(0.05, 0.95, '\n'.join( ['%s: %.2f' % (k, v) for k, v in reg.items()]), ha='left', va='top', transform=ax.transAxes ) x = df.values[:, 0] y = x * reg['slope'] + reg['intercept'] ax.plot(x, y, '-g') return fig # more than two are given, call recursively else: # get the number of data sets n = len(df.columns) # create a plot if plot: # pragma: no cover if ax is None: fig, axes = plt.subplots(n, n, figsize=(n * 5, n * 5), sharex=True, sharey=True) else: axes = ax fig = axes.flatten()[0].get_figure() # make sure there are enough AxesSubplots assert len(axes.flatten()) == n * n # call linear regression for all combinations for i, col_names in zip( range(n * n), product(df.columns, df.columns) ): if col_names[0] == col_names[1]: # TODO: as soon as implemented, use the histogram tool here continue linear_regression(df=df[[*col_names]], plot=True, ax=axes.flatten()[i], notext=notext) # return the whole figure plt.tight_layout() return fig else: p = product(df.columns, df.columns) # define a sorting function # before the .values method of the resulting dictionary was used, # but Python 3.5 and 3.6 return the dictionary with a different # order. therefore a explicit sorting function is needed def to_array(d): return [d['slope'], d['intercept'], d['rvalue'], d['pvalue'], d['stderr']] # old function, kept for reference # to_array = lambda d: np.array(list(d.values())) res = [to_array(linear_regression(df[[*n]])) for n in p] # build the results np_res = [np.array(el).reshape((n, n)) for el in zip(*res)] for el in np_res: np.fill_diagonal(el, np.NaN) # extract and return slope, intercept, rvalue, pvalue, stderr = np_res return slope, intercept, rvalue, pvalue, stderr