Skip to content

Commit 033b598

Browse files
authored
Merge pull request #80 from ipa-lab/development_without_spacy
Development without spacy
2 parents 70a9018 + 88fcf70 commit 033b598

15 files changed

+731
-97
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,9 @@ wintermute.py: error: the following arguments are required: {linux_privesc,windo
179179

180180
# start wintermute, i.e., attack the configured virtual machine
181181
$ python wintermute.py minimal_linux_privesc
182+
183+
# install dependencies for testing if you want to run the tests
184+
$ pip install .[testing]
182185
~~~
183186

184187
## Publications about hackingBuddyGPT

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ pythonpath = "src"
5757
addopts = [
5858
"--import-mode=importlib",
5959
]
60+
[project.optional-dependencies]
61+
testing = [
62+
'pytest',
63+
'pytest-mock'
64+
]
6065

6166
[project.scripts]
6267
wintermute = "hackingBuddyGPT.cli.wintermute:main"

src/hackingBuddyGPT/usecases/web_api_testing/prompt_engineer.py

Lines changed: 58 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import nltk
22
from nltk.tokenize import word_tokenize
3-
from nltk.corpus import stopwords
43
from instructor.retry import InstructorRetryException
54

65

@@ -37,13 +36,11 @@ def __init__(self, strategy, llm_handler, history, schemas, response_handler):
3736
self.found_endpoints = ["/"]
3837
self.endpoint_methods = {}
3938
self.endpoint_found_methods = {}
40-
model_name = "en_core_web_sm"
41-
4239
# Check if the models are already installed
4340
nltk.download('punkt')
4441
nltk.download('stopwords')
4542
self._prompt_history = history
46-
self.prompt = self._prompt_history
43+
self.prompt = {self.round: {"content": "initial_prompt"}}
4744
self.previous_prompt = self._prompt_history[self.round]["content"]
4845
self.schemas = schemas
4946

@@ -77,10 +74,6 @@ def generate_prompt(self, doc=False):
7774
self.round = self.round +1
7875
return self._prompt_history
7976

80-
81-
82-
83-
8477
def in_context_learning(self, doc=False, hint=""):
8578
"""
8679
Generates a prompt for in-context learning.
@@ -91,7 +84,14 @@ def in_context_learning(self, doc=False, hint=""):
9184
Returns:
9285
str: The generated prompt.
9386
"""
94-
return str("\n".join(self._prompt_history[self.round]["content"] + [self.prompt]))
87+
history_content = [entry["content"] for entry in self._prompt_history]
88+
prompt_content = self.prompt.get(self.round, {}).get("content", "")
89+
90+
# Add hint if provided
91+
if hint:
92+
prompt_content += f"\n{hint}"
93+
94+
return "\n".join(history_content + [prompt_content])
9595

9696
def get_http_action_template(self, method):
9797
"""Helper to construct a consistent HTTP action description."""
@@ -103,13 +103,45 @@ def get_http_action_template(self, method):
103103
else:
104104
return (
105105
f"Create HTTPRequests of type {method} considering only the object with id=1 for the endpoint and understand the responses. Ensure that they are correct requests.")
106-
107-
106+
def get_initial_steps(self, common_steps):
107+
return [
108+
"Identify all available endpoints via GET Requests. Exclude those in this list: {self.found_endpoints}",
109+
"Note down the response structures, status codes, and headers for each endpoint.",
110+
"For each endpoint, document the following details: URL, HTTP method, query parameters and path variables, expected request body structure for requests, response structure for successful and error responses."
111+
] + common_steps
112+
113+
def get_phase_steps(self, phase, common_steps):
114+
if phase != "DELETE":
115+
return [
116+
f"Identify for all endpoints {self.found_endpoints} excluding {self.endpoint_found_methods[phase]} a valid HTTP method {phase} call.",
117+
self.get_http_action_template(phase)
118+
] + common_steps
119+
else:
120+
return [
121+
"Check for all endpoints the DELETE method. Delete the first instance for all endpoints.",
122+
self.get_http_action_template(phase)
123+
] + common_steps
124+
125+
def get_endpoints_needing_help(self):
126+
endpoints_needing_help = []
127+
endpoints_and_needed_methods = {}
128+
http_methods_set = {"GET", "POST", "PUT", "DELETE"}
129+
130+
for endpoint, methods in self.endpoint_methods.items():
131+
missing_methods = http_methods_set - set(methods)
132+
if len(methods) < 4:
133+
endpoints_needing_help.append(endpoint)
134+
endpoints_and_needed_methods[endpoint] = list(missing_methods)
135+
136+
if endpoints_needing_help:
137+
first_endpoint = endpoints_needing_help[0]
138+
needed_method = endpoints_and_needed_methods[first_endpoint][0]
139+
return [
140+
f"For endpoint {first_endpoint} find this missing method: {needed_method}. If all the HTTP methods have already been found for an endpoint, then do not include this endpoint in your search."]
141+
return []
108142
def chain_of_thought(self, doc=False, hint=""):
109143
"""
110144
Generates a prompt using the chain-of-thought strategy.
111-
If 'doc' is True, it follows a detailed documentation-oriented prompt strategy based on the round number.
112-
If 'doc' is False, it provides general guidance for early round numbers and focuses on HTTP methods for later rounds.
113145
114146
Args:
115147
doc (bool): Determines whether the documentation-oriented chain of thought should be used.
@@ -126,70 +158,30 @@ def chain_of_thought(self, doc=False, hint=""):
126158
"Make the OpenAPI specification available to developers by incorporating it into your API documentation site and keep the documentation up to date with API changes."
127159
]
128160

129-
http_methods = [ "PUT", "DELETE"]
130-
http_phase = {
131-
5: http_methods[0],
132-
10: http_methods[1]
133-
}
134-
161+
http_methods = ["PUT", "DELETE"]
162+
http_phase = {10: http_methods[0], 15: http_methods[1]}
135163
if doc:
136-
if self.round < 5:
137-
138-
chain_of_thought_steps = [
139-
f"Identify all available endpoints via GET Requests. Exclude those in this list: {self.found_endpoints}", f"Note down the response structures, status codes, and headers for each endpoint.",
140-
f"For each endpoint, document the following details: URL, HTTP method, "
141-
f"query parameters and path variables, expected request body structure for requests, response structure for successful and error responses."
142-
] + common_steps
164+
if self.round <= 5:
165+
chain_of_thought_steps = self.get_initial_steps(common_steps)
166+
elif self.round <= 10:
167+
phase = http_phase.get(min(filter(lambda x: self.round <= x, http_phase.keys())))
168+
chain_of_thought_steps = self.get_phase_steps(phase, common_steps)
143169
else:
144-
if self.round <= 10:
145-
phase = http_phase.get(min(filter(lambda x: self.round <= x, http_phase.keys())))
146-
print(f'phase:{phase}')
147-
if phase != "DELETE":
148-
chain_of_thought_steps = [
149-
f"Identify for all endpoints {self.found_endpoints} excluding {self.endpoint_found_methods[phase]} a valid HTTP method {phase} call.",
150-
self.get_http_action_template(phase)
151-
] + common_steps
152-
else:
153-
chain_of_thought_steps = [
154-
f"Check for all endpoints the DELETE method. Delete the first instance for all endpoints. ",
155-
self.get_http_action_template(phase)
156-
] + common_steps
157-
else:
158-
endpoints_needing_help = []
159-
endpoints_and_needed_methods = {}
160-
161-
# Standard HTTP methods
162-
http_methods = {"GET", "POST", "PUT", "DELETE"}
163-
164-
for endpoint in self.endpoint_methods:
165-
# Calculate the missing methods for the current endpoint
166-
missing_methods = http_methods - set(self.endpoint_methods[endpoint])
167-
168-
if len(self.endpoint_methods[endpoint]) < 4:
169-
endpoints_needing_help.append(endpoint)
170-
# Add the missing methods to the dictionary
171-
endpoints_and_needed_methods[endpoint] = list(missing_methods)
172-
173-
print(f'endpoints_and_needed_methods: {endpoints_and_needed_methods}')
174-
print(f'first endpoint in list: {endpoints_needing_help[0]}')
175-
print(f'methods needed for first endpoint: {endpoints_and_needed_methods[endpoints_needing_help[0]][0]}')
176-
177-
chain_of_thought_steps = [f"For enpoint {endpoints_needing_help[0]} find this missing method :{endpoints_and_needed_methods[endpoints_needing_help[0]][0]} "
178-
f"If all the HTTP methods have already been found for an endpoint, then do not include this endpoint in your search. ",]
179-
170+
chain_of_thought_steps = self.get_endpoints_needing_help()
180171
else:
181172
if self.round == 0:
182-
chain_of_thought_steps = ["Let's think step by step."] # Zero shot prompt
173+
chain_of_thought_steps = ["Let's think step by step."]
183174
elif self.round <= 20:
184175
focus_phases = ["endpoints", "HTTP method GET", "HTTP method POST and PUT", "HTTP method DELETE"]
185176
focus_phase = focus_phases[self.round // 5]
186177
chain_of_thought_steps = [f"Just focus on the {focus_phase} for now."]
187178
else:
188179
chain_of_thought_steps = ["Look for exploits."]
189180

190-
print(f'chain of thought steps: {chain_of_thought_steps}')
191-
prompt = self.check_prompt(self.previous_prompt,
192-
chain_of_thought_steps + [hint] if hint else chain_of_thought_steps)
181+
if hint:
182+
chain_of_thought_steps.append(hint)
183+
184+
prompt = self.check_prompt(self.previous_prompt, chain_of_thought_steps)
193185
return prompt
194186

195187
def token_count(self, text):

src/hackingBuddyGPT/usecases/web_api_testing/simple_openapi_documentation.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from hackingBuddyGPT.capabilities.http_request import HTTPRequest
1010
from hackingBuddyGPT.capabilities.record_note import RecordNote
1111
from hackingBuddyGPT.usecases.agents import Agent
12-
from hackingBuddyGPT.usecases.web_api_testing.utils.documentation_handler import DocumentationHandler
12+
from hackingBuddyGPT.usecases.web_api_testing.utils.openapi_specification_manager import OpenAPISpecificationManager
1313
from hackingBuddyGPT.usecases.web_api_testing.utils.llm_handler import LLMHandler
1414
from hackingBuddyGPT.usecases.web_api_testing.prompt_engineer import PromptEngineer, PromptStrategy
1515
from hackingBuddyGPT.usecases.web_api_testing.utils.response_handler import ResponseHandler
@@ -52,7 +52,7 @@ def init(self):
5252
self.llm_handler = LLMHandler(self.llm, self._capabilities)
5353
self.response_handler = ResponseHandler(self.llm_handler)
5454
self._setup_initial_prompt()
55-
self.documentation_handler = DocumentationHandler(self.llm_handler, self.response_handler)
55+
self.documentation_handler = OpenAPISpecificationManager(self.llm_handler, self.response_handler)
5656

5757
def _setup_capabilities(self):
5858
notes = self._context["notes"]
@@ -74,7 +74,7 @@ def _setup_initial_prompt(self):
7474
response_handler=self.response_handler)
7575

7676

77-
def all_http_methods_found(self):
77+
def all_http_methods_found(self,turn):
7878
print(f'found endpoints:{self.documentation_handler.endpoint_methods.items()}')
7979
print(f'found endpoints values:{self.documentation_handler.endpoint_methods.values()}')
8080

@@ -83,17 +83,20 @@ def all_http_methods_found(self):
8383
print(f'found endpoints:{found_endpoints}')
8484
print(f'expected endpoints:{expected_endpoints}')
8585
print(f'correct? {found_endpoints== expected_endpoints}')
86-
if found_endpoints== expected_endpoints or found_endpoints == expected_endpoints -1:
86+
if found_endpoints > 0 and (found_endpoints== expected_endpoints) :
8787
return True
8888
else:
89+
if turn == 20:
90+
if found_endpoints > 0 and (found_endpoints == expected_endpoints):
91+
return True
8992
return False
9093

9194
def perform_round(self, turn: int):
9295
prompt = self.prompt_engineer.generate_prompt(doc=True)
9396
response, completion = self.llm_handler.call_llm(prompt)
94-
return self._handle_response(completion, response)
97+
return self._handle_response(completion, response, turn)
9598

96-
def _handle_response(self, completion, response):
99+
def _handle_response(self, completion, response, turn):
97100
message = completion.choices[0].message
98101
tool_call_id = message.tool_calls[0].id
99102
command = pydantic_core.to_json(response).decode()
@@ -106,7 +109,6 @@ def _handle_response(self, completion, response):
106109
result_str = self.response_handler.parse_http_status_line(result)
107110
self._prompt_history.append(tool_message(result_str, tool_call_id))
108111
invalid_flags = ["recorded","Not a valid HTTP method", "404" ,"Client Error: Not Found"]
109-
print(f'result_str:{result_str}')
110112
if not result_str in invalid_flags or any(item in result_str for item in invalid_flags):
111113
self.prompt_engineer.found_endpoints = self.documentation_handler.update_openapi_spec(response, result)
112114
self.documentation_handler.write_openapi_to_yaml()
@@ -120,8 +122,7 @@ def _handle_response(self, completion, response):
120122
http_methods_dict[method].append(endpoint)
121123
self.prompt_engineer.endpoint_found_methods = http_methods_dict
122124
self.prompt_engineer.endpoint_methods = self.documentation_handler.endpoint_methods
123-
print(f'SCHEMAS:{self.prompt_engineer.schemas}')
124-
return self.all_http_methods_found()
125+
return self.all_http_methods_found(turn)
125126

126127

127128

src/hackingBuddyGPT/usecases/web_api_testing/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .documentation_handler import DocumentationHandler
1+
from .openapi_specification_manager import OpenAPISpecificationManager
22
from .llm_handler import LLMHandler
33
from .response_handler import ResponseHandler
44
from .openapi_parser import OpenAPISpecificationParser

src/hackingBuddyGPT/usecases/web_api_testing/utils/documentation_handler.py renamed to src/hackingBuddyGPT/usecases/web_api_testing/utils/openapi_specification_manager.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from datetime import datetime
44
from hackingBuddyGPT.capabilities.yamlFile import YAMLFile
55

6-
class DocumentationHandler:
6+
class OpenAPISpecificationManager:
77
"""
88
Handles the generation and updating of an OpenAPI specification document based on dynamic API responses.
99
@@ -51,7 +51,7 @@ def __init__(self, llm_handler, response_handler):
5151
"yaml": YAMLFile()
5252
}
5353

