Skip to content

Commit

Permalink
feat: support multi-conversation memory #28
Browse files Browse the repository at this point in the history
  • Loading branch information
zhayujie committed Feb 5, 2023
1 parent b0e81ae commit 7425d90
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 8 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
- [x] **规则定制化:** 支持私聊中按指定规则触发自动回复,支持对群组设置自动回复白名单
- [x] **多账号:** 支持多微信账号同时运行
- [x] **图片生成:** 支持根据描述生成图片,并自动发送至个人聊天或群聊
- [x] **上下文记忆**:支持多轮对话记忆,且为每个好友维护独立的上下会话


# 更新日志
>**2022.02.05:** 在openai官方接口方案中 (GPT-3模型) 实现上下文对话
>**2022.12.19:** 引入 [itchat-uos](https://github.com/why2lyj/ItChat-UOS) 替换 itchat,解决由于不能登录网页微信而无法使用的问题,且解决Python3.9的兼容问题
Expand Down Expand Up @@ -85,7 +87,8 @@ cp config-template.json config.json
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人
"group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复
"group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], # 开启自动回复的群名称列表
"image_create_prefix": ["", "", ""] # 开启图片回复的前缀
"image_create_prefix": ["", "", ""], # 开启图片回复的前缀
"conversation_max_tokens": 3000 # 支持上下文记忆的最多字符数
}
```
**配置说明:**
Expand All @@ -105,6 +108,7 @@ cp config-template.json config.json

+ 对于图像生成,在满足个人或群组触发条件外,还需要额外的关键词前缀来触发,对应配置 `image_create_prefix `
+ 关于OpenAI对话及图片接口的参数配置(内容自由度、回复字数限制、图片大小等),可以参考 [对话接口](https://beta.openai.com/docs/api-reference/completions)[图像接口](https://beta.openai.com/docs/api-reference/completions) 文档直接在 [代码](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/bot/openai/open_ai_bot.py) `bot/openai/open_ai_bot.py` 中进行调整。
+ `conversation_max_tokens`:表示能够记忆的上下文最大字数(一问一答为一组对话,如果累积的对话字数超出限制,就会优先移除最早的一组对话)


## 运行
Expand Down
87 changes: 81 additions & 6 deletions bot/openai/open_ai_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,34 @@
from common.log import logger
import openai

user_session = dict()

# OpenAI对话模型API (可用)
class OpenAIBot(Bot):
def __init__(self):
openai.api_key = conf().get('open_ai_api_key')

def reply(self, query, context=None):
# auto append question mark
query = self.append_question_mark(query)

# acquire reply content
if not context or not context.get('type') or context.get('type') == 'TEXT':
return self.reply_text(query)
logger.info("[OPEN_AI] query={}".format(query))
from_user_id = context['from_user_id']
if query == '#清除记忆':
Session.clear_session(from_user_id)
return '记忆已清除'

new_query = Session.build_session_query(query, from_user_id)
logger.debug("[OPEN_AI] session query={}".format(new_query))

reply_content = self.reply_text(new_query, query)
Session.save_session(query, reply_content, from_user_id)
return reply_content

elif context.get('type', None) == 'IMAGE_CREATE':
return self.create_img(query)

def reply_text(self, query):
logger.info("[OPEN_AI] query={}".format(query))
def reply_text(self, query, origin_query):
try:
response = openai.Completion.create(
model="text-davinci-003", # 对话模型的名称
Expand All @@ -34,7 +44,7 @@ def reply_text(self, query):
presence_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容
stop=["#"]
)
res_content = response.choices[0]["text"].strip()
res_content = response.choices[0]["text"].strip().rstrip("<|im_end|>")
except Exception as e:
logger.exception(e)
return None
Expand Down Expand Up @@ -93,3 +103,68 @@ def append_question_mark(self, query):
if query.endswith(symbol):
return query
return query + "?"


class Session(object):
@staticmethod
def build_session_query(query, user_id):
'''
build query with conversation history
e.g. Q: xxx
A: xxx
Q: xxx
:param query: query content
:param user_id: from user id
:return: query content with conversaction
'''
new_query = ""
session = user_session.get(user_id, None)
if session:
for conversation in session:
new_query += "Q: " + conversation["question"] + "\n\n\nA: " + conversation["answer"] + "<|im_end|>\n"
new_query += "Q: " + query + "\nA: "
return new_query
else:
return "Q: " + query + "\nA: "

@staticmethod
def save_session(query, answer, user_id):
max_tokens = conf().get("conversation_max_tokens")
if not max_tokens:
# default 3000
max_tokens = 3000
conversation = dict()
conversation["question"] = query
conversation["answer"] = answer
session = user_session.get(user_id)
if session:
# append conversation
session.append(conversation)
else:
# create session
queue = list()
queue.append(conversation)
user_session[user_id] = queue

# discard exceed limit conversation
Session.discard_exceed_conversation(user_session[user_id], max_tokens)


@staticmethod
def discard_exceed_conversation(session, max_tokens):
count = 0
count_list = list()
for i in range(len(session)-1, -1, -1):
# count tokens of conversation list
history_conv = session[i]
count += len(history_conv["question"]) + len(history_conv["answer"])
count_list.append(count)

for c in count_list:
if c > max_tokens:
# pop first conversation
session.pop(0)

@staticmethod
def clear_session(user_id):
user_session[user_id] = []
3 changes: 2 additions & 1 deletion config-template.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
"single_chat_reply_prefix": "[bot] ",
"group_chat_prefix": ["@bot"],
"group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"],
"image_create_prefix": ["", "", ""]
"image_create_prefix": ["", "", ""],
"conversation_max_tokens": 3000
}

0 comments on commit 7425d90

Please sign in to comment.