diff --git a/services/register/all.go b/services/register/all.go index 8db6c5ca36d..e51e153cf5f 100644 --- a/services/register/all.go +++ b/services/register/all.go @@ -10,4 +10,5 @@ import ( _ "go.viam.com/rdk/services/shell/register" _ "go.viam.com/rdk/services/slam/register" _ "go.viam.com/rdk/services/vision/register" + _ "go.viam.com/rdk/services/worldstatestore/register" ) diff --git a/services/register_apis/all.go b/services/register_apis/all.go index 74cc41534b9..1a051f581bf 100644 --- a/services/register_apis/all.go +++ b/services/register_apis/all.go @@ -12,4 +12,5 @@ import ( _ "go.viam.com/rdk/services/navigation" _ "go.viam.com/rdk/services/shell" _ "go.viam.com/rdk/services/slam" + _ "go.viam.com/rdk/services/worldstatestore" ) diff --git a/services/worldstatestore/client.go b/services/worldstatestore/client.go new file mode 100644 index 00000000000..61e2e4f04ad --- /dev/null +++ b/services/worldstatestore/client.go @@ -0,0 +1,143 @@ +package worldstatestore + +import ( + "context" + "errors" + "io" + + "go.opencensus.io/trace" + commonPB "go.viam.com/api/common/v1" + pb "go.viam.com/api/service/worldstatestore/v1" + "go.viam.com/utils/protoutils" + "go.viam.com/utils/rpc" + + "go.viam.com/rdk/logging" + rprotoutils "go.viam.com/rdk/protoutils" + "go.viam.com/rdk/resource" +) + +type client struct { + resource.Named + resource.TriviallyReconfigurable + resource.TriviallyCloseable + name string + client pb.WorldStateStoreServiceClient + logger logging.Logger +} + +// NewClientFromConn constructs a new Client from the connection passed in. +func NewClientFromConn( + ctx context.Context, + conn rpc.ClientConn, + remoteName string, + name resource.Name, + logger logging.Logger, +) (Service, error) { + grpcClient := pb.NewWorldStateStoreServiceClient(conn) + c := &client{ + Named: name.PrependRemote(remoteName).AsNamed(), + name: name.ShortName(), + client: grpcClient, + logger: logger, + } + return c, nil +} + +// ListUUIDs lists all UUIDs in the world state store. +func (c *client) ListUUIDs(ctx context.Context, extra map[string]interface{}) ([][]byte, error) { + ctx, span := trace.StartSpan(ctx, "worldstatestore::client::ListUUIDs") + defer span.End() + ext, err := protoutils.StructToStructPb(extra) + if err != nil { + return nil, err + } + + req := &pb.ListUUIDsRequest{Name: c.name, Extra: ext} + resp, err := c.client.ListUUIDs(ctx, req) + if err != nil { + return nil, err + } + uuids := resp.GetUuids() + if uuids == nil { + return nil, ErrNilResponse + } + + return uuids, nil +} + +// GetTransform gets the transform for a given UUID. +func (c *client) GetTransform(ctx context.Context, uuid []byte, extra map[string]interface{}) (*commonPB.Transform, error) { + ctx, span := trace.StartSpan(ctx, "worldstatestore::client::GetTransform") + defer span.End() + ext, err := protoutils.StructToStructPb(extra) + if err != nil { + return nil, err + } + + req := &pb.GetTransformRequest{Name: c.name, Uuid: uuid, Extra: ext} + resp, err := c.client.GetTransform(ctx, req) + if err != nil { + return nil, err + } + obj := resp.GetTransform() + if obj == nil { + return nil, ErrNilResponse + } + + return obj, nil +} + +// StreamTransformChanges streams transform changes. +func (c *client) StreamTransformChanges(ctx context.Context, extra map[string]interface{}) (*TransformChangeStream, error) { + ctx, span := trace.StartSpan(ctx, "worldstatestore::client::StreamTransformChanges") + defer span.End() + + ext, err := protoutils.StructToStructPb(extra) + if err != nil { + return nil, err + } + + req := &pb.StreamTransformChangesRequest{Name: c.name, Extra: ext} + stream, err := c.client.StreamTransformChanges(ctx, req) + if err != nil { + return nil, err + } + // Check the initial response immediately to catch early errors. + _, err = stream.Recv() + if err != nil { + return nil, err + } + + iter := &TransformChangeStream{ + next: func() (TransformChange, error) { + resp, err := stream.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + return TransformChange{}, io.EOF + } + if ctx.Err() != nil || errors.Is(err, context.Canceled) { + return TransformChange{}, ctx.Err() + } + return TransformChange{}, err + } + change := TransformChange{ + ChangeType: resp.ChangeType, + Transform: resp.Transform, + } + if resp.UpdatedFields != nil { + change.UpdatedFields = resp.UpdatedFields.Paths + } + return change, nil + }, + } + + return iter, nil +} + +// DoCommand handles arbitrary commands. +func (c *client) DoCommand(ctx context.Context, cmd map[string]interface{}) (map[string]interface{}, error) { + ctx, span := trace.StartSpan(ctx, "worldstatestore::client::DoCommand") + defer span.End() + + return rprotoutils.DoFromResourceClient(ctx, c.client, c.name, cmd) +} diff --git a/services/worldstatestore/client_test.go b/services/worldstatestore/client_test.go new file mode 100644 index 00000000000..7f23e495c55 --- /dev/null +++ b/services/worldstatestore/client_test.go @@ -0,0 +1,281 @@ +package worldstatestore_test + +import ( + "context" + "net" + "testing" + + "github.com/pkg/errors" + commonpb "go.viam.com/api/common/v1" + pb "go.viam.com/api/service/worldstatestore/v1" + "go.viam.com/test" + "go.viam.com/utils/rpc" + + viamgrpc "go.viam.com/rdk/grpc" + "go.viam.com/rdk/logging" + "go.viam.com/rdk/resource" + _ "go.viam.com/rdk/services/register" + "go.viam.com/rdk/services/worldstatestore" + "go.viam.com/rdk/testutils/inject" +) + +func TestClient(t *testing.T) { + logger := logging.NewTestLogger(t) + listener1, err := net.Listen("tcp", "localhost:0") + test.That(t, err, test.ShouldBeNil) + rpcServer, err := rpc.NewServer(logger, rpc.WithUnauthenticated()) + test.That(t, err, test.ShouldBeNil) + + srv := &inject.WorldStateStoreService{} + srv.ListUUIDsFunc = func(ctx context.Context, extra map[string]any) ([][]byte, error) { + return [][]byte{[]byte("uuid1"), []byte("uuid2")}, nil + } + srv.GetTransformFunc = func(ctx context.Context, uuid []byte, extra map[string]any) (*commonpb.Transform, error) { + return &commonpb.Transform{ + ReferenceFrame: "test-frame", + Uuid: uuid, + }, nil + } + srv.StreamTransformChangesFunc = func(ctx context.Context, extra map[string]any) (*worldstatestore.TransformChangeStream, error) { + changesChan := make(chan worldstatestore.TransformChange, 1) + changesChan <- worldstatestore.TransformChange{ + ChangeType: pb.TransformChangeType_TRANSFORM_CHANGE_TYPE_ADDED, + Transform: &commonpb.Transform{ + ReferenceFrame: "test-frame", + Uuid: []byte("test-uuid"), + }, + } + close(changesChan) + return worldstatestore.NewTransformChangeStreamFromChannel(ctx, changesChan), nil + } + srv.DoFunc = func(ctx context.Context, cmd map[string]interface{}) (map[string]interface{}, error) { + return cmd, nil + } + + m := map[resource.Name]worldstatestore.Service{ + worldstatestore.Named(testWorldStateStoreServiceName): srv, + } + svc, err := resource.NewAPIResourceCollection(worldstatestore.API, m) + test.That(t, err, test.ShouldBeNil) + resourceAPI, ok, err := resource.LookupAPIRegistration[worldstatestore.Service](worldstatestore.API) + test.That(t, err, test.ShouldBeNil) + test.That(t, ok, test.ShouldBeTrue) + test.That(t, resourceAPI.RegisterRPCService(context.Background(), rpcServer, svc), test.ShouldBeNil) + + go rpcServer.Serve(listener1) + defer rpcServer.Stop() + + t.Run("Failing client", func(t *testing.T) { + cancelCtx, cancel := context.WithCancel(context.Background()) + cancel() + _, err = viamgrpc.Dial(cancelCtx, listener1.Addr().String(), logger) + test.That(t, err, test.ShouldNotBeNil) + test.That(t, err.Error(), test.ShouldContainSubstring, "canceled") + }) + + t.Run("ListUUIDs", func(t *testing.T) { + conn, err := viamgrpc.Dial(context.Background(), listener1.Addr().String(), logger) + test.That(t, err, test.ShouldBeNil) + client, err := worldstatestore.NewClientFromConn( + context.Background(), + conn, + "", + worldstatestore.Named(testWorldStateStoreServiceName), + logger, + ) + test.That(t, err, test.ShouldBeNil) + + uuids, err := client.ListUUIDs(context.Background(), nil) + test.That(t, err, test.ShouldBeNil) + test.That(t, len(uuids), test.ShouldEqual, 2) + test.That(t, uuids[0], test.ShouldResemble, []byte("uuid1")) + test.That(t, uuids[1], test.ShouldResemble, []byte("uuid2")) + + test.That(t, client.Close(context.Background()), test.ShouldBeNil) + test.That(t, conn.Close(), test.ShouldBeNil) + }) + + t.Run("GetTransform", func(t *testing.T) { + conn, err := viamgrpc.Dial(context.Background(), listener1.Addr().String(), logger) + test.That(t, err, test.ShouldBeNil) + client, err := worldstatestore.NewClientFromConn( + context.Background(), + conn, + "", + worldstatestore.Named(testWorldStateStoreServiceName), + logger, + ) + test.That(t, err, test.ShouldBeNil) + + transform, err := client.GetTransform(context.Background(), []byte("test-uuid"), nil) + test.That(t, err, test.ShouldBeNil) + test.That(t, transform.ReferenceFrame, test.ShouldEqual, "test-frame") + test.That(t, transform.Uuid, test.ShouldResemble, []byte("test-uuid")) + + test.That(t, client.Close(context.Background()), test.ShouldBeNil) + test.That(t, conn.Close(), test.ShouldBeNil) + }) + + t.Run("GetTransform returns ErrNilResponse when not found", func(t *testing.T) { + srv.GetTransformFunc = func(ctx context.Context, uuid []byte, extra map[string]any) (*commonpb.Transform, error) { + return nil, nil + } + + conn, err := viamgrpc.Dial(context.Background(), listener1.Addr().String(), logger) + test.That(t, err, test.ShouldBeNil) + client, err := worldstatestore.NewClientFromConn( + context.Background(), + conn, + "", + worldstatestore.Named(testWorldStateStoreServiceName), + logger, + ) + test.That(t, err, test.ShouldBeNil) + + obj, err := client.GetTransform(context.Background(), []byte("missing-uuid"), nil) + test.That(t, err, test.ShouldEqual, worldstatestore.ErrNilResponse) + test.That(t, obj, test.ShouldBeNil) + + test.That(t, client.Close(context.Background()), test.ShouldBeNil) + test.That(t, conn.Close(), test.ShouldBeNil) + }) + + t.Run("StreamTransformChanges", func(t *testing.T) { + conn, err := viamgrpc.Dial(context.Background(), listener1.Addr().String(), logger) + test.That(t, err, test.ShouldBeNil) + client, err := worldstatestore.NewClientFromConn( + context.Background(), + conn, + "", + worldstatestore.Named(testWorldStateStoreServiceName), + logger, + ) + test.That(t, err, test.ShouldBeNil) + + stream, err := client.StreamTransformChanges(context.Background(), nil) + test.That(t, err, test.ShouldBeNil) + + change, err := stream.Next() + test.That(t, err, test.ShouldBeNil) + test.That(t, change.ChangeType, test.ShouldEqual, pb.TransformChangeType_TRANSFORM_CHANGE_TYPE_ADDED) + test.That(t, change.Transform.ReferenceFrame, test.ShouldEqual, "test-frame") + test.That(t, change.Transform.Uuid, test.ShouldResemble, []byte("test-uuid")) + + test.That(t, client.Close(context.Background()), test.ShouldBeNil) + test.That(t, conn.Close(), test.ShouldBeNil) + }) + + t.Run("DoCommand", func(t *testing.T) { + conn, err := viamgrpc.Dial(context.Background(), listener1.Addr().String(), logger) + test.That(t, err, test.ShouldBeNil) + client, err := worldstatestore.NewClientFromConn( + context.Background(), + conn, + "", + worldstatestore.Named(testWorldStateStoreServiceName), + logger, + ) + test.That(t, err, test.ShouldBeNil) + + cmd := map[string]interface{}{"test": "command"} + resp, err := client.DoCommand(context.Background(), cmd) + test.That(t, err, test.ShouldBeNil) + test.That(t, resp, test.ShouldResemble, cmd) + + test.That(t, client.Close(context.Background()), test.ShouldBeNil) + test.That(t, conn.Close(), test.ShouldBeNil) + }) +} + +func TestClientFailures(t *testing.T) { + logger := logging.NewTestLogger(t) + listener1, err := net.Listen("tcp", "localhost:0") + test.That(t, err, test.ShouldBeNil) + rpcServer, err := rpc.NewServer(logger, rpc.WithUnauthenticated()) + test.That(t, err, test.ShouldBeNil) + + srv := &inject.WorldStateStoreService{} + expectedErr := errors.New("fake error") + srv.ListUUIDsFunc = func(ctx context.Context, extra map[string]any) ([][]byte, error) { + return nil, expectedErr + } + srv.GetTransformFunc = func(ctx context.Context, uuid []byte, extra map[string]any) (*commonpb.Transform, error) { + return nil, expectedErr + } + srv.StreamTransformChangesFunc = func(ctx context.Context, extra map[string]any) (*worldstatestore.TransformChangeStream, error) { + return nil, expectedErr + } + + m := map[resource.Name]worldstatestore.Service{ + worldstatestore.Named(testWorldStateStoreServiceName): srv, + } + svc, err := resource.NewAPIResourceCollection(worldstatestore.API, m) + test.That(t, err, test.ShouldBeNil) + resourceAPI, ok, err := resource.LookupAPIRegistration[worldstatestore.Service](worldstatestore.API) + test.That(t, err, test.ShouldBeNil) + test.That(t, ok, test.ShouldBeTrue) + test.That(t, resourceAPI.RegisterRPCService(context.Background(), rpcServer, svc), test.ShouldBeNil) + + go rpcServer.Serve(listener1) + defer rpcServer.Stop() + + t.Run("ListUUIDs with error", func(t *testing.T) { + conn, err := viamgrpc.Dial(context.Background(), listener1.Addr().String(), logger) + test.That(t, err, test.ShouldBeNil) + client, err := worldstatestore.NewClientFromConn( + context.Background(), + conn, + "", + worldstatestore.Named(testWorldStateStoreServiceName), + logger, + ) + test.That(t, err, test.ShouldBeNil) + + _, err = client.ListUUIDs(context.Background(), nil) + test.That(t, err, test.ShouldNotBeNil) + test.That(t, err.Error(), test.ShouldContainSubstring, "fake error") + + test.That(t, client.Close(context.Background()), test.ShouldBeNil) + test.That(t, conn.Close(), test.ShouldBeNil) + }) + + t.Run("GetTransform with error", func(t *testing.T) { + conn, err := viamgrpc.Dial(context.Background(), listener1.Addr().String(), logger) + test.That(t, err, test.ShouldBeNil) + client, err := worldstatestore.NewClientFromConn( + context.Background(), + conn, + "", + worldstatestore.Named(testWorldStateStoreServiceName), + logger, + ) + test.That(t, err, test.ShouldBeNil) + + _, err = client.GetTransform(context.Background(), []byte("test-uuid"), nil) + test.That(t, err, test.ShouldNotBeNil) + test.That(t, err.Error(), test.ShouldContainSubstring, "fake error") + + test.That(t, client.Close(context.Background()), test.ShouldBeNil) + test.That(t, conn.Close(), test.ShouldBeNil) + }) + + t.Run("StreamTransformChanges with error", func(t *testing.T) { + conn, err := viamgrpc.Dial(context.Background(), listener1.Addr().String(), logger) + test.That(t, err, test.ShouldBeNil) + client, err := worldstatestore.NewClientFromConn( + context.Background(), + conn, + "", + worldstatestore.Named(testWorldStateStoreServiceName), + logger, + ) + test.That(t, err, test.ShouldBeNil) + + _, err = client.StreamTransformChanges(context.Background(), nil) + test.That(t, err, test.ShouldNotBeNil) + test.That(t, err.Error(), test.ShouldContainSubstring, "fake error") + + test.That(t, client.Close(context.Background()), test.ShouldBeNil) + test.That(t, conn.Close(), test.ShouldBeNil) + }) +} diff --git a/services/worldstatestore/fake/fake.go b/services/worldstatestore/fake/fake.go new file mode 100644 index 00000000000..1224b3aed64 --- /dev/null +++ b/services/worldstatestore/fake/fake.go @@ -0,0 +1,528 @@ +// Package fake provides a fake implementation of the worldstatestore.Service interface. +package fake + +import ( + "context" + "errors" + "fmt" + "math" + "strings" + "sync" + "time" + + commonpb "go.viam.com/api/common/v1" + pb "go.viam.com/api/service/worldstatestore/v1" + "google.golang.org/protobuf/types/known/structpb" + + "go.viam.com/rdk/logging" + "go.viam.com/rdk/resource" + "go.viam.com/rdk/services/worldstatestore" +) + +// WorldStateStore implements the worldstatestore.Service interface. +type WorldStateStore struct { + resource.Named + resource.TriviallyReconfigurable + resource.TriviallyCloseable + mu sync.RWMutex + + transforms map[string]*commonpb.Transform + fps float64 + + startTime time.Time + activeBackgroundWorkers sync.WaitGroup + + changeChan chan worldstatestore.TransformChange + streamCtx context.Context + cancel context.CancelFunc + + logger logging.Logger +} + +func init() { + resource.RegisterService( + worldstatestore.API, + resource.DefaultModelFamily.WithModel("fake"), + resource.Registration[worldstatestore.Service, resource.NoNativeConfig]{Constructor: func( + ctx context.Context, + deps resource.Dependencies, + conf resource.Config, + logger logging.Logger, + ) (worldstatestore.Service, error) { + return newFakeWorldStateStore(conf.ResourceName(), logger), nil + }}) +} + +// ListUUIDs returns all transform UUIDs currently in the store. +func (f *WorldStateStore) ListUUIDs(ctx context.Context, extra map[string]any) ([][]byte, error) { + f.mu.RLock() + defer f.mu.RUnlock() + + uuids := make([][]byte, 0, len(f.transforms)) + for _, transform := range f.transforms { + uuids = append(uuids, transform.Uuid) + } + + return uuids, nil +} + +// GetTransform returns the transform for the given UUID. +func (f *WorldStateStore) GetTransform(ctx context.Context, uuid []byte, extra map[string]any) (*commonpb.Transform, error) { + f.mu.RLock() + defer f.mu.RUnlock() + + transform, exists := f.transforms[string(uuid)] + if !exists { + return nil, errors.New("transform not found") + } + + return transform, nil +} + +// StreamTransformChanges returns a channel of transform changes. +func (f *WorldStateStore) StreamTransformChanges( + ctx context.Context, + extra map[string]any, +) (*worldstatestore.TransformChangeStream, error) { + return worldstatestore.NewTransformChangeStreamFromChannel(ctx, f.changeChan), nil +} + +// DoCommand handles arbitrary commands. Currently accepts "fps": float64 to set the animation rate. +func (f *WorldStateStore) DoCommand(ctx context.Context, cmd map[string]interface{}) (map[string]interface{}, error) { + if fps, ok := cmd["fps"].(float64); ok { + if fps <= 0 { + return nil, errors.New("fps must be greater than 0") + } + f.mu.Lock() + f.fps = float64(fps) + f.mu.Unlock() + return map[string]any{ + "status": "fps set to " + fmt.Sprintf("%.2f", fps), + }, nil + } + + return map[string]any{ + "status": "command not implemented", + }, nil +} + +// Close stops the fake service and cleans up resources. +func (f *WorldStateStore) Close(ctx context.Context) error { + f.cancel() + + done := make(chan struct{}) + go func() { + f.activeBackgroundWorkers.Wait() + close(done) + }() + + select { + case <-done: + case <-ctx.Done(): + // proceed even if workers did not exit in time + } + + close(f.changeChan) + return nil +} + +func newFakeWorldStateStore(name resource.Name, logger logging.Logger) worldstatestore.Service { + ctx, cancel := context.WithCancel(context.Background()) + + fake := &WorldStateStore{ + Named: name.AsNamed(), + TriviallyReconfigurable: resource.TriviallyReconfigurable{}, + TriviallyCloseable: resource.TriviallyCloseable{}, + transforms: make(map[string]*commonpb.Transform), + fps: 10, + startTime: time.Now(), + changeChan: make(chan worldstatestore.TransformChange, 100), + streamCtx: ctx, + cancel: cancel, + logger: logger, + } + + fake.initializeStaticTransforms() + fake.activeBackgroundWorkers.Add(1) + go func() { + defer fake.activeBackgroundWorkers.Done() + fake.animationLoop() + }() + fake.activeBackgroundWorkers.Add(1) + go func() { + defer fake.activeBackgroundWorkers.Done() + fake.dynamicBoxSequence() + }() + + return fake +} + +// initializeStaticTransforms creates the initial three transforms in the world. +func (f *WorldStateStore) initializeStaticTransforms() { + f.mu.Lock() + defer f.mu.Unlock() + + // Create initial transforms + boxUUID := "box-001" + sphereUUID := "sphere-001" + capsuleUUID := "capsule-001" + + boxMetadata, err := structpb.NewStruct(map[string]any{ + "color": map[string]any{ + "r": 255, + "g": 0, + "b": 0, + }, + "opacity": 0.3, + }) + if err != nil { + panic(err) + } + + sphereMetadata, err := structpb.NewStruct(map[string]any{ + "color": map[string]any{ + "r": 0, + "g": 0, + "b": 255, + }, + "opacity": 0.7, + }) + if err != nil { + panic(err) + } + + capsuleMetadata, err := structpb.NewStruct(map[string]any{ + "color": map[string]any{ + "r": 0, + "g": 255, + "b": 0, + }, + "opacity": 1.0, + }) + if err != nil { + panic(err) + } + + f.transforms[boxUUID] = &commonpb.Transform{ + ReferenceFrame: "static-box", + PoseInObserverFrame: &commonpb.PoseInFrame{ + ReferenceFrame: "world", + Pose: &commonpb.Pose{ + X: -5, Y: 0, Z: 0, Theta: 0, OX: 0, OY: 0, OZ: 1, + }, + }, + PhysicalObject: &commonpb.Geometry{ + GeometryType: &commonpb.Geometry_Box{ + Box: &commonpb.RectangularPrism{ + DimsMm: &commonpb.Vector3{ + X: 100, + Y: 100, + Z: 100, + }, + }, + }, + }, + Uuid: []byte(boxUUID), + Metadata: boxMetadata, + } + + f.transforms[sphereUUID] = &commonpb.Transform{ + ReferenceFrame: "static-sphere", + PoseInObserverFrame: &commonpb.PoseInFrame{ + ReferenceFrame: "world", + Pose: &commonpb.Pose{ + X: 0, Y: 0, Z: 0, Theta: 0, OX: 0, OY: 0, OZ: 1, + }, + }, + PhysicalObject: &commonpb.Geometry{ + GeometryType: &commonpb.Geometry_Sphere{ + Sphere: &commonpb.Sphere{ + RadiusMm: 100, + }, + }, + }, + Uuid: []byte(sphereUUID), + Metadata: sphereMetadata, + } + + f.transforms[capsuleUUID] = &commonpb.Transform{ + ReferenceFrame: "static-capsule", + PoseInObserverFrame: &commonpb.PoseInFrame{ + ReferenceFrame: "world", + Pose: &commonpb.Pose{ + X: 5, Y: 0, Z: 0, Theta: 0, OX: 0, OY: 0, OZ: 1, + }, + }, + PhysicalObject: &commonpb.Geometry{ + GeometryType: &commonpb.Geometry_Capsule{ + Capsule: &commonpb.Capsule{ + RadiusMm: 100, + LengthMm: 100, + }, + }, + }, + Uuid: []byte(capsuleUUID), + Metadata: capsuleMetadata, + } +} + +func (f *WorldStateStore) updateBoxTransform(elapsed time.Duration) { + rotationSpeed := 2 * math.Pi / 5.0 // radians per second + angle := rotationSpeed * elapsed.Seconds() + + f.mu.Lock() + if transform, exists := f.transforms["box-001"]; exists { + transform.PoseInObserverFrame.Pose.Theta = angle * 180 / math.Pi + uuid := transform.Uuid + f.mu.Unlock() + partial := &commonpb.Transform{ + Uuid: uuid, + PoseInObserverFrame: &commonpb.PoseInFrame{ + Pose: &commonpb.Pose{Theta: angle * 180 / math.Pi}, + }, + } + f.emitTransformUpdate(partial, []string{"poseInObserverFrame.pose.theta"}) + return + } + f.mu.Unlock() +} + +func (f *WorldStateStore) updateSphereTransform(elapsed time.Duration) { + frequency := 2 * math.Pi / 5.0 // radians per second + height := math.Sin(frequency*elapsed.Seconds()) * 2.0 // ±2 units + + f.mu.Lock() + if transform, exists := f.transforms["sphere-001"]; exists { + transform.PoseInObserverFrame.Pose.Y = height + uuid := transform.Uuid + f.mu.Unlock() + partial := &commonpb.Transform{ + Uuid: uuid, + PoseInObserverFrame: &commonpb.PoseInFrame{ + Pose: &commonpb.Pose{Y: height}, + }, + } + f.emitTransformUpdate(partial, []string{"poseInObserverFrame.pose.y"}) + return + } + f.mu.Unlock() +} + +func (f *WorldStateStore) updateCapsuleTransform(elapsed time.Duration) { + frequency := 2 * math.Pi / 5.0 // radians per second + scale := 1.0 + 0.5*math.Sin(frequency*elapsed.Seconds()) // 0.5x to 1.5x + r := 100 * scale + l := 100 * scale + + f.mu.Lock() + if transform, exists := f.transforms["capsule-001"]; exists { + transform.PhysicalObject.GetCapsule().RadiusMm = r + transform.PhysicalObject.GetCapsule().LengthMm = l + uuid := transform.Uuid + f.mu.Unlock() + partial := &commonpb.Transform{ + Uuid: uuid, + PhysicalObject: &commonpb.Geometry{ + GeometryType: &commonpb.Geometry_Capsule{ + Capsule: &commonpb.Capsule{RadiusMm: r, LengthMm: l}, + }, + }, + } + f.emitTransformUpdate(partial, []string{"physicalObject.geometryType.value.radiusMm", "physicalObject.geometryType.value.lengthMm"}) + return + } + f.mu.Unlock() +} + +func (f *WorldStateStore) emitTransformChange(uuid string, changeType pb.TransformChangeType, updatedFields []string) { + var transformCopy *commonpb.Transform + + f.mu.RLock() + if transform, exists := f.transforms[uuid]; exists { + transformCopy = &commonpb.Transform{ + ReferenceFrame: transform.ReferenceFrame, + PoseInObserverFrame: transform.PoseInObserverFrame, + Uuid: transform.Uuid, + } + } + f.mu.RUnlock() + + if transformCopy == nil { + return + } + + change := worldstatestore.TransformChange{ + ChangeType: changeType, + Transform: transformCopy, + UpdatedFields: updatedFields, + } + + select { + case f.changeChan <- change: + case <-f.streamCtx.Done(): + default: + // Channel is full, skip this update + } +} + +// emitTransformUpdate emits a change with a partial transform payload for UPDATE events. +func (f *WorldStateStore) emitTransformUpdate(partial *commonpb.Transform, updatedFields []string) { + if partial == nil || len(partial.GetUuid()) == 0 { + return + } + change := worldstatestore.TransformChange{ + ChangeType: pb.TransformChangeType_TRANSFORM_CHANGE_TYPE_UPDATED, + Transform: partial, + UpdatedFields: updatedFields, + } + select { + case f.changeChan <- change: + case <-f.streamCtx.Done(): + default: + // Channel is full, skip this update + } +} + +func (f *WorldStateStore) animationLoop() { + f.mu.RLock() + curFPS := f.fps + f.mu.RUnlock() + if curFPS <= 0 { + curFPS = 1 + } + interval := time.Duration(float64(time.Second) / curFPS) + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-f.streamCtx.Done(): + return + case <-ticker.C: + f.updateTransforms() + // Reconfigure ticker if FPS changed + f.mu.RLock() + newFPS := f.fps + f.mu.RUnlock() + if newFPS != curFPS && newFPS > 0 { + ticker.Stop() + curFPS = newFPS + interval = time.Duration(float64(time.Second) / curFPS) + ticker = time.NewTicker(interval) + } + } + } +} + +func (f *WorldStateStore) dynamicBoxSequence() { + sequence := []struct { + action string + name string + delay time.Duration + }{ + {"add", "box-front-box", 3 * time.Second}, + {"remove", "box-front-box", 0}, + {"add", "box-front-sphere", 3 * time.Second}, + {"remove", "box-front-sphere", 0}, + {"add", "box-front-capsule", 3 * time.Second}, + {"remove", "box-front-capsule", 0}, + } + + for { + for _, step := range sequence { + select { + case <-f.streamCtx.Done(): + return + default: + } + + switch step.action { + case "add": + f.addDynamicBox(step.name) + case "remove": + f.removeDynamicBox(step.name) + } + + if step.delay > 0 { + select { + case <-f.streamCtx.Done(): + return + case <-time.After(step.delay): + } + } + } + } +} + +func (f *WorldStateStore) addDynamicBox(name string) { + var xOffset float64 + switch name { + case "box-front-box": + xOffset = -5 - 2 // In front of the main box + case "box-front-sphere": + xOffset = 0 - 2 // In front of the sphere + case "box-front-capsule": + xOffset = 5 - 2 // In front of the capsule + } + + uuid := name + "-" + time.Now().Format("20060102150405") + transform := &commonpb.Transform{ + ReferenceFrame: "dynamic-box", + PoseInObserverFrame: &commonpb.PoseInFrame{ + ReferenceFrame: "world", + Pose: &commonpb.Pose{ + X: xOffset, Y: 0, Z: 2, Theta: 0, OX: 0, OY: 0, OZ: 1, + }, + }, + Uuid: []byte(uuid), + } + + f.mu.Lock() + f.transforms[uuid] = transform + f.mu.Unlock() + + f.emitTransformChange(uuid, pb.TransformChangeType_TRANSFORM_CHANGE_TYPE_ADDED, nil) +} + +func (f *WorldStateStore) removeDynamicBox(name string) { + f.mu.Lock() + + var uuidToRemove string + for uuid := range f.transforms { + if strings.HasPrefix(uuid, name) { + uuidToRemove = uuid + break + } + } + + if uuidToRemove == "" { + f.mu.Unlock() + return + } + + transform := f.transforms[uuidToRemove] + delete(f.transforms, uuidToRemove) + f.mu.Unlock() + + change := worldstatestore.TransformChange{ + ChangeType: pb.TransformChangeType_TRANSFORM_CHANGE_TYPE_REMOVED, + Transform: &commonpb.Transform{ + Uuid: transform.Uuid, + }, + } + + select { + case f.changeChan <- change: + case <-f.streamCtx.Done(): + default: + // Channel is full, skip this update + } +} + +func (f *WorldStateStore) updateTransforms() { + elapsed := time.Since(f.startTime) + + f.updateBoxTransform(elapsed) + f.updateSphereTransform(elapsed) + f.updateCapsuleTransform(elapsed) +} diff --git a/services/worldstatestore/fake/fake_test.go b/services/worldstatestore/fake/fake_test.go new file mode 100644 index 00000000000..60ce16a3c3d --- /dev/null +++ b/services/worldstatestore/fake/fake_test.go @@ -0,0 +1,140 @@ +package fake + +import ( + "context" + "testing" + "time" + + "go.viam.com/test" + + "go.viam.com/rdk/logging" + "go.viam.com/rdk/resource" + "go.viam.com/rdk/services/worldstatestore" +) + +func TestFakeWorldStateStore(t *testing.T) { + // Create a new fake service + fake := newFakeWorldStateStore(resource.Name{Name: "test"}, nil) + defer fake.Close(context.Background()) + + // Test ListUUIDs + uuids, err := fake.ListUUIDs(context.Background(), nil) + test.That(t, err, test.ShouldBeNil) + test.That(t, len(uuids), test.ShouldEqual, 3) // box, sphere, capsule + + // Test GetTransform for each static transform + boxTransform, err := fake.GetTransform(context.Background(), []byte("box-001"), nil) + test.That(t, err, test.ShouldBeNil) + test.That(t, boxTransform, test.ShouldNotBeNil) + test.That(t, boxTransform.Uuid, test.ShouldResemble, []byte("box-001")) + test.That(t, boxTransform.Metadata, test.ShouldNotBeNil) + + // Test color metadata - should be a structpb.Value containing a StructValue + colorField := boxTransform.Metadata.Fields["color"] + test.That(t, colorField, test.ShouldNotBeNil) + test.That(t, colorField.GetStructValue(), test.ShouldNotBeNil) + test.That(t, colorField.GetStructValue().Fields["r"].GetNumberValue(), test.ShouldEqual, 255) + test.That(t, colorField.GetStructValue().Fields["g"].GetNumberValue(), test.ShouldEqual, 0) + test.That(t, colorField.GetStructValue().Fields["b"].GetNumberValue(), test.ShouldEqual, 0) + + test.That(t, boxTransform.Metadata.Fields["opacity"].GetNumberValue(), test.ShouldEqual, 0.3) + + sphereTransform, err := fake.GetTransform(context.Background(), []byte("sphere-001"), nil) + test.That(t, err, test.ShouldBeNil) + test.That(t, sphereTransform, test.ShouldNotBeNil) + test.That(t, sphereTransform.Uuid, test.ShouldResemble, []byte("sphere-001")) + test.That(t, sphereTransform.Metadata, test.ShouldNotBeNil) + + // Test color metadata for sphere + colorField = sphereTransform.Metadata.Fields["color"] + test.That(t, colorField, test.ShouldNotBeNil) + test.That(t, colorField.GetStructValue(), test.ShouldNotBeNil) + test.That(t, colorField.GetStructValue().Fields["r"].GetNumberValue(), test.ShouldEqual, 0) + test.That(t, colorField.GetStructValue().Fields["g"].GetNumberValue(), test.ShouldEqual, 0) + test.That(t, colorField.GetStructValue().Fields["b"].GetNumberValue(), test.ShouldEqual, 255) + + test.That(t, sphereTransform.Metadata.Fields["opacity"].GetNumberValue(), test.ShouldEqual, 0.7) + + capsuleTransform, err := fake.GetTransform(context.Background(), []byte("capsule-001"), nil) + test.That(t, err, test.ShouldBeNil) + test.That(t, capsuleTransform, test.ShouldNotBeNil) + test.That(t, capsuleTransform.Uuid, test.ShouldResemble, []byte("capsule-001")) + test.That(t, capsuleTransform.Metadata, test.ShouldNotBeNil) + + // Test color metadata for capsule + colorField = capsuleTransform.Metadata.Fields["color"] + test.That(t, colorField, test.ShouldNotBeNil) + test.That(t, colorField.GetStructValue(), test.ShouldNotBeNil) + test.That(t, colorField.GetStructValue().Fields["r"].GetNumberValue(), test.ShouldEqual, 0) + test.That(t, colorField.GetStructValue().Fields["g"].GetNumberValue(), test.ShouldEqual, 255) + test.That(t, colorField.GetStructValue().Fields["b"].GetNumberValue(), test.ShouldEqual, 0) + + test.That(t, capsuleTransform.Metadata.Fields["opacity"].GetNumberValue(), test.ShouldEqual, 1.0) + + // Test StreamTransformChanges + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + stream, err := fake.StreamTransformChanges(ctx, nil) + test.That(t, err, test.ShouldBeNil) + test.That(t, stream, test.ShouldNotBeNil) + + // Wait a bit for some changes to occur + time.Sleep(200 * time.Millisecond) + + // Check that we've received some changes + changeCount := 0 + if _, err := stream.Next(); err == nil { + changeCount++ + } + + // We should have at least some changes after 200ms + test.That(t, changeCount, test.ShouldBeGreaterThanOrEqualTo, 0) +} + +func TestFakeWorldStateStoreClose(t *testing.T) { + fake := newFakeWorldStateStore(resource.Name{Name: "test"}, nil) + + // Test that Close works + err := fake.Close(context.Background()) + test.That(t, err, test.ShouldBeNil) + + // Test that ListUUIDs still works after close (should return empty) + uuids, err := fake.ListUUIDs(context.Background(), nil) + test.That(t, err, test.ShouldBeNil) + test.That(t, len(uuids), test.ShouldBeGreaterThanOrEqualTo, 3) // Static transforms are still available + test.That(t, len(uuids), test.ShouldBeLessThanOrEqualTo, 4) // Dynamic transform may be available +} + +func TestDoCommandSetFPS(t *testing.T) { + logger := logging.NewTestLogger(t) + name := resource.NewName(worldstatestore.API, "fake1") + svc := newFakeWorldStateStore(name, logger) + wss := svc.(*WorldStateStore) + defer func() { _ = wss.Close(context.Background()) }() + + test.That(t, wss.fps, test.ShouldEqual, 10) + + // set fps via DoCommand + resp, err := wss.DoCommand(context.Background(), map[string]any{"fps": float64(20)}) + test.That(t, err, test.ShouldBeNil) + test.That(t, wss.fps, test.ShouldEqual, 20) + test.That(t, resp["status"], test.ShouldEqual, "fps set to 20.00") + + // attempt to set invalid fps + _, err = wss.DoCommand(context.Background(), map[string]any{"fps": float64(0)}) + test.That(t, err, test.ShouldNotBeNil) + test.That(t, err.Error(), test.ShouldContainSubstring, "fps must be greater than 0") + test.That(t, wss.fps, test.ShouldEqual, 20) +} + +func TestDoCommandUnknownCommand(t *testing.T) { + logger := logging.NewTestLogger(t) + name := resource.NewName(worldstatestore.API, "fake3") + svc := newFakeWorldStateStore(name, logger) + wss := svc.(*WorldStateStore) + defer func() { _ = wss.Close(context.Background()) }() + + resp, err := wss.DoCommand(context.Background(), map[string]any{"noop": true}) + test.That(t, err, test.ShouldBeNil) + test.That(t, resp["status"], test.ShouldEqual, "command not implemented") +} diff --git a/services/worldstatestore/register/register.go b/services/worldstatestore/register/register.go new file mode 100644 index 00000000000..b0e3a54bdf1 --- /dev/null +++ b/services/worldstatestore/register/register.go @@ -0,0 +1,7 @@ +// Package register registers all relevant world object store models and also API specific functions +package register + +import ( + // for world state store models. + _ "go.viam.com/rdk/services/worldstatestore/fake" +) diff --git a/services/worldstatestore/server.go b/services/worldstatestore/server.go new file mode 100644 index 00000000000..e23fd072c22 --- /dev/null +++ b/services/worldstatestore/server.go @@ -0,0 +1,138 @@ +package worldstatestore + +import ( + "context" + "errors" + "io" + + "go.opencensus.io/trace" + commonpb "go.viam.com/api/common/v1" + pb "go.viam.com/api/service/worldstatestore/v1" + "google.golang.org/protobuf/types/known/fieldmaskpb" + + "go.viam.com/rdk/protoutils" + "go.viam.com/rdk/resource" +) + +type serviceServer struct { + pb.UnimplementedWorldStateStoreServiceServer + coll resource.APIResourceCollection[Service] +} + +// NewRPCServiceServer constructs a the world state store gRPC service server. +func NewRPCServiceServer(coll resource.APIResourceCollection[Service]) interface{} { + return &serviceServer{coll: coll} +} + +// ListUUIDs returns a list of world state uuids. +func (server *serviceServer) ListUUIDs(ctx context.Context, req *pb.ListUUIDsRequest) ( + *pb.ListUUIDsResponse, error, +) { + ctx, span := trace.StartSpan(ctx, "worldstatestore::server::ListUUIDs") + defer span.End() + + svc, err := server.coll.Resource(req.Name) + if err != nil { + return nil, err + } + + uuids, err := svc.ListUUIDs(ctx, req.Extra.AsMap()) + if err != nil { + return nil, err + } + if uuids == nil { + return nil, ErrNilResponse + } + + return &pb.ListUUIDsResponse{Uuids: uuids}, nil +} + +// GetTransform returns a world state object by uuid. +func (server *serviceServer) GetTransform(ctx context.Context, req *pb.GetTransformRequest) ( + *pb.GetTransformResponse, error, +) { + ctx, span := trace.StartSpan(ctx, "worldstatestore::server::GetTransform") + defer span.End() + + svc, err := server.coll.Resource(req.Name) + if err != nil { + return nil, err + } + + obj, err := svc.GetTransform(ctx, req.Uuid, req.Extra.AsMap()) + if err != nil { + return nil, err + } + if obj == nil { + return &pb.GetTransformResponse{}, nil + } + + return &pb.GetTransformResponse{Transform: obj}, nil +} + +// DoCommand receives arbitrary commands. +func (server *serviceServer) DoCommand(ctx context.Context, + req *commonpb.DoCommandRequest, +) (*commonpb.DoCommandResponse, error) { + ctx, span := trace.StartSpan(ctx, "worldstatestore::server::DoCommand") + defer span.End() + + svc, err := server.coll.Resource(req.Name) + if err != nil { + return nil, err + } + return protoutils.DoFromResourceServer(ctx, svc, req) +} + +// StreamTransformChanges streams changes to world state transforms to the client. +func (server *serviceServer) StreamTransformChanges( + req *pb.StreamTransformChangesRequest, + stream pb.WorldStateStoreService_StreamTransformChangesServer, +) error { + ctx, span := trace.StartSpan(stream.Context(), "worldstatestore::server::StreamTransformChanges") + defer span.End() + + svc, err := server.coll.Resource(req.Name) + if err != nil { + return err + } + + changesStream, err := svc.StreamTransformChanges(ctx, req.Extra.AsMap()) + if err != nil { + return err + } + + // Send an empty response first so the client doesn't block while checking for errors. + err = stream.Send(&pb.StreamTransformChangesResponse{}) + if err != nil { + return err + } + + for { + change, err := changesStream.Next() + if err != nil { + if errors.Is(err, io.EOF) { + return nil + } + return err + } + + // Convert the internal TransformChange to protobuf response + resp := &pb.StreamTransformChangesResponse{ + ChangeType: change.ChangeType, + Transform: change.Transform, + } + + // Convert UpdatedFields to FieldMask if present + if len(change.UpdatedFields) > 0 { + fieldMask := &fieldmaskpb.FieldMask{ + Paths: change.UpdatedFields, + } + resp.UpdatedFields = fieldMask + } + + if err := stream.Send(resp); err != nil { + return err + } + } +} diff --git a/services/worldstatestore/server_test.go b/services/worldstatestore/server_test.go new file mode 100644 index 00000000000..12bf781cf8c --- /dev/null +++ b/services/worldstatestore/server_test.go @@ -0,0 +1,297 @@ +package worldstatestore_test + +import ( + "context" + "testing" + + "github.com/pkg/errors" + commonpb "go.viam.com/api/common/v1" + pb "go.viam.com/api/service/worldstatestore/v1" + "go.viam.com/test" + "google.golang.org/grpc" + "google.golang.org/protobuf/types/known/structpb" + + "go.viam.com/rdk/resource" + "go.viam.com/rdk/services/worldstatestore" + "go.viam.com/rdk/testutils/inject" +) + +const testWorldStateStoreServiceName = "worldstatestore1" + +func newServer(m map[resource.Name]worldstatestore.Service) (pb.WorldStateStoreServiceServer, error) { + coll, err := resource.NewAPIResourceCollection(worldstatestore.API, m) + if err != nil { + return nil, err + } + return worldstatestore.NewRPCServiceServer(coll).(pb.WorldStateStoreServiceServer), nil +} + +func TestWorldStateStoreServerFailures(t *testing.T) { + // Test with no service + m := map[resource.Name]worldstatestore.Service{} + server, err := newServer(m) + test.That(t, err, test.ShouldBeNil) + + // Test ListUUIDs with no service + _, err = server.ListUUIDs(context.Background(), &pb.ListUUIDsRequest{Name: testWorldStateStoreServiceName}) + test.That(t, err, test.ShouldNotBeNil) + test.That(t, err.Error(), test.ShouldContainSubstring, "not found") + + // Test GetTransform with no service + _, err = server.GetTransform( + context.Background(), + &pb.GetTransformRequest{Name: testWorldStateStoreServiceName, Uuid: []byte("test-uuid")}, + ) + test.That(t, err, test.ShouldNotBeNil) + test.That(t, err.Error(), test.ShouldContainSubstring, "not found") + + // Test StreamTransformChanges with no service + req := &pb.StreamTransformChangesRequest{Name: testWorldStateStoreServiceName} + mockStream := &mockStreamTransformChangesServer{ctx: context.Background()} + err = server.StreamTransformChanges(req, mockStream) + test.That(t, err, test.ShouldNotBeNil) + test.That(t, err.Error(), test.ShouldContainSubstring, "not found") + + // Test DoCommand with no service + _, err = server.DoCommand(context.Background(), &commonpb.DoCommandRequest{Name: testWorldStateStoreServiceName}) + test.That(t, err, test.ShouldNotBeNil) + test.That(t, err.Error(), test.ShouldContainSubstring, "not found") +} + +func TestServerListUUIDs(t *testing.T) { + injectWSS := &inject.WorldStateStoreService{} + m := map[resource.Name]worldstatestore.Service{ + worldstatestore.Named(testWorldStateStoreServiceName): injectWSS, + } + server, err := newServer(m) + test.That(t, err, test.ShouldBeNil) + + t.Run("successful ListUUIDs", func(t *testing.T) { + expectedUUIDs := [][]byte{[]byte("uuid1"), []byte("uuid2"), []byte("uuid3")} + extra := map[string]interface{}{"foo": "bar"} + ext, err := structpb.NewStruct(extra) + test.That(t, err, test.ShouldBeNil) + + injectWSS.ListUUIDsFunc = func(ctx context.Context, extra map[string]any) ([][]byte, error) { + return expectedUUIDs, nil + } + + req := &pb.ListUUIDsRequest{ + Name: testWorldStateStoreServiceName, + Extra: ext, + } + + resp, err := server.ListUUIDs(context.Background(), req) + test.That(t, err, test.ShouldBeNil) + test.That(t, resp.Uuids, test.ShouldResemble, expectedUUIDs) + }) + + t.Run("ListUUIDs with error", func(t *testing.T) { + expectedErr := errors.New("fake error") + injectWSS.ListUUIDsFunc = func(ctx context.Context, extra map[string]any) ([][]byte, error) { + return nil, expectedErr + } + + req := &pb.ListUUIDsRequest{Name: testWorldStateStoreServiceName} + _, err := server.ListUUIDs(context.Background(), req) + test.That(t, err, test.ShouldEqual, expectedErr) + }) + + t.Run("ListUUIDs with nil response", func(t *testing.T) { + injectWSS.ListUUIDsFunc = func(ctx context.Context, extra map[string]any) ([][]byte, error) { + return nil, nil + } + + req := &pb.ListUUIDsRequest{Name: testWorldStateStoreServiceName} + _, err := server.ListUUIDs(context.Background(), req) + test.That(t, err, test.ShouldEqual, worldstatestore.ErrNilResponse) + }) +} + +func TestServerGetTransform(t *testing.T) { + injectWSS := &inject.WorldStateStoreService{} + m := map[resource.Name]worldstatestore.Service{ + worldstatestore.Named(testWorldStateStoreServiceName): injectWSS, + } + server, err := newServer(m) + test.That(t, err, test.ShouldBeNil) + + t.Run("successful GetTransform", func(t *testing.T) { + expectedTransform := &commonpb.Transform{ + ReferenceFrame: "test-frame", + Uuid: []byte("test-uuid"), + } + extra := map[string]interface{}{"foo": "bar"} + ext, err := structpb.NewStruct(extra) + test.That(t, err, test.ShouldBeNil) + + injectWSS.GetTransformFunc = func( + ctx context.Context, + uuid []byte, + extra map[string]any, + ) (*commonpb.Transform, error) { + return expectedTransform, nil + } + + req := &pb.GetTransformRequest{ + Name: testWorldStateStoreServiceName, + Uuid: []byte("test-uuid"), + Extra: ext, + } + + resp, err := server.GetTransform(context.Background(), req) + test.That(t, err, test.ShouldBeNil) + test.That(t, resp.Transform, test.ShouldResemble, expectedTransform) + }) + + t.Run("GetTransform with error", func(t *testing.T) { + expectedErr := errors.New("fake error") + injectWSS.GetTransformFunc = func( + ctx context.Context, + uuid []byte, + extra map[string]any, + ) (*commonpb.Transform, error) { + return nil, expectedErr + } + + req := &pb.GetTransformRequest{ + Name: testWorldStateStoreServiceName, + Uuid: []byte("test-uuid"), + } + _, err := server.GetTransform(context.Background(), req) + test.That(t, err, test.ShouldEqual, expectedErr) + }) + + t.Run("GetTransform with nil response", func(t *testing.T) { + injectWSS.GetTransformFunc = func( + ctx context.Context, + uuid []byte, + extra map[string]any, + ) (*commonpb.Transform, error) { + return nil, nil + } + + req := &pb.GetTransformRequest{ + Name: testWorldStateStoreServiceName, + Uuid: []byte("test-uuid"), + } + resp, err := server.GetTransform(context.Background(), req) + test.That(t, err, test.ShouldBeNil) + test.That(t, resp.Transform, test.ShouldBeNil) + }) +} + +func TestServerStreamTransformChanges(t *testing.T) { + injectWSS := &inject.WorldStateStoreService{} + m := map[resource.Name]worldstatestore.Service{ + worldstatestore.Named(testWorldStateStoreServiceName): injectWSS, + } + server, err := newServer(m) + test.That(t, err, test.ShouldBeNil) + + t.Run("successful StreamTransformChanges", func(t *testing.T) { + extra := map[string]interface{}{"foo": "bar"} + ext, err := structpb.NewStruct(extra) + test.That(t, err, test.ShouldBeNil) + + changesChan := make(chan worldstatestore.TransformChange, 2) + changesChan <- worldstatestore.TransformChange{ + ChangeType: pb.TransformChangeType_TRANSFORM_CHANGE_TYPE_ADDED, + Transform: &commonpb.Transform{ + ReferenceFrame: "test-frame", + Uuid: []byte("test-uuid"), + }, + } + changesChan <- worldstatestore.TransformChange{ + ChangeType: pb.TransformChangeType_TRANSFORM_CHANGE_TYPE_UPDATED, + Transform: &commonpb.Transform{ + ReferenceFrame: "test-frame", + Uuid: []byte("test-uuid"), + }, + UpdatedFields: []string{"pose_in_observer_frame"}, + } + close(changesChan) + + injectWSS.StreamTransformChangesFunc = func( + ctx context.Context, + extra map[string]any, + ) (*worldstatestore.TransformChangeStream, error) { + return worldstatestore.NewTransformChangeStreamFromChannel(ctx, changesChan), nil + } + + req := &pb.StreamTransformChangesRequest{ + Name: testWorldStateStoreServiceName, + Extra: ext, + } + + // Create a mock stream + mockStream := &mockStreamTransformChangesServer{ + ctx: context.Background(), + changes: make([]*pb.StreamTransformChangesResponse, 0), + } + + err = server.StreamTransformChanges(req, mockStream) + test.That(t, err, test.ShouldBeNil) + test.That(t, len(mockStream.changes), test.ShouldEqual, 3) // 1 empty + 2 changes + if len(mockStream.changes) == 3 { + updated := mockStream.changes[2] + test.That(t, updated.UpdatedFields, test.ShouldNotBeNil) + test.That(t, updated.UpdatedFields.Paths, test.ShouldResemble, []string{"pose_in_observer_frame"}) + } + }) + + t.Run("StreamTransformChanges with error", func(t *testing.T) { + expectedErr := errors.New("fake error") + injectWSS.StreamTransformChangesFunc = func( + ctx context.Context, + extra map[string]any, + ) (*worldstatestore.TransformChangeStream, error) { + return nil, expectedErr + } + + req := &pb.StreamTransformChangesRequest{Name: testWorldStateStoreServiceName} + mockStream := &mockStreamTransformChangesServer{ + ctx: context.Background(), + } + + err := server.StreamTransformChanges(req, mockStream) + test.That(t, err, test.ShouldEqual, expectedErr) + }) +} + +func TestServerDoCommand(t *testing.T) { + injectWSS := &inject.WorldStateStoreService{} + m := map[resource.Name]worldstatestore.Service{ + worldstatestore.Named(testWorldStateStoreServiceName): injectWSS, + } + server, err := newServer(m) + test.That(t, err, test.ShouldBeNil) + + t.Run("successful DoCommand", func(t *testing.T) { + expectedResponse := map[string]interface{}{"result": "success"} + injectWSS.DoFunc = func(ctx context.Context, cmd map[string]interface{}) (map[string]interface{}, error) { + return expectedResponse, nil + } + + req := &commonpb.DoCommandRequest{Name: testWorldStateStoreServiceName} + resp, err := server.DoCommand(context.Background(), req) + test.That(t, err, test.ShouldBeNil) + test.That(t, resp, test.ShouldNotBeNil) + }) +} + +// mockStreamTransformChangesServer implements pb.WorldStateStoreService_StreamTransformChangesServer for testing. +type mockStreamTransformChangesServer struct { + grpc.ServerStream + ctx context.Context + changes []*pb.StreamTransformChangesResponse +} + +func (m *mockStreamTransformChangesServer) Context() context.Context { + return m.ctx +} + +func (m *mockStreamTransformChangesServer) Send(resp *pb.StreamTransformChangesResponse) error { + m.changes = append(m.changes, resp) + return nil +} diff --git a/services/worldstatestore/world_state_store.go b/services/worldstatestore/world_state_store.go new file mode 100644 index 00000000000..1de3a45f9a9 --- /dev/null +++ b/services/worldstatestore/world_state_store.go @@ -0,0 +1,147 @@ +// Package worldstatestore implements the world state store service, which lets users +// create custom visualizers to be rendered in the client. +// For more information, see the [WorldStateStore service docs]. +// +// [WorldStateStore service docs]: https://docs.viam.com/dev/reference/apis/services/world-state-store/ +package worldstatestore + +import ( + "context" + "errors" + "io" + + commonpb "go.viam.com/api/common/v1" + pb "go.viam.com/api/service/worldstatestore/v1" + + "go.viam.com/rdk/resource" + "go.viam.com/rdk/robot" +) + +func init() { + resource.RegisterAPI(API, resource.APIRegistration[Service]{ + RPCServiceServerConstructor: NewRPCServiceServer, + RPCServiceHandler: pb.RegisterWorldStateStoreServiceHandlerFromEndpoint, + RPCServiceDesc: &pb.WorldStateStoreService_ServiceDesc, + RPCClient: NewClientFromConn, + }) +} + +const ( + // SubtypeName is the name of the type of service. + SubtypeName = "world_state_store" +) + +// API is a variable that identifies the world state store resource API. +var API = resource.APINamespaceRDK.WithServiceType(SubtypeName) + +// ErrNilResponse is the error for when a nil response is returned from a world object store service. +var ErrNilResponse = errors.New("world state store service returned a nil response") + +// Named is a helper for getting the named service's typed resource name. +func Named(name string) resource.Name { + return resource.NewName(API, name) +} + +// FromRobot is a helper for getting the named world state store service from the given Robot. +func FromRobot(r robot.Robot, name string) (Service, error) { + return robot.ResourceFromRobot[Service](r, Named(name)) +} + +// FromDependencies is a helper for getting the named world state store service from a collection of +// dependencies. +func FromDependencies(deps resource.Dependencies, name string) (Service, error) { + return resource.FromDependencies[Service](deps, Named(name)) +} + +// Service describes the functions that are available to the service. +// +// For more information, see the [WorldStateStore service docs]. +// +// ListUUIDs example: +// +// // List the world state uuids of a WorldStateStore Service. +// uuids, err := myWorldStateStoreService.ListUUIDs(ctx, nil) +// if err != nil { +// logger.Fatal(err) +// } +// // Print out the world state +// for _, uuid := range uuids { +// fmt.Printf("UUID: %v", uuid) +// } +// +// For more information, see the [list uuids method docs]. +// +// GetTransform example: +// +// // Get the transform by uuid. +// obj, err := myWorldStateStoreService.GetTransform(ctx, myUUID, nil) +// if err != nil { +// logger.Fatal(err) +// } +// // Print out the transform. +// fmt.Printf("Name: %v\nPose: %+v\nMetadata: %+v\nGeometry: %+v", obj.Name, obj.Pose, obj.Metadata, obj.Geometry) +// +// For more information, see the [get transform method docs]. +// +// StreamTransformChanges example: +// +// // Stream transform changes. +// changes, err := myWorldStateStoreService.StreamTransformChanges(ctx, nil) +// if err != nil { +// logger.Fatal(err) +// } +// for change := range changes { +// fmt.Printf("Change: %v\n", change) +// } +// +// For more information, see the [stream transform changes method docs]. +// +// [WorldStateStore service docs]: https://docs.viam.com/dev/reference/apis/services/world-state-store/ +// [list uuids method docs]: https://docs.viam.com/dev/reference/apis/services/list-uuids/ +// [get transform method docs]: https://docs.viam.com/dev/reference/apis/services/get-transform/ +// [stream transform changes method docs]: https://docs.viam.com/dev/reference/apis/services/stream-transform-changes/ +type Service interface { + resource.Resource + ListUUIDs(ctx context.Context, extra map[string]any) ([][]byte, error) + GetTransform(ctx context.Context, uuid []byte, extra map[string]any) (*commonpb.Transform, error) + StreamTransformChanges(ctx context.Context, extra map[string]any) (*TransformChangeStream, error) +} + +// TransformChange represents a change to a world state transform. +type TransformChange struct { + ChangeType pb.TransformChangeType + Transform *commonpb.Transform + UpdatedFields []string +} + +// TransformChangeStream provides an iterator interface for receiving transform changes. +// Call Next repeatedly until it returns io.EOF. +type TransformChangeStream struct { + next func() (TransformChange, error) +} + +// Next returns the next TransformChange, or io.EOF when the stream ends. +func (s *TransformChangeStream) Next() (TransformChange, error) { + if s == nil || s.next == nil { + return TransformChange{}, io.EOF + } + return s.next() +} + +// NewTransformChangeStreamFromChannel wraps a channel of TransformChange as a TransformChangeStream. +// The provided context is used to cancel iteration; when ctx is done, Next returns ctx.Err(). +func NewTransformChangeStreamFromChannel(ctx context.Context, ch <-chan TransformChange) *TransformChangeStream { + return &TransformChangeStream{ + next: func() (TransformChange, error) { + select { + case <-ctx.Done(): + return TransformChange{}, ctx.Err() + case change, ok := <-ch: + if !ok { + return TransformChange{}, io.EOF + } + return change, nil + } + }, + } +} diff --git a/testutils/inject/worldstatestore_service.go b/testutils/inject/worldstatestore_service.go new file mode 100644 index 00000000000..9f0ad12b571 --- /dev/null +++ b/testutils/inject/worldstatestore_service.go @@ -0,0 +1,68 @@ +package inject + +import ( + "context" + "errors" + + commonpb "go.viam.com/api/common/v1" + + "go.viam.com/rdk/resource" + "go.viam.com/rdk/services/worldstatestore" +) + +// WorldStateStoreService is an injectable world object store service. +type WorldStateStoreService struct { + resource.Named + resource.TriviallyReconfigurable + resource.TriviallyCloseable + name resource.Name + ListUUIDsFunc func(ctx context.Context, extra map[string]any) ([][]byte, error) + GetTransformFunc func(ctx context.Context, uuid []byte, extra map[string]any) (*commonpb.Transform, error) + StreamTransformChangesFunc func(ctx context.Context, extra map[string]any) (*worldstatestore.TransformChangeStream, error) + DoFunc func(ctx context.Context, cmd map[string]interface{}) (map[string]interface{}, error) +} + +// NewWorldStateStoreService returns a new injected world state store service. +func NewWorldStateStoreService(name string) *WorldStateStoreService { + return &WorldStateStoreService{name: worldstatestore.Named(name)} +} + +// Name returns the name of the resource. +func (wosSvc *WorldStateStoreService) Name() resource.Name { + return wosSvc.name +} + +// ListUUIDs calls the injected ListUUIDsFunc or the real version. +func (wosSvc *WorldStateStoreService) ListUUIDs(ctx context.Context, extra map[string]any) ([][]byte, error) { + if wosSvc.ListUUIDsFunc == nil { + return nil, errors.New("ListUUIDsFunc not set") + } + return wosSvc.ListUUIDsFunc(ctx, extra) +} + +// GetTransform calls the injected GetTransformFunc or the real version. +func (wosSvc *WorldStateStoreService) GetTransform(ctx context.Context, uuid []byte, extra map[string]any) (*commonpb.Transform, error) { + if wosSvc.GetTransformFunc == nil { + return nil, errors.New("GetTransformFunc not set") + } + return wosSvc.GetTransformFunc(ctx, uuid, extra) +} + +// DoCommand calls the injected DoCommand or the real version. +func (wosSvc *WorldStateStoreService) DoCommand(ctx context.Context, cmd map[string]interface{}) (map[string]interface{}, error) { + if wosSvc.DoFunc == nil { + return nil, errors.New("DoCommandFunc not set") + } + return wosSvc.DoFunc(ctx, cmd) +} + +// StreamTransformChanges calls the injected StreamTransformChangesFunc or the real version. +func (wosSvc *WorldStateStoreService) StreamTransformChanges( + ctx context.Context, + extra map[string]any, +) (*worldstatestore.TransformChangeStream, error) { + if wosSvc.StreamTransformChangesFunc == nil { + return nil, errors.New("StreamTransformChangesFunc not set") + } + return wosSvc.StreamTransformChangesFunc(ctx, extra) +} diff --git a/web/server/entrypoint_test.go b/web/server/entrypoint_test.go index 3452c2406f4..8d0b373fdd9 100644 --- a/web/server/entrypoint_test.go +++ b/web/server/entrypoint_test.go @@ -112,10 +112,10 @@ func TestEntrypoint(t *testing.T) { err = json.Unmarshal(outputBytes, ®istrations) test.That(t, err, test.ShouldBeNil) - numReg := 52 + numReg := 53 if runtime.GOOS == "windows" { // windows build excludes builtin models that use cgo - numReg = 43 + numReg = 44 } test.That(t, registrations, test.ShouldHaveLength, numReg)