Skip to content

Commit 5de0a8e

Browse files
authored
Remove self.device to prevent import requirement of torch for prediction (#617)
1 parent ae7ba86 commit 5de0a8e

File tree

2 files changed

+4
-5
lines changed

2 files changed

+4
-5
lines changed

cornac/models/lightgcn/lightgcn.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,10 @@ def forward(self, g, feat_dict):
6262

6363

6464
class Model(nn.Module):
65-
def __init__(self, g, in_size, num_layers, lambda_reg, device=None):
65+
def __init__(self, g, in_size, num_layers, lambda_reg):
6666
super(Model, self).__init__()
6767
self.norm_dict = dict()
6868
self.lambda_reg = lambda_reg
69-
self.device = device
7069

7170
self.layers = nn.ModuleList([GCNLayer() for _ in range(num_layers)])
7271

cornac/models/lightgcn/recom_lightgcn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,21 +124,21 @@ def fit(self, train_set, val_set=None):
124124
from .lightgcn import Model
125125
from .lightgcn import construct_graph
126126

127-
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
127+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
128128
if self.seed is not None:
129129
torch.manual_seed(self.seed)
130130
if torch.cuda.is_available():
131131
torch.cuda.manual_seed_all(self.seed)
132132

133133
graph = construct_graph(train_set, self.total_users, self.total_items).to(
134-
self.device
134+
device
135135
)
136136
model = Model(
137137
graph,
138138
self.emb_size,
139139
self.num_layers,
140140
self.lambda_reg,
141-
).to(self.device)
141+
).to(device)
142142

143143
optimizer = torch.optim.Adam(model.parameters(), lr=self.learning_rate)
144144

0 commit comments

Comments
 (0)