Blog Needs a Name


Netpicking Part 3: Prediction Sets

It’s been a while since I looked at mnistk. Recently, I came across this interesting paper which describes something called Regularized Adaptive Prediction Sets (RAPS), a technique for uncertainty estimation. The RAPS algorithm wraps a model’s prediction scores to output a set that contains the true value with high probability for a given confidence level $\alpha$. Basically, instead of getting a single class output (“this image is a squirrel”), we can now get a set. See the example from the paper:

example of prediction sets

My original mnistk post used the analogy of an exam to select the best network. Let’s continue with that here. If a single output is like filling in a bubble in a multiple-choice test, then a prediction set output is like … shoddily filling in multiple bubbles.

would’ve filled more bubbles if I knew about prediction sets back then

This post is about using RAPS to help me choose the best network from mnistk.

Getting the data

I implemented the RAPS algorithm for mnistk a while back, but I didn’t get the time to test it out. The RAPS algorithm requires a calibration dataset when initializing the wrapper logic. I used the first 10000 samples from the QMNIST dataset for calibration, and the remaining 50000 samples as test data to obtain prediction sets. The original mnistk networks had 1001 networks and 12 snapshots per network. That gives 1001 x 12 x 50000 = 600.6 million prediction set data to look at.

It took around 20 hours (on some AWS EC2 torch cpuonly system) to get this data. I found that my implementation of the RAPS algorithm got a 2x speedup if I did all the torch neural network stuff first, and then used numpy vectorization to process the scores for calibration and testing.

Plotting helps … but not much

Simple things first, let’s look at how prediction set size is related to accuracy:

score1

Okay, prediction set size seems to follow a linear-ish relationship to the accuracy of the network. I did a bunch of other plots similar to the ones in Part 1 which showed similar group characteristics. But these plots aren’t helping me as much with what I want: I want some metric to help me pick the best network.

Looking at the top 10

Let’s look at the top 10 networks in terms of accuracy: I’ll use num_c, the number of correct answers (out of 50000). For simplicity, I picked the best performing snapshot for each network, leaving us with 1001 entries.

nameaccuracy_ranknum_c
ResNetStyle_85149636
ResNetStyle_71249610
ResNetStyle_56349599
ResNetStyle_58449597
ResNetStyle_69549582
ResNetStyle_81649572
ResNetStyle_89749569
Conv2dReLU_14849558
ResNetStyle_66949552
ResNetStyle_831049544

As expected, ResnetStyle models dominate. But the above scatterplot said prediction set sizes are closely related to accuracy. Let’s call a prediction with only one element in the prediction set as a sure prediction. Let’s look at the top 10 networks in terms of num_s the number of predictions about which they are sure:

nameaccuracy_ranknum_snum_c
Conv2dSELU_61364956849212
ResNetStyle_5844949849597
ResNetStyle_7124944249610
ResNetStyle_5634944049599
ResNetStyle_8164942549572
Conv2dReLU_1484938349558
ResNetStyle_8514937949636
ResNetStyle_47114933749541
ResNetStyle_8974932149569
ResNetStyle_88334931949497

Where did Conv2dSELU_6 come from!? Somehow it is dominating in terms of minimal prediction set size. If I deployed this model in the real world, I would rarely see any uncertainty reported along with its predictions. But it is not in the top 10 for accuracy; in fact its rank is 136, nowhere close to the top, so from where is it getting all that confidence?

Alright, best of both worlds. I want a model that not only gets a lot of answers correct, but also is sure (highly confident) about the correct answers it gets. Let’s look at the top 10 networks in terms of num_cs, the number of correct answers in which they are sure:

nameaccuracy_ranknum_cnum_snum_cs
ResNetStyle_584495974949849306
ResNetStyle_563495994944049260
ResNetStyle_712496104944249258
ResNetStyle_851496364937949237
ResNetStyle_816495724942549211
Conv2dReLU_148495584938349181
ResNetStyle_4711495414933749128
ResNetStyle_897495694932149127
ResNetStyle_9220495264927249090
ResNetStyle_8833494974931949087

