Skip to content

Commit 41f41d8

Browse files
committed
Output guardrails on streaming
1 parent 39092c2 commit 41f41d8

16 files changed

+553
-156
lines changed

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.BEAN_IF_EXISTS_RETRIEVAL_AUGMENTOR_SUPPLIER;
66
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.MEMORY_ID;
77
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.NO_RETRIEVAL_AUGMENTOR_SUPPLIER;
8-
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.OUTPUT_GUARDRAILS;
98
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.QUARKUS_INPUT_GUARDRAILS;
109
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.QUARKUS_OUTPUT_GUARDRAILS;
1110
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.REGISTER_AI_SERVICES;
@@ -41,13 +40,16 @@
4140
import java.util.stream.Collectors;
4241
import java.util.stream.Stream;
4342

43+
import javax.tools.Tool;
44+
4445
import jakarta.annotation.PreDestroy;
4546
import jakarta.enterprise.context.Dependent;
4647
import jakarta.enterprise.inject.spi.DeploymentException;
4748
import jakarta.enterprise.util.AnnotationLiteral;
4849
import jakarta.inject.Inject;
4950
import jakarta.interceptor.InterceptorBinding;
5051

52+
import org.eclipse.microprofile.config.ConfigProvider;
5153
import org.eclipse.microprofile.rest.client.inject.RestClient;
5254
import org.jboss.jandex.AnnotationInstance;
5355
import org.jboss.jandex.AnnotationTarget;
@@ -69,7 +71,6 @@
6971
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
7072

7173
import dev.langchain4j.guardrail.OutputGuardrail;
72-
import dev.langchain4j.agent.tool.Tool;
7374
import dev.langchain4j.model.chat.request.json.JsonSchema;
7475
import dev.langchain4j.service.IllegalConfigurationException;
7576
import dev.langchain4j.service.Moderate;
@@ -108,6 +109,8 @@
108109
import io.quarkiverse.langchain4j.runtime.aiservice.QuarkusAiServiceContext;
109110
import io.quarkiverse.langchain4j.runtime.aiservice.SpanWrapper;
110111
import io.quarkiverse.langchain4j.runtime.config.GuardrailsConfig;
112+
import io.quarkiverse.langchain4j.runtime.types.TypeSignatureParser;
113+
import io.quarkiverse.langchain4j.runtime.types.TypeUtil;
111114
import io.quarkiverse.langchain4j.spi.DefaultMemoryIdProvider;
112115
import io.quarkus.arc.Arc;
113116
import io.quarkus.arc.ArcContainer;
@@ -1602,29 +1605,40 @@ private AiServiceMethodCreateInfo gatherMethodMetadata(
16021605
boolean switchToWorkerThreadForToolExecution = detectIfToolExecutionRequiresAWorkerThread(method, tools,
16031606
methodToolClassInfo.keySet());
16041607

1608+
var methodReturnTypeSignature = returnTypeSignature(method.returnType(),
1609+
new TypeArgMapper(method.declaringClass(), index));
1610+
16051611
return new AiServiceMethodCreateInfo(method.declaringClass().name().toString(), method.name(), systemMessageInfo,
1606-
userMessageInfo, memoryIdParamPosition, requiresModeration,
1607-
returnTypeSignature(method.returnType(), new TypeArgMapper(method.declaringClass(), index)),
1612+
userMessageInfo, memoryIdParamPosition, requiresModeration, methodReturnTypeSignature,
16081613
overrideChatModelParamPosition, metricsTimedInfo, metricsCountedInfo, spanInfo, responseSchemaInfo,
16091614
methodToolClassInfo, methodMcpClientNames, switchToWorkerThreadForToolExecution, quarkusInputGuardrailClassess,
16101615
quarkusOutputGuardrailClasses,
16111616
accumulatorClassName, responseAugmenterClassName, gatherInputGuardrails(method),
1612-
gatherOutputGuardrails(method));
1617+
gatherOutputGuardrails(method, methodReturnTypeSignature));
16131618
}
16141619

