-
Notifications
You must be signed in to change notification settings - Fork 3
/
train.py
executable file
·258 lines (215 loc) · 8.64 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
#!/usr/bin/env python
"""Train a radiance field with nerfstudio.
For real captures, we recommend using the [bright_yellow]nerfacto[/bright_yellow] model.
Nerfstudio allows for customizing your training and eval configs from the CLI in a powerful way, but there are some
things to understand.
The most demonstrative and helpful example of the CLI structure is the difference in output between the following
commands:
ns-train -h
ns-train nerfacto -h nerfstudio-data
ns-train nerfacto nerfstudio-data -h
In each of these examples, the -h applies to the previous subcommand (ns-train, nerfacto, and nerfstudio-data).
In the first example, we get the help menu for the ns-train script.
In the second example, we get the help menu for the nerfacto model.
In the third example, we get the help menu for the nerfstudio-data dataparser.
With our scripts, your arguments will apply to the preceding subcommand in your command, and thus where you put your
arguments matters! Any optional arguments you discover from running
ns-train nerfacto -h nerfstudio-data
need to come directly after the nerfacto subcommand, since these optional arguments only belong to the nerfacto
subcommand:
ns-train nerfacto {nerfacto optional args} nerfstudio-data
"""
from __future__ import annotations
import random
import socket
import traceback
from datetime import timedelta
from typing import Any, Callable, Optional
import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import tyro
import yaml
from rich.console import Console
from nerfstudio.configs.config_utils import convert_markup_to_ansi
from nerfstudio.configs.method_configs import AnnotatedBaseConfigUnion
from nerfstudio.engine.trainer import TrainerConfig
from nerfstudio.utils import comms, profiler
import nerfstudio
print(nerfstudio)
CONSOLE = Console(width=120)
DEFAULT_TIMEOUT = timedelta(minutes=30)
# speedup for when input size to model doesn't change (much)
torch.backends.cudnn.benchmark = True # type: ignore
def _find_free_port() -> str:
"""Finds a free port."""
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(("", 0))
port = sock.getsockname()[1]
sock.close()
return port
def _set_random_seed(seed) -> None:
"""Set randomness seed in torch and numpy"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def train_loop(local_rank: int, world_size: int, config: TrainerConfig, global_rank: int = 0):
"""Main training function that sets up and runs the trainer per process
Args:
local_rank: current rank of process
world_size: total number of gpus available
config: config file specifying training regimen
"""
_set_random_seed(config.machine.seed + global_rank)
trainer = config.setup(local_rank=local_rank, world_size=world_size)
trainer.setup()
trainer.train()
def _distributed_worker(
local_rank: int,
main_func: Callable,
world_size: int,
num_gpus_per_machine: int,
machine_rank: int,
dist_url: str,
config: TrainerConfig,
timeout: timedelta = DEFAULT_TIMEOUT,
) -> Any:
"""Spawned distributed worker that handles the initialization of process group and handles the
training process on multiple processes.
Args:
local_rank: Current rank of process.
main_func: Function that will be called by the distributed workers.
world_size: Total number of gpus available.
num_gpus_per_machine: Number of GPUs per machine.
machine_rank: Rank of this machine.
dist_url: URL to connect to for distributed jobs, including protocol
E.g., "tcp://127.0.0.1:8686".
It can be set to "auto" to automatically select a free port on localhost.
config: TrainerConfig specifying training regimen.
timeout: Timeout of the distributed workers.
Raises:
e: Exception in initializing the process group
Returns:
Any: TODO: determine the return type
"""
assert torch.cuda.is_available(), "cuda is not available. Please check your installation."
global_rank = machine_rank * num_gpus_per_machine + local_rank
dist.init_process_group(
backend="nccl",
init_method=dist_url,
world_size=world_size,
rank=global_rank,
timeout=timeout,
)
assert comms.LOCAL_PROCESS_GROUP is None
num_machines = world_size // num_gpus_per_machine
for i in range(num_machines):
ranks_on_i = list(range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine))
pg = dist.new_group(ranks_on_i)
if i == machine_rank:
comms.LOCAL_PROCESS_GROUP = pg
assert num_gpus_per_machine <= torch.cuda.device_count()
output = main_func(local_rank, world_size, config, global_rank)
comms.synchronize()
dist.destroy_process_group()
return output
def launch(
main_func: Callable,
num_gpus_per_machine: int,
num_machines: int = 1,
machine_rank: int = 0,
dist_url: str = "auto",
config: Optional[TrainerConfig] = None,
timeout: timedelta = DEFAULT_TIMEOUT,
) -> None:
"""Function that spawns multiple processes to call on main_func
Args:
main_func (Callable): function that will be called by the distributed workers
num_gpus_per_machine (int): number of GPUs per machine
num_machines (int, optional): total number of machines
machine_rank (int, optional): rank of this machine.
dist_url (str, optional): url to connect to for distributed jobs.
config (TrainerConfig, optional): config file specifying training regimen.
timeout (timedelta, optional): timeout of the distributed workers.
"""
assert config is not None
world_size = num_machines * num_gpus_per_machine
if world_size <= 1:
# world_size=0 uses one CPU in one process.
# world_size=1 uses one GPU in one process.
try:
main_func(local_rank=0, world_size=world_size, config=config)
except KeyboardInterrupt:
# print the stack trace
CONSOLE.print(traceback.format_exc())
finally:
profiler.flush_profiler(config.logging)
elif world_size > 1:
# Using multiple gpus with multiple processes.
if dist_url == "auto":
assert num_machines == 1, "dist_url=auto is not supported for multi-machine jobs."
port = _find_free_port()
dist_url = f"tcp://127.0.0.1:{port}"
if num_machines > 1 and dist_url.startswith("file://"):
CONSOLE.log("file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://")
process_context = mp.spawn(
_distributed_worker,
nprocs=num_gpus_per_machine,
join=False,
args=(
main_func,
world_size,
num_gpus_per_machine,
machine_rank,
dist_url,
config,
timeout,
),
)
# process_context won't be None because join=False, so it's okay to assert this
# for Pylance reasons
assert process_context is not None
try:
process_context.join()
except KeyboardInterrupt:
for i, process in enumerate(process_context.processes):
if process.is_alive():
CONSOLE.log(f"Terminating process {i}...")
process.terminate()
process.join()
CONSOLE.log(f"Process {i} finished.")
finally:
profiler.flush_profiler(config.logging)
def main(config: TrainerConfig) -> None:
"""Main function."""
config.set_timestamp()
if config.data:
CONSOLE.log("Using --data alias for --data.pipeline.datamanager.data")
config.pipeline.datamanager.data = config.data
if config.load_config:
CONSOLE.log(f"Loading pre-set config from: {config.load_config}")
config = yaml.load(config.load_config.read_text(), Loader=yaml.Loader)
# print and save config
config.print_to_terminal()
config.save_config()
launch(
main_func=train_loop,
num_gpus_per_machine=config.machine.num_gpus,
num_machines=config.machine.num_machines,
machine_rank=config.machine.machine_rank,
dist_url=config.machine.dist_url,
config=config,
)
def entrypoint():
"""Entrypoint for use with pyproject scripts."""
# Choose a base configuration and override values.
tyro.extras.set_accent_color("bright_yellow")
main(
tyro.cli(
AnnotatedBaseConfigUnion,
description=convert_markup_to_ansi(__doc__),
)
)
if __name__ == "__main__":
entrypoint()