Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 81 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,91 @@ import (
"tailscale.com/client/tailscale/v2"
)

func main() {
client := &tailscale.Client{
Tailnet: os.Getenv("TAILSCALE_TAILNET"),
ClientID: os.Getenv("TAILSCALE_OAUTH_CLIENT_ID"),
ClientSecret: os.Getenv("TAILSCALE_OAUTH_CLIENT_SECRET"),
Scopes: []string{"all:write"},
}

devices, err := client.Devices().List(context.Background())
}
```

## Example (Using Identity Federation)

### With a static ID token:

For static ID tokens, simply return the same token value each time. Note that if both the Tailscale API access token
and the ID token expire, the client must be recreated with a fresh ID token to reauthenticate.

```go
package main

import (
"context"
"os"

"tailscale.com/client/tailscale/v2"
)

func main() {
client := &tailscale.Client{
Tailnet: os.Getenv("TAILSCALE_TAILNET"),
HTTP: tailscale.OAuthConfig{
ClientID: os.Getenv("TAILSCALE_OAUTH_CLIENT_ID"),
ClientSecret: os.Getenv("TAILSCALE_OAUTH_CLIENT_SECRET"),
Scopes: []string{"all:write"},
}.HTTPClient(),
ClientID: os.Getenv("TAILSCALE_CLIENT_ID"),
IDTokenFunc: func() (string, error) {
return os.Getenv("ID_TOKEN"), nil
},
}


devices, err := client.Devices().List(context.Background())
}
```

### With a dynamic ID token generator

For long-running applications, instruct the client on how to fetch ID tokens from your IdP so the client can reauthenticate
automatically when the Tailscale API access token and ID token expire:

```go
package main

import (
"context"
"fmt"
"io"
"net/http"
"os"
"strings"

"tailscale.com/client/tailscale/v2"
)

