Loading [MathJax]/jax/output/CommonHTML/jax.js

Accurate Intelligible Models with Pairwise Interactions (2013)

2022. 12. 23. 17:04Neural Networks/Interpretable AI

논문에 대한 오역, 의역등이 다수 포함되어 있습니다. 댓글로 많은 의견 부탁드립니다


Author: Yin Lou, Rich Caruana, Johannes Gehrke, Giles Hooker

최근 GA2M이라는 흥미로운 키워드를 발견했습니다.

2013년 KDD 19th에서 발표된 논문이며 최근 eXAI에서 종종 활용되고 있습니다.


1. Introduction

기계학습 분야의 모델들은 다양한 형태의 데이터셋에서 좋은 예측 성능을 보여줌으로써 주목받고 있습니다. 하지만, 꽤 많은 모델의 경우, 주어진 임의의 데이터 샘플 xRn의 각 변수(이미지의 경우 픽셀)로부터 목표 y를 예측하기 위해 얼마나 많은 기여도를 전달받았는지 직접적으로 알 수 없습니다. Generalized Additive Model (GAM)은 이러한 문제를 해결하고자 사용될 수 있는 방법 중 하나이며, 아래와 같은 식으로 y를 예측합니다.

g(y)=fi(xi)+β

함수 g()은 풀고자 하는 문제에 따라 출력의 형태를 일치시키기 위한 함수이며, 분류 문제의 경우 sigmoid를 쓸 수 있습니다. 함수 fi()i번째 변수를 전담하여, 해당 변수로부터 예측에 의미 있는 값을 추출해 내는 역할을 합니다. 정리하면 GAM은 각 변수만 전담하는 함수들로부터 기여도를 얻고, 기여도를 합친 뒤 bias를 더하여 예측을 시도하는 방법입니다. 함수 fi의 경우 단순한 선형대수연산을 포함하여 decision tree 및 support vector machine, neural network 등 다양한 모델로 대체가능합니다.

한편 GAM을 요약하는 동안 변수들을 독립적으로 처리한다고만 언급드렸습니다. 따라서, 변수들의 상관관계에 의해 발생하는 영향력은 식에서도 찾아볼 수 없으며, 무시되고 있는 상황입니다. 본 논문은 이러한 부분을 보충하기 위해 Generalized Additive Models plus Interactions (GA2M)을 제안합니다. GA2M의 식은 아래와 같이 씁니다.

g(y)=fij(xi,xj)+fi(xi)+β

위 식은 GAM에서 항 하나가 더 늘어났다는 것을 알 수 있습니다. 추가된 함수 fij는 입력 변수를 하나만 받는 것이 아니라 두 개를 받게 됩니다. 때문에, 두 입력 변수로부터 기여도를 추출하는 역할을 맡습니다. 추가된 항은 interaction의 기여도를 생산하기 위한 특별한 장치입니다. 다음 장에서 interaction을 어떤 방식으로 계산하는지 살펴보겠습니다.

2. Concept of GA2M

GA2M algorithm

먼저 기호에 대한 설명입니다. 총 N개의 샘플이 존재하는 데이터셋 D=(xi,yi)N는 총 n개의 변수들로 이뤄진 xi=(xi1,,xin)xi에 대응되는 yi로 이뤄져 있습니다. 그리고 xu는 집합 u1,,n로부터 선택된 변수들을 가리킵니다. 그리고 GA2M을 위해, 변수들로부터 각각 독립적인 함수들을 통하여 얻은 기여도와 interaction으로부터 기여도를 계산할 수 있는 함수를 각각 고려해야합니다. U1,U2는 변수들을 선택하는 집합을 가리키며, 각각 아래와 같이 쓸 수 있습니다.

U1={{i}}|1in}

U2={{i,j}|1i<jn}

U=U1U2

GA2M은 두 단계에 걸쳐 학습이 진행됩니다. 먼저, U1에 속하는 변수 집합으로부터 GAM으로 모델을 학습시키고, 실제 예측과 모델의 예측 결과에 대한 잔차를 계산합니다. 이어서 U2의 interaction 중 잔차를 가장 잘 예측하는 함수를 찾아 GA2M을 완성합니다. 때문에, 알고리즘은 매우 복잡한 형태로 구성됩니다. 알고리즘에 따르면, 매 반복문마다 해당 반복문에 한하여 가장 좋은 interaction 집합 U2Z로 편입되고, 편입된 Z를 활용하여 다음 반복문에서 다시 한번 잔차에 대한 평가를 반복해야 합니다. 저자들은 연산 비용을 저렴하게 만들기 위한 간단한 GA2M 탐색 방법을 제안합니다.

3. Fast interaction detection

The simplest method to search the contribution of interactions

GA2M는 interaction들로부터 얻어낸 기여도를 활용하여 GAM이 예측하지 못한 잔차를 예측해야합니다. 임의의 변수 xi에 존재하는 모든 고윳값들이 발생시키는 기여도를 각각 계산하는 것은 GAM에서도 종종 활용되던 방법입니다. 저자들은 임의의 두 변수 xixj를 선택했을 때, 두 변수를 활용한 간단한 tree를 만드는 것으로 기여도를 계산하는 방법을 제안합니다. 각 변수의 고윳값들을 정렬 시킨 후, 각 변수마다 한번의 분기를 만든다면, 총 4개의 leaf를 갖는 tree를 만들 수 있습니다. 그리고 각 변수 별 분기 기준인 cut cicj로부터 추가적으로 8가지 정보를 얻을 수 있습니다. xici라는 분기로 나눈 경우를 가정했을 때, ci보다 작은 고윳 값에 속하는 target y의 총합을 알 수 있습니다. 그리고 전체 데이터 셋 크기 N에서 ci보다 작은 고윳 값이 차지하는 빈도 비율도 알 수 있습니다. 또한 ci보다 큰 고윳값에 속하는 경우에 대해서도 마찬가지로 두 개의 값도 알 수 있습니다. 이것을 확장하여 xi,xj에 대해 동시에 고려하면 위쪽 그림의 우측 tree와 같이 표현할 수 있습니다. 그리고 이렇게 만들어진 tree는 interaction을 활용한 가장 간단한 함수 fij입니다.

분기 별 고윳 값 중 작은 쪽부터 큰 쪽으로 탐색해가며 계산하는 tree의 leaf 별 두 가지 값은 히스토그램으로 다룰 수 있습니다. dom(xi)={v1i,,vdii}는 변수 xi의 가능한 모든 값을 정렬한 집합이고, di=|dom(xi)|일 때, 히스토그램 Hti(v)Hwi(v)는 각각 xi=v일 때의 target의 총합과 빈도수의 비율입니다. 이러한 점을 활용하여 분기 별 작은 값과 큰 값 별 기여도를 비교하는 것이 수월해집니다. 누적 분포(cumulative histogram)를 고려하면, 분기에 대한 평가를 반복하지 않고 선형 시간 내로 수행할 수 있습니다. target 합과 빈도 비율에 관하여, cut보다 작은 쪽의 누적 분포 CH{t,w}i(v)는 다음 식과 같습니다. 그리고 누적 분포의 전체 합은 항상 일정하기 때문에 분기보다 큰 쪽의 합은 전체 합에서 분기보다 작은 쪽의 누적 합을 제하는 것으로 쉽게 계산가능합니다.

CH{t,w}i(v)=uvH{t,w}i(u)

¯CH{t,w}i)(v)=u>vHi{t,w}(u)=CH{t,w}i(vdii)CH{t,w}i(v)

Algorithm to compute the simplest contribution tree of GA2M under cost O(di×dj)

위 알고리즘은 leaf 노드를 계산하는 과정을 보여줍니다. xj에 존재하는 가장 작은 값을 고정한 후, 모든 vi에 대해 계산하고, 이어서 j도 점차 증가시켜가면서 모든 경우에 대해 계산하는 과정입니다. 각 계산된 값들은 lookuptable L에 어떤 고윳 값에 대한 계산 결과로써 기록됩니다. 마지막으로 제곱 잔차 합(residual sum of square, RSS)을 활용하여, tree로 나타낸 GA2M의 기여도와 target 간의 차를 계산하여 효율적인 interaction들을 찾습니다.

위 그림은 분기로 나뉜 각 영역에 대해, 계산하는 실제 interaction function을 보여줍니다. 두 변수 조합에 의한 interaction fij으로부터 얻는 기여도 Tij[a,b,c,d]의 평균입니다. 그리고 a,b,c,d는 각 분기로 나뉜 변수 별 영역의 기여도입니다. Decision tree와 마찬가지로 각 node 분기에 대한 평가로써, target에 대한 값들뿐만 아니라 target이 차지하는 비율을 동시에 고려하는 것을 확인할 수 있습니다. 최적의 기여도 Tij를 계산하기 위한 RSS는 아래와 같은 식으로 계산합니다.
 RSS(y,Tij)=Nk=1(ykTij(xk))2=(Nk=1y2k2Nk=1ykTij(xk)+Nk=1T2ij(xk))2=(Nk=1y2k2ra,b,c,dTij.rLt.r+ra,b,c,d(Tij.r)2Lw.r)
yk의 모든 k에 대한 총합은 lookuptable Lt의 합과 같습니다. 따라서, Summation 기호의 합연산을 k1,,N에서 table L.r을 활용한 다음과 같은 식으로 바꿔 쓸 수 있습니다. 마지막항인 모든 xk로부터 계산한 값 Tij(xk)2의 총합에 대해서는 각 영역 별 Tij의 모든 제곱값에 각 영역의 빈도 비율을 곱하는 것으로 쉽게 대신 계산할 수 있습니다.

4. Conclusion

본 논문이 나온 이후 꽤 많은 다양한 ML 모델들이 등장했습니다. 최신 모델의 성능은 본 논문과 비교하기 어려울 정도로 앞서가기 시작했습니다. 따라서 실험부분과 본 논문의 디테일은 생략하였습니다. 한편, 최신 몇몇 연구들은 본 논문이 제안한 GA2M을 활용하여 설명력을 높이고, 성능 또한 SOTA를 달성할 수 있는 가능성을 보였습니다. 이후 리뷰할 논문들 중 일부는 GA2M이 등장하기 때문에, 다른 글로부터 들어오시는 분들께 GA2M을 이해하시는데 도움이 되었으면 좋겠습니다.