Skip to content

Commit 9250c1b

Browse files
authored
[mypyc] Add get/set item, len and truncate to BytesWriter (#20298)
Add support for `b[i]`, `b[i] = x`, `len(b)` and `b.truncate(n)` to `librt.strings.BytesWriter`. The implementations are still unoptimized. Get item and set item don't have primitives yet, but I'll add them in a follow-up PR.
1 parent 688ea0c commit 9250c1b

File tree

8 files changed

+299
-22
lines changed

8 files changed

+299
-22
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
from typing import final
22

3+
from mypy_extensions import i64, u8
4+
35
@final
46
class BytesWriter:
57
def append(self, /, x: int) -> None: ...
68
def write(self, /, b: bytes) -> None: ...
79
def getvalue(self) -> bytes: ...
10+
def truncate(self, /, size: i64) -> None: ...
11+
def __len__(self) -> i64: ...
12+
def __getitem__(self, /, i: i64) -> u8: ...
13+
def __setitem__(self, /, i: i64, x: u8) -> None: ...

mypyc/irbuild/ll_builder.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@
179179
ERR_NEG_INT,
180180
CFunctionDescription,
181181
binary_ops,
182+
function_ops,
182183
method_call_ops,
183184
unary_ops,
184185
)
@@ -2489,7 +2490,11 @@ def builtin_len(self, val: Value, line: int, use_pyssize_t: bool = False) -> Val
24892490
self.activate_block(ok)
24902491
return length
24912492

2492-
# generic case
2493+
op = self.matching_primitive_op(function_ops["builtins.len"], [val], line)
2494+
if op is not None:
2495+
return op
2496+
2497+
# Fallback generic case
24932498
if use_pyssize_t:
24942499
return self.call_c(generic_ssize_t_len_op, [val], line)
24952500
else:

mypyc/lib-rt/librt_strings.c

