Skip to content
Snippets Groups Projects
Commit d495e728 authored by mads's avatar mads
Browse files

shared ssh client + retry

parent 79a10277
No related branches found
No related tags found
No related merge requests found
......@@ -9,7 +9,9 @@ import threading
import psutil
from wetb.utils.cluster_tools import pbswrap
from wetb.utils.cluster_tools.ssh_client import SSHClient
from wetb.utils.cluster_tools.ssh_client import SSHClient, SharedSSHClient
from _collections import deque
import time
class Resource(object):
......@@ -44,9 +46,11 @@ class Resource(object):
class SSHPBSClusterResource(Resource, SSHClient):
def __init__(self, host, username, password, port, min_cpu, min_free):
Resource.__init__(self, min_cpu, min_free)
self.shared_ssh = SharedSSHClient(host, username, password, port)
SSHClient.__init__(self, host, username, password, port=port)
self.lock = threading.Lock()
def new_ssh_connection(self):
return SSHClient(self.host, self.username, self.password, self.port)
......@@ -87,6 +91,8 @@ class SSHPBSClusterResource(Resource, SSHClient):
class LocalResource(Resource):
def __init__(self, process_name):
N = max(1, multiprocessing.cpu_count() / 2)
......@@ -97,11 +103,11 @@ class LocalResource(Resource):
def check_resources(self):
def name(i):
try:
return psutil.Process(i).name
except (psutil._error.AccessDenied, psutil._error.NoSuchProcess):
return psutil.Process(i).name()
except (psutil.AccessDenied, psutil.NoSuchProcess):
return ""
no_cpu = multiprocessing.cpu_count()
cpu_free = no_cpu - self.acquired #(1 - psutil.cpu_percent(.5) / 100) * no_cpu
no_current_process = len([i for i in psutil.get_pid_list() if name(i).lower().startswith(self.process_name.lower())])
return no_cpu, cpu_free, no_current_process
cpu_free = (1 - psutil.cpu_percent(.5) / 100) * no_cpu
no_current_process = len([i for i in psutil.pids() if name(i).lower().startswith(self.process_name.lower())])
return no_cpu, cpu_free, self.acquired
......@@ -12,14 +12,15 @@ RUNNING = "Running"
DONE = "Done"
class SSHPBSJob(SSHClient):
class SSHPBSJob(object):
_status = NOT_SUBMITTED
nodeid = None
jobid = None
def __init__(self, host, username, password, port=22):
SSHClient.__init__(self, host, username, password, port=port)
def __init__(self, sshClient):
self.ssh = sshClient
def submit(self, job, cwd, pbs_out_file):
self.cwd = cwd
......@@ -31,25 +32,27 @@ class SSHPBSJob(SSHClient):
if cwd != "":
cmds.append("cd %s" % cwd)
cmds.append("qsub %s" % job)
_, out, _ = self.execute(";".join(cmds))
ssh = SSHClient(self.ssh.host, self.ssh.username, self.ssh.password, self.ssh.port)
_, out, _ = ssh.execute(";".join(cmds))
self.jobid = out.split(".")[0]
self._status = PENDING
@property
def status(self):
if self._status in [NOT_SUBMITTED, DONE]:
return self._status
with self:
with self.ssh:
if self.is_executing():
self._status = RUNNING
elif self.file_exists(self.pbs_out_file):
elif self.ssh.file_exists(self.pbs_out_file):
self._status = DONE
self.jobid = None
return self._status
def get_nodeid(self):
try:
_, out, _ = self.execute("qstat -f %s | grep exec_host" % self.jobid)
_, out, _ = self.ssh.execute("qstat -f %s | grep exec_host" % self.jobid)
return out.strip().replace("exec_host = ", "").split(".")[0]
except Warning as e:
if 'qstat: Unknown Job Id' in str(e):
......@@ -59,7 +62,7 @@ class SSHPBSJob(SSHClient):
def stop(self):
if self.jobid:
try:
self.execute("qdel %s" % self.jobid)
self.ssh.execute("qdel %s" % self.jobid)
except Warning as e:
if 'qdel: Unknown Job Id' in str(e):
return
......@@ -68,7 +71,7 @@ class SSHPBSJob(SSHClient):
def is_executing(self):
try:
self.execute("qstat %s" % self.jobid)
self.ssh.execute("qstat %s" % self.jobid)
return True
except Warning as e:
if 'qstat: Unknown Job Id' in str(e):
......
......@@ -8,6 +8,10 @@ from io import StringIO
import paramiko
import os
import sys
import threading
from _collections import deque
import time
import traceback
class SSHClient(object):
"A wrapper of paramiko.SSHClient"
......@@ -31,6 +35,7 @@ class SSHClient(object):
self.disconnect += 1
if self.client is None:
self.connect()
return self.client
def connect(self):
if self.password is None:
......@@ -50,14 +55,20 @@ class SSHClient(object):
self.close()
def download(self, remotefilepath, localfile, verbose=False):
def download(self, remotefilepath, localfile, verbose=False, retry=1):
if verbose:
print ("Download %s > %s" % (remotefilepath, str(localfile)))
with self:
if isinstance(localfile, (str, bytes, int)):
ret = self.sftp.get(remotefilepath, localfile)
elif hasattr(localfile, 'write'):
ret = self.sftp.putfo(remotefilepath, localfile)
for i in range(retry):
try:
if isinstance(localfile, (str, bytes, int)):
ret = self.sftp.get(remotefilepath, localfile)
elif hasattr(localfile, 'write'):
ret = self.sftp.putfo(remotefilepath, localfile)
break
except:
pass
print ("retry", i)
if verbose:
print (ret)
......@@ -96,8 +107,12 @@ class SSHClient(object):
if verbose:
print (">>> " + command)
with self:
stdin, stdout, stderr = self.client.exec_command(command)
with self as ssh:
if ssh is None:
exc_info = sys.exc_info()
traceback.print_exception(*exc_info)
raise Exception("ssh_client exe ssh is NOne")
stdin, stdout, stderr = ssh.exec_command(command)
if feed_password:
stdin.write(self.password + "\n")
stdin.flush()
......@@ -143,6 +158,39 @@ class SSHClient(object):
return files
class SharedSSHClient(SSHClient):
def __init__(self, host, username, password=None, port=22, key=None, passphrase=None):
SSHClient.__init__(self, host, username, password=password, port=port, key=key, passphrase=passphrase)
self.shared_ssh_lock = threading.RLock()
self.shared_ssh_queue = deque()
self.next = None
def execute(self, command, sudo=False, verbose=False):
res = SSHClient.execute(self, command, sudo=sudo, verbose=verbose)
return res
def __enter__(self):
with self.shared_ssh_lock:
if self.next == threading.currentThread():
return self.client
self.shared_ssh_queue.append(threading.current_thread())
if self.next is None:
self.next = self.shared_ssh_queue.popleft()
while self.next != threading.currentThread():
time.sleep(1)
return self.client
def __exit__(self, *args):
with self.shared_ssh_lock:
if next != threading.current_thread():
with self.shared_ssh_lock:
if len(self.shared_ssh_queue) > 0:
self.next = self.shared_ssh_queue.popleft()
else:
self.next = None
if __name__ == "__main__":
from mmpe.ui.qt_ui import QtInputUI
q = QtInputUI(None)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment