@@ -972,11 +972,12 @@ def _get_data_series(self, call_func: str, symbol_list: Union[str, List[str]], d
972972 if adj_type not in [None , "F" , "B" ]:
973973 raise Exception ("adj_type 参数只支持 None (不复权) | 'F' (前复权) | 'B' (后复权) " )
974974 ds = DataSeries (self , symbol_list , dur_nano , start_dt_nano , end_dt_nano , adj_type )
975- while not self ._loop .is_running () and not ds .is_ready :
976- deadline = time .time () + 30
977- if not self .wait_update (deadline = deadline ):
978- raise TqTimeoutError (
979- f"{ call_func } 获取数据 ({ symbol_list , duration_seconds , start_dt , end_dt } ) 超时,请检查客户端及网络是否正常。" )
975+ if not self ._loop .is_running ():
976+ while not ds ._task .done ():
977+ deadline = time .time () + 30
978+ if not self .wait_update (deadline = deadline , _task = ds ._task ):
979+ raise TqTimeoutError (
980+ f"{ call_func } 获取数据 ({ symbol_list , duration_seconds , start_dt , end_dt } ) 超时,请检查客户端及网络是否正常。" )
980981 return ds .df
981982
982983 # ----------------------------------------------------------------------
@@ -1811,13 +1812,12 @@ def set_risk_management_rule(self, exchange_id: str, enable: bool, count_limit:
18111812 rule = _get_obj (self ._data , ["trade" , self ._account ._get_account_key (account ), "risk_management_rule" , exchange_id ], RiskManagementRule (self ))
18121813 if not self ._loop .is_running ():
18131814 deadline = time .time () + 30
1814- while not (rule_pack ['enable' ] == rule ['enable' ]
1815- and rule_pack ['self_trade' ].items () <= rule ['self_trade' ].items ()
1816- and rule_pack ['frequent_cancellation' ].items () <= rule ['frequent_cancellation' ].items ()
1817- and rule_pack ['trade_position_ratio' ].items () <= rule ['trade_position_ratio' ].items ()):
1818- # @todo: merge diffs
1819- if not self .wait_update (deadline = deadline ):
1820- raise TqTimeoutError ("设置风控规则超时请检查客户端及网络是否正常" )
1815+ cond = lambda : (rule_pack ['enable' ] == rule ['enable' ]
1816+ and rule_pack ['self_trade' ].items () <= rule ['self_trade' ].items ()
1817+ and rule_pack ['frequent_cancellation' ].items () <= rule ['frequent_cancellation' ].items ()
1818+ and rule_pack ['trade_position_ratio' ].items () <= rule ['trade_position_ratio' ].items ())
1819+ if not self ._wait_update_until (cond = cond , deadline = deadline ):
1820+ raise TqTimeoutError ("设置风控规则超时请检查客户端及网络是否正常" )
18211821 return rule
18221822
18231823 # ----------------------------------------------------------------------
@@ -1933,6 +1933,38 @@ def wait_update(self, deadline: Optional[float] = None, _task: Union[asyncio.Tas
19331933 else : # 订阅多个合约
19341934 self ._update_serial_multi (serial )
19351935
1936+ def _wait_update_until (self , cond : Callable [[], bool ], deadline : Optional [float ] = None ) -> bool :
1937+ """
1938+ TqApi 内部使用,用于等待某个条件满足。持续调用 wait_update(),直到 cond() 返回 True。
1939+
1940+ Args:
1941+ cond (Callable[[], bool]): 条件函数
1942+ deadline (float): [可选]指定截止时间,自unix epoch(1970-01-01 00:00:00 GMT)以来的秒数(time.time())。默认没有超时(无限等待)
1943+
1944+ Returns:
1945+ bool: 当 cond() 为 True 时返回 True, 如果到截止时间 cond() 依然为 False 则返回 False
1946+
1947+ 注:用于 tqsdk 内部,某些地方会用到 api.wait_update(),等待数据更新后再返回给用户,
1948+ * 简单调用 wait_update() 导致 api._sync_diffs 丢失变更
1949+ * 为了避免这种情况,内部调用 wait_update() 应该传入 _task 参数,这样 api._sync_diffs 不会丢失变更
1950+ """
1951+ if cond ():
1952+ return True
1953+
1954+ async def _async_wait_task ():
1955+ async with self .register_update_notify () as update_chan :
1956+ async for _ in update_chan :
1957+ if cond ():
1958+ break
1959+
1960+ _task = self .create_task (_async_wait_task ())
1961+
1962+ while not cond ():
1963+ data_updated = self .wait_update (deadline = deadline , _task = _task )
1964+ if data_updated is False :
1965+ return False # TimeoutError
1966+ return True
1967+
19361968 # ----------------------------------------------------------------------
19371969 def is_changing (self , obj : Any , key : Union [str , List [str ], None ] = None ) -> bool :
19381970 """
@@ -2226,9 +2258,9 @@ def query_graphql(self, query: str, variables: dict, query_id: Optional[str] = N
22262258 })
22272259 deadline = time .time () + 60
22282260 if not self ._loop .is_running ():
2229- while query_id not in symbols :
2230- if not self . wait_update ( deadline = deadline ):
2231- raise TqTimeoutError ("查询合约服务 %s 超时,请检查客户端及网络是否正常 %s" % (query , query_id ))
2261+ if not self . _wait_update_until ( cond = lambda : query_id in symbols , deadline = deadline ) :
2262+ # 使用 _task 参数,确保不会丢掉 _sync_diffs 里的变更
2263+ raise TqTimeoutError ("查询合约服务 %s 超时,请检查客户端及网络是否正常 %s" % (query , query_id ))
22322264 if isinstance (self ._backtest , TqBacktest ):
22332265 self ._send_pack ({
22342266 "aid" : "ins_query" ,
0 commit comments