Skip to content

Commit 3bdfcd1

Browse files
antesajhlee12Chinmay Gore
authored
TabPy Arrow Support (#595)
* Adds Arrow Flight support via configuration parameter. When enabled, sets up a flight server that processes requests with certain payloads. Processing incoming data in batches is supported. * Adds Flight documentation to relevant READMEs. * Adds Arrow server integration tests. * Leverages existing auth infrastructure to apply auth to the Flight connection. * Retains backward compatibility with existing functionality, whether or not Flight is enabled. --------- Co-authored-by: Ho Sanlok Lee <[email protected]> Co-authored-by: Chinmay Gore <[email protected]>
1 parent fad6807 commit 3bdfcd1

File tree

12 files changed

+464
-9
lines changed

12 files changed

+464
-9
lines changed

docs/server-config.md

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,11 @@ at [`logging.config` documentation page](https://docs.python.org/3.6/library/log
9595
through the `/query` method, or using the `tabpy.query(...)` syntax with
9696
the `/evaluate` method.
9797
- `TABPY_GZIP_ENABLE` - Enable Gzip support for requests. Enabled by default.
98+
- `TABPY_ARROW_ENABLE` - Enable Arrow connection for data streaming. Default
99+
value is False.
100+
- `TABPY_ARROWFLIGHT_PORT` - port for
101+
[Arrow Flight](https://arrow.apache.org/docs/format/Flight.html)
102+
connection used in streaming mode. Default value is 13622.
98103

99104
### Configuration File Example
100105

@@ -138,6 +143,13 @@ settings._
138143
# The value should be a float representing the timeout time in seconds.
139144
# TABPY_EVALUATE_TIMEOUT = 30
140145

146+
# Configure TabPy to support streaming data via Arrow Flight.
147+
# This will cause an Arrow Flight server start up. The Arrow
148+
# Flight port defaults to 13622 if not set here.
149+
# TABPY_ARROW_ENABLE = True
150+
# TABPY_ARROWFLIGHT_PORT = 13622
151+
152+
141153
[loggers]
142154
keys=root
143155

@@ -257,6 +269,33 @@ line with the user name.
257269

258270
All endpoints require authentication if it is enabled for the server.
259271

272+
## Arrow Flight
273+
274+
TabPy can be configured to enable Arrow Flight. This will cause a Flight
275+
server to start up alongside the HTTP server and will allow for handling
276+
incoming streamed data in the Arrow columnar format.
277+
278+
**As of May 2023, the Arrow Flight feature can only be used by compatible
279+
versions of Tableau Prep. The Arrow Flight feature is not used by Tableau
280+
Desktop, Tableau Server, or Tableau Cloud, regardless of the
281+
`TABPY_ARROW_ENABLE` setting. In other words, those products will continue
282+
to send the data in a single payload when Arrow Flight is both enabled
283+
and disabled.**
284+
285+
To leverage the Flight server, use an existing Flight Client API. There
286+
are implementations available in C++, Java, and Python. To begin streaming
287+
data to the server, a Flight Descriptor (data path) must be generated.
288+
One can be obtained via the TabPy Flight server by using the client to
289+
submit a `getUniquePath` Action to the server or it can be randomly generated
290+
locally. The client's `do_put` interface can then be used to begin sending
291+
data to the server.
292+
293+
Structure the data payload in Arrow format according to the client's API
294+
requirements. Continue using the client to append the data path with the
295+
data stream.
296+
297+
The mechanism for sending the Python script to the server does not change.
298+
260299
## Logging
261300

262301
Logging for TabPy is implemented with Python's standard logger and can be configured

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def read(fname):
9090
"tornado",
9191
"twisted",
9292
"urllib3",
93+
"pyarrow",
9394
],
9495
entry_points={
9596
"console_scripts": [

tabpy/tabpy_server/app/app.py

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from tabpy.tabpy import __version__
1111
from tabpy.tabpy_server.app.app_parameters import ConfigParameters, SettingsParameters
1212
from tabpy.tabpy_server.app.util import parse_pwd_file
13+
from tabpy.tabpy_server.handlers.basic_auth_server_middleware_factory import BasicAuthServerMiddlewareFactory
14+
from tabpy.tabpy_server.handlers.no_op_auth_handler import NoOpAuthHandler
1315
from tabpy.tabpy_server.management.state import TabPyState
1416
from tabpy.tabpy_server.management.util import _get_state_from_file
1517
from tabpy.tabpy_server.psws.callbacks import init_model_evaluator, init_ps_server
@@ -25,11 +27,11 @@
2527
UploadDestinationHandler,
2628
)
2729
import tornado
28-
30+
import tabpy.tabpy_server.app.arrow_server as pa
31+
import _thread
2932

3033
logger = logging.getLogger(__name__)
3134

32-
3335
def _init_asyncio_patch():
3436
"""
3537
Select compatible event loop for Tornado 5+.
@@ -59,6 +61,7 @@ class TabPyApp:
5961
tabpy_state = None
6062
python_service = None
6163
credentials = {}
64+
arrow_server = None
6265

6366
def __init__(self, config_file):
6467
if config_file is None:
@@ -75,6 +78,42 @@ def __init__(self, config_file):
7578

7679
self._parse_config(config_file)
7780

81+
def _get_tls_certificates(self, config):
82+
tls_certificates = []
83+
cert = config[SettingsParameters.CertificateFile]
84+
key = config[SettingsParameters.KeyFile]
85+
with open(cert, "rb") as cert_file:
86+
tls_cert_chain = cert_file.read()
87+
with open(key, "rb") as key_file:
88+
tls_private_key = key_file.read()
89+
tls_certificates.append((tls_cert_chain, tls_private_key))
90+
return tls_certificates
91+
92+
def _get_arrow_server(self, config):
93+
verify_client = None
94+
tls_certificates = None
95+
scheme = "grpc+tcp"
96+
if config[SettingsParameters.TransferProtocol] == "https":
97+
scheme = "grpc+tls"
98+
tls_certificates = self._get_tls_certificates(config)
99+
100+
host = "localhost"
101+
port = config.get(SettingsParameters.ArrowFlightPort)
102+
location = "{}://{}:{}".format(scheme, host, port)
103+
104+
auth_middleware = None
105+
if "authentication" in config[SettingsParameters.ApiVersions]["v1"]["features"]:
106+
_, creds = parse_pwd_file(config[ConfigParameters.TABPY_PWD_FILE])
107+
auth_middleware = {
108+
"basic": BasicAuthServerMiddlewareFactory(creds)
109+
}
110+
111+
server = pa.FlightServer(host, location,
112+
tls_certificates=tls_certificates,
113+
verify_client=verify_client, auth_handler=NoOpAuthHandler(),
114+
middleware=auth_middleware)
115+
return server
116+
78117
def run(self):
79118
application = self._create_tornado_web_app()
80119
max_request_size = (
@@ -99,18 +138,30 @@ def run(self):
99138
settings = {}
100139
if self.settings[SettingsParameters.GzipEnabled] is True:
101140
settings["decompress_request"] = True
141+
102142
application.listen(
103143
self.settings[SettingsParameters.Port],
104144
ssl_options=ssl_options,
105145
max_buffer_size=max_request_size,
106146
max_body_size=max_request_size,
107147
**settings,
108-
)
148+
)
109149

110150
logger.info(
111151
"Web service listening on port "
112152
f"{str(self.settings[SettingsParameters.Port])}"
113153
)
154+
155+
if self.settings[SettingsParameters.ArrowEnabled]:
156+
def start_pyarrow():
157+
self.arrow_server = self._get_arrow_server(self.settings)
158+
pa.start(self.arrow_server)
159+
160+
try:
161+
_thread.start_new_thread(start_pyarrow, ())
162+
except Exception as e:
163+
logger.critical(f"Failed to start PyArrow server: {e}")
164+
114165
tornado.ioloop.IOLoop.instance().start()
115166

116167
def _create_tornado_web_app(self):
@@ -287,6 +338,8 @@ def _parse_config(self, config_file):
287338
100, None),
288339
(SettingsParameters.GzipEnabled, ConfigParameters.TABPY_GZIP_ENABLE,
289340
True, parser.getboolean),
341+
(SettingsParameters.ArrowEnabled, ConfigParameters.TABPY_ARROW_ENABLE, False, parser.getboolean),
342+
(SettingsParameters.ArrowFlightPort, ConfigParameters.TABPY_ARROWFLIGHT_PORT, 13622, parser.getint),
290343
]
291344

292345
for setting, parameter, default_val, parse_function in settings_parameters:
@@ -424,6 +477,7 @@ def _get_features(self):
424477

425478
features["evaluate_enabled"] = self.settings[SettingsParameters.EvaluateEnabled]
426479
features["gzip_enabled"] = self.settings[SettingsParameters.GzipEnabled]
480+
features["arrow_enabled"] = self.settings[SettingsParameters.ArrowEnabled]
427481
return features
428482

429483
def _build_tabpy_state(self):

tabpy/tabpy_server/app/app_parameters.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,17 @@ class ConfigParameters:
1010
TABPY_TRANSFER_PROTOCOL = "TABPY_TRANSFER_PROTOCOL"
1111
TABPY_CERTIFICATE_FILE = "TABPY_CERTIFICATE_FILE"
1212
TABPY_KEY_FILE = "TABPY_KEY_FILE"
13-
TABPY_PWD_FILE = "TABPY_PWD_FILE"
1413
TABPY_LOG_DETAILS = "TABPY_LOG_DETAILS"
1514
TABPY_STATIC_PATH = "TABPY_STATIC_PATH"
1615
TABPY_MAX_REQUEST_SIZE_MB = "TABPY_MAX_REQUEST_SIZE_MB"
1716
TABPY_EVALUATE_ENABLE = "TABPY_EVALUATE_ENABLE"
1817
TABPY_EVALUATE_TIMEOUT = "TABPY_EVALUATE_TIMEOUT"
1918
TABPY_GZIP_ENABLE = "TABPY_GZIP_ENABLE"
2019

20+
# Arrow specific settings
21+
TABPY_ARROW_ENABLE = "TABPY_ARROW_ENABLE"
22+
TABPY_ARROWFLIGHT_PORT = "TABPY_ARROWFLIGHT_PORT"
23+
2124

2225
class SettingsParameters:
2326
"""
@@ -38,3 +41,7 @@ class SettingsParameters:
3841
EvaluateTimeout = "evaluate_timeout"
3942
EvaluateEnabled = "evaluate_enabled"
4043
GzipEnabled = "gzip_enabled"
44+
45+
# Arrow specific settings
46+
ArrowEnabled = "arrow_enabled"
47+
ArrowFlightPort = "arrowflight_port"
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
19+
import ast
20+
import logging
21+
import threading
22+
import time
23+
import uuid
24+
25+
import pyarrow
26+
import pyarrow.flight
27+
28+
29+
logger = logging.getLogger('__main__.' + __name__)
30+
31+
class FlightServer(pyarrow.flight.FlightServerBase):
32+
def __init__(self, host="localhost", location=None,
33+
tls_certificates=None, verify_client=False,
34+
root_certificates=None, auth_handler=None, middleware=None):
35+
super(FlightServer, self).__init__(
36+
location, auth_handler, tls_certificates, verify_client,
37+
root_certificates, middleware)
38+
self.flights = {}
39+
self.host = host
40+
self.tls_certificates = tls_certificates
41+
self.location = location
42+
43+
@classmethod
44+
def descriptor_to_key(self, descriptor):
45+
return (descriptor.descriptor_type.value, descriptor.command,
46+
tuple(descriptor.path or tuple()))
47+
48+
def _make_flight_info(self, key, descriptor, table):
49+
if self.tls_certificates:
50+
location = pyarrow.flight.Location.for_grpc_tls(
51+
self.host, self.port)
52+
else:
53+
location = pyarrow.flight.Location.for_grpc_tcp(
54+
self.host, self.port)
55+
endpoints = [pyarrow.flight.FlightEndpoint(repr(key), [location]), ]
56+
57+
mock_sink = pyarrow.MockOutputStream()
58+
stream_writer = pyarrow.RecordBatchStreamWriter(
59+
mock_sink, table.schema)
60+
stream_writer.write_table(table)
61+
stream_writer.close()
62+
data_size = mock_sink.size()
63+
64+
return pyarrow.flight.FlightInfo(table.schema,
65+
descriptor, endpoints,
66+
table.num_rows, data_size)
67+
68+
def list_flights(self, context, criteria):
69+
for key, table in self.flights.items():
70+
if key[1] is not None:
71+
descriptor = \
72+
pyarrow.flight.FlightDescriptor.for_command(key[1])
73+
else:
74+
descriptor = pyarrow.flight.FlightDescriptor.for_path(*key[2])
75+
76+
yield self._make_flight_info(key, descriptor, table)
77+
78+
def get_flight_info(self, context, descriptor):
79+
key = FlightServer.descriptor_to_key(descriptor)
80+
logger.info(f"get_flight_info: key={key}")
81+
if key in self.flights:
82+
table = self.flights[key]
83+
return self._make_flight_info(key, descriptor, table)
84+
raise KeyError('Flight not found.')
85+
86+
def do_put(self, context, descriptor, reader, writer):
87+
key = FlightServer.descriptor_to_key(descriptor)
88+
logger.info(f"do_put: key={key}")
89+
self.flights[key] = reader.read_all()
90+
91+
def do_get(self, context, ticket):
92+
logger.info(f"do_get: ticket={ticket}")
93+
key = ast.literal_eval(ticket.ticket.decode())
94+
if key not in self.flights:
95+
logger.warn(f"do_get: key={key} not found")
96+
return None
97+
logger.info(f"do_get: returning key={key}")
98+
flight = self.flights.pop(key)
99+
return pyarrow.flight.RecordBatchStream(flight)
100+
101+
def list_actions(self, context):
102+
return iter([
103+
("getUniquePath", "Get a unique FlightDescriptor path to put data to."),
104+
("clear", "Clear the stored flights."),
105+
("shutdown", "Shut down this server."),
106+
])
107+
108+
def do_action(self, context, action):
109+
logger.info(f"do_action: action={action.type}")
110+
if action.type == "getUniquePath":
111+
uniqueId = str(uuid.uuid4())
112+
logger.info(f"getUniquePath id={uniqueId}")
113+
yield uniqueId.encode('utf-8')
114+
elif action.type == "clear":
115+
self._clear()
116+
elif action.type == "healthcheck":
117+
pass
118+
elif action.type == "shutdown":
119+
self._clear()
120+
yield pyarrow.flight.Result(pyarrow.py_buffer(b'Shutdown!'))
121+
# Shut down on background thread to avoid blocking current
122+
# request
123+
threading.Thread(target=self._shutdown).start()
124+
else:
125+
raise KeyError("Unknown action {!r}".format(action.type))
126+
127+
def _clear(self):
128+
"""Clear the stored flights."""
129+
self.flights = {}
130+
131+
def _shutdown(self):
132+
"""Shut down after a delay."""
133+
logger.info("Server is shutting down...")
134+
time.sleep(2)
135+
self.shutdown()
136+
137+
def start(server):
138+
logger.info(f"Serving on {server.location}")
139+
server.serve()
140+
141+
142+
if __name__ == '__main__':
143+
start()

tabpy/tabpy_server/handlers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,7 @@
1111
from tabpy.tabpy_server.handlers.upload_destination_handler import (
1212
UploadDestinationHandler,
1313
)
14+
from tabpy.tabpy_server.handlers.no_op_auth_handler import NoOpAuthHandler
15+
from tabpy.tabpy_server.handlers.basic_auth_server_middleware_factory import (
16+
BasicAuthServerMiddlewareFactory,
17+
)

0 commit comments

Comments
 (0)