bff.plot.plot_true_vs_pred

bff.plot.plot_true_vs_pred(y_true, y_pred, with_correlation=False, with_determination=True, with_histograms=False, with_identity=False, label_x='Ground truth', label_y='Prediction', title='Predicted vs Actual', ax=None, lim_x=None, lim_y=None, grid='both', figsize=14, 7, dpi=80, style='default', **kwargs)

Plot the ground truth against the predictions of the model.

If a DataFrame is provided, it must only contain one column.

Parameters
  • y_true (np.array or pd.DataFrame) – Actual values.

  • y_pred (np.array or pd.DataFrame) – Predicted values by the model.

  • with_correlation (bool, default False) – If true, print correlation coefficient in the top left corner.

  • with_determination (bool, default True) – If true, print the determination coefficient in the top left corner. If both with_correlation and with_determination are set to true, the correlation coefficient is printed.

  • with_histograms (bool, default False) – If true, plot histograms of y_true and y_pred on the sides. Not possible if the ax is provided.

  • with_identity (bool, default False) – If true, plot the identity line on the scatter plot.

  • label_x (str, default 'Ground truth') – Label for x axis.

  • label_y (str, default 'Prediction') – Label for y axis.

  • title (str, default 'Predicted vs Actual') – Title for the plot (axis level).

  • ax (plt.axes, optional) – Axes from matplotlib, if None, new figure and axes will be created.

  • lim_x (Tuple[TNum, TNum], optional) – Limit for the x axis. If None, automatically calculated according to the limits of the data, with an extra 5% for readability.

  • lim_y (Tuple[TNum, TNum], optional) – Limit for the y axis. If None, automatically calculated according to the limits of the data, with an extra 5% for readability.

  • grid (str or None, default 'both') – Axis where to activate the grid (‘both’, ‘x’, ‘y’). To turn off, set to None.

  • figsize (Tuple[int, int], default (14, 7)) – Size of the figure to plot.

  • dpi (int, default 80) – Resolution of the figure.

  • style (str, default 'default') – Style to use for matplotlib.pyplot. The style is use only in this context and not applied globally.

  • **kwargs – Additional keyword arguments to be passed to the plt.plot function from matplotlib.

Returns

Axes returned by the plt.subplots function. If with_histograms, return the three axes.

Return type

plt.axes

Examples

>>> y_pred = model.predict(x_test, ...)
>>> plot_true_vs_pred(y_true, y_pred, title='MyTitle', linestyle=':')