281 lines
9.5 KiB
Rust
281 lines
9.5 KiB
Rust
use std::sync::Arc;
|
|
use actix_web::{App, HttpServer, middleware::Logger, web};
|
|
use log::{LevelFilter, debug, trace};
|
|
use rmcp::{ErrorData as McpError, RoleServer, ServerHandler, handler::server::{router::prompt::PromptRouter, tool::ToolRouter, wrapper::Parameters}, model::*, prompt, prompt_handler, prompt_router, service::RequestContext, tool, tool_handler, tool_router, transport::streamable_http_server::session::local::LocalSessionManager};
|
|
use rmcp_actix_web::transport::StreamableHttpService;
|
|
use serde_json::json;
|
|
use simplelog::{ColorChoice, ConfigBuilder, TermLogger, TerminalMode};
|
|
use tokio::sync::Mutex;
|
|
|
|
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
|
|
pub struct StructRequest {
|
|
pub a: i32,
|
|
pub b: i32,
|
|
}
|
|
|
|
#[derive(Debug, serde::Serialize, serde::Deserialize, schemars::JsonSchema)]
|
|
pub struct ExamplePromptArgs {
|
|
/// A message to put in the prompt
|
|
pub message: String,
|
|
}
|
|
|
|
#[derive(Debug, serde::Serialize, serde::Deserialize, schemars::JsonSchema)]
|
|
pub struct CounterAnalysisArgs {
|
|
/// The target value you're trying to reach
|
|
pub goal: i32,
|
|
/// Preferred strategy: 'fast' or 'careful'
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub strategy: Option<String>,
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct Counter {
|
|
counter: Arc<Mutex<i32>>,
|
|
tool_router: ToolRouter<Counter>,
|
|
prompt_router: PromptRouter<Counter>,
|
|
}
|
|
|
|
#[tool_router]
|
|
impl Counter {
|
|
#[allow(dead_code)]
|
|
pub fn new() -> Self {
|
|
Self {
|
|
counter: Arc::new(Mutex::new(0)),
|
|
tool_router: Self::tool_router(),
|
|
prompt_router: Self::prompt_router(),
|
|
}
|
|
}
|
|
|
|
fn _create_resource_text(&self, uri: &str, name: &str) -> Resource {
|
|
RawResource::new(uri, name.to_string()).no_annotation()
|
|
}
|
|
|
|
#[tool(description = "Increment the counter by 1")]
|
|
async fn increment(&self) -> Result<CallToolResult, McpError> {
|
|
let mut counter = self.counter.lock().await;
|
|
*counter += 1;
|
|
debug!("incrementing with 1, result: {counter}");
|
|
Ok(CallToolResult::success(vec![Content::text(
|
|
counter.to_string(),
|
|
)]))
|
|
}
|
|
|
|
#[tool(description = "Decrement the counter by 1")]
|
|
async fn decrement(&self) -> Result<CallToolResult, McpError> {
|
|
let mut counter = self.counter.lock().await;
|
|
*counter -= 1;
|
|
debug!("decrementing with 1, result: {counter}");
|
|
Ok(CallToolResult::success(vec![Content::text(
|
|
counter.to_string(),
|
|
)]))
|
|
}
|
|
|
|
#[tool(description = "Get the current counter value")]
|
|
async fn get_value(&self) -> Result<CallToolResult, McpError> {
|
|
let counter = self.counter.lock().await;
|
|
debug!("current value: {counter}");
|
|
Ok(CallToolResult::success(vec![Content::text(
|
|
counter.to_string(),
|
|
)]))
|
|
}
|
|
|
|
#[tool(description = "Say hello to the client")]
|
|
fn say_hello(&self) -> Result<CallToolResult, McpError> {
|
|
debug!("hello!");
|
|
Ok(CallToolResult::success(vec![Content::text("hello")]))
|
|
}
|
|
|
|
#[tool(description = "Repeat what you say")]
|
|
fn echo(&self, Parameters(object): Parameters<JsonObject>) -> Result<CallToolResult, McpError> {
|
|
debug!("{object:?}");
|
|
Ok(CallToolResult::success(vec![Content::text(
|
|
serde_json::Value::Object(object).to_string(),
|
|
)]))
|
|
}
|
|
|
|
#[tool(description = "Calculate the sum of two numbers")]
|
|
fn sum(
|
|
&self,
|
|
Parameters(StructRequest { a, b }): Parameters<StructRequest>,
|
|
) -> Result<CallToolResult, McpError> {
|
|
debug!("the sum of {a} + {b} = {}", a + b);
|
|
Ok(CallToolResult::success(vec![Content::text(
|
|
(a + b).to_string(),
|
|
)]))
|
|
}
|
|
}
|
|
|
|
#[prompt_router]
|
|
impl Counter {
|
|
/// This is an example prompt that takes one required argument, message
|
|
#[prompt(
|
|
name = "example_prompt",
|
|
meta = Meta(rmcp::object!({"meta_key": "meta_value"}))
|
|
)]
|
|
async fn example_prompt(
|
|
&self,
|
|
Parameters(args): Parameters<ExamplePromptArgs>,
|
|
_ctx: RequestContext<RoleServer>,
|
|
) -> Result<Vec<PromptMessage>, McpError> {
|
|
let prompt = format!(
|
|
"This is an example prompt with your message here: '{}'",
|
|
args.message
|
|
);
|
|
Ok(vec![PromptMessage {
|
|
role: PromptMessageRole::User,
|
|
content: PromptMessageContent::text(prompt),
|
|
}])
|
|
}
|
|
|
|
/// Analyze the current counter value and suggest next steps
|
|
#[prompt(name = "counter_analysis")]
|
|
async fn counter_analysis(
|
|
&self,
|
|
Parameters(args): Parameters<CounterAnalysisArgs>,
|
|
_ctx: RequestContext<RoleServer>,
|
|
) -> Result<GetPromptResult, McpError> {
|
|
let strategy = args.strategy.unwrap_or_else(|| "careful".to_string());
|
|
let current_value = *self.counter.lock().await;
|
|
let difference = args.goal - current_value;
|
|
|
|
let messages = vec![
|
|
PromptMessage::new_text(
|
|
PromptMessageRole::Assistant,
|
|
"I'll analyze the counter situation and suggest the best approach.",
|
|
),
|
|
PromptMessage::new_text(
|
|
PromptMessageRole::User,
|
|
format!(
|
|
"Current counter value: {}\nGoal value: {}\nDifference: {}\nStrategy preference: {}\n\nPlease analyze the situation and suggest the best approach to reach the goal.",
|
|
current_value, args.goal, difference, strategy
|
|
),
|
|
),
|
|
];
|
|
|
|
Ok(GetPromptResult {
|
|
description: Some(format!(
|
|
"Counter analysis for reaching {} from {}",
|
|
args.goal, current_value
|
|
)),
|
|
messages,
|
|
})
|
|
}
|
|
}
|
|
|
|
#[tool_handler(meta = Meta(rmcp::object!({"tool_meta_key": "tool_meta_value"})))]
|
|
#[prompt_handler(meta = Meta(rmcp::object!({"router_meta_key": "router_meta_value"})))]
|
|
impl ServerHandler for Counter {
|
|
fn get_info(&self) -> ServerInfo {
|
|
ServerInfo {
|
|
protocol_version: ProtocolVersion::V_2024_11_05,
|
|
capabilities: ServerCapabilities::builder()
|
|
.enable_prompts()
|
|
.enable_resources()
|
|
.enable_tools()
|
|
.build(),
|
|
server_info: Implementation::from_build_env(),
|
|
instructions: Some("This server provides counter tools and prompts. Tools: increment, decrement, get_value, say_hello, echo, sum. Prompts: example_prompt (takes a message), counter_analysis (analyzes counter state with a goal).".to_string()),
|
|
}
|
|
}
|
|
|
|
async fn list_resources(
|
|
&self,
|
|
_request: Option<PaginatedRequestParam>,
|
|
_: RequestContext<RoleServer>,
|
|
) -> Result<ListResourcesResult, McpError> {
|
|
Ok(ListResourcesResult {
|
|
resources: vec![
|
|
self._create_resource_text("str:////Users/to/some/path/", "cwd"),
|
|
self._create_resource_text("memo://insights", "memo-name"),
|
|
],
|
|
next_cursor: None,
|
|
meta: None,
|
|
})
|
|
}
|
|
|
|
async fn read_resource(
|
|
&self,
|
|
ReadResourceRequestParam { uri }: ReadResourceRequestParam,
|
|
_: RequestContext<RoleServer>,
|
|
) -> Result<ReadResourceResult, McpError> {
|
|
match uri.as_str() {
|
|
"str:////Users/to/some/path/" => {
|
|
let cwd = "/Users/to/some/path/";
|
|
Ok(ReadResourceResult {
|
|
contents: vec![ResourceContents::text(cwd, uri)],
|
|
})
|
|
}
|
|
"memo://insights" => {
|
|
let memo = "Business Intelligence Memo\n\nAnalysis has revealed 5 key insights ...";
|
|
Ok(ReadResourceResult {
|
|
contents: vec![ResourceContents::text(memo, uri)],
|
|
})
|
|
}
|
|
_ => Err(McpError::resource_not_found(
|
|
"resource_not_found",
|
|
Some(json!({
|
|
"uri": uri
|
|
})),
|
|
)),
|
|
}
|
|
}
|
|
|
|
async fn list_resource_templates(
|
|
&self,
|
|
_request: Option<PaginatedRequestParam>,
|
|
_: RequestContext<RoleServer>,
|
|
) -> Result<ListResourceTemplatesResult, McpError> {
|
|
Ok(ListResourceTemplatesResult {
|
|
next_cursor: None,
|
|
resource_templates: Vec::new(),
|
|
meta: None,
|
|
})
|
|
}
|
|
|
|
async fn initialize(
|
|
&self,
|
|
request: InitializeRequestParam,
|
|
context: RequestContext<RoleServer>,
|
|
) -> Result<InitializeResult, McpError> {
|
|
trace!("received request: {request:?}");
|
|
trace!("with context: {context:?}");
|
|
Ok(self.get_info())
|
|
}
|
|
}
|
|
|
|
const BIND_ADDRESS: &str = "0.0.0.0:8000";
|
|
|
|
#[tokio::main]
|
|
async fn main() -> anyhow::Result<()> {
|
|
TermLogger::init(
|
|
LevelFilter::Trace,
|
|
ConfigBuilder::default()
|
|
.add_filter_ignore_str("actix_http")
|
|
.add_filter_ignore_str("mio")
|
|
.add_filter_ignore_str("actix_server")
|
|
.add_filter_ignore_str("actix_web")
|
|
.build(),
|
|
TerminalMode::Stderr,
|
|
ColorChoice::Auto,
|
|
)
|
|
.unwrap();
|
|
|
|
let http_service = StreamableHttpService::builder()
|
|
.service_factory(Arc::new(|| Ok(Counter::new())))
|
|
.session_manager(Arc::new(LocalSessionManager::default()))
|
|
.stateful_mode(true)
|
|
.build();
|
|
|
|
HttpServer::new(move || {
|
|
App::new()
|
|
.wrap(Logger::default())
|
|
.route("/health", web::get().to(|| async { "OK" }))
|
|
.service(web::scope("/api/v1/mcp").service(http_service.clone().scope()))
|
|
})
|
|
.bind(BIND_ADDRESS)?
|
|
.run()
|
|
.await?;
|
|
|
|
Ok(())
|
|
}
|