54-
def partial_match(self, element, string_list):
54+
def is_partial_match(self, element, string_list):
5555
return any(element in string or string in element for string in string_list)
5656

5757
def update_openapi_spec(self, resp, result):
@@ -66,7 +66,7 @@ def update_openapi_spec(self, resp, result):
6666

6767
if request.__class__.__name__ == 'RecordNote': # TODO: check why isinstance does not work
6868
self.check_openapi_spec(resp)
69-
if request.__class__.__name__ == 'HTTPRequest':
69+
elif request.__class__.__name__ == 'HTTPRequest':
7070
path = request.path
7171
method = request.method
7272
print(f'method: {method}')
@@ -107,7 +107,7 @@ def update_openapi_spec(self, resp, result):
107107

108108
if '1' not in path and x != "":
109109
endpoint_methods[path].append(method)
110-
elif self.partial_match(x, endpoints.keys()):
110+
elif self.is_partial_match(x, endpoints.keys()):
111111
path = f"/{x}"
112112
print(f'endpoint methods = {endpoint_methods}')
113113
print(f'new path:{path}')

src/hackingBuddyGPT/usecases/web_api_testing/utils/response_handler.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,14 @@ def parse_http_status_line(self, status_line):
5050
"""
5151
if status_line == "Not a valid HTTP method":
5252
return status_line
53-
if status_line and " " in status_line:
54-
protocol, status_code, status_message = status_line.split(' ', 2)
55-
status_message = status_message.split("\r\n")[0]
53+
status_line = status_line.split('\r\n')[0]
54+
# Regular expression to match valid HTTP status lines
55+
match = re.match(r'^(HTTP/\d\.\d) (\d{3}) (.*)$', status_line)
56+
if match:
57+
protocol, status_code, status_message = match.groups()
5658
return f'{status_code} {status_message}'
57-
raise ValueError("Invalid HTTP status line")
59+
else:
60+
raise ValueError("Invalid HTTP status line")
5861

5962
def extract_response_example(self, html_content):
6063
"""

tests/test_llm_handler.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import unittest
2+
from unittest.mock import MagicMock, patch
3+
from hackingBuddyGPT.capabilities.capability import capabilities_to_action_model
4+
from hackingBuddyGPT.usecases.web_api_testing.utils import LLMHandler
5+
6+
7+
class TestLLMHandler(unittest.TestCase):
8+
def setUp(self):
9+
self.llm_mock = MagicMock()
10+
self.capabilities = {'cap1': MagicMock(), 'cap2': MagicMock()}
11+
self.llm_handler = LLMHandler(self.llm_mock, self.capabilities)
12+
13+
'''@patch('hackingBuddyGPT.usecases.web_api_testing.utils.capabilities_to_action_model')
14+
def test_call_llm(self, mock_capabilities_to_action_model):
15+
prompt = [{'role': 'user', 'content': 'Hello, LLM!'}]
16+
response_mock = MagicMock()
17+
self.llm_mock.instructor.chat.completions.create_with_completion.return_value = response_mock
18+
19+
# Mock the capabilities_to_action_model to return a dummy Pydantic model
20+
mock_model = MagicMock()
21+
mock_capabilities_to_action_model.return_value = mock_model
22+
23+
response = self.llm_handler.call_llm(prompt)
24+
25+
self.llm_mock.instructor.chat.completions.create_with_completion.assert_called_once_with(
26+
model=self.llm_mock.model,
27+
messages=prompt,
28+
response_model=mock_model
29+
)
30+
self.assertEqual(response, response_mock)'''
31+
def test_add_created_object(self):
32+
created_object = MagicMock()
33+
object_type = 'test_type'
34+
35+
self.llm_handler.add_created_object(created_object, object_type)
36+
37+
self.assertIn(object_type, self.llm_handler.created_objects)
38+
self.assertIn(created_object, self.llm_handler.created_objects[object_type])
39+
40+
def test_add_created_object_limit(self):
41+
created_object = MagicMock()
42+
object_type = 'test_type'
43+
44+
for _ in range(8): # Exceed the limit of 7 objects
45+
self.llm_handler.add_created_object(created_object, object_type)
46+
47+
self.assertEqual(len(self.llm_handler.created_objects[object_type]), 7)
48+
49+
def test_get_created_objects(self):
50+
created_object = MagicMock()
51+
object_type = 'test_type'
52+
self.llm_handler.add_created_object(created_object, object_type)
53+
54+
created_objects = self.llm_handler.get_created_objects()
55+
56+
self.assertIn(object_type, created_objects)
57+
self.assertIn(created_object, created_objects[object_type])
58+
self.assertEqual(created_objects, self.llm_handler.created_objects)
59+
60+
if __name__ == "__main__":
61+
unittest.main()

0 commit comments

Comments
 (0)