Skip to content

Commit 9e0d59c

Browse files
authored
fix(trimmer): trim last message when allowPartial: true and strategy: first (#8287)
1 parent 9e1832a commit 9e0d59c

File tree

2 files changed

+65
-30
lines changed

2 files changed

+65
-30
lines changed

langchain-core/src/messages/tests/message_utils.test.ts

Lines changed: 64 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
1-
import { it, describe, test, expect } from "@jest/globals";
1+
import { describe, expect, it, test } from "@jest/globals";
22
import { v4 } from "uuid";
3-
import {
4-
filterMessages,
5-
mergeMessageRuns,
6-
trimMessages,
7-
} from "../transformers.js";
83
import { AIMessage, AIMessageChunk } from "../ai.js";
4+
import { BaseMessage, MessageContent } from "../base.js";
95
import { ChatMessage } from "../chat.js";
106
import { HumanMessage } from "../human.js";
117
import { SystemMessage } from "../system.js";
128
import { ToolMessage } from "../tool.js";
13-
import { BaseMessage } from "../base.js";
9+
import {
10+
filterMessages,
11+
mergeMessageRuns,
12+
trimMessages,
13+
} from "../transformers.js";
1414
import {
1515
getBufferString,
1616
mapChatMessagesToStoredMessages,
@@ -141,7 +141,26 @@ describe("mergeMessageRuns", () => {
141141
});
142142

143143
describe("trimMessages can trim", () => {
144-
const messagesAndTokenCounterFactory = () => {
144+
const defaultCountTokensByMessageContent = (
145+
content: MessageContent
146+
): number => {
147+
// treat each message like it adds 3 default tokens at the beginning
148+
// of the message and at the end of the message. 3 + 4 + 3 = 10 tokens
149+
// per message.
150+
const defaultMsgPrefixLen = 3;
151+
const defaultContentLen = 4;
152+
const defaultMsgSuffixLen = 3;
153+
154+
const contentLen = Array.isArray(content)
155+
? content.length * defaultContentLen
156+
: defaultContentLen;
157+
158+
return defaultMsgPrefixLen + contentLen + defaultMsgSuffixLen;
159+
};
160+
161+
const messagesAndTokenCounterFactory = ({
162+
countTokensByMessageContent = defaultCountTokensByMessageContent,
163+
} = {}) => {
145164
const messages = [
146165
new SystemMessage(
147166
"This is a 4 token text. The full message is 10 tokens."
@@ -168,27 +187,10 @@ describe("trimMessages can trim", () => {
168187
];
169188

170189
const dummyTokenCounter = (messages: BaseMessage[]): number => {
171-
// treat each message like it adds 3 default tokens at the beginning
172-
// of the message and at the end of the message. 3 + 4 + 3 = 10 tokens
173-
// per message.
174-
175-
const defaultContentLen = 4;
176-
const defaultMsgPrefixLen = 3;
177-
const defaultMsgSuffixLen = 3;
178-
179-
let count = 0;
180-
for (const msg of messages) {
181-
if (typeof msg.content === "string") {
182-
count +=
183-
defaultMsgPrefixLen + defaultContentLen + defaultMsgSuffixLen;
184-
}
185-
if (Array.isArray(msg.content)) {
186-
count +=
187-
defaultMsgPrefixLen +
188-
msg.content.length * defaultContentLen +
189-
defaultMsgSuffixLen;
190-
}
191-
}
190+
const count = messages.reduce(
191+
(count, msg) => count + countTokensByMessageContent(msg.content),
192+
0
193+
);
192194
console.log(count);
193195
return count;
194196
};
@@ -376,6 +378,39 @@ describe("trimMessages can trim", () => {
376378
]);
377379
});
378380

381+
it("First tokens, allowing partial messages, have to trim the last 10 characters of the last message", async () => {
382+
// For the purpose of this test, we'll override the dummy token counter to count characters.
383+
const { messages, dummyTokenCounter } = messagesAndTokenCounterFactory({
384+
countTokensByMessageContent: (content: MessageContent): number =>
385+
content.length,
386+
});
387+
388+
const totalCharacters = messages.reduce(
389+
(count, msg) => count + msg.content.length,
390+
0
391+
);
392+
393+
const trimmedMessages = await trimMessages(messages, {
394+
maxTokens: totalCharacters - 10,
395+
tokenCounter: dummyTokenCounter,
396+
strategy: "first",
397+
allowPartial: true,
398+
textSplitter: (text: string) => text.split(""),
399+
});
400+
401+
const trimmedMessagesContent = trimmedMessages.map((msg) => msg.content);
402+
expect(trimmedMessagesContent).toEqual([
403+
"This is a 4 token text. The full message is 10 tokens.",
404+
"This is a 4 token text. The full message is 10 tokens.",
405+
[
406+
{ type: "text", text: "This is the FIRST 4 token block." },
407+
{ type: "text", text: "This is the SECOND 4 token block." },
408+
],
409+
"This is a 4 token text. The full message is 10 tokens.",
410+
"This is a 4 token text. The full message is ",
411+
]);
412+
});
413+
379414
it("Last 30 tokens, including system message, not allowing partial messages", async () => {
380415
const { messages, dummyTokenCounter } = messagesAndTokenCounterFactory();
381416
const trimmedMessages = await trimMessages(messages, {

langchain-core/src/messages/transformers.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,7 @@ async function _firstMaxTokens(
745745
break;
746746
}
747747
}
748-
if (idx < messagesCopy.length - 1 && partialStrategy) {
748+
if (idx < messagesCopy.length && partialStrategy) {
749749
let includedPartial = false;
750750
if (Array.isArray(messagesCopy[idx].content)) {
751751
const excluded = messagesCopy[idx];

0 commit comments

Comments
 (0)