Skip to content

Commit 2efb349

Browse files
Update torch indexing and use Tensor.cpu() (#91)
* Update dan.py * Update rnn.py * Update elmo.py * Update rnn.py * Update dan.py * Update elmo.py
1 parent 2c08401 commit 2efb349

3 files changed

Lines changed: 6 additions & 6 deletions

File tree

qanta/guesser/dan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,7 @@ def run_epoch(self, iterator: Iterator):
466466

467467
out = self.model(input_dict, lengths_dict, qanta_ids)
468468
_, preds = torch.max(out, 1)
469-
accuracy = torch.mean(torch.eq(preds, page).float()).data[0]
469+
accuracy = torch.mean(torch.eq(preds, page).float()).cpu().data
470470
batch_loss = self.criterion(out, page)
471471
if is_train:
472472
batch_loss.backward()
@@ -476,7 +476,7 @@ def run_epoch(self, iterator: Iterator):
476476
self.optimizer.step()
477477

478478
batch_accuracies.append(accuracy)
479-
batch_losses.append(batch_loss.data[0])
479+
batch_losses.append(batch_loss.cpu().data)
480480

481481
epoch_end = time.time()
482482

qanta/guesser/elmo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,14 +193,14 @@ def run_epoch(self, batches, train=True):
193193
self.model.zero_grad()
194194
out = self.model(x_batch.cuda(), length_batch.cuda())
195195
_, preds = torch.max(out, 1)
196-
accuracy = torch.mean(torch.eq(preds, y_batch).float()).data[0]
196+
accuracy = torch.mean(torch.eq(preds, y_batch).float()).cpu().data
197197
batch_loss = self.criterion(out, y_batch)
198198
if train:
199199
batch_loss.backward()
200200
torch.nn.utils.clip_grad_norm(self.model.parameters(), 0.25)
201201
self.optimizer.step()
202202
batch_accuracies.append(accuracy)
203-
batch_losses.append(batch_loss.data[0])
203+
batch_losses.append(batch_loss.cpu().data)
204204
epoch_end = time.time()
205205

206206
return np.mean(batch_accuracies), np.mean(batch_losses), epoch_end - epoch_start

qanta/guesser/rnn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def run_epoch(self, iterator: Iterator):
346346

347347
out, hidden = self.model(text, lengths, hidden_init, qanta_ids)
348348
_, preds = torch.max(out, 1)
349-
accuracy = torch.mean(torch.eq(preds, page).float()).data[0]
349+
accuracy = torch.mean(torch.eq(preds, page).float()).cpu().data
350350
batch_loss = self.criterion(out, page)
351351
if is_train:
352352
batch_loss.backward()
@@ -356,7 +356,7 @@ def run_epoch(self, iterator: Iterator):
356356
self.optimizer.step()
357357

358358
batch_accuracies.append(accuracy)
359-
batch_losses.append(batch_loss.data[0])
359+
batch_losses.append(batch_loss.cpu().data)
360360

361361
epoch_end = time.time()
362362

0 commit comments

Comments
 (0)