Skip to content

Commit a4f46e4

Browse files
committed
database: only fetch relevant files for conflict check
This should let us avoid dealing with most of the files most of the time. Signed-off-by: Stephan Renatus <[email protected]>
1 parent ffbdaf0 commit a4f46e4

File tree

1 file changed

+165
-31
lines changed

1 file changed

+165
-31
lines changed

internal/database/database.go

Lines changed: 165 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ import (
3434
"github.com/open-policy-agent/opa-control-plane/internal/aws"
3535
"github.com/open-policy-agent/opa-control-plane/internal/config"
3636
"github.com/open-policy-agent/opa-control-plane/internal/database/sourcedatafs"
37-
ocp_fs "github.com/open-policy-agent/opa-control-plane/internal/fs"
3837
"github.com/open-policy-agent/opa-control-plane/internal/fs/mountfs"
3938
"github.com/open-policy-agent/opa-control-plane/internal/jsonpatch"
4039
"github.com/open-policy-agent/opa-control-plane/internal/logging"
@@ -404,6 +403,125 @@ WHERE source_name = %s AND path = %s AND (`+conditions+")", d.arg(0), d.arg(1)),
404403
}
405404
}
406405

406+
func (d *Database) checkSourceDataConflict(ctx context.Context, tx *sql.Tx, sourceName, path string, data any, bs []byte, principal string) error {
407+
expr, err := authz.Partial(ctx, authz.Access{
408+
Principal: principal,
409+
Permission: "sources.data.read",
410+
Resource: "sources",
411+
Name: sourceName,
412+
}, nil)
413+
if err != nil {
414+
return err
415+
}
416+
417+
conditions, args := expr.SQL(d.arg, []any{sourceName, path})
418+
offset := len(args)
419+
prefixPath := filepath.Dir(path)
420+
421+
// We're only fetching relevant data "blobs", based on their filesystem location:
422+
// 1. anything clashing with the included object/non-object
423+
// 2. anything "upwards" the prefixPath.
424+
//
425+
// For example, to update corp/users/internal/alice/data.json, we'll fetch
426+
// 1a. path LIKE corp/users/internal/alice/KEY/% for each KEY in the object (len(keys) > 0)
427+
// 1b. path LIKE corp/users/internal/alice/% if data.json is not an object
428+
// --> either non-empty result is a conflict
429+
// 2. path IN (corp/users/internal/data.json corp/users/data.json corp/data.json)
430+
// --> these need to be fed into the loader to check
431+
432+
// deal with (1.), downwards conflicts
433+
// NB(sr): We don't need to consult the data itself to determine errors here,
434+
// so we only check existing paths.
435+
{
436+
var keys []string
437+
switch d := data.(type) {
438+
case map[string]any:
439+
keys = slices.Collect(maps.Keys(d))
440+
if keys == nil {
441+
keys = []string{} // we use keys == nil to signal non-object data
442+
}
443+
}
444+
445+
prefixes := make([]any, max(len(keys), 1))
446+
if keys == nil { // (1b)
447+
prefixes = []any{prefixPath + "/%"} // any prefix path is a conflict
448+
} else { // (1a)
449+
for i := range keys {
450+
prefixes[i] = prefixPath + "/" + keys[i] + "/%"
451+
}
452+
}
453+
prefixArgs := make([]string, len(prefixes))
454+
for i := range prefixes {
455+
prefixArgs[i] = "path LIKE " + d.arg(offset+i)
456+
}
457+
458+
values := make([]any, 0, len(prefixes)+len(args)+2)
459+
values = append(values, sourceName, "true")
460+
values = append(values, args[2:]...) // conditions
461+
values = append(values, prefixes...)
462+
463+
query := fmt.Sprintf(`SELECT path FROM sources_data
464+
WHERE source_name = %[1]s
465+
AND (%[2]s = %[2]s)
466+
AND (%[3]s)
467+
AND (%[4]s)
468+
ORDER BY path LIMIT 4`,
469+
d.arg(0), d.arg(1), strings.Join(prefixArgs, " OR "), conditions)
470+
471+
files, err := queryPaths(ctx, tx, query, values...)
472+
if err != nil {
473+
if !errors.Is(err, sql.ErrNoRows) { // no rows, no conflict
474+
return err
475+
}
476+
}
477+
478+
if len(files) > 0 {
479+
if len(files) == 4 {
480+
files[3] = "..."
481+
}
482+
return fmt.Errorf("%w: conflict with %v", ErrDataConflict, files)
483+
}
484+
}
485+
486+
// deal with (2.)
487+
{
488+
upwardsPaths := upwardsPaths(prefixPath)
489+
if len(upwardsPaths) == 0 {
490+
return nil
491+
}
492+
inParams := make([]string, len(upwardsPaths))
493+
for i := range upwardsPaths {
494+
inParams[i] = d.arg(i + offset)
495+
}
496+
497+
query := fmt.Sprintf(`SELECT path FROM sources_data
498+
WHERE source_name = %s AND ((path LIKE %s) OR (path in (%s))) AND (%s)`,
499+
d.arg(0), d.arg(1), strings.Join(inParams, ","), conditions)
500+
values := make([]any, 0, 2+len(upwardsPaths)+len(args))
501+
values = append(values, sourceName, prefixPath+"/%")
502+
values = append(values, args[2:]...) // conditions
503+
values = append(values, upwardsPaths...)
504+
files, err := queryPaths(ctx, tx, query, values...)
505+
if err != nil {
506+
return err
507+
}
508+
if len(files) == 0 {
509+
return nil
510+
}
511+
512+
// Attempt to load, i.e. merge with existing data. If it fails, don't upsert.
513+
fs0 := mountfs.New(map[string]fs.FS{filepath.Dir(path): sourcedatafs.NewSingleFS(ctx, func(context.Context) ([]byte, error) { return bs, nil })})
514+
fs1 := sourcedatafs.New(ctx, files, func(file string) func(context.Context) ([]byte, error) {
515+
return d.sourceData(tx, sourceName, file)
516+
})
517+
fs2 := merged_fs.NewMergedFS(fs0, fs1)
518+
if _, err := loader.NewFileLoader().WithFS(fs2).All([]string{"."}); err != nil {
519+
return fmt.Errorf("%w: %w", ErrDataConflict, err)
520+
}
521+
}
522+
return nil
523+
}
524+
407525
func (d *Database) SourcesDataPut(ctx context.Context, sourceName, path string, data any, principal string) error {
408526
path = filepath.ToSlash(path)
409527
return tx1(ctx, d, d.sourcesDataPut(ctx, sourceName, path, data, principal))
@@ -430,28 +548,13 @@ func (d *Database) sourcesDataPut(ctx context.Context, sourceName, path string,
430548
return err
431549
}
432550

433-
// Attempt to load, i.e. merge with existing data. If it fails, don't upsert.
434-
var files []string
435-
type Path struct {
436-
P string `sql:"path"`
437-
}
438-
for row, err := range sqlrange.QueryContext[Path](ctx, tx,
439-
`SELECT path FROM sources_data WHERE source_name = `+d.arg(0),
440-
sourceName) {
441-
if err != nil {
442-
return err
443-
}
444-
files = append(files, row.P)
445-
}
446-
447-
fs0 := mountfs.New(map[string]fs.FS{filepath.Dir(path): sourcedatafs.NewSingleFS(ctx, func(context.Context) ([]byte, error) { return bs, nil })})
448-
fs1 := sourcedatafs.New(ctx, files, func(file string) func(context.Context) ([]byte, error) {
449-
return d.sourceData(tx, sourceName, file)
450-
})
451-
fs2 := merged_fs.NewMergedFS(fs0, fs1)
452-
ocp_fs.Walk(fs2)
453-
if _, err := loader.NewFileLoader().WithFS(fs2).All([]string{"."}); err != nil {
454-
return fmt.Errorf("%w: %w", ErrDataConflict, err)
551+
// NB: We only check for conflicts if the principal has the right to read source data.
552+
// (Otherwise, write access to could be abused to guess the data or its layout? Let's
553+
// err on the side of caution.)
554+
// This is done implicitly: If the conditions are not satisfiable, none of the file
555+
// lookups will yield anything.
556+
if err := d.checkSourceDataConflict(ctx, tx, sourceName, path, data, bs, principal); err != nil {
557+
return err
455558
}
456559

457560
return d.upsert(ctx, tx, "sources_data", []string{"source_name", "path", "data"}, []string{"source_name", "path"}, sourceName, path, bs)
@@ -1308,18 +1411,11 @@ func (d *Database) iterSourceFiles(ctx context.Context, dbish sqlrange.Queryable
13081411
sourceName)
13091412
}
13101413

1311-
func (d *Database) iterSourceFilenames(ctx context.Context, dbish sqlrange.Queryable, sourceName string) iter.Seq2[Data, error] {
1312-
return sqlrange.QueryContext[Data](ctx,
1313-
dbish,
1314-
`SELECT path FROM sources_data WHERE source_name = `+d.arg(0),
1315-
sourceName)
1316-
}
1317-
13181414
func (d *Database) sourceData(tx *sql.Tx, sourceName, path string) func(context.Context) ([]byte, error) {
13191415
return func(ctx context.Context) ([]byte, error) {
13201416
var data []byte
13211417
err := tx.QueryRowContext(ctx,
1322-
`SELECT data FROM sources_data WHERE source_name = `+d.arg(0)+`AND path = `+d.arg(1),
1418+
`SELECT data FROM sources_data WHERE source_name = `+d.arg(0)+` AND path = `+d.arg(1),
13231419
sourceName, path,
13241420
).Scan(&data)
13251421
return data, err
@@ -1766,3 +1862,41 @@ func tx3[T any, U bool | string](ctx context.Context, db *Database, f func(*sql.
17661862

17671863
return result, result2, nil
17681864
}
1865+
1866+
func upwardsPaths(basePath string) []any {
1867+
prefixes := []any{}
1868+
parts := strings.Split(basePath, "/")
1869+
currentPath := ""
1870+
1871+
for i := 1; i < len(parts); i++ {
1872+
if i > 0 {
1873+
currentPath = strings.Join(parts[:i], "/")
1874+
}
1875+
prefixes = append(prefixes, currentPath+"/data.json")
1876+
}
1877+
return prefixes
1878+
}
1879+
1880+
func queryPaths(ctx context.Context, tx *sql.Tx, query string, values ...any) ([]string, error) {
1881+
rows, err := tx.QueryContext(ctx, query, values...)
1882+
if err != nil {
1883+
return nil, err
1884+
}
1885+
defer rows.Close()
1886+
1887+
var files []string
1888+
for rows.Next() {
1889+
var file string
1890+
if err := rows.Scan(&file); err != nil {
1891+
return nil, err
1892+
}
1893+
files = append(files, file)
1894+
}
1895+
if err := rows.Close(); err != nil {
1896+
return nil, err
1897+
}
1898+
if err := rows.Err(); err != nil {
1899+
return nil, err
1900+
}
1901+
return files, nil
1902+
}

0 commit comments

Comments
 (0)