Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 90 additions & 15 deletions flask_ades_wpst/ades_base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import os
from elasticsearch import Elasticsearch
import boto3
from flask import Response
from jinja2 import Template
import logging
Expand Down Expand Up @@ -39,7 +42,69 @@ def __init__(self, app_config):
# flask_wpst.py.
raise ValueError("Platform {} not implemented.".format(self._platform))
self._ades = ADES_Platform()
self._job_publisher = SnsJobPublisher(app_config["JOB_NOTIFICATION_TOPIC_ARN"])

def get_sts_and_sns_clients(aws_auth_method):
if aws_auth_method == "keys":
sts_client = boto3.client(
"sts",
region_name="us-west-2",
aws_access_key_id=os.getenv("ACCESS_KEY"),
aws_secret_access_key=os.getenv("SECRET_KEY"),
aws_session_token=os.getenv("SESSION_TOKEN"),
)
print(sts_client.get_caller_identity())
client = boto3.client(
"sns",
region_name="us-west-2",
aws_access_key_id=os.getenv("ACCESS_KEY"),
aws_secret_access_key=os.getenv("SECRET_KEY"),
aws_session_token=os.getenv("SESSION_TOKEN"),
)

elif aws_auth_method == "iam":
sts_client = boto3.client("sts", region_name="us-west-2")
print(sts_client.get_caller_identity())
client = boto3.client("sns", region_name="us-west-2")

else:
print(f"Invalid aws_auth_method: {aws_auth_method}")
print(f"Supported methods: iam, keys")
exit()

return sts_client, client

def _update_jobs_database(self, job_id, proc_id, status, job_inputs={}, job_tags=[]):
sts_client, sns_client = self.get_sts_and_sns_clients(aws_auth_method="iam")
job_data = {"id": job_id, "process": proc_id, "status": status, "inputs": job_inputs, "tags": job_tags}
topic_arn = os.environ["JOBS_DATA_SNS_TOPIC_ARN"]
print(
sns_client.publish(
TopicArn=topic_arn, Message=json.dumps(job_data), MessageGroupId=job_id
)
)

def _get_jobs_doc(self, job_id):
"""
This function retrieves the ES document for a given job ID from the Jobs DB
:param job_id:
:return:
"""
# Create an Elasticsearch client
# Initialize the Elasticsearch client
# TODO: Change to use environment variables
es = Elasticsearch([{'host': os.environ["ES_URL"], 'port': 9200}])
index_name = ""
document_id = job_id

# Query the document
try:
result = es.get(index=index_name, id=document_id)
document = result.get('_source', {})
print(f"Retrieved Document:\n {document}")
return document
except Exception as e:
print("An error occurred:", e)


def proc_dict(self, proc):
return {
Expand Down Expand Up @@ -169,9 +234,14 @@ def exec_job(self, proc_id, job_params):
"job_publisher": self._job_publisher
}
ades_resp = self._ades.exec_job(job_spec)
# ades_resp will return platform specific information that should be
job_id = ades_resp.get("job_id")
inputs = ades_resp.get("inputs")
job_status = ades_resp.get("status")
# Update jobs database
self._update_jobs_database(job_id, proc_id, job_status, inputs)
# ades_resp will return platform specific information that should be
# kept in the database with the job ID record
sqlite_exec_job(proc_id, ades_resp["job_id"], ades_resp["inputs"], ades_resp)
sqlite_exec_job(proc_id, job_id, inputs, ades_resp)
return {"code": 201, "location": "{}/processes/{}/jobs/{}".format(self.host, proc_id, ades_resp["job_id"])}

def dismiss_job(self, proc_id, job_id):
Expand All @@ -189,19 +259,24 @@ def dismiss_job(self, proc_id, job_id):
return job_spec

def get_job_results(self, proc_id, job_id):
# job_spec = self.get_job(proc_id, job_id)
products = self._ades.get_job_results(job_id=job_id)
job_doc = self._get_jobs_doc(job_id=job_id)
job_result = dict()
outputs = list()
for product in products:
id = product.get("id")
location = None
locations = product.get("browse_urls")
for loc in locations:
if loc.startswith("s3://"):
location = loc
# create output blocks and append
output = {"mimeType": "tbd", "href": location, "id": id}
outputs.append(output)
#TODO: Add verification to check if job_id corresponds to a job of process type - proc_id
if "outputs" in job_doc:
job_outputs = job_doc.get["outputs"]
print(f"Retrieved Output Field: {json.dumps(outputs)}")
for product in job_outputs:
prod_id = product
prod_location = job_outputs.get(product).get("location")
file_type = job_outputs.get(product).get("class")
output = {
"mimeType": file_type,
"href": prod_location,
"id": prod_id
}
outputs.append(output)
else:
print("Output field not found in the document.")
job_result["outputs"] = outputs
return job_result
54 changes: 2 additions & 52 deletions flask_ades_wpst/ades_hysds.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,34 +51,6 @@ def __init__(
def _generate_job_id_stub(self, qsub_stdout):
return ".".join(qsub_stdout.strip().split(".")[:2])

def _pbs_job_state_to_status_str(self, work_dir, job_state):
pbs_job_state_to_status = {
"Q": "accepted",
"R": "running",
"E": "running",
}
if job_state in pbs_job_state_to_status:
status = pbs_job_state_to_status[job_state]
elif job_state == "F":
# Job finished; need to check cwl-runner exit-code to determine
# if the job succeeded or failed. In the auto-generated, PBS job
# submission script, the exit code is saved to a file.
exit_code_fname = os.path.join(work_dir, self._exit_code_fname)
try:
with open(exit_code_fname, "r") as f:
d = json.loads(f.read())
exit_code = d["exit_code"]
if exit_code == 0:
status = "successful"
else:
status = "failed"
except:
status = "unknown-not-qref"
else:
# Encountered a PBS job state that is not supported.
status = "unknown-no-exit-code"
return status

def _construct_job_spec(self, cwl_wfl, wfl_inputs):
"""
create the job spec for a process to deploy
Expand Down Expand Up @@ -375,15 +347,6 @@ def exec_job(self, job_spec):
try:
# Publish job to JobPublisher passed in the job_spec
hysds_job = job.submit_job(queue="verdi-job_worker", priority=0, tag="test")
job = Job(
id=hysds_job.job_id,
status="submitted",
inputs=params,
outputs={},
labels=labels,
)

job_spec["job_publisher"].publish_job_change(job)

print(f"Submitted job with id {hysds_job.job_id}")

Expand All @@ -395,21 +358,8 @@ def exec_job(self, job_spec):
"error": None,
}
except Exception as ex:
# Publish job to JobPublisher passed in the job_spec
try:
job = Job(
id=hysds_job.job_id,
status="failed",
inputs=params,
outputs={},
labels=labels,
)
job_spec["job_publisher"].publish_job_change(job)
except (AttributeError, UnboundLocalError) as e:
print(f"Failed to publish job, no hysds job id:\n{e}")

error = ex
return {"job_id": hysds_job.job_id, "status": "failed", "inputs": params, "error": str(error)}
error = str(ex)
return {"job_id": hysds_job.job_id, "status": "failed", "inputs": params, "error": error}

def dismiss_job(self, proc_id, job_id):
# We can only dismiss jobs that were last in accepted or running state.
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ cwl-runner==1.0
docker==6.0.0
jsonschema==4.5.1
GitPython==3.1.29
elasticsearch
pydantic==1.10.7
boto3==1.26.118
backoff==2.2.1