From c00846bd322f33e383e41c87cb6641e71523d240 Mon Sep 17 00:00:00 2001 From: tooyosi Date: Tue, 7 Apr 2026 17:11:59 +0100 Subject: [PATCH] Add metadata-driven BaJoR runtime options for contexts --- app/services/batch/context_runtime_options.rb | 41 ++++++ app/services/batch/prediction/create_job.rb | 9 +- app/services/batch/training/create_job.rb | 14 +- lib/bajor/client.rb | 12 +- spec/fixtures/contexts.yml | 14 +- spec/lib/bajor/client_spec.rb | 126 +++++++++++++++++- .../batch/prediction/create_job_spec.rb | 27 +++- .../batch/training/create_job_spec.rb | 27 +++- 8 files changed, 233 insertions(+), 37 deletions(-) create mode 100644 app/services/batch/context_runtime_options.rb diff --git a/app/services/batch/context_runtime_options.rb b/app/services/batch/context_runtime_options.rb new file mode 100644 index 0000000..14f576d --- /dev/null +++ b/app/services/batch/context_runtime_options.rb @@ -0,0 +1,41 @@ +# frozen_string_literal: true + +module Batch + class ContextRuntimeOptions + class << self + def for_training(context) + build(context, mode: :training) + end + + def for_prediction(context) + build(context, mode: :prediction) + end + + private + + def build(context, mode:) + metadata = context.metadata.is_a?(Hash) ? context.metadata : {} + batch_config = metadata['batch'].is_a?(Hash) ? metadata['batch'] : {} + + { + workflow_name: context.extractor_name, + fixed_crop: batch_config['fixed_crop'] || metadata['fixed_crop'], + n_blocks: batch_config['n_blocks'] || metadata['n_blocks'], + container_image_name: batch_config['container_image_name'], + training_script_path: batch_config['training_script_path'], + prediction_script_path: batch_config['prediction_script_path'], + promote_script_path: batch_config['promote_script_path'], + pretrained_checkpoint_url: batch_config['pretrained_checkpoint_url'] + }.compact.tap do |opts| + if mode == :training + opts.delete(:prediction_script_path) + else + opts.delete(:training_script_path) + opts.delete(:promote_script_path) + opts.delete(:n_blocks) + end + end + end + end + end +end diff --git a/app/services/batch/prediction/create_job.rb b/app/services/batch/prediction/create_job.rb index aa52975..820b24c 100644 --- a/app/services/batch/prediction/create_job.rb +++ b/app/services/batch/prediction/create_job.rb @@ -30,14 +30,7 @@ def prediction_options context = Context.find_by(active_subject_set_id: prediction_job.subject_set_id) return {} unless context - fixed_crop = if context.metadata.is_a?(Hash) && context.metadata['fixed_crop'].is_a?(Hash) - context.metadata['fixed_crop'] - end - - { - workflow_name: context.extractor_name, - fixed_crop: fixed_crop, - }.compact + Batch::ContextRuntimeOptions.for_prediction(context) end end end diff --git a/app/services/batch/training/create_job.rb b/app/services/batch/training/create_job.rb index c3d07b7..02a1496 100644 --- a/app/services/batch/training/create_job.rb +++ b/app/services/batch/training/create_job.rb @@ -29,19 +29,7 @@ def training_options context = Context.find_by(workflow_id: training_job.workflow_id) return {} unless context - fixed_crop = if context.metadata.is_a?(Hash) && context.metadata['fixed_crop'].is_a?(Hash) - context.metadata['fixed_crop'] - end - - n_blocks = if context.metadata.is_a?(Hash) && context.metadata['n_blocks'] - context.metadata['n_blocks'] - end - - { - workflow_name: context.extractor_name, - fixed_crop: fixed_crop, - n_blocks: n_blocks - }.compact + Batch::ContextRuntimeOptions.for_training(context) end end end diff --git a/lib/bajor/client.rb b/lib/bajor/client.rb index c734378..087aa1f 100644 --- a/lib/bajor/client.rb +++ b/lib/bajor/client.rb @@ -10,6 +10,16 @@ class TrainingJobTaskError < StandardError; end JSON_HEADERS = { 'Content-Type' => 'application/json', 'Accept' => 'application/json' }.freeze DEFAULT_OPTIONS = { workflow_name: 'cosmic_dawn' }.freeze + BATCH_OPTION_KEYS = %i[ + workflow_name + fixed_crop + n_blocks + container_image_name + training_script_path + prediction_script_path + promote_script_path + pretrained_checkpoint_url + ].freeze include HTTParty @@ -138,7 +148,7 @@ def bajor_service_host def build_opts(options, include_schema=true) raw = options.with_indifferent_access - overrides = raw.symbolize_keys.slice(:workflow_name, :fixed_crop, :n_blocks) + overrides = raw.symbolize_keys.slice(*BATCH_OPTION_KEYS) DEFAULT_OPTIONS .merge(overrides) diff --git a/spec/fixtures/contexts.yml b/spec/fixtures/contexts.yml index 8b74e77..414e648 100644 --- a/spec/fixtures/contexts.yml +++ b/spec/fixtures/contexts.yml @@ -20,10 +20,10 @@ galaxy_zoo_euclid_active_learning_project: galaxy_zoo_cosmos_active_learning_project: id: 3 - workflow_id: 133 + workflow_id: 134 project_id: 41 - active_subject_set_id: 56 - pool_subject_set_id: 67 + active_subject_set_id: 57 + pool_subject_set_id: 68 module_name: 'galaxy_zoo' extractor_name: 'jwst_cosmos' metadata: { @@ -32,5 +32,13 @@ galaxy_zoo_cosmos_active_learning_project: 'lower_left_y': 30, 'upper_right_x': 750, 'upper_right_y': 750 + }, + 'batch': { + 'container_image_name': 'zoobot.azurecr.io/pytorch:custom-jwst', + 'training_script_path': 'jwst/train_model_finetune_on_catalog.py', + 'prediction_script_path': 'jwst/predict_catalog_with_model.py', + 'promote_script_path': 'jwst/promote_best_checkpoint_to_model.sh', + 'pretrained_checkpoint_url': 'https://kadeactivelearning.blob.core.windows.net/models/jwst/jwst-pretrained.ckpt', + 'n_blocks': 2 } } diff --git a/spec/lib/bajor/client_spec.rb b/spec/lib/bajor/client_spec.rb index 2fd6881..9b4787d 100644 --- a/spec/lib/bajor/client_spec.rb +++ b/spec/lib/bajor/client_spec.rb @@ -3,11 +3,17 @@ require 'bajor/client' require 'rails_helper' -def build_expected_body(manifest_url: nil, manifest_path: nil, workflow_name:, fixed_crop: nil, n_blocks: nil) +def build_expected_body(manifest_url: nil, manifest_path: nil, workflow_name:, fixed_crop: nil, n_blocks: nil, container_image_name: nil, training_script_path: nil, prediction_script_path: nil, promote_script_path: nil, pretrained_checkpoint_url: nil) opts = { workflow_name: workflow_name } + opts[:container_image_name] = container_image_name if container_image_name + opts[:training_script_path] = training_script_path if training_script_path + opts[:prediction_script_path] = prediction_script_path if prediction_script_path + opts[:promote_script_path] = promote_script_path if promote_script_path + opts[:pretrained_checkpoint_url] = pretrained_checkpoint_url if pretrained_checkpoint_url + run_opts = [] run_opts << "--schema #{workflow_name}" if manifest_path run_opts << "--fixed-crop '#{fixed_crop.to_json}'" if fixed_crop @@ -155,6 +161,65 @@ def build_expected_body(manifest_url: nil, manifest_path: nil, workflow_name:, f end end + context 'with metadata-driven batch config' do + let(:workflow_name) { 'jwst_cosmos' } + let(:fixed_crop) do + { + lower_left_x: 30, + lower_left_y: 30, + upper_right_x: 750, + upper_right_y: 750 + } + end + let(:n_blocks) { 2 } + let(:container_image_name) { 'zoobot.azurecr.io/pytorch:custom-jwst' } + let(:training_script_path) { 'jwst/train_model_finetune_on_catalog.py' } + let(:promote_script_path) { 'jwst/promote_best_checkpoint_to_model.sh' } + let(:pretrained_checkpoint_url) { 'https://kadeactivelearning.blob.core.windows.net/models/jwst/jwst-pretrained.ckpt' } + let(:expected_body) do + build_expected_body( + manifest_path: catalogue_manifest_path, + workflow_name: workflow_name, + fixed_crop: fixed_crop, + n_blocks: n_blocks, + container_image_name: container_image_name, + training_script_path: training_script_path, + promote_script_path: promote_script_path, + pretrained_checkpoint_url: pretrained_checkpoint_url + ) + end + let(:request) do + stub_request(:post, request_url) + .with( + body: expected_body.to_json, + headers: request_headers + ) + end + + before do + request.to_return(status: 201, body: bajor_response_body.to_json, headers: { content_type: 'application/json' }) + end + + it 'serializes metadata-driven training config into the request body' do + bajor_client.create_training_job( + catalogue_manifest_path, + { + workflow_name: workflow_name, + fixed_crop: fixed_crop, + n_blocks: n_blocks, + container_image_name: container_image_name, + training_script_path: training_script_path, + promote_script_path: promote_script_path, + pretrained_checkpoint_url: pretrained_checkpoint_url + } + ) + + expect( + a_request(:post, request_url).with(body: expected_body, headers: request_headers) + ).to have_been_made.once + end + end + context 'with jswt_cosmos workflow and n_blocks' do let(:workflow_name) { 'jswt_cosmos' } let(:n_blocks) { 2 } @@ -285,13 +350,13 @@ def build_expected_body(manifest_url: nil, manifest_path: nil, workflow_name:, f let(:workflow_name) { 'euclid' } let(:request_body) do - { manifest_url: manifest_url, opts: { workflow_name:} } + { manifest_url: manifest_url, opts: { workflow_name: workflow_name } } end let(:request) do stub_request(:post, request_url) .with( - body: { manifest_url: manifest_url, opts: { workflow_name:} }, + body: { manifest_url: manifest_url, opts: { workflow_name: workflow_name } }, headers: request_headers ) end @@ -301,7 +366,7 @@ def build_expected_body(manifest_url: nil, manifest_path: nil, workflow_name:, f end it 'sends the right workflow name' do - bajor_client.create_prediction_job(manifest_url, { workflow_name: }) + bajor_client.create_prediction_job(manifest_url, { workflow_name: workflow_name }) expect( a_request(:post, request_url).with(body: request_body, headers: request_headers) ).to have_been_made.once @@ -345,6 +410,59 @@ def build_expected_body(manifest_url: nil, manifest_path: nil, workflow_name:, f end end + context 'with metadata-driven prediction config' do + let(:workflow_name) { 'jwst_cosmos' } + let(:fixed_crop) do + { + lower_left_x: 30, + lower_left_y: 30, + upper_right_x: 750, + upper_right_y: 750 + } + end + let(:container_image_name) { 'zoobot.azurecr.io/pytorch:custom-jwst' } + let(:prediction_script_path) { 'jwst/predict_catalog_with_model.py' } + let(:pretrained_checkpoint_url) { 'https://kadeactivelearning.blob.core.windows.net/models/jwst/jwst-pretrained.ckpt' } + let(:request_body) do + build_expected_body( + manifest_url: manifest_url, + workflow_name: workflow_name, + fixed_crop: fixed_crop, + container_image_name: container_image_name, + prediction_script_path: prediction_script_path, + pretrained_checkpoint_url: pretrained_checkpoint_url + ) + end + let(:request) do + stub_request(:post, request_url) + .with( + body: request_body, + headers: request_headers + ) + end + + before do + request.to_return(status: 201, body: bajor_response_body.to_json, headers: { content_type: 'application/json' }) + end + + it 'serializes metadata-driven prediction config into the request body' do + bajor_client.create_prediction_job( + manifest_url, + { + workflow_name: workflow_name, + fixed_crop: fixed_crop, + container_image_name: container_image_name, + prediction_script_path: prediction_script_path, + pretrained_checkpoint_url: pretrained_checkpoint_url + } + ) + + expect( + a_request(:post, request_url).with(body: request_body, headers: request_headers) + ).to have_been_made.once + end + end + context 'with jswt_cosmos workflow and n_blocks' do let(:workflow_name) { 'jswt_cosmos' } let(:n_blocks) { 2 } diff --git a/spec/services/batch/prediction/create_job_spec.rb b/spec/services/batch/prediction/create_job_spec.rb index ffd00b8..0ec7700 100644 --- a/spec/services/batch/prediction/create_job_spec.rb +++ b/spec/services/batch/prediction/create_job_spec.rb @@ -21,10 +21,7 @@ let(:prediction_create_job) { described_class.new(prediction_job, bajor_client_double) } let(:job_service_url) { 'https://bajor-host/prediction/job/123' } let(:prediction_options){ - { - workflow_name: context.extractor_name, - fixed_crop: context.metadata['fixed_crop'], - }.compact + Batch::ContextRuntimeOptions.for_prediction(context) } context 'when bajor submission succeeds' do @@ -63,6 +60,28 @@ end end + describe 'prediction_job with nested batch metadata' do + let(:context){ contexts(:galaxy_zoo_cosmos_active_learning_project) } + let(:prediction_job) do + PredictionJob.new( + manifest_url: manifest_url, + state: :pending, + subject_set_id: context.active_subject_set_id, + probability_threshold: 0.5, + randomisation_factor: 0.5 + ) + end + + it 'passes metadata-driven prediction config through to bajor' do + prediction_create_job.run + + expect(bajor_client_double).to have_received(:create_prediction_job).with( + manifest_url, + Batch::ContextRuntimeOptions.for_prediction(context) + ).once + end + end + it 'updates the state tracking info on the prediction job resource' do expect { prediction_create_job.run diff --git a/spec/services/batch/training/create_job_spec.rb b/spec/services/batch/training/create_job_spec.rb index f7efb0a..8fc2795 100644 --- a/spec/services/batch/training/create_job_spec.rb +++ b/spec/services/batch/training/create_job_spec.rb @@ -5,6 +5,8 @@ RSpec.describe Batch::Training::CreateJob do describe '#run' do + fixtures :contexts + let(:manifest_path) { '/a/shared/blob/storage/path.csv' } let(:manifest_url) { "https://a.shared.blob.storage#{manifest_path}"} let(:training_job) { TrainingJob.new(manifest_url: manifest_url, workflow_id: '123', state: :pending) } @@ -22,10 +24,7 @@ it 'calls the bajor client service with the correct manifest_path' do parent_context = Context.find_by(workflow_id: training_job.workflow_id) - training_opts = { - workflow_name: parent_context.extractor_name, - fixed_crop: parent_context.metadata['fixed_crop'], - }.compact + training_opts = Batch::ContextRuntimeOptions.for_training(parent_context) training_create_job.run expect(bajor_client_double).to have_received(:create_training_job).with(manifest_path, training_opts).once @@ -52,5 +51,25 @@ .and change(training_job, :message).from('').to(error_message) end end + + context 'with metadata-driven batch config' do + let(:metadata_context) { contexts(:galaxy_zoo_cosmos_active_learning_project) } + let(:training_job) do + TrainingJob.new( + manifest_url: manifest_url, + workflow_id: metadata_context.workflow_id, + state: :pending + ) + end + + it 'passes nested batch metadata through to bajor' do + training_create_job.run + + expect(bajor_client_double).to have_received(:create_training_job).with( + manifest_path, + Batch::ContextRuntimeOptions.for_training(metadata_context) + ).once + end + end end end