Skip to content

Commit fac3ae2

Browse files
committed
Merge remote-tracking branch 'ctes/master'
This merges the fix for Masterminds#320.
2 parents e1c3903 + 4ded8f4 commit fac3ae2

File tree

4 files changed

+218
-0
lines changed

4 files changed

+218
-0
lines changed

cte.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package squirrel
2+
3+
import (
4+
"bytes"
5+
"strings"
6+
)
7+
8+
// CTE represents a single common table expression. They are composed of an alias, a few optional components, and a data manipulation statement, though exactly what sort of statement depends on the database system you're using. MySQL, for example, only allows SELECT statements; others, like PostgreSQL, permit INSERTs, UPDATEs, and DELETEs.
9+
// The optional components supported by this fork of Squirrel include:
10+
// * a list of columns
11+
// * the keyword RECURSIVE, the use of which may place additional constraints on the data manipulation statement
12+
type CTE struct {
13+
Alias string
14+
ColumnList []string
15+
Recursive bool
16+
Expression Sqlizer
17+
}
18+
19+
// ToSql builds the SQL for a CTE
20+
func (c CTE) ToSql() (string, []interface{}, error) {
21+
22+
var buf bytes.Buffer
23+
24+
if c.Recursive {
25+
buf.WriteString("RECURSIVE ")
26+
}
27+
28+
buf.WriteString(c.Alias)
29+
30+
if len(c.ColumnList) > 0 {
31+
buf.WriteString("(")
32+
buf.WriteString(strings.Join(c.ColumnList, ", "))
33+
buf.WriteString(")")
34+
}
35+
36+
buf.WriteString(" AS (")
37+
sql, args, err := c.Expression.ToSql()
38+
if err != nil {
39+
return "", []interface{}{}, err
40+
}
41+
buf.WriteString(sql)
42+
buf.WriteString(")")
43+
44+
return buf.String(), args, nil
45+
}

cte_test.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package squirrel
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
)
8+
9+
func TestNormalCTE(t *testing.T) {
10+
11+
cte := CTE{
12+
Alias: "cte",
13+
ColumnList: []string{"abc", "def"},
14+
Recursive: false,
15+
Expression: Select("abc", "def").From("t").Where(Eq{"abc": 1}),
16+
}
17+
18+
sql, args, err := cte.ToSql()
19+
20+
assert.Equal(t, "cte(abc, def) AS (SELECT abc, def FROM t WHERE abc = ?)", sql)
21+
assert.Equal(t, []interface{}{1}, args)
22+
assert.Nil(t, err)
23+
24+
}
25+
26+
func TestRecursiveCTE(t *testing.T) {
27+
28+
// this isn't usually valid SQL, but the point is to test the RECURSIVE part
29+
cte := CTE{
30+
Alias: "cte",
31+
ColumnList: []string{"abc", "def"},
32+
Recursive: true,
33+
Expression: Select("abc", "def").From("t").Where(Eq{"abc": 1}),
34+
}
35+
36+
sql, args, err := cte.ToSql()
37+
38+
assert.Equal(t, "RECURSIVE cte(abc, def) AS (SELECT abc, def FROM t WHERE abc = ?)", sql)
39+
assert.Equal(t, []interface{}{1}, args)
40+
assert.Nil(t, err)
41+
42+
}

