From 8ef952ee475748ee77b6f4168ac6190a716751be Mon Sep 17 00:00:00 2001
From: "Mads M. Pedersen" <mmpe@dtu.dk>
Date: Thu, 26 Jan 2017 08:43:07 +0100
Subject: [PATCH] extracted class interactive_auth_handler in
 test_ssh_client.py

---
 wetb/utils/cluster_tools/ssh_client.py |  9 +++---
 wetb/utils/tests/test_ssh_client.py    | 45 +++++++++++---------------
 2 files changed, 22 insertions(+), 32 deletions(-)

diff --git a/wetb/utils/cluster_tools/ssh_client.py b/wetb/utils/cluster_tools/ssh_client.py
index d349c0f..42124a2 100644
--- a/wetb/utils/cluster_tools/ssh_client.py
+++ b/wetb/utils/cluster_tools/ssh_client.py
@@ -142,12 +142,11 @@ class SSHClient(object):
         else:
             self.client = paramiko.SSHClient()
             self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
-            if self.interactive_auth_handler:
-                transport = self.client.get_transport()
-                transport.auth_interactive(self.username, self.interactive_handler)
-            else:
+            try:
                 self.client.connect(self.host, self.port, username=self.username, password=self.password, pkey=self.key, timeout=self.TIMEOUT)
-            
+            except paramiko.ssh_exception.SSHException as e:
+                transport = self.client.get_transport()
+                transport.auth_interactive(self.username, self.interactive_auth_handler)
                 
         
         
diff --git a/wetb/utils/tests/test_ssh_client.py b/wetb/utils/tests/test_ssh_client.py
index d2591e5..6899d9f 100644
--- a/wetb/utils/tests/test_ssh_client.py
+++ b/wetb/utils/tests/test_ssh_client.py
@@ -23,7 +23,17 @@ import logging
 
 import getpass
 
-
+class sshrisoe_interactive_auth_handler(object):
+    def __init__(self, password):
+        self.password = password
+              
+    def __call__(self, title, instructions, prompt_list):
+        if prompt_list:
+            if prompt_list[0][0]=="AD Password: ":
+                return [self.password]
+            return [getpass.getpass(prompt_list[0][0])]
+        return []
+    
 tfp = os.path.join(os.path.dirname(__file__), 'test_files/')
 all = 0
 class TestSSHClient(unittest.TestCase):
@@ -111,35 +121,16 @@ class TestSSHClient(unittest.TestCase):
 #             ssh = SSHClient('g-047', "mmpe", x.mmpe, gateway=gateway)
 #             self.assertEqual(ssh.execute('hostname')[1].strip(), "g-047")
 
-#     def test_ssh_risoe(self):
-#         if x:
-#             class sshrisoe_interactive_auth_handler(object):
-#                 def __init__(self, password):
-#                     self.password = password
-#                         
-#                 def __call__(self, title, instructions, prompt_list):
-#                     if prompt_list:
-#                         if prompt_list[0][0]=="AD Password: ":
-#                             return [self.password]
-#                         return [getpass.getpass(prompt_list[0][0])]
-#                     return []
-#             
-#             ssh = SSHClient('ssh.risoe.dk', 'mmpe', interactive_auth_handler = sshrisoe_interactive_auth_handler(x.mmpe))
-#             _,out,_ = ssh.execute("hostname")
-#             self.assertEqual(out.strip(), "ssh-03.risoe.dk")
+    def test_ssh_risoe(self):
+        if x:
+            
+            ssh = SSHClient('ssh.risoe.dk', 'mmpe', interactive_auth_handler = sshrisoe_interactive_auth_handler(x.mmpe))
+            _,out,_ = ssh.execute("hostname")
+            self.assertEqual(out.strip(), "ssh-03.risoe.dk")
 
     def test_ssh_risoe_gorm(self):
         if x:
-            class sshrisoe_interactive_auth_handler(object):
-                def __init__(self, password):
-                    self.password = password
-                         
-                def __call__(self, title, instructions, prompt_list):
-                    if prompt_list:
-                        if prompt_list[0][0]=="AD Password: ":
-                            return [self.password]
-                        return [getpass.getpass(prompt_list[0][0])]
-                    return []
+
             gateway = SSHClient('ssh.risoe.dk', 'mmpe', interactive_auth_handler = sshrisoe_interactive_auth_handler(x.mmpe))
             ssh = SSHClient('gorm.risoe.dk', 'mmpe', x.mmpe, gateway = gateway)
             _,out,_ = ssh.execute("hostname")
-- 
GitLab