Skip to content

ding.entry.cli_parsers.k8s_parser

ding.entry.cli_parsers.k8s_parser

K8SParser

__init__(platform_spec=None, **kwargs)

Overview

Should only set global cluster properties

Full Source Code

../ding/entry/cli_parsers/k8s_parser.py

1import os 2import numpy as np 3from time import sleep 4from typing import Dict, List, Optional 5 6 7class K8SParser(): 8 9 def __init__(self, platform_spec: Optional[Dict] = None, **kwargs) -> None: 10 """ 11 Overview: 12 Should only set global cluster properties 13 """ 14 self.kwargs = kwargs 15 self.nodelist = self._parse_node_list() 16 self.ntasks = len(self.nodelist) 17 self.platform_spec = platform_spec 18 self.parallel_workers = kwargs.get("parallel_workers") or 1 19 self.topology = kwargs.get("topology") or "alone" 20 self.ports = int(kwargs.get("ports") or 50515) 21 self.tasks = {} 22 23 def parse(self) -> dict: 24 if self.kwargs.get("mq_type", "nng") != "nng": 25 return self.kwargs 26 procid = int(os.environ["DI_RANK"]) 27 nodename = self.nodelist[procid] 28 task = self._get_task(procid) 29 # Validation 30 assert task["address"] == nodename 31 return {**self.kwargs, **task} 32 33 def _parse_node_list(self) -> List[str]: 34 return os.environ["DI_NODES"].split(",") 35 36 def _get_task(self, procid: int) -> dict: 37 """ 38 Overview: 39 Complete node properties, use environment vars in list instead of on current node. 40 For example, if you want to set nodename in this function, please derive it from DI_NODES. 41 Arguments: 42 - procid (:obj:`int`): Proc order, starting from 0, must be set automatically by dijob. 43 Note that it is different from node_id. 44 """ 45 if procid in self.tasks: 46 return self.tasks.get(procid) 47 48 if self.platform_spec: 49 task = self.platform_spec["tasks"][procid] 50 else: 51 task = {} 52 if "ports" not in task: 53 task["ports"] = self.kwargs.get("ports") or self._get_ports() 54 if "address" not in task: 55 task["address"] = self.kwargs.get("address") or self._get_address(procid) 56 if "node_ids" not in task: 57 task["node_ids"] = self.kwargs.get("node_ids") or self._get_node_id(procid) 58 59 task["attach_to"] = self.kwargs.get("attach_to") or self._get_attach_to(procid, task.get("attach_to")) 60 task["topology"] = self.topology 61 task["parallel_workers"] = self.parallel_workers 62 63 self.tasks[procid] = task 64 return task 65 66 def _get_attach_to(self, procid: int, attach_to: Optional[str] = None) -> str: 67 """ 68 Overview: 69 Parse from pattern of attach_to. If attach_to is specified in the platform_spec, 70 it is formatted as a real address based on the specified address. 71 If not, the real addresses will be generated based on the globally specified typology. 72 Arguments: 73 - procid (:obj:`int`): Proc order. 74 - attach_to (:obj:`str`): The attach_to field in platform_spec for the task with current procid. 75 Returns 76 - attach_to (:obj:`str`): The real addresses for attach_to. 77 """ 78 if attach_to: 79 attach_to = [self._get_attach_to_part(part) for part in attach_to.split(",")] 80 elif procid == 0: 81 attach_to = [] 82 else: 83 if self.topology == "mesh": 84 prev_tasks = [self._get_task(i) for i in range(procid)] 85 attach_to = [self._get_attach_to_from_task(task) for task in prev_tasks] 86 attach_to = list(np.concatenate(attach_to)) 87 elif self.topology == "star": 88 head_task = self._get_task(0) 89 attach_to = self._get_attach_to_from_task(head_task) 90 else: 91 attach_to = [] 92 93 return ",".join(attach_to) 94 95 def _get_attach_to_part(self, attach_part: str) -> str: 96 """ 97 Overview: 98 Parse each part of attach_to. 99 Arguments: 100 - attach_part (:obj:`str`): The attach_to field with specific pattern, e.g. $node:0 101 Returns 102 - attach_to (:obj:`str`): The real address, e.g. tcp://SH-0:50000 103 """ 104 if not attach_part.startswith("$node."): 105 return attach_part 106 attach_node_id = int(attach_part[6:]) 107 attach_task = self._get_task(self._get_procid_from_nodeid(attach_node_id)) 108 return self._get_tcp_link(attach_task["address"], attach_task["ports"]) 109 110 def _get_attach_to_from_task(self, task: dict) -> List[str]: 111 """ 112 Overview: 113 Get attach_to list from task, note that parallel_workers will affact the connected processes. 114 Arguments: 115 - task (:obj:`dict`): The task object. 116 Returns 117 - attach_to (:obj:`str`): The real address, e.g. tcp://SH-0:50000 118 """ 119 port = task.get("ports") 120 address = task.get("address") 121 ports = [int(port) + i for i in range(self.parallel_workers)] 122 attach_to = [self._get_tcp_link(address, port) for port in ports] 123 return attach_to 124 125 def _get_procid_from_nodeid(self, nodeid: int) -> int: 126 procid = None 127 for i in range(self.ntasks): 128 task = self._get_task(i) 129 if task["node_ids"] == nodeid: 130 procid = i 131 break 132 if procid is None: 133 raise Exception("Can not find procid from nodeid: {}".format(nodeid)) 134 return procid 135 136 def _get_ports(self) -> str: 137 return self.ports 138 139 def _get_address(self, procid: int) -> str: 140 address = self.nodelist[procid] 141 return address 142 143 def _get_tcp_link(self, address: str, port: int) -> str: 144 return "tcp://{}:{}".format(address, port) 145 146 def _get_node_id(self, procid: int) -> int: 147 return procid * self.parallel_workers 148 149 150def k8s_parser(platform_spec: Optional[str] = None, **kwargs) -> dict: 151 return K8SParser(platform_spec, **kwargs).parse()