Source code for dpti.lib.RemoteJob

#!/usr/bin/env python3

import os
import stat
import tarfile
import uuid
from enum import Enum

import paramiko


[docs] class JobStatus(Enum): unsubmitted = 1 waiting = 2 running = 3 terminated = 4 finished = 5 unknow = 100
def _default_item(resources, key, value): if key not in resources: resources[key] = value def _set_default_resource(res): if res is None: res = {} _default_item(res, "numb_node", 1) _default_item(res, "task_per_node", 1) _default_item(res, "numb_gpu", 0) _default_item(res, "time_limit", "1:0:0") _default_item(res, "mem_limit", -1) _default_item(res, "partition", "") _default_item(res, "account", "") _default_item(res, "qos", "") _default_item(res, "constraint_list", []) _default_item(res, "license_list", []) _default_item(res, "exclude_list", []) _default_item(res, "module_unload_list", []) _default_item(res, "module_list", []) _default_item(res, "source_list", []) _default_item(res, "envs", None) _default_item(res, "with_mpi", False)
[docs] class SSHSession: def __init__(self, jdata): self.remote_profile = jdata # with open(remote_profile) as fp : # self.remote_profile = json.load(fp) self.remote_host = self.remote_profile["hostname"] self.remote_port = self.remote_profile["port"] self.remote_uname = self.remote_profile["username"] self.remote_workpath = self.remote_profile["work_path"] self.ssh = self._setup_ssh( self.remote_host, self.remote_port, username=self.remote_uname ) def _setup_ssh(self, hostname, port, username=None, password=None): ssh_client = paramiko.SSHClient() ssh_client.load_system_host_keys() ssh_client.set_missing_host_key_policy(paramiko.WarningPolicy) ssh_client.connect(hostname, port=port, username=username, password=password) assert ssh_client.get_transport().is_active() return ssh_client
[docs] def get_ssh_client(self): return self.ssh
[docs] def get_session_root(self): return self.remote_workpath
[docs] def close(self): self.ssh.close()
[docs] class RemoteJob: def __init__(self, ssh_session, local_root): self.local_root = os.path.abspath(local_root) self.job_uuid = str(uuid.uuid4()) # self.job_uuid = 'a21d0017-c9f1-4d29-9a03-97df06965cef' self.remote_root = os.path.join(ssh_session.get_session_root(), self.job_uuid) print("local_root is ", local_root) print("remote_root is", self.remote_root) self.ssh = ssh_session.get_ssh_client() sftp = self.ssh.open_sftp() sftp.mkdir(self.remote_root) sftp.close() # open('job_uuid', 'w').write(self.job_uuid)
[docs] def get_job_root(self): return self.remote_root
[docs] def upload(self, job_dirs, local_up_files, dereference=True): cwd = os.getcwd() os.chdir(self.local_root) file_list = [] for ii in job_dirs: for jj in local_up_files: file_list.append(os.path.join(ii, jj)) self._put_files(file_list, dereference=dereference) os.chdir(cwd)
[docs] def download(self, job_dirs, remote_down_files): cwd = os.getcwd() os.chdir(self.local_root) file_list = [] for ii in job_dirs: for jj in remote_down_files: file_list.append(os.path.join(ii, jj)) self._get_files(file_list) os.chdir(cwd)
[docs] def block_checkcall(self, cmd): stdin, stdout, stderr = self.ssh.exec_command( (f"cd {self.remote_root} ;") + cmd ) exit_status = stdout.channel.recv_exit_status() if exit_status != 0: raise RuntimeError( "Get error code %d in calling through ssh with job: %s ", (exit_status, self.job_uuid), ) return stdin, stdout, stderr
[docs] def block_call(self, cmd): stdin, stdout, stderr = self.ssh.exec_command( (f"cd {self.remote_root} ;") + cmd ) exit_status = stdout.channel.recv_exit_status() return exit_status, stdin, stdout, stderr
[docs] def clean(self): sftp = self.ssh.open_sftp() self._rmtree(sftp, self.remote_root) sftp.close()
def _rmtree(self, sftp, remotepath, level=0, verbose=False): for f in sftp.listdir_attr(remotepath): rpath = os.path.join(remotepath, f.filename) if stat.S_ISDIR(f.st_mode): self._rmtree(sftp, rpath, level=(level + 1)) else: rpath = os.path.join(remotepath, f.filename) if verbose: print("removing {}{}".format(" " * level, rpath)) sftp.remove(rpath) if verbose: print("removing {}{}".format(" " * level, remotepath)) sftp.rmdir(remotepath) def _put_files(self, files, dereference=True): of = self.job_uuid + ".tgz" # local tar cwd = os.getcwd() os.chdir(self.local_root) if os.path.isfile(of): os.remove(of) with tarfile.open(of, "w:gz", dereference=dereference) as tar: for ii in files: tar.add(ii) os.chdir(cwd) # trans from_f = os.path.join(self.local_root, of) to_f = os.path.join(self.remote_root, of) sftp = self.ssh.open_sftp() sftp.put(from_f, to_f) # remote extract self.block_checkcall(f"tar xf {of}") # clean up os.remove(from_f) sftp.remove(to_f) sftp.close() def _get_files(self, files): of = self.job_uuid + ".tgz" flist = "" for ii in files: flist += " " + ii # remote tar self.block_checkcall(f"tar czf {of} {flist}") # trans from_f = os.path.join(self.remote_root, of) to_f = os.path.join(self.local_root, of) if os.path.isfile(to_f): os.remove(to_f) sftp = self.ssh.open_sftp() sftp.get(from_f, to_f) # extract cwd = os.getcwd() os.chdir(self.local_root) with tarfile.open(of, "r:gz") as tar: tar.extractall() os.chdir(cwd) # cleanup os.remove(to_f) sftp.remove(from_f)
[docs] class CloudMachineJob(RemoteJob):
[docs] def submit(self, job_dirs, cmd, args=None, resources=None): # print("Current path is",os.getcwd()) # for ii in job_dirs : # if not os.path.isdir(ii) : # raise RuntimeError("cannot find dir %s" % ii) # print(self.remote_root) script_name = self._make_script(job_dirs, cmd, args, resources) self.stdin, self.stdout, self.stderr = self.ssh.exec_command( f"cd {self.remote_root}; bash {script_name}" )
# print(self.stderr.read().decode('utf-8')) # print(self.stdout.read().decode('utf-8'))
[docs] def check_status(self): if not self._check_finish(self.stdout): return JobStatus.running elif self._get_exit_status(self.stdout) == 0: return JobStatus.finished else: return JobStatus.terminated
def _check_finish(self, stdout): return stdout.channel.exit_status_ready() def _get_exit_status(self, stdout): return stdout.channel.recv_exit_status() def _make_script(self, job_dirs, cmd, args=None, resources=None): _set_default_resource(resources) envs = resources["envs"] module_list = resources["module_list"] module_unload_list = resources["module_unload_list"] task_per_node = resources["task_per_node"] script_name = "run.sh" if args is None: args = [] for ii in job_dirs: args.append("") script = os.path.join(self.remote_root, script_name) sftp = self.ssh.open_sftp() with sftp.open(script, "w") as fp: fp.write("#!/bin/bash\n\n") # fp.write('set -euo pipefail\n') if envs is not None: for key in envs.keys(): fp.write(f"export {key}={envs[key]}\n") fp.write("\n") if module_unload_list is not None: for ii in module_unload_list: fp.write(f"module unload {ii}\n") fp.write("\n") if module_list is not None: for ii in module_list: fp.write(f"module load {ii}\n") fp.write("\n") for ii, jj in zip(job_dirs, args): fp.write(f"cd {ii}\n") fp.write("test $? -ne 0 && exit\n") if resources["with_mpi"] is True: fp.write("mpirun -n %d %s %s\n" % (task_per_node, cmd, jj)) else: fp.write(f"{cmd} {jj}\n") fp.write("test $? -ne 0 && exit\n") fp.write(f"cd {self.remote_root}\n") fp.write("test $? -ne 0 && exit\n") fp.write("\ntouch tag_finished\n") sftp.close() return script_name
[docs] class SlurmJob(RemoteJob):
[docs] def submit(self, job_dirs, cmd, args=None, resources=None): script_name = self._make_script(job_dirs, cmd, args, res=resources) stdin, stdout, stderr = self.block_checkcall( f"cd {self.remote_root}; sbatch {script_name}" ) subret = stdout.readlines() job_id = subret[0].split()[-1] sftp = self.ssh.open_sftp() with sftp.open(os.path.join(self.remote_root, "job_id"), "w") as fp: fp.write(job_id) sftp.close()
[docs] def check_status(self): job_id = self._get_job_id() if job_id == "": raise RuntimeError(f"job {self.remote_root} is has not been submitted") ret, stdin, stdout, stderr = self.block_call("squeue --job " + job_id) err_str = stderr.read().decode("utf-8") if ret != 0: if "Invalid job id specified" in err_str: if self._check_finish_tag(): return JobStatus.finished else: return JobStatus.terminated else: raise RuntimeError( "status command squeue fails to execute\nerror message:%s\nreturn code %d\n" % (err_str, ret) ) status_line = stdout.read().decode("utf-8").split("\n")[-2] status_word = status_line.split()[-4] if status_word in ["PD", "CF", "S"]: return JobStatus.waiting elif status_word in ["R", "CG"]: return JobStatus.running elif status_word in [ "C", "E", "K", "BF", "CA", "CD", "F", "NF", "PR", "SE", "ST", "TO", ]: if self._check_finish_tag(): return JobStatus.finished else: return JobStatus.terminated else: return JobStatus.unknown
def _get_job_id(self): sftp = self.ssh.open_sftp() with sftp.open(os.path.join(self.remote_root, "job_id"), "r") as fp: ret = fp.read().decode("utf-8") sftp.close() return ret def _check_finish_tag(self): sftp = self.ssh.open_sftp() try: sftp.stat(os.path.join(self.remote_root, "tag_finished")) ret = True except OSError: ret = False sftp.close() return ret def _make_script(self, job_dirs, cmd, args=None, res=None): _set_default_resource(res) ret = "" ret += "#!/bin/bash -l\n" ret += "#SBATCH -N %d\n" % res["numb_node"] ret += "#SBATCH --ntasks-per-node %d\n" % res["task_per_node"] ret += "#SBATCH -t {}\n".format(res["time_limit"]) if res["mem_limit"] > 0: ret += "#SBATCH --mem %dG \n" % res["mem_limit"] if len(res["account"]) > 0: ret += "#SBATCH --account {} \n".format(res["account"]) if len(res["partition"]) > 0: ret += "#SBATCH --partition {} \n".format(res["partition"]) if len(res["qos"]) > 0: ret += "#SBATCH --qos {} \n".format(res["qos"]) if res["numb_gpu"] > 0: ret += "#SBATCH --gres=gpu:%d\n" % res["numb_gpu"] for ii in res["constraint_list"]: ret += f"#SBATCH -C {ii} \n" for ii in res["license_list"]: ret += f"#SBATCH -L {ii} \n" for ii in res["exclude_list"]: ret += f"#SBATCH --exclude {ii} \n" ret += "\n" # ret += 'set -euo pipefail\n\n' for ii in res["module_unload_list"]: ret += f"module unload {ii}\n" for ii in res["module_list"]: ret += f"module load {ii}\n" ret += "\n" for ii in res["source_list"]: ret += f"source {ii}\n" ret += "\n" envs = res["envs"] if envs is not None: for key in envs.keys(): ret += f"export {key}={envs[key]}\n" ret += "\n" if args is None: args = [] for ii in job_dirs: args.append("") for ii, jj in zip(job_dirs, args): ret += f"cd {ii}\n" ret += "test $? -ne 0 && exit\n" if res["with_mpi"]: ret += f"srun {cmd} {jj}\n" else: ret += f"{cmd} {jj}\n" ret += "test $? -ne 0 && exit\n" ret += f"cd {self.remote_root}\n" ret += "test $? -ne 0 && exit\n" ret += "\ntouch tag_finished\n" script_name = "run.sub" script = os.path.join(self.remote_root, script_name) sftp = self.ssh.open_sftp() with sftp.open(script, "w") as fp: fp.write(ret) sftp.close() return script_name
[docs] class PBSJob(RemoteJob):
[docs] def submit(self, job_dirs, cmd, args=None, resources=None): script_name = self._make_script(job_dirs, cmd, args, res=resources) stdin, stdout, stderr = self.block_checkcall( f"cd {self.remote_root}; qsub {script_name}" ) subret = stdout.readlines() job_id = subret[0].split()[0] sftp = self.ssh.open_sftp() with sftp.open(os.path.join(self.remote_root, "job_id"), "w") as fp: fp.write(job_id) sftp.close()
[docs] def check_status(self): job_id = self._get_job_id() if job_id == "": raise RuntimeError(f"job {self.remote_root} is has not been submitted") ret, stdin, stdout, stderr = self.block_call("qstat -x " + job_id) err_str = stderr.read().decode("utf-8") if ret != 0: if "qstat: Unknown Job Id" in err_str: if self._check_finish_tag(): return JobStatus.finished else: return JobStatus.terminated else: raise RuntimeError( "status command qstat fails to execute. erro info: %s return code %d" % (err_str, ret) ) status_line = stdout.read().decode("utf-8").split("\n")[-2] status_word = status_line.split()[-2] # print (status_word) if status_word in ["Q", "H"]: return JobStatus.waiting elif status_word in ["R"]: return JobStatus.running elif status_word in ["C", "E", "K", "F"]: if self._check_finish_tag(): return JobStatus.finished else: return JobStatus.terminated else: return JobStatus.unknown
def _get_job_id(self): sftp = self.ssh.open_sftp() with sftp.open(os.path.join(self.remote_root, "job_id"), "r") as fp: ret = fp.read().decode("utf-8") sftp.close() return ret def _check_finish_tag(self): sftp = self.ssh.open_sftp() try: sftp.stat(os.path.join(self.remote_root, "tag_finished")) ret = True except OSError: ret = False sftp.close() return ret def _make_script(self, job_dirs, cmd, args=None, res=None): _set_default_resource(res) ret = "" ret += "#!/bin/bash -l\n" if res.get("hpc_job_name", None): ret += "#PBS -N {}\n".format(res["hpc_job_name"]) if res["numb_gpu"] == 0: ret += "#PBS -l select=%d:ncpus=%d\n" % ( res["numb_node"], res["task_per_node"], ) else: ret += "#PBS -l select=%d:ncpus=%d:ngpus=%d\n" % ( res["numb_node"], res["task_per_node"], res["numb_gpu"], ) ret += "#PBS -l walltime={}\n".format(res["time_limit"]) # if res['mem_limit'] > 0 : # ret += "#PBS -l mem=%dG \n" % res['mem_limit'] ret += "#PBS -j oe\n" if len(res["partition"]) > 0: ret += "#PBS -q {}\n".format(res["partition"]) ret += "\n" for ii in res["module_unload_list"]: ret += f"module unload {ii}\n" for ii in res["module_list"]: ret += f"module load {ii}\n" ret += "\n" for ii in res["source_list"]: ret += f"source {ii}\n" ret += "\n" envs = res["envs"] if envs is not None: for key in envs.keys(): ret += f"export {key}={envs[key]}\n" ret += "\n" ret += "cd $PBS_O_WORKDIR\n\n" if args is None: args = [] for ii in job_dirs: args.append("") for ii, jj in zip(job_dirs, args): ret += f"cd {ii}\n" ret += "test $? -ne 0 && exit\n" if res["with_mpi"]: ret += "mpirun -machinefile $PBS_NODEFILE -n %d %s %s\n" % ( res["numb_node"] * res["task_per_node"], cmd, jj, ) else: ret += f"{cmd} {jj}\n" ret += "test $? -ne 0 && exit\n" ret += f"cd {self.remote_root}\n" ret += "test $? -ne 0 && exit\n" ret += "\ntouch tag_finished\n" script_name = "run.sub" script = os.path.join(self.remote_root, script_name) sftp = self.ssh.open_sftp() with sftp.open(script, "w") as fp: fp.write(ret) sftp.close() return script_name
# ssh_session = SSHSession('localhost.json') # rjob = CloudMachineJob(ssh_session, '.') # # can upload dirs and normal files # rjob.upload(['job0', 'job1'], ['batch_exec.py', 'test']) # rjob.submit(['job0', 'job1'], 'touch a; sleep 2') # while rjob.check_status() == JobStatus.running : # print('checked') # time.sleep(2) # print(rjob.check_status()) # # can download dirs and normal files # rjob.download(['job0', 'job1'], ['a']) # # rjob.clean()