Skip to content

Commit be4b6a2

Browse files
authored
Merge pull request #34 from epfl-lts2/new-projections
feat: add spsd and lineq projection operators
2 parents b0eb893 + f809a01 commit be4b6a2

File tree

3 files changed

+224
-28
lines changed

3 files changed

+224
-28
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ and this project adheres to `Semantic Versioning <https://semver.org>`_.
99
Unreleased
1010
----------
1111

12+
* New function: proj_linalg.
13+
* New function: proj_sdsp.
1214
* New function: proj_positive.
1315
* New function: structured_sparsity.
1416
* Continuous integration with Python 3.6, 3.7, 3.8, 3.9. Dropped 2.7, 3.4, 3.5.

pyunlocbox/functions.py

Lines changed: 148 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
3737
proj_positive
3838
proj_b2
39+
proj_lineq
40+
proj_spsd
3941
4042
**Miscellaneous**
4143
@@ -803,24 +805,13 @@ class proj(func):
803805
See generic attributes descriptions of the
804806
:class:`pyunlocbox.functions.func` base class.
805807
806-
Parameters
807-
----------
808-
epsilon : float, optional
809-
The radius of the ball. Default is 1.
810-
method : {'FISTA', 'ISTA'}, optional
811-
The method used to solve the problem. It can be 'FISTA' or 'ISTA'.
812-
Default is 'FISTA'.
813-
814808
Notes
815809
-----
816810
* All indicator functions (projections) evaluate to zero by definition.
817811
818812
"""
819-
820-
def __init__(self, epsilon=1, method='FISTA', **kwargs):
813+
def __init__(self, **kwargs):
821814
super(proj, self).__init__(**kwargs)
822-
self.epsilon = epsilon
823-
self.method = method
824815

825816
def _eval(self, x):
826817
# Matlab version returns a small delta to avoid division by 0 when
@@ -835,9 +826,9 @@ class proj_positive(proj):
835826
r"""
836827
Projection on the positive octant (eval, prox).
837828
838-
This function is the indicator function :math:`i_S(z)` of the set S which
839-
is zero if `z` is in the set and infinite otherwise. The set S is defined
840-
by :math:`\left\{z \in \mathbb{R}^N \mid z \leq 0 \right\}`.
829+
This function is the indicator function :math:`i_S(z)` of the set
830+
:math:`S = \left\{z \in \mathbb{R}^N \mid z \leq 0 \right\}`
831+
that is zero if :math:`z` is in the set and infinite otherwise.
841832
842833
See generic attributes descriptions of the
843834
:class:`pyunlocbox.functions.proj` base class. Note that the constructor
@@ -867,19 +858,78 @@ def _prox(self, x, T):
867858
return np.clip(x, 0, np.inf)
868859

869860

