The subset parity learning problem: much more than you wanted to know
Imagine that you’re looking for buried treasure on a large desert island, worth a billion dollars. You don’t have a map, but a mysterious hermit offers you a box with a button to help find the treasure. Each time you press the button, it will tell you either “warmer” or “colder”. But there’s a catch. With probability the box will tell you the truth about whether you’re closer than you were last time you pressed. But with the remaining probability of .9999999999999999999999999999992, the box will make a random guess between “warmer” and “colder”. Should you pay $1 for this box?
Keep this in mind as we discuss the closely related problem of parity learning.
In my experience of interacting with the ML and interpretability communities, the majority of people don’t know about the impossibility result of the parity learning problem. The ones who do will often assume that this is a baroque, complicated result that surely doesn’t have a simple proof (another surprising opinion I’ve heard is of people knowing about the result, but saying that “there’s some new architecture that seems to solve it, actually”, which is somewhat indicative of people’s trust in the concept of “proof” in the ML community).
Recently I was pleasantly surprised to realize that the impossibility of solving this problem (via a gradient-based learning algorithm in polynomial time) actually admits a pretty nice and understandable proof. The gist of it boils down to the silly “pirate treasure” story above: the answer, of course, is that you shouldn’t buy the box (at least if you’re trying to maximize your expected income), and for the same reason you can’t build a cool new architecture that solves the parity learning problem.
In this post I’ll briefly explain the problem, how it’s different from some other impossibility results, and why I think it’s important. This will later tie in to a series of posts about the insufficiency of Bayesian methods in (realistic) ML contexts.
What is the parity learning problem?
This is a post where a bit of math can go a long way, but I’ll try to make it approachable to anyone with either a bit of math or a bit of ML background. First, the XOR parity function is a function from a length-n boolean input to a single boolean output. This function depends on a “secret” variable S, which is a subset: (here is the standard set on n elements). On an input vector the function outputs In other words, you look at the bits of on indices indexed by take their sum, and take its parity. Note (for people with some math background) that the value can be written more nicely as where we replace the set by the vector with coefficients in position at indices and otherwise. The vectors can then be interpreted as having coefficients in the field of two elements, and denotes their dot product in this field (i.e., the sum of products of coordinates).
Now the parity problem says that it’s not possible to solve the parity problem as a learning problem in polynomial time. This statement should be interpreted carefully. First, note that we haven’t defined what a learning problem is. A special case of a learning problem is any (polynomial-complexity) weight-based ML architecture that learns via SGD on cross-entropy loss (together with any choice of batch size, update step, initialization protocol, etc.). We will take this as our definition of “learnability” for the sake of this post, though later I’ll point out that our proof also shows that a much larger class of methods is incapable of solving parity. (On the other hand, as we’ll see, an undergraduate with a few hours to spare can solve the parity problem in polynomial time, and with a bit more time can even hand-select weights in an ML architecture to execute their solution.) The second thing to be careful of is that, for any fixed choice of “hidden” subset S, it is possible to design an algorithm that learns the parity problem. Indeed, you can simply initialize the architecture to the “right solution”. So it’s important to conceptualize S here as secret or random.
More concretely, the problem can be conceptualized as a game between two players A and B. Player A randomly chooses a secret subset where is the standard set on n elements (there are subsets, including all of [n] and the empty set, so each is chosen with probability ). Player B commits to a (polynomial-sized) learning algorithm M, which for us means an architecture, initialization scheme, and class of hyperparameters for a gradient-based learning scheme like SGD. Player A then randomly generates a number of sample boolean vectors with N some agreed-upon constant that depends at most polynomially in n, and player B trains the learning algorithm for N’ steps, where N’ is another (large) number that is nevertheless polynomially bounded in n.
The theorem then says that no matter what learning algorithm that player B chose, the probability that the setup will learn an algorithm with >51% accuracy is effectively zero (i.e., it’s exponentially small in n).
This “hidden guessing” game is about P vs. NP, isn’t it?
This is another common confusion, and the answer is no. Most theoretical computer scientists believe that . And assuming this[1] gives another source of unsolvable learning problems. Indeed, if you were to give player A in the above game more freedom, and allowed them to write down any (suitably) randomly chosen circuit (or suitable random ML algorithm) for their “target” function, you get another impossibility result. Here the assumption PBPP implies that there is no way, in polynomial time, to get reliable information about A’s secret circuit C (beyond some statistical regularities) from looking at polynomially many input-output samples of C. This in particular implies that there is no way to guarantee sufficiently accurate behavior of the result of a learning algorithm, since a learning algorithm is a special case of a (probabilistic) polynomial algorithm. But the XOR impossibility is in fact a much more satisfying result. It doesn’t require any assumptions about P vs. NP (and is true mathematically and unconditionally), and even more nicely, there actually does exist a (probabilistically) polynomial-time algorithm to solve it. In other words, we have the following containments (where note that I’m being sloppy about exactly what an “algorithm” is):
And in the case of the “XOR parity” problem, it shows that (without making any assumptions on the first containment, i.e., about P vs. NP), the second containment is proper: i.e., there are polynomially invertible algorithms which are impossible to execute as learning algorithms.
To convince you that this isn’t some deep hidden knowledge, I’ll explain in the following section how an undergrad with a semester of abstract algebra can solve the XOR parity problem. Since knowing the specific solution isn’t critical, I’ll assume a bit of abstract algebra here, and people without the necessary context can safely skip the following section. To be clear about its result, before going on I’ll write the upshot:
Upshot of the following section:
In the setup as above, with player A having a hidden subset S and player B receiving N samples of the input-output behavior of the parity function for boolean inputs it is possible for B to recover the subset (and thus the function ) in polynomial time N’, with overwhelming probability.
In fact, it’s sufficient to look at N = 2n samples (and the length of the solution algorithm is cubic or better in n).
Polynomial-time non “learning-algorithmic” solution
For this section I’m assuming some linear algebra over finite fields; if this isn’t your jam, skip to the next session.
The basic idea is to replace the function with the dot product The random samples can then be understood as an overdetermined linear equation over the field with two elements. Namely, given our random samples we write down an matrix whose length- rows are . Because the samples are chosen randomly, these are random boolean vectors. It is now a standard theorem that given vectors, the probability that they are linearly independent goes asymptotically as As soon as the probability that they’re linearly dependent is , i.e., is negligible (in fact, in the formal sense of going to zero faster than any inverse polynomial).
Thus we can safely assume that the sample input vectors are linearly independent. Now in the assumption of the problem, A gave us both the vectors and also the values We can convert this to a system of linear equations (in ) on the secret vector Namely, we have
for the vector of parities. Now linear independence of the implies that is reconstructible in polynomial time. For example if and already the first sample inputs are linearly independent, then we can write and inverting a (boolean) matrix is doable in cubic time.
Before going on, note that once we’ve correctly guessed the secret subset we can write down a straightforward MLP that executes the parity XOR function Indeed, assume has elements (with since is a subset of an -element set) and let be the elements of , in order. Then we can recursively write as
Now the XOR of two boolean elements is straightforward to write down as a single-layer MLP (whether using ReLU or any other activation function), and appropriately stacking k<n of them together gives a polynomial-size neural net that executes our hidden function (in fact, utilizing some parallelization allows this to be done in O(log(k)) layers).
Handwavy proof of non-learnability
Welcome back to non-math people. This section is also slightly technical and can be skipped by people who don’t care about understanding the proof, but it doesn’t require any abstract algebra background.
At the end of the day, we have a function that we’re claiming can be represented by a (polynomial-sized) neural net, but cannot be learned in polynomial time as such. How can one go about showing this? The important bits of information to collect here are the following:
The functions for different subsets form a basis[2] of all (real-valued) functions on boolean inputs. (Up to rescaling and subtracting a constant, this is also called the “Fourier basis” of functions on boolean inputs.)
Any two functions and for two different subsets are uncorrelated on the set of all inputs [3]. This is key: even if S and T differ in only one element, the half of input vectors that have at that input will have different parities on S and T.
The randomness in the choice of N samples leads to noise in the updates, on the order of 1/polynomial(N).
From these facts we see that:
The gradient update can be decomposed into components associated to each subset S.
The coefficient of the gradient update along each incorrect direction T can be modeled as a random variable, and is comparable in size (up to a polynomial multiple) to the update in the “true” direction T.
Of course, a priori the proof above assumes that the space of possible functions f_w(v) associated to possible weight parameter values coincides with the -dimensional space of all possible functions on boolean inputs. Since we’ve assumed that the number of parameters is polynomial in this isn’t the case: rather, the vector space of possible gradient updates is constrained to be in some -dimensional subspace of the -dimensional space of functions with boolean inputs. Equivalently, all the above possible update directions are projected to some low-dimensional subspace (and suitably normalized).
At the end of the day, we can model the gradient update as a noise vector of some fixed size (that is inverse-polynomial in n, and is associated to the randomness of drawing random inputs), plus a projection of a “signal” vector associated to in to some poly(n)-dimensional subspace. Now standard considerations of high-dimensional projections imply that the “signal” vector might have significant size for some small (polynomial in n) number of “special” subsets S, but for the vast majority of choices of S, it will be suppressed by a massive factor proportional to the square root of dimension: and will completely fail to affect the noise, even after polynomially many update steps: thus the problem of gradient updating to the correct parity algorithm boils down (more or less) to the problem of the pirate-treasure hunter with the very unreliable box.
The discerning reader will see that I swept a significant chunk of not only the proof, but even the logical flow of the argument under the rug: this is perhaps better described as the “intuition” behind the proof rather than a sketch. However, importantly, this “intuition” applies to absolutely any (polynomial-sized) architecture, and in fact applies to a much more general context than SGD: in fact, any learning algorithm including SGD, Adam, even more sophisticated local Bayesian learning setups, will fail for the same reasons.
In fact, what we really used about the SGD “learning algorithms” was that it has some noise and its updating process only uses information averaged over input samples. In fact there is a general result that any learning algorithm that only uses this information cannot learn parity (in polynomial time). The definition of this class of algorithms and its relationship with various learnability and complexity results constitutes the beginning of the classical field of computational learning theory. For a nice compressed introduction which in particular formalizes the proof discussed here, see this paper.
Alternative point of view: lack of incremental pathways
An alternative point of view on the failure of learnability in this case is that there is no good way for an algorithm to incrementally learn parity. There is no story of learning parity that starts from simple algorithms (in some quantitative or even qualitative, Occam razor-esque sense) and recursively learns added epicycles of complexity which improve classification accuracy. For example if we were to try to approximate the parity function by parity functions of smaller subsets, we would totally fail (as parity functions associated to different subsets are uncorrelated); a stronger version of the “lack of incremental pathways” result can be made following a similar intuition to the proof sketch above. This supports the idea that in order to be learnable, an algorithm must in some sense be combinable (at least in a local sense) out of simpler pieces, each of which is “findable” (i.e., doesn’t require exponential luck to get right – in later posts we will identify this with notions of effective dimension) and each of which reduces loss. This is closely related to the “low-hanging fruit prior” point of view, and will later serve as a lead-in to a discussion of “learning stories”,
Does this mean that neural nets are weak?
Now that we’ve seen that neural nets trained on examples of the parity prbolem are provably incapable of learning it in polynomial time, it is reasonable to ask whether this is a hard limitation on the computational capabilities of neural nets. Indeed, I just explained that it is provably impossible to use a learning algorithm (such as an LLM) to solve a problem that can be easily solved by an undergraduate, at least in an amount of time that is shorter than the length of the universe. Does this negate the possibility that modern LLM’s can solve hard-core math problems? Can we stop worrying about human-level AI?
Unfortunately (if you’re worried about AI risk), we haven’t. The impossibility of XOR learning does not imply any limitation on the mathematical ability of LLMs. The issue here is with the notion of “learnability”. In the setup of our XOR problem, we assumed that the LLM is executing SGD learning (or another learning algorithm) on the single learning signal of “what is the parity function applied to the vector ”. If we were to give the parity problem to an advanced LLM, it might be able to solve it, but this would not be from gradient updates on seeing a bunch of examples. Rather, our LLM has seen many mathematical texts, and may be able to use the knowledge in these mathematical texts and a basic understanding of logic to reconstruct the hidden subset S and the parity function Abstracting away the high-level “mathematical understanding” of LLMs, what this is saying is that it is in fact possible to learn the parity problem if the direct learning problem is replaced by a suitably sophisticated curriculum learning-style problem with an enriched class of examples and a more sophisticated loss function. Trying to write out a simple “mathematical” ML algorithm that learns to solve the parity problem is an interesting exercise, that might constitute a nice ML theory paper; I won’t try to do this here.
Not weak, but also not optimal
The main reason I want people in AI safety/ interpretability to know about and understand parity is related to a long-standing question in machine learning of Bayesian learning vs. SGD, where the conventional wisdom has been wrong, but (in my limited understanding) is finally starting to converge in the correct direction (as will be typical of these posts, this is the “it’s complicated” direction). Namely, you can ask whether SGD (and related learning algorithms) can be well understood as finding the optimal solution – or more precisely, as sampling a suitable “Bayesian” distribution of near-optimal solutions[4]. It is easy to see that learning sufficiently general algorithms cannot converge to anything like Bayesian learning for P vs. NP reasons. But a standard counterargument, supported by a standard collection of faulty papers (that I’ll complain about later), was that “real life” problems where deep learning is applied do converge to the Bayesian prior.
One (soft) takeaway from the discussion here is that if training “real-life” modern LLMs involves reasoning in the same reference class as parity, then it is likely that the algorithm they learn is not globally optimal (in a Bayesian sense). Indeed, we see from parity that optimal algorithms in this reference class lack the incremental pathways necessary to be learnable via SGD, and the way that LLMs solve complex problems probably is mediated by curriculum-learning-style “training wheels” that learn general solutions, just not of the most efficient type. [5]
Acknowledgments
I’ve talked to lots of people about this, but particularly important for this post have been a number of conversations with Kaarel Hänni on related topics. I also want to thank Sam Eisenstat for first telling me about parity and the notion of learnability, and thanks to Jake Mendel and Lucius Bushnaq for related discussions.
- ^
More precisely, the probabilistic version and I might be assuming some cryptographic hardness in other places in this section (but not in the rest of the post!)
- ^
Technically, there are 2^n functions but they only span a 2^{n-1}-dimensional subspace, since the “empty-set parity” function is zero; to make this statement precise, one needs to replace by a constant function. A more commonly used related basis is the “Boolean Fourier mode” basis with basis elements which replace the valued functions parity function by a -valued analog. Working with this basis is generally nicer, and in particular makes “uncorrelatedness” arguments cleaner.
- ^
They are uncorrelated on the set of all inputs (i.e., the expected value of doesn’t change even if you condition on a specific value of ), but they are correlated (“in a random way”) on some fixed polynomial-sized set of “training” inputs. In the latter training-set context, they are “not very” correlated, and the correlations can be proved to be suitably “unbiased” when viewed as a noise term.
- ^
This is normally defined as the Boltzmann distribution associated to loss, an object particularly important to Singular Learning Theory.
- ^
Note that this isn’t even an “intuition-level” proof: it’s not obvious that modern ML methods require knowing how to solve problems in the reference class of “parity for n-bit inputs such that is very large”. And even if this were the case, it’s not obvious that ML learning problems don’t just happen to have the property that the “training wheels” for learning parity-like problems don’t just happen to be needed to produce Bayes-optimal algorithms for some important simpler problems. Later when we discuss connections between physics and ML, we’ll see other more rigorous reasons to dismiss strong versions of the SGD = Bayesian hypothesis. But at the same time, it’s important to note that in many contexts: namely, when looking locally in a basin, or looking at simple “circuit-level” behaviors that don’t have enough accumulated complexity to break out of a low-dimensional paradigm, it is reasonable and productive to make little distinction between the two types of learning.
The analogue of P≠NP for randomized polynomial-time algorithms is actually NP≠RP (or equivalently NP⊈BPP), assuming that the intended meaning here is “NP-complete problems cannot be solved in polynomial time, even with randomized algorithms”.[1] More information about these classes are available here.
Also, a bit of a nitpick, but it looks like the claim that these widely-believed complexity-theoretic assumptions imply
is implicitly claiming that a “learning problem” similar to the following is NP-hard:
It seems likely that this learning problem is indeed computationally difficult, given that the naive approach requires brute-force search through circuits. However, to actually prove NP-hardness, there needs to be a reduction from 3SAT (say) to that problem. I don’t see any straightforward approach to construct such a reduction, even though it is not entirely implausible that one exists (possibly with minor modifications to the problem statement).
I do agree with the conclusion though that the problem being discussed here is very different from what is being captured by statements like P≠NP.
The statement P≠BPP actually means “randomized polynomial-time algorithms cannot be ‘derandomized’ by replacing them with equivalent deterministic polynomial-time algorithms”.
Thank you for the nit, of course you’re correct about the NP analogue of BPP. I’ll edit.
And also thanks for operationalizing the generalized circuit “inversion problem” that’s analogous to the parity game described. This is exactly what I meant, but you put it better than I could have.
I think that if you restrict your class of circuits C to have some special form (i.e., “choose your architecture”) and also choose some specific collection of test inputs x_1, .., x_N, then you can make the code of C to depend on a secret boolean string, and have C(x_1), .., C(x_N) to return specific SAT instances of this boolean string; so a sufficiently “fiddly” version of this circuit reconstruction problem is equivalent to P vs. NP. However I think you’re right that hardness of the most natural version (in particular, where the inputs x_i are chosen at random) is probably closer to the existence of one-way functions than to P vs. NP.
Confirming that efficiently finding a small circuit (you don’t actually need further restrictions than size) based on its values on a fixed collection of test inputs is known to imply NP⊆BPP --- see this paper.
This is a fact worth knowing and a lucid explanation—thanks for writing this!
I know it’s not the main point of the post, but I found myself a little lost when you talked about complexity theory; I would be interested to hear more details. In particular:
BPP-invertible⊂NP-invertibleLearning algorithms⊂BPP-invertible.When you say
what are the definitions of these classes? Are these decision problems in the usual sense, or is this a statement about learning theory? I haven’t been able to either track down a definition of e.g. “NP-invertible” or invent one that would make sense—if this is a renamed-version of something people have studied I would be curious to know the original name.
You claim “the assumption P≠BPP implies that there is no way, in polynomial time, to get reliable information about A’s secret circuit C (beyond some statistical regularities)”—I suspect this is not quite what you mean. I’m not sure exactly the formal statement you’re making here, but it would be pretty unusual for P≠BPP to be a relevant assumption (in fact most people in TCS strongly believe P=BPP). You call this “the probablistic version [of P≠NP]”, so I’m guessing this was a typo and you mean NP ⊄ BPP? But I would also be impressed if you could get something like this statement from only NP ⊄ BPP. My best guess would be that you do need those cryptographic assumptions you mentioned for this part (that is, if you want to say “P/poly is not PAC learnable”—in other words, no polynomial-time algorithm can, given random input-output samples from an arbitrary poly-size circuit, find a circuit computing a nearby function except with tiny success probability—I’m pretty sure this is only known conditional on the existence of one-way functions).
Again though, I know these are quibbles and the main point of this section is “you don’t need any unproven complexity assumptions to get lower bounds against statistical query learning”.
(Separately, a minor nit with the presentation: I found the decision to use “learning algorithm” to refer specifically to “weight-based ML architecture that learns via SGD” to be slightly confusing. The takeaway being “there exists an algorithm that learns parities, but there doesn’t exist a learning algorithm that learns parities” is a little slippery—I think it would be worth replacing “learning algorithm” with e.g. “gradient-based learning algorithm” in the writeup.)
Strong upvoted. I think the idea in this post could (if interpreted very generously) turn out to be pretty important for making progress at the more ambitious forms of interpretability. If we/the ais are able to pin down more details about what constitutes a valid learning story or a learnable curriculum, and tie that to the way gradient updates can be decomposed into signal on some circuit and noise on the rest of the network, then it seems like we should be able to understand each circuit as it corresponds to the endpoint of a training story, and each part of the training story should correspond to a simple modification of the circuit to add some more complexity. this is potentially better for interpretability than if it were easy for networks to learn huge chunks of structure all at once. How optimistic are you about there being general insights to be had about the structures of learnable curricula and their relation to networks’ internal structure?
Thanks! I definitely believe this, and I think we have a lot of evidence for this in both toy models and LLMs (I’m planning a couple of posts on this idea of “training stories”), and also theoretical reasons in some contexts. I’m not sure how easy it is to extend the specific approach used in the proof for parity to a general context. I think it inherently uses the fact of orthogonality of Fourier functions on boolean inputs, and understanding other ML algorithms in terms of nice orthogonal functions seems hard to do rigorously, unless you either make some kind of simplifying “presumption of independence” model on learnable algorithms or work in a toy context. In the toy case, there is a nice paper that does exactly this (explains how NN’s will tend to find “incrementally learnable” algorithms), by using a similar idea to the parity proof I outlined. This is the leap complexity paper (that Kaarel and I have looked into; I think you’ve also looked into related things)
I started reading, but I can’t understand what the parity problem is, in the section that ought to define it.
I guess, the parity problem is finding the set S given black-box access to the function, is it?
Parity in computing is whether the count of 1s in a binary string is even or odd, e.g. ’101′ has two 1s ⇒ even parity (to output 0 for even parity, XOR all bits like
1^0^1
.. to output 1 for this, XOR that result with 1).The parity problem (if I understand it correctly) sounds like trying to find out the minimum amount of data samples per input length a learning algorithm ought to need to figure out that a mapping between a binary input and a single bit output is equal to computing XOR parity and not something else (e.g. whether an integer is even/odd, or if there is a pattern in wannabe-random mapping, …), and the conclusion seems to be that you need exponentially more samples for linearly longer input .. unless you can figure out from other clues that you need to calculate parity in which case you just implement parity for any input size and you don’t need any additional sample data.
(FTR: I don’t understand the math here, I am just pattern matching to the usual way this kind problems go)
Why doesn’t it work to train on all the 1-hot input vectors using an architecture that suitably encodes Z_2 dot product and the only variable weights are those for the vector representing S? Does B not get to choose the inputs they will train with?Edit: Mentally swapped A with B in one place while reading.
I believe that instead of “Should you pay $1 for this box?”, the question should be “Should you pay $1 for each use of this box?”. Otherwise, I believe I can just cut the island in half, go back and forth between the sides a billion billion times and keep pressing the button and look for bias in the responses. Then throw away the worse half of the island and repeat
I think this is a crux for me. I don’t have a good guess how common this phenomenon is. The parity problem feels pathological in some sense, but I wouldn’t surprised if there are other classes of problems that would fall into the same category + are represented in some training data.
Yes I agree this is an important crux. I’m not sure which way I lean here. On the one hand, most specific things we can discover about human thinking are highly parallel. On the other hand, it seems very plausible that there are some complicated sequential things going on in the brain that don’t return partial outputs, which are in the same reference class as parity; if this is the case then, insofar as an LLM is reconstructing human brain function, it would need Bayesian-suboptimal “training wheels” to capture these processes.
Has anybody tried actual humans or smart LLMs? It would be interesting to know what methods people actually use.
There are in general simple algorithms for determining S in polynomial time, since it’s just a system of linear equations as in the post. Humans came up with those algorithms, and smart LLMs may be able to recognize the problem type and apply a suitable algorithm in chain-of-thought (with some probability of success).
However, average humans don’t know any linear algebra and almost certainly won’t be able to solve more than a trivial-sized problem instance. Most struggle with the very much simpler “Lights Out” puzzle.
Isn’t this exactly what Minsky showed to be impossible? You need an additional hidden layer.
Just making sure, if instead the box tells you the truth with probability 9999999999999999999999999999992, and gives a random answer for “warmer” or “colder” with the remaining 2^-100, then for a billion dollar prize it’s worth paying $1 for the box?
yes :)