Skip to content

Commit 8fafbe5

Browse files
Merge pull request #1669 from bruin-data/feature/show-estimated-cost-for-data-diff
Add cost estimation via dry-run support to data-diff command
2 parents 4909e8a + c4f88e0 commit 8fafbe5

5 files changed

Lines changed: 402 additions & 1 deletion

File tree

cmd/datadiff.go

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,67 @@ func generateAlterStatements(schemaComparison diff.SchemaComparisonResult, conn1
8383
return generator.GenerateAlterStatements(&schemaComparison)
8484
}
8585

86+
// estimateDiffCost estimates the cost of running a data-diff operation.
87+
func estimateDiffCost(ctx context.Context, conn1, conn2 any, table1Name, table2Name, _, _ string, table1Identifier, table2Identifier string, schemaOnly bool) (*diff.DiffCostEstimate, error) {
88+
var estimate1, estimate2 *diff.TableDiffCostEstimate
89+
var err1, err2 error
90+
var wg conc.WaitGroup
91+
92+
// Try to estimate cost for table 1
93+
wg.Go(func() {
94+
if estimator, ok := conn1.(diff.CostEstimator); ok {
95+
estimate1, err1 = estimator.EstimateTableDiffCost(ctx, table1Name, schemaOnly)
96+
if estimate1 != nil {
97+
estimate1.TableName = table1Identifier
98+
}
99+
} else {
100+
estimate1 = &diff.TableDiffCostEstimate{
101+
TableName: table1Identifier,
102+
Supported: false,
103+
UnsupportedReason: fmt.Sprintf("connection type %T does not support cost estimation", conn1),
104+
}
105+
}
106+
})
107+
108+
// Try to estimate cost for table 2
109+
wg.Go(func() {
110+
if estimator, ok := conn2.(diff.CostEstimator); ok {
111+
estimate2, err2 = estimator.EstimateTableDiffCost(ctx, table2Name, schemaOnly)
112+
if estimate2 != nil {
113+
estimate2.TableName = table2Identifier
114+
}
115+
} else {
116+
estimate2 = &diff.TableDiffCostEstimate{
117+
TableName: table2Identifier,
118+
Supported: false,
119+
UnsupportedReason: fmt.Sprintf("connection type %T does not support cost estimation", conn2),
120+
}
121+
}
122+
})
123+
124+
wg.Wait()
125+
126+
if err1 != nil {
127+
return nil, fmt.Errorf("failed to estimate cost for table '%s': %w", table1Identifier, err1)
128+
}
129+
if err2 != nil {
130+
return nil, fmt.Errorf("failed to estimate cost for table '%s': %w", table2Identifier, err2)
131+
}
132+
133+
// Calculate totals
134+
totalBytesProcessed := estimate1.TotalBytesProcessed + estimate2.TotalBytesProcessed
135+
totalBytesBilled := estimate1.TotalBytesBilled + estimate2.TotalBytesBilled
136+
137+
result := &diff.DiffCostEstimate{
138+
SourceTable: estimate1,
139+
TargetTable: estimate2,
140+
TotalBytesProcessed: totalBytesProcessed,
141+
TotalBytesBilled: totalBytesBilled,
142+
}
143+
144+
return result, nil
145+
}
146+
86147
// DataDiffCmd defines the 'data-diff' command.
87148
func DataDiffCmd() *cli.Command {
88149
var connectionName string
@@ -94,6 +155,7 @@ func DataDiffCmd() *cli.Command {
94155
var targetDialect string
95156
var reverse bool
96157
var outputFormat string
158+
var dryRun bool
97159

98160
return &cli.Command{
99161
Name: "data-diff",
@@ -147,6 +209,11 @@ func DataDiffCmd() *cli.Command {
147209
Destination: &outputFormat,
148210
Value: "plain",
149211
},
212+
&cli.BoolFlag{
213+
Name: "dry-run",
214+
Usage: "Estimate the cost of the comparison without executing it (outputs JSON). Only supported for BigQuery connections.",
215+
Destination: &dryRun,
216+
},
150217
},
151218
Action: func(ctx context.Context, c *cli.Command) error {
152219
if c.NArg() != 2 {
@@ -239,6 +306,27 @@ func DataDiffCmd() *cli.Command {
239306
return fmt.Errorf("connection type %T for '%s' does not support table summarization", conn2, conn2Name)
240307
}
241308

309+
// Handle dry-run mode for cost estimation
310+
if dryRun {
311+
costEstimate, err := estimateDiffCost(ctx, conn1, conn2, table1Name, table2Name, conn1Name, conn2Name, table1Identifier, table2Identifier, !full)
312+
if err != nil {
313+
jsonErr := map[string]string{"error": fmt.Sprintf("failed to estimate cost: %v", err)}
314+
jsonBytes, marshalErr := json.Marshal(jsonErr)
315+
if marshalErr != nil {
316+
return fmt.Errorf("failed to marshal JSON error output: %w", marshalErr)
317+
}
318+
fmt.Fprintln(c.Writer, string(jsonBytes))
319+
return cli.Exit("", 1)
320+
}
321+
322+
jsonBytes, err := json.MarshalIndent(costEstimate, "", " ")
323+
if err != nil {
324+
return fmt.Errorf("failed to marshal cost estimate JSON: %w", err)
325+
}
326+
fmt.Fprintln(c.Writer, string(jsonBytes))
327+
return nil
328+
}
329+
242330
schemaComparison, err := compareTables(ctx, s1, s2, table1Name, table2Name, !full)
243331
if err != nil {
244332
if outputFormat == "json" {

cmd/datadiff_test.go

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package cmd
22

33
import (
4+
"context"
45
"strings"
56
"testing"
67

@@ -536,3 +537,100 @@ func TestGenerateAlterStatements(t *testing.T) {
536537
assert.Contains(t, allStatements, "ALTER COLUMN")
537538
})
538539
}
540+
541+
// mockCostEstimator implements diff.CostEstimator for testing.
542+
type mockCostEstimator struct {
543+
estimate *diff.TableDiffCostEstimate
544+
err error
545+
}
546+
547+
func (m *mockCostEstimator) EstimateTableDiffCost(_ context.Context, tableName string, _ bool) (*diff.TableDiffCostEstimate, error) {
548+
if m.err != nil {
549+
return nil, m.err
550+
}
551+
result := *m.estimate
552+
result.TableName = tableName
553+
return &result, nil
554+
}
555+
556+
// mockNonCostEstimator is a type that doesn't implement CostEstimator.
557+
type mockNonCostEstimator struct{}
558+
559+
func TestEstimateDiffCost(t *testing.T) {
560+
t.Parallel()
561+
562+
t.Run("both connections support cost estimation", func(t *testing.T) {
563+
t.Parallel()
564+
conn1 := &mockCostEstimator{
565+
estimate: &diff.TableDiffCostEstimate{
566+
TotalBytesProcessed: 1000000,
567+
TotalBytesBilled: 10485760, // 10MB minimum
568+
Supported: true,
569+
Queries: []*diff.QueryCostEstimate{
570+
{QueryType: "schema", BytesProcessed: 0},
571+
{QueryType: "rowCount", BytesProcessed: 1000000},
572+
},
573+
},
574+
}
575+
conn2 := &mockCostEstimator{
576+
estimate: &diff.TableDiffCostEstimate{
577+
TotalBytesProcessed: 2000000,
578+
TotalBytesBilled: 10485760,
579+
Supported: true,
580+
Queries: []*diff.QueryCostEstimate{
581+
{QueryType: "schema", BytesProcessed: 0},
582+
{QueryType: "rowCount", BytesProcessed: 2000000},
583+
},
584+
},
585+
}
586+
587+
result, err := estimateDiffCost(t.Context(), conn1, conn2, "table1", "table2", "conn1", "conn2", "conn1:table1", "conn2:table2", true)
588+
589+
require.NoError(t, err)
590+
assert.NotNil(t, result)
591+
assert.Equal(t, int64(3000000), result.TotalBytesProcessed)
592+
assert.Equal(t, int64(20971520), result.TotalBytesBilled)
593+
assert.Equal(t, "conn1:table1", result.SourceTable.TableName)
594+
assert.Equal(t, "conn2:table2", result.TargetTable.TableName)
595+
assert.True(t, result.SourceTable.Supported)
596+
assert.True(t, result.TargetTable.Supported)
597+
})
598+
599+
t.Run("connection does not support cost estimation", func(t *testing.T) {
600+
t.Parallel()
601+
conn1 := &mockCostEstimator{
602+
estimate: &diff.TableDiffCostEstimate{
603+
TotalBytesProcessed: 1000000,
604+
TotalBytesBilled: 10485760,
605+
Supported: true,
606+
},
607+
}
608+
conn2 := &mockNonCostEstimator{}
609+
610+
result, err := estimateDiffCost(t.Context(), conn1, conn2, "table1", "table2", "conn1", "conn2", "conn1:table1", "conn2:table2", true)
611+
612+
require.NoError(t, err)
613+
assert.NotNil(t, result)
614+
assert.True(t, result.SourceTable.Supported)
615+
assert.False(t, result.TargetTable.Supported)
616+
assert.Contains(t, result.TargetTable.UnsupportedReason, "does not support cost estimation")
617+
})
618+
619+
t.Run("cost estimation error", func(t *testing.T) {
620+
t.Parallel()
621+
conn1 := &mockCostEstimator{
622+
err: assert.AnError,
623+
}
624+
conn2 := &mockCostEstimator{
625+
estimate: &diff.TableDiffCostEstimate{
626+
Supported: true,
627+
},
628+
}
629+
630+
result, err := estimateDiffCost(t.Context(), conn1, conn2, "table1", "table2", "conn1", "conn2", "conn1:table1", "conn2:table2", true)
631+
632+
require.Error(t, err)
633+
assert.Nil(t, result)
634+
assert.Contains(t, err.Error(), "failed to estimate cost for table")
635+
})
636+
}

pkg/bigquery/db.go

Lines changed: 140 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1094,7 +1094,7 @@ func (d *Client) fetchDateTimeStats(ctx context.Context, tableName, columnName s
10941094

10951095
func (d *Client) fetchJSONStats(ctx context.Context, tableName, columnName string) (*diff.JSONStatistics, error) {
10961096
statsQuery := fmt.Sprintf(`
1097-
SELECT
1097+
SELECT
10981098
COUNT(*) as count_val,
10991099
COUNTIF(%s IS NULL) as null_count
11001100
FROM %s`,
@@ -1122,6 +1122,145 @@ func (d *Client) fetchJSONStats(ctx context.Context, tableName, columnName strin
11221122
return stats, nil
11231123
}
11241124

1125+
// Estimates the cost of running a table diff operation without executing the queries.
1126+
func (d *Client) EstimateTableDiffCost(ctx context.Context, tableName string, schemaOnly bool) (*diff.TableDiffCostEstimate, error) {
1127+
result := &diff.TableDiffCostEstimate{
1128+
TableName: tableName,
1129+
Queries: make([]*diff.QueryCostEstimate, 0),
1130+
Supported: true,
1131+
}
1132+
1133+
// Parse table name to get dataset reference
1134+
tableComponents := strings.Split(tableName, ".")
1135+
var datasetRef string
1136+
var targetTable string
1137+
1138+
switch len(tableComponents) {
1139+
case 2:
1140+
datasetRef = fmt.Sprintf("%s.%s", d.config.ProjectID, tableComponents[0])
1141+
targetTable = tableComponents[1]
1142+
case 3:
1143+
datasetRef = fmt.Sprintf("%s.%s", tableComponents[0], tableComponents[1])
1144+
targetTable = tableComponents[2]
1145+
default:
1146+
return nil, fmt.Errorf("table name must be in dataset.table or project.dataset.table format, '%s' given", tableName)
1147+
}
1148+
1149+
schemaQuery := buildSchemaQuery(datasetRef, targetTable)
1150+
result.Queries = append(result.Queries, &diff.QueryCostEstimate{
1151+
QueryType: "schema",
1152+
Query: truncateQuery(schemaQuery),
1153+
BytesProcessed: 0, // INFORMATION_SCHEMA queries are free
1154+
BytesBilled: 0,
1155+
})
1156+
1157+
if schemaOnly {
1158+
// In schema-only mode, we only run the schema query
1159+
return result, nil
1160+
}
1161+
1162+
// 2. Row count query - dry run to estimate cost
1163+
countQuery := fmt.Sprintf("SELECT COUNT(*) as row_count FROM `%s`", tableName)
1164+
countEstimate, err := d.estimateQueryCost(ctx, countQuery, "rowCount")
1165+
if err != nil {
1166+
return nil, fmt.Errorf("failed to estimate row count query cost: %w", err)
1167+
}
1168+
result.Queries = append(result.Queries, countEstimate)
1169+
1170+
// 3. Get schema to determine column types (this is free since we're querying INFORMATION_SCHEMA)
1171+
schemaResult, err := d.Select(ctx, &query.Query{Query: schemaQuery})
1172+
if err != nil {
1173+
return nil, fmt.Errorf("failed to get schema for cost estimation: %w", err)
1174+
}
1175+
1176+
// 4. For each column, estimate the statistics query cost
1177+
for _, row := range schemaResult {
1178+
if len(row) < 2 {
1179+
continue
1180+
}
1181+
1182+
columnName, ok := row[0].(string)
1183+
if !ok {
1184+
continue
1185+
}
1186+
1187+
dataType, ok := row[1].(string)
1188+
if !ok {
1189+
continue
1190+
}
1191+
1192+
normalizedType := d.typeMapper.MapType(dataType)
1193+
var statsQuery string
1194+
1195+
switch normalizedType {
1196+
case diff.CommonTypeNumeric:
1197+
statsQuery = fmt.Sprintf(`SELECT MIN(%s), MAX(%s), AVG(%s), SUM(%s), COUNT(%s), COUNTIF(%s IS NULL), STDDEV(%s) FROM %s`,
1198+
columnName, columnName, columnName, columnName, columnName, columnName, columnName, "`"+tableName+"`")
1199+
case diff.CommonTypeString:
1200+
statsQuery = fmt.Sprintf(`SELECT MIN(LENGTH(%s)), MAX(LENGTH(%s)), AVG(LENGTH(%s)), COUNT(DISTINCT %s), COUNT(*), COUNTIF(%s IS NULL), COUNTIF(%s = '') FROM %s`,
1201+
columnName, columnName, columnName, columnName, columnName, columnName, "`"+tableName+"`")
1202+
case diff.CommonTypeBoolean:
1203+
statsQuery = fmt.Sprintf(`SELECT COUNTIF(%s = true), COUNTIF(%s = false), COUNT(*), COUNTIF(%s IS NULL) FROM %s`,
1204+
columnName, columnName, columnName, "`"+tableName+"`")
1205+
case diff.CommonTypeDateTime:
1206+
statsQuery = fmt.Sprintf(`SELECT MIN(%s), MAX(%s), COUNT(DISTINCT %s), COUNT(*), COUNTIF(%s IS NULL) FROM %s`,
1207+
columnName, columnName, columnName, columnName, "`"+tableName+"`")
1208+
case diff.CommonTypeJSON:
1209+
statsQuery = fmt.Sprintf(`SELECT COUNT(*), COUNTIF(%s IS NULL) FROM %s`,
1210+
columnName, "`"+tableName+"`")
1211+
case diff.CommonTypeBinary, diff.CommonTypeUnknown:
1212+
// Skip binary and unknown types
1213+
continue
1214+
}
1215+
1216+
estimate, err := d.estimateQueryCost(ctx, statsQuery, "statistics:"+columnName)
1217+
if err != nil {
1218+
return nil, fmt.Errorf("failed to estimate statistics query cost for column '%s': %w", columnName, err)
1219+
}
1220+
result.Queries = append(result.Queries, estimate)
1221+
}
1222+
1223+
// Calculate totals
1224+
for _, q := range result.Queries {
1225+
result.TotalBytesProcessed += q.BytesProcessed
1226+
result.TotalBytesBilled += q.BytesBilled
1227+
}
1228+
1229+
return result, nil
1230+
}
1231+
1232+
// estimateQueryCost runs a dry-run for a query and returns the bytes estimate.
1233+
func (d *Client) estimateQueryCost(ctx context.Context, queryStr string, queryType string) (*diff.QueryCostEstimate, error) {
1234+
stats, err := d.QueryDryRun(ctx, &query.Query{Query: queryStr})
1235+
if err != nil {
1236+
return nil, err
1237+
}
1238+
1239+
bytesProcessed := stats.TotalBytesProcessed
1240+
// BigQuery has a minimum billing of 10 MB per query
1241+
bytesBilled := bytesProcessed
1242+
if bytesBilled < 10*1024*1024 && bytesBilled > 0 {
1243+
bytesBilled = 10 * 1024 * 1024
1244+
}
1245+
1246+
return &diff.QueryCostEstimate{
1247+
QueryType: queryType,
1248+
Query: truncateQuery(queryStr),
1249+
BytesProcessed: bytesProcessed,
1250+
BytesBilled: bytesBilled,
1251+
}, nil
1252+
}
1253+
1254+
// truncateQuery truncates a query string for display purposes.
1255+
func truncateQuery(q string) string {
1256+
// Remove extra whitespace and newlines
1257+
q = strings.Join(strings.Fields(q), " ")
1258+
if len(q) > 100 {
1259+
return q[:97] + "..."
1260+
}
1261+
return q
1262+
}
1263+
11251264
// tableMetadataResult holds the result of fetching table metadata.
11261265
type tableMetadataResult struct {
11271266
Columns []*ansisql.DBColumn

0 commit comments

Comments
 (0)