Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ jobs:
dir: cpp
env: {}
- lang: Python
cmd: pytest tests --cov=armonik --cov-config=.coveragerc --cov-report=term-missing --cov-append --cov-report xml:coverage.xml --cov-report html:coverage_report
cmd: pytest tests -m "not worker" --cov=armonik --cov-config=.coveragerc --cov-report=term-missing --cov-append --cov-report xml:coverage.xml --cov-report html:coverage_report
worker-cmd: pytest tests -m "worker" --cov=armonik --cov-config=.coveragerc --cov-report=term-missing --cov-append --cov-report xml:coverage.xml --cov-report html:coverage_report
secure-worker: true
dir: python
env: {}
- lang: Rust
Expand Down Expand Up @@ -198,7 +200,12 @@ jobs:
export Http__Endpoint=http://localhost:4999
export GrpcClient__Endpoint=http://localhost:5000
export GrpcClient__AllowUnsafeConnection=true
export ComputePlane__AgentChannel__Address=http://localhost:5000
export ComputePlane__AgentChannel__SocketType=tcp
export ComputePlane__WorkerChannel__Address=http://localhost:5010
export ComputePlane__WorkerChannel__SocketType=tcp
../../scripts/mock_test.sh ${{ matrix.language.dir }} '${{ matrix.language.cmd }}'
[ -z '${{ matrix.language.worker-cmd }}' ] || ../../scripts/mock_test.sh ${{ matrix.language.dir }} '${{ matrix.language.worker-cmd }}'

- name: TLS Insecure
if: ${{ contains(fromJson('["C#", "Rust"]'), matrix.language.lang) }}
Expand All @@ -214,7 +221,12 @@ jobs:
export Grpc__AllowUnsafeConnection=true
export GrpcClient__Endpoint=https://localhost:5001
export GrpcClient__AllowUnsafeConnection=true
export ComputePlane__AgentChannel__Address=https://localhost:5001
export ComputePlane__AgentChannel__SocketType=tcp
export ComputePlane__WorkerChannel__Address=http://localhost:5010
export ComputePlane__WorkerChannel__SocketType=tcp
../../scripts/mock_test.sh ${{ matrix.language.dir }} '${{ matrix.language.cmd }}'
[ -z '${{ matrix.language.worker-cmd }}' ] || [ '${{ matrix.language.secure-worker }}' = false ] || ../../scripts/mock_test.sh ${{ matrix.language.dir }} '${{ matrix.language.worker-cmd }}'

- name: TLS secure
working-directory: packages/csharp/
Expand All @@ -231,7 +243,12 @@ jobs:
export GrpcClient__Endpoint=https://localhost:5001
export GrpcClient__AllowUnsafeConnection=false
export GrpcClient__CaCert=$CertFolder/server1-ca.pem
export ComputePlane__AgentChannel__Address=https://localhost:5001
export ComputePlane__AgentChannel__SocketType=tcp
export ComputePlane__WorkerChannel__Address=http://localhost:5010
export ComputePlane__WorkerChannel__SocketType=tcp
../../scripts/mock_test.sh ${{ matrix.language.dir }} '${{ matrix.language.cmd }}'
[ -z '${{ matrix.language.worker-cmd }}' ] || [ '${{ matrix.language.secure-worker }}' = false ] || ../../scripts/mock_test.sh ${{ matrix.language.dir }} '${{ matrix.language.worker-cmd }}'

- name: TLS store
working-directory: packages/csharp/
Expand All @@ -245,7 +262,12 @@ jobs:
export Http__Endpoint=https://localhost:5002
export GrpcClient__Endpoint=https://localhost:5002
export GrpcClient__AllowUnsafeConnection=false
export ComputePlane__AgentChannel__Address=https://localhost:5002
export ComputePlane__AgentChannel__SocketType=tcp
export ComputePlane__WorkerChannel__Address=http://localhost:5010
export ComputePlane__WorkerChannel__SocketType=tcp
../../scripts/mock_test.sh ${{ matrix.language.dir }} '${{ matrix.language.cmd }}'
[ -z '${{ matrix.language.worker-cmd }}' ] || [ '${{ matrix.language.secure-worker }}' = false ] || ../../scripts/mock_test.sh ${{ matrix.language.dir }} '${{ matrix.language.worker-cmd }}'

