-
-
Notifications
You must be signed in to change notification settings - Fork 16.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Update evolution to CSV format * Update * Update * Update * Update * Update * reset args * reset args * reset args * plot_results() fix * Cleanup * Cleanup2
- Loading branch information
1 parent
4103ce9
commit e78aeac
Showing
6 changed files
with
75 additions
and
65 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,7 +37,7 @@ | |
check_requirements, print_mutation, set_logging, one_cycle, colorstr, methods | ||
from utils.downloads import attempt_download | ||
from utils.loss import ComputeLoss | ||
from utils.plots import plot_labels, plot_evolution | ||
from utils.plots import plot_labels, plot_evolve | ||
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, de_parallel | ||
from utils.loggers.wandb.wandb_utils import check_wandb_resume | ||
from utils.metrics import fitness | ||
|
@@ -367,7 +367,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary | |
fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, [email protected], [email protected]] | ||
if fi > best_fitness: | ||
best_fitness = fi | ||
callbacks.on_fit_epoch_end(mloss, results, lr, epoch, best_fitness, fi) | ||
log_vals = list(mloss) + list(results) + lr | ||
callbacks.on_fit_epoch_end(log_vals, epoch, best_fitness, fi) | ||
|
||
# Save model | ||
if (not nosave) or (final_epoch and not evolve): # if save | ||
|
@@ -464,7 +465,7 @@ def main(opt): | |
check_requirements(requirements=FILE.parent / 'requirements.txt', exclude=['thop']) | ||
|
||
# Resume | ||
if opt.resume and not check_wandb_resume(opt): # resume an interrupted run | ||
if opt.resume and not check_wandb_resume(opt) and not opt.evolve: # resume an interrupted run | ||
ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path | ||
assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist' | ||
with open(Path(ckpt).parent.parent / 'opt.yaml') as f: | ||
|
@@ -474,8 +475,10 @@ def main(opt): | |
else: | ||
opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files | ||
assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified' | ||
opt.name = 'evolve' if opt.evolve else opt.name | ||
opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok or opt.evolve)) | ||
if opt.evolve: | ||
opt.project = 'runs/evolve' | ||
opt.exist_ok = opt.resume | ||
opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) | ||
|
||
# DDP mode | ||
device = select_device(opt.device, batch_size=opt.batch_size) | ||
|
@@ -533,17 +536,17 @@ def main(opt): | |
hyp = yaml.safe_load(f) # load hyps dict | ||
if 'anchors' not in hyp: # anchors commented in hyp.yaml | ||
hyp['anchors'] = 3 | ||
opt.noval, opt.nosave = True, True # only val/save final epoch | ||
opt.noval, opt.nosave, save_dir = True, True, Path(opt.save_dir) # only val/save final epoch | ||
# ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices | ||
yaml_file = Path(opt.save_dir) / 'hyp_evolved.yaml' # save best result here | ||
evolve_yaml, evolve_csv = save_dir / 'hyp_evolve.yaml', save_dir / 'evolve.csv' | ||
if opt.bucket: | ||
os.system(f'gsutil cp gs://{opt.bucket}/evolve.txt .') # download evolve.txt if exists | ||
os.system(f'gsutil cp gs://{opt.bucket}/evolve.csv {save_dir}') # download evolve.csv if exists | ||
|
||
for _ in range(opt.evolve): # generations to evolve | ||
if Path('evolve.txt').exists(): # if evolve.txt exists: select best hyps and mutate | ||
if evolve_csv.exists(): # if evolve.csv exists: select best hyps and mutate | ||
# Select parent(s) | ||
parent = 'single' # parent selection method: 'single' or 'weighted' | ||
x = np.loadtxt('evolve.txt', ndmin=2) | ||
x = np.loadtxt(evolve_csv, ndmin=2, delimiter=',', skiprows=1) | ||
n = min(5, len(x)) # number of previous results to consider | ||
x = x[np.argsort(-fitness(x))][:n] # top n mutations | ||
w = fitness(x) - fitness(x).min() + 1E-6 # weights (sum > 0) | ||
|
@@ -575,12 +578,13 @@ def main(opt): | |
results = train(hyp.copy(), opt, device) | ||
|
||
# Write mutation results | ||
print_mutation(hyp.copy(), results, yaml_file, opt.bucket) | ||
print_mutation(results, hyp.copy(), save_dir, opt.bucket) | ||
|
||
# Plot results | ||
plot_evolution(yaml_file) | ||
print(f'Hyperparameter evolution complete. Best results saved as: {yaml_file}\n' | ||
f'Command to train a new model with these hyperparameters: $ python train.py --hyp {yaml_file}') | ||
plot_evolve(evolve_csv) | ||
print(f'Hyperparameter evolution finished\n' | ||
f"Results saved to {colorstr('bold', save_dir)}" | ||
f'Use best hyperparameters example: $ python train.py --hyp {evolve_yaml}') | ||
|
||
|
||
def run(**kwargs): | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -615,35 +615,43 @@ def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_op | |
print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB") | ||
|
||
|
||
def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''): | ||
# Print mutation results to evolve.txt (for use with train.py --evolve) | ||
a = '%10s' * len(hyp) % tuple(hyp.keys()) # hyperparam keys | ||
b = '%10.3g' * len(hyp) % tuple(hyp.values()) # hyperparam values | ||
c = '%10.4g' * len(results) % results # results (P, R, [email protected], [email protected]:0.95, val_losses x 3) | ||
print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c)) | ||
def print_mutation(results, hyp, save_dir, bucket): | ||
evolve_csv, results_csv, evolve_yaml = save_dir / 'evolve.csv', save_dir / 'results.csv', save_dir / 'hyp_evolve.yaml' | ||
keys = ('metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', | ||
'val/box_loss', 'val/obj_loss', 'val/cls_loss') + tuple(hyp.keys()) # [results + hyps] | ||
keys = tuple(x.strip() for x in keys) | ||
vals = results + tuple(hyp.values()) | ||
n = len(keys) | ||
|
||
# Download (optional) | ||
if bucket: | ||
url = 'gs://%s/evolve.txt' % bucket | ||
if gsutil_getsize(url) > (os.path.getsize('evolve.txt') if os.path.exists('evolve.txt') else 0): | ||
os.system('gsutil cp %s .' % url) # download evolve.txt if larger than local | ||
url = f'gs://{bucket}/evolve.csv' | ||
if gsutil_getsize(url) > (os.path.getsize(evolve_csv) if os.path.exists(evolve_csv) else 0): | ||
os.system(f'gsutil cp {url} {save_dir}') # download evolve.csv if larger than local | ||
|
||
# Log to evolve.csv | ||
s = '' if evolve_csv.exists() else (('%20s,' * n % keys).rstrip(',') + '\n') # add header | ||
with open(evolve_csv, 'a') as f: | ||
f.write(s + ('%20.5g,' * n % vals).rstrip(',') + '\n') | ||
|
||
with open('evolve.txt', 'a') as f: # append result | ||
f.write(c + b + '\n') | ||
x = np.unique(np.loadtxt('evolve.txt', ndmin=2), axis=0) # load unique rows | ||
x = x[np.argsort(-fitness(x))] # sort | ||
np.savetxt('evolve.txt', x, '%10.3g') # save sort by fitness | ||
# Print to screen | ||
print(colorstr('evolve: ') + ', '.join(f'{x.strip():>20s}' for x in keys)) | ||
print(colorstr('evolve: ') + ', '.join(f'{x:20.5g}' for x in vals), end='\n\n\n') | ||
|
||
# Save yaml | ||
for i, k in enumerate(hyp.keys()): | ||
hyp[k] = float(x[0, i + 7]) | ||
with open(yaml_file, 'w') as f: | ||
results = tuple(x[0, :7]) | ||
c = '%10.4g' * len(results) % results # results (P, R, [email protected], [email protected]:0.95, val_losses x 3) | ||
f.write('# Hyperparameter Evolution Results\n# Generations: %g\n# Metrics: ' % len(x) + c + '\n\n') | ||
with open(evolve_yaml, 'w') as f: | ||
data = pd.read_csv(evolve_csv) | ||
data = data.rename(columns=lambda x: x.strip()) # strip keys | ||
i = np.argmax(fitness(data.values[:, :7])) # | ||
f.write(f'# YOLOv5 Hyperparameter Evolution Results\n' + | ||
f'# Best generation: {i}\n' + | ||
f'# Last generation: {len(data)}\n' + | ||
f'# ' + ', '.join(f'{x.strip():>20s}' for x in keys[:7]) + '\n' + | ||
f'# ' + ', '.join(f'{x:>20.5g}' for x in data.values[i, :7]) + '\n\n') | ||
yaml.safe_dump(hyp, f, sort_keys=False) | ||
|
||
if bucket: | ||
os.system('gsutil cp evolve.txt %s gs://%s' % (yaml_file, bucket)) # upload | ||
os.system(f'gsutil cp {evolve_csv} {evolve_yaml} gs://{bucket}') # upload | ||
|
||
|
||
def apply_classifier(x, model, img, im0): | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters