diff --git a/mlx/io/gguf.cpp b/mlx/io/gguf.cpp index 206f6fb31f..c4dd93651b 100644 --- a/mlx/io/gguf.cpp +++ b/mlx/io/gguf.cpp @@ -5,6 +5,8 @@ #include #include +#include + #include "mlx/io/gguf.h" #include "mlx/ops.h" @@ -48,6 +50,14 @@ std::optional gguf_type_to_dtype(const uint32_t& gguf_type) { } Shape get_shape(const gguf_tensor& tensor) { + if (tensor.ndim > GGUF_TENSOR_MAX_DIM) { + throw std::runtime_error( + fmt::format( + "[load_gguf] Tensor has {} dimensions, but the maximum supported is {}." + " The file may be corrupt or malicious.", + tensor.ndim, + GGUF_TENSOR_MAX_DIM)); + } Shape shape; // The dimension order in GGML is the reverse of the order used in MLX. for (int i = tensor.ndim - 1; i >= 0; i--) { diff --git a/tests/load_tests.cpp b/tests/load_tests.cpp index 1531ce060c..172fe02c1e 100644 --- a/tests/load_tests.cpp +++ b/tests/load_tests.cpp @@ -1,6 +1,7 @@ // Copyright © 2023 Apple Inc. #include +#include #include #include @@ -201,6 +202,61 @@ TEST_CASE("test gguf metadata") { } } +TEST_CASE("test gguf rejects oversized ndim") { + // Build a minimal GGUF v3 file with ndim=32 (exceeds GGUF_TENSOR_MAX_DIM=8). + // Verifies the bounds check in get_shape() catches malicious files. + std::string file_path = get_temp_file("test_bad_ndim.gguf"); + + constexpr uint32_t malicious_ndim = 32; + constexpr const char* tensor_name = "weight"; + constexpr size_t name_len = 6; + constexpr uint32_t alignment = 32; + + auto write_le32 = [](std::ofstream& f, uint32_t v) { + char buf[4]; + for (int i = 0; i < 4; i++) + buf[i] = (v >> (i * 8)) & 0xFF; + f.write(buf, 4); + }; + auto write_le64 = [](std::ofstream& f, uint64_t v) { + char buf[8]; + for (int i = 0; i < 8; i++) + buf[i] = (v >> (i * 8)) & 0xFF; + f.write(buf, 8); + }; + + { + std::ofstream f(file_path, std::ios::binary); + // GGUF header + f.write("GGUF", 4); + write_le32(f, 3); // version + write_le64(f, 1); // tensor_count = 1 + write_le64(f, 0); // metadata_kv_count = 0 + + // Tensor info + write_le64(f, name_len); + f.write(tensor_name, name_len); + write_le32(f, malicious_ndim); // ndim = 32 (malicious) + for (uint32_t i = 0; i < malicious_ndim; i++) { + write_le64(f, 1); // each dim = 1 + } + write_le32(f, 0); // type = F32 + write_le64(f, 0); // offset = 0 + + // Pad to alignment, then write tensor data + auto pos = f.tellp(); + size_t data_start = + ((size_t)pos + alignment - 1) & ~(size_t)(alignment - 1); + for (size_t i = (size_t)pos; i < data_start; i++) { + f.put(0); + } + float one = 1.0f; + f.write(reinterpret_cast(&one), sizeof(float)); + } + + CHECK_THROWS_AS(load_gguf(file_path), std::runtime_error); +} + TEST_CASE("test single array serialization") { // Basic test {