Skip to content

Commit f0067b8

Browse files
authored
Add TextClassification, UMAP, DBSCAN and TextClustering tasks (#948)
* Redirect import of task * Add icon for text classification * Add text classification task * Add tests for text classification * Continue with this problematic thing until we merge it in one of the PRs * Port itertools.batched function for python<3.12 * Make more generic the template for text classification * Add tests for the extra flexibility in the template * Fix condition to determine the backend for the structured output * Simplify condition for json schema in structured output * Add folder for clustering related steps * Fix default structured output for inference endpoints * Added examples to the docstrings * Add icon for clustering steps/tasks * Add umap step * Add dbscan step * Redirect import of steps * Add text clustering task * Set default value for repo_id to avoid potential errors when loading the dataset * Change example dataset in docstrings as that has more information * Add unit tests for clustering steps * Remove extra log message unnecesary * Add tests for text clustering process * Update pyproject with dependencies of text_clustering * Set internal variables to None on unload to clean up
1 parent 28ecbc4 commit f0067b8

File tree

18 files changed

+1406
-6
lines changed

18 files changed

+1406
-6
lines changed

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,11 @@ vllm = [
9595
sentence-transformers = ["sentence-transformers >= 3.0.0"]
9696
faiss-cpu = ["faiss-cpu >= 1.8.0"]
9797
faiss-gpu = ["faiss-gpu >= 1.7.2"]
98+
text-clustering = [
99+
"umap-learn >= 0.5.6",
100+
"scikit-learn >= 1.4.1",
101+
"matplotlib >= 3.8.3" # For the figure (even though it's optional)
102+
]
98103

99104
# minhash
100105
minhash = ["datasketch >= 1.6.5", "nltk>3.8.1"]

scripts/install_dependencies.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ python_version=$(python -c "import sys; print(sys.version_info[:2])")
66

77
python -m pip install uv
88

9-
uv pip install --system -e ".[anthropic,argilla,cohere,groq,hf-inference-endpoints,hf-transformers,litellm,llama-cpp,ollama,openai,outlines,vertexai,mistralai,instructor,sentence-transformers,faiss-cpu,minhash]"
9+
uv pip install --system -e ".[anthropic,argilla,cohere,groq,hf-inference-endpoints,hf-transformers,litellm,llama-cpp,ollama,openai,outlines,vertexai,mistralai,instructor,sentence-transformers,faiss-cpu,minhash,text-clustering]"
1010

1111
if [ "${python_version}" != "(3, 12)" ]; then
1212
uv pip install --system -e .[ray]

src/distilabel/steps/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
StepInput,
2222
StepResources,
2323
)
24+
from distilabel.steps.clustering.dbscan import DBSCAN
25+
from distilabel.steps.clustering.text_clustering import TextClustering
26+
from distilabel.steps.clustering.umap import UMAP
2427
from distilabel.steps.columns.combine import CombineOutputs
2528
from distilabel.steps.columns.expand import ExpandColumns
2629
from distilabel.steps.columns.group import CombineColumns, GroupColumns
@@ -67,6 +70,9 @@
6770
"GroupColumns",
6871
"KeepColumns",
6972
"MergeColumns",
73+
"DBSCAN",
74+
"UMAP",
75+
"TextClustering",
7076
"step",
7177
"DeitaFiltering",
7278
"EmbeddingGeneration",
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2023-present, Argilla, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
# Copyright 2023-present, Argilla, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import importlib.util
16+
from typing import TYPE_CHECKING, Any, List, Optional
17+
18+
import numpy as np
19+
from pydantic import Field, PrivateAttr
20+
21+
from distilabel.mixins.runtime_parameters import RuntimeParameter
22+
from distilabel.steps import (
23+
GlobalStep,
24+
StepInput,
25+
)
26+
27+
if TYPE_CHECKING:
28+
from sklearn.cluster import DBSCAN as _DBSCAN
29+
30+
from distilabel.steps.typing import StepOutput
31+
32+
33+
class DBSCAN(GlobalStep):
34+
r"""DBSCAN (Density-Based Spatial Clustering of Applications with Noise) finds core
35+
samples in regions of high density and expands clusters from them. This algorithm
36+
is good for data which contains clusters of similar density.
37+
38+
This is a `GlobalStep` that clusters the embeddings using the DBSCAN algorithm
39+
from `sklearn`. Visit `TextClustering` step for an example of use.
40+
The trained model is saved as an artifact when creating a distiset
41+
and pushing it to the Hugging Face Hub.
42+
43+
Input columns:
44+
- projection (`List[float]`): Vector representation of the text to cluster,
45+
normally the output from the `UMAP` step.
46+
47+
Output columns:
48+
- cluster_label (`int`): Integer representing the label of a given cluster. -1
49+
means it wasn't clustered.
50+
51+
Categories:
52+
- clustering
53+
- text-classification
54+
55+
References:
56+
- [`DBSCAN demo of sklearn`](https://scikit-learn.org/stable/auto_examples/cluster/plot_dbscan.html#demo-of-dbscan-clustering-algorithm)
57+
- [`sklearn dbscan`](https://scikit-learn.org/stable/modules/clustering.html#dbscan)
58+
59+
Attributes:
60+
- eps: The maximum distance between two samples for one to be considered as in the
61+
neighborhood of the other. This is not a maximum bound on the distances of
62+
points within a cluster. This is the most important DBSCAN parameter to
63+
choose appropriately for your data set and distance function.
64+
- min_samples: The number of samples (or total weight) in a neighborhood for a point
65+
to be considered as a core point. This includes the point itself. If `min_samples`
66+
is set to a higher value, DBSCAN will find denser clusters, whereas if it is set
67+
to a lower value, the found clusters will be more sparse.
68+
- metric: The metric to use when calculating distance between instances in a feature
69+
array. If metric is a string or callable, it must be one of the options allowed
70+
by `sklearn.metrics.pairwise_distances` for its metric parameter.
71+
- n_jobs: The number of parallel jobs to run.
72+
73+
Runtime parameters:
74+
- `eps`: The maximum distance between two samples for one to be considered as in the
75+
neighborhood of the other. This is not a maximum bound on the distances of
76+
points within a cluster. This is the most important DBSCAN parameter to
77+
choose appropriately for your data set and distance function.
78+
- `min_samples`: The number of samples (or total weight) in a neighborhood for a point
79+
to be considered as a core point. This includes the point itself. If `min_samples`
80+
is set to a higher value, DBSCAN will find denser clusters, whereas if it is set
81+
to a lower value, the found clusters will be more sparse.
82+
- `metric`: The metric to use when calculating distance between instances in a feature
83+
array. If metric is a string or callable, it must be one of the options allowed
84+
by `sklearn.metrics.pairwise_distances` for its metric parameter.
85+
- `n_jobs`: The number of parallel jobs to run.
86+
"""
87+
88+
eps: Optional[RuntimeParameter[float]] = Field(
89+
default=0.3,
90+
description=(
91+
"The maximum distance between two samples for one to be considered "
92+
"as in the neighborhood of the other. This is not a maximum bound "
93+
"on the distances of points within a cluster. This is the most "
94+
"important DBSCAN parameter to choose appropriately for your data set "
95+
"and distance function."
96+
),
97+
)
98+
min_samples: Optional[RuntimeParameter[int]] = Field(
99+
default=30,
100+
description=(
101+
"The number of samples (or total weight) in a neighborhood for a point to "
102+
"be considered as a core point. This includes the point itself. If "
103+
"`min_samples` is set to a higher value, DBSCAN will find denser clusters, "
104+
"whereas if it is set to a lower value, the found clusters will be more "
105+
"sparse."
106+
),
107+
)
108+
metric: Optional[RuntimeParameter[str]] = Field(
109+
default="euclidean",
110+
description=(
111+
"The metric to use when calculating distance between instances in a "
112+
"feature array. If metric is a string or callable, it must be one of "
113+
"the options allowed by `sklearn.metrics.pairwise_distances` for "
114+
"its metric parameter."
115+
),
116+
)
117+
n_jobs: Optional[RuntimeParameter[int]] = Field(
118+
default=8, description="The number of parallel jobs to run."
119+
)
120+
121+
_clusterer: Optional["_DBSCAN"] = PrivateAttr(None)
122+
123+
def load(self) -> None:
124+
super().load()
125+
if importlib.util.find_spec("sklearn") is None:
126+
raise ImportError(
127+
"`sklearn` package is not installed. Please install it using `pip install scikit-learn`."
128+
)
129+
from sklearn.cluster import DBSCAN as _DBSCAN
130+
131+
self._clusterer = _DBSCAN(
132+
eps=self.eps,
133+
min_samples=self.min_samples,
134+
metric=self.metric,
135+
n_jobs=self.n_jobs,
136+
)
137+
138+
def unload(self) -> None:
139+
self._clusterer = None
140+
141+
@property
142+
def inputs(self) -> List[str]:
143+
return ["projection"]
144+
145+
@property
146+
def outputs(self) -> List[str]:
147+
return ["cluster_label"]
148+
149+
def _save_model(self, model: Any) -> None:
150+
import joblib
151+
152+
def save_model(path):
153+
with open(str(path / "DBSCAN.joblib"), "wb") as f:
154+
joblib.dump(model, f)
155+
156+
self.save_artifact(
157+
name="DBSCAN_model",
158+
write_function=lambda path: save_model(path),
159+
metadata={
160+
"eps": self.eps,
161+
"min_samples": self.min_samples,
162+
"metric": self.metric,
163+
},
164+
)
165+
166+
def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
167+
projections = np.array([input["projection"] for input in inputs])
168+
169+
self._logger.info("🏋️‍♀️ Start training DBSCAN...")
170+
fitted_clusterer = self._clusterer.fit(projections)
171+
cluster_labels = fitted_clusterer.labels_
172+
# Sets the cluster labels for each input, -1 means it wasn't clustered
173+
for input, cluster_label in zip(inputs, cluster_labels):
174+
input["cluster_label"] = cluster_label
175+
self._logger.info(f"DBSCAN labels assigned: {len(set(cluster_labels))}")
176+
self._save_model(fitted_clusterer)
177+
yield inputs

0 commit comments

Comments
 (0)