Source code for MIRAG.affichage

"""
Module for displaying the results of the ADMM in its sparse and separation form 
"""

from matplotlib.colors import LinearSegmentedColormap
from sklearn.metrics import roc_curve, roc_auc_score
import matplotlib.pyplot as plt
from scipy import fft
import numpy as np
import tikzplotlib

from .optim.admm_func import roll_fft


[docs]def cmap_perso(Q): r"""Creation of a divergent custom colormap centered on the mean of the Q matrix Parameters ---------- Q :float Matrix to display (mainly sum of C_k) (Nx,Ny) Returns ------- cmap :obj associated cmap object """ normQ = (Q - Q.min()) / (Q.max() - Q.min()) colors = ["blue", "lightsteelblue", "white", "lightsalmon", "red"] node = [0, normQ.mean() * 0.95, normQ.mean(), normQ.mean() * 1.05, 1] if normQ.mean() * 1.05 > 1: node = [0, normQ.mean() * 0.95, normQ.mean(), 0.9999, 1] cmap_pers = LinearSegmentedColormap.from_list("mycmap", list(zip(node, colors))) return cmap_pers
[docs]def plot_ckmap(alpha, duo=False, t=60, x=128, title=["_", "_"], nfile="_", save=False): r"""Display one or 2 C_k maps from a personal divergent cmap Parameters ---------- alpha :float Matrix (M x N)(mainly sum of C_k or detail of a C_k) duo :bool{False}, optional Display of one or 2 cards. t :int{60}, optional central position (pixel) of the hyperbolas used (ordinate) x :int{128}, optional central position (pixel) of the used hyperbolas (abscissa) title :list{["_","_"]}, optional Titles of the graphs nfile :str{"_"}, optional Name of the file for the record save :bool{False}, optional Save the file. Returns ------- None """ if duo: fig, ax = plt.subplots(1, 2, figsize=(12, 12)) for i in range(2): Q = roll_fft(alpha[i], t, x) cmapPers = cmap_perso(Q) x = ax[0].imshow(Q, cmap=cmapPers, aspect="equal") fig.colorbar(x, ax=ax[i], shrink=0.5) ax[0].set_title(title[i]) else: Q = roll_fft(alpha, t, x) cmapPers = cmap_perso(Q) fig, ax = plt.subplots(figsize=(12, 12)) x = ax.imshow(Q, cmap=cmapPers, aspect="equal") fig.colorbar(x, ax=ax, shrink=0.5) ax.set_title(title[0]) if save: tikzplotlib.save(nfile + ".tex")
[docs]def plot_ckmap_img(T, ck=False, title=["_", "_"], nfile="_", save=False, t=60, x=128): r"""Displaying the sum of C_k and the dimensional image Parameters ---------- T :float array [C_k , img] ([(Nx * Ny * K) , (Nx * Ny)]) or ([(Nx * Ny), (Nx * Ny)]) ck :bool{False}, optional If the corrections/sum on the C_k have been done title :list{["_","_"]}, optional Titles of the graphs nfile :str{"_"}, optional Name of the file for the record save :bool{False}, optional Save the file. t :int{60}, optional central position (pixel) of the hyperbolas used (ordinate) x :int{128}, optional central position (pixel) of the hyperbolas used (abscissa) Returns ------- None """ fig, ax = plt.subplots(1, 2, figsize=(12, 6)) if ck: Q = roll_fft(T[0], t, x) cmapPers = cmap_perso(Q) x = ax[0].imshow(Q, cmap=cmapPers, aspect="equal") else: ax[0].imshow(np.real(T[0]), aspect="equal", cmap="gray") ax[0].set_title(title[0]) ax[1].imshow(np.real(T[1]), aspect="equal", cmap="gray") ax[1].set_title(title[1]) ax[0].set_axis_off() ax[1].set_axis_off() fig.suptitle(title[2], fontsize=16, fontweight="bold") plt.tight_layout() if save: plt.savefig(nfile + ".png")
[docs]def plot_recon(Dal, Dico=None, name="_", save=False, compute=True): r"""Displays the reconstruction of an image using the dictionary and maps of the C_k Parameters ---------- Dal :float Either the already computed reconstruction (Nx * Ny) or the C_k tensor (Nx * Ny * K). Dico :float{None}, optional dictionary necessary for the calculation of the reconstruction (if not already computed) (Nx * Ny * K). name :str{"_"}, optional name of the file to save. save :bool{False}, optional save or not the image. compute :bool{True}, optional reconstruction already computed or not. Returns ------- None """ if compute: Dal1 = fft.ifft2( np.sum(fft.fft2(Dal, axes=(0, 1)) * fft.fft2(Dico, axes=(0, 1)), 2), axes=(0, 1), workers=-1, ) else: Dal1 = Dal _, ax = plt.subplots(1, figsize=(12, 12)) ax.imshow(np.real(Dal1), aspect="equal", cmap="gray") if save: tikzplotlib.save(name + ".tex")
# np.savez(name,Dal1)
[docs]def plot_atomNSM(atm2, paraDic): r"""Display of an atom from the physical dictionary with the right dimensions (m and ns) Parameters ---------- atm2 :float matrix of the atom (Nx * Ny) paraDic :dic dictionary of the parameters of the atom (size ns\m) necessary key : "size_ns_m Returns ------- None Examples -------- >>> paraDic={} >>> paraDic["size_ns_m"]=[900,45] >>> plot_atomNSM(atoms,paraDic) """ fig, ax = plt.subplots(1, figsize=(12, 3)) m = ax.imshow(atm2, cmap="gray", aspect="equal") op = ax.get_yticks() op1 = ax.get_xticks() b = np.round(np.linspace(0, paraDic["size_ns_m"][0], len(op))) b1 = np.round(np.linspace(0, paraDic["size_ns_m"][1], len(op1))) a = np.hstack(("-1", b.astype("str"))) ax.set_yticklabels(a) a = np.hstack(("-1", b1)) ax.set_xticklabels(a.astype("str")) fig.colorbar(m, ax=ax) fig.show()
[docs]def scale_0_1(img): """ Scale an image between 0 and 1 Parameters ---------- img : numpy.ndarray Image to scale Returns ------- img : numpy.ndarray Scaled image """ scaled = (img-img.min())/(img.max()-img.min()) return scaled
[docs]def roc_curve_plot(mask,img,name): """ Plot ROC curve Parameters ---------- mask : float mask of the image img : array offloat array to be roc-curve plotted name : str name of the plot """ f,ax = plt.subplots(1,figsize=(6.5,6.5)) for i,ref_2 in enumerate(img): a = np.array(ref_2,dtype=np.float64)**2 b = np.where(mask>128,1,0)#mask/255.0 b = b.ravel() a = a.ravel() auc_score = roc_auc_score(b, a) fpr, tpr, thresholds = roc_curve(b, a) ax.plot([0,1],[0,1],'--',color='black') ax.plot(fpr, tpr, label=f'{name[i]} - ROC curve (area = %0.2f)' % auc_score) ax.set_xlabel('False Positive Rate') ax.set_ylabel('True Positive Rate') ax.legend() plt.show()