Skip to content

Commit 3b2543b

Browse files
authored
Support Literal[True] and Literal[False] types (#1004)
Closes #859
1 parent 754e636 commit 3b2543b

6 files changed

Lines changed: 141 additions & 12 deletions

File tree

docs/supported-types.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,6 +1143,7 @@ purpose, but with a `typing.Literal` the decoded values are literal `int` or
11431143
A literal can be composed of any of the following objects:
11441144

11451145
- `None`
1146+
- `bool` values (`True` and `False`)
11461147
- `int` values
11471148
- `str` values
11481149
- Nested `typing.Literal` types
@@ -1170,6 +1171,12 @@ values, or doesn't match any of their component types.
11701171
File "<stdin>", line 1, in <module>
11711172
msgspec.ValidationError: Expected `int`, got `str`
11721173
1174+
>>> msgspec.json.decode(b'true', type=Literal[True])
1175+
True
1176+
1177+
>>> msgspec.json.decode(b'false', type=Literal[True, False])
1178+
False
1179+
11731180
``NewType``
11741181
-----------
11751182

src/msgspec/_core.c

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2823,6 +2823,8 @@ AssocList_Sort(AssocList* list) {
28232823
#define MS_TYPE_TYPEDDICT (1ull << 33)
28242824
#define MS_TYPE_DATACLASS (1ull << 34)
28252825
#define MS_TYPE_NAMEDTUPLE (1ull << 35)
2826+
#define MS_TYPE_BOOLLITERAL_TRUE (1ull << 36)
2827+
#define MS_TYPE_BOOLLITERAL_FALSE (1ull << 37)
28262828
/* Constraints */
28272829
#define MS_CONSTR_INT_MIN (1ull << 42)
28282830
#define MS_CONSTR_INT_MAX (1ull << 43)
@@ -2946,6 +2948,8 @@ typedef struct {
29462948
PyObject *int_lookup;
29472949
PyObject *str_lookup;
29482950
bool literal_none;
2951+
bool literal_bool_true;
2952+
bool literal_bool_false;
29492953
} LiteralInfo;
29502954

29512955
typedef struct {
@@ -3452,7 +3456,7 @@ typenode_simple_repr(TypeNode *self) {
34523456
if (self->types & (MS_TYPE_ANY | MS_TYPE_CUSTOM | MS_TYPE_CUSTOM_GENERIC) || self->types == 0) {
34533457
return PyUnicode_FromString("any");
34543458
}
3455-
if (self->types & MS_TYPE_BOOL) {
3459+
if (self->types & (MS_TYPE_BOOL | MS_TYPE_BOOLLITERAL_TRUE | MS_TYPE_BOOLLITERAL_FALSE)) {
34563460
if (!strbuilder_extend_literal(&builder, "bool")) return NULL;
34573461
}
34583462
if (self->types & (MS_TYPE_INT | MS_TYPE_INTENUM | MS_TYPE_INTLITERAL)) {
@@ -3546,6 +3550,8 @@ typedef struct {
35463550
PyObject *literal_str_values;
35473551
PyObject *literal_str_lookup;
35483552
bool literal_none;
3553+
bool literal_bool_true;
3554+
bool literal_bool_false;
35493555
/* Constraints */
35503556
int64_t c_int_min;
35513557
int64_t c_int_max;
@@ -4436,6 +4442,14 @@ typenode_collect_literal(TypeNodeCollectState *state, PyObject *literal) {
44364442
if (obj == Py_None || obj == NONE_TYPE) {
44374443
state->literal_none = true;
44384444
}
4445+
else if (type == &PyBool_Type) {
4446+
if (obj == Py_True) {
4447+
state->literal_bool_true = true;
4448+
}
4449+
else {
4450+
state->literal_bool_false = true;
4451+
}
4452+
}
44394453
else if (type == &PyLong_Type) {
44404454
if (state->literal_int_values == NULL) {
44414455
state->literal_int_values = PySet_New(NULL);
@@ -4472,7 +4486,7 @@ typenode_collect_literal(TypeNodeCollectState *state, PyObject *literal) {
44724486
invalid:
44734487
PyErr_Format(
44744488
PyExc_TypeError,
4475-
"Literal may only contain None/integers/strings - %R is not supported",
4489+
"Literal may only contain None/booleans/integers/strings - %R is not supported",
44764490
literal
44774491
);
44784492

@@ -4510,6 +4524,12 @@ typenode_collect_convert_literals(TypeNodeCollectState *state) {
45104524
if (info->literal_none) {
45114525
state->types |= MS_TYPE_NONE;
45124526
}
4527+
if (info->literal_bool_true) {
4528+
state->types |= MS_TYPE_BOOLLITERAL_TRUE;
4529+
}
4530+
if (info->literal_bool_false) {
4531+
state->types |= MS_TYPE_BOOLLITERAL_FALSE;
4532+
}
45134533
Py_DECREF(cached);
45144534
return 0;
45154535
}
@@ -4538,6 +4558,12 @@ typenode_collect_convert_literals(TypeNodeCollectState *state) {
45384558
if (state->literal_none) {
45394559
state->types |= MS_TYPE_NONE;
45404560
}
4561+
if (state->literal_bool_true) {
4562+
state->types |= MS_TYPE_BOOLLITERAL_TRUE;
4563+
}
4564+
if (state->literal_bool_false) {
4565+
state->types |= MS_TYPE_BOOLLITERAL_FALSE;
4566+
}
45414567

45424568
if (n == 1) {
45434569
/* A single `Literal` object, cache the lookups on it */
@@ -4548,6 +4574,8 @@ typenode_collect_convert_literals(TypeNodeCollectState *state) {
45484574
Py_XINCREF(state->literal_str_lookup);
45494575
info->str_lookup = state->literal_str_lookup;
45504576
info->literal_none = state->literal_none;
4577+
info->literal_bool_true = state->literal_bool_true;
4578+
info->literal_bool_false = state->literal_bool_false;
45514579
PyObject_GC_Track(info);
45524580
PyObject *literal = PyList_GET_ITEM(state->literals, 0);
45534581
int status = PyObject_SetAttr(
@@ -15353,6 +15381,18 @@ mpack_decode_bool(DecoderState *self, PyObject *val, TypeNode *type, PathNode *p
1535315381
Py_INCREF(val);
1535415382
return val;
1535515383
}
15384+
if (val == Py_True && (type->types & MS_TYPE_BOOLLITERAL_TRUE)) {
15385+
Py_INCREF(Py_True);
15386+
return Py_True;
15387+
}
15388+
if (val == Py_False && (type->types & MS_TYPE_BOOLLITERAL_FALSE)) {
15389+
Py_INCREF(Py_False);
15390+
return Py_False;
15391+
}
15392+
if (type->types & (MS_TYPE_BOOLLITERAL_TRUE | MS_TYPE_BOOLLITERAL_FALSE)) {
15393+
ms_raise_validation_error(path, "Invalid enum value %R%U", val);
15394+
return NULL;
15395+
}
1535615396
return ms_validation_error("bool", type, path);
1535715397
}
1535815398

@@ -16978,10 +17018,14 @@ json_decode_true(JSONDecoderState *self, TypeNode *type, PathNode *path) {
1697817018
if (MS_UNLIKELY(c1 != 'r' || c2 != 'u' || c3 != 'e')) {
1697917019
return json_err_invalid(self, "invalid character");
1698017020
}
16981-
if (type->types & (MS_TYPE_ANY | MS_TYPE_BOOL)) {
17021+
if (type->types & (MS_TYPE_ANY | MS_TYPE_BOOL | MS_TYPE_BOOLLITERAL_TRUE)) {
1698217022
Py_INCREF(Py_True);
1698317023
return Py_True;
1698417024
}
17025+
if (type->types & MS_TYPE_BOOLLITERAL_FALSE) {
17026+
ms_raise_validation_error(path, "Invalid enum value %R%U", Py_True);
17027+
return NULL;
17028+
}
1698517029
return ms_validation_error("bool", type, path);
1698617030
}
1698717031

@@ -16999,10 +17043,14 @@ json_decode_false(JSONDecoderState *self, TypeNode *type, PathNode *path) {
1699917043
if (MS_UNLIKELY(c1 != 'a' || c2 != 'l' || c3 != 's' || c4 != 'e')) {
1700017044
return json_err_invalid(self, "invalid character");
1700117045
}
17002-
if (type->types & (MS_TYPE_ANY | MS_TYPE_BOOL)) {
17046+
if (type->types & (MS_TYPE_ANY | MS_TYPE_BOOL | MS_TYPE_BOOLLITERAL_FALSE)) {
1700317047
Py_INCREF(Py_False);
1700417048
return Py_False;
1700517049
}
17050+
if (type->types & MS_TYPE_BOOLLITERAL_TRUE) {
17051+
ms_raise_validation_error(path, "Invalid enum value %R%U", Py_False);
17052+
return NULL;
17053+
}
1700617054
return ms_validation_error("bool", type, path);
1700717055
}
1700817056

@@ -20651,6 +20699,18 @@ convert_bool(
2065120699
Py_INCREF(obj);
2065220700
return obj;
2065320701
}
20702+
if (obj == Py_True && (type->types & MS_TYPE_BOOLLITERAL_TRUE)) {
20703+
Py_INCREF(Py_True);
20704+
return Py_True;
20705+
}
20706+
if (obj == Py_False && (type->types & MS_TYPE_BOOLLITERAL_FALSE)) {
20707+
Py_INCREF(Py_False);
20708+
return Py_False;
20709+
}
20710+
if (type->types & (MS_TYPE_BOOLLITERAL_TRUE | MS_TYPE_BOOLLITERAL_FALSE)) {
20711+
ms_raise_validation_error(path, "Invalid enum value %R%U", obj);
20712+
return NULL;
20713+
}
2065420714
return ms_validation_error("bool", type, path);
2065520715
}
2065620716

src/msgspec/inspect.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -315,11 +315,11 @@ class LiteralType(Type):
315315
Parameters
316316
----------
317317
values: tuple
318-
A tuple of possible values for this literal instance. Only `str` or
319-
`int` literals are supported.
318+
A tuple of possible values for this literal instance. Only `bool`,
319+
`str`, or `int` literals are supported.
320320
"""
321321

322-
values: Union[Tuple[str, ...], Tuple[int, ...]]
322+
values: Union[Tuple[bool, ...], Tuple[str, ...], Tuple[int, ...]]
323323

324324

325325
class CustomType(Type):

tests/unit/test_common.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -869,13 +869,8 @@ def test_int_literal_values_out_of_range(self, values):
869869
@pytest.mark.parametrize(
870870
"typ",
871871
[
872-
Literal[1, False],
873872
Literal["ok", b"bad"],
874873
Literal[1, object()],
875-
Union[Literal[1, 2], Literal[3, False]],
876-
Union[Literal["one", "two"], Literal[3, False]],
877-
Literal[Literal[1, 2], Literal[3, False]],
878-
Literal[Literal["one", "two"], Literal[3, False]],
879874
Literal[1, 2, List[int]],
880875
Literal[1, 2, List],
881876
],
@@ -952,6 +947,34 @@ def test_nested_literals(self):
952947
with pytest.raises(ValidationError, match="Invalid enum value 'carrot'"):
953948
dec.decode(msgspec.msgpack.encode("carrot"))
954949

950+
@pytest.mark.parametrize(
951+
"typ, good, bad",
952+
[
953+
(Literal[True], [True], [False]),
954+
(Literal[False], [False], [True]),
955+
(Literal[True, False], [True, False], []),
956+
(Literal[1, False], [1, False], [True]),
957+
(Literal[True, "yes", None], [True, "yes", None], [False]),
958+
],
959+
)
960+
def test_literal_bool(self, typ, good, bad):
961+
dec = msgspec.msgpack.Decoder(typ)
962+
for val in good:
963+
assert dec.decode(msgspec.msgpack.encode(val)) == val
964+
for val in bad:
965+
with pytest.raises(ValidationError):
966+
dec.decode(msgspec.msgpack.encode(val))
967+
968+
def test_literal_bool_error_message(self):
969+
dec = msgspec.msgpack.Decoder(Literal[True])
970+
with pytest.raises(ValidationError, match="Invalid enum value False"):
971+
dec.decode(msgspec.msgpack.encode(False))
972+
973+
def test_mix_bool_and_bool_literal(self):
974+
dec = msgspec.msgpack.Decoder(Union[Literal[True], bool])
975+
assert dec.decode(msgspec.msgpack.encode(True)) is True
976+
assert dec.decode(msgspec.msgpack.encode(False)) is False
977+
955978
def test_mix_int_and_int_literal(self):
956979
dec = msgspec.msgpack.Decoder(Union[Literal[-1, 1], int])
957980
for x in [-1, 1, 10]:

tests/unit/test_convert.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -921,6 +921,18 @@ def test_int_literal(self):
921921
with pytest.raises(ValidationError, match="Expected `int`, got `str`"):
922922
convert("A", typ)
923923

924+
def test_bool_literal(self):
925+
assert convert(True, Literal[True]) is True
926+
assert convert(False, Literal[False]) is False
927+
assert convert(True, Literal[True, False]) is True
928+
assert convert(False, Literal[True, False]) is False
929+
with pytest.raises(ValidationError, match="Invalid enum value False"):
930+
convert(False, Literal[True])
931+
with pytest.raises(ValidationError, match="Invalid enum value True"):
932+
convert(True, Literal[False])
933+
with pytest.raises(ValidationError, match="Expected `bool`, got `str`"):
934+
convert("yes", Literal[True])
935+
924936

925937
class TestSequences:
926938
def test_any_sequence(self):

tests/unit/test_json.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1199,6 +1199,10 @@ class TestLiteral:
11991199
(1, 2, "three", "four"),
12001200
(1, None),
12011201
("one", None),
1202+
(True,),
1203+
(False,),
1204+
(True, False),
1205+
(True, 1, "yes"),
12021206
],
12031207
)
12041208
def test_literal(self, values):
@@ -1229,6 +1233,29 @@ def test_str_literal_errors(self):
12291233
with pytest.raises(msgspec.ValidationError, match="Invalid enum value 'bad'"):
12301234
dec.decode(b'"bad"')
12311235

1236+
def test_bool_literal_true_only(self):
1237+
dec = msgspec.json.Decoder(Literal[True])
1238+
assert dec.decode(b"true") is True
1239+
with pytest.raises(msgspec.ValidationError, match="Invalid enum value False"):
1240+
dec.decode(b"false")
1241+
1242+
def test_bool_literal_false_only(self):
1243+
dec = msgspec.json.Decoder(Literal[False])
1244+
assert dec.decode(b"false") is False
1245+
with pytest.raises(msgspec.ValidationError, match="Invalid enum value True"):
1246+
dec.decode(b"true")
1247+
1248+
def test_bool_literal_errors(self):
1249+
dec = msgspec.json.Decoder(Literal[True])
1250+
with pytest.raises(msgspec.ValidationError, match="Expected `bool`, got `int`"):
1251+
dec.decode(b"42")
1252+
with pytest.raises(msgspec.ValidationError, match="Expected `bool`, got `str`"):
1253+
dec.decode(b'"hello"')
1254+
with pytest.raises(
1255+
msgspec.ValidationError, match="Expected `bool`, got `null`"
1256+
):
1257+
dec.decode(b"null")
1258+
12321259

12331260
class TestFloat:
12341261
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)