Skip to content
Merged
6 changes: 3 additions & 3 deletions test/unit_test/namespaced_features_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ BOOST_AUTO_TEST_CASE(namespaced_features_test)
BOOST_REQUIRE_THROW(feature_groups[1], VW::vw_exception);
BOOST_REQUIRE_NO_THROW(feature_groups[123]);

check_collections_exact(feature_groups.get_indices(), std::vector<namespace_index>{'a'});
check_collections_exact(feature_groups.get_indices(), std::set<namespace_index>{'a'});

feature_groups.remove_feature_group(123);
begin_end = feature_groups.get_namespace_index_groups('a');
BOOST_CHECK(begin_end.second - begin_end.first == 1);

check_collections_exact(feature_groups.get_indices(), std::vector<namespace_index>{'a'});
check_collections_exact(feature_groups.get_indices(), std::set<namespace_index>{'a'});
feature_groups.remove_feature_group(1234);
check_collections_exact(feature_groups.get_indices(), std::vector<namespace_index>{});
check_collections_exact(feature_groups.get_indices(), std::set<namespace_index>{});
}
33 changes: 26 additions & 7 deletions vowpalwabbit/namespaced_features.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@ const features* namespaced_features::get_feature_group(uint64_t hash) const
return &_feature_groups[it->second];
}

std::vector<namespace_index> namespaced_features::get_indices() const
const std::set<namespace_index>& namespaced_features::get_indices() const { return _contained_indices; }

namespace_index namespaced_features::get_index_for_hash(uint64_t hash) const
{
auto indices_copy = _namespace_indices;
std::sort(indices_copy.begin(), indices_copy.end());
auto last = std::unique(indices_copy.begin(), indices_copy.end());
indices_copy.erase(last, indices_copy.end());
return indices_copy;
auto it = _hash_to_index_mapping.find(hash);
#ifndef VW_NOEXCEPT
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this is used with noexcept it will fail miserably, maybe add:

#else 
 if (it == end())
    return {};

and return an empty set

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed

if (it == _hash_to_index_mapping.end()) { THROW("No index found for hash: " << hash); }
#endif
return _namespace_indices[it->second];
}

std::pair<namespaced_features::indexed_iterator, namespaced_features::indexed_iterator>
Expand All @@ -55,6 +57,8 @@ features& namespaced_features::get_or_create_feature_group(uint64_t hash, namesp
auto new_index = _feature_groups.size() - 1;
_hash_to_index_mapping[hash] = new_index;
_legacy_indices_to_index_mapping[ns_index].push_back(new_index);
// If size is 1, that means this is the first time the ns_index is added and we should add it to the set.
if (_legacy_indices_to_index_mapping[ns_index].size() == 1) { _contained_indices.insert(ns_index); }
existing_group = &_feature_groups.back();
}

Expand Down Expand Up @@ -106,7 +110,12 @@ void namespaced_features::remove_feature_group(uint64_t hash)
// If any groups are left empty, remove them.
for (auto it = _legacy_indices_to_index_mapping.begin(); it != _legacy_indices_to_index_mapping.end();)
{
if (it->second.empty()) { it = _legacy_indices_to_index_mapping.erase(it); }
if (it->second.empty())
{
// There are no more feature groups which correspond to this index.
_contained_indices.erase(it->first);
it = _legacy_indices_to_index_mapping.erase(it);
}
else
{
++it;
Expand All @@ -119,6 +128,16 @@ void namespaced_features::remove_feature_group(uint64_t hash)
}
}

void namespaced_features::clear()
{
_feature_groups.clear();
_namespace_indices.clear();
_namespace_hashes.clear();
_legacy_indices_to_index_mapping.clear();
_hash_to_index_mapping.clear();
_contained_indices.clear();
}

