- 
                Notifications
    
You must be signed in to change notification settings  - Fork 16
 
Add prox operators for simple constraints #13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 6 commits
572b2e0
              b47bfeb
              21b3e70
              7aad227
              308ff28
              2e8cbf6
              f102465
              bb6e608
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -1,5 +1,6 @@ | ||
| import numpy as np | ||
| 
     | 
||
| from numpy.linalg import norm | ||
| from sklearn.isotonic import isotonic_regression | ||
| from ya_glm.opt.base import Func | ||
| 
     | 
||
| 
     | 
||
| 
        
          
        
         | 
    @@ -25,6 +26,22 @@ def _prox(self, x, step=1): | |
| def is_proximable(self): | ||
| return True | ||
| 
     | 
||
| class LinearEquality(Constraint): | ||
| # credited to PyUNLocBoX | ||
| # https://github.com/epfl-lts2/pyunlocbox/ | ||
| def __init__(self, A, b): | ||
| self.A = A | ||
| self.b = b | ||
| self.pinvA = np.linalg.pinv(A) | ||
| 
     | 
||
| def _prox(self, x, step=1): | ||
| residue = self.A@x - self.b | ||
| sol = x - self.pinvA @ residue | ||
| return sol | ||
| 
     | 
||
| @property | ||
| def is_proximable(self): | ||
| return True | ||
| 
     | 
||
| class Simplex(Constraint): | ||
| 
     | 
||
| 
          
            
          
           | 
    @@ -58,8 +75,9 @@ def is_proximable(self): | |
| # See https://gist.github.com/mblondel/6f3b7aaad90606b98f71 | ||
| # for more algorithms. | ||
| def project_simplex(v, z=1): | ||
| if np.sum(v) <= z: | ||
| return v | ||
| # z is what the entries need to add up to, e.g. z=1 for probability simplex | ||
| 
         There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes this is true.  | 
||
| if np.sum(v) <= z: # don't we want the simplex to mean sum == z not sum <= z? | ||
                
       | 
||
| return v # also this doesn't work when v has, say, all negative entries | ||
                
       | 
||
| 
     | 
||
| n_features = v.shape[0] | ||
| u = np.sort(v)[::-1] | ||
| 
        
          
        
         | 
    @@ -74,3 +92,35 @@ def project_simplex(v, z=1): | |
| 
     | 
||
| def project_l1_ball(v, z=1): | ||
| return np.sign(v) * project_simplex(np.abs(v), z) | ||
| 
     | 
||
| class L2Ball(Constraint): | ||
| 
     | 
||
| def __init__(self, mult=1): | ||
| assert mult > 0 | ||
| self.mult = mult | ||
| 
     | 
||
| def _prox(self, x, step=1): | ||
| return x / np.max([norm(x)/self.mult, 1]) | ||
| 
     | 
||
| @property | ||
| def is_proximable(self): | ||
| return True | ||
| 
     | 
||
| class Isotonic(Constraint): | ||
| """Constraint for x1 <= ... <= xn or | ||
| x1 >= ... >= xn """ | ||
| # TODO: allow for general isotonic regression | ||
| # where the order relations are a simple directed | ||
| # graph. For an algorithm see Nemeth and Nemeth, "How to project onto an | ||
| # isotone projection cone", JLLA 2010 | ||
| def __init__(self, increasing=True): | ||
| self.increasing = increasing | ||
| 
     | 
||
| def _prox(self, x, step=1): | ||
| # computes the projection of x onto the monotone cone | ||
| # using the PAVA algorithm | ||
| return isotonic_regression(x, increasing=self.increasing) | ||
| 
     | 
||
| @property | ||
| def is_proximable(self): | ||
| return True | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| import numpy as np | ||
| from unittest import TestCase | ||
| from ya_glm.opt.constraint import convex | ||
| 
     | 
||
| class TestProjectionsOnConstraints(TestCase): | ||
| def setUp(self): | ||
| pass | ||
| 
     | 
||
| def assert_arrays_close(self, test_array, ref_array): | ||
| "Custom assertion for arrays being almost equal" | ||
| try: | ||
| np.testing.assert_allclose(test_array, ref_array) | ||
| except AssertionError: | ||
| self.fail() | ||
| 
     | 
||
| def test_Positive(self): | ||
| cons = convex.Positive() | ||
| v = np.array([-1, 0, 2, 3, -2]) | ||
| self.assert_arrays_close(cons.prox(v), [0, 0, 2, 3, 0]) | ||
| self.assertEqual(cons.prox(-2), 0) | ||
| 
     | 
||
| def test_LinearEquality(self): | ||
| A = np.identity(2) | ||
| b = np.array([1,1]) | ||
| cons = convex.LinearEquality(A, b) | ||
| proj = cons.prox(b) # the proj of b should just be b | ||
| self.assert_arrays_close(proj, b) | ||
| 
     | 
||
| def test_L2Ball(self): | ||
| cons1 = convex.L2Ball(1) | ||
| self.assert_arrays_close(cons1.prox([0,0,0]), [0,0,0]) | ||
| self.assert_arrays_close(cons1.prox([1,0,0]), [1,0,0]) | ||
| self.assert_arrays_close(cons1.prox([0.5,0,0]), [0.5,0,0]) | ||
| self.assert_arrays_close(cons1.prox([1,1,1]), np.array([1,1,1])/np.sqrt(3)) | ||
| self.assert_arrays_close(cons1.prox([1,-1,1]), np.array([1,-1,1])/np.sqrt(3)) | ||
| 
     | 
||
| cons4 = convex.L2Ball(4) | ||
| self.assert_arrays_close(cons4.prox([0,0,0]), [0,0,0]) | ||
| self.assert_arrays_close(cons4.prox([1,0,0]), [1,0,0]) | ||
| self.assert_arrays_close(cons4.prox([0.5,0,0]), [0.5,0,0]) | ||
| self.assert_arrays_close(cons4.prox([-4,3,0]), np.array([-4,3,0])/(5/4)) | ||
| 
     | 
||
| def test_Isotonic(self): | ||
| cons = convex.Isotonic(increasing=True) | ||
| for v in [ | ||
| np.arange(5), | ||
| np.array([-1, 0, 2, 3, -2]), | ||
| np.array([-1, 3, 0, 3, 2]) | ||
| ]: | ||
| result = cons.prox(v) | ||
| lags = result[1:] - result[:-1] | ||
| self.assertTrue((lags >= 0).all()) | ||
| 
     | 
||
| cons = convex.Isotonic(increasing=False) | ||
| for v in [ | ||
| np.arange(5), | ||
| np.array([-1, 0, 2, 3, -2]), | ||
| np.array([-1, 3, 0, 3, 2]) | ||
| ]: | ||
| result = cons.prox(v) | ||
| lags = result[1:] - result[:-1] | ||
| self.assertTrue((lags <= 0).all()) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps we should avoid computing this in init (sometimes we initialize this object just to check if it is proximable). Maybe we should cache it the first time we call prox? I've also though about having a
.setup()method that precomputes required data for these functions.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we'll do it as a
@propertyThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed