Skip to content
Learn Agentic AI12 min read0 views

WebSocket Agent Endpoints with FastAPI: Bidirectional Real-Time Communication

Build bidirectional WebSocket endpoints for AI agents in FastAPI. Learn connection lifecycle management, message routing, heartbeat mechanisms, and handling multiple concurrent agent sessions.

When to Use WebSockets Instead of SSE

Server-Sent Events work well for one-directional streaming where the client sends a request and receives a stream of tokens. But many AI agent scenarios need bidirectional communication: the user sends follow-up messages while the agent is still responding, the agent asks for clarification mid-conversation, or the frontend sends real-time signals like "stop generating" or "the user is typing."

WebSockets provide a persistent, full-duplex connection where both client and server can send messages at any time. FastAPI supports WebSockets natively through Starlette, making it straightforward to build real-time agent communication channels.

Basic WebSocket Agent Endpoint

Here is a minimal WebSocket endpoint that receives user messages and streams agent responses:

from fastapi import FastAPI, WebSocket, WebSocketDisconnect
import json

app = FastAPI()

@app.websocket("/ws/agent/{session_id}")
async def agent_websocket(
    websocket: WebSocket,
    session_id: str,
):
    await websocket.accept()

    try:
        while True:
            # Receive message from client
            data = await websocket.receive_json()

            if data["type"] == "message":
                # Stream agent response back
                async for token in agent.stream(data["content"]):
                    await websocket.send_json({
                        "type": "token",
                        "content": token,
                    })

                await websocket.send_json({
                    "type": "message_complete",
                    "session_id": session_id,
                })

    except WebSocketDisconnect:
        print(f"Client {session_id} disconnected")

The endpoint accepts a connection, then enters an infinite loop that reads messages and sends responses. The WebSocketDisconnect exception is raised when the client closes the connection.

Connection Manager for Multiple Sessions

Production AI agents need to track multiple concurrent connections. A connection manager handles this:

from dataclasses import dataclass, field
import asyncio

@dataclass
class AgentSession:
    websocket: WebSocket
    session_id: str
    user_id: str
    created_at: float = field(default_factory=lambda: time.time())
    is_generating: bool = False

class ConnectionManager:
    def __init__(self):
        self._sessions: dict[str, AgentSession] = {}
        self._lock = asyncio.Lock()

    async def connect(
        self, websocket: WebSocket, session_id: str, user_id: str
    ) -> AgentSession:
        await websocket.accept()
        session = AgentSession(
            websocket=websocket,
            session_id=session_id,
            user_id=user_id,
        )
        async with self._lock:
            self._sessions[session_id] = session
        return session

    async def disconnect(self, session_id: str):
        async with self._lock:
            self._sessions.pop(session_id, None)

    async def send_to_session(
        self, session_id: str, message: dict
    ):
        session = self._sessions.get(session_id)
        if session:
            await session.websocket.send_json(message)

    def get_session(self, session_id: str):
        return self._sessions.get(session_id)

manager = ConnectionManager()

The asyncio.Lock prevents race conditions when multiple connections are added or removed simultaneously.

Structured Message Protocol

Define a clear message protocol with typed messages for both directions:

See AI Voice Agents Handle Real Calls

Book a free demo or calculate how much you can save with AI voice automation.

from pydantic import BaseModel
from enum import Enum
from typing import Optional

class ClientMessageType(str, Enum):
    MESSAGE = "message"
    STOP = "stop"
    PING = "ping"
    TOOL_RESPONSE = "tool_response"

class ServerMessageType(str, Enum):
    TOKEN = "token"
    COMPLETE = "complete"
    ERROR = "error"
    PONG = "pong"
    TOOL_REQUEST = "tool_request"

class ClientMessage(BaseModel):
    type: ClientMessageType
    content: Optional[str] = None
    metadata: Optional[dict] = None

class ServerMessage(BaseModel):
    type: ServerMessageType
    content: Optional[str] = None
    metadata: Optional[dict] = None

Validate incoming messages against this schema to catch malformed data early:

@app.websocket("/ws/agent/{session_id}")
async def agent_websocket(websocket: WebSocket, session_id: str):
    session = await manager.connect(websocket, session_id, "user1")

    try:
        while True:
            raw = await websocket.receive_json()
            try:
                msg = ClientMessage(**raw)
            except ValueError:
                await websocket.send_json(
                    {"type": "error", "content": "Invalid message format"}
                )
                continue

            if msg.type == ClientMessageType.PING:
                await websocket.send_json({"type": "pong"})

            elif msg.type == ClientMessageType.STOP:
                session.is_generating = False

            elif msg.type == ClientMessageType.MESSAGE:
                await handle_agent_message(session, msg.content)

    except WebSocketDisconnect:
        await manager.disconnect(session_id)

Heartbeat Mechanism

WebSocket connections can silently die due to network issues, proxy timeouts, or mobile devices going to sleep. Implement a heartbeat to detect dead connections:

async def heartbeat_task(
    websocket: WebSocket, session_id: str, interval: int = 30
):
    try:
        while True:
            await asyncio.sleep(interval)
            try:
                await websocket.send_json({
                    "type": "ping",
                    "timestamp": time.time(),
                })
            except Exception:
                await manager.disconnect(session_id)
                break
    except asyncio.CancelledError:
        pass

@app.websocket("/ws/agent/{session_id}")
async def agent_websocket(websocket: WebSocket, session_id: str):
    session = await manager.connect(websocket, session_id, "user1")

    # Start heartbeat as a background task
    heartbeat = asyncio.create_task(
        heartbeat_task(websocket, session_id)
    )

    try:
        while True:
            raw = await websocket.receive_json()
            await handle_message(session, raw)
    except WebSocketDisconnect:
        heartbeat.cancel()
        await manager.disconnect(session_id)

Handling Stop Generation

A critical feature for AI agents is letting the user stop generation mid-stream. Use a cancellation flag on the session:

async def handle_agent_message(session: AgentSession, content: str):
    session.is_generating = True

    async for token in llm_service.stream_generate(content):
        if not session.is_generating:
            await session.websocket.send_json({
                "type": "complete",
                "content": "Generation stopped by user.",
            })
            return

        await session.websocket.send_json({
            "type": "token",
            "content": token,
        })

    session.is_generating = False
    await session.websocket.send_json({"type": "complete"})

When the client sends a stop message, the main message loop sets session.is_generating = False, and the generator checks this flag on each iteration.

FAQ

How many concurrent WebSocket connections can a single FastAPI worker handle?

A single async FastAPI worker can handle thousands of concurrent WebSocket connections because each connection consumes very little memory when idle. The bottleneck is usually the LLM API calls, not the WebSocket connections themselves. With proper async patterns, a single Uvicorn worker can manage 5000 or more idle connections comfortably.

Should I use WebSockets or SSE for my AI agent?

Use SSE if your agent follows a simple request-response-stream pattern where the client sends a message and receives a streamed response. Use WebSockets if you need bidirectional communication such as stop-generation signals, agent-initiated clarification questions, real-time typing indicators, or multiple interleaved conversations. WebSockets add complexity in terms of connection management and error handling, so choose SSE unless you need the bidirectional capability.

How do I handle authentication with WebSocket connections?

WebSocket connections do not support custom headers in the browser WebSocket API. The common approaches are: pass a token as a query parameter (/ws/agent?token=xxx), validate it during the accept phase, and reject the connection if invalid. Alternatively, authenticate via a regular HTTP endpoint first, set a session cookie, and validate that cookie when the WebSocket connects. Always validate the token before calling websocket.accept().


#FastAPI #WebSocket #RealTime #AIAgents #Python #AgenticAI #LearnAI #AIEngineering

Share this article
C

CallSphere Team

Expert insights on AI voice agents and customer communication automation.

Try CallSphere AI Voice Agents

See how AI voice agents work for your industry. Live demo available -- no signup required.