Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions credentials/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,7 @@ func (c *tlsCreds) ClientHandshake(ctx context.Context, addr string, rawConn net
case <-ctx.Done():
return nil, nil, ctx.Err()
}
// TODO(zhaoq): Omit the auth info for client now. It is more for
// information than anything else.
return conn, nil, nil
return conn, TLSInfo{conn.ConnectionState()}, nil
}

func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) {
Expand Down
173 changes: 173 additions & 0 deletions credentials/credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@
package credentials

import (
"crypto/tls"
"net"
"testing"

"golang.org/x/net/context"
)

func TestTLSOverrideServerName(t *testing.T) {
Expand All @@ -58,4 +62,173 @@ func TestTLSClone(t *testing.T) {
if c.Info().ServerName != expectedServerName {
t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName)
}

}

const tlsDir = "../test/testdata/"

type serverHandshake func(net.Conn, *tls.ConnectionState) error

func TestClientHandshakeReturnsAuthInfo(t *testing.T) {
var serverConnState tls.ConnectionState
errChan := make(chan error, 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is more like a channel to notify the completion of serverHandle. Can you rename it to "done"? Applies to all the tests you added.

lisAddr, err := launchServer(t, &serverConnState, tlsServerHandshake, errChan)
if err != nil {
return
}
clientConnState, err := clientHandle(t, gRPCClientHandshake, lisAddr)
if err != nil {
return
}
// wait until server has populated the serverAuthInfo struct or failed.
if err = <-errChan; err != nil {
return
}
if !isEqualState(clientConnState, serverConnState) {
t.Fatalf("c.ClientHandshake(_, %v, _) = %v, want %v.", lisAddr, clientConnState, serverConnState)
}
}

func TestServerHandshakeReturnsAuthInfo(t *testing.T) {
var serverConnState tls.ConnectionState
errChan := make(chan error, 1)
lisAddr, err := launchServer(t, &serverConnState, gRPCServerHandshake, errChan)
if err != nil {
return
}
clientConnState, err := clientHandle(t, tlsClientHandshake, lisAddr)
if err != nil {
return
}
// wait until server has populated the serverAuthInfo struct or failed.
if err = <-errChan; err != nil {
return
}
if !isEqualState(clientConnState, serverConnState) {
t.Fatalf("ServerHandshake(_) = %v, want %v.", serverConnState, clientConnState)
}
}

func TestServerAndClientHandshake(t *testing.T) {
var serverConnState tls.ConnectionState
errChan := make(chan error, 1)
lisAddr, err := launchServer(t, &serverConnState, gRPCServerHandshake, errChan)
if err != nil {
return
}
clientConnState, err := clientHandle(t, gRPCClientHandshake, lisAddr)
if err != nil {
return
}
// wait until server has populated the serverAuthInfo struct or failed.
if err = <-errChan; err != nil {
return
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and the above client side logic can be wrapped into a function too:

func clientHandle(hs func(context.Context, string, net.Conn) (net.Conn, AuthInfo, error))

if !isEqualState(clientConnState, serverConnState) {
t.Fatalf("Connection states returened by server: %v and client: %v aren't same", serverConnState, clientConnState)
}
}

func isEqualState(state1, state2 tls.ConnectionState) bool {
if state1.Version == state2.Version &&
state1.HandshakeComplete == state2.HandshakeComplete &&
state1.CipherSuite == state2.CipherSuite &&
state1.NegotiatedProtocol == state2.NegotiatedProtocol {
return true
}
return false
}

func launchServer(t *testing.T, serverConnState *tls.ConnectionState, hs serverHandshake, errChan chan error) (string, error) {
lis, err := net.Listen("tcp", "localhost:0")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The above comments apply to this test case too.

if err != nil {
t.Errorf("Failed to listen: %v", err)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can t.Fatalf directly so that this function only needs to return lis.Addr().String().

return "", err
}
go serverHandle(t, hs, serverConnState, errChan, lis)
return lis.Addr().String(), nil
}

// Is run in a seperate go routine.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

goroutine is one word.

func serverHandle(t *testing.T, hs func(net.Conn, *tls.ConnectionState) error, serverConnState *tls.ConnectionState, errChan chan error, lis net.Listener) {
defer lis.Close()
var err error
defer func() {
errChan <- err
}()
serverRawConn, err := lis.Accept()
if err != nil {
t.Errorf("Server failed to accept connection: %v", err)
return
}
err = hs(serverRawConn, serverConnState)
if err != nil {
t.Errorf("Error at server-side while handshake. Error: %v", err)
return
}
}

func clientHandle(t *testing.T, hs func(net.Conn, string) (tls.ConnectionState, error), lisAddr string) (tls.ConnectionState, error) {
conn, err := net.Dial("tcp", lisAddr)
if err != nil {
t.Errorf("Client failed to connect to %s. Error: %v", lisAddr, err)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use t.Fatalf and this function only returns tls.ConnectionState.

return tls.ConnectionState{}, err
}
defer conn.Close()
clientConnState, err := hs(conn, lisAddr)
if err != nil {
t.Errorf("Error on client while handshake. Error: %v", err)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

t.Fatalf

}
return clientConnState, err
}

// Server handshake implementation using gRPC.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/using/in

func gRPCServerHandshake(conn net.Conn, serverConnState *tls.ConnectionState) error {
serverTLS, err := NewServerTLSFromFile(tlsDir+"server1.pem", tlsDir+"server1.key")
if err != nil {
return err
}
_, serverAuthInfo, err := serverTLS.ServerHandshake(conn)
if err != nil {
return err
}
*serverConnState = serverAuthInfo.(TLSInfo).State
return nil
}

// Client handshake implementation using gRPC.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/using/in

func gRPCClientHandshake(conn net.Conn, lisAddr string) (tls.ConnectionState, error) {
clientTLS := NewTLS(&tls.Config{InsecureSkipVerify: true})
_, authInfo, err := clientTLS.ClientHandshake(context.Background(), lisAddr, conn)
if err != nil {
return tls.ConnectionState{}, err
}
return authInfo.(TLSInfo).State, nil
}

// Server handshake implementation using tls.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the comment.

func tlsServerHandshake(conn net.Conn, serverConnState *tls.ConnectionState) error {
cert, err := tls.LoadX509KeyPair(tlsDir+"server1.pem", tlsDir+"server1.key")
if err != nil {
return err
}
serverTLSConfig := &tls.Config{Certificates: []tls.Certificate{cert}}
serverConn := tls.Server(conn, serverTLSConfig)
err = serverConn.Handshake()
if err != nil {
return err
}
*serverConnState = serverConn.ConnectionState()
return nil
}

// Client handskae implementation using tls.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the comment.

func tlsClientHandshake(conn net.Conn, _ string) (tls.ConnectionState, error) {
clientTLSConfig := &tls.Config{InsecureSkipVerify: true}
clientConn := tls.Client(conn, clientTLSConfig)
err := clientConn.Handshake()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if err := clientConn.Handshake(); err != nil {
  ...
}

if err != nil {
return tls.ConnectionState{}, err
}
return clientConn.ConnectionState(), nil
}