/* global React */
// Shared utilities + canned distributions for The Loaded Die.

const { useState, useEffect, useRef, useMemo, useCallback } = React;

// ---- canned token distributions for the loaded-die demo ----
// Each entry's `dist` is an array of { token, p } summing close to 1.
// The implicit "tail" of ~100k other tokens carries the leftover probability.
const PRESETS = [
  {
    id: "paris",
    prompt: "The capital of France is",
    dist: [
      { token: "Paris",     p: 0.940 },
      { token: "London",    p: 0.025 },
      { token: "beautiful", p: 0.008 },
      { token: "banana",    p: 0.0001 },
    ],
  },
  {
    id: "violets",
    prompt: "Roses are red, violets are",
    dist: [
      { token: "blue",      p: 0.860 },
      { token: "purple",    p: 0.060 },
      { token: "beautiful", p: 0.025 },
      { token: "banana",    p: 0.0001 },
    ],
  },
  {
    id: "math",
    prompt: "Two plus two equals",
    dist: [
      { token: "four",        p: 0.920 },
      { token: "4",           p: 0.040 },
      { token: "twenty-two",  p: 0.008 },
      { token: "banana",      p: 0.0001 },
    ],
  },
  {
    id: "upon",
    prompt: "Once upon a",
    dist: [
      { token: "time",      p: 0.910 },
      { token: "midnight",  p: 0.030 },
      { token: "summer",    p: 0.015 },
      { token: "banana",    p: 0.0001 },
    ],
  },
];

// Compute the residual "tail" probability mass.
function tailProb(dist) {
  const used = dist.reduce((a, b) => a + b.p, 0);
  return Math.max(0, 1 - used);
}

// Sample one outcome from a distribution. Returns the index of the named
// token chosen, or -1 if the sample fell into the long tail of other tokens.
function sample(dist) {
  const r = Math.random();
  let acc = 0;
  for (let i = 0; i < dist.length; i++) {
    acc += dist[i].p;
    if (r < acc) return i;
  }
  return -1; // tail
}

// Standard softmax-temperature reweighting: q_i ∝ p_i^(1/T).
// T → 0 collapses to argmax; T = 1 leaves the distribution alone;
// T → ∞ flattens toward uniform.
function applyTemperature(dist, T) {
  if (T <= 0.02) {
    // deterministic: pure argmax
    const top = [...dist].sort((a, b) => b.p - a.p)[0];
    return dist.map(d => ({ ...d, p: d === top ? 1 : 0 }));
  }
  const tail = tailProb(dist);
  // Approximate the tail as a single mass at very low p, then divide it
  // back out across the named tokens proportionally to their warped weight.
  const exp = 1 / T;
  const warped = dist.map(d => ({ ...d, w: Math.pow(d.p, exp) }));
  const tailW = Math.pow(Math.max(tail, 1e-9), exp);
  const Z = warped.reduce((a, b) => a + b.w, 0) + tailW;
  return warped.map(d => ({ token: d.token, p: d.w / Z }));
}

// Tail mass after warping — used to know how often the model goes
// completely off-script at high temperature.
function tailAfterTemperature(dist, T) {
  if (T <= 0.02) return 0;
  const tail = tailProb(dist);
  const exp = 1 / T;
  const warped = dist.map(d => Math.pow(d.p, exp));
  const tailW = Math.pow(Math.max(tail, 1e-9), exp);
  const Z = warped.reduce((a, b) => a + b, 0) + tailW;
  return tailW / Z;
}

// Sample with a precomputed temperature-warped distribution.
function sampleWithTemp(dist, T) {
  const warped = applyTemperature(dist, T);
  const tailP = tailAfterTemperature(dist, T);
  const r = Math.random();
  let acc = 0;
  for (let i = 0; i < warped.length; i++) {
    acc += warped[i].p;
    if (r < acc) return { kind: "known", index: i, token: warped[i].token };
  }
  // landed in the tail — invent a plausibly-weird off-script token
  return { kind: "tail", token: pickTailToken() };
}

const TAIL_TOKENS = [
  "banana", "spoon", "midnight", "verdant", "umbrella", "lavender",
  "Wednesday", "concrete", "fjord", "telephone", "lighthouse", "marmalade",
  "sphinx", "porcelain", "ricochet", "saffron", "fossil", "noon",
];
function pickTailToken() {
  return TAIL_TOKENS[Math.floor(Math.random() * TAIL_TOKENS.length)];
}

// Pretty-print a probability for the legend.
function fmtPct(p) {
  if (p >= 0.01) return (p * 100).toFixed(1) + "%";
  if (p >= 0.0001) return (p * 100).toFixed(2) + "%";
  return "0.0001%";
}

// Make globally available to other Babel scripts.
Object.assign(window, {
  PRESETS, tailProb, sample, applyTemperature, tailAfterTemperature,
  sampleWithTemp, pickTailToken, fmtPct, TAIL_TOKENS,
});
