diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index 61fa16a326e..78d7dc5abc1 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -1909,7 +1909,6 @@ class GenericLlmRequest std::shared_ptr> mSequenceFinalVec; std::optional mSkipCrossAttnBlocks{std::nullopt}; - SizeType32 mNumVocabs; // Performance metrics. bool mReturnPerfMetrics{false}; diff --git a/cpp/include/tensorrt_llm/executor/executor.h b/cpp/include/tensorrt_llm/executor/executor.h index 1cd10526632..4bb439193eb 100644 --- a/cpp/include/tensorrt_llm/executor/executor.h +++ b/cpp/include/tensorrt_llm/executor/executor.h @@ -711,12 +711,7 @@ class Request [[nodiscard]] std::optional getLanguageAdapterUid() const; [[nodiscard]] std::optional getAllottedTimeMs() const; [[nodiscard]] std::optional> getAdditionalOutputNames() const; -<<<<<<< HEAD - -======= - [[nodiscard]] std::optional getLanguageAdapterUid() const; [[nodiscard]] SizeType32 getNumVocabs() const; ->>>>>>> Fixes to compilation void setStreaming(bool streaming); void setSamplingConfig(SamplingConfig const& config); void setOutputConfig(OutputConfig const& outputConfig); @@ -748,12 +743,8 @@ class Request void setSkipCrossAttnBlocks(Tensor skipCrossAttnBlocks); void setGuidedDecodingParams(GuidedDecodingParams const& guidedDecodingParams); void setLanguageAdapterUid(SizeType32 languageAdapterUid); -<<<<<<< HEAD void setAllottedTimeMs(MillisecondsType allottedTimeMs); - -======= void setNumVocabs(SizeType32 numVocabs); ->>>>>>> Fixes to compilation private: friend class Serialization; class Impl; diff --git a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp index de53eb6ed5e..ba57d503689 100644 --- a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp +++ b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp @@ -2055,8 +2055,6 @@ runtime::CudaEvent TrtGptModelInflightBatching::updateDecoderBuffers( if (returnLogProbs) { - mDecoderBuffers[vocabId]->cumLogProbs = mDecoders[vocabId]->getDecoderState().getCumLogProbs(); - mDecoderBuffers[vocabId]->logProbs = mDecoders[vocabId]->getDecoderState().getLogProbs(); mCopyBufferManager.copy( *mDecoders[vocabId]->getDecoderState().getCumLogProbs(), *mDecoderBuffers[vocabId]->cumLogProbsHost diff --git a/examples/models/core/t5tts/README.md b/examples/models/core/t5tts/README.md new file mode 100644 index 00000000000..11017eea0e5 --- /dev/null +++ b/examples/models/core/t5tts/README.md @@ -0,0 +1,88 @@ + +# Build TRTLLM + +This describes how to run the t5tts in TRTLLM. +Build docker and compile TRTLLM as usual: + +```bash +make -C docker build IMAGE_NAME=t5tts +make -C docker run LOCAL_USER=1 IMAGE_NAME=t5tts CONTAINER_NAME=t5tts +# 90-real - for H100 +python3 ./scripts/build_wheel.py --cuda_architectures "90-real" --benchmarks --trt_root /usr/local/tensorrt +pip install build/tensorrt_llm-0.20.0rc0-cp312-cp312-linux_x86_64.whl +``` + +# Build Engine + +Convert the checkpoint and build the engine: +```bash +# required to pip install omegaconf +# md5sum newmodels/t5tts.ckpt: fb177acdc447af56c8bbfa9d17c75f45 +python examples/models/core/t5tts/convert_checkpoint.py \ + --model_path newmodels/t5tts.ckpt --output_dir newmodels/t5tts_convert + +trtllm-build --checkpoint_dir newmodels/t5tts_convert/encoder/ \ +--output_dir newmodels/t5tts_engine/encoder \ +--paged_kv_cache enable --moe_plugin disable --max_beam_width 1 \ +--max_batch_size 256 --max_input_len 128 --gemm_plugin float16 \ +--bert_attention_plugin float16 --gpt_attention_plugin float16 \ +--remove_input_padding enable --use_paged_context_fmha enable + +trtllm-build --checkpoint_dir newmodels/t5tts_convert/decoder \ + --output_dir newmodels/t5tts_engine/decoder \ + --moe_plugin disable \ + --max_beam_width 1 \ + --max_batch_size 64 \ + --max_input_len 192 \ + --max_seq_len 512 \ + --max_encoder_input_len 512 \ + --gemm_plugin float16 \ + --bert_attention_plugin float16 \ + --gpt_attention_plugin float16 \ + --remove_input_padding enable \ + --use_paged_context_fmha enable +``` + +# Toy inference + +Finally run the model on the dummy input: +```bash +python examples/models/core/t5tts/run.py +``` + +# Benchmark + +gpt manager benchmark is modified to run benchmark with context for decoder. + +```bash +# prepare dummy inputs for inference +# 128 - number of phonemes in avergage sentence +# 160 - context length in frames, corresponds to 160 / 21.5 = 7.44 seconds +# 640 - total sequence length in frames, means 640 - 160 = 480 frames of audio generated, +# which corresponds to 480 / 21.5 = 22.33 seconds +# 768 - batch_size * 3, measure performance on 3 batches at max utilization +python examples/models/core/enc_dec/prepare_benchmark.py --output benchmark.json \ + --samples 768 \ + --max_input_id 98 \ + --num_vocabs 8 \ + --input_len 128 0 128 128 \ + --context_len 160 0 160 160 \ + --output_len 640 0 640 640 + +# run benchmark using generated dummy inputs +./cpp/build/benchmarks/gptManagerBenchmark \ + --dataset benchmark.json \ + --output_csv res.csv \ + --max_batch_size 256 \ + --concurrency 256 \ + --streaming \ + --num_vocabs 8 \ + --enable_chunked_context \ + --encoder_engine_dir newmodels/t5tts_engine/encoder \ + --decoder_engine_dir newmodels/t5tts_engine/decoder 2>&1 > /dev/null + +# print results from res.csv +python3 -c "import csv; f=open('res.csv'); r=csv.reader(f); h=next(r); v=next(r); [print(f'{h[i]:<50}: {v[i]}') for i in range(len(h))]" +``` + + diff --git a/examples/models/core/t5tts/convert_checkpoint.py b/examples/models/core/t5tts/convert_checkpoint.py new file mode 100644 index 00000000000..ab1b929e4e1 --- /dev/null +++ b/examples/models/core/t5tts/convert_checkpoint.py @@ -0,0 +1,614 @@ +import argparse +import configparser +import json +import logging +import os +import types +from datetime import datetime +from pathlib import Path + +import safetensors +import torch + +from tensorrt_llm.functional import (LayerNormPositionType, LayerNormType, + MLPType) + +dir_path = os.path.dirname(os.path.realpath(__file__)) +LOGGER = logging.getLogger(__name__) + +layernorm_type_map = {i.name: i.value for i in LayerNormType} +layernorm_position_map = {i.name: i.value for i in LayerNormPositionType} +mlp_type_map = {i.name: i.value for i in MLPType} + +TORCH_DTYPES = { + 'float32': torch.float32, + 'float64': torch.float64, + 'float16': torch.float16, + 'bfloat16': torch.bfloat16, +} + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument('--quant_ckpt_path', type=str, default=None) + parser.add_argument('--model_name', type=str) + parser.add_argument( + '--model_path', + type=str, + default=None, + ) + parser.add_argument('--dtype', + type=str, + default='float16', + choices=['float32', 'bfloat16', 'float16']) + parser.add_argument('--logits_dtype', + type=str, + default='float16', + choices=['float16', 'float32']) + parser.add_argument('--output_dir', + type=str, + default='tllm_checkpoint', + help='The path to save the TensorRT-LLM checkpoint') + parser.add_argument( + '--use_weight_only', + default=False, + action="store_true", + help='Quantize weights for the various GEMMs to INT4/INT8.' + 'See --weight_only_precision to set the precision') + parser.add_argument( + '--weight_only_precision', + const='int8', + type=str, + nargs='?', + default='int8', + choices=['int8', 'int4'], + help= + 'Define the precision for the weights when using weight-only quantization.' + 'You must also use --use_weight_only for that argument to have an impact.' + ) + parser.add_argument('engine_dir') + args = parser.parse_args() + return args + + +def copy_args_to_component_config(component_config, args): + for arg in vars(args): + setattr(component_config, arg, getattr(args, arg)) + return component_config + + +def parse_model_config(args, ): + config = configparser.ConfigParser() + + config["encoder"] = {} + config["decoder"] = {} + + config["encoder"]["num_heads"] = "12" + config["encoder"]['d_model'] = "768" #hidden_size + config["encoder"]['d_ffn'] = "3072" #ffn_hidden_size + config["encoder"]['vocab_size'] = "98" # used to be 106339 in the branch + config["encoder"]['n_positions'] = "2048" + config["encoder"]['has_position_embedding'] = "true" + #config["encoder"]['has_token_type_embedding'] = + config["encoder"]['layernorm_position'] = "pre_layernorm" + + config["encoder"]['layernorm_type'] = "RmsNorm" + config["encoder"]['num_layers'] = "6" + # config["encoder"]['d_model'] /config["encoder"]["num_heads"] + config["encoder"]['d_kv'] = f"{int(768/12)}" + + config["decoder"]["num_heads"] = "12" + config["decoder"]['d_model'] = "768" #hidden_size + config["decoder"]['d_ffn'] = "3072" #ffn_hidden_size + config["decoder"]['vocab_size'] = "16384" # 8 * 2048 + config["decoder"]['n_positions'] = "2048" + config["decoder"]['has_position_embedding'] = "true" + config["decoder"]['layernorm_position'] = "pre_layernorm" + + config["decoder"]['layernorm_type'] = "RmsNorm" + config["decoder"]['num_layers'] = "12" + config["decoder"]["num_vocabs"] = "8" + + # manually set q_scaling to offset attention scaling's effect. + # TODO: modify kernels to control whether to disable attention scaling + def get_offset_q_scaling(config): + scaling = 1 / config.head_size**.5 + return scaling + + config["structure"] = dict() + config["structure"]["t5_with_bias"] = "false" + #config["structure"]["use_gated_activation"] = str(hf_model.encoder.config.is_gated_act) + config["structure"]["position_embedding_type"] = "learned_absolute" + config["structure"]["model_type"] = "T5TTS" + + def parse_t5_config_by_component(config, component, args): + component_config = types.SimpleNamespace() + component_config = copy_args_to_component_config(component_config, args) + component_config.n_head = config.getint(component, 'num_heads') + component_config.hidden_size = config.getint(component, 'd_model') + component_config.head_size = component_config.hidden_size // component_config.n_head + + component_config.ffn_hidden_size = config.getint(component, 'd_ffn') + component_config.vocab_size = config.getint(component, 'vocab_size') + component_config.n_positions = config.getint(component, + 'n_positions', + fallback=2048) + + component_config.has_position_embedding = config.getboolean( + component, 'has_position_embedding', + fallback=False) # TODO: hardcoded here + + component_config.has_token_type_embedding = config.getboolean( + component, 'has_token_type_embedding', fallback=False) + component_config.has_embedding_layernorm = config.getboolean( + component, 'has_embedding_layernorm', fallback=False) + component_config.has_embedding_scale = config.getboolean( + component, 'has_embedding_scale', fallback=False) + component_config.q_scaling = get_offset_q_scaling(component_config) + component_config.has_attention_qkvo_bias = config.getboolean( + component, 'has_attention_qkvo_bias', + fallback=False) # TODO: hardcoded here + component_config.has_mlp_bias = config.getboolean(component, + 'has_mlp_bias', + fallback=False) + component_config.has_model_final_layernorm = config.getboolean( + component, 'has_model_final_layernorm', fallback=True) + component_config.layernorm_eps = config.getfloat(component, + 'layer_norm_epsilon', + fallback=1e-5) + component_config.layernorm_position = layernorm_position_map[config.get( + component, 'layernorm_position', + fallback='pre_layernorm')] # TODO: hardcoded here + component_config.layernorm_type = layernorm_type_map[config.get( + component, 'layernorm_type', fallback='RmsNorm')] + component_config.hidden_act = config.get(component, + 'dense_act_fn', + fallback="gelu") + component_config.gated_act = config.getboolean(component, + 'is_gated_act', + fallback=True) + #component_config.mlp_type = mlp_type_map['GatedMLP' if component_config.gated_act else 'MLP'] + component_config.num_buckets = config.getint( + component, 'relative_attention_num_buckets', fallback=0) + component_config.max_distance = config.getint( + component, 'relative_attention_max_distance', fallback=0) + component_config.position_embedding_type = config.get( + 'structure', 'position_embedding_type') + component_config.logits_dtype = config.get(component, + 'logits_dtype', + fallback='float16') + + if component == 'encoder': + component_config.n_layer = config.getint(component, 'num_layers') + + component_config.relative_attention = config.get( + 'structure', 'position_embedding_type') == 'relative' + + elif component == 'decoder': + component_config.n_layer = config.getint(component, 'num_layers') + component_config.has_lm_head_bias = config.getboolean( + component, # TODO: T5 with bias + 'has_lm_head_bias', + fallback=False) + component_config.relative_attention = config.getboolean( + component, 'relative_attention', fallback=False) + component_config.rescale_before_lm_head = config.getboolean( + component, + 'tie_word_embeddings', + fallback=True, + ) # default is True (for T5), but False for Flan-T5 + component_config.encoder_hidden_size = config.getint( + 'encoder', 'd_model') + component_config.encoder_num_heads = config.getint( + 'encoder', 'num_heads') + component_config.encoder_head_size = config.getint( + 'encoder', 'd_kv') + #FIXME: check what is the correct generation process for the given checkpoint + component_config.decoder_start_token_id = config.getint( + 'decoder', 'decoder_start_token_id', fallback=106339 - 2) + component_config.eos_token_id = config.getint('decoder', + 'eos_token_id', + fallback=2048 - 1) + bos_token_id = config.get('decoder', + 'bos_token_id', + fallback=2048 - 2) + # T5 does not have bos_token_id + component_config.bos_token_id = int( + bos_token_id) if bos_token_id != "None" else None + component_config.pad_token_id = config.getint('decoder', + 'pad_token_id', + fallback=0) + + vocab_size = config.getint('decoder', 'vocab_size') + num_vocabs = config.getint('decoder', 'num_vocabs') + component_config.vocab_sizes = [vocab_size // num_vocabs] * num_vocabs + + else: + assert False, 'Unsupported component!' + + return component_config + + encoder_config = parse_t5_config_by_component(config, "encoder", args) + decoder_config = parse_t5_config_by_component(config, "decoder", args) + + return encoder_config, decoder_config + + +def convert_t5tts_encoder( + config, + model_dict, + quant_algo: str = None, +): + weights = {} + weights['embedding.vocab_embedding.weight'] = model_dict[ + 'text_embedding.weight'].contiguous() + weights['embedding.position_embedding.weight'] = model_dict[ + 't5_encoder.position_embeddings.weight'].contiguous() + + num_layers = config.n_layer + for i in range(num_layers): + weights[f'encoder_layers.{i}.attention_layernorm.weight'] = model_dict[ + f't5_encoder.layers.{i}.norm_self.weight'].contiguous() + weights[f'encoder_layers.{i}.attention.qkv.weight'] = model_dict[ + f't5_encoder.layers.{i}.self_attention.qkv_net.weight'].contiguous( + ) + weights[f'encoder_layers.{i}.attention.dense.weight'] = model_dict[ + f't5_encoder.layers.{i}.self_attention.o_net.weight'].contiguous() + weights[f'encoder_layers.{i}.pos_ff_layernorm.weight'] = model_dict[ + f't5_encoder.layers.{i}.norm_pos_ff.weight'].contiguous() + weights[f'encoder_layers.{i}.pos_ff.proj.weight'] = model_dict[ + f't5_encoder.layers.{i}.pos_ff.proj.conv.weight'].unsqueeze( + 3).contiguous() + weights[f'encoder_layers.{i}.pos_ff.o_net.weight'] = model_dict[ + f't5_encoder.layers.{i}.pos_ff.o_net.conv.weight'].unsqueeze( + 3).contiguous() + + weights['final_layernorm.weight'] = model_dict[ + f't5_encoder.norm_out.weight'].contiguous() + + return weights + + +def convert_t5tts_decoder( + config, + model_dict, + quant_algo: str = None, +): + weights = {} + #weights['embedding.vocab_embedding.weight'] = model_dict['final_proj.weight'].clone().contiguous() + + weights['lm_head.weight'] = model_dict['final_proj.weight'].clone( + ).contiguous() + + weights['embedding.position_embedding.weight'] = model_dict[ + 't5_decoder.position_embeddings.weight'].contiguous() + + weights[f'embedding.vocab_embedding.weight'] = torch.cat( + [model_dict[f'audio_embeddings.{i}.weight'] for i in range(len(config.vocab_sizes))], dim=0 + ).contiguous() + + num_layers = config.n_layer + for i in range(num_layers): + weights[ + f'decoder_layers.{i}.self_attention_layernorm.weight'] = model_dict[ + f't5_decoder.layers.{i}.norm_self.weight'].contiguous() + weights[f'decoder_layers.{i}.self_attention.qkv.weight'] = model_dict[ + f't5_decoder.layers.{i}.self_attention.qkv_net.weight'].contiguous( + ) + weights[f'decoder_layers.{i}.self_attention.dense.weight'] = model_dict[ + f't5_decoder.layers.{i}.self_attention.o_net.weight'].contiguous() + weights[ + f'decoder_layers.{i}.cross_attention_layernorm.weight'] = model_dict[ + f't5_decoder.layers.{i}.norm_xattn_query.weight'].contiguous() + + t = torch.cat([ + model_dict[f't5_decoder.layers.{i}.cross_attention.q_net.weight'], + model_dict[f't5_decoder.layers.{i}.cross_attention.kv_net.weight'] + ], dim=0).contiguous() + + weights[f'decoder_layers.{i}.cross_attention.qkv.weight'] = t + weights[f'decoder_layers.{i}.cross_attention.dense.weight'] = model_dict[ + f't5_decoder.layers.{i}.cross_attention.o_net.weight'].contiguous() + weights[f'decoder_layers.{i}.pos_ff_layernorm.weight'] = model_dict[ + f't5_decoder.layers.{i}.norm_pos_ff.weight'].contiguous() + weights[ + f'decoder_layers.{i}.cross_attention_memory_layernorm.weight'] = model_dict[ + f't5_decoder.layers.{i}.norm_xattn_memory.weight'].contiguous() + weights[f'decoder_layers.{i}.pos_ff.proj.weight'] = model_dict[ + f't5_decoder.layers.{i}.pos_ff.proj.conv.weight'].unsqueeze( + 3).contiguous() + weights[f'decoder_layers.{i}.pos_ff.o_net.weight'] = model_dict[ + f't5_decoder.layers.{i}.pos_ff.o_net.conv.weight'].unsqueeze( + 3).contiguous() + + weights['final_layernorm.weight'] = model_dict[ + f't5_decoder.norm_out.weight'].contiguous() + + component_save_dir = os.path.join(args.output_dir, "decoder") + os.makedirs(component_save_dir, exist_ok=True) + return weights + + +def get_obj_dict(obj): + return obj.__dict__ + + +def convert_checkpoint(args, model): + + saved_dir = Path(args.output_dir) + saved_dir.mkdir(parents=True, exist_ok=True) + + encoder_saved_dir = saved_dir / "encoder" + encoder_saved_dir.mkdir(parents=True, exist_ok=True) + decoder_saved_dir = saved_dir / "decoder" + decoder_saved_dir.mkdir(parents=True, exist_ok=True) + + world_size = args.tp_size * args.pp_size + + kv_cache_quant_algo = None + quant_algo = None + + encoder_config, decoder_config = parse_model_config(args, ) + + additional_settings = ["gated_act"] + + tllm_encoder_config = { + 'architecture': "T5TTSEncoderModel", + 'dtype': args.dtype, + 'logits_dtype': encoder_config.logits_dtype, + 'num_hidden_layers': encoder_config.n_layer, + 'num_attention_heads': encoder_config.n_head, + 'hidden_size': encoder_config.hidden_size, + 'norm_epsilon': encoder_config.layernorm_eps, + 'vocab_size': encoder_config.vocab_size, + 'position_embedding_type': encoder_config.position_embedding_type, + 'hidden_act': encoder_config.hidden_act, + 'quantization': { + 'quant_algo': quant_algo, + 'kv_cache_quant_algo': kv_cache_quant_algo, + }, + 'mapping': { + 'world_size': world_size, + 'tp_size': args.tp_size, + 'pp_size': args.pp_size, + }, + 'use_parallel_embedding': args.use_parallel_embedding, + 'embedding_sharding_dim': args.embedding_sharding_dim, + 'max_position_embeddings': encoder_config.n_positions, + 'num_key_value_heads': encoder_config.n_head, + 'head_size': encoder_config.head_size, + 'has_position_embedding': encoder_config.has_position_embedding, + 'layernorm_type': encoder_config.layernorm_type, + 'has_attention_qkvo_bias': encoder_config.has_attention_qkvo_bias, + 'has_mlp_bias': encoder_config.has_mlp_bias, + 'has_model_final_layernorm': encoder_config.has_model_final_layernorm, + 'has_embedding_layernorm': encoder_config.has_embedding_layernorm, + 'has_embedding_scale': encoder_config.has_embedding_scale, + 'intermediate_size': encoder_config.ffn_hidden_size, + 'q_scaling': encoder_config.q_scaling, + 'layernorm_position': encoder_config.layernorm_position, + 'relative_attention': encoder_config.relative_attention, + 'max_distance': encoder_config.max_distance, + 'num_buckets': encoder_config.num_buckets, + 'model_type': "t5tts" + } + + for additional_setting in additional_settings: + if hasattr(encoder_config, additional_setting): + tllm_encoder_config.update({ + additional_setting: + getattr(encoder_config, additional_setting) + }) + + tllm_decoder_config = { + 'architecture': "T5TTSDecoderModel", + 'dtype': args.dtype, + 'logits_dtype': decoder_config.logits_dtype, + 'num_hidden_layers': decoder_config.n_layer, + 'num_attention_heads': decoder_config.n_head, + 'hidden_size': decoder_config.hidden_size, + 'norm_epsilon': decoder_config.layernorm_eps, + 'vocab_size': decoder_config.vocab_size, + 'vocab_sizes': decoder_config.vocab_sizes, + 'position_embedding_type': decoder_config.position_embedding_type, + 'hidden_act': decoder_config.hidden_act, + 'quantization': { + 'quant_algo': quant_algo, + 'kv_cache_quant_algo': kv_cache_quant_algo, + }, + 'mapping': { + 'world_size': world_size, + 'tp_size': args.tp_size, + 'pp_size': args.pp_size, + }, + 'use_parallel_embedding': args.use_parallel_embedding, + 'embedding_sharding_dim': args.embedding_sharding_dim, + 'max_position_embeddings': decoder_config.n_positions, + 'head_size': decoder_config.head_size, + 'has_position_embedding': decoder_config.has_position_embedding, + 'layernorm_type': decoder_config.layernorm_type, + 'has_attention_qkvo_bias': decoder_config.has_attention_qkvo_bias, + 'has_mlp_bias': decoder_config.has_mlp_bias, + 'has_model_final_layernorm': decoder_config.has_model_final_layernorm, + 'has_embedding_layernorm': decoder_config.has_embedding_layernorm, + 'has_embedding_scale': decoder_config.has_embedding_scale, + 'intermediate_size': decoder_config.ffn_hidden_size, + 'q_scaling': decoder_config.q_scaling, + 'layernorm_position': decoder_config.layernorm_position, + 'relative_attention': decoder_config.relative_attention, + 'max_distance': decoder_config.max_distance, + 'num_buckets': decoder_config.num_buckets, + 'model_type': "t5tts", + 'rescale_before_lm_head': decoder_config.rescale_before_lm_head, + 'encoder_hidden_size': decoder_config.encoder_hidden_size, + 'encoder_num_heads': decoder_config.encoder_num_heads, + 'encoder_head_size': decoder_config.encoder_head_size, + 'skip_cross_kv': args.skip_cross_kv, + 'use_implicit_relative_attention': args.use_implicit_relative_attention, + 'decoder_start_token_id': decoder_config.decoder_start_token_id, + 'eos_token_id': decoder_config.eos_token_id, + 'bos_token_id': decoder_config.bos_token_id, + 'pad_token_id': decoder_config.pad_token_id, + 'cross_attention': True, # this has to be provided explicitely + } + for additional_setting in additional_settings: + if hasattr(decoder_config, additional_setting): + tllm_decoder_config.update({ + additional_setting: + getattr(decoder_config, additional_setting) + }) + + def convert_and_save(component: str = "encoder", ): + # call get_encoder_config or get_decoder_config according to component + if component == "encoder": + config = tllm_encoder_config + else: + config = tllm_decoder_config + + component_save_dir = os.path.join(args.output_dir, component) + if not os.path.exists(component_save_dir): + os.makedirs(component_save_dir) + + with open(os.path.join(component_save_dir, 'config.json'), 'w') as f: + json.dump(config, f, indent=4, default=get_obj_dict) + + if args.use_weight_only and args.weight_only_precision == 'int4_gptq': + config['quantization'].update({ + 'has_zero_point': True, + }) + + quant_algo = None + """ + plugin_weight_only_quant_type = None + if args.use_weight_only and args.weight_only_precision == 'int8': + plugin_weight_only_quant_type = torch.int8 + quant_algo = QuantAlgo.W8A16 + elif args.use_weight_only and args.weight_only_precision == 'int4': + plugin_weight_only_quant_type = torch.quint4x2 + quant_algo = QuantAlgo.W4A16 + elif args.use_weight_only and args.weight_only_precision == 'int4_gptq': + quant_algo = QuantAlgo.W4A16_GPTQ + """ + + if component == "encoder": + + weights = convert_t5tts_encoder(encoder_config, + model_state_dict, + quant_algo=quant_algo) + else: + assert component == "decoder" + weights = convert_t5tts_decoder(decoder_config, + model_state_dict, + quant_algo=quant_algo) + + safetensors.torch.save_file( + weights, os.path.join(component_save_dir, f'rank0.safetensors')) + + convert_and_save(component="encoder") + convert_and_save(component="decoder") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + formatter_class=argparse.RawTextHelpFormatter) + parser.add_argument('--model_path', type=str, required=True) + + parser.add_argument('--tp_size', + type=int, + default=1, + help='N-way tensor parallelism size') + parser.add_argument('--pp_size', + type=int, + default=1, + help='N-way pipeline parallelism size') + + parser.add_argument('--output_dir', + type=str, + default='tllm_checkpoint', + help='The path to save the TensorRT-LLM checkpoint') + + parser.add_argument( + "--workers", + type=int, + help="How many workers to spawn for conversion (default: 4)", + default=4) + + parser.add_argument("--verbose", + action="store_true", + help="Provide verbose messages") + parser.add_argument( + '--use_parallel_embedding', + action="store_true", + default=False, + help= + 'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled' + ) + parser.add_argument( + '--embedding_sharding_dim', + type=int, + default=0, + choices=[0, 1], + help= + 'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). ' + 'To shard it along hidden dimension, set embedding_sharding_dim=1' + 'Note: embedding sharding is only enabled when embedding_sharding_dim = 0' + ) + parser.add_argument( + '--use_weight_only', + default=False, + action="store_true", + help='Quantize weights for the various GEMMs to INT4/INT8.' + 'See --weight_only_precision to set the precision') + parser.add_argument( + '--weight_only_precision', + const='int8', + type=str, + nargs='?', + default='int8', + choices=['int8', 'int4'], + help= + 'Define the precision for the weights when using weight-only quantization.' + 'You must also use --use_weight_only for that argument to have an impact.' + ) + parser.add_argument( + '--dtype', + type=str, + default='float16', + choices=['float16', 'float32', 'bfloat16'], + help= + 'Target inference dtype. Weights and Computation will be in this dtype, no matter what original dtype the weight checkpoint has.' + ) + parser.add_argument( + '--skip_cross_kv', + action='store_true', + help= + 'Skip redundant cross qkv computation by using TensorRT IfConditional switch (experimental).' + ) + parser.add_argument( + '--use_implicit_relative_attention', + action='store_true', + help= + 'Compute relative attention bias on the fly instead of pre-compute a relative attention bias table.' + ) + args = parser.parse_args() + log_format = "%(asctime)s %(name)s [%(levelname)s] %(message)s" + logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO, + format=log_format) + LOGGER.info("\n=============== Argument ===============") + for key in vars(args): + LOGGER.info(f"{key}: {vars(args)[key]}") + LOGGER.info("========================================") + + start_time = datetime.now() + + model_metadata = {} + model_state_dict = torch.load(args.model_path, + weights_only=False)['state_dict'] + for k in model_state_dict: + model_state_dict[k] = model_state_dict[k].to( + dtype=TORCH_DTYPES[args.dtype]) + convert_checkpoint(args, model_state_dict) + + stop_time = datetime.now() + run_time = (stop_time - start_time) + LOGGER.info("Spend {} (h:m:s) to convert the model".format(run_time)) \ No newline at end of file diff --git a/examples/models/core/t5tts/run.py b/examples/models/core/t5tts/run.py new file mode 100644 index 00000000000..b3472e7e85f --- /dev/null +++ b/examples/models/core/t5tts/run.py @@ -0,0 +1,57 @@ +from tensorrt_llm.runtime import ModelRunnerCpp +import torch + + +def main(): + + runner = ModelRunnerCpp.from_dir( + engine_dir='newmodels/t5tts_engine/', + is_enc_dec=True, + max_input_len=512, + cross_kv_cache_fraction=0.5, + rank=0, + ) + + encoder_tokens = torch.tensor([50] * 128, dtype=torch.int32) + print(f"Encoder input tokens: {str(encoder_tokens)}") + + # Create batch_size=512 by repeating the same input + batch_size = 3 + + # Use [0] as decoder input (decoder start token) and repeat for batch_size + decoder_input_ids = [0] * 16 + decoder_input_ids = torch.tensor(decoder_input_ids, dtype=torch.int32) + decoder_input_ids = decoder_input_ids.unsqueeze(1).repeat(1, 8) + for i in range(8): + decoder_input_ids[:, i] += 2048 * i + decoder_input_ids = decoder_input_ids.flatten() + print(f"Decoder input tokens: {str(len(decoder_input_ids))}") + decoder_input_ids = [ + decoder_input_ids + for _ in range(batch_size) + ] + + # Use the tokenized input as encoder input and repeat for batch_size + encoder_input_ids = [ + encoder_tokens + for _ in range(batch_size) + ] + + with torch.no_grad(): + outputs = runner.generate( + batch_input_ids=decoder_input_ids, + encoder_input_ids=encoder_input_ids, + max_new_tokens=32, + end_id=1, + pad_id=0, + streaming=False, + ) + torch.cuda.synchronize() + + output_ids = outputs.cpu().numpy()[0][0] + print(f"Output tokens (len {len(output_ids)}):") + print(output_ids) + + +if __name__ == "__main__": + main() diff --git a/examples/t5tts/README.md b/examples/t5tts/README.md new file mode 100644 index 00000000000..dfab26585c7 --- /dev/null +++ b/examples/t5tts/README.md @@ -0,0 +1,86 @@ + +# Build TRTLLM + +This describes how to run the t5tts in TRTLLM. +Build docker and compile TRTLLM as usual: + +```bash +make -C docker build IMAGE_NAME=t5tts +make -C docker run LOCAL_USER=1 IMAGE_NAME=t5tts CONTAINER_NAME=t5tts +# 90-real - for H100 +python3 ./scripts/build_wheel.py --cuda_architectures "90-real" --benchmarks --trt_root /usr/local/tensorrt +pip install build/tensorrt_llm-0.20.0rc0-cp312-cp312-linux_x86_64.whl +``` + +# Build Engine + +Convert the checkpoint and build the engine: +```bash +# required to pip install omegaconf +# md5sum newmodels/t5tts.ckpt: fb177acdc447af56c8bbfa9d17c75f45 +python examples/models/core/t5tts/convert_checkpoint.py \ + --model_path newmodels/t5tts.ckpt --output_dir newmodels/t5tts_convert + +trtllm-build --checkpoint_dir newmodels/t5tts_convert/encoder/ \ +--output_dir newmodels/t5tts_engine/encoder \ +--paged_kv_cache enable --moe_plugin disable --max_beam_width 1 \ +--max_batch_size 256 --max_input_len 128 --gemm_plugin float16 \ +--bert_attention_plugin float16 --gpt_attention_plugin float16 \ +--remove_input_padding enable --use_paged_context_fmha enable + +trtllm-build --checkpoint_dir newmodels/t5tts_convert/decoder \ + --output_dir newmodels/t5tts_engine/decoder \ + --moe_plugin disable \ + --max_beam_width 1 \ + --max_batch_size 64 \ + --max_input_len 192 \ + --max_seq_len 512 \ + --max_encoder_input_len 512 \ + --gemm_plugin float16 \ + --bert_attention_plugin float16 \ + --gpt_attention_plugin float16 \ + --remove_input_padding enable \ + --use_paged_context_fmha enable +``` + +# Toy inference + +Finally run the model on the dummy input: +```bash +python examples/models/core/t5tts/run.py +``` + +# Benchmark + +gpt manager benchmark is modified to run benchmark with context for decoder. + +```bash +# prepare dummy inputs for inference +# 128 - number of phonemes in avergage sentence +# 160 - context length in frames, corresponds to 160 / 21.5 = 7.44 seconds +# 640 - total sequence length in frames, means 640 - 160 = 480 frames of audio generated, +# which corresponds to 480 / 21.5 = 22.33 seconds +# 768 - batch_size * 3, measure performance on 3 batches at max utilization +python examples/models/core/enc_dec/prepare_benchmark.py --output benchmark.json \ + --samples 768 \ + --max_input_id 98 \ + --num_vocabs 8 \ + --input_len 128 0 128 128 \ + --context_len 160 0 160 160 \ + --output_len 640 0 640 640 + +# run benchmark using generated dummy inputs +./cpp/build/benchmarks/gptManagerBenchmark \ + --dataset benchmark.json \ + --output_csv res.csv \ + --max_batch_size 256 \ + --concurrency 256 \ + --streaming \ + --num_vocabs 8 \ + --enable_chunked_context \ + --encoder_engine_dir newmodels/t5tts_engine/encoder \ + --decoder_engine_dir newmodels/t5tts_engine/decoder 2>&1 > /dev/null + +# print results from res.csv +python3 -c "import csv; f=open('res.csv'); r=csv.reader(f); h=next(r); v=next(r); [print(f'{h[i]:<50}: {v[i]}') for i in range(len(h))]" +``` diff --git a/examples/t5tts/convert_checkpoint.py b/examples/t5tts/convert_checkpoint.py new file mode 100644 index 00000000000..9c6b51d2411 --- /dev/null +++ b/examples/t5tts/convert_checkpoint.py @@ -0,0 +1,626 @@ +import argparse +import configparser +import json +import logging +import os +import types +from datetime import datetime +from pathlib import Path + +import safetensors +import torch + +from tensorrt_llm.functional import (LayerNormPositionType, LayerNormType, + MLPType) + +dir_path = os.path.dirname(os.path.realpath(__file__)) +LOGGER = logging.getLogger(__name__) + +layernorm_type_map = {i.name: i.value for i in LayerNormType} +layernorm_position_map = {i.name: i.value for i in LayerNormPositionType} +mlp_type_map = {i.name: i.value for i in MLPType} + +TORCH_DTYPES = { + 'float32': torch.float32, + 'float64': torch.float64, + 'float16': torch.float16, + 'bfloat16': torch.bfloat16, +} + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument('--quant_ckpt_path', type=str, default=None) + parser.add_argument('--model_name', type=str) + parser.add_argument( + '--model_path', + type=str, + default=None, + ) + parser.add_argument('--dtype', + type=str, + default='float16', + choices=['float32', 'bfloat16', 'float16']) + parser.add_argument('--logits_dtype', + type=str, + default='float16', + choices=['float16', 'float32']) + parser.add_argument('--output_dir', + type=str, + default='tllm_checkpoint', + help='The path to save the TensorRT-LLM checkpoint') + parser.add_argument( + '--use_weight_only', + default=False, + action="store_true", + help='Quantize weights for the various GEMMs to INT4/INT8.' + 'See --weight_only_precision to set the precision') + parser.add_argument( + '--weight_only_precision', + const='int8', + type=str, + nargs='?', + default='int8', + choices=['int8', 'int4'], + help= + 'Define the precision for the weights when using weight-only quantization.' + 'You must also use --use_weight_only for that argument to have an impact.' + ) + parser.add_argument('engine_dir') + args = parser.parse_args() + print(args) + return args + + +def copy_args_to_component_config(component_config, args): + for arg in vars(args): + setattr(component_config, arg, getattr(args, arg)) + return component_config + + +def parse_model_config(args, ): + config = configparser.ConfigParser() + + config["encoder"] = {} + config["decoder"] = {} + + config["encoder"]["num_heads"] = "12" + config["encoder"]['d_model'] = "768" #hidden_size + config["encoder"]['d_ffn'] = "3072" #ffn_hidden_size + config["encoder"]['vocab_size'] = "98" # used to be 106339 in the branch + config["encoder"]['n_positions'] = "2048" + config["encoder"]['has_position_embedding'] = "true" + #config["encoder"]['has_token_type_embedding'] = + config["encoder"]['layernorm_position'] = "pre_layernorm" + + config["encoder"]['layernorm_type'] = "LayerNorm" + config["encoder"]['num_layers'] = "6" + # config["encoder"]['d_model'] /config["encoder"]["num_heads"] + config["encoder"]['d_kv'] = f"{int(768/12)}" + + config["decoder"]["num_heads"] = "12" + config["decoder"]['d_model'] = "768" #hidden_size + config["decoder"]['d_ffn'] = "3072" #ffn_hidden_size + config["decoder"]['vocab_size'] = "16384" # 8 * 2048 + config["decoder"]['n_positions'] = "2048" + config["decoder"]['has_position_embedding'] = "true" + config["decoder"]['layernorm_position'] = "pre_layernorm" + + config["decoder"]['layernorm_type'] = "LayerNorm" + config["decoder"]['num_layers'] = "12" + config["decoder"]["num_vocabs"] = "8" + + # manually set q_scaling to offset attention scaling's effect. + # TODO: modify kernels to control whether to disable attention scaling + def get_offset_q_scaling(config): + scaling = 1 / config.head_size**.5 + return scaling + + config["structure"] = dict() + config["structure"]["t5_with_bias"] = "false" + #config["structure"]["use_gated_activation"] = str(hf_model.encoder.config.is_gated_act) + config["structure"]["position_embedding_type"] = "learned_absolute" + config["structure"]["model_type"] = "T5TTS" + + def parse_t5_config_by_component(config, component, args): + component_config = types.SimpleNamespace() + component_config = copy_args_to_component_config(component_config, args) + component_config.n_head = config.getint(component, 'num_heads') + component_config.hidden_size = config.getint(component, 'd_model') + component_config.head_size = component_config.hidden_size // component_config.n_head + + component_config.ffn_hidden_size = config.getint(component, 'd_ffn') + component_config.vocab_size = config.getint(component, 'vocab_size') + component_config.n_positions = config.getint(component, + 'n_positions', + fallback=2048) + + component_config.has_position_embedding = config.getboolean( + component, 'has_position_embedding', + fallback=False) # TODO: hardcoded here + + component_config.has_token_type_embedding = config.getboolean( + component, 'has_token_type_embedding', fallback=False) + component_config.has_embedding_layernorm = config.getboolean( + component, 'has_embedding_layernorm', fallback=False) + component_config.has_embedding_scale = config.getboolean( + component, 'has_embedding_scale', fallback=False) + component_config.q_scaling = get_offset_q_scaling(component_config) + component_config.has_attention_qkvo_bias = config.getboolean( + component, 'has_attention_qkvo_bias', + fallback=False) # TODO: hardcoded here + component_config.has_mlp_bias = config.getboolean(component, + 'has_mlp_bias', + fallback=False) + component_config.has_model_final_layernorm = config.getboolean( + component, 'has_model_final_layernorm', fallback=True) + component_config.layernorm_eps = config.getfloat(component, + 'layer_norm_epsilon', + fallback=1e-5) + component_config.layernorm_position = layernorm_position_map[config.get( + component, 'layernorm_position', + fallback='pre_layernorm')] # TODO: hardcoded here + component_config.layernorm_type = layernorm_type_map[config.get( + component, 'layernorm_type', fallback='RmsNorm')] + component_config.hidden_act = config.get(component, + 'dense_act_fn', + fallback="gelu") + component_config.gated_act = config.getboolean(component, + 'is_gated_act', + fallback=True) + #component_config.mlp_type = mlp_type_map['GatedMLP' if component_config.gated_act else 'MLP'] + component_config.num_buckets = config.getint( + component, 'relative_attention_num_buckets', fallback=0) + component_config.max_distance = config.getint( + component, 'relative_attention_max_distance', fallback=0) + component_config.position_embedding_type = config.get( + 'structure', 'position_embedding_type') + component_config.logits_dtype = config.get(component, + 'logits_dtype', + fallback='float16') + + if component == 'encoder': + component_config.n_layer = config.getint(component, 'num_layers') + + component_config.relative_attention = config.get( + 'structure', 'position_embedding_type') == 'relative' + + elif component == 'decoder': + component_config.n_layer = config.getint(component, 'num_layers') + component_config.has_lm_head_bias = config.getboolean( + component, # TODO: T5 with bias + 'has_lm_head_bias', + fallback=True) + component_config.relative_attention = config.getboolean( + component, 'relative_attention', fallback=False) + component_config.rescale_before_lm_head = config.getboolean( + component, + 'tie_word_embeddings', + fallback=True, + ) # default is True (for T5), but False for Flan-T5 + component_config.encoder_hidden_size = config.getint( + 'encoder', 'd_model') + component_config.encoder_num_heads = config.getint( + 'encoder', 'num_heads') + component_config.encoder_head_size = config.getint( + 'encoder', 'd_kv') + #FIXME: check what is the correct generation process for the given checkpoint + component_config.decoder_start_token_id = config.getint( + 'decoder', 'decoder_start_token_id', fallback=106339 - 2) + component_config.eos_token_id = config.getint('decoder', + 'eos_token_id', + fallback=2048 - 1) + + bos_token_id = config.get('decoder', + 'bos_token_id', + fallback=2048 - 2) + # T5 does not have bos_token_id + component_config.bos_token_id = int( + bos_token_id) if bos_token_id != "None" else None + component_config.pad_token_id = config.getint('decoder', + 'pad_token_id', + fallback=0) + + vocab_size = config.getint('decoder', 'vocab_size') + num_vocabs = config.getint('decoder', 'num_vocabs') + component_config.vocab_sizes = [vocab_size // num_vocabs + ] * num_vocabs + + else: + assert False, 'Unsupported component!' + + return component_config + + encoder_config = parse_t5_config_by_component(config, "encoder", args) + decoder_config = parse_t5_config_by_component(config, "decoder", args) + + return encoder_config, decoder_config + + +def convert_t5tts_encoder( + config, + model_dict, + quant_algo: str = None, +): + weights = {} + weights['embedding.vocab_embedding.weight'] = model_dict[ + 'text_embedding.weight'].contiguous() + weights['embedding.position_embedding.weight'] = model_dict[ + 't5_encoder.position_embeddings.weight'].contiguous() + + num_layers = config.n_layer + for i in range(num_layers): + weights[f'encoder_layers.{i}.attention_layernorm.weight'] = model_dict[ + f't5_encoder.layers.{i}.norm_self.weight'].contiguous() + weights[f'encoder_layers.{i}.attention.qkv.weight'] = model_dict[ + f't5_encoder.layers.{i}.self_attention.qkv_net.weight'].contiguous( + ) + weights[f'encoder_layers.{i}.attention.dense.weight'] = model_dict[ + f't5_encoder.layers.{i}.self_attention.o_net.weight'].contiguous() + weights[f'encoder_layers.{i}.pos_ff_layernorm.weight'] = model_dict[ + f't5_encoder.layers.{i}.norm_pos_ff.weight'].contiguous() + weights[f'encoder_layers.{i}.pos_ff.proj.weight'] = model_dict[ + f't5_encoder.layers.{i}.pos_ff.proj.conv.weight'].unsqueeze( + 3).contiguous() + weights[f'encoder_layers.{i}.pos_ff.o_net.weight'] = model_dict[ + f't5_encoder.layers.{i}.pos_ff.o_net.conv.weight'].unsqueeze( + 3).contiguous() + + weights['final_layernorm.weight'] = model_dict[ + f't5_encoder.norm_out.weight'].contiguous() + + return weights + + +def convert_t5tts_decoder( + config, + model_dict, + quant_algo: str = None, +): + weights = {} + #weights['embedding.vocab_embedding.weight'] = model_dict['final_proj.weight'].clone().contiguous() + + weights['lm_head.weight'] = model_dict['final_proj.weight'].clone( + ).contiguous() + weights['lm_head.bias'] = model_dict['final_proj.bias'].clone().contiguous() + + weights['embedding.position_embedding.weight'] = model_dict[ + 't5_decoder.position_embeddings.weight'].contiguous() + + weights[f'embedding.vocab_embedding.weight'] = torch.cat( + [ + model_dict[f'audio_embeddings.{i}.weight'] + for i in range(len(config.vocab_sizes)) + ], + dim=0).contiguous() + + num_layers = config.n_layer + for i in range(num_layers): + weights[ + f'decoder_layers.{i}.self_attention_layernorm.weight'] = model_dict[ + f't5_decoder.layers.{i}.norm_self.weight'].contiguous() + weights[f'decoder_layers.{i}.self_attention.qkv.weight'] = model_dict[ + f't5_decoder.layers.{i}.self_attention.qkv_net.weight'].contiguous( + ) + weights[f'decoder_layers.{i}.self_attention.dense.weight'] = model_dict[ + f't5_decoder.layers.{i}.self_attention.o_net.weight'].contiguous() + weights[ + f'decoder_layers.{i}.cross_attention_layernorm.weight'] = model_dict[ + f't5_decoder.layers.{i}.norm_xattn_query.weight'].contiguous() + + t = torch.cat([ + model_dict[f't5_decoder.layers.{i}.cross_attention.q_net.weight'], + model_dict[f't5_decoder.layers.{i}.cross_attention.kv_net.weight'] + ], + dim=0).contiguous() + + weights[f'decoder_layers.{i}.cross_attention.qkv.weight'] = t + weights[f'decoder_layers.{i}.cross_attention.dense.weight'] = model_dict[ + f't5_decoder.layers.{i}.cross_attention.o_net.weight'].contiguous() + weights[f'decoder_layers.{i}.pos_ff_layernorm.weight'] = model_dict[ + f't5_decoder.layers.{i}.norm_pos_ff.weight'].contiguous() + weights[ + f'decoder_layers.{i}.cross_attention_memory_layernorm.weight'] = model_dict[ + f't5_decoder.layers.{i}.norm_xattn_memory.weight'].contiguous() + weights[f'decoder_layers.{i}.pos_ff.proj.weight'] = model_dict[ + f't5_decoder.layers.{i}.pos_ff.proj.conv.weight'].unsqueeze( + 3).contiguous() + weights[f'decoder_layers.{i}.pos_ff.o_net.weight'] = model_dict[ + f't5_decoder.layers.{i}.pos_ff.o_net.conv.weight'].unsqueeze( + 3).contiguous() + + weights['final_layernorm.weight'] = model_dict[ + f't5_decoder.norm_out.weight'].contiguous() + + component_save_dir = os.path.join(args.output_dir, "decoder") + os.makedirs(component_save_dir, exist_ok=True) + return weights + + +def get_obj_dict(obj): + return obj.__dict__ + + +def convert_checkpoint(args, model): + + saved_dir = Path(args.output_dir) + saved_dir.mkdir(parents=True, exist_ok=True) + + encoder_saved_dir = saved_dir / "encoder" + encoder_saved_dir.mkdir(parents=True, exist_ok=True) + decoder_saved_dir = saved_dir / "decoder" + decoder_saved_dir.mkdir(parents=True, exist_ok=True) + + world_size = args.tp_size * args.pp_size + + kv_cache_quant_algo = None + quant_algo = None + + encoder_config, decoder_config = parse_model_config(args, ) + + additional_settings = ["gated_act"] + + tllm_encoder_config = { + 'architecture': "T5TTSEncoderModel", + 'dtype': args.dtype, + 'logits_dtype': encoder_config.logits_dtype, + 'num_hidden_layers': encoder_config.n_layer, + 'num_attention_heads': encoder_config.n_head, + 'hidden_size': encoder_config.hidden_size, + 'norm_epsilon': encoder_config.layernorm_eps, + 'vocab_size': encoder_config.vocab_size, + 'position_embedding_type': encoder_config.position_embedding_type, + 'hidden_act': encoder_config.hidden_act, + 'quantization': { + 'quant_algo': quant_algo, + 'kv_cache_quant_algo': kv_cache_quant_algo, + }, + 'mapping': { + 'world_size': world_size, + 'tp_size': args.tp_size, + 'pp_size': args.pp_size, + }, + 'use_parallel_embedding': args.use_parallel_embedding, + 'embedding_sharding_dim': args.embedding_sharding_dim, + 'max_position_embeddings': encoder_config.n_positions, + 'num_key_value_heads': encoder_config.n_head, + 'head_size': encoder_config.head_size, + 'has_position_embedding': encoder_config.has_position_embedding, + 'layernorm_type': encoder_config.layernorm_type, + 'has_attention_qkvo_bias': encoder_config.has_attention_qkvo_bias, + 'has_mlp_bias': encoder_config.has_mlp_bias, + 'has_model_final_layernorm': encoder_config.has_model_final_layernorm, + 'has_embedding_layernorm': encoder_config.has_embedding_layernorm, + 'has_embedding_scale': encoder_config.has_embedding_scale, + 'intermediate_size': encoder_config.ffn_hidden_size, + 'q_scaling': encoder_config.q_scaling, + 'layernorm_position': encoder_config.layernorm_position, + 'relative_attention': encoder_config.relative_attention, + 'max_distance': encoder_config.max_distance, + 'num_buckets': encoder_config.num_buckets, + 'model_type': "t5tts" + } + + for additional_setting in additional_settings: + if hasattr(encoder_config, additional_setting): + tllm_encoder_config.update({ + additional_setting: + getattr(encoder_config, additional_setting) + }) + + tllm_decoder_config = { + 'architecture': "T5TTSDecoderModel", + 'dtype': args.dtype, + 'logits_dtype': decoder_config.logits_dtype, + 'num_hidden_layers': decoder_config.n_layer, + 'num_attention_heads': decoder_config.n_head, + 'hidden_size': decoder_config.hidden_size, + 'norm_epsilon': decoder_config.layernorm_eps, + 'vocab_size': decoder_config.vocab_size, + 'vocab_sizes': decoder_config.vocab_sizes, + 'position_embedding_type': decoder_config.position_embedding_type, + 'hidden_act': decoder_config.hidden_act, + 'quantization': { + 'quant_algo': quant_algo, + 'kv_cache_quant_algo': kv_cache_quant_algo, + }, + 'mapping': { + 'world_size': world_size, + 'tp_size': args.tp_size, + 'pp_size': args.pp_size, + }, + 'use_parallel_embedding': args.use_parallel_embedding, + 'embedding_sharding_dim': args.embedding_sharding_dim, + 'max_position_embeddings': decoder_config.n_positions, + 'head_size': decoder_config.head_size, + 'has_position_embedding': decoder_config.has_position_embedding, + 'layernorm_type': decoder_config.layernorm_type, + 'has_attention_qkvo_bias': decoder_config.has_attention_qkvo_bias, + 'has_mlp_bias': decoder_config.has_mlp_bias, + 'has_model_final_layernorm': decoder_config.has_model_final_layernorm, + 'has_embedding_layernorm': decoder_config.has_embedding_layernorm, + 'has_embedding_scale': decoder_config.has_embedding_scale, + 'intermediate_size': decoder_config.ffn_hidden_size, + 'q_scaling': decoder_config.q_scaling, + 'layernorm_position': decoder_config.layernorm_position, + 'relative_attention': decoder_config.relative_attention, + 'max_distance': decoder_config.max_distance, + 'num_buckets': decoder_config.num_buckets, + 'model_type': "t5tts", + 'rescale_before_lm_head': decoder_config.rescale_before_lm_head, + 'encoder_hidden_size': decoder_config.encoder_hidden_size, + 'encoder_num_heads': decoder_config.encoder_num_heads, + 'encoder_head_size': decoder_config.encoder_head_size, + 'skip_cross_kv': args.skip_cross_kv, + 'use_implicit_relative_attention': args.use_implicit_relative_attention, + 'decoder_start_token_id': decoder_config.decoder_start_token_id, + 'eos_token_id': decoder_config.eos_token_id, + 'bos_token_id': decoder_config.bos_token_id, + 'pad_token_id': decoder_config.pad_token_id, + 'cross_attention': True, # this has to be provided explicitly + } + for additional_setting in additional_settings: + if hasattr(decoder_config, additional_setting): + tllm_decoder_config.update({ + additional_setting: + getattr(decoder_config, additional_setting) + }) + + def convert_and_save(component: str = "encoder", ): + # call get_encoder_config or get_decoder_config according to component + if component == "encoder": + config = tllm_encoder_config + else: + config = tllm_decoder_config + + component_save_dir = os.path.join(args.output_dir, component) + if not os.path.exists(component_save_dir): + os.makedirs(component_save_dir) + + with open(os.path.join(component_save_dir, 'config.json'), 'w') as f: + json.dump(config, f, indent=4, default=get_obj_dict) + + if args.use_weight_only and args.weight_only_precision == 'int4_gptq': + config['quantization'].update({ + 'has_zero_point': True, + }) + + quant_algo = None + """ + plugin_weight_only_quant_type = None + if args.use_weight_only and args.weight_only_precision == 'int8': + plugin_weight_only_quant_type = torch.int8 + quant_algo = QuantAlgo.W8A16 + elif args.use_weight_only and args.weight_only_precision == 'int4': + plugin_weight_only_quant_type = torch.quint4x2 + quant_algo = QuantAlgo.W4A16 + elif args.use_weight_only and args.weight_only_precision == 'int4_gptq': + quant_algo = QuantAlgo.W4A16_GPTQ + """ + + if component == "encoder": + + weights = convert_t5tts_encoder(encoder_config, + model_state_dict, + quant_algo=quant_algo) + else: + assert component == "decoder" + weights = convert_t5tts_decoder(decoder_config, + model_state_dict, + quant_algo=quant_algo) + + safetensors.torch.save_file( + weights, os.path.join(component_save_dir, f'rank0.safetensors')) + + convert_and_save(component="encoder") + convert_and_save(component="decoder") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + formatter_class=argparse.RawTextHelpFormatter) + parser.add_argument('--model_path', type=str, required=True) + + parser.add_argument('--tp_size', + type=int, + default=1, + help='N-way tensor parallelism size') + parser.add_argument('--pp_size', + type=int, + default=1, + help='N-way pipeline parallelism size') + + parser.add_argument('--output_dir', + type=str, + default='tllm_checkpoint', + help='The path to save the TensorRT-LLM checkpoint') + + parser.add_argument( + "--workers", + type=int, + help="How many workers to spawn for conversion (default: 4)", + default=4) + + parser.add_argument("--verbose", + action="store_true", + help="Provide verbose messages") + parser.add_argument( + '--use_parallel_embedding', + action="store_true", + default=False, + help= + 'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled' + ) + parser.add_argument( + '--embedding_sharding_dim', + type=int, + default=0, + choices=[0, 1], + help= + 'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). ' + 'To shard it along hidden dimension, set embedding_sharding_dim=1' + 'Note: embedding sharding is only enabled when embedding_sharding_dim = 0' + ) + parser.add_argument( + '--use_weight_only', + default=False, + action="store_true", + help='Quantize weights for the various GEMMs to INT4/INT8.' + 'See --weight_only_precision to set the precision') + parser.add_argument( + '--weight_only_precision', + const='int8', + type=str, + nargs='?', + default='int8', + choices=['int8', 'int4'], + help= + 'Define the precision for the weights when using weight-only quantization.' + 'You must also use --use_weight_only for that argument to have an impact.' + ) + parser.add_argument( + '--dtype', + type=str, + default='float16', + choices=['float16', 'float32', 'bfloat16'], + help= + 'Target inference dtype. Weights and Computation will be in this dtype, no matter what original dtype the weight checkpoint has.' + ) + parser.add_argument('--logits_dtype', + type=str, + default='float16', + choices=['float16', 'float32']) + parser.add_argument( + '--skip_cross_kv', + action='store_true', + help= + 'Skip redundant cross qkv computation by using TensorRT IfConditional switch (experimental).' + ) + parser.add_argument( + '--use_implicit_relative_attention', + action='store_true', + help= + 'Compute relative attention bias on the fly instead of pre-compute a relative attention bias table.' + ) + args = parser.parse_args() + log_format = "%(asctime)s %(name)s [%(levelname)s] %(message)s" + logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO, + format=log_format) + LOGGER.info("\n=============== Argument ===============") + for key in vars(args): + LOGGER.info(f"{key}: {vars(args)[key]}") + LOGGER.info("========================================") + + start_time = datetime.now() + + model_metadata = {} + model_state_dict = torch.load(args.model_path, + weights_only=False)['state_dict'] + for k in model_state_dict: + model_state_dict[k] = model_state_dict[k].to( + dtype=TORCH_DTYPES[args.dtype]) + convert_checkpoint(args, model_state_dict) + + stop_time = datetime.now() + run_time = (stop_time - start_time) + LOGGER.info("Spend {} (h:m:s) to convert the model".format(run_time)) diff --git a/examples/t5tts/run_tts.py b/examples/t5tts/run_tts.py new file mode 100644 index 00000000000..d7a287e1ac3 --- /dev/null +++ b/examples/t5tts/run_tts.py @@ -0,0 +1,262 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import json +import time +from collections import OrderedDict + +import torch + +import tensorrt_llm +from tensorrt_llm import logger +from tensorrt_llm._utils import str_dtype_to_trt, trt_dtype_to_torch +from tensorrt_llm.runtime import ModelRunnerCpp +from tensorrt_llm.runtime.session import Session, TensorInfo + + +def read_config(component, engine_dir): + config_path = f"{engine_dir}/{component}/config.json" + with open(config_path, 'r') as f: + config = json.load(f) + model_config = OrderedDict() + model_config.update(config['pretrained_config']) + model_config.update(config['build_config']) + return model_config + + +class MagpieEncoding: + + def __init__(self, engine_dir): + self.session = self.get_session(engine_dir) + config = read_config('encoder', engine_dir) + self.dtype = config['dtype'] + self.encoder_config = config + + def get_session(self, engine_dir): + serialize_path = f"{engine_dir}/encoder/rank0.engine" + with open(serialize_path, 'rb') as f: + session = Session.from_serialized_engine(f.read()) + return session + + def get_encoder_feature(self): + text_encodings = torch.IntTensor([[ + 96, 40, 29, 26, 93, 90, 55, 74, 52, 93, 29, 26, 39, 93, 90, 52, 77, + 85, 58, 93, 90, 64, 66, 65, 93, 84, 61, 93, 28, 39, 26, 22, 40, 46, + 93, 90, 68, 77, 86, 93, 44, 22, 41, 26, 39, 93, 90, 78, 59, 93, 90, + 57, 84, 85, 97 + ]]) + encoder_input_lengths = torch.IntTensor([text_encodings.shape[1]]) + encoder_input_ids = remove_tensor_padding( + text_encodings, encoder_input_lengths).flatten() + output_list = [ + TensorInfo('input_ids', str_dtype_to_trt('int32'), + encoder_input_ids.shape), + TensorInfo('input_lengths', str_dtype_to_trt('int32'), + encoder_input_lengths.shape), + ] + inputs = OrderedDict() + inputs['input_ids'] = encoder_input_ids + inputs['input_lengths'] = encoder_input_lengths + output_info = (self.session).infer_shapes(output_list) + for k in inputs: + print(f"{k}: {inputs[k].shape}") + print(f"output_info: {output_info}") + logger.debug(f'output info {output_info}') + outputs = { + t.name: + torch.empty(tuple(t.shape), + dtype=trt_dtype_to_torch(t.dtype), + device='cuda') + for t in output_info + } + stream = torch.cuda.current_stream() + ok = self.session.run(inputs=inputs, + outputs=outputs, + stream=stream.cuda_stream) + assert ok, 'Engine execution failed' + stream.synchronize() + encoder_output = outputs['encoder_output'] + return encoder_output + + #inputs['position_ids'] = position_ids + return inputs + + +def remove_tensor_padding(input_tensor, + input_tensor_lengths=None, + pad_value=None): + if pad_value: + assert input_tensor_lengths is None, "input_tensor_lengths should be None when pad_value is provided" + # Text tensor case: batch, seq_len + assert torch.all( + input_tensor[:, 0] != + pad_value), "First token in each sequence should not be pad_value" + assert input_tensor_lengths is None + + # Create a mask for all non-pad tokens + mask = input_tensor != pad_value + + # Apply the mask to input_tensor to remove pad tokens + output_tensor = input_tensor[mask].view(1, -1) + + else: + # Audio tensor case: batch, seq_len, feature_len + # position_ids case: batch, seq_len + assert input_tensor_lengths is not None, "input_tensor_lengths must be provided for 3D input_tensor" + + # Initialize a list to collect valid sequences + valid_sequences = [] + + for i in range(input_tensor.shape[0]): + valid_length = input_tensor_lengths[i] + valid_sequences.append(input_tensor[i, :valid_length]) + + # Concatenate all valid sequences along the batch dimension + output_tensor = torch.cat(valid_sequences, dim=0) + return output_tensor + + +def evaluate(args): + # She had her dark suit in greasy wash water all year. + audio_context_num_tokens = 2048 + audio_num_codebooks = 8 + batch_size = 1 + + text_encodings = [ + 96, 40, 29, 26, 93, 90, 55, 74, 52, 93, 29, 26, 39, 93, 90, 52, 77, 85, + 58, 93, 90, 64, 66, 65, 93, 84, 61, 93, 28, 39, 26, 22, 40, 46, 93, 90, + 68, 77, 86, 93, 44, 22, 41, 26, 39, 93, 90, 78, 59, 93, 90, 57, 84, 85, + 97 + ] + text_encodings = torch.IntTensor(text_encodings) + + audio_context = torch.load('context_codes_bos_scaled.pt').flatten().cuda() + + eos_token_id = 2047 + #encoder = MagpieEncoding(args.engine_dir) + #encoder_output = encoder.get_encoder_feature() + #torch.save(encoder_output, "encoder_output.pt") + + runner_kwargs = dict( + engine_dir=args.engine_dir, + is_enc_dec=True, + max_input_len=1024, + cross_kv_cache_fraction=0.5, + rank=0, + ) + + tllm_model = ModelRunnerCpp.from_dir(**runner_kwargs) + + #inference_dtype = tllm_model.encoder_model_config.dtype + batch_input_ids = [audio_context] * batch_size + encoder_input_ids = [text_encodings] * batch_size + print(f"{audio_context.shape=}, {text_encodings.shape=}") + return_dict = False # when set return_dict=True, get outputs by key + tik = time.time() + return_dict = True + + tllm_output = tllm_model.generate( + batch_input_ids=batch_input_ids, + encoder_input_ids=encoder_input_ids, + max_new_tokens=1024, + bos_token_id=2046, + pad_token_id=0, + eos_token_id=eos_token_id, + streaming=False, + return_dict=return_dict, + ) + + torch.cuda.synchronize() + tok = time.time() + batch_size = len(batch_input_ids) + if return_dict: + tllm_output_ids = tllm_output['output_ids'] + else: + tllm_output_ids = tllm_output + tllm_output_ids = tllm_output_ids % audio_context_num_tokens + + print(f"{tllm_output_ids.shape=}") + + if tensorrt_llm.mpi_rank() == 0: + __output_ids__ = tllm_output_ids.reshape(tllm_output_ids.shape[0], -1, + audio_num_codebooks) + + output_ids_is_eos = torch.where(__output_ids__ == eos_token_id, 1, 0) + trim_output_idx = torch.argmin(torch.where( + torch.sum(output_ids_is_eos, dim=-1) == 8, 0, 1), + dim=1) + + output_ids = [ + __output_ids__[i, :trim_output_idx[i], :] for i in range(batch_size) + ] + + print("--------------------------------------") + print("TRT-LLM output_ids: ", output_ids) + print(f"TRT-LLM E2E time {(tok-tik)*1000}ms") + print("--------------------------------------") + torch.save(output_ids, "output_ids.pt") + return output_ids + + +def print_tensor(tensor_name, tensor, num_elements=10): + if tensor.dtype in (torch.int32, torch.int64): + tensor = tensor.to(dtype=float) + print( + f'{tensor_name}: mean={tensor.abs().mean().item():.3f}, sum={tensor.abs().sum().item():.3f}, max={tensor.abs().max().item():.3f}' + ) + # Pass num_elements=-1 will print the whole tensor + if num_elements < 0: + num_elements = torch.numel(tensor) + print(f'{tensor.flatten()[:num_elements]}') + print("Tensor Shape: ", tensor.size()) + print("") + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument("--max_new_tokens", type=int, default=64) + parser.add_argument("--log_level", type=str, default="error") + parser.add_argument("--engine_dir", "-i", type=str, default="engines") + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--model_name", + type=str, + help="HuggingFace model name or FairSeq model path", + default="t5tts") + parser.add_argument("--num_beams", + type=int, + help="Use beam search if num_beams >1", + default=1) + parser.add_argument("--debug_mode", + help="Whether or not to turn on the debug mode", + action='store_true') + parser.add_argument("--compare_hf_fp32", + help="Compare results with HuggingFace FP32", + action='store_true') + parser.add_argument('--lora_dir', type=str, default=None, nargs="+") + parser.add_argument('--lora_task_uids', type=str, default=None, nargs="+") + parser.add_argument("--output_npy", + type=str, + default=None, + help="Store input/output tensors C++ runtime testing") + return parser.parse_args() + + +if __name__ == "__main__": + import os + + os.environ["TOKENIZERS_PARALLELISM"] = "false" + args = parse_arguments() + logger.set_level(args.log_level) + evaluate(args) diff --git a/tensorrt_llm/builder.py b/tensorrt_llm/builder.py index 831df55f803..f6886aea2af 100644 --- a/tensorrt_llm/builder.py +++ b/tensorrt_llm/builder.py @@ -830,7 +830,7 @@ def optimize_model_with_config(model: PretrainedModel, if build_config.plugin_config.lora_plugin is not None: model.use_lora(build_config.lora_config) - is_enc_dec = model.config.architecture in ["EncoderModel", "DecoderModel"] + is_enc_dec = model.config.architecture in ["EncoderModel", "DecoderModel", "T5TTSEncoderModel", "T5TTSDecoderModel"] # FusedMLP does not support RecurrentGemma FP8 currently. is_recurrent_gemma = model.config.architecture in [ "RecurrentGemmaForCausalLM" @@ -1270,8 +1270,10 @@ def build(model: PretrainedModel, build_config: BuildConfig) -> Engine: build_config.lora_config.lora_target_modules } - if model.config.architecture == "DecoderModel" or "mllama" in model.config.architecture.lower( - ): + is_mllama = "mllama" in model.config.architecture.lower() + is_decoder_model = model.config.architecture == "DecoderModel" + is_t5tts_model = model.config.architecture == "T5TTSDecoderModel" + if is_mllama or is_decoder_model or is_t5tts_model: prepare_input_args["max_seq_len"] = build_config.max_seq_len prepare_input_args[ "max_decoder_input_len"] = build_config.max_input_len diff --git a/tensorrt_llm/models/__init__.py b/tensorrt_llm/models/__init__.py index b97966823d7..dab5dcfe17d 100755 --- a/tensorrt_llm/models/__init__.py +++ b/tensorrt_llm/models/__init__.py @@ -62,6 +62,7 @@ from .recurrentgemma.model import RecurrentGemmaForCausalLM from .redrafter.model import ReDrafterForCausalLM from .stdit.model import STDiT3Model +from .t5tts.model import T5TTSDecoderModel, T5TTSEncoderModel __all__ = [ 'BertModel', @@ -133,6 +134,8 @@ 'SpeculativeDecodingMode', 'CohereForCausalLM', 'MLLaMAForCausalLM', + 'T5TTSEncoderModel', + 'T5TTSDecoderModel', ] MODEL_MAP = { @@ -216,4 +219,6 @@ 'RobertaModel': RobertaModel, 'RobertaForQuestionAnswering': RobertaForQuestionAnswering, 'RobertaForSequenceClassification': RobertaForSequenceClassification, + 'T5TTSEncoderModel': T5TTSEncoderModel, + 'T5TTSDecoderModel': T5TTSDecoderModel, } diff --git a/tensorrt_llm/models/t5tts/__init__.py b/tensorrt_llm/models/t5tts/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tensorrt_llm/models/t5tts/model.py b/tensorrt_llm/models/t5tts/model.py new file mode 100644 index 00000000000..902a2d4d2e6 --- /dev/null +++ b/tensorrt_llm/models/t5tts/model.py @@ -0,0 +1,1872 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from collections import OrderedDict +from typing import Optional + +import tensorrt as trt +import torch + +from tensorrt_llm._common import default_net +from tensorrt_llm._utils import numpy_to_torch, str_dtype_to_torch +from tensorrt_llm.functional import (ACT2FN, LayerNormPositionType, + LayerNormType, MLPType, + PositionEmbeddingType, Tensor, assertion, + concat, gather_last_token_logits, maximum, + mean, minimum, recv, send, shape, squeeze, + unsqueeze, view) +from tensorrt_llm.layers import (MLP, Attention, AttentionMaskParams, + AttentionMaskType, AttentionParams, + BertAttention, ColumnLinear, Conv1d, Embedding, + FusedGatedMLP, GatedMLP, GroupNorm, + KeyValueCacheParams, LayerNorm, + PromptTuningEmbedding, RmsNorm) +from tensorrt_llm.mapping import Mapping +from tensorrt_llm.models.modeling_utils import PretrainedConfig, PretrainedModel +from tensorrt_llm.module import Module, ModuleList +from tensorrt_llm.parameter import Parameter +from tensorrt_llm.plugin.plugin import current_all_reduce_helper + +layernorm_map = { + LayerNormType.LayerNorm: LayerNorm, + LayerNormType.RmsNorm: RmsNorm, + LayerNormType.GroupNorm: GroupNorm, +} + +mlp_map = { + MLPType.MLP: MLP, + MLPType.GatedMLP: GatedMLP, + MLPType.FusedGatedMLP: FusedGatedMLP, +} + + +class PositionwiseConvFF(Module): + + def __init__( + self, + hidden_size: int, + ffn_hidden_size: int, + kernel_size: int = 3, + has_bias: bool = False, + is_causal: bool = True, + hidden_act: str = 'gelu', + padding: Optional[int] = None, + dilation: int = 1, + dtype=None, + groups: int = 1, + ): + super().__init__() + + self.is_causal = is_causal + self.hidden_size = hidden_size + self.pos_ffn_hidden_size = ffn_hidden_size + + self.hidden_act = ACT2FN[hidden_act] + + if self.is_causal: + self.causal_padding = ((kernel_size - 1) * dilation, 0) + + padding = 0 + + elif padding is None: + if kernel_size % 2 == 0: + raise ValueError( + "`kernel_size` must be odd when `padding` is None.") + + padding = int(dilation * (kernel_size - 1) / 2) + + self.proj = Conv1d(hidden_size, + ffn_hidden_size, + kernel_size=kernel_size, + padding=padding, + bias=has_bias, + dilation=dilation, + dtype=dtype) + self.o_net = Conv1d(ffn_hidden_size, + hidden_size, + kernel_size=kernel_size, + padding=padding, + bias=has_bias, + dilation=dilation, + dtype=dtype) + + def forward(self, x: Tensor) -> Tensor: + dim_sqz = x.ndim() + + if self.is_causal: + # currently for T5TTS causal padding is only required for decoder + # but is 0 because (kernel_size-1)=0 for decoder, padding=(kernel_size - 1) * dilation + pass + x = unsqueeze(x, dim_sqz) + + x = self.proj(x) + x = self.o_net(self.hidden_act(x)) + x = squeeze(x, dim_sqz) + return x + + +class PositionalEmbedding(Module): + + def __init__(self, + max_position_embeddings, + hidden_size, + has_embedding_layernorm=False, + has_embedding_scale=False, + layernorm_eps=1e-5, + layernorm_type=LayerNormType.LayerNorm, + dtype=None, + use_parallel_embedding=False, + embedding_sharding_dim=0, + mapping=Mapping()): + super().__init__() + + self.layernorm_type = layernorm_type + ln_type = layernorm_map[layernorm_type] + + self.max_position_embeddings = max_position_embeddings + self.position_embedding = None + self.position_embedding = Embedding( + max_position_embeddings, + hidden_size, + dtype=dtype, + tp_size=mapping.tp_size if use_parallel_embedding else 1, + tp_group=mapping.tp_group if use_parallel_embedding else None, + sharding_dim=embedding_sharding_dim, + tp_rank=mapping.tp_rank) + + self.embedding_layernorm = None + if has_embedding_layernorm: + self.embedding_layernorm = ln_type(normalized_shape=hidden_size, + eps=layernorm_eps, + dtype=dtype) + + self.embedding_scale = 1.0 + if has_embedding_scale: + self.embedding_scale = math.sqrt(hidden_size) + + def forward(self, + input_ids, + position_ids=None, + prompt_tasks=None, + prompt_vocab_size=None): + pos_emb = self.position_embedding(position_ids) + x = input_ids + pos_emb + if self.embedding_layernorm: + x = self.embedding_layernorm(x) + return x + + +class EncoderDecoderEmbedding(Module): + + def __init__(self, + vocab_size, + num_vocabs, + hidden_size, + max_position_embeddings=None, + has_position_embedding=False, + type_vocab_size=None, + has_embedding_layernorm=False, + has_embedding_scale=False, + layernorm_eps=1e-5, + layernorm_type=LayerNormType.LayerNorm, + dtype=None, + use_parallel_embedding=False, + embedding_sharding_dim=0, + mapping=Mapping()): + super().__init__() + + self.num_vocabs = num_vocabs + self.layernorm_type = layernorm_type + ln_type = layernorm_map[layernorm_type] + + self.vocab_embedding = Embedding( + vocab_size, + hidden_size, + dtype=dtype, + tp_size=mapping.tp_size if use_parallel_embedding else 1, + tp_group=mapping.tp_group if use_parallel_embedding else None, + sharding_dim=embedding_sharding_dim, + tp_rank=mapping.tp_rank) + + self.position_embedding = None + self.max_position_embeddings = max_position_embeddings + if has_position_embedding: + self.position_embedding = Embedding( + max_position_embeddings, + hidden_size, + dtype=dtype, + tp_size=mapping.tp_size if use_parallel_embedding else 1, + tp_group=mapping.tp_group if use_parallel_embedding else None, + sharding_dim=embedding_sharding_dim, + tp_rank=mapping.tp_rank) + + self.token_type_embedding = None + if type_vocab_size: + self.token_type_embedding = Embedding( + type_vocab_size, + hidden_size, + dtype=dtype, + tp_size=mapping.tp_size if use_parallel_embedding else 1, + tp_group=mapping.tp_group if use_parallel_embedding else None, + sharding_dim=embedding_sharding_dim, + tp_rank=mapping.tp_rank) + + # e.g. BART true, T5 false + self.embedding_layernorm = None + if has_embedding_layernorm: + self.embedding_layernorm = ln_type(normalized_shape=hidden_size, + eps=layernorm_eps, + dtype=dtype) + + # e.g. BART true, T5 false + self.embedding_scale = 1.0 + if has_embedding_scale: + self.embedding_scale = math.sqrt(hidden_size) + + # Note: embedding offset in BART is not considered as a standard. For the specific case, + # we just need to shrink its position embedding table by [offset:] during weight loading + + def forward(self, + input_ids, + position_ids=None, + token_type_ids=None, + prompt_embedding_table=None, + prompt_tasks=None, + prompt_vocab_size=None): + # position_ids and token_type_ids are provided inputs + # and should not be formulated deterministically + + args = [prompt_embedding_table, prompt_tasks, prompt_vocab_size + ] if prompt_embedding_table is not None else [] + + x = self.vocab_embedding(input_ids, *args) * self.embedding_scale + if self.num_vocabs > 1: + x = view(x, + concat( + [shape(x, 0) / self.num_vocabs, self.num_vocabs, + -1])) # shape [totalSeqLen, nVocab, embDim] + # average across vocabs + x = mean(x, 1) # shape [totalSeqLen, embDim] + + if self.position_embedding: + pos_emb = self.position_embedding(position_ids) + x = x + pos_emb + if self.token_type_embedding: + x = x + self.token_type_embedding(token_type_ids) + + if self.embedding_layernorm: + x = self.embedding_layernorm(x) + + return x + + +class T5TTSEncoderLayer(Module): + + def __init__(self, + hidden_size, + ffn_hidden_size, + num_attention_heads, + num_kv_heads, + head_size, + max_position_embeddings=None, + q_scaling=1.0, + has_attention_qkvo_bias=False, + has_pos_ff_bias=False, + layernorm_position=LayerNormPositionType.pre_layernorm, + layernorm_type=LayerNormType.LayerNorm, + layernorm_eps=1e-5, + hidden_act="gelu", + mapping=Mapping(), + dtype=None, + residual_scaling=1.0, + relative_attention=False, + max_distance=0, + num_buckets=0, + fp16_clamping=False, + conv_is_causal=False): + super().__init__() + + # e.g. BART regular, T5 RMS + self.layernorm_type = layernorm_type + ln_type = layernorm_map[layernorm_type] + + # e.g. BART post, T5 pre + self.layernorm_position = layernorm_position + + # e.g. BART q_scaling = 1.f, T5 q_scaling = 1.f/sqrt(head_size) + self.attention = BertAttention( + hidden_size, + num_attention_heads, + attention_head_size=head_size, + num_kv_heads=num_kv_heads, + max_position_embeddings=max_position_embeddings, + q_scaling=q_scaling, + bias=has_attention_qkvo_bias, + tp_group=mapping.tp_group, + tp_size=mapping.tp_size, + tp_rank=mapping.tp_rank, + dtype=dtype, + relative_attention=relative_attention, + max_distance=max_distance, + num_buckets=num_buckets) + + self.attention_layernorm = ln_type(normalized_shape=hidden_size, + eps=layernorm_eps, + dtype=dtype, + bias=False) + + self.pos_ff = PositionwiseConvFF( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + hidden_act=hidden_act, + has_bias=has_pos_ff_bias, + kernel_size=3, + padding=1, + groups=mapping.tp_group, + dtype=dtype, + is_causal=conv_is_causal, + ) + + self.pos_ff_layernorm = ln_type(normalized_shape=hidden_size, + eps=layernorm_eps, + dtype=dtype, + bias=False) + + self.residual_scaling = residual_scaling + + # T5-series model(e.g. t5-large, t5-3b, flan-t5-small) has accuracy issue due to fp16 overflow + # after residual add. We add workaround for clamping fp16 range [-64000, 64000] after every + # residual add to avoid accuracy drop. + self.fp16_clamping = fp16_clamping + + def forward(self, + hidden_states: Tensor, + attention_mask=None, + input_lengths=None, + max_input_length=None): + assert isinstance(hidden_states, Tensor) + + # self attention + residual = hidden_states * self.residual_scaling + + if self.layernorm_position == LayerNormPositionType.pre_layernorm: + hidden_states = self.attention_layernorm(hidden_states) + + attention_output = self.attention(hidden_states, + attention_mask=attention_mask, + input_lengths=input_lengths, + max_input_length=max_input_length) + + self.register_network_output('attention_output', attention_output) + + hidden_states = residual + attention_output + + if self.fp16_clamping: + hidden_states = maximum(-64000.0, hidden_states) + hidden_states = minimum(64000.0, hidden_states) + + if self.layernorm_position == LayerNormPositionType.post_layernorm: + hidden_states = self.attention_layernorm(hidden_states) + + # MLP + residual = hidden_states * self.residual_scaling + + if self.layernorm_position == LayerNormPositionType.pre_layernorm: + hidden_states = self.pos_ff_layernorm(hidden_states) + + hidden_states = self.pos_ff(hidden_states) + + self.register_network_output('pos_ff_output', hidden_states) + + hidden_states = residual + hidden_states + + if self.fp16_clamping: + hidden_states = maximum(-64000.0, hidden_states) + hidden_states = minimum(64000.0, hidden_states) + + if self.layernorm_position == LayerNormPositionType.post_layernorm: + hidden_states = self.pos_ff_layernorm(hidden_states) + + return hidden_states + + +class T5TTSDecoderLayer(Module): + + def __init__(self, + *, + local_layer_idx, + hidden_size, + ffn_hidden_size, + num_attention_heads, + num_kv_heads, + head_size, + max_position_embeddings=None, + q_scaling=1.0, + has_attention_qkvo_bias=False, + has_pos_ff_bias=False, + has_encoder_input_layernorm=False, + layernorm_position=LayerNormPositionType.pre_layernorm, + layernorm_type=LayerNormType.LayerNorm, + layernorm_eps=1e-5, + hidden_act="gelu", + mapping=Mapping(), + dtype=None, + residual_scaling=1.0, + relative_attention=False, + max_distance=0, + num_buckets=0, + fp16_clamping=False, + skip_cross_kv=False, + use_implicit_relative_attention=False): + super().__init__() + + self.has_encoder_input_layernorm = has_encoder_input_layernorm + + # e.g. BART regular, T5 RMS + self.layernorm_type = layernorm_type + ln_type = layernorm_map[layernorm_type] + + # e.g. BART post, T5 pre + self.layernorm_position = layernorm_position + + # e.g. BART q_scaling = 1.f, T5 q_scaling = 1.f/sqrt(head_size) + self.self_attention = Attention( + local_layer_idx=local_layer_idx, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + attention_head_size=head_size, + num_kv_heads=num_kv_heads, + max_position_embeddings=max_position_embeddings, + q_scaling=q_scaling, + bias=has_attention_qkvo_bias, + attention_mask_type=AttentionMaskType.causal, + tp_group=mapping.tp_group, + tp_size=mapping.tp_size, + tp_rank=mapping.tp_rank, + dtype=dtype, + cross_attention=False, + relative_attention=relative_attention, + max_distance=max_distance if use_implicit_relative_attention else 0, + num_buckets=num_buckets, + position_embedding_type=PositionEmbeddingType.relative + if relative_attention else PositionEmbeddingType.learned_absolute, + use_implicit_relative_attention=use_implicit_relative_attention) + + self.self_attention_layernorm = ln_type(normalized_shape=hidden_size, + eps=layernorm_eps, + dtype=dtype, + bias=False) + + # Note: self attn uses MMHA, mask is always causal triangular + # cross attn has two scenarios: + # - in context phase, all ones mask, same as padding type + # - in generation phase, same causal triangular mask as MMHA + # - context phase special handling is done in plugin by resetting mask type + # + # e.g. BART q_scaling = 1.f, T5 q_scaling = 1.f/sqrt(head_size) + self.cross_attention = Attention( + local_layer_idx=local_layer_idx, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + attention_head_size=head_size, + num_kv_heads=num_kv_heads, + max_position_embeddings=max_position_embeddings, + q_scaling=q_scaling, + bias=has_attention_qkvo_bias, + attention_mask_type=AttentionMaskType.causal, + tp_group=mapping.tp_group, + tp_size=mapping.tp_size, + tp_rank=mapping.tp_rank, + dtype=dtype, + cross_attention=True, + relative_attention= + False, # Cross attention has no relative attention bias + max_distance=max_distance, + num_buckets=num_buckets, + position_embedding_type=PositionEmbeddingType.learned_absolute, + skip_cross_kv=skip_cross_kv) + + self.cache_cross_attention_memory = None + if has_encoder_input_layernorm: + self.cross_attention_memory_layernorm = ln_type( + normalized_shape=hidden_size, + eps=layernorm_eps, + dtype=dtype, + bias=False) + + self.cross_attention_layernorm = ln_type(normalized_shape=hidden_size, + eps=layernorm_eps, + dtype=dtype, + bias=False) + + self.pos_ff = PositionwiseConvFF( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + kernel_size=1, + padding=0, + hidden_act=hidden_act, + has_bias=has_pos_ff_bias, + groups=mapping.tp_group, + dtype=dtype, + is_causal=True, + ) + + self.pos_ff_layernorm = ln_type(normalized_shape=hidden_size, + eps=layernorm_eps, + dtype=dtype, + bias=False) + + self.residual_scaling = residual_scaling + + # T5-series model(e.g. t5-large, t5-3b, flan-t5-small) has accuracy issue due to fp16 overflow + # after residual add. We add workaround for clamping fp16 range [-64000, 64000] after every + # residual add to avoid accuracy drop. + self.fp16_clamping = fp16_clamping + + def forward(self, + hidden_states: Tensor, + encoder_output: Optional[Tensor] = None, + attention_mask_params=None, + use_cache=False, + kv_cache_params=None, + attention_params=None, + cross_kv_cache_gen: Optional[Tensor] = None, + cross_kv_reuse: Optional[Tensor] = None): + assert isinstance(hidden_states, Tensor) + + if encoder_output: + assert isinstance(encoder_output, Tensor) + + # self-attention + residual = hidden_states * self.residual_scaling + + if self.layernorm_position == LayerNormPositionType.pre_layernorm: + hidden_states = self.self_attention_layernorm(hidden_states) + + attention_output = self.self_attention( + hidden_states=hidden_states, + attention_mask=attention_mask_params.self_attention_mask, + use_cache=use_cache, + kv_cache_params=kv_cache_params, + attention_params=attention_params) + + if use_cache: + attention_output, presents_self = attention_output + + self.register_network_output('self_attention_output', attention_output) + + hidden_states = residual + attention_output + + if self.fp16_clamping: + hidden_states = maximum(-64000.0, hidden_states) + hidden_states = minimum(64000.0, hidden_states) + + if self.layernorm_position == LayerNormPositionType.post_layernorm: + hidden_states = self.self_attention_layernorm(hidden_states) + + # cross attention + residual = hidden_states * self.residual_scaling + + if self.layernorm_position == LayerNormPositionType.pre_layernorm: + hidden_states = self.cross_attention_layernorm(hidden_states) + + attention_output = self.cross_attention( + hidden_states=hidden_states, + attention_mask=attention_mask_params.cross_attention_mask, + attention_packed_mask=attention_mask_params. + cross_attention_packed_mask, + encoder_output=encoder_output, + use_cache=use_cache, + kv_cache_params=kv_cache_params, + attention_params=attention_params, + cross_kv_cache_gen=cross_kv_cache_gen, + cross_kv_reuse=cross_kv_reuse) + + if use_cache: + attention_output, presents_cross = attention_output + + self.register_network_output('cross_attention_output', attention_output) + + hidden_states = residual + attention_output + + if self.fp16_clamping: + hidden_states = maximum(-64000.0, hidden_states) + hidden_states = minimum(64000.0, hidden_states) + + if self.layernorm_position == LayerNormPositionType.post_layernorm: + hidden_states = self.cross_attention_layernorm(hidden_states) + + # MLP + residual = hidden_states * self.residual_scaling + + if self.layernorm_position == LayerNormPositionType.pre_layernorm: + hidden_states = self.pos_ff_layernorm(hidden_states) + + hidden_states = self.pos_ff(hidden_states) + self.register_network_output('pos_ff_output', hidden_states) + + hidden_states = residual + hidden_states + + if self.fp16_clamping: + hidden_states = maximum(-64000.0, hidden_states) + hidden_states = minimum(64000.0, hidden_states) + + if self.layernorm_position == LayerNormPositionType.post_layernorm: + hidden_states = self.mlp_layernorm(hidden_states) + + if use_cache: + return (hidden_states, presents_self, presents_cross) + return hidden_states + + +class T5TTSEncoderModel(PretrainedModel): + + def __init__(self, config: PretrainedConfig): + self.check_config(config) + super().__init__(config) + self.mapping = self.config.mapping + + self.has_position_embedding = self.config.has_position_embedding + type_vocab_size = self.config.type_vocab_size + self.has_token_type_embedding = False if type_vocab_size is None else True + + # e.g. BART regular, T5 RMS + self.layernorm_type = self.config.layernorm_type + ln_type = layernorm_map[self.layernorm_type] + + # e.g. BART true, T5 false + self.has_attention_qkvo_bias = self.config.has_attention_qkvo_bias + self.has_pos_ff_bias = self.config.has_pos_ff_bias + + # e.g. BART false, T5 true + self.has_model_final_layernorm = self.config.has_model_final_layernorm + + self._dtype = self.config.dtype + + self.total_num_layers = self.config.num_hidden_layers + self.num_layers = self.config.num_hidden_layers // self.mapping.pp_size + + self.hidden_size = self.config.hidden_size + self.num_heads = self.config.num_attention_heads + num_kv_heads = self.num_heads + if num_kv_heads is None or num_kv_heads <= 0: + num_kv_heads = self.config.num_attention_heads + self.num_kv_heads = num_kv_heads + self.head_size = self.hidden_size // self.num_heads if self.config.head_size is None else self.config.head_size + + self.fp16_clamping = (self.config.dtype + == 'float16') and (self.config.model_type == 't5') + self.mlp_type = MLPType.MLP if not hasattr( + self.config, "mlp_type") else self.config.mlp_type + + if self.mapping.is_first_pp_rank(): + self.embedding = EncoderDecoderEmbedding( + self.config.vocab_size, + 1, # number of vocabs + self.config.hidden_size, + max_position_embeddings=self.config.max_position_embeddings, + has_position_embedding=self.has_position_embedding, + type_vocab_size=type_vocab_size, + has_embedding_layernorm=self.config.has_embedding_layernorm, + has_embedding_scale=self.config.has_embedding_scale, + layernorm_eps=self.config.norm_epsilon, + layernorm_type=self.layernorm_type, + dtype=self.config.dtype, + use_parallel_embedding=self.config.use_parallel_embedding, + embedding_sharding_dim=self.config.embedding_sharding_dim, + mapping=self.mapping) + + self.encoder_layers = ModuleList([ + T5TTSEncoderLayer( + hidden_size=self.hidden_size, + ffn_hidden_size=self.config.intermediate_size, + num_attention_heads=self.num_heads, + num_kv_heads=num_kv_heads, + head_size=self.head_size, + max_position_embeddings=self.config.max_position_embeddings, + q_scaling=self.config.q_scaling, + has_attention_qkvo_bias=self.has_attention_qkvo_bias, + has_pos_ff_bias=self.has_pos_ff_bias, + layernorm_position=self.config.layernorm_position, + layernorm_eps=self.config.norm_epsilon, + layernorm_type=self.layernorm_type, + hidden_act=self.config.hidden_act, + mapping=self.mapping, + dtype=self.config.dtype, + residual_scaling=1.0 + if not hasattr(self.config, "residual_scaling") else + self.config.residual_scaling, + relative_attention=self.config.relative_attention, + max_distance=self.config.max_distance, + num_buckets=self.config.num_buckets, + fp16_clamping=self.fp16_clamping) + for _ in self.mapping.pp_layers(self.total_num_layers) + ]) + + if self.mapping.is_last_pp_rank(): + if self.has_model_final_layernorm: + self.final_layernorm = ln_type( + normalized_shape=self.config.hidden_size, + eps=self.config.norm_epsilon, + dtype=self.config.dtype, + bias=self.config.has_final_layernorm_bias) + + def check_config(self, config: PretrainedConfig): + config.set_if_not_exist('has_position_embedding', False) + config.set_if_not_exist('type_vocab_size', None) + config.set_if_not_exist('rescale_before_lm_head', False) + config.set_if_not_exist('layernorm_type', LayerNormType.LayerNorm) + config.set_if_not_exist('layernorm_position', + LayerNormPositionType.pre_layernorm) + config.set_if_not_exist('has_attention_qkvo_bias', False) + config.set_if_not_exist('has_pos_ff_bias', False) + config.set_if_not_exist('has_model_final_layernorm', False) + config.set_if_not_exist('encoder_hidden_size', None) + config.set_if_not_exist('encoder_num_heads', None) + config.set_if_not_exist('encoder_num_kv_heads', None) + config.set_if_not_exist('encoder_head_size', None) + config.set_if_not_exist('model_type', 't5') + config.set_if_not_exist('skip_cross_kv', False) + config.set_if_not_exist('has_embedding_scale', False) + config.set_if_not_exist('residual_scaling', 1.0) + config.set_if_not_exist('has_lm_head_bias', False) + config.set_if_not_exist('has_final_layernorm_bias', False) + config.set_if_not_exist('num_buckets', None) + config.set_if_not_exist('max_distance', None) + config.set_if_not_exist('relative_attention', False) + config.set_if_not_exist('residual_scaling', 1.0) + + def forward(self, + input_ids: Tensor, + input_lengths=None, + position_ids=None, + token_type_ids=None, + hidden_states=None, + max_input_length=None, + prompt_embedding_table=None, + prompt_tasks=None, + prompt_vocab_size=None, + attention_mask=None): + + # In PP, layer 0 has ids as inputs, all other layers have hidden_states as inputs + if self.mapping.is_first_pp_rank(): + ptuning_args = [ + prompt_embedding_table, prompt_tasks, prompt_vocab_size + ] if prompt_embedding_table is not None else [] + + hidden_states = self.embedding(input_ids, position_ids, + token_type_ids, *ptuning_args) + self.register_network_output('embedding_layer_output', + hidden_states) + else: + hidden_states = recv(hidden_states, self.mapping.prev_pp_rank()) + + for layer_idx, encoder_layer in enumerate(self.encoder_layers): + + hidden_states = encoder_layer(hidden_states=hidden_states, + attention_mask=attention_mask, + input_lengths=input_lengths, + max_input_length=max_input_length) + + if self.mapping.is_last_pp_rank(): + if self.has_model_final_layernorm: + hidden_states = self.final_layernorm(hidden_states) + hidden_states.mark_output('encoder_output', self._dtype) + else: + hidden_states = send(hidden_states, self.mapping.next_pp_rank()) + hidden_states.mark_output('hidden_states_output', self._dtype) + self.register_network_output('hidden_states_output', hidden_states) + + return hidden_states + + def prepare_inputs(self, + max_batch_size, + max_input_len, + prompt_embedding_table_size: int = 0, + *args, + **kwargs): + '''@brief: Prepare inputs Tensors for the model, the given sizes are used to determine the + ranges of the dimensions of when using TRT dynamic shapes. + + @return: a list contains values which can be fed into the self.forward() + ''' + + hidden_size = self.hidden_size + + bs_range = [1, (max_batch_size + 1) // 2, max_batch_size] + inlen_range = [1, (max_input_len + 1) // 2, max_input_len] + num_tokens_range = [ + 1, + (max_input_len * max_batch_size + 1) // 2, + max_input_len * max_batch_size, + ] + + input_ids, position_ids, token_type_ids, hidden_states = None, None, None, None + remove_input_padding = default_net().plugin_config.remove_input_padding + + attention_mask = None + if remove_input_padding: + if self.mapping.is_first_pp_rank(): + input_ids = Tensor( + name="input_ids", + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([("num_tokens", [num_tokens_range])]), + ) + if self.has_position_embedding: + position_ids = Tensor( + name='position_ids', + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([('num_tokens', + [num_tokens_range])]), + ) + if self.has_token_type_embedding: + token_type_ids = Tensor( + name='token_type_ids', + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([('num_tokens', + [num_tokens_range])]), + ) + else: + hidden_states = Tensor(name='hidden_states_input', + dtype=self._dtype, + shape=[-1, hidden_size], + dim_range=OrderedDict([ + ('num_tokens', [num_tokens_range]), + ('hidden_size', [hidden_size]), + ])) + else: + if self.mapping.is_first_pp_rank(): + input_ids = Tensor( + name="input_ids", + dtype=trt.int32, + shape=[-1, -1], + dim_range=OrderedDict([("batch_size", [bs_range]), + ("input_len", [inlen_range])]), + ) + if self.has_position_embedding: + position_ids = Tensor( + name='position_ids', + dtype=trt.int32, + shape=[-1, -1], + dim_range=OrderedDict([('batch_size', [bs_range]), + ('input_len', [inlen_range])]), + ) + if self.has_token_type_embedding: + token_type_ids = Tensor( + name='token_type_ids', + dtype=trt.int32, + shape=[-1, -1], + dim_range=OrderedDict([('batch_size', [bs_range]), + ('input_len', [inlen_range])]), + ) + else: + hidden_states = Tensor(name='hidden_states_input', + dtype=self._dtype, + shape=[-1, -1, hidden_size], + dim_range=OrderedDict([ + ('batch_size', [bs_range]), + ('input_len', [inlen_range]), + ('hidden_size', [hidden_size]), + ])) + + if not default_net().plugin_config.bert_attention_plugin: + attention_mask = Tensor( + name='attention_mask', + dtype=trt.int32, + shape=[-1, -1], + dim_range=OrderedDict([ + ('batch_size', [bs_range]), + ('input_len', [inlen_range]), + ]), + ) + + # if self.mapping.tp_size > 1: + # current_all_reduce_helper().set_workspace_tensor(self.mapping, 1) + # FIXME(TRTLLM-996): Support custom allreduce for encoder models on C++ runtime + + input_lengths = Tensor( + name="input_lengths", + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([("batch_size", [bs_range])]), + ) + max_input_length = Tensor( + name="max_input_length", + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([("max_input_length", [inlen_range])]), + ) + + prompt_embedding_table = None + tasks = None + prompt_vocab_size = None + + if self.mapping.is_first_pp_rank() and prompt_embedding_table_size > 0: + p_embedding_range = [[ + 1, prompt_embedding_table_size // 2, prompt_embedding_table_size + ]] + + prompt_embedding_table = Tensor(name='prompt_embedding_table', + dtype=self._dtype, + shape=[-1, hidden_size], + dim_range=OrderedDict([ + ('prompt_embedding_table_size', + p_embedding_range), + ('hidden_size', [hidden_size]), + ])) + if remove_input_padding: + tasks = Tensor(name='tasks', + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([('input_len_task', + [num_tokens_range])])) + else: + tasks = Tensor(name='tasks', + dtype=trt.int32, + shape=[-1, 1], + dim_range=OrderedDict([ + ('batch_size', bs_range), + ('broadcast_dim', [1]), + ])) + prompt_vocab_size = Tensor(name='prompt_vocab_size', + dtype=trt.int32, + shape=[1], + dim_range=OrderedDict([('size', [1])])) + + result = { + 'input_ids': input_ids, + 'input_lengths': input_lengths, + 'position_ids': position_ids, + 'token_type_ids': token_type_ids, + 'hidden_states': hidden_states, + 'max_input_length': max_input_length, + 'prompt_embedding_table': prompt_embedding_table, + 'prompt_tasks': tasks, + 'prompt_vocab_size': prompt_vocab_size, + 'attention_mask': attention_mask, + } + + return result + + def use_prompt_tuning(self): + embedding = self.embedding.vocab_embedding + self.embedding.vocab_embedding = PromptTuningEmbedding( + num_embeddings=embedding.num_embeddings, + embedding_dim=embedding.embedding_dim, + dtype=embedding.dtype, + tp_size=embedding.tp_size, + tp_group=embedding.tp_group, + sharding_dim=embedding.sharding_dim, + tp_rank=embedding.tp_rank) + + self.embedding.vocab_embedding.weight.value = embedding.weight.raw_value + + def precompute_relative_attention_bias(self, build_config): + pass + + +class T5TTSDecoderModel(PretrainedModel): + + def __init__(self, config: PretrainedConfig): + self.check_config(config) + super().__init__(config) + + self.mapping = self.config.mapping + self.num_vocabs = len(self.config.vocab_sizes) + + self.has_position_embedding = self.config.has_position_embedding + type_vocab_size = self.config.type_vocab_size + self.has_token_type_embedding = (type_vocab_size is not None) + self.rescale_before_lm_head = self.config.rescale_before_lm_head + + # e.g. BART regular, T5 RMS + self.layernorm_type = self.config.layernorm_type + ln_type = layernorm_map[self.layernorm_type] + + # e.g. BART true, T5 false + self.has_attention_qkvo_bias = self.config.has_attention_qkvo_bias + self.has_pos_ff_bias = self.config.has_pos_ff_bias + self.has_encoder_input_layernorm = self.config.has_encoder_input_layernorm + + # e.g. BART false, T5 true + self.has_model_final_layernorm = self.config.has_model_final_layernorm + self._dtype = self.config.dtype + # no quantization considered for now + self._kv_dtype = self._dtype + self._logits_dtype = self.config.logits_dtype + + self.total_num_layers = self.config.num_hidden_layers + self.num_layers = self.config.num_hidden_layers // self.mapping.pp_size + + self.hidden_size = self.config.hidden_size + self.num_heads = self.config.num_attention_heads + + num_kv_heads = self.num_heads + if num_kv_heads is None or num_kv_heads <= 0: + num_kv_heads = self.num_heads + self.num_kv_heads = num_kv_heads + self.head_size = self.hidden_size // self.num_heads if self.config.head_size is None else self.config.head_size + + self.encoder_hidden_size = self.config.encoder_hidden_size + self.encoder_num_heads = self.config.encoder_num_heads + encoder_num_kv_heads = None if not hasattr( + self.config, + "encoder_num_kv_heads") else self.config.encoder_num_kv_heads + if encoder_num_kv_heads is None or encoder_num_kv_heads <= 0: + encoder_num_kv_heads = self.encoder_num_heads + self.encoder_num_kv_heads = encoder_num_kv_heads + self.encoder_head_size = self.encoder_hidden_size // self.num_heads if self.config.encoder_head_size is None else self.config.encoder_head_size + + self.has_position_embedding = self.config.has_position_embedding + self.has_token_type_embedding = type_vocab_size is not None + + self.fp16_clamping = (self.config.dtype + == 'float16') and (self.config.model_type + in ['t5', 'pix2struct']) + + self.skip_cross_kv = self.config.skip_cross_kv + self.mlp_type = MLPType.MLP if not hasattr( + self.config, "mlp_type") else self.config.mlp_type + self.use_implicit_relative_attention = self.config.use_implicit_relative_attention if hasattr( + self.config, "use_implicit_relative_attention") else False + + if self.mapping.is_first_pp_rank(): + self.embedding = EncoderDecoderEmbedding( + self.config.vocab_size, + self.num_vocabs, + self.config.hidden_size, + max_position_embeddings=self.config.max_position_embeddings, + has_position_embedding=self.has_position_embedding, + type_vocab_size=type_vocab_size, + has_embedding_layernorm=self.config.has_embedding_layernorm, + has_embedding_scale=self.config.has_embedding_scale, + layernorm_eps=self.config.norm_epsilon, + layernorm_type=self.layernorm_type, + dtype=self.config.dtype, + use_parallel_embedding=self.config.use_parallel_embedding, + embedding_sharding_dim=self.config.embedding_sharding_dim, + mapping=self.mapping) + + layers_range = self.mapping.pp_layers(self.total_num_layers) + self.decoder_layers = ModuleList([ + T5TTSDecoderLayer( + local_layer_idx=layer_idx - layers_range[0], + hidden_size=self.config.hidden_size, + ffn_hidden_size=self.config.intermediate_size, + num_attention_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + max_position_embeddings=self.config.max_position_embeddings, + q_scaling=self.config.q_scaling, + has_attention_qkvo_bias=self.config.has_attention_qkvo_bias, + has_pos_ff_bias=self.config.has_pos_ff_bias, + has_encoder_input_layernorm=self.config. + has_encoder_input_layernorm, + layernorm_position=self.config.layernorm_position, + layernorm_eps=self.config.norm_epsilon, + layernorm_type=self.config.layernorm_type, + hidden_act=self.config.hidden_act, + mapping=self.mapping, + dtype=self._dtype, + residual_scaling=self.config.residual_scaling, + relative_attention=self.config.relative_attention, + max_distance=self.config.max_distance, + num_buckets=self.config.num_buckets, + fp16_clamping=self.fp16_clamping, + skip_cross_kv=self.skip_cross_kv, + use_implicit_relative_attention=self. + use_implicit_relative_attention) for layer_idx in layers_range + ]) + + if self.mapping.is_last_pp_rank(): + if self.has_model_final_layernorm: + self.final_layernorm = ln_type( + normalized_shape=self.config.hidden_size, + eps=self.config.norm_epsilon, + dtype=self.config.dtype, + bias=self.config.has_final_layernorm_bias) + + self.lm_head = ColumnLinear( + self.config.hidden_size, + self.config.vocab_size, + bias=False if not hasattr(self.config, "has_lm_head_bias") else + self.config.has_lm_head_bias, + dtype=self.config.dtype, + tp_group=self.config.mapping.tp_group, + tp_size=self.config.mapping.tp_size, + gather_output=True, + ) + + if self.config.relative_attention and not self.use_implicit_relative_attention: + self.rel_attn_table = Parameter( + shape=(self.config.num_attention_heads // self.mapping.tp_size, + self.config.num_buckets), + dtype=self._dtype) + + def check_config(self, config: PretrainedConfig): + config.set_if_not_exist('has_position_embedding', False) + config.set_if_not_exist('type_vocab_size', None) + config.set_if_not_exist('rescale_before_lm_head', False) + config.set_if_not_exist('layernorm_type', LayerNormType.LayerNorm) + config.set_if_not_exist('layernorm_position', + LayerNormPositionType.pre_layernorm) + config.set_if_not_exist('has_attention_qkvo_bias', False) + config.set_if_not_exist('has_pos_ff_bias', False) + config.set_if_not_exist('has_encoder_input_layernorm', True) + config.set_if_not_exist('has_model_final_layernorm', False) + config.set_if_not_exist('audio_embedding_dim', 768) + config.set_if_not_exist('has_final_layernorm_bias', False) + config.set_if_not_exist('encoder_hidden_size', None) + config.set_if_not_exist('encoder_num_heads', None) + config.set_if_not_exist('encoder_num_kv_heads', None) + config.set_if_not_exist('encoder_head_size', None) + config.set_if_not_exist('model_type', 't5') + config.set_if_not_exist('skip_cross_kv', False) + config.set_if_not_exist('has_embedding_scale', False) + config.set_if_not_exist('residual_scaling', 1.0) + config.set_if_not_exist('has_lm_head_bias', False) + config.set_if_not_exist('num_buckets', None) + config.set_if_not_exist('max_distance', None) + config.set_if_not_exist('relative_attention', False) + config.set_if_not_exist('residual_scaling', 1.0) + + def forward(self, + decoder_input_ids: Tensor, + encoder_output: Tensor, + position_ids=None, + token_type_ids=None, + use_cache=False, + attention_mask_params=None, + last_token_ids=None, + kv_cache_params=None, + attention_params=None, + hidden_states=None, + cross_kv_cache_gen: Optional[Tensor] = None, + cross_kv_reuse: Optional[Tensor] = None): + if self.mapping.is_first_pp_rank(): + assert isinstance(decoder_input_ids, Tensor) + else: + assert isinstance(hidden_states, Tensor) + + # In PP, layer 0 has ids as inputs, all other layers have hidden_states as inputs + if self.mapping.is_first_pp_rank(): + hidden_states = self.embedding(decoder_input_ids, position_ids, + None) + self.register_network_output('embedding_layer_output', + hidden_states) + else: + hidden_states = recv(hidden_states, self.mapping.prev_pp_rank()) + + kv_cache_params.fill_none_tensor_list(len(self.decoder_layers)) + + if use_cache: + presents = [] + + for i, (decoder_layer, past) in enumerate( + zip(self.decoder_layers, kv_cache_params.past_key_value)): + + hidden_states = decoder_layer( + hidden_states, + encoder_output=encoder_output, + attention_mask_params=attention_mask_params, + use_cache=use_cache, + kv_cache_params=KeyValueCacheParams( + past_key_value=past, + host_past_key_value_lengths=kv_cache_params. + host_past_key_value_lengths, + host_max_attention_window_sizes=kv_cache_params. + host_max_attention_window_sizes, + host_sink_token_length=kv_cache_params. + host_sink_token_length, + cache_indirection=kv_cache_params.cache_indirection, + kv_cache_block_offsets=kv_cache_params. + kv_cache_block_offsets, + host_kv_cache_block_offsets=kv_cache_params. + host_cross_kv_cache_block_offsets, + host_kv_cache_pool_pointers=kv_cache_params. + host_kv_cache_pool_pointers, + host_kv_cache_pool_mapping=kv_cache_params. + host_kv_cache_pool_mapping, + cross_kv_cache_block_offsets=kv_cache_params. + cross_kv_cache_block_offsets, + host_cross_kv_cache_block_offsets=kv_cache_params. + host_cross_kv_cache_block_offsets, + host_cross_kv_cache_pool_pointers=kv_cache_params. + host_cross_kv_cache_pool_pointers, + host_cross_kv_cache_pool_mapping=kv_cache_params. + host_cross_kv_cache_pool_mapping), + attention_params=attention_params, + cross_kv_cache_gen=cross_kv_cache_gen, + cross_kv_reuse=cross_kv_reuse) + + if use_cache: + presents_self, presents_cross = hidden_states[1], hidden_states[ + 2] + presents.append((presents_self, presents_cross)) + hidden_states = hidden_states[0] + self.register_network_output(f'decoder_layer_{i}_output', + hidden_states) + + if self.mapping.is_last_pp_rank(): + if self.has_model_final_layernorm: + hidden_states = self.final_layernorm(hidden_states) + + # [bs, seq, hidden_size] or [num_tokens, hidden_size] -> [bs, hidden_size] + hidden_states = gather_last_token_logits( + hidden_states, last_token_ids, + default_net().plugin_config.remove_input_padding) + self.register_network_output('logits_before_lmhead', hidden_states) + + # Rescale output before projecting on vocab (for T5) + # See https://github.com/huggingface/transformers/blob/0b192de1f353b0e04dad4813e02e2c672de077be/src/transformers/models/t5/modeling_t5.py#L1769-L1772 + # Note: this is specific for T5, to make it more generic, one can pass in a config: + # self.config.tie_word_embeddings - default to be True for T5 + # openai whisper model didn't use this rescale + if self.rescale_before_lm_head: + hidden_states = hidden_states * (self.hidden_size**-0.5) + + # [bs, hidden_size] -> [bs, vocab_size] + lm_logits = self.lm_head(hidden_states) + lm_logits.mark_output('logits', self._logits_dtype) + else: + hidden_states = send(hidden_states, self.mapping.next_pp_rank()) + hidden_states.mark_output('hidden_states_output', self._dtype) + + if use_cache and default_net().plugin_config.paged_kv_cache == False: + for i, present in zip(self.mapping.pp_layers(self.total_num_layers), + presents): + present[0].mark_output(f'present_key_value_{i}', self._kv_dtype) + if default_net().plugin_config.gpt_attention_plugin: + present[1].mark_output(f'cross_present_key_value_{i}', + self._kv_dtype) + if self.mapping.is_last_pp_rank(): + return (lm_logits, tuple(presents)) + return (hidden_states, tuple(presents)) + else: + if self.mapping.is_last_pp_rank(): + return lm_logits + return hidden_states + + def prepare_inputs(self, + max_batch_size, + max_decoder_input_len, + max_seq_len, + max_encoder_input_len, + gather_context_logits: bool = False, + gather_generation_logits: bool = False, + use_cache=True, + max_beam_width=1, + *args, + **kwargs): + '''@brief: Prepare inputs Tensors for the model, the given sizes are used to determine the + ranges of the dimensions of when using TRT dynamic shapes. + + @return: a list contains values which can be fed into the self.forward() + ''' + # Prepare inputs + max_output_len = max_decoder_input_len + max_seq_len + + head_size = self.head_size + num_kv_heads = (self.num_kv_heads + self.mapping.tp_size - + 1) // self.mapping.tp_size + + encoder_head_size = self.encoder_head_size + encoder_num_kv_heads = (self.encoder_num_kv_heads + self.mapping.tp_size + - 1) // self.mapping.tp_size + + bb_range = [ + 1, (max_batch_size * max_beam_width + 1) // 2, + max_batch_size * max_beam_width + ] + bs_range = [1, (max_batch_size + 1) // 2, max_batch_size] + beam_width_range = [1, (max_beam_width + 1) // 2, max_beam_width] + inlen_range = [ + 1, 1, max_decoder_input_len + ] # context phase >= 1 (if forced_input_ids), generation phase = 1 + multivocab_inlen_range = [x * self.num_vocabs for x in inlen_range] + encoder_inlen_range = [ + 1, (max_encoder_input_len + 1) // 2, max_encoder_input_len + ] + mask_len_range = [1, (max_output_len + 1) // 2 + 1, max_output_len + 1] + max_output_len_range = [0, (max_output_len + 1) // 2, max_output_len] + + encoder_num_tokens_range = [ + 0, # 0 for generation phase, >0 for context phase + (max_encoder_input_len * max_batch_size + 1) // 2, + max_encoder_input_len * max_batch_size, + ] + decoder_num_tokens_range = [ + 1, + max_batch_size * max_beam_width, + max(max_decoder_input_len * max_batch_size, + max_beam_width * max_batch_size), + ] + multivocab_decoder_num_tokens_range = [ + x * self.num_vocabs for x in decoder_num_tokens_range + ] + + # No enable_two_optimization_profiles support yet + + encoder_input_len_range = [ + 0, # 0 for generation phase, >0 for context phase + (max_encoder_input_len + 1) // 2, + max_encoder_input_len + ] + max_cross_packed_mask_dim0 = max_batch_size * ( + (max_decoder_input_len + 128 - 1) // 128) * 128 + max_cross_packed_mask_dim1 = ( + (max_encoder_input_len + 256 - 1) // 256) * 256 // 32 + cross_packed_mask_dim0_range = [ + 1, (max_cross_packed_mask_dim0 + 1) // 2, max_cross_packed_mask_dim0 + ] + cross_packed_mask_dim1_range = [ + 0, # 0 for generation phase, >0 for context phase + (max_cross_packed_mask_dim1 + 1) // 2, + max_cross_packed_mask_dim1 + ] + + past_key_value = [] + sequence_length = None + host_past_key_value_lengths = None + runtime_perf_knobs = None + context_progress = None + attention_mask = None + cross_attention_mask = None + cross_attention_packed_mask = None + attention_mask_params = AttentionMaskParams() + use_gpt_attention_plugin = default_net( + ).plugin_config.gpt_attention_plugin + remove_input_padding = default_net().plugin_config.remove_input_padding + paged_kv_cache = default_net().plugin_config.paged_kv_cache + tokens_per_block = default_net().plugin_config.tokens_per_block + + input_ids, position_ids, token_type_ids, hidden_states = None, None, None, None + if remove_input_padding: + if self.mapping.is_first_pp_rank(): + input_ids = Tensor(name='input_ids', + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([ + ('multivocab_decoder_num_tokens', + [multivocab_decoder_num_tokens_range]) + ])) + if self.has_position_embedding: + position_ids = Tensor(name='position_ids', + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([ + ('decoder_num_tokens', + [decoder_num_tokens_range]), + ])) + if self.has_token_type_embedding: + token_type_ids = Tensor( + name='token_type_ids', + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([('decoder_num_tokens', + [decoder_num_tokens_range])]), + ) + else: + hidden_states = Tensor(name='hidden_states_input', + dtype=self._dtype, + shape=[-1, self.hidden_size], + dim_range=OrderedDict([ + ('decoder_num_tokens', + [decoder_num_tokens_range]), + ('hidden_size', [self.hidden_size]), + ])) + else: + if self.mapping.is_first_pp_rank(): + input_ids = Tensor(name='input_ids', + dtype=trt.int32, + shape=[-1, -1], + dim_range=OrderedDict([ + ('batch_size_beam_width', [bb_range]), + ('multivocab_input_len', + [multivocab_inlen_range]), + ])) + if self.has_position_embedding: + position_ids = Tensor(name='position_ids', + dtype=trt.int32, + shape=[-1, -1], + dim_range=OrderedDict([ + ('batch_size_beam_width', + [bb_range]), + ('input_len', [inlen_range]), + ])) + if self.has_token_type_embedding: + token_type_ids = Tensor( + name='token_type_ids', + dtype=trt.int32, + shape=[-1, -1], + dim_range=OrderedDict([('batch_size_beam_width', + [bb_range]), + ('input_len', [inlen_range])]), + ) + else: + hidden_states = Tensor(name='hidden_states_input', + dtype=self._dtype, + shape=[-1, -1, self.hidden_size], + dim_range=OrderedDict([ + ('batch_size_beam_width', [bb_range + ]), + ('input_len', [inlen_range]), + ('hidden_size', [self.hidden_size]), + ])) + + encoder_input_lengths = Tensor( + name="encoder_input_lengths", + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([("batch_size_beam_width", [bb_range])]), + ) + encoder_max_input_length = Tensor( + name="encoder_max_input_length", + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([("encoder_max_input_length", + [encoder_inlen_range])]), + ) + encoder_output = None + if remove_input_padding: + encoder_output = Tensor( + name="encoder_output", + dtype=self._dtype, + shape=[-1, self.encoder_hidden_size], + dim_range=OrderedDict([ + ("encoder_num_tokens", [encoder_num_tokens_range]), + ("encoder_hidden_size", [self.encoder_hidden_size]), + ]), + ) + else: + encoder_output = Tensor( + name="encoder_output", + dtype=self._dtype, + shape=[-1, -1, self.encoder_hidden_size], + dim_range=OrderedDict([ + ("batch_size_beam_width_encoder", [bb_range]), + ("encoder_input_len", [encoder_input_len_range]), + ("encoder_hidden_size", [self.encoder_hidden_size]), + ]), + ) + + if use_gpt_attention_plugin: + host_past_key_value_lengths = Tensor( + name='host_past_key_value_lengths', + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([('batch_size_beam_width', [bb_range])]), + ) + + context_lengths = None + host_context_lengths = None + host_request_types = None + if use_gpt_attention_plugin and remove_input_padding: + host_context_lengths = Tensor(name='host_context_lengths', + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([ + ('batch_size_beam_width', + [bb_range]) + ])) + + if use_gpt_attention_plugin: + sequence_length = Tensor( + name='sequence_length', + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([('batch_size_beam_width', [bb_range])]), + ) + + context_lengths = Tensor(name='context_lengths', + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([ + ('batch_size_beam_width', [bb_range]) + ])) + host_request_types = Tensor(name='host_request_types', + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([ + ('batch_size_beam_width', + [bb_range]) + ])) + runtime_perf_knobs = Tensor(name='host_runtime_perf_knobs', + dtype=trt.int64, + shape=[16], + dim_range=OrderedDict([ + ('perf_knob_size', [16]) + ])) + context_progress = Tensor(name='host_context_progress', + dtype=trt.int64, + shape=[1], + dim_range=OrderedDict([ + ('context_progress_size', [1]) + ])) + + last_token_ids = None + if self.mapping.is_last_pp_rank() and not gather_context_logits: + last_token_ids = Tensor( + name="last_token_ids", + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([("batch_size_last_token_ids", [bb_range]) + ]), + ) + + if not use_gpt_attention_plugin: + attention_mask = Tensor( + name='attention_mask', + dtype=trt.int32, + shape=[-1, -1], + dim_range=OrderedDict([ + ('batch_size_beam_width', [bb_range]), + ('mask_len', [mask_len_range]), + ]), + ) + + cross_attention_mask = Tensor( + name='cross_attention_mask', + dtype=trt.int32, + shape=[-1, -1, -1], + dim_range=OrderedDict([ + ('batch_size_beam_width', [bb_range]), + ('query_len', [1]), + ('encoder_input_len_2', [encoder_input_len_range]), + ]), + ) + else: + cross_attention_mask = Tensor( + name='cross_attention_mask', + dtype=trt.bool, + shape=[-1, -1], + dim_range=OrderedDict([ + ('decoder_num_tokens_2', + [decoder_num_tokens_range + ]), # TODO (bhsueh) should use same name as input_ids + ('encoder_input_len_2', [encoder_input_len_range]), + ]), + ) + + cross_attention_packed_mask = Tensor( + name='cross_attention_packed_mask', + dtype=trt.int32, + shape=[-1, -1], + dim_range=OrderedDict([ + ('cross_packed_mask_dim0', [cross_packed_mask_dim0_range]), + ('cross_packed_mask_dim1', [cross_packed_mask_dim1_range]), + ]), + ) + + # create the attention_mask_params. + attention_mask_params = AttentionMaskParams( + attention_mask, None, cross_attention_mask, + cross_attention_packed_mask) + + cache_indirection = Tensor( + name='cache_indirection', + dtype=trt.int32, + shape=[-1, -1, -1], + dim_range=OrderedDict([ + ('batch_size_cache', [bs_range]), + ('beam_width', [beam_width_range]), + ('max_seq_len', [max_output_len_range]), + ]), + ) + + if self.mapping.tp_size > 1: + current_all_reduce_helper().set_workspace_tensor(self.mapping, 1) + + layers_range = self.mapping.pp_layers(self.total_num_layers) + num_pp_layers = len(layers_range) + + host_max_attention_window_sizes = None + host_sink_token_length = None + if use_gpt_attention_plugin: + host_max_attention_window_sizes = Tensor( + name=f'host_max_attention_window_sizes', + dtype=trt.int32, + shape=[num_pp_layers], + dim_range=OrderedDict([('num_layers', [num_pp_layers])])) + host_sink_token_length = Tensor(name='host_sink_token_length', + dtype=trt.int32, + shape=[1], + dim_range=OrderedDict([('scalar', + [1])])) + + kv_cache_block_offsets = None + host_kv_cache_block_offsets = None + host_kv_cache_pool_pointers = None + host_kv_cache_pool_mapping = None + + cross_kv_cache_block_offsets = None + host_cross_kv_cache_block_offsets = None + host_cross_kv_cache_pool_pointers = None + host_cross_kv_cache_pool_mapping = None + + if use_cache: + if not paged_kv_cache: + for i in layers_range: + kv_dim_range = OrderedDict([ + ('batch_size_beam_width', [bb_range]), + ('kv', [2]), + ('num_heads', [num_kv_heads]), + ('past_key_len', [max_output_len_range]), + ('head_size', [head_size]), + ]) + kv = Tensor(name=f'past_key_value_{i}', + dtype=self._kv_dtype, + shape=[-1, 2, num_kv_heads, -1, head_size], + dim_range=kv_dim_range) + + if use_gpt_attention_plugin: + cross_kv_dim_range = OrderedDict([ + ('batch_size_beam_width', [bb_range]), + ('kv', [2]), + ('cross_num_heads', [encoder_num_kv_heads]), + ('cross_past_key_len', [encoder_input_len_range]), + ('cross_head_size', [encoder_head_size]), + ]) + cross_kv = Tensor(name=f'cross_past_key_value_{i}', + dtype=self._kv_dtype, + shape=[ + -1, 2, encoder_num_kv_heads, -1, + encoder_head_size + ], + dim_range=cross_kv_dim_range) + past_key_value.append((kv, cross_kv)) + else: + # use encoder_output directly, no need to save cross_past_key_value + past_key_value.append((kv, )) + + # TODO: Remove this when TRT fix the named dimension + if not remove_input_padding: + assertion( + shape( + input_ids if self.mapping.is_first_pp_rank() else + hidden_states, 0) == shape(kv, 0), 'batch size') + + else: # paged_kv_cache == True + # PagedKV setup for KV cache of self-attention + max_blocks_per_seq_range = [[ + math.ceil(max_output_len_range[0] / tokens_per_block), + math.ceil(max_output_len_range[1] / tokens_per_block), + math.ceil(max_output_len_range[2] / tokens_per_block) + ]] + max_blocks_per_seq_range = [[ + x for x in max_blocks_per_seq_range[0] + ]] + + # PagedKV setup for KV cache of cross-attention + max_cross_blocks_per_seq_range = [[ + math.ceil(encoder_input_len_range[0] / tokens_per_block), + math.ceil(encoder_input_len_range[1] / tokens_per_block), + math.ceil(encoder_input_len_range[2] / tokens_per_block) + ]] + max_cross_blocks_per_seq_range = [[ + x for x in max_cross_blocks_per_seq_range[0] + ]] + + # TODO(oargov): add support for vgqa, meanwhile assume a single kv cache pool + num_kv_cache_pools = 1 + + kv_cache_block_offsets = Tensor( + name=f'kv_cache_block_offsets', + dtype=trt.int32, + shape=[num_kv_cache_pools, -1, 2, -1], + dim_range=OrderedDict([ + ('num_kv_cache_pools', [num_kv_cache_pools]), + ('batch_size_beam_width', [bb_range]), + ('kv', [2]), + ('max_blocks_per_seq', max_blocks_per_seq_range), + ])) + host_kv_cache_block_offsets = Tensor( + name=f'host_kv_cache_block_offsets', + dtype=trt.int32, + shape=[num_kv_cache_pools, -1, 2, -1], + dim_range=OrderedDict([ + ('num_kv_cache_pools', [num_kv_cache_pools]), + ('batch_size_beam_width', [bb_range]), + ('kv', [2]), + ('max_blocks_per_seq', max_blocks_per_seq_range), + ])) + host_kv_cache_pool_pointers = Tensor( + name=f'host_kv_cache_pool_pointers', + dtype=trt.int64, + shape=[num_kv_cache_pools, 2], + dim_range=OrderedDict([ + ('num_pools_layers', [num_kv_cache_pools]), + ('num_pools_kv', [2]), + ])) + host_kv_cache_pool_mapping = Tensor( + name=f"host_kv_cache_pool_mapping", + dtype=trt.int32, + # 2: (Index of pool, Index of layer within pool) + shape=[num_pp_layers, 2], + dim_range=OrderedDict([ + ('pools_mapping', [num_pp_layers]), + ('layer_cache_pool_locator', [2]), + ])) + + # paged blocks for cross kv + cross_kv_cache_block_offsets = Tensor( + name=f'cross_kv_cache_block_offsets', + dtype=trt.int32, + shape=[num_kv_cache_pools, -1, 2, -1], + dim_range=OrderedDict([ + ('num_kv_cache_pools', [num_kv_cache_pools]), + ('batch_size_beam_width', [bb_range]), + ('kv', [2]), + ('max_cross_blocks_per_seq', + max_cross_blocks_per_seq_range), + ])) + host_cross_kv_cache_block_offsets = Tensor( + name=f'host_cross_kv_cache_block_offsets', + dtype=trt.int32, + shape=[num_kv_cache_pools, -1, 2, -1], + dim_range=OrderedDict([ + ('num_kv_cache_pools', [num_kv_cache_pools]), + ('batch_size_beam_width', [bb_range]), + ('kv', [2]), + ('max_cross_blocks_per_seq', + max_cross_blocks_per_seq_range), + ])) + host_cross_kv_cache_pool_pointers = Tensor( + name=f'host_cross_kv_cache_pool_pointers', + dtype=trt.int64, + shape=[num_kv_cache_pools, 2], + dim_range=OrderedDict([ + ('num_kv_cache_pools', [num_kv_cache_pools]), + ('num_pools', [2]), + ])) + host_cross_kv_cache_pool_mapping = Tensor( + name=f"host_cross_kv_cache_pool_mapping", + dtype=trt.int32, + # 2: (Index of pool, Index of layer within pool) + shape=[num_pp_layers, 2], + dim_range=OrderedDict([ + ('pools_mapping', [num_pp_layers]), + ('layer_cache_pool_locator', [2]), + ])) + + for i in layers_range: + past_key_value.append(None) + + kv_cache_params = KeyValueCacheParams( + past_key_value=past_key_value, + host_past_key_value_lengths=host_past_key_value_lengths, + host_max_attention_window_sizes=host_max_attention_window_sizes, + host_sink_token_length=host_sink_token_length, + cache_indirection=cache_indirection, + kv_cache_block_offsets=kv_cache_block_offsets, + host_kv_cache_block_offsets=host_kv_cache_block_offsets, + host_kv_cache_pool_pointers=host_kv_cache_pool_pointers, + host_kv_cache_pool_mapping=host_kv_cache_pool_mapping, + cross_kv_cache_block_offsets=cross_kv_cache_block_offsets, + host_cross_kv_cache_block_offsets= + host_cross_kv_cache_block_offsets, + host_cross_kv_cache_pool_pointers= + host_cross_kv_cache_pool_pointers, + host_cross_kv_cache_pool_mapping= + host_cross_kv_cache_pool_mapping, + ) + + attention_params = AttentionParams( + sequence_length=sequence_length, + context_lengths=context_lengths, + host_context_lengths=host_context_lengths, + max_context_length=max_decoder_input_len, + host_request_types=host_request_types, + encoder_input_lengths=encoder_input_lengths, + encoder_max_input_length=encoder_max_input_length, + host_runtime_perf_knobs=runtime_perf_knobs, + host_context_progress=context_progress) + + cross_kv_cache_gen = Tensor(name='cross_kv_cache_gen', + dtype=trt.bool, + shape=[1], + dim_range=OrderedDict([ + ('boolean', [1]), + ])) + cross_kv_reuse = None + num_heads = (self.num_heads + self.mapping.tp_size - + 1) // self.mapping.tp_size + cross_kv_out_dim = 2 * num_kv_heads * self.head_size + if self.skip_cross_kv: + if remove_input_padding: + cross_kv_reuse = Tensor( + name="cross_kv_reuse", + dtype=self._dtype, + shape=[-1, cross_kv_out_dim], + dim_range=OrderedDict([ + ("encoder_num_tokens", [encoder_num_tokens_range]), + ("encoder_kv_size", [cross_kv_out_dim]), + ]), + ) + else: + cross_kv_reuse = Tensor( + name="cross_kv_reuse", + dtype=self._dtype, + shape=[-1, -1, cross_kv_out_dim], + dim_range=OrderedDict([ + ("batch_size_beam_width_encoder", [bb_range]), + ("encoder_input_len", [encoder_input_len_range]), + ("encoder_kv_size", [cross_kv_out_dim]), + ]), + ) + + result = { + 'decoder_input_ids': input_ids, + 'encoder_output': encoder_output, + 'position_ids': position_ids, + 'token_type_ids': token_type_ids, + 'use_cache': True, + 'attention_mask_params': attention_mask_params, + 'last_token_ids': last_token_ids, + 'kv_cache_params': kv_cache_params, + 'attention_params': attention_params, + 'hidden_states': hidden_states, + 'cross_kv_cache_gen': cross_kv_cache_gen, + 'cross_kv_reuse': cross_kv_reuse, + } + + return result + + def precompute_relative_attention_bias(self, build_config): + if self.config.relative_attention and not self.use_implicit_relative_attention: + relative_attention_bias_builder = torch.ops.tensorrt_llm.relative_attention_bias + rel_attn_precomputed = torch.zeros( + (self.config.num_attention_heads // self.mapping.tp_size, + build_config.max_seq_len + 1, build_config.max_seq_len + 1), + dtype=str_dtype_to_torch(self.config.dtype), + device='cuda') + rel_attn_table = numpy_to_torch( + self.rel_attn_table.raw_value).to('cuda') + relative_attention_bias_builder( + rel_attn_precomputed, + rel_attn_table, + self.config.num_attention_heads // self.mapping.tp_size, + build_config.max_seq_len, + self.config.num_buckets, + False, + self.config.max_distance, + ) + for layer_idx in range(self.num_layers): + self.decoder_layers[ + layer_idx].self_attention.set_rel_attn_table( + build_config.max_seq_len, rel_attn_precomputed) diff --git a/tensorrt_llm/runtime/model_runner_cpp.py b/tensorrt_llm/runtime/model_runner_cpp.py index f6cef268dcf..9fb6c971fa3 100644 --- a/tensorrt_llm/runtime/model_runner_cpp.py +++ b/tensorrt_llm/runtime/model_runner_cpp.py @@ -107,6 +107,7 @@ def from_dir( eagle_dynamic_tree_max_top_k: Optional[int] = None, lookahead_config: list[int] | None = None, debug_mode: bool = False, + debug_tensor_names: List[str] = [], lora_ckpt_source: str = "hf", gpu_weights_percent: float = 1, max_tokens_in_paged_kv_cache: int | None = None, @@ -159,6 +160,8 @@ def from_dir( KV Cache fraction reserved for cross attention, should only be used with enc-dec models. debug_mode (bool): Whether or not to turn on the debug mode. + debug_tensor_names (List[str]): + A list of tensor names to be debugged. medusa_choices (List[List[int]]): Medusa choices to use when in Medusa decoding. eagle_choices (List[List[int]]): @@ -367,16 +370,18 @@ def from_dir( assert max_beam_width <= model_config.max_beam_width debug_config = None + print(f"{debug_tensor_names=}, {debug_mode=}") if debug_mode: # To debug specific tensors, add tensor names in the following list # if none provided, all input and output tensors will be dumped # if not none, it will disable all input/output dump - debug_tensor_names: List[str] = [ - ] # modify this list for specific tensor dump + #debug_tensor_names: List[str] = [ + #] # modify this list for specific tensor dump debug_config = trtllm.DebugConfig( debug_input_tensors=True, debug_output_tensors=True, debug_tensor_names=debug_tensor_names) + print(f"{debug_config=}") trtllm_config = trtllm.ExecutorConfig( max_batch_size=max_batch_size, @@ -461,7 +466,8 @@ def _check_inputs(self, batch_input_ids: List[List[int]], f"Decoder prefix tokens ({decoder_max_length}) + maximum new tokens ({max_new_tokens}) exceeds the engine or specified limit ({self.max_seq_len})" ) else: - if max_length + max_new_tokens > self.max_seq_len * len(self.model_config.vocab_sizes): + if max_length + max_new_tokens > self.max_seq_len * len( + self.model_config.vocab_sizes): raise RuntimeError( f"Maximum input length ({max_length}) + maximum new tokens ({max_new_tokens}) exceeds the engine or specified limit ({self.max_seq_len})" ) @@ -1056,7 +1062,8 @@ def fill_output_ids(result_token_ids, batch_idx, seq_idx): input_lengths = torch.tensor([x.size(0) for x in batch_input_ids], dtype=torch.int32, - device=cuda_device) // len(self.model_config.vocab_sizes) + device=cuda_device) // len( + self.model_config.vocab_sizes) if output_sequence_lengths: outputs['sequence_lengths'] = torch.tensor(sequence_lengths,