Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 5 additions & 15 deletions mlx/backend/metal/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,22 +60,12 @@ struct CommandEncoder {
enc_->updateFence(fence);
}

template <typename T>
void set_vector_bytes(const SmallVector<T>& vec, size_t nelems, int idx) {
enc_->setBytes(vec.data(), nelems * sizeof(T), idx);
}
template <typename T>
void set_vector_bytes(const SmallVector<T>& vec, int idx) {
return set_vector_bytes(vec, vec.size(), idx);
template <typename Vec, typename = std::enable_if_t<is_vector_v<Vec>>>
void set_vector_bytes(const Vec& vec, size_t nelems, int idx) {
enc_->setBytes(vec.data(), nelems * sizeof(typename Vec::value_type), idx);
}

// TODO: Code is duplicated but they should be deleted soon.
template <typename T>
void set_vector_bytes(const std::vector<T>& vec, size_t nelems, int idx) {
enc_->setBytes(vec.data(), nelems * sizeof(T), idx);
}
template <typename T>
void set_vector_bytes(const std::vector<T>& vec, int idx) {
template <typename Vec, typename = std::enable_if_t<is_vector_v<Vec>>>
void set_vector_bytes(const Vec& vec, int idx) {
return set_vector_bytes(vec, vec.size(), idx);
}

Expand Down
12 changes: 12 additions & 0 deletions mlx/small_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,18 @@ class SmallVector {
std::is_trivially_destructible<T>::value;
};

template <typename>
struct is_vector : std::false_type {};

template <typename T, size_t Size, typename Allocator>
struct is_vector<SmallVector<T, Size, Allocator>> : std::true_type {};

template <typename T, typename Allocator>
struct is_vector<std::vector<T, Allocator>> : std::true_type {};

template <typename Vec>
inline constexpr bool is_vector_v = is_vector<Vec>::value;

#undef MLX_HAS_BUILTIN
#undef MLX_HAS_ATTRIBUTE
#undef MLX_LIKELY
Expand Down
37 changes: 0 additions & 37 deletions mlx/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,43 +259,6 @@ std::ostream& operator<<(std::ostream& os, array a) {
return os;
}

std::ostream& operator<<(std::ostream& os, const SmallVector<int>& v) {
os << "(";
for (int i = 0; i < v.size(); ++i) {
os << v[i] << ((i == v.size() - 1) ? "" : ",");
}
os << ")";
return os;
}

std::ostream& operator<<(std::ostream& os, const SmallVector<int64_t>& v) {
os << "(";
for (int i = 0; i < v.size(); ++i) {
os << v[i] << ((i == v.size() - 1) ? "" : ",");
}
os << ")";
return os;
}

// TODO: Code is duplicated but they should be deleted soon.
std::ostream& operator<<(std::ostream& os, const std::vector<int>& v) {
os << "(";
for (int i = 0; i < v.size(); ++i) {
os << v[i] << ((i == v.size() - 1) ? "" : ",");
}
os << ")";
return os;
}

std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v) {
os << "(";
for (int i = 0; i < v.size(); ++i) {
os << v[i] << ((i == v.size() - 1) ? "" : ",");
}
os << ")";
return os;
}

namespace env {

int get_var(const char* name, int default_value) {
Expand Down
17 changes: 13 additions & 4 deletions mlx/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,6 @@ std::ostream& operator<<(std::ostream& os, const Stream& s);
std::ostream& operator<<(std::ostream& os, const Dtype& d);
std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k);
std::ostream& operator<<(std::ostream& os, array a);
std::ostream& operator<<(std::ostream& os, const SmallVector<int>& v);
std::ostream& operator<<(std::ostream& os, const SmallVector<int64_t>& v);
std::ostream& operator<<(std::ostream& os, const std::vector<int>& v);
std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v);
inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) {
return os << v.real() << (v.imag() >= 0 ? "+" : "") << v.imag() << "j";
}
Expand All @@ -114,6 +110,19 @@ inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) {
return os << static_cast<float>(v);
}

template <typename Vec, typename = std::enable_if_t<is_vector_v<Vec>>>
inline std::ostream& operator<<(std::ostream& os, const Vec& v) {
os << "(";
for (auto it = v.begin(); it != v.end(); ++it) {
os << *it;
if (it != std::prev(v.end())) {
os << ",";
}
}
os << ")";
return os;
}

inline bool is_power_of_2(int n) {
return ((n & (n - 1)) == 0) && n != 0;
}
Expand Down