Learning to Rank with Bandit: Multimodal Search Guide | Mixpeek

9 min read Original article ↗

TL;DR: Static rerankers give you good results. Learned rerankers give you results that get better every time someone clicks. We'll show you how to build one using Thompson Sampling—the same technique powering TikTok's "For You" page.


The Problem with Static Reranking

You've built a solid multimodal retrieval pipeline. Your system combines:

  • CLIP embeddings for visual similarity
  • OCR for text matching
  • Audio transcriptions
  • Object detection (SAM, YOLO)
  • Metadata signals

And you're fusing them with something like:

final_score = 0.4 * clip_score + 0.3 * ocr_score + 0.2 * audio_score + 0.1 * metadata_score

The question: How did you pick those weights?

More importantly: Why are they the same for everyone?

Your e-commerce customers care about visual similarity. Your edtech customers care about OCR and transcripts. Your ad-tech customers care about motion and audio. But your static weights treat them all the same.

This is the problem Thompson Sampling solves.


What is Online Bandit Learning?

Think of it like A/B testing, but continuous and automatic.

Traditional A/B test:

  1. Split traffic 50/50
  2. Wait 2 weeks
  3. Pick the winner
  4. Deploy

Online bandit:

  1. Try different feature weights
  2. Learn from every click
  3. Shift weights toward what works
  4. Never stop learning
Static Pipeline (Traditional) Query Retrieve Rerank Results (no feedback loop) Learned Pipeline (Adaptive) Query Retrieve Rerank Learn Thompson Results Click

The "bandit" learns which features predict clicks for different contexts, users, or domains.


Why Feature-Level Learning?

Wrong approach: Learn a quality score for each document

  • Creates millions of parameters (one per document)
  • New documents have no history (cold start)
  • Doesn't generalize

Right approach: Learn importance weights for each feature

  • 5-10 parameters total (one per feature)
  • New documents benefit immediately
  • Learns modality preferences, not document preferences
❌ Document-Level Learning (Wrong) doc_1: α=1, β=0 (1 interaction) doc_2: α=0, β=1 (1 interaction) doc_3: α=0, β=0 (no data) doc_...: α=0, β=0 (no data) doc_9999: α=0, β=0 (no data) Problem: Millions of params, no generalization ✅ Feature-Level Learning (Right) clip: α=12, β=3 → weight: 0.80 ocr: α=2, β=10 → weight: 0.17 audio: α=8, β=4 → weight: 0.67 metadata: α=6, β=6 → weight: 0.50 Benefit: 4 parameters, works for ALL documents

Thompson Sampling in 3 Steps

Step 1: Maintain a Beta Distribution per Feature

For each feature (clip, ocr, audio, etc.), track:

  • α (alpha): Times this feature was strong in clicked results
  • β (beta): Times this feature was weak in clicked results
# Initial state: no knowledge
feature_params = {
    "clip": {"alpha": 1, "beta": 1},
    "ocr": {"alpha": 1, "beta": 1},
    "audio": {"alpha": 1, "beta": 1},
    "metadata": {"alpha": 1, "beta": 1}
}

Step 2: Sample Weights for Ranking

When a query arrives, sample a weight from each distribution:

import numpy as np

def sample_weights(feature_params):
    """Sample feature weights using Thompson Sampling."""
    weights = {}
    for feature, params in feature_params.items():
        # Sample from Beta(α, β)
        weights[feature] = np.random.beta(
            params["alpha"], 
            params["beta"]
        )
    return weights

# Example output:
# {"clip": 0.78, "ocr": 0.15, "audio": 0.62, "metadata": 0.45}

Then rank documents:

def compute_score(doc_features, weights):
    """Combine features with learned weights."""
    return sum(
        weights[f] * doc_features[f] 
        for f in weights
    )

Step 3: Update from Clicks

When a user clicks a result, check which features were strong:

def update_from_click(clicked_doc_features, feature_params, threshold=0.5):
    """Update α/β based on clicked document features."""
    for feature, value in clicked_doc_features.items():
        if value > threshold:
            # Feature was strong → credit it
            feature_params[feature]["alpha"] += 1
        else:
            # Feature was weak → penalize it
            feature_params[feature]["beta"] += 1

That's it. Three simple steps, zero ML infrastructure.


Handling Cold Start: Hierarchical Contexts

What about new users with no click history?

Solution: Learn at multiple levels simultaneously.

NEW USER QUERY Looking for best context... Has 5+ clicks? (check interaction count) YES ✓ PERSONAL user_123 (best match) NO Has demographics? (segment, device, etc.) YES DEMOGRAPHIC segment_tech _enthusiast (similar users) NO GLOBAL (ultimate fallback)

Key insight: Every click updates:

  1. Personal context (if user has history)
  2. Demographic context (benefits similar users)
  3. Global context (benefits everyone)

New users get instant personalization from their demographic group.


Let's say you're building video search. Initial state:

