Skip to content

Commit 3de9b37

Browse files
committed
Update cli.py with new params and "instruct" mode
1 parent 3d454a7 commit 3de9b37

File tree

1 file changed

+67
-40
lines changed

1 file changed

+67
-40
lines changed

llamacpp/cli.py

Lines changed: 67 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ def parse_args_into_params(argv) -> llamacpp.gpt_params:
99
parser = argparse.ArgumentParser(description="llama.cpp CLI")
1010
parser.add_argument("-i", "--interactive", action="store_true", help="run in interactive mode")
1111
parser.add_argument(
12-
"--interactive-start",
12+
"-ins", "--instruct",
1313
action="store_true",
14-
help="run in interactive mode and poll user input at startup",
14+
help="run in 'instruct mode' where the user is prompted to enter a command",
1515
default=False,
1616
)
1717
parser.add_argument(
@@ -39,11 +39,10 @@ def parse_args_into_params(argv) -> llamacpp.gpt_params:
3939
"--prompt",
4040
type=str,
4141
help="prompt to start generation with (default: random)",
42-
required=True,
4342
)
44-
# parser.add_argument(
45-
# "-f", "--file", type=str, default="", help="prompt file to start generation."
46-
# )
43+
parser.add_argument(
44+
"-f", "--file", type=str, default="", help="prompt file to start generation."
45+
)
4746
parser.add_argument(
4847
"-n", "--n_predict", type=int, default=128, help="number of tokens to predict (default: 128)"
4948
)
@@ -81,29 +80,7 @@ def parse_args_into_params(argv) -> llamacpp.gpt_params:
8180

8281
args = parser.parse_args(argv[1:])
8382

84-
# Add a space in front of the first character to match OG llama tokenizer behavior
85-
args.prompt = " " + args.prompt
86-
87-
# Initialize gpt_params object
88-
params = llamacpp.gpt_params(
89-
args.model,
90-
args.prompt,
91-
args.reverse_prompt,
92-
args.ctx_size,
93-
args.n_predict,
94-
args.top_k,
95-
args.top_p,
96-
args.temp,
97-
args.repeat_penalty,
98-
args.seed,
99-
args.threads,
100-
args.repeat_last_n,
101-
args.batch_size,
102-
args.color,
103-
args.interactive,
104-
)
105-
106-
return params
83+
return args
10784

10885

10986
def process_interactive_input(model: llamacpp.PyLLAMA):
@@ -121,24 +98,57 @@ def process_interactive_input(model: llamacpp.PyLLAMA):
12198
break
12299

123100

124-
def main(params):
101+
def main(args):
125102
"""Main function"""
103+
104+
# if args.file is specified, read the file and set the prompt to the contents
105+
if args.file:
106+
with open(args.file, "r") as f:
107+
args.prompt = f.read().strip()
108+
109+
# Add a space in front of the first character to match OG llama tokenizer behavior
110+
args.prompt = " " + args.prompt
111+
112+
# Initialize the gpt_params object
113+
params = llamacpp.gpt_params(
114+
args.model,
115+
args.ctx_size,
116+
args.n_predict,
117+
args.top_k,
118+
args.top_p,
119+
args.temp,
120+
args.repeat_penalty,
121+
args.seed,
122+
args.threads,
123+
args.repeat_last_n,
124+
args.batch_size,
125+
)
126+
126127
model = llamacpp.PyLLAMA(params)
127128
model.add_bos()
128-
model.update_input(params.prompt)
129+
model.update_input(args.prompt)
129130
model.print_startup_stats()
130131
model.prepare_context()
131132

133+
inp_pfx = model.tokenize("\n\n### Instruction:\n\n", True)
134+
inp_sfx = model.tokenize("\n\n### Response:\n\n", False)
135+
136+
if args.instruct:
137+
args.interactive = True
138+
args.antiprompt = "### Instruction:\n\n"
139+
132140
# Set antiprompt if we are in interactive mode
133-
if params.interactive:
134-
model.set_antiprompt(params.antiprompt)
141+
if args.antiprompt:
142+
args.interactive = True
143+
model.set_antiprompt(args.antiprompt)
135144

136-
if params.interactive:
145+
if args.interactive:
137146
print("== Running in interactive mode. ==")
138147
print(" - Press Ctrl+C to interject at any time.")
139148
print(" - Press Return to return control to LLaMa.")
140149
print(" - If you want to submit another line, end your input in '\\'.")
141150
print()
151+
is_interacting = True
142152

143153
input_noecho = False
144154
is_finished = False
@@ -147,33 +157,50 @@ def main(params):
147157
if model.has_unconsumed_input():
148158
model.ingest_all_pending_input(not input_noecho)
149159
# # reset color to default if we there is no pending user input
150-
# if (!input_noecho && params.use_color) {
160+
# if (!input_noecho && args.use_color) {
151161
# printf(ANSI_COLOR_RESET);
152162
# }
153163
else:
154164
text, is_finished = model.infer_text()
155165
print(text, end="")
156166
input_noecho = False
157167

158-
if params.interactive:
168+
if args.interactive:
159169
if model.is_antiprompt_present():
160170
# reverse prompt found
161171
is_interacting = True
162172
if is_interacting:
173+
if args.instruct:
174+
model.update_input_tokens(inp_pfx)
175+
print("\n> ", end="")
176+
163177
process_interactive_input(model)
178+
179+
if args.instruct:
180+
model.update_input_tokens(inp_sfx)
181+
164182
input_noecho = True
165183
is_interacting = False
166-
184+
185+
# end of text token was found
167186
if is_finished:
168-
break
187+
if args.interactive:
188+
is_interacting = True
189+
else:
190+
print(" [end of text]")
191+
break
192+
193+
if args.interactive and model.is_finished():
194+
model.reset_remaining_tokens()
195+
is_interacting = True
169196

170197
return 0
171198

172199

173200
def run():
174201
# Parse params into a gpt_params object
175-
params = parse_args_into_params(sys.argv)
176-
return main(params)
202+
args = parse_args_into_params(sys.argv)
203+
return main(args)
177204

178205
if __name__ == "__main__":
179206
sys.exit(run())

0 commit comments

Comments
 (0)