How to convert ChatRWKV to Web API and talk to it from Web chat UIChat UI

riversun

riversun

Posted on March 30, 2023

How to convert ChatRWKV to Web API and talk to it from Web chat UIChat UI

Summary

(This article has been translated from Japanese)

Chat UI displayed in a web browser allows you to talk to ChatRWKV.

ChatRWKV is an open source program that aims to work similar to OpenAI's ChatGPT and works quite comfortably on a local PC.

I made it into a chat engine and used the chat UI library chatux to call it, and here are its contents.

↓Now you can chat with ChatRWKV like this!
demo2.gif

The full source code is here
https://github.com/riversun/chatux-server-rwkv

Target

  • The code in this article assumes a GPU (CUDA) environment.
    • As described later, if you use smaller trained data, it works well on older GPUs.
      • For example, using [rwkv-4-pile-3b]

(https://huggingface.co/BlinkDL/rwkv-4-pile-3b) as trained data with fp16i8 worked on GeForce RTX 2060 (6GB memory).

Environment

We tested the following environment

  • Python: 3.9
  • GPU CUDA Version: 12
  • GPU memory: 24GB
  • OS: Ubuntu Desktop 22.04 / Windows 11

Main part

STEP0: Operation to serverize ChatRWKV

  • Since chatux will be used as a front-end chat UI, we want to serverize ChatRWKV.

  • We want to make it as easy as possible, so modify chat.py under the v2 directory of the ChatRWKV repository below to make it a server.

https://github.com/BlinkDL/ChatRWKV/blob/main/v2/chat.py

STEP1: Install related packages

Install related packages (*)

pip install rwkv fastapi uvicorn
Enter fullscreen mode Exit fullscreen mode

(This assumes that you already have pytorch for CUDA running. If you get runtime errors, install the missing packages as needed)

STEP2: Edit chat.py to make it a chat server

Let's modify the previous chat.py into a server

There are 2 points in making a chat server

1. Convert to Web API

  • Use a package called fastapi to turn it into a Web API server
  • Use a package called uvicorn to start a web api server made with fastapi
    • Accept input with path chat_api with

@app.get("/chat_api") annotation
- Next, throw the input text from the user to the rnn model of ChatRWKV with reply = handle_message(text).replace('\n', '<br>') as follows
- Return the results from the rnn model packed in outJson. only this :)

@app.get("/chat_api")
async def chat(text: str = ""):
    reply = handle_message(text).replace('\n', '<br>')
    print(f'input:{text} reply:{reply}')

    outJson = {
        "output": [
            {
                "type": "text",
                "value": reply
            }
        ]
    }
    return outJson;


app.mount("/", StaticFiles(directory="html", html=True), name="html")
Enter fullscreen mode Exit fullscreen mode

2. Handling of rnn model generation results

  • The command line version of chat used to output chat response generation results to the console one by one, but in the case of the web chat version (the current chatux archipelago) it does not support sequential output, so sentence generation by the model ended for the time being. Made it possible to respond in stages
    • Return send_msg = pipeline.decode(model_tokens[begin:]).strip() part as response
    • It is better to return the reply from the rnn model one by one because it is easier to read the results, so the architecture that can realize this is a homework on the chatux side

3. + command processing

  • I tried to pass +gen, +i, +qa, etc., but the response from the model is returned unprocessed, so if you use these commands, it will be strange It may become. Also, free generation is long, so you may have to wait.

Below is the full source code of the chat server

########################################################################################################
# This program is base on The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
#
# The original file of this source code can be found here.
# https://github.com/BlinkDL/ChatRWKV/blob/main/v2/chat.py
#
# Based on source code above, I have added the following functionality required for a Web API server for chatbots
# - web server functionality
# - Use queries received in get requests as input to rnn, and http responses to the results generated by rnn
#
########################################################################################################

from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles

import uvicorn

import os, copy, types, gc, sys

current_path = os.path.dirname(os.path.abspath(__file__))

import numpy as np

args = types.SimpleNamespace()

# specify chat server
HOST = 'localhost'
PORT = 8001
URL = f'http://{HOST}:{PORT}'

# specify RWKV strategy,model(weight data)
STRATEGY = 'cuda fp16i8'
MODEL_NAME = 'RWKV-4-Pile-14B-20230313-ctx8192-test1050.pth'

