딥러닝이 한창 열풍인 요즘, 딥러닝 관련 프레임워크로는
파이토치, 텐서플로, 케라스 등등이 많이 쓰이고 있는 것으로 보입니다
최근 ai 관련 교육을 들으며 파이토치 library 사용하는 것을 배웠는데, 그 때부터 파이토치에 대해 호기심이 생기더라구요
따라서 이번에는 파이토치에서 근간이 되는 torchrun을 실행할 때 어떤 일이 일어나는 지 한 번 알아보고자 합니다
torchrun (Elastic Launch) — PyTorch 2.5 documentation
torchrun (Elastic Launch) — PyTorch 2.5 documentation
torchrun (Elastic Launch) Superset of torch.distributed.launch. torchrun provides a superset of the functionality as torch.distributed.launch with the following additional functionalities: Worker failures are handled gracefully by restarting all workers. W
pytorch.org
위의 pytorch 사이트에서 보면 torchrun api에 대한 설명이 있습니다
Note
torchrun is a python console script to the main module torch.distributed.run declared in the entry_points configuration in setup.py. It is equivalent to invoking python -m torch.distributed.run.
이 부분을 실제 코드에서 어떻게 연결되는지 한 번 보려고 합니다
torchrun의 위치
pytorch/setup.py at main · pytorch/pytorch
pytorch/setup.py at main · pytorch/pytorch
Tensors and Dynamic neural networks in Python with strong GPU acceleration - pytorch/pytorch
github.com
entry_points = {
"console_scripts": [
"torchrun = torch.distributed.run:main",
],
torchrun은 결국 torch.distributed.run의 main함수와 연결 됩니다
torch.distributed.run의 main함수는 아래와 같이 구성 되어 있습니다
def run(args):
torch.multiprocessing._set_thread_name("pt_elastic")
if args.standalone:
args.rdzv_backend = "c10d"
args.rdzv_endpoint = "localhost:0"
args.rdzv_id = str(uuid.uuid4())
logger.info(
"\n**************************************\n"
"Rendezvous info:\n"
"--rdzv-backend=%s "
"--rdzv-endpoint=%s "
"--rdzv-id=%s\n"
"**************************************\n",
args.rdzv_backend,
args.rdzv_endpoint,
args.rdzv_id,
)
config, cmd, cmd_args = config_from_args(args)
elastic_launch(
config=config,
entrypoint=cmd,
)(*cmd_args)
@record
def main(args=None):
args = parse_args(args)
run(args)
if __name__ == "__main__":
main()
config_from_args 내부에서 --nnode = 1 --nproc-per-node=$NUM_TRAINERS 등과 같은 arg로 넘겨준 값들을
parsing을 하고 elastic_launch에 인자로 넘겨줍니다
class elastic_launch:
...
def __init__(
self,
config: LaunchConfig,
entrypoint: Union[Callable, str, None],
):
self._config = config
self._entrypoint = entrypoint
def __call__(self, *args):
return launch_agent(self._config, self._entrypoint, list(args))
elastic_launch 객체를 실제 call 해줄 때, launch_agent가 실행되며 코드는 아래와 같습니다
def launch_agent(
config: LaunchConfig,
entrypoint: Union[Callable, str, None],
args: List[Any],
) -> Dict[int, Any]:
...
agent = LocalElasticAgent(
spec=spec,
logs_specs=config.logs_specs, # type: ignore[arg-type]
start_method=config.start_method,
log_line_prefix_template=config.log_line_prefix_template,
)
shutdown_rdzv = True
try:
metrics.initialize_metrics(metrics.MetricsConfig(config.metrics_cfg))
result = agent.run()
# records that agent.run() has succeeded NOT that workers have succeeded
events.record(agent.get_event_succeeded())
launch_agent에서는 LocalElasticAgent를 run을 해줍니다
class LocalElasticAgent(SimpleElasticAgent):
LocalElasticAgent는 simpleElasticAgent를 상속받고 있고 거기에 run 메서드가 있습니다
@prof
def run(self, role: str = DEFAULT_ROLE) -> RunResult:
start_time = time.monotonic()
shutdown_called: bool = False
try:
result = self._invoke_run(role)
self._total_execution_time = int(time.monotonic() - start_time)
self._record_metrics(result)
self._record_worker_events(result)
return result
run 안에서는 _invoke_run을 실행해주는게 핵심인데,
def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult:
# NOTE: currently only works for a single role
spec = self._worker_group.spec
role = spec.role
logger.info(
"[%s] starting workers for entrypoint: %s", role, spec.get_entrypoint_name()
)
self._initialize_workers(self._worker_group)
monitor_interval = spec.monitor_interval
rdzv_handler = spec.rdzv_handler
while True:
assert self._worker_group.state != WorkerState.INIT
time.sleep(monitor_interval)
run_result = self._monitor_workers(self._worker_group)
state = run_result.state
self._worker_group.state = state
put_metric(f"workers.{role}.remaining_restarts", self._remaining_restarts)
put_metric(f"workers.{role}.{state.name.lower()}", 1)
if state == WorkerState.SUCCEEDED:
logger.info(
"[%s] worker group successfully finished."
" Waiting %s seconds for other agents to finish.",
role,
self._exit_barrier_timeout,
)
self._exit_barrier()
return run_result
elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
if self._remaining_restarts > 0:
logger.info(
"[%s] Worker group %s. "
"%s/%s attempts left;"
" will restart worker group",
role,
state.name,
self._remaining_restarts,
spec.max_restarts,
)
self._remaining_restarts -= 1
self._restart_workers(self._worker_group)
else:
self._stop_workers(self._worker_group)
self._worker_group.state = WorkerState.FAILED
return run_result
elif state == WorkerState.HEALTHY:
# membership changes do not count as retries
num_nodes_waiting = rdzv_handler.num_nodes_waiting()
group_rank = self._worker_group.group_rank
if num_nodes_waiting > 0:
logger.info(
"[%s] Detected %s "
"new nodes from group_rank=%s; "
"will restart worker group",
role,
num_nodes_waiting,
group_rank,
)
self._restart_workers(self._worker_group)
else:
raise Exception( # noqa: TRY002
f"[{role}] Worker group in {state.name} state"
)
_monitor_workers를 실행하고 그로 부터 받은 state에 따라 처리를 해주는 식이네요
_monitor_workers는 abstractmethod 입니다 구현부가 따로 어딘가에 있겠죠 ㅎㅎ
결국, 아래와 같이 명령어를 입력해주면
tochrun뒤에 붙은 argument 들은 config_from_args을 통해 parsing 되어
Elastic_Launch 객체를 실행하고, 내부적인 동작은 LocalElasticAgent의 _invoke_run을 실행하는 것입니다
torchrun
--standalone
--nnodes=1
--nproc-per-node=$NUM_TRAINERS
YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
읽어 주셔서 감사합니다
다음 번엔, llama에서 torchrun을 통해 example script 파이썬 파일을 실행 했을 때 어떤 동작이 일어나는지 한 번 알아보도록 하겠습니다