From a5ce04bea6a46123ac1d84cbd868ce0c226a015c Mon Sep 17 00:00:00 2001 From: "congxiao.wxx" Date: Thu, 25 Jun 2026 17:57:23 +0800 Subject: [PATCH 1/6] Normalize model rate limits into AG-UI run errors Keep the change inside the SDK error-to-AG-UI boundary: detect structured 429/rate-limit signals, preserve RATE_LIMITED retry metadata on RUN_ERROR, and keep non-rate errors on their existing codes. Constraint: Aone 83566999 requires model limit errors to end the AG-UI stream with RUN_ERROR rather than client-delivered text. Rejected: funagent-core/front-end rewrite | downstream AG-UI frames are already pass-through and the SDK owns event semantics. Rejected: broad generic error framework | current need is a small rate-limit normalizer. Confidence: high Scope-risk: narrow Directive: Keep RUN_ERROR payload extensions whitelisted and preserve SSE event framing when adding fields. Tested: uv run --extra server pytest tests/unittests/server/test_invoker.py tests/unittests/server/test_agui_protocol.py tests/unittests/integration/test_langgraph_events.py -q => 101 passed, 1 warning Tested: uv run --extra server pytest tests/unittests/integration/test_langgraph_to_agent_event.py -q => 32 passed Tested: git diff --check && uv run ruff check agentrun/server/error_utils.py agentrun/server/invoker.py agentrun/server/agui_protocol.py agentrun/integration/langgraph/agent_converter.py => passed Tested: UltraQA dynamic AG-UI harness UQA-1..UQA-4 => passed Change-Id: Ice926e2c21201071713ac39faeed736078fc5823 Co-developed-by: Codex Not-tested: full repository test suite and remote CI pending Signed-off-by: congxiao.wxx --- .../integration/langgraph/agent_converter.py | 58 +++---- agentrun/server/agui_protocol.py | 20 ++- agentrun/server/error_utils.py | 159 ++++++++++++++++++ agentrun/server/invoker.py | 7 +- .../integration/test_langgraph_events.py | 41 ++++- .../test_langgraph_to_agent_event.py | 41 ++++- tests/unittests/server/test_agui_protocol.py | 89 ++++++++++ tests/unittests/server/test_invoker.py | 136 +++++++++++++++ 8 files changed, 507 insertions(+), 44 deletions(-) create mode 100644 agentrun/server/error_utils.py diff --git a/agentrun/integration/langgraph/agent_converter.py b/agentrun/integration/langgraph/agent_converter.py index f00a46b..d69c8c9 100644 --- a/agentrun/integration/langgraph/agent_converter.py +++ b/agentrun/integration/langgraph/agent_converter.py @@ -29,6 +29,10 @@ import json from typing import Any, Dict, Iterator, List, Optional, Union +from agentrun.server.error_utils import ( + build_error_event_data, + format_error_message, +) from agentrun.server.model import AgentResult, EventType from agentrun.utils.log import logger @@ -941,71 +945,53 @@ def _convert_astream_events_event( # 7. LLM 错误 elif event_type == "on_llm_error": error = data.get("error") - error_message = "" - if error is not None: - if isinstance(error, Exception): - error_message = f"{type(error).__name__}: {str(error)}" - elif isinstance(error, str): - error_message = error - else: - error_message = str(error) + error_message = format_error_message(error) yield AgentResult( event=EventType.ERROR, - data={ - "message": f"LLM error: {error_message}", - "code": "LLM_ERROR", - }, + data=build_error_event_data( + error, + fallback_code="LLM_ERROR", + fallback_message=f"LLM error: {error_message}", + ), ) # 8. Chain 错误 elif event_type == "on_chain_error": error = data.get("error") chain_name = event_dict.get("name", "") - error_message = "" - if error is not None: - if isinstance(error, Exception): - error_message = f"{type(error).__name__}: {str(error)}" - elif isinstance(error, str): - error_message = error - else: - error_message = str(error) + error_message = format_error_message(error) yield AgentResult( event=EventType.ERROR, - data={ - "message": ( + data=build_error_event_data( + error, + fallback_code="CHAIN_ERROR", + fallback_message=( f"Chain '{chain_name}' error: {error_message}" if chain_name else error_message ), - "code": "CHAIN_ERROR", - }, + ), ) # 9. Retriever 错误 elif event_type == "on_retriever_error": error = data.get("error") retriever_name = event_dict.get("name", "") - error_message = "" - if error is not None: - if isinstance(error, Exception): - error_message = f"{type(error).__name__}: {str(error)}" - elif isinstance(error, str): - error_message = error - else: - error_message = str(error) + error_message = format_error_message(error) yield AgentResult( event=EventType.ERROR, - data={ - "message": ( + data=build_error_event_data( + error, + fallback_code="RETRIEVER_ERROR", + fallback_message=( f"Retriever '{retriever_name}' error: {error_message}" if retriever_name else error_message ), - "code": "RETRIEVER_ERROR", - }, + ), ) # ========================================================================= diff --git a/agentrun/server/agui_protocol.py b/agentrun/server/agui_protocol.py index 5e8ccb4..70a6bb8 100644 --- a/agentrun/server/agui_protocol.py +++ b/agentrun/server/agui_protocol.py @@ -53,6 +53,7 @@ # ============================================================================ DEFAULT_PREFIX = "/ag-ui/agent" +RUN_ERROR_EXTRA_FIELDS = ("retryable", "retryAfterMs", "traceId") @dataclass @@ -743,12 +744,21 @@ def _process_event_with_boundaries( # ERROR 事件 if event.event == EventType.ERROR: - yield self._encoder.encode( - RunErrorEvent( - message=event.data.get("message", ""), - code=event.data.get("code"), - ) + agui_event = RunErrorEvent( + message=event.data.get("message", ""), + code=event.data.get("code"), ) + event_dict = agui_event.model_dump(by_alias=True, exclude_none=True) + for key in RUN_ERROR_EXTRA_FIELDS: + value = event.data.get(key) + if value is not None: + event_dict[key] = value + elif event.addition: + value = event.addition.get(key) + if value is not None: + event_dict[key] = value + json_str = json.dumps(event_dict, ensure_ascii=False) + yield f"event: RUN_ERROR\ndata: {json_str}\n\n" return # STATE 事件 diff --git a/agentrun/server/error_utils.py b/agentrun/server/error_utils.py new file mode 100644 index 0000000..60f9bb0 --- /dev/null +++ b/agentrun/server/error_utils.py @@ -0,0 +1,159 @@ +"""Error helpers for AgentRun server streams.""" + +import re +from typing import Any, Dict, Optional + +RATE_LIMITED_CODE = "RATE_LIMITED" +RATE_LIMITED_MESSAGE = "模型当前请求过多,请稍后再试" +RATE_LIMITED_RETRY_AFTER_MS = 2000 + +_RATE_LIMIT_CODES = { + "ratelimitexceeded", + "ratelimited", + "resourcethrottled", + "throttling", + "throttlingquota", + "throttlingratequota", + "throttlingexception", + "toomanyrequests", +} + +_RATE_LIMIT_TEXT_PATTERNS = [ + re.compile(r"\btoo[-_\s]*many[-_\s]*requests\b", re.IGNORECASE), + re.compile( + r"\brate[-_\s]*limit(?:ed|[-_\s]*exceeded)\b", + re.IGNORECASE, + ), + re.compile(r"\bresource[-_\s]*throttled\b", re.IGNORECASE), + re.compile(r"\bthrottling(?:exception| exception| error)\b", re.IGNORECASE), + re.compile( + r"\b(?:http|status|status code|code)\s*[:=]?\s*429\b", + re.IGNORECASE, + ), +] + + +def format_error_message(error: Any) -> str: + """Format errors consistently with existing LangGraph conversion.""" + if error is None: + return "" + if isinstance(error, Exception): + return f"{type(error).__name__}: {str(error)}" + return str(error) + + +def build_error_event_data( + error: Any, + *, + fallback_code: str, + fallback_message: str, +) -> Dict[str, Any]: + """Build AgentEvent ERROR data, normalizing model rate limits.""" + if not is_rate_limited_error(error): + return {"message": fallback_message, "code": fallback_code} + + data: Dict[str, Any] = { + "message": RATE_LIMITED_MESSAGE, + "code": RATE_LIMITED_CODE, + "retryable": True, + "retryAfterMs": RATE_LIMITED_RETRY_AFTER_MS, + } + trace_id = _extract_trace_id(error) + if trace_id: + data["traceId"] = str(trace_id) + return data + + +def is_rate_limited_error(error: Any) -> bool: + """Return whether an error carries an explicit rate-limit signal.""" + if _extract_status_code(error) == 429: + return True + + if _has_rate_limit_code(error): + return True + + message = str(error or "") + return any(pattern.search(message) for pattern in _RATE_LIMIT_TEXT_PATTERNS) + + +def _extract_status_code(error: Any) -> Optional[int]: + fallback = None + for obj in (error, _get_value(error, "response")): + if obj is None: + continue + for name in ("status_code", "status", "http_status", "statusCode"): + status_code = _to_int(_get_value(obj, name)) + if status_code == 429: + return status_code + if fallback is None and status_code is not None: + fallback = status_code + return fallback + + +def _has_rate_limit_code(error: Any) -> bool: + for obj in (error, _get_value(error, "response")): + if obj is None: + continue + for name in ("code", "error_code", "errorCode"): + error_code = _get_value(obj, name) + if ( + error_code is not None + and _normalize_code(error_code) in _RATE_LIMIT_CODES + ): + return True + return False + + +def _extract_trace_id(error: Any) -> Optional[Any]: + for name in ("trace_id", "traceId", "request_id", "requestId"): + trace_id = _get_value(error, name) + if trace_id: + return trace_id + + response = _get_value(error, "response") + headers = _get_value(response, "headers") + if not headers: + return None + + for name in ("x-acs-request-id", "x-request-id", "x-trace-id", "trace-id"): + trace_id = _get_header(headers, name) + if trace_id: + return trace_id + return None + + +def _get_value(obj: Any, name: str) -> Optional[Any]: + if obj is None: + return None + if isinstance(obj, dict): + return obj.get(name) + return getattr(obj, name, None) + + +def _get_header(headers: Any, name: str) -> Optional[Any]: + if isinstance(headers, dict): + for key, value in headers.items(): + if str(key).lower() == name: + return value + return None + get = getattr(headers, "get", None) + if callable(get): + return get(name) + return None + + +def _normalize_code(code: Any) -> str: + normalized = re.sub(r"[^a-z0-9]", "", str(code).lower()) + for suffix in ("exception", "error"): + if normalized.endswith(suffix): + return normalized[: -len(suffix)] + return normalized + + +def _to_int(value: Any) -> Optional[int]: + if value is None: + return None + try: + return int(value) + except (TypeError, ValueError): + return None diff --git a/agentrun/server/invoker.py b/agentrun/server/invoker.py index 763e6a0..6f56ff7 100644 --- a/agentrun/server/invoker.py +++ b/agentrun/server/invoker.py @@ -24,6 +24,7 @@ ) import uuid +from .error_utils import build_error_event_data from .model import AgentEvent, AgentRequest, EventType from .protocol import ( AsyncInvokeAgentHandler, @@ -142,7 +143,11 @@ async def invoke_stream( logger.error(f"Agent 调用出错: {e}", exc_info=True) yield AgentEvent( event=EventType.ERROR, - data={"message": str(e), "code": type(e).__name__}, + data=build_error_event_data( + e, + fallback_code=type(e).__name__, + fallback_message=str(e), + ), ) def _process_user_event( diff --git a/tests/unittests/integration/test_langgraph_events.py b/tests/unittests/integration/test_langgraph_events.py index 0e51714..a8bd11b 100644 --- a/tests/unittests/integration/test_langgraph_events.py +++ b/tests/unittests/integration/test_langgraph_events.py @@ -816,7 +816,7 @@ def test_on_llm_error(self): "event": "on_llm_error", "run_id": "run_llm", "data": { - "error": RuntimeError("API rate limit exceeded"), + "error": RuntimeError("Model backend failed"), }, } @@ -828,6 +828,45 @@ def test_on_llm_error(self): assert "RuntimeError" in results[0].data["message"] assert results[0].data["code"] == "LLM_ERROR" + def test_on_llm_error_rate_limited(self): + """测试 on_llm_error 限流错误归一化""" + event = { + "event": "on_llm_error", + "run_id": "run_llm", + "data": { + "error": RuntimeError("API rate limit exceeded"), + }, + } + + results = list(AgentRunConverter().to_agui_events(event)) + + assert len(results) == 1 + assert results[0].event == EventType.ERROR + assert results[0].data["message"] == "模型当前请求过多,请稍后再试" + assert results[0].data["code"] == "RATE_LIMITED" + assert results[0].data["retryable"] is True + assert results[0].data["retryAfterMs"] == 2000 + + def test_on_llm_error_rate_limit_text_false_positive(self): + """测试说明性 rate limit/429 文本不会误判""" + event = { + "event": "on_llm_error", + "run_id": "run_llm", + "data": { + "error": RuntimeError( + "ticket 429 mentions rate limit dashboard, auth failed" + ), + }, + } + + results = list(AgentRunConverter().to_agui_events(event)) + + assert len(results) == 1 + assert results[0].event == EventType.ERROR + assert results[0].data["code"] == "LLM_ERROR" + assert "retryable" not in results[0].data + assert "retryAfterMs" not in results[0].data + def test_on_chain_error(self): """测试 on_chain_error 事件 diff --git a/tests/unittests/integration/test_langgraph_to_agent_event.py b/tests/unittests/integration/test_langgraph_to_agent_event.py index 74933d1..491d4ff 100644 --- a/tests/unittests/integration/test_langgraph_to_agent_event.py +++ b/tests/unittests/integration/test_langgraph_to_agent_event.py @@ -814,7 +814,7 @@ def test_on_llm_error(self): "event": "on_llm_error", "run_id": "run_llm", "data": { - "error": RuntimeError("API rate limit exceeded"), + "error": RuntimeError("Model backend failed"), }, } @@ -826,6 +826,45 @@ def test_on_llm_error(self): assert "RuntimeError" in results[0].data["message"] assert results[0].data["code"] == "LLM_ERROR" + def test_on_llm_error_rate_limited(self): + """测试 on_llm_error 限流错误归一化""" + event = { + "event": "on_llm_error", + "run_id": "run_llm", + "data": { + "error": RuntimeError("API rate limit exceeded"), + }, + } + + results = list(AgentRunConverter().to_agui_events(event)) + + assert len(results) == 1 + assert results[0].event == EventType.ERROR + assert results[0].data["message"] == "模型当前请求过多,请稍后再试" + assert results[0].data["code"] == "RATE_LIMITED" + assert results[0].data["retryable"] is True + assert results[0].data["retryAfterMs"] == 2000 + + def test_on_llm_error_rate_limit_text_false_positive(self): + """测试说明性 rate limit/429 文本不会误判""" + event = { + "event": "on_llm_error", + "run_id": "run_llm", + "data": { + "error": RuntimeError( + "ticket 429 mentions rate limit dashboard, auth failed" + ), + }, + } + + results = list(AgentRunConverter().to_agui_events(event)) + + assert len(results) == 1 + assert results[0].event == EventType.ERROR + assert results[0].data["code"] == "LLM_ERROR" + assert "retryable" not in results[0].data + assert "retryAfterMs" not in results[0].data + def test_on_chain_error(self): """测试 on_chain_error 事件 diff --git a/tests/unittests/server/test_agui_protocol.py b/tests/unittests/server/test_agui_protocol.py index 0896a08..6dcf32a 100644 --- a/tests/unittests/server/test_agui_protocol.py +++ b/tests/unittests/server/test_agui_protocol.py @@ -101,6 +101,95 @@ def invoke_agent(request: AgentRequest): assert "RUN_ERROR" in types + @pytest.mark.asyncio + async def test_error_event_addition_fields_preserved(self): + """测试 RUN_ERROR 编码保留扩展字段""" + + def invoke_agent(request: AgentRequest): + yield AgentEvent( + event=EventType.ERROR, + data={ + "message": "模型当前请求过多,请稍后再试", + "code": "RATE_LIMITED", + "retryable": True, + "type": "BROKEN", + }, + addition={ + "retryAfterMs": 2000, + "traceId": "trace-xyz", + "type": "BROKEN", + }, + ) + + client = self.get_client(invoke_agent) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "Hello"}]}, + ) + + assert response.status_code == 200 + assert "event: RUN_ERROR" in response.text + events = _agui_sse_events(response) + run_error = next( + event for event in events if event.get("type") == "RUN_ERROR" + ) + assert run_error["type"] == "RUN_ERROR" + assert run_error["code"] == "RATE_LIMITED" + assert run_error["retryable"] is True + assert run_error["retryAfterMs"] == 2000 + assert run_error["traceId"] == "trace-xyz" + + @pytest.mark.asyncio + async def test_rate_limit_error_stream_payload(self): + """测试 429 错误输出结构化 RUN_ERROR 且无 RUN_FINISHED""" + + class RateLimitError(RuntimeError): + status_code = 429 + + def invoke_agent(request: AgentRequest): + raise RateLimitError("model overloaded") + + client = self.get_client(invoke_agent) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "Hello"}]}, + ) + + assert response.status_code == 200 + assert "event: RUN_ERROR" in response.text + events = _agui_sse_events(response) + types = [event.get("type") for event in events] + run_error = next( + event for event in events if event.get("type") == "RUN_ERROR" + ) + assert "RUN_FINISHED" not in types + assert run_error["message"] == "模型当前请求过多,请稍后再试" + assert run_error["code"] == "RATE_LIMITED" + assert run_error["retryable"] is True + assert run_error["retryAfterMs"] == 2000 + + @pytest.mark.asyncio + async def test_non_rate_limit_error_stream_payload(self): + """测试普通错误不会被误标为限流""" + + def invoke_agent(request: AgentRequest): + raise RuntimeError("boom") + + client = self.get_client(invoke_agent) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "Hello"}]}, + ) + + assert response.status_code == 200 + events = _agui_sse_events(response) + run_error = next( + event for event in events if event.get("type") == "RUN_ERROR" + ) + assert run_error["code"] == "RuntimeError" + assert "retryable" not in run_error + assert "retryAfterMs" not in run_error + @pytest.mark.asyncio async def test_exception_in_parse_request(self): """测试 parse_request 中的异常处理(覆盖 155-156 行) diff --git a/tests/unittests/server/test_invoker.py b/tests/unittests/server/test_invoker.py index 76d8c65..b3e97f3 100644 --- a/tests/unittests/server/test_invoker.py +++ b/tests/unittests/server/test_invoker.py @@ -188,6 +188,142 @@ async def invoke_agent(req: AgentRequest) -> str: assert "Test error" in error_event.data["message"] assert error_event.data["code"] == "ValueError" + @pytest.mark.asyncio + async def test_invoke_stream_rate_limit_error(self, req): + """测试模型限流错误被归一化""" + + class RateLimitError(RuntimeError): + status_code = 429 + trace_id = "trace-123" + + async def invoke_agent(req: AgentRequest) -> str: + raise RateLimitError("model overloaded") + + invoker = AgentInvoker(invoke_agent) + + items: List[AgentEvent] = [] + async for item in invoker.invoke_stream(req): + items.append(item) + + error_event = next( + item for item in items if item.event == EventType.ERROR + ) + assert error_event.data["message"] == "模型当前请求过多,请稍后再试" + assert error_event.data["code"] == "RATE_LIMITED" + assert error_event.data["retryable"] is True + assert error_event.data["retryAfterMs"] == 2000 + assert error_event.data["traceId"] == "trace-123" + + @pytest.mark.asyncio + async def test_invoke_stream_response_rate_limit_error(self, req): + """测试 response.status_code=429 不被顶层非 429 状态掩盖""" + + class RateLimitError(RuntimeError): + status_code = 0 + code = "Other" + response = {"status_code": 429} + + async def invoke_agent(req: AgentRequest) -> str: + raise RateLimitError("model overloaded") + + invoker = AgentInvoker(invoke_agent) + + items: List[AgentEvent] = [] + async for item in invoker.invoke_stream(req): + items.append(item) + + error_event = next( + item for item in items if item.event == EventType.ERROR + ) + assert error_event.data["code"] == "RATE_LIMITED" + assert error_event.data["retryable"] is True + + @pytest.mark.asyncio + async def test_invoke_stream_rate_limit_code_exception(self, req): + """测试带 Exception 后缀的限流错误码被识别""" + + class RateLimitError(RuntimeError): + code = "TooManyRequestsException" + + async def invoke_agent(req: AgentRequest) -> str: + raise RateLimitError("model overloaded") + + invoker = AgentInvoker(invoke_agent) + + items: List[AgentEvent] = [] + async for item in invoker.invoke_stream(req): + items.append(item) + + error_event = next( + item for item in items if item.event == EventType.ERROR + ) + assert error_event.data["code"] == "RATE_LIMITED" + assert error_event.data["retryAfterMs"] == 2000 + + @pytest.mark.asyncio + async def test_invoke_stream_response_rate_limit_code(self, req): + """测试 response.code 不被顶层非限流 code 掩盖""" + + class RateLimitError(RuntimeError): + code = "Other" + response = {"code": "TooManyRequests"} + + async def invoke_agent(req: AgentRequest) -> str: + raise RateLimitError("model overloaded") + + invoker = AgentInvoker(invoke_agent) + + items: List[AgentEvent] = [] + async for item in invoker.invoke_stream(req): + items.append(item) + + error_event = next( + item for item in items if item.event == EventType.ERROR + ) + assert error_event.data["code"] == "RATE_LIMITED" + assert error_event.data["retryAfterMs"] == 2000 + + @pytest.mark.asyncio + async def test_invoke_stream_rate_limit_snake_case_text(self, req): + """测试明确 snake_case 限流文本被识别""" + + async def invoke_agent(req: AgentRequest) -> str: + raise RuntimeError("rate_limit_exceeded: retry later") + + invoker = AgentInvoker(invoke_agent) + + items: List[AgentEvent] = [] + async for item in invoker.invoke_stream(req): + items.append(item) + + error_event = next( + item for item in items if item.event == EventType.ERROR + ) + assert error_event.data["code"] == "RATE_LIMITED" + assert error_event.data["retryAfterMs"] == 2000 + + @pytest.mark.asyncio + async def test_invoke_stream_rate_limit_text_false_positive(self, req): + """测试说明性 rate limit/429 文本不会被误判为限流""" + + async def invoke_agent(req: AgentRequest) -> str: + raise RuntimeError( + "ticket 429 mentions rate limit dashboard, auth failed" + ) + + invoker = AgentInvoker(invoke_agent) + + items: List[AgentEvent] = [] + async for item in invoker.invoke_stream(req): + items.append(item) + + error_event = next( + item for item in items if item.event == EventType.ERROR + ) + assert error_event.data["code"] == "RuntimeError" + assert "retryable" not in error_event.data + assert "retryAfterMs" not in error_event.data + class TestInvokerSync: """同步调用测试""" From 89fbadbffa6d687cd8fd572d201920356f92a61e Mon Sep 17 00:00:00 2001 From: "congxiao.wxx" Date: Thu, 25 Jun 2026 19:32:59 +0800 Subject: [PATCH 2/6] Make header lookup case-insensitive Normalize the requested header name before comparing so mixed-case callers still match HTTP headers case-insensitively. Constraint: Copilot review comment on PR #128 flagged _get_header as unexpectedly case-sensitive for mixed-case lookup names. Rejected: broader header abstraction | a one-line normalization preserves the current helper boundary. Confidence: high Scope-risk: narrow Tested: uv run --extra server pytest tests/unittests/server/test_error_utils.py tests/unittests/server/test_invoker.py tests/unittests/server/test_agui_protocol.py tests/unittests/integration/test_langgraph_events.py -q => 102 passed, 1 warning Tested: uv run --extra server pytest tests/unittests/integration/test_langgraph_to_agent_event.py -q => 32 passed Tested: git diff --check && uv run ruff check agentrun/server/error_utils.py tests/unittests/server/test_error_utils.py => passed Change-Id: I23e62a3b01e88c1b39f8669f89490b1c7f5e9ddf Co-developed-by: Codex Not-tested: full repository test suite and remote CI pending Signed-off-by: congxiao.wxx --- agentrun/server/error_utils.py | 3 ++- tests/unittests/server/test_error_utils.py | 9 +++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) create mode 100644 tests/unittests/server/test_error_utils.py diff --git a/agentrun/server/error_utils.py b/agentrun/server/error_utils.py index 60f9bb0..e0e6931 100644 --- a/agentrun/server/error_utils.py +++ b/agentrun/server/error_utils.py @@ -131,9 +131,10 @@ def _get_value(obj: Any, name: str) -> Optional[Any]: def _get_header(headers: Any, name: str) -> Optional[Any]: + target = str(name).lower() if isinstance(headers, dict): for key, value in headers.items(): - if str(key).lower() == name: + if str(key).lower() == target: return value return None get = getattr(headers, "get", None) diff --git a/tests/unittests/server/test_error_utils.py b/tests/unittests/server/test_error_utils.py new file mode 100644 index 0000000..6b13b11 --- /dev/null +++ b/tests/unittests/server/test_error_utils.py @@ -0,0 +1,9 @@ +"""Tests for server error helpers.""" + +from agentrun.server.error_utils import _get_header + + +def test_get_header_matches_name_case_insensitively(): + headers = {"x-trace-id": "trace-123"} + + assert _get_header(headers, "X-Trace-ID") == "trace-123" From ace0a4a93ae5a151f8e186bb11f294bcc23cd1b4 Mon Sep 17 00:00:00 2001 From: "congxiao.wxx" Date: Thu, 25 Jun 2026 20:55:07 +0800 Subject: [PATCH 3/6] Constrain rate-limit normalization to explicit signals Close the review gaps by moving the helper to a shared utils layer, preserving AG-UI encoder framing, limiting LangGraph normalization to LLM errors, and adding false-positive plus throttling regression coverage. Constraint: Aone 83566999 requires model rate limits to surface as stable AG-UI RUN_ERROR without leaking raw provider errors. Rejected: Matching generic HTTP/status/code 429 text | it misclassified explanatory non-rate-limit errors. Rejected: Hand-written RUN_ERROR SSE framing | it forked the AG-UI encoder contract. Confidence: high Scope-risk: narrow Directive: Keep future rate-limit text matching limited to explicit provider throttle/rate-limit semantics. Tested: uv run --extra server pytest tests/unittests/server/test_error_utils.py tests/unittests/server/test_invoker.py tests/unittests/server/test_agui_protocol.py tests/unittests/integration/test_langgraph_events.py tests/unittests/integration/test_langgraph_to_agent_event.py -q (143 passed, 1 warning); git diff --check; focused ruff; local uvicorn /ag-ui/agent harness for structured 429, throttling text, and false-positive code 429. Change-Id: Ie114d646380623dc8546f082cf2b0776035ccade Co-developed-by: Codex Not-tested: GitHub CI status due local gh auth/API limitations. Signed-off-by: congxiao.wxx --- .../integration/langgraph/agent_converter.py | 20 ++++---- agentrun/server/agui_protocol.py | 11 +++-- agentrun/server/invoker.py | 2 +- agentrun/{server => utils}/error_utils.py | 5 +- .../integration/test_langgraph_events.py | 46 +++++++++++++++++++ .../test_langgraph_to_agent_event.py | 46 +++++++++++++++++++ tests/unittests/server/test_agui_protocol.py | 4 +- tests/unittests/server/test_error_utils.py | 31 ++++++++++++- 8 files changed, 142 insertions(+), 23 deletions(-) rename agentrun/{server => utils}/error_utils.py (95%) diff --git a/agentrun/integration/langgraph/agent_converter.py b/agentrun/integration/langgraph/agent_converter.py index d69c8c9..f4487c5 100644 --- a/agentrun/integration/langgraph/agent_converter.py +++ b/agentrun/integration/langgraph/agent_converter.py @@ -29,7 +29,7 @@ import json from typing import Any, Dict, Iterator, List, Optional, Union -from agentrun.server.error_utils import ( +from agentrun.utils.error_utils import ( build_error_event_data, format_error_message, ) @@ -964,15 +964,14 @@ def _convert_astream_events_event( yield AgentResult( event=EventType.ERROR, - data=build_error_event_data( - error, - fallback_code="CHAIN_ERROR", - fallback_message=( + data={ + "message": ( f"Chain '{chain_name}' error: {error_message}" if chain_name else error_message ), - ), + "code": "CHAIN_ERROR", + }, ) # 9. Retriever 错误 @@ -983,15 +982,14 @@ def _convert_astream_events_event( yield AgentResult( event=EventType.ERROR, - data=build_error_event_data( - error, - fallback_code="RETRIEVER_ERROR", - fallback_message=( + data={ + "message": ( f"Retriever '{retriever_name}' error: {error_message}" if retriever_name else error_message ), - ), + "code": "RETRIEVER_ERROR", + }, ) # ========================================================================= diff --git a/agentrun/server/agui_protocol.py b/agentrun/server/agui_protocol.py index 70a6bb8..b6012e9 100644 --- a/agentrun/server/agui_protocol.py +++ b/agentrun/server/agui_protocol.py @@ -748,17 +748,18 @@ def _process_event_with_boundaries( message=event.data.get("message", ""), code=event.data.get("code"), ) - event_dict = agui_event.model_dump(by_alias=True, exclude_none=True) + extra_fields = {} for key in RUN_ERROR_EXTRA_FIELDS: value = event.data.get(key) if value is not None: - event_dict[key] = value + extra_fields[key] = value elif event.addition: value = event.addition.get(key) if value is not None: - event_dict[key] = value - json_str = json.dumps(event_dict, ensure_ascii=False) - yield f"event: RUN_ERROR\ndata: {json_str}\n\n" + extra_fields[key] = value + if extra_fields: + agui_event = agui_event.model_copy(update=extra_fields) + yield self._encoder.encode(agui_event) return # STATE 事件 diff --git a/agentrun/server/invoker.py b/agentrun/server/invoker.py index 6f56ff7..cf78dc9 100644 --- a/agentrun/server/invoker.py +++ b/agentrun/server/invoker.py @@ -24,7 +24,7 @@ ) import uuid -from .error_utils import build_error_event_data +from agentrun.utils.error_utils import build_error_event_data from .model import AgentEvent, AgentRequest, EventType from .protocol import ( AsyncInvokeAgentHandler, diff --git a/agentrun/server/error_utils.py b/agentrun/utils/error_utils.py similarity index 95% rename from agentrun/server/error_utils.py rename to agentrun/utils/error_utils.py index e0e6931..90fad9f 100644 --- a/agentrun/server/error_utils.py +++ b/agentrun/utils/error_utils.py @@ -1,4 +1,4 @@ -"""Error helpers for AgentRun server streams.""" +"""Error helpers for AgentRun event streams.""" import re from typing import Any, Dict, Optional @@ -25,9 +25,8 @@ re.IGNORECASE, ), re.compile(r"\bresource[-_\s]*throttled\b", re.IGNORECASE), - re.compile(r"\bthrottling(?:exception| exception| error)\b", re.IGNORECASE), re.compile( - r"\b(?:http|status|status code|code)\s*[:=]?\s*429\b", + r"\b(?:throttling|throttlingexception|throttled)\b", re.IGNORECASE, ), ] diff --git a/tests/unittests/integration/test_langgraph_events.py b/tests/unittests/integration/test_langgraph_events.py index a8bd11b..6eef62f 100644 --- a/tests/unittests/integration/test_langgraph_events.py +++ b/tests/unittests/integration/test_langgraph_events.py @@ -890,6 +890,29 @@ def test_on_chain_error(self): assert "KeyError" in results[0].data["message"] assert results[0].data["code"] == "CHAIN_ERROR" + def test_on_chain_error_status_429_keeps_chain_error(self): + """测试 chain 429 不套用模型限流语义""" + + class ChainStatusError(RuntimeError): + status_code = 429 + + event = { + "event": "on_chain_error", + "name": "agent_chain", + "run_id": "run_chain", + "data": { + "error": ChainStatusError("chain quota exceeded"), + }, + } + + results = list(AgentRunConverter().to_agui_events(event)) + + assert len(results) == 1 + assert results[0].event == EventType.ERROR + assert results[0].data["code"] == "CHAIN_ERROR" + assert "retryable" not in results[0].data + assert "retryAfterMs" not in results[0].data + def test_on_retriever_error(self): """测试 on_retriever_error 事件 @@ -913,6 +936,29 @@ def test_on_retriever_error(self): assert "ConnectionError" in results[0].data["message"] assert results[0].data["code"] == "RETRIEVER_ERROR" + def test_on_retriever_error_status_429_keeps_retriever_error(self): + """测试 retriever 429 不套用模型限流语义""" + + class RetrieverStatusError(RuntimeError): + status_code = 429 + + event = { + "event": "on_retriever_error", + "name": "vector_store", + "run_id": "run_retriever", + "data": { + "error": RetrieverStatusError("retriever quota exceeded"), + }, + } + + results = list(AgentRunConverter().to_agui_events(event)) + + assert len(results) == 1 + assert results[0].event == EventType.ERROR + assert results[0].data["code"] == "RETRIEVER_ERROR" + assert "retryable" not in results[0].data + assert "retryAfterMs" not in results[0].data + def test_tool_error_in_complete_flow(self): """测试完整流程中的工具错误 diff --git a/tests/unittests/integration/test_langgraph_to_agent_event.py b/tests/unittests/integration/test_langgraph_to_agent_event.py index 491d4ff..73e4d78 100644 --- a/tests/unittests/integration/test_langgraph_to_agent_event.py +++ b/tests/unittests/integration/test_langgraph_to_agent_event.py @@ -888,6 +888,29 @@ def test_on_chain_error(self): assert "KeyError" in results[0].data["message"] assert results[0].data["code"] == "CHAIN_ERROR" + def test_on_chain_error_status_429_keeps_chain_error(self): + """测试 chain 429 不套用模型限流语义""" + + class ChainStatusError(RuntimeError): + status_code = 429 + + event = { + "event": "on_chain_error", + "name": "agent_chain", + "run_id": "run_chain", + "data": { + "error": ChainStatusError("chain quota exceeded"), + }, + } + + results = list(AgentRunConverter().to_agui_events(event)) + + assert len(results) == 1 + assert results[0].event == EventType.ERROR + assert results[0].data["code"] == "CHAIN_ERROR" + assert "retryable" not in results[0].data + assert "retryAfterMs" not in results[0].data + def test_on_retriever_error(self): """测试 on_retriever_error 事件 @@ -911,6 +934,29 @@ def test_on_retriever_error(self): assert "ConnectionError" in results[0].data["message"] assert results[0].data["code"] == "RETRIEVER_ERROR" + def test_on_retriever_error_status_429_keeps_retriever_error(self): + """测试 retriever 429 不套用模型限流语义""" + + class RetrieverStatusError(RuntimeError): + status_code = 429 + + event = { + "event": "on_retriever_error", + "name": "vector_store", + "run_id": "run_retriever", + "data": { + "error": RetrieverStatusError("retriever quota exceeded"), + }, + } + + results = list(AgentRunConverter().to_agui_events(event)) + + assert len(results) == 1 + assert results[0].event == EventType.ERROR + assert results[0].data["code"] == "RETRIEVER_ERROR" + assert "retryable" not in results[0].data + assert "retryAfterMs" not in results[0].data + def test_tool_error_in_complete_flow(self): """测试完整流程中的工具错误 diff --git a/tests/unittests/server/test_agui_protocol.py b/tests/unittests/server/test_agui_protocol.py index 6dcf32a..c7aa123 100644 --- a/tests/unittests/server/test_agui_protocol.py +++ b/tests/unittests/server/test_agui_protocol.py @@ -128,7 +128,7 @@ def invoke_agent(request: AgentRequest): ) assert response.status_code == 200 - assert "event: RUN_ERROR" in response.text + assert "event: RUN_ERROR" not in response.text events = _agui_sse_events(response) run_error = next( event for event in events if event.get("type") == "RUN_ERROR" @@ -156,7 +156,7 @@ def invoke_agent(request: AgentRequest): ) assert response.status_code == 200 - assert "event: RUN_ERROR" in response.text + assert "event: RUN_ERROR" not in response.text events = _agui_sse_events(response) types = [event.get("type") for event in events] run_error = next( diff --git a/tests/unittests/server/test_error_utils.py b/tests/unittests/server/test_error_utils.py index 6b13b11..ae56ef8 100644 --- a/tests/unittests/server/test_error_utils.py +++ b/tests/unittests/server/test_error_utils.py @@ -1,9 +1,38 @@ """Tests for server error helpers.""" -from agentrun.server.error_utils import _get_header +from agentrun.utils.error_utils import _get_header, is_rate_limited_error def test_get_header_matches_name_case_insensitively(): headers = {"x-trace-id": "trace-123"} assert _get_header(headers, "X-Trace-ID") == "trace-123" + + +def test_explanatory_code_429_text_is_not_rate_limited(): + error = RuntimeError("validation failed for field code 429") + + assert not is_rate_limited_error(error) + + +def test_explanatory_http_429_text_is_not_rate_limited(): + error = RuntimeError( + "docs mention HTTP 429 means rate limit; actual error is 401" + ) + + assert not is_rate_limited_error(error) + + +def test_explicit_throttling_text_is_rate_limited(): + assert is_rate_limited_error(RuntimeError("Throttling: model overloaded")) + + +def test_explicit_throttled_text_is_rate_limited(): + assert is_rate_limited_error(RuntimeError("request throttled by provider")) + + +def test_structured_status_429_is_rate_limited(): + class RateLimitError(RuntimeError): + status_code = 429 + + assert is_rate_limited_error(RateLimitError("model overloaded")) From 48ccc11ec42e108fe3168c8513cce3aca19345e8 Mon Sep 17 00:00:00 2001 From: "congxiao.wxx" Date: Thu, 25 Jun 2026 21:03:16 +0800 Subject: [PATCH 4/6] Preserve model rate-limit error messages Constraint: Users need the original provider error text in AG-UI RUN_ERROR.message while still normalizing retry metadata.\nRejected: Keep fixed Chinese rate-limit copy | It hides the actionable upstream model error from clients.\nConfidence: high\nScope-risk: narrow\nDirective: Keep RATE_LIMITED code and retry fields stable, but do not replace message with generic copy.\nTested: uv run --extra server pytest tests/unittests/server/test_error_utils.py tests/unittests/server/test_invoker.py tests/unittests/server/test_agui_protocol.py tests/unittests/integration/test_langgraph_events.py tests/unittests/integration/test_langgraph_to_agent_event.py -q; uv run ruff check targeted files; git diff --check; uvicorn local harness for status 429/throttling/false-positive 429.\nNot-tested: GitHub CI not yet read back after this commit. Change-Id: I5a1b893d34dd936bf172ca980d3e7d3a496a60d4 Co-developed-by: Codex Signed-off-by: congxiao.wxx --- agentrun/utils/error_utils.py | 3 +-- tests/unittests/integration/test_langgraph_events.py | 2 +- tests/unittests/integration/test_langgraph_to_agent_event.py | 2 +- tests/unittests/server/test_agui_protocol.py | 2 +- tests/unittests/server/test_invoker.py | 2 +- 5 files changed, 5 insertions(+), 6 deletions(-) diff --git a/agentrun/utils/error_utils.py b/agentrun/utils/error_utils.py index 90fad9f..e17a741 100644 --- a/agentrun/utils/error_utils.py +++ b/agentrun/utils/error_utils.py @@ -4,7 +4,6 @@ from typing import Any, Dict, Optional RATE_LIMITED_CODE = "RATE_LIMITED" -RATE_LIMITED_MESSAGE = "模型当前请求过多,请稍后再试" RATE_LIMITED_RETRY_AFTER_MS = 2000 _RATE_LIMIT_CODES = { @@ -52,7 +51,7 @@ def build_error_event_data( return {"message": fallback_message, "code": fallback_code} data: Dict[str, Any] = { - "message": RATE_LIMITED_MESSAGE, + "message": fallback_message, "code": RATE_LIMITED_CODE, "retryable": True, "retryAfterMs": RATE_LIMITED_RETRY_AFTER_MS, diff --git a/tests/unittests/integration/test_langgraph_events.py b/tests/unittests/integration/test_langgraph_events.py index 6eef62f..00503ea 100644 --- a/tests/unittests/integration/test_langgraph_events.py +++ b/tests/unittests/integration/test_langgraph_events.py @@ -842,7 +842,7 @@ def test_on_llm_error_rate_limited(self): assert len(results) == 1 assert results[0].event == EventType.ERROR - assert results[0].data["message"] == "模型当前请求过多,请稍后再试" + assert results[0].data["message"] == "LLM error: RuntimeError: API rate limit exceeded" assert results[0].data["code"] == "RATE_LIMITED" assert results[0].data["retryable"] is True assert results[0].data["retryAfterMs"] == 2000 diff --git a/tests/unittests/integration/test_langgraph_to_agent_event.py b/tests/unittests/integration/test_langgraph_to_agent_event.py index 73e4d78..d67ba9f 100644 --- a/tests/unittests/integration/test_langgraph_to_agent_event.py +++ b/tests/unittests/integration/test_langgraph_to_agent_event.py @@ -840,7 +840,7 @@ def test_on_llm_error_rate_limited(self): assert len(results) == 1 assert results[0].event == EventType.ERROR - assert results[0].data["message"] == "模型当前请求过多,请稍后再试" + assert results[0].data["message"] == "LLM error: RuntimeError: API rate limit exceeded" assert results[0].data["code"] == "RATE_LIMITED" assert results[0].data["retryable"] is True assert results[0].data["retryAfterMs"] == 2000 diff --git a/tests/unittests/server/test_agui_protocol.py b/tests/unittests/server/test_agui_protocol.py index c7aa123..6d1b158 100644 --- a/tests/unittests/server/test_agui_protocol.py +++ b/tests/unittests/server/test_agui_protocol.py @@ -163,7 +163,7 @@ def invoke_agent(request: AgentRequest): event for event in events if event.get("type") == "RUN_ERROR" ) assert "RUN_FINISHED" not in types - assert run_error["message"] == "模型当前请求过多,请稍后再试" + assert run_error["message"] == "model overloaded" assert run_error["code"] == "RATE_LIMITED" assert run_error["retryable"] is True assert run_error["retryAfterMs"] == 2000 diff --git a/tests/unittests/server/test_invoker.py b/tests/unittests/server/test_invoker.py index b3e97f3..bb156be 100644 --- a/tests/unittests/server/test_invoker.py +++ b/tests/unittests/server/test_invoker.py @@ -208,7 +208,7 @@ async def invoke_agent(req: AgentRequest) -> str: error_event = next( item for item in items if item.event == EventType.ERROR ) - assert error_event.data["message"] == "模型当前请求过多,请稍后再试" + assert error_event.data["message"] == "model overloaded" assert error_event.data["code"] == "RATE_LIMITED" assert error_event.data["retryable"] is True assert error_event.data["retryAfterMs"] == 2000 From e666ee716cfd44c6231d535e5cda829ec5a17c91 Mon Sep 17 00:00:00 2001 From: "congxiao.wxx" Date: Thu, 25 Jun 2026 21:15:36 +0800 Subject: [PATCH 5/6] Simplify model rate-limit run errors Constraint: Aone 83566999 only needs model 429 output to terminate as AG-UI RUN_ERROR without hardcoded user-facing copy.\nRejected: fixed localized message or broad error framework | original provider error text is required and a small regex/status-code guard is enough.\nConfidence: high\nScope-risk: narrow\nDirective: Keep RUN_ERROR.message sourced from the original error text; do not introduce generic friendly copy.\nTested: uv run --extra server pytest targeted files -q; uv run ruff check targeted source/server-test files; git diff --check; uvicorn local harness for text 429, structured 429, and normal text.\nNot-tested: GitHub CI final status and DCO fix; earlier pushed commits still require sign-off remediation. Signed-off-by: congxiao.wxx Change-Id: I248d3f29cfd95560edc48f22d2bd014cc5762cef Co-developed-by: Codex Signed-off-by: congxiao.wxx --- .../integration/langgraph/agent_converter.py | 37 ++++- agentrun/server/agui_protocol.py | 14 +- agentrun/server/invoker.py | 51 ++++--- agentrun/utils/error_utils.py | 133 ++++------------- .../integration/test_langgraph_events.py | 72 +--------- .../test_langgraph_to_agent_event.py | 72 +--------- tests/unittests/server/test_agui_protocol.py | 72 +--------- tests/unittests/server/test_error_utils.py | 52 ++++--- tests/unittests/server/test_invoker.py | 134 ++---------------- 9 files changed, 134 insertions(+), 503 deletions(-) diff --git a/agentrun/integration/langgraph/agent_converter.py b/agentrun/integration/langgraph/agent_converter.py index f4487c5..8f1d069 100644 --- a/agentrun/integration/langgraph/agent_converter.py +++ b/agentrun/integration/langgraph/agent_converter.py @@ -29,11 +29,11 @@ import json from typing import Any, Dict, Iterator, List, Optional, Union +from agentrun.server.model import AgentResult, EventType from agentrun.utils.error_utils import ( build_error_event_data, - format_error_message, + is_rate_limited_error, ) -from agentrun.server.model import AgentResult, EventType from agentrun.utils.log import logger # 需要从工具输入中过滤掉的内部字段(LangGraph/MCP 注入的运行时对象) @@ -945,14 +945,25 @@ def _convert_astream_events_event( # 7. LLM 错误 elif event_type == "on_llm_error": error = data.get("error") - error_message = format_error_message(error) + error_message = "" + if error is not None: + if isinstance(error, Exception): + error_message = f"{type(error).__name__}: {str(error)}" + elif isinstance(error, str): + error_message = error + else: + error_message = str(error) yield AgentResult( event=EventType.ERROR, data=build_error_event_data( error, fallback_code="LLM_ERROR", - fallback_message=f"LLM error: {error_message}", + fallback_message=( + error_message + if is_rate_limited_error(error) + else f"LLM error: {error_message}" + ), ), ) @@ -960,7 +971,14 @@ def _convert_astream_events_event( elif event_type == "on_chain_error": error = data.get("error") chain_name = event_dict.get("name", "") - error_message = format_error_message(error) + error_message = "" + if error is not None: + if isinstance(error, Exception): + error_message = f"{type(error).__name__}: {str(error)}" + elif isinstance(error, str): + error_message = error + else: + error_message = str(error) yield AgentResult( event=EventType.ERROR, @@ -978,7 +996,14 @@ def _convert_astream_events_event( elif event_type == "on_retriever_error": error = data.get("error") retriever_name = event_dict.get("name", "") - error_message = format_error_message(error) + error_message = "" + if error is not None: + if isinstance(error, Exception): + error_message = f"{type(error).__name__}: {str(error)}" + elif isinstance(error, str): + error_message = error + else: + error_message = str(error) yield AgentResult( event=EventType.ERROR, diff --git a/agentrun/server/agui_protocol.py b/agentrun/server/agui_protocol.py index b6012e9..a9d6d43 100644 --- a/agentrun/server/agui_protocol.py +++ b/agentrun/server/agui_protocol.py @@ -748,15 +748,11 @@ def _process_event_with_boundaries( message=event.data.get("message", ""), code=event.data.get("code"), ) - extra_fields = {} - for key in RUN_ERROR_EXTRA_FIELDS: - value = event.data.get(key) - if value is not None: - extra_fields[key] = value - elif event.addition: - value = event.addition.get(key) - if value is not None: - extra_fields[key] = value + extra_fields = { + key: event.data[key] + for key in RUN_ERROR_EXTRA_FIELDS + if key in event.data + } if extra_fields: agui_event = agui_event.model_copy(update=extra_fields) yield self._encoder.encode(agui_event) diff --git a/agentrun/server/invoker.py b/agentrun/server/invoker.py index cf78dc9..26f8956 100644 --- a/agentrun/server/invoker.py +++ b/agentrun/server/invoker.py @@ -24,7 +24,10 @@ ) import uuid -from agentrun.utils.error_utils import build_error_event_data +from agentrun.utils.error_utils import ( + build_error_event_data, + is_rate_limited_error, +) from .model import AgentEvent, AgentRequest, EventType from .protocol import ( AsyncInvokeAgentHandler, @@ -118,10 +121,7 @@ async def invoke_stream( if isinstance(item, str): if not item: # 跳过空字符串 continue - yield AgentEvent( - event=EventType.TEXT, - data={"delta": item}, - ) + yield self._wrap_text(item) elif isinstance(item, AgentEvent): # 处理用户返回的事件 @@ -232,12 +232,7 @@ def _wrap_non_stream(self, result: Any) -> List[AgentEvent]: return results if isinstance(result, str): - results.append( - AgentEvent( - event=EventType.TEXT, - data={"delta": result}, - ) - ) + results.append(self._wrap_text(result)) elif isinstance(result, AgentEvent): # 处理可能的 TOOL_CALL 展开 @@ -248,12 +243,7 @@ def _wrap_non_stream(self, result: Any) -> List[AgentEvent]: if isinstance(item, AgentEvent): results.extend(self._process_user_event(item)) elif isinstance(item, str) and item: - results.append( - AgentEvent( - event=EventType.TEXT, - data={"delta": item}, - ) - ) + results.append(self._wrap_text(item)) else: results.extend(self._wrap_model_chunk(item)) @@ -280,10 +270,7 @@ async def _wrap_stream( if isinstance(item, str): if not item: continue - yield AgentEvent( - event=EventType.TEXT, - data={"delta": item}, - ) + yield self._wrap_text(item) elif isinstance(item, AgentEvent): for processed_event in self._process_user_event(item): @@ -351,15 +338,25 @@ def _wrap_model_chunk(self, item: Any) -> List[AgentEvent]: content = self._read_attr_or_key(item, "content") if isinstance(content, str) and content: - events.append( - AgentEvent( - event=EventType.TEXT, - data={"delta": content}, - ) - ) + events.append(self._wrap_text(content)) return events + def _wrap_text(self, text: str) -> AgentEvent: + if is_rate_limited_error(text): + return AgentEvent( + event=EventType.ERROR, + data=build_error_event_data( + text, + fallback_code=type(text).__name__, + fallback_message=text, + ), + ) + return AgentEvent( + event=EventType.TEXT, + data={"delta": text}, + ) + def _read_attr_or_key(self, obj: Any, key: str) -> Any: if isinstance(obj, dict): return obj.get(key) diff --git a/agentrun/utils/error_utils.py b/agentrun/utils/error_utils.py index e17a741..f7ebee6 100644 --- a/agentrun/utils/error_utils.py +++ b/agentrun/utils/error_utils.py @@ -1,44 +1,21 @@ -"""Error helpers for AgentRun event streams.""" +"""Small helpers for model rate-limit errors.""" import re from typing import Any, Dict, Optional RATE_LIMITED_CODE = "RATE_LIMITED" RATE_LIMITED_RETRY_AFTER_MS = 2000 - +_RATE_LIMIT_TEXT_RE = re.compile( + r"429|too[-_\s]*many[-_\s]*requests|rate[-_\s]*limit|throttl", + re.IGNORECASE, +) _RATE_LIMIT_CODES = { "ratelimitexceeded", "ratelimited", - "resourcethrottled", "throttling", - "throttlingquota", - "throttlingratequota", - "throttlingexception", "toomanyrequests", } -_RATE_LIMIT_TEXT_PATTERNS = [ - re.compile(r"\btoo[-_\s]*many[-_\s]*requests\b", re.IGNORECASE), - re.compile( - r"\brate[-_\s]*limit(?:ed|[-_\s]*exceeded)\b", - re.IGNORECASE, - ), - re.compile(r"\bresource[-_\s]*throttled\b", re.IGNORECASE), - re.compile( - r"\b(?:throttling|throttlingexception|throttled)\b", - re.IGNORECASE, - ), -] - - -def format_error_message(error: Any) -> str: - """Format errors consistently with existing LangGraph conversion.""" - if error is None: - return "" - if isinstance(error, Exception): - return f"{type(error).__name__}: {str(error)}" - return str(error) - def build_error_event_data( error: Any, @@ -46,7 +23,7 @@ def build_error_event_data( fallback_code: str, fallback_message: str, ) -> Dict[str, Any]: - """Build AgentEvent ERROR data, normalizing model rate limits.""" + """Keep the original message; add rate-limit metadata only when matched.""" if not is_rate_limited_error(error): return {"message": fallback_message, "code": fallback_code} @@ -56,70 +33,42 @@ def build_error_event_data( "retryable": True, "retryAfterMs": RATE_LIMITED_RETRY_AFTER_MS, } - trace_id = _extract_trace_id(error) + trace_id = _get_value(error, "trace_id") or _get_value(error, "traceId") if trace_id: data["traceId"] = str(trace_id) return data def is_rate_limited_error(error: Any) -> bool: - """Return whether an error carries an explicit rate-limit signal.""" - if _extract_status_code(error) == 429: + if error is None: + return False + if _status_code(error) == 429 or _status_code(_get_value(error, "response")) == 429: return True - - if _has_rate_limit_code(error): + if _rate_limit_code(error) or _rate_limit_code(_get_value(error, "response")): return True - - message = str(error or "") - return any(pattern.search(message) for pattern in _RATE_LIMIT_TEXT_PATTERNS) + return bool(_RATE_LIMIT_TEXT_RE.search(str(error))) -def _extract_status_code(error: Any) -> Optional[int]: - fallback = None - for obj in (error, _get_value(error, "response")): - if obj is None: +def _status_code(obj: Any) -> Optional[int]: + for name in ("status_code", "status", "http_status", "statusCode"): + value = _get_value(obj, name) + if value is None: continue - for name in ("status_code", "status", "http_status", "statusCode"): - status_code = _to_int(_get_value(obj, name)) - if status_code == 429: - return status_code - if fallback is None and status_code is not None: - fallback = status_code - return fallback + try: + return int(value) + except (TypeError, ValueError): + return None + return None -def _has_rate_limit_code(error: Any) -> bool: - for obj in (error, _get_value(error, "response")): - if obj is None: - continue - for name in ("code", "error_code", "errorCode"): - error_code = _get_value(obj, name) - if ( - error_code is not None - and _normalize_code(error_code) in _RATE_LIMIT_CODES - ): - return True +def _rate_limit_code(obj: Any) -> bool: + for name in ("code", "error_code", "errorCode"): + code = _get_value(obj, name) + if code and _normalize_code(code) in _RATE_LIMIT_CODES: + return True return False -def _extract_trace_id(error: Any) -> Optional[Any]: - for name in ("trace_id", "traceId", "request_id", "requestId"): - trace_id = _get_value(error, name) - if trace_id: - return trace_id - - response = _get_value(error, "response") - headers = _get_value(response, "headers") - if not headers: - return None - - for name in ("x-acs-request-id", "x-request-id", "x-trace-id", "trace-id"): - trace_id = _get_header(headers, name) - if trace_id: - return trace_id - return None - - def _get_value(obj: Any, name: str) -> Optional[Any]: if obj is None: return None @@ -128,31 +77,9 @@ def _get_value(obj: Any, name: str) -> Optional[Any]: return getattr(obj, name, None) -def _get_header(headers: Any, name: str) -> Optional[Any]: - target = str(name).lower() - if isinstance(headers, dict): - for key, value in headers.items(): - if str(key).lower() == target: - return value - return None - get = getattr(headers, "get", None) - if callable(get): - return get(name) - return None - - def _normalize_code(code: Any) -> str: - normalized = re.sub(r"[^a-z0-9]", "", str(code).lower()) + value = "".join(ch for ch in str(code).lower() if ch.isalnum()) for suffix in ("exception", "error"): - if normalized.endswith(suffix): - return normalized[: -len(suffix)] - return normalized - - -def _to_int(value: Any) -> Optional[int]: - if value is None: - return None - try: - return int(value) - except (TypeError, ValueError): - return None + if value.endswith(suffix): + return value[: -len(suffix)] + return value diff --git a/tests/unittests/integration/test_langgraph_events.py b/tests/unittests/integration/test_langgraph_events.py index 00503ea..7d88111 100644 --- a/tests/unittests/integration/test_langgraph_events.py +++ b/tests/unittests/integration/test_langgraph_events.py @@ -829,12 +829,12 @@ def test_on_llm_error(self): assert results[0].data["code"] == "LLM_ERROR" def test_on_llm_error_rate_limited(self): - """测试 on_llm_error 限流错误归一化""" + """测试 on_llm_error 限流错误归一化且 message 保留原始错误""" event = { "event": "on_llm_error", "run_id": "run_llm", "data": { - "error": RuntimeError("API rate limit exceeded"), + "error": RuntimeError("Error code: 429 - rate limit exceeded"), }, } @@ -842,31 +842,11 @@ def test_on_llm_error_rate_limited(self): assert len(results) == 1 assert results[0].event == EventType.ERROR - assert results[0].data["message"] == "LLM error: RuntimeError: API rate limit exceeded" + assert results[0].data["message"] == "RuntimeError: Error code: 429 - rate limit exceeded" assert results[0].data["code"] == "RATE_LIMITED" assert results[0].data["retryable"] is True assert results[0].data["retryAfterMs"] == 2000 - def test_on_llm_error_rate_limit_text_false_positive(self): - """测试说明性 rate limit/429 文本不会误判""" - event = { - "event": "on_llm_error", - "run_id": "run_llm", - "data": { - "error": RuntimeError( - "ticket 429 mentions rate limit dashboard, auth failed" - ), - }, - } - - results = list(AgentRunConverter().to_agui_events(event)) - - assert len(results) == 1 - assert results[0].event == EventType.ERROR - assert results[0].data["code"] == "LLM_ERROR" - assert "retryable" not in results[0].data - assert "retryAfterMs" not in results[0].data - def test_on_chain_error(self): """测试 on_chain_error 事件 @@ -890,29 +870,6 @@ def test_on_chain_error(self): assert "KeyError" in results[0].data["message"] assert results[0].data["code"] == "CHAIN_ERROR" - def test_on_chain_error_status_429_keeps_chain_error(self): - """测试 chain 429 不套用模型限流语义""" - - class ChainStatusError(RuntimeError): - status_code = 429 - - event = { - "event": "on_chain_error", - "name": "agent_chain", - "run_id": "run_chain", - "data": { - "error": ChainStatusError("chain quota exceeded"), - }, - } - - results = list(AgentRunConverter().to_agui_events(event)) - - assert len(results) == 1 - assert results[0].event == EventType.ERROR - assert results[0].data["code"] == "CHAIN_ERROR" - assert "retryable" not in results[0].data - assert "retryAfterMs" not in results[0].data - def test_on_retriever_error(self): """测试 on_retriever_error 事件 @@ -936,29 +893,6 @@ def test_on_retriever_error(self): assert "ConnectionError" in results[0].data["message"] assert results[0].data["code"] == "RETRIEVER_ERROR" - def test_on_retriever_error_status_429_keeps_retriever_error(self): - """测试 retriever 429 不套用模型限流语义""" - - class RetrieverStatusError(RuntimeError): - status_code = 429 - - event = { - "event": "on_retriever_error", - "name": "vector_store", - "run_id": "run_retriever", - "data": { - "error": RetrieverStatusError("retriever quota exceeded"), - }, - } - - results = list(AgentRunConverter().to_agui_events(event)) - - assert len(results) == 1 - assert results[0].event == EventType.ERROR - assert results[0].data["code"] == "RETRIEVER_ERROR" - assert "retryable" not in results[0].data - assert "retryAfterMs" not in results[0].data - def test_tool_error_in_complete_flow(self): """测试完整流程中的工具错误 diff --git a/tests/unittests/integration/test_langgraph_to_agent_event.py b/tests/unittests/integration/test_langgraph_to_agent_event.py index d67ba9f..19bbdba 100644 --- a/tests/unittests/integration/test_langgraph_to_agent_event.py +++ b/tests/unittests/integration/test_langgraph_to_agent_event.py @@ -827,12 +827,12 @@ def test_on_llm_error(self): assert results[0].data["code"] == "LLM_ERROR" def test_on_llm_error_rate_limited(self): - """测试 on_llm_error 限流错误归一化""" + """测试 on_llm_error 限流错误归一化且 message 保留原始错误""" event = { "event": "on_llm_error", "run_id": "run_llm", "data": { - "error": RuntimeError("API rate limit exceeded"), + "error": RuntimeError("Error code: 429 - rate limit exceeded"), }, } @@ -840,31 +840,11 @@ def test_on_llm_error_rate_limited(self): assert len(results) == 1 assert results[0].event == EventType.ERROR - assert results[0].data["message"] == "LLM error: RuntimeError: API rate limit exceeded" + assert results[0].data["message"] == "RuntimeError: Error code: 429 - rate limit exceeded" assert results[0].data["code"] == "RATE_LIMITED" assert results[0].data["retryable"] is True assert results[0].data["retryAfterMs"] == 2000 - def test_on_llm_error_rate_limit_text_false_positive(self): - """测试说明性 rate limit/429 文本不会误判""" - event = { - "event": "on_llm_error", - "run_id": "run_llm", - "data": { - "error": RuntimeError( - "ticket 429 mentions rate limit dashboard, auth failed" - ), - }, - } - - results = list(AgentRunConverter().to_agui_events(event)) - - assert len(results) == 1 - assert results[0].event == EventType.ERROR - assert results[0].data["code"] == "LLM_ERROR" - assert "retryable" not in results[0].data - assert "retryAfterMs" not in results[0].data - def test_on_chain_error(self): """测试 on_chain_error 事件 @@ -888,29 +868,6 @@ def test_on_chain_error(self): assert "KeyError" in results[0].data["message"] assert results[0].data["code"] == "CHAIN_ERROR" - def test_on_chain_error_status_429_keeps_chain_error(self): - """测试 chain 429 不套用模型限流语义""" - - class ChainStatusError(RuntimeError): - status_code = 429 - - event = { - "event": "on_chain_error", - "name": "agent_chain", - "run_id": "run_chain", - "data": { - "error": ChainStatusError("chain quota exceeded"), - }, - } - - results = list(AgentRunConverter().to_agui_events(event)) - - assert len(results) == 1 - assert results[0].event == EventType.ERROR - assert results[0].data["code"] == "CHAIN_ERROR" - assert "retryable" not in results[0].data - assert "retryAfterMs" not in results[0].data - def test_on_retriever_error(self): """测试 on_retriever_error 事件 @@ -934,29 +891,6 @@ def test_on_retriever_error(self): assert "ConnectionError" in results[0].data["message"] assert results[0].data["code"] == "RETRIEVER_ERROR" - def test_on_retriever_error_status_429_keeps_retriever_error(self): - """测试 retriever 429 不套用模型限流语义""" - - class RetrieverStatusError(RuntimeError): - status_code = 429 - - event = { - "event": "on_retriever_error", - "name": "vector_store", - "run_id": "run_retriever", - "data": { - "error": RetrieverStatusError("retriever quota exceeded"), - }, - } - - results = list(AgentRunConverter().to_agui_events(event)) - - assert len(results) == 1 - assert results[0].event == EventType.ERROR - assert results[0].data["code"] == "RETRIEVER_ERROR" - assert "retryable" not in results[0].data - assert "retryAfterMs" not in results[0].data - def test_tool_error_in_complete_flow(self): """测试完整流程中的工具错误 diff --git a/tests/unittests/server/test_agui_protocol.py b/tests/unittests/server/test_agui_protocol.py index 6d1b158..3188aab 100644 --- a/tests/unittests/server/test_agui_protocol.py +++ b/tests/unittests/server/test_agui_protocol.py @@ -102,52 +102,11 @@ def invoke_agent(request: AgentRequest): assert "RUN_ERROR" in types @pytest.mark.asyncio - async def test_error_event_addition_fields_preserved(self): - """测试 RUN_ERROR 编码保留扩展字段""" + async def test_text_rate_limit_error_stream_payload(self): + """测试文本形式的 429 错误输出 RUN_ERROR 且无 RUN_FINISHED""" def invoke_agent(request: AgentRequest): - yield AgentEvent( - event=EventType.ERROR, - data={ - "message": "模型当前请求过多,请稍后再试", - "code": "RATE_LIMITED", - "retryable": True, - "type": "BROKEN", - }, - addition={ - "retryAfterMs": 2000, - "traceId": "trace-xyz", - "type": "BROKEN", - }, - ) - - client = self.get_client(invoke_agent) - response = client.post( - "/ag-ui/agent", - json={"messages": [{"role": "user", "content": "Hello"}]}, - ) - - assert response.status_code == 200 - assert "event: RUN_ERROR" not in response.text - events = _agui_sse_events(response) - run_error = next( - event for event in events if event.get("type") == "RUN_ERROR" - ) - assert run_error["type"] == "RUN_ERROR" - assert run_error["code"] == "RATE_LIMITED" - assert run_error["retryable"] is True - assert run_error["retryAfterMs"] == 2000 - assert run_error["traceId"] == "trace-xyz" - - @pytest.mark.asyncio - async def test_rate_limit_error_stream_payload(self): - """测试 429 错误输出结构化 RUN_ERROR 且无 RUN_FINISHED""" - - class RateLimitError(RuntimeError): - status_code = 429 - - def invoke_agent(request: AgentRequest): - raise RateLimitError("model overloaded") + return "Error code: 429 - rate limit exceeded" client = self.get_client(invoke_agent) response = client.post( @@ -156,40 +115,17 @@ def invoke_agent(request: AgentRequest): ) assert response.status_code == 200 - assert "event: RUN_ERROR" not in response.text events = _agui_sse_events(response) types = [event.get("type") for event in events] run_error = next( event for event in events if event.get("type") == "RUN_ERROR" ) assert "RUN_FINISHED" not in types - assert run_error["message"] == "model overloaded" + assert run_error["message"] == "Error code: 429 - rate limit exceeded" assert run_error["code"] == "RATE_LIMITED" assert run_error["retryable"] is True assert run_error["retryAfterMs"] == 2000 - @pytest.mark.asyncio - async def test_non_rate_limit_error_stream_payload(self): - """测试普通错误不会被误标为限流""" - - def invoke_agent(request: AgentRequest): - raise RuntimeError("boom") - - client = self.get_client(invoke_agent) - response = client.post( - "/ag-ui/agent", - json={"messages": [{"role": "user", "content": "Hello"}]}, - ) - - assert response.status_code == 200 - events = _agui_sse_events(response) - run_error = next( - event for event in events if event.get("type") == "RUN_ERROR" - ) - assert run_error["code"] == "RuntimeError" - assert "retryable" not in run_error - assert "retryAfterMs" not in run_error - @pytest.mark.asyncio async def test_exception_in_parse_request(self): """测试 parse_request 中的异常处理(覆盖 155-156 行) diff --git a/tests/unittests/server/test_error_utils.py b/tests/unittests/server/test_error_utils.py index ae56ef8..0299fc1 100644 --- a/tests/unittests/server/test_error_utils.py +++ b/tests/unittests/server/test_error_utils.py @@ -1,38 +1,36 @@ -"""Tests for server error helpers.""" +"""Tests for model rate-limit helpers.""" -from agentrun.utils.error_utils import _get_header, is_rate_limited_error +from agentrun.utils.error_utils import ( + build_error_event_data, + is_rate_limited_error, +) -def test_get_header_matches_name_case_insensitively(): - headers = {"x-trace-id": "trace-123"} +def test_text_429_rate_limit_is_rate_limited(): + assert is_rate_limited_error("Error code: 429 - rate limit exceeded") - assert _get_header(headers, "X-Trace-ID") == "trace-123" - - -def test_explanatory_code_429_text_is_not_rate_limited(): - error = RuntimeError("validation failed for field code 429") - - assert not is_rate_limited_error(error) - - -def test_explanatory_http_429_text_is_not_rate_limited(): - error = RuntimeError( - "docs mention HTTP 429 means rate limit; actual error is 401" - ) - - assert not is_rate_limited_error(error) +def test_structured_status_429_is_rate_limited(): + class RateLimitError(RuntimeError): + status_code = 429 -def test_explicit_throttling_text_is_rate_limited(): - assert is_rate_limited_error(RuntimeError("Throttling: model overloaded")) + assert is_rate_limited_error(RateLimitError("provider overloaded")) -def test_explicit_throttled_text_is_rate_limited(): - assert is_rate_limited_error(RuntimeError("request throttled by provider")) +def test_non_rate_limit_text_is_not_rate_limited(): + assert not is_rate_limited_error("normal response") -def test_structured_status_429_is_rate_limited(): - class RateLimitError(RuntimeError): - status_code = 429 +def test_rate_limit_event_uses_original_message(): + data = build_error_event_data( + "Error code: 429 - rate limit exceeded", + fallback_code="str", + fallback_message="Error code: 429 - rate limit exceeded", + ) - assert is_rate_limited_error(RateLimitError("model overloaded")) + assert data == { + "message": "Error code: 429 - rate limit exceeded", + "code": "RATE_LIMITED", + "retryable": True, + "retryAfterMs": 2000, + } diff --git a/tests/unittests/server/test_invoker.py b/tests/unittests/server/test_invoker.py index bb156be..8337db3 100644 --- a/tests/unittests/server/test_invoker.py +++ b/tests/unittests/server/test_invoker.py @@ -189,64 +189,11 @@ async def invoke_agent(req: AgentRequest) -> str: assert error_event.data["code"] == "ValueError" @pytest.mark.asyncio - async def test_invoke_stream_rate_limit_error(self, req): - """测试模型限流错误被归一化""" - - class RateLimitError(RuntimeError): - status_code = 429 - trace_id = "trace-123" - - async def invoke_agent(req: AgentRequest) -> str: - raise RateLimitError("model overloaded") - - invoker = AgentInvoker(invoke_agent) - - items: List[AgentEvent] = [] - async for item in invoker.invoke_stream(req): - items.append(item) - - error_event = next( - item for item in items if item.event == EventType.ERROR - ) - assert error_event.data["message"] == "model overloaded" - assert error_event.data["code"] == "RATE_LIMITED" - assert error_event.data["retryable"] is True - assert error_event.data["retryAfterMs"] == 2000 - assert error_event.data["traceId"] == "trace-123" - - @pytest.mark.asyncio - async def test_invoke_stream_response_rate_limit_error(self, req): - """测试 response.status_code=429 不被顶层非 429 状态掩盖""" - - class RateLimitError(RuntimeError): - status_code = 0 - code = "Other" - response = {"status_code": 429} - - async def invoke_agent(req: AgentRequest) -> str: - raise RateLimitError("model overloaded") - - invoker = AgentInvoker(invoke_agent) - - items: List[AgentEvent] = [] - async for item in invoker.invoke_stream(req): - items.append(item) - - error_event = next( - item for item in items if item.event == EventType.ERROR - ) - assert error_event.data["code"] == "RATE_LIMITED" - assert error_event.data["retryable"] is True - - @pytest.mark.asyncio - async def test_invoke_stream_rate_limit_code_exception(self, req): - """测试带 Exception 后缀的限流错误码被识别""" - - class RateLimitError(RuntimeError): - code = "TooManyRequestsException" + async def test_invoke_stream_text_rate_limit_error(self, req): + """测试字符串形式的模型限流错误被转成 ERROR""" async def invoke_agent(req: AgentRequest) -> str: - raise RateLimitError("model overloaded") + return "Error code: 429 - rate limit exceeded" invoker = AgentInvoker(invoke_agent) @@ -254,75 +201,12 @@ async def invoke_agent(req: AgentRequest) -> str: async for item in invoker.invoke_stream(req): items.append(item) - error_event = next( - item for item in items if item.event == EventType.ERROR - ) - assert error_event.data["code"] == "RATE_LIMITED" - assert error_event.data["retryAfterMs"] == 2000 - - @pytest.mark.asyncio - async def test_invoke_stream_response_rate_limit_code(self, req): - """测试 response.code 不被顶层非限流 code 掩盖""" - - class RateLimitError(RuntimeError): - code = "Other" - response = {"code": "TooManyRequests"} - - async def invoke_agent(req: AgentRequest) -> str: - raise RateLimitError("model overloaded") - - invoker = AgentInvoker(invoke_agent) - - items: List[AgentEvent] = [] - async for item in invoker.invoke_stream(req): - items.append(item) - - error_event = next( - item for item in items if item.event == EventType.ERROR - ) - assert error_event.data["code"] == "RATE_LIMITED" - assert error_event.data["retryAfterMs"] == 2000 - - @pytest.mark.asyncio - async def test_invoke_stream_rate_limit_snake_case_text(self, req): - """测试明确 snake_case 限流文本被识别""" - - async def invoke_agent(req: AgentRequest) -> str: - raise RuntimeError("rate_limit_exceeded: retry later") - - invoker = AgentInvoker(invoke_agent) - - items: List[AgentEvent] = [] - async for item in invoker.invoke_stream(req): - items.append(item) - - error_event = next( - item for item in items if item.event == EventType.ERROR - ) - assert error_event.data["code"] == "RATE_LIMITED" - assert error_event.data["retryAfterMs"] == 2000 - - @pytest.mark.asyncio - async def test_invoke_stream_rate_limit_text_false_positive(self, req): - """测试说明性 rate limit/429 文本不会被误判为限流""" - - async def invoke_agent(req: AgentRequest) -> str: - raise RuntimeError( - "ticket 429 mentions rate limit dashboard, auth failed" - ) - - invoker = AgentInvoker(invoke_agent) - - items: List[AgentEvent] = [] - async for item in invoker.invoke_stream(req): - items.append(item) - - error_event = next( - item for item in items if item.event == EventType.ERROR - ) - assert error_event.data["code"] == "RuntimeError" - assert "retryable" not in error_event.data - assert "retryAfterMs" not in error_event.data + assert len(items) == 1 + assert items[0].event == EventType.ERROR + assert items[0].data["message"] == "Error code: 429 - rate limit exceeded" + assert items[0].data["code"] == "RATE_LIMITED" + assert items[0].data["retryable"] is True + assert items[0].data["retryAfterMs"] == 2000 class TestInvokerSync: From f15513c493954e0d55bb9eb86ac40acebd7195d7 Mon Sep 17 00:00:00 2001 From: "congxiao.wxx" Date: Tue, 30 Jun 2026 14:52:30 +0800 Subject: [PATCH 6/6] Surface model provider errors as run errors Extend the existing model-error boundary beyond rate limits so known Bailian provider failures are emitted as AG-UI RUN_ERROR with original message text and optional diagnostic metadata. Constraint: Aone 83711637 asks model-side stream errors to surface as run errors instead of text messages. Rejected: status-code-only text promotion | generic assistant text can mention HTTP/status codes and must not terminate the stream. Rejected: broad provider error framework | the current PR only needs a compact SDK-side classifier for known Bailian model errors. Confidence: high Scope-risk: moderate Directive: Keep raw-text model-error patterns narrow; add negative tests whenever expanding provider text matching. Tested: uv run ruff check changed files; uv run pytest tests/unittests/server tests/unittests/integration => 781 passed, 2 skipped; uv run mypy --config-file mypy.ini . => 377 source files clean; independent code-reviewer APPROVE; architect WATCH with no blockers after heuristic tightening. Change-Id: Ibf1cbd9505b7064cbebd804383710fd7333455cd Not-tested: GitHub CI and MR review-comment readback are pending because push/comment touch shared remote state. Signed-off-by: congxiao.wxx --- .../integration/langgraph/agent_converter.py | 7 +- agentrun/server/agui_protocol.py | 9 +- agentrun/server/invoker.py | 10 +- agentrun/utils/error_utils.py | 324 ++++++++++++++++-- .../integration/test_langgraph_events.py | 38 +- .../test_langgraph_to_agent_event.py | 34 +- tests/unittests/server/test_agui_protocol.py | 28 ++ tests/unittests/server/test_error_utils.py | 138 +++++++- tests/unittests/server/test_invoker.py | 46 ++- 9 files changed, 579 insertions(+), 55 deletions(-) diff --git a/agentrun/integration/langgraph/agent_converter.py b/agentrun/integration/langgraph/agent_converter.py index 8f1d069..04cdd4f 100644 --- a/agentrun/integration/langgraph/agent_converter.py +++ b/agentrun/integration/langgraph/agent_converter.py @@ -30,10 +30,7 @@ from typing import Any, Dict, Iterator, List, Optional, Union from agentrun.server.model import AgentResult, EventType -from agentrun.utils.error_utils import ( - build_error_event_data, - is_rate_limited_error, -) +from agentrun.utils.error_utils import build_error_event_data, is_model_error from agentrun.utils.log import logger # 需要从工具输入中过滤掉的内部字段(LangGraph/MCP 注入的运行时对象) @@ -961,7 +958,7 @@ def _convert_astream_events_event( fallback_code="LLM_ERROR", fallback_message=( error_message - if is_rate_limited_error(error) + if is_model_error(error) else f"LLM error: {error_message}" ), ), diff --git a/agentrun/server/agui_protocol.py b/agentrun/server/agui_protocol.py index a9d6d43..4c891c3 100644 --- a/agentrun/server/agui_protocol.py +++ b/agentrun/server/agui_protocol.py @@ -53,7 +53,14 @@ # ============================================================================ DEFAULT_PREFIX = "/ag-ui/agent" -RUN_ERROR_EXTRA_FIELDS = ("retryable", "retryAfterMs", "traceId") +RUN_ERROR_EXTRA_FIELDS = ( + "retryable", + "retryAfterMs", + "traceId", + "requestId", + "statusCode", + "providerCode", +) @dataclass diff --git a/agentrun/server/invoker.py b/agentrun/server/invoker.py index 26f8956..a34bb7a 100644 --- a/agentrun/server/invoker.py +++ b/agentrun/server/invoker.py @@ -24,17 +24,15 @@ ) import uuid -from agentrun.utils.error_utils import ( - build_error_event_data, - is_rate_limited_error, -) +from agentrun.utils.error_utils import build_error_event_data, is_model_error +from agentrun.utils.reasoning import get_reasoning_content + from .model import AgentEvent, AgentRequest, EventType from .protocol import ( AsyncInvokeAgentHandler, InvokeAgentHandler, SyncInvokeAgentHandler, ) -from agentrun.utils.reasoning import get_reasoning_content class AgentInvoker: @@ -343,7 +341,7 @@ def _wrap_model_chunk(self, item: Any) -> List[AgentEvent]: return events def _wrap_text(self, text: str) -> AgentEvent: - if is_rate_limited_error(text): + if is_model_error(text): return AgentEvent( event=EventType.ERROR, data=build_error_event_data( diff --git a/agentrun/utils/error_utils.py b/agentrun/utils/error_utils.py index f7ebee6..3754ec1 100644 --- a/agentrun/utils/error_utils.py +++ b/agentrun/utils/error_utils.py @@ -1,20 +1,143 @@ -"""Small helpers for model rate-limit errors.""" +"""Small helpers for model-side errors.""" +from dataclasses import dataclass import re -from typing import Any, Dict, Optional +from typing import Any, Dict, Iterator, Optional, Tuple RATE_LIMITED_CODE = "RATE_LIMITED" RATE_LIMITED_RETRY_AFTER_MS = 2000 +_KNOWN_STATUS_CODES = (400, 401, 403, 429, 500, 503) +_HTTP_STATUS_TEXT_RE = re.compile( + r"(?:error\s+code|status\s+code|http\s+status|http)" + r"\D{0,20}(400|401|403|429|500|503)\b" + r"|\b(400|401|403|429|500|503)\s*(?:[-:—]|$)", + re.IGNORECASE, +) _RATE_LIMIT_TEXT_RE = re.compile( r"429|too[-_\s]*many[-_\s]*requests|rate[-_\s]*limit|throttl", re.IGNORECASE, ) -_RATE_LIMIT_CODES = { - "ratelimitexceeded", - "ratelimited", - "throttling", - "toomanyrequests", + + +@dataclass(frozen=True) +class _ModelErrorSpec: + code: str + status_code: int + retryable: bool = False + retry_after_ms: Optional[int] = None + + +@dataclass(frozen=True) +class ModelErrorInfo: + code: str + status_code: Optional[int] = None + provider_code: Optional[str] = None + retryable: bool = False + retry_after_ms: Optional[int] = None + + +_RATE_LIMIT_SPEC = _ModelErrorSpec( + RATE_LIMITED_CODE, + 429, + retryable=True, + retry_after_ms=RATE_LIMITED_RETRY_AFTER_MS, +) +_MODEL_ERROR_SPECS = { + "arrearage": _ModelErrorSpec("MODEL_ARREARAGE", 400), + "datainspectionfailed": _ModelErrorSpec( + "MODEL_DATA_INSPECTION_FAILED", 400 + ), + "invalidapikey": _ModelErrorSpec("MODEL_AUTHENTICATION_ERROR", 401), + "authentication": _ModelErrorSpec("MODEL_AUTHENTICATION_ERROR", 401), + "accessdenied": _ModelErrorSpec("MODEL_ACCESS_DENIED", 403), + "accessdeniedunpurchased": _ModelErrorSpec("MODEL_ACCESS_DENIED", 403), + "modelaccessdenied": _ModelErrorSpec("MODEL_ACCESS_DENIED", 403), + "workspaceaccessdenied": _ModelErrorSpec("MODEL_ACCESS_DENIED", 403), + "throttling": _RATE_LIMIT_SPEC, + "throttlingratequota": _RATE_LIMIT_SPEC, + "throttlingallocationquota": _RATE_LIMIT_SPEC, + "limitrequests": _RATE_LIMIT_SPEC, + "ratelimit": _RATE_LIMIT_SPEC, + "ratelimitexceeded": _RATE_LIMIT_SPEC, + "ratelimited": _RATE_LIMIT_SPEC, + "toomanyrequests": _RATE_LIMIT_SPEC, + "internal": _ModelErrorSpec("MODEL_INTERNAL_ERROR", 500, retryable=True), + "internalerroralgo": _ModelErrorSpec( + "MODEL_INTERNAL_ERROR", 500, retryable=True + ), + "system": _ModelErrorSpec("MODEL_INTERNAL_ERROR", 500, retryable=True), + "modelserving": _ModelErrorSpec( + "MODEL_SERVICE_UNAVAILABLE", 503, retryable=True + ), + "serviceunavailable": _ModelErrorSpec( + "MODEL_SERVICE_UNAVAILABLE", 503, retryable=True + ), +} +_STATUS_CODE_SPECS = { + 400: _ModelErrorSpec("MODEL_BAD_REQUEST", 400), + 401: _ModelErrorSpec("MODEL_AUTHENTICATION_ERROR", 401), + 403: _ModelErrorSpec("MODEL_ACCESS_DENIED", 403), + 429: _RATE_LIMIT_SPEC, + 500: _ModelErrorSpec("MODEL_INTERNAL_ERROR", 500, retryable=True), + 503: _ModelErrorSpec("MODEL_SERVICE_UNAVAILABLE", 503, retryable=True), } +_TEXT_PROVIDER_PATTERNS: Tuple[Tuple[re.Pattern[str], Optional[str]], ...] = ( + (re.compile(r"\bArrearage\b", re.IGNORECASE), "Arrearage"), + ( + re.compile( + r"access\s+denied.*account.*good\s+standing", + re.IGNORECASE, + ), + "Arrearage", + ), + ( + re.compile( + r"\bDataInspectionFailed\b|\bdata[_-]inspection[_-]failed\b", + re.IGNORECASE, + ), + "DataInspectionFailed", + ), + (re.compile(r"\bInvalidApiKey\b", re.IGNORECASE), None), + ( + re.compile(r"\bAuthenticationError\b", re.IGNORECASE), + None, + ), + ( + re.compile(r"\bModel\.AccessDenied\b", re.IGNORECASE), + None, + ), + ( + re.compile(r"\bAccessDenied\.Unpurchased\b", re.IGNORECASE), + None, + ), + (re.compile(r"\bAccessDenied\b", re.IGNORECASE), None), + ( + re.compile(r"\bWorkspaceAccessDenied\b", re.IGNORECASE), + None, + ), + ( + re.compile( + r"\bThrottling(?:\.(?:RateQuota|AllocationQuota))?\b", + re.IGNORECASE, + ), + None, + ), + (re.compile(r"\bLimitRequests\b", re.IGNORECASE), None), + (re.compile(r"\bRateLimit\b", re.IGNORECASE), None), + ( + re.compile(r"\bInternalError(?:\.Algo)?\b", re.IGNORECASE), + None, + ), + (re.compile(r"\bSystemError\b", re.IGNORECASE), None), + ( + re.compile(r"\bModelServingError\b", re.IGNORECASE), + None, + ), + ( + re.compile(r"\bServiceUnavailable\b", re.IGNORECASE), + None, + ), +) def build_error_event_data( @@ -23,60 +146,193 @@ def build_error_event_data( fallback_code: str, fallback_message: str, ) -> Dict[str, Any]: - """Keep the original message; add rate-limit metadata only when matched.""" - if not is_rate_limited_error(error): + """Keep the original message; add model-error metadata when matched.""" + model_error = classify_model_error(error) + if model_error is None: return {"message": fallback_message, "code": fallback_code} data: Dict[str, Any] = { "message": fallback_message, - "code": RATE_LIMITED_CODE, - "retryable": True, - "retryAfterMs": RATE_LIMITED_RETRY_AFTER_MS, + "code": model_error.code, } - trace_id = _get_value(error, "trace_id") or _get_value(error, "traceId") + if model_error.retryable: + data["retryable"] = True + if model_error.retry_after_ms is not None: + data["retryAfterMs"] = model_error.retry_after_ms + if model_error.status_code is not None: + data["statusCode"] = model_error.status_code + if model_error.provider_code: + data["providerCode"] = model_error.provider_code + trace_id = _first_value(error, "trace_id", "traceId") if trace_id: data["traceId"] = str(trace_id) + request_id = _first_value(error, "request_id", "requestId") + if request_id: + data["requestId"] = str(request_id) return data -def is_rate_limited_error(error: Any) -> bool: +def classify_model_error(error: Any) -> Optional[ModelErrorInfo]: if error is None: - return False - if _status_code(error) == 429 or _status_code(_get_value(error, "response")) == 429: - return True - if _rate_limit_code(error) or _rate_limit_code(_get_value(error, "response")): - return True - return bool(_RATE_LIMIT_TEXT_RE.search(str(error))) + return None + + status_code = _status_code(error) + provider_code = _provider_code(error) + spec = _spec_from_provider_code(provider_code) + if spec is not None: + return _to_model_error_info(spec, status_code, provider_code) + + text = str(error) + provider_code = _provider_code_from_text(text) + spec = _spec_from_provider_code(provider_code) + if spec is not None: + return _to_model_error_info( + spec, + status_code or _status_code_from_text(text), + provider_code, + ) + + if status_code in _STATUS_CODE_SPECS: + return _to_model_error_info( + _STATUS_CODE_SPECS[status_code], status_code + ) + + if _RATE_LIMIT_TEXT_RE.search(text): + return _to_model_error_info(_RATE_LIMIT_SPEC, status_code) + + return None + + +def is_model_error(error: Any) -> bool: + return classify_model_error(error) is not None + + +def is_rate_limited_error(error: Any) -> bool: + model_error = classify_model_error(error) + return bool(model_error and model_error.code == RATE_LIMITED_CODE) + + +def _to_model_error_info( + spec: _ModelErrorSpec, + status_code: Optional[int], + provider_code: Optional[str] = None, +) -> ModelErrorInfo: + return ModelErrorInfo( + code=spec.code, + status_code=status_code or spec.status_code, + provider_code=provider_code, + retryable=spec.retryable, + retry_after_ms=spec.retry_after_ms, + ) + + +def _spec_from_provider_code( + provider_code: Optional[str], +) -> Optional[_ModelErrorSpec]: + if not provider_code: + return None + return _MODEL_ERROR_SPECS.get(_normalize_code(provider_code)) def _status_code(obj: Any) -> Optional[int]: - for name in ("status_code", "status", "http_status", "statusCode"): - value = _get_value(obj, name) - if value is None: - continue - try: + for part in _iter_error_parts(obj): + for name in ("status_code", "status", "http_status", "statusCode"): + value = _get_value(part, name) + if value is None: + continue + try: + status_code = int(value) + except (TypeError, ValueError): + continue + if status_code in _KNOWN_STATUS_CODES: + return status_code + return None + + +def _status_code_from_text(text: str) -> Optional[int]: + match = _HTTP_STATUS_TEXT_RE.search(text) + if not match: + return None + for value in match.groups(): + if value: return int(value) - except (TypeError, ValueError): - return None return None -def _rate_limit_code(obj: Any) -> bool: - for name in ("code", "error_code", "errorCode"): - code = _get_value(obj, name) - if code and _normalize_code(code) in _RATE_LIMIT_CODES: - return True - return False +def _provider_code(obj: Any) -> Optional[str]: + for part in _iter_error_parts(obj): + for name in ("code", "error_code", "errorCode"): + code = _get_value(part, name) + if code is None or isinstance(code, (dict, list, tuple)): + continue + value = str(code) + if not value.isdigit(): + return value + return None + + +def _provider_code_from_text(text: str) -> Optional[str]: + for pattern, provider_code in _TEXT_PROVIDER_PATTERNS: + match = pattern.search(text) + if match: + return provider_code or match.group(0) + return None + + +def _iter_error_parts(obj: Any) -> Iterator[Any]: + parts = [obj] + seen = set() + index = 0 + while index < len(parts) and index < 20: + part = parts[index] + index += 1 + if part is None or id(part) in seen: + continue + seen.add(id(part)) + yield part + for name in ("response", "body", "data", "error"): + value = _get_value(part, name) + if value is not None and value is not part: + parts.append(value) + json_body = _json_body(part) + if json_body is not None and json_body is not part: + parts.append(json_body) + + +def _json_body(obj: Any) -> Optional[Any]: + json_method = getattr(obj, "json", None) + if not callable(json_method): + return None + try: + return json_method() + except Exception: + return None def _get_value(obj: Any, name: str) -> Optional[Any]: if obj is None: return None if isinstance(obj, dict): - return obj.get(name) + value = obj.get(name) + if value is not None: + return value + lower_name = name.lower() + for key, value in obj.items(): + if isinstance(key, str) and key.lower() == lower_name: + return value + return None return getattr(obj, name, None) +def _first_value(obj: Any, *names: str) -> Optional[Any]: + for part in _iter_error_parts(obj): + for name in names: + value = _get_value(part, name) + if value: + return value + return None + + def _normalize_code(code: Any) -> str: value = "".join(ch for ch in str(code).lower() if ch.isalnum()) for suffix in ("exception", "error"): diff --git a/tests/unittests/integration/test_langgraph_events.py b/tests/unittests/integration/test_langgraph_events.py index 7d88111..a626a2b 100644 --- a/tests/unittests/integration/test_langgraph_events.py +++ b/tests/unittests/integration/test_langgraph_events.py @@ -11,11 +11,9 @@ 边界事件(如 TOOL_CALL_START/END)由协议层自动生成,转换器不再输出这些事件。 """ -from typing import Dict, List, Union +from typing import Dict from unittest.mock import MagicMock -import pytest - from agentrun.integration.langgraph import AgentRunConverter from agentrun.server.model import AgentEvent, EventType @@ -555,7 +553,7 @@ def test_converter_maintains_state(self): }, } - results1 = list(converter.convert(event1)) + list(converter.convert(event1)) assert converter._tool_call_id_map[0] == "call_stateful" # 第二个 chunk 使用映射 @@ -842,11 +840,41 @@ def test_on_llm_error_rate_limited(self): assert len(results) == 1 assert results[0].event == EventType.ERROR - assert results[0].data["message"] == "RuntimeError: Error code: 429 - rate limit exceeded" + assert ( + results[0].data["message"] + == "RuntimeError: Error code: 429 - rate limit exceeded" + ) assert results[0].data["code"] == "RATE_LIMITED" assert results[0].data["retryable"] is True assert results[0].data["retryAfterMs"] == 2000 + def test_on_llm_error_model_service_unavailable(self): + """测试 on_llm_error 模型服务不可用错误归一化""" + + class ModelError(RuntimeError): + status_code = 503 + code = "ModelServingError" + request_id = "req-503" + + event = { + "event": "on_llm_error", + "run_id": "run_llm", + "data": { + "error": ModelError("model overloaded"), + }, + } + + results = list(AgentRunConverter().to_agui_events(event)) + + assert len(results) == 1 + assert results[0].event == EventType.ERROR + assert results[0].data["message"] == "ModelError: model overloaded" + assert results[0].data["code"] == "MODEL_SERVICE_UNAVAILABLE" + assert results[0].data["retryable"] is True + assert results[0].data["statusCode"] == 503 + assert results[0].data["providerCode"] == "ModelServingError" + assert results[0].data["requestId"] == "req-503" + def test_on_chain_error(self): """测试 on_chain_error 事件 diff --git a/tests/unittests/integration/test_langgraph_to_agent_event.py b/tests/unittests/integration/test_langgraph_to_agent_event.py index 19bbdba..a626a2b 100644 --- a/tests/unittests/integration/test_langgraph_to_agent_event.py +++ b/tests/unittests/integration/test_langgraph_to_agent_event.py @@ -553,7 +553,7 @@ def test_converter_maintains_state(self): }, } - results1 = list(converter.convert(event1)) + list(converter.convert(event1)) assert converter._tool_call_id_map[0] == "call_stateful" # 第二个 chunk 使用映射 @@ -840,11 +840,41 @@ def test_on_llm_error_rate_limited(self): assert len(results) == 1 assert results[0].event == EventType.ERROR - assert results[0].data["message"] == "RuntimeError: Error code: 429 - rate limit exceeded" + assert ( + results[0].data["message"] + == "RuntimeError: Error code: 429 - rate limit exceeded" + ) assert results[0].data["code"] == "RATE_LIMITED" assert results[0].data["retryable"] is True assert results[0].data["retryAfterMs"] == 2000 + def test_on_llm_error_model_service_unavailable(self): + """测试 on_llm_error 模型服务不可用错误归一化""" + + class ModelError(RuntimeError): + status_code = 503 + code = "ModelServingError" + request_id = "req-503" + + event = { + "event": "on_llm_error", + "run_id": "run_llm", + "data": { + "error": ModelError("model overloaded"), + }, + } + + results = list(AgentRunConverter().to_agui_events(event)) + + assert len(results) == 1 + assert results[0].event == EventType.ERROR + assert results[0].data["message"] == "ModelError: model overloaded" + assert results[0].data["code"] == "MODEL_SERVICE_UNAVAILABLE" + assert results[0].data["retryable"] is True + assert results[0].data["statusCode"] == 503 + assert results[0].data["providerCode"] == "ModelServingError" + assert results[0].data["requestId"] == "req-503" + def test_on_chain_error(self): """测试 on_chain_error 事件 diff --git a/tests/unittests/server/test_agui_protocol.py b/tests/unittests/server/test_agui_protocol.py index 3188aab..d87d611 100644 --- a/tests/unittests/server/test_agui_protocol.py +++ b/tests/unittests/server/test_agui_protocol.py @@ -125,6 +125,34 @@ def invoke_agent(request: AgentRequest): assert run_error["code"] == "RATE_LIMITED" assert run_error["retryable"] is True assert run_error["retryAfterMs"] == 2000 + assert run_error["statusCode"] == 429 + + @pytest.mark.asyncio + async def test_text_model_error_stream_payload(self): + """测试模型侧错误文本输出 RUN_ERROR 且透出分类字段""" + + def invoke_agent(request: AgentRequest): + return "Error code: 400 - data_inspection_failed: unsafe content" + + client = self.get_client(invoke_agent) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "Hello"}]}, + ) + + assert response.status_code == 200 + events = _agui_sse_events(response) + types = [event.get("type") for event in events] + run_error = next( + event for event in events if event.get("type") == "RUN_ERROR" + ) + assert "RUN_FINISHED" not in types + assert run_error["message"] == ( + "Error code: 400 - data_inspection_failed: unsafe content" + ) + assert run_error["code"] == "MODEL_DATA_INSPECTION_FAILED" + assert run_error["statusCode"] == 400 + assert run_error["providerCode"] == "DataInspectionFailed" @pytest.mark.asyncio async def test_exception_in_parse_request(self): diff --git a/tests/unittests/server/test_error_utils.py b/tests/unittests/server/test_error_utils.py index 0299fc1..93bcca0 100644 --- a/tests/unittests/server/test_error_utils.py +++ b/tests/unittests/server/test_error_utils.py @@ -1,7 +1,13 @@ -"""Tests for model rate-limit helpers.""" +"""Tests for model-side error helpers.""" + +from types import SimpleNamespace + +import pytest from agentrun.utils.error_utils import ( build_error_event_data, + classify_model_error, + is_model_error, is_rate_limited_error, ) @@ -10,6 +16,67 @@ def test_text_429_rate_limit_is_rate_limited(): assert is_rate_limited_error("Error code: 429 - rate limit exceeded") +@pytest.mark.parametrize( + ("text", "code", "status_code", "retryable"), + [ + ( + "Error code: 400 - Arrearage: account is not in good standing", + "MODEL_ARREARAGE", + 400, + False, + ), + ( + "Error code: 400 - data_inspection_failed: unsafe content", + "MODEL_DATA_INSPECTION_FAILED", + 400, + False, + ), + ( + "Error code: 401 - InvalidApiKey: invalid api key", + "MODEL_AUTHENTICATION_ERROR", + 401, + False, + ), + ( + "Error code: 403 - AccessDenied.Unpurchased: model disabled", + "MODEL_ACCESS_DENIED", + 403, + False, + ), + ( + "Error code: 429 - Throttling.RateQuota: too many requests", + "RATE_LIMITED", + 429, + True, + ), + ( + "Error code: 500 - InternalError.Algo: backend failed", + "MODEL_INTERNAL_ERROR", + 500, + True, + ), + ( + "Error code: 503 - ModelServingError: model overloaded", + "MODEL_SERVICE_UNAVAILABLE", + 503, + True, + ), + ], +) +def test_common_bailian_text_errors_are_model_errors( + text, + code, + status_code, + retryable, +): + model_error = classify_model_error(text) + + assert model_error is not None + assert model_error.code == code + assert model_error.status_code == status_code + assert model_error.retryable is retryable + + def test_structured_status_429_is_rate_limited(): class RateLimitError(RuntimeError): status_code = 429 @@ -17,10 +84,78 @@ class RateLimitError(RuntimeError): assert is_rate_limited_error(RateLimitError("provider overloaded")) +def test_structured_response_body_code_is_model_error(): + error = SimpleNamespace( + response=SimpleNamespace( + status_code=403, + json=lambda: { + "code": "WorkspaceAccessDenied", + "requestId": "req-123", + }, + ) + ) + + data = build_error_event_data( + error, + fallback_code="RuntimeError", + fallback_message="workspace denied", + ) + + assert data == { + "message": "workspace denied", + "code": "MODEL_ACCESS_DENIED", + "statusCode": 403, + "providerCode": "WorkspaceAccessDenied", + "requestId": "req-123", + } + + +@pytest.mark.parametrize( + ("text", "provider_code"), + [ + ( + "Error code: 403 - AccessDenied.Unpurchased: model disabled", + "AccessDenied.Unpurchased", + ), + ( + "Error code: 429 - Throttling.RateQuota: too many requests", + "Throttling.RateQuota", + ), + ( + "Error code: 500 - InternalError.Algo: backend failed", + "InternalError.Algo", + ), + ], +) +def test_text_provider_code_preserves_dotted_suffix(text, provider_code): + data = build_error_event_data( + text, + fallback_code="str", + fallback_message=text, + ) + + assert data["providerCode"] == provider_code + + def test_non_rate_limit_text_is_not_rate_limited(): assert not is_rate_limited_error("normal response") +def test_plain_internal_error_is_not_model_error(): + assert not is_model_error("Internal error") + + +@pytest.mark.parametrize( + "text", + [ + "Error code: 400 - validation failed", + "HTTP 500 - Internal Server Error means the service failed", + ], +) +def test_status_code_only_text_is_not_model_error(text): + assert not is_model_error(text) + + def test_rate_limit_event_uses_original_message(): data = build_error_event_data( "Error code: 429 - rate limit exceeded", @@ -33,4 +168,5 @@ def test_rate_limit_event_uses_original_message(): "code": "RATE_LIMITED", "retryable": True, "retryAfterMs": 2000, + "statusCode": 429, } diff --git a/tests/unittests/server/test_invoker.py b/tests/unittests/server/test_invoker.py index 8337db3..06e7312 100644 --- a/tests/unittests/server/test_invoker.py +++ b/tests/unittests/server/test_invoker.py @@ -203,10 +203,54 @@ async def invoke_agent(req: AgentRequest) -> str: assert len(items) == 1 assert items[0].event == EventType.ERROR - assert items[0].data["message"] == "Error code: 429 - rate limit exceeded" + assert ( + items[0].data["message"] == "Error code: 429 - rate limit exceeded" + ) assert items[0].data["code"] == "RATE_LIMITED" assert items[0].data["retryable"] is True assert items[0].data["retryAfterMs"] == 2000 + assert items[0].data["statusCode"] == 429 + + @pytest.mark.asyncio + async def test_invoke_stream_text_model_error(self, req): + """测试字符串形式的模型权限错误被转成 ERROR""" + + async def invoke_agent(req: AgentRequest) -> str: + return "Error code: 403 - AccessDenied: model is not authorized" + + invoker = AgentInvoker(invoke_agent) + + items: List[AgentEvent] = [] + async for item in invoker.invoke_stream(req): + items.append(item) + + assert len(items) == 1 + assert items[0].event == EventType.ERROR + assert items[0].data["message"] == ( + "Error code: 403 - AccessDenied: model is not authorized" + ) + assert items[0].data["code"] == "MODEL_ACCESS_DENIED" + assert items[0].data["statusCode"] == 403 + assert items[0].data["providerCode"] == "AccessDenied" + + @pytest.mark.asyncio + async def test_invoke_stream_status_code_text_stays_text(self, req): + """测试普通状态码说明文本不会被误转成 ERROR""" + + async def invoke_agent(req: AgentRequest) -> str: + return "HTTP 500 - Internal Server Error means the service failed" + + invoker = AgentInvoker(invoke_agent) + + items: List[AgentEvent] = [] + async for item in invoker.invoke_stream(req): + items.append(item) + + assert len(items) == 1 + assert items[0].event == EventType.TEXT + assert items[0].data["delta"] == ( + "HTTP 500 - Internal Server Error means the service failed" + ) class TestInvokerSync: