Skip to content

Commit f50475e

Browse files
authored
Merge pull request #49 from codelion/feat-add-wim
Feat add plugin for privacy
2 parents d4b1fe2 + 357a112 commit f50475e

File tree

11 files changed

+381
-82
lines changed

11 files changed

+381
-82
lines changed

litellm_wrapper.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,16 @@
33
from litellm import completion
44
from typing import List, Dict, Any, Optional
55

6+
SAFETY_SETTINGS = [
7+
{"category": cat, "threshold": "BLOCK_NONE"}
8+
for cat in [
9+
"HARM_CATEGORY_HARASSMENT",
10+
"HARM_CATEGORY_HATE_SPEECH",
11+
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
12+
"HARM_CATEGORY_DANGEROUS_CONTENT"
13+
]
14+
]
15+
616
class LiteLLMWrapper:
717
def __init__(self, api_key: Optional[str] = None, base_url: Optional[str] = None):
818
self.api_key = api_key
@@ -14,7 +24,7 @@ class Chat:
1424
class Completions:
1525
@staticmethod
1626
def create(model: str, messages: List[Dict[str, str]], **kwargs):
17-
response = completion(model=model, messages=messages, **kwargs)
27+
response = completion(model=model, messages=messages, **kwargs, safety_settings=SAFETY_SETTINGS)
1828
# Convert LiteLLM response to match OpenAI response structure
1929
return response
2030