- name: mTLS Insecure
if: ${{ contains(fromJson('["C#", "Rust"]'), matrix.language.lang) }}
Expand All @@ -265,7 +287,12 @@ jobs:
export GrpcClient__AllowUnsafeConnection=true
export GrpcClient__CertPem=$CertFolder/client.pem
export GrpcClient__KeyPem=$CertFolder/client.key
export ComputePlane__AgentChannel__Address=https://localhost:5003
export ComputePlane__AgentChannel__SocketType=tcp
export ComputePlane__WorkerChannel__Address=http://localhost:5010
export ComputePlane__WorkerChannel__SocketType=tcp
../../scripts/mock_test.sh ${{ matrix.language.dir }} '${{ matrix.language.cmd }}'
[ -z '${{ matrix.language.worker-cmd }}' ] || [ '${{ matrix.language.secure-worker }}' = false ] || ../../scripts/mock_test.sh ${{ matrix.language.dir }} '${{ matrix.language.worker-cmd }}'

- name: mTLS secure
working-directory: packages/csharp/
Expand All @@ -287,7 +314,12 @@ jobs:
export GrpcClient__CaCert=$CertFolder/server1-ca.pem
export GrpcClient__CertPem=$CertFolder/client.pem
export GrpcClient__KeyPem=$CertFolder/client.key
export ComputePlane__AgentChannel__Address=https://localhost:5003
export ComputePlane__AgentChannel__SocketType=tcp
export ComputePlane__WorkerChannel__Address=http://localhost:5010
export ComputePlane__WorkerChannel__SocketType=tcp
../../scripts/mock_test.sh ${{ matrix.language.dir }} '${{ matrix.language.cmd }}'
[ -z '${{ matrix.language.worker-cmd }}' ] || [ '${{ matrix.language.secure-worker }}' = false ] || ../../scripts/mock_test.sh ${{ matrix.language.dir }} '${{ matrix.language.worker-cmd }}'

- name: mTLS store
working-directory: packages/csharp/
Expand All @@ -307,7 +339,12 @@ jobs:
export GrpcClient__AllowUnsafeConnection=false
export GrpcClient__CertPem=$CertFolder/client.pem
export GrpcClient__KeyPem=$CertFolder/client.key
export ComputePlane__AgentChannel__Address=https://localhost:5004
export ComputePlane__AgentChannel__SocketType=tcp
export ComputePlane__WorkerChannel__Address=http://localhost:5010
export ComputePlane__WorkerChannel__SocketType=tcp
../../scripts/mock_test.sh ${{ matrix.language.dir }} '${{ matrix.language.cmd }}'
[ -z '${{ matrix.language.worker-cmd }}' ] || [ '${{ matrix.language.secure-worker }}' = false ] || ../../scripts/mock_test.sh ${{ matrix.language.dir }} '${{ matrix.language.worker-cmd }}'

- name: Test Report
uses: dorny/test-reporter@v1
Expand Down
63 changes: 58 additions & 5 deletions packages/csharp/ArmoniK.Api.Mock/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,21 @@
// See the License for the specific language governing permissions and
// limitations under the License.

using System;
using System.Collections.Generic;
using System.IO;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using System.Threading.Tasks;

using ArmoniK.Api.Common.Channel.Utils;
using ArmoniK.Api.Common.Options;
using ArmoniK.Api.gRPC.V1.Worker;
using ArmoniK.Api.Mock;
using ArmoniK.Api.Mock.Services;

using Google.Protobuf;

