Skip to content
This repository was archived by the owner on Aug 19, 2022. It is now read-only.

Commit ef83371

Browse files
implement a cache for session tickets
1 parent 80ca73e commit ef83371

File tree

2 files changed

+150
-0
lines changed

2 files changed

+150
-0
lines changed

session_cache.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
package libp2ptls
2+
3+
import (
4+
"crypto/tls"
5+
6+
ci "github.com/libp2p/go-libp2p-core/crypto"
7+
)
8+
9+
const cacheSize = 3
10+
11+
type clientSessionCache struct {
12+
cache []*tls.ClientSessionState
13+
14+
// When using session resumption, the server won't send its certificate chain.
15+
// We therefore need to save its public key when storing a session ticket,
16+
// so we can return it on conn.RemotePublicKey().
17+
pubKey ci.PubKey
18+
}
19+
20+
var _ tls.ClientSessionCache = &clientSessionCache{}
21+
22+
func newClientSessionCache() *clientSessionCache {
23+
return &clientSessionCache{}
24+
}
25+
26+
func (c *clientSessionCache) Put(_ string, cs *tls.ClientSessionState) {
27+
if len(c.cache) == cacheSize {
28+
c.cache = c.cache[1:]
29+
}
30+
c.cache = append(c.cache, cs)
31+
}
32+
33+
func (c *clientSessionCache) Get(_ string) (*tls.ClientSessionState, bool) {
34+
if len(c.cache) == 0 {
35+
return nil, false
36+
}
37+
ticket := c.cache[len(c.cache)-1]
38+
c.cache = c.cache[:len(c.cache)-1]
39+
return ticket, true
40+
}
41+
42+
func (c *clientSessionCache) SetPubKey(k ci.PubKey) {
43+
if c.pubKey != nil && !c.pubKey.Equals(k) {
44+
panic("mismatching public key")
45+
}
46+
c.pubKey = k
47+
}
48+
49+
func (c *clientSessionCache) GetPubKey() ci.PubKey {
50+
return c.pubKey
51+
}

session_cache_test.go

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
package libp2ptls
2+
3+
import (
4+
"crypto/rand"
5+
"crypto/tls"
6+
"encoding/binary"
7+
"unsafe"
8+
9+
ci "github.com/libp2p/go-libp2p-core/crypto"
10+
11+
. "github.com/onsi/ginkgo"
12+
. "github.com/onsi/gomega"
13+
)
14+
15+
var _ = Describe("Session Ticket Cache", func() {
16+
var cache *clientSessionCache
17+
const key = "irrelevant"
18+
ticketSize := unsafe.Sizeof(&tls.ClientSessionState{})
19+
20+
toSessionTicket := func(n int) *tls.ClientSessionState {
21+
b := make([]byte, ticketSize)
22+
binary.BigEndian.PutUint32(b, uint32(n))
23+
return (*tls.ClientSessionState)(unsafe.Pointer(&b))
24+
}
25+
26+
fromSessionTicket := func(t *tls.ClientSessionState) int {
27+
b := (*[]byte)(unsafe.Pointer(t))
28+
return int(binary.BigEndian.Uint32(*b))
29+
}
30+
31+
BeforeEach(func() {
32+
cache = newClientSessionCache()
33+
})
34+
35+
It("encodes and decodes values from session tickets", func() {
36+
Expect(fromSessionTicket(toSessionTicket(1337))).To(Equal(1337))
37+
})
38+
39+
It("doesn't return a session ticket if there's none", func() {
40+
t, ok := cache.Get(key)
41+
Expect(ok).To(BeFalse())
42+
Expect(t).To(BeNil())
43+
})
44+
45+
It("saves and retrieves session tickets", func() {
46+
cache.Put(key, toSessionTicket(42))
47+
ticket, ok := cache.Get(key)
48+
Expect(ok).To(BeTrue())
49+
Expect(fromSessionTicket(ticket)).To(Equal(42))
50+
_, ok = cache.Get(key)
51+
Expect(ok).To(BeFalse())
52+
})
53+
54+
It("returns the most recent ticket first", func() {
55+
cache.Put(key, toSessionTicket(1))
56+
cache.Put(key, toSessionTicket(2))
57+
ticket, ok := cache.Get(key)
58+
Expect(ok).To(BeTrue())
59+
Expect(fromSessionTicket(ticket)).To(Equal(2))
60+
ticket, ok = cache.Get(key)
61+
Expect(ok).To(BeTrue())
62+
Expect(fromSessionTicket(ticket)).To(Equal(1))
63+
})
64+
65+
It("limits the number of tickets saved", func() {
66+
Expect(cacheSize).To(Equal(3))
67+
cache.Put(key, toSessionTicket(1))
68+
cache.Put(key, toSessionTicket(2))
69+
cache.Put(key, toSessionTicket(3))
70+
cache.Put(key, toSessionTicket(4))
71+
ticket, ok := cache.Get(key)
72+
Expect(ok).To(BeTrue())
73+
Expect(fromSessionTicket(ticket)).To(Equal(4))
74+
ticket, ok = cache.Get(key)
75+
Expect(ok).To(BeTrue())
76+
Expect(fromSessionTicket(ticket)).To(Equal(3))
77+
ticket, ok = cache.Get(key)
78+
Expect(ok).To(BeTrue())
79+
Expect(fromSessionTicket(ticket)).To(Equal(2))
80+
_, ok = cache.Get(key)
81+
Expect(ok).To(BeFalse())
82+
})
83+
84+
It("sets and gets the public key", func() {
85+
_, pub, err := ci.GenerateEd25519Key(rand.Reader)
86+
Expect(err).ToNot(HaveOccurred())
87+
cache.SetPubKey(pub)
88+
Expect(cache.GetPubKey()).To(Equal(pub))
89+
})
90+
91+
It("doesn't allow setting of different public keys", func() {
92+
_, pub1, err := ci.GenerateEd25519Key(rand.Reader)
93+
Expect(err).ToNot(HaveOccurred())
94+
_, pub2, err := ci.GenerateEd25519Key(rand.Reader)
95+
Expect(err).ToNot(HaveOccurred())
96+
cache.SetPubKey(pub1)
97+
Expect(func() { cache.SetPubKey(pub2) }).To(Panic())
98+
})
99+
})

0 commit comments

Comments
 (0)