Writing a tool handler for the Experiment Engine
A "tool" is a capability like L2 Quantization or L5 Batching — a
self-contained optimization that can run as an experiment over a given
(model, hardware, inference engine) triple. Each tool registers
exactly one ToolHandler at import time; the experiment engine takes
care of queueing, quota enforcement, Celery dispatch, lifecycle state
transitions, cancellation, SSE event fan-out, and Prometheus metrics.
What the engine gives you
- Per-org quota: monthly GPU-hour budget + concurrent-experiment cap,
enforced at create time (402 over-budget, 429 over-concurrency).
- RLS isolation: rows are auto-scoped to the caller's org. Workers
bypass via SET LOCAL app.rls_bypass='true'.
- State machine: `pending → queued → running → {completed, failed,
cancelled, timed_out}`. The runner is the only thing allowed to
transition to a terminal state.
- Event stream: every
ctx.emit(...)writes to the
experiment_events table and publishes on a Redis channel. An SSE
client connecting mid-run replays history then tails live events.
- Cooperative cancellation: POST
/experiments/{id}/cancelflips
the row; your handler checks await ctx.check_cancelled() between
stages and returns.
- Metrics:
experiments_created_total,experiments_completed_total{status},
experiment_duration_seconds, experiment_queue_wait_seconds — all
ready for Grafana (dashboard UID iiq-experiments).
Minimal handler
from app.services.inferenceiq.experiments.registry import ToolHandler, register
def _validate(cfg: dict) -> list[str]:
# Called by POST /experiments BEFORE the row is created. Return a
# list of human-readable error strings; empty list = valid. cfg
# is a merged view: {**input_config, "target": target}.
errors = []
if not cfg.get("target", {}).get("model_id"):
errors.append("target.model_id required")
if not cfg.get("candidates"):
errors.append("input_config.candidates must be non-empty")
return errors
def _estimate_cost(cfg: dict) -> float:
# GPU-hours this experiment will consume. Used by the quota check
# before enqueue. Return 0 if the tool is measurement-free (e.g.,
# L0 Hardware Profiling doesn't hold a GPU).
n_candidates = len(cfg.get("candidates", []))
return 0.1 * n_candidates # ~6 minutes of a GPU per candidate
async def _run(ctx) -> None:
# Your work goes here. Emit progress events between stages. Call
# ctx.save_results(...) exactly once at the end to transition
# queued → running → completed cleanly.
await ctx.emit("progress", stage="profiling", payload={})
for i, candidate in enumerate(ctx.input_config["candidates"]):
if await ctx.check_cancelled():
return # Cooperative exit — runner marks CANCELLED
await ctx.emit("progress", stage=f"candidate-{i}",
payload={"name": candidate["name"]})
# ... do the actual work ...
await ctx.save_results({
"winner": "fp8-awq",
"delta_tps_pct": 38,
"artifact_s3_url": "s3://inwire-experiments/.../winner.safetensors",
})
register(ToolHandler(
tool_id="L42",
validate=_validate,
estimate_cost=_estimate_cost,
run=_run,
))
Context API
The ctx passed to your run function is an ExperimentContext with:
| Attribute / Method | Purpose |
|---|---|
ctx.experiment_id |
UUID — the row this run is writing to |
ctx.target |
the target JSON dict from the request |
ctx.input_config |
the input_config JSON dict from the request |
ctx.goal |
the goal enum string from the request (may be None) |
await ctx.emit(event_type, stage, payload) |
Persist one row to experiment_events and publish on the Redis pub/sub channel; SSE subscribers see it immediately |
await ctx.check_cancelled() |
Returns True if the row's status is now cancelled. Call between every slow stage. |
await ctx.save_results(summary) |
Terminal call. Transitions the row to completed and writes summary to results_summary. Raises InvalidTransitionError if called twice. |
Later specs (02 Knowledge Base, 03 Vidur Sweeper, 04 IGT Hardware
Profiler) will also inject ctx.vidur, ctx.igt, ctx.kb, and
ctx.benchmark — stable collaborator objects your handler can call
without reaching into service code.
Common patterns
- Fatal errors: raise an exception. The runner catches it, sets
status=failed, persists error_message, and emits a failed
event. Don't swallow exceptions inside _run.
- Check cancellation before every slow op: benchmark launches,
gRPC calls, S3 uploads. The cancellation is cooperative — a hard
kill would skip cleanup.
- Sweeps: emit one
progressevent per candidate so the UI can
show per-item state in the timeline.
- Big artifacts: write to S3 (via
inwire-agentor the object
storage integration), put only a reference / presigned URL in
results_summary. The DB row is designed to stay small.
- Deterministic handlers: if your tool has a recipe file, include
a content-hash in results_summary so downstream consumers can
cache-key on it.
What NOT to do
- Don't call
ctx.save_resultstwice. It raises
InvalidTransitionError — the state machine rejects writes to a
terminal state.
- Don't catch and swallow exceptions inside
_run. The runner
needs them to mark the experiment failed and record the message.
- Don't assume an HTTP request context. Your handler runs inside a
Celery worker. There's no UserContext, no request headers. If you
need the caller's identity it's recorded on the row as created_by.
- Don't do your own DB commits unless you understand RLS. The
runner opens a worker session with app.rls_bypass='true' set; any
raw SQL you issue will not be org-scoped. Use the ORM and let the
engine manage boundaries.
- Don't emit PII in event payloads. Events are persisted and
broadcast on a Redis channel — they are not the right place for
user-identifying data.
Testing your handler
Every handler should ship with:
- A validation unit test — valid cfg returns
[], missing fields
return the expected error strings.
- A cost estimation unit test — known inputs map to known GPU-hour
values (regressions break quota accounting).
- A happy-path lifecycle test — register the handler, call
run_experiment(db, id, bus) with a mocked event bus, assert
terminal state + results_summary. See
tests/unit/test_experiment_lifecycle.py for the pattern.
Metrics
All metric names start with experiment_ and live in
app.services.inferenceiq.experiments.metrics. The Grafana dashboard
(infra/monitoring/grafana/dashboards/experiments.json, UID
iiq-experiments) visualises them out of the box. If your tool needs
a custom histogram — e.g., per-candidate eval latency — add it to
that module and add a panel to the dashboard. Don't invent new metric
namespaces.
Spec 01 v1.0 scope — what's in, what's deferred
In v1.0 (PR #217 + PR #[follow-up]):
- 3 tables (
experiments,experiment_events,experiment_quota) with RLS - State machine, tool handler registry, ExperimentContext
- 6 REST endpoints (POST/GET/{id}/LIST/cancel/events/quota)
- Celery task wrapper with 30-min
soft_time_limitwatchdog - Per-org GPU-hour quota + concurrency cap, charged on COMPLETED /
CANCELLED / TIMED_OUT / FAILED when actual_gpu_hours > 0
- Monthly quota reset — counter zeros when
month_resets_atpasses
Deferred to v1.1:
- Experiment templates — save a configured experiment as a template,
re-run it later with different inputs. Spec §Open Questions proposed
YES for v1.1. Tracking: new experiment_templates table + `POST
/experiments/from-template/{id}` endpoint. No code yet.
- Per-tool
timeout_secondsoverride — currently all experiments
use the module-wide 30 min soft_time_limit. Tools that need longer
(L8 Distillation can take hours) will need per-task override via
run_experiment_task.apply_async(soft_time_limit=...).
Not in scope for this spec at all:
- Cost modelling (
$/hrinstead of raw GPU-hours) — that's spec 13
Apply+Measure territory
- Cross-org experiment sharing (public experiment catalog) — that's a
separate product decision