Decouple taskloop from task
authorJannis Harder <me@jix.one>
Thu, 21 Apr 2022 14:22:32 +0000 (16:22 +0200)
committerJannis Harder <me@jix.one>
Wed, 15 Jun 2022 14:28:09 +0000 (16:28 +0200)
sbysrc/sby.py
sbysrc/sby_core.py

index d9e0a5c9df58c818140d3bd8b26b2977b30c2ac9..f3eca9b5b732986189b1182330aba50cc742cff2 100644 (file)
@@ -458,13 +458,11 @@ def run_task(taskname):
     for k, v in exe_paths.items():
         task.exe_paths[k] = v
 
-    if throw_err:
+    try:
         task.run(setupmode)
-    else:
-        try:
-            task.run(setupmode)
-        except SbyAbort:
-            pass
+    except SbyAbort:
+        if throw_err:
+            raise
 
     if my_opt_tmpdir:
         task.log(f"Removing directory '{my_workdir}'.")
index 2e092c42626f374d313748023c434d730a054ca5..d133786ce969ac81d8ba59983fdb4470caa8c65d 100644 (file)
@@ -51,6 +51,7 @@ class SbyProc:
         self.running = False
         self.finished = False
         self.terminated = False
+        self.exited = False
         self.checkretcode = False
         self.retcodes = [0]
         self.task = task
@@ -81,7 +82,7 @@ class SbyProc:
         self.logstderr = logstderr
         self.silent = silent
 
-        self.task.procs_pending.append(self)
+        self.task.update_proc_pending(self)
 
         for dep in self.deps:
             dep.register_dep(self)
@@ -90,6 +91,9 @@ class SbyProc:
         self.exit_callback = None
         self.error_callback = None
 
+        if self.task.timeout_reached:
+            self.terminate(True)
+
     def register_dep(self, next_proc):
         if self.finished:
             next_proc.poll()
@@ -137,12 +141,19 @@ class SbyProc:
                 except PermissionError:
                     pass
             self.p.terminate()
-            self.task.procs_running.remove(self)
-            all_procs_running.remove(self)
+            self.task.update_proc_stopped(self)
+        elif not self.finished and not self.terminated and not self.exited:
+            self.task.update_proc_canceled(self)
         self.terminated = True
 
-    def poll(self):
-        if self.finished or self.terminated:
+    def poll(self, force_unchecked=False):
+        if self.task.task_local_abort and not force_unchecked:
+            try:
+                self.poll(True)
+            except SbyAbort:
+                self.task.terminate(True)
+            return
+        if self.finished or self.terminated or self.exited:
             return
 
         if not self.running:
@@ -168,9 +179,7 @@ class SbyProc:
                 self.p = subprocess.Popen(self.cmdline, shell=True, stdin=subprocess.DEVNULL, stdout=subprocess.PIPE,
                         stderr=(subprocess.STDOUT if self.logstderr else None))
 
-            self.task.procs_pending.remove(self)
-            self.task.procs_running.append(self)
-            all_procs_running.append(self)
+            self.task.update_proc_running(self)
             self.running = True
             return
 
@@ -187,26 +196,24 @@ class SbyProc:
         if self.p.poll() is not None:
             if not self.silent:
                 self.task.log(f"{self.info}: finished (returncode={self.p.returncode})")
-            self.task.procs_running.remove(self)
-            all_procs_running.remove(self)
+            self.task.update_proc_stopped(self)
             self.running = False
+            self.exited = True
 
             if self.p.returncode == 127:
-                self.task.status = "ERROR"
                 if not self.silent:
                     self.task.log(f"{self.info}: COMMAND NOT FOUND. ERROR.")
                 self.handle_error(self.p.returncode)
                 self.terminated = True
-                self.task.terminate()
+                self.task.proc_failed(self)
                 return
 
             if self.checkretcode and self.p.returncode not in self.retcodes:
-                self.task.status = "ERROR"
                 if not self.silent:
                     self.task.log(f"{self.info}: task failed. ERROR.")
                 self.handle_error(self.p.returncode)
                 self.terminated = True
-                self.task.terminate()
+                self.task.proc_failed(self)
                 return
 
             self.handle_exit(self.p.returncode)
@@ -321,8 +328,55 @@ class SbyConfig:
     def error(self, logmessage):
         raise SbyAbort(logmessage)
 
+
+class SbyTaskloop:
+    def __init__(self):
+        self.procs_pending = []
+        self.procs_running = []
+        self.tasks = []
+        self.poll_now = False
+
+    def run(self):
+        for proc in self.procs_pending:
+            proc.poll()
+
+        while len(self.procs_running) or self.poll_now:
+            fds = []
+            for proc in self.procs_running:
+                if proc.running:
+                    fds.append(proc.p.stdout)
+
+            if not self.poll_now:
+                if os.name == "posix":
+                    try:
+                        select(fds, [], [], 1.0) == ([], [], [])
+                    except InterruptedError:
+                        pass
+                else:
+                    sleep(0.1)
+            self.poll_now = False
+
+            for proc in self.procs_running:
+                proc.poll()
+
+            for proc in self.procs_pending:
+                proc.poll()
+
+            tasks = self.tasks
+            self.tasks = []
+            for task in tasks:
+                task.check_timeout()
+                if task.procs_pending or task.procs_running:
+                    self.tasks.append(task)
+                else:
+                    task.exit_callback()
+
+        for task in self.tasks:
+            task.exit_callback()
+
+
 class SbyTask(SbyConfig):
-    def __init__(self, sbyconfig, workdir, early_logs, reusedir):
+    def __init__(self, sbyconfig, workdir, early_logs, reusedir, taskloop=None):
         super().__init__()
         self.used_options = set()
         self.models = dict()
@@ -333,6 +387,8 @@ class SbyTask(SbyConfig):
         self.expect = list()
         self.design_hierarchy = None
         self.precise_prop_status = False
+        self.timeout_reached = False
+        self.task_local_abort = False
 
         yosys_program_prefix = "" ##yosys-program-prefix##
         self.exe_paths = {
@@ -346,6 +402,9 @@ class SbyTask(SbyConfig):
             "pono": os.getenv("PONO", "pono"),
         }
 
+        self.taskloop = taskloop or SbyTaskloop()
+        self.taskloop.tasks.append(self)
+
         self.procs_running = []
         self.procs_pending = []
 
@@ -367,36 +426,34 @@ class SbyTask(SbyConfig):
                 for line in sbyconfig:
                     print(line, file=f)
 
-    def taskloop(self):
-        for proc in self.procs_pending:
-            proc.poll()
+    def check_timeout(self):
+        if self.opt_timeout is not None:
+            total_clock_time = int(monotonic() - self.start_clock_time)
+            if total_clock_time > self.opt_timeout:
+                self.log(f"Reached TIMEOUT ({self.opt_timeout} seconds). Terminating all subprocesses.")
+                self.status = "TIMEOUT"
+                self.terminate(timeout=True)
 
-        while len(self.procs_running):
-            fds = []
-            for proc in self.procs_running:
-                if proc.running:
-                    fds.append(proc.p.stdout)
+    def update_proc_pending(self, proc):
+        self.procs_pending.append(proc)
+        self.taskloop.procs_pending.append(proc)
 
-            if os.name == "posix":
-                try:
-                    select(fds, [], [], 1.0) == ([], [], [])
-                except InterruptedError:
-                    pass
-            else:
-                sleep(0.1)
+    def update_proc_running(self, proc):
+        self.procs_pending.remove(proc)
+        self.taskloop.procs_pending.remove(proc)
 
-            for proc in self.procs_running:
-                proc.poll()
+        self.procs_running.append(proc)
+        self.taskloop.procs_running.append(proc)
+        all_procs_running.append(proc)
 
-            for proc in self.procs_pending:
-                proc.poll()
+    def update_proc_stopped(self, proc):
+        self.procs_running.remove(proc)
+        self.taskloop.procs_running.remove(proc)
+        all_procs_running.remove(proc)
 
-            if self.opt_timeout is not None:
-                total_clock_time = int(monotonic() - self.start_clock_time)
-                if total_clock_time > self.opt_timeout:
-                    self.log(f"Reached TIMEOUT ({self.opt_timeout} seconds). Terminating all subprocesses.")
-                    self.status = "TIMEOUT"
-                    self.terminate(timeout=True)
+    def update_proc_canceled(self, proc):
+        self.procs_pending.remove(proc)
+        self.taskloop.procs_pending.remove(proc)
 
     def log(self, logmessage):
         tm = localtime()
@@ -632,8 +689,17 @@ class SbyTask(SbyConfig):
         return self.models[model_name]
 
     def terminate(self, timeout=False):
+        if timeout:
+            self.timeout_reached = True
         for proc in list(self.procs_running):
             proc.terminate(timeout=timeout)
+        for proc in list(self.procs_pending):
+            proc.terminate(timeout=timeout)
+
+    def proc_failed(self, proc):
+        # proc parameter used by autotune override
+        self.status = "ERROR"
+        self.terminate()
 
     def update_status(self, new_status):
         assert new_status in ["PASS", "FAIL", "UNKNOWN", "ERROR"]
@@ -659,6 +725,11 @@ class SbyTask(SbyConfig):
             assert 0
 
     def run(self, setupmode):
+        self.setup_procs(setupmode)
+        if not setupmode:
+            self.taskloop.run()
+
+    def setup_procs(self, setupmode):
         with open(f"{self.workdir}/config.sby", "r") as f:
             self.parse_config(f)
 
@@ -732,8 +803,7 @@ class SbyTask(SbyConfig):
             if opt not in self.used_options:
                 self.error(f"Unused option: {opt}")
 
-        self.taskloop()
-
+    def summarize(self):
         total_clock_time = int(monotonic() - self.start_clock_time)
 
         if os.name == "posix":
@@ -772,6 +842,9 @@ class SbyTask(SbyConfig):
             for line in self.summary:
                 print(line, file=f)
 
+    def exit_callback(self):
+        self.summarize()
+
     def print_junit_result(self, f, junit_ts_name, junit_tc_name, junit_format_strict=False):
         junit_time = strftime('%Y-%m-%dT%H:%M:%S')
         if self.precise_prop_status: