The methodology implemented in this package is described in:
Scaling Laws for Optimal Data Mixtures
Shukor, Mustafa and Bethune, Louis and Busbridge, Dan and Grangier, David and Fini, Enrico and El-Nouby, Alaaeldin and Ablin, Pierre
NeurIPS 2025
ScaleFit is a Python package for fitting parametric scaling laws to empirical data using JAX. It is designed for multi-dimensional scaling functions used in Large Language Model (LLM) analysis, providing tools for parameter estimation and uncertainty quantification via bootstrapping.
Instead of the widely used random initialization strategy, it uses the Basin-Hopping method which implements a more efficient search.
ScaleFit requires Python 3.9 or higher. To install the package and its dependencies (JAX, Pandas, Joblib), run:
git clone https://github.com/your-username/scalefit.git
cd scalefit
pip install .This example fits a basic power law to synthetic data.
import jax.numpy as jnp
import pandas as pd
from scalefit import ScalingLaw
# 1. Define the parametric form of the scaling law
def power_law(params, inputs):
# inputs are the input observations
# params are the parameters of the law
return params["a"] * (inputs["n_tokens"] ** params["b"]) + params["c"]
# 2. Define parameter search bounds.
# The initializations of the search algorithms will be drawn uniformly
# from these bounds. Note: this is only for initialization; the final
# optimized parameters may fall outside these ranges if a better fit
# is found by the optimizer.
bounds = {
"a": (0.0, 10.0),
"b": (-1.0, 0.0),
"c": (0.0, 2.0)
}
# 3. Fake data, here you would load your observations.
df = pd.DataFrame({
"n_tokens": [10, 100, 1000, 10000],
"loss": [2.5, 1.2, 0.75, 0.6]
})
# 4. Fit and Predict
model = ScalingLaw(model_fn=power_law, bounds=bounds)
model.fit(df, df["loss"])
# Prediction returns a Pandas Series matching the input index
y_pred = model.predict(df)
# Get the optimal parameters found during search.
print(model.optimal_params_)More complete examples, including the fit for the data mixture scaling law paper, are in the examples/ folder.
In data/, there are sample CSV files containing the observations from
- The original chinchilla paper, where the data is collected from the figures (
data/chinchilla.csv) - The data mixture scaling law paper Shukor et al. (2025) (
data/dmsl_*.csv)
The project is organized to separate core optimization logic from application-specific scaling models:
-
src/scalefit/: Core library source code.-
scaling.py: Implementation of theScalingLawclass. This handles data sanitization (Pandas/NumPy to JAX), bootstrapping. -
optim.py: Internal utilities for grid search initialization and BasinHopping / L-BFGS integration.
-
-
examples/: Real-world application scripts using real datasets.-
fit_chinchilla.py: A re-implementation of the Chinchilla scaling laws (Hoffmann et al., 2022). Fits the equation$L(N, D) = E + \frac{A}{N^\alpha} + \frac{B}{D^\beta}$ to predict compute-optimal model sizes. -
fit_dmsl.py: Implements the scaling law for data mixture discussed in Shukor et al. (2025). This example demonstrates how to model loss when training on data mixtures with varying weights for different domains.
-
-
data/: Sample CSV files (e.g.,chinchilla.csv,dmsl_nmm.csv) used by the example scripts.
If you use this package in your research, please cite:
@inproceedings{shukor2025scaling,
title={Scaling Laws for Optimal Data Mixtures},
author={Shukor, Mustafa and Bethune, Louis and Busbridge, Dan and Grangier, David and Fini, Enrico and El-Nouby, Alaaeldin and Ablin, Pierre},
booktitle={Advances in Neural Information Processing Systems},
year={2025},
url={https://arxiv.org/abs/2507.09404}
}This project is licensed under the Apple Sample Code License - see the LICENSE file for details.
Third-party dependencies and their licenses are documented in the ACKNOWLEDGMENTS file.