diff --git a/dql/parser.go b/dql/parser.go index 0dcde37c853..19ffdc30e13 100644 --- a/dql/parser.go +++ b/dql/parser.go @@ -2691,7 +2691,7 @@ func validKeyAtRoot(k string) bool { switch k { case "func", "orderasc", "orderdesc", "first", "offset", "after": return true - case "from", "to", "numpaths", "minweight", "maxweight": + case "from", "to", "numpaths", "minweight", "maxweight", "maxfrontiersize": // Specific to shortest path return true case "depth": diff --git a/dql/parser_test.go b/dql/parser_test.go index 11913eb7e25..d764fe9060e 100644 --- a/dql/parser_test.go +++ b/dql/parser_test.go @@ -1302,7 +1302,7 @@ func TestParseQueryWithMultipleVar(t *testing.T) { func TestParseShortestPath(t *testing.T) { query := ` { - shortest(from:0x0a, to:0x0b, numpaths: 3, minweight: 3, maxweight: 6) { + shortest(from:0x0a, to:0x0b, numpaths: 3, minweight: 3, maxweight: 6, maxfrontiersize: 1) { friends name } @@ -1317,6 +1317,7 @@ func TestParseShortestPath(t *testing.T) { require.Equal(t, "3", res.Query[0].Args["numpaths"]) require.Equal(t, "3", res.Query[0].Args["minweight"]) require.Equal(t, "6", res.Query[0].Args["maxweight"]) + require.Equal(t, "1", res.Query[0].Args["maxfrontiersize"]) } func TestParseShortestPathWithUidVars(t *testing.T) { diff --git a/query/query.go b/query/query.go index 7084fffd3f2..46422cbaa92 100644 --- a/query/query.go +++ b/query/query.go @@ -165,6 +165,10 @@ type params struct { MaxWeight float64 // MinWeight is the min weight allowed in a path returned by the shortest path algorithm. MinWeight float64 + // MaxFrontierSize limits the number of candidate paths stored in the priority queue. + // During shortest path computation. This prevents out-of-memory errors on large graphs + // but may affect solution optimality if set too low. + MaxFrontierSize int64 // ExploreDepth is used by recurse and shortest path queries to specify the maximum graph // depth to explore. @@ -714,6 +718,16 @@ func (args *params) fill(gq *dql.GraphQuery) error { args.MinWeight = -math.MaxFloat64 } + if v, ok := gq.Args["maxfrontiersize"]; ok { + maxfrontiersize, err := strconv.ParseInt(v, 0, 64) + if err != nil { + return err + } + args.MaxFrontierSize = maxfrontiersize + } else if !ok { + args.MaxFrontierSize = math.MaxInt64 + } + if gq.ShortestPathArgs.From == nil || gq.ShortestPathArgs.To == nil { return errors.Errorf("from/to can't be nil for shortest path") } @@ -2640,7 +2654,7 @@ func (sg *SubGraph) sortAndPaginateUsingVar(ctx context.Context) error { func isValidArg(a string) bool { switch a { case "numpaths", "from", "to", "orderasc", "orderdesc", "first", "offset", "after", "depth", - "minweight", "maxweight": + "minweight", "maxweight", "maxfrontiersize": return true } return false diff --git a/query/shortest.go b/query/shortest.go index e6bb16661c7..12a107a4a43 100644 --- a/query/shortest.go +++ b/query/shortest.go @@ -405,6 +405,9 @@ func runKShortestPaths(ctx context.Context, sg *SubGraph) ([]*SubGraph, error) { hop: item.hop + 1, path: route{route: curPath}, } + if int64(pq.Len()) > sg.Params.MaxFrontierSize { + pq.Pop() + } heap.Push(&pq, node) } // Return the popped nodes path to pool. @@ -558,6 +561,9 @@ func shortestPath(ctx context.Context, sg *SubGraph) ([]*SubGraph, error) { cost: nodeCost, hop: item.hop + 1, } + if int64(pq.Len()) > sg.Params.MaxFrontierSize { + pq.Pop() + } heap.Push(&pq, node) } else { // We've already seen this node. So, just update the cost diff --git a/systest/shortest-path/graph.rdf.gz b/systest/shortest-path/graph.rdf.gz new file mode 100644 index 00000000000..3efc95fa89a Binary files /dev/null and b/systest/shortest-path/graph.rdf.gz differ diff --git a/systest/shortest-path/graph.schema.gz b/systest/shortest-path/graph.schema.gz new file mode 100644 index 00000000000..7dfb1a6d2d4 Binary files /dev/null and b/systest/shortest-path/graph.schema.gz differ diff --git a/systest/shortest-path/shortest_test.go b/systest/shortest-path/shortest_test.go new file mode 100644 index 00000000000..c1a50de72e9 --- /dev/null +++ b/systest/shortest-path/shortest_test.go @@ -0,0 +1,61 @@ +//go:build integration2 + +/* + * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package main + +import ( + "context" + "testing" + "time" + + "github.com/hypermodeinc/dgraph/v25/dgraphapi" + "github.com/hypermodeinc/dgraph/v25/dgraphtest" + "github.com/hypermodeinc/dgraph/v25/x" + + "github.com/stretchr/testify/require" +) + +func TestShortestPath(t *testing.T) { + conf := dgraphtest.NewClusterConfig().WithNumAlphas(1).WithNumZeros(1).WithReplicas(1).WithACL(time.Hour) + c, err := dgraphtest.NewLocalCluster(conf) + require.NoError(t, err) + defer func() { c.Cleanup(t.Failed()) }() + require.NoError(t, c.Start()) + + err = c.LiveLoad(dgraphtest.LiveOpts{ + DataFiles: []string{"graph.rdf.gz"}, + SchemaFiles: []string{"graph.schema.gz"}, + GqlSchemaFiles: []string{}, + }) + require.NoError(t, err) + + gc, cleanup, err := c.Client() + require.NoError(t, err) + defer cleanup() + require.NoError(t, gc.LoginIntoNamespace(context.Background(), + dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) + + _, err = gc.Query(` + { + q(func: eq(guid, "85270d10-560e-4cc8-8703-4b4c563a2f4e")) { + a as uid + } + q1(func: eq(guid, "4a520068-80b6-42f2-9019-4e6ef8a02bb3")) { + b as uid + } + + path as shortest(from: uid(a), to: uid(b), numpaths: 5, maxfrontiersize: 10000) { + connected_to @facets(weight) + } + + path(func: uid(path)) { + uid + } + } + `) + require.NoError(t, err) +}