Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions docs/reference/notebooks/wage_classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"collapsed": false
},
"source": [
"# German credit scoring [scikit-learn]\n",
"# Wage classification [scikit-learn]\n",
"\n",
"Giskard is an open-source framework for testing all ML models, from LLMs to tabular models. Don’t hesitate to give the project a [star on GitHub](https://github.com/Giskard-AI/giskard) ⭐️ if you find it useful!\n",
"\n",
Expand Down Expand Up @@ -75,11 +75,11 @@
"from urllib.request import urlretrieve\n",
"\n",
"import pandas as pd\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.metrics import accuracy_score\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.metrics import accuracy_score\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
"\n",
"from giskard import Model, Dataset, scan, testing, GiskardClient, Suite"
Expand Down Expand Up @@ -111,8 +111,8 @@
"TEST_RATIO = 0.2\n",
"\n",
"DROP_FEATURES = [\n",
" 'education', \n",
" 'native-country', \n",
" 'education',\n",
" 'native-country',\n",
" 'occupation',\n",
" 'marital-status',\n",
" 'educational-num'\n",
Expand Down Expand Up @@ -229,7 +229,7 @@
},
"outputs": [],
"source": [
"X_train, X_test, y_train, y_test = train_test_split(income_df.drop(columns=TARGET_COLUMN), income_df[TARGET_COLUMN], \n",
"X_train, X_test, y_train, y_test = train_test_split(income_df.drop(columns=TARGET_COLUMN), income_df[TARGET_COLUMN],\n",
" test_size=TEST_RATIO, random_state=RANDOM_SEED)"
]
},
Expand Down Expand Up @@ -257,10 +257,12 @@
"source": [
"raw_data = pd.concat([X_test, y_test], axis=1)\n",
"giskard_dataset = Dataset(\n",
" df=raw_data, # A pandas.DataFrame that contains the raw data (before all the pre-processing steps) and the actual ground truth variable (target).\n",
" df=raw_data,\n",
" # A pandas.DataFrame that contains the raw data (before all the pre-processing steps) and the actual ground truth variable (target).\n",
" target=TARGET_COLUMN, # Ground truth variable.\n",
" name=\"salary_data\", # Optional.\n",
" cat_columns=CATEGORICAL_FEATURES # List of categorical columns. Optional, but is a MUST if available. Inferred automatically if not.\n",
" cat_columns=CATEGORICAL_FEATURES\n",
" # List of categorical columns. Optional, but is a MUST if available. Inferred automatically if not.\n",
")"
]
},
Expand Down Expand Up @@ -351,7 +353,8 @@
"outputs": [],
"source": [
"giskard_model = Model(\n",
" model=pipeline, # A prediction function that encapsulates all the data pre-processing steps and that could be executed with the dataset used by the scan.\n",
" model=pipeline,\n",
" # A prediction function that encapsulates all the data pre-processing steps and that could be executed with the dataset used by the scan.\n",
" model_type=\"classification\", # Either regression, classification or text_generation.\n",
" name=\"salary_cls\", # Optional.\n",
" classification_labels=pipeline.classes_, # Their order MUST be identical to the prediction_function's output order.\n",
Expand Down Expand Up @@ -618,7 +621,7 @@
"outputs": [],
"source": [
"# Create a Giskard client after having install the Giskard server (see documentation)\n",
"api_key = \"<Giskard API key>\" #This can be found in the Settings tab of the Giskard hub\n",
"api_key = \"<Giskard API key>\" #This can be found in the Settings tab of the Giskard hub\n",
"#hf_token = \"<Your Giskard Space token>\" #If the Giskard Hub is installed on HF Space, this can be found on the Settings tab of the Giskard Hub\n",
"\n",
"client = GiskardClient(\n",
Expand Down
2 changes: 0 additions & 2 deletions giskard/datasets/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,6 @@ class Dataset(ColumnMetadataMixin):
column_types (Optional[Dict[str, str]]):
A dictionary of column names and their types (numeric, category or text) for all columns of df. If not provided,
the categorical columns will be automatically inferred.
data_processor (DataProcessor):
An instance of the `DataProcessor` class used for data processing.
"""

name: Optional[str]
Expand Down