Skip to content

Commit a2ee37e

Browse files
authored
Lightgcn fix (#602)
* Add git ignore * Lightgcn fix Removed normalization for layers, not used for lgcn. Fixed sum weight constant to num layers instead of cur layer index. Allow lgcn to take blocks. Fixed requirement error caused by newer dgl versions. Moved edge normalization to graph for easier use. * Lightgcn debug error fix * Simplified layer normalization and readability * Easier support of rcuda
1 parent c603598 commit a2ee37e

File tree

3 files changed

+60
-40
lines changed

3 files changed

+60
-40
lines changed

.gitignore

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,10 @@ coverage.xml
5555

5656
# Sphinx documentation
5757
docs/_build/
58+
59+
60+
# Environment
61+
env
62+
venv
63+
.env
64+
.venv

cornac/models/lightgcn/lightgcn.py

Lines changed: 49 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
from typing import Union, List
2+
13
import torch
24
import torch.nn as nn
35
import torch.nn.functional as F
46
import dgl
57
import dgl.function as fn
68

7-
89
USER_KEY = "user"
910
ITEM_KEY = "item"
1011

@@ -26,40 +27,38 @@ def construct_graph(data_set, total_users, total_items):
2627
}
2728
num_dict = {USER_KEY: total_users, ITEM_KEY: total_items}
2829

29-
return dgl.heterograph(data_dict, num_nodes_dict=num_dict)
30+
g = dgl.heterograph(data_dict, num_nodes_dict=num_dict)
31+
norm_dict = {}
32+
for srctype, etype, dsttype in g.canonical_etypes:
33+
src, dst = g.edges(etype=(srctype, etype, dsttype))
34+
dst_degree = g.in_degrees(
35+
dst, etype=(srctype, etype, dsttype)
36+
).float() # obtain degrees
37+
src_degree = g.out_degrees(src, etype=(srctype, etype, dsttype)).float()
38+
norm = torch.pow(src_degree * dst_degree, -0.5).unsqueeze(1) # compute norm
39+
g.edata['norm'] = {etype: norm}
40+
41+
return g
3042

3143

3244
class GCNLayer(nn.Module):
33-
def __init__(self, norm_dict):
45+
def __init__(self):
3446
super(GCNLayer, self).__init__()
3547

36-
# norm
37-
self.norm_dict = norm_dict
38-
3948
def forward(self, g, feat_dict):
4049
funcs = {} # message and reduce functions dict
4150
# for each type of edges, compute messages and reduce them all
51+
g.ndata["h"] = feat_dict
4252
for srctype, etype, dsttype in g.canonical_etypes:
43-
src, dst = g.edges(etype=(srctype, etype, dsttype))
44-
norm = self.norm_dict[(srctype, etype, dsttype)]
45-
# TODO: CHECK HERE
46-
messages = norm * feat_dict[srctype][src] # compute messages
47-
g.edges[(srctype, etype, dsttype)].data[
48-
etype
49-
] = messages # store in edata
5053
funcs[(srctype, etype, dsttype)] = (
51-
fn.copy_e(etype, "m"),
52-
fn.sum("m", "h"),
54+
fn.u_mul_e("h", "norm", "m"),
55+
fn.sum("m", "h_n"),
5356
) # define message and reduce functions
5457

5558
g.multi_update_all(
5659
funcs, "sum"
5760
) # update all, reduce by first type-wisely then across different types
58-
feature_dict = {}
59-
for ntype in g.ntypes:
60-
h = F.normalize(g.nodes[ntype].data["h"], dim=1, p=2) # l2 normalize
61-
feature_dict[ntype] = h
62-
return feature_dict
61+
return g.dstdata["h_n"]
6362

6463

6564
class Model(nn.Module):
@@ -69,16 +68,7 @@ def __init__(self, g, in_size, num_layers, lambda_reg, device=None):
6968
self.lambda_reg = lambda_reg
7069
self.device = device
7170

