Skip to content

Commit eada3e7

Browse files
- Manually completes witout error.
1 parent e59aa6a commit eada3e7

File tree

9 files changed

+159
-8
lines changed

9 files changed

+159
-8
lines changed

.vscode/launch.json

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@
150150
"select listing.Identifier as key_policy_id, listing.region from aws.cloud_control.resources listing where listing.data__TypeName = 'AWS::KMS::Key' and listing.region in ('us-east-1', 'ap-southeast-1') order by key_policy_id;",
151151
"create or replace materialized view de_gen_01 as select json_extract_path_text(detail.Properties, 'KeyPolicy', 'Id') as key_policy_id, json_extract_path_text(detail.Properties, 'Tags') as key_tags, json_extract_path_text(detail.Properties, 'KeyUsage') as key_usage, json_extract_path_text(detail.Properties, 'Origin') as key_origin, case when json_extract_path_text(detail.Properties, 'MultiRegion') = 'true' then 1 else 0 end as key_is_multi_region, detail.region from aws.cloud_control.resources listing inner join aws.cloud_control.resource detail on detail.data__Identifier = listing.Identifier and detail.region = listing.region where listing.data__TypeName = 'AWS::KMS::Key' and listing.region IN ('us-east-1', 'ap-southeast-1') and detail.data__TypeName = 'AWS::KMS::Key' order by key_policy_id ASC ; select key_policy_id, key_tags, key_usage, key_origin, key_is_multi_region, region from de_gen_01 order by key_policy_id ASC ; drop materialized view if exists de_gen_01 ;",
152152
"show methods in azure.dev_center.customization_tasks;",
153+
"set global \"a.b\"=1;",
154+
"set global \"$.auth.google.sub\"='[email protected]';",
153155
],
154156
"default": "show providers;"
155157
},

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ require (
2121
github.com/spf13/cobra v1.4.0
2222
github.com/spf13/pflag v1.0.5
2323
github.com/spf13/viper v1.10.1
24-
github.com/stackql/any-sdk v0.0.3-beta11
24+
github.com/stackql/any-sdk v0.0.3-beta16
2525
github.com/stackql/go-suffix-map v0.0.1-alpha01
2626
github.com/stackql/psql-wire v0.1.1-alpha07
2727
github.com/stackql/stackql-parser v0.0.14-alpha04

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -471,8 +471,8 @@ github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
471471
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
472472
github.com/spf13/viper v1.10.1 h1:nuJZuYpG7gTj/XqiUwg8bA0cp1+M2mC3J4g5luUYBKk=
473473
github.com/spf13/viper v1.10.1/go.mod h1:IGlFPqhNAPKRxohIzWpI5QEy4kuI7tcl5WvR+8qy1rU=
474-
github.com/stackql/any-sdk v0.0.3-beta11 h1:9cqA3Rzwkkwb4kupO95sa0FK5pBvRWJW4AbqFK7u2Xk=
475-
github.com/stackql/any-sdk v0.0.3-beta11/go.mod h1:CIMFo3fC2ScpqzkzeCkzUQQuzYA1VuqpG0p1EZXN+wY=
474+
github.com/stackql/any-sdk v0.0.3-beta16 h1:G4qD2n9GS+2fWZ7V5u/OqGnuaWqY62GHj9Q2TMQ5unU=
475+
github.com/stackql/any-sdk v0.0.3-beta16/go.mod h1:CIMFo3fC2ScpqzkzeCkzUQQuzYA1VuqpG0p1EZXN+wY=
476476
github.com/stackql/go-suffix-map v0.0.1-alpha01 h1:TDUDS8bySu41Oo9p0eniUeCm43mnRM6zFEd6j6VUaz8=
477477
github.com/stackql/go-suffix-map v0.0.1-alpha01/go.mod h1:QAi+SKukOyf4dBtWy8UMy+hsXXV+yyEE4vmBkji2V7g=
478478
github.com/stackql/psql-wire v0.1.1-alpha07 h1:LQWVUlx4Bougk6dztDNG5tmXxpIVeeTSsInTj801xCs=

internal/stackql/dto/auth_ctx.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,16 @@ import (
77
"strings"
88
)
99

10+
type AuthContexts map[string]*AuthCtx
11+
12+
func (as AuthContexts) Clone() AuthContexts {
13+
rv := make(AuthContexts)
14+
for k, v := range as {
15+
rv[k] = v.Clone()
16+
}
17+
return rv
18+
}
19+
1020
type AuthCtx struct {
1121
Scopes []string `json:"scopes,omitempty" yaml:"scopes,omitempty"`
1222
SQLCfg *SQLBackendCfg `json:"sqlDataSource" yaml:"sqlDataSource"`

internal/stackql/handler/handler.go

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"sync"
1010

1111
"github.com/stackql/any-sdk/anysdk"
12+
"github.com/stackql/any-sdk/pkg/jsonpath"
1213
"github.com/stackql/any-sdk/pkg/nomenclature"
1314
"github.com/stackql/stackql/internal/stackql/acid/tsm"
1415
"github.com/stackql/stackql/internal/stackql/acid/txn_context"
@@ -54,7 +55,7 @@ type HandlerContext interface { //nolint:revive // don't mind stuttering this on
5455
GetProviders() map[string]provider.IProvider
5556
GetControlAttributes() sqlcontrol.ControlAttributes
5657
GetCurrentProvider() string
57-
GetAuthContexts() map[string]*dto.AuthCtx
58+
GetAuthContexts() dto.AuthContexts
5859
GetRegistry() anysdk.RegistryAPI
5960
GetErrorPresentation() string
6061
GetOutfile() io.Writer
@@ -94,6 +95,9 @@ type HandlerContext interface { //nolint:revive // don't mind stuttering this on
9495
GetExportNamespace() string
9596

9697
GetDataFlowCfg() dto.DataFlowCfg
98+
99+
//
100+
SetConfigAtPath(path string, rhs interface{}, scope string) error
97101
}
98102

99103
type standardHandlerContext struct {
@@ -108,7 +112,7 @@ type standardHandlerContext struct {
108112
providers map[string]provider.IProvider
109113
controlAttributes sqlcontrol.ControlAttributes
110114
currentProvider string
111-
authContexts map[string]*dto.AuthCtx
115+
authContexts dto.AuthContexts
112116
sqlDataSources map[string]sql_datasource.SQLDataSource
113117
registry anysdk.RegistryAPI
114118
errorPresentation string
@@ -190,7 +194,7 @@ func (hc *standardHandlerContext) GetControlAttributes() sqlcontrol.ControlAttri
190194
}
191195
func (hc *standardHandlerContext) GetCurrentProvider() string { return hc.currentProvider }
192196

193-
func (hc *standardHandlerContext) GetAuthContexts() map[string]*dto.AuthCtx {
197+
func (hc *standardHandlerContext) GetAuthContexts() dto.AuthContexts {
194198
hc.authMapMutex.Lock()
195199
defer hc.authMapMutex.Unlock()
196200
return hc.authContexts
@@ -405,6 +409,40 @@ func (hc *standardHandlerContext) updateAuthContextIfNotExists(providerName stri
405409
hc.authContexts[providerName] = authCtx
406410
}
407411

412+
func (hc *standardHandlerContext) SetConfigAtPath(path string, rhs interface{}, scope string) error {
413+
return hc.setConfigAtPath(path, rhs, scope)
414+
}
415+
416+
func (hc *standardHandlerContext) setConfigAtPath(path string, rhs interface{}, scope string) error {
417+
searchPath, searchPathErr := composeSystemSearchPath(path)
418+
if searchPathErr != nil {
419+
return searchPathErr
420+
}
421+
system := searchPath.GetSystem()
422+
remainder := searchPath.GetRemainder()
423+
switch system {
424+
case dto.AuthCtxKey:
425+
return hc.setAuthContextAtPath(remainder, rhs, scope)
426+
default:
427+
return fmt.Errorf("system '%s' not supported", system)
428+
}
429+
}
430+
431+
func (hc *standardHandlerContext) setAuthContextAtPath(path string, rhs interface{}, scope string) error {
432+
hc.authMapMutex.Lock()
433+
defer hc.authMapMutex.Unlock()
434+
if scope == "" || scope == "default" {
435+
hc.authContexts = hc.authContexts.Clone()
436+
}
437+
searchPath, searchPathErr := composeSystemSearchPath(path)
438+
if searchPathErr != nil {
439+
return searchPathErr
440+
}
441+
authCtx := hc.authContexts[searchPath.GetSystem()]
442+
rv := jsonpath.Set(authCtx, searchPath.GetRemainder(), rhs)
443+
return rv
444+
}
445+
408446
func (hc *standardHandlerContext) GetNamespaceCollection() tablenamespace.Collection {
409447
return hc.namespaceCollection
410448
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package handler
2+
3+
import (
4+
"fmt"
5+
"strings"
6+
7+
"github.com/stackql/any-sdk/pkg/jsonpath"
8+
)
9+
10+
type simpleSystemPathSearcher struct {
11+
system string
12+
remainder string
13+
}
14+
15+
type systemPathSearcher interface {
16+
GetSystem() string
17+
GetRemainder() string
18+
}
19+
20+
func composeSystemSearchPath(path string) (systemPathSearcher, error) {
21+
pSplit, splitErr := jsonpath.SplitSearchPath(path)
22+
if splitErr != nil {
23+
return nil, splitErr
24+
}
25+
if len(pSplit) < 1 {
26+
return nil, fmt.Errorf("path '%s' is insufficient", path)
27+
}
28+
remainder := ""
29+
if len(pSplit) > 1 {
30+
remainder = strings.TrimPrefix(path, pSplit[0]+".")
31+
}
32+
return &simpleSystemPathSearcher{
33+
system: pSplit[0],
34+
remainder: remainder,
35+
}, nil
36+
}
37+
38+
func (ss *simpleSystemPathSearcher) GetSystem() string {
39+
return ss.system
40+
}
41+
42+
func (ss *simpleSystemPathSearcher) GetRemainder() string {
43+
return ss.remainder
44+
}

internal/stackql/planbuilder/plan_builder.go

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
package planbuilder
22

33
import (
4+
"encoding/json"
45
"fmt"
56
"sort"
67
"strconv"
78
"strings"
89

910
"github.com/stackql/any-sdk/anysdk"
11+
1012
"github.com/stackql/stackql/internal/stackql/acid/txn_context"
1113
"github.com/stackql/stackql/internal/stackql/astanalysis/routeanalysis"
1214
"github.com/stackql/stackql/internal/stackql/handler"
@@ -123,7 +125,7 @@ func (pgb *standardPlanGraphBuilder) createInstructionFor(pbi planbuilderinput.P
123125
_, _, err := pgb.handleSelect(pbi)
124126
return err
125127
case *sqlparser.Set:
126-
return pgb.nop(pbi)
128+
return pgb.handleSet(pbi)
127129
case *sqlparser.SetTransaction:
128130
return pgb.nop(pbi)
129131
case *sqlparser.Show:
@@ -159,6 +161,46 @@ func (pgb *standardPlanGraphBuilder) nop(pbi planbuilderinput.PlanBuilderInput)
159161
return err
160162
}
161163

164+
func setLogic(pbi planbuilderinput.PlanBuilderInput, setExpr *sqlparser.SetExpr) error {
165+
lhsRaw := setExpr.Name.GetRawVal()
166+
lhsTrimmed := strings.TrimPrefix(lhsRaw, "$.")
167+
if lhsTrimmed == lhsRaw {
168+
return nil
169+
}
170+
exprStr := strings.Trim(sqlparser.String(setExpr.Expr), "'")
171+
exprObj := map[string]interface{}{}
172+
deserErr := json.Unmarshal([]byte(exprStr), &exprObj)
173+
if deserErr != nil {
174+
rawRv := pbi.GetHandlerCtx().SetConfigAtPath(lhsTrimmed, exprStr, setExpr.Scope)
175+
return rawRv
176+
}
177+
rv := pbi.GetHandlerCtx().SetConfigAtPath(lhsTrimmed, exprObj, setExpr.Scope)
178+
return rv
179+
}
180+
181+
func (pgb *standardPlanGraphBuilder) handleSet(pbi planbuilderinput.PlanBuilderInput) error {
182+
primitiveGenerator := pgb.rootPrimitiveGenerator
183+
err := primitiveGenerator.AnalyzeStatement(pbi)
184+
if err != nil {
185+
return err
186+
}
187+
setStmt, ok := pbi.GetSet()
188+
if !ok {
189+
return fmt.Errorf("could not cast node of type '%T' to required Set", pbi.GetStatement())
190+
}
191+
if len(setStmt.Exprs) < 1 {
192+
return fmt.Errorf("no set expressions found")
193+
}
194+
pr := primitive.NewLocalPrimitive(
195+
//nolint:revive // acceptable for now
196+
func(pc primitive.IPrimitiveCtx) internaldto.ExecutorOutput {
197+
err = setLogic(pbi, setStmt.Exprs[0])
198+
return internaldto.NewExecutorOutput(nil, nil, nil, nil, err)
199+
})
200+
pgb.planGraphHolder.GetPrimitiveGraph().CreatePrimitiveNode(pr)
201+
return nil
202+
}
203+
162204
func (pgb *standardPlanGraphBuilder) pgInternal(pbi planbuilderinput.PlanBuilderInput) error {
163205
primitiveGenerator := pgb.rootPrimitiveGenerator
164206
err := primitiveGenerator.AnalyzePGInternal(pbi)

internal/stackql/planbuilderinput/plan_builder_input.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ type PlanBuilderInput interface {
4545
GetUnion() (*sqlparser.Union, bool)
4646
GetUpdate() (*sqlparser.Update, bool)
4747
GetUse() (*sqlparser.Use, bool)
48+
GetSet() (*sqlparser.Set, bool)
4849
IsTccSetAheadOfTime() bool
4950
SetIsTccSetAheadOfTime(bool)
5051
SetPrepStmtOffset(int)
@@ -384,6 +385,11 @@ func (pbi *StandardPlanBuilderInput) GetUse() (*sqlparser.Use, bool) {
384385
return rv, ok
385386
}
386387

388+
func (pbi *StandardPlanBuilderInput) GetSet() (*sqlparser.Set, bool) {
389+
rv, ok := pbi.stmt.(*sqlparser.Set)
390+
return rv, ok
391+
}
392+
387393
func (pbi *StandardPlanBuilderInput) GetUpdate() (*sqlparser.Update, bool) {
388394
rv, ok := pbi.stmt.(*sqlparser.Update)
389395
return rv, ok

internal/stackql/primitivegenerator/statement_analyzer.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ func (pb *standardPrimitiveGenerator) AnalyzeStatement(
8282
case *sqlparser.Select:
8383
return pb.analyzeSelect(pbi)
8484
case *sqlparser.Set:
85-
return iqlerror.GetStatementNotSupportedError("SET")
85+
return pb.analyzeSet(pbi)
8686
case *sqlparser.SetTransaction:
8787
return iqlerror.GetStatementNotSupportedError("SET TRANSACTION")
8888
case *sqlparser.Show:
@@ -118,6 +118,15 @@ func (pb *standardPrimitiveGenerator) analyzeUse(
118118
return nil
119119
}
120120

121+
func (pb *standardPrimitiveGenerator) analyzeSet(
122+
pbi planbuilderinput.PlanBuilderInput) error {
123+
_, ok := pbi.GetSet()
124+
if !ok {
125+
return fmt.Errorf("could not cast node of type '%T' to required Set", pbi.GetStatement())
126+
}
127+
return nil
128+
}
129+
121130
//nolint:govet,funlen // this is a beast
122131
func (pb *standardPrimitiveGenerator) analyzeUnion(
123132
pbi planbuilderinput.PlanBuilderInput) error {

0 commit comments

Comments
 (0)