CALM Training

Overview

This document covers training and fine-tuning CALM models: model creation, optimizer configuration, loss function, EWC regularization, learning rate schedules, weight initialization, and the calm train CLI.

Related documents:

Quick Start

Train a model on synthetic data in under 20 lines:

local calm = require("calm")

-- 1. Create a model with random weights
local model = calm.new_model({
    d_model = 64, n_layers = 3, ffn_expand = 2,
    expand = 2, d_state = 16, d_conv = 4,
    l_max = 64, seed = 42,
})

-- 2. Create a trainer
local trainer = calm.trainer(model, {
    lr = 1e-3, optimizer = "adam",
})

-- 3. Train on a sequence (tokens after ATN are the prediction target)
local seq = { calm.BOS, calm.ATN, calm.CMD, 103, 105, 116, calm.EOS }
local cmd_pos = 1  -- 0-indexed position of ATN

for step = 1, 200 do
    local loss = trainer:step({ seq }, { cmd_pos })
    if step % 50 == 0 then print("step " .. step .. "  loss=" .. loss) end
end

-- 4. Save the trained model
trainer:save("/tmp/my_model.cwgt")
trainer:close()
model:close()

Model Creation

From scratch (random weights):

local model = calm.new_model({
    d_model    = 128,
    n_layers   = 6,
    ffn_expand = 4,
    expand     = 2,        -- d_inner = d_model × expand (default 2)
    d_state    = 16,       -- SSM state dimensions (default 16)
    d_conv     = 4,        -- short conv kernel size (default 4)
    l_max      = 768,
    seed       = 42,
    domain     = "shell",
    template   = "BOS;CWD:cwd;GIT:git;...;ATN;CMD:input",
    stop_conditions = "| ; && ||",
    sampler_defaults = {
        temperature = 0.8, top_k = 5,
        max_tokens = 20, num_candidates = 5,
    },
})

From existing weights:

local model = calm.load_model("/path/to/weights.cwgt")

This loads the weight file, validates the header, allocates activations, and prepares the model. The model is immediately ready for inference or further training.

Trainer Configuration

local trainer = calm.trainer(model, {
    lr           = 1e-4,   -- learning rate
    optimizer    = "adam",  -- "adam" or "sgd"
    beta1        = 0.9,    -- Adam first moment decay
    beta2        = 0.999,  -- Adam second moment decay
    eps          = 1e-8,   -- Adam epsilon
    weight_decay = 0.01,   -- decoupled weight decay (AdamW)
    grad_clip    = 1.0,    -- max gradient norm (0 = no clipping)
    ewc_lambda   = 0.0,    -- EWC regularization strength (0 = disabled)
})

The trainer allocates all required buffers on creation:

BufferSize (Mini config, L=768)
Gradient buffer~7 MB
Adam m + v~14 MB
Training activations~15 MB
Backward scratch~3 MB

All buffers are freed when trainer:close() is called.

Adam (default) converges faster and is recommended for most training. Uses ~3x the weight memory for optimizer state.

