RetICL: Sequential Retrieval of In-Context Examples with Reinforcement Learning (2024)

Alexander Scarlatos  Andrew Lan
University of Massachusetts Amherst
{ajscarlatos,andrewlan}@cs.umass.edu

Abstract

Recent developments in large pre-trained language models have enabled unprecedented performance on a variety of downstream tasks. Achieving best performance with these models often leverages in-context learning, where a model performs a (possibly new) task given one or more examples. However, recent work has shown that the choice of examples can have a large impact on task performance and that finding an optimal set of examples is non-trivial. While there are many existing methods for selecting in-context examples, they generally score examples independently, ignoring the dependency between them and the order in which they are provided to the model. In this work, we propose Retrieval for In-Context Learning (RetICL), a learnable method for modeling and optimally selecting examples sequentially for in-context learning. We frame the problem of sequential example selection as a Markov decision process and train an example retriever using reinforcement learning. We evaluate RetICL on math word problem solving and scientific question answering tasks and show that it consistently outperforms or matches heuristic and learnable baselines. We also use case studies to show that RetICL implicitly learns representations of problem solving strategies.

RetICL: Sequential Retrieval of In-Context Examples with Reinforcement Learning


Alexander Scarlatos and Andrew LanUniversity of Massachusetts Amherst{ajscarlatos,andrewlan}@cs.umass.edu

1 Introduction

With the rising prominence of large pre-trained language models (LLMs), prior work has focused on how to best utilize them for various natural language tasks. One of the most popular methods for doing so is prompt tuning, which deals with carefully selecting the natural language prompt that maximizes model performance Liu etal. (2021b). While there are many approaches to prompt tuning, a very successful one is in-context learning (ICL) Brown etal. (2020). In ICL, examples of a new task that the LLM may not have been trained on before are included in the prompt, enabling it to leverage patterns in these examples in a few-shot way. However, the choice of which examples the LLM sees for a particular task can significantly affect the model’s performance Zhao etal. (2021).

The primary goal of ICL example selection is to find examples that, when used in the prompt, elicit a desired response from an LLM. Common practice is to define a function to measure the quality of a set of examples as ϕ(x,e1,,eT)italic-ϕ𝑥subscript𝑒1subscript𝑒𝑇\phi(x,e_{1},\ldots,e_{T})italic_ϕ ( italic_x , italic_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_e start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ), where x𝑥xitalic_x is the input for the current task and e1,,eTsubscript𝑒1subscript𝑒𝑇e_{1},\ldots,e_{T}italic_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_e start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT is a list of T𝑇Titalic_T examples drawn from a corpus 𝒞𝒞\mathcal{C}caligraphic_C, and use ϕitalic-ϕ\phiitalic_ϕ to rank candidate examples. Most existing works assume that examples work independently of each other, i.e., ϕ(x,e1,,eT)=t=1Tϕ(x,et)italic-ϕ𝑥subscript𝑒1subscript𝑒𝑇superscriptsubscriptproduct𝑡1𝑇superscriptitalic-ϕ𝑥subscript𝑒𝑡\phi(x,e_{1},\ldots,e_{T})=\prod_{t=1}^{T}\phi^{\prime}(x,e_{t})italic_ϕ ( italic_x , italic_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_e start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ) = ∏ start_POSTSUBSCRIPT italic_t = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_ϕ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_x , italic_e start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ). Thus, one can find the best set of examples by selecting the top T𝑇Titalic_T examples in the corpus with the highest values of ϕ(x,et)superscriptitalic-ϕ𝑥subscript𝑒𝑡\phi^{\prime}(x,e_{t})italic_ϕ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_x , italic_e start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ), which is often the semantic similarity between x𝑥xitalic_x and etsubscript𝑒𝑡e_{t}italic_e start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. However, there is significant interplay between the roles of different examples in deciding the output of LLMs. Some tasks benefit from example diversity Su etal. (2022), while others benefit from combining specific information across examples Levy etal. (2023). In these cases, simply selecting the top-T𝑇Titalic_T ranked examples may neglect ones that are ranked lower on their own but are useful in conjunction with other examples. Additionally, top-T𝑇Titalic_T selection ignores the order in which examples are provided as LLM input, which also has an impact on its output Lu etal. (2021). See Section2 for a more detailed discussion of related work on ICL example selection.

1.1 Contributions

In this work, we propose RetICL (Retrieval for In-Context Learning), a fully learnable method that sequentially retrieves ICL examples by conditioning on both the current problem and examples that have already been selected. We frame the problem of sequential example selection as a Markov decision process (MDP) and train an example retriever model using reinforcement learning (RL). We construct the model using a recurrent architecture where hidden states act as latent representations of MDP states, and model the example ranking function using a bilinear transformation between the latent and corpus spaces, enabling efficient inference-time maximization of ϕ(x,e1,,eT)italic-ϕ𝑥subscript𝑒1subscript𝑒𝑇\phi(x,e_{1},\ldots,e_{T})italic_ϕ ( italic_x , italic_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_e start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ). We also propose a novel confidence reward function, which uses the perplexity of the generated solution to help guide training. We validate RetICL on the math word problem (MWP) solving datasets TabMWP and GSM8K where it outperforms or matches both heuristic and learnable baselines. Additionally, to test RetICL’s ability to generalize across domains, we validate it on the scientific question answering dataset QASC, where RetICL outperforms all baselines.Finally, we qualitatively analyze RetICL’s learned policies and find that RetICL is able to implicitly infer problem solving strategies while learning its ICL example selection policy. We intend to make our implementation publicly available.

2 Related Work

In ICL, it is common to either randomly select in-context examples Brown etal. (2020); Lewkowycz etal. (2022) or use a hand-crafted set of examples Hendrycks etal. (2021); Wei etal. (2023). However, it is now well known that example selection and ordering can have a large impact on downstream text generation performance Gao etal. (2020); Liu etal. (2021a); Zhao etal. (2021); Lu etal. (2021). There are many existing methods for in-context example selection that focus on different aspects of the problem. Several seek to maximize how much coverage a set of examples provides over the dataset and use diversity to measure coverage Levy etal. (2023); Su etal. (2022); Pitis etal. (2023). Other methods consider semantic features of the examples Liu etal. (2021a); Fu etal. (2022) or compare these features with outputs of the target LLM Qin etal. (2023). Chang and Jia (2023) use random prompting but with a curated corpus based on how well examples perform on a validation set, and Rubin etal. (2021) use contrastive learning to retrieve examples that are likely to have similar labels to the target.While each of these methods tends to tackle a single aspect of example selection, we distinguish our method by combining several of these aspects, particularly considering groups of examples (including ordering) and using the LLM’s outputs for a training signal.There are other works that use RL for ICL example selection Lu etal. (2022); Zhang etal. (2022), although they either do not include previously selected examples in the state or only use high-level features of the examples, while RetICL uses their exact textual content.

3 Methodology

In this section, we detail how we frame ICL example selection as an MDP, how our example retriever model works, and how we train it using RL. We show an overview of our methodology in Figure 1.

RetICL: Sequential Retrieval of In-Context Examples with Reinforcement Learning (1)

3.1 MDP Formulation and Reward Function

We can view ICL example selection as a sequential decision making problem, where we select examples one at a time in such a way that we maximize our chances of achieving some goal when the examples are used as context. Our goal is to maximize r((x,e1,,eT),y)𝑟𝑥subscript𝑒1subscript𝑒𝑇𝑦r(\mathcal{M}(x,e_{1},\ldots,e_{T}),y)italic_r ( caligraphic_M ( italic_x , italic_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_e start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ) , italic_y ), where ()\mathcal{M}(\cdot)caligraphic_M ( ⋅ ) returns the generated output of an LLM given a prompt, y𝑦yitalic_y is the label corresponding to x𝑥xitalic_x, and r𝑟ritalic_r is a task-specific function that returns how good the generated output is. We note that in this setup, the order in which examples are selected matters since the order in which they are provided to \mathcal{M}caligraphic_M must be defined. We also note that while in this work we set T𝑇Titalic_T to a constant, it can also be dynamically set during the decision-making process, which we leave for future work. With this framing, we can naturally define an MDP where the state at time step t𝑡titalic_t corresponds to both x𝑥xitalic_x and the first t𝑡titalic_t examples that have been selected, and the action space is the set of potential candidates to be the next example. Formally,

S0=x,St=x,e1,,et,At=et+1𝒞.formulae-sequencesubscript𝑆0𝑥formulae-sequencesubscript𝑆𝑡𝑥subscript𝑒1subscript𝑒𝑡subscript𝐴𝑡subscript𝑒𝑡1𝒞\displaystyle S_{0}=x,\quad S_{t}=x,e_{1},\ldots,e_{t},\quad A_{t}=e_{t+1}\in%\mathcal{C}.italic_S start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = italic_x , italic_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_x , italic_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_e start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_A start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_e start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ∈ caligraphic_C .

We now define the reward function for the MDP, which we break into two parts: a task-specific goal reward, and a supplementary confidence reward. We define the goal reward, RGsuperscript𝑅𝐺R^{G}italic_R start_POSTSUPERSCRIPT italic_G end_POSTSUPERSCRIPT, simply as the output of r𝑟ritalic_r, as long as it can be formulated to return a scalar value. In settings with definitive correct and incorrect answers, it is natural for r𝑟ritalic_r to be binary, where it returns 1 when the generated solution results in a correct answer and -1 when the generated solution results in an incorrect answer.However, with chain-of-thought (CoT) prompting Wei etal. (2023), this reward function does not account for the reasoning strategy in the solution process that led to the final answer and thus cannot distinguish between sound and flawed logic.To address this issue, we introduce the confidence reward, RCsuperscript𝑅𝐶R^{C}italic_R start_POSTSUPERSCRIPT italic_C end_POSTSUPERSCRIPT, which we define as the inverse perplexity of the generated solution assigned by the LLM, normalized to the range [1,1]11[-1,1][ - 1 , 1 ]. We hypothesize that when an LLM generates a correct solution with high probability (low perplexity), it is likely that the model “knew” how to solve the problem, rather than getting it correct by guessing or using unsound reasoning to arrive at a final answer. Additionally, we hypothesize that when an LLM generates an incorrect solution with high probability, it may have sound reasoning overall but contain a small error, such as an incorrect calculation. Recent works have also found that model confidence is predictive of downstream performance Tian etal. (2023); Kadavath etal. (2022).We define the final reward function to be the average of RGsuperscript𝑅𝐺R^{G}italic_R start_POSTSUPERSCRIPT italic_G end_POSTSUPERSCRIPT and RCsuperscript𝑅𝐶R^{C}italic_R start_POSTSUPERSCRIPT italic_C end_POSTSUPERSCRIPT at the final time step and 0 at all prior time steps. We formally define our reward function as

y^=(x,e1,,eT),^𝑦𝑥subscript𝑒1subscript𝑒𝑇\displaystyle\hat{y}=\mathcal{M}(x,e_{1},\ldots,e_{T}),over^ start_ARG italic_y end_ARG = caligraphic_M ( italic_x , italic_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_e start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ) ,
RG=r(y^,y)=2𝕀[g(y^,y)]1,superscript𝑅𝐺𝑟^𝑦𝑦2𝕀delimited-[]𝑔^𝑦𝑦1\displaystyle R^{G}=r(\hat{y},y)=2\cdot\mathbb{I}[g(\hat{y},y)]-1,italic_R start_POSTSUPERSCRIPT italic_G end_POSTSUPERSCRIPT = italic_r ( over^ start_ARG italic_y end_ARG , italic_y ) = 2 ⋅ blackboard_I [ italic_g ( over^ start_ARG italic_y end_ARG , italic_y ) ] - 1 ,
RC=2p(y^|x,e1,,eT)1|y^|1,superscript𝑅𝐶2subscript𝑝superscriptconditional^𝑦𝑥subscript𝑒1subscript𝑒𝑇1^𝑦1\displaystyle R^{C}=2\cdot p_{\mathcal{M}}(\hat{y}|x,e_{1},\ldots,e_{T})^{%\frac{1}{|\hat{y}|}}-1,italic_R start_POSTSUPERSCRIPT italic_C end_POSTSUPERSCRIPT = 2 ⋅ italic_p start_POSTSUBSCRIPT caligraphic_M end_POSTSUBSCRIPT ( over^ start_ARG italic_y end_ARG | italic_x , italic_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_e start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT divide start_ARG 1 end_ARG start_ARG | over^ start_ARG italic_y end_ARG | end_ARG end_POSTSUPERSCRIPT - 1 ,
Rt={0.5RG+0.5RCift=T0otherwise,subscript𝑅𝑡cases0.5superscript𝑅𝐺0.5superscript𝑅𝐶if𝑡𝑇0otherwise\displaystyle R_{t}=\begin{cases}0.5R^{G}+0.5R^{C}&\text{if }t=T\\0&\text{otherwise},\end{cases}italic_R start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = { start_ROW start_CELL 0.5 italic_R start_POSTSUPERSCRIPT italic_G end_POSTSUPERSCRIPT + 0.5 italic_R start_POSTSUPERSCRIPT italic_C end_POSTSUPERSCRIPT end_CELL start_CELL if italic_t = italic_T end_CELL end_ROW start_ROW start_CELL 0 end_CELL start_CELL otherwise , end_CELL end_ROW

where y^^𝑦\hat{y}over^ start_ARG italic_y end_ARG is the generated solution, g𝑔gitalic_g is a function that checks if two solutions have the same final answer, 𝕀𝕀\mathbb{I}blackboard_I is the indicator function, and psubscript𝑝p_{\mathcal{M}}italic_p start_POSTSUBSCRIPT caligraphic_M end_POSTSUBSCRIPT returns the probability assigned by the LLM.

3.2 Retriever Model

We now detail our model for example retrieval. At a high level, the model constructs a latent representation for each state Stsubscript𝑆𝑡S_{t}italic_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT in the MDP and uses this representation to construct the policy π(St,e)𝜋subscript𝑆𝑡𝑒\pi(S_{t},e)italic_π ( italic_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_e ), which represents the probability of selecting e𝑒eitalic_e to be the next example. After using the policy to select an example, we add it to the current sequence of examples, giving us the state St+1subscript𝑆𝑡1S_{t+1}italic_S start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT, and the process continues sequentially.

We use a long short-term memory (LSTM) model Hochreiter and Schmidhuber (1997) as the base model, where the hidden state 𝐡tsubscript𝐡𝑡\mathbf{h}_{t}bold_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT acts as the latent representation for Stsubscript𝑆𝑡S_{t}italic_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. We construct the initial hidden state of the LSTM, 𝐡0subscript𝐡0\mathbf{h}_{0}bold_h start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, using a vectorized embedding of the input x𝑥xitalic_x, and set the input of the LSTM at time step t𝑡titalic_t to be a vectorized embedding of the example etsubscript𝑒𝑡e_{t}italic_e start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. In this work, we construct these vectorized embeddings using a pre-trained S-BERT model Reimers and Gurevych (2019) and additionally provide learnable soft prompts Lester etal. (2021) to S-BERT to help align the embeddings with the current task. In our experiments, we found that fine-tuning the S-BERT parameters directly did not improve performance.

We produce the policy π(St,e)𝜋subscript𝑆𝑡𝑒\pi(S_{t},e)italic_π ( italic_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_e ) by first producing an unnormalized activation value for each example in the corpus, ϕ(St,e)italic-ϕsubscript𝑆𝑡𝑒\phi(S_{t},e)italic_ϕ ( italic_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_e ), and then using the softmax function to convert these activations into a probability distribution. We construct each ϕ(St,e)italic-ϕsubscript𝑆𝑡𝑒\phi(S_{t},e)italic_ϕ ( italic_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_e ) by performing a learnable bilinear transformation between 𝐡tsubscript𝐡𝑡\mathbf{h}_{t}bold_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT and the vectorized embedding of e𝑒eitalic_e. We choose to model ϕitalic-ϕ\phiitalic_ϕ using a bilinear transformation for two reasons. First, the bilinear transformation learns a mapping between the model’s latent space and the example embedding space, enabling generalization to examples not seen during training and also adding some model interpretability, as we will show later in this paper. Second, the bilinear transformation enables efficient computation of the policy over a large corpus at inference time, which we describe in detail in Supplementary Material C. We additionally use 𝐡tsubscript𝐡𝑡\mathbf{h}_{t}bold_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT to produce an estimate of the value function, v^(St)^𝑣subscript𝑆𝑡\hat{v}(S_{t})over^ start_ARG italic_v end_ARG ( italic_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ), which is required for variance reduction techniques when training policy gradient methods. Concretely, our model architecture is defined as

𝐱=SBERT(𝐏x,x),𝐞=SBERT(𝐏e,e),formulae-sequence𝐱SBERTsubscript𝐏𝑥𝑥𝐞SBERTsubscript𝐏𝑒𝑒\displaystyle\mathbf{x}=\operatorname{S-BERT}(\mathbf{P}_{x},x),\quad\mathbf{e%}=\operatorname{S-BERT}(\mathbf{P}_{e},e),bold_x = start_OPFUNCTION roman_S - roman_BERT end_OPFUNCTION ( bold_P start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT , italic_x ) , bold_e = start_OPFUNCTION roman_S - roman_BERT end_OPFUNCTION ( bold_P start_POSTSUBSCRIPT italic_e end_POSTSUBSCRIPT , italic_e ) ,
𝐡0=tanh(𝐖x𝐱+𝐛x),subscript𝐡0subscript𝐖𝑥𝐱subscript𝐛𝑥\displaystyle\mathbf{h}_{0}=\tanh(\mathbf{W}_{x}\mathbf{x}+\mathbf{b}_{x}),bold_h start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = roman_tanh ( bold_W start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT bold_x + bold_b start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ) ,
𝐡t>0=LSTM(𝐡0;𝐞1,,𝐞t),subscript𝐡𝑡0LSTMsubscript𝐡0subscript𝐞1subscript𝐞𝑡\displaystyle\mathbf{h}_{t>0}=\operatorname{LSTM}(\mathbf{h}_{0};\mathbf{e}_{1%},\ldots,\mathbf{e}_{t}),bold_h start_POSTSUBSCRIPT italic_t > 0 end_POSTSUBSCRIPT = roman_LSTM ( bold_h start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ; bold_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_e start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ,
v^(St)=𝐡tT𝐰v+bv,^𝑣subscript𝑆𝑡superscriptsubscript𝐡𝑡𝑇subscript𝐰𝑣subscript𝑏𝑣\displaystyle\hat{v}(S_{t})=\mathbf{h}_{t}^{T}\mathbf{w}_{v}+b_{v},over^ start_ARG italic_v end_ARG ( italic_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) = bold_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_w start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT + italic_b start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ,
ϕ(St,e)={𝐡tT𝐖a𝐞ife{e1,,et}otherwise,italic-ϕsubscript𝑆𝑡𝑒casessuperscriptsubscript𝐡𝑡𝑇subscript𝐖𝑎𝐞if𝑒subscript𝑒1subscript𝑒𝑡otherwise\displaystyle\phi(S_{t},e)=\begin{cases}\mathbf{h}_{t}^{T}\mathbf{W}_{a}%\mathbf{e}&\text{if }e\notin\{e_{1},\ldots,e_{t}\}\\-\infty&\text{otherwise},\end{cases}italic_ϕ ( italic_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_e ) = { start_ROW start_CELL bold_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_W start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT bold_e end_CELL start_CELL if italic_e ∉ { italic_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_e start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT } end_CELL end_ROW start_ROW start_CELL - ∞ end_CELL start_CELL otherwise , end_CELL end_ROW
π(St,e)=exp(ϕ(St,e))/e𝒞exp(ϕ(St,e)),𝜋subscript𝑆𝑡𝑒italic-ϕsubscript𝑆𝑡𝑒subscriptsuperscript𝑒𝒞italic-ϕsubscript𝑆𝑡superscript𝑒\displaystyle\pi(S_{t},e)=\exp(\phi(S_{t},e))/\textstyle\sum_{e^{\prime}\in%\mathcal{C}}\exp(\phi(S_{t},e^{\prime})),italic_π ( italic_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_e ) = roman_exp ( italic_ϕ ( italic_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_e ) ) / ∑ start_POSTSUBSCRIPT italic_e start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ caligraphic_C end_POSTSUBSCRIPT roman_exp ( italic_ϕ ( italic_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_e start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ) ,

where 𝐏xm×desubscript𝐏𝑥superscript𝑚subscript𝑑𝑒\mathbf{P}_{x}\in\mathbb{R}^{m\times d_{e}}bold_P start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_m × italic_d start_POSTSUBSCRIPT italic_e end_POSTSUBSCRIPT end_POSTSUPERSCRIPT and 𝐏em×desubscript𝐏𝑒superscript𝑚subscript𝑑𝑒\mathbf{P}_{e}\in\mathbb{R}^{m\times d_{e}}bold_P start_POSTSUBSCRIPT italic_e end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_m × italic_d start_POSTSUBSCRIPT italic_e end_POSTSUBSCRIPT end_POSTSUPERSCRIPT are learnable soft prompts, 𝐖xdh×desubscript𝐖𝑥superscriptsubscript𝑑subscript𝑑𝑒\mathbf{W}_{x}\in\mathbb{R}^{d_{h}\times d_{e}}bold_W start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT italic_e end_POSTSUBSCRIPT end_POSTSUPERSCRIPT and 𝐛xdhsubscript𝐛𝑥superscriptsubscript𝑑\mathbf{b}_{x}\in\mathbb{R}^{d_{h}}bold_b start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUPERSCRIPT transform the input embedding space into the latent space, 𝐰vdhsubscript𝐰𝑣superscriptsubscript𝑑\mathbf{w}_{v}\in\mathbb{R}^{d_{h}}bold_w start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUPERSCRIPT and bvsubscript𝑏𝑣b_{v}\in\mathbb{R}italic_b start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ∈ blackboard_R produce the value function estimate from the latent space, 𝐖adh×desubscript𝐖𝑎superscriptsubscript𝑑subscript𝑑𝑒\mathbf{W}_{a}\in\mathbb{R}^{d_{h}\times d_{e}}bold_W start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT italic_e end_POSTSUBSCRIPT end_POSTSUPERSCRIPT performs the bilinear transformation between the latent space and example embedding space, m𝑚mitalic_m is the soft prompt length, desubscript𝑑𝑒d_{e}italic_d start_POSTSUBSCRIPT italic_e end_POSTSUBSCRIPT is the dimension of the S-BERT text embedding vector, and dhsubscript𝑑d_{h}italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT is the size of the LSTM’s hidden states. We set ϕ(St,e)italic-ϕsubscript𝑆𝑡𝑒\phi(S_{t},e)italic_ϕ ( italic_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_e ) to -\infty- ∞ when e𝑒eitalic_e has already been selected to avoid selecting the same example multiple times, which is in line with existing methods.

3.3 Training and Inference

We train the retriever model using proximal policy optimization (PPO) Schulman etal. (2017) with generalized advantage estimation (GAE) Schulman etal. (2015) as our advantage function. We use a reward discount of γ=1𝛾1\gamma=1italic_γ = 1 since all episodes have fixed length and the reward is assigned only at the final time step. We train the value function estimator using mean squared error (MSE) with RTsubscript𝑅𝑇R_{T}italic_R start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT as the target at each time step and weigh the value function loss with a hyperparameter cVFsubscript𝑐VFc_{\text{VF}}italic_c start_POSTSUBSCRIPT VF end_POSTSUBSCRIPT. We also encourage exploration by adding the negative entropy of the policy at each time step to the loss Ahmed etal. (2019), where we additionally weigh the entropy by a hyperparameter cEsubscript𝑐Ec_{\text{E}}italic_c start_POSTSUBSCRIPT E end_POSTSUBSCRIPT and normalize by a factor of 1log(|𝒞|)1𝒞\frac{1}{\log(|\mathcal{C}|)}divide start_ARG 1 end_ARG start_ARG roman_log ( | caligraphic_C | ) end_ARG to account for training with different corpus sizes.

At training time, we select a batch of problems from the dataset, encode their inputs with S-BERT, encode all examples in the corpus with S-BERT, and then construct a sequence of examples for each problem by sequentially sampling from the policy, i.e., et+1π(St,)similar-tosubscript𝑒𝑡1𝜋subscript𝑆𝑡e_{t+1}\sim\pi(S_{t},\cdot)italic_e start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ∼ italic_π ( italic_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , ⋅ ). When T𝑇Titalic_T examples have been selected for each problem, we prompt the LLM with the examples and the current problem’s input, calculate the reward from the LLM’s generations, average the PPO loss, value function loss, and entropy loss over the batch, and backpropagate through our model. At inference time, we greedily select examples from the policy as et+1=argmaxe𝒞π(St,e)subscript𝑒𝑡1subscriptargmax𝑒𝒞𝜋subscript𝑆𝑡𝑒e_{t+1}=\operatorname*{argmax}_{e\in\mathcal{C}}\pi(S_{t},e)italic_e start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT = roman_argmax start_POSTSUBSCRIPT italic_e ∈ caligraphic_C end_POSTSUBSCRIPT italic_π ( italic_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_e ), which empirically outperforms sampling in our experiments.

4 Experiments

In this section, we validate RetICL on MWP solving and scientific question answering tasks and quantitatively compare its performance against several baselines. We also perform an ablation study and examine the effects of adjusting several parameters in order to determine which aspects of the methodology work well and which ones need improvement in future work.

4.1 Datasets

We validate RetICL on two MWP datasets that contain detailed solution steps: TabMWP Lu etal. (2022), where solving each problem requires extracting and reasoning with information from tables, and GSM8K Cobbe etal. (2021), where solving each problem requires multi-step mathematical reasoning and applying various arithmetic operations. In order to test whether RetICL generalizes beyond MWP solving, we also validate on QASC Khot etal. (2019), which contains multiple choice science questions along with explanations for each answer, where answering each question requires logical deduction and leveraging world knowledge.

For all datasets, we use CoT prompting and evaluate correctness based on the final answer of the generated solution.We note that the official TabMWP evaluation code has some issues with regular expressions that cause both false positives and false negatives when evaluating correctness on multiple choice problems. We instead use our own, fixed code to evaluate correctness on TabMWP. We provide additional information on each dataset and the prompts we use in Supplementary Material B.

4.2 Experimental Settings

For each dataset, we randomly select 5,000 problems for training and an additional 200 problems for the corpus, both from the original training set.While it is possible to use a larger corpus, e.g., all remaining problems in the training set, we find that training on a smaller corpus results in higher accuracy. We randomly select 500 problems from the validation set to evaluate on after each epoch and save the model with the highest accuracy on this set. For validation and testing, we use the entire training set as the corpus.We use OpenAI’s code-davinci-002 Codex model Chen etal. (2021) as the LLM for all experiments since it is free and works well on our task, and we leave experimentation on larger, more expensive models for future work.We set the number of in-context examples to T=2𝑇2T=2italic_T = 2 for all experiments to enable a fair comparison to other methods Lu etal. (2022). We note that while increasing T𝑇Titalic_T tends to increase performance for heuristic methods, we find that RetICL’s performance does not tend to increase with T𝑇Titalic_T as expected. We provide details on the impact of T𝑇Titalic_T in Supplementary Material E and leave further exploration of increasing T𝑇Titalic_T for future work.We provide additional hyperparameter settings and implementation details in Supplementary Material A.

4.3 Baselines

We compare RetICL to the following baselines for in-context example selection, comprising both heuristic-based and learnable methods. We also estimate a performance upper bound by exhaustively checking all possible example combinations.

RandomWith random selection, for each problem, we randomly sample T𝑇Titalic_T unique examples from the corpus for the ICL prompt. We evaluate random selection on 3 random seeds and report the average accuracy across all 3 runs.

kNNWith kNN selection Liu etal. (2021a), for each problem, we select the T𝑇Titalic_T examples with the most similar problem inputs from the corpus and use those for the ICL prompt, putting more similar examples later in the prompt due to recency bias Zhao etal. (2021). We evaluate similarity according to the Euclidean distance between the S-BERT embeddings of the problem inputs using the same pre-trained S-BERT model as RetICL.

ComplexityWith complexity-based selection Fu etal. (2022), for each problem, we randomly select T𝑇Titalic_T examples that have the most complex reasoning in the label for the ICL prompt. For TabMWP and GSM8K, we define complexity as the number of steps in the solution, and for QASC we define it as the number of words in the label.

MethodTabMWPGSM8KQASC
Acc.Ex.Acc.Ex.Acc.Ex.
Exhaustive98.303797.954798.4936
Random72.0411,20357.192,15370.411,635
kNN Liu etal. (2021a)88.9510,00359.741,88361.99964
Complexity Fu etal. (2022)63.8028154.661374.193
PromptPG Lu etal. (2022)73.43756.94873.652
LSTM Classifier77.211364.82469.658
RetICL88.5840766.119776.13135

PromptPGWith PromptPG Lu etal. (2022), for each problem, a learned scoring function is evaluated on each individual example in the corpus, and the top T𝑇Titalic_T scoring examples are selected for the ICL prompt.While PromptPG also uses RL to learn its example selection policy, there are many key differences between their method and ours: they do not include previously selected examples in the state, they do not use a confidence-based reward, and they use a much simpler RL framework.We evaluate PromptPG’s performance by running their code with modifications to match our prompting style and use our fixed evaluation code.

LSTM ClassifierIn order to determine the effectiveness of RetICL’s training pipeline, we train a model with a similar architecture but in a supervised manner. Specifically, using the same input format as RetICL, we train an LSTM classifier to predict, at each time step, if the prompt will result in a correct or incorrect response from an LLM. We train on 20 randomly sampled prompts for each sample in the training set with the same training set and corpus as RetICL. At inference time, we greedily select examples that result in the highest predicted likelihood of getting a correct response. We note that this setup is equivalent to estimating the Q𝑄Qitalic_Q-function of a random policy, and provide further details in Supplementary Material D.

ExhaustiveWith exhaustive evaluation, for each problem, we construct a one-shot ICL prompt for each example in the corpus, and consider the current problem to be solved if a correct solution is generated from any of the prompts. We use one-shot prompts instead of few-shot prompts to reduce the search space. Additionally, we restrict the corpus size to 100 and only evaluate on the pre-defined 1,000-sample subset of the test set for TabMWP to reduce computation time.

4.4 Results

Table 1 shows the performance of RetICL and baselines on all datasets, with Acc.being problem solving accuracy and Ex.being the number of unique examples used for each method. We see that RetICL performs the best among non-exhaustive methods on all datasets, beating most baselines by a large margin, with the exception of TabMWP where kNN performs slightly better. We note that while the relative performance of baselines seems to depend on the dataset, RetICL performs well across datasets, showing that our methodology is more generalizable across tasks.

We note that kNN likely performs well on TabMWP due to the presence of many problems in the dataset with very high similarity. Many problems have exactly the same question text other than a few numbers or names changed, making it easy for the LLM to generate a correct solution given a highly similar example. On the contrary, GSM8K does not tend to contain problems that are almost identical, which makes kNN ineffective since problems with high textual similarity may not have similar solution strategies. Additionally, kNN performs worse than Random on QASC since we find that using examples that are highly similar to the current question can cause the LLM to copy answers from the examples and ignore the details of the current question. Furthermore, we see that Complexity is a poor heuristic for TabMWP since it does not account for problem similarity, and also for GSM8K since it uses some examples with overly abstract reasoning that appear to confuse the model.

Perhaps surprisingly, we see that PromptPG is only slightly better than Random on TabMWP and performs on par with Random on GSM8K. And while PromptPG is better than Random and kNN on QASC, we observe that it selects the same examples for each question, thus lacking example diversity. While the results on TabMWP contradict the trends reported in Lu etal. (2022), we believe the discrepancy is mostly due to using the fixed evaluation code. We believe that PromptPG’s relatively low performance also highlights the challenges of solving the ICL example selection problem using RL; in practice, many training tricks are necessary to achieve high performance.

We see that the LSTM Classifier performs relatively well on GSM8K but not on other datasets. We believe the reason is that TabMWP and QASC require more targeted example selection strategies, which are difficult to extract from random example prompts due to sparse coverage of the exponentially large action space. This result highlights the benefit of on-policy learning for example selection.We also see that the Exhaustive method achieves almost perfect accuracy on all datasets. We find this surprising, especially due to the fact that Exhaustive only uses a single ICL example and has access to a smaller corpus. This result implies that there is significant room for growth in ICL example selection methods and also implies that one-shot ICL has the potential to be extremely powerful as long as the example corpus is informative, even for challenging text generation tasks.

In order to determine how well RetICL can generalize to low-resource settings, we examine the effect of reducing the number of available examples at test time. We evaluate RetICL and kNN on TabMWP and GSM8K where we use 0.1%, 1%, 10%, and 100% of all examples as the corpus and show our results in Figure 2. For a clearer visualization,we show the relative accuracy of each method compared to RetICL’s accuracy when the full corpus is available. For TabMWP, we evaluate on the pre-defined 1000-sample subset of the test set. We see that while performance tends to decrease with less available examples, RetICL still retains at least 90% of its performance with only 0.1% of examples available, suggesting that it is still effective in low-resource settings.Additionally, we see that while RetICL and kNN perform similarly at most corpus sizes, RetICL performs better at 0.1%, likely because there are not enough highly similar examples for kNN to leverage. Finally, we note that the irregular trend for GSM8K with kNN implies that semantic similarity of examples is not correlated with problem solving accuracy on this dataset, suggesting that while kNN is very effective for some tasks it is not as robust as RetICL.

RetICL: Sequential Retrieval of In-Context Examples with Reinforcement Learning (2)

4.5 Ablation study

We now examine the impact of various modeling and algorithmic choices via an ablation study. We train on 1,000 problems instead of 5,000 (𝐓𝟏𝐤subscript𝐓𝟏𝐤\textbf{T}_{\textbf{1k}}T start_POSTSUBSCRIPT 1k end_POSTSUBSCRIPT), we no longer use the confidence reward, RCsuperscript𝑅𝐶R^{C}italic_R start_POSTSUPERSCRIPT italic_C end_POSTSUPERSCRIPT, and instead only use the goal reward, RGsuperscript𝑅𝐺R^{G}italic_R start_POSTSUPERSCRIPT italic_G end_POSTSUPERSCRIPT (Conf. Rew.), we no longer condition on previously selected examples by removing the LSTM architecture and instead set the latent state for all time steps to be 𝐡0subscript𝐡0\mathbf{h}_{0}bold_h start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT (LSTM), we no longer include an entropy term in the loss function (Ent.), we no longer train with PPO and instead use REINFORCE (R) and REINFORCE with Baseline (RwB), we no longer provide learnable soft prompts to the S-BERT encoder (SP), and we vary the size of the corpus at train time, using a corpus with 20 problems (𝐓𝐂𝟐𝟎subscript𝐓𝐂𝟐𝟎\textbf{TC}_{\textbf{20}}TC start_POSTSUBSCRIPT 20 end_POSTSUBSCRIPT) from the training set and all remaining problems (𝐓𝐂𝐚𝐥𝐥subscript𝐓𝐂𝐚𝐥𝐥\textbf{TC}_{\textbf{all}}TC start_POSTSUBSCRIPT all end_POSTSUBSCRIPT) from the training set. Additionally, we use T1ksubscriptT1k\text{T}_{\text{1k}}T start_POSTSUBSCRIPT 1k end_POSTSUBSCRIPT for all ablations for fast experimentation, and apply the SP ablation for the TC ablations since otherwise the partial gradients over S-BERT parameters will cause memory issues for TCallsubscriptTCall\text{TC}_{\text{all}}TC start_POSTSUBSCRIPT all end_POSTSUBSCRIPT.

Table 2 shows the results of the ablation study, which we run on both TabMWP and GSM8K.For TabMWP, we evaluate on the pre-defined 1,000-sample subset of the test set.We see that removing the confidence reward and LSTM architecture, which are our key contributions, both negatively impact accuracy, implying that these techniques have a positive impact on ICL example selection. We also see that removing the entropy term has a significant negative impact on accuracy and example diversity, implying that this term is necessary for training the model. Furthermore, we see that REINFORCE with Baseline is similar or slightly worse than PPO, while REINFORCE is significantly worse. Surprisingly, we see that removing soft prompts slightly increases accuracy, although significantly hurts example diversity on GSM8K, due to training instability. We hypothesize that soft prompts may make it harder to find an optimal policy due to increased model complexity, and plan on finding better ways to fine-tune S-BERT in future work. Finally, we see that using much smaller and much larger example corpora at train time both negatively impact accuracy, showing that corpus size is an important hyperparameter and confirming similar results in Lu etal. (2022).

AblationTabMWPGSM8K
Acc.Ex.Acc.Ex.
None88.3022666.1197
T1ksubscriptT1k\text{T}_{\text{1k}}T start_POSTSUBSCRIPT 1k end_POSTSUBSCRIPT87.3015965.9634
T1ksubscriptT1k\text{T}_{\text{1k}}T start_POSTSUBSCRIPT 1k end_POSTSUBSCRIPT, Conf. Rew.86.008464.6720
T1ksubscriptT1k\text{T}_{\text{1k}}T start_POSTSUBSCRIPT 1k end_POSTSUBSCRIPT, LSTM85.1012063.9138
T1ksubscriptT1k\text{T}_{\text{1k}}T start_POSTSUBSCRIPT 1k end_POSTSUBSCRIPT, Ent.79.901562.776
T1ksubscriptT1k\text{T}_{\text{1k}}T start_POSTSUBSCRIPT 1k end_POSTSUBSCRIPT, R74.90661.9414
T1ksubscriptT1k\text{T}_{\text{1k}}T start_POSTSUBSCRIPT 1k end_POSTSUBSCRIPT, RwB85.3010466.195
T1ksubscriptT1k\text{T}_{\text{1k}}T start_POSTSUBSCRIPT 1k end_POSTSUBSCRIPT, SP88.4016366.263
T1ksubscriptT1k\text{T}_{\text{1k}}T start_POSTSUBSCRIPT 1k end_POSTSUBSCRIPT, SP, TC20subscriptTC20\text{TC}_{20}TC start_POSTSUBSCRIPT 20 end_POSTSUBSCRIPT84.507661.8758
T1ksubscriptT1k\text{T}_{\text{1k}}T start_POSTSUBSCRIPT 1k end_POSTSUBSCRIPT, SP, TCallsubscriptTCall\text{TC}_{\text{all}}TC start_POSTSUBSCRIPT all end_POSTSUBSCRIPT85.5012265.8858

5 Qualitative Analysis

We now present several qualitative analyses to interpret RetICL’s learned example selection policy. Our goal is to determine what features RetICL focuses on in individual examples and what strategy RetICL uses to select examples sequentially. We investigate these strategies by first visualizing learned latent example embeddings and then analyzing trends in per-problem example selections.

5.1 Latent Space Analysis

In order to identify features in the selected examples that are being emphasized by RetICL, we perform a visual analysis of the example embeddings in the model’s latent space. Specifically, we transform each example embedding 𝐞𝐞\mathbf{e}bold_e into the model’s latent space using the right half of the bilinear term in ϕitalic-ϕ\phiitalic_ϕ, i.e., 𝐖a𝐞subscript𝐖𝑎𝐞\mathbf{W}_{a}\mathbf{e}bold_W start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT bold_e. We note that since maximizing ϕ(St,e)italic-ϕsubscript𝑆𝑡𝑒\phi(S_{t},e)italic_ϕ ( italic_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_e ) is equivalent to maximizing the inner product 𝐡t,𝐖a𝐞subscript𝐡𝑡subscript𝐖𝑎𝐞\langle\mathbf{h}_{t},\mathbf{W}_{a}\mathbf{e}\rangle⟨ bold_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_W start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT bold_e ⟩, the most likely example to be selected is the one where 𝐖a𝐞subscript𝐖𝑎𝐞\mathbf{W}_{a}\mathbf{e}bold_W start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT bold_e is closest to 𝐡tsubscript𝐡𝑡\mathbf{h}_{t}bold_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT in the latent space.111Maximum inner product and minimum distance are equivalent in our case since 𝐞𝐞\mathbf{e}bold_e is normalized. Therefore, analyzing local regions in the example embedding space reveals how RetICL’s learned policy ranks examples.

For both TabMWP and GSM8K, we randomly select 1,000 examples from the corpus and then apply t-SNE Vander Maaten and Hinton (2008) to reduce their embeddings to 2 dimensions for visualization. Additionally, for the same sets of examples, we also visualize their pre-trained S-BERT embeddings in the same way in order to demonstrate how inter-example similarities change after RetICL training. We summarize our findings here and show the visualizations with a more detailed analysis in Supplementary Material F.

For TabMWP, the S-BERT embeddings are clustered based on problem template, since problems tend to fall into distinct structural and semantic categories. While RetICL mostly retains these clusters in the latent space, it also merges together some clusters of problems that can be solved with similar reasoning strategies, such as finding the largest or smallest value in a set. For GSM8K, the clusters are much less clear since problems cannot be easily placed into categories. However, while the S-BERT embeddings are generally grouped by problem topic, RetICL also groups them by the number of steps in a solution. This result implies that RetICL has learned that solution length is an important feature of ICL examples, and confirms findings from prior work that solution complexity impacts problem solving in LLMs Fu etal. (2022). These findings suggest that RetICL can identify meaningful features, often related to solution strategy, that indicate an example’s utility in ICL.

5.2 Per-Problem Example Selection

We now examine example selections at the per-problem level in order to gain further insight into RetICL’s learned example selection policy.

Table 5 in Supplementary Material G shows the in-context examples selected to help solve a representative problem from the GSM8K dataset. We see that RetICL tends to select examples that share some unique high-level features with the current problem, such as subtracting from a total value, adding up values over some period of time, or defining variables to be proportional to other variables. We note that each problem in the dataset exhibits several such features, so RetICL has to implicitly decide which features are important to the current problem and identify examples with those features. We also see that RetICL tends to select examples with solutions that are relatively long and have numerous reasoning steps. Problem solving errors can be divided into several categories. First, the LLM can exhibit misconceptions when it lacks an example to provide context, such as misinterpreting the meaning of a “discount” when not explicitly instructed. Second, the LLM can try to follow the examples too closely and use reasoning that does not necessarily apply to the current problem. These errors indicate that RetICL’s policy can be improved by selecting based on a broader and more targeted set of features. However, many incorrect solutions are caused by simple arithmetic errors or switching the roles of variables in the problem. We believe these errors are due to limitations of the LLM and are a likely source of noise in the training signal, making it harder to find an optimal policy. Such errors can be fixed by using self-consistency Wang etal. (2023) or external computation engines Wolfram (2023), which we leave for future work.

Table 6 shows in-context examples selected for a problem from the TabMWP dataset. RetICL’s selections tend to follow a surprising pattern: the first example is seemingly unrelated to the current problem, while the second example has similar reasoning steps to the current problem. This strategy has several implications. First, it suggests that RetICL can infer reasoning steps from the current problem and select examples based on this information. Second, it suggests that diversity of examples may help prevent overfitting to particular features in the examples.Problem solving errors tend to be caused by either RetICL selecting an unrelated second example or the second example being similar to the current problem but requiring a slightly different solution strategy. Therefore, RetICL can benefit from policy consistency regularizations and improvements to the retriever model to learn more accurate solution strategy representations.

6 Conclusions and Future Work

In this work, we proposed RetICL, a learnable method for sequential in-context example selection that, unlike existing methods that select all examples independently, takes previously selected examples into account. We framed the problem of sequential example selection as a Markov decision process and developed a novel reward function and example retriever model. We demonstrated that RetICL learns effective strategies for example selection and consistently performs well across math word problem and question answering tasks. There are many avenues for future work. First, we can explore RetICL’s effectiveness on a wider array of tasks, including ones with open-ended goals. Second, we can explore other architectural modifications that could further improve the retriever model, such as using a Transformer instead of an LSTM. Third, since we used a fixed number of examples, we can extend RetICL to let it learn how many examples are needed. Fourth, we can explore whether RetICL can be applied to real-world educational settings, e.g., selecting worked examples to help students solve practice problems.

Limitations

We note that there are several practical limitations to our method. First, we note that RetICL can be expensive and time-consuming to train, with each of our main training runs requiring up to 250,000 LLM inferences. This high number of inferences makes training on paid models prohibitively expensive; for example, it could cost up to approximately $2,500 to train on OpenAI’s text-davinci models. Additionally, newer OpenAI models, such as gpt-3.5-turbo and gpt-4, do not return likelihood information on generated text, making the confidence reward impossible to calculate for these models. We finally note that RetICL’s performance not increasing with the number of examples in the prompt may limit its use in practical settings.

Ethical Considerations

We first note that the high number of inferences required to train RetICL give the method an outsized cost in terms of energy usage; however, we note that the method has a relatively low cost at inference time given its relatively low number of parameters and potential for optimization with MIPS. Additionally, we note that because RetICL uses a black-box LLM reward signal, its example selections are not guaranteed to be interpretable by humans. Finally, because we only experiment with question answering settings, we did not perform any analysis of bias in RetICL’s selections. However, it is possible that RetICL could reflect biases in the LLM it is being trained on. As such, we recommend an analysis of bias in future works that use RetICL in sensitive settings such as student-facing educational tools.

References

  • Ahmed etal. (2019)Zafarali Ahmed, Nicolas LeRoux, Mohammad Norouzi, and Dale Schuurmans. 2019.Understanding the impact of entropy on policy optimization.In International conference on machine learning, pages151–160. PMLR.
  • Brown etal. (2020)TomB. Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared Kaplan,Prafulla Dhariwal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, AmandaAskell, Sandhini Agarwal, Ariel Herbert-Voss, Gretchen Krueger, Tom Henighan,Rewon Child, Aditya Ramesh, DanielM. Ziegler, Jeffrey Wu, Clemens Winter,Christopher Hesse, Mark Chen, Eric Sigler, Mateusz Litwin, Scott Gray,Benjamin Chess, Jack Clark, Christopher Berner, Sam McCandlish, Alec Radford,Ilya Sutskever, and Dario Amodei. 2020.Language modelsare few-shot learners.
  • Chang and Jia (2023)Ting-Yun Chang and Robin Jia. 2023.Data curationalone can stabilize in-context learning.In Proceedings of the 61st Annual Meeting of the Associationfor Computational Linguistics (Volume 1: Long Papers), pages 8123–8144,Toronto, Canada. Association for Computational Linguistics.
  • Chen etal. (2021)Mark Chen, Jerry Tworek, Heewoo Jun, Qiming Yuan, Henrique Ponde deOliveiraPinto, Jared Kaplan, Harri Edwards, Yuri Burda, Nicholas Joseph, GregBrockman, Alex Ray, Raul Puri, Gretchen Krueger, Michael Petrov, HeidyKhlaaf, Girish Sastry, Pamela Mishkin, Brooke Chan, Scott Gray, Nick Ryder,Mikhail Pavlov, Alethea Power, Lukasz Kaiser, Mohammad Bavarian, ClemensWinter, Philippe Tillet, FelipePetroski Such, Dave Cummings, MatthiasPlappert, Fotios Chantzis, Elizabeth Barnes, Ariel Herbert-Voss,WilliamHebgen Guss, Alex Nichol, Alex Paino, Nikolas Tezak, Jie Tang, IgorBabuschkin, Suchir Balaji, Shantanu Jain, William Saunders, ChristopherHesse, AndrewN. Carr, Jan Leike, Josh Achiam, Vedant Misra, Evan Morikawa,Alec Radford, Matthew Knight, Miles Brundage, Mira Murati, Katie Mayer, PeterWelinder, Bob McGrew, Dario Amodei, Sam McCandlish, Ilya Sutskever, andWojciech Zaremba. 2021.Evaluating largelanguage models trained on code.
  • Cobbe etal. (2021)Karl Cobbe, Vineet Kosaraju, Mohammad Bavarian, Mark Chen, Heewoo Jun, LukaszKaiser, Matthias Plappert, Jerry Tworek, Jacob Hilton, Reiichiro Nakano,Christopher Hesse, and John Schulman. 2021.Training verifiers to solve math word problems.arXiv preprint arXiv:2110.14168.
  • Fu etal. (2022)Yao Fu, Hao Peng, Ashish Sabharwal, Peter Clark, and Tushar Khot. 2022.Complexity-based prompting for multi-step reasoning.arXiv preprint arXiv:2210.00720.
  • Gao etal. (2020)Tianyu Gao, Adam Fisch, and Danqi Chen. 2020.Making pre-trained languagemodels better few-shot learners.CoRR, abs/2012.15723.
  • Hendrycks etal. (2021)Dan Hendrycks, Collin Burns, Steven Basart, Andy Zou, Mantas Mazeika, DawnSong, and Jacob Steinhardt. 2021.Measuring massive multitasklanguage understanding.
  • Hochreiter and Schmidhuber (1997)Sepp Hochreiter and Jürgen Schmidhuber. 1997.Long short-term memory.Neural computation, 9(8):1735–1780.
  • Johnson etal. (2019)Jeff Johnson, Matthijs Douze, and Hervé Jégou. 2019.Billion-scale similarity search with GPUs.IEEE Transactions on Big Data, 7(3):535–547.
  • Kadavath etal. (2022)Saurav Kadavath, Tom Conerly, Amanda Askell, Tom Henighan, Dawn Drain, EthanPerez, Nicholas Schiefer, Zac Hatfield-Dodds, Nova DasSarma, EliTran-Johnson, etal. 2022.Language models (mostly) know what they know.arXiv preprint arXiv:2207.05221.
  • Khot etal. (2019)Tushar Khot, Peter Clark, Michal Guerquin, PeterAlexander Jansen, and AshishSabharwal. 2019.Qasc: A dataset for question answering via sentence composition.In AAAI Conference on Artificial Intelligence.
  • Lester etal. (2021)Brian Lester, Rami Al-Rfou, and Noah Constant. 2021.The power of scale for parameter-efficient prompt tuning.arXiv preprint arXiv:2104.08691.
  • Levy etal. (2023)Itay Levy, Ben Bogin, and Jonathan Berant. 2023.Diversedemonstrations improve in-context compositional generalization.In Proceedings of the 61st Annual Meeting of the Associationfor Computational Linguistics (Volume 1: Long Papers), pages 1401–1422,Toronto, Canada. Association for Computational Linguistics.
  • Lewkowycz etal. (2022)Aitor Lewkowycz, Anders Andreassen, David Dohan, Ethan Dyer, HenrykMichalewski, Vinay Ramasesh, Ambrose Slone, Cem Anil, Imanol Schlag, TheoGutman-Solo, Yuhuai Wu, Behnam Neyshabur, Guy Gur-Ari, and Vedant Misra.2022.Solving quantitativereasoning problems with language models.
  • Liu etal. (2021a)Jiachang Liu, Dinghan Shen, Yizhe Zhang, Bill Dolan, Lawrence Carin, and WeizhuChen. 2021a.What makes goodin-context examples for gpt-3333?
  • Liu etal. (2021b)Pengfei Liu, Weizhe Yuan, Jinlan Fu, Zhengbao Jiang, Hiroaki Hayashi, andGraham Neubig. 2021b.Pre-train, prompt,and predict: A systematic survey of prompting methods in natural languageprocessing.
  • Lu etal. (2022)Pan Lu, Liang Qiu, Kai-Wei Chang, YingNian Wu, Song-Chun Zhu, TanmayRajpurohit, Peter Clark, and Ashwin Kalyan. 2022.Dynamic promptlearning via policy gradient for semi-structured mathematical reasoning.
  • Lu etal. (2021)Yao Lu, Max Bartolo, Alastair Moore, Sebastian Riedel, and Pontus Stenetorp.2021.Fantastically ordered prompts and where to find them: Overcomingfew-shot prompt order sensitivity.arXiv preprint arXiv:2104.08786.
  • Pitis etal. (2023)Silviu Pitis, MichaelR. Zhang, Andrew Wang, and Jimmy Ba. 2023.Boosted prompt ensembles forlarge language models.
  • Qin etal. (2023)Chengwei Qin, Aston Zhang, Anirudh Dagar, and Wenming Ye. 2023.In-context learning with iterative demonstration selection.arXiv preprint arXiv:2310.09881.
  • Reimers and Gurevych (2019)Nils Reimers and Iryna Gurevych. 2019.Sentence-bert: Sentenceembeddings using siamese bert-networks.In Proceedings of the 2019 Conference on Empirical Methods inNatural Language Processing. Association for Computational Linguistics.
  • Rubin etal. (2021)Ohad Rubin, Jonathan Herzig, and Jonathan Berant. 2021.Learning to retrieve prompts for in-context learning.arXiv preprint arXiv:2112.08633.
  • Saxe etal. (2013)AndrewM Saxe, JamesL McClelland, and Surya Ganguli. 2013.Exact solutions to the nonlinear dynamics of learning in deep linearneural networks.arXiv preprint arXiv:1312.6120.
  • Schulman etal. (2015)John Schulman, Philipp Moritz, Sergey Levine, Michael Jordan, and PieterAbbeel. 2015.High-dimensional continuous control using generalized advantageestimation.arXiv preprint arXiv:1506.02438.
  • Schulman etal. (2017)John Schulman, Filip Wolski, Prafulla Dhariwal, Alec Radford, and Oleg Klimov.2017.Proximal policyoptimization algorithms.
  • Su etal. (2022)Hongjin Su, Jungo Kasai, ChenHenry Wu, Weijia Shi, Tianlu Wang, Jiayi Xin, RuiZhang, Mari Ostendorf, Luke Zettlemoyer, NoahA Smith, etal. 2022.Selective annotation makes language models better few-shot learners.arXiv preprint arXiv:2209.01975.
  • Tian etal. (2023)Katherine Tian, Eric Mitchell, Allan Zhou, Archit Sharma, Rafael Rafailov,Huaxiu Yao, Chelsea Finn, and ChristopherD Manning. 2023.Just ask for calibration: Strategies for eliciting calibratedconfidence scores from language models fine-tuned with human feedback.arXiv preprint arXiv:2305.14975.
  • Vander Maaten and Hinton (2008)Laurens Vander Maaten and Geoffrey Hinton. 2008.Visualizing data using t-sne.Journal of machine learning research, 9(11).
  • Wang etal. (2023)Xuezhi Wang, Jason Wei, Dale Schuurmans, QuocV Le, EdH. Chi, Sharan Narang,Aakanksha Chowdhery, and Denny Zhou. 2023.Self-consistencyimproves chain of thought reasoning in language models.In The Eleventh International Conference on LearningRepresentations.
  • Wei etal. (2023)Jason Wei, Xuezhi Wang, Dale Schuurmans, Maarten Bosma, Brian Ichter, Fei Xia,EdChi, Quoc Le, and Denny Zhou. 2023.Chain-of-thought promptingelicits reasoning in large language models.
  • Wolfram (2023)Stephen Wolfram. 2023.Chatgpt gets its ’wolfram superpowers’!Stephen Wolfram Writings.
  • Zhang etal. (2022)Yiming Zhang, Shi Feng, and Chenhao Tan. 2022.Active exampleselection for in-context learning.
  • Zhao etal. (2021)TonyZ. Zhao, Eric Wallace, Shi Feng, Dan Klein, and Sameer Singh. 2021.Calibrate beforeuse: Improving few-shot performance of language models.

Appendix A Hyperparameters and Implementation Details

We implement the retriever model and RL algorithms in PyTorch. We note that our PPO implementation is slightly simpler than the standard implementation in that it does not use an inner training loop and instead takes a single training step on each batch sampled from the policy. We also note that we used GitHub Copilot to assist with minimal code-writing. We encode problem inputs and examples using the all-distilroberta-v1 pre-trained S-BERT model with the sentence-transformers library Reimers and Gurevych (2019), take the normalized mean-pooled final layer outputs as the embeddings, and use a soft prompt length of 20. With Codex, we use greedy decoding and set the maximum number of generated tokens to 450/400/150 for TabMWP/GSM8K/QASC, respectively. We set the LSTM’s hidden size to 800, PPO’s ϵitalic-ϵ\epsilonitalic_ϵ to 0.1, GAE’s λ𝜆\lambdaitalic_λ to 0.9, and cVFsubscript𝑐VFc_{\text{VF}}italic_c start_POSTSUBSCRIPT VF end_POSTSUBSCRIPT to 0.5. We set cEsubscript𝑐Ec_{\text{E}}italic_c start_POSTSUBSCRIPT E end_POSTSUBSCRIPT to 0.05 for TabMWP and cEsubscript𝑐Ec_{\text{E}}italic_c start_POSTSUBSCRIPT E end_POSTSUBSCRIPT to 0.1 for GSM8K, where different values are necessary since we find that our method performs differently across datasets and that cEsubscript𝑐Ec_{\text{E}}italic_c start_POSTSUBSCRIPT E end_POSTSUBSCRIPT has a large impact on training stability. We use orthogonal initialization Saxe etal. (2013) for all weight parameters, initialize all bias parameters to 0, and initialize soft prompts using a standard normal distribution. We train using the AdamW optimizer for 50 epochs with a learning rate of 0.001, a weight decay of 0.01, and a batch size of 20. We additionally apply gradient norm clipping on all parameters using a value of 2, which we find is critical to avoid spikes in training losses. We provide the values we experimented with in Table 3, where we selected values based on preliminary experiments aiming to optimize both accuracy and clock time. All final experiments were run on a Lambda Vector workstation with NVIDIA RTX A6000 GPUs. With a single thread, training a RetICL model for TabMWP with 5,000 training samples and 500 validation samples for 50 epochs takes approximately 44 hours, and inference for 1,000 TabMWP samples takes approximately 10 minutes. We note that we decrease runtime by using a cache to avoid repeated LLM inferences at both train and test time, batching requests to OpenAI, and using an adaptive backoff for waiting in between OpenAI API calls in order to maximize throughput given the API rate limit. Finally, we note that a RetICL model with the above hyperparameters has approximately 5.5 million parameters.

We use scikit-learn for t-SNE and matplotlib for visualizing data. All software used in this work is either open source or does not specify a license. To the best of our knowledge, we are consistent with the terms and intended use of all software and services, particularly OpenAI.

HyperparameterValues Tried
Training size1,000, 5,000
Corpus size20, 200, 500, 1,000, all
Soft prompt length5, 10, 20, 40
LSTM hidden size100, 200, 400, 800, 1600
ϵitalic-ϵ\epsilonitalic_ϵ0.1, 0.2
λ𝜆\lambdaitalic_λ0.9
cVFsubscript𝑐VFc_{\text{VF}}italic_c start_POSTSUBSCRIPT VF end_POSTSUBSCRIPT0.5, 1.0
cEsubscript𝑐Ec_{\text{E}}italic_c start_POSTSUBSCRIPT E end_POSTSUBSCRIPT0.01, 0.05, 0.1, 0.2, 0.5
Epochs50, 100, 250
Learning rate0.0001, 0.0005, 0.001
Weight decay0.001, 0.01
Batch size5, 10, 20, 40, 64
Gradient clipping0.5, 1.0, 2.0

Appendix B Dataset Details

TabMWP has a pre-defined train/validation/test split of 23,059/7,686/7,686, GSM8K has a pre-defined train/test split of 7,473/1,319, and QASC has a pre-defined train/validation/test split of 8,134/926/920. We reserve 1,000 random problems from GSM8K’s train set for validation. Because QASC’s test set does not have labels, we instead use the validation set for testing and reserve 1,000 random problems from the train set for validation. All datasets are exclusively in English. TabMWP uses the CC BY-NC-SA license, GSM8K uses the MIT license, and QASC uses the Apache 2.0 license. We show our prompt templates, which we use for both S-BERT and the LLM, in Table 4.

DatasetTemplateExample
TabMWP

Table: [table]
Problem: [problem statement]
Solution: [solution steps]
Final Answer: [final answer]

Table: [TITLE]: Siblings
Number of siblings | Frequency
0 | 14
1 | 3
2 | 6
3 | 8
4 | 0
Problem: The students in Mr. Boyer’s class recorded the number of siblings that each has. How many students have at least 2 siblings?
Solution: Find the rows for 2, 3, and 4 siblings. Add the frequencies for these rows.\n\nAdd:\n\n6 + 8 + 0 = 14\n\n14 students have at least 2 siblings.
Final Answer: 14

GSM8K

Problem: [problem statement]
Solution: [solution steps]
Final Answer: [final answer]

Problem: Marcos has to get across a 5 mile lake in his speedboat in 10 minutes so he can make it to work on time. How fast does he need to go in miles per hour to make it?
Solution: 10 minutes for 5 miles means 10 minutes / 5 miles = 2 minutes/mile
1 hour is 60 minutes so 60 minutes/hour / 2 minutes/mile = 30 miles/hour
Final Answer: 30

QASC

Question: [question] [options]
Solution: [fact 1] [ fact 2] [combined fact]
Final Answer: [final answer]

Question: Mussels have what? (A) seaweed (B) arms (C) Energy (D) a shell (E) warmth (F) bacteria (G) Length (H) legs
Solution: Most mollusks have shells. Mussels are bivalve mollusks. Mussels have shells.
Final Answer: a shell

Appendix C Inference-Time Optimization

We note that our formulation for ϕitalic-ϕ\phiitalic_ϕ enables efficient retrieval of the top-ranking example at each time step via maximum inner-product search (MIPS). We first note that ϕ(St,e)italic-ϕsubscript𝑆𝑡𝑒\phi(S_{t},e)italic_ϕ ( italic_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_e ) is maximized by finding the example e𝑒eitalic_e that maximizes the inner product 𝐡t,𝐖a𝐞subscript𝐡𝑡subscript𝐖𝑎𝐞\langle\mathbf{h}_{t},\mathbf{W}_{a}\mathbf{e}\rangle⟨ bold_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_W start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT bold_e ⟩. We can leverage this information by first pre-computing 𝐖a𝐞subscript𝐖𝑎𝐞\mathbf{W}_{a}\mathbf{e}bold_W start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT bold_e for each example in the corpus and constructing a MIPS index over these vectors, using a library such as faiss Johnson etal. (2019). At inference time, we can now leverage algorithms that perform approximate MIPS in sublinear time, i.e., maximize ϕ(St,e)italic-ϕsubscript𝑆𝑡𝑒\phi(S_{t},e)italic_ϕ ( italic_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_e ) without evaluating 𝐡tT𝐖a𝐞superscriptsubscript𝐡𝑡𝑇subscript𝐖𝑎𝐞\mathbf{h}_{t}^{T}\mathbf{W}_{a}\mathbf{e}bold_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_W start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT bold_e for each example in the corpus. We note that we do not use MIPS in this work since the corpora we experiment on are sufficiently small such that evaluating 𝐡tT𝐖a𝐞superscriptsubscript𝐡𝑡𝑇subscript𝐖𝑎𝐞\mathbf{h}_{t}^{T}\mathbf{W}_{a}\mathbf{e}bold_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_W start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT bold_e for each example is relatively inexpensive. However, we expect that significant computational time can be saved with MIPS when evaluating on corpora at much larger scales.

Appendix D LSTM Classifier Details

For our LSTM Classifier baseline, we use a similar architecture to RetICL but train using a supervised objective rather than RL. Specifically, we collect a dataset 𝒟𝒟\mathcal{D}caligraphic_D of 20 prompts and resulting LLM outputs for each sample in the training and validation sets, where the examples in the prompts are randomly selected from the corpus. We feed samples (x,e1,,eT)𝒟𝑥subscript𝑒1subscript𝑒𝑇𝒟(x,e_{1},\ldots,e_{T})\in\mathcal{D}( italic_x , italic_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_e start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ) ∈ caligraphic_D to an LSTM in the same way they are fed to RetICL. However, the LSTM classifier does not use a bilinear projection to define the policy π(St,e)𝜋subscript𝑆𝑡𝑒\pi(S_{t},e)italic_π ( italic_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_e ), but rather uses a linear projection on LSTM hidden states to define an approximation of the Q𝑄Qitalic_Q-function q^(St,e)^𝑞subscript𝑆𝑡𝑒\hat{q}(S_{t},e)over^ start_ARG italic_q end_ARG ( italic_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_e ). We formalize this objective as

q^(St1,et)=𝐡tT𝐰q+bq,^𝑞subscript𝑆𝑡1subscript𝑒𝑡superscriptsubscript𝐡𝑡𝑇subscript𝐰𝑞subscript𝑏𝑞\displaystyle\hat{q}(S_{t-1},e_{t})=\mathbf{h}_{t}^{T}\mathbf{w}_{q}+b_{q},over^ start_ARG italic_q end_ARG ( italic_S start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT , italic_e start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) = bold_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_w start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT + italic_b start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ,
=t=1T(q^(St1,et)RG)2,superscriptsubscript𝑡1𝑇superscript^𝑞subscript𝑆𝑡1subscript𝑒𝑡superscript𝑅𝐺2\displaystyle\ell=\sum_{t=1}^{T}(\hat{q}(S_{t-1},e_{t})-R^{G})^{2},roman_ℓ = ∑ start_POSTSUBSCRIPT italic_t = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( over^ start_ARG italic_q end_ARG ( italic_S start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT , italic_e start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) - italic_R start_POSTSUPERSCRIPT italic_G end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ,

where 𝐰qdhsubscript𝐰𝑞superscriptsubscript𝑑\mathbf{w}_{q}\in\mathbb{R}^{d_{h}}bold_w start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUPERSCRIPT and bqsubscript𝑏𝑞b_{q}\in\mathbb{R}italic_b start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ∈ blackboard_R are learnable parameters, and \ellroman_ℓ is the loss for a single sample in 𝒟𝒟\mathcal{D}caligraphic_D. At inference time, we use the approximated Q𝑄Qitalic_Q-function to guide an example selection policy via greedy decoding, i.e., et+1=argmaxe𝒞q^(St,e)subscript𝑒𝑡1subscriptargmax𝑒𝒞^𝑞subscript𝑆𝑡𝑒e_{t+1}=\operatorname*{argmax}_{e\in\mathcal{C}}\hat{q}(S_{t},e)italic_e start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT = roman_argmax start_POSTSUBSCRIPT italic_e ∈ caligraphic_C end_POSTSUBSCRIPT over^ start_ARG italic_q end_ARG ( italic_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_e ).

We perform batched training on 𝒟𝒟\mathcal{D}caligraphic_D and evaluate the classification accuracy on the validation set after each epoch. We define this accuracy as the percentage of samples where the estimated value of the final example in the prompt correctly predicts if y^^𝑦\hat{y}over^ start_ARG italic_y end_ARG will be correct or not, i.e., sign(q^(ST1,eT))=sign(RG)sign^𝑞subscript𝑆𝑇1subscript𝑒𝑇signsuperscript𝑅𝐺\operatorname{sign}(\hat{q}(S_{T-1},e_{T}))=\operatorname{sign}(R^{G})roman_sign ( over^ start_ARG italic_q end_ARG ( italic_S start_POSTSUBSCRIPT italic_T - 1 end_POSTSUBSCRIPT , italic_e start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ) ) = roman_sign ( italic_R start_POSTSUPERSCRIPT italic_G end_POSTSUPERSCRIPT ). We train using the AdamW optimizer for 20 epochs with a learning rate of 1e-4, a weight decay of 0.01, a batch size of 256, and perform early stopping based on the classification accuracy on the validation set. We use the same training set, validation set, and training corpus as RetICL. We do not use soft prompts for S-BERT in this setting since we found that it hurts performance. We achieve classification accuracies of 85.54, 61.38, and 61.97 on TabMWP, GSM8K, and QASC, respectively. We note that these accuracies do not correlate with downstream performance when using the classifiers to guide example selection at inference time. This is likely because q^^𝑞\hat{q}over^ start_ARG italic_q end_ARG is trained on a random policy but is used to guide a greedy policy at inference time, so there is an inherent difference between evaluations in these settings. Given these results, we can imagine a different inference time setting where the classifier is used to rank a list of randomly selected prompts rather than guiding a policy, but we leave this investigation for future work. Regardless, we believe these results further highlight the benefits of using RL algorithms for the example selection task.

Appendix E Scaling Number of In-Context Examples

We perform preliminary experiments on scaling T𝑇Titalic_T, the number of in-context examples used in each prompt. Specifically, with T{2,3,4,5}𝑇2345T\in\{2,3,4,5\}italic_T ∈ { 2 , 3 , 4 , 5 }, we calculate the test set accuracy on GSM8K using prompting via Random, kNN, and RetICL (trained on 1,000 samples), and show the results in Figure 3. We observe that while RetICL still performs better than baselines as T𝑇Titalic_T increases, the gap between the methods narrow, particularly because RetICL performs roughly the same with T=2𝑇2T=2italic_T = 2 as it does with higher values. While this result is unexpected, we note that finding an effective policy with a higher T𝑇Titalic_T is a more complex task due to increased degrees of freedom. Therefore, we hypothesize that RetICL’s inability to capitalize on more examples is due to training difficulties, and possibly because of our simpler PPO implementation. We leave further investigations of this phenomenon and potential training improvements for future work.

RetICL: Sequential Retrieval of In-Context Examples with Reinforcement Learning (3)

Appendix F Latent Space Analysis Details

RetICL: Sequential Retrieval of In-Context Examples with Reinforcement Learning (4)

RetICL: Sequential Retrieval of In-Context Examples with Reinforcement Learning (5)

RetICL: Sequential Retrieval of In-Context Examples with Reinforcement Learning (6)

RetICL: Sequential Retrieval of In-Context Examples with Reinforcement Learning (7)

We show the visualized example embeddings in latent space in Figure 4. For GSM8K, we see that RetICL groups examples based on the number of solution steps, whereas the pre-trained S-BERT embeddings do not.We also see that clusters in the RetICL embeddings have been somewhat merged together from the pre-trained embeddings. This result can be interpreted by observing the pre-trained embeddings to be primarily clustered based on topic, e.g., problems about money and problems about time belong to separate clusters, since S-BERT embeddings reflect the semantic content of the examples. While local neighbors in the RetICL embedding space also tend to have similar topics, the clusters are less well-separated, which implies that both topic and the solution strategy, which is partly reflected in the length of solution steps, are used for example selection by RetICL.

For TabMWP, we see that the space looks very different from GSM8K, with many separate clusters being present in both the RetICL and pre-trained spaces, primarily based on the problem’s template. For example, there is one cluster for asking yes/no questions about schedules, one for asking if someone has enough money to buy something, and one for asking what the mean of a set of numbers is. Since RetICL retains these clusters, we can infer that an example’s template is key to example selection. This observation is also validated by kNN’s high performance on this dataset, since the most semantically similar problems are always from the same cluster. While there are not many differences between the RetICL and pre-trained S-BERT embedding spaces, we observe that RetICL has pulled several clusters closer together. For example, it partially merges together problems that require finding the largest value and problems that require finding the smallest value from a set. This observation suggests that problems across the merged template clusters can be used interchangeably as examples, since their problems tend to have similar reasoning strategies.

Appendix G Representative Example Selections

Problem

Marcus is half of Leo’s age and five years younger than Deanna. Deanna is 26. How old is Leo?

Gold Solution

Marcus is 26 - 5 = 21 years old. Thus, Leo is 21 * 2 = 42 years old. Final Answer: 42

RetICLkNN
Selected Examples

Problem: Katy, Wendi, and Carrie went to a bread-making party. Katy brought three 5-pound bags of flour. Wendi brought twice as much flour as Katy, but Carrie brought 5 pounds less than the amount of flour Wendi brought. How much more flour, in ounces, did Carrie bring than Katy? Solution: Katy’s three 5-pound bags of flour weigh 3*5=15 pounds. Wendi brought twice as much flour as Katy, or 15*2=30 pounds of flour. Carrie brought 5 pounds less flour than Wendi, or 30-5=25 pounds of flour. Thus, Carrie brought 25-15=10 pounds of flour more than Katy. In ounces, the 10 pound difference is 10*16=160 ounces. Final Answer: 160

Problem: Vivian plays 10 Spotify songs every day. Her best friend Clara plays 2 fewer songs each day. If in June they didn’t play any song during the weekends only, and there were 8 weekend days in June, what’s the total number of songs they both listened to in that month? Solution: June has 30 days, so if they didn’t play any songs on weekends, then they played 30 days - 8 days = 22 days. If Vivian played 10 songs each day, the total number of songs she played for the month is 10 songs/day * 22 days = 220 songs. Clara played 2 fewer songs, which is 10 songs/day - 2 songs/day = 8 songs/day. The total number of songs played by Clara for the whole month is 8 songs/day * 22 days = 176 songs. The total played by both is 220 songs + 176 songs = 396 songs Final Answer: 396

Problem: Leonard is 4 years younger than Nina who is half as old as Jerome. If the sum of their ages is 36, what is Leonard’s age? Solution: Let x be the age of Nina. So, Leonard’s age is x - 4. And Jerome’s age is 2x. Since the sum of their ages is 36, then the equation is x + x - 4 + 2x = 36. By combining like terms, the equation becomes 4x = 40. So the value of x which represents the age of Nina is x = 40/4 = 10. Therefore, Leonard is 10 - 4 = 6 years old. Final Answer: 6

Problem: In 6 years, Carla will be 30 years old. The sum of the current ages of Carla and Louis is 55. How old is Louis now? Solution: Carla is currently 30-6=24 years old. Louis is 55-24=31 years old at present. Final Answer: 31

Generated Solutions

Leo is twice Marcus’ age, so Leo is 2*Marcus. Deanna is 5 years younger than Marcus, so Marcus is Deanna - 5. Deanna is 26, so Marcus is 26 - 5 = 21. Leo is 2*Marcus, so Leo is 2*21 = 42. Final Answer: 42

Leo is 26+5=31 years old. Marcus is 31/2=15.5 years old. Final Answer: 31

Problem

Table: [TITLE]: Pairs of shoes per store
Stem | Leaf
1 | 9
2 | 3, 9
3 | 2, 8
4 | 5
5 | 2
6 | 2, 3
Problem: Kristen counted the number of pairs of shoes for sale at each of the shoe stores in the mall. How many stores have at least 30 pairs of shoes but fewer than 40 pairs of shoes? (Unit: stores)

Gold Solution

Count all the leaves in the row with stem 3. You counted 2 leaves, which are blue in the stem-and-leaf plot above. 2 stores have at least 30 pairs of shoes but fewer than 40 pairs of shoes. Final Answer: 2

RetICLkNN
Selected Examples

Table: barrette | $0.88
bottle of hand lotion | $0.96
sewing kit | $0.94
box of bandages | $0.94
box of breath mints | $0.80
Problem: How much money does Eve need to buy 6 bottles of hand lotion and a barrette? (Unit: $)Solution: Find the cost of 6 bottles of hand lotion. $0.96 × 6 = $5.76 Now find the total cost. $5.76 + $0.88 = $6.64 Eve needs $6.64. Final Answer: 6.64

Table: [TITLE]: Rotten tomatoes per barrel
Stem | Leaf
2 | 0, 2, 6, 7
3 | 5, 6, 9
4 | 1, 5
5 | 4
6 | 1, 2, 9
7 | 2, 3, 5
Problem: The Clarksville Soup Company recorded the number of rotten tomatoes in each barrel it received. How many barrels had at least 70 rotten tomatoes but less than 80 rotten tomatoes? (Unit: barrels) Solution: Count all the leaves in the row with stem 7. You counted 3 leaves, which are blue in the stem-and-leaf plot above. 3 barrels had at least 70 rotten tomatoes but less than 80 rotten tomatoes. Final Answer: 3

Table: [TITLE]: Pairs of shoes per store
Stem | Leaf
3 | 5
4 | 2, 6
5 | 7
6 | 0, 4
7 |
8 | 4
9 | 0
Problem: Ron counted the number of pairs of shoes for sale at each of the shoe stores in the mall. How many stores have at least 50 pairs of shoes but fewer than 100 pairs of shoes? (Unit: stores) Solution: Count all the leaves in the rows with stems 5, 6, 7, 8, and 9. You counted 5 leaves, which are blue in the stem-and-leaf plot above. 5 stores have at least 50 pairs of shoes but fewer than 100 pairs of shoes. Final Answer: 5

Table: [TITLE]: Pairs of shoes per store
Stem | Leaf
1 | 9
2 | 4
3 | 0, 0, 1
4 | 3
5 | 9
Problem: Emma counted the number of pairs of shoes for sale at each of the shoe stores in the mall. How many stores have at least 24 pairs of shoes but fewer than 46 pairs of shoes? (Unit: stores) Solution: Find the row with stem 2. Count all the leaves greater than or equal to 4. Count all the leaves in the row with stem 3. In the row with stem 4, count all the leaves less than 6. You counted 5 leaves, which are blue in the stem-and-leaf plots above. 5 stores have at least 24 pairs of shoes but fewer than 46 pairs of shoes. Final Answer: 5

Generated Solutions

Count all the leaves in the row with stem 3. You counted 2 leaves, which are blue in the stem-and-leaf plot above. 2 stores have at least 30 pairs of shoes but fewer than 40 pairs of shoes. Final Answer: 2

Count all the leaves in the rows with stems 3 and 4. You counted 5 leaves, which are blue in the stem-and-leaf plots above. 5 stores have at least 30 pairs of shoes but fewer than 40 pairs of shoes. Final Answer: 5

RetICL: Sequential Retrieval of In-Context Examples with Reinforcement Learning (2024)
Top Articles
Latest Posts
Article information

Author: Catherine Tremblay

Last Updated:

Views: 6298

Rating: 4.7 / 5 (47 voted)

Reviews: 94% of readers found this page helpful

Author information

Name: Catherine Tremblay

Birthday: 1999-09-23

Address: Suite 461 73643 Sherril Loaf, Dickinsonland, AZ 47941-2379

Phone: +2678139151039

Job: International Administration Supervisor

Hobby: Dowsing, Snowboarding, Rowing, Beekeeping, Calligraphy, Shooting, Air sports

Introduction: My name is Catherine Tremblay, I am a precious, perfect, tasty, enthusiastic, inexpensive, vast, kind person who loves writing and wants to share my knowledge and understanding with you.