Skip to content

Commit 5eb25dc

Browse files
committed
transformer model example for testing
Signed-off-by: sa-faizal <[email protected]>
1 parent cbb42dd commit 5eb25dc

File tree

3 files changed

+465
-0
lines changed

3 files changed

+465
-0
lines changed

examples/transformer/client.py

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
#!/usr/bin/env python3
2+
import numpy as np
3+
import sys
4+
import argparse
5+
from typing import List, Tuple
6+
import tritonclient.http as httpclient
7+
from tritonclient.utils import InferenceServerException
8+
9+
10+
class SimpleTokenizer:
11+
"""
12+
A simple character-level tokenizer for demo purposes.
13+
"""
14+
15+
def __init__(self, vocab_size=10000):
16+
self.vocab_size = vocab_size
17+
self.pad_token_id = 0
18+
self.unk_token_id = 1
19+
20+
def encode(self, text: str, max_length: int = 128) -> Tuple[List[int], List[int]]:
21+
"""
22+
Encode text to token IDs and create attention mask.
23+
24+
Args:
25+
text: Input text string
26+
max_length: Maximum sequence length
27+
28+
Returns:
29+
Tuple of (input_ids, attention_mask)
30+
"""
31+
# Simple character-level encoding that maps each character to an ID based on its ASCII value
32+
input_ids = [min(ord(c), self.vocab_size - 1) for c in text.lower()]
33+
34+
# Truncate if too long
35+
if len(input_ids) > max_length:
36+
input_ids = input_ids[:max_length]
37+
38+
# Create attention mask (1 for real tokens, 0 for padding)
39+
attention_mask = [1] * len(input_ids)
40+
41+
# Pad to max_length
42+
padding_length = max_length - len(input_ids)
43+
input_ids.extend([self.pad_token_id] * padding_length)
44+
attention_mask.extend([0] * padding_length)
45+
46+
return input_ids, attention_mask
47+
48+
49+
class SentimentClient:
50+
"""
51+
Client for the Transformer Sentiment Classifier on Triton Inference Server.
52+
"""
53+
54+
def __init__(self, url: str = "localhost:8000", model_name: str = "transformer"):
55+
"""
56+
Initialize the client.
57+
58+
Args:
59+
url: Triton server URL (e.g., "localhost:8000")
60+
model_name: Name of the model
61+
"""
62+
self.url = url
63+
self.model_name = model_name
64+
self.client = httpclient.InferenceServerClient(url=url, verbose=False)
65+
self.tokenizer = SimpleTokenizer()
66+
self.max_seq_length = 128
67+
self.class_names = ["Negative", "Neutral", "Positive"]
68+
69+
def check_server_ready(self) -> bool:
70+
"""Check if the Triton server is ready."""
71+
try:
72+
if self.client.is_server_ready():
73+
print(f"Server at {self.url} is ready")
74+
return True
75+
else:
76+
print(f"Server at {self.url} is not ready")
77+
return False
78+
except InferenceServerException as e:
79+
print(f"Failed to connect to server at {self.url}")
80+
print(f" Error: {e}")
81+
return False
82+
83+
def check_model_ready(self) -> bool:
84+
"""Check if the model is ready."""
85+
try:
86+
if self.client.is_model_ready(self.model_name):
87+
print(f"Model '{self.model_name}' is ready")
88+
return True
89+
else:
90+
print(f"Model '{self.model_name}' is not ready")
91+
return False
92+
except InferenceServerException as e:
93+
print(f"Failed to check model status")
94+
print(f" Error: {e}")
95+
return False
96+
97+
def predict(self, text: str) -> Tuple[np.ndarray, int, str]:
98+
"""
99+
Run inference on a single text input.
100+
101+
Args:
102+
text: Input text string
103+
104+
Returns:
105+
Tuple of (probabilities, predicted_class, class_name)
106+
"""
107+
# Tokenize input
108+
input_ids, attention_mask = self.tokenizer.encode(text, self.max_seq_length)
109+
110+
# Convert to numpy arrays with batch dimension
111+
input_ids_np = np.array([input_ids], dtype=np.int64)
112+
attention_mask_np = np.array([attention_mask], dtype=np.int64)
113+
114+
# Create input objects
115+
inputs = [
116+
httpclient.InferInput("INPUT_IDS", input_ids_np.shape, "INT64"),
117+
httpclient.InferInput("ATTENTION_MASK", attention_mask_np.shape, "INT64")
118+
]
119+
120+
# Set data
121+
inputs[0].set_data_from_numpy(input_ids_np)
122+
inputs[1].set_data_from_numpy(attention_mask_np)
123+
124+
# Create output object
125+
outputs = [httpclient.InferRequestedOutput("OUTPUT")]
126+
127+
# Send inference request
128+
try:
129+
response = self.client.infer(
130+
model_name=self.model_name,
131+
inputs=inputs,
132+
outputs=outputs
133+
)
134+
135+
# Get output
136+
output = response.as_numpy("OUTPUT")[0] # Remove batch dimension
137+
predicted_class = int(np.argmax(output))
138+
class_name = self.class_names[predicted_class]
139+
140+
return output, predicted_class, class_name
141+
142+
except InferenceServerException as e:
143+
print(f"Inference failed: {e}")
144+
raise
145+
146+
def predict_batch(self, texts: List[str]) -> List[Tuple[np.ndarray, int, str]]:
147+
"""
148+
Run inference on a batch of text inputs.
149+
150+
Args:
151+
texts: List of input text strings
152+
153+
Returns:
154+
List of tuples (probabilities, predicted_class, class_name) for each input
155+
"""
156+
# Tokenize all inputs
157+
input_ids_batch = []
158+
attention_mask_batch = []
159+
160+
for text in texts:
161+
input_ids, attention_mask = self.tokenizer.encode(text, self.max_seq_length)
162+
input_ids_batch.append(input_ids)
163+
attention_mask_batch.append(attention_mask)
164+
165+
# Convert to numpy arrays
166+
input_ids_np = np.array(input_ids_batch, dtype=np.int64)
167+
attention_mask_np = np.array(attention_mask_batch, dtype=np.int64)
168+
169+
# Create input objects
170+
inputs = [
171+
httpclient.InferInput("INPUT_IDS", input_ids_np.shape, "INT64"),
172+
httpclient.InferInput("ATTENTION_MASK", attention_mask_np.shape, "INT64")
173+
]
174+
175+
# Set data
176+
inputs[0].set_data_from_numpy(input_ids_np)
177+
inputs[1].set_data_from_numpy(attention_mask_np)
178+
179+
# Create output object
180+
outputs = [httpclient.InferRequestedOutput("OUTPUT")]
181+
182+
# Send inference request
183+
try:
184+
response = self.client.infer(
185+
model_name=self.model_name,
186+
inputs=inputs,
187+
outputs=outputs
188+
)
189+
190+
# Get outputs
191+
outputs_np = response.as_numpy("OUTPUT")
192+
193+
results = []
194+
for output in outputs_np:
195+
predicted_class = int(np.argmax(output))
196+
class_name = self.class_names[predicted_class]
197+
results.append((output, predicted_class, class_name))
198+
199+
return results
200+
201+
except InferenceServerException as e:
202+
print(f"Batch inference failed: {e}")
203+
raise
204+

