|
@@ -0,0 +1,630 @@
|
|
|
|
|
+import torch
|
|
|
|
|
+import numpy as np
|
|
|
|
|
+
|
|
|
|
|
+import pandas as pd
|
|
|
|
|
+import seaborn as sn
|
|
|
|
|
+import matplotlib.pyplot as plt
|
|
|
|
|
+import matplotlib.font_manager as fm
|
|
|
|
|
+from itertools import product
|
|
|
|
|
+from matplotlib.collections import QuadMesh
|
|
|
|
|
+from scipy.spatial.distance import cdist
|
|
|
|
|
+from matplotlib.figure import Figure
|
|
|
|
|
+from numpy.typing import NDArray, ArrayLike
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def get_new_fig(fn, figsize=[9, 9]):
|
|
|
|
|
+ """Init graphics"""
|
|
|
|
|
+ fig1 = plt.figure(fn, figsize)
|
|
|
|
|
+ ax1 = fig1.gca() # Get Current Axis
|
|
|
|
|
+ ax1.cla() # clear existing plot
|
|
|
|
|
+ return fig1, ax1
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def configcell_text_and_colors(array_df, lin, col, oText, facecolors, posi, fz, fmt, show_null_values=False):
|
|
|
|
|
+ """
|
|
|
|
|
+ config cell text and colors
|
|
|
|
|
+ and return text elements to add and to dell
|
|
|
|
|
+ @TODO: use fmt
|
|
|
|
|
+ """
|
|
|
|
|
+ text_add = []
|
|
|
|
|
+ text_del = []
|
|
|
|
|
+ cell_val = array_df[lin][col]
|
|
|
|
|
+ tot_all = array_df[-1][-1]
|
|
|
|
|
+ per = (float(cell_val) / tot_all) * 100
|
|
|
|
|
+ curr_column = array_df[:, col]
|
|
|
|
|
+ ccl = len(curr_column)
|
|
|
|
|
+
|
|
|
|
|
+ # last line and/or last column
|
|
|
|
|
+ if (col == (ccl - 1)) or (lin == (ccl - 1)):
|
|
|
|
|
+ # tots and percents
|
|
|
|
|
+ if cell_val != 0:
|
|
|
|
|
+ if (col == ccl - 1) and (lin == ccl - 1):
|
|
|
|
|
+ tot_rig = 0
|
|
|
|
|
+ for i in range(array_df.shape[0] - 1):
|
|
|
|
|
+ tot_rig += array_df[i][i]
|
|
|
|
|
+ per_ok = (float(tot_rig) / cell_val) * 100
|
|
|
|
|
+ elif col == ccl - 1:
|
|
|
|
|
+ tot_rig = array_df[lin][lin]
|
|
|
|
|
+ per_ok = (float(tot_rig) / cell_val) * 100
|
|
|
|
|
+ elif lin == ccl - 1:
|
|
|
|
|
+ tot_rig = array_df[col][col]
|
|
|
|
|
+ per_ok = (float(tot_rig) / cell_val) * 100
|
|
|
|
|
+ per_err = 100 - per_ok # type: ignore
|
|
|
|
|
+ else:
|
|
|
|
|
+ per_ok = per_err = 0
|
|
|
|
|
+
|
|
|
|
|
+ per_ok_s = ["%.2f%%" % (per_ok), "100%"][per_ok == 100] # type: ignore
|
|
|
|
|
+
|
|
|
|
|
+ # text to DEL
|
|
|
|
|
+ text_del.append(oText)
|
|
|
|
|
+
|
|
|
|
|
+ # text to ADD
|
|
|
|
|
+ font_prop = fm.FontProperties(weight="bold", size=fz)
|
|
|
|
|
+ text_kwargs = dict(
|
|
|
|
|
+ color="w",
|
|
|
|
|
+ ha="center",
|
|
|
|
|
+ va="center",
|
|
|
|
|
+ gid="sum",
|
|
|
|
|
+ fontproperties=font_prop,
|
|
|
|
|
+ )
|
|
|
|
|
+ lis_txt = ["%d" % (cell_val), per_ok_s, "%.2f%%" % (per_err)]
|
|
|
|
|
+ lis_kwa = [text_kwargs]
|
|
|
|
|
+ dic = text_kwargs.copy()
|
|
|
|
|
+ dic["color"] = "g"
|
|
|
|
|
+ lis_kwa.append(dic)
|
|
|
|
|
+ dic = text_kwargs.copy()
|
|
|
|
|
+ dic["color"] = "r"
|
|
|
|
|
+ lis_kwa.append(dic)
|
|
|
|
|
+ lis_pos = [
|
|
|
|
|
+ (oText._x, oText._y - 0.3),
|
|
|
|
|
+ (oText._x, oText._y),
|
|
|
|
|
+ (oText._x, oText._y + 0.3),
|
|
|
|
|
+ ]
|
|
|
|
|
+ for i in range(len(lis_txt)):
|
|
|
|
|
+ newText = dict(
|
|
|
|
|
+ x=lis_pos[i][0],
|
|
|
|
|
+ y=lis_pos[i][1],
|
|
|
|
|
+ text=lis_txt[i],
|
|
|
|
|
+ kw=lis_kwa[i],
|
|
|
|
|
+ )
|
|
|
|
|
+ text_add.append(newText)
|
|
|
|
|
+
|
|
|
|
|
+ # set background color for sum cells (last line and last column)
|
|
|
|
|
+ carr = [0.27, 0.30, 0.27, 1.0]
|
|
|
|
|
+ if (col == ccl - 1) and (lin == ccl - 1):
|
|
|
|
|
+ carr = [0.17, 0.20, 0.17, 1.0]
|
|
|
|
|
+ facecolors[posi] = carr
|
|
|
|
|
+
|
|
|
|
|
+ else:
|
|
|
|
|
+ if per > 0:
|
|
|
|
|
+ txt = "%s\n%.1f%%" % (cell_val, per)
|
|
|
|
|
+ else:
|
|
|
|
|
+ if show_null_values == False:
|
|
|
|
|
+ txt = ""
|
|
|
|
|
+ elif show_null_values == True:
|
|
|
|
|
+ txt = "0"
|
|
|
|
|
+ else:
|
|
|
|
|
+ txt = "0\n0.0%"
|
|
|
|
|
+ oText.set_text(txt)
|
|
|
|
|
+
|
|
|
|
|
+ # main diagonal
|
|
|
|
|
+ if col == lin:
|
|
|
|
|
+ # set color of the textin the diagonal to white
|
|
|
|
|
+ oText.set_color("w")
|
|
|
|
|
+ # set background color in the diagonal to blue
|
|
|
|
|
+ facecolors[posi] = [0.35, 0.8, 0.55, 1.0]
|
|
|
|
|
+ else:
|
|
|
|
|
+ oText.set_color("r")
|
|
|
|
|
+
|
|
|
|
|
+ return text_add, text_del
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def insert_totals(df_cm):
|
|
|
|
|
+ """insert total column and line (the last ones)"""
|
|
|
|
|
+ sum_col = []
|
|
|
|
|
+ for c in df_cm.columns:
|
|
|
|
|
+ sum_col.append(df_cm[c].sum())
|
|
|
|
|
+ sum_lin = []
|
|
|
|
|
+ for item_line in df_cm.iterrows():
|
|
|
|
|
+ sum_lin.append(item_line[1].sum())
|
|
|
|
|
+ df_cm["sum_lin"] = sum_lin
|
|
|
|
|
+ sum_col.append(np.sum(sum_lin))
|
|
|
|
|
+ df_cm.loc["sum_col"] = sum_col
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def pp_matrix(
|
|
|
|
|
+ df_cm: NDArray[np.float64] | pd.DataFrame,
|
|
|
|
|
+ annot=True,
|
|
|
|
|
+ cmap="viridis",
|
|
|
|
|
+ fmt=".2f",
|
|
|
|
|
+ fz=10,
|
|
|
|
|
+ lw=1,
|
|
|
|
|
+ cbar=False,
|
|
|
|
|
+ figsize=[9, 9],
|
|
|
|
|
+ show_null_values=False,
|
|
|
|
|
+ pred_val_axis="x",
|
|
|
|
|
+ show=False,
|
|
|
|
|
+ rotation=True,
|
|
|
|
|
+ display_labels=None,
|
|
|
|
|
+):
|
|
|
|
|
+ """
|
|
|
|
|
+ print conf matrix with default layout (like matlab)
|
|
|
|
|
+ params:
|
|
|
|
|
+ df_cm dataframe (pandas) without totals
|
|
|
|
|
+ annot print text in each cell
|
|
|
|
|
+ cmap Oranges,Oranges_r,YlGnBu,Blues,RdBu, ... see:
|
|
|
|
|
+ fz fontsize
|
|
|
|
|
+ lw linewidth
|
|
|
|
|
+ pred_val_axis where to show the prediction values (x or y axis)
|
|
|
|
|
+ 'col' or 'x': show predicted values in columns (x axis) instead lines
|
|
|
|
|
+ 'lin' or 'y': show predicted values in lines (y axis)
|
|
|
|
|
+ show show the plot or not
|
|
|
|
|
+ rotation rotate or not labels on figure
|
|
|
|
|
+ display_labels None, list of labels that display on figure
|
|
|
|
|
+ """
|
|
|
|
|
+ if not isinstance(df_cm, pd.DataFrame):
|
|
|
|
|
+ df_cm = pd.DataFrame(df_cm, index=display_labels, columns=display_labels)
|
|
|
|
|
+
|
|
|
|
|
+ if pred_val_axis in ("col", "x"):
|
|
|
|
|
+ xlbl = "Predicted"
|
|
|
|
|
+ ylbl = "Actual"
|
|
|
|
|
+ else:
|
|
|
|
|
+ xlbl = "Actual"
|
|
|
|
|
+ ylbl = "Predicted"
|
|
|
|
|
+ df_cm = df_cm.T
|
|
|
|
|
+
|
|
|
|
|
+ # create "Total" column
|
|
|
|
|
+ insert_totals(df_cm)
|
|
|
|
|
+
|
|
|
|
|
+ # this is for print allways in the same window
|
|
|
|
|
+ fig, ax1 = get_new_fig("Conf matrix default", figsize)
|
|
|
|
|
+
|
|
|
|
|
+ ax = sn.heatmap(
|
|
|
|
|
+ df_cm,
|
|
|
|
|
+ annot=annot,
|
|
|
|
|
+ annot_kws={"size": fz},
|
|
|
|
|
+ linewidths=lw,
|
|
|
|
|
+ ax=ax1,
|
|
|
|
|
+ cbar=cbar,
|
|
|
|
|
+ cmap=cmap,
|
|
|
|
|
+ linecolor="w",
|
|
|
|
|
+ fmt=fmt,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # set ticklabels rotation
|
|
|
|
|
+ if rotation:
|
|
|
|
|
+ rotation_x = 45
|
|
|
|
|
+ rotation_y = 25
|
|
|
|
|
+ else:
|
|
|
|
|
+ rotation_x = 0
|
|
|
|
|
+ rotation_y = 90
|
|
|
|
|
+
|
|
|
|
|
+ ax.set_xticklabels(ax.get_xticklabels(), rotation=rotation_y, fontsize=10)
|
|
|
|
|
+ ax.set_yticklabels(ax.get_yticklabels(), rotation=rotation_x, fontsize=10)
|
|
|
|
|
+
|
|
|
|
|
+ # Turn off all the ticks
|
|
|
|
|
+ for t in ax.xaxis.get_major_ticks():
|
|
|
|
|
+ t.tick1On = False
|
|
|
|
|
+ t.tick2On = False
|
|
|
|
|
+ for t in ax.yaxis.get_major_ticks():
|
|
|
|
|
+ t.tick1On = False
|
|
|
|
|
+ t.tick2On = False
|
|
|
|
|
+
|
|
|
|
|
+ # face colors list
|
|
|
|
|
+ quadmesh = ax.findobj(QuadMesh)[0]
|
|
|
|
|
+ facecolors = quadmesh.get_facecolors()
|
|
|
|
|
+
|
|
|
|
|
+ # iter in text elements
|
|
|
|
|
+ array_df = np.array(df_cm.to_records(index=False).tolist())
|
|
|
|
|
+ text_add = []
|
|
|
|
|
+ text_del = []
|
|
|
|
|
+ posi = -1 # from left to right, bottom to top.
|
|
|
|
|
+ for t in ax.collections[0].axes.texts: # ax.texts:
|
|
|
|
|
+ pos = np.array(t.get_position()) - [0.5, 0.5]
|
|
|
|
|
+ lin = int(pos[1])
|
|
|
|
|
+ col = int(pos[0])
|
|
|
|
|
+ posi += 1
|
|
|
|
|
+
|
|
|
|
|
+ # set text
|
|
|
|
|
+ txt_res = configcell_text_and_colors(array_df, lin, col, t, facecolors, posi, fz, fmt, show_null_values)
|
|
|
|
|
+
|
|
|
|
|
+ text_add.extend(txt_res[0])
|
|
|
|
|
+ text_del.extend(txt_res[1])
|
|
|
|
|
+
|
|
|
|
|
+ # remove the old ones
|
|
|
|
|
+ for item in text_del:
|
|
|
|
|
+ item.remove()
|
|
|
|
|
+ # append the new ones
|
|
|
|
|
+ for item in text_add:
|
|
|
|
|
+ ax.text(item["x"], item["y"], item["text"], **item["kw"])
|
|
|
|
|
+
|
|
|
|
|
+ # titles and legends
|
|
|
|
|
+ ax.set_title("Confusion matrix")
|
|
|
|
|
+ ax.set_xlabel(xlbl)
|
|
|
|
|
+ ax.set_ylabel(ylbl)
|
|
|
|
|
+ # set layout slim
|
|
|
|
|
+ plt.tight_layout()
|
|
|
|
|
+ if show:
|
|
|
|
|
+ plt.show()
|
|
|
|
|
+
|
|
|
|
|
+ return plt.gcf()
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class ConfusionMatrix:
|
|
|
|
|
+ def __init__(self, thrs_config: dict, class_names: dict, iou_thr=0.5,):
|
|
|
|
|
+ """
|
|
|
|
|
+ Class to create and dislpay confusion matrix for MaskRCNN
|
|
|
|
|
+ or any type of instance segmentation and detection architectures
|
|
|
|
|
+
|
|
|
|
|
+ Parameters
|
|
|
|
|
+ ----------
|
|
|
|
|
+ - iou_thr: IOU threshold
|
|
|
|
|
+ - thrs_config: dict of thresholds for every class
|
|
|
|
|
+ - class_names: dict of class names accordingly to class numbers
|
|
|
|
|
+
|
|
|
|
|
+ Attributes
|
|
|
|
|
+ ----------
|
|
|
|
|
+ - box_matrix: confusion matrix that contains results from boxes
|
|
|
|
|
+ - mask_matrix: confusion matrix that contains results from masks
|
|
|
|
|
+ - display_labels: labels according to classes + miss
|
|
|
|
|
+ - figure_: contains last plot of confusion matrix
|
|
|
|
|
+
|
|
|
|
|
+ Examples
|
|
|
|
|
+ --------
|
|
|
|
|
+ >>> from confusion_matrix import MaskRcnnConfusionMatrix
|
|
|
|
|
+
|
|
|
|
|
+ >>> confusion_matrix = MaskRcnnConfusionMatrix(class_names={0: 'class1', 1: 'class2'},
|
|
|
|
|
+ ... thrs_config={0: 0.5, 1: 0.5})
|
|
|
|
|
+ >>> for images, targets in test_dataloader:
|
|
|
|
|
+ >>> outputs = model(images)
|
|
|
|
|
+ >>> confusion_matrix.update(outputs, targets)
|
|
|
|
|
+
|
|
|
|
|
+ >>> confusion_matrix.plot(show=True)
|
|
|
|
|
+
|
|
|
|
|
+ """
|
|
|
|
|
+ self.iou_thr = iou_thr
|
|
|
|
|
+ self.num_classes = len(thrs_config)
|
|
|
|
|
+ self.classes = class_names.values()
|
|
|
|
|
+ self.box_matrix = np.zeros((self.num_classes + 1, self.num_classes + 1))
|
|
|
|
|
+ self.mask_matrix = np.zeros((self.num_classes + 1, self.num_classes + 1))
|
|
|
|
|
+ self.thrs_config = thrs_config
|
|
|
|
|
+ self.display_labels = list(self.classes) + ["Miss"]
|
|
|
|
|
+
|
|
|
|
|
+ def update(self, predictions: dict, targets: dict, after_nms=False):
|
|
|
|
|
+ """
|
|
|
|
|
+ Update confusion matrix for masks and boxes.
|
|
|
|
|
+ It is not very performative and effective implementations.
|
|
|
|
|
+ Needs to be rewritten in vectorize style, cause currently it's loops
|
|
|
|
|
+ Arguments:
|
|
|
|
|
+ ---------
|
|
|
|
|
+ predictions: dict of prediction from mask_rcnn
|
|
|
|
|
+ targets: dict of targets from dataloader
|
|
|
|
|
+ after_nms: it is after nms already or it needs to be thresholded here
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ -------
|
|
|
|
|
+ None, updates confusion matrix accordingly
|
|
|
|
|
+ """
|
|
|
|
|
+ if isinstance(targets["labels"], torch.Tensor):
|
|
|
|
|
+ targets = {k: v.to("cpu").numpy() for k, v in targets.items() if type(v) is not str}
|
|
|
|
|
+ l_classes = targets["labels"]
|
|
|
|
|
+ l_bboxs = targets["boxes"]
|
|
|
|
|
+ l_masks = targets["masks"]
|
|
|
|
|
+ d_confs = predictions["scores"]
|
|
|
|
|
+ d_bboxs = predictions["boxes"]
|
|
|
|
|
+ d_masks = predictions["masks"]
|
|
|
|
|
+ d_classes = predictions["labels"]
|
|
|
|
|
+ if not after_nms:
|
|
|
|
|
+ box_thrs = [self.thrs_config[label_id]["box_thr"] for label_id in d_classes]
|
|
|
|
|
+ mask_thrs = [self.thrs_config[label_id]["mask_thr"] for label_id in d_classes]
|
|
|
|
|
+ ids = np.where(d_confs > box_thrs)[0]
|
|
|
|
|
+ d_classes = d_classes[ids]
|
|
|
|
|
+ d_bboxs = d_bboxs[ids]
|
|
|
|
|
+ d_masks = d_masks[ids]
|
|
|
|
|
+ box_labels_detected = np.zeros(len(l_classes))
|
|
|
|
|
+ mask_labels_detected = np.zeros(len(l_classes))
|
|
|
|
|
+ box_detections_matched = np.zeros(len(d_classes))
|
|
|
|
|
+ mask_detections_matched = np.zeros(len(d_classes))
|
|
|
|
|
+ for l_idx, (l_class, l_bbox, l_mask) in enumerate(zip(l_classes, l_bboxs, l_masks)):
|
|
|
|
|
+ for d_idx, (d_class, d_bbox, d_mask) in enumerate(zip(d_classes, d_bboxs, d_masks)):
|
|
|
|
|
+ box_iou = self.box_pairwise_iou(l_bbox, d_bbox)
|
|
|
|
|
+ mask_iou = self.mask_iou((l_mask, l_class), (d_mask, d_class))
|
|
|
|
|
+ if box_iou >= self.iou_thr:
|
|
|
|
|
+ self.box_matrix[l_class, d_class] += 1
|
|
|
|
|
+ box_labels_detected[l_idx] = 1
|
|
|
|
|
+ box_detections_matched[d_idx] = 1
|
|
|
|
|
+ if mask_iou >= self.iou_thr:
|
|
|
|
|
+ self.mask_matrix[l_class, d_class] += 1
|
|
|
|
|
+ mask_labels_detected[l_idx] = 1
|
|
|
|
|
+ mask_detections_matched[d_idx] = 1
|
|
|
|
|
+ for i in np.where(box_labels_detected == 0)[0]:
|
|
|
|
|
+ self.box_matrix[l_classes[i], -1] += 1
|
|
|
|
|
+ for i in np.where(box_detections_matched == 0)[0]:
|
|
|
|
|
+ self.box_matrix[-1, d_classes[i]] += 1
|
|
|
|
|
+ for i in np.where(mask_labels_detected == 0)[0]:
|
|
|
|
|
+ self.mask_matrix[l_classes[i], -1] += 1
|
|
|
|
|
+ for i in np.where(mask_detections_matched == 0)[0]:
|
|
|
|
|
+ self.mask_matrix[-1, d_classes[i]] += 1
|
|
|
|
|
+
|
|
|
|
|
+ def process_batch(self, predictions: dict, targets: dict, after_nms=True):
|
|
|
|
|
+ """
|
|
|
|
|
+ Process batch of predictons and targets from model and dataloader
|
|
|
|
|
+ to update confusion matrix.
|
|
|
|
|
+ This is supposed to be effective vectorized implementations. Half of that have done, but not masks.
|
|
|
|
|
+ It means that this implementation only for boxes confusion matrix
|
|
|
|
|
+
|
|
|
|
|
+ Arguments:
|
|
|
|
|
+ predictions: dict of prediction from mask_rcnn
|
|
|
|
|
+ targets: dict of targets from dataloader
|
|
|
|
|
+ after_nms: it is after nms already or it needs to be thresholded here
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ None, updates confusion matrix accordingly
|
|
|
|
|
+ """
|
|
|
|
|
+ if isinstance(targets["labels"], torch.Tensor):
|
|
|
|
|
+ targets = {k: v.to("cpu").numpy() for k, v in targets.items() if type(v) is not str}
|
|
|
|
|
+
|
|
|
|
|
+ gt_classes = targets["labels"]
|
|
|
|
|
+ box_thrs = [self.thrs_config[label_id]["box_thr"] for label_id in predictions["labels"]]
|
|
|
|
|
+ try:
|
|
|
|
|
+ prediction_indexes = np.where(predictions["scores"] > box_thrs)[0]
|
|
|
|
|
+ prediction_classes = predictions["labels"][prediction_indexes]
|
|
|
|
|
+ except IndexError or TypeError as e:
|
|
|
|
|
+ # detections are empty, end of process
|
|
|
|
|
+ print("Какая то хуйня произошла!")
|
|
|
|
|
+ raise e
|
|
|
|
|
+ if len(prediction_classes) == 0 and len(gt_classes) > 0:
|
|
|
|
|
+ for gt_class in gt_classes:
|
|
|
|
|
+ self.box_matrix[self.num_classes, gt_class] += 1
|
|
|
|
|
+ return
|
|
|
|
|
+ elif len(prediction_classes) == 0 and len(gt_classes) == 0:
|
|
|
|
|
+ return
|
|
|
|
|
+
|
|
|
|
|
+ all_ious = self.box_pairwise_iou(targets["boxes"], predictions["boxes"])
|
|
|
|
|
+ want_idx = np.where(all_ious > self.iou_thr)
|
|
|
|
|
+
|
|
|
|
|
+ all_matches = [
|
|
|
|
|
+ [want_idx[0][i], want_idx[1][i], all_ious[want_idx[0][i], want_idx[1][i]]]
|
|
|
|
|
+ for i in range(want_idx[0].shape[0])
|
|
|
|
|
+ ]
|
|
|
|
|
+
|
|
|
|
|
+ all_matches = np.array(all_matches)
|
|
|
|
|
+ if all_matches.shape[0] > 0: # if there is match
|
|
|
|
|
+ all_matches = all_matches[all_matches[:, 2].argsort()[::-1]]
|
|
|
|
|
+ all_matches = all_matches[np.unique(all_matches[:, 1], return_index=True)[1]]
|
|
|
|
|
+ all_matches = all_matches[all_matches[:, 2].argsort()[::-1]]
|
|
|
|
|
+ all_matches = all_matches[np.unique(all_matches[:, 0], return_index=True)[1]]
|
|
|
|
|
+
|
|
|
|
|
+ for i, gt_class in enumerate(gt_classes):
|
|
|
|
|
+ if all_matches.shape[0] > 0 and all_matches[all_matches[:, 0] == i].shape[0] == 1:
|
|
|
|
|
+ detection_class = prediction_classes[int(all_matches[all_matches[:, 0] == i, 1][0])]
|
|
|
|
|
+ self.box_matrix[detection_class, gt_class] += 1
|
|
|
|
|
+ else:
|
|
|
|
|
+ self.box_matrix[self.num_classes, gt_class] += 1
|
|
|
|
|
+
|
|
|
|
|
+ for i, detection_class in enumerate(prediction_classes):
|
|
|
|
|
+ if not all_matches.shape[0] or (all_matches.shape[0] and all_matches[all_matches[:, 1] == i].shape[0] == 0):
|
|
|
|
|
+ detection_class = prediction_classes[i]
|
|
|
|
|
+ self.box_matrix[detection_class, self.num_classes] += 1
|
|
|
|
|
+
|
|
|
|
|
+ def box_pairwise_iou(self, boxes1: NDArray[np.float32], boxes2: NDArray[np.float32]) -> NDArray[np.float32]:
|
|
|
|
|
+ # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
|
|
|
|
|
+ """
|
|
|
|
|
+ Return intersection-over-union (Jaccard index) of boxes.
|
|
|
|
|
+ Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
|
|
|
|
|
+ Arguments:
|
|
|
|
|
+ boxes1 (Array[N, 4])
|
|
|
|
|
+ boxes2 (Array[M, 4])
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ iou (Array[N, M]): the NxM matrix containing the pairwise
|
|
|
|
|
+ IoU values for every element in boxes1 and boxes2
|
|
|
|
|
+ This implementation is taken from the above link and changed so that it only uses numpy..
|
|
|
|
|
+ """
|
|
|
|
|
+ if len(boxes1.shape) < 2:
|
|
|
|
|
+ boxes1 = boxes1.reshape(1, -1)
|
|
|
|
|
+ if len(boxes2.shape) < 2:
|
|
|
|
|
+ boxes2 = boxes2.reshape(1, -1)
|
|
|
|
|
+
|
|
|
|
|
+ def box_area(box):
|
|
|
|
|
+ # box = 4xn
|
|
|
|
|
+ return (box[2] - box[0]) * (box[3] - box[1])
|
|
|
|
|
+
|
|
|
|
|
+ area1 = box_area(boxes1.T)
|
|
|
|
|
+ area2 = box_area(boxes2.T)
|
|
|
|
|
+
|
|
|
|
|
+ lt = np.maximum(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
|
|
|
|
+ rb = np.minimum(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
|
|
|
|
+
|
|
|
|
|
+ inter = np.prod(np.clip(rb - lt, a_min=0, a_max=None), 2) # type: ignore
|
|
|
|
|
+ return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
|
|
|
|
|
+
|
|
|
|
|
+ def mask_iou(self, mask_and_label1: tuple[NDArray, NDArray], mask_and_label2: tuple[NDArray, NDArray]):
|
|
|
|
|
+ """
|
|
|
|
|
+ Return intersection-over-union (Jaccard index) of masks.
|
|
|
|
|
+ Masks should be pixel arrays
|
|
|
|
|
+ Arguments:
|
|
|
|
|
+ two tuples of mask and label
|
|
|
|
|
+ """
|
|
|
|
|
+ mask1, label1 = mask_and_label1
|
|
|
|
|
+ mask2, label2 = mask_and_label2
|
|
|
|
|
+ thrs1 = self.thrs_config[label1]["mask_thr"]
|
|
|
|
|
+ thrs2 = self.thrs_config[label2]["mask_thr"]
|
|
|
|
|
+ mask1_area = np.count_nonzero(mask1 >= thrs1)
|
|
|
|
|
+ mask2_area = np.count_nonzero(mask2 >= thrs2)
|
|
|
|
|
+ intersection = np.count_nonzero(np.logical_and(mask1, mask2))
|
|
|
|
|
+ iou = intersection / (mask1_area + mask2_area - intersection)
|
|
|
|
|
+ return iou
|
|
|
|
|
+
|
|
|
|
|
+ def mask_pairwise_iou(self, masks1: np.ndarray, masks2: np.ndarray, labels1: np.ndarray, labels2: np.ndarray):
|
|
|
|
|
+
|
|
|
|
|
+ """Need to have been imnplemented eventually and tested"""
|
|
|
|
|
+
|
|
|
|
|
+ f1 = np.array(zip(masks1, labels1))
|
|
|
|
|
+ f2 = np.array(zip(masks2, labels2))
|
|
|
|
|
+ return cdist(f1, f2, metric=self.mask_iou) # TODO: implement it finally!
|
|
|
|
|
+
|
|
|
|
|
+ def return_matrix(self):
|
|
|
|
|
+ """Returns tuple of box and mask confusion matrix."""
|
|
|
|
|
+ return self.box_matrix, self.mask_matrix
|
|
|
|
|
+
|
|
|
|
|
+ def get_matrix_figure(self, type="box", pretty=True):
|
|
|
|
|
+ """
|
|
|
|
|
+ Returns figure of confusion matrix of either box or mask type
|
|
|
|
|
+
|
|
|
|
|
+ Parameters
|
|
|
|
|
+ ----------
|
|
|
|
|
+ type: str, either box or mask, default `box`
|
|
|
|
|
+ pretty: bool, default `True`
|
|
|
|
|
+ plot pretty, featurize confusion matrix or just regular
|
|
|
|
|
+ """
|
|
|
|
|
+ if type == "box":
|
|
|
|
|
+ if pretty:
|
|
|
|
|
+ return pp_matrix(
|
|
|
|
|
+ self.box_matrix,
|
|
|
|
|
+ figsize=(14, 14),
|
|
|
|
|
+ rotation=False,
|
|
|
|
|
+ display_labels=self.display_labels,
|
|
|
|
|
+ )
|
|
|
|
|
+ else:
|
|
|
|
|
+ return self.plot(figsize=(10, 10), type_matrix="boxes")
|
|
|
|
|
+ else:
|
|
|
|
|
+ if pretty:
|
|
|
|
|
+ return pp_matrix(
|
|
|
|
|
+ self.mask_matrix,
|
|
|
|
|
+ figsize=(14, 14),
|
|
|
|
|
+ rotation=False,
|
|
|
|
|
+ display_labels=self.display_labels,
|
|
|
|
|
+ )
|
|
|
|
|
+ else:
|
|
|
|
|
+ return self.plot(figsize=(10, 10), type_matrix="masks")
|
|
|
|
|
+
|
|
|
|
|
+ def print_matrix(self):
|
|
|
|
|
+ for i in range(self.num_classes + 1):
|
|
|
|
|
+ print(" ".join(map(str, self.box_matrix[i])))
|
|
|
|
|
+
|
|
|
|
|
+ def pretty_plot(
|
|
|
|
|
+ self,
|
|
|
|
|
+ type="box",
|
|
|
|
|
+ figsize=(14, 14),
|
|
|
|
|
+ rotation=False,
|
|
|
|
|
+ cmap="viridis",
|
|
|
|
|
+ ) -> Figure:
|
|
|
|
|
+ """Plot feature rich confusion matrix.
|
|
|
|
|
+ """
|
|
|
|
|
+ if type=="box":
|
|
|
|
|
+ return pp_matrix(
|
|
|
|
|
+ self.box_matrix,
|
|
|
|
|
+ figsize=figsize,
|
|
|
|
|
+ rotation=rotation,
|
|
|
|
|
+ display_labels=self.display_labels,
|
|
|
|
|
+ cmap=cmap,
|
|
|
|
|
+ show=True,
|
|
|
|
|
+ )
|
|
|
|
|
+ else:
|
|
|
|
|
+ return pp_matrix(
|
|
|
|
|
+ self.mask_matrix,
|
|
|
|
|
+ figsize=figsize,
|
|
|
|
|
+ rotation=rotation,
|
|
|
|
|
+ display_labels=self.display_labels,
|
|
|
|
|
+ cmap=cmap,
|
|
|
|
|
+ show=True,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ def plot(
|
|
|
|
|
+ self,
|
|
|
|
|
+ include_values=True,
|
|
|
|
|
+ cmap="viridis",
|
|
|
|
|
+ xticks_rotation="vertical",
|
|
|
|
|
+ values_format=None,
|
|
|
|
|
+ ax=None,
|
|
|
|
|
+ colorbar=False,
|
|
|
|
|
+ type_matrix="boxes",
|
|
|
|
|
+ figsize=(9, 9),
|
|
|
|
|
+ ) -> Figure:
|
|
|
|
|
+ """Plot visualization of confusion matrix.
|
|
|
|
|
+
|
|
|
|
|
+ Parameters
|
|
|
|
|
+ ----------
|
|
|
|
|
+ include_values : bool, default=True
|
|
|
|
|
+ Includes values in confusion matrix.
|
|
|
|
|
+
|
|
|
|
|
+ cmap : str or matplotlib Colormap, default='viridis'
|
|
|
|
|
+ Colormap recognized by matplotlib.
|
|
|
|
|
+
|
|
|
|
|
+ xticks_rotation : {'vertical', 'horizontal'} or float, \
|
|
|
|
|
+ default='horizontal'
|
|
|
|
|
+ Rotation of xtick labels.
|
|
|
|
|
+
|
|
|
|
|
+ values_format : str, default=None
|
|
|
|
|
+ Format specification for values in confusion matrix. If `None`,
|
|
|
|
|
+ the format specification is 'd' or '.2g' whichever is shorter.
|
|
|
|
|
+
|
|
|
|
|
+ ax : matplotlib axes, default=None
|
|
|
|
|
+ Axes object to plot on. If `None`, a new figure and axes is
|
|
|
|
|
+ created.
|
|
|
|
|
+
|
|
|
|
|
+ colorbar : bool, default=True
|
|
|
|
|
+ Whether or not to add a colorbar to the plot.
|
|
|
|
|
+
|
|
|
|
|
+ figsize : tuple, default (9,9)
|
|
|
|
|
+ Size of figure.
|
|
|
|
|
+
|
|
|
|
|
+ type_matrix : str, ether box or mask
|
|
|
|
|
+ Type of matrix that need to plot.
|
|
|
|
|
+
|
|
|
|
|
+ Returns
|
|
|
|
|
+ -------
|
|
|
|
|
+ display : :firuge:`plt.figure`
|
|
|
|
|
+ """
|
|
|
|
|
+ if ax is None:
|
|
|
|
|
+ fig, ax = plt.subplots(figsize=figsize)
|
|
|
|
|
+ else:
|
|
|
|
|
+ fig = ax.figure
|
|
|
|
|
+
|
|
|
|
|
+ cm = self.box_matrix if type_matrix == "boxes" else self.mask_matrix
|
|
|
|
|
+ n_classes = cm.shape[0]
|
|
|
|
|
+ self.im_ = ax.imshow(cm, interpolation="nearest", cmap=cmap)
|
|
|
|
|
+ self.text_ = None
|
|
|
|
|
+ cmap_min, cmap_max = self.im_.cmap(0), self.im_.cmap(1.0)
|
|
|
|
|
+
|
|
|
|
|
+ if include_values:
|
|
|
|
|
+ self.text_ = np.empty_like(cm, dtype=object)
|
|
|
|
|
+
|
|
|
|
|
+ # print text with appropriate color depending on background
|
|
|
|
|
+ thresh = (cm.max() + cm.min()) / 2.0
|
|
|
|
|
+
|
|
|
|
|
+ for i, j in product(range(n_classes), range(n_classes)):
|
|
|
|
|
+ color = cmap_max if cm[i, j] < thresh else cmap_min
|
|
|
|
|
+
|
|
|
|
|
+ if values_format is None:
|
|
|
|
|
+ text_cm = format(cm[i, j], ".2g")
|
|
|
|
|
+ if cm.dtype.kind != "f":
|
|
|
|
|
+ text_d = format(cm[i, j], "d")
|
|
|
|
|
+ if len(text_d) < len(text_cm):
|
|
|
|
|
+ text_cm = text_d
|
|
|
|
|
+ else:
|
|
|
|
|
+ text_cm = format(cm[i, j], values_format)
|
|
|
|
|
+
|
|
|
|
|
+ self.text_[i, j] = ax.text(j, i, text_cm, ha="center", va="center", color=color)
|
|
|
|
|
+
|
|
|
|
|
+ if self.display_labels is None:
|
|
|
|
|
+ display_labels = np.arange(n_classes)
|
|
|
|
|
+ else:
|
|
|
|
|
+ display_labels = self.display_labels
|
|
|
|
|
+ if colorbar:
|
|
|
|
|
+ fig.colorbar(self.im_, ax=ax)
|
|
|
|
|
+ ax.set(
|
|
|
|
|
+ xticks=np.arange(n_classes),
|
|
|
|
|
+ yticks=np.arange(n_classes),
|
|
|
|
|
+ xticklabels=display_labels,
|
|
|
|
|
+ yticklabels=display_labels,
|
|
|
|
|
+ ylabel="True label",
|
|
|
|
|
+ xlabel="Predicted label",
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ ax.set_ylim((n_classes - 0.5, -0.5))
|
|
|
|
|
+ plt.setp(ax.get_xticklabels(), rotation=xticks_rotation)
|
|
|
|
|
+ plt.tight_layout()
|
|
|
|
|
+ plt.grid(False)
|
|
|
|
|
+
|
|
|
|
|
+ self.figure_ = fig
|
|
|
|
|
+ self.ax_ = ax
|
|
|
|
|
+ return fig
|