Cookbook: ML Coding Agent
Build a coding agent specialized for machine learning workflows — training, evaluation, and model iteration.
Cookbook: ML Coding Agent
Build an agent that understands ML codebases, runs training scripts, evaluates models, and iterates on architectures — all from natural language instructions.
Why Cersei for ML?
Unlike generic agent SDKs, Cersei gives you:
- Full control over tool execution — run
python train.pywith custom timeout, capture GPU metrics - Graph memory — track experiment results across sessions (loss curves, hyperparams, what worked)
- Tree-sitter code intelligence — parse Python ML code to understand model architectures before modifying
- Provider agnostic — use GPT-5.3 for cheap iteration, Claude Opus for complex architecture decisions
The Agent
use cersei::prelude::*;
use cersei_tools::tool_primitives::{code_intel, process};
use std::path::Path;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let provider = cersei_provider::from_model_string("auto")?.0;
// Custom tools for ML workflows
let mut tools = cersei_tools::coding(); // Read, Write, Edit, Glob, Grep, Bash
tools.push(Box::new(TrainTool));
tools.push(Box::new(EvalTool));
tools.push(Box::new(GpuStatusTool));
let agent = Agent::builder()
.provider(provider)
.tools(tools)
.system_prompt(ML_SYSTEM_PROMPT)
.max_turns(30) // ML tasks need many iterations
.max_tokens(16384)
.working_dir(".")
.build()?;
let agent = std::sync::Arc::new(agent);
let mut stream = agent.run_stream(
"Analyze the model architecture in model.py, then train for 5 epochs \
and report the loss curve. If val_loss plateaus, suggest architecture changes."
);
while let Some(event) = stream.next().await {
match event {
AgentEvent::TextDelta(t) => print!("{t}"),
AgentEvent::ToolStart { name, .. } => eprintln!("\n > {name}"),
AgentEvent::ToolEnd { name, duration, .. } => {
eprintln!(" < {name} ({:.1}s)", duration.as_secs_f64());
}
AgentEvent::Complete(_) => break,
_ => {}
}
}
Ok(())
}
const ML_SYSTEM_PROMPT: &str = "You are an ML engineering agent. You can:
- Read and understand model architectures (PyTorch, TensorFlow, MLX)
- Run training scripts and capture metrics
- Analyze loss curves and suggest improvements
- Modify hyperparameters and model code
- Track experiments across sessions via memory
Always check GPU status before training. Report metrics in tables.";Custom Tools
Train Tool
struct TrainTool;
#[async_trait]
impl Tool for TrainTool {
fn name(&self) -> &str { "Train" }
fn description(&self) -> &str {
"Run a training script with GPU monitoring. Returns stdout, stderr, and training metrics."
}
fn permission_level(&self) -> PermissionLevel { PermissionLevel::Execute }
fn input_schema(&self) -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {
"script": { "type": "string", "description": "Python script to run" },
"args": { "type": "string", "description": "CLI arguments", "default": "" },
"timeout_secs": { "type": "integer", "default": 600 }
},
"required": ["script"]
})
}
async fn execute(&self, input: serde_json::Value, ctx: &ToolContext) -> ToolResult {
let script = input["script"].as_str().unwrap_or("train.py");
let args = input["args"].as_str().unwrap_or("");
let timeout = input["timeout_secs"].as_u64().unwrap_or(600);
let cmd = format!("python {} {}", script, args);
let opts = process::ExecOptions {
cwd: Some(ctx.working_dir.clone()),
timeout: Some(std::time::Duration::from_secs(timeout)),
..Default::default()
};
match process::exec(&cmd, opts).await {
Ok(output) => {
let mut result = String::new();
if !output.stdout.is_empty() {
result.push_str(&output.stdout);
}
if !output.stderr.is_empty() {
result.push_str("\n--- stderr ---\n");
result.push_str(&output.stderr);
}
if output.timed_out {
ToolResult::error(format!("Training timed out after {timeout}s\n{result}"))
} else if output.exit_code != 0 {
ToolResult::error(format!("Training failed (exit {})\n{result}", output.exit_code))
} else {
ToolResult::success(result)
}
}
Err(e) => ToolResult::error(format!("Failed to start training: {e}")),
}
}
}GPU Status Tool
struct GpuStatusTool;
#[async_trait]
impl Tool for GpuStatusTool {
fn name(&self) -> &str { "GpuStatus" }
fn description(&self) -> &str { "Check GPU availability and memory usage." }
fn permission_level(&self) -> PermissionLevel { PermissionLevel::ReadOnly }
fn input_schema(&self) -> serde_json::Value {
serde_json::json!({ "type": "object", "properties": {} })
}
async fn execute(&self, _input: serde_json::Value, _ctx: &ToolContext) -> ToolResult {
// Try nvidia-smi first, fall back to Apple MPS check
let nvidia = process::exec(
"nvidia-smi --query-gpu=name,memory.used,memory.total,utilization.gpu --format=csv,noheader",
Default::default(),
).await;
if let Ok(output) = nvidia {
if output.exit_code == 0 {
return ToolResult::success(format!("NVIDIA GPU:\n{}", output.stdout));
}
}
// Check for Apple Silicon MPS
let mps = process::exec(
"python -c \"import torch; print(f'MPS available: {torch.backends.mps.is_available()}')\"",
Default::default(),
).await;
match mps {
Ok(output) if output.exit_code == 0 => {
ToolResult::success(format!("Apple Silicon:\n{}", output.stdout))
}
_ => ToolResult::success("No GPU detected. Training will use CPU."),
}
}
}Using Tree-Sitter for Model Analysis
Before modifying ML code, use tree-sitter to understand the architecture:
use cersei_tools::tool_primitives::code_intel;
use std::path::Path;
// Scan the ML project
let intels = code_intel::scan_project(Path::new("./"), 20);
// Find model definition files
for intel in &intels {
let has_model = intel.symbols.iter().any(|s| {
s.name.contains("Model") || s.name.contains("Network") || s.name.contains("Net")
});
if has_model {
println!("Model file: {} — symbols:", intel.path.display());
for sym in &intel.symbols {
println!(" {} {} (line {})", sym.kind.label(), sym.name, sym.line);
}
}
}Experiment Tracking with Graph Memory
Store experiment results in Cersei's graph memory for cross-session recall:
use cersei_memory::graph::GraphMemory;
let graph = GraphMemory::open(Path::new("~/.abstract/graph.db"))?;
// After each training run, store results
graph.store_memory(
"experiment: lr=1e-4, epochs=10, val_loss=0.23, val_acc=0.91",
cersei_memory::memdir::MemoryType::Project,
0.95, // high confidence
)?;
// Tag with topic
let mem_id = graph.store_memory(/* ... */)?;
graph.tag_memory(&mem_id, "experiment-results")?;
// Later: recall what worked
let results = graph.by_topic("experiment-results");
// Returns: all experiment results sorted by confidenceFull Example Flow
User: "Train the model on the new dataset and compare with last run"
Agent:
1. [GpuStatus] → MPS available: True
2. [Read] model.py → Understands architecture (ResNet variant, 8M params)
3. [Grep] "val_loss" in training logs → Finds last run: val_loss=0.31
4. [Train] python train.py --data new_dataset --epochs 10
→ Training output: epoch 10/10, val_loss=0.24, val_acc=0.89
5. [Memory] Stores: "new_dataset run: val_loss=0.24 (improved from 0.31)"
6. Synthesizes comparison table with improvement analysis