Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
return nil, cfg.err
}
}
if cfg.useIAMAuthN && cfg.setTokenSource && cfg.iamLoginTokenSource == nil {
if cfg.useIAMAuthN && cfg.setTokenSource && !cfg.setIAMAuthNTokenSource {
return nil, errUseIAMTokenSource
}
if cfg.setIAMAuthNTokenSource && !cfg.useIAMAuthN {
Expand All @@ -132,7 +132,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
// If callers have not provided a token source, either explicitly with
// WithTokenSource or implicitly with WithCredentialsJSON etc, then use the
// default token source.
if cfg.iamLoginTokenSource == nil {
if !cfg.setCredentials {
ts, err := google.DefaultTokenSource(ctx, sqladmin.SqlserviceAdminScope)
if err != nil {
return nil, fmt.Errorf("failed to create token source: %v", err)
Expand Down
104 changes: 104 additions & 0 deletions e2e_postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ import (

"cloud.google.com/go/cloudsqlconn"
"github.com/jackc/pgx/v4"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"

"cloud.google.com/go/cloudsqlconn/postgres/pgxv4"
)
Expand Down Expand Up @@ -184,3 +186,105 @@ func TestPostgresHook(t *testing.T) {
defer db2.Close()
testConn(db2)
}

// removeAuthEnvVar retrieves an OAuth2 token and a path to a service account key
// and then unsets GOOGLE_APPLICATION_CREDENTIALS. It returns a cleanup function
// that restores the original setup.
func removeAuthEnvVar(t *testing.T) (*oauth2.Token, string, func()) {
ts, err := google.DefaultTokenSource(context.Background(),
"https://www.googleapis.com/auth/cloud-platform",
)
if err != nil {
t.Errorf("failed to resolve token source: %v", err)
}
tok, err := ts.Token()
if err != nil {
t.Errorf("failed to get token: %v", err)
}
path, ok := os.LookupEnv("GOOGLE_APPLICATION_CREDENTIALS")
if !ok {
t.Fatalf("GOOGLE_APPLICATION_CREDENTIALS was not set in the environment")
}
if err := os.Unsetenv("GOOGLE_APPLICATION_CREDENTIALS"); err != nil {
t.Fatalf("failed to unset GOOGLE_APPLICATION_CREDENTIALS")
}
return tok, path, func() {
os.Setenv("GOOGLE_APPLICATION_CREDENTIALS", path)
}
}

func keyfile(t *testing.T) string {
path := os.Getenv("GOOGLE_APPLICATION_CREDENTIALS")
if path == "" {
t.Fatal("GOOGLE_APPLICATION_CREDENTIALS not set")
}
creds, err := os.ReadFile(path)
if err != nil {
t.Fatalf("io.ReadAll(): %v", err)
}
return string(creds)
}

func TestPostgresAuthentication(t *testing.T) {
if testing.Short() {
t.Skip("skipping Postgres integration tests")
}
requirePostgresVars(t)

creds := keyfile(t)
tok, path, cleanup := removeAuthEnvVar(t)
defer cleanup()

tcs := []struct {
desc string
opts []cloudsqlconn.Option
}{
{
desc: "with token",
opts: []cloudsqlconn.Option{cloudsqlconn.WithTokenSource(
oauth2.StaticTokenSource(tok),
)},
},
{
desc: "with credentials file",
opts: []cloudsqlconn.Option{cloudsqlconn.WithCredentialsFile(path)},
},
{
desc: "with credentials JSON",
opts: []cloudsqlconn.Option{cloudsqlconn.WithCredentialsJSON([]byte(creds))},
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
ctx := context.Background()

d, err := cloudsqlconn.NewDialer(ctx, tc.opts...)
if err != nil {
t.Fatalf("failed to init Dialer: %v", err)
}

dsn := fmt.Sprintf("user=%s password=%s dbname=%s sslmode=disable", postgresUser, postgresPass, postgresDB)
config, err := pgx.ParseConfig(dsn)
if err != nil {
t.Fatalf("failed to parse pgx config: %v", err)
}

config.DialFunc = func(ctx context.Context, network string, instance string) (net.Conn, error) {
return d.Dial(ctx, postgresConnName)
}

conn, connErr := pgx.ConnectConfig(ctx, config)
if connErr != nil {
t.Fatalf("failed to connect: %s", connErr)
}
defer conn.Close(ctx)

var now time.Time
err = conn.QueryRow(context.Background(), "SELECT NOW()").Scan(&now)
if err != nil {
t.Fatalf("QueryRow failed: %s", err)
}
t.Log(now)
})
}
}
8 changes: 6 additions & 2 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,11 @@ type dialerConfig struct {
dialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
refreshTimeout time.Duration
useIAMAuthN bool
setTokenSource bool
setIAMAuthNTokenSource bool
iamLoginTokenSource oauth2.TokenSource
useragents []string
setCredentials bool
setTokenSource bool
setIAMAuthNTokenSource bool
// err tracks any dialer options that may have failed.
err error
}
Expand Down Expand Up @@ -90,6 +91,7 @@ func WithCredentialsJSON(b []byte) Option {
return
}
d.iamLoginTokenSource = scoped.TokenSource
d.setCredentials = true
}
}

Expand Down Expand Up @@ -117,6 +119,7 @@ func WithDefaultDialOptions(opts ...DialOption) Option {
func WithTokenSource(s oauth2.TokenSource) Option {
return func(d *dialerConfig) {
d.setTokenSource = true
d.setCredentials = true
d.sqladminOpts = append(d.sqladminOpts, apiopt.WithTokenSource(s))
}
}
Expand All @@ -138,6 +141,7 @@ func WithTokenSource(s oauth2.TokenSource) Option {
func WithIAMAuthNTokenSources(apiTS, iamLoginTS oauth2.TokenSource) Option {
return func(d *dialerConfig) {
d.setIAMAuthNTokenSource = true
d.setCredentials = true
d.iamLoginTokenSource = iamLoginTS
d.sqladminOpts = append(d.sqladminOpts, apiopt.WithTokenSource(apiTS))
}
Expand Down