2024-07-01

Fast Dice

Don't take any chances without knowing the odds

I got nerd sniped by a fun JavaScript performance puzzle recently, having to do with efficiently calculating the probability of a dice-based Bernoulli trial, for the purpose of a game I’m working on. It goes like this:

  • The player works to collect sets of up to 4 dice, which can be d4s, d6s, or d10s.

  • The player then chooses up to 10 sets of dice from their collection and rolls them.

  • If at least 3 of the sets rolled have a total of 10 or more, the player advances to the next round. All rolled sets are removed from the player’s collection, regardless of whether they advance or not.

It’s helpful to be able to calculate the exact probabilities involved when balancing the game, in addition to playtesting normally. Furthermore, having the exact probabilities allows game design elements to reflect the chance of success (e.g. showing things in a different color) in a way that gives the player some useful information while still leaving a bit of uncertainty. However, with the way this process is designed, calculating the probabilities presents a few challenges:

  • Sets can have any combination of dice, e.g. 4d6, 2d4+d10, 3d6+d10, etc.

  • Sets of dice are tested based on their sum.

  • The player advances if at least 3 sets score 10 or more, but have the option to play more than 3 if they think it would help their chances enough to be worth it.

So, to get this problem out of the way and avoid coming up with any tricky equations, I went with the most general solution I know: convolutions of probability distributions. (This formulation is equivalent to multiplying probability generating functions, since we are dealing with discrete random variables.)

Background: Too convoluted?

The calculation works in two passes:

  1. For each set of dice, generate a CDF of the dice total by reducing the PMFs of the dice with convolutions and doing a cumulative sum. The CDF at 9 represents the Bernoulli parameter for whether the set fails to reach a score of 10.

  2. For each Bernoulli parameter p calculated in the previous step, generate a PMF [p, 1-p] to represent a random variable that has a value of 1 if the test succeeds and 0 otherwise. Reduce these by convolution and do a cumulative sum to get the CDF for the sum of these scores. The CDF at 2 represents the Bernoulli parameter for whether fewer than 3 sets succeeded.

If you’re not familiar with convolution, the 3blue1brown video about it is an excellent introduction. However, if you’re in a hurry, you can think of it as a multiplication of two polynomials, where the i-th element of each input is the coefficient of xi. The fact that the probability distribution of a sum of random variables is a convolution of their individual distributions is extremely useful for numerically calculating probabilities where simpler analytic solutions are out of reach, and is the foundation of the approach outlined above.

Attempt 1: Just write the code

These days JavaScript VMs are very fast, so I generally try not to overthink things unless there is a clear need for it. Naturally I wrote some code that looked like this:

function range(n) {
  return [...Array(n).keys()];
}
function sum(a) {
  return a.reduce((x, y) => x + y, 0);
}
function convolve(a, b) {
  return range(a.length + b.length - 1).map((i) =>
    sum(range(b.length).map((j) => (a[i - j] ?? 0) * b[j]))
  );
}
function cumsum(a) {
  let sum = 0;
  return a.map((x) => (sum += x));
}
function dicepmf(n) {
  return range(n + 1).map((i) => (1 <= i && i <= n ? 1 / n : 0));
}

function pSetFail(ds) {
  const cdf = cumsum(ds.map((n) => dicepmf(n)).reduce(convolve, [1]));
  return cdf[9] ?? 1;
}
function pSetsFail(sets) {
  const pdf = sets
    .map(pSetFail)
    .map((p) => [p, 1 - p])
    .reduce(convolve, [1]);
  const cdf = cumsum(pdf);
  return cdf[2] ?? 1;
}

const sets = [
  [4, 6, 6],
  [4, 6, 6],
  [4, 6, 6],
];
console.log(1 - pSetsFail(sets));

If we run the above, we get 0.125, which is what we would expect, since the odds of d4+2d6 totaling 10 or more is exactly 50%. In other words, we’re flipping a coin 3 times and trying to get HHH.

So this solution works and has the benefit of being concise and readable, but how fast is it? Naive convolution is O(n2), so let’s time its worst case in our context, 10 attempts of 4d10:

const worstcase = range(10).map(() => [10, 10, 10, 10]);

const start = performance.now();
for (let i = 0; i < 1000; i++) pSetsFail(worstcase);
const end = performance.now();

