#!/usr/bin/env python3
import os
import stat
import tarfile
import uuid
from enum import Enum
import paramiko
[docs]
class JobStatus(Enum):
unsubmitted = 1
waiting = 2
running = 3
terminated = 4
finished = 5
unknow = 100
def _default_item(resources, key, value):
if key not in resources:
resources[key] = value
def _set_default_resource(res):
if res is None:
res = {}
_default_item(res, "numb_node", 1)
_default_item(res, "task_per_node", 1)
_default_item(res, "numb_gpu", 0)
_default_item(res, "time_limit", "1:0:0")
_default_item(res, "mem_limit", -1)
_default_item(res, "partition", "")
_default_item(res, "account", "")
_default_item(res, "qos", "")
_default_item(res, "constraint_list", [])
_default_item(res, "license_list", [])
_default_item(res, "exclude_list", [])
_default_item(res, "module_unload_list", [])
_default_item(res, "module_list", [])
_default_item(res, "source_list", [])
_default_item(res, "envs", None)
_default_item(res, "with_mpi", False)
[docs]
class SSHSession:
def __init__(self, jdata):
self.remote_profile = jdata
# with open(remote_profile) as fp :
# self.remote_profile = json.load(fp)
self.remote_host = self.remote_profile["hostname"]
self.remote_port = self.remote_profile["port"]
self.remote_uname = self.remote_profile["username"]
self.remote_workpath = self.remote_profile["work_path"]
self.ssh = self._setup_ssh(
self.remote_host, self.remote_port, username=self.remote_uname
)
def _setup_ssh(self, hostname, port, username=None, password=None):
ssh_client = paramiko.SSHClient()
ssh_client.load_system_host_keys()
ssh_client.set_missing_host_key_policy(paramiko.WarningPolicy)
ssh_client.connect(hostname, port=port, username=username, password=password)
assert ssh_client.get_transport().is_active()
return ssh_client
[docs]
def get_ssh_client(self):
return self.ssh
[docs]
def get_session_root(self):
return self.remote_workpath
[docs]
def close(self):
self.ssh.close()
[docs]
class RemoteJob:
def __init__(self, ssh_session, local_root):
self.local_root = os.path.abspath(local_root)
self.job_uuid = str(uuid.uuid4())
# self.job_uuid = 'a21d0017-c9f1-4d29-9a03-97df06965cef'
self.remote_root = os.path.join(ssh_session.get_session_root(), self.job_uuid)
print("local_root is ", local_root)
print("remote_root is", self.remote_root)
self.ssh = ssh_session.get_ssh_client()
sftp = self.ssh.open_sftp()
sftp.mkdir(self.remote_root)
sftp.close()
# open('job_uuid', 'w').write(self.job_uuid)
[docs]
def get_job_root(self):
return self.remote_root
[docs]
def upload(self, job_dirs, local_up_files, dereference=True):
cwd = os.getcwd()
os.chdir(self.local_root)
file_list = []
for ii in job_dirs:
for jj in local_up_files:
file_list.append(os.path.join(ii, jj))
self._put_files(file_list, dereference=dereference)
os.chdir(cwd)
[docs]
def download(self, job_dirs, remote_down_files):
cwd = os.getcwd()
os.chdir(self.local_root)
file_list = []
for ii in job_dirs:
for jj in remote_down_files:
file_list.append(os.path.join(ii, jj))
self._get_files(file_list)
os.chdir(cwd)
[docs]
def block_checkcall(self, cmd):
stdin, stdout, stderr = self.ssh.exec_command(
(f"cd {self.remote_root} ;") + cmd
)
exit_status = stdout.channel.recv_exit_status()
if exit_status != 0:
raise RuntimeError(
"Get error code %d in calling through ssh with job: %s ",
(exit_status, self.job_uuid),
)
return stdin, stdout, stderr
[docs]
def block_call(self, cmd):
stdin, stdout, stderr = self.ssh.exec_command(
(f"cd {self.remote_root} ;") + cmd
)
exit_status = stdout.channel.recv_exit_status()
return exit_status, stdin, stdout, stderr
[docs]
def clean(self):
sftp = self.ssh.open_sftp()
self._rmtree(sftp, self.remote_root)
sftp.close()
def _rmtree(self, sftp, remotepath, level=0, verbose=False):
for f in sftp.listdir_attr(remotepath):
rpath = os.path.join(remotepath, f.filename)
if stat.S_ISDIR(f.st_mode):
self._rmtree(sftp, rpath, level=(level + 1))
else:
rpath = os.path.join(remotepath, f.filename)
if verbose:
print("removing {}{}".format(" " * level, rpath))
sftp.remove(rpath)
if verbose:
print("removing {}{}".format(" " * level, remotepath))
sftp.rmdir(remotepath)
def _put_files(self, files, dereference=True):
of = self.job_uuid + ".tgz"
# local tar
cwd = os.getcwd()
os.chdir(self.local_root)
if os.path.isfile(of):
os.remove(of)
with tarfile.open(of, "w:gz", dereference=dereference) as tar:
for ii in files:
tar.add(ii)
os.chdir(cwd)
# trans
from_f = os.path.join(self.local_root, of)
to_f = os.path.join(self.remote_root, of)
sftp = self.ssh.open_sftp()
sftp.put(from_f, to_f)
# remote extract
self.block_checkcall(f"tar xf {of}")
# clean up
os.remove(from_f)
sftp.remove(to_f)
sftp.close()
def _get_files(self, files):
of = self.job_uuid + ".tgz"
flist = ""
for ii in files:
flist += " " + ii
# remote tar
self.block_checkcall(f"tar czf {of} {flist}")
# trans
from_f = os.path.join(self.remote_root, of)
to_f = os.path.join(self.local_root, of)
if os.path.isfile(to_f):
os.remove(to_f)
sftp = self.ssh.open_sftp()
sftp.get(from_f, to_f)
# extract
cwd = os.getcwd()
os.chdir(self.local_root)
with tarfile.open(of, "r:gz") as tar:
tar.extractall()
os.chdir(cwd)
# cleanup
os.remove(to_f)
sftp.remove(from_f)
[docs]
class CloudMachineJob(RemoteJob):
[docs]
def submit(self, job_dirs, cmd, args=None, resources=None):
# print("Current path is",os.getcwd())
# for ii in job_dirs :
# if not os.path.isdir(ii) :
# raise RuntimeError("cannot find dir %s" % ii)
# print(self.remote_root)
script_name = self._make_script(job_dirs, cmd, args, resources)
self.stdin, self.stdout, self.stderr = self.ssh.exec_command(
f"cd {self.remote_root}; bash {script_name}"
)
# print(self.stderr.read().decode('utf-8'))
# print(self.stdout.read().decode('utf-8'))
[docs]
def check_status(self):
if not self._check_finish(self.stdout):
return JobStatus.running
elif self._get_exit_status(self.stdout) == 0:
return JobStatus.finished
else:
return JobStatus.terminated
def _check_finish(self, stdout):
return stdout.channel.exit_status_ready()
def _get_exit_status(self, stdout):
return stdout.channel.recv_exit_status()
def _make_script(self, job_dirs, cmd, args=None, resources=None):
_set_default_resource(resources)
envs = resources["envs"]
module_list = resources["module_list"]
module_unload_list = resources["module_unload_list"]
task_per_node = resources["task_per_node"]
script_name = "run.sh"
if args is None:
args = []
for ii in job_dirs:
args.append("")
script = os.path.join(self.remote_root, script_name)
sftp = self.ssh.open_sftp()
with sftp.open(script, "w") as fp:
fp.write("#!/bin/bash\n\n")
# fp.write('set -euo pipefail\n')
if envs is not None:
for key in envs.keys():
fp.write(f"export {key}={envs[key]}\n")
fp.write("\n")
if module_unload_list is not None:
for ii in module_unload_list:
fp.write(f"module unload {ii}\n")
fp.write("\n")
if module_list is not None:
for ii in module_list:
fp.write(f"module load {ii}\n")
fp.write("\n")
for ii, jj in zip(job_dirs, args):
fp.write(f"cd {ii}\n")
fp.write("test $? -ne 0 && exit\n")
if resources["with_mpi"] is True:
fp.write("mpirun -n %d %s %s\n" % (task_per_node, cmd, jj))
else:
fp.write(f"{cmd} {jj}\n")
fp.write("test $? -ne 0 && exit\n")
fp.write(f"cd {self.remote_root}\n")
fp.write("test $? -ne 0 && exit\n")
fp.write("\ntouch tag_finished\n")
sftp.close()
return script_name
[docs]
class SlurmJob(RemoteJob):
[docs]
def submit(self, job_dirs, cmd, args=None, resources=None):
script_name = self._make_script(job_dirs, cmd, args, res=resources)
stdin, stdout, stderr = self.block_checkcall(
f"cd {self.remote_root}; sbatch {script_name}"
)
subret = stdout.readlines()
job_id = subret[0].split()[-1]
sftp = self.ssh.open_sftp()
with sftp.open(os.path.join(self.remote_root, "job_id"), "w") as fp:
fp.write(job_id)
sftp.close()
[docs]
def check_status(self):
job_id = self._get_job_id()
if job_id == "":
raise RuntimeError(f"job {self.remote_root} is has not been submitted")
ret, stdin, stdout, stderr = self.block_call("squeue --job " + job_id)
err_str = stderr.read().decode("utf-8")
if ret != 0:
if "Invalid job id specified" in err_str:
if self._check_finish_tag():
return JobStatus.finished
else:
return JobStatus.terminated
else:
raise RuntimeError(
"status command squeue fails to execute\nerror message:%s\nreturn code %d\n"
% (err_str, ret)
)
status_line = stdout.read().decode("utf-8").split("\n")[-2]
status_word = status_line.split()[-4]
if status_word in ["PD", "CF", "S"]:
return JobStatus.waiting
elif status_word in ["R", "CG"]:
return JobStatus.running
elif status_word in [
"C",
"E",
"K",
"BF",
"CA",
"CD",
"F",
"NF",
"PR",
"SE",
"ST",
"TO",
]:
if self._check_finish_tag():
return JobStatus.finished
else:
return JobStatus.terminated
else:
return JobStatus.unknown
def _get_job_id(self):
sftp = self.ssh.open_sftp()
with sftp.open(os.path.join(self.remote_root, "job_id"), "r") as fp:
ret = fp.read().decode("utf-8")
sftp.close()
return ret
def _check_finish_tag(self):
sftp = self.ssh.open_sftp()
try:
sftp.stat(os.path.join(self.remote_root, "tag_finished"))
ret = True
except OSError:
ret = False
sftp.close()
return ret
def _make_script(self, job_dirs, cmd, args=None, res=None):
_set_default_resource(res)
ret = ""
ret += "#!/bin/bash -l\n"
ret += "#SBATCH -N %d\n" % res["numb_node"]
ret += "#SBATCH --ntasks-per-node %d\n" % res["task_per_node"]
ret += "#SBATCH -t {}\n".format(res["time_limit"])
if res["mem_limit"] > 0:
ret += "#SBATCH --mem %dG \n" % res["mem_limit"]
if len(res["account"]) > 0:
ret += "#SBATCH --account {} \n".format(res["account"])
if len(res["partition"]) > 0:
ret += "#SBATCH --partition {} \n".format(res["partition"])
if len(res["qos"]) > 0:
ret += "#SBATCH --qos {} \n".format(res["qos"])
if res["numb_gpu"] > 0:
ret += "#SBATCH --gres=gpu:%d\n" % res["numb_gpu"]
for ii in res["constraint_list"]:
ret += f"#SBATCH -C {ii} \n"
for ii in res["license_list"]:
ret += f"#SBATCH -L {ii} \n"
for ii in res["exclude_list"]:
ret += f"#SBATCH --exclude {ii} \n"
ret += "\n"
# ret += 'set -euo pipefail\n\n'
for ii in res["module_unload_list"]:
ret += f"module unload {ii}\n"
for ii in res["module_list"]:
ret += f"module load {ii}\n"
ret += "\n"
for ii in res["source_list"]:
ret += f"source {ii}\n"
ret += "\n"
envs = res["envs"]
if envs is not None:
for key in envs.keys():
ret += f"export {key}={envs[key]}\n"
ret += "\n"
if args is None:
args = []
for ii in job_dirs:
args.append("")
for ii, jj in zip(job_dirs, args):
ret += f"cd {ii}\n"
ret += "test $? -ne 0 && exit\n"
if res["with_mpi"]:
ret += f"srun {cmd} {jj}\n"
else:
ret += f"{cmd} {jj}\n"
ret += "test $? -ne 0 && exit\n"
ret += f"cd {self.remote_root}\n"
ret += "test $? -ne 0 && exit\n"
ret += "\ntouch tag_finished\n"
script_name = "run.sub"
script = os.path.join(self.remote_root, script_name)
sftp = self.ssh.open_sftp()
with sftp.open(script, "w") as fp:
fp.write(ret)
sftp.close()
return script_name
[docs]
class PBSJob(RemoteJob):
[docs]
def submit(self, job_dirs, cmd, args=None, resources=None):
script_name = self._make_script(job_dirs, cmd, args, res=resources)
stdin, stdout, stderr = self.block_checkcall(
f"cd {self.remote_root}; qsub {script_name}"
)
subret = stdout.readlines()
job_id = subret[0].split()[0]
sftp = self.ssh.open_sftp()
with sftp.open(os.path.join(self.remote_root, "job_id"), "w") as fp:
fp.write(job_id)
sftp.close()
[docs]
def check_status(self):
job_id = self._get_job_id()
if job_id == "":
raise RuntimeError(f"job {self.remote_root} is has not been submitted")
ret, stdin, stdout, stderr = self.block_call("qstat -x " + job_id)
err_str = stderr.read().decode("utf-8")
if ret != 0:
if "qstat: Unknown Job Id" in err_str:
if self._check_finish_tag():
return JobStatus.finished
else:
return JobStatus.terminated
else:
raise RuntimeError(
"status command qstat fails to execute. erro info: %s return code %d"
% (err_str, ret)
)
status_line = stdout.read().decode("utf-8").split("\n")[-2]
status_word = status_line.split()[-2]
# print (status_word)
if status_word in ["Q", "H"]:
return JobStatus.waiting
elif status_word in ["R"]:
return JobStatus.running
elif status_word in ["C", "E", "K", "F"]:
if self._check_finish_tag():
return JobStatus.finished
else:
return JobStatus.terminated
else:
return JobStatus.unknown
def _get_job_id(self):
sftp = self.ssh.open_sftp()
with sftp.open(os.path.join(self.remote_root, "job_id"), "r") as fp:
ret = fp.read().decode("utf-8")
sftp.close()
return ret
def _check_finish_tag(self):
sftp = self.ssh.open_sftp()
try:
sftp.stat(os.path.join(self.remote_root, "tag_finished"))
ret = True
except OSError:
ret = False
sftp.close()
return ret
def _make_script(self, job_dirs, cmd, args=None, res=None):
_set_default_resource(res)
ret = ""
ret += "#!/bin/bash -l\n"
if res.get("hpc_job_name", None):
ret += "#PBS -N {}\n".format(res["hpc_job_name"])
if res["numb_gpu"] == 0:
ret += "#PBS -l select=%d:ncpus=%d\n" % (
res["numb_node"],
res["task_per_node"],
)
else:
ret += "#PBS -l select=%d:ncpus=%d:ngpus=%d\n" % (
res["numb_node"],
res["task_per_node"],
res["numb_gpu"],
)
ret += "#PBS -l walltime={}\n".format(res["time_limit"])
# if res['mem_limit'] > 0 :
# ret += "#PBS -l mem=%dG \n" % res['mem_limit']
ret += "#PBS -j oe\n"
if len(res["partition"]) > 0:
ret += "#PBS -q {}\n".format(res["partition"])
ret += "\n"
for ii in res["module_unload_list"]:
ret += f"module unload {ii}\n"
for ii in res["module_list"]:
ret += f"module load {ii}\n"
ret += "\n"
for ii in res["source_list"]:
ret += f"source {ii}\n"
ret += "\n"
envs = res["envs"]
if envs is not None:
for key in envs.keys():
ret += f"export {key}={envs[key]}\n"
ret += "\n"
ret += "cd $PBS_O_WORKDIR\n\n"
if args is None:
args = []
for ii in job_dirs:
args.append("")
for ii, jj in zip(job_dirs, args):
ret += f"cd {ii}\n"
ret += "test $? -ne 0 && exit\n"
if res["with_mpi"]:
ret += "mpirun -machinefile $PBS_NODEFILE -n %d %s %s\n" % (
res["numb_node"] * res["task_per_node"],
cmd,
jj,
)
else:
ret += f"{cmd} {jj}\n"
ret += "test $? -ne 0 && exit\n"
ret += f"cd {self.remote_root}\n"
ret += "test $? -ne 0 && exit\n"
ret += "\ntouch tag_finished\n"
script_name = "run.sub"
script = os.path.join(self.remote_root, script_name)
sftp = self.ssh.open_sftp()
with sftp.open(script, "w") as fp:
fp.write(ret)
sftp.close()
return script_name
# ssh_session = SSHSession('localhost.json')
# rjob = CloudMachineJob(ssh_session, '.')
# # can upload dirs and normal files
# rjob.upload(['job0', 'job1'], ['batch_exec.py', 'test'])
# rjob.submit(['job0', 'job1'], 'touch a; sleep 2')
# while rjob.check_status() == JobStatus.running :
# print('checked')
# time.sleep(2)
# print(rjob.check_status())
# # can download dirs and normal files
# rjob.download(['job0', 'job1'], ['a'])
# # rjob.clean()