YatNMN-Softplus d=22 Chinchilla (1.08B) β€” PyTorch / HuggingFace Transformers

A 1.08B-parameter nanochat-architecture GPT with YatNMN-Softplus MLP, trained on English C4 to Chinchilla-optimal token budget (20Γ— params β‰ˆ 21.5B tokens) on a single TPU v6e-8 using FSDP + gradient accumulation.

This is the 1B-scale version of the d=12 (261M) YatNMN-Softplus model that achieved 2.98 loss β€” testing whether the YatNMN advantage scales.

Results

Scale Params Tokens C4 smooth loss Wall time Throughput
d=12 261M 5.22B 2.98 2.2 h 660K tok/s
d=22 1.08B 21.5B 2.83 47.9 h 125K tok/s

The loss improved by 0.15 nats going from 261M β†’ 1.08B at Chinchilla-optimal compute β€” consistent with standard scaling-law predictions.

Quick start

pip install torch transformers safetensors
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained(
    "mlnomad/yatnmn-softplus-d22-chinchilla-1B-pytorch",
    trust_remote_code=True,
    dtype=torch.float32,
).eval()

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")

prompt = "The meaning of life is"
ids = tokenizer(prompt, return_tensors="pt").input_ids
with torch.no_grad():
    out = model.generate(
        ids, max_new_tokens=50,
        do_sample=True, temperature=0.8, top_p=0.9,
        use_cache=True, pad_token_id=tokenizer.eos_token_id or 0,
    )
print(tokenizer.decode(out[0], skip_special_tokens=True))

Model details

Parameters 1,077,145,546 (~1.08B)
Architecture Nanochat GPT with YatNMN-Softplus MLP (d=22, n_embd=1408, n_head=22)
Config seq_len=1024, tied embeddings, SSSL sliding window
Training data allenai/c4 (English split), 21.5B tokens (Chinchilla 20Γ—)
Tokenizer mistralai/Mistral-7B-v0.1 (vocab 32,768)
Optimizer plain AdamW, peak LR 0.03, warmup-cosine, grad_accum=8 (effective batch 512)
Hardware TPU v6e-8 (TRC), europe-west4-a, FSDP + remat
Final loss (smooth) 2.8325
Wall time 47.85 h
Throughput 125K tok/s

YatNMN-Softplus MLP

y = Ξ± Β· (x Β· W + softplus(b))Β² / (||x βˆ’ W||Β² + softplus(Ξ΅))

Per-neuron (ff,) bias, scalar learnable Ξ΅, scalar learnable Ξ±. Same config as the 261M model β€” architecture identical except depth (22 vs 12) and width (1408 vs 768).

Training setup (1B on single v6e-8)

Training a 1B model on a single v6e-8 required:

  • FSDP: model params sharded 8-way across chips (first-dim partitioning on the fsdp mesh axis)
  • Gradient checkpointing (remat): dots_saveable policy on all blocks
  • Gradient accumulation: 8 micro-steps per optimizer apply (effective batch 512 samples = 524K tokens)
  • Batch per device: 8 (reduced from 32 at d=12 for HBM headroom)

Files

β”œβ”€β”€ config.json                       # HF config with auto_map
β”œβ”€β”€ model.safetensors                 # ~4.3 GB, fp32
β”œβ”€β”€ yatnmn_gpt.py                     # pure PyTorch Yat_GPT + YatNMN layer
β”œβ”€β”€ torch_gpt.py                      # shared building blocks
β”œβ”€β”€ configuration_yatnmn_gpt.py       # PretrainedConfig subclass
β”œβ”€β”€ modeling_yatnmn_gpt.py            # PreTrainedModel + KV cache + GenerationMixin
└── README.md

Related

License

Apache 2.0.

Downloads last month
3,699
Safetensors
Model size
1B params
Tensor type
F32
Β·
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Dataset used to train mlnomad/yatnmn-softplus-d22-chinchilla-1B-pytorch

Space using mlnomad/yatnmn-softplus-d22-chinchilla-1B-pytorch 1