Skip to content

Commit f56ddb3

Browse files
committed
add tests and fixes
1 parent 298c8f5 commit f56ddb3

File tree

6 files changed

+251
-85
lines changed

6 files changed

+251
-85
lines changed

doc/src/base/multi-threading.md

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,25 @@ See also [Synchronization](@ref lib-task-sync).
1919

2020
## Atomic operations
2121

22+
```@docs
23+
Base.@atomic
24+
```
25+
26+
!!! note
27+
28+
The following APIs are fairly primitive, and will likely be exposed through an `unsafe_*`-like wrapper.
29+
30+
```
31+
Core.Intrinsics.atomic_pointerref(pointer::Ptr{T}, order::Symbol) --> T
32+
Core.Intrinsics.atomic_pointerset(pointer::Ptr{T}, new::T, order::Symbol) --> pointer
33+
Core.Intrinsics.atomic_pointerswap(pointer::Ptr{T}, new::T, order::Symbol) --> old
34+
Core.Intrinsics.atomic_pointermodify(pointer::Ptr{T}, function::(old::T,arg::S)->T, arg::S, order::Symbol) --> old
35+
Core.Intrinsics.atomic_pointercmpswap(pointer::Ptr{T}, expected::Any, new::T, success_order::Symbol, failure_order::Symbol) --> (old, cmp)
36+
```
37+
2238
!!! warning
2339

24-
The API for atomic operations has not yet been finalized and is likely to change.
40+
The following APIs are deprecated, though support for them is likely to remain for several releases.
2541

