Skip to content

Commit da2fb1f

Browse files
committed
Refine token size calculation and model selection in Coder class
Resolves #25 In this commit, we've made several adjustments to the `Coder` class in `aicodebot/coder.py` and `aicodebot/cli.py`. The token size calculation now includes a 5% buffer, down from 10%, to account for the occasional underestimation by the `tiktoken` library. The `get_token_length` method now defaults to the `gpt-4` model for token counting, and the debug output has been improved for readability. In `aicodebot/cli.py`, we've adjusted the `model_name` calculation in several methods to include `response_token_size` in the token count. This ensures that the selected model can handle the combined size of the request and response. In the `sidekick` method, we've also introduced a `memory_token_size` to allow for a decent history. These changes should improve the accuracy of model selection and prevent errors when the token count exceeds the model's limit.
1 parent df0eda4 commit da2fb1f

File tree

2 files changed

+17
-13
lines changed

2 files changed

+17
-13
lines changed

aicodebot/cli.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def alignment(response_token_size, verbose):
5959
logger.trace(f"Prompt: {prompt}")
6060

6161
# Set up the language model
62-
model_name = Coder.get_llm_model_name(Coder.get_token_length(prompt.template))
62+
model_name = Coder.get_llm_model_name(Coder.get_token_length(prompt.template) + response_token_size)
6363

6464
with Live(Markdown(""), auto_refresh=True) as live:
6565
llm = Coder.get_llm(
@@ -142,7 +142,7 @@ def commit(verbose, response_token_size, yes, skip_pre_commit, files): # noqa:
142142

143143
# Check the size of the diff context and adjust accordingly
144144
request_token_size = Coder.get_token_length(diff_context) + Coder.get_token_length(prompt.template)
145-
model_name = Coder.get_llm_model_name(request_token_size)
145+
model_name = Coder.get_llm_model_name(request_token_size + response_token_size)
146146
if model_name is None:
147147
raise click.ClickException(
148148
f"The diff is too large to generate a commit message ({request_token_size} tokens). 😢"
@@ -303,7 +303,7 @@ def debug(command, verbose):
303303

304304
# Set up the language model
305305
request_token_size = Coder.get_token_length(error_output) + Coder.get_token_length(prompt.template)
306-
model_name = Coder.get_llm_model_name(request_token_size)
306+
model_name = Coder.get_llm_model_name(request_token_size + DEFAULT_MAX_TOKENS)
307307
if model_name is None:
308308
raise click.ClickException(f"The output is too large to debug ({request_token_size} tokens). 😢")
309309

@@ -379,9 +379,8 @@ def review(commit, verbose, output_format, response_token_size, files):
379379
logger.trace(f"Prompt: {prompt}")
380380

381381
# Check the size of the diff context and adjust accordingly
382-
response_token_size = DEFAULT_MAX_TOKENS * 2
383382
request_token_size = Coder.get_token_length(diff_context) + Coder.get_token_length(prompt.template)
384-
model_name = Coder.get_llm_model_name(request_token_size)
383+
model_name = Coder.get_llm_model_name(request_token_size + response_token_size)
385384
if model_name is None:
386385
raise click.ClickException(f"The diff is too large to review ({request_token_size} tokens). 😢")
387386

@@ -432,8 +431,9 @@ def sidekick(request, verbose, response_token_size, files):
432431

433432
# Generate the prompt and set up the model
434433
prompt = get_prompt("sidekick")
434+
memory_token_size = response_token_size * 2 # Allow decent history
435435
request_token_size = Coder.get_token_length(prompt.template) + Coder.get_token_length(context)
436-
model_name = Coder.get_llm_model_name(request_token_size)
436+
model_name = Coder.get_llm_model_name(request_token_size + response_token_size + memory_token_size)
437437
if model_name is None:
438438
raise click.ClickException(
439439
f"The file context you supplied is too large ({request_token_size} tokens). 😢 Try again with less files."
@@ -446,7 +446,7 @@ def sidekick(request, verbose, response_token_size, files):
446446

447447
# Set up the chain
448448
memory = ConversationTokenBufferMemory(
449-
memory_key="chat_history", input_key="task", llm=llm, max_token_limit=DEFAULT_MAX_TOKENS
449+
memory_key="chat_history", input_key="task", llm=llm, max_token_limit=memory_token_size
450450
)
451451
chain = LLMChain(llm=llm, prompt=prompt, memory=memory, verbose=verbose)
452452
history_file = Path.home() / ".aicodebot_request_history"
@@ -457,8 +457,11 @@ def sidekick(request, verbose, response_token_size, files):
457457
if request:
458458
human_input = request
459459
else:
460-
human_input = input_prompt("🤖 ➤ ", history=FileHistory(history_file))
461-
if len(human_input) == 1:
460+
human_input = input_prompt("🤖 ➤ ", history=FileHistory(history_file)).strip()
461+
if not human_input:
462+
# Must have been spaces or blank line
463+
continue
464+
elif len(human_input) == 1:
462465
if human_input.lower() == "q":
463466
break
464467
elif human_input.lower() == "e":

aicodebot/coder.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,8 @@ def get_llm_model_name(token_size=0):
115115
# Pull the list of supported engines from the OpenAI API for this key
116116
supported_engines = Coder.get_openai_supported_engines()
117117

118-
# For some unknown reason, tiktoken often underestimates the token size by ~10%, so let's buffer
119-
token_size = int(token_size * 1.1)
118+
# For some unknown reason, tiktoken often underestimates the token size by ~5%, so let's buffer
119+
token_size = int(token_size * 1.05)
120120

121121
for model, max_tokens in model_options.items():
122122
if model in supported_engines and token_size <= max_tokens:
@@ -130,12 +130,13 @@ def get_llm_model_name(token_size=0):
130130
return None
131131

132132
@staticmethod
133-
def get_token_length(text, model="gpt-3.5-turbo"):
133+
def get_token_length(text, model="gpt-4"):
134134
"""Get the number of tokens in a string using the tiktoken library."""
135135
encoding = tiktoken.encoding_for_model(model)
136136
tokens = encoding.encode(text)
137137
token_length = len(tokens)
138-
logger.debug(f"Token length for text {text[0:10]}...: {token_length}")
138+
short_text = text.strip()[0:20] + "..." if len(text) > 10 else text
139+
logger.debug(f"Token length for {short_text}: {token_length}")
139140
return token_length
140141

141142
@staticmethod

0 commit comments

Comments
 (0)