Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 3 additions & 11 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -4150,20 +4150,12 @@ def cloud_ai_100_feature_generate(
if self.qpc_session is None:
self.qpc_session = QAICInferenceSession(str(self.qpc_path), device_ids)
self.batch_size = self.qpc_session.bindings[0].dims[0]

# Dynamic switching to closest seq_Len based on input_ids_len
inputs = processor(inputs, return_tensors="pt")
input_ids_len = inputs["input_values"].shape[-1]

for allowed_shape in self.qpc_session.allowed_shapes:
seq_len_allowed = allowed_shape[1][1][1]

if seq_len_allowed >= input_ids_len:
self.seq_len = seq_len_allowed
break
self.seq_len = self.qpc_session.bindings[0].dims[1]

# To handle single seq_len as we can't fetch allowed shapes for single seq_len
self.seq_len = self.qpc_session.bindings[0].dims[1] if not hasattr(self, "seq_len") else self.seq_len
inputs = processor(inputs, return_tensors="pt", max_length=self.seq_len, truncation=True, padding="max_length")
input_ids_len = inputs["input_values"].shape[-1]
input_values = np.array(
torch.nn.functional.pad(inputs["input_values"], (0, self.seq_len - input_ids_len), "constant", 0)
)
Expand Down
16 changes: 9 additions & 7 deletions examples/audio/wav2vec2_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
def main():
parser = argparse.ArgumentParser(description="CTC speech recognition inference with Wav2Vec2")
parser.add_argument(
"--model-name",
"--model_name",
type=str,
default="facebook/wav2vec2-base-960h",
help="HuggingFace CTC model ID (e.g., Wav2Vec2)",
)

parser.add_argument("--num-cores", type=int, default=16, help="Number of cores")
parser.add_argument("--num_cores", type=int, default=16, help="Number of cores")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
parser.add_argument("--seq_len", type=int, default=480000, help="Context length for generation")
parser.add_argument("--num-devices", type=int, default=1, help="Number of devices")
args = parser.parse_args()

print(f"Loading CTC model: {args.model_name}")
Expand All @@ -31,10 +34,7 @@ def main():
# Using a standard english dataset
print("Loading audio sample from dataset...")
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
data = ds[0]["audio"]["array"]

# Reshape so shape corresponds to data with batch size 1
data = data.reshape(-1)
data = [ds[i]["audio"]["array"] for i in range(args.batch_size)]

# Load processor
processor = AutoProcessor.from_pretrained(args.model_name)
Expand All @@ -43,7 +43,9 @@ def main():
model = QEFFAutoModelForCTC.from_pretrained(args.model_name)

## STEP 3 -- Compile the model
model.compile(num_cores=args.num_cores)
model.compile(
batch_size=args.batch_size, num_devices=args.num_devices, seq_len=args.seq_len, num_cores=args.num_cores
)

## STEP 4 -- Run the model and generate the output
model_output = model.generate(processor, inputs=data)
Expand Down
Loading