Lines changed: 121 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ _grow_buffer(BytesWriterObject *data, Py_ssize_t n) {
4141
do {
4242
size *= 2;
4343
} while (target >= size);
44-
if (old_size == WRITER_EMBEDDED_BUF_LEN) {
44+
if (data->buf == data->data) {
4545
// Move from embedded buffer to heap-allocated buffer
4646
data->buf = PyMem_Malloc(size);
4747
if (data->buf != NULL) {
@@ -155,8 +155,65 @@ BytesWriter_getvalue(BytesWriterObject *self, PyObject *Py_UNUSED(ignored))
155155
return PyBytes_FromStringAndSize(self->buf, self->len);
156156
}
157157

158+
static Py_ssize_t
159+
BytesWriter_length(BytesWriterObject *self)
160+
{
161+
return self->len;
162+
}
163+
164+
static PyObject*
165+
BytesWriter_item(BytesWriterObject *self, Py_ssize_t index)
166+
{
167+
Py_ssize_t length = self->len;
168+
169+
// Check bounds
170+
if (index < 0 || index >= length) {
171+
PyErr_SetString(PyExc_IndexError, "BytesWriter index out of range");
172+
return NULL;
173+
}
174+
175+
// Return the byte at the given index as a Python int
176+
return PyLong_FromLong((unsigned char)self->buf[index]);
177+
}
178+
179+
static int
180+
BytesWriter_ass_item(BytesWriterObject *self, Py_ssize_t index, PyObject *value)
181+
{
182+
Py_ssize_t length = self->len;
183+
184+
// Check bounds
185+
if (index < 0 || index >= length) {
186+
PyErr_SetString(PyExc_IndexError, "BytesWriter index out of range");
187+
return -1;
188+
}
189+
190+
// Check that value is not NULL (deletion not supported)
191+
if (value == NULL) {
192+
PyErr_SetString(PyExc_TypeError, "BytesWriter does not support item deletion");
193+
return -1;
194+
}
195+
196+
// Convert value to uint8
197+
uint8_t byte_value = CPyLong_AsUInt8(value);
198+
if (unlikely(byte_value == CPY_LL_UINT_ERROR && PyErr_Occurred())) {
199+
CPy_TypeError("u8", value);
200+
return -1;
201+
}
202+
203+
// Assign the byte
204+
self->buf[index] = (char)byte_value;
205+
return 0;
206+
}
207+
208+
static PySequenceMethods BytesWriter_as_sequence = {
209+
.sq_length = (lenfunc)BytesWriter_length,
210+
.sq_item = (ssizeargfunc)BytesWriter_item,
211+
.sq_ass_item = (ssizeobjargproc)BytesWriter_ass_item,
212+
};
213+
158214
static PyObject* BytesWriter_append(PyObject *self, PyObject *const *args, size_t nargs, PyObject *kwnames);
159215
static PyObject* BytesWriter_write(PyObject *self, PyObject *const *args, size_t nargs, PyObject *kwnames);
216+
static PyObject* BytesWriter_truncate(PyObject *self, PyObject *const *args, size_t nargs);
160217

161218
static PyMethodDef BytesWriter_methods[] = {
162219
{"append", (PyCFunction) BytesWriter_append, METH_FASTCALL | METH_KEYWORDS,
@@ -168,6 +225,9 @@ static PyMethodDef BytesWriter_methods[] = {
168225
{"getvalue", (PyCFunction) BytesWriter_getvalue, METH_NOARGS,
169226
"Return the buffer content as bytes object"
170227
},
228+
{"truncate", (PyCFunction) BytesWriter_truncate, METH_FASTCALL,
229+
PyDoc_STR("Truncate the buffer to the specified size")
230+
},
171231
{NULL} /* Sentinel */
172232
};
173233

@@ -182,6 +242,7 @@ static PyTypeObject BytesWriterType = {
182242
.tp_init = (initproc) BytesWriter_init,
183243
.tp_dealloc = (destructor) BytesWriter_dealloc,
184244
.tp_methods = BytesWriter_methods,
245+
.tp_as_sequence = &BytesWriter_as_sequence,
185246
.tp_repr = (reprfunc)BytesWriter_repr,
186247
};
187248

@@ -268,11 +329,68 @@ BytesWriter_append(PyObject *self, PyObject *const *args, size_t nargs, PyObject
268329
return Py_None;
269330
}
270331

332+
static char
333+
BytesWriter_truncate_internal(PyObject *self, int64_t size) {
334+
BytesWriterObject *writer = (BytesWriterObject *)self;
335+
Py_ssize_t current_size = writer->len;
336+
337+
// Validate size is non-negative
338+
if (size < 0) {
339+
PyErr_SetString(PyExc_ValueError, "size must be non-negative");
340+
return CPY_NONE_ERROR;
341+
}
342+
343+
// Validate size doesn't exceed current size
344+
if (size > current_size) {
345+
PyErr_SetString(PyExc_ValueError, "size cannot be larger than current buffer size");
346+
return CPY_NONE_ERROR;
347+
}
348+
349+
writer->len = size;
350+
return CPY_NONE;
351+
}
352+
353+
static PyObject*
354+
BytesWriter_truncate(PyObject *self, PyObject *const *args, size_t nargs) {
355+
if (unlikely(nargs != 1)) {
356+
PyErr_Format(PyExc_TypeError,
357+
"truncate() takes exactly 1 argument (%zu given)", nargs);
358+
return NULL;
359+
}
360+
if (!check_bytes_writer(self)) {
361+
return NULL;
362+
}
363+
364+
PyObject *size_obj = args[0];
365+
int overflow;
366+
long long size = PyLong_AsLongLongAndOverflow(size_obj, &overflow);
367+
368+
if (size == -1 && PyErr_Occurred()) {
369+
return NULL;
370+
}
371+
if (overflow != 0) {
372+
PyErr_SetString(PyExc_ValueError, "integer out of range");
373+
return NULL;
374+
}
375+
376+
if (unlikely(BytesWriter_truncate_internal(self, size) == CPY_NONE_ERROR)) {
377+
return NULL;
378+
}
379+
Py_INCREF(Py_None);
380+
return Py_None;
381+
}
382+
271383
static PyTypeObject *
272384
BytesWriter_type_internal(void) {
273385
return &BytesWriterType; // Return borrowed reference
274386
};
275387

388+
static CPyTagged
389+
BytesWriter_len_internal(PyObject *self) {
390+
BytesWriterObject *writer = (BytesWriterObject *)self;
391+
return writer->len << 1;
392+
}
393+
276394
static PyMethodDef librt_strings_module_methods[] = {
277395
{NULL, NULL, 0, NULL}
278396
};
@@ -311,6 +429,8 @@ librt_strings_module_exec(PyObject *m)
311429
(void *)BytesWriter_append_internal,
312430
(void *)BytesWriter_write_internal,
313431
(void *)BytesWriter_type_internal,
432+
(void *)BytesWriter_len_internal,
433+
(void *)BytesWriter_truncate_internal,
314434
};
315435
PyObject *c_api_object = PyCapsule_New((void *)librt_strings_api, "librt.strings._C_API", NULL);
316436
if (PyModule_Add(m, "_C_API", c_api_object) < 0) {

mypyc/lib-rt/librt_strings.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@ import_librt_strings(void)
2020
// API version -- more recent versions must maintain backward compatibility, i.e.
2121
// we can add new features but not remove or change existing features (unless
2222
// ABI version is changed, but see the comment above).
23-
#define LIBRT_STRINGS_API_VERSION 0
23+
#define LIBRT_STRINGS_API_VERSION 1
2424

2525
// Number of functions in the capsule API. If you add a new function, also increase
2626
// LIBRT_STRINGS_API_VERSION.
27-
#define LIBRT_STRINGS_API_LEN 7
27+
#define LIBRT_STRINGS_API_LEN 9
2828

2929
static void *LibRTStrings_API[LIBRT_STRINGS_API_LEN];
3030

@@ -35,6 +35,8 @@ static void *LibRTStrings_API[LIBRT_STRINGS_API_LEN];
3535
#define LibRTStrings_BytesWriter_append_internal (*(char (*)(PyObject *source, uint8_t value)) LibRTStrings_API[4])
3636
#define LibRTStrings_BytesWriter_write_internal (*(char (*)(PyObject *source, PyObject *value)) LibRTStrings_API[5])
3737
#define LibRTStrings_BytesWriter_type_internal (*(PyTypeObject* (*)(void)) LibRTStrings_API[6])
38+
#define LibRTStrings_BytesWriter_len_internal (*(CPyTagged (*)(PyObject *self)) LibRTStrings_API[7])
39+
#define LibRTStrings_BytesWriter_truncate_internal (*(char (*)(PyObject *self, int64_t size)) LibRTStrings_API[8])
3840

3941
static int
4042
import_librt_strings(void)

mypyc/primitives/librt_strings_ops.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
11
from typing import Final
22

3-
from mypyc.ir.ops import ERR_MAGIC
4-
from mypyc.ir.rtypes import KNOWN_NATIVE_TYPES, bytes_rprimitive, none_rprimitive, uint8_rprimitive
3+
from mypyc.ir.ops import ERR_MAGIC, ERR_NEVER
4+
from mypyc.ir.rtypes import (
5+
KNOWN_NATIVE_TYPES,
6+
bytes_rprimitive,
7+
int64_rprimitive,
8+
none_rprimitive,
9+
short_int_rprimitive,
10+
uint8_rprimitive,
11+
)
512
from mypyc.primitives.registry import function_op, method_op
613

714
bytes_writer_rprimitive: Final = KNOWN_NATIVE_TYPES["librt.strings.BytesWriter"]
@@ -45,3 +52,21 @@
4552
experimental=True,
4653
capsule="librt.strings",
4754
)
55+
56+
method_op(
57+
name="truncate",
58+
arg_types=[bytes_writer_rprimitive, int64_rprimitive],
59+
return_type=none_rprimitive,
60+
c_function_name="LibRTStrings_BytesWriter_truncate_internal",
61+
error_kind=ERR_MAGIC,
62+
)
63+
64+
function_op(
65+
name="builtins.len",
66+
arg_types=[bytes_writer_rprimitive],
67+
return_type=short_int_rprimitive,
68+
c_function_name="LibRTStrings_BytesWriter_len_internal",
69+
error_kind=ERR_NEVER,
70+
experimental=True,
71+
capsule="librt.strings",
72+
)

mypyc/primitives/registry.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,20 @@
77
88
Example op definition:
99
10-
list_len_op = func_op(name='builtins.len',
11-
arg_types=[list_rprimitive],
12-
result_type=short_int_rprimitive,
13-
error_kind=ERR_NEVER,
14-
emit=emit_len)
10+
list_len_op = function_op(name='builtins.len',
11+
arg_types=[list_rprimitive],
12+
result_type=short_int_rprimitive,
13+
error_kind=ERR_NEVER,
14+
c_function_name="...")
1515
1616
This op is automatically generated for calls to len() with a single
1717
list argument. The result type is short_int_rprimitive, and this
18-
never raises an exception (ERR_NEVER). The function emit_len is used
19-
to generate C for this op. The op can also be manually generated using
20-
"list_len_op". Ops that are only generated automatically don't need to
18+
never raises an exception (ERR_NEVER). The function c_function_name is
19+
called when generating C for this op. The op can also be manually generated
20+
using "list_len_op". Ops that are only generated automatically don't need to
2121
be assigned to a module attribute.
2222
23-
Ops defined with custom_op are only explicitly generated in
23+
Ops defined with custom[_primitive]_op are only explicitly generated in
2424
mypyc.irbuild and won't be generated automatically. They are always
2525
assigned to a module attribute, as otherwise they won't be accessible.
2626
@@ -306,6 +306,7 @@ def custom_primitive_op(
306306
steals: StealsDescription = False,
307307
is_borrowed: bool = False,
308308
is_pure: bool = False,
309+
experimental: bool = False,
309310
capsule: str | None = None,
310311
) -> PrimitiveDescription:
311312
"""Define a primitive op that can't be automatically generated based on the AST.
@@ -328,7 +329,7 @@ def custom_primitive_op(
328329
extra_int_constants=extra_int_constants,
329330
priority=0,
330331
is_pure=is_pure,
331-
experimental=False,
332+
experimental=experimental,
332333
capsule=capsule,
333334
)
334335

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,43 @@
11
[case testLibrtStrings_experimental]
22
from librt.strings import BytesWriter
3-
from mypy_extensions import u8
3+
from mypy_extensions import u8, i64
44

5-
def bytes_builder_basics() -> bytes:
5+
def bytes_writer_basics() -> bytes:
66
b = BytesWriter()
77
x: u8 = 1
88
b.append(x)
99
b.write(b'foo')
10+
n: i64 = 2
11+
b.truncate(n)
1012
return b.getvalue()
13+
def bytes_writer_len(b: BytesWriter) -> i64:
14+
return len(b)
1115
[out]
12-
def bytes_builder_basics():
16+
def bytes_writer_basics():
1317
r0, b :: librt.strings.BytesWriter
1418
x :: u8
1519
r1 :: None
1620
r2 :: bytes
1721
r3 :: None
18-
r4 :: bytes
22+
n :: i64
23+
r4 :: None
24+
r5 :: bytes
1925
L0:
2026
r0 = LibRTStrings_BytesWriter_internal()
2127
b = r0
2228
x = 1
2329
r1 = LibRTStrings_BytesWriter_append_internal(b, x)
2430
r2 = b'foo'
2531
r3 = LibRTStrings_BytesWriter_write_internal(b, r2)
26-
r4 = LibRTStrings_BytesWriter_getvalue_internal(b)
27-
return r4
32+
n = 2
33+
r4 = LibRTStrings_BytesWriter_truncate_internal(b, n)
34+
r5 = LibRTStrings_BytesWriter_getvalue_internal(b)
35+
return r5
36+
def bytes_writer_len(b):
37+
b :: librt.strings.BytesWriter
38+
r0 :: short_int
39+
r1 :: i64
40+
L0:
41+
r0 = LibRTStrings_BytesWriter_len_internal(b)
42+
r1 = r0 >> 1
43+
return r1

0 commit comments

Comments
 (0)