diff --git a/cola_inference.py b/cola_inference.py index da54b8ef2..7c978ae21 100644 --- a/cola_inference.py +++ b/cola_inference.py @@ -148,11 +148,11 @@ def main(cl_arguments): args.pool_type = select_pool_type(args) # Prepare data # - _, target_tasks, vocab, word_embs = build_tasks(args) + cuda_device = parse_cuda_list_arg(args.cuda) + _, target_tasks, vocab, word_embs = build_tasks(args, cuda_device) tasks = sorted(set(target_tasks), key=lambda x: x.name) # Build or load model # - cuda_device = parse_cuda_list_arg(args.cuda) model = build_model(args, vocab, word_embs, tasks, cuda_device) log.info("Loading existing model from %s...", cl_args.model_file_path) load_model_state(model, cl_args.model_file_path, args.cuda, [], strict=False)