1- //! Global shell lock
1+ //! Global shell lock.
2+
23use std:: {
34 cell:: Cell ,
45 mem:: MaybeUninit ,
@@ -7,29 +8,77 @@ use std::{
78 sync:: { RwLock , RwLockReadGuard , RwLockWriteGuard } ,
89} ;
910
11+ /// If, on the same thread, there are multiple calls to [`read`] or [`write`],
12+ /// then the `Guard`s returned should be dropped in the reverse order that they
13+ /// were acquired.
14+ ///
15+ /// If this is violated, e.g. in
16+ ///
17+ /// ```ignore
18+ /// let w1 = write();
19+ /// let _w2 = write();
20+ /// drop(w1);
21+ /// // try to use a global resource with write privileges
22+ /// ```
23+ ///
24+ /// then things won't turn out well, because `_w2` doesn't actually contain a
25+ /// lock guard.
26+ #[ derive( Debug ) ]
27+ pub ( crate ) struct Guard ( Option < Repr > ) ;
28+
1029#[ derive( Debug ) ]
11- pub ( crate ) struct Guard {
12- r_guard : Option < RwLockReadGuard < ' static , ( ) > > ,
13- w_guard : Option < RwLockWriteGuard < ' static , ( ) > > ,
30+ enum Repr {
31+ Read ( RwLockReadGuard < ' static , ( ) > ) ,
32+ Write ( RwLockWriteGuard < ' static , ( ) > ) ,
1433}
1534
35+ /// Returns a [`Guard`] for write access to global resources.
1636pub ( crate ) fn write ( ) -> Guard {
17- if LOCKED . with ( |it| it. get ( ) ) {
18- return Guard { r_guard : None , w_guard : None } ;
37+ match CACHE . with ( Cell :: get) {
38+ Cache :: Write => {
39+ // this thread (and only this thread) can already write. don't try to
40+ // acquire another write guard.
41+ Guard ( None )
42+ }
43+ Cache :: Read ( readers) => {
44+ assert_eq ! (
45+ readers, 0 ,
46+ "calling write() with an active read guard on the same thread would deadlock"
47+ ) ;
48+ let w_guard = static_rw_lock ( ) . write ( ) . unwrap_or_else ( |err| err. into_inner ( ) ) ;
49+ // note that we have a writer.
50+ CACHE . with ( |it| it. set ( Cache :: Write ) ) ;
51+ Guard ( Some ( Repr :: Write ( w_guard) ) )
52+ }
1953 }
20-
21- let w_guard = static_rw_lock ( ) . write ( ) . unwrap_or_else ( |err| err. into_inner ( ) ) ;
22- LOCKED . with ( |it| it. set ( true ) ) ;
23- Guard { w_guard : Some ( w_guard) , r_guard : None }
2454}
2555
56+ /// Returns a [`Guard`] for read access to global resources.
2657pub ( crate ) fn read ( ) -> Guard {
27- if LOCKED . with ( |it| it. get ( ) ) {
28- return Guard { r_guard : None , w_guard : None } ;
58+ match CACHE . with ( Cell :: get) {
59+ Cache :: Write => {
60+ // this thread (and only this thread) can already write. it's safe
61+ // to allow this thread to read as well, because we won't have
62+ // concurrent reads and writes, because we're only working on this
63+ // thread.
64+ Guard ( None )
65+ }
66+ Cache :: Read ( readers) => {
67+ if readers == 0 {
68+ // this thread has no readers or writers. try to acquire the
69+ // lock for reading.
70+ let r_guard = static_rw_lock ( ) . read ( ) . unwrap_or_else ( |err| err. into_inner ( ) ) ;
71+ // note that we now have 1 reader.
72+ CACHE . with ( |it| it. set ( Cache :: Read ( 1 ) ) ) ;
73+ Guard ( Some ( Repr :: Read ( r_guard) ) )
74+ } else {
75+ // this thread can already read. don't try to acquire another
76+ // read guard. also, note that we have another reader.
77+ CACHE . with ( |it| it. set ( Cache :: Read ( readers + 1 ) ) ) ;
78+ Guard ( None )
79+ }
80+ }
2981 }
30-
31- let r_guard = static_rw_lock ( ) . read ( ) . unwrap_or_else ( |err| err. into_inner ( ) ) ;
32- Guard { w_guard : None , r_guard : Some ( r_guard) }
3382}
3483
3584fn static_rw_lock ( ) -> & ' static RwLock < ( ) > {
@@ -41,15 +90,69 @@ fn static_rw_lock() -> &'static RwLock<()> {
4190 }
4291}
4392
93+ #[ derive( Debug , Clone , Copy , PartialEq , Eq ) ]
94+ enum Cache {
95+ Read ( usize ) ,
96+ Write ,
97+ }
98+
4499thread_local ! {
45- static LOCKED : Cell <bool > = Cell :: new( false ) ;
100+ static CACHE : Cell <Cache > = Cell :: new( Cache :: Read ( 0 ) ) ;
46101}
47102
48103impl Drop for Guard {
49104 fn drop ( & mut self ) {
50- if self . w_guard . is_some ( ) {
51- LOCKED . with ( |it| it. set ( false ) )
105+ match self . 0 {
106+ Some ( Repr :: Read ( _) ) => CACHE . with ( |it| {
107+ let n = match it. get ( ) {
108+ Cache :: Read ( n) => n,
109+ Cache :: Write => unreachable ! ( "had both a reader and a writer" ) ,
110+ } ;
111+ it. set ( Cache :: Read ( n - 1 ) ) ;
112+ } ) ,
113+ Some ( Repr :: Write ( _) ) => CACHE . with ( |it| {
114+ assert_eq ! ( it. get( ) , Cache :: Write ) ;
115+ it. set ( Cache :: Read ( 0 ) ) ;
116+ } ) ,
117+ None => { }
52118 }
53- let _ = self . r_guard ;
54119 }
55120}
121+
122+ #[ test]
123+ fn read_write_read ( ) {
124+ eprintln ! ( "get r1" ) ;
125+ let r1 = read ( ) ;
126+ eprintln ! ( "got r1" ) ;
127+ let h = std:: thread:: spawn ( || {
128+ eprintln ! ( "get w1" ) ;
129+ let w1 = write ( ) ;
130+ eprintln ! ( "got w1" ) ;
131+ drop ( w1) ;
132+ eprintln ! ( "gave w1" ) ;
133+ } ) ;
134+ std:: thread:: sleep ( std:: time:: Duration :: from_millis ( 300 ) ) ;
135+ eprintln ! ( "get r2" ) ;
136+ let r2 = read ( ) ;
137+ eprintln ! ( "got r2" ) ;
138+ drop ( r1) ;
139+ eprintln ! ( "gave r1" ) ;
140+ drop ( r2) ;
141+ eprintln ! ( "gave r2" ) ;
142+ h. join ( ) . unwrap ( ) ;
143+ }
144+
145+ #[ test]
146+ fn write_read ( ) {
147+ let _w = write ( ) ;
148+ let _r = read ( ) ;
149+ }
150+
151+ #[ test]
152+ #[ should_panic(
153+ expected = "calling write() with an active read guard on the same thread would deadlock"
154+ ) ]
155+ fn read_write ( ) {
156+ let _r = read ( ) ;
157+ let _w = write ( ) ;
158+ }
0 commit comments