Prior-Data Fitted Network

Revised June 11, 2026

The Bayes’ Theorem for Probability Distributions note is optional but recommended background reading.

Given a dataset \(d = \{(\mathbf{x}_i, y_i)\}_{i=1}^n\) and a prior \(p(t)\) over a latent task variable \(t\), the posterior predictive distribution (PPD) gives the Bayesian prediction — the probability of each label \(y\) — for a new input \(\mathbf{x}\):

\[ \begin{equation}\label{eq:ppd} p(y \mid \mathbf{x}, d) = \int p(y \mid \mathbf{x}, t)\, p(t \mid d)\, dt \;. \end{equation} \]

For most non-conjugate prior-likelihood pairs or nonlinear likelihood models, neither the posterior \(p(t \mid d)\) nor the integral in Equation \(\eqref{eq:ppd}\) admits a closed form. In practice, this intractability is resolved in two steps — first, the posterior \(p(t \mid d)\) is approximated, then the PPD is estimated from that approximation.

This common posterior-first strategy has two structural limitations. First, posterior approximation remains dataset-specific, since each new dataset \(d\) requires a new approximation to \(p(t \mid d)\). Second, the two most common methods — Markov Chain Monte Carlo (MCMC) and variational inference (VI) — require density access. Specifically, MCMC methods such as NUTS construct a Markov chain with a stationary distribution equal to the posterior \(p(t \mid d)\), then draw samples from that chain, thus requiring an evaluable, and for NUTS differentiable, unnormalized posterior \(\tilde{p}(t \mid d) \propto p(d \mid t)p(t)\). VI replaces the posterior with a tractable family \(q_\psi(t)\), optimizing \(\psi\) to minimize \(\mathrm{KL}(q_\psi(t) \,\|\, p(t \mid d))\), thus requiring the joint density \(p(t, d)\). Consequently, neither method can be used directly with a prior available only through simulation, which provides samples but no evaluable density.

Prior-Data Fitted Networks (PFNs) replace posterior approximation with amortized prediction, removing both limitations. During a single training phase, a neural network \(Q_\theta\) is trained on datasets sampled from the prior predictive distribution \(p(d)\) to map a dataset \(d\) and query input \(\mathbf{x}\) directly to a predictive distribution over \(y\). At inference time, \(Q_\theta\) produces this distribution without explicitly representing \(t\) or approximating \(p(t \mid d)\). Thus, PFNs avoid per-dataset posterior inference and require only the ability to sample datasets from \(p(d)\), rather than evaluate any density.

The Learning Objective

Let \(Q_\theta(y \mid \mathbf{x}, d)\) denote a parameterized model mapping a context dataset \(d\) and query input \(\mathbf{x}\) to a predictive distribution over \(y\). The prior predictive distribution over datasets is obtained by first sampling a task \(t \sim p(t)\) and then sampling labeled examples conditional on that task. For a dataset of size \(n\):

\[ \begin{equation*}\label{eq:prior-predictive} p(d) = \int p(t)\prod_{i=1}^n p(\mathbf{x}_i,y_i \mid t)\,dt \;, \end{equation*} \]

so the per-task likelihood is \(p(d \mid t) = \prod_{i=1}^n p(\mathbf{x}_i,y_i \mid t)\). The inputs are assumed not to carry task information, so \(p(\mathbf{x},y \mid t)=p(\mathbf{x})\,p(y \mid \mathbf{x},t)\). Equivalently, under this assumption, sampling \(d \sim p(d)\) means sampling \(t \sim p(t)\), then drawing inputs \(\mathbf{x}_i \sim p(\mathbf{x})\) and labels \(y_i \sim p(y \mid \mathbf{x}_i,t)\) for \(i=1,\ldots,n\).