861+
class proj_spsd(proj):
862+
r"""
863+
Projection on symmetric positive semi-definite matrices (eval, prox).
864+
865+
This function is the indicator function :math:`i_S(M)` of the set
866+
:math:`S = \left\{M \in \mathbb{R}^{N \times N}
867+
\mid M \succeq 0, M=M^T \right\}`
868+
that is zero if :math:`M` is in the set and infinite otherwise.
869+
870+
See generic attributes descriptions of the
871+
:class:`pyunlocbox.functions.proj` base class. Note that the constructor
872+
takes keyword-only parameters.
873+
874+
Notes
875+
-----
876+
* The evaluation of this function is zero.
877+
878+
Examples
879+
--------
880+
>>> from pyunlocbox import functions
881+
>>> f = functions.proj_spsd()
882+
>>> A = np.array([[0, -1] , [-1, 1]])
883+
>>> A = (A + A.T) / 2 # Symmetrize the matrix.
884+
>>> np.linalg.eig(A)[0]
885+
array([-0.61803399, 1.61803399])
886+
>>> f.eval(A)
887+
0
888+
>>> Aproj = f.prox(A, 0)
889+
>>> np.linalg.eig(Aproj)[0]
890+
array([0. , 1.61803399])
891+
892+
"""
893+
def __init__(self, **kwargs):
894+
# Constructor takes keyword-only parameters to prevent user errors.
895+
super(proj_spsd, self).__init__(**kwargs)
896+
897+
def _prox(self, x, T):
898+
isreal = np.isreal(x).all()
899+
900+
# 1. make it symmetric.
901+
sol = (x + np.conj(x.T)) / 2
902+
903+
# 2. make it semi-positive.
904+
D, V = np.linalg.eig(sol)
905+
D = np.real(D)
906+
if isreal:
907+
V = np.real(V)
908+
D = np.clip(D, 0, np.inf)
909+
sol = V @ np.diag(D) @ np.conj(V.T)
910+
return sol
911+
912+
870913
class proj_b2(proj):
871914
r"""
872915
Projection on the L2-ball (eval, prox).
873916
874-
This function is the indicator function :math:`i_S(z)` of the set S which
875-
is zero if `z` is in the set and infinite otherwise. The set S is defined
876-
by :math:`\left\{z \in \mathbb{R}^N \mid \|A(z)-y\|_2 \leq \epsilon
877-
\right\}`.
917+
This function is the indicator function :math:`i_S(z)` of the set
918+
:math:`S= \left\{z \in \mathbb{R}^N \mid \|Az-y\|_2 \leq \epsilon \right\}`
919+
that is zero if :math:`z` is in the set and infinite otherwise.
878920
879921
See generic attributes descriptions of the
880922
:class:`pyunlocbox.functions.proj` base class. Note that the constructor
881923
takes keyword-only parameters.
882924
925+
Parameters
926+
----------
927+
epsilon : float, optional
928+
The radius of the ball. Default is 1.
929+
method : {'FISTA', 'ISTA'}, optional
930+
The method used to solve the problem. It can be 'FISTA' or 'ISTA'.
931+
Default is 'FISTA'.
932+
883933
Notes
884934
-----
885935
* The `tol` parameter is defined as the tolerance for the projection on the
@@ -893,6 +943,10 @@ class proj_b2(proj):
893943
:math:`\|A(z)-y\|_2 \leq \epsilon`. It is thus a projection of the vector
894944
`x` onto an L2-ball of diameter `epsilon`.
895945
946+
See Also
947+
--------
948+
proj_lineq : use instead of ``epsilon=0``
949+
896950
Examples
897951
--------
898952
>>> from pyunlocbox import functions
@@ -904,10 +958,11 @@ class proj_b2(proj):
904958
array([1.70710678, 1.70710678])
905959
906960
"""
907-
908-
def __init__(self, **kwargs):
961+
def __init__(self, epsilon=1, method='FISTA', **kwargs):
909962
# Constructor takes keyword-only parameters to prevent user errors.
910963
super(proj_b2, self).__init__(**kwargs)
964+
self.epsilon = epsilon
965+
self.method = method
911966

912967
def _prox(self, x, T):
913968

@@ -993,6 +1048,78 @@ def _prox(self, x, T):
9931048
return sol
9941049

9951050

1051+
class proj_lineq(proj):
1052+
r"""
1053+
Projection on the plane satisfying the linear equality Az = y (eval, prox).
1054+
1055+
This function is the indicator function :math:`i_S(z)` of the set
1056+
:math:`S = \left\{z \in \mathbb{R}^N \mid Az = y \right\}`
1057+
that is zero if :math:`z` is in the set and infinite otherwise.
1058+
1059+
The proximal operator is
1060+
:math:`\operatorname{arg\,min}_z \| z - x \|_2 \text{ s.t. } Az = y`.
1061+
1062+
See generic attributes descriptions of the
1063+
:class:`pyunlocbox.functions.proj` base class. Note that the constructor
1064+
takes keyword-only parameters.
1065+
1066+
Notes
1067+
-----
1068+
* A parameter `pinvA`, the pseudo-inverse of `A`, must be provided if the
1069+
parameter `A` is provided as an operator/callable (not a matrix).
1070+
* The evaluation of this function is zero.
1071+
1072+
See Also
1073+
--------
1074+
proj_b2 : quadratic case
1075+
1076+
Examples
1077+
--------
1078+
>>> from pyunlocbox import functions
1079+
>>> import numpy as np
1080+
>>> x = np.array([0, 0])
1081+
>>> A = np.array([[1, 1]])
1082+
>>> pinvA = np.linalg.pinv(A)
1083+
>>> y = np.array([1])
1084+
>>> f = functions.proj_lineq(A=A, pinvA=pinvA, y=y)
1085+
>>> sol = f.prox(x, 0)
1086+
>>> sol
1087+
array([0.5, 0.5])
1088+
>>> np.abs(A.dot(sol) - y) < 1e-15
1089+
array([ True])
1090+
1091+
"""
1092+
def __init__(self, A=None, pinvA=None, **kwargs):
1093+
# Constructor takes keyword-only parameters to prevent user errors.
1094+
super(proj_lineq, self).__init__(A=A, **kwargs)
1095+
1096+
if pinvA is None:
1097+
if A is None:
1098+
print("Are you sure about the parameters?" +
1099+
"The projection will return y.")
1100+
self.pinvA = lambda x: x
1101+
else:
1102+
if callable(A):
1103+
raise ValueError(
1104+
"Provide A as a numpy array or provide pinvA.")
1105+
else:
1106+
# Transform matrix form to operator form.
1107+
self._pinvA = np.linalg.pinv(A)
1108+
self.pinvA = lambda x: self._pinvA.dot(x)
1109+
else:
1110+
if callable(pinvA):
1111+
self.pinvA = pinvA
1112+
else:
1113+
self.pinvA = lambda x: pinvA.dot(x)
1114+
1115+
def _prox(self, x, T):
1116+
# Applying the projection formula.
1117+
# (for now, only the non scalable version)
1118+
residue = self.A(x) - self.y()
1119+
sol = x - self.pinvA(residue)
1120+
return sol
1121+
1122+
9961123
class structured_sparsity(func):
9971124
r"""
9981125
Structured sparsity (eval, prox).

pyunlocbox/tests/test_functions.py

Lines changed: 74 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,10 @@ def assert_equivalent(param1, param2):
5151
assert_equivalent({'y': 3.2}, {'y': lambda: 3.2})
5252
assert_equivalent({'A': None}, {'A': np.identity(3)})
5353
A = np.array([[-4, 2, 5], [1, 3, -7], [2, -1, 0]])
54+
pinvA = np.linalg.pinv(A) # For proj_linalg.
5455
assert_equivalent({'A': A}, {'A': A, 'At': A.T})
55-
assert_equivalent({'A': lambda x: A.dot(x)}, {'A': A, 'At': A})
56+
assert_equivalent({'A': lambda x: A.dot(x), 'pinvA': pinvA},
57+
{'A': A, 'At': A})
5658

5759
def test_dummy(self):
5860
"""
@@ -95,16 +97,16 @@ def test_norm_l2(self):
9597
self.assertEqual(f.eval([4, 6]), 0)
9698
self.assertEqual(f.eval([5, -2]), 256 + 4)
9799
nptest.assert_allclose(f.grad([4, 6]), 0)
98-
# nptest.assert_allclose(f.grad([5, -2]), [8, -64])
100+
# nptest.assert_allclose(f.grad([5, -2]), [8, -64])
99101
nptest.assert_allclose(f.prox([4, 6], 1), [4, 6])
100102

101103
f = functions.norm_l2(lambda_=2, y=np.fft.fft([2, 4]) / np.sqrt(2),
102104
A=lambda x: np.fft.fft(x) / np.sqrt(x.size),
103105
At=lambda x: np.fft.ifft(x) * np.sqrt(x.size))
104-
# self.assertEqual(f.eval(np.fft.ifft([2, 4])*np.sqrt(2)), 0)
105-
# self.assertEqual(f.eval([3, 5]), 2*np.sqrt(25+81))
106+
# self.assertEqual(f.eval(np.fft.ifft([2, 4])*np.sqrt(2)), 0)
107+
# self.assertEqual(f.eval([3, 5]), 2*np.sqrt(25+81))
106108
nptest.assert_allclose(f.grad([2, 4]), 0)
107-
# nptest.assert_allclose(f.grad([3, 5]), [4*np.sqrt(5), 4*3])
109+
# nptest.assert_allclose(f.grad([3, 5]), [4*np.sqrt(5), 4*3])
108110
nptest.assert_allclose(f.prox([2, 4], 1), [2, 4])
109111
nptest.assert_allclose(f.prox([3, 5], 1), [2.2, 4.2])
110112
nptest.assert_allclose(f.prox([2.2, 4.2], 1), [2.04, 4.04])
@@ -417,6 +419,46 @@ def test_proj_b2(self):
417419
f.method = 'NOT_A_VALID_METHOD'
418420
self.assertRaises(ValueError, f.prox, x, 0)
419421

