LinkToolAddin/host/Gateway.cs

210 lines
8.3 KiB
C#

using System;
using System.Collections.Generic;
using System.Text.RegularExpressions;
using System.Threading;
using System.Xml.Linq;
using LinkToolAddin.client;
using LinkToolAddin.client.prompt;
using LinkToolAddin.host.llm;
using LinkToolAddin.host.llm.entity;
using LinkToolAddin.host.mcp;
using LinkToolAddin.host.prompt;
using LinkToolAddin.message;
using ModelContextProtocol.Protocol.Types;
using Newtonsoft.Json;
namespace LinkToolAddin.host;
public class Gateway
{
public static async void SendMessage(string message, string model, string gdbPath, Action<MessageListItem> callback)
{
Llm bailian = new Bailian
{
api_key = "sk-db177155677e438f832860e7f4da6afc"
};
List<Message> messages = new List<Message>();
messages.Add(new Message
{
Role = "system",
Content = SystemPrompt.SysPromptTemplate
});
messages.Add(new Message
{
Role = "user",
Content = message
});
bool goOn = true;
string pattern = "<tool_use>[\\s\\S]*?<\\/tool_use>";
string promptPattern = "<prompt>[\\s\\S]*?<\\/prompt>";
Dictionary<string,McpServer> servers = new Dictionary<string, McpServer>();
while (goOn)
{
string reponse = await bailian.SendChatAsync(new LlmJsonContent()
{
Model = model,
Messages = messages,
Temperature = 0.7,
TopP = 1,
MaxTokens = 1000,
});
messages.Add(new Message
{
Role = "assistant",
Content = reponse
});
if (Regex.IsMatch(reponse, pattern))
{
//工具类型的消息
XElement toolUse = XElement.Parse(reponse);
string fullToolName = toolUse.Element("name")?.Value;
string toolArgs = toolUse.Element("arguments")?.Value;
Dictionary<string, object> toolParams = JsonConvert.DeserializeObject<Dictionary<string, object>>(toolArgs);
string serverName = fullToolName.Contains(":") ? fullToolName.Split(':')[0] : fullToolName;
string toolName = fullToolName.Contains(":") ? fullToolName.Split(':')[1] : fullToolName;
McpServer mcpServer = servers[serverName];
if (mcpServer is SseMcpServer)
{
SseMcpServer sseMcpServer = mcpServer as SseMcpServer;
SseMcpClient client = new SseMcpClient(sseMcpServer.BaseUrl);
CallToolResponse toolResponse = await client.CallToolAsync(toolName,toolParams);
MessageListItem toolMessageItem = new ToolMessageItem
{
toolName = toolName,
toolParams = toolParams,
type = MessageType.TOOL_MESSAGE,
status = toolResponse.IsError ? "fail" : "success",
content = toolResponse.Content.ToString()
};
messages.Add(new Message
{
Role = "user",
Content = toolResponse.IsError ? SystemPrompt.ErrorPromptTemplate : SystemPrompt.ContinuePromptTemplate
});
messages.Add(new Message
{
Role = "user",
Content = JsonConvert.SerializeObject(toolResponse)
});
callback?.Invoke(toolMessageItem);
}else if (mcpServer is StdioMcpServer)
{
StdioMcpServer stdioMcpServer = mcpServer as StdioMcpServer;
StdioMcpClient client = new StdioMcpClient(stdioMcpServer.Command, stdioMcpServer.Args);
CallToolResponse toolResponse = await client.CallToolAsync(toolName,toolParams);
MessageListItem toolMessageItem = new ToolMessageItem
{
toolName = toolName,
toolParams = toolParams,
type = MessageType.TOOL_MESSAGE,
status = toolResponse.IsError ? "fail" : "success",
content = toolResponse.Content.ToString()
};
messages.Add(new Message
{
Role = "user",
Content = toolResponse.IsError ? SystemPrompt.ErrorPromptTemplate : SystemPrompt.ContinuePromptTemplate
});
messages.Add(new Message
{
Role = "user",
Content = JsonConvert.SerializeObject(toolResponse)
});
callback?.Invoke(toolMessageItem);
}
}
else if (Regex.IsMatch(reponse, promptPattern))
{
XElement prompt = XElement.Parse(reponse);
string fullPromptName = prompt.Element("name")?.Value;
string promptArgs = prompt.Element("arguments")?.Value;
Dictionary<string, object> promptParams = JsonConvert.DeserializeObject<Dictionary<string, object>>(promptArgs);
string serverName = fullPromptName.Contains(":") ? fullPromptName.Split(':')[0] : fullPromptName;
string promptName = fullPromptName.Contains(":") ? fullPromptName.Split(':')[1] : fullPromptName;
string promptRes = DynamicPrompt.GetPrompt(promptName, promptParams);
messages.Add(new Message
{
Role = "user",
Content = promptRes
});
}
else
{
MessageListItem chatMessageListItem = new ChatMessageItem()
{
content = reponse,
role = "assistant",
type = MessageType.CHAT_MESSAGE
};
callback?.Invoke(chatMessageListItem);
}
if (reponse == "[DONE]")
{
goOn = false;
}
}
}
public static async void TestChatMessage(string message, string model, string gdbPath,
Action<MessageListItem> callback)
{
MessageListItem chatListItem = new ChatMessageItem
{
content = message,
role = "assistant",
type = MessageType.CHAT_MESSAGE,
id = "testmsg12345"
};
callback?.Invoke(chatListItem);
}
public static async void TestToolMessage(string message, string model, string gdbPath, Action<MessageListItem> callback)
{
MessageListItem toolListItem = new ToolMessageItem
{
content = message,
type = MessageType.TOOL_MESSAGE,
toolName = "arcgis_pro.executeTool",
toolParams = new Dictionary<string, object>
{
{"gp_name","analysis.Buffer"},
{"gp_params","[\"C:\\test.gdb\\river\",\"30 Meters\"]"}
},
id = "testtool123456",
status = "success",
role = "user",
result = "成功创建缓冲区"
};
callback?.Invoke(toolListItem);
}
public static async void TestWOrkflow(string message, string model, string gdbPath, Action<MessageListItem> callback)
{
Thread.Sleep(2000);
MessageListItem chatListItem = new ChatMessageItem
{
content = message,
role = "assistant",
type = MessageType.CHAT_MESSAGE,
id = "testid12345"
};
callback?.Invoke(chatListItem);
Thread.Sleep(1500);
MessageListItem toolListItem = new ToolMessageItem
{
content = message,
type = MessageType.TOOL_MESSAGE,
toolName = "arcgis_pro.executeTool",
toolParams = new Dictionary<string, object>
{
{"gp_name","analysis.Buffer"},
{"gp_params","[\"C:\\test.gdb\\river\",\"30 Meters\"]"}
},
id = "testtool123456",
status = "success",
role = "user",
result = "成功创建缓冲区"
};
callback?.Invoke(toolListItem);
}
}