2642
```@docs
2743
Base.Threads.Atomic

src/builtins.c

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -969,8 +969,11 @@ JL_CALLABLE(jl_f_cmpswapfield)
969969
if (isatomic == (success_order == jl_memory_order_notatomic))
970970
jl_atomic_error(isatomic ? "cmpswapfield!: atomic field cannot be written non-atomically"
971971
: "cmpswapfield!: non-atomic field cannot be written atomically");
972+
if (isatomic == (failure_order == jl_memory_order_notatomic))
973+
jl_atomic_error(isatomic ? "cmpswapfield!: atomic field cannot be accessed non-atomically"
974+
: "cmpswapfield!: non-atomic field cannot be accessed atomically");
972975
if (failure_order > success_order)
973-
jl_atomic_error("cmpswapfield!: invalid atomic ordering");
976+
jl_atomic_error("invalid atomic ordering");
974977
v = cmpswap_nth_field(st, v, idx, args[2], args[3], isatomic); // always seq_cst, if isatomic needed at all
975978
return v;
976979
}

src/datatype.c

Lines changed: 76 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -891,6 +891,10 @@ JL_DLLEXPORT int jl_atomic_bool_cmpswap_bits(char *dst, const jl_value_t *expect
891891
// n.b.: this can spuriously fail if there are padding bits, the caller should deal with that
892892
int success;
893893
switch (nb) {
894+
case 0: {
895+
success = 1;
896+
break;
897+
}
894898
case 1: {
895899
uint8_t y = *(uint8_t*)expected;
896900
success = jl_atomic_cmpswap((uint8_t*)dst, &y, *(uint8_t*)src);
@@ -941,53 +945,88 @@ JL_DLLEXPORT jl_value_t *jl_atomic_cmpswap_bits(jl_datatype_t *dt, char *dst, co
941945
jl_ptls_t ptls = jl_get_ptls_states();
942946
jl_value_t *y = jl_gc_alloc(ptls, isptr ? nb : tuptyp->size, isptr ? dt : tuptyp);
943947
int success;
948+
jl_datatype_t *et = (jl_datatype_t*)jl_typeof(expected);
944949
switch (nb) {
950+
case 0: {
951+
success = (dt == et);
952+
break;
953+
}
945954
case 1: {
946955
uint8_t *y8 = (uint8_t*)y;
947-
*y8 = *(uint8_t*)expected;
948-
success = jl_atomic_cmpswap((uint8_t*)dst, y8, *(uint8_t*)src);
956+
if (dt == et) {
957+
*y8 = *(uint8_t*)expected;
958+
success = jl_atomic_cmpswap((uint8_t*)dst, y8, *(uint8_t*)src);
959+
}
960+
else {
961+
*y8 = jl_atomic_load((uint8_t*)dst);
962+
success = 0;
963+
}
949964
break;
950965
}
951966
case 2: {
952967
uint16_t *y16 = (uint16_t*)y;
953-
*y16 = *(uint16_t*)expected;
954-
while (1) {
955-
success = jl_atomic_cmpswap((uint16_t*)dst, y16, *(uint16_t*)src);
956-
if (success || !dt->layout->haspadding || !jl_egal__bits(y, expected, dt))
957-
break;
968+
if (dt == et) {
969+
*y16 = *(uint16_t*)expected;
970+
while (1) {
971+
success = jl_atomic_cmpswap((uint16_t*)dst, y16, *(uint16_t*)src);
972+
if (success || !dt->layout->haspadding || !jl_egal__bits(y, expected, dt))
973+
break;
974+
}
975+
}
976+
else {
977+
*y16 = jl_atomic_load((uint16_t*)dst);
978+
success = 0;
958979
}
959980
break;
960981
}
961982
case 4: {
962983
uint32_t *y32 = (uint32_t*)y;
963-
*y32 = *(uint32_t*)expected;
964-
while (1) {
965-
success = jl_atomic_cmpswap((uint32_t*)dst, y32, *(uint32_t*)src);
966-
if (success || !dt->layout->haspadding || !jl_egal__bits(y, expected, dt))
967-
break;
984+
if (dt == et) {
985+
*y32 = *(uint32_t*)expected;
986+
while (1) {
987+
success = jl_atomic_cmpswap((uint32_t*)dst, y32, *(uint32_t*)src);
988+
if (success || !dt->layout->haspadding || !jl_egal__bits(y, expected, dt))
989+
break;
990+
}
991+
}
992+
else {
993+
*y32 = jl_atomic_load((uint32_t*)dst);
994+
success = 0;
968995
}
969996
break;
970997
}
971998
#if MAX_POINTERATOMIC_SIZE > 4
972999
case 8: {
9731000
uint64_t *y64 = (uint64_t*)y;
974-
*y64 = *(uint64_t*)expected;
975-
while (1) {
976-
success = jl_atomic_cmpswap((uint64_t*)dst, y64, *(uint64_t*)src);
977-
if (success || !dt->layout->haspadding || !jl_egal__bits(y, expected, dt))
978-
break;
1001+
if (dt == et) {
1002+
*y64 = *(uint64_t*)expected;
1003+
while (1) {
1004+
success = jl_atomic_cmpswap((uint64_t*)dst, y64, *(uint64_t*)src);
1005+
if (success || !dt->layout->haspadding || !jl_egal__bits(y, expected, dt))
1006+
break;
1007+
}
1008+
}
1009+
else {
1010+
*y64 = jl_atomic_load((uint64_t*)dst);
1011+
success = 0;
9791012
}
9801013
break;
9811014
}
9821015
#endif
9831016
#if MAX_POINTERATOMIC_SIZE > 8
9841017
case 16: {
9851018
uint128_t *y128 = (uint128_t*)y;
986-
*y128 = *(uint128_t*)expected;
987-
while (1) {
988-
success = jl_atomic_cmpswap((uint128_t*)dst, y128, *(uint128_t*)src);
989-
if (success || !dt->layout->haspadding || !jl_egal__bits(y, expected, dt))
990-
break;
1019+
if (dt == et) {
1020+
*y128 = *(uint128_t*)expected;
1021+
while (1) {
1022+
success = jl_atomic_cmpswap((uint128_t*)dst, y128, *(uint128_t*)src);
1023+
if (success || !dt->layout->haspadding || !jl_egal__bits(y, expected, dt))
1024+
break;
1025+
}
1026+
}
1027+
else {
1028+
*y128 = jl_atomic_load((uint128_t*)dst);
1029+
success = 0;
9911030
}
9921031
break;
9931032
}
@@ -1569,12 +1608,9 @@ jl_value_t *modify_nth_field(jl_datatype_t *st, jl_value_t *v, size_t i, jl_valu
15691608
if (isunion) {
15701609
size_t fsz = jl_field_size(st, i);
15711610
uint8_t *psel = &((uint8_t*)v)[offs + fsz - 1];
1572-
unsigned nth = 0;
1573-
if (!jl_find_union_component(ty, jl_typeof(r), &nth))
1574-
assert(0 && "invalid field assignment to isbits union");
1575-
success = (*psel == nth);
1611+
success = (jl_typeof(r) == jl_nth_union_component(ty, *psel));
15761612
if (success) {
1577-
nth = 0;
1613+
unsigned nth = 0;
15781614
if (!jl_find_union_component(ty, yty, &nth))
15791615
assert(0 && "invalid field assignment to isbits union");
15801616
*psel = nth;
@@ -1627,7 +1663,6 @@ jl_value_t *cmpswap_nth_field(jl_datatype_t *st, jl_value_t *v, size_t i, jl_val
16271663
JL_GC_POP();
16281664
}
16291665
else {
1630-
jl_value_t *rty = jl_typeof(r);
16311666
int hasptr;
16321667
int isunion = jl_is_uniontype(ty);
16331668
if (isunion) {
@@ -1637,7 +1672,10 @@ jl_value_t *cmpswap_nth_field(jl_datatype_t *st, jl_value_t *v, size_t i, jl_val
16371672
else {
16381673
hasptr = ((jl_datatype_t*)ty)->layout->npointers > 0;
16391674
}
1640-
size_t fsz = jl_datatype_size((jl_datatype_t*)rty); // need to shrink-wrap the final copy
1675+
jl_value_t *rty = ty;
1676+
size_t fsz;
1677+
if (!isunion)
1678+
fsz = jl_datatype_size((jl_datatype_t*)rty); // need to shrink-wrap the final copy
16411679
int needlock = (isatomic && fsz > MAX_ATOMIC_SIZE);
16421680
if (isatomic && !needlock) {
16431681
r = jl_atomic_cmpswap_bits((jl_datatype_t*)rty, (char*)v + offs, r, rhs, fsz);
@@ -1646,45 +1684,37 @@ jl_value_t *cmpswap_nth_field(jl_datatype_t *st, jl_value_t *v, size_t i, jl_val
16461684
jl_gc_multi_wb(v, rhs); // rhs is immutable
16471685
}
16481686
else {
1687+
jl_ptls_t ptls = jl_get_ptls_states();
16491688
uint8_t *psel;
1650-
unsigned nth;
1651-
int success;
16521689
if (isunion) {
16531690
size_t fsz = jl_field_size(st, i);
16541691
psel = &((uint8_t*)v)[offs + fsz - 1];
1655-
nth = 0;
1656-
success = jl_find_union_component(ty, rty, &nth);
1657-
uint8_t sel = *psel;
1658-
if (success)
1659-
success = nth == sel;
1660-
rty = jl_nth_union_component(ty, sel);
1661-
}
1662-
else {
1663-
success = rty == ty;
1692+
rty = jl_nth_union_component(rty, *psel);
16641693
}
16651694
jl_value_t *params[2];
16661695
params[0] = rty;
16671696
params[1] = (jl_value_t*)jl_bool_type;
16681697
jl_datatype_t *tuptyp = jl_apply_tuple_type_v(params, 2);
16691698
JL_GC_PROMISE_ROOTED(tuptyp); // (JL_ALWAYS_LEAFTYPE)
16701699
assert(!jl_field_isptr(tuptyp, 0));
1671-
jl_ptls_t ptls = jl_get_ptls_states();
16721700
r = jl_gc_alloc(ptls, tuptyp->size, (jl_value_t*)tuptyp);
1701+
int success = (rty == jl_typeof(expected));
16731702
if (needlock)
16741703
jl_lock_value(v);
1704+
size_t fsz = jl_datatype_size((jl_datatype_t*)rty); // need to shrink-wrap the final copy
1705+
memcpy((char*)r, (char*)v + offs, fsz);
16751706
if (success) {
1676-
memcpy((char*)r, (char*)v + offs, fsz);
16771707
if (((jl_datatype_t*)rty)->layout->haspadding)
16781708
success = jl_egal__bits(r, expected, (jl_datatype_t*)rty);
16791709
else
16801710
success = memcmp((char*)r, (char*)expected, fsz) == 0;
16811711
}
16821712
*((uint8_t*)r + fsz) = success ? 1 : 0;
16831713
if (success) {
1684-
rty = jl_typeof(rhs);
1685-
fsz = jl_datatype_size((jl_datatype_t*)rty); // need to shrink-wrap the final copy
1714+
jl_value_t *rty = jl_typeof(rhs);
1715+
size_t fsz = jl_datatype_size((jl_datatype_t*)rty); // need to shrink-wrap the final copy
16861716
if (isunion) {
1687-
nth = 0;
1717+
unsigned nth = 0;
16881718
if (!jl_find_union_component(ty, rty, &nth))
16891719
assert(0 && "invalid field assignment to isbits union");
16901720
*psel = nth;
@@ -1696,7 +1726,7 @@ jl_value_t *cmpswap_nth_field(jl_datatype_t *st, jl_value_t *v, size_t i, jl_val
16961726
if (needlock)
16971727
jl_unlock_value(v);
16981728
}
1699-
r = undefref_check((jl_datatype_t*)ty, r);
1729+
r = undefref_check((jl_datatype_t*)rty, r);
17001730
if (__unlikely(r == NULL))
17011731
jl_throw(jl_undefref_exception);
17021732
}

src/runtime_intrinsics.c

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,6 @@ JL_DLLEXPORT jl_value_t *jl_atomic_pointercmpswap(jl_value_t *p, jl_value_t *exp
209209
jl_error("pointercmpswap: invalid pointer");
210210
if (jl_typeof(x) != ety)
211211
jl_type_error("pointercmpswap", ety, x);
212-
if (jl_typeof(expected) != ety)
213-
jl_type_error("pointercmpswap", ety, expected);
214212
size_t nb = jl_datatype_size(ety);
215213
if ((nb & (nb - 1)) != 0 || nb > MAX_POINTERATOMIC_SIZE)
216214
jl_error("pointercmpswap: invalid atomic operation");

0 commit comments

Comments
 (0)