Feature selection 방법은 크게 3가지로 나뉜다.
- Filter Method (Feature간 상관성 기반)
- Wrapper Method (Feature를 조정하며 모형을 형성하고 예측 성능을 참고하여 Feature 선택)
- Embedded Method (예측 모형 최적화, 회귀계수 추정 과정에서 각 Feature가 선택되는 방식)
이번에 살펴볼 것은 Filter Method 방법인 VIF(Variance Inflation Fector, 분산확장요인)이다.
먼저 VIF의 식을 보면 다음과 같다.
VIF는 다중 공선성(Multicollinearity)을 측정하는데 필요한 방법으로
다중공선성이란, 독립변수(feature)간 상관관계가 있는 것을 의미한다.
일반적으로 OLS회귀 가정이나, 여러 회귀와 분류 모형의 가정은 변수간 독립성을 가정한다. 즉 Feature간 상관관계가 없다는 것을 기본 가정으로 가져간다.
(참고)
2022.02.09 - [공부/통계학] - 기초통계 (상관계수) python
VIF가 10이 넘으면 다중공선성이 있는 것으로 판단하며 5가 넘더라도 주의를 요한다. 만약 특정 Feature a와 b가 서로 상관 관계가 있다고 했을 때 두 변수 모두 VIF가 높다고 판단한다. 어느 하나만 VIF가 높은 경우는 없다고 생각하면된다.
그렇다면, 보스턴 집값 데이터로 한번 확인해보자
import pandas as pd
import numpy as np
import seaborn as sns
from statsmodels.stats.outliers_influence import variance_inflation_factor
import matplotlib.pyplot as plt
from matplotlib import rc
rc('font', family='AppleGothic')
plt.rcParams['axes.unicode_minus'] = False
# 깃허브에 업로드된 데이터 불러오기 (보스턴 집값 데이터)
data = pd.read_csv("https://raw.githubusercontent.com/signature95/tistory/main/dataset/boston.csv")
data
출력 결과
CRIM ZN INDUS CHAS NOX RM AGE DIS RAD TAX PTRATIO B LSTAT MEDV CAT. MEDV
0 0.00632 18.0 2.31 0 0.538 6.575 65.2 4.0900 1 296 15.3 396.90 4.98 24.0 0
1 0.02731 0.0 7.07 0 0.469 6.421 78.9 4.9671 2 242 17.8 396.90 9.14 21.6 0
2 0.02729 0.0 7.07 0 0.469 7.185 61.1 4.9671 2 242 17.8 392.83 4.03 34.7 1
3 0.03237 0.0 2.18 0 0.458 6.998 45.8 6.0622 3 222 18.7 394.63 2.94 33.4 1
4 0.06905 0.0 2.18 0 0.458 7.147 54.2 6.0622 3 222 18.7 396.90 5.33 36.2 1
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
501 0.06263 0.0 11.93 0 0.573 6.593 69.1 2.4786 1 273 21.0 391.99 9.67 22.4 0
502 0.04527 0.0 11.93 0 0.573 6.120 76.7 2.2875 1 273 21.0 396.90 9.08 20.6 0
503 0.06076 0.0 11.93 0 0.573 6.976 91.0 2.1675 1 273 21.0 396.90 5.64 23.9 0
504 0.10959 0.0 11.93 0 0.573 6.794 89.3 2.3889 1 273 21.0 393.45 6.48 22.0 0
505 0.04741 0.0 11.93 0 0.573 6.030 80.8 2.5050 1 273 21.0 396.90 7.88 11.9 0
먼저 다중공선성을 확인하기 앞서, Correlation matrix를 출력해본다.
data.corr()
CRIM ZN INDUS CHAS NOX RM AGE DIS RAD TAX PTRATIO B LSTAT MEDV CAT. MEDV
CRIM 1.000000 -0.200469 0.406583 -0.055892 0.420972 -0.219247 0.352734 -0.379670 0.625505 0.582764 0.289946 -0.385064 0.455621 -0.388305 -0.151987
ZN -0.200469 1.000000 -0.533828 -0.042697 -0.516604 0.311991 -0.569537 0.664408 -0.311948 -0.314563 -0.391679 0.175520 -0.412995 0.360445 0.365296
INDUS 0.406583 -0.533828 1.000000 0.062938 0.763651 -0.391676 0.644779 -0.708027 0.595129 0.720760 0.383248 -0.356977 0.603800 -0.483725 -0.366276
CHAS -0.055892 -0.042697 0.062938 1.000000 0.091203 0.091251 0.086518 -0.099176 -0.007368 -0.035587 -0.121515 0.048788 -0.053929 0.175260 0.108631
NOX 0.420972 -0.516604 0.763651 0.091203 1.000000 -0.302188 0.731470 -0.769230 0.611441 0.668023 0.188933 -0.380051 0.590879 -0.427321 -0.232502
RM -0.219247 0.311991 -0.391676 0.091251 -0.302188 1.000000 -0.240265 0.205246 -0.209847 -0.292048 -0.355501 0.128069 -0.613808 0.695360 0.641265
AGE 0.352734 -0.569537 0.644779 0.086518 0.731470 -0.240265 1.000000 -0.747881 0.456022 0.506456 0.261515 -0.273534 0.602339 -0.376955 -0.191196
DIS -0.379670 0.664408 -0.708027 -0.099176 -0.769230 0.205246 -0.747881 1.000000 -0.494588 -0.534432 -0.232471 0.291512 -0.496996 0.249929 0.118887
RAD 0.625505 -0.311948 0.595129 -0.007368 0.611441 -0.209847 0.456022 -0.494588 1.000000 0.910228 0.464741 -0.444413 0.488676 -0.381626 -0.197924
TAX 0.582764 -0.314563 0.720760 -0.035587 0.668023 -0.292048 0.506456 -0.534432 0.910228 1.000000 0.460853 -0.441808 0.543993 -0.468536 -0.273687
PTRATIO 0.289946 -0.391679 0.383248 -0.121515 0.188933 -0.355501 0.261515 -0.232471 0.464741 0.460853 1.000000 -0.177383 0.374044 -0.507787 -0.443425
B -0.385064 0.175520 -0.356977 0.048788 -0.380051 0.128069 -0.273534 0.291512 -0.444413 -0.441808 -0.177383 1.000000 -0.366087 0.333461 0.155137
LSTAT 0.455621 -0.412995 0.603800 -0.053929 0.590879 -0.613808 0.602339 -0.496996 0.488676 0.543993 0.374044 -0.366087 1.000000 -0.737663 -0.469911
MEDV -0.388305 0.360445 -0.483725 0.175260 -0.427321 0.695360 -0.376955 0.249929 -0.381626 -0.468536 -0.507787 0.333461 -0.737663 1.000000 0.789789
CAT. MEDV -0.151987 0.365296 -0.366276 0.108631 -0.232502 0.641265 -0.191196 0.118887 -0.197924 -0.273687 -0.443425 0.155137 -0.469911 0.789789 1.000000
이렇게 간단하게 살펴보아도 일부 feature 간 상관성이 높게 도출되는 것을 확인할 수 있다.
이번에는 간단하게 산점도를 그려본다.
일부 feature만 선정하여 (MEDV, TAX, NOX) 컬럼에 대한 pairplot으로 산점도를 확인할수 있다.
- MEDV(본인의 집 중위값), NOX(10ppm 당 농축 이산화질소), TAX(10000달러 당 재산세율)
# 간단한 산점도
sns.pairplot(data[['MEDV', 'TAX', 'NOX']])
MEDV, NOX의 산점도를 보면 음의 상관성이 일부 보이는것을 알 수 있다. 이처럼 산점도를 통해서도 다중공선성을 확인할 수도 있다.
하지만, 통계적으로 확인하기 위해서는 VIF를 도출할 필요가 있다.
# VIF 출력을 위한 데이터 프레임 형성
vif = pd.DataFrame()
# VIF 값과 각 Feature 이름에 대해 설정
vif["VIF Factor"] = [variance_inflation_factor(data.values, i) for i in range(data.shape[1])]
vif["features"] = data.columns
# VIF 값이 높은 순으로 정렬
vif = vif.sort_values(by="VIF Factor", ascending=False)
vif = vif.reset_index().drop(columns='index')
vif
출력 결과
VIF Factor features
0 136.875365 RM
1 91.819346 PTRATIO
2 74.549360 NOX
3 61.939733 TAX
4 37.854383 MEDV
5 21.669504 B
6 21.541039 AGE
7 16.044949 DIS
8 15.404871 RAD
9 14.755787 INDUS
10 12.824787 LSTAT
11 3.962213 CAT. MEDV
12 3.043697 ZN
13 2.160156 CRIM
14 1.180552 CHAS
VIF가 10을 초과하는 Feature가 무려 11개나 되는 것을 확인할 수 있다. 이런 상황에서 OLS 회귀를 시행한다면, 다중공선성의 문제로 인해 제대로된 결과를 도출하기 어렵다.
하지만, 무조건 10을 초과한다해서 한꺼번에 feature를 drop하는 법보다는 VIF가 가장 높은 상위 몇 개의 Feature를 제거하면서 확인하는 것이 더 바람직하다. (한번에 drop하면 정보손실이 너무 크게 발생할 수 있기 때문)
따라서 이번에는 VIF가 높은 상위 1개 Feature에 대해 제거한다.
결국 모든 feature가 10을 초과하지 않도록 Feature를 선택하도록 코드를 만들고 시행해 볼 것이다.
def vif(x):
# vif 10 초과시 drop을 위한 임계값 설정
thresh = 10
# Filter method로 feature selection 진행 후 최종 도출 될 데이터 프레임 형성
output = pd.DataFrame()
# 데이터의 컬럼 개수 설정
k = x.shape[1]
# VIF 측정
vif = [variance_inflation_factor(x.values, i) for i in range(x.shape[1])]
for i in range(1,k):
print(f'{i}번째 VIF 측정')
# VIF 최대 값 선정
a = np.argmax(vif)
print(f'Max VIF feature & value : {x.columns[a]}, {vif[a]}')
# VIF 최대 값이 임계치를 넘지 않는 경우 break
if (vif[a] <= thresh):
print('\n')
for q in range(output.shape[1]):
print(f'{output.columns[q]}의 vif는 {np.round(vif[q],2)}입니다.')
break
# VIF 최대 값이 임계치를 넘는 경우, + 1번째 시도인 경우 : if 문으로 해당 feature 제거 후 다시 vif 측정
if (i == 1):
output = x.drop(x.columns[a], axis = 1)
vif = [variance_inflation_factor(output.values, j) for j in range(output.shape[1])]
# VIF 최대 값이 임계치를 넘는 경우, + 1번째 이후 시도인 경우 : if 문으로 해당 feature 제거 후 다시 vif 측정
elif (i > 1):
output = output.drop(output.columns[a], axis = 1)
vif = [variance_inflation_factor(output.values, j) for j in range(output.shape[1])]
return(output)
출력 결과
vif(data)
>>>
1번째 VIF 측정
Max VIF feature & value : RM, 136.87536508826085
2번째 VIF 측정
Max VIF feature & value : TAX, 72.36814995542797
3번째 VIF 측정
Max VIF feature & value : NOX, 59.92707600382466
4번째 VIF 측정
Max VIF feature & value : DIS, 55.03743648686375
5번째 VIF 측정
Max VIF feature & value : TAX, 23.178115127470832
6번째 VIF 측정
Max VIF feature & value : NOX, 15.327774051240777
7번째 VIF 측정
Max VIF feature & value : AGE, 11.867610188739521
8번째 VIF 측정
Max VIF feature & value : INDUS, 6.951560104851184
CRIM의 vif는 2.07입니다.
ZN의 vif는 2.49입니다.
INDUS의 vif는 6.95입니다.
CHAS의 vif는 1.11입니다.
DIS의 vif는 4.02입니다.
RAD의 vif는 4.72입니다.
LSTAT의 vif는 6.93입니다.
CAT. MEDV의 vif는 1.41입니다.
CRIM ZN INDUS CHAS DIS RAD LSTAT CAT. MEDV
0 0.00632 18.0 2.31 0 4.0900 1 4.98 0
1 0.02731 0.0 7.07 0 4.9671 2 9.14 0
2 0.02729 0.0 7.07 0 4.9671 2 4.03 1
3 0.03237 0.0 2.18 0 6.0622 3 2.94 1
4 0.06905 0.0 2.18 0 6.0622 3 5.33 1
... ... ... ... ... ... ... ... ...
501 0.06263 0.0 11.93 0 2.4786 1 9.67 0
502 0.04527 0.0 11.93 0 2.2875 1 9.08 0
503 0.06076 0.0 11.93 0 2.1675 1 5.64 0
504 0.10959 0.0 11.93 0 2.3889 1 6.48 0
505 0.04741 0.0 11.93 0 2.5050 1 7.88 0
확인해보면, 초기에 VIF를 확인했을 때, 15개 중 11개가 VIF 10을 초과하여 4개만 남는 결과가 나왔지만, 상위 1개씩 제거한 결과 8개의 Feature가 선택됨을 알 수 있다.
Filter Method를 이용한, 방법으로 Feature를 선택해보았다.
다음 글에서 이어집니다.
2022.01.12 - [공부/모델링] - Forward feature selection (전진선택법) python
'공부 > 통계학' 카테고리의 다른 글
Backward Feature Selection (후진제거법) python (0) | 2022.01.13 |
---|---|
Forward feature selection (전진선택법) python (0) | 2022.01.12 |
비모수 검정 : Mann-Witney U-test (만 - 위트니 U 검정, python) (0) | 2022.01.11 |
웰치의 t 검정 : welch's t test (파이썬) (2) | 2022.01.11 |
등분산 검정 (파이썬) (0) | 2022.01.11 |
댓글