diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 38532ae97..b4ca03f7f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -54,7 +54,9 @@ jobs: dir: cpp env: {} - lang: Python - cmd: pytest tests --cov=armonik --cov-config=.coveragerc --cov-report=term-missing --cov-append --cov-report xml:coverage.xml --cov-report html:coverage_report + cmd: pytest tests -m "not worker" --cov=armonik --cov-config=.coveragerc --cov-report=term-missing --cov-append --cov-report xml:coverage.xml --cov-report html:coverage_report + worker-cmd: pytest tests -m "worker" --cov=armonik --cov-config=.coveragerc --cov-report=term-missing --cov-append --cov-report xml:coverage.xml --cov-report html:coverage_report + secure-worker: true dir: python env: {} - lang: Rust @@ -198,7 +200,12 @@ jobs: export Http__Endpoint=http://localhost:4999 export GrpcClient__Endpoint=http://localhost:5000 export GrpcClient__AllowUnsafeConnection=true + export ComputePlane__AgentChannel__Address=http://localhost:5000 + export ComputePlane__AgentChannel__SocketType=tcp + export ComputePlane__WorkerChannel__Address=http://localhost:5010 + export ComputePlane__WorkerChannel__SocketType=tcp ../../scripts/mock_test.sh ${{ matrix.language.dir }} '${{ matrix.language.cmd }}' + [ -z '${{ matrix.language.worker-cmd }}' ] || ../../scripts/mock_test.sh ${{ matrix.language.dir }} '${{ matrix.language.worker-cmd }}' - name: TLS Insecure if: ${{ contains(fromJson('["C#", "Rust"]'), matrix.language.lang) }} @@ -214,7 +221,12 @@ jobs: export Grpc__AllowUnsafeConnection=true export GrpcClient__Endpoint=https://localhost:5001 export GrpcClient__AllowUnsafeConnection=true + export ComputePlane__AgentChannel__Address=https://localhost:5001 + export ComputePlane__AgentChannel__SocketType=tcp + export ComputePlane__WorkerChannel__Address=http://localhost:5010 + export ComputePlane__WorkerChannel__SocketType=tcp ../../scripts/mock_test.sh ${{ matrix.language.dir }} '${{ matrix.language.cmd }}' + [ -z '${{ matrix.language.worker-cmd }}' ] || [ '${{ matrix.language.secure-worker }}' = false ] || ../../scripts/mock_test.sh ${{ matrix.language.dir }} '${{ matrix.language.worker-cmd }}' - name: TLS secure working-directory: packages/csharp/ @@ -231,7 +243,12 @@ jobs: export GrpcClient__Endpoint=https://localhost:5001 export GrpcClient__AllowUnsafeConnection=false export GrpcClient__CaCert=$CertFolder/server1-ca.pem + export ComputePlane__AgentChannel__Address=https://localhost:5001 + export ComputePlane__AgentChannel__SocketType=tcp + export ComputePlane__WorkerChannel__Address=http://localhost:5010 + export ComputePlane__WorkerChannel__SocketType=tcp ../../scripts/mock_test.sh ${{ matrix.language.dir }} '${{ matrix.language.cmd }}' + [ -z '${{ matrix.language.worker-cmd }}' ] || [ '${{ matrix.language.secure-worker }}' = false ] || ../../scripts/mock_test.sh ${{ matrix.language.dir }} '${{ matrix.language.worker-cmd }}' - name: TLS store working-directory: packages/csharp/ @@ -245,7 +262,12 @@ jobs: export Http__Endpoint=https://localhost:5002 export GrpcClient__Endpoint=https://localhost:5002 export GrpcClient__AllowUnsafeConnection=false + export ComputePlane__AgentChannel__Address=https://localhost:5002 + export ComputePlane__AgentChannel__SocketType=tcp + export ComputePlane__WorkerChannel__Address=http://localhost:5010 + export ComputePlane__WorkerChannel__SocketType=tcp ../../scripts/mock_test.sh ${{ matrix.language.dir }} '${{ matrix.language.cmd }}' + [ -z '${{ matrix.language.worker-cmd }}' ] || [ '${{ matrix.language.secure-worker }}' = false ] || ../../scripts/mock_test.sh ${{ matrix.language.dir }} '${{ matrix.language.worker-cmd }}' - name: mTLS Insecure if: ${{ contains(fromJson('["C#", "Rust"]'), matrix.language.lang) }} @@ -265,7 +287,12 @@ jobs: export GrpcClient__AllowUnsafeConnection=true export GrpcClient__CertPem=$CertFolder/client.pem export GrpcClient__KeyPem=$CertFolder/client.key + export ComputePlane__AgentChannel__Address=https://localhost:5003 + export ComputePlane__AgentChannel__SocketType=tcp + export ComputePlane__WorkerChannel__Address=http://localhost:5010 + export ComputePlane__WorkerChannel__SocketType=tcp ../../scripts/mock_test.sh ${{ matrix.language.dir }} '${{ matrix.language.cmd }}' + [ -z '${{ matrix.language.worker-cmd }}' ] || [ '${{ matrix.language.secure-worker }}' = false ] || ../../scripts/mock_test.sh ${{ matrix.language.dir }} '${{ matrix.language.worker-cmd }}' - name: mTLS secure working-directory: packages/csharp/ @@ -287,7 +314,12 @@ jobs: export GrpcClient__CaCert=$CertFolder/server1-ca.pem export GrpcClient__CertPem=$CertFolder/client.pem export GrpcClient__KeyPem=$CertFolder/client.key + export ComputePlane__AgentChannel__Address=https://localhost:5003 + export ComputePlane__AgentChannel__SocketType=tcp + export ComputePlane__WorkerChannel__Address=http://localhost:5010 + export ComputePlane__WorkerChannel__SocketType=tcp ../../scripts/mock_test.sh ${{ matrix.language.dir }} '${{ matrix.language.cmd }}' + [ -z '${{ matrix.language.worker-cmd }}' ] || [ '${{ matrix.language.secure-worker }}' = false ] || ../../scripts/mock_test.sh ${{ matrix.language.dir }} '${{ matrix.language.worker-cmd }}' - name: mTLS store working-directory: packages/csharp/ @@ -307,7 +339,12 @@ jobs: export GrpcClient__AllowUnsafeConnection=false export GrpcClient__CertPem=$CertFolder/client.pem export GrpcClient__KeyPem=$CertFolder/client.key + export ComputePlane__AgentChannel__Address=https://localhost:5004 + export ComputePlane__AgentChannel__SocketType=tcp + export ComputePlane__WorkerChannel__Address=http://localhost:5010 + export ComputePlane__WorkerChannel__SocketType=tcp ../../scripts/mock_test.sh ${{ matrix.language.dir }} '${{ matrix.language.cmd }}' + [ -z '${{ matrix.language.worker-cmd }}' ] || [ '${{ matrix.language.secure-worker }}' = false ] || ../../scripts/mock_test.sh ${{ matrix.language.dir }} '${{ matrix.language.worker-cmd }}' - name: Test Report uses: dorny/test-reporter@v1 diff --git a/packages/csharp/ArmoniK.Api.Mock/Program.cs b/packages/csharp/ArmoniK.Api.Mock/Program.cs index 2f5fd31c4..c192421f9 100644 --- a/packages/csharp/ArmoniK.Api.Mock/Program.cs +++ b/packages/csharp/ArmoniK.Api.Mock/Program.cs @@ -14,15 +14,21 @@ // See the License for the specific language governing permissions and // limitations under the License. +using System; using System.Collections.Generic; using System.IO; using System.Security.Cryptography.X509Certificates; using System.Text; using System.Threading.Tasks; +using ArmoniK.Api.Common.Channel.Utils; +using ArmoniK.Api.Common.Options; +using ArmoniK.Api.gRPC.V1.Worker; using ArmoniK.Api.Mock; using ArmoniK.Api.Mock.Services; +using Google.Protobuf; + using Microsoft.AspNetCore.Authentication.Certificate; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Builder; @@ -52,8 +58,7 @@ .GetSection("Port") .Value ?? "5001"); -X509Certificate2? serverCert = null; -string? clientCertPath = null; +X509Certificate2? serverCert = null; var serverCertPath = builder.Configuration.GetSection("Http") .GetSection("Cert") @@ -62,9 +67,9 @@ var serverKeyPath = builder.Configuration.GetSection("Http") .GetSection("Key") .Value ?? ""; -clientCertPath = builder.Configuration.GetSection("Http") - .GetSection("ClientCert") - .Value; +var clientCertPath = builder.Configuration.GetSection("Http") + .GetSection("ClientCert") + .Value; if (string.IsNullOrEmpty(clientCertPath)) { clientCertPath = null; @@ -91,6 +96,17 @@ builder.Services.AddSingleton(service); } +var rawComputePlaneOptions = builder.Configuration.GetSection(ComputePlane.SettingSection); +var workerChannelOptions = rawComputePlaneOptions?.GetSection(ComputePlane.WorkerChannelSection); + +builder.Services.AddSingleton(_ => new GrpcChannel + { + Address = workerChannelOptions?.GetValue("Address") ?? "/cache/armonik_worker.sock", + SocketType = workerChannelOptions?.GetValue("SocketType") ?? GrpcSocketType.UnixDomainSocket, + }); +builder.Services.AddSingleton(); +builder.Services.AddSingleton(); + if (clientCertPath is not null) { builder.Services.AddAuthentication(CertificateAuthenticationDefaults.AuthenticationScheme) @@ -192,6 +208,10 @@ CountingService.ResetCounters(); return Task.CompletedTask; }); +app.MapPost("/worker/process", + SendProcessRequest); +app.MapPost("/worker/healthcheck", + SendHealthCheck); app.Run(); @@ -212,3 +232,36 @@ async Task Calls(HttpContext context) context.Response.ContentType = "application/json"; await context.Response.Body.WriteAsync(Encoding.ASCII.GetBytes(body)); } + +async Task SendProcessRequest(HttpContext context) +{ + var requestBody = Encoding.ASCII.GetString(await ReadAll(context.Request.Body)); + var requestInput = JsonConvert.DeserializeObject(requestBody); + var request = new JsonParser(JsonParser.Settings.Default).Parse(requestInput.Request); + foreach (var result in requestInput.Results) + { + await using var file = File.OpenWrite(Path.Join(request.DataFolder, + result.Key)); + if (requestInput.ResultsEncoding is ResultsEncoding.Base64) + { + await file.WriteAsync(Convert.FromBase64String(result.Value)); + } + else + { + await file.WriteAsync(Convert.FromHexString(result.Value)); + } + } + + var reply = await app.Services.GetRequiredService() + .ProcessRequest(request); + context.Response.ContentType = "application/json"; + await context.Response.Body.WriteAsync(Encoding.ASCII.GetBytes(new JsonFormatter(JsonFormatter.Settings.Default).Format(reply))); +} + +async Task SendHealthCheck(HttpContext context) +{ + var reply = await app.Services.GetRequiredService() + .HealthCheckRequest(); + context.Response.ContentType = "application/json"; + await context.Response.Body.WriteAsync(Encoding.ASCII.GetBytes(new JsonFormatter(JsonFormatter.Settings.Default).Format(reply))); +} diff --git a/packages/csharp/ArmoniK.Api.Mock/WorkerCallService.cs b/packages/csharp/ArmoniK.Api.Mock/WorkerCallService.cs new file mode 100644 index 000000000..9319673a2 --- /dev/null +++ b/packages/csharp/ArmoniK.Api.Mock/WorkerCallService.cs @@ -0,0 +1,63 @@ +// This file is part of the ArmoniK project +// +// Copyright (C) ANEO, 2021-2025.All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License") +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +using ArmoniK.Api.Common.Channel.Utils; +using ArmoniK.Api.gRPC.V1; +using ArmoniK.Api.gRPC.V1.Worker; + +namespace ArmoniK.Api.Mock; + +public class WorkerCallService(GrpcChannelProvider provider) +{ + public async Task HealthCheckRequest() + { + var channel = provider.Get(); + var workerClient = new Worker.WorkerClient(channel); + var reply = await workerClient.HealthCheckAsync(new Empty()); + await channel.ShutdownAsync() + .WaitAsync(CancellationToken.None) + .ConfigureAwait(false); + return reply; + } + + public async Task ProcessRequest(ProcessRequest request) + { + var channel = provider.Get(); + var workerClient = new Worker.WorkerClient(channel); + var reply = await workerClient.ProcessAsync(request); + await channel.ShutdownAsync() + .WaitAsync(CancellationToken.None) + .ConfigureAwait(false); + return reply; + } +} + +public enum ResultsEncoding +{ + Base64, + Hex, +} + +public struct WorkerCallServiceInputModel +{ + public string Request; + public Dictionary Results; + public ResultsEncoding ResultsEncoding; +} diff --git a/packages/python/pyproject.toml b/packages/python/pyproject.toml index 48e69c775..4f7f1f655 100644 --- a/packages/python/pyproject.toml +++ b/packages/python/pyproject.toml @@ -52,4 +52,7 @@ dev = [ [tool.pytest.ini_options] addopts = [ "--import-mode=importlib", +] +markers = [ + "worker: worker tests (deselect with '-m \"not worker\"')" ] \ No newline at end of file diff --git a/packages/python/tests/conftest.py b/packages/python/tests/conftest.py index 4c4c14b88..21bf0350e 100644 --- a/packages/python/tests/conftest.py +++ b/packages/python/tests/conftest.py @@ -1,4 +1,6 @@ +import base64 import os + import pytest import requests @@ -12,8 +14,11 @@ ArmoniKVersions, ) from armonik.common.channel import create_channel, _find_bundle_path, _load_certificates +from armonik.protogen.common.worker_common_pb2 import ProcessRequest from armonik.protogen.worker.agent_service_pb2_grpc import AgentStub -from typing import List, Union +from typing import List, Union, Dict, Any + +from google.protobuf.json_format import MessageToJson ca_cert = os.getenv("Grpc__CaCert") or os.getenv("GrpcClient__CaCert") or None client_cert = os.getenv("Grpc__ClientCert") or os.getenv("GrpcClient__CertPem") or None @@ -24,6 +29,8 @@ http_endpoint = os.getenv("Http__Endpoint", scheme + "://localhost:5000") calls_endpoint = http_endpoint + "/calls.json" reset_endpoint = http_endpoint + "/reset" +healthcheck_endpoint = http_endpoint + "/worker/healthcheck" +process_endpoint = http_endpoint + "/worker/process" data_folder = os.getcwd() request_ca = ca_cert if ca_cert is not None else _find_bundle_path() @@ -211,3 +218,50 @@ def all_rpc_called( print(f"RPCs not implemented in {service_name} service: {missing_rpcs}.") return False return True + + +def call_me_with_healthcheck( + endpoint: str = healthcheck_endpoint, +) -> Union[str, Dict[str, Any]]: + """ + Call the worker for a health check. + Args: + endpoint: endpoint to call. + + Returns: + The result of the call. + """ + response = requests.post(endpoint, verify=request_ca, cert=request_certs) + response.raise_for_status() + if "json" in response.headers["content-type"]: + return response.json() + return response.text + + +def call_me_with_process( + request: ProcessRequest, results: Dict[str, bytes], endpoint: str = process_endpoint +) -> Union[str, Dict[str, Any]]: + """ + Call the worker for Process call. + Args: + request: Process request to send to the worker. + results: Task results used as data dependencies and payload + endpoint: endpoint to call. + + Returns: + The result of the call. + """ + response = requests.post( + endpoint, + verify=request_ca, + cert=request_certs, + json={ + "Request": MessageToJson(request), + "Results": {k: base64.b64encode(v).decode("ascii") for k, v in results.items()}, + "ResultsEncoding": "Base64", + }, + ) + response.raise_for_status() + if "json" in response.headers["content-type"]: + return response.json() + return response.text diff --git a/packages/python/tests/test_full_worker.py b/packages/python/tests/test_full_worker.py new file mode 100644 index 000000000..62024338a --- /dev/null +++ b/packages/python/tests/test_full_worker.py @@ -0,0 +1,140 @@ +import os +from datetime import timedelta + +import pytest +from armonik.common import TaskOptions, Output +from armonik.protogen.common.objects_pb2 import Configuration +from armonik.protogen.common.worker_common_pb2 import ProcessRequest +from armonik.worker import armonik_worker, TaskHandler + +from .conftest import ca_cert, client_cert, client_key + +from .conftest import ( + call_me_with_healthcheck, + call_me_with_process, + data_folder, + rpc_called, +) + +payload_id = "payload_id" +data_dependencies = ["dd_0_id", "dd_1_id"] +expected_output_keys = ["eok_0", "eok_1"] +token = "comm_token" +session_id = "session_id" +task_id = "task_id" + +data_chunk_max_size = 84000 +results = { + payload_id: b"payload", + data_dependencies[0]: b"dd_0", + data_dependencies[1]: b"dd_1", +} +task_options = TaskOptions( + max_duration=timedelta(minutes=5), + priority=3, + max_retries=10, + partition_id="partition_id", + application_name="application_name", + application_version="application_version", + application_namespace="application_namespace", + application_service="application_service", + options={"option1_key": "option1", "option2_key": "option2"}, +) + + +@armonik_worker( + agent_client_certificate=client_cert, + agent_client_key=client_key, + agent_certificate_authority=ca_cert, +) +def worker(task_handler: TaskHandler): + assert task_handler.payload == results[payload_id] + assert len(task_handler.data_dependencies) == len(data_dependencies) + assert task_handler.data_dependencies[data_dependencies[0]] == results[data_dependencies[0]] + assert ( + len( + set(task_handler.data_dependencies.values()).symmetric_difference( + (results[data_dependencies[0]], results[data_dependencies[1]]) + ) + ) + == 0 + ) + assert task_handler.task_options.max_duration == task_options.max_duration + assert task_handler.task_options.priority == task_options.priority + assert task_handler.task_options.max_retries == task_options.max_retries + assert task_handler.task_options.partition_id == task_options.partition_id + assert task_handler.task_options.application_name == task_options.application_name + assert task_handler.task_options.application_version == task_options.application_version + assert task_handler.task_options.application_namespace == task_options.application_namespace + assert task_handler.task_options.application_service == task_options.application_service + assert ( + len( + set(task_handler.task_options.options.keys()).symmetric_difference( + task_options.options.keys() + ) + ) + == 0 + ) + assert ( + len( + set(task_handler.task_options.options.values()).symmetric_difference( + task_options.options.values() + ) + ) + == 0 + ) + assert task_handler.configuration.data_chunk_max_size == data_chunk_max_size + assert ( + len(set(task_handler.expected_results).symmetric_difference(set(expected_output_keys))) == 0 + ) + assert task_handler.token == token + assert task_handler.session_id == session_id + assert task_handler.task_id == task_id + task_handler.send_results({k: k.encode() for k in expected_output_keys}) + return Output() + + +@pytest.fixture +def worker_server(clean_up): + worker.run(wait=False) + yield + worker.stop() + + +@pytest.fixture +def clean_up_data_folder(): + yield + for k in data_dependencies + expected_output_keys + [payload_id]: + p = os.path.join(data_folder, k) + if os.path.exists(p): + os.remove(p) + + +@pytest.mark.worker +class TestFullWorker: + def test_worker_healthcheck(self, worker_server): + _ = worker_server + reply = call_me_with_healthcheck() + assert isinstance(reply, dict), str(reply) + assert reply["status"] == "SERVING", str(reply) + + def test_worker_process(self, worker_server): + _ = worker_server + reply = call_me_with_process( + request=ProcessRequest( + communication_token=token, + session_id=session_id, + task_id=task_id, + task_options=task_options.to_message(), + expected_output_keys=expected_output_keys, + payload_id=payload_id, + data_dependencies=data_dependencies, + data_folder=data_folder, + configuration=Configuration(data_chunk_max_size=data_chunk_max_size), + ), + results=results, + ) + assert isinstance(reply, dict), str(reply) + assert reply.get("output", {}).get("ok") is not None, str(reply) + assert reply.get("output", {}).get("error") is None, str(reply) + assert rpc_called("Agent", "NotifyResultData") diff --git a/scripts/mock_test.sh b/scripts/mock_test.sh index 06e32817c..6708ff8ea 100755 --- a/scripts/mock_test.sh +++ b/scripts/mock_test.sh @@ -29,7 +29,7 @@ sleep 5 cd "$working_dir/$TEST_DIR" -$TEST_COMMAND || ret=$? +eval $TEST_COMMAND || ret=$? echo $server_pid kill $server_pid