@@ -34,7 +34,6 @@ import (
3434 "github.com/open-policy-agent/opa-control-plane/internal/aws"
3535 "github.com/open-policy-agent/opa-control-plane/internal/config"
3636 "github.com/open-policy-agent/opa-control-plane/internal/database/sourcedatafs"
37- ocp_fs "github.com/open-policy-agent/opa-control-plane/internal/fs"
3837 "github.com/open-policy-agent/opa-control-plane/internal/fs/mountfs"
3938 "github.com/open-policy-agent/opa-control-plane/internal/jsonpatch"
4039 "github.com/open-policy-agent/opa-control-plane/internal/logging"
@@ -404,6 +403,125 @@ WHERE source_name = %s AND path = %s AND (`+conditions+")", d.arg(0), d.arg(1)),
404403 }
405404}
406405
406+ func (d * Database ) checkSourceDataConflict (ctx context.Context , tx * sql.Tx , sourceName , path string , data any , bs []byte , principal string ) error {
407+ expr , err := authz .Partial (ctx , authz.Access {
408+ Principal : principal ,
409+ Permission : "sources.data.read" ,
410+ Resource : "sources" ,
411+ Name : sourceName ,
412+ }, nil )
413+ if err != nil {
414+ return err
415+ }
416+
417+ conditions , args := expr .SQL (d .arg , []any {sourceName , path })
418+ offset := len (args )
419+ prefixPath := filepath .Dir (path )
420+
421+ // We're only fetching relevant data "blobs", based on their filesystem location:
422+ // 1. anything clashing with the included object/non-object
423+ // 2. anything "upwards" the prefixPath.
424+ //
425+ // For example, to update corp/users/internal/alice/data.json, we'll fetch
426+ // 1a. path LIKE corp/users/internal/alice/KEY/% for each KEY in the object (len(keys) > 0)
427+ // 1b. path LIKE corp/users/internal/alice/% if data.json is not an object
428+ // --> either non-empty result is a conflict
429+ // 2. path IN (corp/users/internal/data.json corp/users/data.json corp/data.json)
430+ // --> these need to be fed into the loader to check
431+
432+ // deal with (1.), downwards conflicts
433+ // NB(sr): We don't need to consult the data itself to determine errors here,
434+ // so we only check existing paths.
435+ {
436+ var keys []string
437+ switch d := data .(type ) {
438+ case map [string ]any :
439+ keys = slices .Collect (maps .Keys (d ))
440+ if keys == nil {
441+ keys = []string {} // we use keys == nil to signal non-object data
442+ }
443+ }
444+
445+ prefixes := make ([]any , max (len (keys ), 1 ))
446+ if keys == nil { // (1b)
447+ prefixes = []any {prefixPath + "/%" } // any prefix path is a conflict
448+ } else { // (1a)
449+ for i := range keys {
450+ prefixes [i ] = prefixPath + "/" + keys [i ] + "/%"
451+ }
452+ }
453+ prefixArgs := make ([]string , len (prefixes ))
454+ for i := range prefixes {
455+ prefixArgs [i ] = "path LIKE " + d .arg (offset + i )
456+ }
457+
458+ values := make ([]any , 0 , len (prefixes )+ len (args )+ 2 )
459+ values = append (values , sourceName , "true" )
460+ values = append (values , args [2 :]... ) // conditions
461+ values = append (values , prefixes ... )
462+
463+ query := fmt .Sprintf (`SELECT path FROM sources_data
464+ WHERE source_name = %[1]s
465+ AND (%[2]s = %[2]s)
466+ AND (%[3]s)
467+ AND (%[4]s)
468+ ORDER BY path LIMIT 4` ,
469+ d .arg (0 ), d .arg (1 ), strings .Join (prefixArgs , " OR " ), conditions )
470+
471+ files , err := queryPaths (ctx , tx , query , values ... )
472+ if err != nil {
473+ if ! errors .Is (err , sql .ErrNoRows ) { // no rows, no conflict
474+ return err
475+ }
476+ }
477+
478+ if len (files ) > 0 {
479+ if len (files ) == 4 {
480+ files [3 ] = "..."
481+ }
482+ return fmt .Errorf ("%w: conflict with %v" , ErrDataConflict , files )
483+ }
484+ }
485+
486+ // deal with (2.)
487+ {
488+ upwardsPaths := upwardsPaths (prefixPath )
489+ if len (upwardsPaths ) == 0 {
490+ return nil
491+ }
492+ inParams := make ([]string , len (upwardsPaths ))
493+ for i := range upwardsPaths {
494+ inParams [i ] = d .arg (i + offset )
495+ }
496+
497+ query := fmt .Sprintf (`SELECT path FROM sources_data
498+ WHERE source_name = %s AND ((path LIKE %s) OR (path in (%s))) AND (%s)` ,
499+ d .arg (0 ), d .arg (1 ), strings .Join (inParams , "," ), conditions )
500+ values := make ([]any , 0 , 2 + len (upwardsPaths )+ len (args ))
501+ values = append (values , sourceName , prefixPath + "/%" )
502+ values = append (values , args [2 :]... ) // conditions
503+ values = append (values , upwardsPaths ... )
504+ files , err := queryPaths (ctx , tx , query , values ... )
505+ if err != nil {
506+ return err
507+ }
508+ if len (files ) == 0 {
509+ return nil
510+ }
511+
512+ // Attempt to load, i.e. merge with existing data. If it fails, don't upsert.
513+ fs0 := mountfs .New (map [string ]fs.FS {filepath .Dir (path ): sourcedatafs .NewSingleFS (ctx , func (context.Context ) ([]byte , error ) { return bs , nil })})
514+ fs1 := sourcedatafs .New (ctx , files , func (file string ) func (context.Context ) ([]byte , error ) {
515+ return d .sourceData (tx , sourceName , file )
516+ })
517+ fs2 := merged_fs .NewMergedFS (fs0 , fs1 )
518+ if _ , err := loader .NewFileLoader ().WithFS (fs2 ).All ([]string {"." }); err != nil {
519+ return fmt .Errorf ("%w: %w" , ErrDataConflict , err )
520+ }
521+ }
522+ return nil
523+ }
524+
407525func (d * Database ) SourcesDataPut (ctx context.Context , sourceName , path string , data any , principal string ) error {
408526 path = filepath .ToSlash (path )
409527 return tx1 (ctx , d , d .sourcesDataPut (ctx , sourceName , path , data , principal ))
@@ -430,28 +548,13 @@ func (d *Database) sourcesDataPut(ctx context.Context, sourceName, path string,
430548 return err
431549 }
432550
433- // Attempt to load, i.e. merge with existing data. If it fails, don't upsert.
434- var files []string
435- type Path struct {
436- P string `sql:"path"`
437- }
438- for row , err := range sqlrange .QueryContext [Path ](ctx , tx ,
439- `SELECT path FROM sources_data WHERE source_name = ` + d .arg (0 ),
440- sourceName ) {
441- if err != nil {
442- return err
443- }
444- files = append (files , row .P )
445- }
446-
447- fs0 := mountfs .New (map [string ]fs.FS {filepath .Dir (path ): sourcedatafs .NewSingleFS (ctx , func (context.Context ) ([]byte , error ) { return bs , nil })})
448- fs1 := sourcedatafs .New (ctx , files , func (file string ) func (context.Context ) ([]byte , error ) {
449- return d .sourceData (tx , sourceName , file )
450- })
451- fs2 := merged_fs .NewMergedFS (fs0 , fs1 )
452- ocp_fs .Walk (fs2 )
453- if _ , err := loader .NewFileLoader ().WithFS (fs2 ).All ([]string {"." }); err != nil {
454- return fmt .Errorf ("%w: %w" , ErrDataConflict , err )
551+ // NB: We only check for conflicts if the principal has the right to read source data.
552+ // (Otherwise, write access to could be abused to guess the data or its layout? Let's
553+ // err on the side of caution.)
554+ // This is done implicitly: If the conditions are not satisfiable, none of the file
555+ // lookups will yield anything.
556+ if err := d .checkSourceDataConflict (ctx , tx , sourceName , path , data , bs , principal ); err != nil {
557+ return err
455558 }
456559
457560 return d .upsert (ctx , tx , "sources_data" , []string {"source_name" , "path" , "data" }, []string {"source_name" , "path" }, sourceName , path , bs )
@@ -1308,18 +1411,11 @@ func (d *Database) iterSourceFiles(ctx context.Context, dbish sqlrange.Queryable
13081411 sourceName )
13091412}
13101413
1311- func (d * Database ) iterSourceFilenames (ctx context.Context , dbish sqlrange.Queryable , sourceName string ) iter.Seq2 [Data , error ] {
1312- return sqlrange .QueryContext [Data ](ctx ,
1313- dbish ,
1314- `SELECT path FROM sources_data WHERE source_name = ` + d .arg (0 ),
1315- sourceName )
1316- }
1317-
13181414func (d * Database ) sourceData (tx * sql.Tx , sourceName , path string ) func (context.Context ) ([]byte , error ) {
13191415 return func (ctx context.Context ) ([]byte , error ) {
13201416 var data []byte
13211417 err := tx .QueryRowContext (ctx ,
1322- `SELECT data FROM sources_data WHERE source_name = ` + d .arg (0 )+ `AND path = ` + d .arg (1 ),
1418+ `SELECT data FROM sources_data WHERE source_name = ` + d .arg (0 )+ ` AND path = ` + d .arg (1 ),
13231419 sourceName , path ,
13241420 ).Scan (& data )
13251421 return data , err
@@ -1766,3 +1862,41 @@ func tx3[T any, U bool | string](ctx context.Context, db *Database, f func(*sql.
17661862
17671863 return result , result2 , nil
17681864}
1865+
1866+ func upwardsPaths (basePath string ) []any {
1867+ prefixes := []any {}
1868+ parts := strings .Split (basePath , "/" )
1869+ currentPath := ""
1870+
1871+ for i := 1 ; i < len (parts ); i ++ {
1872+ if i > 0 {
1873+ currentPath = strings .Join (parts [:i ], "/" )
1874+ }
1875+ prefixes = append (prefixes , currentPath + "/data.json" )
1876+ }
1877+ return prefixes
1878+ }
1879+
1880+ func queryPaths (ctx context.Context , tx * sql.Tx , query string , values ... any ) ([]string , error ) {
1881+ rows , err := tx .QueryContext (ctx , query , values ... )
1882+ if err != nil {
1883+ return nil , err
1884+ }
1885+ defer rows .Close ()
1886+
1887+ var files []string
1888+ for rows .Next () {
1889+ var file string
1890+ if err := rows .Scan (& file ); err != nil {
1891+ return nil , err
1892+ }
1893+ files = append (files , file )
1894+ }
1895+ if err := rows .Close (); err != nil {
1896+ return nil , err
1897+ }
1898+ if err := rows .Err (); err != nil {
1899+ return nil , err
1900+ }
1901+ return files , nil
1902+ }
0 commit comments