diff --git a/docs/src/usage/saving_and_loading.rst b/docs/src/usage/saving_and_loading.rst index 43f2a79990..d31e34c935 100644 --- a/docs/src/usage/saving_and_loading.rst +++ b/docs/src/usage/saving_and_loading.rst @@ -79,3 +79,11 @@ The functions :func:`save_safetensors` and :func:`save_gguf` are similar to >>> a = mx.array([1.0]) >>> b = mx.array([2.0]) >>> mx.save_safetensors("arrays", {"a": a, "b": b}) + +.. note:: + + When loading files from untrusted sources, MLX validates tensor metadata + before use. GGUF files are checked for valid dimension counts (max 8). + SafeTensors files are checked for consistent ``data_offsets`` (exactly 2 + entries, correct ordering, and byte range matching the declared shape and + dtype). Invalid files will raise an error. diff --git a/mlx/io/gguf.cpp b/mlx/io/gguf.cpp index 206f6fb31f..cd52ed32d1 100644 --- a/mlx/io/gguf.cpp +++ b/mlx/io/gguf.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include "mlx/io/gguf.h" #include "mlx/ops.h" @@ -48,6 +49,13 @@ std::optional gguf_type_to_dtype(const uint32_t& gguf_type) { } Shape get_shape(const gguf_tensor& tensor) { + if (tensor.ndim > MLX_GGUF_MAX_DIMS) { + std::ostringstream msg; + msg << "[load_gguf] Tensor has " << tensor.ndim + << " dimensions, but the maximum supported is " << MLX_GGUF_MAX_DIMS + << ". The file may be corrupt or malicious."; + throw std::runtime_error(msg.str()); + } Shape shape; // The dimension order in GGML is the reverse of the order used in MLX. for (int i = tensor.ndim - 1; i >= 0; i--) { diff --git a/mlx/io/gguf.h b/mlx/io/gguf.h index fa5bc458de..418c4fbf3d 100644 --- a/mlx/io/gguf.h +++ b/mlx/io/gguf.h @@ -10,6 +10,13 @@ extern "C" { #include } +// Maximum number of tensor dimensions supported by the GGUF format. +// Mirrors GGUF_TENSOR_MAX_DIM from gguflib.h. Override at compile time +// with -DMLX_GGUF_MAX_DIMS= if the upstream format changes. +#ifndef MLX_GGUF_MAX_DIMS +#define MLX_GGUF_MAX_DIMS GGUF_TENSOR_MAX_DIM +#endif + namespace mlx::core { Shape get_shape(const gguf_tensor& tensor); diff --git a/mlx/io/safetensors.cpp b/mlx/io/safetensors.cpp index b8f91ba9be..b0d3f14834 100644 --- a/mlx/io/safetensors.cpp +++ b/mlx/io/safetensors.cpp @@ -147,6 +147,30 @@ SafetensorsLoad load_safetensors( const Shape& shape = item.value().at("shape"); const std::vector& data_offsets = item.value().at("data_offsets"); Dtype type = dtype_from_safetensor_str(dtype); + if (data_offsets.size() != 2) { + throw std::runtime_error( + "[load_safetensors] Tensor \"" + item.key() + + "\" data_offsets must have exactly 2 entries"); + } + if (data_offsets[0] > data_offsets[1]) { + throw std::runtime_error( + "[load_safetensors] Tensor \"" + item.key() + + "\" data_offsets[0] > data_offsets[1]"); + } + { + size_t expected_nbytes = type.size(); + for (auto dim : shape) { + expected_nbytes *= static_cast(dim); + } + if ((data_offsets[1] - data_offsets[0]) != expected_nbytes) { + throw std::runtime_error( + "[load_safetensors] Tensor \"" + item.key() + + "\" data_offsets range " + + std::to_string(data_offsets[1] - data_offsets[0]) + + " does not match expected byte size " + + std::to_string(expected_nbytes)); + } + } res.insert( {item.key(), array( diff --git a/python/tests/test_load.py b/python/tests/test_load.py index 97fbf6c4c5..495cbeb61e 100644 --- a/python/tests/test_load.py +++ b/python/tests/test_load.py @@ -127,6 +127,58 @@ def test_save_and_load_safetensors(self): mx.array_equal(load_dict["test"], save_dict["test"]) ) + def test_safetensors_rejects_mismatched_data_offsets(self): + """Verify that data_offsets inconsistent with shape/dtype raise an error.""" + import json + import struct + + # shape=[1000,1000] F32 = 4,000,000 bytes, but data_offsets say 4 + header = json.dumps( + {"t": {"dtype": "F32", "shape": [1000, 1000], "data_offsets": [0, 4]}} + ) + buf = struct.pack(" data_offsets[1] raises an error.""" + import json + import struct + + header = json.dumps( + {"t": {"dtype": "F32", "shape": [1], "data_offsets": [4, 0]}} + ) + buf = struct.pack(" 8 raises an error.""" + import struct + + malicious_ndim = 32 + tensor_name = b"weight" + alignment = 32 + + buf = bytearray() + # GGUF header (v3) + buf += b"GGUF" + buf += struct.pack(" +#include #include #include @@ -40,6 +41,67 @@ TEST_CASE("test save_safetensors") { CHECK(array_equal(test2, ones({2, 2})).item()); } +TEST_CASE("test safetensors rejects mismatched data_offsets") { + // Build a minimal safetensors file where data_offsets claim 4 bytes + // but shape declares 1000x1000 float32 (4,000,000 bytes). + // Verifies that load_safetensors() catches the mismatch. + std::string file_path = get_temp_file("test_bad_offsets.safetensors"); + + std::string header = + R"({"t":{"dtype":"F32","shape":[1000,1000],"data_offsets":[0,4]}})"; + uint64_t header_len = header.size(); + + { + std::ofstream f(file_path, std::ios::binary); + f.write(reinterpret_cast(&header_len), 8); + f.write(header.c_str(), header_len); + // Write only 4 bytes of data (the offsets claim [0,4]) + float one = 1.0f; + f.write(reinterpret_cast(&one), sizeof(float)); + } + + CHECK_THROWS_AS(load_safetensors(file_path), std::runtime_error); +} + +TEST_CASE("test safetensors rejects bad data_offsets count") { + // data_offsets has 3 entries instead of the required 2. + std::string file_path = get_temp_file("test_bad_offsets_count.safetensors"); + + std::string header = + R"({"t":{"dtype":"F32","shape":[1],"data_offsets":[0,4,8]}})"; + uint64_t header_len = header.size(); + + { + std::ofstream f(file_path, std::ios::binary); + f.write(reinterpret_cast(&header_len), 8); + f.write(header.c_str(), header_len); + float one = 1.0f; + f.write(reinterpret_cast(&one), sizeof(float)); + } + + CHECK_THROWS_AS(load_safetensors(file_path), std::runtime_error); +} + +TEST_CASE("test safetensors rejects inverted data_offsets") { + // data_offsets[0] > data_offsets[1] + std::string file_path = + get_temp_file("test_bad_offsets_inverted.safetensors"); + + std::string header = + R"({"t":{"dtype":"F32","shape":[1],"data_offsets":[4,0]}})"; + uint64_t header_len = header.size(); + + { + std::ofstream f(file_path, std::ios::binary); + f.write(reinterpret_cast(&header_len), 8); + f.write(header.c_str(), header_len); + float one = 1.0f; + f.write(reinterpret_cast(&one), sizeof(float)); + } + + CHECK_THROWS_AS(load_safetensors(file_path), std::runtime_error); +} + TEST_CASE("test gguf") { std::string file_path = get_temp_file("test_arr.gguf"); using dict = std::unordered_map; @@ -201,6 +263,61 @@ TEST_CASE("test gguf metadata") { } } +TEST_CASE("test gguf rejects oversized ndim") { + // Build a minimal GGUF v3 file with ndim=32 (exceeds GGUF_TENSOR_MAX_DIM=8). + // Verifies the bounds check in get_shape() catches malicious files. + std::string file_path = get_temp_file("test_bad_ndim.gguf"); + + constexpr uint32_t malicious_ndim = 32; + constexpr const char* tensor_name = "weight"; + constexpr size_t name_len = 6; + constexpr uint32_t alignment = 32; + + auto write_le32 = [](std::ofstream& f, uint32_t v) { + char buf[4]; + for (int i = 0; i < 4; i++) + buf[i] = (v >> (i * 8)) & 0xFF; + f.write(buf, 4); + }; + auto write_le64 = [](std::ofstream& f, uint64_t v) { + char buf[8]; + for (int i = 0; i < 8; i++) + buf[i] = (v >> (i * 8)) & 0xFF; + f.write(buf, 8); + }; + + { + std::ofstream f(file_path, std::ios::binary); + // GGUF header + f.write("GGUF", 4); + write_le32(f, 3); // version + write_le64(f, 1); // tensor_count = 1 + write_le64(f, 0); // metadata_kv_count = 0 + + // Tensor info + write_le64(f, name_len); + f.write(tensor_name, name_len); + write_le32(f, malicious_ndim); // ndim = 32 (malicious) + for (uint32_t i = 0; i < malicious_ndim; i++) { + write_le64(f, 1); // each dim = 1 + } + write_le32(f, 0); // type = F32 + write_le64(f, 0); // offset = 0 + + // Pad to alignment, then write tensor data + auto pos = f.tellp(); + size_t data_start = + ((size_t)pos + alignment - 1) & ~(size_t)(alignment - 1); + for (size_t i = (size_t)pos; i < data_start; i++) { + f.put(0); + } + float one = 1.0f; + f.write(reinterpret_cast(&one), sizeof(float)); + } + + CHECK_THROWS_AS(load_gguf(file_path), std::runtime_error); +} + TEST_CASE("test single array serialization") { // Basic test {