Skip to content

Commit 0103d69

Browse files
tjohnson31415njhill
authored andcommitted
feat: implement CombinedKVCausalLMBatch to support GPTBigCode
Signed-off-by: Travis Johnson <[email protected]>
1 parent ad3fb90 commit 0103d69

File tree

4 files changed

+414
-84
lines changed

4 files changed

+414
-84
lines changed
Binary file not shown.
Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
# Test empty requests
2+
- name: Empty 1
3+
request: {}
4+
response: {}
5+
- name: Empty 2
6+
request:
7+
params: {}
8+
requests: []
9+
response: {}
10+
11+
# Simple
12+
- name: Simple
13+
request:
14+
requests:
15+
- {"text": "def hello_world():\n"}
16+
response:
17+
responses:
18+
- generatedTokenCount: 14
19+
inputTokenCount: 6
20+
stopReason: EOS_TOKEN
21+
text: "\tprint(\"Hello World!\")\n\nhello_world()\n"
22+
23+
# Basic Greedy (implicit)
24+
- name: Basic Greedy, max new tokens (implicit)
25+
request:
26+
requests:
27+
- {"text": "'''Implement the class Shape'''\n"}
28+
response:
29+
responses:
30+
- generatedTokenCount: 20
31+
inputTokenCount: 7
32+
stopReason: MAX_TOKENS
33+
text: "\nclass Shape(object):\n '''Shape class'''\n\n def __init__(self, x,"
34+
35+
# Basic Greedy (explicit)
36+
- name: Basic Greedy, max new tokens (implicit)
37+
request:
38+
params:
39+
method: GREEDY
40+
stopping: {"maxNewTokens": 24}
41+
requests:
42+
- {"text": "'''Implement the class Shape'''\n"}
43+
response:
44+
responses:
45+
- generatedTokenCount: 24
46+
inputTokenCount: 7
47+
stopReason: MAX_TOKENS
48+
text: "\nclass Shape(object):\n '''Shape class'''\n\n def __init__(self, x, y, z):"
49+
50+
# Multiple inputs with token info
51+
- name: Multiple inputs with token info
52+
request:
53+
params:
54+
method: GREEDY
55+
stopping: {"maxNewTokens": 2}
56+
response:
57+
generatedTokens: true
58+
tokenLogprobs: true
59+
topNTokens: 2
60+
requests:
61+
- {"text": "def hello_world():\n"}
62+
- {"text": "def merge_lists("}
63+
- {"text": "if __name__ == \""}
64+
response:
65+
responses:
66+
- generatedTokenCount: 2
67+
inputTokenCount: 6
68+
stopReason: MAX_TOKENS
69+
text: "\tprint"
70+
tokens:
71+
- logprob: -0.08069111
72+
text: "\u0109"
73+
topTokens:
74+
- logprob: -0.08069111
75+
text: "\u0109"
76+
- logprob: -3.2008388
77+
text: '#'
78+
- logprob: -0.89866674
79+
text: print
80+
topTokens:
81+
- logprob: -0.89866674
82+
text: print
83+
- logprob: -1.8317665
84+
text: return
85+
- generatedTokenCount: 2
86+
inputTokenCount: 5
87+
stopReason: MAX_TOKENS
88+
text: l1
89+
tokens:
90+
- logprob: -1.9720234
91+
text: l
92+
topTokens:
93+
- logprob: -1.9720234
94+
text: l
95+
- logprob: -2.3360019
96+
text: list
97+
- logprob: -0.24351147
98+
text: '1'
99+
topTokens:
100+
- logprob: -0.24351147
101+
text: '1'
102+
- logprob: -2.4751484
103+
text: ','
104+
- generatedTokenCount: 2
105+
inputTokenCount: 6
106+
stopReason: MAX_TOKENS
107+
text: 'main":'
108+
tokens:
109+
- logprob: -1.5838054
110+
text: main
111+
topTokens:
112+
- logprob: -1.5838054
113+
text: main
114+
- logprob: -3.0222993
115+
text: test
116+
- logprob: -0.18766436
117+
text: '":'
118+
topTokens:
119+
- logprob: -0.18766436
120+
text: '":'
121+
- logprob: -2.5319178
122+
text: '"'
123+
124+
125+
# Prompt prefix
126+
- name: Greedy with tuned prompt prefix
127+
# Prompt prefixes with multi-shard not yet supported
128+
singleShardOnly: true
129+
request:
130+
# Prefix is "def hello_world():\n"
131+
prefixId: tiny_starcoder
132+
params:
133+
method: GREEDY
134+
requests:
135+
- {"text": "\tprint"}
136+
response:
137+
responses:
138+
- generatedTokenCount: 12
139+
inputTokenCount: 2
140+
stopReason: EOS_TOKEN
141+
text: "(\"Hello World!\")\n\nhello_world()\n"
142+
143+
144+
# Prompt prefix returning input and generated tokens
145+
- name: Greedy with tuned prompt prefix and returned tokens
146+
# Prompt prefixes with multi-shard not yet supported
147+
singleShardOnly: true
148+
request:
149+
# Prefix is "def hello_world():\n"
150+
prefixId: tiny_starcoder
151+
params:
152+
method: GREEDY
153+
stopping: {"maxNewTokens": 2}
154+
response:
155+
inputTokens: true
156+
generatedTokens: true
157+
tokenLogprobs: true
158+
tokenRanks: true
159+
topNTokens: 2
160+
requests:
161+
- {"text": "\tprint(\"Hello"}
162+
response:
163+
responses:
164+
- generatedTokenCount: 2
165+
inputTokenCount: 4
166+
text: ' World!")'
167+
stopReason: MAX_TOKENS
168+
inputTokens:
169+
- logprob: NaN
170+
text: <|endoftext|>
171+
- logprob: -10.14109
172+
rank: 2574
173+
text: <|endoftext|>
174+
topTokens:
175+
- logprob: -3.447822
176+
text: "\u0120_"
177+
- logprob: -3.672276
178+
text: "\u0120__"
179+
- logprob: -12.594888
180+
rank: 1165
181+
text: <|endoftext|>
182+
topTokens:
183+
- logprob: -1.1129533
184+
text: _
185+
- logprob: -1.2004529
186+
text: (
187+
- logprob: -13.206944
188+
rank: 4837
189+
text: <|endoftext|>
190+
topTokens:
191+
- logprob: -0.32641557
192+
text: world
193+
- logprob: -4.8018546
194+
text: server
195+
- logprob: -11.724733
196+
rank: 76
197+
text: <|endoftext|>
198+
topTokens:
199+
- logprob: -0.70839006
200+
text: '():'
201+
- logprob: -0.9568966
202+
text: (
203+
- logprob: -11.811299
204+
rank: 122
205+
text: <|endoftext|>
206+
topTokens:
207+
- logprob: -0.15292865
208+
text: "\u010A\u0120\u0120\u0120"
209+
- logprob: -3.31403
210+
text: "\u010D\u010A\u0120\u0120\u0120"
211+
- logprob: -0.080691434
212+
rank: 1
213+
text: "\u0109"
214+
topTokens:
215+
- logprob: -0.080691434
216+
text: "\u0109"
217+
- logprob: -3.2008343
218+
text: '#'
219+
- logprob: -0.8986669
220+
rank: 1
221+
text: print
222+
topTokens:
223+
- logprob: -0.8986669
224+
text: print
225+
- logprob: -1.8317685
226+
text: return
227+
- logprob: -0.67005044
228+
rank: 1
229+
text: ("
230+
topTokens:
231+
- logprob: -0.67005044
232+
text: ("
233+
- logprob: -1.3652618
234+
text: ('
235+
- logprob: -0.6229511
236+
rank: 1
237+
text: Hello
238+
topTokens:
239+
- logprob: -0.6229511
240+
text: Hello
241+
- logprob: -1.4623008
242+
text: hello
243+
tokens:
244+
- logprob: -0.61369985
245+
rank: 1
246+
text: "\u0120World"
247+
topTokens:
248+
- logprob: -0.61369985
249+
text: "\u0120World"
250+
- logprob: -1.7381792
251+
text: ','
252+
- logprob: -0.7115159
253+
rank: 1
254+
text: '!")'
255+
topTokens:
256+
- logprob: -0.7115159
257+
text: '!")'
258+
- logprob: -1.0358996
259+
text: '")'

integration_tests/text_generation_tests/test_server.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
INCLUDE_STREAMING = True
2121
TESTS_TIMEOUT = 300.0 # 5 mins
22+
TESTS_DIR = os.path.abspath(os.path.join(os.path.dirname(os.path.realpath(__file__)), ".."))
2223

2324

2425
def start_server(
@@ -65,6 +66,7 @@ def start_server(
6566

6667
env = os.environ.copy()
6768
env["RUST_BACKTRACE"] = "full"
69+
env["PREFIX_STORE_PATH"] = os.path.join(TESTS_DIR, "prompt_prefixes")
6870
if not include_cache_env_vars:
6971
env.pop("TRANSFORMERS_CACHE", None)
7072
env.pop("HUGGING_FACE_HUB_CACHE", None)
@@ -122,7 +124,7 @@ def server_fixture(request):
122124
@pytest.fixture
123125
def test_cases(request):
124126
filename = request.node.get_closest_marker("test_case_file").args[0]
125-
with open(filename) as f:
127+
with open(os.path.join(TESTS_DIR, filename)) as f:
126128
return yaml.load(f, Loader=yaml.Loader)
127129

128130

@@ -290,7 +292,7 @@ async def run_test_cases_async(test_cases, seq2seq_model=False, sharded=False):
290292
async def _test_multi_input_seeds(stub):
291293
# Ensure that sending a batch of identical inputs in sampling mode results
292294
# in different output seeds and texts
293-
with open("test_cases_common.yaml") as f:
295+
with open(os.path.join(TESTS_DIR, "test_cases_common.yaml")) as f:
294296
test_case = yaml.load(f, Loader=yaml.Loader)
295297
request = test_case["seed_test"]["request"]
296298
message = json_format.ParseDict(request, pb2.BatchedGenerationRequest())
@@ -326,6 +328,15 @@ async def test_bloom(server_fixture, test_cases):
326328
async def test_mt0(server_fixture, test_cases):
327329
await run_test_cases_async(test_cases, seq2seq_model=True)
328330

331+
# test with tiny GPTBigCode model for the merged kv cache
332+
@pytest.mark.model("bigcode/tiny_starcoder_py")
333+
@pytest.mark.extensions(".safetensors,.json")
334+
@pytest.mark.shards(1)
335+
@pytest.mark.test_case_file("test_cases_tinystarcoderpy.yaml")
336+
@pytest.mark.asyncio
337+
async def test_gptbigcode(server_fixture, test_cases):
338+
await run_test_cases_async(test_cases)
339+
329340

330341
# Test distributed inference - two shards
331342
@pytest.mark.model("bigscience/bloom-560m")
@@ -375,4 +386,3 @@ def event_loop():
375386
loop = asyncio.new_event_loop()
376387
yield loop
377388
loop.close()
378-

0 commit comments

Comments
 (0)