-
-
Notifications
You must be signed in to change notification settings - Fork 16.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Init Commit * new wandb integration * Update * Use data_dict in test * Updates * Update: scope of log_img * Update: scope of log_img * Update * Update: Fix logging conditions * Add tqdm bar, support for .txt dataset format * Improve Result table Logger * Init Commit * new wandb integration * Update * Use data_dict in test * Updates * Update: scope of log_img * Update: scope of log_img * Update * Update: Fix logging conditions * Add tqdm bar, support for .txt dataset format * Improve Result table Logger * Add dataset creation in training script * Change scope: self.wandb_run * Add wandb-artifact:// natively you can now use --resume with wandb run links * Add suuport for logging dataset while training * Cleanup * Fix: Merge conflict * Fix: CI tests * Automatically use wandb config * Fix: Resume * Fix: CI * Enhance: Using val_table * More resume enhancement * FIX : CI * Add alias * Get useful opt config data * train.py cleanup * Cleanup train.py * more cleanup * Cleanup| CI fix * Reformat using PEP8 * FIX:CI * rebase * remove uneccesary changes * remove uneccesary changes * remove uneccesary changes * remove unecessary chage from test.py * FIX: resume from local checkpoint * FIX:resume * FIX:resume * Reformat * Performance improvement * Fix local resume * Fix local resume * FIX:CI * Fix: CI * Imporve image logging * (:(:Redo CI tests:):) * Remember epochs when resuming * Remember epochs when resuming * Update DDP location Potential fix for #2405 * PEP8 reformat * 0.25 confidence threshold * reset train.py plots syntax to previous * reset epochs completed syntax to previous * reset space to previous * remove brackets * reset comment to previous * Update: is_coco check, remove unused code * Remove redundant print statement * Remove wandb imports * remove dsviz logger from test.py * Remove redundant change from test.py * remove redundant changes from train.py * reformat and improvements * Fix typo * Add tqdm tqdm progress when scanning files, naming improvements Co-authored-by: Glenn Jocher <[email protected]>
- Loading branch information
1 parent
ed2c742
commit e8fc97a
Showing
5 changed files
with
282 additions
and
168 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,8 +35,9 @@ def test(data, | |
save_hybrid=False, # for hybrid auto-labelling | ||
save_conf=False, # save auto-label confidences | ||
plots=True, | ||
log_imgs=0, # number of logged images | ||
compute_loss=None): | ||
wandb_logger=None, | ||
compute_loss=None, | ||
is_coco=False): | ||
# Initialize/load model and set device | ||
training = model is not None | ||
if training: # called by train.py | ||
|
@@ -66,21 +67,19 @@ def test(data, | |
|
||
# Configure | ||
model.eval() | ||
is_coco = data.endswith('coco.yaml') # is COCO dataset | ||
with open(data) as f: | ||
data = yaml.load(f, Loader=yaml.SafeLoader) # model dict | ||
if isinstance(data, str): | ||
is_coco = data.endswith('coco.yaml') | ||
with open(data) as f: | ||
data = yaml.load(f, Loader=yaml.SafeLoader) | ||
check_dataset(data) # check | ||
nc = 1 if single_cls else int(data['nc']) # number of classes | ||
iouv = torch.linspace(0.5, 0.95, 10).to(device) # iou vector for [email protected]:0.95 | ||
niou = iouv.numel() | ||
|
||
# Logging | ||
log_imgs, wandb = min(log_imgs, 100), None # ceil | ||
try: | ||
import wandb # Weights & Biases | ||
except ImportError: | ||
log_imgs = 0 | ||
|
||
log_imgs = 0 | ||
if wandb_logger and wandb_logger.wandb: | ||
log_imgs = min(wandb_logger.log_imgs, 100) | ||
# Dataloader | ||
if not training: | ||
if device.type != 'cpu': | ||
|
@@ -147,15 +146,17 @@ def test(data, | |
with open(save_dir / 'labels' / (path.stem + '.txt'), 'a') as f: | ||
f.write(('%g ' * len(line)).rstrip() % line + '\n') | ||
|
||
# W&B logging | ||
if plots and len(wandb_images) < log_imgs: | ||
box_data = [{"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]}, | ||
"class_id": int(cls), | ||
"box_caption": "%s %.3f" % (names[cls], conf), | ||
"scores": {"class_score": conf}, | ||
"domain": "pixel"} for *xyxy, conf, cls in pred.tolist()] | ||
boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space | ||
wandb_images.append(wandb.Image(img[si], boxes=boxes, caption=path.name)) | ||
# W&B logging - Media Panel Plots | ||
if len(wandb_images) < log_imgs and wandb_logger.current_epoch > 0: # Check for test operation | ||
if wandb_logger.current_epoch % wandb_logger.bbox_interval == 0: | ||
box_data = [{"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]}, | ||
"class_id": int(cls), | ||
"box_caption": "%s %.3f" % (names[cls], conf), | ||
"scores": {"class_score": conf}, | ||
"domain": "pixel"} for *xyxy, conf, cls in pred.tolist()] | ||
boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space | ||
wandb_images.append(wandb_logger.wandb.Image(img[si], boxes=boxes, caption=path.name)) | ||
wandb_logger.log_training_progress(predn, path, names) # logs dsviz tables | ||
|
||
# Append to pycocotools JSON dictionary | ||
if save_json: | ||
|
@@ -239,9 +240,11 @@ def test(data, | |
# Plots | ||
if plots: | ||
confusion_matrix.plot(save_dir=save_dir, names=list(names.values())) | ||
if wandb and wandb.run: | ||
val_batches = [wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('test*.jpg'))] | ||
wandb.log({"Images": wandb_images, "Validation": val_batches}, commit=False) | ||
if wandb_logger and wandb_logger.wandb: | ||
val_batches = [wandb_logger.wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('test*.jpg'))] | ||
wandb_logger.log({"Validation": val_batches}) | ||
if wandb_images: | ||
wandb_logger.log({"Bounding Box Debugger/Images": wandb_images}) | ||
|
||
# Save JSON | ||
if save_json and len(jdict): | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
|
||
import argparse | ||
import logging | ||
import math | ||
|
@@ -33,11 +34,12 @@ | |
from utils.loss import ComputeLoss | ||
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution | ||
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, is_parallel | ||
from utils.wandb_logging.wandb_utils import WandbLogger, resume_and_get_id, check_wandb_config_file | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def train(hyp, opt, device, tb_writer=None, wandb=None): | ||
def train(hyp, opt, device, tb_writer=None): | ||
logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items())) | ||
save_dir, epochs, batch_size, total_batch_size, weights, rank = \ | ||
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank | ||
|
@@ -61,10 +63,17 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): | |
init_seeds(2 + rank) | ||
with open(opt.data) as f: | ||
data_dict = yaml.load(f, Loader=yaml.SafeLoader) # data dict | ||
with torch_distributed_zero_first(rank): | ||
check_dataset(data_dict) # check | ||
train_path = data_dict['train'] | ||
test_path = data_dict['val'] | ||
is_coco = opt.data.endswith('coco.yaml') | ||
|
||
# Logging- Doing this before checking the dataset. Might update data_dict | ||
if rank in [-1, 0]: | ||
opt.hyp = hyp # add hyperparameters | ||
run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None | ||
wandb_logger = WandbLogger(opt, Path(opt.save_dir).stem, run_id, data_dict) | ||
data_dict = wandb_logger.data_dict | ||
if wandb_logger.wandb: | ||
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # WandbLogger might update weights, epochs if resuming | ||
loggers = {'wandb': wandb_logger.wandb} # loggers dict | ||
nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes | ||
names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names | ||
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check | ||
|
@@ -83,6 +92,10 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): | |
logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report | ||
else: | ||
model = Model(opt.cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create | ||
with torch_distributed_zero_first(rank): | ||
check_dataset(data_dict) # check | ||
train_path = data_dict['train'] | ||
test_path = data_dict['val'] | ||
|
||
# Freeze | ||
freeze = [] # parameter names to freeze (full or partial) | ||
|
@@ -126,16 +139,6 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): | |
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) | ||
# plot_lr_scheduler(optimizer, scheduler, epochs) | ||
|
||
# Logging | ||
if rank in [-1, 0] and wandb and wandb.run is None: | ||
opt.hyp = hyp # add hyperparameters | ||
wandb_run = wandb.init(config=opt, resume="allow", | ||
project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem, | ||
name=save_dir.stem, | ||
entity=opt.entity, | ||
id=ckpt.get('wandb_id') if 'ckpt' in locals() else None) | ||
loggers = {'wandb': wandb} # loggers dict | ||
|
||
# EMA | ||
ema = ModelEMA(model) if rank in [-1, 0] else None | ||
|
||
|
@@ -326,9 +329,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): | |
# if tb_writer: | ||
# tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch) | ||
# tb_writer.add_graph(model, imgs) # add model to tensorboard | ||
elif plots and ni == 10 and wandb: | ||
wandb.log({"Mosaics": [wandb.Image(str(x), caption=x.name) for x in save_dir.glob('train*.jpg') | ||
if x.exists()]}, commit=False) | ||
elif plots and ni == 10 and wandb_logger.wandb: | ||
wandb_logger.log({"Mosaics": [wandb_logger.wandb.Image(str(x), caption=x.name) for x in | ||
save_dir.glob('train*.jpg') if x.exists()]}) | ||
|
||
# end batch ------------------------------------------------------------------------------------------------ | ||
# end epoch ---------------------------------------------------------------------------------------------------- | ||
|
@@ -343,17 +346,19 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): | |
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights']) | ||
final_epoch = epoch + 1 == epochs | ||
if not opt.notest or final_epoch: # Calculate mAP | ||
results, maps, times = test.test(opt.data, | ||
batch_size=batch_size * 2, | ||
wandb_logger.current_epoch = epoch + 1 | ||
results, maps, times = test.test(data_dict, | ||
batch_size=total_batch_size, | ||
imgsz=imgsz_test, | ||
model=ema.ema, | ||
single_cls=opt.single_cls, | ||
dataloader=testloader, | ||
save_dir=save_dir, | ||
verbose=nc < 50 and final_epoch, | ||
plots=plots and final_epoch, | ||
log_imgs=opt.log_imgs if wandb else 0, | ||
compute_loss=compute_loss) | ||
wandb_logger=wandb_logger, | ||
compute_loss=compute_loss, | ||
is_coco=is_coco) | ||
|
||
# Write | ||
with open(results_file, 'a') as f: | ||
|
@@ -369,8 +374,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): | |
for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags): | ||
if tb_writer: | ||
tb_writer.add_scalar(tag, x, epoch) # tensorboard | ||
if wandb: | ||
wandb.log({tag: x}, step=epoch, commit=tag == tags[-1]) # W&B | ||
if wandb_logger.wandb: | ||
wandb_logger.log({tag: x}) # W&B | ||
|
||
# Update best mAP | ||
fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, [email protected], [email protected]] | ||
|
@@ -386,36 +391,29 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): | |
'ema': deepcopy(ema.ema).half(), | ||
'updates': ema.updates, | ||
'optimizer': optimizer.state_dict(), | ||
'wandb_id': wandb_run.id if wandb else None} | ||
'wandb_id': wandb_logger.wandb_run.id if wandb_logger.wandb else None} | ||
|
||
# Save last, best and delete | ||
torch.save(ckpt, last) | ||
if best_fitness == fi: | ||
torch.save(ckpt, best) | ||
if wandb_logger.wandb: | ||
if ((epoch + 1) % opt.save_period == 0 and not final_epoch) and opt.save_period != -1: | ||
wandb_logger.log_model( | ||
last.parent, opt, epoch, fi, best_model=best_fitness == fi) | ||
del ckpt | ||
|
||
wandb_logger.end_epoch(best_result=best_fitness == fi) | ||
|
||
# end epoch ---------------------------------------------------------------------------------------------------- | ||
# end training | ||
|
||
if rank in [-1, 0]: | ||
# Strip optimizers | ||
final = best if best.exists() else last # final model | ||
for f in last, best: | ||
if f.exists(): | ||
strip_optimizer(f) | ||
if opt.bucket: | ||
os.system(f'gsutil cp {final} gs://{opt.bucket}/weights') # upload | ||
|
||
# Plots | ||
if plots: | ||
plot_results(save_dir=save_dir) # save as results.png | ||
if wandb: | ||
if wandb_logger.wandb: | ||
files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]] | ||
wandb.log({"Results": [wandb.Image(str(save_dir / f), caption=f) for f in files | ||
if (save_dir / f).exists()]}) | ||
if opt.log_artifacts: | ||
wandb.log_artifact(artifact_or_path=str(final), type='model', name=save_dir.stem) | ||
|
||
wandb_logger.log({"Results": [wandb_logger.wandb.Image(str(save_dir / f), caption=f) for f in files | ||
if (save_dir / f).exists()]}) | ||
# Test best.pt | ||
logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600)) | ||
if opt.data.endswith('coco.yaml') and nc == 80: # if COCO | ||
|
@@ -430,13 +428,24 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): | |
dataloader=testloader, | ||
save_dir=save_dir, | ||
save_json=True, | ||
plots=False) | ||
plots=False, | ||
is_coco=is_coco) | ||
|
||
# Strip optimizers | ||
final = best if best.exists() else last # final model | ||
for f in last, best: | ||
if f.exists(): | ||
strip_optimizer(f) # strip optimizers | ||
if opt.bucket: | ||
os.system(f'gsutil cp {final} gs://{opt.bucket}/weights') # upload | ||
if wandb_logger.wandb: # Log the stripped model | ||
wandb_logger.wandb.log_artifact(str(final), type='model', | ||
name='run_' + wandb_logger.wandb_run.id + '_model', | ||
aliases=['last', 'best', 'stripped']) | ||
else: | ||
dist.destroy_process_group() | ||
|
||
wandb.run.finish() if wandb and wandb.run else None | ||
torch.cuda.empty_cache() | ||
wandb_logger.finish_run() | ||
return results | ||
|
||
|
||
|
@@ -464,15 +473,17 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): | |
parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer') | ||
parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode') | ||
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify') | ||
parser.add_argument('--log-imgs', type=int, default=16, help='number of images for W&B logging, max 100') | ||
parser.add_argument('--log-artifacts', action='store_true', help='log artifacts, i.e. final trained model') | ||
parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers') | ||
parser.add_argument('--project', default='runs/train', help='save to project/name') | ||
parser.add_argument('--entity', default=None, help='W&B entity') | ||
parser.add_argument('--name', default='exp', help='save to project/name') | ||
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') | ||
parser.add_argument('--quad', action='store_true', help='quad dataloader') | ||
parser.add_argument('--linear-lr', action='store_true', help='linear LR') | ||
parser.add_argument('--upload_dataset', action='store_true', help='Upload dataset as W&B artifact table') | ||
parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval for W&B') | ||
parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch') | ||
parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used') | ||
opt = parser.parse_args() | ||
|
||
# Set DDP variables | ||
|
@@ -484,7 +495,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): | |
check_requirements() | ||
|
||
# Resume | ||
if opt.resume: # resume an interrupted run | ||
wandb_run = resume_and_get_id(opt) | ||
if opt.resume and not wandb_run: # resume an interrupted run | ||
ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path | ||
assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist' | ||
apriori = opt.global_rank, opt.local_rank | ||
|
@@ -517,18 +529,12 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): | |
|
||
# Train | ||
logger.info(opt) | ||
try: | ||
import wandb | ||
except ImportError: | ||
wandb = None | ||
prefix = colorstr('wandb: ') | ||
logger.info(f"{prefix}Install Weights & Biases for YOLOv5 logging with 'pip install wandb' (recommended)") | ||
if not opt.evolve: | ||
tb_writer = None # init loggers | ||
if opt.global_rank in [-1, 0]: | ||
logger.info(f'Start Tensorboard with "tensorboard --logdir {opt.project}", view at http://localhost:6006/') | ||
tb_writer = SummaryWriter(opt.save_dir) # Tensorboard | ||
train(hyp, opt, device, tb_writer, wandb) | ||
train(hyp, opt, device, tb_writer) | ||
|
||
# Evolve hyperparameters (optional) | ||
else: | ||
|
@@ -602,7 +608,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): | |
hyp[k] = round(hyp[k], 5) # significant digits | ||
|
||
# Train mutation | ||
results = train(hyp.copy(), opt, device, wandb=wandb) | ||
results = train(hyp.copy(), opt, device) | ||
|
||
# Write mutation results | ||
print_mutation(hyp.copy(), results, yaml_file, opt.bucket) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.