clip:  Beta(1, 1)  → no idea if visual matters
ocr:   Beta(1, 1)  → no idea if text matters
audio: Beta(1, 1)  → no idea if audio matters

After 20 clicks in e-commerce:

clip:  Beta(18, 4)  → 0.82 avg weight (visual matters!)
ocr:   Beta(5, 17)  → 0.23 avg weight (text doesn't)
audio: Beta(3, 19)  → 0.14 avg weight (audio irrelevant)

After 20 clicks in edtech:

clip:  Beta(6, 16)  → 0.27 avg weight (visual less important)
ocr:   Beta(19, 3)  → 0.86 avg weight (transcripts crucial)
audio: Beta(17, 5)  → 0.77 avg weight (lecture audio key)

Same features. Different domains. Automatic adaptation.


Architecture Overview

Learned Reranking Architecture Retrieval Pipeline Vector Search Top 100 candidates Static Rerank Cross-encoder scoring Learned Rerank Thompson Sampling Results to User Ranked documents User Interaction Click, view, etc. Learning Loop Extract Features Update α/β in Redis Storage Redis Feature weights (α, β params) ClickHouse Interaction log (analytics) Red highlight = where learning happens | Green dashed = feedback loop

Key components:

  1. Redis: Stores current α/β for each feature (hot, fast)
  2. ClickHouse: Stores interaction history (cold, analytics)
  3. Learned Rerank Stage: Samples weights, computes scores
  4. Update Webhook: Processes clicks, updates Redis

DIY Implementation Guide

Prerequisites

pip install numpy redis

1. Storage Layer

import redis
import json

class BanditStorage:
    def __init__(self, redis_url="redis://localhost:6379"):
        self.redis = redis.from_url(redis_url)
    
    def get_params(self, context_id, features):
        """Get α/β for all features in a context."""
        key = f"bandit:{context_id}"
        params = {}
        
        for feature in features:
            alpha = self.redis.hget(key, f"{feature}.alpha")
            beta = self.redis.hget(key, f"{feature}.beta")
            
            # Default to uniform prior
            params[feature] = {
                "alpha": float(alpha) if alpha else 1.0,
                "beta": float(beta) if beta else 1.0
            }
        
        return params
    
    def update_params(self, context_id, feature_updates):
        """Update α or β for features."""
        key = f"bandit:{context_id}"
        
        for feature, updates in feature_updates.items():
            if "alpha" in updates:
                self.redis.hincrbyfloat(key, f"{feature}.alpha", updates["alpha"])
            if "beta" in updates:
                self.redis.hincrbyfloat(key, f"{feature}.beta", updates["beta"])

2. Thompson Sampling Algorithm

import numpy as np

class ThompsonSampling:
    @staticmethod
    def sample_weights(feature_params, exploration_bonus=1.0):
        """Sample feature weights from Beta distributions."""
        weights = {}
        
        for feature, params in feature_params.items():
            alpha = params["alpha"] * exploration_bonus
            beta = params["beta"] * exploration_bonus
            
            # Sample from Beta(α, β)
            weights[feature] = np.random.beta(alpha, beta)
        
        return weights
    
    @staticmethod
    def compute_score(doc_features, weights):
        """Weighted sum of features."""
        return sum(
            weights.get(f, 0.5) * v 
            for f, v in doc_features.items()
        )

3. Learned Reranker

class LearnedReranker:
    def __init__(self, storage):
        self.storage = storage
        self.sampler = ThompsonSampling()
    
    def rerank(self, documents, context_id, feature_threshold=0.5):
        """Rerank documents using learned feature weights."""
        
        # Extract feature names from first doc
        features = list(documents[0]["features"].keys())
        
        # Get current parameters
        params = self.storage.get_params(context_id, features)
        
        # Sample weights
        weights = self.sampler.sample_weights(params)
        
        # Score all documents
        for doc in documents:
            doc["learned_score"] = self.sampler.compute_score(
                doc["features"], 
                weights
            )
        
        # Sort by learned score
        documents.sort(key=lambda d: d["learned_score"], reverse=True)
        
        return documents
    
    def update_from_click(self, clicked_doc, context_id, threshold=0.5):
        """Update feature parameters from user click."""
        updates = {}
        
        for feature, value in clicked_doc["features"].items():
            if value > threshold:
                # Strong feature gets credit
                updates[feature] = {"alpha": 1}
            else:
                # Weak feature gets penalty
                updates[feature] = {"beta": 1}
        
        self.storage.update_params(context_id, updates)

4. Usage Example

# Initialize
storage = BanditStorage()
reranker = LearnedReranker(storage)

# Sample documents with multimodal features
documents = [
    {
        "id": "doc_1",
        "features": {
            "clip": 0.85,
            "ocr": 0.23,
            "audio": 0.67,
            "metadata": 0.91
        }
    },
    {
        "id": "doc_2",
        "features": {
            "clip": 0.45,
            "ocr": 0.89,
            "audio": 0.12,
            "metadata": 0.56
        }
    }
]

# Rerank for a specific context (e.g., user_123 or segment_techies)
context_id = "user_123"
ranked_docs = reranker.rerank(documents, context_id)

# User clicks the top result
clicked_doc = ranked_docs[0]
reranker.update_from_click(clicked_doc, context_id)

# Next query will use updated weights!

Performance Characteristics

Metric Static Rerank Learned Rerank
Latency overhead 0ms ~1-2ms (sampling)
Storage per context 0 bytes ~200 bytes (α/β per feature)
Training required None None
Adaptation speed Never ~5-10 interactions
Cold start handling N/A Hierarchical fallback
Personalization No Yes (per user/segment)

When to Use Bandit Learning

✅ Good fit:

  • Multimodal retrieval (images, video, audio, text)
  • User behavior varies by segment/domain
  • Need to optimize click-through or engagement
  • Fast iteration required (no retraining)
  • Cold start is important

❌ Not a fit:

  • Static, compliance-driven ranking
  • No user feedback loop
  • Need deep collaborative filtering
  • Regulatory constraints on personalization

Common Pitfalls

1. Learning per document instead of per feature

# ❌ DON'T: Creates millions of parameters
bandit_params["doc_12345"] = {"alpha": 3, "beta": 1}

# ✅ DO: Learn which features matter
bandit_params["clip"] = {"alpha": 12, "beta": 3}

2. Forgetting cold start

# ❌ DON'T: Fail for new users
if user_has_history(user_id):
    context = f"user_{user_id}"
else:
    raise Exception("No history!")

# ✅ DO: Hierarchical fallback
if user_has_history(user_id, min_interactions=5):
    context = f"user_{user_id}"
elif user_has_demographics(user_id):
    context = f"segment_{user_segment}"
else:
    context = "global"

3. Not updating multiple contexts

# ❌ DON'T: Only update personal
update_params(f"user_{user_id}", feature_updates)

# ✅ DO: Update personal + demographic + global
for context in [personal_ctx, demographic_ctx, "global"]:
    update_params(context, feature_updates)

Advanced: Context Strategies

Different use cases need different context strategies:

# E-commerce: User-level personalization
context = f"user_{user_id}"

# B2B SaaS: Team-level learning
context = f"team_{team_id}"

# Content platform: Demographic clustering
context = f"segment_{age_group}_{device_type}"

# Search engine: Query-type clustering
context = f"query_type_{intent_category}"

Monitoring Your Bandit

Key metrics to track:

def get_feature_stats(storage, context_id, features):
    """Get current state of feature learning."""
    params = storage.get_params(context_id, features)
    
    stats = {}
    for feature, p in params.items():
        alpha, beta = p["alpha"], p["beta"]
        
        stats[feature] = {
            "mean_weight": alpha / (alpha + beta),
            "confidence": alpha + beta,  # Higher = more certain
            "preference": "high" if alpha > beta else "low"
        }
    
    return stats

# Example output:
# {
#   "clip": {"mean_weight": 0.82, "confidence": 23, "preference": "high"},
#   "ocr": {"mean_weight": 0.19, "confidence": 21, "preference": "low"}
# }

Why This Beats Alternatives

Approach Training Latency Cold Start Explainability
Thompson Sampling None ~1ms Hierarchical High
XGBoost reranker Batch (weekly) ~50ms Poor Medium
Twin towers Offline (daily) ~10ms Poor Low
Neural reranker Continuous ~20ms Medium Low
LinUCB bandit None ~2ms Medium High

Thompson Sampling gives you the best trade-off for most production systems.


Real-World Impact

At Mixpeek, we've seen customers using learned reranking achieve:

  • 23% improvement in click-through rate (e-commerce visual search)
  • 31% increase in watch time (video content discovery)
  • 40% faster convergence vs. batch retraining (ad creative ranking)
  • Cold start working in <5 interactions (new user onboarding)

The key insight: The system learns what matters for YOUR domain automatically.


Try It With Mixpeek

Building this from scratch is fun for learning, but production systems need:

  • Distributed state management
  • Interaction tracking
  • Context isolation per tenant
  • Monitoring dashboards
  • Fallback strategies

Mixpeek's retrieval API includes learned reranking as a single stage:

{
  "pipeline": [
    {
      "stage_name": "retriever",
      "parameters": { ... }
    },
    {
      "stage_name": "rerank",
      "parameters": {
        "inference_name": "bge-reranker"
      }
    },
    {
      "stage_name": "learned_rerank",
      "parameters": {
        "algorithm": "thompson_sampling",
        "feature_fields": ["features.*"],
        "context_features": ["INPUT.user_id"]
      }
    }
  ]
}

Learn more about Mixpeek's learned reranking →


Further Reading


Questions? Drop them in the comments or join our Discord.