From 53baf51e158c54b26ca3d0b8b768443ad858d64a Mon Sep 17 00:00:00 2001 From: Matteo Dora Date: Thu, 6 Jul 2023 12:40:41 +0200 Subject: [PATCH] Split train and test set in fraud detection fixtures --- .../fraud_detection__binary_classification.py | 28 +++++++++++++------ 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/python-client/tests/fixtures/fraud_detection__binary_classification.py b/python-client/tests/fixtures/fraud_detection__binary_classification.py index 7623b7c707..8d1c10d1cd 100644 --- a/python-client/tests/fixtures/fraud_detection__binary_classification.py +++ b/python-client/tests/fixtures/fraud_detection__binary_classification.py @@ -1,9 +1,10 @@ import os from pathlib import Path -import pytest import pandas as pd +import pytest from pandas.api.types import union_categoricals +from sklearn.model_selection import train_test_split from giskard import Dataset, Model from tests.url_utils import fetch_from_ftp @@ -135,25 +136,36 @@ def preprocess_dataset(train_set, test_set): # Remove useless columns. united.drop("TransactionDT", axis=1, inplace=True) - return united + # Split in train/test sets + train_set, test_set = train_test_split(united, test_size=0.5, random_state=41) + + return train_set, test_set @pytest.fixture() def fraud_detection_data() -> Dataset: - # Download dataset. - raw_data = preprocess_dataset(*read_dataset()) + _, test_set = preprocess_dataset(*read_dataset()) + wrapped_dataset = Dataset( + test_set, name="fraud_detection_adversarial_dataset", target=TARGET_COLUMN, cat_columns=CATEGORICALS + ) + return wrapped_dataset + + +@pytest.fixture() +def fraud_detection_train_data() -> Dataset: + train_set, _ = preprocess_dataset(*read_dataset()) wrapped_dataset = Dataset( - raw_data, name="fraud_detection_adversarial_dataset", target=TARGET_COLUMN, cat_columns=CATEGORICALS + train_set, name="fraud_detection_adversarial_dataset", target=TARGET_COLUMN, cat_columns=CATEGORICALS ) return wrapped_dataset @pytest.fixture() -def fraud_detection_model(fraud_detection_data: Dataset) -> Model: +def fraud_detection_model(fraud_detection_train_data: Dataset) -> Model: from lightgbm import LGBMClassifier - x = fraud_detection_data.df.drop(TARGET_COLUMN, axis=1) - y = fraud_detection_data.df[TARGET_COLUMN] + x = fraud_detection_train_data.df.drop(TARGET_COLUMN, axis=1) + y = fraud_detection_train_data.df[TARGET_COLUMN] estimator = LGBMClassifier() estimator.fit(x, y)