Skip to content

Commit c6312e8

Browse files
authored
[Example] ggml: add json-schema example (#151)
Signed-off-by: dm4 <[email protected]>
1 parent 616562d commit c6312e8

File tree

6 files changed

+286
-0
lines changed

6 files changed

+286
-0
lines changed

.github/workflows/llama.yml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,19 @@ jobs:
327327
default \
328328
$'<|user|>\nWhat is the capital of Japan?<|end|>\n<|assistant|>'
329329
330+
- name: JSON Schema
331+
run: |
332+
test -f ~/.wasmedge/env && source ~/.wasmedge/env
333+
cd wasmedge-ggml/json-schema
334+
curl -LO https://huggingface.co/TheBloke/Llama-2-7b-Chat-GGUF/resolve/main/llama-2-7b-chat.Q5_K_M.gguf
335+
cargo build --target wasm32-wasi --release
336+
time wasmedge --dir .:. \
337+
--env n_gpu_layers="$NGL" \
338+
--nn-preload default:GGML:AUTO:llama-2-7b-chat.Q5_K_M.gguf \
339+
target/wasm32-wasi/release/wasmedge-ggml-json-schema.wasm \
340+
default \
341+
$'[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always output JSON format string.\n<</SYS>>\nGive me a JSON array of Apple products.[/INST]'
342+
330343
- name: Build llama-stream
331344
run: |
332345
cd wasmedge-ggml/llama-stream

wasmedge-ggml/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,8 @@ Supported parameters include:
151151
- `threads`: Set the number of threads for the inference, the same as the `--threads` parameter in llama.cpp.
152152
- `mmproj`: Set the path to the multimodal projector file for llava, the same as the `--mmproj` parameter in llama.cpp.
153153
- `image`: Set the path to the image file for llava, the same as the `--image` parameter in llama.cpp.
154+
- `grammar`: Specify a grammar to constrain model output to a specific format, the same as the `--grammar` parameter in llama.cpp.
155+
- `json-schema`: Specify a JSON schema to constrain model output, the same as the `--json-schema` parameter in llama.cpp.
154156

155157
(For more detailed instructions on usage or default values for the parameters, please refer to [WasmEdge](https://github.com/WasmEdge/WasmEdge/blob/master/plugins/wasi_nn/ggml.cpp).)
156158

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
[package]
2+
name = "wasmedge-ggml-json-schema"
3+
version = "0.1.0"
4+
edition = "2021"
5+
6+
[dependencies]
7+
serde_json = "1.0"
8+
wasmedge-wasi-nn = "0.8.0"
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# JSON Schema Example For WASI-NN with GGML Backend
2+
3+
> [!NOTE]
4+
> Please refer to the [wasmedge-ggml/README.md](../README.md) for the general introduction and the setup of the WASI-NN plugin with GGML backend. This document will focus on the specific example of using json schema in ggml.
5+
6+
## Get the Model
7+
8+
In this example, we are going to use the [llama-2-7b](https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF) model. Please note that we are not using a fine-tuned chat model.
9+
10+
```bash
11+
curl -LO https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q5_K_M.gguf
12+
```
13+
14+
## Parameters
15+
16+
> [!NOTE]
17+
> Please check the parameters section of [wasmedge-ggml/README.md](https://github.com/second-state/WasmEdge-WASINN-examples/tree/master/wasmedge-ggml#parameters) first.
18+
19+
In this example, we are going to use the `json-schema` option to constrain the model to generate the JSON output in a specific format.
20+
21+
You can check [the documents at llama.cpp](https://github.com/ggerganov/llama.cpp/tree/master/examples/main#grammars--json-schemas) for more details about this.
22+
23+
## Execute
24+
25+
```console
26+
$ wasmedge --dir .:. \
27+
--env n_predict=99 \
28+
--nn-preload default:GGML:AUTO:llama-2-7b-chat.Q5_K_M.gguf \
29+
wasmedge-ggml-json-schema.wasm default
30+
31+
USER:
32+
Give me a JSON array of Apple products.
33+
ASSISTANT:
34+
[
35+
{
36+
"productId": 1,
37+
"productName": "iPhone 12 Pro",
38+
"price": 799.99
39+
},
40+
{
41+
"productId": 2,
42+
"productName": "iPad Air",
43+
"price": 599.99
44+
},
45+
{
46+
"productId": 3,
47+
"productName": "MacBook Air",
48+
"price": 999.99
49+
},
50+
{
51+
"productId": 4,
52+
"productName": "Apple Watch Series 7",
53+
"price": 399.99
54+
},
55+
{
56+
"productId": 5,
57+
"productName": "AirPods Pro",
58+
"price": 249.99
59+
}
60+
]
61+
```
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
use serde_json::json;
2+
use serde_json::Value;
3+
use std::env;
4+
use std::io;
5+
use wasmedge_wasi_nn::{
6+
self, BackendError, Error, ExecutionTarget, GraphBuilder, GraphEncoding, GraphExecutionContext,
7+
TensorType,
8+
};
9+
10+
fn read_input() -> String {
11+
loop {
12+
let mut answer = String::new();
13+
io::stdin()
14+
.read_line(&mut answer)
15+
.expect("Failed to read line");
16+
if !answer.is_empty() && answer != "\n" && answer != "\r\n" {
17+
return answer.trim().to_string();
18+
}
19+
}
20+
}
21+
22+
fn get_options_from_env() -> Value {
23+
let mut options = json!({});
24+
if let Ok(val) = env::var("enable_log") {
25+
options["enable-log"] = serde_json::from_str(val.as_str())
26+
.expect("invalid value for enable-log option (true/false)")
27+
} else {
28+
options["enable-log"] = serde_json::from_str("false").unwrap()
29+
}
30+
if let Ok(val) = env::var("n_gpu_layers") {
31+
options["n-gpu-layers"] =
32+
serde_json::from_str(val.as_str()).expect("invalid ngl value (unsigned integer")
33+
} else {
34+
options["n-gpu-layers"] = serde_json::from_str("0").unwrap()
35+
}
36+
if let Ok(val) = env::var("n_predict") {
37+
options["n-predict"] =
38+
serde_json::from_str(val.as_str()).expect("invalid n-predict value (unsigned integer")
39+
}
40+
if let Ok(val) = env::var("json_schema") {
41+
options["json-schema"] =
42+
serde_json::from_str(val.as_str()).expect("invalid n-predict value (unsigned integer")
43+
}
44+
45+
options
46+
}
47+
48+
fn set_data_to_context(context: &mut GraphExecutionContext, data: Vec<u8>) -> Result<(), Error> {
49+
context.set_input(0, TensorType::U8, &[1], &data)
50+
}
51+
52+
#[allow(dead_code)]
53+
fn set_metadata_to_context(
54+
context: &mut GraphExecutionContext,
55+
data: Vec<u8>,
56+
) -> Result<(), Error> {
57+
context.set_input(1, TensorType::U8, &[1], &data)
58+
}
59+
60+
fn get_data_from_context(context: &GraphExecutionContext, index: usize) -> String {
61+
// Preserve for 4096 tokens with average token length 6
62+
const MAX_OUTPUT_BUFFER_SIZE: usize = 4096 * 6;
63+
let mut output_buffer = vec![0u8; MAX_OUTPUT_BUFFER_SIZE];
64+
let mut output_size = context
65+
.get_output(index, &mut output_buffer)
66+
.expect("Failed to get output");
67+
output_size = std::cmp::min(MAX_OUTPUT_BUFFER_SIZE, output_size);
68+
69+
return String::from_utf8_lossy(&output_buffer[..output_size]).to_string();
70+
}
71+
72+
fn get_output_from_context(context: &GraphExecutionContext) -> String {
73+
get_data_from_context(context, 0)
74+
}
75+
76+
fn get_metadata_from_context(context: &GraphExecutionContext) -> Value {
77+
serde_json::from_str(&get_data_from_context(context, 1)).expect("Failed to get metadata")
78+
}
79+
80+
const JSON_SCHEMA: &str = r#"
81+
{
82+
"items": {
83+
"title": "Product",
84+
"description": "A product from the catalog",
85+
"type": "object",
86+
"properties": {
87+
"productId": {
88+
"description": "The unique identifier for a product",
89+
"type": "integer"
90+
},
91+
"productName": {
92+
"description": "Name of the product",
93+
"type": "string"
94+
},
95+
"price": {
96+
"description": "The price of the product",
97+
"type": "number",
98+
"exclusiveMinimum": 0
99+
}
100+
},
101+
"required": [
102+
"productId",
103+
"productName",
104+
"price"
105+
]
106+
},
107+
"minItems": 5
108+
}
109+
"#;
110+
111+
fn main() {
112+
let args: Vec<String> = env::args().collect();
113+
let model_name: &str = &args[1];
114+
115+
// Set options for the graph. Check our README for more details:
116+
// https://github.com/second-state/WasmEdge-WASINN-examples/tree/master/wasmedge-ggml#parameters
117+
let mut options = get_options_from_env();
118+
119+
// Add grammar for JSON output.
120+
// Check [here](https://github.com/ggerganov/llama.cpp/tree/master/grammars) for more details.
121+
options["json-schema"] = JSON_SCHEMA.into();
122+
123+
// Make the output more consistent.
124+
options["temp"] = json!(0.1);
125+
126+
// Create graph and initialize context.
127+
let graph = GraphBuilder::new(GraphEncoding::Ggml, ExecutionTarget::AUTO)
128+
.config(serde_json::to_string(&options).expect("Failed to serialize options"))
129+
.build_from_cache(model_name)
130+
.expect("Failed to build graph");
131+
let mut context = graph
132+
.init_execution_context()
133+
.expect("Failed to init context");
134+
135+
// If there is a third argument, use it as the prompt and enter non-interactive mode.
136+
// This is mainly for the CI workflow.
137+
if args.len() >= 3 {
138+
let prompt = &args[2];
139+
// Set the prompt.
140+
println!("Prompt:\n{}", prompt);
141+
let tensor_data = prompt.as_bytes().to_vec();
142+
context
143+
.set_input(0, TensorType::U8, &[1], &tensor_data)
144+
.expect("Failed to set input");
145+
println!("Response:");
146+
147+
// Get the number of input tokens and llama.cpp versions.
148+
let input_metadata = get_metadata_from_context(&context);
149+
println!("[INFO] llama_commit: {}", input_metadata["llama_commit"]);
150+
println!(
151+
"[INFO] llama_build_number: {}",
152+
input_metadata["llama_build_number"]
153+
);
154+
println!(
155+
"[INFO] Number of input tokens: {}",
156+
input_metadata["input_tokens"]
157+
);
158+
159+
// Get the output.
160+
context.compute().expect("Failed to compute");
161+
let output = get_output_from_context(&context);
162+
println!("{}", output.trim());
163+
164+
// Retrieve the output metadata.
165+
let metadata = get_metadata_from_context(&context);
166+
println!(
167+
"[INFO] Number of input tokens: {}",
168+
metadata["input_tokens"]
169+
);
170+
println!(
171+
"[INFO] Number of output tokens: {}",
172+
metadata["output_tokens"]
173+
);
174+
std::process::exit(0);
175+
}
176+
177+
loop {
178+
println!("USER:");
179+
let input = read_input();
180+
181+
// Set prompt to the input tensor.
182+
set_data_to_context(&mut context, input.as_bytes().to_vec()).expect("Failed to set input");
183+
184+
// Execute the inference.
185+
match context.compute() {
186+
Ok(_) => (),
187+
Err(Error::BackendError(BackendError::ContextFull)) => {
188+
println!("\n[INFO] Context full, we'll reset the context and continue.");
189+
}
190+
Err(Error::BackendError(BackendError::PromptTooLong)) => {
191+
println!("\n[INFO] Prompt too long, we'll reset the context and continue.");
192+
}
193+
Err(err) => {
194+
println!("\n[ERROR] {}", err);
195+
}
196+
}
197+
198+
// Retrieve the output.
199+
let output = get_output_from_context(&context);
200+
println!("ASSISTANT:\n{}", output.trim());
201+
}
202+
}
1.74 MB
Binary file not shown.

0 commit comments

Comments
 (0)