Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 132 additions & 95 deletions cmd/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,137 +13,128 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//

// Package client creates a tunnel client to proxy incoming connections
//
// This binary creates a tunnel client to proxy incoming connections
// over a grpc transport.
package client
// Exmaples to use this binary with ssh's ProxyCommand option:
// TLS:
// ssh -o ProxyCommand="client
// --tunnel_server_address=localhost:$PORT \
// --cert_file=$CERT_FILE \
// --dial_target=target1 \
// --dial_target_type=SSH" $USER@localhost
// mTLS:
// ssh -o ProxyCommand="client
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whitespace to be consistent with above TLS example.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

// --tunnel_server_address=localhost:$PORT \
// --cert_file=$CERT_FILE \
// --key_file=$KEY_FILE \
// --ca_file=$CA_FILE \
// --dial_target=target1 \
// --dial_target_type=SSH" $USER@localhost
package main

import (
"context"
"flag"
"fmt"
"io"
"log"
"net"
"os"
"sync"
"time"

"github.com/cenkalti/backoff/v4"
"github.com/openconfig/grpctunnel/bidi"
"github.com/openconfig/grpctunnel/tunnel"

"google.golang.org/grpc"
"google.golang.org/grpc/credentials"

tpb "github.com/openconfig/grpctunnel/proto/tunnel"
)

// Config defines the parameters to run a tunnel client.
type Config struct {
TunnelAddress, DialAddress, ListenAddress, CertFile, Target, TargetType string
}
var (
tunnelAddress = flag.String("tunnel_server_address", "", "The address of the tunnel")
dialTarget = flag.String("dial_target", "", "The client uses target to register at the server.")
dialTargetType = flag.String("dial_target_type", "", "The type of target protocol, e.g. GNMI or SSH.")
certFile = flag.String("cert_file", "", "The certificate file location")
keyFile = flag.String("key_file", "", "The private key file location")
caFile = flag.String("ca_file", "", "The CA file location (for mTLS). If provided, it will be handled as mTLS")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a TLS expert, so I may be getting something wrong here, but these options are a bit confusing:
In the TLS case -cert_file is used to specify the CA to verify the server.
In the mTLS case -ca_file is used to specify the CA to verify the server, and -cert_file, -key_file are used to specify the client's certificate and key.

What would be more clear is for the -ca_file to always be used to specify the CA to verify the server. You enable mTLS by passing in cert_file, key_file.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point. Updated.


// for setting retry backoff when waiting for target.
retryBaseDelay = time.Second
retryMaxDelay = time.Minute
retryRandomization = 0.5
)