console.log((start - end) / 1000);

This comes out to about 1.4ms on my machine in Node.js, which is definitely usable if you only need it occasionally. Okay, problem solved, let’s move onto the next thing, we’ve got a game to build.

…But it does seem a bit slow, doesn’t it? We’re only dealing with 40 dice here, and only up to d10. What if we want to offer d20s or d100s? We’re just doing multiplications and additions, so surely we can do better, right? We should be able to call this 100 times a frame if we want!

Attempt 2: Loops

The convolve function we wrote is definitely a big part of the problem, and you don’t need a profiler to figure that out. It’s two loops, where the inner loop is generating an array just to calculate its sum. We’re also deliberately looking up keys that don’t exist and using the nullish coalescing operator to convert them to 0s, instead of explicitly checking the index. Furthermore, we’re calculating up to 41 elements of each PMF when we only need the first 10. We can omit those without changing the result. There’s a lot of room for some cheap improvements and we owe it to ourselves to at least try, don’t we?

In particular, we suspect the following things might help:

  1. Replace most calls to range(), map(), and reduce() with explicit loops and mutation.

  2. Replace the nullish coalescing operator with an explicit bounds check.

  3. Truncate the PMFs.

When we make the above changes, the code looks like this:

function convolve(a, b, limit) {
  const n = Math.min(limit, a.length + b.length - 1);
  const out = [];
  for (let i = 0; i < n; i++) {
    out[i] = 0;
    for (let j = 0; j < b.length; j++) {
      const k = i - j;
      if (0 <= k && k < a.length) {
        out[i] += a[k] * b[j];
      }
    }
  }
  return out;
}
function cumsum(a) {
  for (let i = 1; i < a.length; i++) {
    a[i] += a[i - 1];
  }
  return a;
}
function dicepmf(n) {
  const out = [0];
  for (let i = 1; i <= n; i++) {
    out[i] = 1 / n;
  }
  return out;
}

function pSetFail(ds) {
  let pdf = [1];
  for (const n of ds) {
    pdf = convolve(pdf, dicepmf(n), 10);
  }
  const cdf = cumsum(pdf);
  return cdf.length < 10 ? 1 : cdf[9];
}
function pSetsFail(sets) {
  let pdf = [1];
  for (const ds of sets) {
    const p = pSetFail(ds);
    pdf = convolve(pdf, [p, 1 - p], 3);
  }
  const cdf = cumsum(pdf);
  return cdf.length < 3 ? 1 : cdf[2];
}

We’ve sacrificed a bit of conciseness perhaps but this is still quite readable. After quickly double-checking that our first example still prints 0.125, we time it and find that each calculation of the worst case input only takes around 15 microseconds. That’s a 90x improvement!

Of course, we should probably identify how much of an impact each of the above changes had, so here’s rough timing figures for each one individually:

  • Loops and mutation: 340us, 4.0x improvement

  • Bounds checks: 1.0ms, 1.3x improvement

  • PMF truncation: 670us, 2.0x improvement

Loops definitely seem to have the biggest impact by themselves… but where is the 90x improvement coming from when taken together? Microbenching a language like JavaScript in the particular way I’m doing isn’t an exact science but there is definitely something fishy going on. Let’s try the other direction and remove each optimization from the fastest solution to see which one results in the biggest slowdown:

  • No loops and mutation: 410us, 28x slowdown

  • No bounds checks: 290us, 20x slowdown

  • No PMF truncation: 39us, 2.6x slowdown

Bizarre! Loops once again seem to responsible for the biggest improvement, but the bounds checks are an impressive factor as well. I double checked the code to make sure I didn’t get something wrong here, but using explicit bounds checks does seem to be responsible for a noticeable improvement. I don’t know enough about V8 to know why this would be the case, but it’s an interesting thing to keep in mind when trying to write JIT-friendly code I suppose. Maybe I’ll do a deep dive someday to figure out why this happens.

But okay, 15 microseconds is pretty dang fast, and that’s a worst case! If we use the 3 sets of d4+2d6 input we’ve been using for validation, we get speeds closer to 2.5 microseconds. So we’re done, right?

Right??

Attempt 3: Insanity

Our calculation has a very predictable shape: ten times do 10-element convolutions of four sequences and sum their entries, then do 3-element convolutions of ten sequences and sum those. What if we just… hard-coded this? No loops, minimal branches. How fast would it be?

