-
Notifications
You must be signed in to change notification settings - Fork 42
feat: add ca end-to-end tests #409
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
3c30138
2ca2671
e8d35f4
c111706
f8baba1
9f2abb7
1b8d5c9
208796a
2ad3579
b6c0efb
7bcab4a
e45dc24
e4f8cbe
5f1403f
888a8e2
b7aa258
e829ffe
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| {"_label_ca":{"cost":-1.5,"pdf_value":0.0005050505278632045,"action":1.014871597290039},"Timestamp":"2021-08-25T15:36:54.000000Z","Version":"1","EventId":"91f71c8","c":{"RobotJoint1":{"friction":78}},"VWState":{"m":"N/A"},"_skipLearn":true} | ||
| {"_label_ca":{"cost":-1.5,"pdf_value":0.4755050539970398,"action":12.464624404907227},"Timestamp":"2021-08-25T15:36:54.000000Z","Version":"1","EventId":"75d50657","c":{"RobotJoint1":{"friction":78}},"VWState":{"m":"N/A"}} | ||
| {"_label_ca":{"cost":-1.5,"pdf_value":0.4755050539970398,"action":12.43958568572998},"Timestamp":"2021-08-25T15:36:54.000000Z","Version":"1","EventId":"e28a9ae6","c":{"RobotJoint1":{"friction":78}},"VWState":{"m":"N/A"}} |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| {"_label_ca":{"cost":-1.5,"pdf_value":0.0005050505278632045,"action":1.014871597290039},"Timestamp":"2021-08-24T14:38:15.000000Z","Version":"1","EventId":"91f71c8","c":{"RobotJoint1":{"friction":78}},"VWState":{"m":"N/A"}} | ||
| {"_label_ca":{"cost":-1.5,"pdf_value":0.4755050539970398,"action":12.464624404907227},"Timestamp":"2021-08-24T14:38:15.000000Z","Version":"1","EventId":"75d50657","c":{"RobotJoint1":{"friction":78}},"VWState":{"m":"N/A"}} | ||
| {"_label_ca":{"cost":-1.5,"pdf_value":0.4755050539970398,"action":12.43958568572998},"Timestamp":"2021-08-24T14:38:15.000000Z","Version":"1","EventId":"e28a9ae6","c":{"RobotJoint1":{"friction":78}},"VWState":{"m":"N/A"}} |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,30 +21,19 @@ namespace po = boost::program_options; | |
| //global var, yeah ugg | ||
| bool enable_dedup = false; | ||
|
|
||
| static const char *options[] = { | ||
| "cb", | ||
| "invalid-cb", | ||
| "ccb", | ||
| "ccb-with-slot-id", | ||
| "ccb-baseline", | ||
| "slates", | ||
| "ca", | ||
| "f-reward", | ||
| "fi-reward", | ||
| "fi-out-of-bound-reward", | ||
| "fs-reward", | ||
| "fmix-reward", | ||
| "s-reward", | ||
| "si-reward", | ||
| "ss-reward", | ||
| "action-taken", | ||
| "cb-loop", | ||
| "ccb-loop", | ||
| "ccb-baseline-loop", | ||
| nullptr | ||
| }; | ||
|
|
||
| enum options{ | ||
| static const char *options[] = {"cb", "invalid-cb", | ||
| "ccb", "ccb-with-slot-id", | ||
| "ccb-baseline", "slates", | ||
| "ca", "f-reward", | ||
| "fi-reward", "fi-out-of-bound-reward", | ||
| "fs-reward", "fmix-reward", | ||
| "s-reward", "si-reward", | ||
| "ss-reward", "action-taken", | ||
| "cb-loop", "ca-loop", | ||
| "ccb-loop", "ccb-baseline-loop", | ||
| nullptr}; | ||
|
|
||
| enum options { | ||
| CB_ACTION, | ||
| INVALID_CB_ACTION, | ||
| CCB_ACTION, | ||
|
|
@@ -63,6 +52,7 @@ enum options{ | |
| S_S_REWARD, | ||
| ACTION_TAKEN, | ||
| CB_LOOP, | ||
| CA_LOOP, | ||
| CCB_LOOP, | ||
| CCB_BASELINE_ACTION_LOOP | ||
| }; | ||
|
|
@@ -115,9 +105,7 @@ void load_config_from_json(int action, u::configuration& config, bool enable_app | |
| } else if (action == SLATES_ACTION) { | ||
| std::string args = "--slates --ccb_explore_adf --json --quiet --epsilon " + std::to_string(epsilon) + " --first_only --id N/A"; | ||
| config.set(r::name::MODEL_VW_INITIAL_COMMAND_LINE, args.c_str()); | ||
| } | ||
| else if (action == CA_ACTION) | ||
| { | ||
| } else if (action == CA_ACTION || action == CA_LOOP) { | ||
| config.set(r::name::MODEL_VW_INITIAL_COMMAND_LINE, "--cats 4 --min_value 1 --max_value 100 --bandwidth 1 --json --quiet --id N/A"); | ||
| } | ||
| } | ||
|
|
@@ -248,7 +236,9 @@ void send_ccb_outcome(std::mt19937& rng, bool gen_random_reward, const char * ev | |
| } | ||
| } | ||
|
|
||
| int take_action(r::live_model& rl, const char *event_id, int action, unsigned int action_flag, bool gen_random_reward, std::mt19937& rng) { | ||
| int take_action(r::live_model &rl, const char *event_id, int action, | ||
| unsigned int action_flag, bool gen_random_reward, | ||
| std::mt19937 &rng, bool no_loop_actions) { | ||
| r::api_status status; | ||
| float reward = gen_random_reward ? get_random_number(rng) : 1.5f; | ||
|
|
||
|
|
@@ -407,6 +397,34 @@ int take_action(r::live_model& rl, const char *event_id, int action, unsigned in | |
|
|
||
| break; | ||
| }; | ||
| case CA_LOOP: { // "ca_loop", | ||
| r::continuous_action_response response; | ||
| if (rl.request_continuous_action(event_id, JSON_CA_CONTEXT, action_flag, | ||
| response, &status) != err::success) | ||
| std::cout << status.get_error_msg() << std::endl; | ||
| size_t num_of_rewards = get_random_number(rng); | ||
| for (size_t i = 0; i < num_of_rewards; i++) { | ||
| float reward = gen_random_reward ? get_random_number(rng, 0) : 1.5f; | ||
| std::cout << "report outcome: " << reward << " for event: " << event_id | ||
| << std::endl; | ||
| if (rl.report_outcome(event_id, reward, &status) != err::success) | ||
| std::cout << status.get_error_msg() << std::endl; | ||
| } | ||
|
|
||
| if (action_flag == r::action_flags::DEFERRED && !no_loop_actions) { | ||
| size_t rand_num = get_random_number(rng, 0 /*min*/); | ||
| if (rand_num % 2) { | ||
| // send activation | ||
| std::cout << "sending activation for event_id: " << event_id | ||
| << std::endl; | ||
| if (rl.report_action_taken(event_id, &status) != err::success) { | ||
| std::cout << status.get_error_msg() << std::endl; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| break; | ||
| }; | ||
| case CCB_LOOP: { // "ccb action and random number of float rewards and mix of slot ids / non slot ids / float / string rewards" | ||
| // randomly decide to send either ccb with slot id's provided or random slot id's | ||
| // the ccb interactions that are non-random are the ones we can use to send observations for the slot id using the slot-id string | ||
|
|
@@ -495,7 +513,10 @@ int pseudo_random(int seed) { | |
| return (int)(val & 0xFFFFFFFF); | ||
| } | ||
|
|
||
| int run_config(int action, int count, int initial_seed, bool gen_random_reward, bool enable_apprentice_mode, int deferred_action_count, std::string config_file, std::mt19937& rng, float epsilon = 0.0f) { | ||
| int run_config(int action, int count, int initial_seed, bool gen_random_reward, | ||
| bool enable_apprentice_mode, int deferred_action_count, | ||
| std::string config_file, std::mt19937 &rng, bool no_loop_actions, | ||
| float epsilon = 0.0f) { | ||
| u::configuration config; | ||
|
|
||
| if (config_file.empty()) | ||
|
|
@@ -528,7 +549,8 @@ int run_config(int action, int count, int initial_seed, bool gen_random_reward, | |
| auto action_flag = i < deferred_action_count | ||
| ? r::action_flags::DEFERRED : r::action_flags::DEFAULT; | ||
|
|
||
| int r = take_action(rl, event_id, action, action_flag, gen_random_reward, rng); | ||
| int r = take_action(rl, event_id, action, action_flag, gen_random_reward, | ||
| rng, no_loop_actions); | ||
| if(r) | ||
| return r; | ||
| } | ||
|
|
@@ -547,19 +569,28 @@ int main(int argc, char *argv[]) { | |
| bool enable_apprentice_mode = false; | ||
| int deferred_action_count = 0; | ||
| float epsilon = 0.f; | ||
|
|
||
| desc.add_options() | ||
| ("help", "Produce help message") | ||
| ("all", "use all args") | ||
| ("dedup", "Enable dedup/zstd") | ||
| ("count", po::value<int>(), "Number of events to produce") | ||
| ("seed", po::value<int>(), "Initial seed used to produce event ids") | ||
| ("epsilon", po::value<float>(), "epsilon to be used in command line args for VW") | ||
| ("kind", po::value<std::string>(), "which kind of example to generate (cb,invalid-cb,ccb,ccb-with-slot-id,ccb-baseline,slates,ca,cb-loop,ccb-loop,ccb-baseline-loop,(f|s)(s|i|mix|i-out-of-bound)?-reward,action-taken)") | ||
| ("random_reward", "Generate random float reward for observation event") | ||
| ("config_file", po::value<std::string>(), "json config file for rlclinetlib") | ||
| ("apprentice", "Enable apprentice mode") | ||
| ("deferred_action_count", po::value<int>(), "Number of deferred action for interaction events. Set the deferred_action flag to true for first deferred_action_count number of actions"); | ||
| bool no_loop_actions = false; | ||
|
|
||
| desc.add_options()("help", "Produce help message")("all", "use all args")( | ||
| "dedup", "Enable dedup/zstd")("count", po::value<int>(), | ||
| "Number of events to produce")( | ||
| "seed", po::value<int>(), "Initial seed used to produce event ids")( | ||
| "epsilon", po::value<float>(), | ||
| "epsilon to be used in command line args for VW")( | ||
| "kind", po::value<std::string>(), | ||
| "which kind of example to generate " | ||
| "(cb,invalid-cb,ccb,ccb-with-slot-id,ccb-baseline,slates,ca,cb-loop,ca-" | ||
| "loop,ccb-loop,ccb-baseline-loop,(f|s)(s|i|mix|i-out-of-bound)?-reward," | ||
| "action-taken)")("random_reward", | ||
| "Generate random float reward for observation event")( | ||
| "config_file", po::value<std::string>(), | ||
| "json config file for rlclinetlib")("apprentice", | ||
| "Enable apprentice mode")( | ||
| "deferred_action_count", po::value<int>(), | ||
| "Number of deferred action for interaction events. Set the " | ||
| "deferred_action flag to true for first deferred_action_count number of " | ||
| "actions")("no_loop_actions", | ||
| "Flag to disable actions being taken for all outcome events"); | ||
|
||
|
|
||
| po::positional_options_description pd; | ||
| pd.add("kind", 1); | ||
|
|
@@ -572,11 +603,13 @@ int main(int argc, char *argv[]) { | |
| gen_random_reward = vm.count("random_reward"); | ||
| enable_apprentice_mode = vm.count("apprentice"); | ||
| enable_dedup = vm.count("dedup"); | ||
| no_loop_actions = vm.count("no_loop_actions"); | ||
|
|
||
| std::vector<std::string> deferrable_interactions { | ||
| "cb", "invalid-cb", "ccb", "ccb-baseline", "slates", "ca", "cb-loop", | ||
| "ccb-with-slot-id", "ccb-loop", "ccb-baseline-loop" | ||
| }; | ||
| std::vector<std::string> deferrable_interactions{ | ||
| "cb", "invalid-cb", "ccb", | ||
| "ccb-baseline", "slates", "ca", | ||
| "cb-loop", "ca-loop", "ccb-with-slot-id", | ||
| "ccb-loop", "ccb-baseline-loop"}; | ||
|
|
||
| if(vm.count("kind") > 0) | ||
| action_name = vm["kind"].as<std::string>(); | ||
|
|
@@ -614,7 +647,9 @@ int main(int argc, char *argv[]) { | |
|
|
||
| if(gen_all) { | ||
| for(int i = 0; options[i]; ++i) { | ||
| if(run_config(i, count, seed, gen_random_reward, enable_apprentice_mode, deferred_action_count, config_file, rng, epsilon)) | ||
| if (run_config(i, count, seed, gen_random_reward, enable_apprentice_mode, | ||
| deferred_action_count, config_file, rng, no_loop_actions, | ||
| epsilon)) | ||
| return -1; | ||
| } | ||
| return 0; | ||
|
|
@@ -634,5 +669,7 @@ int main(int argc, char *argv[]) { | |
| return -1; | ||
| } | ||
|
|
||
| return run_config(action, count, seed, gen_random_reward, enable_apprentice_mode, deferred_action_count, config_file, rng, epsilon); | ||
| return run_config(action, count, seed, gen_random_reward, | ||
| enable_apprentice_mode, deferred_action_count, config_file, | ||
| rng, no_loop_actions, epsilon); | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.