|
|
@@ -9,246 +9,10 @@ from itertools import product
|
|
|
from matplotlib.collections import QuadMesh
|
|
|
from scipy.spatial.distance import cdist
|
|
|
from matplotlib.figure import Figure
|
|
|
+from matplotlib.text import Text
|
|
|
from numpy.typing import NDArray, ArrayLike
|
|
|
|
|
|
-
|
|
|
-def get_new_fig(fn, figsize=[9, 9]):
|
|
|
- """Init graphics"""
|
|
|
- fig1 = plt.figure(fn, figsize)
|
|
|
- ax1 = fig1.gca() # Get Current Axis
|
|
|
- ax1.cla() # clear existing plot
|
|
|
- return fig1, ax1
|
|
|
-
|
|
|
-
|
|
|
-def configcell_text_and_colors(array_df, lin, col, oText, facecolors, posi, fz, fmt, show_null_values=False):
|
|
|
- """
|
|
|
- config cell text and colors
|
|
|
- and return text elements to add and to dell
|
|
|
- @TODO: use fmt
|
|
|
- """
|
|
|
- text_add = []
|
|
|
- text_del = []
|
|
|
- cell_val = array_df[lin][col]
|
|
|
- tot_all = array_df[-1][-1]
|
|
|
- per = (float(cell_val) / tot_all) * 100
|
|
|
- curr_column = array_df[:, col]
|
|
|
- ccl = len(curr_column)
|
|
|
-
|
|
|
- # last line and/or last column
|
|
|
- if (col == (ccl - 1)) or (lin == (ccl - 1)):
|
|
|
- # tots and percents
|
|
|
- if cell_val != 0:
|
|
|
- if (col == ccl - 1) and (lin == ccl - 1):
|
|
|
- tot_rig = 0
|
|
|
- for i in range(array_df.shape[0] - 1):
|
|
|
- tot_rig += array_df[i][i]
|
|
|
- per_ok = (float(tot_rig) / cell_val) * 100
|
|
|
- elif col == ccl - 1:
|
|
|
- tot_rig = array_df[lin][lin]
|
|
|
- per_ok = (float(tot_rig) / cell_val) * 100
|
|
|
- elif lin == ccl - 1:
|
|
|
- tot_rig = array_df[col][col]
|
|
|
- per_ok = (float(tot_rig) / cell_val) * 100
|
|
|
- per_err = 100 - per_ok # type: ignore
|
|
|
- else:
|
|
|
- per_ok = per_err = 0
|
|
|
-
|
|
|
- per_ok_s = ["%.2f%%" % (per_ok), "100%"][per_ok == 100] # type: ignore
|
|
|
-
|
|
|
- # text to DEL
|
|
|
- text_del.append(oText)
|
|
|
-
|
|
|
- # text to ADD
|
|
|
- font_prop = fm.FontProperties(weight="bold", size=fz)
|
|
|
- text_kwargs = dict(
|
|
|
- color="w",
|
|
|
- ha="center",
|
|
|
- va="center",
|
|
|
- gid="sum",
|
|
|
- fontproperties=font_prop,
|
|
|
- )
|
|
|
- lis_txt = ["%d" % (cell_val), per_ok_s, "%.2f%%" % (per_err)]
|
|
|
- lis_kwa = [text_kwargs]
|
|
|
- dic = text_kwargs.copy()
|
|
|
- dic["color"] = "g"
|
|
|
- lis_kwa.append(dic)
|
|
|
- dic = text_kwargs.copy()
|
|
|
- dic["color"] = "r"
|
|
|
- lis_kwa.append(dic)
|
|
|
- lis_pos = [
|
|
|
- (oText._x, oText._y - 0.3),
|
|
|
- (oText._x, oText._y),
|
|
|
- (oText._x, oText._y + 0.3),
|
|
|
- ]
|
|
|
- for i in range(len(lis_txt)):
|
|
|
- newText = dict(
|
|
|
- x=lis_pos[i][0],
|
|
|
- y=lis_pos[i][1],
|
|
|
- text=lis_txt[i],
|
|
|
- kw=lis_kwa[i],
|
|
|
- )
|
|
|
- text_add.append(newText)
|
|
|
-
|
|
|
- # set background color for sum cells (last line and last column)
|
|
|
- carr = [0.27, 0.30, 0.27, 1.0]
|
|
|
- if (col == ccl - 1) and (lin == ccl - 1):
|
|
|
- carr = [0.17, 0.20, 0.17, 1.0]
|
|
|
- facecolors[posi] = carr
|
|
|
-
|
|
|
- else:
|
|
|
- if per > 0:
|
|
|
- txt = "%s\n%.1f%%" % (cell_val, per)
|
|
|
- else:
|
|
|
- if show_null_values == False:
|
|
|
- txt = ""
|
|
|
- elif show_null_values == True:
|
|
|
- txt = "0"
|
|
|
- else:
|
|
|
- txt = "0\n0.0%"
|
|
|
- oText.set_text(txt)
|
|
|
-
|
|
|
- # main diagonal
|
|
|
- if col == lin:
|
|
|
- # set color of the textin the diagonal to white
|
|
|
- oText.set_color("w")
|
|
|
- # set background color in the diagonal to blue
|
|
|
- facecolors[posi] = [0.35, 0.8, 0.55, 1.0]
|
|
|
- else:
|
|
|
- oText.set_color("r")
|
|
|
-
|
|
|
- return text_add, text_del
|
|
|
-
|
|
|
-
|
|
|
-def insert_totals(df_cm):
|
|
|
- """insert total column and line (the last ones)"""
|
|
|
- sum_col = []
|
|
|
- for c in df_cm.columns:
|
|
|
- sum_col.append(df_cm[c].sum())
|
|
|
- sum_lin = []
|
|
|
- for item_line in df_cm.iterrows():
|
|
|
- sum_lin.append(item_line[1].sum())
|
|
|
- df_cm["sum_lin"] = sum_lin
|
|
|
- sum_col.append(np.sum(sum_lin))
|
|
|
- df_cm.loc["sum_col"] = sum_col
|
|
|
-
|
|
|
-
|
|
|
-def pp_matrix(
|
|
|
- df_cm: NDArray[np.float64] | pd.DataFrame,
|
|
|
- annot=True,
|
|
|
- cmap="viridis",
|
|
|
- fmt=".2f",
|
|
|
- fz=10,
|
|
|
- lw=1,
|
|
|
- cbar=False,
|
|
|
- figsize=[9, 9],
|
|
|
- show_null_values=False,
|
|
|
- pred_val_axis="x",
|
|
|
- show=False,
|
|
|
- rotation=True,
|
|
|
- display_labels=None,
|
|
|
-):
|
|
|
- """
|
|
|
- print conf matrix with default layout (like matlab)
|
|
|
- params:
|
|
|
- df_cm dataframe (pandas) without totals
|
|
|
- annot print text in each cell
|
|
|
- cmap Oranges,Oranges_r,YlGnBu,Blues,RdBu, ... see:
|
|
|
- fz fontsize
|
|
|
- lw linewidth
|
|
|
- pred_val_axis where to show the prediction values (x or y axis)
|
|
|
- 'col' or 'x': show predicted values in columns (x axis) instead lines
|
|
|
- 'lin' or 'y': show predicted values in lines (y axis)
|
|
|
- show show the plot or not
|
|
|
- rotation rotate or not labels on figure
|
|
|
- display_labels None, list of labels that display on figure
|
|
|
- """
|
|
|
- if not isinstance(df_cm, pd.DataFrame):
|
|
|
- df_cm = pd.DataFrame(df_cm, index=display_labels, columns=display_labels)
|
|
|
-
|
|
|
- if pred_val_axis in ("col", "x"):
|
|
|
- xlbl = "Predicted"
|
|
|
- ylbl = "Actual"
|
|
|
- else:
|
|
|
- xlbl = "Actual"
|
|
|
- ylbl = "Predicted"
|
|
|
- df_cm = df_cm.T
|
|
|
-
|
|
|
- # create "Total" column
|
|
|
- insert_totals(df_cm)
|
|
|
-
|
|
|
- # this is for print allways in the same window
|
|
|
- fig, ax1 = get_new_fig("Conf matrix default", figsize)
|
|
|
-
|
|
|
- ax = sn.heatmap(
|
|
|
- df_cm,
|
|
|
- annot=annot,
|
|
|
- annot_kws={"size": fz},
|
|
|
- linewidths=lw,
|
|
|
- ax=ax1,
|
|
|
- cbar=cbar,
|
|
|
- cmap=cmap,
|
|
|
- linecolor="w",
|
|
|
- fmt=fmt,
|
|
|
- )
|
|
|
-
|
|
|
- # set ticklabels rotation
|
|
|
- if rotation:
|
|
|
- rotation_x = 45
|
|
|
- rotation_y = 25
|
|
|
- else:
|
|
|
- rotation_x = 0
|
|
|
- rotation_y = 90
|
|
|
-
|
|
|
- ax.set_xticklabels(ax.get_xticklabels(), rotation=rotation_y, fontsize=10)
|
|
|
- ax.set_yticklabels(ax.get_yticklabels(), rotation=rotation_x, fontsize=10)
|
|
|
-
|
|
|
- # Turn off all the ticks
|
|
|
- for t in ax.xaxis.get_major_ticks():
|
|
|
- t.tick1On = False
|
|
|
- t.tick2On = False
|
|
|
- for t in ax.yaxis.get_major_ticks():
|
|
|
- t.tick1On = False
|
|
|
- t.tick2On = False
|
|
|
-
|
|
|
- # face colors list
|
|
|
- quadmesh = ax.findobj(QuadMesh)[0]
|
|
|
- facecolors = quadmesh.get_facecolors()
|
|
|
-
|
|
|
- # iter in text elements
|
|
|
- array_df = np.array(df_cm.to_records(index=False).tolist())
|
|
|
- text_add = []
|
|
|
- text_del = []
|
|
|
- posi = -1 # from left to right, bottom to top.
|
|
|
- for t in ax.collections[0].axes.texts: # ax.texts:
|
|
|
- pos = np.array(t.get_position()) - [0.5, 0.5]
|
|
|
- lin = int(pos[1])
|
|
|
- col = int(pos[0])
|
|
|
- posi += 1
|
|
|
-
|
|
|
- # set text
|
|
|
- txt_res = configcell_text_and_colors(array_df, lin, col, t, facecolors, posi, fz, fmt, show_null_values)
|
|
|
-
|
|
|
- text_add.extend(txt_res[0])
|
|
|
- text_del.extend(txt_res[1])
|
|
|
-
|
|
|
- # remove the old ones
|
|
|
- for item in text_del:
|
|
|
- item.remove()
|
|
|
- # append the new ones
|
|
|
- for item in text_add:
|
|
|
- ax.text(item["x"], item["y"], item["text"], **item["kw"])
|
|
|
-
|
|
|
- # titles and legends
|
|
|
- ax.set_title("Confusion matrix")
|
|
|
- ax.set_xlabel(xlbl)
|
|
|
- ax.set_ylabel(ylbl)
|
|
|
- # set layout slim
|
|
|
- plt.tight_layout()
|
|
|
- if show:
|
|
|
- plt.show()
|
|
|
-
|
|
|
- return plt.gcf()
|
|
|
-
|
|
|
+Color = tuple[float, float, float, float]
|
|
|
|
|
|
class ConfusionMatrix:
|
|
|
def __init__(self, thrs_config: dict, class_names: dict, iou_thr=0.5,):
|
|
|
@@ -454,12 +218,12 @@ class ConfusionMatrix:
|
|
|
return iou
|
|
|
|
|
|
def mask_pairwise_iou(self, masks1: np.ndarray, masks2: np.ndarray, labels1: np.ndarray, labels2: np.ndarray):
|
|
|
-
|
|
|
+ # TODO: implement it finally!
|
|
|
"""Need to have been imnplemented eventually and tested"""
|
|
|
|
|
|
f1 = np.array(zip(masks1, labels1))
|
|
|
f2 = np.array(zip(masks2, labels2))
|
|
|
- return cdist(f1, f2, metric=self.mask_iou) # TODO: implement it finally!
|
|
|
+ return cdist(f1, f2, metric=self.mask_iou) # type: ignore
|
|
|
|
|
|
def return_matrix(self):
|
|
|
"""Returns tuple of box and mask confusion matrix."""
|
|
|
@@ -628,3 +392,264 @@ class ConfusionMatrix:
|
|
|
self.figure_ = fig
|
|
|
self.ax_ = ax
|
|
|
return fig
|
|
|
+
|
|
|
+
|
|
|
+# Helper function to draw pretty figure
|
|
|
+# -------------------------------------
|
|
|
+
|
|
|
+
|
|
|
+# This is main function to draw pretty confusion matrix
|
|
|
+
|
|
|
+def pp_matrix(
|
|
|
+ df_cm: NDArray[np.float64] | pd.DataFrame,
|
|
|
+ annot=True,
|
|
|
+ cmap="viridis",
|
|
|
+ fmt=".2f",
|
|
|
+ fz=10,
|
|
|
+ lw=1,
|
|
|
+ cbar=False,
|
|
|
+ figsize=[9, 9],
|
|
|
+ show_null_values=False,
|
|
|
+ pred_val_axis="x",
|
|
|
+ show=False,
|
|
|
+ rotation=True,
|
|
|
+ display_labels=None,
|
|
|
+):
|
|
|
+ """
|
|
|
+ print conf matrix with default layout (like matlab)
|
|
|
+ params:
|
|
|
+ df_cm dataframe (pandas) without totals
|
|
|
+ annot print text in each cell
|
|
|
+ cmap Oranges,Oranges_r,YlGnBu,Blues,RdBu, ... see:
|
|
|
+ fz fontsize
|
|
|
+ lw linewidth
|
|
|
+ pred_val_axis where to show the prediction values (x or y axis)
|
|
|
+ 'col' or 'x': show predicted values in columns (x axis) instead lines
|
|
|
+ 'lin' or 'y': show predicted values in lines (y axis)
|
|
|
+ show show the plot or not
|
|
|
+ rotation rotate or not labels on figure
|
|
|
+ display_labels None, list of labels that display on figure
|
|
|
+ """
|
|
|
+ if not isinstance(df_cm, pd.DataFrame):
|
|
|
+ df_cm = pd.DataFrame(df_cm, index=display_labels, columns=display_labels)
|
|
|
+
|
|
|
+ if pred_val_axis in ("col", "x"):
|
|
|
+ xlbl = "Predicted"
|
|
|
+ ylbl = "Actual"
|
|
|
+ else:
|
|
|
+ xlbl = "Actual"
|
|
|
+ ylbl = "Predicted"
|
|
|
+ df_cm = df_cm.T
|
|
|
+
|
|
|
+ # create "Total" column
|
|
|
+ insert_totals(df_cm)
|
|
|
+
|
|
|
+ # this is for print allways in the same window
|
|
|
+ fig, ax1 = get_new_fig("Conf matrix default", figsize)
|
|
|
+
|
|
|
+ ax = sn.heatmap(
|
|
|
+ df_cm,
|
|
|
+ annot=annot,
|
|
|
+ annot_kws={"size": fz},
|
|
|
+ linewidths=lw,
|
|
|
+ ax=ax1,
|
|
|
+ cbar=cbar,
|
|
|
+ cmap=cmap,
|
|
|
+ linecolor="w",
|
|
|
+ fmt=fmt,
|
|
|
+ )
|
|
|
+
|
|
|
+ # set ticklabels rotation
|
|
|
+ if rotation:
|
|
|
+ rotation_x = 45
|
|
|
+ rotation_y = 25
|
|
|
+ else:
|
|
|
+ rotation_x = 0
|
|
|
+ rotation_y = 90
|
|
|
+
|
|
|
+ ax.set_xticklabels(ax.get_xticklabels(), rotation=rotation_y, fontsize=10)
|
|
|
+ ax.set_yticklabels(ax.get_yticklabels(), rotation=rotation_x, fontsize=10)
|
|
|
+
|
|
|
+ # Turn off all the ticks
|
|
|
+ for t in ax.xaxis.get_major_ticks():
|
|
|
+ t.tick1On = False
|
|
|
+ t.tick2On = False
|
|
|
+ for t in ax.yaxis.get_major_ticks():
|
|
|
+ t.tick1On = False
|
|
|
+ t.tick2On = False
|
|
|
+
|
|
|
+ # face colors list
|
|
|
+ quadmesh = ax.findobj(QuadMesh)[0]
|
|
|
+ facecolors = quadmesh.get_facecolors()
|
|
|
+
|
|
|
+ # iter in text elements
|
|
|
+ array_df = np.array(df_cm.to_records(index=False).tolist())
|
|
|
+ text_add = []
|
|
|
+ text_del = []
|
|
|
+ posi = -1 # from left to right, bottom to top.
|
|
|
+ for t in ax.collections[0].axes.texts: # ax.texts:
|
|
|
+ pos = np.array(t.get_position()) - [0.5, 0.5]
|
|
|
+ lin = int(pos[1])
|
|
|
+ col = int(pos[0])
|
|
|
+ posi += 1
|
|
|
+
|
|
|
+ # set text
|
|
|
+ txt_res = configcell_text_and_colors(array_df, lin, col, t, facecolors, posi, fz, fmt, show_null_values)
|
|
|
+
|
|
|
+ text_add.extend(txt_res[0])
|
|
|
+ text_del.extend(txt_res[1])
|
|
|
+
|
|
|
+ # remove the old ones
|
|
|
+ for item in text_del:
|
|
|
+ item.remove()
|
|
|
+ # append the new ones
|
|
|
+ for item in text_add:
|
|
|
+ ax.text(item["x"], item["y"], item["text"], **item["kw"])
|
|
|
+
|
|
|
+ # titles and legends
|
|
|
+ ax.set_title("Confusion matrix")
|
|
|
+ ax.set_xlabel(xlbl)
|
|
|
+ ax.set_ylabel(ylbl)
|
|
|
+ # set layout slim
|
|
|
+ plt.tight_layout()
|
|
|
+ if show:
|
|
|
+ plt.show()
|
|
|
+
|
|
|
+ return plt.gcf()
|
|
|
+
|
|
|
+
|
|
|
+# This is helper functions for pp_print function
|
|
|
+
|
|
|
+def get_new_fig(fn, figsize=[9, 9]):
|
|
|
+ """Init graphics"""
|
|
|
+ fig1 = plt.figure(fn, figsize)
|
|
|
+ ax1 = fig1.gca() # Get Current Axis
|
|
|
+ ax1.cla() # clear existing plot
|
|
|
+ return fig1, ax1
|
|
|
+
|
|
|
+
|
|
|
+def configcell_text_and_colors(
|
|
|
+ array_df: np.ndarray,
|
|
|
+ lin: int,
|
|
|
+ col: int,
|
|
|
+ oText: Text,
|
|
|
+ facecolors: list[Color],
|
|
|
+ posi: int,
|
|
|
+ fz: int,
|
|
|
+ fmt: str,
|
|
|
+ show_null_values=False,
|
|
|
+):
|
|
|
+ """
|
|
|
+ config cell text and colors
|
|
|
+ and return text elements to add and to dell
|
|
|
+ """
|
|
|
+ text_add = []
|
|
|
+ text_del = []
|
|
|
+ cell_val = array_df[lin][col]
|
|
|
+ tot_all = array_df[-1][-1]
|
|
|
+ per = (float(cell_val) / tot_all) * 100
|
|
|
+ curr_column = array_df[:, col]
|
|
|
+ ccl = len(curr_column)
|
|
|
+
|
|
|
+ # last line and/or last column
|
|
|
+ if (col == (ccl - 1)) or (lin == (ccl - 1)):
|
|
|
+ # tots and percents
|
|
|
+ if cell_val != 0:
|
|
|
+ if (col == ccl - 1) and (lin == ccl - 1):
|
|
|
+ tot_rig = 0
|
|
|
+ for i in range(array_df.shape[0] - 1):
|
|
|
+ tot_rig += array_df[i][i]
|
|
|
+ per_ok = (float(tot_rig) / cell_val) * 100
|
|
|
+ elif col == ccl - 1:
|
|
|
+ tot_rig = array_df[lin][lin]
|
|
|
+ per_ok = (float(tot_rig) / cell_val) * 100
|
|
|
+ elif lin == ccl - 1:
|
|
|
+ tot_rig = array_df[col][col]
|
|
|
+ per_ok = (float(tot_rig) / cell_val) * 100
|
|
|
+ per_err = 100 - per_ok # type: ignore
|
|
|
+ else:
|
|
|
+ per_ok = per_err = 0
|
|
|
+
|
|
|
+ per_ok_s = ["%.2f%%" % (per_ok), "100%"][per_ok == 100] # type: ignore
|
|
|
+
|
|
|
+ # text to DEL
|
|
|
+ text_del.append(oText)
|
|
|
+
|
|
|
+ # text to ADD
|
|
|
+ font_prop = fm.FontProperties(weight="bold", size=fz)
|
|
|
+ text_kwargs = dict(
|
|
|
+ color="w",
|
|
|
+ ha="center",
|
|
|
+ va="center",
|
|
|
+ gid="sum",
|
|
|
+ fontproperties=font_prop,
|
|
|
+ )
|
|
|
+ lis_txt = ["%d" % (cell_val), per_ok_s, "%.2f%%" % (per_err)]
|
|
|
+ lis_kwa = [text_kwargs]
|
|
|
+ dic = text_kwargs.copy()
|
|
|
+ dic["color"] = "g"
|
|
|
+ lis_kwa.append(dic)
|
|
|
+ dic = text_kwargs.copy()
|
|
|
+ dic["color"] = "r"
|
|
|
+ lis_kwa.append(dic)
|
|
|
+ lis_pos = [
|
|
|
+ (oText._x, oText._y - 0.3),
|
|
|
+ (oText._x, oText._y),
|
|
|
+ (oText._x, oText._y + 0.3),
|
|
|
+ ]
|
|
|
+ for i in range(len(lis_txt)):
|
|
|
+ new_text = dict(
|
|
|
+ x=lis_pos[i][0],
|
|
|
+ y=lis_pos[i][1],
|
|
|
+ text=lis_txt[i],
|
|
|
+ kw=lis_kwa[i],
|
|
|
+ )
|
|
|
+ text_add.append(new_text)
|
|
|
+
|
|
|
+ # set background color for sum cells (last line and last column)
|
|
|
+ carr = (0.27, 0.30, 0.27, 1.0)
|
|
|
+ if (col == ccl - 1) and (lin == ccl - 1):
|
|
|
+ carr = (0.17, 0.20, 0.17, 1.0)
|
|
|
+ facecolors[posi] = carr
|
|
|
+
|
|
|
+ else:
|
|
|
+ if per > 0:
|
|
|
+ txt = "%s\n%.1f%%" % (cell_val, per)
|
|
|
+ else:
|
|
|
+ if show_null_values == False:
|
|
|
+ txt = ""
|
|
|
+ elif show_null_values == True:
|
|
|
+ txt = "0"
|
|
|
+ else:
|
|
|
+ txt = "0\n0.0%"
|
|
|
+ oText.set_text(txt)
|
|
|
+
|
|
|
+ # main diagonal
|
|
|
+ if col == lin:
|
|
|
+ # set color of the textin the diagonal to white
|
|
|
+ oText.set_color("w")
|
|
|
+ # set background color in the diagonal to blue
|
|
|
+ facecolors[posi] = (0.35, 0.8, 0.55, 1.0)
|
|
|
+ else:
|
|
|
+ oText.set_color("r")
|
|
|
+
|
|
|
+ return text_add, text_del
|
|
|
+
|
|
|
+
|
|
|
+def insert_totals(df_cm):
|
|
|
+ """insert total column and line (the last ones)"""
|
|
|
+ sum_col = []
|
|
|
+ for c in df_cm.columns:
|
|
|
+ sum_col.append(df_cm[c].sum())
|
|
|
+ sum_lin = []
|
|
|
+ for item_line in df_cm.iterrows():
|
|
|
+ sum_lin.append(item_line[1].sum())
|
|
|
+ df_cm["sum_lin"] = sum_lin
|
|
|
+ sum_col.append(np.sum(sum_lin))
|
|
|
+ df_cm.loc["sum_col"] = sum_col
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|