16151620
private static InputGuardrailsLiteral gatherInputGuardrails(MethodInfo method) {
16161621
return new InputGuardrailsLiteral(
16171622
gatherGuardrails(getGuardrailsAnnotation(method, LangChain4jDotNames.INPUT_GUARDRAILS)));
16181623
}
16191624

1620-
private static OutputGuardrailsLiteral gatherOutputGuardrails(MethodInfo method) {
1625+
private static OutputGuardrailsLiteral gatherOutputGuardrails(MethodInfo method, String methodReturnTypeSignature) {
16211626
var annotationInstance = getGuardrailsAnnotation(method, LangChain4jDotNames.OUTPUT_GUARDRAILS);
1622-
var maxRetries = annotationInstance
1627+
var methodReturnsMulti = TypeUtil.isMulti(TypeSignatureParser.parse(methodReturnTypeSignature));
1628+
var maxRetriesAsSetByConfig = annotationInstance
16231629
.map(v -> v.value("maxRetries"))
16241630
.map(AnnotationValue::asInt)
1631+
.or(() -> ConfigProvider.getConfig().getOptionalValue("quarkus.langchain4j.guardrails.max-retries",
1632+
Integer.class))
16251633
.orElse(GuardrailsConfig.MAX_RETRIES_DEFAULT);
16261634

1627-
return new OutputGuardrailsLiteral(gatherGuardrails(annotationInstance), maxRetries);
1635+
// If the method returns a Multi, then we don't want the guardrail service to perform any retries on its own
1636+
// Instead we'll store the value as a config value and we'll have the multi itself perform the retries
1637+
// based on the number of retries configured, either on the annotation or set through config
1638+
return new OutputGuardrailsLiteral(
1639+
gatherGuardrails(annotationInstance),
1640+
methodReturnsMulti ? 0 : maxRetriesAsSetByConfig,
1641+
maxRetriesAsSetByConfig);
16281642
}
16291643

16301644
private static Optional<AnnotationInstance> getGuardrailsAnnotation(MethodInfo methodInfo, DotName annotation) {

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,12 @@ public List<String> asClassNames() {
159159
}
160160
}
161161

162-
public record DeclarativeAiServiceOutputGuardrails(List<ClassInfo> outputGuardrailClassInfos, int maxRetries) {
162+
public record DeclarativeAiServiceOutputGuardrails(List<ClassInfo> outputGuardrailClassInfos, int maxRetries,
163+
int actualMaxRetries) {
164+
public DeclarativeAiServiceOutputGuardrails(List<ClassInfo> outputGuardrailClassInfos, int maxRetries) {
165+
this(outputGuardrailClassInfos, maxRetries, maxRetries);
166+
}
167+
163168
public List<String> asClassNames() {
164169
return this.outputGuardrailClassInfos.stream()
165170
.map(classInfo -> classInfo.name().toString())

core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailChainOnStreamedResponseTest.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ void testThatGuardrailOrderIsCorrect() {
7575
@ActivateRequestContext
7676
void testThatRetryRestartTheChain() {
7777
aiService.failingFirstTwo("1", "foo").collect().asList().await().indefinitely();
78-
;
7978
assertThat(firstGuardrail.spy()).isEqualTo(2);
8079
assertThat(secondGuardrail.spy()).isEqualTo(1);
8180
assertThat(failingGuardrail.spy()).isEqualTo(2);
@@ -105,7 +104,6 @@ void testThatGuardrailOrderIsCorrectWithPassThroughAccumulator() {
105104
@ActivateRequestContext
106105
void testThatRetryRestartTheChainWithPassThroughAccumulator() {
107106
aiService.failingFirstTwoWithPassThroughAccumulator("1", "foo").collect().asList().await().indefinitely();
108-
;
109107
assertThat(firstGuardrail.spy()).isEqualTo(4);
110108
assertThat(secondGuardrail.spy()).isEqualTo(3);
111109
assertThat(failingGuardrail.spy()).isEqualTo(4);

core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailOnStreamedResponseValidationTest.java

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ public class OutputGuardrailOnStreamedResponseValidationTest {
4949
@Inject
5050
OKGuardrail okGuardrail;
5151

52+
@Inject
53+
RewritingGuardrail rewriting;
54+
5255
@Test
5356
@ActivateRequestContext
5457
void testOk() {
@@ -60,7 +63,7 @@ void testOk() {
6063
@ActivateRequestContext
6164
void testOkWithPassThroughAccumulator() {
6265
aiService.okWithPassThroughAccumulator("1").collect().asList().await().indefinitely();
63-
assertThat(okGuardrail.spy()).isEqualTo(1);
66+
assertThat(okGuardrail.spy()).isEqualTo(3);
6467
}
6568

6669
@Test
@@ -85,14 +88,17 @@ void testKOWithPassThroughAccumulator() {
8588
@Test
8689
@ActivateRequestContext
8790
void testRetryOk() {
88-
aiService.retry("3").collect().asList().await().indefinitely();
91+
assertThat(aiService.retry("3").collect().asList().await().indefinitely())
92+
.singleElement()
93+
.isEqualTo("Hi! World!");
8994
assertThat(retry.spy()).isEqualTo(2);
9095
}
9196

9297
@Test
9398
@ActivateRequestContext
9499
void testRetryOkWithPassThroughAccumulator() {
95-
aiService.retryWithPassThroughAccumulator("3").collect().asList().await().indefinitely();
100+
assertThat(aiService.retryWithPassThroughAccumulator("3").collect().asList().await().indefinitely())
101+
.containsExactly("Hi!", " ", "World!");
96102
assertThat(retry.spy()).isEqualTo(4); // "Hi!", "Hi!" (retry), " ", "World!"
97103
}
98104

@@ -145,10 +151,11 @@ void testFatalExceptionWithPassThroughAccumulator() {
145151

146152
@Test
147153
@ActivateRequestContext
148-
void testRewritingWhileStreamingIsNotAllowed() {
149-
assertThatThrownBy(() -> aiService.rewriting("1").collect().asList().await().indefinitely())
150-
.isInstanceOf(GuardrailException.class)
151-
.hasMessageContaining("Attempting to rewrite the LLM output while streaming is not allowed");
154+
void rewritingWhileStreaming() throws InterruptedException {
155+
assertThat(aiService.rewriting("1").collect().asList().await().indefinitely())
156+
.singleElement()
157+
.isEqualTo("Hi! World!,1");
158+
assertThat(rewriting.spy()).isEqualTo(1);
152159
}
153160

154161
@RegisterAiService(streamingChatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class)
@@ -273,7 +280,6 @@ public int spy() {
273280

274281
@RequestScoped
275282
public static class KOFatalGuardrail implements OutputGuardrail {
276-
277283
AtomicInteger spy = new AtomicInteger(0);
278284

279285
@Override
@@ -289,12 +295,18 @@ public int spy() {
289295

290296
@RequestScoped
291297
public static class RewritingGuardrail implements OutputGuardrail {
298+
AtomicInteger spy = new AtomicInteger(0);
292299

293300
@Override
294301
public OutputGuardrailResult validate(AiMessage responseFromLLM) {
302+
spy.incrementAndGet();
295303
String text = responseFromLLM.text();
296304
return successWith(text + ",1");
297305
}
306+
307+
public int spy() {
308+
return spy.get();
309+
}
298310
}
299311

300312
public static class MyChatModelSupplier implements Supplier<StreamingChatModel> {

core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailValidationTest.java

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package io.quarkiverse.langchain4j.test.guardrails;
22

33
import static org.assertj.core.api.Assertions.assertThat;
4+
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
45
import static org.assertj.core.api.Assertions.assertThatThrownBy;
56

67
import java.util.concurrent.atomic.AtomicInteger;
@@ -19,6 +20,7 @@
1920
import dev.langchain4j.data.message.AiMessage;
2021
import dev.langchain4j.guardrail.GuardrailException;
2122
import dev.langchain4j.guardrail.OutputGuardrail;
23+
import dev.langchain4j.guardrail.OutputGuardrailException;
2224
import dev.langchain4j.guardrail.OutputGuardrailResult;
2325
import dev.langchain4j.memory.ChatMemory;
2426
import dev.langchain4j.memory.chat.ChatMemoryProvider;
@@ -91,6 +93,19 @@ void testFatalException() {
9193
assertThat(fatal.spy()).isEqualTo(1);
9294
}
9395

96+
@Test
97+
@ActivateRequestContext
98+
void noRetries() {
99+
MyChatModelSupplier.CHAT_MODEL.spy.set(0);
100+
assertThatExceptionOfType(OutputGuardrailException.class)
101+
.isThrownBy(() -> aiService.noRetry("6"))
102+
.withMessageContaining(
103+
"Output validation failed. The guardrails have reached the maximum number of retries.");
104+
105+
assertThat(MyChatModelSupplier.CHAT_MODEL.spy()).isEqualTo(1);
106+
assertThat(retry.spy()).isEqualTo(1);
107+
}
108+
94109
@RegisterAiService(chatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class)
95110
public interface MyAiService {
96111

@@ -110,6 +125,10 @@ public interface MyAiService {
110125
@OutputGuardrails(RetryingButFailGuardrail.class)
111126
String retryButFail(@MemoryId String mem);
112127

128+
@UserMessage("Say Hi!")
129+
@OutputGuardrails(value = RetryingGuardrail.class, maxRetries = 0)
130+
String noRetry(@MemoryId String mem);
131+
113132
@UserMessage("Say Hi!")
114133
@OutputGuardrails(KOFatalGuardrail.class)
115134
String fatal(@MemoryId String mem);
@@ -147,7 +166,7 @@ public int spy() {
147166
}
148167
}
149168

150-
@ApplicationScoped
169+
@RequestScoped
151170
public static class RetryingGuardrail implements OutputGuardrail {
152171

153172
AtomicInteger spy = new AtomicInteger(0);
@@ -199,19 +218,26 @@ public int spy() {
199218
}
200219

201220
public static class MyChatModelSupplier implements Supplier<ChatModel> {
221+
static final MyChatModel CHAT_MODEL = new MyChatModel();
202222

203223
@Override
204224
public ChatModel get() {
205-
return new MyChatModel();
225+
return CHAT_MODEL;
206226
}
207227
}
208228

209229
public static class MyChatModel implements ChatModel {
230+
private final AtomicInteger spy = new AtomicInteger(0);
210231

211232
@Override
212233
public ChatResponse doChat(ChatRequest request) {
234+
spy.incrementAndGet();
213235
return ChatResponse.builder().aiMessage(new AiMessage("Hi!")).build();
214236
}
237+
238+
public int spy() {
239+
return spy.get();
240+
}
215241
}
216242

217243
public static class MyMemoryProviderSupplier implements Supplier<ChatMemoryProvider> {

core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusOutputGuardrailOnStreamedResponseValidationTest.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import io.quarkiverse.langchain4j.guardrails.OutputTokenAccumulator;
3434
import io.quarkiverse.langchain4j.runtime.aiservice.GuardrailException;
3535
import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory;
36+
import io.quarkiverse.langchain4j.test.guardrails.OutputGuardrailOnStreamedResponseValidationTest.OKGuardrail;
3637
import io.quarkus.test.QuarkusUnitTest;
3738
import io.smallrye.mutiny.Multi;
3839

@@ -50,16 +51,21 @@ public class QuarkusOutputGuardrailOnStreamedResponseValidationTest {
5051
@Inject
5152
MyAiService aiService;
5253

54+
@Inject
55+
OKGuardrail okGuardrail;
56+
5357
@Test
5458
@ActivateRequestContext
5559
void testOk() {
5660
aiService.ok("1").collect().asList().await().indefinitely();
61+
assertThat(okGuardrail.spy()).isEqualTo(1);
5762
}
5863

5964
@Test
6065
@ActivateRequestContext
6166
void testOkWithPassThroughAccumulator() {
6267
aiService.okWithPassThroughAccumulator("1").collect().asList().await().indefinitely();
68+
assertThat(okGuardrail.spy()).isEqualTo(3);
6369
}
6470

6571
@Test

0 commit comments

Comments
 (0)