select.go

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ type selectData struct {
1313
PlaceholderFormat PlaceholderFormat
1414
RunWith BaseRunner
1515
Prefixes []Sqlizer
16+
CTEs []Sqlizer
17+
Union Sqlizer
18+
UnionAll Sqlizer
1619
Options []string
1720
Columns []Sqlizer
1821
From Sqlizer
@@ -78,6 +81,15 @@ func (d *selectData) toSqlRaw() (sqlStr string, args []interface{}, err error) {
7881
sql.WriteString(" ")
7982
}
8083

84+
if len(d.CTEs) > 0 {
85+
sql.WriteString("WITH ")
86+
args, err = appendToSql(d.CTEs, sql, ", ", args)
87+
if err != nil {
88+
return
89+
}
90+
sql.WriteString(" ")
91+
}
92+
8193
sql.WriteString("SELECT ")
8294

8395
if len(d.Options) > 0 {
@@ -116,6 +128,22 @@ func (d *selectData) toSqlRaw() (sqlStr string, args []interface{}, err error) {
116128
}
117129
}
118130

131+
if d.Union != nil {
132+
sql.WriteString(" UNION ")
133+
args, err = appendToSql([]Sqlizer{d.Union}, sql, "", args)
134+
if err != nil {
135+
return
136+
}
137+
}
138+
139+
if d.UnionAll != nil {
140+
sql.WriteString(" UNION ALL ")
141+
args, err = appendToSql([]Sqlizer{d.UnionAll}, sql, "", args)
142+
if err != nil {
143+
return
144+
}
145+
}
146+
119147
if len(d.GroupBys) > 0 {
120148
sql.WriteString(" GROUP BY ")
121149
sql.WriteString(strings.Join(d.GroupBys, ", "))
@@ -253,6 +281,22 @@ func (b SelectBuilder) Options(options ...string) SelectBuilder {
253281
return builder.Extend(b, "Options", options).(SelectBuilder)
254282
}
255283

284+
// With adds a non-recursive CTE to the query.
285+
func (b SelectBuilder) With(alias string, expr Sqlizer) SelectBuilder {
286+
return b.WithCTE(CTE{Alias: alias, ColumnList: []string{}, Recursive: false, Expression: expr})
287+
}
288+
289+
// WithRecursive adds a recursive CTE to the query.
290+
func (b SelectBuilder) WithRecursive(alias string, expr Sqlizer) SelectBuilder {
291+
return b.WithCTE(CTE{Alias: alias, ColumnList: []string{}, Recursive: true, Expression: expr})
292+
}
293+
294+
// WithCTE adds an arbitrary Sqlizer to the query.
295+
// The sqlizer will be sandwiched between the keyword WITH and, if there's more than one CTE, a comma.
296+
func (b SelectBuilder) WithCTE(cte Sqlizer) SelectBuilder {
297+
return builder.Append(b, "CTEs", cte).(SelectBuilder)
298+
}
299+
256300
// Columns adds result columns to the query.
257301
func (b SelectBuilder) Columns(columns ...string) SelectBuilder {
258302
parts := make([]interface{}, 0, len(columns))
@@ -289,6 +333,20 @@ func (b SelectBuilder) FromSelect(from SelectBuilder, alias string) SelectBuilde
289333
return builder.Set(b, "From", Alias(from, alias)).(SelectBuilder)
290334
}
291335

336+
// UnionSelect sets a union SelectBuilder which removes duplicate rows
337+
// --> UNION combines the result from multiple SELECT statements into a single result set
338+
func (b SelectBuilder) UnionSelect(union SelectBuilder) SelectBuilder {
339+
union = union.PlaceholderFormat(Question)
340+
return builder.Set(b, "Union", union).(SelectBuilder)
341+
}
342+
343+
// UnionAllSelect sets a union SelectBuilder which includes all matching rows
344+
// --> UNION combines the result from multiple SELECT statements into a single result set
345+
func (b SelectBuilder) UnionAllSelect(union SelectBuilder) SelectBuilder {
346+
union = union.PlaceholderFormat(Question)
347+
return builder.Set(b, "UnionAll", union).(SelectBuilder)
348+
}
349+
292350
// JoinClause adds a join clause to the query.
293351
func (b SelectBuilder) JoinClause(pred interface{}, args ...interface{}) SelectBuilder {
294352
return builder.Append(b, "Joins", newPart(pred, args...)).(SelectBuilder)
@@ -319,6 +377,16 @@ func (b SelectBuilder) CrossJoin(join string, rest ...interface{}) SelectBuilder
319377
return b.JoinClause("CROSS JOIN "+join, rest...)
320378
}
321379

380+
// Union adds UNION to the query. (duplicate rows are removed)
381+
func (b SelectBuilder) Union(join string, rest ...interface{}) SelectBuilder {
382+
return b.JoinClause("UNION "+join, rest...)
383+
}
384+
385+
// UnionAll adds UNION ALL to the query. (includes all matching rows)
386+
func (b SelectBuilder) UnionAll(join string, rest ...interface{}) SelectBuilder {
387+
return b.JoinClause("UNION ALL "+join, rest...)
388+
}
389+
322390
// Where adds an expression to the WHERE clause of the query.
323391
//
324392
// Expressions are ANDed together in the generated SQL.

select_test.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,30 @@ func TestSelectSubqueryInConjunctionPlaceholderNumbering(t *testing.T) {
279279
assert.Equal(t, []interface{}{1, 2}, args)
280280
}
281281

282+
func TestOneCTE(t *testing.T) {
283+
sql, _, err := Select("*").From("cte").With("cte", Select("abc").From("def")).ToSql()
284+
285+
assert.NoError(t, err)
286+
287+
assert.Equal(t, "WITH cte AS (SELECT abc FROM def) SELECT * FROM cte", sql)
288+
}
289+
290+
func TestTwoCTEs(t *testing.T) {
291+
sql, _, err := Select("*").From("cte").With("cte", Select("abc").From("def")).With("cte2", Select("ghi").From("jkl")).ToSql()
292+
293+
assert.NoError(t, err)
294+
295+
assert.Equal(t, "WITH cte AS (SELECT abc FROM def), cte2 AS (SELECT ghi FROM jkl) SELECT * FROM cte", sql)
296+
}
297+
298+
func TestCTEErrorBubblesUp(t *testing.T) {
299+
300+
// a SELECT with no columns raises an error
301+
_, _, err := Select("*").From("cte").With("cte", SelectBuilder{}.From("def")).ToSql()
302+
303+
assert.Error(t, err)
304+
}
305+
282306
func TestSelectJoinClausePlaceholderNumbering(t *testing.T) {
283307
subquery := Select("a").Where(Eq{"b": 2}).PlaceholderFormat(Dollar)
284308

@@ -461,3 +485,42 @@ func TestRemoveColumns(t *testing.T) {
461485
assert.NoError(t, err)
462486
assert.Equal(t, "SELECT name FROM users", sql)
463487
}
488+
489+
func TestSelectBuilderUnionToSql(t *testing.T) {
490+
multi := Select("column1", "column2").
491+
From("table1").
492+
Where(Eq{"column1": "test"}).
493+
UnionSelect(Select("column3", "column4").From("table2").Where(Lt{"column4": 5}).
494+
UnionSelect(Select("column5", "column6").From("table3").Where(LtOrEq{"column5": 6})))
495+
sql, args, err := multi.ToSql()
496+
assert.NoError(t, err)
497+
498+
expectedSql := `SELECT column1, column2 FROM table1 WHERE column1 = ? ` +
499+
"UNION SELECT column3, column4 FROM table2 WHERE column4 < ? " +
500+
"UNION SELECT column5, column6 FROM table3 WHERE column5 <= ?"
501+
assert.Equal(t, expectedSql, sql)
502+
503+
expectedArgs := []interface{}{"test", 5, 6}
504+
assert.Equal(t, expectedArgs, args)
505+
506+
sql, _, err = multi.PlaceholderFormat(Dollar).ToSql()
507+
assert.NoError(t, err)
508+
expectedSql = `SELECT column1, column2 FROM table1 WHERE column1 = $1 ` +
509+
"UNION SELECT column3, column4 FROM table2 WHERE column4 < $2 " +
510+
"UNION SELECT column5, column6 FROM table3 WHERE column5 <= $3"
511+
assert.Equal(t, expectedSql, sql)
512+
513+
unionAll := Select("count(true) as C").
514+
From("table1").
515+
Where(Eq{"column1": []string{"test", "tester"}}).
516+
UnionAllSelect(Select("count(true) as C").From("table2").Where(Select("true").Prefix("NOT EXISTS(").Suffix(")").From("table3").Where("id=table2.column3")))
517+
sql, args, err = unionAll.ToSql()
518+
assert.NoError(t, err)
519+
520+
expectedSql = `SELECT count(true) as C FROM table1 WHERE column1 IN (?,?) ` +
521+
"UNION ALL SELECT count(true) as C FROM table2 WHERE NOT EXISTS( SELECT true FROM table3 WHERE id=table2.column3 )"
522+
assert.Equal(t, expectedSql, sql)
523+
524+
expectedArgs = []interface{}{"test", "tester"}
525+
assert.Equal(t, expectedArgs, args)
526+
}

0 commit comments

Comments
 (0)