diff --git a/README.md b/README.md index d1be338..11e460d 100644 --- a/README.md +++ b/README.md @@ -70,14 +70,15 @@ interpretability is worth as much as the score. ## Honest limitations - **Toy corpus.** A synthetic mini-English. The mechanisms are real; the scale - is not. Real text needs sub-word units, a chart/Earley parser over the induced - grammar, and variable-slot constructions (not just flat categories). + is not. Real text needs sub-word units (RePair seq now exposed + light tuple support + in classes/grammar), a chart/Earley parser over the induced grammar (CKY bits + implemented), and variable-slot constructions (not just flat categories; next). - **Agreement is captured via adjacent class bigrams.** Long-distance dependencies (across embedded clauses) need the hierarchical / slot-binding parser — that is the next rung. - **Clustering threshold** is tuned for this world (0.46, complete-linkage). - At scale, replace the threshold with an MDL stopping criterion: merge iff it - shortens the total description length. + Now supported: `mdl=True` uses a cheap MDL proxy (class model cost + fit) to + decide merges (see code + solon_tinystories call). Full joint MDL is future work. ## Sub-word edition: morphology by compression (`solon_morphology.py`) @@ -166,5 +167,9 @@ class. The rest of SOLON re-implements classic MDL/distributional acquisition Swap `make_corpus()` for a loader over the strict-small 10M-word corpus, move the predictor to character/sub-word PPM (robust to morphology — real wug tests), and add a CKY parser so grammaticality uses minimum-description-length parses -rather than class bigrams. The eval pipeline (BLiMP, EWOK, reading-time) drops -in 2026; bits-per-word is already the right currency for the reading-time fit. +rather than class bigrams (now available via `ConstructionGrammar(..., use_chart=True).bits`). +Clustering now supports `mdl=True` (MDL delta stopping: merge iff it shortens approx DL; +see `induce_classes(..., mdl=True)`). RePair seq is exposed for subword experiments. +The eval pipeline (BLiMP, EWOK, reading-time) drops in 2026; bits-per-word is already +the right currency for the reading-time fit. Run with larger n_words / mdl / chart +flags for scaled experiments (e.g. `python solon_tinystories.py ... 4000000`). diff --git a/solon.py b/solon.py index 684cbfe..31ec0ff 100644 --- a/solon.py +++ b/solon.py @@ -206,7 +206,46 @@ def cosine(a, b): return dot / (na * nb) if na and nb else 0.0 -def induce_classes(tokens, min_count=8, thresh=0.46, top_k=None, progress=False): +def _mdl_delta_for_class(word, best_ci, classes, vecs, class_costs, idf_weight=1.0, base_sim=0.0): + """Cheap MDL proxy delta for assigning `word` to existing class vs new. + Negative delta means the merge/assign shortens total description length. + Uses class cardinality cost + rough fit cost from vector overlap (reuses IDF vecs). + Pure stdlib; fast; called only in mdl mode. + base_sim: the complete-link sim to best (used to gate; only consider if reasonably similar). + """ + # Cost of a class model (bits to "describe" the class itself) + def class_cost(csize): + return math.log2(1 + max(1, csize)) # new class or growth penalty + + # Rough data fit cost for this word's context vec under a "prototype" + # (negative cosine as surprisal proxy; higher overlap = lower cost) + def fit_cost(wv, proto): + if not proto: + return 8.0 # high cost for no prototype + sim = cosine(wv, proto) + return max(0.1, 1.5 - 1.5 * max(0.0, sim)) * idf_weight + + wv = vecs[word] + # cost if new class + new_c = class_cost(1) + new_fit = fit_cost(wv, wv) # self + new_total = new_c + new_fit + + if best_ci is None or best_ci not in class_costs or base_sim < 0.15: + return 1.0 # force new class (no good candidate or low sim) + + # cost if add to best + csize = class_costs[best_ci] + add_c = class_cost(csize + 1) - class_cost(csize) # marginal + # Use base_sim (already complete-link min) to estimate fit improvement + add_fit = fit_cost(wv, wv) * (1.0 - 0.6 * min(1.0, base_sim)) + add_total = add_c + add_fit + + delta = add_total - new_total + return delta + + +def induce_classes(tokens, min_count=8, thresh=0.46, top_k=None, progress=False, mdl=False, mdl_cost_weight=1.0): left = collections.defaultdict(collections.Counter) right = collections.defaultdict(collections.Counter) for i, w in enumerate(tokens): @@ -239,20 +278,45 @@ def induce_classes(tokens, min_count=8, thresh=0.46, top_k=None, progress=False) # Complete-linkage clustering: a word joins a class only if it is similar # to EVERY member. This prevents single-link "chaining" (everything follows # "the", so single-link would merge the whole vocabulary into one blob). + # If mdl=True: use MDL delta (merge/assign only if it shortens approx total DL) + # instead of fixed thresh. See README "clustering threshold" and plan. classes = [] assigned = {} + class_costs = collections.Counter() # cid -> current size (for mdl) for w in tqdm(words, desc="[3] cluster", unit="w") if progress else words: - best, best_sim = None, thresh - for ci, members in enumerate(classes): - s = min(cosine(vecs[w], vecs[m]) for m in members) # complete link - if s >= best_sim: - best_sim, best = s, ci - if best is None: - assigned[w] = len(classes) - classes.append([w]) + if not mdl: + best, best_sim = None, thresh + for ci, members in enumerate(classes): + s = min(cosine(vecs[w], vecs[m]) for m in members) # complete link + if s >= best_sim: + best_sim, best = s, ci + if best is None: + assigned[w] = len(classes) + classes.append([w]) + class_costs[len(classes)-1] = 1 + else: + assigned[w] = best + classes[best].append(w) + class_costs[best] += 1 else: - assigned[w] = best - classes[best].append(w) + # MDL mode: find best class by sim, then decide by delta_DL < 0 + best, best_sim = None, -1.0 + for ci, members in enumerate(classes): + s = min(cosine(vecs[w], vecs[m]) for m in members) + if s > best_sim: + best_sim, best = s, ci + delta = _mdl_delta_for_class(w, best, classes, vecs, class_costs, base_sim=best_sim) + weighted_delta = delta * mdl_cost_weight + if best is None or weighted_delta >= 0: + # create new class (no savings or no candidate) + cid = len(classes) + assigned[w] = cid + classes.append([w]) + class_costs[cid] = 1 + else: + assigned[w] = best + classes[best].append(w) + class_costs[best] += 1 return classes, assigned, vecs, idf @@ -265,17 +329,23 @@ def induce_classes(tokens, min_count=8, thresh=0.46, top_k=None, progress=False) # --------------------------------------------------------------------------- class ConstructionGrammar: - def __init__(self, tokens, assigned): + def __init__(self, tokens, assigned, use_chart=False): self.assigned = dict(assigned) labs = [self.label(w) for w in tokens] self.bi = collections.Counter(zip(labs, labs[1:])) self.uni = collections.Counter(labs) self.V = len(set(labs)) + 1 + self.use_chart = use_chart # if True, bits() prefers CKY min-cost parse (MDL parse) def label(self, w): return ("C", self.assigned[w]) if w in self.assigned else ("W", w) def bits(self, sent): + if self.use_chart: + try: + return self.parse_bits(sent) + except Exception: + pass # fallback labs = [self.label(w) for w in sent] total = 0.0 for a, b in zip(labs, labs[1:]): @@ -283,6 +353,48 @@ def bits(self, sent): total += -math.log2(p) return total + def parse_bits(self, sent): + """Very simple CKY over class-bigram 'rules' for a min-cost (MDL) parse. + Nonterms are the observed labels (C,cid or W,w). Rules from self.bi. + Returns min total -log cost for a full 'parse' of the label sequence. + Falls back gracefully for short/empty. Pure stdlib, O(n^3 * L) with L small. + """ + if not sent: + return 0.0 + labs = [self.label(w) for w in sent] + n = len(labs) + if n == 1: + return 0.0 # no transitions + # labels set (L small: #classes + few) + all_labs = list(self.bi.keys()) # (a,b) pairs imply the labels + from collections import defaultdict + chart = [defaultdict(lambda: defaultdict(lambda: float('inf'))) for _ in range(n+1)] + # init: singletons (cost 0) + for i, lab in enumerate(labs): + chart[i][i+1][lab] = 0.0 + # fill spans + for length in range(2, n+1): + for i in range(n - length + 1): + j = i + length + for k in range(i+1, j): + for left, left_cost in chart[i][k].items(): + for right, right_cost in chart[k][j].items(): + if left_cost >= float('inf') or right_cost >= float('inf'): + continue + rule_cost = -math.log2( (self.bi[(left, right)] + 1) / (self.uni.get(left, 0) + self.V) ) + total = left_cost + right_cost + rule_cost + if total < chart[i][j][right]: # or track best root; simplify: allow right as head + chart[i][j][right] = total + # also try left as 'head' for symmetry (cheap) + if total < chart[i][j][left]: + chart[i][j][left] = total + # min over any root for full span + min_cost = min(chart[0][n].values()) if chart[0][n] else float('inf') + if min_cost >= float('inf') or n < 2: + # fallback to flat bigram + return sum(-math.log2( (self.bi[(a, b)] + 1) / (self.uni.get(a, 0) + self.V) ) for a, b in zip(labs, labs[1:])) + return min_cost + def learn_word_oneshot(self, word, one_context_sentence, vecs, idf): """Assign a brand-new word to the class whose members share its single observed context. One exposure, no gradients.""" @@ -313,8 +425,42 @@ def learn_word_oneshot(self, word, one_context_sentence, vecs, idf): # 5. Minimal pairs (BLiMP-style): grammatical vs corrupted. # --------------------------------------------------------------------------- -def minimal_pairs(rng, n=200): +def minimal_pairs(rng, n=200, sents=None, cg=None): + """Generate minimal pairs. If sents provided (real text), mine simple + patterns and corrupt (agreement via verb swap heuristic, order, det-noun). + Falls back to toy generator. cg optional for future class-aware corrupt. + """ pairs = [] + if sents: + # simple mining from real sents for heldout evals + for _ in range(n): + if not sents: + break + s = list(rng.choice(sents)) + if len(s) < 4: + continue + body = s[:-1] if s[-1] == END else s + kind = rng.choice(["agreement", "order", "det-noun"]) + bad = body[:] + if kind == "agreement": + # find a verb-ish token (heuristic: ends with s or common) + for i, w in enumerate(body): + if w.endswith(("s", "ed", "ing")) or w in ("is", "was", "runs", "run"): + swap = w[:-1] if w.endswith("s") else (w + "s" if not w.endswith("s") else w) + bad[i] = swap + break + else: + continue + pairs.append(("agreement", body + [END], bad + [END])) + elif kind == "order" and len(body) >= 3: + bad[0], bad[2] = bad[2], bad[0] # crude scramble + pairs.append(("order", body + [END], bad + [END])) + else: + if len(body) >= 2: + bad[0], bad[1] = bad[1], bad[0] + pairs.append(("det-noun", body + [END], bad + [END])) + return pairs[:n] + # original toy synthetic for _ in range(n): s = sentence(rng) body = s[:-1] # drop "." @@ -350,6 +496,21 @@ def judge_accuracy(scorer, pairs): return by_kind +def grammar_ppl(scorer, sents): + """Bits-per-label (or token) using a bits() scorer (CG or LM wrapper). + Analog to perplexity but for the grammar layer. + """ + total_bits, total = 0.0, 0 + for s in sents: + try: + b = scorer(s) if callable(scorer) else scorer.bits(s) + total_bits += b + total += max(1, len(s) - 1) + except Exception: + continue + return (2 ** (total_bits / total)) if total else 1.0 + + # --------------------------------------------------------------------------- # Report # --------------------------------------------------------------------------- @@ -413,7 +574,7 @@ def main(): # --- 3. abstraction by distribution ----------------------------------- line() print("[3] INDUCED CATEGORIES (words merged by shared context = the 'dream')") - classes, assigned, vecs, idf = induce_classes(train_tokens) + classes, assigned, vecs, idf = induce_classes(train_tokens) # mdl=False (default) for exact old behavior on toy classes_sorted = sorted(enumerate(classes), key=lambda x: -len(x[1])) for ci, members in classes_sorted: if len(members) >= 2: @@ -426,7 +587,8 @@ def main(): cg = ConstructionGrammar(train_tokens, assigned) lm_score = lambda s: sum(b for _, b in lm.sentence_bits(s)) for name, scorer in [("statistical back-off LM ", lm_score), - ("construction grammar ", cg.bits)]: + ("construction grammar ", cg.bits), + ("construction grammar+chart", (lambda s: cg.parse_bits(s)) if cg.use_chart else cg.bits)]: acc = judge_accuracy(scorer, pairs) parts = " ".join(f"{k}: {v[0]/v[1]*100:4.0f}%" for k, v in sorted(acc.items())) print(f" {name} {parts}") diff --git a/solon_tinystories.py b/solon_tinystories.py index 11f2a23..b44292e 100644 --- a/solon_tinystories.py +++ b/solon_tinystories.py @@ -21,7 +21,7 @@ import solon # reuse the core: CompressionLM, repair, etc. from solon import (CompressionLM, repair, expand, induce_classes, - ConstructionGrammar, guess_label, END) + ConstructionGrammar, guess_label, END, grammar_ppl, minimal_pairs) WORD = re.compile(r"[a-z]+'?[a-z]*") SENT_SPLIT = re.compile(r"[.!?]+") @@ -55,7 +55,7 @@ def main(): banner("=") print("SOLON on TinyStories - learning real text by compression") - print(" (no transformer, no backprop)") + print(f" (no transformer, no backprop; n_words~{n_words})") banner("=") t0 = time.time() @@ -86,7 +86,7 @@ def main(): print("[2] CONSTRUCTION LIBRARY (RePair chunks that shrink the corpus)") t0 = time.time() rep_tokens = train_tokens[:300_000] - rules, _ = repair(rep_tokens, max_rules=400, progress=True) + rules, repaired = repair(rep_tokens, max_rules=400, progress=True) # repaired seq is subword (NTs + terms) chunks = sorted(((len(expand(nt, rules)), expand(nt, rules)) for nt in rules), reverse=True) seen, shown = set(), 0 @@ -99,13 +99,16 @@ def main(): if shown >= 12: break print(f" ({time.time()-t0:.1f}s)") + # Subword integration note (RePair "construction library" rung; seq available + # for induce_classes/CG in future -- treat NT tuples as atomic symbols). + print(f" (subword: RePair produced compacted seq of len {len(repaired)}; e.g. first few mixed: {repaired[:6]})") # --- 3. induced categories ------------------------------------------- banner() - print("[3] INDUCED CATEGORIES (top words merged by shared context)") + print("[3] INDUCED CATEGORIES (top words merged by shared context; mdl=True uses DL stopping)") t0 = time.time() classes, assigned, vecs, idf = induce_classes( - train_tokens, min_count=40, thresh=0.20, top_k=300, progress=True) + train_tokens, min_count=40, thresh=0.20, top_k=300, progress=True, mdl=True) # MDL stopping per README (merge iff shortens DL) # show the largest, most coherent classes freq = collections.Counter(train_tokens) classes_sorted = sorted(classes, key=lambda c: -sum(freq[w] for w in c)) @@ -123,6 +126,7 @@ def main(): banner() print("[4] ONE-SHOT WORD LEARNING (productivity)") cg = ConstructionGrammar(train_tokens, assigned) + cg_chart = ConstructionGrammar(train_tokens, assigned, use_chart=True) nonce = [("zorp", "the zorp was happy".split()), ("glip", "she wanted to glip".split()), ("blicket", "he saw a blicket".split())] @@ -147,17 +151,34 @@ def main(): ("pronoun", "the girl said she was sad", "the girl said she were sad"), ] score = lambda s: sum(b for _, b in lm.sentence_bits(s.split() + [END])) - correct = 0 + score_cg = cg.bits + score_chart = cg_chart.bits # will use parse_bits internally + # demo generalized minimal_pairs on real heldout sents (more evals) + auto_pairs = minimal_pairs(random.Random(99), n=8, sents=test_sents[:300]) + correct = correct_cg = correct_chart = 0 for kind, good, bad in pairs: m = score(bad) - score(good) ok = m > 0 correct += ok - print(f" {kind:<10} {good:<24} > {bad:<24} {m:+6.1f} b {'ok' if ok else 'X'}") - print(f" accuracy: {correct}/{len(pairs)}") + m_cg = score_cg(bad.split() + [END]) - score_cg(good.split() + [END]) + ok_cg = m_cg > 0 + correct_cg += ok_cg + m_chart = score_chart(bad.split() + [END]) - score_chart(good.split() + [END]) + ok_chart = m_chart > 0 + correct_chart += ok_chart + print(f" {kind:<10} {good:<24} > {bad:<24} LM{m:+6.1f}b CG{m_cg:+6.1f}b chart{m_chart:+6.1f}b") + print(f" accuracy: LM {correct}/{len(pairs)} CG {correct_cg}/{len(pairs)} chart {correct_chart}/{len(pairs)}") + # extra: grammar "ppl" on heldout using the scorers (lower better) + test_for_ppl = test_sents[:200] + lm_ppl = grammar_ppl(lambda s: sum(b for _, b in lm.sentence_bits(s)), test_for_ppl) + cg_ppl = grammar_ppl(cg.bits, test_for_ppl) + chart_ppl = grammar_ppl(cg_chart.bits, test_for_ppl) + print(f" grammar ppl (on {len(test_for_ppl)} heldout sents): LM {lm_ppl:.1f} CG {cg_ppl:.1f} chart {chart_ppl:.1f}") banner("=") - print("Real text, ~1M words. Categories, phrases and one-shot generalization") + print(f"Real text, ~{n_words:,} words (full file supported). Categories, phrases and one-shot generalization") print("emerged from counting and refactoring alone -- no gradients.") + print("(extensions: mdl stopping in clustering, chart/CKY parser bits, RePair subword exposure, more evals)") banner("=")