diff --git a/wetb/utils/cluster_tools/ssh_client.py b/wetb/utils/cluster_tools/ssh_client.py index d349c0fcd8166b56965d99cd796dae92c3f4dc2b..42124a2f0677db640f4d530b9c232c91a13b963b 100644 --- a/wetb/utils/cluster_tools/ssh_client.py +++ b/wetb/utils/cluster_tools/ssh_client.py @@ -142,12 +142,11 @@ class SSHClient(object): else: self.client = paramiko.SSHClient() self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - if self.interactive_auth_handler: - transport = self.client.get_transport() - transport.auth_interactive(self.username, self.interactive_handler) - else: + try: self.client.connect(self.host, self.port, username=self.username, password=self.password, pkey=self.key, timeout=self.TIMEOUT) - + except paramiko.ssh_exception.SSHException as e: + transport = self.client.get_transport() + transport.auth_interactive(self.username, self.interactive_auth_handler) diff --git a/wetb/utils/tests/test_ssh_client.py b/wetb/utils/tests/test_ssh_client.py index d2591e5acf5f0197b5fc35cbcd8ad1632bcf58eb..6899d9f2ba4f8801cdd3cf6f41fd28a8725d2ea0 100644 --- a/wetb/utils/tests/test_ssh_client.py +++ b/wetb/utils/tests/test_ssh_client.py @@ -23,7 +23,17 @@ import logging import getpass - +class sshrisoe_interactive_auth_handler(object): + def __init__(self, password): + self.password = password + + def __call__(self, title, instructions, prompt_list): + if prompt_list: + if prompt_list[0][0]=="AD Password: ": + return [self.password] + return [getpass.getpass(prompt_list[0][0])] + return [] + tfp = os.path.join(os.path.dirname(__file__), 'test_files/') all = 0 class TestSSHClient(unittest.TestCase): @@ -111,35 +121,16 @@ class TestSSHClient(unittest.TestCase): # ssh = SSHClient('g-047', "mmpe", x.mmpe, gateway=gateway) # self.assertEqual(ssh.execute('hostname')[1].strip(), "g-047") -# def test_ssh_risoe(self): -# if x: -# class sshrisoe_interactive_auth_handler(object): -# def __init__(self, password): -# self.password = password -# -# def __call__(self, title, instructions, prompt_list): -# if prompt_list: -# if prompt_list[0][0]=="AD Password: ": -# return [self.password] -# return [getpass.getpass(prompt_list[0][0])] -# return [] -# -# ssh = SSHClient('ssh.risoe.dk', 'mmpe', interactive_auth_handler = sshrisoe_interactive_auth_handler(x.mmpe)) -# _,out,_ = ssh.execute("hostname") -# self.assertEqual(out.strip(), "ssh-03.risoe.dk") + def test_ssh_risoe(self): + if x: + + ssh = SSHClient('ssh.risoe.dk', 'mmpe', interactive_auth_handler = sshrisoe_interactive_auth_handler(x.mmpe)) + _,out,_ = ssh.execute("hostname") + self.assertEqual(out.strip(), "ssh-03.risoe.dk") def test_ssh_risoe_gorm(self): if x: - class sshrisoe_interactive_auth_handler(object): - def __init__(self, password): - self.password = password - - def __call__(self, title, instructions, prompt_list): - if prompt_list: - if prompt_list[0][0]=="AD Password: ": - return [self.password] - return [getpass.getpass(prompt_list[0][0])] - return [] + gateway = SSHClient('ssh.risoe.dk', 'mmpe', interactive_auth_handler = sshrisoe_interactive_auth_handler(x.mmpe)) ssh = SSHClient('gorm.risoe.dk', 'mmpe', x.mmpe, gateway = gateway) _,out,_ = ssh.execute("hostname")