Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion a3c.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,13 @@
from constants import GRAD_NORM_CLIP
from constants import USE_GPU
from constants import USE_LSTM
from argparse import ArgumentParser

arg_parser = ArgumentParser(description='The launchpad for all performance scripts.')
arg_parser.add_argument('-ia', "--num_intra_threads", help='The intra size', type=int, dest="intra", default=0)
arg_parser.add_argument('-ie', "--num_inter_threads", help='The inter size', type=int, dest="inter", default=0)
intra = arg_parser.parse_args().intra
inter = arg_parser.parse_args().inter

def log_uniform(lo, hi, rate):
log_lo = math.log(lo)
Expand Down Expand Up @@ -72,7 +78,8 @@ def log_uniform(lo, hi, rate):

# prepare session
sess = tf.Session(config=tf.ConfigProto(log_device_placement=False,
allow_soft_placement=True))
allow_soft_placement=True,
intra_op_parallelism_threads=int(intra/PARALLEL_SIZE/2), inter_op_parallelism_threads=inter))

init = tf.global_variables_initializer()
sess.run(init)
Expand Down
13 changes: 11 additions & 2 deletions a3c_display.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@
from constants import GRAD_NORM_CLIP
from constants import USE_GPU
from constants import USE_LSTM
from argparse import ArgumentParser

arg_parser = ArgumentParser(description='The launchpad for all performance scripts.')
arg_parser.add_argument('-ia', "--num_intra_threads", help='The intra size', type=int, dest="intra", default=0)
arg_parser.add_argument('-ie', "--num_inter_threads", help='The inter size', type=int, dest="inter", default=0)
intra = arg_parser.parse_args().intra
inter = arg_parser.parse_args().inter

def choose_action(pi_values):
return np.random.choice(range(len(pi_values)), p=pi_values)
Expand All @@ -36,8 +43,10 @@ def choose_action(pi_values):
epsilon = RMSP_EPSILON,
clip_norm = GRAD_NORM_CLIP,
device = device)

sess = tf.Session()
config = tf.ConfigProto(allow_soft_placement=True,
intra_op_parallelism_threads=intra,inter_op_parallelism_threads=inter)
sess = tf.Session(config=config)
# sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

Expand Down