diff --git a/benchmarks/benchmark_serving_structured_output.py b/benchmarks/benchmark_serving_structured_output.py index 55001cf3722a..22bdbc3f1e89 100644 --- a/benchmarks/benchmark_serving_structured_output.py +++ b/benchmarks/benchmark_serving_structured_output.py @@ -115,6 +115,43 @@ class SampleRequest: def sample_requests( tokenizer: PreTrainedTokenizerBase, args: argparse.Namespace ) -> list[SampleRequest]: + def _apply_random_prefix( + tokenizer: PreTrainedTokenizerBase, + requests: list[SampleRequest], + prefix_len: int, + seed: int, + ) -> list[SampleRequest]: + if prefix_len <= 0: + return requests + rng = np.random.default_rng(seed) + vocab_size = tokenizer.vocab_size + prohibited = getattr(tokenizer, "all_special_ids", None) or [] + allowed = np.array([i for i in range(vocab_size) if i not in prohibited]) + if len(allowed) == 0: + return requests + prefix_ids = rng.integers(0, len(allowed), size=prefix_len) + prefix_token_ids = allowed[prefix_ids].tolist() + out = [] + for req in requests: + prompt_ids = tokenizer( + req.prompt, add_special_tokens=False + ).input_ids + full_ids = prefix_token_ids + prompt_ids + full_prompt = tokenizer.decode( + full_ids, add_special_tokens=False + ) + out.append( + SampleRequest( + prompt=full_prompt, + prompt_len=len(full_ids), + expected_output_len=req.expected_output_len, + schema=req.schema, + structure_type=req.structure_type, + completion=req.completion, + ) + ) + return out + if args.dataset == "json" or args.dataset == "json-unique": if args.json_schema_path is None: dir_path = os.path.dirname(os.path.realpath(__file__)) @@ -261,6 +298,9 @@ def _filter_func(item): ) ) + requests = _apply_random_prefix( + tokenizer, requests, args.random_prefix_len, args.seed + ) return requests @@ -945,6 +985,15 @@ def create_argument_parser(): "results in a more uniform arrival of requests.", ) parser.add_argument("--seed", type=int, default=0) + parser.add_argument( + "--random-prefix-len", + type=int, + default=0, + help=( + "Number of prefix tokens to prepend to every prompt. " + "The same prefix is used for all prompts to enable prefix caching." + ), + ) parser.add_argument( "--trust-remote-code", action="store_true", diff --git a/tt_metal/README.md b/tt_metal/README.md index ec298ef0e286..54314f3aace3 100644 --- a/tt_metal/README.md +++ b/tt_metal/README.md @@ -124,7 +124,7 @@ MESH_DEVICE=T3K python examples/offline_inference_tt.py --measure_perf Example commands: -- To run the Llama70B model on Galaxy: `MESH_DEVICE=TG LLAMA_DIR= TT_LLAMA_TEXT_VER="llama3_70b_galaxy" python examples/offline_inference_tt.py --model "meta-llama/Llama-3.1-70B-Instruct" --override_tt_config '{"dispatch_core_axis": "col", "sample_on_device_mode": "all", "fabric_config": "FABRIC_1D_RING", "worker_l1_size": 1344544, "trace_region_size": 184915840}'` +- To run the Llama70B model on Galaxy: `MESH_DEVICE=TG LLAMA_DIR= TT_LLAMA_TEXT_VER="llama3_70b_galaxy" python examples/offline_inference_tt.py --model "meta-llama/Llama-3.1-70B-Instruct" --override_tt_config '{"dispatch_core_axis": "col", "sample_on_device_mode": "all", "fabric_config": "FABRIC_1D_RING", "worker_l1_size": 1344544, "trace_region_size": 216580672}'` - To run the 20B gpt-oss model on Galaxy: `MESH_DEVICE="(4,8)" python examples/offline_inference_tt.py --model "openai/gpt-oss-20b" --max_seqs_in_batch 1 --override_tt_config '{"fabric_config": "FABRIC_1D_RING"}'`