Skip to content

Commit 2a8831b

Browse files
committed
chore: sync vless encryption code
1 parent cdf5e0c commit 2a8831b

4 files changed

Lines changed: 46 additions & 37 deletions

File tree

transport/vless/encryption/client.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,13 @@ func (i *ClientInstance) Init(nfsEKeyBytes, xorPKeyBytes []byte, xorMode, minute
7070
if i.nfsEKey, err = mlkem.NewEncapsulationKey768(nfsEKeyBytes); err != nil {
7171
return
7272
}
73-
hash32 := sha3.Sum256(nfsEKeyBytes)
74-
copy(i.hash11[:], hash32[:])
7573
if xorMode > 0 {
7674
i.xorMode = xorMode
7775
if i.xorPKey, err = ecdh.X25519().NewPublicKey(xorPKeyBytes); err != nil {
7876
return
7977
}
78+
hash32 := sha3.Sum256(nfsEKeyBytes)
79+
copy(i.hash11[:], hash32[:])
8080
}
8181
i.minutes = time.Duration(minutes) * time.Minute
8282
return
@@ -126,7 +126,7 @@ func (i *ClientInstance) Handshake(conn net.Conn) (*ClientConn, error) {
126126
}
127127
// client can send more NFS AEAD paddings / messages if needed
128128

129-
_, t, l, err := ReadAndDiscardPaddings(c.Conn) // allow paddings before server hello
129+
_, t, l, err := ReadAndDiscardPaddings(c.Conn, nil, nil) // allow paddings before server hello
130130
if err != nil {
131131
return nil, err
132132
}
@@ -209,9 +209,9 @@ func (c *ClientConn) Read(b []byte) (int, error) {
209209
return 0, nil
210210
}
211211
if c.peerAEAD == nil {
212-
_, t, l, err := ReadAndDiscardPaddings(c.Conn) // allow paddings before random hello
212+
_, t, l, err := ReadAndDiscardPaddings(c.Conn, nil, nil) // allow paddings before random hello
213213
if err != nil {
214-
if c.instance != nil && strings.HasPrefix(err.Error(), "invalid header: ") { // 0-RTT's 0-RTT
214+
if c.instance != nil && strings.HasPrefix(err.Error(), "invalid header: ") { // 0-RTT
215215
c.instance.Lock()
216216
if bytes.Equal(c.ticket, c.instance.ticket) {
217217
c.instance.expire = time.Now() // expired

transport/vless/encryption/common.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,21 @@ func ReadAndDecodeHeader(conn net.Conn) (h []byte, t byte, l int, err error) {
6262
return
6363
}
6464

65-
func ReadAndDiscardPaddings(conn net.Conn) (h []byte, t byte, l int, err error) {
65+
func ReadAndDiscardPaddings(conn net.Conn, aead cipher.AEAD, nonce []byte) (h []byte, t byte, l int, err error) {
6666
for {
6767
if h, t, l, err = ReadAndDecodeHeader(conn); err != nil || t != 23 {
6868
return
6969
}
70-
if _, err = io.ReadFull(conn, make([]byte, l)); err != nil {
70+
padding := make([]byte, l)
71+
if _, err = io.ReadFull(conn, padding); err != nil {
7172
return
7273
}
74+
if aead != nil {
75+
if _, err := aead.Open(nil, nonce, padding, h); err != nil {
76+
return h, t, l, err
77+
}
78+
IncreaseNonce(nonce)
79+
}
7380
}
7481
}
7582

transport/vless/encryption/doc.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,5 @@
1616
// https://github.com/XTLS/Xray-core/commit/84835bec7d0d8555d0dd30953ed26a272de814c4
1717
// https://github.com/XTLS/Xray-core/commit/373558ed7abdbac3de41745cf30ec04c9adde604
1818
// https://github.com/XTLS/Xray-core/commit/38cc306c955c362f044e074049a5e67b6b9fb389
19+
// https://github.com/XTLS/Xray-core/commit/b33555cc0a52d0af3c23d2af8fca42f8a685d9af
1920
package encryption

transport/vless/encryption/server.go

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,13 @@ func (i *ServerInstance) Init(nfsDKeySeed, xorSKeyBytes []byte, xorMode, minutes
5555
if i.nfsDKey, err = mlkem.NewDecapsulationKey768(nfsDKeySeed); err != nil {
5656
return
5757
}
58-
hash32 := sha3.Sum256(i.nfsDKey.EncapsulationKey().Bytes())
59-
copy(i.hash11[:], hash32[:])
6058
if xorMode > 0 {
6159
i.xorMode = xorMode
6260
if i.xorSKey, err = ecdh.X25519().NewPrivateKey(xorSKeyBytes); err != nil {
6361
return
6462
}
63+
hash32 := sha3.Sum256(i.nfsDKey.EncapsulationKey().Bytes())
64+
copy(i.hash11[:], hash32[:])
6565
}
6666
if minutes > 0 {
6767
i.minutes = time.Duration(minutes) * time.Minute
@@ -106,7 +106,7 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*ServerConn, error) {
106106
}
107107
c := &ServerConn{Conn: conn}
108108

109-
_, t, l, err := ReadAndDiscardPaddings(c.Conn) // allow paddings before client/ticket hello
109+
_, t, l, err := ReadAndDiscardPaddings(c.Conn, nil, nil) // allow paddings before client/ticket hello
110110
if err != nil {
111111
return nil, err
112112
}
@@ -170,11 +170,14 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*ServerConn, error) {
170170
if err != nil {
171171
return nil, err
172172
}
173+
nfsAEAD := NewAEAD(c.cipher, nfsKey, pfsEKeyBytes, encapsulatedNfsKey)
174+
nfsNonce := append([]byte{}, peerClientHello[:11+1]...)
173175
pfsKey, encapsulatedPfsKey := pfsEKey.Encapsulate()
174176
c.baseKey = append(pfsKey, nfsKey...)
175177
pfsAEAD := NewAEAD(c.cipher, c.baseKey, encapsulatedPfsKey, encapsulatedNfsKey)
176-
c.ticket = append(i.hash11[:], pfsAEAD.Seal(nil, peerClientHello[:11+1], []byte("VLESS"), pfsEKeyBytes)...)
177-
IncreaseNonce(peerClientHello[:11+1])
178+
pfsNonce := append([]byte{}, peerClientHello[:11+1]...)
179+
c.ticket = append(i.hash11[:], pfsAEAD.Seal(nil, pfsNonce, []byte("VLESS"), pfsEKeyBytes)...)
180+
IncreaseNonce(pfsNonce)
178181

179182
serverHello := make([]byte, 5+1088+21+randBetween(100, 1000))
180183
EncodeHeader(serverHello, 1, 1088+21)
@@ -183,20 +186,41 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*ServerConn, error) {
183186
padding := serverHello[5+1088+21:]
184187
rand.Read(padding) // important
185188
EncodeHeader(padding, 23, len(padding)-5)
186-
pfsAEAD.Seal(padding[:5], peerClientHello[:11+1], padding[5:len(padding)-16], padding[:5])
189+
pfsAEAD.Seal(padding[:5], pfsNonce, padding[5:len(padding)-16], padding[:5])
187190

188191
if _, err := c.Conn.Write(serverHello); err != nil {
189192
return nil, err
190193
}
191194
// server can send more PFS AEAD paddings / messages if needed
192195

196+
_, t, l, err = ReadAndDiscardPaddings(c.Conn, nfsAEAD, nfsNonce) // allow paddings before ticket hello
197+
if err != nil {
198+
return nil, err
199+
}
200+
if t != 0 {
201+
return nil, fmt.Errorf("unexpected type %v, expect ticket hello", t)
202+
}
203+
peerTicketHello := make([]byte, 32+32)
204+
if l != len(peerTicketHello) {
205+
return nil, fmt.Errorf("unexpected length %v for ticket hello", l)
206+
}
207+
if _, err := io.ReadFull(c.Conn, peerTicketHello); err != nil {
208+
return nil, err
209+
}
210+
if !bytes.Equal(peerTicketHello[:32], c.ticket) {
211+
return nil, errors.New("naughty boy")
212+
}
213+
c.peerRandom = peerTicketHello[32:]
214+
193215
if i.minutes > 0 {
194216
i.Lock()
195-
i.sessions[[32]byte(c.ticket)] = &ServerSession{
217+
s := &ServerSession{
196218
expire: time.Now().Add(i.minutes),
197219
cipher: c.cipher,
198220
baseKey: c.baseKey,
199221
}
222+
s.randoms.Store([32]byte(c.peerRandom), true)
223+
i.sessions[[32]byte(c.ticket)] = s
200224
i.Unlock()
201225
}
202226

@@ -208,26 +232,6 @@ func (c *ServerConn) Read(b []byte) (int, error) {
208232
return 0, nil
209233
}
210234
if c.peerAEAD == nil {
211-
if c.peerRandom == nil { // 1-RTT's 0-RTT
212-
_, t, l, err := ReadAndDiscardPaddings(c.Conn) // allow paddings before ticket hello
213-
if err != nil {
214-
return 0, err
215-
}
216-
if t != 0 {
217-
return 0, fmt.Errorf("unexpected type %v, expect ticket hello", t)
218-
}
219-
peerTicketHello := make([]byte, 32+32)
220-
if l != len(peerTicketHello) {
221-
return 0, fmt.Errorf("unexpected length %v for ticket hello", l)
222-
}
223-
if _, err := io.ReadFull(c.Conn, peerTicketHello); err != nil {
224-
return 0, err
225-
}
226-
if !bytes.Equal(peerTicketHello[:32], c.ticket) {
227-
return 0, errors.New("naughty boy")
228-
}
229-
c.peerRandom = peerTicketHello[32:]
230-
}
231235
c.peerAEAD = NewAEAD(c.cipher, c.baseKey, c.peerRandom, c.ticket)
232236
c.peerNonce = make([]byte, 12)
233237
}
@@ -280,9 +284,6 @@ func (c *ServerConn) Write(b []byte) (int, error) {
280284
}
281285
n += len(b)
282286
if c.aead == nil {
283-
if c.peerRandom == nil {
284-
return 0, errors.New("empty c.peerRandom")
285-
}
286287
data = make([]byte, 5+32+5+len(b)+16)
287288
EncodeHeader(data, 0, 32)
288289
rand.Read(data[5 : 5+32])

0 commit comments

Comments
 (0)