bff.plot.plot_true_vs_pred¶
-
bff.plot.
plot_true_vs_pred
(y_true, y_pred, with_correlation=False, with_determination=True, with_histograms=False, with_identity=False, label_x='Ground truth', label_y='Prediction', title='Predicted vs Actual', ax=None, lim_x=None, lim_y=None, grid='both', figsize=14, 7, dpi=80, style='default', **kwargs)¶ Plot the ground truth against the predictions of the model.
If a DataFrame is provided, it must only contain one column.
- Parameters
y_true (np.array or pd.DataFrame) – Actual values.
y_pred (np.array or pd.DataFrame) – Predicted values by the model.
with_correlation (bool, default False) – If true, print correlation coefficient in the top left corner.
with_determination (bool, default True) – If true, print the determination coefficient in the top left corner. If both with_correlation and with_determination are set to true, the correlation coefficient is printed.
with_histograms (bool, default False) – If true, plot histograms of y_true and y_pred on the sides. Not possible if the ax is provided.
with_identity (bool, default False) – If true, plot the identity line on the scatter plot.
label_x (str, default 'Ground truth') – Label for x axis.
label_y (str, default 'Prediction') – Label for y axis.
title (str, default 'Predicted vs Actual') – Title for the plot (axis level).
ax (plt.axes, optional) – Axes from matplotlib, if None, new figure and axes will be created.
lim_x (Tuple[TNum, TNum], optional) – Limit for the x axis. If None, automatically calculated according to the limits of the data, with an extra 5% for readability.
lim_y (Tuple[TNum, TNum], optional) – Limit for the y axis. If None, automatically calculated according to the limits of the data, with an extra 5% for readability.
grid (str or None, default 'both') – Axis where to activate the grid (‘both’, ‘x’, ‘y’). To turn off, set to None.
figsize (Tuple[int, int], default (14, 7)) – Size of the figure to plot.
dpi (int, default 80) – Resolution of the figure.
style (str, default 'default') – Style to use for matplotlib.pyplot. The style is use only in this context and not applied globally.
**kwargs – Additional keyword arguments to be passed to the plt.plot function from matplotlib.
- Returns
Axes returned by the plt.subplots function. If with_histograms, return the three axes.
- Return type
plt.axes
Examples
>>> y_pred = model.predict(x_test, ...) >>> plot_true_vs_pred(y_true, y_pred, title='MyTitle', linestyle=':')