problems 29 / 29 columns webppl · pyro
methodology ↑↓jk nav · / search
Select a problem from the list →
forestdb-2025-problang-adjectives-qud / atom-1
answer dist/finite solver accept pyro pass 0.0080
00 statement source: data/sources/forestdb.org/models/2025-problang-adjectives-qud.md
given

Price grid: [50, 500, 1000, 5000, 10000]. Prior probabilities over prices: [0.8070, 0.1070, 0.0434, 0.0223, 0.0203]. Valence (a boolean representing whether the speaker feels positively about the item) has a price-dependent prior: at price 50 the probability of positive valence is 0.3173; at 500 it is 0.7920; at 1000 it is 0.8933; at 5000 it is 0.9524; at 10000 it is 0.9864. Utterances are "expensive" and "notExpensive". Cost of "notExpensive" is 1; cost of "expensive" is 0. The threshold theta is drawn uniformly from the price grid. The QUD (question under discussion) is drawn with equal probability from three options: report only the price, report only the valence, or report both price and valence. Speaker optimality parameter alpha=1. Meaning: "expensive" is true when price >= theta; "notExpensive" is true when price <= theta.

model

A Rational Speech Act model combining vague-adjective interpretation with QUD uncertainty. A literal listener hears an utterance, samples a price uniformly from the price grid (not the categorical price prior) and a valence from the valence prior, conditions on the meaning of the utterance given the threshold theta, and returns the QUD-relevant part of the full state {price, valence}. The pragmatic listener uses the categorical price prior (the weights given above) when sampling the full state. A pragmatic speaker chooses utterances with probability proportional to exp(alpha * (literal-listener score on the QUD answer minus utterance cost)). A pragmatic listener reasons about the full state {price, valence} by marginalizing over the QUD and threshold, observing the speaker's choice of utterance.

query

The posterior joint distribution over {price, valence} for a pragmatic listener who hears the utterance "expensive".

answer spec dist/finite
{
  "kind": "dist",
  "domain": "finite",
  "labels": {
    "record": {
      "price": "int",
      "valence": "bool"
    }
  }
}
system prompt constant across problems
(system prompt loads here)
webppl primer solver context
(primer loads here)
01 realizations comparing webppl vs pyro
ground truth
webppl
1// ADJECTIVES + QUD MODEL
2// frankie + shane RSA project
3
4// code adapted from the Kao et al. hyperbole model +
5// gradable adjectives & vagueness resolution model
6
7
8var utterances = ["expensive", "notExpensive"]
9
10var utterancePrior = function() {
11 return uniformDraw(utterances)
12}
13
14var thetaPrior = function() {
15 return uniformDraw(prices)
16}
17
18// theta moderates interpretation of utterances
19var meaning = function(utterance, price, theta) {
20return utterance == "expensive" ? price >= theta :
21 utterance == "notExpensive" ? price <= theta :
22 true
23}
24
25// more words = higher cost
26var cost = function(utterance) {
27 return utterance== 'notExpensive'? 1 :
28 0
29};
30
31var prices = [
32 50,
33 500,
34 1000,
35 5000,
36 10000
37]
38
39var pricePrior = function() {
40 return categorical({
41 vs: prices,
42 ps: [
43 0.8070,
44 0.1070,
45 0.0434,
46 0.0223,
47 0.0203
48 ]
49 })
50}
51
52var valencePrior = function(state) {
53 var probs = {
54 50 : 0.3173,
55 500 : 0.7920,
56 1000 : 0.8933,
57 5000 : 0.9524,
58 10000 : 0.9864
59 }
60 var tf = flip(probs[state])
61 return tf
62}
63
64var qudFns = {
65 price : function(state) {return { price: state.price } },
66 valence : function(state) {return { valence: state.valence } },
67 priceValence : function(state) {
68 return { price: state.price, valence: state.valence }
69 }
70 }
71
72var qudPrior = function() {
73 categorical({
74 vs: ["price", "valence", "priceValence"],
75 ps: [1, 1, 1]
76 })
77}
78
79var literalListener = cache(function(utterance, qud, theta) {
80 return Infer({model: function(){
81 var price = uniformDraw(prices)
82 var valence = valencePrior(price)
83 var fullState = {price, valence}
84 var qudFn = qudFns[qud]
85 var qudAnswer = qudFn(fullState)
86 condition( meaning(utterance, price, theta) )
87 return qudAnswer
88 }
89 })})
90
91// speaker optimality
92var alpha = 1
93
94var speaker = cache(function(fullState, qud, theta) {
95 return Infer({model: function(){
96 var utterance = utterancePrior()
97 var qudFn = qudFns[qud]
98 var qudAnswer = qudFn(fullState)
99 factor(alpha*(literalListener(utterance,qud,theta).score(qudAnswer)
100 - cost(utterance)))
101 return utterance
102 }})
103})
104
105var pragmaticListener = cache(function(utterance) {
106 return Infer({model: function(){
107 //////// priors ////////
108 var price = pricePrior()
109 var valence = valencePrior(price)
110 var qud = qudPrior()
111 var theta = thetaPrior()
112 ////////////////////////
113 var fullState = {price, valence}
114 observe(speaker(fullState, qud, theta), utterance)
115 return {price, valence}
116
117 }})
118})
119
120var listenerPosterior1 = pragmaticListener("expensive")
121var listenerPosterior2 = pragmaticListener('notExpensive')
122
123print('Pragmatic listener hears "expensive":')
124viz(listenerPosterior1)
125print('Pragmatic listener hears "not expensive":')
126viz(listenerPosterior2)
127
128var ANSWER = (pragmaticListener("expensive"));
realization0.008
python
1# RSA adjectives + QUD model (frankie + shane). Pragmatic listener hears "expensive";
2# query is the joint posterior over {price, valence}.
3# Every RSA level is a genuine Pyro enumerated model (config_enumerate +
4# TraceEnum_ELBO.compute_marginals); each level's normalized posterior comes from Pyro's
5# engine, and lower-level log-scores feed the next level only through pyro.factor.
6#
7# State = (price, valence) with valence|price ~ Bernoulli(valence_p[price]).
8# meaning depends only on price, so under L0 valence is conditionally independent of the
9# evidence given price: the L0 joint factors as L0(price)*prior(valence|price), where the
10# L0 price posterior is produced by Pyro inference (compute_marginals) and valence|price is
11# its (un-inferred) prior. The speaker scores L0's score of the QUD projection of its
12# fullState; the pragmatic listener enumerates (price, valence, qud, theta) and factors in
13# the speaker's log-score of "expensive".
14
15utterances = ["expensive", "notExpensive"]
16prices = [50, 500, 1000, 5000, 10000]
17price_t = torch.tensor([float(p) for p in prices])
18price_prior = torch.tensor([0.8070, 0.1070, 0.0434, 0.0223, 0.0203])
19valence_p = torch.tensor([0.3173, 0.7920, 0.8933, 0.9524, 0.9864]) # P(valence=True | price)
20quds = ["price", "valence", "priceValence"]
21alpha = 1.0
22cost = {"expensive": 0.0, "notExpensive": 1.0}
23NEG = torch.tensor(float("-inf"))
24ZERO = torch.tensor(0.0)
25
26def meaning_mask(utterance, theta):
27 if utterance == "expensive":
28 return price_t >= theta
29 return price_t <= theta # notExpensive
30
31# Literal listener L0: a Pyro enumerated model over price (uniformDraw over prices) with the
32# meaning factor. compute_marginals returns the normalized price posterior (log-probs).
33def literal_price_logprobs(utterance, theta):
34 mask = meaning_mask(utterance, theta)
35 @pyro.infer.config_enumerate
36 def m():
37 pi = pyro.sample("price", dist.Categorical(torch.ones(len(prices)) / len(prices)))
38 pyro.factor("meaning", torch.where(mask[pi], ZERO, NEG))
39 return None
40 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(m, lambda: None)
41 pm = marg["price"]
42 sup = pm.enumerate_support()
43 lp = pm.log_prob(sup)
44 out = torch.full((len(prices),), float("-inf"))
45 for j in range(len(sup)):
46 out[int(sup[j].item())] = lp[j]
47 return out
48
49lit_cache = {}
50def literal_price_logprobs_cached(utterance, theta):
51 key = (utterance, round(float(theta), 6))
52 if key not in lit_cache:
53 lit_cache[key] = literal_price_logprobs(utterance, theta)
54 return lit_cache[key]
55
56# L0 log-score of a QUD answer for the speaker's known fullState (price index pidx, valence v).
57# qud == price : score = log L0price[pidx]
58# qud == valence : score = log sum_p L0price[p] * P(v | p)
59# qud == priceValence: score = log( L0price[pidx] * P(v | pidx) )
60def l0_qud_logscore(utterance, theta, qud, pidx, v):
61 lpr = literal_price_logprobs_cached(utterance, theta) # log P(price) under L0
62 pr = lpr.exp()
63 vp = valence_p if v else (1.0 - valence_p)
64 if qud == "price":
65 s = float(pr[pidx].item())
66 elif qud == "valence":
67 s = float((pr * vp).sum().item())
68 else: # priceValence
69 s = float((pr[pidx] * vp[pidx]).item())
70 return math.log(s) if s > 0.0 else float("-inf")
71
72# Speaker S1(utterance | fullState, qud, theta): Pyro enumerated model over utterances with
73# factor alpha*(logL0(qudAnswer|utt) - cost(utt)). compute_marginals normalizes.
74def speaker_logprobs(theta, qud, pidx, v):
75 score = torch.tensor([
76 alpha * (l0_qud_logscore(u, theta, qud, pidx, v) - cost[u]) for u in utterances
77 ])
78 @pyro.infer.config_enumerate
79 def m():
80 ui = pyro.sample("utt", dist.Categorical(torch.ones(len(utterances)) / len(utterances)))
81 pyro.factor("score", score[ui])
82 return None
83 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(m, lambda: None)
84 um = marg["utt"]
85 sup = um.enumerate_support()
86 lp = um.log_prob(sup)
87 out = torch.full((len(utterances),), float("-inf"))
88 for j in range(len(sup)):
89 out[int(sup[j].item())] = lp[j]
90 return out
91
92target_idx = utterances.index("expensive")
93thetas = list(prices) # thetaPrior = uniformDraw(prices)
94
95# Speaker log-score of "expensive" indexed by (price index, valence in {0,1}, qud, theta).
96# Used as the pragmatic listener's factor.
97spk = {}
98for pidx in range(len(prices)):
99 for v in (True, False):
100 for qi, qud in enumerate(quds):
101 for ti, theta in enumerate(thetas):
102 spk[(pidx, int(v), qi, ti)] = speaker_logprobs(theta, qud, pidx, v)[target_idx]
103
104spk_table = torch.zeros(len(prices), 2, len(quds), len(thetas))
105for (pidx, vi, qi, ti), val in spk.items():
106 spk_table[pidx, vi, qi, ti] = val
107
108# Pragmatic listener L1: Pyro enumerated model over price, valence, qud, theta; factor in the
109# speaker's log-score of "expensive". infer_discrete draws joint (price, valence) posterior
110# samples (the query is a joint over two latents); aggregate the tuples into the distribution.
111@pyro.infer.config_enumerate
112def pragmatic_model():
113 pi = pyro.sample("price", dist.Categorical(price_prior))
114 vi = pyro.sample("valence", dist.Bernoulli(valence_p[pi]))
115 qi = pyro.sample("qud", dist.Categorical(torch.ones(len(quds)) / len(quds)))
116 ti = pyro.sample("theta", dist.Categorical(torch.ones(len(thetas)) / len(thetas)))
117 pyro.factor("speaker", spk_table[pi, vi.long(), qi, ti])
118 return pi, vi
119
120serving = pyro.infer.infer_discrete(
121 pyro.infer.config_enumerate(pragmatic_model), first_available_dim=-1
122)
123
124counts = Counter()
125N = 8000
126for _ in range(N):
127 pi, vi = serving()
128 p = prices[int(pi.item())]
129 v = bool(vi.item() > 0.5)
130 counts[(p, v)] += 1
131
132_json = __import__("json")
133ANSWER = {
134 _json.dumps({"price": p, "valence": v}, sort_keys=True): c / N
135 for (p, v), c in counts.items()
136}
137
02answer overlay — webppl vs pyrodist/finite
webppl pyro10 bins
00.150.150.300.30{"price":1000,"valence":false} A = 0.008 B = 0.008{"price":1000,"valence":false} A = 0.008 B = 0.0080.010.01{"price":1000,"valence":false}{"price":1000,"valence":true} A = 0.090 B = 0.089{"price":1000,"valence":true} A = 0.090 B = 0.0890.090.09{"price":1000,"valence":true}{"price":10000,"valence":false} A = 0.001 B = 0.001{"price":10000,"valence":false} A = 0.001 B = 0.001{"price":10000,"valence":false}{"price":10000,"valence":true} A = 0.068 B = 0.071{"price":10000,"valence":true} A = 0.068 B = 0.0710.070.07{"price":10000,"valence":true}{"price":50,"valence":false} A = 0.300 B = 0.302{"price":50,"valence":false} A = 0.300 B = 0.3020.300.30{"price":50,"valence":false}{"price":50,"valence":true} A = 0.296 B = 0.293{"price":50,"valence":true} A = 0.296 B = 0.2930.300.29{"price":50,"valence":true}{"price":500,"valence":false} A = 0.025 B = 0.026{"price":500,"valence":false} A = 0.025 B = 0.0260.030.03{"price":500,"valence":false}{"price":500,"valence":true} A = 0.149 B = 0.148{"price":500,"valence":true} A = 0.149 B = 0.1480.150.15{"price":500,"valence":true}{"price":5000,"valence":false} A = 0.002 B = 0.003{"price":5000,"valence":false} A = 0.002 B = 0.003{"price":5000,"valence":false}{"price":5000,"valence":true} A = 0.061 B = 0.059{"price":5000,"valence":true} A = 0.061 B = 0.0590.060.06{"price":5000,"valence":true}
03 verification
checkstatusevidence
GT self-consistency ok floor 0.0000 (tv)
solver re-derivation accept 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6
cross-language (pyro vs webppl) pass d=0.0080 ≤ tol 0.0270 · floors 0.0135/0.0000
forestdb-2025-problang-teasing / atom-1
answer dist/int solver accept pyro pass 0.0000
00 statement source: data/sources/forestdb.org/models/2025-problang-teasing.md
given

Intelligence states: {1, 2, 3, 4} (1=least, 4=most intelligent). Utterances: {"", "dumb as rocks", "dumb", "f*cking idiot"}. Literal semantics: each utterance has a probability of describing each state. For state indices 1–4 (in order): "" → [0.25, 0.25, 0.25, 0.25]; "dumb as rocks" → [0.45, 0.85, 0.20, 0.02]; "dumb" → [0.85, 0.95, 0.02, 0.02]; "f*cking idiot" → [0.95, 0.55, 0.02, 0.02]. Valence ("good" or "bad") prior per state: state 1 → P(good)=0.01; state 2 → P(good)=0.33; state 3 → P(good)=0.66; state 4 → P(good)=0.99. Arousal ("low" or "high") prior per state: state 1 → P(low)=0.1; state 2 → P(low)=0.7; state 3 → P(low)=0.7; state 4 → P(low)=0.1. Goals: {"goalState", "goalValence", "goalArousal"}, equal prior weights. phi is drawn uniformly from {0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.35, 0.40, 0.45, 0.50, 0.55, 0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90} (18 equally spaced values). Speaker optimality alpha=10. Antisocial scaling lambda=-1.25. The state prior is uniform over {1, 2, 3, 4}.

model

A multi-goal RSA model for teasing language. A literal listener hears an utterance and a goal; the utterance is interpreted stochastically (with the probability that the utterance describes the state given its literal semantics weight); the listener returns both the goal-relevant QUD answer and the state. A pragmatic speaker has a mixed utility: the epistemic component is the literal listener's log-probability of recovering the correct QUD answer; the antisocial component is lambda times the expected state under the literal listener's posterior state distribution given the utterance (the value function is the state value itself — no rescaling or normalization). The speaker's total utility is phi times the epistemic component plus (1-phi) times the antisocial component, exponentiated by alpha. A pragmatic listener samples a state, valence, arousal, goal, and phi, and conditions on the speaker's utterance choice.

query

The posterior marginal distribution over the intelligence state (integer values 1–4) for a pragmatic listener who hears "dumb as rocks".

answer spec dist/int
{
  "kind": "dist",
  "domain": "int"
}
system prompt constant across problems
(system prompt loads here)
webppl primer solver context
(primer loads here)
01 realizations comparing webppl vs pyro
ground truth
webppl
1// Level of intelligence
2var states = [1,2,3,4]
3
4// What can be said
5var utterances = ["", "dumb as rocks", "dumb", "f*cking idiot"]
6
7// Correspondence of utterances to states
8var literalSemantics = {
9 "" : [.25,.25,.25,.25],
10 "dumb as rocks" : [0.45,0.85,0.20,.02],
11 "dumb" : [.85,.95,.02,.02],
12 "f*cking idiot" : [.95,.55,.02,.02]
13}
14
15// Determine whether the utterance describes the state
16// Flip a coin with the literalSemantics weight
17// *state - 1 because of 0-indexing*
18var meaning = function(utterance, state){
19 return flip(literalSemantics[utterance][state - 1]);
20}
21
22// Whether the speaker feels good or bad about the listener's action.
23// Speakers' attitudes toward the true state of the world (e.g., the valence of their affect), which is modeled simply as a binary positive/negative variable (representing whether or not the speaker is repulsed by the listener's behavior).
24var valencePrior = function(state) {
25 state === 1 ? flip(0.01) ? "good" : "bad" :
26 state === 2 ? flip(0.33) ? "good" : "bad" :
27 state === 3 ? flip(0.66) ? "good" : "bad" :
28 state === 4 ? flip(0.99) ? "good" : "bad" :
29 true}
30
31// How amplified the speaker feels about the listener being dumb to different degrees.
32var arousals = ["low", "high"]
33
34//How passionate/aroused the listener feels about how intelligent the person is
35var arousalPrior = function(state) {
36 state === 1 ? categorical([0.1, 0.9], arousals) :
37 state === 2 ? categorical([0.7, 0.3], arousals) :
38 state === 3 ? categorical([0.7, 0.3], arousals) :
39 state === 4 ? categorical([0.1, 0.9], arousals) :
40 true
41}
42
43// A list of strings of QUD choices
44var goals = ["goalState", "goalValence", "goalArousal"]
45
46// There are 3 possible goals with a flat prior
47var goalPrior = function() {
48 categorical([1, 1, 1], goals)
49}
50
51// A speaker's goal is satisfied if the listener infers the correct
52// and relevant information.
53var goalState = function(goal, state, valence, arousal) {
54 goal === "goalState" ? state :
55 goal === "goalValence" ? valence :
56 goal === "goalArousal" ? arousal :
57 true
58}
59
60// literal listener
61var literalListener = function(utterance, goal) {
62 Infer({model: function(){
63 var state = uniformDraw(states)
64 var valence = valencePrior(state)
65 var arousal = arousalPrior(state)
66 var m = meaning(utterance, state);
67 condition(m);
68 return {QUDanswer: goalState(goal,state,valence,arousal), state: state}
69 }})
70}
71
72// value function scales social utility by a parameter lambda
73var lambda = -1.25
74var valueFunction = function(s){
75 return lambda * s
76};
77
78var alpha = 10
79var speaker1 = function(state, phi, goal, valence, arousal) {
80 Infer({model: function(){
81 var utterance = uniformDraw(utterances)
82 var QUDanswer = goalState(goal,state,valence,arousal)
83 var L0_posterior = literalListener(utterance,goal)
84 var L0_stateMarginal = marginalize(L0_posterior,"state")
85 var L0_qudAnswerMarginal = marginalize(L0_posterior,"QUDanswer")
86 var utility = {
87 epistemic: L0_qudAnswerMarginal.score(QUDanswer),
88 antisocial: expectation(L0_stateMarginal, valueFunction)
89 }
90 var speakerUtility = phi * utility.epistemic +
91 (1 - phi) * utility.antisocial
92 factor(alpha * speakerUtility)
93 return utterance
94 }})
95}
96
97//pragmatic listener
98var pragmaticListener = function(utterance) {
99 Infer({model: function(){
100 var state = uniformDraw(states)
101 var valence = valencePrior(state)
102 var arousal = arousalPrior(state)
103 var goal = goalPrior()
104//0.05 corresponds to the starting point of phi while 0.95 is the endpoint.
105//The last 0.05 corresponds to the interval length
106 var phi = uniformDraw(_.range(0.05, 0.95, 0.05))
107 var S1 = speaker1(state, phi, goal, valence, arousal)
108 observe(S1, utterance)
109 return { state, phi, goal, valence, arousal }
110 }})
111}
112viz.marginals(pragmaticListener(""))
113viz.marginals(pragmaticListener("dumb as rocks"))
114viz.marginals(pragmaticListener("dumb"))
115viz.marginals(pragmaticListener("f*cking idiot"))
116
117var ANSWER = marginalize(pragmaticListener("dumb as rocks"), "state");
realization0.000
python
1
2states = [1, 2, 3, 4]
3utterances = ["", "dumb as rocks", "dumb", "f*cking idiot"]
4literalSemantics = {
5 "": [0.25, 0.25, 0.25, 0.25],
6 "dumb as rocks": [0.45, 0.85, 0.20, 0.02],
7 "dumb": [0.85, 0.95, 0.02, 0.02],
8 "f*cking idiot": [0.95, 0.55, 0.02, 0.02],
9}
10valence_good = {1: 0.01, 2: 0.33, 3: 0.66, 4: 0.99}
11arousal_low = {1: 0.1, 2: 0.7, 3: 0.7, 4: 0.1}
12goals = ["goalState", "goalValence", "goalArousal"]
13phis = [round(0.05 + 0.05 * i, 10) for i in range(18)] # 0.05..0.90 step 0.05
14alpha = 10.0
15lam = -1.25
16valence_labels = ["good", "bad"]
17arousal_labels = ["low", "high"]
18
19def _elbo():
20 return pyro.infer.TraceEnum_ELBO(max_plate_nesting=0)
21
22_LL_CACHE = {}
23def literal_listener_marginals(utterance, goal):
24 key = (utterance, goal)
25 if key in _LL_CACHE:
26 return _LL_CACHE[key]
27 sem = torch.tensor(literalSemantics[utterance])
28
29 @pyro.infer.config_enumerate
30 def model():
31 si = pyro.sample("state", dist.Categorical(torch.ones(4) / 4))
32 pg = torch.tensor([valence_good[s] for s in states])[si]
33 pyro.sample("valence", dist.Categorical(torch.stack([pg, 1 - pg], -1)))
34 pl = torch.tensor([arousal_low[s] for s in states])[si]
35 pyro.sample("arousal", dist.Categorical(torch.stack([pl, 1 - pl], -1)))
36 pyro.factor("meaning", torch.log(sem[si])) # condition on literal meaning
37 return si
38
39 marg = _elbo().compute_marginals(model, lambda: None)
40 _LL_CACHE[key] = marg
41 return marg
42
43def qud_answer(goal, state, valence, arousal):
44 if goal == "goalState":
45 return ("state", state)
46 if goal == "goalValence":
47 return ("valence", valence)
48 return ("arousal", arousal)
49
50def ll_logp_qud(utterance, goal, ans):
51 marg = literal_listener_marginals(utterance, goal)
52 kind, val = ans
53 if kind == "state":
54 d = marg["state"]; idx = states.index(val)
55 elif kind == "valence":
56 d = marg["valence"]; idx = valence_labels.index(val)
57 else:
58 d = marg["arousal"]; idx = arousal_labels.index(val)
59 return float(d.log_prob(torch.tensor(idx)))
60
61def ll_state_expectation(utterance, goal):
62 d = literal_listener_marginals(utterance, goal)["state"]
63 sup = d.enumerate_support()
64 p = torch.exp(d.log_prob(sup))
65 return float(sum(lam * states[int(i)] * float(pp) for i, pp in zip(sup, p)))
66
67_S1_CACHE = {}
68def speaker1(state, phi, goal, valence, arousal):
69 key = (state, phi, goal, valence, arousal)
70 if key in _S1_CACHE:
71 return _S1_CACHE[key]
72 ans = qud_answer(goal, state, valence, arousal)
73 utils = torch.tensor([
74 phi * ll_logp_qud(u, goal, ans) + (1 - phi) * ll_state_expectation(u, goal)
75 for u in utterances
76 ])
77
78 @pyro.infer.config_enumerate
79 def model():
80 ui = pyro.sample("utt", dist.Categorical(torch.ones(len(utterances)) / len(utterances)))
81 pyro.factor("util", alpha * utils[ui])
82 return ui
83
84 marg = _elbo().compute_marginals(model, lambda: None)["utt"]
85 _S1_CACHE[key] = marg
86 return marg
87
88target = "dumb as rocks"
89target_idx = utterances.index(target)
90
91# enumerated joint config space for the pragmatic listener
92cfgs = []
93for st in states:
94 for val in valence_labels:
95 for ar in arousal_labels:
96 for goal in goals:
97 for phi in phis:
98 cfgs.append((st, val, ar, goal, phi))
99
100# prior over each config (valence/arousal priors depend on state)
101prior = torch.tensor([
102 (1.0 / len(states))
103 * (valence_good[c[0]] if c[1] == "good" else 1 - valence_good[c[0]])
104 * (arousal_low[c[0]] if c[2] == "low" else 1 - arousal_low[c[0]])
105 * (1.0 / len(goals))
106 * (1.0 / len(phis))
107 for c in cfgs
108])
109
110# speaker likelihood of the target utterance for each config (from S1 inference)
111lik = torch.tensor([
112 float(torch.exp(speaker1(c[0], c[4], c[3], c[1], c[2]).log_prob(torch.tensor(target_idx))))
113 for c in cfgs
114])
115
116@pyro.infer.config_enumerate
117def pragmatic_model():
118 ci = pyro.sample("cfg", dist.Categorical(prior / prior.sum()))
119 pyro.factor("speaker", torch.log(lik[ci] + 1e-300)) # observe(S1, utterance)
120 return ci
121
122cfg_marg = _elbo().compute_marginals(pragmatic_model, lambda: None)["cfg"]
123sup = cfg_marg.enumerate_support()
124p = torch.exp(cfg_marg.log_prob(sup))
125state_post = {s: 0.0 for s in states}
126for i, pp in zip(sup, p):
127 state_post[cfgs[int(i)][0]] += float(pp)
128
129ANSWER = state_post
130
02answer overlay — webppl vs pyrodist/int
webppl pyro4 bins · 1 … 4
00.250.250.490.491 A = 0.046 B = 0.0461 A = 0.046 B = 0.0460.050.0512 A = 0.325 B = 0.3252 A = 0.325 B = 0.3250.330.3323 A = 0.495 B = 0.4953 A = 0.495 B = 0.4950.490.4934 A = 0.134 B = 0.1344 A = 0.134 B = 0.1340.130.134
03 verification
checkstatusevidence
GT self-consistency ok floor 0.0000 (w1)
solver re-derivation accept 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6
cross-language (pyro vs webppl) pass d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000
forestdb-adjectives-qud / atom-1
answer value/real solver accept pyro pass 0.0000
00 statement source: data/sources/forestdb.org/models/adjectives-qud.md
given

Ice cream price grid: [1, 4, 6, 10, 14, 18, 22, 30, 34, 38] dollars. Prior probabilities over prices (proportional, summing to 1 after normalization): [0.01, 0.50, 0.85, 0.63, 0.35, 0.10, 0.04, 0.02, 0.02, 0.01]. Utterances: {"expensive", "null", "cheap"}. Utterance costs: expensive=1, cheap=2, null=0. The utterance prior is proportional to exp(-cost). The threshold theta is drawn uniformly from the price grid. The QUD is fixed to "what is the price?" (the QUD always returns the raw price value). Speaker optimality parameter alpha=1. Meaning: "expensive" is true when price >= theta; "cheap" is true when price <= theta; "null" is always true.

model

A Rational Speech Act model for vague adjectives. A literal listener hears an utterance and a threshold theta, conditions on the utterance meaning holding for the sampled price, and returns the QUD-relevant value (here: the raw price). A pragmatic speaker weighs utterances by exp(alpha * literal-listener-score) discounted by utterance cost. A pragmatic listener samples a price from the price prior and a threshold from the uniform grid prior, observes the speaker's utterance choice, and returns the joint {price, theta, qud}.

query

The posterior expected price (in dollars) for a pragmatic listener who hears "expensive".

answer spec value/real
{
  "kind": "value",
  "domain": "real"
}
system prompt constant across problems
(system prompt loads here)
webppl primer solver context
(primer loads here)
01 realizations comparing webppl vs pyro
ground truth
webppl
1var marginalize = function(dist, key){
2 return Infer( {model: function(){
3 return sample(dist)[key];
4 }})
5};
6///
7
8var icecream = {
9 "prices": [1, 4, 6, 10, 14, 18, 22, 30, 34, 38],
10"probabilities": [0.01, 0.50, 0.85, 0.63, 0.35, 0.10, 0.04, 0.02, 0.02, 0.01]
11};
12
13var statePrior = function() {
14 return categorical(icecream.probabilities, icecream.prices);
15};
16
17var thetaPrior = function() {
18 return uniformDraw(icecream.prices);
19};
20
21var alpha = 1; // optimality parameter
22
23var utterances = ["expensive", "null", "cheap"];
24var cost = {
25 "expensive": 1,
26 "cheap": 2,
27 "null": 0
28};
29var utterancePrior = function() {
30 var uttProbs = map(function(u) {return Math.exp(-cost[u]) }, utterances);
31 return categorical(uttProbs, utterances);
32};
33
34var meaning = function(utterance, price, theta) {
35 utterance == "expensive" ? price >= theta :
36 utterance == "cheap" ? price <= theta :
37 true
38};
39// QUDs
40var QUDs = ["expensive?","less than 15?","what is the price?"]
41var QUDPrior = function() {
42 //uniformDraw(QUDs)
43// categorical([1,10,1],QUDs)// this is equivalent to uniformDraw
44categorical([0,0,1],QUDs)// this is your baseline version
45}
46var QUDFun = function(QUD,state) {
47 QUD == "expensive?" ? state >= 15 :
48 QUD == "less than 15?" ? state <= 15 :
49 state;
50};
51
52var literalListener = cache(function(utterance, theta, QUD) {
53 return Infer({model:function() {
54 var price = statePrior()
55 var qPrice = QUDFun(QUD,price)
56 condition(meaning(utterance, price, theta))
57 return qPrice;
58}})
59});
60
61var speaker = cache(function(price, theta,QUD) {
62 return Infer( {model: function() {
63 var utterance = utterancePrior()
64 var qPrice= QUDFun(QUD, price)
65 factor( alpha * literalListener(utterance, theta, QUD).score(qPrice) );
66 return utterance;
67 }});
68});
69
70var pragmaticListener = function(utterance) {
71 return Infer({model: function() {
72 var price = statePrior()
73 var theta = thetaPrior()
74 var QUD =QUDPrior()
75 factor(speaker(price, theta, QUD).score(utterance));
76 return { price: price, theta: theta, qud: QUD};
77 }})
78}
79
80
81
82print ('cheap icecream')
83var expensiveIcecream = pragmaticListener('expensive');
84print(expectation(marginalize(expensiveIcecream, "price")))
85viz.hist(marginalize(expensiveIcecream, "price"));
86viz.hist(marginalize(expensiveIcecream, "theta"));
87viz.hist(marginalize(expensiveIcecream, "qud"));
88
89var ANSWER = (expectation(marginalize(pragmaticListener('expensive'), 'price')));
90
realization0.000
python
1# RSA icecream adjectives + QUD model. Pragmatic listener hears "expensive".
2# QUDPrior puts all mass on "what is the price?" (categorical([0,0,1], QUDs)) so QUDFun is the
3# identity on the price -> the qud answer is the price itself.
4# Query: posterior expected price for the pragmatic listener.
5# Every RSA level is a genuine Pyro model run under exact enumeration
6# (config_enumerate + TraceEnum_ELBO.compute_marginals). Each level's normalized posterior
7# comes from Pyro's engine; lower-level log-scores enter the next level only as pyro.factor
8# terms. No level's distribution is normalized by hand.
9
10prices = [1, 4, 6, 10, 14, 18, 22, 30, 34, 38]
11price_probs = [0.01, 0.50, 0.85, 0.63, 0.35, 0.10, 0.04, 0.02, 0.02, 0.01]
12state_prior = torch.tensor(price_probs) # dist.Categorical normalizes its probs internally
13price_t = torch.tensor([float(p) for p in prices])
14alpha = 1.0
15utterances = ["expensive", "null", "cheap"]
16cost = {"expensive": 1.0, "cheap": 2.0, "null": 0.0}
17utt_prior = torch.tensor([math.exp(-cost[u]) for u in utterances]) # categorical normalizes
18NEG = torch.tensor(float("-inf"))
19ZERO = torch.tensor(0.0)
20
21# meaning(utterance, price, theta) as a boolean tensor over the enumerated price index.
22def meaning_mask(utterance, theta):
23 if utterance == "expensive":
24 return price_t >= theta
25 if utterance == "cheap":
26 return price_t <= theta
27 return torch.ones(len(prices), dtype=torch.bool)
28
29# Literal listener L0(price | utterance, theta): a Pyro enumerated model.
30# Returns the full log P(price) vector indexed by price index (from compute_marginals).
31def literal_logprobs(utterance, theta):
32 mask = meaning_mask(utterance, theta)
33 @pyro.infer.config_enumerate
34 def m():
35 si = pyro.sample("price", dist.Categorical(state_prior))
36 pyro.factor("meaning", torch.where(mask[si], ZERO, NEG))
37 return None
38 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(m, lambda: None)
39 pm = marg["price"]
40 sup = pm.enumerate_support()
41 lp = pm.log_prob(sup)
42 out = torch.full((len(prices),), float("-inf"))
43 for j in range(len(sup)):
44 out[int(sup[j].item())] = lp[j]
45 return out
46
47lit_cache = {}
48def literal_logprobs_cached(utterance, theta):
49 key = (utterance, round(float(theta), 6))
50 if key not in lit_cache:
51 lit_cache[key] = literal_logprobs(utterance, theta)
52 return lit_cache[key]
53
54# Speaker S1(utterance | price, theta): a Pyro enumerated model over utterances.
55# Factor = alpha * log L0(price | utterance, theta). compute_marginals normalizes.
56# Returns log P(utterance) vector indexed by utterance index.
57def speaker_logprobs(price_index, theta):
58 # log L0 of THIS price under each utterance, as a tensor over the utterance index.
59 score = torch.tensor(
60 [float(literal_logprobs_cached(u, theta)[price_index].item()) for u in utterances]
61 )
62 @pyro.infer.config_enumerate
63 def m():
64 ui = pyro.sample("utt", dist.Categorical(utt_prior))
65 pyro.factor("score", alpha * score[ui])
66 return None
67 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(m, lambda: None)
68 um = marg["utt"]
69 sup = um.enumerate_support()
70 lp = um.log_prob(sup)
71 out = torch.full((len(utterances),), float("-inf"))
72 for j in range(len(sup)):
73 out[int(sup[j].item())] = lp[j]
74 return out
75
76spk_cache = {}
77def speaker_logprobs_cached(price_index, theta):
78 key = (price_index, round(float(theta), 6))
79 if key not in spk_cache:
80 spk_cache[key] = speaker_logprobs(price_index, theta)
81 return spk_cache[key]
82
83# Pragmatic listener L1(price | "expensive"): a Pyro enumerated model over price and theta.
84# Factor = log S1("expensive" | price, theta). compute_marginals gives the price posterior.
85target_idx = utterances.index("expensive")
86thetas = list(prices) # thetaPrior = uniformDraw(prices)
87
88# Build the speaker log-score table S1("expensive" | price_index, theta) for the factor.
89spk_table = torch.zeros(len(prices), len(thetas))
90for pi in range(len(prices)):
91 for ti, theta in enumerate(thetas):
92 spk_table[pi, ti] = speaker_logprobs_cached(pi, theta)[target_idx]
93
94@pyro.infer.config_enumerate
95def pragmatic_model():
96 pi = pyro.sample("price", dist.Categorical(state_prior))
97 ti = pyro.sample("theta", dist.Categorical(torch.ones(len(thetas)) / len(thetas)))
98 pyro.factor("speaker", spk_table[pi, ti])
99 return None
100
101marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(pragmatic_model, lambda: None)
102pm = marg["price"]
103sup = pm.enumerate_support()
104probs = pm.log_prob(sup).exp()
105ANSWER = float(sum(prices[int(sup[j].item())] * float(probs[j].item()) for j in range(len(sup))))
106
02answervalue/real
webppl
12.9027
pyro
12.9027
03 verification
checkstatusevidence
GT self-consistency ok floor 0.0000 (absdiff)
solver re-derivation accept 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6
cross-language (pyro vs webppl) pass d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000
forestdb-astt-metaphor / atom-1
answer dist/finite solver accept pyro pass 0.0000
00 statement source: data/sources/forestdb.org/models/astt-metaphor.md
given

Two possible categories for John: "whale" and "person". Prior probability of whale is 0.01; prior probability of person is 0.99. Utterances: {"whale", "person"}, with equal prior probability. Features of John: "large", "graceful", "majestic", each binary (0 or 1). Feature sets are the 8 combinations of these three binary features in the order: {1,1,1}, {1,1,0}, {1,0,1}, {1,0,0}, {0,1,1}, {0,1,0}, {0,0,1}, {0,0,0} (where fields are large, graceful, majestic). Feature-set prior probabilities given category whale (in the same order): [0.30592786494628, 0.138078454222818, 0.179114768847673, 0.13098781834847, 0.0947267162507846, 0.0531420411185539, 0.0601520520596695, 0.0378702842057509]. Feature-set prior probabilities given category person (in the same order): [0.11687632453038, 0.105787535267869, 0.11568145784997, 0.130847056136141, 0.15288225956497, 0.128098151176801, 0.114694702836614, 0.135132512637255]. Speaker goals: {"large", "graceful", "majestic"}, with equal prior probability. The goal selects the single feature the speaker aims to communicate. Speaker optimality parameter alpha=3.

model

A Rational Speech Act model for metaphor. A literal listener hears an utterance and a goal; it conditions on the category matching the utterance literally, then returns the value of the goal feature. A pragmatic speaker observes the true category and feature set, selects a goal, and chooses an utterance with probability proportional to exp(alpha * literal-listener-score on the goal feature value). A pragmatic listener samples a category from the prior, samples a feature set from the category-conditional prior, samples a goal from the goal prior, and conditions on the speaker's utterance choice. The listener returns the joint {category, large, graceful, majestic}.

query

The posterior joint distribution over {category, large, graceful, majestic} for a pragmatic listener who hears the utterance "whale".

answer spec dist/finite
{
  "kind": "dist",
  "domain": "finite",
  "labels": {
    "record": {
      "category": "string",
      "large": "int",
      "graceful": "int",
      "majestic": "int"
    }
  }
}
system prompt constant across problems
(system prompt loads here)
webppl primer solver context
(primer loads here)
01 realizations comparing webppl vs pyro
ground truth
webppl
1// John could either be a whale or a person.
2var categories = ["whale", "person"]
3
4// It is extremely unlikely that John is actually a whale.
5var categoriesPrior = function() {
6 categorical([0.01, 0.99], categories)
7}
8
9// The speaker could either say "John is a whale" or "John is a person."
10var utterances = ["whale", "person"]
11
12// The utterances are equally costly.
13var utterancePrior = function() {
14 categorical([1,1], utterances)
15}
16
17// The features of John being considered are "large", "graceful",
18// "majestic." Features are binary.
19var featureSets = [
20 {large : 1, graceful : 1, majestic : 1},
21 {large : 1, graceful : 1, majestic : 0},
22 {large : 1, graceful : 0, majestic : 1},
23 {large : 1, graceful : 0, majestic : 0},
24 {large : 0, graceful : 1, majestic : 1},
25 {large : 0, graceful : 1, majestic : 0},
26 {large : 0, graceful : 0, majestic : 1},
27 {large : 0, graceful : 0, majestic : 0}
28]
29
30// information about feature priors (probabilistic world knowledge)
31// obtained by an experimental study (see paper)
32var featureSetPrior = function(category) {
33 category === "whale" ? categorical([0.30592786494628, 0.138078454222818,
34 0.179114768847673, 0.13098781834847,
35 0.0947267162507846, 0.0531420411185539,
36 0.0601520520596695, 0.0378702842057509],
37 featureSets) :
38 category === "person" ? categorical([0.11687632453038, 0.105787535267869,
39 0.11568145784997, 0.130847056136141,
40 0.15288225956497, 0.128098151176801,
41 0.114694702836614, 0.135132512637255],
42 featureSets) :
43 true
44}
45
46// Speaker's possible goals are to communicate feature 1, 2, or 3
47var goals = ["large", "graceful", "majestic"]
48
49// Prior probability of speaker's goal is set to uniform but can
50// change with context/QUD.
51var goalPrior = function() {
52 categorical([1,1,1], goals)
53}
54
55// Speaker optimality parameter
56var alpha = 3
57
58// Check if interpreted category is identical to utterance
59var literalInterpretation = function(utterance, category) {
60 utterance === category
61}
62
63// Check if goal is satisfied
64var goalState = function(goal, featureSet) {
65 goal === "large" ? featureSet.large :
66 goal === "graceful" ? featureSet.graceful :
67 goal === "majestic" ? featureSet.majestic :
68 true
69}
70
71// Define a literal listener
72var literalListener = function(utterance, goal) {
73 Infer({model: function() {
74 var category = uniformDraw(categories)
75 var featureSet = featureSetPrior(category)
76 condition(literalInterpretation(utterance, category))
77 return goalState(goal, featureSet)
78 }})
79}
80
81// Speaker model
82var speaker = function(large, graceful, majestic, goal) {
83 Infer({model: function() {
84 var utterance = utterancePrior()
85 factor(alpha *
86 literalListener(utterance,goal).score(goalState(goal, {large : large, graceful : graceful, majestic : majestic})))
87 return utterance
88 }})
89}
90
91// Define a pragmatic listener
92var pragmaticListener = function(utterance) {
93 Infer({model: function() {
94 var category = categoriesPrior()
95 var featureSet = featureSetPrior(category)
96 var large = featureSet.large
97 var graceful = featureSet.graceful
98 var majestic = featureSet.majestic
99 var goal = goalPrior()
100 observe(speaker(large, graceful, majestic, goal), utterance)
101 return {category, large, graceful, majestic}
102 }})
103}
104
105display("The pragmatic listener's interpretation when the speaker says whale")
106viz.table(pragmaticListener("whale"))
107
108display("The pragmatic listener's interpretation when the speaker says person")
109viz.table(pragmaticListener("person"))
110
111var ANSWER = (pragmaticListener("whale"));
realization0.000
python
1
2categories = ["whale", "person"]
3cat_prior = torch.tensor([0.01, 0.99])
4featureSets = [
5 {"large": 1, "graceful": 1, "majestic": 1},
6 {"large": 1, "graceful": 1, "majestic": 0},
7 {"large": 1, "graceful": 0, "majestic": 1},
8 {"large": 1, "graceful": 0, "majestic": 0},
9 {"large": 0, "graceful": 1, "majestic": 1},
10 {"large": 0, "graceful": 1, "majestic": 0},
11 {"large": 0, "graceful": 0, "majestic": 1},
12 {"large": 0, "graceful": 0, "majestic": 0},
13]
14fs_probs = {
15 "whale": torch.tensor([0.30592786494628, 0.138078454222818, 0.179114768847673,
16 0.13098781834847, 0.0947267162507846, 0.0531420411185539,
17 0.0601520520596695, 0.0378702842057509]),
18 "person": torch.tensor([0.11687632453038, 0.105787535267869, 0.11568145784997,
19 0.130847056136141, 0.15288225956497, 0.128098151176801,
20 0.114694702836614, 0.135132512637255]),
21}
22fs_stack = torch.stack([fs_probs["whale"], fs_probs["person"]]) # (2, 8)
23goals = ["large", "graceful", "majestic"]
24goal_prior = torch.tensor([1.0, 1.0, 1.0])
25utterances = ["whale", "person"]
26alpha = 3.0
27NEG = torch.tensor(-1e30)
28ZERO = torch.tensor(0.0)
29
30
31def goal_state(goal, fs_idx):
32 return featureSets[fs_idx][goal]
33
34
35# Literal listener: inferred distribution over the goal-relevant feature value.
36def literal_listener(utterance, goal):
37 forced = categories.index(utterance)
38
39 @pyro.infer.config_enumerate
40 def model():
41 cat = pyro.sample("cat", dist.Categorical(torch.ones(2)))
42 pyro.sample("fs", dist.Categorical(fs_stack[cat]))
43 # literalInterpretation: utterance === category
44 pyro.factor("cond", torch.where(cat == forced, ZERO, NEG))
45 return None
46
47 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)
48 fsm = marg["fs"]
49 sup = fsm.enumerate_support()
50 pr = fsm.log_prob(sup).exp()
51 out = {0: 0.0, 1: 0.0}
52 for i, p in zip(sup.tolist(), pr.tolist()):
53 out[goal_state(goal, int(i))] += p
54 return out
55
56
57# Speaker: inferred distribution over utterances given the goal feature value.
58def speaker_logprobs(goal_value, goal):
59 l0 = {u: literal_listener(u, goal) for u in utterances}
60
61 @pyro.infer.config_enumerate
62 def model():
63 u = pyro.sample("u", dist.Categorical(torch.ones(len(utterances))))
64 sc = torch.tensor([math.log(l0[uu][goal_value]) if l0[uu][goal_value] > 0 else -1e30
65 for uu in utterances])
66 pyro.factor("sc", alpha * sc[u])
67 return None
68
69 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)
70 um = marg["u"]
71 sup = um.enumerate_support()
72 pr = um.log_prob(sup).exp()
73 return {utterances[int(i)]: p for i, p in zip(sup.tolist(), pr.tolist())}
74
75
76# Pragmatic listener: joint posterior over (category, featureSet) via exact
77# enumeration of a single combined latent; the engine marginalizes it.
78def pragmatic_listener(utterance):
79 spk_cache = {}
80
81 def spk_logp(fs_idx, goal_idx):
82 goal = goals[goal_idx]
83 gv = goal_state(goal, fs_idx)
84 key = (goal, gv)
85 if key not in spk_cache:
86 spk_cache[key] = speaker_logprobs(gv, goal)
87 p = spk_cache[key].get(utterance, 0.0)
88 return math.log(p) if p > 0 else -1e30
89
90 combos = [(c, f) for c in range(2) for f in range(8)]
91 joint_prior = torch.tensor([cat_prior[c].item() * fs_stack[c][f].item() for (c, f) in combos])
92 score_t = torch.zeros(len(combos), 3)
93 for j, (c, f) in enumerate(combos):
94 for g in range(3):
95 score_t[j, g] = spk_logp(f, g)
96
97 @pyro.infer.config_enumerate
98 def model():
99 cf = pyro.sample("cf", dist.Categorical(joint_prior))
100 goal = pyro.sample("goal", dist.Categorical(goal_prior))
101 pyro.factor("obs", score_t[cf, goal])
102 return None
103
104 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)
105 cfm = marg["cf"]
106 sup = cfm.enumerate_support()
107 pr = cfm.log_prob(sup).exp()
108 out = {}
109 for j, p in zip(sup.tolist(), pr.tolist()):
110 c, f = combos[int(j)]
111 rec = {"category": categories[c], "large": featureSets[f]["large"],
112 "graceful": featureSets[f]["graceful"], "majestic": featureSets[f]["majestic"]}
113 out[json.dumps(rec, sort_keys=True)] = p
114 return out
115
116
117ANSWER = pragmatic_listener("whale")
118
02answer overlay — webppl vs pyrodist/finite
webppl pyro16 bins
00.090.090.170.17{"category":"person","graceful":0,"large":0,"majestic":0} A = 0.069 B = 0.069{"category":"person","graceful":0,"large":0,"majestic":0} A = 0.069 B = 0.069{"category":"person","graceful":0,"large":0,"majestic":0}{"category":"person","graceful":0,"large":0,"majestic":1} A = 0.092 B = 0.092{"category":"person","graceful":0,"large":0,"majestic":1} A = 0.092 B = 0.092{"category":"person","graceful":0,"large":1,"majestic":0} A = 0.134 B = 0.134{"category":"person","graceful":0,"large":1,"majestic":0} A = 0.134 B = 0.134{"category":"person","graceful":0,"large":1,"majestic":0}{"category":"person","graceful":0,"large":1,"majestic":1} A = 0.151 B = 0.151{"category":"person","graceful":0,"large":1,"majestic":1} A = 0.151 B = 0.151{"category":"person","graceful":1,"large":0,"majestic":0} A = 0.089 B = 0.089{"category":"person","graceful":1,"large":0,"majestic":0} A = 0.089 B = 0.089{"category":"person","graceful":1,"large":0,"majestic":0}{"category":"person","graceful":1,"large":0,"majestic":1} A = 0.151 B = 0.151{"category":"person","graceful":1,"large":0,"majestic":1} A = 0.151 B = 0.151{"category":"person","graceful":1,"large":1,"majestic":0} A = 0.128 B = 0.128{"category":"person","graceful":1,"large":1,"majestic":0} A = 0.128 B = 0.128{"category":"person","graceful":1,"large":1,"majestic":0}{"category":"person","graceful":1,"large":1,"majestic":1} A = 0.175 B = 0.175{"category":"person","graceful":1,"large":1,"majestic":1} A = 0.175 B = 0.175{"category":"whale","graceful":0,"large":0,"majestic":0} A = 0.000 B = 0.000{"category":"whale","graceful":0,"large":0,"majestic":0} A = 0.000 B = 0.000{"category":"whale","graceful":0,"large":0,"majestic":0}{"category":"whale","graceful":0,"large":0,"majestic":1} A = 0.000 B = 0.000{"category":"whale","graceful":0,"large":0,"majestic":1} A = 0.000 B = 0.000{"category":"whale","graceful":0,"large":1,"majestic":0} A = 0.001 B = 0.001{"category":"whale","graceful":0,"large":1,"majestic":0} A = 0.001 B = 0.001{"category":"whale","graceful":0,"large":1,"majestic":0}{"category":"whale","graceful":0,"large":1,"majestic":1} A = 0.002 B = 0.002{"category":"whale","graceful":0,"large":1,"majestic":1} A = 0.002 B = 0.002{"category":"whale","graceful":1,"large":0,"majestic":0} A = 0.000 B = 0.000{"category":"whale","graceful":1,"large":0,"majestic":0} A = 0.000 B = 0.000{"category":"whale","graceful":1,"large":0,"majestic":0}{"category":"whale","graceful":1,"large":0,"majestic":1} A = 0.001 B = 0.001{"category":"whale","graceful":1,"large":0,"majestic":1} A = 0.001 B = 0.001{"category":"whale","graceful":1,"large":1,"majestic":0} A = 0.002 B = 0.002{"category":"whale","graceful":1,"large":1,"majestic":0} A = 0.002 B = 0.002{"category":"whale","graceful":1,"large":1,"majestic":0}{"category":"whale","graceful":1,"large":1,"majestic":1} A = 0.005 B = 0.005{"category":"whale","graceful":1,"large":1,"majestic":1} A = 0.005 B = 0.005
03 verification
checkstatusevidence
GT self-consistency ok floor 0.0000 (tv)
solver re-derivation accept 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6
cross-language (pyro vs webppl) pass d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000
forestdb-blm / atom-1
answer dist/finite solver accept pyro pass 0.0000
00 statement source: forestdb.org/models/blm.md
given

State space: four combinations of two binary dimensions black and white — {black:true, white:true}, {black:true, white:false}, {black:false, white:true}, {black:false, white:false}. Prior over states is uniform (probability 0.25 each). Utterances: {"blm", "nblm"}. Literal meanings: "blm" is true when black is true; "nblm" is true when black is false. Speaker optimality parameter alpha=1.

model

A two-level RSA model for a political speech act. A literal listener hears an utterance, samples a state from the uniform prior, conditions on the utterance's literal meaning, and returns the state. A pragmatic speaker observes the true state and selects utterances with probability proportional to exp(alpha * literal-listener-score on the state). A pragmatic listener samples a state from the uniform prior and conditions on the pragmatic speaker's utterance choice.

query

The posterior distribution over states {black, white} for a pragmatic listener who hears "blm".

answer spec dist/finite
{
  "kind": "dist",
  "domain": "finite",
  "labels": {
    "record": {
      "black": "bool",
      "white": "bool"
    }
  }
}
system prompt constant across problems
(system prompt loads here)
webppl primer solver context
(primer loads here)
01 realizations comparing webppl vs pyro
ground truth
webppl
1var alpha = 1
2
3var statePrior = function() {
4 return Categorical({
5 // ps: [.9,.04,.04,.02], // make it a priori likely that all lives matter
6 ps: [.25,.25,.25,.25], // uniform prior
7 // ps: [.04,.04,.9,.02], // make it a priori likely that only white lives matter
8 vs: [{black:true, white:true},{black:true, white:false},{black:false, white:true},{black:false, white:false}]})
9};
10
11// possible utterances
12var utterancePrior = function() {
13 return uniformDraw(['blm', 'nblm'])
14// return uniformDraw(['blm', 'wlm', 'alm', 'nlm'])
15};
16
17// meaning funtion to interpret the utterances
18var literalMeanings = {
19 blm: function(state) { return state["black"] },
20 nblm: function(state) { return !state["black"] },
21 wlm: function(state) { return state["white"]},
22 alm: function(state) { return state["black"] && state["white"] },
23 nlm: function(state) { return !state["black"] && !state["white"] }
24};
25
26// literal listener
27var literalListener = cache(function(utt) {
28 return Infer({method:"enumerate"},
29 function(){
30 var state = sample(statePrior())
31 var meaning = literalMeanings[utt]
32 condition(meaning(state))
33 return state
34 })
35});
36
37// pragmatic speaker
38var speaker = cache(function(state) {
39 return Infer({method:"enumerate"},
40 function(){
41 var utt = utterancePrior()
42 factor(alpha * literalListener(utt).score(state))
43 return utt
44 })
45});
46
47// pragmatic listener
48var pragmaticListener = cache(function(utt) {
49 return Infer({method:"enumerate"},
50 function(){
51 var state = sample(statePrior())
52 observe(speaker(state),utt)
53 return state
54 })
55});
56
57pragmaticListener("blm")
58
59var ANSWER = (pragmaticListener('blm'));
realization0.000
python
1# Two-level RSA. Every level is genuine Pyro exact enumeration inference.
2
3alpha = 1.0
4states = [
5 {"black": True, "white": True},
6 {"black": True, "white": False},
7 {"black": False, "white": True},
8 {"black": False, "white": False},
9]
10utterances = ["blm", "nblm"]
11
12
13def literal_meaning(utt, state):
14 if utt == "blm":
15 return state["black"]
16 return not state["black"]
17
18
19def enum_marginal(model, n):
20 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(
21 model, lambda: None
22 )["x"]
23 sup = marg.enumerate_support()
24 probs = marg.log_prob(sup).exp()
25 out = {int(s): 0.0 for s in range(n)}
26 for s, p in zip(sup.tolist(), probs.tolist()):
27 out[int(s)] = p
28 return out
29
30
31def literal_listener(utt):
32 # uniform prior over states, condition on meaning
33 truths = torch.tensor(
34 [1.0 if literal_meaning(utt, s) else 0.0 for s in states]
35 )
36 logt = torch.log(truths)
37
38 @pyro.infer.config_enumerate
39 def model():
40 x = pyro.sample("x", dist.Categorical(torch.ones(len(states))))
41 pyro.factor("ev", logt[x])
42 return x
43
44 return enum_marginal(model, len(states))
45
46
47def speaker(state_idx):
48 # factor alpha * literalListener(utt).score(state)
49 scores = []
50 for u in utterances:
51 d = literal_listener(u)
52 p = d.get(state_idx, 0.0)
53 scores.append(math.log(p) if p > 0 else float("-inf"))
54 scores = torch.tensor(scores)
55
56 @pyro.infer.config_enumerate
57 def model():
58 x = pyro.sample("x", dist.Categorical(torch.ones(len(utterances))))
59 pyro.factor("ev", alpha * scores[x])
60 return x
61
62 return enum_marginal(model, len(utterances))
63
64
65def pragmatic_listener(utt):
66 u_idx = utterances.index(utt)
67 scores = []
68 for i in range(len(states)):
69 d = speaker(i)
70 p = d.get(u_idx, 0.0)
71 scores.append(math.log(p) if p > 0 else float("-inf"))
72 scores = torch.tensor(scores)
73
74 @pyro.infer.config_enumerate
75 def model():
76 x = pyro.sample("x", dist.Categorical(torch.ones(len(states))))
77 pyro.factor("ev", scores[x])
78 return x
79
80 return enum_marginal(model, len(states))
81
82
83idx_dist = pragmatic_listener("blm")
84ANSWER = {
85 json.dumps(states[i], sort_keys=True): idx_dist.get(i, 0.0)
86 for i in range(len(states))
87}
88
02answer overlay — webppl vs pyrodist/finite
webppl pyro2 bins
00.250.250.500.50{"black":true,"white":false} A = 0.500 B = 0.500{"black":true,"white":false} A = 0.500 B = 0.5000.500.50{"black":true,"white":false}{"black":true,"white":true} A = 0.500 B = 0.500{"black":true,"white":true} A = 0.500 B = 0.5000.500.50{"black":true,"white":true}
03 verification
checkstatusevidence
GT self-consistency ok floor 0.0000 (tv)
solver re-derivation accept 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6
cross-language (pyro vs webppl) pass d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000
forestdb-ccgn-metaphor / atom-1
answer dist/finite solver accept pyro pass 0.0000
00 statement source: data/sources/forestdb.org/models/ccgn-metaphor.md
given

Two possible categories: whale and person. The prior probability that someone is actually a whale is 0.01 (person: 0.99). Each entity can possess three binary features: large, graceful, majestic. Feature sets are all 8 binary combinations. Empirical feature-set priors by category: for whales, probabilities are [0.30592786494628, 0.138078454222818, 0.179114768847673, 0.13098781834847, 0.0947267162507846, 0.0531420411185539, 0.0601520520596695, 0.0378702842057509] over the feature sets ordered (large=1,graceful=1,majestic=1), (1,1,0), (1,0,1), (1,0,0), (0,1,1), (0,1,0), (0,0,1), (0,0,0); for persons, the corresponding probabilities are [0.11687632453038, 0.105787535267869, 0.11568145784997, 0.130847056136141, 0.15288225956497, 0.128098151176801, 0.114694702836614, 0.135132512637255]. The speaker has a goal prior that weights communicating 'large' with probability proportional to 5 and each of 'graceful' and 'majestic' with proportional weight 1. Two possible utterances — 'whale' and 'person' — with equal prior probability. Speaker optimality alpha = 3. Literal interpretation: utterance is true of a category if and only if the utterance label matches the category name. A goal is satisfied when the feature corresponding to that goal equals 1.

model

A literal listener conditions on the utterance matching the entity's category and returns whether the goal's corresponding feature is satisfied. A speaker who knows the entity's actual category and all three feature values chooses an utterance in proportion to exp(alpha times the literal listener's log-probability of the entity's actual goal-feature value — the log-probability that the feature equals 1 when the entity's goal feature is 1, and the log-probability that the feature equals 0 when it is 0). A pragmatic listener infers category and all three feature values jointly, marginalizing over the speaker's latent goal, by inverting the speaker model.

query

The posterior joint distribution over category and the three binary features (large, graceful, majestic) given that the pragmatic listener hears the utterance 'whale'.

answer spec dist/finite
{
  "kind": "dist",
  "domain": "finite",
  "labels": {
    "record": {
      "category": "string",
      "large": "int",
      "graceful": "int",
      "majestic": "int"
    }
  }
}
system prompt constant across problems
(system prompt loads here)
webppl primer solver context
(primer loads here)
01 realizations comparing webppl vs pyro
ground truth
webppl
1// John could either be a whale or a person.
2var categories = ["whale", "person"]
3
4// It is extremely unlikely that John is actually a whale.
5var categoriesPrior = function() {
6 categorical([0.01, 0.99], categories)
7}
8
9// The speaker could either say "John is a whale" or "John is a person."
10var utterances = ["whale", "person"]
11
12// The utterances are equally costly.
13var utterancePrior = function() {
14 categorical([1,1], utterances)
15}
16
17// The features of John being considered are "large", "graceful",
18// "majestic." Features are binary.
19var featureSets = [
20 {large : 1, graceful : 1, majestic : 1},
21 {large : 1, graceful : 1, majestic : 0},
22 {large : 1, graceful : 0, majestic : 1},
23 {large : 1, graceful : 0, majestic : 0},
24 {large : 0, graceful : 1, majestic : 1},
25 {large : 0, graceful : 1, majestic : 0},
26 {large : 0, graceful : 0, majestic : 1},
27 {large : 0, graceful : 0, majestic : 0}
28]
29
30// information about feature priors (probabilistic world knowledge)
31// obtained by an experimental study (see paper)
32var featureSetPrior = function(category) {
33 category === "whale" ? categorical([0.30592786494628, 0.138078454222818,
34 0.179114768847673, 0.13098781834847,
35 0.0947267162507846, 0.0531420411185539,
36 0.0601520520596695, 0.0378702842057509],
37 featureSets) :
38 category === "person" ? categorical([0.11687632453038, 0.105787535267869,
39 0.11568145784997, 0.130847056136141,
40 0.15288225956497, 0.128098151176801,
41 0.114694702836614, 0.135132512637255],
42 featureSets) :
43 true
44}
45
46// Speaker's possible goals are to communicate feature 1, 2, or 3
47var goals = ["large", "graceful", "majestic"]
48
49//// Prior probability of speaker's goal is set to uniform but can
50//// change with context/QUD.
51var goalPrior = function() {
52 categorical([5,1,1], goals)
53}
54// var goalPrior = function() {
55// categorical([1,5,1], goals)
56// }
57// var goalPrior = function() {
58// categorical([1,1,5], goals)
59// }
60
61// Speaker optimality parameter
62var alpha = 3
63
64// Check if interpreted category is identical to utterance
65var literalInterpretation = function(utterance, category) {
66 utterance === category
67}
68
69// Check if goal is satisfied
70var goalState = function(goal, featureSet) {
71 goal === "large" ? featureSet.large :
72 goal === "graceful" ? featureSet.graceful :
73 goal === "majestic" ? featureSet.majestic :
74 true
75}
76
77// Define a literal listener
78var literalListener = function(utterance, goal) {
79 Infer({model: function() {
80 var category = uniformDraw(categories)
81 var featureSet = featureSetPrior(category)
82 condition(literalInterpretation(utterance, category))
83 return goalState(goal, featureSet)
84 }})
85}
86
87// Speaker model
88var speaker = function(large, graceful, majestic, goal) {
89 Infer({model: function() {
90 var utterance = utterancePrior()
91 factor(alpha *
92 literalListener(utterance,goal).score(goalState(goal, {large : large, graceful : graceful, majestic : majestic})))
93 return utterance
94 }})
95}
96
97// Define a pragmatic listener
98var pragmaticListener = function(utterance) {
99 Infer({model: function() {
100 var category = categoriesPrior()
101 var featureSet = featureSetPrior(category)
102 var large = featureSet.large
103 var graceful = featureSet.graceful
104 var majestic = featureSet.majestic
105 var goal = goalPrior()
106 observe(speaker(large, graceful, majestic, goal), utterance)
107 return {category, large, graceful, majestic}
108 }})
109}
110
111viz.table(pragmaticListener("whale"))
112
113var ANSWER = (pragmaticListener("whale"));
realization0.000
python
1
2categories = ["whale", "person"]
3cat_prior = torch.tensor([0.01, 0.99])
4featureSets = [
5 {"large": 1, "graceful": 1, "majestic": 1},
6 {"large": 1, "graceful": 1, "majestic": 0},
7 {"large": 1, "graceful": 0, "majestic": 1},
8 {"large": 1, "graceful": 0, "majestic": 0},
9 {"large": 0, "graceful": 1, "majestic": 1},
10 {"large": 0, "graceful": 1, "majestic": 0},
11 {"large": 0, "graceful": 0, "majestic": 1},
12 {"large": 0, "graceful": 0, "majestic": 0},
13]
14fs_probs = {
15 "whale": torch.tensor([0.30592786494628, 0.138078454222818, 0.179114768847673,
16 0.13098781834847, 0.0947267162507846, 0.0531420411185539,
17 0.0601520520596695, 0.0378702842057509]),
18 "person": torch.tensor([0.11687632453038, 0.105787535267869, 0.11568145784997,
19 0.130847056136141, 0.15288225956497, 0.128098151176801,
20 0.114694702836614, 0.135132512637255]),
21}
22fs_stack = torch.stack([fs_probs["whale"], fs_probs["person"]]) # (2, 8)
23goals = ["large", "graceful", "majestic"]
24goal_prior = torch.tensor([5.0, 1.0, 1.0])
25utterances = ["whale", "person"]
26alpha = 3.0
27NEG = torch.tensor(-1e30)
28ZERO = torch.tensor(0.0)
29
30
31def goal_state(goal, fs_idx):
32 return featureSets[fs_idx][goal]
33
34
35# Literal listener: inferred distribution over the goal-relevant feature value.
36def literal_listener(utterance, goal):
37 forced = categories.index(utterance)
38
39 @pyro.infer.config_enumerate
40 def model():
41 cat = pyro.sample("cat", dist.Categorical(torch.ones(2)))
42 pyro.sample("fs", dist.Categorical(fs_stack[cat]))
43 # literalInterpretation: utterance === category
44 pyro.factor("cond", torch.where(cat == forced, ZERO, NEG))
45 return None
46
47 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)
48 fsm = marg["fs"]
49 sup = fsm.enumerate_support()
50 pr = fsm.log_prob(sup).exp()
51 out = {0: 0.0, 1: 0.0}
52 for i, p in zip(sup.tolist(), pr.tolist()):
53 out[goal_state(goal, int(i))] += p
54 return out
55
56
57# Speaker: inferred distribution over utterances given the goal feature value.
58def speaker_logprobs(goal_value, goal):
59 l0 = {u: literal_listener(u, goal) for u in utterances}
60
61 @pyro.infer.config_enumerate
62 def model():
63 u = pyro.sample("u", dist.Categorical(torch.ones(len(utterances))))
64 sc = torch.tensor([math.log(l0[uu][goal_value]) if l0[uu][goal_value] > 0 else -1e30
65 for uu in utterances])
66 pyro.factor("sc", alpha * sc[u])
67 return None
68
69 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)
70 um = marg["u"]
71 sup = um.enumerate_support()
72 pr = um.log_prob(sup).exp()
73 return {utterances[int(i)]: p for i, p in zip(sup.tolist(), pr.tolist())}
74
75
76# Pragmatic listener: joint posterior over (category, featureSet) via exact
77# enumeration of a single combined latent; the engine marginalizes it.
78def pragmatic_listener(utterance):
79 spk_cache = {}
80
81 def spk_logp(fs_idx, goal_idx):
82 goal = goals[goal_idx]
83 gv = goal_state(goal, fs_idx)
84 key = (goal, gv)
85 if key not in spk_cache:
86 spk_cache[key] = speaker_logprobs(gv, goal)
87 p = spk_cache[key].get(utterance, 0.0)
88 return math.log(p) if p > 0 else -1e30
89
90 combos = [(c, f) for c in range(2) for f in range(8)]
91 joint_prior = torch.tensor([cat_prior[c].item() * fs_stack[c][f].item() for (c, f) in combos])
92 score_t = torch.zeros(len(combos), 3)
93 for j, (c, f) in enumerate(combos):
94 for g in range(3):
95 score_t[j, g] = spk_logp(f, g)
96
97 @pyro.infer.config_enumerate
98 def model():
99 cf = pyro.sample("cf", dist.Categorical(joint_prior))
100 goal = pyro.sample("goal", dist.Categorical(goal_prior))
101 pyro.factor("obs", score_t[cf, goal])
102 return None
103
104 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)
105 cfm = marg["cf"]
106 sup = cfm.enumerate_support()
107 pr = cfm.log_prob(sup).exp()
108 out = {}
109 for j, p in zip(sup.tolist(), pr.tolist()):
110 c, f = combos[int(j)]
111 rec = {"category": categories[c], "large": featureSets[f]["large"],
112 "graceful": featureSets[f]["graceful"], "majestic": featureSets[f]["majestic"]}
113 out[json.dumps(rec, sort_keys=True)] = p
114 return out
115
116
117ANSWER = pragmatic_listener("whale")
118
02answer overlay — webppl vs pyrodist/finite
webppl pyro16 bins
00.100.100.200.20{"category":"person","graceful":0,"large":0,"majestic":0} A = 0.047 B = 0.047{"category":"person","graceful":0,"large":0,"majestic":0} A = 0.047 B = 0.047{"category":"person","graceful":0,"large":0,"majestic":0}{"category":"person","graceful":0,"large":0,"majestic":1} A = 0.054 B = 0.054{"category":"person","graceful":0,"large":0,"majestic":1} A = 0.054 B = 0.054{"category":"person","graceful":0,"large":1,"majestic":0} A = 0.194 B = 0.194{"category":"person","graceful":0,"large":1,"majestic":0} A = 0.194 B = 0.194{"category":"person","graceful":0,"large":1,"majestic":0}{"category":"person","graceful":0,"large":1,"majestic":1} A = 0.187 B = 0.187{"category":"person","graceful":0,"large":1,"majestic":1} A = 0.187 B = 0.187{"category":"person","graceful":1,"large":0,"majestic":0} A = 0.055 B = 0.055{"category":"person","graceful":1,"large":0,"majestic":0} A = 0.055 B = 0.055{"category":"person","graceful":1,"large":0,"majestic":0}{"category":"person","graceful":1,"large":0,"majestic":1} A = 0.085 B = 0.085{"category":"person","graceful":1,"large":0,"majestic":1} A = 0.085 B = 0.085{"category":"person","graceful":1,"large":1,"majestic":0} A = 0.166 B = 0.166{"category":"person","graceful":1,"large":1,"majestic":0} A = 0.166 B = 0.166{"category":"person","graceful":1,"large":1,"majestic":0}{"category":"person","graceful":1,"large":1,"majestic":1} A = 0.198 B = 0.198{"category":"person","graceful":1,"large":1,"majestic":1} A = 0.198 B = 0.198{"category":"whale","graceful":0,"large":0,"majestic":0} A = 0.000 B = 0.000{"category":"whale","graceful":0,"large":0,"majestic":0} A = 0.000 B = 0.000{"category":"whale","graceful":0,"large":0,"majestic":0}{"category":"whale","graceful":0,"large":0,"majestic":1} A = 0.000 B = 0.000{"category":"whale","graceful":0,"large":0,"majestic":1} A = 0.000 B = 0.000{"category":"whale","graceful":0,"large":1,"majestic":0} A = 0.002 B = 0.002{"category":"whale","graceful":0,"large":1,"majestic":0} A = 0.002 B = 0.002{"category":"whale","graceful":0,"large":1,"majestic":0}{"category":"whale","graceful":0,"large":1,"majestic":1} A = 0.003 B = 0.003{"category":"whale","graceful":0,"large":1,"majestic":1} A = 0.003 B = 0.003{"category":"whale","graceful":1,"large":0,"majestic":0} A = 0.000 B = 0.000{"category":"whale","graceful":1,"large":0,"majestic":0} A = 0.000 B = 0.000{"category":"whale","graceful":1,"large":0,"majestic":0}{"category":"whale","graceful":1,"large":0,"majestic":1} A = 0.001 B = 0.001{"category":"whale","graceful":1,"large":0,"majestic":1} A = 0.001 B = 0.001{"category":"whale","graceful":1,"large":1,"majestic":0} A = 0.002 B = 0.002{"category":"whale","graceful":1,"large":1,"majestic":0} A = 0.002 B = 0.002{"category":"whale","graceful":1,"large":1,"majestic":0}{"category":"whale","graceful":1,"large":1,"majestic":1} A = 0.005 B = 0.005{"category":"whale","graceful":1,"large":1,"majestic":1} A = 0.005 B = 0.005
03 verification
checkstatusevidence
GT self-consistency ok floor 0.0000 (tv)
solver re-derivation accept 1/2 solvers · d=[0.042, 0.000] · claude-sonnet-4-6
cross-language (pyro vs webppl) pass d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000
forestdb-cnqr-comparison-class / atom-1
answer dist/finite solver accept pyro pass 0.0000
00 statement source: data/sources/forestdb.org/models/cnqr-comparison-class.md
given

Height distributions are discretized using binParam = 3. The superordinate category (all people) has height distribution Gaussian(mu=0, sigma=1). The basketball player subcategory has Gaussian(mu=1, sigma=0.5). State values are 18 evenly spaced points from -3 to 3 (exclusive) in steps of 1/3. State probabilities are proportional to the Gaussian PDF evaluated at each state value. Thresholds for 'tall' are each state value minus 1/(2*binParam); thresholds for 'short' are each state value plus 1/(2*binParam); each threshold set is drawn uniformly. Comparison class is drawn uniformly from {superordinate, subordinate}. Speaker optimality alpha = 5. Three utterances: 'tall', 'short', 'silence' (silence is always true).

model

A literal listener conditions on the utterance being true (tall: state exceeds the tall threshold; short: state is below the short threshold; silence: always true) given the state distribution for the current comparison class. A speaker chooses utterances proportional to exp(alpha times the literal listener's log-probability of the state). A pragmatic listener jointly infers the entity's state, the comparison class, and the threshold pair (one tall-threshold drawn uniformly from the tall set, one short-threshold drawn uniformly from the short set) by inverting the speaker, conditioning on the utterance. The same threshold pair is passed to the speaker and through to the literal listener.

query

The marginal posterior distribution over the comparison class (superordinate vs. subordinate) for a pragmatic listener who hears 'tall' about a basketball player.

answer spec dist/finite
{
  "kind": "dist",
  "domain": "finite",
  "support": [
    "superordinate",
    "subordinate"
  ]
}
system prompt constant across problems
(system prompt loads here)
webppl primer solver context
(primer loads here)
01 realizations comparing webppl vs pyro
ground truth
webppl
1///fold:
2// helper function
3var exp = function(x){return Math.exp(x)}
4
5// for discretization
6var binParam = 3;
7
8// information about the superordinate category prior
9// e.g., the height distribution for all people
10var superordinate_params = {mu: 0, sigma: 1};
11
12// calculate the range in pre-defined steps;
13// these values correspond to possible heights
14var stateVals = _.range(superordinate_params.mu - 3 * superordinate_params.sigma,
15 superordinate_params.mu + 3 * superordinate_params.sigma,
16 superordinate_params.sigma/binParam)
17
18// for each possible height, calculate its probability of occurrence
19var stateProbs = cache(function(stateParams){
20 return map(function(s){
21 exp(Gaussian(stateParams).score(s))
22 }, stateVals)
23});
24
25// generate a statePrior using the possible heights and their probabilities
26var generateStatePrior = cache(function(stateParams) {
27 return Infer({
28 model: function(){
29 return categorical({vs: stateVals, ps: stateProbs(stateParams)})
30 }
31 })
32});
33
34// generate the uniform threshold prior
35var thresholdBins ={
36 positive: map(function(x){
37 return x - (1/(binParam*2));
38 }, sort(stateVals)),
39 negative: map(function(x){
40 return x + (1/(binParam*2));
41 }, sort(stateVals))
42};
43
44var thresholdPrior = cache(function(form){
45 return Infer({
46 model: function() { return uniformDraw(thresholdBins[form]) }
47 });
48});
49
50// information about the superordinate category priors
51var subParams = {
52 gymnasts: {mu: -1, sigma: 0.5}, // gymnast heights
53 soccerPlayers: {mu: 0, sigma: 0.5}, // soccer player heights
54 basketballPlayers: {mu: 1, sigma: 0.5} // basketball player heights
55}
56
57// possible utterances can be either positive (tall) or negative (short) or a null utterance
58var utterances = ["tall", "short", "silence"]
59
60// meaning function for utterances
61var meaning = function(utterance, state, thresholds) {
62 utterance == "tall" ? state > thresholds.tall :
63 utterance == "short" ? state < thresholds.short :
64 true
65}
66
67// assume a uniform prior over comparison classes
68var classPrior = Infer({
69 model: function(){return uniformDraw(["subordinate", "superordinate"])}
70});
71
72// set speaker optimality
73var alpha = 5;
74
75var literalListener = cache(
76 function(utterance, thresholds, comparisonClass) {
77 Infer({model: function(){
78 var StatePrior = generateStatePrior(comparisonClass)
79 var state = sample(StatePrior);
80 var m = meaning(utterance, state, thresholds);
81 condition(m);
82 return state;
83 }})
84 }, 10000 // limit cache size
85)
86
87var speaker1 = cache(
88 function(state, thresholds, comparisonClass) {
89 Infer({model: function(){
90 var utterance = uniformDraw(utterances);
91 var L0 = literalListener(utterance, thresholds, comparisonClass);
92 factor( alpha * L0.score(state) );
93 return utterance;
94 }})
95 }, 10000 // limit cache size
96)
97///
98
99var pragmaticListener = cache(function(utterance, subordinate_params) {
100 Infer({model: function(){
101
102 var statePrior = generateStatePrior(subordinate_params);
103 var state = sample(statePrior);
104 // separate thresholds for positive adjective and negative adjective
105 var thresholds = {
106 tall: sample(thresholdPrior("positive")),
107 short: sample(thresholdPrior("negative"))
108 }
109
110 // uncertainty about the comparison class (superordinate vs. subordinate)
111 var c = sample(classPrior)
112 var comparisonClass = c == "subordinate" ? subordinate_params : superordinate_params
113
114 var S1 = speaker1(state, thresholds, comparisonClass);
115 observe(S1, utterance);
116
117 return { comparisonClass: c, state : state }
118 }})
119}, 10000 // limit cache size
120 )
121
122var ANSWER = (marginalize(pragmaticListener("tall", subParams["basketballPlayers"]), "comparisonClass"));
realization0.000
python
1exp = math.exp
2binParam = 3
3super_mu, super_sigma = 0.0, 1.0
4
5# 18 evenly spaced state values from -3 to 3 (exclusive) in steps of 1/3
6step = super_sigma / binParam
7stateVals = []
8x = super_mu - 3 * super_sigma
9while x < super_mu + 3 * super_sigma - 1e-9:
10 stateVals.append(x)
11 x += step
12
13subParams = {"mu": 1.0, "sigma": 0.5} # basketball players
14sorted_states = sorted(stateVals)
15threshold_positive = [s - 1.0 / (binParam * 2) for s in sorted_states]
16threshold_negative = [s + 1.0 / (binParam * 2) for s in sorted_states]
17utterances = ["tall", "short", "silence"]
18alpha = 5.0
19
20def meaning(utt, state, t_tall, t_short):
21 if utt == "tall":
22 return state > t_tall
23 if utt == "short":
24 return state < t_short
25 return True
26
27# Normalized state prior (categorical over stateVals weighted by the Gaussian pdf).
28_state_dist_cache = {}
29def state_dist(mu, sigma):
30 key = (mu, sigma)
31 if key in _state_dist_cache:
32 return _state_dist_cache[key]
33 d = dist.Normal(torch.tensor(float(mu)), torch.tensor(float(sigma)))
34 ps = [exp(d.log_prob(torch.tensor(s)).item()) for s in stateVals]
35 z = sum(ps)
36 out = (stateVals, torch.tensor([p / z for p in ps]))
37 _state_dist_cache[key] = out
38 return out
39
40# Literal listener: condition on the utterance being true under the comparison-class
41# state prior. Exact enumeration via Pyro's compute_marginals.
42_L0_cache = {}
43def literalListener(utt, t_tall, t_short, cc_mu, cc_sigma):
44 key = (utt, t_tall, t_short, cc_mu, cc_sigma)
45 if key in _L0_cache:
46 return _L0_cache[key]
47 svals, pri = state_dist(cc_mu, cc_sigma)
48 _cond = torch.tensor([
49 0.0 if meaning(utt, svals[i], t_tall, t_short) else float("-inf")
50 for i in range(len(svals))
51 ])
52 @pyro.infer.config_enumerate
53 def m():
54 idx = pyro.sample("state", dist.Categorical(pri))
55 pyro.factor("cond", _cond[idx])
56 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(m, lambda: None)
57 _L0_cache[key] = (marg["state"], svals)
58 return _L0_cache[key]
59
60# Speaker: utterance ~ exp(alpha * L0.score(state)), inverting the literal listener.
61_S1_cache = {}
62def speaker1(state, t_tall, t_short, cc_mu, cc_sigma):
63 key = (state, t_tall, t_short, cc_mu, cc_sigma)
64 if key in _S1_cache:
65 return _S1_cache[key]
66 _sc = []
67 for utt in utterances:
68 marg, svals = literalListener(utt, t_tall, t_short, cc_mu, cc_sigma)
69 if state in svals:
70 lp = marg.log_prob(torch.tensor(svals.index(state))).item()
71 else:
72 lp = float("-inf")
73 _sc.append(alpha * lp)
74 _sc = torch.tensor(_sc)
75 @pyro.infer.config_enumerate
76 def m():
77 uidx = pyro.sample("utt", dist.Categorical(torch.ones(len(utterances))))
78 pyro.factor("util", _sc[uidx])
79 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(m, lambda: None)
80 _S1_cache[key] = marg["utt"]
81 return _S1_cache[key]
82
83# Pragmatic listener: jointly sample (class, state, tall-threshold, short-threshold),
84# observe the heard utterance by factoring the speaker's log-score, and marginalize.
85# The entity's state prior is the subordinate (basketball) prior; the comparison
86# class selects which prior the speaker/literal-listener use internally. Inference is
87# exact enumeration over the full joint via Pyro's compute_marginals.
88utterance_heard = "tall"
89_uidx = utterances.index(utterance_heard)
90classes = ["subordinate", "superordinate"]
91_svals, _state_pri = state_dist(subParams["mu"], subParams["sigma"])
92
93# Speaker log-score for the heard utterance across every joint assignment.
94_obs = torch.full(
95 (len(classes), len(_svals), len(threshold_positive), len(threshold_negative)),
96 float("-inf"),
97)
98for ci, c in enumerate(classes):
99 cc_mu, cc_sigma = (subParams["mu"], subParams["sigma"]) if c == "subordinate" else (super_mu, super_sigma)
100 for si, state in enumerate(_svals):
101 for ti, t_tall in enumerate(threshold_positive):
102 for tj, t_short in enumerate(threshold_negative):
103 S1 = speaker1(state, t_tall, t_short, cc_mu, cc_sigma)
104 _obs[ci, si, ti, tj] = S1.log_prob(torch.tensor(_uidx)).item()
105
106@pyro.infer.config_enumerate
107def pragmaticListener():
108 c = pyro.sample("c", dist.Categorical(torch.ones(len(classes))))
109 s = pyro.sample("state", dist.Categorical(_state_pri))
110 tt = pyro.sample("tt", dist.Categorical(torch.ones(len(threshold_positive))))
111 ts = pyro.sample("ts", dist.Categorical(torch.ones(len(threshold_negative))))
112 pyro.factor("obs", _obs[c, s, tt, ts])
113
114_marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(pragmaticListener, lambda: None)
115_cm = _marg["c"]
116ANSWER = {classes[i]: _cm.probs[i].item() for i in range(len(classes))}
117
02answer overlay — webppl vs pyrodist/finite
webppl pyro2 bins
00.290.290.580.58subordinate A = 0.419 B = 0.419subordinate A = 0.419 B = 0.4190.420.42subordinatesuperordinate A = 0.581 B = 0.581superordinate A = 0.581 B = 0.5810.580.58superordinate
03 verification
checkstatusevidence
GT self-consistency ok floor 0.0000 (tv)
solver re-derivation accept 1/2 solvers · d=[0.000, 0.115] · claude-sonnet-4-6
cross-language (pyro vs webppl) pass d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000
forestdb-codenames / atom-1
answer dist/finite solver accept pyro pass 0.0000
00 statement source: data/sources/forestdb.org/models/codenames.md
given

Six words and their 25-dimensional GloVe vectors: eagle: [-0.8186906894583743, -0.8443627918594182, -0.04304780086785447, -0.8257634263841377, -0.7607218950809542, 0.47786735164930183, 0.36942709316422206, 0.18560148725224498, 0.38625176009619944, 0.24384273963053932, 1.0355862286322068, -0.14170089242313555, -0.17017960843359828, 0.27636172471279313, -0.49477465481497807, -1.199206930890509, 1.0531720839078256, -0.5154875303291531, 0.30704269353337016, 1.5382356443196483, -0.13215425501400774, 1.2222507503066664, 1.3819617662995949, -1.1579407453927437, 0.9439311306043343] pig: [-0.9027808771549458, -1.4539105978263833, 0.5743098399154295, 1.3052815987119957, -0.038556210348244704, -0.22144102997326148, 1.222050088622139, -0.027526643946408857, -0.13265827668708097, 1.4799207507145387, 0.02371336629548181, -0.9405402658175948, 0.06556493358788004, -1.6556208133885402, -0.44306373689318584, -0.475710035110901, 1.2435716830499404, -1.0677780309283533, -0.03344465447168945, 0.16184568683816827, 0.8718035460475897, 2.082956682688621, 0.47430271385843953, -0.4479993650378608, 1.5192928553678355] chicken: [-0.9555633717916903, -0.23467550895948608, 0.9081102168032618, 1.7681919864431317, -0.3888166871286516, 1.2292398003323308, 1.0624961440319318, -0.3558803892040966, -0.17024423658814317, 0.7046776782991592, 1.624196256505183, -1.1423231844008523, -0.9490267652945451, -1.8004114037674281, -0.026086280368055388, -2.089757256612839, 3.5660566372328693, -2.4178611093952225, -0.7077960662621875, 0.9418434965990246, 0.438927172322575, 1.0891725023940724, -0.1237861204326181, 0.7602054634506068, 3.0515580696224083] farm: [-1.3191469349030631, -0.34747873883058705, 0.09525267994894762, 0.08014872654330456, 0.1179814806966339, -0.26926061020753783, 0.709033965954239, -0.6521777385143812, 1.0195239553589313, 0.7192612109870958, 1.1711460976059695, -1.0779079866249233, -0.5443503049555966, 0.08523251153754875, 0.1455530206584687, -1.501097375488643, 1.1151234505440395, 0.0581591541412683, -0.1102242123027589, 0.5253857581277014, 0.21780949510893402, 0.026030837039037417, 0.07282095318396448, -0.6093002665622598, 2.0466066458317336] animal: [-0.08132595025854601, -1.8280616716238214, -0.4241550049238374, 1.3405833261217683, 1.3635302219051426, 0.19656106954281044, 1.0553637657141577, 0.8640316722860499, -0.34682275265131135, 0.27196141799987644, 0.9785603157742483, -3.1767493003780873, -0.7566904249011203, -1.1935303767007424, 0.2523177522167622, 0.33414675815038736, 0.21147820292953767, 0.2089073521353749, 0.36413859545070626, -0.3145854077725169, 0.8470589609352164, 0.8914477422324714, 0.06602846837066885, -1.1974184866543685, 1.6807019645814638] bird: [0.9317828273959833, -1.142927450658389, 1.1249556341704339, 0.7533022372085103, 0.039221572709652965, 0.5302815428039684, 1.1525754405638204, 0.5707370610821617, 0.01803607760778035, 0.9527229145321762, 1.0851468114908822, -0.4626041548552341, -0.5371489443168416, -0.8343285842461913, -0.09713481034287788, 0.8070233789520264, 0.21755780815430825, -0.6588132708557186, -0.7963193188039507, 0.12395864485237663, -0.18545774404118467, 1.311026289715281, 0.7764007851264465, -0.5776179488468618, 0.5640559901962993] The semantic similarity between a clue and a word is measured by the Euclidean distance between their vectors; a word is judged as intended by the clue with probability sigmoid(1/distance). The clue applies to a subset of words only if all words in the subset independently pass this probabilistic check. Possible target word subsets (pairs) to guess: [chicken,eagle], [eagle,pig], [chicken,pig]. Possible clue words: farm, animal, bird; prior over clues is uniform. Prior over target subsets is uniform. Speaker optimality alpha = 1.

model

A literal listener conditions on the clue correctly applying (all words in the subset independently pass the similarity check) and returns the target subset. A speaker chooses clues proportional to exp(alpha times the literal listener's log-probability of the target subset). A pragmatic listener inverts the speaker model to infer the intended target subset.

query

The posterior distribution over target word subsets given that the pragmatic listener hears the clue 'farm'. Represent each candidate pair as a two-element list of the two words, in alphabetical order.

answer spec dist/finite
{
  "kind": "dist",
  "domain": "finite",
  "support": [
    [
      "chicken",
      "eagle"
    ],
    [
      "chicken",
      "pig"
    ],
    [
      "eagle",
      "pig"
    ]
  ]
}
system prompt constant across problems
(system prompt loads here)
webppl primer solver context
(primer loads here)
01 realizations comparing webppl vs pyro
ground truth
webppl
1///fold: vectors
2var vectors = {eagle : Vector([-0.8186906894583743, -0.8443627918594182,
3 -0.04304780086785447, -0.8257634263841377,
4 -0.7607218950809542, 0.47786735164930183,
5 0.36942709316422206, 0.18560148725224498,
6 0.38625176009619944, 0.24384273963053932,
7 1.0355862286322068, -0.14170089242313555,
8 -0.17017960843359828, 0.27636172471279313,
9 -0.49477465481497807, -1.199206930890509,
10 1.0531720839078256, -0.5154875303291531,
11 0.30704269353337016, 1.5382356443196483,
12 -0.13215425501400774, 1.2222507503066664,
13 1.3819617662995949, -1.1579407453927437,
14 0.9439311306043343]),
15 pig: Vector([-0.9027808771549458, -1.4539105978263833,
16 0.5743098399154295, 1.3052815987119957,
17 -0.038556210348244704, -0.22144102997326148,
18 1.222050088622139, -0.027526643946408857,
19 -0.13265827668708097, 1.4799207507145387,
20 0.02371336629548181, -0.9405402658175948,
21 0.06556493358788004, -1.6556208133885402,
22 -0.44306373689318584, -0.475710035110901,
23 1.2435716830499404, -1.0677780309283533,
24 -0.03344465447168945, 0.16184568683816827,
25 0.8718035460475897, 2.082956682688621,
26 0.47430271385843953, -0.4479993650378608,
27 1.5192928553678355]),
28 chicken: Vector([-0.9555633717916903, -0.23467550895948608,
29 0.9081102168032618, 1.7681919864431317,
30 -0.3888166871286516, 1.2292398003323308,
31 1.0624961440319318, -0.3558803892040966,
32 -0.17024423658814317, 0.7046776782991592,
33 1.624196256505183, -1.1423231844008523,
34 -0.9490267652945451, -1.8004114037674281,
35 -0.026086280368055388, -2.089757256612839,
36 3.5660566372328693, -2.4178611093952225,
37 -0.7077960662621875, 0.9418434965990246,
38 0.438927172322575, 1.0891725023940724,
39 -0.1237861204326181, 0.7602054634506068,
40 3.0515580696224083]),
41 farm: Vector([-1.3191469349030631, -0.34747873883058705,
42 0.09525267994894762, 0.08014872654330456,
43 0.1179814806966339, -0.26926061020753783,
44 0.709033965954239, -0.6521777385143812,
45 1.0195239553589313, 0.7192612109870958,
46 1.1711460976059695, -1.0779079866249233,
47 -0.5443503049555966, 0.08523251153754875,
48 0.1455530206584687, -1.501097375488643,
49 1.1151234505440395, 0.0581591541412683,
50 -0.1102242123027589, 0.5253857581277014,
51 0.21780949510893402, 0.026030837039037417,
52 0.07282095318396448, -0.6093002665622598,
53 2.0466066458317336]),
54 animal: Vector([-0.08132595025854601, -1.8280616716238214,
55 -0.4241550049238374, 1.3405833261217683,
56 1.3635302219051426, 0.19656106954281044,
57 1.0553637657141577, 0.8640316722860499,
58 -0.34682275265131135, 0.27196141799987644,
59 0.9785603157742483, -3.1767493003780873,
60 -0.7566904249011203, -1.1935303767007424,
61 0.2523177522167622, 0.33414675815038736,
62 0.21147820292953767, 0.2089073521353749,
63 0.36413859545070626, -0.3145854077725169,
64 0.8470589609352164, 0.8914477422324714,
65 0.06602846837066885, -1.1974184866543685,
66 1.6807019645814638]),
67 bird: Vector([0.9317828273959833, -1.142927450658389,
68 1.1249556341704339, 0.7533022372085103,
69 0.039221572709652965, 0.5302815428039684,
70 1.1525754405638204, 0.5707370610821617,
71 0.01803607760778035, 0.9527229145321762,
72 1.0851468114908822, -0.4626041548552341,
73 -0.5371489443168416, -0.8343285842461913,
74 -0.09713481034287788, 0.8070233789520264,
75 0.21755780815430825, -0.6588132708557186,
76 -0.7963193188039507, 0.12395864485237663,
77 -0.18545774404118467, 1.311026289715281,
78 0.7764007851264465, -0.5776179488468618,
79 0.5640559901962993])
80};
81///
82
83var meaning = function(clue, words) {
84
85 var distance = function(vector1, vector2)
86 {
87 var squared = map(function(tuple)
88 {
89 return (tuple[0] - tuple[1])*(tuple[0] - tuple[1]);
90 }
91 , zip(ad.tensor.toScalars(vector1), ad.tensor.toScalars(vector2))
92 );
93
94
95 var answer = Math.sqrt(sum(squared));
96 return answer;
97 };
98
99 var sigmoid = function(num)
100 {
101 return 1/(1 + Math.exp(-1*num));
102 };
103
104 var trueFalse = function(clue, word)
105 {
106 var dist = distance(vectors[clue], vectors[word]);
107 var prob = sigmoid(1/dist);
108 return flip(prob);
109 };
110
111 var wordsVectors = map(function(word) {return trueFalse(clue, word);}, words);
112
113
114 return all(function(s) {return s;}, wordsVectors);
115};
116
117var wordsPrior = function()
118{
119 var pairs = [["chicken", "eagle"], ["eagle", "pig"], ["chicken", "pig"]];
120 return uniformDraw(pairs);
121};
122
123var cluePrior = function()
124{
125 return uniformDraw(["farm", "animal", "bird"]);
126};
127
128
129var literalListener = function(clue)
130{
131 Infer(function()
132 {
133 var randomSubset = wordsPrior();
134 var uttTruthVal = meaning(clue, randomSubset);
135 condition(uttTruthVal);
136 return randomSubset;
137 }
138 )
139};
140
141var alpha = 1;
142
143var speaker = function(subset)
144{
145 Infer(function()
146 {
147 var clue = cluePrior();
148 factor(alpha*literalListener(clue).score(subset));
149 return clue;
150 }
151 )
152};
153
154var pragmaticListener = function(clue)
155{
156 Infer(function()
157 {
158 var randomSubset = wordsPrior();
159 var s1 = speaker(randomSubset);
160 observe(s1, clue);
161 return randomSubset;
162 }
163 )
164};
165
166viz.table(pragmaticListener("farm"));
167
168var ANSWER = (pragmaticListener("farm"));
realization0.000
python
1
2vectors = {
3 "eagle": [-0.8186906894583743, -0.8443627918594182, -0.04304780086785447, -0.8257634263841377, -0.7607218950809542, 0.47786735164930183, 0.36942709316422206, 0.18560148725224498, 0.38625176009619944, 0.24384273963053932, 1.0355862286322068, -0.14170089242313555, -0.17017960843359828, 0.27636172471279313, -0.49477465481497807, -1.199206930890509, 1.0531720839078256, -0.5154875303291531, 0.30704269353337016, 1.5382356443196483, -0.13215425501400774, 1.2222507503066664, 1.3819617662995949, -1.1579407453927437, 0.9439311306043343],
4 "pig": [-0.9027808771549458, -1.4539105978263833, 0.5743098399154295, 1.3052815987119957, -0.038556210348244704, -0.22144102997326148, 1.222050088622139, -0.027526643946408857, -0.13265827668708097, 1.4799207507145387, 0.02371336629548181, -0.9405402658175948, 0.06556493358788004, -1.6556208133885402, -0.44306373689318584, -0.475710035110901, 1.2435716830499404, -1.0677780309283533, -0.03344465447168945, 0.16184568683816827, 0.8718035460475897, 2.082956682688621, 0.47430271385843953, -0.4479993650378608, 1.5192928553678355],
5 "chicken": [-0.9555633717916903, -0.23467550895948608, 0.9081102168032618, 1.7681919864431317, -0.3888166871286516, 1.2292398003323308, 1.0624961440319318, -0.3558803892040966, -0.17024423658814317, 0.7046776782991592, 1.624196256505183, -1.1423231844008523, -0.9490267652945451, -1.8004114037674281, -0.026086280368055388, -2.089757256612839, 3.5660566372328693, -2.4178611093952225, -0.7077960662621875, 0.9418434965990246, 0.438927172322575, 1.0891725023940724, -0.1237861204326181, 0.7602054634506068, 3.0515580696224083],
6 "farm": [-1.3191469349030631, -0.34747873883058705, 0.09525267994894762, 0.08014872654330456, 0.1179814806966339, -0.26926061020753783, 0.709033965954239, -0.6521777385143812, 1.0195239553589313, 0.7192612109870958, 1.1711460976059695, -1.0779079866249233, -0.5443503049555966, 0.08523251153754875, 0.1455530206584687, -1.501097375488643, 1.1151234505440395, 0.0581591541412683, -0.1102242123027589, 0.5253857581277014, 0.21780949510893402, 0.026030837039037417, 0.07282095318396448, -0.6093002665622598, 2.0466066458317336],
7 "animal": [-0.08132595025854601, -1.8280616716238214, -0.4241550049238374, 1.3405833261217683, 1.3635302219051426, 0.19656106954281044, 1.0553637657141577, 0.8640316722860499, -0.34682275265131135, 0.27196141799987644, 0.9785603157742483, -3.1767493003780873, -0.7566904249011203, -1.1935303767007424, 0.2523177522167622, 0.33414675815038736, 0.21147820292953767, 0.2089073521353749, 0.36413859545070626, -0.3145854077725169, 0.8470589609352164, 0.8914477422324714, 0.06602846837066885, -1.1974184866543685, 1.6807019645814638],
8 "bird": [0.9317828273959833, -1.142927450658389, 1.1249556341704339, 0.7533022372085103, 0.039221572709652965, 0.5302815428039684, 1.1525754405638204, 0.5707370610821617, 0.01803607760778035, 0.9527229145321762, 1.0851468114908822, -0.4626041548552341, -0.5371489443168416, -0.8343285842461913, -0.09713481034287788, 0.8070233789520264, 0.21755780815430825, -0.6588132708557186, -0.7963193188039507, 0.12395864485237663, -0.18545774404118467, 1.311026289715281, 0.7764007851264465, -0.5776179488468618, 0.5640559901962993],
9}
10pairs = [["chicken", "eagle"], ["eagle", "pig"], ["chicken", "pig"]]
11clues = ["farm", "animal", "bird"]
12alpha = 1.0
13
14
15def distance(a, b):
16 return math.sqrt(sum((x - y) ** 2 for x, y in zip(vectors[a], vectors[b])))
17
18
19def sigmoid(x):
20 return 1.0 / (1.0 + math.exp(-x))
21
22
23def prob_true(clue, word):
24 return sigmoid(1.0 / distance(clue, word))
25
26
27def meaning_prob(clue, words):
28 # meaning(clue, words) is true iff every word is independently true under
29 # its Bernoulli; P(true) is the product of the per-word probabilities.
30 p = 1.0
31 for w in words:
32 p *= prob_true(clue, w)
33 return p
34
35
36# Literal listener: pair posterior given the clue is true.
37def literal_listener(clue):
38 logp = torch.log(torch.tensor([meaning_prob(clue, pairs[i]) for i in range(3)]))
39
40 @pyro.infer.config_enumerate
41 def model():
42 sub = pyro.sample("sub", dist.Categorical(torch.ones(3)))
43 pyro.factor("cond", logp[sub]) # condition(uttTruthVal): weight by P(true)
44 return None
45
46 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)
47 m = marg["sub"]
48 sup = m.enumerate_support()
49 pr = m.log_prob(sup).exp()
50 return {int(i): p for i, p in zip(sup.tolist(), pr.tolist())}
51
52
53# Speaker: clue posterior given the intended subset.
54def speaker_logprobs(subset_idx):
55 ll = {cl: literal_listener(cl) for cl in clues}
56
57 @pyro.infer.config_enumerate
58 def model():
59 c = pyro.sample("clue", dist.Categorical(torch.ones(3)))
60 sc = torch.tensor([math.log(ll[clues[i]][subset_idx]) if ll[clues[i]][subset_idx] > 0 else -1e30
61 for i in range(3)])
62 pyro.factor("sc", alpha * sc[c])
63 return None
64
65 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)
66 m = marg["clue"]
67 sup = m.enumerate_support()
68 pr = m.log_prob(sup).exp()
69 return {clues[int(i)]: p for i, p in zip(sup.tolist(), pr.tolist())}
70
71
72# Pragmatic listener: pair posterior given the heard clue.
73def pragmatic_listener(clue):
74 spk = {i: speaker_logprobs(i) for i in range(3)}
75 sc = torch.tensor([math.log(spk[i].get(clue, 0.0)) if spk[i].get(clue, 0.0) > 0 else -1e30
76 for i in range(3)])
77
78 @pyro.infer.config_enumerate
79 def model():
80 sub = pyro.sample("sub", dist.Categorical(torch.ones(3)))
81 pyro.factor("obs", sc[sub])
82 return None
83
84 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)
85 m = marg["sub"]
86 sup = m.enumerate_support()
87 pr = m.log_prob(sup).exp()
88 out = {}
89 for i, p in zip(sup.tolist(), pr.tolist()):
90 out[json.dumps(sorted(pairs[int(i)]))] = p
91 return out
92
93
94ANSWER = pragmatic_listener("farm")
95
02answer overlay — webppl vs pyrodist/finite
webppl pyro3 bins
00.170.170.340.34["chicken","eagle"] A = 0.338 B = 0.338["chicken","eagle"] A = 0.338 B = 0.3380.340.34["chicken","eagle"]["chicken","pig"] A = 0.328 B = 0.328["chicken","pig"] A = 0.328 B = 0.3280.330.33["chicken","pig"]["eagle","pig"] A = 0.333 B = 0.333["eagle","pig"] A = 0.333 B = 0.3330.330.33["eagle","pig"]
03 verification
checkstatusevidence
GT self-consistency ok floor 0.0000 (tv)
solver re-derivation accept 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6
cross-language (pyro vs webppl) pass d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000
forestdb-dickson-speaker-cost / atom-1
answer dist/finite solver accept pyro pass 0.0000
00 statement source: forestdb.org/models/dickson-speaker-cost.md
given

Three objects: 'blue square', 'blue circle', 'green square' (each described by color and shape; object string is the concatenation 'color shape'). The object prior is uniform over these three objects. Seven possible utterances: 'blue', 'green', 'square', 'circle', 'blue square', 'blue circle', 'green square'. An utterance applies to an object if and only if the utterance string is a substring of the object's string. Utterance cost equals the cost parameter times the number of words in the utterance. The cost parameter is drawn uniformly from the 10 values: 0.05, 0.55, 1.05, 1.55, 2.05, 2.55, 3.05, 3.55, 4.05, 4.55. Speaker optimality alpha = 1.

model

A literal listener conditions on the utterance applying to the object and returns the object. A speaker chooses utterances proportional to exp(alpha times (literal listener log-probability of the object minus the utterance cost)). A pragmatic listener infers the intended object and the speaker's cost parameter jointly by inverting the speaker.

query

The joint posterior distribution over the intended object and the speaker's cost parameter given that the pragmatic listener hears 'blue'.

answer spec dist/finite
{
  "kind": "dist",
  "domain": "finite",
  "labels": {
    "record": {
      "obj": "string",
      "costParameter": "real"
    }
  }
}
system prompt constant across problems
(system prompt loads here)
webppl primer solver context
(primer loads here)
01 realizations comparing webppl vs pyro
ground truth
webppl
1// set of states
2var objects = [{color: "blue", shape: "square", string: "blue square"},
3 {color: "blue", shape: "circle", string: "blue circle"},
4 {color: "green", shape: "square", string: "green square"}]
5
6// prior over world states
7var objectPrior = function() {
8 var obj = uniformDraw(objects)
9 return obj.string
10}
11
12// set of utterances
13var utterances = ["blue", "green", "square", "circle",
14 "blue square", "blue circle", "green square"]
15
16// utterance cost function
17var cost = function(utterance, costParameter) {
18 var numWords = function(utterance) {
19 var split = utterance.split(" ")
20 return _.size(split)
21 }
22 return costParameter*numWords(utterance)
23};
24
25var costParameterPrior = function() {
26 return uniformDraw(_.range(0.05, 5, 0.5))
27}
28
29// meaning function to interpret the utterances
30var meaning = function(utterance, obj){
31 _.includes(obj, utterance)
32}
33
34// literal listener
35var literalListener = function(utterance){
36 Infer({model: function(){
37 var obj = objectPrior();
38 condition(meaning(utterance, obj))
39 return obj
40 }})
41}
42
43// set speaker optimality
44var alpha = 1
45
46// pragmatic speaker
47var speaker = function(obj,costParameter){
48 Infer({model: function(){
49 var utterance = uniformDraw(utterances)
50 factor(alpha * (literalListener(utterance).score(obj) -
51 cost(utterance,costParameter)))
52 return utterance
53 }})
54}
55
56// pragmatic listener
57var pragmaticListener = function(utterance){
58 Infer({model: function(){
59 var obj = objectPrior()
60 var costParameter = costParameterPrior()
61 observe(speaker(obj,costParameter),utterance)
62 return {obj, costParameter}
63 }})
64}
65
66
67display("cost parameter prior")
68viz(Infer(costParameterPrior))
69
70var listenerPosteriorBlue = pragmaticListener("blue")
71display("pragmatic listener hears \"blue\"")
72viz.table(marginalize(listenerPosteriorBlue, "obj"))
73viz(marginalize(listenerPosteriorBlue, "costParameter"))
74
75var listenerPosteriorBlueSquare = pragmaticListener("blue square")
76display("pragmatic listener hears \"blue square\"")
77viz.table(marginalize(listenerPosteriorBlueSquare, "obj"))
78viz(marginalize(listenerPosteriorBlueSquare, "costParameter"))
79
80var ANSWER = (pragmaticListener('blue'));
realization0.000
python
1
2objects = [
3 {"color": "blue", "shape": "square", "string": "blue square"},
4 {"color": "blue", "shape": "circle", "string": "blue circle"},
5 {"color": "green", "shape": "square", "string": "green square"},
6]
7obj_strings = [o["string"] for o in objects]
8utterances = ["blue", "green", "square", "circle", "blue square", "blue circle", "green square"]
9# costParameterPrior: uniformDraw(_.range(0.05, 5, 0.5))
10costParams = []
11v = 0.05
12while v < 5 - 1e-12:
13 costParams.append(round(v, 10))
14 v += 0.5
15alpha = 1.0
16NEG = -1e30
17
18
19def num_words(u):
20 return len(u.split(" "))
21
22
23def cost(u, cp):
24 return cp * num_words(u)
25
26
27def meaning(u, objstr):
28 return u in objstr # _.includes(obj, utterance) over the object string
29
30
31# Literal listener: object posterior given a literally-true utterance.
32def literal_listener(utterance):
33 mask = torch.tensor([0.0 if meaning(utterance, s) else NEG for s in obj_strings])
34
35 @pyro.infer.config_enumerate
36 def model():
37 o = pyro.sample("o", dist.Categorical(torch.ones(len(obj_strings))))
38 pyro.factor("cond", mask[o])
39 return None
40
41 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)
42 m = marg["o"]
43 sup = m.enumerate_support()
44 pr = m.log_prob(sup).exp()
45 return {int(i): p for i, p in zip(sup.tolist(), pr.tolist())}
46
47
48# Speaker: utterance posterior given the object and cost parameter.
49def speaker_logprobs(obj_idx, cp):
50 ll = {u: literal_listener(u) for u in utterances}
51
52 @pyro.infer.config_enumerate
53 def model():
54 u = pyro.sample("u", dist.Categorical(torch.ones(len(utterances))))
55 sc = []
56 for uu in utterances:
57 p = ll[uu].get(obj_idx, 0.0)
58 lp = math.log(p) if p > 0 else NEG
59 sc.append(alpha * (lp - cost(uu, cp)))
60 scores = torch.tensor(sc)
61 pyro.factor("sc", scores[u])
62 return None
63
64 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)
65 m = marg["u"]
66 sup = m.enumerate_support()
67 pr = m.log_prob(sup).exp()
68 return {utterances[int(i)]: p for i, p in zip(sup.tolist(), pr.tolist())}
69
70
71# Pragmatic listener: joint posterior over (obj, costParameter) via exact
72# enumeration of a single combined latent; the engine marginalizes.
73def pragmatic_listener(utterance):
74 nO = len(obj_strings)
75 nC = len(costParams)
76 combos = [(o, c) for o in range(nO) for c in range(nC)]
77 joint_prior = torch.ones(len(combos)) # uniform object x uniform cost
78 score_t = torch.zeros(len(combos))
79 for j, (o, c) in enumerate(combos):
80 spk = speaker_logprobs(o, costParams[c])
81 p = spk.get(utterance, 0.0)
82 score_t[j] = math.log(p) if p > 0 else NEG
83
84 @pyro.infer.config_enumerate
85 def model():
86 oc = pyro.sample("oc", dist.Categorical(joint_prior))
87 pyro.factor("obs", score_t[oc])
88 return None
89
90 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)
91 m = marg["oc"]
92 sup = m.enumerate_support()
93 pr = m.log_prob(sup).exp()
94 out = {}
95 for j, p in zip(sup.tolist(), pr.tolist()):
96 o, c = combos[int(j)]
97 rec = {"obj": obj_strings[o], "costParameter": costParams[c]}
98 out[json.dumps(rec, sort_keys=True)] = p
99 return out
100
101
102ANSWER = pragmatic_listener("blue")
103
02answer overlay — webppl vs pyrodist/finite
webppl pyro20 bins
00.0350.0350.0690.069{"costParameter":0.05,"obj":"blue circle"} A = 0.029 B = 0.029{"costParameter":0.05,"obj":"blue circle"} A = 0.029 B = 0.029{"costParameter":0.05,"obj":"blue circle"}{"costParameter":0.05,"obj":"blue square"} A = 0.036 B = 0.036{"costParameter":0.05,"obj":"blue square"} A = 0.036 B = 0.036{"costParameter":0.55,"obj":"blue circle"} A = 0.034 B = 0.034{"costParameter":0.55,"obj":"blue circle"} A = 0.034 B = 0.034{"costParameter":0.55,"obj":"blue circle"}{"costParameter":0.55,"obj":"blue square"} A = 0.044 B = 0.044{"costParameter":0.55,"obj":"blue square"} A = 0.044 B = 0.044{"costParameter":1.05,"obj":"blue circle"} A = 0.038 B = 0.038{"costParameter":1.05,"obj":"blue circle"} A = 0.038 B = 0.038{"costParameter":1.05,"obj":"blue circle"}{"costParameter":1.05,"obj":"blue square"} A = 0.052 B = 0.052{"costParameter":1.05,"obj":"blue square"} A = 0.052 B = 0.052{"costParameter":1.55,"obj":"blue circle"} A = 0.041 B = 0.041{"costParameter":1.55,"obj":"blue circle"} A = 0.041 B = 0.041{"costParameter":1.55,"obj":"blue circle"}{"costParameter":1.55,"obj":"blue square"} A = 0.058 B = 0.058{"costParameter":1.55,"obj":"blue square"} A = 0.058 B = 0.058{"costParameter":2.05,"obj":"blue circle"} A = 0.043 B = 0.043{"costParameter":2.05,"obj":"blue circle"} A = 0.043 B = 0.043{"costParameter":2.05,"obj":"blue circle"}{"costParameter":2.05,"obj":"blue square"} A = 0.062 B = 0.062{"costParameter":2.05,"obj":"blue square"} A = 0.062 B = 0.062{"costParameter":2.55,"obj":"blue circle"} A = 0.044 B = 0.044{"costParameter":2.55,"obj":"blue circle"} A = 0.044 B = 0.044{"costParameter":2.55,"obj":"blue circle"}{"costParameter":2.55,"obj":"blue square"} A = 0.065 B = 0.065{"costParameter":2.55,"obj":"blue square"} A = 0.065 B = 0.065{"costParameter":3.05,"obj":"blue circle"} A = 0.045 B = 0.045{"costParameter":3.05,"obj":"blue circle"} A = 0.045 B = 0.045{"costParameter":3.05,"obj":"blue circle"}{"costParameter":3.05,"obj":"blue square"} A = 0.067 B = 0.067{"costParameter":3.05,"obj":"blue square"} A = 0.067 B = 0.067{"costParameter":3.55,"obj":"blue circle"} A = 0.046 B = 0.046{"costParameter":3.55,"obj":"blue circle"} A = 0.046 B = 0.046{"costParameter":3.55,"obj":"blue circle"}{"costParameter":3.55,"obj":"blue square"} A = 0.068 B = 0.068{"costParameter":3.55,"obj":"blue square"} A = 0.068 B = 0.068{"costParameter":4.05,"obj":"blue circle"} A = 0.046 B = 0.046{"costParameter":4.05,"obj":"blue circle"} A = 0.046 B = 0.046{"costParameter":4.05,"obj":"blue circle"}{"costParameter":4.05,"obj":"blue square"} A = 0.069 B = 0.069{"costParameter":4.05,"obj":"blue square"} A = 0.069 B = 0.069{"costParameter":4.55,"obj":"blue circle"} A = 0.046 B = 0.046{"costParameter":4.55,"obj":"blue circle"} A = 0.046 B = 0.046{"costParameter":4.55,"obj":"blue circle"}{"costParameter":4.55,"obj":"blue square"} A = 0.069 B = 0.069{"costParameter":4.55,"obj":"blue square"} A = 0.069 B = 0.069
03 verification
checkstatusevidence
GT self-consistency ok floor 0.0000 (tv)
solver re-derivation accept 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6
cross-language (pyro vs webppl) pass d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000
forestdb-generics / atom-1
answer dist/real solver accept pyro pass 0.0000
00 statement source: forestdb.org/models/generics.md
given

The prevalence space is a grid of 40 values running from 0.01 to 0.985 in steps of 0.025: 0.01, 0.035, 0.060, 0.085, 0.110, 0.135, 0.160, 0.185, 0.210, 0.235, 0.260, 0.285, 0.310, 0.335, 0.360, 0.385, 0.410, 0.435, 0.460, 0.485, 0.510, 0.535, 0.560, 0.585, 0.610, 0.635, 0.660, 0.685, 0.710, 0.735, 0.760, 0.785, 0.810, 0.835, 0.860, 0.885, 0.910, 0.935, 0.960, 0.985. A discretized Beta distribution over this grid with mean parameter g and concentration parameter d assigns each grid value x a weight proportional to x^(g*d - 1) * (1-x)^((1-g)*d - 1). The prior over prevalence is a mixture of two such discretized Beta components: with probability 0.3, the component with g=0.5 and d=10; otherwise the component with g=0.01 and d=100.

model

Prevalence is drawn from a two-component mixture over the grid. One component concentrates mass near zero (a near-absent component with very low mean and high concentration). The other places mass according to a discretized Beta with the specified mean and concentration (the present component). The mixture weight is the prior probability that the property is present at all.

query

The marginal prior distribution over the numeric prevalence values.

answer spec dist/real
{
  "kind": "dist",
  "domain": "real"
}
system prompt constant across problems
(system prompt loads here)
webppl primer solver context
(primer loads here)
01 realizations comparing webppl vs pyro
ground truth
webppl
1///fold:
2// discretized range between 0 - 1
3var bins = _.range(0.01, 1, 0.025);
4
5// function returns a discretized Beta distribution
6var DiscreteBeta = cache(function(g, d){
7 var a = g * d, b = (1-g) * d;
8 var betaPDF = function(x){
9 return Math.pow(x, a-1)*Math.pow((1-x), b-1)
10 }
11 var probs = map(betaPDF, bins);
12 return Categorical({vs: bins, ps: probs})
13})
14///
15var priorModel = function(params){
16 Infer({model: function(){
17
18 var potential = params["potential"]
19 var g = params["prevalenceWhenPresent"]
20 var d = params["concentrationWhenPresent"]
21
22 var StableDistribution = DiscreteBeta(g, d)
23 var UnstableDistribution = DiscreteBeta(0.01, 100)
24
25 var prevalence = flip(potential) ?
26 sample(StableDistribution) :
27 sample(UnstableDistribution)
28
29 return {prevalence}
30
31 }})
32}
33
34var d = priorModel({potential: 0.3, prevalenceWhenPresent: 0.5, concentrationWhenPresent: 10});
35var ANSWER = marginalize(d, 'prevalence');
36
realization0.000
python
1
2# grid: 0.01 to 0.985 step 0.025 (40 values), matching _.range(0.01, 1, 0.025)
3bins = [round(0.01 + 0.025 * i, 10) for i in range(40)]
4
5def discrete_beta_probs(g, d):
6 a = g * d
7 b = (1 - g) * d
8 raw = [(x ** (a - 1)) * ((1 - x) ** (b - 1)) for x in bins]
9 Z = sum(raw)
10 return torch.tensor([r / Z for r in raw])
11
12stable = discrete_beta_probs(0.5, 10) # present component
13unstable = discrete_beta_probs(0.01, 100) # near-absent component
14potential = 0.3
15
16@pyro.infer.config_enumerate
17def model():
18 present = pyro.sample("present", dist.Bernoulli(potential)).long()
19 probs = torch.where(present.bool().unsqueeze(-1), stable, unstable)
20 idx = pyro.sample("idx", dist.Categorical(probs))
21 return idx
22
23marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(
24 model, lambda: None
25)["idx"]
26sup = marg.enumerate_support()
27probs = torch.exp(marg.log_prob(sup))
28ANSWER = {bins[int(i)]: float(p) for i, p in zip(sup, probs)}
29
02answer overlay — webppl vs pyrodist/real
webppl pyro71 bins · 0.01 … 0.99
00.320.320.640.640.200.400.600.800.01 · 0.6430.01 · 0.643x = 0.01 A = 0.6432 B = 0.6432 Δ = 0.0000x = 0.04 A = 0.0525 B = 0.0525 Δ = 0.0000x = 0.06 A = 0.0000 B = 0.0041 Δ = -0.0041x = 0.06 A = 0.0041 B = 0.0000 Δ = 0.0041x = 0.09 A = 0.0005 B = 0.0005 Δ = 0.0000x = 0.11 A = 0.0000 B = 0.0005 Δ = -0.0005x = 0.11 A = 0.0005 B = 0.0000 Δ = 0.0005x = 0.14 A = 0.0009 B = 0.0009 Δ = 0.0000x = 0.16 A = 0.0015 B = 0.0015 Δ = 0.0000x = 0.18 A = 0.0024 B = 0.0024 Δ = 0.0000x = 0.21 A = 0.0036 B = 0.0036 Δ = 0.0000x = 0.23 A = 0.0049 B = 0.0049 Δ = 0.0000x = 0.26 A = 0.0065 B = 0.0065 Δ = 0.0000x = 0.28 A = 0.0000 B = 0.0081 Δ = -0.0081x = 0.29 A = 0.0081 B = 0.0000 Δ = 0.0081x = 0.31 A = 0.0000 B = 0.0099 Δ = -0.0099x = 0.31 A = 0.0099 B = 0.0000 Δ = 0.0099x = 0.34 A = 0.0000 B = 0.0116 Δ = -0.0116x = 0.34 A = 0.0116 B = 0.0000 Δ = 0.0116x = 0.36 A = 0.0000 B = 0.0133 Δ = -0.0133x = 0.36 A = 0.0133 B = 0.0000 Δ = 0.0133x = 0.39 A = 0.0000 B = 0.0149 Δ = -0.0149x = 0.39 A = 0.0149 B = 0.0000 Δ = 0.0149x = 0.41 A = 0.0000 B = 0.0162 Δ = -0.0162x = 0.41 A = 0.0162 B = 0.0000 Δ = 0.0162x = 0.43 A = 0.0000 B = 0.0172 Δ = -0.0172x = 0.44 A = 0.0172 B = 0.0000 Δ = 0.0172x = 0.46 A = 0.0000 B = 0.0180 Δ = -0.0180x = 0.46 A = 0.0180 B = 0.0000 Δ = 0.0180x = 0.48 A = 0.0000 B = 0.0184 Δ = -0.0184x = 0.49 A = 0.0184 B = 0.0000 Δ = 0.0184x = 0.51 A = 0.0000 B = 0.0184 Δ = -0.0184x = 0.51 A = 0.0184 B = 0.0000 Δ = 0.0184x = 0.54 A = 0.0000 B = 0.0181 Δ = -0.0181x = 0.54 A = 0.0181 B = 0.0000 Δ = 0.0181x = 0.56 A = 0.0000 B = 0.0174 Δ = -0.0174x = 0.56 A = 0.0174 B = 0.0000 Δ = 0.0174x = 0.58 A = 0.0000 B = 0.0164 Δ = -0.0164x = 0.59 A = 0.0164 B = 0.0000 Δ = 0.0164x = 0.61 A = 0.0000 B = 0.0151 Δ = -0.0151x = 0.61 A = 0.0151 B = 0.0000 Δ = 0.0151x = 0.64 A = 0.0000 B = 0.0136 Δ = -0.0136x = 0.64 A = 0.0136 B = 0.0000 Δ = 0.0136x = 0.66 A = 0.0000 B = 0.0120 Δ = -0.0120x = 0.66 A = 0.0120 B = 0.0000 Δ = 0.0120x = 0.69 A = 0.0000 B = 0.0102 Δ = -0.0102x = 0.69 A = 0.0102 B = 0.0000 Δ = 0.0102x = 0.71 A = 0.0000 B = 0.0085 Δ = -0.0085x = 0.71 A = 0.0085 B = 0.0000 Δ = 0.0085x = 0.73 A = 0.0000 B = 0.0068 Δ = -0.0068x = 0.74 A = 0.0068 B = 0.0000 Δ = 0.0068x = 0.76 A = 0.0000 B = 0.0052 Δ = -0.0052x = 0.76 A = 0.0052 B = 0.0000 Δ = 0.0052x = 0.79 A = 0.0000 B = 0.0038 Δ = -0.0038x = 0.79 A = 0.0038 B = 0.0000 Δ = 0.0038x = 0.81 A = 0.0000 B = 0.0027 Δ = -0.0027x = 0.81 A = 0.0027 B = 0.0000 Δ = 0.0027x = 0.83 A = 0.0000 B = 0.0017 Δ = -0.0017x = 0.84 A = 0.0017 B = 0.0000 Δ = 0.0017x = 0.86 A = 0.0000 B = 0.0010 Δ = -0.0010x = 0.86 A = 0.0010 B = 0.0000 Δ = 0.0010x = 0.89 A = 0.0000 B = 0.0005 Δ = -0.0005x = 0.89 A = 0.0005 B = 0.0000 Δ = 0.0005x = 0.91 A = 0.0000 B = 0.0002 Δ = -0.0002x = 0.91 A = 0.0002 B = 0.0000 Δ = 0.0002x = 0.94 A = 0.0000 B = 0.0001 Δ = -0.0001x = 0.94 A = 0.0001 B = 0.0000 Δ = 0.0001x = 0.96 A = 0.0000 B = 0.0000 Δ = -0.0000x = 0.96 A = 0.0000 B = 0.0000 Δ = 0.0000x = 0.98 A = 0.0000 B = 0.0000 Δ = -0.0000x = 0.99 A = 0.0000 B = 0.0000 Δ = 0.0000
03 verification
checkstatusevidence
GT self-consistency ok floor 0.0000 (w1)
solver re-derivation accept 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6
cross-language (pyro vs webppl) pass d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000
forestdb-generics / atom-2
answer dist/real solver accept pyro pass 0.0000
00 statement source: forestdb.org/models/generics.md
given

Prevalence is discretized into 50 bins: values 0.01, 0.03, 0.05, ..., 0.99 (i.e., starting at 0.01 with step 0.02, rounded to 2 decimal places). The threshold for the generic is drawn uniformly from the 49 midpoints between consecutive bins: 0.02, 0.04, ..., 0.98. A Beta-mixture prior over prevalence: with probability 0.3 (the 'potential' parameter), the prevalence is drawn from a discretized Beta with mean 0.5 and concentration 10 (g=0.5, d=10, so a=g*d=5, b=(1-g)*d=5); otherwise, it is drawn from a discretized Beta with g=0.01 and d=100 (a very low-prevalence unstable distribution with a=1, b=99). The discrete Beta weight for bin x is proportional to x^(a-1) * (1-x)^(b-1).

model

The generic statement is true if the prevalence exceeds the threshold. A literal listener draws a prevalence from the prior and a threshold uniformly, conditions on the generic being true, and returns the prevalence.

query

The posterior distribution over prevalence given that the literal listener hears a generic statement.

answer spec dist/real
{
  "kind": "dist",
  "domain": "real"
}
system prompt constant across problems
(system prompt loads here)
webppl primer solver context
(primer loads here)
01 realizations comparing webppl vs pyro
ground truth
webppl
1///fold:
2// discretized range between 0 - 1
3var bins = map(function(x){
4 _.round(x, 2);
5}, _.range(0.01, 1, 0.02));
6
7var thresholdBins = map2(function(x,y){
8 var d = (y - x)/ 2;
9 return x + d
10}, bins.slice(0, bins.length - 1), bins.slice(1, bins.length))
11
12// function returns a discretized Beta distribution
13var DiscreteBeta = cache(function(g, d){
14 var a = g * d, b = (1-g) * d;
15 var betaPDF = function(x){
16 return Math.pow(x, a-1)*Math.pow((1-x), b-1)
17 }
18 var probs = map(betaPDF, bins);
19 return Categorical({vs: bins, ps: probs})
20})
21
22var priorModel = function(params){
23 Infer({model: function(){
24
25 var potential = params["potential"]
26 var g = params["prevalenceWhenPresent"]
27 var d = params["concentrationWhenPresent"]
28
29 var StableDistribution = DiscreteBeta(g, d)
30 var UnstableDistribution = DiscreteBeta(0.01, 100)
31
32 var prevalence = flip(potential) ?
33 sample(StableDistribution) :
34 sample(UnstableDistribution)
35
36 return prevalence
37
38 }})
39}
40///
41var meaning = function(utterance, prevalence, threshold) {
42 return (utterance == 'generic') ? prevalence > threshold : true
43}
44var thresholdPrior = function() { return uniformDraw(thresholdBins) };
45
46var statePrior = priorModel({
47 potential: 0.3,
48 prevalenceWhenPresent: 0.5, // how prevalent under the stable cause
49 concentrationWhenPresent: 10 // the inverse-variance of the stable cause
50})
51
52display("prevalence prior")
53viz(statePrior)
54
55var listener = cache(function(utterance) {
56 Infer({model: function(){
57 var prevalence = sample(statePrior)
58 var threshold = thresholdPrior()
59 var m = meaning(utterance, prevalence, threshold)
60 condition(m)
61 return prevalence
62 }})
63})
64
65display("listener posterior")
66listener("generic")
67
68var ANSWER = (listener('generic'));
realization0.000
python
1# discretized prevalence bins: 0.01, 0.03, ..., 0.99
2bins = [round(0.01 + 0.02 * k, 2) for k in range(50)]
3bins_t = torch.tensor(bins)
4# threshold midpoints between consecutive bins
5threshold_bins = [round((bins[i] + bins[i + 1]) / 2.0, 10) for i in range(len(bins) - 1)]
6threshold_t = torch.tensor(threshold_bins)
7
8def discrete_beta_probs(g, d):
9 a = g * d
10 b = (1.0 - g) * d
11 w = torch.pow(bins_t, a - 1.0) * torch.pow(1.0 - bins_t, b - 1.0)
12 return w / w.sum()
13
14potential = 0.3
15stable = discrete_beta_probs(0.5, 10.0)
16unstable = discrete_beta_probs(0.01, 100.0)
17# mixture prior over prevalence bins
18prior_probs = potential * stable + (1.0 - potential) * unstable
19
20@pyro.infer.config_enumerate
21def model():
22 prev_i = pyro.sample("prevalence", dist.Categorical(prior_probs))
23 thr_i = pyro.sample("threshold", dist.Categorical(torch.ones(len(threshold_bins))))
24 prevalence = bins_t[prev_i]
25 threshold = threshold_t[thr_i]
26 # generic is true iff prevalence > threshold
27 on = prevalence > threshold
28 pyro.factor("meaning", torch.where(on, torch.tensor(0.0), torch.tensor(float("-inf"))))
29 return prev_i
30
31marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)
32m = marg["prevalence"]
33sup = m.enumerate_support()
34probs = m.log_prob(sup).exp()
35ANSWER = {float(bins_t[int(sup[i].item())]): float(probs[i]) for i in range(len(sup)) if float(probs[i]) > 0}
36
02answer overlay — webppl vs pyrodist/real
webppl pyro49 bins · 0.03 … 0.99
00.0260.0260.0510.0510.200.400.600.800.55 · 0.0510.55 · 0.051x = 0.03 A = 0.0110 B = 0.0110 Δ = 0.0000x = 0.05 A = 0.0029 B = 0.0029 Δ = 0.0000x = 0.07 A = 0.0006 B = 0.0006 Δ = 0.0000x = 0.09 A = 0.0002 B = 0.0002 Δ = 0.0000x = 0.11 A = 0.0002 B = 0.0002 Δ = 0.0000x = 0.13 A = 0.0005 B = 0.0005 Δ = 0.0000x = 0.15 A = 0.0009 B = 0.0009 Δ = 0.0000x = 0.17 A = 0.0016 B = 0.0016 Δ = 0.0000x = 0.19 A = 0.0026 B = 0.0026 Δ = 0.0000x = 0.21 A = 0.0038 B = 0.0038 Δ = 0.0000x = 0.23 A = 0.0055 B = 0.0055 Δ = 0.0000x = 0.25 A = 0.0075 B = 0.0075 Δ = 0.0000x = 0.27 A = 0.0099 B = 0.0099 Δ = 0.0000x = 0.29 A = 0.0128 B = 0.0128 Δ = -0.0000x = 0.31 A = 0.0159 B = 0.0159 Δ = 0.0000x = 0.33 A = 0.0194 B = 0.0194 Δ = 0.0000x = 0.35 A = 0.0231 B = 0.0231 Δ = -0.0000x = 0.37 A = 0.0269 B = 0.0269 Δ = -0.0000x = 0.39 A = 0.0308 B = 0.0308 Δ = -0.0000x = 0.41 A = 0.0347 B = 0.0347 Δ = 0.0000x = 0.43 A = 0.0384 B = 0.0384 Δ = 0.0000x = 0.45 A = 0.0418 B = 0.0418 Δ = 0.0000x = 0.47 A = 0.0449 B = 0.0449 Δ = -0.0000x = 0.49 A = 0.0474 B = 0.0474 Δ = -0.0000x = 0.51 A = 0.0494 B = 0.0494 Δ = -0.0000x = 0.53 A = 0.0507 B = 0.0507 Δ = -0.0000x = 0.55 A = 0.0513 B = 0.0513 Δ = 0.0000x = 0.57 A = 0.0512 B = 0.0512 Δ = 0.0000x = 0.59 A = 0.0503 B = 0.0503 Δ = -0.0000x = 0.61 A = 0.0487 B = 0.0487 Δ = -0.0000x = 0.63 A = 0.0464 B = 0.0464 Δ = -0.0000x = 0.65 A = 0.0434 B = 0.0434 Δ = 0.0000x = 0.67 A = 0.0400 B = 0.0400 Δ = 0.0000x = 0.69 A = 0.0361 B = 0.0361 Δ = 0.0000x = 0.71 A = 0.0319 B = 0.0319 Δ = 0.0000x = 0.73 A = 0.0275 B = 0.0275 Δ = 0.0000x = 0.75 A = 0.0232 B = 0.0232 Δ = 0.0000x = 0.77 A = 0.0189 B = 0.0189 Δ = -0.0000x = 0.79 A = 0.0150 B = 0.0150 Δ = 0.0000x = 0.81 A = 0.0114 B = 0.0114 Δ = 0.0000x = 0.83 A = 0.0082 B = 0.0082 Δ = -0.0000x = 0.85 A = 0.0056 B = 0.0056 Δ = 0.0000x = 0.87 A = 0.0036 B = 0.0036 Δ = 0.0000x = 0.89 A = 0.0020 B = 0.0020 Δ = 0.0000x = 0.91 A = 0.0010 B = 0.0010 Δ = 0.0000x = 0.93 A = 0.0004 B = 0.0004 Δ = 0.0000x = 0.95 A = 0.0001 B = 0.0001 Δ = 0.0000x = 0.97 A = 0.0000 B = 0.0000 Δ = 0.0000x = 0.99 A = 0.0000 B = 0.0000 Δ = -0.0000
03 verification
checkstatusevidence
GT self-consistency ok floor 0.0000 (w1)
solver re-derivation accept 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6
cross-language (pyro vs webppl) pass d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000
forestdb-hlms-comparison-class / atom-1
answer dist/finite solver accept pyro pass 0.0000
00 statement source: data/sources/forestdb.org/models/hlms-comparison-class.md
given

Height distributions are discretized using binParam = 3. The superordinate category (all people) has height distribution Gaussian(mu=0, sigma=1). The basketball player subcategory has Gaussian(mu=1, sigma=0.5). State values are 18 evenly spaced points from -3 to 3 (exclusive) in steps of 1/3. State probabilities are proportional to the Gaussian PDF evaluated at each state value. Thresholds for 'tall' are each state value minus 1/(2*binParam); thresholds for 'short' are each state value plus 1/(2*binParam); each threshold set is drawn uniformly. Comparison class is drawn uniformly from {superordinate, subordinate}. Speaker optimality alpha = 5. Three utterances: 'tall', 'short', 'silence' (silence is always true).

model

A literal listener conditions on the utterance being true (tall: state exceeds the tall threshold; short: state is below the short threshold; silence: always true) given the state distribution for the current comparison class. A speaker chooses utterances proportional to exp(alpha times the literal listener's log-probability of the state). A pragmatic listener infers the entity's state, the tall threshold, the short threshold, and the comparison class jointly by inverting the speaker, conditioning on the utterance. The thresholds are drawn from the same uniform priors as in the given; the same threshold pair is passed to the speaker and through to the literal listener.

query

The marginal posterior distribution over the comparison class (superordinate vs. subordinate) for a pragmatic listener who hears 'tall' about a basketball player.

answer spec dist/finite
{
  "kind": "dist",
  "domain": "finite",
  "support": [
    "superordinate",
    "subordinate"
  ]
}
system prompt constant across problems
(system prompt loads here)
webppl primer solver context
(primer loads here)
01 realizations comparing webppl vs pyro
ground truth
webppl
1///fold:
2// helper function
3var exp = function(x){return Math.exp(x)}
4
5// for discretization
6var binParam = 3;
7
8// information about the superordinate category prior
9// e.g., the height distribution for all people
10var superordinate_params = {mu: 0, sigma: 1};
11
12// calculate the range in pre-defined steps;
13// these values correspond to possible heights
14var stateVals = _.range(superordinate_params.mu - 3 * superordinate_params.sigma,
15 superordinate_params.mu + 3 * superordinate_params.sigma,
16 superordinate_params.sigma/binParam)
17
18// for each possible height, calculate its probability of occurrence
19var stateProbs = cache(function(stateParams){
20 return map(function(s){
21 exp(Gaussian(stateParams).score(s))
22 }, stateVals)
23});
24
25// generate a statePrior using the possible heights and their probabilities
26var generateStatePrior = cache(function(stateParams) {
27 return Infer({
28 model: function(){
29 return categorical({vs: stateVals, ps: stateProbs(stateParams)})
30 }
31 })
32});
33
34// generate the uniform threshold prior
35var thresholdBins ={
36 positive: map(function(x){
37 return x - (1/(binParam*2));
38 }, sort(stateVals)),
39 negative: map(function(x){
40 return x + (1/(binParam*2));
41 }, sort(stateVals))
42};
43
44var thresholdPrior = cache(function(form){
45 return Infer({
46 model: function() { return uniformDraw(thresholdBins[form]) }
47 });
48});
49///
50
51// information about the superordinate category priors
52var subParams = {
53 gymnasts: {mu: -1, sigma: 0.5}, // gymnast heights
54 soccerPlayers: {mu: 0, sigma: 0.5}, // soccer player heights
55 basketballPlayers: {mu: 1, sigma: 0.5} // basketball player heights
56}
57
58// possible utterances can be either positive (tall) or negative (short) or a null utterance
59var utterances = ["tall", "short", "silence"]
60
61// meaning function for utterances
62var meaning = function(utterance, state, thresholds) {
63 utterance == "tall" ? state > thresholds.tall :
64 utterance == "short" ? state < thresholds.short :
65 true
66}
67
68// assume a uniform prior over comparison classes
69var classPrior = Infer({
70 model: function(){return uniformDraw(["subordinate", "superordinate"])}
71});
72
73// set speaker optimality
74var alpha = 5;
75
76var literalListener = cache(
77 function(utterance, thresholds, comparisonClass) {
78 Infer({model: function(){
79 var StatePrior = generateStatePrior(comparisonClass)
80 var state = sample(StatePrior);
81 var m = meaning(utterance, state, thresholds);
82 condition(m);
83 return state;
84 }})
85 }, 10000 // limit cache size
86)
87
88var speaker1 = cache(
89 function(state, thresholds, comparisonClass) {
90 Infer({model: function(){
91 var utterance = uniformDraw(utterances);
92 var L0 = literalListener(utterance, thresholds, comparisonClass);
93 factor( alpha * L0.score(state) );
94 return utterance;
95 }})
96 }, 10000 // limit cache size
97)
98
99var pragmaticListener = cache(function(utterance, subordinate_params) {
100 Infer({model: function(){
101
102 var statePrior = generateStatePrior(subordinate_params);
103 var state = sample(statePrior);
104 // separate thresholds for positive adjective and negative adjective
105 var thresholds = {
106 tall: sample(thresholdPrior("positive")),
107 short: sample(thresholdPrior("negative"))
108 }
109
110 // uncertainty about the comparison class (superordinate vs. subordinate)
111 var c = sample(classPrior)
112 var comparisonClass = c == "subordinate" ? subordinate_params : superordinate_params
113
114 var S1 = speaker1(state, thresholds, comparisonClass);
115 observe(S1, utterance);
116
117 return { comparisonClass: c, state : state }
118 }})
119}, 10000 // limit cache size
120 )
121
122// the possible experiment conditions:
123// you hear that someone is a member of a subordinate category
124// then you are told that they are tall/short;
125// the task is to figure out the implicit comparison class
126var exptConditions = [
127 {utt: "tall", sub: "basketballPlayers"},
128 {utt: "short", sub: "basketballPlayers"},
129 {utt: "tall", sub: "soccerPlayers"},
130 {utt: "short", sub: "soccerPlayers"},
131 {utt: "tall", sub: "gymnasts"},
132 {utt: "short", sub: "gymnasts"}
133];
134
135// generate structure predictions by mapping through the experiment conditions
136var L1predictions = map(function(stim){
137 var L1posterior = pragmaticListener(stim.utt, subParams[stim.sub])
138 return {
139 utterance: stim.utt,
140 "P(superordinate comparison class)": exp(marginalize(L1posterior, "comparisonClass").score("superordinate")),
141 "subordinate category": stim.sub,
142 model: "L1"
143 }
144}, exptConditions)
145
146display("the basketball player is tall")
147display("--> height = " + expectation(marginalize(pragmaticListener("tall",{mu: 1, sigma: 0.5}), "state")))
148display("the basketball player is short")
149display("--> height = " + expectation(marginalize(pragmaticListener("short",{mu: 1, sigma: 0.5}), "state")))
150
151display("probability of superordinate comparison class (i.e., tall for all people)")
152viz.bar(L1predictions, {groupBy: "subordinate category"})
153
154var ANSWER = (marginalize(pragmaticListener("tall", subParams["basketballPlayers"]), "comparisonClass"));
realization0.000
python
1
2binParam = 3
3sup_mu, sup_sigma = 0.0, 1.0
4sub_mu, sub_sigma = 1.0, 0.5 # basketballPlayers
5step = sup_sigma / binParam
6stateVals = []
7v = sup_mu - 3 * sup_sigma
8while v < sup_mu + 3 * sup_sigma - 1e-10:
9 stateVals.append(round(v, 10))
10 v += step
11nS = len(stateVals)
12
13
14def gaussian_pdf(x, mu, sig):
15 return math.exp(-0.5 * ((x - mu) / sig) ** 2) / (sig * math.sqrt(2 * math.pi))
16
17
18def state_probs(mu, sig):
19 ps = [gaussian_pdf(s, mu, sig) for s in stateVals]
20 t = sum(ps)
21 return torch.tensor([p / t for p in ps])
22
23
24sorted_sv = sorted(stateVals)
25tall_thr = [s - 1.0 / (binParam * 2) for s in sorted_sv]
26short_thr = [s + 1.0 / (binParam * 2) for s in sorted_sv]
27nT = len(tall_thr)
28utterances = ["tall", "short", "silence"]
29alpha = 5.0
30classes = ["subordinate", "superordinate"]
31class_params = [(sub_mu, sub_sigma), (sup_mu, sup_sigma)]
32class_sp = [state_probs(*class_params[ci]) for ci in range(2)]
33
34
35def meaning(utt, state, tall_t, short_t):
36 if utt == "tall":
37 return state > tall_t
38 if utt == "short":
39 return state < short_t
40 return True
41
42
43# Literal listener: state posterior given utterance, thresholds, comparison class.
44def literal_listener(ui, ti, shi, ci):
45 sp = class_sp[ci]
46 mask = torch.tensor([0.0 if meaning(utterances[ui], stateVals[s], tall_thr[ti], short_thr[shi])
47 else -1e30 for s in range(nS)])
48
49 @pyro.infer.config_enumerate
50 def model():
51 st = pyro.sample("st", dist.Categorical(sp))
52 pyro.factor("cond", mask[st])
53 return None
54
55 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)
56 m = marg["st"]
57 sup = m.enumerate_support()
58 pr = m.log_prob(sup).exp()
59 out = [0.0] * nS
60 for i, p in zip(sup.tolist(), pr.tolist()):
61 out[int(i)] = p
62 return out
63
64
65def pragmatic_listener(ui):
66 sub_sp = state_probs(sub_mu, sub_sigma)
67 # L0 table, cached per (class, thresholds): each level genuinely inferred.
68 ll = {}
69 for ci in range(2):
70 for ti in range(nT):
71 for shi in range(nT):
72 ll[(ci, ti, shi)] = [literal_listener(u, ti, shi, ci) for u in range(3)]
73
74 # Speaker (S1): utterance posterior given (state, thresholds, class), inferred.
75 score = torch.full((2, nS, nT, nT), -1e30)
76 for ci in range(2):
77 for ti in range(nT):
78 for shi in range(nT):
79 l0 = ll[(ci, ti, shi)]
80 for si in range(nS):
81 sc = torch.tensor([alpha * (math.log(l0[u][si]) if l0[u][si] > 0 else -1e30)
82 for u in range(3)])
83
84 @pyro.infer.config_enumerate
85 def smodel():
86 u = pyro.sample("u", dist.Categorical(torch.ones(3)))
87 pyro.factor("sc", sc[u])
88 return None
89
90 sm = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(smodel, lambda: None)
91 mm = sm["u"]
92 ssup = mm.enumerate_support()
93 spr = mm.log_prob(ssup).exp()
94 for i, p in zip(ssup.tolist(), spr.tolist()):
95 if int(i) == ui:
96 score[ci, si, ti, shi] = math.log(p) if p > 0 else -1e30
97
98 # Pragmatic listener (L1): comparison-class marginal, inferred by enumeration.
99 @pyro.infer.config_enumerate
100 def model():
101 st = pyro.sample("st", dist.Categorical(sub_sp))
102 tt = pyro.sample("tt", dist.Categorical(torch.ones(nT)))
103 sh = pyro.sample("sh", dist.Categorical(torch.ones(nT)))
104 cl = pyro.sample("cl", dist.Categorical(torch.ones(2)))
105 pyro.factor("obs", score[cl, st, tt, sh])
106 return None
107
108 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)
109 m = marg["cl"]
110 sup = m.enumerate_support()
111 pr = m.log_prob(sup).exp()
112 return {classes[int(i)]: p for i, p in zip(sup.tolist(), pr.tolist())}
113
114
115ANSWER = pragmatic_listener(0) # "tall"
116
02answer overlay — webppl vs pyrodist/finite
webppl pyro2 bins
00.290.290.580.58subordinate A = 0.419 B = 0.419subordinate A = 0.419 B = 0.4190.420.42subordinatesuperordinate A = 0.581 B = 0.581superordinate A = 0.581 B = 0.5810.580.58superordinate
03 verification
checkstatusevidence
GT self-consistency ok floor 0.0000 (tv)
solver re-derivation accept 1/2 solvers · d=[0.000, 0.115] · claude-sonnet-4-6
cross-language (pyro vs webppl) pass d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000
forestdb-jmr-irony-extension / atom-1
answer dist/finite solver accept pyro pass 0.0000
00 statement source: data/sources/forestdb.org/models/jmr-irony-extension.md
given

Weather states: ["terrible", "bad", "ok", "good", "amazing"]. State prior: categorical with unnormalized weights [1, 5, 40, 40, 40] over those states (California context — benign weather is far more likely). Utterances are the same five state labels, drawn uniformly. Valence prior given state: P(valence = -1 | terrible) = 0.99, P(valence = -1 | bad) = 0.90, P(valence = -1 | ok) = 0.50, P(valence = -1 | good) = 0.09, P(valence = -1 | amazing) = 0.01; valence is +1 otherwise. Arousal values: [0.1, 0.3, 0.5, 0.7, 0.9]. Arousal prior given state (categorical with unnormalized weights over those values): terrible → [1,10,30,45,50], bad → [1,5,25,40,45], ok → [50,45,30,10,1], good → [1,5,25,40,45], amazing → [1,10,30,45,50]. Goals: ["goalState", "goalValence", "goalArousal"], drawn with equal probability. A speaker's goal is satisfied by the listener inferring the state (for goalState), the valence (for goalValence), or the arousal (for goalArousal). The literal interpretation of an utterance is true iff the utterance label equals the world state. Speaker rationality weight: 1.

model

A pragmatic listener hears an utterance and jointly infers the world state, the speaker's valence, and the speaker's arousal. The listener's prior over (state, valence, arousal, goal) is the product of the independent marginal priors. A literal listener, given an utterance and a goal, conditions on the utterance being literally true of the state and returns the goal-relevant quantity (state, valence, or arousal). A speaker, given (state, valence, arousal, goal), chooses an utterance with probability proportional to how well a literal listener with that goal would recover the goal-relevant quantity. The pragmatic listener observes the speaker's chosen utterance and marginalizes over goals, returning the joint (state, valence, arousal).

query

The posterior joint distribution over (state, valence, arousal) for a pragmatic listener who hears the utterance "terrible".

answer spec dist/finite
{
  "kind": "dist",
  "domain": "finite",
  "labels": {
    "record": {
      "state": "string",
      "valence": "int",
      "arousal": "real"
    }
  }
}
system prompt constant across problems
(system prompt loads here)
webppl primer solver context
(primer loads here)
01 realizations comparing webppl vs pyro
ground truth
webppl
1// There are three possible states the weather could be in:
2// terrible, ok, or amazing
3var states = ['terrible','bad','ok','good','amazing']
4
5// Since we are in California, the prior over these states
6// are the following. Once could also imagine this being
7// the prior in a certain context, e.g. when it's clearly
8// sunny and nice out.
9var statePrior = function() {
10 categorical([1,5,40,40,40], states)
11}
12
13// Valence prior defined in terms of negative valence.
14// If the current state is terrible, it's extremely likely
15// that the valence associated is negative. If it's ok, then
16// the valence could be negative or positive with equal
17// probability.
18var valencePrior = function(state) {
19 state === "terrible" ? flip(0.99) ? -1 : 1 :
20 state === "bad" ? flip(0.90) ? -1 : 1 :
21 state === "ok" ? flip(0.5) ? -1 : 1 :
22 state === "good" ? flip(0.09) ? -1 : 1 :
23 state === "amazing" ? flip(0.01) ? -1 : 1 :
24 true
25}
26
27// Define binary arousals (could model as continuous).
28// var arousals = ["low", "high"]
29var arousals = [.1,.3,.5,.7,.9]
30
31// Define goals and goal priors. Could want to communicate state of the world,
32// valence about it, or arousal (intensity of feeling) about it.
33var goals = ["goalState", "goalValence", "goalArousal"]
34
35var goalPrior = function() {
36 categorical([1, 1, 1], goals)
37}
38
39// Assume possible utterances are identical to possible states
40var utterances = states
41
42// Assume cost of utterances is uniform.
43var utterancePrior = function() {
44 uniformDraw(utterances)
45}
46
47// Sample arousal given a state.
48var arousalPrior = function(state) {
49 state === "terrible" ? categorical([1,10,30,45,50], arousals) :
50 state === "bad" ? categorical([1,5,25,40,45], arousals) :
51 state === "ok" ? categorical([50,45,30,10,1], arousals) :
52 state === "good" ? categorical([1,5,25,40,45], arousals) :
53 state === "amazing" ? categorical([1,10,30,45,50], arousals) :
54 true
55}
56
57// Literal interpretation is just whether utterance equals state
58var literalInterpretation = function(utterance, state) {
59 utterance === state
60}
61
62// A speaker's goal is satisfied if the listener infers the correct
63// and relevant information.
64var goalState = function(goal, state, valence, arousal) {
65 goal === "goalState" ? state :
66 goal === "goalValence" ? valence :
67 goal === "goalArousal" ? arousal :
68 true
69}
70
71// Define a literal listener
72var literalListener = function(utterance, goal) {
73 Infer({model: function(){
74 var state = uniformDraw(states)
75 var valence = valencePrior(state)
76 var arousal = arousalPrior(state)
77 condition(literalInterpretation(utterance,state))
78 return goalState(goal, state, valence, arousal)
79 }})
80}
81
82//The speaker takes in a state, valence, arousal, and a goal and returns an utterance
83//based on the probability of the literalListener arriving at the correct
84//state given a goalState
85var speaker = function(state, valence, arousal, goal) {
86 Infer({model: function(){
87 var utterance = utterancePrior()
88 factor(1 * literalListener(utterance,
89 goal).score(goalState(goal,
90 state,
91 valence,
92 arousal)))
93 return utterance
94 }})
95}
96
97// Define a pragmatic listener
98var pragmaticListener = function(utterance) {
99 Infer({model: function(){
100 var state = statePrior()
101 var valence = valencePrior(state)
102 var arousal = arousalPrior(state)
103 var goal = goalPrior()
104 observe(speaker(state, valence, arousal, goal),utterance)
105 return {state,valence, arousal}
106 }})
107}
108
109viz.table(pragmaticListener('terrible'))
110
111var ANSWER = (pragmaticListener("terrible"));
realization0.000
python
1
2NEG = -1e30
3states = ["terrible", "bad", "ok", "good", "amazing"]
4state_prior = torch.tensor([1.0, 5.0, 40.0, 40.0, 40.0])
5val_neg_p = {"terrible": 0.99, "bad": 0.90, "ok": 0.5, "good": 0.09, "amazing": 0.01}
6arousals = [.1, .3, .5, .7, .9]
7arousal_w = {"terrible": [1, 10, 30, 45, 50], "bad": [1, 5, 25, 40, 45],
8 "ok": [50, 45, 30, 10, 1], "good": [1, 5, 25, 40, 45],
9 "amazing": [1, 10, 30, 45, 50]}
10goals = ["goalState", "goalValence", "goalArousal"]
11goal_prior = torch.tensor([1.0, 1.0, 1.0])
12utterances = states[:]
13valences = [-1, 1]
14
15
16def goal_value(goal, st_idx, val, ar):
17 if goal == "goalState":
18 return ("state", states[st_idx])
19 if goal == "goalValence":
20 return ("val", val)
21 return ("ar", ar)
22
23
24# Literal listener: distribution over the goal-relevant quantity, inferred.
25def literal_listener(utt_idx, goal):
26 @pyro.infer.config_enumerate
27 def model():
28 st = pyro.sample("st", dist.Categorical(torch.ones(len(states))))
29 vp = torch.stack([torch.tensor([val_neg_p[s], 1 - val_neg_p[s]]) for s in states])
30 val = pyro.sample("val", dist.Categorical(vp[st]))
31 ap = torch.stack([torch.tensor([float(x) for x in arousal_w[s]]) / sum(arousal_w[s])
32 for s in states])
33 ar = pyro.sample("ar", dist.Categorical(ap[st]))
34 # literalInterpretation: utterance === state
35 pyro.factor("cond", torch.where(st == utt_idx, torch.tensor(0.0), torch.tensor(NEG)))
36 return None
37
38 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)
39 if goal == "goalState":
40 m = marg["st"]
41 sup = m.enumerate_support()
42 pr = m.log_prob(sup).exp()
43 return {("state", states[int(i)]): p for i, p in zip(sup.tolist(), pr.tolist())}
44 if goal == "goalValence":
45 m = marg["val"]
46 sup = m.enumerate_support()
47 pr = m.log_prob(sup).exp()
48 return {("val", valences[int(i)]): p for i, p in zip(sup.tolist(), pr.tolist())}
49 m = marg["ar"]
50 sup = m.enumerate_support()
51 pr = m.log_prob(sup).exp()
52 return {("ar", arousals[int(i)]): p for i, p in zip(sup.tolist(), pr.tolist())}
53
54
55spk_cache = {}
56
57
58# Speaker: utterance posterior given (state, valence, arousal, goal), inferred.
59def speaker_logprobs(st_idx, val, ar, goal):
60 gv = goal_value(goal, st_idx, val, ar)
61 key = (gv, goal)
62 if key in spk_cache:
63 return spk_cache[key]
64 l0 = {u: literal_listener(u, goal) for u in range(len(utterances))}
65
66 @pyro.infer.config_enumerate
67 def model():
68 u = pyro.sample("u", dist.Categorical(torch.ones(len(utterances))))
69 sc = torch.tensor([math.log(l0[uu].get(gv, 0.0)) if l0[uu].get(gv, 0.0) > 0 else NEG
70 for uu in range(len(utterances))])
71 pyro.factor("sc", 1.0 * sc[u])
72 return None
73
74 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)
75 m = marg["u"]
76 sup = m.enumerate_support()
77 pr = m.log_prob(sup).exp()
78 res = {int(i): p for i, p in zip(sup.tolist(), pr.tolist())}
79 spk_cache[key] = res
80 return res
81
82
83# Pragmatic listener: joint posterior over (state, valence, arousal) via exact
84# enumeration of a single combined latent; the engine marginalizes.
85def pragmatic_listener(utt_idx):
86 combos = []
87 for s in range(len(states)):
88 for vi in range(2):
89 for ai in range(len(arousals)):
90 combos.append((s, vi, ai))
91 sp_norm = (state_prior / state_prior.sum()).tolist()
92 jp = []
93 for (s, vi, ai) in combos:
94 st = states[s]
95 pv = val_neg_p[st] if vi == 0 else (1 - val_neg_p[st])
96 aw = arousal_w[st]
97 pa = aw[ai] / sum(aw)
98 jp.append(sp_norm[s] * pv * pa)
99 jp = torch.tensor(jp)
100 score = torch.full((len(combos), 3), NEG)
101 for j, (s, vi, ai) in enumerate(combos):
102 val = valences[vi]
103 ar = arousals[ai]
104 for g in range(3):
105 spk = speaker_logprobs(s, val, ar, goals[g])
106 p = spk.get(utt_idx, 0.0)
107 score[j, g] = math.log(p) if p > 0 else NEG
108
109 @pyro.infer.config_enumerate
110 def model():
111 cf = pyro.sample("cf", dist.Categorical(jp))
112 goal = pyro.sample("goal", dist.Categorical(goal_prior))
113 pyro.factor("obs", score[cf, goal])
114 return None
115
116 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)
117 m = marg["cf"]
118 sup = m.enumerate_support()
119 pr = m.log_prob(sup).exp()
120 out = {}
121 for j, p in zip(sup.tolist(), pr.tolist()):
122 s, vi, ai = combos[int(j)]
123 rec = {"state": states[s], "valence": valences[vi], "arousal": arousals[ai]}
124 out[json.dumps(rec, sort_keys=True)] = p
125 return out
126
127
128ANSWER = pragmatic_listener(0) # "terrible"
129
02answer overlay — webppl vs pyrodist/finite
webppl pyro48 bins · top 48 of 50
00.0490.0490.0980.098{"arousal":0.9,"state":"amazing","valence":1} A = 0.098 B = 0.098{"arousal":0.9,"state":"amazing","valence":1} A = 0.098 B = 0.098{"arousal":0.9,"state":"amazing","valence":1}{"arousal":0.3,"state":"ok","valence":-1} A = 0.095 B = 0.095{"arousal":0.3,"state":"ok","valence":-1} A = 0.095 B = 0.095{"arousal":0.9,"state":"good","valence":1} A = 0.095 B = 0.095{"arousal":0.9,"state":"good","valence":1} A = 0.095 B = 0.095{"arousal":0.7,"state":"amazing","valence":1} A = 0.084 B = 0.084{"arousal":0.7,"state":"amazing","valence":1} A = 0.084 B = 0.084{"arousal":0.1,"state":"ok","valence":-1} A = 0.083 B = 0.083{"arousal":0.1,"state":"ok","valence":-1} A = 0.083 B = 0.083{"arousal":0.1,"state":"ok","valence":-1}{"arousal":0.7,"state":"good","valence":1} A = 0.081 B = 0.081{"arousal":0.7,"state":"good","valence":1} A = 0.081 B = 0.081{"arousal":0.5,"state":"ok","valence":-1} A = 0.072 B = 0.072{"arousal":0.5,"state":"ok","valence":-1} A = 0.072 B = 0.072{"arousal":0.5,"state":"amazing","valence":1} A = 0.049 B = 0.049{"arousal":0.5,"state":"amazing","valence":1} A = 0.049 B = 0.049{"arousal":0.5,"state":"good","valence":1} A = 0.044 B = 0.044{"arousal":0.5,"state":"good","valence":1} A = 0.044 B = 0.044{"arousal":0.5,"state":"good","valence":1}{"arousal":0.9,"state":"bad","valence":-1} A = 0.030 B = 0.030{"arousal":0.9,"state":"bad","valence":-1} A = 0.030 B = 0.030{"arousal":0.7,"state":"bad","valence":-1} A = 0.027 B = 0.027{"arousal":0.7,"state":"bad","valence":-1} A = 0.027 B = 0.027{"arousal":0.7,"state":"ok","valence":-1} A = 0.025 B = 0.025{"arousal":0.7,"state":"ok","valence":-1} A = 0.025 B = 0.025{"arousal":0.5,"state":"ok","valence":1} A = 0.025 B = 0.025{"arousal":0.5,"state":"ok","valence":1} A = 0.025 B = 0.025{"arousal":0.5,"state":"ok","valence":1}{"arousal":0.9,"state":"good","valence":-1} A = 0.024 B = 0.024{"arousal":0.9,"state":"good","valence":-1} A = 0.024 B = 0.024{"arousal":0.3,"state":"ok","valence":1} A = 0.024 B = 0.024{"arousal":0.3,"state":"ok","valence":1} A = 0.024 B = 0.024{"arousal":0.7,"state":"good","valence":-1} A = 0.021 B = 0.021{"arousal":0.7,"state":"good","valence":-1} A = 0.021 B = 0.021{"arousal":0.9,"state":"terrible","valence":-1} A = 0.016 B = 0.016{"arousal":0.9,"state":"terrible","valence":-1} A = 0.016 B = 0.016{"arousal":0.9,"state":"terrible","valence":-1}{"arousal":0.5,"state":"bad","valence":-1} A = 0.016 B = 0.016{"arousal":0.5,"state":"bad","valence":-1} A = 0.016 B = 0.016{"arousal":0.7,"state":"terrible","valence":-1} A = 0.015 B = 0.015{"arousal":0.7,"state":"terrible","valence":-1} A = 0.015 B = 0.015{"arousal":0.5,"state":"good","valence":-1} A = 0.013 B = 0.013{"arousal":0.5,"state":"good","valence":-1} A = 0.013 B = 0.013{"arousal":0.3,"state":"amazing","valence":1} A = 0.011 B = 0.011{"arousal":0.3,"state":"amazing","valence":1} A = 0.011 B = 0.011{"arousal":0.3,"state":"amazing","valence":1}{"arousal":0.5,"state":"terrible","valence":-1} A = 0.010 B = 0.010{"arousal":0.5,"state":"terrible","valence":-1} A = 0.010 B = 0.010{"arousal":0.7,"state":"ok","valence":1} A = 0.009 B = 0.009{"arousal":0.7,"state":"ok","valence":1} A = 0.009 B = 0.009{"arousal":0.3,"state":"good","valence":1} A = 0.006 B = 0.006{"arousal":0.3,"state":"good","valence":1} A = 0.006 B = 0.006{"arousal":0.1,"state":"ok","valence":1} A = 0.004 B = 0.004{"arousal":0.1,"state":"ok","valence":1} A = 0.004 B = 0.004{"arousal":0.1,"state":"ok","valence":1}{"arousal":0.3,"state":"terrible","valence":-1} A = 0.003 B = 0.003{"arousal":0.3,"state":"terrible","valence":-1} A = 0.003 B = 0.003{"arousal":0.3,"state":"bad","valence":-1} A = 0.003 B = 0.003{"arousal":0.3,"state":"bad","valence":-1} A = 0.003 B = 0.003{"arousal":0.9,"state":"amazing","valence":-1} A = 0.003 B = 0.003{"arousal":0.9,"state":"amazing","valence":-1} A = 0.003 B = 0.003{"arousal":0.9,"state":"ok","valence":-1} A = 0.003 B = 0.003{"arousal":0.9,"state":"ok","valence":-1} A = 0.003 B = 0.003{"arousal":0.9,"state":"ok","valence":-1}{"arousal":0.7,"state":"amazing","valence":-1} A = 0.002 B = 0.002{"arousal":0.7,"state":"amazing","valence":-1} A = 0.002 B = 0.002{"arousal":0.3,"state":"good","valence":-1} A = 0.002 B = 0.002{"arousal":0.3,"state":"good","valence":-1} A = 0.002 B = 0.002{"arousal":0.5,"state":"amazing","valence":-1} A = 0.001 B = 0.001{"arousal":0.5,"state":"amazing","valence":-1} A = 0.001 B = 0.001{"arousal":0.9,"state":"bad","valence":1} A = 0.001 B = 0.001{"arousal":0.9,"state":"bad","valence":1} A = 0.001 B = 0.001{"arousal":0.9,"state":"bad","valence":1}{"arousal":0.7,"state":"bad","valence":1} A = 0.001 B = 0.001{"arousal":0.7,"state":"bad","valence":1} A = 0.001 B = 0.001{"arousal":0.9,"state":"ok","valence":1} A = 0.001 B = 0.001{"arousal":0.9,"state":"ok","valence":1} A = 0.001 B = 0.001{"arousal":0.5,"state":"bad","valence":1} A = 0.001 B = 0.001{"arousal":0.5,"state":"bad","valence":1} A = 0.001 B = 0.001{"arousal":0.1,"state":"bad","valence":-1} A = 0.000 B = 0.000{"arousal":0.1,"state":"bad","valence":-1} A = 0.000 B = 0.000{"arousal":0.1,"state":"bad","valence":-1}{"arousal":0.3,"state":"amazing","valence":-1} A = 0.000 B = 0.000{"arousal":0.3,"state":"amazing","valence":-1} A = 0.000 B = 0.000{"arousal":0.1,"state":"good","valence":-1} A = 0.000 B = 0.000{"arousal":0.1,"state":"good","valence":-1} A = 0.000 B = 0.000{"arousal":0.1,"state":"terrible","valence":-1} A = 0.000 B = 0.000{"arousal":0.1,"state":"terrible","valence":-1} A = 0.000 B = 0.000{"arousal":0.1,"state":"good","valence":1} A = 0.000 B = 0.000{"arousal":0.1,"state":"good","valence":1} A = 0.000 B = 0.000{"arousal":0.1,"state":"good","valence":1}{"arousal":0.1,"state":"amazing","valence":1} A = 0.000 B = 0.000{"arousal":0.1,"state":"amazing","valence":1} A = 0.000 B = 0.000{"arousal":0.9,"state":"terrible","valence":1} A = 0.000 B = 0.000{"arousal":0.9,"state":"terrible","valence":1} A = 0.000 B = 0.000{"arousal":0.7,"state":"terrible","valence":1} A = 0.000 B = 0.000{"arousal":0.7,"state":"terrible","valence":1} A = 0.000 B = 0.000{"arousal":0.3,"state":"bad","valence":1} A = 0.000 B = 0.000{"arousal":0.3,"state":"bad","valence":1} A = 0.000 B = 0.000{"arousal":0.3,"state":"bad","valence":1}{"arousal":0.5,"state":"terrible","valence":1} A = 0.000 B = 0.000{"arousal":0.5,"state":"terrible","valence":1} A = 0.000 B = 0.000{"arousal":0.1,"state":"amazing","valence":-1} A = 0.000 B = 0.000{"arousal":0.1,"state":"amazing","valence":-1} A = 0.000 B = 0.000{"arousal":0.3,"state":"terrible","valence":1} A = 0.000 B = 0.000{"arousal":0.3,"state":"terrible","valence":1} A = 0.000 B = 0.000
03 verification
checkstatusevidence
GT self-consistency ok floor 0.0000 (tv)
solver re-derivation accept 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6
cross-language (pyro vs webppl) pass d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000
forestdb-kachakeche-comparison-class / atom-1
answer dist/real solver accept pyro pass 0.0000
00 statement source: data/sources/forestdb.org/models/kachakeche-comparison-class.md
given

Discretization: 30 box-weight values evenly spaced from 0.0 to 5.8 in steps of 0.2 (i.e., 0.0, 0.2, 0.4, …, 5.8). The state distribution for a given comparison class is obtained by computing the unnormalized Gaussian density at each of those 30 values and treating the results as categorical weights. Speaker-class Gaussian parameters: child (mu=0.5, sigma=1), adult (mu=2, sigma=3), bodybuilder (mu=5, sigma=3). The superordinate category has mu=3, sigma=1. Utterances: ["heavy", "light"]. Threshold sets: the "heavy" threshold set is obtained by subtracting 0.1 from each of the 30 state values (giving −0.1, 0.1, …, 5.7); the "light" threshold set by adding 0.1 to each (giving 0.1, 0.3, …, 5.9). Meaning: "heavy" is true with probability 0.9999 when state > the heavy threshold and 0.0001 otherwise; "light" is true with probability 0.9999 when state < the light threshold and 0.0001 otherwise. Speaker optimality alpha = 5. Comparison-class prior given the identity of the speaker: child → categorical([0.75, 0.25, 0.15], [child, adult, bodybuilder]); adult → categorical([0.01, 0.70, 0.50]); bodybuilder → categorical([0.0001, 0.20, 0.99]).

model

A pragmatic listener hears an adjective uttered by a speaker of known identity and infers the box's weight. The listener jointly samples a comparison class from the prior conditioned on the speaker's identity, then samples a weight from that class's Gaussian-discretized distribution, and independently samples a "heavy" threshold uniformly from the heavy threshold set and a "light" threshold uniformly from the light threshold set. The same two thresholds are passed to the speaker and through to the literal listener. A literal listener conditions on the soft meaning function. A speaker with a known weight, both thresholds, and comparison class chooses an utterance with probability proportional to exp(alpha × literal-listener score). The pragmatic listener observes the speaker's utterance and returns the marginal distribution over box weight.

query

The posterior distribution over box weight (one of the 30 values from 0.0 to 5.8) for a pragmatic listener who hears "heavy" uttered by a child.

answer spec dist/real
{
  "kind": "dist",
  "domain": "real"
}
system prompt constant across problems
(system prompt loads here)
webppl primer solver context
(primer loads here)
01 realizations comparing webppl vs pyro
ground truth
webppl
1// helper function
2var exp = function(x){return Math.exp(x)}
3
4// helper function
5var marginalize = function(dist, key){
6 return Infer({model: function(){sample(dist)[key]}})
7}
8// for discretization
9var binParam = 5;
10
11//my model:
12var superordinate = {mu: 3, sigma: 1};
13
14//a list of possible box weights (state values)
15var stateVals = _.range(superordinate.mu - 3 * superordinate.sigma,
16 superordinate.mu + 3 * superordinate.sigma,
17 superordinate.sigma/binParam)
18
19
20// for each possible weight, calculate its probability of occurrence
21var stateProbs = function(stateParams){
22 return map(function(s){
23 exp(Gaussian(stateParams).score(s))
24 }, stateVals)
25};
26
27
28// generate a statePrior using the possible weights and their probabilities
29var generateStatePrior = cache(function(stateParams) {
30 return Infer({
31 model: function(){
32 return categorical({vs: stateVals, ps: stateProbs(stateParams)})
33 }
34 })
35});
36
37// information about the category priors
38var speakerParams = {
39 child: {mu: 0.5, sigma: 1}, // child experience with weights
40 adult: {mu: 2, sigma: 3}, // adult experiance with weights
41 bodybuilder: {mu:5, sigma:3}, // body builder experience with weights
42}
43
44
45// generate the uniform threshold prior
46
47var thresholdBins = cache(function(utterance, stateSupport){
48 return map(function(x){
49 return utterance == "heavy" ? x - (1/(binParam*2)) : x + (1/(binParam*2));
50 }, sort(stateSupport))
51},10000
52)
53
54var thresholdPrior = cache(function(utterance, stateSupport){
55 return Infer({
56 model: function() { return uniformDraw(thresholdBins(utterance, stateSupport)) }
57 });
58},10000
59);
60
61
62
63var utterances = ["heavy","light"]
64
65
66// meaning function for utterances
67var meaning = function(utterance, state, thresholdHeavy, thresholdLight) {
68 utterance == "heavy" ? state > thresholdHeavy ? flip(0.9999) : flip(0.0001) :
69 utterance == "light" ? state < thresholdLight ? flip(0.9999) : flip(0.0001) :
70 true
71}
72
73
74
75
76// set sepeaker optimality
77var alpha = 5;
78
79
80var literalListener =cache(
81 function(utterance, thresholdHeavy, thresholdLight,comparisonClass){
82 Infer({model: function(){
83 var state = sample(generateStatePrior(speakerParams[comparisonClass]))
84 var m = meaning(utterance, state, thresholdHeavy, thresholdLight);
85 condition(m);
86 return state;
87 }})
88 },10000
89)
90
91
92//literalListener("light", 4, 2, "adult")
93
94
95//my model:
96var speaker1 = cache(
97 function(state, thresholdHeavy, thresholdLight,comparisonClass){
98 Infer({model: function(){
99 var utterance = uniformDraw(utterances);
100 var L0 = literalListener(utterance, thresholdHeavy, thresholdLight, comparisonClass);
101 factor( alpha * L0.score(state) );
102 return utterance;
103 }})
104 },10000
105)
106
107
108// generateStatePrior(speakerParams["child"]) .support()
109
110var comparisonClasses = ["child","adult","bodybuilder"]
111var comparisonClassPrior = function(whoSaidIt) {
112 whoSaidIt == "child" ? categorical([0.75, 0.25,0.15],comparisonClasses):
113 whoSaidIt == "adult" ? categorical([0.01,0.7,0.5],comparisonClasses):
114 categorical([0.0001,0.2, 0.99],comparisonClasses)
115
116
117}
118
119var pragmaticListener = cache(function(utterance,whoSaidIt){
120 Infer({model: function(){
121 var CC = comparisonClassPrior(whoSaidIt);
122 var statePrior = generateStatePrior(speakerParams[CC]);
123 var state = sample(statePrior);
124 var thresholdHeavy = sample(thresholdPrior("heavy", statePrior.support()))
125 var thresholdLight = sample(thresholdPrior("light", statePrior.support()))
126 var S1 = speaker1(state, thresholdHeavy,thresholdLight,CC);
127 observe(S1, utterance);
128 return (state)
129}, method:"enumerate"})
130},10000
131)
132
133
134
135// pragmaticListener("heavy","adult")
136display("Listener's interpretation after hearing a child saying 'this box is heavy'")
137viz.density(pragmaticListener("heavy","child"))
138display("Listener's interpretation after hearing an adult saying 'this box is heavy'")
139viz.density(pragmaticListener("heavy","adult"))
140display("Listener's interpretation after hearing a bodybuilder saying 'this box is heavy'")
141viz.density(pragmaticListener("heavy","bodybuilder"))
142display("Listener's interpretation after hearing a child saying 'this box is light'")
143viz.density(pragmaticListener("light","child"))
144display("Listener's interpretation after hearing an adult saying 'this box is light'")
145viz.density(pragmaticListener("light","adult"))
146display("Listener's interpretation after hearing a bodybuilder saying 'this box is light'")
147viz.density(pragmaticListener("light","bodybuilder"))
148
149var ANSWER = (pragmaticListener("heavy", "child"));
realization0.000
python
1# RSA comparison-class model (kachakeche), pragmatic listener hearing a CHILD say
2# 'heavy'. Every RSA level is run through Pyro's exact discrete enumeration
3# (config_enumerate + TraceEnum_ELBO.compute_marginals). The literal listener and
4# the speaker are inferred per configuration but vectorized across all configs with a
5# pyro.plate, so each level is a genuine Pyro inference whose normalized posterior
6# comes from Pyro's engine; lower-level log-scores enter the next level only as
7# pyro.factor terms. No level's distribution is normalized by hand.
8
9binParam = 5
10super_mu, super_sigma = 3.0, 1.0
11# state values: range(mu-3sigma, mu+3sigma, sigma/binParam) = range(0, 6, 0.2) -> 30 values
12NS = 30
13stateVals = [round(super_mu - 3 * super_sigma + i * (super_sigma / binParam), 10)
14 for i in range(NS)]
15stateVals_t = torch.tensor(stateVals)
16
17comparisonClasses = ["child", "adult", "bodybuilder"]
18speakerParams = {"child": (0.5, 1.0), "adult": (2.0, 3.0), "bodybuilder": (5.0, 3.0)}
19NC = len(comparisonClasses)
20alpha = 5.0
21half = 1.0 / (binParam * 2) # 0.1
22
23# threshold sets (stateVals are already sorted ascending):
24# heavy = state - 0.1 ; light = state + 0.1
25heavy_thr = stateVals_t - half # [NS]
26light_thr = stateVals_t + half # [NS]
27
28LOG_HI = math.log(0.9999)
29LOG_LO = math.log(0.0001)
30
31
32def state_prior_probs(mu, sigma):
33 w = dist.Normal(mu, sigma).log_prob(stateVals_t).exp()
34 return w / w.sum()
35
36
37# Per-class state prior table [NC, NS].
38state_prior_table = torch.stack(
39 [state_prior_probs(*speakerParams[cc]) for cc in comparisonClasses]
40)
41
42
43# ---- Literal listener L0, run through Pyro exact enumeration. For each (comparison
44# class, threshold) the model enumerates the state under that class's prior and
45# conditions on the soft meaning (flip(0.9999) if the utterance is literally true of
46# the state else flip(0.0001)); compute_marginals returns the exact state posterior.
47# All NC*NS configs are inferred together via a single plated enumeration. ----
48def run_L0(thr_vec, is_heavy):
49 K = NC * NS # configs over (cc, threshold)
50 prior = state_prior_table.unsqueeze(1).expand(NC, NS, NS).reshape(K, NS) # [K, NS]
51 thr_b = thr_vec.unsqueeze(0).expand(NC, NS).reshape(K) # [K]
52 state_b = stateVals_t.unsqueeze(0) # [1, NS]
53 if is_heavy:
54 holds = state_b > thr_b.unsqueeze(1) # heavy: state > thresholdHeavy
55 else:
56 holds = state_b < thr_b.unsqueeze(1) # light: state < thresholdLight
57 fac = torch.where(holds, torch.tensor(LOG_HI), torch.tensor(LOG_LO)) # [K, NS]
58
59 @pyro.infer.config_enumerate
60 def model():
61 with pyro.plate("cfg", K):
62 s = pyro.sample("state", dist.Categorical(prior))
63 sel = pyro.ops.indexing.Vindex(fac)[torch.arange(K), s]
64 pyro.factor("meaning", sel)
65
66 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=1).compute_marginals(model, lambda: None)
67 m = marg["state"]
68 sup = m.enumerate_support() # [NS, K] state indices per config
69 lp = m.log_prob(sup) # [NS, K]
70 out = torch.full((K, NS), float("-inf"))
71 cfg_idx = torch.arange(K).unsqueeze(0).expand_as(sup)
72 out[cfg_idx, sup] = lp
73 return out.reshape(NC, NS, NS) # [cc, thr, state]
74
75
76logL0H = run_L0(heavy_thr, True) # [cc, thH, state] log L0(state | 'heavy', thH, cc)
77logL0L = run_L0(light_thr, False) # [cc, thL, state] log L0(state | 'light', thL, cc)
78
79
80# ---- Speaker S1, run through Pyro exact enumeration over the two utterances. For a
81# fixed (cc, state, thH, thL) the speaker chooses utterance ~ exp(alpha * log L0(state
82# | utterance, thresholds, cc)); 'heavy' scores against L0H[cc, thH], 'light' against
83# L0L[cc, thL]. compute_marginals normalizes the two-utterance softmax. All
84# NC*NS*NS*NS configs are inferred together via one plated enumeration. ----
85scoreH = alpha * logL0H.permute(0, 2, 1) # [cc, state, thH]
86scoreL = alpha * logL0L.permute(0, 2, 1) # [cc, state, thL]
87sH = scoreH.unsqueeze(3).expand(NC, NS, NS, NS) # [cc, state, thH, thL]
88sL = scoreL.unsqueeze(2).expand(NC, NS, NS, NS) # [cc, state, thH, thL]
89spk_scores = torch.stack([sH, sL], dim=-1) # [cc, state, thH, thL, 2] (heavy, light)
90Kf = NC * NS * NS * NS
91spk_scores_flat = spk_scores.reshape(Kf, 2)
92
93
94@pyro.infer.config_enumerate
95def speaker_model():
96 with pyro.plate("scfg", Kf):
97 u = pyro.sample("utt", dist.Categorical(torch.ones(Kf, 2)))
98 sel = pyro.ops.indexing.Vindex(spk_scores_flat)[torch.arange(Kf), u]
99 pyro.factor("util", sel)
100
101
102smarg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=1).compute_marginals(speaker_model, lambda: None)
103sm = smarg["utt"]
104ssup = sm.enumerate_support() # [2, Kf] utterance indices per config
105slp = sm.log_prob(ssup) # [2, Kf]
106spk_lp = torch.full((Kf, 2), float("-inf"))
107scfg_idx = torch.arange(Kf).unsqueeze(0).expand_as(ssup)
108spk_lp[scfg_idx, ssup] = slp
109# log P_speaker('heavy' | cc, state, thH, thL) (utterance index 0 == 'heavy')
110logS_heavy = spk_lp[:, 0].reshape(NC, NS, NS, NS) # [cc, state, thH, thL]
111
112
113# ---- Pragmatic listener (child speaker), exact enumeration over the joint discrete
114# latents (comparison class, state, heavy-threshold, light-threshold). The speaker
115# log-score for the heard utterance 'heavy' enters as the observation factor;
116# compute_marginals contracts the nuisance latents to the exact marginal over state. ----
117ccp_vec = torch.tensor([0.75, 0.25, 0.15])
118ccp_vec = ccp_vec / ccp_vec.sum() # child speaker comparison-class prior
119
120
121@pyro.infer.config_enumerate
122def prag_model():
123 cc_i = pyro.sample("cc", dist.Categorical(ccp_vec))
124 sp = pyro.ops.indexing.Vindex(state_prior_table)[cc_i]
125 state_i = pyro.sample("state", dist.Categorical(sp))
126 th_h_i = pyro.sample("thH", dist.Categorical(torch.ones(NS) / NS))
127 th_l_i = pyro.sample("thL", dist.Categorical(torch.ones(NS) / NS))
128 score = pyro.ops.indexing.Vindex(logS_heavy)[cc_i, state_i, th_h_i, th_l_i]
129 pyro.factor("obs", score)
130
131
132marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(prag_model, lambda: None)
133m = marg["state"]
134sup = m.enumerate_support()
135probs = m.log_prob(sup).exp()
136ANSWER = {stateVals[int(s)]: float(pr) for s, pr in zip(sup, probs)}
137
02answer overlay — webppl vs pyrodist/real
webppl pyro51 bins · 0 … 5.80
00.0310.0310.0620.0620123451.20 · 0.0621.20 · 0.062x = 0 A = 0.0132 B = 0.0132 Δ = 0.0000x = 0.20 A = 0.0214 B = 0.0214 Δ = 0.0000x = 0.40 A = 0.0312 B = 0.0312 Δ = 0.0000x = 0.60 A = 0.0000 B = 0.0417 Δ = -0.0417x = 0.60 A = 0.0417 B = 0.0000 Δ = 0.0417x = 0.80 A = 0.0514 B = 0.0514 Δ = 0.0000x = 1 A = 0.0585 B = 0.0585 Δ = 0.0000x = 1.20 A = 0.0618 B = 0.0618 Δ = 0.0000x = 1.40 A = 0.0611 B = 0.0611 Δ = 0.0000x = 1.60 A = 0.0573 B = 0.0000 Δ = 0.0573x = 1.60 A = 0.0000 B = 0.0573 Δ = -0.0573x = 1.80 A = 0.0516 B = 0.0000 Δ = 0.0516x = 1.80 A = 0.0000 B = 0.0516 Δ = -0.0516x = 2.00 A = 0.0449 B = 0.0000 Δ = 0.0449x = 2 A = 0.0000 B = 0.0449 Δ = -0.0449x = 2.20 A = 0.0385 B = 0.0000 Δ = 0.0385x = 2.20 A = 0.0000 B = 0.0385 Δ = -0.0385x = 2.40 A = 0.0329 B = 0.0329 Δ = -0.0000x = 2.60 A = 0.0286 B = 0.0286 Δ = 0.0000x = 2.80 A = 0.0000 B = 0.0256 Δ = -0.0256x = 2.80 A = 0.0256 B = 0.0000 Δ = 0.0256x = 3 A = 0.0000 B = 0.0239 Δ = -0.0239x = 3.00 A = 0.0239 B = 0.0000 Δ = 0.0239x = 3.20 A = 0.0000 B = 0.0231 Δ = -0.0231x = 3.20 A = 0.0231 B = 0.0000 Δ = 0.0231x = 3.40 A = 0.0000 B = 0.0231 Δ = -0.0231x = 3.40 A = 0.0231 B = 0.0000 Δ = 0.0231x = 3.60 A = 0.0000 B = 0.0236 Δ = -0.0236x = 3.60 A = 0.0236 B = 0.0000 Δ = 0.0236x = 3.80 A = 0.0000 B = 0.0243 Δ = -0.0243x = 3.80 A = 0.0243 B = 0.0000 Δ = 0.0243x = 4 A = 0.0000 B = 0.0251 Δ = -0.0251x = 4.00 A = 0.0251 B = 0.0000 Δ = 0.0251x = 4.20 A = 0.0000 B = 0.0258 Δ = -0.0258x = 4.20 A = 0.0258 B = 0.0000 Δ = 0.0258x = 4.40 A = 0.0000 B = 0.0264 Δ = -0.0264x = 4.40 A = 0.0264 B = 0.0000 Δ = 0.0264x = 4.60 A = 0.0000 B = 0.0269 Δ = -0.0269x = 4.60 A = 0.0269 B = 0.0000 Δ = 0.0269x = 4.80 A = 0.0000 B = 0.0271 Δ = -0.0271x = 4.80 A = 0.0271 B = 0.0000 Δ = 0.0271x = 5 A = 0.0000 B = 0.0271 Δ = -0.0271x = 5.00 A = 0.0271 B = 0.0000 Δ = 0.0271x = 5.20 A = 0.0000 B = 0.0268 Δ = -0.0268x = 5.20 A = 0.0268 B = 0.0000 Δ = 0.0268x = 5.40 A = 0.0000 B = 0.0264 Δ = -0.0264x = 5.40 A = 0.0264 B = 0.0000 Δ = 0.0264x = 5.60 A = 0.0000 B = 0.0258 Δ = -0.0258x = 5.60 A = 0.0258 B = 0.0000 Δ = 0.0258x = 5.80 A = 0.0000 B = 0.0249 Δ = -0.0249x = 5.80 A = 0.0249 B = 0.0000 Δ = 0.0249
03 verification
checkstatusevidence
GT self-consistency ok floor 0.0000 (w1)
solver re-derivation accept 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6
cross-language (pyro vs webppl) pass d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000
forestdb-keysar / atom-1
answer dist/finite solver accept pyro pass 0.0000
00 statement source: data/sources/forestdb.org/models/keysar.md
given

Objects: all 9 combinations of type ∈ {apple, fish, cup} × color ∈ {red, blue, green}. Utterances: all strings of the form "[color] type" or just "type", giving 12 utterances total (3 color-qualified per type + 3 bare types). Utterance cost: the number of words in the utterance (1 for bare, 2 for color-qualified). Utterance fitness: for a two-word utterance, 0 if the object matches both color and type, −100 otherwise; for a one-word utterance, 0 if the object matches the type, −100 otherwise. Speaker optimality alpha = 3. Context: shared objects are {red apple, blue fish, green cup}; occluded object is {red fish}.

model

A baseline reference game where the speaker has no uncertainty about the environment. A literal listener draws uniformly from a perceived context and weights by utterance fitness. A speaker, knowing the target and the shared context, chooses an utterance with probability proportional to exp(alpha × literal-listener score − utterance cost). A pragmatic listener (the L2 model) assumes the speaker only knows and considers the shared context. The pragmatic listener draws uniformly from the shared context (ignoring the occluded object) and updates on the speaker's distribution.

query

The posterior distribution over object identity (expressed as "color type" strings) for the pragmatic listener who hears the bare utterance "fish" in the given context.

answer spec dist/finite
{
  "kind": "dist",
  "domain": "finite",
  "support": [
    "red apple",
    "blue fish",
    "green cup"
  ]
}
system prompt constant across problems
(system prompt loads here)
webppl primer solver context
(primer loads here)
01 realizations comparing webppl vs pyro
ground truth
webppl
1var possibleUtterances = _.flatten(map(function(modifier) {
2 map(function(type) {
3 return [modifier, type].join(' ').trim();
4 }, ['apple', 'fish', 'cup']);
5}, ['red', 'blue', 'green', '']));
6
7var possibleObjects = [
8 {type: 'apple', color: 'red'}, {type: 'apple', color: 'blue'}, {type: 'apple', color: 'green'},
9 {type: 'fish', color: 'red'}, {type: 'fish', color: 'blue'}, {type: 'fish', color: 'green'},
10 {type: 'cup', color: 'red'}, {type: 'cup', color: 'blue'}, {type: 'cup', color: 'green'}]
11
12var exampleContext = {
13 shared: [
14 {type: 'apple', color: 'red'},
15 {type: 'fish', color: 'blue'},
16 {type: 'cup', color: 'green'}
17 ],
18 occluded: [
19 {type: 'fish', color: 'red'}
20 ]
21};
22
23var alpha = 3;
24
25var uttCost = function(utt) {
26 return utt.split(' ').length;
27}
28
29var uttFitness = function(utt, object) {
30 var descriptors = utt.split(' ');
31 if(descriptors.length > 1) {
32 return (object.color === descriptors[0] &&
33 object.type === descriptors[1]) ? 0 : -100;
34 } else {
35 return object.type === descriptors[0] ? 0 : -100;
36 }
37};
38
39var L0 = cache(function(utt, perceivedContext) {
40 return Infer({method: 'enumerate'}, function() {
41 var object = uniformDraw(perceivedContext);
42 factor(uttFitness(utt, object));
43 return object;
44 });
45});
46
47var S1 = cache(function(target, knownContext) {
48 return Infer({method: 'enumerate'}, function() {
49 var utt = uniformDraw(possibleUtterances);
50 factor(alpha * L0(utt, knownContext).score(target) - uttCost(utt));
51 return utt;
52 });
53});
54
55// Listener only considers objects speaker can see (model Keysar is arguing against)
56var L2 = cache(function(utt, perceivedContext) {
57 var sharedContext = perceivedContext.shared;
58 var fullObjSet = sharedContext.concat(perceivedContext.occluded);
59 return Infer({method: 'enumerate'}, function() {
60 var object = uniformDraw(sharedContext);
61 observe(S1(object, sharedContext), utt);
62 return object.color + " " + object.type;
63 });
64});
65
66console.log("speaker utterance to refer to blue fish");
67viz.table(S1({type: 'fish', color: 'blue'}, exampleContext.shared));
68
69console.log("listener response after hearing (underinformative) 'fish'");
70viz.table(L2('fish', exampleContext));
71
72var ANSWER = (L2('fish', exampleContext));
realization0.000
python
1possibleUtterances = []
2for modifier in ["red", "blue", "green", ""]:
3 for typ in ["apple", "fish", "cup"]:
4 possibleUtterances.append((modifier + " " + typ).strip())
5
6shared = [
7 {"type": "apple", "color": "red"},
8 {"type": "fish", "color": "blue"},
9 {"type": "cup", "color": "green"},
10]
11occluded = [{"type": "fish", "color": "red"}]
12alpha = 3.0
13
14def uttCost(utt):
15 return len(utt.split(" "))
16
17def uttFitness(utt, obj):
18 d = utt.split(" ")
19 if len(d) > 1:
20 return 0.0 if (obj["color"] == d[0] and obj["type"] == d[1]) else -100.0
21 return 0.0 if obj["type"] == d[0] else -100.0
22
23def _ctxkey(ctx):
24 return tuple((o["color"], o["type"]) for o in ctx)
25
26# Literal listener: uniform over the perceived context, factor by utterance fitness.
27# Exact enumeration via Pyro's compute_marginals.
28_L0_cache = {}
29def L0(utt, perceivedContext):
30 key = (utt, _ctxkey(perceivedContext))
31 if key in _L0_cache:
32 return _L0_cache[key]
33 _fit = torch.tensor([uttFitness(utt, o) for o in perceivedContext])
34 @pyro.infer.config_enumerate
35 def m():
36 idx = pyro.sample("obj", dist.Categorical(torch.ones(len(perceivedContext))))
37 pyro.factor("fit", _fit[idx])
38 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(m, lambda: None)
39 _L0_cache[key] = marg["obj"]
40 return _L0_cache[key]
41
42# Speaker: utterance ~ exp(alpha * L0.score(target) - uttCost), inverting L0.
43_S1_cache = {}
44def S1(target, knownContext):
45 key = ((target["color"], target["type"]), _ctxkey(knownContext))
46 if key in _S1_cache:
47 return _S1_cache[key]
48 tidx = knownContext.index(target)
49 L0s = [L0(u, knownContext) for u in possibleUtterances]
50 _sc = torch.tensor([
51 alpha * L0s[i].log_prob(torch.tensor(tidx)).item() - uttCost(possibleUtterances[i])
52 for i in range(len(possibleUtterances))
53 ])
54 @pyro.infer.config_enumerate
55 def m():
56 uidx = pyro.sample("utt", dist.Categorical(torch.ones(len(possibleUtterances))))
57 pyro.factor("util", _sc[uidx])
58 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(m, lambda: None)
59 _S1_cache[key] = marg["utt"]
60 return _S1_cache[key]
61
62# Pragmatic listener (L2): uniform over the shared context, observe the speaker's
63# utterance where the speaker only knows the shared context.
64heard = "fish"
65_uidx = possibleUtterances.index(heard)
66_S1s = [S1(o, shared) for o in shared]
67_obs_scores = torch.tensor([
68 _S1s[i].log_prob(torch.tensor(_uidx)).item() for i in range(len(shared))
69])
70
71@pyro.infer.config_enumerate
72def _L2():
73 oidx = pyro.sample("obj", dist.Categorical(torch.ones(len(shared))))
74 pyro.factor("obs", _obs_scores[oidx])
75
76_marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(_L2, lambda: None)
77_post = _marg["obj"]
78ANSWER = {
79 shared[i]["color"] + " " + shared[i]["type"]: _post.probs[i].item()
80 for i in range(len(shared))
81}
82
02answer overlay — webppl vs pyrodist/finite
webppl pyro3 bins
00.500.501.001.00blue fish A = 1.000 B = 1.000blue fish A = 1.000 B = 1.0001.001.00blue fishgreen cup A = 0.000 B = 0.000green cup A = 0.000 B = 0.000green cupred apple A = 0.000 B = 0.000red apple A = 0.000 B = 0.000red apple
03 verification
checkstatusevidence
GT self-consistency ok floor 0.0000 (tv)
solver re-derivation accept 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6
cross-language (pyro vs webppl) pass d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000
forestdb-keysar / atom-2
answer dist/finite solver accept pyro pass 0.0000
00 statement source: data/sources/forestdb.org/models/keysar.md
given

Objects: all 9 combinations of type ∈ {apple, fish, cup} × color ∈ {red, blue, green}. Utterances: all strings of the form "[color] type" or just "type", giving 12 utterances. Utterance cost: number of words divided by 4. Utterance fitness: for two-word utterances, 0 if both color and type match, −100 otherwise; for one-word utterances, 0 if type matches, −100 otherwise. Speaker optimality alpha = 4. Context: shared objects are {red apple, blue fish, green cup}; occluded object is {red fish}.

model

A context-uncertainty reference game where the speaker is uncertain about additional objects the listener may perceive. A literal listener draws uniformly from a perceived context and weights by utterance fitness. A speaker, given a target and the shared context, marginalizes over uncertainty about one hidden additional object by considering all possible contexts formed by augmenting the shared context with each of the 9 possible objects uniformly, scoring the literal listener on each, then weights utterances by exp(alpha × marginalized-listener score − utterance cost). A pragmatic listener (L2) considers both shared and occluded objects as potential referents, draws uniformly from the full context, and updates on the speaker's distribution (who is assumed to know only the shared context).

query

The posterior distribution over object identity ("color type" strings) for the pragmatic listener who hears the bare utterance "fish" in the given context.

answer spec dist/finite
{
  "kind": "dist",
  "domain": "finite",
  "support": [
    "red apple",
    "blue fish",
    "green cup",
    "red fish"
  ]
}
system prompt constant across problems
(system prompt loads here)
webppl primer solver context
(primer loads here)
01 realizations comparing webppl vs pyro
ground truth
webppl
1var possibleUtterances = _.flatten(map(function(modifier) {
2 return map(function(type) {
3 return [modifier, type].join(' ').trim();
4 }, ['apple', 'fish', 'cup']);
5}, ['red', 'blue', 'green', '']));
6
7var possibleObjects = [
8 {type: 'apple', color: 'red'}, {type: 'apple', color: 'blue'}, {type: 'apple', color: 'green'},
9 {type: 'fish', color: 'red'}, {type: 'fish', color: 'blue'}, {type: 'fish', color: 'green'},
10 {type: 'cup', color: 'red'}, {type: 'cup', color: 'blue'}, {type: 'cup', color: 'green'}]
11
12var exampleContext = {
13 shared: [
14 {type: 'apple', color: 'red'},
15 {type: 'fish', color: 'blue'},
16 {type: 'cup', color: 'green'}
17 ],
18 occluded: [
19 {type: 'fish', color: 'red'}
20 ]
21};
22
23var alpha = 4;
24
25var uttCost = function(utt) {
26 return utt.split(' ').length/4;
27}
28
29var uttFitness = function(utt, object) {
30 var descriptors = utt.split(' ');
31 if(descriptors.length > 1) {
32 return (object.color === descriptors[0] &&
33 object.type === descriptors[1]) ? 0 : -100;
34 } else {
35 return object.type === descriptors[0] ? 0 : -100;
36 }
37};
38
39var L0 = cache(function(utt, perceivedContext) {
40 return Infer({method: 'enumerate'}, function() {
41 var object = uniformDraw(perceivedContext);
42 factor(uttFitness(utt, object))
43 return object;
44 })
45})
46
47// speaker has uncertainty over what's behind occluded square
48// marginalizes over all possibilities
49var S1 = cache(function(target, perceivedContext) {
50 return Infer({method: 'enumerate'}, function() {
51 var utt = uniformDraw(possibleUtterances);
52 var listener = Infer({method: 'enumerate', model: function() {
53 var context = perceivedContext.concat(uniformDraw(possibleObjects));
54 return sample(L0(utt, context));
55 }})
56 factor(alpha * listener.score(target) - uttCost(utt))
57 return utt;
58 });
59});
60
61// Listener reasons about S1; assumes they only see what's in shared context
62// but could be trying to refer to any of the objects
63var L2 = cache(function(utt, perceivedContext) {
64 var sharedContext = perceivedContext.shared;
65 var fullContext = sharedContext.concat(perceivedContext.occluded);
66 return Infer({method: 'enumerate', model: function() {
67 var object = uniformDraw(fullContext);
68 observe(S1(object, sharedContext), utt);
69 return object.color + " " + object.type;
70 }});
71})
72
73console.log("speaker utterance to refer to blue fish")
74viz.table(S1({type: 'fish', color: 'blue'}, exampleContext.shared))
75
76console.log("listener response after hearing (underinformative) 'fish'")
77viz.table(L2('fish', exampleContext));
78
79var ANSWER = (L2('fish', exampleContext));
realization0.000
python
1# Keysar context-uncertainty reference game (RSA).
2# Each level (L0 literal listener, S1 speaker, L2 pragmatic listener) is a
3# discrete enumerable model; we run Pyro's own exact enumeration inference
4# (config_enumerate + TraceEnum_ELBO.compute_marginals) at every level and let
5# the higher levels reason about the lower-level posteriors via their scores,
6# exactly as the WebPPL recursion does.
7
8import json
9
10possibleObjects = [
11 {"type": "apple", "color": "red"}, {"type": "apple", "color": "blue"}, {"type": "apple", "color": "green"},
12 {"type": "fish", "color": "red"}, {"type": "fish", "color": "blue"}, {"type": "fish", "color": "green"},
13 {"type": "cup", "color": "red"}, {"type": "cup", "color": "blue"}, {"type": "cup", "color": "green"},
14]
15
16possibleUtterances = []
17for modifier in ["red", "blue", "green", ""]:
18 for t in ["apple", "fish", "cup"]:
19 possibleUtterances.append((modifier + " " + t).strip())
20
21shared = [
22 {"type": "apple", "color": "red"},
23 {"type": "fish", "color": "blue"},
24 {"type": "cup", "color": "green"},
25]
26occluded = [{"type": "fish", "color": "red"}]
27fullContext = shared + occluded
28
29alpha = 4.0
30
31def obj_key(o):
32 return o["color"] + " " + o["type"]
33
34def uttCost(utt):
35 return len(utt.split(" ")) / 4.0
36
37def uttFitness(utt, obj):
38 parts = utt.split(" ")
39 if len(parts) > 1:
40 return 0.0 if (obj["color"] == parts[0] and obj["type"] == parts[1]) else -100.0
41 else:
42 return 0.0 if obj["type"] == parts[0] else -100.0
43
44def enum_marginal(log_weights):
45 # Run Pyro exact enumeration over a uniform Categorical with per-outcome
46 # log-weights supplied via pyro.factor; return the normalized marginal as a
47 # 1-D tensor of probabilities (one per outcome) computed by Pyro's inference.
48 n = len(log_weights)
49 lw = torch.tensor([float(x) for x in log_weights])
50 base = torch.ones(n) / n
51
52 @pyro.infer.config_enumerate
53 def model():
54 idx = pyro.sample("idx", dist.Categorical(base))
55 pyro.factor("w", lw[idx])
56 return idx
57
58 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)["idx"]
59 sup = marg.enumerate_support()
60 probs = marg.log_prob(sup).exp()
61 out = torch.zeros(n)
62 for s, p in zip(sup, probs):
63 out[int(s.item())] = p
64 return out
65
66_L0 = {}
67def L0(utt, context):
68 # Literal listener: uniform over objects in context, weighted by fitness.
69 key = (utt, tuple(obj_key(o) for o in context))
70 if key in _L0:
71 return _L0[key]
72 probs = enum_marginal([uttFitness(utt, o) for o in context])
73 _L0[key] = probs
74 return probs
75
76_S1 = {}
77def S1(target, perceivedContext):
78 # Speaker: softmax over utterances by alpha * (marginalized literal-listener
79 # log-score of the target) - cost. The listener marginalizes over one hidden
80 # extra object drawn uniformly from possibleObjects.
81 tkey = obj_key(target)
82 key = (tkey, tuple(obj_key(o) for o in perceivedContext))
83 if key in _S1:
84 return _S1[key]
85 log_weights = []
86 for utt in possibleUtterances:
87 # marginal P(listener picks target) over the 9 augmented contexts
88 marg_p = 0.0
89 for added in possibleObjects:
90 context = perceivedContext + [added]
91 l0 = L0(utt, context)
92 for i, o in enumerate(context):
93 if obj_key(o) == tkey:
94 marg_p += float(l0[i].item())
95 marg_p = marg_p / len(possibleObjects)
96 score = math.log(marg_p) if marg_p > 0 else float("-inf")
97 log_weights.append(alpha * score - uttCost(utt))
98 probs = enum_marginal(log_weights)
99 _S1[key] = probs
100 return probs
101
102def L2(utt, perceivedContext):
103 # Pragmatic listener: uniform over fullContext objects, observe S1 emitting
104 # utt (assuming the speaker sees only the shared context).
105 sharedCtx = perceivedContext["shared"]
106 full = perceivedContext["shared"] + perceivedContext["occluded"]
107 utt_idx = possibleUtterances.index(utt)
108 log_weights = []
109 for obj in full:
110 s1 = S1(obj, sharedCtx)
111 p = float(s1[utt_idx].item())
112 log_weights.append(math.log(p) if p > 0 else float("-inf"))
113 probs = enum_marginal(log_weights)
114 return full, probs
115
116context = {"shared": shared, "occluded": occluded}
117full, probs = L2("fish", context)
118
119support = ["red apple", "blue fish", "green cup", "red fish"]
120result = {}
121for o, p in zip(full, probs):
122 result[obj_key(o)] = float(p.item())
123ANSWER = {lab: result.get(lab, 0.0) for lab in support}
124
02answer overlay — webppl vs pyrodist/finite
webppl pyro4 bins
00.430.430.860.86blue fish A = 0.858 B = 0.858blue fish A = 0.858 B = 0.8580.860.86blue fishgreen cup A = 0.000 B = 0.000green cup A = 0.000 B = 0.000green cupred apple A = 0.000 B = 0.000red apple A = 0.000 B = 0.000red applered fish A = 0.142 B = 0.142red fish A = 0.142 B = 0.1420.140.14red fish
03 verification
checkstatusevidence
GT self-consistency ok floor 0.0000 (tv)
solver re-derivation accept 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6
cross-language (pyro vs webppl) pass d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000
forestdb-kids-scope / atom-1
answer dist/finite solver accept pyro pass 0.0000
00 statement source: forestdb.org/models/kids-scope.md
given

Two horses; the world state is the count of red (true) horses: 0, 1, or 2. Each horse is independently red with probability 0.5, so the state prior assigns probability 0.25 to state=0, 0.5 to state=1, and 0.25 to state=2. Utterances: ["null", "every-not"]. Utterance cost: 1 for both. Scopes: ["surface", "inverse"], each with prior probability 0.5. Meaning of "every-not": under surface scope, true iff state = 0 (not any horse is red); under inverse scope, true iff state < 2 (not all horses are red). "null" is always true. QUDs: ["how many?", "all red?", "none red?"], drawn with equal probability. The QUD function maps a state to: state = (numHorses=2) for "all red?"; state = 0 for "none red?"; the raw state count for "how many?". Speaker rationality alpha = 1.

model

A four-level RSA model over scope ambiguity. A literal listener draws the world state uniformly from {0, 1, 2} (equal weight 1/3 each — not the binomial state prior), conditions on the utterance being true under a given scope and QUD, and returns the QUD-relevant quantity. The pragmatic listener and pragmatic speaker use the binomial state prior (0.25/0.5/0.25) described above. A speaker (S1) given a scope, state, and QUD weights utterances by exp(alpha × (literal-listener score on the QUD-relevant quantity − cost)). A pragmatic listener (L1) samples state, scope, and QUD from their priors and updates on the speaker. A pragmatic speaker (S2) weights utterances by the pragmatic listener's posterior probability of the true state.

query

The distribution over utterances ["null", "every-not"] for the pragmatic speaker at world state 1 (one of two horses is red).

answer spec dist/finite
{
  "kind": "dist",
  "domain": "finite",
  "support": [
    "null",
    "every-not"
  ]
}
system prompt constant across problems
(system prompt loads here)
webppl primer solver context
(primer loads here)
01 realizations comparing webppl vs pyro
ground truth
webppl
1// helper function to tally up the state
2var numTrue = function(state) {
3 var fun = function(x) {
4 x ? 1 : 0
5 }
6 return sum(map(fun,state))
7}
8
9// possible utterances
10var utterances = ["null","every-not"];
11var utterancePrior = function() {
12 uniformDraw(utterances)
13}
14// uniform utterance cost
15var cost = function(utterance) {
16 return 1
17}
18
19// possible world states
20var numHorses = 2
21var states = [0,1,2]
22var baserate = 0.5 // change to manipulate prior on world states
23var stateMaker = function(numHorses,stateSoFar) {
24 if (numHorses == 0) {
25 return stateSoFar
26 } else {
27 var newHorse = flip(baserate)
28 var newState = stateSoFar.concat([newHorse])
29 return stateMaker(numHorses - 1, newState)
30 }
31}
32var statePrior = function() {
33 return numTrue(stateMaker(numHorses,[]))
34}
35
36// possible scope interpretations
37var scopes = ["surface", "inverse"]
38var scopePrior = function(){
39 return categorical([.5,.5],scopes) // change to manipulate prior on scope interpretations
40}
41
42
43// meaning function
44var meaning = function(utterance, state, scope) {
45 return utterance == "every-not" ?
46 scope == "surface" ? state == 0 :
47 state < numHorses :
48 true;
49};
50
51// possible QUDs
52var QUDs = ["how many?","all red?","none red?"];
53var QUDPrior = function() {
54 uniformDraw(QUDs);
55 // categorical([.05,.05,.9],QUDs) // change to manipulate prior on QUDs
56}
57var QUDFun = function(QUD,state) {
58 QUD == "all red?" ? state == numHorses :
59 QUD == "none red?" ? state == 0 :
60 state;
61};
62
63// Literal listener (L0)
64var literalListener = cache(function(utterance,scope,QUD) {
65 Infer({model: function(){
66 var state = uniformDraw(states);
67 var qState = QUDFun(QUD,state)
68 condition(meaning(utterance,state,scope));
69 return qState;
70 }});
71});
72
73var alpha = 1
74
75// Speaker (S1)
76var speaker = cache(function(scope, state, QUD) {
77 return Infer({model: function(){
78 var utterance = utterancePrior()
79 var qState = QUDFun(QUD, state)
80 factor(alpha*(literalListener(utterance,scope,QUD).score(qState)
81 - cost(utterance)))
82 return utterance
83 }})
84})
85
86// Pragmatic listener (L1)
87var pragmaticListener = cache(function(utterance) {
88 Infer({model: function(){
89 var state = statePrior();
90 var scope = scopePrior();
91 var QUD = QUDPrior();
92 observe(speaker(scope,state,QUD),utterance);
93 return state
94 }});
95});
96
97// Pragmatic speaker (S2)
98var pragmaticSpeaker = cache(function(state) {
99 Infer({model: function(){
100 var utterance = utterancePrior();
101 factor(pragmaticListener(utterance).score(state))
102 return utterance
103 }})
104})
105
106// A speaker decides whether to endorse the ambiguous utterance as a
107// description of the not-all world state
108display(pragmaticSpeaker(1))
109
110var ANSWER = (pragmaticSpeaker(1));
realization0.000
python
1
2NEG = -1e30
3utterances = ["null", "every-not"]
4numHorses = 2
5states = [0, 1, 2]
6# statePrior = numTrue of two flip(0.5) -> Binomial(2, 0.5)
7state_prior = torch.tensor([0.25, 0.5, 0.25])
8scopes = ["surface", "inverse"]
9scope_prior = torch.tensor([0.5, 0.5])
10QUDs = ["how many?", "all red?", "none red?"]
11qud_prior = torch.ones(3)
12alpha = 1.0
13
14
15def meaning(utt, state, scope):
16 if utt == "every-not":
17 if scope == "surface":
18 return state == 0
19 return state < numHorses
20 return True
21
22
23def qud_fun(qud, state):
24 if qud == "all red?":
25 return ("b", state == numHorses)
26 if qud == "none red?":
27 return ("b", state == 0)
28 return ("n", state)
29
30
31# Literal listener: posterior over the QUD-projected state, inferred.
32def literal_listener(ui, scope_i, qi):
33 utt = utterances[ui]
34 scope = scopes[scope_i]
35 qud = QUDs[qi]
36 mask = torch.tensor([0.0 if meaning(utt, states[s], scope) else NEG for s in range(3)])
37
38 @pyro.infer.config_enumerate
39 def model():
40 st = pyro.sample("st", dist.Categorical(torch.ones(3)))
41 pyro.factor("cond", mask[st])
42 return None
43
44 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)
45 m = marg["st"]
46 sup = m.enumerate_support()
47 pr = m.log_prob(sup).exp()
48 out = {}
49 for i, p in zip(sup.tolist(), pr.tolist()):
50 q = qud_fun(qud, states[int(i)])
51 out[q] = out.get(q, 0.0) + p
52 return out
53
54
55# Speaker S1: utterance posterior given (scope, state, QUD), inferred.
56def speaker1_logprobs(scope_i, st_val, qi):
57 qud = QUDs[qi]
58 q_state = qud_fun(qud, st_val)
59 ll = [literal_listener(u, scope_i, qi) for u in range(2)]
60
61 @pyro.infer.config_enumerate
62 def model():
63 u = pyro.sample("u", dist.Categorical(torch.ones(2)))
64 sc = torch.tensor([alpha * ((math.log(ll[uu].get(q_state, 0.0))
65 if ll[uu].get(q_state, 0.0) > 0 else NEG) - 1.0)
66 for uu in range(2)]) # cost(utterance) == 1
67 pyro.factor("sc", sc[u])
68 return None
69
70 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)
71 m = marg["u"]
72 sup = m.enumerate_support()
73 pr = m.log_prob(sup).exp()
74 return {int(i): p for i, p in zip(sup.tolist(), pr.tolist())}
75
76
77# Pragmatic listener L1: state posterior given utterance, inferred.
78def pragmatic_listener(ui):
79 score = torch.full((3, 2, 3), NEG)
80 for st in range(3):
81 for sc in range(2):
82 for qi in range(3):
83 spk = speaker1_logprobs(sc, states[st], qi)
84 p = spk.get(ui, 0.0)
85 score[st, sc, qi] = math.log(p) if p > 0 else NEG
86
87 @pyro.infer.config_enumerate
88 def model():
89 st = pyro.sample("st", dist.Categorical(state_prior))
90 sc = pyro.sample("sc", dist.Categorical(scope_prior))
91 qd = pyro.sample("qd", dist.Categorical(qud_prior))
92 pyro.factor("obs", score[st, sc, qd])
93 return None
94
95 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)
96 m = marg["st"]
97 sup = m.enumerate_support()
98 pr = m.log_prob(sup).exp()
99 return {int(i): p for i, p in zip(sup.tolist(), pr.tolist())}
100
101
102# Pragmatic speaker S2: utterance posterior given the state, inferred.
103def pragmatic_speaker(st_val):
104 l1 = [pragmatic_listener(u) for u in range(2)]
105 sc = torch.tensor([math.log(l1[u].get(st_val, 0.0)) if l1[u].get(st_val, 0.0) > 0 else NEG
106 for u in range(2)])
107
108 @pyro.infer.config_enumerate
109 def model():
110 u = pyro.sample("u", dist.Categorical(torch.ones(2)))
111 pyro.factor("sc", sc[u])
112 return None
113
114 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)
115 m = marg["u"]
116 sup = m.enumerate_support()
117 pr = m.log_prob(sup).exp()
118 return {utterances[int(i)]: p for i, p in zip(sup.tolist(), pr.tolist())}
119
120
121ANSWER = pragmatic_speaker(1)
122
02answer overlay — webppl vs pyrodist/finite
webppl pyro2 bins
00.250.250.510.51every-not A = 0.506 B = 0.506every-not A = 0.506 B = 0.5060.510.51every-notnull A = 0.494 B = 0.494null A = 0.494 B = 0.4940.490.49null
03 verification
checkstatusevidence
GT self-consistency ok floor 0.0000 (tv)
solver re-derivation accept 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6
cross-language (pyro vs webppl) pass d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000
forestdb-lai-irony / atom-1
answer dist/finite solver accept pyro pass 0.0000
00 statement source: data/sources/forestdb.org/models/lai-irony.md
given

Weather states are 'terrible', 'ok', and 'amazing'. The prior over states is categorical with unnormalized weights [1, 50, 50] — ok and amazing are each 50 times more likely than terrible (California prior). Valence (positive or negative affect) depends on state: given 'terrible', negative valence has probability 0.99; given 'ok', each valence is equally likely (0.5); given 'amazing', positive valence has probability 0.99. Arousal is binary: 'low' or 'high'. Given 'terrible', the arousal distribution is categorical([0.1, 0.9]); given 'ok', categorical([0.9, 0.1]); given 'amazing', categorical([0.1, 0.9]). A speaker has one of three communicative goals — conveying the weather state, conveying valence, or conveying arousal — each equally likely a priori. Utterances are the state labels themselves, each equally likely a priori. The speaker uses a utility-weighted pragmatic reasoning chain: a literal listener updates on whether the utterance matches the state and returns the goal-relevant quantity; a speaker chooses utterances proportional to the literal listener's probability of recovering the goal-relevant truth (soft-max with weight 1); a pragmatic listener inverts the speaker to jointly infer state, valence, arousal, and goal.

model

A pragmatic listener hears an utterance and infers the joint distribution over (weather state, valence, arousal, communicative goal) by inverting a speaker who chose the utterance to satisfy a sampled communicative goal. The literal listener conditions only on whether the utterance names the state literally.

query

The posterior joint distribution over (state, valence, arousal, goal) given the utterance 'terrible', as a distribution over the four-field record {state, valence, arousal, goal}. Goal labels are the strings 'goalState', 'goalValence', and 'goalArousal'.

answer spec dist/finite
{
  "kind": "dist",
  "domain": "finite",
  "labels": {
    "record": {
      "state": "string",
      "valence": "int",
      "arousal": "string",
      "goal": "string"
    }
  }
}
system prompt constant across problems
(system prompt loads here)
webppl primer solver context
(primer loads here)
01 realizations comparing webppl vs pyro
ground truth
webppl
1// There are three possible states the weather could be in:
2// terrible, ok, or amazing
3var states = ['terrible', 'ok', 'amazing']
4
5// Since we are in California, the prior over these states
6// are the following. Once could also imagine this being
7// the prior in a certain context, e.g. when it's clearly
8// sunny and nice out.
9var statePrior = function() {
10 categorical([1, 50, 50], states)
11}
12
13// Valence prior defined in terms of negative valence.
14// If the current state is terrible, it's extremely likely
15// that the valence associated is negative. If it's ok, then
16// the valence could be negative or positive with equal
17// probability.
18var valencePrior = function(state) {
19 state === "terrible" ? flip(0.99) ? -1 : 1 :
20 state === "ok" ? flip(0.5) ? -1 : 1 :
21 state === "amazing" ? flip(0.01) ? -1 : 1 :
22 true
23}
24
25// Define binary arousals (could model as continuous).
26var arousals = ["low", "high"]
27
28// Define goals and goal priors. Could want to communicate state of the world,
29// valence about it, or arousal (intensity of feeling) about it.
30var goals = ["goalState", "goalValence", "goalArousal"]
31
32var goalPrior = function() {
33 categorical([1, 1, 1], goals)
34}
35
36// Assume possible utterances are identical to possible states
37var utterances = states
38
39// Assume cost of utterances is uniform.
40var utterancePrior = function() {
41 uniformDraw(utterances)
42}
43
44// Sample arousal given a state.
45var arousalPrior = function(state) {
46 state === "terrible" ? categorical([0.1, 0.9], arousals) :
47 state === "ok" ? categorical([0.9, 0.1], arousals) :
48 state === "amazing" ? categorical([0.1, 0.9], arousals) :
49 true
50}
51
52// Literal interpretation is just whether utterance equals state
53var literalInterpretation = function(utterance, state) {
54 utterance === state
55}
56
57// A speaker's goal is satisfied if the listener infers the correct
58// and relevant information.
59var goalState = function(goal, state, valence, arousal) {
60 goal === "goalState" ? state :
61 goal === "goalValence" ? valence :
62 goal === "goalArousal" ? arousal :
63 true
64}
65
66// Define a literal listener
67var literalListener = function(utterance, goal) {
68 Infer({model: function(){
69 var state = uniformDraw(states)
70 var valence = valencePrior(state)
71 var arousal = arousalPrior(state)
72 condition(literalInterpretation(utterance,state))
73 return goalState(goal, state, valence, arousal)
74 }})
75}
76
77// Define a speaker
78var speaker = function(state, valence, arousal, goal) {
79 Infer({model: function(){
80 var utterance = utterancePrior()
81 factor(1 * literalListener(utterance,
82 goal).score(goalState(goal,
83 state,
84 valence,
85 arousal)))
86 return utterance
87 }})
88}
89
90// Define a pragmatic listener
91var pragmaticListener = function(utterance) {
92 Infer({model: function(){
93 var state = statePrior()
94 var valence = valencePrior(state)
95 var arousal = arousalPrior(state)
96 var goal = goalPrior()
97 observe(speaker(state, valence, arousal, goal),utterance)
98 return {state, valence, arousal, goal}
99 }})
100}
101
102viz.table(literalListener("terrible", "goalState"))
103viz.table(speaker("terrible", -1, "high", "goalValence"))
104viz.table(pragmaticListener("terrible"))
105
106var ANSWER = (pragmaticListener('terrible'));
realization0.000
python
1
2states = ["terrible", "ok", "amazing"]
3state_prior_w = [1.0, 50.0, 50.0]
4utterances = states[:]
5goals = ["goalState", "goalValence", "goalArousal"]
6valence_labels = [-1, 1]
7arousal_labels = ["low", "high"]
8
9def valence_neg_prob(state): # P(valence == -1)
10 return 0.99 if state == "terrible" else 0.5 if state == "ok" else 0.01
11
12def arousal_low_prob(state): # P(arousal == 'low')
13 return 0.1 if state == "terrible" else 0.9 if state == "ok" else 0.1
14
15def _elbo():
16 return pyro.infer.TraceEnum_ELBO(max_plate_nesting=0)
17
18_LL_CACHE = {}
19def literal_listener_marginals(utterance, goal):
20 key = (utterance, goal)
21 if key in _LL_CACHE:
22 return _LL_CACHE[key]
23 u_idx = states.index(utterance)
24
25 @pyro.infer.config_enumerate
26 def model():
27 si = pyro.sample("state", dist.Categorical(torch.ones(3) / 3)) # uniformDraw
28 pneg = torch.tensor([valence_neg_prob(s) for s in states])[si]
29 # valence idx0 -> -1, idx1 -> +1
30 pyro.sample("valence", dist.Categorical(torch.stack([pneg, 1 - pneg], -1)))
31 pl = torch.tensor([arousal_low_prob(s) for s in states])[si]
32 # arousal idx0 -> low, idx1 -> high
33 pyro.sample("arousal", dist.Categorical(torch.stack([pl, 1 - pl], -1)))
34 # condition literalInterpretation: utterance === state
35 pyro.factor(
36 "lit",
37 torch.where(si == u_idx, torch.tensor(0.0), torch.tensor(float("-inf"))),
38 )
39 return si
40
41 marg = _elbo().compute_marginals(model, lambda: None)
42 _LL_CACHE[key] = marg
43 return marg
44
45def qud_answer(goal, state, valence, arousal):
46 if goal == "goalState":
47 return ("state", state)
48 if goal == "goalValence":
49 return ("valence", valence)
50 return ("arousal", arousal)
51
52def ll_logp(utterance, goal, ans):
53 marg = literal_listener_marginals(utterance, goal)
54 kind, val = ans
55 if kind == "state":
56 d = marg["state"]; idx = states.index(val)
57 elif kind == "valence":
58 d = marg["valence"]; idx = valence_labels.index(val)
59 else:
60 d = marg["arousal"]; idx = arousal_labels.index(val)
61 return float(d.log_prob(torch.tensor(idx)))
62
63_SP_CACHE = {}
64def speaker(state, valence, arousal, goal):
65 key = (state, valence, arousal, goal)
66 if key in _SP_CACHE:
67 return _SP_CACHE[key]
68 ans = qud_answer(goal, state, valence, arousal)
69 utils = torch.tensor([1.0 * ll_logp(u, goal, ans) for u in utterances])
70
71 @pyro.infer.config_enumerate
72 def model():
73 ui = pyro.sample("utt", dist.Categorical(torch.ones(len(utterances)) / len(utterances)))
74 pyro.factor("util", utils[ui]) # factor(1 * literalListener.score(...))
75 return ui
76
77 marg = _elbo().compute_marginals(model, lambda: None)["utt"]
78 _SP_CACHE[key] = marg
79 return marg
80
81target = "terrible"
82target_idx = utterances.index(target)
83
84cfgs = []
85for st in states:
86 for val in valence_labels:
87 for ar in arousal_labels:
88 for goal in goals:
89 cfgs.append((st, val, ar, goal))
90
91Zw = sum(state_prior_w)
92prior = torch.tensor([
93 (state_prior_w[states.index(c[0])] / Zw)
94 * (valence_neg_prob(c[0]) if c[1] == -1 else 1 - valence_neg_prob(c[0]))
95 * (arousal_low_prob(c[0]) if c[2] == "low" else 1 - arousal_low_prob(c[0]))
96 * (1.0 / len(goals))
97 for c in cfgs
98])
99
100lik = torch.tensor([
101 float(torch.exp(speaker(c[0], c[1], c[2], c[3]).log_prob(torch.tensor(target_idx))))
102 for c in cfgs
103])
104
105@pyro.infer.config_enumerate
106def pragmatic_model():
107 ci = pyro.sample("cfg", dist.Categorical(prior / prior.sum()))
108 pyro.factor("speaker", torch.log(lik[ci] + 1e-300)) # observe(speaker, utterance)
109 return ci
110
111cfg_marg = _elbo().compute_marginals(pragmatic_model, lambda: None)["cfg"]
112sup = cfg_marg.enumerate_support()
113p = torch.exp(cfg_marg.log_prob(sup))
114
115support_records = []
116probs = []
117for i, pp in zip(sup, p):
118 st, val, ar, goal = cfgs[int(i)]
119 support_records.append({"state": st, "valence": val, "arousal": ar, "goal": goal})
120 probs.append(float(pp))
121
122ANSWER = {"support": support_records, "probs": probs}
123
02answer overlay — webppl vs pyrodist/finite
webppl pyro36 bins
00.220.220.440.44{"arousal":"high","goal":"goalArousal","state":"amazing","valence":-1} A = 0.004 B = 0.004{"arousal":"high","goal":"goalArousal","state":"amazing","valence":-1} A = 0.004 B = 0.004{"arousal":"high","goal":"goalArousal","state":"amazing","valence":-1}{"arousal":"high","goal":"goalArousal","state":"amazing","valence":1} A = 0.443 B = 0.443{"arousal":"high","goal":"goalArousal","state":"amazing","valence":1} A = 0.443 B = 0.443{"arousal":"high","goal":"goalArousal","state":"ok","valence":-1} A = 0.025 B = 0.025{"arousal":"high","goal":"goalArousal","state":"ok","valence":-1} A = 0.025 B = 0.025{"arousal":"high","goal":"goalArousal","state":"ok","valence":1} A = 0.025 B = 0.025{"arousal":"high","goal":"goalArousal","state":"ok","valence":1} A = 0.025 B = 0.025{"arousal":"high","goal":"goalArousal","state":"ok","valence":1}{"arousal":"high","goal":"goalArousal","state":"terrible","valence":-1} A = 0.009 B = 0.009{"arousal":"high","goal":"goalArousal","state":"terrible","valence":-1} A = 0.009 B = 0.009{"arousal":"high","goal":"goalArousal","state":"terrible","valence":1} A = 0.000 B = 0.000{"arousal":"high","goal":"goalArousal","state":"terrible","valence":1} A = 0.000 B = 0.000{"arousal":"high","goal":"goalState","state":"terrible","valence":-1} A = 0.019 B = 0.019{"arousal":"high","goal":"goalState","state":"terrible","valence":-1} A = 0.019 B = 0.019{"arousal":"high","goal":"goalState","state":"terrible","valence":-1}{"arousal":"high","goal":"goalState","state":"terrible","valence":1} A = 0.000 B = 0.000{"arousal":"high","goal":"goalState","state":"terrible","valence":1} A = 0.000 B = 0.000{"arousal":"high","goal":"goalValence","state":"amazing","valence":-1} A = 0.006 B = 0.006{"arousal":"high","goal":"goalValence","state":"amazing","valence":-1} A = 0.006 B = 0.006{"arousal":"high","goal":"goalValence","state":"amazing","valence":1} A = 0.006 B = 0.006{"arousal":"high","goal":"goalValence","state":"amazing","valence":1} A = 0.006 B = 0.006{"arousal":"high","goal":"goalValence","state":"amazing","valence":1}{"arousal":"high","goal":"goalValence","state":"ok","valence":-1} A = 0.035 B = 0.035{"arousal":"high","goal":"goalValence","state":"ok","valence":-1} A = 0.035 B = 0.035{"arousal":"high","goal":"goalValence","state":"ok","valence":1} A = 0.000 B = 0.000{"arousal":"high","goal":"goalValence","state":"ok","valence":1} A = 0.000 B = 0.000{"arousal":"high","goal":"goalValence","state":"terrible","valence":-1} A = 0.012 B = 0.012{"arousal":"high","goal":"goalValence","state":"terrible","valence":-1} A = 0.012 B = 0.012{"arousal":"high","goal":"goalValence","state":"terrible","valence":-1}{"arousal":"high","goal":"goalValence","state":"terrible","valence":1} A = 0.000 B = 0.000{"arousal":"high","goal":"goalValence","state":"terrible","valence":1} A = 0.000 B = 0.000{"arousal":"low","goal":"goalArousal","state":"amazing","valence":-1} A = 0.000 B = 0.000{"arousal":"low","goal":"goalArousal","state":"amazing","valence":-1} A = 0.000 B = 0.000{"arousal":"low","goal":"goalArousal","state":"amazing","valence":1} A = 0.009 B = 0.009{"arousal":"low","goal":"goalArousal","state":"amazing","valence":1} A = 0.009 B = 0.009{"arousal":"low","goal":"goalArousal","state":"amazing","valence":1}{"arousal":"low","goal":"goalArousal","state":"ok","valence":-1} A = 0.043 B = 0.043{"arousal":"low","goal":"goalArousal","state":"ok","valence":-1} A = 0.043 B = 0.043{"arousal":"low","goal":"goalArousal","state":"ok","valence":1} A = 0.043 B = 0.043{"arousal":"low","goal":"goalArousal","state":"ok","valence":1} A = 0.043 B = 0.043{"arousal":"low","goal":"goalArousal","state":"terrible","valence":-1} A = 0.000 B = 0.000{"arousal":"low","goal":"goalArousal","state":"terrible","valence":-1} A = 0.000 B = 0.000{"arousal":"low","goal":"goalArousal","state":"terrible","valence":-1}{"arousal":"low","goal":"goalArousal","state":"terrible","valence":1} A = 0.000 B = 0.000{"arousal":"low","goal":"goalArousal","state":"terrible","valence":1} A = 0.000 B = 0.000{"arousal":"low","goal":"goalState","state":"terrible","valence":-1} A = 0.002 B = 0.002{"arousal":"low","goal":"goalState","state":"terrible","valence":-1} A = 0.002 B = 0.002{"arousal":"low","goal":"goalState","state":"terrible","valence":1} A = 0.000 B = 0.000{"arousal":"low","goal":"goalState","state":"terrible","valence":1} A = 0.000 B = 0.000{"arousal":"low","goal":"goalState","state":"terrible","valence":1}{"arousal":"low","goal":"goalValence","state":"amazing","valence":-1} A = 0.001 B = 0.001{"arousal":"low","goal":"goalValence","state":"amazing","valence":-1} A = 0.001 B = 0.001{"arousal":"low","goal":"goalValence","state":"amazing","valence":1} A = 0.001 B = 0.001{"arousal":"low","goal":"goalValence","state":"amazing","valence":1} A = 0.001 B = 0.001{"arousal":"low","goal":"goalValence","state":"ok","valence":-1} A = 0.312 B = 0.312{"arousal":"low","goal":"goalValence","state":"ok","valence":-1} A = 0.312 B = 0.312{"arousal":"low","goal":"goalValence","state":"ok","valence":-1}{"arousal":"low","goal":"goalValence","state":"ok","valence":1} A = 0.003 B = 0.003{"arousal":"low","goal":"goalValence","state":"ok","valence":1} A = 0.003 B = 0.003{"arousal":"low","goal":"goalValence","state":"terrible","valence":-1} A = 0.001 B = 0.001{"arousal":"low","goal":"goalValence","state":"terrible","valence":-1} A = 0.001 B = 0.001{"arousal":"low","goal":"goalValence","state":"terrible","valence":1} A = 0.000 B = 0.000{"arousal":"low","goal":"goalValence","state":"terrible","valence":1} A = 0.000 B = 0.000{"arousal":"low","goal":"goalValence","state":"terrible","valence":1}{"arousal":"high","goal":"goalState","state":"amazing","valence":-1} A = 0.000 B = 0.000{"arousal":"high","goal":"goalState","state":"amazing","valence":-1} A = 0.000 B = 0.000{"arousal":"high","goal":"goalState","state":"amazing","valence":1} A = 0.000 B = 0.000{"arousal":"high","goal":"goalState","state":"amazing","valence":1} A = 0.000 B = 0.000{"arousal":"high","goal":"goalState","state":"ok","valence":-1} A = 0.000 B = 0.000{"arousal":"high","goal":"goalState","state":"ok","valence":-1} A = 0.000 B = 0.000{"arousal":"high","goal":"goalState","state":"ok","valence":-1}{"arousal":"high","goal":"goalState","state":"ok","valence":1} A = 0.000 B = 0.000{"arousal":"high","goal":"goalState","state":"ok","valence":1} A = 0.000 B = 0.000{"arousal":"low","goal":"goalState","state":"amazing","valence":-1} A = 0.000 B = 0.000{"arousal":"low","goal":"goalState","state":"amazing","valence":-1} A = 0.000 B = 0.000{"arousal":"low","goal":"goalState","state":"amazing","valence":1} A = 0.000 B = 0.000{"arousal":"low","goal":"goalState","state":"amazing","valence":1} A = 0.000 B = 0.000{"arousal":"low","goal":"goalState","state":"amazing","valence":1}{"arousal":"low","goal":"goalState","state":"ok","valence":-1} A = 0.000 B = 0.000{"arousal":"low","goal":"goalState","state":"ok","valence":-1} A = 0.000 B = 0.000{"arousal":"low","goal":"goalState","state":"ok","valence":1} A = 0.000 B = 0.000{"arousal":"low","goal":"goalState","state":"ok","valence":1} A = 0.000 B = 0.000
03 verification
checkstatusevidence
GT self-consistency ok floor 0.0000 (tv)
solver re-derivation accept 1/2 solvers · d=[0.469, 0.000] · claude-sonnet-4-6
cross-language (pyro vs webppl) pass d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000
forestdb-lxz-chinese-scope / atom-1
answer dist/finite solver accept pyro pass 0.0000
00 statement source: forestdb.org/models/lxz-chinese-scope.md
given

World states are 0, 1, 2 (number of rabbits fed), each equally likely. Utterances are 'null' and 'not-two', each equally likely a priori with uniform cost 1. Scopes are 'surface' and 'inverse', each equally likely. QUDs are 'how many?', 'all red?', and 'none red?', each equally likely. Semantic rationality parameter alpha = 1. The literal meaning of 'not-two' under surface scope is that fewer than 2 rabbits were fed (state < 2); under inverse scope it means no rabbits were fed (state = 0). 'null' is always true. The QUD function maps 'all red?' to whether state = 2, 'none red?' to whether state = 0, and 'how many?' to the state itself. A literal listener conditions on the utterance's truth under the given scope and QUD, then returns the QUD projection of the state. A pragmatic speaker at level 1 weights utterances by alpha times the log-probability the literal listener assigns to the QUD value minus the utterance cost. A pragmatic listener at level 1 inverts the speaker to infer the world state, marginalizing over scope and QUD. A pragmatic speaker at level 2 weights utterances by the log-probability the pragmatic listener assigns to the world state.

model

A four-level RSA hierarchy where the literal listener interprets utterances relative to a scope and a QUD, the first-level speaker communicates through a utility-weighted choice, the first-level pragmatic listener inverts the speaker marginalizing over scope and QUD, and the second-level pragmatic speaker further weighs utterances by how informative they are to the pragmatic listener.

query

The distribution over utterances ('null' or 'not-two') chosen by the second-level pragmatic speaker for world state 1 — one rabbit was fed.

answer spec dist/finite
{
  "kind": "dist",
  "domain": "finite",
  "support": [
    "null",
    "not-two"
  ]
}
system prompt constant across problems
(system prompt loads here)
webppl primer solver context
(primer loads here)
01 realizations comparing webppl vs pyro
ground truth
webppl
1// Here is the code for English Expt 1 (surface scope)
2//different qud prior, different state prior, access to alternative utterances
3
4// Here is the code for the quantifier scope model
5
6// possible utterances
7var utterances = ["null","not-two"];
8
9var utterancePrior = function() {
10 categorical({vs:["null","not-two"],ps:[1,1]})
11}
12
13var cost = function(utterance) {
14 return 1
15}
16
17// possible world states
18var states = [0,1,2];
19var statePrior = function() {
20 categorical({vs:[0,1,2],ps:[1,1,1]})
21}
22
23// possible scopes
24var scopePrior = function(){
25 return categorical({vs:["surface", "inverse"],ps:[1,1]})
26}
27
28var meaning = function(utterance, state, scope) {
29 //if utterance == none:
30 //return state==0
31 //else:
32 //elif utternace == nottwo:
33 //if scope == surface:
34 //return state == 0 / state==1
35 //else:
36 //return state == 0
37 //else:
38 //return true
39
40 return utterance == "not-two" ?
41 scope == "surface" ? (state == 0 | state ==1):
42 state == 0 :
43 true;
44};
45
46
47// QUDs
48var QUDs = ["how many?","all red?","none red?"];
49var QUDPrior = function() {
50 uniformDraw(QUDs);
51}
52var QUDFun = function(QUD,state) {
53 QUD == "all red?" ? state == 2 :
54 QUD == "none red?" ? state == 0 :
55 state;
56};
57
58// Literal listener (L0)
59var literalListener = cache(function(utterance,scope,QUD) {
60 Infer({model: function(){
61 var state = uniformDraw(states);
62 var qState = QUDFun(QUD,state)
63 condition(meaning(utterance,state,scope));
64 return qState;
65 }});
66});
67
68var alpha = 1
69
70// Speaker (S)
71var speaker = cache(function(scope, state, QUD) {
72 return Infer({model: function(){
73 var utterance = utterancePrior()
74 var qState = QUDFun(QUD, state)
75 factor(alpha*(literalListener(utterance,scope,QUD).score(qState)
76 - cost(utterance)))
77 return utterance
78 }})
79})
80
81// Pragmatic listener (L1)
82var pragmaticListener = cache(function(utterance) {
83 Infer({model: function(){
84 var state = statePrior();
85 var scope = scopePrior();
86 var QUD = QUDPrior();
87 observe(speaker(scope,state,QUD),utterance);
88 return state
89 }});
90});
91
92// Pragmatic speaker (S2)
93var pragmaticSpeaker = cache(function(state) {
94 Infer({model: function(){
95 var utterance = utterancePrior();
96 factor(pragmaticListener(utterance).score(state))
97 return utterance
98 }})
99})
100
101
102
103// A speaker decides whether to endorse the ambiguous utterance as a
104// description of the not-all world state
105viz.table(pragmaticSpeaker(1))
106viz(pragmaticSpeaker(1))
107//literalListener("surface", 2, "all red?")
108
109var ANSWER = (pragmaticSpeaker(1));
realization0.000
python
1# RSA quantifier-scope model (English Expt 1, surface scope).
2# Each RSA level is a SEPARATE, fully-computed and memoized Pyro enumeration.
3# A level's only latent sample site is enumerated and its marginal is read with
4# TraceEnum_ELBO.compute_marginals; the finished distribution is then fed into
5# the next level as a fixed pyro.factor score table. No inference is ever run
6# inside another level's active trace, and every sample site has a name unique
7# to its level, so there is no site collision.
8
9utterances = ["null", "not-two"]
10states = [0, 1, 2]
11scopes = ["surface", "inverse"]
12QUDs = ["how many?", "all red?", "none red?"]
13alpha = 1.0
14
15state_logits = torch.zeros(len(states)) # ps:[1,1,1]
16utt_logits = torch.zeros(len(utterances)) # ps:[1,1]
17scope_logits = torch.zeros(len(scopes)) # ps:[1,1]
18qud_logits = torch.zeros(len(QUDs)) # uniform
19
20
21def cost(_u):
22 return 1.0
23
24
25def meaning(utterance, state, scope):
26 if utterance == "not-two":
27 if scope == "surface":
28 return (state == 0) or (state == 1)
29 return state == 0
30 return True
31
32
33def qud_fun(qud, state):
34 if qud == "all red?":
35 return ("bool", state == 2)
36 if qud == "none red?":
37 return ("bool", state == 0)
38 return ("state", state)
39
40
41def enum_marginal(model, site):
42 # Run exact discrete enumeration over a single-latent model and return its
43 # marginal as {support_index: prob}.
44 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(
45 model, lambda: None
46 )[site]
47 sup = marg.enumerate_support()
48 probs = marg.log_prob(sup).exp()
49 return {int(sup[i].item()): float(probs[i].item()) for i in range(sup.shape[0])}
50
51
52# ---- Literal listener L0: posterior over qState given (utterance, scope, QUD) ----
53L0_cache = {}
54
55
56def literal_listener(utterance, scope, qud):
57 key = (utterance, scope, qud)
58 if key in L0_cache:
59 return L0_cache[key]
60 ok = torch.tensor([meaning(utterance, st, scope) for st in states], dtype=torch.bool)
61
62 @pyro.infer.config_enumerate
63 def model():
64 s = pyro.sample("L0_state", dist.Categorical(logits=state_logits))
65 pyro.factor("L0_ev", torch.where(ok[s], torch.tensor(0.0), torch.tensor(float("-inf"))))
66 return s
67
68 state_post = enum_marginal(model, "L0_state")
69 # push the inferred state posterior forward through the QUD projection
70 d = defaultdict(float)
71 for si, p in state_post.items():
72 d[qud_fun(qud, states[si])] += p
73 L0_cache[key] = dict(d)
74 return L0_cache[key]
75
76
77def l0_score(utterance, scope, qud, qstate):
78 p = literal_listener(utterance, scope, qud).get(qstate, 0.0)
79 return math.log(p) if p > 0 else float("-inf")
80
81
82# ---- Speaker S: distribution over utterances given (scope, state, QUD) ----
83S_cache = {}
84
85
86def speaker(scope, state, qud):
87 key = (scope, state, qud)
88 if key in S_cache:
89 return S_cache[key]
90 qstate = qud_fun(qud, state)
91 sc = torch.tensor([alpha * (l0_score(u, scope, qud, qstate) - cost(u)) for u in utterances])
92
93 @pyro.infer.config_enumerate
94 def model():
95 u = pyro.sample("S_utt", dist.Categorical(logits=utt_logits))
96 pyro.factor("S_f", sc[u])
97 return u
98
99 post = enum_marginal(model, "S_utt")
100 S_cache[key] = {utterances[i]: p for i, p in post.items()}
101 return S_cache[key]
102
103
104def speaker_score(scope, state, qud, utterance):
105 p = speaker(scope, state, qud).get(utterance, 0.0)
106 return math.log(p) if p > 0 else float("-inf")
107
108
109# ---- Pragmatic listener L1: posterior over state given utterance ----
110L1_cache = {}
111
112
113def pragmatic_listener(utterance):
114 if utterance in L1_cache:
115 return L1_cache[utterance]
116 # finished speaker scores for every (state, scope, qud), evaluated up front
117 tbl = torch.full((len(states), len(scopes), len(QUDs)), float("-inf"))
118 for si, st in enumerate(states):
119 for sci, scp in enumerate(scopes):
120 for qi, qd in enumerate(QUDs):
121 tbl[si, sci, qi] = speaker_score(scp, st, qd, utterance)
122
123 @pyro.infer.config_enumerate
124 def model():
125 s = pyro.sample("L1_state", dist.Categorical(logits=state_logits))
126 sc = pyro.sample("L1_scope", dist.Categorical(logits=scope_logits))
127 q = pyro.sample("L1_qud", dist.Categorical(logits=qud_logits))
128 pyro.factor("L1_obs", tbl[s, sc, q])
129 return s
130
131 post = enum_marginal(model, "L1_state")
132 L1_cache[utterance] = {states[i]: p for i, p in post.items()}
133 return L1_cache[utterance]
134
135
136def pragmatic_listener_score(utterance, state):
137 p = pragmatic_listener(utterance).get(state, 0.0)
138 return math.log(p) if p > 0 else float("-inf")
139
140
141# ---- Pragmatic speaker S2: distribution over utterances given state ----
142def pragmatic_speaker(state):
143 sc = torch.tensor([pragmatic_listener_score(u, state) for u in utterances])
144
145 @pyro.infer.config_enumerate
146 def model():
147 u = pyro.sample("S2_utt", dist.Categorical(logits=utt_logits))
148 pyro.factor("S2_f", sc[u])
149 return u
150
151 post = enum_marginal(model, "S2_utt")
152 return {utterances[i]: p for i, p in post.items()}
153
154
155ANSWER = pragmatic_speaker(1)
156
02answer overlay — webppl vs pyrodist/finite
webppl pyro2 bins
00.250.250.510.51not-two A = 0.508 B = 0.508not-two A = 0.508 B = 0.5080.510.51not-twonull A = 0.492 B = 0.492null A = 0.492 B = 0.4920.490.49null
03 verification
checkstatusevidence
GT self-consistency ok floor 0.0000 (tv)
solver re-derivation accept 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6
cross-language (pyro vs webppl) pass d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000
forestdb-lxz-chinese-scope / atom-2
answer dist/finite solver accept pyro pass 0.0000
00 statement source: data/sources/forestdb.org/models/lxz-chinese-scope.md
given

World states are 0, 1, 2, 3, 4 (number of rabbits fed out of 4), each equally likely. Utterances are 'null' and 'not-two'; the utterance prior is categorical with weights [1, 10] — 'not-two' is ten times more likely than 'null'. Both utterances have uniform cost 1. Scopes are 'surface' and 'inverse', each equally likely. QUDs are 'how many?', 'all red?', and 'none red?', each equally likely. Semantic rationality parameter alpha = 1. The literal meaning of 'not-two' under surface scope is that fewer than 2 rabbits were fed (state < 2); under inverse scope it means fewer than 3 rabbits were fed (state < 3). 'null' is always true. The QUD function maps 'all red?' to whether state = 4, 'none red?' to whether state = 0, and 'how many?' to the state itself. A literal listener conditions on the utterance's truth under the given scope and QUD, then returns the QUD projection of the state. A pragmatic speaker at level 1 weights utterances by alpha times the log-probability the literal listener assigns to the QUD value minus the utterance cost. A pragmatic listener at level 1 inverts the speaker to infer the world state, marginalizing over scope and QUD. A pragmatic speaker at level 2 weights utterances by the log-probability the pragmatic listener assigns to the world state.

model

A four-level RSA hierarchy where the literal listener interprets utterances relative to a scope and a QUD, the first-level speaker communicates through a utility-weighted choice, the first-level pragmatic listener inverts the speaker marginalizing over scope and QUD, and the second-level pragmatic speaker further weighs utterances by how informative they are to the pragmatic listener.

query

The distribution over utterances ('null' or 'not-two') chosen by the second-level pragmatic speaker for world state 2 — two of four rabbits were fed.

answer spec dist/finite
{
  "kind": "dist",
  "domain": "finite",
  "support": [
    "null",
    "not-two"
  ]
}
system prompt constant across problems
(system prompt loads here)
webppl primer solver context
(primer loads here)
01 realizations comparing webppl vs pyro
ground truth
webppl
1// Here is the code for English Expt 2 (inverse scope)
2//different qud prior, different state prior, access to alternative utterances
3
4// Here is the code for the quantifier scope model
5
6// possible utterances
7var utterances = ["null","not-two"];
8
9var utterancePrior = function() {
10 categorical({vs:["null","not-two"],ps:[1,10]})
11}
12
13var cost = function(utterance) {
14 return utterance == "not-two" ? 1 :
15 1
16}
17
18// possible world states
19var states = [0,1,2,3,4];
20var statePrior = function() {
21 categorical({vs:[0,1,2,3,4],ps:[1,1,1,1,1]})
22}
23
24// possible scopes
25var scopePrior = function(){
26 return categorical({vs:["surface", "inverse"],ps:[1,1]})
27}
28
29// meaning function
30var meaning = function(utterance, state, scope) {
31 return utterance == "not-two" ?
32 scope == "surface" ? (state < 2):
33 (state < 3) :
34 true;
35};
36
37// QUDs
38var QUDs = ["how many?","all red?","none red?"];
39var QUDPrior = function() {
40 categorical({vs:["how many?","all red?","none red?"],ps:[1,1,1]})
41 //uniformDraw(QUDs);
42}
43
44var QUDFun = function(QUD,state) {
45 QUD == "all red?" ? state == 4 :
46 QUD == "none red?" ? state == 0 :
47 state;
48};
49
50// Literal listener (L0)
51var literalListener = cache(function(utterance,scope,QUD) {
52 Infer({model: function(){
53 var state = uniformDraw(states);
54 var qState = QUDFun(QUD,state)
55 condition(meaning(utterance,state,scope));
56 return qState;
57 }});
58});
59
60var alpha = 1
61
62// Speaker (S)
63var speaker = cache(function(scope, state, QUD) {
64 return Infer({model: function(){
65 var utterance = utterancePrior()
66 var qState = QUDFun(QUD, state)
67 factor(alpha*(literalListener(utterance,scope,QUD).score(qState)
68 - cost(utterance)))
69 return utterance
70 }})
71})
72
73// Pragmatic listener (L1)
74var pragmaticListener = cache(function(utterance) {
75 Infer({model: function(){
76 var state = statePrior();
77 var scope = scopePrior();
78 var QUD = QUDPrior();
79 observe(speaker(scope,state,QUD),utterance);
80 return state
81 }});
82});
83
84// Pragmatic speaker (S2)
85var pragmaticSpeaker = cache(function(state) {
86 Infer({model: function(){
87 var utterance = utterancePrior();
88 factor(pragmaticListener(utterance).score(state))
89 return utterance
90 }})
91})
92
93// A speaker decides whether to endorse the ambiguous utterance as a
94// description of the not-all world state
95//viz.table(pragmaticSpeaker(0))
96viz.table(pragmaticSpeaker(2))
97//viz.table(pragmaticSpeaker(2))
98//literalListener("surface", 2, "all red?")
99
100var ANSWER = (pragmaticSpeaker(2));
realization0.000
python
1# RSA quantifier-scope model (English Expt 2, inverse scope).
2# Each RSA level is a SEPARATE, fully-computed and memoized Pyro enumeration
3# with level-unique site names; a finished level's marginal is fed into the next
4# as a fixed pyro.factor score table. No inference inside an active trace.
5
6utterances = ["null", "not-two"]
7states = [0, 1, 2, 3, 4]
8scopes = ["surface", "inverse"]
9QUDs = ["how many?", "all red?", "none red?"]
10alpha = 1.0
11
12state_logits = torch.zeros(len(states)) # ps:[1,1,1,1,1]
13utt_logits = torch.log(torch.tensor([1.0, 10.0])) # ps:[1,10]
14scope_logits = torch.zeros(len(scopes)) # ps:[1,1]
15qud_logits = torch.zeros(len(QUDs)) # ps:[1,1,1]
16
17
18def cost(_u):
19 return 1.0
20
21
22def meaning(utterance, state, scope):
23 if utterance == "not-two":
24 if scope == "surface":
25 return state < 2
26 return state < 3
27 return True
28
29
30def qud_fun(qud, state):
31 if qud == "all red?":
32 return ("bool", state == 4)
33 if qud == "none red?":
34 return ("bool", state == 0)
35 return ("state", state)
36
37
38def enum_marginal(model, site):
39 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(
40 model, lambda: None
41 )[site]
42 sup = marg.enumerate_support()
43 probs = marg.log_prob(sup).exp()
44 return {int(sup[i].item()): float(probs[i].item()) for i in range(sup.shape[0])}
45
46
47L0_cache = {}
48
49
50def literal_listener(utterance, scope, qud):
51 key = (utterance, scope, qud)
52 if key in L0_cache:
53 return L0_cache[key]
54 ok = torch.tensor([meaning(utterance, st, scope) for st in states], dtype=torch.bool)
55
56 @pyro.infer.config_enumerate
57 def model():
58 s = pyro.sample("L0_state", dist.Categorical(logits=state_logits))
59 pyro.factor("L0_ev", torch.where(ok[s], torch.tensor(0.0), torch.tensor(float("-inf"))))
60 return s
61
62 state_post = enum_marginal(model, "L0_state")
63 d = defaultdict(float)
64 for si, p in state_post.items():
65 d[qud_fun(qud, states[si])] += p
66 L0_cache[key] = dict(d)
67 return L0_cache[key]
68
69
70def l0_score(utterance, scope, qud, qstate):
71 p = literal_listener(utterance, scope, qud).get(qstate, 0.0)
72 return math.log(p) if p > 0 else float("-inf")
73
74
75S_cache = {}
76
77
78def speaker(scope, state, qud):
79 key = (scope, state, qud)
80 if key in S_cache:
81 return S_cache[key]
82 qstate = qud_fun(qud, state)
83 sc = torch.tensor([alpha * (l0_score(u, scope, qud, qstate) - cost(u)) for u in utterances])
84
85 @pyro.infer.config_enumerate
86 def model():
87 u = pyro.sample("S_utt", dist.Categorical(logits=utt_logits))
88 pyro.factor("S_f", sc[u])
89 return u
90
91 post = enum_marginal(model, "S_utt")
92 S_cache[key] = {utterances[i]: p for i, p in post.items()}
93 return S_cache[key]
94
95
96def speaker_score(scope, state, qud, utterance):
97 p = speaker(scope, state, qud).get(utterance, 0.0)
98 return math.log(p) if p > 0 else float("-inf")
99
100
101L1_cache = {}
102
103
104def pragmatic_listener(utterance):
105 if utterance in L1_cache:
106 return L1_cache[utterance]
107 tbl = torch.full((len(states), len(scopes), len(QUDs)), float("-inf"))
108 for si, st in enumerate(states):
109 for sci, scp in enumerate(scopes):
110 for qi, qd in enumerate(QUDs):
111 tbl[si, sci, qi] = speaker_score(scp, st, qd, utterance)
112
113 @pyro.infer.config_enumerate
114 def model():
115 s = pyro.sample("L1_state", dist.Categorical(logits=state_logits))
116 sc = pyro.sample("L1_scope", dist.Categorical(logits=scope_logits))
117 q = pyro.sample("L1_qud", dist.Categorical(logits=qud_logits))
118 pyro.factor("L1_obs", tbl[s, sc, q])
119 return s
120
121 post = enum_marginal(model, "L1_state")
122 L1_cache[utterance] = {states[i]: p for i, p in post.items()}
123 return L1_cache[utterance]
124
125
126def pragmatic_listener_score(utterance, state):
127 p = pragmatic_listener(utterance).get(state, 0.0)
128 return math.log(p) if p > 0 else float("-inf")
129
130
131def pragmatic_speaker(state):
132 sc = torch.tensor([pragmatic_listener_score(u, state) for u in utterances])
133
134 @pyro.infer.config_enumerate
135 def model():
136 u = pyro.sample("S2_utt", dist.Categorical(logits=utt_logits))
137 pyro.factor("S2_f", sc[u])
138 return u
139
140 post = enum_marginal(model, "S2_utt")
141 return {utterances[i]: p for i, p in post.items()}
142
143
144ANSWER = pragmatic_speaker(2)
145
02answer overlay — webppl vs pyrodist/finite
webppl pyro2 bins
00.460.460.930.93not-two A = 0.930 B = 0.930not-two A = 0.930 B = 0.9300.930.93not-twonull A = 0.070 B = 0.070null A = 0.070 B = 0.0700.070.07null
03 verification
checkstatusevidence
GT self-consistency ok floor 0.0000 (tv)
solver re-derivation accept 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6
cross-language (pyro vs webppl) pass d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000
forestdb-lxz-chinese-scope / atom-3
answer dist/finite solver accept pyro pass 0.0000
00 statement source: data/sources/forestdb.org/models/lxz-chinese-scope.md
given

World states are 0, 1, 2 (number of rabbits fed), each equally likely. Utterances are 'null', 'not-two', and 'none', each equally likely a priori with uniform cost 1. Scopes are 'surface' and 'inverse'; the scope prior strongly favors surface scope: categorical weights [100, 1]. QUDs are 'how many?', 'all red?', and 'none red?', each equally likely. Semantic rationality parameter alpha = 1. The literal meaning of 'none' is that no rabbits were fed (state = 0). The literal meaning of 'not-two' under surface scope is that fewer than 2 rabbits were fed (state < 2); under inverse scope it means no rabbits were fed (state = 0). 'null' is always true. The QUD function maps 'all red?' to whether state = 2, 'none red?' to whether state = 0, and 'how many?' to the state itself. A literal listener conditions on the utterance's truth under the given scope and QUD, then returns the QUD projection of the state. A pragmatic speaker at level 1 weights utterances by alpha times the log-probability the literal listener assigns to the QUD value minus the utterance cost. A pragmatic listener at level 1 inverts the speaker to infer the world state, marginalizing over scope and QUD. A pragmatic speaker at level 2 weights utterances by the log-probability the pragmatic listener assigns to the world state.

model

A four-level RSA hierarchy where the literal listener interprets utterances relative to a scope and a QUD, the first-level speaker communicates through a utility-weighted choice, the first-level pragmatic listener inverts the speaker marginalizing over scope and QUD, and the second-level pragmatic speaker further weighs utterances by how informative they are to the pragmatic listener.

query

The distribution over utterances ('null', 'not-two', or 'none') chosen by the second-level pragmatic speaker for world state 1 — one rabbit was fed.

answer spec dist/finite
{
  "kind": "dist",
  "domain": "finite",
  "support": [
    "null",
    "not-two",
    "none"
  ]
}
system prompt constant across problems
(system prompt loads here)
webppl primer solver context
(primer loads here)
01 realizations comparing webppl vs pyro
ground truth
webppl
1// Here is the code of Chinese model for Expt 1.
2//different qud prior, different state prior, access to alternative utterances
3
4// Here is the code for the quantifier scope model
5
6// possible utterances
7var utterances = ["null","not-two","none"];
8
9var utterancePrior = function() {
10 categorical({vs:["null","not-two","none"],ps:[1,1,1]})
11}
12
13var cost = function(utterance) {
14 return utterance == "not-two" ? 1 :
15 utterance == 'none'? 1 :
16 1
17}
18
19// possible world states
20var states = [0,1,2];
21var statePrior = function() {
22 categorical({vs:[0,1,2],ps:[1,1,1]})
23}
24
25// possible scopes
26var scopePrior = function(){
27 return categorical({vs:["surface", "inverse"],ps:[100,1]})
28}
29
30// meaning function
31var meaning = function(utterance, state, scope) {
32 return utterance == "none"? state == 0:
33 utterance == "not-two" ?
34 scope == "surface" ? (state < 2):
35 (state == 0) :
36 true;
37};
38
39// QUDs
40var QUDs = ["how many?","all red?","none red?"];
41var QUDPrior = function() {
42 categorical({vs:["how many?","all red?","none red?"],ps:[1,1,1]})
43 //uniformDraw(QUDs);
44}
45var QUDFun = function(QUD,state) {
46 QUD == "all red?" ? state == 2 :
47 QUD == "none red?" ? state == 0 :
48 state;
49};
50
51// Literal listener (L0)
52var literalListener = cache(function(utterance,scope,QUD) {
53 Infer({model: function(){
54 var state = uniformDraw(states);
55 var qState = QUDFun(QUD,state)
56 condition(meaning(utterance,state,scope));
57 return qState;
58 }});
59});
60
61var alpha = 1
62
63// Speaker (S)
64var speaker = cache(function(scope, state, QUD) {
65 return Infer({model: function(){
66 var utterance = utterancePrior()
67 var qState = QUDFun(QUD, state)
68 factor(alpha*(literalListener(utterance,scope,QUD).score(qState)
69 - cost(utterance)))
70 return utterance
71 }})
72})
73
74// Pragmatic listener (L1)
75var pragmaticListener = cache(function(utterance) {
76 Infer({model: function(){
77 var state = statePrior();
78 var scope = scopePrior();
79 var QUD = QUDPrior();
80 observe(speaker(scope,state,QUD),utterance);
81 return state
82 }});
83});
84
85// Pragmatic speaker (S2)
86var pragmaticSpeaker = cache(function(state) {
87 Infer({model: function(){
88 var utterance = utterancePrior();
89 factor(pragmaticListener(utterance).score(state))
90 return utterance
91 }})
92})
93
94// A speaker decides whether to endorse the ambiguous utterance as a
95// description of the not-all world state
96//viz.table(pragmaticSpeaker(0))
97//viz.table(pragmaticSpeaker(1))
98viz.table(pragmaticSpeaker(1))
99//literalListener("surface", 2, "all red?")
100
101var ANSWER = (pragmaticSpeaker(1));
realization0.000
python
1# RSA quantifier-scope model (Chinese Expt 1).
2# Each RSA level is a SEPARATE, fully-computed and memoized Pyro enumeration
3# with level-unique site names; a finished level's marginal is fed into the next
4# as a fixed pyro.factor score table. No inference inside an active trace.
5
6utterances = ["null", "not-two", "none"]
7states = [0, 1, 2]
8scopes = ["surface", "inverse"]
9QUDs = ["how many?", "all red?", "none red?"]
10alpha = 1.0
11
12state_logits = torch.zeros(len(states)) # ps:[1,1,1]
13utt_logits = torch.zeros(len(utterances)) # ps:[1,1,1]
14scope_logits = torch.log(torch.tensor([100.0, 1.0])) # ps:[100,1]
15qud_logits = torch.zeros(len(QUDs)) # ps:[1,1,1]
16
17
18def cost(_u):
19 return 1.0
20
21
22def meaning(utterance, state, scope):
23 if utterance == "none":
24 return state == 0
25 if utterance == "not-two":
26 if scope == "surface":
27 return state < 2
28 return state == 0
29 return True
30
31
32def qud_fun(qud, state):
33 if qud == "all red?":
34 return ("bool", state == 2)
35 if qud == "none red?":
36 return ("bool", state == 0)
37 return ("state", state)
38
39
40def enum_marginal(model, site):
41 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(
42 model, lambda: None
43 )[site]
44 sup = marg.enumerate_support()
45 probs = marg.log_prob(sup).exp()
46 return {int(sup[i].item()): float(probs[i].item()) for i in range(sup.shape[0])}
47
48
49L0_cache = {}
50
51
52def literal_listener(utterance, scope, qud):
53 key = (utterance, scope, qud)
54 if key in L0_cache:
55 return L0_cache[key]
56 ok = torch.tensor([meaning(utterance, st, scope) for st in states], dtype=torch.bool)
57
58 @pyro.infer.config_enumerate
59 def model():
60 s = pyro.sample("L0_state", dist.Categorical(logits=state_logits))
61 pyro.factor("L0_ev", torch.where(ok[s], torch.tensor(0.0), torch.tensor(float("-inf"))))
62 return s
63
64 state_post = enum_marginal(model, "L0_state")
65 d = defaultdict(float)
66 for si, p in state_post.items():
67 d[qud_fun(qud, states[si])] += p
68 L0_cache[key] = dict(d)
69 return L0_cache[key]
70
71
72def l0_score(utterance, scope, qud, qstate):
73 p = literal_listener(utterance, scope, qud).get(qstate, 0.0)
74 return math.log(p) if p > 0 else float("-inf")
75
76
77S_cache = {}
78
79
80def speaker(scope, state, qud):
81 key = (scope, state, qud)
82 if key in S_cache:
83 return S_cache[key]
84 qstate = qud_fun(qud, state)
85 sc = torch.tensor([alpha * (l0_score(u, scope, qud, qstate) - cost(u)) for u in utterances])
86
87 @pyro.infer.config_enumerate
88 def model():
89 u = pyro.sample("S_utt", dist.Categorical(logits=utt_logits))
90 pyro.factor("S_f", sc[u])
91 return u
92
93 post = enum_marginal(model, "S_utt")
94 S_cache[key] = {utterances[i]: p for i, p in post.items()}
95 return S_cache[key]
96
97
98def speaker_score(scope, state, qud, utterance):
99 p = speaker(scope, state, qud).get(utterance, 0.0)
100 return math.log(p) if p > 0 else float("-inf")
101
102
103L1_cache = {}
104
105
106def pragmatic_listener(utterance):
107 if utterance in L1_cache:
108 return L1_cache[utterance]
109 tbl = torch.full((len(states), len(scopes), len(QUDs)), float("-inf"))
110 for si, st in enumerate(states):
111 for sci, scp in enumerate(scopes):
112 for qi, qd in enumerate(QUDs):
113 tbl[si, sci, qi] = speaker_score(scp, st, qd, utterance)
114
115 @pyro.infer.config_enumerate
116 def model():
117 s = pyro.sample("L1_state", dist.Categorical(logits=state_logits))
118 sc = pyro.sample("L1_scope", dist.Categorical(logits=scope_logits))
119 q = pyro.sample("L1_qud", dist.Categorical(logits=qud_logits))
120 pyro.factor("L1_obs", tbl[s, sc, q])
121 return s
122
123 post = enum_marginal(model, "L1_state")
124 L1_cache[utterance] = {states[i]: p for i, p in post.items()}
125 return L1_cache[utterance]
126
127
128def pragmatic_listener_score(utterance, state):
129 p = pragmatic_listener(utterance).get(state, 0.0)
130 return math.log(p) if p > 0 else float("-inf")
131
132
133def pragmatic_speaker(state):
134 sc = torch.tensor([pragmatic_listener_score(u, state) for u in utterances])
135
136 @pyro.infer.config_enumerate
137 def model():
138 u = pyro.sample("S2_utt", dist.Categorical(logits=utt_logits))
139 pyro.factor("S2_f", sc[u])
140 return u
141
142 post = enum_marginal(model, "S2_utt")
143 return {utterances[i]: p for i, p in post.items()}
144
145
146ANSWER = pragmatic_speaker(1)
147
02answer overlay — webppl vs pyrodist/finite
webppl pyro3 bins
00.260.260.510.51none A = 0.206 B = 0.206none A = 0.206 B = 0.2060.210.21nonenot-two A = 0.513 B = 0.513not-two A = 0.513 B = 0.5130.510.51not-twonull A = 0.281 B = 0.281null A = 0.281 B = 0.2810.280.28null
03 verification
checkstatusevidence
GT self-consistency ok floor 0.0000 (tv)
solver re-derivation accept 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6
cross-language (pyro vs webppl) pass d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000
forestdb-lxz-chinese-scope / atom-4
answer dist/finite solver accept pyro pass 0.0000
00 statement source: data/sources/forestdb.org/models/lxz-chinese-scope.md
given

World states are 0, 1, 2, 3, 4 (number of rabbits fed out of 4), each equally likely. Utterances are 'null', 'not-two', and 'none', each equally likely a priori with uniform cost 1. Scopes are 'surface' and 'inverse'; the scope prior strongly favors surface scope: categorical weights [100, 1]. QUDs are 'how many?', 'all red?', and 'none red?', each equally likely. Semantic rationality parameter alpha = 1. The literal meaning of 'none' is that no rabbits were fed (state = 0). The literal meaning of 'not-two' under surface scope is that fewer than 2 rabbits were fed (state < 2); under inverse scope it means fewer than 3 rabbits were fed (state < 3). 'null' is always true. The QUD function maps 'all red?' to whether state = 4, 'none red?' to whether state = 0, and 'how many?' to the state itself. A literal listener conditions on the utterance's truth under the given scope and QUD, then returns the QUD projection of the state. A pragmatic speaker at level 1 weights utterances by alpha times the log-probability the literal listener assigns to the QUD value minus the utterance cost. A pragmatic listener at level 1 inverts the speaker to infer the world state, marginalizing over scope and QUD. A pragmatic speaker at level 2 weights utterances by the log-probability the pragmatic listener assigns to the world state.

model

A four-level RSA hierarchy where the literal listener interprets utterances relative to a scope and a QUD, the first-level speaker communicates through a utility-weighted choice, the first-level pragmatic listener inverts the speaker marginalizing over scope and QUD, and the second-level pragmatic speaker further weighs utterances by how informative they are to the pragmatic listener.

query

The distribution over utterances ('null', 'not-two', or 'none') chosen by the second-level pragmatic speaker for world state 2 — two of four rabbits were fed.

answer spec dist/finite
{
  "kind": "dist",
  "domain": "finite",
  "support": [
    "null",
    "not-two",
    "none"
  ]
}
system prompt constant across problems
(system prompt loads here)
webppl primer solver context
(primer loads here)
01 realizations comparing webppl vs pyro
ground truth
webppl
1// Here is the code of Chinese model for Expt 2.
2//different qud prior, different state prior, access to alternative utterances
3
4// Here is the code for the quantifier scope model
5
6// possible utterances
7var utterances = ["null","not-two","none"];
8
9var utterancePrior = function() {
10 categorical({vs:["null","not-two","none"],ps:[1,1,1]})
11}
12
13var cost = function(utterance) {
14 return utterance == "not-two" ? 1 :
15 utterance == 'none'? 1 :
16 1
17}
18
19// possible world states
20var states = [0,1,2,3,4];
21var statePrior = function() {
22 categorical({vs:[0,1,2,3,4],ps:[1,1,1,1,1]})
23}
24
25// possible scopes
26var scopePrior = function(){
27 return categorical({vs:["surface", "inverse"],ps:[100,1]})
28}
29
30// meaning function
31var meaning = function(utterance, state, scope) {
32 return utterance == "none"? state == 0:
33 utterance == "not-two" ?
34 scope == "surface" ? (state < 2):
35 (state < 3) :
36 true;
37};
38
39// QUDs
40var QUDs = ["how many?","all red?","none red?"];
41var QUDPrior = function() {
42 categorical({vs:["how many?","all red?","none red?"],ps:[1,1,1]})
43 //uniformDraw(QUDs);
44}
45var QUDFun = function(QUD,state) {
46 QUD == "all red?" ? state == 4 :
47 QUD == "none red?" ? state == 0 :
48 state;
49};
50
51// Literal listener (L0)
52var literalListener = cache(function(utterance,scope,QUD) {
53 Infer({model: function(){
54 var state = uniformDraw(states);
55 var qState = QUDFun(QUD,state)
56 condition(meaning(utterance,state,scope));
57 return qState;
58 }});
59});
60
61var alpha = 1
62
63// Speaker (S)
64var speaker = cache(function(scope, state, QUD) {
65 return Infer({model: function(){
66 var utterance = utterancePrior()
67 var qState = QUDFun(QUD, state)
68 factor(alpha*(literalListener(utterance,scope,QUD).score(qState)
69 - cost(utterance)))
70 return utterance
71 }})
72})
73
74// Pragmatic listener (L1)
75var pragmaticListener = cache(function(utterance) {
76 Infer({model: function(){
77 var state = statePrior();
78 var scope = scopePrior();
79 var QUD = QUDPrior();
80 observe(speaker(scope,state,QUD),utterance);
81 return state
82 }});
83});
84
85// Pragmatic speaker (S2)
86var pragmaticSpeaker = cache(function(state) {
87 Infer({model: function(){
88 var utterance = utterancePrior();
89 factor(pragmaticListener(utterance).score(state))
90 return utterance
91 }})
92})
93
94// A speaker decides whether to endorse the ambiguous utterance as a
95// description of the not-all world state
96//viz.table(pragmaticSpeaker(0))
97viz.table(pragmaticSpeaker(2))
98//viz.table(pragmaticSpeaker(2))
99//literalListener("surface", 2, "all red?")
100
101var ANSWER = (pragmaticSpeaker(2));
realization0.000
python
1# RSA quantifier-scope model (Chinese Expt 2).
2# Each RSA level is a SEPARATE, fully-computed and memoized Pyro enumeration
3# with level-unique site names; a finished level's marginal is fed into the next
4# as a fixed pyro.factor score table. No inference inside an active trace.
5
6utterances = ["null", "not-two", "none"]
7states = [0, 1, 2, 3, 4]
8scopes = ["surface", "inverse"]
9QUDs = ["how many?", "all red?", "none red?"]
10alpha = 1.0
11
12state_logits = torch.zeros(len(states)) # ps:[1,1,1,1,1]
13utt_logits = torch.zeros(len(utterances)) # ps:[1,1,1]
14scope_logits = torch.log(torch.tensor([100.0, 1.0])) # ps:[100,1]
15qud_logits = torch.zeros(len(QUDs)) # ps:[1,1,1]
16
17
18def cost(_u):
19 return 1.0
20
21
22def meaning(utterance, state, scope):
23 if utterance == "none":
24 return state == 0
25 if utterance == "not-two":
26 if scope == "surface":
27 return state < 2
28 return state < 3
29 return True
30
31
32def qud_fun(qud, state):
33 if qud == "all red?":
34 return ("bool", state == 4)
35 if qud == "none red?":
36 return ("bool", state == 0)
37 return ("state", state)
38
39
40def enum_marginal(model, site):
41 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(
42 model, lambda: None
43 )[site]
44 sup = marg.enumerate_support()
45 probs = marg.log_prob(sup).exp()
46 return {int(sup[i].item()): float(probs[i].item()) for i in range(sup.shape[0])}
47
48
49L0_cache = {}
50
51
52def literal_listener(utterance, scope, qud):
53 key = (utterance, scope, qud)
54 if key in L0_cache:
55 return L0_cache[key]
56 ok = torch.tensor([meaning(utterance, st, scope) for st in states], dtype=torch.bool)
57
58 @pyro.infer.config_enumerate
59 def model():
60 s = pyro.sample("L0_state", dist.Categorical(logits=state_logits))
61 pyro.factor("L0_ev", torch.where(ok[s], torch.tensor(0.0), torch.tensor(float("-inf"))))
62 return s
63
64 state_post = enum_marginal(model, "L0_state")
65 d = defaultdict(float)
66 for si, p in state_post.items():
67 d[qud_fun(qud, states[si])] += p
68 L0_cache[key] = dict(d)
69 return L0_cache[key]
70
71
72def l0_score(utterance, scope, qud, qstate):
73 p = literal_listener(utterance, scope, qud).get(qstate, 0.0)
74 return math.log(p) if p > 0 else float("-inf")
75
76
77S_cache = {}
78
79
80def speaker(scope, state, qud):
81 key = (scope, state, qud)
82 if key in S_cache:
83 return S_cache[key]
84 qstate = qud_fun(qud, state)
85 sc = torch.tensor([alpha * (l0_score(u, scope, qud, qstate) - cost(u)) for u in utterances])
86
87 @pyro.infer.config_enumerate
88 def model():
89 u = pyro.sample("S_utt", dist.Categorical(logits=utt_logits))
90 pyro.factor("S_f", sc[u])
91 return u
92
93 post = enum_marginal(model, "S_utt")
94 S_cache[key] = {utterances[i]: p for i, p in post.items()}
95 return S_cache[key]
96
97
98def speaker_score(scope, state, qud, utterance):
99 p = speaker(scope, state, qud).get(utterance, 0.0)
100 return math.log(p) if p > 0 else float("-inf")
101
102
103L1_cache = {}
104
105
106def pragmatic_listener(utterance):
107 if utterance in L1_cache:
108 return L1_cache[utterance]
109 tbl = torch.full((len(states), len(scopes), len(QUDs)), float("-inf"))
110 for si, st in enumerate(states):
111 for sci, scp in enumerate(scopes):
112 for qi, qd in enumerate(QUDs):
113 tbl[si, sci, qi] = speaker_score(scp, st, qd, utterance)
114
115 @pyro.infer.config_enumerate
116 def model():
117 s = pyro.sample("L1_state", dist.Categorical(logits=state_logits))
118 sc = pyro.sample("L1_scope", dist.Categorical(logits=scope_logits))
119 q = pyro.sample("L1_qud", dist.Categorical(logits=qud_logits))
120 pyro.factor("L1_obs", tbl[s, sc, q])
121 return s
122
123 post = enum_marginal(model, "L1_state")
124 L1_cache[utterance] = {states[i]: p for i, p in post.items()}
125 return L1_cache[utterance]
126
127
128def pragmatic_listener_score(utterance, state):
129 p = pragmatic_listener(utterance).get(state, 0.0)
130 return math.log(p) if p > 0 else float("-inf")
131
132
133def pragmatic_speaker(state):
134 sc = torch.tensor([pragmatic_listener_score(u, state) for u in utterances])
135
136 @pyro.infer.config_enumerate
137 def model():
138 u = pyro.sample("S2_utt", dist.Categorical(logits=utt_logits))
139 pyro.factor("S2_f", sc[u])
140 return u
141
142 post = enum_marginal(model, "S2_utt")
143 return {utterances[i]: p for i, p in post.items()}
144
145
146ANSWER = pragmatic_speaker(2)
147
02answer overlay — webppl vs pyrodist/finite
webppl pyro3 bins
00.210.210.430.43none A = 0.251 B = 0.251none A = 0.251 B = 0.2510.250.25nonenot-two A = 0.321 B = 0.321not-two A = 0.321 B = 0.3210.320.32not-twonull A = 0.428 B = 0.428null A = 0.428 B = 0.4280.430.43null
03 verification
checkstatusevidence
GT self-consistency ok floor 0.0000 (tv)
solver re-derivation accept 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6
cross-language (pyro vs webppl) pass d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000
forestdb-overinf / atom-1
answer dist/finite solver accept pyro pass 0.0000
00 statement source: forestdb.org/models/overinf.md
given

There are 3 objects in the scene: one big blue object, one small blue object, and one big red object. Utterances are 'big', 'small', 'blue', 'red', 'big_blue', 'small_blue', and 'big_red', each equally likely a priori. The meaning function uses soft semantics parameterized by size_semvalue = 0.8 and color_semvalue = 0.99: a single size word applied to an object returns 0.8 if the size matches, else 0.2; a single color word applied to an object returns 0.99 if the color matches, else 0.01. A two-word utterance of the form SIZE_COLOR returns the product of the corresponding size and color semantic values. Cost is defined by two parameters: size_cost = 0 and color_cost = 0, giving all single-word utterances cost 0 and all two-word utterances cost 0. The literal listener scores each state by adding the semantic value directly as a log-weight (i.e., state s receives unnormalized log-score += meaning(utt, s), so its unnormalized probability is proportional to exp(meaning(utt, s)), not to meaning(utt, s) itself). The pragmatic speaker uses alpha = 30 and costWeight = 1: it factors by alpha times the literal listener's log-probability of the target state minus costWeight times the utterance cost.

model

A two-level RSA reference game model with relaxed (graded) semantics. The literal listener interprets utterances as soft evidence about the object and infers which object is being described. The pragmatic speaker selects utterances by soft-max over the utility of each utterance, where utility is the listener's log-probability of the intended object minus the utterance cost.

query

The distribution over utterances chosen by the pragmatic speaker when communicating the small blue object — {size: 'small', color: 'blue'}.

answer spec dist/finite
{
  "kind": "dist",
  "domain": "finite",
  "support": [
    "big",
    "small",
    "blue",
    "red",
    "big_blue",
    "small_blue",
    "big_red"
  ]
}
system prompt constant across problems
(system prompt loads here)
webppl primer solver context
(primer loads here)
01 realizations comparing webppl vs pyro
ground truth
webppl
1var alpha = 30
2var costWeight = 1
3var size_semvalue = 0.8
4var color_semvalue = 0.99
5var size_cost = 0
6var color_cost = 0
7
8var states = [
9 {size: "big", color: "blue"},
10 {size: "small", color: "blue"},
11 {size: "big", color: "red"}]
12
13var utterances = ["big", "small", "blue", "red", "big_blue", "small_blue", "big_red"]
14
15var colors = ["red", "blue"]
16var sizes = ["big", "small"]
17
18var statePrior = function() {
19 return uniformDraw(states)
20};
21
22var utterancePrior = function() {
23 return uniformDraw(utterances)
24};
25
26// assumes that 2-word utterances consist of SIZE_COLOR, in that order
27var meaning = function(utt, obj) {
28 var splitWords = utt.split('_')
29 if (splitWords.length == 1) {
30 var word = splitWords[0]
31 if(_.includes(colors, word))
32 return word == obj.color ? color_semvalue : 1-color_semvalue;
33 else if (_.includes(sizes, word))
34 return word == obj.size ? size_semvalue : 1-size_semvalue;
35 } else if (splitWords.length == 2) {
36 var size_value = splitWords[0] == obj.size ? size_semvalue : 1-size_semvalue;
37 var color_value = splitWords[1] == obj.color ? color_semvalue : 1-color_semvalue;
38 return size_value*color_value
39 } else
40 console.error("bad utterance length: "+splitWords.length)
41};
42
43var cost = {
44 big: size_cost,
45 small: size_cost,
46 blue: color_cost,
47 red: color_cost,
48 big_blue: size_cost+color_cost,
49 small_blue: size_cost+color_cost,
50 big_red: size_cost+color_cost
51}
52
53// literal listener
54var literalListener = cache(function(utt) {
55 return Infer({method:"enumerate"},
56 function(){
57 var state = statePrior()
58 factor(meaning(utt,state))
59 return state
60 })
61});
62
63// pragmatic speaker
64var speaker = cache(function(state) {
65 return Infer({method:"enumerate"},
66 function(){
67 var utt = utterancePrior()
68 factor(alpha * literalListener(utt).score(state) - costWeight * cost[utt])
69 return utt
70 })
71});
72
73
74display("speaker who wants to communicate big blue object:")
75viz.table(speaker({size: "big", color: "blue"}))
76
77display("speaker who wants to communicate big red object:")
78viz.table(speaker({size: "big", color: "red"}))
79
80display("speaker who wants to communicate small blue object:")
81viz.table(speaker({size: "small", color: "blue"}))
82
83display("literal listener who observes 'big':")
84viz.table(literalListener("big"))
85
86display("literal listener who observes 'small':")
87viz.table(literalListener("small"))
88
89display("literal listener who observes 'blue':")
90viz.table(literalListener("blue"))
91
92display("literal listener who observes 'red':")
93viz.table(literalListener("red"))
94
95display("literal listener who observes 'big blue':")
96viz.table(literalListener("big_blue"))
97
98display("literal listener who observes 'big red':")
99viz.table(literalListener("big_red"))
100
101display("literal listener who observes 'small blue':")
102viz.table(literalListener("small_blue"))
103
104var ANSWER = (speaker({size: 'small', color: 'blue'}));
realization0.000
python
1alpha = 30.0
2costWeight = 1.0
3size_semvalue = 0.8
4color_semvalue = 0.99
5size_cost = 0.0
6color_cost = 0.0
7
8states = [
9 {"size": "big", "color": "blue"},
10 {"size": "small", "color": "blue"},
11 {"size": "big", "color": "red"},
12]
13utterances = ["big", "small", "blue", "red", "big_blue", "small_blue", "big_red"]
14colors = ["red", "blue"]
15sizes = ["big", "small"]
16
17def meaning(utt, obj):
18 parts = utt.split("_")
19 if len(parts) == 1:
20 word = parts[0]
21 if word in colors:
22 return color_semvalue if word == obj["color"] else 1 - color_semvalue
23 else:
24 return size_semvalue if word == obj["size"] else 1 - size_semvalue
25 else:
26 size_value = size_semvalue if parts[0] == obj["size"] else 1 - size_semvalue
27 color_value = color_semvalue if parts[1] == obj["color"] else 1 - color_semvalue
28 return size_value * color_value
29
30cost = {
31 "big": size_cost, "small": size_cost, "blue": color_cost, "red": color_cost,
32 "big_blue": size_cost + color_cost, "small_blue": size_cost + color_cost,
33 "big_red": size_cost + color_cost,
34}
35
36# Literal listener: uniform prior over states, factor(meaning(utt, state)) added as
37# a log-weight. Inference is exact enumeration via Pyro's compute_marginals.
38_L0_cache = {}
39def literalListener(utt):
40 if utt in _L0_cache:
41 return _L0_cache[utt]
42 _mvals = torch.tensor([meaning(utt, s) for s in states])
43 @pyro.infer.config_enumerate
44 def m():
45 idx = pyro.sample("state", dist.Categorical(torch.ones(len(states))))
46 pyro.factor("sem", _mvals[idx])
47 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(m, lambda: None)
48 _L0_cache[utt] = marg["state"]
49 return _L0_cache[utt]
50
51# Pragmatic speaker: utterance ~ exp(alpha * L0.score(state) - costWeight * cost),
52# again by exact enumeration over utterances inverting the literal listener.
53def speaker(state):
54 sidx = states.index(state)
55 L0s = [literalListener(u) for u in utterances]
56 _scores = torch.tensor([
57 alpha * L0s[i].log_prob(torch.tensor(sidx)).item() - costWeight * cost[utterances[i]]
58 for i in range(len(utterances))
59 ])
60 @pyro.infer.config_enumerate
61 def m():
62 uidx = pyro.sample("utt", dist.Categorical(torch.ones(len(utterances))))
63 pyro.factor("util", _scores[uidx])
64 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(m, lambda: None)
65 return marg["utt"]
66
67_sp = speaker({"size": "small", "color": "blue"})
68ANSWER = {utterances[i]: _sp.probs[i].item() for i in range(len(utterances))}
69
02answer overlay — webppl vs pyrodist/finite
webppl pyro7 bins
00.390.390.790.79big A = 0.000 B = 0.000big A = 0.000 B = 0.000bigbig_blue A = 0.000 B = 0.000big_blue A = 0.000 B = 0.000big_bluebig_red A = 0.000 B = 0.000big_red A = 0.000 B = 0.000big_redblue A = 0.005 B = 0.005blue A = 0.005 B = 0.005bluered A = 0.000 B = 0.000red A = 0.000 B = 0.000redsmall A = 0.206 B = 0.206small A = 0.206 B = 0.2060.210.21smallsmall_blue A = 0.789 B = 0.789small_blue A = 0.789 B = 0.7890.790.79small_blue
03 verification
checkstatusevidence
GT self-consistency ok floor 0.0000 (tv)
solver re-derivation accept 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6
cross-language (pyro vs webppl) pass d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000
forestdb-prior-inference / atom-1
answer dist/finite solver accept pyro pass 0.0000
00 statement source: forestdb.org/models/prior-inference.md
given

Three objects of reference: a blue square, a blue circle, and a green square. Four utterances: 'blue', 'green', 'square', 'circle'. Five preference contexts: 'blue_things', 'green_things', 'squares', 'circles', 'none'. Each preference context assigns categorical salience weights over the three objects (in the order blue square, blue circle, green square): blue_things [4, 4, 2], green_things [1, 1, 8], squares [4, 2, 4], circles [1, 8, 1], none [1, 1, 1]. Speaker optimality alpha = 1.

model

An object prior samples an object proportional to the salience weights for the given preference context and returns its string label. An utterance is literally true of an object if the utterance string is contained in the object's string label. A literal listener (L0) draws uniformly from the three objects (ignoring preference), conditions on the utterance being literally true, and returns the object string. A speaker draws uniformly from utterances and weights each by the literal listener's log-probability of the object under that utterance and preference context, scaled by alpha. A pragmatic listener samples an object from the preference-weighted object prior and updates on the speaker choosing the heard utterance.

query

The posterior distribution over object labels for a pragmatic listener who hears the utterance 'square' under the preference context 'blue_things'.

answer spec dist/finite
{
  "kind": "dist",
  "domain": "finite",
  "support": [
    "blue square",
    "blue circle",
    "green square"
  ]
}
system prompt constant across problems
(system prompt loads here)
webppl primer solver context
(primer loads here)
01 realizations comparing webppl vs pyro
ground truth
webppl
1// Frank and Goodman (2012) RSA model from problang.org
2
3// set of states (here: objects of reference)
4// we represent objects as JavaScript objects to demarcate them from utterances
5// internally we treat objects as strings nonetheless
6var objects = [{color: "blue", shape: "square", string: "blue square"},
7 {color: "blue", shape: "circle", string: "blue circle"},
8 {color: "green", shape: "square", string: "green square"}]
9
10// set of utterances
11var utterances = ["blue", "green", "square", "circle"]
12
13var preferences = ["blue_things", "green_things", "squares", "circles","none"]
14
15var preferenceTable = {
16 blue_things : [4,4,2],
17 green_things : [1,1,8],
18 squares : [4,2,4],
19 circles : [1,8,1],
20 none : [1,1,1]
21}
22
23var preferencePrior = function() {
24 return uniformDraw(preferences)
25}
26
27// prior over world states
28var objectPrior = function(preference) {
29 var obj = categorical(preferenceTable[preference],objects)
30 return obj.string
31}
32
33// meaning function to interpret the utterances
34var meaning = function(utterance, obj){
35 _.includes(obj, utterance)
36}
37
38// literal listener
39var literalListener = function(utterance,preference){
40 Infer({model: function(){
41 var obj = uniformDraw(objects).string // L0 has no preference
42 condition(meaning(utterance, obj))
43 return obj
44 }})
45}
46
47// set speaker optimality
48var alpha = 1
49
50// pragmatic speaker
51var speaker = function(obj,preference){
52 Infer({model: function(){
53 var utterance = uniformDraw(utterances)
54 factor(alpha * literalListener(utterance,preference).score(obj))
55 return utterance
56 }})
57}
58
59// pragmatic listener
60var pragmaticListener = function(utterance,preference){
61 Infer({model: function(){
62 var obj = objectPrior(preference)
63 observe(speaker(obj,preference),utterance)
64 return obj
65 }})
66}
67
68print("the listener hears 'square' and has a preference for blue things")
69viz(pragmaticListener("square","blue_things"))
70
71print("the listener hears 'square' and has a preference for green things")
72viz(pragmaticListener("square","green_things"))
73
74print("the listener hears 'square' and has a preference for squares")
75viz(pragmaticListener("square","squares"))
76
77var ANSWER = (pragmaticListener('square', 'blue_things'));
78
realization0.000
python
1# Frank & Goodman (2012) RSA with a preference-dependent object prior.
2# Each RSA level is a separate, completely-finished Pyro discrete enumeration
3# (config_enumerate + TraceEnum_ELBO.compute_marginals). Lower levels are fully
4# computed and memoized BEFORE any higher level runs, so no inference ever runs
5# inside another inference's active enumeration. Each level's single latent has
6# a site name unique across levels (o0 / u1 / o2), avoiding site collisions.
7
8objects = ["blue square", "blue circle", "green square"]
9object_tokens = [["blue", "square"], ["blue", "circle"], ["green", "square"]]
10utterances = ["blue", "green", "square", "circle"]
11preferences = ["blue_things", "green_things", "squares", "circles", "none"]
12preference_table = {
13 "blue_things": [4, 4, 2],
14 "green_things": [1, 1, 8],
15 "squares": [4, 2, 4],
16 "circles": [1, 8, 1],
17 "none": [1, 1, 1],
18}
19alpha = 1.0
20
21utt_logits = torch.zeros(len(utterances)) # speaker draws utterances uniformly
22l0_state_logits = torch.zeros(len(objects)) # L0 object prior is uniform
23
24
25def meaning(utterance, obj_idx):
26 return utterance in object_tokens[obj_idx]
27
28
29def marginal_dict(model, site, support):
30 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(
31 model, lambda: None
32 )[site]
33 sup = marg.enumerate_support()
34 probs = marg.log_prob(sup).exp()
35 out = {}
36 for i in range(sup.shape[0]):
37 out[support[int(sup[i].item())]] = float(probs[i].item())
38 return out
39
40
41def logscore(d, key):
42 p = d.get(key, 0.0)
43 return math.log(p) if p > 0 else float("-inf")
44
45
46# ---- Literal listener L0: infer object given utterance (uniform object prior) ----
47L0_cache = {}
48
49
50def literal_listener(utterance):
51 if utterance in L0_cache:
52 return L0_cache[utterance]
53
54 @pyro.infer.config_enumerate
55 def model():
56 o = pyro.sample("o0", dist.Categorical(logits=l0_state_logits))
57 ok = torch.tensor([meaning(utterance, i) for i in range(len(objects))])
58 pyro.factor("ev", torch.where(ok[o], torch.tensor(0.0), torch.tensor(float("-inf"))))
59
60 L0_cache[utterance] = marginal_dict(model, "o0", objects)
61 return L0_cache[utterance]
62
63
64# ---- Speaker: infer utterance given object ----
65S_cache = {}
66
67
68def speaker(obj):
69 if obj in S_cache:
70 return S_cache[obj]
71 # speaker scores by L0.score(obj); pre-warm all L0 marginals first.
72 sc = torch.tensor([alpha * logscore(literal_listener(u), obj) for u in utterances])
73
74 @pyro.infer.config_enumerate
75 def model():
76 u = pyro.sample("u1", dist.Categorical(logits=utt_logits))
77 pyro.factor("f", sc[u])
78
79 S_cache[obj] = marginal_dict(model, "u1", utterances)
80 return S_cache[obj]
81
82
83# ---- Pragmatic listener: infer object given (utterance, preference) ----
84def pragmatic_listener(utterance, preference):
85 obj_logits = torch.log(torch.tensor([float(w) for w in preference_table[preference]]))
86 # Fully compute & memoize every speaker marginal we will need BEFORE the
87 # outer enumeration runs.
88 sc = torch.tensor([logscore(speaker(objects[i]), utterance) for i in range(len(objects))])
89
90 @pyro.infer.config_enumerate
91 def model():
92 o = pyro.sample("o2", dist.Categorical(logits=obj_logits))
93 pyro.factor("obs", sc[o])
94
95 return marginal_dict(model, "o2", objects)
96
97
98ANSWER = pragmatic_listener("square", "blue_things")
99
02answer overlay — webppl vs pyrodist/finite
webppl pyro2 bins
00.380.380.750.75blue square A = 0.750 B = 0.750blue square A = 0.750 B = 0.7500.750.75blue squaregreen square A = 0.250 B = 0.250green square A = 0.250 B = 0.2500.250.25green square
03 verification
checkstatusevidence
GT self-consistency ok floor 0.0000 (tv)
solver re-derivation accept 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6
cross-language (pyro vs webppl) pass d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000
forestdb-scalar-implicature-qud / atom-1
answer dist/int solver accept pyro pass 0.0000
00 statement source: forestdb.org/models/scalar-implicature-qud.md
given

The world state is the number of red apples on a table, drawn uniformly from {0, 1, 2, 3}. Three utterances drawn uniformly: 'all', 'some', 'none'. Literal meanings: 'all' is true iff state = 3; 'some' is true iff state > 0; 'none' is true iff state = 0. Two possible QUDs: 'all?' (true iff state = 3) and 'any?' (true iff state > 0). Speaker optimality alpha = 1.

model

A literal listener for a given utterance and QUD samples a world state uniformly, conditions on the utterance's literal meaning holding, and returns the QUD value (a boolean). A speaker for a given state and QUD draws an utterance uniformly and weights it by the literal listener's log-probability of the QUD value at that state, scaled by alpha. A pragmatic listener for a given utterance and QUD samples a world state uniformly and updates on the speaker choosing the heard utterance.

query

The posterior distribution over world states (number of red apples) for a pragmatic listener who hears the utterance 'some' under the QUD 'any?'.

answer spec dist/int
{
  "kind": "dist",
  "domain": "int"
}
system prompt constant across problems
(system prompt loads here)
webppl primer solver context
(primer loads here)
01 realizations comparing webppl vs pyro
ground truth
webppl
1// possible states of the world
2var statePrior = function() {
3 return uniformDraw([0, 1, 2, 3])
4};
5
6// possible utterances
7var utterancePrior = function() {
8 return uniformDraw(['all', 'some', 'none']);
9};
10
11// possible quds
12var quds = ['all?','any?']
13
14// prior over quds (only relevant for qud inference)
15var qudPrior = function() {
16 return uniformDraw(quds);
17};
18
19// meaning funtion to interpret the utterances
20var literalMeanings = {
21 all: function(state) { return state === 3; },
22 some: function(state) { return state > 0; },
23 none: function(state) { return state === 0; }
24};
25
26// projection function
27var qudFn = function(qud, state) {
28 var qudAdressed = qud === "all?" ? state === 3 : state > 0
29 return qudAdressed
30}
31
32// literal listener
33var literalListener = cache(function(utt,qud) {
34 return Infer({model: function(){
35 var state = statePrior()
36 var meaning = literalMeanings[utt]
37 condition(meaning(state))
38 return qudFn(qud,state)
39 }})
40});
41
42// set speaker optimality
43var alpha = 1
44
45// pragmatic speaker
46var speaker = cache(function(state,qud) {
47 return Infer({model: function(){
48 var utt = utterancePrior()
49 factor(alpha * literalListener(utt,qud).score(qudFn(qud,state)))
50 return utt
51 }})
52});
53
54// pragmatic listener
55var pragmaticListener = cache(function(utt,qud) {
56 return Infer({model: function(){
57 var state = statePrior()
58 observe(speaker(state,qud),utt)
59 return state
60 }})
61});
62
63// print("pragmatic listener's interpretation of 'some':")
64viz(pragmaticListener('some','any?'));
65
66var ANSWER = (pragmaticListener('some','any?'));
67
realization0.000
python
1# Scalar-implicature RSA with QUDs (Goodman & Stuhlmuller style).
2# Each RSA level is a separate, completely-finished Pyro discrete enumeration
3# (config_enumerate + TraceEnum_ELBO.compute_marginals). Lower levels are fully
4# computed and memoized BEFORE any higher level runs, so no inference ever runs
5# inside another inference's active enumeration. Latent site names are unique
6# across levels (s0 / u1 / s2), avoiding site collisions.
7
8states = [0, 1, 2, 3]
9utterances = ["all", "some", "none"]
10quds = ["all?", "any?"]
11alpha = 1.0
12
13state_logits = torch.zeros(len(states)) # uniformDraw([0,1,2,3])
14utt_logits = torch.zeros(len(utterances)) # uniformDraw(['all','some','none'])
15
16
17def literal_meaning(utt, state):
18 if utt == "all":
19 return state == 3
20 if utt == "some":
21 return state > 0
22 return state == 0 # none
23
24
25def qud_fn(qud, state):
26 if qud == "all?":
27 return state == 3
28 return state > 0 # any?
29
30
31def marginal_dict(model, site, support):
32 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(
33 model, lambda: None
34 )[site]
35 sup = marg.enumerate_support()
36 probs = marg.log_prob(sup).exp()
37 out = {}
38 for i in range(sup.shape[0]):
39 out[support[int(sup[i].item())]] = float(probs[i].item())
40 return out
41
42
43def logscore(d, key):
44 p = d.get(key, 0.0)
45 return math.log(p) if p > 0 else float("-inf")
46
47
48# ---- Literal listener L0: infer qudFn(qud,state) given (utt, qud) ----
49L0_cache = {}
50
51
52def literal_listener(utt, qud):
53 key = (utt, qud)
54 if key in L0_cache:
55 return L0_cache[key]
56 qspace = [False, True] # support over the QUD projection value
57
58 @pyro.infer.config_enumerate
59 def model():
60 s = pyro.sample("s0", dist.Categorical(logits=state_logits))
61 ok = torch.tensor([literal_meaning(utt, st) for st in states])
62 pyro.factor("ev", torch.where(ok[s], torch.tensor(0.0), torch.tensor(float("-inf"))))
63 qidx = torch.tensor([1 if qud_fn(qud, st) else 0 for st in states])
64 # deterministic projection of the latent state to its QUD value,
65 # realized as an enumerated site pinned to qudFn(qud,state).
66 x = pyro.sample("qval", dist.Categorical(logits=torch.zeros(2)))
67 pyro.factor("proj", torch.where(x == qidx[s], torch.tensor(0.0), torch.tensor(float("-inf"))))
68
69 L0_cache[key] = marginal_dict(model, "qval", qspace)
70 return L0_cache[key]
71
72
73# ---- Speaker: infer utterance given (state, qud) ----
74S_cache = {}
75
76
77def speaker(state, qud):
78 key = (state, qud)
79 if key in S_cache:
80 return S_cache[key]
81 qval = qud_fn(qud, state)
82 # speaker scores by L0.score(qudFn(qud,state)); pre-warm L0 first.
83 sc = torch.tensor([alpha * logscore(literal_listener(u, qud), qval) for u in utterances])
84
85 @pyro.infer.config_enumerate
86 def model():
87 u = pyro.sample("u1", dist.Categorical(logits=utt_logits))
88 pyro.factor("f", sc[u])
89
90 S_cache[key] = marginal_dict(model, "u1", utterances)
91 return S_cache[key]
92
93
94# ---- Pragmatic listener: infer state given (utt, qud) ----
95def pragmatic_listener(utt, qud):
96 # Fully compute & memoize all speaker marginals we need BEFORE outer run.
97 sc = torch.tensor([logscore(speaker(st, qud), utt) for st in states])
98
99 @pyro.infer.config_enumerate
100 def model():
101 s = pyro.sample("s2", dist.Categorical(logits=state_logits))
102 pyro.factor("obs", sc[s])
103
104 return marginal_dict(model, "s2", states)
105
106
107ANSWER = pragmatic_listener("some", "any?")
108
02answer overlay — webppl vs pyrodist/int
webppl pyro3 bins · 1 … 3
00.170.170.330.331 A = 0.333 B = 0.3331 A = 0.333 B = 0.3330.330.3312 A = 0.333 B = 0.3332 A = 0.333 B = 0.3330.330.3323 A = 0.333 B = 0.3333 A = 0.333 B = 0.3330.330.333
03 verification
checkstatusevidence
GT self-consistency ok floor 0.0000 (w1)
solver re-derivation accept 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6
cross-language (pyro vs webppl) pass d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000
forestdb-schizophrenia-urns / atom-1
answer dist/real solver accept pyro pass 0.0000
00 statement source: forestdb.org/models/schizophrenia-urns.md
given

Each trial draws nMarbles = 8 marbles. Confidence threshold = 0.6. The participant drew selfData = 4 red marbles. Four other agents reported: [{prediction: 'red', confidence: 'high'}, {prediction: 'red', confidence: 'high'}, {prediction: 'blue', confidence: 'low'}, {prediction: 'blue', confidence: 'low'}]. The prior over the true proportion of red marbles in the urn pRed is a discrete uniform over {0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1}.

model

Each other agent infers pRed from the same prior by updating on a Binomial(n=8, p=pRed) draw equal to their (latent) red count, then predicts 'red' with probability pRed and 'blue' with probability 1 - pRed. An agent's confidence is 'high' if (prediction = 'red' and pRed >= 0.6) or (prediction = 'blue' and pRed <= 0.4), and 'low' otherwise. The participant infers pRed by first conditioning on a Binomial(n=8, p=pRed) draw equal to selfData = 4, then incorporating each other agent's report by computing the expected log-likelihood of that report over the distribution of latent red counts the agent could have seen (Binomial(n=8, p=pRed)), where each latent count is evaluated under the agent's inference model.

query

The posterior distribution over pRed (the urn's true proportion of red marbles), computed by exact enumeration.

answer spec dist/real
{
  "kind": "dist",
  "domain": "real"
}
system prompt constant across problems
(system prompt loads here)
webppl primer solver context
(primer loads here)
01 realizations comparing webppl vs pyro
ground truth
webppl
1// total number of marbles drawn from urn every time
2var nMarbles = 8;
3var threshold = .6;
4
5// example data point for self
6var selfData = 4;
7
8// example data point for others
9var otherData = [{prediction: 'red', confidence: 'high'},
10 {prediction: 'red', confidence: 'high'},
11 {prediction: 'blue', confidence: 'low'},
12 {prediction: 'blue', confidence: 'low'}];
13
14// (discretized) uniform distribution over actual proportion of red in urn
15var rednessPrior = Categorical({vs: [0, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1]});
16
17/*
18 Generative model of other agents
19
20 Assumes they are *also* doing inference about actual proportion
21 based on their data and responding according to their best guess...
22 */
23var otherOutput = cache(function(kRed, threshold) {
24 return Infer({method: 'enumerate', model: function() {
25 var pRed = sample(rednessPrior);
26 var prediction = flip(pRed) ? 'red' : 'blue';
27 var highConf = prediction === 'red' ? pRed >= threshold : pRed <= 1 - threshold;
28
29 observe(Binomial({p: pRed, n: nMarbles}), kRed);
30 return {prediction: prediction, confidence: highConf ? 'high' : 'low'}
31 }});
32})
33
34/*
35 Model of participant's inference on a given trial
36*/
37var trialModel = function() {
38 // participant is trying to infer latent distribution in urn
39 var pRed = sample(rednessPrior);
40
41 // first, take into account own data (i.e. a draw of balls from urn)
42 observe(Binomial({p: pRed, n: nMarbles}), selfData);
43
44 // next, take into account social information
45 // assume their sample was drawn from sample population but not sure of exact data
46 mapData({data: otherData}, function(datum) {
47 var kSeenPrior = Binomial({p: pRed, n : nMarbles});
48 var likelihood = expectation(kSeenPrior, function(k) {
49 return otherOutput(k, threshold).score(datum)
50 })
51 factor(likelihood);
52 })
53
54 // ugly js convert to string for pretty plots
55 return pRed;
56}
57
58
59var ANSWER = (Infer({method: 'enumerate', model: trialModel}))
realization0.000
python
1
2# forestdb-schizophrenia-urns/atom-1
3# Participant infers the urn's red-proportion pRed from own draw + social reports.
4# Both the participant's inference and each other agent's inference are exact
5# enumeration over the discretized pRed prior.
6
7nMarbles = 8
8threshold = 0.6
9selfData = 4
10otherData = [
11 {"prediction": "red", "confidence": "high"},
12 {"prediction": "red", "confidence": "high"},
13 {"prediction": "blue", "confidence": "low"},
14 {"prediction": "blue", "confidence": "low"},
15]
16
17pRed_values = torch.tensor([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])
18n_pRed = pRed_values.shape[0]
19
20
21def binom_logprob(k, n, p):
22 # log P(X = k) for X ~ Binomial(n, p); a single scalar tensor
23 return dist.Binomial(total_count=n, probs=p).log_prob(torch.tensor(float(k)))
24
25
26# --- Other agent's generative model (exact enumeration over pRed) ----------
27# Given the agent saw kRed reds, the agent's posterior over pRed is obtained by
28# enumerating pRed and observing the agent's own draw. The agent then reports a
29# prediction (red w.p. pRed) with a deterministic confidence. We read the
30# induced report distribution off the enumerated pRed posterior.
31def other_output(kRed):
32 @pyro.infer.config_enumerate
33 def agent_model():
34 idx = pyro.sample("pRed", dist.Categorical(torch.ones(n_pRed) / n_pRed))
35 p = pRed_values[idx]
36 pyro.factor("agent_data", binom_logprob(kRed, nMarbles, p))
37 return idx
38
39 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(
40 agent_model, lambda: None
41 )
42 probs_p = marg["pRed"].probs # agent's posterior over pRed indices
43
44 out = {}
45 for i in range(n_pRed):
46 p = pRed_values[i].item()
47 wp = probs_p[i].item()
48 for pred_name, pred_p in (("red", p), ("blue", 1.0 - p)):
49 if pred_name == "red":
50 high = p >= threshold
51 else:
52 high = p <= 1.0 - threshold
53 conf = "high" if high else "low"
54 key = (pred_name, conf)
55 out[key] = out.get(key, 0.0) + wp * pred_p
56 return out
57
58
59# Cache agent outputs for every possible kRed (0..nMarbles)
60_other_cache = {k: other_output(k) for k in range(nMarbles + 1)}
61
62
63# --- Participant's inference (exact enumeration over pRed) -----------------
64@pyro.infer.config_enumerate
65def trial_model():
66 idx = pyro.sample("pRed", dist.Categorical(torch.ones(n_pRed) / n_pRed))
67 p = pRed_values[idx]
68
69 # own data
70 pyro.factor("self_data", binom_logprob(selfData, nMarbles, p))
71
72 # social data: expected log-likelihood of each report over the agent's
73 # latent red count k ~ Binomial(nMarbles, p)
74 for j, datum in enumerate(otherData):
75 key = (datum["prediction"], datum["confidence"])
76 exp_log_lik = torch.tensor(0.0)
77 for k in range(nMarbles + 1):
78 p_k = binom_logprob(k, nMarbles, p).exp()
79 p_datum = _other_cache[k].get(key, 1e-300)
80 exp_log_lik = exp_log_lik + p_k * torch.log(torch.tensor(p_datum))
81 pyro.factor(f"social_{j}", exp_log_lik)
82
83 return idx
84
85
86marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(
87 trial_model, lambda: None
88)
89pRed_post = marg["pRed"].probs
90
91ANSWER = {round(pRed_values[i].item(), 1): pRed_post[i].item() for i in range(n_pRed)}
92
02answer overlay — webppl vs pyrodist/real
webppl pyro11 bins · 0 … 1
00.210.210.420.420 A = 0.000 B = 0.0000 A = 0.000 B = 0.00000.1 A = 0.000 B = 0.0000.1 A = 0.000 B = 0.0000.10.2 A = 0.000 B = 0.0000.2 A = 0.000 B = 0.0000.20.3 A = 0.000 B = 0.0000.3 A = 0.000 B = 0.0000.30.4 A = 0.007 B = 0.0070.4 A = 0.007 B = 0.0070.010.010.40.5 A = 0.079 B = 0.0790.5 A = 0.079 B = 0.0790.080.080.50.6 A = 0.302 B = 0.3020.6 A = 0.302 B = 0.3020.300.300.60.7 A = 0.418 B = 0.4180.7 A = 0.418 B = 0.4180.420.420.70.8 A = 0.181 B = 0.1810.8 A = 0.181 B = 0.1810.180.180.80.9 A = 0.012 B = 0.0120.9 A = 0.012 B = 0.0120.010.010.91 A = 0.000 B = 0.0001 A = 0.000 B = 0.0001
03 verification
checkstatusevidence
GT self-consistency ok floor 0.0000 (tv)
solver re-derivation accept 1/2 solvers · d=[0.000, 0.228] · claude-sonnet-4-6
cross-language (pyro vs webppl) pass d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000
forestdb-singh-uyeda-pronouns / atom-1
answer dist/finite solver accept pyro pass 0.0000
00 statement source: forestdb.org/models/singh-uyeda-pronouns.md
given

The sentence template is 'John hit Fred and Ellen hit ___.' Possible utterances to fill the blank are: 'him' (ambiguous pronoun), 'Fred' (unambiguous), and 'John' (unambiguous), with a categorical prior over utterances weighted [2, 1, 1] respectively. The referent space is {John, Fred}, drawn with equal probability. Possible resolution strategies are {Subject, Parallel}, drawn with equal probability. The meaning of an utterance given a referent and a strategy is: if the utterance is 'him', the Subject strategy resolves it to John and the Parallel strategy resolves it to Fred; if the utterance is a proper name, it is true only when the referent matches the name exactly.

model

A literal listener interprets an utterance relative to a fixed strategy: starting from the uniform referent prior, it conditions on the utterance being true of the referent under that strategy. A speaker who intends a referent under a given strategy chooses an utterance with probability proportional to the utterance prior times the literal listener's probability of assigning the intended referent to that utterance under that strategy. A pragmatic listener who hears an utterance weights each (referent, strategy) pair by its prior probability times the speaker's probability of producing that utterance for that referent under that strategy, and marginalizes over strategies.

query

The posterior distribution over referents (John or Fred) for a pragmatic listener who hears 'him'.

answer spec dist/finite
{
  "kind": "dist",
  "domain": "finite",
  "support": [
    "John",
    "Fred"
  ]
}
system prompt constant across problems
(system prompt loads here)
webppl primer solver context
(primer loads here)
01 realizations comparing webppl vs pyro
ground truth
webppl
1///fold:
2// possible utterances: Ambiguous, Unambiguous Fred, Unambiguous John
3var utterances = ["him", "Fred", "John"]
4
5// samples the utterances that fill in the following template:
6// "John hit Fred and Ellen hit (the utterance)"
7
8var utterancePrior = function() {
9 categorical([2,1,1], utterances)
10}
11
12//possible world states of who Ellen hit
13var states = ["John", "Fred"]
14var statePrior = function() {
15 return uniformDraw(states)
16}
17
18//possible strategies
19var strategyPrior = function(){
20 return uniformDraw(["Subject", "Parallel"])
21}
22
23//meaning function
24var meaning = function(utterance, state, strategy){
25 return utterance == "him" ?
26 strategy == "Subject" ? state == "John" :
27 state == "Fred" :
28 utterance == state
29}
30
31// Literal Listener (L0)
32var literalListener = cache(function(utterence, strategy){
33 return Infer({model: function(){
34 var state = statePrior()
35 condition(meaning(utterence, state, strategy))
36 return state
37 }})
38})
39
40// Speaker (S)
41var speaker = cache(function(strategy, state){
42 return Infer({model: function() {
43 var utterance = utterancePrior()
44 observe(literalListener(utterance, strategy), state)
45 return utterance
46 }})
47})
48///
49// Pragmatic listener (L1)
50var pragmaticListener = cache(function(utterance) {
51 return Infer({model: function(){
52 var state = statePrior()
53 var strategy = strategyPrior()
54 observe(speaker(strategy,state),utterance)
55 return state
56 }})
57})
58pragmaticListener("him")
59
60var ANSWER = (pragmaticListener('him'));
61
realization0.000
python
1# Singh & Uyeda pronoun-resolution RSA.
2# Each RSA level is a separate, completely-finished Pyro discrete enumeration
3# (config_enumerate + TraceEnum_ELBO.compute_marginals). Lower levels are fully
4# computed and memoized BEFORE any higher level runs, so no inference runs inside
5# another's active enumeration. Latent site names are unique across levels
6# (s0 / u1 / s2,st2), avoiding site collisions.
7# Note: the Speaker here does observe(literalListener(utterance,strategy), state),
8# i.e. it factors by L0.score(state) directly, matching the WebPPL reference.
9
10utterances = ["him", "Fred", "John"]
11states = ["John", "Fred"]
12strategies = ["Subject", "Parallel"]
13
14utt_logits = torch.log(torch.tensor([2.0, 1.0, 1.0])) # categorical([2,1,1], utterances)
15state_logits = torch.zeros(len(states)) # uniformDraw
16strategy_logits = torch.zeros(len(strategies)) # uniformDraw
17
18
19def meaning(utterance, state, strategy):
20 if utterance == "him":
21 if strategy == "Subject":
22 return state == "John"
23 return state == "Fred"
24 return utterance == state
25
26
27def marginal_dict(model, site, support):
28 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(
29 model, lambda: None
30 )[site]
31 sup = marg.enumerate_support()
32 probs = marg.log_prob(sup).exp()
33 out = {}
34 for i in range(sup.shape[0]):
35 out[support[int(sup[i].item())]] = float(probs[i].item())
36 return out
37
38
39def logscore(d, key):
40 p = d.get(key, 0.0)
41 return math.log(p) if p > 0 else float("-inf")
42
43
44# ---- Literal listener L0: infer state given (utterance, strategy) ----
45L0_cache = {}
46
47
48def literal_listener(utterance, strategy):
49 key = (utterance, strategy)
50 if key in L0_cache:
51 return L0_cache[key]
52
53 @pyro.infer.config_enumerate
54 def model():
55 s = pyro.sample("s0", dist.Categorical(logits=state_logits))
56 ok = torch.tensor([meaning(utterance, st, strategy) for st in states])
57 pyro.factor("ev", torch.where(ok[s], torch.tensor(0.0), torch.tensor(float("-inf"))))
58
59 L0_cache[key] = marginal_dict(model, "s0", states)
60 return L0_cache[key]
61
62
63# ---- Speaker S: infer utterance given (strategy, state); observes L0 at state ----
64S_cache = {}
65
66
67def speaker(strategy, state):
68 key = (strategy, state)
69 if key in S_cache:
70 return S_cache[key]
71 # observe(literalListener(utterance,strategy), state) -> factor L0.score(state)
72 sc = torch.tensor([logscore(literal_listener(u, strategy), state) for u in utterances])
73
74 @pyro.infer.config_enumerate
75 def model():
76 u = pyro.sample("u1", dist.Categorical(logits=utt_logits))
77 pyro.factor("obs", sc[u])
78
79 S_cache[key] = marginal_dict(model, "u1", utterances)
80 return S_cache[key]
81
82
83# ---- Pragmatic listener L1: infer state given utterance ----
84def pragmatic_listener(utterance):
85 # Fully compute & memoize the whole speaker table BEFORE outer run.
86 tbl = torch.full((len(states), len(strategies)), float("-inf"))
87 for si, sstate in enumerate(states):
88 for sti, strat in enumerate(strategies):
89 tbl[si, sti] = logscore(speaker(strat, sstate), utterance)
90
91 @pyro.infer.config_enumerate
92 def model():
93 s = pyro.sample("s2", dist.Categorical(logits=state_logits))
94 st = pyro.sample("st2", dist.Categorical(logits=strategy_logits))
95 pyro.factor("obs", tbl[s, st])
96
97 return marginal_dict(model, "s2", states)
98
99
100ANSWER = pragmatic_listener("him")
101
02answer overlay — webppl vs pyrodist/finite
webppl pyro2 bins
00.250.250.500.50Fred A = 0.500 B = 0.500Fred A = 0.500 B = 0.5000.500.50FredJohn A = 0.500 B = 0.500John A = 0.500 B = 0.5000.500.50John
03 verification
checkstatusevidence
GT self-consistency ok floor 0.0000 (tv)
solver re-derivation accept 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6
cross-language (pyro vs webppl) pass d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000
forestdb-social-meaning / atom-1
answer dist/finite solver accept pyro pass 0.0000
00 statement source: forestdb.org/models/social-meaning.md
given

Four personae: {name: 'stern', competence: true, friendliness: false}, {name: 'cool', competence: true, friendliness: true}, {name: 'asshole', competence: false, friendliness: false}, {name: 'doofus', competence: false, friendliness: true}. The prior over personae for a voter audience assigns probabilities [0.3, 0.2, 0.3, 0.2] to [stern, cool, asshole, doofus] respectively. Two morphological variants: 'n' and 'ng', both with production cost 0. Speaker optimality alpha = 6. Eckert-Montague semantics: 'ng' is consistent with a persona if that persona has competence = true or friendliness = false; 'n' is consistent with a persona if that persona has competence = false or friendliness = true.

model

A literal listener for a variant samples from the voter persona prior, conditions on the variant being semantically consistent with the persona, and returns the persona. A speaker for a given persona draws uniformly from the two variants and weights each by alpha times (log-probability the literal listener assigns to that persona given the variant, minus the variant's cost). A naive pragmatic listener for a variant samples a persona from the voter prior and updates on the speaker choosing that variant.

query

The posterior distribution over persona names for a naive pragmatic listener who hears the variant 'ng'.

answer spec dist/finite
{
  "kind": "dist",
  "domain": "finite",
  "support": [
    "stern",
    "cool",
    "asshole",
    "doofus"
  ]
}
system prompt constant across problems
(system prompt loads here)
webppl primer solver context
(primer loads here)
01 realizations comparing webppl vs pyro
ground truth
webppl
1// Personae in (6)
2var personae = [{name: "stern", competence : true, friendliness : false},
3 {name: "cool", competence : true, friendliness : true},
4 {name: "asshole", competence : false, friendliness : false},
5 {name: "doofus", competence : false, friendliness: true}]
6
7// Table 2
8var voterPrior = Categorical({ps: [0.3,0.2,0.3,0.2], vs: personae})
9
10// Table 5
11var journalistPrior = Categorical({ps: [0.2,0.2,0.3,0.3], vs: personae})
12
13var personaePrior = voterPrior
14// var personaePrior = journalistPrior
15
16var variants = ["n","ng"]
17
18var cost = {
19 n : 0,
20 ng : 0
21}
22
23// Eckert-Montague Semantics (Table 1)
24var semantics = {
25 ng: function(persona) { return (persona.competence == true | persona.friendliness == false) ; },
26 n: function(persona) { return (persona.competence == false | persona.friendliness == true); },
27}
28
29// Definition in (11) - the 'literal listener'
30
31var conditionalization = function(variant) {
32 return Infer({model: function(){
33 var persona = sample(personaePrior)
34 var meaning = semantics[variant]
35 condition(meaning(persona))
36 return persona
37 }}
38)}
39
40// Definition in (12)
41
42var utility = function(persona, variant) {
43
44 var informativity = conditionalization(variant).score(persona)
45 return(informativity - cost[variant])
46
47}
48
49var alpha = 6
50
51// Definition in (13) - soft-max choice rule
52
53var speaker = function(persona) {
54 return Infer(function() {
55 var variant = uniformDraw(variants)
56 factor(alpha * utility(persona,variant))
57 return(variant)
58 })
59}
60
61// Table 6: persona selection function (the value system)
62
63var mu = function(persona) {
64
65 persona.name == "cool" ? 2 :
66 persona.name == "stern" ? 1 :
67 persona.name == "doofus" ? 1 :
68 persona.name == "asshole" ? 0 :
69 0
70
71}
72
73// Definition in (14): probability distribution over personae
74
75var alphaprime = 6
76
77var personaDistribution = Infer(
78 function() {
79 var persona = sample(personaePrior)
80 factor(alphaprime * mu(persona))
81 return persona
82 })
83
84// Definition in (15): speaker with a value system
85
86var valueSpeaker = function(variant) {
87
88 // Array of utilities of variant for each persona, times probability of the persona
89 var variantUtility = map(function(persona) {
90 return Math.exp(personaDistribution.score(persona)) * Math.exp(speaker(persona).score(variant))
91 }, personae)
92
93 return sum(variantUtility)
94
95}
96
97// Definition in (17): Listening with certainty about speaker's values
98
99var valueInformedListener = function(variant) {
100 return Infer(function(){
101
102 var persona = sample(personaePrior)
103 factor(Math.exp(personaDistribution.score(persona)) * Math.exp(speaker(persona).score(variant)))
104 return persona.name
105
106 })
107}
108
109// Definition in (18): Naive listening
110
111var naiveListener = function(variant) {
112 return Infer(function(){
113
114 var persona = sample(personaePrior)
115 factor(speaker(persona).score(variant))
116 return persona.name
117 })
118}
119
120print("Literal L's beliefs immediately after hearing -n at the barbecue")
121viz.table(conditionalization('n'))
122
123print("Literal L’s beliefs immediately after hearing -ng at the barbecue")
124viz.table(conditionalization('ng'))
125
126print("Obama wants to be the cool guy")
127viz.table(speaker(personae[1]))
128
129// set prior to journalistPrior above
130print("Obama's overall probability of using -ng with the journalist")
131
132print(valueSpeaker('ng'))
133
134print("Obama's overall probability of using -n with the journalist")
135
136print(valueSpeaker('n'))
137
138print("Hearing Obama use -n (and you have a sense of Obama's values)")
139
140viz(valueInformedListener('n'))
141
142print("Hearing Obama use -n (and you're naive as to Obama's values)")
143
144viz(naiveListener('n'))
145
146var ANSWER = naiveListener('ng');
147
realization0.000
python
1# RSA-style social-meaning model (Eckert-Montague semantics).
2# Literal listener -> speaker (softmax) -> naive pragmatic listener, all by exact
3# enumeration. Each level's posterior is obtained through Pyro inference and its
4# log-probabilities (WebPPL's .score) feed the next level.
5
6personae = [
7 {"name": "stern", "competence": True, "friendliness": False},
8 {"name": "cool", "competence": True, "friendliness": True},
9 {"name": "asshole", "competence": False, "friendliness": False},
10 {"name": "doofus", "competence": False, "friendliness": True},
11]
12prior_ps = torch.tensor([0.3, 0.2, 0.3, 0.2])
13variants = ["n", "ng"]
14cost = {"n": 0.0, "ng": 0.0}
15alpha = 6.0
16
17
18def semantics(variant, persona):
19 if variant == "ng":
20 return persona["competence"] is True or persona["friendliness"] is False
21 else: # "n"
22 return persona["competence"] is False or persona["friendliness"] is True
23
24
25# --- Literal listener (definition 11): condition the persona prior on the
26# --- variant being semantically consistent; returns a distribution over personae.
27def literal_listener(variant):
28 @pyro.infer.config_enumerate
29 def model():
30 idx = pyro.sample("persona", dist.Categorical(prior_ps))
31 consistent = torch.tensor(
32 [1.0 if semantics(variant, p) else 0.0 for p in personae]
33 )[idx]
34 logw = torch.where(
35 consistent > 0, torch.tensor(0.0), torch.tensor(float("-inf"))
36 )
37 pyro.factor("sem", logw)
38 return idx
39
40 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(
41 model, lambda: None
42 )
43 return marg["persona"]
44
45
46LL = {v: literal_listener(v) for v in variants}
47idx_support = torch.arange(len(personae))
48
49
50def literal_score(variant, persona_idx):
51 return LL[variant].log_prob(torch.tensor(persona_idx))
52
53
54def utility(persona_idx, variant):
55 return literal_score(variant, persona_idx) - cost[variant]
56
57
58# --- Speaker (definition 13): uniform over variants, factor by alpha*utility;
59# --- returns a distribution over variants for a given persona.
60def speaker(persona_idx):
61 @pyro.infer.config_enumerate
62 def model():
63 v = pyro.sample("variant", dist.Categorical(torch.tensor([0.5, 0.5])))
64 u = torch.stack(
65 [alpha * utility(persona_idx, variants[j]) for j in range(len(variants))]
66 )[v]
67 pyro.factor("u", u)
68 return v
69
70 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(
71 model, lambda: None
72 )
73 return marg["variant"]
74
75
76SPK = [speaker(i) for i in range(len(personae))]
77
78
79def speaker_score(persona_idx, variant):
80 v_idx = variants.index(variant)
81 return SPK[persona_idx].log_prob(torch.tensor(v_idx))
82
83
84# --- Naive pragmatic listener (definition 18): sample persona from prior, factor
85# --- by speaker(persona).score(variant), return the persona name distribution.
86def naive_listener(variant):
87 @pyro.infer.config_enumerate
88 def model():
89 idx = pyro.sample("persona", dist.Categorical(prior_ps))
90 s = torch.stack(
91 [speaker_score(i, variant) for i in range(len(personae))]
92 )[idx]
93 pyro.factor("spk", s)
94 return idx
95
96 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(
97 model, lambda: None
98 )
99 return marg["persona"]
100
101
102final = naive_listener("ng")
103probs = torch.exp(final.log_prob(idx_support))
104ANSWER = {personae[i]["name"]: probs[i].item() for i in range(len(personae))}
105
02answer overlay — webppl vs pyrodist/finite
webppl pyro3 bins
00.330.330.660.66asshole A = 0.204 B = 0.204asshole A = 0.204 B = 0.2040.200.20assholecool A = 0.136 B = 0.136cool A = 0.136 B = 0.1360.140.14coolstern A = 0.660 B = 0.660stern A = 0.660 B = 0.6600.660.66stern
03 verification
checkstatusevidence
GT self-consistency ok floor 0.0000 (tv)
solver re-derivation accept 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6
cross-language (pyro vs webppl) pass d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000
forestdb-zhu-antonyms / atom-1
answer record(expensivePrice, notInexpensivePrice) solver accept pyro pass 0.0000
00 statement source: data/sources/forestdb.org/models/zhu-antonyms.md
given

Two items, watch and sweater, each with a prior distribution over prices and a theta prior (uniform over the item's price bins). Sweater price bins and probabilities: | price | probability | |---|---| | 1.5 | 0.00482838499944466 | | 4.5 | 0.00832934578733181 | | 7.5 | 0.0112952500492109 | | 10.5 | 0.0173774790108894 | | 13.5 | 0.0232006658974883 | | 16.5 | 0.0258422772579257 | Watch price bins: 60 bins from 25 to 2975 in steps of 50, with probabilities [0.040844560268751, 0.0587099798246933, 0.0656194599591356, 0.0667642412698035, 0.0615953803048016, 0.0510809063784378, 0.0467203673419258, 0.0446735950187136, 0.040047421916613, 0.0350583957334483, 0.0297508215717606, 0.0256829651118227, 0.024135920250668, 0.0228891907259206, 0.021706684520276, 0.0186449440066946, 0.0187249266247728, 0.0179250744798993, 0.0173698811746238, 0.0165581725818319, 0.0160745066032247, 0.0127927305129066, 0.0113730680265067, 0.0109485307623827, 0.00923468422650943, 0.00899007751887508, 0.00880520147998275, 0.00838023585866885, 0.00841052411004918, 0.00828830635037619, 0.00834008093757411, 0.00750681534099784, 0.00724072133740109, 0.00717291664158004, 0.00682823777708754, 0.00646995193940331, 0.00697139732982518, 0.00711846547272734, 0.00698781312802354, 0.00732316558583701, 0.00594973158122097, 0.00557461443747403, 0.00541637601910211, 0.00518850469148531, 0.00572025848989677, 0.0051443557601358, 0.00510282169734075, 0.00493720252580643, 0.00560198932991028, 0.00519158715054485, 0.00473398797752786, 0.00540907722833213, 0.00494653421540979, 0.00495500420164643, 0.00494083025189895, 0.00481566268206312, 0.00442965937328148, 0.00441189688100535, 0.00415116538135834, 0.00361842012002631]. Two utterances: 'expensive' and 'not-inexpensive'. Utterance costs: 'expensive' costs 1, 'not-inexpensive' costs 2. Speaker optimality alpha = 2. Soft semantics: 'expensive' is true of a price with probability 0.9999 if price > theta_expensive, and 0.0001 otherwise; 'not-inexpensive' is true with probability 0.9999 if price >= theta_inexpensive, and 0.0001 otherwise. The threshold theta_expensive is drawn uniformly from the item's price bins. With probability 0.2, theta_inexpensive is drawn independently from the item's price bins; with probability 0.8, theta_inexpensive equals theta_expensive.

model

A literal listener for a given utterance and pair of thresholds draws a price uniformly from the item's bins, conditions on the soft semantic being true, and returns the price. A speaker for a given price and threshold pair draws an utterance uniformly and weights it by alpha times (the literal listener's log-probability of the price minus the utterance's cost). A pragmatic listener for a given utterance jointly samples a price from the item's categorical prior, a theta_expensive threshold, and a theta_inexpensive threshold (correlated with theta_expensive with probability 0.8), then updates on the speaker choosing the heard utterance.

query

Two scalar expectations: the expected price inferred by a pragmatic listener hearing 'expensive' for the sweater item, and the expected price inferred hearing 'not-inexpensive' for the sweater item. Return as a record with fields expensivePrice and notInexpensivePrice.

answer spec record(expensivePrice, notInexpensivePrice)
{
  "kind": "record",
  "fields": {
    "expensivePrice": {
      "kind": "value",
      "domain": "real"
    },
    "notInexpensivePrice": {
      "kind": "value",
      "domain": "real"
    }
  }
}
system prompt constant across problems
(system prompt loads here)
webppl primer solver context
(primer loads here)
01 realizations comparing webppl vs pyro
ground truth
webppl
1var watch = {
2 "prices": [25, 75, 125, 175, 225, 275, 325, 375, 425, 475, 525, 575, 625, 675, 725, 775, 825, 875, 925, 975, 1025, 1075, 1125, 1175, 1225, 1275, 1325, 1375, 1425, 1475, 1525, 1575, 1625, 1675, 1725, 1775, 1825, 1875, 1925, 1975, 2025, 2075, 2125, 2175, 2225, 2275, 2325, 2375, 2425, 2475, 2525, 2575, 2625, 2675, 2725, 2775, 2825, 2875, 2925, 2975],
3 "probabilities": [0.040844560268751, 0.0587099798246933, 0.0656194599591356, 0.0667642412698035, 0.0615953803048016, 0.0510809063784378, 0.0467203673419258, 0.0446735950187136, 0.040047421916613, 0.0350583957334483, 0.0297508215717606, 0.0256829651118227, 0.024135920250668, 0.0228891907259206, 0.021706684520276, 0.0186449440066946, 0.0187249266247728, 0.0179250744798993, 0.0173698811746238, 0.0165581725818319, 0.0160745066032247, 0.0127927305129066, 0.0113730680265067, 0.0109485307623827, 0.00923468422650943, 0.00899007751887508, 0.00880520147998275, 0.00838023585866885, 0.00841052411004918, 0.00828830635037619, 0.00834008093757411, 0.00750681534099784, 0.00724072133740109, 0.00717291664158004, 0.00682823777708754, 0.00646995193940331, 0.00697139732982518, 0.00711846547272734, 0.00698781312802354, 0.00732316558583701, 0.00594973158122097, 0.00557461443747403, 0.00541637601910211, 0.00518850469148531, 0.00572025848989677, 0.0051443557601358, 0.00510282169734075, 0.00493720252580643, 0.00560198932991028, 0.00519158715054485, 0.00473398797752786, 0.00540907722833213, 0.00494653421540979, 0.00495500420164643, 0.00494083025189895, 0.00481566268206312, 0.00442965937328148, 0.00441189688100535, 0.00415116538135834, 0.00361842012002631]
4};
5var sweater = {
6 "prices": [1.5, 4.5, 7.5, 10.5, 13.5, 16.5],
7 "probabilities": [0.00482838499944466, 0.00832934578733181, 0.0112952500492109, 0.0173774790108894, 0.0232006658974883, 0.0258422772579257]
8};
9var data = {
10 "watch": watch,
11 "sweater": sweater
12};
13
14var prior = function(item) {
15 var prices = data[item].prices;
16 var probabilities = data[item].probabilities;
17 return function() {
18 return categorical(probabilities, prices);
19 };
20};
21
22var theta_prior = function(item) {
23 var thetas = data[item].prices;
24 return function() {
25 return uniformDraw(thetas) ;
26 };
27};
28
29var alpha = 2; // optimality parameter
30
31
32var utterances = ["expensive","not-inexpensive"];
33
34var cost = {
35 "not-inexpensive":2,
36 "expensive": 1,
37};
38var utterancePrior = function() {
39 return uniformDraw(utterances);
40};
41
42var meaning = function(utterance, price, theta) {
43 utterance == "expensive" ? price > theta.expensive ? flip(0.9999) : flip(0.0001) :
44 utterance == "not-inexpensive" ? !(price < theta.inexpensive)? flip(0.9999) : flip(0.0001):
45 true;
46};
47
48var literalListener = cache(function(utterance, theta, item) {
49 return Infer({method: "enumerate"}, function() {
50 var price = uniformDraw(data[item].prices)
51 condition(meaning(utterance, price, theta))
52 return price;
53 });
54});
55
56var speaker = cache(function(price, theta, item) {
57 return Infer({method: "enumerate"}, function() {
58 var utterance = utterancePrior();
59 factor( alpha * (literalListener(utterance, theta, item).score(price)
60 - cost[utterance]));
61 return utterance;
62 });
63});
64
65var pragmaticListener = function(utterance, item) {
66 // first identify the relevant priors
67 var pricePrior = prior(item);
68 var thetaPrior = theta_prior(item);
69 // then run inference
70 return Infer({method: "enumerate"},
71 function() {
72 var an_neg_thre = flip(0.2)
73 var expensive_theta= thetaPrior()
74
75 var inexpensive_threshold = an_neg_thre ?
76 thetaPrior() :
77 expensive_theta;
78 var price = pricePrior();
79 var theta = {
80 expensive: expensive_theta,
81 inexpensive: inexpensive_threshold
82 }
83 var Posexp = theta.expensive
84 var Posneg = theta.inexpensive
85 factor( speaker(price, theta, item).score(utterance) );
86 return { price: price, Posexp: Posexp , Posneg: Posneg };
87 });
88};
89
90
91var expensiveSweater= pragmaticListener("expensive", "sweater");
92print("Expensive:Prices")
93viz.density(marginalize(expensiveSweater, "price"));
94display(expectation(marginalize(expensiveSweater, "price")))
95print("Expensive:Thresholds:")
96viz.density(marginalize(expensiveSweater, "Posexp"));
97display(expectation(marginalize(expensiveSweater, "Posexp")))
98var notinexpensiveSweater= pragmaticListener("not-inexpensive", "sweater");
99print("NOT-Inexpensive : Prices")
100viz.density(marginalize(notinexpensiveSweater, "price"));
101display(expectation(marginalize(notinexpensiveSweater, "price")))
102print("NOT-Inexpensive:Thresholds:")
103viz.density(marginalize(notinexpensiveSweater, "Posneg"));
104display(expectation(marginalize(notinexpensiveSweater, "Posneg")))
105
106var ANSWER = ({ expensivePrice: expectation(marginalize(pragmaticListener("expensive","sweater"), "price")), notInexpensivePrice: expectation(marginalize(pragmaticListener("not-inexpensive","sweater"), "price")) });
107
realization0.000
python
1# RSA antonyms model (Zhu et al.) for the 'sweater' item.
2# Every level of inference (literal listener, speaker, pragmatic listener) is run
3# through Pyro's exact discrete enumeration (config_enumerate + TraceEnum_ELBO
4# .compute_marginals). The answer is the expectation over the price posterior of
5# the pragmatic listener for two utterances.
6
7sweater_prices = [1.5, 4.5, 7.5, 10.5, 13.5, 16.5]
8sweater_probs = [0.00482838499944466, 0.00832934578733181, 0.0112952500492109,
9 0.0173774790108894, 0.0232006658974883, 0.0258422772579257]
10# normalize the (sub-normalized) price prior, as WebPPL's categorical does
11_z = sum(sweater_probs)
12price_prior_probs = torch.tensor([p / _z for p in sweater_probs])
13prices = torch.tensor(sweater_prices)
14n_prices = len(sweater_prices)
15
16alpha = 2.0
17utterances = ["expensive", "not-inexpensive"]
18cost = {"expensive": 1.0, "not-inexpensive": 2.0}
19
20NEG_INF = torch.tensor(float("-inf"))
21ZERO = torch.tensor(0.0)
22
23
24def meaning_holds(utterance, price_idx, theta_exp_idx, theta_inexp_idx):
25 # returns a boolean tensor: does the utterance's literal meaning hold for this
26 # price under the given thresholds? prices are indexed; theta thresholds are
27 # also one of the prices (uniformDraw over prices).
28 price_val = prices[price_idx]
29 if utterance == "expensive":
30 return price_val > prices[theta_exp_idx]
31 else: # not-inexpensive: NOT (price < inexpensive threshold)
32 return ~(price_val < prices[theta_inexp_idx])
33
34
35def literal_listener_logprobs(utterance, theta_exp_idx, theta_inexp_idx):
36 # exact enumeration over price (uniform prior over prices, conditioned on
37 # meaning) via a Pyro enumerated model. Returns log-probs over price support.
38 @pyro.infer.config_enumerate
39 def model():
40 price_idx = pyro.sample("price", dist.Categorical(torch.ones(n_prices) / n_prices))
41 holds = meaning_holds(utterance, price_idx, theta_exp_idx, theta_inexp_idx)
42 # soft meaning: flip(0.9999) if holds else flip(0.0001), then condition true
43 logp_true = torch.where(holds, torch.log(torch.tensor(0.9999)),
44 torch.log(torch.tensor(0.0001)))
45 pyro.factor("meaning", logp_true)
46 return price_idx
47 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)
48 m = marg["price"]
49 sup = m.enumerate_support()
50 logps = m.log_prob(sup)
51 # index aligned to 0..n_prices-1
52 out = torch.full((n_prices,), float("-inf"))
53 for s, lp in zip(sup.tolist(), logps.tolist()):
54 out[int(s)] = lp
55 return out
56
57
58def speaker_logprobs(price_idx, theta_exp_idx, theta_inexp_idx):
59 # exact enumeration over utterance choice via Pyro. Returns log-probs over the
60 # two utterances.
61 ll_scores = {}
62 for u in utterances:
63 ll_lp = literal_listener_logprobs(u, theta_exp_idx, theta_inexp_idx)
64 ll_scores[u] = ll_lp[price_idx]
65
66 @pyro.infer.config_enumerate
67 def model():
68 u_idx = pyro.sample("utt", dist.Categorical(torch.ones(len(utterances)) / len(utterances)))
69 # factor( alpha * (literalListener.score(price) - cost[utterance]) )
70 score_vals = torch.tensor([alpha * (ll_scores[u] - cost[u]) for u in utterances])
71 pyro.factor("util", score_vals[u_idx])
72 return u_idx
73 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(model, lambda: None)
74 m = marg["utt"]
75 sup = m.enumerate_support()
76 logps = m.log_prob(sup)
77 out = torch.full((len(utterances),), float("-inf"))
78 for s, lp in zip(sup.tolist(), logps.tolist()):
79 out[int(s)] = lp
80 return out
81
82
83# Memoize the speaker log-score (per utterance) over the unique (price, exp-theta,
84# inexp-theta) configurations it depends on; each entry comes from a genuine Pyro
85# enumeration of the speaker (which itself drives the literal-listener enumeration).
86_spk_cache = {}
87def speaker_logprobs_cached(price_idx, theta_exp_idx, theta_inexp_idx):
88 key = (price_idx, theta_exp_idx, theta_inexp_idx)
89 if key not in _spk_cache:
90 _spk_cache[key] = speaker_logprobs(price_idx, theta_exp_idx, theta_inexp_idx)
91 return _spk_cache[key]
92
93
94def pragmatic_listener_price_expectation(utterance):
95 u_target = utterances.index(utterance)
96 # The pragmatic listener jointly samples: an_neg ~ flip(0.2), expensive_theta ~
97 # uniformDraw(prices), a fresh threshold ~ uniformDraw(prices), and price ~
98 # pricePrior; inexpensive_threshold = fresh if an_neg else expensive_theta. The
99 # heard utterance is observed by factoring the speaker's log-score. This whole
100 # joint is run through Pyro's exact discrete enumeration (config_enumerate +
101 # compute_marginals); the speaker score enters via a precomputed per-config
102 # table (each entry itself the result of a Pyro speaker inference).
103 #
104 # Build the speaker-score table over the joint (an_neg, exp_theta, fresh, price).
105 # inexp_theta = fresh if an_neg==1 else exp_theta. Index it inside the model.
106 spk_table = torch.full((2, n_prices, n_prices, n_prices), float("-inf"))
107 for an_neg in (0, 1):
108 for exp_theta in range(n_prices):
109 for fresh in range(n_prices):
110 inexp_theta = fresh if an_neg == 1 else exp_theta
111 for price in range(n_prices):
112 sp_lp = speaker_logprobs_cached(price, exp_theta, inexp_theta)
113 spk_table[an_neg, exp_theta, fresh, price] = sp_lp[u_target]
114
115 @pyro.infer.config_enumerate
116 def prag_model():
117 an_neg = pyro.sample("an_neg", dist.Categorical(torch.tensor([0.8, 0.2])))
118 exp_theta = pyro.sample("exp_theta", dist.Categorical(torch.ones(n_prices) / n_prices))
119 fresh = pyro.sample("fresh", dist.Categorical(torch.ones(n_prices) / n_prices))
120 price = pyro.sample("price", dist.Categorical(price_prior_probs))
121 score = pyro.ops.indexing.Vindex(spk_table)[an_neg, exp_theta, fresh, price]
122 pyro.factor("obs", score)
123 return price
124
125 marg = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0).compute_marginals(prag_model, lambda: None)
126 m = marg["price"]
127 sup = m.enumerate_support()
128 probs = m.log_prob(sup).exp()
129 return float(sum(prices[int(s)].item() * float(pr) for s, pr in zip(sup, probs)))
130
131
132ANSWER = {
133 "expensivePrice": pragmatic_listener_price_expectation("expensive"),
134 "notInexpensivePrice": pragmatic_listener_price_expectation("not-inexpensive"),
135}
136
02answer overlay — webppl vs pyrorecord(expensivePrice, notInexpensivePrice)
expensivePrice
11.6333
notInexpensivePrice
11.3486
03 verification
checkstatusevidence
GT self-consistency ok floor 0.0000 (record)
solver re-derivation accept 2/2 solvers · d=[0.000, 0.000] · claude-sonnet-4-6
cross-language (pyro vs webppl) pass d=0.0000 ≤ tol 0.0000 · floors 0.0000/0.0000