import os
import re
import time
import urllib.parse
from urllib.parse import urljoin
import requests
from dpdispatcher.dlog import dlog
from .config import API_HOST, API_LOGGER_STACK_INFO, HTTP_TIME_OUT
from .retcode import RETCODE
try:
import oss2
from oss2 import SizedFileAdapter, determine_part_size
from oss2.models import PartInfo
except ImportError:
pass
ENABLE_STACK = True if API_LOGGER_STACK_INFO else False
[docs]
class RequestInfoException(Exception):
pass
[docs]
class Client:
def __init__(
self, email=None, password=None, debug=False, ticket=None, base_url=API_HOST
):
self.debug = debug
self.debug = os.getenv("LBG_CLI_DEBUG_PRINT", debug)
self.config = {}
self.token = ""
self.user_id = None
self.config["email"] = email
self.config["password"] = password
self.base_url = base_url
self.last_log_offset = 0
self.ticket = ticket
[docs]
def post(self, url, data=None, header=None, params=None, retry=5):
return self._req(
"POST", url, data=data, header=header, params=params, retry=retry
)
[docs]
def get(self, url, header=None, params=None, retry=5):
return self._req("GET", url, header=header, params=params, retry=retry)
def _req(self, method, url, data=None, header=None, params=None, retry=5):
short_url = url
url = urllib.parse.urljoin(self.base_url, url)
if header is None:
header = {}
if not self.token:
self.refresh_token()
self.ticket = os.environ.get("BOHR_TICKET", "")
header["Authorization"] = f"jwt {self.token}"
header["Brm-Ticket"] = self.ticket
resp_code = None
err = None
for i in range(retry):
resp = None
try:
if method == "GET":
resp = requests.get(url, params=params, headers=header)
else:
if self.debug:
print(data)
resp = requests.post(url, json=data, params=params, headers=header)
except Exception as e:
dlog.error(f"request({i}) error {e}", i, stack_info=ENABLE_STACK)
err = e
time.sleep(1 * i)
continue
resp_code = resp.status_code
if not resp.ok:
if self.debug:
print(f"retry: {i},statusCode: {resp.status_code}")
try:
result = resp.json()
err = result.get("error")
except Exception:
pass
time.sleep(0.1 * i)
continue
result = resp.json()
if result["code"] == "0000" or result["code"] == 0:
return result.get("data", {})
else:
self.token = ""
self.refresh_token()
err = result.get("message") or result.get("error")
raise RequestInfoException(resp_code, short_url, err)
def _login(self):
if self.config["email"] is None or self.config["password"] is None:
raise RequestInfoException(
"can not find login information, please check your config"
)
post_data = {"email": self.config["email"], "password": self.config["password"]}
resp = self.post("/account/login", post_data)
self.token = resp["token"]
# print(self.token)
self.user_id = resp["user_id"]
[docs]
def refresh_token(self, retry=3):
self.ticket = os.environ.get("BOHR_TICKET", "")
if self.ticket:
return
url = "/account/login"
post_data = {"email": self.config["email"], "password": self.config["password"]}
resp_code = None
err = None
for i in range(retry):
resp = requests.post(
urljoin(API_HOST, url),
json=post_data,
timeout=HTTP_TIME_OUT,
)
resp_code = resp.status_code
if not resp.ok:
if self.debug:
print(f"login retry: {i},statusCode: {resp.status_code}")
try:
result = resp.json()
err = result.get("error")
except Exception:
pass
time.sleep(1 * i)
continue
result = resp.json()
if result["code"] == RETCODE.OK or result["code"] == 0:
self.token = result["data"]["token"]
return
raise RequestInfoException(resp_code, url, err)
def _get_oss_bucket(self, endpoint, bucket_name):
# res = get("/tools/sts_token", {})
res = self.get("/data/get_sts_token", {})
# print('debug>>>>>>>>>>>>>', res)
dlog.debug(f"debug: _get_oss_bucket: res:{res}")
auth = oss2.StsAuth(
res["AccessKeyId"], res["AccessKeySecret"], res["SecurityToken"]
)
return oss2.Bucket(auth, endpoint, bucket_name)
[docs]
def download(self, oss_file, save_file, endpoint, bucket_name):
bucket = self._get_oss_bucket(endpoint, bucket_name)
dlog.debug(f"debug: download: oss_file:{oss_file}; save_file:{save_file}")
bucket.get_object_to_file(oss_file, save_file)
return save_file
[docs]
def download_from_url(self, url, save_file):
ret = None
for retry_count in range(3):
try:
ret = requests.get(
url, headers={"Authorization": "jwt " + self.token}, stream=True
)
except Exception as e:
dlog.error(f"request error {e}", stack_info=ENABLE_STACK)
continue
if ret.ok:
break
else:
dlog.error(
f"request error status_code:{ret.status_code} reason: {ret.reason} body: \n{ret.text}"
)
time.sleep(retry_count)
ret = None
if ret is not None:
ret.raise_for_status()
with open(save_file, "wb") as f:
for chunk in ret.iter_content(chunk_size=8192):
f.write(chunk)
ret.close()
[docs]
def upload(self, oss_task_zip, zip_task_file, endpoint, bucket_name):
dlog.debug(
f"debug: upload: oss_task_zip:{oss_task_zip}; zip_task_file:{zip_task_file}"
)
bucket = self._get_oss_bucket(endpoint, bucket_name)
total_size = os.path.getsize(zip_task_file)
part_size = determine_part_size(total_size, preferred_size=1000 * 1024)
upload_id = bucket.init_multipart_upload(oss_task_zip).upload_id
parts = []
with open(zip_task_file, "rb") as fileobj:
part_number = 1
offset = 0
while offset < total_size:
num_to_upload = min(part_size, total_size - offset)
result = bucket.upload_part(
oss_task_zip,
upload_id,
part_number,
SizedFileAdapter(fileobj, num_to_upload),
)
parts.append(PartInfo(part_number, result.etag))
offset += num_to_upload
part_number += 1
# result = bucket.complete_multipart_upload(oss_task_zip, upload_id, parts)
result = bucket.complete_multipart_upload(oss_task_zip, upload_id, parts)
# print('debug:upload_result:', result, dir())
return result
[docs]
def job_create(
self, job_type, oss_path, input_data, program_id=None, group_id=None
):
post_data = {
"job_type": job_type,
"oss_path": oss_path,
}
if program_id is not None:
post_data["project_id"] = program_id
if group_id is not None:
post_data["job_group_id"] = group_id
for k, v in input_data.items():
post_data[k] = v
if input_data.get("backward_files"):
post_data["out_files"] = input_data.get("backward_files")
if input_data.get("command"):
post_data["cmd"] = input_data.get("command")
if input_data.get("machine_type"):
post_data["scass_type"] = input_data.get("machine_type")
log = input_data.get(
"logFiles", input_data.get("log_files", input_data.get("log_file"))
)
if log:
if isinstance(log, str):
post_data["log_files"] = [log]
if (
"checkpoint_files" in post_data
and post_data["checkpoint_files"] == "sync_files"
):
post_data["checkpoint_files"] = ["*"]
camel_data = {self._camelize(k): v for k, v in post_data.items()}
ret = self.post("/brm/v2/job/add", camel_data)
group_id = ret.get("jobGroupId")
return ret["jobId"], group_id
def _camelize(self, str_or_iter):
# code reference from https://pypi.org/project/pyhumps/
regex = re.compile(r"(?<=[^\-_\s])[\-_\s]+[^\-_\s]")
def _is_none(_in):
return "" if _in is None else _in
s = str(_is_none(str_or_iter))
if s.isupper() or s.isnumeric():
return str_or_iter
if len(s) != 0 and not s[:2].isupper():
s = s[0].lower() + s[1:]
return regex.sub(lambda m: m.group(0)[-1].upper(), s)
[docs]
def get_job_detail(self, job_id):
try:
ret = self.get(
f"brm/v1/job/{job_id}",
)
except RequestInfoException as e:
if e.args[0] != 200:
raise e
dlog.error(f"get job detail error {e}", stack_info=ENABLE_STACK)
return None
return ret
[docs]
def get_log(self, job_id):
url, size = self._get_job_log(job_id)
if not url:
return ""
if self.last_log_offset >= size:
return ""
resp = requests.get(url, headers={"Range": f"bytes={self.last_log_offset}-"})
self.last_log_offset += len(resp.content)
return resp.content.decode("utf-8")
def _get_job_log(self, job_id):
ret = self.get(
f"/brm/v1/job/{job_id}/log",
params={
"pageSize": 1,
},
)
d = ret.get("logFiles")
if d and len(d) != 0:
return d[0]["url"], d[0]["size"]
return None, 0
[docs]
def get_tasks_list(self, group_id, per_page=30):
result = []
page = 0
while True:
ret = self.get(
"/brm/v1/job/list",
params={
"groupId": group_id,
"page": page,
"pageSize": per_page,
},
)
if len(ret["items"]) == 0:
break
for each in ret["items"]:
result.append(each)
page += 1
return result
[docs]
def get_job_result_url(self, job_id):
try:
if not job_id:
return None
if "job_group_id" in job_id:
ids = job_id.split(":job_group_id:")
job_id, _ = int(ids[0]), int(ids[1])
ret = self.get(f"/brm/v1/job/{job_id}", {})
if "resultUrl" in ret and len(ret["resultUrl"]) != 0:
return ret.get("resultUrl")
else:
return None
except ValueError as e:
dlog.error(e, stack_info=ENABLE_STACK)
return None
[docs]
def kill(self, job_id):
try:
if not job_id:
return None
if "job_group_id" in job_id:
ids = job_id.split(":job_group_id:")
job_id, _ = int(ids[0]), int(ids[1])
ret = self.post(f"/brm/v1/job/kill/{job_id}", {})
return ret
except ValueError as e:
dlog.error(e, stack_info=ENABLE_STACK)
return None