Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions mlx/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3468,6 +3468,14 @@ array cumsum(
{a});
}

array cumsum(
const array& a,
bool reverse /* = false*/,
bool inclusive /* = true*/,
StreamOrDevice s /* = {}*/) {
return cumsum(flatten(a, to_stream(s)), 0, reverse, inclusive, to_stream(s));
}

array cumprod(
const array& a,
int axis,
Expand All @@ -3490,6 +3498,14 @@ array cumprod(
{a});
}

array cumprod(
const array& a,
bool reverse /* = false*/,
bool inclusive /* = true*/,
StreamOrDevice s /* = {}*/) {
return cumprod(flatten(a, s), 0, reverse, inclusive, s);
}

array cummax(
const array& a,
int axis,
Expand All @@ -3512,6 +3528,14 @@ array cummax(
{a});
}

array cummax(
const array& a,
bool reverse /* = false*/,
bool inclusive /* = true*/,
StreamOrDevice s /* = {}*/) {
return cummax(flatten(a, s), 0, reverse, inclusive, s);
}

array cummin(
const array& a,
int axis,
Expand All @@ -3534,6 +3558,14 @@ array cummin(
{a});
}

array cummin(
const array& a,
bool reverse /* = false*/,
bool inclusive /* = true*/,
StreamOrDevice s /* = {}*/) {
return cummin(flatten(a, s), 0, reverse, inclusive, s);
}

array logcumsumexp(
const array& a,
int axis,
Expand All @@ -3556,6 +3588,15 @@ array logcumsumexp(
{a});
}

array logcumsumexp(
const array& a,
bool reverse /* = false*/,
bool inclusive /* = true*/,
StreamOrDevice s /* = {}*/) {
return logcumsumexp(
flatten(a, to_stream(s)), 0, reverse, inclusive, to_stream(s));
}

/** Convolution operations */

namespace {
Expand Down
35 changes: 35 additions & 0 deletions mlx/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,13 @@ array topk(const array& a, int k, StreamOrDevice s = {});
array topk(const array& a, int k, int axis, StreamOrDevice s = {});

/** Cumulative logsumexp of an array. */
array logcumsumexp(
const array& a,
bool reverse = false,
bool inclusive = true,
StreamOrDevice s = {});

/** Cumulative logsumexp of an array along the given axis. */
array logcumsumexp(
const array& a,
int axis,
Expand Down Expand Up @@ -1186,6 +1193,13 @@ softmax(const array& a, int axis, bool precise = false, StreamOrDevice s = {}) {
array power(const array& a, const array& b, StreamOrDevice s = {});

/** Cumulative sum of an array. */
array cumsum(
const array& a,
bool reverse = false,
bool inclusive = true,
StreamOrDevice s = {});

/** Cumulative sum of an array along the given axis. */
array cumsum(
const array& a,
int axis,
Expand All @@ -1194,6 +1208,13 @@ array cumsum(
StreamOrDevice s = {});

/** Cumulative product of an array. */
array cumprod(
const array& a,
bool reverse = false,
bool inclusive = true,
StreamOrDevice s = {});

/** Cumulative product of an array along the given axis. */
array cumprod(
const array& a,
int axis,
Expand All @@ -1202,6 +1223,13 @@ array cumprod(
StreamOrDevice s = {});

/** Cumulative max of an array. */
array cummax(
const array& a,
bool reverse = false,
bool inclusive = true,
StreamOrDevice s = {});

/** Cumulative max of an array along the given axis. */
array cummax(
const array& a,
int axis,
Expand All @@ -1210,6 +1238,13 @@ array cummax(
StreamOrDevice s = {});

/** Cumulative min of an array. */
array cummin(
const array& a,
bool reverse = false,
bool inclusive = true,
StreamOrDevice s = {});

/** Cumulative min of an array along the given axis. */
array cummin(
const array& a,
int axis,
Expand Down
24 changes: 5 additions & 19 deletions python/src/array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1275,10 +1275,7 @@ void init_array(nb::module_& m) {
if (axis) {
return mx::logcumsumexp(a, *axis, reverse, inclusive, s);
} else {
// TODO: Implement that in the C++ API as well. See concatenate
// above.
return mx::logcumsumexp(
mx::reshape(a, {-1}, s), 0, reverse, inclusive, s);
return mx::logcumsumexp(a, reverse, inclusive, s);
}
},
"axis"_a = nb::none(),
Expand Down Expand Up @@ -1408,9 +1405,7 @@ void init_array(nb::module_& m) {
if (axis) {
return mx::cumsum(a, *axis, reverse, inclusive, s);
} else {
// TODO: Implement that in the C++ API as well. See concatenate
// above.
return mx::cumsum(reshape(a, {-1}, s), 0, reverse, inclusive, s);
return mx::cumsum(a, reverse, inclusive, s);
}
},
"axis"_a = nb::none(),
Expand All @@ -1429,10 +1424,7 @@ void init_array(nb::module_& m) {
if (axis) {
return mx::cumprod(a, *axis, reverse, inclusive, s);
} else {
// TODO: Implement that in the C++ API as well. See concatenate
// above.
return mx::cumprod(
mx::reshape(a, {-1}, s), 0, reverse, inclusive, s);
return mx::cumprod(a, reverse, inclusive, s);
}
},
"axis"_a = nb::none(),
Expand All @@ -1451,10 +1443,7 @@ void init_array(nb::module_& m) {
if (axis) {
return mx::cummax(a, *axis, reverse, inclusive, s);
} else {
// TODO: Implement that in the C++ API as well. See concatenate
// above.
return mx::cummax(
mx::reshape(a, {-1}, s), 0, reverse, inclusive, s);
return mx::cummax(a, reverse, inclusive, s);
}
},
"axis"_a = nb::none(),
Expand All @@ -1473,10 +1462,7 @@ void init_array(nb::module_& m) {
if (axis) {
return mx::cummin(a, *axis, reverse, inclusive, s);
} else {
// TODO: Implement that in the C++ API as well. See concatenate
// above.
return mx::cummin(
mx::reshape(a, {-1}, s), 0, reverse, inclusive, s);
return mx::cummin(a, reverse, inclusive, s);
}
},
"axis"_a = nb::none(),
Expand Down