metya пре 3 година
родитељ
комит
1dbb315ae8
1 измењених фајлова са 265 додато и 240 уклоњено
  1. 265 240
      confusion_matrix.py

+ 265 - 240
confusion_matrix.py

@@ -9,246 +9,10 @@ from itertools import product
 from matplotlib.collections import QuadMesh
 from scipy.spatial.distance import cdist
 from matplotlib.figure import Figure
+from matplotlib.text import Text
 from numpy.typing import NDArray, ArrayLike
 
-
-def get_new_fig(fn, figsize=[9, 9]):
-    """Init graphics"""
-    fig1 = plt.figure(fn, figsize)
-    ax1 = fig1.gca()  # Get Current Axis
-    ax1.cla()  # clear existing plot
-    return fig1, ax1
-
-
-def configcell_text_and_colors(array_df, lin, col, oText, facecolors, posi, fz, fmt, show_null_values=False):
-    """
-    config cell text and colors
-    and return text elements to add and to dell
-    @TODO: use fmt
-    """
-    text_add = []
-    text_del = []
-    cell_val = array_df[lin][col]
-    tot_all = array_df[-1][-1]
-    per = (float(cell_val) / tot_all) * 100
-    curr_column = array_df[:, col]
-    ccl = len(curr_column)
-
-    # last line  and/or last column
-    if (col == (ccl - 1)) or (lin == (ccl - 1)):
-        # tots and percents
-        if cell_val != 0:
-            if (col == ccl - 1) and (lin == ccl - 1):
-                tot_rig = 0
-                for i in range(array_df.shape[0] - 1):
-                    tot_rig += array_df[i][i]
-                per_ok = (float(tot_rig) / cell_val) * 100
-            elif col == ccl - 1:
-                tot_rig = array_df[lin][lin]
-                per_ok = (float(tot_rig) / cell_val) * 100
-            elif lin == ccl - 1:
-                tot_rig = array_df[col][col]
-                per_ok = (float(tot_rig) / cell_val) * 100
-            per_err = 100 - per_ok  # type: ignore
-        else:
-            per_ok = per_err = 0
-
-        per_ok_s = ["%.2f%%" % (per_ok), "100%"][per_ok == 100]  # type: ignore
-
-        # text to DEL
-        text_del.append(oText)
-
-        # text to ADD
-        font_prop = fm.FontProperties(weight="bold", size=fz)
-        text_kwargs = dict(
-            color="w",
-            ha="center",
-            va="center",
-            gid="sum",
-            fontproperties=font_prop,
-        )
-        lis_txt = ["%d" % (cell_val), per_ok_s, "%.2f%%" % (per_err)]
-        lis_kwa = [text_kwargs]
-        dic = text_kwargs.copy()
-        dic["color"] = "g"
-        lis_kwa.append(dic)
-        dic = text_kwargs.copy()
-        dic["color"] = "r"
-        lis_kwa.append(dic)
-        lis_pos = [
-            (oText._x, oText._y - 0.3),
-            (oText._x, oText._y),
-            (oText._x, oText._y + 0.3),
-        ]
-        for i in range(len(lis_txt)):
-            newText = dict(
-                x=lis_pos[i][0],
-                y=lis_pos[i][1],
-                text=lis_txt[i],
-                kw=lis_kwa[i],
-            )
-            text_add.append(newText)
-
-        # set background color for sum cells (last line and last column)
-        carr = [0.27, 0.30, 0.27, 1.0]
-        if (col == ccl - 1) and (lin == ccl - 1):
-            carr = [0.17, 0.20, 0.17, 1.0]
-        facecolors[posi] = carr
-
-    else:
-        if per > 0:
-            txt = "%s\n%.1f%%" % (cell_val, per)
-        else:
-            if show_null_values == False:
-                txt = ""
-            elif show_null_values == True:
-                txt = "0"
-            else:
-                txt = "0\n0.0%"
-        oText.set_text(txt)
-
-        # main diagonal
-        if col == lin:
-            # set color of the textin the diagonal to white
-            oText.set_color("w")
-            # set background color in the diagonal to blue
-            facecolors[posi] = [0.35, 0.8, 0.55, 1.0]
-        else:
-            oText.set_color("r")
-
-    return text_add, text_del
-
-
-def insert_totals(df_cm):
-    """insert total column and line (the last ones)"""
-    sum_col = []
-    for c in df_cm.columns:
-        sum_col.append(df_cm[c].sum())
-    sum_lin = []
-    for item_line in df_cm.iterrows():
-        sum_lin.append(item_line[1].sum())
-    df_cm["sum_lin"] = sum_lin
-    sum_col.append(np.sum(sum_lin))
-    df_cm.loc["sum_col"] = sum_col
-
-
-def pp_matrix(
-    df_cm: NDArray[np.float64] | pd.DataFrame,
-    annot=True,
-    cmap="viridis",
-    fmt=".2f",
-    fz=10,
-    lw=1,
-    cbar=False,
-    figsize=[9, 9],
-    show_null_values=False,
-    pred_val_axis="x",
-    show=False,
-    rotation=True,
-    display_labels=None,
-):
-    """
-    print conf matrix with default layout (like matlab)
-    params:
-      df_cm          dataframe (pandas) without totals
-      annot          print text in each cell
-      cmap           Oranges,Oranges_r,YlGnBu,Blues,RdBu, ... see:
-      fz             fontsize
-      lw             linewidth
-      pred_val_axis  where to show the prediction values (x or y axis)
-                      'col' or 'x': show predicted values in columns (x axis) instead lines
-                      'lin' or 'y': show predicted values in lines   (y axis)
-      show           show the plot or not
-      rotation       rotate or not labels on figure
-      display_labels None, list of labels that display on figure
-    """
-    if not isinstance(df_cm, pd.DataFrame):
-        df_cm = pd.DataFrame(df_cm, index=display_labels, columns=display_labels)
-
-    if pred_val_axis in ("col", "x"):
-        xlbl = "Predicted"
-        ylbl = "Actual"
-    else:
-        xlbl = "Actual"
-        ylbl = "Predicted"
-        df_cm = df_cm.T
-
-    # create "Total" column
-    insert_totals(df_cm)
-
-    # this is for print allways in the same window
-    fig, ax1 = get_new_fig("Conf matrix default", figsize)
-
-    ax = sn.heatmap(
-        df_cm,
-        annot=annot,
-        annot_kws={"size": fz},
-        linewidths=lw,
-        ax=ax1,
-        cbar=cbar,
-        cmap=cmap,
-        linecolor="w",
-        fmt=fmt,
-    )
-
-    # set ticklabels rotation
-    if rotation:
-        rotation_x = 45
-        rotation_y = 25
-    else:
-        rotation_x = 0
-        rotation_y = 90
-
-    ax.set_xticklabels(ax.get_xticklabels(), rotation=rotation_y, fontsize=10)
-    ax.set_yticklabels(ax.get_yticklabels(), rotation=rotation_x, fontsize=10)
-
-    # Turn off all the ticks
-    for t in ax.xaxis.get_major_ticks():
-        t.tick1On = False
-        t.tick2On = False
-    for t in ax.yaxis.get_major_ticks():
-        t.tick1On = False
-        t.tick2On = False
-
-    # face colors list
-    quadmesh = ax.findobj(QuadMesh)[0]
-    facecolors = quadmesh.get_facecolors()
-
-    # iter in text elements
-    array_df = np.array(df_cm.to_records(index=False).tolist())
-    text_add = []
-    text_del = []
-    posi = -1  # from left to right, bottom to top.
-    for t in ax.collections[0].axes.texts:  # ax.texts:
-        pos = np.array(t.get_position()) - [0.5, 0.5]
-        lin = int(pos[1])
-        col = int(pos[0])
-        posi += 1
-
-        # set text
-        txt_res = configcell_text_and_colors(array_df, lin, col, t, facecolors, posi, fz, fmt, show_null_values)
-
-        text_add.extend(txt_res[0])
-        text_del.extend(txt_res[1])
-
-    # remove the old ones
-    for item in text_del:
-        item.remove()
-    # append the new ones
-    for item in text_add:
-        ax.text(item["x"], item["y"], item["text"], **item["kw"])
-
-    # titles and legends
-    ax.set_title("Confusion matrix")
-    ax.set_xlabel(xlbl)
-    ax.set_ylabel(ylbl)
-    # set layout slim
-    plt.tight_layout()
-    if show:
-        plt.show()
-
-    return plt.gcf()
-
+Color = tuple[float, float, float, float]
 
 class ConfusionMatrix:
     def __init__(self, thrs_config: dict, class_names: dict, iou_thr=0.5,):
