@@ -18,17 +18,66 @@ package thrift
1818
1919import (
2020 "context"
21+ "crypto/tls"
2122 "maps"
2223 "net"
2324 "time"
2425)
2526
27+ type contextKey int
28+
29+ const (
30+ connInfoKey contextKey = 0
31+ reqContextKey contextKey = 2
32+ )
33+
2634// Identity represents a secure peer identity
2735type Identity struct {
2836 Type string
2937 Data string
3038}
3139
40+ // ConnInfo contains connection information from clients of the Server.
41+ type ConnInfo struct {
42+ RemoteAddr net.Addr
43+ tlsState tlsConnectionStater // set by thrift tcp servers
44+ }
45+
46+ // tlsConnectionStater is an abstract interface for types that can return
47+ // the state of TLS connections. This is used to support not only tls.Conn
48+ // but also custom wrappers such as permissive TLS/non-TLS sockets.
49+ //
50+ // Caveat: this interface has to support at least tls.Conn, which has
51+ // the current signature for ConnectionState. Because of that, wrappers
52+ // for permissive TLS/non-TLS may return an empty tls.ConnectionState.
53+ type tlsConnectionStater interface {
54+ ConnectionState () tls.ConnectionState
55+ }
56+
57+ // tlsConnectionStaterHandshaker is an abstract interface that allows
58+ // custom "TLS-like" connections to be used with Thrift ALPN logic.
59+ type tlsConnectionStaterHandshaker interface {
60+ tlsConnectionStater
61+ HandshakeContext (context.Context ) error
62+ }
63+
64+ // Compile time interface enforcer
65+ var _ tlsConnectionStater = (* tls .Conn )(nil )
66+ var _ tlsConnectionStaterHandshaker = (* tls .Conn )(nil )
67+
68+ // TLS returns the TLS connection state.
69+ func (c ConnInfo ) TLS () * tls.ConnectionState {
70+ if c .tlsState == nil {
71+ return nil
72+ }
73+ cs := c .tlsState .ConnectionState ()
74+ // See the caveat in tlsConnectionStater.
75+ if cs .Version == 0 {
76+ return nil
77+ }
78+ return & cs
79+ }
80+
3281// RequestContext is a mirror of C++ apache::thrift::RequestContext
3382// Not all options are guaranteed to be implemented by a client
3483type RequestContext struct {
@@ -50,11 +99,24 @@ type RequestContext struct {
5099 contextHeaders
51100}
52101
53- type requestContextKey int
102+ // withConnInfo adds connection info (from a thrift.Transport) to context, if applicable
103+ func withConnInfo (ctx context.Context , conn net.Conn ) context.Context {
104+ var tlsState tlsConnectionStater
105+ if t , ok := conn .(tlsConnectionStater ); ok {
106+ tlsState = t
107+ }
108+ ctx = context .WithValue (ctx , connInfoKey , ConnInfo {
109+ RemoteAddr : conn .RemoteAddr (),
110+ tlsState : tlsState ,
111+ })
112+ return ctx
113+ }
54114
55- const (
56- reqContextKey requestContextKey = 2
57- )
115+ // connInfoFromContext extracts and returns ConnInfo from context.
116+ func connInfoFromContext (ctx context.Context ) (ConnInfo , bool ) {
117+ v , ok := ctx .Value (connInfoKey ).(ConnInfo )
118+ return v , ok
119+ }
58120
59121// GetRequestContext returns the RequestContext in a go context, or nil if there is nothing
60122func GetRequestContext (ctx context.Context ) * RequestContext {
0 commit comments