diff --git a/agentrun/integration/langgraph/agent_converter.py b/agentrun/integration/langgraph/agent_converter.py index f00a46b..04cdd4f 100644 --- a/agentrun/integration/langgraph/agent_converter.py +++ b/agentrun/integration/langgraph/agent_converter.py @@ -30,6 +30,7 @@ from typing import Any, Dict, Iterator, List, Optional, Union from agentrun.server.model import AgentResult, EventType +from agentrun.utils.error_utils import build_error_event_data, is_model_error from agentrun.utils.log import logger # 需要从工具输入中过滤掉的内部字段(LangGraph/MCP 注入的运行时对象) @@ -952,10 +953,15 @@ def _convert_astream_events_event( yield AgentResult( event=EventType.ERROR, - data={ - "message": f"LLM error: {error_message}", - "code": "LLM_ERROR", - }, + data=build_error_event_data( + error, + fallback_code="LLM_ERROR", + fallback_message=( + error_message + if is_model_error(error) + else f"LLM error: {error_message}" + ), + ), ) # 8. Chain 错误 diff --git a/agentrun/server/agui_protocol.py b/agentrun/server/agui_protocol.py index 5e8ccb4..4c891c3 100644 --- a/agentrun/server/agui_protocol.py +++ b/agentrun/server/agui_protocol.py @@ -53,6 +53,14 @@ # ============================================================================ DEFAULT_PREFIX = "/ag-ui/agent" +RUN_ERROR_EXTRA_FIELDS = ( + "retryable", + "retryAfterMs", + "traceId", + "requestId", + "statusCode", + "providerCode", +) @dataclass @@ -743,12 +751,18 @@ def _process_event_with_boundaries( # ERROR 事件 if event.event == EventType.ERROR: - yield self._encoder.encode( - RunErrorEvent( - message=event.data.get("message", ""), - code=event.data.get("code"), - ) + agui_event = RunErrorEvent( + message=event.data.get("message", ""), + code=event.data.get("code"), ) + extra_fields = { + key: event.data[key] + for key in RUN_ERROR_EXTRA_FIELDS + if key in event.data + } + if extra_fields: + agui_event = agui_event.model_copy(update=extra_fields) + yield self._encoder.encode(agui_event) return # STATE 事件 diff --git a/agentrun/server/invoker.py b/agentrun/server/invoker.py index 763e6a0..a34bb7a 100644 --- a/agentrun/server/invoker.py +++ b/agentrun/server/invoker.py @@ -24,13 +24,15 @@ ) import uuid +from agentrun.utils.error_utils import build_error_event_data, is_model_error +from agentrun.utils.reasoning import get_reasoning_content + from .model import AgentEvent, AgentRequest, EventType from .protocol import ( AsyncInvokeAgentHandler, InvokeAgentHandler, SyncInvokeAgentHandler, ) -from agentrun.utils.reasoning import get_reasoning_content class AgentInvoker: @@ -117,10 +119,7 @@ async def invoke_stream( if isinstance(item, str): if not item: # 跳过空字符串 continue - yield AgentEvent( - event=EventType.TEXT, - data={"delta": item}, - ) + yield self._wrap_text(item) elif isinstance(item, AgentEvent): # 处理用户返回的事件 @@ -142,7 +141,11 @@ async def invoke_stream( logger.error(f"Agent 调用出错: {e}", exc_info=True) yield AgentEvent( event=EventType.ERROR, - data={"message": str(e), "code": type(e).__name__}, + data=build_error_event_data( + e, + fallback_code=type(e).__name__, + fallback_message=str(e), + ), ) def _process_user_event( @@ -227,12 +230,7 @@ def _wrap_non_stream(self, result: Any) -> List[AgentEvent]: return results if isinstance(result, str): - results.append( - AgentEvent( - event=EventType.TEXT, - data={"delta": result}, - ) - ) + results.append(self._wrap_text(result)) elif isinstance(result, AgentEvent): # 处理可能的 TOOL_CALL 展开 @@ -243,12 +241,7 @@ def _wrap_non_stream(self, result: Any) -> List[AgentEvent]: if isinstance(item, AgentEvent): results.extend(self._process_user_event(item)) elif isinstance(item, str) and item: - results.append( - AgentEvent( - event=EventType.TEXT, - data={"delta": item}, - ) - ) + results.append(self._wrap_text(item)) else: results.extend(self._wrap_model_chunk(item)) @@ -275,10 +268,7 @@ async def _wrap_stream( if isinstance(item, str): if not item: continue - yield AgentEvent( - event=EventType.TEXT, - data={"delta": item}, - ) + yield self._wrap_text(item) elif isinstance(item, AgentEvent): for processed_event in self._process_user_event(item): @@ -346,15 +336,25 @@ def _wrap_model_chunk(self, item: Any) -> List[AgentEvent]: content = self._read_attr_or_key(item, "content") if isinstance(content, str) and content: - events.append( - AgentEvent( - event=EventType.TEXT, - data={"delta": content}, - ) - ) + events.append(self._wrap_text(content)) return events + def _wrap_text(self, text: str) -> AgentEvent: + if is_model_error(text): + return AgentEvent( + event=EventType.ERROR, + data=build_error_event_data( + text, + fallback_code=type(text).__name__, + fallback_message=text, + ), + ) + return AgentEvent( + event=EventType.TEXT, + data={"delta": text}, + ) + def _read_attr_or_key(self, obj: Any, key: str) -> Any: if isinstance(obj, dict): return obj.get(key) diff --git a/agentrun/utils/error_utils.py b/agentrun/utils/error_utils.py new file mode 100644 index 0000000..3754ec1 --- /dev/null +++ b/agentrun/utils/error_utils.py @@ -0,0 +1,341 @@ +"""Small helpers for model-side errors.""" + +from dataclasses import dataclass +import re +from typing import Any, Dict, Iterator, Optional, Tuple + +RATE_LIMITED_CODE = "RATE_LIMITED" +RATE_LIMITED_RETRY_AFTER_MS = 2000 +_KNOWN_STATUS_CODES = (400, 401, 403, 429, 500, 503) +_HTTP_STATUS_TEXT_RE = re.compile( + r"(?:error\s+code|status\s+code|http\s+status|http)" + r"\D{0,20}(400|401|403|429|500|503)\b" + r"|\b(400|401|403|429|500|503)\s*(?:[-:—]|$)", + re.IGNORECASE, +) +_RATE_LIMIT_TEXT_RE = re.compile( + r"429|too[-_\s]*many[-_\s]*requests|rate[-_\s]*limit|throttl", + re.IGNORECASE, +) + + +@dataclass(frozen=True) +class _ModelErrorSpec: + code: str + status_code: int + retryable: bool = False + retry_after_ms: Optional[int] = None + + +@dataclass(frozen=True) +class ModelErrorInfo: + code: str + status_code: Optional[int] = None + provider_code: Optional[str] = None + retryable: bool = False + retry_after_ms: Optional[int] = None + + +_RATE_LIMIT_SPEC = _ModelErrorSpec( + RATE_LIMITED_CODE, + 429, + retryable=True, + retry_after_ms=RATE_LIMITED_RETRY_AFTER_MS, +) +_MODEL_ERROR_SPECS = { + "arrearage": _ModelErrorSpec("MODEL_ARREARAGE", 400), + "datainspectionfailed": _ModelErrorSpec( + "MODEL_DATA_INSPECTION_FAILED", 400 + ), + "invalidapikey": _ModelErrorSpec("MODEL_AUTHENTICATION_ERROR", 401), + "authentication": _ModelErrorSpec("MODEL_AUTHENTICATION_ERROR", 401), + "accessdenied": _ModelErrorSpec("MODEL_ACCESS_DENIED", 403), + "accessdeniedunpurchased": _ModelErrorSpec("MODEL_ACCESS_DENIED", 403), + "modelaccessdenied": _ModelErrorSpec("MODEL_ACCESS_DENIED", 403), + "workspaceaccessdenied": _ModelErrorSpec("MODEL_ACCESS_DENIED", 403), + "throttling": _RATE_LIMIT_SPEC, + "throttlingratequota": _RATE_LIMIT_SPEC, + "throttlingallocationquota": _RATE_LIMIT_SPEC, + "limitrequests": _RATE_LIMIT_SPEC, + "ratelimit": _RATE_LIMIT_SPEC, + "ratelimitexceeded": _RATE_LIMIT_SPEC, + "ratelimited": _RATE_LIMIT_SPEC, + "toomanyrequests": _RATE_LIMIT_SPEC, + "internal": _ModelErrorSpec("MODEL_INTERNAL_ERROR", 500, retryable=True), + "internalerroralgo": _ModelErrorSpec( + "MODEL_INTERNAL_ERROR", 500, retryable=True + ), + "system": _ModelErrorSpec("MODEL_INTERNAL_ERROR", 500, retryable=True), + "modelserving": _ModelErrorSpec( + "MODEL_SERVICE_UNAVAILABLE", 503, retryable=True + ), + "serviceunavailable": _ModelErrorSpec( + "MODEL_SERVICE_UNAVAILABLE", 503, retryable=True + ), +} +_STATUS_CODE_SPECS = { + 400: _ModelErrorSpec("MODEL_BAD_REQUEST", 400), + 401: _ModelErrorSpec("MODEL_AUTHENTICATION_ERROR", 401), + 403: _ModelErrorSpec("MODEL_ACCESS_DENIED", 403), + 429: _RATE_LIMIT_SPEC, + 500: _ModelErrorSpec("MODEL_INTERNAL_ERROR", 500, retryable=True), + 503: _ModelErrorSpec("MODEL_SERVICE_UNAVAILABLE", 503, retryable=True), +} +_TEXT_PROVIDER_PATTERNS: Tuple[Tuple[re.Pattern[str], Optional[str]], ...] = ( + (re.compile(r"\bArrearage\b", re.IGNORECASE), "Arrearage"), + ( + re.compile( + r"access\s+denied.*account.*good\s+standing", + re.IGNORECASE, + ), + "Arrearage", + ), + ( + re.compile( + r"\bDataInspectionFailed\b|\bdata[_-]inspection[_-]failed\b", + re.IGNORECASE, + ), + "DataInspectionFailed", + ), + (re.compile(r"\bInvalidApiKey\b", re.IGNORECASE), None), + ( + re.compile(r"\bAuthenticationError\b", re.IGNORECASE), + None, + ), + ( + re.compile(r"\bModel\.AccessDenied\b", re.IGNORECASE), + None, + ), + ( + re.compile(r"\bAccessDenied\.Unpurchased\b", re.IGNORECASE), + None, + ), + (re.compile(r"\bAccessDenied\b", re.IGNORECASE), None), + ( + re.compile(r"\bWorkspaceAccessDenied\b", re.IGNORECASE), + None, + ), + ( + re.compile( + r"\bThrottling(?:\.(?:RateQuota|AllocationQuota))?\b", + re.IGNORECASE, + ), + None, + ), + (re.compile(r"\bLimitRequests\b", re.IGNORECASE), None), + (re.compile(r"\bRateLimit\b", re.IGNORECASE), None), + ( + re.compile(r"\bInternalError(?:\.Algo)?\b", re.IGNORECASE), + None, + ), + (re.compile(r"\bSystemError\b", re.IGNORECASE), None), + ( + re.compile(r"\bModelServingError\b", re.IGNORECASE), + None, + ), + ( + re.compile(r"\bServiceUnavailable\b", re.IGNORECASE), + None, + ), +) + + +def build_error_event_data( + error: Any, + *, + fallback_code: str, + fallback_message: str, +) -> Dict[str, Any]: + """Keep the original message; add model-error metadata when matched.""" + model_error = classify_model_error(error) + if model_error is None: + return {"message": fallback_message, "code": fallback_code} + + data: Dict[str, Any] = { + "message": fallback_message, + "code": model_error.code, + } + if model_error.retryable: + data["retryable"] = True + if model_error.retry_after_ms is not None: + data["retryAfterMs"] = model_error.retry_after_ms + if model_error.status_code is not None: + data["statusCode"] = model_error.status_code + if model_error.provider_code: + data["providerCode"] = model_error.provider_code + trace_id = _first_value(error, "trace_id", "traceId") + if trace_id: + data["traceId"] = str(trace_id) + request_id = _first_value(error, "request_id", "requestId") + if request_id: + data["requestId"] = str(request_id) + return data + + +def classify_model_error(error: Any) -> Optional[ModelErrorInfo]: + if error is None: + return None + + status_code = _status_code(error) + provider_code = _provider_code(error) + spec = _spec_from_provider_code(provider_code) + if spec is not None: + return _to_model_error_info(spec, status_code, provider_code) + + text = str(error) + provider_code = _provider_code_from_text(text) + spec = _spec_from_provider_code(provider_code) + if spec is not None: + return _to_model_error_info( + spec, + status_code or _status_code_from_text(text), + provider_code, + ) + + if status_code in _STATUS_CODE_SPECS: + return _to_model_error_info( + _STATUS_CODE_SPECS[status_code], status_code + ) + + if _RATE_LIMIT_TEXT_RE.search(text): + return _to_model_error_info(_RATE_LIMIT_SPEC, status_code) + + return None + + +def is_model_error(error: Any) -> bool: + return classify_model_error(error) is not None + + +def is_rate_limited_error(error: Any) -> bool: + model_error = classify_model_error(error) + return bool(model_error and model_error.code == RATE_LIMITED_CODE) + + +def _to_model_error_info( + spec: _ModelErrorSpec, + status_code: Optional[int], + provider_code: Optional[str] = None, +) -> ModelErrorInfo: + return ModelErrorInfo( + code=spec.code, + status_code=status_code or spec.status_code, + provider_code=provider_code, + retryable=spec.retryable, + retry_after_ms=spec.retry_after_ms, + ) + + +def _spec_from_provider_code( + provider_code: Optional[str], +) -> Optional[_ModelErrorSpec]: + if not provider_code: + return None + return _MODEL_ERROR_SPECS.get(_normalize_code(provider_code)) + + +def _status_code(obj: Any) -> Optional[int]: + for part in _iter_error_parts(obj): + for name in ("status_code", "status", "http_status", "statusCode"): + value = _get_value(part, name) + if value is None: + continue + try: + status_code = int(value) + except (TypeError, ValueError): + continue + if status_code in _KNOWN_STATUS_CODES: + return status_code + return None + + +def _status_code_from_text(text: str) -> Optional[int]: + match = _HTTP_STATUS_TEXT_RE.search(text) + if not match: + return None + for value in match.groups(): + if value: + return int(value) + return None + + +def _provider_code(obj: Any) -> Optional[str]: + for part in _iter_error_parts(obj): + for name in ("code", "error_code", "errorCode"): + code = _get_value(part, name) + if code is None or isinstance(code, (dict, list, tuple)): + continue + value = str(code) + if not value.isdigit(): + return value + return None + + +def _provider_code_from_text(text: str) -> Optional[str]: + for pattern, provider_code in _TEXT_PROVIDER_PATTERNS: + match = pattern.search(text) + if match: + return provider_code or match.group(0) + return None + + +def _iter_error_parts(obj: Any) -> Iterator[Any]: + parts = [obj] + seen = set() + index = 0 + while index < len(parts) and index < 20: + part = parts[index] + index += 1 + if part is None or id(part) in seen: + continue + seen.add(id(part)) + yield part + for name in ("response", "body", "data", "error"): + value = _get_value(part, name) + if value is not None and value is not part: + parts.append(value) + json_body = _json_body(part) + if json_body is not None and json_body is not part: + parts.append(json_body) + + +def _json_body(obj: Any) -> Optional[Any]: + json_method = getattr(obj, "json", None) + if not callable(json_method): + return None + try: + return json_method() + except Exception: + return None + + +def _get_value(obj: Any, name: str) -> Optional[Any]: + if obj is None: + return None + if isinstance(obj, dict): + value = obj.get(name) + if value is not None: + return value + lower_name = name.lower() + for key, value in obj.items(): + if isinstance(key, str) and key.lower() == lower_name: + return value + return None + return getattr(obj, name, None) + + +def _first_value(obj: Any, *names: str) -> Optional[Any]: + for part in _iter_error_parts(obj): + for name in names: + value = _get_value(part, name) + if value: + return value + return None + + +def _normalize_code(code: Any) -> str: + value = "".join(ch for ch in str(code).lower() if ch.isalnum()) + for suffix in ("exception", "error"): + if value.endswith(suffix): + return value[: -len(suffix)] + return value diff --git a/tests/unittests/integration/test_langgraph_events.py b/tests/unittests/integration/test_langgraph_events.py index 0e51714..a626a2b 100644 --- a/tests/unittests/integration/test_langgraph_events.py +++ b/tests/unittests/integration/test_langgraph_events.py @@ -11,11 +11,9 @@ 边界事件(如 TOOL_CALL_START/END)由协议层自动生成,转换器不再输出这些事件。 """ -from typing import Dict, List, Union +from typing import Dict from unittest.mock import MagicMock -import pytest - from agentrun.integration.langgraph import AgentRunConverter from agentrun.server.model import AgentEvent, EventType @@ -555,7 +553,7 @@ def test_converter_maintains_state(self): }, } - results1 = list(converter.convert(event1)) + list(converter.convert(event1)) assert converter._tool_call_id_map[0] == "call_stateful" # 第二个 chunk 使用映射 @@ -816,7 +814,7 @@ def test_on_llm_error(self): "event": "on_llm_error", "run_id": "run_llm", "data": { - "error": RuntimeError("API rate limit exceeded"), + "error": RuntimeError("Model backend failed"), }, } @@ -828,6 +826,55 @@ def test_on_llm_error(self): assert "RuntimeError" in results[0].data["message"] assert results[0].data["code"] == "LLM_ERROR" + def test_on_llm_error_rate_limited(self): + """测试 on_llm_error 限流错误归一化且 message 保留原始错误""" + event = { + "event": "on_llm_error", + "run_id": "run_llm", + "data": { + "error": RuntimeError("Error code: 429 - rate limit exceeded"), + }, + } + + results = list(AgentRunConverter().to_agui_events(event)) + + assert len(results) == 1 + assert results[0].event == EventType.ERROR + assert ( + results[0].data["message"] + == "RuntimeError: Error code: 429 - rate limit exceeded" + ) + assert results[0].data["code"] == "RATE_LIMITED" + assert results[0].data["retryable"] is True + assert results[0].data["retryAfterMs"] == 2000 + + def test_on_llm_error_model_service_unavailable(self): + """测试 on_llm_error 模型服务不可用错误归一化""" + + class ModelError(RuntimeError): + status_code = 503 + code = "ModelServingError" + request_id = "req-503" + + event = { + "event": "on_llm_error", + "run_id": "run_llm", + "data": { + "error": ModelError("model overloaded"), + }, + } + + results = list(AgentRunConverter().to_agui_events(event)) + + assert len(results) == 1 + assert results[0].event == EventType.ERROR + assert results[0].data["message"] == "ModelError: model overloaded" + assert results[0].data["code"] == "MODEL_SERVICE_UNAVAILABLE" + assert results[0].data["retryable"] is True + assert results[0].data["statusCode"] == 503 + assert results[0].data["providerCode"] == "ModelServingError" + assert results[0].data["requestId"] == "req-503" + def test_on_chain_error(self): """测试 on_chain_error 事件 diff --git a/tests/unittests/integration/test_langgraph_to_agent_event.py b/tests/unittests/integration/test_langgraph_to_agent_event.py index 74933d1..a626a2b 100644 --- a/tests/unittests/integration/test_langgraph_to_agent_event.py +++ b/tests/unittests/integration/test_langgraph_to_agent_event.py @@ -553,7 +553,7 @@ def test_converter_maintains_state(self): }, } - results1 = list(converter.convert(event1)) + list(converter.convert(event1)) assert converter._tool_call_id_map[0] == "call_stateful" # 第二个 chunk 使用映射 @@ -814,7 +814,7 @@ def test_on_llm_error(self): "event": "on_llm_error", "run_id": "run_llm", "data": { - "error": RuntimeError("API rate limit exceeded"), + "error": RuntimeError("Model backend failed"), }, } @@ -826,6 +826,55 @@ def test_on_llm_error(self): assert "RuntimeError" in results[0].data["message"] assert results[0].data["code"] == "LLM_ERROR" + def test_on_llm_error_rate_limited(self): + """测试 on_llm_error 限流错误归一化且 message 保留原始错误""" + event = { + "event": "on_llm_error", + "run_id": "run_llm", + "data": { + "error": RuntimeError("Error code: 429 - rate limit exceeded"), + }, + } + + results = list(AgentRunConverter().to_agui_events(event)) + + assert len(results) == 1 + assert results[0].event == EventType.ERROR + assert ( + results[0].data["message"] + == "RuntimeError: Error code: 429 - rate limit exceeded" + ) + assert results[0].data["code"] == "RATE_LIMITED" + assert results[0].data["retryable"] is True + assert results[0].data["retryAfterMs"] == 2000 + + def test_on_llm_error_model_service_unavailable(self): + """测试 on_llm_error 模型服务不可用错误归一化""" + + class ModelError(RuntimeError): + status_code = 503 + code = "ModelServingError" + request_id = "req-503" + + event = { + "event": "on_llm_error", + "run_id": "run_llm", + "data": { + "error": ModelError("model overloaded"), + }, + } + + results = list(AgentRunConverter().to_agui_events(event)) + + assert len(results) == 1 + assert results[0].event == EventType.ERROR + assert results[0].data["message"] == "ModelError: model overloaded" + assert results[0].data["code"] == "MODEL_SERVICE_UNAVAILABLE" + assert results[0].data["retryable"] is True + assert results[0].data["statusCode"] == 503 + assert results[0].data["providerCode"] == "ModelServingError" + assert results[0].data["requestId"] == "req-503" + def test_on_chain_error(self): """测试 on_chain_error 事件 diff --git a/tests/unittests/server/test_agui_protocol.py b/tests/unittests/server/test_agui_protocol.py index 0896a08..d87d611 100644 --- a/tests/unittests/server/test_agui_protocol.py +++ b/tests/unittests/server/test_agui_protocol.py @@ -101,6 +101,59 @@ def invoke_agent(request: AgentRequest): assert "RUN_ERROR" in types + @pytest.mark.asyncio + async def test_text_rate_limit_error_stream_payload(self): + """测试文本形式的 429 错误输出 RUN_ERROR 且无 RUN_FINISHED""" + + def invoke_agent(request: AgentRequest): + return "Error code: 429 - rate limit exceeded" + + client = self.get_client(invoke_agent) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "Hello"}]}, + ) + + assert response.status_code == 200 + events = _agui_sse_events(response) + types = [event.get("type") for event in events] + run_error = next( + event for event in events if event.get("type") == "RUN_ERROR" + ) + assert "RUN_FINISHED" not in types + assert run_error["message"] == "Error code: 429 - rate limit exceeded" + assert run_error["code"] == "RATE_LIMITED" + assert run_error["retryable"] is True + assert run_error["retryAfterMs"] == 2000 + assert run_error["statusCode"] == 429 + + @pytest.mark.asyncio + async def test_text_model_error_stream_payload(self): + """测试模型侧错误文本输出 RUN_ERROR 且透出分类字段""" + + def invoke_agent(request: AgentRequest): + return "Error code: 400 - data_inspection_failed: unsafe content" + + client = self.get_client(invoke_agent) + response = client.post( + "/ag-ui/agent", + json={"messages": [{"role": "user", "content": "Hello"}]}, + ) + + assert response.status_code == 200 + events = _agui_sse_events(response) + types = [event.get("type") for event in events] + run_error = next( + event for event in events if event.get("type") == "RUN_ERROR" + ) + assert "RUN_FINISHED" not in types + assert run_error["message"] == ( + "Error code: 400 - data_inspection_failed: unsafe content" + ) + assert run_error["code"] == "MODEL_DATA_INSPECTION_FAILED" + assert run_error["statusCode"] == 400 + assert run_error["providerCode"] == "DataInspectionFailed" + @pytest.mark.asyncio async def test_exception_in_parse_request(self): """测试 parse_request 中的异常处理(覆盖 155-156 行) diff --git a/tests/unittests/server/test_error_utils.py b/tests/unittests/server/test_error_utils.py new file mode 100644 index 0000000..93bcca0 --- /dev/null +++ b/tests/unittests/server/test_error_utils.py @@ -0,0 +1,172 @@ +"""Tests for model-side error helpers.""" + +from types import SimpleNamespace + +import pytest + +from agentrun.utils.error_utils import ( + build_error_event_data, + classify_model_error, + is_model_error, + is_rate_limited_error, +) + + +def test_text_429_rate_limit_is_rate_limited(): + assert is_rate_limited_error("Error code: 429 - rate limit exceeded") + + +@pytest.mark.parametrize( + ("text", "code", "status_code", "retryable"), + [ + ( + "Error code: 400 - Arrearage: account is not in good standing", + "MODEL_ARREARAGE", + 400, + False, + ), + ( + "Error code: 400 - data_inspection_failed: unsafe content", + "MODEL_DATA_INSPECTION_FAILED", + 400, + False, + ), + ( + "Error code: 401 - InvalidApiKey: invalid api key", + "MODEL_AUTHENTICATION_ERROR", + 401, + False, + ), + ( + "Error code: 403 - AccessDenied.Unpurchased: model disabled", + "MODEL_ACCESS_DENIED", + 403, + False, + ), + ( + "Error code: 429 - Throttling.RateQuota: too many requests", + "RATE_LIMITED", + 429, + True, + ), + ( + "Error code: 500 - InternalError.Algo: backend failed", + "MODEL_INTERNAL_ERROR", + 500, + True, + ), + ( + "Error code: 503 - ModelServingError: model overloaded", + "MODEL_SERVICE_UNAVAILABLE", + 503, + True, + ), + ], +) +def test_common_bailian_text_errors_are_model_errors( + text, + code, + status_code, + retryable, +): + model_error = classify_model_error(text) + + assert model_error is not None + assert model_error.code == code + assert model_error.status_code == status_code + assert model_error.retryable is retryable + + +def test_structured_status_429_is_rate_limited(): + class RateLimitError(RuntimeError): + status_code = 429 + + assert is_rate_limited_error(RateLimitError("provider overloaded")) + + +def test_structured_response_body_code_is_model_error(): + error = SimpleNamespace( + response=SimpleNamespace( + status_code=403, + json=lambda: { + "code": "WorkspaceAccessDenied", + "requestId": "req-123", + }, + ) + ) + + data = build_error_event_data( + error, + fallback_code="RuntimeError", + fallback_message="workspace denied", + ) + + assert data == { + "message": "workspace denied", + "code": "MODEL_ACCESS_DENIED", + "statusCode": 403, + "providerCode": "WorkspaceAccessDenied", + "requestId": "req-123", + } + + +@pytest.mark.parametrize( + ("text", "provider_code"), + [ + ( + "Error code: 403 - AccessDenied.Unpurchased: model disabled", + "AccessDenied.Unpurchased", + ), + ( + "Error code: 429 - Throttling.RateQuota: too many requests", + "Throttling.RateQuota", + ), + ( + "Error code: 500 - InternalError.Algo: backend failed", + "InternalError.Algo", + ), + ], +) +def test_text_provider_code_preserves_dotted_suffix(text, provider_code): + data = build_error_event_data( + text, + fallback_code="str", + fallback_message=text, + ) + + assert data["providerCode"] == provider_code + + +def test_non_rate_limit_text_is_not_rate_limited(): + assert not is_rate_limited_error("normal response") + + +def test_plain_internal_error_is_not_model_error(): + assert not is_model_error("Internal error") + + +@pytest.mark.parametrize( + "text", + [ + "Error code: 400 - validation failed", + "HTTP 500 - Internal Server Error means the service failed", + ], +) +def test_status_code_only_text_is_not_model_error(text): + assert not is_model_error(text) + + +def test_rate_limit_event_uses_original_message(): + data = build_error_event_data( + "Error code: 429 - rate limit exceeded", + fallback_code="str", + fallback_message="Error code: 429 - rate limit exceeded", + ) + + assert data == { + "message": "Error code: 429 - rate limit exceeded", + "code": "RATE_LIMITED", + "retryable": True, + "retryAfterMs": 2000, + "statusCode": 429, + } diff --git a/tests/unittests/server/test_invoker.py b/tests/unittests/server/test_invoker.py index 76d8c65..06e7312 100644 --- a/tests/unittests/server/test_invoker.py +++ b/tests/unittests/server/test_invoker.py @@ -188,6 +188,70 @@ async def invoke_agent(req: AgentRequest) -> str: assert "Test error" in error_event.data["message"] assert error_event.data["code"] == "ValueError" + @pytest.mark.asyncio + async def test_invoke_stream_text_rate_limit_error(self, req): + """测试字符串形式的模型限流错误被转成 ERROR""" + + async def invoke_agent(req: AgentRequest) -> str: + return "Error code: 429 - rate limit exceeded" + + invoker = AgentInvoker(invoke_agent) + + items: List[AgentEvent] = [] + async for item in invoker.invoke_stream(req): + items.append(item) + + assert len(items) == 1 + assert items[0].event == EventType.ERROR + assert ( + items[0].data["message"] == "Error code: 429 - rate limit exceeded" + ) + assert items[0].data["code"] == "RATE_LIMITED" + assert items[0].data["retryable"] is True + assert items[0].data["retryAfterMs"] == 2000 + assert items[0].data["statusCode"] == 429 + + @pytest.mark.asyncio + async def test_invoke_stream_text_model_error(self, req): + """测试字符串形式的模型权限错误被转成 ERROR""" + + async def invoke_agent(req: AgentRequest) -> str: + return "Error code: 403 - AccessDenied: model is not authorized" + + invoker = AgentInvoker(invoke_agent) + + items: List[AgentEvent] = [] + async for item in invoker.invoke_stream(req): + items.append(item) + + assert len(items) == 1 + assert items[0].event == EventType.ERROR + assert items[0].data["message"] == ( + "Error code: 403 - AccessDenied: model is not authorized" + ) + assert items[0].data["code"] == "MODEL_ACCESS_DENIED" + assert items[0].data["statusCode"] == 403 + assert items[0].data["providerCode"] == "AccessDenied" + + @pytest.mark.asyncio + async def test_invoke_stream_status_code_text_stays_text(self, req): + """测试普通状态码说明文本不会被误转成 ERROR""" + + async def invoke_agent(req: AgentRequest) -> str: + return "HTTP 500 - Internal Server Error means the service failed" + + invoker = AgentInvoker(invoke_agent) + + items: List[AgentEvent] = [] + async for item in invoker.invoke_stream(req): + items.append(item) + + assert len(items) == 1 + assert items[0].event == EventType.TEXT + assert items[0].data["delta"] == ( + "HTTP 500 - Internal Server Error means the service failed" + ) + class TestInvokerSync: """同步调用测试"""