@@ -77,7 +77,7 @@ def test_clear_attributes(self):
7777 self .assertEqual (len (labeler ), 0 )
7878
7979 def test_thread_safety (self ):
80- labeler = Labeler (max_custom_attrs = 1000 )
80+ labeler = Labeler (max_custom_attrs = 1100 ) # 11 * 100
8181 num_threads = 10
8282 num_operations = 100
8383
@@ -87,6 +87,8 @@ def worker(thread_id):
8787 f"thread_{ thread_id } _key_{ i_operation } " ,
8888 f"value_{ i_operation } " ,
8989 )
90+ # "shared" key that all 10 threads compete to write to
91+ labeler .add ("shared" , thread_id )
9092
9193 # Start multiple threads
9294 threads = []
@@ -99,10 +101,45 @@ def worker(thread_id):
99101 for thread in threads :
100102 thread .join ()
101103
102- # Check that all attributes were added
103104 attributes = labeler .get_attributes ()
104- expected_count = num_threads * num_operations
105- self .assertEqual (len (attributes ), expected_count )
105+ # Should have all unique keys plus "shared"
106+ expected_unique_keys = num_threads * num_operations
107+ self .assertEqual (len (attributes ), expected_unique_keys + 1 )
108+ # "shared" key should exist and have some valid thread_id
109+ self .assertIn ("shared" , attributes )
110+ self .assertIn (attributes ["shared" ], range (num_threads ))
111+
112+ def test_thread_safety_atomic_increment (self ):
113+ """More non-atomic operations than test_thread_safety"""
114+ labeler = Labeler (max_custom_attrs = 100 )
115+ labeler .add ("counter" , 0 )
116+ num_threads = 100
117+ increments_per_thread = 50
118+ expected_final_value = num_threads * increments_per_thread
119+
120+ def increment_worker ():
121+ for _ in range (increments_per_thread ):
122+ # read-modify-write to increase contention
123+ attrs = labeler .get_attributes () # Read
124+ current = attrs ["counter" ] # Extract
125+ new_value = current + 1 # Modify
126+ labeler .add ("counter" , new_value ) # Write
127+
128+ threads = []
129+ for _ in range (num_threads ):
130+ thread = threading .Thread (target = increment_worker )
131+ threads .append (thread )
132+ for thread in threads :
133+ thread .start ()
134+ for thread in threads :
135+ thread .join ()
136+
137+ final_value = labeler .get_attributes ()["counter" ]
138+ self .assertEqual (
139+ final_value , expected_final_value ,
140+ f"Expected { expected_final_value } , got { final_value } . "
141+ f"Lost { expected_final_value - final_value } updates due to race conditions."
142+ )
106143
107144
108145class TestLabelerContext (unittest .TestCase ):
0 commit comments