@@ -454,12 +218,12 @@ class ConfusionMatrix:
         return iou
 
     def mask_pairwise_iou(self, masks1: np.ndarray, masks2: np.ndarray, labels1: np.ndarray, labels2: np.ndarray):
-
+        # TODO: implement it finally!
         """Need to have been imnplemented eventually and tested"""
 
         f1 = np.array(zip(masks1, labels1))
         f2 = np.array(zip(masks2, labels2))
-        return cdist(f1, f2, metric=self.mask_iou)  # TODO: implement it finally!
+        return cdist(f1, f2, metric=self.mask_iou) # type: ignore
 
     def return_matrix(self):
         """Returns tuple of box and mask confusion matrix."""
@@ -628,3 +392,264 @@ class ConfusionMatrix:
         self.figure_ = fig
         self.ax_ = ax
         return fig
+
+
+# Helper function to draw pretty figure
+# -------------------------------------
+
+
+# This is main function to draw pretty confusion matrix
+
+def pp_matrix(
+    df_cm: NDArray[np.float64] | pd.DataFrame,
+    annot=True,
+    cmap="viridis",
+    fmt=".2f",
+    fz=10,
+    lw=1,
+    cbar=False,
+    figsize=[9, 9],
+    show_null_values=False,
+    pred_val_axis="x",
+    show=False,
+    rotation=True,
+    display_labels=None,
+):
+    """
+    print conf matrix with default layout (like matlab)
+    params:
+      df_cm          dataframe (pandas) without totals
+      annot          print text in each cell
+      cmap           Oranges,Oranges_r,YlGnBu,Blues,RdBu, ... see:
+      fz             fontsize
+      lw             linewidth
+      pred_val_axis  where to show the prediction values (x or y axis)
+                      'col' or 'x': show predicted values in columns (x axis) instead lines
+                      'lin' or 'y': show predicted values in lines   (y axis)
+      show           show the plot or not
+      rotation       rotate or not labels on figure
+      display_labels None, list of labels that display on figure
+    """
+    if not isinstance(df_cm, pd.DataFrame):
+        df_cm = pd.DataFrame(df_cm, index=display_labels, columns=display_labels)
+
+    if pred_val_axis in ("col", "x"):
+        xlbl = "Predicted"
+        ylbl = "Actual"
+    else:
+        xlbl = "Actual"
+        ylbl = "Predicted"
+        df_cm = df_cm.T
+
+    # create "Total" column
+    insert_totals(df_cm)
+
+    # this is for print allways in the same window
+    fig, ax1 = get_new_fig("Conf matrix default", figsize)
+
+    ax = sn.heatmap(
+        df_cm,
+        annot=annot,
+        annot_kws={"size": fz},
+        linewidths=lw,
+        ax=ax1,
+        cbar=cbar,
+        cmap=cmap,
+        linecolor="w",
+        fmt=fmt,
+    )
+
+    # set ticklabels rotation
+    if rotation:
+        rotation_x = 45
+        rotation_y = 25
+    else:
+        rotation_x = 0
+        rotation_y = 90
+
+    ax.set_xticklabels(ax.get_xticklabels(), rotation=rotation_y, fontsize=10)
+    ax.set_yticklabels(ax.get_yticklabels(), rotation=rotation_x, fontsize=10)
+
+    # Turn off all the ticks
+    for t in ax.xaxis.get_major_ticks():
+        t.tick1On = False
+        t.tick2On = False
+    for t in ax.yaxis.get_major_ticks():
+        t.tick1On = False
+        t.tick2On = False
+
+    # face colors list
+    quadmesh = ax.findobj(QuadMesh)[0]
+    facecolors = quadmesh.get_facecolors()
+
+    # iter in text elements
+    array_df = np.array(df_cm.to_records(index=False).tolist())
+    text_add = []
+    text_del = []
+    posi = -1  # from left to right, bottom to top.
+    for t in ax.collections[0].axes.texts:  # ax.texts:
+        pos = np.array(t.get_position()) - [0.5, 0.5]
+        lin = int(pos[1])
+        col = int(pos[0])
+        posi += 1
+
+        # set text
+        txt_res = configcell_text_and_colors(array_df, lin, col, t, facecolors, posi, fz, fmt, show_null_values)
+
+        text_add.extend(txt_res[0])
+        text_del.extend(txt_res[1])
+
+    # remove the old ones
+    for item in text_del:
+        item.remove()
+    # append the new ones
+    for item in text_add:
+        ax.text(item["x"], item["y"], item["text"], **item["kw"])
+
+    # titles and legends
+    ax.set_title("Confusion matrix")
+    ax.set_xlabel(xlbl)
+    ax.set_ylabel(ylbl)
+    # set layout slim
+    plt.tight_layout()
+    if show:
+        plt.show()
+
+    return plt.gcf()
+
+
+# This is helper functions for pp_print function
+
+def get_new_fig(fn, figsize=[9, 9]):
+    """Init graphics"""
+    fig1 = plt.figure(fn, figsize)
+    ax1 = fig1.gca()  # Get Current Axis
+    ax1.cla()  # clear existing plot
+    return fig1, ax1
+
+
+def configcell_text_and_colors(
+    array_df: np.ndarray,
+    lin: int,
+    col: int,
+    oText: Text,
+    facecolors: list[Color],
+    posi: int,
+    fz: int,
+    fmt: str,
+    show_null_values=False,
+):
+    """
+    config cell text and colors
+    and return text elements to add and to dell
+    """
+    text_add = []
+    text_del = []
+    cell_val = array_df[lin][col]
+    tot_all = array_df[-1][-1]
+    per = (float(cell_val) / tot_all) * 100
+    curr_column = array_df[:, col]
+    ccl = len(curr_column)
+
+    # last line  and/or last column
+    if (col == (ccl - 1)) or (lin == (ccl - 1)):
+        # tots and percents
+        if cell_val != 0:
+            if (col == ccl - 1) and (lin == ccl - 1):
+                tot_rig = 0
+                for i in range(array_df.shape[0] - 1):
+                    tot_rig += array_df[i][i]
+                per_ok = (float(tot_rig) / cell_val) * 100
+            elif col == ccl - 1:
+                tot_rig = array_df[lin][lin]
+                per_ok = (float(tot_rig) / cell_val) * 100
+            elif lin == ccl - 1:
+                tot_rig = array_df[col][col]
+                per_ok = (float(tot_rig) / cell_val) * 100
+            per_err = 100 - per_ok  # type: ignore
+        else:
+            per_ok = per_err = 0
+
+        per_ok_s = ["%.2f%%" % (per_ok), "100%"][per_ok == 100]  # type: ignore
+
+        # text to DEL
+        text_del.append(oText)
+
+        # text to ADD
+        font_prop = fm.FontProperties(weight="bold", size=fz)
+        text_kwargs = dict(
+            color="w",
+            ha="center",
+            va="center",
+            gid="sum",
+            fontproperties=font_prop,
+        )
+        lis_txt = ["%d" % (cell_val), per_ok_s, "%.2f%%" % (per_err)]
+        lis_kwa = [text_kwargs]
+        dic = text_kwargs.copy()
+        dic["color"] = "g"
+        lis_kwa.append(dic)
+        dic = text_kwargs.copy()
+        dic["color"] = "r"
+        lis_kwa.append(dic)
+        lis_pos = [
+            (oText._x, oText._y - 0.3),
+            (oText._x, oText._y),
+            (oText._x, oText._y + 0.3),
+        ]
+        for i in range(len(lis_txt)):
+            new_text = dict(
+                x=lis_pos[i][0],
+                y=lis_pos[i][1],
+                text=lis_txt[i],
+                kw=lis_kwa[i],
+            )
+            text_add.append(new_text)
+
+        # set background color for sum cells (last line and last column)
+        carr = (0.27, 0.30, 0.27, 1.0)
+        if (col == ccl - 1) and (lin == ccl - 1):
+            carr = (0.17, 0.20, 0.17, 1.0)
+        facecolors[posi] = carr
+
+    else:
+        if per > 0:
+            txt = "%s\n%.1f%%" % (cell_val, per)
+        else:
+            if show_null_values == False:
+                txt = ""
+            elif show_null_values == True:
+                txt = "0"
+            else:
+                txt = "0\n0.0%"
+        oText.set_text(txt)
+
+        # main diagonal
+        if col == lin:
+            # set color of the textin the diagonal to white
+            oText.set_color("w")
+            # set background color in the diagonal to blue
+            facecolors[posi] = (0.35, 0.8, 0.55, 1.0)
+        else:
+            oText.set_color("r")
+
+    return text_add, text_del
+
+
+def insert_totals(df_cm):
+    """insert total column and line (the last ones)"""
+    sum_col = []
+    for c in df_cm.columns:
+        sum_col.append(df_cm[c].sum())
+    sum_lin = []
+    for item_line in df_cm.iterrows():
+        sum_lin.append(item_line[1].sum())
+    df_cm["sum_lin"] = sum_lin
+    sum_col.append(np.sum(sum_lin))
+    df_cm.loc["sum_col"] = sum_col
+
+
+
+
+
+