Skip to content

Conversation

@rabah-khalek
Copy link
Contributor

@rabah-khalek rabah-khalek commented Jun 21, 2023

1. mlflow Plug-in via evaluate

Description

Integration of mlflow via the model_evaluator plugin.

Installation requirements:

pip install mlflow # mlflow is needed for .evaluate() not mlflow-skinny shipped with giskard
pip install "git+https://github.com/Giskard-AI/giskard.git@gsk-1321/mlflow-integration#subdirectory=python-client" -q

Code example:

import mlflow

from giskard import demo
model1, df = demo.titanic(max_iter=5)
model2, df = demo.titanic(max_iter=100)

with mlflow.start_run(run_name="model1") as run1:
    model1_uri = mlflow.sklearn.log_model(model1, "sklearn_model1", pyfunc_predict_fn="predict_proba").model_uri
    mlflow.evaluate(model=model1_uri, model_type="classifier", data=df, targets="Survived", evaluators="giskard", evaluator_config={"classification_labels": ["no", "yes"]})

with mlflow.start_run(run_name="model2") as run2:
    model2_uri = mlflow.sklearn.log_model(model2, "sklearn_model2", pyfunc_predict_fn="predict_proba").model_uri
    mlflow.evaluate(model=model2_uri, model_type="classifier", data=df, targets="Survived", evaluators="giskard", evaluator_config={"classification_labels": ["no", "yes"]})

Running mlflow ui in the terminal, one gets:

  • the html results of scan embedded as Artifacts.
  • the test suite result of scan logged as Metrics.
  • model and dataset artifacts

Two notebooks to test this feature:

Run them locally in order to run the mlflow ui

2. Giskard API via to_mlflow

Description

logging artifacts and metrics from giskard to mlflow

Code example:

scan_results = giskard.scan(giskard_model, giskard_dataset)
test_suite = results.generate_test_suite("My first test suite")
test_suite_results = test_suite.run()

import mlflow

# Option 1 (via the fluent API)
with mlflow.start_run() as run:
    giskard_model.to_mlflow()
    giskard_dataset.to_mlflow()
    scan_results.to_mlflow()
    test_suite_results.to_mlflow()

# Option 2 (via MlflowClient)
from mlflow import MlflowClient

client = MlflowClient()
experiment_id = "0"
run = client.create_run(experiment_id)

giskard_model.to_mlflow(client, run.info.run_id) 
giskard_dataset.to_mlflow(client, run.info.run_id) 
scan_results.to_mlflow(client, run.info.run_id) 
test_suite_results.to_mlflow(client, run.info.run_id) 

notebook to test this feature:

Run them locally in order to run the mlflow ui:

Open questions

Todo:

Tech

  • better html rendering -- fixed in Make scan widget optionally embeddable #1209
  • better metrics logging
  • use native mlflow saving when possible -- implemented inhttps://github.com/[GSK-1321] Integration with MLflow #1189/commits/e83667df6674c229bb922c797f338b2a5d2b4bd3
    • remove save_model from MLflowBasedModel and introduce a dataclass class to create flags and disable the saving/loading validation during the scan. (PyFunc models are not meant to be saved, and the model going through mlflow.evaluate is a PyFunc model)
  • implement the above notebooks as functional-tests
  • tempfile.NamedTemporaryFile doesn't seem to delete files after context
  • add text as model_type
  • ability to compare two models (via scan_summary.json)
  • log scan results as json file to be rendered in artifact view of mlflow and allow comparison
  • implement to_mlflow() on (see doc):
    • ScanResult
    • TestSuiteResult
    • giskard.Model
    • giskard.Dataset
  • artifact view similar to doc
  • Allow people to add any Model.__init__() argument to evaluator_config not only classification_labels
  • Better error handling + their unit-tests
  • run pdm lock -G:all
  • Implement EvaluationResult
  • implement telemetry tracking
  • Implement an output message pushing for mlflow ui

Writing

LLM models comparison
import pandas as pd
from langchain import PromptTemplate, LLMChain
from langchain.llms import OpenAI
import mlflow
import openai
import os

df = pd.read_csv('https://raw.githubusercontent.com/sunnysai12345/News_Summary/master/news_summary_more.csv')
df_filtered = pd.DataFrame(df["text"].sample(10, random_state=11))

prompt = PromptTemplate(template="Create a reader comment according to the following article summary: '{text}''",
                        input_variables=["text"])

openai.api_key = os.getenv("OPENAI_API_KEY")

llm1 = OpenAI(openai_api_key=openai.api_key,
             request_timeout=20,
             max_retries=100,
             temperature=0,
             model_name="text-ada-001", ) # Possibility to select another model

chain1 = LLMChain(prompt=prompt, llm=llm1)

