[ML] Random Forest
Random Forest
import sklearn.emsemble from RandomForestClassifier
model = RandomForestClassifier()
model.fit(X_train, y_train)
Random Forest 머신러닝 알고리즘의 하나이다.
Random Forest는 여러 개의 Decision tree로 숲을 이룬다.
몇몇의 나무들이 overfitting을 보일 수 있지만, 나무들이 많아 그 영향력이 적어 일반화 성능이 향상된다.
Bagging
Bagging은 Bootstrap Aggregating의 준말이다.
학습 데이터에서 임의의 개수로 데이터를 선택하여 나무를 만드는 것을 Bagging이라고 한다.
임의의 개수는 전체 학습 데이터 개수(n)의 sqrt(n) 만큼 선택한다.
- 복원추출 :
Random subspace
Bagging을 통해 얻어진 데이터를 가지고 나무의 가지를 만들어야 합니다.
이때, 모든 데이터가 아닌 일부 데이터를 가지고 가지치기를 한다.
OOB(Out-of-Bag)
bagging으로 한 번도 선택되지 않을 확률은 36.8% 정도이다.
bagging에서 선택되지 않은 데이터를 검증 데이터로 활용한다.
model = RandomForestClassifier(oob_score=True)
model.fit(train_data, train_label)
acc = model.score(train_data, train_label)
print(f"Accuracy : {acc}")
print(f"OOB Score : {model.oob_score_}")
Accuracy : 0.9820426487093153
OOB Score : 0.8114478114478114
Kaggle 타이타닉 생존자 제출했을 때 점수가 낮은 이유가 있었다.
파라미터
파라미터 | 설명 | Default |
---|---|---|
n_estimators | 나무 개수 지정 | 100 |
max_features | 특성 개수 | auto |
n_jobs | CPU 코어 수 | None |
oob_score | 모델 검증을 위해 OOB 데이터 활용 여부 | False |
-
n_estimators : 클수록 좋다
– 나무가 많아 과적합을 줄일 수 있다.
– 나무가 많으면.. 숲이 커진다. 모든 나무를 보려면 시간과 메모리가 필요하다. -
max_features : 작을수록 좋다
– 일반적으로 기본 값을 쓰는 것이 좋다.
– ‘Default : auto’의 경우 분류에서는 sqrt(n_features)를 회귀에서는 n_features를 의미한다.
– 각 노드에서 가지를 칠 때 bootstrap을 통해 얻어진 데이터 중에서 몇 개의 특성만을 가지고 가지치기를 할 것인가를 의미? -
n_jobs : 그냥 모든 core를 다 쓰고 싶으면 ‘-1’
– ‘Default : None’은 1을 의미한다.
from sklearn.ensemble import RandomForestClassifier
from sklearn import datasets
X, y = datasets.make_classification(n_samples=1000)
def train_data(model, X=X, y=y) :
clf = model
clf.fit(X, y)
model = RandomForestClassifier()
time = %timeit -o train_data(model)
model = RandomForestClassifier(n_jobs=-1)
time = %timeit -o train_data(model)
model = RandomForestClassifier(n_jobs=-1, n_estimators=200)
time = %timeit -o train_data(model)
196 ms ± 23 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
102 ms ± 1.12 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
197 ms ± 1.32 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)