|
29 | 29 | #include <functional> |
30 | 30 | #include <numeric> |
31 | 31 | #include <optional> |
| 32 | +#include <unordered_set> |
32 | 33 |
|
33 | 34 | namespace cudf::io::parquet::experimental::detail { |
34 | 35 |
|
35 | 36 | using aggregate_reader_metadata_base = parquet::detail::aggregate_reader_metadata; |
36 | 37 | using metadata_base = parquet::detail::metadata; |
37 | 38 |
|
| 39 | +using io::detail::inline_column_buffer; |
38 | 40 | using parquet::detail::CompactProtocolReader; |
39 | 41 | using parquet::detail::equality_literals_collector; |
40 | 42 | using parquet::detail::input_column_info; |
41 | 43 | using parquet::detail::row_group_info; |
42 | 44 |
|
| 45 | +namespace { |
| 46 | + |
| 47 | +[[nodiscard]] auto all_row_group_indices( |
| 48 | + host_span<std::vector<cudf::size_type> const> row_group_indices) |
| 49 | +{ |
| 50 | + return std::vector<std::vector<cudf::size_type>>(row_group_indices.begin(), |
| 51 | + row_group_indices.end()); |
| 52 | +} |
| 53 | + |
| 54 | +} // namespace |
| 55 | + |
43 | 56 | metadata::metadata(cudf::host_span<uint8_t const> footer_bytes) |
44 | 57 | { |
45 | 58 | CompactProtocolReader cp(footer_bytes.data(), footer_bytes.size()); |
@@ -137,4 +150,117 @@ void aggregate_reader_metadata::setup_page_index(cudf::host_span<uint8_t const> |
137 | 150 | } |
138 | 151 | } |
139 | 152 |
|
| 153 | +std::tuple<std::vector<input_column_info>, |
| 154 | + std::vector<inline_column_buffer>, |
| 155 | + std::vector<cudf::size_type>> |
| 156 | +aggregate_reader_metadata::select_payload_columns( |
| 157 | + std::optional<std::vector<std::string>> const& payload_column_names, |
| 158 | + std::optional<std::vector<std::string>> const& filter_column_names, |
| 159 | + bool include_index, |
| 160 | + bool strings_to_categorical, |
| 161 | + type_id timestamp_type_id) |
| 162 | +{ |
| 163 | + // If neither payload nor filter columns are specified, select all columns |
| 164 | + if (not payload_column_names.has_value() and not filter_column_names.has_value()) { |
| 165 | + // Call the base `select_columns()` method without specifying any columns |
| 166 | + return select_columns({}, {}, include_index, strings_to_categorical, timestamp_type_id); |
| 167 | + } |
| 168 | + |
| 169 | + std::vector<std::string> valid_payload_columns; |
| 170 | + |
| 171 | + // If payload columns are specified, only select payload columns that do not appear in the filter |
| 172 | + // expression |
| 173 | + if (payload_column_names.has_value()) { |
| 174 | + valid_payload_columns = *payload_column_names; |
| 175 | + // Remove filter columns from the provided payload column names |
| 176 | + if (filter_column_names.has_value() and not filter_column_names->empty()) { |
| 177 | + // Add filter column names to a hash set for faster lookup |
| 178 | + std::unordered_set<std::string> filter_columns_set(filter_column_names->begin(), |
| 179 | + filter_column_names->end()); |
| 180 | + // Remove a payload column name if it is also present in the hash set |
| 181 | + valid_payload_columns.erase(std::remove_if(valid_payload_columns.begin(), |
| 182 | + valid_payload_columns.end(), |
| 183 | + [&filter_columns_set](auto const& col) { |
| 184 | + return filter_columns_set.count(col) > 0; |
| 185 | + }), |
| 186 | + valid_payload_columns.end()); |
| 187 | + } |
| 188 | + // Call the base `select_columns()` method with valid payload columns |
| 189 | + return select_columns( |
| 190 | + valid_payload_columns, {}, include_index, strings_to_categorical, timestamp_type_id); |
| 191 | + } |
| 192 | + |
| 193 | + // Else if only filter columns are specified, select all columns that do not appear in the |
| 194 | + // filter expression |
| 195 | + |
| 196 | + // Add filter column names to a hash set for faster lookup |
| 197 | + std::unordered_set<std::string> filter_columns_set(filter_column_names->begin(), |
| 198 | + filter_column_names->end()); |
| 199 | + |
| 200 | + std::function<void(std::string, int)> add_column_path = [&](std::string path_till_now, |
| 201 | + int schema_idx) { |
| 202 | + auto const& schema_elem = get_schema(schema_idx); |
| 203 | + std::string const curr_path = path_till_now + schema_elem.name; |
| 204 | + // If the current path is not a filter column, then add it and its children to the list of valid |
| 205 | + // payload columns |
| 206 | + if (filter_columns_set.count(curr_path) == 0) { |
| 207 | + valid_payload_columns.push_back(curr_path); |
| 208 | + // Add all children as well |
| 209 | + for (auto const& child_idx : schema_elem.children_idx) { |
| 210 | + add_column_path(curr_path + ".", child_idx); |
| 211 | + } |
| 212 | + } |
| 213 | + }; |
| 214 | + |
| 215 | + // Add all but filter columns to valid payload columns |
| 216 | + if (not filter_column_names->empty()) { |
| 217 | + for (auto const& child_idx : get_schema(0).children_idx) { |
| 218 | + add_column_path("", child_idx); |
| 219 | + } |
| 220 | + } |
| 221 | + |
| 222 | + // Call the base `select_columns()` method with all but filter columns |
| 223 | + return select_columns( |
| 224 | + valid_payload_columns, {}, include_index, strings_to_categorical, timestamp_type_id); |
| 225 | +} |
| 226 | + |
| 227 | +std::vector<std::vector<cudf::size_type>> aggregate_reader_metadata::filter_row_groups_with_stats( |
| 228 | + host_span<std::vector<cudf::size_type> const> row_group_indices, |
| 229 | + host_span<data_type const> output_dtypes, |
| 230 | + host_span<int const> output_column_schemas, |
| 231 | + std::optional<std::reference_wrapper<ast::expression const>> filter, |
| 232 | + rmm::cuda_stream_view stream) const |
| 233 | +{ |
| 234 | + // Return all row groups if no filter expression |
| 235 | + if (not filter.has_value()) { return all_row_group_indices(row_group_indices); } |
| 236 | + |
| 237 | + // Compute total number of input row groups |
| 238 | + cudf::size_type total_row_groups = [&]() { |
| 239 | + if (not row_group_indices.empty()) { |
| 240 | + size_t const total_row_groups = |
| 241 | + std::accumulate(row_group_indices.begin(), |
| 242 | + row_group_indices.end(), |
| 243 | + size_t{0}, |
| 244 | + [](auto sum, auto const& pfm) { return sum + pfm.size(); }); |
| 245 | + |
| 246 | + // Check if we have less than 2B total row groups. |
| 247 | + CUDF_EXPECTS(total_row_groups <= std::numeric_limits<cudf::size_type>::max(), |
| 248 | + "Total number of row groups exceed the cudf::size_type's limit"); |
| 249 | + return static_cast<cudf::size_type>(total_row_groups); |
| 250 | + } else { |
| 251 | + return num_row_groups; |
| 252 | + } |
| 253 | + }(); |
| 254 | + |
| 255 | + // Filter stats table with StatsAST expression and collect filtered row group indices |
| 256 | + auto const stats_filtered_row_group_indices = apply_stats_filters(row_group_indices, |
| 257 | + total_row_groups, |
| 258 | + output_dtypes, |
| 259 | + output_column_schemas, |
| 260 | + filter.value(), |
| 261 | + stream); |
| 262 | + |
| 263 | + return stats_filtered_row_group_indices.value_or(all_row_group_indices(row_group_indices)); |
| 264 | +} |
| 265 | + |
140 | 266 | } // namespace cudf::io::parquet::experimental::detail |
0 commit comments