Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,4 +161,6 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/
.run
test*.py
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,17 @@ VisualGLM-6B 可以进行图像的描述的相关知识的问答。
</details>

## 使用
通过测试:

CentOS Linux release 7.9.2009 (Core)

NVIDIA Driver Version: 525.105.17

cuda_11.7

cuDNN v8.9.3 (July 11th, 2023), for CUDA 11.x

Python 3.10.12

### 模型推理

Expand Down
13 changes: 13 additions & 0 deletions README_en.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,19 @@ VisualGLM-6B can answer questions related to image description.

## Usage


test passed:

CentOS Linux release 7.9.2009 (Core)

NVIDIA Driver Version: 525.105.17

cuda_11.7

cuDNN v8.9.3 (July 11th, 2023), for CUDA 11.x

Python 3.10.12

### Model Inference

Install dependencies with pip
Expand Down
21 changes: 15 additions & 6 deletions api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,19 @@
import datetime
import torch

# 命令行参数
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--quant", choices=[8, 4], type=int, default=None)
args = parser.parse_args()

gpu_number = 0
model, tokenizer = get_infer_setting(gpu_device=gpu_number)
model, tokenizer = get_infer_setting(gpu_device=gpu_number, quant=args.quant)

app = FastAPI()


@app.post('/')
async def visual_glm(request: Request):
json_post_raw = await request.json()
Expand All @@ -30,12 +39,12 @@ async def visual_glm(request: Request):

is_zh = is_chinese(input_text)
input_data = generate_input(input_text, input_image_encoded, history, input_para)
input_image, gen_kwargs = input_data['input_image'], input_data['gen_kwargs']
input_image, gen_kwargs = input_data['input_image'], input_data['gen_kwargs']
with torch.no_grad():
answer, history, _ = chat(None, model, tokenizer, input_text, history=history, image=input_image, \
max_length=gen_kwargs['max_length'], top_p=gen_kwargs['top_p'], \
top_k = gen_kwargs['top_k'], temperature=gen_kwargs['temperature'], english=not is_zh)
max_length=gen_kwargs['max_length'], top_p=gen_kwargs['top_p'], \
top_k=gen_kwargs['top_k'], temperature=gen_kwargs['temperature'], english=not is_zh)

now = datetime.datetime.now()
time = now.strftime("%Y-%m-%d %H:%M:%S")
response = {
Expand All @@ -48,4 +57,4 @@ async def visual_glm(request: Request):


if __name__ == '__main__':
uvicorn.run(app, host='0.0.0.0', port=8080, workers=1)
uvicorn.run(app, host='0.0.0.0', port=8080, workers=1)
13 changes: 9 additions & 4 deletions api_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
tokenizer = AutoTokenizer.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True).half().cuda()


app = FastAPI()


@app.post('/')
async def visual_glm(request: Request):
json_post_raw = await request.json()
Expand All @@ -23,9 +24,13 @@ async def visual_glm(request: Request):
history = request_data.get("history")
image_encoded = request_data.get("image")
query = request_data.get("text")
image_path = process_image(image_encoded)

with torch.no_grad():
if image_encoded is not None:
image_path = process_image(image_encoded)
else:
image_path = ''

with torch.no_grad():
result = model.stream_chat(tokenizer, image_path, query, history=history)
last_result = None
for value in result:
Expand All @@ -46,4 +51,4 @@ async def visual_glm(request: Request):


if __name__ == "__main__":
uvicorn.run(app, host='0.0.0.0', port=8080, workers=1)
uvicorn.run(app, host='0.0.0.0', port=8080, workers=1)
5 changes: 3 additions & 2 deletions model/infer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@ def is_chinese(text):
return zh_pattern.search(text)

def generate_input(input_text, input_image_prompt, history=[], input_para=None, image_is_encoded=True):
image = None
if not image_is_encoded:
image = input_image_prompt
else:
elif input_image_prompt:
decoded_image = base64.b64decode(input_image_prompt)
image = Image.open(BytesIO(decoded_image))

Expand All @@ -50,4 +51,4 @@ def process_image(image_encoded):
image_path = f'./examples/{image_hash}.png'
if not os.path.isfile(image_path):
image.save(image_path)
return os.path.abspath(image_path)
return os.path.abspath(image_path)
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
SwissArmyTransformer>=0.3.6
torch>1.10.0
torchvision
torch==1.13.0
torchvision==0.14.0
torchaudio==0.13.0
transformers>=4.27.1
mdtex2html
gradio
63 changes: 39 additions & 24 deletions web_demo.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
#!/usr/bin/env python

import gradio as gr
import requests
from PIL import Image
import os
import json
from model import is_chinese, get_infer_setting, generate_input, chat
import torch


def generate_text_with_image(input_text, image, history=[], request_data=dict(), is_zh=True):
input_para = {
"max_length": 2048,
Expand All @@ -19,34 +19,42 @@ def generate_text_with_image(input_text, image, history=[], request_data=dict(),
input_para.update(request_data)

input_data = generate_input(input_text, image, history, input_para, image_is_encoded=False)
input_image, gen_kwargs = input_data['input_image'], input_data['gen_kwargs']
input_image, gen_kwargs = input_data['input_image'], input_data['gen_kwargs']
with torch.no_grad():
answer, history, _ = chat(None, model, tokenizer, input_text, history=history, image=input_image, \
max_length=gen_kwargs['max_length'], top_p=gen_kwargs['top_p'], \
top_k = gen_kwargs['top_k'], temperature=gen_kwargs['temperature'], english=not is_zh)
max_length=gen_kwargs['max_length'], top_p=gen_kwargs['top_p'], \
top_k=gen_kwargs['top_k'], temperature=gen_kwargs['temperature'], english=not is_zh)
return answer


def request_model(input_text, temperature, top_p, image_prompt, result_previous):
result_text = [(ele[0], ele[1]) for ele in result_previous]
for i in range(len(result_text)-1, -1, -1):
for i in range(len(result_text) - 1, -1, -1):
if result_text[i][0] == "" or result_text[i][1] == "":
del result_text[i]
print(f"history {result_text}")

is_zh = is_chinese(input_text)
if image_prompt is None:
if is_zh:
result_text.append((input_text, '图片为空!请上传图片并重试。'))
else:
result_text.append((input_text, 'Image empty! Please upload a image and retry.'))
return input_text, result_text
elif input_text == "":
result_text.append((input_text, 'Text empty! Please enter text and retry.'))
return "", result_text
# if image_prompt is None:
# ...
# if is_zh:
# result_text.append((input_text, '图片为空!请上传图片并重试。'))
# else:
# result_text.append((input_text, 'Image empty! Please upload a image and retry.'))
# return input_text, result_text
# elif input_text == "":
# result_text.append((input_text, 'Text empty! Please enter text and retry.'))
# return "", result_text

if not (image_prompt or input_text):
result_text.append((input_text, 'Please enter text or/and upload a image, then retry.'))
return "", result_text

request_para = {"temperature": temperature, "top_p": top_p}
image = Image.open(image_prompt)
if image_prompt is not None:
image = Image.open(image_prompt)
else:
image = None
try:
answer = generate_text_with_image(input_text, image, result_text.copy(), request_para, is_zh)
except Exception as e:
Expand All @@ -73,6 +81,7 @@ def request_model(input_text, temperature, top_p, image_prompt, result_previous)
def clear_fn(value):
return "", [("", "Hi, What do you want to know about this image?")], None


def clear_fn2(value):
return [("", "Hi, What do you want to know about this image?")]

Expand All @@ -81,13 +90,14 @@ def main(args):
gr.close_all()
global model, tokenizer
model, tokenizer = get_infer_setting(gpu_device=0, quant=args.quant)

with gr.Blocks(css='style.css') as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column(scale=4.5):
with gr.Group():
input_text = gr.Textbox(label='Input Text', placeholder='Please enter text prompt below and press ENTER.')
input_text = gr.Textbox(label='Input Text',
placeholder='Please enter text prompt below and press ENTER.')
with gr.Row():
run_button = gr.Button('Generate')
clear_button = gr.Button('Clear')
Expand All @@ -100,30 +110,35 @@ def main(args):
with gr.Row():
maintenance_notice = gr.Markdown(MAINTENANCE_NOTICE1)
with gr.Column(scale=5.5):
result_text = gr.components.Chatbot(label='Multi-round conversation History', value=[("", "Hi, What do you want to know about this image?")]).style(height=550)
result_text = gr.components.Chatbot(label='Multi-round conversation History', value=[
("", "Hi, What do you want to know about this image?")]).style(height=550)

gr.Markdown(NOTES)

print(gr.__version__)
run_button.click(fn=request_model,inputs=[input_text, temperature, top_p, image_prompt, result_text],
outputs=[input_text, result_text])
input_text.submit(fn=request_model,inputs=[input_text, temperature, top_p, image_prompt, result_text],
run_button.click(fn=request_model, inputs=[input_text, temperature, top_p, image_prompt, result_text],
outputs=[input_text, result_text])
input_text.submit(fn=request_model, inputs=[input_text, temperature, top_p, image_prompt, result_text],
outputs=[input_text, result_text])
clear_button.click(fn=clear_fn, inputs=clear_button, outputs=[input_text, result_text, image_prompt])
image_prompt.upload(fn=clear_fn2, inputs=clear_button, outputs=[result_text])
image_prompt.clear(fn=clear_fn2, inputs=clear_button, outputs=[result_text])

print(gr.__version__)

demo.queue(concurrency_count=10)
demo.launch(share=args.share)
print(f"Public IP address:{requests.get('https://api.ipify.org').text}")
demo.launch(share=args.share, server_name=args.server_name, server_port=args.server_port)


if __name__ == '__main__':
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--quant", choices=[8, 4], type=int, default=None)
parser.add_argument("--share", action="store_true")
parser.add_argument("--server_name", type=str, default="127.0.0.1")
parser.add_argument("--server_port", type=int, default=7860)
args = parser.parse_args()

main(args)
main(args)