diff --git a/.gitignore b/.gitignore index 3bdb9ed..58c92cb 100644 --- a/.gitignore +++ b/.gitignore @@ -161,4 +161,6 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ \ No newline at end of file +.idea/ +.run +test*.py \ No newline at end of file diff --git a/README.md b/README.md index f706f60..511f673 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,17 @@ VisualGLM-6B 可以进行图像的描述的相关知识的问答。 ## 使用 +通过测试: + +CentOS Linux release 7.9.2009 (Core) + +NVIDIA Driver Version: 525.105.17 + +cuda_11.7 + +cuDNN v8.9.3 (July 11th, 2023), for CUDA 11.x + +Python 3.10.12 ### 模型推理 diff --git a/README_en.md b/README_en.md index 8b47449..8b3745a 100644 --- a/README_en.md +++ b/README_en.md @@ -38,6 +38,19 @@ VisualGLM-6B can answer questions related to image description. ## Usage + +test passed: + +CentOS Linux release 7.9.2009 (Core) + +NVIDIA Driver Version: 525.105.17 + +cuda_11.7 + +cuDNN v8.9.3 (July 11th, 2023), for CUDA 11.x + +Python 3.10.12 + ### Model Inference Install dependencies with pip diff --git a/api.py b/api.py index 3d15db1..7a0caef 100644 --- a/api.py +++ b/api.py @@ -6,10 +6,19 @@ import datetime import torch +# 命令行参数 +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument("--quant", choices=[8, 4], type=int, default=None) +args = parser.parse_args() + gpu_number = 0 -model, tokenizer = get_infer_setting(gpu_device=gpu_number) +model, tokenizer = get_infer_setting(gpu_device=gpu_number, quant=args.quant) app = FastAPI() + + @app.post('/') async def visual_glm(request: Request): json_post_raw = await request.json() @@ -30,12 +39,12 @@ async def visual_glm(request: Request): is_zh = is_chinese(input_text) input_data = generate_input(input_text, input_image_encoded, history, input_para) - input_image, gen_kwargs = input_data['input_image'], input_data['gen_kwargs'] + input_image, gen_kwargs = input_data['input_image'], input_data['gen_kwargs'] with torch.no_grad(): answer, history, _ = chat(None, model, tokenizer, input_text, history=history, image=input_image, \ - max_length=gen_kwargs['max_length'], top_p=gen_kwargs['top_p'], \ - top_k = gen_kwargs['top_k'], temperature=gen_kwargs['temperature'], english=not is_zh) - + max_length=gen_kwargs['max_length'], top_p=gen_kwargs['top_p'], \ + top_k=gen_kwargs['top_k'], temperature=gen_kwargs['temperature'], english=not is_zh) + now = datetime.datetime.now() time = now.strftime("%Y-%m-%d %H:%M:%S") response = { @@ -48,4 +57,4 @@ async def visual_glm(request: Request): if __name__ == '__main__': - uvicorn.run(app, host='0.0.0.0', port=8080, workers=1) \ No newline at end of file + uvicorn.run(app, host='0.0.0.0', port=8080, workers=1) diff --git a/api_hf.py b/api_hf.py index fa1a635..caa7214 100644 --- a/api_hf.py +++ b/api_hf.py @@ -10,8 +10,9 @@ tokenizer = AutoTokenizer.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True) model = AutoModel.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True).half().cuda() - app = FastAPI() + + @app.post('/') async def visual_glm(request: Request): json_post_raw = await request.json() @@ -23,9 +24,13 @@ async def visual_glm(request: Request): history = request_data.get("history") image_encoded = request_data.get("image") query = request_data.get("text") - image_path = process_image(image_encoded) - with torch.no_grad(): + if image_encoded is not None: + image_path = process_image(image_encoded) + else: + image_path = '' + + with torch.no_grad(): result = model.stream_chat(tokenizer, image_path, query, history=history) last_result = None for value in result: @@ -46,4 +51,4 @@ async def visual_glm(request: Request): if __name__ == "__main__": - uvicorn.run(app, host='0.0.0.0', port=8080, workers=1) \ No newline at end of file + uvicorn.run(app, host='0.0.0.0', port=8080, workers=1) diff --git a/model/infer_util.py b/model/infer_util.py index 342bd4f..84044ee 100644 --- a/model/infer_util.py +++ b/model/infer_util.py @@ -33,9 +33,10 @@ def is_chinese(text): return zh_pattern.search(text) def generate_input(input_text, input_image_prompt, history=[], input_para=None, image_is_encoded=True): + image = None if not image_is_encoded: image = input_image_prompt - else: + elif input_image_prompt: decoded_image = base64.b64decode(input_image_prompt) image = Image.open(BytesIO(decoded_image)) @@ -50,4 +51,4 @@ def process_image(image_encoded): image_path = f'./examples/{image_hash}.png' if not os.path.isfile(image_path): image.save(image_path) - return os.path.abspath(image_path) \ No newline at end of file + return os.path.abspath(image_path) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index f594164..45205c4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ SwissArmyTransformer>=0.3.6 -torch>1.10.0 -torchvision +torch==1.13.0 +torchvision==0.14.0 +torchaudio==0.13.0 transformers>=4.27.1 mdtex2html gradio \ No newline at end of file diff --git a/web_demo.py b/web_demo.py index d0f705a..b5b1929 100644 --- a/web_demo.py +++ b/web_demo.py @@ -1,12 +1,12 @@ #!/usr/bin/env python import gradio as gr +import requests from PIL import Image -import os -import json from model import is_chinese, get_infer_setting, generate_input, chat import torch + def generate_text_with_image(input_text, image, history=[], request_data=dict(), is_zh=True): input_para = { "max_length": 2048, @@ -19,34 +19,42 @@ def generate_text_with_image(input_text, image, history=[], request_data=dict(), input_para.update(request_data) input_data = generate_input(input_text, image, history, input_para, image_is_encoded=False) - input_image, gen_kwargs = input_data['input_image'], input_data['gen_kwargs'] + input_image, gen_kwargs = input_data['input_image'], input_data['gen_kwargs'] with torch.no_grad(): answer, history, _ = chat(None, model, tokenizer, input_text, history=history, image=input_image, \ - max_length=gen_kwargs['max_length'], top_p=gen_kwargs['top_p'], \ - top_k = gen_kwargs['top_k'], temperature=gen_kwargs['temperature'], english=not is_zh) + max_length=gen_kwargs['max_length'], top_p=gen_kwargs['top_p'], \ + top_k=gen_kwargs['top_k'], temperature=gen_kwargs['temperature'], english=not is_zh) return answer def request_model(input_text, temperature, top_p, image_prompt, result_previous): result_text = [(ele[0], ele[1]) for ele in result_previous] - for i in range(len(result_text)-1, -1, -1): + for i in range(len(result_text) - 1, -1, -1): if result_text[i][0] == "" or result_text[i][1] == "": del result_text[i] print(f"history {result_text}") is_zh = is_chinese(input_text) - if image_prompt is None: - if is_zh: - result_text.append((input_text, '图片为空!请上传图片并重试。')) - else: - result_text.append((input_text, 'Image empty! Please upload a image and retry.')) - return input_text, result_text - elif input_text == "": - result_text.append((input_text, 'Text empty! Please enter text and retry.')) - return "", result_text + # if image_prompt is None: + # ... + # if is_zh: + # result_text.append((input_text, '图片为空!请上传图片并重试。')) + # else: + # result_text.append((input_text, 'Image empty! Please upload a image and retry.')) + # return input_text, result_text + # elif input_text == "": + # result_text.append((input_text, 'Text empty! Please enter text and retry.')) + # return "", result_text + + if not (image_prompt or input_text): + result_text.append((input_text, 'Please enter text or/and upload a image, then retry.')) + return "", result_text request_para = {"temperature": temperature, "top_p": top_p} - image = Image.open(image_prompt) + if image_prompt is not None: + image = Image.open(image_prompt) + else: + image = None try: answer = generate_text_with_image(input_text, image, result_text.copy(), request_para, is_zh) except Exception as e: @@ -73,6 +81,7 @@ def request_model(input_text, temperature, top_p, image_prompt, result_previous) def clear_fn(value): return "", [("", "Hi, What do you want to know about this image?")], None + def clear_fn2(value): return [("", "Hi, What do you want to know about this image?")] @@ -81,13 +90,14 @@ def main(args): gr.close_all() global model, tokenizer model, tokenizer = get_infer_setting(gpu_device=0, quant=args.quant) - + with gr.Blocks(css='style.css') as demo: gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(scale=4.5): with gr.Group(): - input_text = gr.Textbox(label='Input Text', placeholder='Please enter text prompt below and press ENTER.') + input_text = gr.Textbox(label='Input Text', + placeholder='Please enter text prompt below and press ENTER.') with gr.Row(): run_button = gr.Button('Generate') clear_button = gr.Button('Clear') @@ -100,15 +110,16 @@ def main(args): with gr.Row(): maintenance_notice = gr.Markdown(MAINTENANCE_NOTICE1) with gr.Column(scale=5.5): - result_text = gr.components.Chatbot(label='Multi-round conversation History', value=[("", "Hi, What do you want to know about this image?")]).style(height=550) + result_text = gr.components.Chatbot(label='Multi-round conversation History', value=[ + ("", "Hi, What do you want to know about this image?")]).style(height=550) gr.Markdown(NOTES) print(gr.__version__) - run_button.click(fn=request_model,inputs=[input_text, temperature, top_p, image_prompt, result_text], - outputs=[input_text, result_text]) - input_text.submit(fn=request_model,inputs=[input_text, temperature, top_p, image_prompt, result_text], + run_button.click(fn=request_model, inputs=[input_text, temperature, top_p, image_prompt, result_text], outputs=[input_text, result_text]) + input_text.submit(fn=request_model, inputs=[input_text, temperature, top_p, image_prompt, result_text], + outputs=[input_text, result_text]) clear_button.click(fn=clear_fn, inputs=clear_button, outputs=[input_text, result_text, image_prompt]) image_prompt.upload(fn=clear_fn2, inputs=clear_button, outputs=[result_text]) image_prompt.clear(fn=clear_fn2, inputs=clear_button, outputs=[result_text]) @@ -116,14 +127,18 @@ def main(args): print(gr.__version__) demo.queue(concurrency_count=10) - demo.launch(share=args.share) + print(f"Public IP address:{requests.get('https://api.ipify.org').text}") + demo.launch(share=args.share, server_name=args.server_name, server_port=args.server_port) if __name__ == '__main__': import argparse + parser = argparse.ArgumentParser() parser.add_argument("--quant", choices=[8, 4], type=int, default=None) parser.add_argument("--share", action="store_true") + parser.add_argument("--server_name", type=str, default="127.0.0.1") + parser.add_argument("--server_port", type=int, default=7860) args = parser.parse_args() - main(args) \ No newline at end of file + main(args)