[AutoML] Multivariate Bayesian Optimization
2021. 8. 4. 13:21ㆍ스터디/AutoML
HPO(Hyper Parameter Optimization) 문제는 다수의 하이퍼 파라미터를 다룬다. 때문에, 단일 변수가 아닌 다변량을 다룰수 있는 코드가 필요하다.
이번 글은 다변량 베이지안 최적화를 수행하는 Python 코드를 공유하고자 한다.
이 때 만들어진 코드는 다음글을 토대로 작성되었다.
* 참고: 다변량이 되니, 결과값의 변차가 꽤 크게 나타남.
from scipy.stats import norm
import numpy as np
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF
# 1. Acquisition function
def expected_improvement(mean, std, max):
z = (mean - max ) / std
return (mean-max) * norm.cdf(z) + std * norm.pdf(z)
# 2. Objective function
def f(X):
x1 = X[:,0]
x2 = X[:,1]
return x1 * np.sin(x2)
# 3. Hyper-parameter space
min_x1, max_x1 = -2, 10
min_x2, max_x2 = 6, 7
# 4. Observation Data
X1 = np.random.uniform(min_x1, max_x1, 3)
X2 = np.random.uniform(min_x2, max_x2, 3)
X = np.array([X1,X2]).transpose()
y = f(X).ravel()
# 5. Instantiate Gaussian Process model
model = GaussianProcessRegressor(kernel=RBF(1.0))
for i in np.arange(20):
# 6. Fit to data
model.fit(X, y)
# 7. Acquisition Function
x1s = np.random.uniform(min_x1, max_x1, 10000)
x2s = np.random.uniform(min_x2, max_x2, 10000)
xs = np.array([x1s, x2s]).transpose()
mean, std = model.predict(xs, return_std=True)
acq = expected_improvement(mean, std, y.max())
# 8. Query Objective Function
x_new = xs[acq.argmax()].reshape(1,-1)
y_new = f(x_new)
# 9. Augment Data
X = np.append(X, x_new, axis=0)
y = np.append(y, np.array([y_new]))
# 9. Get optimal result
print("Optimal")
print("X: ",X[y.argmax()])
print("y: ",y.max())
수정된 부분은
1. 하이퍼 파라미터(X)의 종류를 1개에서 2개로 늘리고,
2. 이에 따른 값의 처리 방식을 일부 수정하였다.
하이퍼 파라미터의 종류는 x1, x2로 구분된다.
목적함수는 x1 * sin(x2) 로 설정되었으며,
x1은 -2~10, x2는 6~7 사이의 값에서 하이퍼 파라미터가 결정되도록 만들었다.
분포를 얻기 위해 가장 중요한 함수인 GaussianProcessRegressor는 입력값으로 X와 y를 갖는다.
이 때, X의 shape는 (sample #, hyper parameter #)으로
y의 shape는 (sample #)으로 구성된다.
이 같이 차원이 변경된 데이터를 처리하기 위해, 데이터 선택 방식을 일부 수정하였다.
'스터디 > AutoML' 카테고리의 다른 글
[AutoML] Searching algorithm for HPO: Bayesian Optimization (0) | 2021.08.02 |
---|---|
Blackbox HPO(1): Grid search와 Random search (0) | 2020.04.18 |