Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,15 @@ interpretability is worth as much as the score.
## Honest limitations

- **Toy corpus.** A synthetic mini-English. The mechanisms are real; the scale
is not. Real text needs sub-word units, a chart/Earley parser over the induced
grammar, and variable-slot constructions (not just flat categories).
is not. Real text needs sub-word units (RePair seq now exposed + light tuple support
in classes/grammar), a chart/Earley parser over the induced grammar (CKY bits
implemented), and variable-slot constructions (not just flat categories; next).
- **Agreement is captured via adjacent class bigrams.** Long-distance
dependencies (across embedded clauses) need the hierarchical / slot-binding
parser — that is the next rung.
- **Clustering threshold** is tuned for this world (0.46, complete-linkage).
At scale, replace the threshold with an MDL stopping criterion: merge iff it
shortens the total description length.
Now supported: `mdl=True` uses a cheap MDL proxy (class model cost + fit) to
decide merges (see code + solon_tinystories call). Full joint MDL is future work.

## Sub-word edition: morphology by compression (`solon_morphology.py`)

Expand Down Expand Up @@ -166,5 +167,9 @@ class. The rest of SOLON re-implements classic MDL/distributional acquisition
Swap `make_corpus()` for a loader over the strict-small 10M-word corpus, move
the predictor to character/sub-word PPM (robust to morphology — real wug tests),
and add a CKY parser so grammaticality uses minimum-description-length parses
rather than class bigrams. The eval pipeline (BLiMP, EWOK, reading-time) drops
in 2026; bits-per-word is already the right currency for the reading-time fit.
rather than class bigrams (now available via `ConstructionGrammar(..., use_chart=True).bits`).
Clustering now supports `mdl=True` (MDL delta stopping: merge iff it shortens approx DL;
see `induce_classes(..., mdl=True)`). RePair seq is exposed for subword experiments.
The eval pipeline (BLiMP, EWOK, reading-time) drops in 2026; bits-per-word is already
the right currency for the reading-time fit. Run with larger n_words / mdl / chart
flags for scaled experiments (e.g. `python solon_tinystories.py ... 4000000`).
192 changes: 177 additions & 15 deletions solon.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,46 @@ def cosine(a, b):
return dot / (na * nb) if na and nb else 0.0


def induce_classes(tokens, min_count=8, thresh=0.46, top_k=None, progress=False):
def _mdl_delta_for_class(word, best_ci, classes, vecs, class_costs, idf_weight=1.0, base_sim=0.0):
"""Cheap MDL proxy delta for assigning `word` to existing class vs new.
Negative delta means the merge/assign shortens total description length.
Uses class cardinality cost + rough fit cost from vector overlap (reuses IDF vecs).
Pure stdlib; fast; called only in mdl mode.
base_sim: the complete-link sim to best (used to gate; only consider if reasonably similar).
"""
# Cost of a class model (bits to "describe" the class itself)
def class_cost(csize):
return math.log2(1 + max(1, csize)) # new class or growth penalty

# Rough data fit cost for this word's context vec under a "prototype"
# (negative cosine as surprisal proxy; higher overlap = lower cost)
def fit_cost(wv, proto):
if not proto:
return 8.0 # high cost for no prototype
sim = cosine(wv, proto)
return max(0.1, 1.5 - 1.5 * max(0.0, sim)) * idf_weight

wv = vecs[word]
# cost if new class
new_c = class_cost(1)
new_fit = fit_cost(wv, wv) # self
new_total = new_c + new_fit

if best_ci is None or best_ci not in class_costs or base_sim < 0.15:
return 1.0 # force new class (no good candidate or low sim)

# cost if add to best
csize = class_costs[best_ci]
add_c = class_cost(csize + 1) - class_cost(csize) # marginal
# Use base_sim (already complete-link min) to estimate fit improvement
add_fit = fit_cost(wv, wv) * (1.0 - 0.6 * min(1.0, base_sim))
add_total = add_c + add_fit

delta = add_total - new_total
return delta


def induce_classes(tokens, min_count=8, thresh=0.46, top_k=None, progress=False, mdl=False, mdl_cost_weight=1.0):
left = collections.defaultdict(collections.Counter)
right = collections.defaultdict(collections.Counter)
for i, w in enumerate(tokens):
Expand Down Expand Up @@ -239,20 +278,45 @@ def induce_classes(tokens, min_count=8, thresh=0.46, top_k=None, progress=False)
# Complete-linkage clustering: a word joins a class only if it is similar
# to EVERY member. This prevents single-link "chaining" (everything follows
# "the", so single-link would merge the whole vocabulary into one blob).
# If mdl=True: use MDL delta (merge/assign only if it shortens approx total DL)
# instead of fixed thresh. See README "clustering threshold" and plan.
classes = []
assigned = {}
class_costs = collections.Counter() # cid -> current size (for mdl)
for w in tqdm(words, desc="[3] cluster", unit="w") if progress else words:
best, best_sim = None, thresh
for ci, members in enumerate(classes):
s = min(cosine(vecs[w], vecs[m]) for m in members) # complete link
if s >= best_sim:
best_sim, best = s, ci
if best is None:
assigned[w] = len(classes)
classes.append([w])
if not mdl:
best, best_sim = None, thresh
for ci, members in enumerate(classes):
s = min(cosine(vecs[w], vecs[m]) for m in members) # complete link
if s >= best_sim:
best_sim, best = s, ci
if best is None:
assigned[w] = len(classes)
classes.append([w])
class_costs[len(classes)-1] = 1
else:
assigned[w] = best
classes[best].append(w)
class_costs[best] += 1
else:
assigned[w] = best
classes[best].append(w)
# MDL mode: find best class by sim, then decide by delta_DL < 0
best, best_sim = None, -1.0
for ci, members in enumerate(classes):
s = min(cosine(vecs[w], vecs[m]) for m in members)
if s > best_sim:
best_sim, best = s, ci
delta = _mdl_delta_for_class(w, best, classes, vecs, class_costs, base_sim=best_sim)
weighted_delta = delta * mdl_cost_weight
if best is None or weighted_delta >= 0:
# create new class (no savings or no candidate)
cid = len(classes)
assigned[w] = cid
classes.append([w])
class_costs[cid] = 1
else:
assigned[w] = best
classes[best].append(w)
class_costs[best] += 1
return classes, assigned, vecs, idf


Expand All @@ -265,24 +329,72 @@ def induce_classes(tokens, min_count=8, thresh=0.46, top_k=None, progress=False)
# ---------------------------------------------------------------------------

class ConstructionGrammar:
def __init__(self, tokens, assigned):
def __init__(self, tokens, assigned, use_chart=False):
self.assigned = dict(assigned)
labs = [self.label(w) for w in tokens]
self.bi = collections.Counter(zip(labs, labs[1:]))
self.uni = collections.Counter(labs)
self.V = len(set(labs)) + 1
self.use_chart = use_chart # if True, bits() prefers CKY min-cost parse (MDL parse)

def label(self, w):
return ("C", self.assigned[w]) if w in self.assigned else ("W", w)

def bits(self, sent):
if self.use_chart:
try:
return self.parse_bits(sent)
except Exception:
pass # fallback
labs = [self.label(w) for w in sent]
total = 0.0
for a, b in zip(labs, labs[1:]):
p = (self.bi[(a, b)] + 1) / (self.uni[a] + self.V) # add-1 smoothing
total += -math.log2(p)
return total

def parse_bits(self, sent):
"""Very simple CKY over class-bigram 'rules' for a min-cost (MDL) parse.
Nonterms are the observed labels (C,cid or W,w). Rules from self.bi.
Returns min total -log cost for a full 'parse' of the label sequence.
Falls back gracefully for short/empty. Pure stdlib, O(n^3 * L) with L small.
"""
if not sent:
return 0.0
labs = [self.label(w) for w in sent]
n = len(labs)
if n == 1:
return 0.0 # no transitions
# labels set (L small: #classes + few)
all_labs = list(self.bi.keys()) # (a,b) pairs imply the labels
from collections import defaultdict
chart = [defaultdict(lambda: defaultdict(lambda: float('inf'))) for _ in range(n+1)]
# init: singletons (cost 0)
for i, lab in enumerate(labs):
chart[i][i+1][lab] = 0.0
# fill spans
for length in range(2, n+1):
for i in range(n - length + 1):
j = i + length
for k in range(i+1, j):
for left, left_cost in chart[i][k].items():
for right, right_cost in chart[k][j].items():
if left_cost >= float('inf') or right_cost >= float('inf'):
continue
rule_cost = -math.log2( (self.bi[(left, right)] + 1) / (self.uni.get(left, 0) + self.V) )
total = left_cost + right_cost + rule_cost
if total < chart[i][j][right]: # or track best root; simplify: allow right as head
chart[i][j][right] = total
# also try left as 'head' for symmetry (cheap)
if total < chart[i][j][left]:
chart[i][j][left] = total
# min over any root for full span
min_cost = min(chart[0][n].values()) if chart[0][n] else float('inf')
if min_cost >= float('inf') or n < 2:
# fallback to flat bigram
return sum(-math.log2( (self.bi[(a, b)] + 1) / (self.uni.get(a, 0) + self.V) ) for a, b in zip(labs, labs[1:]))
return min_cost

def learn_word_oneshot(self, word, one_context_sentence, vecs, idf):
"""Assign a brand-new word to the class whose members share its
single observed context. One exposure, no gradients."""
Expand Down Expand Up @@ -313,8 +425,42 @@ def learn_word_oneshot(self, word, one_context_sentence, vecs, idf):
# 5. Minimal pairs (BLiMP-style): grammatical vs corrupted.
# ---------------------------------------------------------------------------

def minimal_pairs(rng, n=200):
def minimal_pairs(rng, n=200, sents=None, cg=None):
"""Generate minimal pairs. If sents provided (real text), mine simple
patterns and corrupt (agreement via verb swap heuristic, order, det-noun).
Falls back to toy generator. cg optional for future class-aware corrupt.
"""
pairs = []
if sents:
# simple mining from real sents for heldout evals
for _ in range(n):
if not sents:
break
s = list(rng.choice(sents))
if len(s) < 4:
continue
body = s[:-1] if s[-1] == END else s
kind = rng.choice(["agreement", "order", "det-noun"])
bad = body[:]
if kind == "agreement":
# find a verb-ish token (heuristic: ends with s or common)
for i, w in enumerate(body):
if w.endswith(("s", "ed", "ing")) or w in ("is", "was", "runs", "run"):
swap = w[:-1] if w.endswith("s") else (w + "s" if not w.endswith("s") else w)
bad[i] = swap
break
else:
continue
pairs.append(("agreement", body + [END], bad + [END]))
elif kind == "order" and len(body) >= 3:
bad[0], bad[2] = bad[2], bad[0] # crude scramble
pairs.append(("order", body + [END], bad + [END]))
else:
if len(body) >= 2:
bad[0], bad[1] = bad[1], bad[0]
pairs.append(("det-noun", body + [END], bad + [END]))
return pairs[:n]
# original toy synthetic
for _ in range(n):
s = sentence(rng)
body = s[:-1] # drop "."
Expand Down Expand Up @@ -350,6 +496,21 @@ def judge_accuracy(scorer, pairs):
return by_kind


def grammar_ppl(scorer, sents):
"""Bits-per-label (or token) using a bits() scorer (CG or LM wrapper).
Analog to perplexity but for the grammar layer.
"""
total_bits, total = 0.0, 0
for s in sents:
try:
b = scorer(s) if callable(scorer) else scorer.bits(s)
total_bits += b
total += max(1, len(s) - 1)
except Exception:
continue
return (2 ** (total_bits / total)) if total else 1.0


# ---------------------------------------------------------------------------
# Report
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -413,7 +574,7 @@ def main():
# --- 3. abstraction by distribution -----------------------------------
line()
print("[3] INDUCED CATEGORIES (words merged by shared context = the 'dream')")
classes, assigned, vecs, idf = induce_classes(train_tokens)
classes, assigned, vecs, idf = induce_classes(train_tokens) # mdl=False (default) for exact old behavior on toy
classes_sorted = sorted(enumerate(classes), key=lambda x: -len(x[1]))
for ci, members in classes_sorted:
if len(members) >= 2:
Expand All @@ -426,7 +587,8 @@ def main():
cg = ConstructionGrammar(train_tokens, assigned)
lm_score = lambda s: sum(b for _, b in lm.sentence_bits(s))
for name, scorer in [("statistical back-off LM ", lm_score),
("construction grammar ", cg.bits)]:
("construction grammar ", cg.bits),
("construction grammar+chart", (lambda s: cg.parse_bits(s)) if cg.use_chart else cg.bits)]:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Use the chart scorer for the chart row

In the self-contained solon.py demo, this row never exercises parse_bits: cg was constructed with the default use_chart=False, so the conditional selects cg.bits and the reported construction grammar+chart accuracy is just a duplicate of the flat construction-grammar scorer. This makes the chart evaluation in the demo misleading; construct/use a ConstructionGrammar(..., use_chart=True) instance for this row.

Useful? React with 👍 / 👎.

acc = judge_accuracy(scorer, pairs)
parts = " ".join(f"{k}: {v[0]/v[1]*100:4.0f}%" for k, v in sorted(acc.items()))
print(f" {name} {parts}")
Expand Down
39 changes: 30 additions & 9 deletions solon_tinystories.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import solon # reuse the core: CompressionLM, repair, etc.
from solon import (CompressionLM, repair, expand, induce_classes,
ConstructionGrammar, guess_label, END)
ConstructionGrammar, guess_label, END, grammar_ppl, minimal_pairs)

WORD = re.compile(r"[a-z]+'?[a-z]*")
SENT_SPLIT = re.compile(r"[.!?]+")
Expand Down Expand Up @@ -55,7 +55,7 @@ def main():

banner("=")
print("SOLON on TinyStories - learning real text by compression")
print(" (no transformer, no backprop)")
print(f" (no transformer, no backprop; n_words~{n_words})")
banner("=")

