给rwkv_pytorch配了个流式服务和请求demo

项目地址

rwkv_pytorch

服务端

import json
import uuid
import time

import torch
from src.model import RWKV_RNN
from src.sampler import sample_logits
from src.rwkv_tokenizer import RWKV_TOKENIZER
from flask import Flask, request, jsonify, Response

app = Flask(__name__)


# 初始化模型和分词器
def init_model():
    # 模型参数配置
    args = {
        'MODEL_NAME': 'E:/RWKV_Pytorch/weight/RWKV-x060-World-1B6-v2-20240208-ctx4096',
        'vocab_size': 65536,
        'device': "cpu",
        'onnx_opset': '18',
    }
    device = args['device']
    assert device in ['cpu', 'cuda', 'musa', 'npu']

    if device == "musa":
        import torch_musa
    elif device == "npu":
        import torch_npu

    model = RWKV_RNN(args).to(device)
    tokenizer = RWKV_TOKENIZER("asset/rwkv_vocab_v20230424.txt")
    return model, tokenizer, device


def format_messages_to_prompt(messages):
    formatted_prompt = ""

    # 定义角色映射到期望的名称
    role_names = {
        "system": "System",
        "assistant": "Assistant",
        "user": "User"
    }

    # 遍历消息并格式化
    for message in messages:
        role = role_names.get(message['role'], 'Unknown')  # 获取角色名称,默认为'Unknown'
        content = message['content']
        formatted_prompt += f"{role}: {content}\n\n"  # 添加角色和内容到提示,并添加换行符

    formatted_prompt += "Assistant: "
    return formatted_prompt

def generate_text_stream(prompt: str, temperature=1.5, top_p=0.1, max_tokens=2048, stop=['\n\nUser']):
    encoded_input = tokenizer.encode([prompt])
    token = torch.tensor(encoded_input).long().to(device)
    state = torch.zeros(1, model.state_size[0], model.state_size[1]).to(device)
    with torch.no_grad():
        token_out, state_out = model.forward_parallel(token, state)

    del token

    out = token_out[:, -1]
    generated_tokens = ''
    completion_tokens = 0
    if_max_token = True
    for step in range(max_tokens):
        token_sampled = sample_logits(out, temperature, top_p)
        with torch.no_grad():
            out, state = model.forward(token_sampled, state)

        last_token = tokenizer.decode(token_sampled.unsqueeze(1).tolist())[0]
        generated_tokens += last_token
        completion_tokens += 1

        if generated_tokens.endswith(tuple(stop)):
            if_max_token = False
            response = {
                "object": "chat.completion.chunk",
                "model": "rwkv",
                "choices": [{
                    "delta": "",
                    "index": 0,
                    "finish_reason": "stop"
                }]
            }

            yield f"data: {json.dumps(response)}\n\n"

        else:
            response = {
                "object": "chat.completion.chunk",
                "model": "rwkv",
                "choices": [{
                    "delta": {"content": last_token},
                    "index": 0,
                    "finish_reason": None
                }]
            }

            yield f"data: {json.dumps(response)}\n\n"

    if if_max_token:
        response = {
            "object": "chat.completion.chunk",
            "model": "rwkv",
            "choices": [{
                "delta": "",
                "index": 0,
                "finish_reason": "length"
            }]
        }

        yield f"data: {json.dumps(response)}\n\n"

    yield f"data:[DONE]\n\n"



def generate_text(prompt, temperature=1.5, top_p=0.1, max_tokens=2048, stop=['\n\nUser']):

    encoded_input = tokenizer.encode([prompt])
    token = torch.tensor(encoded_input).long().to(device)
    state = torch.zeros(1, model.state_size[0], model.state_size[1]).to(device)
    prompt_tokens = len(encoded_input[0])

    with torch.no_grad():
        token_out, state_out = model.forward_parallel(token, state)

    del token

    out = token_out[:, -1]
    completion_tokens = 0
    if_max_token = True
    generated_tokens = ''
    for step in range(max_tokens):
        token_sampled = sample_logits(out, temperature, top_p)
        with torch.no_grad():
            out, state = model.forward(token_sampled, state)

        # 判断是否达到停止条件
        last_token = tokenizer.decode(token_sampled.unsqueeze(1).tolist())[0]
        completion_tokens += 1
        print(last_token, end='')

        generated_tokens += last_token

        for stop_token in stop:
            if generated_tokens.endswith(stop_token):
                generated_tokens = generated_tokens.replace(stop_token, "")  # 替换掉终止token
                if_max_token = False
                break
        # 如果末尾含有 stop 列表中的字符串,则停止生成
        if not if_max_token:
            break

    total_tokens = prompt_tokens + completion_tokens
    usage = {"prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": total_tokens}
    return generated_tokens, if_max_token, usage


@app.route('/events', methods=['POST'])
def sse_request():
    try:
        # 从查询字符串中获取参数
        data = request.json
        messages = data.get('messages', [])
        stream = data.get('stream', True) == True
        temperature = float(data.get('temperature', 0.5))
        top_p = float(data.get('top_p', 0.9))
        max_tokens = int(data.get('max_tokens', 100))
        stop = data.get('stop', ['\n\nUser'])

        prompt = format_messages_to_prompt(messages)

        if stream:
            return Response(generate_text_stream(prompt=prompt, temperature=temperature, top_p=top_p,
                                          max_tokens=max_tokens, stop=stop),
                            content_type='text/event-stream')
        else:
            completion, if_max_token, usage = generate_text(prompt, temperature=temperature, top_p=top_p,
                                                            max_tokens=max_tokens, stop=stop)
            finish_reason = "stop" if if_max_token else "length"
            unique_id = str(uuid.uuid4())
            current_timestamp = int(time.time())
            response = {
                "id": unique_id,
                "object": "chat.completion",
                "created": current_timestamp,
                "choices": [{
                    "index": 0,
                    "message": {
                        "role": "assistant",
                        "content": completion,
                    },
                    "finish_reason": finish_reason
                }],
                "usage": usage
            }



            return json.dumps(response)
    except Exception as e:
        return json.dumps({"error": str(e)}), 500


if __name__ == '__main__':
    model, tokenizer, device = init_model()
    app.run(debug=False)


解释

  • 首先引入了需要的库,包括json用于处理JSON数据,uuid用于生成唯一标识符,time用于获取当前时间戳,torch用于构建和运行模型,Flask用于构建API。
  • 定义了一个名为app的Flask应用。
  • init_model函数用于初始化模型和分词器。其中,模型参数通过字典args指定。
  • format_messages_to_prompt函数用于将消息格式化为提示字符串,以便于模型生成回复。遍历消息列表,获取每个消息的角色和内容,并添加到提示字符串中。
  • generate_text_stream函数用于以流的形式生成文本。首先将输入的提示字符串编码为张量,然后利用模型生成回复,并利用yield关键字将回复以SSE(服务器发送事件)的形式返回。
  • generate_text函数用于一次性生成完整的文本回复。与generate_text_stream函数类似,不同的是返回的是完整的回复字符串。
  • sse_request函数是Flask应用的主要逻辑,用于处理POST请求。从请求的JSON数据中获取参数,并根据参数的设置调用相应的生成函数。如果参数中设置了stream=True,则返回流式生成的回复;否则返回一次性生成的回复。
  • __main__函数中初始化模型和分词器,然后运行Flask应用。

客户端

import json

import requests
from requests import RequestException

# 配置服务器URL
url = 'http://localhost:5000/events'  # 假设您的Flask应用运行在本地端口5000上


# POST请求示例
def post_request_stream():
    # 构造请求数据
    data = {
        'messages': [
            {'role': 'system', 'content': '你好!'},
            {'role': 'user', 'content': '你能告诉我今天的天气吗?'}
        ],
        'temperature': 0.5,
        'top_p': 0.9,
        'max_tokens': 100,
        'stop': ['\n\nUser'],
        'stream':True
    }

    # 使用 requests 库来连接服务器,并传递参数
    try:
        with  requests.post(url, json=data, stream=True) as r:
            for line in r.iter_lines():
                if line:
                    # 当服务器发送消息时,解码并打印出来
                    decoded_line = line.decode('utf-8')
                    print(json.loads(decoded_line[5:])["choices"][0]["delta"], end="")
    except RequestException as e:
        print(f'An error occurred: {e}')

def post_request():
    # 构造请求数据
    data = {
        'messages': [
            {'role': 'system', 'content': '你好!'},
            {'role': 'user', 'content': '你能告诉我今天的天气吗?'}
        ],
        'temperature': 0.5,
        'top_p': 0.9,
        'max_tokens': 100,
        'stop': ['\n\nUser'],
        'stream':False
    }

    # 使用 requests 库来连接服务器,并传递参数
    try:
        with  requests.post(url, json=data, stream=True) as r:
            for line in r.iter_lines():
                if line:
                    # 当服务器发送消息时,解码并打印出来
                    decoded_line = line.decode('utf-8')
                    res=json.loads(decoded_line)
                    print(res)
    except RequestException as e:
        print(f'An error occurred: {e}')


