Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions app/services/batch/context_runtime_options.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# frozen_string_literal: true

module Batch
class ContextRuntimeOptions
class << self
def for_training(context)
build(context, mode: :training)
end

def for_prediction(context)
build(context, mode: :prediction)
end

private

def build(context, mode:)
metadata = context.metadata.is_a?(Hash) ? context.metadata : {}
batch_config = metadata['batch'].is_a?(Hash) ? metadata['batch'] : {}

{
workflow_name: context.extractor_name,
fixed_crop: batch_config['fixed_crop'] || metadata['fixed_crop'],
n_blocks: batch_config['n_blocks'] || metadata['n_blocks'],
container_image_name: batch_config['container_image_name'],
training_script_path: batch_config['training_script_path'],
prediction_script_path: batch_config['prediction_script_path'],
promote_script_path: batch_config['promote_script_path'],
pretrained_checkpoint_url: batch_config['pretrained_checkpoint_url']
}.compact.tap do |opts|
if mode == :training
opts.delete(:prediction_script_path)
else
opts.delete(:training_script_path)
opts.delete(:promote_script_path)
opts.delete(:n_blocks)
end
end
end
end
end
end
9 changes: 1 addition & 8 deletions app/services/batch/prediction/create_job.rb
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,7 @@ def prediction_options
context = Context.find_by(active_subject_set_id: prediction_job.subject_set_id)
return {} unless context

fixed_crop = if context.metadata.is_a?(Hash) && context.metadata['fixed_crop'].is_a?(Hash)
context.metadata['fixed_crop']
end

{
workflow_name: context.extractor_name,
fixed_crop: fixed_crop,
}.compact
Batch::ContextRuntimeOptions.for_prediction(context)
end
end
end
Expand Down
14 changes: 1 addition & 13 deletions app/services/batch/training/create_job.rb
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,7 @@ def training_options
context = Context.find_by(workflow_id: training_job.workflow_id)
return {} unless context

fixed_crop = if context.metadata.is_a?(Hash) && context.metadata['fixed_crop'].is_a?(Hash)
context.metadata['fixed_crop']
end

n_blocks = if context.metadata.is_a?(Hash) && context.metadata['n_blocks']
context.metadata['n_blocks']
end

{
workflow_name: context.extractor_name,
fixed_crop: fixed_crop,
n_blocks: n_blocks
}.compact
Batch::ContextRuntimeOptions.for_training(context)
end
end
end
Expand Down
12 changes: 11 additions & 1 deletion lib/bajor/client.rb
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,16 @@ class TrainingJobTaskError < StandardError; end

JSON_HEADERS = { 'Content-Type' => 'application/json', 'Accept' => 'application/json' }.freeze
DEFAULT_OPTIONS = { workflow_name: 'cosmic_dawn' }.freeze
BATCH_OPTION_KEYS = %i[
workflow_name
fixed_crop
n_blocks
container_image_name
training_script_path
prediction_script_path
promote_script_path
pretrained_checkpoint_url
].freeze

include HTTParty

Expand Down Expand Up @@ -138,7 +148,7 @@ def bajor_service_host

def build_opts(options, include_schema=true)
raw = options.with_indifferent_access
overrides = raw.symbolize_keys.slice(:workflow_name, :fixed_crop, :n_blocks)
overrides = raw.symbolize_keys.slice(*BATCH_OPTION_KEYS)

DEFAULT_OPTIONS
.merge(overrides)
Expand Down
14 changes: 11 additions & 3 deletions spec/fixtures/contexts.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ galaxy_zoo_euclid_active_learning_project:

galaxy_zoo_cosmos_active_learning_project:
id: 3
workflow_id: 133
workflow_id: 134
project_id: 41
active_subject_set_id: 56
pool_subject_set_id: 67
active_subject_set_id: 57
pool_subject_set_id: 68
module_name: 'galaxy_zoo'
extractor_name: 'jwst_cosmos'
metadata: {
Expand All @@ -32,5 +32,13 @@ galaxy_zoo_cosmos_active_learning_project:
'lower_left_y': 30,
'upper_right_x': 750,
'upper_right_y': 750
},
'batch': {
'container_image_name': 'zoobot.azurecr.io/pytorch:custom-jwst',
'training_script_path': 'jwst/train_model_finetune_on_catalog.py',
'prediction_script_path': 'jwst/predict_catalog_with_model.py',
'promote_script_path': 'jwst/promote_best_checkpoint_to_model.sh',
'pretrained_checkpoint_url': 'https://kadeactivelearning.blob.core.windows.net/models/jwst/jwst-pretrained.ckpt',
'n_blocks': 2
}
}
126 changes: 122 additions & 4 deletions spec/lib/bajor/client_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,17 @@
require 'bajor/client'
require 'rails_helper'

def build_expected_body(manifest_url: nil, manifest_path: nil, workflow_name:, fixed_crop: nil, n_blocks: nil)
def build_expected_body(manifest_url: nil, manifest_path: nil, workflow_name:, fixed_crop: nil, n_blocks: nil, container_image_name: nil, training_script_path: nil, prediction_script_path: nil, promote_script_path: nil, pretrained_checkpoint_url: nil)
opts = {
workflow_name: workflow_name
}

opts[:container_image_name] = container_image_name if container_image_name
opts[:training_script_path] = training_script_path if training_script_path
opts[:prediction_script_path] = prediction_script_path if prediction_script_path
opts[:promote_script_path] = promote_script_path if promote_script_path
opts[:pretrained_checkpoint_url] = pretrained_checkpoint_url if pretrained_checkpoint_url

run_opts = []
run_opts << "--schema #{workflow_name}" if manifest_path
run_opts << "--fixed-crop '#{fixed_crop.to_json}'" if fixed_crop
Expand Down Expand Up @@ -155,6 +161,65 @@ def build_expected_body(manifest_url: nil, manifest_path: nil, workflow_name:, f
end
end

context 'with metadata-driven batch config' do
let(:workflow_name) { 'jwst_cosmos' }
let(:fixed_crop) do
{
lower_left_x: 30,
lower_left_y: 30,
upper_right_x: 750,
upper_right_y: 750
}
end
let(:n_blocks) { 2 }
let(:container_image_name) { 'zoobot.azurecr.io/pytorch:custom-jwst' }
let(:training_script_path) { 'jwst/train_model_finetune_on_catalog.py' }
let(:promote_script_path) { 'jwst/promote_best_checkpoint_to_model.sh' }
let(:pretrained_checkpoint_url) { 'https://kadeactivelearning.blob.core.windows.net/models/jwst/jwst-pretrained.ckpt' }
let(:expected_body) do
build_expected_body(
manifest_path: catalogue_manifest_path,
workflow_name: workflow_name,
fixed_crop: fixed_crop,
n_blocks: n_blocks,
container_image_name: container_image_name,
training_script_path: training_script_path,
promote_script_path: promote_script_path,
pretrained_checkpoint_url: pretrained_checkpoint_url
)
end
let(:request) do
stub_request(:post, request_url)
.with(
body: expected_body.to_json,
headers: request_headers
)
end

before do
request.to_return(status: 201, body: bajor_response_body.to_json, headers: { content_type: 'application/json' })
end

it 'serializes metadata-driven training config into the request body' do
bajor_client.create_training_job(
catalogue_manifest_path,
{
workflow_name: workflow_name,
fixed_crop: fixed_crop,
n_blocks: n_blocks,
container_image_name: container_image_name,
training_script_path: training_script_path,
promote_script_path: promote_script_path,
pretrained_checkpoint_url: pretrained_checkpoint_url
}
)

expect(
a_request(:post, request_url).with(body: expected_body, headers: request_headers)
).to have_been_made.once
end
end

context 'with jswt_cosmos workflow and n_blocks' do
let(:workflow_name) { 'jswt_cosmos' }
let(:n_blocks) { 2 }
Expand Down Expand Up @@ -285,13 +350,13 @@ def build_expected_body(manifest_url: nil, manifest_path: nil, workflow_name:, f
let(:workflow_name) { 'euclid' }

let(:request_body) do
{ manifest_url: manifest_url, opts: { workflow_name:} }
{ manifest_url: manifest_url, opts: { workflow_name: workflow_name } }
end

let(:request) do
stub_request(:post, request_url)
.with(
body: { manifest_url: manifest_url, opts: { workflow_name:} },
body: { manifest_url: manifest_url, opts: { workflow_name: workflow_name } },
headers: request_headers
)
end
Expand All @@ -301,7 +366,7 @@ def build_expected_body(manifest_url: nil, manifest_path: nil, workflow_name:, f
end

it 'sends the right workflow name' do
bajor_client.create_prediction_job(manifest_url, { workflow_name: })
bajor_client.create_prediction_job(manifest_url, { workflow_name: workflow_name })
expect(
a_request(:post, request_url).with(body: request_body, headers: request_headers)
).to have_been_made.once
Expand Down Expand Up @@ -345,6 +410,59 @@ def build_expected_body(manifest_url: nil, manifest_path: nil, workflow_name:, f
end
end

context 'with metadata-driven prediction config' do
let(:workflow_name) { 'jwst_cosmos' }
let(:fixed_crop) do
{
lower_left_x: 30,
lower_left_y: 30,
upper_right_x: 750,
upper_right_y: 750
}
end
let(:container_image_name) { 'zoobot.azurecr.io/pytorch:custom-jwst' }
let(:prediction_script_path) { 'jwst/predict_catalog_with_model.py' }
let(:pretrained_checkpoint_url) { 'https://kadeactivelearning.blob.core.windows.net/models/jwst/jwst-pretrained.ckpt' }
let(:request_body) do
build_expected_body(
manifest_url: manifest_url,
workflow_name: workflow_name,
fixed_crop: fixed_crop,
container_image_name: container_image_name,
prediction_script_path: prediction_script_path,
pretrained_checkpoint_url: pretrained_checkpoint_url
)
end
let(:request) do
stub_request(:post, request_url)
.with(
body: request_body,
headers: request_headers
)
end

before do
request.to_return(status: 201, body: bajor_response_body.to_json, headers: { content_type: 'application/json' })
end

it 'serializes metadata-driven prediction config into the request body' do
bajor_client.create_prediction_job(
manifest_url,
{
workflow_name: workflow_name,
fixed_crop: fixed_crop,
container_image_name: container_image_name,
prediction_script_path: prediction_script_path,
pretrained_checkpoint_url: pretrained_checkpoint_url
}
)

expect(
a_request(:post, request_url).with(body: request_body, headers: request_headers)
).to have_been_made.once
end
end

context 'with jswt_cosmos workflow and n_blocks' do
let(:workflow_name) { 'jswt_cosmos' }
let(:n_blocks) { 2 }
Expand Down
27 changes: 23 additions & 4 deletions spec/services/batch/prediction/create_job_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,7 @@
let(:prediction_create_job) { described_class.new(prediction_job, bajor_client_double) }
let(:job_service_url) { 'https://bajor-host/prediction/job/123' }
let(:prediction_options){
{
workflow_name: context.extractor_name,
fixed_crop: context.metadata['fixed_crop'],
}.compact
Batch::ContextRuntimeOptions.for_prediction(context)
}

context 'when bajor submission succeeds' do
Expand Down Expand Up @@ -63,6 +60,28 @@
end
end

describe 'prediction_job with nested batch metadata' do
let(:context){ contexts(:galaxy_zoo_cosmos_active_learning_project) }
let(:prediction_job) do
PredictionJob.new(
manifest_url: manifest_url,
state: :pending,
subject_set_id: context.active_subject_set_id,
probability_threshold: 0.5,
randomisation_factor: 0.5
)
end

it 'passes metadata-driven prediction config through to bajor' do
prediction_create_job.run

expect(bajor_client_double).to have_received(:create_prediction_job).with(
manifest_url,
Batch::ContextRuntimeOptions.for_prediction(context)
).once
end
end

it 'updates the state tracking info on the prediction job resource' do
expect {
prediction_create_job.run
Expand Down
27 changes: 23 additions & 4 deletions spec/services/batch/training/create_job_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

RSpec.describe Batch::Training::CreateJob do
describe '#run' do
fixtures :contexts

let(:manifest_path) { '/a/shared/blob/storage/path.csv' }
let(:manifest_url) { "https://a.shared.blob.storage#{manifest_path}"}
let(:training_job) { TrainingJob.new(manifest_url: manifest_url, workflow_id: '123', state: :pending) }
Expand All @@ -22,10 +24,7 @@

it 'calls the bajor client service with the correct manifest_path' do
parent_context = Context.find_by(workflow_id: training_job.workflow_id)
training_opts = {
workflow_name: parent_context.extractor_name,
fixed_crop: parent_context.metadata['fixed_crop'],
}.compact
training_opts = Batch::ContextRuntimeOptions.for_training(parent_context)

training_create_job.run
expect(bajor_client_double).to have_received(:create_training_job).with(manifest_path, training_opts).once
Expand All @@ -52,5 +51,25 @@
.and change(training_job, :message).from('').to(error_message)
end
end

context 'with metadata-driven batch config' do
let(:metadata_context) { contexts(:galaxy_zoo_cosmos_active_learning_project) }
let(:training_job) do
TrainingJob.new(
manifest_url: manifest_url,
workflow_id: metadata_context.workflow_id,
state: :pending
)
end

it 'passes nested batch metadata through to bajor' do
training_create_job.run

expect(bajor_client_double).to have_received(:create_training_job).with(
manifest_path,
Batch::ContextRuntimeOptions.for_training(metadata_context)
).once
end
end
end
end
Loading