Skip to content

Commit ed9ed07

Browse files
authored
Implement SANSA model (scalable variant of EASE) (#661)
1 parent 167a212 commit ed9ed07

File tree

9 files changed

+406
-0
lines changed

9 files changed

+406
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ The table below lists the recommendation models/algorithms featured in Cornac. E
154154
| :--: | --------------- | :--: | :---------: | :-----: |
155155
| 2024 | [Comparative Aspects and Opinions Ranking for Recommendation Explanations (Companion)](cornac/models/companion), [docs](https://cornac.readthedocs.io/en/stable/api_ref/models.html#module-cornac.models.companion.recom_companion), [paper](https://lthoang.com/assets/publications/mlj24.pdf) | Hybrid / Sentiment / Explainable | CPU | [quick-start](examples/companion_example.py)
156156
| | [Hypergraphs with Attention on Reviews (HypAR)](cornac/models/hypar), [docs](https://cornac.readthedocs.io/en/stable/api_ref/models.html#module-cornac.models.hypar.recom_hypar), [paper](https://doi.org/10.1007/978-3-031-56027-9_14)| Hybrid / Sentiment / Explainable | [requirements](cornac/models/hypar/requirements_cu116.txt), CPU / GPU | [quick-start](https://github.com/PreferredAI/HypAR)
157+
| 2023 | [Scalable Approximate NonSymmetric Autoencoder (SANSA)](cornac/models/sansa), [docs](https://cornac.readthedocs.io/en/stable/api_ref/models.html#module-cornac.models.sansa.recom_sansa), [paper](https://dl.acm.org/doi/10.1145/3604915.3608827) | Collaborative Filtering | [requirements](cornac/models/sansa/requirements.txt), CPU | [quick-start](examples/sansa_movielens.py), [150k-items](examples/sansa_tradesy.py)
157158
| 2022 | [Disentangled Multimodal Representation Learning for Recommendation (DMRL)](cornac/models/dmrl), [docs](https://cornac.readthedocs.io/en/stable/api_ref/models.html#module-cornac.models.dmrl.recom_dmrl), [paper](https://arxiv.org/pdf/2203.05406.pdf) | Content-Based / Text & Image | [requirements](cornac/models/dmrl/requirements.txt), CPU / GPU | [quick-start](examples/dmrl_example.py)
158159
| 2021 | [Bilateral Variational Autoencoder for Collaborative Filtering (BiVAECF)](cornac/models/bivaecf), [docs](https://cornac.readthedocs.io/en/stable/api_ref/models.html#module-cornac.models.bivaecf.recom_bivaecf), [paper](https://dl.acm.org/doi/pdf/10.1145/3437963.3441759) | Collaborative Filtering / Content-Based | [requirements](cornac/models/bivaecf/requirements.txt), CPU / GPU | [quick-start](https://github.com/PreferredAI/bi-vae), [deep-dive](https://github.com/recommenders-team/recommenders/blob/main/examples/02_model_collaborative_filtering/cornac_bivae_deep_dive.ipynb)
159160
| | [Causal Inference for Visual Debiasing in Visually-Aware Recommendation (CausalRec)](cornac/models/causalrec), [docs](https://cornac.readthedocs.io/en/stable/api_ref/models.html#module-cornac.models.causalrec.recom_causalrec), [paper](https://arxiv.org/abs/2107.02390) | Content-Based / Image | [requirements](cornac/models/causalrec/requirements.txt), CPU / GPU | [quick-start](examples/causalrec_clothing.py)

cornac/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
from .pcrl import PCRL
7373
from .pmf import PMF
7474
from .recvae import RecVAE
75+
from .sansa import SANSA
7576
from .sbpr import SBPR
7677
from .skm import SKMeans
7778
from .sorec import SoRec

cornac/models/sansa/README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Dependencies
2+
Training of SANSA uses [scikit-sparse](https://github.com/scikit-sparse/scikit-sparse), which depends on the [SuiteSparse](https://github.com/DrTimothyAldenDavis/SuiteSparse) numerical library. To install SuiteSparse on Ubuntu and macOS, run the commands below:
3+
```
4+
# Ubuntu
5+
sudo apt-get install libsuitesparse-dev
6+
7+
# macOS
8+
brew install suite-sparse
9+
```
10+
After installing SuiteSparse, simply install the requirements.txt.

cornac/models/sansa/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .recom_sansa import SANSA

cornac/models/sansa/recom_sansa.py

Lines changed: 289 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,289 @@
1+
import numpy as np
2+
import scipy.sparse as sp
3+
4+
from ..recommender import Recommender
5+
from ..recommender import ANNMixin, MEASURE_DOT
6+
from ...exception import ScoreException
7+
8+
9+
class SANSA(Recommender, ANNMixin):
10+
"""Scalable Approximate NonSymmetric Autoencoder for Collaborative Filtering.
11+
12+
Parameters
13+
----------
14+
name: string, optional, default: 'SANSA'
15+
The name of the recommender model.
16+
17+
l2: float, optional, default: 1.0
18+
L2-norm regularization-parameter λ ∈ R+.
19+
20+
weight_matrix_density: float, optional, default: 1e-3
21+
Density of weight matrices.
22+
23+
compute_gramian: boolean, optional, default: True
24+
Indicates whether training input X is a user-item matrix (represents a bipartite graph) or \
25+
or an item-item matrix (e.g, co-occurrence matrix; not a bipartite graph).
26+
27+
factorizer_class: string, optional, default: 'ICF'
28+
Class of Cholesky factorizer. Supported values:
29+
- 'CHOLMOD' - exact Cholesky factorization using CHOLMOD algorithm, followed by pruning.
30+
- 'ICF' - Incomplete Cholesky factorization (i.e., pruning on-the-fly)
31+
CHOLMOD provides higher-quality approximate factorization for increased price. \
32+
ICF is less accurate but more scalable (recommended method when num_items >= ~50K-100K).
33+
Note that ICF uses additional matrix preprocessing and hence different (smaller) l2 regularization.
34+
35+
factorizer_shift_step: float, optional, default: 1e-3
36+
Used with ICF factorizer.
37+
Incomplete factorization may break (zero division), indicating need for increased l2 regularization.
38+
'factorizer_shift_step' is the initial increase in l2 regularization (after first breakdown).
39+
40+
factorizer_shift_multiplier: float, optional, default: 2.0
41+
Used with ICF factorizer.
42+
Multiplier for factorizer shift. After k-th breakdown, additional l2 regularization is \
43+
'factorizer_shift_step' * 'factorizer_shift_multiplier'^(k-1)
44+
45+
inverter_scans: integer, optional, default: 3
46+
Number of scans repairing the approximate inverse factor. Scans repair all columns with residual below \
47+
a certain threshold, and this threshold goes to 0 in later scans. More scans give more accurate results \
48+
but take longer. We recommend values between 0 and 5, use lower values if scans take too long.
49+
50+
inverter_finetune_steps: integer, optional, default: 10
51+
Repairs a small portion of columns with highest residuals. All finetune steps take (roughly) the same amount of time.
52+
We recommend values between 0 and 30.
53+
54+
use_absolute_value_scores: boolean, optional, default: False
55+
Following https://dl.acm.org/doi/abs/10.1145/3640457.3688179, it is recommended for EASE-like models to consider \
56+
the absolute value of scores in situations when X^TX is sparse.
57+
58+
trainable: boolean, optional, default: True
59+
When False, the model is not trained and Cornac assumes that the model is already \
60+
trained.
61+
62+
verbose: boolean, optional, default: False
63+
When True, some running logs are displayed.
64+
65+
seed: int, optional, default: None
66+
Random seed for parameters initialization.
67+
68+
References
69+
----------
70+
* Martin Spišák, Radek Bartyzal, Antonín Hoskovec, Ladislav Peska, and Miroslav Tůma. 2023. \
71+
Scalable Approximate NonSymmetric Autoencoder for Collaborative Filtering. \
72+
In Proceedings of the 17th ACM Conference on Recommender Systems (RecSys '23). \
73+
Association for Computing Machinery, New York, NY, USA, 763–770. https://doi.org/10.1145/3604915.3608827
74+
75+
* SANSA GitHub Repository: https://github.com/glami/sansa
76+
"""
77+
78+
def __init__(
79+
self,
80+
name="SANSA",
81+
l2=1.0,
82+
weight_matrix_density=1e-3,
83+
compute_gramian=True,
84+
factorizer_class="ICF",
85+
factorizer_shift_step=1e-3,
86+
factorizer_shift_multiplier=2.0,
87+
inverter_scans=3,
88+
inverter_finetune_steps=10,
89+
use_absolute_value_scores=False,
90+
trainable=True,
91+
verbose=True,
92+
seed=None,
93+
W1=None, # "weights[0] (sp.csr_matrix)"
94+
W2=None, # "weights[1] (sp.csr_matrix)"
95+
X=None, # user-item interaction matrix (sp.csr_matrix)
96+
):
97+
Recommender.__init__(self, name=name, trainable=trainable, verbose=verbose)
98+
self.l2 = l2
99+
self.weight_matrix_density = weight_matrix_density
100+
self.compute_gramian = compute_gramian
101+
self.factorizer_class = factorizer_class
102+
self.factorizer_shift_step = factorizer_shift_step
103+
self.factorizer_shift_multiplier = factorizer_shift_multiplier
104+
self.inverter_scans = inverter_scans
105+
self.inverter_finetune_steps = inverter_finetune_steps
106+
self.use_absolute_value_scores = use_absolute_value_scores
107+
self.verbose = verbose
108+
self.seed = seed
109+
self.X = X.astype(np.float32) if X is not None and X.dtype != np.float32 else X
110+
self.weights = (W1, W2)
111+
112+
def fit(self, train_set, val_set=None):
113+
"""Fit the model to observations.
114+
115+
Parameters
116+
----------
117+
train_set: :obj:`cornac.data.Dataset`, required
118+
User-Item preference data as well as additional modalities.
119+
120+
val_set: :obj:`cornac.data.Dataset`, optional, default: None
121+
User-Item preference data for model selection purposes (e.g., early stopping).
122+
123+
Returns
124+
-------
125+
self : object
126+
"""
127+
Recommender.fit(self, train_set, val_set)
128+
129+
from sansa.core import (
130+
FactorizationMethod,
131+
GramianFactorizer,
132+
CHOLMODGramianFactorizerConfig,
133+
ICFGramianFactorizerConfig,
134+
UnitLowerTriangleInverter,
135+
UMRUnitLowerTriangleInverterConfig,
136+
)
137+
from sansa.utils import get_squared_norms_along_compressed_axis, inplace_scale_along_compressed_axis, inplace_scale_along_uncompressed_axis
138+
139+
# User-item interaction matrix (sp.csr_matrix)
140+
self.X = train_set.matrix.astype(np.float32)
141+
142+
if self.factorizer_class == "CHOLMOD":
143+
self.factorizer_config = CHOLMODGramianFactorizerConfig()
144+
else:
145+
self.factorizer_config = ICFGramianFactorizerConfig(
146+
factorization_shift_step=self.factorizer_shift_step, # initial diagonal shift if incomplete factorization fails
147+
factorization_shift_multiplier=self.factorizer_shift_multiplier, # multiplier for the shift for subsequent attempts
148+
)
149+
self.factorizer = GramianFactorizer.from_config(self.factorizer_config)
150+
self.factorization_method = self.factorizer_config.factorization_method
151+
152+
self.inverter_config = UMRUnitLowerTriangleInverterConfig(
153+
scans=self.inverter_scans, # number of scans through all columns of the matrix
154+
finetune_steps=self.inverter_finetune_steps, # number of finetuning steps, targeting worst columns
155+
)
156+
self.inverter = UnitLowerTriangleInverter.from_config(self.inverter_config)
157+
158+
# create a working copy of user_item_matrix
159+
X = self.X.copy()
160+
161+
if self.factorization_method == FactorizationMethod.ICF:
162+
# scale matrix X
163+
if self.compute_gramian:
164+
# Inplace scale columns of X by square roots of column norms of X^TX.
165+
da = np.sqrt(np.sqrt(get_squared_norms_along_compressed_axis(X.T @ X)))
166+
# Divide columns of X by the computed square roots of row norms of X^TX
167+
da[da == 0] = 1 # ignore zero elements
168+
inplace_scale_along_uncompressed_axis(X, 1 / da) # CSR column scaling
169+
del da
170+
else:
171+
# Inplace scale rows and columns of X by square roots of row norms of X.
172+
da = np.sqrt(np.sqrt(get_squared_norms_along_compressed_axis(X)))
173+
# Divide rows and columns of X by the computed square roots of row norms of X
174+
da[da == 0] = 1 # ignore zero elements
175+
inplace_scale_along_uncompressed_axis(X, 1 / da) # CSR column scaling
176+
inplace_scale_along_compressed_axis(X, 1 / da) # CSR row scaling
177+
del da
178+
179+
# Compute LDL^T decomposition of
180+
# - P(X^TX + self.l2 * I)P^T if compute_gramian=True
181+
# - P(X + self.l2 * I)P^T if compute_gramian=False
182+
if self.verbose:
183+
print("Computing LDL^T decomposition of permuted item-item matrix...")
184+
L, D, p = self.factorizer.approximate_ldlt(
185+
X,
186+
self.l2,
187+
self.weight_matrix_density,
188+
compute_gramian=self.compute_gramian,
189+
)
190+
del X
191+
192+
# Compute approximate inverse of L using selected method
193+
if self.verbose:
194+
print("Computing approximate inverse of L...")
195+
L_inv = self.inverter.invert(L)
196+
del L
197+
198+
# Construct W = L_inv @ P
199+
inv_p = np.argsort(p)
200+
W = L_inv[:, inv_p]
201+
del L_inv
202+
203+
# Construct W_r (A^{-1} = W.T @ W_r)
204+
W_r = W.copy()
205+
inplace_scale_along_uncompressed_axis(W_r, 1 / D.diagonal())
206+
207+
# Extract diagonal entries
208+
diag = W.copy()
209+
diag.data = diag.data**2
210+
inplace_scale_along_uncompressed_axis(diag, 1 / D.diagonal())
211+
diagsum = diag.sum(axis=0) # original
212+
del diag
213+
diag = np.asarray(diagsum)[0]
214+
215+
# Divide columns of the inverse by negative diagonal entries
216+
# equivalent to dividing the columns of W by negative diagonal entries
217+
inplace_scale_along_compressed_axis(W_r, -1 / diag)
218+
self.weights = (W.T.tocsr(), W_r.tocsr())
219+
220+
return self
221+
222+
def forward(self, X: sp.csr_matrix) -> sp.csr_matrix:
223+
"""
224+
Forward pass.
225+
"""
226+
latent = X @ self.weights[0]
227+
out = latent @ self.weights[1]
228+
return out
229+
230+
def score(self, user_idx, item_idx=None):
231+
"""Predict the scores/ratings of a user for an item.
232+
233+
Parameters
234+
----------
235+
user_idx: int, required
236+
The index of the user for whom to perform score prediction.
237+
238+
item_idx: int, optional, default: None
239+
The index of the item for which to perform score prediction.
240+
If None, scores for all known items will be returned.
241+
242+
Returns
243+
-------
244+
res : A scalar or a Numpy array
245+
Relative scores that the user gives to the item or to all known items
246+
247+
"""
248+
if self.is_unknown_user(user_idx):
249+
raise ScoreException("Can't make score prediction for user %d" % user_idx)
250+
251+
if item_idx is not None and self.is_unknown_item(item_idx):
252+
raise ScoreException("Can't make score prediction for item %d" % item_idx)
253+
254+
scores = self.forward(self.X[user_idx]).toarray().reshape(-1)
255+
if self.use_absolute_value_scores:
256+
scores = np.abs(scores)
257+
if item_idx is None:
258+
return scores
259+
return scores[item_idx]
260+
261+
def get_vector_measure(self):
262+
"""Getting a valid choice of vector measurement in ANNMixin._measures.
263+
264+
Returns
265+
-------
266+
measure: MEASURE_DOT
267+
Dot product aka. inner product
268+
"""
269+
return MEASURE_DOT
270+
271+
def get_user_vectors(self):
272+
"""Getting a matrix of user vectors serving as query for ANN search.
273+
274+
Returns
275+
-------
276+
out: numpy.array
277+
Matrix of user vectors for all users available in the model.
278+
"""
279+
return self.X @ self.weights[0]
280+
281+
def get_item_vectors(self):
282+
"""Getting a matrix of item vectors used for building the index for ANN search.
283+
284+
Returns
285+
-------
286+
out: numpy.array
287+
Matrix of item vectors for all items available in the model.
288+
"""
289+
return self.self.weights[1]
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
sansa >= 1.1.0

examples/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,10 @@
104104

105105
[recvae_example.py](recvae_example.py) - New Variational Autoencoder for Top-N Recommendations with Implicit Feedback (RecVAE).
106106

107+
[sansa_movielens.py](sansa_movielens.py) - Scalable Approximate NonSymmetric Autoencoder (SANSA) with MovieLens 1M dataset.
108+
109+
[sansa_tradesy.py](sansa_movielens.py) - Scalable Approximate NonSymmetric Autoencoder (SANSA) with Tradesy dataset.
110+
107111
[skm_movielens.py](skm_movielens.py) - SKMeans vs BPR on MovieLens data.
108112

109113
[svd_example.py](svd_example.py) - Singular Value Decomposition (SVD) with MovieLens dataset.

examples/sansa_movielens.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
"""Example SANSA (Scalable Approximate NonSymmetric Autoencoder for Collaborative Filtering) on MovieLens data"""
2+
3+
import cornac
4+
from cornac.datasets import movielens
5+
from cornac.eval_methods import RatioSplit
6+
7+
8+
# Load user-item feedback
9+
data = movielens.load_feedback(variant="1M")
10+
11+
# Instantiate an evaluation method to split data into train and test sets.
12+
ratio_split = RatioSplit(
13+
data=data,
14+
test_size=0.2,
15+
exclude_unknowns=True,
16+
verbose=True,
17+
seed=123,
18+
)
19+
20+
sansa_cholmod = cornac.models.SANSA(
21+
name="SANSA (CHOLMOD)",
22+
l2=500.0,
23+
weight_matrix_density=1e-2,
24+
compute_gramian=True,
25+
factorizer_class="CHOLMOD",
26+
factorizer_shift_step=1e-3,
27+
factorizer_shift_multiplier=2.0,
28+
inverter_scans=5,
29+
inverter_finetune_steps=20,
30+
use_absolute_value_scores=False,
31+
)
32+
33+
sansa_icf = cornac.models.SANSA(
34+
name="SANSA (ICF)",
35+
l2=10.0,
36+
weight_matrix_density=1e-2,
37+
compute_gramian=True,
38+
factorizer_class="ICF",
39+
factorizer_shift_step=1e-3,
40+
factorizer_shift_multiplier=2.0,
41+
inverter_scans=5,
42+
inverter_finetune_steps=20,
43+
use_absolute_value_scores=False,
44+
)
45+
46+
47+
# Instantiate evaluation measures
48+
rec_20 = cornac.metrics.Recall(k=20)
49+
rec_50 = cornac.metrics.Recall(k=50)
50+
ndcg_100 = cornac.metrics.NDCG(k=100)
51+
52+
53+
# Put everything together into an experiment and run it
54+
cornac.Experiment(
55+
eval_method=ratio_split,
56+
models=[sansa_cholmod, sansa_icf],
57+
metrics=[rec_20, rec_50, ndcg_100],
58+
user_based=True, # If `False`, results will be averaged over the number of ratings.
59+
save_dir=None,
60+
).run()

0 commit comments

Comments
 (0)