Skip to content

ding.interaction.base.network

ding.interaction.base.network

Full Source Code

../ding/interaction/base/network.py

1import json 2import socket 3import time 4from typing import Optional, Any, Mapping, Callable, Type, Tuple 5 6import requests 7from requests import HTTPError 8from urlobject import URLObject 9from urlobject.path import URLPath 10 11from .common import translate_dict_func 12 13 14def get_host_ip() -> Optional[str]: 15 s = None 16 try: 17 s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 18 s.connect(('8.8.8.8', 80)) 19 ip = s.getsockname()[0] 20 finally: 21 if s is not None: 22 s.close() 23 return ip 24 25 26_DEFAULT_HTTP_PORT = 80 27_DEFAULT_HTTPS_PORT = 443 28 29 30def split_http_address(address: str, default_port: Optional[int] = None) -> Tuple[str, int, bool, str]: 31 _url = URLObject(address) 32 33 _host = _url.hostname 34 _https = (_url.scheme.lower()) == 'https' 35 _port = _url.port or default_port or (_DEFAULT_HTTPS_PORT if _https else _DEFAULT_HTTP_PORT) 36 _path = str(_url.path) or '' 37 38 return _host, _port, _https, _path 39 40 41DEFAULT_RETRIES = 5 42DEFAULT_RETRY_WAITING = 1.0 43 44 45class HttpEngine: 46 47 def __init__(self, host: str, port: int, https: bool = False, path: str = None): 48 self.__base_url = URLObject().with_scheme('https' if https else 'http') \ 49 .with_hostname(host).with_port(port).add_path(path or '') 50 self.__session = requests.session() 51 self.__session.trust_env = False 52 53 # noinspection PyMethodMayBeStatic 54 def _data_process(self, data: Optional[Mapping[str, Any]] = None) -> Mapping[str, Any]: 55 return data or {} 56 57 # noinspection PyMethodMayBeStatic 58 def _base_headers(self) -> Mapping[str, None]: 59 return {} 60 61 def _error_handler(self, err: Exception): 62 raise err 63 64 def get_url(self, path: str = None): 65 original_segments = self.__base_url.path.segments 66 path_segments = URLPath().add(path or '').segments 67 return str(self.__base_url.with_path(URLPath.join_segments(original_segments + path_segments))) 68 69 def __single_request( 70 self, 71 method: str, 72 path: str, 73 data: Optional[Mapping[str, Any]] = None, 74 headers: Optional[Mapping[str, Any]] = None, 75 params: Optional[Mapping[str, Any]] = None, 76 raise_for_status: bool = True 77 ): 78 response = self.__session.request( 79 method=method, 80 url=self.get_url(path), 81 data=json.dumps(self._data_process(data) or {}), 82 headers=headers, 83 params=params or {}, 84 ) 85 if raise_for_status: 86 response.raise_for_status() 87 88 return response 89 90 def request( 91 self, 92 method: str, 93 path: str, 94 data: Optional[Mapping[str, Any]] = None, 95 headers: Optional[Mapping[str, Any]] = None, 96 params: Optional[Mapping[str, Any]] = None, 97 raise_for_status: bool = True, 98 retries: Optional[int] = None, 99 retry_waiting: Optional[float] = None, 100 ) -> requests.Response: 101 _headers = dict(self._base_headers()) 102 _headers.update(headers or {}) 103 104 retries = retries or DEFAULT_RETRIES 105 retry_waiting = retry_waiting or DEFAULT_RETRY_WAITING 106 107 try: 108 _current_retries = 0 109 while True: 110 try: 111 response = self.__single_request(method, path, data, _headers, params, raise_for_status) 112 except requests.exceptions.HTTPError as err: 113 raise err 114 except requests.exceptions.RequestException as err: 115 _current_retries += 1 116 if _current_retries > retries: 117 raise err 118 else: 119 time.sleep(retry_waiting) 120 else: 121 break 122 except Exception as e: 123 self._error_handler(e) 124 else: 125 return response 126 127 128def get_http_engine_class( 129 headers: Mapping[str, Callable[..., Any]], 130 data_processor: Optional[Callable[[Mapping[str, Any]], Mapping[str, Any]]] = None, 131 http_error_gene: Optional[Callable[[HTTPError], Exception]] = None, 132) -> Callable[..., Type[HttpEngine]]: 133 134 def _func(*args, **kwargs) -> Type[HttpEngine]: 135 136 class _HttpEngine(HttpEngine): 137 138 def _data_process(self, data: Optional[Mapping[str, Any]] = None) -> Mapping[str, Any]: 139 return (data_processor or (lambda d: d or {}))(data or {}) 140 141 def _base_headers(self) -> Mapping[str, None]: 142 return translate_dict_func(headers)(*args, **kwargs) 143 144 def _error_handler(self, err: Exception): 145 if http_error_gene is not None and isinstance(err, HTTPError): 146 raise http_error_gene(err) 147 else: 148 raise err 149 150 return _HttpEngine 151 152 return _func