Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 54 additions & 18 deletions internal/opensearchml/create_connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,35 +57,41 @@ type CreateConnectorRequestAction struct {
PostProcessFunction string `json:"post_process_function"`
}

type CreateConnectorResponse struct {
type CreateOrUpdateConnectorResponse struct {
ConnectorID string `json:"connector_id"`
}

func (c *Client) CreateConnector(ctx context.Context, req CreateConnectorRequest) (CreateConnectorResponse, error) {
func (c *Client) CreateOrUpdateConnector(ctx context.Context, req CreateConnectorRequest) (CreateOrUpdateConnectorResponse, error) {
if req.Name == "" {
return CreateConnectorResponse{}, fmt.Errorf("connector name is required")
return CreateOrUpdateConnectorResponse{}, fmt.Errorf("connector name is required")
}

bodyBytes, err := json.Marshal(req)
if err != nil {
return CreateOrUpdateConnectorResponse{}, fmt.Errorf("marshal connector payload: %w", err)
}

// If a connector with this name already exists, update it via PUT.
if id, ok, err := c.FindConnectorIDByName(ctx, req.Name); err != nil {
return CreateConnectorResponse{}, fmt.Errorf("find connector by name: %w", err)
return CreateOrUpdateConnectorResponse{}, fmt.Errorf("find connector by name: %w", err)
} else if ok {
return CreateConnectorResponse{ConnectorID: id}, nil
}
if err := c.updateConnector(ctx, id, bodyBytes); err != nil {
return CreateOrUpdateConnectorResponse{}, fmt.Errorf("update connector: %w", err)
}

bodyBytes, err := json.Marshal(req)
if err != nil {
return CreateConnectorResponse{}, fmt.Errorf("marshal create connector payload: %w", err)
return CreateOrUpdateConnectorResponse{ConnectorID: id}, nil
}

// Connector does not exist yet, create it via POST.
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, "/_plugins/_ml/connectors/_create", bytes.NewReader(bodyBytes))
if err != nil {
return CreateConnectorResponse{}, fmt.Errorf("new request: %w", err)
return CreateOrUpdateConnectorResponse{}, fmt.Errorf("new request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")

httpResp, err := c.opensearch.Client.Perform(httpReq)
if err != nil {
return CreateConnectorResponse{}, fmt.Errorf("perform create request: %w", err)
return CreateOrUpdateConnectorResponse{}, fmt.Errorf("perform create request: %w", err)
}
defer func() {
if err := httpResp.Body.Close(); err != nil {
Expand All @@ -95,28 +101,58 @@ func (c *Client) CreateConnector(ctx context.Context, req CreateConnectorRequest

respBytes, _ := io.ReadAll(httpResp.Body)

// If another caller created it between our check and create, re-check and return.
// If another caller created it between our check and create, update the existing connector.
if httpResp.StatusCode == http.StatusConflict {
if id, ok, err := c.FindConnectorIDByName(ctx, req.Name); err != nil {
return CreateConnectorResponse{}, fmt.Errorf("create conflict; re-find connector: %w", err)
return CreateOrUpdateConnectorResponse{}, fmt.Errorf("create conflict; re-find connector: %w", err)
} else if ok {
return CreateConnectorResponse{ConnectorID: id}, nil
if err := c.updateConnector(ctx, id, bodyBytes); err != nil {
return CreateOrUpdateConnectorResponse{}, fmt.Errorf("create conflict; update connector: %w", err)
}

return CreateOrUpdateConnectorResponse{ConnectorID: id}, nil
}
}

if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
return CreateConnectorResponse{}, fmt.Errorf("create connector failed: status=%d body=%s", httpResp.StatusCode, string(respBytes))
return CreateOrUpdateConnectorResponse{}, fmt.Errorf("create connector failed: status=%d body=%s", httpResp.StatusCode, string(respBytes))
}

var out CreateConnectorResponse
var out CreateOrUpdateConnectorResponse

if err := json.Unmarshal(respBytes, &out); err != nil {
return CreateConnectorResponse{}, fmt.Errorf("unmarshal create response: %w (body=%s)", err, string(respBytes))
return CreateOrUpdateConnectorResponse{}, fmt.Errorf("unmarshal create response: %w (body=%s)", err, string(respBytes))
}

if out.ConnectorID == "" {
return CreateConnectorResponse{}, fmt.Errorf("create response missing connector_id (body=%s)", string(respBytes))
return CreateOrUpdateConnectorResponse{}, fmt.Errorf("create response missing connector_id (body=%s)", string(respBytes))
}

return out, nil
}

// updateConnector sends a PUT request to update an existing connector.
func (c *Client) updateConnector(ctx context.Context, connectorID string, bodyBytes []byte) error {
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPut, fmt.Sprintf("/_plugins/_ml/connectors/%s", connectorID), bytes.NewReader(bodyBytes))
if err != nil {
return fmt.Errorf("new request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")

httpResp, err := c.opensearch.Client.Perform(httpReq)
if err != nil {
return fmt.Errorf("perform update request: %w", err)
}
defer func() {
if err := httpResp.Body.Close(); err != nil {
fmt.Printf("error closing response body: %v\n", err)
}
}()

if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
respBytes, _ := io.ReadAll(httpResp.Body)
return fmt.Errorf("update connector failed: status=%d body=%s", httpResp.StatusCode, string(respBytes))
}

return nil
}
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func handler(ctx context.Context) error {

logger.SetAttr("model_group_id", groupResp.ModelGroupID)

connectorResp, err := client.CreateConnector(ctx, connector)
connectorResp, err := client.CreateOrUpdateConnector(ctx, connector)
if err != nil {
return logger.WrapError(err)
}
Expand Down