confusion_matrix.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655
  1. import torch
  2. import numpy as np
  3. import pandas as pd
  4. import seaborn as sn
  5. import matplotlib.pyplot as plt
  6. import matplotlib.font_manager as fm
  7. from itertools import product
  8. from matplotlib.collections import QuadMesh
  9. from scipy.spatial.distance import cdist
  10. from matplotlib.figure import Figure
  11. from matplotlib.text import Text
  12. from numpy.typing import NDArray, ArrayLike
  13. Color = tuple[float, float, float, float]
  14. class ConfusionMatrix:
  15. def __init__(self, thrs_config: dict, class_names: dict, iou_thr=0.5,):
  16. """
  17. Class to create and dislpay confusion matrix for MaskRCNN
  18. or any type of instance segmentation and detection architectures
  19. Parameters
  20. ----------
  21. - iou_thr: IOU threshold
  22. - thrs_config: dict of thresholds for every class
  23. - class_names: dict of class names accordingly to class numbers
  24. Attributes
  25. ----------
  26. - box_matrix: confusion matrix that contains results from boxes
  27. - mask_matrix: confusion matrix that contains results from masks
  28. - display_labels: labels according to classes + miss
  29. - figure_: contains last plot of confusion matrix
  30. Examples
  31. --------
  32. >>> from confusion_matrix import MaskRcnnConfusionMatrix
  33. >>> confusion_matrix = MaskRcnnConfusionMatrix(class_names={0: 'class1', 1: 'class2'},
  34. ... thrs_config={0: 0.5, 1: 0.5})
  35. >>> for images, targets in test_dataloader:
  36. >>> outputs = model(images)
  37. >>> confusion_matrix.update(outputs, targets)
  38. >>> confusion_matrix.plot(show=True)
  39. """
  40. self.iou_thr = iou_thr
  41. self.num_classes = len(thrs_config)
  42. self.classes = class_names.values()
  43. self.box_matrix = np.zeros((self.num_classes + 1, self.num_classes + 1))
  44. self.mask_matrix = np.zeros((self.num_classes + 1, self.num_classes + 1))
  45. self.thrs_config = thrs_config
  46. self.display_labels = list(self.classes) + ["Miss"]
  47. def update(self, predictions: dict, targets: dict, after_nms=False):
  48. """
  49. Update confusion matrix for masks and boxes.
  50. It is not very performative and effective implementations.
  51. Needs to be rewritten in vectorize style, cause currently it's loops
  52. Arguments:
  53. ---------
  54. predictions: dict of prediction from mask_rcnn
  55. targets: dict of targets from dataloader
  56. after_nms: it is after nms already or it needs to be thresholded here
  57. Returns:
  58. -------
  59. None, updates confusion matrix accordingly
  60. """
  61. if isinstance(targets["labels"], torch.Tensor):
  62. targets = {k: v.to("cpu").numpy() for k, v in targets.items() if type(v) is not str}
  63. l_classes = targets["labels"]
  64. l_bboxs = targets["boxes"]
  65. l_masks = targets["masks"]
  66. d_confs = predictions["scores"]
  67. d_bboxs = predictions["boxes"]
  68. d_masks = predictions["masks"]
  69. d_classes = predictions["labels"]
  70. if not after_nms:
  71. box_thrs = [self.thrs_config[label_id]["box_thr"] for label_id in d_classes]
  72. mask_thrs = [self.thrs_config[label_id]["mask_thr"] for label_id in d_classes]
  73. ids = np.where(d_confs > box_thrs)[0]
  74. d_classes = d_classes[ids]
  75. d_bboxs = d_bboxs[ids]
  76. d_masks = d_masks[ids]
  77. box_labels_detected = np.zeros(len(l_classes))
  78. mask_labels_detected = np.zeros(len(l_classes))
  79. box_detections_matched = np.zeros(len(d_classes))
  80. mask_detections_matched = np.zeros(len(d_classes))
  81. for l_idx, (l_class, l_bbox, l_mask) in enumerate(zip(l_classes, l_bboxs, l_masks)):
  82. for d_idx, (d_class, d_bbox, d_mask) in enumerate(zip(d_classes, d_bboxs, d_masks)):
  83. box_iou = self.box_pairwise_iou(l_bbox, d_bbox)
  84. mask_iou = self.mask_iou((l_mask, l_class), (d_mask, d_class))
  85. if box_iou >= self.iou_thr:
  86. self.box_matrix[l_class, d_class] += 1
  87. box_labels_detected[l_idx] = 1
  88. box_detections_matched[d_idx] = 1
  89. if mask_iou >= self.iou_thr:
  90. self.mask_matrix[l_class, d_class] += 1
  91. mask_labels_detected[l_idx] = 1
  92. mask_detections_matched[d_idx] = 1
  93. for i in np.where(box_labels_detected == 0)[0]:
  94. self.box_matrix[l_classes[i], -1] += 1
  95. for i in np.where(box_detections_matched == 0)[0]:
  96. self.box_matrix[-1, d_classes[i]] += 1
  97. for i in np.where(mask_labels_detected == 0)[0]:
  98. self.mask_matrix[l_classes[i], -1] += 1
  99. for i in np.where(mask_detections_matched == 0)[0]:
  100. self.mask_matrix[-1, d_classes[i]] += 1
  101. def process_batch(self, predictions: dict, targets: dict, after_nms=True):
  102. """
  103. Process batch of predictons and targets from model and dataloader
  104. to update confusion matrix.
  105. This is supposed to be effective vectorized implementations. Half of that have done, but not masks.
  106. It means that this implementation only for boxes confusion matrix
  107. Arguments:
  108. predictions: dict of prediction from mask_rcnn
  109. targets: dict of targets from dataloader
  110. after_nms: it is after nms already or it needs to be thresholded here
  111. Returns:
  112. None, updates confusion matrix accordingly
  113. """
  114. if isinstance(targets["labels"], torch.Tensor):
  115. targets = {k: v.to("cpu").numpy() for k, v in targets.items() if type(v) is not str}
  116. gt_classes = targets["labels"]
  117. box_thrs = [self.thrs_config[label_id]["box_thr"] for label_id in predictions["labels"]]
  118. try:
  119. prediction_indexes = np.where(predictions["scores"] > box_thrs)[0]
  120. prediction_classes = predictions["labels"][prediction_indexes]
  121. except IndexError or TypeError as e:
  122. # detections are empty, end of process
  123. print("Какая то хуйня произошла!")
  124. raise e
  125. if len(prediction_classes) == 0 and len(gt_classes) > 0:
  126. for gt_class in gt_classes:
  127. self.box_matrix[self.num_classes, gt_class] += 1
  128. return
  129. elif len(prediction_classes) == 0 and len(gt_classes) == 0:
  130. return
  131. all_ious = self.box_pairwise_iou(targets["boxes"], predictions["boxes"])
  132. want_idx = np.where(all_ious > self.iou_thr)
  133. all_matches = [
  134. [want_idx[0][i], want_idx[1][i], all_ious[want_idx[0][i], want_idx[1][i]]]
  135. for i in range(want_idx[0].shape[0])
  136. ]
  137. all_matches = np.array(all_matches)
  138. if all_matches.shape[0] > 0: # if there is match
  139. all_matches = all_matches[all_matches[:, 2].argsort()[::-1]]
  140. all_matches = all_matches[np.unique(all_matches[:, 1], return_index=True)[1]]
  141. all_matches = all_matches[all_matches[:, 2].argsort()[::-1]]
  142. all_matches = all_matches[np.unique(all_matches[:, 0], return_index=True)[1]]
  143. for i, gt_class in enumerate(gt_classes):
  144. if all_matches.shape[0] > 0 and all_matches[all_matches[:, 0] == i].shape[0] == 1:
  145. detection_class = prediction_classes[int(all_matches[all_matches[:, 0] == i, 1][0])]
  146. self.box_matrix[detection_class, gt_class] += 1
  147. else:
  148. self.box_matrix[self.num_classes, gt_class] += 1
  149. for i, detection_class in enumerate(prediction_classes):
  150. if not all_matches.shape[0] or (all_matches.shape[0] and all_matches[all_matches[:, 1] == i].shape[0] == 0):
  151. detection_class = prediction_classes[i]
  152. self.box_matrix[detection_class, self.num_classes] += 1
  153. def box_pairwise_iou(self, boxes1: NDArray[np.float32], boxes2: NDArray[np.float32]) -> NDArray[np.float32]:
  154. # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
  155. """
  156. Return intersection-over-union (Jaccard index) of boxes.
  157. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
  158. Arguments:
  159. boxes1 (Array[N, 4])
  160. boxes2 (Array[M, 4])
  161. Returns:
  162. iou (Array[N, M]): the NxM matrix containing the pairwise
  163. IoU values for every element in boxes1 and boxes2
  164. This implementation is taken from the above link and changed so that it only uses numpy..
  165. """
  166. if len(boxes1.shape) < 2:
  167. boxes1 = boxes1.reshape(1, -1)
  168. if len(boxes2.shape) < 2:
  169. boxes2 = boxes2.reshape(1, -1)
  170. def box_area(box):
  171. # box = 4xn
  172. return (box[2] - box[0]) * (box[3] - box[1])
  173. area1 = box_area(boxes1.T)
  174. area2 = box_area(boxes2.T)
  175. lt = np.maximum(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
  176. rb = np.minimum(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
  177. inter = np.prod(np.clip(rb - lt, a_min=0, a_max=None), 2) # type: ignore
  178. return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
  179. def mask_iou(self, mask_and_label1: tuple[NDArray, NDArray], mask_and_label2: tuple[NDArray, NDArray]):
  180. """
  181. Return intersection-over-union (Jaccard index) of masks.
  182. Masks should be pixel arrays
  183. Arguments:
  184. two tuples of mask and label
  185. """
  186. mask1, label1 = mask_and_label1
  187. mask2, label2 = mask_and_label2
  188. thrs1 = self.thrs_config[label1]["mask_thr"]
  189. thrs2 = self.thrs_config[label2]["mask_thr"]
  190. mask1_area = np.count_nonzero(mask1 >= thrs1)
  191. mask2_area = np.count_nonzero(mask2 >= thrs2)
  192. intersection = np.count_nonzero(np.logical_and(mask1, mask2))
  193. iou = intersection / (mask1_area + mask2_area - intersection)
  194. return iou
  195. def mask_pairwise_iou(self, masks1: np.ndarray, masks2: np.ndarray, labels1: np.ndarray, labels2: np.ndarray):
  196. # TODO: implement it finally!
  197. """Need to have been imnplemented eventually and tested"""
  198. f1 = np.array(zip(masks1, labels1))
  199. f2 = np.array(zip(masks2, labels2))
  200. return cdist(f1, f2, metric=self.mask_iou) # type: ignore
  201. def return_matrix(self):
  202. """Returns tuple of box and mask confusion matrix."""
  203. return self.box_matrix, self.mask_matrix
  204. def get_matrix_figure(self, type="box", pretty=True):
  205. """
  206. Returns figure of confusion matrix of either box or mask type
  207. Parameters
  208. ----------
  209. type: str, either box or mask, default `box`
  210. pretty: bool, default `True`
  211. plot pretty, featurize confusion matrix or just regular
  212. """
  213. if type == "box":
  214. if pretty:
  215. return pp_matrix(
  216. self.box_matrix,
  217. figsize=(14, 14),
  218. rotation=False,
  219. display_labels=self.display_labels,
  220. )
  221. else:
  222. return self.plot(figsize=(10, 10), type_matrix="boxes")
  223. else:
  224. if pretty:
  225. return pp_matrix(
  226. self.mask_matrix,
  227. figsize=(14, 14),
  228. rotation=False,
  229. display_labels=self.display_labels,
  230. )
  231. else:
  232. return self.plot(figsize=(10, 10), type_matrix="masks")
  233. def print_matrix(self):
  234. for i in range(self.num_classes + 1):
  235. print(" ".join(map(str, self.box_matrix[i])))
  236. def pretty_plot(
  237. self,
  238. type="box",
  239. figsize=(14, 14),
  240. rotation=False,
  241. cmap="viridis",
  242. ) -> Figure:
  243. """Plot feature rich confusion matrix.
  244. """
  245. if type=="box":
  246. return pp_matrix(
  247. self.box_matrix,
  248. figsize=figsize,
  249. rotation=rotation,
  250. display_labels=self.display_labels,
  251. cmap=cmap,
  252. show=True,
  253. )
  254. else:
  255. return pp_matrix(
  256. self.mask_matrix,
  257. figsize=figsize,
  258. rotation=rotation,
  259. display_labels=self.display_labels,
  260. cmap=cmap,
  261. show=True,
  262. )
  263. def plot(
  264. self,
  265. include_values=True,
  266. cmap="viridis",
  267. xticks_rotation="vertical",
  268. values_format=None,
  269. ax=None,
  270. colorbar=False,
  271. type_matrix="boxes",
  272. figsize=(9, 9),
  273. ) -> Figure:
  274. """Plot visualization of confusion matrix.
  275. Parameters
  276. ----------
  277. include_values : bool, default=True
  278. Includes values in confusion matrix.
  279. cmap : str or matplotlib Colormap, default='viridis'
  280. Colormap recognized by matplotlib.
  281. xticks_rotation : {'vertical', 'horizontal'} or float, \
  282. default='horizontal'
  283. Rotation of xtick labels.
  284. values_format : str, default=None
  285. Format specification for values in confusion matrix. If `None`,
  286. the format specification is 'd' or '.2g' whichever is shorter.
  287. ax : matplotlib axes, default=None
  288. Axes object to plot on. If `None`, a new figure and axes is
  289. created.
  290. colorbar : bool, default=True
  291. Whether or not to add a colorbar to the plot.
  292. figsize : tuple, default (9,9)
  293. Size of figure.
  294. type_matrix : str, ether box or mask
  295. Type of matrix that need to plot.
  296. Returns
  297. -------
  298. display : :firuge:`plt.figure`
  299. """
  300. if ax is None:
  301. fig, ax = plt.subplots(figsize=figsize)
  302. else:
  303. fig = ax.figure
  304. cm = self.box_matrix if type_matrix == "boxes" else self.mask_matrix
  305. n_classes = cm.shape[0]
  306. self.im_ = ax.imshow(cm, interpolation="nearest", cmap=cmap)
  307. self.text_ = None
  308. cmap_min, cmap_max = self.im_.cmap(0), self.im_.cmap(1.0)
  309. if include_values:
  310. self.text_ = np.empty_like(cm, dtype=object)
  311. # print text with appropriate color depending on background
  312. thresh = (cm.max() + cm.min()) / 2.0
  313. for i, j in product(range(n_classes), range(n_classes)):
  314. color = cmap_max if cm[i, j] < thresh else cmap_min
  315. if values_format is None:
  316. text_cm = format(cm[i, j], ".2g")
  317. if cm.dtype.kind != "f":
  318. text_d = format(cm[i, j], "d")
  319. if len(text_d) < len(text_cm):
  320. text_cm = text_d
  321. else:
  322. text_cm = format(cm[i, j], values_format)
  323. self.text_[i, j] = ax.text(j, i, text_cm, ha="center", va="center", color=color)
  324. if self.display_labels is None:
  325. display_labels = np.arange(n_classes)
  326. else:
  327. display_labels = self.display_labels
  328. if colorbar:
  329. fig.colorbar(self.im_, ax=ax)
  330. ax.set(
  331. xticks=np.arange(n_classes),
  332. yticks=np.arange(n_classes),
  333. xticklabels=display_labels,
  334. yticklabels=display_labels,
  335. ylabel="True label",
  336. xlabel="Predicted label",
  337. )
  338. ax.set_ylim((n_classes - 0.5, -0.5))
  339. plt.setp(ax.get_xticklabels(), rotation=xticks_rotation)
  340. plt.tight_layout()
  341. plt.grid(False)
  342. self.figure_ = fig
  343. self.ax_ = ax
  344. return fig
  345. # Helper function to draw pretty figure
  346. # -------------------------------------
  347. # This is main function to draw pretty confusion matrix
  348. def pp_matrix(
  349. df_cm: NDArray[np.float64] | pd.DataFrame,
  350. annot=True,
  351. cmap="viridis",
  352. fmt=".2f",
  353. fz=10,
  354. lw=1,
  355. cbar=False,
  356. figsize=[9, 9],
  357. show_null_values=False,
  358. pred_val_axis="x",
  359. show=False,
  360. rotation=True,
  361. display_labels=None,
  362. ):
  363. """
  364. print conf matrix with default layout (like matlab)
  365. params:
  366. df_cm dataframe (pandas) without totals
  367. annot print text in each cell
  368. cmap Oranges,Oranges_r,YlGnBu,Blues,RdBu, ... see:
  369. fz fontsize
  370. lw linewidth
  371. pred_val_axis where to show the prediction values (x or y axis)
  372. 'col' or 'x': show predicted values in columns (x axis) instead lines
  373. 'lin' or 'y': show predicted values in lines (y axis)
  374. show show the plot or not
  375. rotation rotate or not labels on figure
  376. display_labels None, list of labels that display on figure
  377. """
  378. if not isinstance(df_cm, pd.DataFrame):
  379. df_cm = pd.DataFrame(df_cm, index=display_labels, columns=display_labels)
  380. if pred_val_axis in ("col", "x"):
  381. xlbl = "Predicted"
  382. ylbl = "Actual"
  383. else:
  384. xlbl = "Actual"
  385. ylbl = "Predicted"
  386. df_cm = df_cm.T
  387. # create "Total" column
  388. insert_totals(df_cm)
  389. # this is for print allways in the same window
  390. fig, ax1 = get_new_fig("Conf matrix default", figsize)
  391. ax = sn.heatmap(
  392. df_cm,
  393. annot=annot,
  394. annot_kws={"size": fz},
  395. linewidths=lw,
  396. ax=ax1,
  397. cbar=cbar,
  398. cmap=cmap,
  399. linecolor="w",
  400. fmt=fmt,
  401. )
  402. # set ticklabels rotation
  403. if rotation:
  404. rotation_x = 45
  405. rotation_y = 25
  406. else:
  407. rotation_x = 0
  408. rotation_y = 90
  409. ax.set_xticklabels(ax.get_xticklabels(), rotation=rotation_y, fontsize=10)
  410. ax.set_yticklabels(ax.get_yticklabels(), rotation=rotation_x, fontsize=10)
  411. # Turn off all the ticks
  412. for t in ax.xaxis.get_major_ticks():
  413. t.tick1On = False
  414. t.tick2On = False
  415. for t in ax.yaxis.get_major_ticks():
  416. t.tick1On = False
  417. t.tick2On = False
  418. # face colors list
  419. quadmesh = ax.findobj(QuadMesh)[0]
  420. facecolors = quadmesh.get_facecolors()
  421. # iter in text elements
  422. array_df = np.array(df_cm.to_records(index=False).tolist())
  423. text_add = []
  424. text_del = []
  425. posi = -1 # from left to right, bottom to top.
  426. for t in ax.collections[0].axes.texts: # ax.texts:
  427. pos = np.array(t.get_position()) - [0.5, 0.5]
  428. lin = int(pos[1])
  429. col = int(pos[0])
  430. posi += 1
  431. # set text
  432. txt_res = configcell_text_and_colors(array_df, lin, col, t, facecolors, posi, fz, fmt, show_null_values)
  433. text_add.extend(txt_res[0])
  434. text_del.extend(txt_res[1])
  435. # remove the old ones
  436. for item in text_del:
  437. item.remove()
  438. # append the new ones
  439. for item in text_add:
  440. ax.text(item["x"], item["y"], item["text"], **item["kw"])
  441. # titles and legends
  442. ax.set_title("Confusion matrix")
  443. ax.set_xlabel(xlbl)
  444. ax.set_ylabel(ylbl)
  445. # set layout slim
  446. plt.tight_layout()
  447. if show:
  448. plt.show()
  449. return plt.gcf()
  450. # This is helper functions for pp_print function
  451. def get_new_fig(fn, figsize=[9, 9]):
  452. """Init graphics"""
  453. fig1 = plt.figure(fn, figsize)
  454. ax1 = fig1.gca() # Get Current Axis
  455. ax1.cla() # clear existing plot
  456. return fig1, ax1
  457. def configcell_text_and_colors(
  458. array_df: np.ndarray,
  459. lin: int,
  460. col: int,
  461. oText: Text,
  462. facecolors: list[Color],
  463. posi: int,
  464. fz: int,
  465. fmt: str,
  466. show_null_values=False,
  467. ):
  468. """
  469. config cell text and colors
  470. and return text elements to add and to dell
  471. """
  472. text_add = []
  473. text_del = []
  474. cell_val = array_df[lin][col]
  475. tot_all = array_df[-1][-1]
  476. per = (float(cell_val) / tot_all) * 100
  477. curr_column = array_df[:, col]
  478. ccl = len(curr_column)
  479. # last line and/or last column
  480. if (col == (ccl - 1)) or (lin == (ccl - 1)):
  481. # tots and percents
  482. if cell_val != 0:
  483. if (col == ccl - 1) and (lin == ccl - 1):
  484. tot_rig = 0
  485. for i in range(array_df.shape[0] - 1):
  486. tot_rig += array_df[i][i]
  487. per_ok = (float(tot_rig) / cell_val) * 100
  488. elif col == ccl - 1:
  489. tot_rig = array_df[lin][lin]
  490. per_ok = (float(tot_rig) / cell_val) * 100
  491. elif lin == ccl - 1:
  492. tot_rig = array_df[col][col]
  493. per_ok = (float(tot_rig) / cell_val) * 100
  494. per_err = 100 - per_ok # type: ignore
  495. else:
  496. per_ok = per_err = 0
  497. per_ok_s = ["%.2f%%" % (per_ok), "100%"][per_ok == 100] # type: ignore
  498. # text to DEL
  499. text_del.append(oText)
  500. # text to ADD
  501. font_prop = fm.FontProperties(weight="bold", size=fz)
  502. text_kwargs = dict(
  503. color="w",
  504. ha="center",
  505. va="center",
  506. gid="sum",
  507. fontproperties=font_prop,
  508. )
  509. lis_txt = ["%d" % (cell_val), per_ok_s, "%.2f%%" % (per_err)]
  510. lis_kwa = [text_kwargs]
  511. dic = text_kwargs.copy()
  512. dic["color"] = "g"
  513. lis_kwa.append(dic)
  514. dic = text_kwargs.copy()
  515. dic["color"] = "r"
  516. lis_kwa.append(dic)
  517. lis_pos = [
  518. (oText._x, oText._y - 0.3),
  519. (oText._x, oText._y),
  520. (oText._x, oText._y + 0.3),
  521. ]
  522. for i in range(len(lis_txt)):
  523. new_text = dict(
  524. x=lis_pos[i][0],
  525. y=lis_pos[i][1],
  526. text=lis_txt[i],
  527. kw=lis_kwa[i],
  528. )
  529. text_add.append(new_text)
  530. # set background color for sum cells (last line and last column)
  531. carr = (0.27, 0.30, 0.27, 1.0)
  532. if (col == ccl - 1) and (lin == ccl - 1):
  533. carr = (0.17, 0.20, 0.17, 1.0)
  534. facecolors[posi] = carr
  535. else:
  536. if per > 0:
  537. txt = "%s\n%.1f%%" % (cell_val, per)
  538. else:
  539. if show_null_values == False:
  540. txt = ""
  541. elif show_null_values == True:
  542. txt = "0"
  543. else:
  544. txt = "0\n0.0%"
  545. oText.set_text(txt)
  546. # main diagonal
  547. if col == lin:
  548. # set color of the textin the diagonal to white
  549. oText.set_color("w")
  550. # set background color in the diagonal to blue
  551. facecolors[posi] = (0.35, 0.8, 0.55, 1.0)
  552. else:
  553. oText.set_color("r")
  554. return text_add, text_del
  555. def insert_totals(df_cm):
  556. """insert total column and line (the last ones)"""
  557. sum_col = []
  558. for c in df_cm.columns:
  559. sum_col.append(df_cm[c].sum())
  560. sum_lin = []
  561. for item_line in df_cm.iterrows():
  562. sum_lin.append(item_line[1].sum())
  563. df_cm["sum_lin"] = sum_lin
  564. sum_col.append(np.sum(sum_lin))
  565. df_cm.loc["sum_col"] = sum_col