Semantic Caching for LLMs: Cutting API Costs by 60%
Learn how to implement semantic caching for LLM applications to dramatically reduce API costs and latency. Covers embedding-based cache keys, TTL strategies, cache invalidation, and production deployment patterns with Redis and vector databases.
The Cost Problem
LLM API costs scale linearly with usage. Every identical or near-identical query costs the same as the first time it was asked. In production systems, a significant percentage of queries are semantically equivalent -- "What's your refund policy?" and "How do I get a refund?" should produce the same answer.
Traditional caching uses exact string matching, which misses these semantic duplicates. Semantic caching embeds queries into vectors and matches based on similarity, catching both exact and near-duplicate queries.
Real-world deployments report 40-70% cache hit rates with semantic caching, translating directly to proportional cost savings.
Architecture
User Query
|
v
[Embed Query] --> query_vector
|
v
[Search Cache] -- similarity > threshold?
| |
No Yes
| |
v v
[Call LLM] [Return Cached Response]
|
v
[Store in Cache]
|
v
Return Response
Implementation
Core Semantic Cache with Redis
import hashlib
import json
import time
import numpy as np
import redis.asyncio as redis
from sentence_transformers import SentenceTransformer
class SemanticCache:
def __init__(
self,
redis_url: str = "redis://localhost:6379",
similarity_threshold: float = 0.92,
ttl_seconds: int = 3600,
embedding_model: str = "BAAI/bge-small-en-v1.5",
):
self.redis = redis.from_url(redis_url)
self.threshold = similarity_threshold
self.ttl = ttl_seconds
self.embedder = SentenceTransformer(embedding_model)
self.cache_entries = [] # In production, use a vector index
def _embed(self, text: str) -> np.ndarray:
return self.embedder.encode(text, normalize_embeddings=True)
def _cache_key(self, query_hash: str) -> str:
return f"sem_cache:{query_hash}"
async def get(self, query: str, system_prompt: str = "") -> dict | None:
"""Look up a semantically similar cached response"""
# Include system prompt in the cache context
cache_context = f"{system_prompt}||{query}" if system_prompt else query
query_vector = self._embed(cache_context)
# Search for similar cached queries
best_match = None
best_score = 0.0
# In production, use a vector index (Qdrant, Redis VSS) instead of brute force
keys = await self.redis.keys("sem_cache:*")
for key in keys:
entry = await self.redis.hgetall(key)
if not entry:
continue
cached_vector = np.frombuffer(entry[b"vector"], dtype=np.float32)
similarity = np.dot(query_vector, cached_vector)
if similarity > self.threshold and similarity > best_score:
best_score = similarity
best_match = entry
if best_match:
return {
"response": best_match[b"response"].decode(),
"similarity": best_score,
"cached_at": float(best_match[b"timestamp"]),
"cache_hit": True,
}
return None
async def set(self, query: str, response: str, system_prompt: str = ""):
"""Cache a query-response pair"""
cache_context = f"{system_prompt}||{query}" if system_prompt else query
query_vector = self._embed(cache_context)
query_hash = hashlib.sha256(cache_context.encode()).hexdigest()[:16]
key = self._cache_key(query_hash)
await self.redis.hset(key, mapping={
"query": query,
"response": response,
"vector": query_vector.tobytes(),
"timestamp": str(time.time()),
"system_prompt_hash": hashlib.sha256(
system_prompt.encode()
).hexdigest()[:8],
})
await self.redis.expire(key, self.ttl)
Production Implementation with Redis Vector Search
For production workloads, use Redis Stack with its built-in vector similarity search instead of brute-force comparison:
from redis.commands.search.field import VectorField, TextField, NumericField
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
from redis.commands.search.query import Query
class ProductionSemanticCache:
def __init__(self, redis_client, embedding_dim: int = 384):
self.redis = redis_client
self.dim = embedding_dim
self.index_name = "semantic_cache_idx"
async def create_index(self):
"""Create the vector search index (run once)"""
schema = [
TextField("query"),
TextField("response"),
NumericField("timestamp"),
VectorField("vector",
"FLAT",
{"TYPE": "FLOAT32", "DIM": self.dim, "DISTANCE_METRIC": "COSINE"}),
]
definition = IndexDefinition(prefix=["sem_cache:"], index_type=IndexType.HASH)
await self.redis.ft(self.index_name).create_index(schema, definition=definition)
async def search(self, query_vector: bytes, top_k: int = 1) -> list:
"""Vector similarity search using Redis VSS"""
q = (
Query(f"*=>[KNN {top_k} @vector $vec AS score]")
.sort_by("score")
.return_fields("query", "response", "timestamp", "score")
.dialect(2)
)
results = await self.redis.ft(self.index_name).search(
q, query_params={"vec": query_vector}
)
return results.docs
Middleware Integration
Wrap your LLM client with caching middleware for transparent integration:
class CachedLLMClient:
def __init__(self, llm_client, cache: SemanticCache):
self.llm = llm_client
self.cache = cache
self.stats = {"hits": 0, "misses": 0, "total_saved_tokens": 0}
async def generate(
self,
messages: list[dict],
system: str = "",
model: str = "claude-sonnet-4-20250514",
**kwargs,
):
# Extract the user query (last user message)
user_query = next(
(m["content"] for m in reversed(messages) if m["role"] == "user"),
""
)
# Check cache
cached = await self.cache.get(user_query, system_prompt=system)
if cached:
self.stats["hits"] += 1
return CachedResponse(
content=cached["response"],
from_cache=True,
similarity=cached["similarity"],
)
# Cache miss -- call the LLM
self.stats["misses"] += 1
response = await self.llm.messages.create(
model=model,
system=system,
messages=messages,
**kwargs,
)
# Store in cache
response_text = response.content[0].text
await self.cache.set(user_query, response_text, system_prompt=system)
self.stats["total_saved_tokens"] += (
response.usage.input_tokens + response.usage.output_tokens
) * (self.stats["hits"] / max(self.stats["hits"] + self.stats["misses"], 1))
return response
@property
def hit_rate(self) -> float:
total = self.stats["hits"] + self.stats["misses"]
return self.stats["hits"] / total if total > 0 else 0
Threshold Tuning
The similarity threshold is the most critical parameter. Too low and you return irrelevant cached responses. Too high and you miss valid cache hits.
| Threshold | Hit Rate | Error Rate | Best For |
|---|---|---|---|
| 0.98+ | 5-15% | <0.1% | Safety-critical (medical, legal) |
| 0.95 | 15-30% | <0.5% | Factual Q&A |
| 0.92 | 30-50% | 1-2% | Customer support, general Q&A |
| 0.88 | 50-70% | 3-5% | Informal chat, recommendations |
Threshold Calibration Process
async def calibrate_threshold(eval_pairs: list[tuple[str, str, bool]]):
"""
eval_pairs: [(query1, query2, should_match), ...]
Find the threshold that maximizes F1 score
"""
embedder = SentenceTransformer("BAAI/bge-small-en-v1.5")
similarities = []
for q1, q2, label in eval_pairs:
v1 = embedder.encode(q1, normalize_embeddings=True)
v2 = embedder.encode(q2, normalize_embeddings=True)
sim = np.dot(v1, v2)
similarities.append((sim, label))
# Test thresholds from 0.80 to 0.99
best_f1 = 0
best_threshold = 0.92
for threshold in np.arange(0.80, 0.99, 0.01):
tp = sum(1 for s, l in similarities if s >= threshold and l)
fp = sum(1 for s, l in similarities if s >= threshold and not l)
fn = sum(1 for s, l in similarities if s < threshold and l)
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
if f1 > best_f1:
best_f1 = f1
best_threshold = threshold
return best_threshold, best_f1
Cache Invalidation Strategies
- TTL-based: Set a time-to-live on each cache entry. Simple and effective for data that changes infrequently.
- Version-based: Include a version number in the cache key. Increment the version when your system prompt, model, or knowledge base changes.
- Event-based: Invalidate specific cache entries when the underlying data changes (e.g., when a product price updates).
class VersionedSemanticCache(SemanticCache):
def __init__(self, *args, cache_version: str = "v1", **kwargs):
super().__init__(*args, **kwargs)
self.version = cache_version
def _cache_key(self, query_hash: str) -> str:
return f"sem_cache:{self.version}:{query_hash}"
async def invalidate_version(self, old_version: str):
"""Delete all cache entries for an old version"""
keys = await self.redis.keys(f"sem_cache:{old_version}:*")
if keys:
await self.redis.delete(*keys)
Cost Impact Analysis
For a customer support bot handling 100,000 queries per day using Claude Sonnet:
| Metric | Without Cache | With Semantic Cache (45% hit rate) |
|---|---|---|
| Daily LLM calls | 100,000 | 55,000 |
| Daily input tokens | 50M | 27.5M |
| Daily output tokens | 15M | 8.25M |
| Daily cost | $375 | $206 |
| Monthly cost | $11,250 | $6,187 |
| Monthly savings | -- | $5,063 (45%) |
| Cache infrastructure cost | -- | ~$50/month (Redis) |
The cache infrastructure cost is negligible compared to the LLM API savings. Even a modest 30% hit rate saves thousands of dollars monthly at scale.
Key Takeaways
Semantic caching is one of the highest-ROI optimizations for production LLM applications. The implementation is straightforward, the cost savings are immediate and measurable, and the latency improvement (cached responses in 5-20ms vs 500-2000ms for LLM calls) improves user experience. Start with a conservative similarity threshold (0.95), measure your actual hit rate and error rate, and tune from there.
NYC News
Expert insights on AI voice agents and customer communication automation.
Try CallSphere AI Voice Agents
See how AI voice agents work for your industry. Live demo available -- no signup required.