diff --git a/wetb/utils/cluster_tools/ssh_client.py b/wetb/utils/cluster_tools/ssh_client.py index 39d2f7c81fc5ebcc0dcb14f3b5b60f4873c79cf0..7ba39fc4fa76c0a13018a902cc04ec3d4f083191 100644 --- a/wetb/utils/cluster_tools/ssh_client.py +++ b/wetb/utils/cluster_tools/ssh_client.py @@ -15,17 +15,79 @@ import traceback import zipfile from wetb.utils.timing import print_time import glob +import getpass +from sshtunnel import SSHTunnelForwarder + + + +class SSHInteractiveAuthTunnelForwarder(SSHTunnelForwarder): + + def _connect_to_gateway(self): + """ + Open connection to SSH gateway + - First try with all keys loaded from an SSH agent (if allowed) + - Then with those passed directly or read from ~/.ssh/config + - As last resort, try with a provided password + """ + if self.ssh_password: # avoid conflict using both pass and pkey + self.logger.debug('Trying to log in with password: {0}' + .format('*' * len(self.ssh_password))) + try: + self._transport = self._get_transport() + if self.interactive_auth_gateway: + self._transport.start_client() + def interactive_handler(title, instructions, prompt_list): + if prompt_list: + if prompt_list[0][0]=="AD Password: ": + import x + return [x.mmpe] + return [getpass.getpass(prompt_list[0][0])] + print ("here") + return [] + self._transport.auth_interactive("mmpe", interactive_handler) + else: + self._transport.connect(hostkey=self.ssh_host_key, + username=self.ssh_username, + password=self.ssh_password) + + if self._transport.is_alive: + return + except paramiko.AuthenticationException: + self.logger.debug('Authentication error') + self._stop_transport() +# +# try: +# self._transport = self._get_transport() +# self._transport.start_client() +# def interactive_handler(title, instructions, prompt_list): +# if prompt_list: +# if prompt_list[0][0]=="AD Password: ": +# import x +# return [x.mmpe] +# return [getpass.getpass(prompt_list[0][0])] +# print ("here") +# return [] +# self._transport.auth_interactive("mmpe", interactive_handler) +# if self._transport.is_alive: +# return +# except paramiko.AuthenticationException: +# self.logger.debug('Authentication error') +# self._stop_transport() + + self.logger.error('Could not open connection to gateway') class SSHClient(object): "A wrapper of paramiko.SSHClient" TIMEOUT = 4 - def __init__(self, host, username, password=None, port=22, key=None, passphrase=None): + def __init__(self, host, username, password=None, port=22, key=None, passphrase=None, gateway=None, interactive_auth_gateway=False): self.host = host self.username = username self.password = password self.port = port self.key = key + self.gateway=gateway + self.interactive_auth_gateway = interactive_auth_gateway self.disconnect = 0 self.client = None self.sftp = None @@ -50,11 +112,43 @@ class SSHClient(object): return self.client def connect(self): - if self.password is None or self.password == "": - raise IOError("Password not set for %s"%self.host) - self.client = paramiko.SSHClient() - self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - self.client.connect(self.host, self.port, username=self.username, password=self.password, pkey=self.key, timeout=self.TIMEOUT) +# if self.password is None or self.password == "": +# raise IOError("Password not set for %s"%self.host) + if self.gateway: + + self.tunnel = SSHInteractiveAuthTunnelForwarder( + (self.gateway, self.port), + ssh_username=self.username, + ssh_password=self.password, + remote_bind_address=(self.host, self.port), + local_bind_address=('0.0.0.0', 10022) + ) + self.tunnel.interactive_auth_gateway = self.interactive_auth_gateway + self.tunnel.start() + self.client = paramiko.SSHClient() + self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + self.client.connect("127.0.0.1", 10022, username=self.username, password=self.password) + + + else: + self.client = paramiko.SSHClient() + self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + try: + self.client.connect(self.host, self.port, username=self.username, password=self.password, pkey=self.key, timeout=self.TIMEOUT) + except paramiko.ssh_exception.SSHException as e: + transport = self.client.get_transport() + def interactive_handler(title, instructions, prompt_list): + if prompt_list: + if prompt_list[0][0]=="AD Password: ": + import x + return [x.mmpe] + return [getpass.getpass(prompt_list[0][0])] + print ("here") + return [] + transport.auth_interactive(self.username, interactive_handler) + + + assert self.client is not None self.sftp = paramiko.SFTPClient.from_transport(self.client._transport) return self @@ -141,7 +235,7 @@ class SSHClient(object): def close(self): - for x in ["sftp", "client" ]: + for x in ["sftp", "client", 'tunnel' ]: try: getattr(self, x).close() setattr(self, x, None)