Skip to content
This repository has been archived by the owner on Mar 19, 2024. It is now read-only.

Commit

Permalink
Add augly transformation support (#442)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #442

Add support for augly transformations.

Similar to apex, I made augly install optional for users and didn't add it to requirements.txt -- let me know what you think about this.

Reviewed By: prigoyal, QuentinDuval

Differential Revision: D31462923

fbshipit-source-id: ce793f1adc432b3f1ea08acf4b3f66daa88215a8
  • Loading branch information
iseessel authored and facebook-github-bot committed Oct 8, 2021
1 parent 83d859f commit dd9971a
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 9 deletions.
26 changes: 19 additions & 7 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ cpu: &cpu
environment:
TERM: xterm
machine:
image: default
image: ubuntu-1604:201903-01
resource_class: medium

gpu: &gpu
Expand All @@ -37,8 +37,8 @@ install_python: &install_python
working_directory: ~/
command: |
pyenv versions
pyenv install 3.6.2
pyenv global 3.6.2
pyenv install -f 3.7.0
pyenv global 3.7.0
update_gcc7: &update_gcc7
- run:
Expand Down Expand Up @@ -107,6 +107,17 @@ install_vissl_dep: &install_vissl_dep
# Update this since classy_vision seems to need it.
pip install --progress-bar off --upgrade iopath
# Must install python3-magic as per documentation:
# https://github.com/facebookresearch/AugLy#installation
install_augly: &install_augly
- run:
name: Install augly
working_directory: ~/vissl
command: |
pip install augly
sudo apt-get update
sudo apt-get install python3-magic
install_apex_gpu: &install_apex_gpu
- run:
name: Install Apex
Expand Down Expand Up @@ -153,17 +164,18 @@ jobs:
# Cache the vissl_venv directory that contains dependencies
- restore_cache:
keys:
- v5-cpu-dependencies-{{ checksum "requirements.txt" }}-{{ checksum "setup.py" }}
- v6-cpu-dependencies-{{ checksum "requirements.txt" }}-{{ checksum "setup.py" }}

- <<: *install_vissl_dep
- <<: *install_augly
- <<: *install_classy_vision
- <<: *install_apex_cpu
- <<: *pip_list

- save_cache:
paths:
- ~/vissl_venv
key: v5-cpu-dependencies-{{ checksum "requirements.txt" }}-{{ checksum "setup.py" }}
key: v6-cpu-dependencies-{{ checksum "requirements.txt" }}-{{ checksum "setup.py" }}

- <<: *install_vissl

Expand Down Expand Up @@ -196,7 +208,7 @@ jobs:
# Download and cache dependencies
- restore_cache:
keys:
- v5-gpu-dependencies-{{ checksum "requirements.txt" }}-{{ checksum "setup.py" }}-{{ checksum "docker/common/install_apex.sh" }}
- v6-gpu-dependencies-{{ checksum "requirements.txt" }}-{{ checksum "setup.py" }}-{{ checksum "docker/common/install_apex.sh" }}

- <<: *install_vissl_dep
- <<: *install_classy_vision
Expand All @@ -211,7 +223,7 @@ jobs:
- save_cache:
paths:
- ~/vissl_venv
key: v5-gpu-dependencies-{{ checksum "requirements.txt" }}-{{ checksum "setup.py" }}-{{ checksum "docker/common/install_apex.sh" }}
key: v6-gpu-dependencies-{{ checksum "requirements.txt" }}-{{ checksum "setup.py" }}-{{ checksum "docker/common/install_apex.sh" }}

- <<: *install_vissl

Expand Down
25 changes: 25 additions & 0 deletions configs/config/test/transforms/augly_transforms_example.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# @package _global_
config:
DATA:
TRAIN:
TRANSFORMS:
- name: ImgReplicatePil
num_times: 2
- name: RandomResizedCrop
size: 224
- name: RandomHorizontalFlip
p: 0.5
- name: ImgPilColorDistortion
strength: 1.0
- name: ImgPilGaussianBlur
p: 0.5
radius_min: 0.1
radius_max: 2.0
- name: Blur
transform_type: "augly"
radius: 2.0
p: 1.0
- name: ToTensor
- name: Normalize
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
18 changes: 18 additions & 0 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
from vissl.data.ssl_transforms.img_pil_to_multicrop import ImgPilToMultiCrop
from vissl.data.ssl_transforms.img_pil_to_tensor import ImgToTensor
from vissl.data.ssl_transforms.mnist_img_pil_to_rgb_mode import MNISTImgPil2RGB
from vissl.utils.hydra_config import compose_hydra_configuration, convert_to_attrdict
from vissl.utils.test_utils import (
in_temporary_directory,
run_integration_test,
)


RAND_TENSOR = (torch.rand((224, 224, 3)) * 255).to(dtype=torch.uint8)
Expand Down Expand Up @@ -77,3 +82,16 @@ def test_img_pil_to_multicrop(self):
self.assertEqual((224, 224), crop.size)
for crop in crops[2:]:
self.assertEqual((96, 96), crop.size)

def test_augly_transforms(self):
cfg = compose_hydra_configuration(
[
"config=test/cpu_test/test_cpu_resnet_simclr.yaml",
"+config/test/transforms=augly_transforms_example",
],
)
args, config = convert_to_attrdict(cfg)

with in_temporary_directory() as _:
# Test that the training runs with an augly transformation.
run_integration_test(config)
42 changes: 40 additions & 2 deletions vissl/data/ssl_transforms/ssl_transforms_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,15 @@

from typing import Any, Dict

from classy_vision.dataset.transforms import build_transform, register_transform
from classy_vision.dataset.transforms import (
build_transform as build_classy_transform,
register_transform,
)
from classy_vision.dataset.transforms.classy_transform import ClassyTransform
from vissl.utils.misc import is_augly_available

if is_augly_available():
import augly.image as imaugs # NOQA

# Below the transforms that require passing the labels as well. This is specifc
# to SSL only where we automatically generate the labels for training. All other
Expand Down Expand Up @@ -108,7 +114,7 @@ def __init__(
"""
self.indices = set(indices)
self.name = args["name"]
self.transform = build_transform(args)
self.transform = self._build_transform(args)
self.transform_receives_entire_batch = transform_receives_entire_batch
self.transforms_with_labels = transform_types["TRANSFORMS_WITH_LABELS"]
self.transforms_with_copies = transform_types["TRANSFORMS_WITH_COPIES"]
Expand All @@ -117,6 +123,38 @@ def __init__(
]
self.transforms_with_grouping = transform_types["TRANSFORMS_WITH_GROUPING"]

def _build_transform(self, args):
if "transform_type" not in args:
# Default to classy transform.
return build_classy_transform(args)
elif args["transform_type"] == "augly":
# Build augly transform.
return self._build_augly_transform(args)
else:
raise RuntimeError(
f"Transform type: { args.transform_type } is not supported"
)

def _build_augly_transform(self, args):
assert is_augly_available(), "Please pip install augly."

# the name should be available in augly.image
# if users specify the transform name in snake case,
# we need to convert it to title case.
name = args["name"]

if not hasattr(imaugs, name):
# Try converting name to title case.
name = name.title().replace("_", "")

assert hasattr(imaugs, name), f"{name} isn't a registered tranform for augly."

# Delete superfluous keys.
del args["name"]
del args["transform_type"]

return getattr(imaugs, name)(**args)

def _is_transform_with_labels(self):
"""
_TRANSFORMS_WITH_LABELS = ["ImgRotatePil", "ShuffleImgPatches"]
Expand Down
12 changes: 12 additions & 0 deletions vissl/utils/hydra_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from omegaconf import DictConfig, OmegaConf
from vissl.config import AttrDict, check_cfg_version
from vissl.utils.io import save_file
from vissl.utils.misc import is_augly_available


def save_attrdict_to_disk(cfg: AttrDict):
Expand Down Expand Up @@ -462,6 +463,16 @@ def infer_losses_config(cfg):
return cfg


def assert_transforms(cfg):
for transforms in [cfg.DATA.TRAIN.TRANSFORMS, cfg.DATA.TEST.TRANSFORMS]:
for transform in transforms:
if "transform_type" in transform:
assert transform["transform_type"] in [None, "augly"]

if transform["transform_type"] == "augly":
assert is_augly_available(), "Please pip install augly."


def infer_and_assert_hydra_config(cfg):
"""
Infer values of few parameters in the config file using the value of other config parameters
Expand All @@ -480,6 +491,7 @@ def infer_and_assert_hydra_config(cfg):
"""
cfg = infer_losses_config(cfg)
cfg = infer_learning_rate(cfg)
assert_transforms(cfg)

# pass the seed to cfg["MODEL"] so that model init on different nodes can
# use the same seed.
Expand Down
20 changes: 20 additions & 0 deletions vissl/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
import os
import random
import sys
import tempfile
import time
from functools import partial, wraps
Expand Down Expand Up @@ -80,6 +81,25 @@ def is_apex_available():
return apex_available


def is_augly_available():
"""
Check if apex is available with simple python imports.
"""
try:
assert sys.version_info >= (
3,
7,
0,
), "Please upgrade your python version to 3.7 or higher to use Augly."

import augly.image # NOQA

augly_available = True
except ImportError:
augly_available = False
return augly_available


def find_free_tcp_port():
"""
Find the free port that can be used for Rendezvous on the local machine.
Expand Down

0 comments on commit dd9971a

Please sign in to comment.