diff --git a/wetb/utils/cluster_tools/cluster_resource.py b/wetb/utils/cluster_tools/cluster_resource.py index 162036fd95edafba7bd1e7e45a1d961b9fbefca1..ebdf6f4aece80b05aa114b39ab516ea5bb76e989 100644 --- a/wetb/utils/cluster_tools/cluster_resource.py +++ b/wetb/utils/cluster_tools/cluster_resource.py @@ -9,7 +9,9 @@ import threading import psutil from wetb.utils.cluster_tools import pbswrap -from wetb.utils.cluster_tools.ssh_client import SSHClient +from wetb.utils.cluster_tools.ssh_client import SSHClient, SharedSSHClient +from _collections import deque +import time class Resource(object): @@ -44,9 +46,11 @@ class Resource(object): class SSHPBSClusterResource(Resource, SSHClient): def __init__(self, host, username, password, port, min_cpu, min_free): Resource.__init__(self, min_cpu, min_free) + self.shared_ssh = SharedSSHClient(host, username, password, port) SSHClient.__init__(self, host, username, password, port=port) self.lock = threading.Lock() + def new_ssh_connection(self): return SSHClient(self.host, self.username, self.password, self.port) @@ -87,6 +91,8 @@ class SSHPBSClusterResource(Resource, SSHClient): + + class LocalResource(Resource): def __init__(self, process_name): N = max(1, multiprocessing.cpu_count() / 2) @@ -97,11 +103,11 @@ class LocalResource(Resource): def check_resources(self): def name(i): try: - return psutil.Process(i).name - except (psutil._error.AccessDenied, psutil._error.NoSuchProcess): + return psutil.Process(i).name() + except (psutil.AccessDenied, psutil.NoSuchProcess): return "" no_cpu = multiprocessing.cpu_count() - cpu_free = no_cpu - self.acquired #(1 - psutil.cpu_percent(.5) / 100) * no_cpu - no_current_process = len([i for i in psutil.get_pid_list() if name(i).lower().startswith(self.process_name.lower())]) - return no_cpu, cpu_free, no_current_process + cpu_free = (1 - psutil.cpu_percent(.5) / 100) * no_cpu + no_current_process = len([i for i in psutil.pids() if name(i).lower().startswith(self.process_name.lower())]) + return no_cpu, cpu_free, self.acquired diff --git a/wetb/utils/cluster_tools/pbsjob.py b/wetb/utils/cluster_tools/pbsjob.py index 7b6d6f8e8827eecad29c19fab91f6e3ca280577c..901c00bc62eea095b078c05d4356c269cbf4a63b 100644 --- a/wetb/utils/cluster_tools/pbsjob.py +++ b/wetb/utils/cluster_tools/pbsjob.py @@ -12,14 +12,15 @@ RUNNING = "Running" DONE = "Done" -class SSHPBSJob(SSHClient): +class SSHPBSJob(object): _status = NOT_SUBMITTED nodeid = None jobid = None - def __init__(self, host, username, password, port=22): - SSHClient.__init__(self, host, username, password, port=port) + def __init__(self, sshClient): + self.ssh = sshClient + def submit(self, job, cwd, pbs_out_file): self.cwd = cwd @@ -31,25 +32,27 @@ class SSHPBSJob(SSHClient): if cwd != "": cmds.append("cd %s" % cwd) cmds.append("qsub %s" % job) - _, out, _ = self.execute(";".join(cmds)) + ssh = SSHClient(self.ssh.host, self.ssh.username, self.ssh.password, self.ssh.port) + _, out, _ = ssh.execute(";".join(cmds)) self.jobid = out.split(".")[0] self._status = PENDING @property def status(self): + if self._status in [NOT_SUBMITTED, DONE]: return self._status - with self: + with self.ssh: if self.is_executing(): self._status = RUNNING - elif self.file_exists(self.pbs_out_file): + elif self.ssh.file_exists(self.pbs_out_file): self._status = DONE self.jobid = None return self._status def get_nodeid(self): try: - _, out, _ = self.execute("qstat -f %s | grep exec_host" % self.jobid) + _, out, _ = self.ssh.execute("qstat -f %s | grep exec_host" % self.jobid) return out.strip().replace("exec_host = ", "").split(".")[0] except Warning as e: if 'qstat: Unknown Job Id' in str(e): @@ -59,7 +62,7 @@ class SSHPBSJob(SSHClient): def stop(self): if self.jobid: try: - self.execute("qdel %s" % self.jobid) + self.ssh.execute("qdel %s" % self.jobid) except Warning as e: if 'qdel: Unknown Job Id' in str(e): return @@ -68,7 +71,7 @@ class SSHPBSJob(SSHClient): def is_executing(self): try: - self.execute("qstat %s" % self.jobid) + self.ssh.execute("qstat %s" % self.jobid) return True except Warning as e: if 'qstat: Unknown Job Id' in str(e): diff --git a/wetb/utils/cluster_tools/ssh_client.py b/wetb/utils/cluster_tools/ssh_client.py index cdda5b515e55fb9a73279d7dc122f93a091ed0c5..97d6ed8d8e5e968c5205c9eb0d3d849cfe412046 100644 --- a/wetb/utils/cluster_tools/ssh_client.py +++ b/wetb/utils/cluster_tools/ssh_client.py @@ -8,6 +8,10 @@ from io import StringIO import paramiko import os import sys +import threading +from _collections import deque +import time +import traceback class SSHClient(object): "A wrapper of paramiko.SSHClient" @@ -31,6 +35,7 @@ class SSHClient(object): self.disconnect += 1 if self.client is None: self.connect() + return self.client def connect(self): if self.password is None: @@ -50,14 +55,20 @@ class SSHClient(object): self.close() - def download(self, remotefilepath, localfile, verbose=False): + def download(self, remotefilepath, localfile, verbose=False, retry=1): if verbose: print ("Download %s > %s" % (remotefilepath, str(localfile))) with self: - if isinstance(localfile, (str, bytes, int)): - ret = self.sftp.get(remotefilepath, localfile) - elif hasattr(localfile, 'write'): - ret = self.sftp.putfo(remotefilepath, localfile) + for i in range(retry): + try: + if isinstance(localfile, (str, bytes, int)): + ret = self.sftp.get(remotefilepath, localfile) + elif hasattr(localfile, 'write'): + ret = self.sftp.putfo(remotefilepath, localfile) + break + except: + pass + print ("retry", i) if verbose: print (ret) @@ -96,8 +107,12 @@ class SSHClient(object): if verbose: print (">>> " + command) - with self: - stdin, stdout, stderr = self.client.exec_command(command) + with self as ssh: + if ssh is None: + exc_info = sys.exc_info() + traceback.print_exception(*exc_info) + raise Exception("ssh_client exe ssh is NOne") + stdin, stdout, stderr = ssh.exec_command(command) if feed_password: stdin.write(self.password + "\n") stdin.flush() @@ -143,6 +158,39 @@ class SSHClient(object): return files +class SharedSSHClient(SSHClient): + def __init__(self, host, username, password=None, port=22, key=None, passphrase=None): + SSHClient.__init__(self, host, username, password=password, port=port, key=key, passphrase=passphrase) + self.shared_ssh_lock = threading.RLock() + self.shared_ssh_queue = deque() + self.next = None + + + def execute(self, command, sudo=False, verbose=False): + res = SSHClient.execute(self, command, sudo=sudo, verbose=verbose) + return res + + def __enter__(self): + with self.shared_ssh_lock: + if self.next == threading.currentThread(): + return self.client + self.shared_ssh_queue.append(threading.current_thread()) + if self.next is None: + self.next = self.shared_ssh_queue.popleft() + + while self.next != threading.currentThread(): + time.sleep(1) + return self.client + + def __exit__(self, *args): + with self.shared_ssh_lock: + if next != threading.current_thread(): + with self.shared_ssh_lock: + if len(self.shared_ssh_queue) > 0: + self.next = self.shared_ssh_queue.popleft() + else: + self.next = None + if __name__ == "__main__": from mmpe.ui.qt_ui import QtInputUI q = QtInputUI(None)