diff --git a/.github/workflows/ci-dgraph-oss-build.yml b/.github/workflows/ci-dgraph-oss-build.yml deleted file mode 100644 index 93904d1385b..00000000000 --- a/.github/workflows/ci-dgraph-oss-build.yml +++ /dev/null @@ -1,44 +0,0 @@ -name: ci-dgraph-oss-build - -on: - pull_request: - paths: - - "**/*.go" - - "**/go.mod" - - "**/*.yml" - - "**/Dockerfile" - - "**/Makefile" - types: - - opened - - reopened - - synchronize - - ready_for_review - branches: - - main - - release/** - -permissions: - contents: read - -jobs: - dgraph-oss-build: - if: github.event.pull_request.draft == false - runs-on: warp-ubuntu-latest-x64-4x - timeout-minutes: 10 - steps: - - uses: actions/checkout@v4 - - name: Set up Go - uses: actions/setup-go@v5 - with: - go-version-file: go.mod - - name: Make OSS Linux Build - run: make oss # verify that we can make OSS build - - name: Run OSS Unit Tests - run: | - #!/bin/bash - # go env settings - export GOPATH=~/go - # move the binary - cp dgraph/dgraph ~/go/bin/dgraph - # run OSS unit tests - go test -v -timeout=60m -failfast -tags=oss -count=1 ./... diff --git a/Makefile b/Makefile index a3734ff631a..0d0e10dcdfc 100644 --- a/Makefile +++ b/Makefile @@ -30,10 +30,6 @@ dgraph: dgraph-coverage: $(MAKE) -w -C dgraph test-coverage-binary -.PHONY: oss -oss: - $(MAKE) BUILD_TAGS=oss - .PHONY: version version: @echo Dgraph: ${BUILD_VERSION} @@ -48,10 +44,6 @@ install: @echo "Installing Dgraph..."; \ $(MAKE) -C dgraph install; \ -.PHONY: install_oss oss_install -install_oss oss_install: - $(MAKE) BUILD_TAGS=oss,jemalloc install - .PHONY: uninstall uninstall: @echo "Uninstalling Dgraph ..."; \ @@ -105,7 +97,6 @@ help: @echo @echo Build commands: @echo " make [all] - Build all targets [EE]" - @echo " make oss - Build all targets [OSS]" @echo " make dgraph - Build dgraph binary" @echo " make install - Install all targets" @echo " make uninstall - Uninstall known targets" diff --git a/dgraph/cmd/alpha/http.go b/dgraph/cmd/alpha/http.go index 5b996e1f742..12f58796675 100644 --- a/dgraph/cmd/alpha/http.go +++ b/dgraph/cmd/alpha/http.go @@ -24,6 +24,7 @@ import ( "github.com/pkg/errors" "google.golang.org/grpc/metadata" jsonpb "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" "github.com/dgraph-io/dgo/v240/protos/api" "github.com/hypermodeinc/dgraph/v24/dql" @@ -133,6 +134,50 @@ func parseDuration(r *http.Request, name string) (time.Duration, error) { return durationValue, nil } +func loginHandler(w http.ResponseWriter, r *http.Request) { + if commonHandler(w, r) { + return + } + + // Pass in PoorMan's auth, IP information if present. + ctx := x.AttachRemoteIP(context.Background(), r) + ctx = x.AttachAuthToken(ctx, r) + + body := readRequest(w, r) + loginReq := api.LoginRequest{} + if err := json.Unmarshal(body, &loginReq); err != nil { + x.SetStatusWithData(w, x.Error, err.Error()) + return + } + + resp, err := (&edgraph.Server{}).Login(ctx, &loginReq) + if err != nil { + x.SetStatusWithData(w, x.ErrorInvalidRequest, err.Error()) + return + } + + jwt := &api.Jwt{} + if err := proto.Unmarshal(resp.Json, jwt); err != nil { + x.SetStatusWithData(w, x.Error, err.Error()) + } + + response := map[string]interface{}{} + mp := make(map[string]string) + mp["accessJWT"] = jwt.AccessJwt + mp["refreshJWT"] = jwt.RefreshJwt + response["data"] = mp + + js, err := json.Marshal(response) + if err != nil { + x.SetStatusWithData(w, x.Error, err.Error()) + return + } + + if _, err := x.WriteResponse(w, r, js); err != nil { + glog.Errorf("Error while writing response: %v", err) + } +} + // This method should just build the request and proxy it to the Query method of dgraph.Server. // It can then encode the response as appropriate before sending it back to the user. func queryHandler(w http.ResponseWriter, r *http.Request) { diff --git a/dgraph/cmd/alpha/login_ee.go b/dgraph/cmd/alpha/login_ee.go deleted file mode 100644 index 38cc1614e4c..00000000000 --- a/dgraph/cmd/alpha/login_ee.go +++ /dev/null @@ -1,69 +0,0 @@ -//go:build !oss -// +build !oss - -/* - * SPDX-FileCopyrightText: © Hypermode Inc. - */ - -package alpha - -import ( - "context" - "encoding/json" - "net/http" - - "github.com/golang/glog" - "google.golang.org/protobuf/proto" - - "github.com/dgraph-io/dgo/v240/protos/api" - "github.com/hypermodeinc/dgraph/v24/edgraph" - "github.com/hypermodeinc/dgraph/v24/x" -) - -func loginHandler(w http.ResponseWriter, r *http.Request) { - if commonHandler(w, r) { - return - } - - // Pass in PoorMan's auth, IP information if present. - ctx := x.AttachRemoteIP(context.Background(), r) - ctx = x.AttachAuthToken(ctx, r) - - body := readRequest(w, r) - loginReq := api.LoginRequest{} - if err := json.Unmarshal(body, &loginReq); err != nil { - x.SetStatusWithData(w, x.Error, err.Error()) - return - } - - resp, err := (&edgraph.Server{}).Login(ctx, &loginReq) - if err != nil { - x.SetStatusWithData(w, x.ErrorInvalidRequest, err.Error()) - return - } - - jwt := &api.Jwt{} - if err := proto.Unmarshal(resp.Json, jwt); err != nil { - x.SetStatusWithData(w, x.Error, err.Error()) - } - - response := map[string]interface{}{} - mp := make(map[string]string) - mp["accessJWT"] = jwt.AccessJwt - mp["refreshJWT"] = jwt.RefreshJwt - response["data"] = mp - - js, err := json.Marshal(response) - if err != nil { - x.SetStatusWithData(w, x.Error, err.Error()) - return - } - - if _, err := x.WriteResponse(w, r, js); err != nil { - glog.Errorf("Error while writing response: %v", err) - } -} - -func init() { - http.HandleFunc("/login", loginHandler) -} diff --git a/dgraph/cmd/alpha/run.go b/dgraph/cmd/alpha/run.go index d0355cb6f97..297e97aff16 100644 --- a/dgraph/cmd/alpha/run.go +++ b/dgraph/cmd/alpha/run.go @@ -500,6 +500,7 @@ func setupServer(closer *z.Closer) { baseMux := http.NewServeMux() http.Handle("/", audit.AuditRequestHttp(baseMux)) + http.HandleFunc("/login", loginHandler) baseMux.HandleFunc("/query", queryHandler) baseMux.HandleFunc("/query/", queryHandler) baseMux.HandleFunc("/mutate", mutationHandler) diff --git a/dgraph/cmd/root.go b/dgraph/cmd/root.go index 4554ed6d136..03909f871bd 100644 --- a/dgraph/cmd/root.go +++ b/dgraph/cmd/root.go @@ -34,6 +34,9 @@ import ( "github.com/hypermodeinc/dgraph/v24/dgraph/cmd/migrate" "github.com/hypermodeinc/dgraph/v24/dgraph/cmd/version" "github.com/hypermodeinc/dgraph/v24/dgraph/cmd/zero" + "github.com/hypermodeinc/dgraph/v24/ee/acl" + "github.com/hypermodeinc/dgraph/v24/ee/audit" + "github.com/hypermodeinc/dgraph/v24/ee/backup" "github.com/hypermodeinc/dgraph/v24/upgrade" "github.com/hypermodeinc/dgraph/v24/x" ) @@ -74,7 +77,8 @@ var rootConf = viper.New() var subcommands = []*x.SubCommand{ &bulk.Bulk, &cert.Cert, &conv.Conv, &live.Live, &alpha.Alpha, &zero.Zero, &version.Version, &debug.Debug, &migrate.Migrate, &debuginfo.DebugInfo, &upgrade.Upgrade, &decrypt.Decrypt, &increment.Increment, - &checkupgrade.CheckUpgrade, + &checkupgrade.CheckUpgrade, &backup.Restore, &backup.LsBackup, &backup.ExportBackup, &acl.CmdAcl, + &audit.CmdAudit, } func initCmds() { diff --git a/dgraph/cmd/root_ee.go b/dgraph/cmd/root_ee.go deleted file mode 100644 index 8e98ad03805..00000000000 --- a/dgraph/cmd/root_ee.go +++ /dev/null @@ -1,25 +0,0 @@ -//go:build !oss -// +build !oss - -/* - * SPDX-FileCopyrightText: © Hypermode Inc. - */ - -package cmd - -import ( - acl "github.com/hypermodeinc/dgraph/v24/ee/acl" - "github.com/hypermodeinc/dgraph/v24/ee/audit" - "github.com/hypermodeinc/dgraph/v24/ee/backup" -) - -func init() { - // subcommands already has the default subcommands, we append to EE ones to that. - subcommands = append(subcommands, - &backup.Restore, - &backup.LsBackup, - &backup.ExportBackup, - &acl.CmdAcl, - &audit.CmdAudit, - ) -} diff --git a/edgraph/access.go b/edgraph/access.go index 406d2e032aa..4ea51f3b35a 100644 --- a/edgraph/access.go +++ b/edgraph/access.go @@ -1,6 +1,3 @@ -//go:build oss -// +build oss - /* * SPDX-FileCopyrightText: © Hypermode Inc. * SPDX-License-Identifier: Apache-2.0 @@ -10,85 +7,1411 @@ package edgraph import ( "context" + "encoding/json" + "fmt" + "strconv" + "strings" + "time" + "github.com/golang-jwt/jwt/v5" "github.com/golang/glog" + "github.com/pkg/errors" + otrace "go.opencensus.io/trace" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" + bpb "github.com/dgraph-io/badger/v4/pb" "github.com/dgraph-io/dgo/v240/protos/api" "github.com/dgraph-io/ristretto/v2/z" "github.com/hypermodeinc/dgraph/v24/dql" + "github.com/hypermodeinc/dgraph/v24/ee/acl" + "github.com/hypermodeinc/dgraph/v24/protos/pb" "github.com/hypermodeinc/dgraph/v24/query" + "github.com/hypermodeinc/dgraph/v24/schema" + "github.com/hypermodeinc/dgraph/v24/worker" "github.com/hypermodeinc/dgraph/v24/x" ) -// Login handles login requests from clients. This version rejects all requests -// since ACL is only supported in the enterprise version. +type predsAndvars struct { + preds []string + vars map[string]string +} + +// Login handles login requests from clients. func (s *Server) Login(ctx context.Context, request *api.LoginRequest) (*api.Response, error) { + + if !shouldAllowAcls(request.GetNamespace()) { + return nil, errors.New("operation is not allowed in shared cloud mode") + } + if err := x.HealthCheck(); err != nil { return nil, err } - glog.Warningf("Login failed: %s", x.ErrNotSupported) - return &api.Response{}, x.ErrNotSupported + ctx, span := otrace.StartSpan(ctx, "server.Login") + defer span.End() + + // record the client ip for this login request + var addr string + if ipAddr, err := hasAdminAuth(ctx, "Login"); err != nil { + return nil, err + } else { + addr = ipAddr.String() + span.Annotate([]otrace.Attribute{ + otrace.StringAttribute("client_ip", addr), + }, "client ip for login") + } + + user, err := s.authenticateLogin(ctx, request) + if err != nil { + glog.Errorf("Authentication from address %s failed: %v", addr, err) + return nil, x.ErrorInvalidLogin + } + glog.Infof("%s logged in successfully", user.UserID) + + resp := &api.Response{} + accessJwt, err := getAccessJwt(user.UserID, user.Groups, user.Namespace) + if err != nil { + errMsg := fmt.Sprintf("unable to get access jwt (userid=%s,addr=%s):%v", + user.UserID, addr, err) + glog.Errorf(errMsg) + return nil, errors.Errorf(errMsg) + } + + refreshJwt, err := getRefreshJwt(user.UserID, user.Namespace) + if err != nil { + errMsg := fmt.Sprintf("unable to get refresh jwt (userid=%s,addr=%s):%v", + user.UserID, addr, err) + glog.Errorf(errMsg) + return nil, errors.Errorf(errMsg) + } + + loginJwt := api.Jwt{ + AccessJwt: accessJwt, + RefreshJwt: refreshJwt, + } + + jwtBytes, err := proto.Marshal(&loginJwt) + if err != nil { + errMsg := fmt.Sprintf("unable to marshal jwt (userid=%s,addr=%s):%v", + user.UserID, addr, err) + glog.Errorf(errMsg) + return nil, errors.Errorf(errMsg) + } + resp.Json = jwtBytes + return resp, nil } -// ResetAcl is an empty method since ACL is only supported in the enterprise version. -func InitializeAcl(closer *z.Closer) { - // do nothing +// authenticateLogin authenticates the login request using either the refresh token if present, or +// the pair. If authentication passes, it queries the user's uid and associated +// groups from DB and returns the user object +func (s *Server) authenticateLogin(ctx context.Context, request *api.LoginRequest) (*acl.User, error) { + if err := validateLoginRequest(request); err != nil { + return nil, errors.Wrapf(err, "invalid login request") + } + + var user *acl.User + if len(request.RefreshToken) > 0 { + userData, err := validateToken(request.RefreshToken) + if err != nil { + return nil, errors.Wrapf(err, "unable to authenticate the refresh token %v", + request.RefreshToken) + } + + userId := userData.userId + ctx = x.AttachNamespace(ctx, userData.namespace) + user, err = authorizeUser(ctx, userId, "") + if err != nil { + return nil, errors.Wrapf(err, "while querying user with id %v", userId) + } + + if user == nil { + return nil, errors.Errorf("unable to authenticate: invalid credentials") + } + + user.Namespace = userData.namespace + glog.Infof("Authenticated user %s through refresh token", userId) + return user, nil + } + + // In case of login, we can't extract namespace from JWT because we have not yet given JWT + // to the user, so the login request should contain the namespace, which is then set to ctx. + ctx = x.AttachNamespace(ctx, request.Namespace) + + // authorize the user using password + var err error + user, err = authorizeUser(ctx, request.Userid, request.Password) + if err != nil { + return nil, errors.Wrapf(err, "while querying user with id %v", + request.Userid) + } + + if user == nil { + return nil, errors.Errorf("unable to authenticate: invalid credentials") + } + if !user.PasswordMatch { + return nil, x.ErrorInvalidLogin + } + user.Namespace = request.Namespace + return user, nil } -func upsertGuardianAndGroot(closer *z.Closer, ns uint64) { - // do nothing +type userData struct { + namespace uint64 + userId string + groupIds []string +} + +// validateToken verifies the signature and expiration of the jwt, and if validation passes, +// returns a slice of strings, where the first element is the extracted userId +// and the rest are groupIds encoded in the jwt. +func validateToken(jwtStr string) (*userData, error) { + claims, err := x.ParseJWT(jwtStr) + if err != nil { + return nil, err + } + // by default, the MapClaims.Valid will return true if the exp field is not set + // here we enforce the checking to make sure that the refresh token has not expired + if exp, err := claims.GetExpirationTime(); err != nil || exp == nil { + return nil, errors.Errorf("Token is expired") // the same error msg that's used inside jwt-go + } + + userId, ok := claims["userid"].(string) + if !ok { + return nil, errors.Errorf("userid in claims is not a string:%v", userId) + } + + /* + * Since, JSON numbers follow JavaScript's double-precision floating-point + * format . . . + * -- references: https://restfulapi.net/json-data-types/ + * -- https://www.tutorialspoint.com/json/json_data_types.htm + * . . . and fraction in IEEE 754 double precision binary floating-point + * format has 52 bits, . . . + * -- references: https://en.wikipedia.org/wiki/Double-precision_floating-point_format + * . . . the namespace field of the struct userData below can + * only accomodate a maximum value of (1 << 52) despite it being declared as + * uint64. Numbers bigger than this are likely to fail the test. + */ + namespace, ok := claims["namespace"].(float64) + if !ok { + return nil, errors.Errorf("namespace in claims is not valid:%v", namespace) + } + + groups, ok := claims["groups"].([]interface{}) + var groupIds []string + if ok { + groupIds = make([]string, 0, len(groups)) + for _, group := range groups { + groupId, ok := group.(string) + if !ok { + // This shouldn't happen. So, no need to make the client try to refresh the tokens. + return nil, errors.Errorf("unable to convert group to string:%v", group) + } + + groupIds = append(groupIds, groupId) + } + } + return &userData{namespace: uint64(namespace), userId: userId, groupIds: groupIds}, nil +} + +// validateLoginRequest validates that the login request has either the refresh token or the +// pair +func validateLoginRequest(request *api.LoginRequest) error { + if request == nil { + return errors.Errorf("the request should not be nil") + } + // we will use the refresh token for authentication if it's set + if len(request.RefreshToken) > 0 { + return nil + } + + // otherwise make sure both userid and password are set + if len(request.Userid) == 0 { + return errors.Errorf("the userid should not be empty") + } + if len(request.Password) == 0 { + return errors.Errorf("the password should not be empty") + } + return nil +} + +// getAccessJwt constructs an access jwt with the given user id, groupIds, namespace +// and expiration TTL specified by worker.Config.AccessJwtTtl +func getAccessJwt(userId string, groups []acl.Group, namespace uint64) (string, error) { + token := jwt.NewWithClaims(worker.Config.AclJwtAlg, jwt.MapClaims{ + "userid": userId, + "groups": acl.GetGroupIDs(groups), + "namespace": namespace, + // set the jwt exp according to the ttl + "exp": time.Now().Add(worker.Config.AccessJwtTtl).Unix(), + }) + + jwtString, err := token.SignedString(x.MaybeKeyToBytes(worker.Config.AclSecretKey)) + if err != nil { + return "", errors.Errorf("unable to encode jwt to string: %v", err) + } + return jwtString, nil +} + +// getRefreshJwt constructs a refresh jwt with the given user id, namespace and expiration ttl +// specified by worker.Config.RefreshJwtTtl +func getRefreshJwt(userId string, namespace uint64) (string, error) { + token := jwt.NewWithClaims(worker.Config.AclJwtAlg, jwt.MapClaims{ + "userid": userId, + "namespace": namespace, + "exp": time.Now().Add(worker.Config.RefreshJwtTtl).Unix(), + }) + + jwtString, err := token.SignedString(x.MaybeKeyToBytes(worker.Config.AclSecretKey)) + if err != nil { + return "", errors.Errorf("unable to encode jwt to string: %v", err) + } + return jwtString, nil +} + +const queryUser = ` + query search($userid: string, $password: string){ + user(func: eq(dgraph.xid, $userid)) @filter(type(dgraph.type.User)) { + uid + dgraph.xid + password_match: checkpwd(dgraph.password, $password) + dgraph.user.group { + uid + dgraph.xid + } + } + }` + +// authorizeUser queries the user with the given user id, and returns the associated uid, +// acl groups, and whether the password stored in DB matches the supplied password +func authorizeUser(ctx context.Context, userid string, password string) ( + *acl.User, error) { + + queryVars := map[string]string{ + "$userid": userid, + "$password": password, + } + req := &Request{ + req: &api.Request{ + Query: queryUser, + Vars: queryVars, + }, + doAuth: NoAuthorize, + } + queryResp, err := (&Server{}).doQuery(ctx, req) + if err != nil { + glog.Errorf("Error while query user with id %s: %v", userid, err) + return nil, err + } + user, err := acl.UnmarshalUser(queryResp, "user") + if err != nil { + return nil, err + } + return user, nil +} + +func refreshAclCache(ctx context.Context, ns, refreshTs uint64) error { + req := &Request{ + req: &api.Request{ + Query: queryAcls, + ReadOnly: true, + StartTs: refreshTs, + }, + doAuth: NoAuthorize, + } + + ctx = x.AttachNamespace(ctx, ns) + queryResp, err := (&Server{}).doQuery(ctx, req) + if err != nil { + return errors.Errorf("unable to retrieve acls: %v", err) + } + groups, err := acl.UnmarshalGroups(queryResp.GetJson(), "allAcls") + if err != nil { + return err + } + + worker.AclCachePtr.Update(ns, groups) + glog.V(2).Infof("Updated the ACL cache for namespace: %#x", ns) + return nil + +} + +func RefreshACLs(ctx context.Context) { + for ns := range schema.State().Namespaces() { + if err := refreshAclCache(ctx, ns, 0); err != nil { + glog.Errorf("Error while retrieving acls for namespace %#x: %v", ns, err) + } + } + worker.AclCachePtr.Set() } -// SubscribeForAclUpdates is an empty method since ACL is only supported in the enterprise version. +// SubscribeForAclUpdates subscribes for ACL predicates and updates the acl cache. func SubscribeForAclUpdates(closer *z.Closer) { - // do nothing + defer func() { + glog.Infoln("RefreshAcls closed") + closer.Done() + }() + if worker.Config.AclSecretKey == nil { + // the acl feature is not turned on + return + } + + var maxRefreshTs uint64 + retrieveAcls := func(ns uint64, refreshTs uint64) error { + if refreshTs <= maxRefreshTs { + return nil + } + maxRefreshTs = refreshTs + return refreshAclCache(closer.Ctx(), ns, refreshTs) + } + + closer.AddRunning(1) + go worker.SubscribeForUpdates(aclPrefixes, x.IgnoreBytes, func(kvs *bpb.KVList) { + if kvs == nil || len(kvs.Kv) == 0 { + return + } + kv := x.KvWithMaxVersion(kvs, aclPrefixes) + pk, err := x.Parse(kv.GetKey()) + if err != nil { + glog.Fatalf("Got a key from subscription which is not parsable: %s", err) + } + glog.V(3).Infof("Got ACL update via subscription for attr: %s", pk.Attr) + + ns, _ := x.ParseNamespaceAttr(pk.Attr) + if err := retrieveAcls(ns, kv.GetVersion()); err != nil { + glog.Errorf("Error while retrieving acls: %v", err) + } + }, 1, closer) + <-closer.HasBeenClosed() - closer.Done() } -// RefreshACLs is an empty method since ACL is only supported in the enterprise version. -func RefreshACLs(ctx context.Context) { - return +const queryAcls = ` +{ + allAcls(func: type(dgraph.type.Group)) { + dgraph.xid + dgraph.acl.rule { + dgraph.rule.predicate + dgraph.rule.permission + } + ~dgraph.user.group{ + dgraph.xid + } + } } +` -func authorizeAlter(ctx context.Context, op *api.Operation) error { +var aclPrefixes = [][]byte{ + x.PredicatePrefix(x.GalaxyAttr("dgraph.rule.permission")), + x.PredicatePrefix(x.GalaxyAttr("dgraph.rule.predicate")), + x.PredicatePrefix(x.GalaxyAttr("dgraph.acl.rule")), + x.PredicatePrefix(x.GalaxyAttr("dgraph.user.group")), + x.PredicatePrefix(x.GalaxyAttr("dgraph.type.Group")), + x.PredicatePrefix(x.GalaxyAttr("dgraph.xid")), +} + +// upserts the Groot account. +func InitializeAcl(closer *z.Closer) { + defer func() { + glog.Infof("InitializeAcl closed") + closer.Done() + }() + + if worker.Config.AclSecretKey == nil { + // The acl feature is not turned on. + return + } + upsertGuardianAndGroot(closer, x.GalaxyNamespace) +} + +// Note: The handling of closer should be done by caller. +func upsertGuardianAndGroot(closer *z.Closer, ns uint64) { + if worker.Config.AclSecretKey == nil { + // The acl feature is not turned on. + return + } + for closer.Ctx().Err() == nil { + ctx, cancel := context.WithTimeout(closer.Ctx(), time.Minute) + defer cancel() + ctx = x.AttachNamespace(ctx, ns) + if err := upsertGuardian(ctx); err != nil { + glog.Infof("Unable to upsert the guardian group. Error: %v", err) + time.Sleep(100 * time.Millisecond) + continue + } + break + } + + for closer.Ctx().Err() == nil { + ctx, cancel := context.WithTimeout(closer.Ctx(), time.Minute) + defer cancel() + ctx = x.AttachNamespace(ctx, ns) + if err := upsertGroot(ctx, "password"); err != nil { + glog.Infof("Unable to upsert the groot account. Error: %v", err) + time.Sleep(100 * time.Millisecond) + continue + } + break + } +} + +// upsertGuardian must be called after setting the namespace in the context. +func upsertGuardian(ctx context.Context) error { + query := fmt.Sprintf(` + { + guid as guardians(func: eq(dgraph.xid, "%s")) @filter(type(dgraph.type.Group)) { + uid + } + } + `, x.GuardiansId) + groupNQuads := acl.CreateGroupNQuads(x.GuardiansId) + req := &Request{ + req: &api.Request{ + CommitNow: true, + Query: query, + Mutations: []*api.Mutation{ + { + Set: groupNQuads, + Cond: "@if(eq(len(guid), 0))", + }, + }, + }, + doAuth: NoAuthorize, + } + + resp, err := (&Server{}).doQuery(ctx, req) + + // Structs to parse guardians group uid from query response + type groupNode struct { + Uid string `json:"uid"` + } + + type groupQryResp struct { + GuardiansGroup []groupNode `json:"guardians"` + } + + if err != nil { + return errors.Wrapf(err, "while upserting group with id %s", x.GuardiansId) + } + var groupResp groupQryResp + var guardiansUidStr string + if err := json.Unmarshal(resp.GetJson(), &groupResp); err != nil { + return errors.Wrap(err, "Couldn't unmarshal response from guardians group query") + } + + if len(groupResp.GuardiansGroup) == 0 { + // no guardians group found + // Extract guardians group uid from mutation + newGroupUidMap := resp.GetUids() + guardiansUidStr = newGroupUidMap["newgroup"] + } else if len(groupResp.GuardiansGroup) == 1 { + // we found a guardians group + guardiansUidStr = groupResp.GuardiansGroup[0].Uid + } else { + return errors.Wrap(err, "Multiple guardians group found") + } + + uid, err := strconv.ParseUint(guardiansUidStr, 0, 64) + if err != nil { + return errors.Wrapf(err, "Error while parsing Uid: %s of guardians Group", guardiansUidStr) + } + ns, err := x.ExtractNamespace(ctx) + if err != nil { + return errors.Wrapf(err, "While upserting group with id %s", x.GuardiansId) + } + x.GuardiansUid.Store(ns, uid) + glog.V(2).Infof("Successfully upserted the guardian of namespace: %d\n", ns) return nil } -func authorizeMutation(ctx context.Context, gmu *dql.Mutation) error { +// upsertGroot must be called after setting the namespace in the context. +func upsertGroot(ctx context.Context, passwd string) error { + // groot is the default user of guardians group. + query := fmt.Sprintf(` + { + grootid as grootUser(func: eq(dgraph.xid, "%s")) @filter(type(dgraph.type.User)) { + uid + } + guid as var(func: eq(dgraph.xid, "%s")) @filter(type(dgraph.type.Group)) + } + `, x.GrootId, x.GuardiansId) + userNQuads := acl.CreateUserNQuads(x.GrootId, passwd) + userNQuads = append(userNQuads, &api.NQuad{ + Subject: "_:newuser", + Predicate: "dgraph.user.group", + ObjectId: "uid(guid)", + }) + req := &Request{ + req: &api.Request{ + CommitNow: true, + Query: query, + Mutations: []*api.Mutation{ + { + Set: userNQuads, + // Assuming that if groot exists, it is in guardian group + Cond: "@if(eq(len(grootid), 0) and gt(len(guid), 0))", + }, + }, + }, + doAuth: NoAuthorize, + } + + resp, err := (&Server{}).doQuery(ctx, req) + if err != nil { + return errors.Wrapf(err, "while upserting user with id %s", x.GrootId) + } + + // Structs to parse groot user uid from query response + type userNode struct { + Uid string `json:"uid"` + } + + type userQryResp struct { + GrootUser []userNode `json:"grootUser"` + } + + var grootUserUid string + var userResp userQryResp + if err := json.Unmarshal(resp.GetJson(), &userResp); err != nil { + return errors.Wrap(err, "Couldn't unmarshal response from groot user query") + } + if len(userResp.GrootUser) == 0 { + // no groot user found from query + // Extract uid of created groot user from mutation + newUserUidMap := resp.GetUids() + grootUserUid = newUserUidMap["newuser"] + } else if len(userResp.GrootUser) == 1 { + // we found a groot user + grootUserUid = userResp.GrootUser[0].Uid + } else { + return errors.Wrap(err, "Multiple groot users found") + } + + uid, err := strconv.ParseUint(grootUserUid, 0, 64) + if err != nil { + return errors.Wrapf(err, "Error while parsing Uid: %s of groot user", grootUserUid) + } + ns, err := x.ExtractNamespace(ctx) + if err != nil { + return errors.Wrapf(err, "While upserting user with id %s", x.GrootId) + } + x.GrootUid.Store(ns, uid) + glog.V(2).Infof("Successfully upserted groot account for namespace %d\n", ns) return nil } +// extract the userId, groupIds from the accessJwt in the context +func extractUserAndGroups(ctx context.Context) (*userData, error) { + accessJwt, err := x.ExtractJwt(ctx) + if err != nil { + return nil, err + } + return validateToken(accessJwt) +} + +type authPredResult struct { + allowed []string + blocked map[string]struct{} +} + +func authorizePreds(ctx context.Context, userData *userData, preds []string, + aclOp *acl.Operation) *authPredResult { + + if !worker.AclCachePtr.Loaded() { + RefreshACLs(ctx) + } + + userId := userData.userId + groupIds := userData.groupIds + ns := userData.namespace + blockedPreds := make(map[string]struct{}) + for _, pred := range preds { + nsPred := x.NamespaceAttr(ns, pred) + if err := worker.AclCachePtr.AuthorizePredicate(groupIds, nsPred, aclOp); err != nil { + logAccess(&accessEntry{ + userId: userId, + groups: groupIds, + preds: preds, + operation: aclOp, + allowed: false, + }) + blockedPreds[pred] = struct{}{} + } + } + if worker.HasAccessToAllPreds(ns, groupIds, aclOp) { + // Setting allowed to nil allows access to all predicates. Note that the access to ACL + // predicates will still be blocked. + return &authPredResult{allowed: nil, blocked: blockedPreds} + } + // User can have multiple permission for same predicate, add predicate + allowedPreds := make([]string, 0, len(worker.AclCachePtr.GetUserPredPerms(userId))) + // only if the acl.Op is covered in the set of permissions for the user + for predicate, perm := range worker.AclCachePtr.GetUserPredPerms(userId) { + if (perm & aclOp.Code) > 0 { + allowedPreds = append(allowedPreds, predicate) + } + } + return &authPredResult{allowed: allowedPreds, blocked: blockedPreds} +} + +// authorizeAlter parses the Schema in the operation and authorizes the operation +// using the worker.AclCachePtr. It will return error if any one of the predicates +// specified in alter are not authorized. +func authorizeAlter(ctx context.Context, op *api.Operation) error { + if worker.Config.AclSecretKey == nil { + // the user has not turned on the acl feature + return nil + } + + // extract the list of predicates from the operation object + var preds []string + switch { + case len(op.DropAttr) > 0: + preds = []string{op.DropAttr} + case op.DropOp == api.Operation_ATTR && len(op.DropValue) > 0: + preds = []string{op.DropValue} + default: + update, err := schema.Parse(op.Schema) + if err != nil { + return err + } + + for _, u := range update.Preds { + preds = append(preds, x.ParseAttr(u.Predicate)) + } + } + var userId string + var groupIds []string + + // doAuthorizeAlter checks if alter of all the predicates are allowed + // as a byproduct, it also sets the userId, groups variables + doAuthorizeAlter := func() error { + userData, err := extractUserAndGroups(ctx) + if err != nil { + // We don't follow fail open approach anymore. + return status.Error(codes.Unauthenticated, err.Error()) + } + + userId = userData.userId + groupIds = userData.groupIds + + if x.IsGuardian(groupIds) { + // Members of guardian group are allowed to alter anything. + return nil + } + + // if we get here, we know the user is not a guardian. + if isDropAll(op) || op.DropOp == api.Operation_DATA { + return errors.Errorf( + "only guardians are allowed to drop all data, but the current user is %s", userId) + } + + result := authorizePreds(ctx, userData, preds, acl.Modify) + if len(result.blocked) > 0 { + var msg strings.Builder + for key := range result.blocked { + x.Check2(msg.WriteString(key)) + x.Check2(msg.WriteString(" ")) + } + return status.Errorf(codes.PermissionDenied, + "unauthorized to alter following predicates: %s\n", msg.String()) + } + return nil + } + + err := doAuthorizeAlter() + span := otrace.FromContext(ctx) + if span != nil { + span.Annotatef(nil, (&accessEntry{ + userId: userId, + groups: groupIds, + preds: preds, + operation: acl.Modify, + allowed: err == nil, + }).String()) + } + + return err +} + +// parsePredsFromMutation returns a union set of all the predicate names in the input nquads +func parsePredsFromMutation(nquads []*api.NQuad) []string { + // use a map to dedup predicates + predsMap := make(map[string]struct{}) + for _, nquad := range nquads { + // _STAR_ALL is not a predicate in itself. + if nquad.Predicate != "_STAR_ALL" { + predsMap[nquad.Predicate] = struct{}{} + } + } + + preds := make([]string, 0, len(predsMap)) + for pred := range predsMap { + preds = append(preds, pred) + } + + return preds +} + +func isAclPredMutation(nquads []*api.NQuad) bool { + for _, nquad := range nquads { + if nquad.Predicate == "dgraph.group.acl" && nquad.ObjectValue != nil { + // this mutation is trying to change the permission of some predicate + // check if the predicate list contains an ACL predicate + if _, ok := nquad.ObjectValue.Val.(*api.Value_BytesVal); ok { + aclBytes := nquad.ObjectValue.Val.(*api.Value_BytesVal) + var aclsToChange []acl.Acl + err := json.Unmarshal(aclBytes.BytesVal, &aclsToChange) + if err != nil { + glog.Errorf(fmt.Sprintf("Unable to unmarshal bytes under the dgraph.group.acl "+ + "predicate: %v", err)) + continue + } + for _, aclToChange := range aclsToChange { + if x.IsAclPredicate(aclToChange.Predicate) { + return true + } + } + } + } + } + return false +} + +// authorizeMutation authorizes the mutation using the worker.AclCachePtr. It will return permission +// denied error if any one of the predicates in mutation(set or delete) is unauthorized. +// At this stage, namespace is not attached in the predicates. +func authorizeMutation(ctx context.Context, gmu *dql.Mutation) error { + if worker.Config.AclSecretKey == nil { + // the user has not turned on the acl feature + return nil + } + + preds := parsePredsFromMutation(gmu.Set) + // Del predicates weren't included before. + // A bug probably since f115de2eb6a40d882a86c64da68bf5c2a33ef69a + preds = append(preds, parsePredsFromMutation(gmu.Del)...) + + var userId string + var groupIds []string + // doAuthorizeMutation checks if modification of all the predicates are allowed + // as a byproduct, it also sets the userId and groups + doAuthorizeMutation := func() error { + userData, err := extractUserAndGroups(ctx) + if err != nil { + // We don't follow fail open approach anymore. + return status.Error(codes.Unauthenticated, err.Error()) + } + + userId = userData.userId + groupIds = userData.groupIds + + if x.IsGuardian(groupIds) { + // Members of guardians group are allowed to mutate anything + // (including delete) except the permission of the acl predicates. + switch { + case isAclPredMutation(gmu.Set): + return errors.Errorf("the permission of ACL predicates can not be changed") + case isAclPredMutation(gmu.Del): + return errors.Errorf("ACL predicates can't be deleted") + } + if !shouldAllowAcls(userData.namespace) { + for _, pred := range preds { + if x.IsAclPredicate(pred) { + return status.Errorf(codes.PermissionDenied, + "unauthorized to mutate acl predicates: %s\n", pred) + } + } + } + return nil + } + result := authorizePreds(ctx, userData, preds, acl.Write) + if len(result.blocked) > 0 { + var msg strings.Builder + for key := range result.blocked { + x.Check2(msg.WriteString(key)) + x.Check2(msg.WriteString(" ")) + } + return status.Errorf(codes.PermissionDenied, + "unauthorized to mutate following predicates: %s\n", msg.String()) + } + gmu.AllowedPreds = result.allowed + return nil + } + + err := doAuthorizeMutation() + + span := otrace.FromContext(ctx) + if span != nil { + span.Annotatef(nil, (&accessEntry{ + userId: userId, + groups: groupIds, + preds: preds, + operation: acl.Write, + allowed: err == nil, + }).String()) + } + + return err +} + +func parsePredsFromQuery(dqls []*dql.GraphQuery) predsAndvars { + predsMap := make(map[string]struct{}) + varsMap := make(map[string]string) + for _, gq := range dqls { + if gq.Func != nil { + predsMap[gq.Func.Attr] = struct{}{} + } + if len(gq.Var) > 0 { + varsMap[gq.Var] = gq.Attr + } + if len(gq.Attr) > 0 && gq.Attr != "uid" && gq.Attr != "expand" && gq.Attr != "val" { + predsMap[gq.Attr] = struct{}{} + + } + for _, ord := range gq.Order { + predsMap[ord.Attr] = struct{}{} + } + for _, gbAttr := range gq.GroupbyAttrs { + predsMap[gbAttr.Attr] = struct{}{} + } + for _, pred := range parsePredsFromFilter(gq.Filter) { + predsMap[pred] = struct{}{} + } + childPredandVars := parsePredsFromQuery(gq.Children) + for _, childPred := range childPredandVars.preds { + predsMap[childPred] = struct{}{} + } + for childVar := range childPredandVars.vars { + varsMap[childVar] = childPredandVars.vars[childVar] + } + } + preds := make([]string, 0, len(predsMap)) + for pred := range predsMap { + if len(pred) > 0 { + if _, found := varsMap[pred]; !found { + preds = append(preds, pred) + } + } + } + + pv := predsAndvars{preds: preds, vars: varsMap} + return pv +} + +func parsePredsFromFilter(f *dql.FilterTree) []string { + var preds []string + if f == nil { + return preds + } + if f.Func != nil && len(f.Func.Attr) > 0 { + preds = append(preds, f.Func.Attr) + } + for _, ch := range f.Child { + preds = append(preds, parsePredsFromFilter(ch)...) + } + return preds +} + +type accessEntry struct { + userId string + groups []string + preds []string + operation *acl.Operation + allowed bool +} + +func (log *accessEntry) String() string { + return fmt.Sprintf("ACL-LOG Authorizing user %q with groups %q on predicates %q "+ + "for %q, allowed:%v", log.userId, strings.Join(log.groups, ","), + strings.Join(log.preds, ","), log.operation.Name, log.allowed) +} + +func logAccess(log *accessEntry) { + if glog.V(1) { + glog.Info(log.String()) + } +} + +func blockedPreds(preds []string) map[string]struct{} { + blocked := make(map[string]struct{}) + for _, pred := range preds { + if x.IsAclPredicate(pred) { + blocked[pred] = struct{}{} + } + } + return blocked +} + +// With shared instance enabled, we don't allow ACL operations from any of the non-galaxy namespace. +func shouldAllowAcls(ns uint64) bool { + return !x.Config.SharedInstance || ns == x.GalaxyNamespace +} + +// authorizeQuery authorizes the query using the aclCachePtr. It will silently drop all +// unauthorized predicates from query. +// At this stage, namespace is not attached in the predicates. func authorizeQuery(ctx context.Context, parsedReq *dql.Result, graphql bool) error { - // always allow access + if worker.Config.AclSecretKey == nil { + // the user has not turned on the acl feature + return nil + } + + var userId string + var groupIds []string + var namespace uint64 + predsAndvars := parsePredsFromQuery(parsedReq.Query) + preds := predsAndvars.preds + varsToPredMap := predsAndvars.vars + + // Need this to efficiently identify blocked variables from the + // list of blocked predicates + predToVarsMap := make(map[string]string) + for k, v := range varsToPredMap { + predToVarsMap[v] = k + } + + doAuthorizeQuery := func() (map[string]struct{}, []string, error) { + userData, err := extractUserAndGroups(ctx) + if err != nil { + return nil, nil, status.Error(codes.Unauthenticated, err.Error()) + } + + userId = userData.userId + groupIds = userData.groupIds + namespace = userData.namespace + + if x.IsGuardian(groupIds) { + if shouldAllowAcls(userData.namespace) { + // Members of guardian groups are allowed to query anything. + return nil, nil, nil + } + return blockedPreds(preds), nil, nil + } + + result := authorizePreds(ctx, userData, preds, acl.Read) + return result.blocked, result.allowed, nil + } + + blockedPreds, allowedPreds, err := doAuthorizeQuery() + if err != nil { + return err + } + + if span := otrace.FromContext(ctx); span != nil { + span.Annotatef(nil, (&accessEntry{ + userId: userId, + groups: groupIds, + preds: preds, + operation: acl.Read, + allowed: err == nil, + }).String()) + } + + if len(blockedPreds) != 0 { + // For GraphQL requests, we allow filtered access to the ACL predicates. + // Filter for user_id and group_id is applied for the currently logged in user. + if graphql && shouldAllowAcls(namespace) { + for _, gq := range parsedReq.Query { + addUserFilterToQuery(gq, userId, groupIds) + } + // blockedPreds might have acl predicates which we want to allow access through + // graphql, so deleting those from here. + for _, pred := range x.AllACLPredicates() { + delete(blockedPreds, pred) + } + // In query context ~predicate and predicate are considered different. + delete(blockedPreds, "~dgraph.user.group") + } + + blockedVars := make(map[string]struct{}) + for predicate := range blockedPreds { + if variable, found := predToVarsMap[predicate]; found { + // Add variables to blockedPreds to delete from Query + blockedPreds[variable] = struct{}{} + // Collect blocked Variables to remove from QueryVars + blockedVars[variable] = struct{}{} + } + } + parsedReq.Query = removePredsFromQuery(parsedReq.Query, blockedPreds) + parsedReq.QueryVars = removeVarsFromQueryVars(parsedReq.QueryVars, blockedVars) + } + for i := range parsedReq.Query { + parsedReq.Query[i].AllowedPreds = allowedPreds + } + return nil } func authorizeSchemaQuery(ctx context.Context, er *query.ExecutionResult) error { - // always allow schema access + if worker.Config.AclSecretKey == nil { + // the user has not turned on the acl feature + return nil + } + + // find the predicates being sent in response + preds := make([]string, 0) + predsMap := make(map[string]struct{}) + for _, predNode := range er.SchemaNode { + preds = append(preds, predNode.Predicate) + predsMap[predNode.Predicate] = struct{}{} + } + for _, typeNode := range er.Types { + for _, field := range typeNode.Fields { + if _, ok := predsMap[field.Predicate]; !ok { + preds = append(preds, field.Predicate) + } + } + } + + doAuthorizeSchemaQuery := func() (map[string]struct{}, error) { + userData, err := extractUserAndGroups(ctx) + if err != nil { + return nil, status.Error(codes.Unauthenticated, err.Error()) + } + + groupIds := userData.groupIds + if x.IsGuardian(groupIds) { + if shouldAllowAcls(userData.namespace) { + // Members of guardian groups are allowed to query anything. + return nil, nil + } + return blockedPreds(preds), nil + } + result := authorizePreds(ctx, userData, preds, acl.Read) + return result.blocked, nil + } + + // find the predicates which are blocked for the schema query + blockedPreds, err := doAuthorizeSchemaQuery() + if err != nil { + return err + } + + // remove those predicates from response + if len(blockedPreds) > 0 { + respPreds := make([]*pb.SchemaNode, 0) + for _, predNode := range er.SchemaNode { + if _, ok := blockedPreds[predNode.Predicate]; !ok { + respPreds = append(respPreds, predNode) + } + } + er.SchemaNode = respPreds + + for _, typeNode := range er.Types { + respFields := make([]*pb.SchemaUpdate, 0) + for _, field := range typeNode.Fields { + if _, ok := blockedPreds[field.Predicate]; !ok { + respFields = append(respFields, field) + } + } + typeNode.Fields = respFields + } + } + return nil } -func AuthorizeGuardians(ctx context.Context) error { - // always allow access +// AuthGuardianOfTheGalaxy authorizes the operations for the users who belong to the guardians +// group in the galaxy namespace. This authorization is used for admin usages like creation and +// deletion of a namespace, resetting passwords across namespaces etc. +// NOTE: The caller should not wrap the error returned. If needed, propagate the GRPC error code. +func AuthGuardianOfTheGalaxy(ctx context.Context) error { + if !x.WorkerConfig.AclEnabled { + return nil + } + ns, err := x.ExtractNamespaceFrom(ctx) + if err != nil { + return errors.Wrap(err, "Authorize guardian of the galaxy, extracting jwt token, error:") + } + if ns != 0 { + return status.Error( + codes.PermissionDenied, "Only guardian of galaxy is allowed to do this operation") + } + // AuthorizeGuardians will extract (user, []groups) from the JWT claims and will check if + // any of the group to which the user belongs is "guardians" or not. + if err := AuthorizeGuardians(ctx); err != nil { + s := status.Convert(err) + return status.Error( + s.Code(), "AuthGuardianOfTheGalaxy: failed to authorize guardians. "+s.Message()) + } + glog.V(3).Info("Successfully authorised guardian of the galaxy") return nil } -func AuthGuardianOfTheGalaxy(ctx context.Context) error { - // always allow access +// AuthorizeGuardians authorizes the operation for users which belong to Guardians group. +// NOTE: The caller should not wrap the error returned. If needed, propagate the GRPC error code. +func AuthorizeGuardians(ctx context.Context) error { + if worker.Config.AclSecretKey == nil { + // the user has not turned on the acl feature + return nil + } + + userData, err := extractUserAndGroups(ctx) + switch { + case err == x.ErrNoJwt: + return status.Error(codes.PermissionDenied, err.Error()) + case err != nil: + return status.Error(codes.Unauthenticated, err.Error()) + default: + userId := userData.userId + groupIds := userData.groupIds + + if !x.IsGuardian(groupIds) { + // Deny access for members of non-guardian groups + return status.Error(codes.PermissionDenied, fmt.Sprintf("Only guardians are "+ + "allowed access. User '%v' is not a member of guardians group.", userId)) + } + } + return nil } -func validateToken(jwtStr string) ([]string, error) { - return nil, nil +/* +addUserFilterToQuery applies makes sure that a user can access only its own +acl info by applying filter of userid and groupid to acl predicates. A query like +Conversion pattern: + - me(func: type(dgraph.type.Group)) -> + me(func: type(dgraph.type.Group)) @filter(eq("dgraph.xid", groupIds...)) + - me(func: type(dgraph.type.User)) -> + me(func: type(dgraph.type.User)) @filter(eq("dgraph.xid", userId)) +*/ +func addUserFilterToQuery(gq *dql.GraphQuery, userId string, groupIds []string) { + if gq.Func != nil && gq.Func.Name == "type" { + // type function only supports one argument + if len(gq.Func.Args) != 1 { + return + } + arg := gq.Func.Args[0] + // The case where value of some varialble v (say) is "dgraph.type.Group" and a + // query comes like `eq(dgraph.type, val(v))`, will be ignored here. + if arg.Value == "dgraph.type.User" { + newFilter := userFilter(userId) + gq.Filter = parentFilter(newFilter, gq.Filter) + } else if arg.Value == "dgraph.type.Group" { + newFilter := groupFilter(groupIds) + gq.Filter = parentFilter(newFilter, gq.Filter) + } + } + + gq.Filter = addUserFilterToFilter(gq.Filter, userId, groupIds) + + switch gq.Attr { + case "dgraph.user.group": + newFilter := groupFilter(groupIds) + gq.Filter = parentFilter(newFilter, gq.Filter) + case "~dgraph.user.group": + newFilter := userFilter(userId) + gq.Filter = parentFilter(newFilter, gq.Filter) + } + + for _, ch := range gq.Children { + addUserFilterToQuery(ch, userId, groupIds) + } } -func upsertGuardian(ctx context.Context) error { - return nil +func parentFilter(newFilter, filter *dql.FilterTree) *dql.FilterTree { + if filter == nil { + return newFilter + } + parentFilter := &dql.FilterTree{ + Op: "AND", + Child: []*dql.FilterTree{filter, newFilter}, + } + return parentFilter } -func upsertGroot(ctx context.Context) error { - return nil +func userFilter(userId string) *dql.FilterTree { + // A logged in user should always have a userId. + return &dql.FilterTree{ + Func: &dql.Function{ + Attr: "dgraph.xid", + Name: "eq", + Args: []dql.Arg{{Value: userId}}, + }, + } +} + +func groupFilter(groupIds []string) *dql.FilterTree { + // The user doesn't have any groups, so add an empty filter @filter(uid([])) so that all + // groups are filtered out. + if len(groupIds) == 0 { + filter := &dql.FilterTree{ + Func: &dql.Function{ + Name: "uid", + UID: []uint64{}, + }, + } + return filter + } + + filter := &dql.FilterTree{ + Func: &dql.Function{ + Attr: "dgraph.xid", + Name: "eq", + }, + } + + for _, gid := range groupIds { + filter.Func.Args = append(filter.Func.Args, + dql.Arg{Value: gid}) + } + + return filter +} + +/* + addUserFilterToFilter makes sure that user can't misue filters to access other user's info. + If the *filter* have type(dgraph.type.Group) or type(dgraph.type.User) functions, + it generate a *newFilter* with function like eq(dgraph.xid, userId) or eq(dgraph.xid,groupId...) + and return a filter of the form + + &dql.FilterTree{ + Op: "AND", + Child: []dql.FilterTree{ + {filter, newFilter} + } + } +*/ +func addUserFilterToFilter(filter *dql.FilterTree, userId string, + groupIds []string) *dql.FilterTree { + + if filter == nil { + return nil + } + + if filter.Func != nil && filter.Func.Name == "type" { + + // type function supports only one argument + if len(filter.Func.Args) != 1 { + return nil + } + arg := filter.Func.Args[0] + var newFilter *dql.FilterTree + switch arg.Value { + case "dgraph.type.User": + newFilter = userFilter(userId) + case "dgraph.type.Group": + newFilter = groupFilter(groupIds) + } + + // If filter have function, it can't have children. + return parentFilter(newFilter, filter) + } + + for idx, child := range filter.Child { + filter.Child[idx] = addUserFilterToFilter(child, userId, groupIds) + } + + return filter +} + +// removePredsFromQuery removes all the predicates in blockedPreds +// from all the queries in gqs. +func removePredsFromQuery(gqs []*dql.GraphQuery, + blockedPreds map[string]struct{}) []*dql.GraphQuery { + + filteredGQs := gqs[:0] +L: + for _, gq := range gqs { + if gq.Func != nil && len(gq.Func.Attr) > 0 { + if _, ok := blockedPreds[gq.Func.Attr]; ok { + continue + } + } + if len(gq.Attr) > 0 { + if _, ok := blockedPreds[gq.Attr]; ok { + continue + } + if gq.Attr == "val" { + // TODO (Anurag): If val supports multiple variables, this would + // need an upgrade + for _, variable := range gq.NeedsVar { + if _, ok := blockedPreds[variable.Name]; ok { + continue L + } + } + } + } + + order := gq.Order[:0] + for _, ord := range gq.Order { + if _, ok := blockedPreds[ord.Attr]; ok { + continue + } + order = append(order, ord) + } + + gq.Order = order + gq.Filter = removeFilters(gq.Filter, blockedPreds) + gq.GroupbyAttrs = removeGroupBy(gq.GroupbyAttrs, blockedPreds) + gq.Children = removePredsFromQuery(gq.Children, blockedPreds) + filteredGQs = append(filteredGQs, gq) + } + + return filteredGQs +} + +func removeVarsFromQueryVars(gqs []*dql.Vars, + blockedVars map[string]struct{}) []*dql.Vars { + + filteredGQs := gqs[:0] + for _, gq := range gqs { + var defines []string + var needs []string + for _, variable := range gq.Defines { + if _, ok := blockedVars[variable]; !ok { + defines = append(defines, variable) + } + } + for _, variable := range gq.Needs { + if _, ok := blockedVars[variable]; !ok { + needs = append(needs, variable) + } + } + gq.Defines = defines + gq.Needs = needs + filteredGQs = append(filteredGQs, gq) + } + return filteredGQs +} + +func removeFilters(f *dql.FilterTree, blockedPreds map[string]struct{}) *dql.FilterTree { + if f == nil { + return nil + } + if f.Func != nil && len(f.Func.Attr) > 0 { + if _, ok := blockedPreds[f.Func.Attr]; ok { + return nil + } + } + + filteredChildren := f.Child[:0] + for _, ch := range f.Child { + child := removeFilters(ch, blockedPreds) + if child != nil { + filteredChildren = append(filteredChildren, child) + } + } + if len(filteredChildren) != len(f.Child) && (f.Op == "AND" || f.Op == "NOT") { + return nil + } + f.Child = filteredChildren + return f +} + +func removeGroupBy(gbAttrs []dql.GroupByAttr, + blockedPreds map[string]struct{}) []dql.GroupByAttr { + + filteredGbAttrs := gbAttrs[:0] + for _, gbAttr := range gbAttrs { + if _, ok := blockedPreds[gbAttr.Attr]; ok { + continue + } + filteredGbAttrs = append(filteredGbAttrs, gbAttr) + } + return filteredGbAttrs } diff --git a/edgraph/access_ee.go b/edgraph/access_ee.go deleted file mode 100644 index 0e04907f105..00000000000 --- a/edgraph/access_ee.go +++ /dev/null @@ -1,1419 +0,0 @@ -//go:build !oss -// +build !oss - -/* - * SPDX-FileCopyrightText: © Hypermode Inc. - */ - -package edgraph - -import ( - "context" - "encoding/json" - "fmt" - "strconv" - "strings" - "time" - - "github.com/golang-jwt/jwt/v5" - "github.com/golang/glog" - "github.com/pkg/errors" - otrace "go.opencensus.io/trace" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - "google.golang.org/protobuf/proto" - - bpb "github.com/dgraph-io/badger/v4/pb" - "github.com/dgraph-io/dgo/v240/protos/api" - "github.com/dgraph-io/ristretto/v2/z" - "github.com/hypermodeinc/dgraph/v24/dql" - "github.com/hypermodeinc/dgraph/v24/ee/acl" - "github.com/hypermodeinc/dgraph/v24/protos/pb" - "github.com/hypermodeinc/dgraph/v24/query" - "github.com/hypermodeinc/dgraph/v24/schema" - "github.com/hypermodeinc/dgraph/v24/worker" - "github.com/hypermodeinc/dgraph/v24/x" -) - -type predsAndvars struct { - preds []string - vars map[string]string -} - -// Login handles login requests from clients. -func (s *Server) Login(ctx context.Context, - request *api.LoginRequest) (*api.Response, error) { - - if !shouldAllowAcls(request.GetNamespace()) { - return nil, errors.New("operation is not allowed in shared cloud mode") - } - - if err := x.HealthCheck(); err != nil { - return nil, err - } - - ctx, span := otrace.StartSpan(ctx, "server.Login") - defer span.End() - - // record the client ip for this login request - var addr string - if ipAddr, err := hasAdminAuth(ctx, "Login"); err != nil { - return nil, err - } else { - addr = ipAddr.String() - span.Annotate([]otrace.Attribute{ - otrace.StringAttribute("client_ip", addr), - }, "client ip for login") - } - - user, err := s.authenticateLogin(ctx, request) - if err != nil { - glog.Errorf("Authentication from address %s failed: %v", addr, err) - return nil, x.ErrorInvalidLogin - } - glog.Infof("%s logged in successfully", user.UserID) - - resp := &api.Response{} - accessJwt, err := getAccessJwt(user.UserID, user.Groups, user.Namespace) - if err != nil { - errMsg := fmt.Sprintf("unable to get access jwt (userid=%s,addr=%s):%v", - user.UserID, addr, err) - glog.Errorf(errMsg) - return nil, errors.Errorf(errMsg) - } - - refreshJwt, err := getRefreshJwt(user.UserID, user.Namespace) - if err != nil { - errMsg := fmt.Sprintf("unable to get refresh jwt (userid=%s,addr=%s):%v", - user.UserID, addr, err) - glog.Errorf(errMsg) - return nil, errors.Errorf(errMsg) - } - - loginJwt := api.Jwt{ - AccessJwt: accessJwt, - RefreshJwt: refreshJwt, - } - - jwtBytes, err := proto.Marshal(&loginJwt) - if err != nil { - errMsg := fmt.Sprintf("unable to marshal jwt (userid=%s,addr=%s):%v", - user.UserID, addr, err) - glog.Errorf(errMsg) - return nil, errors.Errorf(errMsg) - } - resp.Json = jwtBytes - return resp, nil -} - -// authenticateLogin authenticates the login request using either the refresh token if present, or -// the pair. If authentication passes, it queries the user's uid and associated -// groups from DB and returns the user object -func (s *Server) authenticateLogin(ctx context.Context, request *api.LoginRequest) (*acl.User, error) { - if err := validateLoginRequest(request); err != nil { - return nil, errors.Wrapf(err, "invalid login request") - } - - var user *acl.User - if len(request.RefreshToken) > 0 { - userData, err := validateToken(request.RefreshToken) - if err != nil { - return nil, errors.Wrapf(err, "unable to authenticate the refresh token %v", - request.RefreshToken) - } - - userId := userData.userId - ctx = x.AttachNamespace(ctx, userData.namespace) - user, err = authorizeUser(ctx, userId, "") - if err != nil { - return nil, errors.Wrapf(err, "while querying user with id %v", userId) - } - - if user == nil { - return nil, errors.Errorf("unable to authenticate: invalid credentials") - } - - user.Namespace = userData.namespace - glog.Infof("Authenticated user %s through refresh token", userId) - return user, nil - } - - // In case of login, we can't extract namespace from JWT because we have not yet given JWT - // to the user, so the login request should contain the namespace, which is then set to ctx. - ctx = x.AttachNamespace(ctx, request.Namespace) - - // authorize the user using password - var err error - user, err = authorizeUser(ctx, request.Userid, request.Password) - if err != nil { - return nil, errors.Wrapf(err, "while querying user with id %v", - request.Userid) - } - - if user == nil { - return nil, errors.Errorf("unable to authenticate: invalid credentials") - } - if !user.PasswordMatch { - return nil, x.ErrorInvalidLogin - } - user.Namespace = request.Namespace - return user, nil -} - -type userData struct { - namespace uint64 - userId string - groupIds []string -} - -// validateToken verifies the signature and expiration of the jwt, and if validation passes, -// returns a slice of strings, where the first element is the extracted userId -// and the rest are groupIds encoded in the jwt. -func validateToken(jwtStr string) (*userData, error) { - claims, err := x.ParseJWT(jwtStr) - if err != nil { - return nil, err - } - // by default, the MapClaims.Valid will return true if the exp field is not set - // here we enforce the checking to make sure that the refresh token has not expired - if exp, err := claims.GetExpirationTime(); err != nil || exp == nil { - return nil, errors.Errorf("Token is expired") // the same error msg that's used inside jwt-go - } - - userId, ok := claims["userid"].(string) - if !ok { - return nil, errors.Errorf("userid in claims is not a string:%v", userId) - } - - /* - * Since, JSON numbers follow JavaScript's double-precision floating-point - * format . . . - * -- references: https://restfulapi.net/json-data-types/ - * -- https://www.tutorialspoint.com/json/json_data_types.htm - * . . . and fraction in IEEE 754 double precision binary floating-point - * format has 52 bits, . . . - * -- references: https://en.wikipedia.org/wiki/Double-precision_floating-point_format - * . . . the namespace field of the struct userData below can - * only accomodate a maximum value of (1 << 52) despite it being declared as - * uint64. Numbers bigger than this are likely to fail the test. - */ - namespace, ok := claims["namespace"].(float64) - if !ok { - return nil, errors.Errorf("namespace in claims is not valid:%v", namespace) - } - - groups, ok := claims["groups"].([]interface{}) - var groupIds []string - if ok { - groupIds = make([]string, 0, len(groups)) - for _, group := range groups { - groupId, ok := group.(string) - if !ok { - // This shouldn't happen. So, no need to make the client try to refresh the tokens. - return nil, errors.Errorf("unable to convert group to string:%v", group) - } - - groupIds = append(groupIds, groupId) - } - } - return &userData{namespace: uint64(namespace), userId: userId, groupIds: groupIds}, nil -} - -// validateLoginRequest validates that the login request has either the refresh token or the -// pair -func validateLoginRequest(request *api.LoginRequest) error { - if request == nil { - return errors.Errorf("the request should not be nil") - } - // we will use the refresh token for authentication if it's set - if len(request.RefreshToken) > 0 { - return nil - } - - // otherwise make sure both userid and password are set - if len(request.Userid) == 0 { - return errors.Errorf("the userid should not be empty") - } - if len(request.Password) == 0 { - return errors.Errorf("the password should not be empty") - } - return nil -} - -// getAccessJwt constructs an access jwt with the given user id, groupIds, namespace -// and expiration TTL specified by worker.Config.AccessJwtTtl -func getAccessJwt(userId string, groups []acl.Group, namespace uint64) (string, error) { - token := jwt.NewWithClaims(worker.Config.AclJwtAlg, jwt.MapClaims{ - "userid": userId, - "groups": acl.GetGroupIDs(groups), - "namespace": namespace, - // set the jwt exp according to the ttl - "exp": time.Now().Add(worker.Config.AccessJwtTtl).Unix(), - }) - - jwtString, err := token.SignedString(x.MaybeKeyToBytes(worker.Config.AclSecretKey)) - if err != nil { - return "", errors.Errorf("unable to encode jwt to string: %v", err) - } - return jwtString, nil -} - -// getRefreshJwt constructs a refresh jwt with the given user id, namespace and expiration ttl -// specified by worker.Config.RefreshJwtTtl -func getRefreshJwt(userId string, namespace uint64) (string, error) { - token := jwt.NewWithClaims(worker.Config.AclJwtAlg, jwt.MapClaims{ - "userid": userId, - "namespace": namespace, - "exp": time.Now().Add(worker.Config.RefreshJwtTtl).Unix(), - }) - - jwtString, err := token.SignedString(x.MaybeKeyToBytes(worker.Config.AclSecretKey)) - if err != nil { - return "", errors.Errorf("unable to encode jwt to string: %v", err) - } - return jwtString, nil -} - -const queryUser = ` - query search($userid: string, $password: string){ - user(func: eq(dgraph.xid, $userid)) @filter(type(dgraph.type.User)) { - uid - dgraph.xid - password_match: checkpwd(dgraph.password, $password) - dgraph.user.group { - uid - dgraph.xid - } - } - }` - -// authorizeUser queries the user with the given user id, and returns the associated uid, -// acl groups, and whether the password stored in DB matches the supplied password -func authorizeUser(ctx context.Context, userid string, password string) ( - *acl.User, error) { - - queryVars := map[string]string{ - "$userid": userid, - "$password": password, - } - req := &Request{ - req: &api.Request{ - Query: queryUser, - Vars: queryVars, - }, - doAuth: NoAuthorize, - } - queryResp, err := (&Server{}).doQuery(ctx, req) - if err != nil { - glog.Errorf("Error while query user with id %s: %v", userid, err) - return nil, err - } - user, err := acl.UnmarshalUser(queryResp, "user") - if err != nil { - return nil, err - } - return user, nil -} - -func refreshAclCache(ctx context.Context, ns, refreshTs uint64) error { - req := &Request{ - req: &api.Request{ - Query: queryAcls, - ReadOnly: true, - StartTs: refreshTs, - }, - doAuth: NoAuthorize, - } - - ctx = x.AttachNamespace(ctx, ns) - queryResp, err := (&Server{}).doQuery(ctx, req) - if err != nil { - return errors.Errorf("unable to retrieve acls: %v", err) - } - groups, err := acl.UnmarshalGroups(queryResp.GetJson(), "allAcls") - if err != nil { - return err - } - - worker.AclCachePtr.Update(ns, groups) - glog.V(2).Infof("Updated the ACL cache for namespace: %#x", ns) - return nil - -} - -func RefreshACLs(ctx context.Context) { - for ns := range schema.State().Namespaces() { - if err := refreshAclCache(ctx, ns, 0); err != nil { - glog.Errorf("Error while retrieving acls for namespace %#x: %v", ns, err) - } - } - worker.AclCachePtr.Set() -} - -// SubscribeForAclUpdates subscribes for ACL predicates and updates the acl cache. -func SubscribeForAclUpdates(closer *z.Closer) { - defer func() { - glog.Infoln("RefreshAcls closed") - closer.Done() - }() - if worker.Config.AclSecretKey == nil { - // the acl feature is not turned on - return - } - - var maxRefreshTs uint64 - retrieveAcls := func(ns uint64, refreshTs uint64) error { - if refreshTs <= maxRefreshTs { - return nil - } - maxRefreshTs = refreshTs - return refreshAclCache(closer.Ctx(), ns, refreshTs) - } - - closer.AddRunning(1) - go worker.SubscribeForUpdates(aclPrefixes, x.IgnoreBytes, func(kvs *bpb.KVList) { - if kvs == nil || len(kvs.Kv) == 0 { - return - } - kv := x.KvWithMaxVersion(kvs, aclPrefixes) - pk, err := x.Parse(kv.GetKey()) - if err != nil { - glog.Fatalf("Got a key from subscription which is not parsable: %s", err) - } - glog.V(3).Infof("Got ACL update via subscription for attr: %s", pk.Attr) - - ns, _ := x.ParseNamespaceAttr(pk.Attr) - if err := retrieveAcls(ns, kv.GetVersion()); err != nil { - glog.Errorf("Error while retrieving acls: %v", err) - } - }, 1, closer) - - <-closer.HasBeenClosed() -} - -const queryAcls = ` -{ - allAcls(func: type(dgraph.type.Group)) { - dgraph.xid - dgraph.acl.rule { - dgraph.rule.predicate - dgraph.rule.permission - } - ~dgraph.user.group{ - dgraph.xid - } - } -} -` - -var aclPrefixes = [][]byte{ - x.PredicatePrefix(x.GalaxyAttr("dgraph.rule.permission")), - x.PredicatePrefix(x.GalaxyAttr("dgraph.rule.predicate")), - x.PredicatePrefix(x.GalaxyAttr("dgraph.acl.rule")), - x.PredicatePrefix(x.GalaxyAttr("dgraph.user.group")), - x.PredicatePrefix(x.GalaxyAttr("dgraph.type.Group")), - x.PredicatePrefix(x.GalaxyAttr("dgraph.xid")), -} - -// upserts the Groot account. -func InitializeAcl(closer *z.Closer) { - defer func() { - glog.Infof("InitializeAcl closed") - closer.Done() - }() - - if worker.Config.AclSecretKey == nil { - // The acl feature is not turned on. - return - } - upsertGuardianAndGroot(closer, x.GalaxyNamespace) -} - -// Note: The handling of closer should be done by caller. -func upsertGuardianAndGroot(closer *z.Closer, ns uint64) { - if worker.Config.AclSecretKey == nil { - // The acl feature is not turned on. - return - } - for closer.Ctx().Err() == nil { - ctx, cancel := context.WithTimeout(closer.Ctx(), time.Minute) - defer cancel() - ctx = x.AttachNamespace(ctx, ns) - if err := upsertGuardian(ctx); err != nil { - glog.Infof("Unable to upsert the guardian group. Error: %v", err) - time.Sleep(100 * time.Millisecond) - continue - } - break - } - - for closer.Ctx().Err() == nil { - ctx, cancel := context.WithTimeout(closer.Ctx(), time.Minute) - defer cancel() - ctx = x.AttachNamespace(ctx, ns) - if err := upsertGroot(ctx, "password"); err != nil { - glog.Infof("Unable to upsert the groot account. Error: %v", err) - time.Sleep(100 * time.Millisecond) - continue - } - break - } -} - -// upsertGuardian must be called after setting the namespace in the context. -func upsertGuardian(ctx context.Context) error { - query := fmt.Sprintf(` - { - guid as guardians(func: eq(dgraph.xid, "%s")) @filter(type(dgraph.type.Group)) { - uid - } - } - `, x.GuardiansId) - groupNQuads := acl.CreateGroupNQuads(x.GuardiansId) - req := &Request{ - req: &api.Request{ - CommitNow: true, - Query: query, - Mutations: []*api.Mutation{ - { - Set: groupNQuads, - Cond: "@if(eq(len(guid), 0))", - }, - }, - }, - doAuth: NoAuthorize, - } - - resp, err := (&Server{}).doQuery(ctx, req) - - // Structs to parse guardians group uid from query response - type groupNode struct { - Uid string `json:"uid"` - } - - type groupQryResp struct { - GuardiansGroup []groupNode `json:"guardians"` - } - - if err != nil { - return errors.Wrapf(err, "while upserting group with id %s", x.GuardiansId) - } - var groupResp groupQryResp - var guardiansUidStr string - if err := json.Unmarshal(resp.GetJson(), &groupResp); err != nil { - return errors.Wrap(err, "Couldn't unmarshal response from guardians group query") - } - - if len(groupResp.GuardiansGroup) == 0 { - // no guardians group found - // Extract guardians group uid from mutation - newGroupUidMap := resp.GetUids() - guardiansUidStr = newGroupUidMap["newgroup"] - } else if len(groupResp.GuardiansGroup) == 1 { - // we found a guardians group - guardiansUidStr = groupResp.GuardiansGroup[0].Uid - } else { - return errors.Wrap(err, "Multiple guardians group found") - } - - uid, err := strconv.ParseUint(guardiansUidStr, 0, 64) - if err != nil { - return errors.Wrapf(err, "Error while parsing Uid: %s of guardians Group", guardiansUidStr) - } - ns, err := x.ExtractNamespace(ctx) - if err != nil { - return errors.Wrapf(err, "While upserting group with id %s", x.GuardiansId) - } - x.GuardiansUid.Store(ns, uid) - glog.V(2).Infof("Successfully upserted the guardian of namespace: %d\n", ns) - return nil -} - -// upsertGroot must be called after setting the namespace in the context. -func upsertGroot(ctx context.Context, passwd string) error { - // groot is the default user of guardians group. - query := fmt.Sprintf(` - { - grootid as grootUser(func: eq(dgraph.xid, "%s")) @filter(type(dgraph.type.User)) { - uid - } - guid as var(func: eq(dgraph.xid, "%s")) @filter(type(dgraph.type.Group)) - } - `, x.GrootId, x.GuardiansId) - userNQuads := acl.CreateUserNQuads(x.GrootId, passwd) - userNQuads = append(userNQuads, &api.NQuad{ - Subject: "_:newuser", - Predicate: "dgraph.user.group", - ObjectId: "uid(guid)", - }) - req := &Request{ - req: &api.Request{ - CommitNow: true, - Query: query, - Mutations: []*api.Mutation{ - { - Set: userNQuads, - // Assuming that if groot exists, it is in guardian group - Cond: "@if(eq(len(grootid), 0) and gt(len(guid), 0))", - }, - }, - }, - doAuth: NoAuthorize, - } - - resp, err := (&Server{}).doQuery(ctx, req) - if err != nil { - return errors.Wrapf(err, "while upserting user with id %s", x.GrootId) - } - - // Structs to parse groot user uid from query response - type userNode struct { - Uid string `json:"uid"` - } - - type userQryResp struct { - GrootUser []userNode `json:"grootUser"` - } - - var grootUserUid string - var userResp userQryResp - if err := json.Unmarshal(resp.GetJson(), &userResp); err != nil { - return errors.Wrap(err, "Couldn't unmarshal response from groot user query") - } - if len(userResp.GrootUser) == 0 { - // no groot user found from query - // Extract uid of created groot user from mutation - newUserUidMap := resp.GetUids() - grootUserUid = newUserUidMap["newuser"] - } else if len(userResp.GrootUser) == 1 { - // we found a groot user - grootUserUid = userResp.GrootUser[0].Uid - } else { - return errors.Wrap(err, "Multiple groot users found") - } - - uid, err := strconv.ParseUint(grootUserUid, 0, 64) - if err != nil { - return errors.Wrapf(err, "Error while parsing Uid: %s of groot user", grootUserUid) - } - ns, err := x.ExtractNamespace(ctx) - if err != nil { - return errors.Wrapf(err, "While upserting user with id %s", x.GrootId) - } - x.GrootUid.Store(ns, uid) - glog.V(2).Infof("Successfully upserted groot account for namespace %d\n", ns) - return nil -} - -// extract the userId, groupIds from the accessJwt in the context -func extractUserAndGroups(ctx context.Context) (*userData, error) { - accessJwt, err := x.ExtractJwt(ctx) - if err != nil { - return nil, err - } - return validateToken(accessJwt) -} - -type authPredResult struct { - allowed []string - blocked map[string]struct{} -} - -func authorizePreds(ctx context.Context, userData *userData, preds []string, - aclOp *acl.Operation) *authPredResult { - - if !worker.AclCachePtr.Loaded() { - RefreshACLs(ctx) - } - - userId := userData.userId - groupIds := userData.groupIds - ns := userData.namespace - blockedPreds := make(map[string]struct{}) - for _, pred := range preds { - nsPred := x.NamespaceAttr(ns, pred) - if err := worker.AclCachePtr.AuthorizePredicate(groupIds, nsPred, aclOp); err != nil { - logAccess(&accessEntry{ - userId: userId, - groups: groupIds, - preds: preds, - operation: aclOp, - allowed: false, - }) - blockedPreds[pred] = struct{}{} - } - } - if worker.HasAccessToAllPreds(ns, groupIds, aclOp) { - // Setting allowed to nil allows access to all predicates. Note that the access to ACL - // predicates will still be blocked. - return &authPredResult{allowed: nil, blocked: blockedPreds} - } - // User can have multiple permission for same predicate, add predicate - allowedPreds := make([]string, 0, len(worker.AclCachePtr.GetUserPredPerms(userId))) - // only if the acl.Op is covered in the set of permissions for the user - for predicate, perm := range worker.AclCachePtr.GetUserPredPerms(userId) { - if (perm & aclOp.Code) > 0 { - allowedPreds = append(allowedPreds, predicate) - } - } - return &authPredResult{allowed: allowedPreds, blocked: blockedPreds} -} - -// authorizeAlter parses the Schema in the operation and authorizes the operation -// using the worker.AclCachePtr. It will return error if any one of the predicates -// specified in alter are not authorized. -func authorizeAlter(ctx context.Context, op *api.Operation) error { - if worker.Config.AclSecretKey == nil { - // the user has not turned on the acl feature - return nil - } - - // extract the list of predicates from the operation object - var preds []string - switch { - case len(op.DropAttr) > 0: - preds = []string{op.DropAttr} - case op.DropOp == api.Operation_ATTR && len(op.DropValue) > 0: - preds = []string{op.DropValue} - default: - update, err := schema.Parse(op.Schema) - if err != nil { - return err - } - - for _, u := range update.Preds { - preds = append(preds, x.ParseAttr(u.Predicate)) - } - } - var userId string - var groupIds []string - - // doAuthorizeAlter checks if alter of all the predicates are allowed - // as a byproduct, it also sets the userId, groups variables - doAuthorizeAlter := func() error { - userData, err := extractUserAndGroups(ctx) - if err != nil { - // We don't follow fail open approach anymore. - return status.Error(codes.Unauthenticated, err.Error()) - } - - userId = userData.userId - groupIds = userData.groupIds - - if x.IsGuardian(groupIds) { - // Members of guardian group are allowed to alter anything. - return nil - } - - // if we get here, we know the user is not a guardian. - if isDropAll(op) || op.DropOp == api.Operation_DATA { - return errors.Errorf( - "only guardians are allowed to drop all data, but the current user is %s", userId) - } - - result := authorizePreds(ctx, userData, preds, acl.Modify) - if len(result.blocked) > 0 { - var msg strings.Builder - for key := range result.blocked { - x.Check2(msg.WriteString(key)) - x.Check2(msg.WriteString(" ")) - } - return status.Errorf(codes.PermissionDenied, - "unauthorized to alter following predicates: %s\n", msg.String()) - } - return nil - } - - err := doAuthorizeAlter() - span := otrace.FromContext(ctx) - if span != nil { - span.Annotatef(nil, (&accessEntry{ - userId: userId, - groups: groupIds, - preds: preds, - operation: acl.Modify, - allowed: err == nil, - }).String()) - } - - return err -} - -// parsePredsFromMutation returns a union set of all the predicate names in the input nquads -func parsePredsFromMutation(nquads []*api.NQuad) []string { - // use a map to dedup predicates - predsMap := make(map[string]struct{}) - for _, nquad := range nquads { - // _STAR_ALL is not a predicate in itself. - if nquad.Predicate != "_STAR_ALL" { - predsMap[nquad.Predicate] = struct{}{} - } - } - - preds := make([]string, 0, len(predsMap)) - for pred := range predsMap { - preds = append(preds, pred) - } - - return preds -} - -func isAclPredMutation(nquads []*api.NQuad) bool { - for _, nquad := range nquads { - if nquad.Predicate == "dgraph.group.acl" && nquad.ObjectValue != nil { - // this mutation is trying to change the permission of some predicate - // check if the predicate list contains an ACL predicate - if _, ok := nquad.ObjectValue.Val.(*api.Value_BytesVal); ok { - aclBytes := nquad.ObjectValue.Val.(*api.Value_BytesVal) - var aclsToChange []acl.Acl - err := json.Unmarshal(aclBytes.BytesVal, &aclsToChange) - if err != nil { - glog.Errorf(fmt.Sprintf("Unable to unmarshal bytes under the dgraph.group.acl "+ - "predicate: %v", err)) - continue - } - for _, aclToChange := range aclsToChange { - if x.IsAclPredicate(aclToChange.Predicate) { - return true - } - } - } - } - } - return false -} - -// authorizeMutation authorizes the mutation using the worker.AclCachePtr. It will return permission -// denied error if any one of the predicates in mutation(set or delete) is unauthorized. -// At this stage, namespace is not attached in the predicates. -func authorizeMutation(ctx context.Context, gmu *dql.Mutation) error { - if worker.Config.AclSecretKey == nil { - // the user has not turned on the acl feature - return nil - } - - preds := parsePredsFromMutation(gmu.Set) - // Del predicates weren't included before. - // A bug probably since f115de2eb6a40d882a86c64da68bf5c2a33ef69a - preds = append(preds, parsePredsFromMutation(gmu.Del)...) - - var userId string - var groupIds []string - // doAuthorizeMutation checks if modification of all the predicates are allowed - // as a byproduct, it also sets the userId and groups - doAuthorizeMutation := func() error { - userData, err := extractUserAndGroups(ctx) - if err != nil { - // We don't follow fail open approach anymore. - return status.Error(codes.Unauthenticated, err.Error()) - } - - userId = userData.userId - groupIds = userData.groupIds - - if x.IsGuardian(groupIds) { - // Members of guardians group are allowed to mutate anything - // (including delete) except the permission of the acl predicates. - switch { - case isAclPredMutation(gmu.Set): - return errors.Errorf("the permission of ACL predicates can not be changed") - case isAclPredMutation(gmu.Del): - return errors.Errorf("ACL predicates can't be deleted") - } - if !shouldAllowAcls(userData.namespace) { - for _, pred := range preds { - if x.IsAclPredicate(pred) { - return status.Errorf(codes.PermissionDenied, - "unauthorized to mutate acl predicates: %s\n", pred) - } - } - } - return nil - } - result := authorizePreds(ctx, userData, preds, acl.Write) - if len(result.blocked) > 0 { - var msg strings.Builder - for key := range result.blocked { - x.Check2(msg.WriteString(key)) - x.Check2(msg.WriteString(" ")) - } - return status.Errorf(codes.PermissionDenied, - "unauthorized to mutate following predicates: %s\n", msg.String()) - } - gmu.AllowedPreds = result.allowed - return nil - } - - err := doAuthorizeMutation() - - span := otrace.FromContext(ctx) - if span != nil { - span.Annotatef(nil, (&accessEntry{ - userId: userId, - groups: groupIds, - preds: preds, - operation: acl.Write, - allowed: err == nil, - }).String()) - } - - return err -} - -func parsePredsFromQuery(dqls []*dql.GraphQuery) predsAndvars { - predsMap := make(map[string]struct{}) - varsMap := make(map[string]string) - for _, gq := range dqls { - if gq.Func != nil { - predsMap[gq.Func.Attr] = struct{}{} - } - if len(gq.Var) > 0 { - varsMap[gq.Var] = gq.Attr - } - if len(gq.Attr) > 0 && gq.Attr != "uid" && gq.Attr != "expand" && gq.Attr != "val" { - predsMap[gq.Attr] = struct{}{} - - } - for _, ord := range gq.Order { - predsMap[ord.Attr] = struct{}{} - } - for _, gbAttr := range gq.GroupbyAttrs { - predsMap[gbAttr.Attr] = struct{}{} - } - for _, pred := range parsePredsFromFilter(gq.Filter) { - predsMap[pred] = struct{}{} - } - childPredandVars := parsePredsFromQuery(gq.Children) - for _, childPred := range childPredandVars.preds { - predsMap[childPred] = struct{}{} - } - for childVar := range childPredandVars.vars { - varsMap[childVar] = childPredandVars.vars[childVar] - } - } - preds := make([]string, 0, len(predsMap)) - for pred := range predsMap { - if len(pred) > 0 { - if _, found := varsMap[pred]; !found { - preds = append(preds, pred) - } - } - } - - pv := predsAndvars{preds: preds, vars: varsMap} - return pv -} - -func parsePredsFromFilter(f *dql.FilterTree) []string { - var preds []string - if f == nil { - return preds - } - if f.Func != nil && len(f.Func.Attr) > 0 { - preds = append(preds, f.Func.Attr) - } - for _, ch := range f.Child { - preds = append(preds, parsePredsFromFilter(ch)...) - } - return preds -} - -type accessEntry struct { - userId string - groups []string - preds []string - operation *acl.Operation - allowed bool -} - -func (log *accessEntry) String() string { - return fmt.Sprintf("ACL-LOG Authorizing user %q with groups %q on predicates %q "+ - "for %q, allowed:%v", log.userId, strings.Join(log.groups, ","), - strings.Join(log.preds, ","), log.operation.Name, log.allowed) -} - -func logAccess(log *accessEntry) { - if glog.V(1) { - glog.Info(log.String()) - } -} - -func blockedPreds(preds []string) map[string]struct{} { - blocked := make(map[string]struct{}) - for _, pred := range preds { - if x.IsAclPredicate(pred) { - blocked[pred] = struct{}{} - } - } - return blocked -} - -// With shared instance enabled, we don't allow ACL operations from any of the non-galaxy namespace. -func shouldAllowAcls(ns uint64) bool { - return !x.Config.SharedInstance || ns == x.GalaxyNamespace -} - -// authorizeQuery authorizes the query using the aclCachePtr. It will silently drop all -// unauthorized predicates from query. -// At this stage, namespace is not attached in the predicates. -func authorizeQuery(ctx context.Context, parsedReq *dql.Result, graphql bool) error { - if worker.Config.AclSecretKey == nil { - // the user has not turned on the acl feature - return nil - } - - var userId string - var groupIds []string - var namespace uint64 - predsAndvars := parsePredsFromQuery(parsedReq.Query) - preds := predsAndvars.preds - varsToPredMap := predsAndvars.vars - - // Need this to efficiently identify blocked variables from the - // list of blocked predicates - predToVarsMap := make(map[string]string) - for k, v := range varsToPredMap { - predToVarsMap[v] = k - } - - doAuthorizeQuery := func() (map[string]struct{}, []string, error) { - userData, err := extractUserAndGroups(ctx) - if err != nil { - return nil, nil, status.Error(codes.Unauthenticated, err.Error()) - } - - userId = userData.userId - groupIds = userData.groupIds - namespace = userData.namespace - - if x.IsGuardian(groupIds) { - if shouldAllowAcls(userData.namespace) { - // Members of guardian groups are allowed to query anything. - return nil, nil, nil - } - return blockedPreds(preds), nil, nil - } - - result := authorizePreds(ctx, userData, preds, acl.Read) - return result.blocked, result.allowed, nil - } - - blockedPreds, allowedPreds, err := doAuthorizeQuery() - if err != nil { - return err - } - - if span := otrace.FromContext(ctx); span != nil { - span.Annotatef(nil, (&accessEntry{ - userId: userId, - groups: groupIds, - preds: preds, - operation: acl.Read, - allowed: err == nil, - }).String()) - } - - if len(blockedPreds) != 0 { - // For GraphQL requests, we allow filtered access to the ACL predicates. - // Filter for user_id and group_id is applied for the currently logged in user. - if graphql && shouldAllowAcls(namespace) { - for _, gq := range parsedReq.Query { - addUserFilterToQuery(gq, userId, groupIds) - } - // blockedPreds might have acl predicates which we want to allow access through - // graphql, so deleting those from here. - for _, pred := range x.AllACLPredicates() { - delete(blockedPreds, pred) - } - // In query context ~predicate and predicate are considered different. - delete(blockedPreds, "~dgraph.user.group") - } - - blockedVars := make(map[string]struct{}) - for predicate := range blockedPreds { - if variable, found := predToVarsMap[predicate]; found { - // Add variables to blockedPreds to delete from Query - blockedPreds[variable] = struct{}{} - // Collect blocked Variables to remove from QueryVars - blockedVars[variable] = struct{}{} - } - } - parsedReq.Query = removePredsFromQuery(parsedReq.Query, blockedPreds) - parsedReq.QueryVars = removeVarsFromQueryVars(parsedReq.QueryVars, blockedVars) - } - for i := range parsedReq.Query { - parsedReq.Query[i].AllowedPreds = allowedPreds - } - - return nil -} - -func authorizeSchemaQuery(ctx context.Context, er *query.ExecutionResult) error { - if worker.Config.AclSecretKey == nil { - // the user has not turned on the acl feature - return nil - } - - // find the predicates being sent in response - preds := make([]string, 0) - predsMap := make(map[string]struct{}) - for _, predNode := range er.SchemaNode { - preds = append(preds, predNode.Predicate) - predsMap[predNode.Predicate] = struct{}{} - } - for _, typeNode := range er.Types { - for _, field := range typeNode.Fields { - if _, ok := predsMap[field.Predicate]; !ok { - preds = append(preds, field.Predicate) - } - } - } - - doAuthorizeSchemaQuery := func() (map[string]struct{}, error) { - userData, err := extractUserAndGroups(ctx) - if err != nil { - return nil, status.Error(codes.Unauthenticated, err.Error()) - } - - groupIds := userData.groupIds - if x.IsGuardian(groupIds) { - if shouldAllowAcls(userData.namespace) { - // Members of guardian groups are allowed to query anything. - return nil, nil - } - return blockedPreds(preds), nil - } - result := authorizePreds(ctx, userData, preds, acl.Read) - return result.blocked, nil - } - - // find the predicates which are blocked for the schema query - blockedPreds, err := doAuthorizeSchemaQuery() - if err != nil { - return err - } - - // remove those predicates from response - if len(blockedPreds) > 0 { - respPreds := make([]*pb.SchemaNode, 0) - for _, predNode := range er.SchemaNode { - if _, ok := blockedPreds[predNode.Predicate]; !ok { - respPreds = append(respPreds, predNode) - } - } - er.SchemaNode = respPreds - - for _, typeNode := range er.Types { - respFields := make([]*pb.SchemaUpdate, 0) - for _, field := range typeNode.Fields { - if _, ok := blockedPreds[field.Predicate]; !ok { - respFields = append(respFields, field) - } - } - typeNode.Fields = respFields - } - } - - return nil -} - -// AuthGuardianOfTheGalaxy authorizes the operations for the users who belong to the guardians -// group in the galaxy namespace. This authorization is used for admin usages like creation and -// deletion of a namespace, resetting passwords across namespaces etc. -// NOTE: The caller should not wrap the error returned. If needed, propagate the GRPC error code. -func AuthGuardianOfTheGalaxy(ctx context.Context) error { - if !x.WorkerConfig.AclEnabled { - return nil - } - ns, err := x.ExtractNamespaceFrom(ctx) - if err != nil { - return errors.Wrap(err, "Authorize guardian of the galaxy, extracting jwt token, error:") - } - if ns != 0 { - return status.Error( - codes.PermissionDenied, "Only guardian of galaxy is allowed to do this operation") - } - // AuthorizeGuardians will extract (user, []groups) from the JWT claims and will check if - // any of the group to which the user belongs is "guardians" or not. - if err := AuthorizeGuardians(ctx); err != nil { - s := status.Convert(err) - return status.Error( - s.Code(), "AuthGuardianOfTheGalaxy: failed to authorize guardians. "+s.Message()) - } - glog.V(3).Info("Successfully authorised guardian of the galaxy") - return nil -} - -// AuthorizeGuardians authorizes the operation for users which belong to Guardians group. -// NOTE: The caller should not wrap the error returned. If needed, propagate the GRPC error code. -func AuthorizeGuardians(ctx context.Context) error { - if worker.Config.AclSecretKey == nil { - // the user has not turned on the acl feature - return nil - } - - userData, err := extractUserAndGroups(ctx) - switch { - case err == x.ErrNoJwt: - return status.Error(codes.PermissionDenied, err.Error()) - case err != nil: - return status.Error(codes.Unauthenticated, err.Error()) - default: - userId := userData.userId - groupIds := userData.groupIds - - if !x.IsGuardian(groupIds) { - // Deny access for members of non-guardian groups - return status.Error(codes.PermissionDenied, fmt.Sprintf("Only guardians are "+ - "allowed access. User '%v' is not a member of guardians group.", userId)) - } - } - - return nil -} - -/* -addUserFilterToQuery applies makes sure that a user can access only its own -acl info by applying filter of userid and groupid to acl predicates. A query like -Conversion pattern: - - me(func: type(dgraph.type.Group)) -> - me(func: type(dgraph.type.Group)) @filter(eq("dgraph.xid", groupIds...)) - - me(func: type(dgraph.type.User)) -> - me(func: type(dgraph.type.User)) @filter(eq("dgraph.xid", userId)) -*/ -func addUserFilterToQuery(gq *dql.GraphQuery, userId string, groupIds []string) { - if gq.Func != nil && gq.Func.Name == "type" { - // type function only supports one argument - if len(gq.Func.Args) != 1 { - return - } - arg := gq.Func.Args[0] - // The case where value of some varialble v (say) is "dgraph.type.Group" and a - // query comes like `eq(dgraph.type, val(v))`, will be ignored here. - if arg.Value == "dgraph.type.User" { - newFilter := userFilter(userId) - gq.Filter = parentFilter(newFilter, gq.Filter) - } else if arg.Value == "dgraph.type.Group" { - newFilter := groupFilter(groupIds) - gq.Filter = parentFilter(newFilter, gq.Filter) - } - } - - gq.Filter = addUserFilterToFilter(gq.Filter, userId, groupIds) - - switch gq.Attr { - case "dgraph.user.group": - newFilter := groupFilter(groupIds) - gq.Filter = parentFilter(newFilter, gq.Filter) - case "~dgraph.user.group": - newFilter := userFilter(userId) - gq.Filter = parentFilter(newFilter, gq.Filter) - } - - for _, ch := range gq.Children { - addUserFilterToQuery(ch, userId, groupIds) - } -} - -func parentFilter(newFilter, filter *dql.FilterTree) *dql.FilterTree { - if filter == nil { - return newFilter - } - parentFilter := &dql.FilterTree{ - Op: "AND", - Child: []*dql.FilterTree{filter, newFilter}, - } - return parentFilter -} - -func userFilter(userId string) *dql.FilterTree { - // A logged in user should always have a userId. - return &dql.FilterTree{ - Func: &dql.Function{ - Attr: "dgraph.xid", - Name: "eq", - Args: []dql.Arg{{Value: userId}}, - }, - } -} - -func groupFilter(groupIds []string) *dql.FilterTree { - // The user doesn't have any groups, so add an empty filter @filter(uid([])) so that all - // groups are filtered out. - if len(groupIds) == 0 { - filter := &dql.FilterTree{ - Func: &dql.Function{ - Name: "uid", - UID: []uint64{}, - }, - } - return filter - } - - filter := &dql.FilterTree{ - Func: &dql.Function{ - Attr: "dgraph.xid", - Name: "eq", - }, - } - - for _, gid := range groupIds { - filter.Func.Args = append(filter.Func.Args, - dql.Arg{Value: gid}) - } - - return filter -} - -/* - addUserFilterToFilter makes sure that user can't misue filters to access other user's info. - If the *filter* have type(dgraph.type.Group) or type(dgraph.type.User) functions, - it generate a *newFilter* with function like eq(dgraph.xid, userId) or eq(dgraph.xid,groupId...) - and return a filter of the form - - &dql.FilterTree{ - Op: "AND", - Child: []dql.FilterTree{ - {filter, newFilter} - } - } -*/ -func addUserFilterToFilter(filter *dql.FilterTree, userId string, - groupIds []string) *dql.FilterTree { - - if filter == nil { - return nil - } - - if filter.Func != nil && filter.Func.Name == "type" { - - // type function supports only one argument - if len(filter.Func.Args) != 1 { - return nil - } - arg := filter.Func.Args[0] - var newFilter *dql.FilterTree - switch arg.Value { - case "dgraph.type.User": - newFilter = userFilter(userId) - case "dgraph.type.Group": - newFilter = groupFilter(groupIds) - } - - // If filter have function, it can't have children. - return parentFilter(newFilter, filter) - } - - for idx, child := range filter.Child { - filter.Child[idx] = addUserFilterToFilter(child, userId, groupIds) - } - - return filter -} - -// removePredsFromQuery removes all the predicates in blockedPreds -// from all the queries in gqs. -func removePredsFromQuery(gqs []*dql.GraphQuery, - blockedPreds map[string]struct{}) []*dql.GraphQuery { - - filteredGQs := gqs[:0] -L: - for _, gq := range gqs { - if gq.Func != nil && len(gq.Func.Attr) > 0 { - if _, ok := blockedPreds[gq.Func.Attr]; ok { - continue - } - } - if len(gq.Attr) > 0 { - if _, ok := blockedPreds[gq.Attr]; ok { - continue - } - if gq.Attr == "val" { - // TODO (Anurag): If val supports multiple variables, this would - // need an upgrade - for _, variable := range gq.NeedsVar { - if _, ok := blockedPreds[variable.Name]; ok { - continue L - } - } - } - } - - order := gq.Order[:0] - for _, ord := range gq.Order { - if _, ok := blockedPreds[ord.Attr]; ok { - continue - } - order = append(order, ord) - } - - gq.Order = order - gq.Filter = removeFilters(gq.Filter, blockedPreds) - gq.GroupbyAttrs = removeGroupBy(gq.GroupbyAttrs, blockedPreds) - gq.Children = removePredsFromQuery(gq.Children, blockedPreds) - filteredGQs = append(filteredGQs, gq) - } - - return filteredGQs -} - -func removeVarsFromQueryVars(gqs []*dql.Vars, - blockedVars map[string]struct{}) []*dql.Vars { - - filteredGQs := gqs[:0] - for _, gq := range gqs { - var defines []string - var needs []string - for _, variable := range gq.Defines { - if _, ok := blockedVars[variable]; !ok { - defines = append(defines, variable) - } - } - for _, variable := range gq.Needs { - if _, ok := blockedVars[variable]; !ok { - needs = append(needs, variable) - } - } - gq.Defines = defines - gq.Needs = needs - filteredGQs = append(filteredGQs, gq) - } - return filteredGQs -} - -func removeFilters(f *dql.FilterTree, blockedPreds map[string]struct{}) *dql.FilterTree { - if f == nil { - return nil - } - if f.Func != nil && len(f.Func.Attr) > 0 { - if _, ok := blockedPreds[f.Func.Attr]; ok { - return nil - } - } - - filteredChildren := f.Child[:0] - for _, ch := range f.Child { - child := removeFilters(ch, blockedPreds) - if child != nil { - filteredChildren = append(filteredChildren, child) - } - } - if len(filteredChildren) != len(f.Child) && (f.Op == "AND" || f.Op == "NOT") { - return nil - } - f.Child = filteredChildren - return f -} - -func removeGroupBy(gbAttrs []dql.GroupByAttr, - blockedPreds map[string]struct{}) []dql.GroupByAttr { - - filteredGbAttrs := gbAttrs[:0] - for _, gbAttr := range gbAttrs { - if _, ok := blockedPreds[gbAttr.Attr]; ok { - continue - } - filteredGbAttrs = append(filteredGbAttrs, gbAttr) - } - return filteredGbAttrs -} diff --git a/edgraph/access_ee_test.go b/edgraph/access_test.go similarity index 98% rename from edgraph/access_ee_test.go rename to edgraph/access_test.go index 4f9d24b6ca5..c63fe225beb 100644 --- a/edgraph/access_ee_test.go +++ b/edgraph/access_test.go @@ -1,8 +1,6 @@ -//go:build !oss -// +build !oss - /* * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 */ package edgraph diff --git a/edgraph/multi_tenancy.go b/edgraph/multi_tenancy.go index 18ceba2e8ee..3d7e5615c10 100644 --- a/edgraph/multi_tenancy.go +++ b/edgraph/multi_tenancy.go @@ -1,6 +1,3 @@ -//go:build oss -// +build oss - /* * SPDX-FileCopyrightText: © Hypermode Inc. * SPDX-License-Identifier: Apache-2.0 @@ -8,7 +5,22 @@ package edgraph -import "context" +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/golang/glog" + "github.com/pkg/errors" + + "github.com/dgraph-io/dgo/v240/protos/api" + "github.com/hypermodeinc/dgraph/v24/protos/pb" + "github.com/hypermodeinc/dgraph/v24/query" + "github.com/hypermodeinc/dgraph/v24/schema" + "github.com/hypermodeinc/dgraph/v24/worker" + "github.com/hypermodeinc/dgraph/v24/x" +) type ResetPasswordInput struct { UserID string @@ -16,18 +28,110 @@ type ResetPasswordInput struct { Namespace uint64 } -func (s *Server) CreateNamespaceInternal(ctx context.Context, passwd string) (uint64, error) { - return 0, nil -} +func (s *Server) ResetPassword(ctx context.Context, inp *ResetPasswordInput) error { + query := fmt.Sprintf(`{ + x as updateUser(func: eq(dgraph.xid, "%s")) @filter(type(dgraph.type.User)) { + uid + } + }`, inp.UserID) -func (s *Server) DeleteNamespace(ctx context.Context, namespace uint64) error { + userNQuads := []*api.NQuad{ + { + Subject: "uid(x)", + Predicate: "dgraph.password", + ObjectValue: &api.Value{Val: &api.Value_StrVal{StrVal: inp.Password}}, + }, + } + req := &Request{ + req: &api.Request{ + CommitNow: true, + Query: query, + Mutations: []*api.Mutation{ + { + Set: userNQuads, + Cond: "@if(gt(len(x), 0))", + }, + }, + }, + doAuth: NoAuthorize, + } + ctx = x.AttachNamespace(ctx, inp.Namespace) + resp, err := (&Server{}).doQuery(ctx, req) + if err != nil { + return errors.Wrapf(err, "Reset password for user %s in namespace %d, got error:", + inp.UserID, inp.Namespace) + } + + type userNode struct { + Uid string `json:"uid"` + } + + type userQryResp struct { + User []userNode `json:"updateUser"` + } + var userResp userQryResp + if err := json.Unmarshal(resp.GetJson(), &userResp); err != nil { + return errors.Wrap(err, "Reset password failed with error") + } + + if len(userResp.User) == 0 { + return errors.New("Failed to reset password, user doesn't exist") + } return nil } -func (s *Server) ResetPassword(ctx context.Context, ns *ResetPasswordInput) error { - return nil +// CreateNamespaceInternal creates a new namespace. Only guardian of galaxy is authorized to do so. +// Authorization is handled by middlewares. +func (s *Server) CreateNamespaceInternal(ctx context.Context, passwd string) (uint64, error) { + glog.V(2).Info("Got create namespace request.") + + num := &pb.Num{Val: 1, Type: pb.Num_NS_ID} + ids, err := worker.AssignNsIdsOverNetwork(ctx, num) + if err != nil { + return 0, errors.Wrapf(err, "Creating namespace, got error:") + } + + ns := ids.StartId + glog.V(2).Infof("Got a lease for NsID: %d", ns) + + // Attach the newly leased NsID in the context in order to create guardians/groot for it. + ctx = x.AttachNamespace(ctx, ns) + m := &pb.Mutations{StartTs: worker.State.GetTimestamp(false)} + m.Schema = schema.InitialSchema(ns) + m.Types = schema.InitialTypes(ns) + _, err = query.ApplyMutations(ctx, m) + if err != nil { + return 0, err + } + + err = x.RetryUntilSuccess(10, 100*time.Millisecond, func() error { + return createGuardianAndGroot(ctx, ids.StartId, passwd) + }) + if err != nil { + return 0, errors.Wrapf(err, "Failed to create guardian and groot: ") + } + glog.V(2).Infof("Created namespace: %d", ns) + return ns, nil } +// This function is used while creating new namespace. New namespace creation is only allowed +// by the guardians of the galaxy group. func createGuardianAndGroot(ctx context.Context, namespace uint64, passwd string) error { + if err := upsertGuardian(ctx); err != nil { + return errors.Wrap(err, "While creating Guardian") + } + if err := upsertGroot(ctx, passwd); err != nil { + return errors.Wrap(err, "While creating Groot") + } return nil } + +// DeleteNamespace deletes a new namespace. Only guardian of galaxy is authorized to do so. +// Authorization is handled by middlewares. +func (s *Server) DeleteNamespace(ctx context.Context, namespace uint64) error { + glog.Info("Deleting namespace", namespace) + if _, ok := schema.State().Namespaces()[namespace]; !ok { + return errors.Errorf("error deleting non-existing namespace %#x", namespace) + } + return worker.ProcessDeleteNsRequest(ctx, namespace) +} diff --git a/edgraph/multi_tenancy_ee.go b/edgraph/multi_tenancy_ee.go deleted file mode 100644 index 58c30c34f19..00000000000 --- a/edgraph/multi_tenancy_ee.go +++ /dev/null @@ -1,139 +0,0 @@ -//go:build !oss -// +build !oss - -/* - * SPDX-FileCopyrightText: © Hypermode Inc. - */ - -package edgraph - -import ( - "context" - "encoding/json" - "fmt" - "time" - - "github.com/golang/glog" - "github.com/pkg/errors" - - "github.com/dgraph-io/dgo/v240/protos/api" - "github.com/hypermodeinc/dgraph/v24/protos/pb" - "github.com/hypermodeinc/dgraph/v24/query" - "github.com/hypermodeinc/dgraph/v24/schema" - "github.com/hypermodeinc/dgraph/v24/worker" - "github.com/hypermodeinc/dgraph/v24/x" -) - -type ResetPasswordInput struct { - UserID string - Password string - Namespace uint64 -} - -func (s *Server) ResetPassword(ctx context.Context, inp *ResetPasswordInput) error { - query := fmt.Sprintf(`{ - x as updateUser(func: eq(dgraph.xid, "%s")) @filter(type(dgraph.type.User)) { - uid - } - }`, inp.UserID) - - userNQuads := []*api.NQuad{ - { - Subject: "uid(x)", - Predicate: "dgraph.password", - ObjectValue: &api.Value{Val: &api.Value_StrVal{StrVal: inp.Password}}, - }, - } - req := &Request{ - req: &api.Request{ - CommitNow: true, - Query: query, - Mutations: []*api.Mutation{ - { - Set: userNQuads, - Cond: "@if(gt(len(x), 0))", - }, - }, - }, - doAuth: NoAuthorize, - } - ctx = x.AttachNamespace(ctx, inp.Namespace) - resp, err := (&Server{}).doQuery(ctx, req) - if err != nil { - return errors.Wrapf(err, "Reset password for user %s in namespace %d, got error:", - inp.UserID, inp.Namespace) - } - - type userNode struct { - Uid string `json:"uid"` - } - - type userQryResp struct { - User []userNode `json:"updateUser"` - } - var userResp userQryResp - if err := json.Unmarshal(resp.GetJson(), &userResp); err != nil { - return errors.Wrap(err, "Reset password failed with error") - } - - if len(userResp.User) == 0 { - return errors.New("Failed to reset password, user doesn't exist") - } - return nil -} - -// CreateNamespaceInternal creates a new namespace. Only guardian of galaxy is authorized to do so. -// Authorization is handled by middlewares. -func (s *Server) CreateNamespaceInternal(ctx context.Context, passwd string) (uint64, error) { - glog.V(2).Info("Got create namespace request.") - - num := &pb.Num{Val: 1, Type: pb.Num_NS_ID} - ids, err := worker.AssignNsIdsOverNetwork(ctx, num) - if err != nil { - return 0, errors.Wrapf(err, "Creating namespace, got error:") - } - - ns := ids.StartId - glog.V(2).Infof("Got a lease for NsID: %d", ns) - - // Attach the newly leased NsID in the context in order to create guardians/groot for it. - ctx = x.AttachNamespace(ctx, ns) - m := &pb.Mutations{StartTs: worker.State.GetTimestamp(false)} - m.Schema = schema.InitialSchema(ns) - m.Types = schema.InitialTypes(ns) - _, err = query.ApplyMutations(ctx, m) - if err != nil { - return 0, err - } - - err = x.RetryUntilSuccess(10, 100*time.Millisecond, func() error { - return createGuardianAndGroot(ctx, ids.StartId, passwd) - }) - if err != nil { - return 0, errors.Wrapf(err, "Failed to create guardian and groot: ") - } - glog.V(2).Infof("Created namespace: %d", ns) - return ns, nil -} - -// This function is used while creating new namespace. New namespace creation is only allowed -// by the guardians of the galaxy group. -func createGuardianAndGroot(ctx context.Context, namespace uint64, passwd string) error { - if err := upsertGuardian(ctx); err != nil { - return errors.Wrap(err, "While creating Guardian") - } - if err := upsertGroot(ctx, passwd); err != nil { - return errors.Wrap(err, "While creating Groot") - } - return nil -} - -// DeleteNamespace deletes a new namespace. Only guardian of galaxy is authorized to do so. -// Authorization is handled by middlewares. -func (s *Server) DeleteNamespace(ctx context.Context, namespace uint64) error { - glog.Info("Deleting namespace", namespace) - if _, ok := schema.State().Namespaces()[namespace]; !ok { - return errors.Errorf("error deleting non-existing namespace %#x", namespace) - } - return worker.ProcessDeleteNsRequest(ctx, namespace) -} diff --git a/ee/acl/acl.go b/ee/acl/acl.go index a446c952cdc..3f2cdc47ebb 100644 --- a/ee/acl/acl.go +++ b/ee/acl/acl.go @@ -1,8 +1,6 @@ -//go:build !oss -// +build !oss - /* * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 */ package acl diff --git a/ee/acl/acl_curl_test.go b/ee/acl/acl_curl_test.go index fda700fdd41..2d7f5a6ce6d 100644 --- a/ee/acl/acl_curl_test.go +++ b/ee/acl/acl_curl_test.go @@ -3,6 +3,7 @@ /* * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 */ package acl diff --git a/ee/acl/acl_integration_test.go b/ee/acl/acl_integration_test.go index 8ff101f94a5..f88dc1b8cf4 100644 --- a/ee/acl/acl_integration_test.go +++ b/ee/acl/acl_integration_test.go @@ -2,6 +2,7 @@ /* * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 */ package acl diff --git a/ee/acl/acl_test.go b/ee/acl/acl_test.go index 0b67fcc2cc0..aa25405304e 100644 --- a/ee/acl/acl_test.go +++ b/ee/acl/acl_test.go @@ -3,6 +3,7 @@ /* * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 */ package acl diff --git a/ee/acl/integration_test.go b/ee/acl/integration_test.go index ea599fbfb00..edf950a3ebe 100644 --- a/ee/acl/integration_test.go +++ b/ee/acl/integration_test.go @@ -2,6 +2,7 @@ /* * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 */ package acl diff --git a/ee/acl/jwt_algo_test.go b/ee/acl/jwt_algo_test.go index 68c5e3d7426..28e83d99345 100644 --- a/ee/acl/jwt_algo_test.go +++ b/ee/acl/jwt_algo_test.go @@ -2,6 +2,7 @@ /* * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 */ package acl diff --git a/ee/acl/run.go b/ee/acl/run.go index cc893436ec2..fd621019acd 100644 --- a/ee/acl/run.go +++ b/ee/acl/run.go @@ -1,37 +1,131 @@ -//go:build oss -// +build oss - /* * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 */ package acl import ( + "fmt" + "os" + + "github.com/golang/glog" "github.com/spf13/cobra" + "github.com/spf13/viper" - "github.com/dgraph-io/dgo/v240/protos/api" "github.com/hypermodeinc/dgraph/v24/x" ) -var CmdAcl x.SubCommand +var ( + // CmdAcl is the sub-command used to manage the ACL system. + CmdAcl x.SubCommand +) + +const defaultGroupList = "dgraph-unused-group" func init() { CmdAcl.Cmd = &cobra.Command{ Use: "acl", - Short: "Enterprise feature. Not supported in oss version", + Short: "Run the Dgraph Enterprise Edition ACL tool", Annotations: map[string]string{"group": "security"}, } CmdAcl.Cmd.SetHelpTemplate(x.NonRootTemplate) -} + flag := CmdAcl.Cmd.PersistentFlags() + flag.StringP("alpha", "a", "127.0.0.1:9080", "Dgraph Alpha gRPC server address") + flag.String("guardian-creds", "", `Login credentials for the guardian + user defines the username to login. + password defines the password of the user. + namespace defines the namespace to log into. + Sample flag could look like --guardian-creds user=username;password=mypass;namespace=2`) -// CreateGroupNQuads cretes NQuads needed to store a group with the give ID. -func CreateGroupNQuads(groupId string) []*api.NQuad { - return nil + // --tls SuperFlag + x.RegisterClientTLSFlags(flag) + + subcommands := initSubcommands() + for _, sc := range subcommands { + CmdAcl.Cmd.AddCommand(sc.Cmd) + sc.Conf = viper.New() + if err := sc.Conf.BindPFlags(sc.Cmd.Flags()); err != nil { + glog.Fatalf("Unable to bind flags for command %v: %v", sc, err) + } + if err := sc.Conf.BindPFlags(CmdAcl.Cmd.PersistentFlags()); err != nil { + glog.Fatalf("Unable to bind persistent flags from acl for command %v: %v", sc, err) + } + sc.Conf.SetEnvPrefix(sc.EnvPrefix) + } } -// CreateUserNQuads creates the NQuads needed to store a user with the given ID and -// password in the ACL system. -func CreateUserNQuads(userId, password string) []*api.NQuad { - return nil +func initSubcommands() []*x.SubCommand { + var cmdAdd x.SubCommand + cmdAdd.Cmd = &cobra.Command{ + Use: "add", + Short: "Run Dgraph acl tool to add a user or group", + Run: func(cmd *cobra.Command, args []string) { + if err := add(cmdAdd.Conf); err != nil { + fmt.Printf("%v\n", err) + os.Exit(1) + } + }, + } + + addFlags := cmdAdd.Cmd.Flags() + addFlags.StringP("user", "u", "", "The user id to be created") + addFlags.StringP("password", "p", "", "The password for the user") + addFlags.StringP("group", "g", "", "The group id to be created") + + var cmdDel x.SubCommand + cmdDel.Cmd = &cobra.Command{ + Use: "del", + Short: "Run Dgraph acl tool to delete a user or group", + Run: func(cmd *cobra.Command, args []string) { + if err := del(cmdDel.Conf); err != nil { + fmt.Printf("Unable to delete the user: %v\n", err) + os.Exit(1) + } + }, + } + + delFlags := cmdDel.Cmd.Flags() + delFlags.StringP("user", "u", "", "The user id to be deleted") + delFlags.StringP("group", "g", "", "The group id to be deleted") + + var cmdMod x.SubCommand + cmdMod.Cmd = &cobra.Command{ + Use: "mod", + Short: "Run Dgraph acl tool to modify a user's password, a user's group list, or a" + + "group's predicate permissions", + Run: func(cmd *cobra.Command, args []string) { + if err := mod(cmdMod.Conf); err != nil { + fmt.Printf("Unable to modify: %v\n", err) + os.Exit(1) + } + }, + } + + modFlags := cmdMod.Cmd.Flags() + modFlags.StringP("user", "u", "", "The user id to be changed") + modFlags.BoolP("new_password", "n", false, "Whether to reset password for the user") + modFlags.StringP("group_list", "l", defaultGroupList, + "The list of groups to be set for the user") + modFlags.StringP("group", "g", "", "The group whose permission is to be changed") + modFlags.StringP("pred", "p", "", "The predicates whose acls are to be changed") + modFlags.IntP("perm", "m", 0, "The acl represented using "+ + "an integer: 4 for read, 2 for write, and 1 for modify. Use a negative value to remove a "+ + "predicate from the group") + + var cmdInfo x.SubCommand + cmdInfo.Cmd = &cobra.Command{ + Use: "info", + Short: "Show info about a user or group", + Run: func(cmd *cobra.Command, args []string) { + if err := info(cmdInfo.Conf); err != nil { + fmt.Printf("Unable to show info: %v\n", err) + os.Exit(1) + } + }, + } + infoFlags := cmdInfo.Cmd.Flags() + infoFlags.StringP("user", "u", "", "The user to be shown") + infoFlags.StringP("group", "g", "", "The group to be shown") + return []*x.SubCommand{&cmdAdd, &cmdDel, &cmdMod, &cmdInfo} } diff --git a/ee/acl/run_ee.go b/ee/acl/run_ee.go deleted file mode 100644 index dc705054129..00000000000 --- a/ee/acl/run_ee.go +++ /dev/null @@ -1,133 +0,0 @@ -//go:build !oss -// +build !oss - -/* - * SPDX-FileCopyrightText: © Hypermode Inc. - */ - -package acl - -import ( - "fmt" - "os" - - "github.com/golang/glog" - "github.com/spf13/cobra" - "github.com/spf13/viper" - - "github.com/hypermodeinc/dgraph/v24/x" -) - -var ( - // CmdAcl is the sub-command used to manage the ACL system. - CmdAcl x.SubCommand -) - -const defaultGroupList = "dgraph-unused-group" - -func init() { - CmdAcl.Cmd = &cobra.Command{ - Use: "acl", - Short: "Run the Dgraph Enterprise Edition ACL tool", - Annotations: map[string]string{"group": "security"}, - } - CmdAcl.Cmd.SetHelpTemplate(x.NonRootTemplate) - flag := CmdAcl.Cmd.PersistentFlags() - flag.StringP("alpha", "a", "127.0.0.1:9080", "Dgraph Alpha gRPC server address") - flag.String("guardian-creds", "", `Login credentials for the guardian - user defines the username to login. - password defines the password of the user. - namespace defines the namespace to log into. - Sample flag could look like --guardian-creds user=username;password=mypass;namespace=2`) - - // --tls SuperFlag - x.RegisterClientTLSFlags(flag) - - subcommands := initSubcommands() - for _, sc := range subcommands { - CmdAcl.Cmd.AddCommand(sc.Cmd) - sc.Conf = viper.New() - if err := sc.Conf.BindPFlags(sc.Cmd.Flags()); err != nil { - glog.Fatalf("Unable to bind flags for command %v: %v", sc, err) - } - if err := sc.Conf.BindPFlags(CmdAcl.Cmd.PersistentFlags()); err != nil { - glog.Fatalf("Unable to bind persistent flags from acl for command %v: %v", sc, err) - } - sc.Conf.SetEnvPrefix(sc.EnvPrefix) - } -} - -func initSubcommands() []*x.SubCommand { - var cmdAdd x.SubCommand - cmdAdd.Cmd = &cobra.Command{ - Use: "add", - Short: "Run Dgraph acl tool to add a user or group", - Run: func(cmd *cobra.Command, args []string) { - if err := add(cmdAdd.Conf); err != nil { - fmt.Printf("%v\n", err) - os.Exit(1) - } - }, - } - - addFlags := cmdAdd.Cmd.Flags() - addFlags.StringP("user", "u", "", "The user id to be created") - addFlags.StringP("password", "p", "", "The password for the user") - addFlags.StringP("group", "g", "", "The group id to be created") - - var cmdDel x.SubCommand - cmdDel.Cmd = &cobra.Command{ - Use: "del", - Short: "Run Dgraph acl tool to delete a user or group", - Run: func(cmd *cobra.Command, args []string) { - if err := del(cmdDel.Conf); err != nil { - fmt.Printf("Unable to delete the user: %v\n", err) - os.Exit(1) - } - }, - } - - delFlags := cmdDel.Cmd.Flags() - delFlags.StringP("user", "u", "", "The user id to be deleted") - delFlags.StringP("group", "g", "", "The group id to be deleted") - - var cmdMod x.SubCommand - cmdMod.Cmd = &cobra.Command{ - Use: "mod", - Short: "Run Dgraph acl tool to modify a user's password, a user's group list, or a" + - "group's predicate permissions", - Run: func(cmd *cobra.Command, args []string) { - if err := mod(cmdMod.Conf); err != nil { - fmt.Printf("Unable to modify: %v\n", err) - os.Exit(1) - } - }, - } - - modFlags := cmdMod.Cmd.Flags() - modFlags.StringP("user", "u", "", "The user id to be changed") - modFlags.BoolP("new_password", "n", false, "Whether to reset password for the user") - modFlags.StringP("group_list", "l", defaultGroupList, - "The list of groups to be set for the user") - modFlags.StringP("group", "g", "", "The group whose permission is to be changed") - modFlags.StringP("pred", "p", "", "The predicates whose acls are to be changed") - modFlags.IntP("perm", "m", 0, "The acl represented using "+ - "an integer: 4 for read, 2 for write, and 1 for modify. Use a negative value to remove a "+ - "predicate from the group") - - var cmdInfo x.SubCommand - cmdInfo.Cmd = &cobra.Command{ - Use: "info", - Short: "Show info about a user or group", - Run: func(cmd *cobra.Command, args []string) { - if err := info(cmdInfo.Conf); err != nil { - fmt.Printf("Unable to show info: %v\n", err) - os.Exit(1) - } - }, - } - infoFlags := cmdInfo.Cmd.Flags() - infoFlags.StringP("user", "u", "", "The user to be shown") - infoFlags.StringP("group", "g", "", "The group to be shown") - return []*x.SubCommand{&cmdAdd, &cmdDel, &cmdMod, &cmdInfo} -} diff --git a/ee/acl/upgrade_test.go b/ee/acl/upgrade_test.go index 5ba613a1ec2..c950d6dffc5 100644 --- a/ee/acl/upgrade_test.go +++ b/ee/acl/upgrade_test.go @@ -2,6 +2,7 @@ /* * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 */ package acl diff --git a/ee/acl/utils.go b/ee/acl/utils.go index 1319da7fe75..59e8dc6a77f 100644 --- a/ee/acl/utils.go +++ b/ee/acl/utils.go @@ -1,8 +1,6 @@ -//go:build !oss -// +build !oss - /* * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 */ package acl diff --git a/ee/audit/audit.go b/ee/audit/audit.go index 2715abb3ff2..d07d90a567a 100644 --- a/ee/audit/audit.go +++ b/ee/audit/audit.go @@ -1,30 +1,142 @@ -//go:build oss -// +build oss - /* * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 */ package audit -import "github.com/hypermodeinc/dgraph/v24/x" +import ( + "fmt" + "math" + "os" + "sync/atomic" + + "github.com/golang/glog" + + "github.com/dgraph-io/ristretto/v2/z" + "github.com/hypermodeinc/dgraph/v24/worker" + "github.com/hypermodeinc/dgraph/v24/x" +) + +const ( + defaultAuditFilenameF = "%s_audit_%d_%d.log" + NodeTypeAlpha = "alpha" + NodeTypeZero = "zero" +) -type AuditConf struct { - Dir string +var auditEnabled uint32 + +type AuditEvent struct { + User string + Namespace uint64 + ServerHost string + ClientHost string + Endpoint string + ReqType string + Req string + Status string + QueryParams map[string][]string +} + +const ( + UnauthorisedUser = "UnauthorisedUser" + UnknownUser = "UnknownUser" + UnknownNamespace = math.MaxUint64 + PoorManAuth = "PoorManAuth" + Grpc = "Grpc" + Http = "Http" + WebSocket = "Websocket" +) + +var auditor = &auditLogger{} + +type auditLogger struct { + log *x.Logger } func GetAuditConf(conf string) *x.LoggerConf { - return nil + if conf == "" || conf == worker.AuditDefaults { + return nil + } + auditFlag := z.NewSuperFlag(conf).MergeAndCheckDefault(worker.AuditDefaults) + out := auditFlag.GetString("output") + if out != "stdout" { + out = auditFlag.GetPath("output") + } + x.AssertTruef(out != "", "out flag is not provided for the audit logs") + encBytes, err := readAuditEncKey(auditFlag) + x.Check(err) + return &x.LoggerConf{ + Compress: auditFlag.GetBool("compress"), + Output: out, + EncryptionKey: encBytes, + Days: auditFlag.GetInt64("days"), + Size: auditFlag.GetInt64("size"), + MessageKey: "endpoint", + } +} + +func readAuditEncKey(conf *z.SuperFlag) ([]byte, error) { + encFile := conf.GetPath("encrypt-file") + if encFile == "" { + return nil, nil + } + encKey, err := os.ReadFile(encFile) + if err != nil { + return nil, err + } + return encKey, nil } +// InitAuditorIfNecessary accepts conf and enterprise edition check function. +// This method keep tracks whether cluster is part of enterprise edition or not. +// It pools eeEnabled function every five minutes to check if the license is still valid or not. func InitAuditorIfNecessary(conf *x.LoggerConf) error { - return nil + if conf == nil { + return nil + } + return InitAuditor(conf, uint64(worker.GroupId()), worker.NodeId()) } +// InitAuditor initializes the auditor. +// This method doesnt keep track of whether cluster is part of enterprise edition or not. +// Client has to keep track of that. func InitAuditor(conf *x.LoggerConf, gId, nId uint64) error { + ntype := NodeTypeAlpha + if gId == 0 { + ntype = NodeTypeZero + } + var err error + filename := fmt.Sprintf(defaultAuditFilenameF, ntype, gId, nId) + if auditor.log, err = x.InitLogger(conf, filename); err != nil { + return err + } + atomic.StoreUint32(&auditEnabled, 1) + glog.Infoln("audit logs are enabled") return nil } +// Close stops the ticker and sync the pending logs in buffer. +// It also sets the log to nil, because its being called by zero when license expires. +// If license added, InitLogger will take care of the file. func Close() { - return + if atomic.LoadUint32(&auditEnabled) == 0 { + return + } + auditor.log.Sync() + auditor.log = nil + glog.Infoln("audit logs are closed.") +} + +func (a *auditLogger) Audit(event *AuditEvent) { + a.log.AuditI(event.Endpoint, + "level", "AUDIT", + "user", event.User, + "namespace", event.Namespace, + "server", event.ServerHost, + "client", event.ClientHost, + "req_type", event.ReqType, + "req_body", event.Req, + "query_param", event.QueryParams, + "status", event.Status) } diff --git a/ee/audit/audit_ee.go b/ee/audit/audit_ee.go deleted file mode 100644 index f86de316edd..00000000000 --- a/ee/audit/audit_ee.go +++ /dev/null @@ -1,144 +0,0 @@ -//go:build !oss -// +build !oss - -/* - * SPDX-FileCopyrightText: © Hypermode Inc. - */ - -package audit - -import ( - "fmt" - "math" - "os" - "sync/atomic" - - "github.com/golang/glog" - - "github.com/dgraph-io/ristretto/v2/z" - "github.com/hypermodeinc/dgraph/v24/worker" - "github.com/hypermodeinc/dgraph/v24/x" -) - -const ( - defaultAuditFilenameF = "%s_audit_%d_%d.log" - NodeTypeAlpha = "alpha" - NodeTypeZero = "zero" -) - -var auditEnabled uint32 - -type AuditEvent struct { - User string - Namespace uint64 - ServerHost string - ClientHost string - Endpoint string - ReqType string - Req string - Status string - QueryParams map[string][]string -} - -const ( - UnauthorisedUser = "UnauthorisedUser" - UnknownUser = "UnknownUser" - UnknownNamespace = math.MaxUint64 - PoorManAuth = "PoorManAuth" - Grpc = "Grpc" - Http = "Http" - WebSocket = "Websocket" -) - -var auditor = &auditLogger{} - -type auditLogger struct { - log *x.Logger -} - -func GetAuditConf(conf string) *x.LoggerConf { - if conf == "" || conf == worker.AuditDefaults { - return nil - } - auditFlag := z.NewSuperFlag(conf).MergeAndCheckDefault(worker.AuditDefaults) - out := auditFlag.GetString("output") - if out != "stdout" { - out = auditFlag.GetPath("output") - } - x.AssertTruef(out != "", "out flag is not provided for the audit logs") - encBytes, err := readAuditEncKey(auditFlag) - x.Check(err) - return &x.LoggerConf{ - Compress: auditFlag.GetBool("compress"), - Output: out, - EncryptionKey: encBytes, - Days: auditFlag.GetInt64("days"), - Size: auditFlag.GetInt64("size"), - MessageKey: "endpoint", - } -} - -func readAuditEncKey(conf *z.SuperFlag) ([]byte, error) { - encFile := conf.GetPath("encrypt-file") - if encFile == "" { - return nil, nil - } - encKey, err := os.ReadFile(encFile) - if err != nil { - return nil, err - } - return encKey, nil -} - -// InitAuditorIfNecessary accepts conf and enterprise edition check function. -// This method keep tracks whether cluster is part of enterprise edition or not. -// It pools eeEnabled function every five minutes to check if the license is still valid or not. -func InitAuditorIfNecessary(conf *x.LoggerConf) error { - if conf == nil { - return nil - } - return InitAuditor(conf, uint64(worker.GroupId()), worker.NodeId()) -} - -// InitAuditor initializes the auditor. -// This method doesnt keep track of whether cluster is part of enterprise edition or not. -// Client has to keep track of that. -func InitAuditor(conf *x.LoggerConf, gId, nId uint64) error { - ntype := NodeTypeAlpha - if gId == 0 { - ntype = NodeTypeZero - } - var err error - filename := fmt.Sprintf(defaultAuditFilenameF, ntype, gId, nId) - if auditor.log, err = x.InitLogger(conf, filename); err != nil { - return err - } - atomic.StoreUint32(&auditEnabled, 1) - glog.Infoln("audit logs are enabled") - return nil -} - -// Close stops the ticker and sync the pending logs in buffer. -// It also sets the log to nil, because its being called by zero when license expires. -// If license added, InitLogger will take care of the file. -func Close() { - if atomic.LoadUint32(&auditEnabled) == 0 { - return - } - auditor.log.Sync() - auditor.log = nil - glog.Infoln("audit logs are closed.") -} - -func (a *auditLogger) Audit(event *AuditEvent) { - a.log.AuditI(event.Endpoint, - "level", "AUDIT", - "user", event.User, - "namespace", event.Namespace, - "server", event.ServerHost, - "client", event.ClientHost, - "req_type", event.ReqType, - "req_body", event.Req, - "query_param", event.QueryParams, - "status", event.Status) -} diff --git a/ee/audit/interceptor.go b/ee/audit/interceptor.go index 5433166ab0f..3b1cf40d79a 100644 --- a/ee/audit/interceptor.go +++ b/ee/audit/interceptor.go @@ -1,32 +1,393 @@ -//go:build oss -// +build oss - /* * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 */ package audit import ( + "bytes" + "compress/gzip" "context" + "encoding/json" + "fmt" + "io" + "net" "net/http" + "regexp" + "strconv" + "strings" + "sync/atomic" + "github.com/dgraph-io/gqlparser/v2/ast" + "github.com/dgraph-io/gqlparser/v2/parser" "github.com/hypermodeinc/dgraph/v24/graphql/schema" + "github.com/hypermodeinc/dgraph/v24/x" + "github.com/golang/glog" + "github.com/gorilla/websocket" "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/peer" + "google.golang.org/grpc/status" +) + +const ( + maxReqLength = 4 << 10 // 4 KB ) +var skipApis = map[string]bool{ + // raft server + "Heartbeat": true, + "RaftMessage": true, + "JoinCluster": true, + "IsPeer": true, + // zero server + "StreamMembership": true, + "UpdateMembership": true, + "Oracle": true, + "Timestamps": true, + "ShouldServe": true, + "Connect": true, + // health server + "Check": true, + "Watch": true, +} + +var skipEPs = map[string]bool{ + // list of endpoints that needs to be skipped + "/health": true, + "/state": true, + "/probe/graphql": true, +} + func AuditRequestGRPC(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { - return handler(ctx, req) + skip := func(method string) bool { + return skipApis[info.FullMethod[strings.LastIndex(info.FullMethod, "/")+1:]] + } + + if atomic.LoadUint32(&auditEnabled) == 0 || skip(info.FullMethod) { + return handler(ctx, req) + } + response, err := handler(ctx, req) + auditGrpc(ctx, req, info) + return response, err } func AuditRequestHttp(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - next.ServeHTTP(w, r) + skip := func(method string) bool { + return skipEPs[r.URL.Path] + } + + if atomic.LoadUint32(&auditEnabled) == 0 || skip(r.URL.Path) { + next.ServeHTTP(w, r) + return + } + + // Websocket connection in graphQl happens differently. We only get access tokens and + // metadata in payload later once the connection is upgraded to correct protocol. + // Doc: https://github.com/apollographql/subscriptions-transport-ws/blob/v0.9.4/PROTOCOL.md + // + // Auditing for websocket connections will be handled by graphql/admin/http.go:154#Subscribe + for _, subprotocol := range websocket.Subprotocols(r) { + if subprotocol == "graphql-ws" { + next.ServeHTTP(w, r) + return + } + } + + rw := NewResponseWriter(w) + var buf bytes.Buffer + tee := io.TeeReader(r.Body, &buf) + r.Body = io.NopCloser(tee) + next.ServeHTTP(rw, r) + r.Body = io.NopCloser(bytes.NewReader(buf.Bytes())) + auditHttp(rw, r) }) } func AuditWebSockets(ctx context.Context, req *schema.Request) { - return + if atomic.LoadUint32(&auditEnabled) == 0 { + return + } + + namespace := uint64(0) + var err error + var user string + // TODO(anurag): X-Dgraph-AccessToken should be exported as a constant + if token := req.Header.Get("X-Dgraph-AccessToken"); token != "" { + user = getUser(token, false) + namespace, err = x.ExtractNamespaceFromJwt(token) + if err != nil { + glog.Warningf("Error while auditing websockets: %s", err) + } + } else if token := req.Header.Get("X-Dgraph-AuthToken"); token != "" { + user = getUser(token, true) + } else { + user = getUser("", false) + } + + ip := "" + if peerInfo, ok := peer.FromContext(ctx); ok { + ip, _, _ = net.SplitHostPort(peerInfo.Addr.String()) + } + + auditor.Audit(&AuditEvent{ + User: user, + Namespace: namespace, + ServerHost: x.WorkerConfig.MyAddr, + ClientHost: ip, + Endpoint: "/graphql", + ReqType: WebSocket, + Req: truncate(req.Query, maxReqLength), + Status: http.StatusText(http.StatusOK), + QueryParams: nil, + }) +} + +func auditGrpc(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo) { + clientHost := "" + if p, ok := peer.FromContext(ctx); ok { + clientHost = p.Addr.String() + } + var user string + var namespace uint64 + var err error + extractUser := func(md metadata.MD) { + if t := md.Get("accessJwt"); len(t) > 0 { + user = getUser(t[0], false) + } else if t := md.Get("auth-token"); len(t) > 0 { + user = getUser(t[0], true) + } else { + user = getUser("", false) + } + } + + extractNamespace := func(md metadata.MD) { + ns := md.Get("namespace") + if len(ns) == 0 { + namespace = UnknownNamespace + } else { + if namespace, err = strconv.ParseUint(ns[0], 10, 64); err != nil { + namespace = UnknownNamespace + } + } + } + + if md, ok := metadata.FromIncomingContext(ctx); ok { + extractUser(md) + extractNamespace(md) + } + + cd := codes.Unknown + if serr, ok := status.FromError(err); ok { + cd = serr.Code() + } + + reqBody := checkRequestBody(Grpc, info.FullMethod[strings.LastIndex(info.FullMethod, + "/")+1:], fmt.Sprintf("%+v", req)) + auditor.Audit(&AuditEvent{ + User: user, + Namespace: namespace, + ServerHost: x.WorkerConfig.MyAddr, + ClientHost: clientHost, + Endpoint: info.FullMethod, + ReqType: Grpc, + Req: truncate(reqBody, maxReqLength), + Status: cd.String(), + }) +} + +func auditHttp(w *ResponseWriter, r *http.Request) { + body := getRequestBody(r) + var user string + if token := r.Header.Get("X-Dgraph-AccessToken"); token != "" { + user = getUser(token, false) + } else if token := r.Header.Get("X-Dgraph-AuthToken"); token != "" { + user = getUser(token, true) + } else { + user = getUser("", false) + } + + auditor.Audit(&AuditEvent{ + User: user, + Namespace: x.ExtractNamespaceHTTP(r), + ServerHost: x.WorkerConfig.MyAddr, + ClientHost: r.RemoteAddr, + Endpoint: r.URL.Path, + ReqType: Http, + Req: truncate(checkRequestBody(Http, r.URL.Path, string(body)), maxReqLength), + Status: http.StatusText(w.statusCode), + QueryParams: r.URL.Query(), + }) +} + +// password fields are accessible only via /admin endpoint hence, +// this will be only called with /admin endpoint +func maskPasswordFieldsInGQL(req string) string { + var gqlReq schema.Request + err := json.Unmarshal([]byte(req), &gqlReq) + if err != nil { + glog.Errorf("unable to unmarshal gql request %v", err) + return req + } + query, gErr := parser.ParseQuery(&ast.Source{ + Input: gqlReq.Query, + }) + if gErr != nil { + glog.Errorf("unable to parse gql request %+v", gErr) + return req + } + if len(query.Operations) == 0 { + return req + } + var variableName string + for _, op := range query.Operations { + if op.Operation != ast.Mutation || len(op.SelectionSet) == 0 { + continue + } + + for _, ss := range op.SelectionSet { + if f, ok := ss.(*ast.Field); ok && len(f.Arguments) > 0 { + variableName = getMaskedFieldVarName(f) + } + } + } + + // no variable present + if variableName == "" { + regex, err := regexp.Compile( + `password[\s]?(.*?)[\s]?:[\s]?(.*?)[\s]?"[\s]?(.*?)[\s]?"`) + if err != nil { + return req + } + return regex.ReplaceAllString(req, "*******") + } + regex, err := regexp.Compile( + fmt.Sprintf(`"%s[\s]?(.*?)[\s]?"[\s]?(.*?)[\s]?:[\s]?(.*?)[\s]?"[\s]?(.*?)[\s]?"`, + variableName[1:])) + if err != nil { + return req + } + return regex.ReplaceAllString(req, "*******") +} + +func getMaskedFieldVarName(f *ast.Field) string { + switch f.Name { + case "resetPassword": + for _, a := range f.Arguments { + if a.Name != "input" || a.Value == nil || a.Value.Children == nil { + continue + } + + for _, c := range a.Value.Children { + if c.Name == "password" && c.Value.Kind == ast.Variable { + return c.Value.String() + } + } + } + case "login": + for _, a := range f.Arguments { + if a.Name == "password" && a.Value.Kind == ast.Variable { + return a.Value.String() + } + } + } + return "" +} + +var skipReqBodyGrpc = map[string]bool{ + "Login": true, +} + +func checkRequestBody(reqType string, path string, body string) string { + switch reqType { + case Grpc: + if skipReqBodyGrpc[path] { + regex, err := regexp.Compile( + `password[\s]?(.*?)[\s]?:[\s]?(.*?)[\s]?"[\s]?(.*?)[\s]?"`) + if err != nil { + return body + } + body = regex.ReplaceAllString(body, "*******") + } + case Http: + if path == "/admin" { + return maskPasswordFieldsInGQL(body) + } else if path == "/grapqhl" { + regex, err := regexp.Compile( + `check[\s]?(.*?)[\s]?Password[\s]?(.*?)[\s]?:[\s]?(.*?)[\s]?"[\s]?(.*?)[\s]?"`) + if err != nil { + return body + } + body = regex.ReplaceAllString(body, "*******") + } + } + return body +} + +func getRequestBody(r *http.Request) []byte { + var in io.Reader = r.Body + if enc := r.Header.Get("Content-Encoding"); enc != "" && enc != "identity" { + if enc == "gzip" { + gz, err := gzip.NewReader(r.Body) + if err != nil { + return []byte(err.Error()) + } + defer gz.Close() + in = gz + } else { + return []byte("unknown encoding") + } + } + + body, err := io.ReadAll(in) + if err != nil { + return []byte(err.Error()) + } + return body +} + +func getUser(token string, poorman bool) string { + if poorman { + return PoorManAuth + } + var user string + var err error + if token == "" { + if x.WorkerConfig.AclEnabled { + user = UnauthorisedUser + } + } else { + if user, err = x.ExtractUserName(token); err != nil { + user = UnknownUser + } + } + return user +} + +type ResponseWriter struct { + http.ResponseWriter + statusCode int +} + +func NewResponseWriter(w http.ResponseWriter) *ResponseWriter { + // WriteHeader(int) is not called if our response implicitly returns 200 OK, so + // we default to that status code. + return &ResponseWriter{w, http.StatusOK} +} + +func (rw *ResponseWriter) WriteHeader(code int) { + rw.statusCode = code + rw.ResponseWriter.WriteHeader(code) +} + +func truncate(s string, l int) string { + if len(s) > l { + return s[:l] + } + return s } diff --git a/ee/audit/interceptor_ee.go b/ee/audit/interceptor_ee.go deleted file mode 100644 index 25fe6b91212..00000000000 --- a/ee/audit/interceptor_ee.go +++ /dev/null @@ -1,395 +0,0 @@ -//go:build !oss -// +build !oss - -/* - * SPDX-FileCopyrightText: © Hypermode Inc. - */ - -package audit - -import ( - "bytes" - "compress/gzip" - "context" - "encoding/json" - "fmt" - "io" - "net" - "net/http" - "regexp" - "strconv" - "strings" - "sync/atomic" - - "github.com/dgraph-io/gqlparser/v2/ast" - "github.com/dgraph-io/gqlparser/v2/parser" - "github.com/hypermodeinc/dgraph/v24/graphql/schema" - "github.com/hypermodeinc/dgraph/v24/x" - - "github.com/golang/glog" - "github.com/gorilla/websocket" - "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/metadata" - "google.golang.org/grpc/peer" - "google.golang.org/grpc/status" -) - -const ( - maxReqLength = 4 << 10 // 4 KB -) - -var skipApis = map[string]bool{ - // raft server - "Heartbeat": true, - "RaftMessage": true, - "JoinCluster": true, - "IsPeer": true, - // zero server - "StreamMembership": true, - "UpdateMembership": true, - "Oracle": true, - "Timestamps": true, - "ShouldServe": true, - "Connect": true, - // health server - "Check": true, - "Watch": true, -} - -var skipEPs = map[string]bool{ - // list of endpoints that needs to be skipped - "/health": true, - "/state": true, - "/probe/graphql": true, -} - -func AuditRequestGRPC(ctx context.Context, req interface{}, - info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { - skip := func(method string) bool { - return skipApis[info.FullMethod[strings.LastIndex(info.FullMethod, "/")+1:]] - } - - if atomic.LoadUint32(&auditEnabled) == 0 || skip(info.FullMethod) { - return handler(ctx, req) - } - response, err := handler(ctx, req) - auditGrpc(ctx, req, info) - return response, err -} - -func AuditRequestHttp(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - skip := func(method string) bool { - return skipEPs[r.URL.Path] - } - - if atomic.LoadUint32(&auditEnabled) == 0 || skip(r.URL.Path) { - next.ServeHTTP(w, r) - return - } - - // Websocket connection in graphQl happens differently. We only get access tokens and - // metadata in payload later once the connection is upgraded to correct protocol. - // Doc: https://github.com/apollographql/subscriptions-transport-ws/blob/v0.9.4/PROTOCOL.md - // - // Auditing for websocket connections will be handled by graphql/admin/http.go:154#Subscribe - for _, subprotocol := range websocket.Subprotocols(r) { - if subprotocol == "graphql-ws" { - next.ServeHTTP(w, r) - return - } - } - - rw := NewResponseWriter(w) - var buf bytes.Buffer - tee := io.TeeReader(r.Body, &buf) - r.Body = io.NopCloser(tee) - next.ServeHTTP(rw, r) - r.Body = io.NopCloser(bytes.NewReader(buf.Bytes())) - auditHttp(rw, r) - }) -} - -func AuditWebSockets(ctx context.Context, req *schema.Request) { - if atomic.LoadUint32(&auditEnabled) == 0 { - return - } - - namespace := uint64(0) - var err error - var user string - // TODO(anurag): X-Dgraph-AccessToken should be exported as a constant - if token := req.Header.Get("X-Dgraph-AccessToken"); token != "" { - user = getUser(token, false) - namespace, err = x.ExtractNamespaceFromJwt(token) - if err != nil { - glog.Warningf("Error while auditing websockets: %s", err) - } - } else if token := req.Header.Get("X-Dgraph-AuthToken"); token != "" { - user = getUser(token, true) - } else { - user = getUser("", false) - } - - ip := "" - if peerInfo, ok := peer.FromContext(ctx); ok { - ip, _, _ = net.SplitHostPort(peerInfo.Addr.String()) - } - - auditor.Audit(&AuditEvent{ - User: user, - Namespace: namespace, - ServerHost: x.WorkerConfig.MyAddr, - ClientHost: ip, - Endpoint: "/graphql", - ReqType: WebSocket, - Req: truncate(req.Query, maxReqLength), - Status: http.StatusText(http.StatusOK), - QueryParams: nil, - }) -} - -func auditGrpc(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo) { - clientHost := "" - if p, ok := peer.FromContext(ctx); ok { - clientHost = p.Addr.String() - } - var user string - var namespace uint64 - var err error - extractUser := func(md metadata.MD) { - if t := md.Get("accessJwt"); len(t) > 0 { - user = getUser(t[0], false) - } else if t := md.Get("auth-token"); len(t) > 0 { - user = getUser(t[0], true) - } else { - user = getUser("", false) - } - } - - extractNamespace := func(md metadata.MD) { - ns := md.Get("namespace") - if len(ns) == 0 { - namespace = UnknownNamespace - } else { - if namespace, err = strconv.ParseUint(ns[0], 10, 64); err != nil { - namespace = UnknownNamespace - } - } - } - - if md, ok := metadata.FromIncomingContext(ctx); ok { - extractUser(md) - extractNamespace(md) - } - - cd := codes.Unknown - if serr, ok := status.FromError(err); ok { - cd = serr.Code() - } - - reqBody := checkRequestBody(Grpc, info.FullMethod[strings.LastIndex(info.FullMethod, - "/")+1:], fmt.Sprintf("%+v", req)) - auditor.Audit(&AuditEvent{ - User: user, - Namespace: namespace, - ServerHost: x.WorkerConfig.MyAddr, - ClientHost: clientHost, - Endpoint: info.FullMethod, - ReqType: Grpc, - Req: truncate(reqBody, maxReqLength), - Status: cd.String(), - }) -} - -func auditHttp(w *ResponseWriter, r *http.Request) { - body := getRequestBody(r) - var user string - if token := r.Header.Get("X-Dgraph-AccessToken"); token != "" { - user = getUser(token, false) - } else if token := r.Header.Get("X-Dgraph-AuthToken"); token != "" { - user = getUser(token, true) - } else { - user = getUser("", false) - } - - auditor.Audit(&AuditEvent{ - User: user, - Namespace: x.ExtractNamespaceHTTP(r), - ServerHost: x.WorkerConfig.MyAddr, - ClientHost: r.RemoteAddr, - Endpoint: r.URL.Path, - ReqType: Http, - Req: truncate(checkRequestBody(Http, r.URL.Path, string(body)), maxReqLength), - Status: http.StatusText(w.statusCode), - QueryParams: r.URL.Query(), - }) -} - -// password fields are accessible only via /admin endpoint hence, -// this will be only called with /admin endpoint -func maskPasswordFieldsInGQL(req string) string { - var gqlReq schema.Request - err := json.Unmarshal([]byte(req), &gqlReq) - if err != nil { - glog.Errorf("unable to unmarshal gql request %v", err) - return req - } - query, gErr := parser.ParseQuery(&ast.Source{ - Input: gqlReq.Query, - }) - if gErr != nil { - glog.Errorf("unable to parse gql request %+v", gErr) - return req - } - if len(query.Operations) == 0 { - return req - } - var variableName string - for _, op := range query.Operations { - if op.Operation != ast.Mutation || len(op.SelectionSet) == 0 { - continue - } - - for _, ss := range op.SelectionSet { - if f, ok := ss.(*ast.Field); ok && len(f.Arguments) > 0 { - variableName = getMaskedFieldVarName(f) - } - } - } - - // no variable present - if variableName == "" { - regex, err := regexp.Compile( - `password[\s]?(.*?)[\s]?:[\s]?(.*?)[\s]?"[\s]?(.*?)[\s]?"`) - if err != nil { - return req - } - return regex.ReplaceAllString(req, "*******") - } - regex, err := regexp.Compile( - fmt.Sprintf(`"%s[\s]?(.*?)[\s]?"[\s]?(.*?)[\s]?:[\s]?(.*?)[\s]?"[\s]?(.*?)[\s]?"`, - variableName[1:])) - if err != nil { - return req - } - return regex.ReplaceAllString(req, "*******") -} - -func getMaskedFieldVarName(f *ast.Field) string { - switch f.Name { - case "resetPassword": - for _, a := range f.Arguments { - if a.Name != "input" || a.Value == nil || a.Value.Children == nil { - continue - } - - for _, c := range a.Value.Children { - if c.Name == "password" && c.Value.Kind == ast.Variable { - return c.Value.String() - } - } - } - case "login": - for _, a := range f.Arguments { - if a.Name == "password" && a.Value.Kind == ast.Variable { - return a.Value.String() - } - } - } - return "" -} - -var skipReqBodyGrpc = map[string]bool{ - "Login": true, -} - -func checkRequestBody(reqType string, path string, body string) string { - switch reqType { - case Grpc: - if skipReqBodyGrpc[path] { - regex, err := regexp.Compile( - `password[\s]?(.*?)[\s]?:[\s]?(.*?)[\s]?"[\s]?(.*?)[\s]?"`) - if err != nil { - return body - } - body = regex.ReplaceAllString(body, "*******") - } - case Http: - if path == "/admin" { - return maskPasswordFieldsInGQL(body) - } else if path == "/grapqhl" { - regex, err := regexp.Compile( - `check[\s]?(.*?)[\s]?Password[\s]?(.*?)[\s]?:[\s]?(.*?)[\s]?"[\s]?(.*?)[\s]?"`) - if err != nil { - return body - } - body = regex.ReplaceAllString(body, "*******") - } - } - return body -} - -func getRequestBody(r *http.Request) []byte { - var in io.Reader = r.Body - if enc := r.Header.Get("Content-Encoding"); enc != "" && enc != "identity" { - if enc == "gzip" { - gz, err := gzip.NewReader(r.Body) - if err != nil { - return []byte(err.Error()) - } - defer gz.Close() - in = gz - } else { - return []byte("unknown encoding") - } - } - - body, err := io.ReadAll(in) - if err != nil { - return []byte(err.Error()) - } - return body -} - -func getUser(token string, poorman bool) string { - if poorman { - return PoorManAuth - } - var user string - var err error - if token == "" { - if x.WorkerConfig.AclEnabled { - user = UnauthorisedUser - } - } else { - if user, err = x.ExtractUserName(token); err != nil { - user = UnknownUser - } - } - return user -} - -type ResponseWriter struct { - http.ResponseWriter - statusCode int -} - -func NewResponseWriter(w http.ResponseWriter) *ResponseWriter { - // WriteHeader(int) is not called if our response implicitly returns 200 OK, so - // we default to that status code. - return &ResponseWriter{w, http.StatusOK} -} - -func (rw *ResponseWriter) WriteHeader(code int) { - rw.statusCode = code - rw.ResponseWriter.WriteHeader(code) -} - -func truncate(s string, l int) string { - if len(s) > l { - return s[:l] - } - return s -} diff --git a/ee/audit/run.go b/ee/audit/run.go index 48d91fdb644..c8c31bec353 100644 --- a/ee/audit/run.go +++ b/ee/audit/run.go @@ -1,14 +1,22 @@ -//go:build oss -// +build oss - /* * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 */ package audit import ( + "crypto/aes" + "crypto/cipher" + "encoding/binary" + "fmt" + "io" + "os" + + "github.com/golang/glog" + "github.com/pkg/errors" "github.com/spf13/cobra" + "github.com/spf13/viper" "github.com/hypermodeinc/dgraph/v24/x" ) @@ -17,7 +25,249 @@ var CmdAudit x.SubCommand func init() { CmdAudit.Cmd = &cobra.Command{ - Use: "audit", - Short: "Enterprise feature. Not supported in oss version", + Use: "audit", + Short: "Dgraph audit tool", + Annotations: map[string]string{"group": "security"}, + } + CmdAudit.Cmd.SetHelpTemplate(x.NonRootTemplate) + + subcommands := initSubcommands() + for _, sc := range subcommands { + CmdAudit.Cmd.AddCommand(sc.Cmd) + sc.Conf = viper.New() + if err := sc.Conf.BindPFlags(sc.Cmd.Flags()); err != nil { + glog.Fatalf("Unable to bind flags for command %v: %v", sc, err) + } + if err := sc.Conf.BindPFlags(CmdAudit.Cmd.PersistentFlags()); err != nil { + glog.Fatalf( + "Unable to bind persistent flags from audit for command %v: %v", sc, err) + } + sc.Conf.SetEnvPrefix(sc.EnvPrefix) + } +} + +var decryptCmd x.SubCommand + +func initSubcommands() []*x.SubCommand { + decryptCmd.Cmd = &cobra.Command{ + Use: "decrypt", + Short: "Run Dgraph Audit tool to decrypt audit files", + Run: func(cmd *cobra.Command, args []string) { + if err := run(); err != nil { + fmt.Printf("%v\n", err) + os.Exit(1) + } + }, + } + + decFlags := decryptCmd.Cmd.Flags() + decFlags.String("in", "", "input file that needs to decrypted.") + decFlags.String("out", "audit_log_out.log", + "output file to which decrypted output will be dumped.") + decFlags.String("encryption_key_file", "", "path to encrypt files.") + return []*x.SubCommand{&decryptCmd} +} + +func run() error { + key, err := os.ReadFile(decryptCmd.Conf.GetString("encryption_key_file")) + x.Check(err) + if key == nil { + return errors.New("no encryption key provided") + } + + file, err := os.Open(decryptCmd.Conf.GetString("in")) + x.Check(err) + defer func() { + if err := file.Close(); err != nil { + glog.Warningf("error closing file: %v", err) + } + }() + + outfile, err := os.OpenFile(decryptCmd.Conf.GetString("out"), + os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.ModePerm) + x.Check(err) + defer func() { + if err := outfile.Close(); err != nil { + glog.Warningf("error closing file: %v", err) + } + }() + block, err := aes.NewCipher(key) + x.Check(err) + + stat, err := os.Stat(decryptCmd.Conf.GetString("in")) + x.Check(err) + if stat.Size() == 0 { + glog.Info("audit file is empty") + return nil + } + + if err := decrypt(file, outfile, block, stat.Size()); err != nil { + return errors.Wrap(err, "could not decrypt audit log") + } + + glog.Infof("decryption of audit file %s is done: decrypted file is %s", + decryptCmd.Conf.GetString("in"), + decryptCmd.Conf.GetString("out")) + return nil +} + +func decrypt(file io.ReaderAt, outfile io.Writer, block cipher.Block, sz int64) error { + // decrypt header in audit log to verify encryption key + // [16]byte IV + [4]byte len(x.VerificationText) + [11]byte x.VerificationText + decryptHeader := func() ([]byte, int64, error) { + var iterator int64 + iv := make([]byte, aes.BlockSize) + n, err := file.ReadAt(iv, iterator) // get first iv + if err != nil { + return nil, 0, errors.Wrap(err, "unable to read IV") + } + iterator = iterator + int64(n) + 4 // length of verification text encoded in uint32 + + ct := make([]byte, len(x.VerificationText)) + n, err = file.ReadAt(ct, iterator) + if err != nil { + return nil, 0, errors.Wrap(err, "unable to read verification text") + } + iterator = iterator + int64(n) + + stream := cipher.NewCTR(block, iv) + stream.XORKeyStream(ct, ct) + if string(ct) != x.VerificationText { + return nil, 0, errors.New("invalid encryption key provided. Please check your encryption key") + } + return iv, iterator, nil } + + // [12]byte baseIV + [4]byte len(x.VerificationTextDeprecated) + [11]byte x.VerificationTextDeprecated + decryptHeaderDeprecated := func() ([]byte, int64, error) { + var iterator int64 = 0 + + iv := make([]byte, aes.BlockSize) + n, err := file.ReadAt(iv, iterator) + if err != nil { + return nil, 0, errors.Wrap(err, "unable to read IV") + } + iterator = iterator + int64(n) + + ct := make([]byte, len(x.VerificationTextDeprecated)) + n, err = file.ReadAt(ct, iterator) + if err != nil { + return nil, 0, errors.Wrap(err, "unable to read verification text") + } + iterator = iterator + int64(n) + + stream := cipher.NewCTR(block, iv) + stream.XORKeyStream(ct, ct) + if string(ct) != x.VerificationTextDeprecated { + return nil, 0, errors.New("invalid encryption key provided. Please check your encryption key") + } + return iv, iterator, nil + } + + useDeprecated := false + iv, iterator, err := decryptHeader() + if err != nil { + // might have an old audit log + iv2, iterator2, err := decryptHeaderDeprecated() + if err != nil { + return errors.New("invalid encryption key provided. Please check your encryption key") + } + // found old audit log + useDeprecated = true + iv, iterator = iv2, iterator2 + } + + // encrypted writes each have the form below + // IV generated for each write + // ################################################################# + // ##### [16]byte IV + [4]byte uint32(len(p)) + [:]byte p ##### + // ################################################################# + decryptBody := func() { + for { + // if its the end of data. finish decrypting + if iterator >= sz { + break + } + n, err := file.ReadAt(iv, iterator) + if err != nil { + glog.Warningf("received %v while decrypting audit log", err) + glog.Warningf("read %v bytes, expected %v", n, len(iv)) + break + } + iterator = iterator + 16 + length := make([]byte, 4) + n, err = file.ReadAt(length, iterator) + if err != nil { + glog.Warningf("received %v while decrypting audit log", err) + glog.Warningf("read %v bytes, expected %v", n, len(length)) + break + } + iterator = iterator + int64(n) + + content := make([]byte, binary.BigEndian.Uint32(length)) + n, err = file.ReadAt(content, iterator) + if err != nil { + glog.Warningf("received %v while decrypting audit log", err) + glog.Warningf("read %v bytes, expected %v", n, len(content)) + break + } + iterator = iterator + int64(n) + + stream := cipher.NewCTR(block, iv) + stream.XORKeyStream(content, content) + n, err = outfile.Write(content) + if err != nil { + glog.Warningf("received %v while writing decrypted audit log", err) + glog.Warningf("wrote %v bytes, expected to write %v", n, len(content)) + break + } + } + } + + // encrypted writes in body have the form + // baseIV is constant, last 4 bytes vary + // ######################################################## + // ##### [4]byte uint32(len(p)) + [:]byte p ##### + // ######################################################## + decryptBodyDeprecated := func() { + for { + // if its the end of data. finish decrypting + if iterator >= sz { + break + } + n, err := file.ReadAt(iv[12:], iterator) + if err != nil { + glog.Warningf("received %v while decrypting audit log", err) + glog.Warningf("read %v bytes, expected %v", n, len(iv[12:])) + break + } + iterator = iterator + int64(n) + + content := make([]byte, binary.BigEndian.Uint32(iv[12:])) + n, err = file.ReadAt(content, iterator) + if err != nil { + glog.Warningf("received %v while decrypting audit log", err) + glog.Warningf("read %v bytes, expected %v", n, len(content)) + break + } + iterator = iterator + int64(n) + stream := cipher.NewCTR(block, iv) + stream.XORKeyStream(content, content) + n, err = outfile.Write(content) + if err != nil { + glog.Warningf("received %v while writing decrypted audit log", err) + glog.Warningf("wrote %v bytes, expected to write %v", n, len(content)) + break + } + } + } + + if useDeprecated { + decryptBodyDeprecated() + } else { + decryptBody() + } + + return nil + } diff --git a/ee/audit/run_ee.go b/ee/audit/run_ee.go deleted file mode 100644 index 770ccde957f..00000000000 --- a/ee/audit/run_ee.go +++ /dev/null @@ -1,275 +0,0 @@ -//go:build !oss -// +build !oss - -/* - * SPDX-FileCopyrightText: © Hypermode Inc. - */ - -package audit - -import ( - "crypto/aes" - "crypto/cipher" - "encoding/binary" - "fmt" - "io" - "os" - - "github.com/golang/glog" - "github.com/pkg/errors" - "github.com/spf13/cobra" - "github.com/spf13/viper" - - "github.com/hypermodeinc/dgraph/v24/x" -) - -var CmdAudit x.SubCommand - -func init() { - CmdAudit.Cmd = &cobra.Command{ - Use: "audit", - Short: "Dgraph audit tool", - Annotations: map[string]string{"group": "security"}, - } - CmdAudit.Cmd.SetHelpTemplate(x.NonRootTemplate) - - subcommands := initSubcommands() - for _, sc := range subcommands { - CmdAudit.Cmd.AddCommand(sc.Cmd) - sc.Conf = viper.New() - if err := sc.Conf.BindPFlags(sc.Cmd.Flags()); err != nil { - glog.Fatalf("Unable to bind flags for command %v: %v", sc, err) - } - if err := sc.Conf.BindPFlags(CmdAudit.Cmd.PersistentFlags()); err != nil { - glog.Fatalf( - "Unable to bind persistent flags from audit for command %v: %v", sc, err) - } - sc.Conf.SetEnvPrefix(sc.EnvPrefix) - } -} - -var decryptCmd x.SubCommand - -func initSubcommands() []*x.SubCommand { - decryptCmd.Cmd = &cobra.Command{ - Use: "decrypt", - Short: "Run Dgraph Audit tool to decrypt audit files", - Run: func(cmd *cobra.Command, args []string) { - if err := run(); err != nil { - fmt.Printf("%v\n", err) - os.Exit(1) - } - }, - } - - decFlags := decryptCmd.Cmd.Flags() - decFlags.String("in", "", "input file that needs to decrypted.") - decFlags.String("out", "audit_log_out.log", - "output file to which decrypted output will be dumped.") - decFlags.String("encryption_key_file", "", "path to encrypt files.") - return []*x.SubCommand{&decryptCmd} -} - -func run() error { - key, err := os.ReadFile(decryptCmd.Conf.GetString("encryption_key_file")) - x.Check(err) - if key == nil { - return errors.New("no encryption key provided") - } - - file, err := os.Open(decryptCmd.Conf.GetString("in")) - x.Check(err) - defer func() { - if err := file.Close(); err != nil { - glog.Warningf("error closing file: %v", err) - } - }() - - outfile, err := os.OpenFile(decryptCmd.Conf.GetString("out"), - os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.ModePerm) - x.Check(err) - defer func() { - if err := outfile.Close(); err != nil { - glog.Warningf("error closing file: %v", err) - } - }() - block, err := aes.NewCipher(key) - x.Check(err) - - stat, err := os.Stat(decryptCmd.Conf.GetString("in")) - x.Check(err) - if stat.Size() == 0 { - glog.Info("audit file is empty") - return nil - } - - if err := decrypt(file, outfile, block, stat.Size()); err != nil { - return errors.Wrap(err, "could not decrypt audit log") - } - - glog.Infof("decryption of audit file %s is done: decrypted file is %s", - decryptCmd.Conf.GetString("in"), - decryptCmd.Conf.GetString("out")) - return nil -} - -func decrypt(file io.ReaderAt, outfile io.Writer, block cipher.Block, sz int64) error { - // decrypt header in audit log to verify encryption key - // [16]byte IV + [4]byte len(x.VerificationText) + [11]byte x.VerificationText - decryptHeader := func() ([]byte, int64, error) { - var iterator int64 - iv := make([]byte, aes.BlockSize) - n, err := file.ReadAt(iv, iterator) // get first iv - if err != nil { - return nil, 0, errors.Wrap(err, "unable to read IV") - } - iterator = iterator + int64(n) + 4 // length of verification text encoded in uint32 - - ct := make([]byte, len(x.VerificationText)) - n, err = file.ReadAt(ct, iterator) - if err != nil { - return nil, 0, errors.Wrap(err, "unable to read verification text") - } - iterator = iterator + int64(n) - - stream := cipher.NewCTR(block, iv) - stream.XORKeyStream(ct, ct) - if string(ct) != x.VerificationText { - return nil, 0, errors.New("invalid encryption key provided. Please check your encryption key") - } - return iv, iterator, nil - } - - // [12]byte baseIV + [4]byte len(x.VerificationTextDeprecated) + [11]byte x.VerificationTextDeprecated - decryptHeaderDeprecated := func() ([]byte, int64, error) { - var iterator int64 = 0 - - iv := make([]byte, aes.BlockSize) - n, err := file.ReadAt(iv, iterator) - if err != nil { - return nil, 0, errors.Wrap(err, "unable to read IV") - } - iterator = iterator + int64(n) - - ct := make([]byte, len(x.VerificationTextDeprecated)) - n, err = file.ReadAt(ct, iterator) - if err != nil { - return nil, 0, errors.Wrap(err, "unable to read verification text") - } - iterator = iterator + int64(n) - - stream := cipher.NewCTR(block, iv) - stream.XORKeyStream(ct, ct) - if string(ct) != x.VerificationTextDeprecated { - return nil, 0, errors.New("invalid encryption key provided. Please check your encryption key") - } - return iv, iterator, nil - } - - useDeprecated := false - iv, iterator, err := decryptHeader() - if err != nil { - // might have an old audit log - iv2, iterator2, err := decryptHeaderDeprecated() - if err != nil { - return errors.New("invalid encryption key provided. Please check your encryption key") - } - // found old audit log - useDeprecated = true - iv, iterator = iv2, iterator2 - } - - // encrypted writes each have the form below - // IV generated for each write - // ################################################################# - // ##### [16]byte IV + [4]byte uint32(len(p)) + [:]byte p ##### - // ################################################################# - decryptBody := func() { - for { - // if its the end of data. finish decrypting - if iterator >= sz { - break - } - n, err := file.ReadAt(iv, iterator) - if err != nil { - glog.Warningf("received %v while decrypting audit log", err) - glog.Warningf("read %v bytes, expected %v", n, len(iv)) - break - } - iterator = iterator + 16 - length := make([]byte, 4) - n, err = file.ReadAt(length, iterator) - if err != nil { - glog.Warningf("received %v while decrypting audit log", err) - glog.Warningf("read %v bytes, expected %v", n, len(length)) - break - } - iterator = iterator + int64(n) - - content := make([]byte, binary.BigEndian.Uint32(length)) - n, err = file.ReadAt(content, iterator) - if err != nil { - glog.Warningf("received %v while decrypting audit log", err) - glog.Warningf("read %v bytes, expected %v", n, len(content)) - break - } - iterator = iterator + int64(n) - - stream := cipher.NewCTR(block, iv) - stream.XORKeyStream(content, content) - n, err = outfile.Write(content) - if err != nil { - glog.Warningf("received %v while writing decrypted audit log", err) - glog.Warningf("wrote %v bytes, expected to write %v", n, len(content)) - break - } - } - } - - // encrypted writes in body have the form - // baseIV is constant, last 4 bytes vary - // ######################################################## - // ##### [4]byte uint32(len(p)) + [:]byte p ##### - // ######################################################## - decryptBodyDeprecated := func() { - for { - // if its the end of data. finish decrypting - if iterator >= sz { - break - } - n, err := file.ReadAt(iv[12:], iterator) - if err != nil { - glog.Warningf("received %v while decrypting audit log", err) - glog.Warningf("read %v bytes, expected %v", n, len(iv[12:])) - break - } - iterator = iterator + int64(n) - - content := make([]byte, binary.BigEndian.Uint32(iv[12:])) - n, err = file.ReadAt(content, iterator) - if err != nil { - glog.Warningf("received %v while decrypting audit log", err) - glog.Warningf("read %v bytes, expected %v", n, len(content)) - break - } - iterator = iterator + int64(n) - stream := cipher.NewCTR(block, iv) - stream.XORKeyStream(content, content) - n, err = outfile.Write(content) - if err != nil { - glog.Warningf("received %v while writing decrypted audit log", err) - glog.Warningf("wrote %v bytes, expected to write %v", n, len(content)) - break - } - } - } - - if useDeprecated { - decryptBodyDeprecated() - } else { - decryptBody() - } - - return nil - -} diff --git a/ee/audit/run_ee_test.go b/ee/audit/run_test.go similarity index 98% rename from ee/audit/run_ee_test.go rename to ee/audit/run_test.go index f8f1f292cfb..5ee26cd1877 100644 --- a/ee/audit/run_ee_test.go +++ b/ee/audit/run_test.go @@ -1,8 +1,6 @@ -//go:build !oss -// +build !oss - /* * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 */ package audit diff --git a/ee/backup/run.go b/ee/backup/run.go index 2773cce4f14..5f08c02553b 100644 --- a/ee/backup/run.go +++ b/ee/backup/run.go @@ -1,8 +1,6 @@ -//go:build !oss -// +build !oss - /* * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 */ package backup diff --git a/ee/enc/util.go b/ee/enc/util.go index 43f6d933aed..93233aee480 100644 --- a/ee/enc/util.go +++ b/ee/enc/util.go @@ -1,25 +1,65 @@ -//go:build oss -// +build oss - /* * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 */ package enc import ( + "crypto/aes" + "crypto/cipher" "io" + + "github.com/pkg/errors" + + "github.com/dgraph-io/badger/v4/y" + "github.com/hypermodeinc/dgraph/v24/x" ) -// Eebuild indicates if this is a Enterprise build. -var EeBuild = false +// EeBuild indicates if this is a Enterprise build. +var EeBuild = true -// GetWriter returns the Writer as is for OSS Builds. -func GetWriter(_ []byte, w io.Writer) (io.Writer, error) { - return w, nil +// GetWriter wraps a crypto StreamWriter using the input key on the input Writer. +func GetWriter(key x.Sensitive, w io.Writer) (io.Writer, error) { + // No encryption, return the input writer as is. + if key == nil { + return w, nil + } + // Encryption, wrap crypto StreamWriter on the input Writer. + c, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + iv, err := y.GenerateIV() + if err != nil { + return nil, err + } + if iv != nil { + if _, err = w.Write(iv); err != nil { + return nil, err + } + } + return cipher.StreamWriter{S: cipher.NewCTR(c, iv), W: w}, nil } -// GetReader returns the reader as is for OSS Builds. -func GetReader(_ []byte, r io.Reader) (io.Reader, error) { - return r, nil +// GetReader wraps a crypto StreamReader using the input key on the input Reader. +func GetReader(key x.Sensitive, r io.Reader) (io.Reader, error) { + // No encryption, return input reader as is. + if key == nil { + return r, nil + } + + // Encryption, wrap crypto StreamReader on input Reader. + c, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + var iv []byte = make([]byte, 16) + cnt, err := r.Read(iv) + if cnt != 16 || err != nil { + err = errors.Errorf("unable to get IV from encrypted backup. Read %v bytes, err %v ", + cnt, err) + return nil, err + } + return cipher.StreamReader{S: cipher.NewCTR(c, iv), R: r}, nil } diff --git a/ee/enc/util_ee.go b/ee/enc/util_ee.go deleted file mode 100644 index 29b95512c48..00000000000 --- a/ee/enc/util_ee.go +++ /dev/null @@ -1,67 +0,0 @@ -//go:build !oss -// +build !oss - -/* - * SPDX-FileCopyrightText: © Hypermode Inc. - */ - -package enc - -import ( - "crypto/aes" - "crypto/cipher" - "io" - - "github.com/pkg/errors" - - "github.com/dgraph-io/badger/v4/y" - "github.com/hypermodeinc/dgraph/v24/x" -) - -// EeBuild indicates if this is a Enterprise build. -var EeBuild = true - -// GetWriter wraps a crypto StreamWriter using the input key on the input Writer. -func GetWriter(key x.Sensitive, w io.Writer) (io.Writer, error) { - // No encryption, return the input writer as is. - if key == nil { - return w, nil - } - // Encryption, wrap crypto StreamWriter on the input Writer. - c, err := aes.NewCipher(key) - if err != nil { - return nil, err - } - iv, err := y.GenerateIV() - if err != nil { - return nil, err - } - if iv != nil { - if _, err = w.Write(iv); err != nil { - return nil, err - } - } - return cipher.StreamWriter{S: cipher.NewCTR(c, iv), W: w}, nil -} - -// GetReader wraps a crypto StreamReader using the input key on the input Reader. -func GetReader(key x.Sensitive, r io.Reader) (io.Reader, error) { - // No encryption, return input reader as is. - if key == nil { - return r, nil - } - - // Encryption, wrap crypto StreamReader on input Reader. - c, err := aes.NewCipher(key) - if err != nil { - return nil, err - } - var iv []byte = make([]byte, 16) - cnt, err := r.Read(iv) - if cnt != 16 || err != nil { - err = errors.Errorf("unable to get IV from encrypted backup. Read %v bytes, err %v ", - cnt, err) - return nil, err - } - return cipher.StreamReader{S: cipher.NewCTR(c, iv), R: r}, nil -} diff --git a/ee/flags.go b/ee/flags.go index 4d55815d665..374afd0fc3f 100644 --- a/ee/flags.go +++ b/ee/flags.go @@ -1,5 +1,6 @@ /* * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 */ package ee diff --git a/ee/keys.go b/ee/keys.go index 66555b13e11..a27984e6732 100644 --- a/ee/keys.go +++ b/ee/keys.go @@ -1,22 +1,134 @@ -//go:build oss -// +build oss - /* * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 */ package ee import ( + "crypto" + "crypto/ed25519" "fmt" + "os" + "strconv" + "strings" + "github.com/golang-jwt/jwt/v5" + "github.com/pkg/errors" "github.com/spf13/viper" + + "github.com/dgraph-io/ristretto/v2/z" + "github.com/hypermodeinc/dgraph/v24/x" ) // GetKeys returns the ACL and encryption keys as configured by the user // through the --acl, --encryption, and --vault flags. On OSS builds, // this function always returns an error. func GetKeys(config *viper.Viper) (*Keys, error) { - return nil, fmt.Errorf( - "flags: acl / encryption is an enterprise-only feature") + aclSuperFlag := z.NewSuperFlag(config.GetString("acl")).MergeAndCheckDefault(AclDefaults) + encSuperFlag := z.NewSuperFlag(config.GetString("encryption")).MergeAndCheckDefault(EncDefaults) + + // Get SecretKey and EncKey from vault / acl / encryption SuperFlags + aclKey, encKey := vaultGetKeys(config) + + encKeyFile := encSuperFlag.GetPath(flagEncKeyFile) + if encKeyFile != "" { + if encKey != nil { + return nil, fmt.Errorf("flags: Encryption key set in both vault and encryption flags") + } + var err error + if encKey, err = os.ReadFile(encKeyFile); err != nil { + return nil, fmt.Errorf("error reading encryption key from file: %s: %s", encKeyFile, err) + } + } + if l := len(encKey); encKey != nil && l != 16 && l != 32 && l != 64 { + return nil, fmt.Errorf("encryption key must have length of 16, 32, or 64 bytes, got %d bytes instead", l) + } + + aclSecretFile := aclSuperFlag.GetPath(flagAclKeyFile) + if aclSecretFile != "" { + if aclKey != nil { + return nil, fmt.Errorf("flags: ACL secret key set in both vault and acl flags") + } + var err error + if aclKey, err = os.ReadFile(aclSecretFile); err != nil { + return nil, fmt.Errorf("error reading ACL secret key from file: %s: %s", aclSecretFile, err) + } + } + + keys := &Keys{ + AclSecretKeyBytes: aclKey, + AclAccessTtl: aclSuperFlag.GetDuration(flagAclAccessTtl), + AclRefreshTtl: aclSuperFlag.GetDuration(flagAclRefreshTtl), + EncKey: encKey, + } + + if aclKey != nil { + algStr := aclSuperFlag.GetString(flagAclJwtAlg) + aclAlg := jwt.GetSigningMethod(algStr) + if aclAlg == nil { + return nil, fmt.Errorf("Unsupported JWT signing algorithm for ACL: %v", algStr) + } + if err := checkAclKeyLength(aclAlg, aclKey); err != nil { + return nil, err + } + privKey, pubKey, err := parseJWTKey(aclAlg, aclKey) + if err != nil { + return nil, err + } + + keys.AclJwtAlg = aclAlg + keys.AclSecretKey = privKey + keys.AclPublicKey = pubKey + } + + return keys, nil +} + +func parseJWTKey(alg jwt.SigningMethod, key x.Sensitive) (interface{}, interface{}, error) { + switch { + case strings.HasPrefix(alg.Alg(), "HS"): + return key, key, nil + + case strings.HasPrefix(alg.Alg(), "ES"): + pk, err := jwt.ParseECPrivateKeyFromPEM(key) + if err != nil { + return nil, nil, errors.Wrapf(err, "error parsing ACL key as ECDSA private key") + } + return pk, &pk.PublicKey, nil + + case strings.HasPrefix(alg.Alg(), "RS") || strings.HasPrefix(alg.Alg(), "PS"): + pk, err := jwt.ParseRSAPrivateKeyFromPEM(key) + if err != nil { + return nil, nil, errors.Wrapf(err, "error parsing ACL key as RSA private key") + } + return pk, &pk.PublicKey, nil + + case alg.Alg() == "EdDSA": + pk, err := jwt.ParseEdPrivateKeyFromPEM(key) + if err != nil { + return nil, nil, errors.Wrapf(err, "error parsing ACL key as EdDSA private key") + } + return pk.(crypto.Signer), pk.(ed25519.PrivateKey).Public(), nil + + default: + return nil, nil, errors.Errorf("unsupported signing algorithm: %v", alg.Alg()) + } +} + +func checkAclKeyLength(alg jwt.SigningMethod, key x.Sensitive) error { + if !strings.HasPrefix(alg.Alg(), "HS") { + return nil + } + + sl, err := strconv.Atoi(strings.TrimPrefix(alg.Alg(), "HS")) + if err != nil { + return errors.Wrapf(err, "error finding sha length for algo %v", alg.Alg()) + } + + // SHA length has to be smaller or equal to the key length + if sl > len(key)*8 { + return errors.Errorf("ACL key length [%v <= %v] bits for JWT algorithm [%v]", sl, len(key)*8, alg.Alg()) + } + return nil } diff --git a/ee/keys_ee.go b/ee/keys_ee.go deleted file mode 100644 index cb23698f13b..00000000000 --- a/ee/keys_ee.go +++ /dev/null @@ -1,136 +0,0 @@ -//go:build !oss -// +build !oss - -/* - * SPDX-FileCopyrightText: © Hypermode Inc. - */ - -package ee - -import ( - "crypto" - "crypto/ed25519" - "fmt" - "os" - "strconv" - "strings" - - "github.com/golang-jwt/jwt/v5" - "github.com/pkg/errors" - "github.com/spf13/viper" - - "github.com/dgraph-io/ristretto/v2/z" - "github.com/hypermodeinc/dgraph/v24/x" -) - -// GetKeys returns the ACL and encryption keys as configured by the user -// through the --acl, --encryption, and --vault flags. On OSS builds, -// this function always returns an error. -func GetKeys(config *viper.Viper) (*Keys, error) { - aclSuperFlag := z.NewSuperFlag(config.GetString("acl")).MergeAndCheckDefault(AclDefaults) - encSuperFlag := z.NewSuperFlag(config.GetString("encryption")).MergeAndCheckDefault(EncDefaults) - - // Get SecretKey and EncKey from vault / acl / encryption SuperFlags - aclKey, encKey := vaultGetKeys(config) - - encKeyFile := encSuperFlag.GetPath(flagEncKeyFile) - if encKeyFile != "" { - if encKey != nil { - return nil, fmt.Errorf("flags: Encryption key set in both vault and encryption flags") - } - var err error - if encKey, err = os.ReadFile(encKeyFile); err != nil { - return nil, fmt.Errorf("error reading encryption key from file: %s: %s", encKeyFile, err) - } - } - if l := len(encKey); encKey != nil && l != 16 && l != 32 && l != 64 { - return nil, fmt.Errorf("encryption key must have length of 16, 32, or 64 bytes, got %d bytes instead", l) - } - - aclSecretFile := aclSuperFlag.GetPath(flagAclKeyFile) - if aclSecretFile != "" { - if aclKey != nil { - return nil, fmt.Errorf("flags: ACL secret key set in both vault and acl flags") - } - var err error - if aclKey, err = os.ReadFile(aclSecretFile); err != nil { - return nil, fmt.Errorf("error reading ACL secret key from file: %s: %s", aclSecretFile, err) - } - } - - keys := &Keys{ - AclSecretKeyBytes: aclKey, - AclAccessTtl: aclSuperFlag.GetDuration(flagAclAccessTtl), - AclRefreshTtl: aclSuperFlag.GetDuration(flagAclRefreshTtl), - EncKey: encKey, - } - - if aclKey != nil { - algStr := aclSuperFlag.GetString(flagAclJwtAlg) - aclAlg := jwt.GetSigningMethod(algStr) - if aclAlg == nil { - return nil, fmt.Errorf("Unsupported JWT signing algorithm for ACL: %v", algStr) - } - if err := checkAclKeyLength(aclAlg, aclKey); err != nil { - return nil, err - } - privKey, pubKey, err := parseJWTKey(aclAlg, aclKey) - if err != nil { - return nil, err - } - - keys.AclJwtAlg = aclAlg - keys.AclSecretKey = privKey - keys.AclPublicKey = pubKey - } - - return keys, nil -} - -func parseJWTKey(alg jwt.SigningMethod, key x.Sensitive) (interface{}, interface{}, error) { - switch { - case strings.HasPrefix(alg.Alg(), "HS"): - return key, key, nil - - case strings.HasPrefix(alg.Alg(), "ES"): - pk, err := jwt.ParseECPrivateKeyFromPEM(key) - if err != nil { - return nil, nil, errors.Wrapf(err, "error parsing ACL key as ECDSA private key") - } - return pk, &pk.PublicKey, nil - - case strings.HasPrefix(alg.Alg(), "RS") || strings.HasPrefix(alg.Alg(), "PS"): - pk, err := jwt.ParseRSAPrivateKeyFromPEM(key) - if err != nil { - return nil, nil, errors.Wrapf(err, "error parsing ACL key as RSA private key") - } - return pk, &pk.PublicKey, nil - - case alg.Alg() == "EdDSA": - pk, err := jwt.ParseEdPrivateKeyFromPEM(key) - if err != nil { - return nil, nil, errors.Wrapf(err, "error parsing ACL key as EdDSA private key") - } - return pk.(crypto.Signer), pk.(ed25519.PrivateKey).Public(), nil - - default: - return nil, nil, errors.Errorf("unsupported signing algorithm: %v", alg.Alg()) - } -} - -func checkAclKeyLength(alg jwt.SigningMethod, key x.Sensitive) error { - if !strings.HasPrefix(alg.Alg(), "HS") { - return nil - } - - sl, err := strconv.Atoi(strings.TrimPrefix(alg.Alg(), "HS")) - if err != nil { - return errors.Wrapf(err, "error finding sha length for algo %v", alg.Alg()) - } - - // SHA length has to be smaller or equal to the key length - if sl > len(key)*8 { - return errors.Errorf("ACL key length [%v <= %v] bits for JWT algorithm [%v]", sl, len(key)*8, alg.Alg()) - } - return nil -} diff --git a/ee/vault_ee.go b/ee/vault.go similarity index 99% rename from ee/vault_ee.go rename to ee/vault.go index 6dac24878db..85be2b8409a 100644 --- a/ee/vault_ee.go +++ b/ee/vault.go @@ -1,8 +1,6 @@ -//go:build !oss -// +build !oss - /* * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 */ package ee diff --git a/ee/vault/vault.go b/ee/vault/vault.go index d52403aa39f..d980e9439d2 100644 --- a/ee/vault/vault.go +++ b/ee/vault/vault.go @@ -3,6 +3,7 @@ /* * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 */ package vault diff --git a/graphql/admin/endpoints.go b/graphql/admin/endpoints.go index 0dedc5ad92d..7cd3dd0f48e 100644 --- a/graphql/admin/endpoints.go +++ b/graphql/admin/endpoints.go @@ -1,6 +1,3 @@ -//go:build oss -// +build oss - /* * SPDX-FileCopyrightText: © Hypermode Inc. * SPDX-License-Identifier: Apache-2.0 @@ -8,8 +5,546 @@ package admin -const adminTypes = `` +const adminTypes = ` + input BackupInput { + + """ + Destination for the backup: e.g. Minio or S3 bucket. + """ + destination: String! + + """ + Access key credential for the destination. + """ + accessKey: String + + """ + Secret key credential for the destination. + """ + secretKey: String + + """ + AWS session token, if required. + """ + sessionToken: String + + """ + Set to true to allow backing up to S3 or Minio bucket that requires no credentials. + """ + anonymous: Boolean + + """ + Force a full backup instead of an incremental backup. + """ + forceFull: Boolean + } + + type BackupPayload { + response: Response + taskId: String + } + + input RestoreTenantInput { + """ + restoreInput contains fields that are required for the restore operation, + i.e., location, backupId, and backupNum + """ + restoreInput: RestoreInput + + """ + fromNamespace is the namespace of the tenant that needs to be restored into namespace 0 of the new cluster. + """ + fromNamespace: Int! + } + + input RestoreInput { + + """ + Destination for the backup: e.g. Minio or S3 bucket. + """ + location: String! + + """ + Backup ID of the backup series to restore. This ID is included in the manifest.json file. + If missing, it defaults to the latest series. + """ + backupId: String + + """ + If backupNum is 0 or missing, the entire series will be restored i.e all the incremental + backups available in the series as well as the full backup will be restored. If backupNum is + non-zero, we restore all the backups in the series that have backupNum smaller or equal to + the backupNum provided here. Backups that have backupNum higher than this will be ignored. + In simple words, all the backups with backupNum of backup <= backupNum will be restored. + """ + backupNum: Int + + """ + All the backups with backupNum >= incrementalFrom will be restored. + """ + incrementalFrom: Int + + """ + If isPartial is set to true then the cluster will be kept in draining mode after + restore. This makes sure that the db is not corrupted by any mutations or tablet + moves in between two restores. + """ + isPartial: Boolean + + """ + Path to the key file needed to decrypt the backup. This file should be accessible + by all alphas in the group. The backup will be written using the encryption key + with which the cluster was started, which might be different than this key. + """ + encryptionKeyFile: String + + """ + Vault server address where the key is stored. This server must be accessible + by all alphas in the group. Default "http://localhost:8200". + """ + vaultAddr: String + + """ + Path to the Vault RoleID file. + """ + vaultRoleIDFile: String + + """ + Path to the Vault SecretID file. + """ + vaultSecretIDFile: String + + """ + Vault kv store path where the key lives. Default "secret/data/dgraph". + """ + vaultPath: String + + """ + Vault kv store field whose value is the key. Default "enc_key". + """ + vaultField: String + + """ + Vault kv store field's format. Must be "base64" or "raw". Default "base64". + """ + vaultFormat: String + + """ + Access key credential for the destination. + """ + accessKey: String + + """ + Secret key credential for the destination. + """ + secretKey: String + + """ + AWS session token, if required. + """ + sessionToken: String + + """ + Set to true to allow backing up to S3 or Minio bucket that requires no credentials. + """ + anonymous: Boolean + } + + type RestorePayload { + """ + A short string indicating whether the restore operation was successfully scheduled. + """ + code: String + + """ + Includes the error message if the operation failed. + """ + message: String + } + + input ListBackupsInput { + """ + Destination for the backup: e.g. Minio or S3 bucket. + """ + location: String! + + """ + Access key credential for the destination. + """ + accessKey: String + + """ + Secret key credential for the destination. + """ + secretKey: String + + """ + AWS session token, if required. + """ + sessionToken: String + + """ + Whether the destination doesn't require credentials (e.g. S3 public bucket). + """ + anonymous: Boolean + + } + + type BackupGroup { + """ + The ID of the cluster group. + """ + groupId: UInt64 + + """ + List of predicates assigned to the group. + """ + predicates: [String] + } + + type Manifest { + """ + Unique ID for the backup series. + """ + backupId: String + + """ + Number of this backup within the backup series. The full backup always has a value of one. + """ + backupNum: UInt64 + + """ + Whether this backup was encrypted. + """ + encrypted: Boolean + + """ + List of groups and the predicates they store in this backup. + """ + groups: [BackupGroup] + + """ + Path to the manifest file. + """ + path: String + + """ + The timestamp at which this backup was taken. The next incremental backup will + start from this timestamp. + """ + since: UInt64 + + """ + The type of backup, either full or incremental. + """ + type: String + } + + type LoginResponse { + + """ + JWT token that should be used in future requests after this login. + """ + accessJWT: String + + """ + Refresh token that can be used to re-login after accessJWT expires. + """ + refreshJWT: String + } + + type LoginPayload { + response: LoginResponse + } + + type User @dgraph(type: "dgraph.type.User") @secret(field: "password", pred: "dgraph.password") { + + """ + Username for the user. Dgraph ensures that usernames are unique. + """ + name: String! @id @dgraph(pred: "dgraph.xid") + + groups: [Group] @dgraph(pred: "dgraph.user.group") + } + + type Group @dgraph(type: "dgraph.type.Group") { + + """ + Name of the group. Dgraph ensures uniqueness of group names. + """ + name: String! @id @dgraph(pred: "dgraph.xid") + users: [User] @dgraph(pred: "~dgraph.user.group") + rules: [Rule] @dgraph(pred: "dgraph.acl.rule") + } + + type Rule @dgraph(type: "dgraph.type.Rule") { + + """ + Predicate to which the rule applies. + """ + predicate: String! @dgraph(pred: "dgraph.rule.predicate") + + """ + Permissions that apply for the rule. Represented following the UNIX file permission + convention. That is, 4 (binary 100) represents READ, 2 (binary 010) represents WRITE, + and 1 (binary 001) represents MODIFY (the permission to change a predicate’s schema). + + The options are: + * 1 (binary 001) : MODIFY + * 2 (010) : WRITE + * 3 (011) : WRITE+MODIFY + * 4 (100) : READ + * 5 (101) : READ+MODIFY + * 6 (110) : READ+WRITE + * 7 (111) : READ+WRITE+MODIFY + + Permission 0, which is equal to no permission for a predicate, blocks all read, + write and modify operations. + """ + permission: Int! @dgraph(pred: "dgraph.rule.permission") + } + + input StringHashFilter { + eq: String + } + + enum UserOrderable { + name + } + + enum GroupOrderable { + name + } + + input AddUserInput { + name: String! + password: String! + groups: [GroupRef] + } + + input AddGroupInput { + name: String! + rules: [RuleRef] + } + + input UserRef { + name: String! + } + + input GroupRef { + name: String! + } + + input RuleRef { + """ + Predicate to which the rule applies. + """ + predicate: String! + + """ + Permissions that apply for the rule. Represented following the UNIX file permission + convention. That is, 4 (binary 100) represents READ, 2 (binary 010) represents WRITE, + and 1 (binary 001) represents MODIFY (the permission to change a predicate’s schema). + + The options are: + * 1 (binary 001) : MODIFY + * 2 (010) : WRITE + * 3 (011) : WRITE+MODIFY + * 4 (100) : READ + * 5 (101) : READ+MODIFY + * 6 (110) : READ+WRITE + * 7 (111) : READ+WRITE+MODIFY + + Permission 0, which is equal to no permission for a predicate, blocks all read, + write and modify operations. + """ + permission: Int! + } + + input UserFilter { + name: StringHashFilter + and: UserFilter + or: UserFilter + not: UserFilter + } + + input UserOrder { + asc: UserOrderable + desc: UserOrderable + then: UserOrder + } + + input GroupOrder { + asc: GroupOrderable + desc: GroupOrderable + then: GroupOrder + } + + input UserPatch { + password: String + groups: [GroupRef] + } + + input UpdateUserInput { + filter: UserFilter! + set: UserPatch + remove: UserPatch + } + + input GroupFilter { + name: StringHashFilter + and: UserFilter + or: UserFilter + not: UserFilter + } + + input SetGroupPatch { + rules: [RuleRef!]! + } + + input RemoveGroupPatch { + rules: [String!]! + } + + input UpdateGroupInput { + filter: GroupFilter! + set: SetGroupPatch + remove: RemoveGroupPatch + } + + type AddUserPayload { + user: [User] + } + + type AddGroupPayload { + group: [Group] + } + + type DeleteUserPayload { + msg: String + numUids: Int + } + + type DeleteGroupPayload { + msg: String + numUids: Int + } + + input AddNamespaceInput { + """ + Enter a new password for groot in that namespace. If you leave it blank, the password will be the default. + """ + password: String + } + + input DeleteNamespaceInput { + namespaceId: Int! + } + + type NamespacePayload { + namespaceId: UInt64 + message: String + } + + input ResetPasswordInput { + userId: String! + password: String! + namespace: Int! + } + + type ResetPasswordPayload { + userId: String + message: String + namespace: UInt64 + } + ` + +const adminMutations = ` + + """ + Start a binary backup. See : https://dgraph.io/docs/enterprise-features/#binary-backups + """ + backup(input: BackupInput!) : BackupPayload + + """ + Start restoring a binary backup. See : + https://dgraph.io/docs/enterprise-features/#binary-backups + """ + restore(input: RestoreInput!) : RestorePayload + + """ + Restore given tenant into namespace 0 of the cluster + """ + restoreTenant(input: RestoreTenantInput!) : RestorePayload + + """ + Login to Dgraph. Successful login results in a JWT that can be used in future requests. + If login is not successful an error is returned. + """ + login(userId: String, password: String, namespace: Int, refreshToken: String): LoginPayload + + """ + Add a user. When linking to groups: if the group doesn't exist it is created; if the group + exists, the new user is linked to the existing group. It's possible to both create new + groups and link to existing groups in the one mutation. + + Dgraph ensures that usernames are unique, hence attempting to add an existing user results + in an error. + """ + addUser(input: [AddUserInput!]!): AddUserPayload + + """ + Add a new group and (optionally) set the rules for the group. + """ + addGroup(input: [AddGroupInput!]!): AddGroupPayload + + """ + Update users, their passwords and groups. As with AddUser, when linking to groups: if the + group doesn't exist it is created; if the group exists, the new user is linked to the existing + group. If the filter doesn't match any users, the mutation has no effect. + """ + updateUser(input: UpdateUserInput!): AddUserPayload + + """ + Add or remove rules for groups. If the filter doesn't match any groups, + the mutation has no effect. + """ + updateGroup(input: UpdateGroupInput!): AddGroupPayload + + deleteGroup(filter: GroupFilter!): DeleteGroupPayload + deleteUser(filter: UserFilter!): DeleteUserPayload + + """ + Add a new namespace. + """ + addNamespace(input: AddNamespaceInput): NamespacePayload + + """ + Delete a namespace. + """ + deleteNamespace(input: DeleteNamespaceInput!): NamespacePayload + + """ + Reset password can only be used by the Guardians of the galaxy to reset password of + any user in any namespace. + """ + resetPassword(input: ResetPasswordInput!): ResetPasswordPayload + ` + +const adminQueries = ` + getUser(name: String!): User + getGroup(name: String!): Group + + """ + Get the currently logged in user. + """ + getCurrentUser: User -const adminMutations = `` + queryUser(filter: UserFilter, order: UserOrder, first: Int, offset: Int): [User] + queryGroup(filter: GroupFilter, order: GroupOrder, first: Int, offset: Int): [Group] -const adminQueries = `` + """ + Get the information about the backups at a given location. + """ + listBackups(input: ListBackupsInput!) : [Manifest] + ` diff --git a/graphql/admin/endpoints_ee.go b/graphql/admin/endpoints_ee.go deleted file mode 100644 index 3eb02ac9d62..00000000000 --- a/graphql/admin/endpoints_ee.go +++ /dev/null @@ -1,552 +0,0 @@ -//go:build !oss -// +build !oss - -/* - * SPDX-FileCopyrightText: © Hypermode Inc. - */ - -package admin - -const adminTypes = ` - input BackupInput { - - """ - Destination for the backup: e.g. Minio or S3 bucket. - """ - destination: String! - - """ - Access key credential for the destination. - """ - accessKey: String - - """ - Secret key credential for the destination. - """ - secretKey: String - - """ - AWS session token, if required. - """ - sessionToken: String - - """ - Set to true to allow backing up to S3 or Minio bucket that requires no credentials. - """ - anonymous: Boolean - - """ - Force a full backup instead of an incremental backup. - """ - forceFull: Boolean - } - - type BackupPayload { - response: Response - taskId: String - } - - input RestoreTenantInput { - """ - restoreInput contains fields that are required for the restore operation, - i.e., location, backupId, and backupNum - """ - restoreInput: RestoreInput - - """ - fromNamespace is the namespace of the tenant that needs to be restored into namespace 0 of the new cluster. - """ - fromNamespace: Int! - } - - input RestoreInput { - - """ - Destination for the backup: e.g. Minio or S3 bucket. - """ - location: String! - - """ - Backup ID of the backup series to restore. This ID is included in the manifest.json file. - If missing, it defaults to the latest series. - """ - backupId: String - - """ - If backupNum is 0 or missing, the entire series will be restored i.e all the incremental - backups available in the series as well as the full backup will be restored. If backupNum is - non-zero, we restore all the backups in the series that have backupNum smaller or equal to - the backupNum provided here. Backups that have backupNum higher than this will be ignored. - In simple words, all the backups with backupNum of backup <= backupNum will be restored. - """ - backupNum: Int - - """ - All the backups with backupNum >= incrementalFrom will be restored. - """ - incrementalFrom: Int - - """ - If isPartial is set to true then the cluster will be kept in draining mode after - restore. This makes sure that the db is not corrupted by any mutations or tablet - moves in between two restores. - """ - isPartial: Boolean - - """ - Path to the key file needed to decrypt the backup. This file should be accessible - by all alphas in the group. The backup will be written using the encryption key - with which the cluster was started, which might be different than this key. - """ - encryptionKeyFile: String - - """ - Vault server address where the key is stored. This server must be accessible - by all alphas in the group. Default "http://localhost:8200". - """ - vaultAddr: String - - """ - Path to the Vault RoleID file. - """ - vaultRoleIDFile: String - - """ - Path to the Vault SecretID file. - """ - vaultSecretIDFile: String - - """ - Vault kv store path where the key lives. Default "secret/data/dgraph". - """ - vaultPath: String - - """ - Vault kv store field whose value is the key. Default "enc_key". - """ - vaultField: String - - """ - Vault kv store field's format. Must be "base64" or "raw". Default "base64". - """ - vaultFormat: String - - """ - Access key credential for the destination. - """ - accessKey: String - - """ - Secret key credential for the destination. - """ - secretKey: String - - """ - AWS session token, if required. - """ - sessionToken: String - - """ - Set to true to allow backing up to S3 or Minio bucket that requires no credentials. - """ - anonymous: Boolean - } - - type RestorePayload { - """ - A short string indicating whether the restore operation was successfully scheduled. - """ - code: String - - """ - Includes the error message if the operation failed. - """ - message: String - } - - input ListBackupsInput { - """ - Destination for the backup: e.g. Minio or S3 bucket. - """ - location: String! - - """ - Access key credential for the destination. - """ - accessKey: String - - """ - Secret key credential for the destination. - """ - secretKey: String - - """ - AWS session token, if required. - """ - sessionToken: String - - """ - Whether the destination doesn't require credentials (e.g. S3 public bucket). - """ - anonymous: Boolean - - } - - type BackupGroup { - """ - The ID of the cluster group. - """ - groupId: UInt64 - - """ - List of predicates assigned to the group. - """ - predicates: [String] - } - - type Manifest { - """ - Unique ID for the backup series. - """ - backupId: String - - """ - Number of this backup within the backup series. The full backup always has a value of one. - """ - backupNum: UInt64 - - """ - Whether this backup was encrypted. - """ - encrypted: Boolean - - """ - List of groups and the predicates they store in this backup. - """ - groups: [BackupGroup] - - """ - Path to the manifest file. - """ - path: String - - """ - The timestamp at which this backup was taken. The next incremental backup will - start from this timestamp. - """ - since: UInt64 - - """ - The type of backup, either full or incremental. - """ - type: String - } - - type LoginResponse { - - """ - JWT token that should be used in future requests after this login. - """ - accessJWT: String - - """ - Refresh token that can be used to re-login after accessJWT expires. - """ - refreshJWT: String - } - - type LoginPayload { - response: LoginResponse - } - - type User @dgraph(type: "dgraph.type.User") @secret(field: "password", pred: "dgraph.password") { - - """ - Username for the user. Dgraph ensures that usernames are unique. - """ - name: String! @id @dgraph(pred: "dgraph.xid") - - groups: [Group] @dgraph(pred: "dgraph.user.group") - } - - type Group @dgraph(type: "dgraph.type.Group") { - - """ - Name of the group. Dgraph ensures uniqueness of group names. - """ - name: String! @id @dgraph(pred: "dgraph.xid") - users: [User] @dgraph(pred: "~dgraph.user.group") - rules: [Rule] @dgraph(pred: "dgraph.acl.rule") - } - - type Rule @dgraph(type: "dgraph.type.Rule") { - - """ - Predicate to which the rule applies. - """ - predicate: String! @dgraph(pred: "dgraph.rule.predicate") - - """ - Permissions that apply for the rule. Represented following the UNIX file permission - convention. That is, 4 (binary 100) represents READ, 2 (binary 010) represents WRITE, - and 1 (binary 001) represents MODIFY (the permission to change a predicate’s schema). - - The options are: - * 1 (binary 001) : MODIFY - * 2 (010) : WRITE - * 3 (011) : WRITE+MODIFY - * 4 (100) : READ - * 5 (101) : READ+MODIFY - * 6 (110) : READ+WRITE - * 7 (111) : READ+WRITE+MODIFY - - Permission 0, which is equal to no permission for a predicate, blocks all read, - write and modify operations. - """ - permission: Int! @dgraph(pred: "dgraph.rule.permission") - } - - input StringHashFilter { - eq: String - } - - enum UserOrderable { - name - } - - enum GroupOrderable { - name - } - - input AddUserInput { - name: String! - password: String! - groups: [GroupRef] - } - - input AddGroupInput { - name: String! - rules: [RuleRef] - } - - input UserRef { - name: String! - } - - input GroupRef { - name: String! - } - - input RuleRef { - """ - Predicate to which the rule applies. - """ - predicate: String! - - """ - Permissions that apply for the rule. Represented following the UNIX file permission - convention. That is, 4 (binary 100) represents READ, 2 (binary 010) represents WRITE, - and 1 (binary 001) represents MODIFY (the permission to change a predicate’s schema). - - The options are: - * 1 (binary 001) : MODIFY - * 2 (010) : WRITE - * 3 (011) : WRITE+MODIFY - * 4 (100) : READ - * 5 (101) : READ+MODIFY - * 6 (110) : READ+WRITE - * 7 (111) : READ+WRITE+MODIFY - - Permission 0, which is equal to no permission for a predicate, blocks all read, - write and modify operations. - """ - permission: Int! - } - - input UserFilter { - name: StringHashFilter - and: UserFilter - or: UserFilter - not: UserFilter - } - - input UserOrder { - asc: UserOrderable - desc: UserOrderable - then: UserOrder - } - - input GroupOrder { - asc: GroupOrderable - desc: GroupOrderable - then: GroupOrder - } - - input UserPatch { - password: String - groups: [GroupRef] - } - - input UpdateUserInput { - filter: UserFilter! - set: UserPatch - remove: UserPatch - } - - input GroupFilter { - name: StringHashFilter - and: UserFilter - or: UserFilter - not: UserFilter - } - - input SetGroupPatch { - rules: [RuleRef!]! - } - - input RemoveGroupPatch { - rules: [String!]! - } - - input UpdateGroupInput { - filter: GroupFilter! - set: SetGroupPatch - remove: RemoveGroupPatch - } - - type AddUserPayload { - user: [User] - } - - type AddGroupPayload { - group: [Group] - } - - type DeleteUserPayload { - msg: String - numUids: Int - } - - type DeleteGroupPayload { - msg: String - numUids: Int - } - - input AddNamespaceInput { - """ - Enter a new password for groot in that namespace. If you leave it blank, the password will be the default. - """ - password: String - } - - input DeleteNamespaceInput { - namespaceId: Int! - } - - type NamespacePayload { - namespaceId: UInt64 - message: String - } - - input ResetPasswordInput { - userId: String! - password: String! - namespace: Int! - } - - type ResetPasswordPayload { - userId: String - message: String - namespace: UInt64 - } - ` - -const adminMutations = ` - - """ - Start a binary backup. See : https://dgraph.io/docs/enterprise-features/#binary-backups - """ - backup(input: BackupInput!) : BackupPayload - - """ - Start restoring a binary backup. See : - https://dgraph.io/docs/enterprise-features/#binary-backups - """ - restore(input: RestoreInput!) : RestorePayload - - """ - Restore given tenant into namespace 0 of the cluster - """ - restoreTenant(input: RestoreTenantInput!) : RestorePayload - - """ - Login to Dgraph. Successful login results in a JWT that can be used in future requests. - If login is not successful an error is returned. - """ - login(userId: String, password: String, namespace: Int, refreshToken: String): LoginPayload - - """ - Add a user. When linking to groups: if the group doesn't exist it is created; if the group - exists, the new user is linked to the existing group. It's possible to both create new - groups and link to existing groups in the one mutation. - - Dgraph ensures that usernames are unique, hence attempting to add an existing user results - in an error. - """ - addUser(input: [AddUserInput!]!): AddUserPayload - - """ - Add a new group and (optionally) set the rules for the group. - """ - addGroup(input: [AddGroupInput!]!): AddGroupPayload - - """ - Update users, their passwords and groups. As with AddUser, when linking to groups: if the - group doesn't exist it is created; if the group exists, the new user is linked to the existing - group. If the filter doesn't match any users, the mutation has no effect. - """ - updateUser(input: UpdateUserInput!): AddUserPayload - - """ - Add or remove rules for groups. If the filter doesn't match any groups, - the mutation has no effect. - """ - updateGroup(input: UpdateGroupInput!): AddGroupPayload - - deleteGroup(filter: GroupFilter!): DeleteGroupPayload - deleteUser(filter: UserFilter!): DeleteUserPayload - - """ - Add a new namespace. - """ - addNamespace(input: AddNamespaceInput): NamespacePayload - - """ - Delete a namespace. - """ - deleteNamespace(input: DeleteNamespaceInput!): NamespacePayload - - """ - Reset password can only be used by the Guardians of the galaxy to reset password of - any user in any namespace. - """ - resetPassword(input: ResetPasswordInput!): ResetPasswordPayload - ` - -const adminQueries = ` - getUser(name: String!): User - getGroup(name: String!): Group - - """ - Get the currently logged in user. - """ - getCurrentUser: User - - queryUser(filter: UserFilter, order: UserOrder, first: Int, offset: Int): [User] - queryGroup(filter: GroupFilter, order: GroupOrder, first: Int, offset: Int): [Group] - - """ - Get the information about the backups at a given location. - """ - listBackups(input: ListBackupsInput!) : [Manifest] - ` diff --git a/graphql/e2e/multi_tenancy/multi_tenancy_test.go b/graphql/e2e/multi_tenancy/multi_tenancy_test.go index f409902ffff..89534505b6e 100644 --- a/graphql/e2e/multi_tenancy/multi_tenancy_test.go +++ b/graphql/e2e/multi_tenancy/multi_tenancy_test.go @@ -2,6 +2,7 @@ /* * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 */ //nolint:lll diff --git a/systest/acl/restore/acl_restore_test.go b/systest/acl/restore/acl_restore_test.go index b1fc35710b6..d6048258dc5 100644 --- a/systest/acl/restore/acl_restore_test.go +++ b/systest/acl/restore/acl_restore_test.go @@ -2,6 +2,7 @@ /* * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 */ package main diff --git a/systest/audit/audit_test.go b/systest/audit/audit_test.go index 6213cc99d9a..30a202fe805 100644 --- a/systest/audit/audit_test.go +++ b/systest/audit/audit_test.go @@ -2,6 +2,7 @@ /* * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 */ package audit diff --git a/systest/audit_encrypted/audit_test.go b/systest/audit_encrypted/audit_test.go index 7ee00a18066..3ca535c0db0 100644 --- a/systest/audit_encrypted/audit_test.go +++ b/systest/audit_encrypted/audit_test.go @@ -2,6 +2,7 @@ /* * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 */ package audit_encrypted diff --git a/systest/backup/common/utils.go b/systest/backup/common/utils.go index 51d288b9c62..0bf2c8c4894 100644 --- a/systest/backup/common/utils.go +++ b/systest/backup/common/utils.go @@ -1,6 +1,3 @@ -//go:build !oss -// +build !oss - /* * SPDX-FileCopyrightText: © Hypermode Inc. * SPDX-License-Identifier: Apache-2.0 diff --git a/systest/cdc/cdc_test.go b/systest/cdc/cdc_test.go index 0157d6bd836..64fff6909b5 100644 --- a/systest/cdc/cdc_test.go +++ b/systest/cdc/cdc_test.go @@ -2,6 +2,7 @@ /* * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 */ package cdc diff --git a/systest/cloud/cloud_test.go b/systest/cloud/cloud_test.go index 753fb840e5b..b8a3e125a82 100644 --- a/systest/cloud/cloud_test.go +++ b/systest/cloud/cloud_test.go @@ -2,6 +2,7 @@ /* * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 */ package main diff --git a/tlstest/acl/acl_over_tls_test.go b/tlstest/acl/acl_over_tls_test.go index 394ab03d001..ca3ee08a402 100644 --- a/tlstest/acl/acl_over_tls_test.go +++ b/tlstest/acl/acl_over_tls_test.go @@ -2,6 +2,7 @@ /* * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 */ package acl diff --git a/tlstest/certrequest/certrequest_test.go b/tlstest/certrequest/certrequest_test.go index 8bca101e1ef..268d54fc11a 100644 --- a/tlstest/certrequest/certrequest_test.go +++ b/tlstest/certrequest/certrequest_test.go @@ -2,6 +2,7 @@ /* * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 */ package certrequest diff --git a/tlstest/certrequireandverify/certrequireandverify_test.go b/tlstest/certrequireandverify/certrequireandverify_test.go index 93e9cc5c90c..48ddff0dd5f 100644 --- a/tlstest/certrequireandverify/certrequireandverify_test.go +++ b/tlstest/certrequireandverify/certrequireandverify_test.go @@ -2,6 +2,7 @@ /* * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 */ package certrequireandverify diff --git a/tlstest/certverifyifgiven/certverifyifgiven_test.go b/tlstest/certverifyifgiven/certverifyifgiven_test.go index d900d75c1b3..612ea1cfdba 100644 --- a/tlstest/certverifyifgiven/certverifyifgiven_test.go +++ b/tlstest/certverifyifgiven/certverifyifgiven_test.go @@ -2,6 +2,7 @@ /* * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 */ package certverifyifgiven diff --git a/worker/acl_cache.go b/worker/acl_cache.go index 2955c26a4bf..c397d231c17 100644 --- a/worker/acl_cache.go +++ b/worker/acl_cache.go @@ -1,8 +1,6 @@ -//go:build !oss -// +build !oss - /* * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 */ package worker diff --git a/worker/acl_cache_test.go b/worker/acl_cache_test.go index b3a4fc23e87..83b7bfe9b16 100644 --- a/worker/acl_cache_test.go +++ b/worker/acl_cache_test.go @@ -1,8 +1,6 @@ -//go:build !oss -// +build !oss - /* * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 */ package worker diff --git a/worker/backup.go b/worker/backup.go index b55de102f37..ced668d41e5 100644 --- a/worker/backup.go +++ b/worker/backup.go @@ -7,13 +7,31 @@ package worker import ( "context" + "encoding/binary" + "encoding/hex" + "fmt" + "io" "math" + "net/url" + "reflect" + "strings" + "sync" + "time" "github.com/golang/glog" + "github.com/klauspost/compress/s2" "github.com/pkg/errors" + ostats "go.opencensus.io/stats" + "google.golang.org/protobuf/proto" "github.com/dgraph-io/badger/v4" + bpb "github.com/dgraph-io/badger/v4/pb" + "github.com/dgraph-io/badger/v4/y" + "github.com/dgraph-io/ristretto/v2/z" + "github.com/hypermodeinc/dgraph/v24/ee/enc" + "github.com/hypermodeinc/dgraph/v24/posting" "github.com/hypermodeinc/dgraph/v24/protos/pb" + "github.com/hypermodeinc/dgraph/v24/tok/hnsw" "github.com/hypermodeinc/dgraph/v24/x" ) @@ -112,3 +130,682 @@ func StoreExport(request *pb.ExportRequest, dir string, key x.Sensitive) error { _, err = exportInternal(context.Background(), request, db, true) return errors.Wrapf(err, "cannot export data inside DB at %s", dir) } + +// Backup handles a request coming from another node. +func (w *grpcWorker) Backup(ctx context.Context, req *pb.BackupRequest) (*pb.BackupResponse, error) { + glog.V(2).Infof("Received backup request via Grpc: %+v", req) + return backupCurrentGroup(ctx, req) +} + +func backupCurrentGroup(ctx context.Context, req *pb.BackupRequest) (*pb.BackupResponse, error) { + glog.Infof("Backup request: group %d at %d", req.GroupId, req.ReadTs) + if err := ctx.Err(); err != nil { + glog.Errorf("Context error during backup: %v\n", err) + return nil, err + } + + g := groups() + if g.groupId() != req.GroupId { + return nil, errors.Errorf("Backup request group mismatch. Mine: %d. Requested: %d\n", + g.groupId(), req.GroupId) + } + + if err := posting.Oracle().WaitForTs(ctx, req.ReadTs); err != nil { + return nil, err + } + + closer, err := g.Node.startTaskAtTs(opBackup, req.ReadTs) + if err != nil { + return nil, errors.Wrapf(err, "cannot start backup operation") + } + defer closer.Done() + + bp := NewBackupProcessor(pstore, req) + defer bp.Close() + + return bp.WriteBackup(closer.Ctx()) +} + +// BackupGroup backs up the group specified in the backup request. +func BackupGroup(ctx context.Context, in *pb.BackupRequest) (*pb.BackupResponse, error) { + glog.V(2).Infof("Sending backup request: %+v\n", in) + if groups().groupId() == in.GroupId { + return backupCurrentGroup(ctx, in) + } + + // This node is not part of the requested group, send the request over the network. + pl := groups().AnyServer(in.GroupId) + if pl == nil { + return nil, errors.Errorf("Couldn't find a server in group %d", in.GroupId) + } + res, err := pb.NewWorkerClient(pl.Get()).Backup(ctx, in) + if err != nil { + glog.Errorf("Backup error group %d: %s", in.GroupId, err) + return nil, err + } + + return res, nil +} + +// backupLock is used to synchronize backups to avoid more than one backup request +// to be processed at the same time. Multiple requests could lead to multiple +// backups with the same backupNum in their manifest. +var backupLock sync.Mutex + +// BackupRes is used to represent the response and error of the Backup gRPC call together to be +// transported via a channel. +type BackupRes struct { + res *pb.BackupResponse + err error +} + +func ProcessBackupRequest(ctx context.Context, req *pb.BackupRequest) error { + if err := x.HealthCheck(); err != nil { + glog.Errorf("Backup canceled, not ready to accept requests: %s", err) + return err + } + + // Grab the lock here to avoid more than one request to be processed at the same time. + backupLock.Lock() + defer backupLock.Unlock() + + backupSuccessful := false + ostats.Record(ctx, x.NumBackups.M(1), x.PendingBackups.M(1)) + defer func() { + if backupSuccessful { + ostats.Record(ctx, x.NumBackupsSuccess.M(1), x.PendingBackups.M(-1)) + } else { + ostats.Record(ctx, x.NumBackupsFailed.M(1), x.PendingBackups.M(-1)) + } + }() + + ts, err := Timestamps(ctx, &pb.Num{ReadOnly: true}) + if err != nil { + glog.Errorf("Unable to retrieve readonly timestamp for backup: %s", err) + return err + } + + req.ReadTs = ts.ReadOnly + req.UnixTs = time.Now().UTC().Format("20060102.150405.000") + + // Read the manifests to get the right timestamp from which to start the backup. + uri, err := url.Parse(req.Destination) + if err != nil { + return err + } + handler, err := NewUriHandler(uri, GetCredentialsFromRequest(req)) + if err != nil { + return err + } + if !handler.DirExists("./") { + if err := handler.CreateDir("./"); err != nil { + return errors.Wrap(err, "while creating backup directory") + } + } + latestManifest, err := GetLatestManifest(handler, uri) + if err != nil { + return err + } + + req.SinceTs = latestManifest.ValidReadTs() + // To force a full backup we'll set the sinceTs to zero. + if req.ForceFull { + req.SinceTs = 0 + } else { + if x.WorkerConfig.EncryptionKey != nil { + // If encryption key given, latest backup should be encrypted. + if latestManifest.Type != "" && !latestManifest.Encrypted { + err = errors.Errorf("latest manifest indicates the last backup was not encrypted " + + "but this instance has encryption turned on. Try \"forceFull\" flag.") + return err + } + } else { + // If encryption turned off, latest backup should be unencrypted. + if latestManifest.Type != "" && latestManifest.Encrypted { + err = errors.Errorf("latest manifest indicates the last backup was encrypted " + + "but this instance has encryption turned off. Try \"forceFull\" flag.") + return err + } + } + } + + // Update the membership state to get the latest mapping of groups to predicates. + if err := UpdateMembershipState(ctx); err != nil { + return err + } + + // Get the current membership state and parse it for easier processing. + state := GetMembershipState() + var groups []uint32 + predMap := make(map[uint32][]string) + for gid, group := range state.Groups { + groups = append(groups, gid) + predMap[gid] = make([]string, 0) + for pred := range group.Tablets { + predMap[gid] = append(predMap[gid], pred) + } + + } + + // see if any of the predicates are vector predicates and add the supporting + // vector predicates to the backup request. + vecPredMap := make(map[uint32][]string) + for gid, preds := range predMap { + schema, err := GetSchemaOverNetwork(ctx, &pb.SchemaRequest{Predicates: preds}) + if err != nil { + return err + } + + for _, pred := range schema { + if pred.Type == "float32vector" && len(pred.IndexSpecs) != 0 { + vecPredMap[gid] = append(predMap[gid], pred.Predicate+hnsw.VecEntry, pred.Predicate+hnsw.VecKeyword, + pred.Predicate+hnsw.VecDead) + } + } + } + + for gid, preds := range vecPredMap { + predMap[gid] = append(predMap[gid], preds...) + } + + glog.Infof( + "Created backup request: read_ts:%d since_ts:%d unix_ts:\"%s\" destination:\"%s\" . Groups=%v\n", + req.ReadTs, + req.SinceTs, + req.UnixTs, + req.Destination, + groups, + ) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + resCh := make(chan BackupRes, len(state.Groups)) + for _, gid := range groups { + br := proto.Clone(req).(*pb.BackupRequest) + br.GroupId = gid + br.Predicates = predMap[gid] + go func(req *pb.BackupRequest) { + res, err := BackupGroup(ctx, req) + resCh <- BackupRes{res: res, err: err} + }(br) + } + + var dropOperations []*pb.DropOperation + for range groups { + backupRes := <-resCh + if backupRes.err != nil { + glog.Errorf("Error received during backup: %v", backupRes.err) + return backupRes.err + } + dropOperations = append(dropOperations, backupRes.res.GetDropOperations()...) + } + + dir := fmt.Sprintf(backupPathFmt, req.UnixTs) + m := Manifest{ + ReadTs: req.ReadTs, + Groups: predMap, + Version: x.ManifestVersion, + DropOperations: dropOperations, + Path: dir, + Compression: "snappy", + } + if req.SinceTs == 0 { + m.Type = "full" + m.BackupId = x.GetRandomName(1) + m.BackupNum = 1 + } else { + m.Type = "incremental" + m.BackupId = latestManifest.BackupId + m.BackupNum = latestManifest.BackupNum + 1 + } + m.Encrypted = x.WorkerConfig.EncryptionKey != nil + + bp := NewBackupProcessor(nil, req) + defer bp.Close() + err = bp.CompleteBackup(ctx, &m) + + if err != nil { + return err + } + + backupSuccessful = true + return nil +} + +func ProcessListBackups(ctx context.Context, location string, creds *x.MinioCredentials) ( + []*Manifest, error) { + + manifests, err := ListBackupManifests(location, creds) + if err != nil { + return nil, errors.Wrapf(err, "cannot read manifests at location %s", location) + } + + res := make([]*Manifest, 0, len(manifests)) + res = append(res, manifests...) + return res, nil +} + +// BackupProcessor handles the different stages of the backup process. +type BackupProcessor struct { + // DB is the Badger pstore managed by this node. + DB *badger.DB + // Request stores the backup request containing the parameters for this backup. + Request *pb.BackupRequest + + // txn is used for the iterators in the threadLocal + txn *badger.Txn + threads []*threadLocal +} + +type threadLocal struct { + Request *pb.BackupRequest + // pre-allocated pb.BackupPostingList object. + bpl pb.BackupPostingList + alloc *z.Allocator + itr *badger.Iterator + buf *z.Buffer +} + +func NewBackupProcessor(db *badger.DB, req *pb.BackupRequest) *BackupProcessor { + bp := &BackupProcessor{ + DB: db, + Request: req, + threads: make([]*threadLocal, x.WorkerConfig.Badger.NumGoroutines), + } + if req.SinceTs > 0 && db != nil { + bp.txn = db.NewTransactionAt(req.ReadTs, false) + } + for i := range bp.threads { + buf := z.NewBuffer(32<<20, "Worker.BackupProcessor") + + bp.threads[i] = &threadLocal{ + Request: bp.Request, + buf: buf, + } + if bp.txn != nil { + iopt := badger.DefaultIteratorOptions + iopt.AllVersions = true + bp.threads[i].itr = bp.txn.NewIterator(iopt) + } + } + return bp +} + +func (pr *BackupProcessor) Close() { + for _, th := range pr.threads { + if pr.txn != nil { + th.itr.Close() + } + _ = th.buf.Release() + } + if pr.txn != nil { + pr.txn.Discard() + } +} + +// LoadResult holds the output of a Load operation. +type LoadResult struct { + // Version is the timestamp at which the database is after loading a backup. + Version uint64 + // MaxLeaseUid is the max UID seen by the load operation. Needed to request zero + // for the proper number of UIDs. + MaxLeaseUid uint64 + // MaxLeaseNsId is the max namespace ID seen by the load operation. + MaxLeaseNsId uint64 + // The error, if any, of the load operation. + Err error +} + +// WriteBackup uses the request values to create a stream writer then hand off the data +// retrieval to stream.Orchestrate. The writer will create all the fd's needed to +// collect the data and later move to the target. +// Returns errors on failure, nil on success. +func (pr *BackupProcessor) WriteBackup(ctx context.Context) (*pb.BackupResponse, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + uri, err := url.Parse(pr.Request.Destination) + if err != nil { + return nil, err + } + handler, err := NewUriHandler(uri, GetCredentialsFromRequest(pr.Request)) + if err != nil { + return nil, err + } + w, err := createBackupFile(handler, uri, pr.Request) + if err != nil { + return nil, err + } + glog.V(3).Infof("Backup manifest version: %d", pr.Request.SinceTs) + + eWriter, err := enc.GetWriter(x.WorkerConfig.EncryptionKey, w) + if err != nil { + return nil, err + } + + // Snappy is much faster than gzip compression, even with the BestSpeed + // gzip option. In fact, in my experiments, gzip compression caused the + // output speed to be ~30 MBps. Snappy can write at ~90 MBps, and overall + // the speed is similar to writing uncompressed data on disk. + // + // These are the times I saw: + // Without compression: 7m2s 33GB output. + // With snappy: 7m11s 9.5GB output. + // With snappy + S3: 7m54s 9.5GB output. + cWriter := s2.NewWriter(eWriter) + + stream := pr.DB.NewStreamAt(pr.Request.ReadTs) + stream.LogPrefix = "Dgraph.Backup" + // Ignore versions less than given sinceTs timestamp, or skip older versions of + // the given key by returning an empty list. + // Do not do this for schema and type keys. Those keys always have a + // version of one. They're handled separately. + stream.SinceTs = pr.Request.SinceTs + stream.Prefix = []byte{x.ByteData} + + var response pb.BackupResponse + stream.KeyToList = func(key []byte, itr *badger.Iterator) (*bpb.KVList, error) { + tl := pr.threads[itr.ThreadId] + tl.alloc = itr.Alloc + + bitr := itr + // Use the threadlocal iterator because "itr" has the sinceTs set and + // it will not be able to read all the data. + if tl.itr != nil { + bitr = tl.itr + bitr.Seek(key) + } + + kvList, dropOp, err := tl.toBackupList(key, bitr) + if err != nil { + return nil, err + } + // we don't want to append a nil value to the slice, so need to check. + if dropOp != nil { + response.DropOperations = append(response.DropOperations, dropOp) + } + return kvList, nil + } + + predMap := make(map[string]struct{}) + for _, pred := range pr.Request.Predicates { + predMap[pred] = struct{}{} + } + stream.ChooseKey = func(item *badger.Item) bool { + parsedKey, err := x.Parse(item.Key()) + if err != nil { + glog.Errorf("error %v while parsing key %v during backup. Skipping...", + err, hex.EncodeToString(item.Key())) + return false + } + + // Do not choose keys that contain parts of a multi-part list. These keys + // will be accessed from the main list. + if parsedKey.HasStartUid { + return false + } + + // Skip backing up the schema and type keys. They will be backed up separately. + if parsedKey.IsSchema() || parsedKey.IsType() { + return false + } + _, ok := predMap[parsedKey.Attr] + return ok + } + + var maxVersion uint64 + stream.Send = func(buf *z.Buffer) error { + list, err := badger.BufferToKVList(buf) + if err != nil { + return err + } + for _, kv := range list.Kv { + if maxVersion < kv.Version { + maxVersion = kv.Version + } + } + return writeKVList(list, cWriter) + } + + // This is where the execution happens. + if err := stream.Orchestrate(ctx); err != nil { + glog.Errorf("While taking backup: %v", err) + return &response, err + } + + // This is used to backup the schema and types. + writePrefix := func(prefix byte) error { + tl := threadLocal{ + alloc: z.NewAllocator(1<<10, "BackupProcessor.WritePrefix"), + } + defer tl.alloc.Release() + + txn := pr.DB.NewTransactionAt(pr.Request.ReadTs, false) + defer txn.Discard() + // We don't need to iterate over all versions. + iopts := badger.DefaultIteratorOptions + iopts.Prefix = []byte{prefix} + + itr := txn.NewIterator(iopts) + defer itr.Close() + + list := &bpb.KVList{} + for itr.Rewind(); itr.Valid(); itr.Next() { + item := itr.Item() + // Don't export deleted items. + if item.IsDeletedOrExpired() { + continue + } + parsedKey, err := x.Parse(item.Key()) + if err != nil { + glog.Errorf("error %v while parsing key %v during backup. Skipping...", + err, hex.EncodeToString(item.Key())) + continue + } + // This check makes sense only for the schema keys. The types are not stored in it. + if _, ok := predMap[parsedKey.Attr]; !parsedKey.IsType() && !ok { + continue + } + kv := y.NewKV(tl.alloc) + if err := item.Value(func(val []byte) error { + kv.Value = append(kv.Value, val...) + return nil + }); err != nil { + return errors.Wrapf(err, "while copying value") + } + + backupKey, err := tl.toBackupKey(item.Key()) + if err != nil { + return err + } + kv.Key = backupKey + kv.UserMeta = tl.alloc.Copy([]byte{item.UserMeta()}) + kv.Version = item.Version() + kv.ExpiresAt = item.ExpiresAt() + list.Kv = append(list.Kv, kv) + } + return writeKVList(list, cWriter) + } + + for _, prefix := range []byte{x.ByteSchema, x.ByteType} { + if err := writePrefix(prefix); err != nil { + glog.Errorf("While writing prefix %d to backup: %v", prefix, err) + return &response, err + } + } + + if maxVersion > pr.Request.ReadTs { + glog.Errorf("Max timestamp seen during backup (%d) is greater than readTs (%d)", + maxVersion, pr.Request.ReadTs) + } + + glog.V(2).Infof("Backup group %d version: %d", pr.Request.GroupId, pr.Request.ReadTs) + if err = cWriter.Close(); err != nil { + glog.Errorf("While closing gzipped writer: %v", err) + return &response, err + } + + if err = w.Close(); err != nil { + glog.Errorf("While closing handler: %v", err) + return &response, err + } + glog.Infof("Backup complete: group %d at %d", pr.Request.GroupId, pr.Request.ReadTs) + return &response, nil +} + +// CompleteBackup will finalize a backup by writing the manifest at the backup destination. +func (pr *BackupProcessor) CompleteBackup(ctx context.Context, m *Manifest) error { + if err := ctx.Err(); err != nil { + return err + } + uri, err := url.Parse(pr.Request.Destination) + if err != nil { + return err + } + handler, err := NewUriHandler(uri, GetCredentialsFromRequest(pr.Request)) + if err != nil { + return err + } + + manifest, err := GetManifestNoUpgrade(handler, uri) + if err != nil { + return err + } + manifest.Manifests = append(manifest.Manifests, m) + + if err := CreateManifest(handler, uri, manifest); err != nil { + return errors.Wrap(err, "complete backup failed") + } + glog.Infof("Backup completed OK.") + return nil +} + +// GoString implements the GoStringer interface for Manifest. +func (m *Manifest) GoString() string { + return fmt.Sprintf(`Manifest{Since: %d, ReadTs: %d, Groups: %v, Encrypted: %v}`, + m.SinceTsDeprecated, m.ReadTs, m.Groups, m.Encrypted) +} + +func (tl *threadLocal) toBackupList(key []byte, itr *badger.Iterator) ( + *bpb.KVList, *pb.DropOperation, error) { + list := &bpb.KVList{} + var dropOp *pb.DropOperation + + item := itr.Item() + if item.Version() < tl.Request.SinceTs { + return list, nil, + errors.Errorf("toBackupList: Item.Version(): %d should be less than sinceTs: %d", + item.Version(), tl.Request.SinceTs) + } + if item.IsDeletedOrExpired() { + return list, nil, nil + } + + switch item.UserMeta() { + case posting.BitEmptyPosting, posting.BitCompletePosting, posting.BitDeltaPosting: + l, err := posting.ReadPostingList(key, itr) + if err != nil { + return nil, nil, errors.Wrapf(err, "while reading posting list") + } + + // Don't allocate kv on tl.alloc, because we don't need it by the end of this func. + kv, err := l.ToBackupPostingList(&tl.bpl, tl.alloc, tl.buf) + if err != nil { + return nil, nil, errors.Wrapf(err, "while rolling up list") + } + + backupKey, err := tl.toBackupKey(kv.Key) + if err != nil { + return nil, nil, err + } + + // check if this key was storing a DROP operation record. If yes, get the drop operation. + dropOp, err = checkAndGetDropOp(key, l, tl.Request.ReadTs) + if err != nil { + return nil, nil, err + } + + kv.Key = backupKey + list.Kv = append(list.Kv, kv) + default: + return nil, nil, errors.Errorf( + "Unexpected meta: %d for key: %s", item.UserMeta(), hex.Dump(key)) + } + return list, dropOp, nil +} + +func (tl *threadLocal) toBackupKey(key []byte) ([]byte, error) { + parsedKey, err := x.Parse(key) + if err != nil { + return nil, errors.Wrapf(err, "could not parse key %s", hex.Dump(key)) + } + bk := parsedKey.ToBackupKey() + + out := tl.alloc.Allocate(proto.Size(bk)) + return x.MarshalToSizedBuffer(out, bk) +} + +func writeKVList(list *bpb.KVList, w io.Writer) error { + if err := binary.Write(w, binary.LittleEndian, uint64(proto.Size(list))); err != nil { + return err + } + buf, err := proto.Marshal(list) + if err != nil { + return err + } + _, err = w.Write(buf) + return err +} + +func checkAndGetDropOp(key []byte, l *posting.List, readTs uint64) (*pb.DropOperation, error) { + isDropOpKey, err := x.IsDropOpKey(key) + if err != nil || !isDropOpKey { + return nil, err + } + + vals, err := l.AllValues(readTs) + if err != nil { + return nil, errors.Wrapf(err, "cannot read value of dgraph.drop.op") + } + switch len(vals) { + case 0: + // do nothing, it means this one was deleted with S * * deletion. + // So, no need to consider it. + return nil, nil + case 1: + val, ok := vals[0].Value.([]byte) + if !ok { + return nil, errors.Errorf("cannot convert value of dgraph.drop.op to byte array, "+ + "got type: %s, value: %v, tid: %v", reflect.TypeOf(vals[0].Value), vals[0].Value, + vals[0].Tid) + } + // A dgraph.drop.op record can have values in only one of the following formats: + // * DROP_ALL; + // * DROP_DATA;ns + // * DROP_ATTR;attrName + // * DROP_NS;ns + // So, accordingly construct the *pb.DropOperation. + dropOp := &pb.DropOperation{} + dropInfo := strings.Split(string(val), ";") + if len(dropInfo) != 2 { + return nil, errors.Errorf("Unexpected value: %s for dgraph.drop.op", val) + } + switch dropInfo[0] { + case "DROP_ALL": + dropOp.DropOp = pb.DropOperation_ALL + case "DROP_DATA": + dropOp.DropOp = pb.DropOperation_DATA + dropOp.DropValue = dropInfo[1] // contains namespace. + case "DROP_ATTR": + dropOp.DropOp = pb.DropOperation_ATTR + dropOp.DropValue = dropInfo[1] + case "DROP_NS": + dropOp.DropOp = pb.DropOperation_NS + dropOp.DropValue = dropInfo[1] // contains namespace. + } + return dropOp, nil + default: + // getting more than one values for a non-list predicate is an error + return nil, errors.Errorf("found multiple values for dgraph.drop.op: %v", vals) + } +} diff --git a/worker/backup_ee.go b/worker/backup_ee.go deleted file mode 100644 index d3bf785ca11..00000000000 --- a/worker/backup_ee.go +++ /dev/null @@ -1,716 +0,0 @@ -//go:build !oss -// +build !oss - -/* - * SPDX-FileCopyrightText: © Hypermode Inc. - */ - -package worker - -import ( - "context" - "encoding/binary" - "encoding/hex" - "fmt" - "io" - "net/url" - "reflect" - "strings" - "sync" - "time" - - "github.com/golang/glog" - "github.com/klauspost/compress/s2" - "github.com/pkg/errors" - ostats "go.opencensus.io/stats" - "google.golang.org/protobuf/proto" - - "github.com/dgraph-io/badger/v4" - bpb "github.com/dgraph-io/badger/v4/pb" - "github.com/dgraph-io/badger/v4/y" - "github.com/dgraph-io/ristretto/v2/z" - "github.com/hypermodeinc/dgraph/v24/ee/enc" - "github.com/hypermodeinc/dgraph/v24/posting" - "github.com/hypermodeinc/dgraph/v24/protos/pb" - "github.com/hypermodeinc/dgraph/v24/tok/hnsw" - "github.com/hypermodeinc/dgraph/v24/x" -) - -// Backup handles a request coming from another node. -func (w *grpcWorker) Backup(ctx context.Context, req *pb.BackupRequest) (*pb.BackupResponse, error) { - glog.V(2).Infof("Received backup request via Grpc: %+v", req) - return backupCurrentGroup(ctx, req) -} - -func backupCurrentGroup(ctx context.Context, req *pb.BackupRequest) (*pb.BackupResponse, error) { - glog.Infof("Backup request: group %d at %d", req.GroupId, req.ReadTs) - if err := ctx.Err(); err != nil { - glog.Errorf("Context error during backup: %v\n", err) - return nil, err - } - - g := groups() - if g.groupId() != req.GroupId { - return nil, errors.Errorf("Backup request group mismatch. Mine: %d. Requested: %d\n", - g.groupId(), req.GroupId) - } - - if err := posting.Oracle().WaitForTs(ctx, req.ReadTs); err != nil { - return nil, err - } - - closer, err := g.Node.startTaskAtTs(opBackup, req.ReadTs) - if err != nil { - return nil, errors.Wrapf(err, "cannot start backup operation") - } - defer closer.Done() - - bp := NewBackupProcessor(pstore, req) - defer bp.Close() - - return bp.WriteBackup(closer.Ctx()) -} - -// BackupGroup backs up the group specified in the backup request. -func BackupGroup(ctx context.Context, in *pb.BackupRequest) (*pb.BackupResponse, error) { - glog.V(2).Infof("Sending backup request: %+v\n", in) - if groups().groupId() == in.GroupId { - return backupCurrentGroup(ctx, in) - } - - // This node is not part of the requested group, send the request over the network. - pl := groups().AnyServer(in.GroupId) - if pl == nil { - return nil, errors.Errorf("Couldn't find a server in group %d", in.GroupId) - } - res, err := pb.NewWorkerClient(pl.Get()).Backup(ctx, in) - if err != nil { - glog.Errorf("Backup error group %d: %s", in.GroupId, err) - return nil, err - } - - return res, nil -} - -// backupLock is used to synchronize backups to avoid more than one backup request -// to be processed at the same time. Multiple requests could lead to multiple -// backups with the same backupNum in their manifest. -var backupLock sync.Mutex - -// BackupRes is used to represent the response and error of the Backup gRPC call together to be -// transported via a channel. -type BackupRes struct { - res *pb.BackupResponse - err error -} - -func ProcessBackupRequest(ctx context.Context, req *pb.BackupRequest) error { - if err := x.HealthCheck(); err != nil { - glog.Errorf("Backup canceled, not ready to accept requests: %s", err) - return err - } - - // Grab the lock here to avoid more than one request to be processed at the same time. - backupLock.Lock() - defer backupLock.Unlock() - - backupSuccessful := false - ostats.Record(ctx, x.NumBackups.M(1), x.PendingBackups.M(1)) - defer func() { - if backupSuccessful { - ostats.Record(ctx, x.NumBackupsSuccess.M(1), x.PendingBackups.M(-1)) - } else { - ostats.Record(ctx, x.NumBackupsFailed.M(1), x.PendingBackups.M(-1)) - } - }() - - ts, err := Timestamps(ctx, &pb.Num{ReadOnly: true}) - if err != nil { - glog.Errorf("Unable to retrieve readonly timestamp for backup: %s", err) - return err - } - - req.ReadTs = ts.ReadOnly - req.UnixTs = time.Now().UTC().Format("20060102.150405.000") - - // Read the manifests to get the right timestamp from which to start the backup. - uri, err := url.Parse(req.Destination) - if err != nil { - return err - } - handler, err := NewUriHandler(uri, GetCredentialsFromRequest(req)) - if err != nil { - return err - } - if !handler.DirExists("./") { - if err := handler.CreateDir("./"); err != nil { - return errors.Wrap(err, "while creating backup directory") - } - } - latestManifest, err := GetLatestManifest(handler, uri) - if err != nil { - return err - } - - req.SinceTs = latestManifest.ValidReadTs() - // To force a full backup we'll set the sinceTs to zero. - if req.ForceFull { - req.SinceTs = 0 - } else { - if x.WorkerConfig.EncryptionKey != nil { - // If encryption key given, latest backup should be encrypted. - if latestManifest.Type != "" && !latestManifest.Encrypted { - err = errors.Errorf("latest manifest indicates the last backup was not encrypted " + - "but this instance has encryption turned on. Try \"forceFull\" flag.") - return err - } - } else { - // If encryption turned off, latest backup should be unencrypted. - if latestManifest.Type != "" && latestManifest.Encrypted { - err = errors.Errorf("latest manifest indicates the last backup was encrypted " + - "but this instance has encryption turned off. Try \"forceFull\" flag.") - return err - } - } - } - - // Update the membership state to get the latest mapping of groups to predicates. - if err := UpdateMembershipState(ctx); err != nil { - return err - } - - // Get the current membership state and parse it for easier processing. - state := GetMembershipState() - var groups []uint32 - predMap := make(map[uint32][]string) - for gid, group := range state.Groups { - groups = append(groups, gid) - predMap[gid] = make([]string, 0) - for pred := range group.Tablets { - predMap[gid] = append(predMap[gid], pred) - } - - } - - // see if any of the predicates are vector predicates and add the supporting - // vector predicates to the backup request. - vecPredMap := make(map[uint32][]string) - for gid, preds := range predMap { - schema, err := GetSchemaOverNetwork(ctx, &pb.SchemaRequest{Predicates: preds}) - if err != nil { - return err - } - - for _, pred := range schema { - if pred.Type == "float32vector" && len(pred.IndexSpecs) != 0 { - vecPredMap[gid] = append(predMap[gid], pred.Predicate+hnsw.VecEntry, pred.Predicate+hnsw.VecKeyword, - pred.Predicate+hnsw.VecDead) - } - } - } - - for gid, preds := range vecPredMap { - predMap[gid] = append(predMap[gid], preds...) - } - - glog.Infof( - "Created backup request: read_ts:%d since_ts:%d unix_ts:\"%s\" destination:\"%s\" . Groups=%v\n", - req.ReadTs, - req.SinceTs, - req.UnixTs, - req.Destination, - groups, - ) - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - resCh := make(chan BackupRes, len(state.Groups)) - for _, gid := range groups { - br := proto.Clone(req).(*pb.BackupRequest) - br.GroupId = gid - br.Predicates = predMap[gid] - go func(req *pb.BackupRequest) { - res, err := BackupGroup(ctx, req) - resCh <- BackupRes{res: res, err: err} - }(br) - } - - var dropOperations []*pb.DropOperation - for range groups { - backupRes := <-resCh - if backupRes.err != nil { - glog.Errorf("Error received during backup: %v", backupRes.err) - return backupRes.err - } - dropOperations = append(dropOperations, backupRes.res.GetDropOperations()...) - } - - dir := fmt.Sprintf(backupPathFmt, req.UnixTs) - m := Manifest{ - ReadTs: req.ReadTs, - Groups: predMap, - Version: x.ManifestVersion, - DropOperations: dropOperations, - Path: dir, - Compression: "snappy", - } - if req.SinceTs == 0 { - m.Type = "full" - m.BackupId = x.GetRandomName(1) - m.BackupNum = 1 - } else { - m.Type = "incremental" - m.BackupId = latestManifest.BackupId - m.BackupNum = latestManifest.BackupNum + 1 - } - m.Encrypted = x.WorkerConfig.EncryptionKey != nil - - bp := NewBackupProcessor(nil, req) - defer bp.Close() - err = bp.CompleteBackup(ctx, &m) - - if err != nil { - return err - } - - backupSuccessful = true - return nil -} - -func ProcessListBackups(ctx context.Context, location string, creds *x.MinioCredentials) ( - []*Manifest, error) { - - manifests, err := ListBackupManifests(location, creds) - if err != nil { - return nil, errors.Wrapf(err, "cannot read manifests at location %s", location) - } - - res := make([]*Manifest, 0, len(manifests)) - res = append(res, manifests...) - return res, nil -} - -// BackupProcessor handles the different stages of the backup process. -type BackupProcessor struct { - // DB is the Badger pstore managed by this node. - DB *badger.DB - // Request stores the backup request containing the parameters for this backup. - Request *pb.BackupRequest - - // txn is used for the iterators in the threadLocal - txn *badger.Txn - threads []*threadLocal -} - -type threadLocal struct { - Request *pb.BackupRequest - // pre-allocated pb.BackupPostingList object. - bpl pb.BackupPostingList - alloc *z.Allocator - itr *badger.Iterator - buf *z.Buffer -} - -func NewBackupProcessor(db *badger.DB, req *pb.BackupRequest) *BackupProcessor { - bp := &BackupProcessor{ - DB: db, - Request: req, - threads: make([]*threadLocal, x.WorkerConfig.Badger.NumGoroutines), - } - if req.SinceTs > 0 && db != nil { - bp.txn = db.NewTransactionAt(req.ReadTs, false) - } - for i := range bp.threads { - buf := z.NewBuffer(32<<20, "Worker.BackupProcessor") - - bp.threads[i] = &threadLocal{ - Request: bp.Request, - buf: buf, - } - if bp.txn != nil { - iopt := badger.DefaultIteratorOptions - iopt.AllVersions = true - bp.threads[i].itr = bp.txn.NewIterator(iopt) - } - } - return bp -} - -func (pr *BackupProcessor) Close() { - for _, th := range pr.threads { - if pr.txn != nil { - th.itr.Close() - } - _ = th.buf.Release() - } - if pr.txn != nil { - pr.txn.Discard() - } -} - -// LoadResult holds the output of a Load operation. -type LoadResult struct { - // Version is the timestamp at which the database is after loading a backup. - Version uint64 - // MaxLeaseUid is the max UID seen by the load operation. Needed to request zero - // for the proper number of UIDs. - MaxLeaseUid uint64 - // MaxLeaseNsId is the max namespace ID seen by the load operation. - MaxLeaseNsId uint64 - // The error, if any, of the load operation. - Err error -} - -// WriteBackup uses the request values to create a stream writer then hand off the data -// retrieval to stream.Orchestrate. The writer will create all the fd's needed to -// collect the data and later move to the target. -// Returns errors on failure, nil on success. -func (pr *BackupProcessor) WriteBackup(ctx context.Context) (*pb.BackupResponse, error) { - if err := ctx.Err(); err != nil { - return nil, err - } - uri, err := url.Parse(pr.Request.Destination) - if err != nil { - return nil, err - } - handler, err := NewUriHandler(uri, GetCredentialsFromRequest(pr.Request)) - if err != nil { - return nil, err - } - w, err := createBackupFile(handler, uri, pr.Request) - if err != nil { - return nil, err - } - glog.V(3).Infof("Backup manifest version: %d", pr.Request.SinceTs) - - eWriter, err := enc.GetWriter(x.WorkerConfig.EncryptionKey, w) - if err != nil { - return nil, err - } - - // Snappy is much faster than gzip compression, even with the BestSpeed - // gzip option. In fact, in my experiments, gzip compression caused the - // output speed to be ~30 MBps. Snappy can write at ~90 MBps, and overall - // the speed is similar to writing uncompressed data on disk. - // - // These are the times I saw: - // Without compression: 7m2s 33GB output. - // With snappy: 7m11s 9.5GB output. - // With snappy + S3: 7m54s 9.5GB output. - cWriter := s2.NewWriter(eWriter) - - stream := pr.DB.NewStreamAt(pr.Request.ReadTs) - stream.LogPrefix = "Dgraph.Backup" - // Ignore versions less than given sinceTs timestamp, or skip older versions of - // the given key by returning an empty list. - // Do not do this for schema and type keys. Those keys always have a - // version of one. They're handled separately. - stream.SinceTs = pr.Request.SinceTs - stream.Prefix = []byte{x.ByteData} - - var response pb.BackupResponse - stream.KeyToList = func(key []byte, itr *badger.Iterator) (*bpb.KVList, error) { - tl := pr.threads[itr.ThreadId] - tl.alloc = itr.Alloc - - bitr := itr - // Use the threadlocal iterator because "itr" has the sinceTs set and - // it will not be able to read all the data. - if tl.itr != nil { - bitr = tl.itr - bitr.Seek(key) - } - - kvList, dropOp, err := tl.toBackupList(key, bitr) - if err != nil { - return nil, err - } - // we don't want to append a nil value to the slice, so need to check. - if dropOp != nil { - response.DropOperations = append(response.DropOperations, dropOp) - } - return kvList, nil - } - - predMap := make(map[string]struct{}) - for _, pred := range pr.Request.Predicates { - predMap[pred] = struct{}{} - } - stream.ChooseKey = func(item *badger.Item) bool { - parsedKey, err := x.Parse(item.Key()) - if err != nil { - glog.Errorf("error %v while parsing key %v during backup. Skipping...", - err, hex.EncodeToString(item.Key())) - return false - } - - // Do not choose keys that contain parts of a multi-part list. These keys - // will be accessed from the main list. - if parsedKey.HasStartUid { - return false - } - - // Skip backing up the schema and type keys. They will be backed up separately. - if parsedKey.IsSchema() || parsedKey.IsType() { - return false - } - _, ok := predMap[parsedKey.Attr] - return ok - } - - var maxVersion uint64 - stream.Send = func(buf *z.Buffer) error { - list, err := badger.BufferToKVList(buf) - if err != nil { - return err - } - for _, kv := range list.Kv { - if maxVersion < kv.Version { - maxVersion = kv.Version - } - } - return writeKVList(list, cWriter) - } - - // This is where the execution happens. - if err := stream.Orchestrate(ctx); err != nil { - glog.Errorf("While taking backup: %v", err) - return &response, err - } - - // This is used to backup the schema and types. - writePrefix := func(prefix byte) error { - tl := threadLocal{ - alloc: z.NewAllocator(1<<10, "BackupProcessor.WritePrefix"), - } - defer tl.alloc.Release() - - txn := pr.DB.NewTransactionAt(pr.Request.ReadTs, false) - defer txn.Discard() - // We don't need to iterate over all versions. - iopts := badger.DefaultIteratorOptions - iopts.Prefix = []byte{prefix} - - itr := txn.NewIterator(iopts) - defer itr.Close() - - list := &bpb.KVList{} - for itr.Rewind(); itr.Valid(); itr.Next() { - item := itr.Item() - // Don't export deleted items. - if item.IsDeletedOrExpired() { - continue - } - parsedKey, err := x.Parse(item.Key()) - if err != nil { - glog.Errorf("error %v while parsing key %v during backup. Skipping...", - err, hex.EncodeToString(item.Key())) - continue - } - // This check makes sense only for the schema keys. The types are not stored in it. - if _, ok := predMap[parsedKey.Attr]; !parsedKey.IsType() && !ok { - continue - } - kv := y.NewKV(tl.alloc) - if err := item.Value(func(val []byte) error { - kv.Value = append(kv.Value, val...) - return nil - }); err != nil { - return errors.Wrapf(err, "while copying value") - } - - backupKey, err := tl.toBackupKey(item.Key()) - if err != nil { - return err - } - kv.Key = backupKey - kv.UserMeta = tl.alloc.Copy([]byte{item.UserMeta()}) - kv.Version = item.Version() - kv.ExpiresAt = item.ExpiresAt() - list.Kv = append(list.Kv, kv) - } - return writeKVList(list, cWriter) - } - - for _, prefix := range []byte{x.ByteSchema, x.ByteType} { - if err := writePrefix(prefix); err != nil { - glog.Errorf("While writing prefix %d to backup: %v", prefix, err) - return &response, err - } - } - - if maxVersion > pr.Request.ReadTs { - glog.Errorf("Max timestamp seen during backup (%d) is greater than readTs (%d)", - maxVersion, pr.Request.ReadTs) - } - - glog.V(2).Infof("Backup group %d version: %d", pr.Request.GroupId, pr.Request.ReadTs) - if err = cWriter.Close(); err != nil { - glog.Errorf("While closing gzipped writer: %v", err) - return &response, err - } - - if err = w.Close(); err != nil { - glog.Errorf("While closing handler: %v", err) - return &response, err - } - glog.Infof("Backup complete: group %d at %d", pr.Request.GroupId, pr.Request.ReadTs) - return &response, nil -} - -// CompleteBackup will finalize a backup by writing the manifest at the backup destination. -func (pr *BackupProcessor) CompleteBackup(ctx context.Context, m *Manifest) error { - if err := ctx.Err(); err != nil { - return err - } - uri, err := url.Parse(pr.Request.Destination) - if err != nil { - return err - } - handler, err := NewUriHandler(uri, GetCredentialsFromRequest(pr.Request)) - if err != nil { - return err - } - - manifest, err := GetManifestNoUpgrade(handler, uri) - if err != nil { - return err - } - manifest.Manifests = append(manifest.Manifests, m) - - if err := CreateManifest(handler, uri, manifest); err != nil { - return errors.Wrap(err, "complete backup failed") - } - glog.Infof("Backup completed OK.") - return nil -} - -// GoString implements the GoStringer interface for Manifest. -func (m *Manifest) GoString() string { - return fmt.Sprintf(`Manifest{Since: %d, ReadTs: %d, Groups: %v, Encrypted: %v}`, - m.SinceTsDeprecated, m.ReadTs, m.Groups, m.Encrypted) -} - -func (tl *threadLocal) toBackupList(key []byte, itr *badger.Iterator) ( - *bpb.KVList, *pb.DropOperation, error) { - list := &bpb.KVList{} - var dropOp *pb.DropOperation - - item := itr.Item() - if item.Version() < tl.Request.SinceTs { - return list, nil, - errors.Errorf("toBackupList: Item.Version(): %d should be less than sinceTs: %d", - item.Version(), tl.Request.SinceTs) - } - if item.IsDeletedOrExpired() { - return list, nil, nil - } - - switch item.UserMeta() { - case posting.BitEmptyPosting, posting.BitCompletePosting, posting.BitDeltaPosting: - l, err := posting.ReadPostingList(key, itr) - if err != nil { - return nil, nil, errors.Wrapf(err, "while reading posting list") - } - - // Don't allocate kv on tl.alloc, because we don't need it by the end of this func. - kv, err := l.ToBackupPostingList(&tl.bpl, tl.alloc, tl.buf) - if err != nil { - return nil, nil, errors.Wrapf(err, "while rolling up list") - } - - backupKey, err := tl.toBackupKey(kv.Key) - if err != nil { - return nil, nil, err - } - - // check if this key was storing a DROP operation record. If yes, get the drop operation. - dropOp, err = checkAndGetDropOp(key, l, tl.Request.ReadTs) - if err != nil { - return nil, nil, err - } - - kv.Key = backupKey - list.Kv = append(list.Kv, kv) - default: - return nil, nil, errors.Errorf( - "Unexpected meta: %d for key: %s", item.UserMeta(), hex.Dump(key)) - } - return list, dropOp, nil -} - -func (tl *threadLocal) toBackupKey(key []byte) ([]byte, error) { - parsedKey, err := x.Parse(key) - if err != nil { - return nil, errors.Wrapf(err, "could not parse key %s", hex.Dump(key)) - } - bk := parsedKey.ToBackupKey() - - out := tl.alloc.Allocate(proto.Size(bk)) - return x.MarshalToSizedBuffer(out, bk) -} - -func writeKVList(list *bpb.KVList, w io.Writer) error { - if err := binary.Write(w, binary.LittleEndian, uint64(proto.Size(list))); err != nil { - return err - } - buf, err := proto.Marshal(list) - if err != nil { - return err - } - _, err = w.Write(buf) - return err -} - -func checkAndGetDropOp(key []byte, l *posting.List, readTs uint64) (*pb.DropOperation, error) { - isDropOpKey, err := x.IsDropOpKey(key) - if err != nil || !isDropOpKey { - return nil, err - } - - vals, err := l.AllValues(readTs) - if err != nil { - return nil, errors.Wrapf(err, "cannot read value of dgraph.drop.op") - } - switch len(vals) { - case 0: - // do nothing, it means this one was deleted with S * * deletion. - // So, no need to consider it. - return nil, nil - case 1: - val, ok := vals[0].Value.([]byte) - if !ok { - return nil, errors.Errorf("cannot convert value of dgraph.drop.op to byte array, "+ - "got type: %s, value: %v, tid: %v", reflect.TypeOf(vals[0].Value), vals[0].Value, - vals[0].Tid) - } - // A dgraph.drop.op record can have values in only one of the following formats: - // * DROP_ALL; - // * DROP_DATA;ns - // * DROP_ATTR;attrName - // * DROP_NS;ns - // So, accordingly construct the *pb.DropOperation. - dropOp := &pb.DropOperation{} - dropInfo := strings.Split(string(val), ";") - if len(dropInfo) != 2 { - return nil, errors.Errorf("Unexpected value: %s for dgraph.drop.op", val) - } - switch dropInfo[0] { - case "DROP_ALL": - dropOp.DropOp = pb.DropOperation_ALL - case "DROP_DATA": - dropOp.DropOp = pb.DropOperation_DATA - dropOp.DropValue = dropInfo[1] // contains namespace. - case "DROP_ATTR": - dropOp.DropOp = pb.DropOperation_ATTR - dropOp.DropValue = dropInfo[1] - case "DROP_NS": - dropOp.DropOp = pb.DropOperation_NS - dropOp.DropValue = dropInfo[1] // contains namespace. - } - return dropOp, nil - default: - // getting more than one values for a non-list predicate is an error - return nil, errors.Errorf("found multiple values for dgraph.drop.op: %v", vals) - } -} diff --git a/worker/backup_handler.go b/worker/backup_handler.go index da975372b60..fb87e5101ef 100644 --- a/worker/backup_handler.go +++ b/worker/backup_handler.go @@ -1,8 +1,6 @@ -//go:build !oss -// +build !oss - /* * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 */ package worker diff --git a/worker/backup_manifest.go b/worker/backup_manifest.go index 47880ea41fa..912e2240273 100644 --- a/worker/backup_manifest.go +++ b/worker/backup_manifest.go @@ -1,8 +1,6 @@ -//go:build !oss -// +build !oss - /* * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 */ package worker diff --git a/worker/backup_oss.go b/worker/backup_oss.go deleted file mode 100644 index c98b550df8b..00000000000 --- a/worker/backup_oss.go +++ /dev/null @@ -1,33 +0,0 @@ -//go:build oss -// +build oss - -/* - * SPDX-FileCopyrightText: © Hypermode Inc. - * SPDX-License-Identifier: Apache-2.0 - */ - -package worker - -import ( - "context" - - "github.com/golang/glog" - - "github.com/hypermodeinc/dgraph/v24/protos/pb" - "github.com/hypermodeinc/dgraph/v24/x" -) - -// Backup implements the Worker interface. -func (w *grpcWorker) Backup(ctx context.Context, req *pb.BackupRequest) (*pb.BackupResponse, error) { - glog.Warningf("Backup failed: %v", x.ErrNotSupported) - return nil, x.ErrNotSupported -} - -func ProcessBackupRequest(ctx context.Context, req *pb.BackupRequest) error { - glog.Warningf("Backup failed: %v", x.ErrNotSupported) - return x.ErrNotSupported -} - -func ProcessListBackups(ctx context.Context, location string, creds *x.MinioCredentials) ([]*Manifest, error) { - return nil, x.ErrNotSupported -} diff --git a/worker/cdc.go b/worker/cdc.go index dc45b8abb49..4782f6d10c9 100644 --- a/worker/cdc.go +++ b/worker/cdc.go @@ -1,6 +1,3 @@ -//go:build oss -// +build oss - /* * SPDX-FileCopyrightText: © Hypermode Inc. * SPDX-License-Identifier: Apache-2.0 @@ -9,39 +6,487 @@ package worker import ( + "bytes" + "encoding/binary" + "encoding/json" "math" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/golang/glog" + "github.com/pkg/errors" + "go.etcd.io/etcd/raft/v3/raftpb" + "google.golang.org/protobuf/proto" + "github.com/dgraph-io/ristretto/v2/z" + "github.com/hypermodeinc/dgraph/v24/posting" "github.com/hypermodeinc/dgraph/v24/protos/pb" + "github.com/hypermodeinc/dgraph/v24/types" + "github.com/hypermodeinc/dgraph/v24/x" +) + +const ( + defaultEventTopic = "dgraph-cdc" ) +// CDC struct is being used to send out change data capture events. There are two ways to do this: +// 1. Use Badger Subscribe. +// 2. Use Raft WAL. +// We chose to go with Raft WAL because in case we lose connection to the sink (say Kafka), we can +// resume from the last sent event and ensure there's continuity in event sending. Note the events +// would sent in the same order as they're being committed. +// With Badger Subscribe, if we lose the connection, we would have no way to send over the "missed" +// events. Even if we scan over Badger, we'd still not get those events in the right order, i.e. +// order of their commit timestamp. So, this approach would be tricky to get right. type CDC struct { + sync.Mutex + sink Sink + closer *z.Closer + pendingTxnEvents map[uint64][]CDCEvent + + // dont use mutex, use atomic for the following. + + // seenIndex is the Raft index till which we have read the raft logs, and + // put the events in our pendingTxnEvents. This does NOT mean that we have + // sent them yet. + seenIndex uint64 + sentTs uint64 // max commit ts for which we have send the events. } func newCDC() *CDC { - return nil + if Config.ChangeDataConf == "" || Config.ChangeDataConf == CDCDefaults { + return nil + } + + cdcFlag := z.NewSuperFlag(Config.ChangeDataConf).MergeAndCheckDefault(CDCDefaults) + sink, err := GetSink(cdcFlag) + x.Check(err) + cdc := &CDC{ + sink: sink, + closer: z.NewCloser(1), + pendingTxnEvents: make(map[uint64][]CDCEvent), + } + return cdc } -func (cd *CDC) getTs() uint64 { - return math.MaxUint64 +func (cdc *CDC) getSeenIndex() uint64 { + if cdc == nil { + return math.MaxUint64 + } + return atomic.LoadUint64(&cdc.seenIndex) } -func (cd *CDC) updateTs(ts uint64) { - return +func (cdc *CDC) getTs() uint64 { + if cdc == nil { + return math.MaxUint64 + } + cdc.Lock() + defer cdc.Unlock() + min := uint64(math.MaxUint64) + for startTs := range cdc.pendingTxnEvents { + min = x.Min(min, startTs) + } + return min } -func (cdc *CDC) getSeenIndex() uint64 { - return math.MaxUint64 +func (cdc *CDC) resetPendingEvents() { + if cdc == nil { + return + } + cdc.Lock() + defer cdc.Unlock() + cdc.pendingTxnEvents = make(map[uint64][]CDCEvent) +} + +func (cdc *CDC) resetPendingEventsForNs(ns uint64) { + if cdc == nil { + return + } + cdc.Lock() + defer cdc.Unlock() + for ts, events := range cdc.pendingTxnEvents { + if len(events) > 0 && binary.BigEndian.Uint64(events[0].Meta.Namespace) == ns { + delete(cdc.pendingTxnEvents, ts) + } + } +} + +func (cdc *CDC) hasPending(attr string) bool { + if cdc == nil { + return false + } + cdc.Lock() + defer cdc.Unlock() + for _, events := range cdc.pendingTxnEvents { + for _, e := range events { + if me, ok := e.Event.(*MutationEvent); ok && me.Attr == attr { + return true + } + } + } + return false +} + +func (cdc *CDC) addToPending(ts uint64, events []CDCEvent) { + if cdc == nil { + return + } + cdc.Lock() + defer cdc.Unlock() + cdc.pendingTxnEvents[ts] = append(cdc.pendingTxnEvents[ts], events...) +} + +func (cdc *CDC) removeFromPending(ts uint64) { + if cdc == nil { + return + } + cdc.Lock() + defer cdc.Unlock() + delete(cdc.pendingTxnEvents, ts) +} + +func (cdc *CDC) updateSeenIndex(index uint64) { + if cdc == nil { + return + } + idx := atomic.LoadUint64(&cdc.seenIndex) + if idx >= index { + return + } + atomic.CompareAndSwapUint64(&cdc.seenIndex, idx, index) } func (cdc *CDC) updateCDCState(state *pb.CDCState) { - return + if cdc == nil { + return + } + + // Dont try to update seen index in case of default mode else cdc job will not + // be able to build the complete pending txns in case of membership changes. + ts := atomic.LoadUint64(&cdc.sentTs) + if ts >= state.SentTs { + return + } + atomic.CompareAndSwapUint64(&cdc.sentTs, ts, state.SentTs) } -func (cd *CDC) Close() { - return +func (cdc *CDC) Close() { + if cdc == nil { + return + } + glog.Infof("closing CDC events...") + cdc.closer.SignalAndWait() + err := cdc.sink.Close() + glog.Errorf("error while closing sink %v", err) } -// todo: test cases old cluster restart, live loader, bulk loader, backup restore etc -func (cd *CDC) processCDCEvents() { - return +func (cdc *CDC) processCDCEvents() { + if cdc == nil { + return + } + + sendToSink := func(pending []CDCEvent, commitTs uint64) error { + batch := make([]SinkMessage, 0) + for _, e := range pending { + e.Meta.CommitTs = commitTs + b, err := json.Marshal(e) + if err != nil { + glog.Errorf("error while marshalling batch for event [%+v]: %v\n", e.Event, err) + continue + } + batch = append(batch, SinkMessage{ + Meta: SinkMeta{ + Topic: defaultEventTopic, + }, + Key: e.Meta.Namespace, + Value: b, + }) + } + if err := cdc.sink.Send(batch); err != nil { + glog.Errorf("error while sending cdc event to sink %+v", err) + return err + } + // We successfully sent messages to sink. + atomic.StoreUint64(&cdc.sentTs, commitTs) + return nil + } + + handleEntry := func(entry raftpb.Entry) (rerr error) { + defer func() { + // Irrespective of whether we act on this entry or not, we should + // always update the seenIndex. Otherwise, we'll loop over these + // entries over and over again. However, if we encounter an error, + // we should not update the index. + if rerr == nil { + cdc.updateSeenIndex(entry.Index) + } + }() + + if entry.Type != raftpb.EntryNormal || len(entry.Data) == 0 { + return + } + + var proposal pb.Proposal + if err := proto.Unmarshal(entry.Data[8:], &proposal); err != nil { + glog.Warningf("CDC: unmarshal failed with error %v. Ignoring.", err) + return + } + if proposal.Mutations != nil { + events := toCDCEvent(entry.Index, proposal.Mutations) + if len(events) == 0 { + return + } + edges := proposal.Mutations.Edges + switch { + case proposal.Mutations.DropOp != pb.Mutations_NONE: // this means its a drop operation + // if there is DROP ALL or DROP DATA operation, clear pending events also. + if proposal.Mutations.DropOp == pb.Mutations_ALL { + cdc.resetPendingEvents() + } else if proposal.Mutations.DropOp == pb.Mutations_DATA { + ns, err := strconv.ParseUint(proposal.Mutations.DropValue, 0, 64) + if err != nil { + glog.Warningf("CDC: parsing namespace failed with error %v. Ignoring.", err) + return + } + cdc.resetPendingEventsForNs(ns) + } + if err := sendToSink(events, proposal.Mutations.StartTs); err != nil { + rerr = errors.Wrapf(err, "unable to send messages to sink") + return + } + // If drop predicate, then mutation only succeeds if there were no pending txn + // This check ensures then event will only be send if there were no pending txns + case len(edges) == 1 && + edges[0].Entity == 0 && + bytes.Equal(edges[0].Value, []byte(x.Star)): + // If there are no pending txn send the events else + // return as the mutation must have errored out in that case. + if !cdc.hasPending(x.ParseAttr(edges[0].Attr)) { + if err := sendToSink(events, proposal.Mutations.StartTs); err != nil { + rerr = errors.Wrapf(err, "unable to send messages to sink") + } + } + return + default: + cdc.addToPending(proposal.Mutations.StartTs, events) + } + } + + if proposal.Delta != nil { + for _, ts := range proposal.Delta.Txns { + // This ensures we dont send events again in case of membership changes. + if ts.CommitTs > 0 && atomic.LoadUint64(&cdc.sentTs) < ts.CommitTs { + events := cdc.pendingTxnEvents[ts.StartTs] + if err := sendToSink(events, ts.CommitTs); err != nil { + rerr = errors.Wrapf(err, "unable to send messages to sink") + return + } + } + // Delete from pending events once events are sent. + cdc.removeFromPending(ts.StartTs) + } + } + return + } + + // This will always run on leader node only. For default mode, Leader will + // check the Raft logs and keep in memory events that are pending. Once + // Txn is done, it will send events to sink, and update sentTs locally. + sendEvents := func() error { + first, err := groups().Node.Store.FirstIndex() + x.Check(err) + cdcIndex := x.Max(atomic.LoadUint64(&cdc.seenIndex)+1, first) + + last := groups().Node.Applied.DoneUntil() + if cdcIndex > last { + return nil + } + for batchFirst := cdcIndex; batchFirst <= last; { + entries, err := groups().Node.Store.Entries(batchFirst, last+1, 256<<20) + if err != nil { + return errors.Wrapf(err, + "CDC: failed to retrieve entries from Raft. Start: %d End: %d", + batchFirst, last+1) + } + if len(entries) == 0 { + return nil + } + batchFirst = entries[len(entries)-1].Index + 1 + for _, entry := range entries { + if err := handleEntry(entry); err != nil { + return errors.Wrapf(err, "CDC: unable to process raft entry") + } + } + } + return nil + } + + jobTick := time.NewTicker(time.Second) + proposalTick := time.NewTicker(3 * time.Minute) + defer cdc.closer.Done() + defer jobTick.Stop() + defer proposalTick.Stop() + var lastSent uint64 + for { + select { + case <-cdc.closer.HasBeenClosed(): + return + case <-jobTick.C: + if groups().Node.AmLeader() { + if err := sendEvents(); err != nil { + glog.Errorf("unable to send events %+v", err) + } + } + case <-proposalTick.C: + // The leader would propose the max sentTs over to the group. + // So, in case of a crash or a leadership change, the new leader + // would know where to send the cdc events from the Raft logs. + if groups().Node.AmLeader() { + sentTs := atomic.LoadUint64(&cdc.sentTs) + if lastSent == sentTs { + // No need to propose anything. + continue + } + if err := groups().Node.proposeCDCState(atomic.LoadUint64(&cdc.sentTs)); err != nil { + glog.Errorf("unable to propose cdc state %+v", err) + } else { + lastSent = sentTs + } + } + } + } +} + +type CDCEvent struct { + Meta *EventMeta `json:"meta"` + Type string `json:"type"` + Event interface{} `json:"event"` +} + +type EventMeta struct { + RaftIndex uint64 `json:"-"` + Namespace []byte `json:"-"` + CommitTs uint64 `json:"commit_ts"` +} + +type MutationEvent struct { + Operation string `json:"operation"` + Uid uint64 `json:"uid"` + Attr string `json:"attr"` + Value interface{} `json:"value"` + ValueType string `json:"value_type"` +} + +type DropEvent struct { + Operation string `json:"operation"` + Type string `json:"type"` + Pred string `json:"pred"` +} + +const ( + EventTypeDrop = "drop" + EventTypeMutation = "mutation" + OpDropPred = "predicate" +) + +func toCDCEvent(index uint64, mutation *pb.Mutations) []CDCEvent { + // todo(Aman): we are skipping schema updates for now. Fix this later. + if len(mutation.Schema) > 0 || len(mutation.Types) > 0 { + return nil + } + + // If drop operation + if mutation.DropOp != pb.Mutations_NONE { + namespace := make([]byte, 8) + var t string + switch mutation.DropOp { + case pb.Mutations_ALL: + // Drop all is cluster wide. + binary.BigEndian.PutUint64(namespace, x.GalaxyNamespace) + case pb.Mutations_DATA: + ns, err := strconv.ParseUint(mutation.DropValue, 0, 64) + if err != nil { + glog.Warningf("CDC: parsing namespace failed with error %v. Ignoring.", err) + return nil + } + binary.BigEndian.PutUint64(namespace, ns) + case pb.Mutations_TYPE: + namespace, t = x.ParseNamespaceBytes(mutation.DropValue) + default: + glog.Error("CDC: got unhandled drop operation") + } + + return []CDCEvent{ + { + Type: EventTypeDrop, + Event: &DropEvent{ + Operation: strings.ToLower(mutation.DropOp.String()), + Type: t, + }, + Meta: &EventMeta{ + RaftIndex: index, + Namespace: namespace, + }, + }, + } + } + + cdcEvents := make([]CDCEvent, 0) + for _, edge := range mutation.Edges { + if x.IsReservedPredicate(edge.Attr) { + continue + } + ns, attr := x.ParseNamespaceBytes(edge.Attr) + // Handle drop attr event. + if edge.Entity == 0 && bytes.Equal(edge.Value, []byte(x.Star)) { + return []CDCEvent{ + { + Type: EventTypeDrop, + Event: &DropEvent{ + Operation: OpDropPred, + Pred: attr, + }, + Meta: &EventMeta{ + RaftIndex: index, + Namespace: ns, + }, + }, + } + } + + var val interface{} + switch { + case posting.TypeID(edge) == types.UidID: + val = edge.ValueId + case posting.TypeID(edge) == types.PasswordID: + val = "****" + default: + // convert to correct type + src := types.Val{Tid: types.BinaryID, Value: edge.Value} + if v, err := types.Convert(src, posting.TypeID(edge)); err == nil { + val = v.Value + } else { + glog.Errorf("error while converting value %v", err) + } + } + cdcEvents = append(cdcEvents, CDCEvent{ + Meta: &EventMeta{ + RaftIndex: index, + Namespace: ns, + }, + Type: EventTypeMutation, + Event: &MutationEvent{ + Operation: strings.ToLower(edge.Op.String()), + Uid: edge.Entity, + Attr: attr, + Value: val, + ValueType: posting.TypeID(edge).Name(), + }, + }) + } + + return cdcEvents } diff --git a/worker/cdc_ee.go b/worker/cdc_ee.go deleted file mode 100644 index bf700bafd38..00000000000 --- a/worker/cdc_ee.go +++ /dev/null @@ -1,494 +0,0 @@ -//go:build !oss -// +build !oss - -/* - * SPDX-FileCopyrightText: © Hypermode Inc. - */ - -package worker - -import ( - "bytes" - "encoding/binary" - "encoding/json" - "math" - "strconv" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/golang/glog" - "github.com/pkg/errors" - "go.etcd.io/etcd/raft/v3/raftpb" - "google.golang.org/protobuf/proto" - - "github.com/dgraph-io/ristretto/v2/z" - "github.com/hypermodeinc/dgraph/v24/posting" - "github.com/hypermodeinc/dgraph/v24/protos/pb" - "github.com/hypermodeinc/dgraph/v24/types" - "github.com/hypermodeinc/dgraph/v24/x" -) - -const ( - defaultEventTopic = "dgraph-cdc" -) - -// CDC struct is being used to send out change data capture events. There are two ways to do this: -// 1. Use Badger Subscribe. -// 2. Use Raft WAL. -// We chose to go with Raft WAL because in case we lose connection to the sink (say Kafka), we can -// resume from the last sent event and ensure there's continuity in event sending. Note the events -// would sent in the same order as they're being committed. -// With Badger Subscribe, if we lose the connection, we would have no way to send over the "missed" -// events. Even if we scan over Badger, we'd still not get those events in the right order, i.e. -// order of their commit timestamp. So, this approach would be tricky to get right. -type CDC struct { - sync.Mutex - sink Sink - closer *z.Closer - pendingTxnEvents map[uint64][]CDCEvent - - // dont use mutex, use atomic for the following. - - // seenIndex is the Raft index till which we have read the raft logs, and - // put the events in our pendingTxnEvents. This does NOT mean that we have - // sent them yet. - seenIndex uint64 - sentTs uint64 // max commit ts for which we have send the events. -} - -func newCDC() *CDC { - if Config.ChangeDataConf == "" || Config.ChangeDataConf == CDCDefaults { - return nil - } - - cdcFlag := z.NewSuperFlag(Config.ChangeDataConf).MergeAndCheckDefault(CDCDefaults) - sink, err := GetSink(cdcFlag) - x.Check(err) - cdc := &CDC{ - sink: sink, - closer: z.NewCloser(1), - pendingTxnEvents: make(map[uint64][]CDCEvent), - } - return cdc -} - -func (cdc *CDC) getSeenIndex() uint64 { - if cdc == nil { - return math.MaxUint64 - } - return atomic.LoadUint64(&cdc.seenIndex) -} - -func (cdc *CDC) getTs() uint64 { - if cdc == nil { - return math.MaxUint64 - } - cdc.Lock() - defer cdc.Unlock() - min := uint64(math.MaxUint64) - for startTs := range cdc.pendingTxnEvents { - min = x.Min(min, startTs) - } - return min -} - -func (cdc *CDC) resetPendingEvents() { - if cdc == nil { - return - } - cdc.Lock() - defer cdc.Unlock() - cdc.pendingTxnEvents = make(map[uint64][]CDCEvent) -} - -func (cdc *CDC) resetPendingEventsForNs(ns uint64) { - if cdc == nil { - return - } - cdc.Lock() - defer cdc.Unlock() - for ts, events := range cdc.pendingTxnEvents { - if len(events) > 0 && binary.BigEndian.Uint64(events[0].Meta.Namespace) == ns { - delete(cdc.pendingTxnEvents, ts) - } - } -} - -func (cdc *CDC) hasPending(attr string) bool { - if cdc == nil { - return false - } - cdc.Lock() - defer cdc.Unlock() - for _, events := range cdc.pendingTxnEvents { - for _, e := range events { - if me, ok := e.Event.(*MutationEvent); ok && me.Attr == attr { - return true - } - } - } - return false -} - -func (cdc *CDC) addToPending(ts uint64, events []CDCEvent) { - if cdc == nil { - return - } - cdc.Lock() - defer cdc.Unlock() - cdc.pendingTxnEvents[ts] = append(cdc.pendingTxnEvents[ts], events...) -} - -func (cdc *CDC) removeFromPending(ts uint64) { - if cdc == nil { - return - } - cdc.Lock() - defer cdc.Unlock() - delete(cdc.pendingTxnEvents, ts) -} - -func (cdc *CDC) updateSeenIndex(index uint64) { - if cdc == nil { - return - } - idx := atomic.LoadUint64(&cdc.seenIndex) - if idx >= index { - return - } - atomic.CompareAndSwapUint64(&cdc.seenIndex, idx, index) -} - -func (cdc *CDC) updateCDCState(state *pb.CDCState) { - if cdc == nil { - return - } - - // Dont try to update seen index in case of default mode else cdc job will not - // be able to build the complete pending txns in case of membership changes. - ts := atomic.LoadUint64(&cdc.sentTs) - if ts >= state.SentTs { - return - } - atomic.CompareAndSwapUint64(&cdc.sentTs, ts, state.SentTs) -} - -func (cdc *CDC) Close() { - if cdc == nil { - return - } - glog.Infof("closing CDC events...") - cdc.closer.SignalAndWait() - err := cdc.sink.Close() - glog.Errorf("error while closing sink %v", err) -} - -func (cdc *CDC) processCDCEvents() { - if cdc == nil { - return - } - - sendToSink := func(pending []CDCEvent, commitTs uint64) error { - batch := make([]SinkMessage, 0) - for _, e := range pending { - e.Meta.CommitTs = commitTs - b, err := json.Marshal(e) - if err != nil { - glog.Errorf("error while marshalling batch for event [%+v]: %v\n", e.Event, err) - continue - } - batch = append(batch, SinkMessage{ - Meta: SinkMeta{ - Topic: defaultEventTopic, - }, - Key: e.Meta.Namespace, - Value: b, - }) - } - if err := cdc.sink.Send(batch); err != nil { - glog.Errorf("error while sending cdc event to sink %+v", err) - return err - } - // We successfully sent messages to sink. - atomic.StoreUint64(&cdc.sentTs, commitTs) - return nil - } - - handleEntry := func(entry raftpb.Entry) (rerr error) { - defer func() { - // Irrespective of whether we act on this entry or not, we should - // always update the seenIndex. Otherwise, we'll loop over these - // entries over and over again. However, if we encounter an error, - // we should not update the index. - if rerr == nil { - cdc.updateSeenIndex(entry.Index) - } - }() - - if entry.Type != raftpb.EntryNormal || len(entry.Data) == 0 { - return - } - - var proposal pb.Proposal - if err := proto.Unmarshal(entry.Data[8:], &proposal); err != nil { - glog.Warningf("CDC: unmarshal failed with error %v. Ignoring.", err) - return - } - if proposal.Mutations != nil { - events := toCDCEvent(entry.Index, proposal.Mutations) - if len(events) == 0 { - return - } - edges := proposal.Mutations.Edges - switch { - case proposal.Mutations.DropOp != pb.Mutations_NONE: // this means its a drop operation - // if there is DROP ALL or DROP DATA operation, clear pending events also. - if proposal.Mutations.DropOp == pb.Mutations_ALL { - cdc.resetPendingEvents() - } else if proposal.Mutations.DropOp == pb.Mutations_DATA { - ns, err := strconv.ParseUint(proposal.Mutations.DropValue, 0, 64) - if err != nil { - glog.Warningf("CDC: parsing namespace failed with error %v. Ignoring.", err) - return - } - cdc.resetPendingEventsForNs(ns) - } - if err := sendToSink(events, proposal.Mutations.StartTs); err != nil { - rerr = errors.Wrapf(err, "unable to send messages to sink") - return - } - // If drop predicate, then mutation only succeeds if there were no pending txn - // This check ensures then event will only be send if there were no pending txns - case len(edges) == 1 && - edges[0].Entity == 0 && - bytes.Equal(edges[0].Value, []byte(x.Star)): - // If there are no pending txn send the events else - // return as the mutation must have errored out in that case. - if !cdc.hasPending(x.ParseAttr(edges[0].Attr)) { - if err := sendToSink(events, proposal.Mutations.StartTs); err != nil { - rerr = errors.Wrapf(err, "unable to send messages to sink") - } - } - return - default: - cdc.addToPending(proposal.Mutations.StartTs, events) - } - } - - if proposal.Delta != nil { - for _, ts := range proposal.Delta.Txns { - // This ensures we dont send events again in case of membership changes. - if ts.CommitTs > 0 && atomic.LoadUint64(&cdc.sentTs) < ts.CommitTs { - events := cdc.pendingTxnEvents[ts.StartTs] - if err := sendToSink(events, ts.CommitTs); err != nil { - rerr = errors.Wrapf(err, "unable to send messages to sink") - return - } - } - // Delete from pending events once events are sent. - cdc.removeFromPending(ts.StartTs) - } - } - return - } - - // This will always run on leader node only. For default mode, Leader will - // check the Raft logs and keep in memory events that are pending. Once - // Txn is done, it will send events to sink, and update sentTs locally. - sendEvents := func() error { - first, err := groups().Node.Store.FirstIndex() - x.Check(err) - cdcIndex := x.Max(atomic.LoadUint64(&cdc.seenIndex)+1, first) - - last := groups().Node.Applied.DoneUntil() - if cdcIndex > last { - return nil - } - for batchFirst := cdcIndex; batchFirst <= last; { - entries, err := groups().Node.Store.Entries(batchFirst, last+1, 256<<20) - if err != nil { - return errors.Wrapf(err, - "CDC: failed to retrieve entries from Raft. Start: %d End: %d", - batchFirst, last+1) - } - if len(entries) == 0 { - return nil - } - batchFirst = entries[len(entries)-1].Index + 1 - for _, entry := range entries { - if err := handleEntry(entry); err != nil { - return errors.Wrapf(err, "CDC: unable to process raft entry") - } - } - } - return nil - } - - jobTick := time.NewTicker(time.Second) - proposalTick := time.NewTicker(3 * time.Minute) - defer cdc.closer.Done() - defer jobTick.Stop() - defer proposalTick.Stop() - var lastSent uint64 - for { - select { - case <-cdc.closer.HasBeenClosed(): - return - case <-jobTick.C: - if groups().Node.AmLeader() { - if err := sendEvents(); err != nil { - glog.Errorf("unable to send events %+v", err) - } - } - case <-proposalTick.C: - // The leader would propose the max sentTs over to the group. - // So, in case of a crash or a leadership change, the new leader - // would know where to send the cdc events from the Raft logs. - if groups().Node.AmLeader() { - sentTs := atomic.LoadUint64(&cdc.sentTs) - if lastSent == sentTs { - // No need to propose anything. - continue - } - if err := groups().Node.proposeCDCState(atomic.LoadUint64(&cdc.sentTs)); err != nil { - glog.Errorf("unable to propose cdc state %+v", err) - } else { - lastSent = sentTs - } - } - } - } -} - -type CDCEvent struct { - Meta *EventMeta `json:"meta"` - Type string `json:"type"` - Event interface{} `json:"event"` -} - -type EventMeta struct { - RaftIndex uint64 `json:"-"` - Namespace []byte `json:"-"` - CommitTs uint64 `json:"commit_ts"` -} - -type MutationEvent struct { - Operation string `json:"operation"` - Uid uint64 `json:"uid"` - Attr string `json:"attr"` - Value interface{} `json:"value"` - ValueType string `json:"value_type"` -} - -type DropEvent struct { - Operation string `json:"operation"` - Type string `json:"type"` - Pred string `json:"pred"` -} - -const ( - EventTypeDrop = "drop" - EventTypeMutation = "mutation" - OpDropPred = "predicate" -) - -func toCDCEvent(index uint64, mutation *pb.Mutations) []CDCEvent { - // todo(Aman): we are skipping schema updates for now. Fix this later. - if len(mutation.Schema) > 0 || len(mutation.Types) > 0 { - return nil - } - - // If drop operation - if mutation.DropOp != pb.Mutations_NONE { - namespace := make([]byte, 8) - var t string - switch mutation.DropOp { - case pb.Mutations_ALL: - // Drop all is cluster wide. - binary.BigEndian.PutUint64(namespace, x.GalaxyNamespace) - case pb.Mutations_DATA: - ns, err := strconv.ParseUint(mutation.DropValue, 0, 64) - if err != nil { - glog.Warningf("CDC: parsing namespace failed with error %v. Ignoring.", err) - return nil - } - binary.BigEndian.PutUint64(namespace, ns) - case pb.Mutations_TYPE: - namespace, t = x.ParseNamespaceBytes(mutation.DropValue) - default: - glog.Error("CDC: got unhandled drop operation") - } - - return []CDCEvent{ - { - Type: EventTypeDrop, - Event: &DropEvent{ - Operation: strings.ToLower(mutation.DropOp.String()), - Type: t, - }, - Meta: &EventMeta{ - RaftIndex: index, - Namespace: namespace, - }, - }, - } - } - - cdcEvents := make([]CDCEvent, 0) - for _, edge := range mutation.Edges { - if x.IsReservedPredicate(edge.Attr) { - continue - } - ns, attr := x.ParseNamespaceBytes(edge.Attr) - // Handle drop attr event. - if edge.Entity == 0 && bytes.Equal(edge.Value, []byte(x.Star)) { - return []CDCEvent{ - { - Type: EventTypeDrop, - Event: &DropEvent{ - Operation: OpDropPred, - Pred: attr, - }, - Meta: &EventMeta{ - RaftIndex: index, - Namespace: ns, - }, - }, - } - } - - var val interface{} - switch { - case posting.TypeID(edge) == types.UidID: - val = edge.ValueId - case posting.TypeID(edge) == types.PasswordID: - val = "****" - default: - // convert to correct type - src := types.Val{Tid: types.BinaryID, Value: edge.Value} - if v, err := types.Convert(src, posting.TypeID(edge)); err == nil { - val = v.Value - } else { - glog.Errorf("error while converting value %v", err) - } - } - cdcEvents = append(cdcEvents, CDCEvent{ - Meta: &EventMeta{ - RaftIndex: index, - Namespace: ns, - }, - Type: EventTypeMutation, - Event: &MutationEvent{ - Operation: strings.ToLower(edge.Op.String()), - Uid: edge.Entity, - Attr: attr, - Value: val, - ValueType: posting.TypeID(edge).Name(), - }, - }) - } - - return cdcEvents -} diff --git a/worker/multi_tenancy.go b/worker/multi_tenancy.go index b1cbe5c5657..ffd1d423c41 100644 --- a/worker/multi_tenancy.go +++ b/worker/multi_tenancy.go @@ -1,6 +1,3 @@ -//go:build oss -// +build oss - /* * SPDX-FileCopyrightText: © Hypermode Inc. * SPDX-License-Identifier: Apache-2.0 @@ -10,20 +7,81 @@ package worker import ( "context" + "time" + + "github.com/golang/glog" + "github.com/pkg/errors" + "golang.org/x/sync/errgroup" + "github.com/hypermodeinc/dgraph/v24/conn" "github.com/hypermodeinc/dgraph/v24/protos/pb" "github.com/hypermodeinc/dgraph/v24/x" ) -func (w *grpcWorker) DeleteNamespace(ctx context.Context, - req *pb.DeleteNsRequest) (*pb.Status, error) { - return nil, x.ErrNotSupported +func (w *grpcWorker) DeleteNamespace(ctx context.Context, req *pb.DeleteNsRequest) (*pb.Status, error) { + var emptyRes pb.Status + if !groups().ServesGroup(req.GroupId) { + return &emptyRes, errors.Errorf("The server doesn't serve group id: %v", req.GroupId) + } + + if err := groups().Node.proposeAndWait(ctx, &pb.Proposal{DeleteNs: req}); err != nil { + return &emptyRes, errors.Wrapf(err, "Delete namespace failed for namespace %d on group %d", + req.Namespace, req.GroupId) + } + return &emptyRes, nil } func ProcessDeleteNsRequest(ctx context.Context, ns uint64) error { - return x.ErrNotSupported + // Update the membership state to get the latest mapping of groups to predicates. + if err := UpdateMembershipState(ctx); err != nil { + return errors.Wrapf(err, "Failed to update membership state while deleting namesapce") + } + + state := GetMembershipState() + g := new(errgroup.Group) + + for gid := range state.Groups { + req := &pb.DeleteNsRequest{Namespace: ns, GroupId: gid} + g.Go(func() error { + return x.RetryUntilSuccess(10, 100*time.Millisecond, func() error { + return proposeDeleteOrSend(ctx, req) + }) + }) + } + + if err := g.Wait(); err != nil { + return errors.Wrap(err, "Failed to process delete request") + } + + // Now propose the change to zero. + return x.RetryUntilSuccess(10, 100*time.Millisecond, func() error { + return sendDeleteToZero(ctx, ns) + }) +} + +func sendDeleteToZero(ctx context.Context, ns uint64) error { + gr := groups() + pl := gr.connToZeroLeader() + if pl == nil { + return conn.ErrNoConnection + } + zc := pb.NewZeroClient(pl.Get()) + _, err := zc.DeleteNamespace(gr.Ctx(), &pb.DeleteNsRequest{Namespace: ns}) + return err } func proposeDeleteOrSend(ctx context.Context, req *pb.DeleteNsRequest) error { - return nil + glog.V(2).Infof("Sending delete namespace request: %+v", req) + if groups().ServesGroup(req.GetGroupId()) && groups().Node.AmLeader() { + _, err := (&grpcWorker{}).DeleteNamespace(ctx, req) + return err + } + + pl := groups().Leader(req.GetGroupId()) + if pl == nil { + return conn.ErrNoConnection + } + c := pb.NewWorkerClient(pl.Get()) + _, err := c.DeleteNamespace(ctx, req) + return err } diff --git a/worker/multi_tenancy_ee.go b/worker/multi_tenancy_ee.go deleted file mode 100644 index 4e7cf4d32c2..00000000000 --- a/worker/multi_tenancy_ee.go +++ /dev/null @@ -1,89 +0,0 @@ -//go:build !oss -// +build !oss - -/* - * SPDX-FileCopyrightText: © Hypermode Inc. - */ - -package worker - -import ( - "context" - "time" - - "github.com/golang/glog" - "github.com/pkg/errors" - "golang.org/x/sync/errgroup" - - "github.com/hypermodeinc/dgraph/v24/conn" - "github.com/hypermodeinc/dgraph/v24/protos/pb" - "github.com/hypermodeinc/dgraph/v24/x" -) - -func (w *grpcWorker) DeleteNamespace(ctx context.Context, req *pb.DeleteNsRequest) (*pb.Status, error) { - var emptyRes pb.Status - if !groups().ServesGroup(req.GroupId) { - return &emptyRes, errors.Errorf("The server doesn't serve group id: %v", req.GroupId) - } - - if err := groups().Node.proposeAndWait(ctx, &pb.Proposal{DeleteNs: req}); err != nil { - return &emptyRes, errors.Wrapf(err, "Delete namespace failed for namespace %d on group %d", - req.Namespace, req.GroupId) - } - return &emptyRes, nil -} - -func ProcessDeleteNsRequest(ctx context.Context, ns uint64) error { - // Update the membership state to get the latest mapping of groups to predicates. - if err := UpdateMembershipState(ctx); err != nil { - return errors.Wrapf(err, "Failed to update membership state while deleting namesapce") - } - - state := GetMembershipState() - g := new(errgroup.Group) - - for gid := range state.Groups { - req := &pb.DeleteNsRequest{Namespace: ns, GroupId: gid} - g.Go(func() error { - return x.RetryUntilSuccess(10, 100*time.Millisecond, func() error { - return proposeDeleteOrSend(ctx, req) - }) - }) - } - - if err := g.Wait(); err != nil { - return errors.Wrap(err, "Failed to process delete request") - } - - // Now propose the change to zero. - return x.RetryUntilSuccess(10, 100*time.Millisecond, func() error { - return sendDeleteToZero(ctx, ns) - }) -} - -func sendDeleteToZero(ctx context.Context, ns uint64) error { - gr := groups() - pl := gr.connToZeroLeader() - if pl == nil { - return conn.ErrNoConnection - } - zc := pb.NewZeroClient(pl.Get()) - _, err := zc.DeleteNamespace(gr.Ctx(), &pb.DeleteNsRequest{Namespace: ns}) - return err -} - -func proposeDeleteOrSend(ctx context.Context, req *pb.DeleteNsRequest) error { - glog.V(2).Infof("Sending delete namespace request: %+v", req) - if groups().ServesGroup(req.GetGroupId()) && groups().Node.AmLeader() { - _, err := (&grpcWorker{}).DeleteNamespace(ctx, req) - return err - } - - pl := groups().Leader(req.GetGroupId()) - if pl == nil { - return conn.ErrNoConnection - } - c := pb.NewWorkerClient(pl.Get()) - _, err := c.DeleteNamespace(ctx, req) - return err -} diff --git a/worker/online_restore.go b/worker/online_restore.go index ffebdb98bc0..10dae3658b0 100644 --- a/worker/online_restore.go +++ b/worker/online_restore.go @@ -1,8 +1,6 @@ -//go:build !oss -// +build !oss - /* * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 */ package worker diff --git a/worker/online_restore_oss.go b/worker/online_restore_oss.go deleted file mode 100644 index 416fcf31eab..00000000000 --- a/worker/online_restore_oss.go +++ /dev/null @@ -1,34 +0,0 @@ -//go:build oss -// +build oss - -/* - * SPDX-FileCopyrightText: © Hypermode Inc. - * SPDX-License-Identifier: Apache-2.0 - */ - -package worker - -import ( - "context" - "sync" - - "github.com/golang/glog" - - "github.com/hypermodeinc/dgraph/v24/protos/pb" - "github.com/hypermodeinc/dgraph/v24/x" -) - -func ProcessRestoreRequest(ctx context.Context, req *pb.RestoreRequest, wg *sync.WaitGroup) error { - glog.Warningf("Restore failed: %v", x.ErrNotSupported) - return x.ErrNotSupported -} - -// Restore implements the Worker interface. -func (w *grpcWorker) Restore(ctx context.Context, req *pb.RestoreRequest) (*pb.Status, error) { - glog.Warningf("Restore failed: %v", x.ErrNotSupported) - return &pb.Status{}, x.ErrNotSupported -} - -func handleRestoreProposal(ctx context.Context, req *pb.RestoreRequest, pidx uint64) error { - return nil -} diff --git a/worker/restore_map.go b/worker/restore_map.go index d2e277d13d9..211806fdf22 100644 --- a/worker/restore_map.go +++ b/worker/restore_map.go @@ -1,8 +1,6 @@ -//go:build !oss -// +build !oss - /* * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 */ package worker diff --git a/worker/restore_reduce.go b/worker/restore_reduce.go index 67841244c0b..14d20ba9308 100644 --- a/worker/restore_reduce.go +++ b/worker/restore_reduce.go @@ -1,8 +1,6 @@ -//go:build !oss -// +build !oss - /* * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 */ package worker