-
Notifications
You must be signed in to change notification settings - Fork 13
Elliotschu patch 1 #8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
elliotschu
wants to merge
3
commits into
main
Choose a base branch
from
elliotschu-patch-1
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,3 @@ | ||
| # CONSCENDI: A Contrastive and Scenario-Guided Distillation Approach to Guardrail Models for Virtual Assistants | ||
|
|
||
| This repository contains the associated dataset for our work, CONSCENDI. Dataset and details coming soon. | ||
| This repository contains the associated dataset for our work, CONSCENDI (https://arxiv.org/abs/2304.14364) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,139 @@ | ||
| import time | ||
| import openai | ||
| import os | ||
| from jinja2 import Template | ||
| import logging | ||
|
|
||
| SLEEP_TIME_SHORT = 0.33 | ||
| SLEEP_TIME_INCREASE_COUNT = 0 | ||
| SLEEP_TIME_LONG = 60 | ||
| KEY_INDEX=0 | ||
|
|
||
| KEYS = [ | ||
| ] | ||
|
|
||
| def short_sleep(logger): | ||
| logger.info(f"Sleeping process for {SLEEP_TIME_SHORT} seconds.") | ||
| time.sleep(SLEEP_TIME_SHORT) | ||
|
|
||
| def long_sleep(logger): | ||
| logger.info(f"Sleeping process for {SLEEP_TIME_LONG} seconds.") | ||
| time.sleep(SLEEP_TIME_LONG) | ||
|
|
||
| def cycle_api_key(keys, logger): | ||
| openai.api_key = keys[0] | ||
| logger.warning(f"Current OpenAI key changed to: {openai.api_key}") | ||
| keys = keys[1:] + [keys[0]] | ||
| return keys | ||
|
|
||
| def invalid_request_error(logger, index): | ||
| logger.warning(f"Too many tokens in request, skipping this index: {index}") | ||
|
|
||
| def handle_rate_limit_error(logger, index): | ||
| logger.warning(f"Rate limit for OpenAI reached, increasing sleeps between calls and retrying index: {index}") | ||
| long_sleep(logger) | ||
|
|
||
| def handle_service_unavailable_error(logger, index): | ||
| logger.warning(f"OpenAI service unavailable waiting and retrying index: {index}") | ||
| long_sleep(logger) | ||
|
|
||
| def query_and_retry(formatted_prompt, temperature, max_tokens, engine, logger, stop=None, max_retries=3): | ||
| keys = KEYS | ||
| if engine == "gpt-4" or engine == "gpt-3.5-turbo": | ||
| while True: | ||
| try: | ||
| output = openai.ChatCompletion.create( | ||
| model=engine, | ||
| messages=[ | ||
| {"role": "user", "content": formatted_prompt}, | ||
| ], | ||
| temperature=temperature, | ||
| max_tokens=max_tokens, | ||
| stop = stop | ||
| ) | ||
| return output | ||
| except (openai.error.RateLimitError, openai.error.ServiceUnavailableError, openai.error.Timeout, | ||
| openai.error.APIConnectionError, | ||
| openai.error.APIError, | ||
| openai.error.TryAgain) as e: | ||
| if isinstance(e, openai.error.RateLimitError): | ||
| handle_rate_limit_error(logger, 0) | ||
| elif isinstance(e, openai.error.ServiceUnavailableError): | ||
| handle_service_unavailable_error(logger, 0) | ||
|
|
||
| # Sleep for a short period before retrying (optional) | ||
| time.sleep(1) | ||
|
|
||
| def query_and_retry_completion(formatted_prompt, temperature, max_tokens, engine, logger, stop=None): | ||
| keys=KEYS | ||
| if engine == "gpt-4" or engine == "gpt-3.5-turbo": | ||
| try: | ||
| output = openai.ChatCompletion.create( | ||
| model="gpt-3.5-turbo", | ||
| messages=[ | ||
| {"role": "user", "content": formatted_prompt}, | ||
| ] | ||
| ) | ||
| except (openai.error.RateLimitError, openai.error.ServiceUnavailableError, openai.error.Timeout, | ||
| openai.error.APIConnectionError, | ||
| openai.error.APIError, | ||
| openai.error.TryAgain) as e: | ||
| handle_rate_limit_error(logger, -1) | ||
| except openai.error.ServiceUnavailableError: | ||
| handle_service_unavailable_error(logger, -1) | ||
| else: | ||
| return output | ||
| else: | ||
| while True: | ||
| try: | ||
| output = openai.Completion.create( | ||
| prompt=formatted_prompt, | ||
| temperature=temperature, | ||
| max_tokens=max_tokens, | ||
| engine=engine, | ||
| stop = stop) | ||
| except (openai.error.RateLimitError, openai.error.ServiceUnavailableError, openai.error.Timeout, | ||
| openai.error.APIConnectionError, | ||
| openai.error.APIError, | ||
| openai.error.TryAgain) as e: | ||
| # except openai.error.RateLimitError: | ||
| handle_rate_limit_error(logger, -1) | ||
| # keys = cycle_api_key(keys, logger) | ||
| except openai.error.ServiceUnavailableError: | ||
| handle_service_unavailable_error(logger, -1) | ||
| else: | ||
| return output | ||
|
|
||
| def load_template(file): | ||
| # rationale = output | ||
| with open(file, 'r') as f: | ||
| template_str = f.read() | ||
| return Template(template_str) | ||
|
|
||
| def create_logger(): | ||
| logger = logging.getLogger("test_logger") | ||
| # Create a handler to output log messages to the console | ||
| console_handler = logging.StreamHandler() | ||
| console_handler.setLevel(logging.DEBUG) | ||
|
|
||
| # Create a formatter to specify the format of the log messages | ||
| formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") | ||
| console_handler.setFormatter(formatter) | ||
|
|
||
| # Add the handler to the logger | ||
| logger.addHandler(console_handler) | ||
| return logger | ||
|
|
||
|
|
||
|
|
||
| # output = query_and_retry("test", 0.5, 100, "gpt-4", create_logger(), stop=["\n"]) | ||
|
|
||
| # make formatted prompt from template using multi-rule-guardrail/sgd/prompts/violation.jinja | ||
| # formatted_prompt = load_template("multi-rule-guardrail/sgd/prompts/violation.jinja").render( | ||
| # rule = "Do not discuss takeout orders for restaurants.", | ||
| # scenario = "The user asks for recommendations on the best takeout meals in San Francisco." | ||
| # ) | ||
| # output = query_and_retry(formatted_prompt, 1, 300, "gpt-4", create_logger(), stop=["[STOP]", "STOP"]) | ||
| # print(output['choices'][0]['message']['content']) | ||
| # print(output['usage']['total_tokens']) | ||
| # print("dog") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,89 @@ | ||
| import pandas as pd | ||
| from collections import Counter | ||
| from nltk.util import ngrams | ||
| import re | ||
| import os | ||
| import openai | ||
| VARUN_KEY = os.getenv("OPENAI_API_KEY") | ||
| openai.api_key = VARUN_KEY | ||
| import pandas as pd | ||
|
|
||
| datasets = [ | ||
| "multi-rule-guardrail/sgd/output/data/final/b_id.csv", | ||
| "multi-rule-guardrail/sgd/output/data/final/b_ood.csv", | ||
| "multi-rule-guardrail/sgd/output/data/final/f_id.csv", | ||
| "multi-rule-guardrail/sgd/output/data/final/f_ood.csv", | ||
| "multi-rule-guardrail/sgd/output/data/final/r_id.csv", | ||
| "multi-rule-guardrail/sgd/output/data/final/r_ood.csv", | ||
| "multi-rule-guardrail/sgd/output/data/final_r.csv", | ||
| "multi-rule-guardrail/sgd/output/data/final_b.csv", | ||
| "multi-rule-guardrail/sgd/output/data/final_f.csv", | ||
| ] | ||
|
|
||
|
|
||
|
|
||
|
|
||
| # for dataset in datasets: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need this commented code? |
||
| # df = pd.read_csv(dataset) | ||
| # # Function to preprocess the text | ||
| # def preprocess(text): | ||
| # text = text.lower() | ||
| # text = re.sub(r'[^\w\s]', '', text) | ||
| # return text | ||
|
|
||
| # # Function to extract both user and virtual assistant messages | ||
| # def extract_responses(conversation, who): | ||
| # turns = conversation.split('\n\n') | ||
| # if who=="user": | ||
| # responses = [turn for idx, turn in enumerate(turns) if idx % 2 == 0] | ||
| # elif who=="assistant": | ||
| # responses = [turn for idx, turn in enumerate(turns) if idx % 2 == 1] | ||
| # elif who=="all": | ||
| # responses = [turn for idx, turn in enumerate(turns)] | ||
| # return ' '.join(responses) | ||
|
|
||
| # # Function to calculate distinct@k | ||
| # def calculate_distinct_k(conversation, k, who): | ||
| # conversation = extract_responses(conversation, who) | ||
| # preprocessed_text = preprocess(conversation) | ||
| # preprocessed_text = re.sub(r'user|assistant', '', preprocessed_text) | ||
| # tokens = preprocessed_text.split() | ||
| # if len(tokens) < k: | ||
| # return 0 | ||
| # k_grams = list(ngrams(tokens, k)) | ||
| # count_unique_k_grams = len(set(k_grams)) | ||
| # distinct_k = count_unique_k_grams / len(k_grams) | ||
| # return distinct_k | ||
|
|
||
| # df['distinct_1_user'] = df['conversation'].apply(calculate_distinct_k, k = 1, who = "user") | ||
| # average_distinct_1_user = df['distinct_1_user'].mean() | ||
|
|
||
| # df['distinct_2_user'] = df['conversation'].apply(calculate_distinct_k, k = 2, who = "user") | ||
| # average_distinct_2_user = df['distinct_2_user'].mean() | ||
|
|
||
| # df['distinct_3_user'] = df['conversation'].apply(calculate_distinct_k, k = 3, who = "user") | ||
| # average_distinct_3_user = df['distinct_1_user'].mean() | ||
|
|
||
| # df['distinct_1_assistant'] = df['conversation'].apply(calculate_distinct_k, k = 1, who = "assistant") | ||
| # average_distinct_1_assistant = df['distinct_1_assistant'].mean() | ||
|
|
||
| # df['distinct_2_assistant'] = df['conversation'].apply(calculate_distinct_k, k = 2, who = "assistant") | ||
| # average_distinct_2_assistant = df['distinct_2_assistant'].mean() | ||
|
|
||
| # df['distinct_3_assistant'] = df['conversation'].apply(calculate_distinct_k, k = 3, who = "assistant") | ||
| # average_distinct_3_assistant = df['distinct_3_assistant'].mean() | ||
|
|
||
| # df['distinct_1_all'] = df['conversation'].apply(calculate_distinct_k, k = 1, who = "all") | ||
| # average_distinct_1_all = df['distinct_1_all'].mean() | ||
|
|
||
| # df['distinct_2_all'] = df['conversation'].apply(calculate_distinct_k, k = 2, who = "all") | ||
| # average_distinct_2_all = df['distinct_2_all'].mean() | ||
|
|
||
| # df['distinct_3_all'] = df['conversation'].apply(calculate_distinct_k, k = 3, who = "all") | ||
| # average_distinct_3_all = df['distinct_3_all'].mean() | ||
|
|
||
| # print("===\n") | ||
| # print(dataset) | ||
| # print(f'distinct@1/2/3 User: \t{average_distinct_1_user:.2f} / {average_distinct_2_user:.2f} / {average_distinct_3_user:.2f}') | ||
| # print(f'distinct@1/2/3 Assi: \t{average_distinct_1_assistant:.2f} / {average_distinct_2_assistant:.2f} / {average_distinct_3_assistant:.2f}') | ||
| # print(f'distinct@1/2/3 Both: \t{average_distinct_1_all:.2f} / {average_distinct_2_all:.2f} / {average_distinct_3_all:.2f}') | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we want to say "OPENAI_KEY"?