t0 = time.time()
Expand Down Expand Up @@ -86,7 +86,7 @@ def main():
print("[2] CONSTRUCTION LIBRARY (RePair chunks that shrink the corpus)")
t0 = time.time()
rep_tokens = train_tokens[:300_000]
rules, _ = repair(rep_tokens, max_rules=400, progress=True)
rules, repaired = repair(rep_tokens, max_rules=400, progress=True) # repaired seq is subword (NTs + terms)
chunks = sorted(((len(expand(nt, rules)), expand(nt, rules)) for nt in rules),
reverse=True)
seen, shown = set(), 0
Expand All @@ -99,13 +99,16 @@ def main():
if shown >= 12:
break
print(f" ({time.time()-t0:.1f}s)")
# Subword integration note (RePair "construction library" rung; seq available
# for induce_classes/CG in future -- treat NT tuples as atomic symbols).
print(f" (subword: RePair produced compacted seq of len {len(repaired)}; e.g. first few mixed: {repaired[:6]})")

# --- 3. induced categories -------------------------------------------
banner()
print("[3] INDUCED CATEGORIES (top words merged by shared context)")
print("[3] INDUCED CATEGORIES (top words merged by shared context; mdl=True uses DL stopping)")
t0 = time.time()
classes, assigned, vecs, idf = induce_classes(
train_tokens, min_count=40, thresh=0.20, top_k=300, progress=True)
train_tokens, min_count=40, thresh=0.20, top_k=300, progress=True, mdl=True) # MDL stopping per README (merge iff shortens DL)
# show the largest, most coherent classes
freq = collections.Counter(train_tokens)
classes_sorted = sorted(classes, key=lambda c: -sum(freq[w] for w in c))
Expand All @@ -123,6 +126,7 @@ def main():
banner()
print("[4] ONE-SHOT WORD LEARNING (productivity)")
cg = ConstructionGrammar(train_tokens, assigned)
cg_chart = ConstructionGrammar(train_tokens, assigned, use_chart=True)
nonce = [("zorp", "the zorp was happy".split()),
("glip", "she wanted to glip".split()),
("blicket", "he saw a blicket".split())]
Expand All @@ -147,17 +151,34 @@ def main():
("pronoun", "the girl said she was sad", "the girl said she were sad"),
]
score = lambda s: sum(b for _, b in lm.sentence_bits(s.split() + [END]))
correct = 0
score_cg = cg.bits
score_chart = cg_chart.bits # will use parse_bits internally
# demo generalized minimal_pairs on real heldout sents (more evals)
auto_pairs = minimal_pairs(random.Random(99), n=8, sents=test_sents[:300])

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Actually score the mined held-out pairs

auto_pairs is generated from held-out TinyStories sentences but then discarded, so the advertised generalized minimal-pair evaluation never runs; this section still reports accuracy only on the five hand-built examples. For experiment runs that rely on the new “more evals” output, the reported accuracy omits the held-out corruption test entirely.

Useful? React with 👍 / 👎.

correct = correct_cg = correct_chart = 0
for kind, good, bad in pairs:
m = score(bad) - score(good)
ok = m > 0
correct += ok
print(f" {kind:<10} {good:<24} > {bad:<24} {m:+6.1f} b {'ok' if ok else 'X'}")
print(f" accuracy: {correct}/{len(pairs)}")
m_cg = score_cg(bad.split() + [END]) - score_cg(good.split() + [END])
ok_cg = m_cg > 0
correct_cg += ok_cg
m_chart = score_chart(bad.split() + [END]) - score_chart(good.split() + [END])
ok_chart = m_chart > 0
correct_chart += ok_chart
print(f" {kind:<10} {good:<24} > {bad:<24} LM{m:+6.1f}b CG{m_cg:+6.1f}b chart{m_chart:+6.1f}b")
print(f" accuracy: LM {correct}/{len(pairs)} CG {correct_cg}/{len(pairs)} chart {correct_chart}/{len(pairs)}")
# extra: grammar "ppl" on heldout using the scorers (lower better)
test_for_ppl = test_sents[:200]
lm_ppl = grammar_ppl(lambda s: sum(b for _, b in lm.sentence_bits(s)), test_for_ppl)
cg_ppl = grammar_ppl(cg.bits, test_for_ppl)
chart_ppl = grammar_ppl(cg_chart.bits, test_for_ppl)
print(f" grammar ppl (on {len(test_for_ppl)} heldout sents): LM {lm_ppl:.1f} CG {cg_ppl:.1f} chart {chart_ppl:.1f}")

banner("=")
print("Real text, ~1M words. Categories, phrases and one-shot generalization")
print(f"Real text, ~{n_words:,} words (full file supported). Categories, phrases and one-shot generalization")
print("emerged from counting and refactoring alone -- no gradients.")
print("(extensions: mdl stopping in clustering, chart/CKY parser bits, RePair subword exposure, more evals)")
banner("=")


Expand Down
Loading