# specify params for weight data
args.n_layer = 32
args.n_embd = 4096
args.ctx_len = 4096

CHAT_LANG = "English"
PROMPT_FILE = f'{current_path}/init_prompt/English-1.py'

try:
    os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1]
except:
    pass
np.set_printoptions(precision=4, suppress=True, linewidth=200)

print('\n\nChatRWKV v2 https://github.com/BlinkDL/ChatRWKV')

import torch

torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True

args.strategy = STRATEGY

os.environ["RWKV_JIT_ON"] = '1'  # '1' or '0', please use torch 1.13+ and benchmark speed
os.environ["RWKV_CUDA_ON"] = '0'  # '1' to compile CUDA kernel (10x faster), requires c++ compiler & cuda libraries
args.MODEL_NAME = f'{current_path}/data/{MODEL_NAME}'

CHAT_LEN_SHORT = 40
CHAT_LEN_LONG = 150
FREE_GEN_LEN = 200

GEN_TEMP = 1.0  # sometimes it's a good idea to increase temp. try it
GEN_TOP_P = 0.8
GEN_alpha_presence = 0.2  # Presence Penalty
GEN_alpha_frequency = 0.2  # Frequency Penalty
AVOID_REPEAT = ',:?!'

CHUNK_LEN = 256  # split input into chunks to save VRAM (shorter -> slower) 入力を分割する

PILE_v2_MODEL = False  # ONLY FOR MY OWN TESTING. STILL TRAINING PILE_v2_MODELs

all_state = {}

print(f'\n{CHAT_LANG} - {args.strategy} - {PROMPT_FILE}')
from rwkv.model import RWKV
from rwkv.utils import PIPELINE

with open(PROMPT_FILE, 'rb') as file:
    user = None
    bot = None
    interface = None
    init_prompt = None
    exec(compile(file.read(), PROMPT_FILE, 'exec'))
init_prompt = init_prompt.strip().split('\n')
for c in range(len(init_prompt)):
    init_prompt[c] = init_prompt[c].strip().strip('\u3000').strip('\r')
init_prompt = '\n' + ('\n'.join(init_prompt)).strip() + '\n\n'

print(f'Loading model - {args.MODEL_NAME}')

model = RWKV(model=args.MODEL_NAME, strategy=args.strategy)

if not PILE_v2_MODEL:
    pipeline = PIPELINE(model, f"{current_path}/20B_tokenizer.json")
    END_OF_TEXT = 0
    END_OF_LINE = 187
else:
    pipeline = PIPELINE(model, "cl100k_base")
    END_OF_TEXT = 100257
    END_OF_LINE = 198

model_tokens = []
model_state = None

AVOID_REPEAT_TOKENS = []
for i in AVOID_REPEAT:
    dd = pipeline.encode(i)
    assert len(dd) == 1
    AVOID_REPEAT_TOKENS += dd


def run_rnn(tokens, newline_adj=0):
    global model_tokens, model_state

    tokens = [int(x) for x in tokens]

    model_tokens += tokens

    while len(tokens) > 0:
        out, model_state = model.forward(tokens[:CHUNK_LEN], model_state)
        tokens = tokens[CHUNK_LEN:]

    out[END_OF_LINE] += newline_adj
    if model_tokens[-1] in AVOID_REPEAT_TOKENS:
        out[model_tokens[-1]] = -999999999
    return out


def save_all_stat(srv, name, last_out):
    n = f'{name}_{srv}'
    all_state[n] = {}
    all_state[n]['out'] = last_out
    all_state[n]['rnn'] = copy.deepcopy(model_state)
    all_state[n]['token'] = copy.deepcopy(model_tokens)


def load_all_stat(srv, name):
    global model_tokens, model_state
    n = f'{name}_{srv}'
    model_state = copy.deepcopy(all_state[n]['rnn'])
    model_tokens = copy.deepcopy(all_state[n]['token'])
    return all_state[n]['out']


out = run_rnn(pipeline.encode(init_prompt))
save_all_stat('', 'chat_init', out)
gc.collect()
torch.cuda.empty_cache()

srv_list = ['dummy_server']

