Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions credentials/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,7 @@ func (c *tlsCreds) ClientHandshake(ctx context.Context, addr string, rawConn net
case <-ctx.Done():
return nil, nil, ctx.Err()
}
// TODO(zhaoq): Omit the auth info for client now. It is more for
// information than anything else.
return conn, nil, nil
return conn, TLSInfo{conn.ConnectionState()}, nil
}

func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) {
Expand Down
96 changes: 96 additions & 0 deletions credentials/credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@
package credentials

import (
"crypto/tls"
"net"
"testing"

"golang.org/x/net/context"
)

func TestTLSOverrideServerName(t *testing.T) {
Expand All @@ -59,3 +63,95 @@ func TestTLSClone(t *testing.T) {
t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName)
}
}

const tlsDir = "../test/testdata/"

func TestTLSClientHandshakeReturnsAuthInfo(t *testing.T) {
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Failed to listen: %v", err)
}
defer lis.Close()
serverTLS, err := NewServerTLSFromFile(tlsDir+"server1.pem", tlsDir+"server1.key")
if err != nil {
t.Fatalf("Failed to create server TLS. Error: %v", err)
}
var serverAuthInfo TLSInfo
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use tls.ConnectionState directly?

done := make(chan bool)
go func() {
defer func() {
done <- true
}()
serverRawConn, err := lis.Accept()
if err != nil {
t.Fatalf("Server failed to accept connection: %v", err)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use t.Errorf

}
serverConn := tls.Server(serverRawConn, serverTLS.(*tlsCreds).config)
serverErr := serverConn.Handshake()
if serverErr != nil {
t.Fatalf("Error on server while handshake. Error: %v", serverErr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

t.Errorf

}
serverAuthInfo = TLSInfo{serverConn.ConnectionState()}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The body can be written in a separate function like:

func serverHandle(hs func(net.Conn) (net.Conn, AuthInfo, error)) {
    ...
}

}()
conn, err := net.Dial("tcp", lis.Addr().String())
if err != nil {
t.Fatalf("Client failed to connect to local server. Error: %v", err)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

print lis.Addr().String() instead of "local server".

}
defer conn.Close()
c := NewTLS(&tls.Config{InsecureSkipVerify: true})
_, authInfo, err := c.ClientHandshake(context.Background(), lis.Addr().String(), conn)
if err != nil {
t.Fatalf("Error on client while handshake. Error: %v", err)
}
// wait until server has populated the serverAuthInfo struct.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.... populated the server AuthInfo or failed.

<-done
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if err := <-done; err != nil {
  return
}

if authInfo.(TLSInfo).State.Version != serverAuthInfo.State.Version {
t.Fatalf("c.ClientHandshake(_, %v, _) = %v, want %v.", lis.Addr().String(), authInfo, serverAuthInfo)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

%s for list.Addr().String()

}
}

func TestTLSServerHandshakeReturnsAuthInfo(t *testing.T) {
lis, err := net.Listen("tcp", "localhost:0")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The above comments apply to this test case too.

if err != nil {
t.Fatalf("Failed to listen: %v", err)
}
defer lis.Close()
serverTLS, err := NewServerTLSFromFile(tlsDir+"server1.pem", tlsDir+"server1.key")
if err != nil {
t.Fatalf("Failed to create server TLS. Error: %v", err)
}
var serverAuthInfo AuthInfo
done := make(chan bool)
go func() {
defer func() {
done <- true
}()
serverRawConn, err := lis.Accept()
if err != nil {
t.Fatalf("Server failed to accept connection: %v", err)
}
var serverErr error
_, serverAuthInfo, serverErr = serverTLS.ServerHandshake(serverRawConn)
if serverErr != nil {
t.Fatalf("Error on server while handshake. Error: %v", serverErr)
}
}()
conn, err := net.Dial("tcp", lis.Addr().String())
if err != nil {
t.Fatalf("Client failed to connect to local server. Error: %v", err)
}
defer conn.Close()
c := NewTLS(&tls.Config{InsecureSkipVerify: true})
clientConn := tls.Client(conn, c.(*tlsCreds).config)
err = clientConn.Handshake()
if err != nil {
t.Fatalf("Error on client while handshake. Error: %v", err)
}
authInfo := TLSInfo{clientConn.ConnectionState()}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use tls.ConnectionState directly?

// wait until server has populated the serverAuthInfo struct.
<-done
if authInfo.State.Version != serverAuthInfo.(TLSInfo).State.Version {
t.Fatalf("ServerHandshake(_) = %v, want %v.", serverAuthInfo, authInfo)
}

}