Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion giskard-ml-worker/ml_worker/utils/grpc_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from ml_worker.core.giskard_dataset import GiskardDataset
from ml_worker.core.model import GiskardModel
from ml_worker_pb2 import SerializedGiskardModel, SerializedGiskardDataset
from generated.ml_worker_pb2 import SerializedGiskardModel, SerializedGiskardDataset


def deserialize_model(serialized_model: SerializedGiskardModel) -> GiskardModel:
Expand Down
209 changes: 67 additions & 142 deletions giskard-ml-worker/poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions giskard-ml-worker/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ numpy = "^1.21.6"
#torch = ">=1.10.0"
#tensorflow = ">=2.0.0"
#catboost = ">=1.0.6"
nlpaug = "^1.1.11"

[tool.poetry.dev-dependencies]
grpcio-tools = ">=1.46.3"
Expand Down
98 changes: 98 additions & 0 deletions giskard-ml-worker/test/fixtures/enron_multilabel_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import logging
import time

import pandas as pd
import pytest
from sklearn import model_selection
from sklearn.compose import ColumnTransformer
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.impute import SimpleImputer
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder
from sklearn.preprocessing import StandardScaler

from ml_worker.core.giskard_dataset import GiskardDataset
from ml_worker.core.model import GiskardModel
from test import path

input_types = {
"Subject": "text",
"Content": "text",
"Week_day": "category",
"Month": "category",
"Hour": "numeric",
"Nb_of_forwarded_msg": "numeric",
"Year": "numeric"
}


@pytest.fixture()
def enron_data() -> GiskardDataset:
logging.info("Fetching Enron Data")
return GiskardDataset(
df=pd.read_csv(path('test_data/enron_data.csv')),
target='Target',
feature_types=input_types
)


@pytest.fixture()
def enron_test_data(enron_data):
return GiskardDataset(
df=pd.DataFrame(enron_data.df).drop(columns=['Target']),
feature_types=input_types,
target=None
)


@pytest.fixture()
def enron_model(enron_data) -> GiskardModel:
start = time.time()

columns_to_scale = [key for key in input_types.keys() if input_types[key] == "numeric"]

numeric_transformer = Pipeline([('imputer', SimpleImputer(strategy='median')),
('scaler', StandardScaler())])

columns_to_encode = [key for key in input_types.keys() if input_types[key] == "category"]

categorical_transformer = Pipeline([
('imputer', SimpleImputer(strategy='constant', fill_value='missing')),
('onehot', OneHotEncoder(handle_unknown='ignore', sparse=False))])

text_transformer = Pipeline([
('vect', CountVectorizer()),
('tfidf', TfidfTransformer())
])

preprocessor = ColumnTransformer(
transformers=[
('num', numeric_transformer, columns_to_scale),
('cat', categorical_transformer, columns_to_encode),
('text_Mail', text_transformer, "Content")
]
)
clf = Pipeline(steps=[('preprocessor', preprocessor),
('classifier', LogisticRegression(max_iter=100))])

Y = enron_data.df['Target']
X = enron_data.df.drop(columns="Target")
X_train, X_test, Y_train, Y_test = model_selection.train_test_split(X, Y, # NOSONAR
test_size=0.20,
random_state=30,
stratify=Y)
clf.fit(X_train, Y_train)

train_time = time.time() - start
model_score = clf.score(X_test, Y_test)
logging.info(f"Trained model with score: {model_score} in {round(train_time * 1000)} ms")

return GiskardModel(
prediction_function=clf.predict_proba,
model_type='classification',
feature_names=list(input_types),
classification_threshold=0.5,
classification_labels=clf.classes_
)
Loading