Skip to content

Commit ff8818f

Browse files
Fix SSH agent forwarding to agentless nodes (#22567)
* lazily forward SSH agents when connecting to agentless nodes When connecting to unregistered OpenSSH nodes, the SSH agent is always forwarded. When connecting to registered OpenSSH (agentless) nodes however, the SSH agent doesn't *need* to be forwarded, so only do so if the user explicitly asks to. * test SSH agent forwarding in agentless integration test
1 parent 3d02011 commit ff8818f

File tree

4 files changed

+65
-37
lines changed

4 files changed

+65
-37
lines changed

integration/integration_test.go

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ import (
4949
"github.com/stretchr/testify/assert"
5050
"github.com/stretchr/testify/require"
5151
"golang.org/x/crypto/ssh"
52+
"golang.org/x/crypto/ssh/agent"
5253
"golang.org/x/exp/slices"
5354
"google.golang.org/grpc"
5455
"google.golang.org/grpc/connectivity"
@@ -7333,8 +7334,15 @@ func testAgentlessConnection(t *testing.T, suite *integrationTestSuite) {
73337334
}, tc.Username)
73347335
require.NoError(t, err)
73357336

7336-
_, _, err = nodeClient.Client.Client.SendRequest("test-request", true, nil)
7337+
// forward SSH agent
7338+
sshClient := nodeClient.Client.Client
7339+
session, err := sshClient.NewSession()
73377340
require.NoError(t, err)
7341+
t.Cleanup(func() {
7342+
require.NoError(t, session.Close())
7343+
})
7344+
require.NoError(t, agent.ForwardToAgent(sshClient, tc.LocalAgent()))
7345+
require.NoError(t, agent.RequestAgentForwarding(session))
73387346

73397347
require.NoError(t, nodeClient.Close())
73407348
}
@@ -7366,21 +7374,41 @@ func startSSHServer(t *testing.T, caPubKeys []ssh.PublicKey, hostKey ssh.Signer)
73667374

73677375
go func() {
73687376
nConn, err := lis.Accept()
7369-
require.NoError(t, err)
7377+
assert.NoError(t, err)
73707378
t.Cleanup(func() {
7379+
// the error is ignored here to avoid failing on net.ErrClosed
73717380
_ = nConn.Close()
73727381
})
73737382

7374-
conn, _, reqs, err := ssh.NewServerConn(nConn, &sshCfg)
7375-
require.NoError(t, err)
7383+
conn, channels, reqs, err := ssh.NewServerConn(nConn, &sshCfg)
7384+
assert.NoError(t, err)
73767385
t.Cleanup(func() {
7386+
// the error is ignored here to avoid failing on net.ErrClosed
73777387
_ = conn.Close()
73787388
})
7389+
go ssh.DiscardRequests(reqs)
73797390

7380-
req := <-reqs
7381-
require.NoError(t, req.Reply(true, nil))
7391+
var agentForwarded bool
7392+
for channelReq := range channels {
7393+
assert.Equal(t, "session", channelReq.ChannelType())
7394+
channel, reqs, err := channelReq.Accept()
7395+
assert.NoError(t, err)
7396+
t.Cleanup(func() {
7397+
// the error is ignored here to avoid failing on net.ErrClosed
7398+
_ = channel.Close()
7399+
})
73827400

7383-
require.NoError(t, conn.Close())
7401+
for req := range reqs {
7402+
if req.WantReply {
7403+
assert.NoError(t, req.Reply(true, nil))
7404+
}
7405+
if req.Type == sshutils.AgentForwardRequest {
7406+
agentForwarded = true
7407+
break
7408+
}
7409+
}
7410+
}
7411+
assert.True(t, agentForwarded)
73847412
}()
73857413

73867414
return lis.Addr().String()

lib/proxy/router.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,6 @@ func (r *Router) DialHost(ctx context.Context, clientSrcAddr, clientDstAddr net.
285285
if err != nil {
286286
return nil, "", trace.Wrap(err)
287287
}
288-
// TODO(capnspacehook): remove when forwarding SSH agent to agentless node works
289288
agentGetter = nil
290289
}
291290

lib/srv/forward/sshserver.go

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,10 @@ func (s *Server) GetLockWatcher() *services.LockWatcher {
514514
}
515515

