diff --git a/fiber/__init__.py b/fiber/__init__.py index 8901d3b..191f6f6 100644 --- a/fiber/__init__.py +++ b/fiber/__init__.py @@ -20,6 +20,12 @@ from fiber import context from fiber.init import init_fiber from fiber.meta import meta +from typing import List +from typing import TYPE_CHECKING + +__version__: str +_in_interactive_console: bool +_names: List[str] __version__ = "0.2.1" @@ -47,11 +53,11 @@ _in_interactive_console = False -def reset(): +def reset() -> None: init_fiber() -def init(**kwargs): +def init(**kwargs) -> None: """ Initialize Fiber. This function is called when you want to re-initialize Fiber with new config values and also re-init loggers. @@ -66,3 +72,15 @@ def init(**kwargs): globals().update((name, getattr(context._default_context, name)) for name in _names) __all__ = _names + [] + + +if TYPE_CHECKING: + current_process = context.FiberContext.current_process + active_children = context.FiberContext.active_children + Process = context.FiberContext.Process + Manager = context.FiberContext.Manager + Pool = context.FiberContext.Pool + SimpleQueue = context.FiberContext.SimpleQueue + Pipe = context.FiberContext.Pipe + cpu_count = context.FiberContext.cpu_count + get_context = context.FiberContext.get_context diff --git a/fiber/backend.py b/fiber/backend.py index 54ae58d..a4d9c84 100644 --- a/fiber/backend.py +++ b/fiber/backend.py @@ -18,19 +18,22 @@ import multiprocessing as mp import fiber.config as config +from fiber.core import Backend + +_backends: dict _backends = {} available_backend = ['kubernetes', 'docker', 'local'] -def is_inside_kubenetes_job(): +def is_inside_kubenetes_job() -> bool: if os.environ.get("KUBERNETES_SERVICE_HOST", None): return True return False -def is_inside_docker_job(): +def is_inside_docker_job() -> bool: if os.environ.get("FIBER_BACKEND", "") == "docker": return True return False @@ -42,7 +45,7 @@ def is_inside_docker_job(): } -def auto_select_backend(): +def auto_select_backend() -> str: for backend_name, test in BACKEND_TESTS.items(): if test(): name = backend_name @@ -53,7 +56,7 @@ def auto_select_backend(): return name -def get_backend(name=None, **kwargs): +def get_backend(name=None, **kwargs) -> Backend: """ Returns a working Fiber backend. If `name` is specified, returns a backend specified by `name`. @@ -70,7 +73,8 @@ def get_backend(name=None, **kwargs): _backend = _backends.get(name, None) if _backend is None: - _backend = importlib.import_module("fiber.{}_backend".format( - name)).Backend(**kwargs) + backend_name = "fiber.{}_backend".format(name) + backend_module = importlib.import_module(backend_name) + _backend = backend_module.Backend(**kwargs) # type: ignore _backends[name] = _backend return _backend diff --git a/fiber/cli.py b/fiber/cli.py index fceb84d..687feb2 100644 --- a/fiber/cli.py +++ b/fiber/cli.py @@ -34,18 +34,27 @@ import fiber import fiber.core as core from fiber.core import ProcessStatus +import pathlib +from typing import Any, List, Tuple, TypeVar, Optional, Union, Dict + +_T0 = TypeVar("_T0") +_TDockerImageBuilder = TypeVar( + "_TDockerImageBuilder", bound="DockerImageBuilder" +) +CONFIG: Dict[str, str] CONFIG = {} -def get_backend(platform): +def get_backend(platform: str) -> fiber.kubernetes_backend.Backend: from fiber.kubernetes_backend import Backend as K8sBackend + backend = K8sBackend(incluster=False) return backend -def find_docker_files(): +def find_docker_files() -> List[pathlib.Path]: """Find all possible docker files on current directory.""" p = Path(".") q = p / "Dockerfile" @@ -58,7 +67,7 @@ def find_docker_files(): return files -def select_docker_file(files): +def select_docker_file(files: List[pathlib.Path]) -> pathlib.Path: """Ask user which docker file to use and return a PurePath object.""" num = 0 n = len(files) @@ -89,7 +98,7 @@ def select_docker_file(files): return files[num] -def get_default_project_gcp(): +def get_default_project_gcp() -> str: """Get default GCP project name.""" name = sp.check_output( "gcloud config list --format 'value(core.project)' 2>/dev/null", @@ -98,7 +107,7 @@ def get_default_project_gcp(): return name.decode("utf-8").strip() -def parse_file_path(path): +def parse_file_path(path: str) -> Tuple[Optional[str], str]: parts = path.split(":") if len(parts) == 1: return (None, path) @@ -112,7 +121,7 @@ def parse_file_path(path): @click.command() @click.argument("src") @click.argument("dst") -def cp(src, dst): +def cp(src: str, dst: str) -> None: """Copy file from a persistent storage""" platform = CONFIG["platform"] @@ -125,12 +134,14 @@ def cp(src, dst): ) if parts_src[0]: - volume = parts_src[0] - elif parts_dst[0]: - volume = parts_dst[0] + parsed_volume = parts_src[0] else: + parsed_volume = parts_dst[0] + + if parsed_volume is None: raise ValueError("Must copy/to from a persistent volume") + volume:str = parsed_volume k8s_backend = get_backend(platform) job_spec = core.JobSpec( @@ -140,6 +151,9 @@ def cp(src, dst): volumes={volume: {"mode": "rw", "bind": "/persistent"}}, ) job = k8s_backend.create_job(job_spec) + if job.data is None: + raise RuntimeError("Failed to create a new job for data copying") + pod_name = job.data.metadata.name print("launched pod: {}".format(pod_name)) @@ -170,7 +184,7 @@ def cp(src, dst): # k8s_backend.terminate_job(job) -def detect_platforms(): +def detect_platforms() -> List[str]: commands = ["gcloud", "aws"] platforms = ["gcp", "aws"] found_platforms = [] @@ -186,7 +200,7 @@ def detect_platforms(): return found_platforms -def prompt_choices(choices, prompt): +def prompt_choices(choices: List[_T0], prompt: str) -> _T0: num = 0 n = len(choices) @@ -216,13 +230,13 @@ def prompt_choices(choices, prompt): class DockerImageBuilder: - def __init__(self, registry=""): + def __init__(self, registry: str = "") -> None: self.registry = registry - def get_docker_registry_image_name(image_base_name): + def get_docker_registry_image_name(self, image_base_name: str) -> str: return image_base_name - def build(self): + def build(self) -> str: files = find_docker_files() n = len(files) if n == 0: @@ -248,25 +262,26 @@ def build(self): return self.full_image_name - def tag(self): + def tag(self) -> str: self.full_image_name = self.image_name + return self.full_image_name - def push(self): + def push(self) -> None: sp.check_call( "docker push {}".format(self.full_image_name), shell=True, ) - def docker_tag(self, in_name, out_name): + def docker_tag(self, in_name: str, out_name: str) -> None: sp.check_call("docker tag {} {}".format(in_name, out_name), shell=True) class AWSImageBuilder(DockerImageBuilder): - def __init__(self, registry): + def __init__(self, registry: str) -> None: self.registry = registry parts = registry.split(".") self.region = parts[-3] - def tag(self): + def tag(self) -> str: image_name = self.image_name full_image_name = "{}/{}".format(self.registry, self.image_name) @@ -275,7 +290,7 @@ def tag(self): self.full_image_name = full_image_name return full_image_name - def need_new_repo(self): + def need_new_repo(self) -> bool: output = sp.check_output( "aws ecr describe-repositories --region {}".format(self.region), shell=True, @@ -293,7 +308,7 @@ def need_new_repo(self): return True - def create_repo_if_needed(self): + def create_repo_if_needed(self) -> None: if self.need_new_repo(): sp.check_call( "aws ecr create-repository --region {} --repository-name {}".format( @@ -304,7 +319,7 @@ def create_repo_if_needed(self): return - def push(self): + def push(self) -> None: self.create_repo_if_needed() try: @@ -320,10 +335,10 @@ def push(self): class GCPImageBuilder(DockerImageBuilder): - def __init__(self, registry="gcr.io"): + def __init__(self, registry: str = "gcr.io") -> None: self.registry = registry - def tag(self): + def tag(self) -> str: image_name = self.image_name proj = get_default_project_gcp() @@ -343,7 +358,15 @@ def tag(self): @click.option("--memory") @click.option("-v", "--volume") @click.argument("args", nargs=-1) -def run(attach, image, gpu, cpu, memory, volume, args): +def run( + attach: bool, + image: str, + gpu: int, + cpu: int, + memory: int, + volume: str, + args: List[str], +) -> int: """Run a command on a kubernetes cluster with fiber.""" platform = CONFIG["platform"] print( @@ -352,6 +375,8 @@ def run(attach, image, gpu, cpu, memory, volume, args): ) ) + builder: DockerImageBuilder + if image: full_image_name = image else: @@ -393,6 +418,9 @@ def run(attach, image, gpu, cpu, memory, volume, args): job_spec.volumes = volumes job = k8s_backend.create_job(job_spec) + if job.data is None: + raise RuntimeError("Failed to create a new job") + pod_name = job.data.metadata.name exitcode = 0 @@ -414,7 +442,7 @@ def run(attach, image, gpu, cpu, memory, volume, args): return 0 -def auto_select_platform(): +def auto_select_platform() -> str: platforms = detect_platforms() if len(platforms) > 1: choice = prompt_choices( @@ -438,7 +466,7 @@ def auto_select_platform(): "--gcp", is_flag=True, help="Run commands on Google Cloud Platform" ) @click.version_option(version=fiber.__version__, prog_name="fiber") -def main(docker_registry, aws, gcp): +def main(docker_registry: str, aws: bool, gcp: bool) -> None: """fiber command line tool that helps to manage workflow of distributed fiber applications. """ diff --git a/fiber/config.py b/fiber/config.py index 75cf391..9d115ba 100644 --- a/fiber/config.py +++ b/fiber/config.py @@ -67,6 +67,11 @@ def main(): import os import logging import configparser +from typing import Any, Dict, List, Type, TypeVar, Optional +from typing import TYPE_CHECKING + +_current_config: None +_TConfig = TypeVar('_TConfig', bound="Config") _current_config = None @@ -84,7 +89,7 @@ def main(): DEFAULT_IMAGE = "fiber-test:latest" -def str2bool(text): +def str2bool(text: str) -> bool: """Simple function to convert a range of values to True/False.""" return text.lower() in ["true", "yes", "1"] @@ -106,34 +111,34 @@ class Config(object): """ - def __init__(self, conf_file=None): + def __init__(self, conf_file: str = None) -> None: # Not documented, people should not use this - self.merge_output = False - self.debug = False - self.image = None - self.default_image = DEFAULT_IMAGE - self.backend = None - self.default_backend = "local" + self.merge_output: bool = False + self.debug: bool = False + self.image: Optional[str] = None + self.default_image: str = DEFAULT_IMAGE + self.backend: Optional[str] = None + self.default_backend: str = "local" # Not documented, this should be removed because it's not used for now - self.use_bash = False - self.log_level = logging.INFO - self.log_file = "/tmp/fiber.log" + self.use_bash: bool = False + self.log_level: int = logging.INFO + self.log_file: str = "/tmp/fiber.log" # If ipc_active is True, Fiber worker processes will connect # to the master process. Otherwise, the master process will connect # to worker processes. # Not documented, should only be used internally - self.ipc_active = True + self.ipc_active: bool = True # if ipc_active is True, this can be 0, otherwise, it can only be a # valid TCP port number. Default 0. - self.ipc_admin_master_port = 0 + self.ipc_admin_master_port: int = 0 # Not documented, this is only used when `ipc_active` is False - self.ipc_admin_worker_port = 8000 + self.ipc_admin_worker_port: int = 8000 # Not documented, need to fine tune this - self.cpu_per_job = 1 + self.cpu_per_job: int = 1 # Not documented, need to fine tune this - self.mem_per_job = None - self.use_push_queue = True - self.kubernetes_namespace = "default" + self.mem_per_job: Optional[int] = None + self.use_push_queue: bool = True + self.kubernetes_namespace: str = "default" if conf_file is None: conf_file = ".fiberconfig" @@ -185,7 +190,7 @@ def __repr__(self): return repr(self.__dict__) @classmethod - def from_dict(cls, kv): + def from_dict(cls: Type[_TConfig], kv: Dict[str, Any]) -> _TConfig: obj = cls() for k in kv: setattr(obj, k, kv[k]) @@ -193,7 +198,7 @@ def from_dict(cls, kv): return obj -def get_object(): +def get_object() -> Config: """ Get a Config object representing current Fiber config @@ -207,7 +212,7 @@ def get_object(): return Config.from_dict(get_dict()) -def get_dict(): +def get_dict() -> Dict[str, Any]: """ Get current Fiber config in a dictionary @@ -218,7 +223,7 @@ def get_dict(): return {k: global_vars[k] for k in vars(_current_config)} -def init(**kwargs): +def init(**kwargs) -> List[str]: """ Init Fiber system and set config values. @@ -247,3 +252,22 @@ def init(**kwargs): logger.debug("Inited fiber with config: %s", vars(_config)) return updates + + +if TYPE_CHECKING: + merge_output = Config.merge_output + debug = Config.debug + image = Config.image + default_image = Config.default_image + backend = Config.backend + default_backend = Config.default_backend + use_bash = Config.use_bash + log_level = Config.log_level + log_file = Config.log_file + ipc_active = Config.ipc_active + ipc_admin_master_port = Config.ipc_admin_master_port + ipc_admin_worker_port = Config.ipc_admin_worker_port + cpu_per_job = Config.cpu_per_job + mem_per_job = Config.mem_per_job + use_push_queue = Config.use_push_queue + kubernetes_namespace = Config.kubernetes_namespace diff --git a/fiber/context.py b/fiber/context.py index 181e585..b6cb70e 100644 --- a/fiber/context.py +++ b/fiber/context.py @@ -15,53 +15,67 @@ import os import fiber.config as config from fiber import process +from typing import Optional, Tuple, Callable, Sequence +from fiber.managers import SyncManager +from fiber.pool import ZPool, ResilientZPool +from fiber.queues import SimpleQueuePush, LazyZConnection +from fiber.queues import Pipe -class FiberContext(): - _name = '' +_default_context: "FiberContext" + + +class FiberContext: + _name = "" Process = process.Process current_process = staticmethod(process.current_process) active_children = staticmethod(process.active_children) - def Manager(self): + def Manager(self) -> SyncManager: """Returns a manager associated with a running server process The managers methods such as `Lock()`, `Condition()` and `Queue()` can be used to create shared objects. """ - from fiber.managers import SyncManager m = SyncManager() m.start() return m - def Pool(self, processes=None, initializer=None, initargs=(), - maxtasksperchild=None, error_handling=False): + def Pool( + self, + processes: int = None, + initializer: Callable = None, + initargs: Sequence = (), + maxtasksperchild: int = None, + error_handling: bool = False, + ) -> ZPool: """Returns a process pool object""" - from .pool import ZPool, ResilientZPool if error_handling: - return ResilientZPool(processes, initializer, initargs, maxtasksperchild) + return ResilientZPool( + processes, initializer, initargs, maxtasksperchild + ) else: return ZPool(processes, initializer, initargs, maxtasksperchild) - def SimpleQueue(self): + def SimpleQueue(self) -> SimpleQueuePush: """Returns a queue object""" if config.use_push_queue: - from .queues import SimpleQueuePush return SimpleQueuePush() # PullQueue is not supported anymore raise NotImplementedError - def Pipe(self, duplex=True): + def Pipe( + self, duplex: bool = True + ) -> Tuple[LazyZConnection, LazyZConnection]: """Returns two connection object connected by a pipe""" - from .queues import Pipe return Pipe(duplex) - def cpu_count(self): + def cpu_count(self) -> Optional[int]: return os.cpu_count() - def get_context(self, method=None): + def get_context(self, method: str = None) -> "FiberContext": if method is None: return self if method != "spawn": @@ -69,8 +83,6 @@ def get_context(self, method=None): return _concrete_contexts[method] -_concrete_contexts = { - 'spawn': FiberContext() -} +_concrete_contexts = {"spawn": FiberContext()} -_default_context = _concrete_contexts['spawn'] +_default_context = _concrete_contexts["spawn"] diff --git a/fiber/core.py b/fiber/core.py index 0333bf0..03067dc 100644 --- a/fiber/core.py +++ b/fiber/core.py @@ -12,10 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from abc import ABC, abstractmethod import enum - - -MEM_CPU_RATIO = 2 # 2G per cpu +from typing import Dict, List, NoReturn, Optional, Any, Union, Tuple class ProcessStatus(enum.Enum): @@ -26,8 +25,24 @@ class ProcessStatus(enum.Enum): class JobSpec(object): - def __init__(self, image=None, command=None, name=None, cpu=None, mem=None, - volumes=None, gpu=None): + image: Optional[str] + command: List[str] + name: str + cpu: Optional[int] + mem: Optional[int] + volumes: Optional[Dict[str, Dict]] + gpu: Optional[int] + + def __init__( + self, + image: str = None, + command: List[str] = [], + name: str = "", + cpu: int = None, + mem: int = None, + volumes: Dict[str, Dict] = None, + gpu: int = None, + ) -> None: # Docker image used to launch this job self.image = image # Command to run in this job container, this should be a sequence @@ -40,8 +55,6 @@ def __init__(self, image=None, command=None, name=None, cpu=None, mem=None, # Maximum number of cpu cores this job can use self.gpu = gpu # Maximum memory size in MB that this job can use - #if mem is None: - # mem = cpu * MEM_CPU_RATIO self.mem = mem # volume name to be mounted, currently only used by k8s backend # For example: @@ -50,64 +63,72 @@ def __init__(self, image=None, command=None, name=None, cpu=None, mem=None, # } self.volumes = volumes - def __eq__(self, other): + def __eq__(self, other) -> bool: return self.__dict__ == other.__dict__ - def __repr__(self): - return ''.format(vars(self)) + def __repr__(self) -> str: + return "".format(vars(self)) class Job(object): # Data is used to hold backend specific data associated with this job - data = None + data: Any # Job id. This is set by backend and should only be used by Fiber backend - jid = None + jid: Union[str, int] # (Optional) The hostname/IP address for this job, this is used to # communicate with the master process. It is only used when # `ipc_admin_passive` is enabled. - host = None + host: str - def __init__(self, data, jid): + def __init__(self, data: Any, jid: Union[str, int]) -> None: + assert data is not None, "Job data is None" self.data = data self.jid = jid + self.host = "" def update(self): # update/refresh job attributes raise NotImplementedError -class Backend(object): +class Backend(ABC): @property + @abstractmethod def name(self): - raise NotImplementedError + pass - def create_job(self, job_spec): + @abstractmethod + def create_job(self, job_spec: JobSpec) -> Job: """This function is called when Fiber wants to create a new Process.""" - raise NotImplementedError + pass - def get_job_status(self, job): + @abstractmethod + def get_job_status(self, job: Job) -> ProcessStatus: """This function is called when Fiber wants to to get job status.""" - raise NotImplementedError + pass - def get_job_logs(self, job): + def get_job_logs(self, job: Job) -> str: """ This function is called when Fiber wants to to get logs of this job """ return "" - def wait_for_job(self, job, timeout): + @abstractmethod + def wait_for_job(self, job: Job, timeout: Optional[float]) -> Optional[int]: """Wait for a specific job until timeout. If timeout is None, wait until job is done. Returns `None` if timed out or `exitcode` if job is finished. """ - raise NotImplementedError + pass - def terminate_job(self, job): + @abstractmethod + def terminate_job(self, job: Job) -> None: """Terminate a job described by `job`.""" - raise NotImplementedError + pass - def get_listen_addr(self): + @abstractmethod + def get_listen_addr(self) -> Tuple[str, int, str]: """This function is called when Fiber wants to listen on a local address for incoming connection. It is currently used by Popen and Queue.""" - raise NotImplementedError + pass diff --git a/fiber/docker_backend.py b/fiber/docker_backend.py index 0059b06..d64910a 100644 --- a/fiber/docker_backend.py +++ b/fiber/docker_backend.py @@ -31,6 +31,7 @@ import fiber.config as config from fiber.core import ProcessStatus from fiber.util import find_ip_by_net_interface, find_listen_address +from typing import Any, Optional, Tuple, TypeVar, Union logger = logging.getLogger('fiber') @@ -47,7 +48,11 @@ class DockerJob(core.Job): - def update(self): + def __init__(self, data: Any, jid: Union[str, int]) -> None: + super().__init__(data, jid) + self.update() + + def update(self) -> None: container = self.data self.host = container.attrs['NetworkSettings']['IPAddress'] @@ -55,12 +60,12 @@ def update(self): class Backend(core.Backend): name = "docker" - def __init__(self): + def __init__(self) -> None: # Based on this link, no lock is needed accessing self.client # https://github.com/docker/docker-py/issues/619 self.client = docker.from_env() - def create_job(self, job_spec): + def create_job(self, job_spec: core.JobSpec) -> DockerJob: logger.debug("[docker]create_job: %s", job_spec) cwd = os.getcwd() volumes = {cwd: {'bind': cwd, 'mode': 'rw'}, @@ -101,7 +106,7 @@ def create_job(self, job_spec): container._fiber_backend_reloading = False return job - def _reload(self, container): + def _reload(self, container) -> None: container._fiber_backend_reloading = True logger.debug("container reloading %s", container.name) container.reload() @@ -112,12 +117,14 @@ def _reload(self, container): container._fiber_backend_reloading = False - def get_job_logs(self, job): + def get_job_logs(self, job: core.Job) -> str: container = job.data return container.logs(stream=False).decode('utf-8') - def get_job_status(self, job): + def get_job_status(self, job: core.Job) -> ProcessStatus: container = job.data + if container is None: + return ProcessStatus.UNKNOWN if config.merge_output: print(container.logs(stream=False).decode('utf-8')) @@ -131,8 +138,12 @@ def get_job_status(self, job): logger.debug("start container reloading thread %s", container.name) return status - def wait_for_job(self, job, timeout): + def wait_for_job(self, job: core.Job, timeout: Optional[float]) -> Optional[int]: container = job.data + if container is None: + # Job not started + return None + logger.debug("wait_for_job: %s", container.name) if config.merge_output: @@ -165,7 +176,7 @@ def wait_for_job(self, job, timeout): return res['StatusCode'] - def terminate_job(self, job): + def terminate_job(self, job: core.Job) -> None: logging.debug("terminate_job") container = job.data @@ -184,15 +195,16 @@ def terminate_job(self, job): raise e logger.debug("terminate job finished, %s", container.status) - def get_listen_addr(self): + def get_listen_addr(self) -> Tuple[str, int, str]: ip = None + ifce: Optional[str] = None if sys.platform == "darwin": # use the same hostname for both master and non master process # because docker.for.mac.localhost resolves to different inside # and outside docker container. "docker.for.mac.localhost" is # the name that doesn't change in and outside the container. - return "docker.for.mac.localhost", 0 + return "docker.for.mac.localhost", 0, "eth0" if not isinstance(fiber.current_process(), fiber.Process): # not inside docker @@ -206,5 +218,15 @@ def get_listen_addr(self): raise mp.ProcessError( "Can't find a usable IPv4 address to listen. ifce_name: {}, " "ifces: {}".format(ifce, psutil.net_if_addrs())) + + if ifce is None: + raise mp.ProcessError( + "Can't find a usable network interface to listen." + "ifces: {}".format(psutil.net_if_addrs()) + ) + + ip_ret: str = ip + ifce_ret: str = ifce + # use 0 to bind to a random free port number - return ip, 0, ifce + return ip_ret, 0, ifce_ret diff --git a/fiber/experimental/ring.py b/fiber/experimental/ring.py index d8522cb..0bcc6cc 100644 --- a/fiber/experimental/ring.py +++ b/fiber/experimental/ring.py @@ -21,6 +21,9 @@ import fiber from fiber.backend import get_backend +from typing import Any, NoReturn, Callable, Sequence, Dict, Iterator, Optional, List + +_manager: Any __all__ = [ @@ -32,10 +35,10 @@ _manager = None -def _get_manager(): +def _get_manager() -> fiber.managers.BaseManager: global _manager if _manager is None: - _manager = fiber.Manager() + _manager = fiber.Manager() # type: ignore return _manager @@ -48,7 +51,7 @@ class RingNode: :param rank: The id assigned to this node. Each node will be assigned a unique id called `rank`. Rank 0 is the control node of the `Ring`. """ - def __init__(self, rank): + def __init__(self, rank) -> None: self.rank = rank self.connected = False self.ip = None @@ -68,7 +71,17 @@ class Ring: :param initargs: positional arguments that are passed to initializer. Currently this is not used. """ - def __init__(self, processes, func, initializer, initargs=None): + + __fiber_meta__: Dict + + def __init__( + self, + processes: int, + func: Callable[[int, int], Any], + initializer: Callable[[Optional[Iterator[Any]]], None], + initargs: Iterator[Any] = None + ) -> None: + self.size = processes self.initializer = initializer self.initargs = initargs @@ -79,12 +92,12 @@ def __init__(self, processes, func, initializer, initargs=None): # Propogate meta info # We can't set attributes to bound/unbound methods (PEP 232), # so we set it to Ring object - self.__fiber_meta__ = func.__fiber_meta__ + self.__fiber_meta__ = func.__fiber_meta__ # type: ignore manager = _get_manager() - self.members = manager.list([RingNode(i) for i in range(self.size)]) + self.members = manager.list([RingNode(i) for i in range(self.size)]) # type: ignore - def _target(self): + def _target(self) -> None: rank = self.rank node = self.members[rank] @@ -97,10 +110,10 @@ def _target(self): node.port = port self.members[rank] = node - self.initializer(self) + self.initializer(self.initargs) self.func(rank, self.size) - def run(self): + def run(self) -> None: """ Start this Ring. This will start the ring 0 process on the same machine and start all the other ring nodes with Fiber processes. @@ -113,15 +126,15 @@ def run(self): # Start process rank 0 self.rank = 0 ctx = mp.get_context("spawn") - p = ctx.Process(target=self._target) - p.start() - procs.append(p) + p0 = ctx.Process(target=self._target) + p0.start() + procs.append(p0) for i in range(1, self.size): self.rank = i p = fiber.Process(target=self._target) p.start() - procs.append(p) + procs.append(p) # type: ignore[arg-type] self.rank = rank # wait for all processes to finish diff --git a/fiber/init.py b/fiber/init.py index 84145a1..b6af48c 100644 --- a/fiber/init.py +++ b/fiber/init.py @@ -22,7 +22,7 @@ import fiber.backend as fiber_backend -def init_logger(config, proc_name=None): +def init_logger(config: fiber_config.Config, proc_name: str = "") -> None: logger = logging.getLogger("fiber") if config.log_file.lower() == "stdout": handler = logging.StreamHandler() @@ -31,7 +31,7 @@ def init_logger(config, proc_name=None): if not os.path.exists(log_dir): os.makedirs(log_dir, exist_ok=True) - if not proc_name: + if proc_name == "": p = fiber.process.current_process() proc_name = p.name @@ -49,7 +49,7 @@ def init_logger(config, proc_name=None): logger.propagate = False -def init_fiber(proc_name=None, **kwargs): +def init_fiber(proc_name: str = "", **kwargs) -> None: """ Initialize Fiber. This function is called when you want to re-initialize Fiber with new config values and also re-init loggers. diff --git a/fiber/kubernetes_backend.py b/fiber/kubernetes_backend.py index d531d46..4f6d353 100644 --- a/fiber/kubernetes_backend.py +++ b/fiber/kubernetes_backend.py @@ -32,6 +32,7 @@ import fiber.config as fiber_config from fiber.core import ProcessStatus from fiber.util import find_ip_by_net_interface, find_listen_address +from typing import Any, Tuple, Optional logger = logging.getLogger("fiber") @@ -49,7 +50,7 @@ class Backend(core.Backend): name = "kubernetes" - def __init__(self, incluster=True): + def __init__(self, incluster: bool = True) -> None: if incluster: config.load_incluster_config() else: @@ -61,24 +62,23 @@ def __init__(self, incluster=True): if incluster: podname = socket.gethostname() - pod = self.core_api.read_namespaced_pod(podname, - self.default_namespace) + pod = self.core_api.read_namespaced_pod(podname, self.default_namespace) # Current model assume that Fiber only lauches 1 container per pod self.current_image = pod.spec.containers[0].image self.volumes = pod.spec.volumes self.mounts = pod.spec.containers[0].volume_mounts - #if pod.spec.volumes: + # if pod.spec.volumes: # self.current_mount = pod.spec.volumes[0].persistent_volume_claim.claim_name - #else: + # else: # self.current_mount = None else: self.current_image = None self.mounts = None self.volumes = None - def _get_resource_requirements(self, job_spec): - #requests = {} + def _get_resource_requirements(self, job_spec: core.JobSpec) -> Any: + # requests = {} limits = {} if job_spec.cpu: @@ -100,17 +100,13 @@ def _get_resource_requirements(self, job_spec): return None - - def create_job(self, job_spec): + def create_job(self, job_spec: core.JobSpec) -> core.Job: logger.debug("[k8s]create_job: %s", job_spec) body = client.V1Pod() name = "{}-{}".format( - job_spec.name.replace("_", "-").lower(), - str(uuid.uuid4())[:8] - ) - body.metadata = client.V1ObjectMeta( - namespace=self.default_namespace, name=name + job_spec.name.replace("_", "-").lower(), str(uuid.uuid4())[:8] ) + body.metadata = client.V1ObjectMeta(namespace=self.default_namespace, name=name) # set environment varialbes # TODO(jiale) add environment variables @@ -118,22 +114,22 @@ def create_job(self, job_spec): image = job_spec.image if job_spec.image else self.current_image container = client.V1Container( - name=name, image=image, command=job_spec.command, env=[], - stdin=True, tty=True, + name=name, + image=image, + command=job_spec.command, + env=[], + stdin=True, + tty=True, ) rr = self._get_resource_requirements(job_spec) if rr: logger.debug( - "[k8s]create_job, container resource requirements: %s", - job_spec + "[k8s]create_job, container resource requirements: %s", job_spec ) container.resources = rr - body.spec = client.V1PodSpec( - containers=[container], - restart_policy="Never" - ) + body.spec = client.V1PodSpec(containers=[container], restart_policy="Never") # propagate mount points to new containers if necesary if job_spec.volumes: @@ -141,18 +137,14 @@ def create_job(self, job_spec): volume_mounts = [] for pd_name, mount_info in job_spec.volumes.items(): - #volume_name = job_spec.volume if job_spec.volume else self.current_mount - pvc = client.V1PersistentVolumeClaimVolumeSource( - claim_name=pd_name - ) + # volume_name = job_spec.volume if job_spec.volume else self.current_mount + pvc = client.V1PersistentVolumeClaimVolumeSource(claim_name=pd_name) volume = client.V1Volume( - persistent_volume_claim=pvc, - name="volume-" + pd_name, + persistent_volume_claim=pvc, name="volume-" + pd_name, ) volumes.append(volume) mount = client.V1VolumeMount( - mount_path=mount_info["bind"], - name=volume.name, + mount_path=mount_info["bind"], name=volume.name, ) if mount_info["mode"] == "r": mount.read_only = True @@ -165,15 +157,13 @@ def create_job(self, job_spec): logger.debug("[k8s]calling create_namespaced_pod: %s", body.metadata.name) try: - v1pod = self.core_api.create_namespaced_pod( - self.default_namespace, body - ) + v1pod = self.core_api.create_namespaced_pod(self.default_namespace, body) except ApiException as e: raise e return core.Job(v1pod, v1pod.metadata.uid) - def get_job_status(self, job): + def get_job_status(self, job: core.Job) -> ProcessStatus: v1pod = job.data name = v1pod.metadata.name namespace = v1pod.metadata.namespace @@ -197,7 +187,7 @@ def get_job_status(self, job): pod_status = v1pod.status return PHASE_STATUS_MAP[pod_status.phase] - def get_job_logs(self, job): + def get_job_logs(self, job: core.Job) -> str: v1job = job.data name = v1job.metadata.name namespace = v1job.metadata.namespace @@ -214,7 +204,7 @@ def get_job_logs(self, job): return logs - def wait_for_job(self, job, timeout): + def wait_for_job(self, job: core.Job, timeout: Optional[float]) -> Optional[int]: logger.debug("[k8s]wait_for_job timeout=%s", timeout) total = 0 @@ -249,11 +239,13 @@ def wait_for_job(self, job, timeout): logger.debug("[k8s]wait_for_job done: container is not terminated") return None - logger.debug("[k8s]wait_for_job done: container terminated with " - "code: {}".format(terminated.exit_code)) + logger.debug( + "[k8s]wait_for_job done: container terminated with " + "code: {}".format(terminated.exit_code) + ) return terminated.exit_code - def terminate_job(self, job): + def terminate_job(self, job: core.Job) -> None: v1job = job.data name = v1job.metadata.name namespace = v1job.metadata.namespace @@ -263,21 +255,22 @@ def terminate_job(self, job): try: logger.debug( "calling delete_namespaced_pod(%s, %s, grace_period_seconds=%s)", - name, namespace, grace_period_seconds, + name, + namespace, + grace_period_seconds, ) self.core_api.delete_namespaced_pod( - name, namespace, grace_period_seconds=grace_period_seconds, - body=body, + name, namespace, grace_period_seconds=grace_period_seconds, body=body, ) except ApiException as e: logger.debug( - "[k8s] Exception when calling " "delete_namespaced_pod: %s", - str(e), + "[k8s] Exception when calling " "delete_namespaced_pod: %s", str(e), ) raise e - def get_listen_addr(self): + def get_listen_addr(self) -> Tuple[str, int, str]: ip = None + ifce: Optional[str] = None # if fiber.current_process() is multiprocessing.current_process(): if not isinstance(fiber.current_process(), fiber.Process): @@ -297,5 +290,15 @@ def get_listen_addr(self): "Can't find a usable IPv4 address to listen. ifce_name: {}, " "ifces: {}".format(ifce, psutil.net_if_addrs()) ) + + if ifce is None: + raise mp.ProcessError( + "Can't find a usable network interface to listen." + "ifces: {}".format(psutil.net_if_addrs()) + ) + + ip_ret: str = ip + ifce_ret: str = ifce + # use 0 to bind to a random free port number - return ip, 0, ifce + return ip_ret, 0, ifce_ret diff --git a/fiber/local_backend.py b/fiber/local_backend.py index 48e8a8d..38049f9 100644 --- a/fiber/local_backend.py +++ b/fiber/local_backend.py @@ -21,6 +21,7 @@ import fiber.core as core from fiber.core import ProcessStatus +from typing import Any, Tuple, Optional class Backend(core.Backend): @@ -31,17 +32,17 @@ class Backend(core.Backend): """ name = "local" - def __init__(self): + def __init__(self) -> None: pass - def create_job(self, job_spec): + def create_job(self, job_spec: core.JobSpec) -> core.Job: proc = subprocess.Popen(job_spec.command) job = core.Job(proc, proc.pid) job.host = 'localhost' return job - def get_job_status(self, job): + def get_job_status(self, job: core.Job) -> ProcessStatus: proc = job.data if proc.poll() is not None: @@ -50,7 +51,7 @@ def get_job_status(self, job): return ProcessStatus.STARTED - def wait_for_job(self, job, timeout): + def wait_for_job(self, job: core.Job, timeout: Optional[float]) -> Optional[int]: proc = job.data if timeout == 0: @@ -63,10 +64,10 @@ def wait_for_job(self, job, timeout): return proc.returncode - def terminate_job(self, job): + def terminate_job(self, job: core.Job) -> None: proc = job.data proc.terminate() - def get_listen_addr(self): + def get_listen_addr(self) -> Tuple[str, int, str]: return "127.0.0.1", 0, "lo" diff --git a/fiber/managers.py b/fiber/managers.py index c5d8cfb..c72e244 100644 --- a/fiber/managers.py +++ b/fiber/managers.py @@ -23,17 +23,29 @@ import multiprocessing.util as mp_util import queue import threading -from multiprocessing.connection import (SocketListener, _validate_family, - address_type) +from multiprocessing.connection import ( # type: ignore + SocketListener, _validate_family, address_type +) from multiprocessing.context import get_spawning_popen -from multiprocessing.managers import (Array, DictProxy, ListProxy, Namespace, - RebuildProxy, State, Value, dispatch) -from multiprocessing.process import AuthenticationString +from multiprocessing.managers import ( # type: ignore + Array, DictProxy, ListProxy, Namespace, + RebuildProxy, State, Value, dispatch +) +from multiprocessing.process import AuthenticationString # type: ignore import fiber.util import fiber.queues from fiber import process from fiber.backend import get_backend +from typing import ( + Any, TypeVar, Callable, Sequence, Dict, Union, Tuple, Optional, + Iterable, Mapping +) + +_T0 = TypeVar('_T0') +_T1 = TypeVar('_T1') +_TAsyncListProxy = TypeVar('_TAsyncListProxy', bound="AsyncListProxy") +Address = Union[str, Tuple[Union[str, bytes], int], None] logger = logging.getLogger('fiber') @@ -42,7 +54,12 @@ class Listener(multiprocessing.connection.Listener): - def __init__(self, address=None, family=None, backlog=500, authkey=None): + def __init__( + self, address: Address = None, + family: str = None, + backlog: int = 500, + authkey: bytes = None + ) -> None: family = family or (address and address_type(address)) \ or default_family if family != 'AF_INET': @@ -51,7 +68,10 @@ def __init__(self, address=None, family=None, backlog=500, authkey=None): backend = get_backend() # TODO(jiale) Add support for other address family for # backend.get_listen_addr - address = address or backend.get_listen_addr() + if address is None: + listen_addr = backend.get_listen_addr() + address = (listen_addr[0], listen_addr[1]) + self._address = address _validate_family(family) @@ -72,7 +92,7 @@ def __init__(self, address=None, family=None, backlog=500, authkey=None): # Use IP from `backend.get_listen_addr` and port from # self._listener._address. # This represents an address that other client can connect to - address = property(lambda self: (self._address[0], + address = property(lambda self: (self._address[0], # type: ignore self._listener._address[1])) @@ -85,7 +105,8 @@ def __init__(self, address=None, family=None, backlog=500, authkey=None): class Server(multiprocessing.managers.Server): - def __init__(self, registry, address, authkey, serializer): + def __init__(self, registry: Dict, address: Address, + authkey: bytes, serializer: str) -> None: assert isinstance(authkey, bytes) self.registry = registry self.authkey = AuthenticationString(authkey) @@ -93,11 +114,11 @@ def __init__(self, registry, address, authkey, serializer): # do authentication later self.listener = Listener(address=address, backlog=500) - self.address = self.listener.address + self.address = self.listener.address # type: ignore self.id_to_obj = {'0': (None, ())} - self.id_to_refcount = {} - self.id_to_local_proxy_obj = {} + self.id_to_refcount: Dict = {} + self.id_to_local_proxy_obj: Dict = {} self.mutex = threading.Lock() @@ -111,11 +132,11 @@ class BaseManager(multiprocessing.managers.BaseManager): The API of this class is the same as [multiprocessing.managers.BaseManager](https://docs.python.org/3.6/library/multiprocessing.html#multiprocessing.managers.BaseManager) """ - _registry = {} + _registry: Dict = {} _Server = Server - def __init__(self, address=None, authkey=None, serializer='pickle', - ctx=None): + def __init__(self, address: str = None, authkey: bytes = None, + serializer: str = 'pickle', ctx=None) -> None: if authkey is None: authkey = process.current_process().authkey self._address = address # XXX not final address if eg ('', 0) @@ -125,7 +146,7 @@ def __init__(self, address=None, authkey=None, serializer='pickle', self._serializer = serializer self._Listener, self._Client = listener_client[serializer] - def get_server(self): + def get_server(self) -> Server: """ Return server object with serve_forever() method and address attribute """ @@ -134,8 +155,11 @@ def get_server(self): self._authkey, self._serializer) @classmethod - def _run_server(cls, registry, address, authkey, serializer, writer, - initializer=None, initargs=()): + def _run_server( + cls, registry: Dict, address: str, authkey: bytes, serializer: str, + writer: fiber.queues.LazyZConnection, + initializer: Callable = None, initargs: Iterable[Any] = () + ) -> None: """Create a server, report its address and run it.""" if initializer is not None: initializer(*initargs) @@ -151,7 +175,11 @@ def _run_server(cls, registry, address, authkey, serializer, writer, logger.info('manager serving at %r', server.address) server.serve_forever() - def start(self, initializer=None, initargs=()): + def start( + self, + initializer: Callable = None, + initargs: Iterable[Any] = () + ) -> None: """Spawn a server process for this manager object.""" assert self._state.value == State.INITIAL logger.debug("start manager %s", self) @@ -168,7 +196,7 @@ def start(self, initializer=None, initargs=()): args=(self._registry, self._address, self._authkey, self._serializer, writer, initializer, initargs), ) - ident = ':'.join(str(i) for i in self._process._identity) + ident = ':'.join(str(i) for i in self._process._identity) # type: ignore[attr-defined] self._process.name = type(self).__name__ + '-' + ident self._process.start() @@ -179,16 +207,22 @@ def start(self, initializer=None, initargs=()): # register a finalizer self._state.value = State.STARTED - self.shutdown = mp_util.Finalize( - self, type(self)._finalize_manager, + self.shutdown = mp_util.Finalize( # type: ignore[assignment] + self, type(self)._finalize_manager, # type: ignore[attr-defined] args=(self._process, self._address, self._authkey, self._state, self._Client), exitpriority=0 ) @classmethod - def register(cls, typeid, callable=None, proxytype=None, exposed=None, - method_to_typeid=None, create_method=True): + def register( + cls, typeid: str, + callable: Callable = None, + proxytype: Any = None, + exposed: Sequence = None, + method_to_typeid: Optional[Mapping[str, str]] = None, + create_method: bool = True + ) -> None: """Register a typeid with the manager type.""" if '_registry' not in cls.__dict__: cls._registry = cls._registry.copy() @@ -227,7 +261,7 @@ def temp(self, *args, **kwds): class ProcessLocalSet(set): - def __init__(self): + def __init__(self) -> None: fiber.util.register_after_fork(self, lambda obj: obj.clear()) def __reduce__(self): @@ -236,11 +270,18 @@ def __reduce__(self): class BaseProxy(multiprocessing.managers.BaseProxy): """A base for proxies of shared objects.""" - _address_to_local = {} + _address_to_local: Dict = {} _mutex = fiber.util.ForkAwareThreadLock() - def __init__(self, token, serializer, manager=None, - authkey=None, exposed=None, incref=True, manager_owned=False): + def __init__( + self, token: multiprocessing.managers.Token, + serializer: str, + manager: BaseManager = None, + authkey: bytes = None, + exposed: Sequence = None, + incref: bool = True, + manager_owned: bool = False + ) -> None: with BaseProxy._mutex: tls_idset = BaseProxy._address_to_local.get(token.address, None) if tls_idset is None: @@ -276,16 +317,16 @@ def __init__(self, token, serializer, manager=None, self._authkey = process.current_process().authkey if incref: - self._incref() + self._incref() # type: ignore[attr-defined] - fiber.util.register_after_fork(self, BaseProxy._after_fork) + fiber.util.register_after_fork(self, BaseProxy._after_fork) # type: ignore[attr-defined] - def connect(self): + def connect(self) -> None: """Connect manager object to the server process.""" Listener, Client = listener_client[self._serializer] - conn = Client(self._address, authkey=self._authkey) + conn = Client(self._address, authkey=self._authkey) # type: ignore[attr-defined] dispatch(conn, None, 'dummy') - self._state.value = State.STARTED + self._state.value = State.STARTED # type: ignore[attr-defined] def __reduce__(self): kwds = {} @@ -301,13 +342,20 @@ def __reduce__(self): (type(self), self._token, self._serializer, kwds)) -def AutoProxy(token, serializer, manager=None, authkey=None, - exposed=None, incref=True): +def AutoProxy( + token: multiprocessing.managers.Token, + serializer: str, + manager: BaseManager = None, + authkey: bytes = None, + exposed: Sequence = None, + incref: bool = True +) -> BaseProxy: """Return an auto-proxy for `token`.""" _Client = listener_client[serializer][1] if exposed is None: - conn = _Client(token.address, authkey=authkey) + _address = (str(token.address[0]), token.address[1]) + conn = _Client(_address, authkey=authkey) try: exposed = dispatch(conn, None, 'get_methods', (token,)) finally: @@ -318,14 +366,14 @@ def AutoProxy(token, serializer, manager=None, authkey=None, if authkey is None: authkey = process.current_process().authkey - ProxyType = MakeProxyType('AutoProxy[%s]' % token.typeid, exposed) + ProxyType = MakeProxyType('AutoProxy[{}]'.format(str(token.typeid)), exposed) proxy = ProxyType(token, serializer, manager=manager, authkey=authkey, incref=incref) proxy._isauto = True return proxy -def MakeProxyType(name, exposed, _cache={}): +def MakeProxyType(name: str, exposed: Sequence, _cache: Dict = {}) -> Any: """Return a proxy type whose methods are given by `exposed`.""" exposed = tuple(exposed) try: @@ -333,14 +381,14 @@ def MakeProxyType(name, exposed, _cache={}): except KeyError: pass - dic = {} + dic: Dict = {} for meth in exposed: exec("""def %s(self, *args, **kwds): return self._callmethod(%r, args, kwds)""" % (meth, meth), dic) ProxyType = type(name, (BaseProxy,), dic) - ProxyType._exposed_ = exposed + ProxyType._exposed_ = exposed # type: ignore[attr-defined] _cache[(name, exposed)] = ProxyType return ProxyType @@ -409,7 +457,7 @@ def __setattr__(self, key, value): callmethod = object.__getattribute__(self, '_callmethod') return callmethod('__setattr__', (key, value)) - def __delattr__(self, key): + def __delattr__(self, key) -> None: if key[0] == '_': return object.__delattr__(self, key) callmethod = object.__getattribute__(self, '_callmethod') @@ -419,10 +467,10 @@ def __delattr__(self, key): class ValueProxy(BaseProxy): _exposed_ = ('get', 'set') - def get(self): + def get(self) -> Any: return self._callmethod('get') - def set(self, value): + def set(self, value) -> None: return self._callmethod('set', (value,)) value = property(get, set) @@ -431,11 +479,11 @@ def set(self, value): # Define AsyncProxy and AsyncManager class AsyncProxyResult(): - def __init__(self, conn, proxy): + def __init__(self, conn, proxy) -> None: self._conn = conn self._proxy = proxy - def get(self): + def get(self) -> Any: kind, result = self._conn.recv() if kind == '#RETURN': @@ -446,11 +494,18 @@ def get(self): class AsyncBaseProxy(BaseProxy): - def _callmethod(self, methodname, args=(), kwds={}): + + # BaseProxy's _callmethod doesn't return anything. But async base proxy + # needs to return something + def _callmethod( # type: ignore[override] + self, methodname: str, + args: Sequence = (), + kwds: Dict = {} + ) -> AsyncProxyResult: try: conn = self._tls.connection except AttributeError: - self._connect() + self._connect() # type: ignore[attr-defined] conn = self._tls.connection conn.send((self._id, methodname, args, kwds)) @@ -458,7 +513,7 @@ def _callmethod(self, methodname, args=(), kwds={}): return AsyncProxyResult(conn, self) -def MakeAsyncProxyType(name, exposed, _cache={}): +def MakeAsyncProxyType(name: str, exposed: Sequence, _cache: Dict = {}) -> type: """Return a proxy type whose methods are given by `exposed`.""" exposed = tuple(exposed) try: @@ -466,25 +521,32 @@ def MakeAsyncProxyType(name, exposed, _cache={}): except KeyError: pass - dic = {} + dic: Dict = {} for meth in exposed: exec("""def %s(self, *args, **kwds): return self._callmethod(%r, args, kwds)""" % (meth, meth), dic) ProxyType = type(name, (AsyncBaseProxy,), dic) - ProxyType._exposed_ = exposed + ProxyType._exposed_ = exposed # type: ignore[attr-defined] _cache[(name, exposed)] = ProxyType return ProxyType -def AsyncAutoProxy(token, serializer, manager=None, authkey=None, - exposed=None, incref=True): +def AsyncAutoProxy( + token: multiprocessing.managers.Token, + serializer: str, + manager: BaseManager = None, + authkey: bytes = None, + exposed: Sequence = None, + incref: bool = True +) -> Any: """Return an auto-proxy for `token`.""" _Client = listener_client[serializer][1] if exposed is None: - conn = _Client(token.address, authkey=authkey) + _address = (str(token.address[0]), token.address[1]) + conn = _Client(_address, authkey=authkey) try: exposed = dispatch(conn, None, 'get_methods', (token,)) finally: @@ -495,7 +557,7 @@ def AsyncAutoProxy(token, serializer, manager=None, authkey=None, if authkey is None: authkey = process.current_process().authkey - ProxyType = MakeAsyncProxyType('AutoProxy[%s]' % token.typeid, exposed) + ProxyType = MakeAsyncProxyType('AutoProxy[{}]'.format(str(token.typeid)), exposed) proxy = ProxyType(token, serializer, manager=manager, authkey=authkey, incref=incref) proxy._isauto = True @@ -505,11 +567,11 @@ def AsyncAutoProxy(token, serializer, manager=None, authkey=None, class AsyncValueProxy(AsyncBaseProxy): _exposed_ = ('get', 'set') - def get(self): + def get(self) -> AsyncProxyResult: return self._callmethod('get') - def set(self, value): - return self._callmethod('set', (value,)) + def set(self, value) -> None: + self._callmethod('set', (value,)) value = property(get, set) @@ -547,8 +609,15 @@ class MyManager(AsyncManager): ``` """ @classmethod - def register(cls, typeid, callable=None, proxytype=None, exposed=None, - method_to_typeid=None, create_method=True): + def register( + cls, + typeid: str, + callable: Callable = None, + proxytype: Any = None, + exposed: Sequence = None, + method_to_typeid: Optional[Mapping[str, str]] = None, + create_method: bool = True + ) -> None: """Register a typeid with the manager type.""" if '_registry' not in cls.__dict__: cls._registry = cls._registry.copy() @@ -593,7 +662,7 @@ def temp(self, *args, **kwds): )) -class AsyncListProxy(AsyncBaseListProxy): +class AsyncListProxy(AsyncBaseListProxy): # type: ignore def __iadd__(self, value): self._callmethod('extend', (value,)) return self @@ -608,7 +677,7 @@ def __imul__(self, value): '__setitem__', 'clear', 'copy', 'get', 'has_key', 'items', 'keys', 'pop', 'popitem', 'setdefault', 'update', 'values' )) -AsyncDictProxy._method_to_typeid_ = { +AsyncDictProxy._method_to_typeid_ = { # type: ignore[attr-defined] '__iter__': 'Iterator', } diff --git a/fiber/meta.py b/fiber/meta.py index 6385b85..7002963 100644 --- a/fiber/meta.py +++ b/fiber/meta.py @@ -13,10 +13,13 @@ # limitations under the License. +from typing import Any, Callable, Dict + + VALID_META_KEYS = ["cpu", "memory", "gpu"] -def post_process(metadata): +def post_process(metadata: Dict) -> Dict: # memory should be in MB if "memory" in metadata: memory = metadata.pop("memory") @@ -25,7 +28,7 @@ def post_process(metadata): return metadata -def meta(**kwargs): +def meta(**kwargs) -> Callable[[Any], Any]: """ fiber.meta API allows you to decorate your function and provide some hints to Fiber. Currently this is mainly used for specify the resource usage of user diff --git a/fiber/pool.py b/fiber/pool.py index 6bb3618..8f67874 100644 --- a/fiber/pool.py +++ b/fiber/pool.py @@ -44,9 +44,13 @@ import time import secrets import traceback -from multiprocessing.pool import (CLOSE, RUN, TERMINATE, - ExceptionWithTraceback, MaybeEncodingError, - ThreadPool, _helper_reraises_exception) +from multiprocessing.pool import ( # type: ignore + CLOSE, RUN, TERMINATE, + ExceptionWithTraceback, + MaybeEncodingError, + ThreadPool, + _helper_reraises_exception +) import fiber.queues import fiber.config as config @@ -55,6 +59,8 @@ from fiber.socket import Socket from fiber.process import current_process import signal +from typing import (Any, Generator, Iterator, NoReturn, Callable, + Sequence, List, Dict, Union, Tuple, Optional) if fiber.util.is_in_interactive_console(): @@ -68,7 +74,7 @@ MAX_PORT = 65535 -def safe_join_worker(proc): +def safe_join_worker(proc: fiber.process.Process) -> None: p = proc if p.is_alive(): # worker has not yet exited @@ -77,7 +83,7 @@ def safe_join_worker(proc): p.join(5) -def safe_terminate_worker(proc): +def safe_terminate_worker(proc: fiber.process.Process) -> None: delay = random.random() # Randomize start time to prevent overloading the server @@ -93,7 +99,7 @@ def safe_terminate_worker(proc): logger.debug("safe_terminate_worker() finished") -def safe_start(proc): +def safe_start(proc: fiber.process.Process) -> None: try: proc.start() proc._start_failed = False @@ -104,7 +110,7 @@ def safe_start(proc): proc._start_failed = True -def mp_worker_core(inqueue, outqueue, maxtasks=None, wrap_exception=False): +def mp_worker_core(inqueue: fiber.queues.SimpleQueue, outqueue: fiber.queues.SimpleQueue, maxtasks: int = None, wrap_exception: bool = False) -> None: logger.debug('mp_worker_core running') put = outqueue.put get = inqueue.get @@ -136,13 +142,13 @@ def mp_worker_core(inqueue, outqueue, maxtasks=None, wrap_exception=False): wrapped)) put((job, i, (False, wrapped))) - task = job = result = func = args = kwds = None + task = job = result = func = args = kwds = None # type: ignore completed += 1 logger.debug('worker exiting after %s tasks' % completed) -def mp_worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None, - wrap_exception=False, num_workers=1): +def mp_worker(inqueue: fiber.queues.SimpleQueue, outqueue: fiber.queues.SimpleQueue, initializer: Callable = None, initargs: Sequence = (), maxtasks: int = None, + wrap_exception: bool = False, num_workers: int = 1) -> None: """This is mostly the same as multiprocessing.pool.worker, the difference is that it will start multiple workers (specified by `num_workers` argument) via multiproccessing and allow the Fiber pool worker to take multiple CPU @@ -150,9 +156,9 @@ def mp_worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None, """ assert maxtasks is None or (type(maxtasks) == int and maxtasks > 0) - if hasattr(inqueue, '_writer'): - inqueue._writer.close() - outqueue._reader.close() + if hasattr(inqueue, 'writer'): + inqueue.writer.close() + outqueue.reader.close() if initializer is not None: initializer(*initargs) @@ -174,18 +180,20 @@ def mp_worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None, class ClassicPool(mp_pool.Pool): + _wrap_exception = True + @staticmethod - def Process(ctx, *args, **kwds): + def Process(ctx, *args, **kwds) -> fiber.process.Process: return fiber.process.Process(*args, **kwds) - def __init__(self, processes=None, initializer=None, initargs=(), - maxtasksperchild=None, cluster=None): + def __init__(self, processes: int = None, initializer: Callable = None, initargs: Sequence = (), + maxtasksperchild: int = None, cluster=None) -> None: self._ctx = None self._setup_queues() - self._taskqueue = queue.Queue() - self._cache = {} - self._state = RUN + self._taskqueue: "queue.Queue[Tuple[Iterator[Tuple], Optional[Callable[[int], None]]]]" = queue.Queue() + self._cache: Dict = {} + self._state: int = RUN self._maxtasksperchild = maxtasksperchild self._initializer = initializer self._initargs = initargs @@ -199,8 +207,8 @@ def __init__(self, processes=None, initializer=None, initargs=(), raise TypeError("initializer must be a callable") self._processes = processes - self._pool = [] - self._threads = [] + self._pool: List[fiber.process.Process] = [] + self._threads: List[threading.Thread] = [] self._repopulate_pool() # Worker handler @@ -209,10 +217,10 @@ def __init__(self, processes=None, initializer=None, initargs=(), args=(self._cache, self._taskqueue, self._ctx, self.Process, self._processes, self._pool, self._threads, self._inqueue, self._outqueue, self._initializer, self._initargs, - self._maxtasksperchild, self._wrap_exception) + self._maxtasksperchild, self._wrap_exception) # type: ignore ) self._worker_handler.daemon = True - self._worker_handler._state = RUN + self._worker_handler._state = RUN # type: ignore self._worker_handler.start() logger.debug( "Pool: started _handle_workers thread(%s:%s)", @@ -231,7 +239,7 @@ def __init__(self, processes=None, initializer=None, initargs=(), self._pool, self._cache) ) self._task_handler.daemon = True - self._task_handler._state = RUN + self._task_handler._state = RUN # type: ignore self._task_handler.start() logger.debug( "Pool: started _handle_tasks thread(%s:%s)", @@ -244,7 +252,7 @@ def __init__(self, processes=None, initializer=None, initargs=(), args=(self._outqueue, self._quick_get, self._cache) ) self._result_handler.daemon = True - self._result_handler._state = RUN + self._result_handler._state = RUN # type: ignore self._result_handler.start() logger.debug( "Pool: started _handle_results thread(%s:%s)", @@ -261,7 +269,7 @@ def __init__(self, processes=None, initializer=None, initargs=(), ) logger.debug("Pool: registered _terminate_pool finalizer") - def _setup_queues(self): + def _setup_queues(self) -> None: self._inqueue = fiber.queues.SimpleQueue() logger.debug("Pool|created Pool._inqueue: %s", self._inqueue) self._outqueue = fiber.queues.SimpleQueue() @@ -273,8 +281,11 @@ def _setup_queues(self): # is a REQ socket. It can't be called consecutively. self._quick_get = self._outqueue.get - def _map_async(self, func, iterable, mapper, chunksize=None, callback=None, - error_callback=None): + def _map_async( + self, func: Callable, iterable: Sequence, mapper, + chunksize: int = None, callback: Callable = None, + error_callback: Callable = None + ) -> mp.pool.MapResult: """ Helper function to implement map, starmap and their async counterparts. """ @@ -293,12 +304,12 @@ def _map_async(self, func, iterable, mapper, chunksize=None, callback=None, if len(iterable) == 0: chunksize = 0 - task_batches = ClassicPool._get_tasks(func, iterable, chunksize) - result = mp_pool.MapResult(self._cache, chunksize, len(iterable), + task_batches = ClassicPool._get_tasks(func, iterable, chunksize) # type: ignore + result: mp_pool.MapResult = mp_pool.MapResult(self._cache, chunksize, len(iterable), callback, error_callback=error_callback) self._taskqueue.put( ( - self._guarded_task_generation(result._job, + self._guarded_task_generation(result._job, # type: ignore mapper, task_batches), None @@ -309,12 +320,12 @@ def _map_async(self, func, iterable, mapper, chunksize=None, callback=None, @staticmethod def _handle_workers(cache, taskqueue, ctx, Process, processes, pool, threads, inqueue, outqueue, initializer, initargs, - maxtasksperchild, wrap_exception): + maxtasksperchild, wrap_exception) -> None: thread = threading.current_thread() # Keep maintaining workers until the cache gets drained, unless the # pool is terminated. - while thread._state == RUN or (cache and thread._state != TERMINATE): + while thread._state == RUN or (cache and thread._state != TERMINATE): # type: ignore ClassicPool._maintain_pool(ctx, Process, processes, pool, threads, inqueue, outqueue, initializer, initargs, maxtasksperchild, wrap_exception) @@ -323,7 +334,7 @@ def _handle_workers(cache, taskqueue, ctx, Process, processes, pool, logger.debug("_handle_workers exits") @staticmethod - def _join_exited_workers(pool): + def _join_exited_workers(pool) -> bool: """Cleanup after any worker processes which have exited due to reaching their specified lifetime. Returns True if any workers were cleaned up. """ @@ -353,7 +364,7 @@ def _join_exited_workers(pool): del pool[i] return cleaned - def _repopulate_pool(self): + def _repopulate_pool(self) -> Any: return self._repopulate_pool_static(self._ctx, self.Process, self._processes, self._pool, self._threads, @@ -366,7 +377,7 @@ def _repopulate_pool(self): @staticmethod def _repopulate_pool_static(ctx, Process, processes, pool, threads, inqueue, outqueue, initializer, initargs, - maxtasksperchild, wrap_exception): + maxtasksperchild, wrap_exception) -> None: """Bring the number of pool processes up to the specified number, for use after reaping workers which have exited. """ @@ -413,7 +424,7 @@ def _repopulate_pool_static(ctx, Process, processes, pool, threads, @staticmethod def _maintain_pool(ctx, Process, processes, pool, threads, inqueue, outqueue, initializer, initargs, maxtasksperchild, - wrap_exception): + wrap_exception) -> None: """Clean up any exited workers and start replacements for them. """ if ClassicPool._join_exited_workers(pool): @@ -424,15 +435,19 @@ def _maintain_pool(ctx, Process, processes, pool, threads, inqueue, wrap_exception) @staticmethod - def _handle_tasks(taskqueue, put, outqueue, pool, cache): + def _handle_tasks( + taskqueue: "queue.Queue[Tuple[Iterator[Tuple], Optional[Callable[[int], None]]]]", + put, outqueue, pool, cache + ) -> None: thread = threading.current_thread() for taskseq, set_length in iter(taskqueue.get, None): + task = None try: # iterating taskseq cannot fail for task in taskseq: - if thread._state: + if thread._state: # type: ignore break try: put(task) @@ -449,7 +464,7 @@ def _handle_tasks(taskqueue, put, outqueue, pool, cache): continue break finally: - task = taskseq = job = None + task = taskseq = job = None # type: ignore else: logger.debug('_handle_tasks: task handler got sentinel') @@ -471,7 +486,7 @@ def _handle_tasks(taskqueue, put, outqueue, pool, cache): logger.debug('_handle_tasks: task handler exiting') @staticmethod - def _handle_results(outqueue, get, cache): + def _handle_results(outqueue, get, cache) -> None: thread = threading.current_thread() while 1: @@ -481,8 +496,8 @@ def _handle_results(outqueue, get, cache): # logger.debug('result handler got EOFError/OSError: exiting') return - if thread._state: - assert thread._state == TERMINATE + if thread._state: # type: ignore + assert thread._state == TERMINATE # type: ignore # logger.debug('result handler found thread._state=TERMINATE') break @@ -497,7 +512,7 @@ def _handle_results(outqueue, get, cache): pass task = job = obj = None - while cache and thread._state != TERMINATE: + while cache and thread._state != TERMINATE: # type: ignore try: task = get() except (OSError, EOFError): @@ -528,11 +543,11 @@ def _handle_results(outqueue, get, cache): pass logger.debug('result handler exiting: len(cache)=%s, thread._state=%s', - len(cache), thread._state) + len(cache), thread._state) # type: ignore @classmethod def _terminate_pool(cls, taskqueue, inqueue, outqueue, pool, threads, - worker_handler, task_handler, result_handler, cache): + worker_handler, task_handler, result_handler, cache) -> None: # this is guaranteed to only be called once logger.debug('finalizing pool') start = time.time() @@ -649,21 +664,21 @@ class Inventory(): can be called with corresponding `seq`. This inventory will handle waiting, managing results from different map calls. """ - def __init__(self, queue_get): + def __init__(self, queue_get: Callable[[], Any]) -> None: self._seq = 0 self._queue_get = queue_get - self._inventory = {} - self._spec = {} - self._idx_cur = {} + self._inventory: Dict[int, List[Any]] = {} + self._spec: Dict[int, int] = {} + self._idx_cur: Dict[int, int] = {} - def add(self, ntasks): + def add(self, ntasks: int) -> int: self._seq += 1 self._inventory[self._seq] = [None] * ntasks self._spec[self._seq] = ntasks self._idx_cur[self._seq] = 0 return self._seq - def get(self, job_seq): + def get(self, job_seq: int) -> Any: n = self._spec[job_seq] while n != 0: @@ -675,10 +690,10 @@ def get(self, job_seq): n = self._spec[seq] ret = self._inventory[job_seq] - self._inventory[job_seq] = None + self._inventory[job_seq] = [] return ret - def iget_unordered(self, job_seq): + def iget_unordered(self, job_seq: int) -> Iterator[Any]: n = self._spec[job_seq] while n != 0: @@ -694,7 +709,7 @@ def iget_unordered(self, job_seq): return - def iget_ordered(self, job_seq): + def iget_ordered(self, job_seq: int) -> Iterator[Any]: idx = self._idx_cur[job_seq] total = len(self._inventory[job_seq]) @@ -729,17 +744,17 @@ def iget_ordered(self, job_seq): class MapResult(): - def __init__(self, seq, inventory): + def __init__(self, seq: int, inventory: Inventory) -> None: self._seq = seq self._inventory = inventory - def get(self): + def get(self) -> Any: return self._inventory.get(self._seq) - def iget_ordered(self): + def iget_ordered(self) -> Any: return self._inventory.iget_ordered(self._seq) - def iget_unordered(self): + def iget_unordered(self) -> Any: return self._inventory.iget_unordered(self._seq) @@ -748,7 +763,7 @@ class ApplyResult(MapResult): represents an handle that can be used to get the actual result. """ - def get(self): + def get(self) -> Any: """Get the actual result represented by this object :returns: Actual result. This method will block if the actual result @@ -757,8 +772,14 @@ def get(self): return self._inventory.get(self._seq)[0] -def zpool_worker_core(master_conn, result_conn, maxtasksperchild, - wrap_exception, rank=-1, req=False): +def zpool_worker_core( + master_conn: LazyZConnection, + result_conn: LazyZConnection, + maxtasksperchild: Optional[int], + wrap_exception: bool, + rank: int = -1, + req: bool = False +) -> None: """ The actual function that processes tasks. @@ -775,13 +796,11 @@ def zpool_worker_core(master_conn, result_conn, maxtasksperchild, """ logger.debug("zpool_worker_core started %s", rank) - proc = None ident = secrets.token_bytes(4) - if req: - proc = current_process() while True: if req: + proc = current_process() # master_conn is a REQ type socket, need to send id (rank) to # master to request a task. id is packed in type unsigned short (H) master_conn.send_bytes(struct.pack("4si", ident, proc.pid)) @@ -813,24 +832,26 @@ def zpool_worker_core(master_conn, result_conn, maxtasksperchild, data = (seq, batch, batch + i, res) if req: - data += (ident,) - result_conn.send(data) + result_conn.send((seq, batch, batch + i, res, ident)) + else: + result_conn.send((seq, batch, batch + i, res)) else: for i, args in enumerate(arg_list): res = func(args) - data = (seq, batch, batch + i, res) if req: - data += (ident,) - result_conn.send(data) + result_conn.send((seq, batch, batch + i, res, ident)) + else: + result_conn.send((seq, batch, batch + i, res)) + #print("worker_core exit, ", rank, proc.pid) -def handle_signal(signal, frame): +def handle_signal(signal, frame) -> None: # run sys.exit() so that atexit handlers can run sys.exit() -def zpool_worker(master_conn, result_conn, initializer=None, initargs=(), - maxtasks=None, wrap_exception=False, num_workers=1, req=False): +def zpool_worker(master_conn: LazyZConnection, result_conn: LazyZConnection, initializer: Callable = None, initargs: Sequence = (), + maxtasks: int = None, wrap_exception: bool = False, num_workers: int = 1, req: bool = False) -> None: """ The entry point of Pool worker function. @@ -885,23 +906,29 @@ class ZPool(): results handling. This makes it faster. """ - def __init__(self, processes=None, initializer=None, initargs=(), - maxtasksperchild=None, cluster=None, - master_sock_type="w"): - - self._pool = [] + def __init__( + self, + processes: int = None, + initializer: Callable = None, + initargs: Sequence = (), + maxtasksperchild: int = None, + cluster=None, + master_sock_type: str = "w" + ) -> None: + + self._pool: List[fiber.process.Process] = [] # Set default processes to 1 self._processes = processes if processes is not None else 1 self._initializer = initializer self._initargs = initargs self._maxtasksperchild = maxtasksperchild self._cluster = cluster - self._seq = 0 - self._state = RUN - self.taskq = queue.Queue() - self.sent_tasks = 0 - self.recv_tasks = 0 - self.max_processing_tasks = 20000 + self._seq: int = 0 + self._state: int = RUN + self.taskq: queue.Queue = queue.Queue() + self.sent_tasks: int = 0 + self.recv_tasks: int = 0 + self.max_processing_tasks: int = 20000 # networking related backend = get_backend() @@ -928,7 +955,7 @@ def __init__(self, processes=None, initializer=None, initargs=(), target=self.__class__._handle_workers, args=(self,) ) td.daemon = True - td._state = RUN + td._state = RUN # type: ignore # `td` will be started later by `lazy_start_workers` later self._worker_handler = td self._worker_handler_started = False @@ -938,18 +965,18 @@ def __init__(self, processes=None, initializer=None, initargs=(), target=self._handle_tasks, ) td.daemon = True - td._state = RUN + td._state = RUN # type: ignore td.start() self._task_handler = td - def __repr__(self): + def __repr__(self) -> str: return "<{}({}, {})>".format( type(self).__name__, getattr(self, "_processes", None), getattr(self, "_master_addr", None), ) - def _handle_tasks(self): + def _handle_tasks(self) -> None: taskq = self.taskq master_sock = self._master_sock @@ -962,10 +989,10 @@ def _handle_tasks(self): master_sock.send(data) self.sent_tasks += 1 - def _task_put(self, task): + def _task_put(self, task) -> None: self.taskq.put(task) - def _res_get(self): + def _res_get(self) -> Any: payload = self._result_sock.recv() self.recv_tasks += 1 data = pickle.loads(payload) @@ -973,16 +1000,16 @@ def _res_get(self): return data @staticmethod - def _join_exited_workers(workers): + def _join_exited_workers(workers: List[fiber.process.Process]) -> List[fiber.process.Process]: thread = threading.current_thread() logger.debug("ZPool _join_exited_workers running, workers %s, " - "thread._state %s", workers, thread._state) + "thread._state %s", workers, thread._state) # type: ignore exited_workers = [] for i in reversed(range(len(workers))): - if thread._state != RUN: + if thread._state != RUN: # type: ignore break worker = workers[i] @@ -1007,8 +1034,15 @@ def _join_exited_workers(workers): return exited_workers @staticmethod - def _maintain_workers(processes, workers, master_addr, result_addr, initializer, - initargs, maxtasksperchild): + def _maintain_workers( + processes: int, + workers: List[fiber.process.Process], + master_addr: str, + result_addr: str, + initializer: Optional[Callable], + initargs: Optional[Sequence], + maxtasksperchild: Optional[int] + ) -> None: thread = threading.current_thread() workers_per_fp = config.cpu_per_job @@ -1017,7 +1051,7 @@ def _maintain_workers(processes, workers, master_addr, result_addr, initializer, logger.debug("ZPool _maintain_workers running, workers %s", workers) threads = [] - while left > 0 and thread._state == RUN: + while left > 0 and thread._state == RUN: # type: ignore if left > workers_per_fp: n = workers_per_fp @@ -1057,7 +1091,7 @@ def _maintain_workers(processes, workers, master_addr, result_addr, initializer, logger.debug("ZPool _maintain_workers finished, workers %s", workers) @staticmethod - def _handle_workers(pool): + def _handle_workers(pool: "ZPool") -> None: logger.debug("%s _handle_workers running", pool) td = threading.current_thread() @@ -1067,7 +1101,7 @@ def _handle_workers(pool): pool._initargs, pool._maxtasksperchild ) - while td._state == RUN: + while td._state == RUN: # type: ignore if len(ZPool._join_exited_workers(pool._pool)) > 0: # create new workers when old workers exited ZPool._maintain_workers( @@ -1082,12 +1116,12 @@ def _handle_workers(pool): pool) @staticmethod - def _chunks(iterable, size): + def _chunks(iterable: Sequence, size: int) -> Iterator[Any]: for i in range(0, len(iterable), size): yield iterable[i:i + size] - def apply_async(self, func, args=(), kwds={}, callback=None, - error_callback=None): + def apply_async(self, func: Callable, args: Sequence = (), kwds: Dict = {}, callback: Callable = None, + error_callback: Callable = None) -> ApplyResult: """ Run function `func` with arguments `args` and keyword arguments `kwds` on a remote Pool worker. This is an asynchronous version of `apply`. @@ -1115,15 +1149,15 @@ def apply_async(self, func, args=(), kwds={}, callback=None, return res - def start_workers(self): + def start_workers(self) -> None: self._worker_handler.start() self._worker_handler_started = True - def lazy_start_workers(self, func): + def lazy_start_workers(self, func: Callable) -> None: if hasattr(func, "__fiber_meta__"): if ( not hasattr(zpool_worker, "__fiber_meta__") - or zpool_worker.__fiber_meta__ != func.__fiber_meta__ + or zpool_worker.__fiber_meta__ != func.__fiber_meta__ # type: ignore ): if self._worker_handler_started: raise RuntimeError( @@ -1131,13 +1165,13 @@ def lazy_start_workers(self, func): "requirements acceptable by this pool. Try creating a " "different pool for it." ) - zpool_worker.__fiber_meta__ = func.__fiber_meta__ + zpool_worker.__fiber_meta__ = func.__fiber_meta__ # type: ignore if not self._worker_handler_started: self.start_workers() - def map_async(self, func, iterable, chunksize=None, callback=None, - error_callback=None): + def map_async(self, func: Callable, iterable: Sequence, chunksize: int = None, callback: Callable = None, + error_callback: Callable = None) -> MapResult: """ For each element `e` in `iterable`, run `func(e)`. The workload is distributed between all the Pool workers. This is an asynchronous @@ -1183,7 +1217,7 @@ def map_async(self, func, iterable, chunksize=None, callback=None, return res - def apply(self, func, args=(), kwds={}): + def apply(self, func: Callable, args: Sequence = (), kwds: Dict ={}) -> Any: """ Run function `func` with arguments `args` and keyword arguments `kwds` on a remote Pool worker. @@ -1196,7 +1230,7 @@ def apply(self, func, args=(), kwds={}): """ return self.apply_async(func, args, kwds).get() - def map(self, func, iterable, chunksize=None): + def map(self, func: Callable, iterable: Sequence, chunksize: int = None) -> List[Any]: """ For each element `e` in `iterable`, run `func(e)`. The workload is distributed between all the Pool workers. @@ -1215,7 +1249,7 @@ def map(self, func, iterable, chunksize=None): logger.debug('%s map func=%s', self, func) return self.map_async(func, iterable, chunksize).get() - def imap(self, func, iterable, chunksize=1): + def imap(self, func: Callable, iterable: Sequence, chunksize: int = 1) -> Iterator[Any]: """ For each element `e` in `iterable`, run `func(e)`. The workload is distributed between all the Pool workers. This function returns an @@ -1234,7 +1268,7 @@ def imap(self, func, iterable, chunksize=1): res = self.map_async(func, iterable, chunksize) return res.iget_ordered() - def imap_unordered(self, func, iterable, chunksize=1): + def imap_unordered(self, func: Callable, iterable: Sequence, chunksize: int = 1) -> Iterator[Any]: """ For each element `e` in `iterable`, run `func(e)`. The workload is distributed between all the Pool workers. This function returns an @@ -1255,8 +1289,8 @@ def imap_unordered(self, func, iterable, chunksize=1): res = self.map_async(func, iterable, chunksize) return res.iget_unordered() - def starmap_async(self, func, iterable, chunksize=None, callback=None, - error_callback=None): + def starmap_async(self, func: Callable, iterable: Sequence, chunksize: int = None, callback: Callable= None, + error_callback: Callable = None) -> MapResult: """ For each element `args` in `iterable`, run `func(*args)`. The workload is distributed between all the Pool workers. This is an asynchronous @@ -1304,7 +1338,7 @@ def starmap_async(self, func, iterable, chunksize=None, callback=None, return res - def starmap(self, func, iterable, chunksize=None): + def starmap(self, func: Callable, iterable: Sequence, chunksize: int = None) -> MapResult: """ For each element `args` in `iterable`, run `func(*args)`. The workload is distributed between all the Pool workers. @@ -1329,12 +1363,12 @@ def starmap(self, func, iterable, chunksize=None): """ return self.starmap_async(func, iterable, chunksize).get() - def _send_sentinels_to_workers(self): + def _send_sentinels_to_workers(self) -> None: logger.debug('send sentinels(None) to workers %s', self) for i in range(self._processes): self._task_put(None) - def close(self): + def close(self) -> None: """ Close this Pool. This means the current pool will be put in to a closing state and it will not accept new tasks. Existing workers will @@ -1344,15 +1378,15 @@ def close(self): logger.debug('closing pool %s', self) if self._state == RUN: self._state = CLOSE - self._worker_handler._state = CLOSE + self._worker_handler._state = CLOSE # type: ignore for p in self._pool: if hasattr(p, '_sentinel'): - p._state = CLOSE + p._state = CLOSE # type: ignore self._send_sentinels_to_workers() - def terminate(self): + def terminate(self) -> None: """ Terminate this pool. This means that this pool will be terminated and all its pool workers will also be terminated. Task that have been @@ -1361,11 +1395,11 @@ def terminate(self): logger.debug('terminating pool %s', self) logger.debug('set pool._worker_handler.status = TERMINATE') - self._worker_handler._state = TERMINATE - self._state = TERMINATE + self._worker_handler._state = TERMINATE # type: ignore + self._state = TERMINATE # type: ignore for p in self._pool: - p._state = TERMINATE + p._state = TERMINATE # type: ignore pool = self._pool N = min(100, len(pool)) @@ -1387,7 +1421,7 @@ def terminate(self): logger.debug("joining pool._worker_handler") self._worker_handler.join() - def join(self): + def join(self) -> None: """ Wait for all the pool workers of this pool to exit. This should be used after `terminate()` or `close()` are called on this pool. @@ -1396,13 +1430,13 @@ def join(self): assert self._state in (TERMINATE, CLOSE) for p in self._pool: - if p._state not in (TERMINATE, CLOSE): + if p._state not in (TERMINATE, CLOSE): # type: ignore logger.debug("%s.join() ignore newly connected Process %s", self, p) continue p.join() - def wait_until_workers_up(self): + def wait_until_workers_up(self) -> None: logger.debug('%s begin wait_until_workers_up', self) workers_per_fp = config.cpu_per_job n = math.ceil(float(self._processes) / workers_per_fp) @@ -1414,7 +1448,7 @@ def wait_until_workers_up(self): for p in self._pool: logger.debug('%s waiting for _sentinel %s', self, p) - while not hasattr(p, '_sentinel') or p._sentinel is None: + while not hasattr(p, '_sentinel') or p._sentinel is None: # type: ignore time.sleep(0.5) # now all the worker has connected to the master, wait # for some additional time to be sure. @@ -1434,14 +1468,19 @@ class ResilientZPool(ZPool): The API of `ResilientZPool` is the same as `ZPool`. One difference is that if `processes` argument is not set, its default value is 1. """ - def __init__(self, processes=None, initializer=None, initargs=(), - maxtasksperchild=None, cluster=None): - - self.active_peer_dict = {} - self.active_peer_list = [] + def __init__( + self, processes: int = None, + initializer: Callable = None, + initargs: Sequence = (), + maxtasksperchild: int = None, + cluster=None + ) -> None: + + self.active_peer_dict: Dict[str, bool] = {} + self.active_peer_list: List[str] = [] self.peer_lock = threading.Lock() self.taskq = queue.Queue() - self._pending_table = {} + self._pending_table: Dict[str, Dict] = {} super(ResilientZPool, self).__init__( processes=processes, @@ -1467,9 +1506,9 @@ def __init__(self, processes=None, initializer=None, initargs=(), self._task_handler = td ''' - self._pid_to_rid = {} + self._pid_to_rid: Dict[int, str] = {} - def _add_peer(self, ident): + def _add_peer(self, ident: str) -> None: self.peer_lock.acquire() self.active_peer_dict[ident] = True @@ -1478,7 +1517,7 @@ def _add_peer(self, ident): self.peer_lock.release() - def _remove_peer(self, ident): + def _remove_peer(self, ident: str) -> None: # _pendint_table will be cleared later in error handling phase self.peer_lock.acquire() @@ -1487,7 +1526,7 @@ def _remove_peer(self, ident): self.peer_lock.release() - def _res_get(self): + def _res_get(self) -> Any: # check for system messages payload = self._result_sock.recv() data = pickle.loads(payload) @@ -1507,16 +1546,16 @@ def _res_get(self): # skip ident return data[:-1] - def _task_put(self, task): + def _task_put(self, task: Union[None, Tuple[int, int, Callable, Sequence, bool]]) -> None: self.taskq.put(task) - def _handle_tasks(self): + def _handle_tasks(self) -> None: thread = threading.current_thread() taskq = self.taskq master_sock = self._master_sock pending_table = self._pending_table - while thread._state == RUN: + while thread._state == RUN: # type: ignore task = taskq.get() if task is None: @@ -1555,8 +1594,15 @@ def _handle_tasks(self): logger.debug('ResilientZPool _handle_tasks exited') @staticmethod - def _maintain_workers(processes, workers, master_addr, result_addr, initializer, - initargs, maxtasksperchild): + def _maintain_workers( + processes: int, + workers: List[fiber.process.Process], + master_addr: str, + result_addr: str, + initializer: Optional[Callable], + initargs: Optional[Sequence], + maxtasksperchild: Optional[int] + ) -> None: thread = threading.current_thread() workers_per_fp = config.cpu_per_job @@ -1566,7 +1612,7 @@ def _maintain_workers(processes, workers, master_addr, result_addr, initializer, workers) threads = [] - while left > 0 and thread._state == RUN: + while left > 0 and thread._state == RUN: # type: ignore if left > workers_per_fp: n = workers_per_fp @@ -1610,7 +1656,7 @@ def _maintain_workers(processes, workers, master_addr, result_addr, initializer, workers) @staticmethod - def _handle_workers(pool): + def _handle_workers(pool: "ResilientZPool") -> None: # type: ignore[override] logger.debug("%s _handle_workers running", pool) td = threading.current_thread() @@ -1620,7 +1666,7 @@ def _handle_workers(pool): pool._initargs, pool._maxtasksperchild ) - while td._state == RUN: + while td._state == RUN: # type: ignore exited_workers = ResilientZPool._join_exited_workers(pool._pool) if len(exited_workers) > 0: # create new workers when old workers exited @@ -1636,6 +1682,12 @@ def _handle_workers(pool): logger.debug("Resubmitting tasks from failed workers") for worker in exited_workers: + if worker.pid is None: + logger.warn( + "Can't get pid for failed worker: {}".format(worker) + ) + continue + rid = pool._pid_to_rid[worker.pid] # remove rid from active peers pool._remove_peer(rid) @@ -1658,34 +1710,33 @@ def _handle_workers(pool): logger.debug("%s _handle_workers finished. Status is not RUN", pool) - def terminate(self): - self._task_handler._state = TERMINATE + def terminate(self) -> None: + self._task_handler._state = TERMINATE # type: ignore super(ResilientZPool, self).terminate() - def close(self): + def close(self) -> None: logger.debug('closing pool %s', self) if self._state == RUN: self._state = CLOSE - self._worker_handler._state = CLOSE + self._worker_handler._state = CLOSE # type: ignore for p in self._pool: if hasattr(p, '_sentinel'): - p._state = CLOSE + p._state = CLOSE # type: ignore #self._send_sentinels_to_workers() #logger.debug("ResilientZPool _send_sentinels_to_workers: " # "send to task handler") self._task_put(None) - self._task_handler._state = CLOSE + self._task_handler._state = CLOSE # type: ignore - def _send_sentinels_to_workers(self): + def _send_sentinels_to_workers(self) -> None: logger.debug("ResilientZPool _send_sentinels_to_workers: " "send to workers") - data = pickle.dumps(None) for ident in self.active_peer_list: - self._master_sock.send_multipart([ident, b"", data]) + self._master_sock.send(None) #Pool = ZPool diff --git a/fiber/popen_fiber_spawn.py b/fiber/popen_fiber_spawn.py index f4726a3..1388aae 100644 --- a/fiber/popen_fiber_spawn.py +++ b/fiber/popen_fiber_spawn.py @@ -36,6 +36,21 @@ from fiber.backend import get_backend from fiber.core import JobSpec from fiber.core import ProcessStatus +from typing import ( + Any, + Dict, + List, + Iterator, + NoReturn, + Optional, + BinaryIO, + Tuple, +) + +_event_counter: Iterator[int] +_event_dict: Dict[int, "EventConn"] +_fiber_background_thread: Optional[threading.Thread] +_fiber_background_thread_lock: threading.Lock logger = logging.getLogger("fiber") @@ -84,7 +99,17 @@ _event_counter = itertools.count(1) -def get_fiber_init(): +class EventConn(): + def __init__( + self, + conn: socket.socket = None + ): + self.event = threading.Event() + self.event.clear() + self.conn = conn + + +def get_fiber_init() -> str: if config.ipc_active: fiber_init = fiber_init_start + fiber_init_net_active + fiber_init_end else: @@ -94,7 +119,9 @@ def get_fiber_init(): return fiber_init -def fiber_background(listen_addr, event_dict): +def fiber_background( + listen_addr: Tuple[str, int], event_dict: Dict[int, "EventConn"] +) -> None: global admin_host, admin_port # Background thread for handling inter fiber process admin traffic @@ -117,7 +144,7 @@ def fiber_background(listen_addr, event_dict): sentinel = event_dict[-1] # notifi master thread that background thread is ready - sentinel.set() + sentinel.event.set() logger.debug("fiber_background thread ready") while True: conn, addr = sock.accept() @@ -127,30 +154,31 @@ def fiber_background(listen_addr, event_dict): # struct.unpack returns a tuple event if the result only has # one element ident = struct.unpack(" str: if backend_name == "docker": # TODO(jiale) fix python path python_exe = "/usr/local/bin/python" else: python_exe = sys.executable - logger.debug("backend is \"%s\", use python exe \"%s\"", - backend_name, python_exe) + logger.debug( + 'backend is "%s", use python exe "%s"', backend_name, python_exe + ) return python_exe -def get_pid_from_jid(jid): +def get_pid_from_jid(jid: Any) -> int: # Some Linux system has 32768 as max pid number. 32749 is a prime number # close to 32768. return hash(jid) % 32749 @@ -159,20 +187,26 @@ def get_pid_from_jid(jid): class Popen(object): method = "spawn" - def __del__(self): - if getattr(self, "ident", None): + def __del__(self) -> None: + ident = getattr(self, "ident", None) + if ident is not None: # clean up entry in event_dict global _event_dict - #logger.debug("cleanup entry _event_dict[%s]", self.ident) - _event_dict.pop(self.ident, None) + if ident in _event_dict: + del _event_dict[ident] - def __repr__(self): + def __repr__(self) -> str: return "<{}({})>".format( type(self).__name__, getattr(self, "process_obj", None) ) - def __init__(self, process_obj, backend=None, launch=False): - self.returncode = None + def __init__( + self, + process_obj: "fiber.process.Process", + backend=None, + launch: bool = False, + ) -> None: + self.returncode: Optional[int] = None self.backend = get_backend() ip, _, _ = self.backend.get_listen_addr() @@ -181,20 +215,20 @@ def __init__(self, process_obj, backend=None, launch=False): self.master_port = config.ipc_admin_master_port self.worker_port = config.ipc_admin_worker_port - self.sock = None - self.host = "" + self.sock: Optional[socket.socket] = None + self.host: str = "" - self.job = None - self.pid = None + self.job: Optional[fiber.core.Job] = None + self.pid: Optional[int] = None self.process_obj = process_obj - self._exiting = None - self.sentinel = None - self.ident = None + self._exiting: bool = False + self.sentinel: Optional[socket.socket] = None + self.ident: int = -1 if launch: self._launch(process_obj) - def launch_fiber_background_thread_if_needed(self): + def launch_fiber_background_thread_if_needed(self) -> None: global _fiber_background_thread_lock _fiber_background_thread_lock.acquire() @@ -203,34 +237,32 @@ def launch_fiber_background_thread_if_needed(self): _fiber_background_thread_lock.release() return - try: - logger.debug( - "_fiber_background_thread is None, creating " - "background thread" - ) - # Create a background thread to handle incoming connections - # from fiber child processes - event = threading.Event() - event.clear() - _event_dict[-1] = event - td = threading.Thread( - target=fiber_background, - args=((self.master_host, self.master_port), _event_dict), - daemon=True, - ) - td.start() - except Exception as e: - raise e - finally: - logger.debug("waiting for background thread") - event.wait() - logger.debug( - "master received message that fiber_background thread is ready" - ) - _fiber_background_thread = td - _fiber_background_thread_lock.release() + ec = EventConn() - def get_command_line(self, **kwds): + logger.debug( + "_fiber_background_thread is None, creating " + "background thread" + ) + # Create a background thread to handle incoming connections + # from fiber child processes + _event_dict[-1] = ec + + td = threading.Thread( + target=fiber_background, + args=((self.master_host, self.master_port), _event_dict), + daemon=True, + ) + td.start() + + logger.debug("waiting for background thread") + ec.event.wait() + logger.debug( + "master received message that fiber_background thread is ready" + ) + _fiber_background_thread = td + _fiber_background_thread_lock.release() + + def get_command_line(self, **kwds) -> List[str]: """Returns prefix of command line used for spawning a child process.""" prog = get_fiber_init() prog = prog.format(**kwds) @@ -248,13 +280,7 @@ def get_command_line(self, **kwds): + ["-c", prog, "--multiprocessing-fork"] ) - def _accept(self): - conn, addr = self.sock.accept() - logger.debug("successfully accept") - # TODO verify if it's the same client - return conn - - def _get_job(self, cmd): + def _get_job(self, cmd: List[str]) -> fiber.core.JobSpec: spec = JobSpec( command=cmd, image=config.image, @@ -262,36 +288,41 @@ def _get_job(self, cmd): cpu=config.cpu_per_job, mem=config.mem_per_job, ) - if hasattr(self.process_obj._target, "__self__"): + if hasattr(self.process_obj.target, "__self__"): metadata = getattr( - self.process_obj._target.__self__, "__fiber_meta__", None + self.process_obj.target.__self__, "__fiber_meta__", None ) else: - metadata = getattr(self.process_obj._target, "__fiber_meta__", None) + metadata = getattr(self.process_obj.target, "__fiber_meta__", None) if metadata: for k, v in metadata.items(): setattr(spec, k, v) return spec - def _run_job(self, job): + def _run_job(self, job_spec: fiber.core.JobSpec) -> fiber.core.Job: + try: - job = self.backend.create_job(job) + job = self.backend.create_job(job_spec) except requests.exceptions.ReadTimeout as e: raise mp.TimeoutError(str(e)) + self.job = job return job - def _sentinel_readable(self, timeout=0): + def _sentinel_readable(self, timeout: int = 0) -> int: # Use fcntl(fd, F_GETFD) instead of select.* becuase: # * select.select() can't work with fd > 1024 # * select.poll() is not thread safe # * select.epoll() is Linux only # Also, fcntl(fd, F_GETFD) is cheaper than the above calls. + if self.sentinel is None: + return False + return fcntl.fcntl(self.sentinel, fcntl.F_GETFD) - def poll(self, flag=os.WNOHANG): + def poll(self, flag: int = os.WNOHANG) -> Optional[int]: # returns None if the process is not stopped yet. Otherwise, returns # process exit code. @@ -327,7 +358,7 @@ def poll(self, flag=os.WNOHANG): return None return self.wait(timeout=0) - def wait(self, timeout=None): + def wait(self, timeout: int = None) -> Optional[int]: if self.job is None: # self.job is None meaning this process hasn't been fully started # yet. @@ -345,7 +376,7 @@ def wait(self, timeout=None): self.returncode = code return self.returncode - def _pickle_data(self, data, fp): + def _pickle_data(self, data, fp: BinaryIO) -> None: if fiber.util.is_in_interactive_console(): logger.debug("in interactive shell, use cloudpickle") cloudpickle.dump(data, fp) @@ -353,7 +384,7 @@ def _pickle_data(self, data, fp): logger.debug("not in interactive shell, use reduction") reduction.dump(data, fp) - def _launch(self, process_obj): + def _launch(self, process_obj) -> None: logger.debug("%s %s _launch called", process_obj, self) if config.ipc_active: @@ -389,15 +420,15 @@ def _launch(self, process_obj): cwd=os.getcwd(), host=admin_host, port=port, id=ident ) - job = self._get_job(cmd) + job_spec = self._get_job(cmd) + + ec = EventConn() - event = threading.Event() - event.clear() - _event_dict[ident] = event + _event_dict[ident] = ec logger.debug( "%s popen_fiber_spawn created event %s and set _event_dict[%s]", self, - event, + ec.event, ident, ) @@ -427,7 +458,7 @@ def _launch(self, process_obj): process_obj._popen = self # launch job - job = self._run_job(job) + job = self._run_job(job_spec) self.pid = get_pid_from_jid(job.jid) # Fix process obj's pid process_obj.ident = self.pid @@ -447,16 +478,22 @@ def _launch(self, process_obj): "connect back" ) return - done = event.wait(0.5) + done = ec.event.wait(0.5) status = self.check_status() if status == ProcessStatus.STOPPED: return logger.debug( "popen_fiber_spawn is waiting for accept event %s to finish", - event, + ec.event, ) - conn = _event_dict[ident] + conn = ec.conn + if conn is None: + raise mp.ProcessError( + "ec.conn should be set at this moment but it is None. " + "Please report this bug to Fiber developers." + ) + logger.debug("got conn from _event_counter[%s]", ident) del _event_dict[ident] logger.debug("remove entry _event_counter[%s]", ident) @@ -511,7 +548,10 @@ def _launch(self, process_obj): self.sentinel = conn logger.debug("_launch finished") - def check_status(self): + def check_status(self) -> fiber.core.ProcessStatus: + if self.job is None: + return ProcessStatus.UNKNOWN + status = self.backend.get_job_status(self.job) if status == ProcessStatus.STOPPED: # something happened that caused Fiber process to hit an early stop @@ -525,7 +565,7 @@ def check_status(self): return status - def terminate(self): + def terminate(self) -> None: logger.debug("[Popen]terminate() called") self._exiting = True diff --git a/fiber/process.py b/fiber/process.py index e411834..af7584f 100644 --- a/fiber/process.py +++ b/fiber/process.py @@ -36,23 +36,29 @@ import logging from multiprocessing.process import BaseProcess +from .popen_fiber_spawn import Popen +from typing import Callable, Sequence, Tuple, Optional, List, Set, Iterator logger = logging.getLogger('fiber') +_children: Set["Process"] +_current_process: "BaseProcess" +_process_counter: Iterator[int] + _children = set() _current_process = mp.current_process() -def _cleanup(): +def _cleanup() -> None: # check for processes which have finished for p in list(_children): if p._popen.poll() is not None: _children.discard(p) -def active_children(): +def active_children() -> List["Process"]: """ Get a list of children processes of the current process. @@ -69,7 +75,7 @@ def active_children(): return list(_children) -def current_process(): +def current_process() -> "BaseProcess": """Return a Process object representing the current process. Example: @@ -158,23 +164,27 @@ class Process(BaseProcess): can call `select` and other eligible functions that works on fds on this file descriptor. """ + _popen: fiber.popen_fiber_spawn.Popen + _name: str + _pid: Optional[int] + _start_method = None _pid = None @staticmethod - def _Popen(process_obj): - from .popen_fiber_spawn import Popen + def _Popen(process_obj: "Process") -> Popen: return Popen(process_obj) - def __init__(self, group=None, target=None, name=None, args=(), kwargs={}, - *, daemon=None): + def __init__(self, group: None = None, target: Callable = None, + name: str = None, args: Tuple = (), kwargs={}, + *, daemon: bool = None): super(Process, self).__init__(group=group, target=target, name=name, args=args, kwargs=kwargs, daemon=daemon) self._parent_pid = current_process().pid # set when Process.start() failed self._start_failed = False - def __repr__(self): + def __repr__(self) -> str: return "{}({}, {})>".format(type(self).__name__, self._name, self.ident) def run(self): @@ -184,7 +194,7 @@ def run(self): """ return super().run() - def start(self): + def start(self) -> None: """Start this process. Under the hood, Fiber calls the API on the computer cluster to start a @@ -214,7 +224,7 @@ def start(self): del self._target, self._args, self._kwargs _children.add(self) - def terminate(self): + def terminate(self) -> None: """Terminate current process. When running locally, Fiber sends an SIGTERM signal to the child @@ -227,7 +237,7 @@ def terminate(self): return self._popen.terminate() - def join(self, timeout=None): + def join(self, timeout=None) -> None: """Wait for this process to terminate. :param timeout: The maximum duration of time in seconds that this call @@ -236,11 +246,11 @@ def join(self, timeout=None): `timeout` is `0`, it will check if the process has exited and return immediately. - :returns: The exit code of this process + :returns: None if process terminates or the method times out """ return super().join(timeout=timeout) - def is_alive(self): + def is_alive(self) -> bool: """Check if current process is still alive :returns: `True` if current process is still alive. Returns `False` if @@ -249,19 +259,26 @@ def is_alive(self): return super().is_alive() @property - def ident(self): + def ident(self) -> Optional[int]: if self._pid is None: - self._pid = self._popen and self._popen.pid + self._pid = self._popen.pid if self._popen else None return self._pid @ident.setter - def ident(self, pid): + def ident(self, pid: int): self._pid = pid - pid = ident + @property + def pid(self) -> Optional[int]: + return self.ident + + @property + def target(self): + return self._target - def _bootstrap(self): + + def _bootstrap(self) -> Tuple[int, Optional[str]]: from multiprocessing import util, context global _current_process, _process_counter, _children err = None @@ -279,7 +296,8 @@ def _bootstrap(self): fiber.util._finalizer_registry.clear() fiber.util._run_after_forkers() except Exception as e: - err = e + import traceback + err = traceback.format_exc() finally: # delay finalization of the old process object until after # _run_after_forkers() is executed @@ -304,7 +322,8 @@ def _bootstrap(self): sys.stderr.write(str(e.args[0]) + '\n') exitcode = 1 except Exception as e: # noqa E722 - err = e + import traceback + err = traceback.format_exc() exitcode = 1 import traceback msg = traceback.format_exc() diff --git a/fiber/queues.py b/fiber/queues.py index 627f856..abf4e1d 100644 --- a/fiber/queues.py +++ b/fiber/queues.py @@ -63,18 +63,22 @@ from fiber.process import current_process from fiber.backend import get_backend from fiber.socket import Socket, ProcessDevice +from typing import Any, Dict, List, NoReturn, Tuple, Union, Optional + +_io_threads: List[Any] +_poller_threads: Dict[Any, Any] # define the port range that fiber can use to listen for incoming connections MIN_PORT = 40000 MAX_PORT = 65535 -logger = logging.getLogger('fiber') +logger = logging.getLogger("fiber") _io_threads = [] _poller_threads = {} -def _clean_up(): +def _clean_up() -> None: for t in _io_threads: logger.debug("cleanup thread %s", t) t.stop() @@ -98,8 +102,14 @@ class ZConnection(multiprocessing.connection._ConnectionBase): > note: ZConnection's fileno method returns a Fiber socket. """ - def __init__(self, handle, readable=True, writable=True): - self._name = None + def __init__( + self, + handle: Union[Socket, Tuple[str, str]], + readable: bool = True, + writable: bool = True, + ) -> None: + self._handle: Optional[Socket] + self._name: Optional[str] = None if handle is None: raise ValueError("invalid socket") @@ -112,84 +122,111 @@ def __init__(self, handle, readable=True, writable=True): if not readable and not writable: raise ValueError( - "at least one of `readable` and `writable` must be True") + "at least one of `readable` and `writable` must be True" + ) self._readable = readable self._writable = writable mp_register_after_fork(self, ZConnection._create_handle) atexit.register(self._close) - def __getstate__(self): - return {"sock_type": self.sock_type, - "dest_addr": self.dest_addr, - "_readable": self._readable, - "_writable": self._writable, - "_name": self._name} - - def __setstate__(self, state): - self.sock_type = state['sock_type'] - self.dest_addr = state['dest_addr'] - self._readable = state['_readable'] - self._writable = state['_writable'] - self._name = state['_name'] + def __getstate__(self) -> Dict: + return { + "sock_type": self.sock_type, + "dest_addr": self.dest_addr, + "_readable": self._readable, + "_writable": self._writable, + "_name": self._name, + } + + def __setstate__(self, state) -> None: + self.sock_type = state["sock_type"] + self.dest_addr = state["dest_addr"] + self._readable = state["_readable"] + self._writable = state["_writable"] + self._name = state["_name"] self._create_handle() mp_register_after_fork(self, ZConnection._create_handle) atexit.register(self._close) - def __del__(self): + def __del__(self) -> None: # __del__ is not reliable, we use atexit to clean up things return - def __repr__(self): + def __repr__(self) -> str: name = ( self._name if getattr(self, "_name", None) is not None else self.dest_addr ) - return ''.format(name, getattr(self, "_handle", None)) + return "".format( + name, getattr(self, "_handle", None) + ) - def _create_handle(self): + def _create_handle(self) -> None: logger.debug("%s _create_handle called", self) - #self._handle = context.socket(self.sock_type) + # self._handle = context.socket(self.sock_type) self._handle = Socket(mode=self.sock_type) self._handle.connect(self.dest_addr) logger.debug("connect to %s", self.dest_addr) - def _close(self): + def _close(self) -> None: if self._handle: self._handle.close() self._handle = None - def _send_bytes(self, buf): - self._handle.send(buf) + def _check_readable(self): + if not self._readable: + raise RuntimeError("ZConnection is readonly") + + def _check_closed(self): + if self._handle is None: + raise RuntimeError("ZConnection handle is closed") + + def _send_bytes(self, buf: bytes) -> None: + self._handle.send(buf) # type: ignore # TODO(jiale) _ConnectionBase's send use _ForkingPickler. Switch to # send_pyobj instead? - def _recv_bytes(self, maxsize=None): + def _recv_bytes(self, maxsize: int = None) -> bytes: # TODO(jiale) support maxsize - return self._handle.recv() + return self._handle.recv() # type: ignore + + def send(self, obj): + """Send a (picklable) object""" + self._check_closed() + self._check_writable() + self._send_bytes(reduction.ForkingPickler.dumps(obj)) - def recv(self): + def recv(self) -> Any: """Receive a (picklable) object""" - #logger.debug("recv called") + # logger.debug("recv called") self._check_closed() self._check_readable() buf = self._recv_bytes() return reduction.ForkingPickler.loads(buf) - def _poll(self, timeout): - return bool(self._handle.poll(timeout=timeout)) + def _poll(self, timeout: float) -> bool: + # return bool(self._handle.poll(timeout=timeout)) + # ZConnection doesn't support poll yet + raise NotImplementedError - def set_name(self, name): + def set_name(self, name: str) -> None: self._name = name class LazyZConnection(ZConnection): - def __init__(self, handle, readable=True, writable=True, name=None): - self._name = None + def __init__( + self, + handle: Union[Socket, Tuple[str, str]], + readable: bool = True, + writable: bool = True, + name: Optional[str] = None, + ) -> None: + self._name: Optional[str] = None if handle is None: raise ValueError("invalid socket") @@ -206,60 +243,63 @@ def __init__(self, handle, readable=True, writable=True, name=None): if not readable and not writable: raise ValueError( - "at least one of `readable` and `writable` must be True") + "at least one of `readable` and `writable` must be True" + ) self._readable = readable self._writable = writable - def _check_closed(self): + def _check_closed(self) -> None: if self._inited is False: self._create_handle() self._inited = True if self._handle is None: raise OSError("handle is closed") - def _close(self): + def _close(self) -> None: if not self._inited: return - self._handle.close() + self._handle.close() # type: ignore - def _send_bytes(self, buf): - #self._handle.send_multipart([b"", buf]) - self._handle.send(buf) + def _send_bytes(self, buf: bytes) -> None: + # self._handle.send_multipart([b"", buf]) + self._handle.send(buf) # type: ignore - def _recv_bytes(self, maxsize=None): + def _recv_bytes(self, maxsize: int = None) -> bytes: # TODO(jiale) support maxsize - #msg = self._handle.recv_multipart() + # msg = self._handle.recv_multipart() # msg -> b'' b'message data' (because of ROUTER) - #data = msg[1] - data = self._handle.recv() + # data = msg[1] + data = self._handle.recv() # type: ignore return data - def __getstate__(self): - return {"sock_type": self.sock_type, - "dest_addr": self.dest_addr, - "_readable": self._readable, - "_writable": self._writable} - - def __setstate__(self, state): - self.sock_type = state['sock_type'] - self.dest_addr = state['dest_addr'] - self._readable = state['_readable'] - self._writable = state['_writable'] + def __getstate__(self) -> Dict: + return { + "sock_type": self.sock_type, + "dest_addr": self.dest_addr, + "_readable": self._readable, + "_writable": self._writable, + } + + def __setstate__(self, state) -> None: + self.sock_type = state["sock_type"] + self.dest_addr = state["dest_addr"] + self._readable = state["_readable"] + self._writable = state["_writable"] self._inited = False class LazyZConnectionPipe(LazyZConnection): - def _send_bytes(self, buf): - self._handle.send(buf) + def _send_bytes(self, buf: bytes) -> None: + self._handle.send(buf) # type: ignore - def _recv_bytes(self, maxsize=None): + def _recv_bytes(self, maxsize: int = None) -> Any: # TODO(jiale) support maxsize - return self._handle.recv() + return self._handle.recv() # type: ignore -def Pipe(duplex=True): +def Pipe(duplex: bool = True) -> Tuple[LazyZConnection, LazyZConnection]: """Return a pair of connected ZConnection objects. :param duplex: if duplex, then both read and write are allowed on each @@ -275,22 +315,28 @@ def Pipe(duplex=True): d.start() if duplex: - return (LazyZConnection(("rw", d.out_addr,)), - LazyZConnection(("rw", d.in_addr,))) - return (LazyZConnection(("r", d.out_addr,)), - LazyZConnection(("w", d.in_addr,))) + return ( + LazyZConnection(("rw", d.out_addr,)), + LazyZConnection(("rw", d.in_addr,)), + ) + return ( + LazyZConnection(("r", d.out_addr,)), + LazyZConnection(("w", d.in_addr,)), + ) -class SimpleQueuePush(): +class SimpleQueuePush: """A queue build on top of Fiber socket. It uses "w" - ("r" - "w") - "r" socket combination. Messages are pushed from one end of the queue to the other end without explicitly pulling. """ - def __repr__(self): + + def __repr__(self) -> str: return "SimpleQueuePush".format( - self._reader_addr, self._writer_addr) + self._reader_addr, self._writer_addr + ) - def __init__(self): + def __init__(self) -> None: self.done = False backend = get_backend() ip, _, _ = backend.get_listen_addr() @@ -307,11 +353,11 @@ def __init__(self): # set reader to None because if reader is connected, Fiber socket will # fairly queue messages to all readers even if this reader is # not reading. - #self.reader = None + # self.reader = None self.reader = LazyZConnection(("r", self._reader_addr,)) self.writer = LazyZConnection(("w", self._writer_addr,)) - def get(self): + def get(self) -> Any: """Get an element from this Queue. :returns: An element from this queue. If there is no element in the @@ -326,22 +372,22 @@ def get(self): self.reader = LazyZConnection(("r", self._reader_addr,)) """ - #data = self.reader._handle.recv() + # data = self.reader._handle.recv() msg = self.reader.recv() - #msg = reduction.ForkingPickler.loads(data) + # msg = reduction.ForkingPickler.loads(data) logger.debug("%s got %s", self, msg) return msg - def put(self, obj): + def put(self, obj: Any) -> None: """Put an element into the Queue. :param obj: Any picklable Python object. """ logger.debug("%s put %s", self, obj) - #data = reduction.ForkingPickler.dumps(obj) + # data = reduction.ForkingPickler.dumps(obj) self.writer.send(obj) - ''' + """ def __getstate__(self): d = self.__dict__.copy() # set reader to None so that de-serialized Queue doesn't connect to @@ -349,8 +395,8 @@ def __getstate__(self): # to all the readers. We want to prevent this behavior. d["reader"] = None return d - ''' + """ -#Pipe = ClassicPipe +# Pipe = ClassicPipe SimpleQueue = SimpleQueuePush diff --git a/fiber/socket.py b/fiber/socket.py index c277c2a..af63a12 100644 --- a/fiber/socket.py +++ b/fiber/socket.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from abc import ABC, abstractmethod import random import logging import multiprocessing as mp import threading from fiber.backend import get_backend +from typing import Any, NoReturn, Optional, Tuple MIN_PORT = 40000 @@ -25,7 +27,7 @@ logger = logging.getLogger("fiber") socket_lib = "nanomsg" -default_socket_ctx = None +default_socket_ctx: "SockContext" if socket_lib == "nanomsg": @@ -41,11 +43,11 @@ raise ValueError("bad socket_lib value: {}".format(socket_lib)) -def get_ctx(): +def get_ctx() -> "SockContext": return default_socket_ctx -def bind_to_random_port(sock, addr_base, max_tries=100, nng=True): +def bind_to_random_port(sock, addr_base: str, max_tries:int = 100, nng: bool = True) -> Optional[int]: num_tries = 0 while num_tries < max_tries: try: @@ -60,32 +62,36 @@ def bind_to_random_port(sock, addr_base, max_tries=100, nng=True): num_tries += 1 continue - return + return None -class SockContext: - default_addr = None +class SockContext(ABC): - def new(self, mode): - raise NotImplementedError + default_addr:str = "" + + @abstractmethod + def new(self, mode: str) -> "Socket": + pass @staticmethod - def bind_random(sock, addr): - raise NotImplementedError + @abstractmethod + def bind_random(sock, addr: str) -> int: + pass @staticmethod - def connect(sock, addr): - raise NotImplementedError + @abstractmethod + def connect(sock, addr: str) -> None: + pass @staticmethod - def close(sock): + def close(sock) -> None: sock.close() class ZMQContext(SockContext): default_addr = "tcp://0.0.0.0" - def __init__(self): + def __init__(self) -> None: self._mode_to_type = { "r": zmq.DEALER, @@ -96,24 +102,28 @@ def __init__(self): } self.context = zmq.Context.instance() - def new(self, mode): + def new(self, mode: str) -> "zmq.Socket": sock_type = self._mode_to_type[mode] if sock_type is None: return None return self.context.socket(sock_type) @staticmethod - def bind_random(sock, addr): + def bind_random(sock, addr: str) -> int: assert type(addr) == str - return sock.bind_to_random_port( + port = sock.bind_to_random_port( addr, min_port=MIN_PORT, max_port=MAX_PORT, max_tries=100 ) + if port is None: + raise RuntimeError("ZMQContext Failed to bind to a random port") + + return port @staticmethod - def connect(sock, addr): + def connect(sock, addr) -> Any: return sock.connect(addr) - def device(self, s1_mode, s2_mode): + def device(self, s1_mode: str, s2_mode: str) -> Tuple["zmq.devices.ThreadDevice", str, str]: backend = get_backend() ip_ext, _, _ = backend.get_listen_addr() @@ -140,7 +150,7 @@ class NNGDevice: STATE_AFTER_INIT = 1 STATE_FINISHED = 2 - def __init__(self, ctx, s1_mode, s2_mode, default_addr="tcp://0.0.0.0"): + def __init__(self, ctx: "NNGContext", s1_mode: str, s2_mode: str, default_addr: str = "tcp://0.0.0.0") -> None: self._mode_to_opener = { "r": pynng.lib.nng_pull0_open_raw, @@ -153,7 +163,7 @@ def __init__(self, ctx, s1_mode, s2_mode, default_addr="tcp://0.0.0.0"): self.default_addr = default_addr self._start_process() - def _start_process(self): + def _start_process(self) -> None: parent_conn, child_conn = mp.Pipe() self._proc = threading.Thread( target=self._run, args=(child_conn,), daemon=True @@ -168,15 +178,7 @@ def _start_process(self): #child_conn.close() self.conn = parent_conn - def _mode_to_opener(self, mode): - opener = self._mode_to_opener[mode] - if opener is None: - raise ValueError( - "Mode {} not supported by {}", mode, self.__class__.__name__ - ) - return opener - - def _create_socks(self): + def _create_socks(self) -> Tuple["pynng.Socket", "pynng.Socket"]: opener1 = self._mode_to_opener[self.s1_mode] opener2 = self._mode_to_opener[self.s2_mode] @@ -186,18 +188,18 @@ def _create_socks(self): return s1, s2 - def _bind_socks(self, s1, s2): + def _bind_socks(self, s1: "pynng.Socket", s2: "pynng.Socket") -> Tuple[Optional[int], Optional[int]]: port1 = bind_to_random_port(s1, self.default_addr) port2 = bind_to_random_port(s2, self.default_addr) return port1, port2 - def _run_device(self, s1, s2): + def _run_device(self, s1: "pynng.Socket", s2: "pynng.Socket") -> None: ret = pynng.lib.nng_device(s1.socket, s2.socket) check_err(ret) - def _run(self, conn): + def _run(self, conn: "mp.connection.Connection") -> None: s1, s2 = self._create_socks() state = NNGDevice.STATE_INIT @@ -233,7 +235,7 @@ def _run(self, conn): else: break - def bind(self): + def bind(self) -> Tuple[str, str]: self.conn.send("#bind") in_addr, out_addr = self.conn.recv() if in_addr is None: @@ -242,7 +244,7 @@ def bind(self): self.out_addr = out_addr return in_addr, out_addr - def start(self): + def start(self) -> None: self.conn.send("#start") code = self.conn.recv() if code is not None: @@ -252,7 +254,7 @@ def start(self): class NNGContext(SockContext): default_addr = "tcp://0.0.0.0" - def __init__(self): + def __init__(self) -> None: self._mode_to_creator = { "r": pynng.Pull0, @@ -262,21 +264,25 @@ def __init__(self): "rep": pynng.Rep0, } - def new(self, mode): + def new(self, mode: str) -> "pynng.Socket": func = self._mode_to_creator[mode] if func is None: return None return func() @staticmethod - def bind_random(sock, addr): - return bind_to_random_port(sock, addr) + def bind_random(sock, addr: str) -> int: + port = bind_to_random_port(sock, addr) + if port is None: + raise RuntimeError("NNGContext Failed to bind to a random port") + + return port @staticmethod - def connect(sock, addr): + def connect(sock, addr: str) -> "pynng.Dialer": return sock.dial(addr) - def device(self, s1_mode, s2_mode): + def device(self, s1_mode: str, s2_mode: str) -> Tuple[NNGDevice, str, str]: self.s1_mode = s1_mode self.s2_mode = s2_mode @@ -295,7 +301,7 @@ def device(self, s1_mode, s2_mode): class NanomsgDevice(NNGDevice): - def __init__(self, ctx, s1_mode, s2_mode, default_addr="tcp://0.0.0.0"): + def __init__(self, ctx: "NanomsgContext", s1_mode: str, s2_mode: str, default_addr: str ="tcp://0.0.0.0") -> None: self.s1_mode = s1_mode self.s2_mode = s2_mode self.default_addr = default_addr @@ -303,18 +309,25 @@ def __init__(self, ctx, s1_mode, s2_mode, default_addr="tcp://0.0.0.0"): self._start_process() - def _create_socks(self): + def _create_socks(self) -> Tuple[nnpy.Socket, nnpy.Socket]: s1 = nnpy.Socket(nnpy.AF_SP_RAW, self.ctx._mode_to_type[self.s1_mode]) s2 = nnpy.Socket(nnpy.AF_SP_RAW, self.ctx._mode_to_type[self.s2_mode]) return s1, s2 - def _bind_socks(self, s1, s2): + def _bind_socks(self, s1: nnpy.Socket, s2: nnpy.Socket) -> Tuple[int, int]: port1 = bind_to_random_port(s1, self.default_addr, nng=False) port2 = bind_to_random_port(s2, self.default_addr, nng=False) + + if port1 is None: + raise RuntimeError("NanomsgDevice failed to bind port1") + + if port2 is None: + raise RuntimeError("NanomsgDevice failed to bind port2") + return port1, port2 - def _run_device(self, s1, s2): + def _run_device(self, s1: nnpy.Socket, s2: nnpy.Socket) -> Any: rc = nnpy.nanomsg.nn_device(s1.sock, s2.sock) return nnpy.errors.convert(rc, rc) @@ -323,7 +336,7 @@ def _run_device(self, s1, s2): class NanomsgContext(SockContext): default_addr = "tcp://0.0.0.0" - def __init__(self): + def __init__(self) -> None: self._mode_to_type = { "r": nnpy.PULL, @@ -333,22 +346,30 @@ def __init__(self): "req": nnpy.REQ, } - def new(self, mode): + def new(self, mode) -> nnpy.Socket: sock_type = self._mode_to_type[mode] if sock_type is None: - return None + raise RuntimeError( + "NangmsgContext got Invalid mode: {}".format(mode) + ) + return nnpy.Socket(nnpy.AF_SP, sock_type) @staticmethod - def bind_random(sock, addr): - return bind_to_random_port(sock, addr, nng=False) + def bind_random(sock, addr) -> int: + port = bind_to_random_port(sock, addr, nng=False) + + if port is None: + raise RuntimeError("NanomsgContext Failed to bind to a random port") + + return port @staticmethod - def connect(sock, addr): + def connect(sock, addr) -> Any: return sock.connect(addr) - def device(self, s1_mode, s2_mode): + def device(self, s1_mode, s2_mode) -> Tuple[NanomsgDevice, str, str]: self.s1_mode = s1_mode self.s2_mode = s2_mode @@ -377,13 +398,13 @@ def device(self, s1_mode, s2_mode): class Socket: - def __repr__(self): + def __repr__(self) -> str: return "{}<{},{}>".format( self.__class__.__name__, self._ctx.__class__.__name__, self._mode) - def __init__(self, ctx=get_ctx(), mode="rw"): + def __init__(self, ctx=get_ctx(), mode="rw") -> None: self._mode = mode self._ctx = ctx self._sock = ctx.new(mode) @@ -391,35 +412,35 @@ def __init__(self, ctx=get_ctx(), mode="rw"): raise ValueError("Socket mode \"{}\" not supported by {}".format( mode, ctx.__class__.__name__)) - def send(self, data): + def send(self, data) -> None: self._sock.send(data) - def recv(self): + def recv(self) -> bytes: return self._sock.recv() - def bind(self): + def bind(self) -> int: addr = self._ctx.default_addr bind_random = self._ctx.bind_random port = bind_random(self._sock, addr) return port - def connect(self, addr): + def connect(self, addr) -> None: _connect = self._ctx.connect _connect(self._sock, addr) - def close(self): + def close(self) -> None: _close = self._ctx.close _close(self._sock) class ProcessDevice: - def __init__(self, s1_mode, s2_mode, ctx=get_ctx()): + def __init__(self, s1_mode: Any, s2_mode: Any, ctx=get_ctx()) -> None: device, in_addr, out_addr = ctx.device(s1_mode, s2_mode) self.device = device self.in_addr = in_addr self.out_addr = out_addr - def start(self): + def start(self) -> None: self.device.start() logger.debug("started device in:%s out:%s", self.in_addr, self.out_addr) diff --git a/fiber/spawn.py b/fiber/spawn.py index cb48f07..859d2f9 100644 --- a/fiber/spawn.py +++ b/fiber/spawn.py @@ -25,12 +25,13 @@ from fiber.init import init_fiber +from typing import Any, NoReturn logger = logging.getLogger('fiber') -def exit_by_signal(): +def exit_by_signal() -> NoReturn: logger.info("Exiting, sending SIGTERM to current process") os.kill(os.getpid(), signal.SIGTERM) @@ -41,7 +42,7 @@ def exit_by_signal(): os._exit(1) -def exit_on_fd_close(fd): +def exit_on_fd_close(fd) -> None: while True: rl, _, _ = select.select([fd], [], [], 0) if fd in rl: @@ -51,7 +52,7 @@ def exit_on_fd_close(fd): time.sleep(1) -def spawn_prepare(fd): +def spawn_prepare(fd) -> int: from_parent_r = os.fdopen(fd, "rb", closefd=False) preparation_data = reduction.pickle.load(from_parent_r) diff --git a/fiber/util.py b/fiber/util.py index a9378d5..1b15ad2 100644 --- a/fiber/util.py +++ b/fiber/util.py @@ -10,7 +10,7 @@ import itertools import logging -import multiprocessing.util as mpu +import multiprocessing.util as mpu # type: ignore import os import re import sys @@ -19,9 +19,14 @@ import weakref import psutil +from typing import Any, Dict, Iterator, Tuple, Optional, Callable, Sequence +_afterfork_registry: weakref.WeakValueDictionary +_finalizer_counter: Iterator +_finalizer_registry: Dict[Tuple[Any, Any], "Finalize"] -logger = logging.getLogger('fiber') + +logger = logging.getLogger("fiber") _afterfork_registry = weakref.WeakValueDictionary() _afterfork_counter = itertools.count() @@ -30,27 +35,35 @@ _finalizer_counter = itertools.count() -def register_after_fork(obj, func): +def register_after_fork(obj, func) -> None: _afterfork_registry[(next(_afterfork_counter), id(obj), func)] = obj -def _run_after_forkers(): - logging.debug('_fun_after_forkers called') +def _run_after_forkers() -> None: + logging.debug("_fun_after_forkers called") items = list(_afterfork_registry.items()) items.sort() for (index, ident, func), obj in items: try: - logging.debug('run after forker %s(%s)', func, obj) + logging.debug("run after forker %s(%s)", func, obj) func(obj) except Exception as e: - logging.info('after forker raised exception %s', e) + logging.info("after forker raised exception %s", e) class Finalize(mpu.Finalize): """Basically this is the same as multiprocessing's Finalize class except this one uses it's own _finalizer_registry. """ - def __init__(self, obj, callback, args=(), kwargs=None, exitpriority=None): + + def __init__( + self, + obj: Any, + callback: Callable, + args: Sequence = (), + kwargs: Dict[str, Any] = None, + exitpriority: int = None, + ) -> None: assert exitpriority is None or type(exitpriority) is int if obj is not None: @@ -67,7 +80,7 @@ def __init__(self, obj, callback, args=(), kwargs=None, exitpriority=None): _finalizer_registry[self._key] = self -def find_ip_by_net_interface(target_interface): +def find_ip_by_net_interface(target_interface: str) -> Optional[str]: """Returns ip, debug_info.""" ifces = psutil.net_if_addrs() ip = None @@ -84,11 +97,11 @@ def find_ip_by_net_interface(target_interface): class ForkAwareThreadLock(object): - def __init__(self): + def __init__(self) -> None: self._reset() register_after_fork(self, ForkAwareThreadLock._reset) - def _reset(self): + def _reset(self) -> None: self._lock = threading.Lock() self.acquire = self._lock.acquire self.release = self._lock.release @@ -101,14 +114,14 @@ def __exit__(self, *args): class ForkAwareLocal(threading.local): - def __init__(self): + def __init__(self) -> None: register_after_fork(self, lambda obj: obj.__dict__.clear()) def __reduce__(self): return type(self), () -def find_listen_address(): +def find_listen_address() -> Tuple[Optional[str], Optional[str]]: """Find an IP address for Fiber to use.""" ip = None ifce = None @@ -124,8 +137,8 @@ def find_listen_address(): return ip, ifce -def is_in_interactive_console(): - if hasattr(sys, 'ps1'): +def is_in_interactive_console() -> bool: + if hasattr(sys, "ps1"): return True return False