Interpretable Machine Learning - newlife-js/Wiki GitHub Wiki
by Christoph Molnar (๋ฒ์ญ : TooTouch)
- ๋จธ์ ๋ฌ๋ ๋ชจ๋ธ์ ์ ๋ขฐํ์ง ๋ชปํ๋ ์ด์ : accuracy์ ๊ฐ์ ๋จ์ผ ํ๊ฐ ์งํ๋ ํ์ค๋ฌธ์ ์ ์ฌ์ฉํ๊ธฐ์ ๋ถ์์ ํ ์งํ์ด๊ธฐ ๋๋ฌธ
- ์์ธก๋ง์ผ๋ก ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ ์ ์๋ ๋ฌธ์ ์ ๋ ์ค์
Intrinsic: shallow decision tree์ ๊ฐ์ sparse ์ ํ ๋ชจ๋ธ๊ณผ ๊ฐ์ ๋จ์ํ ๊ตฌ์กฐ๋ก ์ธํด ํด์ ๊ฐ๋ฅ
Post Hoc: ๋ชจ๋ธ ํ์ต ํ ํด์ ๋ฐฉ๋ฒ ์ ์ฉ(Permutation Feature Importance)
์ ์ฒด๋ก ์ ํด์๊ฐ๋ฅ์ฑ: Hyperplane๊ณผ ๊ฐ์ ๊ฒฐ๊ณผ๋ก ํด์.. (3์ฐจ์ ์ด์์ ์ธ๊ฐ์ ์์๋ฒ์ ๋ฐ)
๋ชจ๋ ์์ค์์ ์ ์ฒด๋ก ์ ํด์๊ฐ๋ฅ์ฑ: ๋จ์ผ๊ฐ์ค์น๋ก ํด์(์ ํ๋ชจ๋ธ์ ๊ฐ์ค์น)
๋จ์ผ ์์ธก์น์ ๋ํ ์ง์ญ์ ํด์๊ฐ๋ฅ์ฑ: ํ๋์ x๋ก ์์ธก
์ดํ๋ฆฌ์ผ์ด์
์์ค: ์ ๋ฌธ์ ์ธ ์ค์ ์ฌ์ฉ์์ ์ํด ํ๊ฐ๋ฐ๋ ๊ฒ
์ธ๊ฐ ์์ค: ์ดํ๋ฆฌ์ผ์ด์
์์ค์ ๋จ์ํํ์ฌ ์ ๋ฌธ๊ฐ๊ฐ ์๋ ์ฌ์ฉ์๊ฐ ํ๊ฐํ๋ ๊ฒ
๊ธฐ๋ฅ ์์ค: ๋ชจ๋ธ์ ํด๋์ค๊ฐ ์ด๋ฏธ ์ธ๊ฐ ์์ค์ ํ๊ฐ์์ ํ๊ฐ๋ ๊ฒฝ์ฐ(์งง์ ํธ๋ฆฌ์ผ์๋ก ์ค๋ช
๋ ฅ์ด ๋๋ค ๋ฑ)
- ๊ฐ์ค์น๋ค๊ณผ feature๊ฐ์ ๊ณฑ์ผ๋ก ์์ธก๊ฐ์ ๋ํ ๊ธฐ์ฌ๋๋ฅผ ์ค๋ช
- ๋น์ ํ์ฑ์ด๋ ๊ตํธ์์ฉ์ด ๋ง์ ๊ฒฝ์ฐ์๋ ์ ์ ์น ๋ชปํจ
- ์ ํ์ฑ์ ์ค๋ช ์ ๋ ์ผ๋ฐ์ ์ด๊ณ ๊ฐ๋จํ๊ฒ ๋ง๋ฆ
โ ๋๋ฌด ๋ง์ feature๋ค์ด ์กด์ฌํ ๊ฒฝ์ฐ์๋ ์ ํ ๋ชจ๋ธ๋ก ํ์ตํ ์ ์์. -> sparsity๋ฅผ ์ ์ฉ
1) Lasso: ํฐ ๊ฐ์ค์น ํฉ์ ํ๋ํฐ๋ฅผ ๋ถ๊ณผํด, ๋ง์ feature๋ค์ ๊ฐ์ค์น๋ค์ 0์ผ๋ก ๋ง๋๋ ์ ๊ทํ ๋ฐฉ๋ฒ
- ์ ๋ฌธ์ ์ธ ์ง์์ ์ด์ฉํด ์๋์ ์ผ๋ก feature ์ ํ
- feature๊ณผ ๋ชฉํ๊ฐ ๊ฐ์ ์๊ด๊ด๊ณ๊ฐ ์๊ณ๊ฐ์ ๋๋ ๊ฒฝ์ฐ ์ ํ(feature์ด ์๋ก ๋ ๋ฆฝ์ ์ด๋ผ๊ณ ๊ฐ์ )
- Forward Selection: feature ํ๋๋ก๋ถํฐ ์์ํด์, ๊ฐ์ฅ ์ข์ ๋ชจ๋ธ์ ๋ง๋๋ feature๋ค์ ํ๋์ฉ ์ถ๊ฐ
- Backward Selection: ๋ชจ๋ feature๋ฅผ ๋ฃ์ ๋ชจ๋ธ๋ก๋ถํฐ ์์ํด์, ๊ฐ์ฅ ์ข์ ๋ชจ๋ธ์ ๋ง๋ค๋๋ก feature๋ฅผ ํ๋์ฉ ์ ๊ฑฐ
์ฅ์ : ๋ง์ ๋ถ์ผ์์ ์ฌ์ฉ๋๊ธฐ ๋๋ฌธ์ ๋์ ์์ค์ ๊ฒฝํ๊ณผ ์ ๋ฌธ์ง์์ด ์๊ณ , ์ต์ ์ ๊ฐ์ค์น๋ฅผ ํ์คํ ์ ์ ์์.
- feature๋ค์ ์ ํ๊ฒฐํฉ์ ๋ก์ง์คํฑํจ์๋ฅผ ์ ์ฉ
- ๊ฐ์ค์น ๋์ odds ratio๋ก ํด์ํจ
๋จ์ : ๊ตํธ์์ฉ x, ํ๋์ feature๊ฐ ๋ ํด๋์ค๋ก ์์ ํ ๋ถ๋ฆฌํ๋ค๋ฉด ํ์ต๋์ง ์์
GLM: ๋ชจ๋ ๊ฒฐ๊ณผ๊ฐ์ ์ ํ(๋น๊ฐ์ฐ์์ ๋ถํฌ, ์์๊ฐ ์๋ ์ ํ ๋ฑ)์ ๋ชจ๋ธ๋งํ๊ธฐ ์ํด ํ์ฅ์ํจ ์ ํ ๋ชจ๋ธ
์ ํ๋ธ๊ณผ ๊ธฐ๋๊ฐ์ ๋น์ ํ ํจ์๋ฅผ ํตํด์ ์ฐ๊ฒฐ
GAM: ๊ฐ์ค์น ํฉ์ด ์๋, ๊ฐ๊ฐ์ feature๋ณ๋ก ์์์ ํจ์๋ฅผ ์ ์ฉํ ๊ฐ์ ํฉ์ผ๋ก ๋ชจ๋ธ๋ง(๋น์ ํ ํด๊ฒฐ ์ํด)
์ฅ์ : ๋ง์ ๊ณณ์์ ์ฌ์ฉ๋๊ณ ์์, ํด์์ฑ์ ์ผ๋ถ ์ ์งํ๋ฉด์ ์ ์ฐํ ๋ชจ๋ธ๋ก ์ํํ๊ฒ ์ ํ ๊ฐ๋ฅ
- impurity๋ฅผ ๊ฐ์์ํค๋ ๋ฐฉํฅ์ผ๋ก feature์ ๊ตฌ๋ถํ๋ ํธ๋ฆฌ๋ฅผ ๊ตฌ์ฑ
- ๋ถํ ํ ๊ฐ์ํ impurity์ ์ ๋๋ฅผ feature ์ค์๋๋ก ์ฌ์ฉ
๋จ์ : ์ ํ ๊ด๊ณ๋ฅผ ๋ค๋ฃฐ ์ ์์. ํธ๋ฆฌ๊ฐ ์์ ์ ์ด์ง ์์(๋ฐ์ดํฐ์ , ์ฒซ ๋ถํ feature์ ๋ฐ๋ผ ๋ฌ๋ผ์ง). ํธ๋ฆฌ๊ฐ ๊น์์๋ก ํด์ํ๊ธฐ๊ฐ ์ด๋ ต๋ค.
- If-Then ๊ตฌ์กฐ๋ก ์์ธกํ๋ ๋ฐฉ์ โ support : ๊ท์น ์กฐ๊ฑด์ด ์ ์ฉ๋๋ ๊ด์ธก์ง์ ๋ฐฑ๋ถ์จ(์ง์ํ๋ ๋ฒ์) โ accuracy: ๊ท์น์ด ์ฌ๋ฐ๋ฅธ ํด๋์ค๋ฅผ ์์ธกํ๋ ๋น์จ
- ์ ์ ํ ๊ฐ๊ฒฉ์ ์ ํํ์ฌ ์ฐ์ํ feature๋ฅผ ๋ฒ์ฃผํ
- feature์ ๊ฒฐ๊ณผ ์ฌ์ด์ ๊ต์ฐจ ํ ์ด๋ธ์ ๋ง๋ค์ด ๊ฐ์ฅ ์ค๋ฅ๊ฐ ์ ์ feature๋ฅผ ์ ํ
2) Sequential covering: ์ ์ฒด ๋ฐ์ดํฐ ์งํฉ ๊ท์น์ ํฌํจํ๋ ์์ฌ๊ฒฐ์ ๋ชฉ๋ก์ ๋ง๋ค๊ธฐ ์ํด ๋จ์ผ ๊ท์น์ ๋ฐ๋ณต์ ์ผ๋ก ํ์ต
- ๊ท์น 1๋ก ํ์ตํ๊ณ , ๊ท์น 1์ ํด๋นํ๋ ๋ฐ์ดํฐ ์ง์ ์ ์ ๊ฑฐํ ํ, ๋๋จธ์ง ๋ฐ์ดํฐ๋ก ๊ทธ ๋ค์ ๊ท์น 2๋ฅผ ํ์ตํ๋ค.
- ๋ฐ์ดํฐ์์ ๋น๋ฒํ ํจํฑ์ ๋ฏธ๋ฆฌ ํ์
- ๋ฏธ๋ฆฌ ํ์ธ๋ ๊ท์น์ ์ ํ ํญ๋ชฉ์์ ์์ฌ๊ฒฐ์ ๋ชฉ๋ก์ ํ์ต
๋จ์ : ํ๊ท๋ฅผ ์์ ํ ๋ฌด์.feature์ ๋ฒ์ฃผํํด์ผ ํจ. feature๊ณผ ์ถ๋ ฅ ๊ฐ์ ์ ํ๊ด๊ณ ์ค๋ช ํ๊ธฐ ์ด๋ ค์.
- ์์ฌ๊ฒฐ์ ๊ท์น์ ํํ๋ก ์๋์ผ๋ก ํ์ง๋ ์ํธ์์ฉ ํจ๊ณผ๋ฅผ ํฌํจํ๋ ํฌ์ ์ ํ ๋ชจ๋ธ์ ํ์ต
- ์์ฌ ๊ฒฐ์ ํธ๋ฆฌ์์ ์ํธ์์ฉ์ ๊ณ ๋ คํ ์ feature์ ์๋์ผ๋ก ์์ฑ
- ์์๋ธ ๋ฑ์ ์ด์ฉํด์ ์ต๋ํ ๋ง์ ๊ท์น์ ๋ง๋ฆ(๋ ธ๋์ ์์ธก๊ฐ์ ๋ฒ๋ฆฌ๊ณ , ๋ถํ ์กฐ๊ฑด๋ง ์ฌ์ฉ)
- ๋ง๋ค์ด์ง ๊ท์น๋ค๊ณผ ๊ธฐ์กด feature์ ์ฌ์ฉํด ํฌ์ ์ ํ ๋ชจ๋ธ์ ๋ง๋ค์ด ๊ฐ์ค์น ์ถ์ ์น๋ฅผ ์ป์
- Lasso ๋ชจ๋ธ์ ๊ฐ์ค์น์ ์ ํ ํญ์ ํ์คํธ์ฐจ๋ฅผ ๊ณฑํด์ feature ์ค์๋๋ฅผ ์ป์
์ฅ์ : feature ์ํธ์์ฉ์ ์๋์ผ๋ก ์ถ๊ฐ(๋น์ ํ ๊ด๊ณ ๋ชจ๋ธ๋ง), ๋ถ๋ฅ ๋ฐ ํ๊ท ๋ชจ๋ ์ปค๋ฒ ๊ฐ๋ฅ, ๋ก์ปฌ ํด์์ฑ ํฅ์(๊ฐ๋ณ ๊ด์ธก์น์๋ ์์์ ๊ท์น๋ง ์ ์ฉ)
- ๋ชจ๋ธ์ ํ์ต๊ณผ ์ค๋ช ์ ๋ถ๋ฆฌ์์ผ, ํ์ต์ ์ข ๋ฅ์ ์ ํ๋์ง ์์ ์ค๋ช ์ ์ ๊ณตํ๋ ๋ฐฉ๋ฒ
- ํด์ ๊ฐ๋ฅํ ๋ชจ๋ธ๋ง์ ์ฌ์ฉํ๊ธฐ์๋ ์ฑ๋ฅ์ด ๋จ์ด์ ธ์...
- Model flexibility(์ด๋ ๋ชจ๋ธ์ด๋ ์ ์ฉ ๊ฐ๋ฅ), Explanation flexibility(ํน์ form์ ์ค๋ช
์ ๊ตญํ๋์ง ์์), Representation flexibility(์ค๋ช
ํ๋ ๋ชจ๋ธ ๋ณ๋ก ๋ค๋ฅธ feature representation์ ์ฌ์ฉ)
โ Example-based Explanation: ๋ชจ๋ธ์ ์ค๋ช ํ๊ธฐ ์ํด ํน์ dataset์ ์ ํ(model-agnostic์์๋ feature์ summary๋ฅผ create)
๋จธ์ ๋ฌ๋ ๋ชจ๋ธ์ average behavior๋ฅผ describeํ๋ ๋ฐฉ๋ฒ
1,2๊ฐ์ง feature๊ฐ ์์ธก ๊ฒฐ๊ณผ์ ๋ฏธ์น๋ marginal effect๋ฅผ ๋ณด์ฌ์ฃผ๋ ๋ฐฉ๋ฒ
S๋ ๊ด์ฌ ์๋ feature์ ์งํฉ(1~2๊ฐ)์ด๋ฉฐ, C๋ ๋จธ์ ๋ฌ๋ ๋ชจ๋ธ์ ์ฌ์ฉ๋ ๋ค๋ฅธ feature๋ค(S์ feature๋ค๊ณผ๋ ์๊ด๊ด๊ณ ์๋ค๋ ๊ฐ์ )
PDP๋ training set์ ๋ํ์ฌ S์ ๊ฐ์ ๋ฐ๋ผ ์์ฑ๋๋ ๊ฒฐ๊ณผ์ ํ๊ท ์ ๊ทธ๋ฆผ
PDP์ average curve๋ก๋ถํฐ์ deviation์ด ํด์๋ก ์ค์๋๊ฐ ๋์
๋จ์ : ํ์ค์ ์ธ ์ต๋ feature ์๊ฐ 2๊ฐ(3D๊น์ง๋ฐ์ ํํ์ด ์๋๋ฏ๋ก), ๋ ๋ฆฝ์ฑ ๊ฐ์ ์ด ํ์, ํ๊ท ์ผ๋ก ์ธํ ํน์ด๊ฐ ํจ๊ณผ ์จ๊ฒจ์ง
feature๋ค์ด ๋จธ์ ๋ฌ๋ ๋ชจ๋ธ์ ๊ฒฐ๊ณผ ์์ธก์ ํ๊ท ์ ์ผ๋ก ์ํฅ์ ๋ฏธ์น๋์ง๋ฅผ ๋ํ๋.
PDP๋ณด๋ค ๋น ๋ฅด๊ณ unbiased(๋ณ์ ๊ฐ ์๊ด๊ด๊ณ ๊ณ ๋ ค)๋์ด ์๋ค.
- Grid๋ฅผ Window๋ก ๋๋์ด์, Window ๋ด์ ์์ธก๊ฐ์ ์ฐจ์ด๋ฅผ ํ๊ท ๋ด์ด์ grid์ ๋ฐ๋ผ accumulate ํ๋ค.
Feature๊ฐ ์๊ด๊ด๊ณ๋ฅผ ์ธก์ ํ๊ธฐ ์ํจ
โ H-statistics: ๋ feature๊ฐ or ํ feature๊ณผ ๋๋จธ์ง feature๋ค์ interaction์ partial dependence๋ฅผ ์ฌ์ฉํ์ฌ ์ธก์ ํ ํต๊ณ๋
์ฅ์ : PD๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ํจ, ์๋ฏธ์๋ ํด์๋ ฅ์ ๊ฐ๋๋ค, ์ฐจ์์ด ์์ด์ feature๊ฐ/๋ชจ๋ธ๊ฐ ๋น๊ต ๊ฐ๋ฅ, ๋ชจ๋ ์ข ๋ฅ์ interaction ํ์ง
๊ณ ์ฐจ์ ํจ์๋ฅผ ๊ฐ๊ฐ์ feature effect์ interaction effect์ ํฉ์ผ๋ก ๋ํ๋ด๋ ๊ฒ.
โ (Generalized) Functional ANOVA โ ALE โ Statistical Regression Models
Feature์ ๊ฐ์ permuteํจ์ ๋ฐ๋ผ ๋ณํ๋ prediction error์ ์ฆ๊ฐ๋ฅผ ์ฌ์ฉ
- Training data๋ก ๋ชจ๋ธ์ ํ์ตํ ๋ค, test data์ permutation ์ ์ฉํ์ฌ feature importance ๋์ถ
์)
์ฅ์ : ํด์๋ ฅ ์ข์, error ratio๋ฅผ ์ฌ์ฉํ๊ธฐ ๋๋ฌธ์ ๋ค๋ฅธ ๋ฌธ์ ๋ค๋ผ๋ฆฌ ๋น๊ต ๊ฐ๋ฅ, feature๊ฐ interaction ๊ณ ๋ ค
๋จ์ : unlabeled data์๋ ์ ์ฉ ๋ถ๊ฐ, correlated๋ feature์ด ์์ผ๋ฉด unrealistic data instance์ ์ํด biased๋ ์ ์์, correlated feature์ ์ถ๊ฐํ๋ฉด ๊ด๋ จ๋ feature์ importance๊ฐ ์ค์ด๋ค ์ ์์
black box model์ ์์ธก์ ๊ทผ์ฌํ๋ ์์ธก๊ฐ๋ฅํ ๋ชจ๋ธ
- black box model์ ์ฌ์ฉํ dataset์ X๋ก, black box model์ ์์ธก์ y๋ก ํด์ linear model์ด๋ decision tree ๊ฐ์ ํด์๊ฐ๋ฅํ ๋ชจ๋ธ์ ํ์ต
- r^2๋ก black box model๊ณผ surrogate model์ ์์ธก์ ์ ์ฌ์ฑ์ ์ธก์
์)
์ฅ์ : flexible(์ด๋ค ํด์๊ฐ๋ฅํ ๋ชจ๋ธ์ด๋ ์ฌ์ฉํ ์ ์์), ์ง๊ด์ , r^2๋ก surrogate ๋ชจ๋ธ์ด ์ผ๋ง๋ ์ ๊ทผ์ฌํ๋์ง๋ฅผ ์ธก์ ํ ์ ์์
๋จ์ : ๋ฐ์ดํฐ๊ฐ ์๋ model์ ๋ํ ๊ฒฐ๋ก ๋ง ๋ด๋ฆด ์ ์์, ์ด๋ ์ ๋์ r^2๊ฐ ์ข์ ๊ฑด์ง๊ฐ ๋ถ๋ช ํ
Prototype: ๋ชจ๋ data๋ฅผ ์ ๋ํํ๋ data instance
Criticism: prototype์ ์ํด ๋ํ๋์ง ๋ชปํ๋ data instance
๋ฐ์ดํฐ์ ๋ถํฌ์ ๋ํ ์ดํด๋ฅผ ๋๊ณ , ํด์๊ฐ๋ฅํ ๋ชจ๋ธ์ ๋ง๋ค๊ฑฐ๋, black box model์ ํด์ํ๋ ๋ฐ ๋์์ด ๋จ
โ MMD-critic: prototype๊ณผ ์ค๋ฐ์ดํฐ์ ๋ถํฌ๋ฅผ ๋น๊ตํ์ฌ, ๊ดด๋ฆฌ๋ฅผ ์ต์ํํ๋ prototype์ ์ ํํ๋ ๋ฐฉ๋ฒ
- prototype๊ณผ criticism์ ๊ฐฏ์๋ฅผ ์ ์
- greedy search๋ก prototype ์ฐพ๊ธฐ
- greedy search๋ก criticism ์ฐพ๊ธฐ
- data density๋ฅผ ์ถ์ ํ๊ธฐ ์ํ kernel function์ ์ด์ฉํ์ฌ ๋ ๋ถํฌ์ ์ฐจ์ด๋ฅผ ๊ณ์ฐํ๋ witness function์ ์ฌ์ฉ
Individual Prediction์ ๋ํ ์ค๋ช ์ ํ๋ ๋ฐฉ๋ฒ
ํ feature์ ๋ณํ์ ๋ฐ๋ผ ๊ฐ๊ฐ์ instance์ ์์ธก์ด ์ด๋ป๊ฒ ๋ณํํ๋์ง๋ฅผ line์ผ๋ก ๋ํ๋
PDP๋ ICE๋ค์ ํ๊ท ์ด๋ผ๊ณ ๋ณด๋ฉด ๋จ.
-
centered ICE Plot: ๊ฐ๊ฐ์ prediction์ ์ฐจ์ด๊ฐ ์์์ ์ด ๋ค๋ฅธ ๊ฒ ๋๋ฌธ์ผ ์ ์์ผ๋ฏ๋ก, ์์์ ์ ์ผ์น์ํด.
-
Derivative ICE Plot: ๋ณํ์ ๋ฐฉํฅ๊ณผ feature์ range ํ์ ์ด ์ฌ์.
๋จ์ : ํ๋์ feature๋ง ๋ํ๋ผ ์ ์์, correlated๋ feature์ผ ๊ฒฝ์ฐ ์ด๋ ค์, ๋๋ฌด ๋ง์ line์ด ์์ ์ ์์, ํ๊ท ์ ๋ณด๊ธฐ ์ด๋ ค์.
Individual prediction์ ์ค๋ช ํ๊ธฐ ์ํ local surrogate model์ ํ์ต
- ์๋ก์ด dataset(๊ด์ฌ instance์ perturbed sample ํฌํจ) ๋ง๋ค์ด black box model์ ์์ธกํ๋๋ก ํ์ต
- ๊ด์ฌ instance์ ๊ฐ๊น์ด new sample์ ๋ ํฐ ๊ฐ์ค์น๋ฅผ ์ค
- Local approximation์๋ง ์ด์ ์ ๋๊ณ , global approximation์ ์ํ ํ์๋ ์์
- Loss(black box model์ prediction๊ณผ์ ์ฐจ์ด)๋ฅผ ์ต์๋ก ํ๋ ๋ฐ ์ด์ ์ ๋๋ฉด์ model complexity์ ์ ํ์ ๋
- neighborhood์ ํฌ๊ธฐ๋ฅผ ๊ฒฐ์ ํ๋ kernel width ์กฐ์ ์ํด์ผ ํจ
์๋ฌธ) ๊ธฐ์กด ๋ชจ๋ธ์ด perturbed sample๋ ์ ์์ธกํ๋ค๋ ๊ฐ์ ์ด ์์ด์ผ ํ๋ ๊ฒ์ธ๊ฐ?
๊ด์ฌ instance๋ ์ด๋ป๊ฒ ์ ์ ํ๋๊ฐ?
์ฅ์ : black box model์ ์ข ๋ฅ์ ์๊ด์์ด ์ ์ฉ ๊ฐ๋ฅ, ๊ฐ๊ฒฐํ๊ณ ์ง๊ด์ ์ธ ์ค๋ช ๊ฐ๋ฅ, tablular, text, image ๋ฐ์ดํฐ ๋ชจ๋ ์ ์ฉ ๊ฐ๋ฅ, fidelity measure ๊ฐ๋ฅ, ํจํค์ง ์ ๋์ด ์์,
๋จ์ : neighborhood์ ํฌ๊ธฐ๋ฅผ ๊ฒฐ์ ํ๋ kernel width ์ ์ ์ด ์ด๋ ค์, Gaussian distribution์ ์ํ sampling์ ํ๊ณ(ex: correlated feature), model complexity ๋ฏธ๋ฆฌ ์ ํด์ผ ํจ, instability of the explanations
predefined output์ ๋ํ ์์ธก์ ๋ณํ(ex: Y->N / 90->100)์ํค๋ ๊ฐ์ฅ ์์ ๋ณํ๊ฐ ํ์ํ feature๋ฅผ ์ค๋ช
- feature value ๋ณํ์ ํฌ๊ธฐ๋ฅผ ์๊ฒ ํ๋ ๊ฒ๋ ์ค์ํ์ง๋ง, ๋ณํํ๋ feature ๊ฐฏ์๋ ์์์ผ ํจ
- ๋ค์์ counterfactual instance๋ฅผ generateํ๋ ๊ฒ ๋ฐ๋์งํ ๋๋ ์์
- ๊ทธ๋ด ๋ฏํ(ํ์ค์ ์ธ) counterfactual instance๋ฅผ generateํด์ผ ํจ
โ Generating Counterfactual Explanations
(1) Minimizing Loss by Watchter:
desired outcome๊ณผ couterfactual์ prediction์ ์ฐจ์ด
- ์ค์ point์ counterfactual point ์ฌ์ด์ ๊ฑฐ๋ฆฌ
- ๋จ์ : ์ ์ ์์ feature๋ง ๊ตฌํจ, categorial feature๋ ๋ค๋ฃจ๊ธฐ ์ด๋ ค์
(2) Minimizing 4 Loss by Dandl:
desired outcome๊ณผ couterfactual์ prediction์ ์ฐจ์ด
- ์ค์ point์ counterfactual point ์ฌ์ด์ ๊ฑฐ๋ฆฌ(Gower distance)
- ๋ณํ๋ feature์ ๊ฐฏ์ + likely feature values/combinations๋ฅผ ๊ฐ์ง counterfactual
์ฅ์ : ํด์์ด ๋ช ํ, ์๋ก์ด counterfactual๋ฅผ ๋ง๋ค๊ฑฐ๋ / ๊ธฐ์กด dataset ์์์ outcome์ด ๋ณํ๊ฒ ๋ง๋ feature๋ฅผ ๋ฝ๊ฑฐ๋ ๋ ๋ค ๊ฐ๋ฅ, data๋ model์ ์๊ด์์ด prediction function์๋ง ์ ๊ทผ์ ์๊ตฌ(๋ฌด์จ ๋ง์ธ์ง ๋ชจ๋ฅด๊ฒ ์...), implementation ์ฌ์
๋ค๋ฅธ feature๊ฐ individual prediction์ ์ํฅ์ ๋ฏธ์น์ง ์๋๋ก ํ๋ ํน์ feature๋ค์ decision rule์ ์ฐพ๋ ๋ฐฉ๋ฒ
- ํน์ coverage ์ด์์ input space์์ ํน์ threshold์ precision์ ๋ง์กฑํ๋ ์กฐ๊ฑด์ ์ฐพ์
โ Finding Anchors
- Candidate Generation
- Best Candidate Identification
- Candidate Precision Validation
- Modified Beam Search
game theory์์ ์ฐฉ์ํ์ฌ, single instance(game)์ ๋ํด feature values(player)๊ฐ gain(single prediction - average prediction)์ ๊ธฐ์ฌํ๋ ์ ๋๋ฅผ ๋ํ๋ด๋ ๋ฐฉ๋ฒ
Sharpley Value : average of marginal contributions to all possible coalitions
์ฅ์ : contrastive explanation ๊ฐ๋ฅ(์ ์ฒด dataset/subset/single data point์๋ ๋น๊ต ๊ฐ๋ฅ), solid theory ์์
๋จ์ : ์ฐ์ฐ๋ โ, ์๋ชป๋ ํด์์ ๊ฐ๋ฅ์ฑ(ํด๋น feature๊ฐ ์ ๊ฑฐ๋์์ ๋์ contribution์ด ์๋, feature value๊ฐ average์์ ์ฐจ์ด์ ๊ธฐ์ฌํ๋ ์ ๋๋ฅผ ๋ํ๋), prediction model์ด ์๋, feature๊ฐ correlated ๋์ด์์ผ๋ฉด, unrealistic data๋ฅผ ํฌํจํ ์ ์์
kernel-based estimation approach for Shapley values
โ KernelSHAP
estimates for an instance x the contributions of each feature value to the prediction
โ TreeSHAP
Tree-based model์ ์ํ SHAP
- marginal expectation ๋์ conditional expectation์ ์ฌ์ฉ
- KernelSHAP๋ณด๋ค ์ฐ์ฐ complexity๊ฐ ๋ฎ์(TLD^2 < TL2^M, T: # of trees, L: max # of leaves, D: max depth)
์) Cervical Cancer
์ฐธ๊ณ ) SHAP graph ํด์
์ฅ์ : solid theoretical foundation, constrastive explanations, LIME๊ณผ Shapley values๋ฅผ ์ฐ๊ฒฐ, fast implementation for tree-based models(global model interpretations๋ฅผ ์ํ ์ฐ์ฐ์ ์ ๋ฆฌ)
๋จ์ : KernelSHAP์ ๋๋ฆผ, KernelSHAP์ feature dependence๋ฅผ ๋ฌด์, TreeSHAP์ unintuitive feature attribution์ ๋ง๋ค์ด๋ผ ์ ์์, ์๋ชป ํด์๋ ์ ์์
์์ interpretation method๋ค๊ณผ์ ์ฐจ๋ณ์
- hidden layers์ feature๋ค์ uncoverํ ์ ์๋๋ก
- gradient๋ฅผ interpretation์ ์ด์ฉํ ์ ์์
โ Feature Visualization
finding the input that maximizes the activation of that unit(individual neurons, channel, entire layers)
- unit์ activation์ ์ต๋ํํ๋ image๋ฅผ ์ฐพ๋ optimization problem์ด๋ผ๊ณ ์๊ฐํ๋ฉด ๋จ
- ๊ธฐ์กด ๋ฐ์ดํฐ์์ ์ฐพ์ ์๋ ์๊ณ , ์๋ก์ด ์ด๋ฏธ์ง๋ฅผ ์์ฑํ ์๋ ์์
- tabular data์ ๋ํด์๋ unit์ activation์ ์ต๋ํํ๋ feature์ ์กฐํฉ์ ์ฐพ๋ ๋ฌธ์
โ Network Dissection
์ฐธ๊ณ : Network Dissection
CNN unit์ interpretability๋ฅผ ์ ๋ํํ๋ ๋ฐฉ๋ฒ
๊ฐ์ : Units of a neural network (like convolutional channels) learn disentangled concept
- Broden dataset(Broadly and densely labeled data)๊ฐ ํ์(pixel level์ concepts์ labeling ํด์ค์ผ ํจ)
- image์์ top activated area๋ฅผ ์ฐพ์๋ด activation mask๋ฅผ ๋ง๋ ๋ค.
- ํด๋น activation mask์ ๊ฐ์ฅ ๋ง์ด ์ผ์นํ๋ concept๋ฅผ ์ฐพ๋๋ค.
์ฅ์ : unique insight๋ฅผ ์ค, unit์ concept๊ณผ ์๋์ผ๋ก ์ฐ๊ฒฐํด์ค, non-technical way๋ก ์ํต ๊ฐ๋ฅ, class๋ฅผ ๋์ด์ concept๊น์ง detect ๊ฐ๋ฅ
๋จ์ : feature visualization image๋ ํด์ ๋ถ๊ฐ๋ฅํ ๊ฒฝ์ฐ๊ฐ ๋ง์, unit์ด ๋๋ฌด ๋ง์, pixel level labeled data ํ์
classification๊ณผ ๊ด๋ จ ์๋ pixel์ highlightํ๋ ๋ฐฉ๋ฒ(sensitivity map, saliency map, pixel attribution map ๋ฑ)
- Vanilla Gradient
- DeconvNet
- Grad-CAM(Gradient-weighted Class Activation Map)
- Guided Grad-CAM
- SmoothGrad
์ฅ์ : explanations are visual, faster to compute than mode-agnostic methods, many methods to choose from
Neural network์ ์ํด ํ์ต๋ latent space์ ๋ด์ฌ๋ concept์ detect
โ TCAV(Testing with Concept Activation Vectors)
concept๊ณผ class ๊ฐ์ ๊ด๊ณ๋ฅผ ๋ฌ์ฌ(์: ์ค๋ฌด๋ฌ๊ฐ ์ผ๋ฃฉ๋ง class์ ์ด๋ป๊ฒ ์ํฅ์ ๋ฏธ์น๋์ง)
CAV: numerical representation that generalizes a concept in the activation space of a NN layer
๋จ์ : shallow NN์์๋ ์ฑ๋ฅ โ, concept labeling ํ์, text๋ tabular data์๋ ์ ์ฉ ์ด๋ ต๋ค.
โ ๋ค๋ฅธ ๋ฐฉ๋ฒ๋ค: ACE, CBM, CW
small perturbation์ ๊ฐํด์ model์ deceiveํ๋ samples
(model์ ์ทจ์ฝ์ ์ ์ฐพ๊ธฐ ์ํ ๊ฒ์ธ ๋ฏ)
- Fast gradient sign method
- 1-pixel attack
- Adversarial patch
- Black box attack
model์ parameter๋ prediction์ ๋ณํ์ํค๋ instance๋ฅผ ์ฐพ๋ ๊ฒ
(model์ debugํ๊ฑฐ๋ ์ค๋ช
ํ๋ ๋ฐ ๋์์ด ๋จ, problematic instance๊ฐ ์๊ฑฐ๋, measurement error๊ฐ ์๊ฑฐ๋ ๋ฑ)
- ํด๋น instance๋ฅผ ์ ๊ฑฐ(deletion Diagnostics)ํ๊ฑฐ๋ loss์ ๊ฐ์ค์น๋ฅผ ์กฐ๊ธ ํฌ๊ฒ(influence functions) ํ์ ๋ ์ผ๋ง๋ ๋ณํ๋์ง
- outlier์๋ ๋ค๋ฆ(dataset๊ณผ์ ๊ฑฐ๋ฆฌ๊ฐ ๋จผ ๊ฒ), ํ์ง๋ง outlier๊ฐ influential instance๊ฐ ๋ ์๋ ์์
โ Deletion Diagnostics: instance๋ฅผ ์ ๊ฑฐํ์ฌ parameter๋ prediction์ ๋ณํ๋ฅผ ๋ณด๋ ๊ฒ
- DFBETA: parameter์ ๋ณํ ์ ๋ํ
- Cook's distance: prediction์ ๋ณํ ์ ๋ํ
โ Influence functions: loss์ ๊ฐ์ค์น(e)๋ฅผ ์กฐ๊ธ ํฌ๊ฒ ํ์ ๋์ ๋ณํ๋ฅผ ๋ณด๋ ๊ฒ
retraining์ ํ์ง ์๊ณ , ๋ณํ๋ฅผ approximateํ๋ ๋ฐฉ๋ฒ