generic_range<namespaced_features::indexed_iterator> namespaced_features::namespace_index_range(
namespace_index ns_index)
{
Expand Down
42 changes: 25 additions & 17 deletions vowpalwabbit/namespaced_features.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <vector>
#include <unordered_map>
#include <cassert>
#include <set>

#include "feature_group.h"
#include "generic_range.h"
Expand Down Expand Up @@ -129,36 +130,38 @@ struct namespaced_features
// Returns nullptr if not found.
const features* get_feature_group(uint64_t hash) const;

// TODO - don't generate this per call.
std::vector<namespace_index> get_indices() const;
const std::set<namespace_index>& get_indices() const;
namespace_index get_index_for_hash(uint64_t hash) const;

// Returns empty range if not found
std::pair<indexed_iterator, indexed_iterator> get_namespace_index_groups(namespace_index index);
std::pair<indexed_iterator, indexed_iterator> get_namespace_index_groups(namespace_index ns_index);
// Returns empty range if not found
std::pair<const_indexed_iterator, const_indexed_iterator> get_namespace_index_groups(namespace_index index) const;
std::pair<const_indexed_iterator, const_indexed_iterator> get_namespace_index_groups(namespace_index ns_index) const;

// If a feature group already exists in this "slot" it will be merged
template <typename FeaturesT>
features& merge_feature_group(FeaturesT&& ftrs, uint64_t hash, namespace_index index);
features& merge_feature_group(FeaturesT&& ftrs, uint64_t hash, namespace_index ns_index);

// If no feature group already exists a default one will be created.
// Creating new feature groups will invalidate any pointers or references held.
features& get_or_create_feature_group(uint64_t hash, namespace_index index);
features& get_or_create_feature_group(uint64_t hash, namespace_index ns_index);

const features& operator[](uint64_t hash) const;
features& operator[](uint64_t hash);

// Removing a feature group will invalidate any pointers or references held.
void remove_feature_group(uint64_t hash);

generic_range<indexed_iterator> namespace_index_range(namespace_index index);
generic_range<const_indexed_iterator> namespace_index_range(namespace_index index) const;
indexed_iterator namespace_index_begin(namespace_index index);
indexed_iterator namespace_index_end(namespace_index index);
const_indexed_iterator namespace_index_begin(namespace_index index) const;
const_indexed_iterator namespace_index_end(namespace_index index) const;
const_indexed_iterator namespace_index_cbegin(namespace_index index) const;
const_indexed_iterator namespace_index_cend(namespace_index index) const;
void clear();

generic_range<indexed_iterator> namespace_index_range(namespace_index ns_index);
generic_range<const_indexed_iterator> namespace_index_range(namespace_index ns_index) const;
indexed_iterator namespace_index_begin(namespace_index ns_index);
indexed_iterator namespace_index_end(namespace_index ns_index);
const_indexed_iterator namespace_index_begin(namespace_index ns_index) const;
const_indexed_iterator namespace_index_end(namespace_index ns_index) const;
const_indexed_iterator namespace_index_cbegin(namespace_index ns_index) const;
const_indexed_iterator namespace_index_cend(namespace_index ns_index) const;

iterator begin();
iterator end();
Expand All @@ -169,11 +172,14 @@ struct namespaced_features

private:
std::vector<features> _feature_groups;
// Can have duplicate values.
std::vector<namespace_index> _namespace_indices;
// Should never have duplicate values.
std::vector<uint64_t> _namespace_hashes;

std::unordered_map<namespace_index, std::vector<size_t>> _legacy_indices_to_index_mapping;
std::unordered_map<uint64_t, size_t> _hash_to_index_mapping;
std::set<namespace_index> _contained_indices;
};

// If a feature group already exists in this "slot" it will be merged
Expand All @@ -190,6 +196,8 @@ features& namespaced_features::merge_feature_group(FeaturesT&& ftrs, uint64_t ha
auto new_index = _feature_groups.size() - 1;
_hash_to_index_mapping[hash] = new_index;
_legacy_indices_to_index_mapping[ns_index].push_back(new_index);
// If size is 1, that means this is the first time the ns_index is added and we should add it to the set.
if (_legacy_indices_to_index_mapping[ns_index].size() == 1) { _contained_indices.insert(ns_index); }
existing_group = &_feature_groups.back();
}
else
Expand All @@ -198,9 +206,9 @@ features& namespaced_features::merge_feature_group(FeaturesT&& ftrs, uint64_t ha
auto existing_index = _hash_to_index_mapping[hash];
// Should we ensure that this doesnt already exist under a DIFFERENT namespace_index?
// However, his shouldn't be possible as ns_index depends on hash.
auto& indices_list = _legacy_indices_to_index_mapping[ns_index];
if (std::find(indices_list.begin(), indices_list.end(), ns_index) == indices_list.end())
{ indices_list.push_back(existing_index); }
auto& ns_indices_list = _legacy_indices_to_index_mapping[ns_index];
if (std::find(ns_indices_list.begin(), ns_indices_list.end(), ns_index) == ns_indices_list.end())
{ ns_indices_list.push_back(existing_index); }
}
return *existing_group;
}
Expand Down