Source code for gloss.strategies.unexplored
import warnings
import numpy as np
from gloss.utils import compute_min_distances, is_duplicate
[docs]
def find_unexplored(surrogate, space, n_points, explored, excluded, direction,
tolerance=0.0, unexplored_threshold=None, n_random_samples=10000,
distance_metric="euclidean"):
"""Find points in unexplored regions with best predicted values."""
sign = 1.0 if direction == "maximize" else -1.0
if space.mode == "discrete":
candidates = space.get_candidates_excluding(excluded, tolerance)
else:
candidates = space.sample(n_random_samples)
if len(candidates) == 0:
warnings.warn("unexplored: no candidates available.")
return []
if distance_metric == "jaccard":
from sklearn.neighbors import BallTree
if len(explored) == 0:
# No observed points → treat everything as unexplored
min_dists = np.full(len(candidates), np.inf)
else:
_tree = BallTree(explored.astype(float), metric="jaccard")
min_dists, _ = _tree.query(candidates.astype(float), k=1)
min_dists = min_dists.ravel()
else:
min_dists = compute_min_distances(candidates, explored)
if unexplored_threshold is None:
if explored.shape[0] >= 2:
from scipy.spatial.distance import pdist
avg_dist = np.mean(pdist(explored))
unexplored_threshold = avg_dist * 0.5
else:
unexplored_threshold = space.diagonal * 0.1
far_mask = min_dists > unexplored_threshold
far_candidates = candidates[far_mask]
if len(far_candidates) == 0:
warnings.warn(
f"unexplored: no candidates above threshold {unexplored_threshold:.4f}. "
"Returning best among all remaining candidates."
)
far_candidates = candidates
preds = surrogate.predict(far_candidates)
order = np.argsort(sign * preds)[::-1]
all_excluded = excluded.copy() if excluded.shape[0] > 0 else np.empty((0, space.ndim))
results = []
for idx in order:
if len(results) >= n_points:
break
point = far_candidates[idx]
if is_duplicate(point, all_excluded, tolerance):
continue
results.append({
"point": point.tolist(),
"strategy": "unexplored",
"predicted_value": float(preds[idx]),
})
all_excluded = np.vstack([all_excluded, point.reshape(1, -1)])
if len(results) < n_points:
warnings.warn(f"unexplored: only found {len(results)}/{n_points} points.")
return results