A training example is formed by sampling a dataset \(d^+ \sim p(d^+)\) using the same prior predictive process with \(n+1\) examples, then holding out one labeled pair \((\mathbf{x}, y) \in d^+\) as the query; the remaining examples form the context \(d\). As a simple example, suppose each task \(t=(t_0,t_1)\) specifies the intercept and slope of a one-dimensional binary classification rule. Drawing \(t \sim p(t)\) fixes a particular decision boundary. Conditional on this task, inputs \(x_i\) are sampled independently of \(t\), and labels are generated according to \(p(y_i=1 \mid x_i,t) = \sigma(t_1 x_i + t_0)\) with \(y_i \in \{0,1\}\). A sampled dataset \(d^+\) might contain \(50\) such labeled points. Holding out one point \((x,y)\) as the query leaves the remaining \(49\) points as the context \(d\). \(Q_\theta\) is trained to predict the held-out label \(y\) from the query input \(x\) and context \(d\). This sampling induces a joint distribution \(p(d,\mathbf{x},y)\) over contexts, query inputs, and labels. By construction, the conditional \(p(y \mid x, d)\) of this joint equals the PPD in Equation \(\eqref{eq:ppd}\).

Why the Induced Conditional Equals the PPD

The data-generating process used to form training examples — sample a full dataset from the prior predictive, then designate one of its pairs as the query — defines a joint over the task variable \(t\), the context \(d\), and the query pair \((\mathbf{x}, y)\). A full dataset \(d^+\) is drawn by sampling \(t \sim p(t)\) and then drawing its \(n+1\) labeled pairs i.i.d. from the task distribution \(p(\mathbf{x}, y \mid t)\). Holding out one pair as the query and assigning the rest to \(d\) leaves the context \(d\) and the query pair \((\mathbf{x}, y)\) conditionally independent given \(t\), thus:

\[ \begin{equation*} p(t, d, \mathbf{x}, y) = p(t)\, p(d \mid t)\, p(\mathbf{x}, y \mid t) \;. \end{equation*} \]

Marginalizing the latent \(t\) and applying the chain rule to the query term, \(p(\mathbf{x}, y \mid t) = p(\mathbf{x} \mid t)\, p(y \mid \mathbf{x}, t)\), gives the observable joint:

\[ \begin{equation*} p(d, \mathbf{x}, y) = \int p(t)\, p(d \mid t)\, p(\mathbf{x} \mid t)\, p(y \mid \mathbf{x}, t)\, dt \;. \end{equation*} \]

This joint divided by its marginal over \(y\) gives the target \(p(y \mid \mathbf{x}, d)\):

\[ \begin{align} p(y \mid \mathbf{x}, d) &= \frac{p(d, \mathbf{x}, y)}{p(d, \mathbf{x})} \nonumber \\ &= \frac{\int p(t)\, p(d \mid t)\, p(\mathbf{x} \mid t)\, p(y \mid \mathbf{x}, t)\, dt}{\int p(t)\, p(d \mid t)\, p(\mathbf{x} \mid t)\, dt} && \text{since } \sum_{y’} p(y’ \mid \mathbf{x}, t) = 1 \nonumber \\ &= \int p(y \mid \mathbf{x}, t)\, p(t \mid \mathbf{x}, d)\, dt \;, \label{eq:joint-y-marg} \end{align} \]

where \(p(t \mid \mathbf{x}, d) \propto p(t)\, p(d \mid t)\, p(\mathbf{x} \mid t)\) is the posterior over tasks given both the context and the query input.

Equation \(\eqref{eq:joint-y-marg}\) matches Equation \(\eqref{eq:ppd}\) once the query input is read as a conditioning argument rather than as evidence about the task, i.e. \(p(t \mid \mathbf{x}, d) = p(t \mid d)\). The reduction is exact whenever the input marginal carries no task information, \(p(\mathbf{x} \mid t) = p(\mathbf{x})\), since the factor \(p(\mathbf{x} \mid t)\) is then constant in \(t\) and cancels between numerator and denominator:

\[ \begin{equation}\label{eq:induced-ratio} p(y \mid \mathbf{x}, d) = \frac{\int p(t)\, p(d \mid t)\, p(y \mid \mathbf{x}, t)\, dt}{\int p(t)\, p(d \mid t)\, dt} \;. \end{equation} \]

The denominator in Equation \(\eqref{eq:induced-ratio}\) normalizes the task-dependent weights \(p(t)\,p(d \mid t)\) into the posterior:

\[ \begin{equation*} p(t \mid d) = \frac{p(t)\,p(d \mid t)}{\int p(t’)\,p(d \mid t’)\,dt’} \;. \end{equation*} \]

Therefore:

\[ \begin{equation*}\label{eq:induced-ppd} p(y \mid \mathbf{x}, d) = \int p(y \mid \mathbf{x}, t)\, p(t \mid d)\, dt \;, \end{equation*} \]

which is exactly the posterior predictive of Equation \(\eqref{eq:ppd}\). The held-out construction therefore makes the conditional of the induced joint coincide with the PPD the network is trained to approximate.

The target \(p(y \mid \mathbf{x}, d)\) fixes a symmetry the model class should respect. Specifically, because the context pairs are drawn i.i.d. given \(t\), the context \(d\) is exchangeable, and permuting its elements leaves the PPD unchanged. PFNs therefore typically realize \(Q_\theta\) as a Transformer that encodes each context pair \((\mathbf{x}_i, y_i)\) as one token and the query \(\mathbf{x}\) as an additional token, with no positional encoding over the context, so that \(Q_\theta(y \mid \mathbf{x}, d)\) is permutation-invariant in \(d\) by construction. The output at the query position parameterizes the predictive distribution over \(y\) (e.g., a softmax over classes for classification).

PFNs are trained to minimize the Prior-Data Negative Log-Likelihood (Prior-Data NLL), the expected negative log-probability \(Q_\theta\) assigns to the held-out label:

\[ \begin{equation}\label{eq:prior-data-nll} \ell_\theta = \mathbb{E}_{(d,\mathbf{x},y) \sim p(d,\mathbf{x},y)} \left[ -\log Q_\theta(y \mid \mathbf{x}, d) \right] \;. \end{equation} \]

Expanding the expectation gives:

\[ \begin{align} \ell_\theta &= -\int_{d,\mathbf{x}} \sum_y p(d,\mathbf{x},y) \log Q_\theta(y \mid \mathbf{x},d) \nonumber \\ &= -\int_{d,\mathbf{x}} p(d,\mathbf{x}) \sum_y p(y \mid \mathbf{x},d) \log Q_\theta(y \mid \mathbf{x},d) && \text{by the chain rule} \nonumber \\ &= \int_{d,\mathbf{x}} p(d,\mathbf{x})\, H\!\left(p(\cdot \mid \mathbf{x},d),\, Q_\theta(\cdot \mid \mathbf{x},d)\right) && \text{by definition of cross-entropy} \nonumber \\ &= \mathbb{E}_{d,\mathbf{x} \sim p(d,\mathbf{x})} \left[ H\!\left(p(\cdot \mid \mathbf{x},d),\, Q_\theta(\cdot \mid \mathbf{x},d)\right) \right] \label{eq:cross-entropy} \;. \end{align} \]

Thus, minimizing the Prior-Data NLL is equivalent to minimizing the expected cross-entropy between the true PPD \(p(\cdot \mid \mathbf{x},d)\) and the PFN prediction \(Q_\theta(\cdot \mid \mathbf{x},d)\).

The cross-entropy decomposes as:

\[ \begin{equation*} H\!\left(p(\cdot \mid \mathbf{x},d),\, Q_\theta(\cdot \mid \mathbf{x},d)\right) = H\!\left(p(\cdot \mid \mathbf{x},d)\right) + \mathrm{KL}\!\left(p(\cdot \mid \mathbf{x},d) \,\|\, Q_\theta(\cdot \mid \mathbf{x},d)\right) \;. \end{equation*} \]

Substituting this identity into Equation \(\eqref{eq:cross-entropy}\) yields:

\[ \begin{align} \ell_\theta &= \mathbb{E}_{d,\mathbf{x}}\left[ H\!\left(p(\cdot \mid \mathbf{x},d)\right) \right] + \mathbb{E}_{d,\mathbf{x}}\left[ \mathrm{KL}\!\left(p(\cdot \mid \mathbf{x},d) \,\|\, Q_\theta(\cdot \mid \mathbf{x},d)\right) \right] \nonumber \\ &= C + \mathbb{E}_{d,\mathbf{x}}\left[ \mathrm{KL}\!\left(p(\cdot \mid \mathbf{x},d) \,\|\, Q_\theta(\cdot \mid \mathbf{x},d)\right) \right] \label{eq:kl-decomp} \;, \end{align} \]

where \(C = \mathbb{E}_{d,\mathbf{x}}\left[ H\!\left(p(\cdot \mid \mathbf{x},d)\right) \right]\) is independent of \(\theta\).

