confusion_matrix.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630
  1. import torch
  2. import numpy as np
  3. import pandas as pd
  4. import seaborn as sn
  5. import matplotlib.pyplot as plt
  6. import matplotlib.font_manager as fm
  7. from itertools import product
  8. from matplotlib.collections import QuadMesh
  9. from scipy.spatial.distance import cdist
  10. from matplotlib.figure import Figure
  11. from numpy.typing import NDArray, ArrayLike
  12. def get_new_fig(fn, figsize=[9, 9]):
  13. """Init graphics"""
  14. fig1 = plt.figure(fn, figsize)
  15. ax1 = fig1.gca() # Get Current Axis
  16. ax1.cla() # clear existing plot
  17. return fig1, ax1
  18. def configcell_text_and_colors(array_df, lin, col, oText, facecolors, posi, fz, fmt, show_null_values=False):
  19. """
  20. config cell text and colors
  21. and return text elements to add and to dell
  22. @TODO: use fmt
  23. """
  24. text_add = []
  25. text_del = []
  26. cell_val = array_df[lin][col]
  27. tot_all = array_df[-1][-1]
  28. per = (float(cell_val) / tot_all) * 100
  29. curr_column = array_df[:, col]
  30. ccl = len(curr_column)
  31. # last line and/or last column
  32. if (col == (ccl - 1)) or (lin == (ccl - 1)):
  33. # tots and percents
  34. if cell_val != 0:
  35. if (col == ccl - 1) and (lin == ccl - 1):
  36. tot_rig = 0
  37. for i in range(array_df.shape[0] - 1):
  38. tot_rig += array_df[i][i]
  39. per_ok = (float(tot_rig) / cell_val) * 100
  40. elif col == ccl - 1:
  41. tot_rig = array_df[lin][lin]
  42. per_ok = (float(tot_rig) / cell_val) * 100
  43. elif lin == ccl - 1:
  44. tot_rig = array_df[col][col]
  45. per_ok = (float(tot_rig) / cell_val) * 100
  46. per_err = 100 - per_ok # type: ignore
  47. else:
  48. per_ok = per_err = 0
  49. per_ok_s = ["%.2f%%" % (per_ok), "100%"][per_ok == 100] # type: ignore
  50. # text to DEL
  51. text_del.append(oText)
  52. # text to ADD
  53. font_prop = fm.FontProperties(weight="bold", size=fz)
  54. text_kwargs = dict(
  55. color="w",
  56. ha="center",
  57. va="center",
  58. gid="sum",
  59. fontproperties=font_prop,
  60. )
  61. lis_txt = ["%d" % (cell_val), per_ok_s, "%.2f%%" % (per_err)]
  62. lis_kwa = [text_kwargs]
  63. dic = text_kwargs.copy()
  64. dic["color"] = "g"
  65. lis_kwa.append(dic)
  66. dic = text_kwargs.copy()
  67. dic["color"] = "r"
  68. lis_kwa.append(dic)
  69. lis_pos = [
  70. (oText._x, oText._y - 0.3),
  71. (oText._x, oText._y),
  72. (oText._x, oText._y + 0.3),
  73. ]
  74. for i in range(len(lis_txt)):
  75. newText = dict(
  76. x=lis_pos[i][0],
  77. y=lis_pos[i][1],
  78. text=lis_txt[i],
  79. kw=lis_kwa[i],
  80. )
  81. text_add.append(newText)
  82. # set background color for sum cells (last line and last column)
  83. carr = [0.27, 0.30, 0.27, 1.0]
  84. if (col == ccl - 1) and (lin == ccl - 1):
  85. carr = [0.17, 0.20, 0.17, 1.0]
  86. facecolors[posi] = carr
  87. else:
  88. if per > 0:
  89. txt = "%s\n%.1f%%" % (cell_val, per)
  90. else:
  91. if show_null_values == False:
  92. txt = ""
  93. elif show_null_values == True:
  94. txt = "0"
  95. else:
  96. txt = "0\n0.0%"
  97. oText.set_text(txt)
  98. # main diagonal
  99. if col == lin:
  100. # set color of the textin the diagonal to white
  101. oText.set_color("w")
  102. # set background color in the diagonal to blue
  103. facecolors[posi] = [0.35, 0.8, 0.55, 1.0]
  104. else:
  105. oText.set_color("r")
  106. return text_add, text_del
  107. def insert_totals(df_cm):
  108. """insert total column and line (the last ones)"""
  109. sum_col = []
  110. for c in df_cm.columns:
  111. sum_col.append(df_cm[c].sum())
  112. sum_lin = []
  113. for item_line in df_cm.iterrows():
  114. sum_lin.append(item_line[1].sum())
  115. df_cm["sum_lin"] = sum_lin
  116. sum_col.append(np.sum(sum_lin))
  117. df_cm.loc["sum_col"] = sum_col
  118. def pp_matrix(
  119. df_cm: NDArray[np.float64] | pd.DataFrame,
  120. annot=True,
  121. cmap="viridis",
  122. fmt=".2f",
  123. fz=10,
  124. lw=1,
  125. cbar=False,
  126. figsize=[9, 9],
  127. show_null_values=False,
  128. pred_val_axis="x",
  129. show=False,
  130. rotation=True,
  131. display_labels=None,
  132. ):
  133. """
  134. print conf matrix with default layout (like matlab)
  135. params:
  136. df_cm dataframe (pandas) without totals
  137. annot print text in each cell
  138. cmap Oranges,Oranges_r,YlGnBu,Blues,RdBu, ... see:
  139. fz fontsize
  140. lw linewidth
  141. pred_val_axis where to show the prediction values (x or y axis)
  142. 'col' or 'x': show predicted values in columns (x axis) instead lines
  143. 'lin' or 'y': show predicted values in lines (y axis)
  144. show show the plot or not
  145. rotation rotate or not labels on figure
  146. display_labels None, list of labels that display on figure
  147. """
  148. if not isinstance(df_cm, pd.DataFrame):
  149. df_cm = pd.DataFrame(df_cm, index=display_labels, columns=display_labels)
  150. if pred_val_axis in ("col", "x"):
  151. xlbl = "Predicted"
  152. ylbl = "Actual"
  153. else:
  154. xlbl = "Actual"
  155. ylbl = "Predicted"
  156. df_cm = df_cm.T
  157. # create "Total" column
  158. insert_totals(df_cm)
  159. # this is for print allways in the same window
  160. fig, ax1 = get_new_fig("Conf matrix default", figsize)
  161. ax = sn.heatmap(
  162. df_cm,
  163. annot=annot,
  164. annot_kws={"size": fz},
  165. linewidths=lw,
  166. ax=ax1,
  167. cbar=cbar,
  168. cmap=cmap,
  169. linecolor="w",
  170. fmt=fmt,
  171. )
  172. # set ticklabels rotation
  173. if rotation:
  174. rotation_x = 45
  175. rotation_y = 25
  176. else:
  177. rotation_x = 0
  178. rotation_y = 90
  179. ax.set_xticklabels(ax.get_xticklabels(), rotation=rotation_y, fontsize=10)
  180. ax.set_yticklabels(ax.get_yticklabels(), rotation=rotation_x, fontsize=10)
  181. # Turn off all the ticks
  182. for t in ax.xaxis.get_major_ticks():
  183. t.tick1On = False
  184. t.tick2On = False
  185. for t in ax.yaxis.get_major_ticks():
  186. t.tick1On = False
  187. t.tick2On = False
  188. # face colors list
  189. quadmesh = ax.findobj(QuadMesh)[0]
  190. facecolors = quadmesh.get_facecolors()
  191. # iter in text elements
  192. array_df = np.array(df_cm.to_records(index=False).tolist())
  193. text_add = []
  194. text_del = []
  195. posi = -1 # from left to right, bottom to top.
  196. for t in ax.collections[0].axes.texts: # ax.texts:
  197. pos = np.array(t.get_position()) - [0.5, 0.5]
  198. lin = int(pos[1])
  199. col = int(pos[0])
  200. posi += 1
  201. # set text
  202. txt_res = configcell_text_and_colors(array_df, lin, col, t, facecolors, posi, fz, fmt, show_null_values)
  203. text_add.extend(txt_res[0])
  204. text_del.extend(txt_res[1])
  205. # remove the old ones
  206. for item in text_del:
  207. item.remove()
  208. # append the new ones
  209. for item in text_add:
  210. ax.text(item["x"], item["y"], item["text"], **item["kw"])
  211. # titles and legends
  212. ax.set_title("Confusion matrix")
  213. ax.set_xlabel(xlbl)
  214. ax.set_ylabel(ylbl)
  215. # set layout slim
  216. plt.tight_layout()
  217. if show:
  218. plt.show()
  219. return plt.gcf()
  220. class ConfusionMatrix:
  221. def __init__(self, thrs_config: dict, class_names: dict, iou_thr=0.5,):
  222. """
  223. Class to create and dislpay confusion matrix for MaskRCNN
  224. or any type of instance segmentation and detection architectures
  225. Parameters
  226. ----------
  227. - iou_thr: IOU threshold
  228. - thrs_config: dict of thresholds for every class
  229. - class_names: dict of class names accordingly to class numbers
  230. Attributes
  231. ----------
  232. - box_matrix: confusion matrix that contains results from boxes
  233. - mask_matrix: confusion matrix that contains results from masks
  234. - display_labels: labels according to classes + miss
  235. - figure_: contains last plot of confusion matrix
  236. Examples
  237. --------
  238. >>> from confusion_matrix import MaskRcnnConfusionMatrix
  239. >>> confusion_matrix = MaskRcnnConfusionMatrix(class_names={0: 'class1', 1: 'class2'},
  240. ... thrs_config={0: 0.5, 1: 0.5})
  241. >>> for images, targets in test_dataloader:
  242. >>> outputs = model(images)
  243. >>> confusion_matrix.update(outputs, targets)
  244. >>> confusion_matrix.plot(show=True)
  245. """
  246. self.iou_thr = iou_thr
  247. self.num_classes = len(thrs_config)
  248. self.classes = class_names.values()
  249. self.box_matrix = np.zeros((self.num_classes + 1, self.num_classes + 1))
  250. self.mask_matrix = np.zeros((self.num_classes + 1, self.num_classes + 1))
  251. self.thrs_config = thrs_config
  252. self.display_labels = list(self.classes) + ["Miss"]
  253. def update(self, predictions: dict, targets: dict, after_nms=False):
  254. """
  255. Update confusion matrix for masks and boxes.
  256. It is not very performative and effective implementations.
  257. Needs to be rewritten in vectorize style, cause currently it's loops
  258. Arguments:
  259. ---------
  260. predictions: dict of prediction from mask_rcnn
  261. targets: dict of targets from dataloader
  262. after_nms: it is after nms already or it needs to be thresholded here
  263. Returns:
  264. -------
  265. None, updates confusion matrix accordingly
  266. """
  267. if isinstance(targets["labels"], torch.Tensor):
  268. targets = {k: v.to("cpu").numpy() for k, v in targets.items() if type(v) is not str}
  269. l_classes = targets["labels"]
  270. l_bboxs = targets["boxes"]
  271. l_masks = targets["masks"]
  272. d_confs = predictions["scores"]
  273. d_bboxs = predictions["boxes"]
  274. d_masks = predictions["masks"]
  275. d_classes = predictions["labels"]
  276. if not after_nms:
  277. box_thrs = [self.thrs_config[label_id]["box_thr"] for label_id in d_classes]
  278. mask_thrs = [self.thrs_config[label_id]["mask_thr"] for label_id in d_classes]
  279. ids = np.where(d_confs > box_thrs)[0]
  280. d_classes = d_classes[ids]
  281. d_bboxs = d_bboxs[ids]
  282. d_masks = d_masks[ids]
  283. box_labels_detected = np.zeros(len(l_classes))
  284. mask_labels_detected = np.zeros(len(l_classes))
  285. box_detections_matched = np.zeros(len(d_classes))
  286. mask_detections_matched = np.zeros(len(d_classes))
  287. for l_idx, (l_class, l_bbox, l_mask) in enumerate(zip(l_classes, l_bboxs, l_masks)):
  288. for d_idx, (d_class, d_bbox, d_mask) in enumerate(zip(d_classes, d_bboxs, d_masks)):
  289. box_iou = self.box_pairwise_iou(l_bbox, d_bbox)
  290. mask_iou = self.mask_iou((l_mask, l_class), (d_mask, d_class))
  291. if box_iou >= self.iou_thr:
  292. self.box_matrix[l_class, d_class] += 1
  293. box_labels_detected[l_idx] = 1
  294. box_detections_matched[d_idx] = 1
  295. if mask_iou >= self.iou_thr:
  296. self.mask_matrix[l_class, d_class] += 1
  297. mask_labels_detected[l_idx] = 1
  298. mask_detections_matched[d_idx] = 1
  299. for i in np.where(box_labels_detected == 0)[0]:
  300. self.box_matrix[l_classes[i], -1] += 1
  301. for i in np.where(box_detections_matched == 0)[0]:
  302. self.box_matrix[-1, d_classes[i]] += 1
  303. for i in np.where(mask_labels_detected == 0)[0]:
  304. self.mask_matrix[l_classes[i], -1] += 1
  305. for i in np.where(mask_detections_matched == 0)[0]:
  306. self.mask_matrix[-1, d_classes[i]] += 1
  307. def process_batch(self, predictions: dict, targets: dict, after_nms=True):
  308. """
  309. Process batch of predictons and targets from model and dataloader
  310. to update confusion matrix.
  311. This is supposed to be effective vectorized implementations. Half of that have done, but not masks.
  312. It means that this implementation only for boxes confusion matrix
  313. Arguments:
  314. predictions: dict of prediction from mask_rcnn
  315. targets: dict of targets from dataloader
  316. after_nms: it is after nms already or it needs to be thresholded here
  317. Returns:
  318. None, updates confusion matrix accordingly
  319. """
  320. if isinstance(targets["labels"], torch.Tensor):
  321. targets = {k: v.to("cpu").numpy() for k, v in targets.items() if type(v) is not str}
  322. gt_classes = targets["labels"]
  323. box_thrs = [self.thrs_config[label_id]["box_thr"] for label_id in predictions["labels"]]
  324. try:
  325. prediction_indexes = np.where(predictions["scores"] > box_thrs)[0]
  326. prediction_classes = predictions["labels"][prediction_indexes]
  327. except IndexError or TypeError as e:
  328. # detections are empty, end of process
  329. print("Какая то хуйня произошла!")
  330. raise e
  331. if len(prediction_classes) == 0 and len(gt_classes) > 0:
  332. for gt_class in gt_classes:
  333. self.box_matrix[self.num_classes, gt_class] += 1
  334. return
  335. elif len(prediction_classes) == 0 and len(gt_classes) == 0:
  336. return
  337. all_ious = self.box_pairwise_iou(targets["boxes"], predictions["boxes"])
  338. want_idx = np.where(all_ious > self.iou_thr)
  339. all_matches = [
  340. [want_idx[0][i], want_idx[1][i], all_ious[want_idx[0][i], want_idx[1][i]]]
  341. for i in range(want_idx[0].shape[0])
  342. ]
  343. all_matches = np.array(all_matches)
  344. if all_matches.shape[0] > 0: # if there is match
  345. all_matches = all_matches[all_matches[:, 2].argsort()[::-1]]
  346. all_matches = all_matches[np.unique(all_matches[:, 1], return_index=True)[1]]
  347. all_matches = all_matches[all_matches[:, 2].argsort()[::-1]]
  348. all_matches = all_matches[np.unique(all_matches[:, 0], return_index=True)[1]]
  349. for i, gt_class in enumerate(gt_classes):
  350. if all_matches.shape[0] > 0 and all_matches[all_matches[:, 0] == i].shape[0] == 1:
  351. detection_class = prediction_classes[int(all_matches[all_matches[:, 0] == i, 1][0])]
  352. self.box_matrix[detection_class, gt_class] += 1
  353. else:
  354. self.box_matrix[self.num_classes, gt_class] += 1
  355. for i, detection_class in enumerate(prediction_classes):
  356. if not all_matches.shape[0] or (all_matches.shape[0] and all_matches[all_matches[:, 1] == i].shape[0] == 0):
  357. detection_class = prediction_classes[i]
  358. self.box_matrix[detection_class, self.num_classes] += 1
  359. def box_pairwise_iou(self, boxes1: NDArray[np.float32], boxes2: NDArray[np.float32]) -> NDArray[np.float32]:
  360. # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
  361. """
  362. Return intersection-over-union (Jaccard index) of boxes.
  363. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
  364. Arguments:
  365. boxes1 (Array[N, 4])
  366. boxes2 (Array[M, 4])
  367. Returns:
  368. iou (Array[N, M]): the NxM matrix containing the pairwise
  369. IoU values for every element in boxes1 and boxes2
  370. This implementation is taken from the above link and changed so that it only uses numpy..
  371. """
  372. if len(boxes1.shape) < 2:
  373. boxes1 = boxes1.reshape(1, -1)
  374. if len(boxes2.shape) < 2:
  375. boxes2 = boxes2.reshape(1, -1)
  376. def box_area(box):
  377. # box = 4xn
  378. return (box[2] - box[0]) * (box[3] - box[1])
  379. area1 = box_area(boxes1.T)
  380. area2 = box_area(boxes2.T)
  381. lt = np.maximum(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
  382. rb = np.minimum(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
  383. inter = np.prod(np.clip(rb - lt, a_min=0, a_max=None), 2) # type: ignore
  384. return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
  385. def mask_iou(self, mask_and_label1: tuple[NDArray, NDArray], mask_and_label2: tuple[NDArray, NDArray]):
  386. """
  387. Return intersection-over-union (Jaccard index) of masks.
  388. Masks should be pixel arrays
  389. Arguments:
  390. two tuples of mask and label
  391. """
  392. mask1, label1 = mask_and_label1
  393. mask2, label2 = mask_and_label2
  394. thrs1 = self.thrs_config[label1]["mask_thr"]
  395. thrs2 = self.thrs_config[label2]["mask_thr"]
  396. mask1_area = np.count_nonzero(mask1 >= thrs1)
  397. mask2_area = np.count_nonzero(mask2 >= thrs2)
  398. intersection = np.count_nonzero(np.logical_and(mask1, mask2))
  399. iou = intersection / (mask1_area + mask2_area - intersection)
  400. return iou
  401. def mask_pairwise_iou(self, masks1: np.ndarray, masks2: np.ndarray, labels1: np.ndarray, labels2: np.ndarray):
  402. """Need to have been imnplemented eventually and tested"""
  403. f1 = np.array(zip(masks1, labels1))
  404. f2 = np.array(zip(masks2, labels2))
  405. return cdist(f1, f2, metric=self.mask_iou) # TODO: implement it finally!
  406. def return_matrix(self):
  407. """Returns tuple of box and mask confusion matrix."""
  408. return self.box_matrix, self.mask_matrix
  409. def get_matrix_figure(self, type="box", pretty=True):
  410. """
  411. Returns figure of confusion matrix of either box or mask type
  412. Parameters
  413. ----------
  414. type: str, either box or mask, default `box`
  415. pretty: bool, default `True`
  416. plot pretty, featurize confusion matrix or just regular
  417. """
  418. if type == "box":
  419. if pretty:
  420. return pp_matrix(
  421. self.box_matrix,
  422. figsize=(14, 14),
  423. rotation=False,
  424. display_labels=self.display_labels,
  425. )
  426. else:
  427. return self.plot(figsize=(10, 10), type_matrix="boxes")
  428. else:
  429. if pretty:
  430. return pp_matrix(
  431. self.mask_matrix,
  432. figsize=(14, 14),
  433. rotation=False,
  434. display_labels=self.display_labels,
  435. )
  436. else:
  437. return self.plot(figsize=(10, 10), type_matrix="masks")
  438. def print_matrix(self):
  439. for i in range(self.num_classes + 1):
  440. print(" ".join(map(str, self.box_matrix[i])))
  441. def pretty_plot(
  442. self,
  443. type="box",
  444. figsize=(14, 14),
  445. rotation=False,
  446. cmap="viridis",
  447. ) -> Figure:
  448. """Plot feature rich confusion matrix.
  449. """
  450. if type=="box":
  451. return pp_matrix(
  452. self.box_matrix,
  453. figsize=figsize,
  454. rotation=rotation,
  455. display_labels=self.display_labels,
  456. cmap=cmap,
  457. show=True,
  458. )
  459. else:
  460. return pp_matrix(
  461. self.mask_matrix,
  462. figsize=figsize,
  463. rotation=rotation,
  464. display_labels=self.display_labels,
  465. cmap=cmap,
  466. show=True,
  467. )
  468. def plot(
  469. self,
  470. include_values=True,
  471. cmap="viridis",
  472. xticks_rotation="vertical",
  473. values_format=None,
  474. ax=None,
  475. colorbar=False,
  476. type_matrix="boxes",
  477. figsize=(9, 9),
  478. ) -> Figure:
  479. """Plot visualization of confusion matrix.
  480. Parameters
  481. ----------
  482. include_values : bool, default=True
  483. Includes values in confusion matrix.
  484. cmap : str or matplotlib Colormap, default='viridis'
  485. Colormap recognized by matplotlib.
  486. xticks_rotation : {'vertical', 'horizontal'} or float, \
  487. default='horizontal'
  488. Rotation of xtick labels.
  489. values_format : str, default=None
  490. Format specification for values in confusion matrix. If `None`,
  491. the format specification is 'd' or '.2g' whichever is shorter.
  492. ax : matplotlib axes, default=None
  493. Axes object to plot on. If `None`, a new figure and axes is
  494. created.
  495. colorbar : bool, default=True
  496. Whether or not to add a colorbar to the plot.
  497. figsize : tuple, default (9,9)
  498. Size of figure.
  499. type_matrix : str, ether box or mask
  500. Type of matrix that need to plot.
  501. Returns
  502. -------
  503. display : :firuge:`plt.figure`
  504. """
  505. if ax is None:
  506. fig, ax = plt.subplots(figsize=figsize)
  507. else:
  508. fig = ax.figure
  509. cm = self.box_matrix if type_matrix == "boxes" else self.mask_matrix
  510. n_classes = cm.shape[0]
  511. self.im_ = ax.imshow(cm, interpolation="nearest", cmap=cmap)
  512. self.text_ = None
  513. cmap_min, cmap_max = self.im_.cmap(0), self.im_.cmap(1.0)
  514. if include_values:
  515. self.text_ = np.empty_like(cm, dtype=object)
  516. # print text with appropriate color depending on background
  517. thresh = (cm.max() + cm.min()) / 2.0
  518. for i, j in product(range(n_classes), range(n_classes)):
  519. color = cmap_max if cm[i, j] < thresh else cmap_min
  520. if values_format is None:
  521. text_cm = format(cm[i, j], ".2g")
  522. if cm.dtype.kind != "f":
  523. text_d = format(cm[i, j], "d")
  524. if len(text_d) < len(text_cm):
  525. text_cm = text_d
  526. else:
  527. text_cm = format(cm[i, j], values_format)
  528. self.text_[i, j] = ax.text(j, i, text_cm, ha="center", va="center", color=color)
  529. if self.display_labels is None:
  530. display_labels = np.arange(n_classes)
  531. else:
  532. display_labels = self.display_labels
  533. if colorbar:
  534. fig.colorbar(self.im_, ax=ax)
  535. ax.set(
  536. xticks=np.arange(n_classes),
  537. yticks=np.arange(n_classes),
  538. xticklabels=display_labels,
  539. yticklabels=display_labels,
  540. ylabel="True label",
  541. xlabel="Predicted label",
  542. )
  543. ax.set_ylim((n_classes - 0.5, -0.5))
  544. plt.setp(ax.get_xticklabels(), rotation=xticks_rotation)
  545. plt.tight_layout()
  546. plt.grid(False)
  547. self.figure_ = fig
  548. self.ax_ = ax
  549. return fig