I won’t bore you with the details and will just show you the code I came up with. I’m not going to use this code, of course. It just felt like a fun puzzle. Here’s the function that calculates the same result as pSetFail() above, with slightly different parameters:

function pSetFail(d0, d1, d2, d3) {
  let w0, w1, w2, w3, w4, w5, w6, w7, w8, w9;
  let x0, x1, x2, x3, x4, x5, x6, x7, x8, x9;
  let y0, y1, y2, y3, y4, y5, y6, y7, y8, y9;
  let z0, z1, z2, z3, z4, z5, z6, z7, z8, z9;
  let m;

  w0 = d0 === 0 ? 1 : 0;
  w1 = d0  >= 1 ? 1 : 0;
  w2 = d0  >= 2 ? 1 : 0;
  w3 = d0  >= 3 ? 1 : 0;
  w4 = d0  >= 4 ? 1 : 0;
  w5 = d0  >= 5 ? 1 : 0;
  w6 = d0  >= 6 ? 1 : 0;
  w7 = d0  >= 7 ? 1 : 0;
  w8 = d0  >= 8 ? 1 : 0;
  w9 = d0  >= 9 ? 1 : 0;

  m = 1;
  m *= d0 === 0 ? 1 : d0;
  m *= d1 === 0 ? 1 : d1;
  m *= d2 === 0 ? 1 : d2;
  m *= d3 === 0 ? 1 : d3;

  switch (d1) {
    case 0:
      x0=w0;       x1=w1;       x2=w2;       x3=w3;       x4=w4;
      x5=w5;       x6=w6;       x7=w7;       x8=w8;       x9=w9;       break;
    case 4:
      x0=0;        x1=w0+x0;    x2=w1+x1;    x3=w2+x2;    x4=w3+x3;
      x5=w4+x4-w0; x6=w5+x5-w1; x7=w6+x6-w2; x8=w7+x7-w3; x9=w8+x8-w4; break;
    case 6:
      x0=0;        x1=w0+x0;    x2=w1+x1;    x3=w2+x2;    x4=w3+x3;
      x5=w4+x4;    x6=w5+x5;    x7=w6+x6-w0; x8=w7+x7-w1; x9=w8+x8-w2; break;
    case 10:
      x0=0;        x1=w0+x0;    x2=w1+x1;    x3=w2+x2;    x4=w3+x3;
      x5=w4+x4;    x6=w5+x5;    x7=w6+x6;    x8=w7+x7;    x9=w8+x8;    break;
  }

  switch (d2) {
    case 0:
      y0=x0;       y1=x1;       y2=x2;       y3=x3;       y4=x4;
      y5=x5;       y6=x6;       y7=x7;       y8=x8;       y9=x9;       break;
    case 4:
      y0=0;        y1=x0+y0;    y2=x1+y1;    y3=x2+y2;    y4=x3+y3;
      y5=x4+y4-x0; y6=x5+y5-x1; y7=x6+y6-x2; y8=x7+y7-x3; y9=x8+y8-x4; break;
    case 6:
      y0=0;        y1=x0+y0;    y2=x1+y1;    y3=x2+y2;    y4=x3+y3;
      y5=x4+y4;    y6=x5+y5;    y7=x6+y6-x0; y8=x7+y7-x1; y9=x8+y8-x2; break;
    case 10:
      y0=0;        y1=x0+y0;    y2=x1+y1;    y3=x2+y2;    y4=x3+y3;
      y5=x4+y4;    y6=x5+y5;    y7=x6+y6;    y8=x7+y7;    y9=x8+y8;    break;
  }

  switch (d3) {
    case 0:
      z0=y0;       z1=y1;       z2=y2;       z3=y3;       z4=y4;
      z5=y5;       z6=y6;       z7=y7;       z8=y8;       z9=y9;       break;
    case 4:
      z0=0;        z1=y0+z0;    z2=y1+z1;    z3=y2+z2;    z4=y3+z3;
      z5=y4+z4-y0; z6=y5+z5-y1; z7=y6+z6-y2; z8=y7+z7-y3; z9=y8+z8-y4; break;
    case 6:
      z0=0;        z1=y0+z0;    z2=y1+z1;    z3=y2+z2;    z4=y3+z3;
      z5=y4+z4;    z6=y5+z5;    z7=y6+z6-y0; z8=y7+z7-y1; z9=y8+z8-y2; break;
    case 10:
      z0=0;        z1=y0+z0;    z2=y1+z1;    z3=y2+z2;    z4=y3+z3;
      z5=y4+z4;    z6=y5+z5;    z7=y6+z6;    z8=y7+z7;    z9=y8+z8;    break;
  }

  return (z0+z1+z2+z3+z4+z5+z6+z7+z8+z9)/m;
}

