Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 73 additions & 1 deletion code/numpy/compare.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
*
* The MIT License (MIT)
*
* Copyright (c) 2020-2021 Zoltán Vörös
* Copyright (c) 2020-2025 Zoltán Vörös
* 2020 Jeff Epler for Adafruit Industries
*/

Expand All @@ -23,6 +23,78 @@
#include "carray/carray_tools.h"
#include "compare.h"

#ifdef ULAB_NUMPY_HAS_BINCOUNT
mp_obj_t compare_bincount(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
static const mp_arg_t allowed_args[] = {
{ MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE} } ,
{ MP_QSTR_weights, MP_ARG_OBJ | MP_ARG_KW_ONLY, { .u_rom_obj = MP_ROM_NONE } },
{ MP_QSTR_minlength, MP_ARG_OBJ | MP_ARG_KW_ONLY, { .u_rom_obj = MP_ROM_NONE } },
};

mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
mp_arg_parse_all(n_args, pos_args, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);

if(!mp_obj_is_type(args[0].u_obj, &ulab_ndarray_type)) {
mp_raise_TypeError(MP_ERROR_TEXT("input must be an ndarray"));
}
ndarray_obj_t *input = MP_OBJ_TO_PTR(args[0].u_obj);

#if ULAB_MAX_DIMS > 1
// no need to check anything, if the maximum number of dimensions is 1
if(input->ndim != 1) {
mp_raise_ValueError(MP_ERROR_TEXT("object too deep for desired arrayy"));

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo with arrayy

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, thanks for catching it!

}
#endif
if((input->dtype != NDARRAY_UINT8) && (input->dtype != NDARRAY_UINT16)) {
mp_raise_TypeError(MP_ERROR_TEXT("cannot cast array data from dtype"));
}

// first find the maximum of the array, and figure out how long the result should be
uint16_t max = 0;
int32_t stride = input->strides[ULAB_MAX_DIMS - 1];
if(input->dtype == NDARRAY_UINT8) {
uint8_t *iarray = (uint8_t *)input->array;
for(size_t i = 0; i < input->len; i++) {
if(*iarray > max) {
max = *iarray;
}
iarray += stride;
}
} else if(input->dtype == NDARRAY_UINT16) {
stride /= 2;
uint16_t *iarray = (uint16_t *)input->array;
for(size_t i = 0; i < input->len; i++) {
if(*iarray > max) {
max = *iarray;
}
iarray += stride;
}
}
ndarray_obj_t *result = ndarray_new_linear_array(max + 1, NDARRAY_UINT16);

// now we can do the binning
uint16_t *rarray = (uint16_t *)result->array;

if(input->dtype == NDARRAY_UINT8) {
uint8_t *iarray = (uint8_t *)input->array;
for(size_t i = 0; i < input->len; i++) {
rarray[*iarray] += 1;
iarray += stride;
}
} else if(input->dtype == NDARRAY_UINT16) {
uint16_t *iarray = (uint16_t *)input->array;
for(size_t i = 0; i < input->len; i++) {
rarray[*iarray] += 1;
iarray += stride;
}
}

return MP_OBJ_FROM_PTR(result);
}

MP_DEFINE_CONST_FUN_OBJ_KW(compare_bincount_obj, 1, compare_bincount);
#endif /* ULAB_NUMPY_HAS_BINCOUNT */

static mp_obj_t compare_function(mp_obj_t x1, mp_obj_t x2, uint8_t op) {
ndarray_obj_t *lhs = ndarray_from_mp_obj(x1, 0);
ndarray_obj_t *rhs = ndarray_from_mp_obj(x2, 0);
Expand Down
3 changes: 2 additions & 1 deletion code/numpy/compare.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
*
* The MIT License (MIT)
*
* Copyright (c) 2020-2021 Zoltán Vörös
* Copyright (c) 2020-2025 Zoltán Vörös
*/

#ifndef _COMPARE_
Expand All @@ -23,6 +23,7 @@ enum COMPARE_FUNCTION_TYPE {
COMPARE_CLIP,
};

MP_DECLARE_CONST_FUN_OBJ_KW(compare_bincount_obj);
MP_DECLARE_CONST_FUN_OBJ_3(compare_clip_obj);
MP_DECLARE_CONST_FUN_OBJ_2(compare_equal_obj);
MP_DECLARE_CONST_FUN_OBJ_2(compare_isfinite_obj);
Expand Down
3 changes: 3 additions & 0 deletions code/numpy/numpy.c
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ static const mp_rom_map_elem_t ulab_numpy_globals_table[] = {
#if ULAB_NUMPY_HAS_ZEROS
{ MP_ROM_QSTR(MP_QSTR_zeros), MP_ROM_PTR(&create_zeros_obj) },
#endif
#if ULAB_NUMPY_HAS_BINCOUNT
{ MP_ROM_QSTR(MP_QSTR_bincount), MP_ROM_PTR(&compare_bincount_obj) },
#endif
#if ULAB_NUMPY_HAS_CLIP
{ MP_ROM_QSTR(MP_QSTR_clip), MP_ROM_PTR(&compare_clip_obj) },
#endif
Expand Down
2 changes: 1 addition & 1 deletion code/ulab.c
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
#include "user/user.h"
#include "utils/utils.h"

#define ULAB_VERSION 6.9.0
#define ULAB_VERSION 6.10.0
#define xstr(s) str(s)
#define str(s) #s

Expand Down
6 changes: 5 additions & 1 deletion code/ulab.h
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,10 @@
#endif

// functions that compare arrays
#ifndef ULAB_NUMPY_HAS_BINCOUNT
#define ULAB_NUMPY_HAS_BINCOUNT (1)
#endif

#ifndef ULAB_NUMPY_HAS_CLIP
#define ULAB_NUMPY_HAS_CLIP (1)
#endif
Expand Down Expand Up @@ -413,7 +417,7 @@
// the integrate module; functions of the integrate module still have
// to be defined separately
#ifndef ULAB_SCIPY_HAS_INTEGRATE_MODULE
#define ULAB_SCIPY_HAS_INTEGRATE_MODULE (1)
#define ULAB_SCIPY_HAS_INTEGRATE_MODULE (1)
#endif

#ifndef ULAB_INTEGRATE_HAS_TANHSINH
Expand Down
Loading