Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 84 additions & 81 deletions cornac/models/ncf/backend_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,8 @@
# limitations under the License.
# ============================================================================

import warnings

# disable annoying tensorflow deprecated API warnings
warnings.filterwarnings("ignore", category=UserWarning)

import tensorflow.compat.v1 as tf

tf.logging.set_verbosity(tf.logging.ERROR)
tf.disable_v2_behavior()
import tensorflow as tf


act_functions = {
Expand All @@ -35,88 +28,98 @@
}


def loss_fn(labels, logits):
cross_entropy = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits)
)
reg_loss = tf.losses.get_regularization_loss()
return cross_entropy + reg_loss


def train_fn(loss, learning_rate, learner):
def get_optimizer(learning_rate, learner):
if learner.lower() == "adagrad":
opt = tf.train.AdagradOptimizer(learning_rate=learning_rate, name="optimizer")
return tf.keras.optimizers.Adagrad(learning_rate=learning_rate)
elif learner.lower() == "rmsprop":
opt = tf.train.RMSPropOptimizer(learning_rate=learning_rate, name="optimizer")
return tf.keras.optimizers.RMSprop(learning_rate=learning_rate)
elif learner.lower() == "adam":
opt = tf.train.AdamOptimizer(learning_rate=learning_rate, name="optimizer")
return tf.keras.optimizers.Adam(learning_rate=learning_rate)
else:
opt = tf.train.GradientDescentOptimizer(
learning_rate=learning_rate, name="optimizer"
)

return opt.minimize(loss)


