Skip to content

Commit 4c0af56

Browse files
authored
fix: Reset lens state when calling Reset (#148)
1 parent f0fb7f7 commit 4c0af56

File tree

12 files changed

+401
-2
lines changed

12 files changed

+401
-2
lines changed

host-go/engine/module/instance.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ type Instance struct {
2222
// module after this function has been called are not guaranteed to be visible to the returned io.Reader.
2323
Memory func() Memory
2424

25+
// Reset resets the wasm memory to the state it was in immediately after configuration, before any items were
26+
// pulled through it.
27+
Reset func()
28+
2529
// OwnedBy hosts a reference to any object(s) that may be required to live in memory for the lifetime of this Module.
2630
//
2731
// This is very important when working with some libraries (such as wasmer-go), as without this, dependencies of other members

host-go/engine/pipes/fromPipe.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ func (p *fromPipe[TSource, TResult]) Bytes() ([]byte, error) {
9090
}
9191

9292
func (p *fromPipe[TSource, TResult]) Reset() {
93+
p.instance.Reset()
9394
p.source.Reset()
9495
}
9596

host-go/engine/pipes/fromSource.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ func (s *fromSource[TSource, TResult]) Bytes() ([]byte, error) {
9191
}
9292

9393
func (s *fromSource[TSource, TResult]) Reset() {
94+
s.instance.Reset()
9495
s.source.Reset()
9596
}
9697

host-go/runtimes/js/runtime.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,11 @@ func (m *wModule) NewInstance(functionName string, paramSets ...map[string]any)
156156
}
157157
}
158158

159+
jsMemory := js.Global().Get("Uint8Array").New(memory.Get("buffer"))
160+
initialLen := jsMemory.Get("length").Int()
161+
initialState := make([]byte, initialLen)
162+
js.CopyBytesToGo(initialState, jsMemory)
163+
159164
return module.Instance{
160165
Alloc: func(u module.MemSize) (module.MemSize, error) {
161166
result := alloc.Invoke(int32(u))
@@ -173,6 +178,16 @@ func (m *wModule) NewInstance(functionName string, paramSets ...map[string]any)
173178
buffer := memory.Get("buffer")
174179
return newMemory(buffer)
175180
},
181+
Reset: func() {
182+
initialLen := len(initialState)
183+
currentLen := jsMemory.Get("length").Int()
184+
185+
js.CopyBytesToJS(jsMemory, initialState)
186+
187+
for i := initialLen; i < currentLen; i++ {
188+
jsMemory.SetIndex(i, js.ValueOf(0))
189+
}
190+
},
176191
OwnedBy: instance,
177192
}, nil
178193
}

host-go/runtimes/wasmer/runtime.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,10 @@ func (m *wModule) NewInstance(functionName string, paramSets ...map[string]any)
151151
}
152152
}
153153

154+
memSlice := memory.Data()
155+
initialState := make([]byte, len(memSlice))
156+
copy(initialState, memSlice)
157+
154158
return module.Instance{
155159
Alloc: func(u module.MemSize) (module.MemSize, error) {
156160
r, err := alloc.Call(u)
@@ -173,6 +177,16 @@ func (m *wModule) NewInstance(functionName string, paramSets ...map[string]any)
173177
Memory: func() module.Memory {
174178
return module.NewBytesMemory(memory.Data())
175179
},
180+
Reset: func() {
181+
initialLen := len(initialState)
182+
currentMemory := memory.Data()
183+
184+
copy(currentMemory, initialState)
185+
186+
for i := initialLen; i < len(currentMemory); i++ {
187+
currentMemory[i] = 0
188+
}
189+
},
176190
OwnedBy: instance,
177191
}, nil
178192
}

host-go/runtimes/wasmtime/runtime.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,10 @@ func (m *wModule) NewInstance(functionName string, paramSets ...map[string]any)
144144
}
145145
}
146146

147+
memSlice := memory.UnsafeData(m.rt.store)
148+
initialState := make([]byte, len(memSlice))
149+
copy(initialState, memSlice)
150+
147151
return module.Instance{
148152
Alloc: func(u module.MemSize) (module.MemSize, error) {
149153
r, err := alloc.Call(m.rt.store, u)
@@ -166,6 +170,16 @@ func (m *wModule) NewInstance(functionName string, paramSets ...map[string]any)
166170
Memory: func() module.Memory {
167171
return module.NewBytesMemory(memory.UnsafeData(m.rt.store))
168172
},
173+
Reset: func() {
174+
initialLen := len(initialState)
175+
currentMemory := memory.UnsafeData(m.rt.store)
176+
177+
copy(currentMemory, initialState)
178+
179+
for i := initialLen; i < len(currentMemory); i++ {
180+
currentMemory[i] = 0
181+
}
182+
},
169183
OwnedBy: instance,
170184
}, nil
171185
}

host-go/runtimes/wazero/runtime.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ func (m *wModule) NewInstance(functionName string, paramSets ...map[string]any)
143143
}
144144
}
145145

146+
initialState, _ := memory.Read(0, memory.Size())
147+
146148
return module.Instance{
147149
Alloc: func(u module.MemSize) (module.MemSize, error) {
148150
r, err := alloc.Call(ctx, uint64(u))
@@ -165,6 +167,15 @@ func (m *wModule) NewInstance(functionName string, paramSets ...map[string]any)
165167
Memory: func() module.Memory {
166168
return newMemory(memory)
167169
},
170+
Reset: func() {
171+
initialLen := uint32(len(initialState))
172+
currentLen := memory.Size()
173+
memory.Write(0, initialState)
174+
175+
for i := initialLen; i < currentLen; i++ {
176+
memory.WriteByte(i, 0)
177+
}
178+
},
168179
OwnedBy: instance,
169180
}, nil
170181
}

tests/action/transform.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// This Source Code Form is subject to the terms of the Mozilla Public
2+
// License, v. 2.0. If a copy of the MPL was not distributed with this
3+
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
4+
5+
package action
6+
7+
import (
8+
"fmt"
9+
10+
"github.com/sourcenetwork/immutable/enumerable"
11+
"github.com/sourcenetwork/lens/host-go/store"
12+
"github.com/stretchr/testify/require"
13+
)
14+
15+
// Transform action executes the transform of the given ID with the given input
16+
// and asserts that the output matches the given expected output.
17+
type Transform struct {
18+
Nodeful
19+
20+
LensID string
21+
Input enumerable.Enumerable[store.Document]
22+
Expected enumerable.Enumerable[store.Document]
23+
}
24+
25+
var _ Action = (*Add)(nil)
26+
var _ Stateful = (*Add)(nil)
27+
28+
func (a *Transform) Execute() {
29+
lensID := replace(a.s, a.LensID)
30+
for nodeIndex, n := range a.Nodes() {
31+
output, err := n.Store.Transform(a.s.Ctx, a.Input, lensID)
32+
require.NoError(a.s.T, err)
33+
34+
n.Transforms = append(n.Transforms, output)
35+
36+
for i := 0; true; i++ {
37+
println(fmt.Sprintf("TransformEval: Node: {%v} item: {%v}", nodeIndex, i))
38+
39+
hasNext, err := output.Next()
40+
require.NoError(a.s.T, err)
41+
42+
expectedHasNext, err := a.Expected.Next()
43+
require.NoError(a.s.T, err)
44+
45+
require.Equal(a.s.T, expectedHasNext, hasNext)
46+
47+
if !hasNext {
48+
break
49+
}
50+
51+
value, err := output.Value()
52+
require.NoError(a.s.T, err)
53+
54+
expectedValue, err := a.Expected.Value()
55+
require.NoError(a.s.T, err)
56+
57+
require.Equal(a.s.T, expectedValue, value)
58+
}
59+
60+
expectedHasNext, err := a.Expected.Next()
61+
require.NoError(a.s.T, err)
62+
require.False(a.s.T, expectedHasNext)
63+
64+
a.Expected.Reset()
65+
}
66+
}

tests/action/transform_reset.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// This Source Code Form is subject to the terms of the Mozilla Public
2+
// License, v. 2.0. If a copy of the MPL was not distributed with this
3+
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
4+
5+
package action
6+
7+
// TransformReset resets the given transform.
8+
type TransformReset struct {
9+
Nodeful
10+
11+
TransformIndex int
12+
}
13+
14+
var _ Action = (*Add)(nil)
15+
var _ Stateful = (*Add)(nil)
16+
17+
func (a *TransformReset) Execute() {
18+
for _, n := range a.Nodes() {
19+
n.Transforms[a.TransformIndex].Reset()
20+
}
21+
}

0 commit comments

Comments
 (0)