diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java index c5f433bfa..d69bec07c 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java @@ -3,10 +3,10 @@ import static dev.langchain4j.service.IllegalConfigurationException.illegalConfiguration; import static io.quarkiverse.langchain4j.deployment.ExceptionUtil.illegalConfigurationForMethod; import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.BEAN_IF_EXISTS_RETRIEVAL_AUGMENTOR_SUPPLIER; -import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.INPUT_GUARDRAILS; import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.MEMORY_ID; import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.NO_RETRIEVAL_AUGMENTOR_SUPPLIER; -import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.OUTPUT_GUARDRAILS; +import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.QUARKUS_INPUT_GUARDRAILS; +import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.QUARKUS_OUTPUT_GUARDRAILS; import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.REGISTER_AI_SERVICES; import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.SEED_MEMORY; import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.V; @@ -38,6 +38,9 @@ import java.util.function.Function; import java.util.function.Predicate; import java.util.stream.Collectors; +import java.util.stream.Stream; + +import javax.tools.Tool; import jakarta.annotation.PreDestroy; import jakarta.enterprise.context.Dependent; @@ -46,6 +49,7 @@ import jakarta.inject.Inject; import jakarta.interceptor.InterceptorBinding; +import org.eclipse.microprofile.config.ConfigProvider; import org.eclipse.microprofile.rest.client.inject.RestClient; import org.jboss.jandex.AnnotationInstance; import org.jboss.jandex.AnnotationTarget; @@ -66,7 +70,7 @@ import com.fasterxml.jackson.databind.PropertyNamingStrategies; -import dev.langchain4j.agent.tool.Tool; +import dev.langchain4j.guardrail.OutputGuardrail; import dev.langchain4j.model.chat.request.json.JsonSchema; import dev.langchain4j.service.IllegalConfigurationException; import dev.langchain4j.service.Moderate; @@ -75,6 +79,8 @@ import io.quarkiverse.langchain4j.ModelName; import io.quarkiverse.langchain4j.RegisterAiService; import io.quarkiverse.langchain4j.ToolBox; +import io.quarkiverse.langchain4j.deployment.DeclarativeAiServiceBuildItem.DeclarativeAiServiceInputGuardrails; +import io.quarkiverse.langchain4j.deployment.DeclarativeAiServiceBuildItem.DeclarativeAiServiceOutputGuardrails; import io.quarkiverse.langchain4j.deployment.config.LangChain4jBuildConfig; import io.quarkiverse.langchain4j.deployment.devui.ToolProviderInfo; import io.quarkiverse.langchain4j.deployment.items.AiServicesMethodBuildItem; @@ -83,8 +89,9 @@ import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem; import io.quarkiverse.langchain4j.deployment.items.ToolMethodBuildItem; import io.quarkiverse.langchain4j.deployment.items.ToolQualifierProvider; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; +import io.quarkiverse.langchain4j.guardrails.InputGuardrailsLiteral; import io.quarkiverse.langchain4j.guardrails.OutputGuardrailAccumulator; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailsLiteral; import io.quarkiverse.langchain4j.runtime.AiServicesRecorder; import io.quarkiverse.langchain4j.runtime.NamedConfigUtil; import io.quarkiverse.langchain4j.runtime.QuarkusServiceOutputParser; @@ -101,6 +108,9 @@ import io.quarkiverse.langchain4j.runtime.aiservice.MetricsTimedWrapper; import io.quarkiverse.langchain4j.runtime.aiservice.QuarkusAiServiceContext; import io.quarkiverse.langchain4j.runtime.aiservice.SpanWrapper; +import io.quarkiverse.langchain4j.runtime.config.GuardrailsConfig; +import io.quarkiverse.langchain4j.runtime.types.TypeSignatureParser; +import io.quarkiverse.langchain4j.runtime.types.TypeUtil; import io.quarkiverse.langchain4j.spi.DefaultMemoryIdProvider; import io.quarkus.arc.Arc; import io.quarkus.arc.ArcContainer; @@ -440,7 +450,10 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem, toolProviderClassName, beanName(declarativeAiServiceClassInfo), toolHallucinationStrategy(instance), + classInputGuardrails(declarativeAiServiceClassInfo, index), + classOutputGuardrails(declarativeAiServiceClassInfo, index), maxSequentialToolInvocations)); + } toolProviderProducer.produce(new ToolProviderMetaBuildItem(toolProviderInfos)); @@ -521,6 +534,40 @@ private DotName chatMemoryProviderSupplierClassDotName(BuildProducer a.value("maxRetries")) + .map(AnnotationValue::asInt) + .orElse(GuardrailsConfig.MAX_RETRIES_DEFAULT); + + return new DeclarativeAiServiceOutputGuardrails(classGuardrails(outputGuardrailsAnnotation, index), maxRetries); + } + + private static List classGuardrails(Optional annotation, IndexView index) { + return gatherGuardrailsStream(annotation) + .map(Type::name) + .map(index::getClassByName) + .toList(); + } + + private static InputGuardrailsLiteral classInputGuardrails(DeclarativeAiServiceBuildItem declarativeAiServiceBuildItem) { + return new InputGuardrailsLiteral(declarativeAiServiceBuildItem.getInputGuardrails().asClassNames()); + } + + private static OutputGuardrailsLiteral classOutputGuardrails(DeclarativeAiServiceBuildItem declarativeAiServiceBuildItem) { + return new OutputGuardrailsLiteral( + declarativeAiServiceBuildItem.getOutputGuardrails().asClassNames(), + declarativeAiServiceBuildItem.getOutputGuardrails().maxRetries()); + } + private static List tools(AnnotationInstance instance, IndexView index) { AnnotationValue toolsInstance = instance.value("tools"); if (toolsInstance != null) { @@ -756,6 +803,8 @@ public void handleDeclarativeServices(AiServicesRecorder recorder, injectModerationModelBean, injectImageModel, toolHallucinationStrategyClassName, + classInputGuardrails(bi), + classOutputGuardrails(bi), maxSequentialToolInvocations))) .setRuntimeInit() .addQualifier() @@ -866,6 +915,11 @@ public void handleDeclarativeServices(AiServicesRecorder recorder, allToolProviders.add(toolProvider); } + configurator + .addInjectionPoint(ParameterizedType.create(DotNames.CDI_INSTANCE, + new Type[] { ClassType.create(io.quarkiverse.langchain4j.guardrails.OutputGuardrail.class) }, null)) + .done(); + configurator .addInjectionPoint(ParameterizedType.create(DotNames.CDI_INSTANCE, new Type[] { ClassType.create(OutputGuardrail.class) }, null)) @@ -913,8 +967,23 @@ public void handleDeclarativeServices(AiServicesRecorder recorder, public void markUsedGuardRailsUnremovable(List methods, BuildProducer unremovableProducer) { for (AiServicesMethodBuildItem method : methods) { - List list = new ArrayList<>(method.getOutputGuardrails()); - list.addAll(method.getInputGuardrails()); + List list = new ArrayList<>(method.getQuarkusOutputGuardrailClassNames()); + list.addAll(method.getQuarkusInputGuardrailClassNames()); + + method.getInputGuardrails() + .map(InputGuardrailsLiteral::value) + .map(Arrays::stream) + .orElseGet(Stream::of) + .map(Class::getName) + .forEach(list::add); + + method.getOutputGuardrails() + .map(OutputGuardrailsLiteral::value) + .map(Arrays::stream) + .orElseGet(Stream::of) + .map(Class::getName) + .forEach(list::add); + for (String cn : list) { unremovableProducer.produce(UnremovableBeanBuildItem.beanTypes(DotName.createSimple(cn))); } @@ -994,8 +1063,23 @@ public void validateGuardrails(SynthesisFinishedBuildItem synthesisFinished, BuildProducer errors) { for (AiServicesMethodBuildItem method : methods) { - List list = new ArrayList<>(method.getOutputGuardrails()); - list.addAll(method.getInputGuardrails()); + List list = new ArrayList<>(method.getQuarkusOutputGuardrailClassNames()); + list.addAll(method.getQuarkusInputGuardrailClassNames()); + + method.getInputGuardrails() + .map(InputGuardrailsLiteral::value) + .map(Arrays::stream) + .orElseGet(Stream::of) + .map(Class::getName) + .forEach(list::add); + + method.getOutputGuardrails() + .map(OutputGuardrailsLiteral::value) + .map(Arrays::stream) + .orElseGet(Stream::of) + .map(Class::getName) + .forEach(list::add); + for (String cn : list) { if (synthesisFinished.beanStream().withBeanType(DotName.createSimple(cn)).isEmpty()) { errors.produce(new ValidationPhaseBuildItem.ValidationErrorBuildItem( @@ -1023,12 +1107,23 @@ public void validateGuardrails(SynthesisFinishedBuildItem synthesisFinished, method.getMethodInfo().name())))); } - // Check that the method have output guardrails - if (method.getOutputGuardrails().isEmpty()) { - errors.produce(new ValidationPhaseBuildItem.ValidationErrorBuildItem( - new DeploymentException("OutputGuardrailAccumulator used without OutputGuardrails in method `%s.%s`" - .formatted(method.getMethodInfo().declaringClass().toString(), - method.getMethodInfo().name())))); + // Check that the method has output guardrails + if (method.getQuarkusOutputGuardrailClassNames().isEmpty() && !method.hasOutputGuardrails()) { + if (method.getQuarkusOutputGuardrailClassNames().isEmpty()) { + errors.produce(new ValidationPhaseBuildItem.ValidationErrorBuildItem( + new DeploymentException( + "OutputGuardrailAccumulator used without io.quarkiverse.langchain4j.guardrails.OutputGuardrails in method `%s.%s`" + .formatted(method.getMethodInfo().declaringClass().toString(), + method.getMethodInfo().name())))); + } + + if (!method.hasOutputGuardrails()) { + errors.produce(new ValidationPhaseBuildItem.ValidationErrorBuildItem( + new DeploymentException( + "OutputGuardrailAccumulator used without dev.langchain4j.service.guardrail.OutputGuardrails in method `%s.%s`" + .formatted(method.getMethodInfo().declaringClass().toString(), + method.getMethodInfo().name())))); + } } } } @@ -1315,8 +1410,10 @@ public void handleAiServices( mc.returnValue(resultHandle); aiServicesMethodProducer.produce(new AiServicesMethodBuildItem(methodInfo, - methodCreateInfo.getInputGuardrailsClassNames(), - methodCreateInfo.getOutputGuardrailsClassNames(), + methodCreateInfo.getQuarkusInputGuardrailsClassNames(), + methodCreateInfo.getQuarkusOutputGuardrailsClassNames(), + methodCreateInfo.getInputGuardrails(), + methodCreateInfo.getOutputGuardrails(), methodCreateInfo.getResponseAugmenterClassName(), methodCreateInfo)); } @@ -1341,7 +1438,21 @@ public void handleAiServices( } } - perClassMetadata.put(ifaceName, new AiServiceClassCreateInfo(perMethodMetadata, implClassName)); + + var aiServiceBuildItem = declarativeAiServiceItems.stream() + .filter(bi -> bi.getServiceClassInfo().equals(iface)) + .findFirst(); + + var inputGuardrails = aiServiceBuildItem + .map(AiServicesProcessor::classInputGuardrails) + .orElse(null); + + var outputGuardrails = aiServiceBuildItem + .map(AiServicesProcessor::classOutputGuardrails) + .orElse(null); + + perClassMetadata.put(ifaceName, + new AiServiceClassCreateInfo(perMethodMetadata, implClassName, inputGuardrails, outputGuardrails)); // make the constructor accessible reflectively since that is how we create the instance reflectiveClassProducer.produce(ReflectiveClassBuildItem.builder(implClassName).build()); } @@ -1480,8 +1591,36 @@ private AiServiceMethodCreateInfo gatherMethodMetadata( List methodMcpClientNames = gatherMethodMcpClientNames(method); - List outputGuardrails = AiServicesMethodBuildItem.gatherGuardrails(method, OUTPUT_GUARDRAILS); - List inputGuardrails = AiServicesMethodBuildItem.gatherGuardrails(method, INPUT_GUARDRAILS); + /** + * @deprecated Will go away once the Quarkus-specific guardrail implementation has been fully removed + */ + @Deprecated(forRemoval = true) + List quarkusOutputGuardrailClasses = AiServicesMethodBuildItem.gatherGuardrails(method, + QUARKUS_OUTPUT_GUARDRAILS); + + var guardrailDeprecationWarning = """ + + ================== DEPRECATION WARNING ================== + The following Quarkus-specific %s guardrail classes have been discovered on the method (%s) in the class (%s). Please move to the new upstream guardrails. + %s + """; + + if (!quarkusOutputGuardrailClasses.isEmpty()) { + log.warnf(guardrailDeprecationWarning, "output", method, method.declaringClass(), + String.join("\n", quarkusOutputGuardrailClasses)); + } + + /** + * @deprecated Will go away once the Quarkus-specific guardrail implementation has been fully removed + */ + @Deprecated(forRemoval = true) + List quarkusInputGuardrailClassess = AiServicesMethodBuildItem.gatherGuardrails(method, + QUARKUS_INPUT_GUARDRAILS); + + if (!quarkusInputGuardrailClassess.isEmpty()) { + log.warnf(guardrailDeprecationWarning, "input", method, method.declaringClass(), + String.join("\n", quarkusInputGuardrailClassess)); + } String accumulatorClassName = AiServicesMethodBuildItem.gatherAccumulator(method); @@ -1491,12 +1630,64 @@ private AiServiceMethodCreateInfo gatherMethodMetadata( boolean switchToWorkerThreadForToolExecution = detectIfToolExecutionRequiresAWorkerThread(method, tools, methodToolClassInfo.keySet()); + var methodReturnTypeSignature = returnTypeSignature(method.returnType(), + new TypeArgMapper(method.declaringClass(), index)); + return new AiServiceMethodCreateInfo(method.declaringClass().name().toString(), method.name(), systemMessageInfo, - userMessageInfo, memoryIdParamPosition, requiresModeration, - returnTypeSignature(method.returnType(), new TypeArgMapper(method.declaringClass(), index)), + userMessageInfo, memoryIdParamPosition, requiresModeration, methodReturnTypeSignature, overrideChatModelParamPosition, metricsTimedInfo, metricsCountedInfo, spanInfo, responseSchemaInfo, - methodToolClassInfo, methodMcpClientNames, switchToWorkerThreadForToolExecution, inputGuardrails, - outputGuardrails, accumulatorClassName, responseAugmenterClassName); + methodToolClassInfo, methodMcpClientNames, switchToWorkerThreadForToolExecution, quarkusInputGuardrailClassess, + quarkusOutputGuardrailClasses, + accumulatorClassName, responseAugmenterClassName, gatherInputGuardrails(method), + gatherOutputGuardrails(method, methodReturnTypeSignature)); + } + + private static InputGuardrailsLiteral gatherInputGuardrails(MethodInfo method) { + return new InputGuardrailsLiteral( + gatherGuardrails(getGuardrailsAnnotation(method, LangChain4jDotNames.INPUT_GUARDRAILS))); + } + + private static OutputGuardrailsLiteral gatherOutputGuardrails(MethodInfo method, String methodReturnTypeSignature) { + var annotationInstance = getGuardrailsAnnotation(method, LangChain4jDotNames.OUTPUT_GUARDRAILS); + var methodReturnsMulti = TypeUtil.isMulti(TypeSignatureParser.parse(methodReturnTypeSignature)); + var maxRetriesAsSetByConfig = annotationInstance + .map(v -> v.value("maxRetries")) + .map(AnnotationValue::asInt) + .or(() -> ConfigProvider.getConfig().getOptionalValue("quarkus.langchain4j.guardrails.max-retries", + Integer.class)) + .orElse(GuardrailsConfig.MAX_RETRIES_DEFAULT); + + // If the method returns a Multi, then we don't want the guardrail service to perform any retries on its own + // Instead we'll store the value as a config value and we'll have the multi itself perform the retries + // based on the number of retries configured, either on the annotation or set through config + return new OutputGuardrailsLiteral( + gatherGuardrails(annotationInstance), + methodReturnsMulti ? 0 : maxRetriesAsSetByConfig, + maxRetriesAsSetByConfig); + } + + private static Optional getGuardrailsAnnotation(MethodInfo methodInfo, DotName annotation) { + return Optional.ofNullable(methodInfo.annotation(annotation)) + .or(() -> getGuardrailsAnnotation(methodInfo.declaringClass(), annotation)); + } + + private static Optional getGuardrailsAnnotation(ClassInfo classInfo, DotName annotation) { + return Optional.ofNullable(classInfo.declaredAnnotation(annotation)); + } + + private static Stream gatherGuardrailsStream(Optional annotation) { + return annotation + .map(AnnotationInstance::value) + .map(AnnotationValue::asClassArray) + .map(Arrays::stream) + .orElseGet(Stream::of) + .distinct(); + } + + private static List gatherGuardrails(Optional annotation) { + return gatherGuardrailsStream(annotation) + .map(t -> t.name().toString()) + .toList(); } private Optional jsonSchemaFrom(java.lang.reflect.Type returnType) { diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java index 84af71fd7..fee2affd3 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java @@ -31,6 +31,8 @@ public final class DeclarativeAiServiceBuildItem extends MultiBuildItem { private final String moderationModelName; private final String imageModelName; private final Optional beanName; + private final DeclarativeAiServiceInputGuardrails inputGuardrails; + private final DeclarativeAiServiceOutputGuardrails outputGuardrails; private final Integer maxSequentialToolInvocations; public DeclarativeAiServiceBuildItem( @@ -51,6 +53,8 @@ public DeclarativeAiServiceBuildItem( DotName toolProviderClassDotName, Optional beanName, DotName toolHallucinationStrategyClassDotName, + DeclarativeAiServiceInputGuardrails inputGuardrails, + DeclarativeAiServiceOutputGuardrails outputGuardrails, Integer maxSequentialToolInvocations) { this.serviceClassInfo = serviceClassInfo; this.chatLanguageModelSupplierClassDotName = chatLanguageModelSupplierClassDotName; @@ -69,6 +73,8 @@ public DeclarativeAiServiceBuildItem( this.toolProviderClassDotName = toolProviderClassDotName; this.beanName = beanName; this.toolHallucinationStrategyClassDotName = toolHallucinationStrategyClassDotName; + this.inputGuardrails = inputGuardrails; + this.outputGuardrails = outputGuardrails; this.maxSequentialToolInvocations = maxSequentialToolInvocations; } @@ -140,6 +146,35 @@ public DotName getToolHallucinationStrategyClassDotName() { return toolHallucinationStrategyClassDotName; } + public DeclarativeAiServiceInputGuardrails getInputGuardrails() { + return inputGuardrails; + } + + public DeclarativeAiServiceOutputGuardrails getOutputGuardrails() { + return outputGuardrails; + } + + public record DeclarativeAiServiceInputGuardrails(List inputGuardrailClassInfos) { + public List asClassNames() { + return this.inputGuardrailClassInfos.stream() + .map(classInfo -> classInfo.name().toString()) + .toList(); + } + } + + public record DeclarativeAiServiceOutputGuardrails(List outputGuardrailClassInfos, int maxRetries, + int actualMaxRetries) { + public DeclarativeAiServiceOutputGuardrails(List outputGuardrailClassInfos, int maxRetries) { + this(outputGuardrailClassInfos, maxRetries, maxRetries); + } + + public List asClassNames() { + return this.outputGuardrailClassInfos.stream() + .map(classInfo -> classInfo.name().toString()) + .toList(); + } + } + public Integer getMaxSequentialToolInvocations() { return maxSequentialToolInvocations; } diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/GuardrailObservabilityProcessorSupport.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/GuardrailObservabilityProcessorSupport.java index 483d251f6..d7d7fed0b 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/GuardrailObservabilityProcessorSupport.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/GuardrailObservabilityProcessorSupport.java @@ -10,21 +10,64 @@ import org.jboss.logging.Logger; import dev.langchain4j.data.message.UserMessage; -import io.quarkiverse.langchain4j.guardrails.InputGuardrail; -import io.quarkiverse.langchain4j.guardrails.InputGuardrailParams; -import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrailParams; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; +import dev.langchain4j.guardrail.InputGuardrail; +import dev.langchain4j.guardrail.InputGuardrailRequest; +import dev.langchain4j.guardrail.InputGuardrailResult; +import dev.langchain4j.guardrail.OutputGuardrail; +import dev.langchain4j.guardrail.OutputGuardrailRequest; +import dev.langchain4j.guardrail.OutputGuardrailResult; final class GuardrailObservabilityProcessorSupport { private static final Logger LOG = Logger.getLogger(GuardrailObservabilityProcessorSupport.class); - private static final DotName INPUT_GUARDRAIL_PARAMS = DotName.createSimple(InputGuardrailParams.class); + + /** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ + @Deprecated(forRemoval = true) + private static final DotName QUARKUS_INPUT_GUARDRAIL_PARAMS = DotName + .createSimple(io.quarkiverse.langchain4j.guardrails.InputGuardrailParams.class); + + /** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ + @Deprecated(forRemoval = true) + private static final DotName QUARKUS_INPUT_GUARDRAIL_RESULT = DotName + .createSimple(io.quarkiverse.langchain4j.guardrails.InputGuardrailResult.class); + + /** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ + @Deprecated(forRemoval = true) + private static final DotName QUARKUS_OUTPUT_GUARDRAIL_PARAMS = DotName + .createSimple(io.quarkiverse.langchain4j.guardrails.OutputGuardrailParams.class); + + /** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ + @Deprecated(forRemoval = true) + private static final DotName QUARKUS_OUTPUT_GUARDRAIL_RESULT = DotName + .createSimple(io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult.class); + + /** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ + @Deprecated(forRemoval = true) + private static final DotName QUARKUS_INPUT_GUARDRAIL = DotName + .createSimple(io.quarkiverse.langchain4j.guardrails.InputGuardrail.class); + + /** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ + @Deprecated(forRemoval = true) + private static final DotName QUARKUS_OUTPUT_GUARDRAIL = DotName + .createSimple(io.quarkiverse.langchain4j.guardrails.OutputGuardrail.class); + private static final DotName INPUT_GUARDRAIL_REQUEST = DotName.createSimple(InputGuardrailRequest.class); private static final DotName INPUT_GUARDRAIL_RESULT = DotName.createSimple(InputGuardrailResult.class); - private static final DotName OUTPUT_GUARDRAIL_PARAMS = DotName.createSimple(OutputGuardrailParams.class); + private static final DotName OUTPUT_GUARDRAIL_REQUEST = DotName.createSimple(OutputGuardrailRequest.class); private static final DotName OUTPUT_GUARDRAIL_RESULT = DotName.createSimple(OutputGuardrailResult.class); private static final DotName INPUT_GUARDRAIL = DotName.createSimple(InputGuardrail.class); private static final DotName OUTPUT_GUARDRAIL = DotName.createSimple(OutputGuardrail.class); + static final DotName MICROMETER_TIMED = DotName.createSimple("io.micrometer.core.annotation.Timed"); static final DotName MICROMETER_COUNTED = DotName.createSimple("io.micrometer.core.annotation.Counted"); static final DotName WITH_SPAN = DotName.createSimple("io.opentelemetry.instrumentation.annotations.WithSpan"); @@ -37,11 +80,17 @@ enum TransformType { } enum GuardrailType { + QUARKUS_INPUT, + QUARKUS_OUTPUT, INPUT, OUTPUT; static Optional from(IndexView indexView, ClassInfo classToCheck) { - if (indexView.getAllKnownImplementors(INPUT_GUARDRAIL).contains(classToCheck)) { + if (indexView.getAllKnownImplementors(QUARKUS_INPUT_GUARDRAIL).contains(classToCheck)) { + return Optional.of(QUARKUS_INPUT); + } else if (indexView.getAllKnownImplementors(QUARKUS_OUTPUT_GUARDRAIL).contains(classToCheck)) { + return Optional.of(QUARKUS_OUTPUT); + } else if (indexView.getAllKnownImplementors(INPUT_GUARDRAIL).contains(classToCheck)) { return Optional.of(INPUT); } else if (indexView.getAllKnownImplementors(OUTPUT_GUARDRAIL).contains(classToCheck)) { return Optional.of(OUTPUT); @@ -102,16 +151,18 @@ private static boolean shouldTransformGuardrailValidateMethod(MethodInfo methodI } var isOtherValidateMethodVariant = switch (guardrailType) { - case INPUT -> isInputGuardrailValidateMethodWithUserMessage(methodInfo); - case OUTPUT -> isOutputGuardrailValidateMethodWithAiMessage(methodInfo); + case QUARKUS_INPUT, INPUT -> isInputGuardrailValidateMethodWithUserMessage(methodInfo); + case QUARKUS_OUTPUT, OUTPUT -> isOutputGuardrailValidateMethodWithAiMessage(methodInfo); }; if (isOtherValidateMethodVariant && !doesMethodAlreadyHaveTransformationAnnotation(methodInfo, transformType)) { // If this is the other method variant, we need to ensure that the // variant with the params isn't also present on the method's declaring class var paramType = switch (guardrailType) { - case INPUT -> Type.parse(INPUT_GUARDRAIL_PARAMS.toString()); - case OUTPUT -> Type.parse(OUTPUT_GUARDRAIL_PARAMS.toString()); + case QUARKUS_INPUT -> Type.parse(QUARKUS_INPUT_GUARDRAIL_PARAMS.toString()); + case QUARKUS_OUTPUT -> Type.parse(QUARKUS_OUTPUT_GUARDRAIL_PARAMS.toString()); + case INPUT -> Type.parse(INPUT_GUARDRAIL_REQUEST.toString()); + case OUTPUT -> Type.parse(OUTPUT_GUARDRAIL_REQUEST.toString()); }; var otherValidateMethod = methodDeclaringClass.method("validate", paramType); @@ -129,9 +180,16 @@ private static boolean shouldTransformGuardrailValidateMethod(MethodInfo methodI * Checks the method meets ALL the following conditions: *
    *
  • The method's name is {@link #VALIDATE_METHOD_NAME}
  • - *
  • IF the method's single parameter's type is {@link InputGuardrailParams} then the return type must be + *
  • IF the method's single parameter's type is + * {@link io.quarkiverse.langchain4j.guardrails.InputGuardrailParams} then the return type must be + * {@link io.quarkiverse.langchain4j.guardrails.InputGuardrailResult}
  • + *
  • IF the method's single parameter's type is + * {@link io.quarkiverse.langchain4j.guardrails.OutputGuardrailParams} then the return type must + * be {@link io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult}
  • + *
  • IF the method's single parameter's type is {@link InputGuardrailRequest} then the return type must + * be * {@link InputGuardrailResult}
  • - *
  • IF the method's single parameter's type is {@link OutputGuardrailParams} then the return type must + *
  • IF the method's single parameter's type is {@link OutputGuardrailRequest} then the return type must * be {@link OutputGuardrailResult}
  • *
*/ @@ -143,7 +201,8 @@ static boolean isGuardrailValidateMethodWithParams(MethodInfo methodInfo) { * Checks the method meets ALL the following conditions: *
    *
  • The method's name is {@link #VALIDATE_METHOD_NAME}
  • - *
  • The method's return type is {@link InputGuardrailResult}
  • + *
  • The method's return type is {@link io.quarkiverse.langchain4j.guardrails.InputGuardrailResult} or + * {@link InputGuardrailResult}
  • *
  • The method's single parameter's type is {@link dev.langchain4j.data.message.UserMessage}
  • *
*/ @@ -156,7 +215,8 @@ private static boolean isInputGuardrailValidateMethodWithUserMessage(MethodInfo * Checks the method meets ALL the following conditions: *
    *
  • The method's name is {@link #VALIDATE_METHOD_NAME}
  • - *
  • The method's return type is {@link OutputGuardrailResult}
  • + *
  • The method's return type is {@link io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult} or + * {@link OutputGuardrailResult}
  • *
  • The method's single parameter's type is {@link dev.langchain4j.data.message.AiMessage}
  • *
*/ @@ -168,9 +228,16 @@ private static boolean isOutputGuardrailValidateMethodWithAiMessage(MethodInfo m /** * Checks the method meets ALL the following conditions: *
    - *
  • IF the method's single parameter's type is {@link InputGuardrailParams} then the return type must be + *
  • IF the method's single parameter's type is + * {@link io.quarkiverse.langchain4j.guardrails.InputGuardrailParams} then the return type must be + * {@link io.quarkiverse.langchain4j.guardrails.InputGuardrailResult}
  • + *
  • IF the method's single parameter's type is + * {@link io.quarkiverse.langchain4j.guardrails.OutputGuardrailParams} then the return type must + * be {@link io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult}
  • + *
  • IF the method's single parameter's type is {@link InputGuardrailRequest} then the return type must + * be * {@link InputGuardrailResult}
  • - *
  • IF the method's single parameter's type is {@link OutputGuardrailParams} then the return type must + *
  • IF the method's single parameter's type is {@link OutputGuardrailRequest} then the return type must * be {@link OutputGuardrailResult}
  • *
*/ @@ -187,8 +254,13 @@ private static boolean doesValidateMethodWithParamsHaveCorrectSignature(MethodIn // Also check the return type var returnType = methodInfo.returnType().name(); - return (INPUT_GUARDRAIL_PARAMS.equals(paramTypeName) && INPUT_GUARDRAIL_RESULT.equals(returnType)) || - (OUTPUT_GUARDRAIL_PARAMS.equals(paramTypeName) && OUTPUT_GUARDRAIL_RESULT.equals(returnType)); + return (QUARKUS_INPUT_GUARDRAIL_PARAMS.equals(paramTypeName) && QUARKUS_INPUT_GUARDRAIL_RESULT.equals(returnType)) + || + (QUARKUS_OUTPUT_GUARDRAIL_PARAMS.equals(paramTypeName) + && QUARKUS_OUTPUT_GUARDRAIL_RESULT.equals(returnType)) + || + (INPUT_GUARDRAIL_REQUEST.equals(paramTypeName) && INPUT_GUARDRAIL_RESULT.equals(returnType)) || + (OUTPUT_GUARDRAIL_REQUEST.equals(paramTypeName) && OUTPUT_GUARDRAIL_RESULT.equals(returnType)); } return false; @@ -207,7 +279,8 @@ private static boolean doesValidateMethodWithoutParamsHaveCorrectSignature(Metho var returnType = methodInfo.returnType().name(); return paramType.equals(paramTypeName) && - (INPUT_GUARDRAIL_RESULT.equals(returnType) || OUTPUT_GUARDRAIL_RESULT.equals(returnType)); + (QUARKUS_INPUT_GUARDRAIL_RESULT.equals(returnType) || QUARKUS_OUTPUT_GUARDRAIL_RESULT.equals(returnType) || + INPUT_GUARDRAIL_RESULT.equals(returnType) || OUTPUT_GUARDRAIL_RESULT.equals(returnType)); } return false; diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/LangChain4jDotNames.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/LangChain4jDotNames.java index 84d286656..6e03e748a 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/LangChain4jDotNames.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/LangChain4jDotNames.java @@ -26,6 +26,8 @@ import dev.langchain4j.service.TokenStream; import dev.langchain4j.service.UserMessage; import dev.langchain4j.service.UserName; +import dev.langchain4j.service.guardrail.InputGuardrails; +import dev.langchain4j.service.guardrail.OutputGuardrails; import dev.langchain4j.service.tool.ToolProvider; import dev.langchain4j.web.search.WebSearchEngine; import dev.langchain4j.web.search.WebSearchTool; @@ -36,8 +38,6 @@ import io.quarkiverse.langchain4j.PdfUrl; import io.quarkiverse.langchain4j.RegisterAiService; import io.quarkiverse.langchain4j.SeedMemory; -import io.quarkiverse.langchain4j.guardrails.InputGuardrails; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; import io.quarkiverse.langchain4j.runtime.aiservice.QuarkusAiServiceContextQualifier; public class LangChain4jDotNames { @@ -49,6 +49,18 @@ public class LangChain4jDotNames { public static final DotName IMAGE_MODEL = DotName.createSimple(ImageModel.class); public static final DotName CHAT_MESSAGE = DotName.createSimple(ChatMessage.class); public static final DotName TOKEN_STREAM = DotName.createSimple(TokenStream.class); + /** + * @deprecated Will go away once the Quarkus-specific guardrail implementation has been fully removed + */ + @Deprecated(forRemoval = true) + public static final DotName QUARKUS_OUTPUT_GUARDRAILS = DotName + .createSimple(io.quarkiverse.langchain4j.guardrails.OutputGuardrails.class); + /** + * @deprecated Will go away once the Quarkus-specific guardrail implementation has been fully removed + */ + @Deprecated(forRemoval = true) + public static final DotName QUARKUS_INPUT_GUARDRAILS = DotName + .createSimple(io.quarkiverse.langchain4j.guardrails.InputGuardrails.class); public static final DotName OUTPUT_GUARDRAILS = DotName.createSimple(OutputGuardrails.class); public static final DotName INPUT_GUARDRAILS = DotName.createSimple(InputGuardrails.class); static final DotName AI_SERVICES = DotName.createSimple(AiServices.class); diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/items/AiServicesMethodBuildItem.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/items/AiServicesMethodBuildItem.java index d25c52663..f8c2e9ef1 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/items/AiServicesMethodBuildItem.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/items/AiServicesMethodBuildItem.java @@ -2,13 +2,16 @@ import java.util.ArrayList; import java.util.List; +import java.util.Optional; import org.jboss.jandex.AnnotationInstance; import org.jboss.jandex.DotName; import org.jboss.jandex.MethodInfo; import org.jboss.jandex.Type; +import io.quarkiverse.langchain4j.guardrails.InputGuardrailsLiteral; import io.quarkiverse.langchain4j.guardrails.OutputGuardrailAccumulator; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailsLiteral; import io.quarkiverse.langchain4j.response.ResponseAugmenter; import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodCreateInfo; import io.quarkus.builder.item.MultiBuildItem; @@ -19,27 +22,70 @@ public final class AiServicesMethodBuildItem extends MultiBuildItem { private final MethodInfo methodInfo; - private final List outputGuardrails; - private final List inputGuardrails; + /** + * @deprecated Will go away once the Quarkus-specific guardrail implementation has been fully removed + */ + @Deprecated(forRemoval = true) + private final List quarkusOutputGuardrailClassNames; + + /** + * @deprecated Will go away once the Quarkus-specific guardrail implementation has been fully removed + */ + @Deprecated(forRemoval = true) + private final List quarkusInputGuardrailClassNames; + private final InputGuardrailsLiteral inputGuardrails; + private final OutputGuardrailsLiteral outputGuardrails; private final AiServiceMethodCreateInfo methodCreateInfo; private final String responseAugmenter; - public AiServicesMethodBuildItem(MethodInfo methodInfo, List inputGuardrails, List outputGuardrails, + public AiServicesMethodBuildItem(MethodInfo methodInfo, List quarkusInputGuardrailClassNames, + List quarkusOutputGuardrailClassNames, + InputGuardrailsLiteral inputGuardrails, OutputGuardrailsLiteral outputGuardrails, String responseAugmenter, AiServiceMethodCreateInfo methodCreateInfo) { this.methodInfo = methodInfo; + this.quarkusInputGuardrailClassNames = quarkusInputGuardrailClassNames; + this.quarkusOutputGuardrailClassNames = quarkusOutputGuardrailClassNames; this.inputGuardrails = inputGuardrails; this.outputGuardrails = outputGuardrails; this.responseAugmenter = responseAugmenter; this.methodCreateInfo = methodCreateInfo; } - public List getOutputGuardrails() { - return outputGuardrails; + /** + * @deprecated Will go away once the Quarkus-specific guardrail implementation has been fully removed + */ + @Deprecated(forRemoval = true) + public List getQuarkusOutputGuardrailClassNames() { + return quarkusOutputGuardrailClassNames; + } + + /** + * @deprecated Will go away once the Quarkus-specific guardrail implementation has been fully removed + */ + @Deprecated(forRemoval = true) + public List getQuarkusInputGuardrailClassNames() { + return quarkusInputGuardrailClassNames; + } + + public Optional getInputGuardrails() { + return Optional.ofNullable(inputGuardrails); + } + + public Optional getOutputGuardrails() { + return Optional.ofNullable(outputGuardrails); + } + + public boolean hasInputGuardrails() { + return getInputGuardrails() + .map(InputGuardrailsLiteral::hasGuardrails) + .orElse(false); } - public List getInputGuardrails() { - return inputGuardrails; + public boolean hasOutputGuardrails() { + return getOutputGuardrails() + .map(OutputGuardrailsLiteral::hasGuardrails) + .orElse(false); } public MethodInfo getMethodInfo() { diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/deployment/GuardrailObservabilityProcessorSupportTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/deployment/GuardrailObservabilityProcessorSupportTest.java index 927172b02..ed62fa2a5 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/deployment/GuardrailObservabilityProcessorSupportTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/deployment/GuardrailObservabilityProcessorSupportTest.java @@ -19,22 +19,22 @@ import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.guardrail.Guardrail; +import dev.langchain4j.guardrail.GuardrailRequest; +import dev.langchain4j.guardrail.GuardrailResult; +import dev.langchain4j.guardrail.InputGuardrail; +import dev.langchain4j.guardrail.InputGuardrailRequest; +import dev.langchain4j.guardrail.InputGuardrailResult; +import dev.langchain4j.guardrail.OutputGuardrail; +import dev.langchain4j.guardrail.OutputGuardrailRequest; +import dev.langchain4j.guardrail.OutputGuardrailResult; +import dev.langchain4j.service.guardrail.InputGuardrails; +import dev.langchain4j.service.guardrail.OutputGuardrails; import io.micrometer.core.annotation.Counted; import io.micrometer.core.annotation.Timed; import io.opentelemetry.instrumentation.annotations.WithSpan; import io.quarkiverse.langchain4j.RegisterAiService; import io.quarkiverse.langchain4j.deployment.GuardrailObservabilityProcessorSupport.TransformType; -import io.quarkiverse.langchain4j.guardrails.Guardrail; -import io.quarkiverse.langchain4j.guardrails.GuardrailParams; -import io.quarkiverse.langchain4j.guardrails.GuardrailResult; -import io.quarkiverse.langchain4j.guardrails.InputGuardrail; -import io.quarkiverse.langchain4j.guardrails.InputGuardrailParams; -import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult; -import io.quarkiverse.langchain4j.guardrails.InputGuardrails; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrailParams; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; class GuardrailObservabilityProcessorSupportTest { private record ClassInfoMapping(Class clazz, boolean shouldHaveGuardrailValidateMethodWithParams, @@ -86,11 +86,11 @@ boolean shouldHaveAtLeastOneMethodRewritten() { new ClassInfoMapping(IGRedefiningBothValidateMethods.class, true, true, false), new ClassInfoMapping(OGRedefiningBothValidateMethods.class, true, false, true), ClassInfoMapping.somethingElse(Guardrail.class), - ClassInfoMapping.somethingElse(GuardrailParams.class), + ClassInfoMapping.somethingElse(GuardrailRequest.class), ClassInfoMapping.somethingElse(GuardrailResult.class), - ClassInfoMapping.somethingElse(OutputGuardrailParams.class), + ClassInfoMapping.somethingElse(OutputGuardrailRequest.class), ClassInfoMapping.somethingElse(OutputGuardrailResult.class), - ClassInfoMapping.somethingElse(InputGuardrailParams.class), + ClassInfoMapping.somethingElse(InputGuardrailRequest.class), ClassInfoMapping.somethingElse(InputGuardrailResult.class), ClassInfoMapping.somethingElse(InputGuardrails.class), ClassInfoMapping.somethingElse(OutputGuardrails.class), @@ -230,7 +230,7 @@ public InputGuardrailResult validate(UserMessage userMessage) { @ApplicationScoped public static class IGDirectlyImplementInputGuardrailWithParams implements InputGuardrail { @Override - public InputGuardrailResult validate(InputGuardrailParams params) { + public InputGuardrailResult validate(InputGuardrailRequest request) { return success(); } } @@ -238,7 +238,7 @@ public InputGuardrailResult validate(InputGuardrailParams params) { @ApplicationScoped public static class OGDirectlyImplementOutputGuardrailWithParams implements OutputGuardrail { @Override - public OutputGuardrailResult validate(OutputGuardrailParams params) { + public OutputGuardrailResult validate(OutputGuardrailRequest request) { return success(); } } @@ -275,7 +275,7 @@ public InputGuardrailResult validate(UserMessage userMessage) { } @Override - public InputGuardrailResult validate(InputGuardrailParams params) { + public InputGuardrailResult validate(InputGuardrailRequest request) { return success(); } } @@ -288,14 +288,14 @@ public OutputGuardrailResult validate(AiMessage responseFromLLM) { } @Override - public OutputGuardrailResult validate(OutputGuardrailParams params) { + public OutputGuardrailResult validate(OutputGuardrailRequest request) { return success(); } } public static abstract class AbstractOGImplementingValidateWithParams implements OutputGuardrail { @Override - public OutputGuardrailResult validate(OutputGuardrailParams params) { + public OutputGuardrailResult validate(OutputGuardrailRequest request) { return success(); } } @@ -309,7 +309,7 @@ public OutputGuardrailResult validate(AiMessage responseFromLLM) { public static abstract class AbstractIGImplementingValidateWithParams implements InputGuardrail { @Override - public InputGuardrailResult validate(InputGuardrailParams params) { + public InputGuardrailResult validate(InputGuardrailRequest params) { return success(); } } diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/deployment/QuarkusGuardrailObservabilityProcessorSupportTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/deployment/QuarkusGuardrailObservabilityProcessorSupportTest.java new file mode 100644 index 000000000..92a102419 --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/deployment/QuarkusGuardrailObservabilityProcessorSupportTest.java @@ -0,0 +1,348 @@ +package io.quarkiverse.langchain4j.deployment; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.IOException; +import java.util.List; +import java.util.stream.Stream; + +import jakarta.enterprise.context.ApplicationScoped; + +import org.jboss.jandex.ClassInfo; +import org.jboss.jandex.Index; +import org.jboss.jandex.MethodInfo; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.UserMessage; +import io.micrometer.core.annotation.Counted; +import io.micrometer.core.annotation.Timed; +import io.opentelemetry.instrumentation.annotations.WithSpan; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.deployment.GuardrailObservabilityProcessorSupport.TransformType; +import io.quarkiverse.langchain4j.guardrails.Guardrail; +import io.quarkiverse.langchain4j.guardrails.GuardrailParams; +import io.quarkiverse.langchain4j.guardrails.GuardrailResult; +import io.quarkiverse.langchain4j.guardrails.InputGuardrail; +import io.quarkiverse.langchain4j.guardrails.InputGuardrailParams; +import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult; +import io.quarkiverse.langchain4j.guardrails.InputGuardrails; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailParams; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; + +/** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ +@Deprecated(forRemoval = true) +class QuarkusGuardrailObservabilityProcessorSupportTest { + private record ClassInfoMapping(Class clazz, boolean shouldHaveGuardrailValidateMethodWithParams, + boolean shouldHaveInputGuardrailValidateMethodWithUserMessage, + boolean shouldHaveOutputGuardrailValidateMethodWithAiMessage) { + + static ClassInfoMapping guardrailWithValidateMethodWithParams(Class clazz) { + return new ClassInfoMapping(clazz, true, false, false); + } + + static ClassInfoMapping inputGuardrailWithValidateMethodWithUserMessage(Class clazz) { + return new ClassInfoMapping(clazz, false, true, false); + } + + static ClassInfoMapping outputGuardrailWithValidateMethodWithAiMessage(Class clazz) { + return new ClassInfoMapping(clazz, false, false, true); + } + + static ClassInfoMapping somethingElse(Class clazz) { + return new ClassInfoMapping(clazz, false, false, false); + } + + boolean shouldHaveAtLeastOneMethodRewritten() { + return !clazz.isInterface() + && (shouldHaveGuardrailValidateMethodWithParams || shouldHaveInputGuardrailValidateMethodWithUserMessage + || shouldHaveOutputGuardrailValidateMethodWithAiMessage); + } + } + + private static final List CLASS_INFO_MAPPINGS = List.of( + ClassInfoMapping.somethingElse(Assistant.class), + ClassInfoMapping.guardrailWithValidateMethodWithParams(IGDirectlyImplementInputGuardrailWithParams.class), + ClassInfoMapping + .inputGuardrailWithValidateMethodWithUserMessage(IGDirectlyImplementInputGuardrailWithUserMessage.class), + ClassInfoMapping.somethingElse(IGExtendingValidateWithParams.class), + ClassInfoMapping.somethingElse(IGExtendingValidateWithUserMessage.class), + ClassInfoMapping.guardrailWithValidateMethodWithParams(OGDirectlyImplementOutputGuardrailWithParams.class), + ClassInfoMapping + .outputGuardrailWithValidateMethodWithAiMessage(OGDirectlyImplementOutputGuardrailWithAiMessage.class), + ClassInfoMapping.somethingElse(OGExtendingValidateWithParams.class), + ClassInfoMapping.somethingElse(OGExtendingValidateWithAiMessage.class), + ClassInfoMapping.guardrailWithValidateMethodWithParams(AbstractOGImplementingValidateWithParams.class), + ClassInfoMapping.outputGuardrailWithValidateMethodWithAiMessage(AbstractOGImplementingValidateWithAiMessage.class), + ClassInfoMapping.guardrailWithValidateMethodWithParams(AbstractIGImplementingValidateWithParams.class), + ClassInfoMapping + .inputGuardrailWithValidateMethodWithUserMessage(AbstractIGImplementingValidateWithUserMessage.class), + ClassInfoMapping.guardrailWithValidateMethodWithParams(InputGuardrail.class), + ClassInfoMapping.guardrailWithValidateMethodWithParams(OutputGuardrail.class), + new ClassInfoMapping(IGRedefiningBothValidateMethods.class, true, true, false), + new ClassInfoMapping(OGRedefiningBothValidateMethods.class, true, false, true), + ClassInfoMapping.somethingElse(Guardrail.class), + ClassInfoMapping.somethingElse(GuardrailParams.class), + ClassInfoMapping.somethingElse(GuardrailResult.class), + ClassInfoMapping.somethingElse(OutputGuardrailParams.class), + ClassInfoMapping.somethingElse(OutputGuardrailResult.class), + ClassInfoMapping.somethingElse(InputGuardrailParams.class), + ClassInfoMapping.somethingElse(InputGuardrailResult.class), + ClassInfoMapping.somethingElse(InputGuardrails.class), + ClassInfoMapping.somethingElse(OutputGuardrails.class), + ClassInfoMapping.somethingElse(ClassWithTimedAnnotation.class), + ClassInfoMapping.somethingElse(ClassWithCountedAnnotation.class), + ClassInfoMapping.somethingElse(ClassWithSpanAnnotation.class)); + + private Index index; + + @BeforeEach + public void setUp() throws IOException { + this.index = Index.of( + CLASS_INFO_MAPPINGS.stream() + .map(ClassInfoMapping::clazz) + .toArray(Class[]::new)); + } + + @ParameterizedTest + @ValueSource(classes = { ClassWithTimedAnnotation.class, ClassWithCountedAnnotation.class }) + void hasMetricsAnnotations(Class clazz) { + var methodInfo = getClassInfo(clazz).firstMethod("someMethod"); + + assertThat(methodInfo) + .isNotNull() + .extracting(GuardrailObservabilityProcessorSupport::doesMethodHaveMetricsAnnotations) + .isEqualTo(true); + } + + @Test + void hasSpanAnnotation() { + var methodInfo = getClassInfo(ClassWithSpanAnnotation.class).firstMethod("someMethod"); + + assertThat(methodInfo) + .isNotNull() + .extracting(GuardrailObservabilityProcessorSupport::doesMethodHaveSpanAnnotation) + .isEqualTo(true); + } + + @ParameterizedTest + @MethodSource("allClassInfoMappingsWithoutClassesWithAnnotations") + void isGuardrailValidateMethodWithParams(ClassInfoMapping classInfoMapping) { + var allMethods = getAllMethods(classInfoMapping.clazz()); + + assertThat(allMethods) + .allSatisfy(method -> { + // None of the methods already have the metrics annotations + assertThat(method) + .isNotNull() + .extracting(GuardrailObservabilityProcessorSupport::doesMethodHaveMetricsAnnotations) + .isEqualTo(false); + + // None of the methods already have the span annotation + assertThat(method) + .isNotNull() + .extracting(GuardrailObservabilityProcessorSupport::doesMethodHaveSpanAnnotation) + .isEqualTo(false); + }); + + var hasGuardrailValidateMethodWithParams = allMethods.stream() + .filter(GuardrailObservabilityProcessorSupport::isGuardrailValidateMethodWithParams) + .count() > 0; + + assertThat(hasGuardrailValidateMethodWithParams) + .isEqualTo(classInfoMapping.shouldHaveGuardrailValidateMethodWithParams); + } + + @ParameterizedTest + @MethodSource("allClassInfoMappings") + void shouldTransformMethod(ClassInfoMapping classInfoMapping) { + var allMethods = getAllMethods(classInfoMapping.clazz()); + var allMethodsThatShouldHaveMetricsTransformed = allMethods.stream() + .filter(methodInfo -> GuardrailObservabilityProcessorSupport.shouldTransformMethod(methodInfo, this.index, + TransformType.METRICS)) + .toList(); + var allMethodsThatShouldHaveSpanTransformed = allMethods.stream() + .filter(methodInfo -> GuardrailObservabilityProcessorSupport.shouldTransformMethod(methodInfo, this.index, + TransformType.OTEL)) + .toList(); + + if (classInfoMapping.shouldHaveAtLeastOneMethodRewritten()) { + assertThat(allMethodsThatShouldHaveMetricsTransformed) + .isNotNull() + .hasSize(1); + + assertThat(allMethodsThatShouldHaveSpanTransformed) + .isNotNull() + .hasSize(1); + } else { + assertThat(allMethodsThatShouldHaveMetricsTransformed) + .isNotNull() + .isEmpty(); + + assertThat(allMethodsThatShouldHaveSpanTransformed) + .isNotNull() + .isEmpty(); + } + } + + static Stream allClassInfoMappingsWithoutClassesWithAnnotations() { + return CLASS_INFO_MAPPINGS.stream() + .filter(classInfo -> !classInfo.clazz().equals(ClassWithTimedAnnotation.class) && + !classInfo.clazz().equals(ClassWithCountedAnnotation.class) && + !classInfo.clazz().equals(ClassWithSpanAnnotation.class)); + } + + static List allClassInfoMappings() { + return CLASS_INFO_MAPPINGS; + } + + private List getAllMethods(Class clazz) { + return getClassInfo(clazz).methods(); + } + + private ClassInfo getClassInfo(Class clazz) { + return this.index.getClassByName(clazz); + } + + @RegisterAiService + interface Assistant { + @InputGuardrails({ IGDirectlyImplementInputGuardrailWithParams.class, + IGDirectlyImplementInputGuardrailWithUserMessage.class, IGExtendingValidateWithParams.class, + IGExtendingValidateWithUserMessage.class, IGRedefiningBothValidateMethods.class }) + @OutputGuardrails({ OGDirectlyImplementOutputGuardrailWithParams.class, + OGDirectlyImplementOutputGuardrailWithAiMessage.class, OGExtendingValidateWithParams.class, + OGExtendingValidateWithAiMessage.class, OGRedefiningBothValidateMethods.class }) + String chat(String message); + } + + @ApplicationScoped + public static class IGDirectlyImplementInputGuardrailWithUserMessage implements InputGuardrail { + @Override + public InputGuardrailResult validate(UserMessage userMessage) { + return success(); + } + } + + @ApplicationScoped + public static class IGDirectlyImplementInputGuardrailWithParams implements InputGuardrail { + @Override + public InputGuardrailResult validate(InputGuardrailParams params) { + return success(); + } + } + + @ApplicationScoped + public static class OGDirectlyImplementOutputGuardrailWithParams implements OutputGuardrail { + @Override + public OutputGuardrailResult validate(OutputGuardrailParams params) { + return success(); + } + } + + @ApplicationScoped + public static class OGDirectlyImplementOutputGuardrailWithAiMessage implements OutputGuardrail { + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + return success(); + } + } + + @ApplicationScoped + public static class OGExtendingValidateWithParams extends AbstractOGImplementingValidateWithParams { + } + + @ApplicationScoped + public static class OGExtendingValidateWithAiMessage extends AbstractOGImplementingValidateWithAiMessage { + } + + @ApplicationScoped + public static class IGExtendingValidateWithParams extends AbstractIGImplementingValidateWithParams { + } + + @ApplicationScoped + public static class IGExtendingValidateWithUserMessage extends AbstractIGImplementingValidateWithUserMessage { + } + + @ApplicationScoped + public static class IGRedefiningBothValidateMethods implements InputGuardrail { + @Override + public InputGuardrailResult validate(UserMessage userMessage) { + return success(); + } + + @Override + public InputGuardrailResult validate(InputGuardrailParams params) { + return success(); + } + } + + @ApplicationScoped + public static class OGRedefiningBothValidateMethods implements OutputGuardrail { + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + return success(); + } + + @Override + public OutputGuardrailResult validate(OutputGuardrailParams params) { + return success(); + } + } + + public static abstract class AbstractOGImplementingValidateWithParams implements OutputGuardrail { + @Override + public OutputGuardrailResult validate(OutputGuardrailParams params) { + return success(); + } + } + + public static abstract class AbstractOGImplementingValidateWithAiMessage implements OutputGuardrail { + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + return success(); + } + } + + public static abstract class AbstractIGImplementingValidateWithParams implements InputGuardrail { + @Override + public InputGuardrailResult validate(InputGuardrailParams params) { + return success(); + } + } + + public static abstract class AbstractIGImplementingValidateWithUserMessage implements InputGuardrail { + @Override + public InputGuardrailResult validate(UserMessage userMessage) { + return success(); + } + } + + public static class ClassWithTimedAnnotation { + @Timed + public void someMethod() { + + } + } + + public static class ClassWithCountedAnnotation { + @Counted + public void someMethod() { + + } + } + + public static class ClassWithSpanAnnotation { + @WithSpan + public void someMethod() { + + } + } +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/guardrails/QuarkusDeclarativeAiServiceOutputGuardrailsConfigBuilderFactoryNoConfigTests.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/guardrails/QuarkusDeclarativeAiServiceOutputGuardrailsConfigBuilderFactoryNoConfigTests.java new file mode 100644 index 000000000..e42d1ff12 --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/guardrails/QuarkusDeclarativeAiServiceOutputGuardrailsConfigBuilderFactoryNoConfigTests.java @@ -0,0 +1,33 @@ +package io.quarkiverse.langchain4j.guardrails; + +import static org.assertj.core.api.Assertions.assertThat; + +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkiverse.langchain4j.runtime.config.GuardrailsConfig; +import io.quarkiverse.langchain4j.runtime.config.LangChain4jConfig; +import io.quarkus.test.QuarkusUnitTest; + +class QuarkusDeclarativeAiServiceOutputGuardrailsConfigBuilderFactoryNoConfigTests { + @RegisterExtension + static QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)); + + @Inject + LangChain4jConfig config; + + @Test + void hasCorrectConfigSettings() { + assertThat(this.config) + .isNotNull() + .extracting(LangChain4jConfig::guardrails) + .isNotNull() + .extracting(GuardrailsConfig::maxRetries) + .isEqualTo(GuardrailsConfig.MAX_RETRIES_DEFAULT); + } +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/guardrails/QuarkusDeclarativeAiServiceOutputGuardrailsConfigBuilderFactoryTests.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/guardrails/QuarkusDeclarativeAiServiceOutputGuardrailsConfigBuilderFactoryTests.java new file mode 100644 index 000000000..e2ce10d26 --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/guardrails/QuarkusDeclarativeAiServiceOutputGuardrailsConfigBuilderFactoryTests.java @@ -0,0 +1,34 @@ +package io.quarkiverse.langchain4j.guardrails; + +import static org.assertj.core.api.Assertions.assertThat; + +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkiverse.langchain4j.runtime.config.GuardrailsConfig; +import io.quarkiverse.langchain4j.runtime.config.LangChain4jConfig; +import io.quarkus.test.QuarkusUnitTest; + +class QuarkusDeclarativeAiServiceOutputGuardrailsConfigBuilderFactoryTests { + @RegisterExtension + static QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)) + .overrideConfigKey("quarkus.langchain4j.guardrails.max-retries", "10"); + + @Inject + LangChain4jConfig config; + + @Test + void hasCorrectConfigSettings() { + assertThat(this.config) + .isNotNull() + .extracting(LangChain4jConfig::guardrails) + .isNotNull() + .extracting(GuardrailsConfig::maxRetries) + .isEqualTo(10); + } +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/GuardrailWithAugmentationTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/GuardrailWithAugmentationTest.java index 3fc957837..f214e5ae5 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/GuardrailWithAugmentationTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/GuardrailWithAugmentationTest.java @@ -17,6 +17,12 @@ import org.junit.jupiter.api.extension.RegisterExtension; import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.guardrail.InputGuardrail; +import dev.langchain4j.guardrail.InputGuardrailRequest; +import dev.langchain4j.guardrail.InputGuardrailResult; +import dev.langchain4j.guardrail.OutputGuardrail; +import dev.langchain4j.guardrail.OutputGuardrailRequest; +import dev.langchain4j.guardrail.OutputGuardrailResult; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.memory.chat.ChatMemoryProvider; import dev.langchain4j.memory.chat.MessageWindowChatMemory; @@ -30,23 +36,18 @@ import dev.langchain4j.rag.RetrievalAugmentor; import dev.langchain4j.rag.content.Content; import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.TokenStream; import dev.langchain4j.service.UserMessage; +import dev.langchain4j.service.guardrail.InputGuardrails; +import dev.langchain4j.service.guardrail.OutputGuardrails; import io.quarkiverse.langchain4j.RegisterAiService; -import io.quarkiverse.langchain4j.guardrails.InputGuardrail; -import io.quarkiverse.langchain4j.guardrails.InputGuardrailParams; -import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult; -import io.quarkiverse.langchain4j.guardrails.InputGuardrails; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrailParams; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; import io.quarkus.test.QuarkusUnitTest; import io.smallrye.mutiny.Multi; /** * Verify that the input and output guardrails can access the augmentation results. */ -public class GuardrailWithAugmentationTest { +public class GuardrailWithAugmentationTest extends TokenStreamExecutor { @RegisterExtension static final QuarkusUnitTest unitTest = new QuarkusUnitTest() @@ -72,6 +73,16 @@ void testInputOnly() { assertThat(outputGuardrail.getSpy()).isEqualTo(0); } + @Test + @ActivateRequestContext + void testInputOnlyTokenStream() throws InterruptedException { + var result = execute(() -> service.inputOnlyTokenStream("2", "foo")); + + assertThat(inputGuardrail.getSpy()).isEqualTo(1); + assertThat(outputGuardrail.getSpy()).isEqualTo(0); + assertThat(result).isEqualTo("Streaming hi !"); + } + @Test @ActivateRequestContext void testInputOnlyMulti() { @@ -111,6 +122,9 @@ public interface MyAiService { @InputGuardrails(MyInputGuardrail.class) Multi inputOnlyMulti(@MemoryId String id, @UserMessage String message); + @InputGuardrails(MyInputGuardrail.class) + TokenStream inputOnlyTokenStream(@MemoryId String id, @UserMessage String message); + @OutputGuardrails(MyOutputGuardrail.class) String outputOnly(@MemoryId String id, @UserMessage String message); @@ -125,9 +139,9 @@ public static class MyInputGuardrail implements InputGuardrail { AtomicInteger spy = new AtomicInteger(); @Override - public InputGuardrailResult validate(InputGuardrailParams params) { + public InputGuardrailResult validate(InputGuardrailRequest request) { spy.incrementAndGet(); - assertThat(params.augmentationResult().contents()).hasSize(2); + assertThat(request.requestParams().augmentationResult().contents()).hasSize(2); return InputGuardrailResult.success(); } @@ -142,9 +156,9 @@ public static class MyOutputGuardrail implements OutputGuardrail { AtomicInteger spy = new AtomicInteger(); @Override - public OutputGuardrailResult validate(OutputGuardrailParams params) { + public OutputGuardrailResult validate(OutputGuardrailRequest request) { spy.incrementAndGet(); - assertThat(params.augmentationResult().contents()).hasSize(2); + assertThat(request.requestParams().augmentationResult().contents()).hasSize(2); return OutputGuardrailResult.success(); } diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputAndDeclarativeAiServiceOutputGuardrailsTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputAndDeclarativeAiServiceOutputGuardrailsTest.java new file mode 100644 index 000000000..6d7cafa78 --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputAndDeclarativeAiServiceOutputGuardrailsTest.java @@ -0,0 +1,272 @@ +package io.quarkiverse.langchain4j.test.guardrails; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; + +import jakarta.enterprise.context.RequestScoped; +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.guardrail.InputGuardrail; +import dev.langchain4j.guardrail.InputGuardrailRequest; +import dev.langchain4j.guardrail.InputGuardrailResult; +import dev.langchain4j.guardrail.OutputGuardrail; +import dev.langchain4j.guardrail.OutputGuardrailRequest; +import dev.langchain4j.guardrail.OutputGuardrailResult; +import dev.langchain4j.memory.ChatMemory; +import dev.langchain4j.memory.chat.ChatMemoryProvider; +import dev.langchain4j.memory.chat.MessageWindowChatMemory; +import dev.langchain4j.model.chat.ChatModel; +import dev.langchain4j.model.chat.request.ChatRequest; +import dev.langchain4j.model.chat.response.ChatResponse; +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.UserMessage; +import dev.langchain4j.service.guardrail.InputGuardrails; +import dev.langchain4j.service.guardrail.OutputGuardrails; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkus.test.QuarkusUnitTest; + +public class InputAndDeclarativeAiServiceOutputGuardrailsTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(MyAiService.class, + MyChatModel.class, MyChatModelSupplier.class, MyMemoryProviderSupplier.class, + ValidationException.class)); + + @Inject + MyOkInputGuardrail okIn; + + @Inject + MyKoInputGuardrail koIn; + + @Inject + MyOkOutputGuardrail okOut; + + @Inject + MyKoOutputGuardrail koOut; + + @Inject + MyKoWithRetryOutputGuardrail koOutWithRetry; + + @Inject + MyKoWithRepromprOutputGuardrail koOutWithReprompt; + + @Inject + MyAiService service; + + @Test + @ActivateRequestContext + void testOk() { + assertThat(okIn.getSpy()).isEqualTo(0); + assertThat(okOut.getSpy()).isEqualTo(0); + service.bothOk("1", "foo"); + assertThat(okIn.getSpy()).isEqualTo(1); + assertThat(okOut.getSpy()).isEqualTo(1); + } + + @Test + @ActivateRequestContext + void testInKo() { + assertThat(koIn.getSpy()).isEqualTo(0); + assertThat(okOut.getSpy()).isEqualTo(0); + assertThatThrownBy(() -> service.inKo("2", "foo")) + .hasRootCauseMessage("boom"); + assertThat(koIn.getSpy()).isEqualTo(1); + assertThat(okOut.getSpy()).isEqualTo(0); + } + + @Test + @ActivateRequestContext + void testOutKo() { + assertThat(okIn.getSpy()).isEqualTo(0); + assertThat(koOut.getSpy()).isEqualTo(0); + assertThatThrownBy(() -> service.outKo("2", "foo")) + .hasRootCauseMessage("boom"); + assertThat(okIn.getSpy()).isEqualTo(1); + assertThat(koOut.getSpy()).isEqualTo(1); + } + + @Test + @ActivateRequestContext + void testRetry() { + assertThat(okIn.getSpy()).isEqualTo(0); + assertThat(koOutWithRetry.getSpy()).isEqualTo(0); + service.outKoWithRetry("2", "foo"); + assertThat(okIn.getSpy()).isEqualTo(1); + assertThat(koOutWithRetry.getSpy()).isEqualTo(2); + } + + @Test + @ActivateRequestContext + void testReprompt() { + assertThat(okIn.getSpy()).isEqualTo(0); + assertThat(koOutWithReprompt.getSpy()).isEqualTo(0); + service.outKoWithReprompt("2", "foo"); + assertThat(okIn.getSpy()).isEqualTo(1); + assertThat(koOutWithReprompt.getSpy()).isEqualTo(2); + } + + @RegisterAiService(chatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class) + public interface MyAiService { + + @InputGuardrails(MyOkInputGuardrail.class) + @OutputGuardrails(MyOkOutputGuardrail.class) + String bothOk(@MemoryId String id, @UserMessage String message); + + @InputGuardrails(MyKoInputGuardrail.class) + @OutputGuardrails(MyOkOutputGuardrail.class) + String inKo(@MemoryId String id, @UserMessage String message); + + @InputGuardrails(MyOkInputGuardrail.class) + @OutputGuardrails(MyKoOutputGuardrail.class) + String outKo(@MemoryId String id, @UserMessage String message); + + @InputGuardrails(MyOkInputGuardrail.class) + @OutputGuardrails(MyKoWithRetryOutputGuardrail.class) + String outKoWithRetry(@MemoryId String id, @UserMessage String message); + + @InputGuardrails(MyOkInputGuardrail.class) + @OutputGuardrails(MyKoWithRepromprOutputGuardrail.class) + String outKoWithReprompt(@MemoryId String id, @UserMessage String message); + + } + + @RequestScoped + public static class MyOkInputGuardrail implements InputGuardrail { + + AtomicInteger spy = new AtomicInteger(); + + @Override + public InputGuardrailResult validate(InputGuardrailRequest request) { + spy.incrementAndGet(); + return success(); + } + + public int getSpy() { + return spy.get(); + } + } + + @RequestScoped + public static class MyKoInputGuardrail implements InputGuardrail { + + AtomicInteger spy = new AtomicInteger(); + + @Override + public InputGuardrailResult validate(InputGuardrailRequest request) { + spy.incrementAndGet(); + return failure("boom", new ValidationException("boom")); + } + + public int getSpy() { + return spy.get(); + } + } + + @RequestScoped + public static class MyOkOutputGuardrail implements OutputGuardrail { + + AtomicInteger spy = new AtomicInteger(); + + @Override + public OutputGuardrailResult validate(OutputGuardrailRequest request) { + spy.incrementAndGet(); + return OutputGuardrailResult.success(); + } + + public int getSpy() { + return spy.get(); + } + } + + @RequestScoped + public static class MyKoOutputGuardrail implements OutputGuardrail { + + AtomicInteger spy = new AtomicInteger(); + + @Override + public OutputGuardrailResult validate(OutputGuardrailRequest request) { + spy.incrementAndGet(); + return failure("boom", new ValidationException("boom")); + } + + public int getSpy() { + return spy.get(); + } + } + + @RequestScoped + public static class MyKoWithRetryOutputGuardrail implements OutputGuardrail { + + AtomicInteger spy = new AtomicInteger(); + + @Override + public OutputGuardrailResult validate(OutputGuardrailRequest request) { + if (spy.incrementAndGet() == 1) { + return retry("KO"); + } + return success(); + } + + public int getSpy() { + return spy.get(); + } + } + + @RequestScoped + public static class MyKoWithRepromprOutputGuardrail implements OutputGuardrail { + + AtomicInteger spy = new AtomicInteger(); + + @Override + public OutputGuardrailResult validate(OutputGuardrailRequest request) { + if (spy.incrementAndGet() == 1) { + return reprompt("KO", "retry"); + } + return success(); + } + + public int getSpy() { + return spy.get(); + } + } + + public static class MyChatModelSupplier implements Supplier { + + @Override + public ChatModel get() { + return new MyChatModel(); + } + } + + public static class MyChatModel implements ChatModel { + + @Override + public ChatResponse doChat(ChatRequest request) { + return ChatResponse.builder().aiMessage(new AiMessage("Hi!")).build(); + } + } + + public static class MyMemoryProviderSupplier implements Supplier { + @Override + public ChatMemoryProvider get() { + return new ChatMemoryProvider() { + @Override + public ChatMemory get(Object memoryId) { + return new MessageWindowChatMemory.Builder().maxMessages(5).build(); + } + }; + } + } +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailChainTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailChainTest.java index 1a5798d9a..aca58f849 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailChainTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailChainTest.java @@ -17,6 +17,9 @@ import org.junit.jupiter.api.extension.RegisterExtension; import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.guardrail.GuardrailException; +import dev.langchain4j.guardrail.InputGuardrail; +import dev.langchain4j.guardrail.InputGuardrailResult; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.memory.chat.ChatMemoryProvider; import dev.langchain4j.model.chat.ChatModel; @@ -24,11 +27,8 @@ import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.service.MemoryId; import dev.langchain4j.service.UserMessage; +import dev.langchain4j.service.guardrail.InputGuardrails; import io.quarkiverse.langchain4j.RegisterAiService; -import io.quarkiverse.langchain4j.guardrails.InputGuardrail; -import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult; -import io.quarkiverse.langchain4j.guardrails.InputGuardrails; -import io.quarkiverse.langchain4j.runtime.aiservice.GuardrailException; import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; import io.quarkus.test.QuarkusUnitTest; diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailNotFoundTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailNotFoundTest.java index 3e37ec10b..b8903325a 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailNotFoundTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailNotFoundTest.java @@ -14,6 +14,8 @@ import org.junit.jupiter.api.extension.RegisterExtension; import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.guardrail.InputGuardrail; +import dev.langchain4j.guardrail.InputGuardrailResult; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.memory.chat.ChatMemoryProvider; import dev.langchain4j.model.chat.ChatModel; @@ -21,10 +23,8 @@ import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.service.MemoryId; import dev.langchain4j.service.UserMessage; +import dev.langchain4j.service.guardrail.InputGuardrails; import io.quarkiverse.langchain4j.RegisterAiService; -import io.quarkiverse.langchain4j.guardrails.InputGuardrail; -import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult; -import io.quarkiverse.langchain4j.guardrails.InputGuardrails; import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; import io.quarkus.test.QuarkusUnitTest; diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailOnClassAndMethodTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailOnClassAndMethodTest.java index ee19110a5..59581f869 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailOnClassAndMethodTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailOnClassAndMethodTest.java @@ -15,6 +15,8 @@ import org.junit.jupiter.api.extension.RegisterExtension; import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.guardrail.InputGuardrail; +import dev.langchain4j.guardrail.InputGuardrailResult; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.memory.chat.ChatMemoryProvider; import dev.langchain4j.model.chat.ChatModel; @@ -22,10 +24,8 @@ import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.service.MemoryId; import dev.langchain4j.service.UserMessage; +import dev.langchain4j.service.guardrail.InputGuardrails; import io.quarkiverse.langchain4j.RegisterAiService; -import io.quarkiverse.langchain4j.guardrails.InputGuardrail; -import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult; -import io.quarkiverse.langchain4j.guardrails.InputGuardrails; import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; import io.quarkus.test.QuarkusUnitTest; diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailOnClassTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailOnClassTest.java index f496f4ab4..8e6117b78 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailOnClassTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailOnClassTest.java @@ -15,6 +15,8 @@ import org.junit.jupiter.api.extension.RegisterExtension; import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.guardrail.InputGuardrail; +import dev.langchain4j.guardrail.InputGuardrailResult; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.memory.chat.ChatMemoryProvider; import dev.langchain4j.model.chat.ChatModel; @@ -22,10 +24,8 @@ import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.service.MemoryId; import dev.langchain4j.service.UserMessage; +import dev.langchain4j.service.guardrail.InputGuardrails; import io.quarkiverse.langchain4j.RegisterAiService; -import io.quarkiverse.langchain4j.guardrails.InputGuardrail; -import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult; -import io.quarkiverse.langchain4j.guardrails.InputGuardrails; import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; import io.quarkus.test.QuarkusUnitTest; diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailPromptTemplateTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailPromptTemplateTest.java index c66acda9f..388268cf7 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailPromptTemplateTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailPromptTemplateTest.java @@ -16,17 +16,17 @@ import org.junit.jupiter.api.extension.RegisterExtension; import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.guardrail.InputGuardrail; +import dev.langchain4j.guardrail.InputGuardrailRequest; +import dev.langchain4j.guardrail.InputGuardrailResult; import dev.langchain4j.memory.chat.ChatMemoryProvider; import dev.langchain4j.model.chat.ChatModel; import dev.langchain4j.model.chat.request.ChatRequest; import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.service.MemoryId; import dev.langchain4j.service.UserMessage; +import dev.langchain4j.service.guardrail.InputGuardrails; import io.quarkiverse.langchain4j.RegisterAiService; -import io.quarkiverse.langchain4j.guardrails.InputGuardrail; -import io.quarkiverse.langchain4j.guardrails.InputGuardrailParams; -import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult; -import io.quarkiverse.langchain4j.guardrails.InputGuardrails; import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; import io.quarkus.test.QuarkusUnitTest; @@ -189,23 +189,23 @@ public interface MyAiService { @RequestScoped public static class GuardrailValidation implements InputGuardrail { - InputGuardrailParams params; + InputGuardrailRequest request; - public InputGuardrailResult validate(InputGuardrailParams params) { - this.params = params; + public InputGuardrailResult validate(InputGuardrailRequest request) { + this.request = request; return success(); } public String spyUserMessageTemplate() { - return params.userMessageTemplate(); + return request.requestParams().userMessageTemplate(); } public String spyUserMessageText() { - return params.userMessage().singleText(); + return request.userMessage().singleText(); } public Map spyVariables() { - return params.variables(); + return request.requestParams().variables(); } } diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailRewritingTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailRewritingTest.java index 1c2326775..3bff71a60 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailRewritingTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailRewritingTest.java @@ -14,16 +14,16 @@ import org.junit.jupiter.api.extension.RegisterExtension; import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.guardrail.InputGuardrail; +import dev.langchain4j.guardrail.InputGuardrailResult; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.memory.chat.ChatMemoryProvider; import dev.langchain4j.model.chat.ChatModel; import dev.langchain4j.model.chat.request.ChatRequest; import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.service.UserMessage; +import dev.langchain4j.service.guardrail.InputGuardrails; import io.quarkiverse.langchain4j.RegisterAiService; -import io.quarkiverse.langchain4j.guardrails.InputGuardrail; -import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult; -import io.quarkiverse.langchain4j.guardrails.InputGuardrails; import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; import io.quarkus.test.QuarkusUnitTest; diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailTest.java index 85e9bf7f7..e22a45cc0 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailTest.java @@ -17,6 +17,8 @@ import org.junit.jupiter.api.extension.RegisterExtension; import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.guardrail.InputGuardrail; +import dev.langchain4j.guardrail.InputGuardrailResult; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.memory.chat.ChatMemoryProvider; import dev.langchain4j.model.chat.ChatModel; @@ -24,10 +26,8 @@ import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.service.MemoryId; import dev.langchain4j.service.UserMessage; +import dev.langchain4j.service.guardrail.InputGuardrails; import io.quarkiverse.langchain4j.RegisterAiService; -import io.quarkiverse.langchain4j.guardrails.InputGuardrail; -import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult; -import io.quarkiverse.langchain4j.guardrails.InputGuardrails; import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; import io.quarkus.arc.Arc; import io.quarkus.test.QuarkusUnitTest; diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailValidationTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailValidationTest.java index 155a6f921..eec41ccc0 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailValidationTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailValidationTest.java @@ -19,6 +19,10 @@ import org.junit.jupiter.api.extension.RegisterExtension; import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.guardrail.InputGuardrail; +import dev.langchain4j.guardrail.InputGuardrailException; +import dev.langchain4j.guardrail.InputGuardrailRequest; +import dev.langchain4j.guardrail.InputGuardrailResult; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.memory.chat.ChatMemoryProvider; import dev.langchain4j.memory.chat.MessageWindowChatMemory; @@ -28,17 +32,14 @@ import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.TokenStream; import dev.langchain4j.service.UserMessage; +import dev.langchain4j.service.guardrail.InputGuardrails; import io.quarkiverse.langchain4j.RegisterAiService; -import io.quarkiverse.langchain4j.guardrails.InputGuardrail; -import io.quarkiverse.langchain4j.guardrails.InputGuardrailParams; -import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult; -import io.quarkiverse.langchain4j.guardrails.InputGuardrails; -import io.quarkiverse.langchain4j.runtime.aiservice.GuardrailException; import io.quarkus.test.QuarkusUnitTest; import io.smallrye.mutiny.Multi; -public class InputGuardrailValidationTest { +public class InputGuardrailValidationTest extends TokenStreamExecutor { @RegisterExtension static final QuarkusUnitTest unitTest = new QuarkusUnitTest() @@ -60,7 +61,7 @@ void testOk() { @ActivateRequestContext void testKo() { assertThatThrownBy(() -> aiService.ko("2")) - .isInstanceOf(GuardrailException.class) + .isInstanceOf(InputGuardrailException.class) .hasMessageContaining("KO"); } @@ -77,7 +78,23 @@ void testOkMulti() { @ActivateRequestContext void testKoMulti() { assertThatThrownBy(() -> aiService.koMulti("2").subscribe().asIterable()) - .isInstanceOf(GuardrailException.class) + .isInstanceOf(InputGuardrailException.class) + .hasMessageContaining("KO"); + } + + @Test + @ActivateRequestContext + void testOkTokenStream() throws InterruptedException { + var strings = execute(() -> aiService.okTokenStream("1")); + + assertThat(strings).isEqualTo("Streaming hi !"); + } + + @Test + @ActivateRequestContext + void testKoTokenStream() { + assertThatThrownBy(() -> aiService.koMulti("2")) + .isInstanceOf(InputGuardrailException.class) .hasMessageContaining("KO"); } @@ -88,7 +105,7 @@ void testKoMulti() { @ActivateRequestContext void testFatalException() { assertThatThrownBy(() -> aiService.fatal("5")) - .isInstanceOf(GuardrailException.class) + .isInstanceOf(InputGuardrailException.class) .hasMessageContaining("Fatal"); assertThat(fatal.spy()).isEqualTo(1); } @@ -119,6 +136,14 @@ public interface MyAiService { @InputGuardrails(KOGuardrail.class) Multi koMulti(@MemoryId String mem); + @UserMessage("Say Hi!") + @InputGuardrails(OKGuardrail.class) + TokenStream okTokenStream(@MemoryId String mem); + + @UserMessage("Say Hi!") + @InputGuardrails(KOGuardrail.class) + TokenStream koTokenStream(@MemoryId String mem); + @UserMessage("Say Hi!") @InputGuardrails(KOFatalGuardrail.class) String fatal(@MemoryId String mem); @@ -181,15 +206,16 @@ public static class MemoryCheck implements InputGuardrail { AtomicInteger spy = new AtomicInteger(0); @Override - public InputGuardrailResult validate(InputGuardrailParams params) { + public InputGuardrailResult validate(InputGuardrailRequest request) { spy.incrementAndGet(); - if (params.memory().messages().isEmpty()) { - assertThat(params.userMessage().singleText()).isEqualTo("foo"); + var messages = request.requestParams().chatMemory().messages(); + if (messages.isEmpty()) { + assertThat(request.userMessage().singleText()).isEqualTo("foo"); } - if (params.memory().messages().size() == 2) { - assertThat(chatMessageToText(params.memory().messages().get(0))).isEqualTo("foo"); - assertThat(chatMessageToText(params.memory().messages().get(1))).isEqualTo("Hi!"); - assertThat(params.userMessage().singleText()).isEqualTo("bar"); + if (messages.size() == 2) { + assertThat(chatMessageToText(messages.get(0))).isEqualTo("foo"); + assertThat(chatMessageToText(messages.get(1))).isEqualTo("Hi!"); + assertThat(request.userMessage().singleText()).isEqualTo("bar"); } return success(); } diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InvalidOutputGuardrailAccumulatorTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InvalidOutputGuardrailAccumulatorTest.java index 94cd06e31..f254a6558 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InvalidOutputGuardrailAccumulatorTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InvalidOutputGuardrailAccumulatorTest.java @@ -17,6 +17,8 @@ import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.guardrail.OutputGuardrail; +import dev.langchain4j.guardrail.OutputGuardrailResult; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.memory.chat.ChatMemoryProvider; import dev.langchain4j.memory.chat.MessageWindowChatMemory; @@ -25,11 +27,9 @@ import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; import dev.langchain4j.service.MemoryId; import dev.langchain4j.service.UserMessage; +import dev.langchain4j.service.guardrail.OutputGuardrails; import io.quarkiverse.langchain4j.RegisterAiService; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; import io.quarkiverse.langchain4j.guardrails.OutputGuardrailAccumulator; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; import io.quarkiverse.langchain4j.guardrails.OutputTokenAccumulator; import io.quarkus.test.QuarkusUnitTest; import io.smallrye.mutiny.Multi; diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/JsonGuardrailsUtilsTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/JsonGuardrailsUtilsTest.java index 034cdbff6..315c296fd 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/JsonGuardrailsUtilsTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/JsonGuardrailsUtilsTest.java @@ -16,6 +16,10 @@ import io.quarkiverse.langchain4j.guardrails.JsonGuardrailsUtils; import io.quarkus.test.QuarkusUnitTest; +/** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ +@Deprecated(forRemoval = true) class JsonGuardrailsUtilsTest { @RegisterExtension diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailAccumulatorNotFoundTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailAccumulatorNotFoundTest.java index b1654c2d2..25c478870 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailAccumulatorNotFoundTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailAccumulatorNotFoundTest.java @@ -17,6 +17,8 @@ import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.guardrail.OutputGuardrail; +import dev.langchain4j.guardrail.OutputGuardrailResult; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.memory.chat.ChatMemoryProvider; import dev.langchain4j.memory.chat.MessageWindowChatMemory; @@ -25,11 +27,9 @@ import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; import dev.langchain4j.service.MemoryId; import dev.langchain4j.service.UserMessage; +import dev.langchain4j.service.guardrail.OutputGuardrails; import io.quarkiverse.langchain4j.RegisterAiService; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; import io.quarkiverse.langchain4j.guardrails.OutputGuardrailAccumulator; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; import io.quarkiverse.langchain4j.guardrails.OutputTokenAccumulator; import io.quarkus.test.QuarkusUnitTest; import io.smallrye.mutiny.Multi; diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailAccumulatorTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailAccumulatorTest.java index 83b770f38..e52647d77 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailAccumulatorTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailAccumulatorTest.java @@ -18,6 +18,9 @@ import org.junit.jupiter.api.extension.RegisterExtension; import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.guardrail.OutputGuardrail; +import dev.langchain4j.guardrail.OutputGuardrailRequest; +import dev.langchain4j.guardrail.OutputGuardrailResult; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.memory.chat.ChatMemoryProvider; import dev.langchain4j.memory.chat.MessageWindowChatMemory; @@ -27,12 +30,9 @@ import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; import dev.langchain4j.service.MemoryId; import dev.langchain4j.service.UserMessage; +import dev.langchain4j.service.guardrail.OutputGuardrails; import io.quarkiverse.langchain4j.RegisterAiService; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; import io.quarkiverse.langchain4j.guardrails.OutputGuardrailAccumulator; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrailParams; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; import io.quarkiverse.langchain4j.guardrails.OutputTokenAccumulator; import io.quarkus.test.QuarkusUnitTest; import io.smallrye.mutiny.Multi; @@ -188,9 +188,9 @@ public void reset() { } @Override - public OutputGuardrailResult validate(OutputGuardrailParams params) { + public OutputGuardrailResult validate(OutputGuardrailRequest request) { count.incrementAndGet(); - chunks.add(params.responseFromLLM().text()); + chunks.add(request.responseFromLLM().aiMessage().text()); return success(); } } diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailChainOnStreamedResponseTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailChainOnStreamedResponseTest.java index efb639411..e9562e9cd 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailChainOnStreamedResponseTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailChainOnStreamedResponseTest.java @@ -17,6 +17,8 @@ import org.junit.jupiter.api.extension.RegisterExtension; import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.guardrail.OutputGuardrail; +import dev.langchain4j.guardrail.OutputGuardrailResult; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.memory.chat.ChatMemoryProvider; import dev.langchain4j.model.chat.StreamingChatModel; @@ -25,11 +27,9 @@ import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; import dev.langchain4j.service.MemoryId; import dev.langchain4j.service.UserMessage; +import dev.langchain4j.service.guardrail.OutputGuardrails; import io.quarkiverse.langchain4j.RegisterAiService; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; import io.quarkiverse.langchain4j.guardrails.OutputGuardrailAccumulator; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; import io.quarkiverse.langchain4j.guardrails.OutputTokenAccumulator; import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; import io.quarkus.test.QuarkusUnitTest; @@ -75,7 +75,6 @@ void testThatGuardrailOrderIsCorrect() { @ActivateRequestContext void testThatRetryRestartTheChain() { aiService.failingFirstTwo("1", "foo").collect().asList().await().indefinitely(); - ; assertThat(firstGuardrail.spy()).isEqualTo(2); assertThat(secondGuardrail.spy()).isEqualTo(1); assertThat(failingGuardrail.spy()).isEqualTo(2); @@ -105,7 +104,6 @@ void testThatGuardrailOrderIsCorrectWithPassThroughAccumulator() { @ActivateRequestContext void testThatRetryRestartTheChainWithPassThroughAccumulator() { aiService.failingFirstTwoWithPassThroughAccumulator("1", "foo").collect().asList().await().indefinitely(); - ; assertThat(firstGuardrail.spy()).isEqualTo(4); assertThat(secondGuardrail.spy()).isEqualTo(3); assertThat(failingGuardrail.spy()).isEqualTo(4); diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailChainOnTokenStreamResponseTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailChainOnTokenStreamResponseTest.java new file mode 100644 index 000000000..f2a1aa9be --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailChainOnTokenStreamResponseTest.java @@ -0,0 +1,194 @@ +package io.quarkiverse.langchain4j.test.guardrails; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Supplier; + +import jakarta.enterprise.context.RequestScoped; +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.guardrail.OutputGuardrail; +import dev.langchain4j.guardrail.OutputGuardrailResult; +import dev.langchain4j.memory.ChatMemory; +import dev.langchain4j.memory.chat.ChatMemoryProvider; +import dev.langchain4j.memory.chat.MessageWindowChatMemory; +import dev.langchain4j.model.chat.StreamingChatModel; +import dev.langchain4j.model.chat.request.ChatRequest; +import dev.langchain4j.model.chat.response.ChatResponse; +import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.TokenStream; +import dev.langchain4j.service.UserMessage; +import dev.langchain4j.service.guardrail.OutputGuardrails; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkus.test.QuarkusUnitTest; + +public class OutputGuardrailChainOnTokenStreamResponseTest extends TokenStreamExecutor { + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)); + + @Inject + MyAiService aiService; + + @Inject + FirstGuardrail firstGuardrail; + + @Inject + SecondGuardrail secondGuardrail; + + @Inject + FailingGuardrail failingGuardrail; + + @Test + @ActivateRequestContext + void testThatGuardrailChainsAreInvoked() throws InterruptedException { + execute(() -> aiService.firstOneTwo("1", "foo")); + assertThat(firstGuardrail.spy()).isEqualTo(1); + assertThat(secondGuardrail.spy()).isEqualTo(1); + assertThat(firstGuardrail.lastAccess()).isLessThan(secondGuardrail.lastAccess()); + } + + @Test + @ActivateRequestContext + void testThatGuardrailOrderIsCorrect() throws InterruptedException { + execute(() -> aiService.twoAndFirst("1", "foo")); + assertThat(firstGuardrail.spy()).isEqualTo(1); + assertThat(secondGuardrail.spy()).isEqualTo(1); + assertThat(secondGuardrail.lastAccess()).isLessThan(firstGuardrail.lastAccess()); + } + + @Test + @ActivateRequestContext + void testThatRetryRestartTheChain() throws InterruptedException { + execute(() -> aiService.failingFirstTwo("1", "foo")); + assertThat(firstGuardrail.spy()).isEqualTo(2); + assertThat(secondGuardrail.spy()).isEqualTo(1); + assertThat(failingGuardrail.spy()).isEqualTo(2); + assertThat(firstGuardrail.lastAccess()).isLessThan(secondGuardrail.lastAccess()); + } + + @RegisterAiService(streamingChatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class) + public interface MyAiService { + @OutputGuardrails({ FirstGuardrail.class, SecondGuardrail.class }) + TokenStream firstOneTwo(@MemoryId String mem, @UserMessage String message); + + @OutputGuardrails({ SecondGuardrail.class, FirstGuardrail.class }) + TokenStream twoAndFirst(@MemoryId String mem, @UserMessage String message); + + @OutputGuardrails({ FirstGuardrail.class, FailingGuardrail.class, SecondGuardrail.class }) + TokenStream failingFirstTwo(@MemoryId String mem, @UserMessage String message); + } + + @RequestScoped + public static class FirstGuardrail implements OutputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + AtomicLong lastAccess = new AtomicLong(); + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + spy.incrementAndGet(); + lastAccess.set(System.nanoTime()); + try { + Thread.sleep(1); + } catch (InterruptedException e) { + // Ignore me + } + return success(); + } + + public int spy() { + return spy.get(); + } + + public long lastAccess() { + return lastAccess.get(); + } + } + + @RequestScoped + public static class SecondGuardrail implements OutputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + volatile AtomicLong lastAccess = new AtomicLong(); + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + spy.incrementAndGet(); + lastAccess.set(System.nanoTime()); + try { + Thread.sleep(1); + } catch (InterruptedException e) { + // Ignore me + } + return success(); + } + + public int spy() { + return spy.get(); + } + + public long lastAccess() { + return lastAccess.get(); + } + } + + @RequestScoped + public static class FailingGuardrail implements OutputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + if (spy.incrementAndGet() == 1) { + return reprompt("Retry", "Retry"); + } + return success(); + } + + public int spy() { + return spy.get(); + } + } + + public static class MyChatModelSupplier implements Supplier { + + @Override + public StreamingChatModel get() { + return new MyStreamedChatModel(); + } + } + + public static class MyStreamedChatModel implements StreamingChatModel { + + @Override + public void doChat(ChatRequest chatRequest, StreamingChatResponseHandler handler) { + handler.onPartialResponse("Hi!"); + handler.onPartialResponse(" "); + handler.onPartialResponse("World!"); + handler.onCompleteResponse(ChatResponse.builder().aiMessage(new AiMessage("")).build()); + } + } + + public static class MyMemoryProviderSupplier implements Supplier { + @Override + public ChatMemoryProvider get() { + return new ChatMemoryProvider() { + @Override + public ChatMemory get(Object memoryId) { + return MessageWindowChatMemory.withMaxMessages(10); + } + }; + } + } +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailChainTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailChainTest.java index 19cf4a545..2efb09f99 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailChainTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailChainTest.java @@ -17,6 +17,9 @@ import org.junit.jupiter.api.extension.RegisterExtension; import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.guardrail.GuardrailException; +import dev.langchain4j.guardrail.OutputGuardrail; +import dev.langchain4j.guardrail.OutputGuardrailResult; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.memory.chat.ChatMemoryProvider; import dev.langchain4j.model.chat.ChatModel; @@ -24,11 +27,8 @@ import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.service.MemoryId; import dev.langchain4j.service.UserMessage; +import dev.langchain4j.service.guardrail.OutputGuardrails; import io.quarkiverse.langchain4j.RegisterAiService; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; -import io.quarkiverse.langchain4j.runtime.aiservice.GuardrailException; import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; import io.quarkus.test.QuarkusUnitTest; diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailNotFoundTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailNotFoundTest.java index 76416f672..dd07aeaf6 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailNotFoundTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailNotFoundTest.java @@ -14,6 +14,8 @@ import org.junit.jupiter.api.extension.RegisterExtension; import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.guardrail.OutputGuardrail; +import dev.langchain4j.guardrail.OutputGuardrailResult; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.memory.chat.ChatMemoryProvider; import dev.langchain4j.model.chat.ChatModel; @@ -21,10 +23,8 @@ import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.service.MemoryId; import dev.langchain4j.service.UserMessage; +import dev.langchain4j.service.guardrail.OutputGuardrails; import io.quarkiverse.langchain4j.RegisterAiService; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; import io.quarkus.test.QuarkusUnitTest; diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailOnClassAndMethodTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailOnClassAndMethodTest.java index 931c6f502..3a0236736 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailOnClassAndMethodTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailOnClassAndMethodTest.java @@ -15,6 +15,8 @@ import org.junit.jupiter.api.extension.RegisterExtension; import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.guardrail.OutputGuardrail; +import dev.langchain4j.guardrail.OutputGuardrailResult; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.memory.chat.ChatMemoryProvider; import dev.langchain4j.model.chat.ChatModel; @@ -22,10 +24,8 @@ import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.service.MemoryId; import dev.langchain4j.service.UserMessage; +import dev.langchain4j.service.guardrail.OutputGuardrails; import io.quarkiverse.langchain4j.RegisterAiService; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; import io.quarkus.test.QuarkusUnitTest; diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailOnClassTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailOnClassTest.java index 860d4fd48..d3756f938 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailOnClassTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailOnClassTest.java @@ -15,6 +15,8 @@ import org.junit.jupiter.api.extension.RegisterExtension; import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.guardrail.OutputGuardrail; +import dev.langchain4j.guardrail.OutputGuardrailResult; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.memory.chat.ChatMemoryProvider; import dev.langchain4j.model.chat.ChatModel; @@ -22,10 +24,8 @@ import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.service.MemoryId; import dev.langchain4j.service.UserMessage; +import dev.langchain4j.service.guardrail.OutputGuardrails; import io.quarkiverse.langchain4j.RegisterAiService; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; import io.quarkus.test.QuarkusUnitTest; diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailOnStreamedResponseTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailOnStreamedResponseTest.java index a0f9e5a7c..7676f24a2 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailOnStreamedResponseTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailOnStreamedResponseTest.java @@ -17,6 +17,8 @@ import org.junit.jupiter.api.extension.RegisterExtension; import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.guardrail.OutputGuardrail; +import dev.langchain4j.guardrail.OutputGuardrailResult; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.memory.chat.ChatMemoryProvider; import dev.langchain4j.model.chat.StreamingChatModel; @@ -25,11 +27,9 @@ import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; import dev.langchain4j.service.MemoryId; import dev.langchain4j.service.UserMessage; +import dev.langchain4j.service.guardrail.OutputGuardrails; import io.quarkiverse.langchain4j.RegisterAiService; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; import io.quarkiverse.langchain4j.guardrails.OutputGuardrailAccumulator; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; import io.quarkiverse.langchain4j.guardrails.OutputTokenAccumulator; import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; import io.quarkus.arc.Arc; diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailOnStreamedResponseValidationTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailOnStreamedResponseValidationTest.java index bc48fa959..c95f839b0 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailOnStreamedResponseValidationTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailOnStreamedResponseValidationTest.java @@ -17,6 +17,9 @@ import org.junit.jupiter.api.extension.RegisterExtension; import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.guardrail.GuardrailException; +import dev.langchain4j.guardrail.OutputGuardrail; +import dev.langchain4j.guardrail.OutputGuardrailResult; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.memory.chat.ChatMemoryProvider; import dev.langchain4j.model.chat.StreamingChatModel; @@ -25,13 +28,10 @@ import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; import dev.langchain4j.service.MemoryId; import dev.langchain4j.service.UserMessage; +import dev.langchain4j.service.guardrail.OutputGuardrails; import io.quarkiverse.langchain4j.RegisterAiService; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; import io.quarkiverse.langchain4j.guardrails.OutputGuardrailAccumulator; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; import io.quarkiverse.langchain4j.guardrails.OutputTokenAccumulator; -import io.quarkiverse.langchain4j.runtime.aiservice.GuardrailException; import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; import io.quarkus.test.QuarkusUnitTest; import io.smallrye.mutiny.Multi; @@ -46,16 +46,24 @@ public class OutputGuardrailOnStreamedResponseValidationTest { @Inject MyAiService aiService; + @Inject + OKGuardrail okGuardrail; + + @Inject + RewritingGuardrail rewriting; + @Test @ActivateRequestContext void testOk() { aiService.ok("1").collect().asList().await().indefinitely(); + assertThat(okGuardrail.spy()).isEqualTo(1); } @Test @ActivateRequestContext void testOkWithPassThroughAccumulator() { aiService.okWithPassThroughAccumulator("1").collect().asList().await().indefinitely(); + assertThat(okGuardrail.spy()).isEqualTo(3); } @Test @@ -80,14 +88,17 @@ void testKOWithPassThroughAccumulator() { @Test @ActivateRequestContext void testRetryOk() { - aiService.retry("3").collect().asList().await().indefinitely(); + assertThat(aiService.retry("3").collect().asList().await().indefinitely()) + .singleElement() + .isEqualTo("Hi! World!"); assertThat(retry.spy()).isEqualTo(2); } @Test @ActivateRequestContext void testRetryOkWithPassThroughAccumulator() { - aiService.retryWithPassThroughAccumulator("3").collect().asList().await().indefinitely(); + assertThat(aiService.retryWithPassThroughAccumulator("3").collect().asList().await().indefinitely()) + .containsExactly("Hi!", " ", "World!"); assertThat(retry.spy()).isEqualTo(4); // "Hi!", "Hi!" (retry), " ", "World!" } @@ -140,10 +151,11 @@ void testFatalExceptionWithPassThroughAccumulator() { @Test @ActivateRequestContext - void testRewritingWhileStreamingIsNotAllowed() { - assertThatThrownBy(() -> aiService.rewriting("1").collect().asList().await().indefinitely()) - .isInstanceOf(GuardrailException.class) - .hasMessageContaining("Attempting to rewrite the LLM output while streaming is not allowed"); + void rewritingWhileStreaming() throws InterruptedException { + assertThat(aiService.rewriting("1").collect().asList().await().indefinitely()) + .singleElement() + .isEqualTo("Hi! World!,1"); + assertThat(rewriting.spy()).isEqualTo(1); } @RegisterAiService(streamingChatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class) @@ -268,7 +280,6 @@ public int spy() { @RequestScoped public static class KOFatalGuardrail implements OutputGuardrail { - AtomicInteger spy = new AtomicInteger(0); @Override @@ -284,12 +295,18 @@ public int spy() { @RequestScoped public static class RewritingGuardrail implements OutputGuardrail { + AtomicInteger spy = new AtomicInteger(0); @Override public OutputGuardrailResult validate(AiMessage responseFromLLM) { + spy.incrementAndGet(); String text = responseFromLLM.text(); return successWith(text + ",1"); } + + public int spy() { + return spy.get(); + } } public static class MyChatModelSupplier implements Supplier { diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailOnTokenStreamedResponseTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailOnTokenStreamedResponseTest.java new file mode 100644 index 000000000..342d18f89 --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailOnTokenStreamedResponseTest.java @@ -0,0 +1,183 @@ +package io.quarkiverse.langchain4j.test.guardrails; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; + +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; + +import jakarta.enterprise.context.RequestScoped; +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.guardrail.OutputGuardrail; +import dev.langchain4j.guardrail.OutputGuardrailException; +import dev.langchain4j.guardrail.OutputGuardrailResult; +import dev.langchain4j.memory.ChatMemory; +import dev.langchain4j.memory.chat.ChatMemoryProvider; +import dev.langchain4j.model.chat.StreamingChatModel; +import dev.langchain4j.model.chat.request.ChatRequest; +import dev.langchain4j.model.chat.response.ChatResponse; +import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.TokenStream; +import dev.langchain4j.service.UserMessage; +import dev.langchain4j.service.guardrail.OutputGuardrails; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; +import io.quarkus.arc.Arc; +import io.quarkus.test.QuarkusUnitTest; + +public class OutputGuardrailOnTokenStreamedResponseTest extends TokenStreamExecutor { + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(MyAiService.class, OKGuardrail.class, KOGuardrail.class, + MyChatModelSupplier.class, MyMemoryProviderSupplier.class, ValidationException.class)); + + @Inject + MyAiService aiService; + + @Inject + OKGuardrail okGuardrail; + + @Inject + KOGuardrail koGuardrail; + + @Test + void testThatOutputGuardrailsAreInvoked() throws InterruptedException { + assertThat(Arc.container().requestContext().isActive()).isFalse(); + Arc.container().requestContext().activate(); + try { + assertThat(okGuardrail.spy()).isEqualTo(0); + execute(() -> aiService.hi("1")); + assertThat(okGuardrail.spy()).isEqualTo(1); + execute(() -> aiService.hi("2")); + assertThat(okGuardrail.spy()).isEqualTo(2); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } finally { + Arc.container().requestContext().deactivate(); + } + + Arc.container().requestContext().activate(); + try { + // New request scope - the value should be back to 0 + assertThat(okGuardrail.spy()).isEqualTo(0); + execute(() -> aiService.hi("1")); + assertThat(okGuardrail.spy()).isEqualTo(1); + execute(() -> aiService.hi("1")); + assertThat(okGuardrail.spy()).isEqualTo(2); + } finally { + Arc.container().requestContext().deactivate(); + } + + assertThat(Arc.container().requestContext().isActive()).isFalse(); + } + + @Test + @ActivateRequestContext + void testThatGuardrailCanThrowValidationException() { + assertThat(koGuardrail.spy()).isEqualTo(0); + assertThatExceptionOfType(OutputGuardrailException.class) + .isThrownBy(() -> execute(() -> aiService.ko("1"))) + .havingCause() + .isInstanceOf(ValidationException.class); + assertThat(koGuardrail.spy()).isEqualTo(1); + assertThatExceptionOfType(OutputGuardrailException.class) + .isThrownBy(() -> execute(() -> aiService.ko("1"))) + .havingCause() + .isInstanceOf(ValidationException.class); + assertThat(koGuardrail.spy()).isEqualTo(2); + } + + @RegisterAiService(streamingChatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class) + public interface MyAiService { + + @UserMessage("Say Hi!") + @OutputGuardrails(OKGuardrail.class) + TokenStream hi(@MemoryId String mem); + + @UserMessage("Say Hi!") + @OutputGuardrails(KOGuardrail.class) + TokenStream ko(@MemoryId String mem); + + } + + @RequestScoped + public static class OKGuardrail implements OutputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + spy.incrementAndGet(); + return success(); + } + + public int spy() { + return spy.get(); + } + } + + @RequestScoped + public static class KOGuardrail implements OutputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + spy.incrementAndGet(); + if (responseFromLLM.text().length() > 3) { // Accumulated response. + return failure("KO", new ValidationException("KO")); + } else { // Chunk, do not fail on the first chunk + if (responseFromLLM.text().contains("Hi!")) { + return success(); + } else { + return failure("KO", new ValidationException("KO")); + } + } + } + + public int spy() { + return spy.get(); + } + } + + public static class MyChatModelSupplier implements Supplier { + + @Override + public StreamingChatModel get() { + return new MyStreamedChatModel(); + } + } + + public static class MyStreamedChatModel implements StreamingChatModel { + + @Override + public void doChat(ChatRequest chatRequest, StreamingChatResponseHandler handler) { + handler.onPartialResponse("Hi!"); + handler.onPartialResponse(" "); + handler.onPartialResponse("World!"); + handler.onCompleteResponse(ChatResponse.builder().aiMessage(new AiMessage("")).build()); + } + } + + public static class MyMemoryProviderSupplier implements Supplier { + @Override + public ChatMemoryProvider get() { + return new ChatMemoryProvider() { + @Override + public ChatMemory get(Object memoryId) { + return new NoopChatMemory(); + } + }; + } + } +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailOnTokenStreamedResponseValidationTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailOnTokenStreamedResponseValidationTest.java new file mode 100644 index 000000000..f2de0c414 --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailOnTokenStreamedResponseValidationTest.java @@ -0,0 +1,262 @@ +package io.quarkiverse.langchain4j.test.guardrails; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; + +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.context.RequestScoped; +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.guardrail.OutputGuardrail; +import dev.langchain4j.guardrail.OutputGuardrailException; +import dev.langchain4j.guardrail.OutputGuardrailResult; +import dev.langchain4j.memory.ChatMemory; +import dev.langchain4j.memory.chat.ChatMemoryProvider; +import dev.langchain4j.memory.chat.MessageWindowChatMemory; +import dev.langchain4j.model.chat.StreamingChatModel; +import dev.langchain4j.model.chat.request.ChatRequest; +import dev.langchain4j.model.chat.response.ChatResponse; +import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.TokenStream; +import dev.langchain4j.service.UserMessage; +import dev.langchain4j.service.guardrail.OutputGuardrails; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkus.test.QuarkusUnitTest; + +public class OutputGuardrailOnTokenStreamedResponseValidationTest extends TokenStreamExecutor { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(MyStreamedChatModel.class)); + + @Inject + MyAiService aiService; + + @Inject + RetryingGuardrail retry; + + @Inject + RewritingGuardrail rewriting; + + @Inject + RetryingButFailGuardrail retryFail; + + @Inject + KOFatalGuardrail fatal; + + @Test + @ActivateRequestContext + void testOk() throws InterruptedException { + execute(() -> aiService.ok("1")); + } + + @Test + @ActivateRequestContext + void testKO() { + assertThatExceptionOfType(OutputGuardrailException.class) + .isThrownBy(() -> execute(() -> aiService.ko("2"))) + .withMessageContaining("KO"); + } + + @Test + @ActivateRequestContext + void testRetryOk() throws InterruptedException { + execute(() -> aiService.retry("3")); + assertThat(retry.spy()).isEqualTo(2); + } + + @Test + @ActivateRequestContext + void testRetryFail() { + assertThatExceptionOfType(OutputGuardrailException.class) + .isThrownBy(() -> execute(() -> aiService.retryButFail("4"))) + .withMessageContaining("maximum number of retries"); + assertThat(retryFail.spy()).isEqualTo(3); + } + + @Test + @ActivateRequestContext + void testFatalException() { + assertThatExceptionOfType(OutputGuardrailException.class) + .isThrownBy(() -> execute(() -> aiService.fatal("5"))) + .withMessageContaining("Fatal"); + assertThat(fatal.spy()).isEqualTo(1); + } + + @Test + @ActivateRequestContext + void testRewritingWhileStreaming() throws InterruptedException { + assertThat(execute(() -> aiService.rewriting("1"))).isEqualTo("Hi! World! ,1"); + assertThat(rewriting.spy()).isEqualTo(1); + } + + @RegisterAiService(streamingChatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class) + public interface MyAiService { + + @UserMessage("Say Hi!") + @OutputGuardrails(OKGuardrail.class) + TokenStream ok(@MemoryId String mem); + + @UserMessage("Say Hi!") + @OutputGuardrails(KOGuardrail.class) + TokenStream ko(@MemoryId String mem); + + @UserMessage("Say Hi!") + @OutputGuardrails(RetryingGuardrail.class) + TokenStream retry(@MemoryId String mem); + + @UserMessage("Say Hi!") + @OutputGuardrails(RetryingButFailGuardrail.class) + TokenStream retryButFail(@MemoryId String mem); + + @UserMessage("Say Hi!") + @OutputGuardrails(KOFatalGuardrail.class) + TokenStream fatal(@MemoryId String mem); + + @UserMessage("Say Hi!") + @OutputGuardrails({ RewritingGuardrail.class }) + TokenStream rewriting(@MemoryId String mem); + } + + @RequestScoped + public static class OKGuardrail implements OutputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + spy.incrementAndGet(); + return success(); + } + + public int spy() { + return spy.get(); + } + } + + @ApplicationScoped + public static class KOGuardrail implements OutputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + spy.incrementAndGet(); + return failure("KO"); + } + + public int spy() { + return spy.get(); + } + } + + @RequestScoped + public static class RetryingGuardrail implements OutputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + int v = spy.incrementAndGet(); + if (v >= 2) { + return OutputGuardrailResult.success(); + } + return retry("KO"); + } + + public int spy() { + return spy.get(); + } + } + + @RequestScoped + public static class RetryingButFailGuardrail implements OutputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + spy.incrementAndGet(); + return retry("KO"); + } + + public int spy() { + return spy.get(); + } + } + + @RequestScoped + public static class KOFatalGuardrail implements OutputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + spy.incrementAndGet(); + throw new IllegalArgumentException("Fatal"); + } + + public int spy() { + return spy.get(); + } + } + + @RequestScoped + public static class RewritingGuardrail implements OutputGuardrail { + private final AtomicInteger spy = new AtomicInteger(0); + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + spy.incrementAndGet(); + String text = responseFromLLM.text(); + return successWith(text + ",1"); + } + + public int spy() { + return spy.get(); + } + } + + public static class MyChatModelSupplier implements Supplier { + + @Override + public StreamingChatModel get() { + return new MyStreamedChatModel(); + } + } + + public static class MyStreamedChatModel implements StreamingChatModel { + + @Override + public void doChat(ChatRequest chatRequest, StreamingChatResponseHandler handler) { + handler.onPartialResponse("Hi!"); + handler.onPartialResponse(" "); + handler.onPartialResponse("World!"); + handler.onCompleteResponse(ChatResponse.builder().aiMessage(new AiMessage("")).build()); + } + } + + public static class MyMemoryProviderSupplier implements Supplier { + @Override + public ChatMemoryProvider get() { + return new ChatMemoryProvider() { + @Override + public ChatMemory get(Object memoryId) { + return MessageWindowChatMemory.withMaxMessages(10); + } + }; + } + } +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailPromptTemplateTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailPromptTemplateTest.java index 5b8397c83..65d084ded 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailPromptTemplateTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailPromptTemplateTest.java @@ -16,17 +16,17 @@ import org.junit.jupiter.api.extension.RegisterExtension; import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.guardrail.OutputGuardrail; +import dev.langchain4j.guardrail.OutputGuardrailRequest; +import dev.langchain4j.guardrail.OutputGuardrailResult; import dev.langchain4j.memory.chat.ChatMemoryProvider; import dev.langchain4j.model.chat.ChatModel; import dev.langchain4j.model.chat.request.ChatRequest; import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.service.MemoryId; import dev.langchain4j.service.UserMessage; +import dev.langchain4j.service.guardrail.OutputGuardrails; import io.quarkiverse.langchain4j.RegisterAiService; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrailParams; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; import io.quarkus.test.QuarkusUnitTest; @@ -186,19 +186,19 @@ public interface MyAiService { @RequestScoped public static class GuardrailValidation implements OutputGuardrail { - OutputGuardrailParams params; + OutputGuardrailRequest params; - public OutputGuardrailResult validate(OutputGuardrailParams params) { - this.params = params; + public OutputGuardrailResult validate(OutputGuardrailRequest request) { + this.params = request; return success(); } public String spyUserMessageTemplate() { - return params.userMessageTemplate(); + return params.requestParams().userMessageTemplate(); } public Map spyVariables() { - return params.variables(); + return params.requestParams().variables(); } } diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailRepromptingRetryDisabledTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailRepromptingRetryDisabledTest.java index 307f2c7c9..0081680b2 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailRepromptingRetryDisabledTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailRepromptingRetryDisabledTest.java @@ -18,6 +18,10 @@ import org.junit.jupiter.api.extension.RegisterExtension; import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.guardrail.GuardrailException; +import dev.langchain4j.guardrail.OutputGuardrail; +import dev.langchain4j.guardrail.OutputGuardrailRequest; +import dev.langchain4j.guardrail.OutputGuardrailResult; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.memory.chat.ChatMemoryProvider; import dev.langchain4j.memory.chat.MessageWindowChatMemory; @@ -26,12 +30,8 @@ import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.service.MemoryId; import dev.langchain4j.service.SystemMessage; +import dev.langchain4j.service.guardrail.OutputGuardrails; import io.quarkiverse.langchain4j.RegisterAiService; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrailParams; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; -import io.quarkiverse.langchain4j.runtime.aiservice.GuardrailException; import io.quarkus.test.QuarkusUnitTest; public class OutputGuardrailRepromptingRetryDisabledTest { @@ -114,7 +114,7 @@ public static class RetryGuardrail implements OutputGuardrail { private final AtomicInteger spy = new AtomicInteger(0); @Override - public OutputGuardrailResult validate(OutputGuardrailParams params) { + public OutputGuardrailResult validate(OutputGuardrailRequest request) { int v = spy.incrementAndGet(); return retry("Retry"); } @@ -130,7 +130,7 @@ public static class RepromptingGuardrail implements OutputGuardrail { private final AtomicInteger spy = new AtomicInteger(0); @Override - public OutputGuardrailResult validate(OutputGuardrailParams params) { + public OutputGuardrailResult validate(OutputGuardrailRequest request) { int v = spy.incrementAndGet(); return reprompt("Retry", "reprompt"); } diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailRepromptingTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailRepromptingTest.java index 397f9e95f..e32b374bf 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailRepromptingTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailRepromptingTest.java @@ -3,6 +3,7 @@ import static io.quarkiverse.langchain4j.runtime.LangChain4jUtil.chatMessageToText; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.atIndex; import java.util.HashMap; import java.util.List; @@ -22,6 +23,10 @@ import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.ChatMessage; import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.guardrail.GuardrailException; +import dev.langchain4j.guardrail.OutputGuardrail; +import dev.langchain4j.guardrail.OutputGuardrailRequest; +import dev.langchain4j.guardrail.OutputGuardrailResult; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.memory.chat.ChatMemoryProvider; import dev.langchain4j.memory.chat.MessageWindowChatMemory; @@ -30,12 +35,8 @@ import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.service.MemoryId; import dev.langchain4j.service.SystemMessage; +import dev.langchain4j.service.guardrail.OutputGuardrails; import io.quarkiverse.langchain4j.RegisterAiService; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrailParams; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; -import io.quarkiverse.langchain4j.runtime.aiservice.GuardrailException; import io.quarkus.test.QuarkusUnitTest; public class OutputGuardrailRepromptingTest { @@ -125,26 +126,31 @@ public static class RepromptingTwo implements OutputGuardrail { private final AtomicInteger spy = new AtomicInteger(0); @Override - public OutputGuardrailResult validate(OutputGuardrailParams params) { + public OutputGuardrailResult validate(OutputGuardrailRequest request) { int v = spy.incrementAndGet(); - List messages = params.memory().messages(); - if (v == 1) { - ChatMessage last = messages.get(messages.size() - 1); - assertThat(last).isInstanceOf(AiMessage.class); - assertThat(((AiMessage) last).text()).isEqualTo("Nope"); - assertThat(params.responseFromLLM().text()).isEqualTo("Nope"); - return reprompt("Retry", "Retry"); - } - if (v == 2) { - // Check that it's in memory - ChatMessage last = messages.get(messages.size() - 1); - ChatMessage beforeLast = messages.get(messages.size() - 2); - - assertThat(last).isInstanceOf(AiMessage.class); - assertThat(((AiMessage) last).text()).isEqualTo("Hello"); - assertThat(params.responseFromLLM().text()).isEqualTo("Hello"); - assertThat(beforeLast).isInstanceOf(UserMessage.class); - assertThat(chatMessageToText(beforeLast)).isEqualTo("Retry"); + List messages = request.requestParams().chatMemory().messages(); + + if ((v == 1) || (v == 2)) { + assertThat(messages) + .hasSize(3) + .satisfies(message -> assertThat(message) + .isNotNull() + .isInstanceOf(dev.langchain4j.data.message.SystemMessage.class) + .extracting(m -> ((dev.langchain4j.data.message.SystemMessage) m).text()) + .isEqualTo("Say Hi!"), + atIndex(0)) + .satisfies(message -> assertThat(message) + .isNotNull() + .isInstanceOf(UserMessage.class) + .extracting(m -> ((UserMessage) m).singleText()) + .isEqualTo("foo"), + atIndex(1)) + .satisfies(message -> assertThat(message) + .isNotNull() + .isInstanceOf(AiMessage.class) + .extracting(m -> ((AiMessage) m).text()) + .isEqualTo("Nope"), + atIndex(2)); return reprompt("Retry", "Retry"); } @@ -165,9 +171,9 @@ public static class RepromptingFailed implements OutputGuardrail { private final AtomicInteger spy = new AtomicInteger(0); @Override - public OutputGuardrailResult validate(OutputGuardrailParams params) { + public OutputGuardrailResult validate(OutputGuardrailRequest request) { int v = spy.incrementAndGet(); - List messages = params.memory().messages(); + List messages = request.requestParams().chatMemory().messages(); if (v == 1) { ChatMessage last = messages.get(messages.size() - 1); assertThat(last).isInstanceOf(AiMessage.class); @@ -175,14 +181,13 @@ public OutputGuardrailResult validate(OutputGuardrailParams params) { return reprompt("Retry", "Retry Once"); } if (v == 2) { - // Check that it's in memory + // Check that it's not in memory ChatMessage last = messages.get(messages.size() - 1); ChatMessage beforeLast = messages.get(messages.size() - 2); assertThat(last).isInstanceOf(AiMessage.class); - assertThat(((AiMessage) last).text()).isEqualTo("Hello"); + assertThat(((AiMessage) last).text()).isEqualTo("Nope"); assertThat(beforeLast).isInstanceOf(UserMessage.class); - assertThat(chatMessageToText(beforeLast)).isEqualTo("Retry Once"); return reprompt("Retry", "Retry Twice"); } return reprompt("Retry", "Retry Again"); diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailTest.java index 27ed96830..f11dde864 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailTest.java @@ -17,6 +17,8 @@ import org.junit.jupiter.api.extension.RegisterExtension; import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.guardrail.OutputGuardrail; +import dev.langchain4j.guardrail.OutputGuardrailResult; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.memory.chat.ChatMemoryProvider; import dev.langchain4j.model.chat.ChatModel; @@ -24,10 +26,8 @@ import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.service.MemoryId; import dev.langchain4j.service.UserMessage; +import dev.langchain4j.service.guardrail.OutputGuardrails; import io.quarkiverse.langchain4j.RegisterAiService; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; import io.quarkus.arc.Arc; import io.quarkus.test.QuarkusUnitTest; diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailValidationTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailValidationTest.java index 83d8806ea..52caf00db 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailValidationTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailValidationTest.java @@ -1,6 +1,7 @@ package io.quarkiverse.langchain4j.test.guardrails; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatThrownBy; import java.util.concurrent.atomic.AtomicInteger; @@ -17,6 +18,10 @@ import org.junit.jupiter.api.extension.RegisterExtension; import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.guardrail.GuardrailException; +import dev.langchain4j.guardrail.OutputGuardrail; +import dev.langchain4j.guardrail.OutputGuardrailException; +import dev.langchain4j.guardrail.OutputGuardrailResult; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.memory.chat.ChatMemoryProvider; import dev.langchain4j.model.chat.ChatModel; @@ -24,11 +29,8 @@ import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.service.MemoryId; import dev.langchain4j.service.UserMessage; +import dev.langchain4j.service.guardrail.OutputGuardrails; import io.quarkiverse.langchain4j.RegisterAiService; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; -import io.quarkiverse.langchain4j.runtime.aiservice.GuardrailException; import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; import io.quarkus.test.QuarkusUnitTest; @@ -91,6 +93,19 @@ void testFatalException() { assertThat(fatal.spy()).isEqualTo(1); } + @Test + @ActivateRequestContext + void noRetries() { + MyChatModelSupplier.CHAT_MODEL.spy.set(0); + assertThatExceptionOfType(OutputGuardrailException.class) + .isThrownBy(() -> aiService.noRetry("6")) + .withMessageContaining( + "Output validation failed. The guardrails have reached the maximum number of retries."); + + assertThat(MyChatModelSupplier.CHAT_MODEL.spy()).isEqualTo(1); + assertThat(retry.spy()).isEqualTo(1); + } + @RegisterAiService(chatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class) public interface MyAiService { @@ -110,6 +125,10 @@ public interface MyAiService { @OutputGuardrails(RetryingButFailGuardrail.class) String retryButFail(@MemoryId String mem); + @UserMessage("Say Hi!") + @OutputGuardrails(value = RetryingGuardrail.class, maxRetries = 0) + String noRetry(@MemoryId String mem); + @UserMessage("Say Hi!") @OutputGuardrails(KOFatalGuardrail.class) String fatal(@MemoryId String mem); @@ -147,7 +166,7 @@ public int spy() { } } - @ApplicationScoped + @RequestScoped public static class RetryingGuardrail implements OutputGuardrail { AtomicInteger spy = new AtomicInteger(0); @@ -199,19 +218,26 @@ public int spy() { } public static class MyChatModelSupplier implements Supplier { + static final MyChatModel CHAT_MODEL = new MyChatModel(); @Override public ChatModel get() { - return new MyChatModel(); + return CHAT_MODEL; } } public static class MyChatModel implements ChatModel { + private final AtomicInteger spy = new AtomicInteger(0); @Override public ChatResponse doChat(ChatRequest request) { + spy.incrementAndGet(); return ChatResponse.builder().aiMessage(new AiMessage("Hi!")).build(); } + + public int spy() { + return spy.get(); + } } public static class MyMemoryProviderSupplier implements Supplier { diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusGuardrailWithAugmentationTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusGuardrailWithAugmentationTest.java new file mode 100644 index 000000000..c277b36f9 --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusGuardrailWithAugmentationTest.java @@ -0,0 +1,219 @@ +package io.quarkiverse.langchain4j.test.guardrails; + +import static io.quarkiverse.langchain4j.runtime.LangChain4jUtil.chatMessageToText; +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; + +import jakarta.enterprise.context.RequestScoped; +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.memory.ChatMemory; +import dev.langchain4j.memory.chat.ChatMemoryProvider; +import dev.langchain4j.memory.chat.MessageWindowChatMemory; +import dev.langchain4j.model.chat.ChatModel; +import dev.langchain4j.model.chat.StreamingChatModel; +import dev.langchain4j.model.chat.request.ChatRequest; +import dev.langchain4j.model.chat.response.ChatResponse; +import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; +import dev.langchain4j.rag.AugmentationRequest; +import dev.langchain4j.rag.AugmentationResult; +import dev.langchain4j.rag.RetrievalAugmentor; +import dev.langchain4j.rag.content.Content; +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.guardrails.InputGuardrail; +import io.quarkiverse.langchain4j.guardrails.InputGuardrailParams; +import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult; +import io.quarkiverse.langchain4j.guardrails.InputGuardrails; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailParams; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; +import io.quarkus.test.QuarkusUnitTest; +import io.smallrye.mutiny.Multi; + +/** + * Verify that the input and output guardrails can access the augmentation results. + * + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ +@Deprecated(forRemoval = true) +public class QuarkusGuardrailWithAugmentationTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(MyAiService.class, + MyChatModel.class, MyChatModelSupplier.class, MyMemoryProviderSupplier.class)); + + @Inject + MyInputGuardrail inputGuardrail; + @Inject + MyOutputGuardrail outputGuardrail; + + @Inject + MyAiService service; + + @Test + @ActivateRequestContext + void testInputOnly() { + String s = service.inputOnly("1", "foo"); + + assertThat(s).isEqualTo("Hi!"); + assertThat(inputGuardrail.getSpy()).isEqualTo(1); + assertThat(outputGuardrail.getSpy()).isEqualTo(0); + } + + @Test + @ActivateRequestContext + void testInputOnlyMulti() { + List list = service.inputOnlyMulti("2", "foo").collect().asList().await().indefinitely(); + + assertThat(inputGuardrail.getSpy()).isEqualTo(1); + assertThat(outputGuardrail.getSpy()).isEqualTo(0); + assertThat(String.join(" ", list)).isEqualTo("Streaming hi !"); + } + + @Test + @ActivateRequestContext + void testOutputOnly() { + String s = service.outputOnly("3", "foo"); + + assertThat(s).isEqualTo("Hi!"); + assertThat(inputGuardrail.getSpy()).isEqualTo(0); + assertThat(outputGuardrail.getSpy()).isEqualTo(1); + } + + @Test + @ActivateRequestContext + void testInputAndOutput() { + String s = service.inputAndOutput("4", "foo"); + + assertThat(s).isEqualTo("Hi!"); + assertThat(inputGuardrail.getSpy()).isEqualTo(1); + assertThat(outputGuardrail.getSpy()).isEqualTo(1); + } + + @RegisterAiService(chatLanguageModelSupplier = MyChatModelSupplier.class, streamingChatLanguageModelSupplier = MyStreamingChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class, retrievalAugmentor = MyRetrievalAugmentor.class) + public interface MyAiService { + + @InputGuardrails(MyInputGuardrail.class) + String inputOnly(@MemoryId String id, @UserMessage String message); + + @InputGuardrails(MyInputGuardrail.class) + Multi inputOnlyMulti(@MemoryId String id, @UserMessage String message); + + @OutputGuardrails(MyOutputGuardrail.class) + String outputOnly(@MemoryId String id, @UserMessage String message); + + @InputGuardrails(MyInputGuardrail.class) + @OutputGuardrails(MyOutputGuardrail.class) + String inputAndOutput(@MemoryId String id, @UserMessage String message); + } + + @RequestScoped + public static class MyInputGuardrail implements InputGuardrail { + + AtomicInteger spy = new AtomicInteger(); + + @Override + public InputGuardrailResult validate(InputGuardrailParams params) { + spy.incrementAndGet(); + assertThat(params.augmentationResult().contents()).hasSize(2); + return InputGuardrailResult.success(); + } + + public int getSpy() { + return spy.get(); + } + } + + @RequestScoped + public static class MyOutputGuardrail implements OutputGuardrail { + + AtomicInteger spy = new AtomicInteger(); + + @Override + public OutputGuardrailResult validate(OutputGuardrailParams params) { + spy.incrementAndGet(); + assertThat(params.augmentationResult().contents()).hasSize(2); + return OutputGuardrailResult.success(); + } + + public int getSpy() { + return spy.get(); + } + } + + public static class MyChatModelSupplier implements Supplier { + + @Override + public ChatModel get() { + return new MyChatModel(); + } + } + + public static class MyStreamingChatModelSupplier implements Supplier { + + @Override + public StreamingChatModel get() { + return new MyStreamingChatModel(); + } + } + + public static class MyChatModel implements ChatModel { + + @Override + public ChatResponse doChat(ChatRequest chatRequest) { + assertThat(chatMessageToText(chatRequest.messages().get(chatRequest.messages().size() - 1))).isEqualTo("augmented"); + return ChatResponse.builder().aiMessage(new AiMessage("Hi!")).build(); + } + } + + public static class MyStreamingChatModel implements StreamingChatModel { + + @Override + public void doChat(ChatRequest chatRequest, StreamingChatResponseHandler handler) { + assertThat(chatMessageToText(chatRequest.messages().get(chatRequest.messages().size() - 1))).isEqualTo("augmented"); + handler.onPartialResponse("Streaming hi"); + handler.onPartialResponse("!"); + handler.onCompleteResponse(ChatResponse.builder().aiMessage(new AiMessage("")).build()); + } + } + + public static class MyMemoryProviderSupplier implements Supplier { + @Override + public ChatMemoryProvider get() { + return new ChatMemoryProvider() { + @Override + public ChatMemory get(Object memoryId) { + return new MessageWindowChatMemory.Builder().maxMessages(5).build(); + } + }; + } + } + + public static class MyRetrievalAugmentor implements Supplier { + @Override + public RetrievalAugmentor get() { + return new RetrievalAugmentor() { + @Override + public AugmentationResult augment(AugmentationRequest augmentationRequest) { + List content = List.of(Content.from("content1"), Content.from("content2")); + return new AugmentationResult(dev.langchain4j.data.message.UserMessage.userMessage("augmented"), content); + } + }; + } + } +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputAndOutputGuardrailsTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusInputAndDeclarativeAiServiceOutputGuardrailsTest.java similarity index 97% rename from core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputAndOutputGuardrailsTest.java rename to core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusInputAndDeclarativeAiServiceOutputGuardrailsTest.java index 2f983dff2..69fa148b5 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputAndOutputGuardrailsTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusInputAndDeclarativeAiServiceOutputGuardrailsTest.java @@ -35,7 +35,11 @@ import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; import io.quarkus.test.QuarkusUnitTest; -public class InputAndOutputGuardrailsTest { +/** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ +@Deprecated(forRemoval = true) +public class QuarkusInputAndDeclarativeAiServiceOutputGuardrailsTest { @RegisterExtension static final QuarkusUnitTest unitTest = new QuarkusUnitTest() diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusInputGuardrailChainTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusInputGuardrailChainTest.java new file mode 100644 index 000000000..e84cd8ba6 --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusInputGuardrailChainTest.java @@ -0,0 +1,202 @@ +package io.quarkiverse.langchain4j.test.guardrails; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Supplier; + +import jakarta.enterprise.context.RequestScoped; +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.memory.ChatMemory; +import dev.langchain4j.memory.chat.ChatMemoryProvider; +import dev.langchain4j.model.chat.ChatModel; +import dev.langchain4j.model.chat.request.ChatRequest; +import dev.langchain4j.model.chat.response.ChatResponse; +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.guardrails.InputGuardrail; +import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult; +import io.quarkiverse.langchain4j.guardrails.InputGuardrails; +import io.quarkiverse.langchain4j.runtime.aiservice.GuardrailException; +import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; +import io.quarkus.test.QuarkusUnitTest; + +/** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ +@Deprecated(forRemoval = true) +public class QuarkusInputGuardrailChainTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(MyAiService.class, + MyChatModel.class, MyChatModelSupplier.class, MyMemoryProviderSupplier.class, + ValidationException.class)); + + @Inject + MyAiService aiService; + + @Inject + FirstGuardrail firstGuardrail; + @Inject + SecondGuardrail secondGuardrail; + + @Inject + FailingGuardrail failingGuardrail; + + @Test + @ActivateRequestContext + void testThatGuardrailChainsAreInvoked() { + aiService.firstOneTwo("1", "foo"); + assertThat(firstGuardrail.spy()).isEqualTo(1); + assertThat(secondGuardrail.spy()).isEqualTo(1); + assertThat(firstGuardrail.lastAccess()).isLessThan(secondGuardrail.lastAccess()); + } + + @Test + @ActivateRequestContext + void testThatGuardrailOrderIsCorrect() { + aiService.twoAndFirst("1", "foo"); + assertThat(firstGuardrail.spy()).isEqualTo(1); + assertThat(secondGuardrail.spy()).isEqualTo(1); + assertThat(secondGuardrail.lastAccess()).isLessThan(firstGuardrail.lastAccess()); + } + + @Test + @ActivateRequestContext + void testFailureTheChain() { + assertThatThrownBy(() -> aiService.failingFirstTwo("1", "foo")) + .isInstanceOf(GuardrailException.class) + .hasCauseInstanceOf(ValidationException.class) + .hasRootCauseMessage("boom"); + assertThat(firstGuardrail.spy()).isEqualTo(1); + assertThat(secondGuardrail.spy()).isEqualTo(0); + assertThat(failingGuardrail.spy()).isEqualTo(1); + } + + @RegisterAiService(chatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class) + public interface MyAiService { + + @InputGuardrails({ FirstGuardrail.class, SecondGuardrail.class }) + String firstOneTwo(@MemoryId String mem, @UserMessage String message); + + @InputGuardrails({ SecondGuardrail.class, FirstGuardrail.class }) + String twoAndFirst(@MemoryId String mem, @UserMessage String message); + + @InputGuardrails({ FirstGuardrail.class, FailingGuardrail.class, SecondGuardrail.class }) + String failingFirstTwo(@MemoryId String mem, @UserMessage String message); + + } + + @RequestScoped + public static class FirstGuardrail implements InputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + AtomicLong lastAccess = new AtomicLong(); + + @Override + public InputGuardrailResult validate(dev.langchain4j.data.message.UserMessage um) { + spy.incrementAndGet(); + lastAccess.set(System.nanoTime()); + try { + Thread.sleep(1); + } catch (InterruptedException e) { + // Ignore me + } + return success(); + } + + public int spy() { + return spy.get(); + } + + public long lastAccess() { + return lastAccess.get(); + } + } + + @RequestScoped + public static class SecondGuardrail implements InputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + volatile AtomicLong lastAccess = new AtomicLong(); + + @Override + public InputGuardrailResult validate(dev.langchain4j.data.message.UserMessage um) { + spy.incrementAndGet(); + lastAccess.set(System.nanoTime()); + try { + Thread.sleep(1); + } catch (InterruptedException e) { + // Ignore me + } + return success(); + } + + public int spy() { + return spy.get(); + } + + public long lastAccess() { + return lastAccess.get(); + } + } + + @RequestScoped + public static class FailingGuardrail implements InputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + + @Override + public InputGuardrailResult validate(dev.langchain4j.data.message.UserMessage um) { + if (spy.incrementAndGet() == 1) { + return fatal("boom", new ValidationException("boom")); + } + return success(); + } + + public int spy() { + return spy.get(); + } + } + + public static class MyChatModelSupplier implements Supplier { + + @Override + public ChatModel get() { + return new MyChatModel(); + } + } + + public static class MyChatModel implements ChatModel { + + @Override + public ChatResponse doChat(ChatRequest request) { + return ChatResponse.builder().aiMessage(new AiMessage("Hi!")).build(); + } + } + + public static class MyMemoryProviderSupplier implements Supplier { + @Override + public ChatMemoryProvider get() { + return new ChatMemoryProvider() { + @Override + public ChatMemory get(Object memoryId) { + return new NoopChatMemory(); + } + }; + } + } +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusInputGuardrailNotFoundTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusInputGuardrailNotFoundTest.java new file mode 100644 index 000000000..80050e559 --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusInputGuardrailNotFoundTest.java @@ -0,0 +1,99 @@ +package io.quarkiverse.langchain4j.test.guardrails; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Fail.fail; + +import java.util.function.Supplier; + +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.enterprise.inject.spi.DeploymentException; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.memory.ChatMemory; +import dev.langchain4j.memory.chat.ChatMemoryProvider; +import dev.langchain4j.model.chat.ChatModel; +import dev.langchain4j.model.chat.request.ChatRequest; +import dev.langchain4j.model.chat.response.ChatResponse; +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.guardrails.InputGuardrail; +import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult; +import io.quarkiverse.langchain4j.guardrails.InputGuardrails; +import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; +import io.quarkus.test.QuarkusUnitTest; + +/** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ +@Deprecated(forRemoval = true) +public class QuarkusInputGuardrailNotFoundTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(MyAiService.class, + MyChatModel.class, MyChatModelSupplier.class, MyMemoryProviderSupplier.class)) + .assertException(t -> { + assertThat(t).isInstanceOf(DeploymentException.class); + assertThat(t).hasMessageContaining( + "io.quarkiverse.langchain4j.test.guardrails.QuarkusInputGuardrailNotFoundTest$MissingGuardRail"); + }); + + @Test + @ActivateRequestContext + void testThatNotFoundGuardrailsAreReported() { + fail("Should not be called"); + } + + @RegisterAiService(chatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class) + public interface MyAiService { + + @UserMessage("Say Hi!") + @InputGuardrails(MissingGuardRail.class) + String hi(@MemoryId String mem); + + } + + public static class MissingGuardRail implements InputGuardrail { + + @Override + public InputGuardrailResult validate(dev.langchain4j.data.message.UserMessage um) { + throw new RuntimeException("Should not be invoked"); + } + + } + + public static class MyChatModelSupplier implements Supplier { + + @Override + public ChatModel get() { + return new MyChatModel(); + } + } + + public static class MyChatModel implements ChatModel { + + @Override + public ChatResponse doChat(ChatRequest request) { + return ChatResponse.builder().aiMessage(new AiMessage("Hi!")).build(); + } + } + + public static class MyMemoryProviderSupplier implements Supplier { + @Override + public ChatMemoryProvider get() { + return new ChatMemoryProvider() { + @Override + public ChatMemory get(Object memoryId) { + return new NoopChatMemory(); + } + }; + } + } +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusInputGuardrailOnClassAndMethodTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusInputGuardrailOnClassAndMethodTest.java new file mode 100644 index 000000000..5d281e12b --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusInputGuardrailOnClassAndMethodTest.java @@ -0,0 +1,134 @@ +package io.quarkiverse.langchain4j.test.guardrails; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.memory.ChatMemory; +import dev.langchain4j.memory.chat.ChatMemoryProvider; +import dev.langchain4j.model.chat.ChatModel; +import dev.langchain4j.model.chat.request.ChatRequest; +import dev.langchain4j.model.chat.response.ChatResponse; +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.guardrails.InputGuardrail; +import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult; +import io.quarkiverse.langchain4j.guardrails.InputGuardrails; +import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; +import io.quarkus.test.QuarkusUnitTest; + +/** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ +@Deprecated(forRemoval = true) +public class QuarkusInputGuardrailOnClassAndMethodTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(MyAiService.class, OKGuardrail.class, KOGuardrail.class, + MyChatModel.class, MyChatModelSupplier.class, MyMemoryProviderSupplier.class)); + + @Inject + MyAiService aiService; + + @Inject + OKGuardrail okGuardrail; + + @Inject + KOGuardrail koGuardrail; + + @Test + @ActivateRequestContext + void testThatGuardrailsFromTheClassAreInvoked() { + assertThat(okGuardrail.spy()).isEqualTo(0); + aiService.hi("1"); + assertThat(okGuardrail.spy()).isEqualTo(1); + aiService.hi("2"); + assertThat(okGuardrail.spy()).isEqualTo(2); + + assertThat(koGuardrail.spy()).isEqualTo(0); + } + + @RegisterAiService(chatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class) + @InputGuardrails(KOGuardrail.class) + public interface MyAiService { + + @UserMessage("Say Hi!") + @InputGuardrails(OKGuardrail.class) + String hi(@MemoryId String mem); + + } + + @ApplicationScoped + public static class OKGuardrail implements InputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + + @Override + public InputGuardrailResult validate(dev.langchain4j.data.message.UserMessage um) { + spy.incrementAndGet(); + return success(); + } + + public int spy() { + return spy.get(); + } + } + + @ApplicationScoped + public static class KOGuardrail implements InputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + + @Override + public InputGuardrailResult validate(dev.langchain4j.data.message.UserMessage um) { + spy.incrementAndGet(); + return failure("KO"); + } + + public int spy() { + return spy.get(); + } + } + + public static class MyChatModelSupplier implements Supplier { + + @Override + public ChatModel get() { + return new MyChatModel(); + } + } + + public static class MyChatModel implements ChatModel { + + @Override + public ChatResponse doChat(ChatRequest request) { + return ChatResponse.builder().aiMessage(new AiMessage("Hi!")).build(); + } + } + + public static class MyMemoryProviderSupplier implements Supplier { + @Override + public ChatMemoryProvider get() { + return new ChatMemoryProvider() { + @Override + public ChatMemory get(Object memoryId) { + return new NoopChatMemory(); + } + }; + } + } +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusInputGuardrailOnClassTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusInputGuardrailOnClassTest.java new file mode 100644 index 000000000..43e2335f3 --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusInputGuardrailOnClassTest.java @@ -0,0 +1,112 @@ +package io.quarkiverse.langchain4j.test.guardrails; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.memory.ChatMemory; +import dev.langchain4j.memory.chat.ChatMemoryProvider; +import dev.langchain4j.model.chat.ChatModel; +import dev.langchain4j.model.chat.request.ChatRequest; +import dev.langchain4j.model.chat.response.ChatResponse; +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.guardrails.InputGuardrail; +import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult; +import io.quarkiverse.langchain4j.guardrails.InputGuardrails; +import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; +import io.quarkus.test.QuarkusUnitTest; + +/** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ +@Deprecated(forRemoval = true) +public class QuarkusInputGuardrailOnClassTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(MyAiService.class, OKGuardrail.class, + MyChatModel.class, MyChatModelSupplier.class, MyMemoryProviderSupplier.class)); + + @Inject + MyAiService aiService; + + @Inject + OKGuardrail okGuardrail; + + @Test + @ActivateRequestContext + void testThatGuardrailsFromTheClassAreInvoked() { + assertThat(okGuardrail.spy()).isEqualTo(0); + aiService.hi("1"); + assertThat(okGuardrail.spy()).isEqualTo(1); + aiService.hi("2"); + assertThat(okGuardrail.spy()).isEqualTo(2); + } + + @RegisterAiService(chatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class) + @InputGuardrails(OKGuardrail.class) + public interface MyAiService { + + @UserMessage("Say Hi!") + String hi(@MemoryId String mem); + + } + + @ApplicationScoped + public static class OKGuardrail implements InputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + + @Override + public InputGuardrailResult validate(dev.langchain4j.data.message.UserMessage ignored) { + spy.incrementAndGet(); + return success(); + } + + public int spy() { + return spy.get(); + } + } + + public static class MyChatModelSupplier implements Supplier { + + @Override + public ChatModel get() { + return new MyChatModel(); + } + } + + public static class MyChatModel implements ChatModel { + + @Override + public ChatResponse doChat(ChatRequest request) { + return ChatResponse.builder().aiMessage(new AiMessage("Hi!")).build(); + } + } + + public static class MyMemoryProviderSupplier implements Supplier { + @Override + public ChatMemoryProvider get() { + return new ChatMemoryProvider() { + @Override + public ChatMemory get(Object memoryId) { + return new NoopChatMemory(); + } + }; + } + } +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusInputGuardrailPromptTemplateTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusInputGuardrailPromptTemplateTest.java new file mode 100644 index 000000000..cc33f6f55 --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusInputGuardrailPromptTemplateTest.java @@ -0,0 +1,238 @@ +package io.quarkiverse.langchain4j.test.guardrails; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; +import java.util.Map; +import java.util.function.Supplier; + +import jakarta.enterprise.context.RequestScoped; +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.memory.chat.ChatMemoryProvider; +import dev.langchain4j.model.chat.ChatModel; +import dev.langchain4j.model.chat.request.ChatRequest; +import dev.langchain4j.model.chat.response.ChatResponse; +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.guardrails.InputGuardrail; +import io.quarkiverse.langchain4j.guardrails.InputGuardrailParams; +import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult; +import io.quarkiverse.langchain4j.guardrails.InputGuardrails; +import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; +import io.quarkus.test.QuarkusUnitTest; + +/** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ +@Deprecated(forRemoval = true) +public class QuarkusInputGuardrailPromptTemplateTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(MyAiService.class, MyAiService.class, GuardrailValidation.class, + MyChatModel.class, MyChatModelSupplier.class, MyMemoryProviderSupplier.class)); + @Inject + MyAiService aiService; + + @Inject + GuardrailValidation guardrailValidation; + + @Test + @ActivateRequestContext + void shouldWorkNoParameters() { + aiService.getJoke(); + assertThat(guardrailValidation.spyUserMessageTemplate()).isEqualTo("Tell me a joke"); + assertThat(guardrailValidation.spyVariables()).isEmpty(); + } + + @Test + @ActivateRequestContext + void shouldWorkWithMemoryId() { + aiService.getAnotherJoke("memory-id-001"); + assertThat(guardrailValidation.spyUserMessageTemplate()).isEqualTo("Tell me another joke"); + assertThat(guardrailValidation.spyVariables()).containsExactlyInAnyOrderEntriesOf(Map.of( + "memoryId", "memory-id-001", + "it", "memory-id-001")); + } + + @Test + @ActivateRequestContext + void shouldWorkWithNoMemoryIdAndOneParameter() { + aiService.sayHiToMyFriendNoMemory("Rambo"); + assertThat(guardrailValidation.spyUserMessageTemplate()).isEqualTo("Say hi to my friend {friend}!"); + assertThat(guardrailValidation.spyVariables()) + .containsExactlyInAnyOrderEntriesOf(Map.of( + "friend", "Rambo", + "it", "Rambo")); + } + + @Test + @ActivateRequestContext + void shouldWorkWithMemoryIdAndOneParameter() { + aiService.sayHiToMyFriend("1", "Chuck Norris"); + assertThat(guardrailValidation.spyUserMessageTemplate()).isEqualTo("Say hi to my friend {friend}!"); + assertThat(guardrailValidation.spyVariables()) + .containsExactlyInAnyOrderEntriesOf(Map.of( + "friend", "Chuck Norris", + "mem", "1")); + } + + @Test + @ActivateRequestContext + void shouldWorkWithNoMemoryIdAndThreeParameters() { + aiService.sayHiToMyFriends("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone"); + assertThat(guardrailValidation.spyUserMessageTemplate()) + .isEqualTo("Tell me something about {topic1}, {topic2}, {topic3}!"); + assertThat(guardrailValidation.spyVariables()) + .containsExactlyInAnyOrderEntriesOf(Map.of( + "topic1", "Chuck Norris", + "topic2", "Jean-Claude Van Damme", + "topic3", "Silvester Stallone")); + } + + @Test + @ActivateRequestContext + void shouldWorkWithNoMemoryIdAndList() { + aiService.sayHiToMyFriends(List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone")); + assertThat(guardrailValidation.spyUserMessageText()) + .isEqualTo("Tell me something about [Chuck Norris, Jean-Claude Van Damme, Silvester Stallone]!"); + assertThat(guardrailValidation.spyUserMessageTemplate()).isEqualTo("Tell me something about {topics}!"); + assertThat(guardrailValidation.spyVariables()) + .containsExactlyInAnyOrderEntriesOf(Map.of( + "topics", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone"), + "it", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone"))); + } + + @Test + @ActivateRequestContext + void shouldWorkWithMemoryIdAndList() { + aiService.sayHiToMyFriends("memory-id-007", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone")); + assertThat(guardrailValidation.spyUserMessageText()).isEqualTo( + "Tell me something about [Chuck Norris, Jean-Claude Van Damme, Silvester Stallone]! This is my memory id: memory-id-007"); + assertThat(guardrailValidation.spyUserMessageTemplate()) + .isEqualTo("Tell me something about {topics}! This is my memory id: {memoryId}"); + assertThat(guardrailValidation.spyVariables()) + .containsExactlyInAnyOrderEntriesOf(Map.of( + "topics", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone"), + "memoryId", "memory-id-007")); + } + + @Test + @ActivateRequestContext + void shouldWorkWithMemoryIdAndOneItemFromList() { + aiService.sayHiToMyFriend("memory-id-007", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone")); + assertThat(guardrailValidation.spyUserMessageText()) + .isEqualTo("Tell me something about Chuck Norris! This is my memory id: memory-id-007"); + assertThat(guardrailValidation.spyUserMessageTemplate()) + .isEqualTo("Tell me something about {topics[0]}! This is my memory id: {memoryId}"); + assertThat(guardrailValidation.spyVariables()) + .containsExactlyInAnyOrderEntriesOf(Map.of( + "topics", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone"), + "memoryId", "memory-id-007")); + } + + @Test + @ActivateRequestContext + void shouldWorkWithNoUserMessage() { + // UserMessage annotation is not provided, then no user message template should be available + aiService.saySomething("Is this a parameter or a prompt?"); + assertThat(guardrailValidation.spyUserMessageTemplate()).isEmpty(); + assertThat(guardrailValidation.spyVariables()).isEmpty(); + } + + @RegisterAiService(chatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class) + public interface MyAiService { + + @InputGuardrails(GuardrailValidation.class) + @UserMessage("Tell me a joke") + String getJoke(); + + @UserMessage("Tell me another joke") + @InputGuardrails(GuardrailValidation.class) + String getAnotherJoke(@MemoryId String memoryId); + + @UserMessage("Say hi to my friend {friend}!") + @InputGuardrails(GuardrailValidation.class) + String sayHiToMyFriendNoMemory(String friend); + + @UserMessage("Say hi to my friend {friend}!") + @InputGuardrails(GuardrailValidation.class) + String sayHiToMyFriend(@MemoryId String mem, String friend); + + @UserMessage("Tell me something about {topic1}, {topic2}, {topic3}!") + @InputGuardrails(GuardrailValidation.class) + String sayHiToMyFriends(String topic1, String topic2, String topic3); + + @UserMessage("Tell me something about {topics}!") + @InputGuardrails(GuardrailValidation.class) + String sayHiToMyFriends(List topics); + + @UserMessage("Tell me something about {topics}! This is my memory id: {memoryId}") + @InputGuardrails(GuardrailValidation.class) + String sayHiToMyFriends(@MemoryId String memoryId, List topics); + + @UserMessage("Tell me something about {topics[0]}! This is my memory id: {memoryId}") + @InputGuardrails(GuardrailValidation.class) + String sayHiToMyFriend(@MemoryId String memoryId, List topics); + + @InputGuardrails(GuardrailValidation.class) + String saySomething(String isThisAPromptOrAParameter); + + } + + @RequestScoped + public static class GuardrailValidation implements InputGuardrail { + + InputGuardrailParams params; + + public InputGuardrailResult validate(InputGuardrailParams params) { + this.params = params; + return success(); + } + + public String spyUserMessageTemplate() { + return params.userMessageTemplate(); + } + + public String spyUserMessageText() { + return params.userMessage().singleText(); + } + + public Map spyVariables() { + return params.variables(); + } + } + + public static class MyChatModelSupplier implements Supplier { + + @Override + public ChatModel get() { + return new MyChatModel(); + } + } + + public static class MyChatModel implements ChatModel { + + @Override + public ChatResponse doChat(ChatRequest request) { + return ChatResponse.builder().aiMessage(new AiMessage("Hi!")).build(); + } + } + + public static class MyMemoryProviderSupplier implements Supplier { + @Override + public ChatMemoryProvider get() { + return memoryId -> new NoopChatMemory(); + } + } +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusInputGuardrailRewritingTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusInputGuardrailRewritingTest.java new file mode 100644 index 000000000..8d4dc25a2 --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusInputGuardrailRewritingTest.java @@ -0,0 +1,102 @@ +package io.quarkiverse.langchain4j.test.guardrails; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.function.Supplier; + +import jakarta.enterprise.context.RequestScoped; +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.memory.ChatMemory; +import dev.langchain4j.memory.chat.ChatMemoryProvider; +import dev.langchain4j.model.chat.ChatModel; +import dev.langchain4j.model.chat.request.ChatRequest; +import dev.langchain4j.model.chat.response.ChatResponse; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.guardrails.InputGuardrail; +import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult; +import io.quarkiverse.langchain4j.guardrails.InputGuardrails; +import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; +import io.quarkus.test.QuarkusUnitTest; + +/** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ +@Deprecated(forRemoval = true) +public class QuarkusInputGuardrailRewritingTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(MyAiService.class, MessageTruncatingGuardrail.class, EchoChatModel.class, + MyChatModelSupplier.class, MyMemoryProviderSupplier.class)); + + @Inject + MyAiService aiService; + + @Test + @ActivateRequestContext + void testRewriting() { + assertEquals(MessageTruncatingGuardrail.MAX_LENGTH, aiService.test("first prompt", "second prompt").length()); + } + + @RegisterAiService(chatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class) + public interface MyAiService { + + @UserMessage("Given {first} and {second} do something") + @InputGuardrails(MessageTruncatingGuardrail.class) + String test(String first, String second); + + } + + @RequestScoped + public static class MessageTruncatingGuardrail implements InputGuardrail { + + static final int MAX_LENGTH = 20; + + @Override + public InputGuardrailResult validate(dev.langchain4j.data.message.UserMessage um) { + String text = um.singleText(); + return successWith(text.substring(0, MAX_LENGTH)); + } + } + + public static class MyChatModelSupplier implements Supplier { + + @Override + public ChatModel get() { + return new EchoChatModel(); + } + } + + public static class EchoChatModel implements ChatModel { + + @Override + public ChatResponse doChat(ChatRequest request) { + return ChatResponse.builder() + .aiMessage( + new AiMessage(((dev.langchain4j.data.message.UserMessage) request.messages().get(0)).singleText())) + .build(); + } + } + + public static class MyMemoryProviderSupplier implements Supplier { + @Override + public ChatMemoryProvider get() { + return new ChatMemoryProvider() { + @Override + public ChatMemory get(Object memoryId) { + return new NoopChatMemory(); + } + }; + } + } +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusInputGuardrailTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusInputGuardrailTest.java new file mode 100644 index 000000000..e78d9dce3 --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusInputGuardrailTest.java @@ -0,0 +1,160 @@ +package io.quarkiverse.langchain4j.test.guardrails; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.context.RequestScoped; +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.memory.ChatMemory; +import dev.langchain4j.memory.chat.ChatMemoryProvider; +import dev.langchain4j.model.chat.ChatModel; +import dev.langchain4j.model.chat.request.ChatRequest; +import dev.langchain4j.model.chat.response.ChatResponse; +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.guardrails.InputGuardrail; +import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult; +import io.quarkiverse.langchain4j.guardrails.InputGuardrails; +import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; +import io.quarkus.arc.Arc; +import io.quarkus.test.QuarkusUnitTest; + +/** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ +@Deprecated(forRemoval = true) +public class QuarkusInputGuardrailTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(MyAiService.class, OKGuardrail.class, KOGuardrail.class, + MyChatModel.class, MyChatModelSupplier.class, MyMemoryProviderSupplier.class, + ValidationException.class)); + + @Inject + MyAiService aiService; + + @Inject + OKGuardrail okGuardrail; + @Inject + KOGuardrail koGuardrail; + + @Test + void testThatInputGuardrailsAreInvoked() { + assertThat(Arc.container().requestContext().isActive()).isFalse(); + Arc.container().requestContext().activate(); + assertThat(okGuardrail.spy()).isEqualTo(0); + aiService.hi("1"); + assertThat(okGuardrail.spy()).isEqualTo(1); + aiService.hi("2"); + assertThat(okGuardrail.spy()).isEqualTo(2); + Arc.container().requestContext().deactivate(); + + Arc.container().requestContext().activate(); + // New request scope - the value should be back to 0 + assertThat(okGuardrail.spy()).isEqualTo(0); + aiService.hi("1"); + assertThat(okGuardrail.spy()).isEqualTo(1); + aiService.hi("2"); + assertThat(okGuardrail.spy()).isEqualTo(2); + } + + @Test + @ActivateRequestContext + void testThatGuardrailCanThrowValidationException() { + assertThat(koGuardrail.spy()).isEqualTo(0); + assertThatThrownBy(() -> aiService.ko("1")) + .hasCauseExactlyInstanceOf(ValidationException.class); + assertThat(koGuardrail.spy()).isEqualTo(1); + assertThatThrownBy(() -> aiService.ko("1")) + .hasCauseExactlyInstanceOf(ValidationException.class); + assertThat(koGuardrail.spy()).isEqualTo(2); + } + + @RegisterAiService(chatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class) + public interface MyAiService { + + @UserMessage("Say Hi!") + @InputGuardrails(OKGuardrail.class) + String hi(@MemoryId String mem); + + @UserMessage("Say Hi!") + @InputGuardrails(KOGuardrail.class) + String ko(@MemoryId String mem); + + } + + @RequestScoped + public static class OKGuardrail implements InputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + + @Override + public InputGuardrailResult validate(dev.langchain4j.data.message.UserMessage um) { + spy.incrementAndGet(); + return success(); + } + + public int spy() { + return spy.get(); + } + } + + @ApplicationScoped + public static class KOGuardrail implements InputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + + @Override + public InputGuardrailResult validate(dev.langchain4j.data.message.UserMessage um) { + spy.incrementAndGet(); + return failure("KO", new ValidationException("KO")); + } + + public int spy() { + return spy.get(); + } + } + + public static class MyChatModelSupplier implements Supplier { + + @Override + public ChatModel get() { + return new MyChatModel(); + } + } + + public static class MyChatModel implements ChatModel { + + @Override + public ChatResponse doChat(ChatRequest request) { + return ChatResponse.builder().aiMessage(new AiMessage("Hi!")).build(); + } + } + + public static class MyMemoryProviderSupplier implements Supplier { + @Override + public ChatMemoryProvider get() { + return new ChatMemoryProvider() { + @Override + public ChatMemory get(Object memoryId) { + return new NoopChatMemory(); + } + }; + } + } +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusInputGuardrailValidationTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusInputGuardrailValidationTest.java new file mode 100644 index 000000000..d6d26fbb5 --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusInputGuardrailValidationTest.java @@ -0,0 +1,251 @@ +package io.quarkiverse.langchain4j.test.guardrails; + +import static io.quarkiverse.langchain4j.runtime.LangChain4jUtil.chatMessageToText; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.context.RequestScoped; +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.memory.ChatMemory; +import dev.langchain4j.memory.chat.ChatMemoryProvider; +import dev.langchain4j.memory.chat.MessageWindowChatMemory; +import dev.langchain4j.model.chat.ChatModel; +import dev.langchain4j.model.chat.StreamingChatModel; +import dev.langchain4j.model.chat.request.ChatRequest; +import dev.langchain4j.model.chat.response.ChatResponse; +import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.guardrails.InputGuardrail; +import io.quarkiverse.langchain4j.guardrails.InputGuardrailParams; +import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult; +import io.quarkiverse.langchain4j.guardrails.InputGuardrails; +import io.quarkiverse.langchain4j.runtime.aiservice.GuardrailException; +import io.quarkus.test.QuarkusUnitTest; +import io.smallrye.mutiny.Multi; + +/** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ +@Deprecated(forRemoval = true) +public class QuarkusInputGuardrailValidationTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(MyAiService.class, OKGuardrail.class, KOGuardrail.class, + MyChatModel.class, MyStreamingChatModel.class, MyChatModelSupplier.class, + MyMemoryProviderSupplier.class, ValidationException.class)); + + @Inject + MyAiService aiService; + + @Test + @ActivateRequestContext + void testOk() { + aiService.ok("1"); + } + + @Test + @ActivateRequestContext + void testKo() { + assertThatThrownBy(() -> aiService.ko("2")) + .isInstanceOf(GuardrailException.class) + .hasMessageContaining("KO"); + } + + @Test + @ActivateRequestContext + void testOkMulti() { + List strings = aiService.okMulti("1").log() + .collect().asList().await().indefinitely(); + + assertThat(String.join(" ", strings)).isEqualTo("Streaming hi !"); + } + + @Test + @ActivateRequestContext + void testKoMulti() { + assertThatThrownBy(() -> aiService.koMulti("2").subscribe().asIterable()) + .isInstanceOf(GuardrailException.class) + .hasMessageContaining("KO"); + } + + @Inject + KOFatalGuardrail fatal; + + @Test + @ActivateRequestContext + void testFatalException() { + assertThatThrownBy(() -> aiService.fatal("5")) + .isInstanceOf(GuardrailException.class) + .hasMessageContaining("Fatal"); + assertThat(fatal.spy()).isEqualTo(1); + } + + @Test + @ActivateRequestContext + void testMemory() { + aiService.test("1", "foo"); + aiService.test("1", "bar"); + } + + @RegisterAiService(chatLanguageModelSupplier = MyChatModelSupplier.class, streamingChatLanguageModelSupplier = MyStreamingChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class) + public interface MyAiService { + + @UserMessage("Say Hi!") + @InputGuardrails(OKGuardrail.class) + String ok(@MemoryId String mem); + + @UserMessage("Say Hi!") + @InputGuardrails(KOGuardrail.class) + String ko(@MemoryId String mem); + + @UserMessage("Say Hi!") + @InputGuardrails(OKGuardrail.class) + Multi okMulti(@MemoryId String mem); + + @UserMessage("Say Hi!") + @InputGuardrails(KOGuardrail.class) + Multi koMulti(@MemoryId String mem); + + @UserMessage("Say Hi!") + @InputGuardrails(KOFatalGuardrail.class) + String fatal(@MemoryId String mem); + + @InputGuardrails(MemoryCheck.class) + String test(@MemoryId String name, @UserMessage String message); + } + + @RequestScoped + public static class OKGuardrail implements InputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + + @Override + public InputGuardrailResult validate(dev.langchain4j.data.message.UserMessage um) { + spy.incrementAndGet(); + return success(); + } + + public int spy() { + return spy.get(); + } + } + + @ApplicationScoped + public static class KOGuardrail implements InputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + + @Override + public InputGuardrailResult validate(dev.langchain4j.data.message.UserMessage um) { + spy.incrementAndGet(); + return failure("KO"); + } + + public int spy() { + return spy.get(); + } + } + + @ApplicationScoped + public static class KOFatalGuardrail implements InputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + + @Override + public InputGuardrailResult validate(dev.langchain4j.data.message.UserMessage um) { + spy.incrementAndGet(); + throw new IllegalArgumentException("Fatal"); + } + + public int spy() { + return spy.get(); + } + } + + @RequestScoped + public static class MemoryCheck implements InputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + + @Override + public InputGuardrailResult validate(InputGuardrailParams params) { + spy.incrementAndGet(); + if (params.memory().messages().isEmpty()) { + assertThat(params.userMessage().singleText()).isEqualTo("foo"); + } + if (params.memory().messages().size() == 2) { + assertThat(chatMessageToText(params.memory().messages().get(0))).isEqualTo("foo"); + assertThat(chatMessageToText(params.memory().messages().get(1))).isEqualTo("Hi!"); + assertThat(params.userMessage().singleText()).isEqualTo("bar"); + } + return success(); + } + + public int spy() { + return spy.get(); + } + } + + public static class MyChatModelSupplier implements Supplier { + + @Override + public ChatModel get() { + return new MyChatModel(); + } + } + + public static class MyStreamingChatModelSupplier implements Supplier { + + @Override + public StreamingChatModel get() { + return new MyStreamingChatModel(); + } + } + + public static class MyChatModel implements ChatModel { + + @Override + public ChatResponse doChat(ChatRequest request) { + return ChatResponse.builder().aiMessage(new AiMessage("Hi!")).build(); + } + } + + public static class MyStreamingChatModel implements StreamingChatModel { + + @Override + public void doChat(ChatRequest chatRequest, StreamingChatResponseHandler handler) { + handler.onPartialResponse("Streaming hi"); + handler.onPartialResponse("!"); + handler.onCompleteResponse(ChatResponse.builder().aiMessage(new AiMessage("")).build()); + } + } + + public static class MyMemoryProviderSupplier implements Supplier { + @Override + public ChatMemoryProvider get() { + return new ChatMemoryProvider() { + @Override + public ChatMemory get(Object memoryId) { + return new MessageWindowChatMemory.Builder().maxMessages(5).build(); + } + }; + } + } +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusInvalidOutputGuardrailAccumulatorTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusInvalidOutputGuardrailAccumulatorTest.java new file mode 100644 index 000000000..d7491bb12 --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusInvalidOutputGuardrailAccumulatorTest.java @@ -0,0 +1,119 @@ +package io.quarkiverse.langchain4j.test.guardrails; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Fail.fail; + +import java.util.List; +import java.util.function.Supplier; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.enterprise.inject.spi.DeploymentException; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.memory.ChatMemory; +import dev.langchain4j.memory.chat.ChatMemoryProvider; +import dev.langchain4j.memory.chat.MessageWindowChatMemory; +import dev.langchain4j.model.chat.StreamingChatModel; +import dev.langchain4j.model.chat.response.ChatResponse; +import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailAccumulator; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; +import io.quarkiverse.langchain4j.guardrails.OutputTokenAccumulator; +import io.quarkus.test.QuarkusUnitTest; +import io.smallrye.mutiny.Multi; + +/** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ +@Deprecated(forRemoval = true) +public class QuarkusInvalidOutputGuardrailAccumulatorTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(MyAiService.class, + MyMemoryProviderSupplier.class)) + .assertException(t -> { + assertThat(t).isInstanceOf(DeploymentException.class); + assertThat(t).hasMessageContaining( + "io.quarkiverse.langchain4j.test.guardrails.QuarkusInvalidOutputGuardrailAccumulatorTest$MyAiService.hi"); + }); + + @Test + @ActivateRequestContext + void testThatInvalidAccumulatorAreReported() { + fail("Should not be called"); + } + + @RegisterAiService(streamingChatLanguageModelSupplier = MyStreamingChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class) + public interface MyAiService { + + @UserMessage("Say Hi!") + @OutputGuardrails(MyGuardRail.class) + @OutputGuardrailAccumulator(MyAccumulator.class) + String hi(@MemoryId String mem); + + } + + @ApplicationScoped + public static class MyAccumulator implements OutputTokenAccumulator { + + @Override + public Multi accumulate(Multi tokens) { + return tokens; + } + } + + @ApplicationScoped + public static class MyGuardRail implements OutputGuardrail { + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + throw new RuntimeException("Should not be invoked"); + } + + } + + public static class MyMemoryProviderSupplier implements Supplier { + @Override + public ChatMemoryProvider get() { + return new ChatMemoryProvider() { + @Override + public ChatMemory get(Object memoryId) { + return new MessageWindowChatMemory.Builder().maxMessages(5).build(); + } + }; + } + } + + public static class MyStreamingChatModelSupplier implements Supplier { + + @Override + public StreamingChatModel get() { + return new StreamingChatModel() { + @Override + public void chat(List messages, StreamingChatResponseHandler handler) { + handler.onPartialResponse("Stream"); + handler.onPartialResponse("ing"); + handler.onPartialResponse(" "); + handler.onPartialResponse("world"); + handler.onPartialResponse("!"); + handler.onCompleteResponse(ChatResponse.builder().aiMessage(new AiMessage("")).build()); + } + }; + } + } + +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusOutputGuardrailAccumulatorNotFoundTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusOutputGuardrailAccumulatorNotFoundTest.java new file mode 100644 index 000000000..4a0f102cc --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusOutputGuardrailAccumulatorNotFoundTest.java @@ -0,0 +1,119 @@ +package io.quarkiverse.langchain4j.test.guardrails; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Fail.fail; + +import java.util.List; +import java.util.function.Supplier; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.enterprise.inject.spi.DeploymentException; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.memory.ChatMemory; +import dev.langchain4j.memory.chat.ChatMemoryProvider; +import dev.langchain4j.memory.chat.MessageWindowChatMemory; +import dev.langchain4j.model.chat.StreamingChatModel; +import dev.langchain4j.model.chat.response.ChatResponse; +import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailAccumulator; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; +import io.quarkiverse.langchain4j.guardrails.OutputTokenAccumulator; +import io.quarkus.test.QuarkusUnitTest; +import io.smallrye.mutiny.Multi; + +/** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ +@Deprecated(forRemoval = true) +public class QuarkusOutputGuardrailAccumulatorNotFoundTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(MyAiService.class, + MyMemoryProviderSupplier.class)) + .assertException(t -> { + assertThat(t).isInstanceOf(DeploymentException.class); + assertThat(t).hasMessageContaining( + "io.quarkiverse.langchain4j.test.guardrails.QuarkusOutputGuardrailAccumulatorNotFoundTest$MissingAccumulator"); + }); + + @Test + @ActivateRequestContext + void testThatNotFoundAccumulatorAreReported() { + fail("Should not be called"); + } + + @RegisterAiService(streamingChatLanguageModelSupplier = MyStreamingChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class) + public interface MyAiService { + + @UserMessage("Say Hi!") + @OutputGuardrails(MyGuardRail.class) + @OutputGuardrailAccumulator(MissingAccumulator.class) + Multi hi(@MemoryId String mem); + + } + + // Not a bean + public static class MissingAccumulator implements OutputTokenAccumulator { + + @Override + public Multi accumulate(Multi tokens) { + return tokens; + } + } + + @ApplicationScoped + public static class MyGuardRail implements OutputGuardrail { + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + throw new RuntimeException("Should not be invoked"); + } + + } + + public static class MyMemoryProviderSupplier implements Supplier { + @Override + public ChatMemoryProvider get() { + return new ChatMemoryProvider() { + @Override + public ChatMemory get(Object memoryId) { + return new MessageWindowChatMemory.Builder().maxMessages(5).build(); + } + }; + } + } + + public static class MyStreamingChatModelSupplier implements Supplier { + + @Override + public StreamingChatModel get() { + return new StreamingChatModel() { + @Override + public void chat(List messages, StreamingChatResponseHandler handler) { + handler.onPartialResponse("Stream"); + handler.onPartialResponse("ing"); + handler.onPartialResponse(" "); + handler.onPartialResponse("world"); + handler.onPartialResponse("!"); + handler.onCompleteResponse(ChatResponse.builder().aiMessage(new AiMessage("")).build()); + } + }; + } + } + +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusOutputGuardrailAccumulatorTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusOutputGuardrailAccumulatorTest.java new file mode 100644 index 000000000..0f79ab1ae --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusOutputGuardrailAccumulatorTest.java @@ -0,0 +1,288 @@ +package io.quarkiverse.langchain4j.test.guardrails; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.memory.ChatMemory; +import dev.langchain4j.memory.chat.ChatMemoryProvider; +import dev.langchain4j.memory.chat.MessageWindowChatMemory; +import dev.langchain4j.model.chat.StreamingChatModel; +import dev.langchain4j.model.chat.request.ChatRequest; +import dev.langchain4j.model.chat.response.ChatResponse; +import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailAccumulator; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailParams; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; +import io.quarkiverse.langchain4j.guardrails.OutputTokenAccumulator; +import io.quarkus.test.QuarkusUnitTest; +import io.smallrye.mutiny.Multi; + +/** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ +@Deprecated(forRemoval = true) +public class QuarkusOutputGuardrailAccumulatorTest { + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(MyStreamingChatModelSupplier.class, MyMemoryProviderSupplier.class, MyGuardrail.class, + SizeBasedAccumulator.class, SeparatorAccumulator.class)); + + @Inject + MyAiService ai; + @Inject + MyGuardrail guardrail; + + @Test + @ActivateRequestContext + void testWithoutAccumulator() { + guardrail.reset(); + var list = ai.usingDefaultAccumulator("Roger") + .collect().asList() + .await().indefinitely(); + assertThat(guardrail.count()).isEqualTo(1); + assertThat(guardrail.chunks()).containsExactly("Streaming world!"); + assertThat(list).containsExactly("Streaming world!"); + } + + @Test + @ActivateRequestContext + void testWithSizeBasedAccumulator() { + guardrail.reset(); + var list = ai.usingSizeAccumulator("Roger") + .collect().asList() + .await().indefinitely(); + assertThat(guardrail.count()).isEqualTo(4); + assertThat(guardrail.chunks()).containsExactly("Strea", "ming ", "world", "!"); + assertThat(list).containsExactly("Strea", "ming ", "world", "!"); + } + + @Test + @ActivateRequestContext + void testWithSeparatorBasedAccumulator() { + guardrail.reset(); + var list = ai.usingSeparatorAccumulator("Roger") + .collect().asList() + .await().indefinitely(); + assertThat(guardrail.count()).isEqualTo(2); + assertThat(guardrail.chunks()).containsExactly("Streaming", "world!"); + assertThat(list).containsExactly("Streaming", "world!"); + } + + @Test + @ActivateRequestContext + void testWithFailingAccumulator() { + guardrail.reset(); + assertThatThrownBy(() -> { + ai.usingFailingAccumulator("Roger") + .collect().asList() + .await().indefinitely(); + }).isInstanceOf(IllegalArgumentException.class); + assertThat(guardrail.count()).isEqualTo(3); + assertThat(guardrail.chunks()).containsExactly("Stream", "ing", " "); + } + + @Test + @ActivateRequestContext + void testWithThrowingAccumulator() { + guardrail.reset(); + assertThatThrownBy(() -> { + ai.usingThrowingAccumulator("Roger") + .collect().asList() + .await().indefinitely(); + }).isInstanceOf(IllegalArgumentException.class); + assertThat(guardrail.count()).isEqualTo(0); + } + + @RegisterAiService(streamingChatLanguageModelSupplier = MyStreamingChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class) + public interface MyAiService { + + @UserMessage("Say Hi!") + @OutputGuardrails(MyGuardrail.class) + Multi usingDefaultAccumulator(@MemoryId String mem); + + @UserMessage("Say Hi!") + @OutputGuardrails(MyGuardrail.class) + @OutputGuardrailAccumulator(SizeBasedAccumulator.class) + Multi usingSizeAccumulator(@MemoryId String mem); + + @UserMessage("Say Hi!") + @OutputGuardrails(MyGuardrail.class) + @OutputGuardrailAccumulator(SeparatorAccumulator.class) + Multi usingSeparatorAccumulator(@MemoryId String mem); + + @UserMessage("Say Hi!") + @OutputGuardrails(MyGuardrail.class) + @OutputGuardrailAccumulator(FailingAccumulator.class) + Multi usingFailingAccumulator(@MemoryId String mem); + + @UserMessage("Say Hi!") + @OutputGuardrails(MyGuardrail.class) + @OutputGuardrailAccumulator(ThrowingAccumulator.class) + Multi usingThrowingAccumulator(@MemoryId String mem); + } + + public static class MyMemoryProviderSupplier implements Supplier { + @Override + public ChatMemoryProvider get() { + return new ChatMemoryProvider() { + @Override + public ChatMemory get(Object memoryId) { + return new MessageWindowChatMemory.Builder().maxMessages(5).build(); + } + }; + } + } + + public static class MyStreamingChatModelSupplier implements Supplier { + + @Override + public StreamingChatModel get() { + return new StreamingChatModel() { + @Override + public void doChat(ChatRequest chatRequest, StreamingChatResponseHandler handler) { + handler.onPartialResponse("Stream"); + handler.onPartialResponse("ing"); + handler.onPartialResponse(" "); + handler.onPartialResponse("world"); + handler.onPartialResponse("!"); + handler.onCompleteResponse(ChatResponse.builder().aiMessage(new AiMessage("")).build()); + } + }; + } + } + + @ApplicationScoped + public static class MyGuardrail implements OutputGuardrail { + + AtomicInteger count = new AtomicInteger(); + List chunks = new ArrayList<>(); + + public int count() { + return count.get(); + } + + public List chunks() { + return chunks; + } + + public void reset() { + count.set(0); + chunks.clear(); + } + + @Override + public OutputGuardrailResult validate(OutputGuardrailParams params) { + count.incrementAndGet(); + chunks.add(params.responseFromLLM().text()); + return success(); + } + } + + @ApplicationScoped + public static class SizeBasedAccumulator implements OutputTokenAccumulator { + + @Override + public Multi accumulate(Multi tokens) { + return tokens + .withContext((multi, context) -> { + context.put("buffer", new StringBuffer()); + return multi + .flatMap(token -> { + StringBuffer buffer = context.get("buffer"); + buffer.append(token); + if (buffer.length() >= 5) { + var item = buffer.substring(0, 5); + buffer.delete(0, 5); + return Multi.createFrom().item(item); + } + return Multi.createFrom().empty(); + }) + .onCompletion().switchTo(() -> { + StringBuffer buffer = context.get("buffer"); + if (!buffer.isEmpty()) { + return Multi.createFrom().item(buffer.toString()); + } + return Multi.createFrom().empty(); + }); + }); + + } + } + + @ApplicationScoped + public static class SeparatorAccumulator implements OutputTokenAccumulator { + + @Override + public Multi accumulate(Multi tokens) { + return tokens + .withContext((multi, context) -> { + context.put("buffer", new StringBuffer()); + return multi + .flatMap(token -> { + StringBuffer buffer = context.get("buffer"); + buffer.append(token); + var idx = buffer.indexOf(" "); + if (idx != -1) { + var item = buffer.substring(0, idx + 1).trim(); // Drop the space + buffer.delete(0, idx + 1); + return Multi.createFrom().item(item); + } + return Multi.createFrom().empty(); + }) + .onCompletion().switchTo(() -> { + StringBuffer buffer = context.get("buffer"); + if (!buffer.isEmpty()) { + return Multi.createFrom().item(buffer.toString().trim()); + } + return Multi.createFrom().empty(); + }); + }); + + } + } + + @ApplicationScoped + public static class FailingAccumulator implements OutputTokenAccumulator { + + @Override + public Multi accumulate(Multi tokens) { + return tokens + .map(s -> { + if (s.contains("w")) { + throw new IllegalArgumentException("I don't like W"); + } + return s; + }); + } + } + + @ApplicationScoped + public static class ThrowingAccumulator implements OutputTokenAccumulator { + + @Override + public Multi accumulate(Multi tokens) { + throw new IllegalArgumentException("Boom"); + } + } +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusOutputGuardrailChainOnStreamedResponseTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusOutputGuardrailChainOnStreamedResponseTest.java new file mode 100644 index 000000000..4526f534c --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusOutputGuardrailChainOnStreamedResponseTest.java @@ -0,0 +1,256 @@ +package io.quarkiverse.langchain4j.test.guardrails; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Supplier; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.context.RequestScoped; +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.memory.ChatMemory; +import dev.langchain4j.memory.chat.ChatMemoryProvider; +import dev.langchain4j.model.chat.StreamingChatModel; +import dev.langchain4j.model.chat.request.ChatRequest; +import dev.langchain4j.model.chat.response.ChatResponse; +import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailAccumulator; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; +import io.quarkiverse.langchain4j.guardrails.OutputTokenAccumulator; +import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; +import io.quarkus.test.QuarkusUnitTest; +import io.smallrye.mutiny.Multi; + +/** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ +@Deprecated(forRemoval = true) +public class QuarkusOutputGuardrailChainOnStreamedResponseTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)); + + @Inject + MyAiService aiService; + + @Inject + FirstGuardrail firstGuardrail; + @Inject + SecondGuardrail secondGuardrail; + + @Inject + FailingGuardrail failingGuardrail; + + @Test + @ActivateRequestContext + void testThatGuardrailChainsAreInvoked() { + aiService.firstOneTwo("1", "foo").collect().asList().await().indefinitely(); + assertThat(firstGuardrail.spy()).isEqualTo(1); + assertThat(secondGuardrail.spy()).isEqualTo(1); + assertThat(firstGuardrail.lastAccess()).isLessThan(secondGuardrail.lastAccess()); + } + + @Test + @ActivateRequestContext + void testThatGuardrailOrderIsCorrect() { + aiService.twoAndFirst("1", "foo").collect().asList().await().indefinitely(); + ; + assertThat(firstGuardrail.spy()).isEqualTo(1); + assertThat(secondGuardrail.spy()).isEqualTo(1); + assertThat(secondGuardrail.lastAccess()).isLessThan(firstGuardrail.lastAccess()); + } + + @Test + @ActivateRequestContext + void testThatRetryRestartTheChain() { + aiService.failingFirstTwo("1", "foo").collect().asList().await().indefinitely(); + ; + assertThat(firstGuardrail.spy()).isEqualTo(2); + assertThat(secondGuardrail.spy()).isEqualTo(1); + assertThat(failingGuardrail.spy()).isEqualTo(2); + assertThat(firstGuardrail.lastAccess()).isLessThan(secondGuardrail.lastAccess()); + } + + @Test + @ActivateRequestContext + void testThatGuardrailChainsAreInvokedWithPassThroughAccumulator() { + aiService.firstOneTwoWithPassThroughAccumulator("1", "foo").collect().asList().await().indefinitely(); + assertThat(firstGuardrail.spy()).isEqualTo(3); + assertThat(secondGuardrail.spy()).isEqualTo(3); + assertThat(firstGuardrail.lastAccess()).isLessThan(secondGuardrail.lastAccess()); + } + + @Test + @ActivateRequestContext + void testThatGuardrailOrderIsCorrectWithPassThroughAccumulator() { + aiService.twoAndFirstWithPassThroughAccumulator("1", "foo").collect().asList().await().indefinitely(); + ; + assertThat(firstGuardrail.spy()).isEqualTo(3); + assertThat(secondGuardrail.spy()).isEqualTo(3); + assertThat(secondGuardrail.lastAccess()).isLessThan(firstGuardrail.lastAccess()); + } + + @Test + @ActivateRequestContext + void testThatRetryRestartTheChainWithPassThroughAccumulator() { + aiService.failingFirstTwoWithPassThroughAccumulator("1", "foo").collect().asList().await().indefinitely(); + ; + assertThat(firstGuardrail.spy()).isEqualTo(4); + assertThat(secondGuardrail.spy()).isEqualTo(3); + assertThat(failingGuardrail.spy()).isEqualTo(4); + assertThat(firstGuardrail.lastAccess()).isLessThan(secondGuardrail.lastAccess()); + } + + @RegisterAiService(streamingChatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class) + public interface MyAiService { + + @OutputGuardrails({ FirstGuardrail.class, SecondGuardrail.class }) + Multi firstOneTwo(@MemoryId String mem, @UserMessage String message); + + @OutputGuardrails({ SecondGuardrail.class, FirstGuardrail.class }) + Multi twoAndFirst(@MemoryId String mem, @UserMessage String message); + + @OutputGuardrails({ FirstGuardrail.class, FailingGuardrail.class, SecondGuardrail.class }) + Multi failingFirstTwo(@MemoryId String mem, @UserMessage String message); + + @OutputGuardrails({ FirstGuardrail.class, SecondGuardrail.class }) + @OutputGuardrailAccumulator(PassThroughAccumulator.class) + Multi firstOneTwoWithPassThroughAccumulator(@MemoryId String mem, @UserMessage String message); + + @OutputGuardrails({ SecondGuardrail.class, FirstGuardrail.class }) + @OutputGuardrailAccumulator(PassThroughAccumulator.class) + Multi twoAndFirstWithPassThroughAccumulator(@MemoryId String mem, @UserMessage String message); + + @OutputGuardrails({ FirstGuardrail.class, FailingGuardrail.class, SecondGuardrail.class }) + @OutputGuardrailAccumulator(PassThroughAccumulator.class) + Multi failingFirstTwoWithPassThroughAccumulator(@MemoryId String mem, @UserMessage String message); + + } + + @RequestScoped + public static class FirstGuardrail implements OutputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + AtomicLong lastAccess = new AtomicLong(); + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + spy.incrementAndGet(); + lastAccess.set(System.nanoTime()); + try { + Thread.sleep(1); + } catch (InterruptedException e) { + // Ignore me + } + return success(); + } + + public int spy() { + return spy.get(); + } + + public long lastAccess() { + return lastAccess.get(); + } + } + + @RequestScoped + public static class SecondGuardrail implements OutputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + volatile AtomicLong lastAccess = new AtomicLong(); + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + spy.incrementAndGet(); + lastAccess.set(System.nanoTime()); + try { + Thread.sleep(1); + } catch (InterruptedException e) { + // Ignore me + } + return success(); + } + + public int spy() { + return spy.get(); + } + + public long lastAccess() { + return lastAccess.get(); + } + } + + @RequestScoped + public static class FailingGuardrail implements OutputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + if (spy.incrementAndGet() == 1) { + return reprompt("Retry", "Retry"); + } + return success(); + } + + public int spy() { + return spy.get(); + } + } + + public static class MyChatModelSupplier implements Supplier { + + @Override + public StreamingChatModel get() { + return new MyStreamedChatModel(); + } + } + + public static class MyStreamedChatModel implements StreamingChatModel { + + @Override + public void doChat(ChatRequest chatRequest, StreamingChatResponseHandler handler) { + handler.onPartialResponse("Hi!"); + handler.onPartialResponse(" "); + handler.onPartialResponse("World!"); + handler.onCompleteResponse(ChatResponse.builder().aiMessage(new AiMessage("")).build()); + } + } + + @ApplicationScoped + public static class PassThroughAccumulator implements OutputTokenAccumulator { + + @Override + public Multi accumulate(Multi tokens) { + return tokens; + } + } + + public static class MyMemoryProviderSupplier implements Supplier { + @Override + public ChatMemoryProvider get() { + return new ChatMemoryProvider() { + @Override + public ChatMemory get(Object memoryId) { + return new NoopChatMemory(); + } + }; + } + } +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusOutputGuardrailChainTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusOutputGuardrailChainTest.java new file mode 100644 index 000000000..fb4a978b5 --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusOutputGuardrailChainTest.java @@ -0,0 +1,276 @@ +package io.quarkiverse.langchain4j.test.guardrails; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; + +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Supplier; + +import jakarta.enterprise.context.RequestScoped; +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.memory.ChatMemory; +import dev.langchain4j.memory.chat.ChatMemoryProvider; +import dev.langchain4j.model.chat.ChatModel; +import dev.langchain4j.model.chat.request.ChatRequest; +import dev.langchain4j.model.chat.response.ChatResponse; +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; +import io.quarkiverse.langchain4j.runtime.aiservice.GuardrailException; +import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; +import io.quarkus.test.QuarkusUnitTest; + +/** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ +@Deprecated(forRemoval = true) +public class QuarkusOutputGuardrailChainTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(MyAiService.class, + MyChatModel.class, MyChatModelSupplier.class, MyMemoryProviderSupplier.class)); + + @Inject + MyAiService aiService; + + @Inject + FirstGuardrail firstGuardrail; + @Inject + SecondGuardrail secondGuardrail; + + @Inject + FailingGuardrail failingGuardrail; + + @Test + @ActivateRequestContext + void testThatGuardrailChainsAreInvoked() { + aiService.firstOneTwo("1", "foo"); + assertThat(firstGuardrail.spy()).isEqualTo(1); + assertThat(secondGuardrail.spy()).isEqualTo(1); + assertThat(firstGuardrail.lastAccess()).isLessThan(secondGuardrail.lastAccess()); + } + + @Test + @ActivateRequestContext + void testThatGuardrailOrderIsCorrect() { + aiService.twoAndFirst("1", "foo"); + assertThat(firstGuardrail.spy()).isEqualTo(1); + assertThat(secondGuardrail.spy()).isEqualTo(1); + assertThat(secondGuardrail.lastAccess()).isLessThan(firstGuardrail.lastAccess()); + } + + @Test + @ActivateRequestContext + void testThatRetryRestartTheChain() { + aiService.failingFirstTwo("1", "foo"); + assertThat(firstGuardrail.spy()).isEqualTo(2); + assertThat(secondGuardrail.spy()).isEqualTo(1); + assertThat(failingGuardrail.spy()).isEqualTo(2); + assertThat(firstGuardrail.lastAccess()).isLessThan(secondGuardrail.lastAccess()); + } + + @Test + @ActivateRequestContext + void testThatRewritesTheOutputTwiceInTheChain() { + assertThat(aiService.rewritingSuccess("1", "foo")).isEqualTo("Hi!,1,2"); + } + + @Test + @ActivateRequestContext + void testThatRepromptAfterRewriteIsNotAllowed() { + assertThatExceptionOfType(GuardrailException.class) + .isThrownBy(() -> aiService.repromptAfterRewrite("1", "foo")) + .withMessageContaining("Retry or reprompt is not allowed after a rewritten output"); + } + + @Test + @ActivateRequestContext + void testThatRewritesTheOutputWithAResult() { + assertThat(aiService.rewritingSuccessWithResult("1", "foo")).isSameAs(RewritingGuardrailWithResult.RESULT); + } + + @RegisterAiService(chatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class) + public interface MyAiService { + + @OutputGuardrails({ FirstGuardrail.class, SecondGuardrail.class }) + String firstOneTwo(@MemoryId String mem, @UserMessage String message); + + @OutputGuardrails({ SecondGuardrail.class, FirstGuardrail.class }) + String twoAndFirst(@MemoryId String mem, @UserMessage String message); + + @OutputGuardrails({ FirstGuardrail.class, FailingGuardrail.class, SecondGuardrail.class }) + String failingFirstTwo(@MemoryId String mem, @UserMessage String message); + + @OutputGuardrails({ FirstRewritingGuardrail.class, SecondRewritingGuardrail.class }) + String rewritingSuccess(@MemoryId String mem, @UserMessage String message); + + @OutputGuardrails({ FirstRewritingGuardrail.class, RepromptingGuardrail.class }) + String repromptAfterRewrite(@MemoryId String mem, @UserMessage String message); + + @OutputGuardrails({ FirstRewritingGuardrail.class, RewritingGuardrailWithResult.class }) + Integer rewritingSuccessWithResult(@MemoryId String mem, @UserMessage String message); + + } + + @RequestScoped + public static class FirstGuardrail implements OutputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + AtomicLong lastAccess = new AtomicLong(); + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + spy.incrementAndGet(); + lastAccess.set(System.nanoTime()); + try { + Thread.sleep(1); + } catch (InterruptedException e) { + // Ignore me + } + return success(); + } + + public int spy() { + return spy.get(); + } + + public long lastAccess() { + return lastAccess.get(); + } + } + + @RequestScoped + public static class SecondGuardrail implements OutputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + volatile AtomicLong lastAccess = new AtomicLong(); + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + spy.incrementAndGet(); + lastAccess.set(System.nanoTime()); + try { + Thread.sleep(1); + } catch (InterruptedException e) { + // Ignore me + } + return success(); + } + + public int spy() { + return spy.get(); + } + + public long lastAccess() { + return lastAccess.get(); + } + } + + @RequestScoped + public static class FailingGuardrail implements OutputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + if (spy.incrementAndGet() == 1) { + return reprompt("Retry", "Retry"); + } + return success(); + } + + public int spy() { + return spy.get(); + } + } + + @RequestScoped + public static class FirstRewritingGuardrail implements OutputGuardrail { + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + String text = responseFromLLM.text(); + return successWith(text + ",1"); + } + } + + @RequestScoped + public static class SecondRewritingGuardrail implements OutputGuardrail { + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + String text = responseFromLLM.text(); + return successWith(text + ",2"); + } + } + + @RequestScoped + public static class RewritingGuardrailWithResult implements OutputGuardrail { + + static final Integer RESULT = 1_000; + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + String text = responseFromLLM.text(); + return successWith(text + ",2", RESULT); + } + } + + @RequestScoped + public static class RepromptingGuardrail implements OutputGuardrail { + + private boolean firstCall = true; + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + if (firstCall) { + firstCall = false; + String text = responseFromLLM.text(); + return reprompt("Wrong message", text + ", " + text); + } + return success(); + } + } + + public static class MyChatModelSupplier implements Supplier { + + @Override + public ChatModel get() { + return new MyChatModel(); + } + } + + public static class MyChatModel implements ChatModel { + + @Override + public ChatResponse doChat(ChatRequest request) { + return ChatResponse.builder().aiMessage(new AiMessage("Hi!")).build(); + } + } + + public static class MyMemoryProviderSupplier implements Supplier { + @Override + public ChatMemoryProvider get() { + return new ChatMemoryProvider() { + @Override + public ChatMemory get(Object memoryId) { + return new NoopChatMemory(); + } + }; + } + } +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusOutputGuardrailNotFoundTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusOutputGuardrailNotFoundTest.java new file mode 100644 index 000000000..aeb7ef02e --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusOutputGuardrailNotFoundTest.java @@ -0,0 +1,99 @@ +package io.quarkiverse.langchain4j.test.guardrails; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Fail.fail; + +import java.util.function.Supplier; + +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.enterprise.inject.spi.DeploymentException; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.memory.ChatMemory; +import dev.langchain4j.memory.chat.ChatMemoryProvider; +import dev.langchain4j.model.chat.ChatModel; +import dev.langchain4j.model.chat.request.ChatRequest; +import dev.langchain4j.model.chat.response.ChatResponse; +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; +import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; +import io.quarkus.test.QuarkusUnitTest; + +/** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ +@Deprecated(forRemoval = true) +public class QuarkusOutputGuardrailNotFoundTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(MyAiService.class, + MyChatModel.class, MyChatModelSupplier.class, MyMemoryProviderSupplier.class)) + .assertException(t -> { + assertThat(t).isInstanceOf(DeploymentException.class); + assertThat(t).hasMessageContaining( + "io.quarkiverse.langchain4j.test.guardrails.QuarkusOutputGuardrailNotFoundTest$MissingGuardRail"); + }); + + @Test + @ActivateRequestContext + void testThatNotFoundGuardrailsAreReported() { + fail("Should not be called"); + } + + @RegisterAiService(chatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class) + public interface MyAiService { + + @UserMessage("Say Hi!") + @OutputGuardrails(MissingGuardRail.class) + String hi(@MemoryId String mem); + + } + + public static class MissingGuardRail implements OutputGuardrail { + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + throw new RuntimeException("Should not be invoked"); + } + + } + + public static class MyChatModelSupplier implements Supplier { + + @Override + public ChatModel get() { + return new MyChatModel(); + } + } + + public static class MyChatModel implements ChatModel { + + @Override + public ChatResponse doChat(ChatRequest request) { + return ChatResponse.builder().aiMessage(new AiMessage("Hi!")).build(); + } + } + + public static class MyMemoryProviderSupplier implements Supplier { + @Override + public ChatMemoryProvider get() { + return new ChatMemoryProvider() { + @Override + public ChatMemory get(Object memoryId) { + return new NoopChatMemory(); + } + }; + } + } +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusOutputGuardrailOnClassAndMethodTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusOutputGuardrailOnClassAndMethodTest.java new file mode 100644 index 000000000..6b8ff16b1 --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusOutputGuardrailOnClassAndMethodTest.java @@ -0,0 +1,134 @@ +package io.quarkiverse.langchain4j.test.guardrails; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.memory.ChatMemory; +import dev.langchain4j.memory.chat.ChatMemoryProvider; +import dev.langchain4j.model.chat.ChatModel; +import dev.langchain4j.model.chat.request.ChatRequest; +import dev.langchain4j.model.chat.response.ChatResponse; +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; +import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; +import io.quarkus.test.QuarkusUnitTest; + +/** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ +@Deprecated(forRemoval = true) +public class QuarkusOutputGuardrailOnClassAndMethodTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(MyAiService.class, OKGuardrail.class, KOGuardrail.class, + MyChatModel.class, MyChatModelSupplier.class, MyMemoryProviderSupplier.class)); + + @Inject + MyAiService aiService; + + @Inject + OKGuardrail okGuardrail; + + @Inject + KOGuardrail koGuardrail; + + @Test + @ActivateRequestContext + void testThatGuardrailsFromTheClassAreInvoked() { + assertThat(okGuardrail.spy()).isEqualTo(0); + aiService.hi("1"); + assertThat(okGuardrail.spy()).isEqualTo(1); + aiService.hi("2"); + assertThat(okGuardrail.spy()).isEqualTo(2); + + assertThat(koGuardrail.spy()).isEqualTo(0); + } + + @RegisterAiService(chatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class) + @OutputGuardrails(KOGuardrail.class) + public interface MyAiService { + + @UserMessage("Say Hi!") + @OutputGuardrails(OKGuardrail.class) + String hi(@MemoryId String mem); + + } + + @ApplicationScoped + public static class OKGuardrail implements OutputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + spy.incrementAndGet(); + return success(); + } + + public int spy() { + return spy.get(); + } + } + + @ApplicationScoped + public static class KOGuardrail implements OutputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + spy.incrementAndGet(); + return failure("KO"); + } + + public int spy() { + return spy.get(); + } + } + + public static class MyChatModelSupplier implements Supplier { + + @Override + public ChatModel get() { + return new MyChatModel(); + } + } + + public static class MyChatModel implements ChatModel { + + @Override + public ChatResponse doChat(ChatRequest request) { + return ChatResponse.builder().aiMessage(new AiMessage("Hi!")).build(); + } + } + + public static class MyMemoryProviderSupplier implements Supplier { + @Override + public ChatMemoryProvider get() { + return new ChatMemoryProvider() { + @Override + public ChatMemory get(Object memoryId) { + return new NoopChatMemory(); + } + }; + } + } +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusOutputGuardrailOnClassTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusOutputGuardrailOnClassTest.java new file mode 100644 index 000000000..d8977ad80 --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusOutputGuardrailOnClassTest.java @@ -0,0 +1,112 @@ +package io.quarkiverse.langchain4j.test.guardrails; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.memory.ChatMemory; +import dev.langchain4j.memory.chat.ChatMemoryProvider; +import dev.langchain4j.model.chat.ChatModel; +import dev.langchain4j.model.chat.request.ChatRequest; +import dev.langchain4j.model.chat.response.ChatResponse; +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; +import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; +import io.quarkus.test.QuarkusUnitTest; + +/** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ +@Deprecated(forRemoval = true) +public class QuarkusOutputGuardrailOnClassTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(MyAiService.class, OKGuardrail.class, + MyChatModel.class, MyChatModelSupplier.class, MyMemoryProviderSupplier.class)); + + @Inject + MyAiService aiService; + + @Inject + OKGuardrail okGuardrail; + + @Test + @ActivateRequestContext + void testThatGuardrailsFromTheClassAreInvoked() { + assertThat(okGuardrail.spy()).isEqualTo(0); + aiService.hi("1"); + assertThat(okGuardrail.spy()).isEqualTo(1); + aiService.hi("2"); + assertThat(okGuardrail.spy()).isEqualTo(2); + } + + @RegisterAiService(chatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class) + @OutputGuardrails(OKGuardrail.class) + public interface MyAiService { + + @UserMessage("Say Hi!") + String hi(@MemoryId String mem); + + } + + @ApplicationScoped + public static class OKGuardrail implements OutputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + spy.incrementAndGet(); + return success(); + } + + public int spy() { + return spy.get(); + } + } + + public static class MyChatModelSupplier implements Supplier { + + @Override + public ChatModel get() { + return new MyChatModel(); + } + } + + public static class MyChatModel implements ChatModel { + + @Override + public ChatResponse doChat(ChatRequest request) { + return ChatResponse.builder().aiMessage(new AiMessage("Hi!")).build(); + } + } + + public static class MyMemoryProviderSupplier implements Supplier { + @Override + public ChatMemoryProvider get() { + return new ChatMemoryProvider() { + @Override + public ChatMemory get(Object memoryId) { + return new NoopChatMemory(); + } + }; + } + } +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusOutputGuardrailOnStreamedResponseTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusOutputGuardrailOnStreamedResponseTest.java new file mode 100644 index 000000000..a97792469 --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusOutputGuardrailOnStreamedResponseTest.java @@ -0,0 +1,237 @@ +package io.quarkiverse.langchain4j.test.guardrails; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.context.RequestScoped; +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.memory.ChatMemory; +import dev.langchain4j.memory.chat.ChatMemoryProvider; +import dev.langchain4j.model.chat.StreamingChatModel; +import dev.langchain4j.model.chat.request.ChatRequest; +import dev.langchain4j.model.chat.response.ChatResponse; +import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailAccumulator; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; +import io.quarkiverse.langchain4j.guardrails.OutputTokenAccumulator; +import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; +import io.quarkus.arc.Arc; +import io.quarkus.test.QuarkusUnitTest; +import io.smallrye.mutiny.Multi; + +/** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ +@Deprecated(forRemoval = true) +public class QuarkusOutputGuardrailOnStreamedResponseTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(MyAiService.class, OKGuardrail.class, KOGuardrail.class, + MyChatModelSupplier.class, MyMemoryProviderSupplier.class, ValidationException.class)); + + @Inject + MyAiService aiService; + + @Inject + OKGuardrail okGuardrail; + @Inject + KOGuardrail koGuardrail; + + @Test + void testThatOutputGuardrailsAreInvoked() { + assertThat(Arc.container().requestContext().isActive()).isFalse(); + Arc.container().requestContext().activate(); + try { + assertThat(okGuardrail.spy()).isEqualTo(0); + aiService.hiUsingDefaultAccumulator("1").collect().asList().await().indefinitely(); + assertThat(okGuardrail.spy()).isEqualTo(1); + aiService.hiUsingDefaultAccumulator("2").collect().asList().await().indefinitely(); + assertThat(okGuardrail.spy()).isEqualTo(2); + } finally { + Arc.container().requestContext().deactivate(); + } + + Arc.container().requestContext().activate(); + try { + // New request scope - the value should be back to 0 + assertThat(okGuardrail.spy()).isEqualTo(0); + aiService.hiUsingDefaultAccumulator("1").collect().asList().await().indefinitely(); + assertThat(okGuardrail.spy()).isEqualTo(1); + aiService.hiUsingDefaultAccumulator("1").collect().asList().await().indefinitely(); + assertThat(okGuardrail.spy()).isEqualTo(2); + } finally { + Arc.container().requestContext().deactivate(); + } + + assertThat(Arc.container().requestContext().isActive()).isFalse(); + + Arc.container().requestContext().activate(); + try { + assertThat(okGuardrail.spy()).isEqualTo(0); + aiService.hiUsingPassThroughAccumulator("1").collect().asList().await().indefinitely(); + assertThat(okGuardrail.spy()).isEqualTo(3); // 3 chunks + aiService.hiUsingPassThroughAccumulator("2").collect().asList().await().indefinitely(); + assertThat(okGuardrail.spy()).isEqualTo(6); // 3+3 chunks + } finally { + Arc.container().requestContext().deactivate(); + } + + Arc.container().requestContext().activate(); + try { + // New request scope - the value should be back to 0 + assertThat(okGuardrail.spy()).isEqualTo(0); + aiService.hiUsingPassThroughAccumulator("1").collect().asList().await().indefinitely(); + assertThat(okGuardrail.spy()).isEqualTo(3); + aiService.hiUsingPassThroughAccumulator("1").collect().asList().await().indefinitely(); + assertThat(okGuardrail.spy()).isEqualTo(6); + } finally { + Arc.container().requestContext().deactivate(); + } + } + + @Test + @ActivateRequestContext + void testThatGuardrailCanThrowValidationException() { + assertThat(koGuardrail.spy()).isEqualTo(0); + assertThatThrownBy(() -> aiService.koUsingDefaultAccumulator("1").collect().asList().await().indefinitely()) + .hasCauseExactlyInstanceOf(ValidationException.class); + assertThat(koGuardrail.spy()).isEqualTo(1); + assertThatThrownBy(() -> aiService.koUsingDefaultAccumulator("1").collect().asList().await().indefinitely()) + .hasCauseExactlyInstanceOf(ValidationException.class); + assertThat(koGuardrail.spy()).isEqualTo(2); + } + + @Test + @ActivateRequestContext + void testThatGuardrailCanThrowValidationExceptionWhenUsingPassThroughAccumulator() { + assertThat(koGuardrail.spy()).isEqualTo(0); + assertThatThrownBy(() -> aiService.koUsingPassThroughAccumulator("1").collect().asList().await().indefinitely()) + .hasCauseExactlyInstanceOf(ValidationException.class); + assertThat(koGuardrail.spy()).isEqualTo(2); // First chunk is ok, the second one fails. + assertThatThrownBy(() -> aiService.koUsingPassThroughAccumulator("1").collect().asList().await().indefinitely()) + .hasCauseExactlyInstanceOf(ValidationException.class); + assertThat(koGuardrail.spy()).isEqualTo(4); + } + + @RegisterAiService(streamingChatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class) + public interface MyAiService { + + @UserMessage("Say Hi!") + @OutputGuardrails(OKGuardrail.class) + Multi hiUsingDefaultAccumulator(@MemoryId String mem); + + @UserMessage("Say Hi!") + @OutputGuardrails(KOGuardrail.class) + Multi koUsingDefaultAccumulator(@MemoryId String mem); + + @UserMessage("Say Hi!") + @OutputGuardrails(OKGuardrail.class) + @OutputGuardrailAccumulator(PassThroughAccumulator.class) + Multi hiUsingPassThroughAccumulator(@MemoryId String mem); + + @UserMessage("Say Hi!") + @OutputGuardrails(KOGuardrail.class) + @OutputGuardrailAccumulator(PassThroughAccumulator.class) + Multi koUsingPassThroughAccumulator(@MemoryId String mem); + + } + + @RequestScoped + public static class OKGuardrail implements OutputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + spy.incrementAndGet(); + return success(); + } + + public int spy() { + return spy.get(); + } + } + + @RequestScoped + public static class KOGuardrail implements OutputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + spy.incrementAndGet(); + if (responseFromLLM.text().length() > 3) { // Accumulated response. + return failure("KO", new ValidationException("KO")); + } else { // Chunk, do not fail on the first chunk + if (responseFromLLM.text().contains("Hi!")) { + return success(); + } else { + return failure("KO", new ValidationException("KO")); + } + } + } + + public int spy() { + return spy.get(); + } + } + + public static class MyChatModelSupplier implements Supplier { + + @Override + public StreamingChatModel get() { + return new MyStreamedChatModel(); + } + } + + public static class MyStreamedChatModel implements StreamingChatModel { + + @Override + public void doChat(ChatRequest chatRequest, StreamingChatResponseHandler handler) { + handler.onPartialResponse("Hi!"); + handler.onPartialResponse(" "); + handler.onPartialResponse("World!"); + handler.onCompleteResponse(ChatResponse.builder().aiMessage(new AiMessage("")).build()); + } + } + + @ApplicationScoped + public static class PassThroughAccumulator implements OutputTokenAccumulator { + + @Override + public Multi accumulate(Multi tokens) { + return tokens; + } + } + + public static class MyMemoryProviderSupplier implements Supplier { + @Override + public ChatMemoryProvider get() { + return new ChatMemoryProvider() { + @Override + public ChatMemory get(Object memoryId) { + return new NoopChatMemory(); + } + }; + } + } +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusOutputGuardrailOnStreamedResponseValidationTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusOutputGuardrailOnStreamedResponseValidationTest.java new file mode 100644 index 000000000..edbd96bc3 --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusOutputGuardrailOnStreamedResponseValidationTest.java @@ -0,0 +1,351 @@ +package io.quarkiverse.langchain4j.test.guardrails; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.context.RequestScoped; +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.memory.ChatMemory; +import dev.langchain4j.memory.chat.ChatMemoryProvider; +import dev.langchain4j.model.chat.StreamingChatModel; +import dev.langchain4j.model.chat.request.ChatRequest; +import dev.langchain4j.model.chat.response.ChatResponse; +import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailAccumulator; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; +import io.quarkiverse.langchain4j.guardrails.OutputTokenAccumulator; +import io.quarkiverse.langchain4j.runtime.aiservice.GuardrailException; +import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; +import io.quarkiverse.langchain4j.test.guardrails.OutputGuardrailOnStreamedResponseValidationTest.OKGuardrail; +import io.quarkus.test.QuarkusUnitTest; +import io.smallrye.mutiny.Multi; + +/** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ +@Deprecated(forRemoval = true) +public class QuarkusOutputGuardrailOnStreamedResponseValidationTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(MyStreamedChatModel.class)); + + @Inject + MyAiService aiService; + + @Inject + OKGuardrail okGuardrail; + + @Test + @ActivateRequestContext + void testOk() { + aiService.ok("1").collect().asList().await().indefinitely(); + assertThat(okGuardrail.spy()).isEqualTo(1); + } + + @Test + @ActivateRequestContext + void testOkWithPassThroughAccumulator() { + aiService.okWithPassThroughAccumulator("1").collect().asList().await().indefinitely(); + assertThat(okGuardrail.spy()).isEqualTo(3); + } + + @Test + @ActivateRequestContext + void testKO() { + assertThatThrownBy(() -> aiService.ko("2").collect().asList().await().indefinitely()) + .isInstanceOf(GuardrailException.class) + .hasMessageContaining("KO"); + } + + @Test + @ActivateRequestContext + void testKOWithPassThroughAccumulator() { + assertThatThrownBy(() -> aiService.koWithPassThroughAccumulator("2").collect().asList().await().indefinitely()) + .isInstanceOf(GuardrailException.class) + .hasMessageContaining("KO"); + } + + @Inject + RetryingGuardrail retry; + + @Test + @ActivateRequestContext + void testRetryOk() { + aiService.retry("3").collect().asList().await().indefinitely(); + assertThat(retry.spy()).isEqualTo(2); + } + + @Test + @ActivateRequestContext + void testRetryOkWithPassThroughAccumulator() { + aiService.retryWithPassThroughAccumulator("3").collect().asList().await().indefinitely(); + assertThat(retry.spy()).isEqualTo(4); // "Hi!", "Hi!" (retry), " ", "World!" + } + + @Inject + RetryingButFailGuardrail retryFail; + + @Test + @ActivateRequestContext + void testRetryFail() { + assertThatThrownBy(() -> aiService.retryButFail("4").collect().asList().await().indefinitely()) + .isInstanceOf(GuardrailException.class) + .hasMessageContaining("maximum number of retries"); + assertThat(retryFail.spy()).isEqualTo(4); + } + + @Inject + PassThroughAccumulator accumulator; + + @Test + @ActivateRequestContext + void testRetryFailWithPassThroughAccumulator() { + assertThatThrownBy( + () -> aiService.retryButFailWithPassThroughAccumulator("4").collect().asList().await().indefinitely()) + .isInstanceOf(GuardrailException.class) + .hasMessageContaining("maximum number of retries"); + assertThat(retryFail.spy()).isEqualTo(4); + assertThat(accumulator.spy()).isEqualTo(4); + } + + @Inject + KOFatalGuardrail fatal; + + @Test + @ActivateRequestContext + void testFatalException() { + assertThatThrownBy(() -> aiService.fatal("5").collect().asList().await().indefinitely()) + .isInstanceOf(GuardrailException.class) + .hasMessageContaining("Fatal"); + assertThat(fatal.spy()).isEqualTo(1); + } + + @Test + @ActivateRequestContext + void testFatalExceptionWithPassThroughAccumulator() { + assertThatThrownBy(() -> aiService.fatalWithPassThroughAccumulator("5").collect().asList().await().indefinitely()) + .isInstanceOf(GuardrailException.class) + .hasMessageContaining("Fatal"); + assertThat(fatal.spy()).isEqualTo(1); + } + + @Test + @ActivateRequestContext + void testRewritingWhileStreamingIsNotAllowed() { + assertThatThrownBy(() -> aiService.rewriting("1").collect().asList().await().indefinitely()) + .isInstanceOf(GuardrailException.class) + .hasMessageContaining("Attempting to rewrite the LLM output while streaming is not allowed"); + } + + @RegisterAiService(streamingChatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class) + public interface MyAiService { + + @UserMessage("Say Hi!") + @OutputGuardrails(OKGuardrail.class) + Multi ok(@MemoryId String mem); + + @UserMessage("Say Hi!") + @OutputGuardrails(KOGuardrail.class) + Multi ko(@MemoryId String mem); + + @UserMessage("Say Hi!") + @OutputGuardrails(RetryingGuardrail.class) + Multi retry(@MemoryId String mem); + + @UserMessage("Say Hi!") + @OutputGuardrails(RetryingButFailGuardrail.class) + Multi retryButFail(@MemoryId String mem); + + @UserMessage("Say Hi!") + @OutputGuardrails(KOFatalGuardrail.class) + Multi fatal(@MemoryId String mem); + + @UserMessage("Say Hi!") + @OutputGuardrails(OKGuardrail.class) + @OutputGuardrailAccumulator(PassThroughAccumulator.class) + Multi okWithPassThroughAccumulator(@MemoryId String mem); + + @UserMessage("Say Hi!") + @OutputGuardrails(KOGuardrail.class) + @OutputGuardrailAccumulator(PassThroughAccumulator.class) + Multi koWithPassThroughAccumulator(@MemoryId String mem); + + @UserMessage("Say Hi!") + @OutputGuardrails(RetryingGuardrail.class) + @OutputGuardrailAccumulator(PassThroughAccumulator.class) + Multi retryWithPassThroughAccumulator(@MemoryId String mem); + + @UserMessage("Say Hi!") + @OutputGuardrails(RetryingButFailGuardrail.class) + @OutputGuardrailAccumulator(PassThroughAccumulator.class) + Multi retryButFailWithPassThroughAccumulator(@MemoryId String mem); + + @UserMessage("Say Hi!") + @OutputGuardrails(KOFatalGuardrail.class) + @OutputGuardrailAccumulator(PassThroughAccumulator.class) + Multi fatalWithPassThroughAccumulator(@MemoryId String mem); + + @UserMessage("Say Hi!") + @OutputGuardrails({ RewritingGuardrail.class }) + Multi rewriting(@MemoryId String mem); + } + + @RequestScoped + public static class OKGuardrail implements OutputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + spy.incrementAndGet(); + return success(); + } + + public int spy() { + return spy.get(); + } + } + + @ApplicationScoped + public static class KOGuardrail implements OutputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + spy.incrementAndGet(); + return failure("KO"); + } + + public int spy() { + return spy.get(); + } + } + + @RequestScoped + public static class RetryingGuardrail implements OutputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + int v = spy.incrementAndGet(); + if (v >= 2) { + return OutputGuardrailResult.success(); + } + return retry("KO"); + } + + public int spy() { + return spy.get(); + } + } + + @RequestScoped + public static class RetryingButFailGuardrail implements OutputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + spy.incrementAndGet(); + return retry("KO"); + } + + public int spy() { + return spy.get(); + } + } + + @RequestScoped + public static class KOFatalGuardrail implements OutputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + spy.incrementAndGet(); + throw new IllegalArgumentException("Fatal"); + } + + public int spy() { + return spy.get(); + } + } + + @RequestScoped + public static class RewritingGuardrail implements OutputGuardrail { + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + String text = responseFromLLM.text(); + return successWith(text + ",1"); + } + } + + public static class MyChatModelSupplier implements Supplier { + + @Override + public StreamingChatModel get() { + return new MyStreamedChatModel(); + } + } + + public static class MyStreamedChatModel implements StreamingChatModel { + + @Override + public void doChat(ChatRequest chatRequest, StreamingChatResponseHandler handler) { + handler.onPartialResponse("Hi!"); + handler.onPartialResponse(" "); + handler.onPartialResponse("World!"); + handler.onCompleteResponse(ChatResponse.builder().aiMessage(new AiMessage("")).build()); + } + } + + @RequestScoped + public static class PassThroughAccumulator implements OutputTokenAccumulator { + + AtomicInteger spy = new AtomicInteger(); + + public int spy() { + return spy.get(); + } + + @Override + public Multi accumulate(Multi tokens) { + return tokens + .onSubscription().invoke(() -> spy.incrementAndGet()); + } + } + + public static class MyMemoryProviderSupplier implements Supplier { + @Override + public ChatMemoryProvider get() { + return new ChatMemoryProvider() { + @Override + public ChatMemory get(Object memoryId) { + return new NoopChatMemory(); + } + }; + } + } +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusOutputGuardrailPromptTemplateTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusOutputGuardrailPromptTemplateTest.java new file mode 100644 index 000000000..d39b661f2 --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusOutputGuardrailPromptTemplateTest.java @@ -0,0 +1,231 @@ +package io.quarkiverse.langchain4j.test.guardrails; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; +import java.util.Map; +import java.util.function.Supplier; + +import jakarta.enterprise.context.RequestScoped; +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.memory.chat.ChatMemoryProvider; +import dev.langchain4j.model.chat.ChatModel; +import dev.langchain4j.model.chat.request.ChatRequest; +import dev.langchain4j.model.chat.response.ChatResponse; +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailParams; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; +import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; +import io.quarkus.test.QuarkusUnitTest; + +/** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ +@Deprecated(forRemoval = true) +public class QuarkusOutputGuardrailPromptTemplateTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(MyAiService.class, MyAiService.class, GuardrailValidation.class, + MyChatModel.class, MyChatModelSupplier.class, MyMemoryProviderSupplier.class)); + @Inject + MyAiService aiService; + + @Inject + GuardrailValidation guardrailValidation; + + @Test + @ActivateRequestContext + void shouldWorkNoParameters() { + aiService.getJoke(); + assertThat(guardrailValidation.spyUserMessageTemplate()).isEqualTo("Tell me a joke"); + assertThat(guardrailValidation.spyVariables()).isEmpty(); + } + + @Test + @ActivateRequestContext + void shouldWorkWithMemoryId() { + aiService.getAnotherJoke("memory-id-001"); + assertThat(guardrailValidation.spyUserMessageTemplate()).isEqualTo("Tell me another joke"); + assertThat(guardrailValidation.spyVariables()).containsExactlyInAnyOrderEntriesOf(Map.of( + "memoryId", "memory-id-001", + "it", "memory-id-001")); + } + + @Test + @ActivateRequestContext + void shouldWorkWithNoMemoryIdAndOneParameter() { + aiService.sayHiToMyFriendNoMemory("Rambo"); + assertThat(guardrailValidation.spyUserMessageTemplate()).isEqualTo("Say hi to my friend {friend}!"); + assertThat(guardrailValidation.spyVariables()) + .containsExactlyInAnyOrderEntriesOf(Map.of( + "friend", "Rambo", + "it", "Rambo")); + } + + @Test + @ActivateRequestContext + void shouldWorkWithMemoryIdAndOneParameter() { + aiService.sayHiToMyFriend("1", "Chuck Norris"); + assertThat(guardrailValidation.spyUserMessageTemplate()).isEqualTo("Say hi to my friend {friend}!"); + assertThat(guardrailValidation.spyVariables()) + .containsExactlyInAnyOrderEntriesOf(Map.of( + "friend", "Chuck Norris", + "mem", "1")); + } + + @Test + @ActivateRequestContext + void shouldWorkWithNoMemoryIdAndThreeParameters() { + aiService.sayHiToMyFriends("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone"); + assertThat(guardrailValidation.spyUserMessageTemplate()) + .isEqualTo("Tell me something about {topic1}, {topic2}, {topic3}!"); + assertThat(guardrailValidation.spyVariables()) + .containsExactlyInAnyOrderEntriesOf(Map.of( + "topic1", "Chuck Norris", + "topic2", "Jean-Claude Van Damme", + "topic3", "Silvester Stallone")); + } + + @Test + @ActivateRequestContext + void shouldWorkWithNoMemoryIdAndList() { + aiService.sayHiToMyFriends(List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone")); + + assertThat(guardrailValidation.spyUserMessageTemplate()).isEqualTo("Tell me something about {topics}!"); + assertThat(guardrailValidation.spyVariables()) + .containsExactlyInAnyOrderEntriesOf(Map.of( + "topics", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone"), + "it", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone"))); + } + + @Test + @ActivateRequestContext + void shouldWorkWithMemoryIdAndList() { + aiService.sayHiToMyFriends("memory-id-007", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone")); + + assertThat(guardrailValidation.spyUserMessageTemplate()) + .isEqualTo("Tell me something about {topics}! This is my memory id: {memoryId}"); + assertThat(guardrailValidation.spyVariables()) + .containsExactlyInAnyOrderEntriesOf(Map.of( + "topics", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone"), + "memoryId", "memory-id-007")); + } + + @Test + @ActivateRequestContext + void shouldWorkWithMemoryIdAndOneItemFromList() { + aiService.sayHiToMyFriend("memory-id-007", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone")); + + assertThat(guardrailValidation.spyUserMessageTemplate()) + .isEqualTo("Tell me something about {topics[0]}! This is my memory id: {memoryId}"); + assertThat(guardrailValidation.spyVariables()) + .containsExactlyInAnyOrderEntriesOf(Map.of( + "topics", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone"), + "memoryId", "memory-id-007")); + } + + @Test + @ActivateRequestContext + void shouldWorkWithNoUserMessage() { + // UserMessage annotation is not provided, then no user message template should be available + aiService.saySomething("Is this a parameter or a prompt?"); + assertThat(guardrailValidation.spyUserMessageTemplate()).isEmpty(); + assertThat(guardrailValidation.spyVariables()).isEmpty(); + } + + @RegisterAiService(chatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class) + public interface MyAiService { + + @OutputGuardrails(GuardrailValidation.class) + @UserMessage("Tell me a joke") + String getJoke(); + + @UserMessage("Tell me another joke") + @OutputGuardrails(GuardrailValidation.class) + String getAnotherJoke(@MemoryId String memoryId); + + @UserMessage("Say hi to my friend {friend}!") + @OutputGuardrails(GuardrailValidation.class) + String sayHiToMyFriendNoMemory(String friend); + + @UserMessage("Say hi to my friend {friend}!") + @OutputGuardrails(GuardrailValidation.class) + String sayHiToMyFriend(@MemoryId String mem, String friend); + + @UserMessage("Tell me something about {topic1}, {topic2}, {topic3}!") + @OutputGuardrails(GuardrailValidation.class) + String sayHiToMyFriends(String topic1, String topic2, String topic3); + + @UserMessage("Tell me something about {topics}!") + @OutputGuardrails(GuardrailValidation.class) + String sayHiToMyFriends(List topics); + + @UserMessage("Tell me something about {topics}! This is my memory id: {memoryId}") + @OutputGuardrails(GuardrailValidation.class) + String sayHiToMyFriends(@MemoryId String memoryId, List topics); + + @UserMessage("Tell me something about {topics[0]}! This is my memory id: {memoryId}") + @OutputGuardrails(GuardrailValidation.class) + String sayHiToMyFriend(@MemoryId String memoryId, List topics); + + @OutputGuardrails(GuardrailValidation.class) + String saySomething(String isThisAPromptOrAParameter); + + } + + @RequestScoped + public static class GuardrailValidation implements OutputGuardrail { + + OutputGuardrailParams params; + + public OutputGuardrailResult validate(OutputGuardrailParams params) { + this.params = params; + return success(); + } + + public String spyUserMessageTemplate() { + return params.userMessageTemplate(); + } + + public Map spyVariables() { + return params.variables(); + } + } + + public static class MyChatModelSupplier implements Supplier { + + @Override + public ChatModel get() { + return new MyChatModel(); + } + } + + public static class MyChatModel implements ChatModel { + + @Override + public ChatResponse doChat(ChatRequest request) { + return ChatResponse.builder().aiMessage(new AiMessage("Hi!")).build(); + } + } + + public static class MyMemoryProviderSupplier implements Supplier { + @Override + public ChatMemoryProvider get() { + return memoryId -> new NoopChatMemory(); + } + } +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusOutputGuardrailRepromptingRetryDisabledTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusOutputGuardrailRepromptingRetryDisabledTest.java new file mode 100644 index 000000000..35fc26eec --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusOutputGuardrailRepromptingRetryDisabledTest.java @@ -0,0 +1,167 @@ +package io.quarkiverse.langchain4j.test.guardrails; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.memory.ChatMemory; +import dev.langchain4j.memory.chat.ChatMemoryProvider; +import dev.langchain4j.memory.chat.MessageWindowChatMemory; +import dev.langchain4j.model.chat.ChatModel; +import dev.langchain4j.model.chat.request.ChatRequest; +import dev.langchain4j.model.chat.response.ChatResponse; +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.SystemMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailParams; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; +import io.quarkiverse.langchain4j.runtime.aiservice.GuardrailException; +import io.quarkus.test.QuarkusUnitTest; + +/** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ +@Deprecated(forRemoval = true) +public class QuarkusOutputGuardrailRepromptingRetryDisabledTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(MyAiService.class, + MyChatModel.class, MyChatModelSupplier.class, MyMemoryProviderSupplier.class)) + .overrideRuntimeConfigKey("quarkus.langchain4j.guardrails.max-retries", "0"); + + @Inject + MyAiService aiService; + + @Test + @ActivateRequestContext + void testOk() { + aiService.ok("1", "foo"); + } + + @Inject + RetryGuardrail retryGuardrail; + + @Test + @ActivateRequestContext + void testRetryFailing() { + assertThatThrownBy(() -> aiService.retry("1", "foo")) + .isInstanceOf(GuardrailException.class) + .hasMessageContaining("maximum number of retries"); + assertThat(retryGuardrail.getSpy()).isEqualTo(1); // No retry + } + + @Inject + RepromptingGuardrail repromptingGuardrail; + + @Test + @ActivateRequestContext + void testRepromptingFailing() { + assertThatThrownBy(() -> aiService.reprompting("1", "foo")) + .isInstanceOf(GuardrailException.class) + .hasMessageContaining("maximum number of retries"); + assertThat(repromptingGuardrail.getSpy()).isEqualTo(1); // No retry + } + + @RegisterAiService(chatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class) + @SystemMessage("Say Hi!") + public interface MyAiService { + + @OutputGuardrails(OkGuardrail.class) + String ok(@MemoryId String mem, @dev.langchain4j.service.UserMessage String message); + + @OutputGuardrails(RetryGuardrail.class) + String retry(@MemoryId String mem, @dev.langchain4j.service.UserMessage String message); + + @OutputGuardrails(RepromptingGuardrail.class) + String reprompting(@MemoryId String mem, @dev.langchain4j.service.UserMessage String message); + } + + public static class MyChatModelSupplier implements Supplier { + + @Override + public ChatModel get() { + return new MyChatModel(); + } + } + + @ApplicationScoped + public static class OkGuardrail implements OutputGuardrail { + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + return success(); + } + + } + + @ApplicationScoped + public static class RetryGuardrail implements OutputGuardrail { + + private final AtomicInteger spy = new AtomicInteger(0); + + @Override + public OutputGuardrailResult validate(OutputGuardrailParams params) { + int v = spy.incrementAndGet(); + return retry("Retry"); + } + + public int getSpy() { + return spy.get(); + } + } + + @ApplicationScoped + public static class RepromptingGuardrail implements OutputGuardrail { + + private final AtomicInteger spy = new AtomicInteger(0); + + @Override + public OutputGuardrailResult validate(OutputGuardrailParams params) { + int v = spy.incrementAndGet(); + return reprompt("Retry", "reprompt"); + } + + public int getSpy() { + return spy.get(); + } + } + + public static class MyChatModel implements ChatModel { + @Override + public ChatResponse doChat(ChatRequest request) { + return ChatResponse.builder().aiMessage(new AiMessage("Hello")).build(); + } + } + + public static class MyMemoryProviderSupplier implements Supplier { + private final Map memories = new HashMap<>(); + + @Override + public ChatMemoryProvider get() { + return new ChatMemoryProvider() { + @Override + public ChatMemory get(Object memoryId) { + return memories.computeIfAbsent(memoryId.toString(), k -> MessageWindowChatMemory.withMaxMessages(10)); + } + }; + } + } +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusOutputGuardrailRepromptingTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusOutputGuardrailRepromptingTest.java new file mode 100644 index 000000000..121d91d9d --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusOutputGuardrailRepromptingTest.java @@ -0,0 +1,228 @@ +package io.quarkiverse.langchain4j.test.guardrails; + +import static io.quarkiverse.langchain4j.runtime.LangChain4jUtil.chatMessageToText; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.memory.ChatMemory; +import dev.langchain4j.memory.chat.ChatMemoryProvider; +import dev.langchain4j.memory.chat.MessageWindowChatMemory; +import dev.langchain4j.model.chat.ChatModel; +import dev.langchain4j.model.chat.request.ChatRequest; +import dev.langchain4j.model.chat.response.ChatResponse; +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.SystemMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailParams; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; +import io.quarkiverse.langchain4j.runtime.aiservice.GuardrailException; +import io.quarkus.test.QuarkusUnitTest; + +/** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ +@Deprecated(forRemoval = true) +public class QuarkusOutputGuardrailRepromptingTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(MyAiService.class, + MyChatModel.class, MyChatModelSupplier.class, MyMemoryProviderSupplier.class)); + + @Inject + MyAiService aiService; + + @Inject + RepromptingOne repromptingOne; + @Inject + RepromptingTwo repromptingTwo; + @Inject + RepromptingFailed repromptingFailed; + + @Test + @ActivateRequestContext + void testRepromptingOkAfterOneRetry() { + aiService.one("1", "foo"); + assertThat(repromptingOne.getSpy()).isEqualTo(2); + } + + @Test + @ActivateRequestContext + void testRepromptingOkAfterTwoRetries() { + aiService.two("2", "foo"); + assertThat(repromptingTwo.getSpy()).isEqualTo(3); + } + + @Test + @ActivateRequestContext + void testRepromptingFailing() { + assertThatThrownBy(() -> aiService.fail("3", "foo")) + .isInstanceOf(GuardrailException.class); + assertThat(repromptingFailed.getSpy()).isEqualTo(3); + + } + + @RegisterAiService(chatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class) + @SystemMessage("Say Hi!") + public interface MyAiService { + + @OutputGuardrails(RepromptingOne.class) + String one(@MemoryId String mem, @dev.langchain4j.service.UserMessage String message); + + @OutputGuardrails(RepromptingTwo.class) + String two(@MemoryId String mem, @dev.langchain4j.service.UserMessage String message); + + @OutputGuardrails(RepromptingFailed.class) + String fail(@MemoryId String mem, @dev.langchain4j.service.UserMessage String message); + } + + public static class MyChatModelSupplier implements Supplier { + + @Override + public ChatModel get() { + return new MyChatModel(); + } + } + + @ApplicationScoped + public static class RepromptingOne implements OutputGuardrail { + + private final AtomicInteger spy = new AtomicInteger(0); + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + if (spy.incrementAndGet() == 1) { + return reprompt("Retry", "Retry"); + } + return success(); + } + + public int getSpy() { + return spy.get(); + } + } + + @ApplicationScoped + public static class RepromptingTwo implements OutputGuardrail { + + private final AtomicInteger spy = new AtomicInteger(0); + + @Override + public OutputGuardrailResult validate(OutputGuardrailParams params) { + int v = spy.incrementAndGet(); + List messages = params.memory().messages(); + if (v == 1) { + ChatMessage last = messages.get(messages.size() - 1); + assertThat(last).isInstanceOf(AiMessage.class); + assertThat(((AiMessage) last).text()).isEqualTo("Nope"); + assertThat(params.responseFromLLM().text()).isEqualTo("Nope"); + return reprompt("Retry", "Retry"); + } + if (v == 2) { + // Check that it's in memory + ChatMessage last = messages.get(messages.size() - 1); + ChatMessage beforeLast = messages.get(messages.size() - 2); + + assertThat(last).isInstanceOf(AiMessage.class); + assertThat(((AiMessage) last).text()).isEqualTo("Hello"); + assertThat(params.responseFromLLM().text()).isEqualTo("Hello"); + assertThat(beforeLast).isInstanceOf(UserMessage.class); + assertThat(chatMessageToText(beforeLast)).isEqualTo("Retry"); + + return reprompt("Retry", "Retry"); + } + if (v != 3) { + throw new IllegalArgumentException("Unexpected call"); + } + return success(); + } + + public int getSpy() { + return spy.get(); + } + } + + @ApplicationScoped + public static class RepromptingFailed implements OutputGuardrail { + + private final AtomicInteger spy = new AtomicInteger(0); + + @Override + public OutputGuardrailResult validate(OutputGuardrailParams params) { + int v = spy.incrementAndGet(); + List messages = params.memory().messages(); + if (v == 1) { + ChatMessage last = messages.get(messages.size() - 1); + assertThat(last).isInstanceOf(AiMessage.class); + assertThat(((AiMessage) last).text()).isEqualTo("Nope"); + return reprompt("Retry", "Retry Once"); + } + if (v == 2) { + // Check that it's in memory + ChatMessage last = messages.get(messages.size() - 1); + ChatMessage beforeLast = messages.get(messages.size() - 2); + + assertThat(last).isInstanceOf(AiMessage.class); + assertThat(((AiMessage) last).text()).isEqualTo("Hello"); + assertThat(beforeLast).isInstanceOf(UserMessage.class); + assertThat(chatMessageToText(beforeLast)).isEqualTo("Retry Once"); + return reprompt("Retry", "Retry Twice"); + } + return reprompt("Retry", "Retry Again"); + } + + public int getSpy() { + return spy.get(); + } + } + + public static class MyChatModel implements ChatModel { + + @Override + public ChatResponse doChat(ChatRequest request) { + ChatMessage last = request.messages().get(request.messages().size() - 1); + if (last instanceof UserMessage && chatMessageToText(last).equals("foo")) { + return ChatResponse.builder().aiMessage(new AiMessage("Nope")).build(); + } + if (last instanceof UserMessage && chatMessageToText(last).contains("Retry")) { + return ChatResponse.builder().aiMessage(new AiMessage("Hello")).build(); + } + throw new IllegalArgumentException("Unexpected message: " + request.messages()); + } + } + + public static class MyMemoryProviderSupplier implements Supplier { + private final Map memories = new HashMap<>(); + + @Override + public ChatMemoryProvider get() { + return new ChatMemoryProvider() { + @Override + public ChatMemory get(Object memoryId) { + return memories.computeIfAbsent(memoryId.toString(), k -> MessageWindowChatMemory.withMaxMessages(10)); + } + }; + } + } +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusOutputGuardrailTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusOutputGuardrailTest.java new file mode 100644 index 000000000..d15503b56 --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusOutputGuardrailTest.java @@ -0,0 +1,159 @@ +package io.quarkiverse.langchain4j.test.guardrails; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.context.RequestScoped; +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.memory.ChatMemory; +import dev.langchain4j.memory.chat.ChatMemoryProvider; +import dev.langchain4j.model.chat.ChatModel; +import dev.langchain4j.model.chat.request.ChatRequest; +import dev.langchain4j.model.chat.response.ChatResponse; +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; +import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; +import io.quarkus.arc.Arc; +import io.quarkus.test.QuarkusUnitTest; + +/** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ +@Deprecated(forRemoval = true) +public class QuarkusOutputGuardrailTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(MyAiService.class, OKGuardrail.class, KOGuardrail.class, MyChatModel.class, + MyChatModelSupplier.class, MyMemoryProviderSupplier.class, ValidationException.class)); + + @Inject + MyAiService aiService; + + @Inject + OKGuardrail okGuardrail; + @Inject + KOGuardrail koGuardrail; + + @Test + void testThatOutputGuardrailsAreInvoked() { + assertThat(Arc.container().requestContext().isActive()).isFalse(); + Arc.container().requestContext().activate(); + assertThat(okGuardrail.spy()).isEqualTo(0); + aiService.hi("1"); + assertThat(okGuardrail.spy()).isEqualTo(1); + aiService.hi("2"); + assertThat(okGuardrail.spy()).isEqualTo(2); + Arc.container().requestContext().deactivate(); + + Arc.container().requestContext().activate(); + // New request scope - the value should be back to 0 + assertThat(okGuardrail.spy()).isEqualTo(0); + aiService.hi("1"); + assertThat(okGuardrail.spy()).isEqualTo(1); + aiService.hi("2"); + assertThat(okGuardrail.spy()).isEqualTo(2); + } + + @Test + @ActivateRequestContext + void testThatGuardrailCanThrowValidationException() { + assertThat(koGuardrail.spy()).isEqualTo(0); + assertThatThrownBy(() -> aiService.ko("1")) + .hasCauseExactlyInstanceOf(ValidationException.class); + assertThat(koGuardrail.spy()).isEqualTo(1); + assertThatThrownBy(() -> aiService.ko("1")) + .hasCauseExactlyInstanceOf(ValidationException.class); + assertThat(koGuardrail.spy()).isEqualTo(2); + } + + @RegisterAiService(chatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class) + public interface MyAiService { + + @UserMessage("Say Hi!") + @OutputGuardrails(OKGuardrail.class) + String hi(@MemoryId String mem); + + @UserMessage("Say Hi!") + @OutputGuardrails(KOGuardrail.class) + String ko(@MemoryId String mem); + + } + + @RequestScoped + public static class OKGuardrail implements OutputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + spy.incrementAndGet(); + return success(); + } + + public int spy() { + return spy.get(); + } + } + + @ApplicationScoped + public static class KOGuardrail implements OutputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + spy.incrementAndGet(); + return failure("KO", new ValidationException("KO")); + } + + public int spy() { + return spy.get(); + } + } + + public static class MyChatModelSupplier implements Supplier { + + @Override + public ChatModel get() { + return new MyChatModel(); + } + } + + public static class MyChatModel implements ChatModel { + + @Override + public ChatResponse doChat(ChatRequest request) { + return ChatResponse.builder().aiMessage(new AiMessage("Hi!")).build(); + } + } + + public static class MyMemoryProviderSupplier implements Supplier { + @Override + public ChatMemoryProvider get() { + return new ChatMemoryProvider() { + @Override + public ChatMemory get(Object memoryId) { + return new NoopChatMemory(); + } + }; + } + } +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusOutputGuardrailValidationTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusOutputGuardrailValidationTest.java new file mode 100644 index 000000000..a119b90fc --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/QuarkusOutputGuardrailValidationTest.java @@ -0,0 +1,232 @@ +package io.quarkiverse.langchain4j.test.guardrails; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.context.RequestScoped; +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.memory.ChatMemory; +import dev.langchain4j.memory.chat.ChatMemoryProvider; +import dev.langchain4j.model.chat.ChatModel; +import dev.langchain4j.model.chat.request.ChatRequest; +import dev.langchain4j.model.chat.response.ChatResponse; +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; +import io.quarkiverse.langchain4j.runtime.aiservice.GuardrailException; +import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; +import io.quarkus.test.QuarkusUnitTest; + +/** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ +@Deprecated(forRemoval = true) +public class QuarkusOutputGuardrailValidationTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(MyAiService.class, OKGuardrail.class, KOGuardrail.class, + MyChatModel.class, MyChatModelSupplier.class, MyMemoryProviderSupplier.class)); + + @Inject + MyAiService aiService; + + @Test + @ActivateRequestContext + void testOk() { + aiService.ok("1"); + } + + @Test + @ActivateRequestContext + void testKo() { + assertThatThrownBy(() -> aiService.ko("2")) + .isInstanceOf(GuardrailException.class) + .hasMessageContaining("KO"); + } + + @Inject + RetryingGuardrail retry; + + @Test + @ActivateRequestContext + void testRetryOk() { + aiService.retry("3"); + assertThat(retry.spy()).isEqualTo(2); + } + + @Inject + RetryingButFailGuardrail retryFail; + + @Test + @ActivateRequestContext + void testRetryFail() { + assertThatThrownBy(() -> aiService.retryButFail("4")) + .isInstanceOf(GuardrailException.class) + .hasMessageContaining("maximum number of retries"); + assertThat(retryFail.spy()).isEqualTo(3); + } + + @Inject + KOFatalGuardrail fatal; + + @Test + @ActivateRequestContext + void testFatalException() { + assertThatThrownBy(() -> aiService.fatal("5")) + .isInstanceOf(GuardrailException.class) + .hasMessageContaining("Fatal"); + assertThat(fatal.spy()).isEqualTo(1); + } + + @RegisterAiService(chatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class) + public interface MyAiService { + + @UserMessage("Say Hi!") + @OutputGuardrails(OKGuardrail.class) + String ok(@MemoryId String mem); + + @UserMessage("Say Hi!") + @OutputGuardrails(KOGuardrail.class) + String ko(@MemoryId String mem); + + @UserMessage("Say Hi!") + @OutputGuardrails(RetryingGuardrail.class) + String retry(@MemoryId String mem); + + @UserMessage("Say Hi!") + @OutputGuardrails(RetryingButFailGuardrail.class) + String retryButFail(@MemoryId String mem); + + @UserMessage("Say Hi!") + @OutputGuardrails(KOFatalGuardrail.class) + String fatal(@MemoryId String mem); + } + + @RequestScoped + public static class OKGuardrail implements OutputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + spy.incrementAndGet(); + return success(); + } + + public int spy() { + return spy.get(); + } + } + + @ApplicationScoped + public static class KOGuardrail implements OutputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + spy.incrementAndGet(); + return failure("KO"); + } + + public int spy() { + return spy.get(); + } + } + + @ApplicationScoped + public static class RetryingGuardrail implements OutputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + int v = spy.incrementAndGet(); + if (v == 2) { + return OutputGuardrailResult.success(); + } + return retry("KO"); + } + + public int spy() { + return spy.get(); + } + } + + @ApplicationScoped + public static class RetryingButFailGuardrail implements OutputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + int v = spy.incrementAndGet(); + return retry("KO"); + } + + public int spy() { + return spy.get(); + } + } + + @ApplicationScoped + public static class KOFatalGuardrail implements OutputGuardrail { + + AtomicInteger spy = new AtomicInteger(0); + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + spy.incrementAndGet(); + throw new IllegalArgumentException("Fatal"); + } + + public int spy() { + return spy.get(); + } + } + + public static class MyChatModelSupplier implements Supplier { + + @Override + public ChatModel get() { + return new MyChatModel(); + } + } + + public static class MyChatModel implements ChatModel { + + @Override + public ChatResponse doChat(ChatRequest request) { + return ChatResponse.builder().aiMessage(new AiMessage("Hi!")).build(); + } + } + + public static class MyMemoryProviderSupplier implements Supplier { + @Override + public ChatMemoryProvider get() { + return new ChatMemoryProvider() { + @Override + public ChatMemory get(Object memoryId) { + return new NoopChatMemory(); + } + }; + } + } +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/TokenStreamExecutor.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/TokenStreamExecutor.java new file mode 100644 index 000000000..8619274ce --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/TokenStreamExecutor.java @@ -0,0 +1,31 @@ +package io.quarkiverse.langchain4j.test.guardrails; + +import java.util.ArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; + +import dev.langchain4j.service.TokenStream; + +public abstract class TokenStreamExecutor { + protected String execute(Supplier aiServiceInvocation) throws InterruptedException { + var latch = new CountDownLatch(1); + var values = new ArrayList(); + + aiServiceInvocation + .get() + .onError(t -> { + throw new RuntimeException(t); + }) + .onPartialResponse(values::add) + .onCompleteResponse(response -> { + values.add(response.aiMessage().text()); + latch.countDown(); + }) + .start(); + + latch.await(10, TimeUnit.SECONDS); + + return String.join(" ", values).strip(); + } +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/response/ImperativeResponseAugmenterWithOutputGuardrailTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/response/ImperativeResponseAugmenterWithOutputGuardrailTest.java index 72715a650..d203d846e 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/response/ImperativeResponseAugmenterWithOutputGuardrailTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/response/ImperativeResponseAugmenterWithOutputGuardrailTest.java @@ -12,11 +12,11 @@ import org.junit.jupiter.api.extension.RegisterExtension; import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.guardrail.OutputGuardrail; +import dev.langchain4j.guardrail.OutputGuardrailResult; import dev.langchain4j.service.UserMessage; +import dev.langchain4j.service.guardrail.OutputGuardrails; import io.quarkiverse.langchain4j.RegisterAiService; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; import io.quarkiverse.langchain4j.response.ResponseAugmenter; import io.quarkus.test.QuarkusUnitTest; diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/response/QuarkusImperativeResponseAugmenterWithOutputGuardrailTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/response/QuarkusImperativeResponseAugmenterWithOutputGuardrailTest.java new file mode 100644 index 000000000..a2e4e3fda --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/response/QuarkusImperativeResponseAugmenterWithOutputGuardrailTest.java @@ -0,0 +1,64 @@ +package io.quarkiverse.langchain4j.test.response; + +import static org.assertj.core.api.Assertions.assertThat; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; +import io.quarkiverse.langchain4j.response.ResponseAugmenter; +import io.quarkus.test.QuarkusUnitTest; + +/** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ +@Deprecated(forRemoval = true) +public class QuarkusImperativeResponseAugmenterWithOutputGuardrailTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(ResponseAugmenterTestUtils.FakeChatModelSupplier.class, + ResponseAugmenterTestUtils.FakeChatModel.class, + ResponseAugmenterTestUtils.UppercaseAugmenter.class, + ResponseAugmenterTestUtils.AppenderAugmenter.class)); + + @Inject + MyAiService ai; + + @Test + @ActivateRequestContext + void test() { + assertThat(ai.hi()).isEqualTo("BONJOUR!"); + } + + @RegisterAiService(chatLanguageModelSupplier = ResponseAugmenterTestUtils.FakeChatModelSupplier.class) + public interface MyAiService { + + @UserMessage("Say Hello World!") + @ResponseAugmenter(ResponseAugmenterTestUtils.UppercaseAugmenter.class) + @OutputGuardrails(ChangingOutputGuardRail.class) + String hi(); + + } + + @ApplicationScoped + public static class ChangingOutputGuardRail implements OutputGuardrail { + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + return successWith("Bonjour!"); + } + } +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/response/QuarkusStreamingResponseAugmenterWithOutputGuardrailTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/response/QuarkusStreamingResponseAugmenterWithOutputGuardrailTest.java new file mode 100644 index 000000000..4938064ff --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/response/QuarkusStreamingResponseAugmenterWithOutputGuardrailTest.java @@ -0,0 +1,109 @@ +package io.quarkiverse.langchain4j.test.response; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.context.RequestScoped; +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailAccumulator; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; +import io.quarkiverse.langchain4j.guardrails.OutputTokenAccumulator; +import io.quarkiverse.langchain4j.response.ResponseAugmenter; +import io.quarkus.test.QuarkusUnitTest; +import io.smallrye.mutiny.Multi; + +/** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ +@Deprecated(forRemoval = true) +public class QuarkusStreamingResponseAugmenterWithOutputGuardrailTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(ResponseAugmenterTestUtils.FakeStreamedChatModel.class, + ResponseAugmenterTestUtils.FakeStreamedChatModelSupplier.class, + ResponseAugmenterTestUtils.StreamedUppercaseAugmenter.class, + ResponseAugmenterTestUtils.AppenderAugmenter.class)); + + @Inject + MyAiService ai; + + @Inject + OkGuardrail gr; + + @Test + @ActivateRequestContext + void test() { + List list = ai.hi().collect().asList() + .await().indefinitely(); + assertThat(list).containsExactly("HI! ", "WORLD!"); + assertThat(gr.isCalled()).isTrue(); + } + + @Test + @ActivateRequestContext + void testWithAnAugmenterHandlingBothStreamingAndImperative() { + List list = ai.hiAppend().collect().asList() + .await().indefinitely(); + assertThat(list).containsExactly("Hi! ", "World!", " WONDERFUL!"); + assertThat(gr.isCalled()).isTrue(); + } + + @RegisterAiService(streamingChatLanguageModelSupplier = ResponseAugmenterTestUtils.FakeStreamedChatModelSupplier.class) + public interface MyAiService { + + @UserMessage("Say Hello World!") + @ResponseAugmenter(ResponseAugmenterTestUtils.StreamedUppercaseAugmenter.class) + @OutputGuardrailAccumulator(MyAccumulator.class) + @OutputGuardrails(OkGuardrail.class) + Multi hi(); + + @UserMessage("Say Hello World!") + @ResponseAugmenter(ResponseAugmenterTestUtils.AppenderAugmenter.class) + @OutputGuardrailAccumulator(MyAccumulator.class) + @OutputGuardrails(OkGuardrail.class) + Multi hiAppend(); + + } + + @ApplicationScoped + public static class MyAccumulator implements OutputTokenAccumulator { + + @Override + public Multi accumulate(Multi tokens) { + return tokens.group().intoLists().of(2) + .map(l -> String.join(" ", l)); + } + } + + @RequestScoped + public static class OkGuardrail implements OutputGuardrail { + + volatile boolean called = false; + + public boolean isCalled() { + return called; + } + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + called = true; + return success(); + } + } +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/response/StreamingResponseAugmenterWithOutputGuardrailTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/response/StreamingResponseAugmenterWithOutputGuardrailTest.java index 602200a14..95bb30a11 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/response/StreamingResponseAugmenterWithOutputGuardrailTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/response/StreamingResponseAugmenterWithOutputGuardrailTest.java @@ -15,12 +15,12 @@ import org.junit.jupiter.api.extension.RegisterExtension; import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.guardrail.OutputGuardrail; +import dev.langchain4j.guardrail.OutputGuardrailResult; import dev.langchain4j.service.UserMessage; +import dev.langchain4j.service.guardrail.OutputGuardrails; import io.quarkiverse.langchain4j.RegisterAiService; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; import io.quarkiverse.langchain4j.guardrails.OutputGuardrailAccumulator; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; import io.quarkiverse.langchain4j.guardrails.OutputTokenAccumulator; import io.quarkiverse.langchain4j.response.ResponseAugmenter; import io.quarkus.test.QuarkusUnitTest; diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/QuarkusClassInstanceFactory.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/QuarkusClassInstanceFactory.java new file mode 100644 index 000000000..13cc8c7a0 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/QuarkusClassInstanceFactory.java @@ -0,0 +1,12 @@ +package io.quarkiverse.langchain4j; + +import jakarta.enterprise.inject.spi.CDI; + +import dev.langchain4j.spi.classloading.ClassInstanceFactory; + +public class QuarkusClassInstanceFactory implements ClassInstanceFactory { + @Override + public T getInstanceOfClass(Class clazz) { + return CDI.current().select(clazz).get(); + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/QuarkusClassMetadataProviderFactory.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/QuarkusClassMetadataProviderFactory.java new file mode 100644 index 000000000..706d9ee62 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/QuarkusClassMetadataProviderFactory.java @@ -0,0 +1,85 @@ +package io.quarkiverse.langchain4j; + +import java.lang.annotation.Annotation; +import java.util.Arrays; +import java.util.Collections; +import java.util.Map; +import java.util.Optional; + +import dev.langchain4j.service.guardrail.InputGuardrails; +import dev.langchain4j.service.guardrail.OutputGuardrails; +import dev.langchain4j.spi.classloading.ClassMetadataProviderFactory; +import io.quarkiverse.langchain4j.runtime.AiServicesRecorder; +import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceClassCreateInfo; +import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodCreateInfo; + +public class QuarkusClassMetadataProviderFactory implements ClassMetadataProviderFactory { + @Override + public Optional getAnnotation(AiServiceMethodCreateInfo method, Class annotationClass) { + return GuardrailType.fromAnnotationClass(annotationClass) + .getAnnotation(method); + } + + @Override + public Optional getAnnotation(Class clazz, Class annotationClass) { + return GuardrailType.fromAnnotationClass(annotationClass) + .getAnnotation(clazz); + } + + @Override + public Iterable getNonStaticMethodsOnClass(Class aiServiceClass) { + return getClassMetadata(aiServiceClass) + .map(AiServiceClassCreateInfo::methodMap) + .map(Map::values) + .orElseGet(Collections::emptyList); + } + + private static Optional getClassMetadata(Class aiServiceClass) { + return Optional.ofNullable(AiServicesRecorder.getMetadata().get(aiServiceClass.getName())); + } + + private enum GuardrailType { + INPUT(InputGuardrails.class) { + @Override + protected Optional getAnnotation(AiServiceMethodCreateInfo method) { + return Optional.ofNullable((T) method.getInputGuardrails()); + } + + @Override + protected Optional getAnnotation(Class clazz) { + return getClassMetadata(clazz) + .map(classCreateInfo -> (T) classCreateInfo.inputGuardrails()); + } + }, + OUTPUT(OutputGuardrails.class) { + @Override + protected Optional getAnnotation(AiServiceMethodCreateInfo method) { + return Optional.ofNullable((T) method.getOutputGuardrails()); + } + + @Override + protected Optional getAnnotation(Class clazz) { + return getClassMetadata(clazz) + .map(classCreateInfo -> (T) classCreateInfo.outputGuardrails()); + } + }; + + private final Class annotationClass; + + GuardrailType(Class annotationClass) { + this.annotationClass = annotationClass; + } + + protected abstract Optional getAnnotation(AiServiceMethodCreateInfo method); + + protected abstract Optional getAnnotation(Class clazz); + + public static GuardrailType fromAnnotationClass(Class annotationClass) { + return Arrays.stream(values()) + .filter(type -> type.annotationClass.equals(annotationClass)) + .findFirst() + .orElseThrow(() -> new IllegalArgumentException( + "Unsupported guardrail annotation: %s".formatted(annotationClass.getName()))); + } + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/audit/GuardrailExecutedEvent.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/audit/GuardrailExecutedEvent.java index c09330c85..86ecdc98a 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/audit/GuardrailExecutedEvent.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/audit/GuardrailExecutedEvent.java @@ -11,7 +11,9 @@ * * @param

the type of guardrail parameters used in the validation process * @param the type of guardrail result produced by the validation process + * @deprecated This will be replaced with an alternate version when the upstream guardrail implementation is merged */ +@Deprecated(forRemoval = true) public interface GuardrailExecutedEvent

, G extends Guardrail> extends LLMInteractionEvent { /** diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/audit/InputGuardrailExecutedEvent.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/audit/InputGuardrailExecutedEvent.java index a8a552363..4d37cdaaf 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/audit/InputGuardrailExecutedEvent.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/audit/InputGuardrailExecutedEvent.java @@ -5,6 +5,10 @@ import io.quarkiverse.langchain4j.guardrails.InputGuardrailParams; import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult; +/** + * @deprecated This will be replaced with an alternate version when the upstream guardrail implementation is merged + */ +@Deprecated(forRemoval = true) public interface InputGuardrailExecutedEvent extends GuardrailExecutedEvent { /** diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/audit/OutputGuardrailExecutedEvent.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/audit/OutputGuardrailExecutedEvent.java index ef15d013a..13d71e49b 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/audit/OutputGuardrailExecutedEvent.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/audit/OutputGuardrailExecutedEvent.java @@ -4,6 +4,10 @@ import io.quarkiverse.langchain4j.guardrails.OutputGuardrailParams; import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; +/** + * @deprecated This will be replaced with an alternate version when the upstream guardrail implementation is merged + */ +@Deprecated(forRemoval = true) public interface OutputGuardrailExecutedEvent extends GuardrailExecutedEvent { } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/audit/guardrails/GuardrailExecutedEvent.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/audit/guardrails/GuardrailExecutedEvent.java new file mode 100644 index 000000000..be0ff2601 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/audit/guardrails/GuardrailExecutedEvent.java @@ -0,0 +1,42 @@ +package io.quarkiverse.langchain4j.audit.guardrails; + +import dev.langchain4j.guardrail.Guardrail; +import dev.langchain4j.guardrail.GuardrailRequest; +import dev.langchain4j.guardrail.GuardrailResult; +import io.quarkiverse.langchain4j.audit.LLMInteractionEvent; + +/** + * Represents an event that is executed when a guardrail validation occurs. + * This interface serves as a marker for events that contain both parameters + * and results associated with guardrail validation. + * + * @param

the type of guardrail parameters used in the validation process + * @param the type of guardrail result produced by the validation process + */ +public interface GuardrailExecutedEvent

, G extends Guardrail> + extends LLMInteractionEvent { + /** + * Retrieves the request used for input guardrail validation. + * + * @return the parameters containing user message, memory, augmentation result, user message template, + * and associated variables for input guardrail validation. + */ + P request(); + + /** + * Retrieves the result of the input guardrail validation process. + * + * @return the result of the input guardrail validation, including the validation outcome + * and any associated failures, if present. + */ + R result(); + + /** + * Retrieves the guardrail class associated with the validation process. + * + * @return the guardrail class that implements the logic for validating + * the interaction between user and LLM, represented as an instance + * of the type extending {@code Guardrail}. + */ + Class guardrailClass(); +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/audit/guardrails/InputGuardrailExecutedEvent.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/audit/guardrails/InputGuardrailExecutedEvent.java new file mode 100644 index 000000000..bb5e09777 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/audit/guardrails/InputGuardrailExecutedEvent.java @@ -0,0 +1,18 @@ +package io.quarkiverse.langchain4j.audit.guardrails; + +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.guardrail.InputGuardrail; +import dev.langchain4j.guardrail.InputGuardrailRequest; +import dev.langchain4j.guardrail.InputGuardrailResult; + +public interface InputGuardrailExecutedEvent + extends GuardrailExecutedEvent { + /** + * Retrieves a rewritten user message if a successful rewritten result exists. + * If the result contains a rewritten message, it constructs a new user message + * with the rewritten text; otherwise, it returns the original user message. + * + * @return The rewritten user message if a rewritten result exists; otherwise, the original user message. + */ + UserMessage rewrittenUserMessage(); +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/audit/guardrails/OutputGuardrailExecutedEvent.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/audit/guardrails/OutputGuardrailExecutedEvent.java new file mode 100644 index 000000000..68b7f4f83 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/audit/guardrails/OutputGuardrailExecutedEvent.java @@ -0,0 +1,9 @@ +package io.quarkiverse.langchain4j.audit.guardrails; + +import dev.langchain4j.guardrail.OutputGuardrail; +import dev.langchain4j.guardrail.OutputGuardrailRequest; +import dev.langchain4j.guardrail.OutputGuardrailResult; + +public interface OutputGuardrailExecutedEvent + extends GuardrailExecutedEvent { +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/audit/guardrails/internal/DefaultInputGuardrailExecutedEvent.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/audit/guardrails/internal/DefaultInputGuardrailExecutedEvent.java new file mode 100644 index 000000000..32021a3fb --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/audit/guardrails/internal/DefaultInputGuardrailExecutedEvent.java @@ -0,0 +1,17 @@ +package io.quarkiverse.langchain4j.audit.guardrails.internal; + +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.guardrail.InputGuardrail; +import dev.langchain4j.guardrail.InputGuardrailRequest; +import dev.langchain4j.guardrail.InputGuardrailResult; +import io.quarkiverse.langchain4j.audit.AuditSourceInfo; +import io.quarkiverse.langchain4j.audit.guardrails.InputGuardrailExecutedEvent; + +public record DefaultInputGuardrailExecutedEvent(AuditSourceInfo sourceInfo, InputGuardrailRequest request, + InputGuardrailResult result, Class guardrailClass) implements InputGuardrailExecutedEvent { + + @Override + public UserMessage rewrittenUserMessage() { + return result.hasRewrittenResult() ? request().withText(result.successfulText()).userMessage() : request.userMessage(); + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/audit/guardrails/internal/DefaultOutputGuardrailExecutedEvent.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/audit/guardrails/internal/DefaultOutputGuardrailExecutedEvent.java new file mode 100644 index 000000000..874086128 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/audit/guardrails/internal/DefaultOutputGuardrailExecutedEvent.java @@ -0,0 +1,11 @@ +package io.quarkiverse.langchain4j.audit.guardrails.internal; + +import dev.langchain4j.guardrail.OutputGuardrail; +import dev.langchain4j.guardrail.OutputGuardrailRequest; +import dev.langchain4j.guardrail.OutputGuardrailResult; +import io.quarkiverse.langchain4j.audit.AuditSourceInfo; +import io.quarkiverse.langchain4j.audit.guardrails.OutputGuardrailExecutedEvent; + +public record DefaultOutputGuardrailExecutedEvent(AuditSourceInfo sourceInfo, OutputGuardrailRequest request, + OutputGuardrailResult result, Class guardrailClass) implements OutputGuardrailExecutedEvent { +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/audit/internal/DefaultInputGuardrailExecutedEvent.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/audit/internal/DefaultInputGuardrailExecutedEvent.java index 4b9cfaad5..04192f7e6 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/audit/internal/DefaultInputGuardrailExecutedEvent.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/audit/internal/DefaultInputGuardrailExecutedEvent.java @@ -7,6 +7,10 @@ import io.quarkiverse.langchain4j.guardrails.InputGuardrailParams; import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult; +/** + * @deprecated This will be replaced with an alternate version when the upstream guardrail implementation is merged + */ +@Deprecated(forRemoval = true) public record DefaultInputGuardrailExecutedEvent(AuditSourceInfo sourceInfo, InputGuardrailParams params, InputGuardrailResult result, Class guardrailClass) implements InputGuardrailExecutedEvent { diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/audit/internal/DefaultOutputGuardrailExecutedEvent.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/audit/internal/DefaultOutputGuardrailExecutedEvent.java index 151c55244..f81b2040d 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/audit/internal/DefaultOutputGuardrailExecutedEvent.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/audit/internal/DefaultOutputGuardrailExecutedEvent.java @@ -6,6 +6,10 @@ import io.quarkiverse.langchain4j.guardrails.OutputGuardrailParams; import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; +/** + * @deprecated This will be replaced with an alternate version when the upstream guardrail implementation is merged + */ +@Deprecated(forRemoval = true) public record DefaultOutputGuardrailExecutedEvent(AuditSourceInfo sourceInfo, OutputGuardrailParams params, OutputGuardrailResult result, Class guardrailClass) implements OutputGuardrailExecutedEvent { } \ No newline at end of file diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/AbstractJsonExtractorOutputGuardrail.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/AbstractJsonExtractorOutputGuardrail.java index ff25c40ab..f1a297597 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/AbstractJsonExtractorOutputGuardrail.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/AbstractJsonExtractorOutputGuardrail.java @@ -8,6 +8,10 @@ import dev.langchain4j.data.message.AiMessage; +/** + * @deprecated Use {@link dev.langchain4j.guardrail.JsonExtractorOutputGuardrail} instead + */ +@Deprecated(forRemoval = true) public abstract class AbstractJsonExtractorOutputGuardrail implements OutputGuardrail { @Inject diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/ClassProvidingAnnotationLiteral.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/ClassProvidingAnnotationLiteral.java new file mode 100644 index 000000000..31cf196b9 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/ClassProvidingAnnotationLiteral.java @@ -0,0 +1,77 @@ +package io.quarkiverse.langchain4j.guardrails; + +import java.lang.annotation.Annotation; +import java.util.ArrayList; +import java.util.List; + +import jakarta.enterprise.util.AnnotationLiteral; + +import dev.langchain4j.guardrail.Guardrail; + +public abstract sealed class ClassProvidingAnnotationLiteral + extends AnnotationLiteral permits InputGuardrailsLiteral, OutputGuardrailsLiteral { + private final List classNames = new ArrayList<>(); + private transient volatile List> guardrailClasses; + + protected ClassProvidingAnnotationLiteral(List classNames) { + if (classNames != null) { + this.classNames.addAll(classNames); + } + } + + /** + * Needed because this class will be serialized & deserialized + */ + public List getClassNames() { + return classNames; + } + + /** + * Needed because this class will be serialized & deserialized + */ + public void setClassNames(List classNames) { + this.classNames.clear(); + + if (classNames != null) { + this.classNames.addAll(classNames); + } + } + + public Class[] value() { + return getClasses(); + } + + private boolean neesCacheInitialization() { + return (this.guardrailClasses == null) || this.guardrailClasses.size() != this.classNames.size(); + } + + private void checkClassCache() { + // Using double-checked locking pattern for cache initialization + if (neesCacheInitialization()) { + synchronized (this) { + if (neesCacheInitialization()) { + var classLoader = Thread.currentThread().getContextClassLoader(); + this.guardrailClasses = this.classNames.stream() + .map(className -> { + try { + return (Class) Class.forName(className, false, classLoader); + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } + }) + .toList(); + } + } + } + } + + protected final Class[] getClasses() { + checkClassCache(); + return this.guardrailClasses.toArray(Class[]::new); + } + + public boolean hasGuardrails() { + checkClassCache(); + return !this.guardrailClasses.isEmpty(); + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/Guardrail.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/Guardrail.java index 06bf0d48b..4ceda4440 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/Guardrail.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/Guardrail.java @@ -3,7 +3,10 @@ /** * A guardrail is a rule that is applied when interacting with an LLM either to the input (the user message) or to the output of * the model to ensure that they are safe and meet the expectations of the model. + * + * @deprecated Use {@link dev.langchain4j.guardrail.Guardrail} instead */ +@Deprecated(forRemoval = true) public interface Guardrail

> { /** diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailParams.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailParams.java index e04fa76b8..ea7047010 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailParams.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailParams.java @@ -6,7 +6,10 @@ /** * Represents the parameter passed to {@link Guardrail#validate(GuardrailParams)}} in order to validate an interaction between a * user and the LLM. + * + * @deprecated Use {@link dev.langchain4j.guardrail.GuardrailRequestParams} instead */ +@Deprecated(forRemoval = true) public interface GuardrailParams { /** diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailResult.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailResult.java index cb9da4da2..69391346f 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailResult.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailResult.java @@ -4,7 +4,10 @@ /** * The result of the validation of an interaction between a user and the LLM. + * + * @deprecated Use {@link dev.langchain4j.guardrail.GuardrailResult} instead */ +@Deprecated(forRemoval = true) public interface GuardrailResult { /** diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrail.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrail.java index 6c15d4f08..c8333aafc 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrail.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrail.java @@ -10,8 +10,11 @@ * safe and meets the expectations of the model. *

* Implementation should be exposed as a CDI bean, and the class name configured in {@link InputGuardrails#value()} annotation. + * + * @deprecated Use {@link dev.langchain4j.guardrail.InputGuardrail} instead */ @Experimental("This feature is experimental and the API is subject to change") +@Deprecated(forRemoval = true) public interface InputGuardrail extends Guardrail { /** diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrailParams.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrailParams.java index 01e5bc1de..a78327e44 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrailParams.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrailParams.java @@ -19,7 +19,9 @@ * @param augmentationResult the augmentation result, can be {@code null} * @param userMessageTemplate the user message template, cannot be {@code null} * @param variables the variable to be used with userMessageTemplate, cannot be {@code null} + * @deprecated Use {@link dev.langchain4j.guardrail.InputGuardrailRequest} instead */ +@Deprecated(forRemoval = true) public record InputGuardrailParams(UserMessage userMessage, ChatMemory memory, AugmentationResult augmentationResult, String userMessageTemplate, Map variables) implements GuardrailParams { diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrailResult.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrailResult.java index 1fda84526..0ea59393e 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrailResult.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrailResult.java @@ -9,7 +9,9 @@ * * @param result The result of the input guardrail validation. * @param failures The list of failures, empty if the validation succeeded. + * @deprecated Use {@link dev.langchain4j.guardrail.InputGuardrailResult} instead */ +@Deprecated(forRemoval = true) public record InputGuardrailResult(Result result, String successfulText, List failures) implements GuardrailResult { diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrails.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrails.java index 2269675eb..7117eb31e 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrails.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrails.java @@ -24,7 +24,10 @@ *

* When several guardrails are applied, the order of the guardrails is important, as the guardrails are applied in the * order they are listed. + * + * @deprecated Use {@link dev.langchain4j.service.guardrail.InputGuardrails} instead */ +@Deprecated(forRemoval = true) @Retention(RUNTIME) @Target({ ElementType.TYPE, ElementType.METHOD }) public @interface InputGuardrails { diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrailsLiteral.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrailsLiteral.java new file mode 100644 index 000000000..e2d33ea02 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrailsLiteral.java @@ -0,0 +1,20 @@ +package io.quarkiverse.langchain4j.guardrails; + +import java.util.List; + +import dev.langchain4j.guardrail.InputGuardrail; +import dev.langchain4j.service.guardrail.InputGuardrails; + +public final class InputGuardrailsLiteral extends ClassProvidingAnnotationLiteral + implements InputGuardrails { + /** + * Needed because this class will be serialized & deserialized + */ + public InputGuardrailsLiteral() { + this(List.of()); + } + + public InputGuardrailsLiteral(List guardrailClassNames) { + super(guardrailClassNames); + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/JsonGuardrailsUtils.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/JsonGuardrailsUtils.java index 0064448c5..f7bbfdc94 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/JsonGuardrailsUtils.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/JsonGuardrailsUtils.java @@ -7,6 +7,7 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; +@Deprecated(forRemoval = true) @ApplicationScoped public class JsonGuardrailsUtils { diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/NoopChatExecutor.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/NoopChatExecutor.java new file mode 100644 index 000000000..3e2067115 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/NoopChatExecutor.java @@ -0,0 +1,26 @@ +package io.quarkiverse.langchain4j.guardrails; + +import java.util.List; + +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.guardrail.ChatExecutor; +import dev.langchain4j.model.chat.response.ChatResponse; + +/** + * This is needed for output guardrails on a method returning Multi. + * The LC4J api requires a {@link ChatExecutor} passed into the {@link dev.langchain4j.guardrail.OutputGuardrailExecutor}, + * but in the case of a Multi, we do not want the {@link dev.langchain4j.guardrail.OutputGuardrailExecutor} to re-execute the + * request. + * Instead, we will fail and retry the Multi itself + */ +public final class NoopChatExecutor implements ChatExecutor { + @Override + public ChatResponse execute() { + return execute(List.of()); + } + + @Override + public ChatResponse execute(List chatMessages) { + return null; + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrail.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrail.java index 3d3489f3f..2730e9881 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrail.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrail.java @@ -14,7 +14,10 @@ * In the case of reprompting, the reprompt message is added to the LLM context and the request is retried. *

* The maximum number of retries is configurable using {@code quarkus.langchain4j.guardrails.max-retries}, defaulting to 3. + * + * @deprecated Use {@link dev.langchain4j.guardrail.OutputGuardrail} instead */ +@Deprecated(forRemoval = true) @Experimental("This feature is experimental and the API is subject to change") public interface OutputGuardrail extends Guardrail { diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailParams.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailParams.java index b4799bac2..d6b9c8189 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailParams.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailParams.java @@ -17,7 +17,9 @@ * @param augmentationResult the augmentation result, can be {@code null} * @param userMessageTemplate the user message template, cannot be {@code null} * @param variables the variable to be used with userMessageTemplate, cannot be {@code null} + * @deprecated Use {@link dev.langchain4j.guardrail.OutputGuardrailRequest} instead */ +@Deprecated(forRemoval = true) public record OutputGuardrailParams(AiMessage responseFromLLM, ChatMemory memory, AugmentationResult augmentationResult, String userMessageTemplate, Map variables) implements GuardrailParams { diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailResult.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailResult.java index d58a04515..ac0966649 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailResult.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailResult.java @@ -9,7 +9,9 @@ * * @param result The result of the output guardrail validation. * @param failures The list of failures, empty if the validation succeeded. + * @deprecated Use {@link dev.langchain4j.guardrail.OutputGuardrailResult} instead */ +@Deprecated(forRemoval = true) public record OutputGuardrailResult(Result result, String successfulText, Object successfulResult, List failures) implements GuardrailResult { diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrails.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrails.java index 906afefbd..64d17fb6a 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrails.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrails.java @@ -23,7 +23,10 @@ *

* When several guardrails are applied, the order of the guardrails is important. Note that in case of retry or reprompt, * all the guardrails will be re-applied to the new response. + * + * @deprecated Use {@link dev.langchain4j.service.guardrail.OutputGuardrails} instead */ +@Deprecated(forRemoval = true) @Retention(RUNTIME) @Target({ ElementType.TYPE, ElementType.METHOD }) public @interface OutputGuardrails { diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailsLiteral.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailsLiteral.java new file mode 100644 index 000000000..d771de01c --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailsLiteral.java @@ -0,0 +1,69 @@ +package io.quarkiverse.langchain4j.guardrails; + +import java.util.List; + +import dev.langchain4j.guardrail.OutputGuardrail; +import dev.langchain4j.service.guardrail.OutputGuardrails; +import io.quarkiverse.langchain4j.runtime.config.GuardrailsConfig; + +public final class OutputGuardrailsLiteral extends ClassProvidingAnnotationLiteral + implements OutputGuardrails { + + private int maxRetriesToPerform; + private int maxRetriesAsSetByConfig; + + /** + * Needed because this class will be serialized & deserialized + */ + public OutputGuardrailsLiteral() { + this(List.of(), GuardrailsConfig.MAX_RETRIES_DEFAULT, GuardrailsConfig.MAX_RETRIES_DEFAULT); + } + + public OutputGuardrailsLiteral(List guardrailsClasses, int maxRetries) { + this(guardrailsClasses, maxRetries, maxRetries); + } + + /** + * + * @param guardrailsClasses The guardrail classes + * @param maxRetriesToPerform How many retries we want the {@link dev.langchain4j.service.guardrail.GuardrailService} to + * perform + * @param maxRetriesAsSetByConfig The actual number of max retries as set on the annotation. Used in case the method's + * return type is {@link io.smallrye.mutiny.Multi}. + */ + public OutputGuardrailsLiteral(List guardrailsClasses, int maxRetriesToPerform, int maxRetriesAsSetByConfig) { + super(guardrailsClasses); + this.maxRetriesToPerform = maxRetriesToPerform; + this.maxRetriesAsSetByConfig = maxRetriesAsSetByConfig; + } + + /** + * Needed because this class will be serialized & deserialized + */ + public int getMaxRetriesToPerform() { + return maxRetriesToPerform; + } + + /** + * Needed because this class will be serialized & deserialized + */ + public void setMaxRetriesToPerform(int maxRetriesToPerform) { + this.maxRetriesToPerform = maxRetriesToPerform; + } + + @Override + public int maxRetries() { + return this.maxRetriesToPerform; + } + + public int getMaxRetriesAsSetByConfig() { + return maxRetriesAsSetByConfig; + } + + /** + * Needed because this class will be serialized & deserialized + */ + public void setMaxRetriesAsSetByConfig(int maxRetriesAsSetByConfig) { + this.maxRetriesAsSetByConfig = maxRetriesAsSetByConfig; + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/QuarkusInputGuardrailsConfig.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/QuarkusInputGuardrailsConfig.java new file mode 100644 index 000000000..af718b532 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/QuarkusInputGuardrailsConfig.java @@ -0,0 +1,7 @@ +package io.quarkiverse.langchain4j.guardrails; + +import dev.langchain4j.guardrail.config.InputGuardrailsConfig; + +public class QuarkusInputGuardrailsConfig implements InputGuardrailsConfig { + +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/QuarkusInputGuardrailsConfigBuilderFactory.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/QuarkusInputGuardrailsConfigBuilderFactory.java new file mode 100644 index 000000000..3475f48f3 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/QuarkusInputGuardrailsConfigBuilderFactory.java @@ -0,0 +1,19 @@ +package io.quarkiverse.langchain4j.guardrails; + +import dev.langchain4j.guardrail.config.InputGuardrailsConfig; +import dev.langchain4j.guardrail.config.InputGuardrailsConfig.InputGuardrailsConfigBuilder; +import dev.langchain4j.spi.guardrail.config.InputGuardrailsConfigBuilderFactory; + +public class QuarkusInputGuardrailsConfigBuilderFactory implements InputGuardrailsConfigBuilderFactory { + @Override + public InputGuardrailsConfigBuilder get() { + return new QuarkusInputGuardrailsConfigBuilder(); + } + + private static class QuarkusInputGuardrailsConfigBuilder implements InputGuardrailsConfigBuilder { + @Override + public InputGuardrailsConfig build() { + return new QuarkusInputGuardrailsConfig(); + } + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/QuarkusOutputGuardrailsConfig.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/QuarkusOutputGuardrailsConfig.java new file mode 100644 index 000000000..ebba10149 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/QuarkusOutputGuardrailsConfig.java @@ -0,0 +1,42 @@ +package io.quarkiverse.langchain4j.guardrails; + +import org.eclipse.microprofile.config.ConfigProvider; + +import dev.langchain4j.guardrail.config.OutputGuardrailsConfig; +import io.quarkiverse.langchain4j.runtime.config.GuardrailsConfig; +import io.quarkiverse.langchain4j.runtime.config.LangChain4jConfig; +import io.smallrye.config.SmallRyeConfig; + +public class QuarkusOutputGuardrailsConfig implements OutputGuardrailsConfig { + private final int maxRetries; + private final GuardrailsConfig guardrailsConfig; + + public QuarkusOutputGuardrailsConfig(QuarkusOutputGuardrailsConfigBuilder builder) { + this.maxRetries = builder.maxRetries; + this.guardrailsConfig = ConfigProvider.getConfig() + .unwrap(SmallRyeConfig.class) + .getConfigMapping(LangChain4jConfig.class) + .guardrails(); + } + + @Override + public int maxRetries() { + return (this.guardrailsConfig.maxRetries() == GuardrailsConfig.MAX_RETRIES_DEFAULT) ? this.maxRetries + : this.guardrailsConfig.maxRetries(); + } + + static class QuarkusOutputGuardrailsConfigBuilder implements OutputGuardrailsConfigBuilder { + private int maxRetries = GuardrailsConfig.MAX_RETRIES_DEFAULT; + + @Override + public OutputGuardrailsConfigBuilder maxRetries(int maxRetries) { + this.maxRetries = maxRetries; + return this; + } + + @Override + public OutputGuardrailsConfig build() { + return new QuarkusOutputGuardrailsConfig(this); + } + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/QuarkusOutputGuardrailsConfigBuilderFactory.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/QuarkusOutputGuardrailsConfigBuilderFactory.java new file mode 100644 index 000000000..ea5ab523e --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/QuarkusOutputGuardrailsConfigBuilderFactory.java @@ -0,0 +1,12 @@ +package io.quarkiverse.langchain4j.guardrails; + +import dev.langchain4j.guardrail.config.OutputGuardrailsConfig.OutputGuardrailsConfigBuilder; +import dev.langchain4j.spi.guardrail.config.OutputGuardrailsConfigBuilderFactory; +import io.quarkiverse.langchain4j.guardrails.QuarkusOutputGuardrailsConfig.QuarkusOutputGuardrailsConfigBuilder; + +public class QuarkusOutputGuardrailsConfigBuilderFactory implements OutputGuardrailsConfigBuilderFactory { + @Override + public OutputGuardrailsConfigBuilder get() { + return new QuarkusOutputGuardrailsConfigBuilder(); + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceClassCreateInfo.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceClassCreateInfo.java index 45f394bd7..86da7131e 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceClassCreateInfo.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceClassCreateInfo.java @@ -2,8 +2,12 @@ import java.util.Map; +import io.quarkiverse.langchain4j.guardrails.InputGuardrailsLiteral; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailsLiteral; + /** * @param methodMap the key is a methodId generated at build time */ -public record AiServiceClassCreateInfo(Map methodMap, String implClassName) { +public record AiServiceClassCreateInfo(Map methodMap, String implClassName, + InputGuardrailsLiteral inputGuardrails, OutputGuardrailsLiteral outputGuardrails) { } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodCreateInfo.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodCreateInfo.java index 5d3d6f94a..750abcb6c 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodCreateInfo.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodCreateInfo.java @@ -15,8 +15,8 @@ import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.model.chat.request.json.JsonSchema; import dev.langchain4j.service.tool.ToolExecutor; -import io.quarkiverse.langchain4j.guardrails.InputGuardrail; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; +import io.quarkiverse.langchain4j.guardrails.InputGuardrailsLiteral; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailsLiteral; import io.quarkiverse.langchain4j.guardrails.OutputTokenAccumulator; import io.quarkiverse.langchain4j.response.AiResponseAugmenter; import io.quarkiverse.langchain4j.runtime.ResponseSchemaUtil; @@ -45,8 +45,20 @@ public final class AiServiceMethodCreateInfo { private final ResponseSchemaInfo responseSchemaInfo; // support for guardrails - private final List outputGuardrailsClassNames; - private final List inputGuardrailsClassNames; + /** + * @deprecated Will go away once the Quarkus-specific guardrail implementation has been fully removed + */ + @Deprecated(forRemoval = true) + private final List quarkusOutputGuardrailsClassNames; + + /** + * @deprecated Will go away once the Quarkus-specific guardrail implementation has been fully removed + */ + @Deprecated(forRemoval = true) + private final List quarkusInputGuardrailsClassNames; + + private final InputGuardrailsLiteral inputGuardrails; + private final OutputGuardrailsLiteral outputGuardrails; // support for response augmenter, potentially null private final String responseAugmenterClassName; @@ -56,14 +68,22 @@ public final class AiServiceMethodCreateInfo { private transient final Map toolExecutors = new ConcurrentHashMap<>(); // Don't cache the instances, because of scope issues (some will need to be re-queried) - private transient final List> outputGuardrails = new CopyOnWriteArrayList<>(); - private transient final List> inputGuardrails = new CopyOnWriteArrayList<>(); + /** + * @deprecated Will go away once the Quarkus-specific guardrail implementation has been fully removed + */ + @Deprecated(forRemoval = true) + private transient final List> quarkusOutputGuardrailsClasses = new CopyOnWriteArrayList<>(); + /** + * @deprecated Will go away once the Quarkus-specific guardrail implementation has been fully removed + */ + @Deprecated(forRemoval = true) + private transient final List> quarkusInputGuardrailsClasses = new CopyOnWriteArrayList<>(); private transient Class> augmenter; private final String outputTokenAccumulatorClassName; private OutputTokenAccumulator accumulator; - private final LazyValue guardrailsMaxRetry; + private final LazyValue quarkusGuardrailsMaxRetry; private final boolean switchToWorkerThreadForToolExecution; @RecordableConstructor @@ -81,10 +101,12 @@ public AiServiceMethodCreateInfo(String interfaceName, String methodName, Map> toolClassInfo, List mcpClientNames, boolean switchToWorkerThreadForToolExecution, - List inputGuardrailsClassNames, - List outputGuardrailsClassNames, + List quarkusInputGuardrailsClassNames, + List quarkusOutputGuardrailsClassNames, String outputTokenAccumulatorClassName, - String responseAugmenterClassName) { + String responseAugmenterClassName, + InputGuardrailsLiteral inputGuardrails, + OutputGuardrailsLiteral outputGuardrails) { this.interfaceName = interfaceName; this.methodName = methodName; this.systemMessageInfo = systemMessageInfo; @@ -105,11 +127,13 @@ public Type get() { this.responseSchemaInfo = responseSchemaInfo; this.toolClassInfo = toolClassInfo; this.mcpClientNames = mcpClientNames; - this.inputGuardrailsClassNames = inputGuardrailsClassNames; - this.outputGuardrailsClassNames = outputGuardrailsClassNames; + this.quarkusInputGuardrailsClassNames = quarkusInputGuardrailsClassNames; + this.quarkusOutputGuardrailsClassNames = quarkusOutputGuardrailsClassNames; + this.inputGuardrails = inputGuardrails; + this.outputGuardrails = outputGuardrails; this.outputTokenAccumulatorClassName = outputTokenAccumulatorClassName; // Use a lazy value to get the value at runtime. - this.guardrailsMaxRetry = new LazyValue(new Supplier() { + this.quarkusGuardrailsMaxRetry = new LazyValue(new Supplier() { @Override public Integer get() { return ConfigProvider.getConfig().getOptionalValue("quarkus.langchain4j.guardrails.max-retries", Integer.class) @@ -188,12 +212,24 @@ public Map getToolExecutors() { return toolExecutors; } - public List getOutputGuardrailsClassNames() { - return outputGuardrailsClassNames; + /** + * @deprecated Will go away once the Quarkus-specific guardrail implementation has been fully removed + */ + @Deprecated(forRemoval = true) + public List getQuarkusOutputGuardrailsClassNames() { + return quarkusOutputGuardrailsClassNames; } - public List> getOutputGuardrailsClasses() { - return outputGuardrails; + /** + * @deprecated Will go away once the Quarkus-specific guardrail implementation has been fully removed + */ + @Deprecated(forRemoval = true) + public List> getQuarkusOutputGuardrailsClasses() { + return quarkusOutputGuardrailsClasses; + } + + public InputGuardrailsLiteral getInputGuardrails() { + return inputGuardrails; } public String getResponseAugmenterClassName() { @@ -223,16 +259,28 @@ public Class> getResponseAugmenter() { } } - public int getGuardrailsMaxRetry() { - return guardrailsMaxRetry.get(); + public int getQuarkusGuardrailsMaxRetry() { + return quarkusGuardrailsMaxRetry.get(); } - public List getInputGuardrailsClassNames() { - return inputGuardrailsClassNames; + /** + * @deprecated Will go away once the Quarkus-specific guardrail implementation has been fully removed + */ + @Deprecated(forRemoval = true) + public List getQuarkusInputGuardrailsClassNames() { + return quarkusInputGuardrailsClassNames; } - public List> getInputGuardrailsClasses() { - return inputGuardrails; + /** + * @deprecated Will go away once the Quarkus-specific guardrail implementation has been fully removed + */ + @Deprecated(forRemoval = true) + public List> getQuarkusInputGuardrailsClasses() { + return quarkusInputGuardrailsClasses; + } + + public OutputGuardrailsLiteral getOutputGuardrails() { + return outputGuardrails; } public String getOutputTokenAccumulatorClassName() { diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java index ad1b378d6..d19ac0e1a 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java @@ -50,6 +50,8 @@ import dev.langchain4j.data.message.ToolExecutionResultMessage; import dev.langchain4j.data.message.UserMessage; import dev.langchain4j.data.pdf.PdfFile; +import dev.langchain4j.guardrail.ChatExecutor; +import dev.langchain4j.guardrail.GuardrailRequestParams; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.model.chat.ChatModel; import dev.langchain4j.model.chat.request.ChatRequest; @@ -68,13 +70,11 @@ import dev.langchain4j.model.output.TokenUsage; import dev.langchain4j.rag.AugmentationRequest; import dev.langchain4j.rag.AugmentationResult; -import dev.langchain4j.rag.content.Content; import dev.langchain4j.rag.query.Metadata; import dev.langchain4j.service.AiServiceContext; import dev.langchain4j.service.AiServiceTokenStream; import dev.langchain4j.service.AiServiceTokenStreamParameters; import dev.langchain4j.service.Result; -import dev.langchain4j.service.TokenStream; import dev.langchain4j.service.output.ServiceOutputParser; import dev.langchain4j.service.tool.ToolExecutor; import dev.langchain4j.service.tool.ToolProviderRequest; @@ -97,15 +97,13 @@ import io.quarkiverse.langchain4j.runtime.ContextLocals; import io.quarkiverse.langchain4j.runtime.QuarkusServiceOutputParser; import io.quarkiverse.langchain4j.runtime.ResponseSchemaUtil; +import io.quarkiverse.langchain4j.runtime.aiservice.GuardrailsSupport.GuardrailRetryException; +import io.quarkiverse.langchain4j.runtime.aiservice.GuardrailsSupport.OutputGuardrailStreamingMapper; import io.quarkiverse.langchain4j.runtime.types.TypeUtil; import io.quarkiverse.langchain4j.spi.DefaultMemoryIdProvider; import io.quarkus.arc.Arc; -import io.smallrye.common.vertx.VertxContext; import io.smallrye.mutiny.Multi; import io.smallrye.mutiny.infrastructure.Infrastructure; -import io.smallrye.mutiny.operators.AbstractMulti; -import io.smallrye.mutiny.operators.multi.processors.UnicastProcessor; -import io.smallrye.mutiny.subscription.MultiSubscriber; import io.vertx.core.Context; /** @@ -243,25 +241,31 @@ public Flow.Publisher apply(AugmentationResult ar) { ChatMessage augmentedUserMessage = ar.chatMessage(); ChatMemory memory = context.chatMemoryService.getChatMemory(memoryId); - UserMessage guardrailsMessage = GuardrailsSupport.invokeInputGuardrails(methodCreateInfo, + /** + * @deprecated Deprecated in favor of upstream implementation + */ + UserMessage guardrailsMessage = GuardrailsSupport.invokeInputGuardRails(methodCreateInfo, (UserMessage) augmentedUserMessage, memory, ar, templateVariables, beanManager, auditSourceInfo); + guardrailsMessage = GuardrailsSupport.executeInputGuardrails(context.guardrailService(), + guardrailsMessage, + methodCreateInfo, memory, ar, templateVariables); List messagesToSend = messagesToSend(guardrailsMessage, needsMemorySeed); var stream = new TokenStreamMulti(messagesToSend, effectiveToolSpecifications, finalToolExecutors, ar.contents(), context, memoryId, methodCreateInfo.isSwitchToWorkerThreadForToolExecution(), isRunningOnWorkerThread); + return stream - .filter(event -> { - return !isStringMulti || event instanceof ChatEvent.PartialResponseEvent; - }).map(event -> { + .filter(event -> !isStringMulti || event instanceof ChatEvent.PartialResponseEvent) + .map(event -> { if (isStringMulti && event instanceof ChatEvent.PartialResponseEvent) { return ((ChatEvent.PartialResponseEvent) event).getChunk(); } return event; - }).plug(m -> ResponseAugmenterSupport.apply(m, methodCreateInfo, - new ResponseAugmenterParams((UserMessage) augmentedUserMessage, - memory, ar, methodCreateInfo.getUserMessageTemplate(), - templateVariables))); + }) + .plug(m -> ResponseAugmenterSupport.apply(m, methodCreateInfo, + new ResponseAugmenterParams((UserMessage) augmentedUserMessage, memory, ar, + methodCreateInfo.getUserMessageTemplate(), templateVariables))); } private List messagesToSend(UserMessage augmentedUserMessage, @@ -277,27 +281,38 @@ private List messagesToSend(UserMessage augmentedUserMessage, } } - userMessage = GuardrailsSupport.invokeInputGuardrails(methodCreateInfo, userMessage, - context.hasChatMemory() ? context.chatMemoryService.getChatMemory(memoryId) : null, - augmentationResult, templateVariables, beanManager, auditSourceInfo); + var guardrailService = context.guardrailService(); + var chatMemory = context.hasChatMemory() ? context.chatMemoryService.getChatMemory(memoryId) : null; + + /** + * @deprecated Deprecated in favor of upstream implementation + */ + userMessage = GuardrailsSupport.invokeInputGuardRails(methodCreateInfo, userMessage, chatMemory, augmentationResult, + templateVariables, beanManager, auditSourceInfo); + + userMessage = GuardrailsSupport.executeInputGuardrails(guardrailService, userMessage, methodCreateInfo, chatMemory, + augmentationResult, templateVariables); - CommittableChatMemory chatMemory; + CommittableChatMemory committableChatMemory; List messagesToSend; if (context.hasChatMemory()) { // we want to defer saving the new messages because the service could fail and be retried - chatMemory = new DefaultCommittableChatMemory(context.chatMemoryService.getChatMemory(memoryId)); - messagesToSend = createMessagesToSendForExistingMemory(systemMessage, userMessage, chatMemory, needsMemorySeed, + committableChatMemory = new DefaultCommittableChatMemory(chatMemory); + messagesToSend = createMessagesToSendForExistingMemory(systemMessage, userMessage, committableChatMemory, + needsMemorySeed, context, methodCreateInfo); } else { - chatMemory = new NoopChatMemory(); + committableChatMemory = new NoopChatMemory(); messagesToSend = createMessagesToSendForNoMemory(systemMessage, userMessage, needsMemorySeed, context, methodCreateInfo); } if (TypeUtil.isTokenStream(returnType)) { // TODO Indicate the output guardrails cannot be used when using token stream. - chatMemory.commit(); // for streaming cases, we really have to commit because all alternatives are worse + // NOTE - only the quarkus-specific output guardrails aren't implemented using a TokenStream + // Upstream supports it + committableChatMemory.commit(); // for streaming cases, we really have to commit because all alternatives are worse var aiServiceTokenStreamParams = AiServiceTokenStreamParameters.builder() .messages(messagesToSend) .toolSpecifications(toolSpecifications) @@ -305,40 +320,41 @@ private List messagesToSend(UserMessage augmentedUserMessage, .retrievedContents((augmentationResult != null ? augmentationResult.contents() : null)) .context(context) .memoryId(memoryId) + .methodKey(methodCreateInfo) + .commonGuardrailParams( + GuardrailRequestParams.builder() + .chatMemory(committableChatMemory) + .augmentationResult(augmentationResult) + .userMessageTemplate(methodCreateInfo.getUserMessageTemplate()) + .variables(templateVariables) + .build()) .build(); return new AiServiceTokenStream(aiServiceTokenStreamParams); } var actualAugmentationResult = augmentationResult; var actualUserMessage = userMessage; - if (isMulti) { - chatMemory.commit(); // for streaming cases, we really have to commit because all alternatives are worse - if (methodCreateInfo.getOutputGuardrailsClassNames().isEmpty()) { - var stream = new TokenStreamMulti(messagesToSend, toolSpecifications, toolExecutors, - (augmentationResult != null ? augmentationResult.contents() : null), context, memoryId, - methodCreateInfo.isSwitchToWorkerThreadForToolExecution(), isRunningOnWorkerThread); - return stream.filter(event -> { - return !isStringMulti || event instanceof ChatEvent.PartialResponseEvent; - }).map(event -> { - if (isStringMulti && event instanceof ChatEvent.PartialResponseEvent) { - return ((ChatEvent.PartialResponseEvent) event).getChunk(); - } - return event; - }).plug(m -> ResponseAugmenterSupport.apply(m, methodCreateInfo, - new ResponseAugmenterParams(actualUserMessage, - chatMemory, actualAugmentationResult, methodCreateInfo.getUserMessageTemplate(), - Collections.unmodifiableMap(templateVariables)))); - } - return new TokenStreamMulti(messagesToSend, toolSpecifications, toolExecutors, + if (isMulti) { + committableChatMemory.commit(); // for streaming cases, we really have to commit because all alternatives are worse + var hasQuarkusOutputGuardrails = !methodCreateInfo.getQuarkusOutputGuardrailsClassNames().isEmpty(); + var hasUpstreamGuardrails = methodCreateInfo.getOutputGuardrails().hasGuardrails(); + Multi stream = new TokenStreamMulti(messagesToSend, toolSpecifications, toolExecutors, (augmentationResult != null ? augmentationResult.contents() : null), context, memoryId, - methodCreateInfo.isSwitchToWorkerThreadForToolExecution(), isRunningOnWorkerThread) - .plug(s -> GuardrailsSupport.accumulate(s, methodCreateInfo)) - .map(chunk -> { + methodCreateInfo.isSwitchToWorkerThreadForToolExecution(), isRunningOnWorkerThread); + + if (hasQuarkusOutputGuardrails || hasUpstreamGuardrails) { + stream = stream.filter(o -> o instanceof ChatEvent) + .map(ChatEvent.class::cast) + .plug(s -> GuardrailsSupport.accumulate(s, methodCreateInfo)); + + if (hasQuarkusOutputGuardrails) { + stream = stream.map(chunk -> { + ChatEvent.AccumulatedResponseEvent accumulatedChunk = (ChatEvent.AccumulatedResponseEvent) chunk; OutputGuardrailResult result; try { - result = GuardrailsSupport.invokeOutputGuardrailsForStream(methodCreateInfo, - new OutputGuardrailParams(AiMessage.from(chunk.getMessage()), chatMemory, + result = GuardrailsSupport.invokeOutputGuardRails(methodCreateInfo, + new OutputGuardrailParams(AiMessage.from(accumulatedChunk.getMessage()), chatMemory, actualAugmentationResult, methodCreateInfo.getUserMessageTemplate(), Collections.unmodifiableMap(templateVariables)), @@ -351,7 +367,7 @@ private List messagesToSend(UserMessage augmentedUserMessage, if (!result.isRetry()) { throw new GuardrailException(result.toString(), result.getFirstFailureException()); } else if (result.getReprompt() != null) { - chatMemory.add(new UserMessage(result.getReprompt())); + committableChatMemory.add(new UserMessage(result.getReprompt())); throw new GuardrailsSupport.GuardrailRetryException(); } else { // Retry without re-prompting @@ -362,31 +378,64 @@ private List messagesToSend(UserMessage augmentedUserMessage, throw new GuardrailException( "Attempting to rewrite the LLM output while streaming is not allowed"); } + if (isStringMulti) { - return chunk.getMessage(); + return accumulatedChunk.getMessage(); } + return chunk; } }) - // Retry logic: - // 1. retry only on the custom RetryException - // 2. If we still have a RetryException afterward, we fail. - .onFailure(GuardrailsSupport.GuardrailRetryException.class).retry() - .atMost(methodCreateInfo.getGuardrailsMaxRetry()) - .onFailure(GuardrailsSupport.GuardrailRetryException.class) - .transform(t -> new GuardrailException( - "Output validation failed. The guardrails have reached the maximum number of retries")) - .plug(m -> ResponseAugmenterSupport.apply(m, methodCreateInfo, - new ResponseAugmenterParams(actualUserMessage, - chatMemory, actualAugmentationResult, methodCreateInfo.getUserMessageTemplate(), - Collections.unmodifiableMap(templateVariables)))); + // Retry logic: + // 1. retry only on the custom RetryException + // 2. If we still have a RetryException afterward, we fail. + .onFailure(GuardrailRetryException.class) + .retry() + .atMost(methodCreateInfo.getQuarkusGuardrailsMaxRetry()) + .onFailure(GuardrailRetryException.class) + .transform(t -> new GuardrailException( + "Output validation failed. The guardrails have reached the maximum number of retries")); + } + + if (hasUpstreamGuardrails) { + stream = stream.map( + new OutputGuardrailStreamingMapper( + guardrailService, + methodCreateInfo, + committableChatMemory, + actualAugmentationResult, + templateVariables, + isStringMulti)) + .onFailure(GuardrailsSupport::isOutputGuardrailRetry) + .retry() + .atMost(methodCreateInfo.getOutputGuardrails().getMaxRetriesAsSetByConfig()); + } + } else { + stream = stream.filter(event -> !isStringMulti || event instanceof ChatEvent.PartialResponseEvent) + .map(event -> { + if (isStringMulti && (event instanceof ChatEvent.PartialResponseEvent)) { + return ((ChatEvent.PartialResponseEvent) event).getChunk(); + } + + return event; + }); + } + + return stream.plug(m -> ResponseAugmenterSupport.apply(m, methodCreateInfo, + new ResponseAugmenterParams(actualUserMessage, chatMemory, actualAugmentationResult, + methodCreateInfo.getUserMessageTemplate(), templateVariables))); } Future moderationFuture = triggerModerationIfNeeded(context, methodCreateInfo, messagesToSend); log.debug("Attempting to obtain AI response"); - ChatResponse response = executeRequest(context, methodCreateInfo, methodArgs, messagesToSend, toolSpecifications); + ChatRequest chatRequest = createChatRequest(context, methodCreateInfo, methodArgs, messagesToSend, toolSpecifications); + ChatExecutor chatExecutor = ChatExecutor.builder(context.effectiveChatModel(methodCreateInfo, methodArgs)) + .chatRequest(chatRequest) + .build(); + + ChatResponse response = chatExecutor.execute(); log.debug("AI response obtained"); @@ -402,14 +451,13 @@ private List messagesToSend(UserMessage augmentedUserMessage, : getMaxSequentialToolExecutions(); int executionsLeft = maxSequentialToolExecutions; while (true) { - if (executionsLeft-- == 0) { throw runtime("Something is wrong, exceeded %s sequential tool executions", maxSequentialToolExecutions); } AiMessage aiMessage = response.aiMessage(); - chatMemory.add(aiMessage); + committableChatMemory.add(aiMessage); if (!aiMessage.hasToolExecutionRequests()) { break; @@ -423,12 +471,12 @@ private List messagesToSend(UserMessage augmentedUserMessage, ? context.toolService.applyToolHallucinationStrategy(toolExecutionRequest) : executeTool(auditSourceInfo, toolExecutionRequest, toolExecutor, memoryId, beanManager); - chatMemory.add(toolExecutionResultMessage); + committableChatMemory.add(toolExecutionResultMessage); } log.debug("Attempting to obtain AI response"); ChatModel effectiveChatModel = context.effectiveChatModel(methodCreateInfo, methodArgs); - ChatRequest.Builder chatRequestBuilder = ChatRequest.builder().messages(chatMemory.messages()); + ChatRequest.Builder chatRequestBuilder = ChatRequest.builder().messages(committableChatMemory.messages()); DefaultChatRequestParameters.Builder parametersBuilder = ChatRequestParameters.builder(); if (supportsJsonSchema(effectiveChatModel)) { Optional jsonSchema = methodCreateInfo.getResponseSchemaInfo().structuredOutputSchema(); @@ -464,31 +512,38 @@ private List messagesToSend(UserMessage augmentedUserMessage, String userMessageTemplate = methodCreateInfo.getUserMessageTemplate(); - var guardrailResponse = GuardrailsSupport.invokeOutputGuardrails(methodCreateInfo, chatMemory, + /** + * @deprecated Deprecated in favor of upstream implementation + */ + var guardrailResponse = GuardrailsSupport.invokeOutputGuardRails(methodCreateInfo, committableChatMemory, context.effectiveChatModel(methodCreateInfo, methodArgs), response, toolSpecifications, - new OutputGuardrailParams(response.aiMessage(), chatMemory, augmentationResult, userMessageTemplate, + new OutputGuardrailParams(response.aiMessage(), committableChatMemory, augmentationResult, userMessageTemplate, Collections.unmodifiableMap(templateVariables)), beanManager, auditSourceInfo); response = guardrailResponse.response(); + Object guardrailResult = guardrailResponse + .getRewrittenResult(); + guardrailResult = GuardrailsSupport.executeOutputGuardrails(guardrailService, methodCreateInfo, response, chatExecutor, + committableChatMemory, augmentationResult, templateVariables, guardrailResult); // everything worked as expected so let's commit the messages - chatMemory.commit(); + committableChatMemory.commit(); - var responseAugmenterParam = new ResponseAugmenterParams(userMessage, chatMemory, augmentationResult, + var responseAugmenterParam = new ResponseAugmenterParams(userMessage, committableChatMemory, augmentationResult, userMessageTemplate, templateVariables); - Object guardrailResult = guardrailResponse.getRewrittenResult(); - if (guardrailResult != null && TypeUtil.isTypeOf(returnType, guardrailResult.getClass())) { + if ((guardrailResult != null) && TypeUtil.isTypeOf(returnType, guardrailResult.getClass())) { return ResponseAugmenterSupport.invoke(guardrailResult, methodCreateInfo, responseAugmenterParam); } - response = ChatResponse.builder() - .aiMessage(response.aiMessage()) - .metadata(response.metadata()) - .build(); + if (guardrailResult instanceof ChatResponse) { + response = (ChatResponse) guardrailResult; + } + + response = ChatResponse.builder().aiMessage(response.aiMessage()).metadata(response.metadata()).build(); if (TypeUtil.isResult(returnType)) { var parsedResponse = SERVICE_OUTPUT_PARSER.parse(ChatResponse.builder().aiMessage(response.aiMessage()).build(), @@ -520,6 +575,10 @@ private static ToolExecutionResultMessage executeTool(AuditSourceInfo auditSourc return toolExecutionResultMessage; } + /** + * @deprecated Deprecated in favor of upstream implementation + */ + @Deprecated(forRemoval = true) private static ChatResponse executeRequest(JsonSchema jsonSchema, List messagesToSend, ChatModel chatModel, List toolSpecifications) { var chatRequest = ChatRequest.builder() @@ -530,6 +589,18 @@ private static ChatResponse executeRequest(JsonSchema jsonSchema, List messagesToSend, ChatModel chatModel, + List toolSpecifications) { + return ChatRequest.builder() + .messages(messagesToSend) + .parameters(constructStructuredResponseParams(toolSpecifications, jsonSchema).build()) + .build(); + } + + /** + * @deprecated Deprecated in favor of upstream implementation + */ + @Deprecated(forRemoval = true) private static ChatResponse executeRequest(List messagesToSend, ChatModel chatModel, List toolSpecifications) { var chatRequest = ChatRequest.builder() @@ -540,6 +611,21 @@ private static ChatResponse executeRequest(List messagesToSend, Cha return chatModel.chat(chatRequest.build()); } + static ChatRequest createChatRequest(List messagesToSend, ChatModel chatModel, + List toolSpecifications) { + var chatRequest = ChatRequest.builder() + .messages(messagesToSend); + + if (toolSpecifications != null) { + chatRequest.toolSpecifications(toolSpecifications); + } + return chatRequest.build(); + } + + /** + * @deprecated Deprecated in favor of upstream implementation + */ + @Deprecated(forRemoval = true) static ChatResponse executeRequest(AiServiceMethodCreateInfo methodCreateInfo, List messagesToSend, ChatModel chatModel, List toolSpecifications) { var jsonSchema = supportsJsonSchema(chatModel) ? methodCreateInfo.getResponseSchemaInfo().structuredOutputSchema() @@ -549,11 +635,20 @@ static ChatResponse executeRequest(AiServiceMethodCreateInfo methodCreateInfo, L : executeRequest(messagesToSend, chatModel, toolSpecifications); } - static ChatResponse executeRequest(QuarkusAiServiceContext context, + static ChatRequest createChatRequest(AiServiceMethodCreateInfo methodCreateInfo, List messagesToSend, + ChatModel chatModel, List toolSpecifications) { + var jsonSchema = supportsJsonSchema(chatModel) ? methodCreateInfo.getResponseSchemaInfo().structuredOutputSchema() + : Optional. empty(); + + return jsonSchema.isPresent() ? createChatRequest(jsonSchema.get(), messagesToSend, chatModel, toolSpecifications) + : createChatRequest(messagesToSend, chatModel, toolSpecifications); + } + + static ChatRequest createChatRequest(QuarkusAiServiceContext context, AiServiceMethodCreateInfo methodCreateInfo, Object[] methodArgs, List messagesToSend, List toolSpecifications) { - return executeRequest(methodCreateInfo, messagesToSend, - context.effectiveChatModel(methodCreateInfo, methodArgs), + + return createChatRequest(methodCreateInfo, messagesToSend, context.effectiveChatModel(methodCreateInfo, methodArgs), toolSpecifications); } @@ -578,7 +673,7 @@ private static Object doImplementGenerateImage(AiServiceMethodCreateInfo methodC AugmentationResult augmentationResult = null; // TODO: we can only support input guardrails for now as it is tied to AiMessage - GuardrailsSupport.invokeInputGuardrails(methodCreateInfo, userMessage, + GuardrailsSupport.invokeInputGuardRails(methodCreateInfo, userMessage, context.hasChatMemory() ? context.chatMemoryService.getChatMemory(memoryId) : null, augmentationResult, templateVariables, beanManager, auditSourceInfo); @@ -967,76 +1062,4 @@ public interface Wrapper { Object wrap(Input input, Function fun); } - - private static class TokenStreamMulti extends AbstractMulti implements Multi { - private final List messagesToSend; - private final List toolSpecifications; - private final Map toolsExecutors; - private final List contents; - private final QuarkusAiServiceContext context; - private final Object memoryId; - private final boolean switchToWorkerThreadForToolExecution; - private final boolean isCallerRunningOnWorkerThread; - - public TokenStreamMulti(List messagesToSend, List toolSpecifications, - Map toolExecutors, - List contents, QuarkusAiServiceContext context, Object memoryId, - boolean switchToWorkerThreadForToolExecution, boolean isCallerRunningOnWorkerThread) { - // We need to pass and store the parameters to the constructor because we need to re-create a stream on every subscription. - this.messagesToSend = messagesToSend; - this.toolSpecifications = toolSpecifications; - this.toolsExecutors = toolExecutors; - this.contents = contents; - this.context = context; - this.memoryId = memoryId; - this.switchToWorkerThreadForToolExecution = switchToWorkerThreadForToolExecution; - this.isCallerRunningOnWorkerThread = isCallerRunningOnWorkerThread; - } - - @Override - public void subscribe(MultiSubscriber subscriber) { - UnicastProcessor processor = UnicastProcessor.create(); - processor.subscribe(subscriber); - - createTokenStream(processor); - } - - private void createTokenStream(UnicastProcessor processor) { - Context ctxt = null; - if (switchToWorkerThreadForToolExecution || isCallerRunningOnWorkerThread) { - // we create or retrieve the current context, to use `executeBlocking` when required. - ctxt = VertxContext.getOrCreateDuplicatedContext(); - } - - var stream = new QuarkusAiServiceTokenStream(messagesToSend, toolSpecifications, - toolsExecutors, contents, context, memoryId, ctxt, switchToWorkerThreadForToolExecution, - isCallerRunningOnWorkerThread); - TokenStream tokenStream = stream - .onPartialResponse(chunk -> processor - .onNext(new ChatEvent.PartialResponseEvent(chunk))) - .onCompleteResponse(message -> { - processor.onNext(new ChatEvent.ChatCompletedEvent(message)); - processor.onComplete(); - }) - .onRetrieved(content -> { - processor.onNext(new ChatEvent.ContentFetchedEvent(content)); - }) - .onToolExecuted(execution -> { - processor.onNext(new ChatEvent.ToolExecutedEvent(execution)); - }) - .onError(processor::onError); - // This is equivalent to "run subscription on worker thread" - if (switchToWorkerThreadForToolExecution && Context.isOnEventLoopThread()) { - ctxt.executeBlocking(new Callable() { - @Override - public Void call() { - tokenStream.start(); - return null; - } - }); - } else { - tokenStream.start(); - } - } - } } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/DeclarativeAiServiceCreateInfo.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/DeclarativeAiServiceCreateInfo.java index 0b770257e..61fd36ff3 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/DeclarativeAiServiceCreateInfo.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/DeclarativeAiServiceCreateInfo.java @@ -4,6 +4,9 @@ import jakarta.enterprise.util.AnnotationLiteral; +import io.quarkiverse.langchain4j.guardrails.InputGuardrailsLiteral; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailsLiteral; + public record DeclarativeAiServiceCreateInfo( String serviceClassName, String languageModelSupplierClassName, @@ -22,5 +25,7 @@ public record DeclarativeAiServiceCreateInfo( boolean needsModerationModel, boolean needsImageModel, String toolHallucinationStrategyClassName, + InputGuardrailsLiteral inputGuardrails, + OutputGuardrailsLiteral outputGuardrails, Integer maxSequentialToolInvocations) { } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailException.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailException.java index 131769eba..db0c70b5e 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailException.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailException.java @@ -4,7 +4,10 @@ * Exception thrown when an input or output guardrail validation fails. *

* This exception is not intended to be used in guardrail implementation. + * + * @deprecated Use {@link dev.langchain4j.guardrail.GuardrailException} instead */ +@Deprecated(forRemoval = true) public class GuardrailException extends RuntimeException { public GuardrailException(String message) { super(message); diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailsSupport.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailsSupport.java index c828b7f86..de07e6962 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailsSupport.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailsSupport.java @@ -16,12 +16,19 @@ import jakarta.enterprise.inject.spi.CDI; import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.guardrail.ChatExecutor; +import dev.langchain4j.guardrail.GuardrailRequestParams; +import dev.langchain4j.guardrail.InputGuardrailRequest; +import dev.langchain4j.guardrail.OutputGuardrailException; +import dev.langchain4j.guardrail.OutputGuardrailRequest; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.model.chat.ChatModel; import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.model.chat.response.ChatResponseMetadata; import dev.langchain4j.rag.AugmentationResult; +import dev.langchain4j.service.guardrail.GuardrailService; import io.quarkiverse.langchain4j.audit.AuditSourceInfo; import io.quarkiverse.langchain4j.audit.InputGuardrailExecutedEvent; import io.quarkiverse.langchain4j.audit.OutputGuardrailExecutedEvent; @@ -33,15 +40,166 @@ import io.quarkiverse.langchain4j.guardrails.InputGuardrail; import io.quarkiverse.langchain4j.guardrails.InputGuardrailParams; import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult; +import io.quarkiverse.langchain4j.guardrails.NoopChatExecutor; import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; import io.quarkiverse.langchain4j.guardrails.OutputGuardrailParams; import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; import io.quarkiverse.langchain4j.guardrails.OutputTokenAccumulator; +import io.quarkiverse.langchain4j.runtime.aiservice.ChatEvent.AccumulatedResponseEvent; +import io.quarkiverse.langchain4j.runtime.aiservice.ChatEvent.ChatEventType; import io.smallrye.mutiny.Multi; public class GuardrailsSupport { + static UserMessage executeInputGuardrails(GuardrailService guardrailService, UserMessage userMessage, + AiServiceMethodCreateInfo methodCreateInfo, ChatMemory chatMemory, AugmentationResult augmentationResult, + Map templateVariables) { + var um = userMessage; + + if (guardrailService.hasInputGuardrails(methodCreateInfo)) { + var request = InputGuardrailRequest.builder() + .userMessage(userMessage) + .commonParams( + GuardrailRequestParams.builder() + .chatMemory(chatMemory) + .augmentationResult(augmentationResult) + .userMessageTemplate(methodCreateInfo.getUserMessageTemplate()) + .variables(templateVariables) + .build()) + .build(); + + um = guardrailService.executeGuardrails(methodCreateInfo, request); + } + + return um; + } + + static T executeOutputGuardrails(GuardrailService guardrailService, AiServiceMethodCreateInfo methodCreateInfo, + ChatResponse response, ChatExecutor chatExecutor, CommittableChatMemory committableChatMemory, + AugmentationResult augmentationResult, Map templateVariables, + @Deprecated(forRemoval = true) T quarkusSpecificGuardrailResult) { + + /** + * The quarkusSpecificGuardrailResult will be removed when the quarkus-specific guardrails implementation is removed + */ + T result = quarkusSpecificGuardrailResult; + + if (guardrailService.hasOutputGuardrails(methodCreateInfo)) { + var request = OutputGuardrailRequest.builder() + .responseFromLLM(response) + .chatExecutor(chatExecutor) + .requestParams( + GuardrailRequestParams.builder() + .chatMemory(committableChatMemory) + .augmentationResult(augmentationResult) + .userMessageTemplate(methodCreateInfo.getUserMessageTemplate()) + .variables(templateVariables) + .build()) + .build(); + + result = guardrailService.executeGuardrails(methodCreateInfo, request); + } + + return result; + } + + static boolean isOutputGuardrailRetry(Throwable t) { + return (t instanceof OutputGuardrailException) && + t.getMessage().toLowerCase().contains("the guardrails have reached the maximum number of retries."); + } + + static class OutputGuardrailStreamingMapper + implements Function { + private final GuardrailService guardrailService; + private final AiServiceMethodCreateInfo methodCreateInfo; + private final CommittableChatMemory committableChatMemory; + private final AugmentationResult augmentationResult; + private final Map templateVariables; + private final boolean isStringMulti; + + OutputGuardrailStreamingMapper(GuardrailService guardrailService, AiServiceMethodCreateInfo methodCreateInfo, + CommittableChatMemory committableChatMemory, AugmentationResult augmentationResult, + Map templateVariables, boolean isStringMulti) { + this.guardrailService = guardrailService; + this.methodCreateInfo = methodCreateInfo; + this.committableChatMemory = committableChatMemory; + this.augmentationResult = augmentationResult; + this.templateVariables = templateVariables; + this.isStringMulti = isStringMulti; + } + + private Object apply(ChatEvent chunk) { + if (chunk.getEventType() == ChatEventType.AccumulatedResponse) { + var accumulatedChunk = (ChatEvent.AccumulatedResponseEvent) chunk; + var metadata = accumulatedChunk.getMetadata(); + var guardrailResult = executeOutputGuardrails( + guardrailService, + methodCreateInfo, + ChatResponse.builder() + .aiMessage(AiMessage.from(accumulatedChunk.getMessage())) + .build(), + new NoopChatExecutor(), + committableChatMemory, + augmentationResult, + templateVariables, + null); + + if (guardrailResult instanceof ChatResponse) { + String message = ((ChatResponse) guardrailResult).aiMessage().text(); + return isStringMulti ? message : new AccumulatedResponseEvent(message, metadata); + } else if (guardrailResult instanceof String) { + return isStringMulti ? (String) guardrailResult + : new AccumulatedResponseEvent((String) guardrailResult, metadata); + } else if (guardrailResult != null) { + return isStringMulti ? guardrailResult.toString() + : new AccumulatedResponseEvent(guardrailResult.toString(), metadata); + } + } + + return chunk; + } + + private Object apply(String chunk) { + var guardrailResult = executeOutputGuardrails( + guardrailService, + methodCreateInfo, + ChatResponse.builder() + .aiMessage(AiMessage.from(chunk)) + .build(), + new NoopChatExecutor(), + committableChatMemory, + augmentationResult, + templateVariables, + null); + + if (guardrailResult instanceof ChatResponse) { + return ((ChatResponse) guardrailResult).aiMessage().text(); + } else if (guardrailResult instanceof String) { + return (String) guardrailResult; + } else if (guardrailResult != null) { + // TODO is this really needed + return guardrailResult.toString(); + } + + return chunk; + } + + @Override + public Object apply(Object chunk) { + if (chunk instanceof ChatEvent) { + return apply((ChatEvent) chunk); + } else if (chunk instanceof String) { + return apply((String) chunk); + } + + return chunk; + } + } - public static UserMessage invokeInputGuardrails(AiServiceMethodCreateInfo methodCreateInfo, UserMessage userMessage, + /** + * @deprecated Deprecated in favor of upstream implementation + */ + @Deprecated(forRemoval = true) + public static UserMessage invokeInputGuardRails(AiServiceMethodCreateInfo methodCreateInfo, UserMessage userMessage, ChatMemory chatMemory, AugmentationResult augmentationResult, Map templateVariables, BeanManager beanManager, AuditSourceInfo auditSourceInfo) { InputGuardrailResult result; @@ -66,14 +224,18 @@ public static UserMessage invokeInputGuardrails(AiServiceMethodCreateInfo method return userMessage; } - public static OutputGuardrailResponse invokeOutputGuardrails(AiServiceMethodCreateInfo methodCreateInfo, + /** + * @deprecated Deprecated in favor of upstream implementation + */ + @Deprecated(forRemoval = true) + public static OutputGuardrailResponse invokeOutputGuardRails(AiServiceMethodCreateInfo methodCreateInfo, ChatMemory chatMemory, ChatModel chatModel, ChatResponse response, List toolSpecifications, OutputGuardrailParams output, BeanManager beanManager, AuditSourceInfo auditSourceInfo) { int attempt = 0; - int max = methodCreateInfo.getGuardrailsMaxRetry(); + int max = methodCreateInfo.getQuarkusGuardrailsMaxRetry(); if (max <= 0) { max = 1; } @@ -124,6 +286,10 @@ public static OutputGuardrailResponse invokeOutputGuardrails(AiServiceMethodCrea return new OutputGuardrailResponse(response, result); } + /** + * @deprecated Deprecated in favor of upstream implementation + */ + @Deprecated(forRemoval = true) public record OutputGuardrailResponse(ChatResponse response, OutputGuardrailResult result) { public boolean hasRewrittenResult() { @@ -135,17 +301,21 @@ public Object getRewrittenResult() { } } + /** + * @deprecated Deprecated in favor of upstream implementation + */ + @Deprecated(forRemoval = true) @SuppressWarnings("unchecked") - private static OutputGuardrailResult invokeOutputGuardRails(AiServiceMethodCreateInfo methodCreateInfo, + static OutputGuardrailResult invokeOutputGuardRails(AiServiceMethodCreateInfo methodCreateInfo, OutputGuardrailParams params, BeanManager beanManager, AuditSourceInfo auditSourceInfo) { - if (methodCreateInfo.getOutputGuardrailsClassNames().isEmpty()) { + if (methodCreateInfo.getQuarkusOutputGuardrailsClassNames().isEmpty()) { return OutputGuardrailResult.success(); } List> classes; synchronized (AiServiceMethodImplementationSupport.class) { - classes = methodCreateInfo.getOutputGuardrailsClasses(); + classes = methodCreateInfo.getQuarkusOutputGuardrailsClasses(); if (classes.isEmpty()) { - for (String className : methodCreateInfo.getOutputGuardrailsClassNames()) { + for (String className : methodCreateInfo.getQuarkusOutputGuardrailsClassNames()) { try { classes.add((Class) Class.forName(className, true, Thread.currentThread().getContextClassLoader())); @@ -163,17 +333,21 @@ private static OutputGuardrailResult invokeOutputGuardRails(AiServiceMethodCreat beanManager, auditSourceInfo); } + /** + * @deprecated Deprecated in favor of upstream implementation + */ + @Deprecated(forRemoval = true) @SuppressWarnings("unchecked") private static InputGuardrailResult invokeInputGuardRails(AiServiceMethodCreateInfo methodCreateInfo, InputGuardrailParams params, BeanManager beanManager, AuditSourceInfo auditSourceInfo) { - if (methodCreateInfo.getInputGuardrailsClassNames().isEmpty()) { + if (methodCreateInfo.getQuarkusInputGuardrailsClassNames().isEmpty()) { return InputGuardrailResult.success(); } List> classes; synchronized (AiServiceMethodImplementationSupport.class) { - classes = methodCreateInfo.getInputGuardrailsClasses(); + classes = methodCreateInfo.getQuarkusInputGuardrailsClasses(); if (classes.isEmpty()) { - for (String className : methodCreateInfo.getInputGuardrailsClassNames()) { + for (String className : methodCreateInfo.getQuarkusInputGuardrailsClassNames()) { try { classes.add((Class) Class.forName(className, true, Thread.currentThread().getContextClassLoader())); @@ -191,6 +365,10 @@ private static InputGuardrailResult invokeInputGuardRails(AiServiceMethodCreateI beanManager, auditSourceInfo); } + /** + * @deprecated Deprecated in favor of upstream implementation + */ + @Deprecated(forRemoval = true) private static GR guardrailResult(GuardrailParams params, List> classes, GR accumulatedResults, Function, GR> producer, BeanManager beanManager, @@ -221,6 +399,10 @@ private static GR guardrailResult(GuardrailParams p return accumulatedResults; } + /** + * @deprecated Deprecated in favor of upstream implementation + */ + @Deprecated(forRemoval = true) private static GR compose(GR oldResult, GR newResult, Function, GR> producer) { if (oldResult.isSuccess()) { @@ -246,28 +428,29 @@ private static class ChatResponseAccumulator { } - public static Multi accumulate( - Multi upstream, AiServiceMethodCreateInfo methodCreateInfo) { + public static Multi accumulate(Multi upstream, + AiServiceMethodCreateInfo methodCreateInfo) { OutputTokenAccumulator accumulator; synchronized (AiServiceMethodImplementationSupport.class) { accumulator = methodCreateInfo.getOutputTokenAccumulator(); if (accumulator == null) { String cn = methodCreateInfo.getOutputTokenAccumulatorClassName(); if (cn == null) { - return upstream.collect().in(ChatResponseAccumulator::new, (chatResponseAccumulator, chatEvent) -> { - if (chatEvent - .getEventType() == ChatEvent.ChatEventType.PartialResponse) { - chatResponseAccumulator.stringBuilder.append( - ((ChatEvent.PartialResponseEvent) chatEvent) - .getChunk()); - } - if (chatEvent - .getEventType() == ChatEvent.ChatEventType.Completed) { - chatResponseAccumulator.metadata = ((ChatEvent.ChatCompletedEvent) chatEvent) - .getChatResponse().metadata(); - } - }) - .map(acc -> new ChatEvent.AccumulatedResponseEvent( + return upstream.collect() + .in(ChatResponseAccumulator::new, (chatResponseAccumulator, chatEvent) -> { + if (chatEvent + .getEventType() == ChatEventType.PartialResponse) { + chatResponseAccumulator.stringBuilder.append( + ((ChatEvent.PartialResponseEvent) chatEvent) + .getChunk()); + } + if (chatEvent + .getEventType() == ChatEventType.Completed) { + chatResponseAccumulator.metadata = ((ChatEvent.ChatCompletedEvent) chatEvent) + .getChatResponse().metadata(); + } + }) + .map(acc -> new AccumulatedResponseEvent( acc.stringBuilder.toString(), acc.metadata)) .toMulti(); } @@ -299,11 +482,19 @@ public static Multi accumulate( } + /** + * @deprecated Deprecated in favor of upstream implementation + */ + @Deprecated(forRemoval = true) public static OutputGuardrailResult invokeOutputGuardrailsForStream(AiServiceMethodCreateInfo methodCreateInfo, OutputGuardrailParams outputGuardrailParams, BeanManager beanManager, AuditSourceInfo auditSourceInfo) { return invokeOutputGuardRails(methodCreateInfo, outputGuardrailParams, beanManager, auditSourceInfo); } + /** + * @deprecated Deprecated in favor of upstream implementation + */ + @Deprecated(forRemoval = true) static class GuardrailRetryException extends RuntimeException { // Marker class to indicate a retry to the downstream consumer. } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/TokenStreamMulti.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/TokenStreamMulti.java new file mode 100644 index 000000000..275863dc2 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/TokenStreamMulti.java @@ -0,0 +1,84 @@ +package io.quarkiverse.langchain4j.runtime.aiservice; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.Callable; + +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.rag.content.Content; +import dev.langchain4j.service.TokenStream; +import dev.langchain4j.service.tool.ToolExecutor; +import io.smallrye.common.vertx.VertxContext; +import io.smallrye.mutiny.Multi; +import io.smallrye.mutiny.operators.AbstractMulti; +import io.smallrye.mutiny.operators.multi.processors.UnicastProcessor; +import io.smallrye.mutiny.subscription.MultiSubscriber; +import io.vertx.core.Context; + +class TokenStreamMulti extends AbstractMulti implements Multi { + private final List messagesToSend; + private final List toolSpecifications; + private final Map toolsExecutors; + private final List contents; + private final QuarkusAiServiceContext context; + private final Object memoryId; + private final boolean switchToWorkerThreadForToolExecution; + private final boolean isCallerRunningOnWorkerThread; + + TokenStreamMulti(List messagesToSend, List toolSpecifications, + Map toolExecutors, + List contents, QuarkusAiServiceContext context, Object memoryId, + boolean switchToWorkerThreadForToolExecution, boolean isCallerRunningOnWorkerThread) { + // We need to pass and store the parameters to the constructor because we need to re-create a stream on every subscription. + this.messagesToSend = messagesToSend; + this.toolSpecifications = toolSpecifications; + this.toolsExecutors = toolExecutors; + this.contents = contents; + this.context = context; + this.memoryId = memoryId; + this.switchToWorkerThreadForToolExecution = switchToWorkerThreadForToolExecution; + this.isCallerRunningOnWorkerThread = isCallerRunningOnWorkerThread; + } + + @Override + public void subscribe(MultiSubscriber subscriber) { + UnicastProcessor processor = UnicastProcessor.create(); + processor.subscribe(subscriber); + + createTokenStream(processor); + } + + private void createTokenStream(UnicastProcessor processor) { + Context ctxt = null; + if (switchToWorkerThreadForToolExecution || isCallerRunningOnWorkerThread) { + // we create or retrieve the current context, to use `executeBlocking` when required. + ctxt = VertxContext.getOrCreateDuplicatedContext(); + } + + var stream = new QuarkusAiServiceTokenStream(messagesToSend, toolSpecifications, + toolsExecutors, contents, context, memoryId, ctxt, switchToWorkerThreadForToolExecution, + isCallerRunningOnWorkerThread); + TokenStream tokenStream = stream + .onPartialResponse(chunk -> processor.onNext(new ChatEvent.PartialResponseEvent(chunk))) + .onCompleteResponse(message -> { + processor.onNext(new ChatEvent.ChatCompletedEvent(message)); + processor.onComplete(); + }) + .onRetrieved(content -> processor.onNext(new ChatEvent.ContentFetchedEvent(content))) + .onToolExecuted(execution -> processor.onNext(new ChatEvent.ToolExecutedEvent(execution))) + .onError(processor::onError); + // This is equivalent to "run subscription on worker thread" + if (switchToWorkerThreadForToolExecution && Context.isOnEventLoopThread()) { + ctxt.executeBlocking(new Callable() { + @Override + public Void call() { + tokenStream.start(); + return null; + } + }); + } else { + tokenStream.start(); + } + } +} diff --git a/core/runtime/src/main/resources/META-INF/services/dev.langchain4j.spi.classloading.ClassInstanceFactory b/core/runtime/src/main/resources/META-INF/services/dev.langchain4j.spi.classloading.ClassInstanceFactory new file mode 100644 index 000000000..9fe7a50ca --- /dev/null +++ b/core/runtime/src/main/resources/META-INF/services/dev.langchain4j.spi.classloading.ClassInstanceFactory @@ -0,0 +1 @@ +io.quarkiverse.langchain4j.QuarkusClassInstanceFactory diff --git a/core/runtime/src/main/resources/META-INF/services/dev.langchain4j.spi.classloading.ClassMetadataProviderFactory b/core/runtime/src/main/resources/META-INF/services/dev.langchain4j.spi.classloading.ClassMetadataProviderFactory new file mode 100644 index 000000000..3d73afaab --- /dev/null +++ b/core/runtime/src/main/resources/META-INF/services/dev.langchain4j.spi.classloading.ClassMetadataProviderFactory @@ -0,0 +1 @@ +io.quarkiverse.langchain4j.QuarkusClassMetadataProviderFactory diff --git a/core/runtime/src/main/resources/META-INF/services/dev.langchain4j.spi.guardrail.config.InputGuardrailsConfigBuilderFactory b/core/runtime/src/main/resources/META-INF/services/dev.langchain4j.spi.guardrail.config.InputGuardrailsConfigBuilderFactory new file mode 100644 index 000000000..8ffee6cb3 --- /dev/null +++ b/core/runtime/src/main/resources/META-INF/services/dev.langchain4j.spi.guardrail.config.InputGuardrailsConfigBuilderFactory @@ -0,0 +1 @@ +io.quarkiverse.langchain4j.guardrails.QuarkusInputGuardrailsConfigBuilderFactory diff --git a/core/runtime/src/main/resources/META-INF/services/dev.langchain4j.spi.guardrail.config.OutputGuardrailsConfigBuilderFactory b/core/runtime/src/main/resources/META-INF/services/dev.langchain4j.spi.guardrail.config.OutputGuardrailsConfigBuilderFactory new file mode 100644 index 000000000..f8e12e3b9 --- /dev/null +++ b/core/runtime/src/main/resources/META-INF/services/dev.langchain4j.spi.guardrail.config.OutputGuardrailsConfigBuilderFactory @@ -0,0 +1 @@ +io.quarkiverse.langchain4j.guardrails.QuarkusOutputGuardrailsConfigBuilderFactory diff --git a/docs/modules/ROOT/pages/guardrails.adoc b/docs/modules/ROOT/pages/guardrails.adoc index bf27763d1..06b8891c6 100644 --- a/docs/modules/ROOT/pages/guardrails.adoc +++ b/docs/modules/ROOT/pages/guardrails.adoc @@ -3,6 +3,89 @@ include::./includes/attributes.adoc[] include::./includes/customization.adoc[] +[IMPORTANT] +.Quarkus-specific implementation deprecation +==== +The Quarkus-specific guardrail implementation is deprecated in favor of the https://docs.langchain4j.dev/tutorials/guardrails[LangChain4j-specific implementation]. If you are currently +using the Quarkus specific guardrail implementation, you should migrate to the upstream implementation. + +In most cases the switch is simply to change package imports from `io.quarkiverse.langchain4j.guardrails` to `dev.langchain4j.guardrail` or `dev.langchain4j.service.guardrail`. +There are a few differences in the APIs, but the class/interface names are mostly the same. + +In fact, the https://docs.langchain4j.dev/tutorials/guardrails[upstream implementation] was backported from the Quarkus implementation! + +Both the Quarkus-specific implementation and the upstream implementation will be supported in Quarkus for a few releases, but eventually the Quarkus-specific implementation will be removed. + +There will also be deprecation warnings printed to the console that look similar to this: + +``` +WARN [io.qua.lan.dep.AiServicesProcessor] (build-44) +================== DEPRECATION WARNING ================== +The following Quarkus-specific output guardrail classes have been discovered on the method (java.lang.String hi(java.lang.String mem)) in the class (io.quarkiverse.langchain4j.test.guardrails.QuarkusOutputGuardrailTest$MyAiService). Please move to the new upstream guardrails. +io.quarkiverse.langchain4j.test.guardrails.QuarkusOutputGuardrailTest$OKGuardrail +``` +==== + +== Guardrail Scopes + +Guardrails MUST be CDI beans. They can be in any CDI scope, including request scope, application scope, or session scope. + +The scope of the guardrail is important as it defines the lifecycle of the guardrail, especially when the guardrail is stateful. + +== Output Guardrails configuration + +By default, Quarkus Langchain4J will limit the number of retries to `3` (the default in upstream LangChain4j is `2`). +This is configurable using the `quarkus.langchain4j.guardrails.max-retries` configuration property: + +[source,properties] +---- +quarkus.langchain4j.guardrails.max-retries=5 +---- + +NOTE: Setting `quarkus.langchain4j.guardrails.max-retries` to 0 disables retries. + +Configuration can also be https://docs.langchain4j.dev/tutorials/guardrails#configuration[set in the `@OutputGuardrails` annotation directly], which will override any defaults set +for a specific operation. + +== Output Guardrails for Streamed Responses + +Output guardrails can be applied to methods that return `Multi` or `TokenStream`. +By default, Quarkus will automatically assemble the full response before executing the guardrail chain. +Keep in mind that this may have a performance impact when handling large responses. + +To control when the guardrail chain is invoked during streaming, configure an accumulator: + +[source, java] +---- +@UserMessage("...") +@OutputGuardrails(MyGuardrail.class) +@OutputGuardrailAccumulator(PassThroughAccumulator.class) // Defines the accumulator +Multi ask(); +---- + +The `@OutputGuardrailAccumulator` annotation allows you to specify a custom accumulator. +The accumulator must implement the `io.quarkiverse.langchain4j.guardrails.OutputTokenAccumulator` interface and be exposed as a CDI bean. +The following is an example of a pass-through accumulator that does not accumulate tokens: + +[source,java] +---- +@ApplicationScoped +public class PassThroughAccumulator implements OutputTokenAccumulator { + + @Override + public Multi accumulate(Multi tokens) { + return tokens; // Passes the tokens through without accumulating + } +} +---- + +You can create accumulators based on various criteria, such as the number of tokens, a specific separator, or time intervals. + +When an accumulator is set, the output guardrail chain is invoked for **each item** emitted by the `Multi` returned by the `accumulate` method. + +In the case of a retry, the accumulator is called again with the new response, restarting the stream from the beginning. The same behavior applies for reprompts. + +== Deprecated Quarkus-specific features Guardrails are mechanisms that validate the user input and the output of an LLM to ensure it meets your expectations. Typically, you can: @@ -16,11 +99,11 @@ If an output guardrail fails, the system may retry or reprompt to improve the re image::guardrails.png[width=600,align="center"] -== Input Guardrails +=== Input Guardrails Input guardrails are _functions_ invoked before the LLM is called. -=== Implementing Input Guardrails +==== Implementing Input Guardrails Input guardrails are implemented as CDI beans and must implement the `io.quarkiverse.langchain4j.guardrails.InputGuardrail` interface: @@ -91,7 +174,7 @@ Simple guardrails can use this method. The second one is used for more complex guardrails that need more information, like the memory or the augmentation results. For example, they can check whether enough documents are present in the augmentation results, or whether the user is repeating the same question. -==== Input Guardrails Outcome +===== Input Guardrails Outcome Input guardrails can have three outcomes: @@ -99,14 +182,14 @@ Input guardrails can have three outcomes: - _fail_ – The input is invalid, but the next guardrail is still executed to accumulate all potential validation issues. - _fatal_ - The input is invalid, the next guardrail is **not** executed, and the error is rethrown. The LLM is not called. -==== Input Guardrails Scopes +===== Input Guardrails Scopes Input guardrails are CDI beans. They can be in any CDI scope, including request scope, application scope, or session scope. The CDI scope of a guardrail defines its lifecycle, which is particularly important when the guardrail maintains state. -=== Declaring Input Guardrails +==== Declaring Input Guardrails Input guardrails are declared on the AI Service interface. You can declare input guardrails in two ways: @@ -138,7 +221,7 @@ public interface ChatBot { } ---- -==== Input Guardrail Chain +===== Input Guardrail Chain You can declare multiple guardrails. In this case, a chain is created, and the guardrails are executed in the order they are declared. @@ -194,11 +277,11 @@ public class VerifyHeroFormat implements InputGuardrail { } ---- -== Output Guardrails +=== Output Guardrails Output guardrails are _functions_ invoked once the LLM has produced its output. -=== Implementing Output Guardrails +==== Implementing Output Guardrails Output guardrails are implemented as CDI beans and must implement the `io.quarkiverse.langchain4j.guardrails.OutputGuardrail` interface: @@ -301,7 +384,7 @@ The second signature is used when the guardrail needs more information, like the Note that the guardrail cannot modify the memory or the augmentation results. The <<_detecting_hallucinations_in_the_rag_context>> section gives an example of guardrail using the augmented results. -==== Output Guardrails Outcome +===== Output Guardrails Outcome Output guardrails can have six outcomes: @@ -327,24 +410,7 @@ return retry("Invalid JSON"); return reprompt("Invalid JSON", "Make sure you return a valid JSON object"); ---- -By default, Quarkus Langchain4J will limit the number of retries to 3. -This is configurable using the `quarkus.langchain4j.guardrails.max-retries` configuration property: - -[source,properties] ----- -quarkus.langchain4j.guardrails.max-retries=5 ----- - -NOTE: Setting `quarkus.langchain4j.guardrails.max-retries` to 0 disables retries. - -==== Output Guardrails Scopes - -Output guardrails are CDI beans. -They can be in any CDI scope, including request scope, application scope, or session scope. - -The scope of the guardrail is important as it defines the lifecycle of the guardrail, especially when the guardrail is stateful. - -=== Declaring Output Guardrails +==== Declaring Output Guardrails Output guardrails are declared on the AI Service interface. You can declare output guardrails in two ways: @@ -375,7 +441,7 @@ public interface ChatBot { } ---- -==== Output Guardrail Chain +===== Output Guardrail Chain You can declare multiple guardrails. In this case, a chain is created, and the guardrails are executed in the order they are declared. @@ -408,45 +474,6 @@ Then, the `ConsistentStoryGuardrail` is executed to check that the story is cons If the `JsonGuardrail` fails, the `ConsistentStoryGuardrail` is not executed. However, if the `ConsistentStoryGuardrail` fails with a retry or reprompt, the `JsonGuardrail` is executed again with the new response. -=== Output Guardrails for Streamed Responses - -Output guardrails can be applied to methods that return `Multi`, but not to those returning `TokenStream`. -By default, Quarkus will automatically assemble the full response before executing the guardrail chain. -Keep in mind that this may have a performance impact when handling large responses. - -To control when the guardrail chain is invoked during streaming, configure an accumulator: - -[source, java] ----- -@UserMessage("...") -@OutputGuardrails(MyGuardrail.class) -@OutputGuardrailAccumulator(PassThroughAccumulator.class) // Defines the accumulator -Multi ask(); ----- - -The `@OutputGuardrailAccumulator` annotation allows you to specify a custom accumulator. -The accumulator must implement the `io.quarkiverse.langchain4j.guardrails.OutputTokenAccumulator` interface and be exposed as a CDI bean. -The following is an example of a pass-through accumulator that does not accumulate tokens: - -[source,java] ----- -@ApplicationScoped -public class PassThroughAccumulator implements OutputTokenAccumulator { - - @Override - public Multi accumulate(Multi tokens) { - return tokens; // Passes the tokens through without accumulating - } -} ----- - -You can create accumulators based on various criteria, such as the number of tokens, a specific separator, or time intervals. - -When an accumulator is set, the output guardrail chain is invoked for **each item** emitted by the `Multi` returned by the `accumulate` method. - -In the case of a retry, the accumulator is called again with the new response, restarting the stream from the beginning. -The same behavior applies for reprompts. - [#_detecting_hallucinations_in_the_rag_context] === Detecting Hallucinations in the RAG Context @@ -515,7 +542,7 @@ public class HallucinationGuard implements OutputGuardrail { } ---- -=== Rewriting the LLM output +==== Rewriting the LLM output It may happen that the output generated by the LLM is not completely satisfying, but it can be programmatically adjusted instead of attempting a retry or a reprompt, both implying a costly, time consuming and less reliable new interaction with the LLM. For instance, it is quite common that an LLM produces the json of the data object that it is required to extract from the user prompt, but appends some unwanted explanation to it, making the json unparsable, something like [source] @@ -629,7 +656,7 @@ public class CustomersExtractionOutputGuardrail extends AbstractJsonExtractorOut } ---- -=== Unit testing +==== Unit testing Output guardrails can also be unit tested using provided AssertJ-based helpers. A set of AssertJ custom assertions (following https://assertj.github.io/doc/#assertj-core-custom-assertions-entry-point[AssertJ's custom assertions pattern]) are available to help you unit test your guardrails. @@ -856,7 +883,7 @@ class EmailContainsRequiredInformationOutputGuardrailTests { See https://github.com/quarkiverse/quarkus-langchain4j/blob/main/testing/core/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailResultAssert.java[`OutputGuardrailResultAssert.java`] for more information about the different kinds of asserts you can do. -== Summary +=== Summary Guardrails in Quarkus LangChain4j provide a declarative and extensible way to validate both input and output around LLM invocations. diff --git a/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/AssistantResourceWithGuardrailsAndObservability.java b/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/AssistantResourceWithGuardrailsAndObservability.java index 047aba125..60748108d 100644 --- a/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/AssistantResourceWithGuardrailsAndObservability.java +++ b/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/AssistantResourceWithGuardrailsAndObservability.java @@ -6,15 +6,15 @@ import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.guardrail.InputGuardrail; +import dev.langchain4j.guardrail.InputGuardrailRequest; +import dev.langchain4j.guardrail.InputGuardrailResult; +import dev.langchain4j.guardrail.OutputGuardrail; +import dev.langchain4j.guardrail.OutputGuardrailRequest; +import dev.langchain4j.guardrail.OutputGuardrailResult; +import dev.langchain4j.service.guardrail.InputGuardrails; +import dev.langchain4j.service.guardrail.OutputGuardrails; import io.quarkiverse.langchain4j.RegisterAiService; -import io.quarkiverse.langchain4j.guardrails.InputGuardrail; -import io.quarkiverse.langchain4j.guardrails.InputGuardrailParams; -import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult; -import io.quarkiverse.langchain4j.guardrails.InputGuardrails; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrailParams; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; @Path("assistant-with-guardrails-observability") public class AssistantResourceWithGuardrailsAndObservability { @@ -51,7 +51,7 @@ public InputGuardrailResult validate(UserMessage userMessage) { @ApplicationScoped public static class IGDirectlyImplementInputGuardrailWithParams implements InputGuardrail { @Override - public InputGuardrailResult validate(InputGuardrailParams params) { + public InputGuardrailResult validate(InputGuardrailRequest request) { return success(); } } @@ -59,7 +59,7 @@ public InputGuardrailResult validate(InputGuardrailParams params) { @ApplicationScoped public static class OGDirectlyImplementOutputGuardrailWithParams implements OutputGuardrail { @Override - public OutputGuardrailResult validate(OutputGuardrailParams params) { + public OutputGuardrailResult validate(OutputGuardrailRequest request) { return success(); } } @@ -90,7 +90,7 @@ public static class IGExtendingValidateWithUserMessage extends AbstractIGImpleme public static abstract class AbstractOGImplementingValidateWithParams implements OutputGuardrail { @Override - public OutputGuardrailResult validate(OutputGuardrailParams params) { + public OutputGuardrailResult validate(OutputGuardrailRequest request) { return success(); } } @@ -104,7 +104,7 @@ public OutputGuardrailResult validate(AiMessage responseFromLLM) { public static abstract class AbstractIGImplementingValidateWithParams implements InputGuardrail { @Override - public InputGuardrailResult validate(InputGuardrailParams params) { + public InputGuardrailResult validate(InputGuardrailRequest request) { return success(); } } diff --git a/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/QuarkusAssistantResourceWithGuardrailsAndObservability.java b/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/QuarkusAssistantResourceWithGuardrailsAndObservability.java new file mode 100644 index 000000000..79d3e5823 --- /dev/null +++ b/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/QuarkusAssistantResourceWithGuardrailsAndObservability.java @@ -0,0 +1,122 @@ +package org.acme.example.openai.aiservices; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.Path; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.guardrails.InputGuardrail; +import io.quarkiverse.langchain4j.guardrails.InputGuardrailParams; +import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult; +import io.quarkiverse.langchain4j.guardrails.InputGuardrails; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailParams; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; + +/** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ +@Deprecated(forRemoval = true) +@Path("quarkus-assistant-with-guardrails-observability") +public class QuarkusAssistantResourceWithGuardrailsAndObservability { + private final Assistant assistant; + + public QuarkusAssistantResourceWithGuardrailsAndObservability(Assistant assistant) { + this.assistant = assistant; + } + + @GET + public String assistant() { + return assistant.chat("test"); + } + + @RegisterAiService + interface Assistant { + @InputGuardrails({ IGDirectlyImplementInputGuardrailWithParams.class, + IGDirectlyImplementInputGuardrailWithUserMessage.class, IGExtendingValidateWithParams.class, + IGExtendingValidateWithUserMessage.class }) + @OutputGuardrails({ OGDirectlyImplementOutputGuardrailWithParams.class, + OGDirectlyImplementOutputGuardrailWithAiMessage.class, OGExtendingValidateWithParams.class, + OGExtendingValidateWithAiMessage.class }) + String chat(String message); + } + + @ApplicationScoped + public static class IGDirectlyImplementInputGuardrailWithUserMessage implements InputGuardrail { + @Override + public InputGuardrailResult validate(UserMessage userMessage) { + return success(); + } + } + + @ApplicationScoped + public static class IGDirectlyImplementInputGuardrailWithParams implements InputGuardrail { + @Override + public InputGuardrailResult validate(InputGuardrailParams params) { + return success(); + } + } + + @ApplicationScoped + public static class OGDirectlyImplementOutputGuardrailWithParams implements OutputGuardrail { + @Override + public OutputGuardrailResult validate(OutputGuardrailParams params) { + return success(); + } + } + + @ApplicationScoped + public static class OGDirectlyImplementOutputGuardrailWithAiMessage implements OutputGuardrail { + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + return success(); + } + } + + @ApplicationScoped + public static class OGExtendingValidateWithParams extends AbstractOGImplementingValidateWithParams { + } + + @ApplicationScoped + public static class OGExtendingValidateWithAiMessage extends AbstractOGImplementingValidateWithAiMessage { + } + + @ApplicationScoped + public static class IGExtendingValidateWithParams extends AbstractIGImplementingValidateWithParams { + } + + @ApplicationScoped + public static class IGExtendingValidateWithUserMessage extends AbstractIGImplementingValidateWithUserMessage { + } + + public static abstract class AbstractOGImplementingValidateWithParams implements OutputGuardrail { + @Override + public OutputGuardrailResult validate(OutputGuardrailParams params) { + return success(); + } + } + + public static abstract class AbstractOGImplementingValidateWithAiMessage implements OutputGuardrail { + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + return success(); + } + } + + public static abstract class AbstractIGImplementingValidateWithParams implements InputGuardrail { + @Override + public InputGuardrailResult validate(InputGuardrailParams params) { + return success(); + } + } + + public static abstract class AbstractIGImplementingValidateWithUserMessage implements InputGuardrail { + @Override + public InputGuardrailResult validate(UserMessage userMessage) { + return success(); + } + } +} diff --git a/integration-tests/openai/src/test/java/org/acme/example/openai/aiservices/AssistantResourceWithGuardrailsAndObservabilityTest.java b/integration-tests/openai/src/test/java/org/acme/example/openai/aiservices/AssistantResourceWithGuardrailsAndObservabilityTest.java index c5bbcad6f..6443aea7b 100644 --- a/integration-tests/openai/src/test/java/org/acme/example/openai/aiservices/AssistantResourceWithGuardrailsAndObservabilityTest.java +++ b/integration-tests/openai/src/test/java/org/acme/example/openai/aiservices/AssistantResourceWithGuardrailsAndObservabilityTest.java @@ -35,11 +35,12 @@ class AssistantResourceWithGuardrailsAndObservabilityTest { @BeforeAll static void addSimpleRegistry() { Metrics.globalRegistry.add(new SimpleMeterRegistry()); + Metrics.globalRegistry.clear(); } @Test void guardrailMetricsAvailable() { - get("/assistant-with-guardrails-observability").then() + get("assistant-with-guardrails-observability").then() .statusCode(200) .body(TestUtils.containsStringOrMock("test")); diff --git a/integration-tests/openai/src/test/java/org/acme/example/openai/aiservices/QuarkusAssistantResourceWithGuardrailsAndObservabilityTest.java b/integration-tests/openai/src/test/java/org/acme/example/openai/aiservices/QuarkusAssistantResourceWithGuardrailsAndObservabilityTest.java new file mode 100644 index 000000000..11547c514 --- /dev/null +++ b/integration-tests/openai/src/test/java/org/acme/example/openai/aiservices/QuarkusAssistantResourceWithGuardrailsAndObservabilityTest.java @@ -0,0 +1,124 @@ +package org.acme.example.openai.aiservices; + +import static io.restassured.RestAssured.get; +import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; +import static org.hamcrest.Matchers.containsString; + +import java.time.Duration; + +import jakarta.inject.Inject; + +import org.acme.example.openai.aiservices.QuarkusAssistantResourceWithGuardrailsAndObservability.AbstractIGImplementingValidateWithParams; +import org.acme.example.openai.aiservices.QuarkusAssistantResourceWithGuardrailsAndObservability.AbstractIGImplementingValidateWithUserMessage; +import org.acme.example.openai.aiservices.QuarkusAssistantResourceWithGuardrailsAndObservability.AbstractOGImplementingValidateWithAiMessage; +import org.acme.example.openai.aiservices.QuarkusAssistantResourceWithGuardrailsAndObservability.AbstractOGImplementingValidateWithParams; +import org.acme.example.openai.aiservices.QuarkusAssistantResourceWithGuardrailsAndObservability.IGDirectlyImplementInputGuardrailWithParams; +import org.acme.example.openai.aiservices.QuarkusAssistantResourceWithGuardrailsAndObservability.IGDirectlyImplementInputGuardrailWithUserMessage; +import org.acme.example.openai.aiservices.QuarkusAssistantResourceWithGuardrailsAndObservability.OGDirectlyImplementOutputGuardrailWithAiMessage; +import org.acme.example.openai.aiservices.QuarkusAssistantResourceWithGuardrailsAndObservability.OGDirectlyImplementOutputGuardrailWithParams; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.MeterRegistry; +import io.micrometer.core.instrument.Metrics; +import io.micrometer.core.instrument.Timer; +import io.micrometer.core.instrument.simple.SimpleMeterRegistry; +import io.quarkus.test.junit.QuarkusTest; + +/** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ +@Deprecated(forRemoval = true) +@QuarkusTest +class QuarkusAssistantResourceWithGuardrailsAndObservabilityTest { + @Inject + MeterRegistry registry; + + @BeforeAll + static void addSimpleRegistry() { + Metrics.globalRegistry.add(new SimpleMeterRegistry()); + Metrics.globalRegistry.clear(); + } + + @Test + void guardrailMetricsAvailable() { + get("/quarkus-assistant-with-guardrails-observability").then() + .statusCode(200) + .body(containsString("MockGPT")); + + await() + .atMost(Duration.ofSeconds(10)) + .untilAsserted(() -> assertThat( + registry + .find("guardrail.invoked") + .tag("method", "validate") + .counters()) + .hasSize(8) + .allSatisfy(c -> assertThat(c) + .isNotNull() + .extracting( + Counter::count, + ct -> ct.getId().getDescription()) + .containsExactly( + 1.0, "Measures the number of times this guardrail was invoked")) + .map(c -> c.getId().getTag("class")) + .containsOnlyOnce(AbstractIGImplementingValidateWithUserMessage.class.getName()) + .containsOnlyOnce(IGDirectlyImplementInputGuardrailWithParams.class.getName()) + .containsOnlyOnce(AbstractIGImplementingValidateWithParams.class.getName()) + .containsOnlyOnce(OGDirectlyImplementOutputGuardrailWithAiMessage.class.getName()) + .containsOnlyOnce(AbstractOGImplementingValidateWithParams.class.getName()) + .containsOnlyOnce(AbstractOGImplementingValidateWithAiMessage.class.getName()) + .containsOnlyOnce(IGDirectlyImplementInputGuardrailWithUserMessage.class.getName()) + .containsOnlyOnce(OGDirectlyImplementOutputGuardrailWithParams.class.getName())); + + await() + .atMost(Duration.ofSeconds(10)) + .untilAsserted(() -> assertThat( + registry + .find("guardrail.timed") + .tag("method", "validate") + .timers()) + .hasSize(8) + .allSatisfy(t -> assertThat(t) + .isNotNull() + .extracting( + Timer::count, + ti -> ti.getId().getDescription()) + .containsExactly(1L, "Measures the runtime of this guardrail")) + .map(c -> c.getId().getTag("class")) + .containsOnlyOnce(AbstractIGImplementingValidateWithUserMessage.class.getName()) + .containsOnlyOnce(IGDirectlyImplementInputGuardrailWithParams.class.getName()) + .containsOnlyOnce(AbstractIGImplementingValidateWithParams.class.getName()) + .containsOnlyOnce(OGDirectlyImplementOutputGuardrailWithAiMessage.class.getName()) + .containsOnlyOnce(AbstractOGImplementingValidateWithParams.class.getName()) + .containsOnlyOnce(AbstractOGImplementingValidateWithAiMessage.class.getName()) + .containsOnlyOnce(IGDirectlyImplementInputGuardrailWithUserMessage.class.getName()) + .containsOnlyOnce(OGDirectlyImplementOutputGuardrailWithParams.class.getName())); + + await() + .atMost(Duration.ofSeconds(10)) + .untilAsserted(() -> assertThat( + registry + .find("guardrail.timed") + .tag("method", "validate") + .timers()) + .hasSize(8) + .allSatisfy(t -> assertThat(t) + .isNotNull() + .extracting( + Timer::count, + ti -> ti.getId().getDescription()) + .containsExactly(1L, "Measures the runtime of this guardrail")) + .map(c -> c.getId().getTag("class")) + .containsOnlyOnce(AbstractIGImplementingValidateWithUserMessage.class.getName()) + .containsOnlyOnce(IGDirectlyImplementInputGuardrailWithParams.class.getName()) + .containsOnlyOnce(AbstractIGImplementingValidateWithParams.class.getName()) + .containsOnlyOnce(OGDirectlyImplementOutputGuardrailWithAiMessage.class.getName()) + .containsOnlyOnce(AbstractOGImplementingValidateWithParams.class.getName()) + .containsOnlyOnce(AbstractOGImplementingValidateWithAiMessage.class.getName()) + .containsOnlyOnce(IGDirectlyImplementInputGuardrailWithUserMessage.class.getName()) + .containsOnlyOnce(OGDirectlyImplementOutputGuardrailWithParams.class.getName())); + } +} diff --git a/model-providers/openai/openai-vanilla/deployment/pom.xml b/model-providers/openai/openai-vanilla/deployment/pom.xml index 1c2e22466..db56852b6 100644 --- a/model-providers/openai/openai-vanilla/deployment/pom.xml +++ b/model-providers/openai/openai-vanilla/deployment/pom.xml @@ -63,6 +63,11 @@ ${project.version} test + + dev.langchain4j + langchain4j-test + test + diff --git a/model-providers/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/AuditingTests.java b/model-providers/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/AuditingTests.java index b95b8d6d0..ba0afc906 100644 --- a/model-providers/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/AuditingTests.java +++ b/model-providers/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/AuditingTests.java @@ -3,7 +3,7 @@ import static com.github.tomakehurst.wiremock.client.WireMock.okJson; import static com.github.tomakehurst.wiremock.client.WireMock.post; import static com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo; -import static io.quarkiverse.langchain4j.guardrails.GuardrailAssertions.assertThat; +import static dev.langchain4j.test.guardrail.GuardrailAssertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.atIndex; @@ -30,31 +30,32 @@ import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.SystemMessage; import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.guardrail.GuardrailException; +import dev.langchain4j.guardrail.InputGuardrail; +import dev.langchain4j.guardrail.InputGuardrailResult; +import dev.langchain4j.guardrail.OutputGuardrail; +import dev.langchain4j.guardrail.OutputGuardrailResult; +import dev.langchain4j.service.guardrail.InputGuardrails; +import dev.langchain4j.service.guardrail.OutputGuardrails; +import dev.langchain4j.test.guardrail.GuardrailAssertions; import io.quarkiverse.langchain4j.RegisterAiService; import io.quarkiverse.langchain4j.audit.AuditSourceInfo; import io.quarkiverse.langchain4j.audit.InitialMessagesCreatedEvent; -import io.quarkiverse.langchain4j.audit.InputGuardrailExecutedEvent; import io.quarkiverse.langchain4j.audit.LLMInteractionCompleteEvent; import io.quarkiverse.langchain4j.audit.LLMInteractionEvent; import io.quarkiverse.langchain4j.audit.LLMInteractionFailureEvent; -import io.quarkiverse.langchain4j.audit.OutputGuardrailExecutedEvent; import io.quarkiverse.langchain4j.audit.ResponseFromLLMReceivedEvent; import io.quarkiverse.langchain4j.audit.ToolExecutedEvent; +import io.quarkiverse.langchain4j.audit.guardrails.InputGuardrailExecutedEvent; +import io.quarkiverse.langchain4j.audit.guardrails.OutputGuardrailExecutedEvent; +import io.quarkiverse.langchain4j.audit.guardrails.internal.DefaultInputGuardrailExecutedEvent; +import io.quarkiverse.langchain4j.audit.guardrails.internal.DefaultOutputGuardrailExecutedEvent; import io.quarkiverse.langchain4j.audit.internal.DefaultInitialMessagesCreatedEvent; -import io.quarkiverse.langchain4j.audit.internal.DefaultInputGuardrailExecutedEvent; import io.quarkiverse.langchain4j.audit.internal.DefaultLLMInteractionCompleteEvent; import io.quarkiverse.langchain4j.audit.internal.DefaultLLMInteractionFailureEvent; -import io.quarkiverse.langchain4j.audit.internal.DefaultOutputGuardrailExecutedEvent; import io.quarkiverse.langchain4j.audit.internal.DefaultResponseFromLLMReceivedEvent; import io.quarkiverse.langchain4j.audit.internal.DefaultToolExecutedEvent; -import io.quarkiverse.langchain4j.guardrails.InputGuardrail; -import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult; -import io.quarkiverse.langchain4j.guardrails.InputGuardrails; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; import io.quarkiverse.langchain4j.openai.testing.internal.OpenAiBaseTest; -import io.quarkiverse.langchain4j.runtime.aiservice.GuardrailException; import io.quarkiverse.langchain4j.testing.internal.WiremockAware; import io.quarkus.logging.Log; import io.quarkus.test.QuarkusUnitTest; @@ -85,90 +86,117 @@ class AuditingTests extends OpenAiBaseTest { OutputGuardrailAuditor outputGuardrailAuditor; @Test - void should_audit_input_guardrail_events() { + void should_audit_quarkus_input_guardrail_events() { setupWiremock(); - assertThatExceptionOfType(GuardrailException.class) - .isThrownBy(() -> assistant.chatWithInputGuardrails(USER_MESSAGE)) + assertThatExceptionOfType(io.quarkiverse.langchain4j.runtime.aiservice.GuardrailException.class) + .isThrownBy(() -> assistant.chatWithQuarkusInputGuardrails(USER_MESSAGE)) .withMessage("The guardrail %s failed with this message: User message is not valid", - FailureInputGuardrail.class.getName()); + FailureQuarkusInputGuardrail.class.getName()); - assertThat(inputGuardrailAuditor.inputGuardrailExecutedEvents) + io.quarkiverse.langchain4j.guardrails.GuardrailAssertions + .assertThat(inputGuardrailAuditor.quarkusInputGuardrailExecutedEvents) .hasSize(2) .satisfies(inputGuardrailExecutedEvent -> { - assertThat(inputGuardrailExecutedEvent) + io.quarkiverse.langchain4j.guardrails.GuardrailAssertions.assertThat(inputGuardrailExecutedEvent) .extracting( e -> e.sourceInfo().methodName(), e -> e.params().userMessage().singleText(), e -> e.rewrittenUserMessage().singleText(), - InputGuardrailExecutedEvent::guardrailClass) + io.quarkiverse.langchain4j.audit.InputGuardrailExecutedEvent::guardrailClass) .containsExactly( - "chatWithInputGuardrails", + "chatWithQuarkusInputGuardrails", USER_MESSAGE, "Success!!", - SuccessInputGuardrail.class); + SuccessQuarkusInputGuardrail.class); - assertThat(inputGuardrailExecutedEvent.result()).isSuccessful(); + io.quarkiverse.langchain4j.guardrails.GuardrailAssertions.assertThat(inputGuardrailExecutedEvent.result()) + .isSuccessful(); }, atIndex(0)) .satisfies(inputGuardrailExecutedEvent -> { - assertThat(inputGuardrailExecutedEvent) + io.quarkiverse.langchain4j.guardrails.GuardrailAssertions.assertThat(inputGuardrailExecutedEvent) .extracting( e -> e.sourceInfo().methodName(), e -> e.params().userMessage().singleText(), e -> e.rewrittenUserMessage().singleText(), - InputGuardrailExecutedEvent::guardrailClass) + io.quarkiverse.langchain4j.audit.InputGuardrailExecutedEvent::guardrailClass) .containsExactly( - "chatWithInputGuardrails", + "chatWithQuarkusInputGuardrails", "Success!!", "Success!!", - FailureInputGuardrail.class); + FailureQuarkusInputGuardrail.class); - assertThat(inputGuardrailExecutedEvent.result()) + io.quarkiverse.langchain4j.guardrails.GuardrailAssertions.assertThat(inputGuardrailExecutedEvent.result()) .hasSingleFailureWithMessage("User message is not valid"); }, atIndex(1)); } @Test - void should_audit_output_guardrail_events() { + void should_audit_quarkus_output_guardrail_events() { setupWiremock(); - assertThatExceptionOfType(GuardrailException.class) - .isThrownBy(() -> assistant.chatWithOutputGuardrails(USER_MESSAGE)) + assertThatExceptionOfType(io.quarkiverse.langchain4j.runtime.aiservice.GuardrailException.class) + .isThrownBy(() -> assistant.chatWithQuarkusOutputGuardrails(USER_MESSAGE)) .withMessage("The guardrail %s failed with this message: LLM response is not valid", - FailureOutputGuardrail.class.getName()); + FailureQuarkusOutputGuardrail.class.getName()); - assertThat(outputGuardrailAuditor.outputGuardrailExecutedEvents) + io.quarkiverse.langchain4j.guardrails.GuardrailAssertions + .assertThat(outputGuardrailAuditor.quarkusOutputGuardrailExecutedEvents) .hasSize(2) .satisfies(outputGuardrailExecutedEvent -> { - assertThat(outputGuardrailExecutedEvent) + io.quarkiverse.langchain4j.guardrails.GuardrailAssertions.assertThat(outputGuardrailExecutedEvent) .extracting( e -> e.sourceInfo().methodName(), e -> e.params().responseFromLLM().text(), - OutputGuardrailExecutedEvent::guardrailClass) + io.quarkiverse.langchain4j.audit.OutputGuardrailExecutedEvent::guardrailClass) .containsExactly( - "chatWithOutputGuardrails", + "chatWithQuarkusOutputGuardrails", EXPECTED_RESPONSE, - SuccessOutputGuardrail.class); + SuccessQuarkusOutputGuardrail.class); - assertThat(outputGuardrailExecutedEvent.result()) + io.quarkiverse.langchain4j.guardrails.GuardrailAssertions.assertThat(outputGuardrailExecutedEvent.result()) .hasSuccess("Success!!", "Success!!"); }, atIndex(0)) .satisfies(outputGuardrailExecutedEvent -> { - assertThat(outputGuardrailExecutedEvent) + io.quarkiverse.langchain4j.guardrails.GuardrailAssertions.assertThat(outputGuardrailExecutedEvent) .extracting( e -> e.sourceInfo().methodName(), e -> e.params().responseFromLLM().text(), - OutputGuardrailExecutedEvent::guardrailClass) + io.quarkiverse.langchain4j.audit.OutputGuardrailExecutedEvent::guardrailClass) .containsExactly( - "chatWithOutputGuardrails", + "chatWithQuarkusOutputGuardrails", "Success!!", - FailureOutputGuardrail.class); + FailureQuarkusOutputGuardrail.class); - assertThat(outputGuardrailExecutedEvent.result()) + io.quarkiverse.langchain4j.guardrails.GuardrailAssertions.assertThat(outputGuardrailExecutedEvent.result()) .hasSingleFailureWithMessage("LLM response is not valid"); }, atIndex(1)); } + @Test + void should_audit_input_guardrail_events() { + setupWiremock(); + + assertThatExceptionOfType(GuardrailException.class) + .isThrownBy(() -> assistant.chatWithInputGuardrails(USER_MESSAGE)) + .withMessage("The guardrail %s failed with this message: User message is not valid", + FailureInputGuardrail.class.getName()); + + assertThat(inputGuardrailAuditor.inputGuardrailExecutedEvents).isEmpty(); + } + + @Test + void should_audit_output_guardrail_events() { + setupWiremock(); + + assertThatExceptionOfType(GuardrailException.class) + .isThrownBy(() -> assistant.chatWithOutputGuardrails(USER_MESSAGE)) + .withMessage("The guardrail %s failed with this message: LLM response is not valid", + FailureOutputGuardrail.class.getName()); + + assertThat(outputGuardrailAuditor.outputGuardrailExecutedEvents).isEmpty(); + } + @Test void should_execute_tool_then_answer() throws IOException { setupWiremock(); @@ -176,9 +204,12 @@ void should_execute_tool_then_answer() throws IOException { var answer = assistant.chat(USER_MESSAGE); assertThat(answer).isEqualTo(EXPECTED_RESPONSE); - assertThat(wiremock().getServeEvents()).hasSize(2); - assertMultipleRequestMessage(getRequestAsMap(getRequestBody(wiremock().getServeEvents().get(0))), + var numServeEvents = wiremock().getServeEvents().size(); + assertThat(numServeEvents).isGreaterThanOrEqualTo(2); + + assertMultipleRequestMessage( + getRequestAsMap(getRequestBody(wiremock().getServeEvents().get((numServeEvents == 2) ? 0 : 2))), List.of( new MessageContent("system", "You are a chat bot that answers questions"), new MessageContent("user", @@ -278,6 +309,38 @@ private void setupWiremock() { wiremock().setSingleScenarioState(SCENARIO, Scenario.STARTED); } + @Singleton + static class SuccessQuarkusInputGuardrail implements io.quarkiverse.langchain4j.guardrails.InputGuardrail { + @Override + public io.quarkiverse.langchain4j.guardrails.InputGuardrailResult validate(UserMessage userMessage) { + return successWith("Success!!"); + } + } + + @Singleton + static class FailureQuarkusInputGuardrail implements io.quarkiverse.langchain4j.guardrails.InputGuardrail { + @Override + public io.quarkiverse.langchain4j.guardrails.InputGuardrailResult validate(UserMessage userMessage) { + return failure("User message is not valid"); + } + } + + @Singleton + static class SuccessQuarkusOutputGuardrail implements io.quarkiverse.langchain4j.guardrails.OutputGuardrail { + @Override + public io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult validate(AiMessage responseFromLLM) { + return successWith("Success!!"); + } + } + + @Singleton + static class FailureQuarkusOutputGuardrail implements io.quarkiverse.langchain4j.guardrails.OutputGuardrail { + @Override + public io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult validate(AiMessage responseFromLLM) { + return failure("LLM response is not valid"); + } + } + @Singleton static class SuccessInputGuardrail implements InputGuardrail { @Override @@ -331,6 +394,14 @@ public void run() { interface Assistant { String chat(String message); + @io.quarkiverse.langchain4j.guardrails.InputGuardrails({ SuccessQuarkusInputGuardrail.class, + FailureQuarkusInputGuardrail.class }) + String chatWithQuarkusInputGuardrails(String message); + + @io.quarkiverse.langchain4j.guardrails.OutputGuardrails({ SuccessQuarkusOutputGuardrail.class, + FailureQuarkusOutputGuardrail.class }) + String chatWithQuarkusOutputGuardrails(String message); + @InputGuardrails({ SuccessInputGuardrail.class, FailureInputGuardrail.class }) String chatWithInputGuardrails(String message); @@ -456,15 +527,36 @@ private static boolean captureEvent(AuditSourceInfo sourceInfo) { @Singleton static class InputGuardrailAuditor { + /** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ + @Deprecated(forRemoval = true) + List quarkusInputGuardrailExecutedEvents = new ArrayList<>(); List inputGuardrailExecutedEvents = new ArrayList<>(); + /** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ + @Deprecated(forRemoval = true) + public void inputGuardrailExecuted( + @Observes io.quarkiverse.langchain4j.audit.InputGuardrailExecutedEvent inputGuardrailExecutedEvent) { + assertThat(inputGuardrailExecutedEvent) + .isNotNull() + .isExactlyInstanceOf(io.quarkiverse.langchain4j.audit.internal.DefaultInputGuardrailExecutedEvent.class); + handle(inputGuardrailExecutedEvent); + + if ("chatWithQuarkusInputGuardrails".equals(inputGuardrailExecutedEvent.sourceInfo().methodName())) { + this.quarkusInputGuardrailExecutedEvents.add(inputGuardrailExecutedEvent); + } + } + public void inputGuardrailExecuted(@Observes InputGuardrailExecutedEvent inputGuardrailExecutedEvent) { assertThat(inputGuardrailExecutedEvent) .isNotNull() .isExactlyInstanceOf(DefaultInputGuardrailExecutedEvent.class); handle(inputGuardrailExecutedEvent); - if ("chatWithInputGuardrails".equals(inputGuardrailExecutedEvent.sourceInfo().methodName())) { + if ("chatWithQuarkusInputGuardrails".equals(inputGuardrailExecutedEvent.sourceInfo().methodName())) { this.inputGuardrailExecutedEvents.add(inputGuardrailExecutedEvent); } } @@ -476,15 +568,36 @@ private static void handle(LLMInteractionEvent llmInteractionEvent) { @Singleton static class OutputGuardrailAuditor { + /** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ + @Deprecated(forRemoval = true) + List quarkusOutputGuardrailExecutedEvents = new ArrayList<>(); List outputGuardrailExecutedEvents = new ArrayList<>(); + /** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ + @Deprecated(forRemoval = true) + public void outputGuardrailExecuted( + @Observes io.quarkiverse.langchain4j.audit.OutputGuardrailExecutedEvent outputGuardrailExecutedEvent) { + assertThat(outputGuardrailExecutedEvent) + .isNotNull() + .isExactlyInstanceOf(io.quarkiverse.langchain4j.audit.internal.DefaultOutputGuardrailExecutedEvent.class); + handle(outputGuardrailExecutedEvent); + + if ("chatWithQuarkusOutputGuardrails".equals(outputGuardrailExecutedEvent.sourceInfo().methodName())) { + this.quarkusOutputGuardrailExecutedEvents.add(outputGuardrailExecutedEvent); + } + } + public void outputGuardrailExecuted(@Observes OutputGuardrailExecutedEvent outputGuardrailExecutedEvent) { assertThat(outputGuardrailExecutedEvent) .isNotNull() .isExactlyInstanceOf(DefaultOutputGuardrailExecutedEvent.class); handle(outputGuardrailExecutedEvent); - if ("chatWithOutputGuardrails".equals(outputGuardrailExecutedEvent.sourceInfo().methodName())) { + if ("chatWithQuarkusOutputGuardrails".equals(outputGuardrailExecutedEvent.sourceInfo().methodName())) { this.outputGuardrailExecutedEvents.add(outputGuardrailExecutedEvent); } } diff --git a/testing/core/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailAssertions.java b/testing/core/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailAssertions.java index 31f897780..f1f58c0f2 100644 --- a/testing/core/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailAssertions.java +++ b/testing/core/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailAssertions.java @@ -7,7 +7,10 @@ *

* This follows the pattern described in https://assertj.github.io/doc/#assertj-core-custom-assertions-entry-point *

+ * + * @deprecated Use {@link dev.langchain4j.test.guardrail.GuardrailAssertions} instead */ +@Deprecated(forRemoval = true) public class GuardrailAssertions extends Assertions { /** * Returns an {@link OutputGuardrailResultAssert} for assertions on an {@link OutputGuardrailResult} diff --git a/testing/core/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailResultAssert.java b/testing/core/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailResultAssert.java index 2538fadde..38da84764 100644 --- a/testing/core/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailResultAssert.java +++ b/testing/core/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailResultAssert.java @@ -22,7 +22,9 @@ * @param The type of {@link GuardrailResultAssert} * @param The type of {@link GuardrailResult} * @param The type of {@link Failure} + * @deprecated Use {@link dev.langchain4j.test.guardrail.GuardrailResultAssert} instead */ +@Deprecated(forRemoval = true) public sealed abstract class GuardrailResultAssert, R extends GuardrailResult, F extends Failure> extends AbstractObjectAssert permits InputGuardrailResultAssert, OutputGuardrailResultAssert { diff --git a/testing/core/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrailResultAssert.java b/testing/core/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrailResultAssert.java index b071e55dc..681971922 100644 --- a/testing/core/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrailResultAssert.java +++ b/testing/core/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrailResultAssert.java @@ -7,7 +7,10 @@ *

* This follows the pattern described in https://assertj.github.io/doc/#assertj-core-custom-assertions-creation *

+ * + * @deprecated Use {@link dev.langchain4j.test.guardrail.InputGuardrailResultAssert} instead */ +@Deprecated(forRemoval = true) public final class InputGuardrailResultAssert extends GuardrailResultAssert { diff --git a/testing/core/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailResultAssert.java b/testing/core/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailResultAssert.java index 81b38804f..d058700cc 100644 --- a/testing/core/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailResultAssert.java +++ b/testing/core/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailResultAssert.java @@ -7,7 +7,10 @@ *

* This follows the pattern described in https://assertj.github.io/doc/#assertj-core-custom-assertions-creation *

+ * + * @deprecated Use {@link dev.langchain4j.test.guardrail.OutputGuardrailResultAssert} instead */ +@Deprecated(forRemoval = true) public final class OutputGuardrailResultAssert extends GuardrailResultAssert { diff --git a/testing/core/src/test/java/io/quarkiverse/langchain4j/guardrails/EmailContainsRequiredInformationOutputGuardrail.java b/testing/core/src/test/java/io/quarkiverse/langchain4j/guardrails/EmailContainsRequiredInformationOutputGuardrail.java index 247980f0b..140eadea8 100644 --- a/testing/core/src/test/java/io/quarkiverse/langchain4j/guardrails/EmailContainsRequiredInformationOutputGuardrail.java +++ b/testing/core/src/test/java/io/quarkiverse/langchain4j/guardrails/EmailContainsRequiredInformationOutputGuardrail.java @@ -4,6 +4,10 @@ import dev.langchain4j.data.message.AiMessage; +/** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ +@Deprecated(forRemoval = true) public class EmailContainsRequiredInformationOutputGuardrail implements OutputGuardrail { static final String NO_RESPONSE_MESSAGE = "No response found"; static final String NO_RESPONSE_PROMPT = "The response was empty. Please try again."; diff --git a/testing/core/src/test/java/io/quarkiverse/langchain4j/guardrails/EmailContainsRequiredInformationOutputGuardrailTests.java b/testing/core/src/test/java/io/quarkiverse/langchain4j/guardrails/EmailContainsRequiredInformationOutputGuardrailTests.java index bd586550d..d36cf8b61 100644 --- a/testing/core/src/test/java/io/quarkiverse/langchain4j/guardrails/EmailContainsRequiredInformationOutputGuardrailTests.java +++ b/testing/core/src/test/java/io/quarkiverse/langchain4j/guardrails/EmailContainsRequiredInformationOutputGuardrailTests.java @@ -17,6 +17,10 @@ import io.quarkiverse.langchain4j.guardrails.GuardrailResult.Result; import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult.Failure; +/** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ +@Deprecated(forRemoval = true) class EmailContainsRequiredInformationOutputGuardrailTests { private static final String CLAIM_NUMBER = "CLM195501"; private static final String CLAIM_STATUS = "denied"; diff --git a/testing/core/src/test/java/io/quarkiverse/langchain4j/guardrails/EmailEndsAppropriatelyOutputGuardrail.java b/testing/core/src/test/java/io/quarkiverse/langchain4j/guardrails/EmailEndsAppropriatelyOutputGuardrail.java index 5f84e5ceb..a7ef89acf 100644 --- a/testing/core/src/test/java/io/quarkiverse/langchain4j/guardrails/EmailEndsAppropriatelyOutputGuardrail.java +++ b/testing/core/src/test/java/io/quarkiverse/langchain4j/guardrails/EmailEndsAppropriatelyOutputGuardrail.java @@ -2,6 +2,10 @@ import dev.langchain4j.data.message.AiMessage; +/** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ +@Deprecated(forRemoval = true) public class EmailEndsAppropriatelyOutputGuardrail implements OutputGuardrail { static final String REPROMPT_MESSAGE = "Invalid email"; static final String EMAIL_ENDING = """ diff --git a/testing/core/src/test/java/io/quarkiverse/langchain4j/guardrails/EmailEndsAppropriatelyOutputGuardrailTests.java b/testing/core/src/test/java/io/quarkiverse/langchain4j/guardrails/EmailEndsAppropriatelyOutputGuardrailTests.java index 5c4902c79..ad156d2c6 100644 --- a/testing/core/src/test/java/io/quarkiverse/langchain4j/guardrails/EmailEndsAppropriatelyOutputGuardrailTests.java +++ b/testing/core/src/test/java/io/quarkiverse/langchain4j/guardrails/EmailEndsAppropriatelyOutputGuardrailTests.java @@ -10,6 +10,10 @@ import dev.langchain4j.data.message.AiMessage; import io.quarkiverse.langchain4j.guardrails.GuardrailResult.Result; +/** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ +@Deprecated(forRemoval = true) class EmailEndsAppropriatelyOutputGuardrailTests { EmailEndsAppropriatelyOutputGuardrail guardrail = spy(new EmailEndsAppropriatelyOutputGuardrail()); diff --git a/testing/core/src/test/java/io/quarkiverse/langchain4j/guardrails/EmailStartsAppropriatelyOutputGuardrail.java b/testing/core/src/test/java/io/quarkiverse/langchain4j/guardrails/EmailStartsAppropriatelyOutputGuardrail.java index a3a073395..d34165daa 100644 --- a/testing/core/src/test/java/io/quarkiverse/langchain4j/guardrails/EmailStartsAppropriatelyOutputGuardrail.java +++ b/testing/core/src/test/java/io/quarkiverse/langchain4j/guardrails/EmailStartsAppropriatelyOutputGuardrail.java @@ -2,6 +2,10 @@ import dev.langchain4j.data.message.AiMessage; +/** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ +@Deprecated(forRemoval = true) public class EmailStartsAppropriatelyOutputGuardrail implements OutputGuardrail { static final String REPROMPT_MESSAGE = "Invalid email"; static final String REPROMPT_PROMPT = """ diff --git a/testing/core/src/test/java/io/quarkiverse/langchain4j/guardrails/EmailStartsAppropriatelyOutputGuardrailTests.java b/testing/core/src/test/java/io/quarkiverse/langchain4j/guardrails/EmailStartsAppropriatelyOutputGuardrailTests.java index 04c3c33fa..d2888e3be 100644 --- a/testing/core/src/test/java/io/quarkiverse/langchain4j/guardrails/EmailStartsAppropriatelyOutputGuardrailTests.java +++ b/testing/core/src/test/java/io/quarkiverse/langchain4j/guardrails/EmailStartsAppropriatelyOutputGuardrailTests.java @@ -10,6 +10,10 @@ import dev.langchain4j.data.message.AiMessage; import io.quarkiverse.langchain4j.guardrails.GuardrailResult.Result; +/** + * @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed + */ +@Deprecated(forRemoval = true) class EmailStartsAppropriatelyOutputGuardrailTests { EmailStartsAppropriatelyOutputGuardrail guardrail = spy(new EmailStartsAppropriatelyOutputGuardrail());