Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 20 additions & 19 deletions python-client/giskard/models/base/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from typing import Any, Callable, Iterable, Optional, Union

import cloudpickle
import mlflow
import numpy as np
import pandas as pd
import yaml
import mlflow

from ...core.core import ModelType
from ...core.validation import configured_validate_arguments
Expand All @@ -31,18 +31,18 @@ class WrapperModel(BaseModel, ABC):

@configured_validate_arguments
def __init__(
self,
model: Any,
model_type: ModelType,
data_preprocessing_function: Optional[Callable[[pd.DataFrame], Any]] = None,
model_postprocessing_function: Optional[Callable[[Any], Any]] = None,
name: Optional[str] = None,
feature_names: Optional[Iterable] = None,
classification_threshold: Optional[float] = 0.5,
classification_labels: Optional[Iterable] = None,
id: Optional[str] = None,
batch_size: Optional[int] = None,
**kwargs,
self,
model: Any,
model_type: ModelType,
data_preprocessing_function: Optional[Callable[[pd.DataFrame], Any]] = None,
model_postprocessing_function: Optional[Callable[[Any], Any]] = None,
name: Optional[str] = None,
feature_names: Optional[Iterable] = None,
classification_threshold: Optional[float] = 0.5,
classification_labels: Optional[Iterable] = None,
id: Optional[str] = None,
batch_size: Optional[int] = None,
**kwargs,
) -> None:
"""
Parameters
Expand Down Expand Up @@ -130,7 +130,12 @@ def predict_df(self, df: pd.DataFrame):
output = self._postprocess(output)
outputs.append(output)

return np.concatenate(outputs)
raw_prediction = np.concatenate(outputs)

if self.is_regression:
return raw_prediction.astype(float)

return raw_prediction

def _possibly_fix_predictions_shape(self, raw_predictions: np.ndarray):
if not self.is_classification:
Expand Down Expand Up @@ -292,15 +297,11 @@ def load_wrapper_meta(cls, local_dir):
# ensuring backward compatibility
return {"batch_size": None}

def to_mlflow(self,
artifact_path: str = "prediction-function-from-giskard",
**kwargs):

def to_mlflow(self, artifact_path: str = "prediction-function-from-giskard", **kwargs):
def _giskard_predict(df):
return self.predict(df)

class MLflowModel(mlflow.pyfunc.PythonModel):

def predict(self, df):
return _giskard_predict(df)

Expand Down
68 changes: 41 additions & 27 deletions python-client/giskard/models/cache/cache.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import csv
import os
import uuid
from pathlib import Path
from typing import Dict, List, Any, Iterable, Optional
from typing import Any, Iterable, List, Optional

import numpy as np
import pandas as pd

from giskard.core.core import SupportedModelTypes
from giskard.settings import settings
from ...client.python_utils import warning
from ...core.core import SupportedModelTypes
from ...settings import settings

NaN = float("NaN")

Expand All @@ -24,43 +23,58 @@ def flatten(xs):


class ModelCache:
id: Optional[str] = None
prediction_cache: Dict[str, Any] = None
_default_cache_dir_prefix = Path(settings.home_dir / settings.cache_dir / "global" / "prediction_cache")

vectorized_get_cache_or_na = None

def __init__(self, model_type: SupportedModelTypes, id: Optional[str] = None, cache_dir: Path = None):
self.id = id or str(uuid.uuid4())
def __init__(self, model_type: SupportedModelTypes, id: Optional[str] = None, cache_dir: Optional[Path] = None):
self.id = id
self.prediction_cache = dict()
self.cache_dir = cache_dir or Path(settings.home_dir / settings.cache_dir / "global/prediction_cache" / self.id)

if id is not None:
if (self.cache_dir / CACHE_CSV_FILENAME).exists():
with open(self.cache_dir / CACHE_CSV_FILENAME, "r") as pred_f:
reader = csv.reader(pred_f)
for row in reader:
if model_type == SupportedModelTypes.TEXT_GENERATION:
self.prediction_cache[row[0]] = row[1:]
elif model_type == SupportedModelTypes.REGRESSION:
self.prediction_cache[row[0]] = float(row[1])
else:
self.prediction_cache[row[0]] = [float(i) for i in row[1:]]

if cache_dir is None and self.id:
cache_dir = self._default_cache_dir_prefix.joinpath(self.id)

self.cache_file = cache_dir / CACHE_CSV_FILENAME if cache_dir else None

self.vectorized_get_cache_or_na = np.vectorize(self.get_cache_or_na, otypes=[object])
self.model_type = model_type
self._warmed_up = False

def warm_up_from_disk(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not warm up at creation time instead of read_from_cache ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because I don’t want to warm up if the cache is disabled. But it’s still better to create the ModelCache object in the model, so that it can be used if the cache is enabled at a later time.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in case cache is disabled the cleanest will be not to initialize _cache property at all and then warm it up in the ModelCache constructor depending on a cache type (in memory only vs FS backed), WDYT?

Copy link
Member Author

@mattbit mattbit Aug 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But cache can be disabled temporarily. So if we don’t initialize the _cache property at all, we would need to check if it exists at every prediction, initializing it there and possibly warm it up at prediction time, in addition to checking that cache is enabled. That’s because cache could have been disabled when model was initialized, but enabled later.

In any case I would avoid doing expensive and persistent operations upon construction of the instance.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok let's keep it this way

if self.cache_file is None or not self.cache_file.exists():
return

try:
with self.cache_file.open("r", newline="") as pred_f:
reader = csv.reader(pred_f)
for row in reader:
if self.model_type == SupportedModelTypes.TEXT_GENERATION:
# Text generation models output should be a single string
self.prediction_cache[row[0]] = row[1]
elif self.model_type == SupportedModelTypes.REGRESSION:
# Regression models output is always casted to float
self.prediction_cache[row[0]] = float(row[1])
else:
# Classification models return list of probabilities
self.prediction_cache[row[0]] = [float(i) for i in row[1:]]
except Exception as e:
warning(f"Failed to load cache from disk for model {self.id}: {e}")

def get_cache_or_na(self, key: str):
return self.prediction_cache.get(key, NaN)

def read_from_cache(self, keys: pd.Series):
if self.id and not self._warmed_up:
self.warm_up_from_disk()
self._warmed_up = True

return pd.Series(self.vectorized_get_cache_or_na(keys), index=keys.index)

def set_cache(self, keys: pd.Series, values: List[Any]):
for i in range(len(keys)):
self.prediction_cache[keys.iloc[i]] = values[i]

if self.id:
os.makedirs(self.cache_dir, exist_ok=True)
with open(self.cache_dir / CACHE_CSV_FILENAME, "a") as pred_f:
if self.cache_file is not None:
self.cache_file.parent.mkdir(parents=True, exist_ok=True)
with self.cache_file.open("a", newline="") as pred_f:
writer = csv.writer(pred_f)
for i in range(len(keys)):
writer.writerow(flatten([keys.iloc[i], values[i]]))
Expand Down
31 changes: 30 additions & 1 deletion python-client/tests/models/test_model_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,42 @@
import numpy as np
import pandas as pd
import xxhash
from langchain import LLMChain, PromptTemplate
from langchain.llms.fake import FakeListLLM

import giskard
from giskard import Model, Dataset
from giskard import Dataset, Model
from giskard.core.core import SupportedModelTypes
from giskard.models.cache import ModelCache


def test_model_prediction_is_cached_on_text_generation_model():
llm = FakeListLLM(responses=['This is my text with special chars" → ,.!? # and \n\nnewlines', "This is my text"])

prompt = PromptTemplate(template="{instruct}", input_variables=["instruct"])
chain = LLMChain(llm=llm, prompt=prompt)
model = Model(chain, model_type="text_generation")
dataset = Dataset(
pd.DataFrame({"instruct": ["Test 1", "Test 2"]}),
column_types={
"instruct": "text",
},
)
model.predict(dataset)

# Should load from cache
assert model.predict(dataset).raw_prediction.tolist() == llm.responses
assert model.predict(dataset).raw_prediction.tolist() == llm.responses

# Test cache persistence
model_id = model.id

del model
model = Model(chain, model_type="text_generation", id=model_id.hex)

assert model.predict(dataset).raw_prediction.tolist() == llm.responses


def test_model_prediction_is_cached_on_regression_model():
called_indexes = []

Expand Down