diff --git a/cpp/src/arrow/util/string_test.cc b/cpp/src/arrow/util/string_test.cc index 8988eb9996c0..750f7e372935 100644 --- a/cpp/src/arrow/util/string_test.cc +++ b/cpp/src/arrow/util/string_test.cc @@ -28,6 +28,7 @@ #include "arrow/testing/gtest_util.h" #include "arrow/util/regex.h" #include "arrow/util/string.h" +#include "arrow/util/base64.h" namespace arrow { namespace internal { @@ -238,6 +239,49 @@ TEST(ToChars, FloatingPoint) { } } +TEST(Base64DecodeTest, ValidInputs) { + EXPECT_EQ(arrow::util::base64_decode("Zg=="), "f"); + EXPECT_EQ(arrow::util::base64_decode("Zm8="), "fo"); + EXPECT_EQ(arrow::util::base64_decode("Zm9v"), "foo"); + EXPECT_EQ(arrow::util::base64_decode("aGVsbG8gd29ybGQ="), "hello world"); +} + +TEST(Base64DecodeTest, InvalidLength) { + EXPECT_EQ(arrow::util::base64_decode("abc"), ""); + EXPECT_EQ(arrow::util::base64_decode("abcde"), ""); +} + +TEST(Base64DecodeTest, InvalidCharacters) { + EXPECT_EQ(arrow::util::base64_decode("ab$="), ""); + EXPECT_EQ(arrow::util::base64_decode("Zm9v*"), ""); + EXPECT_EQ(arrow::util::base64_decode("abcd$AAA"), ""); +} + +TEST(Base64DecodeTest, InvalidPadding) { + EXPECT_EQ(arrow::util::base64_decode("ab=c"), ""); + EXPECT_EQ(arrow::util::base64_decode("abc==="), ""); + EXPECT_EQ(arrow::util::base64_decode("abcd=AAA"), ""); + EXPECT_EQ(arrow::util::base64_decode("Zm=9v"), ""); +} + +TEST(Base64DecodeTest, EdgeCases) { + EXPECT_EQ(arrow::util::base64_decode("===="), ""); + EXPECT_EQ(arrow::util::base64_decode("TQ=="), "M"); +} + +TEST(Base64DecodeTest, EmptyInput) { + EXPECT_EQ(arrow::util::base64_decode(""), ""); +} + +TEST(Base64DecodeTest, NonAsciiInput) { + std::string input = std::string("abcd") + char(0xFF) + "=="; + EXPECT_EQ(arrow::util::base64_decode(input), ""); +} + +TEST(Base64DecodeTest, PartialCorruption) { + EXPECT_EQ(arrow::util::base64_decode("aGVs$G8gd29ybGQ="), ""); +} + #if !defined(_WIN32) || defined(NDEBUG) TEST(ToChars, LocaleIndependent) { diff --git a/cpp/src/arrow/vendored/base64.cpp b/cpp/src/arrow/vendored/base64.cpp index 6f53c0524e71..7a768b25de8c 100644 --- a/cpp/src/arrow/vendored/base64.cpp +++ b/cpp/src/arrow/vendored/base64.cpp @@ -30,7 +30,9 @@ */ #include "arrow/util/base64.h" +#include "arrow/util/logging.h" #include +#include namespace arrow { namespace util { @@ -101,7 +103,46 @@ std::string base64_decode(std::string_view encoded_string) { unsigned char char_array_4[4], char_array_3[3]; std::string ret; - while (in_len-- && ( encoded_string[in_] != '=') && is_base64(encoded_string[in_])) { + static const std::string base64_chars = + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz" + "0123456789+/"; + + auto is_base64 = [](unsigned char c) -> bool { + return (std::isalnum(c) || (c == '+') || (c == '/')); + }; + + if (encoded_string.size() % 4 != 0) { + ARROW_LOG(ERROR) << "Invalid base64 input: length is not a multiple of 4"; + return ""; + } + + size_t padding_start = encoded_string.find('='); + + if (padding_start != std::string::npos) { + for (size_t k = padding_start; k < encoded_string.size(); ++k) { + if (encoded_string[k] != '=') { + ARROW_LOG(ERROR) << "Invalid base64 input: padding character '=' found at invalid position"; + return ""; + } + } + + size_t padding_count = encoded_string.size() - padding_start; + + if (padding_count > 2) { + ARROW_LOG(ERROR) << "Invalid base64 input: too many padding characters"; + return ""; + } + } + + for (char c : encoded_string) { + if (c != '=' && !is_base64(c)) { + ARROW_LOG(ERROR) << "Invalid base64 input: contains non-base64 character '" << c << "'"; + return ""; + } + } + + while (in_len-- && encoded_string[in_] != '=') { char_array_4[i++] = encoded_string[in_]; in_++; if (i ==4) { for (i = 0; i <4; i++)