@@ -891,6 +891,10 @@ JL_DLLEXPORT int jl_atomic_bool_cmpswap_bits(char *dst, const jl_value_t *expect
891891 // n.b.: this can spuriously fail if there are padding bits, the caller should deal with that
892892 int success ;
893893 switch (nb ) {
894+ case 0 : {
895+ success = 1 ;
896+ break ;
897+ }
894898 case 1 : {
895899 uint8_t y = * (uint8_t * )expected ;
896900 success = jl_atomic_cmpswap ((uint8_t * )dst , & y , * (uint8_t * )src );
@@ -941,53 +945,88 @@ JL_DLLEXPORT jl_value_t *jl_atomic_cmpswap_bits(jl_datatype_t *dt, char *dst, co
941945 jl_ptls_t ptls = jl_get_ptls_states ();
942946 jl_value_t * y = jl_gc_alloc (ptls , isptr ? nb : tuptyp -> size , isptr ? dt : tuptyp );
943947 int success ;
948+ jl_datatype_t * et = (jl_datatype_t * )jl_typeof (expected );
944949 switch (nb ) {
950+ case 0 : {
951+ success = (dt == et );
952+ break ;
953+ }
945954 case 1 : {
946955 uint8_t * y8 = (uint8_t * )y ;
947- * y8 = * (uint8_t * )expected ;
948- success = jl_atomic_cmpswap ((uint8_t * )dst , y8 , * (uint8_t * )src );
956+ if (dt == et ) {
957+ * y8 = * (uint8_t * )expected ;
958+ success = jl_atomic_cmpswap ((uint8_t * )dst , y8 , * (uint8_t * )src );
959+ }
960+ else {
961+ * y8 = jl_atomic_load ((uint8_t * )dst );
962+ success = 0 ;
963+ }
949964 break ;
950965 }
951966 case 2 : {
952967 uint16_t * y16 = (uint16_t * )y ;
953- * y16 = * (uint16_t * )expected ;
954- while (1 ) {
955- success = jl_atomic_cmpswap ((uint16_t * )dst , y16 , * (uint16_t * )src );
956- if (success || !dt -> layout -> haspadding || !jl_egal__bits (y , expected , dt ))
957- break ;
968+ if (dt == et ) {
969+ * y16 = * (uint16_t * )expected ;
970+ while (1 ) {
971+ success = jl_atomic_cmpswap ((uint16_t * )dst , y16 , * (uint16_t * )src );
972+ if (success || !dt -> layout -> haspadding || !jl_egal__bits (y , expected , dt ))
973+ break ;
974+ }
975+ }
976+ else {
977+ * y16 = jl_atomic_load ((uint16_t * )dst );
978+ success = 0 ;
958979 }
959980 break ;
960981 }
961982 case 4 : {
962983 uint32_t * y32 = (uint32_t * )y ;
963- * y32 = * (uint32_t * )expected ;
964- while (1 ) {
965- success = jl_atomic_cmpswap ((uint32_t * )dst , y32 , * (uint32_t * )src );
966- if (success || !dt -> layout -> haspadding || !jl_egal__bits (y , expected , dt ))
967- break ;
984+ if (dt == et ) {
985+ * y32 = * (uint32_t * )expected ;
986+ while (1 ) {
987+ success = jl_atomic_cmpswap ((uint32_t * )dst , y32 , * (uint32_t * )src );
988+ if (success || !dt -> layout -> haspadding || !jl_egal__bits (y , expected , dt ))
989+ break ;
990+ }
991+ }
992+ else {
993+ * y32 = jl_atomic_load ((uint32_t * )dst );
994+ success = 0 ;
968995 }
969996 break ;
970997 }
971998#if MAX_POINTERATOMIC_SIZE > 4
972999 case 8 : {
9731000 uint64_t * y64 = (uint64_t * )y ;
974- * y64 = * (uint64_t * )expected ;
975- while (1 ) {
976- success = jl_atomic_cmpswap ((uint64_t * )dst , y64 , * (uint64_t * )src );
977- if (success || !dt -> layout -> haspadding || !jl_egal__bits (y , expected , dt ))
978- break ;
1001+ if (dt == et ) {
1002+ * y64 = * (uint64_t * )expected ;
1003+ while (1 ) {
1004+ success = jl_atomic_cmpswap ((uint64_t * )dst , y64 , * (uint64_t * )src );
1005+ if (success || !dt -> layout -> haspadding || !jl_egal__bits (y , expected , dt ))
1006+ break ;
1007+ }
1008+ }
1009+ else {
1010+ * y64 = jl_atomic_load ((uint64_t * )dst );
1011+ success = 0 ;
9791012 }
9801013 break ;
9811014 }
9821015#endif
9831016#if MAX_POINTERATOMIC_SIZE > 8
9841017 case 16 : {
9851018 uint128_t * y128 = (uint128_t * )y ;
986- * y128 = * (uint128_t * )expected ;
987- while (1 ) {
988- success = jl_atomic_cmpswap ((uint128_t * )dst , y128 , * (uint128_t * )src );
989- if (success || !dt -> layout -> haspadding || !jl_egal__bits (y , expected , dt ))
990- break ;
1019+ if (dt == et ) {
1020+ * y128 = * (uint128_t * )expected ;
1021+ while (1 ) {
1022+ success = jl_atomic_cmpswap ((uint128_t * )dst , y128 , * (uint128_t * )src );
1023+ if (success || !dt -> layout -> haspadding || !jl_egal__bits (y , expected , dt ))
1024+ break ;
1025+ }
1026+ }
1027+ else {
1028+ * y128 = jl_atomic_load ((uint128_t * )dst );
1029+ success = 0 ;
9911030 }
9921031 break ;
9931032 }
@@ -1569,12 +1608,9 @@ jl_value_t *modify_nth_field(jl_datatype_t *st, jl_value_t *v, size_t i, jl_valu
15691608 if (isunion ) {
15701609 size_t fsz = jl_field_size (st , i );
15711610 uint8_t * psel = & ((uint8_t * )v )[offs + fsz - 1 ];
1572- unsigned nth = 0 ;
1573- if (!jl_find_union_component (ty , jl_typeof (r ), & nth ))
1574- assert (0 && "invalid field assignment to isbits union" );
1575- success = (* psel == nth );
1611+ success = (jl_typeof (r ) == jl_nth_union_component (ty , * psel ));
15761612 if (success ) {
1577- nth = 0 ;
1613+ unsigned nth = 0 ;
15781614 if (!jl_find_union_component (ty , yty , & nth ))
15791615 assert (0 && "invalid field assignment to isbits union" );
15801616 * psel = nth ;
@@ -1627,7 +1663,6 @@ jl_value_t *cmpswap_nth_field(jl_datatype_t *st, jl_value_t *v, size_t i, jl_val
16271663 JL_GC_POP ();
16281664 }
16291665 else {
1630- jl_value_t * rty = jl_typeof (r );
16311666 int hasptr ;
16321667 int isunion = jl_is_uniontype (ty );
16331668 if (isunion ) {
@@ -1637,7 +1672,10 @@ jl_value_t *cmpswap_nth_field(jl_datatype_t *st, jl_value_t *v, size_t i, jl_val
16371672 else {
16381673 hasptr = ((jl_datatype_t * )ty )-> layout -> npointers > 0 ;
16391674 }
1640- size_t fsz = jl_datatype_size ((jl_datatype_t * )rty ); // need to shrink-wrap the final copy
1675+ jl_value_t * rty = ty ;
1676+ size_t fsz ;
1677+ if (!isunion )
1678+ fsz = jl_datatype_size ((jl_datatype_t * )rty ); // need to shrink-wrap the final copy
16411679 int needlock = (isatomic && fsz > MAX_ATOMIC_SIZE );
16421680 if (isatomic && !needlock ) {
16431681 r = jl_atomic_cmpswap_bits ((jl_datatype_t * )rty , (char * )v + offs , r , rhs , fsz );
@@ -1646,45 +1684,37 @@ jl_value_t *cmpswap_nth_field(jl_datatype_t *st, jl_value_t *v, size_t i, jl_val
16461684 jl_gc_multi_wb (v , rhs ); // rhs is immutable
16471685 }
16481686 else {
1687+ jl_ptls_t ptls = jl_get_ptls_states ();
16491688 uint8_t * psel ;
1650- unsigned nth ;
1651- int success ;
16521689 if (isunion ) {
16531690 size_t fsz = jl_field_size (st , i );
16541691 psel = & ((uint8_t * )v )[offs + fsz - 1 ];
1655- nth = 0 ;
1656- success = jl_find_union_component (ty , rty , & nth );
1657- uint8_t sel = * psel ;
1658- if (success )
1659- success = nth == sel ;
1660- rty = jl_nth_union_component (ty , sel );
1661- }
1662- else {
1663- success = rty == ty ;
1692+ rty = jl_nth_union_component (rty , * psel );
16641693 }
16651694 jl_value_t * params [2 ];
16661695 params [0 ] = rty ;
16671696 params [1 ] = (jl_value_t * )jl_bool_type ;
16681697 jl_datatype_t * tuptyp = jl_apply_tuple_type_v (params , 2 );
16691698 JL_GC_PROMISE_ROOTED (tuptyp ); // (JL_ALWAYS_LEAFTYPE)
16701699 assert (!jl_field_isptr (tuptyp , 0 ));
1671- jl_ptls_t ptls = jl_get_ptls_states ();
16721700 r = jl_gc_alloc (ptls , tuptyp -> size , (jl_value_t * )tuptyp );
1701+ int success = (rty == jl_typeof (expected ));
16731702 if (needlock )
16741703 jl_lock_value (v );
1704+ size_t fsz = jl_datatype_size ((jl_datatype_t * )rty ); // need to shrink-wrap the final copy
1705+ memcpy ((char * )r , (char * )v + offs , fsz );
16751706 if (success ) {
1676- memcpy ((char * )r , (char * )v + offs , fsz );
16771707 if (((jl_datatype_t * )rty )-> layout -> haspadding )
16781708 success = jl_egal__bits (r , expected , (jl_datatype_t * )rty );
16791709 else
16801710 success = memcmp ((char * )r , (char * )expected , fsz ) == 0 ;
16811711 }
16821712 * ((uint8_t * )r + fsz ) = success ? 1 : 0 ;
16831713 if (success ) {
1684- rty = jl_typeof (rhs );
1685- fsz = jl_datatype_size ((jl_datatype_t * )rty ); // need to shrink-wrap the final copy
1714+ jl_value_t * rty = jl_typeof (rhs );
1715+ size_t fsz = jl_datatype_size ((jl_datatype_t * )rty ); // need to shrink-wrap the final copy
16861716 if (isunion ) {
1687- nth = 0 ;
1717+ unsigned nth = 0 ;
16881718 if (!jl_find_union_component (ty , rty , & nth ))
16891719 assert (0 && "invalid field assignment to isbits union" );
16901720 * psel = nth ;
@@ -1696,7 +1726,7 @@ jl_value_t *cmpswap_nth_field(jl_datatype_t *st, jl_value_t *v, size_t i, jl_val
16961726 if (needlock )
16971727 jl_unlock_value (v );
16981728 }
1699- r = undefref_check ((jl_datatype_t * )ty , r );
1729+ r = undefref_check ((jl_datatype_t * )rty , r );
17001730 if (__unlikely (r == NULL ))
17011731 jl_throw (jl_undefref_exception );
17021732 }
0 commit comments