AI/Llama

torchrun을 실행 했을 때 일어나는 동작들

냥냥냥냥냥냥 2024. 12. 29. 15:29

딥러닝이 한창 열풍인 요즘, 딥러닝 관련 프레임워크로는 

파이토치, 텐서플로, 케라스 등등이 많이 쓰이고 있는 것으로 보입니다

 

최근 ai 관련 교육을 들으며 파이토치 library 사용하는 것을 배웠는데, 그 때부터 파이토치에 대해 호기심이 생기더라구요

따라서 이번에는 파이토치에서 근간이 되는 torchrun을 실행할 때 어떤 일이 일어나는 지 한 번 알아보고자 합니다

torchrun (Elastic Launch) — PyTorch 2.5 documentation

 

torchrun (Elastic Launch) — PyTorch 2.5 documentation

torchrun (Elastic Launch) Superset of torch.distributed.launch. torchrun provides a superset of the functionality as torch.distributed.launch with the following additional functionalities: Worker failures are handled gracefully by restarting all workers. W

pytorch.org

위의 pytorch 사이트에서 보면 torchrun api에 대한 설명이 있습니다

Note

torchrun is a python console script to the main module torch.distributed.run declared in the entry_points configuration in setup.py. It is equivalent to invoking python -m torch.distributed.run.

 

이 부분을 실제 코드에서 어떻게 연결되는지 한 번 보려고 합니다

 

torchrun의 위치

pytorch/setup.py at main · pytorch/pytorch

 

pytorch/setup.py at main · pytorch/pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration - pytorch/pytorch

github.com

    entry_points = {
        "console_scripts": [
            "torchrun = torch.distributed.run:main",
        ],

torchrun은 결국 torch.distributed.run의 main함수와 연결 됩니다

torch.distributed.run의 main함수는 아래와 같이 구성 되어 있습니다

def run(args):
    torch.multiprocessing._set_thread_name("pt_elastic")

    if args.standalone:
        args.rdzv_backend = "c10d"
        args.rdzv_endpoint = "localhost:0"
        args.rdzv_id = str(uuid.uuid4())
        logger.info(
            "\n**************************************\n"
            "Rendezvous info:\n"
            "--rdzv-backend=%s "
            "--rdzv-endpoint=%s "
            "--rdzv-id=%s\n"
            "**************************************\n",
            args.rdzv_backend,
            args.rdzv_endpoint,
            args.rdzv_id,
        )

    config, cmd, cmd_args = config_from_args(args)
    elastic_launch(
        config=config,
        entrypoint=cmd,
    )(*cmd_args)


@record
def main(args=None):
    args = parse_args(args)
    run(args)


if __name__ == "__main__":
    main()

config_from_args 내부에서 --nnode = 1 --nproc-per-node=$NUM_TRAINERS 등과 같은 arg로 넘겨준 값들을

parsing을 하고 elastic_launch에 인자로 넘겨줍니다

 

class elastic_launch:
...

    def __init__(
        self,
        config: LaunchConfig,
        entrypoint: Union[Callable, str, None],
    ):
        self._config = config
        self._entrypoint = entrypoint

    def __call__(self, *args):
        return launch_agent(self._config, self._entrypoint, list(args))

elastic_launch 객체를 실제 call 해줄 때,  launch_agent가 실행되며 코드는 아래와 같습니다

 

def launch_agent(
    config: LaunchConfig,
    entrypoint: Union[Callable, str, None],
    args: List[Any],
) -> Dict[int, Any]:
...

    agent = LocalElasticAgent(
        spec=spec,
        logs_specs=config.logs_specs,  # type: ignore[arg-type]
        start_method=config.start_method,
        log_line_prefix_template=config.log_line_prefix_template,
    )

    shutdown_rdzv = True
    try:
        metrics.initialize_metrics(metrics.MetricsConfig(config.metrics_cfg))

        result = agent.run()
        # records that agent.run() has succeeded NOT that workers have succeeded
        events.record(agent.get_event_succeeded())

launch_agent에서는 LocalElasticAgent를 run을 해줍니다

class LocalElasticAgent(SimpleElasticAgent):

LocalElasticAgent는 simpleElasticAgent를 상속받고 있고 거기에 run 메서드가 있습니다

    @prof
    def run(self, role: str = DEFAULT_ROLE) -> RunResult:
        start_time = time.monotonic()
        shutdown_called: bool = False
        try:
            result = self._invoke_run(role)
            self._total_execution_time = int(time.monotonic() - start_time)
            self._record_metrics(result)
            self._record_worker_events(result)
            return result

run 안에서는 _invoke_run을 실행해주는게 핵심인데,

    def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult:
        # NOTE: currently only works for a single role

        spec = self._worker_group.spec
        role = spec.role

        logger.info(
            "[%s] starting workers for entrypoint: %s", role, spec.get_entrypoint_name()
        )

        self._initialize_workers(self._worker_group)
        monitor_interval = spec.monitor_interval
        rdzv_handler = spec.rdzv_handler

        while True:
            assert self._worker_group.state != WorkerState.INIT
            time.sleep(monitor_interval)
            run_result = self._monitor_workers(self._worker_group)
            state = run_result.state
            self._worker_group.state = state

            put_metric(f"workers.{role}.remaining_restarts", self._remaining_restarts)
            put_metric(f"workers.{role}.{state.name.lower()}", 1)

            if state == WorkerState.SUCCEEDED:
                logger.info(
                    "[%s] worker group successfully finished."
                    " Waiting %s seconds for other agents to finish.",
                    role,
                    self._exit_barrier_timeout,
                )
                self._exit_barrier()
                return run_result
            elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
                if self._remaining_restarts > 0:
                    logger.info(
                        "[%s] Worker group %s. "
                        "%s/%s attempts left;"
                        " will restart worker group",
                        role,
                        state.name,
                        self._remaining_restarts,
                        spec.max_restarts,
                    )
                    self._remaining_restarts -= 1
                    self._restart_workers(self._worker_group)
                else:
                    self._stop_workers(self._worker_group)
                    self._worker_group.state = WorkerState.FAILED
                    return run_result
            elif state == WorkerState.HEALTHY:
                # membership changes do not count as retries
                num_nodes_waiting = rdzv_handler.num_nodes_waiting()
                group_rank = self._worker_group.group_rank
                if num_nodes_waiting > 0:
                    logger.info(
                        "[%s] Detected %s "
                        "new nodes from group_rank=%s; "
                        "will restart worker group",
                        role,
                        num_nodes_waiting,
                        group_rank,
                    )
                    self._restart_workers(self._worker_group)
            else:
                raise Exception(  # noqa: TRY002
                    f"[{role}] Worker group in {state.name} state"
                )

_monitor_workers를 실행하고 그로 부터 받은 state에 따라 처리를 해주는 식이네요

_monitor_workers는 abstractmethod 입니다 구현부가 따로 어딘가에 있겠죠 ㅎㅎ

 

 

결국, 아래와 같이 명령어를 입력해주면 

tochrun뒤에 붙은 argument 들은 config_from_args을 통해 parsing 되어

Elastic_Launch 객체를 실행하고, 내부적인 동작은 LocalElasticAgent의 _invoke_run을 실행하는 것입니다

torchrun
    --standalone
    --nnodes=1
    --nproc-per-node=$NUM_TRAINERS
    YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)

 

읽어 주셔서 감사합니다

다음 번엔, llama에서 torchrun을 통해 example script 파이썬 파일을 실행 했을 때 어떤 동작이 일어나는지 한 번 알아보도록 하겠습니다