Skip to content
82 changes: 54 additions & 28 deletions giskard/models/base/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Iterable, List, Optional, Tuple, Type, Union

import builtins
import importlib
import logging
Expand All @@ -15,9 +13,12 @@
import numpy as np
import pandas as pd
import yaml
from typing import Iterable, List, Optional, Tuple, Type, Union

from giskard.client.dtos import ModelMetaInfo

from .model_prediction import ModelPredictionResults
from ..cache import get_cache_enabled
from ..utils import np_types_to_native
from ...client.giskard_client import GiskardClient
from ...core.core import ModelMeta, ModelType, SupportedModelTypes
from ...core.validation import configured_validate_arguments
Expand All @@ -27,9 +28,6 @@
from ...models.cache import ModelCache
from ...path_utils import get_size
from ...settings import settings
from ..cache import get_cache_enabled
from ..utils import np_types_to_native
from .model_prediction import ModelPredictionResults

META_FILENAME = "giskard-model-meta.yaml"

Expand Down Expand Up @@ -174,14 +172,42 @@ def __init__(
def name(self):
return self.meta.name if self.meta.name is not None else self.__class__.__name__

@property
def description(self):
return self.meta.description

@property
def model_type(self):
return self.meta.model_type

@property
def feature_names(self):
return self.meta.feature_names

@property
def classification_labels(self):
return self.meta.classification_labels

@property
def loader_class(self):
return self.meta.loader_class

@property
def loader_module(self):
return self.meta.loader_module

@property
def classification_threshold(self):
return self.meta.classification_threshold

@property
def is_classification(self) -> bool:
"""Compute if the model is of type classification.

Returns:
bool: True if the model is of type classification, False otherwise
"""
return self.meta.model_type == SupportedModelTypes.CLASSIFICATION
return self.model_type == SupportedModelTypes.CLASSIFICATION

@property
def is_binary_classification(self) -> bool:
Expand All @@ -191,7 +217,7 @@ def is_binary_classification(self) -> bool:
bool: True if the model is of type binary classification, False otherwise.
"""

return self.is_classification and len(self.meta.classification_labels) == 2
return self.is_classification and len(self.classification_labels) == 2

@property
def is_regression(self) -> bool:
Expand All @@ -200,7 +226,7 @@ def is_regression(self) -> bool:
Returns:
bool: True if the model is of type regression, False otherwise.
"""
return self.meta.model_type == SupportedModelTypes.REGRESSION
return self.model_type == SupportedModelTypes.REGRESSION

@property
def is_text_generation(self) -> bool:
Expand All @@ -209,7 +235,7 @@ def is_text_generation(self) -> bool:
Returns:
bool: True if the model is of type text generation, False otherwise.
"""
return self.meta.model_type == SupportedModelTypes.TEXT_GENERATION
return self.model_type == SupportedModelTypes.TEXT_GENERATION

@classmethod
def determine_model_class(
Expand All @@ -236,15 +262,15 @@ def save_meta(self, local_path, *_args, **_kwargs):
{
"language_version": platform.python_version(),
"language": "PYTHON",
"model_type": self.meta.model_type.name.upper(),
"threshold": self.meta.classification_threshold,
"feature_names": self.meta.feature_names,
"classification_labels": self.meta.classification_labels,
"loader_module": self.meta.loader_module,
"loader_class": self.meta.loader_class,
"model_type": self.model_type.name.upper(),
"threshold": self.classification_threshold,
"feature_names": self.feature_names,
"classification_labels": self.classification_labels,
"loader_module": self.loader_module,
"loader_class": self.loader_class,
"id": str(self.id),
"name": self.meta.name,
"description": self.meta.description,
"name": self.name,
"description": self.description,
"size": get_size(local_path),
},
f,
Expand Down Expand Up @@ -288,18 +314,18 @@ def prepare_dataframe(self, df, column_dtypes=None, target=None, *_args, **_kwar
df.drop(target, axis=1, inplace=True)
if column_dtypes and target in column_dtypes:
del column_dtypes[target]
if target and self.meta.feature_names and target in self.meta.feature_names:
self.meta.feature_names.remove(target)
if target and self.feature_names and target in self.feature_names:
self.feature_names.remove(target)

if self.meta.feature_names:
if set(self.meta.feature_names) > set(df.columns):
column_names = set(self.meta.feature_names) - set(df.columns)
if self.feature_names:
if set(self.feature_names) > set(df.columns):
column_names = set(self.feature_names) - set(df.columns)
raise ValueError(
f"The following columns are not found in the dataset: {', '.join(sorted(column_names))}"
)
df = df[self.meta.feature_names]
df = df[self.feature_names]
if column_dtypes:
column_dtypes = {k: v for k, v in column_dtypes.items() if k in self.meta.feature_names}
column_dtypes = {k: v for k, v in column_dtypes.items() if k in self.feature_names}

for cname, ctype in column_dtypes.items():
if cname not in df:
Expand Down Expand Up @@ -348,8 +374,8 @@ def predict(self, dataset: Dataset, *_args, **_kwargs) -> ModelPredictionResults
prediction=raw_prediction, raw_prediction=raw_prediction, raw=raw_prediction
)
elif self.is_classification:
labels = np.array(self.meta.classification_labels)
threshold = self.meta.classification_threshold
labels = np.array(self.classification_labels)
threshold = self.classification_threshold

if threshold is not None and len(labels) == 2:
predicted_lbl_idx = (raw_prediction[:, 1] > threshold).astype(int)
Expand All @@ -369,7 +395,7 @@ def predict(self, dataset: Dataset, *_args, **_kwargs) -> ModelPredictionResults
all_predictions=all_predictions,
)
else:
raise ValueError(f"Prediction task is not supported: {self.meta.model_type}")
raise ValueError(f"Prediction task is not supported: {self.model_type}")
timer.stop(f"Predicted dataset with shape {dataset.df.shape}")
return result

Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_catboost_changed_column_order(german_credit_test_data, german_credit_ca
german_credit_test_data.df = df.reindex(df.columns[::-1], axis=1)

# reset feature names to test the behaviour when they're not provided
german_credit_catboost.feature_names = None
german_credit_catboost.meta.feature_names = None

res = german_credit_catboost.predict(german_credit_test_data)
assert len(res.prediction) == len(german_credit_test_data.df)
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_model_explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_explain(ds_name: str, model_name: str, include_feature_names: bool, req

# Try without feature names, it should also work
if not include_feature_names:
model.feature_names = None
model.meta.feature_names = None

explanations = explain(model, ds, ds.df.iloc[0].to_dict())

Expand All @@ -51,7 +51,7 @@ def test_explain(ds_name: str, model_name: str, include_feature_names: bool, req


def test_explain_shuffle_columns(german_credit_test_data, german_credit_model):
german_credit_model.feature_names = None
german_credit_model.meta.feature_names = None
ds = german_credit_test_data
# change column order
res = explain(german_credit_model, ds, ds.df.iloc[0][ds.df.columns[::-1]].to_dict())
Expand Down