72-
for srctype, etype, dsttype in g.canonical_etypes:
73-
src, dst = g.edges(etype=(srctype, etype, dsttype))
74-
dst_degree = g.in_degrees(
75-
dst, etype=(srctype, etype, dsttype)
76-
).float() # obtain degrees
77-
src_degree = g.out_degrees(src, etype=(srctype, etype, dsttype)).float()
78-
norm = torch.pow(src_degree * dst_degree, -0.5).unsqueeze(1) # compute norm
79-
self.norm_dict[(srctype, etype, dsttype)] = norm
80-
81-
self.layers = nn.ModuleList([GCNLayer(self.norm_dict) for _ in range(num_layers)])
71+
self.layers = nn.ModuleList([GCNLayer() for _ in range(num_layers)])
8272

8373
self.initializer = nn.init.xavier_uniform_
8474

@@ -92,16 +82,37 @@ def __init__(self, g, in_size, num_layers, lambda_reg, device=None):
9282
}
9383
)
9484

95-
def forward(self, g, users=None, pos_items=None, neg_items=None):
96-
h_dict = {ntype: self.feature_dict[ntype] for ntype in g.ntypes}
97-
# obtain features of each layer and concatenate them all
98-
user_embeds = h_dict[USER_KEY]
99-
item_embeds = h_dict[ITEM_KEY]
85+
def forward(self, in_g: Union[dgl.DGLGraph, List[dgl.DGLGraph]], users=None, pos_items=None, neg_items=None):
86+
87+
if isinstance(in_g, list):
88+
h_dict = {ntype: self.feature_dict[ntype][in_g[0].ndata[dgl.NID][ntype]] for ntype in in_g[0].ntypes}
89+
user_embeds = h_dict[USER_KEY][in_g[-1].dstnodes(USER_KEY)]
90+
item_embeds = h_dict[ITEM_KEY][in_g[-1].dstnodes(ITEM_KEY)]
91+
iterator = enumerate(zip(in_g, self.layers))
92+
else:
93+
h_dict = {ntype: self.feature_dict[ntype] for ntype in in_g.ntypes}
94+
# obtain features of each layer and concatenate them all
95+
user_embeds = h_dict[USER_KEY]
96+
item_embeds = h_dict[ITEM_KEY]
97+
iterator = enumerate(zip([in_g] * len(self.layers), self.layers))
98+
99+
user_embeds = user_embeds * (1 / (len(self.layers) + 1))
100+
item_embeds = item_embeds * (1 / (len(self.layers) + 1))
100101

101-
for k, layer in enumerate(self.layers):
102+
for k, (g, layer) in iterator:
102103
h_dict = layer(g, h_dict)
103-
user_embeds = user_embeds + (h_dict[USER_KEY] * 1 / (k + 1))
104-
item_embeds = item_embeds + (h_dict[ITEM_KEY] * 1 / (k + 1))
104+
ue = h_dict[USER_KEY]
105+
ie = h_dict[ITEM_KEY]
106+
107+
if isinstance(in_g, list):
108+
ue = ue[in_g[-1].dstnodes(USER_KEY)]
109+
ie = ie[in_g[-1].dstnodes(ITEM_KEY)]
110+
111+
user_embeds = user_embeds + ue
112+
item_embeds = item_embeds + ie
113+
114+
user_embeds = user_embeds / (len(self.layers) + 1)
115+
item_embeds = item_embeds / (len(self.layers) + 1)
105116

106117
u_g_embeddings = user_embeds if users is None else user_embeds[users, :]
107118
pos_i_g_embeddings = item_embeds if pos_items is None else item_embeds[pos_items, :]
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
1-
torch>=2.0.0
2-
dgl>=1.1.0
1+
# Comment in to use cuda 11.X
2+
#-f https://data.dgl.ai/wheels/cu11X/repo.html
3+
torch==2.0.0
4+
dgl==1.1.0

0 commit comments

Comments
 (0)