Skip to content

Commit 7f9f43a

Browse files
authored
Merge pull request #73 from ipa-lab/explorative_refactoring
Explorative refactoring
2 parents 4ea46ea + 48f7852 commit 7f9f43a

File tree

15 files changed

+361
-320
lines changed

15 files changed

+361
-320
lines changed

src/hackingBuddyGPT/usecases/agents.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,27 @@
44
from rich.panel import Panel
55
from typing import Dict
66

7+
from hackingBuddyGPT.usecases.base import Logger
78
from hackingBuddyGPT.utils import llm_util
8-
99
from hackingBuddyGPT.capabilities.capability import Capability, capabilities_to_simple_text_handler
10-
from .common_patterns import RoundBasedUseCase
10+
from hackingBuddyGPT.utils.openai.openai_llm import OpenAIConnection
11+
1112

1213
@dataclass
13-
class Agent(RoundBasedUseCase, ABC):
14+
class Agent(ABC):
1415
_capabilities: Dict[str, Capability] = field(default_factory=dict)
1516
_default_capability: Capability = None
17+
_log: Logger = None
18+
19+
llm: OpenAIConnection = None
1620

1721
def init(self):
18-
super().init()
22+
pass
23+
24+
# callback
25+
@abstractmethod
26+
def perform_round(self, turn: int) -> bool:
27+
pass
1928

2029
def add_capability(self, cap: Capability, default: bool = False):
2130
self._capabilities[cap.get_name()] = cap
@@ -29,6 +38,7 @@ def get_capability_block(self) -> str:
2938
capability_descriptions, _parser = capabilities_to_simple_text_handler(self._capabilities)
3039
return "You can either\n\n" + "\n".join(f"- {description}" for description in capability_descriptions.values())
3140

41+
3242
@dataclass
3343
class AgentWorldview(ABC):
3444

@@ -40,6 +50,7 @@ def to_template(self):
4050
def update(self, capability, cmd, result):
4151
pass
4252

53+
4354
class TemplatedAgent(Agent):
4455

4556
_state: AgentWorldview = None
@@ -59,7 +70,7 @@ def set_template(self, template:str):
5970
def perform_round(self, turn:int) -> bool:
6071
got_root : bool = False
6172

