diff --git a/client/tool/ArcGisPro.cs b/client/tool/ArcGisPro.cs index da2e604..163f253 100644 --- a/client/tool/ArcGisPro.cs +++ b/client/tool/ArcGisPro.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.ComponentModel; +using System.Linq; using System.Threading.Tasks; using ArcGIS.Core.Data; using ArcGIS.Core.Data.Raster; @@ -9,6 +10,7 @@ using ArcGIS.Desktop.Framework.Threading.Tasks; using LinkToolAddin.server; using ModelContextProtocol.Server; using Newtonsoft.Json; +using Newtonsoft.Json.Linq; namespace LinkToolAddin.client.tool; diff --git a/host/Gateway.cs b/host/Gateway.cs index fca37cc..28935f0 100644 --- a/host/Gateway.cs +++ b/host/Gateway.cs @@ -8,6 +8,7 @@ using System.Text; using System.Text.RegularExpressions; using System.Threading; using System.Threading.Tasks; +using System.Windows.Documents; using System.Xml; using System.Xml.Linq; using ArcGIS.Desktop.Framework.Dialogs; @@ -37,6 +38,13 @@ namespace LinkToolAddin.host; public class Gateway { private static ILog log = LogManager.GetLogger(typeof(Gateway)); + private static bool goOn = true; + + public static void StopConversation() + { + goOn = false; + } + public static async void SendMessage(string message, string model, string gdbPath, Action callback) { Llm bailian = new Bailian @@ -240,12 +248,19 @@ public class Gateway Role = "user", Content = message }); - bool goOn = true; + goOn = true; string toolPattern = "^[\\s\\S]*?<\\/tool_use>$"; string promptPattern = "^[\\s\\S]*?<\\/prompt>$"; McpServerList mcpServerList = new McpServerList(); + int loop = 0; while (goOn) { + loop++; + if (loop > 20) + { + MessageBox.Show("达到最大循环次数", "退出循环"); + break; + } LlmJsonContent jsonContent = new LlmJsonContent() { Model = model, @@ -266,6 +281,11 @@ public class Gateway if (Regex.IsMatch(chunk, toolPattern)) { //返回工具卡片 + messages.Add(new Message + { + Role = "assistant", + Content = chunk + }); XElement toolUse = XElement.Parse(chunk); string fullToolName = toolUse.Element("name")?.Value; string toolArgs = toolUse.Element("arguments")?.Value; @@ -290,13 +310,14 @@ public class Gateway messages.Add(new Message { Role = "user", - Content = toolResponse.IsError ? SystemPrompt.ErrorPromptTemplate : SystemPrompt.ContinuePromptTemplate - }); - messages.Add(new Message - { - Role = "user", - Content = JsonConvert.SerializeObject(toolResponse) + // Content = toolResponse.IsError ? SystemPrompt.ErrorPromptTemplate : SystemPrompt.ContinuePromptTemplate + Content = toolResponse.IsError ? SystemPrompt.ErrorPromptTemplate : SystemPrompt.ContinuePrompt(JsonConvert.SerializeObject(toolResponse)) }); + // messages.Add(new Message + // { + // Role = "user", + // Content = JsonConvert.SerializeObject(toolResponse) + // }); callback?.Invoke(toolMessageItem); }else if (mcpServer is StdioMcpServer) { @@ -315,19 +336,35 @@ public class Gateway messages.Add(new Message { Role = "user", - Content = toolResponse.IsError ? SystemPrompt.ErrorPromptTemplate : SystemPrompt.ContinuePromptTemplate - }); - messages.Add(new Message - { - Role = "user", - Content = JsonConvert.SerializeObject(toolResponse) + // Content = toolResponse.IsError ? SystemPrompt.ErrorPromptTemplate : SystemPrompt.ContinuePromptTemplate + Content = toolResponse.IsError ? SystemPrompt.ErrorPromptTemplate : SystemPrompt.ContinuePrompt(JsonConvert.SerializeObject(toolResponse)) }); + // messages.Add(new Message + // { + // Role = "user", + // Content = JsonConvert.SerializeObject(toolResponse) + // }); callback?.Invoke(toolMessageItem); }else if (mcpServer is InnerMcpServer) { Type type = Type.GetType("LinkToolAddin.client.tool."+serverName); MethodInfo method = type.GetMethod(toolName,BindingFlags.Public | BindingFlags.Static); - var task = method.Invoke(null, toolParams.Values.ToArray()) as Task; + var methodParams = toolParams.Values.ToArray(); + object[] args = new object[methodParams.Length]; + for (int i = 0; i < methodParams.Length; i++) + { + if (methodParams[i].GetType() == typeof(JArray)) + { + List list = new List(); + list = (methodParams[i] as JArray).Select(token => token.ToString()).ToList(); + args[i] = list; + } + else + { + args[i] = methodParams[i]; + } + } + var task = method.Invoke(null, args) as Task; JsonRpcResultEntity innerResult = await task; if (innerResult is JsonRpcErrorEntity) { @@ -365,13 +402,14 @@ public class Gateway messages.Add(new Message { Role = "user", - Content = SystemPrompt.ContinuePromptTemplate - }); - messages.Add(new Message - { - Role = "user", - Content = JsonConvert.SerializeObject(innerResult) + // Content = SystemPrompt.ContinuePromptTemplate + Content = SystemPrompt.ContinuePrompt(JsonConvert.SerializeObject(innerResult)) }); + // messages.Add(new Message + // { + // Role = "user", + // Content = JsonConvert.SerializeObject(innerResult) + // }); callback?.Invoke(toolMessageItem); } } @@ -451,16 +489,9 @@ public class Gateway private static async Task GetToolInfos(McpServerList mcpServerList) { - int loop = 0; StringBuilder toolInfos = new StringBuilder(); foreach (McpServer mcpServer in mcpServerList.GetAllServers()) { - loop++; - if (loop > 3) - { - MessageBox.Show("达到最大循环次数", "退出循环"); - break; - } if (mcpServer is InnerMcpServer) { InnerMcpServer innerMcpServer = (InnerMcpServer)mcpServer; diff --git a/server/CallArcGISPro.cs b/server/CallArcGISPro.cs index 049fd3f..4d5c611 100644 --- a/server/CallArcGISPro.cs +++ b/server/CallArcGISPro.cs @@ -4,6 +4,7 @@ using System.Threading.Tasks; using ArcGIS.Desktop.Core.Geoprocessing; using ArcGIS.Desktop.Framework.Dialogs; using ArcGIS.Desktop.Framework.Threading.Tasks; +using Newtonsoft.Json; namespace LinkToolAddin.server; @@ -22,7 +23,7 @@ public class CallArcGISPro Error = new Error() { Code = results.ErrorCode, - Message = results.ErrorMessages.ToString() + Message = JsonConvert.SerializeObject(results.ErrorMessages) } }; } @@ -30,7 +31,7 @@ public class CallArcGISPro { jsonRpcResultEntity = new JsonRpcSuccessEntity { - Result = results.Messages.ToString() + Result = JsonConvert.SerializeObject(results.Messages) }; } return jsonRpcResultEntity;