using Microsoft.AspNetCore.Authentication.Certificate;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Builder;
Expand Down Expand Up @@ -52,8 +58,7 @@
.GetSection("Port")
.Value ?? "5001");

X509Certificate2? serverCert = null;
string? clientCertPath = null;
X509Certificate2? serverCert = null;

var serverCertPath = builder.Configuration.GetSection("Http")
.GetSection("Cert")
Expand All @@ -62,9 +67,9 @@
var serverKeyPath = builder.Configuration.GetSection("Http")
.GetSection("Key")
.Value ?? "";
clientCertPath = builder.Configuration.GetSection("Http")
.GetSection("ClientCert")
.Value;
var clientCertPath = builder.Configuration.GetSection("Http")
.GetSection("ClientCert")
.Value;
if (string.IsNullOrEmpty(clientCertPath))
{
clientCertPath = null;
Expand All @@ -91,6 +96,17 @@
builder.Services.AddSingleton(service);
}

var rawComputePlaneOptions = builder.Configuration.GetSection(ComputePlane.SettingSection);
var workerChannelOptions = rawComputePlaneOptions?.GetSection(ComputePlane.WorkerChannelSection);

builder.Services.AddSingleton(_ => new GrpcChannel
{
Address = workerChannelOptions?.GetValue<string>("Address") ?? "/cache/armonik_worker.sock",
SocketType = workerChannelOptions?.GetValue<GrpcSocketType>("SocketType") ?? GrpcSocketType.UnixDomainSocket,
});
builder.Services.AddSingleton<GrpcChannelProvider>();
builder.Services.AddSingleton<WorkerCallService>();

if (clientCertPath is not null)
{
builder.Services.AddAuthentication(CertificateAuthenticationDefaults.AuthenticationScheme)
Expand Down Expand Up @@ -192,6 +208,10 @@
CountingService.ResetCounters();
return Task.CompletedTask;
});
app.MapPost("/worker/process",
SendProcessRequest);
app.MapPost("/worker/healthcheck",
SendHealthCheck);

app.Run();

Expand All @@ -212,3 +232,36 @@ async Task Calls(HttpContext context)
context.Response.ContentType = "application/json";
await context.Response.Body.WriteAsync(Encoding.ASCII.GetBytes(body));
}

async Task SendProcessRequest(HttpContext context)
{
var requestBody = Encoding.ASCII.GetString(await ReadAll(context.Request.Body));
var requestInput = JsonConvert.DeserializeObject<WorkerCallServiceInputModel>(requestBody);
var request = new JsonParser(JsonParser.Settings.Default).Parse<ProcessRequest>(requestInput.Request);
foreach (var result in requestInput.Results)
{
await using var file = File.OpenWrite(Path.Join(request.DataFolder,
result.Key));
if (requestInput.ResultsEncoding is ResultsEncoding.Base64)
{
await file.WriteAsync(Convert.FromBase64String(result.Value));
}
else
{
await file.WriteAsync(Convert.FromHexString(result.Value));
}
}

var reply = await app.Services.GetRequiredService<WorkerCallService>()
.ProcessRequest(request);
context.Response.ContentType = "application/json";
await context.Response.Body.WriteAsync(Encoding.ASCII.GetBytes(new JsonFormatter(JsonFormatter.Settings.Default).Format(reply)));
}

async Task SendHealthCheck(HttpContext context)
{
var reply = await app.Services.GetRequiredService<WorkerCallService>()
.HealthCheckRequest();
context.Response.ContentType = "application/json";
await context.Response.Body.WriteAsync(Encoding.ASCII.GetBytes(new JsonFormatter(JsonFormatter.Settings.Default).Format(reply)));
}
63 changes: 63 additions & 0 deletions packages/csharp/ArmoniK.Api.Mock/WorkerCallService.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// This file is part of the ArmoniK project
//
// Copyright (C) ANEO, 2021-2025.All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License")
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;

using ArmoniK.Api.Common.Channel.Utils;
using ArmoniK.Api.gRPC.V1;
using ArmoniK.Api.gRPC.V1.Worker;

