diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index e6b269e4c..080effbea 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -46,7 +46,6 @@ from configuration import configuration from constants import ( ENDPOINT_PATH_STREAMING_QUERY, - INTERRUPTED_RESPONSE_MESSAGE, LLM_TOKEN_EVENT, LLM_TOOL_CALL_EVENT, LLM_TOOL_RESULT_EVENT, @@ -122,6 +121,7 @@ validate_shield_ids_override, ) from utils.stream_interrupts import ( + build_interrupted_response, deregister_stream, persist_interrupted_turn, register_interrupt_callback, @@ -634,9 +634,10 @@ async def generate_response( # pylint: disable=too-many-arguments,too-many-posi current_task = asyncio.current_task() if current_task is not None: current_task.uncancel() + full_text, suffix = build_interrupted_response(turn_summary.partial_tokens) if not persist_guard[0]: persist_guard[0] = True - turn_summary.llm_response = INTERRUPTED_RESPONSE_MESSAGE + turn_summary.llm_response = full_text await persist_interrupted_turn( context, responses_params, @@ -644,6 +645,11 @@ async def generate_response( # pylint: disable=too-many-arguments,too-many-posi _background_topic_summary_tasks, original_input, ) + yield stream_event( + {"id": turn_summary.next_chunk_id, "token": suffix}, + LLM_TOKEN_EVENT, + context.query_request.media_type or MEDIA_TYPE_JSON, + ) yield stream_interrupted_event(context.request_id) finally: deregister_stream(context.request_id) @@ -774,6 +780,7 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat media_type, ) chunk_id += 1 + turn_summary.next_chunk_id = chunk_id # Store MCP call item info for later lookup when arguments.done event occurs elif event_type == "response.output_item.added": @@ -789,6 +796,7 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat elif event_type == "response.output_text.delta": delta_chunk = cast(TextDeltaChunk, chunk) text_parts.append(delta_chunk.delta) + turn_summary.partial_tokens.append(delta_chunk.delta) yield stream_event( { "id": chunk_id, @@ -798,6 +806,7 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat media_type, ) chunk_id += 1 + turn_summary.next_chunk_id = chunk_id # Final text of the output (capture, but emit at response.completed) elif event_type == "response.output_text.done": @@ -886,6 +895,7 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat media_type, ) chunk_id += 1 + turn_summary.next_chunk_id = chunk_id # Incomplete or failed response - emit error elif event_type in ("response.incomplete", "response.failed"): diff --git a/src/constants.py b/src/constants.py index ebd209a04..a9e522ae9 100644 --- a/src/constants.py +++ b/src/constants.py @@ -12,7 +12,7 @@ UNABLE_TO_PROCESS_RESPONSE: Final[str] = "Unable to process this request" # Response stored in the conversation when the user interrupts a streaming request -INTERRUPTED_RESPONSE_MESSAGE: Final[str] = "You interrupted this request." +INTERRUPTED_RESPONSE_MESSAGE: Final[str] = "Response stopped by the user." # Max seconds to wait for topic summary in background task after interrupt persist. TOPIC_SUMMARY_INTERRUPT_TIMEOUT_SECONDS: Final[float] = 30.0 diff --git a/src/models/common/turn_summary.py b/src/models/common/turn_summary.py index f09e24845..2b342b758 100644 --- a/src/models/common/turn_summary.py +++ b/src/models/common/turn_summary.py @@ -114,6 +114,16 @@ class TurnSummary(BaseModel): description="Structured response output items, captured for compacted-mode " "turn persistence (LCORE-1572). Empty on the non-compacted path.", ) + partial_tokens: list[str] = Field( + default_factory=list, + description="Accumulated text deltas during streaming, used to reconstruct " + "partial content on interruption.", + ) + next_chunk_id: int = Field( + default=0, + description="Next monotonic SSE chunk index, kept in sync with the inner " + "generator so the interrupt handler can emit a sequentially valid id.", + ) class ToolInfoSummary(BaseModel): diff --git a/src/utils/agents/streaming.py b/src/utils/agents/streaming.py index 138852bc2..df4fdd669 100644 --- a/src/utils/agents/streaming.py +++ b/src/utils/agents/streaming.py @@ -25,7 +25,7 @@ ) from configuration import configuration -from constants import INTERRUPTED_RESPONSE_MESSAGE, MEDIA_TYPE_JSON +from constants import MEDIA_TYPE_JSON from log import get_logger from models.common.agents import ( AgentTurnAccumulator, @@ -65,6 +65,7 @@ maybe_get_topic_summary, ) from utils.stream_interrupts import ( + build_interrupted_response, deregister_stream, persist_interrupted_turn, register_interrupt_callback, @@ -197,9 +198,10 @@ async def generate_agent_response( current_task = asyncio.current_task() if current_task is not None: current_task.uncancel() + full_text, suffix = build_interrupted_response(turn_summary.partial_tokens) if not persist_guard[0]: persist_guard[0] = True - turn_summary.llm_response = INTERRUPTED_RESPONSE_MESSAGE + turn_summary.llm_response = full_text await persist_interrupted_turn( context, responses_params, @@ -207,6 +209,12 @@ async def generate_agent_response( background_topic_summary_tasks, original_input, ) + yield serialize_event( + TokenStreamPayload.create( + chunk_id=turn_summary.next_chunk_id, token=suffix + ), + media_type, + ) yield serialize_event( InterruptedStreamPayload.create(request_id=context.request_id), media_type, @@ -347,11 +355,13 @@ def _process_token( Token stream payload containing the emitted token chunk. """ state.text_parts.append(text) + state.turn_summary.partial_tokens.append(text) payload = TokenStreamPayload.create( chunk_id=state.chunk_id, token=text, ) state.chunk_id += 1 + state.turn_summary.next_chunk_id = state.chunk_id return payload @@ -402,6 +412,7 @@ def _( token=final_text, ) state.chunk_id += 1 + state.turn_summary.next_chunk_id = state.chunk_id return payload diff --git a/src/utils/markdown_repair.py b/src/utils/markdown_repair.py new file mode 100644 index 000000000..b3aa59475 --- /dev/null +++ b/src/utils/markdown_repair.py @@ -0,0 +1,105 @@ +"""Utilities for repairing truncated markdown content. + +Used when a streaming response is interrupted mid-content to close +any open markdown constructs (code fences, HTML block tags) that +would otherwise break rendering. +""" + +import re +from typing import Final + +BLOCK_HTML_TAGS: Final[frozenset[str]] = frozenset( + { + "div", + "table", + "tr", + "td", + "th", + "thead", + "tbody", + "details", + "summary", + "pre", + } +) + +_FENCE_RE: Final[re.Pattern[str]] = re.compile(r"^(\s{0,3})((`{3,})|(~{3,}))") +_TAG_RE: Final[re.Pattern[str]] = re.compile(r"<(/?)(\w+)([^>]*?)(/?)>") + + +def _process_html_tags(line: str, html_stack: list[str]) -> None: + """Update *html_stack* with block-level HTML open/close tags found in *line*. + + Parameters: + line: A single line of text to scan for HTML tags. + html_stack: Mutable stack tracking open block-level tags. + """ + for tag_match in _TAG_RE.finditer(line): + is_closing = tag_match.group(1) == "/" + tag_name = tag_match.group(2).lower() + is_self_closing = tag_match.group(4) == "/" + + if tag_name not in BLOCK_HTML_TAGS or is_self_closing: + continue + + if is_closing: + if html_stack and html_stack[-1] == tag_name: + html_stack.pop() + else: + html_stack.append(tag_name) + + +def close_open_markdown(text: str) -> str: + """Return a suffix that closes any open markdown constructs in *text*. + + Scans for unclosed fenced code blocks and unclosed HTML block-level + tags. Returns only the closing characters (callers append the result). + Returns an empty string when nothing needs closing. + + Parameters: + text: Partial markdown content that may contain open constructs. + + Returns: + A suffix string to append that closes open constructs. + """ + if not text or not text.strip(): + return "" + + lines = text.split("\n") + in_code_fence = False + fence_char = "" + fence_len = 0 + html_stack: list[str] = [] + + for line in lines: + fence_match = _FENCE_RE.match(line) + if not fence_match: + if not in_code_fence: + _process_html_tags(line, html_stack) + continue + + group_3 = fence_match.group(3) + group_4 = fence_match.group(4) + matched_group = group_3 or group_4 + char = "`" if group_3 else "~" + if not in_code_fence: + in_code_fence = True + fence_char = char + fence_len = len(matched_group) + elif ( + char == fence_char + and len(matched_group) >= fence_len + and line[fence_match.end() :].strip(" \t") == "" + ): + in_code_fence = False + fence_char = "" + fence_len = 0 + + suffix_parts: list[str] = [] + if in_code_fence: + suffix_parts.append(f"\n{fence_char * fence_len}") + + for tag in reversed(html_stack): + suffix_parts.append(f"\n{tag}>") + + return "".join(suffix_parts) diff --git a/src/utils/stream_interrupts.py b/src/utils/stream_interrupts.py index 522a4e52f..5afaf92f8 100644 --- a/src/utils/stream_interrupts.py +++ b/src/utils/stream_interrupts.py @@ -20,6 +20,7 @@ from models.common.responses.types import ResponseInput from models.common.turn_summary import TurnSummary from utils.conversations import append_turn_items_to_conversation +from utils.markdown_repair import close_open_markdown from utils.query import store_query_results, update_conversation_topic_summary from utils.responses import get_topic_summary from utils.shields import append_turn_to_conversation @@ -215,6 +216,28 @@ async def background_update_topic_summary( ) +def build_interrupted_response(partial_tokens: list[str]) -> tuple[str, str]: + """Build the final interrupted response text from accumulated tokens. + + Joins partial tokens, repairs any open markdown constructs, and appends + an italicized interruption indicator. + + Parameters: + partial_tokens: List of text deltas accumulated during streaming. + + Returns: + A tuple of (full_response_text, suffix_to_emit) where full_response_text + is the complete message for persistence and suffix_to_emit is the new + content to send as a final SSE token event. + """ + partial_text = "".join(partial_tokens) + repaired_text = close_open_markdown(partial_text) + interrupted_indicator = f"\n\n*{INTERRUPTED_RESPONSE_MESSAGE}*" + suffix = repaired_text + interrupted_indicator + final_text = partial_text + suffix + return final_text, suffix + + async def persist_interrupted_turn( context: ResponseGeneratorContext, responses_params: ResponsesApiParams, @@ -251,7 +274,7 @@ async def persist_interrupted_turn( original_input, [ OpenAIResponseMessage( - role="assistant", content=INTERRUPTED_RESPONSE_MESSAGE + role="assistant", content=turn_summary.llm_response ) ], ) @@ -260,7 +283,7 @@ async def persist_interrupted_turn( context.client, responses_params.conversation, cast(str, responses_params.input), - INTERRUPTED_RESPONSE_MESSAGE, + turn_summary.llm_response, ) except Exception: # pylint: disable=broad-except logger.exception( @@ -342,7 +365,8 @@ async def _on_interrupt() -> None: if guard[0]: return guard[0] = True - turn_summary.llm_response = INTERRUPTED_RESPONSE_MESSAGE + full_text, _ = build_interrupted_response(turn_summary.partial_tokens) + turn_summary.llm_response = full_text await persist_interrupted_turn( context, responses_params, diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index dd5efd227..fb2d9027b 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -51,6 +51,7 @@ ) from configuration import AppConfig from constants import ( + INTERRUPTED_RESPONSE_MESSAGE, MEDIA_TYPE_JSON, MEDIA_TYPE_TEXT, ) @@ -70,6 +71,8 @@ from utils.stream_interrupts import StreamInterruptRegistry from utils.token_counter import TokenCounter +INTERRUPTED_INDICATOR = f"\n\n*{INTERRUPTED_RESPONSE_MESSAGE}*" + MOCK_AUTH_STREAMING = ( "00000001-0001-0001-0001-000000000001", "mock_username", @@ -1379,6 +1382,7 @@ async def mock_generator() -> AsyncIterator[str]: result.append(item) assert any("start" in item for item in result) + assert any('"event": "token"' in item for item in result) assert any('"event": "interrupted"' in item for item in result) assert not any('"event": "end"' in item for item in result) consume_query_tokens_mock.assert_not_called() @@ -1387,13 +1391,13 @@ async def mock_generator() -> AsyncIterator[str]: mock_context.client, existing_conv_id, "test", - "You interrupted this request.", + INTERRUPTED_INDICATOR, ) store_query_results_mock.assert_called_once() call_kwargs = store_query_results_mock.call_args[1] assert call_kwargs["user_id"] == "user_123" assert call_kwargs["conversation_id"] == existing_conv_id - assert call_kwargs["summary"].llm_response == "You interrupted this request." + assert call_kwargs["summary"].llm_response == INTERRUPTED_INDICATOR assert call_kwargs["topic_summary"] is None isolate_stream_interrupt_registry.deregister_stream.assert_called_once_with( diff --git a/tests/unit/utils/agents/test_streaming.py b/tests/unit/utils/agents/test_streaming.py index b93c314cc..6c34c0645 100644 --- a/tests/unit/utils/agents/test_streaming.py +++ b/tests/unit/utils/agents/test_streaming.py @@ -64,6 +64,8 @@ ) from utils.token_counter import TokenCounter +INTERRUPTED_INDICATOR = f"\n\n*{INTERRUPTED_RESPONSE_MESSAGE}*" + TEST_CONVERSATION_ID = "123e4567-e89b-12d3-a456-426614174000" @@ -715,9 +717,9 @@ async def inner() -> AsyncIterator[str]: ) ] - assert _sse_event_types(result) == ["start", "token", "interrupted"] + assert _sse_event_types(result) == ["start", "token", "token", "interrupted"] persist_mock.assert_awaited_once() - assert turn_summary.llm_response == INTERRUPTED_RESPONSE_MESSAGE + assert turn_summary.llm_response == INTERRUPTED_INDICATOR stream_interrupt_mocks["deregister"].assert_called_once_with(context.request_id) @pytest.mark.asyncio @@ -808,7 +810,7 @@ async def inner() -> AsyncIterator[str]: ) ] - assert _sse_event_types(result) == ["start", "token", "interrupted"] + assert _sse_event_types(result) == ["start", "token", "token", "interrupted"] persist_mock.assert_not_awaited() @@ -961,6 +963,147 @@ async def test_no_run_result_logs_and_returns_early( assert turn_summary.token_usage.input_tokens == 0 +class TestInterruptPartialTokenAccumulation: + """Tests verifying real partial-text accumulation through the streaming pipeline on interrupt.""" + + @pytest.mark.asyncio + async def test_interrupt_accumulates_partial_tokens_and_persists( + self, + mocker: MockerFixture, + make_generator_context: Callable[..., ResponseGeneratorContext], + responses_params: ResponsesApiParams, + ) -> None: + """Cancel mid-stream through agent_response_generator and verify partial content is accumulated, repaired, and persisted.""" + context = make_generator_context() + turn_summary = TurnSummary() + background_tasks: list[asyncio.Task[None]] = [] + + events_before_cancel = [ + PartStartEvent(index=0, part=TextPart(content="Hello")), + PartDeltaEvent(index=0, delta=TextPartDelta(content_delta=" world")), + ] + + def _cancelling_run_stream( + events: list[Any], + ) -> Any: + async def _event_stream() -> AsyncIterator[Any]: + for event in events: + yield event + raise asyncio.CancelledError() + + class _Ctx: + """Async context manager that cancels after yielding events.""" + + async def __aenter__(self) -> AsyncIterator[Any]: + return _event_stream() + + async def __aexit__(self, *_args: object) -> None: + return None + + return _Ctx() + + mock_agent = mocker.Mock() + mock_agent.run_stream_events.return_value = _cancelling_run_stream( + events_before_cancel + ) + + persist_mock = mocker.patch( + "utils.agents.streaming.persist_interrupted_turn", + new=mocker.AsyncMock(), + ) + mocker.patch( + "utils.agents.streaming.register_interrupt_callback", + return_value=[False], + ) + + inner = agent_response_generator( + mock_agent, + responses_params, + context, + turn_summary, + ENDPOINT_PATH_STREAMING_QUERY, + ) + + result = [ + event + async for event in generate_agent_response( + inner, + context, + responses_params, + turn_summary, + background_tasks, + ) + ] + + event_types = _sse_event_types(result) + assert event_types == ["start", "token", "token", "token", "interrupted"] + + assert turn_summary.partial_tokens == ["Hello", " world"] + + assert "Hello world" in turn_summary.llm_response + assert INTERRUPTED_RESPONSE_MESSAGE in turn_summary.llm_response + + persist_mock.assert_awaited_once() + + token_events = [ + json.loads(e.removeprefix("data: ").strip()) + for e in result + if e.startswith("data: ") + and json.loads(e.removeprefix("data: ").strip())["event"] == "token" + ] + chunk_ids = [t["data"]["id"] for t in token_events] + assert chunk_ids == sorted(chunk_ids), "chunk_ids must be monotonically ordered" + assert all(cid >= 0 for cid in chunk_ids), "all chunk_ids must be non-negative" + assert chunk_ids[-1] == len(chunk_ids) - 1 + + @pytest.mark.asyncio + async def test_interrupt_with_no_tokens_uses_zero_chunk_id( + self, + mocker: MockerFixture, + make_generator_context: Callable[..., ResponseGeneratorContext], + responses_params: ResponsesApiParams, + ) -> None: + """Cancel before any tokens are emitted; interrupt suffix should use chunk_id 0.""" + context = make_generator_context() + turn_summary = TurnSummary() + background_tasks: list[asyncio.Task[None]] = [] + + async def inner() -> AsyncIterator[str]: + raise asyncio.CancelledError() + yield "" # pragma: no cover + + persist_mock = mocker.patch( + "utils.agents.streaming.persist_interrupted_turn", + new=mocker.AsyncMock(), + ) + mocker.patch( + "utils.agents.streaming.register_interrupt_callback", + return_value=[False], + ) + + result = [ + event + async for event in generate_agent_response( + inner(), + context, + responses_params, + turn_summary, + background_tasks, + ) + ] + + token_events = [ + json.loads(e.removeprefix("data: ").strip()) + for e in result + if e.startswith("data: ") + and json.loads(e.removeprefix("data: ").strip())["event"] == "token" + ] + assert len(token_events) == 1 + assert token_events[0]["data"]["id"] == 0 + + persist_mock.assert_awaited_once() + + def _sse_event_types(events: list[str]) -> list[str]: """Extract SSE event types from serialized stream lines.""" types: list[str] = [] diff --git a/tests/unit/utils/test_markdown_repair.py b/tests/unit/utils/test_markdown_repair.py new file mode 100644 index 000000000..1ad3908a4 --- /dev/null +++ b/tests/unit/utils/test_markdown_repair.py @@ -0,0 +1,174 @@ +"""Unit tests for markdown repair utilities.""" + +from utils.markdown_repair import close_open_markdown + + +class TestCloseOpenMarkdownCodeFences: + """Tests for closing unclosed code fences.""" + + def test_unclosed_backtick_fence(self) -> None: + """Unclosed triple-backtick fence gets closed.""" + text = "Some text\n```\ncode here" + result = close_open_markdown(text) + assert result == "\n```" + + def test_unclosed_tilde_fence(self) -> None: + """Unclosed tilde fence gets closed with tildes.""" + text = "Some text\n~~~\ncode here" + result = close_open_markdown(text) + assert result == "\n~~~" + + def test_unclosed_fence_with_language(self) -> None: + """Unclosed fence with language specifier gets closed.""" + text = "Some text\n```python\ndef foo():\n pass" + result = close_open_markdown(text) + assert result == "\n```" + + def test_closed_fence_no_repair(self) -> None: + """Properly closed fence needs no repair.""" + text = "Some text\n```\ncode\n```\nmore text" + result = close_open_markdown(text) + assert result == "" + + def test_multiple_fences_last_unclosed(self) -> None: + """Multiple fences where only the last is unclosed.""" + text = "```\nfirst\n```\n\n```\nsecond block" + result = close_open_markdown(text) + assert result == "\n```" + + def test_backticks_mid_line_not_fence(self) -> None: + """Triple backticks not at line start are not fences.""" + text = "Use ```code``` for inline code" + result = close_open_markdown(text) + assert result == "" + + def test_fence_no_trailing_newline(self) -> None: + """Unclosed fence at end of text without trailing newline.""" + text = "```\ncode" + result = close_open_markdown(text) + assert result == "\n```" + + def test_fence_with_text_after_backticks(self) -> None: + """Backticks at line start followed by text is a fence with info string.""" + text = "```this is my sentence" + result = close_open_markdown(text) + assert result == "\n```" + + def test_shorter_fence_inside_longer_fence_not_closer(self) -> None: + """A 3-backtick line inside a 4-backtick fence is content, not a closer.""" + text = "````python\ndef foo():\n```\n pass" + result = close_open_markdown(text) + assert result == "\n````" + + def test_fence_with_trailing_text_not_closer(self) -> None: + """A fence marker with trailing non-whitespace inside a fence is content.""" + text = "```python\nprint('x')\n```not a closer" + result = close_open_markdown(text) + assert result == "\n```" + + def test_fence_with_trailing_whitespace_is_closer(self) -> None: + """A fence marker with only trailing whitespace is a valid closer.""" + text = "```python\nprint('x')\n``` " + result = close_open_markdown(text) + assert result == "" + + +class TestCloseOpenMarkdownHtmlTags: + """Tests for closing unclosed HTML block tags.""" + + def test_unclosed_div(self) -> None: + """Single unclosed div tag gets closed.""" + text = "
| cell |
\nformatted text" + result = close_open_markdown(text) + assert result == "\n" + + def test_multiple_unclosed_tags_reversed(self) -> None: + """Multiple unclosed tags are closed in reverse order.""" + text = "