Cersei

Cookbook: ML Coding Agent

Build a coding agent specialized for machine learning workflows — training, evaluation, and model iteration.

Cookbook: ML Coding Agent

Build an agent that understands ML codebases, runs training scripts, evaluates models, and iterates on architectures — all from natural language instructions.

Why Cersei for ML?

Unlike generic agent SDKs, Cersei gives you:

  • Full control over tool execution — run python train.py with custom timeout, capture GPU metrics
  • Graph memory — track experiment results across sessions (loss curves, hyperparams, what worked)
  • Tree-sitter code intelligence — parse Python ML code to understand model architectures before modifying
  • Provider agnostic — use GPT-5.3 for cheap iteration, Claude Opus for complex architecture decisions

The Agent

use cersei::prelude::*;
use cersei_tools::tool_primitives::{code_intel, process};
use std::path::Path;

#[tokio::main]
async fn main() -> anyhow::Result<()> {
    let provider = cersei_provider::from_model_string("auto")?.0;

    // Custom tools for ML workflows
    let mut tools = cersei_tools::coding(); // Read, Write, Edit, Glob, Grep, Bash
    tools.push(Box::new(TrainTool));
    tools.push(Box::new(EvalTool));
    tools.push(Box::new(GpuStatusTool));

    let agent = Agent::builder()
        .provider(provider)
        .tools(tools)
        .system_prompt(ML_SYSTEM_PROMPT)
        .max_turns(30) // ML tasks need many iterations
        .max_tokens(16384)
        .working_dir(".")
        .build()?;

    let agent = std::sync::Arc::new(agent);
    let mut stream = agent.run_stream(
        "Analyze the model architecture in model.py, then train for 5 epochs \
         and report the loss curve. If val_loss plateaus, suggest architecture changes."
    );

    while let Some(event) = stream.next().await {
        match event {
            AgentEvent::TextDelta(t) => print!("{t}"),
            AgentEvent::ToolStart { name, .. } => eprintln!("\n  > {name}"),
            AgentEvent::ToolEnd { name, duration, .. } => {
                eprintln!("  < {name} ({:.1}s)", duration.as_secs_f64());
            }
            AgentEvent::Complete(_) => break,
            _ => {}
        }
    }

    Ok(())
}

const ML_SYSTEM_PROMPT: &str = "You are an ML engineering agent. You can:
- Read and understand model architectures (PyTorch, TensorFlow, MLX)
- Run training scripts and capture metrics
- Analyze loss curves and suggest improvements
- Modify hyperparameters and model code
- Track experiments across sessions via memory

Always check GPU status before training. Report metrics in tables.";

Custom Tools

Train Tool

struct TrainTool;

#[async_trait]
impl Tool for TrainTool {
    fn name(&self) -> &str { "Train" }
    fn description(&self) -> &str {
        "Run a training script with GPU monitoring. Returns stdout, stderr, and training metrics."
    }
    fn permission_level(&self) -> PermissionLevel { PermissionLevel::Execute }

    fn input_schema(&self) -> serde_json::Value {
        serde_json::json!({
            "type": "object",
            "properties": {
                "script": { "type": "string", "description": "Python script to run" },
                "args": { "type": "string", "description": "CLI arguments", "default": "" },
                "timeout_secs": { "type": "integer", "default": 600 }
            },
            "required": ["script"]
        })
    }

    async fn execute(&self, input: serde_json::Value, ctx: &ToolContext) -> ToolResult {
        let script = input["script"].as_str().unwrap_or("train.py");
        let args = input["args"].as_str().unwrap_or("");
        let timeout = input["timeout_secs"].as_u64().unwrap_or(600);

        let cmd = format!("python {} {}", script, args);
        let opts = process::ExecOptions {
            cwd: Some(ctx.working_dir.clone()),
            timeout: Some(std::time::Duration::from_secs(timeout)),
            ..Default::default()
        };

        match process::exec(&cmd, opts).await {
            Ok(output) => {
                let mut result = String::new();
                if !output.stdout.is_empty() {
                    result.push_str(&output.stdout);
                }
                if !output.stderr.is_empty() {
                    result.push_str("\n--- stderr ---\n");
                    result.push_str(&output.stderr);
                }
                if output.timed_out {
                    ToolResult::error(format!("Training timed out after {timeout}s\n{result}"))
                } else if output.exit_code != 0 {
                    ToolResult::error(format!("Training failed (exit {})\n{result}", output.exit_code))
                } else {
                    ToolResult::success(result)
                }
            }
            Err(e) => ToolResult::error(format!("Failed to start training: {e}")),
        }
    }
}

