Skip to content

Commit 45df461

Browse files
authored
fix(tool/bigquery): prevent allowedDatasets bypass in forecast query (#3324)
Addressing vulnerabilities in `bigquery-analyze-contribution` and `bigquery-forecast` tools. - Updated the tool options (like column names and metrics) to automatically wrap in single quotes, making it impossible for external users to inject malicious SQL code. - For query inputs, the tool now dry-runs the entire fully assembled statement against BigQuery to inspect every dataset it will access, guaranteeing that hidden accesses (like those inside SQL Views) are caught and blocked. Reported by: Matteo Panzeri
1 parent 1d8df0d commit 45df461

8 files changed

Lines changed: 902 additions & 118 deletions

File tree

internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go

Lines changed: 35 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
102102
}
103103

104104
inputDataParameter := parameters.NewStringParameter("input_data", inputDataDescription)
105-
contributionMetricParameter := parameters.NewStringParameter("contribution_metric",
105+
contributionMetricParameter := parameters.NewStringParameterWithEscape("contribution_metric",
106106
`The name of the column that contains the metric to analyze.
107107
Provides the expression to use to calculate the metric you are analyzing.
108108
To calculate a summable metric, the expression must be in the form SUM(metric_column_name),
@@ -114,11 +114,11 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
114114
115115
To calculate a summable by category metric, the expression must be in the form
116116
SUM(metric_sum_column_name)/COUNT(DISTINCT categorical_column_name). The summed column must be a numeric data type.
117-
The categorical column must have type BOOL, DATE, DATETIME, TIME, TIMESTAMP, STRING, or INT64.`)
118-
isTestColParameter := parameters.NewStringParameter("is_test_col",
119-
"The name of the column that identifies whether a row is in the test or control group.")
117+
The categorical column must have type BOOL, DATE, DATETIME, TIME, TIMESTAMP, STRING, or INT64.`, "single-quotes")
118+
isTestColParameter := parameters.NewStringParameterWithEscape("is_test_col",
119+
"The name of the column that identifies whether a row is in the test or control group.", "single-quotes")
120120
dimensionIDColsParameter := parameters.NewArrayParameterWithRequired("dimension_id_cols",
121-
"An array of column names that uniquely identify each dimension.", false, parameters.NewStringParameter("dimension_id_col", "A dimension column name."))
121+
"An array of column names that uniquely identify each dimension.", false, parameters.NewStringParameterWithEscape("dimension_id_col", "A dimension column name.", "single-quotes"))
122122
topKInsightsParameter := parameters.NewIntParameterWithDefault("top_k_insights_by_apriori_support", 30,
123123
"The number of top insights to return, ranked by apriori support.")
124124
pruningMethodParameter := parameters.NewStringParameterWithDefault("pruning_method", "PRUNE_REDUNDANT_INSIGHTS",
@@ -166,7 +166,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
166166
return nil, util.NewAgentError(fmt.Sprintf("unable to cast input_data parameter %s", paramsMap["input_data"]), nil)
167167
}
168168

169-
bqClient, restService, err := source.RetrieveClientAndService(accessToken)
169+
bqClient, _, err := source.RetrieveClientAndService(accessToken)
170170
if err != nil {
171171
return nil, util.NewClientServerError("failed to retrieve BigQuery client", http.StatusInternalServerError, err)
172172
}
@@ -177,22 +177,22 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
177177
if !ok {
178178
return nil, util.NewAgentError(fmt.Sprintf("unable to cast contribution_metric parameter %v", paramsMap["contribution_metric"]), nil)
179179
}
180-
if strings.ContainsRune(contributionMetric, '\'') {
180+
if !bqutil.ValidContributionMetricParam(contributionMetric) {
181181
return nil, util.NewAgentError("invalid 'contribution_metric': must not contain single quotes", nil)
182182
}
183183

184184
isTestCol, ok := paramsMap["is_test_col"].(string)
185185
if !ok {
186186
return nil, util.NewAgentError(fmt.Sprintf("unable to cast is_test_col parameter %v", paramsMap["is_test_col"]), nil)
187187
}
188-
if !bqutil.ValidColumnName(isTestCol) {
188+
if !bqutil.ValidColumnParam(isTestCol) {
189189
return nil, util.NewAgentError(fmt.Sprintf("invalid column name for 'is_test_col': %q; must match [a-zA-Z_][a-zA-Z0-9_]*", isTestCol), nil)
190190
}
191191

192192
var options []string
193193
options = append(options, "MODEL_TYPE = 'CONTRIBUTION_ANALYSIS'")
194-
options = append(options, fmt.Sprintf("CONTRIBUTION_METRIC = '%s'", contributionMetric))
195-
options = append(options, fmt.Sprintf("IS_TEST_COL = '%s'", isTestCol))
194+
options = append(options, fmt.Sprintf("CONTRIBUTION_METRIC = %s", contributionMetric))
195+
options = append(options, fmt.Sprintf("IS_TEST_COL = %s", isTestCol))
196196

197197
if val, ok := paramsMap["dimension_id_cols"]; ok {
198198
if cols, ok := val.([]any); ok {
@@ -202,10 +202,10 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
202202
if !ok {
203203
return nil, util.NewAgentError(fmt.Sprintf("dimension_id_cols contains non-string value: %v", c), nil)
204204
}
205-
if !bqutil.ValidColumnName(colStr) {
205+
if !bqutil.ValidColumnParam(colStr) {
206206
return nil, util.NewAgentError(fmt.Sprintf("invalid column name in 'dimension_id_cols': %q; must match [a-zA-Z_][a-zA-Z0-9_]*", colStr), nil)
207207
}
208-
strCols = append(strCols, fmt.Sprintf("'%s'", colStr))
208+
strCols = append(strCols, colStr)
209209
}
210210
options = append(options, fmt.Sprintf("DIMENSION_ID_COLS = [%s]", strings.Join(strCols, ", ")))
211211
} else {
@@ -226,37 +226,6 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
226226
var inputDataSource string
227227
trimmedUpperInputData := strings.TrimSpace(strings.ToUpper(inputData))
228228
if strings.HasPrefix(trimmedUpperInputData, "SELECT") || strings.HasPrefix(trimmedUpperInputData, "WITH") {
229-
if len(source.BigQueryAllowedDatasets()) > 0 {
230-
var connProps []*bigqueryapi.ConnectionProperty
231-
session, err := source.BigQuerySession()(ctx)
232-
if err != nil {
233-
return nil, util.NewClientServerError("failed to get BigQuery session", http.StatusInternalServerError, err)
234-
}
235-
if session != nil {
236-
connProps = []*bigqueryapi.ConnectionProperty{
237-
{Key: "session_id", Value: session.ID},
238-
}
239-
}
240-
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, inputData, nil, connProps, source.GetMaximumBytesBilled())
241-
if err != nil {
242-
return nil, util.ProcessGcpError(err)
243-
}
244-
statementType := dryRunJob.Statistics.Query.StatementType
245-
if statementType != "SELECT" {
246-
return nil, util.NewAgentError(fmt.Sprintf("the 'input_data' parameter only supports a table ID or a SELECT query. The provided query has statement type '%s'", statementType), nil)
247-
}
248-
249-
queryStats := dryRunJob.Statistics.Query
250-
if queryStats != nil {
251-
for _, tableRef := range queryStats.ReferencedTables {
252-
if !source.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) {
253-
return nil, util.NewAgentError(fmt.Sprintf("query in input_data accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectId, tableRef.DatasetId), nil)
254-
}
255-
}
256-
} else {
257-
return nil, util.NewAgentError("could not analyze query in input_data to validate against allowed datasets", nil)
258-
}
259-
}
260229
inputDataSource = fmt.Sprintf("(%s)", inputData)
261230
} else {
262231
if !bqutil.ValidTableID(inputData) {
@@ -305,6 +274,29 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
305274
// If not in protected mode, create a session for this invocation.
306275
createModelQuery.CreateSession = true
307276
}
277+
if len(source.BigQueryAllowedDatasets()) > 0 {
278+
createModelQuery.DryRun = true
279+
dryRunJob, err := createModelQuery.Run(ctx)
280+
if err != nil {
281+
return nil, util.ProcessGcpError(err)
282+
}
283+
status := dryRunJob.LastStatus()
284+
if status.Statistics != nil {
285+
if qStats, ok := status.Statistics.Details.(*bigqueryapi.QueryStatistics); ok {
286+
for _, tableRef := range qStats.ReferencedTables {
287+
if !source.IsDatasetAllowed(tableRef.ProjectID, tableRef.DatasetID) {
288+
return nil, util.NewAgentError(fmt.Sprintf("query accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectID, tableRef.DatasetID), nil)
289+
}
290+
}
291+
} else {
292+
return nil, util.NewAgentError("could not get query statistics details during dry run validation", nil)
293+
}
294+
} else {
295+
return nil, util.NewAgentError("could not dry run model creation query to validate allowed datasets", nil)
296+
}
297+
createModelQuery.DryRun = false
298+
}
299+
308300
createModelJob, err := createModelQuery.Run(ctx)
309301
if err != nil {
310302
return nil, util.ProcessGcpError(err)

0 commit comments

Comments
 (0)