forked from deepjavalibrary/djl-serving
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrequest.py
More file actions
194 lines (162 loc) · 6.94 KB
/
request.py
File metadata and controls
194 lines (162 loc) · 6.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
#!/usr/bin/env python
#
# Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
import inspect
import json
from typing import Union, Callable, Any, List, Dict, Optional
from djl_python.output_formatter import get_output_formatter, adapt_legacy_output_formatter
from djl_python.request_io import Token, TextGenerationOutput, TextInput, RequestOutput, RequestInput
from djl_python.utils import is_streaming
class Request(object):
"""
This class represents each request that comes to the handler.
In rolling batch, handler is called for each forward function.
So this class represents the states of each request until the
last token is generated.
"""
def __init__(self, request_input: RequestInput = None):
"""
Initialize a request
:param id: request id
"""
#TODO: Remove some of these redundant attributes and
# use request_input and request_output wherever necessary.
self.id = request_input.request_id
self.request_input = request_input
self.input_text = request_input.input_text
self.last_token = False
self.adapter = request_input.adapters
# server parameters may not be set, if custom input formatter is used.
if not self.request_input.server_parameters:
self.request_input.server_parameters = self.request_input.parameters.copy(
)
self.parameters = self.request_input.server_parameters
# output formatter
request_input.output_formatter = self.parameters.pop(
"output_formatter", request_input.output_formatter)
# stream parameter is only used for determining the output.
stream = self.parameters.pop("stream", False)
# details is only used in output formatter for rolling batch
self.parameters.pop("details", False)
self.output_formatter, self.content_type = get_output_formatter(
request_input.output_formatter, stream, request_input.tgi_compat)
request_input.output_formatter = self.output_formatter
self.legacy_formatter = self._is_output_formatter_legacy()
self.request_output = TextGenerationOutput(request_id=self.id,
input=self.request_input)
self.next_token_str = ""
self.error_message = None
self.error_code = None
def _is_output_formatter_legacy(self):
signature_parameters = list(
inspect.signature(self.output_formatter).parameters.values())
return signature_parameters[0].annotation not in [
RequestOutput, TextGenerationOutput
]
def __repr__(self):
return f"<Request id: {self.id} Input {self.input_text} Parameters {self.parameters} Finished {self.last_token}>"
def set_next_token(self,
next_token: Union[Token, str],
last_token: bool = False,
finish_reason: str = None):
"""
Sets the newly generated token.
If the function is called for multiple times, we will append tokens to the token string.
:param next_token: next token to be set.
:param last_token: whether this token is the last of the sequence.
:param finish_reason: what reason made the generation ends. Current options:
length: end because max_output_token size reached
eos_token: End of sequence token found
stop_sequence: Preset stop sequence token found
"""
if isinstance(next_token, str):
next_token = Token(-1, next_token)
next_token.request_id = self.id
self.request_output.set_next_token(next_token,
is_last_token=last_token,
finish_reason=finish_reason)
self.last_token = last_token
if last_token:
self.request_output.finished = True
def get_next_token(self) -> str:
"""
Gets the token generated for the request.
:return: next_token
"""
if self.is_cancelled():
return ""
if self.next_token_str:
return self.next_token_str
if self.legacy_formatter:
self.next_token_str = adapt_legacy_output_formatter(
self.request_output)
elif not is_streaming(self.request_output.input.parameters):
# there is no need for iterators in non-streaming use-cases
self.next_token_str = self.output_formatter(self.request_output)
else:
best_sequence = self.request_output.sequences[
self.request_output.best_sequence_index]
while best_sequence.has_next_token():
self.next_token_str += self.output_formatter(
self.request_output)
return self.next_token_str
def reset_next_token(self):
"""
Reset the next token.
"""
self.next_token_str = ""
def is_last_token(self) -> bool:
"""
Whether the generated token is the last one
:return: whether last token of the sequence.
"""
return self.last_token
def get_content_type(self) -> str:
"""
Content type of this particular request
:return: content type
"""
return self.content_type
def get_error_message(self) -> Optional[str]:
"""
Error message for the request if inference failed
:return: the error message
"""
return self.error_message
def get_error_code(self) -> Optional[int]:
"""
HTTP Status code to return when inference fails
:return: the status code
"""
return self.error_code
def set_error_message(self, error_message: str):
"""
Sets the Error message for the request if inference failed
"""
self.error_message = error_message
def set_error_code(self, code: int):
"""
Sets the HTTP Status code to return when inference fails
"""
self.error_code = code
def get_client_request_id(self) -> str:
"""
Returns the requestId specified in the HTTP request
:return: the requestId specified in the HTTP request
"""
return self.request_input.client_request_id
def is_cancelled(self) -> bool:
"""
Returns whether the request has been cancelled by the client
:return: true if the request is cancelled
"""
return self.request_input.is_cancelled