Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/bub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,12 @@
from bub.framework import DEFAULT_HOME, BubFramework
from bub.hookspecs import hookimpl
from bub.tools import tool
from bub.turn_admission import AdmitDecision, SteeringBuffer, TurnSnapshot
from bub.turn_admission import AdmitDecision, TurnSnapshot

__all__ = [
"AdmitDecision",
"BubFramework",
"Settings",
"SteeringBuffer",
"TurnSnapshot",
"config",
"ensure_config",
Expand Down
20 changes: 20 additions & 0 deletions src/bub/builtin/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import asyncio
import inspect
import re
import shlex
Expand All @@ -22,6 +23,7 @@
)
from bub.builtin.settings import load_settings
from bub.builtin.tape import Tape
from bub.envelope import field_of
from bub.framework import BubFramework
from bub.runtime import AsyncStreamEvents, StreamEvent, StreamState
from bub.skills import discover_skills, render_skills_prompt
Expand Down Expand Up @@ -95,6 +97,7 @@ async def run_stream(
StreamEvent("final", {"text": "error: empty prompt", "ok": False}),
])

state.setdefault("session_id", session_id)
tape = self.tape.session_tape(
session_id, workspace_from_state(state), context=replace(self.tape.context, state=state)
)
Expand All @@ -117,6 +120,7 @@ async def run_stream(
allowed_skills=allowed_skills,
allowed_tools=allowed_tools,
)

return self._events_with_callback(events, callback=stack.aclose)

