改造fastchat默认gradio页面以支持系统指令输入 - peter-xbs/CommonCodes GitHub Wiki
代码如下: gradio_web_server_sysmsg.py代替原有gradio_web_server.py,运行即可支持系统指令输入
"""
The gradio demo server for chatting with a single model.
"""
import argparse
from collections import defaultdict
import datetime
import hashlib
import json
import os
import random
import time
import uuid
import gradio as gr
import requests
from fastchat.constants import (
LOGDIR,
WORKER_API_TIMEOUT,
ErrorCode,
MODERATION_MSG,
CONVERSATION_LIMIT_MSG,
RATE_LIMIT_MSG,
SERVER_ERROR_MSG,
INPUT_CHAR_LEN_LIMIT,
CONVERSATION_TURN_LIMIT,
SESSION_EXPIRATION_TIME,
)
from fastchat.model.model_adapter import (
get_conversation_template,
)
from fastchat.model.model_registry import get_model_info, model_info
from fastchat.serve.api_provider import get_api_provider_stream_iter
from fastchat.utils import (
build_logger,
get_window_url_params_js,
get_window_url_params_with_tos_js,
moderation_filter,
parse_gradio_auth_creds,
load_image,
)
logger = build_logger("gradio_web_server", "gradio_web_server.log")
headers = {"User-Agent": "FastChat Client"}
no_change_btn = gr.Button()
enable_btn = gr.Button(interactive=True, visible=True)
disable_btn = gr.Button(interactive=False)
invisible_btn = gr.Button(interactive=False, visible=False)
controller_url = None
enable_moderation = False
acknowledgment_md = """
### Terms of Service
Users are required to agree to the following terms before using the service:
The service is a research preview. It only provides limited safety measures and may generate offensive content.
It must not be used for any illegal, harmful, violent, racist, or sexual purposes.
Please do not upload any private information.
The service collects user dialogue data, including both text and images, and reserves the right to distribute it under a Creative Commons Attribution (CC-BY) or a similar license.
### Acknowledgment
We thank [Kaggle](https://www.kaggle.com/), [MBZUAI](https://mbzuai.ac.ae/), [a16z](https://www.a16z.com/), [Together AI](https://www.together.ai/), [Anyscale](https://www.anyscale.com/), [HuggingFace](https://huggingface.co/) for their generous [sponsorship](https://lmsys.org/donations/).
<div class="sponsor-image-about">
<img src="https://storage.googleapis.com/public-arena-asset/kaggle.png" alt="Kaggle">
<img src="https://storage.googleapis.com/public-arena-asset/mbzuai.jpeg" alt="MBZUAI">
<img src="https://storage.googleapis.com/public-arena-asset/a16z.jpeg" alt="a16z">
<img src="https://storage.googleapis.com/public-arena-asset/together.png" alt="Together AI">
<img src="https://storage.googleapis.com/public-arena-asset/anyscale.png" alt="AnyScale">
<img src="https://storage.googleapis.com/public-arena-asset/huggingface.png" alt="HuggingFace">
</div>
"""
# JSON file format of API-based models:
# {
# "gpt-3.5-turbo": {
# "model_name": "gpt-3.5-turbo",
# "api_type": "openai",
# "api_base": "https://api.openai.com/v1",
# "api_key": "sk-******",
# "anony_only": false
# }
# }
#
# - "api_type" can be one of the following: openai, anthropic, gemini, or mistral. For custom APIs, add a new type and implement it accordingly.
# - "anony_only" indicates whether to display this model in anonymous mode only.
api_endpoint_info = {}
class State:
def __init__(self, model_name):
self.conv = get_conversation_template(model_name)
self.conv_id = uuid.uuid4().hex
self.skip_next = False
self.model_name = model_name
def to_gradio_chatbot(self):
return self.conv.to_gradio_chatbot()
def dict(self):
base = self.conv.dict()
base.update(
{
"conv_id": self.conv_id,
"model_name": self.model_name,
}
)
return base
def set_global_vars(controller_url_, enable_moderation_):
global controller_url, enable_moderation
controller_url = controller_url_
enable_moderation = enable_moderation_
def get_conv_log_filename():
t = datetime.datetime.now()
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
return name
def get_model_list(controller_url, register_api_endpoint_file, multimodal):
global api_endpoint_info
# Add models from the controller
if controller_url:
ret = requests.post(controller_url + "/refresh_all_workers")
assert ret.status_code == 200
if multimodal:
ret = requests.post(controller_url + "/list_multimodal_models")
models = ret.json()["models"]
else:
ret = requests.post(controller_url + "/list_language_models")
models = ret.json()["models"]
else:
models = []
# Add models from the API providers
if register_api_endpoint_file:
api_endpoint_info = json.load(open(register_api_endpoint_file))
for mdl, mdl_dict in api_endpoint_info.items():
mdl_multimodal = mdl_dict.get("multimodal", False)
if multimodal and mdl_multimodal:
models += [mdl]
elif not multimodal and not mdl_multimodal:
models += [mdl]
# Remove anonymous models
models = list(set(models))
visible_models = models.copy()
for mdl in visible_models:
if mdl not in api_endpoint_info:
continue
mdl_dict = api_endpoint_info[mdl]
if mdl_dict["anony_only"]:
visible_models.remove(mdl)
# Sort models and add descriptions
priority = {k: f"___{i:03d}" for i, k in enumerate(model_info)}
models.sort(key=lambda x: priority.get(x, x))
visible_models.sort(key=lambda x: priority.get(x, x))
logger.info(f"All models: {models}")
logger.info(f"Visible models: {visible_models}")
return visible_models, models
def load_demo_single(models, url_params):
selected_model = models[0] if len(models) > 0 else ""
if "model" in url_params:
model = url_params["model"]
if model in models:
selected_model = model
dropdown_update = gr.Dropdown(choices=models, value=selected_model, visible=True)
state = None
return state, dropdown_update
def load_demo(url_params, request: gr.Request):
global models
ip = get_ip(request)
logger.info(f"load_demo. ip: {ip}. params: {url_params}")
if args.model_list_mode == "reload":
models, all_models = get_model_list(
controller_url, args.register_api_endpoint_file, False
)
return load_demo_single(models, url_params)
def vote_last_response(state, vote_type, model_selector, request: gr.Request):
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(time.time(), 4),
"type": vote_type,
"model": model_selector,
"state": state.dict(),
"ip": get_ip(request),
}
fout.write(json.dumps(data) + "\n")
def upvote_last_response(state, model_selector, request: gr.Request):
ip = get_ip(request)
logger.info(f"upvote. ip: {ip}")
vote_last_response(state, "upvote", model_selector, request)
return ("",) + (disable_btn,) * 3
def downvote_last_response(state, model_selector, request: gr.Request):
ip = get_ip(request)
logger.info(f"downvote. ip: {ip}")
vote_last_response(state, "downvote", model_selector, request)
return ("",) + (disable_btn,) * 3
def flag_last_response(state, model_selector, request: gr.Request):
ip = get_ip(request)
logger.info(f"flag. ip: {ip}")
vote_last_response(state, "flag", model_selector, request)
return ("",) + (disable_btn,) * 3
def regenerate(state, request: gr.Request):
ip = get_ip(request)
logger.info(f"regenerate. ip: {ip}")
state.conv.update_last_message(None)
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
def clear_history(request: gr.Request):
ip = get_ip(request)
logger.info(f"clear_history. ip: {ip}")
state = None
return (state, [], "", None) + (disable_btn,) * 5
def get_ip(request: gr.Request):
if "cf-connecting-ip" in request.headers:
ip = request.headers["cf-connecting-ip"]
else:
ip = request.client.host
return ip
def _prepare_text_with_image(state, text, image):
if image is not None:
if len(state.conv.get_images()) > 0:
# reset convo with new image
state.conv = get_conversation_template(state.model_name)
image = state.conv.convert_image_to_base64(
image
) # PIL type is not JSON serializable
text = text, [image]
return text
def add_text(state, model_selector, text, image, request: gr.Request):
ip = get_ip(request)
logger.info(f"add_text. ip: {ip}. len: {len(text)}")
if state is None:
state = State(model_selector)
if len(text) <= 0:
state.skip_next = True
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
flagged = moderation_filter(text, [state.model_name])
if flagged:
logger.info(f"violate moderation. ip: {ip}. text: {text}")
# overwrite the original text
text = MODERATION_MSG
if (len(state.conv.messages) - state.conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
logger.info(f"conversation turn limit. ip: {ip}. text: {text}")
state.skip_next = True
return (state, state.to_gradio_chatbot(), CONVERSATION_LIMIT_MSG, None) + (
no_change_btn,
) * 5
text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
text = _prepare_text_with_image(state, text, image)
state.conv.append_message(state.conv.roles[0], text)
state.conv.append_message(state.conv.roles[1], None)
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
def model_worker_stream_iter(
conv,
model_name,
worker_addr,
prompt,
temperature,
repetition_penalty,
top_p,
max_new_tokens,
images,
):
# Make requests
gen_params = {
"model": model_name,
"prompt": prompt,
"temperature": temperature,
"repetition_penalty": repetition_penalty,
"top_p": top_p,
"max_new_tokens": max_new_tokens,
"stop": conv.stop_str,
"stop_token_ids": conv.stop_token_ids,
"echo": False,
}
logger.info(f"==== request ====\n{gen_params}")
if len(images) > 0:
gen_params["images"] = images
# Stream output
response = requests.post(
worker_addr + "/worker_generate_stream",
headers=headers,
json=gen_params,
stream=True,
timeout=WORKER_API_TIMEOUT,
)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
yield data
def is_limit_reached(model_name, ip):
monitor_url = "http://localhost:9090"
try:
ret = requests.get(
f"{monitor_url}/is_limit_reached?model={model_name}&user_id={ip}", timeout=1
)
obj = ret.json()
return obj
except Exception as e:
logger.info(f"monitor error: {e}")
return None
def set_system_msg(conv, mode, sys_gr):
# ["默认", "问诊", "文档", "自定义"]
if mode == "默认":
return conv
elif mode == "问诊":
sys_gr = """你现在扮演一名富有同情心的临床医生,为患者提供专业预问诊服务,你事先并不了解你的患者的基本信息,你需要通过和患者的持续沟通来获取患者病情概况,并给出诊疗意见。
你在和患者对话时,应符合以下基本要求:
1. 首轮对话你应该询问患者病情信息;
2. 你和患者沟通的对话主题应丰富,问诊应有层次和顺序,按下述顺序逐步深入,重点包含主诉、现病史、既往史,如果患者病情可能与遗传相关可询问家族史相关信息;根据病情需要决定是否询问个人史、婚育史信息。
3. 现病史部分问诊顺序和内容要求如下:
a. 按下述顺序分别询问患者发病情况:
a1. 发病时间、地点;
a2. 起病缓急;
a3. 前驱症状;
a4. 可能诱因等;
b. 按下述顺序分别询问患者主要症状特点及其发展变化:
b1. 症状部位;
b2. 症状性、质;
b3. 症状持续时间、程度;
b4. 症状缓解、加剧因素;
b5. 症状演变发展等;
c. 按下述顺序分别询问伴随症状:
c1. 如果患者描述了伴随症状,可询问伴随症状情况;
c2. 并询问伴随症状与主要症状之间的关系
d. 询问患者发病后诊治经过及结果:
d1. 本次就诊前是否有相关检查;
d2. 本次就诊前治疗详细情况及效果;
e. 如对后续诊断和治疗有所帮助,可选择性询问下述患者发病以来一般情况:
e1. 精神、睡眠状况;
e2. 食欲变化情况;
e3. 大小便情况;
e4. 体重变化情况;
e5. 体力变化情况等;
4. 既往史部分问诊顺序和内容:
a. 询问患者与本次疾病、症状相关的疾病史;
b. 如对后续诊疗有益,可选择性询问下述病史:
b1. 传染病史;
b2. 预防接种史;
b3. 手术外伤史;
b4. 过敏史等信息;
5. 再次注意,你对患者信息事先并不了解,只能通过和患者持续沟通才能掌握患者信息;
6. 每次只能问一个问题,再次强调一次只能向患者询问一个问题,需考虑到患者同时回答多个问题压力会比较大;
7. 你的提问是层层递进式的,比如在患者描述病情中涉及到的关键点后,你应该继续深入挖掘详情信息,为后续准确的诊断提供足够信息,特别注意一次提问一个问题,根据患者的回复逐步深化追问,最终获取患者全面信息;
8. 你应该发起尽可能多的问诊轮次,以收集足够的信息支持至少2种鉴别诊断判定;
9. 对于患者提出的问题,你应该给与恰当的回复;
10. 不要重复提问相同或相似的问题;
11. 最后对话结束时,你需提醒患者你已收集相关信息并同步给接诊医生。"""
conv.set_system_message(sys_gr)
return conv
elif mode == "自定义":
if not sys_gr.strip():
return conv
else:
conv.set_system_message(sys_gr)
return conv
return conv
def bot_response(
state,
temperature,
top_p,
max_new_tokens,
mode,
sys_gr,
request: gr.Request,
apply_rate_limit=True,
use_recommended_config=False,
):
ip = get_ip(request)
logger.info(f"bot_response. ip: {ip}")
start_tstamp = time.time()
temperature = float(temperature)
top_p = float(top_p)
max_new_tokens = int(max_new_tokens)
if state.skip_next:
# This generate call is skipped due to invalid inputs
state.skip_next = False
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
return
if apply_rate_limit:
ret = is_limit_reached(state.model_name, ip)
if ret is not None and ret["is_limit_reached"]:
error_msg = RATE_LIMIT_MSG + "\n\n" + ret["reason"]
logger.info(f"rate limit reached. ip: {ip}. error_msg: {ret['reason']}")
state.conv.update_last_message(error_msg)
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
return
conv, model_name = state.conv, state.model_name
conv = set_system_msg(conv, mode, sys_gr)
model_api_dict = (
api_endpoint_info[model_name] if model_name in api_endpoint_info else None
)
images = conv.get_images()
if model_api_dict is None:
# Query worker address
ret = requests.post(
controller_url + "/get_worker_address", json={"model": model_name}
)
worker_addr = ret.json()["address"]
logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
# No available worker
if worker_addr == "":
conv.update_last_message(SERVER_ERROR_MSG)
yield (
state,
state.to_gradio_chatbot(),
disable_btn,
disable_btn,
disable_btn,
enable_btn,
enable_btn,
)
return
# Construct prompt.
# We need to call it here, so it will not be affected by "▌".
prompt = conv.get_prompt()
# Set repetition_penalty
if "t5" in model_name:
repetition_penalty = 1.2
else:
repetition_penalty = 1.0
stream_iter = model_worker_stream_iter(
conv,
model_name,
worker_addr,
prompt,
temperature,
repetition_penalty,
top_p,
max_new_tokens,
images,
)
else:
if use_recommended_config:
recommended_config = model_api_dict.get("recommended_config", None)
if recommended_config is not None:
temperature = recommended_config.get("temperature", temperature)
top_p = recommended_config.get("top_p", top_p)
stream_iter = get_api_provider_stream_iter(
conv,
model_name,
model_api_dict,
temperature,
top_p,
max_new_tokens,
)
conv.update_last_message("▌")
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
try:
for i, data in enumerate(stream_iter):
if data["error_code"] == 0:
output = data["text"].strip()
conv.update_last_message(output + "▌")
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
else:
output = data["text"] + f"\n\n(error_code: {data['error_code']})"
conv.update_last_message(output)
yield (state, state.to_gradio_chatbot()) + (
disable_btn,
disable_btn,
disable_btn,
enable_btn,
enable_btn,
)
return
output = data["text"].strip()
conv.update_last_message(output)
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
except requests.exceptions.RequestException as e:
conv.update_last_message(
f"{SERVER_ERROR_MSG}\n\n"
f"(error_code: {ErrorCode.GRADIO_REQUEST_ERROR}, {e})"
)
yield (state, state.to_gradio_chatbot()) + (
disable_btn,
disable_btn,
disable_btn,
enable_btn,
enable_btn,
)
return
except Exception as e:
conv.update_last_message(
f"{SERVER_ERROR_MSG}\n\n"
f"(error_code: {ErrorCode.GRADIO_STREAM_UNKNOWN_ERROR}, {e})"
)
yield (state, state.to_gradio_chatbot()) + (
disable_btn,
disable_btn,
disable_btn,
enable_btn,
enable_btn,
)
return
finish_tstamp = time.time()
logger.info(f"{output}")
# We load the image because gradio accepts base64 but that increases file size by ~1.33x
loaded_images = [load_image(image) for image in images]
images_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in loaded_images]
for image, hash_str in zip(loaded_images, images_hash):
t = datetime.datetime.now()
filename = os.path.join(
LOGDIR,
"serve_images",
f"{hash_str}.jpg",
)
if not os.path.isfile(filename):
os.makedirs(os.path.dirname(filename), exist_ok=True)
image.save(filename)
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name,
"gen_params": {
"temperature": temperature,
"top_p": top_p,
"max_new_tokens": max_new_tokens,
},
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state.dict(),
"ip": get_ip(request),
"images": images_hash,
}
fout.write(json.dumps(data) + "\n")
block_css = """
#notice_markdown .prose {
font-size: 120% !important;
}
#notice_markdown th {
display: none;
}
#notice_markdown td {
padding-top: 6px;
padding-bottom: 6px;
}
#model_description_markdown {
font-size: 120% !important;
}
#leaderboard_markdown .prose {
font-size: 120% !important;
}
#leaderboard_markdown td {
padding-top: 6px;
padding-bottom: 6px;
}
#leaderboard_dataframe td {
line-height: 0.1em;
}
#about_markdown .prose {
font-size: 120% !important;
}
#ack_markdown .prose {
font-size: 120% !important;
}
footer {
display:none !important;
}
.sponsor-image-about img {
margin: 0 20px;
margin-top: 20px;
height: 40px;
max-height: 100%;
width: auto;
float: left;
}
"""
def get_model_description_md(models):
model_description_md = """
| | | |
| ---- | ---- | ---- |
"""
ct = 0
visited = set()
for i, name in enumerate(models):
minfo = get_model_info(name)
if minfo.simple_name in visited:
continue
visited.add(minfo.simple_name)
one_model_md = f"[{minfo.simple_name}]({minfo.link}): {minfo.description}"
if ct % 3 == 0:
model_description_md += "|"
model_description_md += f" {one_model_md} |"
if ct % 3 == 2:
model_description_md += "\n"
ct += 1
return model_description_md
def build_about():
about_markdown = """
# About Us
Chatbot Arena is an open-source research project developed by members from [LMSYS](https://lmsys.org/about/) and UC Berkeley [SkyLab](https://sky.cs.berkeley.edu/). Our mission is to build an open crowdsourced platform to collect human feedback and evaluate LLMs under real-world scenarios. We open-source our [FastChat](https://github.com/lm-sys/FastChat) project at GitHub and release chat and human feedback datasets [here](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md). We invite everyone to join us in this journey!
## Read More
- Chatbot Arena [launch post](https://lmsys.org/blog/2023-05-03-arena/), [data release](https://lmsys.org/blog/2023-07-20-dataset/)
- LMSYS-Chat-1M [report](https://arxiv.org/abs/2309.11998)
## Core Members
[Lianmin Zheng](https://lmzheng.net/), [Wei-Lin Chiang](https://infwinston.github.io/), [Ying Sheng](https://sites.google.com/view/yingsheng/home), [Siyuan Zhuang](https://scholar.google.com/citations?user=KSZmI5EAAAAJ)
## Advisors
[Ion Stoica](http://people.eecs.berkeley.edu/~istoica/), [Joseph E. Gonzalez](https://people.eecs.berkeley.edu/~jegonzal/), [Hao Zhang](https://cseweb.ucsd.edu/~haozhang/)
## Contact Us
- Follow our [Twitter](https://twitter.com/lmsysorg), [Discord](https://discord.gg/HSWAKCrnFx) or email us at [email protected]
- File issues on [GitHub](https://github.com/lm-sys/FastChat)
- Download our datasets and models on [HuggingFace](https://huggingface.co/lmsys)
## Acknowledgment
We thank [SkyPilot](https://github.com/skypilot-org/skypilot) and [Gradio](https://github.com/gradio-app/gradio) team for their system support.
We also thank [Kaggle](https://www.kaggle.com/), [MBZUAI](https://mbzuai.ac.ae/), [a16z](https://www.a16z.com/), [Together AI](https://www.together.ai/), [Anyscale](https://www.anyscale.com/), [HuggingFace](https://huggingface.co/) for their generous sponsorship. Learn more about partnership [here](https://lmsys.org/donations/).
<div class="sponsor-image-about">
<img src="https://storage.googleapis.com/public-arena-asset/kaggle.png" alt="Kaggle">
<img src="https://storage.googleapis.com/public-arena-asset/mbzuai.jpeg" alt="MBZUAI">
<img src="https://storage.googleapis.com/public-arena-asset/a16z.jpeg" alt="a16z">
<img src="https://storage.googleapis.com/public-arena-asset/together.png" alt="Together AI">
<img src="https://storage.googleapis.com/public-arena-asset/anyscale.png" alt="AnyScale">
<img src="https://storage.googleapis.com/public-arena-asset/huggingface.png" alt="HuggingFace">
</div>
"""
gr.Markdown(about_markdown, elem_id="about_markdown")
def build_single_model_ui(models, add_promotion_links=False):
promotion = (
"""
- | [GitHub](https://github.com/lm-sys/FastChat) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) |
- Introducing Llama 2: The Next Generation Open Source Large Language Model. [[Website]](https://ai.meta.com/llama/)
- Vicuna: An Open-Source Chatbot Impressing GPT-4 with 90% ChatGPT Quality. [[Blog]](https://lmsys.org/blog/2023-03-30-vicuna/)
## 🤖 Choose any model to chat
"""
if add_promotion_links
else ""
)
notice_markdown = f"""
# 🏔️ 九天●医疗大模型测试页面
{promotion}
"""
state = gr.State()
gr.Markdown(notice_markdown, elem_id="notice_markdown")
with gr.Group(elem_id="share-region-named"):
with gr.Row(elem_id="model_selector_row"):
model_selector = gr.Dropdown(
choices=models,
value=models[0] if len(models) > 0 else "",
interactive=True,
show_label=False,
container=False,
)
with gr.Row():
with gr.Accordion(
f"🔍 Expand to see the descriptions of {len(models)} models",
open=False,
):
model_description_md = get_model_description_md(models)
gr.Markdown(model_description_md, elem_id="model_description_markdown")
with gr.Accordion("系统指令设置", open=False) as sys_set:
with gr.Row():
mode = gr.Radio(
["默认", "问诊", "自定义"],
label="Mode",
value="默认",
)
sys_gr = gr.Textbox(
label="系统指令",
lines=2,
)
chatbot = gr.Chatbot(
elem_id="chatbot",
label="Scroll down and start chatting",
height=550,
show_copy_button=True,
)
with gr.Row():
textbox = gr.Textbox(
show_label=False,
placeholder="👉 Enter your prompt and press ENTER",
elem_id="input_box",
)
send_btn = gr.Button(value="Send", variant="primary", scale=0)
with gr.Row() as button_row:
upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
with gr.Accordion("Parameters", open=False) as parameter_row:
temperature = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.7,
step=0.1,
interactive=True,
label="Temperature",
)
top_p = gr.Slider(
minimum=0.0,
maximum=1.0,
value=1.0,
step=0.1,
interactive=True,
label="Top P",
)
max_output_tokens = gr.Slider(
minimum=16,
maximum=2048,
value=1024,
step=64,
interactive=True,
label="Max output tokens",
)
if add_promotion_links:
gr.Markdown(acknowledgment_md, elem_id="ack_markdown")
# Register listeners
imagebox = gr.State(None)
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
upvote_btn.click(
upvote_last_response,
[state, model_selector],
[textbox, upvote_btn, downvote_btn, flag_btn],
)
downvote_btn.click(
downvote_last_response,
[state, model_selector],
[textbox, upvote_btn, downvote_btn, flag_btn],
)
flag_btn.click(
flag_last_response,
[state, model_selector],
[textbox, upvote_btn, downvote_btn, flag_btn],
)
regenerate_btn.click(
regenerate, state, [state, chatbot, textbox, imagebox] + btn_list
).then(
bot_response,
[state, temperature, top_p, max_output_tokens, mode, sys_gr],
[state, chatbot] + btn_list,
)
clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox] + btn_list)
model_selector.change(
clear_history, None, [state, chatbot, textbox, imagebox] + btn_list
)
textbox.submit(
add_text,
[state, model_selector, textbox, imagebox],
[state, chatbot, textbox, imagebox] + btn_list,
).then(
bot_response,
[state, temperature, top_p, max_output_tokens, mode, sys_gr],
[state, chatbot] + btn_list,
)
send_btn.click(
add_text,
[state, model_selector, textbox, imagebox],
[state, chatbot, textbox, imagebox] + btn_list,
).then(
bot_response,
[state, temperature, top_p, max_output_tokens, mode, sys_gr],
[state, chatbot] + btn_list,
)
return [state, model_selector]
def build_demo(models):
with gr.Blocks(
title="Chat with Open Large Language Models",
theme=gr.themes.Soft(),
css=block_css,
) as demo:
url_params = gr.JSON(visible=False)
state, model_selector = build_single_model_ui(models)
if args.model_list_mode not in ["once", "reload"]:
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
if args.show_terms_of_use:
load_js = get_window_url_params_with_tos_js
else:
load_js = get_window_url_params_js
demo.load(
load_demo,
[url_params],
[
state,
model_selector,
],
js=load_js,
)
return demo
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int)
parser.add_argument(
"--share",
action="store_true",
help="Whether to generate a public, shareable link",
)
parser.add_argument(
"--controller-url",
type=str,
default="http://localhost:21001",
help="The address of the controller",
)
parser.add_argument(
"--concurrency-count",
type=int,
default=10,
help="The concurrency count of the gradio queue",
)
parser.add_argument(
"--model-list-mode",
type=str,
default="once",
choices=["once", "reload"],
help="Whether to load the model list once or reload the model list every time",
)
parser.add_argument(
"--moderate",
action="store_true",
help="Enable content moderation to block unsafe inputs",
)
parser.add_argument(
"--show-terms-of-use",
action="store_true",
help="Shows term of use before loading the demo",
)
parser.add_argument(
"--register-api-endpoint-file",
type=str,
help="Register API-based model endpoints from a JSON file",
)
parser.add_argument(
"--gradio-auth-path",
type=str,
help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"',
)
parser.add_argument(
"--gradio-root-path",
type=str,
help="Sets the gradio root path, eg /abc/def. Useful when running behind a reverse-proxy or at a custom URL path prefix",
)
args = parser.parse_args()
logger.info(f"args: {args}")
# Set global variables
set_global_vars(args.controller_url, args.moderate)
models, all_models = get_model_list(
args.controller_url, args.register_api_endpoint_file, False
)
# Set authorization credentials
auth = None
if args.gradio_auth_path is not None:
auth = parse_gradio_auth_creds(args.gradio_auth_path)
# Launch the demo
demo = build_demo(models)
demo.queue(
default_concurrency_limit=args.concurrency_count,
status_update_rate=10,
api_open=False,
).launch(
server_name=args.host,
server_port=args.port,
share=args.share,
max_threads=200,
auth=auth,
root_path=args.gradio_root_path,
)