for s in srv_list:
    save_all_stat(s, 'chat', out)


def reply_msg(msg):
    print(f"bot's response' {bot}{interface} {msg}\n")


def handle_message(message):
    global model_tokens, model_state

    srv = 'dummy_server'

    msg = message.replace('\\n', '\n').strip()

    x_temp = GEN_TEMP
    x_top_p = GEN_TOP_P
    if ("-temp=" in msg):
        x_temp = float(msg.split("-temp=")[1].split(" ")[0])
        msg = msg.replace("-temp=" + f'{x_temp:g}', "")
    if ("-top_p=" in msg):
        x_top_p = float(msg.split("-top_p=")[1].split(" ")[0])
        msg = msg.replace("-top_p=" + f'{x_top_p:g}', "")
        # print(f"top_p: {x_top_p}")
    if x_temp <= 0.2:
        x_temp = 0.2
    if x_temp >= 5:
        x_temp = 5
    if x_top_p <= 0:
        x_top_p = 0

    if msg == '+reset':
        out = load_all_stat('', 'chat_init')
        save_all_stat(srv, 'chat', out)
        reply_msg("Chat reset.")
        return "Chat Reset"

    elif msg[:5].lower() == '+gen ' or msg[:3].lower() == '+i ' or msg[:4].lower() == '+qa ' \
            or msg[:4].lower() == '+qq ' or msg.lower() == '+++' or msg.lower() == '++':

        if msg[:5].lower() == '+gen ':
            new = '\n' + msg[5:].strip()
            model_state = None
            model_tokens = []
            out = run_rnn(pipeline.encode(new))
            save_all_stat(srv, 'gen_0', out)

        elif msg[:3].lower() == '+i ':
            new = f'''
Below is an instruction that describes a task. Write a response that appropriately completes the request.

# Instruction:
{msg[3:].strip()}

# Response:
'''
            model_state = None
            model_tokens = []
            out = run_rnn(pipeline.encode(new))
            save_all_stat(srv, 'gen_0', out)

        elif msg[:4].lower() == '+qq ':
            new = '\nQ: ' + msg[4:].strip() + '\nA:'
            model_state = None
            model_tokens = []
            out = run_rnn(pipeline.encode(new))
            save_all_stat(srv, 'gen_0', out)

        elif msg[:4].lower() == '+qa ':
            out = load_all_stat('', 'chat_init')

            real_msg = msg[4:].strip()
            new = f"{user}{interface} {real_msg}\n\n{bot}{interface}"

            out = run_rnn(pipeline.encode(new))
            save_all_stat(srv, 'gen_0', out)

        elif msg.lower() == '+++':
            try:
                out = load_all_stat(srv, 'gen_1')
                save_all_stat(srv, 'gen_0', out)
            except:
                return

        elif msg.lower() == '++':
            try:
                out = load_all_stat(srv, 'gen_0')
            except:
                return

        begin = len(model_tokens)
        out_last = begin
        occurrence = {}
        for i in range(FREE_GEN_LEN + 100):
            for n in occurrence:
                out[n] -= (GEN_alpha_presence + occurrence[n] * GEN_alpha_frequency)
            token = pipeline.sample_logits(
                out,
                temperature=x_temp,
                top_p=x_top_p,
            )
            if token == END_OF_TEXT:
                break
            if token not in occurrence:
                occurrence[token] = 1
            else:
                occurrence[token] += 1

            if msg[:4].lower() == '+qa ':  # or msg[:4].lower() == '+qq ':
                out = run_rnn([token], newline_adj=-2)
            else:
                out = run_rnn([token])

            xxx = pipeline.decode(model_tokens[out_last:])
            if '\ufffd' not in xxx:  # avoid utf-8 display issues
                print(xxx, end='', flush=True)
                out_last = begin + i + 1
                if i >= FREE_GEN_LEN:
                    break
        print('\n')
        send_msg = pipeline.decode(model_tokens[begin:]).strip()
        # print(f'### send ###\n[{send_msg}]')
        save_all_stat(srv, 'gen_1', out)
        return send_msg

    else:
        if msg.lower() == '+':
            try:
                out = load_all_stat(srv, 'chat_pre')
            except:
                return
        else:
            out = load_all_stat(srv, 'chat')
            new = f"{user}{interface} {msg}\n\n{bot}{interface}"
            # print(f'### add ###\n[{new}]')
            out = run_rnn(pipeline.encode(new), newline_adj=-999999999)
            save_all_stat(srv, 'chat_pre', out)

        begin = len(model_tokens)
        out_last = begin

        occurrence = {}
        for i in range(999):
            if i <= 0:
                newline_adj = -999999999
            elif i <= CHAT_LEN_SHORT:
                newline_adj = (i - CHAT_LEN_SHORT) / 10
            elif i <= CHAT_LEN_LONG:
                newline_adj = 0
            else:
                newline_adj = min(2, (i - CHAT_LEN_LONG) * 0.25)  # MUST END THE GENERATION

            for n in occurrence:
                out[n] -= (GEN_alpha_presence + occurrence[n] * GEN_alpha_frequency)
            token = pipeline.sample_logits(
                out,
                temperature=x_temp,
                top_p=x_top_p,
            )
            # if token == END_OF_TEXT:
            #     break
            if token not in occurrence:
                occurrence[token] = 1
            else:
                occurrence[token] += 1

            out = run_rnn([token], newline_adj=newline_adj)
            out[END_OF_TEXT] = -999999999  # disable <|endoftext|>

            xxx = pipeline.decode(model_tokens[out_last:])
            if '\ufffd' not in xxx:  # avoid utf-8 display issues
                out_last = begin + i + 1

            send_msg = pipeline.decode(model_tokens[begin:])
            if '\n\n' in send_msg:
                send_msg = send_msg.strip()
                # print(f'send_msg={send_msg}')
                break

        save_all_stat(srv, 'chat', out)
        return send_msg;


