Deep learning for tabular data 2 - Debunking the myth of the black box
Not knowing how to apply deep neural networks (DNNs) to the sort of tabular data commonly found in industry is a common reason for not adopting DNNs in production and the topic of my last post. Another common reason for hesitating to use DNNs is what I’ll call the “myth of the black box”. That is, the idea that unlike simpler models, such as linear regression or decision trees which come with “interpretable” linear coefficients or feature importances, DNNs are hard to understand.
There is a certain sense in which this sentiment is true, and any inherit features of a model that could decrease our confidence in that model carry important implications - especially if the model’s predictions come with high stakes such as decisions regarding a patient’s health care, whether someone will be denied a home mortgage or large monetary investments. Such cases come with a large responsibility.
With that context in mind, this post aims to make a claim that the idea of DNNs being a complete black box is false, even in spite of being much more computationally complex than models such as linear regression. In order to debunk this myth we’ll take a look at a variety of methods for inspecting a trained DNN. In particular we’ll look at concrete examples of a model trained on a non trivial data set. Lastly we’ll consider how these methods compare with attempting to interpret linear coefficients and feature importances.
Preamble
For this post MLB’s statcast data was used to train several to predict if a given plate appearance resulted in a walk or hit. The data set contains every pitch from every plate appearance in the 2018 season, minus a few PAs where something unusual happened - such as the batter or pitcher injured mid PA or the PA ended by a runner caught steeling.
To keep this post clean, it won’t include any code. If you are curious about the data or how the models were trained the notebook can be found on my GitHub account, just keep in mind the models are trained for illustrative purposes for this post.
Lastly, keep in mind that the term “deep neural networks (DNNs)” is used rather loosely in this post to mean any neural network, especially those with many hidden layers or that use architectures typically applied to artificial intelligence tasks (computer vision, natural language processing, etc.)
Using embeddings
Embeddings are great for representing categorical variables in tabular data for both performance and as a way of introducing “transparency” into what a model is learning.
Embeddings are probably best known in NLP in the context of word embeddings. A classic example of evaluating the quality of a learned word embedding was introduced with Word2Vec. Experimental examples from the paper include “$vector(“King”) - vector(“Man”) + vector(“Woman”)$ results in a vector that is closest to the vector representation of the word Queen” which is amazing.
Let’s take a look at the benefit of using embeddings in the context of our problem.
An example
The model
The following results are generated from a recurrent neural network (specifically an LSTM) trained on the MLB Statcast data to predict if a plate appearance (PA) resulted in the batter reaching base safely (a walk or hit).
A given PA is represented as a sequence of vectors, one for each pitch of the PA. Each vector is made up of around 30 features such as who was batting, who was pitching, the pitch speed, location, and movement, the count, etc. Batch normalization is applied to the raw numeric features and all categorical features are represented as 1-dimensional embeddings, except for the batter and pitcher which are encoded into 2-dimensional embeddings. As mentioned above I’m glazing over the specific details of the model and data but they can be found on GitHub (linked above).
Pitcher embeddings
So once the model is trained how do we go about inspecting the embeddings, say specifically the pitcher embeddings?
It’s unlikely that we could inspect these embeddings using any sort of vector arithmetic like the Word2Vec examples since pitchers don’t have “syntactic and semantic” relationships like words do. However, a basic understanding of how this model “works” suggests that if we take a look we should find some sort of relationship between pitchers that are embedded “close” to one another.
Since we’ve embedded the pitchers into two dimensions, a convenient way to look at all of these relationships at once is to plot those 2D embeddings. We could look at each point one by one, but an easier (and perhaps more reliable way) is to shade each point by the pitchers WHIP (the well known metric in baseball that represents how many walks/hits a pitcher gives up in an inning on average - literally how many positive cases of our target variable given up per inning) and look for general trends in the embeddings. With our trained model this plot looks as follows
<!DOCTYPE html>
Remarkably we see almost exactly what our intuition tells us we should find: a general trend in the embeddings, specifically a gradient from top-left to bottom-right of pitchers with the highest WHIPs in the top-left corner and pitchers with some of lowest in the bottom-right. You can mouse over the points to see which pitcher is embedded where, but be careful trying to reach any conclusions about exact positions from this plot alone.
Batter embeddings
Let’s take a similar approach to inspecting the learned batter embeddings. Since we are predicting whether the plate appearance results in a walk or hit a convenient metric to shade the points by this time is OPS a weighted average of sorts that favors batters who reach base often, especially for extra base hits.
<!DOCTYPE html>
As with the pitcher embeddings we see a general trend, coincidentally from top-left to bottom-right again. Once again, although it’s difficult to draw conclusions about the location of each batter embedding, overall this plot instills confidence that the model is learning about meaningful trends in the data and not random noise.
Using Attention
Another great way to introduce visibility into a model is attention. Like embeddings, attention has been incredibly successful in the domain of natural language processing.
Attention is a subject worthy of its own post (or posts), and there is no shortage of great posts covering this topic. For this reason we will omit any attempts to describe what attention is or how it works in any detail here.
For the purpose of this post it suffices to know that attention is a technique for sequence based tasks (this talk gives a great overview of sequence to sequence learning). In most cases the sequences are sentences, ie. sequences of words, or word vectors, and an “attention mechanism” allows models to learn which words are “most important”. The attention mechanism does this by assigning a weight (real valued scalar, usually the output of a softmax) to each vector in the sequence.
This means that attention will only be useful for certain kinds of problems, unlike embeddings which can be applied to any tabular data set with categorical variables.
A few examples
Let’s take a look at an example: predicting a walk or hit from a sequence of pitches thrown in a given plate appearance. The results below are from the same LSTM described above where the attention mechanism is applied to the output of the LSTM.
Segura walks and Mookie homers
Our first example is Jean Segura walking. The model predicted Segura getting on base with probability 73%. The attention assigned to each pitch in this at-bat is included in the table below. Notice the model pays the most attention to the final two pitches - both of which are a full count - with the most attention assigned to the final pitch. This is great since it matches up with how we would view this PA intuitively.
des | attention | pitch_name | zone | balls | strikes | total_pitch_number | fld_score | bat_score | |
---|---|---|---|---|---|---|---|---|---|
0 | NaN | 0.000008 | Split Finger | 7.0 | 0.0 | 0.0 | 65 | 4.0 | 2.0 |
1 | NaN | 0.000004 | Knuckle Curve | 14.0 | 0.0 | 1.0 | 66 | 4.0 | 2.0 |
2 | NaN | 0.000011 | 4-Seam Fastball | 2.0 | 1.0 | 1.0 | 67 | 4.0 | 2.0 |
3 | NaN | 0.000017 | 4-Seam Fastball | 12.0 | 1.0 | 2.0 | 68 | 4.0 | 2.0 |
4 | NaN | 0.000235 | Split Finger | 14.0 | 2.0 | 2.0 | 69 | 4.0 | 2.0 |
5 | NaN | 0.073793 | 4-Seam Fastball | 6.0 | 3.0 | 2.0 | 70 | 4.0 | 2.0 |
6 | Jean Segura walks. | 0.925931 | 4-Seam Fastball | 14.0 | 3.0 | 2.0 | 71 | 4.0 | 2.0 |
Similarly, when Mookie Betts hits a home run in this example, we find the model attends most closely to fastballs thrown down the middle of the plate (which are arguably easier pitchers to hit).
des | attention | pitch_name | zone | balls | strikes | total_pitch_number | fld_score | bat_score | |
---|---|---|---|---|---|---|---|---|---|
0 | NaN | 0.245446 | 2-Seam Fastball | 14.0 | 0.0 | 0.0 | 1 | 0.0 | 0.0 |
1 | NaN | 0.293591 | 2-Seam Fastball | 5.0 | 1.0 | 0.0 | 2 | 0.0 | 0.0 |
2 | Mookie Betts homers (22) on a fly ball to left... | 0.460964 | 2-Seam Fastball | 2.0 | 1.0 | 1.0 | 3 | 0.0 | 0.0 |
J. T. Realmuto grounds out
Let’s look at one last example of how our model applies attention.
Similar to Mookie’s plate appearance, J. T. Realmuto is predicted to get on base with probability 43% as the mdoel attends to a 2-seam fastball thrown down the middle. Even though the model didn’t predict that J. T. would get on base with an extremely high probability it still suggests that Realmuto was 55% more likely to get on base in this plate appearance than he would in an average at-bat (considering his batting average was .277 last year).
This makes the fact that J. T. grounded out instead of reaching base more interesting, perhaps the model “got this one wrong?”
However, we can see the model attending heavily to the final pitch, a 2-seam fastball thrown right down the middle of the plate in a favorable count, which if we are honest with ourselves, seems like a good opportunity for J. T. to get on base.
des | attention | pitch_name | zone | balls | strikes | total_pitch_number | fld_score | bat_score | |
---|---|---|---|---|---|---|---|---|---|
0 | NaN | 0.044510 | Cutter | 3.0 | 0.0 | 0.0 | 1 | 5.0 | 4.0 |
1 | NaN | 0.027510 | 2-Seam Fastball | 14.0 | 0.0 | 1.0 | 2 | 5.0 | 4.0 |
2 | NaN | 0.056896 | Cutter | 14.0 | 1.0 | 1.0 | 3 | 5.0 | 4.0 |
3 | J. T. Realmuto grounds out, shortstop Asdru... | 0.871084 | 2-Seam Fastball | 5.0 | 2.0 | 1.0 | 4 | 5.0 | 4.0 |
Other ideas
Let’s briefly talk about a few more ideas that are likely lesser known.
Probing the model
“Probing the model” is a good way to inspect what a model is learning for numeric values. By probing the model, we mean fixing all input features but one and generating a bunch of predictions by altering that one feature’s value.
For example, applying this technique to Segura’s walk mentioned above, we can get predictions by tweaking the feature representing the total number of pitches thrown by the pitcher at the beginning of the at bat from 1 to 90.
Conventional wisdom, that the probability of Segura walking increases as the pitcher has thrown more pitches, is realized in the plot below.
Counterfactuals
Another interesting idea that has come up is finding counterfactuals. It’s similar to probing the model but defined in a more rigorous, automated way, for example as an optimization problem which finds the “smallest” feature change that results in a different outcome. This paper, though lengthy, describes the idea well.
A “hybrid” model
The last approach we will mention is to use a “hybrid” model. That is a model where a DNN to “encodes” some of the features and where this encoding is then concatenated alongside the other features held out from the DNN and fed into a simpler model such as a linear regression or decision tree. An example of this sort of model is in the notebook linked at the top of this post.
Discussion
Comparison to feature coefficients and feature importances
Before reaching any conclusions about whether or not the “myth of the black box” has been debunked by this post, let’s see how the methods discussed in this post compare with linear regression. After all, this post is mainly geared towards data scientists already using simpler models but are hesitant to take the leap to methods like DNNs.
With a linear model, and its variants, we get a set of weights $\beta_{i}, \ i>0$ and an intercept $\beta_{0}$ and our prediction is calculated as
\[y = \sum_{i=1}^{n}{\beta_{i}x_{i}} + \beta_{0}\]All of this simply means that we can make certain inferences about features based on their coefficients. For example, it is common to infer that $|\beta_{i}| > |\beta_{j}|$ implies $x_{i}$ is in some sense more important than $x_{j}$ (given certain assumptions, eg. feature scaling).
For example, using a trained logistic regression model on the same Statcast data to rank pitchers we find a list topped by the 2018 Cy Young Winners Blake Snell and Jacob deGrom along with Cy Young Candidate (and Philadelphia Phillie!) Arron Nola, 2018 world series pitcher David Price and other elite pitchers of 2018.
player_name | coef | |
---|---|---|
1 | Aaron Nola | -0.094848 |
2 | Trevor Bauer | -0.073879 |
3 | German Marquez | -0.073038 |
4 | Tanner Roark | -0.069197 |
5 | Blake Snell | -0.068322 |
6 | J.A. Happ | -0.065919 |
7 | Kyle Hendricks | -0.061301 |
8 | David Price | -0.060981 |
9 | Jacob deGrom | -0.060951 |
10 | Noah Syndergaard | -0.055206 |
As with the DNNs, inspecting the logistic regression in this way is a great sanity check.
At this point some might argue that feature coefficients still maintain an advantage over the methods we looked at above because they are “interpretable”. But is this actually true? For example, what is the interpretation of Nola’s -0.094848 coefficient? If this was linear regression we could attempt to interpret it based on its linear relationship with the target variable, but the logistic function gives us no such favors. Or, what is the interpretation as to why the rankings above aren’t led by Blake Snell, who objectively put up better stats than some of those below him and interpreting the coefficients on features such as number of strikes/balls in the count is even more gray. As a sanity check inspecting the linear coefficients can be indispensably useful, but in the context of any machine learning task with more than a modest number of features (such as ours, which is still only a small feature set in reality) such interpretations fail to hold up to scrutiny.
Lastly, it’s also worth mentioning that the “complexity” of the LSTM actually gives us more methods of inspecting what the model is linear than linear coefficients.
The myth of the black box
In light of our comparison above and having looked at a number of concrete examples in which we were able to inspect an LSTM trained on a non-trivial data set, it’s safe to conclude that neural networks are not a complete “black box.”