diff --git a/backend/src/main/java/ai/giskard/domain/Callable.java b/backend/src/main/java/ai/giskard/domain/Callable.java index 530ca26944..9ceca7622e 100644 --- a/backend/src/main/java/ai/giskard/domain/Callable.java +++ b/backend/src/main/java/ai/giskard/domain/Callable.java @@ -1,10 +1,10 @@ package ai.giskard.domain; import ai.giskard.utils.SimpleJSONStringAttributeConverter; +import jakarta.persistence.*; import lombok.Getter; import lombok.Setter; -import jakarta.persistence.*; import java.io.Serializable; import java.util.List; import java.util.UUID; @@ -13,9 +13,6 @@ @Setter @Inheritance(strategy = InheritanceType.SINGLE_TABLE) @Entity(name = "callable_functions") -@Table(uniqueConstraints = { - @UniqueConstraint(columnNames = {"name", "module", "version"}) -}) @DiscriminatorColumn(name = "callable_type", discriminatorType = DiscriminatorType.STRING) public class Callable implements Serializable { @Id diff --git a/backend/src/main/java/ai/giskard/domain/ml/CallToActionKind.java b/backend/src/main/java/ai/giskard/domain/ml/CallToActionKind.java new file mode 100644 index 0000000000..8eb91b7e1d --- /dev/null +++ b/backend/src/main/java/ai/giskard/domain/ml/CallToActionKind.java @@ -0,0 +1,25 @@ +package ai.giskard.domain.ml; + +import com.dataiku.j2ts.annotations.UIModel; +import com.fasterxml.jackson.annotation.JsonValue; + +@UIModel +public enum CallToActionKind { + NONE, + CREATE_SLICE, + CREATE_TEST, + CREATE_PERTURBATION, + SAVE_PERTURBATION, + CREATE_ROBUSTNESS_TEST, + CREATE_SLICE_OPEN_DEBUGGER, + OPEN_DEBUGGER_BORDERLINE, + ADD_TEST_TO_CATALOG, + SAVE_EXAMPLE, + OPEN_DEBUGGER_OVERCONFIDENCE, + CREATE_UNIT_TEST; + + @JsonValue + public int toValue() { + return ordinal(); + } +} diff --git a/backend/src/main/java/ai/giskard/domain/ml/PushKind.java b/backend/src/main/java/ai/giskard/domain/ml/PushKind.java new file mode 100644 index 0000000000..b6c8ddc5a8 --- /dev/null +++ b/backend/src/main/java/ai/giskard/domain/ml/PushKind.java @@ -0,0 +1,18 @@ +package ai.giskard.domain.ml; + +import com.dataiku.j2ts.annotations.UIModel; +import com.fasterxml.jackson.annotation.JsonValue; + +@UIModel +public enum PushKind { + INVALID, + PERTURBATION, + CONTRIBUTION, + OVERCONFIDENCE, + BORDERLINE; + + @JsonValue + public int toValue() { + return ordinal(); + } +} diff --git a/backend/src/main/java/ai/giskard/domain/ml/SuiteTestExecution.java b/backend/src/main/java/ai/giskard/domain/ml/SuiteTestExecution.java index 3b41666c04..6afb5d53a1 100644 --- a/backend/src/main/java/ai/giskard/domain/ml/SuiteTestExecution.java +++ b/backend/src/main/java/ai/giskard/domain/ml/SuiteTestExecution.java @@ -7,11 +7,11 @@ import ai.giskard.web.dto.ml.TestResultMessageDTO; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.databind.ObjectMapper; +import jakarta.persistence.*; import lombok.Getter; import lombok.NoArgsConstructor; import lombok.Setter; -import jakarta.persistence.*; import java.util.HashMap; import java.util.List; import java.util.Map; diff --git a/backend/src/main/java/ai/giskard/ml/MLWorkerWSAction.java b/backend/src/main/java/ai/giskard/ml/MLWorkerWSAction.java index 63fc49e819..ee1eb53279 100644 --- a/backend/src/main/java/ai/giskard/ml/MLWorkerWSAction.java +++ b/backend/src/main/java/ai/giskard/ml/MLWorkerWSAction.java @@ -13,9 +13,11 @@ public enum MLWorkerWSAction { GENERATE_TEST_SUITE("generateTestSuite"), STOP_WORKER("stopWorker"), GET_CATALOG("getCatalog"), - GENERATE_QUERY_BASED_SLICING_FUNCTION("generateQueryBasedSlicingFunction"); + GENERATE_QUERY_BASED_SLICING_FUNCTION("generateQueryBasedSlicingFunction"), + GET_PUSH("getPush"); private final String actionName; + MLWorkerWSAction(String name) { actionName = name; } diff --git a/backend/src/main/java/ai/giskard/ml/dto/MLWorkerWSGetPushDTO.java b/backend/src/main/java/ai/giskard/ml/dto/MLWorkerWSGetPushDTO.java new file mode 100644 index 0000000000..71c397cbc9 --- /dev/null +++ b/backend/src/main/java/ai/giskard/ml/dto/MLWorkerWSGetPushDTO.java @@ -0,0 +1,35 @@ +package ai.giskard.ml.dto; + +import ai.giskard.domain.ml.CallToActionKind; +import ai.giskard.domain.ml.PushKind; +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; + +import javax.annotation.Nullable; +import java.util.Map; + +@Getter +@Setter +@Builder +public class MLWorkerWSGetPushDTO implements MLWorkerWSBaseDTO { + private MLWorkerWSArtifactRefDTO dataset; + private MLWorkerWSArtifactRefDTO model; + private int rowIdx; + private MLWorkerWSDataFrameDTO dataframe; + + private String target; + + @JsonProperty("column_types") + private Map columnTypes; + + @JsonProperty("column_dtypes") + private Map columnDtypes; + @Nullable + @JsonProperty("push_kind") + private PushKind pushKind; + @Nullable + @JsonProperty("cta_kind") + private CallToActionKind ctaKind; +} diff --git a/backend/src/main/java/ai/giskard/ml/dto/MLWorkerWSGetPushResultDTO.java b/backend/src/main/java/ai/giskard/ml/dto/MLWorkerWSGetPushResultDTO.java new file mode 100644 index 0000000000..1fed7e79cf --- /dev/null +++ b/backend/src/main/java/ai/giskard/ml/dto/MLWorkerWSGetPushResultDTO.java @@ -0,0 +1,21 @@ +package ai.giskard.ml.dto; + +import lombok.Getter; +import lombok.Setter; + +import javax.annotation.Nullable; + +@Getter +@Setter +public class MLWorkerWSGetPushResultDTO implements MLWorkerWSBaseDTO { + @Nullable + private MLWorkerWSPushDTO perturbation; + @Nullable + private MLWorkerWSPushDTO contribution; + @Nullable + private MLWorkerWSPushDTO borderline; + @Nullable + private MLWorkerWSPushDTO overconfidence; + @Nullable + private MLWorkerWSPushActionDTO action; +} diff --git a/backend/src/main/java/ai/giskard/ml/dto/MLWorkerWSPushActionDTO.java b/backend/src/main/java/ai/giskard/ml/dto/MLWorkerWSPushActionDTO.java new file mode 100644 index 0000000000..0a79ac69fe --- /dev/null +++ b/backend/src/main/java/ai/giskard/ml/dto/MLWorkerWSPushActionDTO.java @@ -0,0 +1,17 @@ +package ai.giskard.ml.dto; + +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.Getter; +import lombok.Setter; + +import java.util.List; + +@Getter +@Setter +public class MLWorkerWSPushActionDTO implements MLWorkerWSBaseDTO { + + @JsonProperty("object_uuid") + private String objectUuid; + + private List arguments; +} diff --git a/backend/src/main/java/ai/giskard/ml/dto/MLWorkerWSPushDTO.java b/backend/src/main/java/ai/giskard/ml/dto/MLWorkerWSPushDTO.java new file mode 100644 index 0000000000..26cd761cbd --- /dev/null +++ b/backend/src/main/java/ai/giskard/ml/dto/MLWorkerWSPushDTO.java @@ -0,0 +1,24 @@ +package ai.giskard.ml.dto; + +import ai.giskard.domain.ml.PushKind; +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.Getter; +import lombok.Setter; + +import javax.annotation.Nullable; +import java.util.List; + +@Getter +@Setter +public class MLWorkerWSPushDTO implements MLWorkerWSBaseDTO { + private PushKind kind; + @Nullable + private String key; + @Nullable + private String value; + @JsonProperty("push_title") + private String pushTitle; + + @JsonProperty("push_details") + private List pushDetails; +} diff --git a/backend/src/main/java/ai/giskard/ml/dto/MLWorkerWSPushDetailsDTO.java b/backend/src/main/java/ai/giskard/ml/dto/MLWorkerWSPushDetailsDTO.java new file mode 100644 index 0000000000..bd39367299 --- /dev/null +++ b/backend/src/main/java/ai/giskard/ml/dto/MLWorkerWSPushDetailsDTO.java @@ -0,0 +1,14 @@ +package ai.giskard.ml.dto; + +import ai.giskard.domain.ml.CallToActionKind; +import lombok.Getter; +import lombok.Setter; + +@Getter +@Setter +public class MLWorkerWSPushDetailsDTO implements MLWorkerWSBaseDTO { + private String action; + private String explanation; + private String button; + private CallToActionKind cta; +} diff --git a/backend/src/main/java/ai/giskard/repository/ml/CallableRepository.java b/backend/src/main/java/ai/giskard/repository/ml/CallableRepository.java index 5dc682037d..965a3e344e 100644 --- a/backend/src/main/java/ai/giskard/repository/ml/CallableRepository.java +++ b/backend/src/main/java/ai/giskard/repository/ml/CallableRepository.java @@ -9,6 +9,6 @@ @NoRepositoryBean public interface CallableRepository extends MappableJpaRepository { - int countByNameAndModule(String name, String module); + int countByDisplayName(String displayName); } diff --git a/backend/src/main/java/ai/giskard/service/SlicingFunctionService.java b/backend/src/main/java/ai/giskard/service/SlicingFunctionService.java index 458e93f42b..9a52583f42 100644 --- a/backend/src/main/java/ai/giskard/service/SlicingFunctionService.java +++ b/backend/src/main/java/ai/giskard/service/SlicingFunctionService.java @@ -9,6 +9,7 @@ import ai.giskard.web.dto.SlicingFunctionDTO; import ai.giskard.web.dto.mapper.GiskardMapper; import com.fasterxml.jackson.core.JsonProcessingException; +import org.apache.logging.log4j.util.Strings; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; @@ -42,10 +43,16 @@ public SlicingFunctionDTO save(SlicingFunctionDTO slicingFunction) { protected SlicingFunction create(SlicingFunctionDTO dto) { SlicingFunction function = giskardMapper.fromDTO(dto); function.setProjectKey(dto.getProjectKey()); + if (function.getArgs() != null) { function.getArgs().forEach(arg -> arg.setFunction(function)); } - function.setVersion(slicingFunctionRepository.countByNameAndModule(function.getName(), function.getModule()) + 1); + + if (Strings.isBlank(function.getDisplayName())) { + function.setDisplayName(function.getModule() + "." + function.getName()); + } + + function.setVersion(slicingFunctionRepository.countByDisplayName(function.getDisplayName()) + 1); return function; } @@ -82,7 +89,7 @@ public SlicingFunctionDTO generate(List comparisonClauses) slicingFunction.setModuleDoc(""); slicingFunction.setName(name); slicingFunction.setTags(List.of("pickle", "ui")); - slicingFunction.setVersion(slicingFunctionRepository.countByNameAndModule(slicingFunction.getName(), slicingFunction.getModule()) + 1); + slicingFunction.setVersion(slicingFunctionRepository.countByDisplayName(name) + 1); slicingFunction.setCellLevel(false); slicingFunction.setColumnType(""); slicingFunction.setProcessType(DatasetProcessFunctionType.CLAUSES); diff --git a/backend/src/main/java/ai/giskard/service/TestFunctionService.java b/backend/src/main/java/ai/giskard/service/TestFunctionService.java index 89ac61dd9f..fea39b9637 100644 --- a/backend/src/main/java/ai/giskard/service/TestFunctionService.java +++ b/backend/src/main/java/ai/giskard/service/TestFunctionService.java @@ -4,6 +4,7 @@ import ai.giskard.repository.ml.TestFunctionRepository; import ai.giskard.web.dto.TestFunctionDTO; import ai.giskard.web.dto.mapper.GiskardMapper; +import org.apache.logging.log4j.util.Strings; import org.springframework.stereotype.Service; @Service @@ -29,7 +30,13 @@ protected TestFunction create(TestFunctionDTO dto) { if (function.getArgs() != null) { function.getArgs().forEach(arg -> arg.setFunction(function)); } - function.setVersion(testFunctionRepository.countByNameAndModule(function.getName(), function.getModule()) + 1); + + if (Strings.isBlank(function.getDisplayName())) { + function.setDisplayName(function.getModule() + "." + function.getName()); + } + + function.setVersion(testFunctionRepository.countByDisplayName(function.getDisplayName()) + 1); + return function; } diff --git a/backend/src/main/java/ai/giskard/service/TransformationFunctionService.java b/backend/src/main/java/ai/giskard/service/TransformationFunctionService.java index 75d5f9eb03..1ea694723a 100644 --- a/backend/src/main/java/ai/giskard/service/TransformationFunctionService.java +++ b/backend/src/main/java/ai/giskard/service/TransformationFunctionService.java @@ -4,6 +4,7 @@ import ai.giskard.repository.ml.TransformationFunctionRepository; import ai.giskard.web.dto.TransformationFunctionDTO; import ai.giskard.web.dto.mapper.GiskardMapper; +import org.apache.logging.log4j.util.Strings; import org.springframework.stereotype.Service; @Service @@ -29,7 +30,13 @@ protected TransformationFunction create(TransformationFunctionDTO dto) { if (function.getArgs() != null) { function.getArgs().forEach(arg -> arg.setFunction(function)); } - function.setVersion(transformationFunctionRepository.countByNameAndModule(function.getName(), function.getModule()) + 1); + + if (Strings.isBlank(function.getDisplayName())) { + function.setDisplayName(function.getModule() + "." + function.getName()); + } + + function.setVersion(transformationFunctionRepository.countByDisplayName(function.getDisplayName()) + 1); + return function; } diff --git a/backend/src/main/java/ai/giskard/service/ml/MLWorkerWSCommService.java b/backend/src/main/java/ai/giskard/service/ml/MLWorkerWSCommService.java index c0432da39a..858b6210c9 100644 --- a/backend/src/main/java/ai/giskard/service/ml/MLWorkerWSCommService.java +++ b/backend/src/main/java/ai/giskard/service/ml/MLWorkerWSCommService.java @@ -83,6 +83,7 @@ public MLWorkerWSBaseDTO performAction(MLWorkerID workerID, MLWorkerWSAction act case ECHO -> parseReplyDTO(result, MLWorkerWSEchoMsgDTO.class); case GENERATE_TEST_SUITE -> parseReplyDTO(result, MLWorkerWSGenerateTestSuiteDTO.class); case GET_CATALOG -> parseReplyDTO(result, MLWorkerWSCatalogDTO.class); + case GET_PUSH -> parseReplyDTO(result, MLWorkerWSGetPushResultDTO.class); }; } catch (JsonProcessingException e) { return parseReplyErrorDTO(result); diff --git a/backend/src/main/java/ai/giskard/web/dto/ApplyPushDTO.java b/backend/src/main/java/ai/giskard/web/dto/ApplyPushDTO.java new file mode 100644 index 0000000000..28d6e094ff --- /dev/null +++ b/backend/src/main/java/ai/giskard/web/dto/ApplyPushDTO.java @@ -0,0 +1,24 @@ +package ai.giskard.web.dto; + +import ai.giskard.domain.ml.CallToActionKind; +import ai.giskard.domain.ml.PushKind; +import com.dataiku.j2ts.annotations.UIModel; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; + +import java.util.Map; +import java.util.UUID; + +@Getter +@Setter +@UIModel +@NoArgsConstructor +public class ApplyPushDTO { + private UUID modelId; + private UUID datasetId; + private int rowIdx; + private PushKind pushKind; + private CallToActionKind ctaKind; + private Map features; +} diff --git a/backend/src/main/java/ai/giskard/web/dto/PushActionDTO.java b/backend/src/main/java/ai/giskard/web/dto/PushActionDTO.java new file mode 100644 index 0000000000..de9e294de0 --- /dev/null +++ b/backend/src/main/java/ai/giskard/web/dto/PushActionDTO.java @@ -0,0 +1,17 @@ +package ai.giskard.web.dto; + +import com.dataiku.j2ts.annotations.UIModel; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; + +import java.util.Map; + +@Getter +@Setter +@UIModel +@NoArgsConstructor +public class PushActionDTO { + private String objectUuid; + private Map parameters; +} diff --git a/backend/src/main/java/ai/giskard/web/rest/controllers/PushController.java b/backend/src/main/java/ai/giskard/web/rest/controllers/PushController.java new file mode 100644 index 0000000000..d9dcbd3c98 --- /dev/null +++ b/backend/src/main/java/ai/giskard/web/rest/controllers/PushController.java @@ -0,0 +1,175 @@ +package ai.giskard.web.rest.controllers; + +import ai.giskard.domain.ColumnType; +import ai.giskard.domain.ml.Dataset; +import ai.giskard.domain.ml.ProjectModel; +import ai.giskard.exception.MLWorkerIllegalReplyException; +import ai.giskard.exception.MLWorkerNotConnectedException; +import ai.giskard.ml.MLWorkerID; +import ai.giskard.ml.MLWorkerWSAction; +import ai.giskard.ml.dto.*; +import ai.giskard.repository.ml.DatasetRepository; +import ai.giskard.repository.ml.ModelRepository; +import ai.giskard.security.PermissionEvaluator; +import ai.giskard.service.ml.MLWorkerWSCommService; +import ai.giskard.service.ml.MLWorkerWSService; +import ai.giskard.web.dto.ApplyPushDTO; +import ai.giskard.web.dto.PredictionInputDTO; +import com.google.common.collect.Maps; +import jakarta.validation.constraints.NotNull; +import lombok.RequiredArgsConstructor; +import org.apache.logging.log4j.util.Strings; +import org.springframework.web.bind.annotation.*; + +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.UUID; +import java.util.stream.Collectors; + +@RestController +@RequiredArgsConstructor +@RequestMapping("/api/v2/") +public class PushController { + private final ModelRepository modelRepository; + private final PermissionEvaluator permissionEvaluator; + private final DatasetRepository datasetRepository; + + private final MLWorkerWSCommService mlWorkerWSCommService; + private final MLWorkerWSService mlWorkerWSService; + + + @PostMapping("/pushes/{modelId}/{datasetId}/{idx}") + public MLWorkerWSGetPushResultDTO getPushes(@PathVariable @NotNull UUID modelId, + @PathVariable @NotNull UUID datasetId, + @PathVariable @NotNull int idx, + @RequestBody @NotNull PredictionInputDTO data) { + ProjectModel model = modelRepository.getMandatoryById(modelId); + Dataset dataset = datasetRepository.getMandatoryById(datasetId); + permissionEvaluator.validateCanReadProject(model.getProject().getId()); + Map features = data.getFeatures(); + + MLWorkerWSArtifactRefDTO datasetRef = MLWorkerWSArtifactRefDTO.fromDataset(dataset); + MLWorkerWSArtifactRefDTO modelRef = MLWorkerWSArtifactRefDTO.fromModel(model); + + MLWorkerWSGetPushDTO.MLWorkerWSGetPushDTOBuilder paramBuilder = MLWorkerWSGetPushDTO.builder() + .dataset(datasetRef) + .model(modelRef) + .rowIdx(idx); + + if (features != null) { + MLWorkerWSDataFrameDTO dataframe = MLWorkerWSDataFrameDTO.builder() + .rows( + List.of( + MLWorkerWSDataRowDTO.builder().columns( + features.entrySet().stream() + .filter(entry -> !shouldDrop( + dataset.getColumnDtypes().get(entry.getKey()), + entry.getValue() + )).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)) + ).build() + ) + ).build(); + + paramBuilder.dataframe(dataframe); + } + + if (dataset.getTarget() != null) { + paramBuilder.target(dataset.getTarget()); + } + if (dataset.getColumnTypes() != null) { + paramBuilder.columnTypes(Maps.transformValues(dataset.getColumnTypes(), ColumnType::getName)); + } + if (dataset.getColumnDtypes() != null) { + paramBuilder.columnDtypes(dataset.getColumnDtypes()); + } + + if (mlWorkerWSService.isWorkerConnected(MLWorkerID.EXTERNAL)) { + MLWorkerWSBaseDTO result = mlWorkerWSCommService.performAction( + MLWorkerID.EXTERNAL, + MLWorkerWSAction.GET_PUSH, + paramBuilder.build() + ); + + if (result instanceof MLWorkerWSGetPushResultDTO response) { + return response; + } else if (result instanceof MLWorkerWSErrorDTO error) { + throw new MLWorkerIllegalReplyException(error); + } + throw new MLWorkerIllegalReplyException("Cannot get ML Worker GetPushResult reply"); + } + + throw new MLWorkerNotConnectedException(MLWorkerID.EXTERNAL); + } + + @PostMapping("/push/apply") + public MLWorkerWSGetPushResultDTO applyPushSuggestion(@RequestBody ApplyPushDTO applyPushDTO) { + ProjectModel model = modelRepository.getMandatoryById(applyPushDTO.getModelId()); + Dataset dataset = datasetRepository.getMandatoryById(applyPushDTO.getDatasetId()); + permissionEvaluator.validateCanReadProject(model.getProject().getId()); + Map features = applyPushDTO.getFeatures(); + + MLWorkerWSArtifactRefDTO datasetRef = MLWorkerWSArtifactRefDTO.fromDataset(dataset); + MLWorkerWSArtifactRefDTO modelRef = MLWorkerWSArtifactRefDTO.fromModel(model); + + MLWorkerWSGetPushDTO.MLWorkerWSGetPushDTOBuilder paramBuilder = MLWorkerWSGetPushDTO.builder() + .dataset(datasetRef) + .model(modelRef) + .rowIdx(applyPushDTO.getRowIdx()) + .pushKind(applyPushDTO.getPushKind()) + .ctaKind(applyPushDTO.getCtaKind()); + + if (features != null) { + MLWorkerWSDataFrameDTO dataframe = MLWorkerWSDataFrameDTO.builder() + .rows( + List.of( + MLWorkerWSDataRowDTO.builder().columns( + features.entrySet().stream() + .filter(entry -> !shouldDrop( + dataset.getColumnDtypes().get(entry.getKey()), + entry.getValue() + )).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)) + ).build() + ) + ).build(); + + paramBuilder.dataframe(dataframe); + } + + if (dataset.getTarget() != null) { + paramBuilder.target(dataset.getTarget()); + } + if (dataset.getColumnTypes() != null) { + paramBuilder.columnTypes(Maps.transformValues(dataset.getColumnTypes(), ColumnType::getName)); + } + if (dataset.getColumnDtypes() != null) { + paramBuilder.columnDtypes(dataset.getColumnDtypes()); + } + + paramBuilder.pushKind(applyPushDTO.getPushKind()); + paramBuilder.ctaKind(applyPushDTO.getCtaKind()); + + if (mlWorkerWSService.isWorkerConnected(MLWorkerID.EXTERNAL)) { + MLWorkerWSBaseDTO result = mlWorkerWSCommService.performAction( + MLWorkerID.EXTERNAL, + MLWorkerWSAction.GET_PUSH, + paramBuilder.build() + ); + + if (result instanceof MLWorkerWSGetPushResultDTO response) { + return response; + } else if (result instanceof MLWorkerWSErrorDTO error) { + throw new MLWorkerIllegalReplyException(error); + } + throw new MLWorkerIllegalReplyException("Cannot get ML Worker GetPushResult reply"); + } + + throw new MLWorkerNotConnectedException(MLWorkerID.EXTERNAL); + } + + // Probably move this to a util class. + public boolean shouldDrop(String columnDtype, String value) { + return Objects.isNull(columnDtype) || Objects.isNull(value) || + ((columnDtype.startsWith("int") || columnDtype.startsWith("float")) && Strings.isBlank(value)); + } +} diff --git a/backend/src/main/resources/config/liquibase/changelog/20230829180047_changelog.xml b/backend/src/main/resources/config/liquibase/changelog/20230829180047_changelog.xml new file mode 100644 index 0000000000..455c74f14c --- /dev/null +++ b/backend/src/main/resources/config/liquibase/changelog/20230829180047_changelog.xml @@ -0,0 +1,8 @@ + + + + + + diff --git a/backend/src/main/resources/config/liquibase/master.xml b/backend/src/main/resources/config/liquibase/master.xml index 6cd6552a50..62af303942 100644 --- a/backend/src/main/resources/config/liquibase/master.xml +++ b/backend/src/main/resources/config/liquibase/master.xml @@ -34,4 +34,6 @@ + + diff --git a/common/proto/ml-worker.proto b/common/proto/ml-worker.proto new file mode 100644 index 0000000000..e69de29bb2 diff --git a/frontend/src/api.ts b/frontend/src/api.ts index 17dab94795..9c4c842c6a 100644 --- a/frontend/src/api.ts +++ b/frontend/src/api.ts @@ -7,6 +7,8 @@ import { AdminUserDTO, ApiKeyDTO, AppConfigDTO, + ApplyPushDTO, + CallToActionKind, CatalogDTO, ComparisonClauseDTO, CreateFeedbackDTO, @@ -39,6 +41,7 @@ import { PrepareImportProjectDTO, ProjectDTO, ProjectPostDTO, + PushKind, RoleDTO, RowFilterDTO, SetupDTO, @@ -204,7 +207,7 @@ function downloadURL(urlString) { export const api = { async getHuggingFaceSpacesToken(spaceId: string) { - return await huggingface.get(`https://huggingface.co/api/spaces/${spaceId}/jwt`); + return await huggingface.get(`https://huggingface.co/api/spaces/${spaceId}/jwt`); }, async logInGetToken(username: string, password: string) { return apiV2.post(`/authenticate`, {username, password}); @@ -259,7 +262,7 @@ export const api = { return apiV2.delete(`/admin/users/${login}`); }, async disableUser(login: string) { - return apiV2.patch(`/admin/users/${login}/disable`); + return apiV2.patch(`/admin/users/${login}/disable`); }, async enableUser(login: string) { return apiV2.patch(`/admin/users/${login}/enable`); @@ -441,82 +444,97 @@ export const api = { return apiV2.post(`/models/${modelId}/predict`, data, {signal: controller.signal}); }, - async prepareInspection(payload: InspectionCreateDTO) { - return apiV2.post(`/inspection`, payload); - }, - async explain(modelId: string, datasetId: string, inputData: object, controller: AbortController) { - return apiV2.post( - `/models/${modelId}/explain/${datasetId}`, - { features: inputData }, - { signal: controller.signal } - ); - }, - async explainText(modelId: string, datasetId: string, inputData: object, featureName: string) { - return apiV2.post( - `/models/explain-text/${featureName}`, - { - features: inputData, - }, - { params: { modelId, datasetId } } - ); - }, - // feedbacks - async submitFeedback(payload: CreateFeedbackDTO, projectId: number) { - return apiV2.post(`/feedbacks/${projectId}`, payload); - }, - async getProjectFeedbacks(projectId: number) { - return apiV2.get(`/feedbacks/all/${projectId}`); - }, - async getFeedback(id: number) { - return apiV2.get(`/feedbacks/${id}`); - }, - async replyToFeedback(feedbackId: number, content: string, replyToId: number | null = null) { - return apiV2.post(`/feedbacks/${feedbackId}/reply`, { - content, - replyToReply: replyToId - }); - }, - async deleteFeedback(id: number) { - return apiV2.delete(`/feedbacks/${id}`); - }, - async deleteFeedbackReply(feedbackId: number, replyId: number) { - return apiV2.delete(`/feedbacks/${feedbackId}/replies/${replyId}`); - }, - async runAdHocTest(projectId: number, testUuid: string, inputs: Array, sample: boolean, debug: boolean = false) { - return apiV2.post(`/testing/tests/run-test?sample=${sample}`, { - projectId, - testUuid, - inputs, - debug - }); - }, - async getCatalog(projectId: number) { - return apiV2.get(`/catalog`, { - params: { - projectId - } - }); - }, - async createSlicingFunction(comparisonClauses: Array) { - return apiV2.post(`/slices/no-code`, comparisonClauses); - }, - async uploadLicense(form: FormData) { - return apiV2.post(`/ee/license`, form, { - headers: { - 'Content-Type': 'multipart/form-data', - }, - }); - }, - async finalizeSetup(allowAnalytics: boolean, license: string) { - return apiV2.post(`/setup`, { - allowAnalytics: allowAnalytics, - license: license, - }); - }, - async datasetProcessing(projectId: number, datasetUuid: string, functions: Array, sample: boolean = true) { - return apiV2.post( - `/project/${projectId}/datasets/${encodeURIComponent(datasetUuid)}/process?sample=${sample}`, - functions - ); - }, + async prepareInspection(payload: InspectionCreateDTO) { + return apiV2.post(`/inspection`, payload); + }, + async explain(modelId: string, datasetId: string, inputData: object, controller: AbortController) { + return apiV2.post( + `/models/${modelId}/explain/${datasetId}`, + {features: inputData}, + {signal: controller.signal} + ); + }, + async explainText(modelId: string, datasetId: string, inputData: object, featureName: string) { + return apiV2.post( + `/models/explain-text/${featureName}`, + { + features: inputData, + }, + {params: {modelId, datasetId}} + ); + }, + // feedbacks + async submitFeedback(payload: CreateFeedbackDTO, projectId: number) { + return apiV2.post(`/feedbacks/${projectId}`, payload); + }, + async getProjectFeedbacks(projectId: number) { + return apiV2.get(`/feedbacks/all/${projectId}`); + }, + async getFeedback(id: number) { + return apiV2.get(`/feedbacks/${id}`); + }, + async replyToFeedback(feedbackId: number, content: string, replyToId: number | null = null) { + return apiV2.post(`/feedbacks/${feedbackId}/reply`, { + content, + replyToReply: replyToId + }); + }, + async deleteFeedback(id: number) { + return apiV2.delete(`/feedbacks/${id}`); + }, + async deleteFeedbackReply(feedbackId: number, replyId: number) { + return apiV2.delete(`/feedbacks/${feedbackId}/replies/${replyId}`); + }, + async runAdHocTest(projectId: number, testUuid: string, inputs: Array, sample: boolean, debug: boolean = false) { + return apiV2.post(`/testing/tests/run-test?sample=${sample}`, { + projectId, + testUuid, + inputs, + debug + }); + }, + async getCatalog(projectId: number) { + return apiV2.get(`/catalog`, { + params: { + projectId + } + }); + }, + async createSlicingFunction(comparisonClauses: Array) { + return apiV2.post(`/slices/no-code`, comparisonClauses); + }, + async uploadLicense(form: FormData) { + return apiV2.post(`/ee/license`, form, { + headers: { + 'Content-Type': 'multipart/form-data', + }, + }); + }, + async finalizeSetup(allowAnalytics: boolean, license: string) { + return apiV2.post(`/setup`, { + allowAnalytics: allowAnalytics, + license: license, + }); + }, + async datasetProcessing(projectId: number, datasetUuid: string, functions: Array, sample: boolean = true) { + return apiV2.post( + `/project/${projectId}/datasets/${encodeURIComponent(datasetUuid)}/process?sample=${sample}`, + functions + ); + }, + async getPushes(modelId: string, datasetId: string, idx: number, features: any) { + return apiV2.post(`/pushes/${modelId}/${datasetId}/${idx}`, { + features: features + }); + }, + async applyPush(modelId: string, datasetId: string, idx: number, pushKind: PushKind, ctaKind: CallToActionKind, features: any) { + return apiV2.post(`/push/apply`, { + modelId, + datasetId, + rowIdx: idx, + pushKind: pushKind, + ctaKind: ctaKind, + features: features + }); + }, }; diff --git a/frontend/src/components/PushPopover.vue b/frontend/src/components/PushPopover.vue new file mode 100644 index 0000000000..85c3f08d18 --- /dev/null +++ b/frontend/src/components/PushPopover.vue @@ -0,0 +1,254 @@ + + + + + \ No newline at end of file diff --git a/frontend/src/components/TransformationPopover.vue b/frontend/src/components/TransformationPopover.vue index 13b8c59672..bd79ac9696 100644 --- a/frontend/src/components/TransformationPopover.vue +++ b/frontend/src/components/TransformationPopover.vue @@ -1,32 +1,37 @@ @@ -63,40 +68,40 @@ function handleOnChanged() { diff --git a/frontend/src/generated-sources/ai/giskard/domain/ml/call-to-action-kind.ts b/frontend/src/generated-sources/ai/giskard/domain/ml/call-to-action-kind.ts new file mode 100644 index 0000000000..a201bea138 --- /dev/null +++ b/frontend/src/generated-sources/ai/giskard/domain/ml/call-to-action-kind.ts @@ -0,0 +1,17 @@ +/** + * Generated from ai.giskard.domain.ml.CallToActionKind + */ +export enum CallToActionKind { + None = 0, + CreateSlice = 1, + CreateTest = 2, + CreatePerturbation = 3, + SavePerturbation = 4, + CreateRobustnessTest = 5, + CreateSliceOpenDebugger = 6, + OpenDebuggerBorderline = 7, + AddTestToCatalog = 8, + SaveExample = 9, + OpenDebuggerOverconfidence = 10, + CreateUnitTest = 11 +} \ No newline at end of file diff --git a/frontend/src/generated-sources/ai/giskard/domain/ml/push-kind.ts b/frontend/src/generated-sources/ai/giskard/domain/ml/push-kind.ts new file mode 100644 index 0000000000..92e5aa60c1 --- /dev/null +++ b/frontend/src/generated-sources/ai/giskard/domain/ml/push-kind.ts @@ -0,0 +1,10 @@ +/** + * Generated from ai.giskard.domain.ml.PushKind + */ +export enum PushKind { + Invalid = 'Invalid', + Perturbation = 'Perturbation', + Contribution = 'Contribution', + Overconfidence = 'Overconfidence', + Borderline = 'Borderline' +} \ No newline at end of file diff --git a/frontend/src/generated-sources/ai/giskard/web/dto/apply-push-dto.ts b/frontend/src/generated-sources/ai/giskard/web/dto/apply-push-dto.ts new file mode 100644 index 0000000000..f0470663b2 --- /dev/null +++ b/frontend/src/generated-sources/ai/giskard/web/dto/apply-push-dto.ts @@ -0,0 +1,14 @@ +import type {CallToActionKind} from './../../domain/ml/call-to-action-kind'; +import type {PushKind} from './../../domain/ml/push-kind'; + +/** + * Generated from ai.giskard.web.dto.ApplyPushDTO + */ +export interface ApplyPushDTO { + ctaKind: CallToActionKind; + datasetId: string; + features: {[key: string]: string}; + modelId: string; + pushKind: PushKind; + rowIdx: number; +} \ No newline at end of file diff --git a/frontend/src/generated-sources/ai/giskard/web/dto/push-action-dto.ts b/frontend/src/generated-sources/ai/giskard/web/dto/push-action-dto.ts new file mode 100644 index 0000000000..89e2f6c447 --- /dev/null +++ b/frontend/src/generated-sources/ai/giskard/web/dto/push-action-dto.ts @@ -0,0 +1,7 @@ +/** + * Generated from ai.giskard.web.dto.PushActionDTO + */ +export interface PushActionDTO { + object_uuid: string; + arguments: { [key: string]: string }; +} \ No newline at end of file diff --git a/frontend/src/generated-sources/ai/giskard/web/dto/push-details-dto.ts b/frontend/src/generated-sources/ai/giskard/web/dto/push-details-dto.ts new file mode 100644 index 0000000000..3e5f50e1e3 --- /dev/null +++ b/frontend/src/generated-sources/ai/giskard/web/dto/push-details-dto.ts @@ -0,0 +1,8 @@ +/** + * Generated from ai.giskard.web.dto.PushDetailsDTO + */ +export interface PushDetailsDTO { + action: string; + button: string; + explanation: string; +} \ No newline at end of file diff --git a/frontend/src/generated-sources/ai/giskard/web/dto/push-dto.ts b/frontend/src/generated-sources/ai/giskard/web/dto/push-dto.ts new file mode 100644 index 0000000000..0fd5741c9c --- /dev/null +++ b/frontend/src/generated-sources/ai/giskard/web/dto/push-dto.ts @@ -0,0 +1,14 @@ +import type {PushDetailsDTO} from './push-details-dto'; +import {PushKind} from "@/generated-sources"; + +/** + * Generated from ai.giskard.web.dto.PushDTO + */ +export interface PushDTO { + details: PushDetailsDTO[]; + key: string; + perturbationValue: string; + push_title: string; + value: string; + kind: PushKind; +} \ No newline at end of file diff --git a/frontend/src/generated-sources/index.ts b/frontend/src/generated-sources/index.ts index 3534c90829..a8795150b4 100644 --- a/frontend/src/generated-sources/index.ts +++ b/frontend/src/generated-sources/index.ts @@ -1,9 +1,11 @@ export const GENERATED_MAPPING = { 'ai.giskard.domain.ColumnType' : 'ColumnType', 'ai.giskard.domain.GeneralSettings' : 'GeneralSettings', + 'ai.giskard.domain.ml.CallToActionKind' : 'CallToActionKind', 'ai.giskard.domain.ml.CodeLanguage' : 'CodeLanguage', 'ai.giskard.domain.ml.ModelLanguage' : 'ModelLanguage', 'ai.giskard.domain.ml.ModelType' : 'ModelType', + 'ai.giskard.domain.ml.PushKind' : 'PushKind', 'ai.giskard.domain.ml.table.Filter' : 'Filter', 'ai.giskard.domain.ml.table.RegressionUnit' : 'RegressionUnit', 'ai.giskard.domain.ml.table.RowFilterType' : 'RowFilterType', @@ -14,6 +16,7 @@ export const GENERATED_MAPPING = { 'ai.giskard.ml.dto.MLWorkerWSTestMessageType' : 'MLWorkerWSTestMessageType', 'ai.giskard.service.ee.FeatureFlag' : 'FeatureFlag', 'ai.giskard.web.dto.ApiKeyDTO' : 'ApiKeyDTO', + 'ai.giskard.web.dto.ApplyPushDTO' : 'ApplyPushDTO', 'ai.giskard.web.dto.CallableDTO' : 'CallableDTO', 'ai.giskard.web.dto.CatalogDTO' : 'CatalogDTO', 'ai.giskard.web.dto.ComparisonClauseDTO' : 'ComparisonClauseDTO', @@ -73,6 +76,9 @@ export const GENERATED_MAPPING = { 'ai.giskard.web.dto.PrepareDeleteDTO$LightFeedback' : 'PrepareDeleteDTO.LightFeedback', 'ai.giskard.web.dto.PrepareDeleteDTO' : 'PrepareDeleteDTO', 'ai.giskard.web.dto.PrepareImportProjectDTO' : 'PrepareImportProjectDTO', + 'ai.giskard.web.dto.PushActionDTO' : 'PushActionDTO', + 'ai.giskard.web.dto.PushDetailsDTO' : 'PushDetailsDTO', + 'ai.giskard.web.dto.PushDTO' : 'PushDTO', 'ai.giskard.web.dto.RequiredInputDTO' : 'RequiredInputDTO', 'ai.giskard.web.dto.RowFilterDTO' : 'RowFilterDTO', 'ai.giskard.web.dto.RunAdhocTestRequest' : 'RunAdhocTestRequest', @@ -98,9 +104,11 @@ export const GENERATED_MAPPING = { }; export * from './ai/giskard/domain/column-type'; export * from './ai/giskard/domain/general-settings'; +export * from './ai/giskard/domain/ml/call-to-action-kind'; export * from './ai/giskard/domain/ml/code-language'; export * from './ai/giskard/domain/ml/model-language'; export * from './ai/giskard/domain/ml/model-type'; +export * from './ai/giskard/domain/ml/push-kind'; export * from './ai/giskard/domain/ml/table/filter'; export * from './ai/giskard/domain/ml/table/regression-unit'; export * from './ai/giskard/domain/ml/table/row-filter-type'; @@ -111,6 +119,7 @@ export * from './ai/giskard/jobs/job-type'; export * from './ai/giskard/ml/dto/mlworker-wstest-message-type'; export * from './ai/giskard/service/ee/feature-flag'; export * from './ai/giskard/web/dto/api-key-dto'; +export * from './ai/giskard/web/dto/apply-push-dto'; export * from './ai/giskard/web/dto/callable-dto'; export * from './ai/giskard/web/dto/catalog-dto'; export * from './ai/giskard/web/dto/comparison-clause-dto'; @@ -166,6 +175,9 @@ export * from './ai/giskard/web/dto/prediction-dto'; export * from './ai/giskard/web/dto/prediction-input-dto'; export * from './ai/giskard/web/dto/prepare-delete-dto'; export * from './ai/giskard/web/dto/prepare-import-project-dto'; +export * from './ai/giskard/web/dto/push-action-dto'; +export * from './ai/giskard/web/dto/push-details-dto'; +export * from './ai/giskard/web/dto/push-dto'; export * from './ai/giskard/web/dto/required-input-dto'; export * from './ai/giskard/web/dto/row-filter-dto'; export * from './ai/giskard/web/dto/run-adhoc-test-request'; diff --git a/frontend/src/generated-sources/j2ts-generated-metadata.json b/frontend/src/generated-sources/j2ts-generated-metadata.json index a79e31d94f..660dac7aee 100644 --- a/frontend/src/generated-sources/j2ts-generated-metadata.json +++ b/frontend/src/generated-sources/j2ts-generated-metadata.json @@ -6,70 +6,42 @@ "ai.giskard.domain.ml.table.RowFilterType": "RowFilterType", "ai.giskard.jobs.JobType": "JobType", "ai.giskard.web.dto.GenerateTestSuiteInputDTO": "GenerateTestSuiteInputDTO", - "ai.giskard.web.dto.ml.TestTemplateExecutionResultDTO": "TestTemplateExecutionResultDTO", "ai.giskard.web.dto.user.PasswordChangeDTO": "PasswordChangeDTO", "ai.giskard.web.dto.FeedbackReplyDTO": "FeedbackReplyDTO", - "ai.giskard.web.dto.ml.write.FilePostDTO": "FilePostDTO", - "ai.giskard.domain.ColumnType": "ColumnType", "ai.giskard.web.dto.ml.ExecuteTestSuiteRequest": "ExecuteTestSuiteRequest", - "ai.giskard.web.dto.ml.SuiteTestExecutionDTO": "SuiteTestExecutionDTO", - "ai.giskard.web.dto.TransformationResultMessageDTO": "TransformationResultMessageDTO", - "ai.giskard.web.rest.vm.TokenAndPasswordVM": "TokenAndPasswordVM", - "ai.giskard.web.dto.PrepareDeleteDTO": "PrepareDeleteDTO", - "ai.giskard.web.dto.ml.ProjectDTO": "ProjectDTO", - "ai.giskard.web.dto.CreateFeedbackReplyDTO": "CreateFeedbackReplyDTO", "ai.giskard.web.dto.PasswordResetRequest": "PasswordResetRequest", "ai.giskard.web.dto.user.UserMinimalDTO": "UserMinimalDTO", - "ai.giskard.web.dto.ApiKeyDTO": "ApiKeyDTO", "ai.giskard.web.dto.ml.ModelDTO": "ModelDTO", "ai.giskard.web.dto.ml.SingleTestResultDTO": "SingleTestResultDTO", "ai.giskard.web.dto.config.AppConfigDTO$AppInfoDTO": "AppConfigDTO.AppInfoDTO", "ai.giskard.domain.ml.table.Filter": "Filter", "ai.giskard.web.dto.TestFunctionDTO": "TestFunctionDTO", - "ai.giskard.web.dto.user.UpdateMeDTO": "UpdateMeDTO", "ai.giskard.web.dto.user.AdminUserDTO": "AdminUserDTO", - "ai.giskard.web.dto.config.MLWorkerInfoDTO": "MLWorkerInfoDTO", - "ai.giskard.web.dto.PostImportProjectDTO": "PostImportProjectDTO", "ai.giskard.web.dto.DatasetMetadataDTO": "DatasetMetadataDTO", "ai.giskard.web.dto.ml.CodeBasedTestPresetDTO": "CodeBasedTestPresetDTO", - "ai.giskard.web.dto.TestSuiteCompleteDTO": "TestSuiteCompleteDTO", "ai.giskard.web.dto.SuiteTestDTO": "SuiteTestDTO", "ai.giskard.web.dto.DatasetProcessFunctionType": "DatasetProcessFunctionType", "ai.giskard.jobs.JobState": "JobState", - "ai.giskard.web.dto.SlicingFunctionDTO": "SlicingFunctionDTO", + "ai.giskard.domain.ml.PushKind": "PushKind", "ai.giskard.web.dto.GenerateTestSuiteDTO": "GenerateTestSuiteDTO", "ai.giskard.web.dto.ml.TestExecutionStatusDTO": "TestExecutionStatusDTO", - "ai.giskard.web.dto.ExplainTextResponseDTO": "ExplainTextResponseDTO", - "ai.giskard.web.dto.JWTToken": "JWTToken", - "ai.giskard.web.dto.ModelUploadParamsDTO": "ModelUploadParamsDTO", "ai.giskard.web.dto.RequiredInputDTO": "RequiredInputDTO", + "ai.giskard.web.dto.ApplyPushDTO": "ApplyPushDTO", "ai.giskard.web.dto.CreateFeedbackDTO": "CreateFeedbackDTO", - "ai.giskard.web.dto.TransformationFunctionDTO": "TransformationFunctionDTO", - "ai.giskard.web.dto.ml.write.ModelPostDTO": "ModelPostDTO", "ai.giskard.web.rest.vm.ManagedUserVM": "ManagedUserVM", - "ai.giskard.web.dto.CatalogDTO": "CatalogDTO", "ai.giskard.web.dto.PrepareDeleteDTO$LightTestSuite": "PrepareDeleteDTO.LightTestSuite", - "ai.giskard.web.dto.ml.NamedSingleTestResultDTO": "NamedSingleTestResultDTO", "ai.giskard.web.dto.user.RoleDTO": "RoleDTO", "ai.giskard.web.dto.FunctionInputDTO": "FunctionInputDTO", "ai.giskard.web.dto.config.LicenseDTO": "LicenseDTO", "ai.giskard.web.dto.ComparisonType": "ComparisonType", "ai.giskard.web.dto.FeedbackMinimalDTO": "FeedbackMinimalDTO", - "ai.giskard.web.dto.ml.DatasetDTO": "DatasetDTO", "ai.giskard.web.dto.PrepareDeleteDTO$LightFeedback": "PrepareDeleteDTO.LightFeedback", - "ai.giskard.web.dto.PredictionDTO": "PredictionDTO", - "ai.giskard.web.dto.user.AdminUserDTO$AdminUserDTOWithPassword": "AdminUserDTO.AdminUserDTOWithPassword", - "ai.giskard.web.dto.PrepareImportProjectDTO": "PrepareImportProjectDTO", "ai.giskard.web.dto.TestSuiteDTO": "TestSuiteDTO", - "ai.giskard.web.dto.FilterDatasetDTO": "FilterDatasetDTO", - "ai.giskard.web.dto.ml.TestResultMessageDTO": "TestResultMessageDTO", - "ai.giskard.web.dto.DataUploadParamsDTO": "DataUploadParamsDTO", "ai.giskard.web.dto.ml.write.TestSuitePostDTO": "TestSuitePostDTO", "ai.giskard.domain.ml.ModelLanguage": "ModelLanguage", "ai.giskard.web.dto.ml.TestSuiteExecutionDTO": "TestSuiteExecutionDTO", "ai.giskard.web.dto.FeedbackDTO": "FeedbackDTO", "ai.giskard.web.dto.JobDTO": "JobDTO", - "ai.giskard.web.dto.RowFilterDTO": "RowFilterDTO", "ai.giskard.domain.ml.ModelType": "ModelType", "ai.giskard.web.dto.ml.ProjectPostDTO": "ProjectPostDTO", "ai.giskard.web.dto.config.MLWorkerInfoDTO$PlatformInfoDTO": "MLWorkerInfoDTO.PlatformInfoDTO", @@ -77,24 +49,58 @@ "ai.giskard.web.dto.SetupDTO": "SetupDTO", "ai.giskard.domain.ml.CodeLanguage": "CodeLanguage", "ai.giskard.web.dto.DatasetPageDTO": "DatasetPageDTO", - "ai.giskard.web.dto.ExplainResponseDTO": "ExplainResponseDTO", - "ai.giskard.web.dto.ModelMetadataDTO": "ModelMetadataDTO", + "ai.giskard.web.dto.PushDTO": "PushDTO", "ai.giskard.web.dto.DatasetProcessingResultDTO": "DatasetProcessingResultDTO", - "ai.giskard.web.dto.MessageDTO": "MessageDTO", "ai.giskard.domain.MLWorkerType": "MLWorkerType", "ai.giskard.web.dto.PredictionInputDTO": "PredictionInputDTO", "ai.giskard.web.dto.TestFunctionArgumentDTO": "TestFunctionArgumentDTO", "ai.giskard.web.dto.config.AppConfigDTO": "AppConfigDTO", - "ai.giskard.web.dto.ml.InspectionDTO": "InspectionDTO", - "ai.giskard.ml.dto.MLWorkerWSTestMessageType": "MLWorkerWSTestMessageType", - "ai.giskard.domain.GeneralSettings": "GeneralSettings", "ai.giskard.web.dto.ParameterizedCallableDTO": "ParameterizedCallableDTO", "ai.giskard.domain.ml.table.RegressionUnit": "RegressionUnit", "ai.giskard.web.dto.ComparisonClauseDTO": "ComparisonClauseDTO", "ai.giskard.web.dto.ml.TestEditorConfigDTO": "TestEditorConfigDTO", - "ai.giskard.web.dto.CallableDTO": "CallableDTO", "ai.giskard.service.ee.FeatureFlag": "FeatureFlag", "ai.giskard.domain.ml.TestResult": "TestResult", - "ai.giskard.web.dto.user.UserDTO": "UserDTO" + "ai.giskard.web.dto.user.UserDTO": "UserDTO", + "ai.giskard.web.dto.ml.TestTemplateExecutionResultDTO": "TestTemplateExecutionResultDTO", + "ai.giskard.web.dto.ml.write.FilePostDTO": "FilePostDTO", + "ai.giskard.domain.ColumnType": "ColumnType", + "ai.giskard.web.dto.ml.SuiteTestExecutionDTO": "SuiteTestExecutionDTO", + "ai.giskard.web.dto.TransformationResultMessageDTO": "TransformationResultMessageDTO", + "ai.giskard.web.rest.vm.TokenAndPasswordVM": "TokenAndPasswordVM", + "ai.giskard.web.dto.PrepareDeleteDTO": "PrepareDeleteDTO", + "ai.giskard.web.dto.ml.ProjectDTO": "ProjectDTO", + "ai.giskard.web.dto.CreateFeedbackReplyDTO": "CreateFeedbackReplyDTO", + "ai.giskard.web.dto.ApiKeyDTO": "ApiKeyDTO", + "ai.giskard.web.dto.user.UpdateMeDTO": "UpdateMeDTO", + "ai.giskard.web.dto.config.MLWorkerInfoDTO": "MLWorkerInfoDTO", + "ai.giskard.web.dto.PostImportProjectDTO": "PostImportProjectDTO", + "ai.giskard.web.dto.TestSuiteCompleteDTO": "TestSuiteCompleteDTO", + "ai.giskard.web.dto.SlicingFunctionDTO": "SlicingFunctionDTO", + "ai.giskard.domain.ml.CallToActionKind": "CallToActionKind", + "ai.giskard.web.dto.ExplainTextResponseDTO": "ExplainTextResponseDTO", + "ai.giskard.web.dto.JWTToken": "JWTToken", + "ai.giskard.web.dto.ModelUploadParamsDTO": "ModelUploadParamsDTO", + "ai.giskard.web.dto.TransformationFunctionDTO": "TransformationFunctionDTO", + "ai.giskard.web.dto.ml.write.ModelPostDTO": "ModelPostDTO", + "ai.giskard.web.dto.CatalogDTO": "CatalogDTO", + "ai.giskard.web.dto.ml.NamedSingleTestResultDTO": "NamedSingleTestResultDTO", + "ai.giskard.web.dto.PushActionDTO": "PushActionDTO", + "ai.giskard.web.dto.ml.DatasetDTO": "DatasetDTO", + "ai.giskard.web.dto.PredictionDTO": "PredictionDTO", + "ai.giskard.web.dto.user.AdminUserDTO$AdminUserDTOWithPassword": "AdminUserDTO.AdminUserDTOWithPassword", + "ai.giskard.web.dto.PrepareImportProjectDTO": "PrepareImportProjectDTO", + "ai.giskard.web.dto.FilterDatasetDTO": "FilterDatasetDTO", + "ai.giskard.web.dto.ml.TestResultMessageDTO": "TestResultMessageDTO", + "ai.giskard.web.dto.DataUploadParamsDTO": "DataUploadParamsDTO", + "ai.giskard.web.dto.RowFilterDTO": "RowFilterDTO", + "ai.giskard.web.dto.ExplainResponseDTO": "ExplainResponseDTO", + "ai.giskard.web.dto.ModelMetadataDTO": "ModelMetadataDTO", + "ai.giskard.web.dto.MessageDTO": "MessageDTO", + "ai.giskard.web.dto.PushDetailsDTO": "PushDetailsDTO", + "ai.giskard.web.dto.ml.InspectionDTO": "InspectionDTO", + "ai.giskard.ml.dto.MLWorkerWSTestMessageType": "MLWorkerWSTestMessageType", + "ai.giskard.domain.GeneralSettings": "GeneralSettings", + "ai.giskard.web.dto.CallableDTO": "CallableDTO" } } \ No newline at end of file diff --git a/frontend/src/stores/catalog.ts b/frontend/src/stores/catalog.ts index 5fd0ed3aec..73eccaa801 100644 --- a/frontend/src/stores/catalog.ts +++ b/frontend/src/stores/catalog.ts @@ -6,10 +6,10 @@ import { DatasetDTO, DatasetProcessFunctionDTO } from '@/generated-sources'; -import {defineStore} from 'pinia'; -import {api} from '@/api'; -import {chain} from "lodash"; -import {getColumnType} from "@/utils/column-type-utils"; +import { defineStore } from 'pinia'; +import { api } from '@/api'; +import { chain } from 'lodash'; +import { getColumnType } from '@/utils/column-type-utils'; interface State { catalog: CatalogDTO | null @@ -18,7 +18,7 @@ interface State { function latestVersions(data?: Array): Array { return chain(data ?? []) - .groupBy(func => `${func.module}.${func.name}`) + .groupBy(func => func.module ?? `${func.module}.${func.name}`) .mapValues(functions => chain(functions) .maxBy(func => func.version ?? 1) .value()) diff --git a/frontend/src/stores/debugging-sessions.ts b/frontend/src/stores/debugging-sessions.ts index f8d8554280..25510e3419 100644 --- a/frontend/src/stores/debugging-sessions.ts +++ b/frontend/src/stores/debugging-sessions.ts @@ -1,51 +1,84 @@ -import { InspectionDTO, InspectionCreateDTO } from '@/generated-sources'; -import { defineStore } from 'pinia'; -import { api } from '@/api'; +import {InspectionCreateDTO, InspectionDTO, ParameterizedCallableDTO, RowFilterType} from '@/generated-sources'; +import {defineStore} from 'pinia'; +import {api} from '@/api'; + +interface FilterType { + label: string; + value: RowFilterType; + disabled?: boolean; + description?: string; +} interface State { - projectId: number | null; - debuggingSessions: Array; - currentDebuggingSessionId: number | null; + projectId: number | null; + debuggingSessions: Array; + currentDebuggingSessionId: number | null; + selectedSlicingFunction: Partial; + selectedFilter: FilterType | null } export const useDebuggingSessionsStore = defineStore('debuggingSessions', { - state: (): State => ({ - projectId: null, - debuggingSessions: [], - currentDebuggingSessionId: null, - }), - getters: {}, - actions: { - async reload() { - if (this.projectId !== null) { - await this.loadDebuggingSessions(this.projectId); - } - }, - async loadDebuggingSessions(projectId: number) { - if (this.projectId !== projectId) { - this.projectId = projectId; - this.currentDebuggingSessionId = null; - } - this.debuggingSessions = await api.getProjectInspections(this.projectId); - }, - async createDebuggingSession(inspection: InspectionCreateDTO) { - const newDebuggingSession = await api.prepareInspection(inspection); - await this.reload(); - return newDebuggingSession; - }, - async deleteDebuggingSession(inspectionId: number) { - await api.deleteInspection(inspectionId); - if (this.currentDebuggingSessionId === inspectionId) { - this.currentDebuggingSessionId = null; - } - await this.reload(); - }, - async updateDebuggingSessionName(inspectionId: number, inspection: InspectionCreateDTO) { - await api.updateInspectionName(inspectionId, inspection); - await this.reload(); - }, - setCurrentDebuggingSessionId(inspectionId: number | null) { - this.currentDebuggingSessionId = inspectionId; + state: (): State => ({ + projectId: null, + debuggingSessions: [], + currentDebuggingSessionId: null, + selectedSlicingFunction: { + uuid: undefined, + params: [], + type: 'SLICING' + }, + selectedFilter: null + }), + getters: {}, + actions: { + async reload() { + if (this.projectId !== null) { + await this.loadDebuggingSessions(this.projectId); + } + }, + clear() { + this.selectedSlicingFunction = { + uuid: undefined, + params: [], + type: 'SLICING' + } + }, + async loadDebuggingSessions(projectId: number) { + if (this.projectId !== projectId) { + this.projectId = projectId; + this.currentDebuggingSessionId = null; + } + this.debuggingSessions = await api.getProjectInspections(this.projectId); + }, + async createDebuggingSession(inspection: InspectionCreateDTO) { + const newDebuggingSession = await api.prepareInspection(inspection); + await this.reload(); + return newDebuggingSession; + }, + async deleteDebuggingSession(inspectionId: number) { + await api.deleteInspection(inspectionId); + if (this.currentDebuggingSessionId === inspectionId) { + this.currentDebuggingSessionId = null; + } + await this.reload(); + }, + async updateDebuggingSessionName(inspectionId: number, inspection: InspectionCreateDTO) { + await api.updateInspectionName(inspectionId, inspection); + await this.reload(); + }, + setCurrentDebuggingSessionId(inspectionId: number | null) { + this.currentDebuggingSessionId = inspectionId; + this.clear(); + }, + setCurrentSlicingFunctionUuid(uuid: string) { + this.selectedSlicingFunction = { + uuid: uuid, + params: [], + type: 'SLICING' + }; + }, + setSelectedFilter(filter: FilterType | null) { + this.selectedFilter = filter; + } }, - }, }); diff --git a/frontend/src/stores/push.ts b/frontend/src/stores/push.ts new file mode 100644 index 0000000000..07d457eb34 --- /dev/null +++ b/frontend/src/stores/push.ts @@ -0,0 +1,61 @@ +import {CallToActionKind, PushDTO, PushKind} from "@/generated-sources"; +import {defineStore} from "pinia"; +import {api} from "@/api"; + +interface State { + pushes: { [key: string]: Pushes }; + current: Pushes | undefined; + identifier: PushIdentifier | undefined; +} + +interface Pushes { + perturbation: PushDTO, + contribution: PushDTO, + borderline: PushDTO, + overconfidence: PushDTO +} + +interface PushIdentifier { + modelId: string, + datasetId: string, + rowNb: number, + inputData: any, + modelFeatures: string[] +} + +export const usePushStore = defineStore('push', { + state: (): State => ({ + pushes: {}, + current: undefined, + identifier: undefined, + }), + getters: {}, + actions: { + async fetchPushSuggestions(modelId: string, datasetId: string, rowNb: number, inputData: any, modelFeatures: string[]) { + this.identifier = {modelId, datasetId, rowNb, inputData, modelFeatures}; + let identifierString = JSON.stringify({modelId, datasetId, rowNb, inputData, modelFeatures}); + let previous = this.current; + this.current = undefined; + + let result = await api.getPushes(modelId, datasetId, rowNb, inputData); + + if (previous == result) { + return; + } + + let currentIdentifier = JSON.stringify(this.identifier); + if (currentIdentifier === identifierString) { + console.log("Setting pushes") + // @ts-ignore + this.current = result; + } + + return this.current; + }, + async applyPush(pushKind: PushKind, ctaKind: CallToActionKind): Promise { + let result = await api.applyPush(this.identifier!.modelId, this.identifier!.datasetId, this.identifier!.rowNb, pushKind, ctaKind, this.identifier!.inputData); + // @ts-ignore + return result; + } + } +}) \ No newline at end of file diff --git a/frontend/src/utils/nametags.utils.ts b/frontend/src/utils/nametags.utils.ts new file mode 100644 index 0000000000..716b8aea4b --- /dev/null +++ b/frontend/src/utils/nametags.utils.ts @@ -0,0 +1,35 @@ +import {useProjectArtifactsStore} from "@/stores/project-artifacts"; + +export function $tags(input: string): string { + if (input == undefined) { + return input; + } + + const store = useProjectArtifactsStore(); + + // In the input, there will be tags such as: , + // The uuid is like model:b1d7dd0e-400c-421d-8b0e-721f852e77a8 + // Grab their names from the stored data and replace it! + let result = input; + const modelRegex = //g; + // Can match: , <*_dataset:uuid> + const datasetRegex = /<([a-z_]*_)?dataset:(.*?)>/g; + const modelMatches = input.matchAll(modelRegex); + const datasetMatches = input.matchAll(datasetRegex); + + for (const match of modelMatches) { + const model = store.models.find(model => model.id === match[1]); + if (model) { + result = result.replace(match[0], $tags(model.name ?? "Unnamed model")); + } + } + + for (const match of datasetMatches) { + const dataset = store.datasets.find(dataset => dataset.id === match[2]); + if (dataset) { + result = result.replace(match[0], $tags(dataset.name ?? "Unnamed dataset")); + } + } + + return result; +} \ No newline at end of file diff --git a/frontend/src/views/main/project/Datasets.vue b/frontend/src/views/main/project/Datasets.vue index 406c0e6807..0f869e424a 100644 --- a/frontend/src/views/main/project/Datasets.vue +++ b/frontend/src/views/main/project/Datasets.vue @@ -2,7 +2,8 @@
- + @@ -32,7 +33,7 @@ - @@ -101,18 +102,19 @@