From e907782de12120e6fd18a8f4e467bd03d4fb4bbd Mon Sep 17 00:00:00 2001 From: joshua Date: Fri, 26 Sep 2025 14:34:37 -0700 Subject: [PATCH 1/8] Improve Rust on_close_callback WIP Finally fixed on_close callback - threading issue Convert close to promise and resolve it in Python and Java Add Java onclose callback test Cleanup test Add mutex --- Cargo.lock | 107 ++++++++++++------ c/src/transaction.rs | 4 +- c/typedb_driver.i | 71 +++++++++++- dependencies/typedb/artifacts.bzl | 2 +- java/connection/TransactionImpl.java | 2 +- java/test/integration/BUILD | 23 ++++ java/test/integration/DriverTest.java | 101 +++++++++++++++++ python/tests/integration/BUILD | 11 ++ python/tests/integration/test_driver.py | 53 +++++++++ python/typedb/connection/transaction.py | 16 ++- rust/src/common/error.rs | 2 +- rust/src/connection/network/channel.rs | 1 - rust/src/connection/network/proto/common.rs | 2 +- rust/src/connection/network/stub.rs | 4 +- .../connection/network/transmitter/import.rs | 2 +- .../network/transmitter/response_sink.rs | 3 +- .../network/transmitter/transaction.rs | 107 +++++++++++++----- rust/src/connection/server_connection.rs | 7 +- rust/src/connection/transaction_stream.rs | 4 +- rust/src/database/database.rs | 5 +- rust/src/database/database_manager.rs | 3 +- rust/src/database/migration.rs | 2 +- rust/src/transaction.rs | 7 +- .../behaviour/steps/connection/database.rs | 2 +- rust/tests/behaviour/steps/connection/mod.rs | 2 - rust/tests/integration/BUILD | 18 +++ rust/tests/integration/driver.rs | 84 ++++++++++++++ rust/tests/integration/example.rs | 2 + rust/tests/integration/mod.rs | 1 - 29 files changed, 548 insertions(+), 100 deletions(-) create mode 100644 java/test/integration/DriverTest.java create mode 100644 python/tests/integration/test_driver.py create mode 100644 rust/tests/integration/driver.rs diff --git a/Cargo.lock b/Cargo.lock index 4f93e243f..66f26e573 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -276,9 +276,9 @@ dependencies = [ [[package]] name = "async-std" -version = "1.13.1" +version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "730294c1c08c2e0f85759590518f6333f0d5a0a766a27d519c1b244c3dfd8a24" +checksum = "2c8e079a4ab67ae52b7403632e4618815d6db36d2a010cfe41b02c1b1578f93b" dependencies = [ "async-attributes", "async-channel 1.9.0", @@ -1354,6 +1354,17 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "io-uring" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "046fa2d4d00aea763528b4950358d0ead425372445dc8ff86312b3c69ff7727b" +dependencies = [ + "bitflags 2.6.0", + "cfg-if", + "libc", +] + [[package]] name = "is-terminal" version = "0.4.13" @@ -1397,10 +1408,11 @@ checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" [[package]] name = "js-sys" -version = "0.3.70" +version = "0.3.81" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1868808506b929d7b0cfa8f75951347aa71bb21144b7791bae35d9bccfcfe37a" +checksum = "ec48937a97411dcb524a265206ccd4c90bb711fca92b2792c407f268825b9305" dependencies = [ + "once_cell", "wasm-bindgen", ] @@ -1421,9 +1433,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "libc" -version = "0.2.169" +version = "0.2.176" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" +checksum = "58f929b4d672ea937a23a1ab494143d968337a5f47e56d0815df1e0890ddf174" [[package]] name = "linked-hash-map" @@ -1464,9 +1476,9 @@ dependencies = [ [[package]] name = "macro_rules_attribute" -version = "0.2.0" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a82271f7bc033d84bbca59a3ce3e4159938cb08a9c3aebbe54d215131518a13" +checksum = "65049d7923698040cd0b1ddcced9b0eb14dd22c5f86ae59c3740eab64a676520" dependencies = [ "macro_rules_attribute-proc_macro", "paste", @@ -1474,9 +1486,9 @@ dependencies = [ [[package]] name = "macro_rules_attribute-proc_macro" -version = "0.2.0" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8dd856d451cc0da70e2ef2ce95a18e39a93b7558bedf10201ad28503f918568" +checksum = "670fdfda89751bc4a84ac13eaa63e205cf0fd22b4c9a5fbfa085b63c1f1d3a30" [[package]] name = "matchit" @@ -2207,18 +2219,28 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.219" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.219" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", @@ -2227,15 +2249,16 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.140" +version = "1.0.145" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373" +checksum = "402a6f66d8c709116cf22f558eab210f5a50187f702eb4d7e5ef38d9a7f1c79c" dependencies = [ "indexmap 2.5.0", "itoa", "memchr", "ryu", "serde", + "serde_core", ] [[package]] @@ -2354,6 +2377,16 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "socket2" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "233504af464074f9d066d7b5416c5f9b894a5862a6506e306f7b816cdd6f1807" +dependencies = [ + "libc", + "windows-sys 0.59.0", +] + [[package]] name = "spin" version = "0.9.8" @@ -2522,20 +2555,22 @@ dependencies = [ [[package]] name = "tokio" -version = "1.45.1" +version = "1.47.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75ef51a33ef1da925cea3e4eb122833cb377c61439ca401b770f54902b806779" +checksum = "89e49afdadebb872d3145a5638b59eb0691ea23e46ca484037cfab3b76b95038" dependencies = [ "backtrace", "bytes", + "io-uring", "libc", "mio", "parking_lot", "pin-project-lite", "signal-hook-registry", - "socket2 0.5.7", + "slab", + "socket2 0.6.0", "tokio-macros", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -2820,13 +2855,15 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.16.0" +version = "1.18.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "458f7a779bf54acc9f347480ac654f68407d3aab21269a6e3c9f922acd9e2da9" +checksum = "2f87b8aa10b915a06587d0dec516c282ff295b475d94abf425d62b57710070a2" dependencies = [ "getrandom 0.3.1", + "js-sys", "rand 0.9.0", "serde", + "wasm-bindgen", ] [[package]] @@ -2883,24 +2920,25 @@ dependencies = [ [[package]] name = "wasm-bindgen" -version = "0.2.93" +version = "0.2.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a82edfc16a6c469f5f44dc7b571814045d60404b55a0ee849f9bcfa2e63dd9b5" +checksum = "c1da10c01ae9f1ae40cbfac0bac3b1e724b320abfcf52229f80b547c0d250e2d" dependencies = [ "cfg-if", "once_cell", + "rustversion", "wasm-bindgen-macro", + "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-backend" -version = "0.2.93" +version = "0.2.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9de396da306523044d3302746f1208fa71d7532227f15e347e2d93e4145dd77b" +checksum = "671c9a5a66f49d8a47345ab942e2cb93c7d1d0339065d4f8139c486121b43b19" dependencies = [ "bumpalo", "log", - "once_cell", "proc-macro2", "quote", "syn 2.0.87", @@ -2921,9 +2959,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.93" +version = "0.2.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "585c4c91a46b072c92e908d99cb1dcdf95c5218eeb6f3bf1efa991ee7a68cccf" +checksum = "7ca60477e4c59f5f2986c50191cd972e3a50d8a95603bc9434501cf156a9a119" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -2931,9 +2969,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.93" +version = "0.2.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "afc340c74d9005395cf9dd098506f7f44e38f2b4a21c6aaacf9a105ea5e1e836" +checksum = "9f07d2f20d4da7b26400c9f4a0511e6e0345b040694e8a75bd41d578fa4421d7" dependencies = [ "proc-macro2", "quote", @@ -2944,9 +2982,12 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.93" +version = "0.2.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c62a0a307cb4a311d3a07867860911ca130c3494e8c2719593806c08bc5d0484" +checksum = "bad67dc8b2a1a6e5448428adec4c3e84c43e561d8c9ee8a9e5aabeb193ec41d1" +dependencies = [ + "unicode-ident", +] [[package]] name = "web-sys" diff --git a/c/src/transaction.rs b/c/src/transaction.rs index 8e4f31d7d..8f05e5b19 100644 --- a/c/src/transaction.rs +++ b/c/src/transaction.rs @@ -66,8 +66,8 @@ pub extern "C" fn transaction_close(txn: *mut Transaction) { /// Forcibly closes this transaction. To be used in exceptional cases. #[no_mangle] -pub extern "C" fn transaction_force_close(txn: *mut Transaction) { - borrow_mut(txn).force_close(); +pub extern "C" fn transaction_force_close(txn: *mut Transaction) -> *mut VoidPromise { + release(VoidPromise(Box::new(borrow_mut(txn).force_close()))) } /// Commits the changes made via this transaction to the TypeDB database. diff --git a/c/typedb_driver.i b/c/typedb_driver.i index f95d87410..19cfc1511 100644 --- a/c/typedb_driver.i +++ b/c/typedb_driver.i @@ -17,6 +17,7 @@ * under the License. */ +%module(threads=1) typedb_driver %module(directors="1") typedb_driver %{ extern "C" { @@ -117,11 +118,73 @@ struct TransactionCallbackDirector { #include #include #include -static std::unordered_map transactionOnCloseCallbacks {}; + +class ThreadSafeTransactionCallbacks { +private: + // 1. The static map to protect + static std::unordered_map s_transactionOnCloseCallbacks; + + // 2. The static mutex to manage access + static std::mutex s_mutex; + +public: + // Delete copy/move constructors and assignment operators + // to prevent accidental copying of the singleton-like structure + ThreadSafeTransactionCallbacks(const ThreadSafeTransactionCallbacks&) = delete; + ThreadSafeTransactionCallbacks& operator=(const ThreadSafeTransactionCallbacks&) = delete; + + // --- Core Operations --- + + /** + * @brief Inserts a key-value pair into the map in a thread-safe manner. + */ + static void insert(size_t key, TransactionCallbackDirector* value) { + // Lock the mutex for the duration of this scope + std::lock_guard lock(s_mutex); + + // Thread-safe insertion + s_transactionOnCloseCallbacks[key] = value; + } + + /** + * @brief Retrieves a value associated with a key in a thread-safe manner. + * @returns The value pointer, or nullptr if the key is not found. + */ + static TransactionCallbackDirector* find(size_t key) { + // Lock the mutex for the duration of this scope + std::lock_guard lock(s_mutex); + + // Thread-safe lookup + auto it = s_transactionOnCloseCallbacks.find(key); + if (it != s_transactionOnCloseCallbacks.end()) { + return it->second; + } + return nullptr; // Return nullptr if not found + } + + /** + * @brief Removes a key-value pair from the map in a thread-safe manner. + */ + static void remove(size_t key) { + // Lock the mutex for the duration of this scope + std::lock_guard lock(s_mutex); + + // Thread-safe removal + s_transactionOnCloseCallbacks.erase(key); + } + + // Add other necessary map operations (e.g., size(), contains(), clear()) here... +}; + +// Initialize the static members +std::unordered_map ThreadSafeTransactionCallbacks::s_transactionOnCloseCallbacks; +std::mutex ThreadSafeTransactionCallbacks::s_mutex; + static void transaction_callback_execute(size_t ID, Error* error) { try { - transactionOnCloseCallbacks.at(ID)->callback(error); - transactionOnCloseCallbacks.erase(ID); + auto cb = ThreadSafeTransactionCallbacks::find(ID); + cb->callback(error); + ThreadSafeTransactionCallbacks::remove(ID); } catch (std::exception const& e) { std::cerr << "[ERROR] " << e.what() << std::endl; } @@ -135,7 +198,7 @@ static void transaction_callback_execute(size_t ID, Error* error) { void transaction_on_close_register(const Transaction* transaction, TransactionCallbackDirector* handler) { static std::atomic_size_t nextID; std::size_t ID = nextID.fetch_add(1); - transactionOnCloseCallbacks.insert({ID, handler}); + ThreadSafeTransactionCallbacks::insert(ID, handler); transaction_on_close(transaction, ID, &transaction_callback_execute); } %} diff --git a/dependencies/typedb/artifacts.bzl b/dependencies/typedb/artifacts.bzl index 1a3db21f7..00b1c26be 100644 --- a/dependencies/typedb/artifacts.bzl +++ b/dependencies/typedb/artifacts.bzl @@ -25,7 +25,7 @@ def typedb_artifact(): artifact_name = "typedb-all-{platform}-{version}.{ext}", tag_source = deployment["artifact"]["release"]["download"], commit_source = deployment["artifact"]["snapshot"]["download"], - tag = "3.5.0-rc0" + tag = "3.5.0" ) #def typedb_cloud_artifact(): diff --git a/java/connection/TransactionImpl.java b/java/connection/TransactionImpl.java index a18bf40b4..4301986a3 100644 --- a/java/connection/TransactionImpl.java +++ b/java/connection/TransactionImpl.java @@ -133,7 +133,7 @@ public void rollback() throws TypeDBDriverException { public void close() throws TypeDBDriverException { if (nativeObject.isOwned()) { try { - transaction_force_close(nativeObject); + transaction_force_close(nativeObject).get(); } catch (com.typedb.driver.jni.Error error) { throw new TypeDBDriverException(error); } finally { diff --git a/java/test/integration/BUILD b/java/test/integration/BUILD index 972f84b62..688245847 100644 --- a/java/test/integration/BUILD +++ b/java/test/integration/BUILD @@ -46,6 +46,29 @@ typedb_java_test( ], ) +typedb_java_test( + name = "test-driver", + srcs = ["DriverTest.java"], + server_artifacts = { + "@typedb_bazel_distribution//platform:is_linux_arm64": "@typedb_artifact_linux-arm64//file", + "@typedb_bazel_distribution//platform:is_linux_x86_64": "@typedb_artifact_linux-x86_64//file", + "@typedb_bazel_distribution//platform:is_mac_arm64": "@typedb_artifact_mac-arm64//file", + "@typedb_bazel_distribution//platform:is_mac_x86_64": "@typedb_artifact_mac-x86_64//file", +# "@typedb_bazel_distribution//platform:is_windows_x86_64": "@typedb_artifact_windows-x86_64//file", + }, + test_class = "com.typedb.driver.test.integration.DriverTest", + deps = [ + # Internal dependencies + "//java:driver-java", + "//java/api", + "//java/common", + + # External dependencies from @typedb + "@maven//:org_slf4j_slf4j_api", +# "@maven//:com_typedb_typedb_runner", + ], +) + typedb_java_test( name = "test-value", srcs = ["ValueTest.java"], diff --git a/java/test/integration/DriverTest.java b/java/test/integration/DriverTest.java new file mode 100644 index 000000000..18014d195 --- /dev/null +++ b/java/test/integration/DriverTest.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package com.typedb.driver.test.integration; + +import com.typedb.driver.TypeDB; +import com.typedb.driver.api.Credentials; +import com.typedb.driver.api.Driver; +import com.typedb.driver.api.DriverOptions; +import com.typedb.driver.api.Transaction; +import com.typedb.driver.api.answer.ConceptRow; +import com.typedb.driver.api.answer.QueryAnswer; +import com.typedb.driver.api.concept.Concept; +import com.typedb.driver.api.concept.instance.Attribute; +import com.typedb.driver.api.concept.type.AttributeType; +import com.typedb.driver.api.concept.value.Value; +import com.typedb.driver.api.database.Database; +import com.typedb.driver.common.Duration; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.OffsetDateTime; +import java.time.ZoneId; +import java.time.ZonedDateTime; +import java.time.format.DateTimeFormatter; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Consumer; +import java.util.stream.Collectors; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertTrue; + +@SuppressWarnings("Duplicates") +public class DriverTest { + private static final String DB_NAME = "typedb"; + private static final String ADDRESS = "0.0.0.0:1729"; + private static Driver typedbDriver; + + @BeforeClass + public static void setUpClass() { + typedbDriver = TypeDB.driver(ADDRESS, new Credentials("admin", "password"), new DriverOptions(false, null)); + if (typedbDriver.databases().contains(DB_NAME)) typedbDriver.databases().get(DB_NAME).delete(); + typedbDriver.databases().create(DB_NAME); + } + + @AfterClass + public static void close() { + typedbDriver.close(); + } + + @Test + public void transaction_on_close() { + Database db = typedbDriver.databases().get(DB_NAME); + db.delete(); + typedbDriver.databases().create(DB_NAME); + + AtomicBoolean calledOnClose = new AtomicBoolean(false); + + localhostTypeDBTX(transaction -> { + + transaction.onClose(error -> { + calledOnClose.set(true); + }); + + transaction.close(); + assertTrue(calledOnClose.get()); + }, Transaction.Type.READ); + } + + private void localhostTypeDBTX(Consumer fn, Transaction.Type type/*, Options options*/) { + try (Transaction transaction = typedbDriver.transaction(DB_NAME, type/*, options*/)) { + fn.accept(transaction); + } + } +} diff --git a/python/tests/integration/BUILD b/python/tests/integration/BUILD index 4f70e6437..42a6c41d9 100644 --- a/python/tests/integration/BUILD +++ b/python/tests/integration/BUILD @@ -35,6 +35,17 @@ py_test( python_version = "PY3" ) +py_test( + name = "test_driver", + srcs = ["test_driver.py"], + deps = [ + "//python:driver_python", + requirement("PyHamcrest"), + ], + data = ["//python:native-driver-binary-link", "//python:native-driver-wrapper-link"], + python_version = "PY3" +) + py_test( name = "test_values", srcs = ["test_values.py"], diff --git a/python/tests/integration/test_driver.py b/python/tests/integration/test_driver.py new file mode 100644 index 000000000..77a661d75 --- /dev/null +++ b/python/tests/integration/test_driver.py @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest +from unittest import TestCase +import time + +from hamcrest import * +from typedb.driver import * + + +class TestExample(TestCase): + + def setUp(self): + with TypeDB.driver(TypeDB.DEFAULT_ADDRESS, Credentials("admin", "password"), DriverOptions(is_tls_enabled=False)) as driver: + if driver.databases.contains("typedb"): + driver.databases.get("typedb").delete() + + + def test_on_close_callback(self): + with TypeDB.driver(TypeDB.DEFAULT_ADDRESS, Credentials("admin", "password"), DriverOptions(is_tls_enabled=False)) as driver: + driver.databases.create("typedb") + database = driver.databases.get("typedb") + assert_that(database.name, is_("typedb")) + + tx = driver.transaction(database.name, TransactionType.READ) + + transaction_closed = {"closed": False} + def callback(_error): + transaction_closed.update({"closed": True}) + tx.on_close(callback) + + tx.close() + + assert_that(transaction_closed["closed"], is_(True)) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/python/typedb/connection/transaction.py b/python/typedb/connection/transaction.py index 69c372301..1486141ce 100644 --- a/python/typedb/connection/transaction.py +++ b/python/typedb/connection/transaction.py @@ -78,7 +78,8 @@ def is_open(self) -> bool: return transaction_is_open(self.native_object) def on_close(self, function: callable): - transaction_on_close(self.native_object, _Transaction.TransactionOnClose(function).__disown__()) + callback = _Transaction.TransactionOnClose(function) + transaction_on_close(self.native_object, callback.__disown__()) class TransactionOnClose(TransactionCallbackDirector): @@ -87,7 +88,16 @@ def __init__(self, function: callable): self._function = function def callback(self, error: NativeError) -> None: - self._function(TypeDBException(error_code(error), error_message(error))) + try: + if error: + self._function(TypeDBException(error_code(error), error_message(error))) + else: + self._function(None) + except Exception as e: + # WARNING: SWIG will not propagate any errors (including syntax!) to the user without more work so we can at least log + import sys + print("Error invoking onclose callback: ", e, file=sys.stderr) + raise e def commit(self): try: @@ -104,7 +114,7 @@ def rollback(self): def close(self): if self._native_object.thisown: - transaction_force_close(self._native_object) + void_promise_resolve(transaction_force_close(self._native_object)) def __enter__(self): return self diff --git a/rust/src/common/error.rs b/rust/src/common/error.rs index 207376352..96dad3fd5 100644 --- a/rust/src/common/error.rs +++ b/rust/src/common/error.rs @@ -21,7 +21,7 @@ use std::{collections::HashSet, error::Error as StdError, fmt}; use itertools::Itertools; use tonic::{Code, Status}; -use tonic_types::{ErrorDetails, ErrorInfo, StatusExt}; +use tonic_types::StatusExt; use super::{address::Address, RequestID}; diff --git a/rust/src/connection/network/channel.rs b/rust/src/connection/network/channel.rs index a96ec9bdd..b7c47c092 100644 --- a/rust/src/connection/network/channel.rs +++ b/rust/src/connection/network/channel.rs @@ -22,7 +22,6 @@ use std::sync::{Arc, RwLock}; use tonic::{ body::BoxBody, client::GrpcService, - metadata::MetadataValue, service::{ interceptor::{InterceptedService, ResponseFuture as InterceptorResponseFuture}, Interceptor, diff --git a/rust/src/connection/network/proto/common.rs b/rust/src/connection/network/proto/common.rs index b4b9f4348..602581a8c 100644 --- a/rust/src/connection/network/proto/common.rs +++ b/rust/src/connection/network/proto/common.rs @@ -19,7 +19,7 @@ use typedb_protocol::{ options::{Query as QueryOptionsProto, Transaction as TransactionOptionsProto}, - transaction, Options, + transaction, }; use super::{IntoProto, TryFromProto}; diff --git a/rust/src/connection/network/stub.rs b/rust/src/connection/network/stub.rs index 9da178c79..3e1c7d530 100644 --- a/rust/src/connection/network/stub.rs +++ b/rust/src/connection/network/stub.rs @@ -20,12 +20,12 @@ use std::sync::Arc; use futures::{future::BoxFuture, FutureExt, TryFutureExt}; -use log::{debug, trace, warn}; +use log::{debug, trace}; use tokio::sync::mpsc::{unbounded_channel as unbounded_async, UnboundedSender}; use tokio_stream::wrappers::UnboundedReceiverStream; use tonic::{Response, Status, Streaming}; use typedb_protocol::{ - authentication, connection, database, database_manager, migration, server_manager, transaction, + connection, database, database_manager, migration, server_manager, transaction, type_db_client::TypeDbClient as GRPC, user, user_manager, }; diff --git a/rust/src/connection/network/transmitter/import.rs b/rust/src/connection/network/transmitter/import.rs index 892d2595c..ffe062ed2 100644 --- a/rust/src/connection/network/transmitter/import.rs +++ b/rust/src/connection/network/transmitter/import.rs @@ -17,7 +17,7 @@ * under the License. */ -use std::{ops::Deref, sync::Arc, thread::sleep, time::Duration}; +use std::{sync::Arc, thread::sleep, time::Duration}; use futures::StreamExt; #[cfg(not(feature = "sync"))] diff --git a/rust/src/connection/network/transmitter/response_sink.rs b/rust/src/connection/network/transmitter/response_sink.rs index 561876e33..4c7d5dd22 100644 --- a/rust/src/connection/network/transmitter/response_sink.rs +++ b/rust/src/connection/network/transmitter/response_sink.rs @@ -20,13 +20,12 @@ use std::{fmt, fmt::Formatter, sync::Arc}; use crossbeam::channel::Sender as SyncOneshotSender; -use itertools::Either; use log::{debug, error}; use tokio::sync::{mpsc::UnboundedSender, oneshot::Sender as AsyncOneshotSender}; use crate::{ common::{RequestID, Result}, - error::{ConnectionError, InternalError}, + error::InternalError, Error, }; diff --git a/rust/src/connection/network/transmitter/transaction.rs b/rust/src/connection/network/transmitter/transaction.rs index 81465efad..8d9ece061 100644 --- a/rust/src/connection/network/transmitter/transaction.rs +++ b/rust/src/connection/network/transmitter/transaction.rs @@ -46,11 +46,10 @@ use tokio::{ }; use tonic::Streaming; use typedb_protocol::transaction::{self, res_part::ResPart, server::Server, stream_signal::res_part::State}; -use uuid::Uuid; #[cfg(feature = "sync")] use super::oneshot_blocking as oneshot; -use super::response_sink::{ImmediateHandler, ResponseSink, StreamResponse}; +use super::response_sink::{ResponseSink, StreamResponse}; use crate::{ common::{ box_promise, @@ -59,10 +58,9 @@ use crate::{ Callback, Promise, RequestID, Result, }, connection::{ - message::{QueryResponse, Request, Response, TransactionRequest, TransactionResponse}, + message::{QueryResponse, TransactionRequest, TransactionResponse}, network::proto::{FromProto, IntoProto, TryFromProto}, runtime::BackgroundRuntime, - server_connection::LatencyTracker, }, Error, }; @@ -71,7 +69,7 @@ pub(in crate::connection) struct TransactionTransmitter { request_sink: UnboundedSender<(TransactionRequest, Option>)>, is_open: Arc>, error: Arc>>, - on_close_register_sink: UnboundedSender) + Send + Sync>>, + on_close_register_sink: UnboundedSender<(Box) + Send + Sync>, UnboundedSender<()>)>, shutdown_sink: UnboundedSender<()>, // runtime is alive as long as the transaction transmitter is alive: background_runtime: Arc, @@ -79,7 +77,11 @@ pub(in crate::connection) struct TransactionTransmitter { impl Drop for TransactionTransmitter { fn drop(&mut self) { - self.force_close(); + // TODO: in the async context, this now returns a promise... do we care? + // ---> Basically, this is now a network round trip operation and can take time + // ---> Decision 1: should drop be blocking like this + // ---> Decision 2: if it should, we need to use async drop for async contexts? or poll? + let _ = self.force_close(); } } @@ -118,15 +120,52 @@ impl TransactionTransmitter { &self.shutdown_sink } - pub(in crate::connection) fn force_close(&self) { + #[cfg(not(feature = "sync"))] + pub(in crate::connection) fn force_close(&self) -> impl Promise<'static, Result<()>> { + if self.is_open.compare_exchange(true, false).is_ok() { + let (closed_sink, mut closed_source) = unbounded_async(); + let close_notifier_callback = Box::new(move |error | { + closed_sink.send(()).unwrap(); + }); + self.on_close(close_notifier_callback); + *self.error.write().unwrap() = Some(ConnectionError::TransactionIsClosed.into()); + self.shutdown_sink.send(()).ok(); + box_promise(async move { + closed_source.await.ok(); + Ok(()) + }) + } else { + box_promise(async move { + Ok(()) + }) + } + } + + #[cfg(feature = "sync")] + pub(in crate::connection) fn force_close(&self) -> impl Promise<'static, Result<()>> { if self.is_open.compare_exchange(true, false).is_ok() { + let (closed_sink, closed_source) = oneshot(); + let close_notifier_callback = Box::new(move |error| { + closed_sink.send(()).unwrap(); + }); + self.on_close(close_notifier_callback); *self.error.write().unwrap() = Some(ConnectionError::TransactionIsClosed.into()); self.shutdown_sink.send(()).ok(); + box_promise(move || { + closed_source.recv().ok(); + Ok(()) + }) + } else { + box_promise(move || { + Ok(()) + }) } } pub(in crate::connection) fn on_close(&self, callback: impl FnOnce(Option) + Send + Sync + 'static) { - self.on_close_register_sink.send(Box::new(callback)).ok(); + let (sender, mut sink) = unbounded_async(); + self.on_close_register_sink.send((Box::new(callback), sender)).ok(); + sink.blocking_recv().expect("Did not receive on_close registration success signal"); } #[cfg(not(feature = "sync"))] @@ -230,7 +269,10 @@ impl TransactionTransmitter { response_source: Streaming, is_open: Arc>, error: Arc>>, - on_close_callback_source: UnboundedReceiver) + Send + Sync>>, + on_close_callback_source: UnboundedReceiver<( + Box) + Send + Sync>, + UnboundedSender<()>, + )>, callback_handler_sink: Sender<(Callback, AsyncOneshotSender<()>)>, shutdown_sink: UnboundedSender<()>, shutdown_signal: UnboundedReceiver<()>, @@ -244,22 +286,20 @@ impl TransactionTransmitter { }; tokio::task::spawn_blocking({ let collector = collector.clone(); - move || { - Self::dispatch_loop(queue_source, request_sink, collector, on_close_callback_source, shutdown_signal) - } + move || Self::sync_dispatch_loop(queue_source, request_sink, collector, shutdown_signal) }); - tokio::spawn(Self::listen_loop(response_source, collector, shutdown_sink)); + tokio::spawn(Self::async_listen_loop(response_source, collector, on_close_callback_source, shutdown_sink)); } - fn dispatch_loop( + const DISPATCH_INTERVAL: Duration = Duration::from_micros(50); + + fn sync_dispatch_loop( mut request_source: UnboundedReceiver<(TransactionRequest, Option>)>, request_sink: UnboundedSender, mut collector: ResponseCollector, - mut on_close_callback_source: UnboundedReceiver) + Send + Sync>>, mut shutdown_signal: UnboundedReceiver<()>, ) { const MAX_GRPC_MESSAGE_LEN: usize = 1_000_000; - const DISPATCH_INTERVAL: Duration = Duration::from_micros(50); let mut request_buffer = TransactionRequestBuffer::default(); loop { @@ -269,11 +309,8 @@ impl TransactionTransmitter { } break; } - if let Ok(callback) = on_close_callback_source.try_recv() { - collector.on_close.write().unwrap().push(callback) - } // sleep, then take all messages off the request queue and dispatch them - sleep(DISPATCH_INTERVAL); + sleep(Self::DISPATCH_INTERVAL); while let Ok(recv) = request_source.try_recv() { let (request, callback) = recv; let request = request.into_proto(); @@ -291,17 +328,31 @@ impl TransactionTransmitter { } } - async fn listen_loop( + async fn async_listen_loop( mut grpc_source: Streaming, collector: ResponseCollector, + mut on_close_callback_source: UnboundedReceiver<( + Box) + Send + Sync>, + UnboundedSender<()>, + )>, shutdown_sink: UnboundedSender<()>, ) { loop { - match grpc_source.next().await { - Some(Ok(message)) => collector.collect(message).await, - Some(Err(status)) => break collector.close_with_error(status.into()).await, - None => break collector.close().await, - } + let result = tokio::select! { biased; + message = grpc_source.next() => { + match message { + Some(Ok(message)) => collector.collect(message).await, + Some(Err(status)) => break collector.close_with_error(status.into()).await, + None => break collector.close().await + } + } + callback_option = on_close_callback_source.recv() => { + if let Some((callback, recorded_signal)) = callback_option { + collector.on_close.write().unwrap().push(callback); + recorded_signal.send(()).expect("Failed to signal back that on_close callback was recorded.") + } + } + }; } shutdown_sink.send(()).ok(); } @@ -438,8 +489,8 @@ impl ResponseCollector { for (_, listener) in listeners.drain() { listener.finish(Ok(TransactionResponse::Close)); } - let callbacks = std::mem::take(&mut *self.on_close.write().unwrap()); - for callback in callbacks { + let on_close_callbacks = std::mem::take(&mut *self.on_close.write().unwrap()); + for callback in on_close_callbacks { let (response_sink, response) = oneshot_async(); self.callback_handler_sink.send((Box::new(move || callback(None)), response_sink)).unwrap(); response.await.ok(); diff --git a/rust/src/connection/server_connection.rs b/rust/src/connection/server_connection.rs index 8e805dd34..1574f66cb 100644 --- a/rust/src/connection/server_connection.rs +++ b/rust/src/connection/server_connection.rs @@ -18,7 +18,6 @@ */ use std::{ - collections::{HashMap, HashSet}, fmt, sync::{ atomic::{AtomicU64, Ordering}, @@ -31,10 +30,10 @@ use tokio::{sync::mpsc::UnboundedSender, time::Instant}; use uuid::Uuid; use crate::{ - common::{address::Address, RequestID}, + common::address::Address, connection::{ database::{export_stream::DatabaseExportStream, import_stream::DatabaseImportStream}, - message::{DatabaseImportRequest, Request, Response, TransactionRequest, TransactionResponse}, + message::{DatabaseImportRequest, Request, Response, TransactionRequest}, network::transmitter::{ DatabaseExportTransmitter, DatabaseImportTransmitter, RPCTransmitter, TransactionTransmitter, }, @@ -43,7 +42,7 @@ use crate::{ }, error::{ConnectionError, InternalError}, info::{DatabaseInfo, UserInfo}, - Credentials, DriverOptions, TransactionOptions, TransactionType, User, + Credentials, DriverOptions, TransactionOptions, TransactionType, }; #[derive(Clone)] diff --git a/rust/src/connection/transaction_stream.rs b/rust/src/connection/transaction_stream.rs index 44aca8275..c3b2f8bf4 100644 --- a/rust/src/connection/transaction_stream.rs +++ b/rust/src/connection/transaction_stream.rs @@ -78,8 +78,8 @@ impl TransactionStream { self.transaction_transmitter.is_open() } - pub(crate) fn force_close(&self) { - self.transaction_transmitter.force_close(); + pub(crate) fn force_close(&self) -> impl Promise<'static, Result<()>> { + self.transaction_transmitter.force_close() } pub(crate) fn type_(&self) -> TransactionType { diff --git a/rust/src/database/database.rs b/rust/src/database/database.rs index 59a315ac2..9c2a51db9 100644 --- a/rust/src/database/database.rs +++ b/rust/src/database/database.rs @@ -25,10 +25,7 @@ use std::{ fs::File, io::{BufWriter, Write}, path::Path, - sync::{ - atomic::{AtomicU64, Ordering}, - Arc, RwLock, - }, + sync::{Arc, RwLock}, thread::sleep, time::Duration, }; diff --git a/rust/src/database/database_manager.rs b/rust/src/database/database_manager.rs index 40ca83794..387679c66 100644 --- a/rust/src/database/database_manager.rs +++ b/rust/src/database/database_manager.rs @@ -21,12 +21,11 @@ use std::future::Future; use std::{ collections::HashMap, - io::{BufReader, BufWriter, Cursor, Read}, + io::{BufReader}, path::Path, sync::{Arc, RwLock}, }; -use prost::{decode_length_delimiter, Message}; use typedb_protocol::migration::Item; use super::Database; diff --git a/rust/src/database/migration.rs b/rust/src/database/migration.rs index c86a1c66e..4f4780224 100644 --- a/rust/src/database/migration.rs +++ b/rust/src/database/migration.rs @@ -20,7 +20,7 @@ use std::{ cmp::max, fs::{File, OpenOptions}, - io::{BufRead, Read, Write}, + io::{BufRead, Read}, marker::PhantomData, path::Path, }; diff --git a/rust/src/transaction.rs b/rust/src/transaction.rs index 93c9daccb..b362faa57 100644 --- a/rust/src/transaction.rs +++ b/rust/src/transaction.rs @@ -104,10 +104,11 @@ impl Transaction { /// # Examples /// /// ```rust - /// transaction.force_close() + #[cfg_attr(feature = "sync", doc = "transaction.force_close().resolve()")] + #[cfg_attr(not(feature = "sync"), doc = "transaction.force_close().await")] /// ``` - pub fn force_close(&self) { - self.transaction_stream.force_close(); + pub fn force_close(&self) -> impl Promise<'static, Result<()>> { + self.transaction_stream.force_close() } /// Commits the changes made via this transaction to the TypeDB database. Whether or not the transaction is commited successfully, it gets closed after the commit call. diff --git a/rust/tests/behaviour/steps/connection/database.rs b/rust/tests/behaviour/steps/connection/database.rs index 6a42aa20c..9ed45176b 100644 --- a/rust/tests/behaviour/steps/connection/database.rs +++ b/rust/tests/behaviour/steps/connection/database.rs @@ -17,7 +17,7 @@ * under the License. */ -use std::{collections::HashSet, fs::File, io::Read}; +use std::io::Read; use cucumber::{gherkin::Step, given, then, when}; use futures::{ diff --git a/rust/tests/behaviour/steps/connection/mod.rs b/rust/tests/behaviour/steps/connection/mod.rs index 293a88108..67d85db5c 100644 --- a/rust/tests/behaviour/steps/connection/mod.rs +++ b/rust/tests/behaviour/steps/connection/mod.rs @@ -19,8 +19,6 @@ use cucumber::{given, then, when}; use macro_rules_attribute::apply; -use tokio::time::sleep; -use typedb_driver::{Credentials, TypeDBDriver}; use crate::{assert_with_timeout, generic_step, params, params::check_boolean, Context}; diff --git a/rust/tests/integration/BUILD b/rust/tests/integration/BUILD index 11ba6b476..b3af28ff8 100644 --- a/rust/tests/integration/BUILD +++ b/rust/tests/integration/BUILD @@ -43,6 +43,24 @@ rust_test( ], ) +rust_test( + name = "test_driver", + srcs = ["driver.rs"], + deps = [ + "//rust:typedb_driver", + "@crates//:async-std", + "@crates//:chrono", + "@crates//:futures", + "@crates//:itertools", + "@crates//:regex", + "@crates//:serde_json", + "@crates//:serial_test", + "@crates//:smol", + "@crates//:tokio", + "@crates//:uuid", + ], +) + checkstyle_test( name = "checkstyle", include = glob(["*"]), diff --git a/rust/tests/integration/driver.rs b/rust/tests/integration/driver.rs new file mode 100644 index 000000000..1242b0d7c --- /dev/null +++ b/rust/tests/integration/driver.rs @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::{ + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + } + , + time::Instant, +}; + +use serial_test::serial; +use typedb_driver::{ + Credentials, DriverOptions, TransactionType, TypeDBDriver, +}; + +// EXAMPLE END MARKER + +async fn cleanup() { + let driver = TypeDBDriver::new( + TypeDBDriver::DEFAULT_ADDRESS, + Credentials::new("admin", "password"), + DriverOptions::new(false, None).unwrap(), + ) + .await + .unwrap(); + if driver.databases().contains("typedb").await.unwrap() { + driver.databases().get("typedb").await.unwrap().delete().await.unwrap(); + } +} + +#[test] +#[serial] +fn transaction_callback() { + async_std::task::block_on(async { + cleanup().await; + let driver = TypeDBDriver::new( + TypeDBDriver::DEFAULT_ADDRESS, + Credentials::new("admin", "password"), + DriverOptions::new(false, None).unwrap(), + ) + .await + .unwrap(); + + driver.databases().create("typedb").await.unwrap(); + let database = driver.databases().get("typedb").await.unwrap(); + assert_eq!(database.name(), "typedb"); + + let close_called = Arc::new(AtomicBool::new(false)); + let transaction = driver.transaction(database.name(), TransactionType::Read).await.unwrap(); + transaction.on_close(Box::new({ + let clone = close_called.clone(); + move |error| { + clone.store(true, Ordering::SeqCst); + } + })); + + drop(transaction); // TODO: drop isn't blocking... so we need to spin? or is there an alternative? + + while !close_called.load(Ordering::Acquire) { + // Yield the current time slice to the OS scheduler. + // This prevents the loop from consuming 100% of a CPU core. + std::thread::yield_now(); + } + assert!(close_called.load(Ordering::SeqCst)) + }) +} diff --git a/rust/tests/integration/example.rs b/rust/tests/integration/example.rs index fee42911a..1706d82a0 100644 --- a/rust/tests/integration/example.rs +++ b/rust/tests/integration/example.rs @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +mod driver; + // EXAMPLE START MARKER use std::time::Duration; diff --git a/rust/tests/integration/mod.rs b/rust/tests/integration/mod.rs index 22e8bfb34..f4890ab29 100644 --- a/rust/tests/integration/mod.rs +++ b/rust/tests/integration/mod.rs @@ -18,4 +18,3 @@ */ mod cluster; -mod cluster; From a87a9d77de14898776ecd4c79d7a4a087fedb429 Mon Sep 17 00:00:00 2001 From: joshua Date: Tue, 30 Sep 2025 12:31:55 -0700 Subject: [PATCH 2/8] Trigger CI From df3ba3aa90c63bb1a4bb27e87a5f55d43229bf73 Mon Sep 17 00:00:00 2001 From: joshua Date: Tue, 30 Sep 2025 14:48:22 -0700 Subject: [PATCH 3/8] Deploy snapshot --- .circleci/config.yml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 43e6d2f36..588801a09 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -850,32 +850,32 @@ workflows: - deploy-snapshot-linux-arm64: filters: branches: - only: [development, master, "3.0"] + only: [development, master, improve-on-close-callback] - deploy-snapshot-linux-x86_64: filters: branches: - only: [development, master, "3.0"] + only: [development, master, improve-on-close-callback] - deploy-snapshot-mac-arm64: filters: branches: - only: [development, master, "3.0"] + only: [development, master, improve-on-close-callback] - deploy-snapshot-mac-x86_64: filters: branches: - only: [development, master, "3.0"] + only: [development, master, improve-on-close-callback] - deploy-snapshot-windows-x86_64: filters: branches: - only: [development, master, "3.0"] + only: [development, master, improve-on-close-callback] - deploy-snapshot-any: filters: branches: - only: [development, master, "3.0"] + only: [development, master, improve-on-close-callback] requires: - deploy-snapshot-linux-arm64 - deploy-snapshot-linux-x86_64 From b1b3f5b8346bd4aac82e5e9fecd7d67c6526c559 Mon Sep 17 00:00:00 2001 From: joshua Date: Tue, 30 Sep 2025 14:59:24 -0700 Subject: [PATCH 4/8] Fix build --- rust/src/connection/network/transmitter/transaction.rs | 2 +- rust/tests/behaviour/steps/connection/transaction.rs | 2 +- rust/tests/behaviour/steps/lib.rs | 4 ++-- rust/tests/integration/example.rs | 1 - 4 files changed, 4 insertions(+), 5 deletions(-) diff --git a/rust/src/connection/network/transmitter/transaction.rs b/rust/src/connection/network/transmitter/transaction.rs index 8d9ece061..3f488c940 100644 --- a/rust/src/connection/network/transmitter/transaction.rs +++ b/rust/src/connection/network/transmitter/transaction.rs @@ -131,7 +131,7 @@ impl TransactionTransmitter { *self.error.write().unwrap() = Some(ConnectionError::TransactionIsClosed.into()); self.shutdown_sink.send(()).ok(); box_promise(async move { - closed_source.await.ok(); + closed_source.recv().await; Ok(()) }) } else { diff --git a/rust/tests/behaviour/steps/connection/transaction.rs b/rust/tests/behaviour/steps/connection/transaction.rs index ede25f09c..23e0684d2 100644 --- a/rust/tests/behaviour/steps/connection/transaction.rs +++ b/rust/tests/behaviour/steps/connection/transaction.rs @@ -156,7 +156,7 @@ pub async fn transaction_commits(context: &mut Context, may_error: params::MayEr #[apply(generic_step)] #[step(expr = "transaction closes")] pub async fn transaction_closes(context: &mut Context) { - context.take_transaction().force_close(); + context.take_transaction().force_close().await.ok(); } #[apply(generic_step)] diff --git a/rust/tests/behaviour/steps/lib.rs b/rust/tests/behaviour/steps/lib.rs index a33b5412f..2117c67fb 100644 --- a/rust/tests/behaviour/steps/lib.rs +++ b/rust/tests/behaviour/steps/lib.rs @@ -234,13 +234,13 @@ impl Context { pub async fn cleanup_transactions(&mut self) { while let Some(transaction) = self.try_take_transaction() { - transaction.force_close(); + transaction.force_close().await.ok(); } } pub async fn cleanup_background_transactions(&mut self) { while let Some(background_transaction) = self.try_take_background_transaction() { - background_transaction.force_close(); + background_transaction.force_close().await.ok(); } } diff --git a/rust/tests/integration/example.rs b/rust/tests/integration/example.rs index 1706d82a0..b59914267 100644 --- a/rust/tests/integration/example.rs +++ b/rust/tests/integration/example.rs @@ -16,7 +16,6 @@ * specific language governing permissions and limitations * under the License. */ -mod driver; // EXAMPLE START MARKER use std::time::Duration; From ebf41f51970a0352625b778c784066eab0c0d6cf Mon Sep 17 00:00:00 2001 From: joshua Date: Tue, 30 Sep 2025 15:09:07 -0700 Subject: [PATCH 5/8] Include mutex --- c/typedb_driver.i | 1 + 1 file changed, 1 insertion(+) diff --git a/c/typedb_driver.i b/c/typedb_driver.i index 19cfc1511..2cc43181b 100644 --- a/c/typedb_driver.i +++ b/c/typedb_driver.i @@ -117,6 +117,7 @@ struct TransactionCallbackDirector { %{ #include #include +#include #include class ThreadSafeTransactionCallbacks { From cbf1d843afdcd842fa9a3e4bc59bfa732314020e Mon Sep 17 00:00:00 2001 From: joshua Date: Tue, 30 Sep 2025 18:55:49 -0700 Subject: [PATCH 6/8] Fix drop behaviour --- .../network/transmitter/transaction.rs | 73 ++++++++++--------- rust/src/connection/transaction_stream.rs | 8 +- rust/src/database/migration.rs | 2 +- rust/src/transaction.rs | 4 +- rust/tests/integration/example.rs | 6 +- 5 files changed, 52 insertions(+), 41 deletions(-) diff --git a/rust/src/connection/network/transmitter/transaction.rs b/rust/src/connection/network/transmitter/transaction.rs index 3f488c940..7faee4000 100644 --- a/rust/src/connection/network/transmitter/transaction.rs +++ b/rust/src/connection/network/transmitter/transaction.rs @@ -36,34 +36,26 @@ use log::{debug, error}; use prost::Message; #[cfg(not(feature = "sync"))] use tokio::sync::oneshot::channel as oneshot; -use tokio::{ - select, - sync::{ - mpsc::{error::SendError, unbounded_channel as unbounded_async, UnboundedReceiver, UnboundedSender}, - oneshot::{channel as oneshot_async, Sender as AsyncOneshotSender}, - }, - time::{sleep_until, Instant}, -}; +use tokio::{select, sync::{ + mpsc::{error::SendError, unbounded_channel as unbounded_async, UnboundedReceiver, UnboundedSender}, + oneshot::{channel as oneshot_async, Sender as AsyncOneshotSender}, +}, task, time::{sleep_until, Instant}}; use tonic::Streaming; use typedb_protocol::transaction::{self, res_part::ResPart, server::Server, stream_signal::res_part::State}; #[cfg(feature = "sync")] use super::oneshot_blocking as oneshot; use super::response_sink::{ResponseSink, StreamResponse}; -use crate::{ - common::{ - box_promise, - error::ConnectionError, - stream::{NetworkStream, Stream}, - Callback, Promise, RequestID, Result, - }, - connection::{ - message::{QueryResponse, TransactionRequest, TransactionResponse}, - network::proto::{FromProto, IntoProto, TryFromProto}, - runtime::BackgroundRuntime, - }, - Error, -}; +use crate::{common::{ + box_promise, + error::ConnectionError, + stream::{NetworkStream, Stream}, + Callback, Promise, RequestID, Result, +}, connection::{ + message::{QueryResponse, TransactionRequest, TransactionResponse}, + network::proto::{FromProto, IntoProto, TryFromProto}, + runtime::BackgroundRuntime, +}, resolve, Error}; pub(in crate::connection) struct TransactionTransmitter { request_sink: UnboundedSender<(TransactionRequest, Option>)>, @@ -77,11 +69,10 @@ pub(in crate::connection) struct TransactionTransmitter { impl Drop for TransactionTransmitter { fn drop(&mut self) { - // TODO: in the async context, this now returns a promise... do we care? - // ---> Basically, this is now a network round trip operation and can take time - // ---> Decision 1: should drop be blocking like this - // ---> Decision 2: if it should, we need to use async drop for async contexts? or poll? - let _ = self.force_close(); + // fire and forget shutdown + if self.is_open.compare_exchange(true, false).is_ok() { + self.shutdown_sink.send(()).ok(); + } } } @@ -124,13 +115,15 @@ impl TransactionTransmitter { pub(in crate::connection) fn force_close(&self) -> impl Promise<'static, Result<()>> { if self.is_open.compare_exchange(true, false).is_ok() { let (closed_sink, mut closed_source) = unbounded_async(); - let close_notifier_callback = Box::new(move |error | { + let close_notifier_callback = Box::new(move |error| { closed_sink.send(()).unwrap(); }); - self.on_close(close_notifier_callback); + let on_close_submit_promise = self.on_close(close_notifier_callback); *self.error.write().unwrap() = Some(ConnectionError::TransactionIsClosed.into()); - self.shutdown_sink.send(()).ok(); + let shutdown_sink = self.shutdown_sink.clone(); box_promise(async move { + resolve!(on_close_submit_promise); + shutdown_sink.send(()).ok(); closed_source.recv().await; Ok(()) }) @@ -142,13 +135,13 @@ impl TransactionTransmitter { } #[cfg(feature = "sync")] - pub(in crate::connection) fn force_close(&self) -> impl Promise<'static, Result<()>> { + pub(in crate::connection) fn force_close(&self) -> impl Promise<'_, Result<()>> { if self.is_open.compare_exchange(true, false).is_ok() { let (closed_sink, closed_source) = oneshot(); let close_notifier_callback = Box::new(move |error| { closed_sink.send(()).unwrap(); }); - self.on_close(close_notifier_callback); + resolve!(self.on_close(close_notifier_callback)); *self.error.write().unwrap() = Some(ConnectionError::TransactionIsClosed.into()); self.shutdown_sink.send(()).ok(); box_promise(move || { @@ -162,10 +155,22 @@ impl TransactionTransmitter { } } - pub(in crate::connection) fn on_close(&self, callback: impl FnOnce(Option) + Send + Sync + 'static) { + #[cfg(not(feature = "sync"))] + pub(in crate::connection) fn on_close(&self, callback: impl FnOnce(Option) + Send + Sync + 'static) -> impl Promise<'static, ()> { + let (sender, mut sink) = unbounded_async(); + self.on_close_register_sink.send((Box::new(callback), sender)).ok(); + box_promise(async move { + sink.recv().await.expect("Did not receive on_close registration success signal"); + }) + } + + #[cfg(feature = "sync")] + pub(in crate::connection) fn on_close(&self, callback: impl FnOnce(Option) + Send + Sync + 'static) -> impl Promise<'static, ()> { let (sender, mut sink) = unbounded_async(); self.on_close_register_sink.send((Box::new(callback), sender)).ok(); - sink.blocking_recv().expect("Did not receive on_close registration success signal"); + box_promise(move || { + sink.blocking_recv().expect("Did not receive on_close registration success signal"); + }) } #[cfg(not(feature = "sync"))] diff --git a/rust/src/connection/transaction_stream.rs b/rust/src/connection/transaction_stream.rs index c3b2f8bf4..31c3a96fd 100644 --- a/rust/src/connection/transaction_stream.rs +++ b/rust/src/connection/transaction_stream.rs @@ -78,8 +78,10 @@ impl TransactionStream { self.transaction_transmitter.is_open() } - pub(crate) fn force_close(&self) -> impl Promise<'static, Result<()>> { - self.transaction_transmitter.force_close() + pub(crate) fn force_close(&self) -> impl Promise<'_, Result<()>> { + promisify! { + resolve!(self.transaction_transmitter.force_close()) + } } pub(crate) fn type_(&self) -> TransactionType { @@ -90,7 +92,7 @@ impl TransactionStream { self.options } - pub(crate) fn on_close(&self, callback: impl FnOnce(Option) + Send + Sync + 'static) { + pub(crate) fn on_close(&self, callback: impl FnOnce(Option) + Send + Sync + 'static) -> impl Promise<'static, ()> { self.transaction_transmitter.on_close(callback) } diff --git a/rust/src/database/migration.rs b/rust/src/database/migration.rs index 4f4780224..c572da184 100644 --- a/rust/src/database/migration.rs +++ b/rust/src/database/migration.rs @@ -20,7 +20,7 @@ use std::{ cmp::max, fs::{File, OpenOptions}, - io::{BufRead, Read}, + io::{BufRead}, marker::PhantomData, path::Path, }; diff --git a/rust/src/transaction.rs b/rust/src/transaction.rs index b362faa57..84b739ebe 100644 --- a/rust/src/transaction.rs +++ b/rust/src/transaction.rs @@ -95,7 +95,7 @@ impl Transaction { /// ```rust /// transaction.on_close(function) /// ``` - pub fn on_close(&self, callback: impl FnOnce(Option) + Send + Sync + 'static) { + pub fn on_close(&self, callback: impl FnOnce(Option) + Send + Sync + 'static) -> impl Promise<'static, ()> { self.transaction_stream.on_close(callback) } @@ -107,7 +107,7 @@ impl Transaction { #[cfg_attr(feature = "sync", doc = "transaction.force_close().resolve()")] #[cfg_attr(not(feature = "sync"), doc = "transaction.force_close().await")] /// ``` - pub fn force_close(&self) -> impl Promise<'static, Result<()>> { + pub fn force_close(&self) -> impl Promise<'_, Result<()>> { self.transaction_stream.force_close() } diff --git a/rust/tests/integration/example.rs b/rust/tests/integration/example.rs index b59914267..b6932e895 100644 --- a/rust/tests/integration/example.rs +++ b/rust/tests/integration/example.rs @@ -44,7 +44,11 @@ async fn cleanup() { .await .unwrap(); if driver.databases().contains("typedb").await.unwrap() { - driver.databases().get("typedb").await.unwrap().delete().await.unwrap(); + println!("Confirmed DB contains, going to get..."); + let db = driver.databases().get("typedb").await.unwrap(); + println!("Got DB"); + db.delete().await.unwrap(); + println!("Deleted db"); } } From 9e94c901d7bcffe0adfaa9e8fefe2215884555c3 Mon Sep 17 00:00:00 2001 From: joshua Date: Tue, 30 Sep 2025 18:57:44 -0700 Subject: [PATCH 7/8] Trigger CI From a3bbf3a897b060054cb3d64cca65d621dbd5b58d Mon Sep 17 00:00:00 2001 From: joshua Date: Tue, 30 Sep 2025 19:06:49 -0700 Subject: [PATCH 8/8] Use built-in Python 3.11 in mac jobs --- .circleci/config.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 588801a09..96f4ff0f1 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -100,7 +100,7 @@ commands: type: string steps: - run: | - brew install python@3.9 +# brew install python@3.9 curl -OL "https://github.com/bazelbuild/bazelisk/releases/download/v1.17.0/bazelisk-darwin-<>" sudo mv "bazelisk-darwin-<>" /usr/local/bin/bazel chmod a+x /usr/local/bin/bazel @@ -149,7 +149,7 @@ commands: steps: - install-brew-rosetta - run: | - /usr/local/bin/brew install python@3.9 +# /usr/local/bin/brew install python@3.9 tool/test/start-community-server.sh /usr/local/bin/python3.9 -m pip install wheel /usr/local/bin/python3.9 -m pip install pip==21.3.1