516516
func (s *Server) Serve() {
517-
config := &ssh.ServerConfig{}
517+
var (
518+
succeeded bool
519+
config = &ssh.ServerConfig{}
520+
)
518521

519522
// Configure callback for user certificate authentication.
520523
config.PublicKeyCallback = s.authHandlers.UserKeyAuth
@@ -538,15 +541,22 @@ func (s *Server) Serve() {
538541
s.log.Debugf("Supported KEX algorithms: %q.", s.kexAlgorithms)
539542
s.log.Debugf("Supported MAC algorithms: %q.", s.macAlgorithms)
540543

541-
sconn, chans, reqs, err := ssh.NewServerConn(s.serverConn, config)
542-
if err != nil {
544+
// close
545+
defer func() {
546+
if succeeded {
547+
return
548+
}
549+
543550
if s.userAgent != nil {
544551
s.userAgent.Close()
545552
}
546553
s.targetConn.Close()
547554
s.clientConn.Close()
548555
s.serverConn.Close()
556+
}()
549557

558+
sconn, chans, reqs, err := ssh.NewServerConn(s.serverConn, config)
559+
if err != nil {
550560
s.log.Errorf("Unable to create server connection: %v.", err)
551561
return
552562
}
@@ -558,13 +568,6 @@ func (s *Server) Serve() {
558568
// Take connection and extract identity information for the user from it.
559569
s.identityContext, err = s.authHandlers.CreateIdentityContext(sconn)
560570
if err != nil {
561-
if s.userAgent != nil {
562-
s.userAgent.Close()
563-
}
564-
s.targetConn.Close()
565-
s.clientConn.Close()
566-
s.serverConn.Close()
567-
568571
s.log.Errorf("Unable to create server connection: %v.", err)
569572
return
570573
}
@@ -578,17 +581,12 @@ func (s *Server) Serve() {
578581
s.rejectChannel(chans, err.Error())
579582
sconn.Close()
580583

581-
if s.userAgent != nil {
582-
s.userAgent.Close()
583-
}
584-
s.targetConn.Close()
585-
s.clientConn.Close()
586-
s.serverConn.Close()
587-
588584
s.log.Errorf("Unable to create remote connection: %v", err)
589585
return
590586
}
591587

588+
succeeded = true
589+
592590
// The keep-alive loop will keep pinging the remote server and after it has
593591
// missed a certain number of keep-alive requests it will cancel the
594592
// closeContext which signals the server to shutdown.
@@ -646,12 +644,12 @@ func (s *Server) newRemoteClient(ctx context.Context, systemLogin string) (*trac
646644
var signers []ssh.Signer
647645
if s.agentlessSigner != nil {
648646
signers = []ssh.Signer{s.agentlessSigner}
649-
} else if s.userAgent != nil {
650-
s, err := s.userAgent.Signers()
647+
} else {
648+
var err error
649+
signers, err = s.userAgent.Signers()
651650
if err != nil {
652651
return nil, trace.Wrap(err)
653652
}
654-
signers = s
655653
}
656654
authMethod := ssh.PublicKeysCallback(signersWithSHA1Fallback(signers))
657655

@@ -1140,19 +1138,24 @@ func (s *Server) dispatch(ctx context.Context, ch ssh.Channel, req *ssh.Request,
11401138
}
11411139

11421140
func (s *Server) handleAgentForward(ch ssh.Channel, req *ssh.Request, ctx *srv.ServerContext) error {
1143-
// TODO(capnspacehook): remove once SSH agent forwarding issue is fixed
1144-
if s.userAgent == nil {
1145-
return trace.BadParameter("SSH agent is not set")
1146-
}
1147-
11481141
// Check if the user's RBAC role allows agent forwarding.
11491142
err := s.authHandlers.CheckAgentForward(ctx)
11501143
if err != nil {
11511144
return trace.Wrap(err)
11521145
}
11531146

11541147
// Route authentication requests to the agent that was forwarded to the proxy.
1155-
err = agent.ForwardToAgent(ctx.RemoteClient.Client, s.userAgent)
1148+
// If no agent was forwarded to the proxy, create one now.
1149+
userAgent := s.userAgent
1150+
if userAgent == nil {
1151+
ctx.ConnectionContext.SetForwardAgent(true)
1152+
userAgent, err = ctx.StartAgentChannel()
1153+
if err != nil {
1154+
return trace.Wrap(err)
1155+
}
1156+
}
1157+
1158+
err = agent.ForwardToAgent(ctx.RemoteClient.Client, userAgent)
11561159
if err != nil {
11571160
return trace.Wrap(err)
11581161
}

lib/sshutils/ctx.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,7 @@ func (a *agentChannel) Close() error {
132132
func (c *ConnectionContext) StartAgentChannel() (teleagent.Agent, error) {
133133
// refuse to start an agent if forwardAgent has not yet been set.
134134
if !c.GetForwardAgent() {
135-
// TODO(capnspacehook): update SSH agent in forwarding SSH server
136-
// when connecting to agentless nodes
137-
return nil, trace.AccessDenied("agent forwarding required in proxy recording mode")
135+
return nil, trace.AccessDenied("agent forwarding has not been requested")
138136
}
139137
// open a agent channel to client
140138
ch, _, err := c.ServerConn.OpenChannel(AuthAgentRequest, nil)

0 commit comments

Comments
 (0)