SGD is simpler and uses no additional state beyond the gradient buffer. Requires higher learning rates (typically 10-100x Adam's LR) and more steps to converge. Useful for quick fine-tuning with tight memory constraints.

Loss Function

Standard cross-entropy loss over next-token prediction:

loss = CrossEntropy(logits[:-1, :], target_ids[1:])

Loss is computed on all tokens after the <ATN> boundary. The cmd_pos value stores the 0-indexed position of <ATN> in the sequence. The loss loop starts at this position, predicting tokens[cmd_pos + 1] onward — so every token after <ATN> is a prediction target. If <ATN> is absent, cmd_pos is 0 and loss covers the entire sequence (full-sequence mode).

For shell datasets, <ATN> appears immediately before <CMD>, so <CMD> itself is the first prediction target, followed by the command bytes. For other domains (e.g. dictionary), the first content token after <ATN> is the first target. The mechanism is the same — <ATN> marks where loss begins.

Training Loop

Single batch step:

local loss = trainer:step(sequences, cmd_positions)

Each step() call:

  1. Zeros all gradients

  2. Runs forward + backward for each sequence, accumulating gradients

  3. Averages gradients over the batch

  4. Adds EWC penalty gradient (if ewc_lambda > 0 and Fisher is computed)

  5. Clips gradients to grad_clip norm

  6. Applies optimizer update (Adam or SGD)

Training from scratch:

local calm = require("calm")

local model = calm.new_model({
    d_model = 128, n_layers = 6, ffn_expand = 4,
    expand = 2, d_state = 16, d_conv = 4,
    l_max = 768, seed = 42,
})

local trainer = calm.trainer(model, {
    lr = 1e-3, optimizer = "adam",
    weight_decay = 0.01, grad_clip = 1.0,
})

local ds = calm.load_dataset("/path/to/train.ctds")

local num_epochs = 10
local batch_size = 32
local batches_per_epoch = math.ceil(ds:count() / batch_size)

for epoch = 1, num_epochs do
    ds:shuffle()
    local epoch_loss = 0
    for b = 0, batches_per_epoch - 1 do
        local seqs, cmds = ds:batch(b, batch_size)
        if #seqs == 0 then break end
        local loss = trainer:step(seqs, cmds)
        epoch_loss = epoch_loss + loss
    end
    print(string.format("epoch %d  avg_loss=%.4f", epoch, epoch_loss / batches_per_epoch))
end

trainer:save("/path/to/trained.cwgt")
trainer:close()
model:close()
ds:close()

Fine-tuning existing weights:

Same as above, but start from calm.load_model() instead of calm.new_model(), and use a lower learning rate:

local model = calm.load_model("/path/to/base_model.cwgt")
local trainer = calm.trainer(model, {
    lr = 5e-5, optimizer = "adam",
    weight_decay = 0.01, grad_clip = 1.0,
})
-- ... training loop ...

EWC Regularization

Elastic Weight Consolidation prevents catastrophic forgetting during fine-tuning. When the model fine-tunes on new user history, EWC penalizes large changes to parameters that were important for previously learned patterns.

loss_ewc = λ × Σ_i F_i × (θ_i - θ*_i)²

Where F_i is the Fisher information (diagonal approximation) for parameter i, θ*_i is the parameter value after previous training, and λ controls the regularization strength.

Workflow:

-- 1. Train on initial data
local model = calm.load_model("base.cwgt")
local trainer = calm.trainer(model, { lr = 1e-4, optimizer = "adam" })

for epoch = 1, 5 do
    -- ... training loop on dataset A ...
end

-- 2. Compute Fisher information and save anchor
trainer:compute_fisher(validation_seqs, validation_cmd_pos)

-- 3. Save model (includes EWC data: Fisher diagonal + anchor weights)
trainer:save("model_with_ewc.cwgt")
trainer:close()

-- 4. Later: fine-tune on new data with EWC protection
local model2 = calm.load_model("model_with_ewc.cwgt")
local trainer2 = calm.trainer(model2, {
    lr = 1e-4, optimizer = "adam",
    ewc_lambda = 10.0,  -- regularization strength
})

for epoch = 1, 5 do
    -- ... training loop on dataset B ...
    -- EWC penalty automatically prevents drift from anchor
end

trainer2:save("model_updated.cwgt")
trainer2:close()
model2:close()

After compute_fisher(), the Fisher information diagonal F[i] is stored per parameter, approximating how important each parameter is for the current task. The current weights are saved as the "anchor" theta*. Both are written into the weight file when saved.

During subsequent training with ewc_lambda > 0, each optimizer step adds a penalty gradient:

grad[i] += ewc_lambda * F[i] * (theta[i] - theta*[i])

Choosing ewc_lambda:

ValueBehavior
0No EWC (free adaptation)
0.1 - 1.0Mild: allows substantial adaptation
1.0 - 10.0Moderate: balances old and new
10.0 - 100.0Strong: heavily preserves old behavior
> 100.0Very strong: new data barely changes the model

Start with ewc_lambda = 10.0 and adjust based on whether the model retains enough of its original capability.

Learning Rate Schedule

For base model training: cosine schedule with warmup.

For fine-tuning: constant low learning rate (e.g., 1e-4 to 5e-5). No warmup needed for fine-tuning.

ScenarioOptimizerLRNotes
From scratch (Nano/Micro)Adam1e-3Small models tolerate higher LR
From scratch (Mini/Small)Adam3e-4Larger models need lower LR
Fine-tuningAdam1e-4 to 5e-5Lower to avoid forgetting
Quick adaptationSGD0.01Few steps, coarse updates

Gradient Computation

The backward pass requires:

Weight Initialization

Practical Tips

Batch size: Batch size 1 works but is noisy. Batch sizes of 8-32 provide smoother gradients. The training step averages gradients over the batch, so larger batches give more stable updates at the cost of more computation per step.

Monitoring convergence: Watch the loss value returned by trainer:step(). For memorization (overfitting to a small dataset), loss should approach 0. For generalization, track loss on a held-out validation set separately using model:forward().

Memory budget:

ComponentNano (d=64)Mini (d=128)Small (d=192)
Weights~1 MB~8 MB~28 MB
Gradients~1 MB~8 MB~28 MB
Adam state~2 MB~16 MB~56 MB
Train activations~2 MB~15 MB~40 MB
Scratch buffers~1 MB~3 MB~6 MB
Total~7 MB~50 MB~158 MB

All training memory is freed when trainer:close() is called. The model retains only its weights and inference activation buffers.

Shell Builtin: calm train

calm train accepts options:

When training a model from scratch (--model new), use a higher learning rate with cosine schedule for best results:

calm train --model new --size nano --lr 1e-3 --warmup-steps 100 --epochs 10 --batch-size 32

calm train runs in the foreground. Use job start calm train ... for background execution.

calm evaluate computes average, min, and max per-sequence loss on a CTDS dataset. Useful for checking model quality on a held-out test set.

calm benchmark creates temporary models at each size (Nano through Small) and measures forward pass and completion latency.

Lua API

Training operations

local calm = require("calm")

-- Start a training session (from existing or freshly initialized model)
local trainer = calm.trainer(model, {
    lr = 1e-4,              -- learning rate
    optimizer = "adam",      -- "adam" or "sgd"
    weight_decay = 0.01,    -- AdamW weight decay
    ewc_lambda = 0.5,       -- EWC regularization strength (0 = disabled)
    grad_clip = 1.0,        -- gradient clipping norm
})

-- Feed a batch of training sequences
-- sequences is an array of token ID arrays
-- cmd_positions is an array of ATN positions (for loss masking)
local loss = trainer:step(sequences, cmd_positions)

-- Compute and store Fisher information for EWC
-- (call after training, before saving weights)
trainer:compute_fisher(validation_sequences, validation_cmd_positions)

-- Save updated weights (includes EWC data if computed)
trainer:save("/path/to/weights.cwgt")

-- Adjust learning rate dynamically
trainer:set_lr(5e-5)

-- Contrastive training step (for embedding models)
-- queries and positives are arrays of token ID arrays
local loss = trainer:contrastive_step(queries, positives, {
    temperature = 0.07,  -- InfoNCE temperature (default: 0.07)
    pool = "mean",       -- pooling: "mean" or "last" (default: "mean")
})

-- Manual gradient accumulation (forward+backward without optimizer step)
trainer:accumulate(tokens, cmd_pos)

-- Access/clear gradients
local grad = trainer:get_grad(param_index)
trainer:zero_grad()

-- Free training state (optimizer moments, etc.)
trainer:close()