From b9866ea374d8834ff4d26f3a8dc4aa729f266bb0 Mon Sep 17 00:00:00 2001 From: Jiale Zhi Date: Thu, 3 Sep 2020 16:23:47 -0700 Subject: [PATCH 01/23] [type-hints] Add type hints to fiber/config.py --- fiber/config.py | 68 +++++++++++++++++++++++++++++++++---------------- 1 file changed, 46 insertions(+), 22 deletions(-) diff --git a/fiber/config.py b/fiber/config.py index 75cf391..9d115ba 100644 --- a/fiber/config.py +++ b/fiber/config.py @@ -67,6 +67,11 @@ def main(): import os import logging import configparser +from typing import Any, Dict, List, Type, TypeVar, Optional +from typing import TYPE_CHECKING + +_current_config: None +_TConfig = TypeVar('_TConfig', bound="Config") _current_config = None @@ -84,7 +89,7 @@ def main(): DEFAULT_IMAGE = "fiber-test:latest" -def str2bool(text): +def str2bool(text: str) -> bool: """Simple function to convert a range of values to True/False.""" return text.lower() in ["true", "yes", "1"] @@ -106,34 +111,34 @@ class Config(object): """ - def __init__(self, conf_file=None): + def __init__(self, conf_file: str = None) -> None: # Not documented, people should not use this - self.merge_output = False - self.debug = False - self.image = None - self.default_image = DEFAULT_IMAGE - self.backend = None - self.default_backend = "local" + self.merge_output: bool = False + self.debug: bool = False + self.image: Optional[str] = None + self.default_image: str = DEFAULT_IMAGE + self.backend: Optional[str] = None + self.default_backend: str = "local" # Not documented, this should be removed because it's not used for now - self.use_bash = False - self.log_level = logging.INFO - self.log_file = "/tmp/fiber.log" + self.use_bash: bool = False + self.log_level: int = logging.INFO + self.log_file: str = "/tmp/fiber.log" # If ipc_active is True, Fiber worker processes will connect # to the master process. Otherwise, the master process will connect # to worker processes. # Not documented, should only be used internally - self.ipc_active = True + self.ipc_active: bool = True # if ipc_active is True, this can be 0, otherwise, it can only be a # valid TCP port number. Default 0. - self.ipc_admin_master_port = 0 + self.ipc_admin_master_port: int = 0 # Not documented, this is only used when `ipc_active` is False - self.ipc_admin_worker_port = 8000 + self.ipc_admin_worker_port: int = 8000 # Not documented, need to fine tune this - self.cpu_per_job = 1 + self.cpu_per_job: int = 1 # Not documented, need to fine tune this - self.mem_per_job = None - self.use_push_queue = True - self.kubernetes_namespace = "default" + self.mem_per_job: Optional[int] = None + self.use_push_queue: bool = True + self.kubernetes_namespace: str = "default" if conf_file is None: conf_file = ".fiberconfig" @@ -185,7 +190,7 @@ def __repr__(self): return repr(self.__dict__) @classmethod - def from_dict(cls, kv): + def from_dict(cls: Type[_TConfig], kv: Dict[str, Any]) -> _TConfig: obj = cls() for k in kv: setattr(obj, k, kv[k]) @@ -193,7 +198,7 @@ def from_dict(cls, kv): return obj -def get_object(): +def get_object() -> Config: """ Get a Config object representing current Fiber config @@ -207,7 +212,7 @@ def get_object(): return Config.from_dict(get_dict()) -def get_dict(): +def get_dict() -> Dict[str, Any]: """ Get current Fiber config in a dictionary @@ -218,7 +223,7 @@ def get_dict(): return {k: global_vars[k] for k in vars(_current_config)} -def init(**kwargs): +def init(**kwargs) -> List[str]: """ Init Fiber system and set config values. @@ -247,3 +252,22 @@ def init(**kwargs): logger.debug("Inited fiber with config: %s", vars(_config)) return updates + + +if TYPE_CHECKING: + merge_output = Config.merge_output + debug = Config.debug + image = Config.image + default_image = Config.default_image + backend = Config.backend + default_backend = Config.default_backend + use_bash = Config.use_bash + log_level = Config.log_level + log_file = Config.log_file + ipc_active = Config.ipc_active + ipc_admin_master_port = Config.ipc_admin_master_port + ipc_admin_worker_port = Config.ipc_admin_worker_port + cpu_per_job = Config.cpu_per_job + mem_per_job = Config.mem_per_job + use_push_queue = Config.use_push_queue + kubernetes_namespace = Config.kubernetes_namespace From 968b887229f8762e95c564292b98706b0e567473 Mon Sep 17 00:00:00 2001 From: Jiale Zhi Date: Fri, 4 Sep 2020 15:30:54 -0700 Subject: [PATCH 02/23] [type-hints] Add type hints to fiber/core.py --- fiber/core.py | 51 ++++++++++++++++++++++++++++++++------------------- 1 file changed, 32 insertions(+), 19 deletions(-) diff --git a/fiber/core.py b/fiber/core.py index 0333bf0..a8a9abe 100644 --- a/fiber/core.py +++ b/fiber/core.py @@ -13,9 +13,7 @@ # limitations under the License. import enum - - -MEM_CPU_RATIO = 2 # 2G per cpu +from typing import Dict, List, NoReturn, Optional, Any, Union class ProcessStatus(enum.Enum): @@ -26,8 +24,24 @@ class ProcessStatus(enum.Enum): class JobSpec(object): - def __init__(self, image=None, command=None, name=None, cpu=None, mem=None, - volumes=None, gpu=None): + image: Optional[str] + command: List[str] + name: str + cpu: Optional[int] + mem: Optional[int] + volumes: Optional[Dict[str, Dict]] + gpu: Optional[int] + + def __init__( + self, + image: str = None, + command: List[str] = [], + name: str = "", + cpu: int = None, + mem: int = None, + volumes: Dict[str, Dict] = None, + gpu: int = None, + ) -> None: # Docker image used to launch this job self.image = image # Command to run in this job container, this should be a sequence @@ -40,8 +54,6 @@ def __init__(self, image=None, command=None, name=None, cpu=None, mem=None, # Maximum number of cpu cores this job can use self.gpu = gpu # Maximum memory size in MB that this job can use - #if mem is None: - # mem = cpu * MEM_CPU_RATIO self.mem = mem # volume name to be mounted, currently only used by k8s backend # For example: @@ -50,24 +62,25 @@ def __init__(self, image=None, command=None, name=None, cpu=None, mem=None, # } self.volumes = volumes - def __eq__(self, other): + def __eq__(self, other) -> bool: return self.__dict__ == other.__dict__ - def __repr__(self): - return ''.format(vars(self)) + def __repr__(self) -> str: + return "".format(vars(self)) class Job(object): # Data is used to hold backend specific data associated with this job - data = None + data: Any # Job id. This is set by backend and should only be used by Fiber backend - jid = None + jid: Union[str, int] # (Optional) The hostname/IP address for this job, this is used to # communicate with the master process. It is only used when # `ipc_admin_passive` is enabled. - host = None + host: str - def __init__(self, data, jid): + def __init__(self, data: Any, jid: Union[str, int]) -> None: + assert data is not None, "Job data is None" self.data = data self.jid = jid @@ -81,28 +94,28 @@ class Backend(object): def name(self): raise NotImplementedError - def create_job(self, job_spec): + def create_job(self, job_spec: JobSpec): """This function is called when Fiber wants to create a new Process.""" raise NotImplementedError - def get_job_status(self, job): + def get_job_status(self, job: Job): """This function is called when Fiber wants to to get job status.""" raise NotImplementedError - def get_job_logs(self, job): + def get_job_logs(self, job: Job) -> str: """ This function is called when Fiber wants to to get logs of this job """ return "" - def wait_for_job(self, job, timeout): + def wait_for_job(self, job: Job, timeout: float): """Wait for a specific job until timeout. If timeout is None, wait until job is done. Returns `None` if timed out or `exitcode` if job is finished. """ raise NotImplementedError - def terminate_job(self, job): + def terminate_job(self, job: Job): """Terminate a job described by `job`.""" raise NotImplementedError From 779834a06d4fb8e5616e27374a9e18e7b5299407 Mon Sep 17 00:00:00 2001 From: Jiale Zhi Date: Fri, 4 Sep 2020 21:58:49 -0700 Subject: [PATCH 03/23] [type-hints] Add type hints to fiber/backend.py --- fiber/backend.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/fiber/backend.py b/fiber/backend.py index 54ae58d..8265377 100644 --- a/fiber/backend.py +++ b/fiber/backend.py @@ -18,19 +18,22 @@ import multiprocessing as mp import fiber.config as config +from fiber.core import Backend + +_backends: dict _backends = {} available_backend = ['kubernetes', 'docker', 'local'] -def is_inside_kubenetes_job(): +def is_inside_kubenetes_job() -> bool: if os.environ.get("KUBERNETES_SERVICE_HOST", None): return True return False -def is_inside_docker_job(): +def is_inside_docker_job() -> bool: if os.environ.get("FIBER_BACKEND", "") == "docker": return True return False @@ -42,7 +45,7 @@ def is_inside_docker_job(): } -def auto_select_backend(): +def auto_select_backend() -> str: for backend_name, test in BACKEND_TESTS.items(): if test(): name = backend_name @@ -53,7 +56,7 @@ def auto_select_backend(): return name -def get_backend(name=None, **kwargs): +def get_backend(name=None, **kwargs) -> Backend: """ Returns a working Fiber backend. If `name` is specified, returns a backend specified by `name`. @@ -70,7 +73,7 @@ def get_backend(name=None, **kwargs): _backend = _backends.get(name, None) if _backend is None: - _backend = importlib.import_module("fiber.{}_backend".format( - name)).Backend(**kwargs) + _backend = importlib.import_module("fiber.{}_backend".format( # type: ignore + name)).Backend(**kwargs) # type: Backend _backends[name] = _backend return _backend From a84834cb174423ef2f2368086a13ee303b8b0b7f Mon Sep 17 00:00:00 2001 From: Jiale Zhi Date: Fri, 4 Sep 2020 22:06:21 -0700 Subject: [PATCH 04/23] [type-hints] Add type hints to fiber/init.py --- fiber/init.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fiber/init.py b/fiber/init.py index 84145a1..b6af48c 100644 --- a/fiber/init.py +++ b/fiber/init.py @@ -22,7 +22,7 @@ import fiber.backend as fiber_backend -def init_logger(config, proc_name=None): +def init_logger(config: fiber_config.Config, proc_name: str = "") -> None: logger = logging.getLogger("fiber") if config.log_file.lower() == "stdout": handler = logging.StreamHandler() @@ -31,7 +31,7 @@ def init_logger(config, proc_name=None): if not os.path.exists(log_dir): os.makedirs(log_dir, exist_ok=True) - if not proc_name: + if proc_name == "": p = fiber.process.current_process() proc_name = p.name @@ -49,7 +49,7 @@ def init_logger(config, proc_name=None): logger.propagate = False -def init_fiber(proc_name=None, **kwargs): +def init_fiber(proc_name: str = "", **kwargs) -> None: """ Initialize Fiber. This function is called when you want to re-initialize Fiber with new config values and also re-init loggers. From 1d6d68f7ed81fefabbf387dd9b656492a5a295c5 Mon Sep 17 00:00:00 2001 From: Jiale Zhi Date: Fri, 4 Sep 2020 22:25:07 -0700 Subject: [PATCH 05/23] [type-hints] Add type hints to fiber/util.py --- fiber/util.py | 43 ++++++++++++++++++++++++++++--------------- 1 file changed, 28 insertions(+), 15 deletions(-) diff --git a/fiber/util.py b/fiber/util.py index a9378d5..1b15ad2 100644 --- a/fiber/util.py +++ b/fiber/util.py @@ -10,7 +10,7 @@ import itertools import logging -import multiprocessing.util as mpu +import multiprocessing.util as mpu # type: ignore import os import re import sys @@ -19,9 +19,14 @@ import weakref import psutil +from typing import Any, Dict, Iterator, Tuple, Optional, Callable, Sequence +_afterfork_registry: weakref.WeakValueDictionary +_finalizer_counter: Iterator +_finalizer_registry: Dict[Tuple[Any, Any], "Finalize"] -logger = logging.getLogger('fiber') + +logger = logging.getLogger("fiber") _afterfork_registry = weakref.WeakValueDictionary() _afterfork_counter = itertools.count() @@ -30,27 +35,35 @@ _finalizer_counter = itertools.count() -def register_after_fork(obj, func): +def register_after_fork(obj, func) -> None: _afterfork_registry[(next(_afterfork_counter), id(obj), func)] = obj -def _run_after_forkers(): - logging.debug('_fun_after_forkers called') +def _run_after_forkers() -> None: + logging.debug("_fun_after_forkers called") items = list(_afterfork_registry.items()) items.sort() for (index, ident, func), obj in items: try: - logging.debug('run after forker %s(%s)', func, obj) + logging.debug("run after forker %s(%s)", func, obj) func(obj) except Exception as e: - logging.info('after forker raised exception %s', e) + logging.info("after forker raised exception %s", e) class Finalize(mpu.Finalize): """Basically this is the same as multiprocessing's Finalize class except this one uses it's own _finalizer_registry. """ - def __init__(self, obj, callback, args=(), kwargs=None, exitpriority=None): + + def __init__( + self, + obj: Any, + callback: Callable, + args: Sequence = (), + kwargs: Dict[str, Any] = None, + exitpriority: int = None, + ) -> None: assert exitpriority is None or type(exitpriority) is int if obj is not None: @@ -67,7 +80,7 @@ def __init__(self, obj, callback, args=(), kwargs=None, exitpriority=None): _finalizer_registry[self._key] = self -def find_ip_by_net_interface(target_interface): +def find_ip_by_net_interface(target_interface: str) -> Optional[str]: """Returns ip, debug_info.""" ifces = psutil.net_if_addrs() ip = None @@ -84,11 +97,11 @@ def find_ip_by_net_interface(target_interface): class ForkAwareThreadLock(object): - def __init__(self): + def __init__(self) -> None: self._reset() register_after_fork(self, ForkAwareThreadLock._reset) - def _reset(self): + def _reset(self) -> None: self._lock = threading.Lock() self.acquire = self._lock.acquire self.release = self._lock.release @@ -101,14 +114,14 @@ def __exit__(self, *args): class ForkAwareLocal(threading.local): - def __init__(self): + def __init__(self) -> None: register_after_fork(self, lambda obj: obj.__dict__.clear()) def __reduce__(self): return type(self), () -def find_listen_address(): +def find_listen_address() -> Tuple[Optional[str], Optional[str]]: """Find an IP address for Fiber to use.""" ip = None ifce = None @@ -124,8 +137,8 @@ def find_listen_address(): return ip, ifce -def is_in_interactive_console(): - if hasattr(sys, 'ps1'): +def is_in_interactive_console() -> bool: + if hasattr(sys, "ps1"): return True return False From fc55faaea69d07f75f974cb61b5bee4a64bc2e0f Mon Sep 17 00:00:00 2001 From: Jiale Zhi Date: Fri, 4 Sep 2020 22:41:31 -0700 Subject: [PATCH 06/23] [type-hints] Add type hints to fiber/spawn.py --- fiber/spawn.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/fiber/spawn.py b/fiber/spawn.py index cb48f07..859d2f9 100644 --- a/fiber/spawn.py +++ b/fiber/spawn.py @@ -25,12 +25,13 @@ from fiber.init import init_fiber +from typing import Any, NoReturn logger = logging.getLogger('fiber') -def exit_by_signal(): +def exit_by_signal() -> NoReturn: logger.info("Exiting, sending SIGTERM to current process") os.kill(os.getpid(), signal.SIGTERM) @@ -41,7 +42,7 @@ def exit_by_signal(): os._exit(1) -def exit_on_fd_close(fd): +def exit_on_fd_close(fd) -> None: while True: rl, _, _ = select.select([fd], [], [], 0) if fd in rl: @@ -51,7 +52,7 @@ def exit_on_fd_close(fd): time.sleep(1) -def spawn_prepare(fd): +def spawn_prepare(fd) -> int: from_parent_r = os.fdopen(fd, "rb", closefd=False) preparation_data = reduction.pickle.load(from_parent_r) From d593fd775183139c88f73c5db2f3a15a65843819 Mon Sep 17 00:00:00 2001 From: Jiale Zhi Date: Sat, 5 Sep 2020 12:15:57 -0700 Subject: [PATCH 07/23] Make `Backend` an abstract base class --- fiber/core.py | 33 ++++++++++++++++++++------------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/fiber/core.py b/fiber/core.py index a8a9abe..42e1340 100644 --- a/fiber/core.py +++ b/fiber/core.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from abc import ABC, abstractmethod import enum -from typing import Dict, List, NoReturn, Optional, Any, Union +from typing import Dict, List, NoReturn, Optional, Any, Union, Tuple class ProcessStatus(enum.Enum): @@ -89,18 +90,21 @@ def update(self): raise NotImplementedError -class Backend(object): +class Backend(ABC): @property + @abstractmethod def name(self): - raise NotImplementedError + pass - def create_job(self, job_spec: JobSpec): + @abstractmethod + def create_job(self, job_spec: JobSpec) -> Job: """This function is called when Fiber wants to create a new Process.""" - raise NotImplementedError + pass - def get_job_status(self, job: Job): + @abstractmethod + def get_job_status(self, job: Job) -> ProcessStatus: """This function is called when Fiber wants to to get job status.""" - raise NotImplementedError + pass def get_job_logs(self, job: Job) -> str: """ @@ -108,19 +112,22 @@ def get_job_logs(self, job: Job) -> str: """ return "" - def wait_for_job(self, job: Job, timeout: float): + @abstractmethod + def wait_for_job(self, job: Job, timeout: float) -> Optional[int]: """Wait for a specific job until timeout. If timeout is None, wait until job is done. Returns `None` if timed out or `exitcode` if job is finished. """ - raise NotImplementedError + pass - def terminate_job(self, job: Job): + @abstractmethod + def terminate_job(self, job: Job) -> None: """Terminate a job described by `job`.""" - raise NotImplementedError + pass - def get_listen_addr(self): + @abstractmethod + def get_listen_addr(self) -> Tuple[str, int, str]: """This function is called when Fiber wants to listen on a local address for incoming connection. It is currently used by Popen and Queue.""" - raise NotImplementedError + pass From 433926c7a98694a634e87d0de3b5d77effd3b3a1 Mon Sep 17 00:00:00 2001 From: Jiale Zhi Date: Sat, 5 Sep 2020 12:17:45 -0700 Subject: [PATCH 08/23] [type-hints] Add type hints to fiber/local_backend.py --- fiber/local_backend.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/fiber/local_backend.py b/fiber/local_backend.py index 48e8a8d..bc473f8 100644 --- a/fiber/local_backend.py +++ b/fiber/local_backend.py @@ -21,6 +21,7 @@ import fiber.core as core from fiber.core import ProcessStatus +from typing import Any, Tuple, Optional class Backend(core.Backend): @@ -31,17 +32,17 @@ class Backend(core.Backend): """ name = "local" - def __init__(self): + def __init__(self) -> None: pass - def create_job(self, job_spec): + def create_job(self, job_spec: core.JobSpec) -> core.Job: proc = subprocess.Popen(job_spec.command) job = core.Job(proc, proc.pid) job.host = 'localhost' return job - def get_job_status(self, job): + def get_job_status(self, job: core.Job) -> ProcessStatus: proc = job.data if proc.poll() is not None: @@ -50,7 +51,7 @@ def get_job_status(self, job): return ProcessStatus.STARTED - def wait_for_job(self, job, timeout): + def wait_for_job(self, job: core.Job, timeout: float) -> Optional[int]: proc = job.data if timeout == 0: @@ -63,10 +64,10 @@ def wait_for_job(self, job, timeout): return proc.returncode - def terminate_job(self, job): + def terminate_job(self, job: core.Job) -> None: proc = job.data proc.terminate() - def get_listen_addr(self): + def get_listen_addr(self) -> Tuple[str, int, str]: return "127.0.0.1", 0, "lo" From e69b013e9ab5b82838e1bb43788965a228b0e90a Mon Sep 17 00:00:00 2001 From: Jiale Zhi Date: Sat, 5 Sep 2020 13:04:57 -0700 Subject: [PATCH 09/23] [type-hints] Add type hints to fiber/kubernetes_backend.py --- fiber/kubernetes_backend.py | 95 +++++++++++++++++++------------------ 1 file changed, 49 insertions(+), 46 deletions(-) diff --git a/fiber/kubernetes_backend.py b/fiber/kubernetes_backend.py index d531d46..0726f5c 100644 --- a/fiber/kubernetes_backend.py +++ b/fiber/kubernetes_backend.py @@ -32,6 +32,7 @@ import fiber.config as fiber_config from fiber.core import ProcessStatus from fiber.util import find_ip_by_net_interface, find_listen_address +from typing import Any, Tuple, Optional logger = logging.getLogger("fiber") @@ -49,7 +50,7 @@ class Backend(core.Backend): name = "kubernetes" - def __init__(self, incluster=True): + def __init__(self, incluster: bool = True) -> None: if incluster: config.load_incluster_config() else: @@ -61,24 +62,23 @@ def __init__(self, incluster=True): if incluster: podname = socket.gethostname() - pod = self.core_api.read_namespaced_pod(podname, - self.default_namespace) + pod = self.core_api.read_namespaced_pod(podname, self.default_namespace) # Current model assume that Fiber only lauches 1 container per pod self.current_image = pod.spec.containers[0].image self.volumes = pod.spec.volumes self.mounts = pod.spec.containers[0].volume_mounts - #if pod.spec.volumes: + # if pod.spec.volumes: # self.current_mount = pod.spec.volumes[0].persistent_volume_claim.claim_name - #else: + # else: # self.current_mount = None else: self.current_image = None self.mounts = None self.volumes = None - def _get_resource_requirements(self, job_spec): - #requests = {} + def _get_resource_requirements(self, job_spec: core.JobSpec) -> Any: + # requests = {} limits = {} if job_spec.cpu: @@ -100,17 +100,13 @@ def _get_resource_requirements(self, job_spec): return None - - def create_job(self, job_spec): + def create_job(self, job_spec: core.JobSpec) -> core.Job: logger.debug("[k8s]create_job: %s", job_spec) body = client.V1Pod() name = "{}-{}".format( - job_spec.name.replace("_", "-").lower(), - str(uuid.uuid4())[:8] - ) - body.metadata = client.V1ObjectMeta( - namespace=self.default_namespace, name=name + job_spec.name.replace("_", "-").lower(), str(uuid.uuid4())[:8] ) + body.metadata = client.V1ObjectMeta(namespace=self.default_namespace, name=name) # set environment varialbes # TODO(jiale) add environment variables @@ -118,22 +114,22 @@ def create_job(self, job_spec): image = job_spec.image if job_spec.image else self.current_image container = client.V1Container( - name=name, image=image, command=job_spec.command, env=[], - stdin=True, tty=True, + name=name, + image=image, + command=job_spec.command, + env=[], + stdin=True, + tty=True, ) rr = self._get_resource_requirements(job_spec) if rr: logger.debug( - "[k8s]create_job, container resource requirements: %s", - job_spec + "[k8s]create_job, container resource requirements: %s", job_spec ) container.resources = rr - body.spec = client.V1PodSpec( - containers=[container], - restart_policy="Never" - ) + body.spec = client.V1PodSpec(containers=[container], restart_policy="Never") # propagate mount points to new containers if necesary if job_spec.volumes: @@ -141,18 +137,14 @@ def create_job(self, job_spec): volume_mounts = [] for pd_name, mount_info in job_spec.volumes.items(): - #volume_name = job_spec.volume if job_spec.volume else self.current_mount - pvc = client.V1PersistentVolumeClaimVolumeSource( - claim_name=pd_name - ) + # volume_name = job_spec.volume if job_spec.volume else self.current_mount + pvc = client.V1PersistentVolumeClaimVolumeSource(claim_name=pd_name) volume = client.V1Volume( - persistent_volume_claim=pvc, - name="volume-" + pd_name, + persistent_volume_claim=pvc, name="volume-" + pd_name, ) volumes.append(volume) mount = client.V1VolumeMount( - mount_path=mount_info["bind"], - name=volume.name, + mount_path=mount_info["bind"], name=volume.name, ) if mount_info["mode"] == "r": mount.read_only = True @@ -165,15 +157,13 @@ def create_job(self, job_spec): logger.debug("[k8s]calling create_namespaced_pod: %s", body.metadata.name) try: - v1pod = self.core_api.create_namespaced_pod( - self.default_namespace, body - ) + v1pod = self.core_api.create_namespaced_pod(self.default_namespace, body) except ApiException as e: raise e return core.Job(v1pod, v1pod.metadata.uid) - def get_job_status(self, job): + def get_job_status(self, job: core.Job) -> ProcessStatus: v1pod = job.data name = v1pod.metadata.name namespace = v1pod.metadata.namespace @@ -197,7 +187,7 @@ def get_job_status(self, job): pod_status = v1pod.status return PHASE_STATUS_MAP[pod_status.phase] - def get_job_logs(self, job): + def get_job_logs(self, job: core.Job) -> str: v1job = job.data name = v1job.metadata.name namespace = v1job.metadata.namespace @@ -214,7 +204,7 @@ def get_job_logs(self, job): return logs - def wait_for_job(self, job, timeout): + def wait_for_job(self, job: core.Job, timeout: float) -> Optional[int]: logger.debug("[k8s]wait_for_job timeout=%s", timeout) total = 0 @@ -249,11 +239,13 @@ def wait_for_job(self, job, timeout): logger.debug("[k8s]wait_for_job done: container is not terminated") return None - logger.debug("[k8s]wait_for_job done: container terminated with " - "code: {}".format(terminated.exit_code)) + logger.debug( + "[k8s]wait_for_job done: container terminated with " + "code: {}".format(terminated.exit_code) + ) return terminated.exit_code - def terminate_job(self, job): + def terminate_job(self, job: core.Job) -> None: v1job = job.data name = v1job.metadata.name namespace = v1job.metadata.namespace @@ -263,21 +255,22 @@ def terminate_job(self, job): try: logger.debug( "calling delete_namespaced_pod(%s, %s, grace_period_seconds=%s)", - name, namespace, grace_period_seconds, + name, + namespace, + grace_period_seconds, ) self.core_api.delete_namespaced_pod( - name, namespace, grace_period_seconds=grace_period_seconds, - body=body, + name, namespace, grace_period_seconds=grace_period_seconds, body=body, ) except ApiException as e: logger.debug( - "[k8s] Exception when calling " "delete_namespaced_pod: %s", - str(e), + "[k8s] Exception when calling " "delete_namespaced_pod: %s", str(e), ) raise e - def get_listen_addr(self): + def get_listen_addr(self) -> Tuple[str, int, str]: ip = None + ifce: Optional[str] = None # if fiber.current_process() is multiprocessing.current_process(): if not isinstance(fiber.current_process(), fiber.Process): @@ -297,5 +290,15 @@ def get_listen_addr(self): "Can't find a usable IPv4 address to listen. ifce_name: {}, " "ifces: {}".format(ifce, psutil.net_if_addrs()) ) + + if ifce is None: + raise mp.ProcessError( + "Can't find a usable network interface to listen." + "ifces: {}".format(psutil.net_if_addrs()) + ) + + ip_ret: str = ip + ifce_ret: str = ifce + # use 0 to bind to a random free port number - return ip, 0, ifce + return ip_ret, 0, ifce_ret From 54044d662999162c1d153c50ce29c7377c9f41de Mon Sep 17 00:00:00 2001 From: Jiale Zhi Date: Sat, 5 Sep 2020 13:07:38 -0700 Subject: [PATCH 10/23] [type-hints] Add type hints to fiber/docker_backend.py --- fiber/core.py | 1 + fiber/docker_backend.py | 44 ++++++++++++++++++++++++++++++----------- 2 files changed, 34 insertions(+), 11 deletions(-) diff --git a/fiber/core.py b/fiber/core.py index 42e1340..c6a57f1 100644 --- a/fiber/core.py +++ b/fiber/core.py @@ -84,6 +84,7 @@ def __init__(self, data: Any, jid: Union[str, int]) -> None: assert data is not None, "Job data is None" self.data = data self.jid = jid + self.host = "" def update(self): # update/refresh job attributes diff --git a/fiber/docker_backend.py b/fiber/docker_backend.py index 0059b06..84fbab9 100644 --- a/fiber/docker_backend.py +++ b/fiber/docker_backend.py @@ -31,6 +31,7 @@ import fiber.config as config from fiber.core import ProcessStatus from fiber.util import find_ip_by_net_interface, find_listen_address +from typing import Any, Optional, Tuple, TypeVar, Union logger = logging.getLogger('fiber') @@ -47,7 +48,11 @@ class DockerJob(core.Job): - def update(self): + def __init__(self, data: Any, jid: Union[str, int]) -> None: + super().__init__(data, jid) + self.update() + + def update(self) -> None: container = self.data self.host = container.attrs['NetworkSettings']['IPAddress'] @@ -55,12 +60,12 @@ def update(self): class Backend(core.Backend): name = "docker" - def __init__(self): + def __init__(self) -> None: # Based on this link, no lock is needed accessing self.client # https://github.com/docker/docker-py/issues/619 self.client = docker.from_env() - def create_job(self, job_spec): + def create_job(self, job_spec: core.JobSpec) -> DockerJob: logger.debug("[docker]create_job: %s", job_spec) cwd = os.getcwd() volumes = {cwd: {'bind': cwd, 'mode': 'rw'}, @@ -101,7 +106,7 @@ def create_job(self, job_spec): container._fiber_backend_reloading = False return job - def _reload(self, container): + def _reload(self, container) -> None: container._fiber_backend_reloading = True logger.debug("container reloading %s", container.name) container.reload() @@ -112,12 +117,14 @@ def _reload(self, container): container._fiber_backend_reloading = False - def get_job_logs(self, job): + def get_job_logs(self, job: core.Job) -> str: container = job.data return container.logs(stream=False).decode('utf-8') - def get_job_status(self, job): + def get_job_status(self, job: core.Job) -> ProcessStatus: container = job.data + if container is None: + return ProcessStatus.UNKNOWN if config.merge_output: print(container.logs(stream=False).decode('utf-8')) @@ -131,8 +138,12 @@ def get_job_status(self, job): logger.debug("start container reloading thread %s", container.name) return status - def wait_for_job(self, job, timeout): + def wait_for_job(self, job: core.Job, timeout: float) -> Optional[int]: container = job.data + if container is None: + # Job not started + return None + logger.debug("wait_for_job: %s", container.name) if config.merge_output: @@ -165,7 +176,7 @@ def wait_for_job(self, job, timeout): return res['StatusCode'] - def terminate_job(self, job): + def terminate_job(self, job: core.Job) -> None: logging.debug("terminate_job") container = job.data @@ -184,15 +195,16 @@ def terminate_job(self, job): raise e logger.debug("terminate job finished, %s", container.status) - def get_listen_addr(self): + def get_listen_addr(self) -> Tuple[str, int, str]: ip = None + ifce: Optional[str] = None if sys.platform == "darwin": # use the same hostname for both master and non master process # because docker.for.mac.localhost resolves to different inside # and outside docker container. "docker.for.mac.localhost" is # the name that doesn't change in and outside the container. - return "docker.for.mac.localhost", 0 + return "docker.for.mac.localhost", 0, "eth0" if not isinstance(fiber.current_process(), fiber.Process): # not inside docker @@ -206,5 +218,15 @@ def get_listen_addr(self): raise mp.ProcessError( "Can't find a usable IPv4 address to listen. ifce_name: {}, " "ifces: {}".format(ifce, psutil.net_if_addrs())) + + if ifce is None: + raise mp.ProcessError( + "Can't find a usable network interface to listen." + "ifces: {}".format(psutil.net_if_addrs()) + ) + + ip_ret: str = ip + ifce_ret: str = ifce + # use 0 to bind to a random free port number - return ip, 0, ifce + return ip_ret, 0, ifce_ret From d934799c5daf92829724941b6a5d7890a116125e Mon Sep 17 00:00:00 2001 From: Jiale Zhi Date: Sat, 5 Sep 2020 15:30:53 -0700 Subject: [PATCH 11/23] [type-hints] Add type hints to fiber/__init__.py --- fiber/__init__.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/fiber/__init__.py b/fiber/__init__.py index 8901d3b..191f6f6 100644 --- a/fiber/__init__.py +++ b/fiber/__init__.py @@ -20,6 +20,12 @@ from fiber import context from fiber.init import init_fiber from fiber.meta import meta +from typing import List +from typing import TYPE_CHECKING + +__version__: str +_in_interactive_console: bool +_names: List[str] __version__ = "0.2.1" @@ -47,11 +53,11 @@ _in_interactive_console = False -def reset(): +def reset() -> None: init_fiber() -def init(**kwargs): +def init(**kwargs) -> None: """ Initialize Fiber. This function is called when you want to re-initialize Fiber with new config values and also re-init loggers. @@ -66,3 +72,15 @@ def init(**kwargs): globals().update((name, getattr(context._default_context, name)) for name in _names) __all__ = _names + [] + + +if TYPE_CHECKING: + current_process = context.FiberContext.current_process + active_children = context.FiberContext.active_children + Process = context.FiberContext.Process + Manager = context.FiberContext.Manager + Pool = context.FiberContext.Pool + SimpleQueue = context.FiberContext.SimpleQueue + Pipe = context.FiberContext.Pipe + cpu_count = context.FiberContext.cpu_count + get_context = context.FiberContext.get_context From 6d8f0b5a9cef812cdf523f521339600b39f4f0ea Mon Sep 17 00:00:00 2001 From: Jiale Zhi Date: Sat, 5 Sep 2020 15:34:26 -0700 Subject: [PATCH 12/23] [type-hints] Add type hints to fiber/backend.py --- fiber/backend.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/fiber/backend.py b/fiber/backend.py index 8265377..a4d9c84 100644 --- a/fiber/backend.py +++ b/fiber/backend.py @@ -73,7 +73,8 @@ def get_backend(name=None, **kwargs) -> Backend: _backend = _backends.get(name, None) if _backend is None: - _backend = importlib.import_module("fiber.{}_backend".format( # type: ignore - name)).Backend(**kwargs) # type: Backend + backend_name = "fiber.{}_backend".format(name) + backend_module = importlib.import_module(backend_name) + _backend = backend_module.Backend(**kwargs) # type: ignore _backends[name] = _backend return _backend From 896ec502b6e8a79fb0ef31acc34905ab0ef117be Mon Sep 17 00:00:00 2001 From: Jiale Zhi Date: Sat, 5 Sep 2020 15:39:12 -0700 Subject: [PATCH 13/23] [type-hints] Add type hints to fiber/cli.py --- fiber/cli.py | 82 +++++++++++++++++++++++++++++++++++----------------- 1 file changed, 55 insertions(+), 27 deletions(-) diff --git a/fiber/cli.py b/fiber/cli.py index fceb84d..687feb2 100644 --- a/fiber/cli.py +++ b/fiber/cli.py @@ -34,18 +34,27 @@ import fiber import fiber.core as core from fiber.core import ProcessStatus +import pathlib +from typing import Any, List, Tuple, TypeVar, Optional, Union, Dict + +_T0 = TypeVar("_T0") +_TDockerImageBuilder = TypeVar( + "_TDockerImageBuilder", bound="DockerImageBuilder" +) +CONFIG: Dict[str, str] CONFIG = {} -def get_backend(platform): +def get_backend(platform: str) -> fiber.kubernetes_backend.Backend: from fiber.kubernetes_backend import Backend as K8sBackend + backend = K8sBackend(incluster=False) return backend -def find_docker_files(): +def find_docker_files() -> List[pathlib.Path]: """Find all possible docker files on current directory.""" p = Path(".") q = p / "Dockerfile" @@ -58,7 +67,7 @@ def find_docker_files(): return files -def select_docker_file(files): +def select_docker_file(files: List[pathlib.Path]) -> pathlib.Path: """Ask user which docker file to use and return a PurePath object.""" num = 0 n = len(files) @@ -89,7 +98,7 @@ def select_docker_file(files): return files[num] -def get_default_project_gcp(): +def get_default_project_gcp() -> str: """Get default GCP project name.""" name = sp.check_output( "gcloud config list --format 'value(core.project)' 2>/dev/null", @@ -98,7 +107,7 @@ def get_default_project_gcp(): return name.decode("utf-8").strip() -def parse_file_path(path): +def parse_file_path(path: str) -> Tuple[Optional[str], str]: parts = path.split(":") if len(parts) == 1: return (None, path) @@ -112,7 +121,7 @@ def parse_file_path(path): @click.command() @click.argument("src") @click.argument("dst") -def cp(src, dst): +def cp(src: str, dst: str) -> None: """Copy file from a persistent storage""" platform = CONFIG["platform"] @@ -125,12 +134,14 @@ def cp(src, dst): ) if parts_src[0]: - volume = parts_src[0] - elif parts_dst[0]: - volume = parts_dst[0] + parsed_volume = parts_src[0] else: + parsed_volume = parts_dst[0] + + if parsed_volume is None: raise ValueError("Must copy/to from a persistent volume") + volume:str = parsed_volume k8s_backend = get_backend(platform) job_spec = core.JobSpec( @@ -140,6 +151,9 @@ def cp(src, dst): volumes={volume: {"mode": "rw", "bind": "/persistent"}}, ) job = k8s_backend.create_job(job_spec) + if job.data is None: + raise RuntimeError("Failed to create a new job for data copying") + pod_name = job.data.metadata.name print("launched pod: {}".format(pod_name)) @@ -170,7 +184,7 @@ def cp(src, dst): # k8s_backend.terminate_job(job) -def detect_platforms(): +def detect_platforms() -> List[str]: commands = ["gcloud", "aws"] platforms = ["gcp", "aws"] found_platforms = [] @@ -186,7 +200,7 @@ def detect_platforms(): return found_platforms -def prompt_choices(choices, prompt): +def prompt_choices(choices: List[_T0], prompt: str) -> _T0: num = 0 n = len(choices) @@ -216,13 +230,13 @@ def prompt_choices(choices, prompt): class DockerImageBuilder: - def __init__(self, registry=""): + def __init__(self, registry: str = "") -> None: self.registry = registry - def get_docker_registry_image_name(image_base_name): + def get_docker_registry_image_name(self, image_base_name: str) -> str: return image_base_name - def build(self): + def build(self) -> str: files = find_docker_files() n = len(files) if n == 0: @@ -248,25 +262,26 @@ def build(self): return self.full_image_name - def tag(self): + def tag(self) -> str: self.full_image_name = self.image_name + return self.full_image_name - def push(self): + def push(self) -> None: sp.check_call( "docker push {}".format(self.full_image_name), shell=True, ) - def docker_tag(self, in_name, out_name): + def docker_tag(self, in_name: str, out_name: str) -> None: sp.check_call("docker tag {} {}".format(in_name, out_name), shell=True) class AWSImageBuilder(DockerImageBuilder): - def __init__(self, registry): + def __init__(self, registry: str) -> None: self.registry = registry parts = registry.split(".") self.region = parts[-3] - def tag(self): + def tag(self) -> str: image_name = self.image_name full_image_name = "{}/{}".format(self.registry, self.image_name) @@ -275,7 +290,7 @@ def tag(self): self.full_image_name = full_image_name return full_image_name - def need_new_repo(self): + def need_new_repo(self) -> bool: output = sp.check_output( "aws ecr describe-repositories --region {}".format(self.region), shell=True, @@ -293,7 +308,7 @@ def need_new_repo(self): return True - def create_repo_if_needed(self): + def create_repo_if_needed(self) -> None: if self.need_new_repo(): sp.check_call( "aws ecr create-repository --region {} --repository-name {}".format( @@ -304,7 +319,7 @@ def create_repo_if_needed(self): return - def push(self): + def push(self) -> None: self.create_repo_if_needed() try: @@ -320,10 +335,10 @@ def push(self): class GCPImageBuilder(DockerImageBuilder): - def __init__(self, registry="gcr.io"): + def __init__(self, registry: str = "gcr.io") -> None: self.registry = registry - def tag(self): + def tag(self) -> str: image_name = self.image_name proj = get_default_project_gcp() @@ -343,7 +358,15 @@ def tag(self): @click.option("--memory") @click.option("-v", "--volume") @click.argument("args", nargs=-1) -def run(attach, image, gpu, cpu, memory, volume, args): +def run( + attach: bool, + image: str, + gpu: int, + cpu: int, + memory: int, + volume: str, + args: List[str], +) -> int: """Run a command on a kubernetes cluster with fiber.""" platform = CONFIG["platform"] print( @@ -352,6 +375,8 @@ def run(attach, image, gpu, cpu, memory, volume, args): ) ) + builder: DockerImageBuilder + if image: full_image_name = image else: @@ -393,6 +418,9 @@ def run(attach, image, gpu, cpu, memory, volume, args): job_spec.volumes = volumes job = k8s_backend.create_job(job_spec) + if job.data is None: + raise RuntimeError("Failed to create a new job") + pod_name = job.data.metadata.name exitcode = 0 @@ -414,7 +442,7 @@ def run(attach, image, gpu, cpu, memory, volume, args): return 0 -def auto_select_platform(): +def auto_select_platform() -> str: platforms = detect_platforms() if len(platforms) > 1: choice = prompt_choices( @@ -438,7 +466,7 @@ def auto_select_platform(): "--gcp", is_flag=True, help="Run commands on Google Cloud Platform" ) @click.version_option(version=fiber.__version__, prog_name="fiber") -def main(docker_registry, aws, gcp): +def main(docker_registry: str, aws: bool, gcp: bool) -> None: """fiber command line tool that helps to manage workflow of distributed fiber applications. """ From 4b10f7b0a8cda115164a37e2c7d10c9394756bc1 Mon Sep 17 00:00:00 2001 From: Jiale Zhi Date: Sat, 5 Sep 2020 15:47:07 -0700 Subject: [PATCH 14/23] [type-hints] Add type hints to fiber/context.py --- fiber/context.py | 48 ++++++++++++++++++++++++++++++------------------ 1 file changed, 30 insertions(+), 18 deletions(-) diff --git a/fiber/context.py b/fiber/context.py index 181e585..b6cb70e 100644 --- a/fiber/context.py +++ b/fiber/context.py @@ -15,53 +15,67 @@ import os import fiber.config as config from fiber import process +from typing import Optional, Tuple, Callable, Sequence +from fiber.managers import SyncManager +from fiber.pool import ZPool, ResilientZPool +from fiber.queues import SimpleQueuePush, LazyZConnection +from fiber.queues import Pipe -class FiberContext(): - _name = '' +_default_context: "FiberContext" + + +class FiberContext: + _name = "" Process = process.Process current_process = staticmethod(process.current_process) active_children = staticmethod(process.active_children) - def Manager(self): + def Manager(self) -> SyncManager: """Returns a manager associated with a running server process The managers methods such as `Lock()`, `Condition()` and `Queue()` can be used to create shared objects. """ - from fiber.managers import SyncManager m = SyncManager() m.start() return m - def Pool(self, processes=None, initializer=None, initargs=(), - maxtasksperchild=None, error_handling=False): + def Pool( + self, + processes: int = None, + initializer: Callable = None, + initargs: Sequence = (), + maxtasksperchild: int = None, + error_handling: bool = False, + ) -> ZPool: """Returns a process pool object""" - from .pool import ZPool, ResilientZPool if error_handling: - return ResilientZPool(processes, initializer, initargs, maxtasksperchild) + return ResilientZPool( + processes, initializer, initargs, maxtasksperchild + ) else: return ZPool(processes, initializer, initargs, maxtasksperchild) - def SimpleQueue(self): + def SimpleQueue(self) -> SimpleQueuePush: """Returns a queue object""" if config.use_push_queue: - from .queues import SimpleQueuePush return SimpleQueuePush() # PullQueue is not supported anymore raise NotImplementedError - def Pipe(self, duplex=True): + def Pipe( + self, duplex: bool = True + ) -> Tuple[LazyZConnection, LazyZConnection]: """Returns two connection object connected by a pipe""" - from .queues import Pipe return Pipe(duplex) - def cpu_count(self): + def cpu_count(self) -> Optional[int]: return os.cpu_count() - def get_context(self, method=None): + def get_context(self, method: str = None) -> "FiberContext": if method is None: return self if method != "spawn": @@ -69,8 +83,6 @@ def get_context(self, method=None): return _concrete_contexts[method] -_concrete_contexts = { - 'spawn': FiberContext() -} +_concrete_contexts = {"spawn": FiberContext()} -_default_context = _concrete_contexts['spawn'] +_default_context = _concrete_contexts["spawn"] From ad30a73a6eb5edfb1887a4e87bbbc633086140a9 Mon Sep 17 00:00:00 2001 From: Jiale Zhi Date: Sat, 5 Sep 2020 15:50:43 -0700 Subject: [PATCH 15/23] [type-hints] Add type hints to fiber/meta.py --- fiber/meta.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/fiber/meta.py b/fiber/meta.py index 6385b85..7002963 100644 --- a/fiber/meta.py +++ b/fiber/meta.py @@ -13,10 +13,13 @@ # limitations under the License. +from typing import Any, Callable, Dict + + VALID_META_KEYS = ["cpu", "memory", "gpu"] -def post_process(metadata): +def post_process(metadata: Dict) -> Dict: # memory should be in MB if "memory" in metadata: memory = metadata.pop("memory") @@ -25,7 +28,7 @@ def post_process(metadata): return metadata -def meta(**kwargs): +def meta(**kwargs) -> Callable[[Any], Any]: """ fiber.meta API allows you to decorate your function and provide some hints to Fiber. Currently this is mainly used for specify the resource usage of user From 6d633a4797e4928929a0f6b1e00e8db2cd72f191 Mon Sep 17 00:00:00 2001 From: Jiale Zhi Date: Sat, 5 Sep 2020 17:43:38 -0700 Subject: [PATCH 16/23] [type-hints] Add type hints to fiber/popen_fiber_spawn.py --- fiber/popen_fiber_spawn.py | 196 ++++++++++++++++++++++--------------- 1 file changed, 118 insertions(+), 78 deletions(-) diff --git a/fiber/popen_fiber_spawn.py b/fiber/popen_fiber_spawn.py index f4726a3..3f6a837 100644 --- a/fiber/popen_fiber_spawn.py +++ b/fiber/popen_fiber_spawn.py @@ -36,6 +36,21 @@ from fiber.backend import get_backend from fiber.core import JobSpec from fiber.core import ProcessStatus +from typing import ( + Any, + Dict, + List, + Iterator, + NoReturn, + Optional, + BinaryIO, + Tuple, +) + +_event_counter: Iterator[int] +_event_dict: Dict[int, "EventConn"] +_fiber_background_thread: Optional[threading.Thread] +_fiber_background_thread_lock: threading.Lock logger = logging.getLogger("fiber") @@ -84,7 +99,17 @@ _event_counter = itertools.count(1) -def get_fiber_init(): +class EventConn(): + def __init__( + self, + conn: socket.socket = None + ): + self.event = threading.Event() + self.event.clear() + self.conn = conn + + +def get_fiber_init() -> str: if config.ipc_active: fiber_init = fiber_init_start + fiber_init_net_active + fiber_init_end else: @@ -94,7 +119,9 @@ def get_fiber_init(): return fiber_init -def fiber_background(listen_addr, event_dict): +def fiber_background( + listen_addr: Tuple[str, int], event_dict: Dict[int, "EventConn"] +) -> None: global admin_host, admin_port # Background thread for handling inter fiber process admin traffic @@ -117,7 +144,7 @@ def fiber_background(listen_addr, event_dict): sentinel = event_dict[-1] # notifi master thread that background thread is ready - sentinel.set() + sentinel.event.set() logger.debug("fiber_background thread ready") while True: conn, addr = sock.accept() @@ -127,30 +154,31 @@ def fiber_background(listen_addr, event_dict): # struct.unpack returns a tuple event if the result only has # one element ident = struct.unpack(" str: if backend_name == "docker": # TODO(jiale) fix python path python_exe = "/usr/local/bin/python" else: python_exe = sys.executable - logger.debug("backend is \"%s\", use python exe \"%s\"", - backend_name, python_exe) + logger.debug( + 'backend is "%s", use python exe "%s"', backend_name, python_exe + ) return python_exe -def get_pid_from_jid(jid): +def get_pid_from_jid(jid: Any) -> int: # Some Linux system has 32768 as max pid number. 32749 is a prime number # close to 32768. return hash(jid) % 32749 @@ -159,20 +187,26 @@ def get_pid_from_jid(jid): class Popen(object): method = "spawn" - def __del__(self): - if getattr(self, "ident", None): + def __del__(self) -> None: + ident = getattr(self, "ident", None) + if ident is not None: # clean up entry in event_dict global _event_dict - #logger.debug("cleanup entry _event_dict[%s]", self.ident) - _event_dict.pop(self.ident, None) + if ident in _event_dict: + del _event_dict[ident] - def __repr__(self): + def __repr__(self) -> str: return "<{}({})>".format( type(self).__name__, getattr(self, "process_obj", None) ) - def __init__(self, process_obj, backend=None, launch=False): - self.returncode = None + def __init__( + self, + process_obj: "fiber.process.Process", + backend=None, + launch: bool = False, + ) -> None: + self.returncode: Optional[int] = None self.backend = get_backend() ip, _, _ = self.backend.get_listen_addr() @@ -181,20 +215,20 @@ def __init__(self, process_obj, backend=None, launch=False): self.master_port = config.ipc_admin_master_port self.worker_port = config.ipc_admin_worker_port - self.sock = None - self.host = "" + self.sock: Optional[socket.socket] = None + self.host: str = "" - self.job = None - self.pid = None + self.job: Optional[fiber.core.Job] = None + self.pid: Optional[int] = None self.process_obj = process_obj - self._exiting = None - self.sentinel = None - self.ident = None + self._exiting: bool = False + self.sentinel: Optional[socket.socket] = None + self.ident: int = -1 if launch: self._launch(process_obj) - def launch_fiber_background_thread_if_needed(self): + def launch_fiber_background_thread_if_needed(self) -> None: global _fiber_background_thread_lock _fiber_background_thread_lock.acquire() @@ -203,34 +237,32 @@ def launch_fiber_background_thread_if_needed(self): _fiber_background_thread_lock.release() return - try: - logger.debug( - "_fiber_background_thread is None, creating " - "background thread" - ) - # Create a background thread to handle incoming connections - # from fiber child processes - event = threading.Event() - event.clear() - _event_dict[-1] = event - td = threading.Thread( - target=fiber_background, - args=((self.master_host, self.master_port), _event_dict), - daemon=True, - ) - td.start() - except Exception as e: - raise e - finally: - logger.debug("waiting for background thread") - event.wait() - logger.debug( - "master received message that fiber_background thread is ready" - ) - _fiber_background_thread = td - _fiber_background_thread_lock.release() + ec = EventConn() - def get_command_line(self, **kwds): + logger.debug( + "_fiber_background_thread is None, creating " + "background thread" + ) + # Create a background thread to handle incoming connections + # from fiber child processes + _event_dict[-1] = ec + + td = threading.Thread( + target=fiber_background, + args=((self.master_host, self.master_port), _event_dict), + daemon=True, + ) + td.start() + + logger.debug("waiting for background thread") + ec.event.wait() + logger.debug( + "master received message that fiber_background thread is ready" + ) + _fiber_background_thread = td + _fiber_background_thread_lock.release() + + def get_command_line(self, **kwds) -> List[str]: """Returns prefix of command line used for spawning a child process.""" prog = get_fiber_init() prog = prog.format(**kwds) @@ -248,13 +280,7 @@ def get_command_line(self, **kwds): + ["-c", prog, "--multiprocessing-fork"] ) - def _accept(self): - conn, addr = self.sock.accept() - logger.debug("successfully accept") - # TODO verify if it's the same client - return conn - - def _get_job(self, cmd): + def _get_job(self, cmd: List[str]) -> fiber.core.JobSpec: spec = JobSpec( command=cmd, image=config.image, @@ -274,24 +300,29 @@ def _get_job(self, cmd): return spec - def _run_job(self, job): + def _run_job(self, job_spec: fiber.core.JobSpec) -> fiber.core.Job: + try: - job = self.backend.create_job(job) + job = self.backend.create_job(job_spec) except requests.exceptions.ReadTimeout as e: raise mp.TimeoutError(str(e)) + self.job = job return job - def _sentinel_readable(self, timeout=0): + def _sentinel_readable(self, timeout: int = 0) -> int: # Use fcntl(fd, F_GETFD) instead of select.* becuase: # * select.select() can't work with fd > 1024 # * select.poll() is not thread safe # * select.epoll() is Linux only # Also, fcntl(fd, F_GETFD) is cheaper than the above calls. + if self.sentinel is None: + return False + return fcntl.fcntl(self.sentinel, fcntl.F_GETFD) - def poll(self, flag=os.WNOHANG): + def poll(self, flag: int = os.WNOHANG) -> Optional[int]: # returns None if the process is not stopped yet. Otherwise, returns # process exit code. @@ -327,7 +358,7 @@ def poll(self, flag=os.WNOHANG): return None return self.wait(timeout=0) - def wait(self, timeout=None): + def wait(self, timeout: int = None) -> Optional[int]: if self.job is None: # self.job is None meaning this process hasn't been fully started # yet. @@ -345,7 +376,7 @@ def wait(self, timeout=None): self.returncode = code return self.returncode - def _pickle_data(self, data, fp): + def _pickle_data(self, data, fp: BinaryIO) -> None: if fiber.util.is_in_interactive_console(): logger.debug("in interactive shell, use cloudpickle") cloudpickle.dump(data, fp) @@ -353,7 +384,7 @@ def _pickle_data(self, data, fp): logger.debug("not in interactive shell, use reduction") reduction.dump(data, fp) - def _launch(self, process_obj): + def _launch(self, process_obj) -> None: logger.debug("%s %s _launch called", process_obj, self) if config.ipc_active: @@ -389,15 +420,15 @@ def _launch(self, process_obj): cwd=os.getcwd(), host=admin_host, port=port, id=ident ) - job = self._get_job(cmd) + job_spec = self._get_job(cmd) + + ec = EventConn() - event = threading.Event() - event.clear() - _event_dict[ident] = event + _event_dict[ident] = ec logger.debug( "%s popen_fiber_spawn created event %s and set _event_dict[%s]", self, - event, + ec.event, ident, ) @@ -427,7 +458,7 @@ def _launch(self, process_obj): process_obj._popen = self # launch job - job = self._run_job(job) + job = self._run_job(job_spec) self.pid = get_pid_from_jid(job.jid) # Fix process obj's pid process_obj.ident = self.pid @@ -447,16 +478,22 @@ def _launch(self, process_obj): "connect back" ) return - done = event.wait(0.5) + done = ec.event.wait(0.5) status = self.check_status() if status == ProcessStatus.STOPPED: return logger.debug( "popen_fiber_spawn is waiting for accept event %s to finish", - event, + ec.event, ) - conn = _event_dict[ident] + conn = ec.conn + if conn is None: + raise mp.ProcessError( + "ec.conn should be set at this moment but it is None. " + "Please report this bug to Fiber developers." + ) + logger.debug("got conn from _event_counter[%s]", ident) del _event_dict[ident] logger.debug("remove entry _event_counter[%s]", ident) @@ -511,7 +548,10 @@ def _launch(self, process_obj): self.sentinel = conn logger.debug("_launch finished") - def check_status(self): + def check_status(self) -> fiber.core.ProcessStatus: + if self.job is None: + return ProcessStatus.UNKNOWN + status = self.backend.get_job_status(self.job) if status == ProcessStatus.STOPPED: # something happened that caused Fiber process to hit an early stop @@ -525,7 +565,7 @@ def check_status(self): return status - def terminate(self): + def terminate(self) -> None: logger.debug("[Popen]terminate() called") self._exiting = True From 0e7cf77e588ee32771d171dae2219ccf55c7de1f Mon Sep 17 00:00:00 2001 From: Jiale Zhi Date: Sun, 6 Sep 2020 11:48:00 -0700 Subject: [PATCH 17/23] `Backend.wait_for_job` should accept `None` as `timeout` argument --- fiber/core.py | 2 +- fiber/docker_backend.py | 2 +- fiber/kubernetes_backend.py | 2 +- fiber/local_backend.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/fiber/core.py b/fiber/core.py index c6a57f1..03067dc 100644 --- a/fiber/core.py +++ b/fiber/core.py @@ -114,7 +114,7 @@ def get_job_logs(self, job: Job) -> str: return "" @abstractmethod - def wait_for_job(self, job: Job, timeout: float) -> Optional[int]: + def wait_for_job(self, job: Job, timeout: Optional[float]) -> Optional[int]: """Wait for a specific job until timeout. If timeout is None, wait until job is done. Returns `None` if timed out or `exitcode` if job is finished. diff --git a/fiber/docker_backend.py b/fiber/docker_backend.py index 84fbab9..d64910a 100644 --- a/fiber/docker_backend.py +++ b/fiber/docker_backend.py @@ -138,7 +138,7 @@ def get_job_status(self, job: core.Job) -> ProcessStatus: logger.debug("start container reloading thread %s", container.name) return status - def wait_for_job(self, job: core.Job, timeout: float) -> Optional[int]: + def wait_for_job(self, job: core.Job, timeout: Optional[float]) -> Optional[int]: container = job.data if container is None: # Job not started diff --git a/fiber/kubernetes_backend.py b/fiber/kubernetes_backend.py index 0726f5c..4f6d353 100644 --- a/fiber/kubernetes_backend.py +++ b/fiber/kubernetes_backend.py @@ -204,7 +204,7 @@ def get_job_logs(self, job: core.Job) -> str: return logs - def wait_for_job(self, job: core.Job, timeout: float) -> Optional[int]: + def wait_for_job(self, job: core.Job, timeout: Optional[float]) -> Optional[int]: logger.debug("[k8s]wait_for_job timeout=%s", timeout) total = 0 diff --git a/fiber/local_backend.py b/fiber/local_backend.py index bc473f8..38049f9 100644 --- a/fiber/local_backend.py +++ b/fiber/local_backend.py @@ -51,7 +51,7 @@ def get_job_status(self, job: core.Job) -> ProcessStatus: return ProcessStatus.STARTED - def wait_for_job(self, job: core.Job, timeout: float) -> Optional[int]: + def wait_for_job(self, job: core.Job, timeout: Optional[float]) -> Optional[int]: proc = job.data if timeout == 0: From 0496d3858528de93f1f630770e775ac6de37bcf8 Mon Sep 17 00:00:00 2001 From: Jiale Zhi Date: Sun, 6 Sep 2020 12:13:22 -0700 Subject: [PATCH 18/23] [type-hints] Add type hints to fiber/process.py Also, create a new `Process.target` property so that `Popen` doesn't need to access Process's private attribute. --- fiber/popen_fiber_spawn.py | 6 ++--- fiber/process.py | 55 +++++++++++++++++++++++++------------- 2 files changed, 39 insertions(+), 22 deletions(-) diff --git a/fiber/popen_fiber_spawn.py b/fiber/popen_fiber_spawn.py index 3f6a837..1388aae 100644 --- a/fiber/popen_fiber_spawn.py +++ b/fiber/popen_fiber_spawn.py @@ -288,12 +288,12 @@ def _get_job(self, cmd: List[str]) -> fiber.core.JobSpec: cpu=config.cpu_per_job, mem=config.mem_per_job, ) - if hasattr(self.process_obj._target, "__self__"): + if hasattr(self.process_obj.target, "__self__"): metadata = getattr( - self.process_obj._target.__self__, "__fiber_meta__", None + self.process_obj.target.__self__, "__fiber_meta__", None ) else: - metadata = getattr(self.process_obj._target, "__fiber_meta__", None) + metadata = getattr(self.process_obj.target, "__fiber_meta__", None) if metadata: for k, v in metadata.items(): setattr(spec, k, v) diff --git a/fiber/process.py b/fiber/process.py index e411834..57ecf74 100644 --- a/fiber/process.py +++ b/fiber/process.py @@ -36,23 +36,29 @@ import logging from multiprocessing.process import BaseProcess +from .popen_fiber_spawn import Popen +from typing import Callable, Sequence, Tuple, Optional, List, Set, Iterator logger = logging.getLogger('fiber') +_children: Set["Process"] +_current_process: "BaseProcess" +_process_counter: Iterator[int] + _children = set() _current_process = mp.current_process() -def _cleanup(): +def _cleanup() -> None: # check for processes which have finished for p in list(_children): if p._popen.poll() is not None: _children.discard(p) -def active_children(): +def active_children() -> List["Process"]: """ Get a list of children processes of the current process. @@ -69,7 +75,7 @@ def active_children(): return list(_children) -def current_process(): +def current_process() -> "BaseProcess": """Return a Process object representing the current process. Example: @@ -158,23 +164,27 @@ class Process(BaseProcess): can call `select` and other eligible functions that works on fds on this file descriptor. """ + _popen: fiber.popen_fiber_spawn.Popen + _name: str + _pid: Optional[int] + _start_method = None _pid = None @staticmethod - def _Popen(process_obj): - from .popen_fiber_spawn import Popen + def _Popen(process_obj: "Process") -> Popen: return Popen(process_obj) - def __init__(self, group=None, target=None, name=None, args=(), kwargs={}, - *, daemon=None): + def __init__(self, group: None = None, target: Callable = None, + name: str = None, args: Tuple = (), kwargs={}, + *, daemon: bool = None): super(Process, self).__init__(group=group, target=target, name=name, args=args, kwargs=kwargs, daemon=daemon) self._parent_pid = current_process().pid # set when Process.start() failed self._start_failed = False - def __repr__(self): + def __repr__(self) -> str: return "{}({}, {})>".format(type(self).__name__, self._name, self.ident) def run(self): @@ -184,7 +194,7 @@ def run(self): """ return super().run() - def start(self): + def start(self) -> None: """Start this process. Under the hood, Fiber calls the API on the computer cluster to start a @@ -214,7 +224,7 @@ def start(self): del self._target, self._args, self._kwargs _children.add(self) - def terminate(self): + def terminate(self) -> None: """Terminate current process. When running locally, Fiber sends an SIGTERM signal to the child @@ -227,7 +237,7 @@ def terminate(self): return self._popen.terminate() - def join(self, timeout=None): + def join(self, timeout=None) -> None: """Wait for this process to terminate. :param timeout: The maximum duration of time in seconds that this call @@ -236,11 +246,11 @@ def join(self, timeout=None): `timeout` is `0`, it will check if the process has exited and return immediately. - :returns: The exit code of this process + :returns: None if process terminates or the method times out """ return super().join(timeout=timeout) - def is_alive(self): + def is_alive(self) -> bool: """Check if current process is still alive :returns: `True` if current process is still alive. Returns `False` if @@ -249,19 +259,24 @@ def is_alive(self): return super().is_alive() @property - def ident(self): + def ident(self) -> Optional[int]: if self._pid is None: - self._pid = self._popen and self._popen.pid + self._pid = self._popen.pid if self._popen else None return self._pid @ident.setter - def ident(self, pid): + def ident(self, pid: int): self._pid = pid pid = ident - def _bootstrap(self): + @property + def target(self): + return self._target + + + def _bootstrap(self) -> Tuple[int, Optional[str]]: from multiprocessing import util, context global _current_process, _process_counter, _children err = None @@ -279,7 +294,8 @@ def _bootstrap(self): fiber.util._finalizer_registry.clear() fiber.util._run_after_forkers() except Exception as e: - err = e + import traceback + err = traceback.format_exc() finally: # delay finalization of the old process object until after # _run_after_forkers() is executed @@ -304,7 +320,8 @@ def _bootstrap(self): sys.stderr.write(str(e.args[0]) + '\n') exitcode = 1 except Exception as e: # noqa E722 - err = e + import traceback + err = traceback.format_exc() exitcode = 1 import traceback msg = traceback.format_exc() From 75a544aa44ece111e24f7f3f339ff2930c7a2be7 Mon Sep 17 00:00:00 2001 From: Jiale Zhi Date: Sun, 6 Sep 2020 12:25:11 -0700 Subject: [PATCH 19/23] [type-hints] Add type hints to fiber/socket.py --- fiber/socket.py | 143 +++++++++++++++++++++++++++--------------------- 1 file changed, 82 insertions(+), 61 deletions(-) diff --git a/fiber/socket.py b/fiber/socket.py index c277c2a..af63a12 100644 --- a/fiber/socket.py +++ b/fiber/socket.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from abc import ABC, abstractmethod import random import logging import multiprocessing as mp import threading from fiber.backend import get_backend +from typing import Any, NoReturn, Optional, Tuple MIN_PORT = 40000 @@ -25,7 +27,7 @@ logger = logging.getLogger("fiber") socket_lib = "nanomsg" -default_socket_ctx = None +default_socket_ctx: "SockContext" if socket_lib == "nanomsg": @@ -41,11 +43,11 @@ raise ValueError("bad socket_lib value: {}".format(socket_lib)) -def get_ctx(): +def get_ctx() -> "SockContext": return default_socket_ctx -def bind_to_random_port(sock, addr_base, max_tries=100, nng=True): +def bind_to_random_port(sock, addr_base: str, max_tries:int = 100, nng: bool = True) -> Optional[int]: num_tries = 0 while num_tries < max_tries: try: @@ -60,32 +62,36 @@ def bind_to_random_port(sock, addr_base, max_tries=100, nng=True): num_tries += 1 continue - return + return None -class SockContext: - default_addr = None +class SockContext(ABC): - def new(self, mode): - raise NotImplementedError + default_addr:str = "" + + @abstractmethod + def new(self, mode: str) -> "Socket": + pass @staticmethod - def bind_random(sock, addr): - raise NotImplementedError + @abstractmethod + def bind_random(sock, addr: str) -> int: + pass @staticmethod - def connect(sock, addr): - raise NotImplementedError + @abstractmethod + def connect(sock, addr: str) -> None: + pass @staticmethod - def close(sock): + def close(sock) -> None: sock.close() class ZMQContext(SockContext): default_addr = "tcp://0.0.0.0" - def __init__(self): + def __init__(self) -> None: self._mode_to_type = { "r": zmq.DEALER, @@ -96,24 +102,28 @@ def __init__(self): } self.context = zmq.Context.instance() - def new(self, mode): + def new(self, mode: str) -> "zmq.Socket": sock_type = self._mode_to_type[mode] if sock_type is None: return None return self.context.socket(sock_type) @staticmethod - def bind_random(sock, addr): + def bind_random(sock, addr: str) -> int: assert type(addr) == str - return sock.bind_to_random_port( + port = sock.bind_to_random_port( addr, min_port=MIN_PORT, max_port=MAX_PORT, max_tries=100 ) + if port is None: + raise RuntimeError("ZMQContext Failed to bind to a random port") + + return port @staticmethod - def connect(sock, addr): + def connect(sock, addr) -> Any: return sock.connect(addr) - def device(self, s1_mode, s2_mode): + def device(self, s1_mode: str, s2_mode: str) -> Tuple["zmq.devices.ThreadDevice", str, str]: backend = get_backend() ip_ext, _, _ = backend.get_listen_addr() @@ -140,7 +150,7 @@ class NNGDevice: STATE_AFTER_INIT = 1 STATE_FINISHED = 2 - def __init__(self, ctx, s1_mode, s2_mode, default_addr="tcp://0.0.0.0"): + def __init__(self, ctx: "NNGContext", s1_mode: str, s2_mode: str, default_addr: str = "tcp://0.0.0.0") -> None: self._mode_to_opener = { "r": pynng.lib.nng_pull0_open_raw, @@ -153,7 +163,7 @@ def __init__(self, ctx, s1_mode, s2_mode, default_addr="tcp://0.0.0.0"): self.default_addr = default_addr self._start_process() - def _start_process(self): + def _start_process(self) -> None: parent_conn, child_conn = mp.Pipe() self._proc = threading.Thread( target=self._run, args=(child_conn,), daemon=True @@ -168,15 +178,7 @@ def _start_process(self): #child_conn.close() self.conn = parent_conn - def _mode_to_opener(self, mode): - opener = self._mode_to_opener[mode] - if opener is None: - raise ValueError( - "Mode {} not supported by {}", mode, self.__class__.__name__ - ) - return opener - - def _create_socks(self): + def _create_socks(self) -> Tuple["pynng.Socket", "pynng.Socket"]: opener1 = self._mode_to_opener[self.s1_mode] opener2 = self._mode_to_opener[self.s2_mode] @@ -186,18 +188,18 @@ def _create_socks(self): return s1, s2 - def _bind_socks(self, s1, s2): + def _bind_socks(self, s1: "pynng.Socket", s2: "pynng.Socket") -> Tuple[Optional[int], Optional[int]]: port1 = bind_to_random_port(s1, self.default_addr) port2 = bind_to_random_port(s2, self.default_addr) return port1, port2 - def _run_device(self, s1, s2): + def _run_device(self, s1: "pynng.Socket", s2: "pynng.Socket") -> None: ret = pynng.lib.nng_device(s1.socket, s2.socket) check_err(ret) - def _run(self, conn): + def _run(self, conn: "mp.connection.Connection") -> None: s1, s2 = self._create_socks() state = NNGDevice.STATE_INIT @@ -233,7 +235,7 @@ def _run(self, conn): else: break - def bind(self): + def bind(self) -> Tuple[str, str]: self.conn.send("#bind") in_addr, out_addr = self.conn.recv() if in_addr is None: @@ -242,7 +244,7 @@ def bind(self): self.out_addr = out_addr return in_addr, out_addr - def start(self): + def start(self) -> None: self.conn.send("#start") code = self.conn.recv() if code is not None: @@ -252,7 +254,7 @@ def start(self): class NNGContext(SockContext): default_addr = "tcp://0.0.0.0" - def __init__(self): + def __init__(self) -> None: self._mode_to_creator = { "r": pynng.Pull0, @@ -262,21 +264,25 @@ def __init__(self): "rep": pynng.Rep0, } - def new(self, mode): + def new(self, mode: str) -> "pynng.Socket": func = self._mode_to_creator[mode] if func is None: return None return func() @staticmethod - def bind_random(sock, addr): - return bind_to_random_port(sock, addr) + def bind_random(sock, addr: str) -> int: + port = bind_to_random_port(sock, addr) + if port is None: + raise RuntimeError("NNGContext Failed to bind to a random port") + + return port @staticmethod - def connect(sock, addr): + def connect(sock, addr: str) -> "pynng.Dialer": return sock.dial(addr) - def device(self, s1_mode, s2_mode): + def device(self, s1_mode: str, s2_mode: str) -> Tuple[NNGDevice, str, str]: self.s1_mode = s1_mode self.s2_mode = s2_mode @@ -295,7 +301,7 @@ def device(self, s1_mode, s2_mode): class NanomsgDevice(NNGDevice): - def __init__(self, ctx, s1_mode, s2_mode, default_addr="tcp://0.0.0.0"): + def __init__(self, ctx: "NanomsgContext", s1_mode: str, s2_mode: str, default_addr: str ="tcp://0.0.0.0") -> None: self.s1_mode = s1_mode self.s2_mode = s2_mode self.default_addr = default_addr @@ -303,18 +309,25 @@ def __init__(self, ctx, s1_mode, s2_mode, default_addr="tcp://0.0.0.0"): self._start_process() - def _create_socks(self): + def _create_socks(self) -> Tuple[nnpy.Socket, nnpy.Socket]: s1 = nnpy.Socket(nnpy.AF_SP_RAW, self.ctx._mode_to_type[self.s1_mode]) s2 = nnpy.Socket(nnpy.AF_SP_RAW, self.ctx._mode_to_type[self.s2_mode]) return s1, s2 - def _bind_socks(self, s1, s2): + def _bind_socks(self, s1: nnpy.Socket, s2: nnpy.Socket) -> Tuple[int, int]: port1 = bind_to_random_port(s1, self.default_addr, nng=False) port2 = bind_to_random_port(s2, self.default_addr, nng=False) + + if port1 is None: + raise RuntimeError("NanomsgDevice failed to bind port1") + + if port2 is None: + raise RuntimeError("NanomsgDevice failed to bind port2") + return port1, port2 - def _run_device(self, s1, s2): + def _run_device(self, s1: nnpy.Socket, s2: nnpy.Socket) -> Any: rc = nnpy.nanomsg.nn_device(s1.sock, s2.sock) return nnpy.errors.convert(rc, rc) @@ -323,7 +336,7 @@ def _run_device(self, s1, s2): class NanomsgContext(SockContext): default_addr = "tcp://0.0.0.0" - def __init__(self): + def __init__(self) -> None: self._mode_to_type = { "r": nnpy.PULL, @@ -333,22 +346,30 @@ def __init__(self): "req": nnpy.REQ, } - def new(self, mode): + def new(self, mode) -> nnpy.Socket: sock_type = self._mode_to_type[mode] if sock_type is None: - return None + raise RuntimeError( + "NangmsgContext got Invalid mode: {}".format(mode) + ) + return nnpy.Socket(nnpy.AF_SP, sock_type) @staticmethod - def bind_random(sock, addr): - return bind_to_random_port(sock, addr, nng=False) + def bind_random(sock, addr) -> int: + port = bind_to_random_port(sock, addr, nng=False) + + if port is None: + raise RuntimeError("NanomsgContext Failed to bind to a random port") + + return port @staticmethod - def connect(sock, addr): + def connect(sock, addr) -> Any: return sock.connect(addr) - def device(self, s1_mode, s2_mode): + def device(self, s1_mode, s2_mode) -> Tuple[NanomsgDevice, str, str]: self.s1_mode = s1_mode self.s2_mode = s2_mode @@ -377,13 +398,13 @@ def device(self, s1_mode, s2_mode): class Socket: - def __repr__(self): + def __repr__(self) -> str: return "{}<{},{}>".format( self.__class__.__name__, self._ctx.__class__.__name__, self._mode) - def __init__(self, ctx=get_ctx(), mode="rw"): + def __init__(self, ctx=get_ctx(), mode="rw") -> None: self._mode = mode self._ctx = ctx self._sock = ctx.new(mode) @@ -391,35 +412,35 @@ def __init__(self, ctx=get_ctx(), mode="rw"): raise ValueError("Socket mode \"{}\" not supported by {}".format( mode, ctx.__class__.__name__)) - def send(self, data): + def send(self, data) -> None: self._sock.send(data) - def recv(self): + def recv(self) -> bytes: return self._sock.recv() - def bind(self): + def bind(self) -> int: addr = self._ctx.default_addr bind_random = self._ctx.bind_random port = bind_random(self._sock, addr) return port - def connect(self, addr): + def connect(self, addr) -> None: _connect = self._ctx.connect _connect(self._sock, addr) - def close(self): + def close(self) -> None: _close = self._ctx.close _close(self._sock) class ProcessDevice: - def __init__(self, s1_mode, s2_mode, ctx=get_ctx()): + def __init__(self, s1_mode: Any, s2_mode: Any, ctx=get_ctx()) -> None: device, in_addr, out_addr = ctx.device(s1_mode, s2_mode) self.device = device self.in_addr = in_addr self.out_addr = out_addr - def start(self): + def start(self) -> None: self.device.start() logger.debug("started device in:%s out:%s", self.in_addr, self.out_addr) From 7cdde55e0abdea7a1e08061474bc8e9e9d3418a4 Mon Sep 17 00:00:00 2001 From: Jiale Zhi Date: Sun, 6 Sep 2020 15:11:16 -0700 Subject: [PATCH 20/23] [type-hints] Add type hints to fiber/queues.py --- fiber/queues.py | 204 +++++++++++++++++++++++++++++------------------- 1 file changed, 125 insertions(+), 79 deletions(-) diff --git a/fiber/queues.py b/fiber/queues.py index 627f856..abf4e1d 100644 --- a/fiber/queues.py +++ b/fiber/queues.py @@ -63,18 +63,22 @@ from fiber.process import current_process from fiber.backend import get_backend from fiber.socket import Socket, ProcessDevice +from typing import Any, Dict, List, NoReturn, Tuple, Union, Optional + +_io_threads: List[Any] +_poller_threads: Dict[Any, Any] # define the port range that fiber can use to listen for incoming connections MIN_PORT = 40000 MAX_PORT = 65535 -logger = logging.getLogger('fiber') +logger = logging.getLogger("fiber") _io_threads = [] _poller_threads = {} -def _clean_up(): +def _clean_up() -> None: for t in _io_threads: logger.debug("cleanup thread %s", t) t.stop() @@ -98,8 +102,14 @@ class ZConnection(multiprocessing.connection._ConnectionBase): > note: ZConnection's fileno method returns a Fiber socket. """ - def __init__(self, handle, readable=True, writable=True): - self._name = None + def __init__( + self, + handle: Union[Socket, Tuple[str, str]], + readable: bool = True, + writable: bool = True, + ) -> None: + self._handle: Optional[Socket] + self._name: Optional[str] = None if handle is None: raise ValueError("invalid socket") @@ -112,84 +122,111 @@ def __init__(self, handle, readable=True, writable=True): if not readable and not writable: raise ValueError( - "at least one of `readable` and `writable` must be True") + "at least one of `readable` and `writable` must be True" + ) self._readable = readable self._writable = writable mp_register_after_fork(self, ZConnection._create_handle) atexit.register(self._close) - def __getstate__(self): - return {"sock_type": self.sock_type, - "dest_addr": self.dest_addr, - "_readable": self._readable, - "_writable": self._writable, - "_name": self._name} - - def __setstate__(self, state): - self.sock_type = state['sock_type'] - self.dest_addr = state['dest_addr'] - self._readable = state['_readable'] - self._writable = state['_writable'] - self._name = state['_name'] + def __getstate__(self) -> Dict: + return { + "sock_type": self.sock_type, + "dest_addr": self.dest_addr, + "_readable": self._readable, + "_writable": self._writable, + "_name": self._name, + } + + def __setstate__(self, state) -> None: + self.sock_type = state["sock_type"] + self.dest_addr = state["dest_addr"] + self._readable = state["_readable"] + self._writable = state["_writable"] + self._name = state["_name"] self._create_handle() mp_register_after_fork(self, ZConnection._create_handle) atexit.register(self._close) - def __del__(self): + def __del__(self) -> None: # __del__ is not reliable, we use atexit to clean up things return - def __repr__(self): + def __repr__(self) -> str: name = ( self._name if getattr(self, "_name", None) is not None else self.dest_addr ) - return ''.format(name, getattr(self, "_handle", None)) + return "".format( + name, getattr(self, "_handle", None) + ) - def _create_handle(self): + def _create_handle(self) -> None: logger.debug("%s _create_handle called", self) - #self._handle = context.socket(self.sock_type) + # self._handle = context.socket(self.sock_type) self._handle = Socket(mode=self.sock_type) self._handle.connect(self.dest_addr) logger.debug("connect to %s", self.dest_addr) - def _close(self): + def _close(self) -> None: if self._handle: self._handle.close() self._handle = None - def _send_bytes(self, buf): - self._handle.send(buf) + def _check_readable(self): + if not self._readable: + raise RuntimeError("ZConnection is readonly") + + def _check_closed(self): + if self._handle is None: + raise RuntimeError("ZConnection handle is closed") + + def _send_bytes(self, buf: bytes) -> None: + self._handle.send(buf) # type: ignore # TODO(jiale) _ConnectionBase's send use _ForkingPickler. Switch to # send_pyobj instead? - def _recv_bytes(self, maxsize=None): + def _recv_bytes(self, maxsize: int = None) -> bytes: # TODO(jiale) support maxsize - return self._handle.recv() + return self._handle.recv() # type: ignore + + def send(self, obj): + """Send a (picklable) object""" + self._check_closed() + self._check_writable() + self._send_bytes(reduction.ForkingPickler.dumps(obj)) - def recv(self): + def recv(self) -> Any: """Receive a (picklable) object""" - #logger.debug("recv called") + # logger.debug("recv called") self._check_closed() self._check_readable() buf = self._recv_bytes() return reduction.ForkingPickler.loads(buf) - def _poll(self, timeout): - return bool(self._handle.poll(timeout=timeout)) + def _poll(self, timeout: float) -> bool: + # return bool(self._handle.poll(timeout=timeout)) + # ZConnection doesn't support poll yet + raise NotImplementedError - def set_name(self, name): + def set_name(self, name: str) -> None: self._name = name class LazyZConnection(ZConnection): - def __init__(self, handle, readable=True, writable=True, name=None): - self._name = None + def __init__( + self, + handle: Union[Socket, Tuple[str, str]], + readable: bool = True, + writable: bool = True, + name: Optional[str] = None, + ) -> None: + self._name: Optional[str] = None if handle is None: raise ValueError("invalid socket") @@ -206,60 +243,63 @@ def __init__(self, handle, readable=True, writable=True, name=None): if not readable and not writable: raise ValueError( - "at least one of `readable` and `writable` must be True") + "at least one of `readable` and `writable` must be True" + ) self._readable = readable self._writable = writable - def _check_closed(self): + def _check_closed(self) -> None: if self._inited is False: self._create_handle() self._inited = True if self._handle is None: raise OSError("handle is closed") - def _close(self): + def _close(self) -> None: if not self._inited: return - self._handle.close() + self._handle.close() # type: ignore - def _send_bytes(self, buf): - #self._handle.send_multipart([b"", buf]) - self._handle.send(buf) + def _send_bytes(self, buf: bytes) -> None: + # self._handle.send_multipart([b"", buf]) + self._handle.send(buf) # type: ignore - def _recv_bytes(self, maxsize=None): + def _recv_bytes(self, maxsize: int = None) -> bytes: # TODO(jiale) support maxsize - #msg = self._handle.recv_multipart() + # msg = self._handle.recv_multipart() # msg -> b'' b'message data' (because of ROUTER) - #data = msg[1] - data = self._handle.recv() + # data = msg[1] + data = self._handle.recv() # type: ignore return data - def __getstate__(self): - return {"sock_type": self.sock_type, - "dest_addr": self.dest_addr, - "_readable": self._readable, - "_writable": self._writable} - - def __setstate__(self, state): - self.sock_type = state['sock_type'] - self.dest_addr = state['dest_addr'] - self._readable = state['_readable'] - self._writable = state['_writable'] + def __getstate__(self) -> Dict: + return { + "sock_type": self.sock_type, + "dest_addr": self.dest_addr, + "_readable": self._readable, + "_writable": self._writable, + } + + def __setstate__(self, state) -> None: + self.sock_type = state["sock_type"] + self.dest_addr = state["dest_addr"] + self._readable = state["_readable"] + self._writable = state["_writable"] self._inited = False class LazyZConnectionPipe(LazyZConnection): - def _send_bytes(self, buf): - self._handle.send(buf) + def _send_bytes(self, buf: bytes) -> None: + self._handle.send(buf) # type: ignore - def _recv_bytes(self, maxsize=None): + def _recv_bytes(self, maxsize: int = None) -> Any: # TODO(jiale) support maxsize - return self._handle.recv() + return self._handle.recv() # type: ignore -def Pipe(duplex=True): +def Pipe(duplex: bool = True) -> Tuple[LazyZConnection, LazyZConnection]: """Return a pair of connected ZConnection objects. :param duplex: if duplex, then both read and write are allowed on each @@ -275,22 +315,28 @@ def Pipe(duplex=True): d.start() if duplex: - return (LazyZConnection(("rw", d.out_addr,)), - LazyZConnection(("rw", d.in_addr,))) - return (LazyZConnection(("r", d.out_addr,)), - LazyZConnection(("w", d.in_addr,))) + return ( + LazyZConnection(("rw", d.out_addr,)), + LazyZConnection(("rw", d.in_addr,)), + ) + return ( + LazyZConnection(("r", d.out_addr,)), + LazyZConnection(("w", d.in_addr,)), + ) -class SimpleQueuePush(): +class SimpleQueuePush: """A queue build on top of Fiber socket. It uses "w" - ("r" - "w") - "r" socket combination. Messages are pushed from one end of the queue to the other end without explicitly pulling. """ - def __repr__(self): + + def __repr__(self) -> str: return "SimpleQueuePush".format( - self._reader_addr, self._writer_addr) + self._reader_addr, self._writer_addr + ) - def __init__(self): + def __init__(self) -> None: self.done = False backend = get_backend() ip, _, _ = backend.get_listen_addr() @@ -307,11 +353,11 @@ def __init__(self): # set reader to None because if reader is connected, Fiber socket will # fairly queue messages to all readers even if this reader is # not reading. - #self.reader = None + # self.reader = None self.reader = LazyZConnection(("r", self._reader_addr,)) self.writer = LazyZConnection(("w", self._writer_addr,)) - def get(self): + def get(self) -> Any: """Get an element from this Queue. :returns: An element from this queue. If there is no element in the @@ -326,22 +372,22 @@ def get(self): self.reader = LazyZConnection(("r", self._reader_addr,)) """ - #data = self.reader._handle.recv() + # data = self.reader._handle.recv() msg = self.reader.recv() - #msg = reduction.ForkingPickler.loads(data) + # msg = reduction.ForkingPickler.loads(data) logger.debug("%s got %s", self, msg) return msg - def put(self, obj): + def put(self, obj: Any) -> None: """Put an element into the Queue. :param obj: Any picklable Python object. """ logger.debug("%s put %s", self, obj) - #data = reduction.ForkingPickler.dumps(obj) + # data = reduction.ForkingPickler.dumps(obj) self.writer.send(obj) - ''' + """ def __getstate__(self): d = self.__dict__.copy() # set reader to None so that de-serialized Queue doesn't connect to @@ -349,8 +395,8 @@ def __getstate__(self): # to all the readers. We want to prevent this behavior. d["reader"] = None return d - ''' + """ -#Pipe = ClassicPipe +# Pipe = ClassicPipe SimpleQueue = SimpleQueuePush From f976d8c9a90b38771cceb6ee4702f9092a93b557 Mon Sep 17 00:00:00 2001 From: Jiale Zhi Date: Sun, 6 Sep 2020 16:52:18 -0700 Subject: [PATCH 21/23] [type-hints] Add type hints to fiber/pool.py --- fiber/pool.py | 357 +++++++++++++++++++++++++++-------------------- fiber/process.py | 4 +- 2 files changed, 207 insertions(+), 154 deletions(-) diff --git a/fiber/pool.py b/fiber/pool.py index 6bb3618..8f67874 100644 --- a/fiber/pool.py +++ b/fiber/pool.py @@ -44,9 +44,13 @@ import time import secrets import traceback -from multiprocessing.pool import (CLOSE, RUN, TERMINATE, - ExceptionWithTraceback, MaybeEncodingError, - ThreadPool, _helper_reraises_exception) +from multiprocessing.pool import ( # type: ignore + CLOSE, RUN, TERMINATE, + ExceptionWithTraceback, + MaybeEncodingError, + ThreadPool, + _helper_reraises_exception +) import fiber.queues import fiber.config as config @@ -55,6 +59,8 @@ from fiber.socket import Socket from fiber.process import current_process import signal +from typing import (Any, Generator, Iterator, NoReturn, Callable, + Sequence, List, Dict, Union, Tuple, Optional) if fiber.util.is_in_interactive_console(): @@ -68,7 +74,7 @@ MAX_PORT = 65535 -def safe_join_worker(proc): +def safe_join_worker(proc: fiber.process.Process) -> None: p = proc if p.is_alive(): # worker has not yet exited @@ -77,7 +83,7 @@ def safe_join_worker(proc): p.join(5) -def safe_terminate_worker(proc): +def safe_terminate_worker(proc: fiber.process.Process) -> None: delay = random.random() # Randomize start time to prevent overloading the server @@ -93,7 +99,7 @@ def safe_terminate_worker(proc): logger.debug("safe_terminate_worker() finished") -def safe_start(proc): +def safe_start(proc: fiber.process.Process) -> None: try: proc.start() proc._start_failed = False @@ -104,7 +110,7 @@ def safe_start(proc): proc._start_failed = True -def mp_worker_core(inqueue, outqueue, maxtasks=None, wrap_exception=False): +def mp_worker_core(inqueue: fiber.queues.SimpleQueue, outqueue: fiber.queues.SimpleQueue, maxtasks: int = None, wrap_exception: bool = False) -> None: logger.debug('mp_worker_core running') put = outqueue.put get = inqueue.get @@ -136,13 +142,13 @@ def mp_worker_core(inqueue, outqueue, maxtasks=None, wrap_exception=False): wrapped)) put((job, i, (False, wrapped))) - task = job = result = func = args = kwds = None + task = job = result = func = args = kwds = None # type: ignore completed += 1 logger.debug('worker exiting after %s tasks' % completed) -def mp_worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None, - wrap_exception=False, num_workers=1): +def mp_worker(inqueue: fiber.queues.SimpleQueue, outqueue: fiber.queues.SimpleQueue, initializer: Callable = None, initargs: Sequence = (), maxtasks: int = None, + wrap_exception: bool = False, num_workers: int = 1) -> None: """This is mostly the same as multiprocessing.pool.worker, the difference is that it will start multiple workers (specified by `num_workers` argument) via multiproccessing and allow the Fiber pool worker to take multiple CPU @@ -150,9 +156,9 @@ def mp_worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None, """ assert maxtasks is None or (type(maxtasks) == int and maxtasks > 0) - if hasattr(inqueue, '_writer'): - inqueue._writer.close() - outqueue._reader.close() + if hasattr(inqueue, 'writer'): + inqueue.writer.close() + outqueue.reader.close() if initializer is not None: initializer(*initargs) @@ -174,18 +180,20 @@ def mp_worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None, class ClassicPool(mp_pool.Pool): + _wrap_exception = True + @staticmethod - def Process(ctx, *args, **kwds): + def Process(ctx, *args, **kwds) -> fiber.process.Process: return fiber.process.Process(*args, **kwds) - def __init__(self, processes=None, initializer=None, initargs=(), - maxtasksperchild=None, cluster=None): + def __init__(self, processes: int = None, initializer: Callable = None, initargs: Sequence = (), + maxtasksperchild: int = None, cluster=None) -> None: self._ctx = None self._setup_queues() - self._taskqueue = queue.Queue() - self._cache = {} - self._state = RUN + self._taskqueue: "queue.Queue[Tuple[Iterator[Tuple], Optional[Callable[[int], None]]]]" = queue.Queue() + self._cache: Dict = {} + self._state: int = RUN self._maxtasksperchild = maxtasksperchild self._initializer = initializer self._initargs = initargs @@ -199,8 +207,8 @@ def __init__(self, processes=None, initializer=None, initargs=(), raise TypeError("initializer must be a callable") self._processes = processes - self._pool = [] - self._threads = [] + self._pool: List[fiber.process.Process] = [] + self._threads: List[threading.Thread] = [] self._repopulate_pool() # Worker handler @@ -209,10 +217,10 @@ def __init__(self, processes=None, initializer=None, initargs=(), args=(self._cache, self._taskqueue, self._ctx, self.Process, self._processes, self._pool, self._threads, self._inqueue, self._outqueue, self._initializer, self._initargs, - self._maxtasksperchild, self._wrap_exception) + self._maxtasksperchild, self._wrap_exception) # type: ignore ) self._worker_handler.daemon = True - self._worker_handler._state = RUN + self._worker_handler._state = RUN # type: ignore self._worker_handler.start() logger.debug( "Pool: started _handle_workers thread(%s:%s)", @@ -231,7 +239,7 @@ def __init__(self, processes=None, initializer=None, initargs=(), self._pool, self._cache) ) self._task_handler.daemon = True - self._task_handler._state = RUN + self._task_handler._state = RUN # type: ignore self._task_handler.start() logger.debug( "Pool: started _handle_tasks thread(%s:%s)", @@ -244,7 +252,7 @@ def __init__(self, processes=None, initializer=None, initargs=(), args=(self._outqueue, self._quick_get, self._cache) ) self._result_handler.daemon = True - self._result_handler._state = RUN + self._result_handler._state = RUN # type: ignore self._result_handler.start() logger.debug( "Pool: started _handle_results thread(%s:%s)", @@ -261,7 +269,7 @@ def __init__(self, processes=None, initializer=None, initargs=(), ) logger.debug("Pool: registered _terminate_pool finalizer") - def _setup_queues(self): + def _setup_queues(self) -> None: self._inqueue = fiber.queues.SimpleQueue() logger.debug("Pool|created Pool._inqueue: %s", self._inqueue) self._outqueue = fiber.queues.SimpleQueue() @@ -273,8 +281,11 @@ def _setup_queues(self): # is a REQ socket. It can't be called consecutively. self._quick_get = self._outqueue.get - def _map_async(self, func, iterable, mapper, chunksize=None, callback=None, - error_callback=None): + def _map_async( + self, func: Callable, iterable: Sequence, mapper, + chunksize: int = None, callback: Callable = None, + error_callback: Callable = None + ) -> mp.pool.MapResult: """ Helper function to implement map, starmap and their async counterparts. """ @@ -293,12 +304,12 @@ def _map_async(self, func, iterable, mapper, chunksize=None, callback=None, if len(iterable) == 0: chunksize = 0 - task_batches = ClassicPool._get_tasks(func, iterable, chunksize) - result = mp_pool.MapResult(self._cache, chunksize, len(iterable), + task_batches = ClassicPool._get_tasks(func, iterable, chunksize) # type: ignore + result: mp_pool.MapResult = mp_pool.MapResult(self._cache, chunksize, len(iterable), callback, error_callback=error_callback) self._taskqueue.put( ( - self._guarded_task_generation(result._job, + self._guarded_task_generation(result._job, # type: ignore mapper, task_batches), None @@ -309,12 +320,12 @@ def _map_async(self, func, iterable, mapper, chunksize=None, callback=None, @staticmethod def _handle_workers(cache, taskqueue, ctx, Process, processes, pool, threads, inqueue, outqueue, initializer, initargs, - maxtasksperchild, wrap_exception): + maxtasksperchild, wrap_exception) -> None: thread = threading.current_thread() # Keep maintaining workers until the cache gets drained, unless the # pool is terminated. - while thread._state == RUN or (cache and thread._state != TERMINATE): + while thread._state == RUN or (cache and thread._state != TERMINATE): # type: ignore ClassicPool._maintain_pool(ctx, Process, processes, pool, threads, inqueue, outqueue, initializer, initargs, maxtasksperchild, wrap_exception) @@ -323,7 +334,7 @@ def _handle_workers(cache, taskqueue, ctx, Process, processes, pool, logger.debug("_handle_workers exits") @staticmethod - def _join_exited_workers(pool): + def _join_exited_workers(pool) -> bool: """Cleanup after any worker processes which have exited due to reaching their specified lifetime. Returns True if any workers were cleaned up. """ @@ -353,7 +364,7 @@ def _join_exited_workers(pool): del pool[i] return cleaned - def _repopulate_pool(self): + def _repopulate_pool(self) -> Any: return self._repopulate_pool_static(self._ctx, self.Process, self._processes, self._pool, self._threads, @@ -366,7 +377,7 @@ def _repopulate_pool(self): @staticmethod def _repopulate_pool_static(ctx, Process, processes, pool, threads, inqueue, outqueue, initializer, initargs, - maxtasksperchild, wrap_exception): + maxtasksperchild, wrap_exception) -> None: """Bring the number of pool processes up to the specified number, for use after reaping workers which have exited. """ @@ -413,7 +424,7 @@ def _repopulate_pool_static(ctx, Process, processes, pool, threads, @staticmethod def _maintain_pool(ctx, Process, processes, pool, threads, inqueue, outqueue, initializer, initargs, maxtasksperchild, - wrap_exception): + wrap_exception) -> None: """Clean up any exited workers and start replacements for them. """ if ClassicPool._join_exited_workers(pool): @@ -424,15 +435,19 @@ def _maintain_pool(ctx, Process, processes, pool, threads, inqueue, wrap_exception) @staticmethod - def _handle_tasks(taskqueue, put, outqueue, pool, cache): + def _handle_tasks( + taskqueue: "queue.Queue[Tuple[Iterator[Tuple], Optional[Callable[[int], None]]]]", + put, outqueue, pool, cache + ) -> None: thread = threading.current_thread() for taskseq, set_length in iter(taskqueue.get, None): + task = None try: # iterating taskseq cannot fail for task in taskseq: - if thread._state: + if thread._state: # type: ignore break try: put(task) @@ -449,7 +464,7 @@ def _handle_tasks(taskqueue, put, outqueue, pool, cache): continue break finally: - task = taskseq = job = None + task = taskseq = job = None # type: ignore else: logger.debug('_handle_tasks: task handler got sentinel') @@ -471,7 +486,7 @@ def _handle_tasks(taskqueue, put, outqueue, pool, cache): logger.debug('_handle_tasks: task handler exiting') @staticmethod - def _handle_results(outqueue, get, cache): + def _handle_results(outqueue, get, cache) -> None: thread = threading.current_thread() while 1: @@ -481,8 +496,8 @@ def _handle_results(outqueue, get, cache): # logger.debug('result handler got EOFError/OSError: exiting') return - if thread._state: - assert thread._state == TERMINATE + if thread._state: # type: ignore + assert thread._state == TERMINATE # type: ignore # logger.debug('result handler found thread._state=TERMINATE') break @@ -497,7 +512,7 @@ def _handle_results(outqueue, get, cache): pass task = job = obj = None - while cache and thread._state != TERMINATE: + while cache and thread._state != TERMINATE: # type: ignore try: task = get() except (OSError, EOFError): @@ -528,11 +543,11 @@ def _handle_results(outqueue, get, cache): pass logger.debug('result handler exiting: len(cache)=%s, thread._state=%s', - len(cache), thread._state) + len(cache), thread._state) # type: ignore @classmethod def _terminate_pool(cls, taskqueue, inqueue, outqueue, pool, threads, - worker_handler, task_handler, result_handler, cache): + worker_handler, task_handler, result_handler, cache) -> None: # this is guaranteed to only be called once logger.debug('finalizing pool') start = time.time() @@ -649,21 +664,21 @@ class Inventory(): can be called with corresponding `seq`. This inventory will handle waiting, managing results from different map calls. """ - def __init__(self, queue_get): + def __init__(self, queue_get: Callable[[], Any]) -> None: self._seq = 0 self._queue_get = queue_get - self._inventory = {} - self._spec = {} - self._idx_cur = {} + self._inventory: Dict[int, List[Any]] = {} + self._spec: Dict[int, int] = {} + self._idx_cur: Dict[int, int] = {} - def add(self, ntasks): + def add(self, ntasks: int) -> int: self._seq += 1 self._inventory[self._seq] = [None] * ntasks self._spec[self._seq] = ntasks self._idx_cur[self._seq] = 0 return self._seq - def get(self, job_seq): + def get(self, job_seq: int) -> Any: n = self._spec[job_seq] while n != 0: @@ -675,10 +690,10 @@ def get(self, job_seq): n = self._spec[seq] ret = self._inventory[job_seq] - self._inventory[job_seq] = None + self._inventory[job_seq] = [] return ret - def iget_unordered(self, job_seq): + def iget_unordered(self, job_seq: int) -> Iterator[Any]: n = self._spec[job_seq] while n != 0: @@ -694,7 +709,7 @@ def iget_unordered(self, job_seq): return - def iget_ordered(self, job_seq): + def iget_ordered(self, job_seq: int) -> Iterator[Any]: idx = self._idx_cur[job_seq] total = len(self._inventory[job_seq]) @@ -729,17 +744,17 @@ def iget_ordered(self, job_seq): class MapResult(): - def __init__(self, seq, inventory): + def __init__(self, seq: int, inventory: Inventory) -> None: self._seq = seq self._inventory = inventory - def get(self): + def get(self) -> Any: return self._inventory.get(self._seq) - def iget_ordered(self): + def iget_ordered(self) -> Any: return self._inventory.iget_ordered(self._seq) - def iget_unordered(self): + def iget_unordered(self) -> Any: return self._inventory.iget_unordered(self._seq) @@ -748,7 +763,7 @@ class ApplyResult(MapResult): represents an handle that can be used to get the actual result. """ - def get(self): + def get(self) -> Any: """Get the actual result represented by this object :returns: Actual result. This method will block if the actual result @@ -757,8 +772,14 @@ def get(self): return self._inventory.get(self._seq)[0] -def zpool_worker_core(master_conn, result_conn, maxtasksperchild, - wrap_exception, rank=-1, req=False): +def zpool_worker_core( + master_conn: LazyZConnection, + result_conn: LazyZConnection, + maxtasksperchild: Optional[int], + wrap_exception: bool, + rank: int = -1, + req: bool = False +) -> None: """ The actual function that processes tasks. @@ -775,13 +796,11 @@ def zpool_worker_core(master_conn, result_conn, maxtasksperchild, """ logger.debug("zpool_worker_core started %s", rank) - proc = None ident = secrets.token_bytes(4) - if req: - proc = current_process() while True: if req: + proc = current_process() # master_conn is a REQ type socket, need to send id (rank) to # master to request a task. id is packed in type unsigned short (H) master_conn.send_bytes(struct.pack("4si", ident, proc.pid)) @@ -813,24 +832,26 @@ def zpool_worker_core(master_conn, result_conn, maxtasksperchild, data = (seq, batch, batch + i, res) if req: - data += (ident,) - result_conn.send(data) + result_conn.send((seq, batch, batch + i, res, ident)) + else: + result_conn.send((seq, batch, batch + i, res)) else: for i, args in enumerate(arg_list): res = func(args) - data = (seq, batch, batch + i, res) if req: - data += (ident,) - result_conn.send(data) + result_conn.send((seq, batch, batch + i, res, ident)) + else: + result_conn.send((seq, batch, batch + i, res)) + #print("worker_core exit, ", rank, proc.pid) -def handle_signal(signal, frame): +def handle_signal(signal, frame) -> None: # run sys.exit() so that atexit handlers can run sys.exit() -def zpool_worker(master_conn, result_conn, initializer=None, initargs=(), - maxtasks=None, wrap_exception=False, num_workers=1, req=False): +def zpool_worker(master_conn: LazyZConnection, result_conn: LazyZConnection, initializer: Callable = None, initargs: Sequence = (), + maxtasks: int = None, wrap_exception: bool = False, num_workers: int = 1, req: bool = False) -> None: """ The entry point of Pool worker function. @@ -885,23 +906,29 @@ class ZPool(): results handling. This makes it faster. """ - def __init__(self, processes=None, initializer=None, initargs=(), - maxtasksperchild=None, cluster=None, - master_sock_type="w"): - - self._pool = [] + def __init__( + self, + processes: int = None, + initializer: Callable = None, + initargs: Sequence = (), + maxtasksperchild: int = None, + cluster=None, + master_sock_type: str = "w" + ) -> None: + + self._pool: List[fiber.process.Process] = [] # Set default processes to 1 self._processes = processes if processes is not None else 1 self._initializer = initializer self._initargs = initargs self._maxtasksperchild = maxtasksperchild self._cluster = cluster - self._seq = 0 - self._state = RUN - self.taskq = queue.Queue() - self.sent_tasks = 0 - self.recv_tasks = 0 - self.max_processing_tasks = 20000 + self._seq: int = 0 + self._state: int = RUN + self.taskq: queue.Queue = queue.Queue() + self.sent_tasks: int = 0 + self.recv_tasks: int = 0 + self.max_processing_tasks: int = 20000 # networking related backend = get_backend() @@ -928,7 +955,7 @@ def __init__(self, processes=None, initializer=None, initargs=(), target=self.__class__._handle_workers, args=(self,) ) td.daemon = True - td._state = RUN + td._state = RUN # type: ignore # `td` will be started later by `lazy_start_workers` later self._worker_handler = td self._worker_handler_started = False @@ -938,18 +965,18 @@ def __init__(self, processes=None, initializer=None, initargs=(), target=self._handle_tasks, ) td.daemon = True - td._state = RUN + td._state = RUN # type: ignore td.start() self._task_handler = td - def __repr__(self): + def __repr__(self) -> str: return "<{}({}, {})>".format( type(self).__name__, getattr(self, "_processes", None), getattr(self, "_master_addr", None), ) - def _handle_tasks(self): + def _handle_tasks(self) -> None: taskq = self.taskq master_sock = self._master_sock @@ -962,10 +989,10 @@ def _handle_tasks(self): master_sock.send(data) self.sent_tasks += 1 - def _task_put(self, task): + def _task_put(self, task) -> None: self.taskq.put(task) - def _res_get(self): + def _res_get(self) -> Any: payload = self._result_sock.recv() self.recv_tasks += 1 data = pickle.loads(payload) @@ -973,16 +1000,16 @@ def _res_get(self): return data @staticmethod - def _join_exited_workers(workers): + def _join_exited_workers(workers: List[fiber.process.Process]) -> List[fiber.process.Process]: thread = threading.current_thread() logger.debug("ZPool _join_exited_workers running, workers %s, " - "thread._state %s", workers, thread._state) + "thread._state %s", workers, thread._state) # type: ignore exited_workers = [] for i in reversed(range(len(workers))): - if thread._state != RUN: + if thread._state != RUN: # type: ignore break worker = workers[i] @@ -1007,8 +1034,15 @@ def _join_exited_workers(workers): return exited_workers @staticmethod - def _maintain_workers(processes, workers, master_addr, result_addr, initializer, - initargs, maxtasksperchild): + def _maintain_workers( + processes: int, + workers: List[fiber.process.Process], + master_addr: str, + result_addr: str, + initializer: Optional[Callable], + initargs: Optional[Sequence], + maxtasksperchild: Optional[int] + ) -> None: thread = threading.current_thread() workers_per_fp = config.cpu_per_job @@ -1017,7 +1051,7 @@ def _maintain_workers(processes, workers, master_addr, result_addr, initializer, logger.debug("ZPool _maintain_workers running, workers %s", workers) threads = [] - while left > 0 and thread._state == RUN: + while left > 0 and thread._state == RUN: # type: ignore if left > workers_per_fp: n = workers_per_fp @@ -1057,7 +1091,7 @@ def _maintain_workers(processes, workers, master_addr, result_addr, initializer, logger.debug("ZPool _maintain_workers finished, workers %s", workers) @staticmethod - def _handle_workers(pool): + def _handle_workers(pool: "ZPool") -> None: logger.debug("%s _handle_workers running", pool) td = threading.current_thread() @@ -1067,7 +1101,7 @@ def _handle_workers(pool): pool._initargs, pool._maxtasksperchild ) - while td._state == RUN: + while td._state == RUN: # type: ignore if len(ZPool._join_exited_workers(pool._pool)) > 0: # create new workers when old workers exited ZPool._maintain_workers( @@ -1082,12 +1116,12 @@ def _handle_workers(pool): pool) @staticmethod - def _chunks(iterable, size): + def _chunks(iterable: Sequence, size: int) -> Iterator[Any]: for i in range(0, len(iterable), size): yield iterable[i:i + size] - def apply_async(self, func, args=(), kwds={}, callback=None, - error_callback=None): + def apply_async(self, func: Callable, args: Sequence = (), kwds: Dict = {}, callback: Callable = None, + error_callback: Callable = None) -> ApplyResult: """ Run function `func` with arguments `args` and keyword arguments `kwds` on a remote Pool worker. This is an asynchronous version of `apply`. @@ -1115,15 +1149,15 @@ def apply_async(self, func, args=(), kwds={}, callback=None, return res - def start_workers(self): + def start_workers(self) -> None: self._worker_handler.start() self._worker_handler_started = True - def lazy_start_workers(self, func): + def lazy_start_workers(self, func: Callable) -> None: if hasattr(func, "__fiber_meta__"): if ( not hasattr(zpool_worker, "__fiber_meta__") - or zpool_worker.__fiber_meta__ != func.__fiber_meta__ + or zpool_worker.__fiber_meta__ != func.__fiber_meta__ # type: ignore ): if self._worker_handler_started: raise RuntimeError( @@ -1131,13 +1165,13 @@ def lazy_start_workers(self, func): "requirements acceptable by this pool. Try creating a " "different pool for it." ) - zpool_worker.__fiber_meta__ = func.__fiber_meta__ + zpool_worker.__fiber_meta__ = func.__fiber_meta__ # type: ignore if not self._worker_handler_started: self.start_workers() - def map_async(self, func, iterable, chunksize=None, callback=None, - error_callback=None): + def map_async(self, func: Callable, iterable: Sequence, chunksize: int = None, callback: Callable = None, + error_callback: Callable = None) -> MapResult: """ For each element `e` in `iterable`, run `func(e)`. The workload is distributed between all the Pool workers. This is an asynchronous @@ -1183,7 +1217,7 @@ def map_async(self, func, iterable, chunksize=None, callback=None, return res - def apply(self, func, args=(), kwds={}): + def apply(self, func: Callable, args: Sequence = (), kwds: Dict ={}) -> Any: """ Run function `func` with arguments `args` and keyword arguments `kwds` on a remote Pool worker. @@ -1196,7 +1230,7 @@ def apply(self, func, args=(), kwds={}): """ return self.apply_async(func, args, kwds).get() - def map(self, func, iterable, chunksize=None): + def map(self, func: Callable, iterable: Sequence, chunksize: int = None) -> List[Any]: """ For each element `e` in `iterable`, run `func(e)`. The workload is distributed between all the Pool workers. @@ -1215,7 +1249,7 @@ def map(self, func, iterable, chunksize=None): logger.debug('%s map func=%s', self, func) return self.map_async(func, iterable, chunksize).get() - def imap(self, func, iterable, chunksize=1): + def imap(self, func: Callable, iterable: Sequence, chunksize: int = 1) -> Iterator[Any]: """ For each element `e` in `iterable`, run `func(e)`. The workload is distributed between all the Pool workers. This function returns an @@ -1234,7 +1268,7 @@ def imap(self, func, iterable, chunksize=1): res = self.map_async(func, iterable, chunksize) return res.iget_ordered() - def imap_unordered(self, func, iterable, chunksize=1): + def imap_unordered(self, func: Callable, iterable: Sequence, chunksize: int = 1) -> Iterator[Any]: """ For each element `e` in `iterable`, run `func(e)`. The workload is distributed between all the Pool workers. This function returns an @@ -1255,8 +1289,8 @@ def imap_unordered(self, func, iterable, chunksize=1): res = self.map_async(func, iterable, chunksize) return res.iget_unordered() - def starmap_async(self, func, iterable, chunksize=None, callback=None, - error_callback=None): + def starmap_async(self, func: Callable, iterable: Sequence, chunksize: int = None, callback: Callable= None, + error_callback: Callable = None) -> MapResult: """ For each element `args` in `iterable`, run `func(*args)`. The workload is distributed between all the Pool workers. This is an asynchronous @@ -1304,7 +1338,7 @@ def starmap_async(self, func, iterable, chunksize=None, callback=None, return res - def starmap(self, func, iterable, chunksize=None): + def starmap(self, func: Callable, iterable: Sequence, chunksize: int = None) -> MapResult: """ For each element `args` in `iterable`, run `func(*args)`. The workload is distributed between all the Pool workers. @@ -1329,12 +1363,12 @@ def starmap(self, func, iterable, chunksize=None): """ return self.starmap_async(func, iterable, chunksize).get() - def _send_sentinels_to_workers(self): + def _send_sentinels_to_workers(self) -> None: logger.debug('send sentinels(None) to workers %s', self) for i in range(self._processes): self._task_put(None) - def close(self): + def close(self) -> None: """ Close this Pool. This means the current pool will be put in to a closing state and it will not accept new tasks. Existing workers will @@ -1344,15 +1378,15 @@ def close(self): logger.debug('closing pool %s', self) if self._state == RUN: self._state = CLOSE - self._worker_handler._state = CLOSE + self._worker_handler._state = CLOSE # type: ignore for p in self._pool: if hasattr(p, '_sentinel'): - p._state = CLOSE + p._state = CLOSE # type: ignore self._send_sentinels_to_workers() - def terminate(self): + def terminate(self) -> None: """ Terminate this pool. This means that this pool will be terminated and all its pool workers will also be terminated. Task that have been @@ -1361,11 +1395,11 @@ def terminate(self): logger.debug('terminating pool %s', self) logger.debug('set pool._worker_handler.status = TERMINATE') - self._worker_handler._state = TERMINATE - self._state = TERMINATE + self._worker_handler._state = TERMINATE # type: ignore + self._state = TERMINATE # type: ignore for p in self._pool: - p._state = TERMINATE + p._state = TERMINATE # type: ignore pool = self._pool N = min(100, len(pool)) @@ -1387,7 +1421,7 @@ def terminate(self): logger.debug("joining pool._worker_handler") self._worker_handler.join() - def join(self): + def join(self) -> None: """ Wait for all the pool workers of this pool to exit. This should be used after `terminate()` or `close()` are called on this pool. @@ -1396,13 +1430,13 @@ def join(self): assert self._state in (TERMINATE, CLOSE) for p in self._pool: - if p._state not in (TERMINATE, CLOSE): + if p._state not in (TERMINATE, CLOSE): # type: ignore logger.debug("%s.join() ignore newly connected Process %s", self, p) continue p.join() - def wait_until_workers_up(self): + def wait_until_workers_up(self) -> None: logger.debug('%s begin wait_until_workers_up', self) workers_per_fp = config.cpu_per_job n = math.ceil(float(self._processes) / workers_per_fp) @@ -1414,7 +1448,7 @@ def wait_until_workers_up(self): for p in self._pool: logger.debug('%s waiting for _sentinel %s', self, p) - while not hasattr(p, '_sentinel') or p._sentinel is None: + while not hasattr(p, '_sentinel') or p._sentinel is None: # type: ignore time.sleep(0.5) # now all the worker has connected to the master, wait # for some additional time to be sure. @@ -1434,14 +1468,19 @@ class ResilientZPool(ZPool): The API of `ResilientZPool` is the same as `ZPool`. One difference is that if `processes` argument is not set, its default value is 1. """ - def __init__(self, processes=None, initializer=None, initargs=(), - maxtasksperchild=None, cluster=None): - - self.active_peer_dict = {} - self.active_peer_list = [] + def __init__( + self, processes: int = None, + initializer: Callable = None, + initargs: Sequence = (), + maxtasksperchild: int = None, + cluster=None + ) -> None: + + self.active_peer_dict: Dict[str, bool] = {} + self.active_peer_list: List[str] = [] self.peer_lock = threading.Lock() self.taskq = queue.Queue() - self._pending_table = {} + self._pending_table: Dict[str, Dict] = {} super(ResilientZPool, self).__init__( processes=processes, @@ -1467,9 +1506,9 @@ def __init__(self, processes=None, initializer=None, initargs=(), self._task_handler = td ''' - self._pid_to_rid = {} + self._pid_to_rid: Dict[int, str] = {} - def _add_peer(self, ident): + def _add_peer(self, ident: str) -> None: self.peer_lock.acquire() self.active_peer_dict[ident] = True @@ -1478,7 +1517,7 @@ def _add_peer(self, ident): self.peer_lock.release() - def _remove_peer(self, ident): + def _remove_peer(self, ident: str) -> None: # _pendint_table will be cleared later in error handling phase self.peer_lock.acquire() @@ -1487,7 +1526,7 @@ def _remove_peer(self, ident): self.peer_lock.release() - def _res_get(self): + def _res_get(self) -> Any: # check for system messages payload = self._result_sock.recv() data = pickle.loads(payload) @@ -1507,16 +1546,16 @@ def _res_get(self): # skip ident return data[:-1] - def _task_put(self, task): + def _task_put(self, task: Union[None, Tuple[int, int, Callable, Sequence, bool]]) -> None: self.taskq.put(task) - def _handle_tasks(self): + def _handle_tasks(self) -> None: thread = threading.current_thread() taskq = self.taskq master_sock = self._master_sock pending_table = self._pending_table - while thread._state == RUN: + while thread._state == RUN: # type: ignore task = taskq.get() if task is None: @@ -1555,8 +1594,15 @@ def _handle_tasks(self): logger.debug('ResilientZPool _handle_tasks exited') @staticmethod - def _maintain_workers(processes, workers, master_addr, result_addr, initializer, - initargs, maxtasksperchild): + def _maintain_workers( + processes: int, + workers: List[fiber.process.Process], + master_addr: str, + result_addr: str, + initializer: Optional[Callable], + initargs: Optional[Sequence], + maxtasksperchild: Optional[int] + ) -> None: thread = threading.current_thread() workers_per_fp = config.cpu_per_job @@ -1566,7 +1612,7 @@ def _maintain_workers(processes, workers, master_addr, result_addr, initializer, workers) threads = [] - while left > 0 and thread._state == RUN: + while left > 0 and thread._state == RUN: # type: ignore if left > workers_per_fp: n = workers_per_fp @@ -1610,7 +1656,7 @@ def _maintain_workers(processes, workers, master_addr, result_addr, initializer, workers) @staticmethod - def _handle_workers(pool): + def _handle_workers(pool: "ResilientZPool") -> None: # type: ignore[override] logger.debug("%s _handle_workers running", pool) td = threading.current_thread() @@ -1620,7 +1666,7 @@ def _handle_workers(pool): pool._initargs, pool._maxtasksperchild ) - while td._state == RUN: + while td._state == RUN: # type: ignore exited_workers = ResilientZPool._join_exited_workers(pool._pool) if len(exited_workers) > 0: # create new workers when old workers exited @@ -1636,6 +1682,12 @@ def _handle_workers(pool): logger.debug("Resubmitting tasks from failed workers") for worker in exited_workers: + if worker.pid is None: + logger.warn( + "Can't get pid for failed worker: {}".format(worker) + ) + continue + rid = pool._pid_to_rid[worker.pid] # remove rid from active peers pool._remove_peer(rid) @@ -1658,34 +1710,33 @@ def _handle_workers(pool): logger.debug("%s _handle_workers finished. Status is not RUN", pool) - def terminate(self): - self._task_handler._state = TERMINATE + def terminate(self) -> None: + self._task_handler._state = TERMINATE # type: ignore super(ResilientZPool, self).terminate() - def close(self): + def close(self) -> None: logger.debug('closing pool %s', self) if self._state == RUN: self._state = CLOSE - self._worker_handler._state = CLOSE + self._worker_handler._state = CLOSE # type: ignore for p in self._pool: if hasattr(p, '_sentinel'): - p._state = CLOSE + p._state = CLOSE # type: ignore #self._send_sentinels_to_workers() #logger.debug("ResilientZPool _send_sentinels_to_workers: " # "send to task handler") self._task_put(None) - self._task_handler._state = CLOSE + self._task_handler._state = CLOSE # type: ignore - def _send_sentinels_to_workers(self): + def _send_sentinels_to_workers(self) -> None: logger.debug("ResilientZPool _send_sentinels_to_workers: " "send to workers") - data = pickle.dumps(None) for ident in self.active_peer_list: - self._master_sock.send_multipart([ident, b"", data]) + self._master_sock.send(None) #Pool = ZPool diff --git a/fiber/process.py b/fiber/process.py index 57ecf74..af7584f 100644 --- a/fiber/process.py +++ b/fiber/process.py @@ -269,7 +269,9 @@ def ident(self) -> Optional[int]: def ident(self, pid: int): self._pid = pid - pid = ident + @property + def pid(self) -> Optional[int]: + return self.ident @property def target(self): From aec2094f264bd5af1ff879c3f15a45b98fa8fd52 Mon Sep 17 00:00:00 2001 From: Jiale Zhi Date: Mon, 7 Sep 2020 12:47:11 -0700 Subject: [PATCH 22/23] [type-hints] Add type hints to fiber/managers.py --- fiber/managers.py | 191 +++++++++++++++++++++++++++++++--------------- 1 file changed, 130 insertions(+), 61 deletions(-) diff --git a/fiber/managers.py b/fiber/managers.py index c5d8cfb..c72e244 100644 --- a/fiber/managers.py +++ b/fiber/managers.py @@ -23,17 +23,29 @@ import multiprocessing.util as mp_util import queue import threading -from multiprocessing.connection import (SocketListener, _validate_family, - address_type) +from multiprocessing.connection import ( # type: ignore + SocketListener, _validate_family, address_type +) from multiprocessing.context import get_spawning_popen -from multiprocessing.managers import (Array, DictProxy, ListProxy, Namespace, - RebuildProxy, State, Value, dispatch) -from multiprocessing.process import AuthenticationString +from multiprocessing.managers import ( # type: ignore + Array, DictProxy, ListProxy, Namespace, + RebuildProxy, State, Value, dispatch +) +from multiprocessing.process import AuthenticationString # type: ignore import fiber.util import fiber.queues from fiber import process from fiber.backend import get_backend +from typing import ( + Any, TypeVar, Callable, Sequence, Dict, Union, Tuple, Optional, + Iterable, Mapping +) + +_T0 = TypeVar('_T0') +_T1 = TypeVar('_T1') +_TAsyncListProxy = TypeVar('_TAsyncListProxy', bound="AsyncListProxy") +Address = Union[str, Tuple[Union[str, bytes], int], None] logger = logging.getLogger('fiber') @@ -42,7 +54,12 @@ class Listener(multiprocessing.connection.Listener): - def __init__(self, address=None, family=None, backlog=500, authkey=None): + def __init__( + self, address: Address = None, + family: str = None, + backlog: int = 500, + authkey: bytes = None + ) -> None: family = family or (address and address_type(address)) \ or default_family if family != 'AF_INET': @@ -51,7 +68,10 @@ def __init__(self, address=None, family=None, backlog=500, authkey=None): backend = get_backend() # TODO(jiale) Add support for other address family for # backend.get_listen_addr - address = address or backend.get_listen_addr() + if address is None: + listen_addr = backend.get_listen_addr() + address = (listen_addr[0], listen_addr[1]) + self._address = address _validate_family(family) @@ -72,7 +92,7 @@ def __init__(self, address=None, family=None, backlog=500, authkey=None): # Use IP from `backend.get_listen_addr` and port from # self._listener._address. # This represents an address that other client can connect to - address = property(lambda self: (self._address[0], + address = property(lambda self: (self._address[0], # type: ignore self._listener._address[1])) @@ -85,7 +105,8 @@ def __init__(self, address=None, family=None, backlog=500, authkey=None): class Server(multiprocessing.managers.Server): - def __init__(self, registry, address, authkey, serializer): + def __init__(self, registry: Dict, address: Address, + authkey: bytes, serializer: str) -> None: assert isinstance(authkey, bytes) self.registry = registry self.authkey = AuthenticationString(authkey) @@ -93,11 +114,11 @@ def __init__(self, registry, address, authkey, serializer): # do authentication later self.listener = Listener(address=address, backlog=500) - self.address = self.listener.address + self.address = self.listener.address # type: ignore self.id_to_obj = {'0': (None, ())} - self.id_to_refcount = {} - self.id_to_local_proxy_obj = {} + self.id_to_refcount: Dict = {} + self.id_to_local_proxy_obj: Dict = {} self.mutex = threading.Lock() @@ -111,11 +132,11 @@ class BaseManager(multiprocessing.managers.BaseManager): The API of this class is the same as [multiprocessing.managers.BaseManager](https://docs.python.org/3.6/library/multiprocessing.html#multiprocessing.managers.BaseManager) """ - _registry = {} + _registry: Dict = {} _Server = Server - def __init__(self, address=None, authkey=None, serializer='pickle', - ctx=None): + def __init__(self, address: str = None, authkey: bytes = None, + serializer: str = 'pickle', ctx=None) -> None: if authkey is None: authkey = process.current_process().authkey self._address = address # XXX not final address if eg ('', 0) @@ -125,7 +146,7 @@ def __init__(self, address=None, authkey=None, serializer='pickle', self._serializer = serializer self._Listener, self._Client = listener_client[serializer] - def get_server(self): + def get_server(self) -> Server: """ Return server object with serve_forever() method and address attribute """ @@ -134,8 +155,11 @@ def get_server(self): self._authkey, self._serializer) @classmethod - def _run_server(cls, registry, address, authkey, serializer, writer, - initializer=None, initargs=()): + def _run_server( + cls, registry: Dict, address: str, authkey: bytes, serializer: str, + writer: fiber.queues.LazyZConnection, + initializer: Callable = None, initargs: Iterable[Any] = () + ) -> None: """Create a server, report its address and run it.""" if initializer is not None: initializer(*initargs) @@ -151,7 +175,11 @@ def _run_server(cls, registry, address, authkey, serializer, writer, logger.info('manager serving at %r', server.address) server.serve_forever() - def start(self, initializer=None, initargs=()): + def start( + self, + initializer: Callable = None, + initargs: Iterable[Any] = () + ) -> None: """Spawn a server process for this manager object.""" assert self._state.value == State.INITIAL logger.debug("start manager %s", self) @@ -168,7 +196,7 @@ def start(self, initializer=None, initargs=()): args=(self._registry, self._address, self._authkey, self._serializer, writer, initializer, initargs), ) - ident = ':'.join(str(i) for i in self._process._identity) + ident = ':'.join(str(i) for i in self._process._identity) # type: ignore[attr-defined] self._process.name = type(self).__name__ + '-' + ident self._process.start() @@ -179,16 +207,22 @@ def start(self, initializer=None, initargs=()): # register a finalizer self._state.value = State.STARTED - self.shutdown = mp_util.Finalize( - self, type(self)._finalize_manager, + self.shutdown = mp_util.Finalize( # type: ignore[assignment] + self, type(self)._finalize_manager, # type: ignore[attr-defined] args=(self._process, self._address, self._authkey, self._state, self._Client), exitpriority=0 ) @classmethod - def register(cls, typeid, callable=None, proxytype=None, exposed=None, - method_to_typeid=None, create_method=True): + def register( + cls, typeid: str, + callable: Callable = None, + proxytype: Any = None, + exposed: Sequence = None, + method_to_typeid: Optional[Mapping[str, str]] = None, + create_method: bool = True + ) -> None: """Register a typeid with the manager type.""" if '_registry' not in cls.__dict__: cls._registry = cls._registry.copy() @@ -227,7 +261,7 @@ def temp(self, *args, **kwds): class ProcessLocalSet(set): - def __init__(self): + def __init__(self) -> None: fiber.util.register_after_fork(self, lambda obj: obj.clear()) def __reduce__(self): @@ -236,11 +270,18 @@ def __reduce__(self): class BaseProxy(multiprocessing.managers.BaseProxy): """A base for proxies of shared objects.""" - _address_to_local = {} + _address_to_local: Dict = {} _mutex = fiber.util.ForkAwareThreadLock() - def __init__(self, token, serializer, manager=None, - authkey=None, exposed=None, incref=True, manager_owned=False): + def __init__( + self, token: multiprocessing.managers.Token, + serializer: str, + manager: BaseManager = None, + authkey: bytes = None, + exposed: Sequence = None, + incref: bool = True, + manager_owned: bool = False + ) -> None: with BaseProxy._mutex: tls_idset = BaseProxy._address_to_local.get(token.address, None) if tls_idset is None: @@ -276,16 +317,16 @@ def __init__(self, token, serializer, manager=None, self._authkey = process.current_process().authkey if incref: - self._incref() + self._incref() # type: ignore[attr-defined] - fiber.util.register_after_fork(self, BaseProxy._after_fork) + fiber.util.register_after_fork(self, BaseProxy._after_fork) # type: ignore[attr-defined] - def connect(self): + def connect(self) -> None: """Connect manager object to the server process.""" Listener, Client = listener_client[self._serializer] - conn = Client(self._address, authkey=self._authkey) + conn = Client(self._address, authkey=self._authkey) # type: ignore[attr-defined] dispatch(conn, None, 'dummy') - self._state.value = State.STARTED + self._state.value = State.STARTED # type: ignore[attr-defined] def __reduce__(self): kwds = {} @@ -301,13 +342,20 @@ def __reduce__(self): (type(self), self._token, self._serializer, kwds)) -def AutoProxy(token, serializer, manager=None, authkey=None, - exposed=None, incref=True): +def AutoProxy( + token: multiprocessing.managers.Token, + serializer: str, + manager: BaseManager = None, + authkey: bytes = None, + exposed: Sequence = None, + incref: bool = True +) -> BaseProxy: """Return an auto-proxy for `token`.""" _Client = listener_client[serializer][1] if exposed is None: - conn = _Client(token.address, authkey=authkey) + _address = (str(token.address[0]), token.address[1]) + conn = _Client(_address, authkey=authkey) try: exposed = dispatch(conn, None, 'get_methods', (token,)) finally: @@ -318,14 +366,14 @@ def AutoProxy(token, serializer, manager=None, authkey=None, if authkey is None: authkey = process.current_process().authkey - ProxyType = MakeProxyType('AutoProxy[%s]' % token.typeid, exposed) + ProxyType = MakeProxyType('AutoProxy[{}]'.format(str(token.typeid)), exposed) proxy = ProxyType(token, serializer, manager=manager, authkey=authkey, incref=incref) proxy._isauto = True return proxy -def MakeProxyType(name, exposed, _cache={}): +def MakeProxyType(name: str, exposed: Sequence, _cache: Dict = {}) -> Any: """Return a proxy type whose methods are given by `exposed`.""" exposed = tuple(exposed) try: @@ -333,14 +381,14 @@ def MakeProxyType(name, exposed, _cache={}): except KeyError: pass - dic = {} + dic: Dict = {} for meth in exposed: exec("""def %s(self, *args, **kwds): return self._callmethod(%r, args, kwds)""" % (meth, meth), dic) ProxyType = type(name, (BaseProxy,), dic) - ProxyType._exposed_ = exposed + ProxyType._exposed_ = exposed # type: ignore[attr-defined] _cache[(name, exposed)] = ProxyType return ProxyType @@ -409,7 +457,7 @@ def __setattr__(self, key, value): callmethod = object.__getattribute__(self, '_callmethod') return callmethod('__setattr__', (key, value)) - def __delattr__(self, key): + def __delattr__(self, key) -> None: if key[0] == '_': return object.__delattr__(self, key) callmethod = object.__getattribute__(self, '_callmethod') @@ -419,10 +467,10 @@ def __delattr__(self, key): class ValueProxy(BaseProxy): _exposed_ = ('get', 'set') - def get(self): + def get(self) -> Any: return self._callmethod('get') - def set(self, value): + def set(self, value) -> None: return self._callmethod('set', (value,)) value = property(get, set) @@ -431,11 +479,11 @@ def set(self, value): # Define AsyncProxy and AsyncManager class AsyncProxyResult(): - def __init__(self, conn, proxy): + def __init__(self, conn, proxy) -> None: self._conn = conn self._proxy = proxy - def get(self): + def get(self) -> Any: kind, result = self._conn.recv() if kind == '#RETURN': @@ -446,11 +494,18 @@ def get(self): class AsyncBaseProxy(BaseProxy): - def _callmethod(self, methodname, args=(), kwds={}): + + # BaseProxy's _callmethod doesn't return anything. But async base proxy + # needs to return something + def _callmethod( # type: ignore[override] + self, methodname: str, + args: Sequence = (), + kwds: Dict = {} + ) -> AsyncProxyResult: try: conn = self._tls.connection except AttributeError: - self._connect() + self._connect() # type: ignore[attr-defined] conn = self._tls.connection conn.send((self._id, methodname, args, kwds)) @@ -458,7 +513,7 @@ def _callmethod(self, methodname, args=(), kwds={}): return AsyncProxyResult(conn, self) -def MakeAsyncProxyType(name, exposed, _cache={}): +def MakeAsyncProxyType(name: str, exposed: Sequence, _cache: Dict = {}) -> type: """Return a proxy type whose methods are given by `exposed`.""" exposed = tuple(exposed) try: @@ -466,25 +521,32 @@ def MakeAsyncProxyType(name, exposed, _cache={}): except KeyError: pass - dic = {} + dic: Dict = {} for meth in exposed: exec("""def %s(self, *args, **kwds): return self._callmethod(%r, args, kwds)""" % (meth, meth), dic) ProxyType = type(name, (AsyncBaseProxy,), dic) - ProxyType._exposed_ = exposed + ProxyType._exposed_ = exposed # type: ignore[attr-defined] _cache[(name, exposed)] = ProxyType return ProxyType -def AsyncAutoProxy(token, serializer, manager=None, authkey=None, - exposed=None, incref=True): +def AsyncAutoProxy( + token: multiprocessing.managers.Token, + serializer: str, + manager: BaseManager = None, + authkey: bytes = None, + exposed: Sequence = None, + incref: bool = True +) -> Any: """Return an auto-proxy for `token`.""" _Client = listener_client[serializer][1] if exposed is None: - conn = _Client(token.address, authkey=authkey) + _address = (str(token.address[0]), token.address[1]) + conn = _Client(_address, authkey=authkey) try: exposed = dispatch(conn, None, 'get_methods', (token,)) finally: @@ -495,7 +557,7 @@ def AsyncAutoProxy(token, serializer, manager=None, authkey=None, if authkey is None: authkey = process.current_process().authkey - ProxyType = MakeAsyncProxyType('AutoProxy[%s]' % token.typeid, exposed) + ProxyType = MakeAsyncProxyType('AutoProxy[{}]'.format(str(token.typeid)), exposed) proxy = ProxyType(token, serializer, manager=manager, authkey=authkey, incref=incref) proxy._isauto = True @@ -505,11 +567,11 @@ def AsyncAutoProxy(token, serializer, manager=None, authkey=None, class AsyncValueProxy(AsyncBaseProxy): _exposed_ = ('get', 'set') - def get(self): + def get(self) -> AsyncProxyResult: return self._callmethod('get') - def set(self, value): - return self._callmethod('set', (value,)) + def set(self, value) -> None: + self._callmethod('set', (value,)) value = property(get, set) @@ -547,8 +609,15 @@ class MyManager(AsyncManager): ``` """ @classmethod - def register(cls, typeid, callable=None, proxytype=None, exposed=None, - method_to_typeid=None, create_method=True): + def register( + cls, + typeid: str, + callable: Callable = None, + proxytype: Any = None, + exposed: Sequence = None, + method_to_typeid: Optional[Mapping[str, str]] = None, + create_method: bool = True + ) -> None: """Register a typeid with the manager type.""" if '_registry' not in cls.__dict__: cls._registry = cls._registry.copy() @@ -593,7 +662,7 @@ def temp(self, *args, **kwds): )) -class AsyncListProxy(AsyncBaseListProxy): +class AsyncListProxy(AsyncBaseListProxy): # type: ignore def __iadd__(self, value): self._callmethod('extend', (value,)) return self @@ -608,7 +677,7 @@ def __imul__(self, value): '__setitem__', 'clear', 'copy', 'get', 'has_key', 'items', 'keys', 'pop', 'popitem', 'setdefault', 'update', 'values' )) -AsyncDictProxy._method_to_typeid_ = { +AsyncDictProxy._method_to_typeid_ = { # type: ignore[attr-defined] '__iter__': 'Iterator', } From 6c9bf59003c499261e70c20c352b976b15c7e134 Mon Sep 17 00:00:00 2001 From: Jiale Zhi Date: Mon, 7 Sep 2020 13:09:20 -0700 Subject: [PATCH 23/23] [type-hints] Add type hints to fiber/experimental/ring.py --- fiber/experimental/ring.py | 39 +++++++++++++++++++++++++------------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/fiber/experimental/ring.py b/fiber/experimental/ring.py index d8522cb..0bcc6cc 100644 --- a/fiber/experimental/ring.py +++ b/fiber/experimental/ring.py @@ -21,6 +21,9 @@ import fiber from fiber.backend import get_backend +from typing import Any, NoReturn, Callable, Sequence, Dict, Iterator, Optional, List + +_manager: Any __all__ = [ @@ -32,10 +35,10 @@ _manager = None -def _get_manager(): +def _get_manager() -> fiber.managers.BaseManager: global _manager if _manager is None: - _manager = fiber.Manager() + _manager = fiber.Manager() # type: ignore return _manager @@ -48,7 +51,7 @@ class RingNode: :param rank: The id assigned to this node. Each node will be assigned a unique id called `rank`. Rank 0 is the control node of the `Ring`. """ - def __init__(self, rank): + def __init__(self, rank) -> None: self.rank = rank self.connected = False self.ip = None @@ -68,7 +71,17 @@ class Ring: :param initargs: positional arguments that are passed to initializer. Currently this is not used. """ - def __init__(self, processes, func, initializer, initargs=None): + + __fiber_meta__: Dict + + def __init__( + self, + processes: int, + func: Callable[[int, int], Any], + initializer: Callable[[Optional[Iterator[Any]]], None], + initargs: Iterator[Any] = None + ) -> None: + self.size = processes self.initializer = initializer self.initargs = initargs @@ -79,12 +92,12 @@ def __init__(self, processes, func, initializer, initargs=None): # Propogate meta info # We can't set attributes to bound/unbound methods (PEP 232), # so we set it to Ring object - self.__fiber_meta__ = func.__fiber_meta__ + self.__fiber_meta__ = func.__fiber_meta__ # type: ignore manager = _get_manager() - self.members = manager.list([RingNode(i) for i in range(self.size)]) + self.members = manager.list([RingNode(i) for i in range(self.size)]) # type: ignore - def _target(self): + def _target(self) -> None: rank = self.rank node = self.members[rank] @@ -97,10 +110,10 @@ def _target(self): node.port = port self.members[rank] = node - self.initializer(self) + self.initializer(self.initargs) self.func(rank, self.size) - def run(self): + def run(self) -> None: """ Start this Ring. This will start the ring 0 process on the same machine and start all the other ring nodes with Fiber processes. @@ -113,15 +126,15 @@ def run(self): # Start process rank 0 self.rank = 0 ctx = mp.get_context("spawn") - p = ctx.Process(target=self._target) - p.start() - procs.append(p) + p0 = ctx.Process(target=self._target) + p0.start() + procs.append(p0) for i in range(1, self.size): self.rank = i p = fiber.Process(target=self._target) p.start() - procs.append(p) + procs.append(p) # type: ignore[arg-type] self.rank = rank # wait for all processes to finish