diff --git a/.dockerignore b/.dockerignore index c3265dd..31d11c5 100644 --- a/.dockerignore +++ b/.dockerignore @@ -3,4 +3,9 @@ bin/ examples/ README.MD LICENCE -.github/ \ No newline at end of file +.github/ +.git/ +docs +tests +.dockerignore +.gitignore \ No newline at end of file diff --git a/.gitignore b/.gitignore index d00a2fb..77e7742 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ bin /stuff/ .DS_Store +.claude \ No newline at end of file diff --git a/examples/Dockerfile-cattery-tiny b/examples/Dockerfile-cattery-tiny index e6a54eb..ab62069 100644 --- a/examples/Dockerfile-cattery-tiny +++ b/examples/Dockerfile-cattery-tiny @@ -5,14 +5,20 @@ RUN apt-get update && apt-get install -y \ WORKDIR /action-runner -RUN curl -sL -o actions-runner-linux-x64-2.323.0.tar.gz https://github.com/actions/runner/releases/download/v2.323.0/actions-runner-linux-x64-2.323.0.tar.gz +ARG RUNNER_VERISON='2.333' +ENV RUNNER_VERISON=$RUNNER_VERISON + +RUN curl -sL -o actions-runner-linux-x64-${RUNNER_VERISON}.0.tar.gz https://github.com/actions/runner/releases/download/v${RUNNER_VERISON}.0/actions-runner-linux-x64-${RUNNER_VERISON}.0.tar.gz RUN ls -al -RUN tar xzf ./actions-runner-linux-x64-2.323.0.tar.gz +RUN tar xzf ./actions-runner-linux-x64-${RUNNER_VERISON}.0.tar.gz WORKDIR /cattery-agent COPY . . -RUN go build -o /action-runner/cattery/cattery +RUN ls -al /action-runner +RUN ls -al /cattery-agent +RUN ls -al . + -#ENTRYPOINT ["/action-runner/cattery/cattery", "agent", "-r","/action-runner", "-s", "http://10.10.10.116:5137"] \ No newline at end of file +RUN cd src && go build -o /action-runner/cattery/cattery \ No newline at end of file diff --git a/examples/docker-compose.yaml b/examples/docker-compose.yaml new file mode 100644 index 0000000..a8bb3c2 --- /dev/null +++ b/examples/docker-compose.yaml @@ -0,0 +1,39 @@ +version: '3.8' + +services: + mongo1: + image: mongo:latest + container_name: mongo1 + command: mongod --replSet rs0 --bind_ip_all + ports: + - "27017:27017" + volumes: + - mongo1-data:/data/db + networks: + - mongo-cluster + + mongo-config: + image: mongo:latest + container_name: mongo-config + depends_on: + - mongo1 + networks: + - mongo-cluster + # This command runs a script to initiate the replica set + command: > + bash -c "sleep 10 && mongosh --host mongo1:27017 < ttl { + logger.Warnf("Restart request for workflow %d expired (age: %v), deleting", req.WorkflowRunId, time.Since(req.CreatedAt)) + _ = wr.repository.DeleteRestartRequest(ctx, req.WorkflowRunId) + continue + } + + wr.handleRestartRequest(ctx, logger, req) } - log.Debugf("Restarting failed jobs for workflow run id %d", workflowRunId) - err = ghClient.RestartFailedJobs(repoName, workflowRunId) +} + +func (wr *WorkflowRestarter) handleRestartRequest(ctx context.Context, logger *log.Entry, req repositories.RestartRequest) { + ghClient, err := githubClient.NewGithubClientWithOrgName(req.OrgName) if err != nil { - log.Errorf("Failed to restart workflow run id %d: %v", workflowRunId, err) - return err + logger.Errorf("Failed to get GitHub client for org %s: %v", req.OrgName, err) + return } - log.Debugf("Successfully restarted failed jobs for workflow run id %d, removing restart request from DB", workflowRunId) - err = wr.repository.DeleteRestartRequest(workflowRunId) + + status, conclusion, err := ghClient.GetWorkflowRunStatus(req.RepoName, req.WorkflowRunId) if err != nil { - log.Errorf("Failed to delete restart request for workflow run id %d: %v", workflowRunId, err) - return err + logger.Errorf("Failed to get workflow run status for %d: %v", req.WorkflowRunId, err) + return } - log.Debugf("Finished restart request for workflow run id %d", workflowRunId) - return nil -} -// Cleanup clean db on cancelled or completed workflow runs -func (wr *WorkflowRestarter) Cleanup(workflowRunId int64, ghOrg string, repoName string) error { - log.Debugf("Cleanup for workflow run id %d", workflowRunId) - log.Debugf("Checking restart request for workflow run id %d", workflowRunId) - exists, err := wr.repository.CheckRestartRequest(workflowRunId) - if err != nil { - log.Errorf("Failed to check restart request: %s", err.Error()) - return err + if status != "completed" { + return } - if !exists { - log.Debugf("No restart request found for workflow run id %d", workflowRunId) - return nil + + switch conclusion { + case "failure": + logger.Infof("Restarting failed jobs for workflow run %d (%s/%s)", req.WorkflowRunId, req.OrgName, req.RepoName) + err = ghClient.RestartFailedJobs(req.RepoName, req.WorkflowRunId) + if err != nil { + logger.Errorf("Failed to restart workflow run %d: %v", req.WorkflowRunId, err) + return + } + logger.Infof("Successfully restarted failed jobs for workflow run %d", req.WorkflowRunId) + default: + logger.Debugf("Workflow run %d completed with conclusion '%s', cleaning up restart request", req.WorkflowRunId, conclusion) } - log.Debugf("Successfully cleaned up restart request for workflow run id %d, removing restart request from DB", workflowRunId) - err = wr.repository.DeleteRestartRequest(workflowRunId) - if err != nil { - log.Errorf("Failed to delete restart request for workflow run id %d: %v", workflowRunId, err) - return err + + if err := wr.repository.DeleteRestartRequest(ctx, req.WorkflowRunId); err != nil { + logger.Errorf("Failed to delete restart request for workflow %d: %v", req.WorkflowRunId, err) } - log.Debugf("Finished cleanup restart request for workflow run id %d", workflowRunId) - return nil } diff --git a/src/lib/scaleSetClient/scaleSetClient.go b/src/lib/scaleSetClient/scaleSetClient.go new file mode 100644 index 0000000..90602fb --- /dev/null +++ b/src/lib/scaleSetClient/scaleSetClient.go @@ -0,0 +1,145 @@ +package scaleSetClient + +import ( + "cattery/lib/config" + "context" + "fmt" + "os" + "strings" + "time" + + "github.com/actions/scaleset" + log "github.com/sirupsen/logrus" +) + +type ScaleSetClient struct { + client *scaleset.Client + session *scaleset.MessageSessionClient + scaleSet *scaleset.RunnerScaleSet + org *config.GitHubOrganization + trayType *config.TrayType + logger *log.Entry +} + +func NewScaleSetClient(org *config.GitHubOrganization, trayType *config.TrayType) (*ScaleSetClient, error) { + privateKey, err := os.ReadFile(org.PrivateKeyPath) + if err != nil { + return nil, fmt.Errorf("failed to read private key: %w", err) + } + + client, err := scaleset.NewClientWithGitHubApp(scaleset.ClientWithGitHubAppConfig{ + GitHubConfigURL: fmt.Sprintf("https://github.com/%s", org.Name), + GitHubAppAuth: scaleset.GitHubAppAuth{ + ClientID: org.AppClientId, + InstallationID: org.InstallationId, + PrivateKey: string(privateKey), + }, + }) + if err != nil { + return nil, fmt.Errorf("failed to create scale set client: %w", err) + } + + return &ScaleSetClient{ + client: client, + org: org, + trayType: trayType, + logger: log.WithFields(log.Fields{ + "component": "scaleSetClient", + "trayType": trayType.Name, + "org": org.Name, + }), + }, nil +} + +func (sc *ScaleSetClient) EnsureScaleSet(ctx context.Context) error { + existing, err := sc.client.GetRunnerScaleSet(ctx, int(sc.trayType.RunnerGroupId), sc.trayType.Name) + if err != nil { + return fmt.Errorf("failed to get scale set: %w", err) + } + if existing != nil { + sc.scaleSet = existing + sc.logger.Infof("Found existing scale set: %s (ID: %d)", existing.Name, existing.ID) + return nil + } + + sc.logger.Infof("Creating new scale set: %s", sc.trayType.Name) + created, err := sc.client.CreateRunnerScaleSet(ctx, &scaleset.RunnerScaleSet{ + Name: sc.trayType.Name, + RunnerGroupID: int(sc.trayType.RunnerGroupId), + Labels: []scaleset.Label{ + {Name: sc.trayType.Name, Type: "User"}, + }, + }) + if err != nil { + return fmt.Errorf("failed to create scale set: %w", err) + } + + sc.scaleSet = created + sc.logger.Infof("Created scale set: %s (ID: %d)", created.Name, created.ID) + return nil +} + +func (sc *ScaleSetClient) CreateSession(ctx context.Context) error { + hostname, _ := os.Hostname() + + const maxRetries = 5 + const retryDelay = 30 * time.Second + + for attempt := range maxRetries { + session, err := sc.client.MessageSessionClient(ctx, sc.scaleSet.ID, hostname) + if err == nil { + sc.session = session + sc.logger.Info("Message session created") + return nil + } + + if !strings.Contains(err.Error(), "409 Conflict") || attempt == maxRetries-1 { + return fmt.Errorf("failed to create message session: %w", err) + } + + sc.logger.Warnf("Session conflict (attempt %d/%d), stale session likely exists — retrying in %v", attempt+1, maxRetries, retryDelay) + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(retryDelay): + } + } + + return fmt.Errorf("unreachable") +} + +func (sc *ScaleSetClient) Poll(ctx context.Context, lastMessageID int, maxCapacity int) (*scaleset.RunnerScaleSetMessage, error) { + return sc.session.GetMessage(ctx, lastMessageID, maxCapacity) +} + +func (sc *ScaleSetClient) Ack(ctx context.Context, messageID int) error { + return sc.session.DeleteMessage(ctx, messageID) +} + +func (sc *ScaleSetClient) GenerateJitRunnerConfig(ctx context.Context, runnerName string) (*scaleset.RunnerScaleSetJitRunnerConfig, error) { + return sc.client.GenerateJitRunnerConfig(ctx, &scaleset.RunnerScaleSetJitRunnerSetting{ + Name: runnerName, + WorkFolder: "_work", + }, sc.scaleSet.ID) +} + +func (sc *ScaleSetClient) Close(ctx context.Context) error { + if sc.session != nil { + return sc.session.Close(ctx) + } + return nil +} + +func (sc *ScaleSetClient) GetScaleSetID() int { + if sc.scaleSet != nil { + return sc.scaleSet.ID + } + return 0 +} + +func (sc *ScaleSetClient) Session() scaleset.RunnerScaleSetSession { + if sc.session != nil { + return sc.session.Session() + } + return scaleset.RunnerScaleSetSession{} +} diff --git a/src/lib/scaleSetPoller/manager.go b/src/lib/scaleSetPoller/manager.go new file mode 100644 index 0000000..37e0d8e --- /dev/null +++ b/src/lib/scaleSetPoller/manager.go @@ -0,0 +1,27 @@ +package scaleSetPoller + +import "sync" + +type Manager struct { + mu sync.RWMutex + pollers map[string]*Poller + Wg sync.WaitGroup +} + +func NewManager() *Manager { + return &Manager{ + pollers: make(map[string]*Poller), + } +} + +func (m *Manager) Register(trayTypeName string, poller *Poller) { + m.mu.Lock() + defer m.mu.Unlock() + m.pollers[trayTypeName] = poller +} + +func (m *Manager) GetPoller(trayTypeName string) *Poller { + m.mu.RLock() + defer m.mu.RUnlock() + return m.pollers[trayTypeName] +} diff --git a/src/lib/scaleSetPoller/poller.go b/src/lib/scaleSetPoller/poller.go new file mode 100644 index 0000000..1b7dd2c --- /dev/null +++ b/src/lib/scaleSetPoller/poller.go @@ -0,0 +1,161 @@ +package scaleSetPoller + +import ( + "cattery/lib/metrics" + "cattery/lib/scaleSetClient" + "cattery/lib/trayManager" + "context" + "fmt" + "strconv" + "time" + + "cattery/lib/config" + + "github.com/actions/scaleset" + "github.com/actions/scaleset/listener" + log "github.com/sirupsen/logrus" +) + +type Poller struct { + client *scaleSetClient.ScaleSetClient + trayType *config.TrayType + trayManager *trayManager.TrayManager + logger *log.Entry +} + +func NewPoller( + client *scaleSetClient.ScaleSetClient, + trayType *config.TrayType, + tm *trayManager.TrayManager, +) *Poller { + return &Poller{ + client: client, + trayType: trayType, + trayManager: tm, + logger: log.WithFields(log.Fields{ + "component": "scaleSetPoller", + "trayType": trayType.Name, + }), + } +} + +func (p *Poller) Client() *scaleSetClient.ScaleSetClient { + return p.client +} + +func (p *Poller) Run(ctx context.Context) error { + p.logger.Info("Starting scale set poller") + + if err := p.client.EnsureScaleSet(ctx); err != nil { + return fmt.Errorf("failed to ensure scale set: %w", err) + } + + if err := p.client.CreateSession(ctx); err != nil { + return fmt.Errorf("failed to create session: %w", err) + } + defer func() { + closeCtx, closeCancel := context.WithTimeout(context.Background(), 10*time.Second) + defer closeCancel() + if err := p.client.Close(closeCtx); err != nil { + p.logger.Errorf("Failed to close session: %v", err) + } + }() + + scaleSetID := p.client.GetScaleSetID() + + scaler := &catteryScaler{poller: p} + + l, err := listener.New( + &sessionAdapter{client: p.client}, + listener.Config{ + ScaleSetID: scaleSetID, + MaxRunners: p.trayType.MaxTrays, + }, + listener.WithMetricsRecorder(scaler), + ) + if err != nil { + return fmt.Errorf("failed to create listener: %w", err) + } + + p.logger.Info("Entering listener loop") + return l.Run(ctx, scaler) +} + +// sessionAdapter adapts our ScaleSetClient to the listener.Client interface. +type sessionAdapter struct { + client *scaleSetClient.ScaleSetClient +} + +func (s *sessionAdapter) GetMessage(ctx context.Context, lastMessageID, maxCapacity int) (*scaleset.RunnerScaleSetMessage, error) { + return s.client.Poll(ctx, lastMessageID, maxCapacity) +} + +func (s *sessionAdapter) DeleteMessage(ctx context.Context, messageID int) error { + return s.client.Ack(ctx, messageID) +} + +func (s *sessionAdapter) Session() scaleset.RunnerScaleSetSession { + return s.client.Session() +} + +// catteryScaler implements the listener.Scaler and listener.MetricsRecorder interfaces. +type catteryScaler struct { + poller *Poller +} + +// MetricsRecorder implementation. + +func (cs *catteryScaler) RecordStatistics(statistics *scaleset.RunnerScaleSetStatistic) {} +func (cs *catteryScaler) RecordJobStarted(msg *scaleset.JobStarted) {} +func (cs *catteryScaler) RecordJobCompleted(msg *scaleset.JobCompleted) {} +func (cs *catteryScaler) RecordDesiredRunners(count int) {} + +func (cs *catteryScaler) HandleDesiredRunnerCount(ctx context.Context, count int) (int, error) { + err := cs.poller.trayManager.ScaleForDemand(ctx, cs.poller.trayType, count) + if err != nil { + cs.poller.logger.Errorf("Failed to scale for demand (%d): %v", count, err) + return 0, err + } + + return cs.poller.trayManager.CountTrays(ctx, cs.poller.trayType.Name) +} + +func (cs *catteryScaler) HandleJobStarted(ctx context.Context, jobInfo *scaleset.JobStarted) error { + cs.poller.logger.Infof("Job started: %s on runner %s (workflow run %d)", + jobInfo.JobDisplayName, jobInfo.RunnerName, jobInfo.WorkflowRunID) + + jobID, _ := strconv.ParseInt(jobInfo.JobID, 10, 64) + + tray, err := cs.poller.trayManager.SetJob(ctx, jobInfo.RunnerName, jobID, jobInfo.WorkflowRunID, jobInfo.RepositoryName) + if err != nil { + cs.poller.logger.Errorf("Failed to set job on tray %s: %v", jobInfo.RunnerName, err) + return err + } + + if tray == nil { + cs.poller.logger.Warnf("Tray %s not found for job %s (workflow run %d) — tray already removed", + jobInfo.RunnerName, jobInfo.JobDisplayName, jobInfo.WorkflowRunID) + } + + return nil +} + +func (cs *catteryScaler) HandleJobCompleted(ctx context.Context, jobInfo *scaleset.JobCompleted) error { + cs.poller.logger.Infof("Job completed: %s on runner %s (result: %s)", + jobInfo.JobDisplayName, jobInfo.RunnerName, jobInfo.Result) + + if jobInfo.RunnerName == "" { + cs.poller.logger.Warnf("Job completed with empty runner name (result: %s, job: %s) — skipping tray deletion", + jobInfo.Result, jobInfo.JobDisplayName) + return nil + } + + _, err := cs.poller.trayManager.DeleteTray(ctx, jobInfo.RunnerName) + if err != nil { + cs.poller.logger.Errorf("Failed to delete tray %s: %v", jobInfo.RunnerName, err) + return err + } + + metrics.RegisteredTraysAdd(cs.poller.trayType.GitHubOrg, cs.poller.trayType.Name, -1) + return nil +} diff --git a/src/lib/trayManager/trayManager.go b/src/lib/trayManager/trayManager.go index 80dd02e..75e7f20 100644 --- a/src/lib/trayManager/trayManager.go +++ b/src/lib/trayManager/trayManager.go @@ -2,15 +2,13 @@ package trayManager import ( "cattery/lib/config" - "cattery/lib/githubClient" - "cattery/lib/jobQueue" "cattery/lib/metrics" "cattery/lib/trays" "cattery/lib/trays/providers" "cattery/lib/trays/repositories" "context" - "errors" "fmt" + "sync" "time" log "github.com/sirupsen/logrus" @@ -18,58 +16,99 @@ import ( type TrayManager struct { trayRepository repositories.ITrayRepository - - isStaleTraysFound bool } func NewTrayManager(trayRepository repositories.ITrayRepository) *TrayManager { return &TrayManager{ - trayRepository: trayRepository, - isStaleTraysFound: false, + trayRepository: trayRepository, + } +} + +func (tm *TrayManager) createTrays(ctx context.Context, trayType *config.TrayType, count int) error { + maxParallel := trayType.MaxParallelCreation + if maxParallel <= 0 { + maxParallel = config.DefaultMaxParallelCreation + } + + results := tm.createTraysParallel(ctx, trayType, count, maxParallel) + return tm.logCreationResults(trayType.Name, results) +} + +// createTraysParallel creates trays concurrently, limited to maxParallel at a time. +// Returns a slice of errors, one per tray (nil means success). +func (tm *TrayManager) createTraysParallel(ctx context.Context, trayType *config.TrayType, count int, maxParallel int) []error { + var wg sync.WaitGroup + semaphore := make(chan struct{}, maxParallel) + errors := make([]error, count) + + for i := 0; i < count; i++ { + semaphore <- struct{}{} // block if maxParallel goroutines are already running + wg.Add(1) + + go func(index int) { + defer wg.Done() + defer func() { <-semaphore }() + + log.Infof("Creating tray %d/%d for type: %s", index+1, count, trayType.Name) + errors[index] = tm.CreateTray(ctx, trayType) + }(i) } + + wg.Wait() + return errors } -func (tm *TrayManager) createTrays(trayType *config.TrayType, n int) error { - for i := 0; i < n; i++ { - log.Infof("Creating tray %d for type: %s", i+1, trayType.Name) - err := tm.CreateTray(trayType) +func (tm *TrayManager) logCreationResults(trayTypeName string, results []error) error { + total := len(results) + failed := 0 + + for _, err := range results { if err != nil { - return err + log.Errorf("Failed to create tray for type %s: %v", trayTypeName, err) + failed++ } } + + if failed == total { + return fmt.Errorf("all %d tray creations failed for type %s", total, trayTypeName) + } + if failed > 0 { + log.Warnf("%d out of %d tray creations failed for type %s", failed, total, trayTypeName) + } + return nil } -func (tm *TrayManager) CreateTray(trayType *config.TrayType) error { - +func (tm *TrayManager) CreateTray(ctx context.Context, trayType *config.TrayType) error { provider, err := providers.GetProvider(trayType.Provider) if err != nil { - var errMsg = fmt.Sprintf("Failed to get provider for type %s: %v", trayType.Name, err) - log.Error(errMsg) - return errors.New(errMsg) + return fmt.Errorf("failed to get provider for type %s: %w", trayType.Name, err) } tray := trays.NewTray(*trayType) err = provider.RunTray(tray) if err != nil { - log.Errorf("Failed to run tray for provider '%s', tray '%s': %v", trayType.Provider, tray.GetId(), err) + log.Errorf("Failed to run tray for provider '%s', tray '%s': %v", trayType.Provider, tray.Id, err) metrics.TrayProviderErrors(tray.GitHubOrgName, tray.ProviderName, tray.TrayTypeName, "create") return err } - err = tm.trayRepository.Save(tray) + err = tm.trayRepository.Save(ctx, tray) if err != nil { - var errMsg = fmt.Sprintf("Failed to save tray %s: %v", trayType.Name, err) - log.Error(errMsg) - return errors.New(errMsg) + log.Errorf("Failed to save tray %s: %v — cleaning up provider resource", trayType.Name, err) + if cleanErr := provider.CleanTray(tray); cleanErr != nil { + log.Errorf("Failed to clean up tray %s after save failure: %v", tray.Id, cleanErr) + metrics.TrayProviderErrors(tray.GitHubOrgName, tray.ProviderName, tray.TrayTypeName, "delete") + } + return fmt.Errorf("failed to save tray %s: %w", trayType.Name, err) } return nil } -func (tm *TrayManager) GetTrayById(trayId string) (*trays.Tray, error) { - tray, err := tm.trayRepository.GetById(trayId) +func (tm *TrayManager) GetTrayById(ctx context.Context, trayId string) (*trays.Tray, error) { + tray, err := tm.trayRepository.GetById(ctx, trayId) if err != nil { return nil, err } @@ -80,63 +119,43 @@ func (tm *TrayManager) GetTrayById(trayId string) (*trays.Tray, error) { return tray, nil } -func (tm *TrayManager) Registering(trayId string) (*trays.Tray, error) { - tray, err := tm.trayRepository.UpdateStatus(trayId, trays.TrayStatusRegistering, 0, 0, 0) +func (tm *TrayManager) Registering(ctx context.Context, trayId string) (*trays.Tray, error) { + tray, err := tm.trayRepository.UpdateStatus(ctx, trayId, trays.TrayStatusRegistering, 0, 0, 0, "") if err != nil { return nil, err } if tray == nil { - var errorMsg = fmt.Sprintf("Failed to update tray status for tray '%s'", trayId) - return nil, errors.New(errorMsg) + return nil, fmt.Errorf("failed to update tray status for tray '%s'", trayId) } - return tray, nil } -func (tm *TrayManager) Registered(trayId string, ghRunnerId int64) (*trays.Tray, error) { - tray, err := tm.trayRepository.UpdateStatus(trayId, trays.TrayStatusRegistered, 0, 0, ghRunnerId) +func (tm *TrayManager) Registered(ctx context.Context, trayId string, ghRunnerId int64) (*trays.Tray, error) { + tray, err := tm.trayRepository.UpdateStatus(ctx, trayId, trays.TrayStatusRegistered, 0, 0, ghRunnerId, "") if err != nil { return nil, err } if tray == nil { - var errorMsg = fmt.Sprintf("Failed to update tray status for tray '%s'", trayId) - return nil, errors.New(errorMsg) + return nil, fmt.Errorf("failed to update tray status for tray '%s'", trayId) } - return tray, nil } -func (tm *TrayManager) SetJob(trayId string, jobRunId int64, workflowRunId int64) (*trays.Tray, error) { - tray, err := tm.trayRepository.UpdateStatus(trayId, trays.TrayStatusRunning, jobRunId, workflowRunId, 0) +func (tm *TrayManager) SetJob(ctx context.Context, trayId string, jobRunId int64, workflowRunId int64, repository string) (*trays.Tray, error) { + tray, err := tm.trayRepository.UpdateStatus(ctx, trayId, trays.TrayStatusRunning, jobRunId, workflowRunId, 0, repository) if err != nil { return nil, err } - if tray == nil { - var errorMsg = fmt.Sprintf("Failed to update tray status for tray '%s'", trayId) - return nil, errors.New(errorMsg) - } - return tray, nil } -func (tm *TrayManager) DeleteTray(trayId string) (*trays.Tray, error) { - - var tray, err = tm.trayRepository.UpdateStatus(trayId, trays.TrayStatusDeleting, 0, 0, 0) +func (tm *TrayManager) DeleteTray(ctx context.Context, trayId string) (*trays.Tray, error) { + tray, err := tm.trayRepository.UpdateStatus(ctx, trayId, trays.TrayStatusDeleting, 0, 0, 0, "") if err != nil { return nil, err } if tray == nil { - return nil, nil // Tray not found, nothing to delete - } - - ghClient, err := githubClient.NewGithubClientWithOrgName(tray.GetGitHubOrgName()) - if err != nil { - return nil, err - } - - err = ghClient.RemoveRunner(tray.GitHubRunnerId) - if err != nil { - return nil, err + return nil, nil } provider, err := providers.GetProviderForTray(tray) @@ -146,12 +165,12 @@ func (tm *TrayManager) DeleteTray(trayId string) (*trays.Tray, error) { err = provider.CleanTray(tray) if err != nil { - log.Errorf("Failed to delete tray for provider %s, tray %s: %v", provider.GetProviderName(), tray.GetId(), err) + log.Errorf("Failed to delete tray for provider %s, tray %s: %v", provider.GetProviderName(), tray.Id, err) metrics.TrayProviderErrors(tray.GitHubOrgName, tray.ProviderName, tray.TrayTypeName, "delete") return nil, err } - err = tm.trayRepository.Delete(trayId) + err = tm.trayRepository.Delete(ctx, trayId) if err != nil { return nil, err } @@ -160,8 +179,7 @@ func (tm *TrayManager) DeleteTray(trayId string) (*trays.Tray, error) { } func (tm *TrayManager) HandleStale(ctx context.Context) { - - var interval = time.Minute * 2 + interval := time.Minute * 2 go func() { for { @@ -169,10 +187,9 @@ func (tm *TrayManager) HandleStale(ctx context.Context) { case <-ctx.Done(): return default: - time.Sleep(interval / 2) - stale, err := tm.trayRepository.GetStale(interval, interval*2) + stale, err := tm.trayRepository.GetStale(ctx, interval) if err != nil { log.Errorf("Failed to get stale trays: %v", err) continue @@ -180,17 +197,13 @@ func (tm *TrayManager) HandleStale(ctx context.Context) { if len(stale) > 0 { log.Infof("Found %d stale trays: %v", len(stale), stale) - tm.isStaleTraysFound = true } for _, tray := range stale { - log.Debugf("Deleting stale tray: %s", tray.GetId()) - - _, err := tm.DeleteTray(tray.GetId()) - if err != nil { - log.Errorf("Failed to delete tray %s: %v", tray.GetId(), err) + log.Debugf("Deleting stale tray: %s", tray.Id) + if _, err := tm.DeleteTray(ctx, tray.Id); err != nil { + log.Errorf("Failed to delete tray %s: %v", tray.Id, err) } - metrics.StaleTraysInc(tray.GitHubOrgName, tray.TrayTypeName) } } @@ -198,84 +211,27 @@ func (tm *TrayManager) HandleStale(ctx context.Context) { }() } -func (tm *TrayManager) HandleJobsQueue(ctx context.Context, manager *jobQueue.QueueManager) { - go func() { - for { - select { - case <-ctx.Done(): - return - default: - - if tm.isStaleTraysFound { - err := manager.CleanupCompletedJobs() - if err != nil { - log.Errorf("Failed to cleanup completed jobs: %v", err) - } - tm.isStaleTraysFound = false - } - - var groups = manager.GetJobsCount() - for typeName, jobsCount := range groups { - err := tm.handleType(typeName, jobsCount) - if err != nil { - log.Error(err) - } - } - - time.Sleep(10 * time.Second) - } - } - }() -} - -func (tm *TrayManager) handleType(trayTypeName string, jobsInQueue int) error { - // log.Debugf("Handling tray type %s with %d jobs in queue", trayTypeName, jobsInQueue) - countByStatus, total, err := tm.trayRepository.CountByTrayType(trayTypeName) +// ScaleForDemand scales trays for a given tray type based on the desired runner count. +// Follows ARC's pattern: scale up when needed, let HandleJobCompleted and the stale +// handler take care of scale-down. No ghost detection — trust local tray state. +func (tm *TrayManager) ScaleForDemand(ctx context.Context, trayType *config.TrayType, desiredCount int) error { + activeCount, err := tm.CountTrays(ctx, trayType.Name) if err != nil { - log.Errorf("Failed to count trays for type %s: %v", trayTypeName, err) return err } - var traysWithNoJob = countByStatus[trays.TrayStatusCreating] + countByStatus[trays.TrayStatusRegistering] + countByStatus[trays.TrayStatusRegistered] - // log.Debugf("Tray type %s has %d trays, %d with no job", trayTypeName, total, traysWithNoJob) - if jobsInQueue > traysWithNoJob { - var trayType = getTrayType(trayTypeName) - if trayType == nil { - log.Warnf("Tray type '%s' not found in config; skipping creation", trayTypeName) - return nil - } - - var remainingTrays = trayType.MaxTrays - total - var traysToCreate = jobsInQueue - traysWithNoJob - if traysToCreate > remainingTrays { - traysToCreate = remainingTrays - } - - err := tm.createTrays(trayType, traysToCreate) - if err != nil { - return err - } + if desiredCount <= activeCount { + return nil } - if jobsInQueue < traysWithNoJob { - var traysToDelete = traysWithNoJob - jobsInQueue - redundant, err := tm.trayRepository.MarkRedundant(trayTypeName, traysToDelete) - if err != nil { - return err - } - - for _, tray := range redundant { - if _, delErr := tm.DeleteTray(tray.Id); delErr != nil { - log.Errorf("Failed to delete redundant tray %s: %v", tray.Id, delErr) - } - } - + traysToCreate := min(desiredCount-activeCount, trayType.MaxTrays-activeCount) + if traysToCreate > 0 { + return tm.createTrays(ctx, trayType, traysToCreate) } - return nil } -func getTrayType(trayTypeName string) *config.TrayType { - var trayType = config.AppConfig.GetTrayType(trayTypeName) - return trayType +// CountTrays returns the number of active (non-deleting) trays for a given tray type. +func (tm *TrayManager) CountTrays(ctx context.Context, trayTypeName string) (int, error) { + return tm.trayRepository.CountActive(ctx, trayTypeName) } diff --git a/src/lib/trays/providers/dockerProvider.go b/src/lib/trays/providers/dockerProvider.go index 102b8a3..93d2a3d 100644 --- a/src/lib/trays/providers/dockerProvider.go +++ b/src/lib/trays/providers/dockerProvider.go @@ -36,29 +36,21 @@ func (d *DockerProvider) GetProviderName() string { return d.name } -func (d *DockerProvider) GetTray(id string) (*trays.Tray, error) { - //TODO implement me - panic("implement me") -} - -func (d *DockerProvider) ListTrays() ([]*trays.Tray, error) { - //TODO implement me - panic("implement me") -} - func (d *DockerProvider) RunTray(tray *trays.Tray) error { - var containerName = tray.GetId() + var containerName = tray.Id - var trayConfig = tray.GetTrayConfig().(config.DockerTrayConfig) + var trayConfig = tray.TrayConfig().(config.DockerTrayConfig) var image = trayConfig.Image + var serverUrl = config.AppConfig.Server.AdvertiseUrl + var dockerCommand = exec.Command("docker", "run", "-d", "--rm", "--add-host=host.docker.internal:host-gateway", "--name", containerName, image, - "/action-runner/cattery/cattery", "agent", "-i", tray.GetId(), "-s", "http://host.docker.internal:5137", "--runner-folder", "/action-runner") + "/action-runner/cattery/cattery", "agent", "-i", tray.Id, "-s", serverUrl, "--runner-folder", "/action-runner") d.logger.Info("Running docker command: ", dockerCommand.String()) err := dockerCommand.Run() @@ -72,14 +64,14 @@ func (d *DockerProvider) RunTray(tray *trays.Tray) error { } func (d *DockerProvider) CleanTray(tray *trays.Tray) error { - var dockerCommand = exec.Command("docker", "container", "stop", tray.GetId()) + var dockerCommand = exec.Command("docker", "container", "stop", tray.Id) dockerCommandOutput, err := dockerCommand.CombinedOutput() if err != nil { output := string(dockerCommandOutput) d.logger.Trace(output) if strings.Contains(strings.ToLower(output), "no such container") { - d.logger.Trace("No such container: ", tray.GetId()) + d.logger.Trace("No such container: ", tray.Id) return nil } return err diff --git a/src/lib/trays/providers/gceProvider.go b/src/lib/trays/providers/gceProvider.go index 8076371..e8c2785 100644 --- a/src/lib/trays/providers/gceProvider.go +++ b/src/lib/trays/providers/gceProvider.go @@ -47,20 +47,10 @@ func (g *GceProvider) GetProviderName() string { return g.Name } -func (g *GceProvider) GetTray(id string) (*trays.Tray, error) { - //TODO implement me - panic("implement me") -} - -func (g *GceProvider) ListTrays() ([]*trays.Tray, error) { - //TODO implement me - panic("implement me") -} - func (g *GceProvider) RunTray(tray *trays.Tray) error { ctx := context.Background() - var trayConfig = tray.GetTrayConfig().(config.GoogleTrayConfig) + var trayConfig = tray.TrayConfig().(config.GoogleTrayConfig) var ( project = g.providerConfig.Get("project") @@ -69,23 +59,28 @@ func (g *GceProvider) RunTray(tray *trays.Tray) error { machineType = trayConfig.MachineType ) + var extraMetadata config.TrayExtraMetadata + if tt := tray.TrayType(); tt != nil { + extraMetadata = tt.ExtraMetadata + } + var metadata = createGcpMetadata( map[string]string{ "cattery-url": config.AppConfig.Server.AdvertiseUrl, - "cattery-agent-id": tray.GetId(), + "cattery-agent-id": tray.Id, }, - tray.GetTrayType().ExtraMetadata, + extraMetadata, ) var zone = zones[rand.Intn(len(zones))] - _, err := g.instanceClient.Insert(ctx, &computepb.InsertInstanceRequest{ + op, err := g.instanceClient.Insert(ctx, &computepb.InsertInstanceRequest{ Project: project, Zone: zone, SourceInstanceTemplate: &instanceTemplate, InstanceResource: &computepb.Instance{ MachineType: proto.String(fmt.Sprintf("zones/%s/machineTypes/%s", zone, machineType)), - Name: proto.String(tray.GetId()), + Name: proto.String(tray.Id), Metadata: metadata, }, }) @@ -94,6 +89,11 @@ func (g *GceProvider) RunTray(tray *trays.Tray) error { return err } + if err := op.Wait(ctx); err != nil { + g.logger.Errorf("Failed waiting for tray creation to complete: %v", err) + return err + } + tray.ProviderData["zone"] = zone return nil @@ -111,7 +111,7 @@ func (g *GceProvider) CleanTray(tray *trays.Tray) error { ) _, err = client.Delete(context.Background(), &computepb.DeleteInstanceRequest{ - Instance: tray.GetId(), + Instance: tray.Id, Project: project, Zone: zone, }) @@ -121,7 +121,7 @@ func (g *GceProvider) CleanTray(tray *trays.Tray) error { if e.Code != 404 { return err } else { - g.logger.Tracef("Tray not found during deletion; skipping: %v (tray %s)", err, tray.GetId()) + g.logger.Tracef("Tray not found during deletion; skipping: %v (tray %s)", err, tray.Id) return nil } } diff --git a/src/lib/trays/providers/iTrayProvider.go b/src/lib/trays/providers/iTrayProvider.go index f4b3d98..f1c1018 100644 --- a/src/lib/trays/providers/iTrayProvider.go +++ b/src/lib/trays/providers/iTrayProvider.go @@ -7,12 +7,6 @@ import ( type ITrayProvider interface { GetProviderName() string - // GetTray returns the tray with the given ID. - GetTray(id string) (*trays.Tray, error) - - // ListTrays returns all trays. - ListTrays() ([]*trays.Tray, error) - // RunTray spawns a new tray. RunTray(tray *trays.Tray) error diff --git a/src/lib/trays/providers/trayProviderFactory.go b/src/lib/trays/providers/trayProviderFactory.go index 0ee887f..9e89b2d 100644 --- a/src/lib/trays/providers/trayProviderFactory.go +++ b/src/lib/trays/providers/trayProviderFactory.go @@ -4,10 +4,15 @@ import ( "cattery/lib/config" "cattery/lib/trays" "errors" + "sync" + log "github.com/sirupsen/logrus" ) -var providers = make(map[string]ITrayProvider) +var ( + providersMu sync.RWMutex + providers = make(map[string]ITrayProvider) +) var logger = log.WithFields(log.Fields{ "name": "trayProviderFactory", @@ -28,35 +33,43 @@ func GetProviderByTrayTypeName(trayTypeName string) (ITrayProvider, error) { } func GetProvider(providerName string) (ITrayProvider, error) { - + providersMu.RLock() if existingProvider, ok := providers[providerName]; ok { + providersMu.RUnlock() return existingProvider, nil } + providersMu.RUnlock() var result ITrayProvider var p = config.AppConfig.GetProvider(providerName) if p == nil { - var err = errors.New("No provider found for " + providerName) - logger.Error(err.Error()) - return nil, err + return nil, errors.New("no provider found for " + providerName) } var provider = *p switch provider["type"] { case "docker": - result = NewDockerProvider(providerName, provider) + if p := NewDockerProvider(providerName, provider); p != nil { + result = p + } case "google": - result = NewGceProvider(providerName, provider) + if p := NewGceProvider(providerName, provider); p != nil { + result = p + } default: - var errMsg = "Unknown provider: " + providerName - logger.Error(errMsg) - return nil, errors.New(errMsg) + return nil, errors.New("unknown provider type: " + provider["type"]) + } + + if result == nil { + return nil, errors.New("failed to initialize provider: " + providerName) } + providersMu.Lock() providers[providerName] = result + providersMu.Unlock() return result, nil } diff --git a/src/lib/trays/repositories/iTrayRepository.go b/src/lib/trays/repositories/iTrayRepository.go index fd22ed7..1ddbdf6 100644 --- a/src/lib/trays/repositories/iTrayRepository.go +++ b/src/lib/trays/repositories/iTrayRepository.go @@ -2,15 +2,15 @@ package repositories import ( "cattery/lib/trays" + "context" "time" ) type ITrayRepository interface { - GetById(trayId string) (*trays.Tray, error) - Save(tray *trays.Tray) error - Delete(trayId string) error - UpdateStatus(trayId string, status trays.TrayStatus, jobRunId int64, workflowRunId int64, ghRunnerId int64) (*trays.Tray, error) - CountByTrayType(trayType string) (map[trays.TrayStatus]int, int, error) - MarkRedundant(trayType string, limit int) ([]*trays.Tray, error) - GetStale(d time.Duration, rd time.Duration) ([]*trays.Tray, error) + GetById(ctx context.Context, trayId string) (*trays.Tray, error) + Save(ctx context.Context, tray *trays.Tray) error + Delete(ctx context.Context, trayId string) error + UpdateStatus(ctx context.Context, trayId string, status trays.TrayStatus, jobRunId int64, workflowRunId int64, ghRunnerId int64, repository string) (*trays.Tray, error) + CountActive(ctx context.Context, trayType string) (int, error) + GetStale(ctx context.Context, d time.Duration) ([]*trays.Tray, error) } diff --git a/src/lib/trays/repositories/mongodbTrayRepository.go b/src/lib/trays/repositories/mongodbTrayRepository.go index 9164820..b022693 100644 --- a/src/lib/trays/repositories/mongodbTrayRepository.go +++ b/src/lib/trays/repositories/mongodbTrayRepository.go @@ -23,14 +23,13 @@ func (m *MongodbTrayRepository) Connect(collection *mongo.Collection) { m.collection = collection } -func (m *MongodbTrayRepository) GetById(trayId string) (*trays.Tray, error) { - dbResult := m.collection.FindOne(context.Background(), bson.M{"id": trayId}) +func (m *MongodbTrayRepository) GetById(ctx context.Context, trayId string) (*trays.Tray, error) { + dbResult := m.collection.FindOne(ctx, bson.M{"id": trayId}) var result trays.Tray err := dbResult.Decode(&result) if err != nil { if errors.Is(err, mongo.ErrNoDocuments) { - // Handle the "not found" case implicitly return nil, nil } return nil, err @@ -39,98 +38,58 @@ func (m *MongodbTrayRepository) GetById(trayId string) (*trays.Tray, error) { return &result, nil } -func (m *MongodbTrayRepository) GetStale(d time.Duration, rd time.Duration) ([]*trays.Tray, error) { - dbResult, err := m.collection.Find(context.Background(), - bson.M{"$or": []bson.M{ - { - "status": bson.M{"$ne": trays.TrayStatusRunning}, - "statusChanged": bson.M{"$lte": time.Now().UTC().Add(-d)}, - }, - }, +func (m *MongodbTrayRepository) GetStale(ctx context.Context, d time.Duration) ([]*trays.Tray, error) { + dbResult, err := m.collection.Find(ctx, + bson.M{ + "status": bson.M{"$ne": trays.TrayStatusRunning}, + "statusChanged": bson.M{"$lte": time.Now().UTC().Add(-d)}, }) if err != nil { return nil, err } var traysArr []*trays.Tray - if err := dbResult.All(context.Background(), &traysArr); err != nil { + if err := dbResult.All(ctx, &traysArr); err != nil { return nil, err } return traysArr, nil - -} - -func (m *MongodbTrayRepository) MarkRedundant(trayType string, limit int) ([]*trays.Tray, error) { - - var resultTrays = make([]*trays.Tray, 0) - var ids = make([]string, 0) - - for i := 0; i < limit; i++ { - dbResult := m.collection.FindOneAndUpdate( - context.Background(), - bson.M{"status": trays.TrayStatusCreating, "trayTypeName": trayType}, - bson.M{"$set": bson.M{"status": trays.TrayStatusDeleting, "statusChanged": time.Now().UTC(), "jobRunId": 0}}, - options.FindOneAndUpdate().SetReturnDocument(options.After)) - - var result trays.Tray - err := dbResult.Decode(&result) - if err != nil { - if errors.Is(err, mongo.ErrNoDocuments) { - break - } - return nil, err - } - - resultTrays = append(resultTrays, &result) - ids = append(ids, result.Id) - } - - return resultTrays, nil } -func (m *MongodbTrayRepository) GetByJobRunId(jobRunId int64) (*trays.Tray, error) { - dbResult := m.collection.FindOne(context.Background(), bson.M{"jobRunId": jobRunId}) - - var result trays.Tray - err := dbResult.Decode(&result) +func (m *MongodbTrayRepository) CountActive(ctx context.Context, trayType string) (int, error) { + count, err := m.collection.CountDocuments(ctx, bson.M{ + "trayTypeName": trayType, + "status": bson.M{"$ne": trays.TrayStatusDeleting}, + }) if err != nil { - if errors.Is(err, mongo.ErrNoDocuments) { - return nil, nil - } - return nil, err + return 0, err } - - return &result, nil + return int(count), nil } -func (m *MongodbTrayRepository) Save(tray *trays.Tray) error { +func (m *MongodbTrayRepository) Save(ctx context.Context, tray *trays.Tray) error { tray.StatusChanged = time.Now().UTC() - _, err := m.collection.InsertOne(context.Background(), tray) - if err != nil { - return err - } - - return nil + _, err := m.collection.InsertOne(ctx, tray) + return err } -func (m *MongodbTrayRepository) UpdateStatus(trayId string, status trays.TrayStatus, jobRunId int64, workflowRunId int64, ghRunnerId int64) (*trays.Tray, error) { - - var setQuery = bson.M{"status": status, "statusChanged": time.Now().UTC()} +func (m *MongodbTrayRepository) UpdateStatus(ctx context.Context, trayId string, status trays.TrayStatus, jobRunId int64, workflowRunId int64, ghRunnerId int64, repository string) (*trays.Tray, error) { + setQuery := bson.M{"status": status, "statusChanged": time.Now().UTC()} if jobRunId != 0 { setQuery["jobRunId"] = jobRunId } - if ghRunnerId != 0 { setQuery["gitHubRunnerId"] = ghRunnerId } - if workflowRunId != 0 { setQuery["workflowRunId"] = workflowRunId } + if repository != "" { + setQuery["repository"] = repository + } dbResult := m.collection.FindOneAndUpdate( - context.Background(), + ctx, bson.M{"id": trayId}, bson.M{"$set": setQuery}, options.FindOneAndUpdate().SetReturnDocument(options.After)) @@ -147,53 +106,8 @@ func (m *MongodbTrayRepository) UpdateStatus(trayId string, status trays.TraySta return &result, nil } -func (m *MongodbTrayRepository) Delete(trayId string) error { - _, err := m.collection.DeleteOne(context.Background(), bson.M{"id": trayId}) - if err != nil { - return err - } - - return nil +func (m *MongodbTrayRepository) Delete(ctx context.Context, trayId string) error { + _, err := m.collection.DeleteOne(ctx, bson.M{"id": trayId}) + return err } -func (m *MongodbTrayRepository) CountByTrayType(trayType string) (map[trays.TrayStatus]int, int, error) { - - var matchStage = bson.D{ - {"$match", bson.D{{"trayTypeName", trayType}}}, - } - var groupStage = bson.D{ - {"$group", bson.D{ - {"_id", "$status"}, - {"count", bson.D{{"$sum", 1}}}, - }}} - - cursor, err := m.collection.Aggregate(context.Background(), mongo.Pipeline{matchStage, groupStage}) - if err != nil { - return nil, 0, err - } - - var dbResults []bson.M - if err = cursor.All(context.TODO(), &dbResults); err != nil { - return nil, 0, err - } - - var result = make(map[trays.TrayStatus]int) - result[trays.TrayStatusCreating] = 0 - result[trays.TrayStatusRegistering] = 0 - result[trays.TrayStatusDeleting] = 0 - result[trays.TrayStatusRegistered] = 0 - result[trays.TrayStatusRunning] = 0 - - var total = 0 - - for _, res := range dbResults { - var int32Status = res["_id"].(int32) - - status := int32Status - cnt, _ := res["count"].(int32) - result[trays.TrayStatus(status)] = int(cnt) - total += int(cnt) - } - return result, total, nil - -} diff --git a/src/lib/trays/repositories/mongodbTrayRepository_test.go b/src/lib/trays/repositories/mongodbTrayRepository_test.go index 91f671d..fa6128a 100644 --- a/src/lib/trays/repositories/mongodbTrayRepository_test.go +++ b/src/lib/trays/repositories/mongodbTrayRepository_test.go @@ -8,7 +8,6 @@ import ( "testing" "time" - "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/mongo" "go.mongodb.org/mongo-driver/v2/mongo/options" ) @@ -92,7 +91,7 @@ func TestGetById(t *testing.T) { insertTestTrays(t, collection, []*TestTray{testTray}) // Test GetById - tray, err := repo.GetById("test-tray-1") + tray, err := repo.GetById(context.Background(),"test-tray-1") if err != nil { t.Fatalf("GetById failed: %v", err) } @@ -114,7 +113,7 @@ func TestGetById(t *testing.T) { } // Test GetById with non-existent ID - tray, err = repo.GetById("non-existent") + tray, err = repo.GetById(context.Background(),"non-existent") if err != nil { t.Error("Expected no error for non-existent tray, got: ", err) } @@ -148,13 +147,13 @@ func TestSave(t *testing.T) { tray.ProviderData["something"] = "worker-1" // Test Save - err := repo.Save(tray) + err := repo.Save(context.Background(),tray) if err != nil { t.Fatalf("Save failed: %v", err) } // Verify the tray was saved - savedTray, err := repo.GetById(tray.Id) + savedTray, err := repo.GetById(context.Background(),tray.Id) if err != nil { t.Fatalf("Failed to get saved tray: %v", err) } @@ -201,7 +200,7 @@ func TestUpdateStatus(t *testing.T) { insertTestTrays(t, collection, []*TestTray{testTray}) // Test UpdateStatus with jobRunId only - updatedTray, err := repo.UpdateStatus("test-tray-1", trays.TrayStatusRegistered, 123, 0, 0) + updatedTray, err := repo.UpdateStatus(context.Background(),"test-tray-1", trays.TrayStatusRegistered, 123, 0, 0, "") if err != nil { t.Fatalf("UpdateStatus failed: %v", err) } @@ -219,7 +218,7 @@ func TestUpdateStatus(t *testing.T) { } // Test UpdateStatus with ghRunnerId - updatedTray, err = repo.UpdateStatus("test-tray-1", trays.TrayStatusRunning, 456, 333, 789) + updatedTray, err = repo.UpdateStatus(context.Background(),"test-tray-1", trays.TrayStatusRunning, 456, 333, 789, "") if err != nil { t.Fatalf("UpdateStatus with ghRunnerId failed: %v", err) } @@ -241,7 +240,7 @@ func TestUpdateStatus(t *testing.T) { } // Test UpdateStatus with non-existent ID - updatedTray, err = repo.UpdateStatus("non-existent", trays.TrayStatusRegistered, 123, 0, 0) + updatedTray, err = repo.UpdateStatus(context.Background(),"non-existent", trays.TrayStatusRegistered, 123, 0, 0, "") if err != nil { t.Fatalf("UpdateStatus with non-existent ID failed: %v", err) } @@ -265,13 +264,13 @@ func TestDelete(t *testing.T) { insertTestTrays(t, collection, []*TestTray{testTray}) // Test Delete - err := repo.Delete("test-tray-1") + err := repo.Delete(context.Background(),"test-tray-1") if err != nil { t.Fatalf("Delete failed: %v", err) } // Verify the tray was deleted - deletedTray, err := repo.GetById("test-tray-1") + deletedTray, err := repo.GetById(context.Background(),"test-tray-1") if err != nil { t.Error("Expected no error for deleted tray, got: ", err) } @@ -281,194 +280,57 @@ func TestDelete(t *testing.T) { } // Test Delete with non-existent ID - err = repo.Delete("non-existent") + err = repo.Delete(context.Background(),"non-existent") if err != nil { t.Fatalf("Delete with non-existent ID failed: %v", err) } } -// TestGetByJobRunId tests the GetByJobRunId method -func TestGetByJobRunId(t *testing.T) { +// TestCountActive tests the CountActive method +func TestCountActive(t *testing.T) { client, collection := setupTestCollection(t) defer client.Disconnect(context.Background()) - // Create test repository - repo := NewMongodbTrayRepository() - repo.Connect(collection) - - // Insert test data - testTray1 := createTestTray("test-tray-1", "test-type", trays.TrayStatusRunning, 123) - testTray2 := createTestTray("test-tray-2", "test-type", trays.TrayStatusCreating, 0) - insertTestTrays(t, collection, []*TestTray{testTray1, testTray2}) - - // Test GetByJobRunId - tray, err := repo.GetByJobRunId(123) - if err != nil { - t.Fatalf("GetByJobRunId failed: %v", err) - } - - if tray == nil { - t.Fatal("GetByJobRunId returned nil tray") - } - - if tray.Id != "test-tray-1" { - t.Errorf("Expected tray ID 'test-tray-1', got '%s'", tray.Id) - } - - if tray.JobRunId != 123 { - t.Errorf("Expected JobRunId 123, got %d", tray.JobRunId) - } - - // Test GetByJobRunId with non-existent JobRunId - tray, err = repo.GetByJobRunId(999) - if err != nil { - t.Fatalf("GetByJobRunId with non-existent JobRunId failed: %v", err) - } - - if tray != nil { - t.Error("Expected nil tray for non-existent JobRunId, got non-nil") - } -} - -// TestMarkRedundant tests the MarkRedundant method -func TestMarkRedundant(t *testing.T) { - client, collection := setupTestCollection(t) - defer client.Disconnect(context.Background()) - - // Create test repository repo := NewMongodbTrayRepository() repo.Connect(collection) - // Insert test data - testTray1 := createTestTray("test-tray-1", "test-type", trays.TrayStatusCreating, 0) - testTray2 := createTestTray("test-tray-2", "test-type", trays.TrayStatusCreating, 0) - testTray3 := createTestTray("test-tray-3", "test-type", trays.TrayStatusRegistered, 0) - testTray4 := createTestTray("test-tray-4", "other-type", trays.TrayStatusCreating, 0) - insertTestTrays(t, collection, []*TestTray{testTray1, testTray2, testTray3, testTray4}) - - // Test MarkRedundant - redundantTrays, err := repo.MarkRedundant("test-type", 2) - if err != nil { - t.Fatalf("MarkRedundant failed: %v", err) - } - - // Verify that the correct number of trays were marked as redundant - if len(redundantTrays) != 2 { - t.Errorf("Expected 2 redundant trays, got %d", len(redundantTrays)) - } - - // Verify that the trays were actually marked as deleting in the database - // by querying the database directly - cursor, err := collection.Find(context.Background(), bson.M{"trayTypeName": "test-type", "status": trays.TrayStatusDeleting}) - if err != nil { - t.Fatalf("Failed to query database: %v", err) - } - - var deletingTrays []TestTray - err = cursor.All(context.Background(), &deletingTrays) - if err != nil { - t.Fatalf("Failed to decode cursor: %v", err) - } - - if len(deletingTrays) != 2 { - t.Errorf("Expected 2 trays marked as deleting in the database, got %d", len(deletingTrays)) - } - - // Verify that the correct trays were marked as deleting - deletingTrayIds := make(map[string]bool) - for _, tray := range deletingTrays { - deletingTrayIds[tray.Id] = true - - // Verify the status and jobRunId were updated correctly - if tray.Status != trays.TrayStatusDeleting { - t.Errorf("Expected tray status %v, got %v", trays.TrayStatusDeleting, tray.Status) - } - - if tray.JobRunId != 0 { - t.Errorf("Expected JobRunId 0, got %d", tray.JobRunId) - } - } - - // Check that the correct trays were marked as deleting - if !deletingTrayIds["test-tray-1"] { - t.Error("Expected test-tray-1 to be marked as deleting") - } - - if !deletingTrayIds["test-tray-2"] { - t.Error("Expected test-tray-2 to be marked as deleting") - } - - // Verify that trays with different status or type were not affected - unchangedTray, err := repo.GetById("test-tray-3") - if err != nil { - t.Fatalf("Failed to get test-tray-3: %v", err) - } - - if unchangedTray.Status != trays.TrayStatusRegistered { - t.Errorf("Expected test-tray-3 status to remain %v, got %v", trays.TrayStatusRegistered, unchangedTray.Status) - } - - unchangedTray, err = repo.GetById("test-tray-4") - if err != nil { - t.Fatalf("Failed to get test-tray-4: %v", err) - } - - if unchangedTray.Status != trays.TrayStatusCreating { - t.Errorf("Expected test-tray-4 status to remain %v, got %v", trays.TrayStatusCreating, unchangedTray.Status) - } - - // Test MarkRedundant with limit - // Add more test trays - testTray5 := createTestTray("test-tray-5", "test-type", trays.TrayStatusCreating, 0) - testTray6 := createTestTray("test-tray-6", "test-type", trays.TrayStatusCreating, 0) - insertTestTrays(t, collection, []*TestTray{testTray5, testTray6}) - - // Mark only 1 tray as redundant - redundantTrays, err = repo.MarkRedundant("test-type", 1) - if err != nil { - t.Fatalf("MarkRedundant with limit failed: %v", err) - } - - // Verify that only 1 more tray was marked as deleting - cursor, err = collection.Find(context.Background(), bson.M{"trayTypeName": "test-type", "status": trays.TrayStatusDeleting}) - if err != nil { - t.Fatalf("Failed to query database: %v", err) + // Insert test data: 2 Creating, 1 Registered, 1 Running, 2 Deleting for test-type + testTrays := []*TestTray{ + createTestTray("test-tray-1", "test-type", trays.TrayStatusCreating, 0), + createTestTray("test-tray-2", "test-type", trays.TrayStatusCreating, 0), + createTestTray("test-tray-3", "test-type", trays.TrayStatusRegistered, 0), + createTestTray("test-tray-4", "test-type", trays.TrayStatusRunning, 0), + createTestTray("test-tray-5", "test-type", trays.TrayStatusDeleting, 0), + createTestTray("test-tray-6", "test-type", trays.TrayStatusDeleting, 0), + createTestTray("other-tray-1", "other-type", trays.TrayStatusCreating, 0), } + insertTestTrays(t, collection, testTrays) - err = cursor.All(context.Background(), &deletingTrays) + // Active = all non-deleting = 2 + 1 + 1 = 4 + count, err := repo.CountActive(context.Background(), "test-type") if err != nil { - t.Fatalf("Failed to decode cursor: %v", err) + t.Fatalf("CountActive failed: %v", err) } - - if len(deletingTrays) != 3 { - t.Errorf("Expected 3 trays marked as deleting in the database, got %d", len(deletingTrays)) + if count != 4 { + t.Errorf("Expected 4 active trays, got %d", count) } - // Test MarkRedundant with non-existent tray type - redundantTrays, err = repo.MarkRedundant("non-existent", 2) + // other-type: 1 active + count, err = repo.CountActive(context.Background(), "other-type") if err != nil { - t.Fatalf("MarkRedundant with non-existent tray type failed: %v", err) + t.Fatalf("CountActive for other-type failed: %v", err) } - - if len(redundantTrays) != 0 { - t.Errorf("Expected 0 redundant trays for non-existent type, got %d", len(redundantTrays)) - } - - // Test MarkRedundant with empty collection - // Clear the collection - err = collection.Drop(context.Background()) - if err != nil { - t.Fatalf("Failed to drop collection: %v", err) + if count != 1 { + t.Errorf("Expected 1 active tray for other-type, got %d", count) } - // Try to mark redundant trays in an empty collection - redundantTrays, err = repo.MarkRedundant("test-type", 2) + // non-existent type: 0 + count, err = repo.CountActive(context.Background(), "non-existent") if err != nil { - t.Fatalf("MarkRedundant with empty collection failed: %v", err) + t.Fatalf("CountActive for non-existent type failed: %v", err) } - - if len(redundantTrays) != 0 { - t.Errorf("Expected 0 redundant trays for empty collection, got %d", len(redundantTrays)) + if count != 0 { + t.Errorf("Expected 0 active trays for non-existent type, got %d", count) } } @@ -500,7 +362,7 @@ func TestGetStale(t *testing.T) { insertTestTrays(t, collection, []*TestTray{staleTray1, staleTray2, freshTray1, freshTray2}) // Test GetStale with 5 minute duration - staleTrays, err := repo.GetStale(5*time.Minute, 5*time.Minute) + staleTrays, err := repo.GetStale(context.Background(),5*time.Minute) if err != nil { t.Fatalf("GetStale failed: %v", err) } @@ -545,7 +407,7 @@ func TestGetStale(t *testing.T) { insertTestTrays(t, collection, []*TestTray{freshTray1, freshTray2}) // Test GetStale again with 5 minute duration - staleTrays, err = repo.GetStale(5*time.Minute, 5*time.Minute) + staleTrays, err = repo.GetStale(context.Background(),5*time.Minute) if err != nil { t.Fatalf("GetStale failed: %v", err) } @@ -605,7 +467,7 @@ func TestConnect(t *testing.T) { insertTestTrays(t, collection, []*TestTray{testTray}) // Try to get the tray using the repository - tray, err := repo.GetById("test-connect") + tray, err := repo.GetById(context.Background(),"test-connect") if err != nil { t.Fatalf("GetById failed after Connect: %v", err) } @@ -619,101 +481,3 @@ func TestConnect(t *testing.T) { } } -// TestCountByTrayType tests the CountByTrayType method -func TestCountByTrayType(t *testing.T) { - client, collection := setupTestCollection(t) - defer client.Disconnect(context.Background()) - - // Create test repository - repo := NewMongodbTrayRepository() - repo.Connect(collection) - - // Insert test data with specific counts for each status - // 2 Creating, 3 Registered, 1 Running, 2 Deleting for test-type - testTrays := []*TestTray{ - createTestTray("test-tray-1", "test-type", trays.TrayStatusCreating, 0), - createTestTray("test-tray-2", "test-type", trays.TrayStatusCreating, 0), - createTestTray("test-tray-3", "test-type", trays.TrayStatusRegistered, 0), - createTestTray("test-tray-4", "test-type", trays.TrayStatusRegistered, 0), - createTestTray("test-tray-5", "test-type", trays.TrayStatusRegistered, 0), - createTestTray("test-tray-6", "test-type", trays.TrayStatusRunning, 0), - createTestTray("test-tray-7", "test-type", trays.TrayStatusDeleting, 0), - createTestTray("test-tray-8", "test-type", trays.TrayStatusDeleting, 0), - // Different tray type - createTestTray("other-tray-1", "other-type", trays.TrayStatusCreating, 0), - createTestTray("other-tray-2", "other-type", trays.TrayStatusRegistered, 0), - } - insertTestTrays(t, collection, testTrays) - - // Test CountByTrayType for test-type - counts, total, err := repo.CountByTrayType("test-type") - if err != nil { - t.Fatalf("CountByTrayType failed: %v", err) - } - - // Verify the total count - expectedTotal := 8 // Total number of test-type trays - if total != expectedTotal { - t.Errorf("Expected total count %d, got %d", expectedTotal, total) - } - - // Verify counts for each status - expectedCounts := map[trays.TrayStatus]int{ - trays.TrayStatusCreating: 2, - trays.TrayStatusRegistered: 3, - trays.TrayStatusRunning: 1, - trays.TrayStatusDeleting: 2, - trays.TrayStatusRegistering: 0, // No trays with this status - } - - for status, expectedCount := range expectedCounts { - if counts[status] != expectedCount { - t.Errorf("Expected count %d for status %v, got %d", expectedCount, status, counts[status]) - } - } - - // Test CountByTrayType for other-type - counts, total, err = repo.CountByTrayType("other-type") - if err != nil { - t.Fatalf("CountByTrayType for other-type failed: %v", err) - } - - // Verify the total count for other-type - expectedTotal = 2 // Total number of other-type trays - if total != expectedTotal { - t.Errorf("Expected total count %d for other-type, got %d", expectedTotal, total) - } - - // Verify counts for each status for other-type - expectedCounts = map[trays.TrayStatus]int{ - trays.TrayStatusCreating: 1, - trays.TrayStatusRegistered: 1, - trays.TrayStatusRunning: 0, - trays.TrayStatusDeleting: 0, - trays.TrayStatusRegistering: 0, - } - - for status, expectedCount := range expectedCounts { - if counts[status] != expectedCount { - t.Errorf("Expected count %d for status %v in other-type, got %d", expectedCount, status, counts[status]) - } - } - - // Test CountByTrayType with non-existent tray type - counts, total, err = repo.CountByTrayType("non-existent") - if err != nil { - t.Fatalf("CountByTrayType with non-existent tray type failed: %v", err) - } - - // Verify the total count for non-existent type - if total != 0 { - t.Errorf("Expected total count 0 for non-existent type, got %d", total) - } - - // Verify that all status counts are 0 for non-existent type - for status, count := range counts { - if count != 0 { - t.Errorf("Expected count 0 for status %v in non-existent type, got %d", status, count) - } - } -} diff --git a/src/lib/trays/repositories/traysRepository.go b/src/lib/trays/repositories/traysRepository.go deleted file mode 100644 index cd40362..0000000 --- a/src/lib/trays/repositories/traysRepository.go +++ /dev/null @@ -1,54 +0,0 @@ -package repositories - -import ( - "cattery/lib/trays" - "sync" -) - -type MemTrayRepository struct { - ITrayRepository - trays map[string]*trays.Tray - mutex sync.RWMutex -} - -func NewMemTrayRepository() *MemTrayRepository { - return &MemTrayRepository{ - trays: make(map[string]*trays.Tray), - mutex: sync.RWMutex{}, - } -} - -func (r *MemTrayRepository) GetById(trayId string) (*trays.Tray, error) { - r.mutex.RLock() - defer r.mutex.RUnlock() - - tray, exists := r.trays[trayId] - if !exists { - return nil, nil - } - - return tray, nil -} - -func (r *MemTrayRepository) Save(tray *trays.Tray) error { - r.mutex.Lock() - defer r.mutex.Unlock() - - r.trays[tray.GetId()] = tray - return nil -} - -func (r *MemTrayRepository) Delete(trayId string) error { - r.mutex.Lock() - defer r.mutex.Unlock() - - delete(r.trays, trayId) - return nil -} - -func (r *MemTrayRepository) Len() int { - r.mutex.RLock() - defer r.mutex.RUnlock() - - return len(r.trays) -} diff --git a/src/lib/trays/tray.go b/src/lib/trays/tray.go index 1aa2c48..2cac404 100644 --- a/src/lib/trays/tray.go +++ b/src/lib/trays/tray.go @@ -11,13 +11,13 @@ import ( type Tray struct { Id string `bson:"id"` TrayTypeName string `bson:"trayTypeName"` - trayType config.TrayType ProviderName string `bson:"providerName"` GitHubOrgName string `bson:"gitHubOrgName"` GitHubRunnerId int64 `bson:"gitHubRunnerId"` JobRunId int64 `bson:"jobRunId"` WorkflowRunId int64 `bson:"workflowRunId"` + Repository string `bson:"repository"` Status TrayStatus `bson:"status"` StatusChanged time.Time `bson:"statusChanged"` @@ -25,7 +25,6 @@ type Tray struct { } func NewTray(trayType config.TrayType) *Tray { - b := make([]byte, 8) _, err := rand.Read(b) if err != nil { @@ -33,44 +32,34 @@ func NewTray(trayType config.TrayType) *Tray { } id := hex.EncodeToString(b) - var trayId = fmt.Sprintf("%s-%s", trayType.Name, id) - var tray = &Tray{ - Id: trayId, + return &Tray{ + Id: fmt.Sprintf("%s-%s", trayType.Name, id), TrayTypeName: trayType.Name, - trayType: trayType, ProviderName: trayType.Provider, Status: TrayStatusCreating, GitHubOrgName: trayType.GitHubOrg, - JobRunId: 0, - WorkflowRunId: 0, ProviderData: make(map[string]string), } - - return tray } -func (tray *Tray) GetId() string { - return tray.Id +// TrayType returns the configuration for this tray's type from the current config. +// Returns nil if the tray type no longer exists in config. +func (tray *Tray) TrayType() *config.TrayType { + return config.AppConfig.GetTrayType(tray.TrayTypeName) } -func (tray *Tray) GetGitHubOrgName() string { - return tray.GitHubOrgName -} - -func (tray *Tray) GetTrayTypeName() string { - return tray.TrayTypeName -} - -func (tray *Tray) GetTrayType() config.TrayType { - return tray.trayType -} - -func (tray *Tray) GetTrayConfig() config.TrayConfig { - return config.AppConfig.GetTrayType(tray.TrayTypeName).Config +// TrayConfig returns the provider-specific config (DockerTrayConfig, GoogleTrayConfig, etc.). +// Returns nil if the tray type no longer exists in config. +func (tray *Tray) TrayConfig() config.TrayConfig { + tt := tray.TrayType() + if tt == nil { + return nil + } + return tt.Config } func (tray *Tray) String() string { - return fmt.Sprintf("id: %s, trayTypeName: %s, status: %s, gitHubOrgName: %s, statusChanged: %s", + return fmt.Sprintf("id: %s, trayTypeName: %s, status: %s, gitHubOrgName: %s, statusChanged: %s", tray.Id, tray.TrayTypeName, tray.Status, tray.GitHubOrgName, tray.StatusChanged.Format(time.RFC3339)) } diff --git a/src/server/handlers/agentHandler.go b/src/server/handlers/agentHandler.go index 589b5cf..27b05a0 100644 --- a/src/server/handlers/agentHandler.go +++ b/src/server/handlers/agentHandler.go @@ -3,7 +3,6 @@ package handlers import ( "cattery/lib/agents" "cattery/lib/config" - "cattery/lib/githubClient" "cattery/lib/messages" "cattery/lib/metrics" "cattery/lib/trays" @@ -18,7 +17,7 @@ import ( ) // AgentRegister is a handler for agent registration requests -func AgentRegister(responseWriter http.ResponseWriter, r *http.Request) { +func (h *Handlers) AgentRegister(responseWriter http.ResponseWriter, r *http.Request) { var logger = log.WithFields(log.Fields{ "handler": "agent", @@ -27,11 +26,6 @@ func AgentRegister(responseWriter http.ResponseWriter, r *http.Request) { logger.Tracef("AgentRegister: %v", r) - if r.Method != http.MethodGet { - http.Error(responseWriter, "Method not allowed", http.StatusMethodNotAllowed) - return - } - var id = r.PathValue("id") var agentId = validateAgentId(id) @@ -41,7 +35,7 @@ func AgentRegister(responseWriter http.ResponseWriter, r *http.Request) { logger.Debug("Agent registration request") - var tray, err = TrayManager.Registering(agentId) + var tray, err = h.TrayManager.Registering(r.Context(), agentId) if err != nil { var errMsg = fmt.Sprintf("Failed to update tray status for agent '%s': %v", agentId, err) logger.Error(errMsg) @@ -49,44 +43,37 @@ func AgentRegister(responseWriter http.ResponseWriter, r *http.Request) { return } - var trayType = config.AppConfig.GetTrayType(tray.GetTrayTypeName()) + var trayType = config.AppConfig.GetTrayType(tray.TrayTypeName) if trayType == nil { - var errMsg = fmt.Sprintf("Tray type '%s' not found", tray.GetTrayTypeName()) + var errMsg = fmt.Sprintf("Tray type '%s' not found", tray.TrayTypeName) logger.Error(errMsg) http.Error(responseWriter, errMsg, http.StatusInternalServerError) return } logger = logger.WithFields(log.Fields{"trayType": trayType.Name}) - logger.Debugf("Found tray %s for agent %s, with organization %s", tray.GetId(), agentId, tray.GetGitHubOrgName()) + logger.Debugf("Found tray %s for agent %s, with organization %s", tray.Id, agentId, tray.GitHubOrgName) - // TODO handle - client, err := githubClient.NewGithubClientWithOrgName(tray.GetGitHubOrgName()) - if err != nil { - var errMsg = fmt.Sprintf("Organization '%s' is invalid: %v", tray.GetGitHubOrgName(), err) + poller := h.ScaleSetManager.GetPoller(trayType.Name) + if poller == nil { + var errMsg = fmt.Sprintf("No scale set poller found for tray type '%s'", trayType.Name) logger.Error(errMsg) http.Error(responseWriter, errMsg, http.StatusInternalServerError) return } - logger = logger.WithFields(log.Fields{"githubOrg": tray.GetGitHubOrgName()}) - - jitRunnerConfig, err := client.CreateJITConfig( - tray.GetId(), - trayType.RunnerGroupId, - []string{trayType.Name}, - ) + jitRunnerConfig, err := poller.Client().GenerateJitRunnerConfig(r.Context(), tray.Id) if err != nil { logger.Errorf("Failed to generate jitRunnerConfig: %v", err) http.Error(responseWriter, "Failed to generate jitRunnerConfig", http.StatusInternalServerError) return } - var jitConfig = jitRunnerConfig.GetEncodedJITConfig() + var jitConfig = jitRunnerConfig.EncodedJITConfig var newAgent = agents.Agent{ AgentId: agentId, - RunnerId: jitRunnerConfig.GetRunner().GetID(), + RunnerId: int64(jitRunnerConfig.Runner.ID), Shutdown: trayType.Shutdown, } @@ -95,6 +82,7 @@ func AgentRegister(responseWriter http.ResponseWriter, r *http.Request) { JitConfig: jitConfig, } + responseWriter.Header().Set("Content-Type", "application/json") err = json.NewEncoder(responseWriter).Encode(registerResponse) if err != nil { logger.Errorf("Failed to encode response: %v", err) @@ -102,7 +90,7 @@ func AgentRegister(responseWriter http.ResponseWriter, r *http.Request) { return } - _, err = TrayManager.Registered(agentId, jitRunnerConfig.GetRunner().GetID()) + _, err = h.TrayManager.Registered(r.Context(), agentId, int64(jitRunnerConfig.Runner.ID)) if err != nil { logger.Errorf("%v", err) } @@ -118,7 +106,7 @@ func validateAgentId(agentId string) string { } // AgentUnregister is a handler for agent unregister requests -func AgentUnregister(responseWriter http.ResponseWriter, r *http.Request) { +func (h *Handlers) AgentUnregister(responseWriter http.ResponseWriter, r *http.Request) { var logger = log.WithFields(log.Fields{ "handler": "agent", "call": "AgentUnregister", @@ -126,14 +114,9 @@ func AgentUnregister(responseWriter http.ResponseWriter, r *http.Request) { logger.Tracef("AgentUnregister: %v", r) - if r.Method != http.MethodPost { - http.Error(responseWriter, "Method not allowed", http.StatusMethodNotAllowed) - return - } - var trayId = r.PathValue("id") - var tray, err = TrayManager.GetTrayById(trayId) + var tray, err = h.TrayManager.GetTrayById(r.Context(), trayId) if err != nil { var errMsg = fmt.Sprintf("Failed to get tray for agent '%s': %v", trayId, err) logger.Error(errMsg) @@ -160,10 +143,12 @@ func AgentUnregister(responseWriter http.ResponseWriter, r *http.Request) { logger.Tracef("Agent unregister request") - _, err = TrayManager.DeleteTray(tray.Id) + _, err = h.TrayManager.DeleteTray(r.Context(), tray.Id) if err != nil { logger.Errorf("Failed to delete tray: %v", err) + http.Error(responseWriter, "Failed to delete tray", http.StatusInternalServerError) + return } logger.Infof("Agent %s unregistered, reason: %d", unregisterRequest.Agent.AgentId, unregisterRequest.Reason) @@ -176,54 +161,17 @@ func AgentUnregister(responseWriter http.ResponseWriter, r *http.Request) { } func AgentDownloadBinary(responseWriter http.ResponseWriter, r *http.Request) { - var logger = log.WithFields(log.Fields{ - "handler": "agent", - "call": "AgentDownloadBinary", - }) - logger.Tracef("AgentDownloadBinary: %v", r) - - if r.Method != http.MethodGet { - http.Error(responseWriter, "Method not allowed", http.StatusMethodNotAllowed) - return - } - - // Get the current executable path execPath, err := os.Executable() if err != nil { - logger.Errorf("Failed to get executable path: %v", err) http.Error(responseWriter, "Failed to get binary path", http.StatusInternalServerError) return } - // Open the binary file - file, err := os.Open(execPath) - if err != nil { - logger.Errorf("Failed to open binary file: %v", err) - http.Error(responseWriter, "Failed to open binary file", http.StatusInternalServerError) - return - } - defer file.Close() - - // Get file info for size and name - fileInfo, err := file.Stat() - if err != nil { - logger.Errorf("Failed to get file info: %v", err) - http.Error(responseWriter, "Failed to get file info", http.StatusInternalServerError) - return - } - - // Set appropriate headers - responseWriter.Header().Set("Content-Type", "application/octet-stream") responseWriter.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", filepath.Base(execPath))) - responseWriter.Header().Set("Content-Length", fmt.Sprintf("%d", fileInfo.Size())) - - // Serve the file - http.ServeContent(responseWriter, r, filepath.Base(execPath), fileInfo.ModTime(), file) - - logger.Infof("Binary file served: %s (%d bytes)", execPath, fileInfo.Size()) + http.ServeFile(responseWriter, r, execPath) } -func AgentPing(responseWriter http.ResponseWriter, r *http.Request) { +func (h *Handlers) AgentPing(responseWriter http.ResponseWriter, r *http.Request) { var logger = log.WithFields(log.Fields{ "handler": "agent", "call": "AgentPing", @@ -239,13 +187,12 @@ func AgentPing(responseWriter http.ResponseWriter, r *http.Request) { Message: "", } - tray, err := TrayManager.GetTrayById(agentId) + tray, err := h.TrayManager.GetTrayById(r.Context(), agentId) if err != nil { var errMsg = fmt.Sprintf("Failed to get tray by id '%s': %v", agentId, err) logger.Error(errMsg) - http.Error(responseWriter, errMsg, http.StatusInternalServerError) - pingResponse.Message = "Failed to get tray by id: " + errMsg + pingResponse.Message = errMsg pingResponse.Terminate = true writeResponse(responseWriter, pingResponse, logger) @@ -254,9 +201,8 @@ func AgentPing(responseWriter http.ResponseWriter, r *http.Request) { if tray == nil { var errMsg = fmt.Sprintf("Tray with id '%s' not found", agentId) logger.Error(errMsg) - http.Error(responseWriter, errMsg, http.StatusGone) - pingResponse.Message = "Failed to get tray by id: " + errMsg + pingResponse.Message = errMsg pingResponse.Terminate = true writeResponse(responseWriter, pingResponse, logger) @@ -290,7 +236,7 @@ func writeResponse(responseWriter http.ResponseWriter, pingResponse any, logger } } -func AgentInterrupt(responseWriter http.ResponseWriter, r *http.Request) { +func (h *Handlers) AgentInterrupt(responseWriter http.ResponseWriter, r *http.Request) { var logger = log.WithFields(log.Fields{ "handler": "agent", "call": "AgentRestart", @@ -298,11 +244,6 @@ func AgentInterrupt(responseWriter http.ResponseWriter, r *http.Request) { logger.Tracef("AgentRestart: %v", r) - if r.Method != http.MethodPost { - http.Error(responseWriter, "Method not allowed", http.StatusMethodNotAllowed) - return - } - var id = r.PathValue("id") var agentId = validateAgentId(id) @@ -312,7 +253,7 @@ func AgentInterrupt(responseWriter http.ResponseWriter, r *http.Request) { logger.Debug("Agent restart request with id " + agentId) - tray, err := TrayManager.GetTrayById(agentId) + tray, err := h.TrayManager.GetTrayById(r.Context(), agentId) if err != nil { var errMsg = fmt.Sprintf("Failed to get tray by id '%s': %v", agentId, err) logger.Error(errMsg) @@ -326,5 +267,9 @@ func AgentInterrupt(responseWriter http.ResponseWriter, r *http.Request) { return } workflowRunId := tray.WorkflowRunId - RestartManager.RequestRestart(workflowRunId) + if err := h.RestartManager.RequestRestart(workflowRunId, tray.GitHubOrgName, tray.Repository); err != nil { + logger.Errorf("Failed to request restart for workflow %d: %v", workflowRunId, err) + http.Error(responseWriter, "Failed to request restart", http.StatusInternalServerError) + return + } } diff --git a/src/server/handlers/rootHandler.go b/src/server/handlers/rootHandler.go index ff1c54e..965f936 100644 --- a/src/server/handlers/rootHandler.go +++ b/src/server/handlers/rootHandler.go @@ -1,16 +1,18 @@ package handlers import ( - "cattery/lib/jobQueue" "cattery/lib/restarter" + "cattery/lib/scaleSetPoller" "cattery/lib/trayManager" "net/http" ) -var QueueManager *jobQueue.QueueManager -var TrayManager *trayManager.TrayManager -var RestartManager *restarter.WorkflowRestarter +type Handlers struct { + TrayManager *trayManager.TrayManager + RestartManager *restarter.WorkflowRestarter + ScaleSetManager *scaleSetPoller.Manager +} -func Index(responseWriter http.ResponseWriter, r *http.Request) { +func (h *Handlers) Index(responseWriter http.ResponseWriter, r *http.Request) { return } diff --git a/src/server/handlers/webhookHandler.go b/src/server/handlers/webhookHandler.go deleted file mode 100644 index d701ad7..0000000 --- a/src/server/handlers/webhookHandler.go +++ /dev/null @@ -1,243 +0,0 @@ -package handlers - -import ( - "cattery/lib/config" - "cattery/lib/jobs" - "fmt" - "net/http" - - "github.com/google/go-github/v70/github" - log "github.com/sirupsen/logrus" -) - -func Webhook(responseWriter http.ResponseWriter, r *http.Request) { - - var logger = log.WithFields( - log.Fields{ - "handler": "webhook", - "call": "Webhook", - }, - ) - - logger.Tracef("Webhook received") - - if r.Method != http.MethodPost { - http.Error(responseWriter, "Method not allowed", http.StatusMethodNotAllowed) - return - } - - event := r.Header.Get("X-GitHub-Event") - - switch event { - case "workflow_job": - handleWorkflowJobWebhook(responseWriter, r, logger) - case "workflow_run": - handleWorkflowRunWebhook(responseWriter, r, logger) - default: - logger.Debugf("Ignoring webhook request: X-GitHub-Event is not 'workflow_job' or 'workflow_run', got '%s'", event) - return - } -} - -func handleWorkflowJobWebhook(responseWriter http.ResponseWriter, r *http.Request, logger *log.Entry) { - var webhookData *github.WorkflowJobEvent - - var organizationName = r.PathValue("org") - var org = config.AppConfig.GetGitHubOrg(organizationName) - if org == nil { - var errMsg = fmt.Sprintf("Organization '%s' not found in config", organizationName) - logger.Error(errMsg) - http.Error(responseWriter, errMsg, http.StatusBadRequest) - return - } - logger = logger.WithField("githubOrg", organizationName) - logger = logger.WithField("type", "workflow_job") - - payload, err := github.ValidatePayload(r, []byte(org.WebhookSecret)) - if err != nil { - logger.Errorf("Failed to validate payload: %v", err) - http.Error(responseWriter, "Failed to validate payload", http.StatusBadRequest) - return - } - - hook, err := github.ParseWebHook(r.Header.Get("X-GitHub-Event"), payload) - if err != nil { - logger.Errorf("Failed to parse webhook: %v", err) - return - } - webhookData, ok := hook.(*github.WorkflowJobEvent) - if !ok { - logger.Errorf("Webhook payload is not WorkflowJobEvent") - return - } - - logger.Tracef("Event payload: %v", payload) - - trayType := getTrayType(webhookData) - if trayType == nil { - logger.Tracef("Ignoring action: '%s', for job '%s', no tray type found for labels: %v", webhookData.GetAction(), *webhookData.WorkflowJob.Name, webhookData.WorkflowJob.Labels) - return - } - logger = logger.WithField("jobRunId", webhookData.WorkflowJob.GetID()) - - logger.Debugf("Action: %s", webhookData.GetAction()) - - job := jobs.FromGithubModel(webhookData) - job.TrayType = trayType.Name - - logger = logger.WithField("trayType", trayType.Name) - - switch webhookData.GetAction() { - case "queued": - handleQueuedWorkflowJob(responseWriter, logger, job) - case "in_progress": - handleInProgressWorkflowJob(responseWriter, logger, job) - case "completed": - handleCompletedWorkflowJob(responseWriter, logger, job) - default: - logger.Debugf("Ignoring action: '%s', for job '%s'", webhookData.GetAction(), *webhookData.WorkflowJob.Name) - return - } -} - -func handleWorkflowRunWebhook(responseWriter http.ResponseWriter, r *http.Request, logger *log.Entry) { - log.Debugf("Received workflow_run webhook") - logger = logger.WithField("type", "workflow_run") - var webhookData *github.WorkflowRunEvent - organizationName := r.PathValue("org") - org := config.AppConfig.GetGitHubOrg(organizationName) - if org == nil { - errMsg := fmt.Sprintf("Organization '%s' not found in config", organizationName) - logger.Error(errMsg) - http.Error(responseWriter, errMsg, http.StatusBadRequest) - return - } - payload, err := github.ValidatePayload(r, []byte(org.WebhookSecret)) - if err != nil { - logger.Errorf("Error validating payload: %v", err) - http.Error(responseWriter, "Error validating payload", http.StatusBadRequest) - return - } - hook, err := github.ParseWebHook(r.Header.Get("X-GitHub-Event"), payload) - if err != nil { - logger.Errorf("Error parsing webhook: %v", err) - http.Error(responseWriter, "Error parsing webhook", http.StatusBadRequest) - return - } - webhookData, ok := hook.(*github.WorkflowRunEvent) - if !ok { - logger.Errorf("Webhook payload is not WorkflowRunEvent") - http.Error(responseWriter, "Webhook payload is not WorkflowRunEvent", http.StatusBadRequest) - return - } - conclusion := webhookData.GetWorkflowRun().GetConclusion() - repoName := webhookData.GetRepo().GetName() - orgName := webhookData.GetOrg().GetLogin() - logger.Debugf("Action: %s, Org: %s, Repo: %s, Workflow run ID: %d, conclusion: %s", webhookData.GetAction(), orgName, repoName, webhookData.GetWorkflowRun().GetID(), conclusion) - - // On "completed" action and "failure" conclusion trigger restart - if webhookData.GetAction() == "completed" && conclusion == "failure" { - logger.Infof("Requesting restart for failed jobs in workflow run ID: %d", webhookData.GetWorkflowRun().GetID()) - err := RestartManager.Restart(*webhookData.WorkflowRun.ID, orgName, repoName) - if err != nil { - logger.Errorf("Failed to request restart: %v", err) - http.Error(responseWriter, "Failed to request restart", http.StatusInternalServerError) - } - return - } - // On "completed" action and "cancelled" or "success" conclusion trigger cleanup - if webhookData.GetAction() == "completed" && (conclusion == "cancelled" || conclusion == "success") { - if conclusion == "cancelled" { - logger.Infof("Cleaning up jobs for workflow run ID: %d", webhookData.GetWorkflowRun().GetID()) - err := QueueManager.CleanupByWorkflowRun(*webhookData.WorkflowRun.ID) - if err != nil { - logger.Errorf("Failed to cleanup jobs: %v", err) - http.Error(responseWriter, "Failed to cleanup jobs", http.StatusInternalServerError) - } - } - logger.Infof("Cleaning up restart requests for workflow run ID: %d", webhookData.GetWorkflowRun().GetID()) - err = RestartManager.Cleanup(*webhookData.WorkflowRun.ID, orgName, repoName) - if err != nil { - logger.Errorf("Failed to cleanup restart requests: %v", err) - http.Error(responseWriter, "Failed to cleanup restart requests", http.StatusInternalServerError) - } - return - } - -} - -// handleCompletedWorkflowJob -// handles the 'completed' action of the workflow job event -func handleCompletedWorkflowJob(responseWriter http.ResponseWriter, logger *log.Entry, job *jobs.Job) { - - err := QueueManager.UpdateJobStatus(job.Id, jobs.JobStatusFinished) - if err != nil { - logger.Errorf("Failed to update job status: %v", err) - } - - _, err = TrayManager.DeleteTray(job.RunnerName) - if err != nil { - logger.Errorf("Failed to delete tray: %v", err) - } -} - -// handleInProgressWorkflowJob -// handles the 'in_progress' action of the workflow job event -func handleInProgressWorkflowJob(responseWriter http.ResponseWriter, logger *log.Entry, job *jobs.Job) { - - err := QueueManager.JobInProgress(job.Id) - if err != nil { - var errMsg = fmt.Sprintf("Failed to mark job '%s/%s' as in progress: %v", job.WorkflowName, job.Name, err) - logger.Error(errMsg) - http.Error(responseWriter, errMsg, http.StatusInternalServerError) - } - - tray, err := TrayManager.SetJob(job.RunnerName, job.Id, job.WorkflowId) - if tray == nil { - logger.Errorf("Failed to set job '%s/%s' as in progress to tray, tray not found: %v", job.WorkflowName, job.Name, err) - } - if err != nil { - logger.Errorf("Failed to set job '%s/%s' as in progress to tray: %v", job.WorkflowName, job.Name, err) - } - - logger.Infof("Tray '%s' is running '%s/%s/%s/%s'", - job.RunnerName, - job.Organization, job.Repository, job.WorkflowName, job.Name, - ) -} - -// handleQueuedWorkflowJob -// handles the 'handleQueuedWorkflowJob' action of the workflow job event -func handleQueuedWorkflowJob(responseWriter http.ResponseWriter, logger *log.Entry, job *jobs.Job) { - err := QueueManager.AddJob(job) - if err != nil { - var errMsg = fmt.Sprintf("Failed to enqueue job '%s/%s/%s': %v", job.Repository, job.WorkflowName, job.Name, err) - logger.Error(errMsg) - http.Error(responseWriter, errMsg, http.StatusInternalServerError) - return - } - - logger.Infof("Enqueued job %s/%s/%s/%s ", job.Organization, job.Repository, job.WorkflowName, job.Name) -} - -func getTrayType(webhookData *github.WorkflowJobEvent) *config.TrayType { - - if len(webhookData.WorkflowJob.Labels) != 1 { - // Cattery only support one label for now - return nil - } - - // find tray type based on labels (runs_on) - var label = webhookData.WorkflowJob.Labels[0] - var trayType = config.AppConfig.GetTrayType(label) - - if trayType == nil { - return nil - } - - if trayType.GitHubOrg != webhookData.GetOrg().GetLogin() { - return nil - } - - return trayType -} diff --git a/src/server/server.go b/src/server/server.go index 6583f98..b99622a 100644 --- a/src/server/server.go +++ b/src/server/server.go @@ -2,9 +2,10 @@ package server import ( "cattery/lib/config" - "cattery/lib/jobQueue" "cattery/lib/restarter" restarterRepo "cattery/lib/restarter/repositories" + "cattery/lib/scaleSetClient" + "cattery/lib/scaleSetPoller" "cattery/lib/trayManager" "cattery/lib/trays/repositories" "cattery/server/handlers" @@ -13,6 +14,7 @@ import ( "os" "os/signal" "syscall" + "time" "github.com/prometheus/client_golang/prometheus/promhttp" log "github.com/sirupsen/logrus" @@ -24,43 +26,32 @@ func Start() { var logger = log.New() - sigs := make(chan os.Signal, 1) - signal.Notify(sigs, syscall.SIGINT) - signal.Notify(sigs, syscall.SIGTERM) - signal.Notify(sigs, syscall.SIGKILL) - - var webhookMux = http.NewServeMux() - webhookMux.HandleFunc("/{$}", handlers.Index) - webhookMux.HandleFunc("GET /agent/register/{id}", handlers.AgentRegister) - webhookMux.HandleFunc("POST /agent/unregister/{id}", handlers.AgentUnregister) - webhookMux.HandleFunc("GET /agent/download", handlers.AgentDownloadBinary) - webhookMux.HandleFunc("POST /agent/interrupt/{id}", handlers.AgentInterrupt) - webhookMux.HandleFunc("POST /agent/ping/{id}", handlers.AgentPing) - - webhookMux.HandleFunc("POST /github/{org}", handlers.Webhook) - - webhookMux.HandleFunc("/metrics", promhttp.Handler().ServeHTTP) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - var webhookServer = &http.Server{ - Addr: config.AppConfig.Server.ListenAddress, - Handler: webhookMux, - } + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) // Db connection serverAPI := options.ServerAPI(options.ServerAPIVersion1) opts := options.Client(). ApplyURI(config.AppConfig.Database.Uri). - SetServerAPIOptions(serverAPI) //.SetTimeout(3 * time.Second) + SetServerAPIOptions(serverAPI) client, err := mongo.Connect(opts) if err != nil { logger.Fatal(err) } - err = client.Ping(context.Background(), nil) - if err != nil { - logger.Errorf("Failed to connect to MongoDB: %v", err) - os.Exit(1) + { + timeoutCtx, cf := context.WithTimeout(context.Background(), 3*time.Second) + defer cf() + + err = client.Ping(timeoutCtx, nil) + if err != nil { + logger.Errorf("Failed to connect to MongoDB: %v", err) + os.Exit(1) + } } var database = client.Database(config.AppConfig.Database.Database) @@ -68,31 +59,81 @@ func Start() { // Initialize tray manager and repository var trayRepository = repositories.NewMongodbTrayRepository() trayRepository.Connect(database.Collection("trays")) + tm := trayManager.NewTrayManager(trayRepository) - handlers.TrayManager = trayManager.NewTrayManager(trayRepository) - - //QueueManager initialization - handlers.QueueManager = jobQueue.NewQueueManager() - handlers.QueueManager.Connect(database.Collection("jobs")) - - // Initialize restarter repository + // Initialize restarter var restartManagerRepository = restarterRepo.NewMongodbRestarterRepository() restartManagerRepository.Connect(database.Collection("restarters")) + rm := restarter.NewWorkflowRestarter(restartManagerRepository) + + // Initialize scale set pollers — one per TrayType + ssm := scaleSetPoller.NewManager() + for _, trayType := range config.AppConfig.TrayTypes { + org := config.AppConfig.GetGitHubOrg(trayType.GitHubOrg) + if org == nil { + logger.Fatalf("GitHub organization '%s' not found for tray type '%s'", trayType.GitHubOrg, trayType.Name) + } + + ssClient, err := scaleSetClient.NewScaleSetClient(org, trayType) + if err != nil { + logger.Fatalf("Failed to create scale set client for tray type '%s': %v", trayType.Name, err) + } - handlers.RestartManager = restarter.NewWorkflowRestarter(restartManagerRepository) + poller := scaleSetPoller.NewPoller(ssClient, trayType, tm) + ssm.Register(trayType.Name, poller) + + ssm.Wg.Add(1) + go func(p *scaleSetPoller.Poller, name string) { + defer ssm.Wg.Done() + for { + if err := p.Run(ctx); err != nil { + if ctx.Err() != nil { + logger.Infof("Scale set poller for '%s' stopped: %v", name, err) + return + } + logger.Errorf("Scale set poller for '%s' exited with error: %v — restarting in 30s", name, err) + select { + case <-ctx.Done(): + return + case <-time.After(30 * time.Second): + } + continue + } + return + } + }(poller, trayType.Name) + } - err = handlers.QueueManager.Load() - if err != nil { - logger.Errorf("Failed to load queue manager: %v", err) + // Start restart poller (replaces workflow_run webhook) + rm.StartPoller(ctx) + + // Start stale tray cleanup + tm.HandleStale(ctx) + + h := &handlers.Handlers{ + TrayManager: tm, + RestartManager: rm, + ScaleSetManager: ssm, } - handlers.TrayManager.HandleJobsQueue(context.Background(), handlers.QueueManager) - handlers.TrayManager.HandleStale(context.Background()) + var mux = http.NewServeMux() + mux.HandleFunc("/{$}", h.Index) + mux.HandleFunc("GET /agent/register/{id}", h.AgentRegister) + mux.HandleFunc("POST /agent/unregister/{id}", h.AgentUnregister) + mux.HandleFunc("GET /agent/download", handlers.AgentDownloadBinary) + mux.HandleFunc("POST /agent/interrupt/{id}", h.AgentInterrupt) + mux.HandleFunc("POST /agent/ping/{id}", h.AgentPing) + mux.HandleFunc("/metrics", promhttp.Handler().ServeHTTP) + + var httpServer = &http.Server{ + Addr: config.AppConfig.Server.ListenAddress, + Handler: mux, + } - // Start the server + // Start HTTP server go func() { - logger.Infof("Starting webhook server on %s", config.AppConfig.Server.ListenAddress) - err := webhookServer.ListenAndServe() + logger.Infof("Starting server on %s", config.AppConfig.Server.ListenAddress) + err := httpServer.ListenAndServe() if err != nil { logger.Fatal(err) return @@ -101,4 +142,9 @@ func Start() { sig := <-sigs logger.Info("Got signal ", sig) + cancel() + + logger.Info("Waiting for pollers to shut down...") + ssm.Wg.Wait() + logger.Info("All pollers stopped") }