Skip to content

Commit 6e68de5

Browse files
tdoublepJRosenkranzdaviswerprashantgupta24njhill
authored andcommitted
Speculative decoding for llama and gpt_bigcode (#79)
This PR adds support for speculative decoding for `llama` and `gpt_bigcode` models. It introduces a new model type and batch type (following the same pattern as for the Flash models). The speculator and the KV cache manager are imported from `fms_extras` package. <img width="1102" alt="image" src="https://github.com/IBM/text-generation-inference/assets/7945038/1889b44e-2330-467a-b838-935979c838df"> --------- Signed-off-by: Thomas Parnell <[email protected]> Signed-off-by: Nick Hill <[email protected]> Co-authored-by: JOSHUA ROSENKRANZ <[email protected]> Co-authored-by: Davis Wertheimer <[email protected]> Co-authored-by: Prashant Gupta <[email protected]> Co-authored-by: Nick Hill <[email protected]>
1 parent 3c432fb commit 6e68de5

File tree

9 files changed

+2010
-15
lines changed

9 files changed

+2010
-15
lines changed

integration_tests/sample_client.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import time
2+
import grpc
3+
from google.protobuf import json_format
4+
from text_generation_tests.pb import generation_pb2_grpc as gpb2, generation_pb2 as pb2
5+
6+
7+
def get_streaming_response_tgis(response):
8+
stop = False
9+
generated_tokens = 0
10+
while not stop:
11+
try:
12+
x = next(response)
13+
timestamp = time.time_ns()
14+
data = json_format.MessageToDict(x)
15+
# skip first response (tokenizer output only)
16+
if "inputTokenCount" not in data:
17+
n_tokens = data["generatedTokenCount"] - generated_tokens
18+
generated_tokens = data["generatedTokenCount"]
19+
yield data, n_tokens, timestamp, True, None
20+
except Exception as e:
21+
timestamp = time.time_ns()
22+
yield None, 0, timestamp, False, e
23+
24+
25+
channel = grpc.insecure_channel("localhost:8033")
26+
stub = gpb2.GenerationServiceStub(channel)
27+
max_new_tokens = 100
28+
29+
template = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{}\n\n### Response:"
30+
num_req = 0
31+
while True:
32+
prompt_input = input(f"\n{num_req}) Enter a prompt:\n")
33+
34+
print("-" * 40)
35+
print("Output:")
36+
prompt = template.format(prompt_input)
37+
sample_request = {
38+
"model_id": "dummy-model-name",
39+
"request": {"text": prompt},
40+
"params": {
41+
"method": "GREEDY",
42+
"stopping": {
43+
"max_new_tokens": max_new_tokens,
44+
"min_new_tokens": max_new_tokens,
45+
},
46+
},
47+
}
48+
message = json_format.ParseDict(sample_request, pb2.SingleGenerationRequest())
49+
output = []
50+
total_time = 0
51+
response = stub.GenerateStream(message)
52+
response_generator = get_streaming_response_tgis(response)
53+
t0 = time.time_ns()
54+
response = ""
55+
stop = False
56+
while not stop:
57+
r, n_tokens, t, ok, err = next(response_generator)
58+
59+
if not ok:
60+
stop = True
61+
# check if we have reached end of stream
62+
if type(err) is StopIteration:
63+
continue
64+
duration = (t - t0) / 1000.0 / 1000.0
65+
record = {
66+
"response": r,
67+
"ok": ok,
68+
"error": str(err),
69+
"timestamp": t,
70+
"duration_ms": duration,
71+
"n_tokens": n_tokens,
72+
}
73+
total_time += duration
74+
response += r["text"]
75+
output.append(record)
76+
t0 = t
77+
78+
# print(json.dumps(output, indent=4))
79+
print("-" * 40)
80+
print(response)
81+
print("-" * 40)
82+
print(f"Total_time : {total_time}ms")
83+
print(f"Time_per_token : {total_time/max_new_tokens}ms")
84+
print("-" * 40)
85+
num_req += 1

router/src/batcher.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -723,8 +723,19 @@ impl<'a> TokenProcessor<'a> {
723723
let request_id = output.request_id;
724724
let next_token_id = output.token_id;
725725

726-
let e = self.entries.get_mut(&request_id)
727-
.expect("ID not found. This is a bug.");
726+
let e = self.entries.get_mut(&request_id);
727+
728+
// if a client cancelled a request and speculative decoding is
729+
// enabled, it's possible that the request will get removed
730+
// from entries table, but there can still be tokens in outputs stream
731+
// corresponding to that request. ideally we could defer removing
732+
// the request_id from the entries table until all tokens have been
733+
// processed...but for now let's just ignore them.
734+
if e.is_none() {
735+
continue;
736+
}
737+
738+
let e = e.unwrap();
728739

729740
let is_stream = e.stream_tx.is_some();
730741
let stop_seqs = &e.request.parameters.stop_seqs;

server/text_generation_server/inference_engine/tgis_native.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from transformers.models.auto.auto_factory import _BaseAutoModelClass
99

10-
from text_generation_server.models import FLASH_ATTENTION
10+
from text_generation_server.models import FLASH_ATTENTION, PAGED_ATTENTION
1111
from text_generation_server.utils import Weights
1212

1313
from text_generation_server.inference_engine import BaseInferenceEngine
@@ -80,8 +80,12 @@ def __init__(
8080
elif model_type == "gpt_bigcode":
8181
self._config.transpose = self._config.architectures[0].startswith("GPT2")
8282
aliases = {"transformer.wte.weight": ["lm_head.weight"]}
83-
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import FlashSantacoderForCausalLM
84-
model_class = FlashSantacoderForCausalLM
83+
if PAGED_ATTENTION:
84+
from text_generation_server.models.custom_modeling.paged_santacoder_modeling import PagedSantacoderForCausalLM
85+
model_class = PagedSantacoderForCausalLM
86+
else:
87+
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import FlashSantacoderForCausalLM
88+
model_class = FlashSantacoderForCausalLM
8589

8690
elif model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
8791
if sharded and self._config.alibi:
@@ -94,8 +98,12 @@ def __init__(
9498
model_class = FlashRWForCausalLM
9599

96100
elif model_type == "llama":
97-
from text_generation_server.models.custom_modeling.flash_llama_modeling import FlashLlamaForCausalLM
98-
model_class = FlashLlamaForCausalLM
101+
if PAGED_ATTENTION:
102+
from text_generation_server.models.custom_modeling.paged_llama_modeling import PagedLlamaForCausalLM
103+
model_class = PagedLlamaForCausalLM
104+
else:
105+
from text_generation_server.models.custom_modeling.flash_llama_modeling import FlashLlamaForCausalLM
106+
model_class = FlashLlamaForCausalLM
99107

100108
self._config.quantize = quantize
101109

server/text_generation_server/models/__init__.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoModelForCausalLM, PretrainedConfig
1414

1515
FLASH_ATTENTION = os.getenv("FLASH_ATTENTION", "false").lower() == "true"
16+
PAGED_ATTENTION = os.getenv("PAGED_ATTENTION", "false").lower() == "true"
1617

17-
__all__ = ["Model", "CausalLM", "Seq2SeqLM", "get_model", "FLASH_ATTENTION", "PT2_COMPILE"]
18+
__all__ = ["Model", "CausalLM", "Seq2SeqLM", "get_model", "FLASH_ATTENTION", "PAGED_ATTENTION", "PT2_COMPILE"]
1819

1920
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
2021
# in PyTorch 1.12 and later.
@@ -43,6 +44,38 @@ def get_model(
4344
model_config_dict, _kwargs = PretrainedConfig.get_config_dict(model_path)
4445
model_type = model_config_dict["model_type"]
4546

47+
if PAGED_ATTENTION:
48+
print(f"Using Paged Attention")
49+
50+
if deployment_framework != "tgis_native":
51+
print_rank_n(
52+
f"WARNING: Using deployment engine tgis_native rather than {deployment_framework} "
53+
"because PAGED_ATTENTION is enabled"
54+
)
55+
deployment_framework = "tgis_native"
56+
57+
if model_type == "llama":
58+
# Custom config type for LLaMA models
59+
from text_generation_server.models.custom_modeling.paged_llama_modeling import LlamaConfig
60+
model_config = LlamaConfig.from_pretrained(model_path)
61+
elif model_type == "gpt_bigcode":
62+
from transformers import GPTBigCodeConfig
63+
model_config = GPTBigCodeConfig.from_pretrained(model_path)
64+
# num_key_value_heads is used in creating cache, here we add that attribute based on mqa
65+
model_config.num_key_value_heads = 1 if model_config.multi_query else model_config.num_attention_heads
66+
else:
67+
raise NotImplementedError("PAGED_ATTENTION only supported for gpt_bigcode and llama for now")
68+
69+
from text_generation_server.models.paged_causal_lm import PagedCausalLM
70+
return PagedCausalLM(
71+
model_name,
72+
revision,
73+
deployment_framework,
74+
dtype, quantize,
75+
model_config,
76+
max_sequence_length=max_sequence_length,
77+
)
78+
4679
if FLASH_ATTENTION:
4780
# This will raise an exception if flash attention is not supported by the device
4881
import text_generation_server.utils.flash_attn as flash_attn

0 commit comments

Comments
 (0)