diff --git a/dialer.go b/dialer.go index c5c6d363..967a688d 100644 --- a/dialer.go +++ b/dialer.go @@ -120,7 +120,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) { return nil, cfg.err } } - if cfg.useIAMAuthN && cfg.setTokenSource && cfg.iamLoginTokenSource == nil { + if cfg.useIAMAuthN && cfg.setTokenSource && !cfg.setIAMAuthNTokenSource { return nil, errUseIAMTokenSource } if cfg.setIAMAuthNTokenSource && !cfg.useIAMAuthN { @@ -132,7 +132,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) { // If callers have not provided a token source, either explicitly with // WithTokenSource or implicitly with WithCredentialsJSON etc, then use the // default token source. - if cfg.iamLoginTokenSource == nil { + if !cfg.setCredentials { ts, err := google.DefaultTokenSource(ctx, sqladmin.SqlserviceAdminScope) if err != nil { return nil, fmt.Errorf("failed to create token source: %v", err) diff --git a/e2e_postgres_test.go b/e2e_postgres_test.go index a6efcac1..fa26728f 100644 --- a/e2e_postgres_test.go +++ b/e2e_postgres_test.go @@ -29,6 +29,8 @@ import ( "cloud.google.com/go/cloudsqlconn" "github.com/jackc/pgx/v4" + "golang.org/x/oauth2" + "golang.org/x/oauth2/google" "cloud.google.com/go/cloudsqlconn/postgres/pgxv4" ) @@ -184,3 +186,105 @@ func TestPostgresHook(t *testing.T) { defer db2.Close() testConn(db2) } + +// removeAuthEnvVar retrieves an OAuth2 token and a path to a service account key +// and then unsets GOOGLE_APPLICATION_CREDENTIALS. It returns a cleanup function +// that restores the original setup. +func removeAuthEnvVar(t *testing.T) (*oauth2.Token, string, func()) { + ts, err := google.DefaultTokenSource(context.Background(), + "https://www.googleapis.com/auth/cloud-platform", + ) + if err != nil { + t.Errorf("failed to resolve token source: %v", err) + } + tok, err := ts.Token() + if err != nil { + t.Errorf("failed to get token: %v", err) + } + path, ok := os.LookupEnv("GOOGLE_APPLICATION_CREDENTIALS") + if !ok { + t.Fatalf("GOOGLE_APPLICATION_CREDENTIALS was not set in the environment") + } + if err := os.Unsetenv("GOOGLE_APPLICATION_CREDENTIALS"); err != nil { + t.Fatalf("failed to unset GOOGLE_APPLICATION_CREDENTIALS") + } + return tok, path, func() { + os.Setenv("GOOGLE_APPLICATION_CREDENTIALS", path) + } +} + +func keyfile(t *testing.T) string { + path := os.Getenv("GOOGLE_APPLICATION_CREDENTIALS") + if path == "" { + t.Fatal("GOOGLE_APPLICATION_CREDENTIALS not set") + } + creds, err := os.ReadFile(path) + if err != nil { + t.Fatalf("io.ReadAll(): %v", err) + } + return string(creds) +} + +func TestPostgresAuthentication(t *testing.T) { + if testing.Short() { + t.Skip("skipping Postgres integration tests") + } + requirePostgresVars(t) + + creds := keyfile(t) + tok, path, cleanup := removeAuthEnvVar(t) + defer cleanup() + + tcs := []struct { + desc string + opts []cloudsqlconn.Option + }{ + { + desc: "with token", + opts: []cloudsqlconn.Option{cloudsqlconn.WithTokenSource( + oauth2.StaticTokenSource(tok), + )}, + }, + { + desc: "with credentials file", + opts: []cloudsqlconn.Option{cloudsqlconn.WithCredentialsFile(path)}, + }, + { + desc: "with credentials JSON", + opts: []cloudsqlconn.Option{cloudsqlconn.WithCredentialsJSON([]byte(creds))}, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + ctx := context.Background() + + d, err := cloudsqlconn.NewDialer(ctx, tc.opts...) + if err != nil { + t.Fatalf("failed to init Dialer: %v", err) + } + + dsn := fmt.Sprintf("user=%s password=%s dbname=%s sslmode=disable", postgresUser, postgresPass, postgresDB) + config, err := pgx.ParseConfig(dsn) + if err != nil { + t.Fatalf("failed to parse pgx config: %v", err) + } + + config.DialFunc = func(ctx context.Context, network string, instance string) (net.Conn, error) { + return d.Dial(ctx, postgresConnName) + } + + conn, connErr := pgx.ConnectConfig(ctx, config) + if connErr != nil { + t.Fatalf("failed to connect: %s", connErr) + } + defer conn.Close(ctx) + + var now time.Time + err = conn.QueryRow(context.Background(), "SELECT NOW()").Scan(&now) + if err != nil { + t.Fatalf("QueryRow failed: %s", err) + } + t.Log(now) + }) + } +} diff --git a/options.go b/options.go index 11160353..7fca8252 100644 --- a/options.go +++ b/options.go @@ -40,10 +40,11 @@ type dialerConfig struct { dialFunc func(ctx context.Context, network, addr string) (net.Conn, error) refreshTimeout time.Duration useIAMAuthN bool - setTokenSource bool - setIAMAuthNTokenSource bool iamLoginTokenSource oauth2.TokenSource useragents []string + setCredentials bool + setTokenSource bool + setIAMAuthNTokenSource bool // err tracks any dialer options that may have failed. err error } @@ -90,6 +91,7 @@ func WithCredentialsJSON(b []byte) Option { return } d.iamLoginTokenSource = scoped.TokenSource + d.setCredentials = true } } @@ -117,6 +119,7 @@ func WithDefaultDialOptions(opts ...DialOption) Option { func WithTokenSource(s oauth2.TokenSource) Option { return func(d *dialerConfig) { d.setTokenSource = true + d.setCredentials = true d.sqladminOpts = append(d.sqladminOpts, apiopt.WithTokenSource(s)) } } @@ -138,6 +141,7 @@ func WithTokenSource(s oauth2.TokenSource) Option { func WithIAMAuthNTokenSources(apiTS, iamLoginTS oauth2.TokenSource) Option { return func(d *dialerConfig) { d.setIAMAuthNTokenSource = true + d.setCredentials = true d.iamLoginTokenSource = iamLoginTS d.sqladminOpts = append(d.sqladminOpts, apiopt.WithTokenSource(apiTS)) }