Accurate Intelligible Models with Pairwise Interactions (2013)

2022. 12. 23. 17:04Neural Networks/Interpretable AI

논문에 대한 오역, 의역등이 다수 포함되어 있습니다. 댓글로 많은 의견 부탁드립니다


Author: Yin Lou, Rich Caruana, Johannes Gehrke, Giles Hooker

최근 $\mathrm{GA^2M}$이라는 흥미로운 키워드를 발견했습니다.

2013년 KDD 19th에서 발표된 논문이며 최근 eXAI에서 종종 활용되고 있습니다.


1. Introduction

기계학습 분야의 모델들은 다양한 형태의 데이터셋에서 좋은 예측 성능을 보여줌으로써 주목받고 있습니다. 하지만, 꽤 많은 모델의 경우, 주어진 임의의 데이터 샘플 $x\in \mathbb{R}^{n}$의 각 변수(이미지의 경우 픽셀)로부터 목표 $y$를 예측하기 위해 얼마나 많은 기여도를 전달받았는지 직접적으로 알 수 없습니다. Generalized Additive Model (GAM)은 이러한 문제를 해결하고자 사용될 수 있는 방법 중 하나이며, 아래와 같은 식으로 $y$를 예측합니다.

$$ g(y) = \sum f_i(x_i) + \beta$$

함수 $g(\cdot)$은 풀고자 하는 문제에 따라 출력의 형태를 일치시키기 위한 함수이며, 분류 문제의 경우 sigmoid를 쓸 수 있습니다. 함수 $f_i(\cdot)$은 $i$번째 변수를 전담하여, 해당 변수로부터 예측에 의미 있는 값을 추출해 내는 역할을 합니다. 정리하면 GAM은 각 변수만 전담하는 함수들로부터 기여도를 얻고, 기여도를 합친 뒤 bias를 더하여 예측을 시도하는 방법입니다. 함수 $f_i$의 경우 단순한 선형대수연산을 포함하여 decision tree 및 support vector machine, neural network 등 다양한 모델로 대체가능합니다.

한편 GAM을 요약하는 동안 변수들을 독립적으로 처리한다고만 언급드렸습니다. 따라서, 변수들의 상관관계에 의해 발생하는 영향력은 식에서도 찾아볼 수 없으며, 무시되고 있는 상황입니다. 본 논문은 이러한 부분을 보충하기 위해 Generalized Additive Models plus Interactions ($\mathrm{GA^2M}$)을 제안합니다. $\mathrm{GA^2M}$의 식은 아래와 같이 씁니다.

$$ g(y) = \sum f_{ij}(x_i,x_j) + \sum f_i(x_i) + \beta $$

위 식은 GAM에서 항 하나가 더 늘어났다는 것을 알 수 있습니다. 추가된 함수 $f_{ij}$는 입력 변수를 하나만 받는 것이 아니라 두 개를 받게 됩니다. 때문에, 두 입력 변수로부터 기여도를 추출하는 역할을 맡습니다. 추가된 항은 interaction의 기여도를 생산하기 위한 특별한 장치입니다. 다음 장에서 interaction을 어떤 방식으로 계산하는지 살펴보겠습니다.

2. Concept of $\mathrm{GA^2M}$

$\mathrm{GA^2M}$ algorithm

먼저 기호에 대한 설명입니다. 총 $N$개의 샘플이 존재하는 데이터셋 $\mathcal{D}={(x_i, y_i)}^N$는 총 $n$개의 변수들로 이뤄진 $x_i = (x_{i1}, \cdots, x_{in})$와 $x_i$에 대응되는 $y_i$로 이뤄져 있습니다. 그리고 $x_u$는 집합 $u\subseteq{1,\cdots,n}$로부터 선택된 변수들을 가리킵니다. 그리고 $\mathrm{GA^2M}$을 위해, 변수들로부터 각각 독립적인 함수들을 통하여 얻은 기여도와 interaction으로부터 기여도를 계산할 수 있는 함수를 각각 고려해야합니다. $\mathcal{U^1, U^2}$는 변수들을 선택하는 집합을 가리키며, 각각 아래와 같이 쓸 수 있습니다.

$$\mathcal{U}^1 = \{\{i\}\}|1 \leq i \leq n\}$$

$$\mathcal{U}^2 = \{\{i,j\}|1 \leq i < j \leq n\}$$

$$\mathcal{U} = \mathcal{U}^1 \cup \mathcal{U}^2$$

$\mathrm{GA^2M}$은 두 단계에 걸쳐 학습이 진행됩니다. 먼저, $\mathcal{U}^1$에 속하는 변수 집합으로부터 GAM으로 모델을 학습시키고, 실제 예측과 모델의 예측 결과에 대한 잔차를 계산합니다. 이어서 $\mathcal{U}^2$의 interaction 중 잔차를 가장 잘 예측하는 함수를 찾아 $\mathrm{GA^2M}$을 완성합니다. 때문에, 알고리즘은 매우 복잡한 형태로 구성됩니다. 알고리즘에 따르면, 매 반복문마다 해당 반복문에 한하여 가장 좋은 interaction 집합 $\mathcal{U}^2$는 $\mathcal{Z}$로 편입되고, 편입된 $\mathcal{Z}$를 활용하여 다음 반복문에서 다시 한번 잔차에 대한 평가를 반복해야 합니다. 저자들은 연산 비용을 저렴하게 만들기 위한 간단한 $\mathrm{GA^2M}$ 탐색 방법을 제안합니다.

3. Fast interaction detection

The simplest method to search the contribution of interactions

$\mathrm{GA^2M}$는 interaction들로부터 얻어낸 기여도를 활용하여 $\mathrm{GAM}$이 예측하지 못한 잔차를 예측해야합니다. 임의의 변수 $x_i$에 존재하는 모든 고윳값들이 발생시키는 기여도를 각각 계산하는 것은 $\mathrm{GAM}$에서도 종종 활용되던 방법입니다. 저자들은 임의의 두 변수 $x_i$와 $x_j$를 선택했을 때, 두 변수를 활용한 간단한 tree를 만드는 것으로 기여도를 계산하는 방법을 제안합니다. 각 변수의 고윳값들을 정렬 시킨 후, 각 변수마다 한번의 분기를 만든다면, 총 4개의 leaf를 갖는 tree를 만들 수 있습니다. 그리고 각 변수 별 분기 기준인 cut $c_i$와 $c_j$로부터 추가적으로 8가지 정보를 얻을 수 있습니다. $x_i$를 $c_i$라는 분기로 나눈 경우를 가정했을 때, $c_i$보다 작은 고윳 값에 속하는 target $y$의 총합을 알 수 있습니다. 그리고 전체 데이터 셋 크기 $N$에서 $c_i$보다 작은 고윳 값이 차지하는 빈도 비율도 알 수 있습니다. 또한 $c_i$보다 큰 고윳값에 속하는 경우에 대해서도 마찬가지로 두 개의 값도 알 수 있습니다. 이것을 확장하여 $x_i, x_j$에 대해 동시에 고려하면 위쪽 그림의 우측 tree와 같이 표현할 수 있습니다. 그리고 이렇게 만들어진 tree는 interaction을 활용한 가장 간단한 함수 $f_{ij}$입니다.

분기 별 고윳 값 중 작은 쪽부터 큰 쪽으로 탐색해가며 계산하는 tree의 leaf 별 두 가지 값은 히스토그램으로 다룰 수 있습니다. $dom(x_i)=\{v_i^1,\cdots, v_i^{d_i}\}$는 변수 $x_i$의 가능한 모든 값을 정렬한 집합이고, $d_i = |dom(x_i)|$일 때, 히스토그램 $H_i^t(v)$와 $H_i^w(v)$는 각각 $x_i=v$일 때의 target의 총합과 빈도수의 비율입니다. 이러한 점을 활용하여 분기 별 작은 값과 큰 값 별 기여도를 비교하는 것이 수월해집니다. 누적 분포(cumulative histogram)를 고려하면, 분기에 대한 평가를 반복하지 않고 선형 시간 내로 수행할 수 있습니다. target 합과 빈도 비율에 관하여, cut보다 작은 쪽의 누적 분포 $CH_i^{\{t,w\}}(v)$는 다음 식과 같습니다. 그리고 누적 분포의 전체 합은 항상 일정하기 때문에 분기보다 큰 쪽의 합은 전체 합에서 분기보다 작은 쪽의 누적 합을 제하는 것으로 쉽게 계산가능합니다.