func listen(ctx context.Context, c *tunnel.Client, listenAddress string, target tunnel.Target) error {
l, err := net.Listen("tcp", listenAddress)
if err != nil {
return fmt.Errorf("failed to listen: %s: %v", listenAddress, err)
}
defer l.Close()
// config defines the parameters to run a tunnel client.
type config struct {
tunnelAddress,
caFile,
keyFile,
certFile,
dialTarget, // The remote target to dial
dialTargetType string // The remote target type to dial
}

errCh := make(chan error)
go func() {
for {
conn, err := l.Accept()
if err != nil {
select {
case errCh <- fmt.Errorf("failed to accept connection: %v", err):
default:
}
return
}
// Errors from this goroutine will be logged only, because we don't want an
// underlying stream issue to tear the server down
go func(conn net.Conn) {
defer conn.Close()
session, err := c.NewSession(target)
if err != nil {
log.Printf("error from new session: %v", err)
return
}
if err = bidi.Copy(session, conn); err != nil {
log.Printf("error from bidi copy: %v", err)
}
}(conn)
}
}()
type stdIOConn struct {
io.Reader
io.WriteCloser
}

select {
case <-ctx.Done():
return ctx.Err()
case err := <-errCh:
return err
}
func getBackOff() *backoff.ExponentialBackOff {
bo := backoff.NewExponentialBackOff()
bo.MaxElapsedTime = 0 // Retry Subscribe indefinitely.
bo.InitialInterval = retryBaseDelay
bo.MaxInterval = retryMaxDelay
bo.RandomizationFactor = retryRandomization
return bo
}

// Run starts a tunnel client, connecting to the tunnel server via the provided tunnel address.
// The client uses the target to identify whether it can handle the target (u) sent by the server.
func Run(ctx context.Context, conf Config) error {
opts := []grpc.DialOption{grpc.WithDefaultCallOptions()}
if conf.CertFile == "" {
opts = append(opts, grpc.WithInsecure())
func run(ctx context.Context, conf config) error {
var opts []grpc.DialOption
var err error
if len(conf.caFile) == 0 {
opts, err = tunnel.DialTLSCredsOpts(conf.certFile)
} else {
creds, err := credentials.NewClientTLSFromFile(conf.CertFile, "")
if err != nil {
return fmt.Errorf("failed to load credentials: %v", err)
}
opts = append(opts, grpc.WithTransportCredentials(creds))
opts, err = tunnel.DialmTLSCredsOpts(conf.certFile, conf.keyFile, conf.caFile)
}

if err != nil {
return err
}
clientConn, err := grpc.Dial(conf.TunnelAddress, opts...)
clientConn, err := grpc.Dial(conf.tunnelAddress, opts...)
if err != nil {
return fmt.Errorf("grpc dial error: %v", err)
}
defer clientConn.Close()

registerHandler := func(t tunnel.Target) error {
if t.ID != conf.Target {
return fmt.Errorf("client cannot handle: %s", t.ID)
}
return nil
}
peers := make(map[tunnel.Target]struct{})
var peerMux sync.Mutex

handler := func(_ tunnel.Target, i io.ReadWriteCloser) error {
conn, err := net.Dial("tcp", conf.DialAddress)
if err != nil {
log.Printf("Error dialing client: %v", err)
return err
}

if err = bidi.Copy(i, conn); err != nil {
// Logging this error only as we don't want the client to stop because an
// underlying stream had an issue
log.Printf("Copy error: %v", err)
}

return nil
}

peerAddCh := make(chan tunnel.Target, 1)
peerAddHandler := func(t tunnel.Target) error {
peerAddCh <- t
peerMux.Lock()
defer peerMux.Unlock()
peers[t] = struct{}{}
log.Printf("peer target %s added\n", t)
return nil
}

peerDelCh := make(chan tunnel.Target, 1)
peerDelHandler := func(t tunnel.Target) error {
peerDelCh <- t
peerMux.Lock()
defer peerMux.Unlock()
if _, ok := peers[t]; ok {
delete(peers, t)
log.Printf("peer target %s deleted\n", t)
}
return nil
}

targets := make(map[tunnel.Target]struct{})
t := tunnel.Target{ID: conf.Target, Type: conf.TargetType}
targets[t] = struct{}{}
client, err := tunnel.NewClient(tpb.NewTunnelClient(clientConn), tunnel.ClientConfig{
RegisterHandler: registerHandler,
Handler: handler,
Subscriptions: []string{conf.TargetType},
PeerAddHandler: peerAddHandler,
PeerDelHandler: peerDelHandler,
PeerAddHandler: peerAddHandler,
PeerDelHandler: peerDelHandler,
Subscriptions: []string{conf.dialTargetType},
}, targets)

if err != nil {
Expand All @@ -165,13 +156,59 @@ func Run(ctx context.Context, conf Config) error {
}
}()

// listen for any request to create a new session
dialTarget := tunnel.Target{ID: conf.dialTarget, Type: conf.dialTargetType}
foundDialTarget := func() bool {
peerMux.Lock()
defer peerMux.Unlock()
_, ok := peers[dialTarget]
return ok
}

// Dial the target with retry.
go func() {
bo := getBackOff()
for !foundDialTarget() {
wait := bo.NextBackOff()
log.Printf("dial target %s (type: %s) not found. reconnecting in %s (all targets found: %s) \n", conf.dialTarget, conf.dialTargetType, wait, peers)
time.Sleep(wait)
}

session, err := client.NewSession(dialTarget)
if err != nil {
log.Printf("error from new session: %v", err)
errCh <- err
return
}
log.Printf("new session established for target: %s\n", dialTarget)

// Once a tunnel session is established, it connects it to a stdio.
stdio := &stdIOConn{Reader: os.Stdin, WriteCloser: os.Stdout}
if err = bidi.Copy(session, stdio); err != nil {
log.Printf("error from bidi copy: %v\n", err)
return
}

}()

// Listen for any request to create a new session.
select {
case target := <-peerAddCh:
return listen(ctx, client, conf.ListenAddress, target)
case <-ctx.Done():
return ctx.Err()
case err := <-errCh:
return err
return fmt.Errorf("exiting: %s", err)
}
}

func main() {
flag.Parse()
if err := run(context.Background(), config{
tunnelAddress: *tunnelAddress,
dialTarget: *dialTarget,
dialTargetType: *dialTargetType,
certFile: *certFile,
keyFile: *keyFile,
caFile: *caFile,
}); err != nil {
log.Fatal(err)
}
}
65 changes: 0 additions & 65 deletions cmd/client/client_test.go

This file was deleted.

Loading