본문 바로가기
공부/통계학

VIF (분산확장요인, python)

by signature95 2022. 1. 11.
728x90
반응형

Feature selection 방법은 크게 3가지로 나뉜다.

  1. Filter Method (Feature간 상관성 기반)
  2. Wrapper Method (Feature를 조정하며 모형을 형성하고 예측 성능을 참고하여 Feature 선택)
  3. Embedded Method (예측 모형 최적화, 회귀계수 추정 과정에서 각 Feature가 선택되는 방식)

 

이번에 살펴볼 것은 Filter Method 방법인 VIF(Variance Inflation Fector, 분산확장요인)이다.

 

 

먼저 VIF의 식을 보면 다음과 같다.

VIF는 다중 공선성(Multicollinearity)을 측정하는데 필요한 방법으로
다중공선성이란, 독립변수(feature)간 상관관계가 있는 것을 의미한다.

 

일반적으로 OLS회귀 가정이나, 여러 회귀와 분류 모형의 가정은 변수간 독립성을 가정한다. 즉 Feature간 상관관계가 없다는 것을 기본 가정으로 가져간다. 

 

(참고)

2022.02.09 - [공부/통계학] - 기초통계 (상관계수) python

 

기초통계 (상관계수) python

이전 포스트에 이어서 작성하는 내용입니다. 2022.02.08 - [공부/통계학] - 기초 통계 (분산) python 기초 통계 (분산) python 이전 포스트에 이어서 작성하는 내용입니다. 2022.02.04 - [공부/통계학] - 기초

signature95.tistory.com

 

VIF가 10이 넘으면 다중공선성이 있는 것으로 판단하며 5가 넘더라도 주의를 요한다. 만약 특정 Feature a와 b가 서로 상관 관계가 있다고 했을 때 두 변수 모두 VIF가 높다고 판단한다. 어느 하나만 VIF가 높은 경우는 없다고 생각하면된다.

 

 

그렇다면, 보스턴 집값 데이터로 한번 확인해보자

import pandas as pd
import numpy as np
import seaborn as sns
from statsmodels.stats.outliers_influence import variance_inflation_factor
import matplotlib.pyplot as plt
from matplotlib import rc

rc('font', family='AppleGothic')
plt.rcParams['axes.unicode_minus'] = False

# 깃허브에 업로드된 데이터 불러오기 (보스턴 집값 데이터)
data = pd.read_csv("https://raw.githubusercontent.com/signature95/tistory/main/dataset/boston.csv")
data

출력 결과

	CRIM	ZN	INDUS	CHAS	NOX	RM	AGE	DIS	RAD	TAX	PTRATIO	B	LSTAT	MEDV	CAT. MEDV
0	0.00632	18.0	2.31	0	0.538	6.575	65.2	4.0900	1	296	15.3	396.90	4.98	24.0	0
1	0.02731	0.0	7.07	0	0.469	6.421	78.9	4.9671	2	242	17.8	396.90	9.14	21.6	0
2	0.02729	0.0	7.07	0	0.469	7.185	61.1	4.9671	2	242	17.8	392.83	4.03	34.7	1
3	0.03237	0.0	2.18	0	0.458	6.998	45.8	6.0622	3	222	18.7	394.63	2.94	33.4	1
4	0.06905	0.0	2.18	0	0.458	7.147	54.2	6.0622	3	222	18.7	396.90	5.33	36.2	1
...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...
501	0.06263	0.0	11.93	0	0.573	6.593	69.1	2.4786	1	273	21.0	391.99	9.67	22.4	0
502	0.04527	0.0	11.93	0	0.573	6.120	76.7	2.2875	1	273	21.0	396.90	9.08	20.6	0
503	0.06076	0.0	11.93	0	0.573	6.976	91.0	2.1675	1	273	21.0	396.90	5.64	23.9	0
504	0.10959	0.0	11.93	0	0.573	6.794	89.3	2.3889	1	273	21.0	393.45	6.48	22.0	0
505	0.04741	0.0	11.93	0	0.573	6.030	80.8	2.5050	1	273	21.0	396.90	7.88	11.9	0

 

먼저 다중공선성을 확인하기 앞서, Correlation matrix를 출력해본다.

 

data.corr()
		CRIM		ZN		INDUS		CHAS		NOX		RM		AGE		DIS		RAD		TAX		PTRATIO		B		LSTAT		MEDV		CAT. MEDV
CRIM		1.000000	-0.200469	0.406583	-0.055892	0.420972	-0.219247	0.352734	-0.379670	0.625505	0.582764	0.289946	-0.385064	0.455621	-0.388305	-0.151987
ZN		-0.200469	1.000000	-0.533828	-0.042697	-0.516604	0.311991	-0.569537	0.664408	-0.311948	-0.314563	-0.391679	0.175520	-0.412995	0.360445	0.365296
INDUS		0.406583	-0.533828	1.000000	0.062938	0.763651	-0.391676	0.644779	-0.708027	0.595129	0.720760	0.383248	-0.356977	0.603800	-0.483725	-0.366276
CHAS		-0.055892	-0.042697	0.062938	1.000000	0.091203	0.091251	0.086518	-0.099176	-0.007368	-0.035587	-0.121515	0.048788	-0.053929	0.175260	0.108631
NOX		0.420972	-0.516604	0.763651	0.091203	1.000000	-0.302188	0.731470	-0.769230	0.611441	0.668023	0.188933	-0.380051	0.590879	-0.427321	-0.232502
RM		-0.219247	0.311991	-0.391676	0.091251	-0.302188	1.000000	-0.240265	0.205246	-0.209847	-0.292048	-0.355501	0.128069	-0.613808	0.695360	0.641265
AGE		0.352734	-0.569537	0.644779	0.086518	0.731470	-0.240265	1.000000	-0.747881	0.456022	0.506456	0.261515	-0.273534	0.602339	-0.376955	-0.191196
DIS		-0.379670	0.664408	-0.708027	-0.099176	-0.769230	0.205246	-0.747881	1.000000	-0.494588	-0.534432	-0.232471	0.291512	-0.496996	0.249929	0.118887
RAD		0.625505	-0.311948	0.595129	-0.007368	0.611441	-0.209847	0.456022	-0.494588	1.000000	0.910228	0.464741	-0.444413	0.488676	-0.381626	-0.197924
TAX		0.582764	-0.314563	0.720760	-0.035587	0.668023	-0.292048	0.506456	-0.534432	0.910228	1.000000	0.460853	-0.441808	0.543993	-0.468536	-0.273687
PTRATIO		0.289946	-0.391679	0.383248	-0.121515	0.188933	-0.355501	0.261515	-0.232471	0.464741	0.460853	1.000000	-0.177383	0.374044	-0.507787	-0.443425
B		-0.385064	0.175520	-0.356977	0.048788	-0.380051	0.128069	-0.273534	0.291512	-0.444413	-0.441808	-0.177383	1.000000	-0.366087	0.333461	0.155137
LSTAT		0.455621	-0.412995	0.603800	-0.053929	0.590879	-0.613808	0.602339	-0.496996	0.488676	0.543993	0.374044	-0.366087	1.000000	-0.737663	-0.469911
MEDV		-0.388305	0.360445	-0.483725	0.175260	-0.427321	0.695360	-0.376955	0.249929	-0.381626	-0.468536	-0.507787	0.333461	-0.737663	1.000000	0.789789
CAT. MEDV	-0.151987	0.365296	-0.366276	0.108631	-0.232502	0.641265	-0.191196	0.118887	-0.197924	-0.273687	-0.443425	0.155137	-0.469911	0.789789	1.000000

이렇게 간단하게 살펴보아도 일부 feature 간 상관성이 높게 도출되는 것을 확인할 수 있다.

 

 

이번에는 간단하게 산점도를 그려본다.

일부 feature만 선정하여 (MEDV, TAX, NOX) 컬럼에 대한 pairplot으로 산점도를 확인할수 있다.

- MEDV(본인의 집 중위값), NOX(10ppm 당 농축 이산화질소), TAX(10000달러 당 재산세율)

# 간단한 산점도
sns.pairplot(data[['MEDV', 'TAX', 'NOX']])

MEDV, NOX의 산점도를 보면 음의 상관성이 일부 보이는것을 알 수 있다. 이처럼 산점도를 통해서도 다중공선성을 확인할 수도 있다.

 

하지만, 통계적으로 확인하기 위해서는 VIF를 도출할 필요가 있다.

# VIF 출력을 위한 데이터 프레임 형성
vif = pd.DataFrame()

# VIF 값과 각 Feature 이름에 대해 설정
vif["VIF Factor"] = [variance_inflation_factor(data.values, i) for i in range(data.shape[1])]
vif["features"] = data.columns 

# VIF 값이 높은 순으로 정렬
vif = vif.sort_values(by="VIF Factor", ascending=False)
vif = vif.reset_index().drop(columns='index')
vif

출력 결과

VIF 	Factor		features
0	136.875365	RM
1	91.819346	PTRATIO
2	74.549360	NOX
3	61.939733	TAX
4	37.854383	MEDV
5	21.669504	B
6	21.541039	AGE
7	16.044949	DIS
8	15.404871	RAD
9	14.755787	INDUS
10	12.824787	LSTAT
11	3.962213	CAT. MEDV
12	3.043697	ZN
13	2.160156	CRIM
14	1.180552	CHAS

 

VIF가 10을 초과하는 Feature가 무려 11개나 되는 것을 확인할 수 있다. 이런 상황에서 OLS 회귀를 시행한다면, 다중공선성의 문제로 인해 제대로된 결과를 도출하기 어렵다. 

