ding.entry.cli_parsers.slurm_parser¶
ding.entry.cli_parsers.slurm_parser
¶
Full Source Code
../ding/entry/cli_parsers/slurm_parser.py
1import os 2import re 3from time import sleep 4import numpy as np 5from typing import Any, Dict, List, Optional 6 7 8class SlurmParser(): 9 10 def __init__(self, platform_spec: Optional[Dict] = None, **kwargs) -> None: 11 """ 12 Overview: 13 Should only set global cluster properties 14 """ 15 self.kwargs = kwargs 16 self.ntasks = int(os.environ["SLURM_NTASKS"]) 17 self.platform_spec = platform_spec 18 self.tasks = {} 19 self.ntasks_per_node = int(os.environ["SLURM_NTASKS_PER_NODE"]) 20 self.nodelist = self._parse_node_list() 21 self.ports = int(kwargs.get("ports") or 15151) 22 self.parallel_workers = kwargs.get("parallel_workers") or 1 23 self.topology = kwargs.get("topology") or "alone" 24 25 def parse(self) -> dict: 26 procid = int(os.environ["SLURM_PROCID"]) 27 task = self._get_task(procid) 28 # Validation 29 assert task["address"] == os.environ["SLURMD_NODENAME"] 30 return {**self.kwargs, **task} 31 32 def _get_task(self, procid: int) -> Dict[str, Any]: 33 if procid in self.tasks: 34 return self.tasks.get(procid) 35 if self.platform_spec: 36 task = self.platform_spec["tasks"][procid] 37 else: 38 task = {} 39 if "ports" not in task: 40 task["ports"] = self._get_ports(procid) 41 if "address" not in task: 42 task["address"] = self._get_address(procid) 43 if "node_ids" not in task: 44 task["node_ids"] = self._get_node_id(procid) 45 46 task["attach_to"] = self._get_attach_to(procid, task.get("attach_to")) 47 task["topology"] = self.topology 48 task["parallel_workers"] = self.parallel_workers 49 50 self.tasks[procid] = task 51 return task 52 53 def _parse_node_list(self) -> List[str]: 54 nodelist = os.environ["SLURM_NODELIST"] 55 result = re.match(r"(.*)?\[(.*)\]$", nodelist) 56 if result: 57 prefix, tails = result.groups() 58 nodelist = [] 59 for tail in tails.split(","): 60 if "-" in tail: 61 start, stop = tail.split("-") 62 for number in range(int(start), int(stop) + 1): 63 nodelist.append(prefix + str(number)) 64 else: 65 nodelist.append(prefix + tail) 66 elif isinstance(nodelist, str): 67 nodelist = [nodelist] 68 if self.ntasks_per_node > 1: 69 expand_nodelist = [] # Expand node for each task 70 for node in nodelist: 71 for _ in range(self.ntasks_per_node): 72 expand_nodelist.append(node) 73 nodelist = expand_nodelist 74 return nodelist 75 76 def _get_attach_to(self, procid: int, attach_to: Optional[str] = None) -> str: 77 if attach_to: 78 attach_to = [self._get_attach_to_part(part) for part in attach_to.split(",")] 79 elif procid == 0: 80 attach_to = [] 81 else: 82 if self.topology == "mesh": 83 prev_tasks = [self._get_task(i) for i in range(procid)] 84 attach_to = [self._get_attach_to_from_task(task) for task in prev_tasks] 85 attach_to = list(np.concatenate(attach_to)) 86 elif self.topology == "star": 87 head_task = self._get_task(0) 88 attach_to = self._get_attach_to_from_task(head_task) 89 else: 90 attach_to = [] 91 92 return ",".join(attach_to) 93 94 def _get_attach_to_part(self, attach_part: str) -> str: 95 """ 96 Overview: 97 Parse each part of attach_to. 98 Arguments: 99 - attach_part (:obj:`str`): The attach_to field with specific pattern, e.g. $node:0 100 Returns 101 - attach_to (:obj:`str`): The real address, e.g. tcp://SH-0:50000 102 """ 103 if not attach_part.startswith("$node."): 104 return attach_part 105 attach_node_id = int(attach_part[6:]) 106 attach_task = self._get_task(self._get_procid_from_nodeid(attach_node_id)) 107 return self._get_tcp_link(attach_task["address"], attach_task["ports"]) 108 109 def _get_attach_to_from_task(self, task: dict) -> List[str]: 110 """ 111 Overview: 112 Get attach_to list from task, note that parallel_workers will affact the connected processes. 113 Arguments: 114 - task (:obj:`dict`): The task object. 115 Returns 116 - attach_to (:obj:`str`): The real address, e.g. tcp://SH-0:50000 117 """ 118 port = task.get("ports") 119 address = task.get("address") 120 ports = [int(port) + i for i in range(self.parallel_workers)] 121 attach_to = [self._get_tcp_link(address, port) for port in ports] 122 return attach_to 123 124 def _get_procid_from_nodeid(self, nodeid: int) -> int: 125 procid = None 126 for i in range(self.ntasks): 127 task = self._get_task(i) 128 if task["node_ids"] == nodeid: 129 procid = i 130 break 131 if procid is None: 132 raise Exception("Can not find procid from nodeid: {}".format(nodeid)) 133 return procid 134 135 def _get_ports(self, procid) -> int: 136 return self.ports + (procid % self.ntasks_per_node) * self.parallel_workers 137 138 def _get_address(self, procid: int) -> str: 139 address = self.nodelist[procid] 140 return address 141 142 def _get_node_id(self, procid: int) -> int: 143 return procid * self.parallel_workers 144 145 def _get_tcp_link(self, address: str, port: int) -> str: 146 return "tcp://{}:{}".format(address, port) 147 148 149def slurm_parser(platform_spec: str, **kwargs) -> dict: 150 return SlurmParser(platform_spec, **kwargs).parse()