|
15 | 15 |
|
16 | 16 | package org.sqlflow.client; |
17 | 17 |
|
| 18 | +import java.util.Iterator; |
| 19 | +import java.util.List; |
| 20 | +import java.util.NoSuchElementException; |
| 21 | +import java.util.concurrent.TimeUnit; |
| 22 | + |
| 23 | +import com.google.protobuf.Any; |
| 24 | +import io.grpc.ManagedChannel; |
| 25 | +import io.grpc.ManagedChannelBuilder; |
18 | 26 | import io.grpc.StatusRuntimeException; |
19 | | -import proto.Sqlflow.JobStatus; |
| 27 | +import org.apache.commons.lang3.StringUtils; |
| 28 | +import org.sqlflow.client.utils.HTMLDetector; |
| 29 | +import proto.SQLFlowGrpc; |
| 30 | +import proto.Sqlflow.FetchRequest; |
| 31 | +import proto.Sqlflow.FetchResponse; |
| 32 | +import proto.Sqlflow.FetchResponse.Logs; |
| 33 | +import proto.Sqlflow.Job; |
| 34 | +import proto.Sqlflow.Request; |
| 35 | +import proto.Sqlflow.Response; |
20 | 36 | import proto.Sqlflow.Session; |
21 | 37 |
|
22 | | -public interface SQLFlow { |
23 | | - /** |
24 | | - * Submit a task to SQLFlow server. This method return immediately. |
25 | | - * |
26 | | - * @param session: specify dbConnStr(datasource), user Id ... |
27 | | - * mysql://root:root@tcp(localhost)/iris |
28 | | - * @param sql: sql program. |
29 | | - * <p>Example: "SELECT * FROM iris.test; SELECT * FROM iris.iris TO TRAIN DNNClassifier |
30 | | - * COLUMN..." * |
31 | | - * @return return a job id for tracking. |
32 | | - * @throws IllegalArgumentException header or sql error |
33 | | - * @throws StatusRuntimeException |
34 | | - */ |
35 | | - String submit(Session session, String sql) |
36 | | - throws IllegalArgumentException, StatusRuntimeException; |
37 | | - |
38 | | - /** |
39 | | - * Fetch the job status by job id. The job id always returned by submit. By fetch(), we are able |
40 | | - * to tracking the job status |
41 | | - * |
42 | | - * @param jobId specific the job we are going to track |
43 | | - * @return see @code proto.JobStatus.Code |
44 | | - * @throws StatusRuntimeException |
45 | | - */ |
46 | | - JobStatus fetch(String jobId) throws StatusRuntimeException; |
47 | | - |
48 | | - /** |
49 | | - * Close the opened channel to SQLFlow server. Waits for the channel to become terminated, giving |
50 | | - * up if the timeout is reached. |
51 | | - * |
52 | | - * @throws InterruptedException thrown by awaitTermination |
53 | | - */ |
54 | | - void release() throws InterruptedException; |
| 38 | +public class SQLFlow { |
| 39 | + private Builder builder; |
| 40 | + |
| 41 | + private SQLFlowGrpc.SQLFlowBlockingStub blockingStub; |
| 42 | + // TODO(weiguo): It looks we need the futureStub to handle a large data set. |
| 43 | + // private SQLFlowGrpc.SQLFlowFutureStub futureStub; |
| 44 | + |
| 45 | + private SQLFlow(Builder builder) { |
| 46 | + this.builder = builder; |
| 47 | + blockingStub = SQLFlowGrpc.newBlockingStub(builder.channel); |
| 48 | + } |
| 49 | + |
| 50 | + public void run(String sql) |
| 51 | + throws IllegalArgumentException, StatusRuntimeException, NoSuchElementException { |
| 52 | + if (StringUtils.isBlank(sql)) { |
| 53 | + throw new IllegalArgumentException("sql is empty"); |
| 54 | + } |
| 55 | + |
| 56 | + Request req = Request.newBuilder().setSession(builder.session).setSql(sql).build(); |
| 57 | + try { |
| 58 | + Iterator<Response> responses = blockingStub.run(req); |
| 59 | + handleSQLFlowResponses(responses); |
| 60 | + } catch (StatusRuntimeException e) { |
| 61 | + // TODO(weiguo) logger.error |
| 62 | + throw e; |
| 63 | + } |
| 64 | + } |
| 65 | + |
| 66 | + public void release() throws InterruptedException { |
| 67 | + builder.channel.shutdown().awaitTermination(5, TimeUnit.SECONDS); |
| 68 | + } |
| 69 | + |
| 70 | + private void handleSQLFlowResponses(Iterator<Response> responses) { |
| 71 | + if (responses == null || !responses.hasNext()) { |
| 72 | + throw new NoSuchElementException("bad response"); |
| 73 | + } |
| 74 | + while (responses.hasNext()) { |
| 75 | + Response response = responses.next(); |
| 76 | + if (response == null) { |
| 77 | + break; |
| 78 | + } |
| 79 | + if (response.hasHead()) { |
| 80 | + builder.handler.handleHeader(response.getHead().getColumnNamesList()); |
| 81 | + } else if (response.hasRow()) { |
| 82 | + List<Any> rows = response.getRow().getDataList(); |
| 83 | + builder.handler.handleRows(rows); |
| 84 | + } else if (response.hasMessage()) { |
| 85 | + String msg = response.getMessage().getMessage(); |
| 86 | + if (HTMLDetector.validate(msg)) { |
| 87 | + builder.handler.handleHTML(msg); |
| 88 | + } else { |
| 89 | + builder.handler.handleText(msg); |
| 90 | + } |
| 91 | + } else if (response.hasEoe()) { |
| 92 | + builder.handler.handleEOE(); |
| 93 | + // assert(!responses.hasNext()) |
| 94 | + } else if (response.hasJob()) { |
| 95 | + trackingJobStatus(response.getJob().getId()); |
| 96 | + // assert(!responses.hasNext()) |
| 97 | + } else { |
| 98 | + break; |
| 99 | + } |
| 100 | + } |
| 101 | + } |
| 102 | + |
| 103 | + private void trackingJobStatus(String jobId) { |
| 104 | + Job job = Job.newBuilder().setId(jobId).build(); |
| 105 | + FetchRequest req = FetchRequest.newBuilder().setJob(job).build(); |
| 106 | + while (true) { |
| 107 | + FetchResponse response = blockingStub.fetch(req); |
| 108 | + Logs logs = response.getLogs(); |
| 109 | + logs.getContentList().forEach(this.builder.handler::handleText); |
| 110 | + if (response.getEof()) { |
| 111 | + this.builder.handler.handleEOE(); |
| 112 | + break; |
| 113 | + } |
| 114 | + req = response.getUpdatedFetchSince(); |
| 115 | + |
| 116 | + try { |
| 117 | + Thread.sleep(builder.intervalFetching); |
| 118 | + } catch (InterruptedException e) { |
| 119 | + break; |
| 120 | + } |
| 121 | + } |
| 122 | + } |
| 123 | + |
| 124 | + public static class Builder { |
| 125 | + private ManagedChannel channel; |
| 126 | + private MessageHandler handler; |
| 127 | + private Session session; |
| 128 | + private long intervalFetching = 2000L; // millis |
| 129 | + |
| 130 | + public static Builder newInstance() { |
| 131 | + return new Builder(); |
| 132 | + } |
| 133 | + |
| 134 | + public Builder withChannel(ManagedChannel channel) { |
| 135 | + this.channel = channel; |
| 136 | + return this; |
| 137 | + } |
| 138 | + |
| 139 | + public Builder withMessageHandler(MessageHandler handler) { |
| 140 | + this.handler = handler; |
| 141 | + return this; |
| 142 | + } |
| 143 | + |
| 144 | + public Builder withIntervalFetching(long mills) { |
| 145 | + if (mills > 0) { |
| 146 | + this.intervalFetching = mills; |
| 147 | + } |
| 148 | + return this; |
| 149 | + } |
| 150 | + |
| 151 | + public Builder withSession(Session session) { |
| 152 | + if (session == null || StringUtils.isAnyBlank(session.getDbConnStr(), session.getUserId())) { |
| 153 | + throw new IllegalArgumentException("data source and userId are not allowed to be empty"); |
| 154 | + } |
| 155 | + this.session = session; |
| 156 | + return this; |
| 157 | + } |
| 158 | + |
| 159 | + /** |
| 160 | + * Open a channel to the SQLFlow server. The serverUrl argument always ends with a port. |
| 161 | + * |
| 162 | + * @param serverUrl an address the SQLFlow server exposed. |
| 163 | + * <p>Example: "localhost:50051" |
| 164 | + */ |
| 165 | + public Builder forTarget(String serverUrl) { |
| 166 | + return withChannel(ManagedChannelBuilder.forTarget(serverUrl).usePlaintext().build()); |
| 167 | + } |
| 168 | + |
| 169 | + public SQLFlow build() { |
| 170 | + return new SQLFlow(this); |
| 171 | + } |
| 172 | + } |
55 | 173 | } |
0 commit comments