diff --git a/common/common.cpp b/common/common.cpp index 4fc0b9e5b..04ddc26e4 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1490,6 +1490,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.use_mmap = false; return true; } + if (arg == "-dio" || arg == "--direct-io") { + params.use_direct_io = true; + params.use_mmap = false; + return true; + } if (arg == "-rtr" || arg == "--run-time-repack") { params.repack_tensors = true; params.use_mmap = false; @@ -2421,6 +2426,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param if (llama_supports_mmap()) { options.push_back({ "*", " --no-mmap", "do not memory-map model (slower load but may reduce pageouts if not using mlock)" }); } + options.push_back({ "*", "-dio, --direct-io", "use DirectIO if available (disables mmap)"}); options.push_back({ "*", " --run-time-repack", "repack tensors if interleaved variant is available"}); options.push_back({ "*", " --cpu-moe", "keep all MoE weights in CPU memory"}); options.push_back({ "*", " --n-cpu-moe N", "keep MoE weights of the first N layers in CPU memory"}); @@ -3200,6 +3206,7 @@ struct llama_model_params common_model_params_to_llama(const gpt_params & params mparams.split_mode = params.split_mode; mparams.tensor_split = params.tensor_split; mparams.use_mmap = params.use_mmap; + mparams.use_direct_io = params.use_direct_io; mparams.use_mlock = params.use_mlock; mparams.check_tensors = params.check_tensors; mparams.repack_tensors = params.repack_tensors; @@ -4286,6 +4293,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "n_predict: %d # default: -1 (unlimited)\n", params.n_predict); fprintf(stream, "n_probs: %d # only used by server binary, default: 0\n", sparams.n_probs); fprintf(stream, "no_mmap: %s # default: false\n", !params.use_mmap ? "true" : "false"); + fprintf(stream, "direct-io: %s # default: false\n", params.use_direct_io ? "true" : "false"); fprintf(stream, "repack: %s # default: false\n", params.repack_tensors ? "true" : "false"); fprintf(stream, "use_thp: %s # default: false\n", params.use_thp ? "true" : "false"); fprintf(stream, "validate_quants: %s # default: false\n", params.validate_quants ? "true" : "false"); diff --git a/common/common.h b/common/common.h index 8a958cdd8..fe78fe068 100644 --- a/common/common.h +++ b/common/common.h @@ -337,6 +337,7 @@ struct gpt_params { bool ignore_eos = false; // ignore generated EOS tokens bool logits_all = false; // return logits for all tokens in the batch bool use_mmap = true; // use mmap for faster loads + bool use_direct_io = false; // read from disk without buffering bool use_mlock = false; // use mlock to keep model in memory bool verbose_prompt = false; // print prompt tokens before generation bool display_prompt = true; // print prompt before generation diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index a4420fd61..5743169a9 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -254,6 +254,7 @@ struct cmd_params { std::vector reuse; std::vector> tensor_split; std::vector use_mmap; + std::vector use_direct_io; std::vector embeddings; std::vector buft_overrides; ggml_numa_strategy numa; @@ -299,6 +300,7 @@ static const cmd_params cmd_params_defaults = { /* reuse */ {true}, /* tensor_split */ {std::vector(llama_max_devices(), 0.0f)}, /* use_mmap */ {true}, + /* use_direct_io */ {false}, /* embeddings */ {false}, /* buft_overrides */ {}, /* numa */ GGML_NUMA_STRATEGY_DISABLED, @@ -349,6 +351,7 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -ser, --smart-expert-reduction (default: %s)\n", join(cmd_params_defaults.attn_max_batch, ",").c_str()); printf(" -gr, --graph-reuse <0|1> (default: %s)\n", join(cmd_params_defaults.reuse, ",").c_str()); printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str()); + printf(" -dio, --direct-io <0|1> (default: %s)\n", join(cmd_params_defaults.use_direct_io, ",").c_str()); printf(" --numa (default: disabled)\n"); printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str()); printf(" -ts, --tensor-split (default: 0)\n"); @@ -725,6 +728,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { } auto p = string_split(argv[i], split_delim); params.use_mmap.insert(params.use_mmap.end(), p.begin(), p.end()); + } else if (arg == "-dio" || arg == "--direct-io") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = string_split(argv[i], split_delim); + params.use_direct_io.insert(params.use_direct_io.end(), p.begin(), p.end()); } else if (arg == "-embd" || arg == "--embeddings") { if (++i >= argc) { invalid_param = true; @@ -904,6 +914,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { if (params.ser.empty()) { params.ser = cmd_params_defaults.ser; } if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; } if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; } + if (params.use_direct_io.empty()) { params.use_direct_io = cmd_params_defaults.use_direct_io; } if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; } if (params.n_threads.empty()) { params.n_threads = cmd_params_defaults.n_threads; } if (!params.buft_overrides.empty()) params.buft_overrides.emplace_back(llama_model_tensor_buft_override{nullptr, nullptr}); @@ -945,6 +956,7 @@ struct cmd_params_instance { std::vector tensor_split; std::string cuda_params; bool use_mmap; + bool use_direct_io = false; bool embeddings; bool repack = false; bool fmoe = true; @@ -969,6 +981,7 @@ struct cmd_params_instance { mparams.main_gpu = main_gpu; mparams.tensor_split = tensor_split.data(); mparams.use_mmap = use_mmap; + mparams.use_direct_io = use_direct_io; mparams.repack_tensors = repack; mparams.use_thp = use_thp; mparams.merge_qkv = mqkv; @@ -986,6 +999,7 @@ struct cmd_params_instance { split_mode == other.split_mode && main_gpu == other.main_gpu && use_mmap == other.use_mmap && + use_direct_io == other.use_direct_io && repack == other.repack && mqkv == other.mqkv && muge == other.muge && @@ -1032,6 +1046,7 @@ static std::vector get_cmd_params_instances(const cmd_param for (const auto & mg : params.main_gpu) for (const auto & ts : params.tensor_split) for (const auto & mmp : params.use_mmap) + for (const auto & dio : params.use_direct_io) for (const auto & embd : params.embeddings) for (const auto & nb : params.n_batch) for (const auto & nub : params.n_ubatch) @@ -1071,6 +1086,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .tensor_split = */ ts, /* .cuda_params = */ params.cuda_params, /* .use_mmap = */ mmp, + /* .use_direct_io= */ dio, /* .embeddings = */ embd, /* .repack = */ params.repack, /* .fmoe = */ params.fmoe, @@ -1114,6 +1130,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .tensor_split = */ ts, /* .cuda_params = */ params.cuda_params, /* .use_mmap = */ mmp, + /* .use_direct_io= */ dio, /* .embeddings = */ embd, /* .repack = */ params.repack, /* .fmoe = */ params.fmoe, @@ -1157,6 +1174,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .tensor_split = */ ts, /* .cuda_params = */ params.cuda_params, /* .use_mmap = */ mmp, + /* .use_direct_io= */ dio, /* .embeddings = */ embd, /* .repack = */ params.repack, /* .fmoe = */ params.fmoe, @@ -1200,6 +1218,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .tensor_split = */ ts, /* .cuda_params = */ params.cuda_params, /* .use_mmap = */ mmp, + /* .use_direct_io= */ dio, /* .embeddings = */ embd, /* .repack = */ params.repack, /* .fmoe = */ params.fmoe, @@ -1254,6 +1273,7 @@ struct test { std::vector tensor_split; std::string cuda_params; bool use_mmap; + bool use_direct_io = false; bool embeddings; bool repack = false; bool fmoe = false; @@ -1298,6 +1318,7 @@ struct test { tensor_split = inst.tensor_split; cuda_params = inst.cuda_params; use_mmap = inst.use_mmap; + use_direct_io = inst.use_direct_io; embeddings = inst.embeddings; repack = inst.repack; mqkv = inst.mqkv; @@ -1415,8 +1436,9 @@ struct test { } if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" || field == "gpu_blas" || field == "blas" || field == "sycl" || field == "no_kv_offload" || - field == "flash_attn" || field == "use_mmap" || field == "embeddings" || field == "repack" || field == "use_thp" || - field == "fused_moe" || field == "grouped_er" || field == "no_fused_up_gate" || field == "no_ooae" || field == "mqkv" || + field == "flash_attn" || field == "use_mmap" || field == "use_direct_io" || field == "embeddings" || + field == "repack" || field == "use_thp" || field == "fused_moe" || field == "grouped_er" || + field == "no_fused_up_gate" || field == "no_ooae" || field == "mqkv" || field == "rcache" || field == "reuse" || field == "muge" || field == "sas") { return BOOL; } @@ -1459,7 +1481,7 @@ struct test { std::to_string(n_gpu_layers), split_mode_str(split_mode), std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), std::to_string(mla_attn), std::to_string(attn_max_batch), ser_to_string(ser), std::to_string(reuse), - tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), + tensor_split_str, std::to_string(use_mmap), std::to_string(use_direct_io), std::to_string(embeddings), std::to_string(repack), std::to_string(mqkv), std::to_string(muge), std::to_string(fmoe), std::to_string(ger), std::to_string(no_fug), std::to_string(use_thp), std::to_string(no_ooae), std::to_string(rcache), std::to_string(sas), cuda_params, override_tensor, @@ -1481,7 +1503,8 @@ struct test { "n_threads", "type_k", "type_v", "n_gpu_layers", "split_mode", "main_gpu", "no_kv_offload", "flash_attn", "mla_attn", "attn_max_batch", "ser", "reuse", - "tensor_split", "use_mmap", "embeddings", "repack", "mqkv", "muge", "fused_moe", "grouped_er", + "tensor_split", "use_mmap", "use_direct_io", "embeddings", + "repack", "mqkv", "muge", "fused_moe", "grouped_er", "no_fused_up_gate", "use_thp", "no_ooae", "rcache", "sas", "cuda_params", "override_tensor", "n_prompt", "n_gen", "test_time", "avg_ns", "stddev_ns", @@ -1660,6 +1683,9 @@ struct markdown_printer : public printer { if (field == "use_mmap") { return 4; } + if (field == "use_direct_io") { + return 3; + } if (field == "repack") { return 3; } @@ -1733,6 +1759,9 @@ struct markdown_printer : public printer { if (field == "use_mmap") { return "mmap"; } + if (field == "use_direct_io") { + return "dio"; + } if (field == "repack") { return "rtr"; } @@ -1833,6 +1862,9 @@ struct markdown_printer : public printer { if (params.use_mmap.size() > 1 || params.use_mmap != cmd_params_defaults.use_mmap) { fields.emplace_back("use_mmap"); } + if (params.use_direct_io.size() > 1 || params.use_direct_io != cmd_params_defaults.use_direct_io) { + fields.emplace_back("use_direct_io"); + } if (params.embeddings.size() > 1 || params.embeddings != cmd_params_defaults.embeddings) { fields.emplace_back("embeddings"); } diff --git a/include/llama.h b/include/llama.h index 5332d8103..e1da80723 100644 --- a/include/llama.h +++ b/include/llama.h @@ -387,6 +387,7 @@ extern "C" { // Keep the booleans together to avoid misalignment during copy-by-value. bool vocab_only; // only load the vocabulary, no weights bool use_mmap; // use mmap if possible + bool use_direct_io; // use direct io, takes precedence over use_mmap when supported bool use_mlock; // force system to keep model in RAM bool check_tensors; // validate model tensor data bool repack_tensors;// repack if available diff --git a/src/llama-mmap.cpp b/src/llama-mmap.cpp index 4a65c9cb6..7a907a294 100644 --- a/src/llama-mmap.cpp +++ b/src/llama-mmap.cpp @@ -15,9 +15,10 @@ #ifdef __has_include #if __has_include() #include + #include + #include #if defined(_POSIX_MAPPED_FILES) #include - #include #endif #if defined(_POSIX_MEMLOCK_RANGE) #include @@ -76,7 +77,7 @@ struct llama_file::impl { return ret; } - impl(const char * fname, const char * mode) { + impl(const char * fname, const char * mode, [[maybe_unused]] const bool use_direct_io = false) { fp = ggml_fopen(fname, mode); if (fp == NULL) { throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno))); @@ -111,7 +112,7 @@ struct llama_file::impl { } } - void read_raw(void * ptr, size_t len) const { + void read_raw(void * ptr, size_t len) { size_t bytes_read = 0; while (bytes_read < len) { size_t chunk_size = std::min(len - bytes_read, 64*1024*1024); @@ -128,7 +129,7 @@ struct llama_file::impl { } } - uint32_t read_u32() const { + uint32_t read_u32() { uint32_t val; read_raw(&val, sizeof(val)); return val; @@ -155,16 +156,55 @@ struct llama_file::impl { write_raw(&val, sizeof(val)); } + bool has_direct_io() const { + return true; + } + ~impl() { if (fp) { std::fclose(fp); } } #else - impl(const char * fname, const char * mode) { - fp = ggml_fopen(fname, mode); + impl(const char * fname, const char * mode, [[maybe_unused]] const bool use_direct_io = false) : fname(fname) { +#ifdef __linux__ + // Try unbuffered I/O for read only + if (use_direct_io && std::strcmp(mode, "rb") == 0) { + if (init_fd()) { + return; + } + LLAMA_LOG_WARN("Failed to open file '%s' with error: %s. Falling back to buffered I/O", + fname, strerror(errno)); + } +#endif + init_fp(mode); + } + +#ifdef __linux__ + bool init_fd() { + fd = open(fname.c_str(), O_RDONLY | O_DIRECT); + + if (fd != -1) { + struct stat file_stats{}; + fstat(fd, &file_stats); + + size = file_stats.st_size; + alignment = file_stats.st_blksize; + + off_t ret = lseek(fd, 0, SEEK_SET); + if (ret == -1) { + throw std::runtime_error(format("seek error: %s", strerror(errno))); + } + return true; + } + return false; + } +#endif + + void init_fp(const char * mode) { + fp = ggml_fopen(fname.c_str(), mode); if (fp == NULL) { - throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno))); + throw std::runtime_error(format("failed to open %s: %s", fname.c_str(), strerror(errno))); } seek(0, SEEK_END); size = tell(); @@ -172,46 +212,122 @@ struct llama_file::impl { } size_t tell() const { -// TODO: this ifdef is never true? -#ifdef _WIN32 - __int64 ret = _ftelli64(fp); -#else - long ret = std::ftell(fp); -#endif - if (ret == -1) { - throw std::runtime_error(format("ftell error: %s", strerror(errno))); + if (fd == -1) { + long ret = std::ftell(fp); + if (ret == -1) { + throw std::runtime_error(format("ftell error: %s", strerror(errno))); + } + + return (size_t) ret; } - return (size_t) ret; + off_t pos = lseek(fd, 0, SEEK_CUR); + if (pos == -1) { + throw std::runtime_error(format("lseek error: %s", strerror(errno))); + } + return (size_t) pos; } void seek(size_t offset, int whence) const { -// TODO: this ifdef is never true? -#ifdef _WIN32 - int ret = _fseeki64(fp, (__int64) offset, whence); -#else - int ret = std::fseek(fp, (long) offset, whence); -#endif - if (ret != 0) { + off_t ret = 0; + if (fd == -1) { + ret = std::fseek(fp, (long) offset, whence); + } else { + ret = lseek(fd, offset, whence); + } + if (ret == -1) { throw std::runtime_error(format("seek error: %s", strerror(errno))); } } - void read_raw(void * ptr, size_t len) const { + void read_raw_unsafe(void * ptr, size_t len) { if (len == 0) { return; } errno = 0; - std::size_t ret = std::fread(ptr, len, 1, fp); - if (ferror(fp)) { - throw std::runtime_error(format("read error: %s", strerror(errno))); + if (fd == -1) { + const size_t curr_off = tell(); + const size_t to_read = std::min(len, size - curr_off); + + std::size_t ret = std::fread(ptr, to_read, 1, fp); + if (ferror(fp)) { + throw std::runtime_error(format("read error: %s", strerror(errno))); + } + if (to_read > 0 && ret != 1) { + throw std::runtime_error("unexpectedly reached end of file"); + } + } else { + size_t bytes_read = 0; + while (bytes_read < len) { + const size_t to_read = len - bytes_read; + ssize_t ret = ::read(fd, reinterpret_cast(ptr) + bytes_read, to_read); + + if (ret == -1) { + if (errno == EINTR) { + continue; // Interrupted by signal, retry + } + // Fallback to std::fread in case the DMA controller cannot access the buffer + if (errno == EFAULT || errno == EINVAL) { + LLAMA_LOG_WARN("%s: Falling back to buffered IO due to %s\n", __func__, strerror(errno)); + auto curr_off = tell(); + close(fd); + fd = -1; + alignment = 1; + init_fp("rb"); + seek(curr_off, SEEK_SET); + read_raw_unsafe(ptr, len); + return; + } + throw std::runtime_error(format("read error: %s", strerror(errno))); + } + if (ret == 0) { + // EOF: allow if this read was only pulling alignment padding past file end + off_t pos = lseek(fd, 0, SEEK_CUR); + if (pos != -1 && (size_t) pos == size) { + std::memset(reinterpret_cast(ptr) + bytes_read, 0, len - bytes_read); + return; + } + throw std::runtime_error("unexpectedly reached end of file"); + } + + bytes_read += (size_t) ret; + } } - if (ret != 1) { - throw std::runtime_error("unexpectedly reached end of file"); + } + + void read_aligned_chunk(void * dest, size_t size) { + size_t offset = tell(); + off_t aligned_offset = offset & ~(alignment - 1); + off_t offset_from_alignment = offset - aligned_offset; + size_t bytes_to_read = (offset_from_alignment + size + alignment - 1) & ~(alignment - 1); + + void * raw_buffer = nullptr; + int ret = posix_memalign(&raw_buffer, alignment, bytes_to_read); + if (ret != 0) { + throw std::runtime_error(format("posix_memalign failed with error %d", ret)); } + + struct aligned_buffer_deleter { + void operator()(void * p) const { free(p); } + }; + std::unique_ptr buffer(raw_buffer); + + seek(aligned_offset, SEEK_SET); + read_raw_unsafe(buffer.get(), bytes_to_read); + + uintptr_t actual_data = reinterpret_cast(buffer.get()) + offset_from_alignment; + memcpy(dest, reinterpret_cast(actual_data), size); } - uint32_t read_u32() const { + void read_raw(void * ptr, size_t len) { + if (has_direct_io()) { + read_aligned_chunk(ptr, len); + } else { + read_raw_unsafe(ptr, len); + } + } + + uint32_t read_u32() { uint32_t ret; read_raw(&ret, sizeof(ret)); return ret; @@ -232,27 +348,48 @@ struct llama_file::impl { write_raw(&val, sizeof(val)); } + bool has_direct_io() const { + return fd != -1 && alignment > 1; + } + ~impl() { - if (fp) { + if (fd != -1) { + close(fd); + } else { std::fclose(fp); } } + int fd = -1; + std::string fname; #endif - FILE * fp; - size_t size; + size_t read_alignment() const { + return alignment; + } + + size_t alignment = 1; + + FILE * fp{}; + size_t size{}; }; -llama_file::llama_file(const char * fname, const char * mode) : pimpl(std::make_unique(fname, mode)) {} +llama_file::llama_file(const char * fname, const char * mode, const bool use_direct_io) : + pimpl(std::make_unique(fname, mode, use_direct_io)) {} llama_file::~llama_file() = default; size_t llama_file::tell() const { return pimpl->tell(); } size_t llama_file::size() const { return pimpl->size; } +size_t llama_file::read_alignment() const { return pimpl->read_alignment(); } +bool llama_file::has_direct_io() const { return pimpl->has_direct_io(); } + int llama_file::file_id() const { #ifdef _WIN32 return _fileno(pimpl->fp); #else + if (pimpl->fd != -1) { + return pimpl->fd; + } #if defined(fileno) return fileno(pimpl->fp); #else @@ -262,9 +399,14 @@ int llama_file::file_id() const { } void llama_file::seek(size_t offset, int whence) const { pimpl->seek(offset, whence); } -void llama_file::read_raw(void * ptr, size_t len) const { pimpl->read_raw(ptr, len); } +void llama_file::read_raw(void * ptr, size_t len) { pimpl->read_raw(ptr, len); } +#ifdef _WIN32 +void llama_file::read_raw_unsafe(void * ptr, size_t len) { pimpl->read_raw(ptr, len); } +#else +void llama_file::read_raw_unsafe(void * ptr, size_t len) { pimpl->read_raw_unsafe(ptr, len); } +#endif -uint32_t llama_file::read_u32() const { return pimpl->read_u32(); } +uint32_t llama_file::read_u32() { return pimpl->read_u32(); } void llama_file::write_raw(const void * ptr, size_t len) const { pimpl->write_raw(ptr, len); } void llama_file::write_u32(uint32_t val) const { pimpl->write_u32(val); } diff --git a/src/llama-mmap.h b/src/llama-mmap.h index a1efa068f..0c658082c 100644 --- a/src/llama-mmap.h +++ b/src/llama-mmap.h @@ -3,6 +3,7 @@ #include #include #include +#include struct llama_file; struct llama_mmap; @@ -13,7 +14,7 @@ using llama_mmaps = std::vector>; using llama_mlocks = std::vector>; struct llama_file { - llama_file(const char * fname, const char * mode); + llama_file(const char * fname, const char * mode, bool use_direct_io = false); ~llama_file(); size_t tell() const; @@ -23,12 +24,17 @@ struct llama_file { void seek(size_t offset, int whence) const; - void read_raw(void * ptr, size_t len) const; - uint32_t read_u32() const; + void read_raw(void * ptr, size_t len); + void read_raw_unsafe(void * ptr, size_t len); + void read_aligned_chunk(void * dest, size_t size); + uint32_t read_u32(); void write_raw(const void * ptr, size_t len) const; void write_u32(uint32_t val) const; + size_t read_alignment() const; + bool has_direct_io() const; + private: struct impl; std::unique_ptr pimpl; diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index 4437c1ecf..c13e6f3c8 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -204,7 +204,7 @@ namespace GGUFMeta { }; } -llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, +llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap, bool use_direct_io, bool check_tensors, bool repack_tensors, bool use_thp, bool merge_qkv, bool merge_up_gate_exps, const llama_model_kv_override * param_overrides_p, const llama_model_tensor_buft_override * param_tensor_buft_overrides_p) { @@ -253,9 +253,23 @@ llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap, get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false); llm_kv = LLM_KV(llm_arch_from_string(arch_name)); - files.emplace_back(new llama_file(fname.c_str(), "rb")); + files.emplace_back(new llama_file(fname.c_str(), "rb", use_direct_io)); contexts.emplace_back(ctx); + if (use_mmap && use_direct_io) { + if (files.back()->has_direct_io()) { + LLAMA_LOG_WARN("%s: direct I/O is enabled, disabling mmap\n", __func__); + use_mmap = false; + } else { + LLAMA_LOG_WARN("%s: direct I/O is not available, using mmap\n", __func__); + use_direct_io = false; + + // reopen file using std::fopen for mmap + files.pop_back(); + files.emplace_back(new llama_file(fname.c_str(), "rb", false)); + } + } + // Save tensors data offset of the main file. // For subsidiary files, `meta` tensor data offset must not be used, // so we build a unified tensors index for weights. @@ -295,7 +309,7 @@ llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap, throw std::runtime_error(format("%s: failed to load GGUF split from %s\n", __func__, split_path)); } - files.emplace_back(new llama_file(split_path, "rb")); + files.emplace_back(new llama_file(split_path, "rb", use_direct_io)); contexts.emplace_back(ctx); // Save tensors data offset info of the shard. @@ -494,6 +508,7 @@ llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap, } this->use_mmap = use_mmap; + this->use_direct_io = use_direct_io; this->check_tensors = check_tensors; this->repack_tensors = repack_tensors; this->use_thp = use_thp; @@ -903,7 +918,15 @@ bool llama_model_loader::load_all_data( // 4 staging buffers for async uploads, each sized 1MB seems to be a good default for single NVMe drives. // NVMe raid configurations might require more / larger buffers. constexpr size_t n_buffers = 4; - constexpr size_t buffer_size = 1 * 1024 * 1024; // 1MB + + size_t alignment = 1; + for (const auto & file : files) { + alignment = std::max(file->read_alignment(), alignment); + } + + // Buffer size: balance between memory usage and I/O efficiency + // 64MB works well for NVMe drives + const size_t buffer_size = alignment != 1 ? 64 * 1024 * 1024 + 2 * alignment : 1 * 1024 * 1024; std::vector host_buffers; std::vector host_ptrs; @@ -995,19 +1018,54 @@ bool llama_model_loader::load_all_data( #if defined(GGML_USE_CUDA) // If cuda_backend is valid load the tensor in chunks to pinned memory and upload the buffers asynchronously to the GPU. if (cuda_backend) { - file->seek(weight->offs, SEEK_SET); + size_t offset = weight->offs; + alignment = file->read_alignment(); + size_t aligned_offset = offset & ~(alignment - 1); + size_t offset_from_alignment = offset - aligned_offset; + file->seek(aligned_offset, SEEK_SET); + + // Calculate aligned read boundaries + size_t read_start = aligned_offset; + size_t read_end = (offset + n_size + alignment - 1) & ~(alignment - 1); size_t bytes_read = 0; + size_t data_read = 0; // Actual tensor data copied (excluding padding) + + while (bytes_read < read_end - read_start) { + size_t read_size = std::min(buffer_size, read_end - read_start - bytes_read); - while (bytes_read < n_size) { - size_t read_iteration = std::min(buffer_size, n_size - bytes_read); + // Align the destination pointer within the pinned buffer + uintptr_t ptr_dest_aligned = (reinterpret_cast(host_ptrs[buffer_idx]) + alignment - 1) & ~(alignment - 1); + // Wait for previous upload to complete before reusing buffer ggml_backend_event_synchronize(events[buffer_idx]); - file->read_raw(host_ptrs[buffer_idx], read_iteration); - ggml_backend_tensor_set_async(cuda_backend, cur, host_ptrs[buffer_idx], bytes_read, read_iteration); + + // Read aligned chunk from file + file->read_raw_unsafe(reinterpret_cast(ptr_dest_aligned), read_size); + + // Calculate actual data portion (excluding alignment padding) + uintptr_t ptr_data = ptr_dest_aligned; + size_t data_to_copy = read_size; + + // Skip alignment padding at start of first chunk + if (bytes_read == 0) { + ptr_data += offset_from_alignment; + data_to_copy -= offset_from_alignment; + } + + // Trim alignment padding at end of last chunk + if (aligned_offset + bytes_read + read_size > offset + n_size) { + data_to_copy -= (read_end - (offset + n_size)); + } + + // Async upload actual data to GPU + ggml_backend_tensor_set_async(cuda_backend, cur, + reinterpret_cast(ptr_data), data_read, data_to_copy); ggml_backend_event_record(events[buffer_idx]); - bytes_read += read_iteration; + data_read += data_to_copy; + bytes_read += read_size; + ++buffer_idx; buffer_idx %= n_buffers; } diff --git a/src/llama-model-loader.h b/src/llama-model-loader.h index c59eaf4f3..1dc274d80 100644 --- a/src/llama-model-loader.h +++ b/src/llama-model-loader.h @@ -41,6 +41,7 @@ struct llama_model_loader { size_t n_bytes = 0; bool use_mmap = false; + bool use_direct_io = false; bool check_tensors; bool repack_tensors = false; bool use_thp = false; @@ -80,7 +81,7 @@ struct llama_model_loader { std::string arch_name; LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN); - llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, bool repack_tensors, bool use_thp, + llama_model_loader(const std::string & fname, bool use_mmap, bool use_direct_io, bool check_tensors, bool repack_tensors, bool use_thp, bool merge_qkv, bool merge_up_gate_exps, const llama_model_kv_override * param_overrides_p, const llama_model_tensor_buft_override * param_tensor_buft_overrides_p); diff --git a/src/llama-quantize.cpp b/src/llama-quantize.cpp index 42e3fd75b..d06968677 100644 --- a/src/llama-quantize.cpp +++ b/src/llama-quantize.cpp @@ -1008,7 +1008,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s auto v = (std::vector*)params->kv_overrides; kv_overrides = v->data(); } - llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, /* repack_tensors */ false, + llama_model_loader ml(fname_inp, use_mmap, /*use_direct_io*/ false, /*check_tensors*/ true, /* repack_tensors */ false, /* use_thp */ false, /* merge_qkv */ false, /* merge_up_gate_exps */ false, kv_overrides, nullptr); ml.init_mappings(false); // no prefetching diff --git a/src/llama.cpp b/src/llama.cpp index 3d22de1e3..069123850 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2246,7 +2246,7 @@ static bool llm_load_tensors( // Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback static int llama_model_load(const std::string & fname, llama_model & model, llama_model_params & params) { try { - llama_model_loader ml(fname, params.use_mmap, params.check_tensors, + llama_model_loader ml(fname, params.use_mmap, params.use_direct_io, params.check_tensors, params.repack_tensors, params.use_thp, params.merge_qkv, params.merge_up_gate_exps, params.kv_overrides, params.tensor_buft_overrides); @@ -4216,6 +4216,7 @@ struct llama_model_params llama_model_default_params() { /*.tensor_buft_overrides =*/ nullptr, /*.vocab_only =*/ false, /*.use_mmap =*/ true, + /*.use_direct_io =*/ false, /*.use_mlock =*/ false, /*.check_tensors =*/ false, /*.repack_tensors =*/ false,