func main() {
client := &tailscale.Client{
Tailnet: os.Getenv("TAILSCALE_TAILNET"),
ClientID: os.Getenv("TAILSCALE_CLIENT_ID"),
IDTokenFunc: func() (string, error) {
resp, err := http.Get("https://my-idp.com/id-token")
if err != nil {
return "", fmt.Errorf("failed to fetch ID token: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("failed to fetch ID token: status %d", resp.StatusCode)
}

body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("failed to read ID token response: %w", err)
}

return strings.TrimSpace(string(body)), nil
},
}

devices, err := client.Devices().List(context.Background())
}
```
Expand Down
71 changes: 68 additions & 3 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"io"
"net/http"
"net/url"
"strings"
"sync"
"time"

Expand All @@ -27,9 +28,6 @@ type Client struct {
BaseURL *url.URL
// UserAgent configures the User-Agent HTTP header for requests. Defaults to "tailscale-client-go".
UserAgent string
// APIKey allows specifying an APIKey to use for authentication.
// To use OAuth Client credentials, construct an [http.Client] using [OAuthConfig] and specify that below.
APIKey string
// Tailnet allows specifying a specific tailnet by name, to which this Client will connect by default.
// If Tailnet is left blank, the client will connect to default tailnet based on the client's credential,
// using the "-" (dash) default tailnet path.
Expand All @@ -39,6 +37,29 @@ type Client struct {
// If not specified, a new [http.Client] with a Timeout of 1 minute will be used.
HTTP *http.Client

// APIKey allows specifying an APIKey to use for authentication.
// To use OAuth Client credentials, set ClientID and ClientSecret instead.
// To use identity federation, set ClientID and IDTokenFunc instead.
APIKey string
// ClientID is the ID of the Tailscale OAuth client.
// When set along with ClientSecret, the client will use OAuth client credentials for authentication.
// When set along with IDTokenFunc, the client will use identity federation for authentication.
// For OAuth, if empty and ClientSecret is provided, the ClientID will be derived from the ClientSecret.
// Overwrites APIKey if set.
ClientID string
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of adding more and more parameters to the Client type, let's just make different pluggable authentication mechanisms and plug them into a single place like here.

// ClientSecret is the client secret of the OAuth client.
// When set, the client will use OAuth client credentials for authentication.
ClientSecret string
// IDTokenFunc returns an identity token from the IdP to exchange for a Tailscale API token.
// The client calls this function to obtain a fresh ID token and reauthenticate when the API token
// and cached ID token have expired. For static tokens, return the token directly. If a static token
// expires, the client cannot automatically refresh the API token; the consumer is responsible to create a new client
// with a fresh ID token.
IDTokenFunc func() (string, error)
// Scopes are the scopes to request when generating tokens for the OAuth client.
// Only used when ClientSecret is set.
Scopes []string

initOnce sync.Once

// Specific resources
Expand Down Expand Up @@ -71,6 +92,9 @@ const defaultContentType = "application/json"
const defaultHttpClientTimeout = time.Minute
const defaultUserAgent = "tailscale-client-go"

// defaultClientID is a fallback if we failed to derive the real ClientID from the ClientSecret
const defaultClientID = "k1234DEVEL"

var defaultBaseURL *url.URL

func init() {
Expand All @@ -97,6 +121,35 @@ func (c *Client) init() {
if c.Tailnet == "" {
c.Tailnet = "-"
}

var underlyingTransport http.RoundTripper
if c.HTTP.Transport != nil {
underlyingTransport = c.HTTP.Transport
} else {
underlyingTransport = http.DefaultTransport
}

if c.ClientID != "" && c.IDTokenFunc != nil {
c.HTTP.Transport = newIdentityFederationTransport(
underlyingTransport,
c.BaseURL.String(),
c.ClientID,
c.IDTokenFunc,
)
} else if c.ClientSecret != "" {
if c.ClientID == "" {
c.ClientID = deriveClientID(c.ClientSecret)
}

c.HTTP.Transport = newOAuthTransport(
underlyingTransport,
c.BaseURL.String(),
c.ClientID,
c.ClientSecret,
c.Scopes,
)
}

c.contacts = &ContactsResource{c}
c.devicePosture = &DevicePostureResource{c}
c.devices = &DevicesResource{c}
Expand Down Expand Up @@ -379,3 +432,15 @@ func ErrorData(err error) []APIErrorData {
func PointerTo[T any](value T) *T {
return &value
}

// deriveClientID extracts the ClientID from a ClientSecret.
// It expects the ClientSecret to be in the format "tskey-client-{clientID}-{suffix}".
// If the derived ClientID ends up equal to the ClientSecret because the value does not have the expecyed shape,
// it returns a dummy defaultClientID to prevent logging the ClientSecret value by logging the ClientID by mistake.
func deriveClientID(clientSecret string) string {
clientID, _, _ := strings.Cut(strings.TrimPrefix(clientSecret, "tskey-client-"), "-")
if clientID == clientSecret {
return defaultClientID
}
return clientID
}
113 changes: 113 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ package tailscale
import (
_ "embed"
"io"
"net/http"
"net/url"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
)

func TestErrorData(t *testing.T) {
Expand Down Expand Up @@ -69,6 +71,117 @@ func Test_BuildTailnetURLDefault(t *testing.T) {
assert.EqualValues(t, expected.String(), actual.String())
}

func Test_ClientAuthentication(t *testing.T) {
t.Parallel()

t.Run("OAuth transport is set when ClientSecret are provided", func(t *testing.T) {
c := &Client{
ClientSecret: "tskey-client-abc123-xyz789",
Scopes: []string{"all:read"},
}
c.init()

assert.NotNil(t, c.HTTP)
assert.NotNil(t, c.HTTP.Transport)
_, ok := c.HTTP.Transport.(*oauth2.Transport)
assert.True(t, ok, "expected transport to be *oauth2.Transport")
})

t.Run("OAuth transport is set when Federated Identity config is provided", func(t *testing.T) {
c := &Client{
ClientID: "test-client-id",
IDTokenFunc: func() (string, error) {
return "test-token", nil
},
}
c.init()

assert.NotNil(t, c.HTTP)
assert.NotNil(t, c.HTTP.Transport)
_, ok := c.HTTP.Transport.(*oauth2.Transport)
assert.True(t, ok, "expected transport to be *oauth2.Transport")
})

t.Run("OAuth wraps custom transport preserving proxy settings", func(t *testing.T) {
proxyURL, _ := url.Parse("http://proxy.example.com:8080")
customTransport := &http.Transport{
Proxy: http.ProxyURL(proxyURL),
}
c := &Client{
ClientSecret: "tskey-client-abc123-xyz789",
HTTP: &http.Client{
Transport: customTransport,
},
}
c.init()

assert.NotNil(t, c.HTTP)
assert.NotNil(t, c.HTTP.Transport)
oauthTransport, ok := c.HTTP.Transport.(*oauth2.Transport)
assert.True(t, ok, "expected transport to be *oauth2.Transport")

// Verify the custom transport with proxy is wrapped
wrappedTransport, ok := oauthTransport.Base.(*http.Transport)
assert.True(t, ok, "underlying transport should be *http.Transport")
assert.NotNil(t, wrappedTransport.Proxy, "proxy setting should be preserved")
})

t.Run("Identity federation wraps custom transport preserving proxy settings", func(t *testing.T) {
proxyURL, _ := url.Parse("http://proxy.example.com:8080")
customTransport := &http.Transport{
Proxy: http.ProxyURL(proxyURL),
}
c := &Client{
ClientID: "test-client-id",
IDTokenFunc: func() (string, error) {
return "test-token", nil
},
HTTP: &http.Client{
Transport: customTransport,
},
}
c.init()

assert.NotNil(t, c.HTTP)
assert.NotNil(t, c.HTTP.Transport)
tokenTransport, ok := c.HTTP.Transport.(*oauth2.Transport)
assert.True(t, ok, "expected transport to be *oauth2.Transport")

// Verify the custom transport with proxy is wrapped
wrappedTransport, ok := tokenTransport.Base.(*http.Transport)
assert.True(t, ok, "underlying transport should be *http.Transport")
assert.NotNil(t, wrappedTransport.Proxy, "proxy setting should be preserved")
})
}

func Test_DeriveClientID(t *testing.T) {
t.Parallel()

tests := []struct {
name string
clientSecret string
want string
}{
{
name: "Valid client secret with standard format",
clientSecret: "tskey-client-abc123-xyz789",
want: "abc123",
},
{
name: "Client secret with unexpected shape",
clientSecret: "plaintext",
want: defaultClientID,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := deriveClientID(tt.clientSecret)
assert.Equal(t, tt.want, got)
})
}
}

func ptrTo[T any](v T) *T {
return &v
}
Loading