"""
References
----------
.. [5] 'Sparse Decomposition of the GPR Useful Signal from Hyperbola Dictionary',
Guillaume Terasse, Jean-Marie Nicolas, Emmanuel Trouvé and Émeline Drouet
Avalaible at: https://hal.archives-ouvertes.fr/hal-01351242
.. warning::
**To do**
- Add logging for tracking
- Add stop condition on the variation of parameters in addition to the number of iterations
"""
import logging
import numpy as np
from tqdm import tqdm
from scipy import linalg,fft
from sklearn.utils.validation import _deprecate_positional_args
from pywt import threshold,threshold_firm
from . import admm_func
[docs]def cost_function_addm2(Y,rhoS,rhoL,old_auxiliary,old_dal,old_L,L,dal,primal,auxiliary):
r"""ADMM cost and error calculation function
with 2 constraints.
Parameters
----------
Y : float
original image (Nx,Ny,1)
rhoS : float
penalty parameter on variable S
rhoL : float
penalty parameter on variable L
old_auxiliary : float
auxiliary variable of the previous iteration :math:`\mathbf{S}_k^{i-1}`
old_dal : float
reconstruction of the previous iteration :math:`\sum_k{\mathbf{C}_k^{i-1}\star\mathbf{H}_k}^{i-1}`
old_L : float
variable of the hollow matrix of the previous iteration :math:`\mathbf{L}^{i-1}`
L : float
variable of the hollow matrix :math:`\mathbf{L}`
dal :float
reconstruction
primal : float
primal variable :math:`\mathbf{C}_k^i`
auxiliary : float
auxiliary variable :math:`\mathbf{S}_k^{i-1}`
Returns
-------
var_prim_:float
variation of the primal
error_rec_:float
reconstruction error
error_primal_:float
primal error
error_dual_S:float
error of dual of S :math:`\mathbf{U_S}`
error_dual_L:float
error from dual of L :math:`\mathbf{U_L}`
"""
var_prim_ = np.linalg.norm(dal-old_dal,'fro',axis=(0,1))[0]
error_rec_ = np.linalg.norm(Y-dal-L,'fro',axis=(0,1))[0]
error_primal_ = np.linalg.norm(np.sum(primal-auxiliary,2))
error_dual_L = np.linalg.norm(np.sum(-rhoL*(L-old_L),2))
error_dual_S = np.linalg.norm(np.sum(-rhoS*(auxiliary-old_auxiliary),2))
return var_prim_,error_rec_,error_primal_,error_dual_S,error_dual_L
[docs]def addm_iteration_norm2(lambdaS,c,Y,Dal,DF,DF_H,Us,Uy,rhoS,rhoL,S,dim1,over_relax,penalty="l1",m=-1):
r"""function for computing an iteration of `ADMM` with norm 2 and the hollow matrix
Based on [4]_ and [5]_
Parameters
----------
lambdaS : float
parsimony parameter
c : float
pre-calculation of the term for Sherman_MorrisonF
Y : float
original image
Dal : float
convolution product dictionary + coeff map (Nx,Ny,1)
DF : complex
fft of the dictionary (Nx,Ny,K)
DF_H : complex
conjugate of the dictionary fft (Nx,Ny,K)
Us : float
dual variable of S
Uy : float
dual variable of L
rhoS : float
penalty parameter on S
rhoL : float
penalty parameter on L
S : float
auxiliary variable :math:`\mathbf{S}_k`
dim1 : int
dimensions of the image (ex: [256,256,1])
over_relax : float
over-relaxation parameter (improves convergence for :math:`\alpha\sim 1.6`)
penalty : str{"l1", "l0", "FirmThresholding"}, optional
data attachment penalty, basic :math:`\sum_k{||\mathbf{S}_k||_1`
m : int{-1}, optional
Number of workers (cores) used for the fft
Returns
-------
primal_new:float
variable primal :math:`\mathbf{C}_k^i` (Nx,Ny,K)
auxiliary_new:float
dual variable :math:`\mathbf{S}_k^i` (Nx,Ny,K)
L: float
hollow matrix (Nx,Ny,1)
Dal: float
convolution product dictionary + map coeff ?:math:`\sum_k{\mathbf{C}_k^i\star\mathbf{H}_k^i}`? (Nx,Ny,1)
"""
[u,Sig,v] = linalg.svd((Y-Dal+Uy)[:,:,0],check_finite=False)
Sig = admm_func.diag_thresh(u.shape[0],v.shape[0],Sig)
L = (u@threshold(Sig,(1/rhoL),'soft')@v).reshape(dim1)
x_b = fft.fft2((Y-L+Uy),axes=(0,1),workers=m)
z_b = fft.fft2(S-Us,axes=(0,1),workers=m)
b = rhoL*DF_H*x_b + rhoS*z_b
alphaf = admm_func.Sherman_MorrisonF(DF_H,b,c,rhoS)
primal_new = fft.ifft2(alphaf,axes=(0,1),workers=m)
if over_relax>0:
primal_new = over_relax*primal_new-((over_relax-1)*S)
#update S => add to admm func
if penalty=="l1":
auxiliary_new = threshold(primal_new+Us,lambdaS/rhoS, 'soft')
elif penalty=="l0":
auxiliary_new = threshold(primal_new+Us,lambdaS/rhoS, 'hard')
elif penalty=="FT":
auxiliary_new = threshold_firm(primal_new+Us,lambdaS/rhoS,2)
#SF = fft.fft2(auxiliary_new,axes=(0,1),workers=m)
#Dal = fft.ifft2(np.sum(DF*SF,2,keepdims=True),axes=(0,1),workers=m)
Dal = fft.ifft2(np.sum(DF*alphaf,2,keepdims=True),axes=(0,1),workers=m)
return primal_new, auxiliary_new, L, Dal
class ADMMSourceSep(admm_func.ConvolutionalSparseCoding):
r"""Convolutional sparse coding with ADMM algorithm for image processing applications.
Attributes
----------
dictionary : array_like of shape (n_pixelsx, n_pixelx_y, n_atoms)
Dictionary for sparse coding.
eps : float
sparsity coeficient. must be greater than 0.
n_iter : int
number of iterations.
delta : float
tolerance for the stopping criterion.
rho : float
penalty parameter.
update_rho : bool, optional
if True, rho is updated after each iteration.
save_iterations: bool, optional
if True, save the history of the ADMM algorithm.
verbosity: int, optional
verbosity level.
iterations_: list
history of the ADMM algorithm.
error_:
"""
@_deprecate_positional_args
def __init__(self, dictionary, eps, n_iter, delta, rhoS, rhoL,over_relax=0,
update_rho="adaptive", penalty="l1", norm_optim="Frobenius", save_iterations=False,
verbosity=0) -> None:
super().__init__(dictionary)
self.eps = eps
self.normalize = True
self.over_relax = over_relax
self.penalty = penalty
self.norm_optim = norm_optim
self.max_grad_iter = 50
self.alpha = 0.001
self.threshold_H = 250
self.n_iter = n_iter
self.iter_init_ = n_iter
self.delta = delta
self.rhoS = rhoS
self.rhoL = rhoL
self.workers = -1
self.update_rho = update_rho
self.save_iterations = save_iterations
self.verbosity = verbosity
self.converged_ = False
self.iterations_ = None
self.error_primal_ = []
self.error_dual_S_ = []
self.error_dual_L_ = []
self.error_rec_ = []
self.var_prim_ = []
self.auxiliary_ = 0
self.primal_ = 0
self.dual_S = 0
self.dual_L = 0
self.dal_ = 0
self.L_ = 0
self.c_ = 0
self.iterations_save_ = []
def _update_c(self):
self.c_ = self.rhoL*self.DF_/(self.rhoS + self.rhoL*self.c_inter)
def _update_rho(self):
if self.update_rho == "adaptive":
self.rhoL,k = admm_func.update_rhoLS_adp(self.error_primal_[-1],self.error_dual_L_[-1],self.rhoL)
self.rhoS,k = admm_func.update_rhoLS_adp(self.error_primal_[-1],self.error_dual_S_[-1],self.rhoS)
if self.norm_optim=="Frobenius":
self._update_c()
else:
if self.update_rho == "increase":
self.rhoL = 1.1*self.rhoL
self.rhoS = 1.1*self.rhoS
if self.norm_optim=="Frobenius":
self._update_c()
k=1.1
else:
k=1
self.dual_L = (self.dual_L + self.Y-self.dal_-self.L_)/k
self.dual_S = (self.dual_S + self.primal_-self.auxiliary_)/k
def _iteration_addm(self):
if self.norm_optim=="Frobenius":
primal_new, auxiliary_new, L_new, dal_new = addm_iteration_norm2(
self.eps, self.c_, self.Y, self.dal_, self.DF_, self.DF_H_,
self.dual_S,self.dual_L, self.rhoS, self.rhoL, self.auxiliary_, self.dim1,
self.over_relax, penalty = self.penalty, m=self.workers
)
else:
logging.warning("Error norme")
self._save_iteration(L_new, primal_new, auxiliary_new, dal_new)
def _cost_function(self, L, primal, auxiliary, dal):
cost = cost_function_addm2(self.Y,self.rhoS,self.rhoL,self.auxiliary_,self.dal_,self.L_,L,dal,primal,auxiliary)
self.var_prim_.append(cost[0])
self.error_rec_.append(cost[1])
self.error_primal_.append(cost[2])
self.error_dual_S_.append(cost[3])
self.error_dual_L_.append(cost[4])
def _save_iteration(self, L, primal, auxiliary, dal):
self._cost_function(L, primal, auxiliary, dal)
self.primal_ = primal
self.auxiliary_ = auxiliary
self.dal_ = dal
self.L_ = L
self._update_rho()
if self.save_iterations:
values = list(self.__dict__.values())
keys = list(self.__dict__.keys())
dico = dict(zip(keys, values))
self.iterations_save_ = dico
def _set_solution(self, dictionary_solution):
for i in dictionary_solution.keys():
vars(self)[i]=dictionary_solution[i]
self.n_iter = self.iter_init_+self.iterations_
def _initialize_for_estimation(self,X,initial_solution):
if self.converged_:
logging.warning("Overwriting past estimation.")
self.dim1 = (X.shape[0],X.shape[1],1)
self.dimK = self.dictionary.shape
self.dual_L = np.zeros(self.dim1)
self.dual_S = np.zeros(self.dimK)
self.auxiliary_ = np.zeros(self.dimK)
self.primal_tilde_ = np.zeros(self.dimK)
self.S_tilde = fft.fft2(X,axes=(0,1),workers=-1)
self.Y = X.reshape(self.dim1)
if self.normalize:
self.Y = (self.Y-self.Y.min())/(self.Y.max()-self.Y.min())
if initial_solution is not None:
self._set_solution(initial_solution)
self.converged_ = False
else:
self._precompute()
self.converged_ = False
self.iterations_ = 0
if self.save_iterations:
self.iterations_save_ = []
def _precompute(self):
self.DF_ = fft.fft2(self.dictionary,axes=(0,1),workers=self.workers)
self.DF_H_ = np.conj(self.DF_)
self.YF = fft.fft2(self.Y,axes=(0,1),workers=self.workers)
if self.norm_optim=="Frobenius":
self.c_inter = np.sum(self.DF_H_*self.DF_,2,keepdims=True)
self._update_c()
def fit(self, X, y=None, initial_solution=None):
X = self._validate_data(X)
self._initialize_for_estimation(X,initial_solution)
pbar = tqdm(total = self.n_iter,leave=True)
pbar.n = self.iterations_
while not self.converged_:
self._iteration_addm()
info = f"rec : {float(self.error_rec_[-1]):.4} ||duaS : {float(self.error_dual_S_[-1]):.3} ||duaL : {float(self.error_dual_L_[-1]):.3} ||pri : {float(self.error_primal_[-1]):.4} ||rhoS : {float(self.rhoS):.3} ||rhoL : {float(self.rhoL):.3}"
pbar.set_description(info)
pbar.update(1)
self.iterations_ += 1
# ajout max alpha variation + L variation
#cond_erro = np.max([self.error_rec_[-1], self.error_dual_[-1]]) < self.delta
cond_iter = self.iterations_ >= self.n_iter
self.converged_ = cond_iter #or cond_erro
pbar.close()
return self