From 30eef00cb168e8e8b39483c2961f945f716a6d06 Mon Sep 17 00:00:00 2001
From: "Mads M. Pedersen" <mmpe@dtu.dk>
Date: Fri, 20 Jan 2017 08:51:40 +0100
Subject: [PATCH] gateway and interactive_auth in ssh_client

---
 wetb/utils/cluster_tools/ssh_client.py | 108 +++++++++++++++++++++++--
 1 file changed, 101 insertions(+), 7 deletions(-)

diff --git a/wetb/utils/cluster_tools/ssh_client.py b/wetb/utils/cluster_tools/ssh_client.py
index 39d2f7c8..7ba39fc4 100644
--- a/wetb/utils/cluster_tools/ssh_client.py
+++ b/wetb/utils/cluster_tools/ssh_client.py
@@ -15,17 +15,79 @@ import traceback
 import zipfile
 from wetb.utils.timing import print_time
 import glob
+import getpass
+from sshtunnel import SSHTunnelForwarder
+
+
+
+class SSHInteractiveAuthTunnelForwarder(SSHTunnelForwarder):
+    
+    def _connect_to_gateway(self):
+        """
+        Open connection to SSH gateway
+         - First try with all keys loaded from an SSH agent (if allowed)
+         - Then with those passed directly or read from ~/.ssh/config
+         - As last resort, try with a provided password
+        """
+        if self.ssh_password:  # avoid conflict using both pass and pkey
+            self.logger.debug('Trying to log in with password: {0}'
+                              .format('*' * len(self.ssh_password)))
+            try:
+                self._transport = self._get_transport()
+                if self.interactive_auth_gateway:
+                    self._transport.start_client()
+                    def interactive_handler(title, instructions, prompt_list):
+                        if prompt_list:
+                            if prompt_list[0][0]=="AD Password: ":
+                                import x
+                                return [x.mmpe]
+                            return [getpass.getpass(prompt_list[0][0])]
+                        print ("here")
+                        return []
+                    self._transport.auth_interactive("mmpe", interactive_handler)
+                else:
+                    self._transport.connect(hostkey=self.ssh_host_key,
+                                            username=self.ssh_username,
+                                            password=self.ssh_password)
+                
+                if self._transport.is_alive:
+                    return
+            except paramiko.AuthenticationException:
+                self.logger.debug('Authentication error')
+                self._stop_transport()
+#         
+#         try:
+#             self._transport = self._get_transport()
+#             self._transport.start_client()
+#             def interactive_handler(title, instructions, prompt_list):
+#                 if prompt_list:
+#                     if prompt_list[0][0]=="AD Password: ":
+#                         import x
+#                         return [x.mmpe]
+#                     return [getpass.getpass(prompt_list[0][0])]
+#                 print ("here")
+#                 return []
+#             self._transport.auth_interactive("mmpe", interactive_handler)
+#             if self._transport.is_alive:
+#                 return
+#         except paramiko.AuthenticationException:
+#             self.logger.debug('Authentication error')
+#             self._stop_transport()
+ 
+        self.logger.error('Could not open connection to gateway')
 
 class SSHClient(object):
     "A wrapper of paramiko.SSHClient"
     TIMEOUT = 4
 
-    def __init__(self, host, username, password=None, port=22, key=None, passphrase=None):
+    def __init__(self, host, username, password=None, port=22, key=None, passphrase=None, gateway=None, interactive_auth_gateway=False):
         self.host = host
         self.username = username
         self.password = password
         self.port = port
         self.key = key
+        self.gateway=gateway
+        self.interactive_auth_gateway = interactive_auth_gateway
         self.disconnect = 0
         self.client = None
         self.sftp = None
@@ -50,11 +112,43 @@ class SSHClient(object):
         return self.client
 
     def connect(self):
-        if self.password is None or self.password == "":
-            raise IOError("Password not set for %s"%self.host)
-        self.client = paramiko.SSHClient()
-        self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
-        self.client.connect(self.host, self.port, username=self.username, password=self.password, pkey=self.key, timeout=self.TIMEOUT)
+#        if self.password is None or self.password == "":
+#             raise IOError("Password not set for %s"%self.host)
+        if self.gateway:
+                          
+            self.tunnel = SSHInteractiveAuthTunnelForwarder(
+                (self.gateway, self.port),
+                ssh_username=self.username,
+                ssh_password=self.password,
+                remote_bind_address=(self.host, self.port),
+                local_bind_address=('0.0.0.0', 10022)
+            )
+            self.tunnel.interactive_auth_gateway = self.interactive_auth_gateway
+            self.tunnel.start()
+            self.client = paramiko.SSHClient()
+            self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+            self.client.connect("127.0.0.1", 10022, username=self.username, password=self.password)
+
+                 
+        else:
+            self.client = paramiko.SSHClient()
+            self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+            try:
+                self.client.connect(self.host, self.port, username=self.username, password=self.password, pkey=self.key, timeout=self.TIMEOUT)
+            except paramiko.ssh_exception.SSHException as e:
+                transport = self.client.get_transport()
+                def interactive_handler(title, instructions, prompt_list):
+                    if prompt_list:
+                        if prompt_list[0][0]=="AD Password: ":
+                            import x
+                            return [x.mmpe]
+                        return [getpass.getpass(prompt_list[0][0])]
+                    print ("here")
+                    return []
+                transport.auth_interactive(self.username, interactive_handler)
+                
+        
+        
         assert self.client is not None
         self.sftp = paramiko.SFTPClient.from_transport(self.client._transport)
         return self
@@ -141,7 +235,7 @@ class SSHClient(object):
         
 
     def close(self):
-        for x in ["sftp", "client" ]:
+        for x in ["sftp", "client", 'tunnel' ]:
             try:
                 getattr(self, x).close()
                 setattr(self, x, None)
-- 
GitLab