@@ -10,6 +10,7 @@ import (
1010 "io"
1111 "net"
1212 "runtime"
13+ "strings"
1314 "sync"
1415 "time"
1516
@@ -37,13 +38,12 @@ func init() {
3738
3839type ClientInstance struct {
3940 sync.RWMutex
40- nfsEKey * mlkem.EncapsulationKey768
41- nfsEKeySha256 [32 ]byte
42- xor uint32
43- minutes time.Duration
44- expire time.Time
45- baseKey []byte
46- ticket []byte
41+ nfsEKey * mlkem.EncapsulationKey768
42+ xorKey []byte
43+ minutes time.Duration
44+ expire time.Time
45+ baseKey []byte
46+ ticket []byte
4747}
4848
4949type ClientConn struct {
@@ -60,10 +60,17 @@ type ClientConn struct {
6060}
6161
6262func (i * ClientInstance ) Init (nfsEKeyBytes []byte , xor uint32 , minutes time.Duration ) (err error ) {
63+ if i .nfsEKey != nil {
64+ err = errors .New ("already initialized" )
65+ return
66+ }
6367 i .nfsEKey , err = mlkem .NewEncapsulationKey768 (nfsEKeyBytes )
68+ if err != nil {
69+ return
70+ }
6471 if xor > 0 {
65- i . nfsEKeySha256 = sha256 .Sum256 (nfsEKeyBytes )
66- i .xor = xor
72+ xorKey : = sha256 .Sum256 (nfsEKeyBytes )
73+ i .xorKey = xorKey [:]
6774 }
6875 i .minutes = minutes
6976 return
@@ -73,8 +80,8 @@ func (i *ClientInstance) Handshake(conn net.Conn) (net.Conn, error) {
7380 if i .nfsEKey == nil {
7481 return nil , errors .New ("uninitialized" )
7582 }
76- if i .xor > 0 {
77- conn = NewXorConn (conn , i .nfsEKeySha256 [:] )
83+ if i .xorKey != nil {
84+ conn = NewXorConn (conn , i .xorKey )
7885 }
7986 c := & ClientConn {Conn : conn }
8087
@@ -110,14 +117,14 @@ func (i *ClientInstance) Handshake(conn net.Conn) (net.Conn, error) {
110117 }
111118 // client can send more padding / NFS AEAD messages if needed
112119
113- _ , t , l , err := ReadAndDecodeHeader (c .Conn )
120+ _ , t , l , err := ReadAndDiscardPaddings (c .Conn )
114121 if err != nil {
115122 return nil , err
116123 }
124+
117125 if t != 1 {
118126 return nil , fmt .Errorf ("unexpected type %v, expect random hello" , t )
119127 }
120-
121128 peerRandomHello := make ([]byte , 1088 + 21 )
122129 if l != len (peerRandomHello ) {
123130 return nil , fmt .Errorf ("unexpected length %v for random hello" , l )
@@ -194,34 +201,17 @@ func (c *ClientConn) Read(b []byte) (int, error) {
194201 return 0 , nil
195202 }
196203 if c .peerAead == nil {
197- var t byte
198- var l int
199- var err error
200- if c .instance == nil { // from 1-RTT
201- for {
202- if _ , t , l , err = ReadAndDecodeHeader (c .Conn ); err != nil {
203- return 0 , err
204- }
205- if t != 23 {
206- break
207- }
208- if _ , err := io .ReadFull (c .Conn , make ([]byte , l )); err != nil {
209- return 0 , err
210- }
211- }
212- } else {
213- h := make ([]byte , 5 )
214- if _ , err := io .ReadFull (c .Conn , h ); err != nil {
215- return 0 , err
216- }
217- if t , l , err = DecodeHeader (h ); err != nil {
204+ _ , t , l , err := ReadAndDiscardPaddings (c .Conn )
205+ if err != nil {
206+ if c .instance != nil && strings .HasPrefix (err .Error (), "invalid header: " ) { // from 0-RTT
218207 c .instance .Lock ()
219208 if bytes .Equal (c .ticket , c .instance .ticket ) {
220209 c .instance .expire = time .Now () // expired
221210 }
222211 c .instance .Unlock ()
223212 return 0 , errors .New ("new handshake needed" )
224213 }
214+ return 0 , err
225215 }
226216 if t != 0 {
227217 return 0 , fmt .Errorf ("unexpected type %v, expect server random" , t )
0 commit comments