examples/transformer/config.pbtxt

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
name: "transformer"
2+
backend: "python"
3+
max_batch_size: 8 # maximum batch size that model supports for the types of batching on Triton
4+
5+
# Input tensor specifications
6+
input [
7+
{
8+
name: "INPUT_IDS"
9+
data_type: TYPE_INT64
10+
dims: [ 128 ] # max_seq_length
11+
},
12+
{
13+
name: "ATTENTION_MASK"
14+
data_type: TYPE_INT64
15+
dims: [ 128 ] # max_seq_length
16+
}
17+
]
18+
19+
# Output tensor specifications
20+
output [
21+
{
22+
name: "OUTPUT"
23+
data_type: TYPE_FP32
24+
dims: [ 3 ] # num_classes (Negative, Neutral, Positive)
25+
}
26+
]
27+
28+
# Instance group configuration
29+
# For GPUs: Use KIND_GPU
30+
# For CPU-only: Use KIND_CPU
31+
instance_group [
32+
{
33+
count: 1
34+
kind: KIND_GPU
35+
gpus: [ 0 ]
36+
}
37+
]
38+
39+
# Dynamic batching configuration for better throughput
40+
dynamic_batching {
41+
preferred_batch_size: [ 4, 8 ]
42+
max_queue_delay_microseconds: 100
43+
}
44+
45+
# Model version policy - serve the latest version
46+
version_policy: { latest: { num_versions: 1 } }
47+

0 commit comments

Comments
 (0)