diff --git a/wetb/hawc2/simulation.py b/wetb/hawc2/simulation.py index eca748c5f7f73c1acf79ec881f64675ea92a1b22..7fba699262d62d2b3f2b5652ded0f0e046eabd88 100755 --- a/wetb/hawc2/simulation.py +++ b/wetb/hawc2/simulation.py @@ -20,6 +20,7 @@ from future import standard_library from wetb.hawc2 import log_file from wetb.hawc2.htc_file import HTCFile, fmt_path from wetb.hawc2.log_file import LogFile +from wetb.utils import threadnames @@ -144,19 +145,26 @@ class Simulation(object): self.update_status() def abort(self, update_status=True): - if self.status != QUEUED: - self.host.stop() - for _ in range(50): - if self.is_simulating is False: - break - time.sleep(0.1) - if self.logFile.status not in [log_file.DONE]: - self.status = ABORTED + self.status = ABORTED self.is_simulating = False self.is_done = True + self.host.stop() if update_status: self.update_status() +# if self.status != QUEUED: +# self.host.stop() +# for _ in range(50): +# if self.is_simulating is False: +# break +# time.sleep(0.1) +# if self.logFile.status not in [log_file.DONE]: +# self.status = ABORTED +# self.is_simulating = False +# self.is_done = True +# if update_status: +# self.update_status() + def show_status(self): #print ("log status:", self.logFile.status) if self.logFile.status == log_file.SIMULATING: @@ -228,6 +236,8 @@ class Simulation(object): self.host._simulate() self.returncode, self.stdout = self.host.returncode, self.host.stdout if self.host.returncode or 'error' in self.host.stdout.lower(): + if self.status==ABORTED: + return if "error" in self.host.stdout.lower(): self.errors = (list(set([l for l in self.host.stdout.split("\n") if 'error' in l.lower()]))) self.status = ERROR @@ -263,15 +273,12 @@ class Simulation(object): return dst turb_files = [f for f in self.htcFile.turbulence_files() if self.copy_turbulence and not os.path.isfile(os.path.join(self.exepath, f))] if self.ios: - output_patterns = [fmt(dst) for dst in (["../output/*", "../output/"] + - turb_files + - [os.path.join(self.exepath, self.stdout_filename)])] - output_files = set([fmt_path(f) for pattern in output_patterns for f in self.host.glob(fmt_path(os.path.join(self.tmp_exepath, pattern)), recursive=True)]) + output_patterns = ["../output/*", "../output/"] + turb_files + [os.path.join(self.exepath, self.stdout_filename)] else: - output_patterns = [fmt(dst) for dst in (self.htcFile.output_files() + - turb_files + - [os.path.join(self.exepath, self.stdout_filename)])] - output_files = set([fmt_path(f) for pattern in output_patterns for f in self.host.glob(fmt_path(os.path.join(self.tmp_exepath, pattern)))]) + output_patterns = self.htcFile.output_files() + turb_files + [os.path.join(self.exepath, self.stdout_filename)] + output_files = set(self.host.glob([fmt_path(os.path.join(self.tmp_exepath,fmt(p))) for p in output_patterns], recursive=self.ios)) + + try: self.host._finish_simulation(output_files) if self.status != ERROR: @@ -325,6 +332,7 @@ class Simulation(object): def simulate_distributed(self): + threadnames.register("Simulation %s"%self.simulation_id) try: self.prepare_simulation() try: @@ -335,7 +343,8 @@ class Simulation(object): raise finally: try: - self.finish_simulation() + if self.status!=ABORTED: + self.finish_simulation() except: print ("finish_simulation failed", str(self)) raise diff --git a/wetb/hawc2/simulation_resources.py b/wetb/hawc2/simulation_resources.py index dfc386a3da9c2144eb2abda6a7f37e73430c6bf7..b9650f8db3bf1e517565dbdb18fe9c9cd91dd4be 100644 --- a/wetb/hawc2/simulation_resources.py +++ b/wetb/hawc2/simulation_resources.py @@ -23,7 +23,7 @@ from wetb.utils.cluster_tools.pbsjob import SSHPBSJob, NOT_SUBMITTED, DONE from wetb.utils.cluster_tools.ssh_client import SSHClient from wetb.utils.timing import print_time from wetb.hawc2.htc_file import fmt_path - +import numpy as np class SimulationHost(object): def __init__(self, simulation): @@ -60,6 +60,8 @@ class LocalSimulationHost(SimulationHost): return datetime.now() def glob(self, path, recursive=False): + if isinstance(path, list): + return [f for p in path for f in self.glob(p, recursive)] if recursive: return [os.path.join(root, f) for root, _, files in os.walk(path) for f in files] else: @@ -168,7 +170,7 @@ class SimulationThread(Thread): self.process = subprocess.Popen(" ".join([wine, hawc2exe, htcfile]), stdout=stdout, stderr=STDOUT, shell=True, cwd=exepath) #shell must be True to inwoke wine else: self.process = subprocess.Popen([hawc2exe, htcfile], stdout=stdout, stderr=STDOUT, shell=False, cwd=exepath, creationflags=CREATE_NO_WINDOW) - self.process.communicate() + #self.process.communicate() import psutil try: @@ -222,25 +224,25 @@ class PBSClusterSimulationResource(SSHPBSClusterResource): except: pass - def update_status(self): + def update_resource_status(self): try: - _, out, _ = self.execute("find .hawc2launcher/ -name '*.out'") + _, out, _ = self.ssh.execute("find .hawc2launcher/ -name '*.out'") self.finished = set([f.split("/")[1] for f in out.split("\n") if "/" in f]) - except Exception: - #print ("resource_manager.update_status, out", str(e)) + except Exception as e: + print ("resource_manager.update_status, out", str(e)) pass try: - _, out, _ = self.execute("find .hawc2launcher -name 'status*' -exec cat {} \;") + _, out, _ = self.ssh.execute("find .hawc2launcher -name 'status*' -exec cat {} \;") self.loglines = {l.split(";")[0] : l.split(";")[1:] for l in out.split("\n") if ";" in l} - except Exception: - #print ("resource_manager.update_status, status file", str(e)) + except Exception as e: + print ("resource_manager.update_status, status file", str(e)) pass try: - _, out, _ = self.execute("qstat -u %s" % self.username) + _, out, _ = self.ssh.execute("qstat -u %s" % self.username) self.is_executing = set([j.split(".")[0] for j in out.split("\n")[5:] if "." in j]) - except Exception: - #print ("resource_manager.update_status, qstat", str(e)) + except Exception as e: + print ("resource_manager.update_status, qstat", str(e)) pass class GormSimulationResource(PBSClusterSimulationResource): @@ -250,49 +252,46 @@ source activate wetb_py3""" PBSClusterSimulationResource.__init__(self, "gorm.risoe.dk", username, password, 22, 25, 100, init_cmd, wine_cmd, "python") -class PBSClusterSimulationHost(SimulationHost, SSHClient): +class PBSClusterSimulationHost(SimulationHost): def __init__(self, simulation, resource): SimulationHost.__init__(self, simulation) - SSHClient.__init__(self, resource.host, resource.username, resource.password, resource.port) - self.pbsjob = SSHPBSJob(resource.shared_ssh) + self.ssh = resource.new_ssh_connection() + self.pbsjob = SSHPBSJob(resource.new_ssh_connection()) self.resource = resource hawc2exe = property(lambda self : os.path.basename(self.sim.hawc2exe)) - + def glob(self, *args,**kwargs): + return self.ssh.glob(*args,**kwargs) def get_datetime(self): - v, out, err = self.execute('date "+%Y,%m,%d,%H,%M,%S"') + v, out, err = self.ssh.execute('date "+%Y,%m,%d,%H,%M,%S"') if v == 0: return datetime.strptime(out.strip(), "%Y,%m,%d,%H,%M,%S") #@print_time def _prepare_simulation(self, input_files): - with self: - self.execute(["mkdir -p .hawc2launcher/%s" % self.simulation_id], verbose=False) - self.execute("mkdir -p %s%s" % (self.tmp_exepath, os.path.dirname(self.log_filename))) - - self.upload_files(self.modelpath, self.tmp_modelpath, file_lst = [os.path.relpath(f, self.modelpath) for f in input_files]) + with self.ssh: + self.ssh.execute(["mkdir -p .hawc2launcher/%s" % self.simulation_id], verbose=False) + self.ssh.execute("mkdir -p %s%s" % (self.tmp_exepath, os.path.dirname(self.log_filename))) + self.ssh.upload_files(self.modelpath, self.tmp_modelpath, file_lst = [os.path.relpath(f, self.modelpath) for f in input_files], callback=self.sim.progress_callback("Copy to host")) # for src_file in input_files: # dst = unix_path(self.tmp_modelpath + os.path.relpath(src_file, self.modelpath)) # self.execute("mkdir -p %s" % os.path.dirname(dst), verbose=False) # self.upload(src_file, dst, verbose=False) # ##assert self.ssh.file_exists(dst) - f = io.StringIO(self.pbsjobfile(self.sim.ios)) f.seek(0) - self.upload(f, self.tmp_exepath + "%s.in" % self.simulation_id) - self.execute("mkdir -p %s%s" % (self.tmp_exepath, os.path.dirname(self.stdout_filename))) + self.ssh.upload(f, self.tmp_exepath + "%s.in" % self.simulation_id) + self.ssh.execute("mkdir -p %s%s" % (self.tmp_exepath, os.path.dirname(self.stdout_filename))) remote_log_filename = "%s%s" % (self.tmp_exepath, self.log_filename) - self.execute("rm -f %s" % remote_log_filename) - - - + self.ssh.execute("rm -f %s" % remote_log_filename) + #@print_time def _finish_simulation(self, output_files): - with self: + with self.ssh: download_failed = [] try: - self.download_files(self.tmp_modelpath, self.modelpath, file_lst = [os.path.relpath(f, self.tmp_modelpath) for f in output_files] ) + self.ssh.download_files(self.tmp_modelpath, self.modelpath, file_lst = [os.path.relpath(f, self.tmp_modelpath) for f in output_files], callback=self.sim.progress_callback("Copy from host") ) except: # # for src_file in output_files: @@ -303,15 +302,15 @@ class PBSClusterSimulationHost(SimulationHost, SSHClient): # except Exception as e: # download_failed.append(dst_file) # if download_failed: - raise Warning("Failed to download %s from %s"%(",".join(download_failed), self.host)) + raise Warning("Failed to download %s from %s"%(",".join(download_failed), self.ssh.host)) else: try: - self.execute('rm -r .hawc2launcher/%s' % self.simulation_id) + self.ssh.execute('rm -r .hawc2launcher/%s' % self.simulation_id) finally: try: - self.execute('rm .hawc2launcher/status_%s' % self.simulation_id) + self.ssh.execute('rm .hawc2launcher/status_%s' % self.simulation_id) except: - raise Warning("Fail to remove temporary files and folders on %s"%self.host) + raise Warning("Fail to remove temporary files and folders on %s"%self.ssh.host) def _simulate(self): @@ -324,9 +323,9 @@ class PBSClusterSimulationHost(SimulationHost, SSHClient): time.sleep(sleeptime) local_out_file = self.exepath + self.stdout_filename - with self: + with self.ssh: try: - self.download(self.tmp_exepath + self.stdout_filename, local_out_file) + self.ssh.download(self.tmp_exepath + self.stdout_filename, local_out_file) with open(local_out_file) as fid: _, self.stdout, returncode_str, _ = fid.read().split("---------------------") self.returncode = returncode_str.strip() != "0" @@ -334,7 +333,7 @@ class PBSClusterSimulationHost(SimulationHost, SSHClient): self.returncode = 1 self.stdout = "error: Could not download and read stdout file" try: - self.download(self.tmp_exepath + self.log_filename, self.exepath + self.log_filename) + self.ssh.download(self.tmp_exepath + self.log_filename, self.exepath + self.log_filename) except Exception: raise Warning ("Logfile not found", self.tmp_modelpath + self.log_filename) self.sim.logFile = LogFile.from_htcfile(self.htcFile, self.exepath) @@ -400,7 +399,11 @@ class PBSClusterSimulationHost(SimulationHost, SSHClient): cp_back += "mkdir -p $PBS_O_WORKDIR/%s/. \n" % folder cp_back += "cp -R -f %s/. $PBS_O_WORKDIR/%s/.\n" % (folder, folder) rel_htcfilename = fmt_path(os.path.relpath(self.htcFile.filename, self.exepath)) - + try: + steps = self.htcFile.simulation.time_stop[0] / self.htcFile.simulation.newmark.deltat[0] + walltime = "%02d:00:00"%np.ceil(steps/500/60) + except: + walltime = "04:00:00" init=""" ### Standard Output #PBS -N h2l_%s @@ -408,14 +411,14 @@ class PBSClusterSimulationHost(SimulationHost, SSHClient): #PBS -j oe #PBS -o %s ### Maximum wallclock time format HOURS:MINUTES:SECONDS -#PBS -l walltime=04:00:00 +#PBS -l walltime=%s ###PBS -a 201547.53 #PBS -lnodes=1:ppn=1 ### Queue name #PBS -q workq ### Create scratch directory and copy data to it cd $PBS_O_WORKDIR -pwd"""% (self.simulation_id, self.stdout_filename) +pwd"""% (self.simulation_id, self.stdout_filename, walltime) copy_to=""" cp -R %s /scratch/$USER/$PBS_JOBID ### Execute commands on scratch nodes diff --git a/wetb/utils/cluster_tools/cluster_resource.py b/wetb/utils/cluster_tools/cluster_resource.py index d12e3e02d4b0e816c11f75de32379898d77df5f9..55206763141b49d1ce2d99cca0953728b8cf6c98 100644 --- a/wetb/utils/cluster_tools/cluster_resource.py +++ b/wetb/utils/cluster_tools/cluster_resource.py @@ -66,13 +66,18 @@ class Resource(object): def __init__(self, min_cpu, min_free): self.min_cpu = min_cpu self.min_free = min_free + self.cpu_free=0 self.acquired = 0 + self.no_cpu="?" + self.used_by_user=0 self.resource_lock = threading.Lock() def ok2submit(self): """Always ok to have min_cpu cpus and ok to have more if there are min_free free cpus""" try: + print ("ok2submit") total, free, user = self.check_resources() + print ("ok2submit", total, free, user) except: return False if user < self.min_cpu: @@ -91,7 +96,7 @@ class Resource(object): self.acquired -= 1 - def update_status(self): + def update_resource_status(self): try: self.no_cpu, self.cpu_free, self.used_by_user = self.check_resources() except Exception: @@ -114,7 +119,8 @@ class SSHPBSClusterResource(Resource): def new_ssh_connection(self): - return SSHClient(self.host, self.username, self.password, self.port) + return SSHClient(self.host, self.ssh.username, self.ssh.password, self.ssh.port) + #return self.ssh def check_resources(self): with self.resource_lock: diff --git a/wetb/utils/cluster_tools/pbsjob.py b/wetb/utils/cluster_tools/pbsjob.py index 901c00bc62eea095b078c05d4356c269cbf4a63b..d1c1c99ec54246d467c42c94f17af921d6297dfb 100644 --- a/wetb/utils/cluster_tools/pbsjob.py +++ b/wetb/utils/cluster_tools/pbsjob.py @@ -32,8 +32,7 @@ class SSHPBSJob(object): if cwd != "": cmds.append("cd %s" % cwd) cmds.append("qsub %s" % job) - ssh = SSHClient(self.ssh.host, self.ssh.username, self.ssh.password, self.ssh.port) - _, out, _ = ssh.execute(";".join(cmds)) + _, out, _ = self.ssh.execute(";".join(cmds)) self.jobid = out.split(".")[0] self._status = PENDING diff --git a/wetb/utils/cluster_tools/ssh_client.py b/wetb/utils/cluster_tools/ssh_client.py index c089dbc8b65747cda451f5fd94555c707b8e1de6..5f2759342182b0163541e6a4e5c09b4315e379ab 100644 --- a/wetb/utils/cluster_tools/ssh_client.py +++ b/wetb/utils/cluster_tools/ssh_client.py @@ -17,6 +17,9 @@ from wetb.utils.timing import print_time import glob import getpass from sshtunnel import SSHTunnelForwarder, SSH_CONFIG_FILE +from wetb.utils.ui import UI +from wetb.utils import threadnames + @@ -100,7 +103,7 @@ class SSHClient(object): "A wrapper of paramiko.SSHClient" TIMEOUT = 4 - def __init__(self, host, username, password=None, port=22, key=None, passphrase=None, interactive_auth_handler=None, gateway=None): + def __init__(self, host, username, password=None, port=22, key=None, passphrase=None, interactive_auth_handler=None, gateway=None, ui=UI()): self.host = host self.username = username self.password = password @@ -108,9 +111,11 @@ class SSHClient(object): self.key = key self.gateway=gateway self.interactive_auth_handler = interactive_auth_handler + self.ui = ui self.disconnect = 0 self.client = None - self.sftp = None + self.ssh_lock = threading.RLock() + #self.sftp = None self.transport = None if key is not None: self.key = paramiko.RSAKey.from_private_key(StringIO(key), password=passphrase) @@ -119,21 +124,22 @@ class SSHClient(object): return self.host, self.username, self.password, self.port def __enter__(self): - self.disconnect += 1 - if self.client is None or self.client._transport is None or self.client._transport.is_active() is False: - self.close() - try: - self.connect() - self.disconnect = 1 - except Exception as e: + with self.ssh_lock: + self.disconnect += 1 + if self.client is None or self.client._transport is None or self.client._transport.is_active() is False: self.close() - self.disconnect = 0 - raise e - return self.client + try: + self.connect() + self.disconnect = 1 + except Exception as e: + self.close() + self.disconnect = 0 + raise e + return self.client def connect(self): - print ("connect", self.host) + print ("connect", self.host, threadnames.name()) #print (traceback.print_stack()) if self.gateway: if self.gateway.interactive_auth_handler: @@ -152,10 +158,15 @@ class SSHClient(object): local_bind_address=('0.0.0.0', 10022) ) + print ("start tunnel") self.tunnel.start() + print ("self.client = paramiko.SSHClient()") self.client = paramiko.SSHClient() + self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + print ('self.client.connect("127.0.0.1", 10022, username=self.username, password=self.password)') self.client.connect("127.0.0.1", 10022, username=self.username, password=self.password) + print ("done") elif self.password is None or self.password == "": raise IOError("Password not set for %s"%self.host) @@ -171,48 +182,80 @@ class SSHClient(object): assert self.client is not None - self.sftp = paramiko.SFTPClient.from_transport(self.client._transport) + #self.sftp = paramiko.SFTPClient.from_transport(self.client._transport) return self + + def __del__(self): + self.close() + + @property + def sftp(self): + return paramiko.SFTPClient.from_transport(self.client._transport) + +# @sftp.setter +# def sftp(self, values): +# pass + def __exit__(self, *args): self.disconnect -= 1 if self.disconnect == 0: self.close() - def download(self, remotefilepath, localfile, verbose=False, retry=1): + def download(self, remotefilepath, localfile, verbose=False, retry=1, callback=None): if verbose: ret = None print ("Download %s > %s" % (remotefilepath, str(localfile))) - with self: - for i in range(retry): - if i>0: - print ("Retry download %s, #%d"%(remotefilepath, i)) - try: - if isinstance(localfile, (str, bytes, int)): - ret = self.sftp.get(remotefilepath, localfile) - elif hasattr(localfile, 'write'): - ret = self.sftp.putfo(remotefilepath, localfile) - break - except: - pass - print ("Download %s failed from %s"%(remotefilepath, self.host)) + if callback is None: + callback = self.ui.progress_callback() + + for i in range(retry): + if i>0: + print ("Retry download %s, #%d"%(remotefilepath, i)) + + try: + #print ("start download enter", threadnames.name()) + SSHClient.__enter__(self) + #print ("start download", threadnames.name()) + if isinstance(localfile, (str, bytes, int)): + ret = self.sftp.get(remotefilepath, localfile, callback=callback) + elif hasattr(localfile, 'write'): + ret = self.sftp.putfo(remotefilepath, localfile, callback=callback) + break + except: + pass + finally: + #print ("End download", threadnames.name()) + SSHClient.__exit__(self) + #print ("End download exit", threadnames.name()) + + print ("Download %s failed from %s"%(remotefilepath, self.host)) if verbose: print (ret) - def upload(self, localfile, filepath, verbose=False): + def upload(self, localfile, filepath, verbose=False, callback=None): if verbose: print ("Upload %s > %s" % (localfile, filepath)) - with self: + if callback is None: + callback = self.ui.progress_callback() + try: + SSHClient.__enter__(self) if isinstance(localfile, (str, bytes, int)): - ret = self.sftp.put(localfile, filepath) + ret = self.sftp.put(localfile, filepath, callback=callback) elif hasattr(localfile, 'read'): - ret = self.sftp.putfo(localfile, filepath) + size = len(localfile.read()) + localfile.seek(0) + ret = self.sftp.putfo(localfile, filepath, file_size=size, callback=callback) + finally: + #print ("End upload", threadnames.name()) + SSHClient.__exit__(self) + #print ("End upload exit", threadnames.name()) if verbose: print (ret) - def upload_files(self, localpath, remotepath, file_lst=["."], compression_level=1): + def upload_files(self, localpath, remotepath, file_lst=["."], compression_level=1, callback=None): assert os.path.isdir(localpath) if not isinstance(file_lst, (tuple, list)): file_lst = [file_lst] @@ -228,27 +271,30 @@ class SSHClient(object): zipf.write(f, os.path.relpath(f, localpath)) zipf.close() remote_zn = os.path.join(remotepath, zn).replace("\\","/") - self.execute("mkdir -p %s"%(remotepath)) - - self.upload(zn, remote_zn) - self.execute("unzip %s -d %s && rm %s"%(remote_zn, remotepath, remote_zn)) + with self: + self.execute("mkdir -p %s"%(remotepath)) + + self.upload(zn, remote_zn, callback=callback) + self.execute("unzip %s -d %s && rm %s"%(remote_zn, remotepath, remote_zn)) except: raise finally: os.remove(zn) - def download_files(self, remote_path, localpath, file_lst=["."], compression_level=1): + def download_files(self, remote_path, localpath, file_lst=["."], compression_level=1, callback=None): if not isinstance(file_lst, (tuple, list)): file_lst = [file_lst] file_lst = [f.replace("\\","/") for f in file_lst] - remote_zip = os.path.join(remote_path, "tmp.zip").replace("\\","/") - self.execute("cd %s && zip -r tmp.zip %s"%(remote_path, " ".join(file_lst))) + zn = 'tmp_%s_%s.zip'%(id(self),time.time()) - local_zip = os.path.join(localpath, "tmp.zip") + remote_zip = os.path.join(remote_path, zn).replace("\\","/") + self.execute("cd %s && zip -r %s %s"%(remote_path, zn, " ".join(file_lst))) + + local_zip = os.path.join(localpath, zn) if not os.path.isdir(localpath): os.makedirs(localpath) - self.download(remote_zip, local_zip) + self.download(remote_zip, local_zip, callback=callback) self.execute("rm -f %s" % remote_zip) with zipfile.ZipFile(local_zip, "r") as z: z.extractall(localpath) @@ -256,7 +302,7 @@ class SSHClient(object): def close(self): - for x in ["sftp", "client", 'tunnel' ]: + for x in ["client", 'tunnel' ]: try: getattr(self, x).close() setattr(self, x, None) @@ -319,6 +365,9 @@ class SSHClient(object): ret = self.execute('wine regedit tmp.reg') def glob(self, filepattern, cwd="", recursive=False): + if isinstance(filepattern, list): + with self: + return [f for fp in filepattern for f in self.glob(fp, cwd, recursive)] cwd = os.path.join(cwd, os.path.split(filepattern)[0]).replace("\\", "/") filepattern = os.path.split(filepattern)[1] if recursive: @@ -333,37 +382,49 @@ class SSHClient(object): class SharedSSHClient(SSHClient): def __init__(self, host, username, password=None, port=22, key=None, passphrase=None, interactive_auth_handler=None, gateway=None): SSHClient.__init__(self, host, username, password=password, port=port, key=key, passphrase=passphrase, interactive_auth_handler=interactive_auth_handler, gateway=gateway) - self.shared_ssh_lock = threading.RLock() + self.shared_ssh_queue = deque() self.next = None + + + def execute(self, command, sudo=False, verbose=False): res = SSHClient.execute(self, command, sudo=sudo, verbose=verbose) return res + + + def __enter__(self): - with self.shared_ssh_lock: - if self.next == threading.currentThread(): - return self.client - self.shared_ssh_queue.append(threading.current_thread()) - if self.next is None: - self.next = self.shared_ssh_queue.popleft() - - while self.next != threading.currentThread(): - time.sleep(1) - SSHClient.__enter__(self) + with self.ssh_lock: + SSHClient.__enter__(self) + #print ("request SSH", threading.currentThread()) +# if len(self.shared_ssh_queue)>0 and self.shared_ssh_queue[0] == threading.get_ident(): +# # SSH already allocated to this thread ( multiple use-statements in "with ssh:" block +# self.shared_ssh_queue.appendleft(threading.get_ident()) +# else: +# self.shared_ssh_queue.append(threading.get_ident()) + + if len(self.shared_ssh_queue)>0 and self.shared_ssh_queue[0] == threading.get_ident(): + # SSH already allocated to this thread ( multiple use-statements in "with ssh:" block + self.shared_ssh_queue.popleft() + + self.shared_ssh_queue.append(threading.get_ident()) + + while self.shared_ssh_queue[0] != threading.get_ident(): + #print ("Waiting for ssh", threadnames.name(), [threadnames.name(id) for id in self.shared_ssh_queue]) + time.sleep(2) + #print ("Got SSH", threadnames.name()) + return self.client def __exit__(self, *args): - with self.shared_ssh_lock: - if next != threading.current_thread(): - with self.shared_ssh_lock: - if len(self.shared_ssh_queue) > 0: - self.next = self.shared_ssh_queue.popleft() - else: - self.next = None - + with self.ssh_lock: + if len(self.shared_ssh_queue)>0 and self.shared_ssh_queue[0] == threading.get_ident(): + self.shared_ssh_queue.popleft() + if __name__ == "__main__": diff --git a/wetb/utils/tests/test_ssh_client.py b/wetb/utils/tests/test_ssh_client.py index 567e763689e752a769ea0efbf0cd051429ac5aee..022caddbd287c5b4814b53455472ddcf7c5fc6ac 100644 --- a/wetb/utils/tests/test_ssh_client.py +++ b/wetb/utils/tests/test_ssh_client.py @@ -6,6 +6,8 @@ Created on 23. dec. 2016 import unittest from wetb.utils.cluster_tools.ssh_client import SSHClient import os +from wetb.utils.text_ui import TextUI + try: import sys @@ -37,7 +39,7 @@ class TestSSHClient(unittest.TestCase): def setUp(self): if x: - self.ssh = SSHClient('gorm', 'mmpe',x.mmpe ) + self.ssh = SSHClient('gorm', 'mmpe',x.mmpe) def test_execute(self): @@ -54,10 +56,12 @@ class TestSSHClient(unittest.TestCase): self.ssh.execute("rm -f tmp.txt") io.StringIO() - txt = "Hello world" + txt = "Hello world"*1000000 f = io.StringIO(txt) f.seek(0) - self.ssh.upload(f, "tmp.txt") + print ("start upload") + self.ssh.upload(f, "tmp.txt", callback = TextUI().progress_callback("Uploading")) + print ("endupload") _,out,_ = self.ssh.execute("cat tmp.txt") self.assertEqual(out, txt) fn = tfp + "tmp.txt" @@ -114,7 +118,7 @@ class TestSSHClient(unittest.TestCase): self.assertEqual(out.strip(), "g-000.risoe.dk") def test_ssh_g047(self): - if 1 or all: + if 0 or all: if x: gateway = SSHClient('gorm.risoe.dk', 'mmpe', x.mmpe) ssh = SSHClient('g-047', "mmpe", x.mmpe, gateway=gateway) @@ -128,10 +132,10 @@ class TestSSHClient(unittest.TestCase): self.assertEqual(out.strip(), "ssh-03.risoe.dk") def test_ssh_risoe_gorm(self): - if 0 or all: + if 1 or all: if x: - gateway = SSHClient('ssh.risoe.dk', 'mmpe', interactive_auth_handler = sshrisoe_interactive_auth_handler(x.mmpe)) + gateway = SSHClient('ssh.risoe.dk', 'mmpe', password="xxx", interactive_auth_handler = sshrisoe_interactive_auth_handler(x.mmpe)) ssh = SSHClient('10.40.23.49', 'mmpe', x.mmpe, gateway = gateway) _,out,_ = ssh.execute("hostname") self.assertEqual(out.strip(), "g-000.risoe.dk")