#!/usr/bin/env python
# coding: utf-8
from dpdispatcher.base_context import BaseContext
import os, paramiko, tarfile, time
import uuid
import shutil
from functools import lru_cache
from glob import glob
from dpdispatcher import dlog
from dargs.dargs import Argument
from typing import List
import pathlib
import socket
# from dpdispatcher.submission import Machine
from dpdispatcher.utils import (
get_sha256, generate_totp, rsync,
retry, RetrySignal,
)
[docs]class SSHSession (object):
def __init__(self,
hostname,
username,
password=None,
port=22,
key_filename=None,
passphrase=None,
timeout=10,
totp_secret=None,
tar_compress=True
):
self.hostname = hostname
self.username = username
self.password = password
self.port = port
self.key_filename = key_filename
self.passphrase = passphrase
self.timeout = timeout
self.totp_secret = totp_secret
self.ssh = None
self.tar_compress = tar_compress
self._setup_ssh()
# @classmethod
# def deserialize(cls, jdata):
# instance = cls(**jdata)
# return instance
# def bk_ensure_alive(self,
# max_check = 10,
# sleep_time = 10):
# count = 1
# while not self._check_alive():
# if count == max_check:
# raise RuntimeError('cannot connect ssh after %d failures at interval %d s' %
# (max_check, sleep_time))
# dlog.info('connection check failed, try to reconnect to ' + self.remote_host)
# self._setup_ssh(hostname=self.remote_host,
# port=self.remote_port,
# username=self.remote_uname,
# password=self.remote_password,
# key_filename=self.local_key_filename,
# timeout=self.remote_timeout,
# passphrase=self.local_key_passphrase)
# count += 1
# time.sleep(sleep_time)
[docs] def ensure_alive(self,
max_check = 10,
sleep_time = 10):
count = 1
while not self._check_alive():
if count == max_check:
raise RuntimeError('cannot connect ssh after %d failures at interval %d s' %
(max_check, sleep_time))
dlog.info('connection check failed, try to reconnect to ' + self.remote_root)
self._setup_ssh()
count += 1
time.sleep(sleep_time)
def _check_alive(self):
if self.ssh is None:
return False
try :
transport = self.ssh.get_transport()
transport.send_ignore()
return True
except EOFError:
return False
# def bk_setup_ssh(self,
# hostname,
# port=22,
# username=None,
# password=None,
# key_filename=None,
# timeout=None,
# passphrase=None):
# self.ssh = paramiko.SSHClient()
# # ssh_client.load_system_host_keys()
# self.ssh.set_missing_host_key_policy(paramiko.WarningPolicy)
# self.ssh.connect(hostname=hostname, port=port,
# username=username, password=password,
# key_filename=key_filename, timeout=timeout, passphrase=passphrase)
# assert(self.ssh.get_transport().is_active())
# transport = self.ssh.get_transport()
# transport.set_keepalive(60)
def _setup_ssh(self):
# machine = self.machine
self.ssh = paramiko.SSHClient()
self.ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy)
# if self.totp_secret and self.password is None:
# self.password = generate_totp(self.totp_secret)
# self.ssh.connect(hostname=self.hostname, port=self.port,
# username=self.username, password=self.password,
# key_filename=self.key_filename, timeout=self.timeout,passphrase=self.passphrase,
# compress=True,
# )
# assert(self.ssh.get_transport().is_active())
# transport = self.ssh.get_transport()
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(self.timeout)
sock.connect((self.hostname, self.port))
#Make a Paramiko Transport object using the socket
ts = paramiko.Transport(sock)
ts.banner_timeout = 60
ts.use_compression(compress=True)
#Tell Paramiko that the Transport is going to be used as a client
ts.start_client(timeout=self.timeout)
#Begin authentication; note that the username and callback are passed
if self.totp_secret:
ts.auth_interactive(self.username, self.inter_handler)
else:
key_path = os.path.join(os.path.expanduser("~"), ".ssh", "id_rsa")
if self.key_filename:
key_path = os.path.abspath(self.key_filename)
key = None
if os.path.exists(key_path):
try:
key = paramiko.RSAKey.from_private_key_file(key_path)
except paramiko.PasswordRequiredException:
key = paramiko.RSAKey.from_private_key_file(key_path, self.passphrase)
if key:
try:
ts.auth_publickey(self.username, key)
except paramiko.ssh_exception.AuthenticationException:
if self.password:
ts.auth_password(self.username, self.password)
else:
raise RuntimeError("Authentication failed, try to provide password")
elif self.password:
ts.auth_password(self.username, self.password)
else:
raise RuntimeError("Please provide at least one form of authentication")
assert(ts.is_active())
#Opening a session creates a channel along the socket to the server
ts.open_session(timeout=self.timeout)
ts.set_keepalive(60)
self.ssh._transport = ts
# reset sftp
self._sftp = None
[docs] def inter_handler(self, title, instructions, prompt_list):
"""
inter_handler: the callback for paramiko.transport.auth_interactive
The prototype for this function is defined by Paramiko, so all of the
arguments need to be there, even though we don't use 'title' or
'instructions'.
The function is expected to return a tuple of data containing the
responses to the provided prompts. Experimental results suggests that
there will be one call of this function per prompt, but the mechanism
allows for multiple prompts to be sent at once, so it's best to assume
that that can happen.
Since tuples can't really be built on the fly, the responses are
collected in a list which is then converted to a tuple when it's time
to return a value.
Experiments suggest that the username prompt never happens. This makes
sense, but the Username prompt is included here just in case.
"""
resp = [] #Initialize the response container
#Walk the list of prompts that the server sent that we need to answer
for pr in prompt_list:
#str() used to to make sure that we're dealing with a string rather than a unicode string
#strip() used to get rid of any padding spaces sent by the server
pr_str = str(pr[0]).strip().lower()
if "username" in pr_str:
resp.append(self.username)
elif "password" in pr_str:
resp.append(self.password)
elif "verification" in pr_str and self.totp_secret:
resp.append(generate_totp(self.totp_secret))
return tuple(resp) #Convert the response list to a tuple and return it
[docs] def get_ssh_client(self) :
return self.ssh
# def get_session_root(self):
# return self.remote_root
[docs] def close(self) :
self.ssh.close()
@retry(sleep=1)
def exec_command(self, cmd):
"""Calling self.ssh.exec_command but has an exception check."""
try:
return self.ssh.exec_command(cmd)
except (paramiko.ssh_exception.SSHException, socket.timeout) as e:
# SSH session not active
# retry for up to 3 times
# ensure alive
self.ensure_alive()
raise RetrySignal("SSH session not active in calling %s" % cmd) from e
@property
def sftp(self):
"""Returns sftp. Open a new one if not existing."""
if self._sftp is None:
self.ensure_alive()
self._sftp = self.ssh.open_sftp()
return self._sftp
[docs] @staticmethod
def arginfo():
doc_hostname = 'hostname or ip of ssh connection.'
doc_username = 'username of target linux system'
doc_password = ('(deprecated) password of linux system. Please use '
'`SSH keys <https://www.ssh.com/academy/ssh/key>`_ instead to improve security.')
doc_port = 'ssh connection port.'
doc_key_filename = 'key filename used by ssh connection. If left None, find key in ~/.ssh or ' \
'use password for login'
doc_passphrase = 'passphrase of key used by ssh connection'
doc_timeout = 'timeout of ssh connection'
doc_totp_secret = 'Time-based one time password secret. It should be a base32-encoded string' \
' extracted from the 2D code.'
doc_tar_compress = 'The archive will be compressed in upload and download if it is True. If not, compression will be skipped.'
ssh_remote_profile_args = [
Argument("hostname", str, optional=False, doc=doc_hostname),
Argument("username", str, optional=False, doc=doc_username),
Argument("password", str, optional=True, doc=doc_password),
Argument("port", int, optional=True, default=22, doc=doc_port),
Argument("key_filename", [str, None], optional=True, default=None, doc=doc_key_filename),
Argument("passphrase", [str, None], optional=True, default=None, doc=doc_passphrase),
Argument("timeout", int, optional=True, default=10, doc=doc_timeout),
Argument("totp_secret", str, optional=True, default=None, doc=doc_totp_secret),
Argument("tar_compress", bool, optional=True, default=True, doc = doc_tar_compress),
]
ssh_remote_profile_format = Argument("ssh_session", dict, ssh_remote_profile_args)
return ssh_remote_profile_format
[docs] def put(self, from_f, to_f):
if self.rsync_available:
return rsync(from_f, self.remote + ":" + to_f)
return self.sftp.put(from_f, to_f)
[docs] def get(self, from_f, to_f):
if self.rsync_available:
return rsync(self.remote + ":" + from_f, to_f)
return self.sftp.get(from_f, to_f)
@property
@lru_cache(maxsize=None)
def rsync_available(self) -> bool:
return (shutil.which("rsync") is not None and self.password is None
and self.port == 22 and self.key_filename is None
and self.passphrase is None)
@property
def remote(self) -> str:
return "%s@%s" % (self.username, self.hostname)
[docs]class SSHContext(BaseContext):
def __init__ (self,
local_root,
remote_root,
remote_profile,
clean_asynchronously=False,
*args,
**kwargs,
):
assert(type(local_root) == str)
self.init_local_root = local_root
self.init_remote_root = remote_root
self.temp_local_root = os.path.abspath(local_root)
assert os.path.isabs(remote_root), f"remote_root must be a abspath"
self.temp_remote_root = remote_root
self.remote_profile = remote_profile
# self.job_uuid = None
self.clean_asynchronously = clean_asynchronously
# self.job_uuid = job_uuid
# if job_uuid:
# self.job_uuid=job_uuid
# else:
# self.job_uuid = str(uuid.uuid4())
self.ssh_session = SSHSession(**remote_profile)
# self.temp_remote_root = os.path.join(self.ssh_session.get_session_root())
self.ssh_session.ensure_alive()
try:
self.sftp.mkdir(self.temp_remote_root)
except OSError:
pass
[docs] @classmethod
def load_from_dict(cls, context_dict):
# instance = cls()
# input = dict(
# hostname = jdata['hostname'],
# remote_root = jdata['remote_root'],
# username = jdata['username'],
# password = jdata.get('password', None),
# port = jdata.get('port', 22),
# key_filename = jdata.get('key_filename', None),
# passphrase = jdata.get('passphrase', None),
# timeout = jdata.get('timeout', 10),
# tar_compress = jdata.get('tar_compress', True)
# )
local_root = context_dict['local_root']
remote_root = context_dict['remote_root']
remote_profile = context_dict['remote_profile']
clean_asynchronously = context_dict.get('clean_asynchronously', False)
ssh_context = cls(
local_root=local_root,
remote_root=remote_root,
remote_profile=remote_profile,
clean_asynchronously=clean_asynchronously
)
# local_root = jdata['local_root']
# ssh_session = SSHSession(**input)
# ssh_context = SSHContext(
# local_root=local_root,
# ssh_session=ssh_session,
# clean_asynchronously=jdata.get('clean_asynchronously', False),
# )
return ssh_context
@property
def ssh(self):
return self.ssh_session.get_ssh_client()
@property
def sftp(self):
return self.ssh_session.sftp
[docs] def close(self):
self.ssh_session.close()
[docs] def get_job_root(self) :
return self.remote_root
[docs] def bind_submission(self, submission):
self.submission = submission
self.local_root = pathlib.PurePath(os.path.join(self.temp_local_root, submission.work_base)).as_posix()
# self.remote_root = os.path.join(self.temp_remote_root, self.submission.submission_hash, self.submission.work_base )
self.remote_root = pathlib.PurePath(os.path.join(self.temp_remote_root, self.submission.submission_hash)).as_posix()
sftp = self.ssh_session.ssh.open_sftp()
try:
sftp.mkdir(self.remote_root)
except OSError:
pass
# self.job_uuid = submission.submission_hash
# dlog.debug("debug:SSHContext.bind_submission"
# "{submission.submission_hash}; {self.local_root}; {self.remote_root")
# try:
# print('self.remote_root', self.remote_root)
# sftp = self.ssh_session.ssh.open_sftp()
# sftp.mkdir(self.remote_root)
# sftp.close()
# except Exception:
# pass
def _walk_directory(self, files, work_path, file_list, directory_list):
"""Convert input path to list of files and directories."""
for jj in files :
file_name = os.path.join(work_path, jj)
if os.path.isfile(file_name):
file_list.append(file_name)
elif os.path.isdir(file_name):
for root, dirs, files in os.walk(file_name, topdown=False, followlinks=True):
if not files:
directory_list.append(root)
for name in files:
file_list.append(os.path.join(root, name))
elif glob(file_name):
# If the file name contains a wildcard, os.path functions will fail to identify it. Use glob to get the complete list of filenames which match the wildcard.
abs_file_list = glob(file_name)
rel_file_list = [os.path.relpath(ii, start=work_path) for ii in abs_file_list]
self._walk_directory(rel_file_list, work_path, file_list, directory_list)
else:
raise RuntimeError(f'cannot find upload file {work_path} {jj}')
[docs] def upload(self,
# job_dirs,
submission,
# local_up_files,
dereference = True) :
dlog.info(f'remote path: {self.remote_root}')
# remote_cwd =
self.ssh_session.sftp.chdir(self.temp_remote_root)
recover = False
try:
self.ssh_session.sftp.mkdir(os.path.basename(self.remote_root))
except OSError:
# mkdir failed meaning it exists
if len(self.ssh_session.sftp.listdir(os.path.basename(self.remote_root))):
recover = True
self.ssh_session.sftp.chdir(None)
cwd = os.getcwd()
os.chdir(self.local_root)
file_list = []
directory_list = []
for task in submission.belonging_tasks:
directory_list.append(task.task_work_path)
# file_list.append(ii)
self._walk_directory(task.forward_files, task.task_work_path, file_list, directory_list)
self._walk_directory(submission.forward_common_files, self.local_root, file_list, directory_list)
# check if the same file exists on the remote file
# only check sha256 when the job is recovered
if recover:
# generate local sha256 file
sha256_list = []
for jj in file_list:
sha256 = get_sha256(jj)
jj_rel = pathlib.PurePath(os.path.relpath(jj, self.local_root)).as_posix()
sha256_list.append(f"{sha256} {jj_rel}")
# write to remote
sha256_file = os.path.join(self.remote_root, ".tmp.sha256." + str(uuid.uuid4()))
self.write_file(sha256_file, "\n".join(sha256_list))
# check sha256
# `:` means pass: https://stackoverflow.com/a/2421592/9567349
_, stdout, _ = self.block_checkcall("sha256sum -c %s --quiet >.sha256sum_stdout 2>/dev/null || :" % sha256_file)
self.sftp.remove(sha256_file)
# regenerate file list
file_list = []
for ii in self.read_file(".sha256sum_stdout").split("\n"):
if ii:
file_list.append(ii.split(":")[0])
else:
# convert to relative path to local_root
file_list = [os.path.relpath(jj, self.local_root) for jj in file_list]
self._put_files(file_list, dereference = dereference, directories=directory_list, tar_compress = self.remote_profile.get('tar_compress', None))
os.chdir(cwd)
[docs] def download(self,
submission,
# job_dirs,
# remote_down_files,
check_exists = False,
mark_failure = True,
back_error=False) :
self.ssh_session.ensure_alive()
cwd = os.getcwd()
os.chdir(self.local_root)
file_list = []
# for ii in job_dirs :
for task in submission.belonging_tasks :
for jj in task.backward_files:
file_name = pathlib.PurePath(os.path.join(task.task_work_path, jj)).as_posix()
if check_exists:
if self.check_file_exists(file_name):
file_list.append(file_name)
elif mark_failure :
with open(os.path.join(self.local_root, task.task_work_path, 'tag_failure_download_%s' % jj), 'w') as fp: pass
else:
pass
else:
file_list.append(file_name)
if back_error:
errors=glob(os.path.join(task.task_work_path, 'error*'))
file_list.extend(errors)
file_list.extend(submission.backward_common_files)
if len(file_list) > 0:
self._get_files(file_list, tar_compress = self.remote_profile.get('tar_compress', None))
os.chdir(cwd)
[docs] def block_checkcall(self,
cmd,
asynchronously=False,
stderr_whitelist=None) :
"""Run command with arguments. Wait for command to complete. If the return code
was zero then return, otherwise raise RuntimeError.
Parameters
----------
cmd: str
The command to run.
asynchronously: bool, optional, default=False
Run command asynchronously. If True, `nohup` will be used to run the command.
"""
self.ssh_session.ensure_alive()
if asynchronously:
cmd = "nohup %s >/dev/null &" % cmd
stdin, stdout, stderr = self.ssh_session.exec_command(('cd %s ;' % self.remote_root) + cmd)
exit_status = stdout.channel.recv_exit_status()
if exit_status != 0:
raise RuntimeError("Get error code %d in calling %s through ssh with job: %s . message: %s" %
(exit_status, cmd, self.submission.submission_hash, stderr.read().decode('utf-8')))
return stdin, stdout, stderr
[docs] def block_call(self,
cmd) :
self.ssh_session.ensure_alive()
stdin, stdout, stderr = self.ssh_session.exec_command(('cd %s ;' % self.remote_root) + cmd)
exit_status = stdout.channel.recv_exit_status()
return exit_status, stdin, stdout, stderr
[docs] def clean(self) :
self.ssh_session.ensure_alive()
self._rmtree(self.remote_root)
[docs] def write_file(self, fname, write_str):
self.ssh_session.ensure_alive()
fname = pathlib.PurePath(os.path.join(self.remote_root, fname)).as_posix()
# to prevent old file from being overwritten but cancelled, create a temporary file first
# when it is fully written, rename it to the original file name
with self.sftp.open(fname + "~", 'w') as fp :
fp.write(write_str)
# sftp.rename may throw OSError
self.block_checkcall("mv %s %s" % (fname + "~", fname))
[docs] def read_file(self, fname):
self.ssh_session.ensure_alive()
with self.sftp.open(pathlib.PurePath(os.path.join(self.remote_root, fname)).as_posix(), 'r') as fp:
ret = fp.read().decode('utf-8')
return ret
[docs] def check_file_exists(self, fname):
self.ssh_session.ensure_alive()
try:
self.sftp.stat(pathlib.PurePath(os.path.join(self.remote_root, fname)).as_posix())
ret = True
except IOError:
ret = False
return ret
[docs] def call(self, cmd):
stdin, stdout, stderr = self.ssh_session.exec_command(cmd)
# stdin, stdout, stderr = self.ssh.exec_command('echo $$; exec ' + cmd)
# pid = stdout.readline().strip()
# print(pid)
return {'stdin':stdin, 'stdout':stdout, 'stderr':stderr}
[docs] def check_finish(self, cmd_pipes):
return cmd_pipes['stdout'].channel.exit_status_ready()
[docs] def get_return(self, cmd_pipes):
if not self.check_finish(cmd_pipes):
return None, None, None
else :
retcode = cmd_pipes['stdout'].channel.recv_exit_status()
return retcode, cmd_pipes['stdout'], cmd_pipes['stderr']
[docs] def kill(self, cmd_pipes) :
raise RuntimeError('dose not work! we do not know how to kill proc through paramiko.SSHClient')
#self.block_checkcall('kill -15 %s' % cmd_pipes['pid'])
def _rmtree(self, remotepath, verbose = False):
"""Remove the remote path."""
# The original implementation method removes files one by one using sftp.
# If the latency of the remote server is high, it is very slow.
# Thus, it's better to use system's `rm` to remove a directory, which may
# save a lot of time.
if verbose:
dlog.info('removing %s' % remotepath)
# In some supercomputers, it's very slow to remove large numbers of files
# (e.g. directory containing trajectory) due to bad I/O performance.
# So an asynchronously option is provided.
self.block_checkcall('rm -rf %s' % remotepath, asynchronously=self.clean_asynchronously)
def _put_files(self,
files,
dereference = True,
directories = None,
tar_compress = True,
) :
"""Upload files to server.
Parameters
----------
files: list
uploaded files
dereference: bool, default: True
If dereference is False, add symbolic and hard links to the archive.
If it is True, add the content of the target files to the archive.
This has no effect on systems that do not support symbolic links.
directories: list, default: None
uploaded directories non-recursively. Use `files` for uploading
recursively
tar_compress: bool, default: True
If tar_compress is True, compress the archive using gzip
It it is False, then it is uncompressed
"""
of_suffix = '.tgz'
tarfile_mode = "w:gz"
kwargs = {'compresslevel': 6}
if not tar_compress :
of_suffix = '.tar'
tarfile_mode = "w"
kwargs = {}
of = self.submission.submission_hash + of_suffix
# local tar
cwd = os.getcwd()
os.chdir(self.local_root)
if os.path.isfile(of) :
os.remove(of)
with tarfile.open(of, tarfile_mode, dereference = dereference, **kwargs) as tar:
for ii in files :
tar.add(ii)
if directories is not None:
for ii in directories:
tar.add(ii, recursive=False)
os.chdir(cwd)
self.ssh_session.ensure_alive()
try:
self.sftp.mkdir(self.remote_root)
except OSError:
pass
# trans
from_f = pathlib.PurePath(os.path.join(self.local_root, of)).as_posix()
to_f = pathlib.PurePath(os.path.join(self.remote_root, of)).as_posix()
try:
self.ssh_session.put(from_f, to_f)
except FileNotFoundError:
raise FileNotFoundError("from %s to %s @ %s : %s Error!"%(from_f, self.ssh_session.username, self.ssh_session.hostname, to_f))
# remote extract
self.block_checkcall('tar xf %s' % of)
# clean up
os.remove(from_f)
self.sftp.remove(to_f)
def _get_files(self,
files,
tar_compress = True) :
of_suffix = '.tar.gz'
tarfile_mode = "r:gz"
tar_command = 'czfh'
if not tar_compress :
of_suffix = '.tar'
tarfile_mode = "r"
tar_command = 'cfh'
of = self.submission.submission_hash + of_suffix
# remote tar
# If the number of files are large, we may get "Argument list too long" error.
# Thus, "-T" accepts a file containing the list of files
per_nfile = 100
ntar = len(files) // per_nfile + 1
if ntar <= 1:
self.block_checkcall('tar %s %s %s' % (tar_command, of, " ".join(files)))
else:
file_list_file = os.path.join(self.remote_root, ".tmp.tar." + str(uuid.uuid4()))
self.write_file(file_list_file, "\n".join(files))
self.block_checkcall('tar %s %s -T %s' % (tar_command, of, file_list_file))
# trans
from_f = pathlib.PurePath(os.path.join(self.remote_root, of)).as_posix()
to_f = pathlib.PurePath(os.path.join(self.local_root, of)).as_posix()
if os.path.isfile(to_f) :
os.remove(to_f)
self.ssh_session.get(from_f, to_f)
# extract
cwd = os.getcwd()
os.chdir(self.local_root)
with tarfile.open(of, mode = tarfile_mode) as tar:
tar.extractall()
os.chdir(cwd)
# cleanup
os.remove(to_f)
self.sftp.remove(from_f)
[docs] @classmethod
def machine_subfields(cls) -> List[Argument]:
"""Generate the machine subfields.
Returns
-------
list[Argument]
machine subfields
"""
doc_remote_profile = 'The information used to maintain the connection with remote machine.'
remote_profile_format = SSHSession.arginfo()
remote_profile_format.name = "remote_profile"
remote_profile_format.doc = doc_remote_profile
return [remote_profile_format]