Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions mlx/io/gguf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include <fstream>
#include <numeric>

#include <fmt/format.h>

#include "mlx/io/gguf.h"
#include "mlx/ops.h"

Expand Down Expand Up @@ -48,6 +50,14 @@ std::optional<Dtype> gguf_type_to_dtype(const uint32_t& gguf_type) {
}

Shape get_shape(const gguf_tensor& tensor) {
if (tensor.ndim > GGUF_TENSOR_MAX_DIM) {
throw std::runtime_error(
fmt::format(
"[load_gguf] Tensor has {} dimensions, but the maximum supported is {}."
" The file may be corrupt or malicious.",
tensor.ndim,
GGUF_TENSOR_MAX_DIM));
}
Shape shape;
// The dimension order in GGML is the reverse of the order used in MLX.
for (int i = tensor.ndim - 1; i >= 0; i--) {
Expand Down
56 changes: 56 additions & 0 deletions tests/load_tests.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright © 2023 Apple Inc.

#include <filesystem>
#include <fstream>
#include <stdexcept>
#include <vector>

Expand Down Expand Up @@ -201,6 +202,61 @@ TEST_CASE("test gguf metadata") {
}
}

TEST_CASE("test gguf rejects oversized ndim") {
// Build a minimal GGUF v3 file with ndim=32 (exceeds GGUF_TENSOR_MAX_DIM=8).
// Verifies the bounds check in get_shape() catches malicious files.
std::string file_path = get_temp_file("test_bad_ndim.gguf");

constexpr uint32_t malicious_ndim = 32;
constexpr const char* tensor_name = "weight";
constexpr size_t name_len = 6;
constexpr uint32_t alignment = 32;

auto write_le32 = [](std::ofstream& f, uint32_t v) {
char buf[4];
for (int i = 0; i < 4; i++)
buf[i] = (v >> (i * 8)) & 0xFF;
f.write(buf, 4);
};
auto write_le64 = [](std::ofstream& f, uint64_t v) {
char buf[8];
for (int i = 0; i < 8; i++)
buf[i] = (v >> (i * 8)) & 0xFF;
f.write(buf, 8);
};

{
std::ofstream f(file_path, std::ios::binary);
// GGUF header
f.write("GGUF", 4);
write_le32(f, 3); // version
write_le64(f, 1); // tensor_count = 1
write_le64(f, 0); // metadata_kv_count = 0

// Tensor info
write_le64(f, name_len);
f.write(tensor_name, name_len);
write_le32(f, malicious_ndim); // ndim = 32 (malicious)
for (uint32_t i = 0; i < malicious_ndim; i++) {
write_le64(f, 1); // each dim = 1
}
write_le32(f, 0); // type = F32
write_le64(f, 0); // offset = 0

// Pad to alignment, then write tensor data
auto pos = f.tellp();
size_t data_start =
((size_t)pos + alignment - 1) & ~(size_t)(alignment - 1);
for (size_t i = (size_t)pos; i < data_start; i++) {
f.put(0);
}
float one = 1.0f;
f.write(reinterpret_cast<const char*>(&one), sizeof(float));
}

CHECK_THROWS_AS(load_gguf(file_path), std::runtime_error);
}

TEST_CASE("test single array serialization") {
// Basic test
{
Expand Down
Loading