diff --git a/torchchat/usages/openai_api.py b/torchchat/usages/openai_api.py index 72a6dfc9b..99fd82fe8 100644 --- a/torchchat/usages/openai_api.py +++ b/torchchat/usages/openai_api.py @@ -388,6 +388,8 @@ def callback(x, *, done_generating=False): device_sync(device=self.builder_args.device) + buffer = [] + ILLEGAL_CHAR = '\ufffd' # Process each token, metrics tuple yielded by Generator.generate. for y, _ in self.generate( model=self.model, @@ -413,10 +415,15 @@ def callback(x, *, done_generating=False): break y = y.view(-1) + buffer.append(y.item()) # Decode the torch.Tensor token to a string and append to the buffer. Separate the sequences with a period token. content = "".join( - self.tokenizer.decode([self.tokenizer.encode(".")[0]] + y.tolist())[1:] + self.tokenizer.decode([self.tokenizer.encode(".")[0]] + buffer)[1:] ) + # Skip content while illegal characters appear. + if ILLEGAL_CHAR in content: + continue + buffer.clear() # Package the sequence into a CompletionChunkResponse and yield it. chunk_delta = ChunkDelta(