if __name__ == '__main__':
    # post_request()
    post_request_stream()

解释

这段代码是一个用于向服务器发送POST请求的示例代码。

首先,我们需要导入一些必要的库。json库用于处理JSON数据,requests库用于发送HTTP请求,RequestException用于处理请求异常。

接下来,我们需要配置服务器的URL。在这个示例中,假设服务器运行在本地端口5000上。

代码中定义了两个函数post_request_streampost_request,分别用于发送带有流式响应和非流式响应的POST请求。

post_request_stream函数构造了一个包含各种参数的数据字典,并使用requests.post方法发送POST请求。在请求的参数中,stream参数被设置为True,表示我们希望获得一个流式的响应。接着,我们使用r.iter_lines()方法来迭代获取服务器发送的消息。每收到一行消息,我们将其解码并打印出来。

post_request函数的代码结构与post_request_stream函数相似,不同之处在于stream参数被设置为False,表示我们希望获得一个非流式的响应。

最后,在程序的主体部分,我们调用post_request_stream函数来发送流式的POST请求,并注释掉了post_request函数的调用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/559249.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

美化博客文章(持续更新)

🎁个人主页:我们的五年 🔍系列专栏:游戏实现:贪吃蛇​​​​​​ 🌷追光的人,终会万丈光芒 前言: 该文提供我的一些文章设计的一些方法 目录 1.应用超链接 1.应用超链接

差速机器人模型LQR 控制仿真——路径模拟

LQR路径跟踪要求路径中带角度,即坐标(x,y,yaw),而一般我们的规划出来的路径不带角度。这里通过总结相关方法,并提供一个案例。 将点路径拟合成一条完整的线路径算法 将点路径拟合成一条完整的线路径是一个常见的问题…

【Java开发指南 | 第十五篇】Java Character 类、String 类

读者可订阅专栏:Java开发指南 |【CSDN秋说】 文章目录 Java Character 类转义序列 Java String 类连接字符串 Java Character 类 Character 类是 Java 中用来表示字符的包装类,它提供了一系列静态方法用于对字符进行操作,其主要分为静态方法…

06 JavaScript学习:语句

JavaScript 语句是用来执行特定任务或操作的一组指令。它可以包括变量声明、条件语句、循环语句、函数调用等。JavaScript 语句以分号结尾,每个语句都会被解释器执行。 分号 ; 在JavaScript中,分号(;)用于表示语句的结束。尽管在…

python爬虫-----深入了解 requests 库(第二十五天)

🎈🎈作者主页: 喔的嘛呀🎈🎈 🎈🎈所属专栏:python爬虫学习🎈🎈 ✨✨谢谢大家捧场,祝屏幕前的小伙伴们每天都有好运相伴左右,一定要天天…

【汇编语言】初识汇编

【汇编语言】初识汇编 文章目录 【汇编语言】初识汇编前言由机器语言到汇编语言机器语言与机器指令汇编语言与汇编指令汇编语言程序示例 计算机组成指令和数据的表示计算机的存储单元计算机的总线 内存读写与地址空间CPU对存储器的读写内存地址空间 总结 前言 为什么要学习汇编…

Numpy重修系列(一) --- 初识Numpy

