Price grid: [50, 500, 1000, 5000, 10000]. Prior probabilities over prices: [0.8070, 0.1070, 0.0434, 0.0223, 0.0203]. Valence (a boolean representing whether the speaker feels positively about the item) has a price-dependent prior: at price 50 the probability of positive valence is 0.3173; at 500 it is 0.7920; at 1000 it is 0.8933; at 5000 it is 0.9524; at 10000 it is 0.9864. Utterances are "expensive" and "notExpensive". Cost of "notExpensive" is 1; cost of "expensive" is 0. The threshold theta is drawn uniformly from the price grid. The QUD (question under discussion) is drawn with equal probability from three options: report only the price, report only the valence, or report both price and valence. Speaker optimality parameter alpha=1. Meaning: "expensive" is true when price >= theta; "notExpensive" is true when price <= theta.
A Rational Speech Act model combining vague-adjective interpretation with QUD uncertainty. A literal listener hears an utterance, samples a price uniformly from the price grid (not the categorical price prior) and a valence from the valence prior, conditions on the meaning of the utterance given the threshold theta, and returns the QUD-relevant part of the full state {price, valence}. The pragmatic listener uses the categorical price prior (the weights given above) when sampling the full state. A pragmatic speaker chooses utterances with probability proportional to exp(alpha * (literal-listener score on the QUD answer minus utterance cost)). A pragmatic listener reasons about the full state {price, valence} by marginalizing over the QUD and threshold, observing the speaker's choice of utterance.
The posterior joint distribution over {price, valence} for a pragmatic listener who hears the utterance "expensive".
answer spec
{
"kind": "dist",
"domain": "finite",
"labels": {
"record": {
"price": "int",
"valence": "bool"
}
}
}system prompt
(system prompt loads here)
webppl primer
(primer loads here)
1// ADJECTIVES + QUD MODEL2// frankie + shane RSA project34// code adapted from the Kao et al. hyperbole model +5// gradable adjectives & vagueness resolution model678var utterances = ["expensive", "notExpensive"]910var utterancePrior = function() {11 return uniformDraw(utterances)12}1314var thetaPrior = function() {15 return uniformDraw(prices)16}1718// theta moderates interpretation of utterances19var meaning = function(utterance, price, theta) {20return utterance == "expensive" ? price >= theta :21 utterance == "notExpensive" ? price <= theta :22 true23}2425// more words = higher cost26var cost = function(utterance) {27 return utterance== 'notExpensive'? 1 :28 029};3031var prices = [32 50,33 500,34 1000,35 5000,36 1000037]3839var pricePrior = function() {40 return categorical({41 vs: prices,42 ps: [43 0.8070,44 0.1070,45 0.0434,46 0.0223,47 0.020348 ]49 })50}5152var valencePrior = function(state) {53 var probs = {54 50 : 0.3173,55 500 : 0.7920,56 1000 : 0.8933,57 5000 : 0.9524,58 10000 : 0.986459 }60 var tf = flip(probs[state])61 return tf62}6364var qudFns = {65 price : function(state) {return { price: state.price } },66 valence : function(state) {return { valence: state.valence } },67 priceValence : function(state) {68 return { price: state.price, valence: state.valence }69 }70 }7172var qudPrior = function() {73 categorical({74 vs: ["price", "valence", "priceValence"],75 ps: [1, 1, 1]76 })77}7879var literalListener = cache(function(utterance, qud, theta) {80 return Infer({model: function(){81 var price = uniformDraw(prices)82 var valence = valencePrior(price)83 var fullState = {price, valence}84 var qudFn = qudFns[qud]85 var qudAnswer = qudFn(fullState)86 condition( meaning(utterance, price, theta) )87 return qudAnswer88 }89 })})9091// speaker optimality92var alpha = 19394var speaker = cache(function(fullState, qud, theta) {95 return Infer({model: function(){96 var utterance = utterancePrior()97 var qudFn = qudFns[qud]98 var qudAnswer = qudFn(fullState)99 factor(alpha*(literalListener(utterance,qud,theta).score(qudAnswer)100 - cost(utterance)))101 return utterance102 }})103})104105var pragmaticListener = cache(function(utterance) {106 return Infer({model: function(){107 //////// priors ////////108 var price = pricePrior()109 var valence = valencePrior(price)110 var qud = qudPrior()111 var theta = thetaPrior()112 ////////////////////////113 var fullState = {price, valence}114 observe(speaker(fullState, qud, theta), utterance)115 return {price, valence}116117 }})118})119120var listenerPosterior1 = pragmaticListener("expensive")121var listenerPosterior2 = pragmaticListener('notExpensive')122123print('Pragmatic listener hears "expensive":')124viz(listenerPosterior1)125print('Pragmatic listener hears "not expensive":')126viz(listenerPosterior2)127128var ANSWER = (pragmaticListener("expensive"));
1# RSA adjectives + QUD model (frankie + shane). Pragmatic listener hears "expensive";2# query is the joint posterior over {price, valence}.3# Every RSA level is a genuine Pyro enumerated model (config_enumerate +4# TraceEnum_ELBO.compute_marginals); each level's normalized posterior comes from Pyro's5# engine, and lower-level log-scores feed the next level only through pyro.factor.6#7# State = (price, valence) with valence|price ~ Bernoulli(valence_p[price]).8# meaning depends only on price, so under L0 valence is conditionally independent of the9# evidence given price: the L0 joint factors as L0(price)*prior(valence|price), where the10# L0 price posterior is produced by Pyro inference (compute_marginals) and valence|price is11# its (un-inferred) prior. The speaker scores L0's score of the QUD projection of its12# fullState; the pragmatic listener enumerates (price, valence, qud, theta) and factors in13# the speaker's log-score of "expensive".1415utterances = ["expensive", "notExpensive"]16prices = [50, 500, 1000, 5000, 10000]17price_t = torch.tensor([float(p) for p in prices])18price_prior = torch.tensor([0.8070, 0.1070, 0.0434, 0.0223, 0.0203])19valence_p = torch.tensor([0.3173, 0.7920, 0.8933, 0.9524, 0.9864]) # P(valence=True | price)20quds = ["price", "valence", "priceValence"]21alpha = 1.022cost = {"expensive": 0.0, "notExpensive": 1.0}23NEG = torch.tensor(float("-inf"))24ZERO = torch.tensor(0.0)2526def meaning_mask(utterance, theta):27 if utterance == "expensive":28 return price_t >= theta29 return price_t <= theta # notExpensive3031# Literal listener L0: a Pyro enumerated model over price (uniformDraw over prices) with the32# meaning factor. compute_marginals returns the normalized price posterior (log-probs).33def literal_price_logprobs(utterance, theta):34 mask = meaning_mask(utterance, theta)35 @pyro.infer.config_enumerate36 def m():37 pi = pyro.sample("price", dist.Categorical(torch.ones(len(prices)) / len(prices)))38 pyro.factor("meaning", torch.where(mask[pi], ZERO, NEG))39 return None40 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(m, lambda: None)41 pm = marg["price"]42 sup = pm.enumerate_support()43 lp = pm.log_prob(sup)44 out = torch.full((len(prices),), float("-inf"))45 for j in range(len(sup)):46 out[int(sup[j].item())] = lp[j]47 return out4849lit_cache = {}50def literal_price_logprobs_cached(utterance, theta):51 key = (utterance, round(float(theta), 6))52 if key not in lit_cache:53 lit_cache[key] = literal_price_logprobs(utterance, theta)54 return lit_cache[key]5556# L0 log-score of a QUD answer for the speaker's known fullState (price index pidx, valence v).57# qud == price : score = log L0price[pidx]58# qud == valence : score = log sum_p L0price[p] * P(v | p)59# qud == priceValence: score = log( L0price[pidx] * P(v | pidx) )60def l0_qud_logscore(utterance, theta, qud, pidx, v):61 lpr = literal_price_logprobs_cached(utterance, theta) # log P(price) under L062 pr = lpr.exp()63 vp = valence_p if v else (1.0 - valence_p)64 if qud == "price":65 s = float(pr[pidx].item())66 elif qud == "valence":67 s = float((pr * vp).sum().item())68 else: # priceValence69 s = float((pr[pidx] * vp[pidx]).item())70 return math.log(s) if s > 0.0 else float("-inf")7172# Speaker S1(utterance | fullState, qud, theta): Pyro enumerated model over utterances with73# factor alpha*(logL0(qudAnswer|utt) - cost(utt)). compute_marginals normalizes.74def speaker_logprobs(theta, qud, pidx, v):75 score = torch.tensor([76 alpha * (l0_qud_logscore(u, theta, qud, pidx, v) - cost[u]) for u in utterances77 ])78 @pyro.infer.config_enumerate79 def m():80 ui = pyro.sample("utt", dist.Categorical(torch.ones(len(utterances)) / len(utterances)))81 pyro.factor("score", score[ui])82 return None83 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(m, lambda: None)84 um = marg["utt"]85 sup = um.enumerate_support()86 lp = um.log_prob(sup)87 out = torch.full((len(utterances),), float("-inf"))88 for j in range(len(sup)):89 out[int(sup[j].item())] = lp[j]90 return out9192target_idx = utterances.index("expensive")93thetas = list(prices) # thetaPrior = uniformDraw(prices)9495# Speaker log-score of "expensive" indexed by (price index, valence in {0,1}, qud, theta).96# Used as the pragmatic listener's factor.97spk = {}98for pidx in range(len(prices)):99 for v in (True, False):100 for qi, qud in enumerate(quds):101 for ti, theta in enumerate(thetas):102 spk[(pidx, int(v), qi, ti)] = speaker_logprobs(theta, qud, pidx, v)[target_idx]103104spk_table = torch.zeros(len(prices), 2, len(quds), len(thetas))105for (pidx, vi, qi, ti), val in spk.items():106 spk_table[pidx, vi, qi, ti] = val107108# Pragmatic listener L1: Pyro enumerated model over price, valence, qud, theta; factor in the109# speaker's log-score of "expensive". infer_discrete draws joint (price, valence) posterior110# samples (the query is a joint over two latents); aggregate the tuples into the distribution.111@pyro.infer.config_enumerate112def pragmatic_model():113 pi = pyro.sample("price", dist.Categorical(price_prior))114 vi = pyro.sample("valence", dist.Bernoulli(valence_p[pi]))115 qi = pyro.sample("qud", dist.Categorical(torch.ones(len(quds)) / len(quds)))116 ti = pyro.sample("theta", dist.Categorical(torch.ones(len(thetas)) / len(thetas)))117 pyro.factor("speaker", spk_table[pi, vi.long(), qi, ti])118 return pi, vi119120serving = pyro.infer.infer_discrete(121 pyro.infer.config_enumerate(pragmatic_model), first_available_dim=-1122)123124counts = Counter()125N = 8000126for _ in range(N):127 pi, vi = serving()128 p = prices[int(pi.item())]129 v = bool(vi.item() > 0.5)130 counts[(p, v)] += 1131132_json = __import__("json")133ANSWER = {134 _json.dumps({"price": p, "valence": v}, sort_keys=True): c / N135 for (p, v), c in counts.items()136}137
| check | status | evidence |
|---|---|---|
| GT self-consistency | ok | floor 0.0000 (tv) |
| solver re-derivation | accept | 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6 |
| cross-language (pyro vs webppl) | pass | d=0.0080 ≤ tol 0.0270 · floors 0.0135/0.0000 |
Intelligence states: {1, 2, 3, 4} (1=least, 4=most intelligent). Utterances: {"", "dumb as rocks", "dumb", "f*cking idiot"}. Literal semantics: each utterance has a probability of describing each state. For state indices 1–4 (in order): "" → [0.25, 0.25, 0.25, 0.25]; "dumb as rocks" → [0.45, 0.85, 0.20, 0.02]; "dumb" → [0.85, 0.95, 0.02, 0.02]; "f*cking idiot" → [0.95, 0.55, 0.02, 0.02]. Valence ("good" or "bad") prior per state: state 1 → P(good)=0.01; state 2 → P(good)=0.33; state 3 → P(good)=0.66; state 4 → P(good)=0.99. Arousal ("low" or "high") prior per state: state 1 → P(low)=0.1; state 2 → P(low)=0.7; state 3 → P(low)=0.7; state 4 → P(low)=0.1. Goals: {"goalState", "goalValence", "goalArousal"}, equal prior weights. phi is drawn uniformly from {0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.35, 0.40, 0.45, 0.50, 0.55, 0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90} (18 equally spaced values). Speaker optimality alpha=10. Antisocial scaling lambda=-1.25. The state prior is uniform over {1, 2, 3, 4}.
A multi-goal RSA model for teasing language. A literal listener hears an utterance and a goal; the utterance is interpreted stochastically (with the probability that the utterance describes the state given its literal semantics weight); the listener returns both the goal-relevant QUD answer and the state. A pragmatic speaker has a mixed utility: the epistemic component is the literal listener's log-probability of recovering the correct QUD answer; the antisocial component is lambda times the expected state under the literal listener's posterior state distribution given the utterance (the value function is the state value itself — no rescaling or normalization). The speaker's total utility is phi times the epistemic component plus (1-phi) times the antisocial component, exponentiated by alpha. A pragmatic listener samples a state, valence, arousal, goal, and phi, and conditions on the speaker's utterance choice.
The posterior marginal distribution over the intelligence state (integer values 1–4) for a pragmatic listener who hears "dumb as rocks".
answer spec
{
"kind": "dist",
"domain": "int"
}system prompt
(system prompt loads here)
webppl primer
(primer loads here)
1// Level of intelligence2var states = [1,2,3,4]34// What can be said5var utterances = ["", "dumb as rocks", "dumb", "f*cking idiot"]67// Correspondence of utterances to states8var literalSemantics = {9 "" : [.25,.25,.25,.25],10 "dumb as rocks" : [0.45,0.85,0.20,.02],11 "dumb" : [.85,.95,.02,.02],12 "f*cking idiot" : [.95,.55,.02,.02]13}1415// Determine whether the utterance describes the state16// Flip a coin with the literalSemantics weight17// *state - 1 because of 0-indexing*18var meaning = function(utterance, state){19 return flip(literalSemantics[utterance][state - 1]);20}2122// Whether the speaker feels good or bad about the listener's action.23// Speakers' attitudes toward the true state of the world (e.g., the valence of their affect), which is modeled simply as a binary positive/negative variable (representing whether or not the speaker is repulsed by the listener's behavior).24var valencePrior = function(state) {25 state === 1 ? flip(0.01) ? "good" : "bad" :26 state === 2 ? flip(0.33) ? "good" : "bad" :27 state === 3 ? flip(0.66) ? "good" : "bad" :28 state === 4 ? flip(0.99) ? "good" : "bad" :29 true}3031// How amplified the speaker feels about the listener being dumb to different degrees.32var arousals = ["low", "high"]3334//How passionate/aroused the listener feels about how intelligent the person is35var arousalPrior = function(state) {36 state === 1 ? categorical([0.1, 0.9], arousals) :37 state === 2 ? categorical([0.7, 0.3], arousals) :38 state === 3 ? categorical([0.7, 0.3], arousals) :39 state === 4 ? categorical([0.1, 0.9], arousals) :40 true41}4243// A list of strings of QUD choices44var goals = ["goalState", "goalValence", "goalArousal"]4546// There are 3 possible goals with a flat prior47var goalPrior = function() {48 categorical([1, 1, 1], goals)49}5051// A speaker's goal is satisfied if the listener infers the correct52// and relevant information.53var goalState = function(goal, state, valence, arousal) {54 goal === "goalState" ? state :55 goal === "goalValence" ? valence :56 goal === "goalArousal" ? arousal :57 true58}5960// literal listener61var literalListener = function(utterance, goal) {62 Infer({model: function(){63 var state = uniformDraw(states)64 var valence = valencePrior(state)65 var arousal = arousalPrior(state)66 var m = meaning(utterance, state);67 condition(m);68 return {QUDanswer: goalState(goal,state,valence,arousal), state: state}69 }})70}7172// value function scales social utility by a parameter lambda73var lambda = -1.2574var valueFunction = function(s){75 return lambda * s76};7778var alpha = 1079var speaker1 = function(state, phi, goal, valence, arousal) {80 Infer({model: function(){81 var utterance = uniformDraw(utterances)82 var QUDanswer = goalState(goal,state,valence,arousal)83 var L0_posterior = literalListener(utterance,goal)84 var L0_stateMarginal = marginalize(L0_posterior,"state")85 var L0_qudAnswerMarginal = marginalize(L0_posterior,"QUDanswer")86 var utility = {87 epistemic: L0_qudAnswerMarginal.score(QUDanswer),88 antisocial: expectation(L0_stateMarginal, valueFunction)89 }90 var speakerUtility = phi * utility.epistemic +91 (1 - phi) * utility.antisocial92 factor(alpha * speakerUtility)93 return utterance94 }})95}9697//pragmatic listener98var pragmaticListener = function(utterance) {99 Infer({model: function(){100 var state = uniformDraw(states)101 var valence = valencePrior(state)102 var arousal = arousalPrior(state)103 var goal = goalPrior()104//0.05 corresponds to the starting point of phi while 0.95 is the endpoint.105//The last 0.05 corresponds to the interval length106 var phi = uniformDraw(_.range(0.05, 0.95, 0.05))107 var S1 = speaker1(state, phi, goal, valence, arousal)108 observe(S1, utterance)109 return { state, phi, goal, valence, arousal }110 }})111}112viz.marginals(pragmaticListener(""))113viz.marginals(pragmaticListener("dumb as rocks"))114viz.marginals(pragmaticListener("dumb"))115viz.marginals(pragmaticListener("f*cking idiot"))116117var ANSWER = marginalize(pragmaticListener("dumb as rocks"), "state");
12states = [1, 2, 3, 4]3utterances = ["", "dumb as rocks", "dumb", "f*cking idiot"]4literalSemantics = {5 "": [0.25, 0.25, 0.25, 0.25],6 "dumb as rocks": [0.45, 0.85, 0.20, 0.02],7 "dumb": [0.85, 0.95, 0.02, 0.02],8 "f*cking idiot": [0.95, 0.55, 0.02, 0.02],9}10valence_good = {1: 0.01, 2: 0.33, 3: 0.66, 4: 0.99}11arousal_low = {1: 0.1, 2: 0.7, 3: 0.7, 4: 0.1}12goals = ["goalState", "goalValence", "goalArousal"]13phis = [round(0.05 + 0.05 * i, 10) for i in range(18)] # 0.05..0.90 step 0.0514alpha = 10.015lam = -1.2516valence_labels = ["good", "bad"]17arousal_labels = ["low", "high"]1819def _elbo():20 return pyro.infer.TraceEnum_ELBO(max_plate_nesting=0)2122_LL_CACHE = {}23def literal_listener_marginals(utterance, goal):24 key = (utterance, goal)25 if key in _LL_CACHE:26 return _LL_CACHE[key]27 sem = torch.tensor(literalSemantics[utterance])2829 @pyro.infer.config_enumerate30 def model():31 si = pyro.sample("state", dist.Categorical(torch.ones(4) / 4))32 pg = torch.tensor([valence_good[s] for s in states])[si]33 pyro.sample("valence", dist.Categorical(torch.stack([pg, 1 - pg], -1)))34 pl = torch.tensor([arousal_low[s] for s in states])[si]35 pyro.sample("arousal", dist.Categorical(torch.stack([pl, 1 - pl], -1)))36 pyro.factor("meaning", torch.log(sem[si])) # condition on literal meaning37 return si3839 marg = _elbo().compute_marginals(model, lambda: None)40 _LL_CACHE[key] = marg41 return marg4243def qud_answer(goal, state, valence, arousal):44 if goal == "goalState":45 return ("state", state)46 if goal == "goalValence":47 return ("valence", valence)48 return ("arousal", arousal)4950def ll_logp_qud(utterance, goal, ans):51 marg = literal_listener_marginals(utterance, goal)52 kind, val = ans53 if kind == "state":54 d = marg["state"]; idx = states.index(val)55 elif kind == "valence":56 d = marg["valence"]; idx = valence_labels.index(val)57 else:58 d = marg["arousal"]; idx = arousal_labels.index(val)59 return float(d.log_prob(torch.tensor(idx)))6061def ll_state_expectation(utterance, goal):62 d = literal_listener_marginals(utterance, goal)["state"]63 sup = d.enumerate_support()64 p = torch.exp(d.log_prob(sup))65 return float(sum(lam * states[int(i)] * float(pp) for i, pp in zip(sup, p)))6667_S1_CACHE = {}68def speaker1(state, phi, goal, valence, arousal):69 key = (state, phi, goal, valence, arousal)70 if key in _S1_CACHE:71 return _S1_CACHE[key]72 ans = qud_answer(goal, state, valence, arousal)73 utils = torch.tensor([74 phi * ll_logp_qud(u, goal, ans) + (1 - phi) * ll_state_expectation(u, goal)75 for u in utterances76 ])7778 @pyro.infer.config_enumerate79 def model():80 ui = pyro.sample("utt", dist.Categorical(torch.ones(len(utterances)) / len(utterances)))81 pyro.factor("util", alpha * utils[ui])82 return ui8384 marg = _elbo().compute_marginals(model, lambda: None)["utt"]85 _S1_CACHE[key] = marg86 return marg8788target = "dumb as rocks"89target_idx = utterances.index(target)9091# enumerated joint config space for the pragmatic listener92cfgs = []93for st in states:94 for val in valence_labels:95 for ar in arousal_labels:96 for goal in goals:97 for phi in phis:98 cfgs.append((st, val, ar, goal, phi))99100# prior over each config (valence/arousal priors depend on state)101prior = torch.tensor([102 (1.0 / len(states))103 * (valence_good[c[0]] if c[1] == "good" else 1 - valence_good[c[0]])104 * (arousal_low[c[0]] if c[2] == "low" else 1 - arousal_low[c[0]])105 * (1.0 / len(goals))106 * (1.0 / len(phis))107 for c in cfgs108])109110# speaker likelihood of the target utterance for each config (from S1 inference)111lik = torch.tensor([112 float(torch.exp(speaker1(c[0], c[4], c[3], c[1], c[2]).log_prob(torch.tensor(target_idx))))113 for c in cfgs114])115116@pyro.infer.config_enumerate117def pragmatic_model():118 ci = pyro.sample("cfg", dist.Categorical(prior / prior.sum()))119 pyro.factor("speaker", torch.log(lik[ci] + 1e-300)) # observe(S1, utterance)120 return ci121122cfg_marg = _elbo().compute_marginals(pragmatic_model, lambda: None)["cfg"]123sup = cfg_marg.enumerate_support()124p = torch.exp(cfg_marg.log_prob(sup))125state_post = {s: 0.0 for s in states}126for i, pp in zip(sup, p):127 state_post[cfgs[int(i)][0]] += float(pp)128129ANSWER = state_post130
| check | status | evidence |
|---|---|---|
| GT self-consistency | ok | floor 0.0000 (w1) |
| solver re-derivation | accept | 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6 |
| cross-language (pyro vs webppl) | pass | d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000 |
Ice cream price grid: [1, 4, 6, 10, 14, 18, 22, 30, 34, 38] dollars. Prior probabilities over prices (proportional, summing to 1 after normalization): [0.01, 0.50, 0.85, 0.63, 0.35, 0.10, 0.04, 0.02, 0.02, 0.01]. Utterances: {"expensive", "null", "cheap"}. Utterance costs: expensive=1, cheap=2, null=0. The utterance prior is proportional to exp(-cost). The threshold theta is drawn uniformly from the price grid. The QUD is fixed to "what is the price?" (the QUD always returns the raw price value). Speaker optimality parameter alpha=1. Meaning: "expensive" is true when price >= theta; "cheap" is true when price <= theta; "null" is always true.
A Rational Speech Act model for vague adjectives. A literal listener hears an utterance and a threshold theta, conditions on the utterance meaning holding for the sampled price, and returns the QUD-relevant value (here: the raw price). A pragmatic speaker weighs utterances by exp(alpha * literal-listener-score) discounted by utterance cost. A pragmatic listener samples a price from the price prior and a threshold from the uniform grid prior, observes the speaker's utterance choice, and returns the joint {price, theta, qud}.
The posterior expected price (in dollars) for a pragmatic listener who hears "expensive".
answer spec
{
"kind": "value",
"domain": "real"
}system prompt
(system prompt loads here)
webppl primer
(primer loads here)
1var marginalize = function(dist, key){2 return Infer( {model: function(){3 return sample(dist)[key];4 }})5};6///78var icecream = {9 "prices": [1, 4, 6, 10, 14, 18, 22, 30, 34, 38],10"probabilities": [0.01, 0.50, 0.85, 0.63, 0.35, 0.10, 0.04, 0.02, 0.02, 0.01]11};1213var statePrior = function() {14 return categorical(icecream.probabilities, icecream.prices);15};1617var thetaPrior = function() {18 return uniformDraw(icecream.prices);19};2021var alpha = 1; // optimality parameter2223var utterances = ["expensive", "null", "cheap"];24var cost = {25 "expensive": 1,26 "cheap": 2,27 "null": 028};29var utterancePrior = function() {30 var uttProbs = map(function(u) {return Math.exp(-cost[u]) }, utterances);31 return categorical(uttProbs, utterances);32};3334var meaning = function(utterance, price, theta) {35 utterance == "expensive" ? price >= theta :36 utterance == "cheap" ? price <= theta :37 true38};39// QUDs40var QUDs = ["expensive?","less than 15?","what is the price?"]41var QUDPrior = function() {42 //uniformDraw(QUDs)43// categorical([1,10,1],QUDs)// this is equivalent to uniformDraw44categorical([0,0,1],QUDs)// this is your baseline version45}46var QUDFun = function(QUD,state) {47 QUD == "expensive?" ? state >= 15 :48 QUD == "less than 15?" ? state <= 15 :49 state;50};5152var literalListener = cache(function(utterance, theta, QUD) {53 return Infer({model:function() {54 var price = statePrior()55 var qPrice = QUDFun(QUD,price)56 condition(meaning(utterance, price, theta))57 return qPrice;58}})59});6061var speaker = cache(function(price, theta,QUD) {62 return Infer( {model: function() {63 var utterance = utterancePrior()64 var qPrice= QUDFun(QUD, price)65 factor( alpha * literalListener(utterance, theta, QUD).score(qPrice) );66 return utterance;67 }});68});6970var pragmaticListener = function(utterance) {71 return Infer({model: function() {72 var price = statePrior()73 var theta = thetaPrior()74 var QUD =QUDPrior()75 factor(speaker(price, theta, QUD).score(utterance));76 return { price: price, theta: theta, qud: QUD};77 }})78}79808182print ('cheap icecream')83var expensiveIcecream = pragmaticListener('expensive');84print(expectation(marginalize(expensiveIcecream, "price")))85viz.hist(marginalize(expensiveIcecream, "price"));86viz.hist(marginalize(expensiveIcecream, "theta"));87viz.hist(marginalize(expensiveIcecream, "qud"));8889var ANSWER = (expectation(marginalize(pragmaticListener('expensive'), 'price')));90
1# RSA icecream adjectives + QUD model. Pragmatic listener hears "expensive".2# QUDPrior puts all mass on "what is the price?" (categorical([0,0,1], QUDs)) so QUDFun is the3# identity on the price -> the qud answer is the price itself.4# Query: posterior expected price for the pragmatic listener.5# Every RSA level is a genuine Pyro model run under exact enumeration6# (config_enumerate + TraceEnum_ELBO.compute_marginals). Each level's normalized posterior7# comes from Pyro's engine; lower-level log-scores enter the next level only as pyro.factor8# terms. No level's distribution is normalized by hand.910prices = [1, 4, 6, 10, 14, 18, 22, 30, 34, 38]11price_probs = [0.01, 0.50, 0.85, 0.63, 0.35, 0.10, 0.04, 0.02, 0.02, 0.01]12state_prior = torch.tensor(price_probs) # dist.Categorical normalizes its probs internally13price_t = torch.tensor([float(p) for p in prices])14alpha = 1.015utterances = ["expensive", "null", "cheap"]16cost = {"expensive": 1.0, "cheap": 2.0, "null": 0.0}17utt_prior = torch.tensor([math.exp(-cost[u]) for u in utterances]) # categorical normalizes18NEG = torch.tensor(float("-inf"))19ZERO = torch.tensor(0.0)2021# meaning(utterance, price, theta) as a boolean tensor over the enumerated price index.22def meaning_mask(utterance, theta):23 if utterance == "expensive":24 return price_t >= theta25 if utterance == "cheap":26 return price_t <= theta27 return torch.ones(len(prices), dtype=torch.bool)2829# Literal listener L0(price | utterance, theta): a Pyro enumerated model.30# Returns the full log P(price) vector indexed by price index (from compute_marginals).31def literal_logprobs(utterance, theta):32 mask = meaning_mask(utterance, theta)33 @pyro.infer.config_enumerate34 def m():35 si = pyro.sample("price", dist.Categorical(state_prior))36 pyro.factor("meaning", torch.where(mask[si], ZERO, NEG))37 return None38 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(m, lambda: None)39 pm = marg["price"]40 sup = pm.enumerate_support()41 lp = pm.log_prob(sup)42 out = torch.full((len(prices),), float("-inf"))43 for j in range(len(sup)):44 out[int(sup[j].item())] = lp[j]45 return out4647lit_cache = {}48def literal_logprobs_cached(utterance, theta):49 key = (utterance, round(float(theta), 6))50 if key not in lit_cache:51 lit_cache[key] = literal_logprobs(utterance, theta)52 return lit_cache[key]5354# Speaker S1(utterance | price, theta): a Pyro enumerated model over utterances.55# Factor = alpha * log L0(price | utterance, theta). compute_marginals normalizes.56# Returns log P(utterance) vector indexed by utterance index.57def speaker_logprobs(price_index, theta):58 # log L0 of THIS price under each utterance, as a tensor over the utterance index.59 score = torch.tensor(60 [float(literal_logprobs_cached(u, theta)[price_index].item()) for u in utterances]61 )62 @pyro.infer.config_enumerate63 def m():64 ui = pyro.sample("utt", dist.Categorical(utt_prior))65 pyro.factor("score", alpha * score[ui])66 return None67 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(m, lambda: None)68 um = marg["utt"]69 sup = um.enumerate_support()70 lp = um.log_prob(sup)71 out = torch.full((len(utterances),), float("-inf"))72 for j in range(len(sup)):73 out[int(sup[j].item())] = lp[j]74 return out7576spk_cache = {}77def speaker_logprobs_cached(price_index, theta):78 key = (price_index, round(float(theta), 6))79 if key not in spk_cache:80 spk_cache[key] = speaker_logprobs(price_index, theta)81 return spk_cache[key]8283# Pragmatic listener L1(price | "expensive"): a Pyro enumerated model over price and theta.84# Factor = log S1("expensive" | price, theta). compute_marginals gives the price posterior.85target_idx = utterances.index("expensive")86thetas = list(prices) # thetaPrior = uniformDraw(prices)8788# Build the speaker log-score table S1("expensive" | price_index, theta) for the factor.89spk_table = torch.zeros(len(prices), len(thetas))90for pi in range(len(prices)):91 for ti, theta in enumerate(thetas):92 spk_table[pi, ti] = speaker_logprobs_cached(pi, theta)[target_idx]9394@pyro.infer.config_enumerate95def pragmatic_model():96 pi = pyro.sample("price", dist.Categorical(state_prior))97 ti = pyro.sample("theta", dist.Categorical(torch.ones(len(thetas)) / len(thetas)))98 pyro.factor("speaker", spk_table[pi, ti])99 return None100101marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(pragmatic_model, lambda: None)102pm = marg["price"]103sup = pm.enumerate_support()104probs = pm.log_prob(sup).exp()105ANSWER = float(sum(prices[int(sup[j].item())] * float(probs[j].item()) for j in range(len(sup))))106
12.9027
12.9027
| check | status | evidence |
|---|---|---|
| GT self-consistency | ok | floor 0.0000 (absdiff) |
| solver re-derivation | accept | 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6 |
| cross-language (pyro vs webppl) | pass | d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000 |
Two possible categories for John: "whale" and "person". Prior probability of whale is 0.01; prior probability of person is 0.99. Utterances: {"whale", "person"}, with equal prior probability. Features of John: "large", "graceful", "majestic", each binary (0 or 1). Feature sets are the 8 combinations of these three binary features in the order: {1,1,1}, {1,1,0}, {1,0,1}, {1,0,0}, {0,1,1}, {0,1,0}, {0,0,1}, {0,0,0} (where fields are large, graceful, majestic). Feature-set prior probabilities given category whale (in the same order): [0.30592786494628, 0.138078454222818, 0.179114768847673, 0.13098781834847, 0.0947267162507846, 0.0531420411185539, 0.0601520520596695, 0.0378702842057509]. Feature-set prior probabilities given category person (in the same order): [0.11687632453038, 0.105787535267869, 0.11568145784997, 0.130847056136141, 0.15288225956497, 0.128098151176801, 0.114694702836614, 0.135132512637255]. Speaker goals: {"large", "graceful", "majestic"}, with equal prior probability. The goal selects the single feature the speaker aims to communicate. Speaker optimality parameter alpha=3.
A Rational Speech Act model for metaphor. A literal listener hears an utterance and a goal; it conditions on the category matching the utterance literally, then returns the value of the goal feature. A pragmatic speaker observes the true category and feature set, selects a goal, and chooses an utterance with probability proportional to exp(alpha * literal-listener-score on the goal feature value). A pragmatic listener samples a category from the prior, samples a feature set from the category-conditional prior, samples a goal from the goal prior, and conditions on the speaker's utterance choice. The listener returns the joint {category, large, graceful, majestic}.
The posterior joint distribution over {category, large, graceful, majestic} for a pragmatic listener who hears the utterance "whale".
answer spec
{
"kind": "dist",
"domain": "finite",
"labels": {
"record": {
"category": "string",
"large": "int",
"graceful": "int",
"majestic": "int"
}
}
}system prompt
(system prompt loads here)
webppl primer
(primer loads here)
1// John could either be a whale or a person.2var categories = ["whale", "person"]34// It is extremely unlikely that John is actually a whale.5var categoriesPrior = function() {6 categorical([0.01, 0.99], categories)7}89// The speaker could either say "John is a whale" or "John is a person."10var utterances = ["whale", "person"]1112// The utterances are equally costly.13var utterancePrior = function() {14 categorical([1,1], utterances)15}1617// The features of John being considered are "large", "graceful",18// "majestic." Features are binary.19var featureSets = [20 {large : 1, graceful : 1, majestic : 1},21 {large : 1, graceful : 1, majestic : 0},22 {large : 1, graceful : 0, majestic : 1},23 {large : 1, graceful : 0, majestic : 0},24 {large : 0, graceful : 1, majestic : 1},25 {large : 0, graceful : 1, majestic : 0},26 {large : 0, graceful : 0, majestic : 1},27 {large : 0, graceful : 0, majestic : 0}28]2930// information about feature priors (probabilistic world knowledge)31// obtained by an experimental study (see paper)32var featureSetPrior = function(category) {33 category === "whale" ? categorical([0.30592786494628, 0.138078454222818,34 0.179114768847673, 0.13098781834847,35 0.0947267162507846, 0.0531420411185539,36 0.0601520520596695, 0.0378702842057509],37 featureSets) :38 category === "person" ? categorical([0.11687632453038, 0.105787535267869,39 0.11568145784997, 0.130847056136141,40 0.15288225956497, 0.128098151176801,41 0.114694702836614, 0.135132512637255],42 featureSets) :43 true44}4546// Speaker's possible goals are to communicate feature 1, 2, or 347var goals = ["large", "graceful", "majestic"]4849// Prior probability of speaker's goal is set to uniform but can50// change with context/QUD.51var goalPrior = function() {52 categorical([1,1,1], goals)53}5455// Speaker optimality parameter56var alpha = 35758// Check if interpreted category is identical to utterance59var literalInterpretation = function(utterance, category) {60 utterance === category61}6263// Check if goal is satisfied64var goalState = function(goal, featureSet) {65 goal === "large" ? featureSet.large :66 goal === "graceful" ? featureSet.graceful :67 goal === "majestic" ? featureSet.majestic :68 true69}7071// Define a literal listener72var literalListener = function(utterance, goal) {73 Infer({model: function() {74 var category = uniformDraw(categories)75 var featureSet = featureSetPrior(category)76 condition(literalInterpretation(utterance, category))77 return goalState(goal, featureSet)78 }})79}8081// Speaker model82var speaker = function(large, graceful, majestic, goal) {83 Infer({model: function() {84 var utterance = utterancePrior()85 factor(alpha *86 literalListener(utterance,goal).score(goalState(goal, {large : large, graceful : graceful, majestic : majestic})))87 return utterance88 }})89}9091// Define a pragmatic listener92var pragmaticListener = function(utterance) {93 Infer({model: function() {94 var category = categoriesPrior()95 var featureSet = featureSetPrior(category)96 var large = featureSet.large97 var graceful = featureSet.graceful98 var majestic = featureSet.majestic99 var goal = goalPrior()100 observe(speaker(large, graceful, majestic, goal), utterance)101 return {category, large, graceful, majestic}102 }})103}104105display("The pragmatic listener's interpretation when the speaker says whale")106viz.table(pragmaticListener("whale"))107108display("The pragmatic listener's interpretation when the speaker says person")109viz.table(pragmaticListener("person"))110111var ANSWER = (pragmaticListener("whale"));
12categories = ["whale", "person"]3cat_prior = torch.tensor([0.01, 0.99])4featureSets = [5 {"large": 1, "graceful": 1, "majestic": 1},6 {"large": 1, "graceful": 1, "majestic": 0},7 {"large": 1, "graceful": 0, "majestic": 1},8 {"large": 1, "graceful": 0, "majestic": 0},9 {"large": 0, "graceful": 1, "majestic": 1},10 {"large": 0, "graceful": 1, "majestic": 0},11 {"large": 0, "graceful": 0, "majestic": 1},12 {"large": 0, "graceful": 0, "majestic": 0},13]14fs_probs = {15 "whale": torch.tensor([0.30592786494628, 0.138078454222818, 0.179114768847673,16 0.13098781834847, 0.0947267162507846, 0.0531420411185539,17 0.0601520520596695, 0.0378702842057509]),18 "person": torch.tensor([0.11687632453038, 0.105787535267869, 0.11568145784997,19 0.130847056136141, 0.15288225956497, 0.128098151176801,20 0.114694702836614, 0.135132512637255]),21}22fs_stack = torch.stack([fs_probs["whale"], fs_probs["person"]]) # (2, 8)23goals = ["large", "graceful", "majestic"]24goal_prior = torch.tensor([1.0, 1.0, 1.0])25utterances = ["whale", "person"]26alpha = 3.027NEG = torch.tensor(-1e30)28ZERO = torch.tensor(0.0)293031def goal_state(goal, fs_idx):32 return featureSets[fs_idx][goal]333435# Literal listener: inferred distribution over the goal-relevant feature value.36def literal_listener(utterance, goal):37 forced = categories.index(utterance)3839 @pyro.infer.config_enumerate40 def model():41 cat = pyro.sample("cat", dist.Categorical(torch.ones(2)))42 pyro.sample("fs", dist.Categorical(fs_stack[cat]))43 # literalInterpretation: utterance === category44 pyro.factor("cond", torch.where(cat == forced, ZERO, NEG))45 return None4647 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)48 fsm = marg["fs"]49 sup = fsm.enumerate_support()50 pr = fsm.log_prob(sup).exp()51 out = {0: 0.0, 1: 0.0}52 for i, p in zip(sup.tolist(), pr.tolist()):53 out[goal_state(goal, int(i))] += p54 return out555657# Speaker: inferred distribution over utterances given the goal feature value.58def speaker_logprobs(goal_value, goal):59 l0 = {u: literal_listener(u, goal) for u in utterances}6061 @pyro.infer.config_enumerate62 def model():63 u = pyro.sample("u", dist.Categorical(torch.ones(len(utterances))))64 sc = torch.tensor([math.log(l0[uu][goal_value]) if l0[uu][goal_value] > 0 else -1e3065 for uu in utterances])66 pyro.factor("sc", alpha * sc[u])67 return None6869 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)70 um = marg["u"]71 sup = um.enumerate_support()72 pr = um.log_prob(sup).exp()73 return {utterances[int(i)]: p for i, p in zip(sup.tolist(), pr.tolist())}747576# Pragmatic listener: joint posterior over (category, featureSet) via exact77# enumeration of a single combined latent; the engine marginalizes it.78def pragmatic_listener(utterance):79 spk_cache = {}8081 def spk_logp(fs_idx, goal_idx):82 goal = goals[goal_idx]83 gv = goal_state(goal, fs_idx)84 key = (goal, gv)85 if key not in spk_cache:86 spk_cache[key] = speaker_logprobs(gv, goal)87 p = spk_cache[key].get(utterance, 0.0)88 return math.log(p) if p > 0 else -1e308990 combos = [(c, f) for c in range(2) for f in range(8)]91 joint_prior = torch.tensor([cat_prior[c].item() * fs_stack[c][f].item() for (c, f) in combos])92 score_t = torch.zeros(len(combos), 3)93 for j, (c, f) in enumerate(combos):94 for g in range(3):95 score_t[j, g] = spk_logp(f, g)9697 @pyro.infer.config_enumerate98 def model():99 cf = pyro.sample("cf", dist.Categorical(joint_prior))100 goal = pyro.sample("goal", dist.Categorical(goal_prior))101 pyro.factor("obs", score_t[cf, goal])102 return None103104 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)105 cfm = marg["cf"]106 sup = cfm.enumerate_support()107 pr = cfm.log_prob(sup).exp()108 out = {}109 for j, p in zip(sup.tolist(), pr.tolist()):110 c, f = combos[int(j)]111 rec = {"category": categories[c], "large": featureSets[f]["large"],112 "graceful": featureSets[f]["graceful"], "majestic": featureSets[f]["majestic"]}113 out[json.dumps(rec, sort_keys=True)] = p114 return out115116117ANSWER = pragmatic_listener("whale")118
| check | status | evidence |
|---|---|---|
| GT self-consistency | ok | floor 0.0000 (tv) |
| solver re-derivation | accept | 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6 |
| cross-language (pyro vs webppl) | pass | d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000 |
State space: four combinations of two binary dimensions black and white — {black:true, white:true}, {black:true, white:false}, {black:false, white:true}, {black:false, white:false}. Prior over states is uniform (probability 0.25 each). Utterances: {"blm", "nblm"}. Literal meanings: "blm" is true when black is true; "nblm" is true when black is false. Speaker optimality parameter alpha=1.
A two-level RSA model for a political speech act. A literal listener hears an utterance, samples a state from the uniform prior, conditions on the utterance's literal meaning, and returns the state. A pragmatic speaker observes the true state and selects utterances with probability proportional to exp(alpha * literal-listener-score on the state). A pragmatic listener samples a state from the uniform prior and conditions on the pragmatic speaker's utterance choice.
The posterior distribution over states {black, white} for a pragmatic listener who hears "blm".
answer spec
{
"kind": "dist",
"domain": "finite",
"labels": {
"record": {
"black": "bool",
"white": "bool"
}
}
}system prompt
(system prompt loads here)
webppl primer
(primer loads here)
1var alpha = 123var statePrior = function() {4 return Categorical({5 // ps: [.9,.04,.04,.02], // make it a priori likely that all lives matter6 ps: [.25,.25,.25,.25], // uniform prior7 // ps: [.04,.04,.9,.02], // make it a priori likely that only white lives matter8 vs: [{black:true, white:true},{black:true, white:false},{black:false, white:true},{black:false, white:false}]})9};1011// possible utterances12var utterancePrior = function() {13 return uniformDraw(['blm', 'nblm'])14// return uniformDraw(['blm', 'wlm', 'alm', 'nlm'])15};1617// meaning funtion to interpret the utterances18var literalMeanings = {19 blm: function(state) { return state["black"] },20 nblm: function(state) { return !state["black"] },21 wlm: function(state) { return state["white"]},22 alm: function(state) { return state["black"] && state["white"] },23 nlm: function(state) { return !state["black"] && !state["white"] }24};2526// literal listener27var literalListener = cache(function(utt) {28 return Infer({method:"enumerate"},29 function(){30 var state = sample(statePrior())31 var meaning = literalMeanings[utt]32 condition(meaning(state))33 return state34 })35});3637// pragmatic speaker38var speaker = cache(function(state) {39 return Infer({method:"enumerate"},40 function(){41 var utt = utterancePrior()42 factor(alpha * literalListener(utt).score(state))43 return utt44 })45});4647// pragmatic listener48var pragmaticListener = cache(function(utt) {49 return Infer({method:"enumerate"},50 function(){51 var state = sample(statePrior())52 observe(speaker(state),utt)53 return state54 })55});5657pragmaticListener("blm")5859var ANSWER = (pragmaticListener('blm'));
1# Two-level RSA. Every level is genuine Pyro exact enumeration inference.23alpha = 1.04states = [5 {"black": True, "white": True},6 {"black": True, "white": False},7 {"black": False, "white": True},8 {"black": False, "white": False},9]10utterances = ["blm", "nblm"]111213def literal_meaning(utt, state):14 if utt == "blm":15 return state["black"]16 return not state["black"]171819def enum_marginal(model, n):20 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(21 model, lambda: None22 )["x"]23 sup = marg.enumerate_support()24 probs = marg.log_prob(sup).exp()25 out = {int(s): 0.0 for s in range(n)}26 for s, p in zip(sup.tolist(), probs.tolist()):27 out[int(s)] = p28 return out293031def literal_listener(utt):32 # uniform prior over states, condition on meaning33 truths = torch.tensor(34 [1.0 if literal_meaning(utt, s) else 0.0 for s in states]35 )36 logt = torch.log(truths)3738 @pyro.infer.config_enumerate39 def model():40 x = pyro.sample("x", dist.Categorical(torch.ones(len(states))))41 pyro.factor("ev", logt[x])42 return x4344 return enum_marginal(model, len(states))454647def speaker(state_idx):48 # factor alpha * literalListener(utt).score(state)49 scores = []50 for u in utterances:51 d = literal_listener(u)52 p = d.get(state_idx, 0.0)53 scores.append(math.log(p) if p > 0 else float("-inf"))54 scores = torch.tensor(scores)5556 @pyro.infer.config_enumerate57 def model():58 x = pyro.sample("x", dist.Categorical(torch.ones(len(utterances))))59 pyro.factor("ev", alpha * scores[x])60 return x6162 return enum_marginal(model, len(utterances))636465def pragmatic_listener(utt):66 u_idx = utterances.index(utt)67 scores = []68 for i in range(len(states)):69 d = speaker(i)70 p = d.get(u_idx, 0.0)71 scores.append(math.log(p) if p > 0 else float("-inf"))72 scores = torch.tensor(scores)7374 @pyro.infer.config_enumerate75 def model():76 x = pyro.sample("x", dist.Categorical(torch.ones(len(states))))77 pyro.factor("ev", scores[x])78 return x7980 return enum_marginal(model, len(states))818283idx_dist = pragmatic_listener("blm")84ANSWER = {85 json.dumps(states[i], sort_keys=True): idx_dist.get(i, 0.0)86 for i in range(len(states))87}88
| check | status | evidence |
|---|---|---|
| GT self-consistency | ok | floor 0.0000 (tv) |
| solver re-derivation | accept | 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6 |
| cross-language (pyro vs webppl) | pass | d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000 |
Two possible categories: whale and person. The prior probability that someone is actually a whale is 0.01 (person: 0.99). Each entity can possess three binary features: large, graceful, majestic. Feature sets are all 8 binary combinations. Empirical feature-set priors by category: for whales, probabilities are [0.30592786494628, 0.138078454222818, 0.179114768847673, 0.13098781834847, 0.0947267162507846, 0.0531420411185539, 0.0601520520596695, 0.0378702842057509] over the feature sets ordered (large=1,graceful=1,majestic=1), (1,1,0), (1,0,1), (1,0,0), (0,1,1), (0,1,0), (0,0,1), (0,0,0); for persons, the corresponding probabilities are [0.11687632453038, 0.105787535267869, 0.11568145784997, 0.130847056136141, 0.15288225956497, 0.128098151176801, 0.114694702836614, 0.135132512637255]. The speaker has a goal prior that weights communicating 'large' with probability proportional to 5 and each of 'graceful' and 'majestic' with proportional weight 1. Two possible utterances — 'whale' and 'person' — with equal prior probability. Speaker optimality alpha = 3. Literal interpretation: utterance is true of a category if and only if the utterance label matches the category name. A goal is satisfied when the feature corresponding to that goal equals 1.
A literal listener conditions on the utterance matching the entity's category and returns whether the goal's corresponding feature is satisfied. A speaker who knows the entity's actual category and all three feature values chooses an utterance in proportion to exp(alpha times the literal listener's log-probability of the entity's actual goal-feature value — the log-probability that the feature equals 1 when the entity's goal feature is 1, and the log-probability that the feature equals 0 when it is 0). A pragmatic listener infers category and all three feature values jointly, marginalizing over the speaker's latent goal, by inverting the speaker model.
The posterior joint distribution over category and the three binary features (large, graceful, majestic) given that the pragmatic listener hears the utterance 'whale'.
answer spec
{
"kind": "dist",
"domain": "finite",
"labels": {
"record": {
"category": "string",
"large": "int",
"graceful": "int",
"majestic": "int"
}
}
}system prompt
(system prompt loads here)
webppl primer
(primer loads here)
1// John could either be a whale or a person.2var categories = ["whale", "person"]34// It is extremely unlikely that John is actually a whale.5var categoriesPrior = function() {6 categorical([0.01, 0.99], categories)7}89// The speaker could either say "John is a whale" or "John is a person."10var utterances = ["whale", "person"]1112// The utterances are equally costly.13var utterancePrior = function() {14 categorical([1,1], utterances)15}1617// The features of John being considered are "large", "graceful",18// "majestic." Features are binary.19var featureSets = [20 {large : 1, graceful : 1, majestic : 1},21 {large : 1, graceful : 1, majestic : 0},22 {large : 1, graceful : 0, majestic : 1},23 {large : 1, graceful : 0, majestic : 0},24 {large : 0, graceful : 1, majestic : 1},25 {large : 0, graceful : 1, majestic : 0},26 {large : 0, graceful : 0, majestic : 1},27 {large : 0, graceful : 0, majestic : 0}28]2930// information about feature priors (probabilistic world knowledge)31// obtained by an experimental study (see paper)32var featureSetPrior = function(category) {33 category === "whale" ? categorical([0.30592786494628, 0.138078454222818,34 0.179114768847673, 0.13098781834847,35 0.0947267162507846, 0.0531420411185539,36 0.0601520520596695, 0.0378702842057509],37 featureSets) :38 category === "person" ? categorical([0.11687632453038, 0.105787535267869,39 0.11568145784997, 0.130847056136141,40 0.15288225956497, 0.128098151176801,41 0.114694702836614, 0.135132512637255],42 featureSets) :43 true44}4546// Speaker's possible goals are to communicate feature 1, 2, or 347var goals = ["large", "graceful", "majestic"]4849//// Prior probability of speaker's goal is set to uniform but can50//// change with context/QUD.51var goalPrior = function() {52 categorical([5,1,1], goals)53}54// var goalPrior = function() {55// categorical([1,5,1], goals)56// }57// var goalPrior = function() {58// categorical([1,1,5], goals)59// }6061// Speaker optimality parameter62var alpha = 36364// Check if interpreted category is identical to utterance65var literalInterpretation = function(utterance, category) {66 utterance === category67}6869// Check if goal is satisfied70var goalState = function(goal, featureSet) {71 goal === "large" ? featureSet.large :72 goal === "graceful" ? featureSet.graceful :73 goal === "majestic" ? featureSet.majestic :74 true75}7677// Define a literal listener78var literalListener = function(utterance, goal) {79 Infer({model: function() {80 var category = uniformDraw(categories)81 var featureSet = featureSetPrior(category)82 condition(literalInterpretation(utterance, category))83 return goalState(goal, featureSet)84 }})85}8687// Speaker model88var speaker = function(large, graceful, majestic, goal) {89 Infer({model: function() {90 var utterance = utterancePrior()91 factor(alpha *92 literalListener(utterance,goal).score(goalState(goal, {large : large, graceful : graceful, majestic : majestic})))93 return utterance94 }})95}9697// Define a pragmatic listener98var pragmaticListener = function(utterance) {99 Infer({model: function() {100 var category = categoriesPrior()101 var featureSet = featureSetPrior(category)102 var large = featureSet.large103 var graceful = featureSet.graceful104 var majestic = featureSet.majestic105 var goal = goalPrior()106 observe(speaker(large, graceful, majestic, goal), utterance)107 return {category, large, graceful, majestic}108 }})109}110111viz.table(pragmaticListener("whale"))112113var ANSWER = (pragmaticListener("whale"));
12categories = ["whale", "person"]3cat_prior = torch.tensor([0.01, 0.99])4featureSets = [5 {"large": 1, "graceful": 1, "majestic": 1},6 {"large": 1, "graceful": 1, "majestic": 0},7 {"large": 1, "graceful": 0, "majestic": 1},8 {"large": 1, "graceful": 0, "majestic": 0},9 {"large": 0, "graceful": 1, "majestic": 1},10 {"large": 0, "graceful": 1, "majestic": 0},11 {"large": 0, "graceful": 0, "majestic": 1},12 {"large": 0, "graceful": 0, "majestic": 0},13]14fs_probs = {15 "whale": torch.tensor([0.30592786494628, 0.138078454222818, 0.179114768847673,16 0.13098781834847, 0.0947267162507846, 0.0531420411185539,17 0.0601520520596695, 0.0378702842057509]),18 "person": torch.tensor([0.11687632453038, 0.105787535267869, 0.11568145784997,19 0.130847056136141, 0.15288225956497, 0.128098151176801,20 0.114694702836614, 0.135132512637255]),21}22fs_stack = torch.stack([fs_probs["whale"], fs_probs["person"]]) # (2, 8)23goals = ["large", "graceful", "majestic"]24goal_prior = torch.tensor([5.0, 1.0, 1.0])25utterances = ["whale", "person"]26alpha = 3.027NEG = torch.tensor(-1e30)28ZERO = torch.tensor(0.0)293031def goal_state(goal, fs_idx):32 return featureSets[fs_idx][goal]333435# Literal listener: inferred distribution over the goal-relevant feature value.36def literal_listener(utterance, goal):37 forced = categories.index(utterance)3839 @pyro.infer.config_enumerate40 def model():41 cat = pyro.sample("cat", dist.Categorical(torch.ones(2)))42 pyro.sample("fs", dist.Categorical(fs_stack[cat]))43 # literalInterpretation: utterance === category44 pyro.factor("cond", torch.where(cat == forced, ZERO, NEG))45 return None4647 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)48 fsm = marg["fs"]49 sup = fsm.enumerate_support()50 pr = fsm.log_prob(sup).exp()51 out = {0: 0.0, 1: 0.0}52 for i, p in zip(sup.tolist(), pr.tolist()):53 out[goal_state(goal, int(i))] += p54 return out555657# Speaker: inferred distribution over utterances given the goal feature value.58def speaker_logprobs(goal_value, goal):59 l0 = {u: literal_listener(u, goal) for u in utterances}6061 @pyro.infer.config_enumerate62 def model():63 u = pyro.sample("u", dist.Categorical(torch.ones(len(utterances))))64 sc = torch.tensor([math.log(l0[uu][goal_value]) if l0[uu][goal_value] > 0 else -1e3065 for uu in utterances])66 pyro.factor("sc", alpha * sc[u])67 return None6869 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)70 um = marg["u"]71 sup = um.enumerate_support()72 pr = um.log_prob(sup).exp()73 return {utterances[int(i)]: p for i, p in zip(sup.tolist(), pr.tolist())}747576# Pragmatic listener: joint posterior over (category, featureSet) via exact77# enumeration of a single combined latent; the engine marginalizes it.78def pragmatic_listener(utterance):79 spk_cache = {}8081 def spk_logp(fs_idx, goal_idx):82 goal = goals[goal_idx]83 gv = goal_state(goal, fs_idx)84 key = (goal, gv)85 if key not in spk_cache:86 spk_cache[key] = speaker_logprobs(gv, goal)87 p = spk_cache[key].get(utterance, 0.0)88 return math.log(p) if p > 0 else -1e308990 combos = [(c, f) for c in range(2) for f in range(8)]91 joint_prior = torch.tensor([cat_prior[c].item() * fs_stack[c][f].item() for (c, f) in combos])92 score_t = torch.zeros(len(combos), 3)93 for j, (c, f) in enumerate(combos):94 for g in range(3):95 score_t[j, g] = spk_logp(f, g)9697 @pyro.infer.config_enumerate98 def model():99 cf = pyro.sample("cf", dist.Categorical(joint_prior))100 goal = pyro.sample("goal", dist.Categorical(goal_prior))101 pyro.factor("obs", score_t[cf, goal])102 return None103104 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)105 cfm = marg["cf"]106 sup = cfm.enumerate_support()107 pr = cfm.log_prob(sup).exp()108 out = {}109 for j, p in zip(sup.tolist(), pr.tolist()):110 c, f = combos[int(j)]111 rec = {"category": categories[c], "large": featureSets[f]["large"],112 "graceful": featureSets[f]["graceful"], "majestic": featureSets[f]["majestic"]}113 out[json.dumps(rec, sort_keys=True)] = p114 return out115116117ANSWER = pragmatic_listener("whale")118
| check | status | evidence |
|---|---|---|
| GT self-consistency | ok | floor 0.0000 (tv) |
| solver re-derivation | accept | 1/2 solvers · d=[0.042, 0.000] · claude-sonnet-4-6 |
| cross-language (pyro vs webppl) | pass | d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000 |
Height distributions are discretized using binParam = 3. The superordinate category (all people) has height distribution Gaussian(mu=0, sigma=1). The basketball player subcategory has Gaussian(mu=1, sigma=0.5). State values are 18 evenly spaced points from -3 to 3 (exclusive) in steps of 1/3. State probabilities are proportional to the Gaussian PDF evaluated at each state value. Thresholds for 'tall' are each state value minus 1/(2*binParam); thresholds for 'short' are each state value plus 1/(2*binParam); each threshold set is drawn uniformly. Comparison class is drawn uniformly from {superordinate, subordinate}. Speaker optimality alpha = 5. Three utterances: 'tall', 'short', 'silence' (silence is always true).
A literal listener conditions on the utterance being true (tall: state exceeds the tall threshold; short: state is below the short threshold; silence: always true) given the state distribution for the current comparison class. A speaker chooses utterances proportional to exp(alpha times the literal listener's log-probability of the state). A pragmatic listener jointly infers the entity's state, the comparison class, and the threshold pair (one tall-threshold drawn uniformly from the tall set, one short-threshold drawn uniformly from the short set) by inverting the speaker, conditioning on the utterance. The same threshold pair is passed to the speaker and through to the literal listener.
The marginal posterior distribution over the comparison class (superordinate vs. subordinate) for a pragmatic listener who hears 'tall' about a basketball player.
answer spec
{
"kind": "dist",
"domain": "finite",
"support": [
"superordinate",
"subordinate"
]
}system prompt
(system prompt loads here)
webppl primer
(primer loads here)
1///fold:2// helper function3var exp = function(x){return Math.exp(x)}45// for discretization6var binParam = 3;78// information about the superordinate category prior9// e.g., the height distribution for all people10var superordinate_params = {mu: 0, sigma: 1};1112// calculate the range in pre-defined steps;13// these values correspond to possible heights14var stateVals = _.range(superordinate_params.mu - 3 * superordinate_params.sigma,15 superordinate_params.mu + 3 * superordinate_params.sigma,16 superordinate_params.sigma/binParam)1718// for each possible height, calculate its probability of occurrence19var stateProbs = cache(function(stateParams){20 return map(function(s){21 exp(Gaussian(stateParams).score(s))22 }, stateVals)23});2425// generate a statePrior using the possible heights and their probabilities26var generateStatePrior = cache(function(stateParams) {27 return Infer({28 model: function(){29 return categorical({vs: stateVals, ps: stateProbs(stateParams)})30 }31 })32});3334// generate the uniform threshold prior35var thresholdBins ={36 positive: map(function(x){37 return x - (1/(binParam*2));38 }, sort(stateVals)),39 negative: map(function(x){40 return x + (1/(binParam*2));41 }, sort(stateVals))42};4344var thresholdPrior = cache(function(form){45 return Infer({46 model: function() { return uniformDraw(thresholdBins[form]) }47 });48});4950// information about the superordinate category priors51var subParams = {52 gymnasts: {mu: -1, sigma: 0.5}, // gymnast heights53 soccerPlayers: {mu: 0, sigma: 0.5}, // soccer player heights54 basketballPlayers: {mu: 1, sigma: 0.5} // basketball player heights55}5657// possible utterances can be either positive (tall) or negative (short) or a null utterance58var utterances = ["tall", "short", "silence"]5960// meaning function for utterances61var meaning = function(utterance, state, thresholds) {62 utterance == "tall" ? state > thresholds.tall :63 utterance == "short" ? state < thresholds.short :64 true65}6667// assume a uniform prior over comparison classes68var classPrior = Infer({69 model: function(){return uniformDraw(["subordinate", "superordinate"])}70});7172// set speaker optimality73var alpha = 5;7475var literalListener = cache(76 function(utterance, thresholds, comparisonClass) {77 Infer({model: function(){78 var StatePrior = generateStatePrior(comparisonClass)79 var state = sample(StatePrior);80 var m = meaning(utterance, state, thresholds);81 condition(m);82 return state;83 }})84 }, 10000 // limit cache size85)8687var speaker1 = cache(88 function(state, thresholds, comparisonClass) {89 Infer({model: function(){90 var utterance = uniformDraw(utterances);91 var L0 = literalListener(utterance, thresholds, comparisonClass);92 factor( alpha * L0.score(state) );93 return utterance;94 }})95 }, 10000 // limit cache size96)97///9899var pragmaticListener = cache(function(utterance, subordinate_params) {100 Infer({model: function(){101102 var statePrior = generateStatePrior(subordinate_params);103 var state = sample(statePrior);104 // separate thresholds for positive adjective and negative adjective105 var thresholds = {106 tall: sample(thresholdPrior("positive")),107 short: sample(thresholdPrior("negative"))108 }109110 // uncertainty about the comparison class (superordinate vs. subordinate)111 var c = sample(classPrior)112 var comparisonClass = c == "subordinate" ? subordinate_params : superordinate_params113114 var S1 = speaker1(state, thresholds, comparisonClass);115 observe(S1, utterance);116117 return { comparisonClass: c, state : state }118 }})119}, 10000 // limit cache size120 )121122var ANSWER = (marginalize(pragmaticListener("tall", subParams["basketballPlayers"]), "comparisonClass"));
1exp = math.exp2binParam = 33super_mu, super_sigma = 0.0, 1.045# 18 evenly spaced state values from -3 to 3 (exclusive) in steps of 1/36step = super_sigma / binParam7stateVals = []8x = super_mu - 3 * super_sigma9while x < super_mu + 3 * super_sigma - 1e-9:10 stateVals.append(x)11 x += step1213subParams = {"mu": 1.0, "sigma": 0.5} # basketball players14sorted_states = sorted(stateVals)15threshold_positive = [s - 1.0 / (binParam * 2) for s in sorted_states]16threshold_negative = [s + 1.0 / (binParam * 2) for s in sorted_states]17utterances = ["tall", "short", "silence"]18alpha = 5.01920def meaning(utt, state, t_tall, t_short):21 if utt == "tall":22 return state > t_tall23 if utt == "short":24 return state < t_short25 return True2627# Normalized state prior (categorical over stateVals weighted by the Gaussian pdf).28_state_dist_cache = {}29def state_dist(mu, sigma):30 key = (mu, sigma)31 if key in _state_dist_cache:32 return _state_dist_cache[key]33 d = dist.Normal(torch.tensor(float(mu)), torch.tensor(float(sigma)))34 ps = [exp(d.log_prob(torch.tensor(s)).item()) for s in stateVals]35 z = sum(ps)36 out = (stateVals, torch.tensor([p / z for p in ps]))37 _state_dist_cache[key] = out38 return out3940# Literal listener: condition on the utterance being true under the comparison-class41# state prior. Exact enumeration via Pyro's compute_marginals.42_L0_cache = {}43def literalListener(utt, t_tall, t_short, cc_mu, cc_sigma):44 key = (utt, t_tall, t_short, cc_mu, cc_sigma)45 if key in _L0_cache:46 return _L0_cache[key]47 svals, pri = state_dist(cc_mu, cc_sigma)48 _cond = torch.tensor([49 0.0 if meaning(utt, svals[i], t_tall, t_short) else float("-inf")50 for i in range(len(svals))51 ])52 @pyro.infer.config_enumerate53 def m():54 idx = pyro.sample("state", dist.Categorical(pri))55 pyro.factor("cond", _cond[idx])56 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(m, lambda: None)57 _L0_cache[key] = (marg["state"], svals)58 return _L0_cache[key]5960# Speaker: utterance ~ exp(alpha * L0.score(state)), inverting the literal listener.61_S1_cache = {}62def speaker1(state, t_tall, t_short, cc_mu, cc_sigma):63 key = (state, t_tall, t_short, cc_mu, cc_sigma)64 if key in _S1_cache:65 return _S1_cache[key]66 _sc = []67 for utt in utterances:68 marg, svals = literalListener(utt, t_tall, t_short, cc_mu, cc_sigma)69 if state in svals:70 lp = marg.log_prob(torch.tensor(svals.index(state))).item()71 else:72 lp = float("-inf")73 _sc.append(alpha * lp)74 _sc = torch.tensor(_sc)75 @pyro.infer.config_enumerate76 def m():77 uidx = pyro.sample("utt", dist.Categorical(torch.ones(len(utterances))))78 pyro.factor("util", _sc[uidx])79 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(m, lambda: None)80 _S1_cache[key] = marg["utt"]81 return _S1_cache[key]8283# Pragmatic listener: jointly sample (class, state, tall-threshold, short-threshold),84# observe the heard utterance by factoring the speaker's log-score, and marginalize.85# The entity's state prior is the subordinate (basketball) prior; the comparison86# class selects which prior the speaker/literal-listener use internally. Inference is87# exact enumeration over the full joint via Pyro's compute_marginals.88utterance_heard = "tall"89_uidx = utterances.index(utterance_heard)90classes = ["subordinate", "superordinate"]91_svals, _state_pri = state_dist(subParams["mu"], subParams["sigma"])9293# Speaker log-score for the heard utterance across every joint assignment.94_obs = torch.full(95 (len(classes), len(_svals), len(threshold_positive), len(threshold_negative)),96 float("-inf"),97)98for ci, c in enumerate(classes):99 cc_mu, cc_sigma = (subParams["mu"], subParams["sigma"]) if c == "subordinate" else (super_mu, super_sigma)100 for si, state in enumerate(_svals):101 for ti, t_tall in enumerate(threshold_positive):102 for tj, t_short in enumerate(threshold_negative):103 S1 = speaker1(state, t_tall, t_short, cc_mu, cc_sigma)104 _obs[ci, si, ti, tj] = S1.log_prob(torch.tensor(_uidx)).item()105106@pyro.infer.config_enumerate107def pragmaticListener():108 c = pyro.sample("c", dist.Categorical(torch.ones(len(classes))))109 s = pyro.sample("state", dist.Categorical(_state_pri))110 tt = pyro.sample("tt", dist.Categorical(torch.ones(len(threshold_positive))))111 ts = pyro.sample("ts", dist.Categorical(torch.ones(len(threshold_negative))))112 pyro.factor("obs", _obs[c, s, tt, ts])113114_marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(pragmaticListener, lambda: None)115_cm = _marg["c"]116ANSWER = {classes[i]: _cm.probs[i].item() for i in range(len(classes))}117
| check | status | evidence |
|---|---|---|
| GT self-consistency | ok | floor 0.0000 (tv) |
| solver re-derivation | accept | 1/2 solvers · d=[0.000, 0.115] · claude-sonnet-4-6 |
| cross-language (pyro vs webppl) | pass | d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000 |
Six words and their 25-dimensional GloVe vectors: eagle: [-0.8186906894583743, -0.8443627918594182, -0.04304780086785447, -0.8257634263841377, -0.7607218950809542, 0.47786735164930183, 0.36942709316422206, 0.18560148725224498, 0.38625176009619944, 0.24384273963053932, 1.0355862286322068, -0.14170089242313555, -0.17017960843359828, 0.27636172471279313, -0.49477465481497807, -1.199206930890509, 1.0531720839078256, -0.5154875303291531, 0.30704269353337016, 1.5382356443196483, -0.13215425501400774, 1.2222507503066664, 1.3819617662995949, -1.1579407453927437, 0.9439311306043343] pig: [-0.9027808771549458, -1.4539105978263833, 0.5743098399154295, 1.3052815987119957, -0.038556210348244704, -0.22144102997326148, 1.222050088622139, -0.027526643946408857, -0.13265827668708097, 1.4799207507145387, 0.02371336629548181, -0.9405402658175948, 0.06556493358788004, -1.6556208133885402, -0.44306373689318584, -0.475710035110901, 1.2435716830499404, -1.0677780309283533, -0.03344465447168945, 0.16184568683816827, 0.8718035460475897, 2.082956682688621, 0.47430271385843953, -0.4479993650378608, 1.5192928553678355] chicken: [-0.9555633717916903, -0.23467550895948608, 0.9081102168032618, 1.7681919864431317, -0.3888166871286516, 1.2292398003323308, 1.0624961440319318, -0.3558803892040966, -0.17024423658814317, 0.7046776782991592, 1.624196256505183, -1.1423231844008523, -0.9490267652945451, -1.8004114037674281, -0.026086280368055388, -2.089757256612839, 3.5660566372328693, -2.4178611093952225, -0.7077960662621875, 0.9418434965990246, 0.438927172322575, 1.0891725023940724, -0.1237861204326181, 0.7602054634506068, 3.0515580696224083] farm: [-1.3191469349030631, -0.34747873883058705, 0.09525267994894762, 0.08014872654330456, 0.1179814806966339, -0.26926061020753783, 0.709033965954239, -0.6521777385143812, 1.0195239553589313, 0.7192612109870958, 1.1711460976059695, -1.0779079866249233, -0.5443503049555966, 0.08523251153754875, 0.1455530206584687, -1.501097375488643, 1.1151234505440395, 0.0581591541412683, -0.1102242123027589, 0.5253857581277014, 0.21780949510893402, 0.026030837039037417, 0.07282095318396448, -0.6093002665622598, 2.0466066458317336] animal: [-0.08132595025854601, -1.8280616716238214, -0.4241550049238374, 1.3405833261217683, 1.3635302219051426, 0.19656106954281044, 1.0553637657141577, 0.8640316722860499, -0.34682275265131135, 0.27196141799987644, 0.9785603157742483, -3.1767493003780873, -0.7566904249011203, -1.1935303767007424, 0.2523177522167622, 0.33414675815038736, 0.21147820292953767, 0.2089073521353749, 0.36413859545070626, -0.3145854077725169, 0.8470589609352164, 0.8914477422324714, 0.06602846837066885, -1.1974184866543685, 1.6807019645814638] bird: [0.9317828273959833, -1.142927450658389, 1.1249556341704339, 0.7533022372085103, 0.039221572709652965, 0.5302815428039684, 1.1525754405638204, 0.5707370610821617, 0.01803607760778035, 0.9527229145321762, 1.0851468114908822, -0.4626041548552341, -0.5371489443168416, -0.8343285842461913, -0.09713481034287788, 0.8070233789520264, 0.21755780815430825, -0.6588132708557186, -0.7963193188039507, 0.12395864485237663, -0.18545774404118467, 1.311026289715281, 0.7764007851264465, -0.5776179488468618, 0.5640559901962993] The semantic similarity between a clue and a word is measured by the Euclidean distance between their vectors; a word is judged as intended by the clue with probability sigmoid(1/distance). The clue applies to a subset of words only if all words in the subset independently pass this probabilistic check. Possible target word subsets (pairs) to guess: [chicken,eagle], [eagle,pig], [chicken,pig]. Possible clue words: farm, animal, bird; prior over clues is uniform. Prior over target subsets is uniform. Speaker optimality alpha = 1.
A literal listener conditions on the clue correctly applying (all words in the subset independently pass the similarity check) and returns the target subset. A speaker chooses clues proportional to exp(alpha times the literal listener's log-probability of the target subset). A pragmatic listener inverts the speaker model to infer the intended target subset.
The posterior distribution over target word subsets given that the pragmatic listener hears the clue 'farm'. Represent each candidate pair as a two-element list of the two words, in alphabetical order.
answer spec
{
"kind": "dist",
"domain": "finite",
"support": [
[
"chicken",
"eagle"
],
[
"chicken",
"pig"
],
[
"eagle",
"pig"
]
]
}system prompt
(system prompt loads here)
webppl primer
(primer loads here)
1///fold: vectors2var vectors = {eagle : Vector([-0.8186906894583743, -0.8443627918594182,3 -0.04304780086785447, -0.8257634263841377,4 -0.7607218950809542, 0.47786735164930183,5 0.36942709316422206, 0.18560148725224498,6 0.38625176009619944, 0.24384273963053932,7 1.0355862286322068, -0.14170089242313555,8 -0.17017960843359828, 0.27636172471279313,9 -0.49477465481497807, -1.199206930890509,10 1.0531720839078256, -0.5154875303291531,11 0.30704269353337016, 1.5382356443196483,12 -0.13215425501400774, 1.2222507503066664,13 1.3819617662995949, -1.1579407453927437,14 0.9439311306043343]),15 pig: Vector([-0.9027808771549458, -1.4539105978263833,16 0.5743098399154295, 1.3052815987119957,17 -0.038556210348244704, -0.22144102997326148,18 1.222050088622139, -0.027526643946408857,19 -0.13265827668708097, 1.4799207507145387,20 0.02371336629548181, -0.9405402658175948,21 0.06556493358788004, -1.6556208133885402,22 -0.44306373689318584, -0.475710035110901,23 1.2435716830499404, -1.0677780309283533,24 -0.03344465447168945, 0.16184568683816827,25 0.8718035460475897, 2.082956682688621,26 0.47430271385843953, -0.4479993650378608,27 1.5192928553678355]),28 chicken: Vector([-0.9555633717916903, -0.23467550895948608,29 0.9081102168032618, 1.7681919864431317,30 -0.3888166871286516, 1.2292398003323308,31 1.0624961440319318, -0.3558803892040966,32 -0.17024423658814317, 0.7046776782991592,33 1.624196256505183, -1.1423231844008523,34 -0.9490267652945451, -1.8004114037674281,35 -0.026086280368055388, -2.089757256612839,36 3.5660566372328693, -2.4178611093952225,37 -0.7077960662621875, 0.9418434965990246,38 0.438927172322575, 1.0891725023940724,39 -0.1237861204326181, 0.7602054634506068,40 3.0515580696224083]),41 farm: Vector([-1.3191469349030631, -0.34747873883058705,42 0.09525267994894762, 0.08014872654330456,43 0.1179814806966339, -0.26926061020753783,44 0.709033965954239, -0.6521777385143812,45 1.0195239553589313, 0.7192612109870958,46 1.1711460976059695, -1.0779079866249233,47 -0.5443503049555966, 0.08523251153754875,48 0.1455530206584687, -1.501097375488643,49 1.1151234505440395, 0.0581591541412683,50 -0.1102242123027589, 0.5253857581277014,51 0.21780949510893402, 0.026030837039037417,52 0.07282095318396448, -0.6093002665622598,53 2.0466066458317336]),54 animal: Vector([-0.08132595025854601, -1.8280616716238214,55 -0.4241550049238374, 1.3405833261217683,56 1.3635302219051426, 0.19656106954281044,57 1.0553637657141577, 0.8640316722860499,58 -0.34682275265131135, 0.27196141799987644,59 0.9785603157742483, -3.1767493003780873,60 -0.7566904249011203, -1.1935303767007424,61 0.2523177522167622, 0.33414675815038736,62 0.21147820292953767, 0.2089073521353749,63 0.36413859545070626, -0.3145854077725169,64 0.8470589609352164, 0.8914477422324714,65 0.06602846837066885, -1.1974184866543685,66 1.6807019645814638]),67 bird: Vector([0.9317828273959833, -1.142927450658389,68 1.1249556341704339, 0.7533022372085103,69 0.039221572709652965, 0.5302815428039684,70 1.1525754405638204, 0.5707370610821617,71 0.01803607760778035, 0.9527229145321762,72 1.0851468114908822, -0.4626041548552341,73 -0.5371489443168416, -0.8343285842461913,74 -0.09713481034287788, 0.8070233789520264,75 0.21755780815430825, -0.6588132708557186,76 -0.7963193188039507, 0.12395864485237663,77 -0.18545774404118467, 1.311026289715281,78 0.7764007851264465, -0.5776179488468618,79 0.5640559901962993])80};81///8283var meaning = function(clue, words) {8485 var distance = function(vector1, vector2)86 {87 var squared = map(function(tuple)88 {89 return (tuple[0] - tuple[1])*(tuple[0] - tuple[1]);90 }91 , zip(ad.tensor.toScalars(vector1), ad.tensor.toScalars(vector2))92 );939495 var answer = Math.sqrt(sum(squared));96 return answer;97 };9899 var sigmoid = function(num)100 {101 return 1/(1 + Math.exp(-1*num));102 };103104 var trueFalse = function(clue, word)105 {106 var dist = distance(vectors[clue], vectors[word]);107 var prob = sigmoid(1/dist);108 return flip(prob);109 };110111 var wordsVectors = map(function(word) {return trueFalse(clue, word);}, words);112113114 return all(function(s) {return s;}, wordsVectors);115};116117var wordsPrior = function()118{119 var pairs = [["chicken", "eagle"], ["eagle", "pig"], ["chicken", "pig"]];120 return uniformDraw(pairs);121};122123var cluePrior = function()124{125 return uniformDraw(["farm", "animal", "bird"]);126};127128129var literalListener = function(clue)130{131 Infer(function()132 {133 var randomSubset = wordsPrior();134 var uttTruthVal = meaning(clue, randomSubset);135 condition(uttTruthVal);136 return randomSubset;137 }138 )139};140141var alpha = 1;142143var speaker = function(subset)144{145 Infer(function()146 {147 var clue = cluePrior();148 factor(alpha*literalListener(clue).score(subset));149 return clue;150 }151 )152};153154var pragmaticListener = function(clue)155{156 Infer(function()157 {158 var randomSubset = wordsPrior();159 var s1 = speaker(randomSubset);160 observe(s1, clue);161 return randomSubset;162 }163 )164};165166viz.table(pragmaticListener("farm"));167168var ANSWER = (pragmaticListener("farm"));
12vectors = {3 "eagle": [-0.8186906894583743, -0.8443627918594182, -0.04304780086785447, -0.8257634263841377, -0.7607218950809542, 0.47786735164930183, 0.36942709316422206, 0.18560148725224498, 0.38625176009619944, 0.24384273963053932, 1.0355862286322068, -0.14170089242313555, -0.17017960843359828, 0.27636172471279313, -0.49477465481497807, -1.199206930890509, 1.0531720839078256, -0.5154875303291531, 0.30704269353337016, 1.5382356443196483, -0.13215425501400774, 1.2222507503066664, 1.3819617662995949, -1.1579407453927437, 0.9439311306043343],4 "pig": [-0.9027808771549458, -1.4539105978263833, 0.5743098399154295, 1.3052815987119957, -0.038556210348244704, -0.22144102997326148, 1.222050088622139, -0.027526643946408857, -0.13265827668708097, 1.4799207507145387, 0.02371336629548181, -0.9405402658175948, 0.06556493358788004, -1.6556208133885402, -0.44306373689318584, -0.475710035110901, 1.2435716830499404, -1.0677780309283533, -0.03344465447168945, 0.16184568683816827, 0.8718035460475897, 2.082956682688621, 0.47430271385843953, -0.4479993650378608, 1.5192928553678355],5 "chicken": [-0.9555633717916903, -0.23467550895948608, 0.9081102168032618, 1.7681919864431317, -0.3888166871286516, 1.2292398003323308, 1.0624961440319318, -0.3558803892040966, -0.17024423658814317, 0.7046776782991592, 1.624196256505183, -1.1423231844008523, -0.9490267652945451, -1.8004114037674281, -0.026086280368055388, -2.089757256612839, 3.5660566372328693, -2.4178611093952225, -0.7077960662621875, 0.9418434965990246, 0.438927172322575, 1.0891725023940724, -0.1237861204326181, 0.7602054634506068, 3.0515580696224083],6 "farm": [-1.3191469349030631, -0.34747873883058705, 0.09525267994894762, 0.08014872654330456, 0.1179814806966339, -0.26926061020753783, 0.709033965954239, -0.6521777385143812, 1.0195239553589313, 0.7192612109870958, 1.1711460976059695, -1.0779079866249233, -0.5443503049555966, 0.08523251153754875, 0.1455530206584687, -1.501097375488643, 1.1151234505440395, 0.0581591541412683, -0.1102242123027589, 0.5253857581277014, 0.21780949510893402, 0.026030837039037417, 0.07282095318396448, -0.6093002665622598, 2.0466066458317336],7 "animal": [-0.08132595025854601, -1.8280616716238214, -0.4241550049238374, 1.3405833261217683, 1.3635302219051426, 0.19656106954281044, 1.0553637657141577, 0.8640316722860499, -0.34682275265131135, 0.27196141799987644, 0.9785603157742483, -3.1767493003780873, -0.7566904249011203, -1.1935303767007424, 0.2523177522167622, 0.33414675815038736, 0.21147820292953767, 0.2089073521353749, 0.36413859545070626, -0.3145854077725169, 0.8470589609352164, 0.8914477422324714, 0.06602846837066885, -1.1974184866543685, 1.6807019645814638],8 "bird": [0.9317828273959833, -1.142927450658389, 1.1249556341704339, 0.7533022372085103, 0.039221572709652965, 0.5302815428039684, 1.1525754405638204, 0.5707370610821617, 0.01803607760778035, 0.9527229145321762, 1.0851468114908822, -0.4626041548552341, -0.5371489443168416, -0.8343285842461913, -0.09713481034287788, 0.8070233789520264, 0.21755780815430825, -0.6588132708557186, -0.7963193188039507, 0.12395864485237663, -0.18545774404118467, 1.311026289715281, 0.7764007851264465, -0.5776179488468618, 0.5640559901962993],9}10pairs = [["chicken", "eagle"], ["eagle", "pig"], ["chicken", "pig"]]11clues = ["farm", "animal", "bird"]12alpha = 1.0131415def distance(a, b):16 return math.sqrt(sum((x - y) ** 2 for x, y in zip(vectors[a], vectors[b])))171819def sigmoid(x):20 return 1.0 / (1.0 + math.exp(-x))212223def prob_true(clue, word):24 return sigmoid(1.0 / distance(clue, word))252627def meaning_prob(clue, words):28 # meaning(clue, words) is true iff every word is independently true under29 # its Bernoulli; P(true) is the product of the per-word probabilities.30 p = 1.031 for w in words:32 p *= prob_true(clue, w)33 return p343536# Literal listener: pair posterior given the clue is true.37def literal_listener(clue):38 logp = torch.log(torch.tensor([meaning_prob(clue, pairs[i]) for i in range(3)]))3940 @pyro.infer.config_enumerate41 def model():42 sub = pyro.sample("sub", dist.Categorical(torch.ones(3)))43 pyro.factor("cond", logp[sub]) # condition(uttTruthVal): weight by P(true)44 return None4546 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)47 m = marg["sub"]48 sup = m.enumerate_support()49 pr = m.log_prob(sup).exp()50 return {int(i): p for i, p in zip(sup.tolist(), pr.tolist())}515253# Speaker: clue posterior given the intended subset.54def speaker_logprobs(subset_idx):55 ll = {cl: literal_listener(cl) for cl in clues}5657 @pyro.infer.config_enumerate58 def model():59 c = pyro.sample("clue", dist.Categorical(torch.ones(3)))60 sc = torch.tensor([math.log(ll[clues[i]][subset_idx]) if ll[clues[i]][subset_idx] > 0 else -1e3061 for i in range(3)])62 pyro.factor("sc", alpha * sc[c])63 return None6465 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)66 m = marg["clue"]67 sup = m.enumerate_support()68 pr = m.log_prob(sup).exp()69 return {clues[int(i)]: p for i, p in zip(sup.tolist(), pr.tolist())}707172# Pragmatic listener: pair posterior given the heard clue.73def pragmatic_listener(clue):74 spk = {i: speaker_logprobs(i) for i in range(3)}75 sc = torch.tensor([math.log(spk[i].get(clue, 0.0)) if spk[i].get(clue, 0.0) > 0 else -1e3076 for i in range(3)])7778 @pyro.infer.config_enumerate79 def model():80 sub = pyro.sample("sub", dist.Categorical(torch.ones(3)))81 pyro.factor("obs", sc[sub])82 return None8384 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)85 m = marg["sub"]86 sup = m.enumerate_support()87 pr = m.log_prob(sup).exp()88 out = {}89 for i, p in zip(sup.tolist(), pr.tolist()):90 out[json.dumps(sorted(pairs[int(i)]))] = p91 return out929394ANSWER = pragmatic_listener("farm")95
| check | status | evidence |
|---|---|---|
| GT self-consistency | ok | floor 0.0000 (tv) |
| solver re-derivation | accept | 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6 |
| cross-language (pyro vs webppl) | pass | d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000 |
Three objects: 'blue square', 'blue circle', 'green square' (each described by color and shape; object string is the concatenation 'color shape'). The object prior is uniform over these three objects. Seven possible utterances: 'blue', 'green', 'square', 'circle', 'blue square', 'blue circle', 'green square'. An utterance applies to an object if and only if the utterance string is a substring of the object's string. Utterance cost equals the cost parameter times the number of words in the utterance. The cost parameter is drawn uniformly from the 10 values: 0.05, 0.55, 1.05, 1.55, 2.05, 2.55, 3.05, 3.55, 4.05, 4.55. Speaker optimality alpha = 1.
A literal listener conditions on the utterance applying to the object and returns the object. A speaker chooses utterances proportional to exp(alpha times (literal listener log-probability of the object minus the utterance cost)). A pragmatic listener infers the intended object and the speaker's cost parameter jointly by inverting the speaker.
The joint posterior distribution over the intended object and the speaker's cost parameter given that the pragmatic listener hears 'blue'.
answer spec
{
"kind": "dist",
"domain": "finite",
"labels": {
"record": {
"obj": "string",
"costParameter": "real"
}
}
}system prompt
(system prompt loads here)
webppl primer
(primer loads here)
1// set of states2var objects = [{color: "blue", shape: "square", string: "blue square"},3 {color: "blue", shape: "circle", string: "blue circle"},4 {color: "green", shape: "square", string: "green square"}]56// prior over world states7var objectPrior = function() {8 var obj = uniformDraw(objects)9 return obj.string10}1112// set of utterances13var utterances = ["blue", "green", "square", "circle",14 "blue square", "blue circle", "green square"]1516// utterance cost function17var cost = function(utterance, costParameter) {18 var numWords = function(utterance) {19 var split = utterance.split(" ")20 return _.size(split)21 }22 return costParameter*numWords(utterance)23};2425var costParameterPrior = function() {26 return uniformDraw(_.range(0.05, 5, 0.5))27}2829// meaning function to interpret the utterances30var meaning = function(utterance, obj){31 _.includes(obj, utterance)32}3334// literal listener35var literalListener = function(utterance){36 Infer({model: function(){37 var obj = objectPrior();38 condition(meaning(utterance, obj))39 return obj40 }})41}4243// set speaker optimality44var alpha = 14546// pragmatic speaker47var speaker = function(obj,costParameter){48 Infer({model: function(){49 var utterance = uniformDraw(utterances)50 factor(alpha * (literalListener(utterance).score(obj) -51 cost(utterance,costParameter)))52 return utterance53 }})54}5556// pragmatic listener57var pragmaticListener = function(utterance){58 Infer({model: function(){59 var obj = objectPrior()60 var costParameter = costParameterPrior()61 observe(speaker(obj,costParameter),utterance)62 return {obj, costParameter}63 }})64}656667display("cost parameter prior")68viz(Infer(costParameterPrior))6970var listenerPosteriorBlue = pragmaticListener("blue")71display("pragmatic listener hears \"blue\"")72viz.table(marginalize(listenerPosteriorBlue, "obj"))73viz(marginalize(listenerPosteriorBlue, "costParameter"))7475var listenerPosteriorBlueSquare = pragmaticListener("blue square")76display("pragmatic listener hears \"blue square\"")77viz.table(marginalize(listenerPosteriorBlueSquare, "obj"))78viz(marginalize(listenerPosteriorBlueSquare, "costParameter"))7980var ANSWER = (pragmaticListener('blue'));
12objects = [3 {"color": "blue", "shape": "square", "string": "blue square"},4 {"color": "blue", "shape": "circle", "string": "blue circle"},5 {"color": "green", "shape": "square", "string": "green square"},6]7obj_strings = [o["string"] for o in objects]8utterances = ["blue", "green", "square", "circle", "blue square", "blue circle", "green square"]9# costParameterPrior: uniformDraw(_.range(0.05, 5, 0.5))10costParams = []11v = 0.0512while v < 5 - 1e-12:13 costParams.append(round(v, 10))14 v += 0.515alpha = 1.016NEG = -1e30171819def num_words(u):20 return len(u.split(" "))212223def cost(u, cp):24 return cp * num_words(u)252627def meaning(u, objstr):28 return u in objstr # _.includes(obj, utterance) over the object string293031# Literal listener: object posterior given a literally-true utterance.32def literal_listener(utterance):33 mask = torch.tensor([0.0 if meaning(utterance, s) else NEG for s in obj_strings])3435 @pyro.infer.config_enumerate36 def model():37 o = pyro.sample("o", dist.Categorical(torch.ones(len(obj_strings))))38 pyro.factor("cond", mask[o])39 return None4041 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)42 m = marg["o"]43 sup = m.enumerate_support()44 pr = m.log_prob(sup).exp()45 return {int(i): p for i, p in zip(sup.tolist(), pr.tolist())}464748# Speaker: utterance posterior given the object and cost parameter.49def speaker_logprobs(obj_idx, cp):50 ll = {u: literal_listener(u) for u in utterances}5152 @pyro.infer.config_enumerate53 def model():54 u = pyro.sample("u", dist.Categorical(torch.ones(len(utterances))))55 sc = []56 for uu in utterances:57 p = ll[uu].get(obj_idx, 0.0)58 lp = math.log(p) if p > 0 else NEG59 sc.append(alpha * (lp - cost(uu, cp)))60 scores = torch.tensor(sc)61 pyro.factor("sc", scores[u])62 return None6364 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)65 m = marg["u"]66 sup = m.enumerate_support()67 pr = m.log_prob(sup).exp()68 return {utterances[int(i)]: p for i, p in zip(sup.tolist(), pr.tolist())}697071# Pragmatic listener: joint posterior over (obj, costParameter) via exact72# enumeration of a single combined latent; the engine marginalizes.73def pragmatic_listener(utterance):74 nO = len(obj_strings)75 nC = len(costParams)76 combos = [(o, c) for o in range(nO) for c in range(nC)]77 joint_prior = torch.ones(len(combos)) # uniform object x uniform cost78 score_t = torch.zeros(len(combos))79 for j, (o, c) in enumerate(combos):80 spk = speaker_logprobs(o, costParams[c])81 p = spk.get(utterance, 0.0)82 score_t[j] = math.log(p) if p > 0 else NEG8384 @pyro.infer.config_enumerate85 def model():86 oc = pyro.sample("oc", dist.Categorical(joint_prior))87 pyro.factor("obs", score_t[oc])88 return None8990 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)91 m = marg["oc"]92 sup = m.enumerate_support()93 pr = m.log_prob(sup).exp()94 out = {}95 for j, p in zip(sup.tolist(), pr.tolist()):96 o, c = combos[int(j)]97 rec = {"obj": obj_strings[o], "costParameter": costParams[c]}98 out[json.dumps(rec, sort_keys=True)] = p99 return out100101102ANSWER = pragmatic_listener("blue")103
| check | status | evidence |
|---|---|---|
| GT self-consistency | ok | floor 0.0000 (tv) |
| solver re-derivation | accept | 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6 |
| cross-language (pyro vs webppl) | pass | d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000 |
The prevalence space is a grid of 40 values running from 0.01 to 0.985 in steps of 0.025: 0.01, 0.035, 0.060, 0.085, 0.110, 0.135, 0.160, 0.185, 0.210, 0.235, 0.260, 0.285, 0.310, 0.335, 0.360, 0.385, 0.410, 0.435, 0.460, 0.485, 0.510, 0.535, 0.560, 0.585, 0.610, 0.635, 0.660, 0.685, 0.710, 0.735, 0.760, 0.785, 0.810, 0.835, 0.860, 0.885, 0.910, 0.935, 0.960, 0.985. A discretized Beta distribution over this grid with mean parameter g and concentration parameter d assigns each grid value x a weight proportional to x^(g*d - 1) * (1-x)^((1-g)*d - 1). The prior over prevalence is a mixture of two such discretized Beta components: with probability 0.3, the component with g=0.5 and d=10; otherwise the component with g=0.01 and d=100.
Prevalence is drawn from a two-component mixture over the grid. One component concentrates mass near zero (a near-absent component with very low mean and high concentration). The other places mass according to a discretized Beta with the specified mean and concentration (the present component). The mixture weight is the prior probability that the property is present at all.
The marginal prior distribution over the numeric prevalence values.
answer spec
{
"kind": "dist",
"domain": "real"
}system prompt
(system prompt loads here)
webppl primer
(primer loads here)
1///fold:2// discretized range between 0 - 13var bins = _.range(0.01, 1, 0.025);45// function returns a discretized Beta distribution6var DiscreteBeta = cache(function(g, d){7 var a = g * d, b = (1-g) * d;8 var betaPDF = function(x){9 return Math.pow(x, a-1)*Math.pow((1-x), b-1)10 }11 var probs = map(betaPDF, bins);12 return Categorical({vs: bins, ps: probs})13})14///15var priorModel = function(params){16 Infer({model: function(){1718 var potential = params["potential"]19 var g = params["prevalenceWhenPresent"]20 var d = params["concentrationWhenPresent"]2122 var StableDistribution = DiscreteBeta(g, d)23 var UnstableDistribution = DiscreteBeta(0.01, 100)2425 var prevalence = flip(potential) ?26 sample(StableDistribution) :27 sample(UnstableDistribution)2829 return {prevalence}3031 }})32}3334var d = priorModel({potential: 0.3, prevalenceWhenPresent: 0.5, concentrationWhenPresent: 10});35var ANSWER = marginalize(d, 'prevalence');36
12# grid: 0.01 to 0.985 step 0.025 (40 values), matching _.range(0.01, 1, 0.025)3bins = [round(0.01 + 0.025 * i, 10) for i in range(40)]45def discrete_beta_probs(g, d):6 a = g * d7 b = (1 - g) * d8 raw = [(x ** (a - 1)) * ((1 - x) ** (b - 1)) for x in bins]9 Z = sum(raw)10 return torch.tensor([r / Z for r in raw])1112stable = discrete_beta_probs(0.5, 10) # present component13unstable = discrete_beta_probs(0.01, 100) # near-absent component14potential = 0.31516@pyro.infer.config_enumerate17def model():18 present = pyro.sample("present", dist.Bernoulli(potential)).long()19 probs = torch.where(present.bool().unsqueeze(-1), stable, unstable)20 idx = pyro.sample("idx", dist.Categorical(probs))21 return idx2223marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(24 model, lambda: None25)["idx"]26sup = marg.enumerate_support()27probs = torch.exp(marg.log_prob(sup))28ANSWER = {bins[int(i)]: float(p) for i, p in zip(sup, probs)}29
| check | status | evidence |
|---|---|---|
| GT self-consistency | ok | floor 0.0000 (w1) |
| solver re-derivation | accept | 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6 |
| cross-language (pyro vs webppl) | pass | d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000 |
Prevalence is discretized into 50 bins: values 0.01, 0.03, 0.05, ..., 0.99 (i.e., starting at 0.01 with step 0.02, rounded to 2 decimal places). The threshold for the generic is drawn uniformly from the 49 midpoints between consecutive bins: 0.02, 0.04, ..., 0.98. A Beta-mixture prior over prevalence: with probability 0.3 (the 'potential' parameter), the prevalence is drawn from a discretized Beta with mean 0.5 and concentration 10 (g=0.5, d=10, so a=g*d=5, b=(1-g)*d=5); otherwise, it is drawn from a discretized Beta with g=0.01 and d=100 (a very low-prevalence unstable distribution with a=1, b=99). The discrete Beta weight for bin x is proportional to x^(a-1) * (1-x)^(b-1).
The generic statement is true if the prevalence exceeds the threshold. A literal listener draws a prevalence from the prior and a threshold uniformly, conditions on the generic being true, and returns the prevalence.
The posterior distribution over prevalence given that the literal listener hears a generic statement.
answer spec
{
"kind": "dist",
"domain": "real"
}system prompt
(system prompt loads here)
webppl primer
(primer loads here)
1///fold:2// discretized range between 0 - 13var bins = map(function(x){4 _.round(x, 2);5}, _.range(0.01, 1, 0.02));67var thresholdBins = map2(function(x,y){8 var d = (y - x)/ 2;9 return x + d10}, bins.slice(0, bins.length - 1), bins.slice(1, bins.length))1112// function returns a discretized Beta distribution13var DiscreteBeta = cache(function(g, d){14 var a = g * d, b = (1-g) * d;15 var betaPDF = function(x){16 return Math.pow(x, a-1)*Math.pow((1-x), b-1)17 }18 var probs = map(betaPDF, bins);19 return Categorical({vs: bins, ps: probs})20})2122var priorModel = function(params){23 Infer({model: function(){2425 var potential = params["potential"]26 var g = params["prevalenceWhenPresent"]27 var d = params["concentrationWhenPresent"]2829 var StableDistribution = DiscreteBeta(g, d)30 var UnstableDistribution = DiscreteBeta(0.01, 100)3132 var prevalence = flip(potential) ?33 sample(StableDistribution) :34 sample(UnstableDistribution)3536 return prevalence3738 }})39}40///41var meaning = function(utterance, prevalence, threshold) {42 return (utterance == 'generic') ? prevalence > threshold : true43}44var thresholdPrior = function() { return uniformDraw(thresholdBins) };4546var statePrior = priorModel({47 potential: 0.3,48 prevalenceWhenPresent: 0.5, // how prevalent under the stable cause49 concentrationWhenPresent: 10 // the inverse-variance of the stable cause50})5152display("prevalence prior")53viz(statePrior)5455var listener = cache(function(utterance) {56 Infer({model: function(){57 var prevalence = sample(statePrior)58 var threshold = thresholdPrior()59 var m = meaning(utterance, prevalence, threshold)60 condition(m)61 return prevalence62 }})63})6465display("listener posterior")66listener("generic")6768var ANSWER = (listener('generic'));
1# discretized prevalence bins: 0.01, 0.03, ..., 0.992bins = [round(0.01 + 0.02 * k, 2) for k in range(50)]3bins_t = torch.tensor(bins)4# threshold midpoints between consecutive bins5threshold_bins = [round((bins[i] + bins[i + 1]) / 2.0, 10) for i in range(len(bins) - 1)]6threshold_t = torch.tensor(threshold_bins)78def discrete_beta_probs(g, d):9 a = g * d10 b = (1.0 - g) * d11 w = torch.pow(bins_t, a - 1.0) * torch.pow(1.0 - bins_t, b - 1.0)12 return w / w.sum()1314potential = 0.315stable = discrete_beta_probs(0.5, 10.0)16unstable = discrete_beta_probs(0.01, 100.0)17# mixture prior over prevalence bins18prior_probs = potential * stable + (1.0 - potential) * unstable1920@pyro.infer.config_enumerate21def model():22 prev_i = pyro.sample("prevalence", dist.Categorical(prior_probs))23 thr_i = pyro.sample("threshold", dist.Categorical(torch.ones(len(threshold_bins))))24 prevalence = bins_t[prev_i]25 threshold = threshold_t[thr_i]26 # generic is true iff prevalence > threshold27 on = prevalence > threshold28 pyro.factor("meaning", torch.where(on, torch.tensor(0.0), torch.tensor(float("-inf"))))29 return prev_i3031marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)32m = marg["prevalence"]33sup = m.enumerate_support()34probs = m.log_prob(sup).exp()35ANSWER = {float(bins_t[int(sup[i].item())]): float(probs[i]) for i in range(len(sup)) if float(probs[i]) > 0}36
| check | status | evidence |
|---|---|---|
| GT self-consistency | ok | floor 0.0000 (w1) |
| solver re-derivation | accept | 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6 |
| cross-language (pyro vs webppl) | pass | d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000 |
Height distributions are discretized using binParam = 3. The superordinate category (all people) has height distribution Gaussian(mu=0, sigma=1). The basketball player subcategory has Gaussian(mu=1, sigma=0.5). State values are 18 evenly spaced points from -3 to 3 (exclusive) in steps of 1/3. State probabilities are proportional to the Gaussian PDF evaluated at each state value. Thresholds for 'tall' are each state value minus 1/(2*binParam); thresholds for 'short' are each state value plus 1/(2*binParam); each threshold set is drawn uniformly. Comparison class is drawn uniformly from {superordinate, subordinate}. Speaker optimality alpha = 5. Three utterances: 'tall', 'short', 'silence' (silence is always true).
A literal listener conditions on the utterance being true (tall: state exceeds the tall threshold; short: state is below the short threshold; silence: always true) given the state distribution for the current comparison class. A speaker chooses utterances proportional to exp(alpha times the literal listener's log-probability of the state). A pragmatic listener infers the entity's state, the tall threshold, the short threshold, and the comparison class jointly by inverting the speaker, conditioning on the utterance. The thresholds are drawn from the same uniform priors as in the given; the same threshold pair is passed to the speaker and through to the literal listener.
The marginal posterior distribution over the comparison class (superordinate vs. subordinate) for a pragmatic listener who hears 'tall' about a basketball player.
answer spec
{
"kind": "dist",
"domain": "finite",
"support": [
"superordinate",
"subordinate"
]
}system prompt
(system prompt loads here)
webppl primer
(primer loads here)
1///fold:2// helper function3var exp = function(x){return Math.exp(x)}45// for discretization6var binParam = 3;78// information about the superordinate category prior9// e.g., the height distribution for all people10var superordinate_params = {mu: 0, sigma: 1};1112// calculate the range in pre-defined steps;13// these values correspond to possible heights14var stateVals = _.range(superordinate_params.mu - 3 * superordinate_params.sigma,15 superordinate_params.mu + 3 * superordinate_params.sigma,16 superordinate_params.sigma/binParam)1718// for each possible height, calculate its probability of occurrence19var stateProbs = cache(function(stateParams){20 return map(function(s){21 exp(Gaussian(stateParams).score(s))22 }, stateVals)23});2425// generate a statePrior using the possible heights and their probabilities26var generateStatePrior = cache(function(stateParams) {27 return Infer({28 model: function(){29 return categorical({vs: stateVals, ps: stateProbs(stateParams)})30 }31 })32});3334// generate the uniform threshold prior35var thresholdBins ={36 positive: map(function(x){37 return x - (1/(binParam*2));38 }, sort(stateVals)),39 negative: map(function(x){40 return x + (1/(binParam*2));41 }, sort(stateVals))42};4344var thresholdPrior = cache(function(form){45 return Infer({46 model: function() { return uniformDraw(thresholdBins[form]) }47 });48});49///5051// information about the superordinate category priors52var subParams = {53 gymnasts: {mu: -1, sigma: 0.5}, // gymnast heights54 soccerPlayers: {mu: 0, sigma: 0.5}, // soccer player heights55 basketballPlayers: {mu: 1, sigma: 0.5} // basketball player heights56}5758// possible utterances can be either positive (tall) or negative (short) or a null utterance59var utterances = ["tall", "short", "silence"]6061// meaning function for utterances62var meaning = function(utterance, state, thresholds) {63 utterance == "tall" ? state > thresholds.tall :64 utterance == "short" ? state < thresholds.short :65 true66}6768// assume a uniform prior over comparison classes69var classPrior = Infer({70 model: function(){return uniformDraw(["subordinate", "superordinate"])}71});7273// set speaker optimality74var alpha = 5;7576var literalListener = cache(77 function(utterance, thresholds, comparisonClass) {78 Infer({model: function(){79 var StatePrior = generateStatePrior(comparisonClass)80 var state = sample(StatePrior);81 var m = meaning(utterance, state, thresholds);82 condition(m);83 return state;84 }})85 }, 10000 // limit cache size86)8788var speaker1 = cache(89 function(state, thresholds, comparisonClass) {90 Infer({model: function(){91 var utterance = uniformDraw(utterances);92 var L0 = literalListener(utterance, thresholds, comparisonClass);93 factor( alpha * L0.score(state) );94 return utterance;95 }})96 }, 10000 // limit cache size97)9899var pragmaticListener = cache(function(utterance, subordinate_params) {100 Infer({model: function(){101102 var statePrior = generateStatePrior(subordinate_params);103 var state = sample(statePrior);104 // separate thresholds for positive adjective and negative adjective105 var thresholds = {106 tall: sample(thresholdPrior("positive")),107 short: sample(thresholdPrior("negative"))108 }109110 // uncertainty about the comparison class (superordinate vs. subordinate)111 var c = sample(classPrior)112 var comparisonClass = c == "subordinate" ? subordinate_params : superordinate_params113114 var S1 = speaker1(state, thresholds, comparisonClass);115 observe(S1, utterance);116117 return { comparisonClass: c, state : state }118 }})119}, 10000 // limit cache size120 )121122// the possible experiment conditions:123// you hear that someone is a member of a subordinate category124// then you are told that they are tall/short;125// the task is to figure out the implicit comparison class126var exptConditions = [127 {utt: "tall", sub: "basketballPlayers"},128 {utt: "short", sub: "basketballPlayers"},129 {utt: "tall", sub: "soccerPlayers"},130 {utt: "short", sub: "soccerPlayers"},131 {utt: "tall", sub: "gymnasts"},132 {utt: "short", sub: "gymnasts"}133];134135// generate structure predictions by mapping through the experiment conditions136var L1predictions = map(function(stim){137 var L1posterior = pragmaticListener(stim.utt, subParams[stim.sub])138 return {139 utterance: stim.utt,140 "P(superordinate comparison class)": exp(marginalize(L1posterior, "comparisonClass").score("superordinate")),141 "subordinate category": stim.sub,142 model: "L1"143 }144}, exptConditions)145146display("the basketball player is tall")147display("--> height = " + expectation(marginalize(pragmaticListener("tall",{mu: 1, sigma: 0.5}), "state")))148display("the basketball player is short")149display("--> height = " + expectation(marginalize(pragmaticListener("short",{mu: 1, sigma: 0.5}), "state")))150151display("probability of superordinate comparison class (i.e., tall for all people)")152viz.bar(L1predictions, {groupBy: "subordinate category"})153154var ANSWER = (marginalize(pragmaticListener("tall", subParams["basketballPlayers"]), "comparisonClass"));
12binParam = 33sup_mu, sup_sigma = 0.0, 1.04sub_mu, sub_sigma = 1.0, 0.5 # basketballPlayers5step = sup_sigma / binParam6stateVals = []7v = sup_mu - 3 * sup_sigma8while v < sup_mu + 3 * sup_sigma - 1e-10:9 stateVals.append(round(v, 10))10 v += step11nS = len(stateVals)121314def gaussian_pdf(x, mu, sig):15 return math.exp(-0.5 * ((x - mu) / sig) ** 2) / (sig * math.sqrt(2 * math.pi))161718def state_probs(mu, sig):19 ps = [gaussian_pdf(s, mu, sig) for s in stateVals]20 t = sum(ps)21 return torch.tensor([p / t for p in ps])222324sorted_sv = sorted(stateVals)25tall_thr = [s - 1.0 / (binParam * 2) for s in sorted_sv]26short_thr = [s + 1.0 / (binParam * 2) for s in sorted_sv]27nT = len(tall_thr)28utterances = ["tall", "short", "silence"]29alpha = 5.030classes = ["subordinate", "superordinate"]31class_params = [(sub_mu, sub_sigma), (sup_mu, sup_sigma)]32class_sp = [state_probs(*class_params[ci]) for ci in range(2)]333435def meaning(utt, state, tall_t, short_t):36 if utt == "tall":37 return state > tall_t38 if utt == "short":39 return state < short_t40 return True414243# Literal listener: state posterior given utterance, thresholds, comparison class.44def literal_listener(ui, ti, shi, ci):45 sp = class_sp[ci]46 mask = torch.tensor([0.0 if meaning(utterances[ui], stateVals[s], tall_thr[ti], short_thr[shi])47 else -1e30 for s in range(nS)])4849 @pyro.infer.config_enumerate50 def model():51 st = pyro.sample("st", dist.Categorical(sp))52 pyro.factor("cond", mask[st])53 return None5455 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)56 m = marg["st"]57 sup = m.enumerate_support()58 pr = m.log_prob(sup).exp()59 out = [0.0] * nS60 for i, p in zip(sup.tolist(), pr.tolist()):61 out[int(i)] = p62 return out636465def pragmatic_listener(ui):66 sub_sp = state_probs(sub_mu, sub_sigma)67 # L0 table, cached per (class, thresholds): each level genuinely inferred.68 ll = {}69 for ci in range(2):70 for ti in range(nT):71 for shi in range(nT):72 ll[(ci, ti, shi)] = [literal_listener(u, ti, shi, ci) for u in range(3)]7374 # Speaker (S1): utterance posterior given (state, thresholds, class), inferred.75 score = torch.full((2, nS, nT, nT), -1e30)76 for ci in range(2):77 for ti in range(nT):78 for shi in range(nT):79 l0 = ll[(ci, ti, shi)]80 for si in range(nS):81 sc = torch.tensor([alpha * (math.log(l0[u][si]) if l0[u][si] > 0 else -1e30)82 for u in range(3)])8384 @pyro.infer.config_enumerate85 def smodel():86 u = pyro.sample("u", dist.Categorical(torch.ones(3)))87 pyro.factor("sc", sc[u])88 return None8990 sm = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(smodel, lambda: None)91 mm = sm["u"]92 ssup = mm.enumerate_support()93 spr = mm.log_prob(ssup).exp()94 for i, p in zip(ssup.tolist(), spr.tolist()):95 if int(i) == ui:96 score[ci, si, ti, shi] = math.log(p) if p > 0 else -1e309798 # Pragmatic listener (L1): comparison-class marginal, inferred by enumeration.99 @pyro.infer.config_enumerate100 def model():101 st = pyro.sample("st", dist.Categorical(sub_sp))102 tt = pyro.sample("tt", dist.Categorical(torch.ones(nT)))103 sh = pyro.sample("sh", dist.Categorical(torch.ones(nT)))104 cl = pyro.sample("cl", dist.Categorical(torch.ones(2)))105 pyro.factor("obs", score[cl, st, tt, sh])106 return None107108 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)109 m = marg["cl"]110 sup = m.enumerate_support()111 pr = m.log_prob(sup).exp()112 return {classes[int(i)]: p for i, p in zip(sup.tolist(), pr.tolist())}113114115ANSWER = pragmatic_listener(0) # "tall"116
| check | status | evidence |
|---|---|---|
| GT self-consistency | ok | floor 0.0000 (tv) |
| solver re-derivation | accept | 1/2 solvers · d=[0.000, 0.115] · claude-sonnet-4-6 |
| cross-language (pyro vs webppl) | pass | d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000 |
Weather states: ["terrible", "bad", "ok", "good", "amazing"]. State prior: categorical with unnormalized weights [1, 5, 40, 40, 40] over those states (California context — benign weather is far more likely). Utterances are the same five state labels, drawn uniformly. Valence prior given state: P(valence = -1 | terrible) = 0.99, P(valence = -1 | bad) = 0.90, P(valence = -1 | ok) = 0.50, P(valence = -1 | good) = 0.09, P(valence = -1 | amazing) = 0.01; valence is +1 otherwise. Arousal values: [0.1, 0.3, 0.5, 0.7, 0.9]. Arousal prior given state (categorical with unnormalized weights over those values): terrible → [1,10,30,45,50], bad → [1,5,25,40,45], ok → [50,45,30,10,1], good → [1,5,25,40,45], amazing → [1,10,30,45,50]. Goals: ["goalState", "goalValence", "goalArousal"], drawn with equal probability. A speaker's goal is satisfied by the listener inferring the state (for goalState), the valence (for goalValence), or the arousal (for goalArousal). The literal interpretation of an utterance is true iff the utterance label equals the world state. Speaker rationality weight: 1.
A pragmatic listener hears an utterance and jointly infers the world state, the speaker's valence, and the speaker's arousal. The listener's prior over (state, valence, arousal, goal) is the product of the independent marginal priors. A literal listener, given an utterance and a goal, conditions on the utterance being literally true of the state and returns the goal-relevant quantity (state, valence, or arousal). A speaker, given (state, valence, arousal, goal), chooses an utterance with probability proportional to how well a literal listener with that goal would recover the goal-relevant quantity. The pragmatic listener observes the speaker's chosen utterance and marginalizes over goals, returning the joint (state, valence, arousal).
The posterior joint distribution over (state, valence, arousal) for a pragmatic listener who hears the utterance "terrible".
answer spec
{
"kind": "dist",
"domain": "finite",
"labels": {
"record": {
"state": "string",
"valence": "int",
"arousal": "real"
}
}
}system prompt
(system prompt loads here)
webppl primer
(primer loads here)
1// There are three possible states the weather could be in:2// terrible, ok, or amazing3var states = ['terrible','bad','ok','good','amazing']45// Since we are in California, the prior over these states6// are the following. Once could also imagine this being7// the prior in a certain context, e.g. when it's clearly8// sunny and nice out.9var statePrior = function() {10 categorical([1,5,40,40,40], states)11}1213// Valence prior defined in terms of negative valence.14// If the current state is terrible, it's extremely likely15// that the valence associated is negative. If it's ok, then16// the valence could be negative or positive with equal17// probability.18var valencePrior = function(state) {19 state === "terrible" ? flip(0.99) ? -1 : 1 :20 state === "bad" ? flip(0.90) ? -1 : 1 :21 state === "ok" ? flip(0.5) ? -1 : 1 :22 state === "good" ? flip(0.09) ? -1 : 1 :23 state === "amazing" ? flip(0.01) ? -1 : 1 :24 true25}2627// Define binary arousals (could model as continuous).28// var arousals = ["low", "high"]29var arousals = [.1,.3,.5,.7,.9]3031// Define goals and goal priors. Could want to communicate state of the world,32// valence about it, or arousal (intensity of feeling) about it.33var goals = ["goalState", "goalValence", "goalArousal"]3435var goalPrior = function() {36 categorical([1, 1, 1], goals)37}3839// Assume possible utterances are identical to possible states40var utterances = states4142// Assume cost of utterances is uniform.43var utterancePrior = function() {44 uniformDraw(utterances)45}4647// Sample arousal given a state.48var arousalPrior = function(state) {49 state === "terrible" ? categorical([1,10,30,45,50], arousals) :50 state === "bad" ? categorical([1,5,25,40,45], arousals) :51 state === "ok" ? categorical([50,45,30,10,1], arousals) :52 state === "good" ? categorical([1,5,25,40,45], arousals) :53 state === "amazing" ? categorical([1,10,30,45,50], arousals) :54 true55}5657// Literal interpretation is just whether utterance equals state58var literalInterpretation = function(utterance, state) {59 utterance === state60}6162// A speaker's goal is satisfied if the listener infers the correct63// and relevant information.64var goalState = function(goal, state, valence, arousal) {65 goal === "goalState" ? state :66 goal === "goalValence" ? valence :67 goal === "goalArousal" ? arousal :68 true69}7071// Define a literal listener72var literalListener = function(utterance, goal) {73 Infer({model: function(){74 var state = uniformDraw(states)75 var valence = valencePrior(state)76 var arousal = arousalPrior(state)77 condition(literalInterpretation(utterance,state))78 return goalState(goal, state, valence, arousal)79 }})80}8182//The speaker takes in a state, valence, arousal, and a goal and returns an utterance83//based on the probability of the literalListener arriving at the correct84//state given a goalState85var speaker = function(state, valence, arousal, goal) {86 Infer({model: function(){87 var utterance = utterancePrior()88 factor(1 * literalListener(utterance,89 goal).score(goalState(goal,90 state,91 valence,92 arousal)))93 return utterance94 }})95}9697// Define a pragmatic listener98var pragmaticListener = function(utterance) {99 Infer({model: function(){100 var state = statePrior()101 var valence = valencePrior(state)102 var arousal = arousalPrior(state)103 var goal = goalPrior()104 observe(speaker(state, valence, arousal, goal),utterance)105 return {state,valence, arousal}106 }})107}108109viz.table(pragmaticListener('terrible'))110111var ANSWER = (pragmaticListener("terrible"));
12NEG = -1e303states = ["terrible", "bad", "ok", "good", "amazing"]4state_prior = torch.tensor([1.0, 5.0, 40.0, 40.0, 40.0])5val_neg_p = {"terrible": 0.99, "bad": 0.90, "ok": 0.5, "good": 0.09, "amazing": 0.01}6arousals = [.1, .3, .5, .7, .9]7arousal_w = {"terrible": [1, 10, 30, 45, 50], "bad": [1, 5, 25, 40, 45],8 "ok": [50, 45, 30, 10, 1], "good": [1, 5, 25, 40, 45],9 "amazing": [1, 10, 30, 45, 50]}10goals = ["goalState", "goalValence", "goalArousal"]11goal_prior = torch.tensor([1.0, 1.0, 1.0])12utterances = states[:]13valences = [-1, 1]141516def goal_value(goal, st_idx, val, ar):17 if goal == "goalState":18 return ("state", states[st_idx])19 if goal == "goalValence":20 return ("val", val)21 return ("ar", ar)222324# Literal listener: distribution over the goal-relevant quantity, inferred.25def literal_listener(utt_idx, goal):26 @pyro.infer.config_enumerate27 def model():28 st = pyro.sample("st", dist.Categorical(torch.ones(len(states))))29 vp = torch.stack([torch.tensor([val_neg_p[s], 1 - val_neg_p[s]]) for s in states])30 val = pyro.sample("val", dist.Categorical(vp[st]))31 ap = torch.stack([torch.tensor([float(x) for x in arousal_w[s]]) / sum(arousal_w[s])32 for s in states])33 ar = pyro.sample("ar", dist.Categorical(ap[st]))34 # literalInterpretation: utterance === state35 pyro.factor("cond", torch.where(st == utt_idx, torch.tensor(0.0), torch.tensor(NEG)))36 return None3738 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)39 if goal == "goalState":40 m = marg["st"]41 sup = m.enumerate_support()42 pr = m.log_prob(sup).exp()43 return {("state", states[int(i)]): p for i, p in zip(sup.tolist(), pr.tolist())}44 if goal == "goalValence":45 m = marg["val"]46 sup = m.enumerate_support()47 pr = m.log_prob(sup).exp()48 return {("val", valences[int(i)]): p for i, p in zip(sup.tolist(), pr.tolist())}49 m = marg["ar"]50 sup = m.enumerate_support()51 pr = m.log_prob(sup).exp()52 return {("ar", arousals[int(i)]): p for i, p in zip(sup.tolist(), pr.tolist())}535455spk_cache = {}565758# Speaker: utterance posterior given (state, valence, arousal, goal), inferred.59def speaker_logprobs(st_idx, val, ar, goal):60 gv = goal_value(goal, st_idx, val, ar)61 key = (gv, goal)62 if key in spk_cache:63 return spk_cache[key]64 l0 = {u: literal_listener(u, goal) for u in range(len(utterances))}6566 @pyro.infer.config_enumerate67 def model():68 u = pyro.sample("u", dist.Categorical(torch.ones(len(utterances))))69 sc = torch.tensor([math.log(l0[uu].get(gv, 0.0)) if l0[uu].get(gv, 0.0) > 0 else NEG70 for uu in range(len(utterances))])71 pyro.factor("sc", 1.0 * sc[u])72 return None7374 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)75 m = marg["u"]76 sup = m.enumerate_support()77 pr = m.log_prob(sup).exp()78 res = {int(i): p for i, p in zip(sup.tolist(), pr.tolist())}79 spk_cache[key] = res80 return res818283# Pragmatic listener: joint posterior over (state, valence, arousal) via exact84# enumeration of a single combined latent; the engine marginalizes.85def pragmatic_listener(utt_idx):86 combos = []87 for s in range(len(states)):88 for vi in range(2):89 for ai in range(len(arousals)):90 combos.append((s, vi, ai))91 sp_norm = (state_prior / state_prior.sum()).tolist()92 jp = []93 for (s, vi, ai) in combos:94 st = states[s]95 pv = val_neg_p[st] if vi == 0 else (1 - val_neg_p[st])96 aw = arousal_w[st]97 pa = aw[ai] / sum(aw)98 jp.append(sp_norm[s] * pv * pa)99 jp = torch.tensor(jp)100 score = torch.full((len(combos), 3), NEG)101 for j, (s, vi, ai) in enumerate(combos):102 val = valences[vi]103 ar = arousals[ai]104 for g in range(3):105 spk = speaker_logprobs(s, val, ar, goals[g])106 p = spk.get(utt_idx, 0.0)107 score[j, g] = math.log(p) if p > 0 else NEG108109 @pyro.infer.config_enumerate110 def model():111 cf = pyro.sample("cf", dist.Categorical(jp))112 goal = pyro.sample("goal", dist.Categorical(goal_prior))113 pyro.factor("obs", score[cf, goal])114 return None115116 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)117 m = marg["cf"]118 sup = m.enumerate_support()119 pr = m.log_prob(sup).exp()120 out = {}121 for j, p in zip(sup.tolist(), pr.tolist()):122 s, vi, ai = combos[int(j)]123 rec = {"state": states[s], "valence": valences[vi], "arousal": arousals[ai]}124 out[json.dumps(rec, sort_keys=True)] = p125 return out126127128ANSWER = pragmatic_listener(0) # "terrible"129
| check | status | evidence |
|---|---|---|
| GT self-consistency | ok | floor 0.0000 (tv) |
| solver re-derivation | accept | 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6 |
| cross-language (pyro vs webppl) | pass | d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000 |
Discretization: 30 box-weight values evenly spaced from 0.0 to 5.8 in steps of 0.2 (i.e., 0.0, 0.2, 0.4, …, 5.8). The state distribution for a given comparison class is obtained by computing the unnormalized Gaussian density at each of those 30 values and treating the results as categorical weights. Speaker-class Gaussian parameters: child (mu=0.5, sigma=1), adult (mu=2, sigma=3), bodybuilder (mu=5, sigma=3). The superordinate category has mu=3, sigma=1. Utterances: ["heavy", "light"]. Threshold sets: the "heavy" threshold set is obtained by subtracting 0.1 from each of the 30 state values (giving −0.1, 0.1, …, 5.7); the "light" threshold set by adding 0.1 to each (giving 0.1, 0.3, …, 5.9). Meaning: "heavy" is true with probability 0.9999 when state > the heavy threshold and 0.0001 otherwise; "light" is true with probability 0.9999 when state < the light threshold and 0.0001 otherwise. Speaker optimality alpha = 5. Comparison-class prior given the identity of the speaker: child → categorical([0.75, 0.25, 0.15], [child, adult, bodybuilder]); adult → categorical([0.01, 0.70, 0.50]); bodybuilder → categorical([0.0001, 0.20, 0.99]).
A pragmatic listener hears an adjective uttered by a speaker of known identity and infers the box's weight. The listener jointly samples a comparison class from the prior conditioned on the speaker's identity, then samples a weight from that class's Gaussian-discretized distribution, and independently samples a "heavy" threshold uniformly from the heavy threshold set and a "light" threshold uniformly from the light threshold set. The same two thresholds are passed to the speaker and through to the literal listener. A literal listener conditions on the soft meaning function. A speaker with a known weight, both thresholds, and comparison class chooses an utterance with probability proportional to exp(alpha × literal-listener score). The pragmatic listener observes the speaker's utterance and returns the marginal distribution over box weight.
The posterior distribution over box weight (one of the 30 values from 0.0 to 5.8) for a pragmatic listener who hears "heavy" uttered by a child.
answer spec
{
"kind": "dist",
"domain": "real"
}system prompt
(system prompt loads here)
webppl primer
(primer loads here)
1// helper function2var exp = function(x){return Math.exp(x)}34// helper function5var marginalize = function(dist, key){6 return Infer({model: function(){sample(dist)[key]}})7}8// for discretization9var binParam = 5;1011//my model:12var superordinate = {mu: 3, sigma: 1};1314//a list of possible box weights (state values)15var stateVals = _.range(superordinate.mu - 3 * superordinate.sigma,16 superordinate.mu + 3 * superordinate.sigma,17 superordinate.sigma/binParam)181920// for each possible weight, calculate its probability of occurrence21var stateProbs = function(stateParams){22 return map(function(s){23 exp(Gaussian(stateParams).score(s))24 }, stateVals)25};262728// generate a statePrior using the possible weights and their probabilities29var generateStatePrior = cache(function(stateParams) {30 return Infer({31 model: function(){32 return categorical({vs: stateVals, ps: stateProbs(stateParams)})33 }34 })35});3637// information about the category priors38var speakerParams = {39 child: {mu: 0.5, sigma: 1}, // child experience with weights40 adult: {mu: 2, sigma: 3}, // adult experiance with weights41 bodybuilder: {mu:5, sigma:3}, // body builder experience with weights42}434445// generate the uniform threshold prior4647var thresholdBins = cache(function(utterance, stateSupport){48 return map(function(x){49 return utterance == "heavy" ? x - (1/(binParam*2)) : x + (1/(binParam*2));50 }, sort(stateSupport))51},1000052)5354var thresholdPrior = cache(function(utterance, stateSupport){55 return Infer({56 model: function() { return uniformDraw(thresholdBins(utterance, stateSupport)) }57 });58},1000059);60616263var utterances = ["heavy","light"]646566// meaning function for utterances67var meaning = function(utterance, state, thresholdHeavy, thresholdLight) {68 utterance == "heavy" ? state > thresholdHeavy ? flip(0.9999) : flip(0.0001) :69 utterance == "light" ? state < thresholdLight ? flip(0.9999) : flip(0.0001) :70 true71}7273747576// set sepeaker optimality77var alpha = 5;787980var literalListener =cache(81 function(utterance, thresholdHeavy, thresholdLight,comparisonClass){82 Infer({model: function(){83 var state = sample(generateStatePrior(speakerParams[comparisonClass]))84 var m = meaning(utterance, state, thresholdHeavy, thresholdLight);85 condition(m);86 return state;87 }})88 },1000089)909192//literalListener("light", 4, 2, "adult")939495//my model:96var speaker1 = cache(97 function(state, thresholdHeavy, thresholdLight,comparisonClass){98 Infer({model: function(){99 var utterance = uniformDraw(utterances);100 var L0 = literalListener(utterance, thresholdHeavy, thresholdLight, comparisonClass);101 factor( alpha * L0.score(state) );102 return utterance;103 }})104 },10000105)106107108// generateStatePrior(speakerParams["child"]) .support()109110var comparisonClasses = ["child","adult","bodybuilder"]111var comparisonClassPrior = function(whoSaidIt) {112 whoSaidIt == "child" ? categorical([0.75, 0.25,0.15],comparisonClasses):113 whoSaidIt == "adult" ? categorical([0.01,0.7,0.5],comparisonClasses):114 categorical([0.0001,0.2, 0.99],comparisonClasses)115116117}118119var pragmaticListener = cache(function(utterance,whoSaidIt){120 Infer({model: function(){121 var CC = comparisonClassPrior(whoSaidIt);122 var statePrior = generateStatePrior(speakerParams[CC]);123 var state = sample(statePrior);124 var thresholdHeavy = sample(thresholdPrior("heavy", statePrior.support()))125 var thresholdLight = sample(thresholdPrior("light", statePrior.support()))126 var S1 = speaker1(state, thresholdHeavy,thresholdLight,CC);127 observe(S1, utterance);128 return (state)129}, method:"enumerate"})130},10000131)132133134135// pragmaticListener("heavy","adult")136display("Listener's interpretation after hearing a child saying 'this box is heavy'")137viz.density(pragmaticListener("heavy","child"))138display("Listener's interpretation after hearing an adult saying 'this box is heavy'")139viz.density(pragmaticListener("heavy","adult"))140display("Listener's interpretation after hearing a bodybuilder saying 'this box is heavy'")141viz.density(pragmaticListener("heavy","bodybuilder"))142display("Listener's interpretation after hearing a child saying 'this box is light'")143viz.density(pragmaticListener("light","child"))144display("Listener's interpretation after hearing an adult saying 'this box is light'")145viz.density(pragmaticListener("light","adult"))146display("Listener's interpretation after hearing a bodybuilder saying 'this box is light'")147viz.density(pragmaticListener("light","bodybuilder"))148149var ANSWER = (pragmaticListener("heavy", "child"));
1# RSA comparison-class model (kachakeche), pragmatic listener hearing a CHILD say2# 'heavy'. Every RSA level is run through Pyro's exact discrete enumeration3# (config_enumerate + TraceEnum_ELBO.compute_marginals). The literal listener and4# the speaker are inferred per configuration but vectorized across all configs with a5# pyro.plate, so each level is a genuine Pyro inference whose normalized posterior6# comes from Pyro's engine; lower-level log-scores enter the next level only as7# pyro.factor terms. No level's distribution is normalized by hand.89binParam = 510super_mu, super_sigma = 3.0, 1.011# state values: range(mu-3sigma, mu+3sigma, sigma/binParam) = range(0, 6, 0.2) -> 30 values12NS = 3013stateVals = [round(super_mu - 3 * super_sigma + i * (super_sigma / binParam), 10)14 for i in range(NS)]15stateVals_t = torch.tensor(stateVals)1617comparisonClasses = ["child", "adult", "bodybuilder"]18speakerParams = {"child": (0.5, 1.0), "adult": (2.0, 3.0), "bodybuilder": (5.0, 3.0)}19NC = len(comparisonClasses)20alpha = 5.021half = 1.0 / (binParam * 2) # 0.12223# threshold sets (stateVals are already sorted ascending):24# heavy = state - 0.1 ; light = state + 0.125heavy_thr = stateVals_t - half # [NS]26light_thr = stateVals_t + half # [NS]2728LOG_HI = math.log(0.9999)29LOG_LO = math.log(0.0001)303132def state_prior_probs(mu, sigma):33 w = dist.Normal(mu, sigma).log_prob(stateVals_t).exp()34 return w / w.sum()353637# Per-class state prior table [NC, NS].38state_prior_table = torch.stack(39 [state_prior_probs(*speakerParams[cc]) for cc in comparisonClasses]40)414243# ---- Literal listener L0, run through Pyro exact enumeration. For each (comparison44# class, threshold) the model enumerates the state under that class's prior and45# conditions on the soft meaning (flip(0.9999) if the utterance is literally true of46# the state else flip(0.0001)); compute_marginals returns the exact state posterior.47# All NC*NS configs are inferred together via a single plated enumeration. ----48def run_L0(thr_vec, is_heavy):49 K = NC * NS # configs over (cc, threshold)50 prior = state_prior_table.unsqueeze(1).expand(NC, NS, NS).reshape(K, NS) # [K, NS]51 thr_b = thr_vec.unsqueeze(0).expand(NC, NS).reshape(K) # [K]52 state_b = stateVals_t.unsqueeze(0) # [1, NS]53 if is_heavy:54 holds = state_b > thr_b.unsqueeze(1) # heavy: state > thresholdHeavy55 else:56 holds = state_b < thr_b.unsqueeze(1) # light: state < thresholdLight57 fac = torch.where(holds, torch.tensor(LOG_HI), torch.tensor(LOG_LO)) # [K, NS]5859 @pyro.infer.config_enumerate60 def model():61 with pyro.plate("cfg", K):62 s = pyro.sample("state", dist.Categorical(prior))63 sel = pyro.ops.indexing.Vindex(fac)[torch.arange(K), s]64 pyro.factor("meaning", sel)6566 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=1).compute_marginals(model, lambda: None)67 m = marg["state"]68 sup = m.enumerate_support() # [NS, K] state indices per config69 lp = m.log_prob(sup) # [NS, K]70 out = torch.full((K, NS), float("-inf"))71 cfg_idx = torch.arange(K).unsqueeze(0).expand_as(sup)72 out[cfg_idx, sup] = lp73 return out.reshape(NC, NS, NS) # [cc, thr, state]747576logL0H = run_L0(heavy_thr, True) # [cc, thH, state] log L0(state | 'heavy', thH, cc)77logL0L = run_L0(light_thr, False) # [cc, thL, state] log L0(state | 'light', thL, cc)787980# ---- Speaker S1, run through Pyro exact enumeration over the two utterances. For a81# fixed (cc, state, thH, thL) the speaker chooses utterance ~ exp(alpha * log L0(state82# | utterance, thresholds, cc)); 'heavy' scores against L0H[cc, thH], 'light' against83# L0L[cc, thL]. compute_marginals normalizes the two-utterance softmax. All84# NC*NS*NS*NS configs are inferred together via one plated enumeration. ----85scoreH = alpha * logL0H.permute(0, 2, 1) # [cc, state, thH]86scoreL = alpha * logL0L.permute(0, 2, 1) # [cc, state, thL]87sH = scoreH.unsqueeze(3).expand(NC, NS, NS, NS) # [cc, state, thH, thL]88sL = scoreL.unsqueeze(2).expand(NC, NS, NS, NS) # [cc, state, thH, thL]89spk_scores = torch.stack([sH, sL], dim=-1) # [cc, state, thH, thL, 2] (heavy, light)90Kf = NC * NS * NS * NS91spk_scores_flat = spk_scores.reshape(Kf, 2)929394@pyro.infer.config_enumerate95def speaker_model():96 with pyro.plate("scfg", Kf):97 u = pyro.sample("utt", dist.Categorical(torch.ones(Kf, 2)))98 sel = pyro.ops.indexing.Vindex(spk_scores_flat)[torch.arange(Kf), u]99 pyro.factor("util", sel)100101102smarg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=1).compute_marginals(speaker_model, lambda: None)103sm = smarg["utt"]104ssup = sm.enumerate_support() # [2, Kf] utterance indices per config105slp = sm.log_prob(ssup) # [2, Kf]106spk_lp = torch.full((Kf, 2), float("-inf"))107scfg_idx = torch.arange(Kf).unsqueeze(0).expand_as(ssup)108spk_lp[scfg_idx, ssup] = slp109# log P_speaker('heavy' | cc, state, thH, thL) (utterance index 0 == 'heavy')110logS_heavy = spk_lp[:, 0].reshape(NC, NS, NS, NS) # [cc, state, thH, thL]111112113# ---- Pragmatic listener (child speaker), exact enumeration over the joint discrete114# latents (comparison class, state, heavy-threshold, light-threshold). The speaker115# log-score for the heard utterance 'heavy' enters as the observation factor;116# compute_marginals contracts the nuisance latents to the exact marginal over state. ----117ccp_vec = torch.tensor([0.75, 0.25, 0.15])118ccp_vec = ccp_vec / ccp_vec.sum() # child speaker comparison-class prior119120121@pyro.infer.config_enumerate122def prag_model():123 cc_i = pyro.sample("cc", dist.Categorical(ccp_vec))124 sp = pyro.ops.indexing.Vindex(state_prior_table)[cc_i]125 state_i = pyro.sample("state", dist.Categorical(sp))126 th_h_i = pyro.sample("thH", dist.Categorical(torch.ones(NS) / NS))127 th_l_i = pyro.sample("thL", dist.Categorical(torch.ones(NS) / NS))128 score = pyro.ops.indexing.Vindex(logS_heavy)[cc_i, state_i, th_h_i, th_l_i]129 pyro.factor("obs", score)130131132marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(prag_model, lambda: None)133m = marg["state"]134sup = m.enumerate_support()135probs = m.log_prob(sup).exp()136ANSWER = {stateVals[int(s)]: float(pr) for s, pr in zip(sup, probs)}137
| check | status | evidence |
|---|---|---|
| GT self-consistency | ok | floor 0.0000 (w1) |
| solver re-derivation | accept | 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6 |
| cross-language (pyro vs webppl) | pass | d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000 |
Objects: all 9 combinations of type ∈ {apple, fish, cup} × color ∈ {red, blue, green}. Utterances: all strings of the form "[color] type" or just "type", giving 12 utterances total (3 color-qualified per type + 3 bare types). Utterance cost: the number of words in the utterance (1 for bare, 2 for color-qualified). Utterance fitness: for a two-word utterance, 0 if the object matches both color and type, −100 otherwise; for a one-word utterance, 0 if the object matches the type, −100 otherwise. Speaker optimality alpha = 3. Context: shared objects are {red apple, blue fish, green cup}; occluded object is {red fish}.
A baseline reference game where the speaker has no uncertainty about the environment. A literal listener draws uniformly from a perceived context and weights by utterance fitness. A speaker, knowing the target and the shared context, chooses an utterance with probability proportional to exp(alpha × literal-listener score − utterance cost). A pragmatic listener (the L2 model) assumes the speaker only knows and considers the shared context. The pragmatic listener draws uniformly from the shared context (ignoring the occluded object) and updates on the speaker's distribution.
The posterior distribution over object identity (expressed as "color type" strings) for the pragmatic listener who hears the bare utterance "fish" in the given context.
answer spec
{
"kind": "dist",
"domain": "finite",
"support": [
"red apple",
"blue fish",
"green cup"
]
}system prompt
(system prompt loads here)
webppl primer
(primer loads here)
1var possibleUtterances = _.flatten(map(function(modifier) {2 map(function(type) {3 return [modifier, type].join(' ').trim();4 }, ['apple', 'fish', 'cup']);5}, ['red', 'blue', 'green', '']));67var possibleObjects = [8 {type: 'apple', color: 'red'}, {type: 'apple', color: 'blue'}, {type: 'apple', color: 'green'},9 {type: 'fish', color: 'red'}, {type: 'fish', color: 'blue'}, {type: 'fish', color: 'green'},10 {type: 'cup', color: 'red'}, {type: 'cup', color: 'blue'}, {type: 'cup', color: 'green'}]1112var exampleContext = {13 shared: [14 {type: 'apple', color: 'red'},15 {type: 'fish', color: 'blue'},16 {type: 'cup', color: 'green'}17 ],18 occluded: [19 {type: 'fish', color: 'red'}20 ]21};2223var alpha = 3;2425var uttCost = function(utt) {26 return utt.split(' ').length;27}2829var uttFitness = function(utt, object) {30 var descriptors = utt.split(' ');31 if(descriptors.length > 1) {32 return (object.color === descriptors[0] &&33 object.type === descriptors[1]) ? 0 : -100;34 } else {35 return object.type === descriptors[0] ? 0 : -100;36 }37};3839var L0 = cache(function(utt, perceivedContext) {40 return Infer({method: 'enumerate'}, function() {41 var object = uniformDraw(perceivedContext);42 factor(uttFitness(utt, object));43 return object;44 });45});4647var S1 = cache(function(target, knownContext) {48 return Infer({method: 'enumerate'}, function() {49 var utt = uniformDraw(possibleUtterances);50 factor(alpha * L0(utt, knownContext).score(target) - uttCost(utt));51 return utt;52 });53});5455// Listener only considers objects speaker can see (model Keysar is arguing against)56var L2 = cache(function(utt, perceivedContext) {57 var sharedContext = perceivedContext.shared;58 var fullObjSet = sharedContext.concat(perceivedContext.occluded);59 return Infer({method: 'enumerate'}, function() {60 var object = uniformDraw(sharedContext);61 observe(S1(object, sharedContext), utt);62 return object.color + " " + object.type;63 });64});6566console.log("speaker utterance to refer to blue fish");67viz.table(S1({type: 'fish', color: 'blue'}, exampleContext.shared));6869console.log("listener response after hearing (underinformative) 'fish'");70viz.table(L2('fish', exampleContext));7172var ANSWER = (L2('fish', exampleContext));
1possibleUtterances = []2for modifier in ["red", "blue", "green", ""]:3 for typ in ["apple", "fish", "cup"]:4 possibleUtterances.append((modifier + " " + typ).strip())56shared = [7 {"type": "apple", "color": "red"},8 {"type": "fish", "color": "blue"},9 {"type": "cup", "color": "green"},10]11occluded = [{"type": "fish", "color": "red"}]12alpha = 3.01314def uttCost(utt):15 return len(utt.split(" "))1617def uttFitness(utt, obj):18 d = utt.split(" ")19 if len(d) > 1:20 return 0.0 if (obj["color"] == d[0] and obj["type"] == d[1]) else -100.021 return 0.0 if obj["type"] == d[0] else -100.02223def _ctxkey(ctx):24 return tuple((o["color"], o["type"]) for o in ctx)2526# Literal listener: uniform over the perceived context, factor by utterance fitness.27# Exact enumeration via Pyro's compute_marginals.28_L0_cache = {}29def L0(utt, perceivedContext):30 key = (utt, _ctxkey(perceivedContext))31 if key in _L0_cache:32 return _L0_cache[key]33 _fit = torch.tensor([uttFitness(utt, o) for o in perceivedContext])34 @pyro.infer.config_enumerate35 def m():36 idx = pyro.sample("obj", dist.Categorical(torch.ones(len(perceivedContext))))37 pyro.factor("fit", _fit[idx])38 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(m, lambda: None)39 _L0_cache[key] = marg["obj"]40 return _L0_cache[key]4142# Speaker: utterance ~ exp(alpha * L0.score(target) - uttCost), inverting L0.43_S1_cache = {}44def S1(target, knownContext):45 key = ((target["color"], target["type"]), _ctxkey(knownContext))46 if key in _S1_cache:47 return _S1_cache[key]48 tidx = knownContext.index(target)49 L0s = [L0(u, knownContext) for u in possibleUtterances]50 _sc = torch.tensor([51 alpha * L0s[i].log_prob(torch.tensor(tidx)).item() - uttCost(possibleUtterances[i])52 for i in range(len(possibleUtterances))53 ])54 @pyro.infer.config_enumerate55 def m():56 uidx = pyro.sample("utt", dist.Categorical(torch.ones(len(possibleUtterances))))57 pyro.factor("util", _sc[uidx])58 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(m, lambda: None)59 _S1_cache[key] = marg["utt"]60 return _S1_cache[key]6162# Pragmatic listener (L2): uniform over the shared context, observe the speaker's63# utterance where the speaker only knows the shared context.64heard = "fish"65_uidx = possibleUtterances.index(heard)66_S1s = [S1(o, shared) for o in shared]67_obs_scores = torch.tensor([68 _S1s[i].log_prob(torch.tensor(_uidx)).item() for i in range(len(shared))69])7071@pyro.infer.config_enumerate72def _L2():73 oidx = pyro.sample("obj", dist.Categorical(torch.ones(len(shared))))74 pyro.factor("obs", _obs_scores[oidx])7576_marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(_L2, lambda: None)77_post = _marg["obj"]78ANSWER = {79 shared[i]["color"] + " " + shared[i]["type"]: _post.probs[i].item()80 for i in range(len(shared))81}82
| check | status | evidence |
|---|---|---|
| GT self-consistency | ok | floor 0.0000 (tv) |
| solver re-derivation | accept | 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6 |
| cross-language (pyro vs webppl) | pass | d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000 |
Objects: all 9 combinations of type ∈ {apple, fish, cup} × color ∈ {red, blue, green}. Utterances: all strings of the form "[color] type" or just "type", giving 12 utterances. Utterance cost: number of words divided by 4. Utterance fitness: for two-word utterances, 0 if both color and type match, −100 otherwise; for one-word utterances, 0 if type matches, −100 otherwise. Speaker optimality alpha = 4. Context: shared objects are {red apple, blue fish, green cup}; occluded object is {red fish}.
A context-uncertainty reference game where the speaker is uncertain about additional objects the listener may perceive. A literal listener draws uniformly from a perceived context and weights by utterance fitness. A speaker, given a target and the shared context, marginalizes over uncertainty about one hidden additional object by considering all possible contexts formed by augmenting the shared context with each of the 9 possible objects uniformly, scoring the literal listener on each, then weights utterances by exp(alpha × marginalized-listener score − utterance cost). A pragmatic listener (L2) considers both shared and occluded objects as potential referents, draws uniformly from the full context, and updates on the speaker's distribution (who is assumed to know only the shared context).
The posterior distribution over object identity ("color type" strings) for the pragmatic listener who hears the bare utterance "fish" in the given context.
answer spec
{
"kind": "dist",
"domain": "finite",
"support": [
"red apple",
"blue fish",
"green cup",
"red fish"
]
}system prompt
(system prompt loads here)
webppl primer
(primer loads here)
1var possibleUtterances = _.flatten(map(function(modifier) {2 return map(function(type) {3 return [modifier, type].join(' ').trim();4 }, ['apple', 'fish', 'cup']);5}, ['red', 'blue', 'green', '']));67var possibleObjects = [8 {type: 'apple', color: 'red'}, {type: 'apple', color: 'blue'}, {type: 'apple', color: 'green'},9 {type: 'fish', color: 'red'}, {type: 'fish', color: 'blue'}, {type: 'fish', color: 'green'},10 {type: 'cup', color: 'red'}, {type: 'cup', color: 'blue'}, {type: 'cup', color: 'green'}]1112var exampleContext = {13 shared: [14 {type: 'apple', color: 'red'},15 {type: 'fish', color: 'blue'},16 {type: 'cup', color: 'green'}17 ],18 occluded: [19 {type: 'fish', color: 'red'}20 ]21};2223var alpha = 4;2425var uttCost = function(utt) {26 return utt.split(' ').length/4;27}2829var uttFitness = function(utt, object) {30 var descriptors = utt.split(' ');31 if(descriptors.length > 1) {32 return (object.color === descriptors[0] &&33 object.type === descriptors[1]) ? 0 : -100;34 } else {35 return object.type === descriptors[0] ? 0 : -100;36 }37};3839var L0 = cache(function(utt, perceivedContext) {40 return Infer({method: 'enumerate'}, function() {41 var object = uniformDraw(perceivedContext);42 factor(uttFitness(utt, object))43 return object;44 })45})4647// speaker has uncertainty over what's behind occluded square48// marginalizes over all possibilities49var S1 = cache(function(target, perceivedContext) {50 return Infer({method: 'enumerate'}, function() {51 var utt = uniformDraw(possibleUtterances);52 var listener = Infer({method: 'enumerate', model: function() {53 var context = perceivedContext.concat(uniformDraw(possibleObjects));54 return sample(L0(utt, context));55 }})56 factor(alpha * listener.score(target) - uttCost(utt))57 return utt;58 });59});6061// Listener reasons about S1; assumes they only see what's in shared context62// but could be trying to refer to any of the objects63var L2 = cache(function(utt, perceivedContext) {64 var sharedContext = perceivedContext.shared;65 var fullContext = sharedContext.concat(perceivedContext.occluded);66 return Infer({method: 'enumerate', model: function() {67 var object = uniformDraw(fullContext);68 observe(S1(object, sharedContext), utt);69 return object.color + " " + object.type;70 }});71})7273console.log("speaker utterance to refer to blue fish")74viz.table(S1({type: 'fish', color: 'blue'}, exampleContext.shared))7576console.log("listener response after hearing (underinformative) 'fish'")77viz.table(L2('fish', exampleContext));7879var ANSWER = (L2('fish', exampleContext));
1# Keysar context-uncertainty reference game (RSA).2# Each level (L0 literal listener, S1 speaker, L2 pragmatic listener) is a3# discrete enumerable model; we run Pyro's own exact enumeration inference4# (config_enumerate + TraceEnum_ELBO.compute_marginals) at every level and let5# the higher levels reason about the lower-level posteriors via their scores,6# exactly as the WebPPL recursion does.78import json910possibleObjects = [11 {"type": "apple", "color": "red"}, {"type": "apple", "color": "blue"}, {"type": "apple", "color": "green"},12 {"type": "fish", "color": "red"}, {"type": "fish", "color": "blue"}, {"type": "fish", "color": "green"},13 {"type": "cup", "color": "red"}, {"type": "cup", "color": "blue"}, {"type": "cup", "color": "green"},14]1516possibleUtterances = []17for modifier in ["red", "blue", "green", ""]:18 for t in ["apple", "fish", "cup"]:19 possibleUtterances.append((modifier + " " + t).strip())2021shared = [22 {"type": "apple", "color": "red"},23 {"type": "fish", "color": "blue"},24 {"type": "cup", "color": "green"},25]26occluded = [{"type": "fish", "color": "red"}]27fullContext = shared + occluded2829alpha = 4.03031def obj_key(o):32 return o["color"] + " " + o["type"]3334def uttCost(utt):35 return len(utt.split(" ")) / 4.03637def uttFitness(utt, obj):38 parts = utt.split(" ")39 if len(parts) > 1:40 return 0.0 if (obj["color"] == parts[0] and obj["type"] == parts[1]) else -100.041 else:42 return 0.0 if obj["type"] == parts[0] else -100.04344def enum_marginal(log_weights):45 # Run Pyro exact enumeration over a uniform Categorical with per-outcome46 # log-weights supplied via pyro.factor; return the normalized marginal as a47 # 1-D tensor of probabilities (one per outcome) computed by Pyro's inference.48 n = len(log_weights)49 lw = torch.tensor([float(x) for x in log_weights])50 base = torch.ones(n) / n5152 @pyro.infer.config_enumerate53 def model():54 idx = pyro.sample("idx", dist.Categorical(base))55 pyro.factor("w", lw[idx])56 return idx5758 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)["idx"]59 sup = marg.enumerate_support()60 probs = marg.log_prob(sup).exp()61 out = torch.zeros(n)62 for s, p in zip(sup, probs):63 out[int(s.item())] = p64 return out6566_L0 = {}67def L0(utt, context):68 # Literal listener: uniform over objects in context, weighted by fitness.69 key = (utt, tuple(obj_key(o) for o in context))70 if key in _L0:71 return _L0[key]72 probs = enum_marginal([uttFitness(utt, o) for o in context])73 _L0[key] = probs74 return probs7576_S1 = {}77def S1(target, perceivedContext):78 # Speaker: softmax over utterances by alpha * (marginalized literal-listener79 # log-score of the target) - cost. The listener marginalizes over one hidden80 # extra object drawn uniformly from possibleObjects.81 tkey = obj_key(target)82 key = (tkey, tuple(obj_key(o) for o in perceivedContext))83 if key in _S1:84 return _S1[key]85 log_weights = []86 for utt in possibleUtterances:87 # marginal P(listener picks target) over the 9 augmented contexts88 marg_p = 0.089 for added in possibleObjects:90 context = perceivedContext + [added]91 l0 = L0(utt, context)92 for i, o in enumerate(context):93 if obj_key(o) == tkey:94 marg_p += float(l0[i].item())95 marg_p = marg_p / len(possibleObjects)96 score = math.log(marg_p) if marg_p > 0 else float("-inf")97 log_weights.append(alpha * score - uttCost(utt))98 probs = enum_marginal(log_weights)99 _S1[key] = probs100 return probs101102def L2(utt, perceivedContext):103 # Pragmatic listener: uniform over fullContext objects, observe S1 emitting104 # utt (assuming the speaker sees only the shared context).105 sharedCtx = perceivedContext["shared"]106 full = perceivedContext["shared"] + perceivedContext["occluded"]107 utt_idx = possibleUtterances.index(utt)108 log_weights = []109 for obj in full:110 s1 = S1(obj, sharedCtx)111 p = float(s1[utt_idx].item())112 log_weights.append(math.log(p) if p > 0 else float("-inf"))113 probs = enum_marginal(log_weights)114 return full, probs115116context = {"shared": shared, "occluded": occluded}117full, probs = L2("fish", context)118119support = ["red apple", "blue fish", "green cup", "red fish"]120result = {}121for o, p in zip(full, probs):122 result[obj_key(o)] = float(p.item())123ANSWER = {lab: result.get(lab, 0.0) for lab in support}124
| check | status | evidence |
|---|---|---|
| GT self-consistency | ok | floor 0.0000 (tv) |
| solver re-derivation | accept | 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6 |
| cross-language (pyro vs webppl) | pass | d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000 |
Two horses; the world state is the count of red (true) horses: 0, 1, or 2. Each horse is independently red with probability 0.5, so the state prior assigns probability 0.25 to state=0, 0.5 to state=1, and 0.25 to state=2. Utterances: ["null", "every-not"]. Utterance cost: 1 for both. Scopes: ["surface", "inverse"], each with prior probability 0.5. Meaning of "every-not": under surface scope, true iff state = 0 (not any horse is red); under inverse scope, true iff state < 2 (not all horses are red). "null" is always true. QUDs: ["how many?", "all red?", "none red?"], drawn with equal probability. The QUD function maps a state to: state = (numHorses=2) for "all red?"; state = 0 for "none red?"; the raw state count for "how many?". Speaker rationality alpha = 1.
A four-level RSA model over scope ambiguity. A literal listener draws the world state uniformly from {0, 1, 2} (equal weight 1/3 each — not the binomial state prior), conditions on the utterance being true under a given scope and QUD, and returns the QUD-relevant quantity. The pragmatic listener and pragmatic speaker use the binomial state prior (0.25/0.5/0.25) described above. A speaker (S1) given a scope, state, and QUD weights utterances by exp(alpha × (literal-listener score on the QUD-relevant quantity − cost)). A pragmatic listener (L1) samples state, scope, and QUD from their priors and updates on the speaker. A pragmatic speaker (S2) weights utterances by the pragmatic listener's posterior probability of the true state.
The distribution over utterances ["null", "every-not"] for the pragmatic speaker at world state 1 (one of two horses is red).
answer spec
{
"kind": "dist",
"domain": "finite",
"support": [
"null",
"every-not"
]
}system prompt
(system prompt loads here)
webppl primer
(primer loads here)
1// helper function to tally up the state2var numTrue = function(state) {3 var fun = function(x) {4 x ? 1 : 05 }6 return sum(map(fun,state))7}89// possible utterances10var utterances = ["null","every-not"];11var utterancePrior = function() {12 uniformDraw(utterances)13}14// uniform utterance cost15var cost = function(utterance) {16 return 117}1819// possible world states20var numHorses = 221var states = [0,1,2]22var baserate = 0.5 // change to manipulate prior on world states23var stateMaker = function(numHorses,stateSoFar) {24 if (numHorses == 0) {25 return stateSoFar26 } else {27 var newHorse = flip(baserate)28 var newState = stateSoFar.concat([newHorse])29 return stateMaker(numHorses - 1, newState)30 }31}32var statePrior = function() {33 return numTrue(stateMaker(numHorses,[]))34}3536// possible scope interpretations37var scopes = ["surface", "inverse"]38var scopePrior = function(){39 return categorical([.5,.5],scopes) // change to manipulate prior on scope interpretations40}414243// meaning function44var meaning = function(utterance, state, scope) {45 return utterance == "every-not" ?46 scope == "surface" ? state == 0 :47 state < numHorses :48 true;49};5051// possible QUDs52var QUDs = ["how many?","all red?","none red?"];53var QUDPrior = function() {54 uniformDraw(QUDs);55 // categorical([.05,.05,.9],QUDs) // change to manipulate prior on QUDs56}57var QUDFun = function(QUD,state) {58 QUD == "all red?" ? state == numHorses :59 QUD == "none red?" ? state == 0 :60 state;61};6263// Literal listener (L0)64var literalListener = cache(function(utterance,scope,QUD) {65 Infer({model: function(){66 var state = uniformDraw(states);67 var qState = QUDFun(QUD,state)68 condition(meaning(utterance,state,scope));69 return qState;70 }});71});7273var alpha = 17475// Speaker (S1)76var speaker = cache(function(scope, state, QUD) {77 return Infer({model: function(){78 var utterance = utterancePrior()79 var qState = QUDFun(QUD, state)80 factor(alpha*(literalListener(utterance,scope,QUD).score(qState)81 - cost(utterance)))82 return utterance83 }})84})8586// Pragmatic listener (L1)87var pragmaticListener = cache(function(utterance) {88 Infer({model: function(){89 var state = statePrior();90 var scope = scopePrior();91 var QUD = QUDPrior();92 observe(speaker(scope,state,QUD),utterance);93 return state94 }});95});9697// Pragmatic speaker (S2)98var pragmaticSpeaker = cache(function(state) {99 Infer({model: function(){100 var utterance = utterancePrior();101 factor(pragmaticListener(utterance).score(state))102 return utterance103 }})104})105106// A speaker decides whether to endorse the ambiguous utterance as a107// description of the not-all world state108display(pragmaticSpeaker(1))109110var ANSWER = (pragmaticSpeaker(1));
12NEG = -1e303utterances = ["null", "every-not"]4numHorses = 25states = [0, 1, 2]6# statePrior = numTrue of two flip(0.5) -> Binomial(2, 0.5)7state_prior = torch.tensor([0.25, 0.5, 0.25])8scopes = ["surface", "inverse"]9scope_prior = torch.tensor([0.5, 0.5])10QUDs = ["how many?", "all red?", "none red?"]11qud_prior = torch.ones(3)12alpha = 1.0131415def meaning(utt, state, scope):16 if utt == "every-not":17 if scope == "surface":18 return state == 019 return state < numHorses20 return True212223def qud_fun(qud, state):24 if qud == "all red?":25 return ("b", state == numHorses)26 if qud == "none red?":27 return ("b", state == 0)28 return ("n", state)293031# Literal listener: posterior over the QUD-projected state, inferred.32def literal_listener(ui, scope_i, qi):33 utt = utterances[ui]34 scope = scopes[scope_i]35 qud = QUDs[qi]36 mask = torch.tensor([0.0 if meaning(utt, states[s], scope) else NEG for s in range(3)])3738 @pyro.infer.config_enumerate39 def model():40 st = pyro.sample("st", dist.Categorical(torch.ones(3)))41 pyro.factor("cond", mask[st])42 return None4344 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)45 m = marg["st"]46 sup = m.enumerate_support()47 pr = m.log_prob(sup).exp()48 out = {}49 for i, p in zip(sup.tolist(), pr.tolist()):50 q = qud_fun(qud, states[int(i)])51 out[q] = out.get(q, 0.0) + p52 return out535455# Speaker S1: utterance posterior given (scope, state, QUD), inferred.56def speaker1_logprobs(scope_i, st_val, qi):57 qud = QUDs[qi]58 q_state = qud_fun(qud, st_val)59 ll = [literal_listener(u, scope_i, qi) for u in range(2)]6061 @pyro.infer.config_enumerate62 def model():63 u = pyro.sample("u", dist.Categorical(torch.ones(2)))64 sc = torch.tensor([alpha * ((math.log(ll[uu].get(q_state, 0.0))65 if ll[uu].get(q_state, 0.0) > 0 else NEG) - 1.0)66 for uu in range(2)]) # cost(utterance) == 167 pyro.factor("sc", sc[u])68 return None6970 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)71 m = marg["u"]72 sup = m.enumerate_support()73 pr = m.log_prob(sup).exp()74 return {int(i): p for i, p in zip(sup.tolist(), pr.tolist())}757677# Pragmatic listener L1: state posterior given utterance, inferred.78def pragmatic_listener(ui):79 score = torch.full((3, 2, 3), NEG)80 for st in range(3):81 for sc in range(2):82 for qi in range(3):83 spk = speaker1_logprobs(sc, states[st], qi)84 p = spk.get(ui, 0.0)85 score[st, sc, qi] = math.log(p) if p > 0 else NEG8687 @pyro.infer.config_enumerate88 def model():89 st = pyro.sample("st", dist.Categorical(state_prior))90 sc = pyro.sample("sc", dist.Categorical(scope_prior))91 qd = pyro.sample("qd", dist.Categorical(qud_prior))92 pyro.factor("obs", score[st, sc, qd])93 return None9495 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)96 m = marg["st"]97 sup = m.enumerate_support()98 pr = m.log_prob(sup).exp()99 return {int(i): p for i, p in zip(sup.tolist(), pr.tolist())}100101102# Pragmatic speaker S2: utterance posterior given the state, inferred.103def pragmatic_speaker(st_val):104 l1 = [pragmatic_listener(u) for u in range(2)]105 sc = torch.tensor([math.log(l1[u].get(st_val, 0.0)) if l1[u].get(st_val, 0.0) > 0 else NEG106 for u in range(2)])107108 @pyro.infer.config_enumerate109 def model():110 u = pyro.sample("u", dist.Categorical(torch.ones(2)))111 pyro.factor("sc", sc[u])112 return None113114 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)115 m = marg["u"]116 sup = m.enumerate_support()117 pr = m.log_prob(sup).exp()118 return {utterances[int(i)]: p for i, p in zip(sup.tolist(), pr.tolist())}119120121ANSWER = pragmatic_speaker(1)122
| check | status | evidence |
|---|---|---|
| GT self-consistency | ok | floor 0.0000 (tv) |
| solver re-derivation | accept | 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6 |
| cross-language (pyro vs webppl) | pass | d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000 |
Weather states are 'terrible', 'ok', and 'amazing'. The prior over states is categorical with unnormalized weights [1, 50, 50] — ok and amazing are each 50 times more likely than terrible (California prior). Valence (positive or negative affect) depends on state: given 'terrible', negative valence has probability 0.99; given 'ok', each valence is equally likely (0.5); given 'amazing', positive valence has probability 0.99. Arousal is binary: 'low' or 'high'. Given 'terrible', the arousal distribution is categorical([0.1, 0.9]); given 'ok', categorical([0.9, 0.1]); given 'amazing', categorical([0.1, 0.9]). A speaker has one of three communicative goals — conveying the weather state, conveying valence, or conveying arousal — each equally likely a priori. Utterances are the state labels themselves, each equally likely a priori. The speaker uses a utility-weighted pragmatic reasoning chain: a literal listener updates on whether the utterance matches the state and returns the goal-relevant quantity; a speaker chooses utterances proportional to the literal listener's probability of recovering the goal-relevant truth (soft-max with weight 1); a pragmatic listener inverts the speaker to jointly infer state, valence, arousal, and goal.
A pragmatic listener hears an utterance and infers the joint distribution over (weather state, valence, arousal, communicative goal) by inverting a speaker who chose the utterance to satisfy a sampled communicative goal. The literal listener conditions only on whether the utterance names the state literally.
The posterior joint distribution over (state, valence, arousal, goal) given the utterance 'terrible', as a distribution over the four-field record {state, valence, arousal, goal}. Goal labels are the strings 'goalState', 'goalValence', and 'goalArousal'.
answer spec
{
"kind": "dist",
"domain": "finite",
"labels": {
"record": {
"state": "string",
"valence": "int",
"arousal": "string",
"goal": "string"
}
}
}system prompt
(system prompt loads here)
webppl primer
(primer loads here)
1// There are three possible states the weather could be in:2// terrible, ok, or amazing3var states = ['terrible', 'ok', 'amazing']45// Since we are in California, the prior over these states6// are the following. Once could also imagine this being7// the prior in a certain context, e.g. when it's clearly8// sunny and nice out.9var statePrior = function() {10 categorical([1, 50, 50], states)11}1213// Valence prior defined in terms of negative valence.14// If the current state is terrible, it's extremely likely15// that the valence associated is negative. If it's ok, then16// the valence could be negative or positive with equal17// probability.18var valencePrior = function(state) {19 state === "terrible" ? flip(0.99) ? -1 : 1 :20 state === "ok" ? flip(0.5) ? -1 : 1 :21 state === "amazing" ? flip(0.01) ? -1 : 1 :22 true23}2425// Define binary arousals (could model as continuous).26var arousals = ["low", "high"]2728// Define goals and goal priors. Could want to communicate state of the world,29// valence about it, or arousal (intensity of feeling) about it.30var goals = ["goalState", "goalValence", "goalArousal"]3132var goalPrior = function() {33 categorical([1, 1, 1], goals)34}3536// Assume possible utterances are identical to possible states37var utterances = states3839// Assume cost of utterances is uniform.40var utterancePrior = function() {41 uniformDraw(utterances)42}4344// Sample arousal given a state.45var arousalPrior = function(state) {46 state === "terrible" ? categorical([0.1, 0.9], arousals) :47 state === "ok" ? categorical([0.9, 0.1], arousals) :48 state === "amazing" ? categorical([0.1, 0.9], arousals) :49 true50}5152// Literal interpretation is just whether utterance equals state53var literalInterpretation = function(utterance, state) {54 utterance === state55}5657// A speaker's goal is satisfied if the listener infers the correct58// and relevant information.59var goalState = function(goal, state, valence, arousal) {60 goal === "goalState" ? state :61 goal === "goalValence" ? valence :62 goal === "goalArousal" ? arousal :63 true64}6566// Define a literal listener67var literalListener = function(utterance, goal) {68 Infer({model: function(){69 var state = uniformDraw(states)70 var valence = valencePrior(state)71 var arousal = arousalPrior(state)72 condition(literalInterpretation(utterance,state))73 return goalState(goal, state, valence, arousal)74 }})75}7677// Define a speaker78var speaker = function(state, valence, arousal, goal) {79 Infer({model: function(){80 var utterance = utterancePrior()81 factor(1 * literalListener(utterance,82 goal).score(goalState(goal,83 state,84 valence,85 arousal)))86 return utterance87 }})88}8990// Define a pragmatic listener91var pragmaticListener = function(utterance) {92 Infer({model: function(){93 var state = statePrior()94 var valence = valencePrior(state)95 var arousal = arousalPrior(state)96 var goal = goalPrior()97 observe(speaker(state, valence, arousal, goal),utterance)98 return {state, valence, arousal, goal}99 }})100}101102viz.table(literalListener("terrible", "goalState"))103viz.table(speaker("terrible", -1, "high", "goalValence"))104viz.table(pragmaticListener("terrible"))105106var ANSWER = (pragmaticListener('terrible'));
12states = ["terrible", "ok", "amazing"]3state_prior_w = [1.0, 50.0, 50.0]4utterances = states[:]5goals = ["goalState", "goalValence", "goalArousal"]6valence_labels = [-1, 1]7arousal_labels = ["low", "high"]89def valence_neg_prob(state): # P(valence == -1)10 return 0.99 if state == "terrible" else 0.5 if state == "ok" else 0.011112def arousal_low_prob(state): # P(arousal == 'low')13 return 0.1 if state == "terrible" else 0.9 if state == "ok" else 0.11415def _elbo():16 return pyro.infer.TraceEnum_ELBO(max_plate_nesting=0)1718_LL_CACHE = {}19def literal_listener_marginals(utterance, goal):20 key = (utterance, goal)21 if key in _LL_CACHE:22 return _LL_CACHE[key]23 u_idx = states.index(utterance)2425 @pyro.infer.config_enumerate26 def model():27 si = pyro.sample("state", dist.Categorical(torch.ones(3) / 3)) # uniformDraw28 pneg = torch.tensor([valence_neg_prob(s) for s in states])[si]29 # valence idx0 -> -1, idx1 -> +130 pyro.sample("valence", dist.Categorical(torch.stack([pneg, 1 - pneg], -1)))31 pl = torch.tensor([arousal_low_prob(s) for s in states])[si]32 # arousal idx0 -> low, idx1 -> high33 pyro.sample("arousal", dist.Categorical(torch.stack([pl, 1 - pl], -1)))34 # condition literalInterpretation: utterance === state35 pyro.factor(36 "lit",37 torch.where(si == u_idx, torch.tensor(0.0), torch.tensor(float("-inf"))),38 )39 return si4041 marg = _elbo().compute_marginals(model, lambda: None)42 _LL_CACHE[key] = marg43 return marg4445def qud_answer(goal, state, valence, arousal):46 if goal == "goalState":47 return ("state", state)48 if goal == "goalValence":49 return ("valence", valence)50 return ("arousal", arousal)5152def ll_logp(utterance, goal, ans):53 marg = literal_listener_marginals(utterance, goal)54 kind, val = ans55 if kind == "state":56 d = marg["state"]; idx = states.index(val)57 elif kind == "valence":58 d = marg["valence"]; idx = valence_labels.index(val)59 else:60 d = marg["arousal"]; idx = arousal_labels.index(val)61 return float(d.log_prob(torch.tensor(idx)))6263_SP_CACHE = {}64def speaker(state, valence, arousal, goal):65 key = (state, valence, arousal, goal)66 if key in _SP_CACHE:67 return _SP_CACHE[key]68 ans = qud_answer(goal, state, valence, arousal)69 utils = torch.tensor([1.0 * ll_logp(u, goal, ans) for u in utterances])7071 @pyro.infer.config_enumerate72 def model():73 ui = pyro.sample("utt", dist.Categorical(torch.ones(len(utterances)) / len(utterances)))74 pyro.factor("util", utils[ui]) # factor(1 * literalListener.score(...))75 return ui7677 marg = _elbo().compute_marginals(model, lambda: None)["utt"]78 _SP_CACHE[key] = marg79 return marg8081target = "terrible"82target_idx = utterances.index(target)8384cfgs = []85for st in states:86 for val in valence_labels:87 for ar in arousal_labels:88 for goal in goals:89 cfgs.append((st, val, ar, goal))9091Zw = sum(state_prior_w)92prior = torch.tensor([93 (state_prior_w[states.index(c[0])] / Zw)94 * (valence_neg_prob(c[0]) if c[1] == -1 else 1 - valence_neg_prob(c[0]))95 * (arousal_low_prob(c[0]) if c[2] == "low" else 1 - arousal_low_prob(c[0]))96 * (1.0 / len(goals))97 for c in cfgs98])99100lik = torch.tensor([101 float(torch.exp(speaker(c[0], c[1], c[2], c[3]).log_prob(torch.tensor(target_idx))))102 for c in cfgs103])104105@pyro.infer.config_enumerate106def pragmatic_model():107 ci = pyro.sample("cfg", dist.Categorical(prior / prior.sum()))108 pyro.factor("speaker", torch.log(lik[ci] + 1e-300)) # observe(speaker, utterance)109 return ci110111cfg_marg = _elbo().compute_marginals(pragmatic_model, lambda: None)["cfg"]112sup = cfg_marg.enumerate_support()113p = torch.exp(cfg_marg.log_prob(sup))114115support_records = []116probs = []117for i, pp in zip(sup, p):118 st, val, ar, goal = cfgs[int(i)]119 support_records.append({"state": st, "valence": val, "arousal": ar, "goal": goal})120 probs.append(float(pp))121122ANSWER = {"support": support_records, "probs": probs}123
| check | status | evidence |
|---|---|---|
| GT self-consistency | ok | floor 0.0000 (tv) |
| solver re-derivation | accept | 1/2 solvers · d=[0.469, 0.000] · claude-sonnet-4-6 |
| cross-language (pyro vs webppl) | pass | d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000 |
World states are 0, 1, 2 (number of rabbits fed), each equally likely. Utterances are 'null' and 'not-two', each equally likely a priori with uniform cost 1. Scopes are 'surface' and 'inverse', each equally likely. QUDs are 'how many?', 'all red?', and 'none red?', each equally likely. Semantic rationality parameter alpha = 1. The literal meaning of 'not-two' under surface scope is that fewer than 2 rabbits were fed (state < 2); under inverse scope it means no rabbits were fed (state = 0). 'null' is always true. The QUD function maps 'all red?' to whether state = 2, 'none red?' to whether state = 0, and 'how many?' to the state itself. A literal listener conditions on the utterance's truth under the given scope and QUD, then returns the QUD projection of the state. A pragmatic speaker at level 1 weights utterances by alpha times the log-probability the literal listener assigns to the QUD value minus the utterance cost. A pragmatic listener at level 1 inverts the speaker to infer the world state, marginalizing over scope and QUD. A pragmatic speaker at level 2 weights utterances by the log-probability the pragmatic listener assigns to the world state.
A four-level RSA hierarchy where the literal listener interprets utterances relative to a scope and a QUD, the first-level speaker communicates through a utility-weighted choice, the first-level pragmatic listener inverts the speaker marginalizing over scope and QUD, and the second-level pragmatic speaker further weighs utterances by how informative they are to the pragmatic listener.
The distribution over utterances ('null' or 'not-two') chosen by the second-level pragmatic speaker for world state 1 — one rabbit was fed.
answer spec
{
"kind": "dist",
"domain": "finite",
"support": [
"null",
"not-two"
]
}system prompt
(system prompt loads here)
webppl primer
(primer loads here)
1// Here is the code for English Expt 1 (surface scope)2//different qud prior, different state prior, access to alternative utterances34// Here is the code for the quantifier scope model56// possible utterances7var utterances = ["null","not-two"];89var utterancePrior = function() {10 categorical({vs:["null","not-two"],ps:[1,1]})11}1213var cost = function(utterance) {14 return 115}1617// possible world states18var states = [0,1,2];19var statePrior = function() {20 categorical({vs:[0,1,2],ps:[1,1,1]})21}2223// possible scopes24var scopePrior = function(){25 return categorical({vs:["surface", "inverse"],ps:[1,1]})26}2728var meaning = function(utterance, state, scope) {29 //if utterance == none:30 //return state==031 //else:32 //elif utternace == nottwo:33 //if scope == surface:34 //return state == 0 / state==135 //else:36 //return state == 037 //else:38 //return true3940 return utterance == "not-two" ?41 scope == "surface" ? (state == 0 | state ==1):42 state == 0 :43 true;44};454647// QUDs48var QUDs = ["how many?","all red?","none red?"];49var QUDPrior = function() {50 uniformDraw(QUDs);51}52var QUDFun = function(QUD,state) {53 QUD == "all red?" ? state == 2 :54 QUD == "none red?" ? state == 0 :55 state;56};5758// Literal listener (L0)59var literalListener = cache(function(utterance,scope,QUD) {60 Infer({model: function(){61 var state = uniformDraw(states);62 var qState = QUDFun(QUD,state)63 condition(meaning(utterance,state,scope));64 return qState;65 }});66});6768var alpha = 16970// Speaker (S)71var speaker = cache(function(scope, state, QUD) {72 return Infer({model: function(){73 var utterance = utterancePrior()74 var qState = QUDFun(QUD, state)75 factor(alpha*(literalListener(utterance,scope,QUD).score(qState)76 - cost(utterance)))77 return utterance78 }})79})8081// Pragmatic listener (L1)82var pragmaticListener = cache(function(utterance) {83 Infer({model: function(){84 var state = statePrior();85 var scope = scopePrior();86 var QUD = QUDPrior();87 observe(speaker(scope,state,QUD),utterance);88 return state89 }});90});9192// Pragmatic speaker (S2)93var pragmaticSpeaker = cache(function(state) {94 Infer({model: function(){95 var utterance = utterancePrior();96 factor(pragmaticListener(utterance).score(state))97 return utterance98 }})99})100101102103// A speaker decides whether to endorse the ambiguous utterance as a104// description of the not-all world state105viz.table(pragmaticSpeaker(1))106viz(pragmaticSpeaker(1))107//literalListener("surface", 2, "all red?")108109var ANSWER = (pragmaticSpeaker(1));
1# RSA quantifier-scope model (English Expt 1, surface scope).2# Each RSA level is a SEPARATE, fully-computed and memoized Pyro enumeration.3# A level's only latent sample site is enumerated and its marginal is read with4# TraceEnum_ELBO.compute_marginals; the finished distribution is then fed into5# the next level as a fixed pyro.factor score table. No inference is ever run6# inside another level's active trace, and every sample site has a name unique7# to its level, so there is no site collision.89utterances = ["null", "not-two"]10states = [0, 1, 2]11scopes = ["surface", "inverse"]12QUDs = ["how many?", "all red?", "none red?"]13alpha = 1.01415state_logits = torch.zeros(len(states)) # ps:[1,1,1]16utt_logits = torch.zeros(len(utterances)) # ps:[1,1]17scope_logits = torch.zeros(len(scopes)) # ps:[1,1]18qud_logits = torch.zeros(len(QUDs)) # uniform192021def cost(_u):22 return 1.0232425def meaning(utterance, state, scope):26 if utterance == "not-two":27 if scope == "surface":28 return (state == 0) or (state == 1)29 return state == 030 return True313233def qud_fun(qud, state):34 if qud == "all red?":35 return ("bool", state == 2)36 if qud == "none red?":37 return ("bool", state == 0)38 return ("state", state)394041def enum_marginal(model, site):42 # Run exact discrete enumeration over a single-latent model and return its43 # marginal as {support_index: prob}.44 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(45 model, lambda: None46 )[site]47 sup = marg.enumerate_support()48 probs = marg.log_prob(sup).exp()49 return {int(sup[i].item()): float(probs[i].item()) for i in range(sup.shape[0])}505152# ---- Literal listener L0: posterior over qState given (utterance, scope, QUD) ----53L0_cache = {}545556def literal_listener(utterance, scope, qud):57 key = (utterance, scope, qud)58 if key in L0_cache:59 return L0_cache[key]60 ok = torch.tensor([meaning(utterance, st, scope) for st in states], dtype=torch.bool)6162 @pyro.infer.config_enumerate63 def model():64 s = pyro.sample("L0_state", dist.Categorical(logits=state_logits))65 pyro.factor("L0_ev", torch.where(ok[s], torch.tensor(0.0), torch.tensor(float("-inf"))))66 return s6768 state_post = enum_marginal(model, "L0_state")69 # push the inferred state posterior forward through the QUD projection70 d = defaultdict(float)71 for si, p in state_post.items():72 d[qud_fun(qud, states[si])] += p73 L0_cache[key] = dict(d)74 return L0_cache[key]757677def l0_score(utterance, scope, qud, qstate):78 p = literal_listener(utterance, scope, qud).get(qstate, 0.0)79 return math.log(p) if p > 0 else float("-inf")808182# ---- Speaker S: distribution over utterances given (scope, state, QUD) ----83S_cache = {}848586def speaker(scope, state, qud):87 key = (scope, state, qud)88 if key in S_cache:89 return S_cache[key]90 qstate = qud_fun(qud, state)91 sc = torch.tensor([alpha * (l0_score(u, scope, qud, qstate) - cost(u)) for u in utterances])9293 @pyro.infer.config_enumerate94 def model():95 u = pyro.sample("S_utt", dist.Categorical(logits=utt_logits))96 pyro.factor("S_f", sc[u])97 return u9899 post = enum_marginal(model, "S_utt")100 S_cache[key] = {utterances[i]: p for i, p in post.items()}101 return S_cache[key]102103104def speaker_score(scope, state, qud, utterance):105 p = speaker(scope, state, qud).get(utterance, 0.0)106 return math.log(p) if p > 0 else float("-inf")107108109# ---- Pragmatic listener L1: posterior over state given utterance ----110L1_cache = {}111112113def pragmatic_listener(utterance):114 if utterance in L1_cache:115 return L1_cache[utterance]116 # finished speaker scores for every (state, scope, qud), evaluated up front117 tbl = torch.full((len(states), len(scopes), len(QUDs)), float("-inf"))118 for si, st in enumerate(states):119 for sci, scp in enumerate(scopes):120 for qi, qd in enumerate(QUDs):121 tbl[si, sci, qi] = speaker_score(scp, st, qd, utterance)122123 @pyro.infer.config_enumerate124 def model():125 s = pyro.sample("L1_state", dist.Categorical(logits=state_logits))126 sc = pyro.sample("L1_scope", dist.Categorical(logits=scope_logits))127 q = pyro.sample("L1_qud", dist.Categorical(logits=qud_logits))128 pyro.factor("L1_obs", tbl[s, sc, q])129 return s130131 post = enum_marginal(model, "L1_state")132 L1_cache[utterance] = {states[i]: p for i, p in post.items()}133 return L1_cache[utterance]134135136def pragmatic_listener_score(utterance, state):137 p = pragmatic_listener(utterance).get(state, 0.0)138 return math.log(p) if p > 0 else float("-inf")139140141# ---- Pragmatic speaker S2: distribution over utterances given state ----142def pragmatic_speaker(state):143 sc = torch.tensor([pragmatic_listener_score(u, state) for u in utterances])144145 @pyro.infer.config_enumerate146 def model():147 u = pyro.sample("S2_utt", dist.Categorical(logits=utt_logits))148 pyro.factor("S2_f", sc[u])149 return u150151 post = enum_marginal(model, "S2_utt")152 return {utterances[i]: p for i, p in post.items()}153154155ANSWER = pragmatic_speaker(1)156
| check | status | evidence |
|---|---|---|
| GT self-consistency | ok | floor 0.0000 (tv) |
| solver re-derivation | accept | 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6 |
| cross-language (pyro vs webppl) | pass | d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000 |
World states are 0, 1, 2, 3, 4 (number of rabbits fed out of 4), each equally likely. Utterances are 'null' and 'not-two'; the utterance prior is categorical with weights [1, 10] — 'not-two' is ten times more likely than 'null'. Both utterances have uniform cost 1. Scopes are 'surface' and 'inverse', each equally likely. QUDs are 'how many?', 'all red?', and 'none red?', each equally likely. Semantic rationality parameter alpha = 1. The literal meaning of 'not-two' under surface scope is that fewer than 2 rabbits were fed (state < 2); under inverse scope it means fewer than 3 rabbits were fed (state < 3). 'null' is always true. The QUD function maps 'all red?' to whether state = 4, 'none red?' to whether state = 0, and 'how many?' to the state itself. A literal listener conditions on the utterance's truth under the given scope and QUD, then returns the QUD projection of the state. A pragmatic speaker at level 1 weights utterances by alpha times the log-probability the literal listener assigns to the QUD value minus the utterance cost. A pragmatic listener at level 1 inverts the speaker to infer the world state, marginalizing over scope and QUD. A pragmatic speaker at level 2 weights utterances by the log-probability the pragmatic listener assigns to the world state.
A four-level RSA hierarchy where the literal listener interprets utterances relative to a scope and a QUD, the first-level speaker communicates through a utility-weighted choice, the first-level pragmatic listener inverts the speaker marginalizing over scope and QUD, and the second-level pragmatic speaker further weighs utterances by how informative they are to the pragmatic listener.
The distribution over utterances ('null' or 'not-two') chosen by the second-level pragmatic speaker for world state 2 — two of four rabbits were fed.
answer spec
{
"kind": "dist",
"domain": "finite",
"support": [
"null",
"not-two"
]
}system prompt
(system prompt loads here)
webppl primer
(primer loads here)
1// Here is the code for English Expt 2 (inverse scope)2//different qud prior, different state prior, access to alternative utterances34// Here is the code for the quantifier scope model56// possible utterances7var utterances = ["null","not-two"];89var utterancePrior = function() {10 categorical({vs:["null","not-two"],ps:[1,10]})11}1213var cost = function(utterance) {14 return utterance == "not-two" ? 1 :15 116}1718// possible world states19var states = [0,1,2,3,4];20var statePrior = function() {21 categorical({vs:[0,1,2,3,4],ps:[1,1,1,1,1]})22}2324// possible scopes25var scopePrior = function(){26 return categorical({vs:["surface", "inverse"],ps:[1,1]})27}2829// meaning function30var meaning = function(utterance, state, scope) {31 return utterance == "not-two" ?32 scope == "surface" ? (state < 2):33 (state < 3) :34 true;35};3637// QUDs38var QUDs = ["how many?","all red?","none red?"];39var QUDPrior = function() {40 categorical({vs:["how many?","all red?","none red?"],ps:[1,1,1]})41 //uniformDraw(QUDs);42}4344var QUDFun = function(QUD,state) {45 QUD == "all red?" ? state == 4 :46 QUD == "none red?" ? state == 0 :47 state;48};4950// Literal listener (L0)51var literalListener = cache(function(utterance,scope,QUD) {52 Infer({model: function(){53 var state = uniformDraw(states);54 var qState = QUDFun(QUD,state)55 condition(meaning(utterance,state,scope));56 return qState;57 }});58});5960var alpha = 16162// Speaker (S)63var speaker = cache(function(scope, state, QUD) {64 return Infer({model: function(){65 var utterance = utterancePrior()66 var qState = QUDFun(QUD, state)67 factor(alpha*(literalListener(utterance,scope,QUD).score(qState)68 - cost(utterance)))69 return utterance70 }})71})7273// Pragmatic listener (L1)74var pragmaticListener = cache(function(utterance) {75 Infer({model: function(){76 var state = statePrior();77 var scope = scopePrior();78 var QUD = QUDPrior();79 observe(speaker(scope,state,QUD),utterance);80 return state81 }});82});8384// Pragmatic speaker (S2)85var pragmaticSpeaker = cache(function(state) {86 Infer({model: function(){87 var utterance = utterancePrior();88 factor(pragmaticListener(utterance).score(state))89 return utterance90 }})91})9293// A speaker decides whether to endorse the ambiguous utterance as a94// description of the not-all world state95//viz.table(pragmaticSpeaker(0))96viz.table(pragmaticSpeaker(2))97//viz.table(pragmaticSpeaker(2))98//literalListener("surface", 2, "all red?")99100var ANSWER = (pragmaticSpeaker(2));
1# RSA quantifier-scope model (English Expt 2, inverse scope).2# Each RSA level is a SEPARATE, fully-computed and memoized Pyro enumeration3# with level-unique site names; a finished level's marginal is fed into the next4# as a fixed pyro.factor score table. No inference inside an active trace.56utterances = ["null", "not-two"]7states = [0, 1, 2, 3, 4]8scopes = ["surface", "inverse"]9QUDs = ["how many?", "all red?", "none red?"]10alpha = 1.01112state_logits = torch.zeros(len(states)) # ps:[1,1,1,1,1]13utt_logits = torch.log(torch.tensor([1.0, 10.0])) # ps:[1,10]14scope_logits = torch.zeros(len(scopes)) # ps:[1,1]15qud_logits = torch.zeros(len(QUDs)) # ps:[1,1,1]161718def cost(_u):19 return 1.0202122def meaning(utterance, state, scope):23 if utterance == "not-two":24 if scope == "surface":25 return state < 226 return state < 327 return True282930def qud_fun(qud, state):31 if qud == "all red?":32 return ("bool", state == 4)33 if qud == "none red?":34 return ("bool", state == 0)35 return ("state", state)363738def enum_marginal(model, site):39 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(40 model, lambda: None41 )[site]42 sup = marg.enumerate_support()43 probs = marg.log_prob(sup).exp()44 return {int(sup[i].item()): float(probs[i].item()) for i in range(sup.shape[0])}454647L0_cache = {}484950def literal_listener(utterance, scope, qud):51 key = (utterance, scope, qud)52 if key in L0_cache:53 return L0_cache[key]54 ok = torch.tensor([meaning(utterance, st, scope) for st in states], dtype=torch.bool)5556 @pyro.infer.config_enumerate57 def model():58 s = pyro.sample("L0_state", dist.Categorical(logits=state_logits))59 pyro.factor("L0_ev", torch.where(ok[s], torch.tensor(0.0), torch.tensor(float("-inf"))))60 return s6162 state_post = enum_marginal(model, "L0_state")63 d = defaultdict(float)64 for si, p in state_post.items():65 d[qud_fun(qud, states[si])] += p66 L0_cache[key] = dict(d)67 return L0_cache[key]686970def l0_score(utterance, scope, qud, qstate):71 p = literal_listener(utterance, scope, qud).get(qstate, 0.0)72 return math.log(p) if p > 0 else float("-inf")737475S_cache = {}767778def speaker(scope, state, qud):79 key = (scope, state, qud)80 if key in S_cache:81 return S_cache[key]82 qstate = qud_fun(qud, state)83 sc = torch.tensor([alpha * (l0_score(u, scope, qud, qstate) - cost(u)) for u in utterances])8485 @pyro.infer.config_enumerate86 def model():87 u = pyro.sample("S_utt", dist.Categorical(logits=utt_logits))88 pyro.factor("S_f", sc[u])89 return u9091 post = enum_marginal(model, "S_utt")92 S_cache[key] = {utterances[i]: p for i, p in post.items()}93 return S_cache[key]949596def speaker_score(scope, state, qud, utterance):97 p = speaker(scope, state, qud).get(utterance, 0.0)98 return math.log(p) if p > 0 else float("-inf")99100101L1_cache = {}102103104def pragmatic_listener(utterance):105 if utterance in L1_cache:106 return L1_cache[utterance]107 tbl = torch.full((len(states), len(scopes), len(QUDs)), float("-inf"))108 for si, st in enumerate(states):109 for sci, scp in enumerate(scopes):110 for qi, qd in enumerate(QUDs):111 tbl[si, sci, qi] = speaker_score(scp, st, qd, utterance)112113 @pyro.infer.config_enumerate114 def model():115 s = pyro.sample("L1_state", dist.Categorical(logits=state_logits))116 sc = pyro.sample("L1_scope", dist.Categorical(logits=scope_logits))117 q = pyro.sample("L1_qud", dist.Categorical(logits=qud_logits))118 pyro.factor("L1_obs", tbl[s, sc, q])119 return s120121 post = enum_marginal(model, "L1_state")122 L1_cache[utterance] = {states[i]: p for i, p in post.items()}123 return L1_cache[utterance]124125126def pragmatic_listener_score(utterance, state):127 p = pragmatic_listener(utterance).get(state, 0.0)128 return math.log(p) if p > 0 else float("-inf")129130131def pragmatic_speaker(state):132 sc = torch.tensor([pragmatic_listener_score(u, state) for u in utterances])133134 @pyro.infer.config_enumerate135 def model():136 u = pyro.sample("S2_utt", dist.Categorical(logits=utt_logits))137 pyro.factor("S2_f", sc[u])138 return u139140 post = enum_marginal(model, "S2_utt")141 return {utterances[i]: p for i, p in post.items()}142143144ANSWER = pragmatic_speaker(2)145
| check | status | evidence |
|---|---|---|
| GT self-consistency | ok | floor 0.0000 (tv) |
| solver re-derivation | accept | 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6 |
| cross-language (pyro vs webppl) | pass | d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000 |
World states are 0, 1, 2 (number of rabbits fed), each equally likely. Utterances are 'null', 'not-two', and 'none', each equally likely a priori with uniform cost 1. Scopes are 'surface' and 'inverse'; the scope prior strongly favors surface scope: categorical weights [100, 1]. QUDs are 'how many?', 'all red?', and 'none red?', each equally likely. Semantic rationality parameter alpha = 1. The literal meaning of 'none' is that no rabbits were fed (state = 0). The literal meaning of 'not-two' under surface scope is that fewer than 2 rabbits were fed (state < 2); under inverse scope it means no rabbits were fed (state = 0). 'null' is always true. The QUD function maps 'all red?' to whether state = 2, 'none red?' to whether state = 0, and 'how many?' to the state itself. A literal listener conditions on the utterance's truth under the given scope and QUD, then returns the QUD projection of the state. A pragmatic speaker at level 1 weights utterances by alpha times the log-probability the literal listener assigns to the QUD value minus the utterance cost. A pragmatic listener at level 1 inverts the speaker to infer the world state, marginalizing over scope and QUD. A pragmatic speaker at level 2 weights utterances by the log-probability the pragmatic listener assigns to the world state.
A four-level RSA hierarchy where the literal listener interprets utterances relative to a scope and a QUD, the first-level speaker communicates through a utility-weighted choice, the first-level pragmatic listener inverts the speaker marginalizing over scope and QUD, and the second-level pragmatic speaker further weighs utterances by how informative they are to the pragmatic listener.
The distribution over utterances ('null', 'not-two', or 'none') chosen by the second-level pragmatic speaker for world state 1 — one rabbit was fed.
answer spec
{
"kind": "dist",
"domain": "finite",
"support": [
"null",
"not-two",
"none"
]
}system prompt
(system prompt loads here)
webppl primer
(primer loads here)
1// Here is the code of Chinese model for Expt 1.2//different qud prior, different state prior, access to alternative utterances34// Here is the code for the quantifier scope model56// possible utterances7var utterances = ["null","not-two","none"];89var utterancePrior = function() {10 categorical({vs:["null","not-two","none"],ps:[1,1,1]})11}1213var cost = function(utterance) {14 return utterance == "not-two" ? 1 :15 utterance == 'none'? 1 :16 117}1819// possible world states20var states = [0,1,2];21var statePrior = function() {22 categorical({vs:[0,1,2],ps:[1,1,1]})23}2425// possible scopes26var scopePrior = function(){27 return categorical({vs:["surface", "inverse"],ps:[100,1]})28}2930// meaning function31var meaning = function(utterance, state, scope) {32 return utterance == "none"? state == 0:33 utterance == "not-two" ?34 scope == "surface" ? (state < 2):35 (state == 0) :36 true;37};3839// QUDs40var QUDs = ["how many?","all red?","none red?"];41var QUDPrior = function() {42 categorical({vs:["how many?","all red?","none red?"],ps:[1,1,1]})43 //uniformDraw(QUDs);44}45var QUDFun = function(QUD,state) {46 QUD == "all red?" ? state == 2 :47 QUD == "none red?" ? state == 0 :48 state;49};5051// Literal listener (L0)52var literalListener = cache(function(utterance,scope,QUD) {53 Infer({model: function(){54 var state = uniformDraw(states);55 var qState = QUDFun(QUD,state)56 condition(meaning(utterance,state,scope));57 return qState;58 }});59});6061var alpha = 16263// Speaker (S)64var speaker = cache(function(scope, state, QUD) {65 return Infer({model: function(){66 var utterance = utterancePrior()67 var qState = QUDFun(QUD, state)68 factor(alpha*(literalListener(utterance,scope,QUD).score(qState)69 - cost(utterance)))70 return utterance71 }})72})7374// Pragmatic listener (L1)75var pragmaticListener = cache(function(utterance) {76 Infer({model: function(){77 var state = statePrior();78 var scope = scopePrior();79 var QUD = QUDPrior();80 observe(speaker(scope,state,QUD),utterance);81 return state82 }});83});8485// Pragmatic speaker (S2)86var pragmaticSpeaker = cache(function(state) {87 Infer({model: function(){88 var utterance = utterancePrior();89 factor(pragmaticListener(utterance).score(state))90 return utterance91 }})92})9394// A speaker decides whether to endorse the ambiguous utterance as a95// description of the not-all world state96//viz.table(pragmaticSpeaker(0))97//viz.table(pragmaticSpeaker(1))98viz.table(pragmaticSpeaker(1))99//literalListener("surface", 2, "all red?")100101var ANSWER = (pragmaticSpeaker(1));
1# RSA quantifier-scope model (Chinese Expt 1).2# Each RSA level is a SEPARATE, fully-computed and memoized Pyro enumeration3# with level-unique site names; a finished level's marginal is fed into the next4# as a fixed pyro.factor score table. No inference inside an active trace.56utterances = ["null", "not-two", "none"]7states = [0, 1, 2]8scopes = ["surface", "inverse"]9QUDs = ["how many?", "all red?", "none red?"]10alpha = 1.01112state_logits = torch.zeros(len(states)) # ps:[1,1,1]13utt_logits = torch.zeros(len(utterances)) # ps:[1,1,1]14scope_logits = torch.log(torch.tensor([100.0, 1.0])) # ps:[100,1]15qud_logits = torch.zeros(len(QUDs)) # ps:[1,1,1]161718def cost(_u):19 return 1.0202122def meaning(utterance, state, scope):23 if utterance == "none":24 return state == 025 if utterance == "not-two":26 if scope == "surface":27 return state < 228 return state == 029 return True303132def qud_fun(qud, state):33 if qud == "all red?":34 return ("bool", state == 2)35 if qud == "none red?":36 return ("bool", state == 0)37 return ("state", state)383940def enum_marginal(model, site):41 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(42 model, lambda: None43 )[site]44 sup = marg.enumerate_support()45 probs = marg.log_prob(sup).exp()46 return {int(sup[i].item()): float(probs[i].item()) for i in range(sup.shape[0])}474849L0_cache = {}505152def literal_listener(utterance, scope, qud):53 key = (utterance, scope, qud)54 if key in L0_cache:55 return L0_cache[key]56 ok = torch.tensor([meaning(utterance, st, scope) for st in states], dtype=torch.bool)5758 @pyro.infer.config_enumerate59 def model():60 s = pyro.sample("L0_state", dist.Categorical(logits=state_logits))61 pyro.factor("L0_ev", torch.where(ok[s], torch.tensor(0.0), torch.tensor(float("-inf"))))62 return s6364 state_post = enum_marginal(model, "L0_state")65 d = defaultdict(float)66 for si, p in state_post.items():67 d[qud_fun(qud, states[si])] += p68 L0_cache[key] = dict(d)69 return L0_cache[key]707172def l0_score(utterance, scope, qud, qstate):73 p = literal_listener(utterance, scope, qud).get(qstate, 0.0)74 return math.log(p) if p > 0 else float("-inf")757677S_cache = {}787980def speaker(scope, state, qud):81 key = (scope, state, qud)82 if key in S_cache:83 return S_cache[key]84 qstate = qud_fun(qud, state)85 sc = torch.tensor([alpha * (l0_score(u, scope, qud, qstate) - cost(u)) for u in utterances])8687 @pyro.infer.config_enumerate88 def model():89 u = pyro.sample("S_utt", dist.Categorical(logits=utt_logits))90 pyro.factor("S_f", sc[u])91 return u9293 post = enum_marginal(model, "S_utt")94 S_cache[key] = {utterances[i]: p for i, p in post.items()}95 return S_cache[key]969798def speaker_score(scope, state, qud, utterance):99 p = speaker(scope, state, qud).get(utterance, 0.0)100 return math.log(p) if p > 0 else float("-inf")101102103L1_cache = {}104105106def pragmatic_listener(utterance):107 if utterance in L1_cache:108 return L1_cache[utterance]109 tbl = torch.full((len(states), len(scopes), len(QUDs)), float("-inf"))110 for si, st in enumerate(states):111 for sci, scp in enumerate(scopes):112 for qi, qd in enumerate(QUDs):113 tbl[si, sci, qi] = speaker_score(scp, st, qd, utterance)114115 @pyro.infer.config_enumerate116 def model():117 s = pyro.sample("L1_state", dist.Categorical(logits=state_logits))118 sc = pyro.sample("L1_scope", dist.Categorical(logits=scope_logits))119 q = pyro.sample("L1_qud", dist.Categorical(logits=qud_logits))120 pyro.factor("L1_obs", tbl[s, sc, q])121 return s122123 post = enum_marginal(model, "L1_state")124 L1_cache[utterance] = {states[i]: p for i, p in post.items()}125 return L1_cache[utterance]126127128def pragmatic_listener_score(utterance, state):129 p = pragmatic_listener(utterance).get(state, 0.0)130 return math.log(p) if p > 0 else float("-inf")131132133def pragmatic_speaker(state):134 sc = torch.tensor([pragmatic_listener_score(u, state) for u in utterances])135136 @pyro.infer.config_enumerate137 def model():138 u = pyro.sample("S2_utt", dist.Categorical(logits=utt_logits))139 pyro.factor("S2_f", sc[u])140 return u141142 post = enum_marginal(model, "S2_utt")143 return {utterances[i]: p for i, p in post.items()}144145146ANSWER = pragmatic_speaker(1)147
| check | status | evidence |
|---|---|---|
| GT self-consistency | ok | floor 0.0000 (tv) |
| solver re-derivation | accept | 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6 |
| cross-language (pyro vs webppl) | pass | d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000 |
World states are 0, 1, 2, 3, 4 (number of rabbits fed out of 4), each equally likely. Utterances are 'null', 'not-two', and 'none', each equally likely a priori with uniform cost 1. Scopes are 'surface' and 'inverse'; the scope prior strongly favors surface scope: categorical weights [100, 1]. QUDs are 'how many?', 'all red?', and 'none red?', each equally likely. Semantic rationality parameter alpha = 1. The literal meaning of 'none' is that no rabbits were fed (state = 0). The literal meaning of 'not-two' under surface scope is that fewer than 2 rabbits were fed (state < 2); under inverse scope it means fewer than 3 rabbits were fed (state < 3). 'null' is always true. The QUD function maps 'all red?' to whether state = 4, 'none red?' to whether state = 0, and 'how many?' to the state itself. A literal listener conditions on the utterance's truth under the given scope and QUD, then returns the QUD projection of the state. A pragmatic speaker at level 1 weights utterances by alpha times the log-probability the literal listener assigns to the QUD value minus the utterance cost. A pragmatic listener at level 1 inverts the speaker to infer the world state, marginalizing over scope and QUD. A pragmatic speaker at level 2 weights utterances by the log-probability the pragmatic listener assigns to the world state.
A four-level RSA hierarchy where the literal listener interprets utterances relative to a scope and a QUD, the first-level speaker communicates through a utility-weighted choice, the first-level pragmatic listener inverts the speaker marginalizing over scope and QUD, and the second-level pragmatic speaker further weighs utterances by how informative they are to the pragmatic listener.
The distribution over utterances ('null', 'not-two', or 'none') chosen by the second-level pragmatic speaker for world state 2 — two of four rabbits were fed.
answer spec
{
"kind": "dist",
"domain": "finite",
"support": [
"null",
"not-two",
"none"
]
}system prompt
(system prompt loads here)
webppl primer
(primer loads here)
1// Here is the code of Chinese model for Expt 2.2//different qud prior, different state prior, access to alternative utterances34// Here is the code for the quantifier scope model56// possible utterances7var utterances = ["null","not-two","none"];89var utterancePrior = function() {10 categorical({vs:["null","not-two","none"],ps:[1,1,1]})11}1213var cost = function(utterance) {14 return utterance == "not-two" ? 1 :15 utterance == 'none'? 1 :16 117}1819// possible world states20var states = [0,1,2,3,4];21var statePrior = function() {22 categorical({vs:[0,1,2,3,4],ps:[1,1,1,1,1]})23}2425// possible scopes26var scopePrior = function(){27 return categorical({vs:["surface", "inverse"],ps:[100,1]})28}2930// meaning function31var meaning = function(utterance, state, scope) {32 return utterance == "none"? state == 0:33 utterance == "not-two" ?34 scope == "surface" ? (state < 2):35 (state < 3) :36 true;37};3839// QUDs40var QUDs = ["how many?","all red?","none red?"];41var QUDPrior = function() {42 categorical({vs:["how many?","all red?","none red?"],ps:[1,1,1]})43 //uniformDraw(QUDs);44}45var QUDFun = function(QUD,state) {46 QUD == "all red?" ? state == 4 :47 QUD == "none red?" ? state == 0 :48 state;49};5051// Literal listener (L0)52var literalListener = cache(function(utterance,scope,QUD) {53 Infer({model: function(){54 var state = uniformDraw(states);55 var qState = QUDFun(QUD,state)56 condition(meaning(utterance,state,scope));57 return qState;58 }});59});6061var alpha = 16263// Speaker (S)64var speaker = cache(function(scope, state, QUD) {65 return Infer({model: function(){66 var utterance = utterancePrior()67 var qState = QUDFun(QUD, state)68 factor(alpha*(literalListener(utterance,scope,QUD).score(qState)69 - cost(utterance)))70 return utterance71 }})72})7374// Pragmatic listener (L1)75var pragmaticListener = cache(function(utterance) {76 Infer({model: function(){77 var state = statePrior();78 var scope = scopePrior();79 var QUD = QUDPrior();80 observe(speaker(scope,state,QUD),utterance);81 return state82 }});83});8485// Pragmatic speaker (S2)86var pragmaticSpeaker = cache(function(state) {87 Infer({model: function(){88 var utterance = utterancePrior();89 factor(pragmaticListener(utterance).score(state))90 return utterance91 }})92})9394// A speaker decides whether to endorse the ambiguous utterance as a95// description of the not-all world state96//viz.table(pragmaticSpeaker(0))97viz.table(pragmaticSpeaker(2))98//viz.table(pragmaticSpeaker(2))99//literalListener("surface", 2, "all red?")100101var ANSWER = (pragmaticSpeaker(2));
1# RSA quantifier-scope model (Chinese Expt 2).2# Each RSA level is a SEPARATE, fully-computed and memoized Pyro enumeration3# with level-unique site names; a finished level's marginal is fed into the next4# as a fixed pyro.factor score table. No inference inside an active trace.56utterances = ["null", "not-two", "none"]7states = [0, 1, 2, 3, 4]8scopes = ["surface", "inverse"]9QUDs = ["how many?", "all red?", "none red?"]10alpha = 1.01112state_logits = torch.zeros(len(states)) # ps:[1,1,1,1,1]13utt_logits = torch.zeros(len(utterances)) # ps:[1,1,1]14scope_logits = torch.log(torch.tensor([100.0, 1.0])) # ps:[100,1]15qud_logits = torch.zeros(len(QUDs)) # ps:[1,1,1]161718def cost(_u):19 return 1.0202122def meaning(utterance, state, scope):23 if utterance == "none":24 return state == 025 if utterance == "not-two":26 if scope == "surface":27 return state < 228 return state < 329 return True303132def qud_fun(qud, state):33 if qud == "all red?":34 return ("bool", state == 4)35 if qud == "none red?":36 return ("bool", state == 0)37 return ("state", state)383940def enum_marginal(model, site):41 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(42 model, lambda: None43 )[site]44 sup = marg.enumerate_support()45 probs = marg.log_prob(sup).exp()46 return {int(sup[i].item()): float(probs[i].item()) for i in range(sup.shape[0])}474849L0_cache = {}505152def literal_listener(utterance, scope, qud):53 key = (utterance, scope, qud)54 if key in L0_cache:55 return L0_cache[key]56 ok = torch.tensor([meaning(utterance, st, scope) for st in states], dtype=torch.bool)5758 @pyro.infer.config_enumerate59 def model():60 s = pyro.sample("L0_state", dist.Categorical(logits=state_logits))61 pyro.factor("L0_ev", torch.where(ok[s], torch.tensor(0.0), torch.tensor(float("-inf"))))62 return s6364 state_post = enum_marginal(model, "L0_state")65 d = defaultdict(float)66 for si, p in state_post.items():67 d[qud_fun(qud, states[si])] += p68 L0_cache[key] = dict(d)69 return L0_cache[key]707172def l0_score(utterance, scope, qud, qstate):73 p = literal_listener(utterance, scope, qud).get(qstate, 0.0)74 return math.log(p) if p > 0 else float("-inf")757677S_cache = {}787980def speaker(scope, state, qud):81 key = (scope, state, qud)82 if key in S_cache:83 return S_cache[key]84 qstate = qud_fun(qud, state)85 sc = torch.tensor([alpha * (l0_score(u, scope, qud, qstate) - cost(u)) for u in utterances])8687 @pyro.infer.config_enumerate88 def model():89 u = pyro.sample("S_utt", dist.Categorical(logits=utt_logits))90 pyro.factor("S_f", sc[u])91 return u9293 post = enum_marginal(model, "S_utt")94 S_cache[key] = {utterances[i]: p for i, p in post.items()}95 return S_cache[key]969798def speaker_score(scope, state, qud, utterance):99 p = speaker(scope, state, qud).get(utterance, 0.0)100 return math.log(p) if p > 0 else float("-inf")101102103L1_cache = {}104105106def pragmatic_listener(utterance):107 if utterance in L1_cache:108 return L1_cache[utterance]109 tbl = torch.full((len(states), len(scopes), len(QUDs)), float("-inf"))110 for si, st in enumerate(states):111 for sci, scp in enumerate(scopes):112 for qi, qd in enumerate(QUDs):113 tbl[si, sci, qi] = speaker_score(scp, st, qd, utterance)114115 @pyro.infer.config_enumerate116 def model():117 s = pyro.sample("L1_state", dist.Categorical(logits=state_logits))118 sc = pyro.sample("L1_scope", dist.Categorical(logits=scope_logits))119 q = pyro.sample("L1_qud", dist.Categorical(logits=qud_logits))120 pyro.factor("L1_obs", tbl[s, sc, q])121 return s122123 post = enum_marginal(model, "L1_state")124 L1_cache[utterance] = {states[i]: p for i, p in post.items()}125 return L1_cache[utterance]126127128def pragmatic_listener_score(utterance, state):129 p = pragmatic_listener(utterance).get(state, 0.0)130 return math.log(p) if p > 0 else float("-inf")131132133def pragmatic_speaker(state):134 sc = torch.tensor([pragmatic_listener_score(u, state) for u in utterances])135136 @pyro.infer.config_enumerate137 def model():138 u = pyro.sample("S2_utt", dist.Categorical(logits=utt_logits))139 pyro.factor("S2_f", sc[u])140 return u141142 post = enum_marginal(model, "S2_utt")143 return {utterances[i]: p for i, p in post.items()}144145146ANSWER = pragmatic_speaker(2)147
| check | status | evidence |
|---|---|---|
| GT self-consistency | ok | floor 0.0000 (tv) |
| solver re-derivation | accept | 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6 |
| cross-language (pyro vs webppl) | pass | d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000 |
There are 3 objects in the scene: one big blue object, one small blue object, and one big red object. Utterances are 'big', 'small', 'blue', 'red', 'big_blue', 'small_blue', and 'big_red', each equally likely a priori. The meaning function uses soft semantics parameterized by size_semvalue = 0.8 and color_semvalue = 0.99: a single size word applied to an object returns 0.8 if the size matches, else 0.2; a single color word applied to an object returns 0.99 if the color matches, else 0.01. A two-word utterance of the form SIZE_COLOR returns the product of the corresponding size and color semantic values. Cost is defined by two parameters: size_cost = 0 and color_cost = 0, giving all single-word utterances cost 0 and all two-word utterances cost 0. The literal listener scores each state by adding the semantic value directly as a log-weight (i.e., state s receives unnormalized log-score += meaning(utt, s), so its unnormalized probability is proportional to exp(meaning(utt, s)), not to meaning(utt, s) itself). The pragmatic speaker uses alpha = 30 and costWeight = 1: it factors by alpha times the literal listener's log-probability of the target state minus costWeight times the utterance cost.
A two-level RSA reference game model with relaxed (graded) semantics. The literal listener interprets utterances as soft evidence about the object and infers which object is being described. The pragmatic speaker selects utterances by soft-max over the utility of each utterance, where utility is the listener's log-probability of the intended object minus the utterance cost.
The distribution over utterances chosen by the pragmatic speaker when communicating the small blue object — {size: 'small', color: 'blue'}.
answer spec
{
"kind": "dist",
"domain": "finite",
"support": [
"big",
"small",
"blue",
"red",
"big_blue",
"small_blue",
"big_red"
]
}system prompt
(system prompt loads here)
webppl primer
(primer loads here)
1var alpha = 302var costWeight = 13var size_semvalue = 0.84var color_semvalue = 0.995var size_cost = 06var color_cost = 078var states = [9 {size: "big", color: "blue"},10 {size: "small", color: "blue"},11 {size: "big", color: "red"}]1213var utterances = ["big", "small", "blue", "red", "big_blue", "small_blue", "big_red"]1415var colors = ["red", "blue"]16var sizes = ["big", "small"]1718var statePrior = function() {19 return uniformDraw(states)20};2122var utterancePrior = function() {23 return uniformDraw(utterances)24};2526// assumes that 2-word utterances consist of SIZE_COLOR, in that order27var meaning = function(utt, obj) {28 var splitWords = utt.split('_')29 if (splitWords.length == 1) {30 var word = splitWords[0]31 if(_.includes(colors, word))32 return word == obj.color ? color_semvalue : 1-color_semvalue;33 else if (_.includes(sizes, word))34 return word == obj.size ? size_semvalue : 1-size_semvalue;35 } else if (splitWords.length == 2) {36 var size_value = splitWords[0] == obj.size ? size_semvalue : 1-size_semvalue;37 var color_value = splitWords[1] == obj.color ? color_semvalue : 1-color_semvalue;38 return size_value*color_value39 } else40 console.error("bad utterance length: "+splitWords.length)41};4243var cost = {44 big: size_cost,45 small: size_cost,46 blue: color_cost,47 red: color_cost,48 big_blue: size_cost+color_cost,49 small_blue: size_cost+color_cost,50 big_red: size_cost+color_cost51}5253// literal listener54var literalListener = cache(function(utt) {55 return Infer({method:"enumerate"},56 function(){57 var state = statePrior()58 factor(meaning(utt,state))59 return state60 })61});6263// pragmatic speaker64var speaker = cache(function(state) {65 return Infer({method:"enumerate"},66 function(){67 var utt = utterancePrior()68 factor(alpha * literalListener(utt).score(state) - costWeight * cost[utt])69 return utt70 })71});727374display("speaker who wants to communicate big blue object:")75viz.table(speaker({size: "big", color: "blue"}))7677display("speaker who wants to communicate big red object:")78viz.table(speaker({size: "big", color: "red"}))7980display("speaker who wants to communicate small blue object:")81viz.table(speaker({size: "small", color: "blue"}))8283display("literal listener who observes 'big':")84viz.table(literalListener("big"))8586display("literal listener who observes 'small':")87viz.table(literalListener("small"))8889display("literal listener who observes 'blue':")90viz.table(literalListener("blue"))9192display("literal listener who observes 'red':")93viz.table(literalListener("red"))9495display("literal listener who observes 'big blue':")96viz.table(literalListener("big_blue"))9798display("literal listener who observes 'big red':")99viz.table(literalListener("big_red"))100101display("literal listener who observes 'small blue':")102viz.table(literalListener("small_blue"))103104var ANSWER = (speaker({size: 'small', color: 'blue'}));
1alpha = 30.02costWeight = 1.03size_semvalue = 0.84color_semvalue = 0.995size_cost = 0.06color_cost = 0.078states = [9 {"size": "big", "color": "blue"},10 {"size": "small", "color": "blue"},11 {"size": "big", "color": "red"},12]13utterances = ["big", "small", "blue", "red", "big_blue", "small_blue", "big_red"]14colors = ["red", "blue"]15sizes = ["big", "small"]1617def meaning(utt, obj):18 parts = utt.split("_")19 if len(parts) == 1:20 word = parts[0]21 if word in colors:22 return color_semvalue if word == obj["color"] else 1 - color_semvalue23 else:24 return size_semvalue if word == obj["size"] else 1 - size_semvalue25 else:26 size_value = size_semvalue if parts[0] == obj["size"] else 1 - size_semvalue27 color_value = color_semvalue if parts[1] == obj["color"] else 1 - color_semvalue28 return size_value * color_value2930cost = {31 "big": size_cost, "small": size_cost, "blue": color_cost, "red": color_cost,32 "big_blue": size_cost + color_cost, "small_blue": size_cost + color_cost,33 "big_red": size_cost + color_cost,34}3536# Literal listener: uniform prior over states, factor(meaning(utt, state)) added as37# a log-weight. Inference is exact enumeration via Pyro's compute_marginals.38_L0_cache = {}39def literalListener(utt):40 if utt in _L0_cache:41 return _L0_cache[utt]42 _mvals = torch.tensor([meaning(utt, s) for s in states])43 @pyro.infer.config_enumerate44 def m():45 idx = pyro.sample("state", dist.Categorical(torch.ones(len(states))))46 pyro.factor("sem", _mvals[idx])47 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(m, lambda: None)48 _L0_cache[utt] = marg["state"]49 return _L0_cache[utt]5051# Pragmatic speaker: utterance ~ exp(alpha * L0.score(state) - costWeight * cost),52# again by exact enumeration over utterances inverting the literal listener.53def speaker(state):54 sidx = states.index(state)55 L0s = [literalListener(u) for u in utterances]56 _scores = torch.tensor([57 alpha * L0s[i].log_prob(torch.tensor(sidx)).item() - costWeight * cost[utterances[i]]58 for i in range(len(utterances))59 ])60 @pyro.infer.config_enumerate61 def m():62 uidx = pyro.sample("utt", dist.Categorical(torch.ones(len(utterances))))63 pyro.factor("util", _scores[uidx])64 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(m, lambda: None)65 return marg["utt"]6667_sp = speaker({"size": "small", "color": "blue"})68ANSWER = {utterances[i]: _sp.probs[i].item() for i in range(len(utterances))}69
| check | status | evidence |
|---|---|---|
| GT self-consistency | ok | floor 0.0000 (tv) |
| solver re-derivation | accept | 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6 |
| cross-language (pyro vs webppl) | pass | d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000 |
Three objects of reference: a blue square, a blue circle, and a green square. Four utterances: 'blue', 'green', 'square', 'circle'. Five preference contexts: 'blue_things', 'green_things', 'squares', 'circles', 'none'. Each preference context assigns categorical salience weights over the three objects (in the order blue square, blue circle, green square): blue_things [4, 4, 2], green_things [1, 1, 8], squares [4, 2, 4], circles [1, 8, 1], none [1, 1, 1]. Speaker optimality alpha = 1.
An object prior samples an object proportional to the salience weights for the given preference context and returns its string label. An utterance is literally true of an object if the utterance string is contained in the object's string label. A literal listener (L0) draws uniformly from the three objects (ignoring preference), conditions on the utterance being literally true, and returns the object string. A speaker draws uniformly from utterances and weights each by the literal listener's log-probability of the object under that utterance and preference context, scaled by alpha. A pragmatic listener samples an object from the preference-weighted object prior and updates on the speaker choosing the heard utterance.
The posterior distribution over object labels for a pragmatic listener who hears the utterance 'square' under the preference context 'blue_things'.
answer spec
{
"kind": "dist",
"domain": "finite",
"support": [
"blue square",
"blue circle",
"green square"
]
}system prompt
(system prompt loads here)
webppl primer
(primer loads here)
1// Frank and Goodman (2012) RSA model from problang.org23// set of states (here: objects of reference)4// we represent objects as JavaScript objects to demarcate them from utterances5// internally we treat objects as strings nonetheless6var objects = [{color: "blue", shape: "square", string: "blue square"},7 {color: "blue", shape: "circle", string: "blue circle"},8 {color: "green", shape: "square", string: "green square"}]910// set of utterances11var utterances = ["blue", "green", "square", "circle"]1213var preferences = ["blue_things", "green_things", "squares", "circles","none"]1415var preferenceTable = {16 blue_things : [4,4,2],17 green_things : [1,1,8],18 squares : [4,2,4],19 circles : [1,8,1],20 none : [1,1,1]21}2223var preferencePrior = function() {24 return uniformDraw(preferences)25}2627// prior over world states28var objectPrior = function(preference) {29 var obj = categorical(preferenceTable[preference],objects)30 return obj.string31}3233// meaning function to interpret the utterances34var meaning = function(utterance, obj){35 _.includes(obj, utterance)36}3738// literal listener39var literalListener = function(utterance,preference){40 Infer({model: function(){41 var obj = uniformDraw(objects).string // L0 has no preference42 condition(meaning(utterance, obj))43 return obj44 }})45}4647// set speaker optimality48var alpha = 14950// pragmatic speaker51var speaker = function(obj,preference){52 Infer({model: function(){53 var utterance = uniformDraw(utterances)54 factor(alpha * literalListener(utterance,preference).score(obj))55 return utterance56 }})57}5859// pragmatic listener60var pragmaticListener = function(utterance,preference){61 Infer({model: function(){62 var obj = objectPrior(preference)63 observe(speaker(obj,preference),utterance)64 return obj65 }})66}6768print("the listener hears 'square' and has a preference for blue things")69viz(pragmaticListener("square","blue_things"))7071print("the listener hears 'square' and has a preference for green things")72viz(pragmaticListener("square","green_things"))7374print("the listener hears 'square' and has a preference for squares")75viz(pragmaticListener("square","squares"))7677var ANSWER = (pragmaticListener('square', 'blue_things'));78
1# Frank & Goodman (2012) RSA with a preference-dependent object prior.2# Each RSA level is a separate, completely-finished Pyro discrete enumeration3# (config_enumerate + TraceEnum_ELBO.compute_marginals). Lower levels are fully4# computed and memoized BEFORE any higher level runs, so no inference ever runs5# inside another inference's active enumeration. Each level's single latent has6# a site name unique across levels (o0 / u1 / o2), avoiding site collisions.78objects = ["blue square", "blue circle", "green square"]9object_tokens = [["blue", "square"], ["blue", "circle"], ["green", "square"]]10utterances = ["blue", "green", "square", "circle"]11preferences = ["blue_things", "green_things", "squares", "circles", "none"]12preference_table = {13 "blue_things": [4, 4, 2],14 "green_things": [1, 1, 8],15 "squares": [4, 2, 4],16 "circles": [1, 8, 1],17 "none": [1, 1, 1],18}19alpha = 1.02021utt_logits = torch.zeros(len(utterances)) # speaker draws utterances uniformly22l0_state_logits = torch.zeros(len(objects)) # L0 object prior is uniform232425def meaning(utterance, obj_idx):26 return utterance in object_tokens[obj_idx]272829def marginal_dict(model, site, support):30 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(31 model, lambda: None32 )[site]33 sup = marg.enumerate_support()34 probs = marg.log_prob(sup).exp()35 out = {}36 for i in range(sup.shape[0]):37 out[support[int(sup[i].item())]] = float(probs[i].item())38 return out394041def logscore(d, key):42 p = d.get(key, 0.0)43 return math.log(p) if p > 0 else float("-inf")444546# ---- Literal listener L0: infer object given utterance (uniform object prior) ----47L0_cache = {}484950def literal_listener(utterance):51 if utterance in L0_cache:52 return L0_cache[utterance]5354 @pyro.infer.config_enumerate55 def model():56 o = pyro.sample("o0", dist.Categorical(logits=l0_state_logits))57 ok = torch.tensor([meaning(utterance, i) for i in range(len(objects))])58 pyro.factor("ev", torch.where(ok[o], torch.tensor(0.0), torch.tensor(float("-inf"))))5960 L0_cache[utterance] = marginal_dict(model, "o0", objects)61 return L0_cache[utterance]626364# ---- Speaker: infer utterance given object ----65S_cache = {}666768def speaker(obj):69 if obj in S_cache:70 return S_cache[obj]71 # speaker scores by L0.score(obj); pre-warm all L0 marginals first.72 sc = torch.tensor([alpha * logscore(literal_listener(u), obj) for u in utterances])7374 @pyro.infer.config_enumerate75 def model():76 u = pyro.sample("u1", dist.Categorical(logits=utt_logits))77 pyro.factor("f", sc[u])7879 S_cache[obj] = marginal_dict(model, "u1", utterances)80 return S_cache[obj]818283# ---- Pragmatic listener: infer object given (utterance, preference) ----84def pragmatic_listener(utterance, preference):85 obj_logits = torch.log(torch.tensor([float(w) for w in preference_table[preference]]))86 # Fully compute & memoize every speaker marginal we will need BEFORE the87 # outer enumeration runs.88 sc = torch.tensor([logscore(speaker(objects[i]), utterance) for i in range(len(objects))])8990 @pyro.infer.config_enumerate91 def model():92 o = pyro.sample("o2", dist.Categorical(logits=obj_logits))93 pyro.factor("obs", sc[o])9495 return marginal_dict(model, "o2", objects)969798ANSWER = pragmatic_listener("square", "blue_things")99
| check | status | evidence |
|---|---|---|
| GT self-consistency | ok | floor 0.0000 (tv) |
| solver re-derivation | accept | 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6 |
| cross-language (pyro vs webppl) | pass | d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000 |
The world state is the number of red apples on a table, drawn uniformly from {0, 1, 2, 3}. Three utterances drawn uniformly: 'all', 'some', 'none'. Literal meanings: 'all' is true iff state = 3; 'some' is true iff state > 0; 'none' is true iff state = 0. Two possible QUDs: 'all?' (true iff state = 3) and 'any?' (true iff state > 0). Speaker optimality alpha = 1.
A literal listener for a given utterance and QUD samples a world state uniformly, conditions on the utterance's literal meaning holding, and returns the QUD value (a boolean). A speaker for a given state and QUD draws an utterance uniformly and weights it by the literal listener's log-probability of the QUD value at that state, scaled by alpha. A pragmatic listener for a given utterance and QUD samples a world state uniformly and updates on the speaker choosing the heard utterance.
The posterior distribution over world states (number of red apples) for a pragmatic listener who hears the utterance 'some' under the QUD 'any?'.
answer spec
{
"kind": "dist",
"domain": "int"
}system prompt
(system prompt loads here)
webppl primer
(primer loads here)
1// possible states of the world2var statePrior = function() {3 return uniformDraw([0, 1, 2, 3])4};56// possible utterances7var utterancePrior = function() {8 return uniformDraw(['all', 'some', 'none']);9};1011// possible quds12var quds = ['all?','any?']1314// prior over quds (only relevant for qud inference)15var qudPrior = function() {16 return uniformDraw(quds);17};1819// meaning funtion to interpret the utterances20var literalMeanings = {21 all: function(state) { return state === 3; },22 some: function(state) { return state > 0; },23 none: function(state) { return state === 0; }24};2526// projection function27var qudFn = function(qud, state) {28 var qudAdressed = qud === "all?" ? state === 3 : state > 029 return qudAdressed30}3132// literal listener33var literalListener = cache(function(utt,qud) {34 return Infer({model: function(){35 var state = statePrior()36 var meaning = literalMeanings[utt]37 condition(meaning(state))38 return qudFn(qud,state)39 }})40});4142// set speaker optimality43var alpha = 14445// pragmatic speaker46var speaker = cache(function(state,qud) {47 return Infer({model: function(){48 var utt = utterancePrior()49 factor(alpha * literalListener(utt,qud).score(qudFn(qud,state)))50 return utt51 }})52});5354// pragmatic listener55var pragmaticListener = cache(function(utt,qud) {56 return Infer({model: function(){57 var state = statePrior()58 observe(speaker(state,qud),utt)59 return state60 }})61});6263// print("pragmatic listener's interpretation of 'some':")64viz(pragmaticListener('some','any?'));6566var ANSWER = (pragmaticListener('some','any?'));67
1# Scalar-implicature RSA with QUDs (Goodman & Stuhlmuller style).2# Each RSA level is a separate, completely-finished Pyro discrete enumeration3# (config_enumerate + TraceEnum_ELBO.compute_marginals). Lower levels are fully4# computed and memoized BEFORE any higher level runs, so no inference ever runs5# inside another inference's active enumeration. Latent site names are unique6# across levels (s0 / u1 / s2), avoiding site collisions.78states = [0, 1, 2, 3]9utterances = ["all", "some", "none"]10quds = ["all?", "any?"]11alpha = 1.01213state_logits = torch.zeros(len(states)) # uniformDraw([0,1,2,3])14utt_logits = torch.zeros(len(utterances)) # uniformDraw(['all','some','none'])151617def literal_meaning(utt, state):18 if utt == "all":19 return state == 320 if utt == "some":21 return state > 022 return state == 0 # none232425def qud_fn(qud, state):26 if qud == "all?":27 return state == 328 return state > 0 # any?293031def marginal_dict(model, site, support):32 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(33 model, lambda: None34 )[site]35 sup = marg.enumerate_support()36 probs = marg.log_prob(sup).exp()37 out = {}38 for i in range(sup.shape[0]):39 out[support[int(sup[i].item())]] = float(probs[i].item())40 return out414243def logscore(d, key):44 p = d.get(key, 0.0)45 return math.log(p) if p > 0 else float("-inf")464748# ---- Literal listener L0: infer qudFn(qud,state) given (utt, qud) ----49L0_cache = {}505152def literal_listener(utt, qud):53 key = (utt, qud)54 if key in L0_cache:55 return L0_cache[key]56 qspace = [False, True] # support over the QUD projection value5758 @pyro.infer.config_enumerate59 def model():60 s = pyro.sample("s0", dist.Categorical(logits=state_logits))61 ok = torch.tensor([literal_meaning(utt, st) for st in states])62 pyro.factor("ev", torch.where(ok[s], torch.tensor(0.0), torch.tensor(float("-inf"))))63 qidx = torch.tensor([1 if qud_fn(qud, st) else 0 for st in states])64 # deterministic projection of the latent state to its QUD value,65 # realized as an enumerated site pinned to qudFn(qud,state).66 x = pyro.sample("qval", dist.Categorical(logits=torch.zeros(2)))67 pyro.factor("proj", torch.where(x == qidx[s], torch.tensor(0.0), torch.tensor(float("-inf"))))6869 L0_cache[key] = marginal_dict(model, "qval", qspace)70 return L0_cache[key]717273# ---- Speaker: infer utterance given (state, qud) ----74S_cache = {}757677def speaker(state, qud):78 key = (state, qud)79 if key in S_cache:80 return S_cache[key]81 qval = qud_fn(qud, state)82 # speaker scores by L0.score(qudFn(qud,state)); pre-warm L0 first.83 sc = torch.tensor([alpha * logscore(literal_listener(u, qud), qval) for u in utterances])8485 @pyro.infer.config_enumerate86 def model():87 u = pyro.sample("u1", dist.Categorical(logits=utt_logits))88 pyro.factor("f", sc[u])8990 S_cache[key] = marginal_dict(model, "u1", utterances)91 return S_cache[key]929394# ---- Pragmatic listener: infer state given (utt, qud) ----95def pragmatic_listener(utt, qud):96 # Fully compute & memoize all speaker marginals we need BEFORE outer run.97 sc = torch.tensor([logscore(speaker(st, qud), utt) for st in states])9899 @pyro.infer.config_enumerate100 def model():101 s = pyro.sample("s2", dist.Categorical(logits=state_logits))102 pyro.factor("obs", sc[s])103104 return marginal_dict(model, "s2", states)105106107ANSWER = pragmatic_listener("some", "any?")108
| check | status | evidence |
|---|---|---|
| GT self-consistency | ok | floor 0.0000 (w1) |
| solver re-derivation | accept | 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6 |
| cross-language (pyro vs webppl) | pass | d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000 |
Each trial draws nMarbles = 8 marbles. Confidence threshold = 0.6. The participant drew selfData = 4 red marbles. Four other agents reported: [{prediction: 'red', confidence: 'high'}, {prediction: 'red', confidence: 'high'}, {prediction: 'blue', confidence: 'low'}, {prediction: 'blue', confidence: 'low'}]. The prior over the true proportion of red marbles in the urn pRed is a discrete uniform over {0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1}.
Each other agent infers pRed from the same prior by updating on a Binomial(n=8, p=pRed) draw equal to their (latent) red count, then predicts 'red' with probability pRed and 'blue' with probability 1 - pRed. An agent's confidence is 'high' if (prediction = 'red' and pRed >= 0.6) or (prediction = 'blue' and pRed <= 0.4), and 'low' otherwise. The participant infers pRed by first conditioning on a Binomial(n=8, p=pRed) draw equal to selfData = 4, then incorporating each other agent's report by computing the expected log-likelihood of that report over the distribution of latent red counts the agent could have seen (Binomial(n=8, p=pRed)), where each latent count is evaluated under the agent's inference model.
The posterior distribution over pRed (the urn's true proportion of red marbles), computed by exact enumeration.
answer spec
{
"kind": "dist",
"domain": "real"
}system prompt
(system prompt loads here)
webppl primer
(primer loads here)
1// total number of marbles drawn from urn every time2var nMarbles = 8;3var threshold = .6;45// example data point for self6var selfData = 4;78// example data point for others9var otherData = [{prediction: 'red', confidence: 'high'},10 {prediction: 'red', confidence: 'high'},11 {prediction: 'blue', confidence: 'low'},12 {prediction: 'blue', confidence: 'low'}];1314// (discretized) uniform distribution over actual proportion of red in urn15var rednessPrior = Categorical({vs: [0, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1]});1617/*18 Generative model of other agents1920 Assumes they are *also* doing inference about actual proportion21 based on their data and responding according to their best guess...22 */23var otherOutput = cache(function(kRed, threshold) {24 return Infer({method: 'enumerate', model: function() {25 var pRed = sample(rednessPrior);26 var prediction = flip(pRed) ? 'red' : 'blue';27 var highConf = prediction === 'red' ? pRed >= threshold : pRed <= 1 - threshold;2829 observe(Binomial({p: pRed, n: nMarbles}), kRed);30 return {prediction: prediction, confidence: highConf ? 'high' : 'low'}31 }});32})3334/*35 Model of participant's inference on a given trial36*/37var trialModel = function() {38 // participant is trying to infer latent distribution in urn39 var pRed = sample(rednessPrior);4041 // first, take into account own data (i.e. a draw of balls from urn)42 observe(Binomial({p: pRed, n: nMarbles}), selfData);4344 // next, take into account social information45 // assume their sample was drawn from sample population but not sure of exact data46 mapData({data: otherData}, function(datum) {47 var kSeenPrior = Binomial({p: pRed, n : nMarbles});48 var likelihood = expectation(kSeenPrior, function(k) {49 return otherOutput(k, threshold).score(datum)50 })51 factor(likelihood);52 })5354 // ugly js convert to string for pretty plots55 return pRed;56}575859var ANSWER = (Infer({method: 'enumerate', model: trialModel}))
12# forestdb-schizophrenia-urns/atom-13# Participant infers the urn's red-proportion pRed from own draw + social reports.4# Both the participant's inference and each other agent's inference are exact5# enumeration over the discretized pRed prior.67nMarbles = 88threshold = 0.69selfData = 410otherData = [11 {"prediction": "red", "confidence": "high"},12 {"prediction": "red", "confidence": "high"},13 {"prediction": "blue", "confidence": "low"},14 {"prediction": "blue", "confidence": "low"},15]1617pRed_values = torch.tensor([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])18n_pRed = pRed_values.shape[0]192021def binom_logprob(k, n, p):22 # log P(X = k) for X ~ Binomial(n, p); a single scalar tensor23 return dist.Binomial(total_count=n, probs=p).log_prob(torch.tensor(float(k)))242526# --- Other agent's generative model (exact enumeration over pRed) ----------27# Given the agent saw kRed reds, the agent's posterior over pRed is obtained by28# enumerating pRed and observing the agent's own draw. The agent then reports a29# prediction (red w.p. pRed) with a deterministic confidence. We read the30# induced report distribution off the enumerated pRed posterior.31def other_output(kRed):32 @pyro.infer.config_enumerate33 def agent_model():34 idx = pyro.sample("pRed", dist.Categorical(torch.ones(n_pRed) / n_pRed))35 p = pRed_values[idx]36 pyro.factor("agent_data", binom_logprob(kRed, nMarbles, p))37 return idx3839 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(40 agent_model, lambda: None41 )42 probs_p = marg["pRed"].probs # agent's posterior over pRed indices4344 out = {}45 for i in range(n_pRed):46 p = pRed_values[i].item()47 wp = probs_p[i].item()48 for pred_name, pred_p in (("red", p), ("blue", 1.0 - p)):49 if pred_name == "red":50 high = p >= threshold51 else:52 high = p <= 1.0 - threshold53 conf = "high" if high else "low"54 key = (pred_name, conf)55 out[key] = out.get(key, 0.0) + wp * pred_p56 return out575859# Cache agent outputs for every possible kRed (0..nMarbles)60_other_cache = {k: other_output(k) for k in range(nMarbles + 1)}616263# --- Participant's inference (exact enumeration over pRed) -----------------64@pyro.infer.config_enumerate65def trial_model():66 idx = pyro.sample("pRed", dist.Categorical(torch.ones(n_pRed) / n_pRed))67 p = pRed_values[idx]6869 # own data70 pyro.factor("self_data", binom_logprob(selfData, nMarbles, p))7172 # social data: expected log-likelihood of each report over the agent's73 # latent red count k ~ Binomial(nMarbles, p)74 for j, datum in enumerate(otherData):75 key = (datum["prediction"], datum["confidence"])76 exp_log_lik = torch.tensor(0.0)77 for k in range(nMarbles + 1):78 p_k = binom_logprob(k, nMarbles, p).exp()79 p_datum = _other_cache[k].get(key, 1e-300)80 exp_log_lik = exp_log_lik + p_k * torch.log(torch.tensor(p_datum))81 pyro.factor(f"social_{j}", exp_log_lik)8283 return idx848586marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(87 trial_model, lambda: None88)89pRed_post = marg["pRed"].probs9091ANSWER = {round(pRed_values[i].item(), 1): pRed_post[i].item() for i in range(n_pRed)}92
| check | status | evidence |
|---|---|---|
| GT self-consistency | ok | floor 0.0000 (tv) |
| solver re-derivation | accept | 1/2 solvers · d=[0.000, 0.228] · claude-sonnet-4-6 |
| cross-language (pyro vs webppl) | pass | d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000 |
The sentence template is 'John hit Fred and Ellen hit ___.' Possible utterances to fill the blank are: 'him' (ambiguous pronoun), 'Fred' (unambiguous), and 'John' (unambiguous), with a categorical prior over utterances weighted [2, 1, 1] respectively. The referent space is {John, Fred}, drawn with equal probability. Possible resolution strategies are {Subject, Parallel}, drawn with equal probability. The meaning of an utterance given a referent and a strategy is: if the utterance is 'him', the Subject strategy resolves it to John and the Parallel strategy resolves it to Fred; if the utterance is a proper name, it is true only when the referent matches the name exactly.
A literal listener interprets an utterance relative to a fixed strategy: starting from the uniform referent prior, it conditions on the utterance being true of the referent under that strategy. A speaker who intends a referent under a given strategy chooses an utterance with probability proportional to the utterance prior times the literal listener's probability of assigning the intended referent to that utterance under that strategy. A pragmatic listener who hears an utterance weights each (referent, strategy) pair by its prior probability times the speaker's probability of producing that utterance for that referent under that strategy, and marginalizes over strategies.
The posterior distribution over referents (John or Fred) for a pragmatic listener who hears 'him'.
answer spec
{
"kind": "dist",
"domain": "finite",
"support": [
"John",
"Fred"
]
}system prompt
(system prompt loads here)
webppl primer
(primer loads here)
1///fold:2// possible utterances: Ambiguous, Unambiguous Fred, Unambiguous John3var utterances = ["him", "Fred", "John"]45// samples the utterances that fill in the following template:6// "John hit Fred and Ellen hit (the utterance)"78var utterancePrior = function() {9 categorical([2,1,1], utterances)10}1112//possible world states of who Ellen hit13var states = ["John", "Fred"]14var statePrior = function() {15 return uniformDraw(states)16}1718//possible strategies19var strategyPrior = function(){20 return uniformDraw(["Subject", "Parallel"])21}2223//meaning function24var meaning = function(utterance, state, strategy){25 return utterance == "him" ?26 strategy == "Subject" ? state == "John" :27 state == "Fred" :28 utterance == state29}3031// Literal Listener (L0)32var literalListener = cache(function(utterence, strategy){33 return Infer({model: function(){34 var state = statePrior()35 condition(meaning(utterence, state, strategy))36 return state37 }})38})3940// Speaker (S)41var speaker = cache(function(strategy, state){42 return Infer({model: function() {43 var utterance = utterancePrior()44 observe(literalListener(utterance, strategy), state)45 return utterance46 }})47})48///49// Pragmatic listener (L1)50var pragmaticListener = cache(function(utterance) {51 return Infer({model: function(){52 var state = statePrior()53 var strategy = strategyPrior()54 observe(speaker(strategy,state),utterance)55 return state56 }})57})58pragmaticListener("him")5960var ANSWER = (pragmaticListener('him'));61
1# Singh & Uyeda pronoun-resolution RSA.2# Each RSA level is a separate, completely-finished Pyro discrete enumeration3# (config_enumerate + TraceEnum_ELBO.compute_marginals). Lower levels are fully4# computed and memoized BEFORE any higher level runs, so no inference runs inside5# another's active enumeration. Latent site names are unique across levels6# (s0 / u1 / s2,st2), avoiding site collisions.7# Note: the Speaker here does observe(literalListener(utterance,strategy), state),8# i.e. it factors by L0.score(state) directly, matching the WebPPL reference.910utterances = ["him", "Fred", "John"]11states = ["John", "Fred"]12strategies = ["Subject", "Parallel"]1314utt_logits = torch.log(torch.tensor([2.0, 1.0, 1.0])) # categorical([2,1,1], utterances)15state_logits = torch.zeros(len(states)) # uniformDraw16strategy_logits = torch.zeros(len(strategies)) # uniformDraw171819def meaning(utterance, state, strategy):20 if utterance == "him":21 if strategy == "Subject":22 return state == "John"23 return state == "Fred"24 return utterance == state252627def marginal_dict(model, site, support):28 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(29 model, lambda: None30 )[site]31 sup = marg.enumerate_support()32 probs = marg.log_prob(sup).exp()33 out = {}34 for i in range(sup.shape[0]):35 out[support[int(sup[i].item())]] = float(probs[i].item())36 return out373839def logscore(d, key):40 p = d.get(key, 0.0)41 return math.log(p) if p > 0 else float("-inf")424344# ---- Literal listener L0: infer state given (utterance, strategy) ----45L0_cache = {}464748def literal_listener(utterance, strategy):49 key = (utterance, strategy)50 if key in L0_cache:51 return L0_cache[key]5253 @pyro.infer.config_enumerate54 def model():55 s = pyro.sample("s0", dist.Categorical(logits=state_logits))56 ok = torch.tensor([meaning(utterance, st, strategy) for st in states])57 pyro.factor("ev", torch.where(ok[s], torch.tensor(0.0), torch.tensor(float("-inf"))))5859 L0_cache[key] = marginal_dict(model, "s0", states)60 return L0_cache[key]616263# ---- Speaker S: infer utterance given (strategy, state); observes L0 at state ----64S_cache = {}656667def speaker(strategy, state):68 key = (strategy, state)69 if key in S_cache:70 return S_cache[key]71 # observe(literalListener(utterance,strategy), state) -> factor L0.score(state)72 sc = torch.tensor([logscore(literal_listener(u, strategy), state) for u in utterances])7374 @pyro.infer.config_enumerate75 def model():76 u = pyro.sample("u1", dist.Categorical(logits=utt_logits))77 pyro.factor("obs", sc[u])7879 S_cache[key] = marginal_dict(model, "u1", utterances)80 return S_cache[key]818283# ---- Pragmatic listener L1: infer state given utterance ----84def pragmatic_listener(utterance):85 # Fully compute & memoize the whole speaker table BEFORE outer run.86 tbl = torch.full((len(states), len(strategies)), float("-inf"))87 for si, sstate in enumerate(states):88 for sti, strat in enumerate(strategies):89 tbl[si, sti] = logscore(speaker(strat, sstate), utterance)9091 @pyro.infer.config_enumerate92 def model():93 s = pyro.sample("s2", dist.Categorical(logits=state_logits))94 st = pyro.sample("st2", dist.Categorical(logits=strategy_logits))95 pyro.factor("obs", tbl[s, st])9697 return marginal_dict(model, "s2", states)9899100ANSWER = pragmatic_listener("him")101
| check | status | evidence |
|---|---|---|
| GT self-consistency | ok | floor 0.0000 (tv) |
| solver re-derivation | accept | 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6 |
| cross-language (pyro vs webppl) | pass | d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000 |
Four personae: {name: 'stern', competence: true, friendliness: false}, {name: 'cool', competence: true, friendliness: true}, {name: 'asshole', competence: false, friendliness: false}, {name: 'doofus', competence: false, friendliness: true}. The prior over personae for a voter audience assigns probabilities [0.3, 0.2, 0.3, 0.2] to [stern, cool, asshole, doofus] respectively. Two morphological variants: 'n' and 'ng', both with production cost 0. Speaker optimality alpha = 6. Eckert-Montague semantics: 'ng' is consistent with a persona if that persona has competence = true or friendliness = false; 'n' is consistent with a persona if that persona has competence = false or friendliness = true.
A literal listener for a variant samples from the voter persona prior, conditions on the variant being semantically consistent with the persona, and returns the persona. A speaker for a given persona draws uniformly from the two variants and weights each by alpha times (log-probability the literal listener assigns to that persona given the variant, minus the variant's cost). A naive pragmatic listener for a variant samples a persona from the voter prior and updates on the speaker choosing that variant.
The posterior distribution over persona names for a naive pragmatic listener who hears the variant 'ng'.
answer spec
{
"kind": "dist",
"domain": "finite",
"support": [
"stern",
"cool",
"asshole",
"doofus"
]
}system prompt
(system prompt loads here)
webppl primer
(primer loads here)
1// Personae in (6)2var personae = [{name: "stern", competence : true, friendliness : false},3 {name: "cool", competence : true, friendliness : true},4 {name: "asshole", competence : false, friendliness : false},5 {name: "doofus", competence : false, friendliness: true}]67// Table 28var voterPrior = Categorical({ps: [0.3,0.2,0.3,0.2], vs: personae})910// Table 511var journalistPrior = Categorical({ps: [0.2,0.2,0.3,0.3], vs: personae})1213var personaePrior = voterPrior14// var personaePrior = journalistPrior1516var variants = ["n","ng"]1718var cost = {19 n : 0,20 ng : 021}2223// Eckert-Montague Semantics (Table 1)24var semantics = {25 ng: function(persona) { return (persona.competence == true | persona.friendliness == false) ; },26 n: function(persona) { return (persona.competence == false | persona.friendliness == true); },27}2829// Definition in (11) - the 'literal listener'3031var conditionalization = function(variant) {32 return Infer({model: function(){33 var persona = sample(personaePrior)34 var meaning = semantics[variant]35 condition(meaning(persona))36 return persona37 }}38)}3940// Definition in (12)4142var utility = function(persona, variant) {4344 var informativity = conditionalization(variant).score(persona)45 return(informativity - cost[variant])4647}4849var alpha = 65051// Definition in (13) - soft-max choice rule5253var speaker = function(persona) {54 return Infer(function() {55 var variant = uniformDraw(variants)56 factor(alpha * utility(persona,variant))57 return(variant)58 })59}6061// Table 6: persona selection function (the value system)6263var mu = function(persona) {6465 persona.name == "cool" ? 2 :66 persona.name == "stern" ? 1 :67 persona.name == "doofus" ? 1 :68 persona.name == "asshole" ? 0 :69 07071}7273// Definition in (14): probability distribution over personae7475var alphaprime = 67677var personaDistribution = Infer(78 function() {79 var persona = sample(personaePrior)80 factor(alphaprime * mu(persona))81 return persona82 })8384// Definition in (15): speaker with a value system8586var valueSpeaker = function(variant) {8788 // Array of utilities of variant for each persona, times probability of the persona89 var variantUtility = map(function(persona) {90 return Math.exp(personaDistribution.score(persona)) * Math.exp(speaker(persona).score(variant))91 }, personae)9293 return sum(variantUtility)9495}9697// Definition in (17): Listening with certainty about speaker's values9899var valueInformedListener = function(variant) {100 return Infer(function(){101102 var persona = sample(personaePrior)103 factor(Math.exp(personaDistribution.score(persona)) * Math.exp(speaker(persona).score(variant)))104 return persona.name105106 })107}108109// Definition in (18): Naive listening110111var naiveListener = function(variant) {112 return Infer(function(){113114 var persona = sample(personaePrior)115 factor(speaker(persona).score(variant))116 return persona.name117 })118}119120print("Literal L's beliefs immediately after hearing -n at the barbecue")121viz.table(conditionalization('n'))122123print("Literal L’s beliefs immediately after hearing -ng at the barbecue")124viz.table(conditionalization('ng'))125126print("Obama wants to be the cool guy")127viz.table(speaker(personae[1]))128129// set prior to journalistPrior above130print("Obama's overall probability of using -ng with the journalist")131132print(valueSpeaker('ng'))133134print("Obama's overall probability of using -n with the journalist")135136print(valueSpeaker('n'))137138print("Hearing Obama use -n (and you have a sense of Obama's values)")139140viz(valueInformedListener('n'))141142print("Hearing Obama use -n (and you're naive as to Obama's values)")143144viz(naiveListener('n'))145146var ANSWER = naiveListener('ng');147
1# RSA-style social-meaning model (Eckert-Montague semantics).2# Literal listener -> speaker (softmax) -> naive pragmatic listener, all by exact3# enumeration. Each level's posterior is obtained through Pyro inference and its4# log-probabilities (WebPPL's .score) feed the next level.56personae = [7 {"name": "stern", "competence": True, "friendliness": False},8 {"name": "cool", "competence": True, "friendliness": True},9 {"name": "asshole", "competence": False, "friendliness": False},10 {"name": "doofus", "competence": False, "friendliness": True},11]12prior_ps = torch.tensor([0.3, 0.2, 0.3, 0.2])13variants = ["n", "ng"]14cost = {"n": 0.0, "ng": 0.0}15alpha = 6.0161718def semantics(variant, persona):19 if variant == "ng":20 return persona["competence"] is True or persona["friendliness"] is False21 else: # "n"22 return persona["competence"] is False or persona["friendliness"] is True232425# --- Literal listener (definition 11): condition the persona prior on the26# --- variant being semantically consistent; returns a distribution over personae.27def literal_listener(variant):28 @pyro.infer.config_enumerate29 def model():30 idx = pyro.sample("persona", dist.Categorical(prior_ps))31 consistent = torch.tensor(32 [1.0 if semantics(variant, p) else 0.0 for p in personae]33 )[idx]34 logw = torch.where(35 consistent > 0, torch.tensor(0.0), torch.tensor(float("-inf"))36 )37 pyro.factor("sem", logw)38 return idx3940 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(41 model, lambda: None42 )43 return marg["persona"]444546LL = {v: literal_listener(v) for v in variants}47idx_support = torch.arange(len(personae))484950def literal_score(variant, persona_idx):51 return LL[variant].log_prob(torch.tensor(persona_idx))525354def utility(persona_idx, variant):55 return literal_score(variant, persona_idx) - cost[variant]565758# --- Speaker (definition 13): uniform over variants, factor by alpha*utility;59# --- returns a distribution over variants for a given persona.60def speaker(persona_idx):61 @pyro.infer.config_enumerate62 def model():63 v = pyro.sample("variant", dist.Categorical(torch.tensor([0.5, 0.5])))64 u = torch.stack(65 [alpha * utility(persona_idx, variants[j]) for j in range(len(variants))]66 )[v]67 pyro.factor("u", u)68 return v6970 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(71 model, lambda: None72 )73 return marg["variant"]747576SPK = [speaker(i) for i in range(len(personae))]777879def speaker_score(persona_idx, variant):80 v_idx = variants.index(variant)81 return SPK[persona_idx].log_prob(torch.tensor(v_idx))828384# --- Naive pragmatic listener (definition 18): sample persona from prior, factor85# --- by speaker(persona).score(variant), return the persona name distribution.86def naive_listener(variant):87 @pyro.infer.config_enumerate88 def model():89 idx = pyro.sample("persona", dist.Categorical(prior_ps))90 s = torch.stack(91 [speaker_score(i, variant) for i in range(len(personae))]92 )[idx]93 pyro.factor("spk", s)94 return idx9596 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(97 model, lambda: None98 )99 return marg["persona"]100101102final = naive_listener("ng")103probs = torch.exp(final.log_prob(idx_support))104ANSWER = {personae[i]["name"]: probs[i].item() for i in range(len(personae))}105
| check | status | evidence |
|---|---|---|
| GT self-consistency | ok | floor 0.0000 (tv) |
| solver re-derivation | accept | 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6 |
| cross-language (pyro vs webppl) | pass | d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000 |
Two items, watch and sweater, each with a prior distribution over prices and a theta prior (uniform over the item's price bins). Sweater price bins and probabilities: | price | probability | |---|---| | 1.5 | 0.00482838499944466 | | 4.5 | 0.00832934578733181 | | 7.5 | 0.0112952500492109 | | 10.5 | 0.0173774790108894 | | 13.5 | 0.0232006658974883 | | 16.5 | 0.0258422772579257 | Watch price bins: 60 bins from 25 to 2975 in steps of 50, with probabilities [0.040844560268751, 0.0587099798246933, 0.0656194599591356, 0.0667642412698035, 0.0615953803048016, 0.0510809063784378, 0.0467203673419258, 0.0446735950187136, 0.040047421916613, 0.0350583957334483, 0.0297508215717606, 0.0256829651118227, 0.024135920250668, 0.0228891907259206, 0.021706684520276, 0.0186449440066946, 0.0187249266247728, 0.0179250744798993, 0.0173698811746238, 0.0165581725818319, 0.0160745066032247, 0.0127927305129066, 0.0113730680265067, 0.0109485307623827, 0.00923468422650943, 0.00899007751887508, 0.00880520147998275, 0.00838023585866885, 0.00841052411004918, 0.00828830635037619, 0.00834008093757411, 0.00750681534099784, 0.00724072133740109, 0.00717291664158004, 0.00682823777708754, 0.00646995193940331, 0.00697139732982518, 0.00711846547272734, 0.00698781312802354, 0.00732316558583701, 0.00594973158122097, 0.00557461443747403, 0.00541637601910211, 0.00518850469148531, 0.00572025848989677, 0.0051443557601358, 0.00510282169734075, 0.00493720252580643, 0.00560198932991028, 0.00519158715054485, 0.00473398797752786, 0.00540907722833213, 0.00494653421540979, 0.00495500420164643, 0.00494083025189895, 0.00481566268206312, 0.00442965937328148, 0.00441189688100535, 0.00415116538135834, 0.00361842012002631]. Two utterances: 'expensive' and 'not-inexpensive'. Utterance costs: 'expensive' costs 1, 'not-inexpensive' costs 2. Speaker optimality alpha = 2. Soft semantics: 'expensive' is true of a price with probability 0.9999 if price > theta_expensive, and 0.0001 otherwise; 'not-inexpensive' is true with probability 0.9999 if price >= theta_inexpensive, and 0.0001 otherwise. The threshold theta_expensive is drawn uniformly from the item's price bins. With probability 0.2, theta_inexpensive is drawn independently from the item's price bins; with probability 0.8, theta_inexpensive equals theta_expensive.
A literal listener for a given utterance and pair of thresholds draws a price uniformly from the item's bins, conditions on the soft semantic being true, and returns the price. A speaker for a given price and threshold pair draws an utterance uniformly and weights it by alpha times (the literal listener's log-probability of the price minus the utterance's cost). A pragmatic listener for a given utterance jointly samples a price from the item's categorical prior, a theta_expensive threshold, and a theta_inexpensive threshold (correlated with theta_expensive with probability 0.8), then updates on the speaker choosing the heard utterance.
Two scalar expectations: the expected price inferred by a pragmatic listener hearing 'expensive' for the sweater item, and the expected price inferred hearing 'not-inexpensive' for the sweater item. Return as a record with fields expensivePrice and notInexpensivePrice.
answer spec
{
"kind": "record",
"fields": {
"expensivePrice": {
"kind": "value",
"domain": "real"
},
"notInexpensivePrice": {
"kind": "value",
"domain": "real"
}
}
}system prompt
(system prompt loads here)
webppl primer
(primer loads here)
1var watch = {2 "prices": [25, 75, 125, 175, 225, 275, 325, 375, 425, 475, 525, 575, 625, 675, 725, 775, 825, 875, 925, 975, 1025, 1075, 1125, 1175, 1225, 1275, 1325, 1375, 1425, 1475, 1525, 1575, 1625, 1675, 1725, 1775, 1825, 1875, 1925, 1975, 2025, 2075, 2125, 2175, 2225, 2275, 2325, 2375, 2425, 2475, 2525, 2575, 2625, 2675, 2725, 2775, 2825, 2875, 2925, 2975],3 "probabilities": [0.040844560268751, 0.0587099798246933, 0.0656194599591356, 0.0667642412698035, 0.0615953803048016, 0.0510809063784378, 0.0467203673419258, 0.0446735950187136, 0.040047421916613, 0.0350583957334483, 0.0297508215717606, 0.0256829651118227, 0.024135920250668, 0.0228891907259206, 0.021706684520276, 0.0186449440066946, 0.0187249266247728, 0.0179250744798993, 0.0173698811746238, 0.0165581725818319, 0.0160745066032247, 0.0127927305129066, 0.0113730680265067, 0.0109485307623827, 0.00923468422650943, 0.00899007751887508, 0.00880520147998275, 0.00838023585866885, 0.00841052411004918, 0.00828830635037619, 0.00834008093757411, 0.00750681534099784, 0.00724072133740109, 0.00717291664158004, 0.00682823777708754, 0.00646995193940331, 0.00697139732982518, 0.00711846547272734, 0.00698781312802354, 0.00732316558583701, 0.00594973158122097, 0.00557461443747403, 0.00541637601910211, 0.00518850469148531, 0.00572025848989677, 0.0051443557601358, 0.00510282169734075, 0.00493720252580643, 0.00560198932991028, 0.00519158715054485, 0.00473398797752786, 0.00540907722833213, 0.00494653421540979, 0.00495500420164643, 0.00494083025189895, 0.00481566268206312, 0.00442965937328148, 0.00441189688100535, 0.00415116538135834, 0.00361842012002631]4};5var sweater = {6 "prices": [1.5, 4.5, 7.5, 10.5, 13.5, 16.5],7 "probabilities": [0.00482838499944466, 0.00832934578733181, 0.0112952500492109, 0.0173774790108894, 0.0232006658974883, 0.0258422772579257]8};9var data = {10 "watch": watch,11 "sweater": sweater12};1314var prior = function(item) {15 var prices = data[item].prices;16 var probabilities = data[item].probabilities;17 return function() {18 return categorical(probabilities, prices);19 };20};2122var theta_prior = function(item) {23 var thetas = data[item].prices;24 return function() {25 return uniformDraw(thetas) ;26 };27};2829var alpha = 2; // optimality parameter303132var utterances = ["expensive","not-inexpensive"];3334var cost = {35 "not-inexpensive":2,36 "expensive": 1,37};38var utterancePrior = function() {39 return uniformDraw(utterances);40};4142var meaning = function(utterance, price, theta) {43 utterance == "expensive" ? price > theta.expensive ? flip(0.9999) : flip(0.0001) :44 utterance == "not-inexpensive" ? !(price < theta.inexpensive)? flip(0.9999) : flip(0.0001):45 true;46};4748var literalListener = cache(function(utterance, theta, item) {49 return Infer({method: "enumerate"}, function() {50 var price = uniformDraw(data[item].prices)51 condition(meaning(utterance, price, theta))52 return price;53 });54});5556var speaker = cache(function(price, theta, item) {57 return Infer({method: "enumerate"}, function() {58 var utterance = utterancePrior();59 factor( alpha * (literalListener(utterance, theta, item).score(price)60 - cost[utterance]));61 return utterance;62 });63});6465var pragmaticListener = function(utterance, item) {66 // first identify the relevant priors67 var pricePrior = prior(item);68 var thetaPrior = theta_prior(item);69 // then run inference70 return Infer({method: "enumerate"},71 function() {72 var an_neg_thre = flip(0.2)73 var expensive_theta= thetaPrior()7475 var inexpensive_threshold = an_neg_thre ?76 thetaPrior() :77 expensive_theta;78 var price = pricePrior();79 var theta = {80 expensive: expensive_theta,81 inexpensive: inexpensive_threshold82 }83 var Posexp = theta.expensive84 var Posneg = theta.inexpensive85 factor( speaker(price, theta, item).score(utterance) );86 return { price: price, Posexp: Posexp , Posneg: Posneg };87 });88};899091var expensiveSweater= pragmaticListener("expensive", "sweater");92print("Expensive:Prices")93viz.density(marginalize(expensiveSweater, "price"));94display(expectation(marginalize(expensiveSweater, "price")))95print("Expensive:Thresholds:")96viz.density(marginalize(expensiveSweater, "Posexp"));97display(expectation(marginalize(expensiveSweater, "Posexp")))98var notinexpensiveSweater= pragmaticListener("not-inexpensive", "sweater");99print("NOT-Inexpensive : Prices")100viz.density(marginalize(notinexpensiveSweater, "price"));101display(expectation(marginalize(notinexpensiveSweater, "price")))102print("NOT-Inexpensive:Thresholds:")103viz.density(marginalize(notinexpensiveSweater, "Posneg"));104display(expectation(marginalize(notinexpensiveSweater, "Posneg")))105106var ANSWER = ({ expensivePrice: expectation(marginalize(pragmaticListener("expensive","sweater"), "price")), notInexpensivePrice: expectation(marginalize(pragmaticListener("not-inexpensive","sweater"), "price")) });107
1# RSA antonyms model (Zhu et al.) for the 'sweater' item.2# Every level of inference (literal listener, speaker, pragmatic listener) is run3# through Pyro's exact discrete enumeration (config_enumerate + TraceEnum_ELBO4# .compute_marginals). The answer is the expectation over the price posterior of5# the pragmatic listener for two utterances.67sweater_prices = [1.5, 4.5, 7.5, 10.5, 13.5, 16.5]8sweater_probs = [0.00482838499944466, 0.00832934578733181, 0.0112952500492109,9 0.0173774790108894, 0.0232006658974883, 0.0258422772579257]10# normalize the (sub-normalized) price prior, as WebPPL's categorical does11_z = sum(sweater_probs)12price_prior_probs = torch.tensor([p / _z for p in sweater_probs])13prices = torch.tensor(sweater_prices)14n_prices = len(sweater_prices)1516alpha = 2.017utterances = ["expensive", "not-inexpensive"]18cost = {"expensive": 1.0, "not-inexpensive": 2.0}1920NEG_INF = torch.tensor(float("-inf"))21ZERO = torch.tensor(0.0)222324def meaning_holds(utterance, price_idx, theta_exp_idx, theta_inexp_idx):25 # returns a boolean tensor: does the utterance's literal meaning hold for this26 # price under the given thresholds? prices are indexed; theta thresholds are27 # also one of the prices (uniformDraw over prices).28 price_val = prices[price_idx]29 if utterance == "expensive":30 return price_val > prices[theta_exp_idx]31 else: # not-inexpensive: NOT (price < inexpensive threshold)32 return ~(price_val < prices[theta_inexp_idx])333435def literal_listener_logprobs(utterance, theta_exp_idx, theta_inexp_idx):36 # exact enumeration over price (uniform prior over prices, conditioned on37 # meaning) via a Pyro enumerated model. Returns log-probs over price support.38 @pyro.infer.config_enumerate39 def model():40 price_idx = pyro.sample("price", dist.Categorical(torch.ones(n_prices) / n_prices))41 holds = meaning_holds(utterance, price_idx, theta_exp_idx, theta_inexp_idx)42 # soft meaning: flip(0.9999) if holds else flip(0.0001), then condition true43 logp_true = torch.where(holds, torch.log(torch.tensor(0.9999)),44 torch.log(torch.tensor(0.0001)))45 pyro.factor("meaning", logp_true)46 return price_idx47 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)48 m = marg["price"]49 sup = m.enumerate_support()50 logps = m.log_prob(sup)51 # index aligned to 0..n_prices-152 out = torch.full((n_prices,), float("-inf"))53 for s, lp in zip(sup.tolist(), logps.tolist()):54 out[int(s)] = lp55 return out565758def speaker_logprobs(price_idx, theta_exp_idx, theta_inexp_idx):59 # exact enumeration over utterance choice via Pyro. Returns log-probs over the60 # two utterances.61 ll_scores = {}62 for u in utterances:63 ll_lp = literal_listener_logprobs(u, theta_exp_idx, theta_inexp_idx)64 ll_scores[u] = ll_lp[price_idx]6566 @pyro.infer.config_enumerate67 def model():68 u_idx = pyro.sample("utt", dist.Categorical(torch.ones(len(utterances)) / len(utterances)))69 # factor( alpha * (literalListener.score(price) - cost[utterance]) )70 score_vals = torch.tensor([alpha * (ll_scores[u] - cost[u]) for u in utterances])71 pyro.factor("util", score_vals[u_idx])72 return u_idx73 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)74 m = marg["utt"]75 sup = m.enumerate_support()76 logps = m.log_prob(sup)77 out = torch.full((len(utterances),), float("-inf"))78 for s, lp in zip(sup.tolist(), logps.tolist()):79 out[int(s)] = lp80 return out818283# Memoize the speaker log-score (per utterance) over the unique (price, exp-theta,84# inexp-theta) configurations it depends on; each entry comes from a genuine Pyro85# enumeration of the speaker (which itself drives the literal-listener enumeration).86_spk_cache = {}87def speaker_logprobs_cached(price_idx, theta_exp_idx, theta_inexp_idx):88 key = (price_idx, theta_exp_idx, theta_inexp_idx)89 if key not in _spk_cache:90 _spk_cache[key] = speaker_logprobs(price_idx, theta_exp_idx, theta_inexp_idx)91 return _spk_cache[key]929394def pragmatic_listener_price_expectation(utterance):95 u_target = utterances.index(utterance)96 # The pragmatic listener jointly samples: an_neg ~ flip(0.2), expensive_theta ~97 # uniformDraw(prices), a fresh threshold ~ uniformDraw(prices), and price ~98 # pricePrior; inexpensive_threshold = fresh if an_neg else expensive_theta. The99 # heard utterance is observed by factoring the speaker's log-score. This whole100 # joint is run through Pyro's exact discrete enumeration (config_enumerate +101 # compute_marginals); the speaker score enters via a precomputed per-config102 # table (each entry itself the result of a Pyro speaker inference).103 #104 # Build the speaker-score table over the joint (an_neg, exp_theta, fresh, price).105 # inexp_theta = fresh if an_neg==1 else exp_theta. Index it inside the model.106 spk_table = torch.full((2, n_prices, n_prices, n_prices), float("-inf"))107 for an_neg in (0, 1):108 for exp_theta in range(n_prices):109 for fresh in range(n_prices):110 inexp_theta = fresh if an_neg == 1 else exp_theta111 for price in range(n_prices):112 sp_lp = speaker_logprobs_cached(price, exp_theta, inexp_theta)113 spk_table[an_neg, exp_theta, fresh, price] = sp_lp[u_target]114115 @pyro.infer.config_enumerate116 def prag_model():117 an_neg = pyro.sample("an_neg", dist.Categorical(torch.tensor([0.8, 0.2])))118 exp_theta = pyro.sample("exp_theta", dist.Categorical(torch.ones(n_prices) / n_prices))119 fresh = pyro.sample("fresh", dist.Categorical(torch.ones(n_prices) / n_prices))120 price = pyro.sample("price", dist.Categorical(price_prior_probs))121 score = pyro.ops.indexing.Vindex(spk_table)[an_neg, exp_theta, fresh, price]122 pyro.factor("obs", score)123 return price124125 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(prag_model, lambda: None)126 m = marg["price"]127 sup = m.enumerate_support()128 probs = m.log_prob(sup).exp()129 return float(sum(prices[int(s)].item() * float(pr) for s, pr in zip(sup, probs)))130131132ANSWER = {133 "expensivePrice": pragmatic_listener_price_expectation("expensive"),134 "notInexpensivePrice": pragmatic_listener_price_expectation("not-inexpensive"),135}136
| check | status | evidence |
|---|---|---|
| GT self-consistency | ok | floor 0.0000 (record) |
| solver re-derivation | accept | 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6 |
| cross-language (pyro vs webppl) | pass | d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000 |