app = FastAPI()


@app.get("/chat_api")
async def chat(text: str = ""):
    reply = handle_message(text).replace('\n', '<br>')
    print(f'input:{text} reply:{reply}')

    outJson = {
        "output": [
            {
                "type": "text",
                "value": reply
            }
        ]
    }
    return outJson;


app.mount("/", StaticFiles(directory="html", html=True), name="html")


def start_server():
    uvicorn.run(app, host=HOST, port=PORT)


def main():
    start_server()


if __name__ == "__main__":
    main()

Enter fullscreen mode Exit fullscreen mode

STEP3: Obtain trained model data (weights)

The code is completed, but we do not yet have the trained weights, so we download them from the ChatRWKV author's page on Hugging Face.

There are several types of model data: 3b, 7b, 14b, etc. The higher the number, the more parameters you have, the more expressive your model is.

However, the higher the number of parameters, the more GPU memory is consumed, so download the data according to the execution environment.

Below is a table of typical model data types and GPU memory consumption

strategy rwkv-4-pile-14b rwkv-4-pile-7b rwkv-4-pile-3b
fp16 28GB 16GB 6GB
fp16i8 14GB 8.6GB 3GB

By the way, if you use fp16i8 (which seems to mean quantize fp16 trained data to int8), you can reduce the amount of GPU memory used, although the accuracy may be slightly lower.

