Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions credentials/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,7 @@ func (c *tlsCreds) ClientHandshake(ctx context.Context, addr string, rawConn net
case <-ctx.Done():
return nil, nil, ctx.Err()
}
// TODO(zhaoq): Omit the auth info for client now. It is more for
// information than anything else.
return conn, nil, nil
return conn, TLSInfo{conn.ConnectionState()}, nil
}

func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) {
Expand Down
161 changes: 161 additions & 0 deletions credentials/credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@
package credentials

import (
"crypto/tls"
"net"
"testing"

"golang.org/x/net/context"
)

func TestTLSOverrideServerName(t *testing.T) {
Expand All @@ -58,4 +62,161 @@ func TestTLSClone(t *testing.T) {
if c.Info().ServerName != expectedServerName {
t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName)
}

}

const tlsDir = "../test/testdata/"

type serverHandshake func(net.Conn) (AuthInfo, error)

func TestClientHandshakeReturnsAuthInfo(t *testing.T) {
done := make(chan AuthInfo, 1)
lis := launchServer(t, tlsServerHandshake, done)
defer lis.Close()
lisAddr := lis.Addr().String()
clientAuthInfo := clientHandle(t, gRPCClientHandshake, lisAddr)
// wait until server sends serverAuthInfo or fails.
serverAuthInfo, ok := <-done
if !ok {
t.Fatalf("Error at server-side")
}
if !compare(clientAuthInfo, serverAuthInfo) {
t.Fatalf("c.ClientHandshake(_, %v, _) = %v, want %v.", lisAddr, clientAuthInfo, serverAuthInfo)
}
}

func TestServerHandshakeReturnsAuthInfo(t *testing.T) {
done := make(chan AuthInfo, 1)
lis := launchServer(t, gRPCServerHandshake, done)
defer lis.Close()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this to the place right after listen().
If something fails before this line, lis will not be closed().

clientAuthInfo := clientHandle(t, tlsClientHandshake, lis.Addr().String())
// wait until server sends serverAuthInfo or fails.
serverAuthInfo, ok := <-done
if !ok {
t.Fatalf("Error at server-side")
}
if !compare(clientAuthInfo, serverAuthInfo) {
t.Fatalf("ServerHandshake(_) = %v, want %v.", serverAuthInfo, clientAuthInfo)
}
}

func TestServerAndClientHandshake(t *testing.T) {
done := make(chan AuthInfo, 1)
lis := launchServer(t, gRPCServerHandshake, done)
defer lis.Close()
clientAuthInfo := clientHandle(t, gRPCClientHandshake, lis.Addr().String())
// wait until server sends serverAuthInfo or fails.
serverAuthInfo, ok := <-done
if !ok {
t.Fatalf("Error at server-side")
}
if !compare(clientAuthInfo, serverAuthInfo) {
t.Fatalf("AuthInfo returned by server: %v and client: %v aren't same", serverAuthInfo, clientAuthInfo)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and the above client side logic can be wrapped into a function too:

func clientHandle(hs func(context.Context, string, net.Conn) (net.Conn, AuthInfo, error))

}

func compare(a1, a2 AuthInfo) bool {
if a1.AuthType() != a2.AuthType() {
return false
}
switch a1.AuthType() {
case "tls":
state1 := a1.(TLSInfo).State
state2 := a2.(TLSInfo).State
if state1.Version == state2.Version &&
state1.HandshakeComplete == state2.HandshakeComplete &&
state1.CipherSuite == state2.CipherSuite &&
state1.NegotiatedProtocol == state2.NegotiatedProtocol {
return true
}
return false
default:
return false
}
}

func launchServer(t *testing.T, hs serverHandshake, done chan AuthInfo) net.Listener {
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Failed to listen: %v", err)
}
go serverHandle(t, hs, done, lis)
return lis
}

// Is run in a seperate goroutine.
func serverHandle(t *testing.T, hs serverHandshake, done chan AuthInfo, lis net.Listener) {
serverRawConn, err := lis.Accept()
if err != nil {
t.Errorf("Server failed to accept connection: %v", err)
close(done)
return
}
serverAuthInfo, err := hs(serverRawConn)
if err != nil {
t.Errorf("Server failed while handshake. Error: %v", err)
close(done)
return
}
done <- serverAuthInfo
}

func clientHandle(t *testing.T, hs func(net.Conn, string) (AuthInfo, error), lisAddr string) AuthInfo {
conn, err := net.Dial("tcp", lisAddr)
if err != nil {
t.Fatalf("Client failed to connect to %s. Error: %v", lisAddr, err)
}
defer conn.Close()
clientAuthInfo, err := hs(conn, lisAddr)
if err != nil {
t.Fatalf("Error on client while handshake. Error: %v", err)
}
return clientAuthInfo
}

// Server handshake implementation in gRPC.
func gRPCServerHandshake(conn net.Conn) (AuthInfo, error) {
serverTLS, err := NewServerTLSFromFile(tlsDir+"server1.pem", tlsDir+"server1.key")
if err != nil {
return TLSInfo{}, err
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return nil, err?

}
_, serverAuthInfo, err := serverTLS.ServerHandshake(conn)
if err != nil {
return TLSInfo{}, err
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return nil, err?

}
return serverAuthInfo, nil
}

// Client handshake implementation in gRPC.
func gRPCClientHandshake(conn net.Conn, lisAddr string) (AuthInfo, error) {
clientTLS := NewTLS(&tls.Config{InsecureSkipVerify: true})
_, authInfo, err := clientTLS.ClientHandshake(context.Background(), lisAddr, conn)
if err != nil {
return TLSInfo{}, err
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

}
return authInfo, nil
}

func tlsServerHandshake(conn net.Conn) (AuthInfo, error) {
cert, err := tls.LoadX509KeyPair(tlsDir+"server1.pem", tlsDir+"server1.key")
if err != nil {
return TLSInfo{}, err
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

}
serverTLSConfig := &tls.Config{Certificates: []tls.Certificate{cert}}
serverConn := tls.Server(conn, serverTLSConfig)
err = serverConn.Handshake()
if err != nil {
return TLSInfo{}, err
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

}
return TLSInfo{State: serverConn.ConnectionState()}, nil
}

func tlsClientHandshake(conn net.Conn, _ string) (AuthInfo, error) {
clientTLSConfig := &tls.Config{InsecureSkipVerify: true}
clientConn := tls.Client(conn, clientTLSConfig)
err := clientConn.Handshake()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if err := clientConn.Handshake(); err != nil {
  ...
}

if err != nil {
return TLSInfo{}, err
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

}
return TLSInfo{State: clientConn.ConnectionState()}, nil
}