Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RecommenderSystems/dlrm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def str_list(x):
parser.add_argument('--eval_batch_size', type=int, default=512)
parser.add_argument("--eval_batch_size_per_proc", type=int, default=None)
parser.add_argument('--eval_interval', type=int, default=1000)
parser.add_argument("--eval_save_dir", type=str, default='', help="eval AUC offline if available")
parser.add_argument("--batch_size", type=int, default=16384)
parser.add_argument("--batch_size_per_proc", type=int, default=None)
parser.add_argument("--learning_rate", type=float, default=1e-3)
Expand Down
26 changes: 20 additions & 6 deletions RecommenderSystems/dlrm/train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import sys
import pickle

sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))
Expand All @@ -15,6 +16,7 @@
from graph import DLRMValGraph, DLRMTrainGraph
import warnings
import utils.logger as log
from utils.auc_calculater import calculate_auc_from_dir


class Trainer(object):
Expand Down Expand Up @@ -153,7 +155,6 @@ def __call__(self):
self.train()

def train(self):
losses = []
self.dlrm_module.train()
for _ in range(self.max_iter):
self.cur_iter += 1
Expand All @@ -168,21 +169,34 @@ def train(self):
if self.eval_after_training:
self.eval(True)

if self.args.eval_save_dir != '' and self.eval_after_training:
calculate_auc_from_dir(self.args.eval_save_dir)

def eval(self, save_model=False):
if self.eval_batchs <= 0:
return
self.dlrm_module.eval()
labels = np.array([[0]])
preds = np.array([[0]])
labels = []
preds = []
for _ in range(self.eval_batchs):
if self.execution_mode == "graph":
pred, label = self.eval_graph()
else:
pred, label = self.inference()
label_ = label.numpy().astype(np.float32)
labels = np.concatenate((labels, label_), axis=0)
preds = np.concatenate((preds, pred.numpy()), axis=0)
auc = roc_auc_score(labels[1:], preds[1:])
labels.append(label_)
preds.append(pred.numpy())
if self.args.eval_save_dir != '':
pf = os.path.join(self.args.eval_save_dir, f'eval_results_iter_{self.cur_iter}.pkl')
with open(pf, 'wb') as f:
obj = {'labels': labels, 'preds': preds, 'iter': self.cur_iter}
pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL)
# auc = roc_auc_score(label_, pred.numpy())
auc = 'nc'
else:
labels = np.concatenate(labels, axis=0)
preds = np.concatenate(preds, axis=0)
auc = roc_auc_score(labels, preds)
self.meter_eval(auc)
if save_model:
sub_save_dir = f"iter_{self.cur_iter}_val_auc_{auc}"
Expand Down
30 changes: 30 additions & 0 deletions RecommenderSystems/dlrm/utils/auc_calculater.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import os
import sys
import time
import pickle
import numpy as np
from sklearn.metrics import roc_auc_score


def calculate_auc_from_file(pkl):
results = pickle.load(open(pkl, 'rb'))
labels = results['labels']
preds = results['preds']
iter = results['iter']
labels = np.concatenate(labels, axis=0)
preds = np.concatenate(preds, axis=0)
auc = roc_auc_score(labels, preds)
print('iter', iter, auc, time.time(), labels.shape[0])


def calculate_auc_from_dir(directory, startswith='eval_results_iter'):
print('calculate AUC from folder:', directory)
for file in os.listdir(directory):
filename = os.fsdecode(file)
if filename.startswith(startswith) and filename.endswith(".pkl"):
calculate_auc_from_file(os.path.join(directory, filename))


if __name__ == "__main__":
assert len(sys.argv) == 2, 'please input directory'
calculate_auc_from_dir(sys.argv[1])