(When specifying it in the code, use cuda fp16 or cuda fp16i8.

Download the weight data (*.pth) file from

Get BlinkDL/rwkv-4-pile-14b.

BlinkDL/rwkv-4-pile-14b

  • Download

  • GPU memory consumption

    • About 28GB (fp16) ... Out of memory even on a 24GB class GPU
    • About 14GB(fp16i8) ... Accuracy is said to be a little degraded, but it works on T4.
  • Parameters

Parameters

args.n_layer = 40
args.n_embd = 5120
args.ctx_len = 8192
Enter fullscreen mode Exit fullscreen mode

Get BlinkDL/rwkv-4-pile-7b

BlinkDL/rwkv-4-pile-7b

args.n_layer = 32
args.n_embd = 4096
args.ctx_len = 4096
Enter fullscreen mode Exit fullscreen mode

Get BlinkDL/rwkv-4-pile-3b

BlinkDL/rwkv-4-pile-3b

args.n_layer = 32
args.n_embd = 2560
args.ctx_len = 4096
Enter fullscreen mode Exit fullscreen mode

Place the downloaded files in the [project]/data folder

If this doesn't ring a bell, refer to the folder structure at the bottom of this page.
https://github.com/riversun/chatux-server-rwkv

STEP4: Set the learned model data (weights) in the code

Open the chatux-server-rwkv.py file you just created.

  • In the #specify RWKV strategy,model(weight data) area, enter STRATEGY= and MODEL_NAME as shown below. The MODEL_NAME is just the file name.
  • Around # specify params for weight data, enter the parameters for the trained model data. ↑(You can copy and paste the ones shown in the above.
# specify RWKV strategy,model(weight data)
STRATEGY = 'cuda fp16i8'
MODEL_NAME = 'RWKV-4-Pile-7B-20230109-ctx4096.pth'

# specify params for weight data
args.n_layer = 32
args.n_embd = 4096
args.ctx_len = 4096
Enter fullscreen mode Exit fullscreen mode

STEP5: Place other files

5-1 Create ChatUI (HTML)

Prepare index.html as follows and save it as [project]/html/index.html.

chatux loads chatux.min.js.

Now you will see the chat UI in a small window on the right edge for PCs, or on the full screen for smartphones, and you can chat.

index.html


<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=1">
    <title>chatRWKV </title>
</head>
<body>
<script src="https://riversun.github.io/chatux/chatux.min.js"></script>
<script>
    const chatux = new ChatUx();

    // initializing param for chatux
 const initParam =
        {
            renderMode: 'auto',
            api: {
                //echo chat server
                endpoint:'/chat_api',
                method: 'GET',
                dataType: 'json'
            },
            bot: {
                botPhoto: 'https://riversun.github.io/chatbot/bot_icon_operator.png',
                humanPhoto: null,
                widget: {
                    sendLabel: 'SEND',
                    placeHolder: 'Say something'
                }
            },
            window: {
                title: 'chatRWKV',
                infoUrl: 'https://github.com/riversun/chatux'
            }
        };
    chatux.init(initParam);
    chatux.start(true);


</script>
</body>
</html>
Enter fullscreen mode Exit fullscreen mode

The following server information is specified to match the chat server just created

 api: {
                //echo chat server
                endpoint:'/chat_api',
                method: 'GET',
                dataType: 'json'
            },
Enter fullscreen mode Exit fullscreen mode

5-2 Putting the initial prompt

Create a folder [project]/init_prompt and place the initial prompt (the first conversation context prompt) there.

The initial prompt is specified in the source code at PROMPT_FILE = f'{current_path}/init_prompt/English-1.py'.

Here, we put https://github.com/BlinkDL/ChatRWKV/blob/main/v2/prompt/default/English-1.py.

Incidentally, if you want to make it Japanese-inspired, you can put the initial prompt in Japanese.

https://github.com/riversun/chatux-server-rwkv/blob/main/init_prompt/Japanese-2.py

5-3 Putting tokenizer.json

Place https://github.com/BlinkDL/ChatRWKV/blob/main/v2/20B_tokenizer.json as [project]/20B_tokenizer.json.

STEP6: Start the chat server

You can start the chat server with the following command

python chatux-server-rwkv.py
Enter fullscreen mode Exit fullscreen mode

STEP7: Execution

Once the server is started, open http://localhost:8001 in your browser.

Now you can hit ChatRWKV from the web chat UI (ChatUX)!

Demo Video.

https://www.youtube.com/embed/t3vuNmIYXBo

ChatRWKV does a good job.

As you can see in the video below, I asked him the following questions and he responded nicely.

The Tallest Mountain in the World."
Who was the lead actor on Titanic?"
Other films made by the director of Titanic"

Summary

  • I explained how to make ChatRWKV conversational with Web Chat UI

    • FastAPI is used to convert to Web API
    • Sequential messages returned by the rnn model are summarized before being responded to.
    • ToBe
      • ChatUX should have sequential callbacks like ChatGPT and ChatRWKV
      • Better response for free generation
      • Engine is classified
      • On-memory conversation context (status to be passed to rnn) is persistent
  • Here is a sample code that has been tested

Thank you for reading to the end!

💖 💪 🙅 🚩
riversun
riversun

Posted on March 30, 2023

Join Our Newsletter. No Spam, Only the good stuff.

Sign up to receive the latest update from our blog.

Related