async def _run_command(self, tape: Tape, *, line: str) -> str:
Expand Down Expand Up @@ -273,6 +277,7 @@ async def _stream_events_with_auto_handoff(
state.error = output.error
state.usage = output.usage
elapsed_ms = int((time.monotonic() - start) * 1000)
should_continue = should_continue or self._has_steering_messages(tape.context.state)
if not should_continue:
await tape.append_event(
"loop.step",
Expand Down Expand Up @@ -355,12 +360,23 @@ async def _run_once_stream(
resolved_model = model or self.settings.model

model_tools_for_call = model_tools(tools)
steering_inbox = self.framework.get_steering_inbox()
steering_envelopes = await steering_inbox.drain_messages(tape.context.state) if steering_inbox else []
steering_messages = list(
await asyncio.gather(*[
self.framework.build_prompt(
message, session_id=field_of(message, "session_id"), state=tape.context.state
)
for message in steering_envelopes
])
)
return self.model_runner.run(
tape=tape,
model=resolved_model,
tools=model_tools_for_call,
system_prompt=system_prompt,
prompt=prompt,
steering_messages=steering_messages,
)

def _system_prompt(
Expand All @@ -384,6 +400,10 @@ def _continue_prompt(self, tape: Tape) -> str:
return f"{CONTINUE_PROMPT} [context: {tape.context.state['context']}]"
return CONTINUE_PROMPT

def _has_steering_messages(self, state: State) -> bool:
steering_inbox = self.framework.get_steering_inbox()
return bool(steering_inbox and steering_inbox.message_count(state) > 0)


@dataclass(frozen=True)
class Args:
Expand Down
13 changes: 10 additions & 3 deletions src/bub/builtin/hook_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from bub.builtin.agent import Agent
from bub.builtin.context import default_tape_context
from bub.builtin.settings import DEFAULT_MODEL
from bub.builtin.steering import InMemorySteeringInbox
from bub.channels.base import Channel
from bub.channels.message import ChannelMessage, MediaItem
from bub.envelope import content_of, field_of
Expand All @@ -18,7 +19,7 @@
from bub.runtime import AsyncStreamEvents
from bub.tape import TapeContext, TapeStore
from bub.turn_admission import AdmitDecision, TurnSnapshot
from bub.types import Envelope, MessageHandler, State
from bub.types import Envelope, MessageHandler, State, SteeringInboxProtocol

AGENTS_FILE_NAME = "AGENTS.md"
MODEL_PROVIDER_CHOICES: tuple[str, ...] = (
Expand Down Expand Up @@ -140,6 +141,8 @@ async def load_state(self, message: ChannelMessage, session_id: str) -> State:
# fresh/unknown session never inherits another session's model.
if model := await self._recover_session_model(session_id):
state["model"] = model
if thread_id := field_of(message, "context", {}).get("thread_id"):
state["_runtime_thread_id"] = thread_id
return state

@hookimpl
Expand Down Expand Up @@ -324,7 +327,11 @@ def build_tape_context(self) -> TapeContext:
return default_tape_context()

@hookimpl
def admit_message(
def provide_steering_inbox(self) -> SteeringInboxProtocol:
return InMemorySteeringInbox()

@hookimpl
async def admit_message(
self,
session_id: str,
message: Envelope,
Expand All @@ -333,4 +340,4 @@ def admit_message(
outbound_router = self.framework._outbound_router
if outbound_router is None:
return None
return outbound_router.admit_channel_message(session_id=session_id, message=message, turn=turn)
return await outbound_router.admit_channel_message(session_id=session_id, message=message, turn=turn)
9 changes: 7 additions & 2 deletions src/bub/builtin/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def run(
tools: list[Tool],
system_prompt: str | None,
prompt: str | list[dict],
steering_messages: list[list[dict[str, Any]] | str] | None = None,
) -> AsyncStreamEvents:
state = StreamState()

Expand All @@ -96,6 +97,7 @@ async def iterator() -> AsyncGenerator[StreamEvent, None]:
system_prompt=system_prompt,
prompt=prompt,
model=model,
steering_messages=steering_messages,
)
output = ModelOutputAccumulator()
async with asyncio.timeout(self.settings.model_timeout_seconds):
Expand Down Expand Up @@ -159,6 +161,7 @@ async def build_messages(
system_prompt: str | None,
prompt: str | list[dict],
model: str,
steering_messages: list[list[dict[str, Any]] | str] | None = None,
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
prompt_message: dict[str, Any] = {"role": "user", "content": prompt}
try:
Expand All @@ -172,10 +175,12 @@ async def build_messages(
model=model,
)
raise
steering_messages_native = [{"role": "user", "content": message} for message in (steering_messages or [])]
if system_prompt:
messages = [{"role": "system", "content": system_prompt}, *messages]
messages.append(prompt_message)
return messages, [prompt_message]
new_messages = [*steering_messages_native, prompt_message]
messages.extend(new_messages)
return messages, new_messages

async def record_context_error(
self,
Expand Down
36 changes: 36 additions & 0 deletions src/bub/builtin/steering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Steering inbox implementations."""

from __future__ import annotations

from collections import defaultdict, deque
from collections.abc import Hashable

from bub.types import Envelope, State


class InMemorySteeringInbox:
"""Process-local steering inbox keyed by runtime thread or session."""

def __init__(self) -> None:
self._messages: defaultdict[Hashable, deque[Envelope]] = defaultdict(deque)

async def enqueue_message(self, message: Envelope, state: State) -> None:
self._messages[self._key(state)].append(message)

async def drain_messages(self, state: State) -> list[Envelope]:
key = self._key(state)
messages = list(self._messages.pop(key, ()))
return messages

def message_count(self, state: State) -> int:
return len(self._messages.get(self._key(state), ()))

@staticmethod
def _key(state: State) -> Hashable:
thread_id = state.get("_runtime_thread_id")
if isinstance(thread_id, Hashable) and thread_id:
return thread_id
session_id = state.get("session_id")
if isinstance(session_id, Hashable) and session_id:
return session_id
return "default"
2 changes: 1 addition & 1 deletion src/bub/channels/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def stream_events(self, message: ChannelMessage, stream: AsyncIterable[StreamEve
"""Optionally wrap the output stream for this channel."""
return stream

def admit_message(
async def admit_message(
self,
session_id: str,
message: Envelope,
Expand Down
11 changes: 6 additions & 5 deletions src/bub/channels/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,12 +281,12 @@ async def _main_loop(self) -> None:
continue

request = self._normalize_input(raw)
await self._echo_input(raw)

message = ChannelMessage(
session_id=self._message_template["session_id"],
channel=self._message_template["channel"],
chat_id=self._message_template["chat_id"],
context={"thread_id": self._message_template["session_id"]}, # use the same thread_id for all messages
content=request,
lifespan=self.message_lifespan(),
)
Expand Down Expand Up @@ -332,11 +332,11 @@ def _prompt_label(self) -> str:
symbol = ">" if self._mode == "agent" else ","
return f"{cwd} {symbol} "

async def _echo_input(self, raw: str) -> None:
async def _echo_input(self, raw: str, steering: bool = False) -> None:
stream_printer = getattr(self, "_stream_printer", None)
if stream_printer is not None:
await stream_printer.commit_live_text()
self._renderer.input_echo(self._prompt_label(), raw)
self._renderer.input_echo(self._prompt_label(), raw, steering=steering)

async def stream_events(
self, message: ChannelMessage, stream: AsyncIterable[StreamEvent]
Expand Down Expand Up @@ -416,12 +416,13 @@ def _history_file(home: Path, workspace: Path) -> Path:
workspace_hash = md5(str(workspace).encode("utf-8"), usedforsecurity=False).hexdigest()
return home / "history" / f"{workspace_hash}.history"

def admit_message(
async def admit_message(
self,
session_id: str,
message: Envelope,
turn: TurnSnapshot,
) -> AdmitDecision | None:
await self._echo_input(message.content, steering=turn.is_running)
if not turn.is_running:
return None
return AdmitDecision("follow_up", reason="cli session is already generating")
return AdmitDecision("steer", reason="cli session is already generating")
5 changes: 3 additions & 2 deletions src/bub/channels/cli/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,11 @@ def error(self, text: str) -> None:
return
self.console.print(f"[red bold]Error >[/]\n{text}")

def input_echo(self, prompt: str, text: str) -> None:
def input_echo(self, prompt: str, text: str, steering: bool = False) -> None:
if not text.strip():
return
self.console.print(f"[bold]{prompt}[/]{text}", new_line_start=True)
mid = "[grey](steering)[/] " if steering else ""
self.console.print(f"[dim][bold]{prompt}[/]{mid}{text}[/]", new_line_start=True)

def tool_call_start(self, *, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> None:
self.console.print(Text(_format_tool_call(name, args, kwargs), style="magenta"), new_line_start=True)
Expand Down
Loading
Loading