62-
with self.console.status("[bold green]Asking LLM for a new command..."):
73+
with self._log.console.status("[bold green]Asking LLM for a new command..."):
6374
# TODO output/log state
6475
options = self._state.to_template()
6576
options.update({
@@ -70,16 +81,16 @@ def perform_round(self, turn:int) -> bool:
7081
answer = self.llm.get_response(self._template, **options)
7182
cmd = llm_util.cmd_output_fixer(answer.result)
7283

73-
with self.console.status("[bold green]Executing that command..."):
74-
self.console.print(Panel(answer.result, title="[bold cyan]Got command from LLM:"))
84+
with self._log.console.status("[bold green]Executing that command..."):
85+
self._log.console.print(Panel(answer.result, title="[bold cyan]Got command from LLM:"))
7586
capability = self.get_capability(cmd.split(" ", 1)[0])
7687
result, got_root = capability(cmd)
7788

7889
# log and output the command and its result
79-
self.log_db.add_log_query(self._run_id, turn, cmd, result, answer)
90+
self._log.log_db.add_log_query(self._log.run_id, turn, cmd, result, answer)
8091
self._state.update(capability, cmd, result)
8192
# TODO output/log new state
82-
self.console.print(Panel(result, title=f"[bold cyan]{cmd}"))
93+
self._log.console.print(Panel(result, title=f"[bold cyan]{cmd}"))
8394

8495
# if we got root, we can stop the loop
8596
return got_root

src/hackingBuddyGPT/usecases/base.py

Lines changed: 112 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,24 @@
11
import abc
22
import argparse
3-
from dataclasses import dataclass, field
3+
import typing
4+
from dataclasses import dataclass
5+
from rich.panel import Panel
46
from typing import Dict, Type
57

6-
from hackingBuddyGPT.utils.configurable import ParameterDefinitions, build_parser, get_arguments, get_class_parameters
8+
from hackingBuddyGPT.utils.configurable import ParameterDefinitions, build_parser, get_arguments, get_class_parameters, transparent
9+
from hackingBuddyGPT.utils.console.console import Console
10+
from hackingBuddyGPT.utils.db_storage.db_storage import DbStorage
711

12+
13+
@dataclass
14+
class Logger:
15+
log_db: DbStorage
16+
console: Console
17+
tag: str = ""
18+
run_id: int = 0
19+
20+
21+
@dataclass
822
class UseCase(abc.ABC):
923
"""
1024
A UseCase is the combination of tools and capabilities to solve a specific problem.
@@ -16,13 +30,21 @@ class UseCase(abc.ABC):
1630
so that they can be automatically discovered and run from the command line.
1731
"""
1832

33+
log_db: DbStorage
34+
console: Console
35+
tag: str = ""
36+
37+
_run_id: int = 0
38+
_log: Logger = None
39+
1940
def init(self):
2041
"""
2142
The init method is called before the run method. It is used to initialize the UseCase, and can be used to
2243
perform any dynamic setup that is needed before the run method is called. One of the most common use cases is
2344
setting up the llm capabilities from the tools that were injected.
2445
"""
25-
pass
46+
self._run_id = self.log_db.create_new_run(self.get_name(), self.tag)
47+
self._log = Logger(self.log_db, self.console, self.tag, self._run_id)
2648

2749
@abc.abstractmethod
2850
def run(self):
@@ -33,6 +55,46 @@ def run(self):
3355
"""
3456
pass
3557

58+
@abc.abstractmethod
59+
def get_name(self) -> str:
60+
"""
61+
This method should return the name of the use case. It is used for logging and debugging purposes.
62+
"""
63+
pass
64+
65+
66+
# this runs the main loop for a bounded amount of turns or until root was achieved
67+
@dataclass
68+
class AutonomousUseCase(UseCase, abc.ABC):
69+
max_turns: int = 10
70+
71+
_got_root: bool = False
72+
73+
@abc.abstractmethod
74+
def perform_round(self, turn: int):
75+
pass
76+
77+
def run(self):
78+
turn = 1
79+
while turn <= self.max_turns and not self._got_root:
80+
self._log.console.log(f"[yellow]Starting turn {turn} of {self.max_turns}")
81+
82+
self._got_root = self.perform_round(turn)
83+
84+
# finish turn and commit logs to storage
85+
self._log.log_db.commit()
86+
turn += 1
87+
88+
# write the final result to the database and console
89+
if self._got_root:
90+
self._log.log_db.run_was_success(self._run_id, turn)
91+
self._log.console.print(Panel("[bold green]Got Root!", title="Run finished"))
92+
else:
93+
self._log.log_db.run_was_failure(self._run_id, turn)
94+
self._log.console.print(Panel("[green]maximum turn number reached", title="Run finished"))
95+
96+
return self._got_root
97+
3698

3799
@dataclass
38100
class _WrappedUseCase:
@@ -56,17 +118,55 @@ def __call__(self, args: argparse.Namespace):
56118
use_cases: Dict[str, _WrappedUseCase] = dict()
57119

58120

59-
def use_case(name: str, desc: str):
60-
"""
61-
By wrapping a UseCase with this decorator, it will be automatically discoverable and can be run from the command
62-
line.
63-
"""
121+
T = typing.TypeVar("T")
122+
123+
124+
class AutonomousAgentUseCase(AutonomousUseCase, typing.Generic[T]):
125+
agent: T = None
126+
127+
def perform_round(self, turn: int):
128+
raise ValueError("Do not use AutonomousAgentUseCase without supplying an agent type as generic")
129+
130+
def get_name(self) -> str:
131+
raise ValueError("Do not use AutonomousAgentUseCase without supplying an agent type as generic")
132+
133+
@classmethod
134+
def __class_getitem__(cls, item):
135+
item = dataclass(item)
136+
item.__parameters__ = get_class_parameters(item)
64137

65-
def inner(cls: Type[UseCase]):
138+
class AutonomousAgentUseCase(AutonomousUseCase):
139+
agent: transparent(item) = None
140+
141+
def init(self):
142+
super().init()
143+
self.agent._log = self._log
144+
self.agent.init()
145+
146+
def get_name(self) -> str:
147+
return self.__class__.__name__
148+
149+
def perform_round(self, turn: int):
150+
return self.agent.perform_round(turn)
151+
152+
constructed_class = dataclass(AutonomousAgentUseCase)
153+
154+
return constructed_class
155+
156+
157+
def use_case(description):
158+
def inner(cls):
159+
name = cls.__name__.removesuffix("UseCase")
66160
if name in use_cases:
67161
raise IndexError(f"Use case with name {name} already exists")
68-
use_cases[name] = _WrappedUseCase(name, desc, cls, get_class_parameters(cls, name))
162+
use_cases[name] = _WrappedUseCase(name, description, cls, get_class_parameters(cls))
163+
return inner
69164

70-
return cls
71165

72-
return inner
166+
def register_use_case(name: str, description: str, use_case: Type[UseCase]):
167+
"""
168+
This function is used to register a UseCase that was created manually, and not through the use_case decorator.
169+
"""
170+
if name in use_cases:
171+
raise IndexError(f"Use case with name {name} already exists")
172+
use_cases[name] = _WrappedUseCase(name, description, use_case, get_class_parameters(use_case))

src/hackingBuddyGPT/usecases/common_patterns.py

Lines changed: 0 additions & 62 deletions
This file was deleted.
Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,20 @@
11
import pathlib
2-
from dataclasses import dataclass, field
32
from mako.template import Template
43
from rich.panel import Panel
54

65
from hackingBuddyGPT.capabilities import SSHRunCommand, SSHTestCredential
76
from hackingBuddyGPT.utils import SSHConnection, llm_util
8-
from hackingBuddyGPT.usecases.base import use_case
7+
from hackingBuddyGPT.usecases.base import use_case, AutonomousAgentUseCase
98
from hackingBuddyGPT.usecases.agents import Agent
109
from hackingBuddyGPT.utils.cli_history import SlidingCliHistory
1110

1211
template_dir = pathlib.Path(__file__).parent
1312
template_next_cmd = Template(filename=str(template_dir / "next_cmd.txt"))
1413

15-
@use_case("minimal_linux_privesc", "Showcase Minimal Linux Priv-Escalation")
16-
@dataclass
14+
1715
class MinimalLinuxPrivesc(Agent):
1816

1917
conn: SSHConnection = None
20-
2118
_sliding_history: SlidingCliHistory = None
2219

2320
def init(self):
@@ -27,25 +24,30 @@ def init(self):
2724
self.add_capability(SSHTestCredential(conn=self.conn))
2825
self._template_size = self.llm.count_tokens(template_next_cmd.source)
2926

30-
def perform_round(self, turn):
31-
got_root : bool = False
27+
def perform_round(self, turn: int) -> bool:
28+
got_root: bool = False
3229

33-
with self.console.status("[bold green]Asking LLM for a new command..."):
30+
with self._log.console.status("[bold green]Asking LLM for a new command..."):
3431
# get as much history as fits into the target context size
3532
history = self._sliding_history.get_history(self.llm.context_size - llm_util.SAFETY_MARGIN - self._template_size)
3633

3734
# get the next command from the LLM
3835
answer = self.llm.get_response(template_next_cmd, capabilities=self.get_capability_block(), history=history, conn=self.conn)
3936
cmd = llm_util.cmd_output_fixer(answer.result)
4037

41-
with self.console.status("[bold green]Executing that command..."):
42-
self.console.print(Panel(answer.result, title="[bold cyan]Got command from LLM:"))
43-
result, got_root = self.get_capability(cmd.split(" ", 1)[0])(cmd)
38+
with self._log.console.status("[bold green]Executing that command..."):
39+
self._log.console.print(Panel(answer.result, title="[bold cyan]Got command from LLM:"))
40+
result, got_root = self.get_capability(cmd.split(" ", 1)[0])(cmd)
4441

4542
# log and output the command and its result
46-
self.log_db.add_log_query(self._run_id, turn, cmd, result, answer)
43+
self._log.log_db.add_log_query(self._log.run_id, turn, cmd, result, answer)
4744
self._sliding_history.add_command(cmd, result)
48-
self.console.print(Panel(result, title=f"[bold cyan]{cmd}"))
45+
self._log.console.print(Panel(result, title=f"[bold cyan]{cmd}"))
4946

5047
# if we got root, we can stop the loop
5148
return got_root
49+
50+
51+
@use_case("Showcase Minimal Linux Priv-Escalation")
52+
class MinimalLinuxPrivescUseCase(AutonomousAgentUseCase[MinimalLinuxPrivesc]):
53+
pass

0 commit comments

Comments
 (0)