Skip to content

Commit a140723

Browse files
update-returning
Summary: - Support for `UPDATE RETURNING`.
1 parent 414a233 commit a140723

File tree

8 files changed

+283
-21
lines changed

8 files changed

+283
-21
lines changed

.vscode/launch.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@
178178
"insert into google.storage.buckets( project, data__name) select 'testing-project', 'silly-bucket' returning projectNumber;",
179179
"insert /*+ AWAIT */ into google.compute.networks(project, data__name, data__autoCreateSubnetworks) select 'mutable-project', 'auto-test-01', false returning creationTimestamp, name;",
180180
"registry pull google 'v0.1.2'; show resources in google.storage; registry pull google 'v0.1.1-alpha01'; show resources in google.storage; registry pull google 'v0.1.0'; show resources in google.storage;",
181+
"update google.storage.buckets set data__labels = '{ \"app_stub\": \"factory\" }' where bucket = 'demo-app-bucket1' returning labels, projectNumber;",
181182
],
182183
"default": "show providers;"
183184
},

internal/stackql/parserutil/parser_util.go

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,26 @@ func ExtractInsertReturningColumnNames(
108108
return colNames, err
109109
}
110110

111+
func ExtractUpdateReturningColumnNames(
112+
updateStmt *sqlparser.Update,
113+
formatter sqlparser.NodeFormatter,
114+
) ([]ColumnHandle, error) {
115+
var colNames []ColumnHandle
116+
var err error
117+
for _, node := range updateStmt.SelectExprs {
118+
switch node := node.(type) {
119+
case *sqlparser.AliasedExpr:
120+
cn, cErr := inferColNameFromExpr(node.Expr, formatter, node.As.GetRawVal())
121+
if cErr != nil {
122+
return nil, cErr
123+
}
124+
colNames = append(colNames, cn)
125+
case *sqlparser.StarExpr:
126+
}
127+
}
128+
return colNames, err
129+
}
130+
111131
func ExtractInsertColumnNames(insertStmt *sqlparser.Insert) ([]string, error) {
112132
var colNames []string
113133
var err error
@@ -254,6 +274,10 @@ func ExtractInsertValColumnsPlusPlaceHolders(insStmt *sqlparser.Insert) (map[int
254274
return extractInsertValColumns(insStmt, false)
255275
}
256276

277+
// func ExtractUpdateValColumnsPlusPlaceHolders(updateStmt *sqlparser.Update) (map[int]map[int]interface{}, int, error) {
278+
// return extractUpdateValColumnsArray(updateStmt, false)
279+
// }
280+
257281
func extractInsertValColumns(
258282
insStmt *sqlparser.Insert,
259283
includePlaceholders bool,
@@ -288,6 +312,40 @@ func extractInsertValColumns(
288312
return nil, nonValCount, err
289313
}
290314

315+
// func extractUpdateValColumnsArray(
316+
// updateStmt *sqlparser.Update,
317+
// includePlaceholders bool,
318+
// ) (map[int]map[int]interface{}, int, error) {
319+
// precursor, nonVal, precursorErr := extractUpdateValColumns(updateStmt, includePlaceholders)
320+
// if precursorErr != nil {
321+
// return nil, len(nonVal), precursorErr
322+
// }
323+
// if len(nonVal) > 0 {
324+
// return nil, len(nonVal), fmt.Errorf("disallowed non val updates for colums: %v", nonVal)
325+
// }
326+
// retVal := make(map[int]map[int]interface{})
327+
// firstRow := make(map[int]interface{})
328+
// i := 0
329+
// for idx, col := range updateStmt.Exprs {
330+
// if col == nil {
331+
// if includePlaceholders {
332+
// retVal[i] = map[int]interface{}{"$placeholder": nil}
333+
// } else {
334+
// retVal[i] = nil
335+
// }
336+
// }
337+
// for k, v := range col {
338+
// if includePlaceholders {
339+
// retVal[i] = map[int]interface{}{
340+
// constants.PlaceholderKey: v,
341+
// }
342+
// }
343+
// retVal[i][k.GetRawVal()] = v
344+
// }
345+
// }
346+
// return retVal, len(nonVal), nil
347+
// }
348+
291349
//nolint:gocognit,gocritic // not overly complex
292350
func extractUpdateValColumns(
293351
updateStmt *sqlparser.Update,

internal/stackql/planbuilder/plan_builder.go

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1072,9 +1072,68 @@ func (pgb *standardPlanGraphBuilder) handleUpdate(pbi planbuilderinput.PlanBuild
10721072
bldrInput.SetTxnCtrlCtrs(pbi.GetTxnCtrlCtrs())
10731073
bldrInput.SetIsTargetPhysicalTable(true)
10741074
}
1075-
bldr := primitivebuilder.NewInsertOrUpdate(
1076-
bldrInput,
1077-
)
1075+
isAwait := primitiveGenerator.GetPrimitiveComposer().IsAwait()
1076+
var bldr primitivebuilder.Builder
1077+
if len(node.SelectExprs) > 0 {
1078+
// Two cases:
1079+
// 1. Synchronous. Equivalent to select.
1080+
// 2. Asynchronous. Whole other story.
1081+
tableMeta, tableMetaExists := bldrInput.GetTableMetadata()
1082+
if !tableMetaExists {
1083+
return fmt.Errorf("could not obtain table metadata for node '%s'", node.Action)
1084+
}
1085+
rc, rcErr := tableinsertioncontainer.NewTableInsertionContainer(
1086+
tableMeta,
1087+
handlerCtx.GetSQLEngine(),
1088+
handlerCtx.GetTxnCounterMgr(),
1089+
)
1090+
if rcErr != nil {
1091+
return rcErr
1092+
}
1093+
bldrInput.SetTableInsertionContainer(rc)
1094+
bldrInput.SetIsReturning(true)
1095+
if !isAwait {
1096+
bldr = primitivebuilder.NewSingleAcquireAndSelect(
1097+
bldrInput,
1098+
primitiveGenerator.GetPrimitiveComposer().GetInsertPreparedStatementCtx(),
1099+
primitiveGenerator.GetPrimitiveComposer().GetSelectPreparedStatementCtx(),
1100+
nil,
1101+
)
1102+
} else {
1103+
bldrInput.SetIsAwait(true)
1104+
bldrInput.SetIsReturning(true)
1105+
bldrInput.SetInsertCtx(primitiveGenerator.GetPrimitiveComposer().GetInsertPreparedStatementCtx())
1106+
lhsBldr := primitivebuilder.NewInsertOrUpdate(
1107+
bldrInput,
1108+
)
1109+
newBldrInput := builder_input.NewBuilderInput(
1110+
pgb.planGraphHolder,
1111+
handlerCtx,
1112+
tbl,
1113+
)
1114+
newBldrInput.SetParserNode(node)
1115+
newBldrInput.SetAnnotatedAST(pbi.GetAnnotatedAST())
1116+
newBldrInput.SetTxnCtrlCtrs(pbi.GetTxnCtrlCtrs())
1117+
newBldrInput.SetTableInsertionContainer(rc)
1118+
newBldrInput.SetDependencyNode(selectPrimitiveNode)
1119+
newBldrInput.SetIsAwait(isAwait)
1120+
rhsBldr := primitivebuilder.NewSingleSelect(
1121+
pgb.planGraphHolder, handlerCtx, primitiveGenerator.GetPrimitiveComposer().GetSelectPreparedStatementCtx(),
1122+
[]tableinsertioncontainer.TableInsertionContainer{rc},
1123+
nil,
1124+
streaming.NewNopMapStream(),
1125+
)
1126+
bldr = primitivebuilder.NewDependencySubDAGBuilder(
1127+
pgb.planGraphHolder,
1128+
[]primitivebuilder.Builder{lhsBldr},
1129+
rhsBldr,
1130+
)
1131+
}
1132+
} else {
1133+
bldr = primitivebuilder.NewInsertOrUpdate(
1134+
bldrInput,
1135+
)
1136+
}
10781137
err = bldr.Build()
10791138
if err != nil {
10801139
return err

internal/stackql/primitivebuilder/insert_or_update.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ func (ss *insertOrUpdate) Build() error {
4444
}
4545
case *sqlparser.Update:
4646
mutableInput.SetVerb("update")
47+
if len(node.SelectExprs) > 0 {
48+
mutableInput.SetIsReturning(true)
49+
}
4750
default:
4851
return fmt.Errorf("mutation executor: cannnot accomodate node of type '%T'", node)
4952
}

internal/stackql/primitivegenerator/statement_analyzer.go

Lines changed: 97 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -960,6 +960,24 @@ func (pb *standardPrimitiveGenerator) buildRequestContext(
960960
meta tablemetadata.ExtendedTableMetadata,
961961
execContext anysdk.ExecContext,
962962
rowsToInsert map[int]map[int]interface{},
963+
) error {
964+
paramMapArray, paramErr := util.ExtractSQLNodeParams(node, rowsToInsert)
965+
if paramErr != nil {
966+
return paramErr
967+
}
968+
return pb.buildRequestContextFromMapArray(
969+
node,
970+
meta,
971+
execContext,
972+
paramMapArray,
973+
)
974+
}
975+
976+
func (pb *standardPrimitiveGenerator) buildRequestContextFromMapArray(
977+
node sqlparser.SQLNode,
978+
meta tablemetadata.ExtendedTableMetadata,
979+
execContext anysdk.ExecContext,
980+
paramMapArray map[int]map[string]interface{},
963981
) error {
964982
m, err := meta.GetMethod()
965983
if err != nil {
@@ -977,17 +995,13 @@ func (pb *standardPrimitiveGenerator) buildRequestContext(
977995
if prErr != nil {
978996
return prErr
979997
}
980-
paramMap, paramErr := util.ExtractSQLNodeParams(node, rowsToInsert)
981-
if paramErr != nil {
982-
return paramErr
983-
}
984998
meta.WithGetHTTPArmoury(
985999
func() (anysdk.HTTPArmoury, error) {
9861000
httpPreparator := anysdk.NewHTTPPreparator(
9871001
pr,
9881002
svc,
9891003
m,
990-
paramMap,
1004+
paramMapArray,
9911005
nil,
9921006
execContext,
9931007
logging.GetLogger(),
@@ -1179,35 +1193,101 @@ func (pb *standardPrimitiveGenerator) AnalyzeUpdate(pbi planbuilderinput.PlanBui
11791193
return nil
11801194
}
11811195

1182-
prov, err := tbl.GetProvider()
1196+
// prov, err := tbl.GetProvider()
1197+
// if err != nil {
1198+
// return err
1199+
// }
1200+
// currentService, err := tbl.GetServiceStr()
1201+
// if err != nil {
1202+
// return err
1203+
// }
1204+
// currentResource, err := tbl.GetResourceStr()
1205+
// if err != nil {
1206+
// return err
1207+
// }
1208+
1209+
pb.parseComments(node.Comments)
1210+
1211+
method, err := tbl.GetMethod()
11831212
if err != nil {
11841213
return err
11851214
}
1186-
currentService, err := tbl.GetServiceStr()
1187-
if err != nil {
1188-
return err
1215+
1216+
if pb.PrimitiveComposer.IsAwait() && !method.IsAwaitable() {
1217+
return fmt.Errorf("method %s is not awaitable", method.GetName())
11891218
}
1190-
currentResource, err := tbl.GetResourceStr()
1219+
1220+
if tbl.IsPhysicalTable() {
1221+
return nil
1222+
}
1223+
svc, err := tbl.GetService()
11911224
if err != nil {
11921225
return err
11931226
}
1194-
1195-
pb.parseComments(node.Comments)
1196-
1197-
method, err := tbl.GetMethod()
1227+
updateValOnlyRows, _, err := parserutil.ExtractUpdateValColumns(node)
11981228
if err != nil {
11991229
return err
12001230
}
1231+
firstRow := make(map[string]interface{})
1232+
for k, v := range updateValOnlyRows {
1233+
firstRow[k.GetRawVal()] = v
1234+
}
1235+
updateValOnlyRowsMap := map[int]map[string]interface{}{0: firstRow}
1236+
_, isOpenapi := svc.(anysdk.OpenAPIService)
1237+
if !isOpenapi {
1238+
err = pb.buildRequestContextFromMapArray(
1239+
node,
1240+
tbl,
1241+
nil,
1242+
updateValOnlyRowsMap,
1243+
)
1244+
if err != nil {
1245+
return err
1246+
}
1247+
return nil
1248+
}
12011249

12021250
if pb.PrimitiveComposer.IsAwait() && !method.IsAwaitable() {
12031251
return fmt.Errorf("method %s is not awaitable", method.GetName())
12041252
}
1205-
1206-
_, err = checkResource(handlerCtx, prov, currentService, currentResource)
1253+
if pb.PrimitiveComposer.IsAwait() && !method.IsAwaitable() {
1254+
return fmt.Errorf("method %s is not awaitable", method.GetName())
1255+
}
1256+
analysisInput := anysdk.NewMethodAnalysisInput(
1257+
method,
1258+
svc,
1259+
true,
1260+
[]anysdk.ColumnDescriptor{},
1261+
)
1262+
analyser := anysdk.NewMethodAnalyzer()
1263+
methodAnalysisOutput, analysisErr := analyser.AnalyzeUnaryAction(analysisInput)
1264+
if analysisErr != nil {
1265+
return analysisErr
1266+
}
1267+
err = pb.buildRequestContextFromMapArray(node, tbl, nil, updateValOnlyRowsMap)
12071268
if err != nil {
12081269
return err
12091270
}
1210-
1271+
columnHandles := []parserutil.ColumnHandle{}
1272+
if len(node.SelectExprs) > 0 {
1273+
columnHandles, err = parserutil.ExtractUpdateReturningColumnNames(node, handlerCtx.GetASTFormatter())
1274+
if err != nil {
1275+
return err
1276+
}
1277+
}
1278+
err = pb.analyzeUnaryAction(
1279+
pbi,
1280+
handlerCtx,
1281+
node,
1282+
nil,
1283+
tbl,
1284+
columnHandles,
1285+
methodAnalysisOutput,
1286+
)
1287+
if err != nil {
1288+
return err
1289+
}
1290+
pb.PrimitiveComposer.SetTable(node, tbl)
12111291
return nil
12121292
}
12131293

internal/stackql/taxonomy/hierarchy.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,27 @@ func GetAliasFromStatement(node sqlparser.SQLNode) string {
125125
}
126126
}
127127

128+
func GetUpdateTargetTableName(node *sqlparser.Update) (string, error) {
129+
return getUpdateTargetTableName(node)
130+
}
131+
132+
func getUpdateTargetTableName(node *sqlparser.Update) (string, error) {
133+
if len(node.TableExprs) != 1 {
134+
return "", fmt.Errorf("update statement must have exactly one table expression, not %d", len(node.TableExprs))
135+
}
136+
switch t := node.TableExprs[0].(type) {
137+
case *sqlparser.AliasedTableExpr:
138+
switch et := t.Expr.(type) {
139+
case sqlparser.TableName:
140+
return et.GetRawVal(), nil
141+
default:
142+
return "", fmt.Errorf("update statement must have exactly one table expression, not %T", et)
143+
}
144+
default:
145+
return "", fmt.Errorf("update statement must have exactly one table expression, not %T", t)
146+
}
147+
}
148+
128149
func GetTableNameFromStatement(node sqlparser.SQLNode, formatter sqlparser.NodeFormatter) string {
129150
switch n := node.(type) {
130151
case *sqlparser.AliasedTableExpr:
@@ -138,6 +159,9 @@ func GetTableNameFromStatement(node sqlparser.SQLNode, formatter sqlparser.NodeF
138159
return n.MethodName.GetRawVal()
139160
case *sqlparser.Insert:
140161
return n.Table.GetRawVal()
162+
case *sqlparser.Update:
163+
tableName, _ := getUpdateTargetTableName(n)
164+
return tableName
141165
case *sqlparser.Delete:
142166
if len(n.TableExprs) != 1 {
143167
return astformat.String(n, formatter)

test/python/stackql_test_tooling/flask/gcp/app.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11

22
import logging
33
from flask import Flask, render_template, request, jsonify
4+
import json
45

56
import os
67

@@ -14,7 +15,7 @@
1415

1516
@app.before_request
1617
def log_request_info():
17-
logger.info(f"Request: {request.method} {request.path} - Query: {request.args}")
18+
logger.info(f"Request: {request.method} {request.path} - Query: {request.args} -- Body: {request.get_data(as_text=True)}")
1819

1920
@app.route('/storage/v1/b', methods=['GET'])
2021
def v1_storage_buckets_list():
@@ -36,6 +37,12 @@ def v1_storage_buckets_insert():
3637
return render_template('buckets-insert-generic.jinja.json', bucket_name=bucket_name), 200, {'Content-Type': 'application/json'}
3738
return '{"msg": "Disallowed"}', 401, {'Content-Type': 'application/json'}
3839

40+
@app.route('/storage/v1/b/<bucket_name>', methods=['PATCH'])
41+
def v1_storage_buckets_update(bucket_name: str):
42+
body = request.get_json()
43+
labels = json.dumps(body.get('labels', {}))
44+
return render_template('buckets-update-generic.jinja.json', bucket_name=bucket_name, labels=labels), 200, {'Content-Type': 'application/json'}
45+
3946
@app.route('/compute/v1/projects/testing-project/global/networks', methods=['GET'])
4047
def projects_testing_project_global_networks():
4148
return render_template('route_27_template.json'), 200, {'Content-Type': 'application/json'}

0 commit comments

Comments
 (0)