Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 4 additions & 10 deletions visualdl/logic/histogram.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,6 @@ struct HistogramBuilder {
T right_boundary{std::numeric_limits<T>::min()};
std::vector<int> buckets;

void Get(size_t n, T* left, T* right, int* frequency) {
CHECK(!buckets.empty()) << "need to CreateBuckets first.";
CHECK_LT(n, num_buckets_) << "n out of range.";
*left = left_boundary + span_ * n;
*right = *left + span_;
*frequency = buckets[n];
}

private:
// Get the left and right boundaries.
void UpdateBoundary(const std::vector<T>& data) {
Expand All @@ -106,9 +98,11 @@ struct HistogramBuilder {
(float)left_boundary / num_buckets_;
buckets.resize(num_buckets_);

// Go through the data, increase the item count in a bucket.
for (auto v : data) {
int offset = std::min(int((v - left_boundary) / span_), num_buckets_ - 1);
buckets[offset]++;
int bucket_group_index =
std::min(int((v - left_boundary) / span_), num_buckets_ - 1);
buckets[bucket_group_index]++;
}
}

Expand Down
29 changes: 26 additions & 3 deletions visualdl/logic/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,14 @@ PYBIND11_MODULE(core, m) {
auto tablet = self.tablet(tag);
return vs::components::TextReader(tablet);
})
.def("get_audio", [](vs::LogReader& self, const std::string& tag) {
.def("get_audio",
[](vs::LogReader& self, const std::string& tag) {
auto tablet = self.tablet(tag);
return vs::components::AudioReader(self.mode(), tablet);
})
.def("get_embedding", [](vs::LogReader& self, const std::string& tag) {
auto tablet = self.tablet(tag);
return vs::components::AudioReader(self.mode(), tablet);
return vs::components::EmbeddingReader(tablet);
});

// clang-format on
Expand Down Expand Up @@ -136,7 +141,11 @@ PYBIND11_MODULE(core, m) {
int step_cycle) {
auto tablet = self.AddTablet(tag);
return vs::components::Audio(tablet, num_samples, step_cycle);
});
})
.def("new_embedding", [](vs::LogWriter& self, const std::string& tag) {
auto tablet = self.AddTablet(tag);
return vs::components::Embedding(tablet);
});