422+
def test_proj_lineq(self):
423+
"""
424+
Test the projection on Ax = y
425+
426+
"""
427+
x = np.zeros([10])
428+
A = np.ones([1, 10])
429+
y = np.array([10])
430+
f = functions.proj_lineq(A=A, y=y)
431+
sol = f.prox(x, 0)
432+
np.testing.assert_allclose(sol, np.ones([10]))
433+
np.testing.assert_allclose(A.dot(sol), y)
434+
435+
f = functions.proj_lineq(A=A)
436+
sol = f.prox(x, 0)
437+
np.testing.assert_allclose(sol, np.zeros([10]))
438+
439+
for i in range(1, 15):
440+
x = np.random.randn(10)
441+
y = np.random.randn(i)
442+
A = np.random.randn(i, 10)
443+
pinvA = np.linalg.pinv(A)
444+
f1 = functions.proj_lineq(A=A, y=y)
445+
f2 = functions.proj_lineq(A=lambda x: A.dot(x), pinvA=pinvA, y=y)
446+
f3 = functions.proj_lineq(A=A, pinvA=lambda x: pinvA.dot(x), y=y)
447+
f4 = functions.proj_lineq(A=A, pinvA=pinvA, y=y)
448+
sol1 = f1.prox(x, 0)
449+
sol2 = f2.prox(x, 0)
450+
sol3 = f3.prox(x, 0)
451+
sol4 = f4.prox(x, 0)
452+
np.testing.assert_allclose(sol1, sol2)
453+
np.testing.assert_allclose(sol1, sol3)
454+
np.testing.assert_allclose(sol1, sol4)
455+
if i <= x.size:
456+
np.testing.assert_allclose(A.dot(sol1), y)
457+
if i >= x.size:
458+
np.testing.assert_allclose(sol1, pinvA.dot(y))
459+
460+
self.assertRaises(ValueError, functions.proj_lineq, A=lambda x: x)
461+
420462
def test_proj_positive(self):
421463
"""
422464
Test the projection on the positive octant.
@@ -430,6 +472,30 @@ def test_proj_positive(self):
430472
nptest.assert_equal(res[x > 0], x[x > 0]) # Positives are unchanged.
431473
self.assertEqual(fpos.eval(x), 0)
432474

475+
def test_proj_spsd(self):
476+
"""
477+
Test the projection on symmetric positive semi-definite matrices.
478+
479+
"""
480+
f_spds = functions.proj_spsd()
481+
A = np.random.randn(10, 10)
482+
A = A + A.T
483+
eig1 = np.sort(np.real(np.linalg.eig(A)[0]))
484+
res = f_spds.prox(A, T=1)
485+
eig2 = np.sort(np.real(np.linalg.eig(res)[0]))
486+
# All eigenvalues are positive.
487+
assert ((eig2 > -1e-13).all())
488+
489+
# Positive value are unchanged.
490+
np.testing.assert_allclose(eig2[eig1 > 0], eig1[eig1 > 0])
491+
492+
# The symmetrization works.
493+
A = np.random.rand(10, 10) + 10 * np.eye(10)
494+
res = f_spds.prox(A, T=1)
495+
np.testing.assert_allclose(res, (A + A.T) / 2)
496+
497+
self.assertEqual(f_spds.eval(A), 0)
498+
433499
def test_structured_sparsity(self):
434500
"""
435501
Test the structured sparsity function.
@@ -503,8 +569,9 @@ def test_independent_problems(self):
503569
if name == 'norm_tv':
504570
# Each column is one-dimensional.
505571
f = func(dim=1, maxit=20, tol=0)
506-
elif name == 'norm_nuclear':
507-
# TODO: make this test two dimensional for the norm nuclear?
572+
elif name in ['norm_nuclear', 'proj_spsd']:
573+
# TODO: make this test two dimensional for the norm nuclear
574+
# and the spsd projection?
508575
continue
509576
else:
510577
f = func()

0 commit comments

Comments
 (0)