Neural Additive Models: Interpretable Machine Learning with Neural Nets

2022. 3. 1. 23:26Neural Networks/Interpretable AI

 

 

논문에 대한 오역, 의역등이 다수 포함되어 있습니다. 댓글로 많은 의견 부탁드립니다


Author: Rishabh Agarwal, Levi Melnick, Nicholas Frosst, Xuezhou Zhang, Ben Lengerich, Rich Caruana, Geoffrey Hinton

 

Neural Additive Models, 줄여서 NAM은 심층신경망의 해석가능성을 제시한 논문입니다. 

NAM을 읽고 난 후, 이게 될까 싶었는데 실제로 해보니 어느정도 괜찮은 성능을 보였습니다. 

Official code는 TF1.x로 구현되었으며, 제 깃허브를 참조하시면 TF2.x도 사용할 수 있습니다. 

 

https://github.com/merchen911/NAM

 

GitHub - merchen911/NAM: Neural Additive Model with tensorflow 2.x

Neural Additive Model with tensorflow 2.x. Contribute to merchen911/NAM development by creating an account on GitHub.

github.com


1. Introduction

  인공신경망을 활용한 모델들은 Computer vision (CV) 이나 language process (NLP) 에서 엄청난 성능을 보이고 있는 반면, 그 결과에 대한 해석 가능성은 여전히 분분한 상황입니다. CV 모델들은 꽤 많은 부분들이 convolution network에 기반한 만큼 convolutional activation map을 개선한 다양한 방법들을 활용할 수 있습니다. 한편, 최근 CV trend는 attention 기반 방법들도 좋은 성능을 보이고 있습니다. NLP를 위한 모델로도 transformer를 활용한 방법부터, electra까지 다양한 학습방법이 제안되었습니다. 하지만 이 방법들은 표 형식의 데이터를 위해 고안된 방법들이 아닐뿐 더러, 결과 해석에 대한 분분한 의견차가 발생합니다. 

  해석을 위한 논문으로써, DeepLift, Layer-wise relevance propagation (LRP), LIME과 더불어 최근 주목받는 SHAP이 있습니다. 이 방법들은 임의의 모델 $f_x$를 학습한 이후 post-hoc 방식으로 새로운 모델 $g_x$를 활용하여 $f_x$를 모방합니다. 그리고, $g_x$에 기반하여 $f_x$를 해석합니다. NAM이 고안된 목적은 인공신경망의 예측 결과에 대한 해석 가능성 도입입니다. NAM은 general additive models (GAM)에 기반을 두었으며, 이것은 LIME이나 SHAP과 일부 배경을 공유합니다. GAM은 아주 단순하며, 때문에 강력합니다. 

$$ g(\mathbb{E}[y])=\beta + f_1(x_1) + f_2(x_2) + \cdots + f_K(x_K) $$

위 식은 일반적인 GAM 식을 보여줍니다. 그리고 위 식은 간단한 선형 조합과 같다는 것을 알 수 있습니다. 총 $K$개의 변수가 존재할 때, $K$개를 각각 전담하는 함수 $f_i$를 활용하여 $x_i$로부터 특징을 추출해내고 bias역할을 하는 $\beta$와 더하는 방식으로 결과를 예측합니다. NAM은 이 식에서 함수 $f_i$를 간단한 모듈로 구성했습니다. 그리고 학습가능한 모듈들을 활용하여 변수 $x_i$로부터 가중치를 얻은 후, 가중치의 합으로서 결과를 예측합니다. NAM의 주요 기여는 다음과 같습니다. 

  • NAM은 딥러닝에 기반한 함수를 제시하며, 이것은 기존 트리기반의 GAM 모델들보다 훨씬 큰 규모의 커뮤니티입니다. 
  • NAM은 다른 딥러닝 모델들과 결합될 수 있으며, 결합된 방법에 해석력을 추가할 수 있다는 장점이 있습니다. 
  • NAM에 의해 학습된 $f_i$들은 그 모든것을 설명할 순 없지만, 적어도 NAM이 계산한 결과에 대해 정확한 기여도 계산이 가능합니다. 
  • 트리 기반 GAM은 여전히 GPU/TPU 사용에 있어서 자유로운 상황이 아니지만, NAM은 오픈 라이브러리를 활용하여 GPU/TPU를 충분히 사용할 수 있습니다. 