Minimizing the Prior-Data NLL therefore minimizes the expected KL divergence from the true PPD to the PFN prediction:

\[ \begin{equation*} \arg\min_\theta \ell_\theta = \arg\min_\theta \mathbb{E}_{d,\mathbf{x}}\left[ \mathrm{KL}\!\left(p(\cdot \mid \mathbf{x},d) \,\|\, Q_\theta(\cdot \mid \mathbf{x},d)\right) \right] \;. \end{equation*} \]

If the model class can represent the true PPD and the global minimum is attained, then any global optimum \(\theta^*\) satisfies \(Q_{\theta^*}(\cdot \mid \mathbf{x},d) = p(\cdot \mid \mathbf{x},d)\) for \(p(d,\mathbf{x})\)-almost every \((\mathbf{x},d)\).

Minimizing \(\ell_\theta\) is tractable because it requires only samples from the prior predictive data-generating process. In each training step, draw a dataset, hold out one labeled point \((\mathbf{x},y)\), and minimize the negative log-likelihood that \(Q_\theta\) assigns to \(y\) given the remaining examples. No closed-form posterior, evaluable likelihood, or evaluable prior density is required.

Two practical refinements generalize this loop. First, each sampled dataset is split into a context and many held-out queries rather than a single pair. With attention masked so that each query token attends only to the context, one forward pass yields a loss term for every query, and each sampled dataset contributes many loss terms to Equation \(\eqref{eq:prior-data-nll}\). Second, the context size \(n\) is sampled over a range rather than fixed, so the expectation in Equation \(\eqref{eq:prior-data-nll}\) runs over datasets of varying size. This allows a single trained network to serve contexts of any size within the trained range.

Inference as In-Context Bayesian Prediction

At inference, a trained PFN performs in-context prediction. Given a context dataset \(d = \{(\mathbf{x}_i, y_i)\}_{i=1}^n\) and a query input \(\mathbf{x}\), the network returns a predictive distribution over the unknown label \(y\):

\[ \begin{equation*} Q_{\theta^*}(y \mid \mathbf{x}, d) \approx p(y \mid \mathbf{x}, d) \;. \end{equation*} \]

The context \(d\) provides information about the task through the input to the network. The parameters \(\theta^*\) stay fixed — no posterior is fit, no model is trained, and no gradient step is taken for the new dataset. Changing \(d\) changes the forward-pass activations, not the weights. By Equation \(\eqref{eq:kl-decomp}\), if the model class can represent the true PPD and the global minimum is attained, then \(Q_{\theta^*}(\cdot \mid \mathbf{x},d)\) matches \(p(\cdot \mid \mathbf{x},d)\) except possibly on context-query pairs \((d,\mathbf{x})\) that have probability zero under the PFN training distribution. In practice, Bayesian prediction for a new dataset is therefore approximated by a single forward pass of \(Q_{\theta^*}\).

The single-forward-pass approximation is only as good as the map learned during PFN training. The learned map depends on the expressiveness of \(Q_\theta\), the success of optimization, and the prior predictive distribution from which training datasets were sampled. Thus, a PFN is not a prior-free inference machine — it amortizes Bayesian prediction for the data-generating assumptions encoded by its sampler. The amortization is also bounded computationally. A forward pass through the Transformer scales quadratically in the number of input tokens, and predictions are reliable only for context sizes within the range sampled during training. Specializing PFNs to a domain therefore means choosing both a sampler and an operating range for \(n\); TabPFN makes both choices for small tabular datasets.

References

Transformers Can Do Bayesian Inference, ICLR (2022)

Samuel Müller, Noah Hollmann, Sebastian Pineda Arango, Josif Grabocka, and Frank Hutter

TabPFN: A Transformer That Solves Small Tabular Classification Problems in a Second, ICLR (2023)

Noah Hollmann, Samuel Müller, Katharina Eggensperger, and Frank Hutter

Accurate Predictions on Small Data with a Tabular Foundation Model, Nature (2025)

Noah Hollmann, Samuel Müller, Lennart Purucker, Arjun Krishnakumar, Max Körfer, Shi Bin Hoo, Robin Tibor Schirrmeister, and Frank Hutter