Why do Nearest Neighbor Language Models Work?

Frank F. Xu, Uri Alon, Graham Neubig

Research output: Contribution to journalConference articlepeer-review

Abstract

Language models (LMs) compute the probability of a text by sequentially computing a representation of an already-seen context and using this representation to predict the next word. Currently, most LMs calculate these representations through a neural network consuming the immediate previous context. However recently, retrieval-augmented LMs have shown to improve over standard neural LMs, by accessing information retrieved from a large datastore, in addition to their standard, parametric, next-word prediction. In this paper, we set out to understand why retrieval-augmented language models, and specifically why k-nearest neighbor language models (kNN-LMs) perform better than standard parametric LMs, even when the k-nearest neighbor component retrieves examples from the same training set that the LM was originally trained on. To this end, we perform analysis of various dimensions over which kNN-LM diverges from standard LMs, and investigate these dimensions one by one. Empirically, we identify three main reasons why kNN-LM performs better than standard LMs: using a different input representation for predicting the next tokens, approximate kNN search, and the importance of softmax temperature for the kNN distribution. Further, we incorporate some insights into the standard parametric LM, improving performance without the need for an explicit retrieval component. The code is available at https://github.com/frankxu2004/knnlm-why.

Original languageEnglish
Pages (from-to)38325-38341
Number of pages17
JournalProceedings of Machine Learning Research
Volume202
StatePublished - 2023
Externally publishedYes
Event40th International Conference on Machine Learning, ICML 2023 - Honolulu, United States
Duration: 23 Jul 202329 Jul 2023

Bibliographical note

Publisher Copyright:
© 2023 Proceedings of Machine Learning Research. All rights reserved.

ASJC Scopus subject areas

  • Artificial Intelligence
  • Software
  • Control and Systems Engineering
  • Statistics and Probability

Fingerprint

Dive into the research topics of 'Why do Nearest Neighbor Language Models Work?'. Together they form a unique fingerprint.

Cite this