//------------------- components --------------------
#define ADD_SCALAR_READER(T) \
Expand Down Expand Up @@ -233,6 +242,20 @@ PYBIND11_MODULE(core, m) {
.def("total_records", &cp::TextReader::total_records)
.def("size", &cp::TextReader::size);

py::class_<cp::Embedding>(m, "EmbeddingWriter")
.def("set_caption", &cp::Embedding::SetCaption)
.def("add_embeddings_with_word_list",
&cp::Embedding::AddEmbeddingsWithWordList);

py::class_<cp::EmbeddingReader>(m, "EmbeddingReader")
.def("get_all_labels", &cp::EmbeddingReader::get_all_labels)
.def("get_all_embeddings", &cp::EmbeddingReader::get_all_embeddings)
.def("ids", &cp::EmbeddingReader::ids)
.def("timestamps", &cp::EmbeddingReader::timestamps)
.def("caption", &cp::EmbeddingReader::caption)
.def("total_records", &cp::EmbeddingReader::total_records)
.def("size", &cp::EmbeddingReader::size);

py::class_<cp::Audio>(m, "AudioWriter", R"pbdoc(
PyBind class. Must instantiate through the LogWriter.
)pbdoc")
Expand Down
73 changes: 73 additions & 0 deletions visualdl/logic/sdk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,79 @@ std::string TextReader::caption() const {

size_t TextReader::size() const { return reader_.total_records(); }

/*
* Embedding functions
*/
void Embedding::AddEmbeddingsWithWordList(
const std::vector<std::vector<float>>& word_embeddings,
std::vector<std::string>& labels) {
for (int i = 0; i < word_embeddings.size(); i++) {
AddEmbedding(i, word_embeddings[i], labels[i]);
}
}

void Embedding::AddEmbedding(int item_id,
const std::vector<float>& one_hot_vector,
std::string& label) {
auto record = tablet_.AddRecord();
record.SetId(item_id);
time_t time = std::time(nullptr);
record.SetTimeStamp(time);
auto entry = record.AddData();
entry.SetMulti<float>(one_hot_vector);
entry.SetRaw(label);
}

/*
* EmbeddingReader functions
*/
std::vector<std::string> EmbeddingReader::get_all_labels() const {
std::vector<std::string> result;

for (int i = 0; i < total_records(); i++) {
auto record = reader_.record(i);
auto entry = record.data(0);
result.push_back(entry.GetRaw());
}
return result;
}

std::vector<std::vector<float>> EmbeddingReader::get_all_embeddings() const {
std::vector<std::vector<float>> result;

for (int i = 0; i < total_records(); i++) {
auto record = reader_.record(i);
auto entry = record.data(0);
auto tensors = entry.GetMulti<float>();

result.push_back(tensors);
}
return result;
}

std::vector<int> EmbeddingReader::ids() const {
std::vector<int> res;
for (int i = 0; i < reader_.total_records(); i++) {
res.push_back(reader_.record(i).id());
}
return res;
}

std::vector<time_t> EmbeddingReader::timestamps() const {
std::vector<time_t> res;
for (int i = 0; i < reader_.total_records(); i++) {
res.push_back(reader_.record(i).timestamp());
}
return res;
}

std::string EmbeddingReader::caption() const {
CHECK(!reader_.captions().empty()) << "no caption";
return reader_.captions().front();
}

size_t EmbeddingReader::size() const { return reader_.total_records(); }

void Audio::StartSampling() {
if (!ToSampleThisStep()) return;

Expand Down
46 changes: 46 additions & 0 deletions visualdl/logic/sdk.h
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,52 @@ struct TextReader {
TabletReader reader_;
};

/*
* Embedding component writer
*/
struct Embedding {
Embedding(Tablet tablet) : tablet_(tablet) {
tablet_.SetType(Tablet::Type::kEmbedding);
}
void SetCaption(const std::string cap) {
tablet_.SetCaptions(std::vector<std::string>({cap}));
}

// Add all word vectors along with all labels
// The index of labels should match with the index of word_embeddings
// EX: ["Apple", "Orange"] means the first item in word_embeddings represents
// "Apple"
void AddEmbeddingsWithWordList(
const std::vector<std::vector<float>>& word_embeddings,
std::vector<std::string>& labels);
// TODO: Create another function that takes 'word_embeddings' and 'word_dict'
private:
void AddEmbedding(int item_id,
const std::vector<float>& one_hot_vector,
std::string& label);

Tablet tablet_;
};

/*
* Embedding Reader.
*/
struct EmbeddingReader {
EmbeddingReader(TabletReader reader) : reader_(reader) {}

std::vector<int> ids() const;
std::vector<std::string> get_all_labels() const;
std::vector<std::vector<float>> get_all_embeddings() const;

std::vector<time_t> timestamps() const;
std::string caption() const;
size_t total_records() const { return reader_.total_records(); }
size_t size() const;

private:
TabletReader reader_;
};

/*
* Image component writer.
*/
Expand Down
8 changes: 8 additions & 0 deletions visualdl/python/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ def text(self, tag):
check_tag_name_valid(tag)
return self.reader.get_text(tag)

def embedding(self, tag):
check_tag_name_valid(tag)
return self.reader.get_embedding(tag)

def audio(self, tag):
"""
Get an audio reader with tag
Expand Down Expand Up @@ -256,6 +260,10 @@ def text(self, tag):
check_tag_name_valid(tag)
return self.writer.new_text(tag)

def embedding(self, tag):
check_tag_name_valid(tag)
return self.writer.new_embedding(tag)

def save(self):
self.writer.save()

Expand Down
1 change: 1 addition & 0 deletions visualdl/storage/storage.proto
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ message Tablet {
kImage = 2;
kText = 3;
kAudio = 4;
kEmbedding = 5;
}
// The unique identification for this `Tablet`. VisualDL will have no the
// concept of FileWriter like TB. It will store all the tablets in a single
Expand Down
4 changes: 4 additions & 0 deletions visualdl/storage/tablet.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ struct Tablet {
kImage = 2,
kText = 3,
kAudio = 4,
kEmbedding = 5,
kUnknown = -1
};

Expand All @@ -59,6 +60,9 @@ struct Tablet {
if (name == "audio") {
return kAudio;
}
if (name == "embedding") {
return kEmbedding;
}
LOG(ERROR) << "unknown component: " << name;
return kUnknown;
}
Expand Down