Skip to content

Commit

Permalink
Fix size mismatch error when CatToNumTransform sees only a subset o…
Browse files Browse the repository at this point in the history
…f labels at test time (#446)

Fixes this:

```console
$ python benchmark/data_frame_benchmark.py --scale large --idx 1 --model ExcelFormer --task_type multiclass_classification
Downloading https://archive.ics.uci.edu/static/public/158/poker+hand.zip
Traceback (most recent call last):
  File "<console>", line 1, in <module>
  File "/home/aki/work/github.com/pyg-team/pytorch-frame/torch_frame/transforms/fittable_base_transform.py", line 25, in __call__
    return self.forward(copy.copy(tf))
  File "/home/aki/work/github.com/pyg-team/pytorch-frame/torch_frame/transforms/fittable_base_transform.py", line 88, in forward
    transformed_tf = self._forward(tf)
  File "/home/aki/work/github.com/pyg-team/pytorch-frame/torch_frame/transforms/cat_to_num_transform.py", line 133, in _forward
    (num_classes - 1)] = ((v + target_mean) /
RuntimeError: The size of tensor a (7) must match the size of tensor b (9) at non-singleton dimension 1
```

The benchmark result will be added in a follow-up.
  • Loading branch information
akihironitta authored Sep 6, 2024
1 parent 63cafb7 commit 2285c6a
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 40 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed size mismatch `RuntimeError` in `transforms.CatToNumTransform` ([#446](https://github.com/pyg-team/pytorch-frame/pull/446))
- Removed CUDA synchronizations from `nn.LinearEmbeddingEncoder` ([#432](https://github.com/pyg-team/pytorch-frame/pull/432))
- Removed CUDA synchronizations from N/A imputation logic in `nn.StypeEncoder` ([#433](https://github.com/pyg-team/pytorch-frame/pull/433), [#434](https://github.com/pyg-team/pytorch-frame/pull/434))

Expand Down
10 changes: 8 additions & 2 deletions test/transforms/test_cat_to_num_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,16 @@ def test_cat_to_num_transform_on_categorical_only_dataset(with_nan):
# Raise informative error when input tensor frame contains new category
out = transform(tensor_frame)

# ensure different max value of y at test time works
tensor_frame.feat_dict[stype.categorical] = torch.zeros_like(
tensor_frame.feat_dict[stype.categorical])
transform(tensor_frame)


@pytest.mark.parametrize('task_type', [
TaskType.MULTICLASS_CLASSIFICATION, TaskType.REGRESSION,
TaskType.BINARY_CLASSIFICATION
TaskType.MULTICLASS_CLASSIFICATION,
TaskType.REGRESSION,
TaskType.BINARY_CLASSIFICATION,
])
def test_cat_to_num_transform_with_loading(task_type):
num_rows = 10
Expand Down
91 changes: 53 additions & 38 deletions torch_frame/transforms/cat_to_num_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,27 @@


class CatToNumTransform(FittableBaseTransform):
r"""A transform that encodes the categorical features of
the :class:`TensorFrame` object using target statistics.
The original transform is explained in
https://dl.acm.org/doi/10.1145/507533.507538
Specifically, each categorical feature is transformed
into numerical feature using m-probability estimate,
defined by (n_c + p * m)/ (n + m), where n_c is the
total count of the category, n is the total count,
p is the prior probability and m is a smoothing factor.
r"""Transforms categorical features in :class:`TensorFrame` using target
statistics. The original transform is explained in
`A preprocessing scheme for high-cardinality categorical attributes in
classification and prediction problems
<https://dl.acm.org/doi/10.1145/507533.507538>`_ paper.
Specifically, each categorical feature is transformed into numerical
feature using m-probability estimate, defined by
.. math::
\frac{n_c + p \cdot m}{n + m}
where :math:`n_c` is the count of the category, :math:`n` is the total
count, :math:`p` is the prior probability and :math:`m` is a smoothing
factor.
"""
def _fit(
self,
tf_train: TensorFrame,
col_stats: dict[str, dict[StatType, Any]],
):
) -> None:
if tf_train.y is None:
raise RuntimeError(
"'{self.__class__.__name__}' cannot be used when target column"
Expand All @@ -39,6 +45,7 @@ def _fit(
"columns. No fitting will be performed.")
self._transformed_stats = col_stats
return

tensor = self._replace_nans(tf_train.feat_dict[stype.categorical],
NAStrategy.MOST_FREQUENT)
self.col_stats = col_stats
Expand All @@ -50,16 +57,16 @@ def _fit(
# the number of columns to (num_target_classes - 1). More details can
# be found in https://dl.acm.org/doi/10.1145/507533.507538
if not torch.is_floating_point(tf_train.y) and tf_train.y.max() > 1:
num_classes = tf_train.y.max() + 1
target = F.one_hot(tf_train.y, num_classes)[:, :-1]
self.num_classes = tf_train.y.max() + 1
target = F.one_hot(tf_train.y, self.num_classes)[:, :-1]
self.target_mean = target.float().mean(dim=0)
shape = tf_train.feat_dict[stype.categorical].shape
transformed_tensor = torch.zeros(shape[0],
shape[1] * (num_classes - 1),
num_rows, num_cols = tf_train.feat_dict[stype.categorical].shape
transformed_tensor = torch.zeros(num_rows,
num_cols * (self.num_classes - 1),
dtype=torch.float32,
device=tf_train.device)
else:
num_classes = 2
self.num_classes = 2
target = tf_train.y.unsqueeze(1)
mask = ~torch.isnan(target)
if (~mask).any():
Expand All @@ -76,11 +83,12 @@ def _fit(
device=tf_train.device)
feat = tensor[:, i]
v = torch.index_select(count, 0, feat).unsqueeze(1).repeat(
1, num_classes - 1)
transformed_tensor[:, i * (num_classes - 1):(i + 1) *
(num_classes - 1)] = ((v + self.target_mean) /
(self.data_size + 1))
columns += [col_name + f"_{i}" for i in range(num_classes - 1)]
1, self.num_classes - 1)
start = i * (self.num_classes - 1)
end = (i + 1) * (self.num_classes - 1)
transformed_tensor[:, start:end] = ((v + self.target_mean) /
(self.data_size + 1))
columns += [f"{col_name}_{i}" for i in range(self.num_classes - 1)]

self.new_columns = columns
transformed_df = pd.DataFrame(transformed_tensor.cpu().numpy(),
Expand All @@ -104,34 +112,41 @@ def _forward(self, tf: TensorFrame) -> TensorFrame:
"The input TensorFrame does not contain any categorical "
"columns. The original TensorFrame will be returned.")
return tf
tensor = self._replace_nans(tf.feat_dict[stype.categorical],
NAStrategy.MOST_FREQUENT)
tensor = self._replace_nans(
tf.feat_dict[stype.categorical],
NAStrategy.MOST_FREQUENT,
)
if not torch.is_floating_point(tf.y) and tf.y.max() > 1:
num_classes = tf.y.max() + 1
shape = tf.feat_dict[stype.categorical].shape
transformed_tensor = torch.zeros(shape[0],
shape[1] * (num_classes - 1),
dtype=torch.float32,
device=tf.device)
num_rows, num_cols = tf.feat_dict[stype.categorical].shape
transformed_tensor = torch.zeros(
num_rows,
num_cols * (self.num_classes - 1),
dtype=torch.float32,
device=tf.device,
)
else:
num_classes = 2
transformed_tensor = torch.zeros_like(
tf.feat_dict[stype.categorical], dtype=torch.float32)
tf.feat_dict[stype.categorical],
dtype=torch.float32,
)
target_mean = self.target_mean.to(tf.device)
for i in range(len(tf.col_names_dict[stype.categorical])):
col_name = tf.col_names_dict[stype.categorical][i]
count = torch.tensor(self.col_stats[col_name][StatType.COUNT][1],
device=tf.device)
count = torch.tensor(
self.col_stats[col_name][StatType.COUNT][1],
device=tf.device,
)
feat = tensor[:, i]
max_cat = feat.max()
if max_cat >= len(count):
raise RuntimeError(
f"{col_name} contains new category {max_cat} not seen "
f"'{col_name}' contains new category '{max_cat}' not seen "
f"during fit stage.")
v = count[feat].unsqueeze(1).repeat(1, num_classes - 1)
transformed_tensor[:, i * (num_classes - 1):(i + 1) *
(num_classes - 1)] = ((v + target_mean) /
(self.data_size + 1))
v = count[feat].unsqueeze(1).repeat(1, self.num_classes - 1)
start = i * (self.num_classes - 1)
end = (i + 1) * (self.num_classes - 1)
transformed_tensor[:, start:end] = ((v + target_mean) /
(self.data_size + 1))

# turn the categorical features into numerical features
if stype.numerical in tf.feat_dict:
Expand Down

0 comments on commit 2285c6a

Please sign in to comment.