LinkToolAddin/host/Gateway.cs

406 lines
17 KiB
C#

using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.ComponentModel.DataAnnotations;
using System.Linq;
using System.Reflection;
using System.Text;
using System.Text.RegularExpressions;
using System.Threading;
using System.Threading.Tasks;
using System.Xml;
using System.Xml.Linq;
using ArcGIS.Desktop.Framework.Dialogs;
using LinkToolAddin.client;
using LinkToolAddin.client.prompt;
using LinkToolAddin.host.llm;
using LinkToolAddin.host.llm.entity;
using LinkToolAddin.host.mcp;
using LinkToolAddin.host.prompt;
using LinkToolAddin.message;
using LinkToolAddin.server;
using LinkToolAddin.ui.dockpane;
using log4net;
using Microsoft.Extensions.AI;
using ModelContextProtocol.Client;
using ModelContextProtocol.Protocol.Types;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
using Newtonsoft.Json.Schema;
using Newtonsoft.Json.Schema.Generation;
using Tool = LinkToolAddin.host.mcp.Tool;
namespace LinkToolAddin.host;
public class Gateway
{
private static ILog log = LogManager.GetLogger(typeof(Gateway));
public static async void SendMessage(string message, string model, string gdbPath, Action<MessageListItem> callback)
{
Llm bailian = new Bailian
{
api_key = "sk-db177155677e438f832860e7f4da6afc"
};
List<Message> messages = new List<Message>();
string toolInfos = await GetToolInfos(new McpServerList());
log.Info(SystemPrompt.SysPrompt(gdbPath, toolInfos));
messages.Add(new Message
{
Role = "system",
Content = SystemPrompt.SysPrompt(gdbPath, toolInfos)
});
messages.Add(new Message
{
Role = "user",
Content = message
});
bool goOn = true;
string pattern = "^<tool_use>[\\s\\S]*?<\\/tool_use>$";
string promptPattern = "^<prompt>[\\s\\S]*?<\\/prompt>$";
McpServerList mcpServerList = new McpServerList();
while (goOn)
{
string reponse = await bailian.SendChatAsync(new LlmJsonContent()
{
Model = model,
Messages = messages,
Temperature = 0.7,
TopP = 1,
MaxTokens = 1000,
});
log.Info(reponse);
messages.Add(new Message
{
Role = "assistant",
Content = reponse
});
if (Regex.IsMatch(reponse, pattern))
{
//工具类型的消息
XElement toolUse = XElement.Parse(reponse);
string fullToolName = toolUse.Element("name")?.Value;
string toolArgs = toolUse.Element("arguments")?.Value;
Dictionary<string, object> toolParams = JsonConvert.DeserializeObject<Dictionary<string, object>>(toolArgs);
string serverName = fullToolName.Contains(":") ? fullToolName.Split(':')[0] : fullToolName;
string toolName = fullToolName.Contains(":") ? fullToolName.Split(':')[1] : fullToolName;
McpServer mcpServer = mcpServerList.GetServer(serverName);
if (mcpServer is SseMcpServer)
{
SseMcpServer sseMcpServer = mcpServer as SseMcpServer;
SseMcpClient client = new SseMcpClient(sseMcpServer.BaseUrl);
CallToolResponse toolResponse = await client.CallToolAsync(toolName,toolParams);
MessageListItem toolMessageItem = new ToolMessageItem
{
toolName = toolName,
toolParams = toolParams,
type = MessageType.TOOL_MESSAGE,
status = toolResponse.IsError ? "fail" : "success",
content = toolResponse.Content.ToString()
};
messages.Add(new Message
{
Role = "user",
Content = toolResponse.IsError ? SystemPrompt.ErrorPromptTemplate : SystemPrompt.ContinuePromptTemplate
});
messages.Add(new Message
{
Role = "user",
Content = JsonConvert.SerializeObject(toolResponse)
});
callback?.Invoke(toolMessageItem);
}else if (mcpServer is StdioMcpServer)
{
StdioMcpServer stdioMcpServer = mcpServer as StdioMcpServer;
StdioMcpClient client = new StdioMcpClient(stdioMcpServer.Command, stdioMcpServer.Args);
CallToolResponse toolResponse = await client.CallToolAsync(toolName,toolParams);
MessageListItem toolMessageItem = new ToolMessageItem
{
toolName = toolName,
toolParams = toolParams,
type = MessageType.TOOL_MESSAGE,
status = toolResponse.IsError ? "fail" : "success",
content = toolResponse.Content.ToString()
};
messages.Add(new Message
{
Role = "user",
Content = toolResponse.IsError ? SystemPrompt.ErrorPromptTemplate : SystemPrompt.ContinuePromptTemplate
});
messages.Add(new Message
{
Role = "user",
Content = JsonConvert.SerializeObject(toolResponse)
});
callback?.Invoke(toolMessageItem);
}else if (mcpServer is InnerMcpServer)
{
Type type = Type.GetType("LinkToolAddin.client.tool."+serverName);
MethodInfo method = type.GetMethod(toolName,BindingFlags.Public | BindingFlags.Static);
var task = method.Invoke(null, toolParams.Values.ToArray()) as Task<JsonRpcResultEntity>;
JsonRpcResultEntity innerResult = await task;
if (innerResult is JsonRpcErrorEntity)
{
MessageListItem toolMessageItem = new ToolMessageItem
{
toolName = toolName,
toolParams = toolParams,
type = MessageType.TOOL_MESSAGE,
status = "fail",
content = JsonConvert.SerializeObject(innerResult)
};
messages.Add(new Message
{
Role = "user",
Content = SystemPrompt.ErrorPromptTemplate
});
messages.Add(new Message
{
Role = "user",
Content = JsonConvert.SerializeObject(innerResult)
});
callback?.Invoke(toolMessageItem);
}else if (innerResult is JsonRpcSuccessEntity)
{
MessageListItem toolMessageItem = new ToolMessageItem
{
toolName = toolName,
toolParams = toolParams,
type = MessageType.TOOL_MESSAGE,
status = "success",
content = JsonConvert.SerializeObject(innerResult)
};
messages.Add(new Message
{
Role = "user",
Content = SystemPrompt.ContinuePromptTemplate
});
messages.Add(new Message
{
Role = "user",
Content = JsonConvert.SerializeObject(innerResult)
});
callback?.Invoke(toolMessageItem);
}
}
}
else if (Regex.IsMatch(reponse, promptPattern))
{
XElement prompt = XElement.Parse(reponse);
string fullPromptName = prompt.Element("name")?.Value;
string promptArgs = prompt.Element("arguments")?.Value;
Dictionary<string, object> promptParams = JsonConvert.DeserializeObject<Dictionary<string, object>>(promptArgs);
string serverName = fullPromptName.Contains(":") ? fullPromptName.Split(':')[0] : fullPromptName;
string promptName = fullPromptName.Contains(":") ? fullPromptName.Split(':')[1] : fullPromptName;
string promptRes = DynamicPrompt.GetPrompt(promptName, promptParams);
messages.Add(new Message
{
Role = "user",
Content = promptRes
});
}
else
{
MessageListItem chatMessageListItem = new ChatMessageItem()
{
content = reponse,
role = "assistant",
type = MessageType.CHAT_MESSAGE
};
callback?.Invoke(chatMessageListItem);
}
if (reponse == "[DONE]")
{
goOn = false;
}
}
}
private static async Task<string> GetToolInfos(McpServerList mcpServerList)
{
int loop = 0;
StringBuilder toolInfos = new StringBuilder();
foreach (McpServer mcpServer in mcpServerList.GetAllServers())
{
loop++;
if (loop > 3)
{
MessageBox.Show("达到最大循环次数", "退出循环");
break;
}
if (mcpServer is InnerMcpServer)
{
InnerMcpServer innerMcpServer = (InnerMcpServer)mcpServer;
Type type = Type.GetType("LinkToolAddin.client.tool." + innerMcpServer.Name);
MethodInfo[] methods = type.GetMethods();
foreach (MethodInfo method in methods)
{
if (method.IsPublic && method.IsStatic)
{
string methodName = method.Name;
string methodDescription = method.GetCustomAttribute<DescriptionAttribute>()?.Description;
string methodParamSchema = GenerateMethodParamSchema(method);
McpToolDefinition toolDefinition = new McpToolDefinition
{
Tool = new Tool
{
Name = innerMcpServer.Name + ":" + methodName,
Description = methodDescription,
Arguments = methodParamSchema
}
};
XNode node = JsonConvert.DeserializeXNode(JsonConvert.SerializeObject(toolDefinition));
toolInfos.AppendLine(node.ToString());
toolInfos.AppendLine();
}
}
}
else if(mcpServer is SseMcpServer)
{
SseMcpClient client = new SseMcpClient((mcpServer as SseMcpServer).BaseUrl);
IList<McpClientTool> tools = await client.GetToolListAsync();
foreach (McpClientTool tool in tools)
{
string toolName = (mcpServer as SseMcpServer).Name + ":" + tool.Name;
string toolDescription = tool.Description;
string toolParamSchema = tool.JsonSchema.ToString();
McpToolDefinition toolDefinition = new McpToolDefinition
{
Tool = new Tool
{
Name = toolName,
Description = toolDescription,
Arguments = toolParamSchema
}
};
toolInfos.AppendLine(JsonConvert.DeserializeXNode(JsonConvert.SerializeObject(toolDefinition)).ToString());
toolInfos.AppendLine();
}
}else if (mcpServer is StdioMcpServer)
{
StdioMcpClient client = new StdioMcpClient((mcpServer as StdioMcpServer).Command, (mcpServer as StdioMcpServer).Args);
IList<McpClientTool> tools = await client.GetToolListAsync();
foreach (McpClientTool tool in tools)
{
string toolName = (mcpServer as StdioMcpServer).Name + ":" + tool.Name;;
string toolDescription = tool.Description;
string toolParamSchema = tool.JsonSchema.ToString();
McpToolDefinition toolDefinition = new McpToolDefinition
{
Tool = new Tool
{
Name = toolName,
Description = toolDescription,
Arguments = CompressJson(toolParamSchema)
}
};
toolInfos.AppendLine(JsonConvert.DeserializeXNode(JsonConvert.SerializeObject(toolDefinition)).ToString());
toolInfos.AppendLine();
}
}
}
return toolInfos.ToString();
}
public static string CompressJson(string json)
{
// 解析JSON并自动去除无关空白
var token = JToken.Parse(json);
// 序列化为无格式紧凑字符串
return token.ToString(Newtonsoft.Json.Formatting.None);
}
private static string GenerateMethodParamSchema(MethodInfo method)
{
var generator = new JSchemaGenerator
{
// 启用属性注解处理
DefaultRequired = Required.DisallowNull,
SchemaReferenceHandling = SchemaReferenceHandling.None
};
var paramSchema = new JSchema { Type = JSchemaType.Object };
foreach (ParameterInfo param in method.GetParameters())
{
// 生成参数类型的基础Schema
JSchema typeSchema = generator.Generate(param.ParameterType);
// 添加Description描述
var descriptionAttr = param.GetCustomAttribute<DescriptionAttribute>();
if (descriptionAttr != null)
{
typeSchema.Description = descriptionAttr.Description; // 网页6的Description特性处理
}
paramSchema.Properties.Add(param.Name, typeSchema);
}
var settings = new JsonSerializerSettings {
Formatting = Newtonsoft.Json.Formatting.None, // 关键设置:禁用缩进和换行
NullValueHandling = NullValueHandling.Ignore // 可选:忽略空值
};
return JsonConvert.SerializeObject(paramSchema, settings);;
}
public static async void TestChatMessage(string message, string model, string gdbPath,
Action<MessageListItem> callback)
{
MessageListItem chatListItem = new ChatMessageItem
{
content = message,
role = "assistant",
type = MessageType.CHAT_MESSAGE,
id = "testmsg12345"
};
callback?.Invoke(chatListItem);
}
public static async void TestToolMessage(string message, string model, string gdbPath, Action<MessageListItem> callback)
{
MessageListItem toolListItem = new ToolMessageItem
{
content = message,
type = MessageType.TOOL_MESSAGE,
toolName = "arcgis_pro.executeTool",
toolParams = new Dictionary<string, object>
{
{"gp_name","analysis.Buffer"},
{"gp_params","[\"C:\\test.gdb\\river\",\"30 Meters\"]"}
},
id = "testtool123456",
status = "success",
role = "user",
result = "成功创建缓冲区"
};
callback?.Invoke(toolListItem);
}
public static async void TestWorkflow(string message, string model, string gdbPath, Action<MessageListItem> callback)
{
Thread.Sleep(2000);
MessageListItem chatListItem = new ChatMessageItem
{
content = message,
role = "assistant",
type = MessageType.CHAT_MESSAGE,
id = "testid12345"
};
callback?.Invoke(chatListItem);
Thread.Sleep(1500);
MessageListItem toolListItem = new ToolMessageItem
{
content = message,
type = MessageType.TOOL_MESSAGE,
toolName = "arcgis_pro.executeTool",
toolParams = new Dictionary<string, object>
{
{"gp_name","analysis.Buffer"},
{"gp_params","[\"C:\\test.gdb\\river\",\"30 Meters\"]"}
},
id = "testtool123456",
status = "success",
role = "user",
result = "成功创建缓冲区"
};
callback?.Invoke(toolListItem);
}
}