Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion credentials/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ func (c *tlsCreds) ClientHandshake(ctx context.Context, addr string, rawConn net
}
// TODO(zhaoq): Omit the auth info for client now. It is more for
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this TODO

// information than anything else.
return conn, nil, nil
return conn, TLSInfo{conn.ConnectionState()}, nil
}

func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) {
Expand Down
94 changes: 94 additions & 0 deletions credentials/credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@
package credentials

import (
"crypto/tls"
"net"
"testing"

"golang.org/x/net/context"
)

func TestTLSOverrideServerName(t *testing.T) {
Expand All @@ -59,3 +63,93 @@ func TestTLSClone(t *testing.T) {
t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName)
}
}

func TestTLSClientHandshakeReturnsAuthInfo(t *testing.T) {
localPort := ":5050"
tlsDir := "../test/testdata/"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

define this as a file level const
const tlsDir := ...

lis, err := net.Listen("tcp", localPort)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do listen() on localhost:0 and get the listening address from lis

if err != nil {
t.Fatalf("Failed to start local server. Listener error: %v", err)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can simplify the error message to be just
t.Fatalf("Failed to listen: %v", err)

}
serverTLS, err := NewServerTLSFromFile(tlsDir+"server1.pem", tlsDir+"server1.key")
if err != nil {
t.Fatalf("Failed to create server TLS. Error: %v", err)
}
var serverAuthInfo AuthInfo
done := make(chan bool)
go func() {
defer func() {
done <- true
}()
serverRawConn, _ := lis.Accept()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check the error here?

serverConn := tls.Server(serverRawConn, serverTLS.(*tlsCreds).config)
serverErr := serverConn.Handshake()
if serverErr != nil {
t.Fatalf("Error on server while handshake. Error: %v", serverErr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

t.Errorf

}
serverAuthInfo = TLSInfo{serverConn.ConnectionState()}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The body can be written in a separate function like:

func serverHandle(hs func(net.Conn) (net.Conn, AuthInfo, error)) {
    ...
}

}()
defer lis.Close()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this to the place right after listen().
If something fails before this line, lis will not be closed().

conn, err := net.Dial("tcp", localPort)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

defer close this connection

if err != nil {
t.Fatalf("Client failed to connect to local server. Error: %v", err)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

print lis.Addr().String() instead of "local server".

}
c := NewTLS(&tls.Config{InsecureSkipVerify: true})
_, authInfo, err := c.ClientHandshake(context.Background(), localPort, conn)
if err != nil {
t.Fatalf("Error on client while handshake. Error: %v", err)
}
select {
case <-done:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the select is not necessary, just <-done.

// wait until server has populated the serverAuthInfo struct.
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and the above client side logic can be wrapped into a function too:

func clientHandle(hs func(context.Context, string, net.Conn) (net.Conn, AuthInfo, error))

if authInfo.AuthType() != serverAuthInfo.AuthType() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TLSInfo.AuthType always returns "tls".
How about comparing something inside ConnectionState instead? like Version?

t.Fatalf("c.ClientHandshake(_, %v, _) = %v, want %v.", localPort, authInfo, serverAuthInfo)
}
}

func TestTLSServerHandshakeReturnsAuthInfo(t *testing.T) {
localPort := ":5050"
tlsDir := "../test/testdata/"
lis, err := net.Listen("tcp", localPort)
if err != nil {
t.Fatalf("Failed to start local server. Listener error: %v", err)
}
serverTLS, err := NewServerTLSFromFile(tlsDir+"server1.pem", tlsDir+"server1.key")
if err != nil {
t.Fatalf("Failed to create server TLS. Error: %v", err)
}
var serverAuthInfo AuthInfo
done := make(chan bool)
go func() {
defer func() {
done <- true
}()
serverRawConn, _ := lis.Accept()
var serverErr error
_, serverAuthInfo, serverErr = serverTLS.ServerHandshake(serverRawConn)
if serverErr != nil {
t.Fatalf("Error on server while handshake. Error: %v", serverErr)
}
}()
defer lis.Close()
conn, err := net.Dial("tcp", localPort)
if err != nil {
t.Fatalf("Client failed to connect to local server. Error: %v", err)
}
c := NewTLS(&tls.Config{InsecureSkipVerify: true})
clientConn := tls.Client(conn, c.(*tlsCreds).config)
err = clientConn.Handshake()
if err != nil {
t.Fatalf("Error on client while handshake. Error: %v", err)
}
authInfo := TLSInfo{clientConn.ConnectionState()}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use tls.ConnectionState directly?

select {
case <-done:
// wait until server has populated the serverAuthInfo struct.
}
if authInfo.AuthType() != serverAuthInfo.AuthType() {
t.Fatalf("ServerHandshake(_) = %v, want %v.", serverAuthInfo, authInfo)
}

}