GPU Status Tool

struct GpuStatusTool;

#[async_trait]
impl Tool for GpuStatusTool {
    fn name(&self) -> &str { "GpuStatus" }
    fn description(&self) -> &str { "Check GPU availability and memory usage." }
    fn permission_level(&self) -> PermissionLevel { PermissionLevel::ReadOnly }

    fn input_schema(&self) -> serde_json::Value {
        serde_json::json!({ "type": "object", "properties": {} })
    }

    async fn execute(&self, _input: serde_json::Value, _ctx: &ToolContext) -> ToolResult {
        // Try nvidia-smi first, fall back to Apple MPS check
        let nvidia = process::exec(
            "nvidia-smi --query-gpu=name,memory.used,memory.total,utilization.gpu --format=csv,noheader",
            Default::default(),
        ).await;

        if let Ok(output) = nvidia {
            if output.exit_code == 0 {
                return ToolResult::success(format!("NVIDIA GPU:\n{}", output.stdout));
            }
        }

        // Check for Apple Silicon MPS
        let mps = process::exec(
            "python -c \"import torch; print(f'MPS available: {torch.backends.mps.is_available()}')\"",
            Default::default(),
        ).await;

        match mps {
            Ok(output) if output.exit_code == 0 => {
                ToolResult::success(format!("Apple Silicon:\n{}", output.stdout))
            }
            _ => ToolResult::success("No GPU detected. Training will use CPU."),
        }
    }
}

Using Tree-Sitter for Model Analysis

Before modifying ML code, use tree-sitter to understand the architecture:

use cersei_tools::tool_primitives::code_intel;
use std::path::Path;

// Scan the ML project
let intels = code_intel::scan_project(Path::new("./"), 20);

// Find model definition files
for intel in &intels {
    let has_model = intel.symbols.iter().any(|s| {
        s.name.contains("Model") || s.name.contains("Network") || s.name.contains("Net")
    });
    if has_model {
        println!("Model file: {} — symbols:", intel.path.display());
        for sym in &intel.symbols {
            println!("  {} {} (line {})", sym.kind.label(), sym.name, sym.line);
        }
    }
}

Experiment Tracking with Graph Memory

Store experiment results in Cersei's graph memory for cross-session recall:

use cersei_memory::graph::GraphMemory;

let graph = GraphMemory::open(Path::new("~/.abstract/graph.db"))?;

// After each training run, store results
graph.store_memory(
    "experiment: lr=1e-4, epochs=10, val_loss=0.23, val_acc=0.91",
    cersei_memory::memdir::MemoryType::Project,
    0.95, // high confidence
)?;

// Tag with topic
let mem_id = graph.store_memory(/* ... */)?;
graph.tag_memory(&mem_id, "experiment-results")?;

// Later: recall what worked
let results = graph.by_topic("experiment-results");
// Returns: all experiment results sorted by confidence

Full Example Flow

User: "Train the model on the new dataset and compare with last run"

Agent:
1. [GpuStatus] → MPS available: True
2. [Read] model.py → Understands architecture (ResNet variant, 8M params)
3. [Grep] "val_loss" in training logs → Finds last run: val_loss=0.31
4. [Train] python train.py --data new_dataset --epochs 10
   → Training output: epoch 10/10, val_loss=0.24, val_acc=0.89
5. [Memory] Stores: "new_dataset run: val_loss=0.24 (improved from 0.31)"
6. Synthesizes comparison table with improvement analysis

On this page