WebSocket Agent Endpoints with FastAPI: Bidirectional Real-Time Communication
Build bidirectional WebSocket endpoints for AI agents in FastAPI. Learn connection lifecycle management, message routing, heartbeat mechanisms, and handling multiple concurrent agent sessions.
When to Use WebSockets Instead of SSE
Server-Sent Events work well for one-directional streaming where the client sends a request and receives a stream of tokens. But many AI agent scenarios need bidirectional communication: the user sends follow-up messages while the agent is still responding, the agent asks for clarification mid-conversation, or the frontend sends real-time signals like "stop generating" or "the user is typing."
WebSockets provide a persistent, full-duplex connection where both client and server can send messages at any time. FastAPI supports WebSockets natively through Starlette, making it straightforward to build real-time agent communication channels.
Basic WebSocket Agent Endpoint
Here is a minimal WebSocket endpoint that receives user messages and streams agent responses:
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
import json
app = FastAPI()
@app.websocket("/ws/agent/{session_id}")
async def agent_websocket(
websocket: WebSocket,
session_id: str,
):
await websocket.accept()
try:
while True:
# Receive message from client
data = await websocket.receive_json()
if data["type"] == "message":
# Stream agent response back
async for token in agent.stream(data["content"]):
await websocket.send_json({
"type": "token",
"content": token,
})
await websocket.send_json({
"type": "message_complete",
"session_id": session_id,
})
except WebSocketDisconnect:
print(f"Client {session_id} disconnected")
The endpoint accepts a connection, then enters an infinite loop that reads messages and sends responses. The WebSocketDisconnect exception is raised when the client closes the connection.
Connection Manager for Multiple Sessions
Production AI agents need to track multiple concurrent connections. A connection manager handles this:
from dataclasses import dataclass, field
import asyncio
@dataclass
class AgentSession:
websocket: WebSocket
session_id: str
user_id: str
created_at: float = field(default_factory=lambda: time.time())
is_generating: bool = False
class ConnectionManager:
def __init__(self):
self._sessions: dict[str, AgentSession] = {}
self._lock = asyncio.Lock()
async def connect(
self, websocket: WebSocket, session_id: str, user_id: str
) -> AgentSession:
await websocket.accept()
session = AgentSession(
websocket=websocket,
session_id=session_id,
user_id=user_id,
)
async with self._lock:
self._sessions[session_id] = session
return session
async def disconnect(self, session_id: str):
async with self._lock:
self._sessions.pop(session_id, None)
async def send_to_session(
self, session_id: str, message: dict
):
session = self._sessions.get(session_id)
if session:
await session.websocket.send_json(message)
def get_session(self, session_id: str):
return self._sessions.get(session_id)
manager = ConnectionManager()
The asyncio.Lock prevents race conditions when multiple connections are added or removed simultaneously.
Structured Message Protocol
Define a clear message protocol with typed messages for both directions:
See AI Voice Agents Handle Real Calls
Book a free demo or calculate how much you can save with AI voice automation.
from pydantic import BaseModel
from enum import Enum
from typing import Optional
class ClientMessageType(str, Enum):
MESSAGE = "message"
STOP = "stop"
PING = "ping"
TOOL_RESPONSE = "tool_response"
class ServerMessageType(str, Enum):
TOKEN = "token"
COMPLETE = "complete"
ERROR = "error"
PONG = "pong"
TOOL_REQUEST = "tool_request"
class ClientMessage(BaseModel):
type: ClientMessageType
content: Optional[str] = None
metadata: Optional[dict] = None
class ServerMessage(BaseModel):
type: ServerMessageType
content: Optional[str] = None
metadata: Optional[dict] = None
Validate incoming messages against this schema to catch malformed data early:
@app.websocket("/ws/agent/{session_id}")
async def agent_websocket(websocket: WebSocket, session_id: str):
session = await manager.connect(websocket, session_id, "user1")
try:
while True:
raw = await websocket.receive_json()
try:
msg = ClientMessage(**raw)
except ValueError:
await websocket.send_json(
{"type": "error", "content": "Invalid message format"}
)
continue
if msg.type == ClientMessageType.PING:
await websocket.send_json({"type": "pong"})
elif msg.type == ClientMessageType.STOP:
session.is_generating = False
elif msg.type == ClientMessageType.MESSAGE:
await handle_agent_message(session, msg.content)
except WebSocketDisconnect:
await manager.disconnect(session_id)
Heartbeat Mechanism
WebSocket connections can silently die due to network issues, proxy timeouts, or mobile devices going to sleep. Implement a heartbeat to detect dead connections:
async def heartbeat_task(
websocket: WebSocket, session_id: str, interval: int = 30
):
try:
while True:
await asyncio.sleep(interval)
try:
await websocket.send_json({
"type": "ping",
"timestamp": time.time(),
})
except Exception:
await manager.disconnect(session_id)
break
except asyncio.CancelledError:
pass
@app.websocket("/ws/agent/{session_id}")
async def agent_websocket(websocket: WebSocket, session_id: str):
session = await manager.connect(websocket, session_id, "user1")
# Start heartbeat as a background task
heartbeat = asyncio.create_task(
heartbeat_task(websocket, session_id)
)
try:
while True:
raw = await websocket.receive_json()
await handle_message(session, raw)
except WebSocketDisconnect:
heartbeat.cancel()
await manager.disconnect(session_id)
Handling Stop Generation
A critical feature for AI agents is letting the user stop generation mid-stream. Use a cancellation flag on the session:
async def handle_agent_message(session: AgentSession, content: str):
session.is_generating = True
async for token in llm_service.stream_generate(content):
if not session.is_generating:
await session.websocket.send_json({
"type": "complete",
"content": "Generation stopped by user.",
})
return
await session.websocket.send_json({
"type": "token",
"content": token,
})
session.is_generating = False
await session.websocket.send_json({"type": "complete"})
When the client sends a stop message, the main message loop sets session.is_generating = False, and the generator checks this flag on each iteration.
FAQ
How many concurrent WebSocket connections can a single FastAPI worker handle?
A single async FastAPI worker can handle thousands of concurrent WebSocket connections because each connection consumes very little memory when idle. The bottleneck is usually the LLM API calls, not the WebSocket connections themselves. With proper async patterns, a single Uvicorn worker can manage 5000 or more idle connections comfortably.
Should I use WebSockets or SSE for my AI agent?
Use SSE if your agent follows a simple request-response-stream pattern where the client sends a message and receives a streamed response. Use WebSockets if you need bidirectional communication such as stop-generation signals, agent-initiated clarification questions, real-time typing indicators, or multiple interleaved conversations. WebSockets add complexity in terms of connection management and error handling, so choose SSE unless you need the bidirectional capability.
How do I handle authentication with WebSocket connections?
WebSocket connections do not support custom headers in the browser WebSocket API. The common approaches are: pass a token as a query parameter (/ws/agent?token=xxx), validate it during the accept phase, and reject the connection if invalid. Alternatively, authenticate via a regular HTTP endpoint first, set a session cookie, and validate that cookie when the WebSocket connects. Always validate the token before calling websocket.accept().
#FastAPI #WebSocket #RealTime #AIAgents #Python #AgenticAI #LearnAI #AIEngineering
CallSphere Team
Expert insights on AI voice agents and customer communication automation.
Try CallSphere AI Voice Agents
See how AI voice agents work for your industry. Live demo available -- no signup required.