Skip to content

Commit

Permalink
Merge pull request #131 from FrancescoSaverioZuppichini/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
FrancescoSaverioZuppichini authored Oct 26, 2020
2 parents 518575c + 03604ff commit e3b58b9
Show file tree
Hide file tree
Showing 8 changed files with 904 additions and 793 deletions.
69 changes: 0 additions & 69 deletions benchmark.ipynb

This file was deleted.

150 changes: 94 additions & 56 deletions glasses/utils/PretrainedWeightsProvider.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Tuple
from typing import Callable
from functools import wraps
logging.basicConfig( level=logging.INFO)
logging.basicConfig(level=logging.INFO)

IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
Expand Down Expand Up @@ -50,6 +50,7 @@ def transform(self):

def pretrained(name: str = None) -> Callable:
_name = name

def decorator(func: Callable) -> Callable:
"""Decorator to fetch the pretrained model.
Expand All @@ -70,6 +71,52 @@ def wrapper(*args, pretrained: bool = False, **kwargs) -> Callable:
return wrapper
return decorator


@dataclass
class BasicUrlHandler:
url: str

def get_response(self) -> requests.Request:
r = requests.get(self.url, stream=True)
return r

def __call__(self, save_path: Path, chunk_size: int = 1024) -> Path:
r = self.get_response()

with open(save_path, 'wb') as f:
total_length = sys.getsizeof(r.content)
bar = tqdm(r.iter_content(chunk_size=chunk_size),
total=total_length // chunk_size)
for chunk in bar:
if chunk:
f.write(chunk)
f.flush()


class GoogleDriveUrlHandler(BasicUrlHandler):

def __init__(self, url: str, file_id: str):
super().__init__(url)
self.file_id = file_id

def get_confirm_token(self, response: requests.Request) -> object:
for key, value in response.cookies.items():
if key.startswith('download_warning'):
return value
return None

def get_response(self) -> requests.Request:
session = requests.Session()

response = session.get(self.url, params = { 'id' : self.file_id }, stream = True)
token = self.get_confirm_token(response)

if token:
params = { 'id' : self.file_id, 'confirm' : token }
response = session.get(self.url, params = params, stream = True)

return response

@dataclass
class PretrainedWeightsProvider:
"""
Expand All @@ -83,74 +130,65 @@ class PretrainedWeightsProvider:
>>> provider = PretrainedWeightsProvider(override=True) # override model even if already downloaded
"""

# zoo = {
# 'resnet18':ResNet.resnet18,
# 'resnet34':ResNet.resnet34,
# 'resnet50':ResNet.resnet50,
# 'resnet101': ResNet.resnet101,
# 'resnet152': ResNet.resnet152,

# 'resnext50_32x4d': ResNetXt.resnext50_32x4d,
# 'resnext101_32x8d': ResNetXt.resnext101_32x8d,
# 'wide_resnet50_2': WideResNet.wide_resnet50_2,
# 'wide_resnet101_2': WideResNet.wide_resnet101_2,

# 'densenet121': DenseNet.densenet121,
# 'densenet169': DenseNet.densenet169,
# 'densenet201': DenseNet.densenet201,
# 'densenet161': DenseNet.densenet161,
# 'vgg11': VGG.vgg11,
# 'vgg13': VGG.vgg13,
# 'vgg16': VGG.vgg16,
# 'vgg19': VGG.vgg19,

# 'mobilenet_v2': MobileNetV2,

# 'efficientnet-b0': EfficientNet.b0,
# 'efficientnet-b1': EfficientNet.b1,
# 'efficientnet-b2': EfficientNet.b2,
# 'efficientnet-b3': EfficientNet.b3,
# 'efficientnet-b4': EfficientNet.b4,
# 'efficientnet-b5': EfficientNet.b5,
# 'efficientnet-b6': EfficientNet.b6,
# 'efficientnet-b7': EfficientNet.b7,

# }

BASE_URL = 'https://cv-glasses.s3.eu-central-1.amazonaws.com'
BASE_DIR = Path(f"{os.environ['HOME']}/models_weights")
BASE_URL: str = 'https://cv-glasses.s3.eu-central-1.amazonaws.com'
BASE_DIR: Path = Path(torch.hub.get_dir()) / Path('glasses')
save_dir: Path = BASE_DIR
chunk_size: int = 1024
chunk_size: int = 1024 * 1
verbose: int = 0
override: bool = False

def __post_init__(self):
self.save_dir.mkdir(exist_ok=True)

def download_weight(self, url: str, save_path: Path) -> Path:
r = requests.get(url, stream=True)
weights_zoo = {
'resnet18': BasicUrlHandler('https://github.com/FrancescoSaverioZuppichini/glasses/blob/feature/weights/weights/resnet18.pth?raw=true'),
'resnet26': BasicUrlHandler('https://github.com/FrancescoSaverioZuppichini/glasses/blob/feature/weights/weights/resnet26.pth?raw=true'),
'resnet34': BasicUrlHandler('https://github.com/FrancescoSaverioZuppichini/glasses/blob/feature/weights/weights/resnet34.pth?raw=true'),
'mobilenet_v2': BasicUrlHandler('https://github.com/FrancescoSaverioZuppichini/glasses/blob/feature/weights/weights/mobilenet_v2.pth?raw=true'),
'efficientnet_b0': BasicUrlHandler('https://github.com/FrancescoSaverioZuppichini/glasses/blob/feature/weights/weights/efficientnet_b0.pth?raw=true'),
'efficientnet_b1': BasicUrlHandler('https://github.com/FrancescoSaverioZuppichini/glasses/blob/feature/weights/weights/efficientnet_b1.pth?raw=true'),
'efficientnet_b2': BasicUrlHandler('https://github.com/FrancescoSaverioZuppichini/glasses/blob/feature/weights/weights/efficientnet_b2.pth?raw=true'),
'efficientnet_b3': BasicUrlHandler('https://github.com/FrancescoSaverioZuppichini/glasses/blob/feature/weights/weights/efficientnet_b3.pth?raw=true'),
'densenet121': BasicUrlHandler('https://github.com/FrancescoSaverioZuppichini/glasses/blob/feature/weights/weights/densenet121.pth?raw=true'),
'densenet169': BasicUrlHandler('https://github.com/FrancescoSaverioZuppichini/glasses/blob/feature/weights/weights/densenet169.pth?raw=true'),
'densenet201': BasicUrlHandler('https://github.com/FrancescoSaverioZuppichini/glasses/blob/feature/weights/weights/densenet201.pth?raw=true'),
# from google drive
'resnet50': GoogleDriveUrlHandler('https://docs.google.com/uc?export=download', file_id = '1DYXJ12tLb-W687Wa9MWfvyarlz52cyD3'),
'cse_resnet50': GoogleDriveUrlHandler('https://docs.google.com/uc?export=download', file_id = '1CMyib_ACsWUIbXa7KjX3NXKkfAQyNnLd'),
'resnet101': GoogleDriveUrlHandler('https://docs.google.com/uc?export=download', file_id = '14q5m53eYqQOPb1ZQYHButFW_g9Ec5pmR'),
'resnet152': GoogleDriveUrlHandler('https://docs.google.com/uc?export=download', file_id = '1d-EGQi-HGFNXEdQE7cVzXvmFl9iZAd-F'),
'resnext50_32x4d': GoogleDriveUrlHandler('https://docs.google.com/uc?export=download', file_id = '1lvV5v-WT0YBLSB9j3beGs8cV7Qc3ecEg'),
'resnext101_32x8d': GoogleDriveUrlHandler('https://docs.google.com/uc?export=download', file_id = '1y4GfcknrznFhMdMsbZwdZBYPN6UUNP2H'),
'wide_resnet50_2': GoogleDriveUrlHandler('https://docs.google.com/uc?export=download', file_id = '1or9L8aO7QDU0haP1pdbwGPrSLiTRQkqa'),
'wide_resnet101_2': GoogleDriveUrlHandler('https://docs.google.com/uc?export=download', file_id = '1VUvWd6MF7ySDx7kQH3siJjtpmxxw5LE8'),
'vgg11': GoogleDriveUrlHandler('https://docs.google.com/uc?export=download', file_id = '1dnlUB4ew8EdLMTVa9xS0CpyXVICpEls2'),
'vgg13': GoogleDriveUrlHandler('https://docs.google.com/uc?export=download', file_id = '1X87UaYvENTuLRD94TP8PJE0h-7su3lJL'),
'vgg16': GoogleDriveUrlHandler('https://docs.google.com/uc?export=download', file_id = '1yER36sIvoYZXRY_QHgk9sX6ecMP6t2h7'),
'vgg19': GoogleDriveUrlHandler('https://docs.google.com/uc?export=download', file_id = '1VWBABqCyrlqNXlacS5lHjDCWGkbwSIYZ'),
'vgg11_bn': GoogleDriveUrlHandler('https://docs.google.com/uc?export=download', file_id = '1HCqOnxN2RCyRvUy8pQ_5XnAqH3zI1bTp'),
'vgg13_bn': GoogleDriveUrlHandler('https://docs.google.com/uc?export=download', file_id = '1YlttLo-9VDgXq03gdnkJ8NIBEMpYwejN'),
'vgg16_bn': GoogleDriveUrlHandler('https://docs.google.com/uc?export=download', file_id = '1X6dvcZYPQcwTGlQ1S87pOUuCmTUhP1zj'),
'vgg19_bn': GoogleDriveUrlHandler('https://docs.google.com/uc?export=download', file_id = '1rHNKV8MgES-7PXYdzarI23MklRMpxaOm'),
'densenet161': GoogleDriveUrlHandler('https://docs.google.com/uc?export=download', file_id = '153fMUorCUGSl4pKSA4tzduaI6BFG7hu5'),

}

with open(save_path, 'wb') as f:
total_length = sys.getsizeof(r.content)
bar = tqdm(r.iter_content(chunk_size=self.chunk_size),
total=total_length // self.chunk_size)
for chunk in bar:
if chunk:
f.write(chunk)
f.flush()
def __post_init__(self):
try:
self.save_dir.mkdir(exist_ok=True)
except FileNotFoundError:
self.save_dir = Path(os.environ['HOME']) / Path('.glasses/')
self.save_dir.mkdir(exist_ok=True)

def __getitem__(self, key: str) -> dict:
# if key not in self.zoo:
# raise KeyError(
# f'No weights for model "{key}". Available models are {",".join(list(self.zoo_models_mapping.keys()))}')
if key not in self.weights_zoo:
raise KeyError(
f'No weights for model "{key}". Available models are {",".join(list(self.weights_zoo.keys()))}')

save_path = self.save_dir / f'{key}.pth'

should_download = not save_path.exists()

if should_download or self.override:
url = f'{self.BASE_URL}/{key}.pth'
self.download_weight(url, save_path)
handler = self.weights_zoo[key]
handler(save_path, self.chunk_size)

weights = torch.load(save_path)
logging.info(f'Loaded {key} pretrained weights.')
Expand Down
Loading

0 comments on commit e3b58b9

Please sign in to comment.