Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/src/usage/saving_and_loading.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,11 @@ The functions :func:`save_safetensors` and :func:`save_gguf` are similar to
>>> a = mx.array([1.0])
>>> b = mx.array([2.0])
>>> mx.save_safetensors("arrays", {"a": a, "b": b})
.. note::

When loading files from untrusted sources, MLX validates tensor metadata
before use. GGUF files are checked for valid dimension counts (max 8).
SafeTensors files are checked for consistent ``data_offsets`` (exactly 2
entries, correct ordering, and byte range matching the declared shape and
dtype). Invalid files will raise an error.
8 changes: 8 additions & 0 deletions mlx/io/gguf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <cstring>
#include <fstream>
#include <numeric>
#include <sstream>

#include "mlx/io/gguf.h"
#include "mlx/ops.h"
Expand Down Expand Up @@ -48,6 +49,13 @@ std::optional<Dtype> gguf_type_to_dtype(const uint32_t& gguf_type) {
}

Shape get_shape(const gguf_tensor& tensor) {
if (tensor.ndim > MLX_GGUF_MAX_DIMS) {
std::ostringstream msg;
msg << "[load_gguf] Tensor has " << tensor.ndim
<< " dimensions, but the maximum supported is " << MLX_GGUF_MAX_DIMS
<< ". The file may be corrupt or malicious.";
throw std::runtime_error(msg.str());
}
Shape shape;
// The dimension order in GGML is the reverse of the order used in MLX.
for (int i = tensor.ndim - 1; i >= 0; i--) {
Expand Down
7 changes: 7 additions & 0 deletions mlx/io/gguf.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@ extern "C" {
#include <gguflib.h>
}

// Maximum number of tensor dimensions supported by the GGUF format.
// Mirrors GGUF_TENSOR_MAX_DIM from gguflib.h. Override at compile time
// with -DMLX_GGUF_MAX_DIMS=<value> if the upstream format changes.
#ifndef MLX_GGUF_MAX_DIMS
#define MLX_GGUF_MAX_DIMS GGUF_TENSOR_MAX_DIM
#endif

namespace mlx::core {

Shape get_shape(const gguf_tensor& tensor);
Expand Down
24 changes: 24 additions & 0 deletions mlx/io/safetensors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,30 @@ SafetensorsLoad load_safetensors(
const Shape& shape = item.value().at("shape");
const std::vector<size_t>& data_offsets = item.value().at("data_offsets");
Dtype type = dtype_from_safetensor_str(dtype);
if (data_offsets.size() != 2) {
throw std::runtime_error(
"[load_safetensors] Tensor \"" + item.key() +
"\" data_offsets must have exactly 2 entries");
}
if (data_offsets[0] > data_offsets[1]) {
throw std::runtime_error(
"[load_safetensors] Tensor \"" + item.key() +
"\" data_offsets[0] > data_offsets[1]");
}
{
size_t expected_nbytes = type.size();
for (auto dim : shape) {
expected_nbytes *= static_cast<size_t>(dim);
}
if ((data_offsets[1] - data_offsets[0]) != expected_nbytes) {
throw std::runtime_error(
"[load_safetensors] Tensor \"" + item.key() +
"\" data_offsets range " +
std::to_string(data_offsets[1] - data_offsets[0]) +
" does not match expected byte size " +
std::to_string(expected_nbytes));
}
}
res.insert(
{item.key(),
array(
Expand Down
89 changes: 89 additions & 0 deletions python/tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,58 @@ def test_save_and_load_safetensors(self):
mx.array_equal(load_dict["test"], save_dict["test"])
)

def test_safetensors_rejects_mismatched_data_offsets(self):
"""Verify that data_offsets inconsistent with shape/dtype raise an error."""
import json
import struct

# shape=[1000,1000] F32 = 4,000,000 bytes, but data_offsets say 4
header = json.dumps(
{"t": {"dtype": "F32", "shape": [1000, 1000], "data_offsets": [0, 4]}}
)
buf = struct.pack("<Q", len(header)) + header.encode() + struct.pack("<f", 1.0)

bad_file = os.path.join(self.test_dir, "bad_offsets.safetensors")
with open(bad_file, "wb") as f:
f.write(buf)

with self.assertRaises(RuntimeError):
mx.load(bad_file)

def test_safetensors_rejects_bad_data_offsets_count(self):
"""Verify that data_offsets with != 2 entries raises an error."""
import json
import struct

header = json.dumps(
{"t": {"dtype": "F32", "shape": [1], "data_offsets": [0, 4, 8]}}
)
buf = struct.pack("<Q", len(header)) + header.encode() + struct.pack("<f", 1.0)

bad_file = os.path.join(self.test_dir, "bad_offsets_count.safetensors")
with open(bad_file, "wb") as f:
f.write(buf)

with self.assertRaises(RuntimeError):
mx.load(bad_file)

def test_safetensors_rejects_inverted_data_offsets(self):
"""Verify that data_offsets[0] > data_offsets[1] raises an error."""
import json
import struct

header = json.dumps(
{"t": {"dtype": "F32", "shape": [1], "data_offsets": [4, 0]}}
)
buf = struct.pack("<Q", len(header)) + header.encode() + struct.pack("<f", 1.0)

bad_file = os.path.join(self.test_dir, "bad_offsets_inverted.safetensors")
with open(bad_file, "wb") as f:
f.write(buf)

with self.assertRaises(RuntimeError):
mx.load(bad_file)

@unittest.skipIf(platform.system() == "Windows", "GGUF is disabled on Windows")
def test_save_and_load_gguf(self):
if not os.path.isdir(self.test_dir):
Expand Down Expand Up @@ -300,6 +352,43 @@ def test_save_and_load_gguf_metadata_mixed(self):
else:
self.assertEqual(meta_load_dict[k], v)

@unittest.skipIf(platform.system() == "Windows", "GGUF is disabled on Windows")
def test_gguf_rejects_oversized_ndim(self):
"""Verify that loading a GGUF file with ndim > 8 raises an error."""
import struct

malicious_ndim = 32
tensor_name = b"weight"
alignment = 32

buf = bytearray()
# GGUF header (v3)
buf += b"GGUF"
buf += struct.pack("<I", 3) # version
buf += struct.pack("<Q", 1) # tensor_count
buf += struct.pack("<Q", 0) # metadata_kv_count

# Tensor info
buf += struct.pack("<Q", len(tensor_name))
buf += tensor_name
buf += struct.pack("<I", malicious_ndim) # ndim = 32 (malicious)
for _ in range(malicious_ndim):
buf += struct.pack("<Q", 1) # each dim = 1
buf += struct.pack("<I", 0) # type = F32
buf += struct.pack("<Q", 0) # offset = 0

# Pad to alignment, then append tensor data
pad_needed = (alignment - (len(buf) % alignment)) % alignment
buf += b"\x00" * pad_needed
buf += struct.pack("<f", 1.0)

bad_file = os.path.join(self.test_dir, "bad_ndim.gguf")
with open(bad_file, "wb") as f:
f.write(buf)

with self.assertRaises(RuntimeError):
mx.load(bad_file)

def test_save_and_load_fs(self):
if not os.path.isdir(self.test_dir):
os.mkdir(self.test_dir)
Expand Down
117 changes: 117 additions & 0 deletions tests/load_tests.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright © 2023 Apple Inc.

#include <filesystem>
#include <fstream>
#include <stdexcept>
#include <vector>

Expand Down Expand Up @@ -40,6 +41,67 @@ TEST_CASE("test save_safetensors") {
CHECK(array_equal(test2, ones({2, 2})).item<bool>());
}

TEST_CASE("test safetensors rejects mismatched data_offsets") {
// Build a minimal safetensors file where data_offsets claim 4 bytes
// but shape declares 1000x1000 float32 (4,000,000 bytes).
// Verifies that load_safetensors() catches the mismatch.
std::string file_path = get_temp_file("test_bad_offsets.safetensors");

std::string header =
R"({"t":{"dtype":"F32","shape":[1000,1000],"data_offsets":[0,4]}})";
uint64_t header_len = header.size();

{
std::ofstream f(file_path, std::ios::binary);
f.write(reinterpret_cast<const char*>(&header_len), 8);
f.write(header.c_str(), header_len);
// Write only 4 bytes of data (the offsets claim [0,4])
float one = 1.0f;
f.write(reinterpret_cast<const char*>(&one), sizeof(float));
}

CHECK_THROWS_AS(load_safetensors(file_path), std::runtime_error);
}

TEST_CASE("test safetensors rejects bad data_offsets count") {
// data_offsets has 3 entries instead of the required 2.
std::string file_path = get_temp_file("test_bad_offsets_count.safetensors");

std::string header =
R"({"t":{"dtype":"F32","shape":[1],"data_offsets":[0,4,8]}})";
uint64_t header_len = header.size();

{
std::ofstream f(file_path, std::ios::binary);
f.write(reinterpret_cast<const char*>(&header_len), 8);
f.write(header.c_str(), header_len);
float one = 1.0f;
f.write(reinterpret_cast<const char*>(&one), sizeof(float));
}

CHECK_THROWS_AS(load_safetensors(file_path), std::runtime_error);
}

TEST_CASE("test safetensors rejects inverted data_offsets") {
// data_offsets[0] > data_offsets[1]
std::string file_path =
get_temp_file("test_bad_offsets_inverted.safetensors");

std::string header =
R"({"t":{"dtype":"F32","shape":[1],"data_offsets":[4,0]}})";
uint64_t header_len = header.size();

{
std::ofstream f(file_path, std::ios::binary);
f.write(reinterpret_cast<const char*>(&header_len), 8);
f.write(header.c_str(), header_len);
float one = 1.0f;
f.write(reinterpret_cast<const char*>(&one), sizeof(float));
}

CHECK_THROWS_AS(load_safetensors(file_path), std::runtime_error);
}

TEST_CASE("test gguf") {
std::string file_path = get_temp_file("test_arr.gguf");
using dict = std::unordered_map<std::string, array>;
Expand Down Expand Up @@ -201,6 +263,61 @@ TEST_CASE("test gguf metadata") {
}
}

TEST_CASE("test gguf rejects oversized ndim") {
// Build a minimal GGUF v3 file with ndim=32 (exceeds GGUF_TENSOR_MAX_DIM=8).
// Verifies the bounds check in get_shape() catches malicious files.
std::string file_path = get_temp_file("test_bad_ndim.gguf");

constexpr uint32_t malicious_ndim = 32;
constexpr const char* tensor_name = "weight";
constexpr size_t name_len = 6;
constexpr uint32_t alignment = 32;

auto write_le32 = [](std::ofstream& f, uint32_t v) {
char buf[4];
for (int i = 0; i < 4; i++)
buf[i] = (v >> (i * 8)) & 0xFF;
f.write(buf, 4);
};
auto write_le64 = [](std::ofstream& f, uint64_t v) {
char buf[8];
for (int i = 0; i < 8; i++)
buf[i] = (v >> (i * 8)) & 0xFF;
f.write(buf, 8);
};

{
std::ofstream f(file_path, std::ios::binary);
// GGUF header
f.write("GGUF", 4);
write_le32(f, 3); // version
write_le64(f, 1); // tensor_count = 1
write_le64(f, 0); // metadata_kv_count = 0

// Tensor info
write_le64(f, name_len);
f.write(tensor_name, name_len);
write_le32(f, malicious_ndim); // ndim = 32 (malicious)
for (uint32_t i = 0; i < malicious_ndim; i++) {
write_le64(f, 1); // each dim = 1
}
write_le32(f, 0); // type = F32
write_le64(f, 0); // offset = 0

// Pad to alignment, then write tensor data
auto pos = f.tellp();
size_t data_start =
((size_t)pos + alignment - 1) & ~(size_t)(alignment - 1);
for (size_t i = (size_t)pos; i < data_start; i++) {
f.put(0);
}
float one = 1.0f;
f.write(reinterpret_cast<const char*>(&one), sizeof(float));
}

CHECK_THROWS_AS(load_gguf(file_path), std::runtime_error);
}

TEST_CASE("test single array serialization") {
// Basic test
{
Expand Down