diff --git a/bioblend/galaxy/__init__.py b/bioblend/galaxy/__init__.py index 937b23d28..5d535cd9b 100644 --- a/bioblend/galaxy/__init__.py +++ b/bioblend/galaxy/__init__.py @@ -3,6 +3,8 @@ """ from typing import Optional +import requests + from bioblend.galaxy import ( config, container_resolution, @@ -39,6 +41,7 @@ def __init__( email: Optional[str] = None, password: Optional[str] = None, verify: bool = True, + session: Optional[requests.Session] = None, ) -> None: """ A base representation of a connection to a Galaxy instance, identified @@ -81,7 +84,7 @@ def __init__( :param verify: Whether to verify the server's TLS certificate :type verify: bool """ - super().__init__(url, key, email, password, verify=verify) + super().__init__(url, key, email, password, verify=verify, session=session) self.libraries = libraries.LibraryClient(self) self.histories = histories.HistoryClient(self) self.workflows = workflows.WorkflowClient(self) diff --git a/bioblend/galaxy/objects/galaxy_instance.py b/bioblend/galaxy/objects/galaxy_instance.py index ddde94831..08e51526f 100644 --- a/bioblend/galaxy/objects/galaxy_instance.py +++ b/bioblend/galaxy/objects/galaxy_instance.py @@ -9,6 +9,8 @@ Optional, ) +import requests + import bioblend import bioblend.galaxy from bioblend.galaxy.datasets import TERMINAL_STATES @@ -57,8 +59,9 @@ def __init__( email: Optional[str] = None, password: Optional[str] = None, verify: bool = True, + session: Optional[requests.Session] = None, ) -> None: - self.gi = bioblend.galaxy.GalaxyInstance(url, api_key, email, password, verify) + self.gi = bioblend.galaxy.GalaxyInstance(url, api_key, email, password, verify, session=session) self.log = bioblend.log self.datasets = client.ObjDatasetClient(self) self.dataset_collections = client.ObjDatasetCollectionClient(self) diff --git a/bioblend/galaxyclient.py b/bioblend/galaxyclient.py index 3a6c168e0..3c81f0773 100644 --- a/bioblend/galaxyclient.py +++ b/bioblend/galaxyclient.py @@ -38,6 +38,7 @@ def __init__( password: Optional[str] = None, verify: bool = True, timeout: Optional[float] = None, + session: Optional[requests.Session] = None, ) -> None: """ :param verify: Whether to verify the server's TLS certificate @@ -47,6 +48,11 @@ def __init__( """ self.verify = verify self.timeout = timeout + + if session is None: + session = requests.Session() + self.session = session + # Make sure the URL scheme is defined (otherwise requests will not work) if not url.lower().startswith("http"): found_scheme = None @@ -54,7 +60,7 @@ def __init__( for scheme in ("https://", "http://"): log.warning(f"Missing scheme in url, trying with {scheme}") with contextlib.suppress(requests.RequestException): - r = requests.get( + r = session.get( scheme + url, timeout=self.timeout, verify=self.verify, @@ -133,7 +139,7 @@ def make_get_request(self, url: str, **kwargs: Any) -> requests.Response: headers = self.json_headers kwargs.setdefault("timeout", self.timeout) kwargs.setdefault("verify", self.verify) - r = requests.get(url, headers=headers, **kwargs) + r = self.session.get(url, headers=headers, **kwargs) return r def make_post_request( @@ -173,8 +179,7 @@ def my_dumps(d: dict) -> dict: data = json.dumps(payload) if payload is not None else None headers = self.json_headers post_params = params - - r = requests.post( + r = self.session.post( url, params=post_params, data=data, @@ -214,7 +219,7 @@ def make_delete_request( """ data = json.dumps(payload) if payload is not None else None headers = self.json_headers - r = requests.delete( + r = self.session.delete( url, params=params, data=data, @@ -236,7 +241,7 @@ def make_put_request(self, url: str, payload: Optional[dict] = None, params: Opt """ data = json.dumps(payload) if payload is not None else None headers = self.json_headers - r = requests.put( + r = self.session.put( url, params=params, data=data, @@ -272,7 +277,7 @@ def make_patch_request(self, url: str, payload: Optional[dict] = None, params: O """ data = json.dumps(payload) if payload is not None else None headers = self.json_headers - r = requests.patch( + r = self.session.patch( url, params=params, data=data, @@ -355,7 +360,7 @@ def key(self) -> Optional[str]: auth_url = f"{self.url}/authenticate/baseauth" # Use lower level method instead of make_get_request() because we # need the additional Authorization header. - r = requests.get( + r = self.session.get( auth_url, headers=headers, timeout=self.timeout, diff --git a/docs/examples/session_handling/cookie_handler.py b/docs/examples/session_handling/cookie_handler.py new file mode 100644 index 000000000..7b21065ce --- /dev/null +++ b/docs/examples/session_handling/cookie_handler.py @@ -0,0 +1,96 @@ +""" +Cookie handler for Authelia authentication using the Galaxy API +""" + +import getpass +import sys +from http.cookiejar import LWPCookieJar +from pathlib import Path +from pprint import pprint + +import requests +from galaxy_api import ( + get_inputs, + get_workflows, +) + +AUTH_HOSTNAME = "auth.service.org" +API_HOSTNAME = "galaxy.service.org" +cookie_path = Path(".galaxy_auth.txt") +cookie_jar = LWPCookieJar(cookie_path) + + +class ExpiredCookies(Exception): + pass + + +class NoCookies(Exception): + pass + + +def main(): + try: + cookie_jar.load() # raises OSError + if not cookie_jar: # if empty due to expirations + raise ExpiredCookies() + except OSError: + print("No cached session found, please authenticate") + prompt_authentication() + except ExpiredCookies: + print("Session has expired, please authenticate") + prompt_authentication() + run_examples() + + +def prompt_authentication(): + # -------------------------------------------------------------------------- + # Prompt for username and password + + username = input("Please enter username: ") + password = getpass.getpass(f"Please enter password for {username}: ") + + # -------------------------------------------------------------------------- + # Prepare authentication packet and authenticate session using Authelia + + login_body = { + "username": username, + "password": password, + "requestMethod": "GET", + "keepMeLoggedIn": True, + "targetURL": API_HOSTNAME, + } + + with requests.Session() as session: + session.cookies = cookie_jar + session.verify = True + + session.post(f"https://{AUTH_HOSTNAME}/api/firstfactor", json=login_body) + + response = session.get(f"https://{AUTH_HOSTNAME}/api/user/info") + if response.status_code != 200: + print("Authentication failed") + sys.exit() + else: + pprint(response.json()) + session.cookies.save() + + +def run_examples(): + GALAXY_KEY = "user_api_key" + WORKFLOW_NAME = "workflow_name" + with requests.Session() as session: + session.cookies = cookie_jar + + print("Running demo to demonstrate how to use the Galaxy API with Authelia") + + print("Getting workflows from Galaxy") + response = get_workflows(f"https://{API_HOSTNAME}", GALAXY_KEY, session=session) + print(response) + + print("Getting inputs for a workflow") + response = get_inputs(f"https://{API_HOSTNAME}", GALAXY_KEY, WORKFLOW_NAME, session=session) + print(response) + + +if __name__ == "__main__": + main() diff --git a/docs/examples/session_handling/galaxy_api.py b/docs/examples/session_handling/galaxy_api.py new file mode 100644 index 000000000..c052fc609 --- /dev/null +++ b/docs/examples/session_handling/galaxy_api.py @@ -0,0 +1,58 @@ +from bioblend.galaxy import GalaxyInstance + + +def get_inputs(server, api_key, workflow_name, session=None): + """ + Function to get an array of inputs for a given galaxy workflow + + Usage: + get_inputs( + server = "galaxy.server.org", + api_key = "user_api_key", + workflow_name = "workflow_name", + ) + + Args: + server (string): Galaxy server address + api_key (string): User generated string from galaxy instance + to create: User > Preferences > Manage API Key > Create a new key + workflow_name (string): Target workflow name + Returns: + inputs (array of strings): Input files expected by the workflow, these will be in the same order as they should be given in the main API call + """ + + gi = GalaxyInstance(url=server, key=api_key, session=session) + api_workflow = gi.workflows.get_workflows(name=workflow_name) + steps = gi.workflows.export_workflow_dict(api_workflow[0]["id"])["steps"] + inputs = [] + for step in steps: + # Some of the steps don't take inputs so have to skip these + if len(steps[step]["inputs"]) > 0: + inputs.append(steps[step]["inputs"][0]["name"]) + + return inputs + + +def get_workflows(server, api_key, session=None): + """ + Function to get an array of workflows available on a given galaxy instance + + Usage: + get_workflows( + server = "galaxy.server.org", + api_key = "user_api_key", + ) + + Args: + server (string): Galaxy server address + api_key (string): User generated string from galaxy instance + to create: User > Preferences > Manage API Key > Create a new key + Returns: + workflows (array of strings): Workflows available to be run on the galaxy instance provided + """ + gi = GalaxyInstance(url=server, key=api_key, session=session) + workflows_dict = gi.workflows.get_workflows() + workflows = [] + for item in workflows_dict: + workflows.append(item["name"]) + return workflows