하지만, 무조건 10을 초과한다해서 한꺼번에 feature를 drop하는 법보다는 VIF가 가장 높은 상위 몇 개의 Feature를 제거하면서 확인하는 것이 더 바람직하다. (한번에 drop하면 정보손실이 너무 크게 발생할 수 있기 때문)

 

 

따라서 이번에는 VIF가 높은 상위 1개 Feature에 대해 제거한다.

결국 모든 feature가 10을 초과하지 않도록 Feature를 선택하도록 코드를 만들고 시행해 볼 것이다.

def vif(x):
    # vif 10 초과시 drop을 위한 임계값 설정
    thresh = 10
    # Filter method로 feature selection 진행 후 최종 도출 될 데이터 프레임 형성
    output = pd.DataFrame()
    # 데이터의 컬럼 개수 설정
    k = x.shape[1]
    # VIF 측정
    vif = [variance_inflation_factor(x.values, i) for i in range(x.shape[1])]
    for i in range(1,k):
        print(f'{i}번째 VIF 측정')
        # VIF 최대 값 선정
        a = np.argmax(vif)
        print(f'Max VIF feature & value : {x.columns[a]}, {vif[a]}')
        # VIF 최대 값이 임계치를 넘지 않는 경우 break
        if (vif[a] <= thresh):
            print('\n')
            for q in range(output.shape[1]):
                print(f'{output.columns[q]}의 vif는 {np.round(vif[q],2)}입니다.')
            break
        # VIF 최대 값이 임계치를 넘는 경우, + 1번째 시도인 경우 : if 문으로 해당 feature 제거 후 다시 vif 측정
        if (i == 1):
            output = x.drop(x.columns[a], axis = 1)
            vif = [variance_inflation_factor(output.values, j) for j in range(output.shape[1])]
        # VIF 최대 값이 임계치를 넘는 경우, + 1번째 이후 시도인 경우 : if 문으로 해당 feature 제거 후 다시 vif 측정
        elif (i > 1):
            output = output.drop(output.columns[a], axis = 1)
            vif = [variance_inflation_factor(output.values, j) for j in range(output.shape[1])]
    return(output)

출력 결과

vif(data)

>>>
1번째 VIF 측정
Max VIF feature & value : RM, 136.87536508826085
2번째 VIF 측정
Max VIF feature & value : TAX, 72.36814995542797
3번째 VIF 측정
Max VIF feature & value : NOX, 59.92707600382466
4번째 VIF 측정
Max VIF feature & value : DIS, 55.03743648686375
5번째 VIF 측정
Max VIF feature & value : TAX, 23.178115127470832
6번째 VIF 측정
Max VIF feature & value : NOX, 15.327774051240777
7번째 VIF 측정
Max VIF feature & value : AGE, 11.867610188739521
8번째 VIF 측정
Max VIF feature & value : INDUS, 6.951560104851184


CRIM의 vif는 2.07입니다.
ZN의 vif는 2.49입니다.
INDUS의 vif는 6.95입니다.
CHAS의 vif는 1.11입니다.
DIS의 vif는 4.02입니다.
RAD의 vif는 4.72입니다.
LSTAT의 vif는 6.93입니다.
CAT. MEDV의 vif는 1.41입니다.

	CRIM	ZN	INDUS	CHAS	DIS	RAD	LSTAT	CAT. MEDV
0	0.00632	18.0	2.31	0	4.0900	1	4.98	0
1	0.02731	0.0	7.07	0	4.9671	2	9.14	0
2	0.02729	0.0	7.07	0	4.9671	2	4.03	1
3	0.03237	0.0	2.18	0	6.0622	3	2.94	1
4	0.06905	0.0	2.18	0	6.0622	3	5.33	1
...	...	...	...	...	...	...	...	...
501	0.06263	0.0	11.93	0	2.4786	1	9.67	0
502	0.04527	0.0	11.93	0	2.2875	1	9.08	0
503	0.06076	0.0	11.93	0	2.1675	1	5.64	0
504	0.10959	0.0	11.93	0	2.3889	1	6.48	0
505	0.04741	0.0	11.93	0	2.5050	1	7.88	0

확인해보면, 초기에 VIF를 확인했을 때, 15개 중 11개가 VIF 10을 초과하여 4개만 남는 결과가 나왔지만, 상위 1개씩 제거한 결과 8개의 Feature가 선택됨을 알 수 있다.

 

Filter Method를 이용한, 방법으로 Feature를 선택해보았다.

 

다음 글에서 이어집니다.

2022.01.12 - [공부/모델링] - Forward feature selection (전진선택법) python

 

Forward feature selection (전진선택법) python

이전 filter method를 다룬 VIF (분산확장요인, python)에 이어서 작성하는 포스트입니다. 2022.01.11 - [공부/모델링] - VIF (분산확장요인, python) VIF (분산확장요인, python) Feature selection 방법은 크게..

signature95.tistory.com

 

728x90

댓글