def emb(
uid, iid, num_users, num_items, emb_size, reg_user, reg_item, seed=None, scope="emb"
):
with tf.variable_scope(scope):
user_emb = tf.get_variable(
"user_emb",
shape=[num_users, emb_size],
dtype=tf.float32,
initializer=tf.random_normal_initializer(stddev=0.01, seed=seed),
regularizer=tf.keras.regularizers.L2(reg_user),
return tf.keras.optimizers.SGD(learning_rate=learning_rate)


class GMFLayer(tf.keras.layers.Layer):
def __init__(self, num_users, num_items, emb_size, reg_user, reg_item, seed=None, **kwargs):
super(GMFLayer, self).__init__(**kwargs)
self.num_users = num_users
self.num_items = num_items
self.emb_size = emb_size
self.reg_user = reg_user
self.reg_item = reg_item
self.seed = seed

# Initialize embeddings
self.user_embedding = tf.keras.layers.Embedding(
num_users,
emb_size,
embeddings_initializer=tf.keras.initializers.RandomNormal(stddev=0.01, seed=seed),
embeddings_regularizer=tf.keras.regularizers.L2(reg_user),
name="user_embedding"
)
item_emb = tf.get_variable(
"item_emb",
shape=[num_items, emb_size],
dtype=tf.float32,
initializer=tf.random_normal_initializer(stddev=0.01, seed=seed),
regularizer=tf.keras.regularizers.L2(reg_item),
)

return tf.nn.embedding_lookup(user_emb, uid), tf.nn.embedding_lookup(item_emb, iid)


def gmf(uid, iid, num_users, num_items, emb_size, reg_user, reg_item, seed=None):
with tf.variable_scope("GMF") as scope:
user_emb, item_emb = emb(
uid=uid,
iid=iid,
num_users=num_users,
num_items=num_items,
emb_size=emb_size,
reg_user=reg_user,
reg_item=reg_item,
seed=seed,
scope=scope,

self.item_embedding = tf.keras.layers.Embedding(
num_items,
emb_size,
embeddings_initializer=tf.keras.initializers.RandomNormal(stddev=0.01, seed=seed),
embeddings_regularizer=tf.keras.regularizers.L2(reg_item),
name="item_embedding"
)

def call(self, inputs):
user_ids, item_ids = inputs
user_emb = self.user_embedding(user_ids)
item_emb = self.item_embedding(item_ids)
return tf.multiply(user_emb, item_emb)


def mlp(uid, iid, num_users, num_items, layers, reg_layers, act_fn, seed=None):
with tf.variable_scope("MLP") as scope:
user_emb, item_emb = emb(
uid=uid,
iid=iid,
num_users=num_users,
num_items=num_items,
emb_size=int(layers[0] / 2),
reg_user=reg_layers[0],
reg_item=reg_layers[0],
seed=seed,
scope=scope,
class MLPLayer(tf.keras.layers.Layer):
def __init__(self, num_users, num_items, layers, reg_layers, act_fn, seed=None, **kwargs):
super(MLPLayer, self).__init__(**kwargs)
self.num_users = num_users
self.num_items = num_items
self.layers = layers
self.reg_layers = reg_layers
self.act_fn = act_fn
self.seed = seed

# Initialize embeddings
self.user_embedding = tf.keras.layers.Embedding(
num_users,
int(layers[0] / 2),
embeddings_initializer=tf.keras.initializers.RandomNormal(stddev=0.01, seed=seed),
embeddings_regularizer=tf.keras.regularizers.L2(reg_layers[0]),
name="user_embedding"
)
interaction = tf.concat([user_emb, item_emb], axis=-1)
for i, layer in enumerate(layers[1:]):
interaction = tf.layers.dense(
interaction,
units=layer,
name="layer{}".format(i + 1),
activation=act_functions.get(act_fn, tf.nn.relu),
kernel_initializer=tf.initializers.lecun_uniform(seed),
kernel_regularizer=tf.keras.regularizers.L2(reg_layers[i + 1]),

self.item_embedding = tf.keras.layers.Embedding(
num_items,
int(layers[0] / 2),
embeddings_initializer=tf.keras.initializers.RandomNormal(stddev=0.01, seed=seed),
embeddings_regularizer=tf.keras.regularizers.L2(reg_layers[0]),
name="item_embedding"
)

# Define dense layers
self.dense_layers = []
for i, layer_size in enumerate(layers[1:]):
self.dense_layers.append(
tf.keras.layers.Dense(
layer_size,
activation=act_functions.get(act_fn, tf.nn.relu),
kernel_initializer=tf.keras.initializers.LecunUniform(seed=seed),
kernel_regularizer=tf.keras.regularizers.L2(reg_layers[i + 1]),
name=f"layer{i+1}"
)
)

def call(self, inputs):
user_ids, item_ids = inputs
user_emb = self.user_embedding(user_ids)
item_emb = self.item_embedding(item_ids)
interaction = tf.concat([user_emb, item_emb], axis=-1)

for layer in self.dense_layers:
interaction = layer(interaction)

return interaction
88 changes: 39 additions & 49 deletions cornac/models/ncf/recom_gmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,55 +111,45 @@ def __init__(
########################
## TensorFlow backend ##
########################
def _build_graph_tf(self):
import tensorflow.compat.v1 as tf
from .backend_tf import gmf, loss_fn, train_fn

self.graph = tf.Graph()
with self.graph.as_default():
tf.set_random_seed(self.seed)

self.user_id = tf.placeholder(shape=[None], dtype=tf.int32, name="user_id")
self.item_id = tf.placeholder(shape=[None], dtype=tf.int32, name="item_id")
self.labels = tf.placeholder(
shape=[None, 1], dtype=tf.float32, name="labels"
)

self.interaction = gmf(
uid=self.user_id,
iid=self.item_id,
num_users=self.num_users,
num_items=self.num_items,
emb_size=self.num_factors,
reg_user=self.reg,
reg_item=self.reg,
seed=self.seed,
)

logits = tf.layers.dense(
self.interaction,
units=1,
name="logits",
kernel_initializer=tf.initializers.lecun_uniform(self.seed),
)
self.prediction = tf.nn.sigmoid(logits)

self.loss = loss_fn(labels=self.labels, logits=logits)
self.train_op = train_fn(
self.loss, learning_rate=self.lr, learner=self.learner
)

self.initializer = tf.global_variables_initializer()
self.saver = tf.train.Saver()

self._sess_init_tf()

def _score_tf(self, user_idx, item_idx):
feed_dict = {
self.user_id: [user_idx],
self.item_id: np.arange(self.num_items) if item_idx is None else [item_idx],
}
return self.sess.run(self.prediction, feed_dict=feed_dict)
def _build_model_tf(self):
import tensorflow as tf
from .backend_tf import GMFLayer

# Define inputs
user_input = tf.keras.layers.Input(shape=(1,), dtype=tf.int32, name="user_input")
item_input = tf.keras.layers.Input(shape=(1,), dtype=tf.int32, name="item_input")

# GMF layer
gmf_layer = GMFLayer(
num_users=self.num_users,
num_items=self.num_items,
emb_size=self.num_factors,
reg_user=self.reg,
reg_item=self.reg,
seed=self.seed,
name="gmf_layer"
)

# Get embeddings and element-wise product
gmf_vector = gmf_layer([user_input, item_input])

# Output layer
logits = tf.keras.layers.Dense(
1,
kernel_initializer=tf.keras.initializers.LecunUniform(seed=self.seed),
name="logits"
)(gmf_vector)

prediction = tf.keras.layers.Activation('sigmoid', name="prediction")(logits)

# Create model with both logits and prediction outputs
model = tf.keras.Model(
inputs=[user_input, item_input],
outputs=prediction,
name="GMF"
)

return model

#####################
## PyTorch backend ##
Expand Down
93 changes: 39 additions & 54 deletions cornac/models/ncf/recom_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,60 +116,45 @@ def __init__(
########################
## TensorFlow backend ##
########################
def _build_graph_tf(self):
import tensorflow.compat.v1 as tf
from .backend_tf import mlp, loss_fn, train_fn

self.graph = tf.Graph()
with self.graph.as_default():
tf.set_random_seed(self.seed)

self.user_id = tf.placeholder(shape=[None], dtype=tf.int32, name="user_id")
self.item_id = tf.placeholder(shape=[None], dtype=tf.int32, name="item_id")
self.labels = tf.placeholder(
shape=[None, 1], dtype=tf.float32, name="labels"
)

self.interaction = mlp(
uid=self.user_id,
iid=self.item_id,
num_users=self.num_users,
num_items=self.num_items,
layers=self.layers,
reg_layers=[self.reg] * len(self.layers),
act_fn=self.act_fn,
seed=self.seed,
)
logits = tf.layers.dense(
self.interaction,
units=1,
name="logits",
kernel_initializer=tf.initializers.lecun_uniform(self.seed),
)
self.prediction = tf.nn.sigmoid(logits)

self.loss = loss_fn(labels=self.labels, logits=logits)
self.train_op = train_fn(
self.loss, learning_rate=self.lr, learner=self.learner
)

self.initializer = tf.global_variables_initializer()
self.saver = tf.train.Saver()

self._sess_init_tf()

def _score_tf(self, user_idx, item_idx):
if item_idx is None:
feed_dict = {
self.user_id: np.ones(self.num_items) * user_idx,
self.item_id: np.arange(self.num_items),
}
else:
feed_dict = {
self.user_id: [user_idx],
self.item_id: [item_idx],
}
return self.sess.run(self.prediction, feed_dict=feed_dict)
def _build_model_tf(self):
import tensorflow as tf
from .backend_tf import MLPLayer

# Define inputs
user_input = tf.keras.layers.Input(shape=(1,), dtype=tf.int32, name="user_input")
item_input = tf.keras.layers.Input(shape=(1,), dtype=tf.int32, name="item_input")

# MLP layer
mlp_layer = MLPLayer(
num_users=self.num_users,
num_items=self.num_items,
layers=self.layers,
reg_layers=[self.reg] * len(self.layers),
act_fn=self.act_fn,
seed=self.seed,
name="mlp_layer"
)

# Get MLP vector
mlp_vector = mlp_layer([user_input, item_input])

# Output layer
logits = tf.keras.layers.Dense(
1,
kernel_initializer=tf.keras.initializers.LecunUniform(seed=self.seed),
name="logits"
)(mlp_vector)

prediction = tf.keras.layers.Activation('sigmoid', name="prediction")(logits)

# Create model
model = tf.keras.Model(
inputs=[user_input, item_input],
outputs=prediction,
name="MLP"
)

return model

#####################
## PyTorch backend ##
Expand Down
Loading
Loading