And here’s the function that computes the same result as pSetsFail() above, again with slightly different parameters:

function pSetsFail(ds) {
  const p0 = pSetFail(ds[ 0], ds[ 1], ds[ 2], ds[ 3]);
  const p1 = pSetFail(ds[ 4], ds[ 5], ds[ 6], ds[ 7]);
  const p2 = pSetFail(ds[ 8], ds[ 9], ds[10], ds[11]);
  const p3 = pSetFail(ds[12], ds[13], ds[14], ds[15]);
  const p4 = pSetFail(ds[16], ds[17], ds[18], ds[19]);
  const p5 = pSetFail(ds[20], ds[21], ds[22], ds[23]);
  const p6 = pSetFail(ds[24], ds[25], ds[26], ds[27]);
  const p7 = pSetFail(ds[28], ds[29], ds[30], ds[31]);
  const p8 = pSetFail(ds[32], ds[33], ds[34], ds[35]);
  const p9 = pSetFail(ds[36], ds[37], ds[38], ds[39]);

  const a1 =          1-p0;

  const b0 = p1*p0;
  const b1 = p1*a1 + (1-p1)*p0;
  const b2 =         (1-p1)*a1;

  const c0 = p2*b0;
  const c1 = p2*b1 + (1-p2)*b0;
  const c2 = p2*b2 + (1-p2)*b1;

  const d0 = p3*c0;
  const d1 = p3*c1 + (1-p3)*c0;
  const d2 = p3*c2 + (1-p3)*c1;

  const e0 = p4*d0;
  const e1 = p4*d1 + (1-p4)*d0;
  const e2 = p4*d2 + (1-p4)*d1;

  const f0 = p5*e0;
  const f1 = p5*e1 + (1-p5)*e0;
  const f2 = p5*e2 + (1-p5)*e1;

  const g0 = p6*f0;
  const g1 = p6*f1 + (1-p6)*f0;
  const g2 = p6*f2 + (1-p6)*f1;

  const h0 = p7*g0;
  const h1 = p7*g1 + (1-p7)*g0;
  const h2 = p7*g2 + (1-p7)*g1;

  const i0 = p8*h0;
  const i1 = p8*h1 + (1-p8)*h0;
  const i2 = p8*h2 + (1-p8)*h1;

  const j0 = p9*i0;
  const j1 = p9*i1 + (1-p9)*i0;
  const j2 = p9*i2 + (1-p9)*i1;

  return j0+j1+j2;
}

When running this on the worst case input with the same benchmarking strategy as before, I get times of 400 nanoseconds.

Why is it so much faster? Well, I can’t be 100% sure without digging deeper, but my guess is this is very JIT-friendly code, and that the CPU code it generates is very cache-friendly. pSetFail is also all integer sums and products until the return statement, which probably helps a bit. Maybe with a bit more digging it could be made even faster. But I think 400ns is pretty fast. That’s a 3000x improvement from where we started. Not bad!

By the way, if you look very closely at pSetFail, you might notice that it looks a lot like the ultimate brute force solution: counting up the possible ways to get different sums.

The future?

We’re still using a naive convolution algorithm, one essentially based directly on the definition. For small sequences this works well enough, but for larger sequences this becomes prohibitive. This type of convolution comes up in signal processing a lot, where you might want to compute the convolution of two sequences that have tens of thousands of elements. In those scenarios you can take advantage of the convolution theorem, which lets you multiply the Fourier transforms of the inputs point-wise and do an inverse Fourier transform to get the same result. This is also the basis of some fast integer multiplication algorithms.

Are our inputs too small to benefit from this knowledge? Probably. But I can’t help but be a little curious. If I ever decide to give it a shot, I’ll be sure to write about it.