1+ from typing import Union , List
2+
13import torch
24import torch .nn as nn
35import torch .nn .functional as F
46import dgl
57import dgl .function as fn
68
7-
89USER_KEY = "user"
910ITEM_KEY = "item"
1011
@@ -26,40 +27,38 @@ def construct_graph(data_set, total_users, total_items):
2627 }
2728 num_dict = {USER_KEY : total_users , ITEM_KEY : total_items }
2829
29- return dgl .heterograph (data_dict , num_nodes_dict = num_dict )
30+ g = dgl .heterograph (data_dict , num_nodes_dict = num_dict )
31+ norm_dict = {}
32+ for srctype , etype , dsttype in g .canonical_etypes :
33+ src , dst = g .edges (etype = (srctype , etype , dsttype ))
34+ dst_degree = g .in_degrees (
35+ dst , etype = (srctype , etype , dsttype )
36+ ).float () # obtain degrees
37+ src_degree = g .out_degrees (src , etype = (srctype , etype , dsttype )).float ()
38+ norm = torch .pow (src_degree * dst_degree , - 0.5 ).unsqueeze (1 ) # compute norm
39+ g .edata ['norm' ] = {etype : norm }
40+
41+ return g
3042
3143
3244class GCNLayer (nn .Module ):
33- def __init__ (self , norm_dict ):
45+ def __init__ (self ):
3446 super (GCNLayer , self ).__init__ ()
3547
36- # norm
37- self .norm_dict = norm_dict
38-
3948 def forward (self , g , feat_dict ):
4049 funcs = {} # message and reduce functions dict
4150 # for each type of edges, compute messages and reduce them all
51+ g .ndata ["h" ] = feat_dict
4252 for srctype , etype , dsttype in g .canonical_etypes :
43- src , dst = g .edges (etype = (srctype , etype , dsttype ))
44- norm = self .norm_dict [(srctype , etype , dsttype )]
45- # TODO: CHECK HERE
46- messages = norm * feat_dict [srctype ][src ] # compute messages
47- g .edges [(srctype , etype , dsttype )].data [
48- etype
49- ] = messages # store in edata
5053 funcs [(srctype , etype , dsttype )] = (
51- fn .copy_e ( etype , "m" ),
52- fn .sum ("m" , "h " ),
54+ fn .u_mul_e ( "h" , "norm" , "m" ),
55+ fn .sum ("m" , "h_n " ),
5356 ) # define message and reduce functions
5457
5558 g .multi_update_all (
5659 funcs , "sum"
5760 ) # update all, reduce by first type-wisely then across different types
58- feature_dict = {}
59- for ntype in g .ntypes :
60- h = F .normalize (g .nodes [ntype ].data ["h" ], dim = 1 , p = 2 ) # l2 normalize
61- feature_dict [ntype ] = h
62- return feature_dict
61+ return g .dstdata ["h_n" ]
6362
6463
6564class Model (nn .Module ):
@@ -69,16 +68,7 @@ def __init__(self, g, in_size, num_layers, lambda_reg, device=None):
6968 self .lambda_reg = lambda_reg
7069 self .device = device
7170
72- for srctype , etype , dsttype in g .canonical_etypes :
73- src , dst = g .edges (etype = (srctype , etype , dsttype ))
74- dst_degree = g .in_degrees (
75- dst , etype = (srctype , etype , dsttype )
76- ).float () # obtain degrees
77- src_degree = g .out_degrees (src , etype = (srctype , etype , dsttype )).float ()
78- norm = torch .pow (src_degree * dst_degree , - 0.5 ).unsqueeze (1 ) # compute norm
79- self .norm_dict [(srctype , etype , dsttype )] = norm
80-
81- self .layers = nn .ModuleList ([GCNLayer (self .norm_dict ) for _ in range (num_layers )])
71+ self .layers = nn .ModuleList ([GCNLayer () for _ in range (num_layers )])
8272
8373 self .initializer = nn .init .xavier_uniform_
8474
@@ -92,16 +82,37 @@ def __init__(self, g, in_size, num_layers, lambda_reg, device=None):
9282 }
9383 )
9484
95- def forward (self , g , users = None , pos_items = None , neg_items = None ):
96- h_dict = {ntype : self .feature_dict [ntype ] for ntype in g .ntypes }
97- # obtain features of each layer and concatenate them all
98- user_embeds = h_dict [USER_KEY ]
99- item_embeds = h_dict [ITEM_KEY ]
85+ def forward (self , in_g : Union [dgl .DGLGraph , List [dgl .DGLGraph ]], users = None , pos_items = None , neg_items = None ):
86+
87+ if isinstance (in_g , list ):
88+ h_dict = {ntype : self .feature_dict [ntype ][in_g [0 ].ndata [dgl .NID ][ntype ]] for ntype in in_g [0 ].ntypes }
89+ user_embeds = h_dict [USER_KEY ][in_g [- 1 ].dstnodes (USER_KEY )]
90+ item_embeds = h_dict [ITEM_KEY ][in_g [- 1 ].dstnodes (ITEM_KEY )]
91+ iterator = enumerate (zip (in_g , self .layers ))
92+ else :
93+ h_dict = {ntype : self .feature_dict [ntype ] for ntype in in_g .ntypes }
94+ # obtain features of each layer and concatenate them all
95+ user_embeds = h_dict [USER_KEY ]
96+ item_embeds = h_dict [ITEM_KEY ]
97+ iterator = enumerate (zip ([in_g ] * len (self .layers ), self .layers ))
98+
99+ user_embeds = user_embeds * (1 / (len (self .layers ) + 1 ))
100+ item_embeds = item_embeds * (1 / (len (self .layers ) + 1 ))
100101
101- for k , layer in enumerate ( self . layers ) :
102+ for k , ( g , layer ) in iterator :
102103 h_dict = layer (g , h_dict )
103- user_embeds = user_embeds + (h_dict [USER_KEY ] * 1 / (k + 1 ))
104- item_embeds = item_embeds + (h_dict [ITEM_KEY ] * 1 / (k + 1 ))
104+ ue = h_dict [USER_KEY ]
105+ ie = h_dict [ITEM_KEY ]
106+
107+ if isinstance (in_g , list ):
108+ ue = ue [in_g [- 1 ].dstnodes (USER_KEY )]
109+ ie = ie [in_g [- 1 ].dstnodes (ITEM_KEY )]
110+
111+ user_embeds = user_embeds + ue
112+ item_embeds = item_embeds + ie
113+
114+ user_embeds = user_embeds / (len (self .layers ) + 1 )
115+ item_embeds = item_embeds / (len (self .layers ) + 1 )
105116
106117 u_g_embeddings = user_embeds if users is None else user_embeds [users , :]
107118 pos_i_g_embeddings = item_embeds if pos_items is None else item_embeds [pos_items , :]
0 commit comments