2026-03-19

from scratch: small language model

Toy language model.

#ai#long
model: gpt-oss-20b human: nmcgi

I wanted to see if I could train a small transformer from scratch — one that could generate prose in the style of my own blog posts and serve as a local knowledge base to feed into a larger LLM.

This is a complete walkthrough. You'll need Python 3.10+, a CUDA-enabled GPU with at least 8 GB VRAM, and some patience.

Project layout

my-slm/
├─ data/
│   └─ posts/*.md          # your blog posts
├─ src/
│   ├─ preprocess.py       # markdown cleaning + tokenizer training
│   ├─ dataset.py          # HuggingFace Dataset wrapper
│   ├─ model.py            # transformer definition
│   └─ train.py            # training script
└─ requirements.txt

Install dependencies:

python -m venv .venv && source .venv/bin/activate

pip install torch==2.3.0 \
            transformers==4.41.0 \
            datasets==2.20.0 \
            sentencepiece==0.1.99 \
            accelerate==0.30.0 \
            tqdm \
            einops

For a CUDA-enabled PyTorch wheel, get the right command from pytorch.org.


Data ingestion

Strip front-matter, code blocks, and extra whitespace from the markdown files down to plain text.

# src/preprocess.py
import glob
import re
from pathlib import Path

def read_markdown_files(root: str = "data/posts") -> list[str]:
    """Return a list of raw markdown strings."""
    files = sorted(glob.glob(f"{root}/*.md"))
    return [Path(f).read_text(encoding="utf-8") for f in files]

def strip_front_matter(md: str) -> str:
    """Remove YAML front-matter (--- ... ---)."""
    return re.sub(r"^---.*?---\n", "", md, flags=re.S)

def remove_code_blocks(md: str) -> str:
    """Remove fenced code blocks ```...```"""
    return re.sub(r"```[\s\S]*?```", "", md)

def clean_markdown(raw_md: list[str]) -> list[str]:
    cleaned = []
    for txt in raw_md:
        txt = strip_front_matter(txt)
        txt = remove_code_blocks(txt)
        txt = re.sub(r"\n{2,}", "\n\n", txt.strip())
        cleaned.append(txt)
    return cleaned

if __name__ == "__main__":
    raws = read_markdown_files()
    cleaned = clean_markdown(raws)
    Path("data/cleaned.txt").write_text("\n\n".join(cleaned), encoding="utf-8")
python src/preprocess.py

Tokenizer

Train a BPE tokenizer using SentencePiece. 8k vocab is enough for technical tokens without making the embedding matrix unwieldy.

spm_train \
    --input=data/cleaned.txt \
    --model_prefix=spm \
    --vocab_size=8000 \
    --character_coverage=1.0 \
    --model_type=bpe \
    --pad_id=0 \
    --unk_id=1 \
    --bos_id=2 \
    --eos_id=3

This produces spm.model and spm.vocab. Wrap it for Hugging Face:

# src/tokenizer.py
from transformers import PreTrainedTokenizerFast
from tokenizers import SentencePieceBPETokenizer

tokenizer = PreTrainedTokenizerFast(tokenizer_file=None)
tokenizer = PreTrainedTokenizerFast.from_pretrained(".", tokenizer_file="spm.model")
tokenizer.save_pretrained("tokenizer")

Dataset

Tokenize into fixed-length sequences of 512 tokens.

# src/dataset.py
from pathlib import Path
from datasets import Dataset
from transformers import AutoTokenizer

TOKENIZER_PATH = "tokenizer"
SEQ_LEN = 512

def build_tokenized_dataset():
    tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH)

    def tokenize_function(examples):
        return tokenizer(
            examples["text"],
            truncation=True,
            max_length=SEQ_LEN,
            padding="max_length",
            return_attention_mask=False,
        )

    raw_text = Path("data/cleaned.txt").read_text(encoding="utf-8")
    texts = [p.strip() for p in raw_text.split("\n\n") if p.strip()]

    ds = Dataset.from_dict({"text": texts})
    tokenized_ds = ds.map(tokenize_function, batched=True, remove_columns=["text"])
    return tokenized_ds

if __name__ == "__main__":
    ds = build_tokenized_dataset()
    print(ds)
python src/dataset.py

Model architecture

A GPT-style transformer with roughly 125 M parameters. The causal mask is the key detail — it's passed as an additive attention bias (-inf for future positions, 0 elsewhere) so each position can only attend to itself and tokens before it.

# src/model.py
import torch
from torch import nn
from transformers import PreTrainedModel, PretrainedConfig

class TinyTransformerConfig(PretrainedConfig):
    def __init__(self,
                 vocab_size=8000,
                 hidden_size=768,
                 num_hidden_layers=12,
                 num_attention_heads=12,
                 intermediate_size=3072,
                 max_position_embeddings=512,
                 **kwargs):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.intermediate_size = intermediate_size
        self.max_position_embeddings = max_position_embeddings

class TinyTransformer(PreTrainedModel):
    config_class = TinyTransformerConfig

    def __init__(self, config: TinyTransformerConfig):
        super().__init__(config)
        self.embeddings = nn.Embedding(config.vocab_size,
                                       config.hidden_size,
                                       padding_idx=0)
        self.position_embeddings = nn.Embedding(
            config.max_position_embeddings, config.hidden_size
        )
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=config.hidden_size,
            nhead=config.num_attention_heads,
            dim_feedforward=config.intermediate_size,
            dropout=0.1,
            activation="gelu",
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, config.num_hidden_layers)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

    def _causal_mask(self, sz: int, device) -> torch.Tensor:
        # Additive mask: -inf blocks attention to future positions, 0 allows it
        return torch.triu(torch.full((sz, sz), float('-inf'), device=device), diagonal=1)

    def forward(self, input_ids):
        seq_len = input_ids.size(1)
        position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
        embeddings = self.embeddings(input_ids) + self.position_embeddings(position_ids)
        mask = self._causal_mask(seq_len, input_ids.device)
        # TransformerEncoder expects (S, N, E)
        encoder_output = self.encoder(embeddings.transpose(0, 1), mask=mask)
        logits = self.lm_head(encoder_output.transpose(0, 1))
        return logits

Training loop

Uses accelerate so the same script works on one GPU or many.

# src/train.py
from pathlib import Path
import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader
from accelerate import Accelerator
from transformers import AutoTokenizer, get_linear_schedule_with_warmup
from tqdm.auto import tqdm

from dataset import build_tokenized_dataset
from model import TinyTransformer, TinyTransformerConfig

OUTPUT_DIR = Path("checkpoints")
OUTPUT_DIR.mkdir(exist_ok=True)

def main():
    accelerator = Accelerator()
    tokenizer = AutoTokenizer.from_pretrained("tokenizer")

    ds = build_tokenized_dataset()
    train_loader = DataLoader(ds, batch_size=8, shuffle=True)

    config = TinyTransformerConfig(vocab_size=len(tokenizer))
    model = TinyTransformer(config)

    optimizer = AdamW(model.parameters(), lr=5e-4, weight_decay=0.01)
    total_steps = len(train_loader) * 3  # 3 epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(0.1 * total_steps),
        num_training_steps=total_steps,
    )

    model, optimizer, train_loader, scheduler = accelerator.prepare(
        model, optimizer, train_loader, scheduler
    )

    for epoch in range(3):
        model.train()
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
        for batch in pbar:
            input_ids = batch["input_ids"].to(accelerator.device)
            logits = model(input_ids)

            # Shift for next-token prediction
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = input_ids[..., 1:].contiguous()

            loss = torch.nn.CrossEntropyLoss(ignore_index=0)(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1)
            )

            accelerator.backward(loss)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

            pbar.set_postfix({"loss": f"{loss.item():.4f}"})

        if accelerator.is_main_process:
            ckpt_path = OUTPUT_DIR / f"epoch{epoch+1}.pth"
            torch.save({
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict(),
                "config": config.to_dict(),
            }, ckpt_path)
            print(f"\nCheckpoint saved to {ckpt_path}")

if __name__ == "__main__":
    main()

Run it:

accelerate launch src/train.py

The first time you run accelerate launch, it'll prompt you for a configuration. Choose "Single GPU" unless you have more.


Evaluation

Compute perplexity on a held-out split after training.

# src/eval.py
import torch
from tqdm.auto import tqdm
from model import TinyTransformer, TinyTransformerConfig

def compute_perplexity(model, dataloader):
    model.eval()
    total_loss = 0.0
    total_tokens = 0
    loss_fct = torch.nn.CrossEntropyLoss(ignore_index=0)

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            input_ids = batch["input_ids"].to(model.device)
            logits = model(input_ids)

            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = input_ids[..., 1:].contiguous()

            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
                            shift_labels.view(-1))
            total_loss += loss.item() * shift_labels.numel()
            total_tokens += shift_labels.numel()

    avg_loss = total_loss / total_tokens
    return torch.exp(torch.tensor(avg_loss)).item()

if __name__ == "__main__":
    ckpt = torch.load("checkpoints/epoch3.pth", map_location="cpu")
    config = TinyTransformerConfig(**ckpt["config"])
    model = TinyTransformer(config)
    model.load_state_dict(ckpt["model_state_dict"])

    from dataset import build_tokenized_dataset
    ds_val = build_tokenized_dataset()
    val_loader = torch.utils.data.DataLoader(ds_val, batch_size=8, shuffle=False)

    ppl = compute_perplexity(model, val_loader)
    print(f"Perplexity: {ppl:.2f}")

Perplexity around 20-30 on a small technical corpus is a reasonable target for a model this size.


Exporting for inference

Save the checkpoint in Hugging Face format so you can load it later with from_pretrained.

# src/export.py
import torch
from pathlib import Path
from transformers import AutoTokenizer
from model import TinyTransformer, TinyTransformerConfig

OUTPUT_DIR = Path("slm")
OUTPUT_DIR.mkdir(exist_ok=True)

def export():
    ckpt = torch.load("checkpoints/epoch3.pth", map_location="cpu")
    config = TinyTransformerConfig(**ckpt["config"])
    model = TinyTransformer(config)
    model.load_state_dict(ckpt["model_state_dict"])
    model.save_pretrained(str(OUTPUT_DIR))
    AutoTokenizer.from_pretrained("tokenizer").save_pretrained(str(OUTPUT_DIR))

if __name__ == "__main__":
    export()

Load it back:

from transformers import AutoTokenizer
from model import TinyTransformer

model = TinyTransformer.from_pretrained("slm")
tokenizer = AutoTokenizer.from_pretrained("slm")

prompt = "Explain the difference between a mutex and a semaphore in C++."
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
generated = model.generate(input_ids, max_new_tokens=200)
print(tokenizer.decode(generated[0], skip_special_tokens=True))

Using it as an advisor to a larger LLM

The idea: use the SLM as a local knowledge retrieval step, then inject its output into a larger model's prompt as context.

from transformers import AutoTokenizer
from model import TinyTransformer

slm = TinyTransformer.from_pretrained("slm")
slm_tokenizer = AutoTokenizer.from_pretrained("slm")

def get_slm_summary(topic: str, max_tokens=200):
    prompt = f"Summarize the key technical details about {topic}."
    ids = slm_tokenizer(prompt, return_tensors="pt").input_ids
    out = slm.generate(ids, max_new_tokens=max_tokens)
    return slm_tokenizer.decode(out[0], skip_special_tokens=True)

def ask_big_lm(question: str, big_lm_fn):
    summary = get_slm_summary(question.split()[0])  # crude topic extraction
    full_prompt = f"""You are an expert technical advisor.

Context (from internal knowledge base):
{summary}

Question:
{question}

Answer:"""
    return big_lm_fn(full_prompt)

big_lm_fn is whatever API call or local model you're routing to — Anthropic, OpenAI, a local Llama, whatever. For production, replace the naive topic extraction with something better (spaCy, keyword list, a small retrieval model).


Optional: quantization

If inference is slow, quantize to 8-bit with bitsandbytes:

pip install bitsandbytes
from transformers import BitsAndBytesConfig
from model import TinyTransformer

bnb_config = BitsAndBytesConfig(load_in_8bit=True)
model = TinyTransformer.from_pretrained("slm", quantization_config=bnb_config)

FlashAttention (flash-attn) is also worth adding if your GPU supports it — drop it in and the transformer layers pick it up automatically.