$$ CH_i^{\{t,w\}}(v) = \sum_{u \leq v} H_i^{\{t,w\}}(u) $$

$$ \bar{CH}_i^{\{t,w\}})(v) = \sum_{u>v} H_i{\{t,w\}}(u) = CH_i^{\{t,w\}}(v_i^{d_i}) - CH_i^{\{t,w\}}(v) $$

Algorithm to compute the simplest contribution tree of $\mathrm{GA^2M}$ under cost $\mathcal{O}(d_i \times d_j)$

위 알고리즘은 leaf 노드를 계산하는 과정을 보여줍니다. $x_j$에 존재하는 가장 작은 값을 고정한 후, 모든 $v_i$에 대해 계산하고, 이어서 $j$도 점차 증가시켜가면서 모든 경우에 대해 계산하는 과정입니다. 각 계산된 값들은 lookuptable $L$에 어떤 고윳 값에 대한 계산 결과로써 기록됩니다. 마지막으로 제곱 잔차 합(residual sum of square, RSS)을 활용하여, tree로 나타낸 $\mathrm{GA^2M}$의 기여도와 target 간의 차를 계산하여 효율적인 interaction들을 찾습니다.

위 그림은 분기로 나뉜 각 영역에 대해, 계산하는 실제 interaction function을 보여줍니다. 두 변수 조합에 의한 interaction $f_{ij}$으로부터 얻는 기여도 $T_{ij}$는 $[a,b,c,d]$의 평균입니다. 그리고 $a,b,c,d$는 각 분기로 나뉜 변수 별 영역의 기여도입니다. Decision tree와 마찬가지로 각 node 분기에 대한 평가로써, target에 대한 값들뿐만 아니라 target이 차지하는 비율을 동시에 고려하는 것을 확인할 수 있습니다. 최적의 기여도 $T_{ij}$를 계산하기 위한 RSS는 아래와 같은 식으로 계산합니다.
\begin{align}
\ RSS(y, T_{ij}) &= \sum^N_{k=1} \left(y_k-T_{ij}(x_k) \right)^2 \\
&= \left( \sum_{k=1}^N y_k^2 - 2 \sum_{k=1}^N y_k T_{ij}(x_k) + \sum_{k=1}^N T_{ij}^2(x_k) \right)^2 \\
&= \left( \sum_{k=1}^N y_k^2 - 2 \sum_{r\in{a,b,c,d}} T_{ij}.r L^t.r + \sum_{r\in{a,b,c,d}} (T_{ij}.r)^2 L^w.r \right)
\end{align}
$y_k$의 모든 $k$에 대한 총합은 lookuptable $L^t$의 합과 같습니다. 따라서, Summation 기호의 합연산을 $k\rightarrow{1,\cdots,N}$에서 table $L.r$을 활용한 다음과 같은 식으로 바꿔 쓸 수 있습니다. 마지막항인 모든 $x_k$로부터 계산한 값 $T_{ij}(x_k)^2$의 총합에 대해서는 각 영역 별 $T_{ij}$의 모든 제곱값에 각 영역의 빈도 비율을 곱하는 것으로 쉽게 대신 계산할 수 있습니다.

4. Conclusion

본 논문이 나온 이후 꽤 많은 다양한 ML 모델들이 등장했습니다. 최신 모델의 성능은 본 논문과 비교하기 어려울 정도로 앞서가기 시작했습니다. 따라서 실험부분과 본 논문의 디테일은 생략하였습니다. 한편, 최신 몇몇 연구들은 본 논문이 제안한 $\mathrm{GA^2M}$을 활용하여 설명력을 높이고, 성능 또한 SOTA를 달성할 수 있는 가능성을 보였습니다. 이후 리뷰할 논문들 중 일부는 $\mathrm{GA^2M}$이 등장하기 때문에, 다른 글로부터 들어오시는 분들께 $\mathrm{GA^2M}$을 이해하시는데 도움이 되었으면 좋겠습니다.