Skip to content

Commit ab351a8

Browse files
committed
Add duration metric for model forward pass in isolation
Signed-off-by: Nick Hill <[email protected]>
1 parent 4d3866a commit ab351a8

File tree

10 files changed

+69
-39
lines changed

10 files changed

+69
-39
lines changed

proto/generate.proto

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,12 @@ message PrefillRequest {
157157
message GenerateResult {
158158
/// Next tokens
159159
repeated Token output_tokens = 1;
160+
/// Request-specific errors
160161
repeated GenerateError errors = 2;
161162
uint64 batch_id = 3;
163+
164+
/// Time taken by model forward pass in nanoseconds
165+
uint64 forward_time_ns = 4;
162166
}
163167

164168
message PrefillResponse {

router/client/src/client.rs

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ pub struct Client {
1515
stub: TextGenerationServiceClient<Channel>,
1616
}
1717

18+
pub type GenerateTokenResponse = (Vec<Token>, Vec<InputTokens>, Vec<GenerateError>, u64, Duration);
19+
1820
impl Client {
1921
/// Returns a client connected to the given url
2022
pub async fn connect(uri: Uri) -> Result<Self> {
@@ -116,7 +118,7 @@ impl Client {
116118
#[instrument(skip(self))]
117119
pub async fn prefill(
118120
&mut self, batch: Batch, to_prune: Vec<CachedBatch>,
119-
) -> Result<(Vec<Token>, Vec<InputTokens>, Vec<GenerateError>, u64)> {
121+
) -> Result<GenerateTokenResponse> {
120122
let request = tonic::Request::new(PrefillRequest{
121123
batch: Some(batch), to_prune,
122124
});
@@ -129,17 +131,22 @@ impl Client {
129131
let result = response
130132
.result
131133
.ok_or_else(|| ClientError::Generation("Unexpected empty response".into()))?;
132-
Ok((result.output_tokens, response.input_tokens, result.errors, result.batch_id))
134+
Ok((
135+
result.output_tokens,
136+
response.input_tokens,
137+
result.errors,
138+
result.batch_id,
139+
Duration::from_nanos(result.forward_time_ns),
140+
))
133141
}
134142

135143
/// Generate one token for each request in the given cached batch(es)
136144
///
137145
/// Returns next generated token of each request in the batches and id of the next cached batch
138146
#[instrument(skip(self))]
139147
pub async fn next_token(
140-
&mut self,
141-
batches: Vec<CachedBatch>,
142-
) -> Result<Option<(Vec<Token>, Vec<InputTokens>, Vec<GenerateError>, u64)>> {
148+
&mut self, batches: Vec<CachedBatch>,
149+
) -> Result<Option<GenerateTokenResponse>> {
143150
let request = tonic::Request::new(
144151
NextTokenRequest { batches }
145152
);
@@ -149,6 +156,12 @@ impl Client {
149156
.instrument(info_span!("generate_with_cache"))
150157
.await?
151158
.into_inner();
152-
Ok(response.result.map(|r| (r.output_tokens, vec![], r.errors, r.batch_id)))
159+
Ok(response.result.map(|result| (
160+
result.output_tokens,
161+
vec![],
162+
result.errors,
163+
result.batch_id,
164+
Duration::from_nanos(result.forward_time_ns),
165+
)))
153166
}
154167
}

router/client/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ pub use pb::generate::v1::{
1313
};
1414
pub use pb::generate::v1::next_token_chooser_parameters::LengthPenalty;
1515
pub use sharded_client::ShardedClient;
16+
pub use client::GenerateTokenResponse;
1617
use thiserror::Error;
1718
use tonic::transport;
1819
use tonic::Status;

router/client/src/sharded_client.rs

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
/// Multi shard Client
2-
use crate::{ClientError, GenerateError, Result};
3-
use crate::{Batch, Client, HealthResponse, Token};
2+
use crate::{ClientError, Result};
3+
use crate::{Batch, Client, HealthResponse};
44
use futures::future::join_all;
55
use tokio::runtime::Handle;
66
use tokio::sync::{broadcast, mpsc};
77
use tonic::transport::Uri;
8-
use crate::pb::generate::v1::{CachedBatch, InputTokens};
8+
use crate::client::GenerateTokenResponse;
9+
use crate::pb::generate::v1::CachedBatch;
910
use crate::pb::generate::v1::model_info_response::ModelType;
1011
use crate::sharded_client::Request::{NextToken, Prefill};
1112

@@ -19,9 +20,7 @@ enum Request {
1920
#[derive(Debug)]
2021
pub struct ShardedClient {
2122
clients: Vec<Client>,
22-
sender: broadcast::Sender<(Request, mpsc::Sender<
23-
Result<Option<(Vec<Token>, Vec<InputTokens>, Vec<GenerateError>, u64)>>
24-
>)>,
23+
sender: broadcast::Sender<(Request, mpsc::Sender<Result<Option<GenerateTokenResponse>>>)>,
2524
handle: Handle,
2625
}
2726

@@ -94,7 +93,7 @@ impl ShardedClient {
9493
/// Optionally prunes existing batches first to maximize available memory
9594
pub async fn prefill(
9695
&mut self, batch: Batch, to_prune: Vec<CachedBatch>,
97-
) -> Result<Option<(Vec<Token>, Vec<InputTokens>, Vec<GenerateError>, u64)>> {
96+
) -> Result<Option<GenerateTokenResponse>> {
9897
if batch.requests.is_empty() {
9998
return Ok(None);
10099
}
@@ -108,9 +107,8 @@ impl ShardedClient {
108107
///
109108
/// Returns next generated token of each request in the batches and id of the next cached batch
110109
pub async fn next_token(
111-
&mut self,
112-
batches: Vec<CachedBatch>,
113-
) -> Result<Option<(Vec<Token>, Vec<InputTokens>, Vec<GenerateError>, u64)>> {
110+
&mut self, batches: Vec<CachedBatch>,
111+
) -> Result<Option<GenerateTokenResponse>> {
114112
let (tx, mut rx) = mpsc::channel(1);
115113
self.sender.send((NextToken(batches), tx))
116114
.map_err(|e| ClientError::Generation(e.to_string()))?;

router/src/batcher.rs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use std::task::{Context, Poll};
1313
use futures::{FutureExt, pin_mut, TryFutureExt};
1414
use futures::future::Map;
1515
use nohash_hasher::IntMap;
16-
use text_generation_client::{ClientError, Token, ShardedClient, CachedBatch, RequestsStatus, InputTokens, GenerateError, Batch};
16+
use text_generation_client::{ClientError, Token, ShardedClient, CachedBatch, RequestsStatus, InputTokens, GenerateError, Batch, GenerateTokenResponse};
1717
use thiserror::Error;
1818
use tokio::select;
1919

@@ -547,9 +547,7 @@ impl<'a> TokenProcessor<'a> {
547547
/// Wrap a future inside a match statement to handle errors and send the response to the Batcher
548548
async fn _wrap_future<B: BatchType>(
549549
&mut self,
550-
future: impl Future<Output = Result<
551-
Option<(Vec<Token>, Vec<InputTokens>, Vec<GenerateError>, u64)>, ClientError
552-
>>,
550+
future: impl Future<Output = Result<Option<GenerateTokenResponse>, ClientError>>,
553551
method: &'static str,
554552
start_time: Instant,
555553
// First request id in this batch if it doesn't comprise all current entries
@@ -573,7 +571,7 @@ impl<'a> TokenProcessor<'a> {
573571

574572
match result {
575573
Ok(
576-
Some((generated_tokens, input_tokens, errors, next_batch_id))
574+
Some((generated_tokens, input_tokens, errors, next_batch_id, forward_duration))
577575
) => {
578576
self.process_input_tokens(input_tokens);
579577
let completed_request_ids = self.process_next_tokens(
@@ -587,6 +585,12 @@ impl<'a> TokenProcessor<'a> {
587585
"method" => method,
588586
"makeup" => "single_only", // later will possibly be beam_only or mixed
589587
);
588+
metrics::histogram!(
589+
"tgi_batch_inference_forward_duration",
590+
forward_duration,
591+
"method" => method,
592+
"makeup" => "single_only", // later will possibly be beam_only or mixed
593+
);
590594
// Probably don't need this additional counter because the duration histogram
591595
// records a total count
592596
metrics::increment_counter!("tgi_batch_inference_success", "method" => method);

server/text_generation_server/models/causal_lm.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import time
23
from operator import itemgetter
34

45
import torch
@@ -467,7 +468,7 @@ def __init__(
467468

468469
# Perform a forward pass to determine the ordering of past key attention tensor dimensions
469470
one_token = torch.tensor([[1]], device=inference_engine.get_device())
470-
_, past_key_values = self.forward(input_ids=one_token, attention_mask=one_token)
471+
_, past_key_values, _ = self.forward(input_ids=one_token, attention_mask=one_token)
471472
key_past, value_past = past_key_values[0]
472473
keys_head_dim_last = key_past.shape[-1] == value_past.shape[-1]
473474
self.batch_type = CausalLMBatch if keys_head_dim_last else KeysDimTransposedCausalLMBatch
@@ -487,7 +488,7 @@ def forward(
487488
position_ids: Optional[torch.Tensor] = None,
488489
past_key_values: Optional = None,
489490
inputs_embeds: Optional[torch.Tensor] = None,
490-
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
491+
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]], int]:
491492
model_inputs = self.model.prepare_inputs_for_generation(
492493
input_ids, past_key_values,
493494
attention_mask=attention_mask,
@@ -506,18 +507,18 @@ def forward(
506507
model_inputs["inputs_embeds"] = inputs_embeds
507508

508509
# Model Forward
510+
start_time = time.time_ns()
509511
outputs = self.model.forward(**model_inputs)
510-
return (
511-
outputs.logits, outputs.past_key_values,
512-
)
512+
took_ns = time.time_ns() - start_time
513+
return outputs.logits, outputs.past_key_values, took_ns
513514

514515
def generate_token(
515516
self, batch: CausalLMBatch, first: bool = False, for_concat: bool = False,
516-
) -> Tuple[List[TokenInfo], Optional[List[InputTokens]], List[GenerateError]]:
517+
) -> Tuple[List[TokenInfo], Optional[List[InputTokens]], List[GenerateError], int]:
517518
# slice the attention mask to the correct shape
518519
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
519520

520-
logits, past = self.forward(
521+
logits, past, forward_time_ns = self.forward(
521522
batch.input_ids, attention_mask, batch.position_ids, batch.past_key_values, batch.inputs_embeds,
522523
)
523524

@@ -605,7 +606,7 @@ def generate_token(
605606
batch.max_sequence_length += 1
606607
batch.padding_right_offset -= 1
607608

608-
return generated_tokens, input_token_infos, decode_errors
609+
return generated_tokens, input_token_infos, decode_errors, forward_time_ns
609610

610611

611612
class KeysDimTransposedCausalLMBatch(CausalLMBatch):

server/text_generation_server/models/flash_causal_lm.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import time
23
from operator import itemgetter
34

45
import torch
@@ -319,7 +320,7 @@ def batch_type(self) -> Type[FlashCausalLMBatch]:
319320

320321
def generate_token(
321322
self, batch: FlashCausalLMBatch, first: bool = False, for_concat: bool = False,
322-
) -> Tuple[List[TokenInfo], Optional[List[InputTokens]], List[GenerateError]]:
323+
) -> Tuple[List[TokenInfo], Optional[List[InputTokens]], List[GenerateError], int]:
323324

324325
batch_size = len(batch)
325326
past_key_values = batch.past_key_values if first or batch_size > 1 \
@@ -333,6 +334,7 @@ def generate_token(
333334
else:
334335
prealloc_length = None
335336

337+
start_time = time.time_ns()
336338
out, present = self.model.forward(
337339
batch.input_ids,
338340
batch.position_ids,
@@ -342,6 +344,7 @@ def generate_token(
342344
past_key_values,
343345
prealloc_length,
344346
)
347+
forward_time_ns = time.time_ns() - start_time
345348

346349
# Update present
347350
present_pad = self.present_pad
@@ -369,7 +372,7 @@ def generate_token(
369372
batch.cu_seqlens.add_(batch.cu_seqlens_q)
370373
batch.max_seqlen += 1
371374

372-
return generated_tokens, input_token_infos, decode_errors
375+
return generated_tokens, input_token_infos, decode_errors, forward_time_ns
373376

374377
def _process_prefill(
375378
self, batch: FlashCausalLMBatch, out,

server/text_generation_server/models/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def batch_type(self) -> Type[B]:
123123
@abstractmethod
124124
def generate_token(
125125
self, batch: B, first: bool = False, for_concat: bool = False,
126-
) -> Tuple[List[TokenInfo], Optional[List[InputTokens]], List[GenerateError]]:
126+
) -> Tuple[List[TokenInfo], Optional[List[InputTokens]], List[GenerateError], int]:
127127
raise NotImplementedError
128128

129129
@staticmethod

server/text_generation_server/models/seq2seq_lm.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import time
23
from operator import itemgetter
34

45
import torch
@@ -492,7 +493,7 @@ def __init__(
492493

493494
# Perform a forward pass to determine the ordering of past key attention tensor dimensions
494495
one_token = torch.tensor([[bos_token_id]], device=inference_engine.get_device())
495-
_, _, past_key_values = self.forward(
496+
_, _, past_key_values, _ = self.forward(
496497
input_ids=one_token,
497498
attention_mask=torch.ones_like(one_token),
498499
decoder_input_ids=one_token,
@@ -523,12 +524,14 @@ def forward(
523524
torch.Tensor,
524525
torch.Tensor,
525526
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
527+
int,
526528
]:
527529
if inputs_embeds is not None:
528530
input_ids = None
529531
if decoder_inputs_embeds is not None:
530532
decoder_input_ids = None
531533

534+
start_time = time.time_ns()
532535
outputs = self.model.forward(
533536
input_ids=input_ids,
534537
inputs_embeds=inputs_embeds,
@@ -541,21 +544,22 @@ def forward(
541544
use_cache=True,
542545
return_dict=True,
543546
)
547+
took_ns = time.time_ns() - start_time
544548
return (
545-
outputs.logits, outputs.encoder_last_hidden_state, outputs.past_key_values,
549+
outputs.logits, outputs.encoder_last_hidden_state, outputs.past_key_values, took_ns
546550
)
547551

548552
def generate_token(
549553
self, batch: Seq2SeqLMBatch, first: bool = False, for_concat: bool = False,
550-
) -> Tuple[List[TokenInfo], Optional[List[InputTokens]], List[GenerateError]]:
554+
) -> Tuple[List[TokenInfo], Optional[List[InputTokens]], List[GenerateError], int]:
551555
# slice to the correct shape
552556
decoder_attention_mask = None if batch.decoder_attention_mask is None \
553557
else batch.decoder_attention_mask[:, : -batch.padding_right_offset]
554558

555559
encoder_outputs = None if batch.encoder_last_hidden_state is None \
556560
else BaseModelOutput(last_hidden_state=batch.encoder_last_hidden_state)
557561

558-
logits, encoder_last_hidden_state, past = self.forward(
562+
logits, encoder_last_hidden_state, past, forward_time_ns = self.forward(
559563
batch.input_ids,
560564
batch.attention_mask,
561565
batch.decoder_input_ids,
@@ -647,7 +651,7 @@ def generate_token(
647651
batch.max_decoder_input_length += 1
648652
batch.padding_right_offset -= 1
649653

650-
return generated_tokens, input_token_infos, decode_errors
654+
return generated_tokens, input_token_infos, decode_errors, forward_time_ns
651655

652656

653657
class KeysDimTransposedSeq2SeqLMBatch(Seq2SeqLMBatch):

server/text_generation_server/server.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ async def Prefill(self, request: generate_pb2.PrefillRequest, context) -> genera
120120
if batch is not None:
121121
for_concat = len(self.cache) > 0
122122
# Prefill and generate first token
123-
output_tokens, input_token_info, decode_errors = self.model.generate_token(
123+
output_tokens, input_token_info, decode_errors, forward_time_ns = self.model.generate_token(
124124
batch, first=True, for_concat=for_concat,
125125
)
126126
if not is_healthcheck:
@@ -140,6 +140,7 @@ async def Prefill(self, request: generate_pb2.PrefillRequest, context) -> genera
140140
],
141141
errors=[err.to_pb() for err in errors] if errors else None,
142142
batch_id=batch_id,
143+
forward_time_ns=forward_time_ns,
143144
),
144145
input_tokens=[
145146
input_tokens.to_pb() for input_tokens in input_token_info
@@ -175,7 +176,7 @@ async def NextToken(self, request: generate_pb2.NextTokenRequest, context) -> ge
175176
# Ensure batches are garbage-collected post-concatenation
176177
del batches
177178

178-
output_tokens, _, errors = self.model.generate_token(batch)
179+
output_tokens, _, errors, forward_time_ns = self.model.generate_token(batch)
179180
self.cache.set(batch)
180181

181182
return generate_pb2.NextTokenResponse(
@@ -185,6 +186,7 @@ async def NextToken(self, request: generate_pb2.NextTokenRequest, context) -> ge
185186
],
186187
errors=[err.to_pb() for err in errors] if errors else None,
187188
batch_id=batch.get_id(),
189+
forward_time_ns=forward_time_ns,
188190
)
189191
)
190192

0 commit comments

Comments
 (0)