一、为什么使用Numpy? 1.1、简介 Python科学计算基础包,提供 多维数组对象 、派生对象(掩码数组、矩阵) 数组的快速操作(数学计算、逻辑、形状变化、排序、选择、输入输出、离散傅里叶变换、基本线性代数、基本统计运…

数据分析案例-中国黄金股票市场的EDA与价格预测

🤵‍♂️ 个人主页:艾派森的个人主页 ✍🏻作者简介:Python学习者 🐋 希望大家多多支持,我们一起进步!😄 如果文章对你有帮助的话, 欢迎评论 💬点赞&#x1f4…

【数据结构】单链表经典算法题的巧妙解题思路

目录 题目 1.移除链表元素 2.反转链表 3.链表的中间节点 4.合并两个有序链表 5.环形链表的约瑟夫问题 解析 题目1:创建新链表 题目2:巧用三个指针 题目3:快慢指针 题目4:哨兵位节点 题目5:环形链表 介绍完了…

Activity——spring方式创建activiti所需数据表结构

文章目录 前言依赖引入编写数据库连接等配置配置日志文件编写java代码生成数据库表结构问题反馈与解决思路问题一:Cause: java.sql.SQLSyntaxErrorException: Table activiti_02.act_ge_property doesnt exist 为什么文件名必须写死? 前言 在之前创建ac…

循序渐进丨使用 Python 向 MogDB 数据库批量操作数据的方法

当我们有时候需要向数据库里批量插入数据,或者批量导出数据时,除了使用传统的gsql copy命令,也可以通过Python的驱动psycopg2进行批量操作。本文介绍了使用psycopg2里的executemany、copy_from、copy_to、copy_expert等方式来批量操作 MogDB …

js-pytorch:开启前端+AI新世界

嗨, 大家好, 我是 徐小夕。最近在 github 上发现一款非常有意思的框架—— js-pytorch。它可以让前端轻松使用 javascript 来运行深度学习框架。作为一名资深前端技术玩家, 今天就和大家分享一下这款框架。 往期精彩 Nocode/Doc,可…

python爬虫之爬取携程景点评价(5)

一、景点部分评价爬取 【携程攻略】携程旅游攻略,自助游,自驾游,出游,自由行攻略指南 (ctrip.com) import requests from bs4 import BeautifulSoupif __name__ __main__:url https://m.ctrip.com/webapp/you/commentWeb/commentList?seo0&businessId22176&busines…

“中医显示器”是人体健康监测器

随着科技的进步,现代医学设备已经深入到了人们的日常生活中。然而,在这个过程中,我们不应忘记我们的医学根源,中医。我们将中医的望、闻、问、切四诊与现代科技相结合,通过一系列的传感器和算法将人体的生理状态以数字…

3、MYSQL-一条sql如何在MYSQL中执行的

MySQL的内部组件结构 大体来说,MySQL 可以分为 Server 层和存储引擎层两部分。 Server层 主要包括连接器、查询缓存、分析器、优化器、执行器等,涵盖 MySQL 的大多数核心服务功能,以及所有的内置函数(如日期、时间、数学和加密函…

[Algorithm][滑动窗口][无重复字符的最长字串][最大连续的一个数 Ⅲ][将x减到0的最小操作数]详细讲解

目录 1.无重复字符的最长字串1.题目链接2.算法原理详解3.代码实现 2.最大连续的一个数 Ⅲ1.题目链接2.算法原理详解3.代码实现 3.将x减到0的最小操作数1.题目链接2.算法原理详解3.代码实现 1.无重复字符的最长字串 1.题目链接 无重复字符的最长字串 2.算法原理详解 研究的对…

算法打卡day39

今日任务: 1)卡码网57. 爬楼梯(70. 爬楼梯进阶版) 2)322.零钱兑换 3)279.完全平方数 4)复习day14 卡码网57. 爬楼梯(70. 爬楼梯进阶版) 题目链接:57. 爬楼梯…

数据结构从入门到实战——顺序表的应用

目录 一、基于动态顺序表实现通讯录 二、代码实现 2.1 通讯录的初始化 2.2 通讯录的销毁 2.3 通讯录的展示 2.4 通讯录添加联系人信息 2.5 通讯录删除联系人信息 2.6 通讯录修改联系人信息 2.7 通讯录的查找联系人信息 2.8 将通讯录中联系人信息保存到文件中 2.9…

乡政府管理系统|基于Springboot的乡政府管理系统设计与实现(源码+数据库+文档)

乡政府管理系统目录 目录 基于Springboot的乡政府管理系统设计与实现 一、前言 二、系统功能设计 三、系统实现 1、用户信息管理 2、活动信息管理 3、新闻类型管理 4、新闻动态管理 四、数据库设计 1、实体ER图 五、核心代码 六、论文参考 七、最新计算机毕设选题推…

考研党们,搭子们,打打鸡血!刷视频免疫了,时间竟然多了起来!——早读(逆天打工人爬取热门微信文章解读)

断舍离,断的是过去 引言Python 代码第一篇 人民日报 一个班级,29人全部“上岸”! 第二篇 人民日报 来了!新闻早班车要闻社会政策 结尾 时间就像河流 它带来一切 也带走一切 不打游戏不刷视频 时间的河流便能带来更丰富的体验 引言…
最新文章