Skip to content

Commit

Permalink
Balanced clustering with lambda param
Browse files Browse the repository at this point in the history
  • Loading branch information
jacketsj committed Oct 18, 2024
1 parent 631e9bf commit 83a30a4
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 16 deletions.
15 changes: 13 additions & 2 deletions python/python/lance/cuvs/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def __init__(
seed: Optional[int] = None,
device: Optional[str] = None,
itopk_size: int = 10,
balance_factor: Optional[float] = None,
cluster_counts: Optional[torch.Tensor] = None,
):
if metric == "dot":
raise ValueError(
Expand All @@ -70,6 +72,8 @@ def __init__(
centroids=centroids,
seed=seed,
device=device,
balance_factor=balance_factor,
cluster_counts=cluster_counts,
)

if self.device.type != "cuda" or not torch.cuda.is_available():
Expand All @@ -95,9 +99,13 @@ def fit(
logging.info("Total rebuild time: %s", self.time_rebuild)

def rebuild_index(self):
centroids = self.centroids
if self.balance_factor is not None:
self.pad_centroids()

rebuild_time_start = time.time()
cagra_metric = "sqeuclidean"
dim = self.centroids.shape[1]
dim = centroids.shape[1]
graph_degree = max(dim // 4, 32)
nn_descent_degree = graph_degree * 2
index_params = cagra.IndexParams(
Expand All @@ -107,7 +115,7 @@ def rebuild_index(self):
build_algo="nn_descent",
compression=None,
)
self.index = cagra.build(index_params, self.centroids)
self.index = cagra.build(index_params, centroids)
rebuild_time_end = time.time()
self.time_rebuild += rebuild_time_end - rebuild_time_start

Expand All @@ -121,6 +129,9 @@ def _transform(
if self.metric == "cosine":
data = torch.nn.functional.normalize(data)

if self.padded_centroids is not None:
data = self.pad_data(data)

search_time_start = time.time()
device = torch.device("cuda")
out_idx = raft_common.device_ndarray.empty((data.shape[0], 1), dtype="uint32")
Expand Down
6 changes: 6 additions & 0 deletions python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1472,6 +1472,7 @@ def create_index(
storage_options: Optional[Dict[str, str]] = None,
filter_nan: bool = True,
one_pass_ivfpq: bool = False,
balance_factor: Optional[float] = None,
**kwargs,
) -> LanceDataset:
"""Create index on column.
Expand Down Expand Up @@ -1534,6 +1535,9 @@ def create_index(
for nullable columns. Obtains a small speed boost.
one_pass_ivfpq: bool
Defaults to False. If enabled, index type must be "IVF_PQ". Reduces disk IO.
balance_factor: float, optional
A factor used to balance clusters. No balancing by default. 1 is often a
good value. Only enabled if using an accelerator.
kwargs :
Parameters passed to the index building process.
Expand Down Expand Up @@ -1683,6 +1687,7 @@ def create_index(
num_sub_vectors=num_sub_vectors,
batch_size=20480,
filter_nan=filter_nan,
balance_factor=balance_factor,
)
)
timers["ivf+pq_train:end"] = time.time()
Expand Down Expand Up @@ -1783,6 +1788,7 @@ def create_index(
metric,
accelerator,
filter_nan=filter_nan,
balance_factor=balance_factor,
)
timers["ivf_train:end"] = time.time()
ivf_train_time = timers["ivf_train:end"] - timers["ivf_train:start"]
Expand Down
77 changes: 63 additions & 14 deletions python/python/lance/torch/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def __init__(
centroids: Optional[torch.Tensor] = None,
seed: Optional[int] = None,
device: Optional[str] = None,
balance_factor: Optional[float] = None,
cluster_counts: Optional[torch.Tensor] = None,
):
self.k = k
self.max_iters = max_iters
Expand All @@ -82,6 +84,13 @@ def __init__(
self.device = preferred_device(device)
self.tolerance = tolerance
self.seed = seed
self.balance_factor = balance_factor
self.padded_centroids = None

if cluster_counts is None:
self.counts = torch.zeros(k, device=self.device)
else:
self.counts = cluster_counts

self.y2 = None

Expand Down Expand Up @@ -169,14 +178,12 @@ def fit(
logging.debug("Total distance: %s, iter: %s", self.total_distance, i)
logging.info("Finish KMean training in %s", time.time() - start)

def _updated_centroids(
self, centroids: torch.Tensor, counts: torch.Tensor
) -> torch.Tensor:
centroids = centroids / counts[:, None]
zero_counts = counts == 0
def _updated_centroids(self, centroids: torch.Tensor) -> torch.Tensor:
centroids = centroids / self.counts[:, None]
zero_counts = self.counts == 0
for idx in zero_counts.nonzero(as_tuple=False):
# split the largest cluster and remove empty cluster
max_idx = torch.argmax(counts).item()
max_idx = torch.argmax(self.counts).item()
# add 1% gassuian noise to the largest centroid
# do this twice so we effectively split the largest cluster into 2
# rand_like returns on [0, 1) so we need to shift it to [-0.5, 0.5)
Expand Down Expand Up @@ -229,9 +236,9 @@ def _fit_once(
new_centroids = torch.zeros_like(
self.centroids, device=self.device, dtype=torch.float32
)
counts_per_part = torch.zeros(self.centroids.shape[0], device=self.device)
ones = torch.ones(1024 * 16, device=self.device)
self.rebuild_index()
self.counts = torch.zeros(self.k, device=self.device)
ones = torch.ones(1024 * 16, device=self.device)
for idx, chunk in enumerate(data):
if idx % 50 == 0:
logging.info("Kmeans::train: epoch %s, chunk %s", epoch, idx)
Expand All @@ -253,7 +260,7 @@ def _fit_once(
ones = torch.ones(len(ids), out=ones, device=self.device)

new_centroids.index_add_(0, ids, chunk.type(torch.float32))
counts_per_part.index_add_(0, ids, ones[: ids.shape[0]])
self.counts.index_add_(0, ids, ones[: ids.shape[0]])
del ids
del dists
del chunk
Expand All @@ -274,13 +281,50 @@ def _fit_once(
raise StopIteration("kmeans: converged")

# cast to the type we get the data in
self.centroids = self._updated_centroids(new_centroids, counts_per_part).type(
dtype
)
self.centroids = self._updated_centroids(new_centroids).type(dtype)
return total_dist

def pad_centroids(self):
if self.metric == "dot":
self.padded_centroids = torch.cat(
[
self.centroids,
-(self.balance_factor * self.counts).unsqueeze(1),
],
dim=1,
)
else:
self.padded_centroids = torch.cat(
[
self.centroids,
torch.sqrt(self.balance_factor * self.counts).unsqueeze(1),
],
dim=1,
)
self.y2 = (self.padded_centroids * self.padded_centroids).sum(dim=1)

def rebuild_index(self):
self.y2 = (self.centroids * self.centroids).sum(dim=1)
if self.balance_factor is not None:
self.pad_centroids()

def pad_data(self, data):
if self.metric == "dot":
return torch.cat(
[
data,
torch.ones(data.size(0), 1, device=data.device, dtype=data.dtype),
],
dim=1,
)
else:
return torch.cat(
[
data,
torch.zeros(data.size(0), 1, device=data.device, dtype=data.dtype),
],
dim=1,
)

def _transform(
self,
Expand All @@ -290,10 +334,15 @@ def _transform(
if self.metric == "cosine":
data = torch.nn.functional.normalize(data)

centroids = self.centroids
if self.padded_centroids is not None:
centroids = self.padded_centroids
data = self.pad_data(data)

if self.metric in ["l2", "cosine"]:
return self.dist_func(data, self.centroids, y2=y2)
return self.dist_func(data, centroids, y2=y2)
else:
return self.dist_func(data, self.centroids)
return self.dist_func(data, centroids)

def transform(
self, data: Union[pa.Array, np.ndarray, torch.Tensor]
Expand Down
11 changes: 11 additions & 0 deletions python/python/lance/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def train_ivf_centroids_on_accelerator(
sample_rate: int = 256,
max_iters: int = 50,
filter_nan: bool = True,
balance_factor: Optional[float] = None,
) -> (np.ndarray, Any):
"""Use accelerator (GPU or MPS) to train kmeans."""

Expand Down Expand Up @@ -277,6 +278,7 @@ def train_ivf_centroids_on_accelerator(
metric=metric_type,
device="cuda",
centroids=init_centroids,
balance_factor=balance_factor,
)
else:
logging.info("Training IVF partitions using GPU(%s)", accelerator)
Expand All @@ -286,15 +288,22 @@ def train_ivf_centroids_on_accelerator(
metric=metric_type,
device=accelerator,
centroids=init_centroids,
balance_factor=balance_factor,
)
kmeans.fit(ds)

centroids = kmeans.centroids.cpu().numpy()
counts = kmeans.counts.cpu().numpy()

with tempfile.NamedTemporaryFile(delete=False) as f:
np.save(f, centroids)
logging.info("Saved centroids to %s", f.name)

if balance_factor is not None:
with tempfile.NamedTemporaryFile(delete=False) as f:
np.save(f, counts)
logging.info("Saved cluster counts to %s", f.name)

return centroids, kmeans


Expand Down Expand Up @@ -598,6 +607,7 @@ def one_pass_train_ivf_pq_on_accelerator(
sample_rate: int = 256,
max_iters: int = 50,
filter_nan: bool = True,
balance_factor: Optional[float] = None,
):
centroids, kmeans = train_ivf_centroids_on_accelerator(
dataset,
Expand All @@ -609,6 +619,7 @@ def one_pass_train_ivf_pq_on_accelerator(
sample_rate=sample_rate,
max_iters=max_iters,
filter_nan=filter_nan,
balance_factor=balance_factor,
)
dataset_residuals = compute_partitions(
dataset,
Expand Down

0 comments on commit 83a30a4

Please sign in to comment.