Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[flake8]
ignore = E501
200 changes: 200 additions & 0 deletions examples/event_streamer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
"""
Example usage of the EventStreamer class

NOTE: quick and dirty implementation
"""

import asyncio
import os
from datetime import datetime
from decimal import Decimal
from typing import TypeVar

import dotenv

from tastytrade.dxfeed import (
Candle,
Greeks,
Profile,
Quote,
Summary,
TheoPrice,
TimeAndSale,
Trade,
Underlying,
)
from tastytrade.session import Session
from tastytrade.streamer import EventStreamer

dotenv.load_dotenv()

session = Session(
provider_secret=os.getenv("TT_API_CLIENT_SECRET", ""),
refresh_token=os.getenv("TT_REFRESH_TOKEN", ""),
)


E = TypeVar(
"E",
Candle,
Greeks,
Profile,
Quote,
Summary,
TheoPrice,
TimeAndSale,
Trade,
Underlying,
)


class OHLCVBar:
def __init__(self, symbol: str, open_price: Decimal, volume: int | None, timestamp: datetime):
self.symbol: str = symbol
self.open: Decimal = open_price
self.high: Decimal = open_price
self.low: Decimal = open_price
self.close: Decimal = open_price
self.volume: int | None = volume
self.tick_count: int = 1
self.timestamp: datetime = timestamp

def update(self, price: Decimal, volume: int | None) -> None:
self.high = max(self.high, price)
self.low = min(self.low, price)
self.close = price
self.volume = volume if volume is not None else self.volume
self.tick_count += 1

def __str__(self) -> str:
return (
f"OHLCVBar(symbol={self.symbol}, open={self.open}, high={self.high}, low={self.low},"
f" close={self.close}, volume={self.volume}, "
f"tick_count={self.tick_count}, timestamp={self.timestamp})"
)

def __repr__(self) -> str:
return self.__str__()


class OHLCVBars:
def __init__(self, timeframe: str, symbol: str):
self.ohlcv_bars: list[OHLCVBar] = []
self.current_bar: OHLCVBar | None = None
self.timeframe: str = timeframe
self.symbol: str = symbol

def update(self, event: Trade) -> None:
ts_ms = event.time

# --- TICK BARS ---
if self.timeframe.endswith("t"):
n = int(self.timeframe[:-1]) # "200t" -> 200 (also works for "1t")

if self.current_bar is None:
self.current_bar = OHLCVBar(self.symbol, event.price, event.size, timestamp=datetime.fromtimestamp(ts_ms / 1000))
else:
self.current_bar.update(event.price, event.size)

# close AFTER update so tick_count includes this trade
if self.current_bar.tick_count >= n:
self.ohlcv_bars.append(self.current_bar)
# start a new bar on next tick (or immediately if you prefer)
self.current_bar = None
print(self.ohlcv_bars[-1])
return

# --- TIME BARS (bucketed) ---
bucket_ts = self.bucket_timestamp(ts_ms, self.timeframe)

if self.current_bar is None:
self.current_bar = OHLCVBar(
self.symbol, event.price, event.size, timestamp=bucket_ts
)
return

# If this tick belongs to a new time bucket, finalize old bar and start new one
if bucket_ts != self.current_bar.timestamp:
self.ohlcv_bars.append(self.current_bar)
self.current_bar = OHLCVBar(
self.symbol, event.price, event.size, timestamp=bucket_ts
)
print(self.ohlcv_bars[-1])
else:
self.current_bar.update(event.price, event.size)

@staticmethod
def bucket_timestamp(ts_ms: int, timeframe: str) -> datetime:
# floor ts into the timeframe bucket so you never get two candles for the
# same bucket
if timeframe == "5s":
start = (ts_ms // 5_000) * 5_000
elif timeframe == "1m":
start = (ts_ms // 60_000) * 60_000
elif timeframe == "5m":
start = (ts_ms // 300_000) * 300_000
elif timeframe == "15m":
start = (ts_ms // 900_000) * 900_000
elif timeframe == "30m":
start = (ts_ms // 1_800_000) * 1_800_000
elif timeframe == "1h":
start = (ts_ms // 3_600_000) * 3_600_000
elif timeframe == "4h":
start = (ts_ms // 14_400_000) * 14_400_000
elif timeframe == "1d":
start = (ts_ms // 86_400_000) * 86_400_000
elif timeframe == "1w":
start = (ts_ms // 604_800_000) * 604_800_000
else:
raise ValueError(f"Invalid timeframe: {timeframe}")

return datetime.fromtimestamp(start / 1000)


ohlcv_bars_dict: dict[str, OHLCVBars] = {}


def ohlcv_bars(symbols: list[str]) -> None:
global ohlcv_bars_list
for symbol in symbols:
ohlcv_bars_dict[symbol] = OHLCVBars("5s", symbol)


async def update_ohlcv_bars(trade: Trade) -> None:
global ohlcv_bars_dict
ohlcv_bars_dict[trade.event_symbol].update(trade)


async def handle_quote(quote: Quote) -> None:
print(f"QUOTE: {quote}")


async def monitor_tasks(tasks: list[asyncio.Task[None]], stop_time: int) -> None:
"""
Monitor the tasks and stop them after the given time
"""
await asyncio.sleep(stop_time)
for t in tasks:
t.cancel()


async def main() -> None:
quote_streamer = EventStreamer(session, ["SPY"], Quote, handle_quote)
symbols = ["SPY", "AAPL"]
ohlcv_bars(symbols=symbols) # initialize the ohlcv bars in a global dictionary

ohlcv_bars_streamers = [EventStreamer(session, [symbol], Trade, update_ohlcv_bars) for symbol in symbols]

tasks: list[asyncio.Task[None]] = []
tasks.append(asyncio.create_task(quote_streamer.start()))
for t in ohlcv_bars_streamers:
tasks.append(asyncio.create_task(t.start()))

try:
await asyncio.gather(*tasks, monitor_tasks(tasks, 20))
except asyncio.CancelledError:
pass


if __name__ == "__main__":
asyncio.run(main())
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ dependencies = [
"httpx>=0.28.1",
"pandas-market-calendars>=5.1.1",
"pydantic>=2.11.9",
"python-dotenv>=1.0.0",
"websockets>=15.0.1",
]
dynamic = ["version"]
Expand Down
52 changes: 52 additions & 0 deletions tastytrade/streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,3 +884,55 @@ async def _map_message(self, message: list[Any]) -> None:
results = cls.from_stream(data)
for r in results:
await self._queues[msg_type].put(r) # type: ignore


class EventStreamer:
"""
Utility class to streamline the process of subscribing to events and
calling a callback when an event is received.

NOTE: Instantiate an instance of EventStreamer for each event class
you want to stream

NOTE: stop the streamer by cancelling the task
"""

def __init__(
self,
session: Session,
symbols: list[str],
event_class: type[U],
callback: Callable[[U], Coroutine[Any, Any, None]],
):
"""
Initializes the event streamer

:param session: The session to use for the streamer
:param symbols: The symbols to subscribe to
:param event_class: The event class to subscribe to
:param callback: The callback to call when an event is received
"""

self._session = session
self._symbols = symbols
self._event_class = event_class
# callback can be set to different function by consumer after initialization
self.callback = callback

async def start(self) -> None:
"""
starts the streamer and subscribes to the given symbols and calls
the callback when an event is received
"""

if self.callback is None:
raise TastytradeError("No callback provided")

self._stop = False

async with DXLinkStreamer(self._session) as streamer:
await streamer.subscribe(self._event_class, self._symbols) # type: ignore
async for event in streamer.listen(self._event_class):
await self.callback(event) # type: ignore

self._stop = False
4 changes: 3 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from datetime import date, datetime
from unittest.mock import patch
from zoneinfo import ZoneInfo

from tastytrade.utils import (
TZ,
get_future_fx_monthly,
get_future_grain_monthly,
get_future_index_monthly,
Expand All @@ -15,6 +15,8 @@
today_in_new_york,
)

TZ = ZoneInfo("US/Eastern")


def test_get_third_friday():
assert get_third_friday(date(2024, 3, 2)) == date(2024, 3, 15)
Expand Down
2 changes: 2 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading