''' Created on 27/11/2015 @author: MMPE ''' from io import StringIO import sys import paramiko import os import threading from _collections import deque import time import traceback import zipfile import glob from sshtunnel import SSHTunnelForwarder, SSH_CONFIG_FILE from wetb.utils.ui import UI class SSHInteractiveAuthTunnelForwarder(SSHTunnelForwarder): def __init__( self, interactive_auth_handler, ssh_address_or_host=None, ssh_config_file=SSH_CONFIG_FILE, ssh_host_key=None, ssh_password=None, ssh_pkey=None, ssh_private_key_password=None, ssh_proxy=None, ssh_proxy_enabled=True, ssh_username=None, local_bind_address=None, local_bind_addresses=None, logger=None, mute_exceptions=False, remote_bind_address=None, remote_bind_addresses=None, set_keepalive=0.0, threaded=True, compression=None, allow_agent=True, * args, ** kwargs): self.interactive_auth_handler = interactive_auth_handler SSHTunnelForwarder.__init__(self, ssh_address_or_host=ssh_address_or_host, ssh_config_file=ssh_config_file, ssh_host_key=ssh_host_key, ssh_password=ssh_password, ssh_pkey=ssh_pkey, ssh_private_key_password=ssh_private_key_password, ssh_proxy=ssh_proxy, ssh_proxy_enabled=ssh_proxy_enabled, ssh_username=ssh_username, local_bind_address=local_bind_address, local_bind_addresses=local_bind_addresses, logger=logger, mute_exceptions=mute_exceptions, remote_bind_address=remote_bind_address, remote_bind_addresses=remote_bind_addresses, set_keepalive=set_keepalive, threaded=threaded, compression=compression, allow_agent=allow_agent, *args, **kwargs) def _connect_to_gateway(self): """ Open connection to SSH gateway - First try with all keys loaded from an SSH agent (if allowed) - Then with those passed directly or read from ~/.ssh/config - As last resort, try with a provided password """ try: self._transport = self._get_transport() self._transport.start_client() self._transport.auth_interactive(self.ssh_username, self.interactive_auth_handler) if self._transport.is_alive: return except paramiko.AuthenticationException: self.logger.debug('Authentication error') self._stop_transport() self.logger.error('Could not open connection to gateway') def _connect_to_gateway_old(self): """ Open connection to SSH gateway - First try with all keys loaded from an SSH agent (if allowed) - Then with those passed directly or read from ~/.ssh/config - As last resort, try with a provided password """ if self.ssh_password: # avoid conflict using both pass and pkey self.logger.debug('Trying to log in with password: {0}' .format('*' * len(self.ssh_password))) try: self._transport = self._get_transport() if self.interactive_auth_handler: self._transport.start_client() self._transport.auth_interactive(self.ssh_username, self.interactive_auth_handler) else: self._transport.connect(hostkey=self.ssh_host_key, username=self.ssh_username, password=self.ssh_password) if self._transport.is_alive: return except paramiko.AuthenticationException: self.logger.debug('Authentication error') self._stop_transport() self.logger.error('Could not open connection to gateway') class SSHClient(object): "A wrapper of paramiko.SSHClient" TIMEOUT = 4 def __init__(self, host, username, password=None, port=22, key=None, passphrase=None, interactive_auth_handler=None, gateway=None, ui=UI()): self.host = host self.username = username self.password = password self.port = port self.key = key self.gateway=gateway self.interactive_auth_handler = interactive_auth_handler self.ui = ui self.disconnect = 0 self.client = None self.ssh_lock = threading.RLock() #self.sftp = None self.transport = None self.counter_lock = threading.RLock() self.counter=0 if key is not None: self.key = paramiko.RSAKey.from_private_key(StringIO(key), password=passphrase) def info(self): return self.host, self.username, self.password, self.port def __enter__(self): with self.ssh_lock: self.disconnect += 1 if self.client is None or self.client._transport is None or self.client._transport.is_active() is False: self.close() try: self.connect() self.disconnect = 1 except Exception as e: self.close() self.disconnect = 0 raise e return self.client def connect(self): if self.gateway: if self.gateway.interactive_auth_handler: self.tunnel = SSHInteractiveAuthTunnelForwarder(self.gateway.interactive_auth_handler, (self.gateway.host, self.gateway.port), ssh_username=self.gateway.username, ssh_password=self.gateway.password, remote_bind_address=(self.host, self.port), local_bind_address=('0.0.0.0', 10022) ) else: self.tunnel = SSHTunnelForwarder((self.gateway.host, self.gateway.port), ssh_username=self.gateway.username, ssh_password=self.gateway.password, remote_bind_address=(self.host, self.port), local_bind_address=('0.0.0.0', 10022) ) print ("start tunnel") self.tunnel.start() print ("self.client = paramiko.SSHClient()") self.client = paramiko.SSHClient() self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) print ('self.client.connect("127.0.0.1", 10022, username=self.username, password=self.password)') self.client.connect("127.0.0.1", 10022, username=self.username, password=self.password) print ("done") elif self.password is None or self.password == "": raise IOError("Password not set for %s"%self.host) else: self.client = paramiko.SSHClient() self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) try: self.client.connect(self.host, self.port, username=self.username, password=self.password, pkey=self.key, timeout=self.TIMEOUT) except paramiko.ssh_exception.SSHException as e: transport = self.client.get_transport() transport.auth_interactive(self.username, self.interactive_auth_handler) assert self.client is not None #self.sftp = paramiko.SFTPClient.from_transport(self.client._transport) return self def __del__(self): self.close() @property def sftp(self): return paramiko.SFTPClient.from_transport(self.client._transport) # @sftp.setter # def sftp(self, values): # pass def __exit__(self, *args): self.disconnect -= 1 if self.disconnect == 0: self.close() def download(self, remotefilepath, localfile, verbose=False, retry=1, callback=None): if verbose: ret = None print ("Download %s > %s" % (remotefilepath, str(localfile))) if callback is None: callback = self.ui.progress_callback() for i in range(retry): if i>0: print ("Retry download %s, #%d"%(remotefilepath, i)) try: SSHClient.__enter__(self) if isinstance(localfile, (str, bytes, int)): ret = self.sftp.get(remotefilepath, localfile, callback=callback) elif hasattr(localfile, 'write'): ret = self.sftp.putfo(remotefilepath, localfile, callback=callback) break except: pass finally: SSHClient.__exit__(self) print ("Download %s failed from %s"%(remotefilepath, self.host)) if verbose: print (ret) def upload(self, localfile, filepath, verbose=False, callback=None): if verbose: print ("Upload %s > %s" % (localfile, filepath)) if callback is None: callback = self.ui.progress_callback() try: SSHClient.__enter__(self) if isinstance(localfile, (str, bytes, int)): ret = self.sftp.put(localfile, filepath, callback=callback) elif hasattr(localfile, 'read'): size = len(localfile.read()) localfile.seek(0) ret = self.sftp.putfo(localfile, filepath, file_size=size, callback=callback) except Exception as e: print ("upload failed ", str(e)) raise e finally: SSHClient.__exit__(self) if verbose: print (ret) def upload_files(self, localpath, remotepath, file_lst=["."], compression_level=1, callback=None): assert os.path.isdir(localpath) if not isinstance(file_lst, (tuple, list)): file_lst = [file_lst] files = ([os.path.join(root, f) for fp in file_lst for root,_,files in os.walk(os.path.join(localpath, fp )) for f in files] + [f for fp in file_lst for f in glob.glob(os.path.join(localpath, fp)) ]) files = set([os.path.abspath(f) for f in files]) compression_levels = {0:zipfile.ZIP_STORED, 1:zipfile.ZIP_DEFLATED, 2:zipfile.ZIP_BZIP2, 3:zipfile.ZIP_LZMA} with self.counter_lock: self.counter+=1 zn = 'tmp_%s_%04d.zip'%(id(self),self.counter) zipf = zipfile.ZipFile(zn, 'w', compression_levels[compression_level]) try: for f in files: zipf.write(f, os.path.relpath(f, localpath)) zipf.close() remote_zn = os.path.join(remotepath, zn).replace("\\","/") with self: self.execute("mkdir -p %s"%(remotepath)) self.upload(zn, remote_zn, callback=callback) self.execute("unzip %s -d %s && rm %s"%(remote_zn, remotepath, remote_zn)) except: print ("upload files failed", ) traceback.print_exc() raise finally: os.remove(zn) def download_files(self, remote_path, localpath, file_lst=["."], compression_level=1, callback=None): if not isinstance(file_lst, (tuple, list)): file_lst = [file_lst] file_lst = [f.replace("\\","/") for f in file_lst] with self.counter_lock: self.counter+=1 zn = 'tmp_%s_%04d.zip'%(id(self),self.counter) remote_zip = os.path.join(remote_path, zn).replace("\\","/") self.execute("cd %s && zip -r %s %s"%(remote_path, zn, " ".join(file_lst))) local_zip = os.path.join(localpath, zn) if not os.path.isdir(localpath): os.makedirs(localpath) self.download(remote_zip, local_zip, callback=callback) self.execute("rm -f %s" % remote_zip) with zipfile.ZipFile(local_zip, "r") as z: z.extractall(localpath) os.remove(local_zip) def close(self): for x in ["client", 'tunnel' ]: try: getattr(self, x).close() setattr(self, x, None) except: pass self.disconnect = False def file_exists(self, filename): _, out, _ = (self.execute('[ -f %s ] && echo "File exists" || echo "File does not exists"' % filename.replace("\\", "/"))) return out.strip() == "File exists" def execute(self, command, sudo=False, verbose=False): feed_password = False if sudo and self.username != "root": command = "sudo -S -p '' %s" % command feed_password = self.password is not None and len(self.password) > 0 if isinstance(command, (list, tuple)): command = "\n".join(command) if verbose: print (">>> " + command) with self as ssh: if ssh is None: exc_info = sys.exc_info() traceback.print_exception(*exc_info) raise Exception("ssh_client exe ssh is None") stdin, stdout, stderr = ssh.exec_command(command) if feed_password: stdin.write(self.password + "\n") stdin.flush() v, out, err = stdout.channel.recv_exit_status(), stdout.read().decode(), stderr.read().decode() if v: raise Warning ("out:\n%s\n----------\nerr:\n%s" % (out, err)) elif verbose: if out: sys.stdout.write(out) if err: sys.stderr.write(err) return v, out, err def append_wine_path(self, path): ret = self.execute('wine regedit /E tmp.reg "HKEY_LOCAL_MACHINE\System\CurrentControlSet\Control\Session Manager\Environment"') self.download('tmp.reg', 'tmp.reg') with open('tmp.reg') as fid: lines = fid.readlines() path_line = [l for l in lines if l.startswith('"PATH"=')][0] for p in path_line[8:-1].split(";"): if os.path.abspath(p) == os.path.abspath(p): return if path not in path_line: path_line = path_line.strip()[:-1] + ";" + path + '"' with open('tmp.reg', 'w') as fid: fid.write("".join(lines[:3] + [path_line])) self.upload('tmp.reg', 'tmp.reg') ret = self.execute('wine regedit tmp.reg') def glob(self, filepattern, cwd="", recursive=False): if isinstance(filepattern, list): with self: return [f for fp in filepattern for f in self.glob(fp, cwd, recursive)] cwd = os.path.join(cwd, os.path.split(filepattern)[0]).replace("\\", "/") filepattern = os.path.split(filepattern)[1] if recursive: _, out, _ = self.execute(r'find %s -type f -name "%s"' % (cwd, filepattern)) else: _, out, _ = self.execute(r'find %s -maxdepth 1 -type f -name "%s"' % (cwd, filepattern)) return [file for file in out.strip().split("\n") if file != ""] class SharedSSHClient(SSHClient): def __init__(self, host, username, password=None, port=22, key=None, passphrase=None, interactive_auth_handler=None, gateway=None): SSHClient.__init__(self, host, username, password=password, port=port, key=key, passphrase=passphrase, interactive_auth_handler=interactive_auth_handler, gateway=gateway) self.shared_ssh_queue = deque() self.next = None def execute(self, command, sudo=False, verbose=False): res = SSHClient.execute(self, command, sudo=sudo, verbose=verbose) return res def __enter__(self): with self.ssh_lock: SSHClient.__enter__(self) #print ("request SSH", threading.currentThread()) # if len(self.shared_ssh_queue)>0 and self.shared_ssh_queue[0] == threading.get_ident(): # # SSH already allocated to this thread ( multiple use-statements in "with ssh:" block # self.shared_ssh_queue.appendleft(threading.get_ident()) # else: # self.shared_ssh_queue.append(threading.get_ident()) if len(self.shared_ssh_queue)>0 and self.shared_ssh_queue[0] == threading.get_ident(): # SSH already allocated to this thread ( multiple use-statements in "with ssh:" block self.shared_ssh_queue.popleft() self.shared_ssh_queue.append(threading.get_ident()) while self.shared_ssh_queue[0] != threading.get_ident(): time.sleep(2) return self.client def __exit__(self, *args): with self.ssh_lock: if len(self.shared_ssh_queue)>0 and self.shared_ssh_queue[0] == threading.get_ident(): self.shared_ssh_queue.popleft() if __name__ == "__main__": from mmpe.ui.qt_ui import QtInputUI q = QtInputUI(None) x = None username, password = "mmpe", x.password #q.get_login("mmpe") client = SSHClient(host='gorm', port=22, username=username, password=password) print (client.glob("*.*", ".hawc2launcher/medium1__1__")) # ssh.upload('../News.txt', 'news.txt')