diff --git a/client/prompt/DynamicPrompt.cs b/client/prompt/DynamicPrompt.cs index 2a342e7..52d695c 100644 --- a/client/prompt/DynamicPrompt.cs +++ b/client/prompt/DynamicPrompt.cs @@ -4,10 +4,14 @@ namespace LinkToolAddin.client.prompt; public class DynamicPrompt { - public static string GetPrompt(string name,Dictionary args) + public static string GetPrompt(string name,Dictionary args = null) { PromptTemplates promptTemplate = new PromptTemplates(); string template = promptTemplate.GetPrompt(name); + if (args == null) + { + return template; + } foreach (KeyValuePair pair in args) { string replaceKey = "{{"+pair.Key+"}}"; @@ -15,4 +19,11 @@ public class DynamicPrompt } return template; } + + public static Dictionary GetAllPrompts() + { + PromptTemplates promptTemplate = new PromptTemplates(); + Dictionary template = promptTemplate.GetPromptsDict(); + return template; + } } \ No newline at end of file diff --git a/client/prompt/PromptTemplates.cs b/client/prompt/PromptTemplates.cs index 3992482..9ef4ed3 100644 --- a/client/prompt/PromptTemplates.cs +++ b/client/prompt/PromptTemplates.cs @@ -17,4 +17,9 @@ public class PromptTemplates { return prompts[name]; } + + public Dictionary GetPromptsDict() + { + return prompts; + } } \ No newline at end of file diff --git a/client/tool/ArcGisPro.cs b/client/tool/ArcGisPro.cs index c94f2e3..da2e604 100644 --- a/client/tool/ArcGisPro.cs +++ b/client/tool/ArcGisPro.cs @@ -1,6 +1,11 @@ -using System.Collections.Generic; +using System; +using System.Collections.Generic; using System.ComponentModel; using System.Threading.Tasks; +using ArcGIS.Core.Data; +using ArcGIS.Core.Data.Raster; +using ArcGIS.Core.Geometry; +using ArcGIS.Desktop.Framework.Threading.Tasks; using LinkToolAddin.server; using ModelContextProtocol.Server; using Newtonsoft.Json; @@ -9,7 +14,7 @@ namespace LinkToolAddin.client.tool; public class ArcGisPro { - [McpServerTool, Description("ArcGIS Pro Tool")] + [McpServerTool, Description("可以通过调用ArcGIS Pro的地理处理工具实现一些数据处理功能。")] public static async Task ArcGisProTool(string toolName, List toolParams) { // Call the ArcGIS Pro method and get the result @@ -18,4 +23,85 @@ public class ArcGisPro // Serialize the result back to a JSON string return result; } + + [McpServerTool, Description("查看指定数据的坐标系、范围、几何类型、是否有Z坐标和M坐标,获取字段列表等")] + public static async Task DataProperty(string datasetPath,string dataName) + { + JsonRpcResultEntity result = new JsonRpcResultEntity(); + await QueuedTask.Run(() => + { + try + { + using Geodatabase gdb = new Geodatabase(new FileGeodatabaseConnectionPath(new Uri(datasetPath))); + FeatureClass featureClass = gdb.OpenDataset(dataName); + FeatureClassDefinition featureClassDefinition = featureClass.GetDefinition(); + SpatialReference spatialReference = featureClassDefinition.GetSpatialReference(); + GeometryType geometryType = featureClassDefinition.GetShapeType(); + result = new JsonRpcSuccessEntity() + { + Id = 1, + Result = JsonConvert.SerializeObject(new Dictionary() + { + {"spatialReference", spatialReference.Name+"(WKID:"+spatialReference.Wkid+")"}, + {"dataName", dataName}, + {"geometryType", geometryType.ToString()}, + {"hasZValue", featureClassDefinition.HasZ()}, + {"hasMValue", featureClassDefinition.HasM()}, + {"fields",featureClassDefinition.GetFields()} + }) + }; + return result; + } + catch (Exception ex) + { + result = new JsonRpcErrorEntity() + { + Error = new Error() + { + Message = ex.Message + }, + Id = 1 + }; + return result; + } + }); + return result; + } + + [McpServerTool, Description("列出gdb数据库中的所有数据名称")] + public static async Task ListData(string gdbPath) + { + var datasets = new List(); + await QueuedTask.Run(() => + { + using (Geodatabase gdb = new Geodatabase(new FileGeodatabaseConnectionPath(new Uri(gdbPath)))) + { + // 获取所有要素类(Feature Classes) + var featureClasses = gdb.GetDefinitions(); + foreach (var fc in featureClasses) + datasets.Add($"要素类: {fc.GetName()}"); + + // 获取所有表格(Tables) + var tables = gdb.GetDefinitions(); + foreach (var table in tables) + datasets.Add($"表格: {table.GetName()}"); + + // 获取所有要素数据集(Feature Datasets) + var featureDatasets = gdb.GetDefinitions(); + foreach (var fd in featureDatasets) + datasets.Add($"要素数据集: {fd.GetName()}"); + + // 获取所有栅格数据集(Raster Datasets) + var rasterDatasets = gdb.GetDefinitions(); + foreach (var raster in rasterDatasets) + datasets.Add($"栅格数据: {raster.GetName()}"); + } + }); + JsonRpcResultEntity result = new JsonRpcSuccessEntity() + { + Id = 1, + Result = JsonConvert.SerializeObject(datasets) + }; + return result; + } } \ No newline at end of file diff --git a/client/tool/KnowledgeBase.cs b/client/tool/KnowledgeBase.cs new file mode 100644 index 0000000..9d2b162 --- /dev/null +++ b/client/tool/KnowledgeBase.cs @@ -0,0 +1,25 @@ +using System.Collections.Generic; +using System.ComponentModel; +using System.Threading.Tasks; +using LinkToolAddin.host.llm.entity; +using LinkToolAddin.resource; +using LinkToolAddin.server; +using ModelContextProtocol.Server; +using Newtonsoft.Json; + +namespace LinkToolAddin.client.tool; + +public class KnowledgeBase +{ + [McpServerTool, Description("可以查询ArcGIS Pro的帮助文档获取关于地理处理工具使用参数的说明")] + public static async Task QueryArcgisHelpDoc(string query) + { + DocDb docDb = new DocDb("sk-db177155677e438f832860e7f4da6afc", DocDb.KnowledgeBase.ArcGISProHelpDoc); + KnowledgeResult knowledgeResult = await docDb.Retrieve(query); + JsonRpcResultEntity result = new JsonRpcSuccessEntity() + { + Result = JsonConvert.SerializeObject(knowledgeResult.ChunkList), + }; + return result; + } +} \ No newline at end of file diff --git a/common/HttpRequest.cs b/common/HttpRequest.cs index a38807c..16c9383 100644 --- a/common/HttpRequest.cs +++ b/common/HttpRequest.cs @@ -1,8 +1,13 @@ +using System; +using System.Collections.Generic; +using System.Globalization; +using System.IO; using System.Net.Http; using System.Net.Http.Headers; using System.Text; using System.Threading.Tasks; using LinkToolAddin.host.llm.entity; +using LinkToolAddin.host.llm.entity.stream; using Newtonsoft.Json; namespace LinkToolAddin.common; @@ -18,4 +23,104 @@ public class HttpRequest var responseBody = await response.Content.ReadAsStringAsync(); return responseBody; } + + public static async IAsyncEnumerable SendStreamPostRequestAsync(string url, string jsonContent, string apiKey) + { + using var client = new HttpClient(); + client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", apiKey); + + // 发送 POST 请求并获取响应流 + var response = await client.PostAsync(url,new StringContent(jsonContent, Encoding.UTF8, "application/json")); + + // 验证响应状态 + response.EnsureSuccessStatusCode(); + + // 获取响应流 + using var stream = await response.Content.ReadAsStreamAsync(); + using var reader = new StreamReader(stream); + + // 流式读取 + string line; + while ((line = await reader.ReadLineAsync()) != null) + { + string content = ProcessPartialResponse(line); + yield return content; + } + } + + private static string ProcessPartialResponse(string rawData) + { + try + { + var lines = rawData.Split(new[] { "data: " }, StringSplitOptions.RemoveEmptyEntries); + foreach (var line in lines) + { + var trimmedLine = line.Trim(); + if (!string.IsNullOrEmpty(trimmedLine)) + { + var result = JsonConvert.DeserializeObject(trimmedLine); + return result.Choices[0].Delta.Content; + } + } + } + catch { /* 处理解析异常 */ } + return null; + } + + public static async IAsyncEnumerable PostWithStreamingResponseAsync( + string url, + string body, + string apiKey, + string contentType = "application/json", + Action configureHeaders = null) + { + using (var client = new HttpClient()) + { + // 设置超时时间为30分钟 + client.Timeout = TimeSpan.FromMinutes(30); + client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", apiKey); + using (var request = new HttpRequestMessage(HttpMethod.Post, url)) + { + // 设置请求头和Body + Console.WriteLine("开始请求..."); + configureHeaders?.Invoke(request); + request.Content = new StringContent(body, Encoding.UTF8, contentType); + + // 发送请求并立即开始读取响应流 + using (var response = await client.SendAsync( + request, + HttpCompletionOption.ResponseHeadersRead)) + { + response.EnsureSuccessStatusCode(); + + // 获取响应流 + using (var stream = await response.Content.ReadAsStreamAsync()) + using (var reader = new StreamReader(stream)) + { + string line; + + StringBuilder incompleteJsonBuffer = new StringBuilder(); + + // 流式读取并输出到控制台 + while ((line = await reader.ReadLineAsync()) != null) + { + foreach (var chunk in line.Split(new[] { "data: " }, StringSplitOptions.RemoveEmptyEntries)) + { + LlmStreamChat dataObj = null; + try + { + dataObj = JsonConvert.DeserializeObject(chunk); + }catch{/*process exception*/} + + if (dataObj is not null) + { + yield return dataObj.Choices[0].Delta.Content; + } + } + } + } + } + } + } + } } \ No newline at end of file diff --git a/common/JsonSchemaGenerator.cs b/common/JsonSchemaGenerator.cs new file mode 100644 index 0000000..ff799c4 --- /dev/null +++ b/common/JsonSchemaGenerator.cs @@ -0,0 +1,150 @@ +using System.Collections.ObjectModel; + +namespace LinkToolAddin.common; + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Text.Json; + +public static class JsonSchemaGenerator +{ + public static string GenerateJsonSchema(MethodInfo methodInfo) + { + var parameters = methodInfo.GetParameters(); + var properties = new Dictionary(); + var required = new List(); + + foreach (var param in parameters) + { + var paramName = param.Name ?? throw new InvalidOperationException("参数没有名称。"); + var paramSchema = GenerateSchemaForType(param.ParameterType); + properties[paramName] = paramSchema; + + if (!param.IsOptional) + { + required.Add(paramName); + } + } + + var schemaRoot = new Dictionary + { + { "$schema", "http://json-schema.org/draft-07/schema#" }, + { "type", "object" }, + { "properties", properties } + }; + + if (required.Count > 0) + { + schemaRoot["required"] = required; + } + + var options = new JsonSerializerOptions { WriteIndented = true }; + return JsonSerializer.Serialize(schemaRoot, options); + } + + private static object GenerateSchemaForType(Type type) + { + // 处理可空类型 + if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>)) + { + var underlyingType = Nullable.GetUnderlyingType(type); + return new[] { GenerateSchemaForType(underlyingType), "null" }; + } + + // 处理集合类型(数组或IEnumerable) + if (IsCollectionType(type, out Type elementType)) + { + return new Dictionary + { + { "type", "array" }, + { "items", GenerateSchemaForType(elementType) } + }; + } + + // 处理基本类型(int, string, bool, etc.) + if (IsPrimitiveType(type)) + { + string jsonType = MapClrTypeToJsonType(type); + var schema = new Dictionary { { "type", jsonType } }; + + if (type == typeof(DateTime)) + schema["format"] = "date-time"; + else if (type == typeof(Guid)) + schema["format"] = "uuid"; + + return schema; + } + + // 处理复杂类型(类、结构体) + if (type.IsClass || type.IsValueType) + { + var props = new Dictionary(); + foreach (var prop in type.GetProperties(BindingFlags.Public | BindingFlags.Instance)) + { + props[prop.Name] = GenerateSchemaForType(prop.PropertyType); + } + + return new Dictionary + { + { "type", "object" }, + { "properties", props } + }; + } + + // 默认情况 + return new Dictionary { { "type", "any" } }; + } + + private static bool IsCollectionType(Type type, out Type elementType) + { + if (type == typeof(string)) + { + elementType = null; + return false; + } + + if (type.IsArray) + { + elementType = type.GetElementType(); + return true; + } + + if (type.IsGenericType) + { + var genericTypeDef = type.GetGenericTypeDefinition(); + if (genericTypeDef == typeof(IEnumerable<>) || + genericTypeDef == typeof(List<>) || + genericTypeDef == typeof(Collection<>)) + { + elementType = type.GetGenericArguments()[0]; + return true; + } + } + + elementType = null; + return false; + } + + private static bool IsPrimitiveType(Type type) + { + return type.IsPrimitive || type == typeof(string) || type == typeof(decimal) || type == typeof(DateTime) || type == typeof(Guid); + } + + private static string MapClrTypeToJsonType(Type type) + { + if (type == typeof(int) || type == typeof(short) || type == typeof(long) || + type == typeof(uint) || type == typeof(ushort) || type == typeof(ulong)) + return "integer"; + if (type == typeof(float) || type == typeof(double) || type == typeof(decimal)) + return "number"; + if (type == typeof(bool)) + return "boolean"; + if (type == typeof(string)) + return "string"; + if (type == typeof(DateTime) || type == typeof(Guid)) + return "string"; + return "any"; + } +} \ No newline at end of file diff --git a/host/Gateway.cs b/host/Gateway.cs index 2380d4f..fca37cc 100644 --- a/host/Gateway.cs +++ b/host/Gateway.cs @@ -11,6 +11,7 @@ using System.Threading.Tasks; using System.Xml; using System.Xml.Linq; using ArcGIS.Desktop.Framework.Dialogs; +using ArcGIS.Desktop.Internal.Mapping.Locate; using LinkToolAddin.client; using LinkToolAddin.client.prompt; using LinkToolAddin.host.llm; @@ -29,6 +30,7 @@ using Newtonsoft.Json.Linq; using Newtonsoft.Json.Schema; using Newtonsoft.Json.Schema.Generation; using Tool = LinkToolAddin.host.mcp.Tool; +using LinkToolAddin.common; namespace LinkToolAddin.host; @@ -68,6 +70,7 @@ public class Gateway TopP = 1, MaxTokens = 1000, }); + long timestamp = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(); log.Info(reponse); messages.Add(new Message { @@ -95,7 +98,8 @@ public class Gateway toolParams = toolParams, type = MessageType.TOOL_MESSAGE, status = toolResponse.IsError ? "fail" : "success", - content = toolResponse.Content.ToString() + content = JsonConvert.SerializeObject(toolResponse), + id = timestamp.ToString() }; messages.Add(new Message { @@ -119,7 +123,8 @@ public class Gateway toolParams = toolParams, type = MessageType.TOOL_MESSAGE, status = toolResponse.IsError ? "fail" : "success", - content = toolResponse.Content.ToString() + content = JsonConvert.SerializeObject(toolResponse), + id = timestamp.ToString() }; messages.Add(new Message { @@ -204,7 +209,8 @@ public class Gateway { content = reponse, role = "assistant", - type = MessageType.CHAT_MESSAGE + type = MessageType.CHAT_MESSAGE, + id = timestamp.ToString() }; callback?.Invoke(chatMessageListItem); } @@ -214,7 +220,235 @@ public class Gateway } } } - + + public static async void SendMessageStream(string message, string model, string gdbPath, Action callback) + { + Llm bailian = new Bailian + { + api_key = "sk-db177155677e438f832860e7f4da6afc" + }; + List messages = new List(); + string toolInfos = await GetToolInfos(new McpServerList()); + log.Info(SystemPrompt.SysPrompt(gdbPath, toolInfos)); + messages.Add(new Message + { + Role = "system", + Content = SystemPrompt.SysPrompt(gdbPath, toolInfos) + }); + messages.Add(new Message + { + Role = "user", + Content = message + }); + bool goOn = true; + string toolPattern = "^[\\s\\S]*?<\\/tool_use>$"; + string promptPattern = "^[\\s\\S]*?<\\/prompt>$"; + McpServerList mcpServerList = new McpServerList(); + while (goOn) + { + LlmJsonContent jsonContent = new LlmJsonContent() + { + Model = model, + Messages = messages, + Temperature = 0.7, + TopP = 1, + MaxTokens = 1000, + }; + long timestamp = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(); + string messageContent = ""; + await foreach(var chunk in bailian.SendChatStreamAsync(jsonContent)) + { + if (chunk == "[DONE]") + { + goOn = false; + }else if (chunk.StartsWith("")) + { + if (Regex.IsMatch(chunk, toolPattern)) + { + //返回工具卡片 + XElement toolUse = XElement.Parse(chunk); + string fullToolName = toolUse.Element("name")?.Value; + string toolArgs = toolUse.Element("arguments")?.Value; + Dictionary toolParams = JsonConvert.DeserializeObject>(toolArgs); + string serverName = fullToolName.Contains(":") ? fullToolName.Split(':')[0] : fullToolName; + string toolName = fullToolName.Contains(":") ? fullToolName.Split(':')[1] : fullToolName; + McpServer mcpServer = mcpServerList.GetServer(serverName); + if (mcpServer is SseMcpServer) + { + SseMcpServer sseMcpServer = mcpServer as SseMcpServer; + SseMcpClient client = new SseMcpClient(sseMcpServer.BaseUrl); + CallToolResponse toolResponse = await client.CallToolAsync(toolName,toolParams); + MessageListItem toolMessageItem = new ToolMessageItem + { + toolName = toolName, + toolParams = toolParams, + type = MessageType.TOOL_MESSAGE, + status = toolResponse.IsError ? "fail" : "success", + content = JsonConvert.SerializeObject(toolResponse), + id = timestamp.ToString() + }; + messages.Add(new Message + { + Role = "user", + Content = toolResponse.IsError ? SystemPrompt.ErrorPromptTemplate : SystemPrompt.ContinuePromptTemplate + }); + messages.Add(new Message + { + Role = "user", + Content = JsonConvert.SerializeObject(toolResponse) + }); + callback?.Invoke(toolMessageItem); + }else if (mcpServer is StdioMcpServer) + { + StdioMcpServer stdioMcpServer = mcpServer as StdioMcpServer; + StdioMcpClient client = new StdioMcpClient(stdioMcpServer.Command, stdioMcpServer.Args); + CallToolResponse toolResponse = await client.CallToolAsync(toolName,toolParams); + MessageListItem toolMessageItem = new ToolMessageItem + { + toolName = toolName, + toolParams = toolParams, + type = MessageType.TOOL_MESSAGE, + status = toolResponse.IsError ? "fail" : "success", + content = JsonConvert.SerializeObject(toolResponse), + id = timestamp.ToString() + }; + messages.Add(new Message + { + Role = "user", + Content = toolResponse.IsError ? SystemPrompt.ErrorPromptTemplate : SystemPrompt.ContinuePromptTemplate + }); + messages.Add(new Message + { + Role = "user", + Content = JsonConvert.SerializeObject(toolResponse) + }); + callback?.Invoke(toolMessageItem); + }else if (mcpServer is InnerMcpServer) + { + Type type = Type.GetType("LinkToolAddin.client.tool."+serverName); + MethodInfo method = type.GetMethod(toolName,BindingFlags.Public | BindingFlags.Static); + var task = method.Invoke(null, toolParams.Values.ToArray()) as Task; + JsonRpcResultEntity innerResult = await task; + if (innerResult is JsonRpcErrorEntity) + { + MessageListItem toolMessageItem = new ToolMessageItem + { + toolName = toolName, + toolParams = toolParams, + type = MessageType.TOOL_MESSAGE, + status = "fail", + content = JsonConvert.SerializeObject(innerResult), + id = timestamp.ToString() + }; + messages.Add(new Message + { + Role = "user", + Content = SystemPrompt.ErrorPromptTemplate + }); + messages.Add(new Message + { + Role = "user", + Content = JsonConvert.SerializeObject(innerResult) + }); + callback?.Invoke(toolMessageItem); + }else if (innerResult is JsonRpcSuccessEntity) + { + MessageListItem toolMessageItem = new ToolMessageItem + { + toolName = toolName, + toolParams = toolParams, + type = MessageType.TOOL_MESSAGE, + status = "success", + content = JsonConvert.SerializeObject(innerResult), + id = timestamp.ToString() + }; + messages.Add(new Message + { + Role = "user", + Content = SystemPrompt.ContinuePromptTemplate + }); + messages.Add(new Message + { + Role = "user", + Content = JsonConvert.SerializeObject(innerResult) + }); + callback?.Invoke(toolMessageItem); + } + } + } + else + { + MessageListItem toolMessageItem = new ToolMessageItem + { + toolName = "", + toolParams = new Dictionary(), + type = MessageType.TOOL_MESSAGE, + status = "loading", + content = "正在生成工具调用参数", + id = timestamp.ToString() + }; + callback?.Invoke(toolMessageItem); + continue; + } + }else if (chunk.StartsWith("")) + { + if (Regex.IsMatch(chunk, promptPattern)) + { + XElement promptUse = XElement.Parse(chunk); + string promptKey = promptUse.Element("name")?.Value; + string promptContent = DynamicPrompt.GetPrompt(promptKey,null); + messages.Add(new Message + { + Role = "user", + Content = JsonConvert.SerializeObject(promptContent) + }); + MessageListItem toolMessageItem = new ToolMessageItem + { + toolName = "调用提示词", + toolParams = null, + type = MessageType.TOOL_MESSAGE, + status = "success", + content = promptKey, + id = timestamp.ToString() + }; + callback?.Invoke(toolMessageItem); + } + else + { + MessageListItem toolMessageItem = new ToolMessageItem + { + toolName = "调用提示词", + toolParams = null, + type = MessageType.TOOL_MESSAGE, + status = "loading", + content = "正在调用提示词", + id = timestamp.ToString() + }; + callback?.Invoke(toolMessageItem); + } + } + else + { + //普通流式消息卡片 + MessageListItem chatMessageListItem = new ChatMessageItem() + { + content = chunk, + role = "assistant", + type = MessageType.CHAT_MESSAGE, + id = timestamp.ToString() + }; + messageContent = chunk; + callback?.Invoke(chatMessageListItem); + } + } + messages.Add(new Message + { + Role = "assistant", + Content = messageContent + }); + } + } + private static async Task GetToolInfos(McpServerList mcpServerList) { int loop = 0; @@ -225,7 +459,7 @@ public class Gateway if (loop > 3) { MessageBox.Show("达到最大循环次数", "退出循环"); - break; + break; } if (mcpServer is InnerMcpServer) { @@ -238,7 +472,7 @@ public class Gateway { string methodName = method.Name; string methodDescription = method.GetCustomAttribute()?.Description; - string methodParamSchema = GenerateMethodParamSchema(method); + string methodParamSchema = LinkToolAddin.common.JsonSchemaGenerator.GenerateJsonSchema(method); McpToolDefinition toolDefinition = new McpToolDefinition { Tool = new Tool @@ -298,6 +532,21 @@ public class Gateway } } } + + Dictionary prompts = DynamicPrompt.GetAllPrompts(); + foreach (KeyValuePair prompt in prompts) + { + McpPromptDefinition promptDefinition = new McpPromptDefinition + { + Prompt = new LinkToolAddin.host.mcp.Prompt + { + Name = prompt.Key + } + }; + XNode node = JsonConvert.DeserializeXNode(JsonConvert.SerializeObject(promptDefinition)); + toolInfos.AppendLine(node.ToString()); + toolInfos.AppendLine(); + } return toolInfos.ToString(); } diff --git a/host/McpServerList.cs b/host/McpServerList.cs index 52dffd7..2c03a41 100644 --- a/host/McpServerList.cs +++ b/host/McpServerList.cs @@ -21,7 +21,7 @@ public class McpServerList {"Content-Type","application/json"} } }); - servers.Add("arcgis", new InnerMcpServer + servers.Add("ArcGisPro", new InnerMcpServer { Name = "ArcGisPro", Type = "inner", diff --git a/host/llm/Bailian.cs b/host/llm/Bailian.cs index bd444c7..b43eed0 100644 --- a/host/llm/Bailian.cs +++ b/host/llm/Bailian.cs @@ -1,3 +1,4 @@ +using System; using System.Collections.Generic; using System.Net.Http; using System.Net.Http.Headers; @@ -17,9 +18,18 @@ public class Bailian : Llm public string max_tokens { get; set; } public string app_id { get; set; } public string api_key { get; set; } - public IAsyncEnumerable SendChatStreamAsync(string message) + public async IAsyncEnumerable SendChatStreamAsync(LlmJsonContent jsonContent) { - throw new System.NotImplementedException(); + jsonContent.Stream = true; + StringBuilder builder = new StringBuilder(); + await foreach (var chunk in HttpRequest.PostWithStreamingResponseAsync( + "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions", + JsonConvert.SerializeObject(jsonContent), + api_key)) + { + builder.Append(chunk); + yield return builder.ToString(); + } } public IAsyncEnumerable SendApplicationStreamAsync(string message) diff --git a/host/llm/Llm.cs b/host/llm/Llm.cs index 62a11f4..1807b9d 100644 --- a/host/llm/Llm.cs +++ b/host/llm/Llm.cs @@ -11,7 +11,7 @@ public interface Llm public string top_p { get; set; } public string max_tokens { get; set; } - public IAsyncEnumerable SendChatStreamAsync(string message); + public IAsyncEnumerable SendChatStreamAsync(LlmJsonContent jsonContent); public IAsyncEnumerable SendApplicationStreamAsync(string message); public Task SendChatAsync(LlmJsonContent jsonContent); public Task SendApplicationAsync(CommonInput commonInput); diff --git a/host/llm/entity/stream/LlmStreamChat.cs b/host/llm/entity/stream/LlmStreamChat.cs new file mode 100644 index 0000000..6991982 --- /dev/null +++ b/host/llm/entity/stream/LlmStreamChat.cs @@ -0,0 +1,54 @@ +namespace LinkToolAddin.host.llm.entity.stream +{ + using System; + using System.Collections.Generic; + + using System.Globalization; + using Newtonsoft.Json; + using Newtonsoft.Json.Converters; + + public partial class LlmStreamChat + { + [JsonProperty("choices")] + public Choice[] Choices { get; set; } + + [JsonProperty("object")] + public string Object { get; set; } + + [JsonProperty("usage")] + public object Usage { get; set; } + + [JsonProperty("created")] + public long Created { get; set; } + + [JsonProperty("system_fingerprint")] + public object SystemFingerprint { get; set; } + + [JsonProperty("model")] + public string Model { get; set; } + + [JsonProperty("id")] + public string Id { get; set; } + } + + public partial class Choice + { + [JsonProperty("finish_reason")] + public string FinishReason { get; set; } + + [JsonProperty("delta")] + public Delta Delta { get; set; } + + [JsonProperty("index")] + public long Index { get; set; } + + [JsonProperty("logprobs")] + public object Logprobs { get; set; } + } + + public partial class Delta + { + [JsonProperty("content")] + public string Content { get; set; } + } +} \ No newline at end of file diff --git a/host/mcp/McpPromptDefinition.cs b/host/mcp/McpPromptDefinition.cs new file mode 100644 index 0000000..83ee903 --- /dev/null +++ b/host/mcp/McpPromptDefinition.cs @@ -0,0 +1,27 @@ +namespace LinkToolAddin.host.mcp +{ + using System; + using System.Collections.Generic; + + using System.Globalization; + using Newtonsoft.Json; + using Newtonsoft.Json.Converters; + + public partial class McpPromptDefinition + { + [JsonProperty("prompt")] + public Prompt Prompt { get; set; } + } + + public partial class Prompt + { + [JsonProperty("name")] + public string Name { get; set; } + + [JsonProperty("description")] + public string Description { get; set; } + + [JsonProperty("arguments")] + public string Arguments { get; set; } + } +} \ No newline at end of file diff --git a/ui/dockpane/TestDockpane.xaml b/ui/dockpane/TestDockpane.xaml index e550084..fd7d8c8 100644 --- a/ui/dockpane/TestDockpane.xaml +++ b/ui/dockpane/TestDockpane.xaml @@ -21,6 +21,7 @@ + @@ -33,5 +34,6 @@ + \ No newline at end of file diff --git a/ui/dockpane/TestDockpane.xaml.cs b/ui/dockpane/TestDockpane.xaml.cs index d67e5c9..abf8b4a 100644 --- a/ui/dockpane/TestDockpane.xaml.cs +++ b/ui/dockpane/TestDockpane.xaml.cs @@ -1,5 +1,6 @@ using System.Collections.Generic; using System.Linq; +using System.Text; using System.Windows; using System.Windows.Controls; using LinkToolAddin.client; @@ -27,6 +28,9 @@ namespace LinkToolAddin.ui.dockpane { private static ILog log = LogManager.GetLogger(typeof(TestDockpaneView)); + private List idList = new List(); + private Dictionary messageDict = new Dictionary(); + public TestDockpaneView() { InitLogger(); @@ -139,6 +143,35 @@ namespace LinkToolAddin.ui.dockpane }); log.Info(reponse); } + + private async void Request_Bailian_Stream_Test() + { + LlmJsonContent jsonContent = new LlmJsonContent() + { + Model = "qwen-max", + Messages = new List() + { + new Message() + { + Role = "user", + Content = "给我写一篇1000字的高考议论文" + } + }, + Temperature = 0.7, + TopP = 1, + MaxTokens = 1000, + Stream = true + }; + Llm bailian = new Bailian + { + api_key = "sk-db177155677e438f832860e7f4da6afc", + app_id = "6a77c5a68de64f469b79fcdcde9d5001", + }; + await foreach (var chunk in bailian.SendChatStreamAsync(jsonContent)) + { + log.Info(chunk); + } + } private void TestButton_OnClick(object sender, RoutedEventArgs e) { @@ -159,7 +192,31 @@ namespace LinkToolAddin.ui.dockpane private void PromptTestButton_OnClick(object sender, RoutedEventArgs e) { string userPrompt = PromptTestTextBox.Text; - Gateway.SendMessage(userPrompt,"qwen-max","C:/Project/test.gdb",AddReply); + // Gateway.SendMessage(userPrompt,"qwen-max","C:/Project/test.gdb",AddReply); + Gateway.SendMessageStream(userPrompt,"qwen-max","D:\\01_Project\\20250305_LinkTool\\20250420_AiDemoProject\\20250420_AiDemoProject.gdb",AddReplyStream); + } + + public void AddReplyStream(MessageListItem msg) + { + string id = msg.id; + if (idList.Contains(id)) + { + messageDict[id] = msg; + } + else + { + idList.Add(id); + messageDict.Add(msg.id, msg); + } + ReplyTextBox.Clear(); + StringBuilder builder = new StringBuilder(); + foreach (KeyValuePair pair in messageDict) + { + MessageListItem msgItem = pair.Value; + builder.AppendLine(msgItem.content); + ReplyTextBox.Text = builder.ToString(); + ReplyTextBox.ScrollToEnd(); + } } public void AddReply(MessageListItem msg) @@ -169,5 +226,10 @@ namespace LinkToolAddin.ui.dockpane string originContent = ReplyTextBox.Text; ReplyTextBox.Text = originContent + content; } + + private void TestStream_OnClick(object sender, RoutedEventArgs e) + { + Request_Bailian_Stream_Test(); + } } }