Skip to content

Commit 9031f13

Browse files
committed
identityfederation: replace the HTTP ctor with additional client fields
Signed-off-by: mcoulombe <[email protected]>
1 parent 2645aaa commit 9031f13

File tree

6 files changed

+347
-166
lines changed

6 files changed

+347
-166
lines changed

README.md

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,12 @@ import (
4242

4343
func main() {
4444
client := &tailscale.Client{
45-
Tailnet: os.Getenv("TAILSCALE_TAILNET"),
46-
HTTP: tailscale.OAuthConfig{
47-
ClientID: os.Getenv("TAILSCALE_OAUTH_CLIENT_ID"),
48-
ClientSecret: os.Getenv("TAILSCALE_OAUTH_CLIENT_SECRET"),
49-
Scopes: []string{"all:write"},
50-
}.HTTPClient(),
45+
Tailnet: os.Getenv("TAILSCALE_TAILNET"),
46+
ClientID: os.Getenv("TAILSCALE_OAUTH_CLIENT_ID"),
47+
ClientSecret: os.Getenv("TAILSCALE_OAUTH_CLIENT_SECRET"),
48+
Scopes: []string{"all:write"},
5149
}
52-
50+
5351
devices, err := client.Devices().List(context.Background())
5452
}
5553
```
@@ -66,26 +64,18 @@ package main
6664

6765
import (
6866
"context"
69-
"log"
7067
"os"
7168

7269
"tailscale.com/client/tailscale/v2"
7370
)
7471

7572
func main() {
76-
httpClient, err := tailscale.IdentityFederationConfig{
73+
client := &tailscale.Client{
74+
Tailnet: os.Getenv("TAILSCALE_TAILNET"),
7775
ClientID: os.Getenv("TAILSCALE_CLIENT_ID"),
7876
IDTokenFunc: func() (string, error) {
7977
return os.Getenv("ID_TOKEN"), nil
8078
},
81-
}.HTTPClient()
82-
if err != nil {
83-
log.Fatal(err)
84-
}
85-
86-
client := &tailscale.Client{
87-
Tailnet: os.Getenv("TAILSCALE_TAILNET"),
88-
HTTP: httpClient,
8979
}
9080

9181
devices, err := client.Devices().List(context.Background())
@@ -104,7 +94,6 @@ import (
10494
"context"
10595
"fmt"
10696
"io"
107-
"log"
10897
"net/http"
10998
"os"
11099
"strings"
@@ -113,7 +102,8 @@ import (
113102
)
114103

115104
func main() {
116-
httpClient, err := tailscale.IdentityFederationConfig{
105+
client := &tailscale.Client{
106+
Tailnet: os.Getenv("TAILSCALE_TAILNET"),
117107
ClientID: os.Getenv("TAILSCALE_CLIENT_ID"),
118108
IDTokenFunc: func() (string, error) {
119109
resp, err := http.Get("https://my-idp.com/id-token")
@@ -133,14 +123,6 @@ func main() {
133123

134124
return strings.TrimSpace(string(body)), nil
135125
},
136-
}.HTTPClient()
137-
if err != nil {
138-
log.Fatal(err)
139-
}
140-
141-
client := &tailscale.Client{
142-
Tailnet: os.Getenv("TAILSCALE_TAILNET"),
143-
HTTP: httpClient,
144126
}
145127

146128
devices, err := client.Devices().List(context.Background())

client.go

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"io"
1616
"net/http"
1717
"net/url"
18+
"strings"
1819
"sync"
1920
"time"
2021

@@ -27,9 +28,6 @@ type Client struct {
2728
BaseURL *url.URL
2829
// UserAgent configures the User-Agent HTTP header for requests. Defaults to "tailscale-client-go".
2930
UserAgent string
30-
// APIKey allows specifying an APIKey to use for authentication.
31-
// To use OAuth Client credentials, construct an [http.Client] using [OAuthConfig] and specify that below.
32-
APIKey string
3331
// Tailnet allows specifying a specific tailnet by name, to which this Client will connect by default.
3432
// If Tailnet is left blank, the client will connect to default tailnet based on the client's credential,
3533
// using the "-" (dash) default tailnet path.
@@ -39,6 +37,29 @@ type Client struct {
3937
// If not specified, a new [http.Client] with a Timeout of 1 minute will be used.
4038
HTTP *http.Client
4139

40+
// APIKey allows specifying an APIKey to use for authentication.
41+
// To use OAuth Client credentials, set ClientID and ClientSecret instead.
42+
// To use identity federation, set ClientID and IDTokenFunc instead.
43+
APIKey string
44+
// ClientID is the ID of the Tailscale OAuth client.
45+
// When set along with ClientSecret, the client will use OAuth client credentials for authentication.
46+
// When set along with IDTokenFunc, the client will use identity federation for authentication.
47+
// For OAuth, if empty and ClientSecret is provided, the ClientID will be derived from the ClientSecret.
48+
// Overwrites APIKey if set.
49+
ClientID string
50+
// ClientSecret is the client secret of the OAuth client.
51+
// When set, the client will use OAuth client credentials for authentication.
52+
ClientSecret string
53+
// IDTokenFunc returns an identity token from the IdP to exchange for a Tailscale API token.
54+
// The client calls this function to obtain a fresh ID token and reauthenticate when the API token
55+
// and cached ID token have expired. For static tokens, return the token directly. If a static token
56+
// expires, the client cannot automatically refresh the API token; the consumer is responsible to create a new client
57+
// with a fresh ID token.
58+
IDTokenFunc func() (string, error)
59+
// Scopes are the scopes to request when generating tokens for the OAuth client.
60+
// Only used when ClientSecret is set.
61+
Scopes []string
62+
4263
initOnce sync.Once
4364

4465
// Specific resources
@@ -71,6 +92,9 @@ const defaultContentType = "application/json"
7192
const defaultHttpClientTimeout = time.Minute
7293
const defaultUserAgent = "tailscale-client-go"
7394

95+
// defaultClientID is a fallback if we failed to derive the real ClientID from the ClientSecret
96+
const defaultClientID = "k1234DEVEL"
97+
7498
var defaultBaseURL *url.URL
7599

76100
func init() {
@@ -97,6 +121,43 @@ func (c *Client) init() {
97121
if c.Tailnet == "" {
98122
c.Tailnet = "-"
99123
}
124+
125+
var underlyingTransport http.RoundTripper
126+
if c.HTTP.Transport != nil {
127+
underlyingTransport = c.HTTP.Transport
128+
} else {
129+
underlyingTransport = http.DefaultTransport
130+
}
131+
132+
if c.ClientID != "" && c.IDTokenFunc != nil {
133+
transport := &tokenTransport{
134+
transport: underlyingTransport,
135+
baseURL: c.BaseURL.String(),
136+
clientID: c.ClientID,
137+
idTokenFunc: c.IDTokenFunc,
138+
}
139+
140+
// Wrap the HTTP client's transport with the federated-identity-aware transport
141+
c.HTTP.Transport = transport
142+
} else if c.ClientSecret != "" {
143+
// Derive ClientID from ClientSecret if not provided
144+
clientID := c.ClientID
145+
if clientID == "" {
146+
clientID = deriveClientID(c.ClientSecret)
147+
}
148+
149+
transport := &oauthTransport{
150+
transport: underlyingTransport,
151+
baseURL: c.BaseURL.String(),
152+
clientID: clientID,
153+
clientSecret: c.ClientSecret,
154+
scopes: c.Scopes,
155+
}
156+
157+
// Wrap the HTTP client's transport with the OAuth-aware transport
158+
c.HTTP.Transport = transport
159+
}
160+
100161
c.contacts = &ContactsResource{c}
101162
c.devicePosture = &DevicePostureResource{c}
102163
c.devices = &DevicesResource{c}
@@ -379,3 +440,15 @@ func ErrorData(err error) []APIErrorData {
379440
func PointerTo[T any](value T) *T {
380441
return &value
381442
}
443+
444+
// deriveClientID extracts the ClientID from a ClientSecret.
445+
// It expects the ClientSecret to be in the format "tskey-client-{clientID}-{suffix}".
446+
// If the derived ClientID ends up equal to the ClientSecret because the value does not have the expecyed shape,
447+
// it returns a dummy defaultClientID to prevent logging the ClientSecret value by logging the ClientID by mistake.
448+
func deriveClientID(clientSecret string) string {
449+
clientID, _, _ := strings.Cut(strings.TrimPrefix(clientSecret, "tskey-client-"), "-")
450+
if clientID == clientSecret {
451+
return defaultClientID
452+
}
453+
return clientID
454+
}

client_test.go

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package tailscale
66
import (
77
_ "embed"
88
"io"
9+
"net/http"
910
"net/url"
1011
"testing"
1112

@@ -69,6 +70,118 @@ func Test_BuildTailnetURLDefault(t *testing.T) {
6970
assert.EqualValues(t, expected.String(), actual.String())
7071
}
7172

73+
func Test_ClientAuthentication(t *testing.T) {
74+
t.Parallel()
75+
76+
t.Run("OAuth transport is set when ClientSecret are provided", func(t *testing.T) {
77+
c := &Client{
78+
ClientSecret: "tskey-client-abc123-xyz789",
79+
Scopes: []string{"all:read"},
80+
}
81+
c.init()
82+
83+
assert.NotNil(t, c.HTTP)
84+
assert.NotNil(t, c.HTTP.Transport)
85+
transport, ok := c.HTTP.Transport.(*oauthTransport)
86+
assert.True(t, ok, "expected transport to be *oauthTransport")
87+
assert.Equal(t, "abc123", transport.clientID, "clientID should be derived from ClientSecret")
88+
})
89+
90+
t.Run("Identity federation transport takes precedence when ClientID and IDTokenFunc are set", func(t *testing.T) {
91+
c := &Client{
92+
ClientID: "test-client-id",
93+
IDTokenFunc: func() (string, error) {
94+
return "test-token", nil
95+
},
96+
}
97+
c.init()
98+
99+
assert.NotNil(t, c.HTTP)
100+
assert.NotNil(t, c.HTTP.Transport)
101+
_, ok := c.HTTP.Transport.(*tokenTransport)
102+
assert.True(t, ok, "expected transport to be *tokenTransport")
103+
})
104+
105+
t.Run("OAuth wraps custom transport preserving proxy settings", func(t *testing.T) {
106+
proxyURL, _ := url.Parse("http://proxy.example.com:8080")
107+
customTransport := &http.Transport{
108+
Proxy: http.ProxyURL(proxyURL),
109+
}
110+
c := &Client{
111+
ClientSecret: "tskey-client-abc123-xyz789",
112+
HTTP: &http.Client{
113+
Transport: customTransport,
114+
},
115+
}
116+
c.init()
117+
118+
assert.NotNil(t, c.HTTP)
119+
assert.NotNil(t, c.HTTP.Transport)
120+
oauthTransport, ok := c.HTTP.Transport.(*oauthTransport)
121+
assert.True(t, ok, "expected transport to be *oauthTransport")
122+
123+
// Verify the custom transport with proxy is wrapped
124+
wrappedTransport, ok := oauthTransport.transport.(*http.Transport)
125+
assert.True(t, ok, "underlying transport should be *http.Transport")
126+
assert.NotNil(t, wrappedTransport.Proxy, "proxy setting should be preserved")
127+
})
128+
129+
t.Run("Identity federation wraps custom transport preserving proxy settings", func(t *testing.T) {
130+
proxyURL, _ := url.Parse("http://proxy.example.com:8080")
131+
customTransport := &http.Transport{
132+
Proxy: http.ProxyURL(proxyURL),
133+
}
134+
c := &Client{
135+
ClientID: "test-client-id",
136+
IDTokenFunc: func() (string, error) {
137+
return "test-token", nil
138+
},
139+
HTTP: &http.Client{
140+
Transport: customTransport,
141+
},
142+
}
143+
c.init()
144+
145+
assert.NotNil(t, c.HTTP)
146+
assert.NotNil(t, c.HTTP.Transport)
147+
tokenTransport, ok := c.HTTP.Transport.(*tokenTransport)
148+
assert.True(t, ok, "expected transport to be *tokenTransport")
149+
150+
// Verify the custom transport with proxy is wrapped
151+
wrappedTransport, ok := tokenTransport.transport.(*http.Transport)
152+
assert.True(t, ok, "underlying transport should be *http.Transport")
153+
assert.NotNil(t, wrappedTransport.Proxy, "proxy setting should be preserved")
154+
})
155+
}
156+
157+
func Test_DeriveClientID(t *testing.T) {
158+
t.Parallel()
159+
160+
tests := []struct {
161+
name string
162+
clientSecret string
163+
want string
164+
}{
165+
{
166+
name: "Valid client secret with standard format",
167+
clientSecret: "tskey-client-abc123-xyz789",
168+
want: "abc123",
169+
},
170+
{
171+
name: "Client secret with unexpected shape",
172+
clientSecret: "plaintext",
173+
want: defaultClientID,
174+
},
175+
}
176+
177+
for _, tt := range tests {
178+
t.Run(tt.name, func(t *testing.T) {
179+
got := deriveClientID(tt.clientSecret)
180+
assert.Equal(t, tt.want, got)
181+
})
182+
}
183+
}
184+
72185
func ptrTo[T any](v T) *T {
73186
return &v
74187
}

0 commit comments

Comments
 (0)