with mlflow.start_run(run_name="text-ada-001") as run2:
    model_uri = mlflow.langchain.log_model(chain1, "langchain").model_uri
    mlflow.evaluate(model=model_uri, model_type="text", data=df_filtered, evaluators="giskard")

llm2 = OpenAI(openai_api_key=openai.api_key,
              request_timeout=20,
              max_retries=100,
              temperature=0,
              model_name="text-embedding-ada-002", ) # Possibility to select another model

chain2 = LLMChain(prompt=prompt, llm=llm2)

with mlflow.start_run(run_name="text-embedding-ada-002") as run1:
    model_uri = mlflow.langchain.log_model(chain2, "langchain").model_uri
    mlflow.evaluate(model=model_uri, model_type="text", data=df_filtered, evaluators="giskard")

Screenshot 2023-07-10 at 15 01 19

  • doc
  • Write an article

Optional

  • Populate MLflow's evaluation examples with the giskard evaluator.
  • Capture giskard.scanner.logger ERROR Detector LLMToxicityDetector failed with error: 'PyFuncModel' object has no attribute 'rewrite_prompt' and output instead something like giskard.scanner.logger ERROR Detector LLMToxicityDetector is not supported by giskard plug-in through mlflow.evaluate

@rabah-khalek rabah-khalek self-assigned this Jun 21, 2023
@linear
Copy link

linear bot commented Jun 21, 2023

@rabah-khalek rabah-khalek marked this pull request as draft June 21, 2023 14:49
rabah-khalek and others added 15 commits July 11, 2023 22:40
# Conflicts:
#	python-client/giskard/core/suite.py
#	python-client/giskard/models/__init__.py
#	python-client/giskard/models/base/__init__.py
#	python-client/giskard/models/catboost/__init__.py
#	python-client/giskard/models/huggingface/__init__.py
#	python-client/giskard/models/langchain.py
#	python-client/giskard/models/sklearn/__init__.py
#	python-client/giskard/scanner/result.py
#	python-client/tests/models/automodel/test_infer_giskard_cls.py
@andreybavt
Copy link
Contributor

@andreybavt I am currently wrapping the PyFunc model I have access to in mlflow.evaluate with a custom giskard CloudPickleBasedModel. I only needed to customise model_predict, which was the advantage, since PyFunc standardise predict().

The drawback of not having access to the underlying model object, is that we can't run all scan detectors. This mainly happened only once, for the toxicity detector where we needed the underling langchain model object to regenerate prompts, see:

https://github.com/Giskard-AI/giskard/blob/116e1fd4f652cedcc0497a54eb7a78ecba089c89/python-client/giskard/models/langchain/__init__.py#L51-L71

We can put a pin in this, but I think we might be forced to re-think the unwrapping of the model-object from the artifact logged by MLflow for our own purposes (scan).

I think we can look at model_type argument (or the value of model_type in model_config) to either wrap it with just a PyFuncModel or a LangchainModel and then be able to call rewrite_prompt in case it's a text_generation model

@rabah-khalek
Copy link
Contributor Author

@andreybavt I am currently wrapping the PyFunc model I have access to in mlflow.evaluate with a custom giskard CloudPickleBasedModel. I only needed to customise model_predict, which was the advantage, since PyFunc standardise predict().
The drawback of not having access to the underlying model object, is that we can't run all scan detectors. This mainly happened only once, for the toxicity detector where we needed the underling langchain model object to regenerate prompts, see:
https://github.com/Giskard-AI/giskard/blob/116e1fd4f652cedcc0497a54eb7a78ecba089c89/python-client/giskard/models/langchain/__init__.py#L51-L71

We can put a pin in this, but I think we might be forced to re-think the unwrapping of the model-object from the artifact logged by MLflow for our own purposes (scan).

I think we can look at model_type argument (or the value of model_type in model_config) to either wrap it with just a PyFuncModel or a LangchainModel and then be able to call rewrite_prompt in case it's a text_generation model

As it's not a blocker, I think it's a good idea to merge this branch and to take care of this point in a new one. This will unlock some marketing actions. WDYT?

@sonarqubecloud
Copy link

Kudos, SonarCloud Quality Gate passed!    Quality Gate passed

Bug A 0 Bugs
Vulnerability A 0 Vulnerabilities
Security Hotspot A 0 Security Hotspots
Code Smell A 1 Code Smell

80.4% 80.4% Coverage
0.0% 0.0% Duplication

@rabah-khalek rabah-khalek merged commit 38da0e5 into main Jul 26, 2023
@rabah-khalek rabah-khalek added Python Pull requests that update Python code Integrations labels Aug 2, 2023
@Hartorn Hartorn deleted the gsk-1321/mlflow-integration branch September 22, 2023 10:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Integrations Python Pull requests that update Python code

Development

Successfully merging this pull request may close these issues.

4 participants