diff --git a/host/Gateway.cs b/host/Gateway.cs index d20b0a8..3984613 100644 --- a/host/Gateway.cs +++ b/host/Gateway.cs @@ -229,6 +229,27 @@ public class Gateway } } + private static (string Matched, string Remaining) ExtractMatchedPart(string input, string toolPattern) + { + if (string.IsNullOrEmpty(input) || string.IsNullOrEmpty(toolPattern)) + return (string.Empty, input); + + Regex regex = new Regex(toolPattern); + Match match = regex.Match(input); + + if (!match.Success) + return (string.Empty, input); + + string matched = match.Value; + int startIndex = match.Index; + int length = match.Length; + + // 构造剩余字符串 + string remaining = input.Substring(0, startIndex) + input.Substring(startIndex + length); + + return (matched, remaining); + } + public static async void SendMessageStream(string message, string model, string gdbPath, Action callback) { Llm bailian = new Bailian @@ -249,8 +270,8 @@ public class Gateway Content = message }); goOn = true; - string toolPattern = "^[\\s\\S]*?<\\/tool_use>$"; - string promptPattern = "^[\\s\\S]*?<\\/prompt>$"; + string toolPattern = "([\\s\\S]*?)([\\s\\S]*?)<\\/name>([\\s\\S]*?)([\\s\\S]*?)<\\/arguments>([\\s\\S]*?)<\\/tool_use>"; + string promptPattern = "([\\s\\S]*?)([\\s\\S]*?)<\\/name>([\\s\\S]*?)<\\/prompt>"; McpServerList mcpServerList = new McpServerList(); int loop = 0; while (goOn) @@ -270,169 +291,49 @@ public class Gateway MaxTokens = 1000, }; long timestamp = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(); - string messageContent = ""; + string messageContent = ""; //一次请求下完整的response + var (toolMatched, toolRemaining) = ExtractMatchedPart(messageContent, toolPattern); + var (promptMatched, promptRemaining) = ExtractMatchedPart(messageContent, promptPattern); + if (toolMatched == "" && promptMatched == "" && messageContent != "") + { + //如果本次回复不包含任何工具的调用或提示词的调用,则不再请求 + break; + } await foreach(var chunk in bailian.SendChatStreamAsync(jsonContent)) { - if (chunk == "[DONE]") + if (!goOn) { - goOn = false; - }else if (chunk.StartsWith("")) + break; + } + var (matched, remaining) = ExtractMatchedPart(chunk, toolPattern); + if (matched == "") { - if (Regex.IsMatch(chunk, toolPattern)) + var (matchedPrompt, remainingPrompt) = ExtractMatchedPart(chunk, promptPattern); + if (matchedPrompt == "") { - //返回工具卡片 - messages.Add(new Message + //普通消息文本 + MessageListItem chatMessageListItem = new ChatMessageItem() { - Role = "assistant", - Content = chunk - }); - XElement toolUse = XElement.Parse(chunk); - string fullToolName = toolUse.Element("name")?.Value; - string toolArgs = toolUse.Element("arguments")?.Value; - Dictionary toolParams = JsonConvert.DeserializeObject>(toolArgs); - string serverName = fullToolName.Contains(":") ? fullToolName.Split(':')[0] : fullToolName; - string toolName = fullToolName.Contains(":") ? fullToolName.Split(':')[1] : fullToolName; - McpServer mcpServer = mcpServerList.GetServer(serverName); - if (mcpServer is SseMcpServer) - { - SseMcpServer sseMcpServer = mcpServer as SseMcpServer; - SseMcpClient client = new SseMcpClient(sseMcpServer.BaseUrl); - CallToolResponse toolResponse = await client.CallToolAsync(toolName,toolParams); - MessageListItem toolMessageItem = new ToolMessageItem - { - toolName = toolName, - toolParams = toolParams, - type = MessageType.TOOL_MESSAGE, - status = toolResponse.IsError ? "fail" : "success", - content = JsonConvert.SerializeObject(toolResponse), - id = timestamp.ToString() - }; - messages.Add(new Message - { - Role = "user", - // Content = toolResponse.IsError ? SystemPrompt.ErrorPromptTemplate : SystemPrompt.ContinuePromptTemplate - Content = toolResponse.IsError ? SystemPrompt.ErrorPromptTemplate : SystemPrompt.ContinuePrompt(JsonConvert.SerializeObject(toolResponse)) - }); - // messages.Add(new Message - // { - // Role = "user", - // Content = JsonConvert.SerializeObject(toolResponse) - // }); - callback?.Invoke(toolMessageItem); - }else if (mcpServer is StdioMcpServer) - { - StdioMcpServer stdioMcpServer = mcpServer as StdioMcpServer; - StdioMcpClient client = new StdioMcpClient(stdioMcpServer.Command, stdioMcpServer.Args); - CallToolResponse toolResponse = await client.CallToolAsync(toolName,toolParams); - MessageListItem toolMessageItem = new ToolMessageItem - { - toolName = toolName, - toolParams = toolParams, - type = MessageType.TOOL_MESSAGE, - status = toolResponse.IsError ? "fail" : "success", - content = JsonConvert.SerializeObject(toolResponse), - id = timestamp.ToString() - }; - messages.Add(new Message - { - Role = "user", - // Content = toolResponse.IsError ? SystemPrompt.ErrorPromptTemplate : SystemPrompt.ContinuePromptTemplate - Content = toolResponse.IsError ? SystemPrompt.ErrorPromptTemplate : SystemPrompt.ContinuePrompt(JsonConvert.SerializeObject(toolResponse)) - }); - // messages.Add(new Message - // { - // Role = "user", - // Content = JsonConvert.SerializeObject(toolResponse) - // }); - callback?.Invoke(toolMessageItem); - }else if (mcpServer is InnerMcpServer) - { - Type type = Type.GetType("LinkToolAddin.client.tool."+serverName); - MethodInfo method = type.GetMethod(toolName,BindingFlags.Public | BindingFlags.Static); - var methodParams = toolParams.Values.ToArray(); - object[] args = new object[methodParams.Length]; - for (int i = 0; i < methodParams.Length; i++) - { - if (methodParams[i].GetType() == typeof(JArray)) - { - List list = new List(); - list = (methodParams[i] as JArray).Select(token => token.ToString()).ToList(); - args[i] = list; - } - else - { - args[i] = methodParams[i]; - } - } - var task = method.Invoke(null, args) as Task; - JsonRpcResultEntity innerResult = await task; - if (innerResult is JsonRpcErrorEntity) - { - MessageListItem toolMessageItem = new ToolMessageItem - { - toolName = toolName, - toolParams = toolParams, - type = MessageType.TOOL_MESSAGE, - status = "fail", - content = JsonConvert.SerializeObject(innerResult), - id = timestamp.ToString() - }; - messages.Add(new Message - { - Role = "user", - Content = SystemPrompt.ErrorPromptTemplate - }); - messages.Add(new Message - { - Role = "user", - Content = JsonConvert.SerializeObject(innerResult) - }); - callback?.Invoke(toolMessageItem); - }else if (innerResult is JsonRpcSuccessEntity) - { - MessageListItem toolMessageItem = new ToolMessageItem - { - toolName = toolName, - toolParams = toolParams, - type = MessageType.TOOL_MESSAGE, - status = "success", - content = JsonConvert.SerializeObject(innerResult), - id = timestamp.ToString() - }; - messages.Add(new Message - { - Role = "user", - // Content = SystemPrompt.ContinuePromptTemplate - Content = SystemPrompt.ContinuePrompt(JsonConvert.SerializeObject(innerResult)) - }); - // messages.Add(new Message - // { - // Role = "user", - // Content = JsonConvert.SerializeObject(innerResult) - // }); - callback?.Invoke(toolMessageItem); - } - } + content = remainingPrompt, + role = "assistant", + type = MessageType.CHAT_MESSAGE, + id = timestamp.ToString() + }; + messageContent = remainingPrompt; + callback?.Invoke(chatMessageListItem); } else { - MessageListItem toolMessageItem = new ToolMessageItem + //包含Prompt调用请求的消息 + MessageListItem chatMessageListItem = new ChatMessageItem() { - toolName = "", - toolParams = new Dictionary(), - type = MessageType.TOOL_MESSAGE, - status = "loading", - content = "正在生成工具调用参数", + content = remainingPrompt, + role = "assistant", + type = MessageType.CHAT_MESSAGE, id = timestamp.ToString() }; - callback?.Invoke(toolMessageItem); - continue; - } - }else if (chunk.StartsWith("")) - { - if (Regex.IsMatch(chunk, promptPattern)) - { - XElement promptUse = XElement.Parse(chunk); + callback?.Invoke(chatMessageListItem); + XElement promptUse = XElement.Parse(matchedPrompt); string promptKey = promptUse.Element("name")?.Value; string promptContent = DynamicPrompt.GetPrompt(promptKey,null); messages.Add(new Message @@ -446,49 +347,134 @@ public class Gateway toolParams = null, type = MessageType.TOOL_MESSAGE, status = "success", - content = promptKey, - id = timestamp.ToString() - }; - callback?.Invoke(toolMessageItem); - } - else - { - MessageListItem toolMessageItem = new ToolMessageItem - { - toolName = "调用提示词", - toolParams = null, - type = MessageType.TOOL_MESSAGE, - status = "loading", - content = "正在调用提示词", - id = timestamp.ToString() + content = "成功调用提示词:"+promptKey, + id = (timestamp+1).ToString() }; callback?.Invoke(toolMessageItem); } } else { - string content = chunk; - if (content.EndsWith("[DONE]")) - { - content = content.Substring(0, content.Length - 6); - } - //普通流式消息卡片 + //包含工具调用请求的消息 MessageListItem chatMessageListItem = new ChatMessageItem() { - content = content, + content = remaining, role = "assistant", type = MessageType.CHAT_MESSAGE, id = timestamp.ToString() }; - messageContent = content; callback?.Invoke(chatMessageListItem); + XElement toolUse = XElement.Parse(matched); + string fullToolName = toolUse.Element("name")?.Value; + string toolArgs = toolUse.Element("arguments")?.Value; + Dictionary toolParams = JsonConvert.DeserializeObject>(toolArgs); + string serverName = fullToolName.Contains(":") ? fullToolName.Split(':')[0] : fullToolName; + string toolName = fullToolName.Contains(":") ? fullToolName.Split(':')[1] : fullToolName; + McpServer mcpServer = mcpServerList.GetServer(serverName); + if (mcpServer is SseMcpServer) + { + SseMcpServer sseMcpServer = mcpServer as SseMcpServer; + SseMcpClient client = new SseMcpClient(sseMcpServer.BaseUrl); + CallToolResponse toolResponse = await client.CallToolAsync(toolName,toolParams); + MessageListItem toolMessageItem = new ToolMessageItem + { + toolName = toolName, + toolParams = toolParams, + type = MessageType.TOOL_MESSAGE, + status = toolResponse.IsError ? "fail" : "success", + content = JsonConvert.SerializeObject(toolResponse), + id = (timestamp+1).ToString() + }; + messages.Add(new Message + { + Role = "user", + Content = toolResponse.IsError ? SystemPrompt.ErrorPromptTemplate : SystemPrompt.ContinuePrompt(JsonConvert.SerializeObject(toolResponse)) + }); + callback?.Invoke(toolMessageItem); + }else if (mcpServer is StdioMcpServer) + { + StdioMcpServer stdioMcpServer = mcpServer as StdioMcpServer; + StdioMcpClient client = new StdioMcpClient(stdioMcpServer.Command, stdioMcpServer.Args); + CallToolResponse toolResponse = await client.CallToolAsync(toolName,toolParams); + MessageListItem toolMessageItem = new ToolMessageItem + { + toolName = toolName, + toolParams = toolParams, + type = MessageType.TOOL_MESSAGE, + status = toolResponse.IsError ? "fail" : "success", + content = JsonConvert.SerializeObject(toolResponse), + id = (timestamp+1).ToString() + }; + messages.Add(new Message + { + Role = "user", + Content = toolResponse.IsError ? SystemPrompt.ErrorPromptTemplate : SystemPrompt.ContinuePrompt(JsonConvert.SerializeObject(toolResponse)) + }); + callback?.Invoke(toolMessageItem); + }else if (mcpServer is InnerMcpServer) + { + Type type = Type.GetType("LinkToolAddin.client.tool."+serverName); + MethodInfo method = type.GetMethod(toolName,BindingFlags.Public | BindingFlags.Static); + var methodParams = toolParams.Values.ToArray(); + object[] args = new object[methodParams.Length]; + for (int i = 0; i < methodParams.Length; i++) + { + if (methodParams[i].GetType() == typeof(JArray)) + { + List list = new List(); + list = (methodParams[i] as JArray).Select(token => token.ToString()).ToList(); + args[i] = list; + } + else + { + args[i] = methodParams[i]; + } + } + var task = method.Invoke(null, args) as Task; + JsonRpcResultEntity innerResult = await task; + if (innerResult is JsonRpcErrorEntity) + { + MessageListItem toolMessageItem = new ToolMessageItem + { + toolName = toolName, + toolParams = toolParams, + type = MessageType.TOOL_MESSAGE, + status = "fail", + content = JsonConvert.SerializeObject(innerResult), + id = (timestamp+1).ToString() + }; + messages.Add(new Message + { + Role = "user", + Content = SystemPrompt.ErrorPromptTemplate + }); + messages.Add(new Message + { + Role = "user", + Content = JsonConvert.SerializeObject(innerResult) + }); + callback?.Invoke(toolMessageItem); + }else if (innerResult is JsonRpcSuccessEntity) + { + MessageListItem toolMessageItem = new ToolMessageItem + { + toolName = toolName, + toolParams = toolParams, + type = MessageType.TOOL_MESSAGE, + status = "success", + content = JsonConvert.SerializeObject(innerResult), + id = (timestamp+1).ToString() + }; + messages.Add(new Message + { + Role = "user", + Content = SystemPrompt.ContinuePrompt(JsonConvert.SerializeObject(innerResult)) + }); + callback?.Invoke(toolMessageItem); + } + } } } - messages.Add(new Message - { - Role = "assistant", - Content = messageContent - }); } }