Skip to content

Commit 81aa9b5

Browse files
authored
[Observation] Add some behaviroal tests for changes, transactions, and tracking (#64179)
* [Observation] Add some behaviroal tests for changes, transactions, and tracking * Correct transactions to properly suspend when awaiting for changes
1 parent 0ff5cd9 commit 81aa9b5

File tree

2 files changed

+298
-17
lines changed

2 files changed

+298
-17
lines changed

stdlib/public/Observation/Sources/Observation/ObservationRegistrar.swift

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,13 @@ public struct ObservationRegistrar<Subject: Observable>: Sendable {
5454
fileprivate struct Next {
5555
enum Kind {
5656
case transaction(TrackedProperties<Subject>)
57+
case pendingTransaction(TrackedProperties<Subject>, UnsafeContinuation<TrackedProperties<Subject>?, Never>)
5758
case change(TrackedProperties<Subject>, UnsafeContinuation<Subject?, Never>)
5859
case tracking(TrackedProperties<Subject>, @Sendable () -> Void)
5960
case cancelled
6061
}
6162

62-
fileprivate let kind: Kind
63+
fileprivate var kind: Kind
6364
fileprivate var collected: TrackedProperties<Subject>?
6465
}
6566

@@ -139,8 +140,10 @@ extension ObservationRegistrar.Context {
139140
_ generation: Int,
140141
isolation: isolated Delivery
141142
) async -> TrackedProperties<Subject>? {
142-
return state.withCriticalRegion { state in
143-
state.complete(generation)
143+
return await withUnsafeContinuation { continuation in
144+
state.withCriticalRegion { state in
145+
state.complete(generation, continuation: continuation)
146+
}
144147
}
145148
}
146149

@@ -218,8 +221,8 @@ extension ObservationRegistrar.Next {
218221

219222
@available(SwiftStdlib 5.9, *)
220223
extension ObservationRegistrar.Next.Kind {
221-
fileprivate func resume(
222-
keyPath: PartialKeyPath<Subject>,
224+
fileprivate mutating func resume(
225+
keyPath: PartialKeyPath<Subject>,
223226
phase: ObservationRegistrar.Phase,
224227
collected: inout TrackedProperties<Subject>?
225228
) -> ObservationRegistrar.ResumeAction? {
@@ -235,6 +238,20 @@ extension ObservationRegistrar.Next.Kind {
235238
collected = properties
236239
}
237240
}
241+
return nil
242+
case (.pendingTransaction(let observedTrackedProperties, let continuation), .willSet):
243+
if observedTrackedProperties.contains(keyPath) {
244+
if var properties = collected {
245+
properties.insert(keyPath)
246+
continuation.resume(returning: properties)
247+
} else {
248+
var properties = TrackedProperties<Subject>()
249+
properties.insert(keyPath)
250+
continuation.resume(returning: properties)
251+
}
252+
self = .transaction(observedTrackedProperties)
253+
}
254+
238255
return nil
239256
case (.change(let observedTrackedProperties, let continuation), .didSet):
240257
if observedTrackedProperties.contains(keyPath) {
@@ -292,6 +309,8 @@ extension ObservationRegistrar.Next.Kind {
292309
switch self {
293310
case .transaction(let properties):
294311
invalidate(properties: properties, from: &lookup, generation: generation)
312+
case .pendingTransaction(let properties, _):
313+
invalidate(properties: properties, from: &lookup, generation: generation)
295314
case .change(let properties, _):
296315
invalidate(properties: properties, from: &lookup, generation: generation)
297316
case .tracking(let properties, _):
@@ -303,6 +322,8 @@ extension ObservationRegistrar.Next.Kind {
303322

304323
fileprivate func deinitialize() {
305324
switch self {
325+
case .pendingTransaction(_, let continuation):
326+
continuation.resume(returning: nil)
306327
case .change(_, let continuation):
307328
continuation.resume(returning: nil)
308329
default:
@@ -383,17 +404,21 @@ extension ObservationRegistrar.State {
383404
}
384405

385406
fileprivate mutating func complete(
386-
_ generation: Int
387-
) -> TrackedProperties<Subject>? {
407+
_ generation: Int,
408+
continuation: UnsafeContinuation<TrackedProperties<Subject>?, Never>
409+
){
388410
if let existing = nexts.removeValue(forKey: generation) {
389411
switch existing.kind {
390-
case .transaction:
391-
return existing.collected
412+
case .transaction(let properties):
413+
if let collected = existing.collected {
414+
continuation.resume(returning: collected)
415+
} else {
416+
nexts[generation] = ObservationRegistrar.Next(kind: .pendingTransaction(properties, continuation))
417+
}
392418
default:
393-
return nil
419+
continuation.resume(returning: nil)
394420
}
395421
}
396-
return nil
397422
}
398423

399424
fileprivate mutating func willSet<Member>(
Lines changed: 262 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,274 @@
1-
// RUN: %target-run-simple-swift
1+
// RUN: %target-run-simple-swift( -Xfrontend -disable-availability-checking -parse-as-library -enable-experimental-feature Macros -Xfrontend -plugin-path -Xfrontend %swift-host-lib-dir/plugins)
2+
23
// REQUIRES: executable_test
34
// REQUIRES: observation
5+
// REQUIRES: concurrency
46
// REQUIRES: objc_interop
5-
// UNSUPPORTED: freestanding
7+
// REQUIRES: executable_test
68

79
import StdlibUnittest
810
import Observation
11+
import _Concurrency
12+
13+
@usableFromInline
14+
@inline(never)
15+
func _blackHole<T>(_ value: T) { }
16+
17+
final class UnsafeBox<Contents>: @unchecked Sendable {
18+
var contents: Contents
19+
20+
init(_ contents: Contents) {
21+
self.contents = contents
22+
}
23+
}
24+
25+
@available(SwiftStdlib 5.9, *)
26+
final class TestWithoutMacro: Observable {
27+
let _registrar = ObservationRegistrar<TestWithoutMacro>()
28+
29+
public nonisolated func transactions<Delivery>(
30+
for properties: TrackedProperties<TestWithoutMacro>,
31+
isolation: Delivery
32+
) -> ObservedTransactions<TestWithoutMacro, Delivery> where Delivery: Actor {
33+
_registrar.transactions(for: properties, isolation: isolation)
34+
}
35+
36+
public nonisolated func changes<Member>(
37+
for keyPath: KeyPath<TestWithoutMacro, Member>
38+
) -> ObservedChanges<TestWithoutMacro, Member> where Member: Sendable {
39+
_registrar.changes(for: keyPath)
40+
}
941

10-
let suite = TestSuite("Observable")
42+
private struct _Storage {
43+
var field1 = "test"
44+
var field2 = "test"
45+
var field3 = 0
46+
}
47+
48+
private var _storage = _Storage()
49+
50+
var field1: String {
51+
get {
52+
_registrar.access(self, keyPath: \.field1)
53+
return _storage.field1
54+
}
55+
set {
56+
_registrar.withMutation(of: self, keyPath: \.field1) {
57+
_storage.field1 = newValue
58+
}
59+
}
60+
}
61+
62+
var field2: String {
63+
get {
64+
_registrar.access(self, keyPath: \.field2)
65+
return _storage.field2
66+
}
67+
set {
68+
_registrar.withMutation(of: self, keyPath: \.field2) {
69+
_storage.field2 = newValue
70+
}
71+
}
72+
}
1173

12-
if #available(SwiftStdlib 9999, *) {
13-
suite.test("Basic") {
74+
var field3: Int {
75+
get {
76+
_registrar.access(self, keyPath: \.field3)
77+
return _storage.field3
78+
}
79+
set {
80+
_registrar.withMutation(of: self, keyPath: \.field3) {
81+
_storage.field3 = newValue
82+
}
83+
}
84+
}
85+
}
86+
87+
@available(SwiftStdlib 5.9, *)
88+
@Observable final class TestWithMacro {
89+
var field1 = "test"
90+
var field2 = "test"
91+
var field3 = 0
92+
}
1493

94+
extension AsyncSequence {
95+
func triggerIteration(
96+
_ continuation: UnsafeContinuation<Void, Never>
97+
) -> TriggerSequence<Self> {
98+
TriggerSequence(self, continuation: continuation)
1599
}
16100
}
17101

18-
runAllTests()
102+
struct TriggerSequence<Base: AsyncSequence> {
103+
let base: Base
104+
let continuation: UnsafeContinuation<Void, Never>
105+
106+
init(_ base: Base, continuation: UnsafeContinuation<Void, Never>) {
107+
self.base = base
108+
self.continuation = continuation
109+
}
110+
}
111+
112+
extension TriggerSequence: AsyncSequence {
113+
typealias Element = Base.Element
114+
115+
struct Iterator: AsyncIteratorProtocol {
116+
var continuation: UnsafeContinuation<Void, Never>?
117+
var base: Base.AsyncIterator
118+
119+
init(
120+
_ base: Base.AsyncIterator,
121+
continuation: UnsafeContinuation<Void, Never>
122+
) {
123+
self.base = base
124+
self.continuation = continuation
125+
}
126+
127+
mutating func next() async rethrows -> Base.Element? {
128+
if let continuation {
129+
self.continuation = nil
130+
continuation.resume()
131+
}
132+
return try await base.next()
133+
}
134+
}
135+
136+
func makeAsyncIterator() -> Iterator {
137+
Iterator(base.makeAsyncIterator(), continuation: continuation)
138+
}
139+
}
140+
141+
@main struct Main {
142+
@MainActor
143+
static func main() async {
144+
let suite = TestSuite("Observable")
145+
146+
suite.test("unobserved value changes (macro)") {
147+
let subject = TestWithMacro()
148+
for i in 0..<100 {
149+
subject.field3 = i
150+
}
151+
}
152+
153+
suite.test("unobserved value changes (nonmacro)") {
154+
let subject = TestWithoutMacro()
155+
for i in 0..<100 {
156+
subject.field3 = i
157+
}
158+
}
159+
160+
suite.test("changes emit values (macro)") { @MainActor in
161+
let subject = TestWithMacro()
162+
var t: Task<String?, Never>?
163+
await withUnsafeContinuation { continuation in
164+
t = Task { @MainActor in
165+
// Note: this must be fully established
166+
// so we must await the trigger to fire
167+
let changes = subject.changes(for: \.field1)
168+
.triggerIteration(continuation)
169+
for await value in changes {
170+
return value
171+
}
172+
return nil
173+
}
174+
}
175+
subject.field1 = "a"
176+
let value = await t!.value
177+
expectEqual(value, "a")
178+
}
179+
180+
suite.test("changes emit values (nonmacro)") { @MainActor in
181+
let subject = TestWithoutMacro()
182+
var t: Task<String?, Never>?
183+
await withUnsafeContinuation { continuation in
184+
t = Task { @MainActor in
185+
// Note: this must be fully established
186+
// so we must await the trigger to fire
187+
let changes = subject.changes(for: \.field1)
188+
.triggerIteration(continuation)
189+
for await value in changes {
190+
return value
191+
}
192+
return nil
193+
}
194+
}
195+
subject.field1 = "a"
196+
let value = await t!.value
197+
expectEqual(value, "a")
198+
}
199+
200+
201+
suite.test("changes cancellation terminates") { @MainActor in
202+
let subject = TestWithMacro()
203+
var finished = false
204+
let t = Task { @MainActor in
205+
for await _ in subject.changes(for: \.field1) {
206+
207+
}
208+
finished = true
209+
}
210+
try? await Task.sleep(for: .seconds(0.1))
211+
expectEqual(finished, false)
212+
t.cancel()
213+
try? await Task.sleep(for: .seconds(0.1))
214+
expectEqual(finished, true)
215+
}
216+
217+
suite.test("transactions emit values (macro)") { @MainActor in
218+
let subject = TestWithMacro()
219+
var t: Task<TrackedProperties<TestWithMacro>?, Never>?
220+
await withUnsafeContinuation { continuation in
221+
t = Task { @MainActor in
222+
// Note: this must be fully established
223+
// so we must await the trigger to fire
224+
let transactions = subject.transactions(for: \.field1)
225+
.triggerIteration(continuation)
226+
for await value in transactions {
227+
return value
228+
}
229+
return nil
230+
}
231+
}
232+
subject.field1 = "a"
233+
let value = await t!.value
234+
expectEqual(value?.contains(\.field1), true)
235+
}
236+
237+
suite.test("transactions emit values (nonmacro)") { @MainActor in
238+
let subject = TestWithoutMacro()
239+
var t: Task<TrackedProperties<TestWithoutMacro>?, Never>?
240+
await withUnsafeContinuation { continuation in
241+
t = Task { @MainActor in
242+
// Note: this must be fully established
243+
// so we must await the trigger to fire
244+
let transactions = subject.transactions(for: \.field1)
245+
.triggerIteration(continuation)
246+
for await value in transactions {
247+
return value
248+
}
249+
return nil
250+
}
251+
}
252+
subject.field1 = "a"
253+
let value = await t!.value
254+
expectEqual(value?.contains(\.field1), true)
255+
}
256+
257+
suite.test("tracking") { @MainActor in
258+
let subject = TestWithMacro()
259+
let changed = UnsafeBox(false)
260+
ObservationTracking.withTracking {
261+
_blackHole(subject.field1)
262+
} onChange: {
263+
changed.contents = true
264+
}
265+
expectEqual(changed.contents, false)
266+
subject.field2 = "asdf"
267+
expectEqual(changed.contents, false)
268+
subject.field1 = "asdf"
269+
expectEqual(changed.contents, true)
270+
}
271+
272+
await runAllTestsAsync()
273+
}
274+
}

0 commit comments

Comments
 (0)