Now the top 10 is seeing some shakeups! The original “best” network is sometimes not sure when being correct, so it falls down the rankings. What happens when the model’s predictions are wrong (don’t match with ground truth) or the model is unsure about the predictions (set size is greater than 1)?

Each prediction made by the model can fall into one of four classes:

sureunsuretotal
correctnum_csnum_cunum_c
wrongnum_wsnum_wunum_w
totalnum_snum_u50000

Hmm num_wu indicates there is a fifth class of predictions:

Let’s look at the top 10 again with these five quantities:

nameaccuracy_ranknum_csnum_cunum_wsnum_wunum_wenum_sum5
ResNetStyle_584493062911922918250000
ResNetStyle_563492603391802120050000
ResNetStyle_712492583521842518150000
ResNetStyle_851492373991422519750000
ResNetStyle_816492113612142319150000
Conv2dReLU_148491813772022921150000
ResNetStyle_4711491284132093221850000
ResNetStyle_897491274421942321450000
ResNetStyle_9220490904361823725550000
ResNetStyle_8833490874102323623550000

How can we use the information about the models’ less-than-ideal predictions (everything except num_cs) to find the best network? Keeping with the analogy, we can have negative marking, that nasty troll from competitive exams. First let’s subtract a mark for each unsure or incorrect answer:

nameaccuracy_rankcs - allnum_csnum_cunum_wsnum_wunum_we
ResNetStyle_584486124930629119229182
ResNetStyle_563485204926033918021200
ResNetStyle_712485164925835218425181
ResNetStyle_851484744923739914225197
ResNetStyle_816484224921136121423191
Conv2dReLU_148483624918137720229211
ResNetStyle_4711482564912841320932218
ResNetStyle_897482544912744219423214
ResNetStyle_9220481804909043618237255
ResNetStyle_8833481744908741023236235

That didn’t change the order at all, so let’s have weighted negative marking.

Now let’s see if the top 10 shift:

nameaccuracy_rankcs - weightednum_csnum_cunum_wsnum_wunum_we
ResNetStyle_584435394930629119229182
ResNetStyle_851433384923739914225197
ResNetStyle_563432774926033918021200
ResNetStyle_712432654925835218425181
ResNetStyle_816427424921136121423191
Conv2dReLU_148425464918137720229211
ResNetStyle_897422284912744219423214
ResNetStyle_4711421104912841320932218
ResNetStyle_695419984907550716723228
ResNetStyle_9220419174909043618237255

Indeed they did: the overall rankings shifted, and a network from the original top 10 (ResnetStyle_69) sneaked back in because of the new scoring scale. Of course, the weights I picked were arbitrary, but the point is that the incorrect/unsure predictions also need to considered when trying to choose the best network.

Closing Notes

Regularized Adaptive Prediction Sets (RAPS) provide additional context when I need to choose the best network. Earlier, I could examine networks in terms of accuracy/weights/memory use/training time. Now, with surface-level information about the prediction sets (just set sizes), I can break down a raw accuracy score into 5 different components, and choose the best network for my use case by assigning weights for the types of errors I don’t want. This can be useful in deployment: I am fine with networks on embedded devices having confidence issues about predictions (especially for incorrect predictions). But complex networks in the cloud should provide more confidence, because they are usually the last resort. I am not certain (!) about what should be done when the initial prediction is unsure, and the backup gives a different answer (the num_wu case) because that seems ripe for confusion.

There are other interesting angles to examine. I want to see why that Conv2dSELU_6 has so much confidence, and what means for the individual predictions it makes. The RAPS algorithm provides two hyperparameters, $k$ and $\lambda$ in addition to $\alpha$, and a wrapper to find the optimal $k$ and $\lambda$ for a given network during calibration. I used the same $k$ and $\lambda$ values throughout, so that leaves me to wonder how much the above tables are affected if I use the optimal values. I have also not used the calibrated scores or the generalized quantile values anywhere in the analysis. That’s probably worth a separate round on AWS and another post.