Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions src/reqstool/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import os
import sys
from typing import Optional, TextIO, Union
from typing import Literal, Optional, TextIO, Union, cast

if __package__ is None or len(__package__) == 0:
_script_dir = os.path.abspath(os.path.dirname(__file__))
Expand Down Expand Up @@ -360,10 +360,27 @@ class ComboRawTextandArgsDefaultUltimateHelpFormatter(
mcp_parser = subparsers.add_parser(
"mcp",
help=(
"Start the Model Context Protocol server (stdio). "
"Start the Model Context Protocol server. "
"With no source, auto-detects the dataset from .reqstool-ai.yaml in cwd or an ancestor directory."
),
)
mcp_parser.add_argument(
"--transport",
choices=["stdio", "sse", "streamable-http"],
default="stdio",
help="Transport to use (default: %(default)s)",
)
mcp_parser.add_argument(
"--host",
default="127.0.0.1",
help="Host for HTTP transports (default: %(default)s)",
)
mcp_parser.add_argument(
"--port",
type=int,
default=8000,
help="Port for HTTP transports (default: %(default)s)",
)
mcp_source_subparsers = mcp_parser.add_subparsers(dest="source", required=False)
self._add_subparsers_source(mcp_source_subparsers, include_report_options=False, include_filter_options=False)

Expand Down Expand Up @@ -511,7 +528,12 @@ def command_mcp(self, mcp_args: argparse.Namespace):
location = self._get_initial_source(mcp_args)

try:
start_server(location=location)
start_server(
location=location,
transport=cast(Literal["stdio", "sse", "streamable-http"], mcp_args.transport),
host=mcp_args.host,
port=mcp_args.port,
)
except Exception as exc:
logging.fatal("reqstool MCP server crashed: %s", exc)
sys.exit(1)
Expand Down
16 changes: 14 additions & 2 deletions src/reqstool/mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@


import logging
from typing import Literal

from reqstool.common.project_session import ProjectSession
from reqstool.common.enrichment.enricher import BUILT_IN_PRESETS, enrich_text
Expand All @@ -21,7 +22,12 @@
logger = logging.getLogger(__name__)


def start_server(location: LocationInterface) -> None: # noqa: C901
def start_server( # noqa: C901
location: LocationInterface,
transport: Literal["stdio", "sse", "streamable-http"] = "stdio",
host: str = "127.0.0.1",
port: int = 8000,
) -> None:
try:
from mcp.server.fastmcp import FastMCP
except ImportError as exc:
Expand All @@ -39,6 +45,11 @@ def start_server(location: LocationInterface) -> None: # noqa: C901
urn_source_paths = session.urn_source_paths

mcp = FastMCP("reqstool")
mcp.settings.host = host
mcp.settings.port = port
if transport == "streamable-http":
mcp.settings.json_response = True
mcp.settings.stateless_http = True

@mcp.tool()
def list_requirements(urn: str | None = None, lifecycle_state: str | None = None) -> list[dict]:
Expand Down Expand Up @@ -147,6 +158,7 @@ def enrich_document(content: str, preset: str) -> str:
return enrich_text(content, repo.get_all_requirements(), repo.get_all_svcs(), repo.get_all_mvrs(), config)

try:
mcp.run()
logger.info("Starting reqstool MCP server (transport=%s, host=%s, port=%s)", transport, host, port)
mcp.run(transport=transport)
finally:
session.close()