diff --git a/wetb/utils/cluster_tools/ssh_client.py b/wetb/utils/cluster_tools/ssh_client.py index c3db5c428da8f78d674bbdc151dae75e62099712..d349c0fcd8166b56965d99cd796dae92c3f4dc2b 100644 --- a/wetb/utils/cluster_tools/ssh_client.py +++ b/wetb/utils/cluster_tools/ssh_client.py @@ -16,12 +16,38 @@ import zipfile from wetb.utils.timing import print_time import glob import getpass -from sshtunnel import SSHTunnelForwarder +from sshtunnel import SSHTunnelForwarder, SSH_CONFIG_FILE class SSHInteractiveAuthTunnelForwarder(SSHTunnelForwarder): - pass + def __init__( + self, + interactive_auth_handler, + ssh_address_or_host=None, + ssh_config_file=SSH_CONFIG_FILE, + ssh_host_key=None, + ssh_password=None, + ssh_pkey=None, + ssh_private_key_password=None, + ssh_proxy=None, + ssh_proxy_enabled=True, + ssh_username=None, + local_bind_address=None, + local_bind_addresses=None, + logger=None, + mute_exceptions=False, + remote_bind_address=None, + remote_bind_addresses=None, + set_keepalive=0.0, + threaded=True, + compression=None, + allow_agent=True, * + args, ** + kwargs): + self.interactive_auth_handler = interactive_auth_handler + SSHTunnelForwarder.__init__(self, ssh_address_or_host=ssh_address_or_host, ssh_config_file=ssh_config_file, ssh_host_key=ssh_host_key, ssh_password=ssh_password, ssh_pkey=ssh_pkey, ssh_private_key_password=ssh_private_key_password, ssh_proxy=ssh_proxy, ssh_proxy_enabled=ssh_proxy_enabled, ssh_username=ssh_username, local_bind_address=local_bind_address, local_bind_addresses=local_bind_addresses, logger=logger, mute_exceptions=mute_exceptions, remote_bind_address=remote_bind_address, remote_bind_addresses=remote_bind_addresses, set_keepalive=set_keepalive, threaded=threaded, compression=compression, allow_agent=allow_agent, *args, **kwargs) + def _connect_to_gateway(self): """ @@ -37,15 +63,7 @@ class SSHInteractiveAuthTunnelForwarder(SSHTunnelForwarder): self._transport = self._get_transport() if self.interactive_auth_handler: self._transport.start_client() - def interactive_handler(title, instructions, prompt_list): - if prompt_list: - if prompt_list[0][0]=="AD Password: ": - import x - return [x.mmpe] - return [getpass.getpass(prompt_list[0][0])] - print ("here") - return [] - self._transport.auth_interactive("mmpe", interactive_handler) + self._transport.auth_interactive(self.ssh_username, self.interactive_auth_handler) else: self._transport.connect(hostkey=self.ssh_host_key, username=self.ssh_username, @@ -64,7 +82,7 @@ class SSHClient(object): "A wrapper of paramiko.SSHClient" TIMEOUT = 4 - def __init__(self, host, username, password=None, port=22, key=None, passphrase=None, gateway=None, interactive_auth_handler=None): + def __init__(self, host, username, password=None, port=22, key=None, passphrase=None, interactive_auth_handler=None, gateway=None): self.host = host self.username = username self.password = password @@ -99,15 +117,22 @@ class SSHClient(object): # if self.password is None or self.password == "": # raise IOError("Password not set for %s"%self.host) if self.gateway: - - self.tunnel = SSHInteractiveAuthTunnelForwarder( - (self.gateway.host, self.gateway.port), - ssh_username=self.gateway.username, - ssh_password=self.gateway.password, - remote_bind_address=(self.host, self.port), - local_bind_address=('0.0.0.0', 10022) - ) - self.tunnel.interactive_auth_handler = self.gateway.interactive_auth_handler + if self.gateway.interactive_auth_handler: + self.tunnel = SSHInteractiveAuthTunnelForwarder(self.gateway.interactive_auth_handler, + (self.gateway.host, self.gateway.port), + ssh_username=self.gateway.username, + ssh_password=self.gateway.password, + remote_bind_address=(self.host, self.port), + local_bind_address=('0.0.0.0', 10022) + ) + else: + self.tunnel = SSHTunnelForwarder((self.gateway.host, self.gateway.port), + ssh_username=self.gateway.username, + ssh_password=self.gateway.password, + remote_bind_address=(self.host, self.port), + local_bind_address=('0.0.0.0', 10022) + ) + self.tunnel.start() self.client = paramiko.SSHClient() self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) @@ -117,19 +142,12 @@ class SSHClient(object): else: self.client = paramiko.SSHClient() self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - try: - self.client.connect(self.host, self.port, username=self.username, password=self.password, pkey=self.key, timeout=self.TIMEOUT) - except paramiko.ssh_exception.SSHException as e: + if self.interactive_auth_handler: transport = self.client.get_transport() - def interactive_handler(title, instructions, prompt_list): - if prompt_list: - if prompt_list[0][0]=="AD Password: ": - import x - return [x.mmpe] - return [getpass.getpass(prompt_list[0][0])] - print ("here") - return [] - transport.auth_interactive(self.username, interactive_handler) + transport.auth_interactive(self.username, self.interactive_handler) + else: + self.client.connect(self.host, self.port, username=self.username, password=self.password, pkey=self.key, timeout=self.TIMEOUT) + @@ -294,8 +312,8 @@ class SSHClient(object): class SharedSSHClient(SSHClient): - def __init__(self, host, username, password=None, port=22, key=None, passphrase=None): - SSHClient.__init__(self, host, username, password=password, port=port, key=key, passphrase=passphrase) + def __init__(self, host, username, password=None, port=22, key=None, passphrase=None, interactive_auth_handler=None, gateway=None): + SSHClient.__init__(self, host, username, password=password, port=port, key=key, passphrase=passphrase, interactive_auth_handler=interactive_auth_handler, gateway=gateway) self.shared_ssh_lock = threading.RLock() self.shared_ssh_queue = deque() self.next = None @@ -327,6 +345,8 @@ class SharedSSHClient(SSHClient): else: self.next = None + + if __name__ == "__main__": from mmpe.ui.qt_ui import QtInputUI q = QtInputUI(None) diff --git a/wetb/utils/tests/test_ssh_client.py b/wetb/utils/tests/test_ssh_client.py new file mode 100644 index 0000000000000000000000000000000000000000..90ad55e2e65393d63813a3f4ea979448679dbba6 --- /dev/null +++ b/wetb/utils/tests/test_ssh_client.py @@ -0,0 +1,206 @@ +''' +Created on 23. dec. 2016 + +@author: mmpe +''' +import unittest +from wetb.utils.cluster_tools.ssh_client import SSHClient +import sys +import os +try: + import x +except: + x=None + +import io +from wetb.utils.timing import print_time +import shutil +import paramiko +from paramiko.message import Message +from paramiko.common import cMSG_SERVICE_REQUEST +import getpass +import logging + +import getpass + + +tfp = os.path.join(os.path.dirname(__file__), 'test_files/') +all = 0 +class TestSSHClient(unittest.TestCase): + + def setUp(self): + if x: + self.ssh = SSHClient('gorm', 'mmpe',x.mmpe ) + + + def test_execute(self): + if 0 or all: + if x: + _,out,_ = self.ssh.execute("ls -a") + ssh_ls = ";".join(sorted(out.split("\n"))[3:]) #Exclude ['', '.', '..'] + win_ls = ";".join(sorted(os.listdir(r"z:"))) + self.assertEqual(ssh_ls, win_ls) + + def test_file_transfer(self): + if 0 or all: + if x: + self.ssh.execute("rm -f tmp.txt") + io.StringIO() + + txt = "Hello world" + f = io.StringIO(txt) + f.seek(0) + self.ssh.upload(f, "tmp.txt") + _,out,_ = self.ssh.execute("cat tmp.txt") + self.assertEqual(out, txt) + fn = tfp + "tmp.txt" + if os.path.isfile (fn): + os.remove(fn) + self.assertFalse(os.path.isfile(fn)) + self.ssh.download("tmp.txt", fn) + with open(fn) as fid: + self.assertEqual(fid.read(), txt) + + + def test_folder_transfer(self): + if 0 or all: + if x: + p = r"C:\mmpe\HAWC2\models\version_12.3beta/" + p = r'C:\mmpe\programming\python\WindEnergyToolbox\wetb\hawc2\tests\test_files\simulation_setup\DTU10MWRef6.0_IOS/' + self.ssh.execute("rm -r -f ./tmp_test") + self.ssh.upload_files(p, "./tmp_test", ["input/"]) + shutil.rmtree("./test/input", ignore_errors=True) + self.ssh.download_files("./tmp_test", tfp, "input/" ) + os.path.isfile(tfp + "/input/data/DTU_10MW_RWT_Blade_st.dat") + shutil.rmtree("./test/input", ignore_errors=True) + + + def test_folder_transfer_specific_files_uppercase(self): + if 0 or all: + if x: + p = tfp + files = [os.path.join(tfp, "TEST.txt")] + self.ssh.execute("rm -r -f ./tmp_test") + self.ssh.upload_files(p, "./tmp_test", file_lst=files) + self.assertFalse(self.ssh.file_exists("./tmp_test/test.txt")) + self.assertTrue(self.ssh.file_exists("./tmp_test/TEST.txt")) + + + + def test_folder_transfer_specific_files(self): + if 0 or all: + if x: + p = r"C:\mmpe\HAWC2\models\version_12.3beta/" + p = r'C:\mmpe\programming\python\WindEnergyToolbox\wetb\hawc2\tests\test_files\simulation_setup\DTU10MWRef6.0_IOS/' + files = [os.path.join(os.path.relpath(root, p), f) for root,_,files in os.walk(p+"input/") for f in files] + self.ssh.execute("rm -r -f ./tmp_test") + self.ssh.upload_files(p, "./tmp_test", file_lst=files[:5]) + self.ssh.download_files("./tmp_test", tfp + "tmp/", file_lst = files[:3]) + self.assertEqual(len(os.listdir(tfp+"tmp/input/data/")),2) + shutil.rmtree(tfp + "tmp/") + +# def test_ssh_gorm(self): +# if x: +# ssh = SSHClient('gorm.risoe.dk', 'mmpe', x.mmpe) +# _,out,_ = ssh.execute("hostname") +# self.assertEqual(out.strip(), "g-000.risoe.dk") + +# def test_ssh_g047(self): +# if x: +# gateway = SSHClient('gorm.risoe.dk', 'mmpe', x.mmpe) +# ssh = SSHClient('g-047', "mmpe", x.mmpe, gateway=gateway) +# self.assertEqual(ssh.execute('hostname')[1].strip(), "g-047") + +# def test_ssh_risoe(self): +# if x: +# class sshrisoe_interactive_auth_handler(object): +# def __init__(self, password): +# self.password = password +# +# def __call__(self, title, instructions, prompt_list): +# if prompt_list: +# if prompt_list[0][0]=="AD Password: ": +# return [self.password] +# return [getpass.getpass(prompt_list[0][0])] +# return [] +# +# ssh = SSHClient('ssh.risoe.dk', 'mmpe', interactive_auth_handler = sshrisoe_interactive_auth_handler(x.mmpe)) +# _,out,_ = ssh.execute("hostname") +# self.assertEqual(out.strip(), "ssh-03.risoe.dk") + + def test_ssh_risoe_gorm(self): + if x: + class sshrisoe_interactive_auth_handler(object): + def __init__(self, password): + self.password = password + + def __call__(self, title, instructions, prompt_list): + if prompt_list: + if prompt_list[0][0]=="AD Password: ": + return [self.password] + return [getpass.getpass(prompt_list[0][0])] + return [] + gateway = SSHClient('ssh.risoe.dk', 'mmpe', x.mmpe, interactive_auth_handler = sshrisoe_interactive_auth_handler(x.mmpe)) + ssh = SSHClient('gorm.risoe.dk', 'mmpe', x.mmpe, gateway = gateway) + _,out,_ = ssh.execute("hostname") + self.assertEqual(out.strip(), "g-000.risoe.dk") + + + +# def test_ssh_risoe(self): +# #logger = logging.getLogger("paramiko") +# #logger.setLevel(logging.DEBUG) # for example +# #ch = logging.StreamHandler(sys.stdout) +# #ch.setLevel(logging.DEBUG) +# #formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +# #ch.setFormatter(formatter) +# #logger.addHandler(ch) +# ssh = SSHClient('ssh.risoe.dk', 'mmpe') +# print (ssh.connect()) +# return +# username = "mmpe" +# +# client = paramiko.client.SSHClient() +# +# # Any means of getting the PKey will do. This code assumes you've only got one key loaded in your active ssh-agent. +# # See also: +# # - http://docs.paramiko.org/en/1.17/api/keys.html#paramiko.pkey.PKey +# # - http://docs.paramiko.org/en/1.17/api/client.html#paramiko.client.SSHClient.connect +# my_pkey = None #paramiko.agent.Agent().get_keys()[0] +# +# try: +# client.connect( +# hostname="ssh.risoe.dk", +# port=22, +# username=username, +# look_for_keys=False, +# pkey=my_pkey +# ) +# except paramiko.ssh_exception.SSHException as e: +# pass +# +# transport = client.get_transport() +# +# # Sometimes sshd is configured to use 'keyboard-interactive' instead of 'password' to implement the YubiKey challenge. +# # In that case, you can use something like this. +# # The code below assumes the server will only ask one question and expect the YubiKey OTP as an answer. +# # If there's more questions to answer, you should handle those per the docs at: +# # http://docs.paramiko.org/en/1.17/api/transport.html#paramiko.transport.Transport.auth_interactive +# # +# def interactive_handler(title, instructions, prompt_list): +# if prompt_list: +# if prompt_list[0][0]=="AD Password: ": +# return [x.mmpe] +# return [getpass.getpass(prompt_list[0][0])] +# print ("here") +# return [] +# transport.auth_interactive(username, interactive_handler) +# +# #transport.auth_password(username, x.mmpe) +# +# # You should now be able to use client as the authenticated user. +# client.exec_command("echo hej") +# +if __name__ == "__main__": + #import sys;sys.argv = ['', 'Test.testName'] + unittest.main() \ No newline at end of file