1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+
18
+
19
+ import ast
20
+ import logging
21
+ import threading
22
+ import time
23
+ import uuid
24
+
25
+ import pyarrow
26
+ import pyarrow .flight
27
+
28
+
29
+ logger = logging .getLogger ('__main__.' + __name__ )
30
+
31
+ class FlightServer (pyarrow .flight .FlightServerBase ):
32
+ def __init__ (self , host = "localhost" , location = None ,
33
+ tls_certificates = None , verify_client = False ,
34
+ root_certificates = None , auth_handler = None , middleware = None ):
35
+ super (FlightServer , self ).__init__ (
36
+ location , auth_handler , tls_certificates , verify_client ,
37
+ root_certificates , middleware )
38
+ self .flights = {}
39
+ self .host = host
40
+ self .tls_certificates = tls_certificates
41
+ self .location = location
42
+
43
+ @classmethod
44
+ def descriptor_to_key (self , descriptor ):
45
+ return (descriptor .descriptor_type .value , descriptor .command ,
46
+ tuple (descriptor .path or tuple ()))
47
+
48
+ def _make_flight_info (self , key , descriptor , table ):
49
+ if self .tls_certificates :
50
+ location = pyarrow .flight .Location .for_grpc_tls (
51
+ self .host , self .port )
52
+ else :
53
+ location = pyarrow .flight .Location .for_grpc_tcp (
54
+ self .host , self .port )
55
+ endpoints = [pyarrow .flight .FlightEndpoint (repr (key ), [location ]), ]
56
+
57
+ mock_sink = pyarrow .MockOutputStream ()
58
+ stream_writer = pyarrow .RecordBatchStreamWriter (
59
+ mock_sink , table .schema )
60
+ stream_writer .write_table (table )
61
+ stream_writer .close ()
62
+ data_size = mock_sink .size ()
63
+
64
+ return pyarrow .flight .FlightInfo (table .schema ,
65
+ descriptor , endpoints ,
66
+ table .num_rows , data_size )
67
+
68
+ def list_flights (self , context , criteria ):
69
+ for key , table in self .flights .items ():
70
+ if key [1 ] is not None :
71
+ descriptor = \
72
+ pyarrow .flight .FlightDescriptor .for_command (key [1 ])
73
+ else :
74
+ descriptor = pyarrow .flight .FlightDescriptor .for_path (* key [2 ])
75
+
76
+ yield self ._make_flight_info (key , descriptor , table )
77
+
78
+ def get_flight_info (self , context , descriptor ):
79
+ key = FlightServer .descriptor_to_key (descriptor )
80
+ logger .info (f"get_flight_info: key={ key } " )
81
+ if key in self .flights :
82
+ table = self .flights [key ]
83
+ return self ._make_flight_info (key , descriptor , table )
84
+ raise KeyError ('Flight not found.' )
85
+
86
+ def do_put (self , context , descriptor , reader , writer ):
87
+ key = FlightServer .descriptor_to_key (descriptor )
88
+ logger .info (f"do_put: key={ key } " )
89
+ self .flights [key ] = reader .read_all ()
90
+
91
+ def do_get (self , context , ticket ):
92
+ logger .info (f"do_get: ticket={ ticket } " )
93
+ key = ast .literal_eval (ticket .ticket .decode ())
94
+ if key not in self .flights :
95
+ logger .warn (f"do_get: key={ key } not found" )
96
+ return None
97
+ logger .info (f"do_get: returning key={ key } " )
98
+ flight = self .flights .pop (key )
99
+ return pyarrow .flight .RecordBatchStream (flight )
100
+
101
+ def list_actions (self , context ):
102
+ return iter ([
103
+ ("getUniquePath" , "Get a unique FlightDescriptor path to put data to." ),
104
+ ("clear" , "Clear the stored flights." ),
105
+ ("shutdown" , "Shut down this server." ),
106
+ ])
107
+
108
+ def do_action (self , context , action ):
109
+ logger .info (f"do_action: action={ action .type } " )
110
+ if action .type == "getUniquePath" :
111
+ uniqueId = str (uuid .uuid4 ())
112
+ logger .info (f"getUniquePath id={ uniqueId } " )
113
+ yield uniqueId .encode ('utf-8' )
114
+ elif action .type == "clear" :
115
+ self ._clear ()
116
+ elif action .type == "healthcheck" :
117
+ pass
118
+ elif action .type == "shutdown" :
119
+ self ._clear ()
120
+ yield pyarrow .flight .Result (pyarrow .py_buffer (b'Shutdown!' ))
121
+ # Shut down on background thread to avoid blocking current
122
+ # request
123
+ threading .Thread (target = self ._shutdown ).start ()
124
+ else :
125
+ raise KeyError ("Unknown action {!r}" .format (action .type ))
126
+
127
+ def _clear (self ):
128
+ """Clear the stored flights."""
129
+ self .flights = {}
130
+
131
+ def _shutdown (self ):
132
+ """Shut down after a delay."""
133
+ logger .info ("Server is shutting down..." )
134
+ time .sleep (2 )
135
+ self .shutdown ()
136
+
137
+ def start (server ):
138
+ logger .info (f"Serving on { server .location } " )
139
+ server .serve ()
140
+
141
+
142
+ if __name__ == '__main__' :
143
+ start ()
0 commit comments