Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,10 @@ coverage.xml

# Sphinx documentation
docs/_build/


# Environment
env
venv
.env
.venv
87 changes: 49 additions & 38 deletions cornac/models/lightgcn/lightgcn.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Union, List

import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.function as fn


USER_KEY = "user"
ITEM_KEY = "item"

Expand All @@ -26,40 +27,38 @@ def construct_graph(data_set, total_users, total_items):
}
num_dict = {USER_KEY: total_users, ITEM_KEY: total_items}

return dgl.heterograph(data_dict, num_nodes_dict=num_dict)
g = dgl.heterograph(data_dict, num_nodes_dict=num_dict)
norm_dict = {}
for srctype, etype, dsttype in g.canonical_etypes:
src, dst = g.edges(etype=(srctype, etype, dsttype))
dst_degree = g.in_degrees(
dst, etype=(srctype, etype, dsttype)
).float() # obtain degrees
src_degree = g.out_degrees(src, etype=(srctype, etype, dsttype)).float()
norm = torch.pow(src_degree * dst_degree, -0.5).unsqueeze(1) # compute norm
g.edata['norm'] = {etype: norm}

return g


class GCNLayer(nn.Module):
def __init__(self, norm_dict):
def __init__(self):
super(GCNLayer, self).__init__()

# norm
self.norm_dict = norm_dict

def forward(self, g, feat_dict):
funcs = {} # message and reduce functions dict
# for each type of edges, compute messages and reduce them all
g.ndata["h"] = feat_dict
for srctype, etype, dsttype in g.canonical_etypes:
src, dst = g.edges(etype=(srctype, etype, dsttype))
norm = self.norm_dict[(srctype, etype, dsttype)]
# TODO: CHECK HERE
messages = norm * feat_dict[srctype][src] # compute messages
g.edges[(srctype, etype, dsttype)].data[
etype
] = messages # store in edata
funcs[(srctype, etype, dsttype)] = (
fn.copy_e(etype, "m"),
fn.sum("m", "h"),
fn.u_mul_e("h", "norm", "m"),
fn.sum("m", "h_n"),
) # define message and reduce functions

g.multi_update_all(
funcs, "sum"
) # update all, reduce by first type-wisely then across different types
feature_dict = {}
for ntype in g.ntypes:
h = F.normalize(g.nodes[ntype].data["h"], dim=1, p=2) # l2 normalize
feature_dict[ntype] = h
return feature_dict
return g.dstdata["h_n"]


class Model(nn.Module):
Expand All @@ -69,16 +68,7 @@ def __init__(self, g, in_size, num_layers, lambda_reg, device=None):
self.lambda_reg = lambda_reg
self.device = device

for srctype, etype, dsttype in g.canonical_etypes:
src, dst = g.edges(etype=(srctype, etype, dsttype))
dst_degree = g.in_degrees(
dst, etype=(srctype, etype, dsttype)
).float() # obtain degrees
src_degree = g.out_degrees(src, etype=(srctype, etype, dsttype)).float()
norm = torch.pow(src_degree * dst_degree, -0.5).unsqueeze(1) # compute norm
self.norm_dict[(srctype, etype, dsttype)] = norm

self.layers = nn.ModuleList([GCNLayer(self.norm_dict) for _ in range(num_layers)])
self.layers = nn.ModuleList([GCNLayer() for _ in range(num_layers)])

self.initializer = nn.init.xavier_uniform_

Expand All @@ -92,16 +82,37 @@ def __init__(self, g, in_size, num_layers, lambda_reg, device=None):
}
)

def forward(self, g, users=None, pos_items=None, neg_items=None):
h_dict = {ntype: self.feature_dict[ntype] for ntype in g.ntypes}
# obtain features of each layer and concatenate them all
user_embeds = h_dict[USER_KEY]
item_embeds = h_dict[ITEM_KEY]
def forward(self, in_g: Union[dgl.DGLGraph, List[dgl.DGLGraph]], users=None, pos_items=None, neg_items=None):

if isinstance(in_g, list):
h_dict = {ntype: self.feature_dict[ntype][in_g[0].ndata[dgl.NID][ntype]] for ntype in in_g[0].ntypes}
user_embeds = h_dict[USER_KEY][in_g[-1].dstnodes(USER_KEY)]
item_embeds = h_dict[ITEM_KEY][in_g[-1].dstnodes(ITEM_KEY)]
iterator = enumerate(zip(in_g, self.layers))
else:
h_dict = {ntype: self.feature_dict[ntype] for ntype in in_g.ntypes}
# obtain features of each layer and concatenate them all
user_embeds = h_dict[USER_KEY]
item_embeds = h_dict[ITEM_KEY]
iterator = enumerate(zip([in_g] * len(self.layers), self.layers))

user_embeds = user_embeds * (1 / (len(self.layers) + 1))
item_embeds = item_embeds * (1 / (len(self.layers) + 1))

for k, layer in enumerate(self.layers):
for k, (g, layer) in iterator:
h_dict = layer(g, h_dict)
user_embeds = user_embeds + (h_dict[USER_KEY] * 1 / (k + 1))
item_embeds = item_embeds + (h_dict[ITEM_KEY] * 1 / (k + 1))
ue = h_dict[USER_KEY]
ie = h_dict[ITEM_KEY]

if isinstance(in_g, list):
ue = ue[in_g[-1].dstnodes(USER_KEY)]
ie = ie[in_g[-1].dstnodes(ITEM_KEY)]

user_embeds = user_embeds + ue
item_embeds = item_embeds + ie

user_embeds = user_embeds / (len(self.layers) + 1)
item_embeds = item_embeds / (len(self.layers) + 1)

u_g_embeddings = user_embeds if users is None else user_embeds[users, :]
pos_i_g_embeddings = item_embeds if pos_items is None else item_embeds[pos_items, :]
Expand Down
4 changes: 2 additions & 2 deletions cornac/models/lightgcn/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
torch>=2.0.0
dgl>=1.1.0
torch==2.0.0
dgl==1.1.0