/*
* lrnnSMDDS - linear RNN/Reservoir hybrid for CPU (C Implementation)
*
* PolyForm Noncommercial License 1.0.0
* https://polyformproject.org/licenses/noncommercial/1.0.0/
*
* ======================================================
*
* Features:
* S. SwiGLU in Channel Mixing (more coherence)
* M. Multi-Scale Token Shift (larger context/"infinite")
* D. Data-Dependent Decay with Low-Rank (speed in large context)
* D. Dynamic State Checkpointing (faster/linear generation)
* S. Slot-memory resorvoir (perfect recall, transformers style, legacy/proven)
*
* Compile on cygwin on Windows (or POSIX linux, gcc is needed!):
*
* gcc -std=c17 -O3 -march=native -Wall --fast-math -Wextra -o lrnn aismdd.c
*
* Usage:
* Training: ./lrnn --train corpus.txt --save model.bin --epochs 20
* Generate: ./lrnn --load model.bin --seed "Hello world" --tokens 200
*/
#define _POSIX_C_SOURCE 200809L
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
/* ============================================================
* Constants and Configuration
* ============================================================ */
#define MAX_VOCAB_SIZE 65536
#define MAX_LAYERS 32
#define EPSILON 1e-6f
#define GRAD_CLIP 50.0f
typedef struct {
int vocab_size;
int n_layer;
int n_embd;
int n_head;
int ctx_len;
int decay_lora_rank;
float ffn_multiplier;
int n_mem_slots;
} lrnnConfig;
static lrnnConfig default_config(void) {
lrnnConfig cfg = {
.vocab_size = 256, // overwritten by build_vocabulary()
.n_layer = 2,
.n_embd = 64,
.n_head = 1,
.ctx_len = 64,
.decay_lora_rank = 2,
.ffn_multiplier = 2.0f,
.n_mem_slots = 4
};
return cfg;
}
static inline int ffn_hidden(const lrnnConfig *cfg) {
return (int)(cfg->n_embd * cfg->ffn_multiplier);
}
//forward help
/* ============================================================
* Tensor Structure and Operations
* ============================================================ */
typedef struct {
float *data;
int rows;
int cols;
int size;
} Tensor;
static Tensor tensor_alloc(int rows, int cols) {
Tensor t;
t.rows = rows;
t.cols = cols;
t.size = rows * cols;
t.data = NULL;
if (t.size > 0) {
t.data = (float *)calloc((size_t)t.size, sizeof(float));
if (!t.data) {
fprintf(stderr, "Error: tensor allocation failed (%d x %d)\n", rows, cols);
exit(1);
}
}
return t;
}
static Tensor tensor_alloc_1d(int size) {
return tensor_alloc(size, 1);
}
static void tensor_free(Tensor *t) {
if (t && t->data) {
free(t->data);
t->data = NULL;
}
if (t) {
t->rows = t->cols = t->size = 0;
}
}
static void tensor_copy(Tensor *dst, const Tensor *src) {
if (dst->size != src->size) {
fprintf(stderr, "Error: tensor_copy size mismatch (%d vs %d)\n", dst->size, src->size);
exit(1);
}
if (src->size > 0) {
memcpy(dst->data, src->data, (size_t)src->size * sizeof(float));
}
}
static void tensor_fill(Tensor *t, float val) {
for (int i = 0; i < t->size; i++) {
t->data[i] = val;
}
}
static void tensor_zero(Tensor *t) {
if (t->data && t->size > 0) {
memset(t->data, 0, (size_t)t->size * sizeof(float));
}
}
/* Random initialization */
static float randn(void) {
/* Box-Muller transform */
float u1 = ((float)rand() + 1.0f) / ((float)RAND_MAX + 2.0f);
float u2 = ((float)rand() + 1.0f) / ((float)RAND_MAX + 2.0f);
return sqrtf(-2.0f * logf(u1)) * cosf(2.0f * 3.14159265f * u2);
}
static float rand_uniform(float lo, float hi) {
return lo + ((float)rand() / (float)RAND_MAX) * (hi - lo);
}
static void tensor_randn(Tensor *t, float scale) {
for (int i = 0; i < t->size; i++) {
t->data[i] = randn() * scale;
}
}
static void tensor_rand_uniform(Tensor *t, float lo, float hi) {
for (int i = 0; i < t->size; i++) {
t->data[i] = rand_uniform(lo, hi);
}
}
/* ============================================================
* Activation Functions
* ============================================================ */
static inline float sigmoid_f(float x) {
return 1.0f / (1.0f + expf(-x));
}
static inline float silu_f(float x) {
return x * sigmoid_f(x);
}
static inline float clamp_f(float x, float lo, float hi) {
if (x < lo) return lo;
if (x > hi) return hi;
return x;
}
static void sigmoid_vec(float *out, const float *in, int n) {
for (int i = 0; i < n; i++) {
out[i] = sigmoid_f(in[i]);
}
}
static void exp_vec(float *out, const float *in, int n) {
for (int i = 0; i < n; i++) {
out[i] = expf(clamp_f(in[i], -10.0f, 10.0f));
}
}
static void softmax_vec(float *out, const float *in, int n) {
float max_val = in[0];
for (int i = 1; i < n; i++) {
if (in[i] > max_val) max_val = in[i];
}
float sum = 0.0f;
for (int i = 0; i < n; i++) {
out[i] = expf(in[i] - max_val);
sum += out[i];
}
float inv_sum = 1.0f / (sum + EPSILON);
for (int i = 0; i < n; i++) {
out[i] *= inv_sum;
}
}
/* ============================================================
* Vector/Matrix Operations
* ============================================================ */
/* out = a + b (element-wise) */
static void vec_add(float *out, const float *a, const float *b, int n) {
for (int i = 0; i < n; i++) {
out[i] = a[i] + b[i];
}
}
/* out = a * b (element-wise) */
static void vec_mul(float *out, const float *a, const float *b, int n) {
for (int i = 0; i < n; i++) {
out[i] = a[i] * b[i];
}
}
/* out = x @ W where x: (1, in_dim), W: (in_dim, out_dim) -> out: (1, out_dim) */
static void matvec(float *out, const float *x, const Tensor *W) {
int in_dim = W->rows;
int out_dim = W->cols;
for (int j = 0; j < out_dim; j++) {
float sum = 0.0f;
for (int i = 0; i < in_dim; i++) {
sum += x[i] * W->data[i * out_dim + j];
}
out[j] = sum;
}
}
/* out = X @ W where X: (seq_len, in_dim), W: (in_dim, out_dim) -> out: (seq_len, out_dim) */
static void matmul(Tensor *out, const Tensor *X, const Tensor *W) {
int seq_len = X->rows;
int in_dim = X->cols;
int out_dim = W->cols;
if (W->rows != in_dim) {
fprintf(stderr, "matmul dimension mismatch: X(%d,%d) @ W(%d,%d)\n",
X->rows, X->cols, W->rows, W->cols);
exit(1);
}
for (int s = 0; s < seq_len; s++) {
for (int j = 0; j < out_dim; j++) {
float sum = 0.0f;
for (int i = 0; i < in_dim; i++) {
sum += X->data[s * in_dim + i] * W->data[i * out_dim + j];
}
out->data[s * out_dim + j] = sum;
}
}
}
/* Layer normalization: out = weight * (x - mean) / sqrt(var + eps) + bias */
static void layer_norm(float *out, const float *x, const float *weight,
const float *bias, int n) {
float mean = 0.0f;
for (int i = 0; i < n; i++) {
mean += x[i];
}
mean /= (float)n;
float var = 0.0f;
for (int i = 0; i < n; i++) {
float d = x[i] - mean;
var += d * d;
}
var /= (float)n;
float inv_std = 1.0f / sqrtf(var + EPSILON);
for (int i = 0; i < n; i++) {
out[i] = weight[i] * (x[i] - mean) * inv_std + bias[i];
}
}
/* Layer norm for sequence: (seq_len, n_embd) */
static void layer_norm_seq(Tensor *out, const Tensor *x, const Tensor *weight,
const Tensor *bias) {
int seq_len = x->rows;
int n_embd = x->cols;
for (int s = 0; s < seq_len; s++) {
layer_norm(out->data + s * n_embd,
x->data + s * n_embd,
weight->data, bias->data, n_embd);
}
}
/* ============================================================
* ALiBi - Attention with Linear Biases
* ============================================================
* Each head h gets slope m_h = 1 / 2^(8h/H) where H = n_head.
* For token at position t attending to position s:
* bias = -m_h * |t - s|
*
* In WKV recurrence, this manifests as an additional
* geometric decay per head that compounds with the
* data-dependent decay.
* ============================================================ */
static void compute_alibi_slopes(float *slopes, int n_head) {
/* slopes[h] = 2^(-8*(h+1)/n_head) */
for (int h = 0; h < n_head; h++) {
float exponent = -8.0f * (float)(h + 1) / (float)n_head;
slopes[h] = powf(2.0f, exponent);
}
}
/* ============================================================
* Layer Parameters Structure
* ============================================================ */
typedef struct {
/* Layer norms */
Tensor ln1_weight, ln1_bias;
Tensor ln2_weight, ln2_bias;
/* Multi-scale token shift */
Tensor time_shift_w1, time_shift_w2, time_shift_w4;
/* Token mixing ratios */
Tensor time_mix_r, time_mix_k, time_mix_v;
/* Data-dependent decay (low-rank) */
Tensor decay_lora_a, decay_lora_b;
Tensor decay_base;
Tensor time_first;
/* Projections */
Tensor Wr, Wk, Wv, Wo;
/* Channel mix */
Tensor channel_mix;
Tensor ffn_gate, ffn_up, ffn_down;
Tensor alibi_slopes;
Tensor mem_gate_write; /* (n_embd, n_mem_slots) - write gate */
Tensor mem_gate_read; /* (n_embd, n_mem_slots) - read gate */
} LayerParams;
typedef struct {
Tensor emb;
Tensor ln0_weight, ln0_bias;
LayerParams *layers;
Tensor ln_out_weight, ln_out_bias;
Tensor head;
int n_layers;
} ModelParams;
/* Layer state for generation */
typedef struct {
Tensor x_prev_1, x_prev_2, x_prev_3, x_prev_4;
Tensor wkv_num; /* (n_mem_slots, n_embd) */
Tensor wkv_den; /* (n_mem_slots, n_embd) */
Tensor ffn_prev;
} LayerState;
typedef struct {
LayerState *layers;
int n_layers;
} ModelState;
/* Vocabulary */
typedef struct {
char chars[MAX_VOCAB_SIZE];
int char_to_idx[256];
int size;
} Vocabulary;
/* ============================================================
* Word-Level Tokenizer Structures
* ============================================================ */
typedef enum {
TOKENIZER_CHAR,
TOKENIZER_WORD,
TOKENIZER_AUTO
} TokenizerType;
#define MAX_WORD_LEN 64
#define MAX_WORDS 32768
#define WORD_HASH_SIZE 65536
typedef struct {
char **words; /* Array of word strings */
int *hash_table; /* Hash table: hash -> word index (-1 if empty) */
int *hash_keys; /* Hash table: stored hash for collision detection */
int size; /* Number of words */
int capacity; /* Allocated capacity */
int unk_idx; /* token index */
int pad_idx; /* token index */
int space_idx; /* Space token index */
int newline_idx; /* Newline token index */
} WordVocabulary;
typedef struct {
TokenizerType type;
Vocabulary char_vocab; /* Used if type == TOKENIZER_CHAR */
WordVocabulary word_vocab; /* Used if type == TOKENIZER_WORD */
} Tokenizer;
/* ============================================================
* Forward Declarations
* ============================================================ */
/* Word vocabulary functions */
static void init_word_vocabulary(WordVocabulary *wv);
static void free_word_vocabulary(WordVocabulary *wv);
static int word_vocab_find(const WordVocabulary *wv, const char *word);
static int word_vocab_add(WordVocabulary *wv, const char *word);
static void build_word_vocabulary(WordVocabulary *wv, const char *text, size_t len);
static int *tokenize_words(const char *text, size_t len, const WordVocabulary *wv, int *out_len);
static const char *decode_word_token(int token, const WordVocabulary *wv);
/* Character vocabulary functions */
static void build_vocabulary(Vocabulary *vocab, const char *text, size_t len);
static int *tokenize(const char *text, size_t len, const Vocabulary *vocab, int *out_len);
/* Tokenizer interface */
static void init_tokenizer(Tokenizer *tok, TokenizerType type);
static void free_tokenizer(Tokenizer *tok);
static void build_tokenizer(Tokenizer *tok, const char *text, size_t len, TokenizerType requested_type);
static int tokenizer_vocab_size(const Tokenizer *tok);
static int *tokenizer_encode(const Tokenizer *tok, const char *text, size_t len, int *out_len);
static void tokenizer_decode_token(const Tokenizer *tok, int token, char *out, int out_size);
/* ============================================================
* Parameter Initialization
* ============================================================ */
static void init_layer_params(LayerParams *lp, const lrnnConfig *cfg, int layer_idx) {
int n_embd = cfg->n_embd;
int ffn_h = ffn_hidden(cfg);
int lora_rank = cfg->decay_lora_rank;
float proj_scale = 0.02f / sqrtf((float)cfg->n_layer);
float ffn_scale = proj_scale;
//reservoir/slots initialization
int n_slots = cfg->n_mem_slots;
lp->mem_gate_write = tensor_alloc(n_embd, n_slots);
tensor_randn(&lp->mem_gate_write, 0.01f);
lp->mem_gate_read = tensor_alloc(n_embd, n_slots);
tensor_randn(&lp->mem_gate_read, 0.01f);
/* Layer norms */
lp->ln1_weight = tensor_alloc_1d(n_embd);
tensor_fill(&lp->ln1_weight, 1.0f);
lp->ln1_bias = tensor_alloc_1d(n_embd);
lp->ln2_weight = tensor_alloc_1d(n_embd);
tensor_fill(&lp->ln2_weight, 1.0f);
lp->ln2_bias = tensor_alloc_1d(n_embd);
/* Multi-scale token shift */
lp->time_shift_w1 = tensor_alloc_1d(n_embd);
tensor_rand_uniform(&lp->time_shift_w1, 0.3f, 0.7f);
lp->time_shift_w2 = tensor_alloc_1d(n_embd);
tensor_rand_uniform(&lp->time_shift_w2, 0.1f, 0.3f);
lp->time_shift_w4 = tensor_alloc_1d(n_embd);
tensor_rand_uniform(&lp->time_shift_w4, 0.0f, 0.2f);
/* Token mixing */
lp->time_mix_r = tensor_alloc_1d(n_embd);
tensor_fill(&lp->time_mix_r, 0.5f);
lp->time_mix_k = tensor_alloc_1d(n_embd);
tensor_fill(&lp->time_mix_k, 0.5f);
lp->time_mix_v = tensor_alloc_1d(n_embd);
tensor_fill(&lp->time_mix_v, 0.5f);
/* Decay LoRA */
lp->decay_lora_a = tensor_alloc(n_embd, lora_rank);
tensor_randn(&lp->decay_lora_a, 0.01f);
lp->decay_lora_b = tensor_alloc(lora_rank, n_embd);
tensor_randn(&lp->decay_lora_b, 0.01f);
/* Per-head initialization for better multi-head diversity */
lp->decay_base = tensor_alloc_1d(n_embd);
{
int hdim_val = n_embd / cfg->n_head;
for (int h = 0; h < cfg->n_head; h++) {
float base_val = 1.5f - 0.1f * layer_idx - 0.2f * h;
for (int d = 0; d < hdim_val; d++) {
lp->decay_base.data[h * hdim_val + d] = base_val;
}
}
}
lp->time_first = tensor_alloc_1d(n_embd);
{
int hdim_val = n_embd / cfg->n_head;
for (int h = 0; h < cfg->n_head; h++) {
float tf_val = -3.0f + layer_idx * 0.3f + h * 0.5f;
for (int d = 0; d < hdim_val; d++) {
lp->time_first.data[h * hdim_val + d] = tf_val;
}
}
}
/* Projections */
lp->Wr = tensor_alloc(n_embd, n_embd);
tensor_randn(&lp->Wr, proj_scale);
lp->Wk = tensor_alloc(n_embd, n_embd);
tensor_randn(&lp->Wk, proj_scale);
lp->Wv = tensor_alloc(n_embd, n_embd);
tensor_randn(&lp->Wv, proj_scale);
lp->Wo = tensor_alloc(n_embd, n_embd);
tensor_randn(&lp->Wo, proj_scale);
/* Channel mix */
lp->channel_mix = tensor_alloc_1d(n_embd);
tensor_fill(&lp->channel_mix, 0.5f);
lp->ffn_gate = tensor_alloc(n_embd, ffn_h);
tensor_randn(&lp->ffn_gate, ffn_scale);
lp->ffn_up = tensor_alloc(n_embd, ffn_h);
tensor_randn(&lp->ffn_up, ffn_scale);
lp->ffn_down = tensor_alloc(ffn_h, n_embd);
tensor_randn(&lp->ffn_down, ffn_scale);
/* ALiBi slopes - fixed, not learned */
lp->alibi_slopes = tensor_alloc_1d(cfg->n_head);
compute_alibi_slopes(lp->alibi_slopes.data, cfg->n_head);
}
static void free_layer_params(LayerParams *lp) {
tensor_free(&lp->ln1_weight);
tensor_free(&lp->ln1_bias);
tensor_free(&lp->ln2_weight);
tensor_free(&lp->ln2_bias);
tensor_free(&lp->time_shift_w1);
tensor_free(&lp->time_shift_w2);
tensor_free(&lp->time_shift_w4);
tensor_free(&lp->time_mix_r);
tensor_free(&lp->time_mix_k);
tensor_free(&lp->time_mix_v);
tensor_free(&lp->decay_lora_a);
tensor_free(&lp->decay_lora_b);
tensor_free(&lp->decay_base);
tensor_free(&lp->time_first);
tensor_free(&lp->Wr);
tensor_free(&lp->Wk);
tensor_free(&lp->Wv);
tensor_free(&lp->Wo);
tensor_free(&lp->channel_mix);
tensor_free(&lp->ffn_gate);
tensor_free(&lp->ffn_up);
tensor_free(&lp->ffn_down);
tensor_free(&lp->alibi_slopes);
tensor_free(&lp->mem_gate_write);
tensor_free(&lp->mem_gate_read);
}
static void init_model_params(ModelParams *mp, const lrnnConfig *cfg) {
int n_embd = cfg->n_embd;
int vocab_size = cfg->vocab_size;
if (n_embd % cfg->n_head != 0) {
fprintf(stderr, "Error: n_embd (%d) must be divisible by n_head (%d)\n",
n_embd, cfg->n_head);
exit(1);
}
mp->n_layers = cfg->n_layer;
/* Embedding */
mp->emb = tensor_alloc(vocab_size, n_embd);
tensor_randn(&mp->emb, 0.02f);
/* Initial layer norm */
mp->ln0_weight = tensor_alloc_1d(n_embd);
tensor_fill(&mp->ln0_weight, 1.0f);
mp->ln0_bias = tensor_alloc_1d(n_embd);
/* Layers */
mp->layers = (LayerParams *)calloc((size_t)cfg->n_layer, sizeof(LayerParams));
if (!mp->layers) {
fprintf(stderr, "Error: failed to allocate layers\n");
exit(1);
}
for (int i = 0; i < cfg->n_layer; i++) {
init_layer_params(&mp->layers[i], cfg, i);
}
/* Output */
mp->ln_out_weight = tensor_alloc_1d(n_embd);
tensor_fill(&mp->ln_out_weight, 1.0f);
mp->ln_out_bias = tensor_alloc_1d(n_embd);
mp->head = tensor_alloc(n_embd, vocab_size);
tensor_randn(&mp->head, 0.02f);
}
static void free_model_params(ModelParams *mp) {
tensor_free(&mp->emb);
tensor_free(&mp->ln0_weight);
tensor_free(&mp->ln0_bias);
for (int i = 0; i < mp->n_layers; i++) {
free_layer_params(&mp->layers[i]);
}
free(mp->layers);
mp->layers = NULL;
tensor_free(&mp->ln_out_weight);
tensor_free(&mp->ln_out_bias);
tensor_free(&mp->head);
}
/* ============================================================
* State Management
* ============================================================ */
static void init_layer_state(LayerState *ls, int n_embd, int n_mem_slots) {
ls->x_prev_1 = tensor_alloc_1d(n_embd);
ls->x_prev_2 = tensor_alloc_1d(n_embd);
ls->x_prev_3 = tensor_alloc_1d(n_embd);
ls->x_prev_4 = tensor_alloc_1d(n_embd);
ls->wkv_num = tensor_alloc(n_mem_slots, n_embd);
ls->wkv_den = tensor_alloc(n_mem_slots, n_embd);
ls->ffn_prev = tensor_alloc_1d(n_embd);
}
static void free_layer_state(LayerState *ls) {
tensor_free(&ls->x_prev_1);
tensor_free(&ls->x_prev_2);
tensor_free(&ls->x_prev_3);
tensor_free(&ls->x_prev_4);
tensor_free(&ls->wkv_num);
tensor_free(&ls->wkv_den);
tensor_free(&ls->ffn_prev);
}
static void init_model_state(ModelState *ms, const lrnnConfig *cfg) {
ms->n_layers = cfg->n_layer;
ms->layers = (LayerState *)calloc((size_t)cfg->n_layer, sizeof(LayerState));
if (!ms->layers) {
fprintf(stderr, "Error: failed to allocate state\n");
exit(1);
}
for (int i = 0; i < cfg->n_layer; i++) {
init_layer_state(&ms->layers[i], cfg->n_embd, cfg->n_mem_slots);
}
}
static void free_model_state(ModelState *ms) {
for (int i = 0; i < ms->n_layers; i++) {
free_layer_state(&ms->layers[i]);
}
free(ms->layers);
ms->layers = NULL;
}
static inline int head_dim(const lrnnConfig *cfg) {
return cfg->n_embd / cfg->n_head;
}
/* ============================================================
* Forward Pass - Single Token (for Generation)
* ============================================================ */
static void forward_single(float *logits, int token, const ModelParams *mp,
ModelState *state, const lrnnConfig *cfg) {
int n_embd = cfg->n_embd;
int n_head = cfg->n_head;
int hdim = n_embd / n_head;
int ffn_h = ffn_hidden(cfg);
int lora_rank = cfg->decay_lora_rank;
float *x = (float *)malloc((size_t)n_embd * sizeof(float));
float *x_norm = (float *)malloc((size_t)n_embd * sizeof(float));
float *x_shifted= (float *)malloc((size_t)n_embd * sizeof(float));
float *xr = (float *)malloc((size_t)n_embd * sizeof(float));
float *xk = (float *)malloc((size_t)n_embd * sizeof(float));
float *xv = (float *)malloc((size_t)n_embd * sizeof(float));
float *r = (float *)malloc((size_t)n_embd * sizeof(float));
float *k = (float *)malloc((size_t)n_embd * sizeof(float));
float *v = (float *)malloc((size_t)n_embd * sizeof(float));
float *decay_delta = (float *)malloc((size_t)n_embd * sizeof(float));
float *decay = (float *)malloc((size_t)n_embd * sizeof(float));
float *k_exp = (float *)malloc((size_t)n_embd * sizeof(float));
float *time_first_val = (float *)malloc((size_t)n_embd * sizeof(float));
float *wkv = (float *)malloc((size_t)n_embd * sizeof(float));
float *tm_out = (float *)malloc((size_t)n_embd * sizeof(float));
float *xm = (float *)malloc((size_t)n_embd * sizeof(float));
float *gate = (float *)malloc((size_t)ffn_h * sizeof(float));
float *up = (float *)malloc((size_t)ffn_h * sizeof(float));
float *hidden = (float *)malloc((size_t)ffn_h * sizeof(float));
float *cm_out = (float *)malloc((size_t)n_embd * sizeof(float));
float *lora_tmp = (float *)malloc((size_t)lora_rank * sizeof(float));
float *w1_sig = (float *)malloc((size_t)n_embd * sizeof(float));
float *w2_sig = (float *)malloc((size_t)n_embd * sizeof(float));
float *w4_sig = (float *)malloc((size_t)n_embd * sizeof(float));
/* Token embedding */
memcpy(x, mp->emb.data + token * n_embd, (size_t)n_embd * sizeof(float));
/* Initial layer norm */
layer_norm(x, x, mp->ln0_weight.data, mp->ln0_bias.data, n_embd);
for (int layer_idx = 0; layer_idx < mp->n_layers; layer_idx++) {
const LayerParams *lp = &mp->layers[layer_idx];
LayerState *ls = &state->layers[layer_idx];
/* ============ TimeMix ============ */
layer_norm(x_norm, x, lp->ln1_weight.data, lp->ln1_bias.data, n_embd);
/* Multi-scale shift */
sigmoid_vec(w1_sig, lp->time_shift_w1.data, n_embd);
sigmoid_vec(w2_sig, lp->time_shift_w2.data, n_embd);
sigmoid_vec(w4_sig, lp->time_shift_w4.data, n_embd);
for (int i = 0; i < n_embd; i++) {
float w_sum = w1_sig[i] + w2_sig[i] + w4_sig[i] + EPSILON;
float nw1 = w1_sig[i] / w_sum;
float nw2 = w2_sig[i] / w_sum;
float nw4 = w4_sig[i] / w_sum;
x_shifted[i] = nw1 * ls->x_prev_1.data[i] +
nw2 * ls->x_prev_2.data[i] +
nw4 * ls->x_prev_4.data[i];
}
/* Mix current with shifted */
for (int i = 0; i < n_embd; i++) {
float mr = sigmoid_f(lp->time_mix_r.data[i]);
float mk = sigmoid_f(lp->time_mix_k.data[i]);
float mv = sigmoid_f(lp->time_mix_v.data[i]);
xr[i] = x_norm[i] * mr + x_shifted[i] * (1.0f - mr);
xk[i] = x_norm[i] * mk + x_shifted[i] * (1.0f - mk);
xv[i] = x_norm[i] * mv + x_shifted[i] * (1.0f - mv);
}
/* R, K, V projections */
matvec(r, xr, &lp->Wr);
matvec(k, xk, &lp->Wk);
matvec(v, xv, &lp->Wv);
/* Data-dependent decay */
matvec(lora_tmp, x_norm, &lp->decay_lora_a);
matvec(decay_delta, lora_tmp, &lp->decay_lora_b);
for (int i = 0; i < n_embd; i++) {
decay[i] = sigmoid_f(lp->decay_base.data[i] + decay_delta[i]);
}
/* Receptance gate */
sigmoid_vec(r, r, n_embd);
for (int i = 0; i < n_embd; i++) {
time_first_val[i] = expf(clamp_f(lp->time_first.data[i], -10.0f, 10.0f));
}
exp_vec(k_exp, k, n_embd);
/* ---- Multi-Head Multi-Slot WKV with ALiBi ---- */
{
int n_slots = cfg->n_mem_slots;
/* Compute write gates: softmax(x_norm @ mem_gate_write) -> (n_slots,) */
float *write_logits = (float *)malloc((size_t)n_slots * sizeof(float));
float *write_gates = (float *)malloc((size_t)n_slots * sizeof(float));
float *read_logits = (float *)malloc((size_t)n_slots * sizeof(float));
float *read_gates = (float *)malloc((size_t)n_slots * sizeof(float));
matvec(write_logits, x_norm, &lp->mem_gate_write);
softmax_vec(write_gates, write_logits, n_slots);
matvec(read_logits, x_norm, &lp->mem_gate_read);
softmax_vec(read_gates, read_logits, n_slots);
for (int h = 0; h < n_head; h++) {
int base = h * hdim;
float alibi_decay_h = expf(-lp->alibi_slopes.data[h]);
for (int d = 0; d < hdim; d++) {
int i = base + d;
float kv = k_exp[i] * v[i];
/* Read: weighted sum across slots */
float read_num = 0.0f, read_den = 0.0f;
for (int s = 0; s < n_slots; s++) {
int si = s * n_embd + i;
read_num += read_gates[s] * ls->wkv_num.data[si];
read_den += read_gates[s] * ls->wkv_den.data[si];
}
/* WKV output with time_first boost */
float num = read_num + time_first_val[i] * kv;
float den = read_den + time_first_val[i] * k_exp[i] + EPSILON;
wkv[i] = num / den;
/* Write: update each slot weighted by write gate */
for (int s = 0; s < n_slots; s++) {
int si = s * n_embd + i;
float wg = write_gates[s];
float combined = decay[i] * alibi_decay_h;
/* Slot update: interpolate between decay and new info */
ls->wkv_num.data[si] = combined * ls->wkv_num.data[si] + wg * kv;
ls->wkv_den.data[si] = combined * ls->wkv_den.data[si] + wg * k_exp[i];
}
}
}
free(write_logits); free(write_gates);
free(read_logits); free(read_gates);
}
/* Apply receptance and output projection */
vec_mul(wkv, r, wkv, n_embd);
matvec(tm_out, wkv, &lp->Wo);
vec_add(x, x, tm_out, n_embd);
/* Update previous tokens */
tensor_copy(&ls->x_prev_4, &ls->x_prev_3);
tensor_copy(&ls->x_prev_3, &ls->x_prev_2);
tensor_copy(&ls->x_prev_2, &ls->x_prev_1);
memcpy(ls->x_prev_1.data, x_norm, (size_t)n_embd * sizeof(float));
/* ============ ChannelMix ============ */
layer_norm(x_norm, x, lp->ln2_weight.data, lp->ln2_bias.data, n_embd);
for (int i = 0; i < n_embd; i++) {
float mix = sigmoid_f(lp->channel_mix.data[i]);
xm[i] = x_norm[i] * mix + ls->ffn_prev.data[i] * (1.0f - mix);
}
/* SwiGLU */
matvec(gate, xm, &lp->ffn_gate);
matvec(up, xm, &lp->ffn_up);
for (int i = 0; i < ffn_h; i++) {
hidden[i] = silu_f(gate[i]) * up[i];
}
matvec(cm_out, hidden, &lp->ffn_down);
vec_add(x, x, cm_out, n_embd);
memcpy(ls->ffn_prev.data, x_norm, (size_t)n_embd * sizeof(float));
}
/* Output layer norm and projection */
layer_norm(x, x, mp->ln_out_weight.data, mp->ln_out_bias.data, n_embd);
matvec(logits, x, &mp->head);
/* Cleanup */
free(x); free(x_norm); free(x_shifted);
free(xr); free(xk); free(xv);
free(r); free(k); free(v);
free(decay_delta); free(decay); free(k_exp);
free(time_first_val); free(wkv); free(tm_out);
free(xm); free(gate); free(up); free(hidden); free(cm_out);
free(lora_tmp); free(w1_sig); free(w2_sig); free(w4_sig);
}
/* ============================================================
* Loss Computation
* ============================================================ */
static float cross_entropy_loss(const Tensor *logits, const int *targets, int n) {
int vocab_size = logits->cols;
float *probs = (float *)malloc((size_t)vocab_size * sizeof(float));
float total_loss = 0.0f;
for (int t = 0; t < n; t++) {
softmax_vec(probs, logits->data + t * vocab_size, vocab_size);
int target = targets[t];
float p = probs[target];
if (p < EPSILON) p = EPSILON;
total_loss -= logf(p);
}
free(probs);
return total_loss / (float)n;
}
/* ============================================================
* File I/O
* ============================================================ */
static void write_tensor(FILE *f, const Tensor *t) {
fwrite(&t->rows, sizeof(int), 1, f);
fwrite(&t->cols, sizeof(int), 1, f);
fwrite(t->data, sizeof(float), (size_t)t->size, f);
}
static void read_tensor(FILE *f, Tensor *t) {
int rows, cols;
if (fread(&rows, sizeof(int), 1, f) != 1) return;
if (fread(&cols, sizeof(int), 1, f) != 1) return;
*t = tensor_alloc(rows, cols);
if (fread(t->data, sizeof(float), (size_t)t->size, f) != (size_t)t->size) {
fprintf(stderr, "Warning: incomplete tensor read\n");
}
}
static void write_layer_params(FILE *f, const LayerParams *lp) {
write_tensor(f, &lp->ln1_weight);
write_tensor(f, &lp->ln1_bias);
write_tensor(f, &lp->ln2_weight);
write_tensor(f, &lp->ln2_bias);
write_tensor(f, &lp->time_shift_w1);
write_tensor(f, &lp->time_shift_w2);
write_tensor(f, &lp->time_shift_w4);
write_tensor(f, &lp->time_mix_r);
write_tensor(f, &lp->time_mix_k);
write_tensor(f, &lp->time_mix_v);
write_tensor(f, &lp->decay_lora_a);
write_tensor(f, &lp->decay_lora_b);
write_tensor(f, &lp->decay_base);
write_tensor(f, &lp->time_first);
write_tensor(f, &lp->Wr);
write_tensor(f, &lp->Wk);
write_tensor(f, &lp->Wv);
write_tensor(f, &lp->Wo);
write_tensor(f, &lp->channel_mix);
write_tensor(f, &lp->ffn_gate);
write_tensor(f, &lp->ffn_up);
write_tensor(f, &lp->ffn_down);
write_tensor(f, &lp->mem_gate_write);
write_tensor(f, &lp->mem_gate_read);
}
static void read_layer_params(FILE *f, LayerParams *lp) {
read_tensor(f, &lp->ln1_weight);
read_tensor(f, &lp->ln1_bias);
read_tensor(f, &lp->ln2_weight);
read_tensor(f, &lp->ln2_bias);
read_tensor(f, &lp->time_shift_w1);
read_tensor(f, &lp->time_shift_w2);
read_tensor(f, &lp->time_shift_w4);
read_tensor(f, &lp->time_mix_r);
read_tensor(f, &lp->time_mix_k);
read_tensor(f, &lp->time_mix_v);
read_tensor(f, &lp->decay_lora_a);
read_tensor(f, &lp->decay_lora_b);
read_tensor(f, &lp->decay_base);
read_tensor(f, &lp->time_first);
read_tensor(f, &lp->Wr);
read_tensor(f, &lp->Wk);
read_tensor(f, &lp->Wv);
read_tensor(f, &lp->Wo);
read_tensor(f, &lp->channel_mix);
read_tensor(f, &lp->ffn_gate);
read_tensor(f, &lp->ffn_up);
read_tensor(f, &lp->ffn_down);
read_tensor(f, &lp->mem_gate_write);
read_tensor(f, &lp->mem_gate_read);
}
/* ============================================================
* File I/O (Updated for Hybrid Tokenizer)
* ============================================================ */
static int save_model(const char *path, const ModelParams *mp,
const lrnnConfig *cfg, const Tokenizer *tok) {
FILE *f = fopen(path, "wb");
if (!f) {
fprintf(stderr, "Error: cannot open %s for writing\n", path);
return -1;
}
/* Magic and version (updated magic for new format) */
const char magic[] = "lrnnC02"; /* Version 2 for tokenizer support */
fwrite(magic, 1, 8, f);
fwrite(cfg, sizeof(lrnnConfig), 1, f);
/* Tokenizer type */
int tok_type = (int)tok->type;
fwrite(&tok_type, sizeof(int), 1, f);
/* Save vocabulary based on type */
if (tok->type == TOKENIZER_CHAR) {
fwrite(&tok->char_vocab.size, sizeof(int), 1, f);
fwrite(tok->char_vocab.chars, sizeof(char), (size_t)tok->char_vocab.size, f);
fwrite(tok->char_vocab.char_to_idx, sizeof(int), 256, f);
} else {
/* Word vocabulary */
fwrite(&tok->word_vocab.size, sizeof(int), 1, f);
fwrite(&tok->word_vocab.unk_idx, sizeof(int), 1, f);
fwrite(&tok->word_vocab.pad_idx, sizeof(int), 1, f);
fwrite(&tok->word_vocab.space_idx, sizeof(int), 1, f);
fwrite(&tok->word_vocab.newline_idx, sizeof(int), 1, f);
/* Save each word with length prefix */
for (int i = 0; i < tok->word_vocab.size; i++) {
int len = (int)strlen(tok->word_vocab.words[i]);
fwrite(&len, sizeof(int), 1, f);
fwrite(tok->word_vocab.words[i], sizeof(char), (size_t)len, f);
}
}
/* Model params */
write_tensor(f, &mp->emb);
write_tensor(f, &mp->ln0_weight);
write_tensor(f, &mp->ln0_bias);
fwrite(&mp->n_layers, sizeof(int), 1, f);
for (int i = 0; i < mp->n_layers; i++) {
write_layer_params(f, &mp->layers[i]);
}
write_tensor(f, &mp->ln_out_weight);
write_tensor(f, &mp->ln_out_bias);
write_tensor(f, &mp->head);
fclose(f);
return 0;
}
static int load_model(const char *path, ModelParams *mp,
lrnnConfig *cfg, Tokenizer *tok) {
FILE *f = fopen(path, "rb");
if (!f) {
fprintf(stderr, "Error: cannot open %s for reading\n", path);
return -1;
}
char magic[8];
if (fread(magic, 1, 8, f) != 8) {
fclose(f);
return -1;
}
/* Check version */
bool is_v1 = (strncmp(magic, "lrnnC01", 7) == 0);
bool is_v2 = (strncmp(magic, "lrnnC02", 7) == 0);
if (!is_v1 && !is_v2) {
fprintf(stderr, "Error: invalid model file format\n");
fclose(f);
return -1;
}
if (fread(cfg, sizeof(lrnnConfig), 1, f) != 1) {
fclose(f);
return -1;
}
memset(tok, 0, sizeof(Tokenizer));
if (is_v1) {
/* Old format: character vocabulary only */
tok->type = TOKENIZER_CHAR;
if (fread(&tok->char_vocab.size, sizeof(int), 1, f) != 1) {
fclose(f);
return -1;
}
if (fread(tok->char_vocab.chars, sizeof(char),
(size_t)tok->char_vocab.size, f) != (size_t)tok->char_vocab.size) {
fclose(f);
return -1;
}
if (fread(tok->char_vocab.char_to_idx, sizeof(int), 256, f) != 256) {
fclose(f);
return -1;
}
} else {
int tok_type;
if (fread(&tok_type, sizeof(int), 1, f) != 1) {
fclose(f);
return -1;
}
tok->type = (TokenizerType)tok_type;
if (tok->type == TOKENIZER_CHAR) {
if (fread(&tok->char_vocab.size, sizeof(int), 1, f) != 1) {
fclose(f);
return -1;
}
if (fread(tok->char_vocab.chars, sizeof(char),
(size_t)tok->char_vocab.size, f) != (size_t)tok->char_vocab.size) {
fclose(f);
return -1;
}
if (fread(tok->char_vocab.char_to_idx, sizeof(int), 256, f) != 256) {
fclose(f);
return -1;
}
} else {
/* Word vocabulary */
init_word_vocabulary(&tok->word_vocab);
int size;
if (fread(&size, sizeof(int), 1, f) != 1) {
fclose(f);
return -1;
}
if (fread(&tok->word_vocab.unk_idx, sizeof(int), 1, f) != 1) {
fclose(f);
return -1;
}
if (fread(&tok->word_vocab.pad_idx, sizeof(int), 1, f) != 1) {
fclose(f);
return -1;
}
if (fread(&tok->word_vocab.space_idx, sizeof(int), 1, f) != 1) {
fclose(f);
return -1;
}
if (fread(&tok->word_vocab.newline_idx, sizeof(int), 1, f) != 1) {
fclose(f);
return -1;
}
/* Read each word */
for (int i = 0; i < size; i++) {
int len;
if (fread(&len, sizeof(int), 1, f) != 1) {
fclose(f);
return -1;
}
char word[MAX_WORD_LEN];
if (len >= MAX_WORD_LEN) len = MAX_WORD_LEN - 1;
if (fread(word, sizeof(char), (size_t)len, f) != (size_t)len) {
fclose(f);
return -1;
}
word[len] = '\0';
word_vocab_add(&tok->word_vocab, word);
}
}
}
/* Model params */
read_tensor(f, &mp->emb);
read_tensor(f, &mp->ln0_weight);
read_tensor(f, &mp->ln0_bias);
if (fread(&mp->n_layers, sizeof(int), 1, f) != 1) {
fclose(f);
return -1;
}
mp->layers = (LayerParams *)calloc((size_t)mp->n_layers, sizeof(LayerParams));
for (int i = 0; i < mp->n_layers; i++) {
read_layer_params(f, &mp->layers[i]);
}
/* Recompute ALiBi slopes (deterministic from config) */
for (int i = 0; i < mp->n_layers; i++) {
mp->layers[i].alibi_slopes = tensor_alloc_1d(cfg->n_head);
compute_alibi_slopes(mp->layers[i].alibi_slopes.data, cfg->n_head);
}
read_tensor(f, &mp->ln_out_weight);
read_tensor(f, &mp->ln_out_bias);
read_tensor(f, &mp->head);
fclose(f);
return 0;
}
/* ============================================================
* Vocabulary Building
* ============================================================ */
static void build_vocabulary(Vocabulary *vocab, const char *text, size_t len) {
bool seen[256] = {false};
vocab->size = 0;
for (size_t i = 0; i < len; i++) {
unsigned char c = (unsigned char)text[i];
if (!seen[c]) {
seen[c] = true;
vocab->chars[vocab->size] = (char)c;
vocab->char_to_idx[c] = vocab->size;
vocab->size++;
}
}
/* Sort for consistency */
for (int i = 0; i < vocab->size - 1; i++) {
for (int j = i + 1; j < vocab->size; j++) {
if ((unsigned char)vocab->chars[i] > (unsigned char)vocab->chars[j]) {
char tmp = vocab->chars[i];
vocab->chars[i] = vocab->chars[j];
vocab->chars[j] = tmp;
}
}
}
/* Rebuild index */
for (int i = 0; i < 256; i++) {
vocab->char_to_idx[i] = 0;
}
for (int i = 0; i < vocab->size; i++) {
vocab->char_to_idx[(unsigned char)vocab->chars[i]] = i;
}
}
/* ============================================================
* Word-Level Vocabulary Implementation
* ============================================================ */
static unsigned int word_hash(const char *word) {
unsigned int hash = 5381;
while (*word) {
hash = ((hash << 5) + hash) ^ (unsigned char)*word++;
}
return hash;
}
static void init_word_vocabulary(WordVocabulary *wv) {
wv->capacity = MAX_WORDS;
wv->words = (char **)calloc((size_t)wv->capacity, sizeof(char *));
wv->hash_table = (int *)malloc(WORD_HASH_SIZE * sizeof(int));
wv->hash_keys = (int *)malloc(WORD_HASH_SIZE * sizeof(int));
for (int i = 0; i < WORD_HASH_SIZE; i++) {
wv->hash_table[i] = -1;
wv->hash_keys[i] = -1;
}
wv->size = 0;
wv->unk_idx = -1;
wv->pad_idx = -1;
wv->space_idx = -1;
wv->newline_idx = -1;
}
static void free_word_vocabulary(WordVocabulary *wv) {
if (wv->words) {
for (int i = 0; i < wv->size; i++) {
free(wv->words[i]);
}
free(wv->words);
wv->words = NULL;
}
if (wv->hash_table) {
free(wv->hash_table);
wv->hash_table = NULL;
}
if (wv->hash_keys) {
free(wv->hash_keys);
wv->hash_keys = NULL;
}
wv->size = 0;
}
static int word_vocab_find(const WordVocabulary *wv, const char *word) {
unsigned int hash = word_hash(word);
unsigned int idx = hash % WORD_HASH_SIZE;
for (int probe = 0; probe < 1000; probe++) {
unsigned int slot = (idx + probe) % WORD_HASH_SIZE;
if (wv->hash_table[slot] < 0) {
return -1; /* Not found */
}
if (wv->hash_keys[slot] == (int)hash) {
int word_idx = wv->hash_table[slot];
if (strcmp(wv->words[word_idx], word) == 0) {
return word_idx;
}
}
}
return -1;
}
static int word_vocab_add(WordVocabulary *wv, const char *word) {
/* Check if already exists */
int existing = word_vocab_find(wv, word);
if (existing >= 0) return existing;
/* Check capacity */
if (wv->size >= wv->capacity - 1) {
fprintf(stderr, "Warning: word vocabulary full\n");
return wv->unk_idx;
}
/* Add word */
int word_idx = wv->size;
wv->words[word_idx] = strdup(word);
wv->size++;
/* Add to hash table */
unsigned int hash = word_hash(word);
unsigned int idx = hash % WORD_HASH_SIZE;
for (int probe = 0; probe < 1000; probe++) {
unsigned int slot = (idx + probe) % WORD_HASH_SIZE;
if (wv->hash_table[slot] < 0) {
wv->hash_table[slot] = word_idx;
wv->hash_keys[slot] = (int)hash;
break;
}
}
return word_idx;
}
static inline int is_word_boundary(char c) {
return c == ' ' || c == '\n' || c == '\t' || c == '\r' ||
c == '.' || c == ',' || c == '!' || c == '?' ||
c == ':' || c == ';' || c == '"' || c == '\'' ||
c == '(' || c == ')' || c == '[' || c == ']' ||
c == '{' || c == '}' || c == '-' || c == '/' ||
c == '\\' || c == '@' || c == '#' || c == '$' ||
c == '%' || c == '&' || c == '*' || c == '+' ||
c == '=' || c == '<' || c == '>' || c == '|' ||
c == '~' || c == '`' || c == '^';
}
static void build_word_vocabulary(WordVocabulary *wv, const char *text, size_t len) {
init_word_vocabulary(wv);
/* Add special tokens first */
wv->unk_idx = word_vocab_add(wv, "");
wv->pad_idx = word_vocab_add(wv, "");
wv->space_idx = word_vocab_add(wv, " ");
wv->newline_idx = word_vocab_add(wv, "\n");
/* Add common punctuation as separate tokens */
word_vocab_add(wv, ".");
word_vocab_add(wv, ",");
word_vocab_add(wv, "!");
word_vocab_add(wv, "?");
word_vocab_add(wv, ":");
word_vocab_add(wv, ";");
word_vocab_add(wv, "\"");
word_vocab_add(wv, "'");
word_vocab_add(wv, "(");
word_vocab_add(wv, ")");
word_vocab_add(wv, "-");
word_vocab_add(wv, "\t");
/* Parse text and add words */
char word[MAX_WORD_LEN];
int word_len = 0;
for (size_t i = 0; i < len; i++) {
char c = text[i];
if (is_word_boundary(c)) {
/* End current word */
if (word_len > 0) {
word[word_len] = '\0';
word_vocab_add(wv, word);
word_len = 0;
}
/* Add boundary char as token (except space/tab which we handle specially) */
if (c != ' ' && c != '\t' && c != '\r') {
char punct[2] = {c, '\0'};
word_vocab_add(wv, punct);
}
} else {
/* Accumulate word */
if (word_len < MAX_WORD_LEN - 1) {
word[word_len++] = c;
}
}
}
/* Handle last word */
if (word_len > 0) {
word[word_len] = '\0';
word_vocab_add(wv, word);
}
}
static int *tokenize_words(const char *text, size_t len,
const WordVocabulary *wv, int *out_len) {
/* Estimate max tokens */
int max_tokens = (int)(len / 2) + 100;
int *tokens = (int *)malloc((size_t)max_tokens * sizeof(int));
int n_tokens = 0;
char word[MAX_WORD_LEN];
int word_len = 0;
for (size_t i = 0; i < len; i++) {
char c = text[i];
if (is_word_boundary(c)) {
/* End current word */
if (word_len > 0) {
word[word_len] = '\0';
int idx = word_vocab_find(wv, word);
tokens[n_tokens++] = (idx >= 0) ? idx : wv->unk_idx;
word_len = 0;
}
/* Add boundary token */
if (c == ' ') {
tokens[n_tokens++] = wv->space_idx;
} else if (c == '\n') {
tokens[n_tokens++] = wv->newline_idx;
} else if (c == '\t') {
tokens[n_tokens++] = wv->space_idx; /* Treat tab as space */
} else if (c != '\r') {
char punct[2] = {c, '\0'};
int idx = word_vocab_find(wv, punct);
if (idx >= 0) {
tokens[n_tokens++] = idx;
}
}
} else {
if (word_len < MAX_WORD_LEN - 1) {
word[word_len++] = c;
}
}
/* Grow buffer if needed */
if (n_tokens >= max_tokens - 10) {
max_tokens *= 2;
tokens = (int *)realloc(tokens, (size_t)max_tokens * sizeof(int));
}
}
/* Handle last word */
if (word_len > 0) {
word[word_len] = '\0';
int idx = word_vocab_find(wv, word);
tokens[n_tokens++] = (idx >= 0) ? idx : wv->unk_idx;
}
*out_len = n_tokens;
return tokens;
}
static const char *decode_word_token(int token, const WordVocabulary *wv) {
if (token >= 0 && token < wv->size && wv->words[token]) {
return wv->words[token];
}
return "";
}
/* ============================================================
* Unified Tokenizer Interface
* ============================================================ */
static void init_tokenizer(Tokenizer *tok, TokenizerType type) {
memset(tok, 0, sizeof(Tokenizer));
tok->type = type;
}
static void free_tokenizer(Tokenizer *tok) {
if (tok->type == TOKENIZER_WORD) {
free_word_vocabulary(&tok->word_vocab);
}
/* char_vocab doesn't need explicit free (static arrays) */
}
static void build_tokenizer(Tokenizer *tok, const char *text, size_t len,
TokenizerType requested_type) {
/* Auto-select based on corpus size */
if (requested_type == TOKENIZER_AUTO) {
if (len < 20000) {
tok->type = TOKENIZER_CHAR;
printf(" Auto-selected: character tokenizer (corpus < 20KB)\n");
} else {
tok->type = TOKENIZER_WORD;
printf(" Auto-selected: word tokenizer (corpus >= 20KB)\n");
}
} else {
tok->type = requested_type;
}
if (tok->type == TOKENIZER_CHAR) {
build_vocabulary(&tok->char_vocab, text, len);
} else {
build_word_vocabulary(&tok->word_vocab, text, len);
}
}
static int tokenizer_vocab_size(const Tokenizer *tok) {
if (tok->type == TOKENIZER_CHAR) {
return tok->char_vocab.size;
} else {
return tok->word_vocab.size;
}
}
static int *tokenizer_encode(const Tokenizer *tok, const char *text, size_t len,
int *out_len) {
if (tok->type == TOKENIZER_CHAR) {
return tokenize(text, len, &tok->char_vocab, out_len);
} else {
return tokenize_words(text, len, &tok->word_vocab, out_len);
}
}
static void tokenizer_decode_token(const Tokenizer *tok, int token,
char *out, int out_size) {
if (tok->type == TOKENIZER_CHAR) {
if (token >= 0 && token < tok->char_vocab.size) {
out[0] = tok->char_vocab.chars[token];
out[1] = '\0';
} else {
out[0] = '?';
out[1] = '\0';
}
} else {
const char *word = decode_word_token(token, &tok->word_vocab);
strncpy(out, word, out_size - 1);
out[out_size - 1] = '\0';
}
}
/* Auto-configure model based on corpus and tokenizer */
static lrnnConfig config_for_corpus(long corpus_bytes, TokenizerType tok_type,
int vocab_size) {
lrnnConfig cfg = default_config();
cfg.vocab_size = vocab_size;
/* Word tokenizer is more efficient, so we can use smaller models */
float efficiency = (tok_type == TOKENIZER_WORD) ? 5.0f : 1.0f;
long effective_size = (long)(corpus_bytes / efficiency);
if (effective_size < 5000) {
cfg.n_layer = 2;
cfg.n_embd = 64;
cfg.ctx_len = 64;
cfg.decay_lora_rank = 4;
cfg.ffn_multiplier = 1.5f;
cfg.n_mem_slots = 2;
}
else if (effective_size < 50000) {
cfg.n_layer = 4;
cfg.n_embd = 128;
cfg.ctx_len = 128;
cfg.decay_lora_rank = 8;
cfg.ffn_multiplier = 2.0f;
cfg.n_mem_slots = 4;
}
else if (effective_size < 500000) {
cfg.n_layer = 6;
cfg.n_embd = 256;
cfg.ctx_len = 256;
cfg.decay_lora_rank = 16;
cfg.ffn_multiplier = 2.5f;
cfg.n_mem_slots = 4;
}
else {
cfg.n_layer = 8;
cfg.n_embd = 384;
cfg.ctx_len = 512;
cfg.decay_lora_rank = 32;
cfg.ffn_multiplier = 3.0f;
cfg.n_mem_slots = 8;
}
cfg.n_head = cfg.n_embd / 32;
if (cfg.n_head < 2) cfg.n_head = 2;
return cfg;
}
static int *tokenize(const char *text, size_t len, const Vocabulary *vocab, int *out_len) {
int *tokens = (int *)malloc(len * sizeof(int));
for (size_t i = 0; i < len; i++) {
tokens[i] = vocab->char_to_idx[(unsigned char)text[i]];
}
*out_len = (int)len;
return tokens;
}
/* ============================================================
* TRAINING SECTION - Full Backpropagation Implementation
* ============================================================
*
* This implements analytical gradients for all model parameters:
* - Embedding layer
* - Layer normalization (all instances)
* - Multi-scale token shift weights
* - Token mixing parameters (r, k, v)
* - Data-dependent decay (LoRA + base)
* - Projection matrices (Wr, Wk, Wv, Wo)
* - SwiGLU FFN (gate, up, down)
* - Output head
*/
/* ============================================================
* Gradient Structures
* ============================================================ */
typedef struct {
/* Layer norms */
Tensor ln1_weight, ln1_bias;
Tensor ln2_weight, ln2_bias;
/* Multi-scale token shift */
Tensor time_shift_w1, time_shift_w2, time_shift_w4;
/* Token mixing ratios */
Tensor time_mix_r, time_mix_k, time_mix_v;
/* Data-dependent decay (low-rank) */
Tensor decay_lora_a, decay_lora_b;
Tensor decay_base;
Tensor time_first;
/* Projections */
Tensor Wr, Wk, Wv, Wo;
/* Channel mix */
Tensor channel_mix;
Tensor ffn_gate, ffn_up, ffn_down;
/* In LayerGrads, add: */
Tensor mem_gate_write; /* (n_embd, n_mem_slots) */
Tensor mem_gate_read; /* (n_embd, n_mem_slots) */
} LayerGrads;
typedef struct {
Tensor emb;
Tensor ln0_weight, ln0_bias;
LayerGrads *layers;
Tensor ln_out_weight, ln_out_bias;
Tensor head;
int n_layers;
} ModelGrads;
/* ============================================================
* Forward Pass Cache (for backpropagation)
* ============================================================ */
typedef struct {
Tensor x_in;
/* TimeMix forward cache */
Tensor x_ln1; /* After first layer norm */
Tensor x_shifted; /* Multi-scale shifted */
Tensor shift_w1_sig; /* sigmoid(time_shift_w1) */
Tensor shift_w2_sig; /* sigmoid(time_shift_w2) */
Tensor shift_w4_sig; /* sigmoid(time_shift_w4) */
Tensor shift_w_sum; /* w1 + w2 + w4 + eps */
Tensor xr, xk, xv; /* After mixing */
Tensor mix_r_sig; /* sigmoid(time_mix_r) */
Tensor mix_k_sig; /* sigmoid(time_mix_k) */
Tensor mix_v_sig; /* sigmoid(time_mix_v) */
Tensor r_pre; /* Before sigmoid */
Tensor k_pre; /* Before exp */
Tensor v; /* v values */
Tensor r; /* After sigmoid (receptance) */
Tensor k_exp; /* After exp */
Tensor decay_tmp; /* LoRA intermediate (seq, rank) */
Tensor decay_delta; /* LoRA output */
Tensor decay_pre; /* Before sigmoid */
Tensor decay; /* After sigmoid */
Tensor time_first_exp; /* exp(time_first) */
Tensor *num_states; /* (seq+1) tensors, each (n_mem_slots, n_embd) */
Tensor *den_states; /* (seq+1) tensors, each (n_mem_slots, n_embd) */
Tensor *write_gates; /* (seq) tensors, each (n_mem_slots,) */
Tensor *read_gates; /* (seq) tensors, each (n_mem_slots,) */
Tensor wkv; /* WKV output */
Tensor wkv_r; /* wkv * r */
Tensor tm_out; /* After Wo projection */
Tensor x_after_tm; /* x + tm_out (residual) */
/* ChannelMix forward cache */
Tensor x_ln2; /* After second layer norm */
Tensor xm; /* After channel mixing */
Tensor cm_mix_sig; /* sigmoid(channel_mix) */
Tensor gate_pre; /* Before silu */
Tensor up_val; /* up values */
Tensor gate_silu; /* After silu */
Tensor hidden; /* gate * up */
Tensor cm_out; /* After down projection */
} LayerCache;
typedef struct {
int seq_len;
int n_layers;
Tensor emb_out; /* Token embeddings (seq, n_embd) */
Tensor x_ln0; /* After initial layer norm */
LayerCache *layers; /* Per-layer cache */
Tensor x_final; /* Final hidden states before output ln */
Tensor x_ln_out; /* After final layer norm */
Tensor logits; /* Final logits (seq, vocab) */
} ForwardCache;
/* ============================================================
* Gradient Allocation and Deallocation
* ============================================================ */
static void init_layer_grads(LayerGrads *lg, const lrnnConfig *cfg, const LayerParams *lp) {
int n_embd = cfg->n_embd;
int ffn_h = ffn_hidden(cfg);
int lora_rank = cfg->decay_lora_rank;
int n_slots = cfg->n_mem_slots;
lg->ln1_weight = tensor_alloc(lp->ln1_weight.rows, lp->ln1_weight.cols);
lg->ln1_bias = tensor_alloc(lp->ln1_bias.rows, lp->ln1_bias.cols);
lg->ln2_weight = tensor_alloc(lp->ln2_weight.rows, lp->ln2_weight.cols);
lg->ln2_bias = tensor_alloc(lp->ln2_bias.rows, lp->ln2_bias.cols);
lg->time_shift_w1 = tensor_alloc_1d(n_embd);
lg->time_shift_w2 = tensor_alloc_1d(n_embd);
lg->time_shift_w4 = tensor_alloc_1d(n_embd);
lg->time_mix_r = tensor_alloc_1d(n_embd);
lg->time_mix_k = tensor_alloc_1d(n_embd);
lg->time_mix_v = tensor_alloc_1d(n_embd);
lg->decay_lora_a = tensor_alloc(n_embd, lora_rank);
lg->decay_lora_b = tensor_alloc(lora_rank, n_embd);
lg->decay_base = tensor_alloc_1d(n_embd);
lg->time_first = tensor_alloc_1d(n_embd);
lg->Wr = tensor_alloc(n_embd, n_embd);
lg->Wk = tensor_alloc(n_embd, n_embd);
lg->Wv = tensor_alloc(n_embd, n_embd);
lg->Wo = tensor_alloc(n_embd, n_embd);
lg->channel_mix = tensor_alloc_1d(n_embd);
lg->ffn_gate = tensor_alloc(n_embd, ffn_h);
lg->ffn_up = tensor_alloc(n_embd, ffn_h);
lg->ffn_down = tensor_alloc(ffn_h, n_embd);
lg->mem_gate_write = tensor_alloc(n_embd, n_slots);
lg->mem_gate_read = tensor_alloc(n_embd, n_slots);
}
static void zero_layer_grads(LayerGrads *lg) {
tensor_zero(&lg->ln1_weight);
tensor_zero(&lg->ln1_bias);
tensor_zero(&lg->ln2_weight);
tensor_zero(&lg->ln2_bias);
tensor_zero(&lg->time_shift_w1);
tensor_zero(&lg->time_shift_w2);
tensor_zero(&lg->time_shift_w4);
tensor_zero(&lg->time_mix_r);
tensor_zero(&lg->time_mix_k);
tensor_zero(&lg->time_mix_v);
tensor_zero(&lg->decay_lora_a);
tensor_zero(&lg->decay_lora_b);
tensor_zero(&lg->decay_base);
tensor_zero(&lg->time_first);
tensor_zero(&lg->Wr);
tensor_zero(&lg->Wk);
tensor_zero(&lg->Wv);
tensor_zero(&lg->Wo);
tensor_zero(&lg->channel_mix);
tensor_zero(&lg->ffn_gate);
tensor_zero(&lg->ffn_up);
tensor_zero(&lg->ffn_down);
// ADD:
tensor_zero(&lg->mem_gate_write);
tensor_zero(&lg->mem_gate_read);
}
static void free_layer_grads(LayerGrads *lg) {
tensor_free(&lg->ln1_weight);
tensor_free(&lg->ln1_bias);
tensor_free(&lg->ln2_weight);
tensor_free(&lg->ln2_bias);
tensor_free(&lg->time_shift_w1);
tensor_free(&lg->time_shift_w2);
tensor_free(&lg->time_shift_w4);
tensor_free(&lg->time_mix_r);
tensor_free(&lg->time_mix_k);
tensor_free(&lg->time_mix_v);
tensor_free(&lg->decay_lora_a);
tensor_free(&lg->decay_lora_b);
tensor_free(&lg->decay_base);
tensor_free(&lg->time_first);
tensor_free(&lg->Wr);
tensor_free(&lg->Wk);
tensor_free(&lg->Wv);
tensor_free(&lg->Wo);
tensor_free(&lg->channel_mix);
tensor_free(&lg->ffn_gate);
tensor_free(&lg->ffn_up);
tensor_free(&lg->ffn_down);
tensor_free(&lg->mem_gate_write);
tensor_free(&lg->mem_gate_read);
}
static void init_model_grads(ModelGrads *mg, const ModelParams *mp, const lrnnConfig *cfg) {
int n_embd = cfg->n_embd;
int vocab_size = cfg->vocab_size;
mg->n_layers = cfg->n_layer;
mg->emb = tensor_alloc(vocab_size, n_embd);
mg->ln0_weight = tensor_alloc_1d(n_embd);
mg->ln0_bias = tensor_alloc_1d(n_embd);
mg->layers = (LayerGrads *)calloc((size_t)cfg->n_layer, sizeof(LayerGrads));
for (int i = 0; i < cfg->n_layer; i++) {
init_layer_grads(&mg->layers[i], cfg, &mp->layers[i]);
}
mg->ln_out_weight = tensor_alloc_1d(n_embd);
mg->ln_out_bias = tensor_alloc_1d(n_embd);
mg->head = tensor_alloc(n_embd, vocab_size);
}
static void zero_model_grads(ModelGrads *mg) {
tensor_zero(&mg->emb);
tensor_zero(&mg->ln0_weight);
tensor_zero(&mg->ln0_bias);
for (int i = 0; i < mg->n_layers; i++) {
zero_layer_grads(&mg->layers[i]);
}
tensor_zero(&mg->ln_out_weight);
tensor_zero(&mg->ln_out_bias);
tensor_zero(&mg->head);
}
static void free_model_grads(ModelGrads *mg) {
tensor_free(&mg->emb);
tensor_free(&mg->ln0_weight);
tensor_free(&mg->ln0_bias);
for (int i = 0; i < mg->n_layers; i++) {
free_layer_grads(&mg->layers[i]);
}
free(mg->layers);
mg->layers = NULL;
tensor_free(&mg->ln_out_weight);
tensor_free(&mg->ln_out_bias);
tensor_free(&mg->head);
}
/* ============================================================
* Forward Cache Allocation and Deallocation
* ============================================================ */
static void init_layer_cache(LayerCache *lc, int seq_len, const lrnnConfig *cfg) {
if (seq_len <= 0) {
fprintf(stderr, "FATAL: init_layer_cache called with seq_len=%d\n", seq_len);
exit(1);
}
int n_embd = cfg->n_embd;
int ffn_h = ffn_hidden(cfg);
int lora_rank = cfg->decay_lora_rank;
int n_slots = cfg->n_mem_slots;
/* --- All regular cache tensors --- */
lc->x_in = tensor_alloc(seq_len, n_embd);
lc->x_ln1 = tensor_alloc(seq_len, n_embd);
lc->x_shifted = tensor_alloc(seq_len, n_embd);
lc->shift_w1_sig = tensor_alloc_1d(n_embd);
lc->shift_w2_sig = tensor_alloc_1d(n_embd);
lc->shift_w4_sig = tensor_alloc_1d(n_embd);
lc->shift_w_sum = tensor_alloc_1d(n_embd);
lc->xr = tensor_alloc(seq_len, n_embd);
lc->xk = tensor_alloc(seq_len, n_embd);
lc->xv = tensor_alloc(seq_len, n_embd);
lc->mix_r_sig = tensor_alloc_1d(n_embd);
lc->mix_k_sig = tensor_alloc_1d(n_embd);
lc->mix_v_sig = tensor_alloc_1d(n_embd);
lc->r_pre = tensor_alloc(seq_len, n_embd);
lc->k_pre = tensor_alloc(seq_len, n_embd);
lc->v = tensor_alloc(seq_len, n_embd);
lc->r = tensor_alloc(seq_len, n_embd);
lc->k_exp = tensor_alloc(seq_len, n_embd);
lc->decay_tmp = tensor_alloc(seq_len, lora_rank);
lc->decay_delta = tensor_alloc(seq_len, n_embd);
lc->decay_pre = tensor_alloc(seq_len, n_embd);
lc->decay = tensor_alloc(seq_len, n_embd);
lc->time_first_exp = tensor_alloc_1d(n_embd);
lc->wkv = tensor_alloc(seq_len, n_embd);
lc->wkv_r = tensor_alloc(seq_len, n_embd);
lc->tm_out = tensor_alloc(seq_len, n_embd);
lc->x_after_tm = tensor_alloc(seq_len, n_embd);
lc->x_ln2 = tensor_alloc(seq_len, n_embd);
lc->xm = tensor_alloc(seq_len, n_embd);
lc->cm_mix_sig = tensor_alloc_1d(n_embd);
lc->gate_pre = tensor_alloc(seq_len, ffn_h);
lc->up_val = tensor_alloc(seq_len, ffn_h);
lc->gate_silu = tensor_alloc(seq_len, ffn_h);
lc->hidden = tensor_alloc(seq_len, ffn_h);
lc->cm_out = tensor_alloc(seq_len, n_embd);
/* --- Multi-slot WKV states --- */
lc->num_states = (Tensor *)calloc((size_t)(seq_len + 1), sizeof(Tensor));
lc->den_states = (Tensor *)calloc((size_t)(seq_len + 1), sizeof(Tensor));
if (!lc->num_states || !lc->den_states) {
fprintf(stderr, "Failed to allocate WKV state arrays\n");
exit(1);
}
for (int t = 0; t <= seq_len; t++) {
lc->num_states[t] = tensor_alloc(n_slots, n_embd);
lc->den_states[t] = tensor_alloc(n_slots, n_embd);
}
/* --- Memory gates --- */
lc->write_gates = (Tensor *)calloc((size_t)seq_len, sizeof(Tensor));
lc->read_gates = (Tensor *)calloc((size_t)seq_len, sizeof(Tensor));
if (!lc->write_gates || !lc->read_gates) {
fprintf(stderr, "Failed to allocate gate arrays\n");
exit(1);
}
for (int t = 0; t < seq_len; t++) {
lc->write_gates[t] = tensor_alloc_1d(n_slots);
lc->read_gates[t] = tensor_alloc_1d(n_slots);
}
}
static void free_layer_cache(LayerCache *lc, int seq_len) {
tensor_free(&lc->x_in);
tensor_free(&lc->x_ln1);
tensor_free(&lc->x_shifted);
tensor_free(&lc->shift_w1_sig);
tensor_free(&lc->shift_w2_sig);
tensor_free(&lc->shift_w4_sig);
tensor_free(&lc->shift_w_sum);
tensor_free(&lc->xr);
tensor_free(&lc->xk);
tensor_free(&lc->xv);
tensor_free(&lc->mix_r_sig);
tensor_free(&lc->mix_k_sig);
tensor_free(&lc->mix_v_sig);
tensor_free(&lc->r_pre);
tensor_free(&lc->k_pre);
tensor_free(&lc->v);
tensor_free(&lc->r);
tensor_free(&lc->k_exp);
tensor_free(&lc->decay_tmp);
tensor_free(&lc->decay_delta);
tensor_free(&lc->decay_pre);
tensor_free(&lc->decay);
tensor_free(&lc->time_first_exp);
for (int t = 0; t <= seq_len; t++) {
tensor_free(&lc->num_states[t]);
tensor_free(&lc->den_states[t]);
}
free(lc->num_states);
free(lc->den_states);
tensor_free(&lc->wkv);
tensor_free(&lc->wkv_r);
tensor_free(&lc->tm_out);
tensor_free(&lc->x_after_tm);
tensor_free(&lc->x_ln2);
tensor_free(&lc->xm);
tensor_free(&lc->cm_mix_sig);
tensor_free(&lc->gate_pre);
tensor_free(&lc->up_val);
tensor_free(&lc->gate_silu);
tensor_free(&lc->hidden);
tensor_free(&lc->cm_out);
// free(lc->num_states);
// free(lc->den_states);
for (int t = 0; t < seq_len; t++) {
tensor_free(&lc->write_gates[t]);
tensor_free(&lc->read_gates[t]);
}
free(lc->write_gates);
free(lc->read_gates);
}
static void init_forward_cache(ForwardCache *fc, int seq_len, const lrnnConfig *cfg) {
int n_embd = cfg->n_embd;
int vocab_size = cfg->vocab_size;
fc->seq_len = seq_len;
fc->n_layers = cfg->n_layer;
fc->emb_out = tensor_alloc(seq_len, n_embd);
fc->x_ln0 = tensor_alloc(seq_len, n_embd);
fc->layers = (LayerCache *)calloc((size_t)cfg->n_layer, sizeof(LayerCache));
for (int i = 0; i < cfg->n_layer; i++) {
init_layer_cache(&fc->layers[i], seq_len, cfg);
}
fc->x_final = tensor_alloc(seq_len, n_embd);
fc->x_ln_out = tensor_alloc(seq_len, n_embd);
fc->logits = tensor_alloc(seq_len, vocab_size);
}
static void free_forward_cache(ForwardCache *fc) {
tensor_free(&fc->emb_out);
tensor_free(&fc->x_ln0);
for (int i = 0; i < fc->n_layers; i++) {
free_layer_cache(&fc->layers[i], fc->seq_len);
}
free(fc->layers);
tensor_free(&fc->x_final);
tensor_free(&fc->x_ln_out);
tensor_free(&fc->logits);
}
/* ============================================================
* Backward Primitive Operations
* ============================================================ */
/* Sigmoid backward: d_input = d_output * sigmoid(x) * (1 - sigmoid(x))
* Given y = sigmoid(x), d_input = d_output * y * (1 - y) */
static void sigmoid_backward(float *d_input, const float *d_output,
const float *y, int n) {
for (int i = 0; i < n; i++) {
d_input[i] = d_output[i] * y[i] * (1.0f - y[i]);
}
}
/* SiLU backward: y = x * sigmoid(x)
* dy/dx = sigmoid(x) + x * sigmoid(x) * (1 - sigmoid(x))
* = sigmoid(x) * (1 + x * (1 - sigmoid(x))) */
static void silu_backward(float *d_input, const float *d_output,
const float *x, int n) {
for (int i = 0; i < n; i++) {
float s = sigmoid_f(x[i]);
float grad = s * (1.0f + x[i] * (1.0f - s));
d_input[i] = d_output[i] * grad;
}
}
/* Exp backward: d_input = d_output * exp(x) = d_output * y
* But we clamp, so we need the clamped version */
static void exp_backward_clamped(float *d_input, const float *d_output,
const float *x, int n) {
for (int i = 0; i < n; i++) {
float clamped = clamp_f(x[i], -10.0f, 10.0f);
float y = expf(clamped);
/* Gradient is zero outside the clamp range */
if (x[i] < -10.0f || x[i] > 10.0f) {
d_input[i] = 0.0f;
} else {
d_input[i] = d_output[i] * y;
}
}
}
/* Matrix multiply backward: Y = X @ W
* dL/dX = dL/dY @ W^T
* dL/dW = X^T @ dL/dY */
static void matmul_backward_x(Tensor *d_X, const Tensor *d_Y, const Tensor *W) {
/* d_X = d_Y @ W^T */
/* d_Y: (seq, out_dim), W: (in_dim, out_dim), d_X: (seq, in_dim) */
int seq_len = d_Y->rows;
int out_dim = d_Y->cols;
int in_dim = W->rows;
for (int s = 0; s < seq_len; s++) {
for (int i = 0; i < in_dim; i++) {
float sum = 0.0f;
for (int j = 0; j < out_dim; j++) {
sum += d_Y->data[s * out_dim + j] * W->data[i * out_dim + j];
}
d_X->data[s * in_dim + i] = sum;
}
}
}
static void matmul_backward_w(Tensor *d_W, const Tensor *d_Y, const Tensor *X) {
/* d_W = X^T @ d_Y */
/* X: (seq, in_dim), d_Y: (seq, out_dim), d_W: (in_dim, out_dim) */
int seq_len = X->rows;
int in_dim = X->cols;
int out_dim = d_Y->cols;
for (int i = 0; i < in_dim; i++) {
for (int j = 0; j < out_dim; j++) {
float sum = 0.0f;
for (int s = 0; s < seq_len; s++) {
sum += X->data[s * in_dim + i] * d_Y->data[s * out_dim + j];
}
d_W->data[i * out_dim + j] += sum; /* accumulate */
}
}
}
/* Layer normalization backward
* y = weight * (x - mean) / sqrt(var + eps) + bias
* This is a bit complex due to the mean and variance dependencies */
static void layer_norm_backward_single(float *d_x, float *d_weight, float *d_bias,
const float *d_y, const float *x,
const float *weight, int n) {
/* Forward stats */
float mean = 0.0f;
for (int i = 0; i < n; i++) mean += x[i];
mean /= (float)n;
float var = 0.0f;
for (int i = 0; i < n; i++) {
float a = x[i] - mean;
var += a * a;
}
var /= (float)n;
float inv_std = 1.0f / sqrtf(var + EPSILON);
/* Accumulate d_gamma, d_beta and helper sums for dx */
float sum_dy_gamma = 0.0f;
float sum_dy_gamma_xhat = 0.0f;
for (int i = 0; i < n; i++) {
float x_hat = (x[i] - mean) * inv_std;
float dy = d_y[i];
d_weight[i] += dy * x_hat;
d_bias[i] += dy;
float dy_gamma = dy * weight[i];
sum_dy_gamma += dy_gamma;
sum_dy_gamma_xhat += dy_gamma * x_hat;
}
/* dx formula */
float inv_n = 1.0f / (float)n;
for (int i = 0; i < n; i++) {
float x_hat = (x[i] - mean) * inv_std;
float dy_gamma = d_y[i] * weight[i];
d_x[i] = inv_n * inv_std *
((float)n * dy_gamma - sum_dy_gamma - x_hat * sum_dy_gamma_xhat);
}
}
static void layer_norm_backward_seq(Tensor *d_x, Tensor *d_weight, Tensor *d_bias,
const Tensor *d_y, const Tensor *x,
const Tensor *weight) {
int seq_len = x->rows;
int n_embd = x->cols;
for (int s = 0; s < seq_len; s++) {
layer_norm_backward_single(
d_x->data + s * n_embd,
d_weight->data,
d_bias->data,
d_y->data + s * n_embd,
x->data + s * n_embd,
weight->data,
n_embd
);
}
}
/* Softmax cross-entropy backward
* For softmax with cross-entropy, the gradient is simply: probs - one_hot(target)
* This is one of the nice properties of this combination */
static void softmax_cross_entropy_backward(Tensor *d_logits, const Tensor *logits,
const int *targets) {
int seq_len = logits->rows;
int vocab_size = logits->cols;
float *probs = (float *)malloc((size_t)vocab_size * sizeof(float));
float eps = 0.1f; // smoothing factor
float invV = 1.0f / (float)vocab_size;
for (int t = 0; t < seq_len; t++) {
softmax_vec(probs, logits->data + t * vocab_size, vocab_size);
int target = targets[t];
for (int v = 0; v < vocab_size; v++) {
float q = (v == target)
? (1.0f - eps + eps * invV) // mostly target…
: (eps * invV); // …but small mass on others
d_logits->data[t * vocab_size + v] =
(probs[v] - q) / (float)seq_len;
}
}
free(probs);
}
/* ============================================================
* WKV Backward Pass
* ============================================================
*
* Forward recurrence:
* wkv_t = (num_t + tf * k_t * v_t) / (den_t + tf * k_t + eps)
* num_{t+1} = decay_t * num_t + k_t * v_t
* den_{t+1} = decay_t * den_t + k_t
*
* Where:
* k_t = exp(k_pre_t)
* tf = exp(time_first)
* decay_t = sigmoid(decay_base + decay_delta_t)
*
* We need backward pass through this recurrence.
*/
static void wkv_backward(
Tensor *d_k_exp, Tensor *d_v, Tensor *d_decay, Tensor *d_time_first_exp,
Tensor *d_mem_gate_write_proj,
Tensor *d_mem_gate_read_proj,
const Tensor *d_wkv, const Tensor *k_exp, const Tensor *v,
const Tensor *decay, const Tensor *time_first_exp,
Tensor *num_states, Tensor *den_states,
Tensor *write_gates, Tensor *read_gates,
int seq_len, int n_embd,
int n_head, int n_slots,
const float *alibi_slopes
) {
int hdim = n_embd / n_head;
tensor_zero(d_k_exp);
tensor_zero(d_v);
tensor_zero(d_decay);
tensor_zero(d_time_first_exp);
/* Gradient accumulators for slot states */
/* d_num_next[s * n_embd + i], d_den_next[s * n_embd + i] */
float *d_num_next = (float *)calloc((size_t)(n_slots * n_embd), sizeof(float));
float *d_den_next = (float *)calloc((size_t)(n_slots * n_embd), sizeof(float));
/* Gradient accumulators for gate logits (pre-softmax) */
/* We'll accumulate d_write_gates and d_read_gates, then backprop
through softmax outside this function */
float *d_write_g = (float *)calloc((size_t)n_slots, sizeof(float));
float *d_read_g = (float *)calloc((size_t)n_slots, sizeof(float));
float *alibi_decay_arr = (float *)malloc((size_t)n_head * sizeof(float));
for (int h = 0; h < n_head; h++) {
alibi_decay_arr[h] = expf(-alibi_slopes[h]);
}
for (int t = seq_len - 1; t >= 0; t--) {
memset(d_write_g, 0, (size_t)n_slots * sizeof(float));
memset(d_read_g, 0, (size_t)n_slots * sizeof(float));
for (int h = 0; h < n_head; h++) {
int base_h = h * hdim;
float ad = alibi_decay_arr[h];
for (int d_idx = 0; d_idx < hdim; d_idx++) {
int i = base_h + d_idx;
int idx = t * n_embd + i;
float ki = k_exp->data[idx];
float vi = v->data[idx];
float di = decay->data[idx];
float tfi = time_first_exp->data[i];
float kv = ki * vi;
float combined = di * ad;
/* Recompute read values */
float read_num = 0.0f, read_den = 0.0f;
for (int s = 0; s < n_slots; s++) {
int si = s * n_embd + i;
read_num += read_gates[t].data[s] * num_states[t].data[si];
read_den += read_gates[t].data[s] * den_states[t].data[si];
}
float numerator = read_num + tfi * kv;
float denominator = read_den + tfi * ki + EPSILON;
float inv_den = 1.0f / denominator;
float dw = d_wkv->data[idx];
float d_numerator = dw * inv_den;
float d_denominator = -dw * numerator * inv_den * inv_den;
/* Gradients for read */
float d_read_num = d_numerator;
float d_read_den = d_denominator;
for (int s = 0; s < n_slots; s++) {
int si = s * n_embd + i;
d_read_g[s] += d_read_num * num_states[t].data[si];
d_read_g[s] += d_read_den * den_states[t].data[si];
/* Gradient to states from read */
float d_state_num = d_read_num * read_gates[t].data[s];
float d_state_den = d_read_den * read_gates[t].data[s];
/* Add gradient from future via state update */
d_state_num += d_num_next[si] * combined;
d_state_den += d_den_next[si] * combined;
/* Gradient for decay from state update */
d_decay->data[idx] += (d_num_next[si] * num_states[t].data[si]
+ d_den_next[si] * den_states[t].data[si]) * ad;
/* Gradient for write gate */
d_write_g[s] += d_num_next[si] * kv + d_den_next[si] * ki;
/* Gradient for k, v from write */
float wg = write_gates[t].data[s];
d_k_exp->data[idx] += d_num_next[si] * wg * vi
+ d_den_next[si] * wg;
d_v->data[idx] += d_num_next[si] * wg * ki;
/* Propagate to previous timestep */
d_num_next[si] = d_state_num;
d_den_next[si] = d_state_den;
}
/* Gradients from wkv output for tf, k, v */
d_time_first_exp->data[i] += d_numerator * kv + d_denominator * ki;
d_k_exp->data[idx] += d_numerator * tfi * vi + d_denominator * tfi;
d_v->data[idx] += d_numerator * tfi * ki;
}
}
/* Backprop through softmax for write/read gates at timestep t */
/* d_logits = softmax_backward(d_gates, gates) */
/* For softmax: d_logit_i = sum_j (d_gate_j * gate_j * (delta_ij - gate_i)) */
/* We need to accumulate into d_mem_gate_write_proj / d_mem_gate_read_proj */
{
float *wg = write_gates[t].data;
float *rg = read_gates[t].data;
float d_wl[n_slots], d_rl[n_slots]; /* VLA ok, n_slots is small */
/* Softmax backward for write gates */
float dot_w = 0.0f;
for (int s = 0; s < n_slots; s++) dot_w += d_write_g[s] * wg[s];
for (int s = 0; s < n_slots; s++) {
d_wl[s] = wg[s] * (d_write_g[s] - dot_w);
}
/* Softmax backward for read gates */
float dot_r = 0.0f;
for (int s = 0; s < n_slots; s++) dot_r += d_read_g[s] * rg[s];
for (int s = 0; s < n_slots; s++) {
d_rl[s] = rg[s] * (d_read_g[s] - dot_r);
}
/* These are gradients w.r.t. the gate logits = x_ln1 @ mem_gate_{write,read}
* We accumulate d_x_ln1 contribution and d_W contribution outside */
/* Store in the gradient tensors for later matmul backward */
for (int s = 0; s < n_slots; s++) {
d_mem_gate_write_proj->data[t * n_slots + s] = d_wl[s];
d_mem_gate_read_proj->data[t * n_slots + s] = d_rl[s];
}
}
}
/* Clamp */
for (int i = 0; i < d_k_exp->size; i++) {
d_k_exp->data[i] = clamp_f(d_k_exp->data[i], -GRAD_CLIP, GRAD_CLIP);
d_v->data[i] = clamp_f(d_v->data[i], -GRAD_CLIP, GRAD_CLIP);
}
for (int i = 0; i < d_decay->size; i++) {
d_decay->data[i] = clamp_f(d_decay->data[i], -GRAD_CLIP, GRAD_CLIP);
}
for (int i = 0; i < n_embd; i++) {
d_time_first_exp->data[i] = clamp_f(d_time_first_exp->data[i],
-GRAD_CLIP, GRAD_CLIP);
}
free(d_num_next);
free(d_den_next);
free(d_write_g);
free(d_read_g);
free(alibi_decay_arr);
}
static void multi_scale_shift_backward(
Tensor *d_x, /* Gradient w.r.t. input x */
Tensor *d_shift_w1, /* Gradient w.r.t. time_shift_w1 (before sigmoid) */
Tensor *d_shift_w2,
Tensor *d_shift_w4,
const Tensor *d_out, /* Incoming gradient */
const Tensor *x, /* Original input */
const Tensor *shift_w1_sig, /* Cached sigmoid outputs */
const Tensor *shift_w2_sig,
const Tensor *shift_w4_sig,
const Tensor *shift_w_sum, /* w1 + w2 + w4 + eps */
int seq_len,
int n_embd
) {
tensor_zero(d_x);
tensor_zero(d_shift_w1);
tensor_zero(d_shift_w2);
tensor_zero(d_shift_w4);
for (int t = 0; t < seq_len; t++) {
for (int i = 0; i < n_embd; i++) {
float w1 = shift_w1_sig->data[i];
float w2 = shift_w2_sig->data[i];
float w4 = shift_w4_sig->data[i];
float w_sum = shift_w_sum->data[i];
float inv_sum = 1.0f / w_sum;
float x1 = (t >= 1) ? x->data[(t-1) * n_embd + i] : 0.0f;
float x2 = (t >= 2) ? x->data[(t-2) * n_embd + i] : 0.0f;
float x4 = (t >= 4) ? x->data[(t-4) * n_embd + i] : 0.0f;
float d_out_ti = d_out->data[t * n_embd + i];
/* out = (w1*x1 + w2*x2 + w4*x4) / w_sum
* where w_sum = w1 + w2 + w4 + eps */
/* Gradient w.r.t. x1, x2, x4 */
float d_x1 = d_out_ti * w1 * inv_sum;
float d_x2 = d_out_ti * w2 * inv_sum;
float d_x4 = d_out_ti * w4 * inv_sum;
if (t >= 1) d_x->data[(t-1) * n_embd + i] += d_x1;
if (t >= 2) d_x->data[(t-2) * n_embd + i] += d_x2;
if (t >= 4) d_x->data[(t-4) * n_embd + i] += d_x4;
/* Gradient w.r.t. w1, w2, w4 (normalized weights) */
float numerator = w1 * x1 + w2 * x2 + w4 * x4;
/* d_w1 (before normalization) */
/* Let n1 = w1/sum, output = n1*x1 + n2*x2 + n4*x4 */
/* d_w1 = d_out * (x1/sum - numerator/sum^2) */
float d_w1_raw = d_out_ti * (x1 * inv_sum - numerator * inv_sum * inv_sum);
float d_w2_raw = d_out_ti * (x2 * inv_sum - numerator * inv_sum * inv_sum);
float d_w4_raw = d_out_ti * (x4 * inv_sum - numerator * inv_sum * inv_sum);
/* Through sigmoid: d_w_pre = d_w_raw * w * (1 - w) */
d_shift_w1->data[i] += d_w1_raw * w1 * (1.0f - w1);
d_shift_w2->data[i] += d_w2_raw * w2 * (1.0f - w2);
d_shift_w4->data[i] += d_w4_raw * w4 * (1.0f - w4);
}
}
}
/* ============================================================
* Token Mixing Backward
* ============================================================ */
static void token_mixing_backward(
Tensor *d_x_ln1, /* Gradient w.r.t. layer norm output */
Tensor *d_x_shifted, /* Gradient w.r.t. shifted input */
Tensor *d_mix_r, /* Gradient w.r.t. time_mix_r */
Tensor *d_mix_k,
Tensor *d_mix_v,
const Tensor *d_xr, /* Incoming gradient for xr */
const Tensor *d_xk,
const Tensor *d_xv,
const Tensor *x_ln1,
const Tensor *x_shifted,
const Tensor *mix_r_sig,
const Tensor *mix_k_sig,
const Tensor *mix_v_sig,
int seq_len,
int n_embd
) {
tensor_zero(d_x_ln1);
tensor_zero(d_x_shifted);
tensor_zero(d_mix_r);
tensor_zero(d_mix_k);
tensor_zero(d_mix_v);
for (int t = 0; t < seq_len; t++) {
for (int i = 0; i < n_embd; i++) {
int idx = t * n_embd + i;
float mr = mix_r_sig->data[i];
float mk = mix_k_sig->data[i];
float mv = mix_v_sig->data[i];
float x_val = x_ln1->data[idx];
float x_sh = x_shifted->data[idx];
/* xr = x * mr + x_shifted * (1 - mr) */
float d_xr_ti = d_xr->data[idx];
d_x_ln1->data[idx] += d_xr_ti * mr;
d_x_shifted->data[idx] += d_xr_ti * (1.0f - mr);
/* d_mr = d_xr * (x - x_shifted) */
float d_mr = d_xr_ti * (x_val - x_sh);
/* Through sigmoid */
d_mix_r->data[i] += d_mr * mr * (1.0f - mr);
/* xk = x * mk + x_shifted * (1 - mk) */
float d_xk_ti = d_xk->data[idx];
d_x_ln1->data[idx] += d_xk_ti * mk;
d_x_shifted->data[idx] += d_xk_ti * (1.0f - mk);
float d_mk = d_xk_ti * (x_val - x_sh);
d_mix_k->data[i] += d_mk * mk * (1.0f - mk);
/* xv = x * mv + x_shifted * (1 - mv) */
float d_xv_ti = d_xv->data[idx];
d_x_ln1->data[idx] += d_xv_ti * mv;
d_x_shifted->data[idx] += d_xv_ti * (1.0f - mv);
float d_mv = d_xv_ti * (x_val - x_sh);
d_mix_v->data[i] += d_mv * mv * (1.0f - mv);
}
}
}
/* ============================================================
* Channel Mixing Backward
* ============================================================ */
static void channel_mix_shift_backward(
Tensor *d_x_ln2, /* Gradient w.r.t. layer norm output */
Tensor *d_channel_mix, /* Gradient w.r.t. channel_mix parameter */
const Tensor *d_xm, /* Incoming gradient */
const Tensor *x_ln2, /* Cached layer norm output */
const Tensor *cm_mix_sig, /* Cached sigmoid(channel_mix) */
int seq_len,
int n_embd
) {
tensor_zero(d_x_ln2);
tensor_zero(d_channel_mix);
for (int t = 0; t < seq_len; t++) {
for (int i = 0; i < n_embd; i++) {
int idx = t * n_embd + i;
float mix = cm_mix_sig->data[i];
float x_curr = x_ln2->data[idx];
float x_prev = (t > 0) ? x_ln2->data[(t-1) * n_embd + i] : 0.0f;
float d_xm_ti = d_xm->data[idx];
/* xm = x_curr * mix + x_prev * (1 - mix) */
d_x_ln2->data[idx] += d_xm_ti * mix;
if (t > 0) {
d_x_ln2->data[(t-1) * n_embd + i] += d_xm_ti * (1.0f - mix);
}
float d_mix = d_xm_ti * (x_curr - x_prev);
d_channel_mix->data[i] += d_mix * mix * (1.0f - mix);
}
}
}
/* ============================================================
* Forward Pass with Caching
* ============================================================ */
static float forward_with_cache(ForwardCache *cache, const int *tokens, int seq_len,
const ModelParams *mp, const lrnnConfig *cfg) {
int n_embd = cfg->n_embd;
int n_head = cfg->n_head;
//int hdim = n_embd / n_head;
/* Embedding lookup */
for (int t = 0; t < seq_len; t++) {
int tok = tokens[t];
memcpy(cache->emb_out.data + t * n_embd,
mp->emb.data + tok * n_embd,
(size_t)n_embd * sizeof(float));
}
/* Initial layer norm */
layer_norm_seq(&cache->x_ln0, &cache->emb_out, &mp->ln0_weight, &mp->ln0_bias);
/* Copy x_ln0 to first layer input - we'll use x_final as working space */
tensor_copy(&cache->x_final, &cache->x_ln0);
/* Process layers */
for (int layer_idx = 0; layer_idx < mp->n_layers; layer_idx++) {
const LayerParams *lp = &mp->layers[layer_idx];
LayerCache *lc = &cache->layers[layer_idx];
/* ============ TimeMix Forward ============ */
tensor_copy(&lc->x_in, &cache->x_final);
/* Layer norm 1 */
layer_norm_seq(&lc->x_ln1, &cache->x_final, &lp->ln1_weight, &lp->ln1_bias);
/* Multi-scale shift - compute sigmoid weights */
sigmoid_vec(lc->shift_w1_sig.data, lp->time_shift_w1.data, n_embd);
sigmoid_vec(lc->shift_w2_sig.data, lp->time_shift_w2.data, n_embd);
sigmoid_vec(lc->shift_w4_sig.data, lp->time_shift_w4.data, n_embd);
for (int i = 0; i < n_embd; i++) {
lc->shift_w_sum.data[i] = lc->shift_w1_sig.data[i] +
lc->shift_w2_sig.data[i] +
lc->shift_w4_sig.data[i] + EPSILON;
}
/* Apply multi-scale shift */
for (int t = 0; t < seq_len; t++) {
for (int i = 0; i < n_embd; i++) {
float w1 = lc->shift_w1_sig.data[i] / lc->shift_w_sum.data[i];
float w2 = lc->shift_w2_sig.data[i] / lc->shift_w_sum.data[i];
float w4 = lc->shift_w4_sig.data[i] / lc->shift_w_sum.data[i];
float x1 = (t >= 1) ? lc->x_ln1.data[(t-1) * n_embd + i] : 0.0f;
float x2 = (t >= 2) ? lc->x_ln1.data[(t-2) * n_embd + i] : 0.0f;
float x4 = (t >= 4) ? lc->x_ln1.data[(t-4) * n_embd + i] : 0.0f;
lc->x_shifted.data[t * n_embd + i] = w1 * x1 + w2 * x2 + w4 * x4;
}
}
/* Token mixing weights */
sigmoid_vec(lc->mix_r_sig.data, lp->time_mix_r.data, n_embd);
sigmoid_vec(lc->mix_k_sig.data, lp->time_mix_k.data, n_embd);
sigmoid_vec(lc->mix_v_sig.data, lp->time_mix_v.data, n_embd);
for (int t = 0; t < seq_len; t++) {
for (int i = 0; i < n_embd; i++) {
int idx = t * n_embd + i;
float mr = lc->mix_r_sig.data[i];
float mk = lc->mix_k_sig.data[i];
float mv = lc->mix_v_sig.data[i];
lc->xr.data[idx] = lc->x_ln1.data[idx] * mr + lc->x_shifted.data[idx] * (1.0f - mr);
lc->xk.data[idx] = lc->x_ln1.data[idx] * mk + lc->x_shifted.data[idx] * (1.0f - mk);
lc->xv.data[idx] = lc->x_ln1.data[idx] * mv + lc->x_shifted.data[idx] * (1.0f - mv);
}
}
/* Projections */
matmul(&lc->r_pre, &lc->xr, &lp->Wr);
matmul(&lc->k_pre, &lc->xk, &lp->Wk);
matmul(&lc->v, &lc->xv, &lp->Wv);
/* Data-dependent decay */
matmul(&lc->decay_tmp, &lc->x_ln1, &lp->decay_lora_a);
matmul(&lc->decay_delta, &lc->decay_tmp, &lp->decay_lora_b);
for (int t = 0; t < seq_len; t++) {
for (int i = 0; i < n_embd; i++) {
int idx = t * n_embd + i;
lc->decay_pre.data[idx] = lp->decay_base.data[i] + lc->decay_delta.data[idx];
lc->decay.data[idx] = sigmoid_f(lc->decay_pre.data[idx]);
}
}
/* Receptance (sigmoid) and k (exp) */
for (int i = 0; i < lc->r_pre.size; i++) {
lc->r.data[i] = sigmoid_f(lc->r_pre.data[i]);
}
for (int i = 0; i < n_embd; i++) {
lc->time_first_exp.data[i] = expf(clamp_f(lp->time_first.data[i], -10.0f, 10.0f));
}
for (int i = 0; i < lc->k_pre.size; i++) {
lc->k_exp.data[i] = expf(clamp_f(lc->k_pre.data[i], -10.0f, 10.0f));
}
/* WKV sequential scan with state caching and ALiBi */
/* WKV scan with multi-slot memory and ALiBi */
{
int n_slots = cfg->n_mem_slots;
int hdim_val = n_embd / n_head;
/* Zero initial states */
tensor_zero(&lc->num_states[0]);
tensor_zero(&lc->den_states[0]);
float *wl = (float *)malloc((size_t)n_slots * sizeof(float));
float *rl = (float *)malloc((size_t)n_slots * sizeof(float));
for (int t = 0; t < seq_len; t++) {
/* Compute gates for this timestep */
matvec(wl, lc->x_ln1.data + t * n_embd, &lp->mem_gate_write);
softmax_vec(lc->write_gates[t].data, wl, n_slots);
matvec(rl, lc->x_ln1.data + t * n_embd, &lp->mem_gate_read);
softmax_vec(lc->read_gates[t].data, rl, n_slots);
for (int h = 0; h < n_head; h++) {
int base = h * hdim_val;
float ad = expf(-lp->alibi_slopes.data[h]);
for (int d = 0; d < hdim_val; d++) {
int i = base + d;
int idx = t * n_embd + i;
float ki = lc->k_exp.data[idx];
float vi = lc->v.data[idx];
float di = lc->decay.data[idx];
float tfi = lc->time_first_exp.data[i];
float kv = ki * vi;
float combined = di * ad;
/* Read from slots */
float read_num = 0.0f, read_den = 0.0f;
for (int s = 0; s < n_slots; s++) {
int si = s * n_embd + i;
read_num += lc->read_gates[t].data[s] *
lc->num_states[t].data[si];
read_den += lc->read_gates[t].data[s] *
lc->den_states[t].data[si];
}
lc->wkv.data[idx] = (read_num + tfi * kv) /
(read_den + tfi * ki + EPSILON);
/* Write to slots */
for (int s = 0; s < n_slots; s++) {
int si = s * n_embd + i;
float wg = lc->write_gates[t].data[s];
lc->num_states[t+1].data[si] =
combined * lc->num_states[t].data[si] + wg * kv;
lc->den_states[t+1].data[si] =
combined * lc->den_states[t].data[si] + wg * ki;
}
}
}
}
free(wl);
free(rl);
}
/* Apply receptance and output projection */
for (int i = 0; i < lc->wkv.size; i++) {
lc->wkv_r.data[i] = lc->wkv.data[i] * lc->r.data[i];
}
matmul(&lc->tm_out, &lc->wkv_r, &lp->Wo);
/* Residual connection */
for (int i = 0; i < cache->x_final.size; i++) {
lc->x_after_tm.data[i] = cache->x_final.data[i] + lc->tm_out.data[i];
}
/* ============ ChannelMix Forward ============ */
/* Layer norm 2 */
layer_norm_seq(&lc->x_ln2, &lc->x_after_tm, &lp->ln2_weight, &lp->ln2_bias);
/* Channel mixing */
sigmoid_vec(lc->cm_mix_sig.data, lp->channel_mix.data, n_embd);
for (int t = 0; t < seq_len; t++) {
for (int i = 0; i < n_embd; i++) {
int idx = t * n_embd + i;
float mix = lc->cm_mix_sig.data[i];
float x_curr = lc->x_ln2.data[idx];
float x_prev = (t > 0) ? lc->x_ln2.data[(t-1) * n_embd + i] : 0.0f;
lc->xm.data[idx] = x_curr * mix + x_prev * (1.0f - mix);
}
}
/* SwiGLU FFN */
matmul(&lc->gate_pre, &lc->xm, &lp->ffn_gate);
matmul(&lc->up_val, &lc->xm, &lp->ffn_up);
for (int i = 0; i < lc->gate_pre.size; i++) {
lc->gate_silu.data[i] = silu_f(lc->gate_pre.data[i]);
lc->hidden.data[i] = lc->gate_silu.data[i] * lc->up_val.data[i];
}
matmul(&lc->cm_out, &lc->hidden, &lp->ffn_down);
/* Residual connection - update x_final for next layer */
for (int i = 0; i < cache->x_final.size; i++) {
cache->x_final.data[i] = lc->x_after_tm.data[i] + lc->cm_out.data[i];
}
}
/* Output layer norm and head */
layer_norm_seq(&cache->x_ln_out, &cache->x_final, &mp->ln_out_weight, &mp->ln_out_bias);
matmul(&cache->logits, &cache->x_ln_out, &mp->head);
/* Compute loss (return for monitoring) */
float loss = cross_entropy_loss(&cache->logits, tokens + 1, seq_len - 1);
return loss;
}
/* ============================================================
* Backward Pass
* ============================================================ */
static void backward_pass(ModelGrads *grads, const ForwardCache *cache,
const int *tokens, int seq_len,
const ModelParams *mp, const lrnnConfig *cfg) {
int n_embd = cfg->n_embd;
int vocab_size = cfg->vocab_size;
int ffn_h = ffn_hidden(cfg);
int lora_rank = cfg->decay_lora_rank;
int target_len = seq_len - 1; /* We predict tokens[1:] from tokens[:-1] */
/* Allocate working gradient tensors */
Tensor d_logits = tensor_alloc(target_len, vocab_size);
Tensor d_x_ln_out = tensor_alloc(target_len, n_embd);
Tensor d_x_final = tensor_alloc(target_len, n_embd);
Tensor d_x_tmp = tensor_alloc(target_len, n_embd);
/* Softmax cross-entropy backward */
/* Create view of logits for target_len positions */
Tensor logits_view = {
.data = cache->logits.data,
.rows = target_len,
.cols = vocab_size,
.size = target_len * vocab_size
};
softmax_cross_entropy_backward(&d_logits, &logits_view, tokens + 1);
/* Head backward: logits = x_ln_out @ head */
matmul_backward_x(&d_x_ln_out, &d_logits, &mp->head);
/* Create view of x_ln_out for target_len */
Tensor x_ln_out_view = {
.data = cache->x_ln_out.data,
.rows = target_len,
.cols = n_embd,
.size = target_len * n_embd
};
matmul_backward_w(&grads->head, &d_logits, &x_ln_out_view);
/* Output layer norm backward */
Tensor x_final_view = {
.data = cache->x_final.data,
.rows = target_len,
.cols = n_embd,
.size = target_len * n_embd
};
layer_norm_backward_seq(&d_x_final, &grads->ln_out_weight, &grads->ln_out_bias,
&d_x_ln_out, &x_final_view, &mp->ln_out_weight);
/* Backward through layers (in reverse order) */
for (int layer_idx = mp->n_layers - 1; layer_idx >= 0; layer_idx--) {
const LayerParams *lp = &mp->layers[layer_idx];
const LayerCache *lc = &cache->layers[layer_idx];
LayerGrads *lg = &grads->layers[layer_idx];
/* ============ ChannelMix Backward ============ */
/* d_x_final comes from residual: x_final = x_after_tm + cm_out */
/* So d_x_after_tm and d_cm_out both get d_x_final */
Tensor d_cm_out = tensor_alloc(target_len, n_embd);
Tensor d_hidden = tensor_alloc(target_len, ffn_h);
Tensor d_gate_silu = tensor_alloc(target_len, ffn_h);
Tensor d_up_val = tensor_alloc(target_len, ffn_h);
Tensor d_gate_pre = tensor_alloc(target_len, ffn_h);
Tensor d_xm = tensor_alloc(target_len, n_embd);
Tensor d_x_ln2 = tensor_alloc(target_len, n_embd);
Tensor d_x_after_tm = tensor_alloc(target_len, n_embd);
/* d_cm_out = d_x_final */
tensor_copy(&d_cm_out, &d_x_final);
/* cm_out = hidden @ ffn_down backward */
Tensor hidden_view = {
.data = lc->hidden.data,
.rows = target_len,
.cols = ffn_h,
.size = target_len * ffn_h
};
matmul_backward_x(&d_hidden, &d_cm_out, &lp->ffn_down);
matmul_backward_w(&lg->ffn_down, &d_cm_out, &hidden_view);
/* hidden = gate_silu * up_val */
for (int i = 0; i < target_len * ffn_h; i++) {
d_gate_silu.data[i] = d_hidden.data[i] * lc->up_val.data[i];
d_up_val.data[i] = d_hidden.data[i] * lc->gate_silu.data[i];
}
/* gate_silu = silu(gate_pre) */
silu_backward(d_gate_pre.data, d_gate_silu.data, lc->gate_pre.data, target_len * ffn_h);
/* gate_pre = xm @ ffn_gate, up_val = xm @ ffn_up */
Tensor xm_view = {
.data = lc->xm.data,
.rows = target_len,
.cols = n_embd,
.size = target_len * n_embd
};
Tensor d_xm_gate = tensor_alloc(target_len, n_embd);
Tensor d_xm_up = tensor_alloc(target_len, n_embd);
matmul_backward_x(&d_xm_gate, &d_gate_pre, &lp->ffn_gate);
matmul_backward_w(&lg->ffn_gate, &d_gate_pre, &xm_view);
matmul_backward_x(&d_xm_up, &d_up_val, &lp->ffn_up);
matmul_backward_w(&lg->ffn_up, &d_up_val, &xm_view);
/* Combine gradients for xm */
for (int i = 0; i < target_len * n_embd; i++) {
d_xm.data[i] = d_xm_gate.data[i] + d_xm_up.data[i];
}
tensor_free(&d_xm_gate);
tensor_free(&d_xm_up);
/* Channel mixing shift backward */
Tensor x_ln2_view = {
.data = lc->x_ln2.data,
.rows = target_len,
.cols = n_embd,
.size = target_len * n_embd
};
channel_mix_shift_backward(&d_x_ln2, &lg->channel_mix, &d_xm,
&x_ln2_view, &lc->cm_mix_sig,target_len, n_embd);
/* Layer norm 2 backward */
Tensor x_after_tm_view = {
.data = lc->x_after_tm.data,
.rows = target_len,
.cols = n_embd,
.size = target_len * n_embd
};
layer_norm_backward_seq(&d_x_after_tm, &lg->ln2_weight, &lg->ln2_bias,
&d_x_ln2, &x_after_tm_view, &lp->ln2_weight);
/* Add residual gradient */
for (int i = 0; i < target_len * n_embd; i++) {
d_x_after_tm.data[i] += d_x_final.data[i];
}
tensor_free(&d_cm_out);
tensor_free(&d_hidden);
tensor_free(&d_gate_silu);
tensor_free(&d_up_val);
tensor_free(&d_gate_pre);
tensor_free(&d_xm);
tensor_free(&d_x_ln2);
/* ============ TimeMix Backward ============ */
Tensor d_tm_out = tensor_alloc(target_len, n_embd);
Tensor d_wkv_r = tensor_alloc(target_len, n_embd);
Tensor d_wkv = tensor_alloc(target_len, n_embd);
Tensor d_r = tensor_alloc(target_len, n_embd);
Tensor d_k_exp = tensor_alloc(target_len, n_embd);
Tensor d_v = tensor_alloc(target_len, n_embd);
Tensor d_decay = tensor_alloc(target_len, n_embd);
Tensor d_time_first_exp = tensor_alloc_1d(n_embd);
Tensor d_decay_pre = tensor_alloc(target_len, n_embd);
Tensor d_decay_delta = tensor_alloc(target_len, n_embd);
Tensor d_decay_tmp = tensor_alloc(target_len, lora_rank);
Tensor d_x_ln1_decay = tensor_alloc(target_len, n_embd);
Tensor d_r_pre = tensor_alloc(target_len, n_embd);
Tensor d_k_pre = tensor_alloc(target_len, n_embd);
Tensor d_xr = tensor_alloc(target_len, n_embd);
Tensor d_xk = tensor_alloc(target_len, n_embd);
Tensor d_xv = tensor_alloc(target_len, n_embd);
Tensor d_x_ln1 = tensor_alloc(target_len, n_embd);
Tensor d_x_shifted = tensor_alloc(target_len, n_embd);
Tensor d_x_layer_in = tensor_alloc(target_len, n_embd);
/* d_tm_out comes from residual: x_after_tm = x + tm_out */
tensor_copy(&d_tm_out, &d_x_after_tm);
/* tm_out = wkv_r @ Wo backward */
Tensor wkv_r_view = {
.data = lc->wkv_r.data,
.rows = target_len,
.cols = n_embd,
.size = target_len * n_embd
};
matmul_backward_x(&d_wkv_r, &d_tm_out, &lp->Wo);
matmul_backward_w(&lg->Wo, &d_tm_out, &wkv_r_view);
/* wkv_r = wkv * r */
for (int i = 0; i < target_len * n_embd; i++) {
d_wkv.data[i] = d_wkv_r.data[i] * lc->r.data[i];
d_r.data[i] = d_wkv_r.data[i] * lc->wkv.data[i];
}
/* decay_tmp = x_ln1 @ decay_lora_a backward */
Tensor x_ln1_view = {
.data = lc->x_ln1.data,
.rows = target_len,
.cols = n_embd,
.size = target_len * n_embd
};
/* WKV backward - updated call with multi-head ALiBi */
/* Memory gate backward */
Tensor d_gate_write_logits = tensor_alloc(target_len, cfg->n_mem_slots);
Tensor d_gate_read_logits = tensor_alloc(target_len, cfg->n_mem_slots);
/* wkv_backward now fills these */
wkv_backward(&d_k_exp, &d_v, &d_decay, &d_time_first_exp,
&d_gate_write_logits, &d_gate_read_logits,
&d_wkv, &lc->k_exp, &lc->v, &lc->decay,
&lc->time_first_exp,
lc->num_states, lc->den_states,
lc->write_gates, lc->read_gates,
target_len, n_embd,
cfg->n_head, cfg->n_mem_slots,
lp->alibi_slopes.data);
/* gate_write_logits = x_ln1 @ mem_gate_write backward */
Tensor d_x_ln1_wgate = tensor_alloc(target_len, n_embd);
matmul_backward_x(&d_x_ln1_wgate, &d_gate_write_logits, &lp->mem_gate_write);
matmul_backward_w(&lg->mem_gate_write, &d_gate_write_logits, &x_ln1_view);
Tensor d_x_ln1_rgate = tensor_alloc(target_len, n_embd);
matmul_backward_x(&d_x_ln1_rgate, &d_gate_read_logits, &lp->mem_gate_read);
matmul_backward_w(&lg->mem_gate_read, &d_gate_read_logits, &x_ln1_view);
/* Add to d_x_ln1 later when combining all x_ln1 gradients */
/* After token_mixing_backward and before multi_scale_shift_backward: */
for (int i = 0; i < target_len * n_embd; i++) {
d_x_ln1.data[i] += d_x_ln1_decay.data[i]
+ d_x_ln1_wgate.data[i]
+ d_x_ln1_rgate.data[i];
}
tensor_free(&d_gate_write_logits);
tensor_free(&d_gate_read_logits);
tensor_free(&d_x_ln1_wgate);
tensor_free(&d_x_ln1_rgate);
/* time_first backward: time_first_exp = exp(time_first) */
for (int i = 0; i < n_embd; i++) {
float x = lp->time_first.data[i];
if (x >= -10.0f && x <= 10.0f) {
lg->time_first.data[i] += d_time_first_exp.data[i] * lc->time_first_exp.data[i];
}
}
/* decay backward: decay = sigmoid(decay_pre) */
for (int i = 0; i < target_len * n_embd; i++) {
float y = lc->decay.data[i];
d_decay_pre.data[i] = d_decay.data[i] * y * (1.0f - y);
}
/* decay_pre = decay_base + decay_delta */
for (int t = 0; t < target_len; t++) {
for (int i = 0; i < n_embd; i++) {
int idx = t * n_embd + i;
lg->decay_base.data[i] += d_decay_pre.data[idx];
d_decay_delta.data[idx] = d_decay_pre.data[idx];
}
}
/* decay_delta = decay_tmp @ decay_lora_b backward */
Tensor decay_tmp_view = {
.data = lc->decay_tmp.data,
.rows = target_len,
.cols = lora_rank,
.size = target_len * lora_rank
};
matmul_backward_x(&d_decay_tmp, &d_decay_delta, &lp->decay_lora_b);
matmul_backward_w(&lg->decay_lora_b, &d_decay_delta, &decay_tmp_view);
matmul_backward_x(&d_x_ln1_decay, &d_decay_tmp, &lp->decay_lora_a);
matmul_backward_w(&lg->decay_lora_a, &d_decay_tmp, &x_ln1_view);
/* r backward: r = sigmoid(r_pre) */
sigmoid_backward(d_r_pre.data, d_r.data, lc->r.data, target_len * n_embd);
/* k_exp backward: k_exp = exp(k_pre) (clamped) */
exp_backward_clamped(d_k_pre.data, d_k_exp.data, lc->k_pre.data, target_len * n_embd);
/* r_pre = xr @ Wr, k_pre = xk @ Wk, v = xv @ Wv backward */
Tensor xr_view = { .data = lc->xr.data, .rows = target_len, .cols = n_embd, .size = target_len * n_embd };
Tensor xk_view = { .data = lc->xk.data, .rows = target_len, .cols = n_embd, .size = target_len * n_embd };
Tensor xv_view = { .data = lc->xv.data, .rows = target_len, .cols = n_embd, .size = target_len * n_embd };
matmul_backward_x(&d_xr, &d_r_pre, &lp->Wr);
matmul_backward_w(&lg->Wr, &d_r_pre, &xr_view);
matmul_backward_x(&d_xk, &d_k_pre, &lp->Wk);
matmul_backward_w(&lg->Wk, &d_k_pre, &xk_view);
matmul_backward_x(&d_xv, &d_v, &lp->Wv);
matmul_backward_w(&lg->Wv, &d_v, &xv_view);
/* Token mixing backward */
token_mixing_backward(&d_x_ln1, &d_x_shifted,
&lg->time_mix_r, &lg->time_mix_k, &lg->time_mix_v,
&d_xr, &d_xk, &d_xv,
&lc->x_ln1, &lc->x_shifted,
&lc->mix_r_sig, &lc->mix_k_sig, &lc->mix_v_sig, target_len, n_embd);
/* Add gradient from decay LoRA path */
for (int i = 0; i < target_len * n_embd; i++) {
d_x_ln1.data[i] += d_x_ln1_decay.data[i];
}
/* Multi-scale shift backward */
Tensor d_x_ln1_shift = tensor_alloc(target_len, n_embd);
multi_scale_shift_backward(&d_x_ln1_shift,
&lg->time_shift_w1, &lg->time_shift_w2, &lg->time_shift_w4,
&d_x_shifted, &lc->x_ln1,
&lc->shift_w1_sig, &lc->shift_w2_sig, &lc->shift_w4_sig,
&lc->shift_w_sum, target_len, n_embd);
/* Add shift gradients to x_ln1 */
for (int i = 0; i < target_len * n_embd; i++) {
d_x_ln1.data[i] += d_x_ln1_shift.data[i];
}
tensor_free(&d_x_ln1_shift);
Tensor layer_input;
if (layer_idx == 0) {
layer_input.data = cache->x_ln0.data;
layer_input.rows = target_len;
layer_input.cols = n_embd;
layer_input.size = target_len * n_embd;
} else {
layer_input = tensor_alloc(target_len, n_embd);
for (int i = 0; i < target_len * n_embd; i++) {
layer_input.data[i] = lc->x_after_tm.data[i] - lc->tm_out.data[i];
}
}
/* Layer norm 1 backward - use cached layer input */
Tensor layer_input_view = {
.data = lc->x_in.data,
.rows = target_len,
.cols = n_embd,
.size = target_len * n_embd
};
layer_norm_backward_seq(&d_x_layer_in, &lg->ln1_weight, &lg->ln1_bias,
&d_x_ln1, &layer_input_view, &lp->ln1_weight);
if (layer_idx > 0) {
tensor_free(&layer_input);
}
/* Add residual gradient from TM block */
for (int i = 0; i < target_len * n_embd; i++) {
d_x_layer_in.data[i] += d_x_after_tm.data[i];
}
/* This gradient flows to the previous layer or initial embedding */
tensor_copy(&d_x_final, &d_x_layer_in);
/* Free layer-specific gradients */
tensor_free(&d_tm_out);
tensor_free(&d_wkv_r);
tensor_free(&d_wkv);
tensor_free(&d_r);
tensor_free(&d_k_exp);
tensor_free(&d_v);
tensor_free(&d_decay);
tensor_free(&d_time_first_exp);
tensor_free(&d_decay_pre);
tensor_free(&d_decay_delta);
tensor_free(&d_decay_tmp);
tensor_free(&d_x_ln1_decay);
tensor_free(&d_r_pre);
tensor_free(&d_k_pre);
tensor_free(&d_xr);
tensor_free(&d_xk);
tensor_free(&d_xv);
tensor_free(&d_x_ln1);
tensor_free(&d_x_shifted);
tensor_free(&d_x_layer_in);
tensor_free(&d_x_after_tm);
}
/* Initial layer norm backward */
Tensor d_emb_out = tensor_alloc(target_len, n_embd);
Tensor emb_out_view = {
.data = cache->emb_out.data,
.rows = target_len,
.cols = n_embd,
.size = target_len * n_embd
};
layer_norm_backward_seq(&d_emb_out, &grads->ln0_weight, &grads->ln0_bias,
&d_x_final, &emb_out_view, &mp->ln0_weight);
/* Embedding backward */
for (int t = 0; t < target_len; t++) {
int tok = tokens[t];
for (int i = 0; i < n_embd; i++) {
grads->emb.data[tok * n_embd + i] += d_emb_out.data[t * n_embd + i];
}
}
tensor_free(&d_logits);
tensor_free(&d_x_ln_out);
tensor_free(&d_x_final);
tensor_free(&d_x_tmp);
tensor_free(&d_emb_out);
}
/* ============================================================
* Adam Optimizer
* ============================================================ */
typedef struct {
float beta1;
float beta2;
float epsilon;
float weight_decay;
int t; /* timestep */
} AdamConfig;
typedef struct {
Tensor m; /* First moment */
Tensor v; /* Second moment */
} AdamState;
typedef struct {
AdamState emb;
AdamState ln0_weight, ln0_bias;
struct {
AdamState ln1_weight, ln1_bias;
AdamState ln2_weight, ln2_bias;
AdamState time_shift_w1, time_shift_w2, time_shift_w4;
AdamState time_mix_r, time_mix_k, time_mix_v;
AdamState decay_lora_a, decay_lora_b;
AdamState decay_base;
AdamState time_first;
AdamState Wr, Wk, Wv, Wo;
AdamState channel_mix;
AdamState ffn_gate, ffn_up, ffn_down;
AdamState mem_gate_write;
AdamState mem_gate_read;
} *layers;
AdamState ln_out_weight, ln_out_bias;
AdamState head;
int n_layers;
} AdamStates;
static void init_adam_state(AdamState *as, int rows, int cols) {
as->m = tensor_alloc(rows, cols);
as->v = tensor_alloc(rows, cols);
}
static void free_adam_state(AdamState *as) {
tensor_free(&as->m);
tensor_free(&as->v);
}
static void init_adam_states(AdamStates *as, const lrnnConfig *cfg) {
int n_embd = cfg->n_embd;
int vocab_size = cfg->vocab_size;
int ffn_h = ffn_hidden(cfg);
int lora_rank = cfg->decay_lora_rank;
as->n_layers = cfg->n_layer;
init_adam_state(&as->emb, vocab_size, n_embd);
init_adam_state(&as->ln0_weight, n_embd, 1);
init_adam_state(&as->ln0_bias, n_embd, 1);
as->layers = calloc((size_t)cfg->n_layer, sizeof(*as->layers));
for (int i = 0; i < cfg->n_layer; i++) {
init_adam_state(&as->layers[i].ln1_weight, n_embd, 1);
init_adam_state(&as->layers[i].ln1_bias, n_embd, 1);
init_adam_state(&as->layers[i].ln2_weight, n_embd, 1);
init_adam_state(&as->layers[i].ln2_bias, n_embd, 1);
init_adam_state(&as->layers[i].time_shift_w1, n_embd, 1);
init_adam_state(&as->layers[i].time_shift_w2, n_embd, 1);
init_adam_state(&as->layers[i].time_shift_w4, n_embd, 1);
init_adam_state(&as->layers[i].time_mix_r, n_embd, 1);
init_adam_state(&as->layers[i].time_mix_k, n_embd, 1);
init_adam_state(&as->layers[i].time_mix_v, n_embd, 1);
init_adam_state(&as->layers[i].decay_lora_a, n_embd, lora_rank);
init_adam_state(&as->layers[i].decay_lora_b, lora_rank, n_embd);
init_adam_state(&as->layers[i].decay_base, n_embd, 1);
init_adam_state(&as->layers[i].time_first, n_embd, 1);
init_adam_state(&as->layers[i].Wr, n_embd, n_embd);
init_adam_state(&as->layers[i].Wk, n_embd, n_embd);
init_adam_state(&as->layers[i].Wv, n_embd, n_embd);
init_adam_state(&as->layers[i].Wo, n_embd, n_embd);
init_adam_state(&as->layers[i].channel_mix, n_embd, 1);
init_adam_state(&as->layers[i].ffn_gate, n_embd, ffn_h);
init_adam_state(&as->layers[i].ffn_up, n_embd, ffn_h);
init_adam_state(&as->layers[i].ffn_down, ffn_h, n_embd);
init_adam_state(&as->layers[i].mem_gate_write, n_embd, cfg->n_mem_slots);
init_adam_state(&as->layers[i].mem_gate_read, n_embd, cfg->n_mem_slots);
}
init_adam_state(&as->ln_out_weight, n_embd, 1);
init_adam_state(&as->ln_out_bias, n_embd, 1);
init_adam_state(&as->head, n_embd, vocab_size);
}
static void free_adam_states(AdamStates *as) {
free_adam_state(&as->emb);
free_adam_state(&as->ln0_weight);
free_adam_state(&as->ln0_bias);
for (int i = 0; i < as->n_layers; i++) {
free_adam_state(&as->layers[i].ln1_weight);
free_adam_state(&as->layers[i].ln1_bias);
free_adam_state(&as->layers[i].ln2_weight);
free_adam_state(&as->layers[i].ln2_bias);
free_adam_state(&as->layers[i].time_shift_w1);
free_adam_state(&as->layers[i].time_shift_w2);
free_adam_state(&as->layers[i].time_shift_w4);
free_adam_state(&as->layers[i].time_mix_r);
free_adam_state(&as->layers[i].time_mix_k);
free_adam_state(&as->layers[i].time_mix_v);
free_adam_state(&as->layers[i].decay_lora_a);
free_adam_state(&as->layers[i].decay_lora_b);
free_adam_state(&as->layers[i].decay_base);
free_adam_state(&as->layers[i].time_first);
free_adam_state(&as->layers[i].Wr);
free_adam_state(&as->layers[i].Wk);
free_adam_state(&as->layers[i].Wv);
free_adam_state(&as->layers[i].Wo);
free_adam_state(&as->layers[i].channel_mix);
free_adam_state(&as->layers[i].ffn_gate);
free_adam_state(&as->layers[i].ffn_up);
free_adam_state(&as->layers[i].ffn_down);
free_adam_state(&as->layers[i].mem_gate_write);
free_adam_state(&as->layers[i].mem_gate_read);
}
free(as->layers);
as->layers = NULL;
free_adam_state(&as->ln_out_weight);
free_adam_state(&as->ln_out_bias);
free_adam_state(&as->head);
}
/* ============================================================
* Adam Optimizer Update Step
* ============================================================ */
static void adam_update(Tensor *param, Tensor *grad, AdamState *state,
AdamConfig *config, float lr) {
float beta1 = config->beta1;
float beta2 = config->beta2;
float eps = config->epsilon;
float wd = config->weight_decay;
int t = config->t;
/* Bias correction factors */
float bias_correction1 = 1.0f - powf(beta1, (float)t);
float bias_correction2 = 1.0f - powf(beta2, (float)t);
for (int i = 0; i < param->size; i++) {
float g = grad->data[i];
/* Gradient clipping */
// if (g > GRAD_CLIP) g = GRAD_CLIP;
// if (g < -GRAD_CLIP) g = -GRAD_CLIP;
/* Update biased first moment estimate */
state->m.data[i] = beta1 * state->m.data[i] + (1.0f - beta1) * g;
/* Update biased second raw moment estimate */
state->v.data[i] = beta2 * state->v.data[i] + (1.0f - beta2) * g * g;
/* Compute bias-corrected estimates */
float m_hat = state->m.data[i] / bias_correction1;
float v_hat = state->v.data[i] / bias_correction2;
/* Update parameters with AdamW weight decay */
param->data[i] -= lr * (m_hat / (sqrtf(v_hat) + eps) + wd * param->data[i]);
}
}
/* ============================================================
* Gradient Norm Clipping
* ============================================================ */
static float compute_tensor_norm_sq(const Tensor *t) {
float sum = 0.0f;
for (int i = 0; i < t->size; i++) {
sum += t->data[i] * t->data[i];
}
return sum;
}
static void scale_tensor(Tensor *t, float scale) {
for (int i = 0; i < t->size; i++) {
t->data[i] *= scale;
}
}
static void clip_gradients_by_global_norm(ModelGrads *grads, float max_norm) {
/* Compute global L2 norm of all gradients */
float total_norm_sq = 0.0f;
total_norm_sq += compute_tensor_norm_sq(&grads->emb);
total_norm_sq += compute_tensor_norm_sq(&grads->ln0_weight);
total_norm_sq += compute_tensor_norm_sq(&grads->ln0_bias);
for (int i = 0; i < grads->n_layers; i++) {
LayerGrads *lg = &grads->layers[i];
total_norm_sq += compute_tensor_norm_sq(&lg->ln1_weight);
total_norm_sq += compute_tensor_norm_sq(&lg->ln1_bias);
total_norm_sq += compute_tensor_norm_sq(&lg->ln2_weight);
total_norm_sq += compute_tensor_norm_sq(&lg->ln2_bias);
total_norm_sq += compute_tensor_norm_sq(&lg->time_shift_w1);
total_norm_sq += compute_tensor_norm_sq(&lg->time_shift_w2);
total_norm_sq += compute_tensor_norm_sq(&lg->time_shift_w4);
total_norm_sq += compute_tensor_norm_sq(&lg->time_mix_r);
total_norm_sq += compute_tensor_norm_sq(&lg->time_mix_k);
total_norm_sq += compute_tensor_norm_sq(&lg->time_mix_v);
total_norm_sq += compute_tensor_norm_sq(&lg->decay_lora_a);
total_norm_sq += compute_tensor_norm_sq(&lg->decay_lora_b);
total_norm_sq += compute_tensor_norm_sq(&lg->decay_base);
total_norm_sq += compute_tensor_norm_sq(&lg->time_first);
total_norm_sq += compute_tensor_norm_sq(&lg->Wr);
total_norm_sq += compute_tensor_norm_sq(&lg->Wk);
total_norm_sq += compute_tensor_norm_sq(&lg->Wv);
total_norm_sq += compute_tensor_norm_sq(&lg->Wo);
total_norm_sq += compute_tensor_norm_sq(&lg->channel_mix);
total_norm_sq += compute_tensor_norm_sq(&lg->ffn_gate);
total_norm_sq += compute_tensor_norm_sq(&lg->ffn_up);
total_norm_sq += compute_tensor_norm_sq(&lg->ffn_down);
total_norm_sq += compute_tensor_norm_sq(&lg->mem_gate_write);
total_norm_sq += compute_tensor_norm_sq(&lg->mem_gate_read);
}
total_norm_sq += compute_tensor_norm_sq(&grads->ln_out_weight);
total_norm_sq += compute_tensor_norm_sq(&grads->ln_out_bias);
total_norm_sq += compute_tensor_norm_sq(&grads->head);
float total_norm = sqrtf(total_norm_sq);
/* Scale gradients if norm exceeds max */
if (total_norm > max_norm) {
float scale = max_norm / (total_norm + 1e-8f);
scale_tensor(&grads->emb, scale);
scale_tensor(&grads->ln0_weight, scale);
scale_tensor(&grads->ln0_bias, scale);
for (int i = 0; i < grads->n_layers; i++) {
LayerGrads *lg = &grads->layers[i];
scale_tensor(&lg->ln1_weight, scale);
scale_tensor(&lg->ln1_bias, scale);
scale_tensor(&lg->ln2_weight, scale);
scale_tensor(&lg->ln2_bias, scale);
scale_tensor(&lg->time_shift_w1, scale);
scale_tensor(&lg->time_shift_w2, scale);
scale_tensor(&lg->time_shift_w4, scale);
scale_tensor(&lg->time_mix_r, scale);
scale_tensor(&lg->time_mix_k, scale);
scale_tensor(&lg->time_mix_v, scale);
scale_tensor(&lg->decay_lora_a, scale);
scale_tensor(&lg->decay_lora_b, scale);
scale_tensor(&lg->decay_base, scale);
scale_tensor(&lg->time_first, scale);
scale_tensor(&lg->Wr, scale);
scale_tensor(&lg->Wk, scale);
scale_tensor(&lg->Wv, scale);
scale_tensor(&lg->Wo, scale);
scale_tensor(&lg->channel_mix, scale);
scale_tensor(&lg->ffn_gate, scale);
scale_tensor(&lg->ffn_up, scale);
scale_tensor(&lg->ffn_down, scale);
scale_tensor(&lg->mem_gate_write, scale);
scale_tensor(&lg->mem_gate_read, scale);
}
scale_tensor(&grads->ln_out_weight, scale);
scale_tensor(&grads->ln_out_bias, scale);
scale_tensor(&grads->head, scale);
}
}
/* ============================================================
* Apply Adam Updates to All Parameters
* ============================================================ */
static void apply_adam_updates(ModelParams *mp, ModelGrads *grads,
AdamStates *adam, AdamConfig *config, float lr) {
/* Increment timestep */
config->t++;
/* Embedding and initial layer norm */
adam_update(&mp->emb, &grads->emb, &adam->emb, config, lr);
adam_update(&mp->ln0_weight, &grads->ln0_weight, &adam->ln0_weight, config, lr);
adam_update(&mp->ln0_bias, &grads->ln0_bias, &adam->ln0_bias, config, lr);
/* Per-layer parameters */
for (int i = 0; i < mp->n_layers; i++) {
LayerParams *lp = &mp->layers[i];
LayerGrads *lg = &grads->layers[i];
/* Layer norms */
adam_update(&lp->ln1_weight, &lg->ln1_weight, &adam->layers[i].ln1_weight, config, lr);
adam_update(&lp->ln1_bias, &lg->ln1_bias, &adam->layers[i].ln1_bias, config, lr);
adam_update(&lp->ln2_weight, &lg->ln2_weight, &adam->layers[i].ln2_weight, config, lr);
adam_update(&lp->ln2_bias, &lg->ln2_bias, &adam->layers[i].ln2_bias, config, lr);
/* Multi-scale token shift */
adam_update(&lp->time_shift_w1, &lg->time_shift_w1, &adam->layers[i].time_shift_w1, config, lr);
adam_update(&lp->time_shift_w2, &lg->time_shift_w2, &adam->layers[i].time_shift_w2, config, lr);
adam_update(&lp->time_shift_w4, &lg->time_shift_w4, &adam->layers[i].time_shift_w4, config, lr);
/* Token mixing ratios */
adam_update(&lp->time_mix_r, &lg->time_mix_r, &adam->layers[i].time_mix_r, config, lr);
adam_update(&lp->time_mix_k, &lg->time_mix_k, &adam->layers[i].time_mix_k, config, lr);
adam_update(&lp->time_mix_v, &lg->time_mix_v, &adam->layers[i].time_mix_v, config, lr);
/* Data-dependent decay */
adam_update(&lp->decay_lora_a, &lg->decay_lora_a, &adam->layers[i].decay_lora_a, config, lr);
adam_update(&lp->decay_lora_b, &lg->decay_lora_b, &adam->layers[i].decay_lora_b, config, lr);
adam_update(&lp->decay_base, &lg->decay_base, &adam->layers[i].decay_base, config, lr);
adam_update(&lp->time_first, &lg->time_first, &adam->layers[i].time_first, config, lr);
/* Projections */
adam_update(&lp->Wr, &lg->Wr, &adam->layers[i].Wr, config, lr);
adam_update(&lp->Wk, &lg->Wk, &adam->layers[i].Wk, config, lr);
adam_update(&lp->Wv, &lg->Wv, &adam->layers[i].Wv, config, lr);
adam_update(&lp->Wo, &lg->Wo, &adam->layers[i].Wo, config, lr);
/* Channel mix and FFN */
adam_update(&lp->channel_mix, &lg->channel_mix, &adam->layers[i].channel_mix, config, lr);
adam_update(&lp->ffn_gate, &lg->ffn_gate, &adam->layers[i].ffn_gate, config, lr);
adam_update(&lp->ffn_up, &lg->ffn_up, &adam->layers[i].ffn_up, config, lr);
adam_update(&lp->ffn_down, &lg->ffn_down, &adam->layers[i].ffn_down, config, lr);
adam_update(&lp->mem_gate_write, &lg->mem_gate_write,
&adam->layers[i].mem_gate_write, config, lr);
adam_update(&lp->mem_gate_read, &lg->mem_gate_read,
&adam->layers[i].mem_gate_read, config, lr);
}
/* Output layer norm and head */
adam_update(&mp->ln_out_weight, &grads->ln_out_weight, &adam->ln_out_weight, config, lr);
adam_update(&mp->ln_out_bias, &grads->ln_out_bias, &adam->ln_out_bias, config, lr);
adam_update(&mp->head, &grads->head, &adam->head, config, lr);
}
/* ============================================================
* Training Function
* ============================================================ */
static void train_model(const char *corpus_path, const char *save_path,
int epochs, lrnnConfig *cfg, float lr,
TokenizerType tok_type, bool auto_config) {
printf("======================================================================\n");
printf(" lrnn-like Model - Training (C Implementation)\n");
printf("======================================================================\n\n");
/* Load corpus */
printf("Loading corpus: %s\n", corpus_path);
FILE *f = fopen(corpus_path, "rb");
if (!f) {
fprintf(stderr, "Error: cannot open corpus file: %s\n", corpus_path);
return;
}
fseek(f, 0, SEEK_END);
long file_size = ftell(f);
fseek(f, 0, SEEK_SET);
if (file_size <= 0) {
fprintf(stderr, "Error: empty or invalid corpus file\n");
fclose(f);
return;
}
char *text = (char *)malloc((size_t)file_size + 1);
if (!text) {
fprintf(stderr, "Error: cannot allocate memory for corpus\n");
fclose(f);
return;
}
size_t read_size = fread(text, 1, (size_t)file_size, f);
text[read_size] = '\0';
fclose(f);
printf(" Loaded %zu bytes\n", read_size);
/* Build tokenizer */
printf("\nBuilding tokenizer...\n");
Tokenizer tok;
init_tokenizer(&tok, tok_type);
build_tokenizer(&tok, text, read_size, tok_type);
int vocab_size = tokenizer_vocab_size(&tok);
printf(" Vocabulary size: %d %s\n", vocab_size,
tok.type == TOKENIZER_CHAR ? "characters" : "words");
/* Auto-configure model if requested */
if (auto_config) {
printf("\nAuto-configuring model for corpus size...\n");
*cfg = config_for_corpus(file_size, tok.type, vocab_size);
} else {
cfg->vocab_size = vocab_size;
}
/* Tokenize */
int token_count;
int *tokens = tokenizer_encode(&tok, text, read_size, &token_count);
free(text);
printf(" Token count: %d\n", token_count);
if (tok.type == TOKENIZER_WORD) {
printf(" Compression ratio: %.2fx\n", (float)read_size / (float)token_count);
}
/* Initialize model */
printf("\nInitializing model...\n");
printf(" Layers: %d\n", cfg->n_layer);
printf(" Embedding dim: %d\n", cfg->n_embd);
printf(" FFN hidden: %d\n", ffn_hidden(cfg));
printf(" Context length: %d\n", cfg->ctx_len);
printf(" LoRA rank: %d\n", cfg->decay_lora_rank);
ModelParams mp;
memset(&mp, 0, sizeof(mp));
init_model_params(&mp, cfg);
/* Count parameters */
long total_params = 0;
total_params += mp.emb.size;
total_params += mp.ln0_weight.size + mp.ln0_bias.size;
for (int i = 0; i < mp.n_layers; i++) {
LayerParams *lp = &mp.layers[i];
total_params += lp->ln1_weight.size + lp->ln1_bias.size;
total_params += lp->ln2_weight.size + lp->ln2_bias.size;
total_params += lp->time_shift_w1.size + lp->time_shift_w2.size + lp->time_shift_w4.size;
total_params += lp->time_mix_r.size + lp->time_mix_k.size + lp->time_mix_v.size;
total_params += lp->decay_lora_a.size + lp->decay_lora_b.size;
total_params += lp->decay_base.size + lp->time_first.size;
total_params += lp->Wr.size + lp->Wk.size + lp->Wv.size + lp->Wo.size;
total_params += lp->channel_mix.size;
total_params += lp->ffn_gate.size + lp->ffn_up.size + lp->ffn_down.size;
}
total_params += mp.ln_out_weight.size + mp.ln_out_bias.size;
total_params += mp.head.size;
printf(" Total parameters: %ld (%.2f MB)\n", total_params,
(float)total_params * sizeof(float) / (1024.0f * 1024.0f));
printf(" Params per byte: %.2f\n", (float)total_params / (float)read_size);
/* Initialize gradients */
ModelGrads grads;
memset(&grads, 0, sizeof(grads));
init_model_grads(&grads, &mp, cfg);
/* Initialize Adam optimizer */
AdamStates adam;
memset(&adam, 0, sizeof(adam));
init_adam_states(&adam, cfg);
AdamConfig adam_cfg = {
.beta1 = 0.9f,
.beta2 = 0.999f,
.epsilon = 1e-8f,
.weight_decay = 0.0f,
.t = 0
};
/* Training configuration */
int batch_size = cfg->ctx_len;
if (batch_size > token_count - 1) {
batch_size = token_count - 1;
}
int n_batches = (token_count - 1) / batch_size;
if (n_batches < 1) n_batches = 1;
/* Allocate forward cache */
ForwardCache cache;
memset(&cache, 0, sizeof(cache));
init_forward_cache(&cache, batch_size, cfg);
printf("\n======================================================================\n");
printf("Starting training...\n");
printf(" Tokenizer: %s\n", tok.type == TOKENIZER_CHAR ? "character" : "word");
printf(" Batch size: %d tokens\n", batch_size);
printf(" Batches per epoch: %d\n", n_batches);
printf("======================================================================\n\n");
time_t start_time = time(NULL);
float best_loss = FLT_MAX;
for (int epoch = 0; epoch < epochs; epoch++) {
float epoch_loss = 0.0f;
int batch_count = 0;
/* Shuffle batches (just random offset each epoch) */
int offset = rand() % (batch_size > 10 ? 10 : 1);
for (int batch = 0; batch < n_batches; batch++) {
int start = offset + batch * batch_size;
if (start + batch_size >= token_count) continue;
int *batch_tokens = tokens + start;
int seq_len = batch_size;
/* Zero gradients */
zero_model_grads(&grads);
/* Forward pass with caching */
float loss = forward_with_cache(&cache, batch_tokens, seq_len, &mp, cfg);
/* Check for NaN/Inf */
if (!isfinite(loss)) {
printf("Warning: NaN/Inf loss detected at epoch %d, batch %d. Skipping.\n",
epoch + 1, batch);
continue;
}
/* Backward pass */
backward_pass(&grads, &cache, batch_tokens, seq_len, &mp, cfg);
/* Global gradient clipping */
clip_gradients_by_global_norm(&grads, 5.0f);
/* Adam update */
apply_adam_updates(&mp, &grads, &adam, &adam_cfg, lr);
epoch_loss += loss;
batch_count++;
/* Progress indicator */
if ((batch + 1) % 10 == 0 || batch == n_batches - 1) {
printf("\r Epoch %d/%d - Batch %d/%d - Loss: %.4f",
epoch + 1, epochs, batch + 1, n_batches,
batch_count > 0 ? epoch_loss / (float)batch_count : 0.0f);
fflush(stdout);
}
}
/* Compute epoch statistics */
float avg_loss = (batch_count > 0) ? epoch_loss / (float)batch_count : 0.0f;
float perplexity = expf(avg_loss);
time_t elapsed = time(NULL) - start_time;
int hours = (int)(elapsed / 3600);
int mins = (int)((elapsed % 3600) / 60);
int secs = (int)(elapsed % 60);
printf("\n Epoch %d/%d complete - Loss: %.4f - Perplexity: %.2f - Time: %02d:%02d:%02d\n",
epoch + 1, epochs, avg_loss, perplexity, hours, mins, secs);
/* Track best loss */
if (avg_loss < best_loss) {
best_loss = avg_loss;
printf(" ** New best loss! **\n");
}
/* Save checkpoint */
if ((epoch + 1) % 5 == 0 || epoch == epochs - 1) {
printf(" Saving checkpoint to: %s\n", save_path);
if (save_model(save_path, &mp, cfg, &tok) != 0) {
fprintf(stderr, " Warning: failed to save checkpoint\n");
}
}
}
/* Final save */
printf("\nSaving final model to: %s\n", save_path);
save_model(save_path, &mp, cfg, &tok);
/* Cleanup */
free(tokens);
free_forward_cache(&cache);
free_model_grads(&grads);
free_adam_states(&adam);
free_model_params(&mp);
free_tokenizer(&tok);
}
/* ============================================================
* Generation with Dynamic State Checkpointing
* ============================================================ */
/* ============================================================
* Generation with Dynamic State Checkpointing
* ============================================================ */
/* Comparison function for qsort - descending by probability */
typedef struct {
float prob;
int index;
} ProbIndex;
static int prob_index_cmp_desc(const void *a, const void *b) {
float pa = ((const ProbIndex *)a)->prob;
float pb = ((const ProbIndex *)b)->prob;
if (pa > pb) return -1;
if (pa < pb) return 1;
return 0;
}
static int sample_top_p(const float *probs, int vocab_size, float top_p) {
/* Build (prob, index) pairs */
ProbIndex *pi = (ProbIndex *)malloc((size_t)vocab_size * sizeof(ProbIndex));
if (!pi) {
/* Fallback: argmax */
int best = 0;
for (int i = 1; i < vocab_size; i++) {
if (probs[i] > probs[best]) best = i;
}
return best;
}
for (int i = 0; i < vocab_size; i++) {
pi[i].prob = probs[i];
pi[i].index = i;
}
/* O(V log V) sort instead of O(V²) bubble sort */
qsort(pi, (size_t)vocab_size, sizeof(ProbIndex), prob_index_cmp_desc);
/* Find top-p cutoff */
float cumsum = 0.0f;
int cutoff = vocab_size; /* default: use all */
for (int i = 0; i < vocab_size; i++) {
cumsum += pi[i].prob;
if (cumsum >= top_p) {
cutoff = i + 1;
break;
}
}
if (cutoff < 1) cutoff = 1;
/* Renormalize over the kept tokens */
float sum = 0.0f;
for (int i = 0; i < cutoff; i++) {
sum += pi[i].prob;
}
/* Sample from the truncated distribution */
float r = ((float)rand() / (float)RAND_MAX) * sum;
float running = 0.0f;
int sampled = pi[0].index;
for (int i = 0; i < cutoff; i++) {
running += pi[i].prob;
if (running >= r) {
sampled = pi[i].index;
break;
}
}
free(pi);
return sampled;
}
static void generate_text(const char *model_path, const char *seed_text,
int n_tokens, float temperature, float top_p) {
printf("======================================================================\n");
printf(" Text Generation\n");
printf("======================================================================\n\n");
/* Load model */
printf("Loading model: %s\n", model_path);
ModelParams mp;
lrnnConfig cfg;
Tokenizer tok;
memset(&mp, 0, sizeof(mp));
memset(&tok, 0, sizeof(tok));
if (load_model(model_path, &mp, &cfg, &tok) != 0) {
fprintf(stderr, "Failed to load model\n");
return;
}
printf(" Model loaded!\n");
printf(" Tokenizer: %s\n", tok.type == TOKENIZER_CHAR ? "character" : "word");
printf(" Vocab size: %d\n", cfg.vocab_size);
printf(" Layers: %d, Dim: %d\n", cfg.n_layer, cfg.n_embd);
/* Initialize state */
ModelState state;
init_model_state(&state, &cfg);
printf("\nSeed: \"%s\"\n", seed_text);
printf("Generating %d tokens (temp=%.2f, top_p=%.2f)\n\n", n_tokens, temperature, top_p);
printf("======================================================================\n");
printf("%s", seed_text);
fflush(stdout);
float *logits = (float *)malloc((size_t)cfg.vocab_size * sizeof(float));
float *probs = (float *)malloc((size_t)cfg.vocab_size * sizeof(float));
/* Tokenize and process seed */
int seed_token_count;
int *seed_tokens = tokenizer_encode(&tok, seed_text, strlen(seed_text), &seed_token_count);
int last_token = 0;
for (int i = 0; i < seed_token_count; i++) {
forward_single(logits, seed_tokens[i], &mp, &state, &cfg);
last_token = seed_tokens[i];
}
free(seed_tokens);
/* Generate tokens */
srand((unsigned int)time(NULL));
char decode_buf[MAX_WORD_LEN];
for (int i = 0; i < n_tokens; i++) {
forward_single(logits, last_token, &mp, &state, &cfg);
/* Apply temperature */
if (temperature != 1.0f) {
for (int j = 0; j < cfg.vocab_size; j++) {
logits[j] /= temperature;
}
}
softmax_vec(probs, logits, cfg.vocab_size);
/* Sample */
int next_token = sample_top_p(probs, cfg.vocab_size, top_p);
/* Decode and print */
tokenizer_decode_token(&tok, next_token, decode_buf, sizeof(decode_buf));
printf("%s", decode_buf);
fflush(stdout);
last_token = next_token;
}
printf("\n======================================================================\n");
/* Cleanup */
free(logits);
free(probs);
free_model_state(&state);
free_model_params(&mp);
free_tokenizer(&tok);
}
/* ============================================================
* Main Entry Point
* ============================================================ */
static void print_usage(const char *prog) {
printf("Usage:\n");
printf(" Training:\n");
printf(" %s --train corpus.txt --save model.bin [options]\n\n", prog);
printf(" Generation:\n");
printf(" %s --load model.bin --seed \"text\" [options]\n\n", prog);
printf("Options:\n");
printf(" --train FILE Path to training corpus\n");
printf(" --save FILE Path to save model\n");
printf(" --load FILE Path to load model\n");
printf(" --seed TEXT Seed text for generation\n");
printf(" --epochs N Training epochs (default: 20)\n");
printf(" --tokens N Tokens to generate (default: 200)\n");
printf(" --layers N Number of layers (default: auto)\n");
printf(" --dim N Embedding dimension (default: auto)\n");
printf(" --heads N Number of heads (default: auto)\n");
printf(" --ctx N Max context length (default: auto)\n");
printf(" --lr FLOAT Learning rate (default: 0.0003)\n");
printf(" --temp FLOAT Temperature (default: 0.8)\n");
printf(" --top_p FLOAT Top-p sampling (default: 0.9)\n");
printf(" --tokenizer TYPE char, word, or auto (default: auto)\n");
printf(" --auto-config Auto-configure model size (default: on)\n");
printf(" --no-auto-config Disable auto-configuration\n");
printf(" --help Show this help\n");
}
int main(int argc, char *argv[]) {
static struct option long_options[] = {
{"train", required_argument, 0, 't'},
{"save", required_argument, 0, 's'},
{"load", required_argument, 0, 'l'},
{"seed", required_argument, 0, 'S'},
{"epochs", required_argument, 0, 'e'},
{"tokens", required_argument, 0, 'n'},
{"layers", required_argument, 0, 'L'},
{"dim", required_argument, 0, 'd'},
{"heads", required_argument, 0, 'h'},
{"ctx", required_argument, 0, 'c'},
{"lr", required_argument, 0, 'r'},
{"temp", required_argument, 0, 'T'},
{"top_p", required_argument, 0, 'p'},
{"tokenizer", required_argument, 0, 'k'},
{"auto-config", no_argument, 0, 'A'},
{"no-auto-config", no_argument, 0, 'N'},
{"help", no_argument, 0, 'H'},
{0, 0, 0, 0}
};
char *train_path = NULL;
char *save_path = NULL;
char *load_path = NULL;
char *seed_text = NULL;
int epochs = 20;
int n_tokens = 200;
float lr = 0.0003f;
float temperature = 0.8f;
float top_p = 0.9f;
TokenizerType tok_type = TOKENIZER_AUTO;
bool auto_config = true;
bool manual_config = false;
lrnnConfig cfg = default_config();
int opt;
int option_index = 0;
while ((opt = getopt_long(argc, argv, "t:s:l:S:e:n:L:d:h:c:r:T:p:k:ANH",
long_options, &option_index)) != -1) {
switch (opt) {
case 't': train_path = optarg; break;
case 's': save_path = optarg; break;
case 'l': load_path = optarg; break;
case 'S': seed_text = optarg; break;
case 'e': epochs = atoi(optarg); break;
case 'n': n_tokens = atoi(optarg); break;
case 'L': cfg.n_layer = atoi(optarg); manual_config = true; break;
case 'd': cfg.n_embd = atoi(optarg); manual_config = true; break;
case 'h': cfg.n_head = atoi(optarg); manual_config = true; break;
case 'c': cfg.ctx_len = atoi(optarg); manual_config = true; break;
case 'r': lr = (float)atof(optarg); break;
case 'T': temperature = (float)atof(optarg); break;
case 'p': top_p = (float)atof(optarg); break;
case 'k':
if (strcmp(optarg, "char") == 0) {
tok_type = TOKENIZER_CHAR;
} else if (strcmp(optarg, "word") == 0) {
tok_type = TOKENIZER_WORD;
} else {
tok_type = TOKENIZER_AUTO;
}
break;
case 'A': auto_config = true; break;
case 'N': auto_config = false; break;
case 'H':
default:
print_usage(argv[0]);
return (opt == 'H') ? 0 : 1;
}
}
/* If user specified any model params manually, disable auto-config */
if (manual_config) {
auto_config = false;
}
/* Training mode */
if (train_path) {
if (!save_path) {
fprintf(stderr, "Error: --save is required when training\n");
return 1;
}
train_model(train_path, save_path, epochs, &cfg, lr, tok_type, auto_config);
}
/* Generation mode */
else if (load_path) {
if (!seed_text || strlen(seed_text) == 0) {
fprintf(stderr, "Error: --seed is required for generation\n");
return 1;
}
generate_text(load_path, seed_text, n_tokens, temperature, top_p);
}
else {
print_usage(argv[0]);
return 1;
}
return 0;
}