Skip to content

Commit 4b95a0f

Browse files
authored
Fix error checks during certificatePath reading and parsing in azuread (#227)
1 parent 573423d commit 4b95a0f

File tree

2 files changed

+78
-2
lines changed

2 files changed

+78
-2
lines changed

azuread/configuration.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,11 +184,11 @@ func (p *azureFedAuthConfig) provideActiveDirectoryToken(ctx context.Context, se
184184
case p.certificatePath != "":
185185
var certData []byte
186186
certData, err = os.ReadFile(p.certificatePath)
187-
if err != nil {
187+
if err == nil {
188188
var certs []*x509.Certificate
189189
var key crypto.PrivateKey
190190
certs, key, err = azidentity.ParseCertificates(certData, []byte(p.clientSecret))
191-
if err != nil {
191+
if err == nil {
192192
cred, err = azidentity.NewClientCertificateCredential(tenant, p.clientID, certs, key, nil)
193193
}
194194
}

azuread/configuration_test.go

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,13 @@
44
package azuread
55

66
import (
7+
"context"
8+
"errors"
9+
"io/fs"
10+
"net/url"
11+
"os"
712
"reflect"
13+
"strings"
814
"testing"
915

1016
mssql "github.com/microsoft/go-mssqldb"
@@ -137,3 +143,73 @@ func TestValidateParameters(t *testing.T) {
137143
}
138144
}
139145
}
146+
147+
func TestProvideActiveDirectoryTokenValidations(t *testing.T) {
148+
nonExistentCertPath := os.TempDir() + "non_existent_cert.pem"
149+
150+
f, err := os.CreateTemp("", "malformed_cert.pem")
151+
if err != nil {
152+
t.Fatalf("create temporary file: %v", err)
153+
}
154+
if err = f.Truncate(0); err != nil {
155+
t.Fatalf("truncate temporary file: %v", err)
156+
}
157+
if _, err = f.Write([]byte("malformed")); err != nil {
158+
t.Fatalf("write to temporary file: %v", err)
159+
}
160+
if err = f.Close(); err != nil {
161+
t.Fatalf("close temporary file: %v", err)
162+
}
163+
malformedCertPath := f.Name()
164+
t.Cleanup(func() { _ = os.Remove(malformedCertPath) })
165+
166+
tests := []struct {
167+
name string
168+
dsn string
169+
expectedErr error
170+
expectedErrContains string
171+
}{
172+
{
173+
name: "ActiveDirectoryServicePrincipal_cert_not_found",
174+
dsn: `sqlserver://someserver.database.windows.net?` +
175+
`user id=` + url.QueryEscape("my-app-id@my-tenant-id") + "&" +
176+
`fedauth=ActiveDirectoryServicePrincipal` + "&" +
177+
`clientcertpath=` + nonExistentCertPath + "&" +
178+
`applicationclientid=someguid`,
179+
expectedErr: fs.ErrNotExist,
180+
},
181+
{
182+
name: "ActiveDirectoryServicePrincipal_cert_malformed",
183+
dsn: `sqlserver://someserver.database.windows.net?` +
184+
`user id=` + url.QueryEscape("my-app-id@my-tenant-id") + "&" +
185+
`fedauth=ActiveDirectoryServicePrincipal` + "&" +
186+
`clientcertpath=` + malformedCertPath + "&" +
187+
`applicationclientid=someguid`,
188+
expectedErrContains: "error reading P12 data",
189+
},
190+
}
191+
for _, tst := range tests {
192+
t.Run(tst.name, func(t *testing.T) {
193+
config, err := parse(tst.dsn)
194+
if err != nil {
195+
t.Errorf("Unexpected parse error: %v", err)
196+
return
197+
}
198+
_, err = config.provideActiveDirectoryToken(context.Background(), "", "authority/tenant")
199+
if err == nil {
200+
t.Errorf("Expected error but got nil")
201+
return
202+
}
203+
if tst.expectedErr != nil {
204+
if !errors.Is(err, tst.expectedErr) {
205+
t.Errorf("Expected error '%v' but got err = %v", tst.expectedErr, err)
206+
}
207+
}
208+
if tst.expectedErrContains != "" {
209+
if !strings.Contains(err.Error(), tst.expectedErrContains) {
210+
return
211+
}
212+
}
213+
})
214+
}
215+
}

0 commit comments

Comments
 (0)