namespace ArmoniK.Api.Mock;

public class WorkerCallService(GrpcChannelProvider provider)
{
public async Task<HealthCheckReply> HealthCheckRequest()
{
var channel = provider.Get();
var workerClient = new Worker.WorkerClient(channel);
var reply = await workerClient.HealthCheckAsync(new Empty());
await channel.ShutdownAsync()
.WaitAsync(CancellationToken.None)
.ConfigureAwait(false);
return reply;
}

public async Task<ProcessReply> ProcessRequest(ProcessRequest request)
{
var channel = provider.Get();
var workerClient = new Worker.WorkerClient(channel);
var reply = await workerClient.ProcessAsync(request);
await channel.ShutdownAsync()
.WaitAsync(CancellationToken.None)
.ConfigureAwait(false);
return reply;
}
}

public enum ResultsEncoding
{
Base64,
Hex,
}

public struct WorkerCallServiceInputModel
{
public string Request;
public Dictionary<string, string> Results;
public ResultsEncoding ResultsEncoding;
}
3 changes: 3 additions & 0 deletions packages/python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,7 @@ dev = [
[tool.pytest.ini_options]
addopts = [
"--import-mode=importlib",
]
markers = [
"worker: worker tests (deselect with '-m \"not worker\"')"
]
56 changes: 55 additions & 1 deletion packages/python/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import base64
import os

import pytest
import requests

Expand All @@ -12,8 +14,11 @@
ArmoniKVersions,
)
from armonik.common.channel import create_channel, _find_bundle_path, _load_certificates
from armonik.protogen.common.worker_common_pb2 import ProcessRequest
from armonik.protogen.worker.agent_service_pb2_grpc import AgentStub
from typing import List, Union
from typing import List, Union, Dict, Any

from google.protobuf.json_format import MessageToJson

ca_cert = os.getenv("Grpc__CaCert") or os.getenv("GrpcClient__CaCert") or None
client_cert = os.getenv("Grpc__ClientCert") or os.getenv("GrpcClient__CertPem") or None
Expand All @@ -24,6 +29,8 @@
http_endpoint = os.getenv("Http__Endpoint", scheme + "://localhost:5000")
calls_endpoint = http_endpoint + "/calls.json"
reset_endpoint = http_endpoint + "/reset"
healthcheck_endpoint = http_endpoint + "/worker/healthcheck"
process_endpoint = http_endpoint + "/worker/process"
data_folder = os.getcwd()

request_ca = ca_cert if ca_cert is not None else _find_bundle_path()
Expand Down Expand Up @@ -211,3 +218,50 @@ def all_rpc_called(
print(f"RPCs not implemented in {service_name} service: {missing_rpcs}.")
return False
return True


def call_me_with_healthcheck(
endpoint: str = healthcheck_endpoint,
) -> Union[str, Dict[str, Any]]:
"""
Call the worker for a health check.
Args:
endpoint: endpoint to call.

Returns:
The result of the call.
"""
response = requests.post(endpoint, verify=request_ca, cert=request_certs)
response.raise_for_status()
if "json" in response.headers["content-type"]:
return response.json()
return response.text


def call_me_with_process(
request: ProcessRequest, results: Dict[str, bytes], endpoint: str = process_endpoint
) -> Union[str, Dict[str, Any]]:
"""
Call the worker for Process call.
Args:
request: Process request to send to the worker.
results: Task results used as data dependencies and payload
endpoint: endpoint to call.

Returns:
The result of the call.
"""
response = requests.post(
endpoint,
verify=request_ca,
cert=request_certs,
json={
"Request": MessageToJson(request),
"Results": {k: base64.b64encode(v).decode("ascii") for k, v in results.items()},
"ResultsEncoding": "Base64",
},
)
response.raise_for_status()
if "json" in response.headers["content-type"]:
return response.json()
return response.text
Loading