|
4 | 4 | package azuread |
5 | 5 |
|
6 | 6 | import ( |
| 7 | + "context" |
| 8 | + "errors" |
| 9 | + "io/fs" |
| 10 | + "net/url" |
| 11 | + "os" |
7 | 12 | "reflect" |
| 13 | + "strings" |
8 | 14 | "testing" |
9 | 15 |
|
10 | 16 | mssql "github.com/microsoft/go-mssqldb" |
@@ -137,3 +143,73 @@ func TestValidateParameters(t *testing.T) { |
137 | 143 | } |
138 | 144 | } |
139 | 145 | } |
| 146 | + |
| 147 | +func TestProvideActiveDirectoryTokenValidations(t *testing.T) { |
| 148 | + nonExistentCertPath := os.TempDir() + "non_existent_cert.pem" |
| 149 | + |
| 150 | + f, err := os.CreateTemp("", "malformed_cert.pem") |
| 151 | + if err != nil { |
| 152 | + t.Fatalf("create temporary file: %v", err) |
| 153 | + } |
| 154 | + if err = f.Truncate(0); err != nil { |
| 155 | + t.Fatalf("truncate temporary file: %v", err) |
| 156 | + } |
| 157 | + if _, err = f.Write([]byte("malformed")); err != nil { |
| 158 | + t.Fatalf("write to temporary file: %v", err) |
| 159 | + } |
| 160 | + if err = f.Close(); err != nil { |
| 161 | + t.Fatalf("close temporary file: %v", err) |
| 162 | + } |
| 163 | + malformedCertPath := f.Name() |
| 164 | + t.Cleanup(func() { _ = os.Remove(malformedCertPath) }) |
| 165 | + |
| 166 | + tests := []struct { |
| 167 | + name string |
| 168 | + dsn string |
| 169 | + expectedErr error |
| 170 | + expectedErrContains string |
| 171 | + }{ |
| 172 | + { |
| 173 | + name: "ActiveDirectoryServicePrincipal_cert_not_found", |
| 174 | + dsn: `sqlserver://someserver.database.windows.net?` + |
| 175 | + `user id=` + url.QueryEscape("my-app-id@my-tenant-id") + "&" + |
| 176 | + `fedauth=ActiveDirectoryServicePrincipal` + "&" + |
| 177 | + `clientcertpath=` + nonExistentCertPath + "&" + |
| 178 | + `applicationclientid=someguid`, |
| 179 | + expectedErr: fs.ErrNotExist, |
| 180 | + }, |
| 181 | + { |
| 182 | + name: "ActiveDirectoryServicePrincipal_cert_malformed", |
| 183 | + dsn: `sqlserver://someserver.database.windows.net?` + |
| 184 | + `user id=` + url.QueryEscape("my-app-id@my-tenant-id") + "&" + |
| 185 | + `fedauth=ActiveDirectoryServicePrincipal` + "&" + |
| 186 | + `clientcertpath=` + malformedCertPath + "&" + |
| 187 | + `applicationclientid=someguid`, |
| 188 | + expectedErrContains: "error reading P12 data", |
| 189 | + }, |
| 190 | + } |
| 191 | + for _, tst := range tests { |
| 192 | + t.Run(tst.name, func(t *testing.T) { |
| 193 | + config, err := parse(tst.dsn) |
| 194 | + if err != nil { |
| 195 | + t.Errorf("Unexpected parse error: %v", err) |
| 196 | + return |
| 197 | + } |
| 198 | + _, err = config.provideActiveDirectoryToken(context.Background(), "", "authority/tenant") |
| 199 | + if err == nil { |
| 200 | + t.Errorf("Expected error but got nil") |
| 201 | + return |
| 202 | + } |
| 203 | + if tst.expectedErr != nil { |
| 204 | + if !errors.Is(err, tst.expectedErr) { |
| 205 | + t.Errorf("Expected error '%v' but got err = %v", tst.expectedErr, err) |
| 206 | + } |
| 207 | + } |
| 208 | + if tst.expectedErrContains != "" { |
| 209 | + if !strings.Contains(err.Error(), tst.expectedErrContains) { |
| 210 | + return |
| 211 | + } |
| 212 | + } |
| 213 | + }) |
| 214 | + } |
| 215 | +} |
0 commit comments