Skip to content

Commit ca9aecf

Browse files
committed
Merge branch 'refs/heads/dev'
2 parents 106a9d1 + b3e07ee commit ca9aecf

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+1191
-290
lines changed

eval/humaneval.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from exllamav2 import model_init
66
from exllamav2 import ExLlamaV2Cache, ExLlamaV2Cache_Q4, ExLlamaV2Cache_Q6, ExLlamaV2Cache_Q8
77
from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2DynamicJob, ExLlamaV2Sampler
8-
import argparse, contextlib
8+
import argparse, contextlib, subprocess
99
import util
1010

1111
# Args
@@ -20,6 +20,7 @@
2020
parser.add_argument("--max_tokens", type = int, default = 768, help = "Max number of tokens for each completion")
2121
parser.add_argument("-pf", "--prompt_format", type = str, help = "Instruct format to apply. Default is raw completion (for base models) ")
2222
parser.add_argument("-v", "--verbose", action = "store_true", help = "Spam completions to console while generating")
23+
parser.add_argument("-e", "--eval", action = "store_true", help = "Run evaluation script on output file after sampling")
2324
model_init.add_args(parser)
2425
args = parser.parse_args()
2526

@@ -52,6 +53,13 @@
5253
"<|start_header_id|>assistant<|end_header_id|>\n\n"
5354
"Sure! Here is how you might implement the function:\n\n```python\n{{problem}} ",
5455
" "
56+
),
57+
"gemma": (
58+
"<bos><start_of_turn>user\n"
59+
"Complete the following Python function:\n\n{{problem}}<|eot_id|>"
60+
"<start_of_turn>model\n"
61+
"```python\n{{problem}} ",
62+
" "
5563
)
5664
}
5765

@@ -192,3 +200,8 @@
192200
print(f" -- Saving: {args.output}")
193201
write_jsonl(args.output, samples)
194202

203+
# Optionally launch eval script
204+
205+
if args.eval:
206+
subprocess.run(["evaluate_functional_correctness", args.output])
207+

examples/chat.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161

6262
parser.add_argument("-ngram", "--ngram_decoding", action = "store_true", help = "Use n-gram speculative decoding")
6363

64-
parser.add_argument("-pt", "--print_timings", action = "store_true", help = "Output timings after each prompt")
64+
parser.add_argument("-pt", "--print_timings", action = "store_true", help = "Output timings/stats after each prompt")
6565
parser.add_argument("-amnesia", "--amnesia", action = "store_true", help = "Forget context after every response")
6666

6767
# Arrrgs
@@ -235,7 +235,9 @@ def get_tokenized_context(max_len):
235235

236236
# Stop conditions
237237

238-
generator.set_stop_conditions(prompt_format.stop_conditions(tokenizer))
238+
sc = prompt_format.stop_conditions(tokenizer)
239+
sc = [x for x in sc if x]
240+
generator.set_stop_conditions(sc)
239241

240242
# ANSI color codes
241243

@@ -393,8 +395,9 @@ def get_tokenized_context(max_len):
393395
else:
394396
sd_stats = ""
395397

398+
ctx_tokens = active_context.shape[-1]
396399
print()
397-
print(col_sysprompt + f"(Response: {response_tokens} tokens, {speed:.2f} tokens/second{sd_stats})" + col_default)
400+
print(col_sysprompt + f"(Context: {ctx_tokens} tokens, response: {response_tokens} tokens, {speed:.2f} tokens/second{sd_stats})" + col_default)
398401

399402
# Optionally forget context after each response
400403

examples/chat_prompts.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ def subs_prompt(self):
229229
def stop_conditions(self, tokenizer):
230230
return \
231231
[tokenizer.eos_token_id,
232+
tokenizer.single_id("<|im_end|>"),
232233
"""<|im_end|>"""]
233234

234235
def encoding_options(self):

examples/dynamic_gen.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def main():
136136
if use_draft_model:
137137

138138
draft_config = ExLlamaV2Config(draft_model_dir)
139+
draft_config.arch_compat_overrides()
139140
draft_model = ExLlamaV2(draft_config)
140141

141142
draft_cache = ExLlamaV2Cache(
@@ -155,6 +156,7 @@ def main():
155156
# 2048, which will also be the limit of the chunk size for prefill used by the dynamic generator.
156157

157158
config = ExLlamaV2Config(model_dir)
159+
config.arch_compat_overrides()
158160
config.max_input_len = max_chunk_size
159161
config.max_attention_size = max_chunk_size ** 2
160162
model = ExLlamaV2(config)

examples/inference.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
model_dir = "/mnt/str/models/mistral-7b-exl2/4.0bpw"
99
config = ExLlamaV2Config(model_dir)
10+
config.arch_compat_overrides()
1011
model = ExLlamaV2(config)
1112
cache = ExLlamaV2Cache(model, max_seq_len = 32768, lazy = True)
1213
model.load_autosplit(cache, progress = True)

examples/inference_async.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
async def main():
1010
model_dir = "/mnt/str/models/llama3-8b-exl2/4.0bpw"
1111
config = ExLlamaV2Config(model_dir)
12+
config.arch_compat_overrides()
1213
model = ExLlamaV2(config)
1314
cache = ExLlamaV2Cache(model, lazy = True)
1415
model.load_autosplit(cache, progress = True)

examples/inference_banned_strings.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
model_dir = "/mnt/str/models/llama3-8b-instruct-exl2/6.0bpw/"
1111
config = ExLlamaV2Config(model_dir)
12+
config.arch_compat_overrides()
1213
model = ExLlamaV2(config)
1314
cache = ExLlamaV2Cache(model, lazy = True)
1415
model.load_autosplit(cache, progress = True)

examples/inference_cfg.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
model_dir = "/mnt/str/models/llama3-8b-instruct-exl2/4.0bpw"
1010
config = ExLlamaV2Config(model_dir)
11+
config.arch_compat_overrides()
1112
model = ExLlamaV2(config)
1213
cache = ExLlamaV2Cache(model, max_seq_len = 32768, lazy = True)
1314
model.load_autosplit(cache, progress = True)

examples/inference_dedup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
model_dir = "/mnt/str/models/llama3-8b-instruct-exl2/4.0bpw"
1010
config = ExLlamaV2Config(model_dir)
11+
config.arch_compat_overrides()
1112
model = ExLlamaV2(config)
1213
cache = ExLlamaV2Cache(model, max_seq_len = 8192, lazy = True)
1314
model.load_autosplit(cache, progress = True)

examples/inference_json.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
model_dir = "/mnt/str/models/mistral-7b-exl2/4.0bpw"
1515
config = ExLlamaV2Config(model_dir)
16+
config.arch_compat_overrides()
1617
model = ExLlamaV2(config)
1718
cache = ExLlamaV2Cache(model, max_seq_len = 32768, lazy = True)
1819
model.load_autosplit(cache, progress = True)

0 commit comments

Comments
 (0)