2. Modeling jagged function

Linear (left) & ExU (right)를 활용하여 간단한 이진 분류를 시행한 결과

  NAM은 Linear-ReLU가 3번 가량 반복되는 구조로 구성됩니다. 저자들은 실험 과정 중 한가지 문제점을 발견합니다. 위 그림의 왼쪽 그래프를 따르면, 독립 및 종속 변수간의 관계가 급격하게 변하는 경우 기존 linear combination은 그 경향성을 잡아내지 못합니다. Linear combination은 경향의 중간 그 어딘가에 위치했으며, 이것은 NAM의 목표를 달성하는데 부족하다고 말합니다. 저자들은 문제 해결을 위해 exp-centered hidden unit (ExU)를 제안합니다. ExU의 연산은 아래와 같은 수식을 따릅니다. 아래 식의 $w$와 $\beta$는 learnable weight와 bias를 의미합니다. 

$$ Linear(x) = w \cdot x^\top+ \beta $$

$$ ExU(x) = e^w \cdot ( x - \beta )^\top $$

식에 따르면 ExU는 linear보다 $w$에 더 민감한 변화를 보입니다. 그리고 실제 실험에서도 우측과 같이 linear combination보다 급격한 변화도 따라갈 수 있는 경향을 보여줍니다.

3. Neural Additive Models

Multitask를 위한 NAM 구조
NAM을 구성하는 모듈, feature net들의 구조

  NAM의 기본 모듈들은 ExU를 활용하여 위와 같은 구조를 생성합니다. 각각의 작은 모듈들은 feature net이라고 불리며 특정 변수를 전담하게 됩니다. 각 feature net들의 ExU의 $w$는 특별히 1000개의 node로 구성하며, 이후 linear는 100개 이하의 node들을 사용합니다. 각 모듈들 $f_i$는 $i$번째 변수만 전담하며, scalar를 입력받고, scalar를 출력합니다. 그리고 GAM의 수식을 따라 $K$가지의 변수들을 합하고, bias를 더하여 예측을 수행합니다. 

  한편, 저자들은 ExU를 활용한 NAM의 학습과정에서 특정 변수의 의존도가 매우 커지는 현상을 확인했습니다. 때문에, 효과적인 NAM의 학습을 위해 regularization을 위한 4가지 방법이 도입됩니다. (1) dropout(2) L2 weight decay를 활용하여 특정 node의 가중치가 커지는 현상을 방지합니다. (3) Output penalty는 각 feature net의 output들에 대해서만 L2 norm을 적용하여, 특정 feature net의 output이 과도하게 커지는 현상을 방지하는 것이며, (4) Feature dropout도 비슷한 역할로써 학습 과정 중 특정 feature net의 output을 0으로 바꾸는 역할을 합니다. NAM의 학습을 위한 손실함수는 목적에 맞게 cross-entropy, mean squared error 등을 다양하게 사용할 수 있으며, 4가지 제약조건이 걸린 식은 아래와 같이 쓸 수 있습니다. 

$$ \mathcal{L}(\theta) = \mathbb{E}_{x,y\sim D} [ l(x,y;\theta) + \lambda_1\eta(x;\theta) ] + \lambda_2\gamma(\theta) $$

$$ \eta(x;\theta) = \frac{1}{K} \sum_x \sum_k (f_k^\theta(x_k))^2 $$

전체 손실 함수 식 $\mathcal{L}$ 중 $l$은 학습 목표에 따른 손실 함수이며, $\eta$는 output penalty를 가리키고, $\gamma$는 weight decay를 의미합니다. 

3. Conclusion

  앞서 말씀드린대로, 논문에 수록된 결과들도 좋은편인지라 논문을 읽으면서 이게 될까라는 의문이 강하게 들었었습니다. pytorch로 된 코드들로 파일럿 실험을 했을 때 실제로도 생각보다 결과가 좋아, TF2로 새로 구성했으며, 제 깃헙에 올라와있는 주피터 노트북 파일로 NAM을 활용해보실 수 있습니다.