diff --git a/delete.go b/delete.go index f3f31e63..e0c8f8e4 100644 --- a/delete.go +++ b/delete.go @@ -30,7 +30,7 @@ func (d *deleteData) Exec() (sql.Result, error) { func (d *deleteData) ToSql() (sqlStr string, args []interface{}, err error) { if len(d.From) == 0 { - err = fmt.Errorf("delete statements must specify a From table") + err = ErrNoTable return } diff --git a/errors.go b/errors.go new file mode 100644 index 00000000..92610a17 --- /dev/null +++ b/errors.go @@ -0,0 +1,8 @@ +package squirrel + +import "errors" + +var ( + ErrNoTable = errors.New("statement must specify a table") + ErrNoValues = errors.New("statement must have at least one set of values or select clause") +) diff --git a/insert.go b/insert.go index 97870783..98fbf9a7 100644 --- a/insert.go +++ b/insert.go @@ -52,11 +52,11 @@ func (d *insertData) QueryRow() RowScanner { func (d *insertData) ToSql() (sqlStr string, args []interface{}, err error) { if len(d.Into) == 0 { - err = errors.New("insert statements must specify a table") + err = ErrNoTable return } if len(d.Values) == 0 && d.Select == nil { - err = errors.New("insert statements must have at least one set of values or select clause") + err = ErrNoValues return } diff --git a/merge.go b/merge.go new file mode 100644 index 00000000..9768c0a2 --- /dev/null +++ b/merge.go @@ -0,0 +1,328 @@ +package squirrel + +import ( + "bytes" + "database/sql" + "errors" + "fmt" + "io" + "sort" + "strings" + + "github.com/lann/builder" +) + +type Typed struct { + Type string + Value interface{} +} + +type mergeData struct { + PlaceholderFormat PlaceholderFormat + RunWith BaseRunner + Prefixes []Sqlizer + Into string + ValuesAlias string + On string + When []string + Columns []string + Values [][]interface{} + Suffixes []Sqlizer + Select *SelectBuilder +} + +func (d *mergeData) Exec() (sql.Result, error) { + if d.RunWith == nil { + return nil, RunnerNotSet + } + return ExecWith(d.RunWith, d) +} + +func (d *mergeData) Query() (*sql.Rows, error) { + if d.RunWith == nil { + return nil, RunnerNotSet + } + return QueryWith(d.RunWith, d) +} + +func (d *mergeData) QueryRow() RowScanner { + if d.RunWith == nil { + return &Row{err: RunnerNotSet} + } + queryRower, ok := d.RunWith.(QueryRower) + if !ok { + return &Row{err: RunnerNotQueryRunner} + } + return QueryRowWith(queryRower, d) +} + +func (d *mergeData) ToSql() (sqlStr string, args []interface{}, err error) { + if len(d.Into) == 0 { + err = ErrNoTable + return + } + if len(d.Values) == 0 && d.Select == nil { + err = ErrNoValues + return + } + + sql := &bytes.Buffer{} + + if len(d.Prefixes) > 0 { + args, err = appendToSql(d.Prefixes, sql, " ", args) + if err != nil { + return + } + + sql.WriteString(" ") + } + + sql.WriteString("MERGE INTO ") + sql.WriteString(d.Into) + sql.WriteString(" ") + + sql.WriteString("USING ") + + sql.WriteString("(") + if d.Select != nil { + args, err = d.appendSelectToSQL(sql, args) + } else { + args, err = d.appendValuesToSQL(sql, args) + } + if err != nil { + return + } + + sql.WriteString(")") + + if d.ValuesAlias != "" { + sql.WriteString(" AS ") + sql.WriteString(d.ValuesAlias) + sql.WriteString(" ") + } + + if len(d.Columns) > 0 { + sql.WriteString("(") + sql.WriteString(strings.Join(d.Columns, ",")) + sql.WriteString(")") + } + + if d.On != "" { + sql.WriteString(" ON ") + sql.WriteString(d.On) + } + + if len(d.When) > 0 { + sql.WriteString(" WHEN ") + sql.WriteString(strings.Join(d.When, " WHEN ")) + } + + if len(d.Suffixes) > 0 { + sql.WriteString(" ") + args, err = appendToSql(d.Suffixes, sql, " ", args) + if err != nil { + return + } + } + + sqlStr, err = d.PlaceholderFormat.ReplacePlaceholders(sql.String()) + return +} + +func (d *mergeData) appendValuesToSQL(w io.Writer, args []interface{}) ([]interface{}, error) { + if len(d.Values) == 0 { + return args, errors.New("values for insert statements are not set") + } + + io.WriteString(w, "VALUES ") + + valuesStrings := make([]string, len(d.Values)) + for r, row := range d.Values { + valueStrings := make([]string, len(row)) + for v, val := range row { + var valueType string + switch rv := val.(type) { + case Typed: + valueType = rv.Type + val = rv.Value + } + if vs, ok := val.(Sqlizer); ok { + vsql, vargs, err := vs.ToSql() + if err != nil { + return nil, err + } + valueStrings[v] = vsql + args = append(args, vargs...) + } else { + valueStrings[v] = "?" + args = append(args, val) + } + if valueType != "" { + valueStrings[v] = fmt.Sprintf("%s::%s", valueStrings[v], valueType) + } + } + valuesStrings[r] = fmt.Sprintf("(%s)", strings.Join(valueStrings, ",")) + } + + io.WriteString(w, strings.Join(valuesStrings, ",")) + + return args, nil +} + +func (d *mergeData) appendSelectToSQL(w io.Writer, args []interface{}) ([]interface{}, error) { + if d.Select == nil { + return args, errors.New("select clause for insert statements are not set") + } + + selectClause, sArgs, err := d.Select.ToSql() + if err != nil { + return args, err + } + + io.WriteString(w, selectClause) + args = append(args, sArgs...) + + return args, nil +} + +// Builder + +// MergeBuilder builds SQL INSERT statements. +type MergeBuilder builder.Builder + +func init() { + builder.Register(MergeBuilder{}, mergeData{}) +} + +// Format methods + +// PlaceholderFormat sets PlaceholderFormat (e.g. Question or Dollar) for the +// query. +func (b MergeBuilder) PlaceholderFormat(f PlaceholderFormat) MergeBuilder { + return builder.Set(b, "PlaceholderFormat", f).(MergeBuilder) +} + +// Runner methods + +// RunWith sets a Runner (like database/sql.DB) to be used with e.g. Exec. +func (b MergeBuilder) RunWith(runner BaseRunner) MergeBuilder { + return setRunWith(b, runner).(MergeBuilder) +} + +// Exec builds and Execs the query with the Runner set by RunWith. +func (b MergeBuilder) Exec() (sql.Result, error) { + data := builder.GetStruct(b).(mergeData) + return data.Exec() +} + +// Query builds and Querys the query with the Runner set by RunWith. +func (b MergeBuilder) Query() (*sql.Rows, error) { + data := builder.GetStruct(b).(mergeData) + return data.Query() +} + +// QueryRow builds and QueryRows the query with the Runner set by RunWith. +func (b MergeBuilder) QueryRow() RowScanner { + data := builder.GetStruct(b).(mergeData) + return data.QueryRow() +} + +// Scan is a shortcut for QueryRow().Scan. +func (b MergeBuilder) Scan(dest ...interface{}) error { + return b.QueryRow().Scan(dest...) +} + +// SQL methods + +// ToSql builds the query into a SQL string and bound args. +func (b MergeBuilder) ToSql() (string, []interface{}, error) { + data := builder.GetStruct(b).(mergeData) + return data.ToSql() +} + +// MustSql builds the query into a SQL string and bound args. +// It panics if there are any errors. +func (b MergeBuilder) MustSql() (string, []interface{}) { + sql, args, err := b.ToSql() + if err != nil { + panic(err) + } + return sql, args +} + +// Prefix adds an expression to the beginning of the query +func (b MergeBuilder) Prefix(sql string, args ...interface{}) MergeBuilder { + return b.PrefixExpr(Expr(sql, args...)) +} + +// PrefixExpr adds an expression to the very beginning of the query +func (b MergeBuilder) PrefixExpr(expr Sqlizer) MergeBuilder { + return builder.Append(b, "Prefixes", expr).(MergeBuilder) +} + +// Into sets the INTO clause of the query. +func (b MergeBuilder) Into(into string) MergeBuilder { + return builder.Set(b, "Into", into).(MergeBuilder) +} + +// ValuesAlias sets the AS vals clause of the query. +func (b MergeBuilder) ValuesAlias(valuesAlias string) MergeBuilder { + return builder.Set(b, "ValuesAlias", valuesAlias).(MergeBuilder) +} + +// On sets the ON clause of the query. +func (b MergeBuilder) On(on string) MergeBuilder { + return builder.Set(b, "On", on).(MergeBuilder) +} + +// When sets the WHEN MATCHED/NOT MATCHED clause of the query. +func (b MergeBuilder) When(when string) MergeBuilder { + return builder.Append(b, "When", when).(MergeBuilder) +} + +// Columns adds insert columns to the query. +func (b MergeBuilder) Columns(columns ...string) MergeBuilder { + return builder.Extend(b, "Columns", columns).(MergeBuilder) +} + +// Values adds a single row's values to the query. +func (b MergeBuilder) Values(values ...interface{}) MergeBuilder { + return builder.Append(b, "Values", values).(MergeBuilder) +} + +// Suffix adds an expression to the end of the query +func (b MergeBuilder) Suffix(sql string, args ...interface{}) MergeBuilder { + return b.SuffixExpr(Expr(sql, args...)) +} + +// SuffixExpr adds an expression to the end of the query +func (b MergeBuilder) SuffixExpr(expr Sqlizer) MergeBuilder { + return builder.Append(b, "Suffixes", expr).(MergeBuilder) +} + +// SetMap set columns and values for insert builder from a map of column name and value +// note that it will reset all previous columns and values was set if any +func (b MergeBuilder) SetMap(clauses map[string]interface{}) MergeBuilder { + // Keep the columns in a consistent order by sorting the column key string. + cols := make([]string, 0, len(clauses)) + for col := range clauses { + cols = append(cols, col) + } + sort.Strings(cols) + + vals := make([]interface{}, 0, len(clauses)) + for _, col := range cols { + vals = append(vals, clauses[col]) + } + + b = builder.Set(b, "Columns", cols).(MergeBuilder) + b = builder.Set(b, "Values", [][]interface{}{vals}).(MergeBuilder) + + return b +} + +// Select set Select clause for insert query +// If Values and Select are used, then Select has higher priority +func (b MergeBuilder) Select(sb SelectBuilder) MergeBuilder { + return builder.Set(b, "Select", &sb).(MergeBuilder) +} diff --git a/merge_ctx.go b/merge_ctx.go new file mode 100644 index 00000000..331b27d5 --- /dev/null +++ b/merge_ctx.go @@ -0,0 +1,70 @@ +//go:build go1.8 +// +build go1.8 + +package squirrel + +import ( + "context" + "database/sql" + + "github.com/lann/builder" +) + +func (d *mergeData) ExecContext(ctx context.Context) (sql.Result, error) { + if d.RunWith == nil { + return nil, RunnerNotSet + } + ctxRunner, ok := d.RunWith.(ExecerContext) + if !ok { + return nil, NoContextSupport + } + return ExecContextWith(ctx, ctxRunner, d) +} + +func (d *mergeData) QueryContext(ctx context.Context) (*sql.Rows, error) { + if d.RunWith == nil { + return nil, RunnerNotSet + } + ctxRunner, ok := d.RunWith.(QueryerContext) + if !ok { + return nil, NoContextSupport + } + return QueryContextWith(ctx, ctxRunner, d) +} + +func (d *mergeData) QueryRowContext(ctx context.Context) RowScanner { + if d.RunWith == nil { + return &Row{err: RunnerNotSet} + } + queryRower, ok := d.RunWith.(QueryRowerContext) + if !ok { + if _, ok := d.RunWith.(QueryerContext); !ok { + return &Row{err: RunnerNotQueryRunner} + } + return &Row{err: NoContextSupport} + } + return QueryRowContextWith(ctx, queryRower, d) +} + +// ExecContext builds and ExecContexts the query with the Runner set by RunWith. +func (b MergeBuilder) ExecContext(ctx context.Context) (sql.Result, error) { + data := builder.GetStruct(b).(mergeData) + return data.ExecContext(ctx) +} + +// QueryContext builds and QueryContexts the query with the Runner set by RunWith. +func (b MergeBuilder) QueryContext(ctx context.Context) (*sql.Rows, error) { + data := builder.GetStruct(b).(mergeData) + return data.QueryContext(ctx) +} + +// QueryRowContext builds and QueryRowContexts the query with the Runner set by RunWith. +func (b MergeBuilder) QueryRowContext(ctx context.Context) RowScanner { + data := builder.GetStruct(b).(mergeData) + return data.QueryRowContext(ctx) +} + +// ScanContext is a shortcut for QueryRowContext().Scan. +func (b MergeBuilder) ScanContext(ctx context.Context, dest ...interface{}) error { + return b.QueryRowContext(ctx).Scan(dest...) +} diff --git a/merge_ctx_test.go b/merge_ctx_test.go new file mode 100644 index 00000000..d4d26768 --- /dev/null +++ b/merge_ctx_test.go @@ -0,0 +1,42 @@ +//go:build go1.8 +// +build go1.8 + +package squirrel + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMergeBuilderContextRunners(t *testing.T) { + db := &DBStub{} + b := Merge("test").Values(1).RunWith(db) + + expectedSql := "MERGE INTO test USING (VALUES (?)) ()" + + b.ExecContext(ctx) + assert.Equal(t, expectedSql, db.LastExecSql) + + b.QueryContext(ctx) + assert.Equal(t, expectedSql, db.LastQuerySql) + + b.QueryRowContext(ctx) + assert.Equal(t, expectedSql, db.LastQueryRowSql) + + err := b.ScanContext(ctx) + assert.NoError(t, err) +} + +func TestMergeBuilderContextNoRunner(t *testing.T) { + b := Merge("test").Values(1) + + _, err := b.ExecContext(ctx) + assert.Equal(t, RunnerNotSet, err) + + _, err = b.QueryContext(ctx) + assert.Equal(t, RunnerNotSet, err) + + err = b.ScanContext(ctx) + assert.Equal(t, RunnerNotSet, err) +} diff --git a/merge_test.go b/merge_test.go new file mode 100644 index 00000000..05d1a2a3 --- /dev/null +++ b/merge_test.go @@ -0,0 +1,102 @@ +package squirrel + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMergeBuilderToSql(t *testing.T) { + b := Merge("a"). + Prefix("WITH prefix AS ?", 0). + Values(1, 2). + ValuesAlias("vals"). + Columns("b", "c"). + On("a.b = vals.b AND a.c = vals.c"). + When("MATCHED THEN UPDATE SET b = vals.b, c = vals.c"). + When("NOT MATCHED THEN INSERT (b, c) VALUES (vals.b, vals.c)"). + Suffix("RETURNING a.b") + + sql, args, err := b.ToSql() + assert.NoError(t, err) + + expectedSQL := + "WITH prefix AS ? MERGE INTO a USING (VALUES (?,?)) AS vals (b,c) ON a.b = vals.b AND a.c = vals.c " + + "WHEN MATCHED THEN UPDATE SET b = vals.b, c = vals.c " + + "WHEN NOT MATCHED THEN INSERT (b, c) VALUES (vals.b, vals.c) RETURNING a.b" + assert.Equal(t, expectedSQL, sql) + + expectedArgs := []interface{}{0, 1, 2} + assert.Equal(t, expectedArgs, args) +} + +func TestMergeBuilderToSqlErr(t *testing.T) { + _, _, err := Merge("").Values(1).ToSql() + assert.Error(t, err) + + _, _, err = Merge("x").ToSql() + assert.Error(t, err) +} + +func TestMergeBuilderMustSql(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("TestMergeBuilderMustSql should have panicked!") + } + }() + Merge("").MustSql() +} + +func TestMergeBuilderPlaceholders(t *testing.T) { + b := Merge("test").Values(1, 2) + + sql, _, _ := b.PlaceholderFormat(Question).ToSql() + assert.Equal(t, "MERGE INTO test USING (VALUES (?,?))", sql) + + sql, _, _ = b.PlaceholderFormat(Dollar).ToSql() + assert.Equal(t, "MERGE INTO test USING (VALUES ($1,$2))", sql) +} + +func TestMergeBuilderRunners(t *testing.T) { + db := &DBStub{} + b := Merge("test").Values(1).RunWith(db) + + expectedSQL := "MERGE INTO test USING (VALUES (?))" + + b.Exec() + assert.Equal(t, expectedSQL, db.LastExecSql) +} + +func TestMergeBuilderNoRunner(t *testing.T) { + b := Merge("test").Values(1) + + _, err := b.Exec() + assert.Equal(t, RunnerNotSet, err) +} + +func TestMergeBuilderSetMap(t *testing.T) { + b := Merge("table").SetMap(Eq{"field1": 1, "field2": 2, "field3": 3}) + + sql, args, err := b.ToSql() + assert.NoError(t, err) + + expectedSQL := "MERGE INTO table USING (VALUES (?,?,?))(field1,field2,field3)" + assert.Equal(t, expectedSQL, sql) + + expectedArgs := []interface{}{1, 2, 3} + assert.Equal(t, expectedArgs, args) +} + +func TestMergeBuilderSelect(t *testing.T) { + sb := Select("field1").From("table1").Where(Eq{"field1": 1}) + ib := Merge("table2").ValuesAlias("vals").On("table2.field1 = vals.field1").Columns("field1").Select(sb) + + sql, args, err := ib.ToSql() + assert.NoError(t, err) + + expectedSQL := "MERGE INTO table2 USING (SELECT field1 FROM table1 WHERE field1 = ?) AS vals (field1) ON table2.field1 = vals.field1" + assert.Equal(t, expectedSQL, sql) + + expectedArgs := []interface{}{1} + assert.Equal(t, expectedArgs, args) +} diff --git a/statement.go b/statement.go index 9420c67f..12ecf21b 100644 --- a/statement.go +++ b/statement.go @@ -15,6 +15,11 @@ func (b StatementBuilderType) Insert(into string) InsertBuilder { return InsertBuilder(b).Into(into) } +// Merge returns a MergeBuilder for this StatementBuilderType. +func (b StatementBuilderType) Merge(into string) MergeBuilder { + return MergeBuilder(b).Into(into) +} + // Replace returns a InsertBuilder for this StatementBuilderType with the // statement keyword set to "REPLACE". func (b StatementBuilderType) Replace(into string) InsertBuilder { @@ -65,6 +70,13 @@ func Insert(into string) InsertBuilder { return StatementBuilder.Insert(into) } +// Merge returns a new MergeBuilder with the given table name. +// +// See MergeBuilder.Into. +func Merge(into string) MergeBuilder { + return StatementBuilder.Merge(into) +} + // Replace returns a new InsertBuilder with the statement keyword set to // "REPLACE" and with the given table name. // diff --git a/update.go b/update.go index eb2a9c4d..9408e0de 100644 --- a/update.go +++ b/update.go @@ -56,7 +56,7 @@ func (d *updateData) QueryRow() RowScanner { func (d *updateData) ToSql() (sqlStr string, args []interface{}, err error) { if len(d.Table) == 0 { - err = fmt.Errorf("update statements must specify a table") + err = ErrNoTable return } if len(d.SetClauses) == 0 {