@@ -28,8 +38,8 @@ def list():
2838
# This list can be expanded as needed.
2939
return {
3040
"data": [
31-
{"id": "gpt-3.5-turbo"},
32-
{"id": "gpt-4"},
41+
{"id": "gpt-4o-mini"},
42+
{"id": "gpt-4o"},
3343
{"id": "command-nightly"},
3444
# Add more models as needed
3545
]

optillm.py

Lines changed: 48 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from optillm.rto import round_trip_optimization
2323
from optillm.self_consistency import advanced_self_consistency_approach
2424
from optillm.pvg import inference_time_pv_game
25-
from optillm.z3_solver import Z3SolverSystem
25+
from optillm.z3_solver import Z3SymPySolverSystem
2626
from optillm.rstar import RStar
2727
from optillm.cot_reflection import cot_reflection
2828
from optillm.plansearch import plansearch
@@ -44,31 +44,34 @@
4444
# Initialize Flask app
4545
app = Flask(__name__)
4646

47-
# OpenAI, Azure, or LiteLLM API configuration
48-
if os.environ.get("OPENAI_API_KEY"):
49-
API_KEY = os.environ.get("OPENAI_API_KEY")
50-
default_client = OpenAI(api_key=API_KEY)
51-
elif os.environ.get("AZURE_OPENAI_API_KEY"):
52-
API_KEY = os.environ.get("AZURE_OPENAI_API_KEY")
53-
API_VERSION = os.environ.get("AZURE_API_VERSION")
54-
AZURE_ENDPOINT = os.environ.get("AZURE_API_BASE")
55-
if API_KEY is not None:
56-
default_client = AzureOpenAI(
57-
api_key=API_KEY,
58-
api_version=API_VERSION,
59-
azure_endpoint=AZURE_ENDPOINT,
60-
)
47+
def get_config():
48+
API_KEY = None
49+
# OpenAI, Azure, or LiteLLM API configuration
50+
if os.environ.get("OPENAI_API_KEY"):
51+
API_KEY = os.environ.get("OPENAI_API_KEY")
52+
default_client = OpenAI(api_key=API_KEY)
53+
elif os.environ.get("AZURE_OPENAI_API_KEY"):
54+
API_KEY = os.environ.get("AZURE_OPENAI_API_KEY")
55+
API_VERSION = os.environ.get("AZURE_API_VERSION")
56+
AZURE_ENDPOINT = os.environ.get("AZURE_API_BASE")
57+
if API_KEY is not None:
58+
default_client = AzureOpenAI(
59+
api_key=API_KEY,
60+
api_version=API_VERSION,
61+
azure_endpoint=AZURE_ENDPOINT,
62+
)
63+
else:
64+
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
65+
azure_credential = DefaultAzureCredential()
66+
token_provider = get_bearer_token_provider(azure_credential, "https://cognitiveservices.azure.com/.default")
67+
default_client = AzureOpenAI(
68+
api_version=API_VERSION,
69+
azure_endpoint=AZURE_ENDPOINT,
70+
azure_ad_token_provider=token_provider
71+
)
6172
else:
62-
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
63-
azure_credential = DefaultAzureCredential()
64-
token_provider = get_bearer_token_provider(azure_credential, "https://cognitiveservices.azure.com/.default")
65-
default_client = AzureOpenAI(
66-
api_version=API_VERSION,
67-
azure_endpoint=AZURE_ENDPOINT,
68-
azure_ad_token_provider=token_provider
69-
)
70-
else:
71-
default_client = LiteLLMWrapper()
73+
default_client = LiteLLMWrapper()
74+
return default_client, API_KEY
7275

7376
# Server configuration
7477
server_config = {
@@ -156,7 +159,7 @@ def execute_single_approach(approach, system_prompt, initial_query, client, mode
156159
elif approach == 'rto':
157160
return round_trip_optimization(system_prompt, initial_query, client, model)
158161
elif approach == 'z3':
159-
z3_solver = Z3SolverSystem(system_prompt, client, model)
162+
z3_solver = Z3SymPySolverSystem(system_prompt, client, model)
160163
return z3_solver.process_query(initial_query)
161164
elif approach == "self_consistency":
162165
return advanced_self_consistency_approach(system_prompt, initial_query, client, model)
@@ -263,6 +266,14 @@ def check_api_key():
263266
def proxy():
264267
logger.info('Received request to /v1/chat/completions')
265268
data = request.get_json()
269+
auth_header = request.headers.get("Authorization")
270+
bearer_token = ""
271+
272+
if auth_header and auth_header.startswith("Bearer "):
273+
# Extract the bearer token
274+
bearer_token = auth_header.split("Bearer ")[1].strip()
275+
logger.debug(f"Intercepted Bearer Token: {bearer_token}")
276+
266277
logger.debug(f'Request data: {data}')
267278

268279
stream = data.get('stream', False)
@@ -281,15 +292,20 @@ def proxy():
281292
model = f"{optillm_approach}-{model}"
282293

283294
base_url = server_config['base_url']
284-
285-
if base_url != "":
286-
client = OpenAI(api_key=API_KEY, base_url=base_url)
287-
else:
288-
client = default_client
295+
default_client, api_key = get_config()
289296

290297
operation, approaches, model = parse_combined_approach(model, known_approaches, plugin_approaches)
291298
logger.info(f'Using approach(es) {approaches}, operation {operation}, with model {model}')
292299

300+
if bearer_token != "" and bearer_token.startswith("sk-") and model.startswith("gpt"):
301+
api_key = bearer_token
302+
if base_url != "":
303+
client = OpenAI(api_key=api_key, base_url=base_url)
304+
else:
305+
client = OpenAI(api_key=api_key)
306+
else:
307+
client = default_client
308+
293309
try:
294310
if operation == 'SINGLE':
295311
final_response, completion_tokens = execute_single_approach(approaches[0], system_prompt, initial_query, client, model)
@@ -342,7 +358,7 @@ def proxy():
342358
@app.route('/v1/models', methods=['GET'])
343359
def proxy_models():
344360
logger.info('Received request to /v1/models')
345-
361+
default_client, API_KEY = get_config()
346362
try:
347363
if server_config['base_url']:
348364
client = OpenAI(api_key=API_KEY, base_url=server_config['base_url'])

optillm/plugins/memory_plugin.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -50,19 +50,23 @@ def extract_query(text: str) -> Tuple[str, str]:
5050
return query, context
5151

5252
def extract_key_information(text: str, client, model: str) -> List[str]:
53+
# print(f"Prompt : {text}")
5354
prompt = f"""Extract key information from the following text. Provide a list of important facts or concepts, each on a new line:
5455
5556
{text}
5657
5758
Key information:"""
5859

59-
response = client.chat.completions.create(
60-
model=model,
61-
messages=[{"role": "user", "content": prompt}],
62-
max_tokens=1000
63-
)
64-
65-
key_info = response.choices[0].message.content.strip().split('\n')
60+
try:
61+
response = client.chat.completions.create(
62+
model=model,
63+
messages=[{"role": "user", "content": prompt}],
64+
max_tokens=1000
65+
)
66+
key_info = response.choices[0].message.content.strip().split('\n')
67+
except Exception as e:
68+
print(f"Error parsing content: {str(e)}")
69+
return [],0
6670

6771
return [info.strip('- ') for info in key_info if info.strip()], response.usage.completion_tokens
6872

@@ -75,14 +79,16 @@ def run(system_prompt: str, initial_query: str, client, model: str) -> Tuple[str
7579
chunk_size = 10000
7680
for i in range(0, len(context), chunk_size):
7781
chunk = context[i:i+chunk_size]
82+
# print(f"chunk: {chunk}")
7883
key_info, tokens = extract_key_information(chunk, client, model)
84+
#print(f"key info: {key_info}")
7985
completion_tokens += tokens
8086
for info in key_info:
8187
memory.add(info)
82-
88+
# print(f"query : {query}")
8389
# Retrieve relevant information from memory
8490
relevant_info = memory.get_relevant(query)
85-
91+
# print(f"relevant_info : {relevant_info}")
8692
# Generate response using relevant information
8793
prompt = f"""System: {system_prompt}
8894
@@ -96,8 +102,8 @@ def run(system_prompt: str, initial_query: str, client, model: str) -> Tuple[str
96102
messages=[{"role": "user", "content": prompt}],
97103
max_tokens=1000
98104
)
99-
105+
print(f"response : {response}")
100106
final_response = response.choices[0].message.content.strip()
101107
completion_tokens += response.usage.completion_tokens
102-
108+
print(f"final_response: {final_response}")
103109
return final_response, completion_tokens

optillm/plugins/privacy_plugin.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import spacy
2+
from presidio_analyzer import AnalyzerEngine
3+
from presidio_anonymizer import AnonymizerEngine, DeanonymizeEngine, OperatorConfig
4+
from presidio_anonymizer.operators import Operator, OperatorType
5+
6+
from typing import Dict, Tuple
7+
8+
SLUG = "privacy"
9+
10+
class InstanceCounterAnonymizer(Operator):
11+
"""
12+
Anonymizer which replaces the entity value
13+
with an instance counter per entity.
14+
"""
15+
16+
REPLACING_FORMAT = "<{entity_type}_{index}>"
17+
18+
def operate(self, text: str, params: Dict = None) -> str:
19+
"""Anonymize the input text."""
20+
21+
entity_type: str = params["entity_type"]
22+
23+
# entity_mapping is a dict of dicts containing mappings per entity type
24+
entity_mapping: Dict[Dict:str] = params["entity_mapping"]
25+
26+
entity_mapping_for_type = entity_mapping.get(entity_type)
27+
if not entity_mapping_for_type:
28+
new_text = self.REPLACING_FORMAT.format(
29+
entity_type=entity_type, index=0
30+
)
31+
entity_mapping[entity_type] = {}
32+
33+
else:
34+
if text in entity_mapping_for_type:
35+
return entity_mapping_for_type[text]
36+
37+
previous_index = self._get_last_index(entity_mapping_for_type)
38+
new_text = self.REPLACING_FORMAT.format(
39+
entity_type=entity_type, index=previous_index + 1
40+
)
41+
42+
entity_mapping[entity_type][text] = new_text
43+
return new_text
44+
45+
@staticmethod
46+
def _get_last_index(entity_mapping_for_type: Dict) -> int:
47+
"""Get the last index for a given entity type."""
48+
49+
def get_index(value: str) -> int:
50+
return int(value.split("_")[-1][:-1])
51+
52+
indices = [get_index(v) for v in entity_mapping_for_type.values()]
53+
return max(indices)
54+
55+
def validate(self, params: Dict = None) -> None:
56+
"""Validate operator parameters."""
57+
58+
if "entity_mapping" not in params:
59+
raise ValueError("An input Dict called `entity_mapping` is required.")
60+
if "entity_type" not in params:
61+
raise ValueError("An entity_type param is required.")
62+
63+
def operator_name(self) -> str:
64+
return "entity_counter"
65+
66+
def operator_type(self) -> OperatorType:
67+
return OperatorType.Anonymize
68+
69+
def download_model(model_name):
70+
if not spacy.util.is_package(model_name):
71+
print(f"Downloading {model_name} model...")
72+
spacy.cli.download(model_name)
73+
else:
74+
print(f"{model_name} model already downloaded.")
75+
76+
def replace_entities(entity_map, text):
77+
# Create a reverse mapping of placeholders to entity names
78+
reverse_map = {}
79+
for entity_type, entities in entity_map.items():
80+
for entity_name, placeholder in entities.items():
81+
reverse_map[placeholder] = entity_name
82+
83+
# Function to replace placeholders with entity names
84+
def replace_placeholder(match):
85+
placeholder = match.group(0)
86+
return reverse_map.get(placeholder, placeholder)
87+
88+
# Use regex to find and replace all placeholders
89+
import re
90+
pattern = r'<[A-Z_]+_\d+>'
91+
replaced_text = re.sub(pattern, replace_placeholder, text)
92+
93+
return replaced_text
94+
95+
def run(system_prompt: str, initial_query: str, client, model: str) -> Tuple[str, int]:
96+
# Use the function
97+
model_name = "en_core_web_lg"
98+
download_model(model_name)
99+
100+
analyzer = AnalyzerEngine()
101+
analyzer_results = analyzer.analyze(text=initial_query, language="en")
102+
103+
# Create Anonymizer engine and add the custom anonymizer
104+
anonymizer_engine = AnonymizerEngine()
105+
anonymizer_engine.add_anonymizer(InstanceCounterAnonymizer)
106+
107+
# Create a mapping between entity types and counters
108+
entity_mapping = dict()
109+
110+
# Anonymize the text
111+
anonymized_result = anonymizer_engine.anonymize(
112+
initial_query,
113+
analyzer_results,
114+
{
115+
"DEFAULT": OperatorConfig(
116+
"entity_counter", {"entity_mapping": entity_mapping}
117+
)
118+
},
119+
)
120+
# print(f"Anonymized request: {anonymized_result.text}")
121+
122+
response = client.chat.completions.create(
123+
model=model,
124+
messages=[
125+
{"role": "system", "content": system_prompt},
126+
{"role": "user", "content": anonymized_result.text}],
127+
)
128+
129+
# print(entity_mapping)
130+
final_response = response.choices[0].message.content.strip()
131+
# print(f"response: {final_response}")
132+
133+
final_response = replace_entities(entity_mapping, final_response)
134+
135+
return final_response, response.usage.completion_tokens

optillm/rto.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def round_trip_optimization(system_prompt: str, initial_query: str, client, mode
5959
c2 = extract_code_from_prompt(c2)
6060

6161
if c1.strip() == c2.strip():
62-
return c1
62+
return c1, rto_completion_tokens
6363

6464
messages = [{"role": "system", "content": system_prompt},
6565
{"role": "user", "content": f"Initial query: {initial_query}\n\nFirst generated code (C1):\n{c1}\n\nSecond generated code (C2):\n{c2}\n\nBased on the initial query and these two different code implementations, generate a final, optimized version of the code. Only respond with the final code, do not return anything else."}]

0 commit comments

Comments
 (0)