Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions marimo/_snippets/data/altair-15.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# /// script
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thoughts on renaming these series to embedding-1.py? maybe for more discoverability (also not teaching you altair, but rather using some tools for embedding)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. In my mind this is the main visualisation technique for embeddings, but I agree it's less of an altair thing at this point

# requires-python = ">=3.12"
# dependencies = [
# "altair==5.5.0",
# "marimo",
# "model2vec==0.6.0",
# "polars==1.31.0",
# "scikit-learn==1.7.1",
# "umap-learn==0.5.9.post2",
# ]
# ///

import marimo

__generated_with = "0.14.13"
app = marimo.App(width="columns")


@app.cell
def _(mo):
mo.md(
r"""
# Visualization: Embedding Summary and Bulk Selection

Create interactive dashboards using `mo.altair_chart` and `UMAP`.
This technique is generally useful for any kind of embedding, but we're demonstrating it with text embeddings below.
"""
)
return



@app.cell
def _():
import marimo as mo
import polars as pl
import altair as alt
from model2vec import StaticModel
from umap import UMAP
return StaticModel, UMAP, alt, mo, pl


@app.cell
def _(StaticModel, UMAP, pl):
DATASET = "https://calmcode.io/static/data/clinc.csv"
TEXT_COL = "text"

df = pl.read_csv(DATASET).sample(10_000)

# We're using Model2Vec because it so lightweight, sentence-transformers will also work!
tfm = StaticModel.from_pretrained("minishlab/potion-base-8M")
df = df.with_columns(emb=tfm.encode(df[TEXT_COL].to_list()))

# UMAP turns the high-dimensional embeddings into 2D points which are easier to visualize.
x_pca = UMAP(n_components=2).fit_transform(df["emb"].to_numpy())
df = df.with_columns(x=x_pca[:, 0], y=x_pca[:, 1]).select(TEXT_COL, "x", "y")
return TEXT_COL, df


@app.cell
def _(alt, df, mo):
chart = mo.ui.altair_chart(alt.Chart(df).mark_point().encode(x="x", y="y").properties(width=500))
return (chart,)


@app.cell
def _(TEXT_COL, chart, mo):
mo.hstack([chart, chart.value.select(TEXT_COL)])
return


if __name__ == "__main__":
app.run()

Loading