diff --git a/cmd/server/main.go b/cmd/server/main.go index c4f915e..1c30e1e 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -2,20 +2,23 @@ package main import ( "fmt" + "log" + "net/http" "github.com/benchttp/engine/server" ) -const port = "8080" +const ( + port = "8080" + // token is a dummy token used for development only. + token = "6db67fafc4f5bf965a5a" //nolint:gosec +) func main() { - if err := run(); err != nil { - fmt.Println(err) - } -} - -func run() error { addr := ":" + port fmt.Println("http://localhost" + addr) - return server.ListenAndServe(addr) + + handler := server.NewHandler(false, token) + + log.Fatal(http.ListenAndServe(addr, handler)) } diff --git a/go.mod b/go.mod index 4de6a8f..30119a4 100644 --- a/go.mod +++ b/go.mod @@ -8,4 +8,7 @@ require ( gopkg.in/yaml.v3 v3.0.1 ) -require github.com/drykit-go/cond v0.1.0 // indirect +require ( + github.com/drykit-go/cond v0.1.0 // indirect + github.com/gorilla/websocket v1.5.0 +) diff --git a/go.sum b/go.sum index 12452b0..9e7eaf9 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,8 @@ github.com/drykit-go/strcase v0.2.0/go.mod h1:cWK0/az2f09UPIbJ42Sb8Iqdv01uENrFX+ github.com/drykit-go/testx v0.1.0/go.mod h1:qGXb49a8CzQ82crBeCVW8R3kGU1KRgWHnI+Q6CNVbz8= github.com/drykit-go/testx v1.2.0 h1:UsH+tFd24z3Xu+mwvwPY+9eBEg9CUyMsUeMYyUprG0o= github.com/drykit-go/testx v1.2.0/go.mod h1:qTzXJgnAg8n31woklBzNTaWzLMJrnFk93x/aeaIpc20= +github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= +github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 h1:uVc8UZUe6tr40fFVnUP5Oj+veunVezqYl9z7DYw9xzw= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= diff --git a/internal/configparse/json.go b/internal/configparse/json.go index 37a5bbb..736c520 100644 --- a/internal/configparse/json.go +++ b/internal/configparse/json.go @@ -8,7 +8,7 @@ import ( func JSON(in []byte) (runner.Config, error) { parser := jsonParser{} - var uconf unmarshaledConfig + var uconf UnmarshaledConfig if err := parser.parse(in, &uconf); err != nil { return runner.Config{}, err } diff --git a/internal/configparse/parse.go b/internal/configparse/parse.go index d8c8c27..9bcb084 100644 --- a/internal/configparse/parse.go +++ b/internal/configparse/parse.go @@ -11,11 +11,11 @@ import ( "github.com/benchttp/engine/runner" ) -// unmarshaledConfig is a raw data model for runner config files. +// UnmarshaledConfig is a raw data model for runner config files. // It serves as a receiver for unmarshaling processes and for that reason // its types are kept simple (certain types are incompatible with certain // unmarshalers). -type unmarshaledConfig struct { +type UnmarshaledConfig struct { Extends *string `yaml:"extends" json:"extends"` Request struct { @@ -47,7 +47,7 @@ type unmarshaledConfig struct { // and returns it or the first non-nil error occurring in the process, // which can be any of the values declared in the package. func Parse(filename string) (cfg runner.Config, err error) { - uconfs, err := parseFileRecursive(filename, []unmarshaledConfig{}, set{}) + uconfs, err := parseFileRecursive(filename, []UnmarshaledConfig{}, set{}) if err != nil { return } @@ -73,9 +73,9 @@ func (s set) add(v string) error { // occurring in the process. func parseFileRecursive( filename string, - uconfs []unmarshaledConfig, + uconfs []UnmarshaledConfig, seen set, -) ([]unmarshaledConfig, error) { +) ([]UnmarshaledConfig, error) { // avoid infinite recursion caused by circular reference if err := seen.add(filename); err != nil { return uconfs, ErrCircularExtends @@ -100,7 +100,7 @@ func parseFileRecursive( // parseFile parses a single config file and returns the result as an // unmarshaledConfig and an appropriate error predeclared in the package. -func parseFile(filename string) (uconf unmarshaledConfig, err error) { +func parseFile(filename string) (uconf UnmarshaledConfig, err error) { b, err := os.ReadFile(filename) switch { case err == nil: @@ -127,7 +127,7 @@ func parseFile(filename string) (uconf unmarshaledConfig, err error) { // as runner.ConfigGlobal and merging them into a single one. // It returns the merged result or the first non-nil error occurring in the // process. -func parseAndMergeConfigs(uconfs []unmarshaledConfig) (cfg runner.Config, err error) { +func parseAndMergeConfigs(uconfs []UnmarshaledConfig) (cfg runner.Config, err error) { if len(uconfs) == 0 { // supposedly catched upstream, should not occur return cfg, errors.New( "an unacceptable error occurred parsing the config file, " + @@ -164,7 +164,7 @@ func (pconf *parsedConfig) add(field string) { // newParsedConfig parses an input raw config as a runner.ConfigGlobal and returns // a parsedConfig or the first non-nil error occurring in the process. -func newParsedConfig(uconf unmarshaledConfig) (parsedConfig, error) { //nolint:gocognit // acceptable complexity for a parsing func +func newParsedConfig(uconf UnmarshaledConfig) (parsedConfig, error) { //nolint:gocognit // acceptable complexity for a parsing func const numField = 12 // should match the number of config Fields (not critical) pconf := parsedConfig{ diff --git a/internal/configparse/parser.go b/internal/configparse/parser.go index eaa036f..c647f88 100644 --- a/internal/configparse/parser.go +++ b/internal/configparse/parser.go @@ -22,7 +22,7 @@ const ( type configParser interface { // parse parses a raw bytes input as a raw config and stores // the resulting value into dst. - parse(in []byte, dst *unmarshaledConfig) error + parse(in []byte, dst *UnmarshaledConfig) error } // newParser returns an appropriate parser according to ext, or a non-nil @@ -43,7 +43,7 @@ type yamlParser struct{} // parse decodes a raw yaml input in strict mode (unknown fields disallowed) // and stores the resulting value into dst. -func (p yamlParser) parse(in []byte, dst *unmarshaledConfig) error { +func (p yamlParser) parse(in []byte, dst *UnmarshaledConfig) error { decoder := yaml.NewDecoder(bytes.NewReader(in)) decoder.KnownFields(true) return p.handleError(decoder.Decode(dst)) @@ -130,7 +130,7 @@ type jsonParser struct{} // parse decodes a raw JSON input in strict mode (unknown fields disallowed) // and stores the resulting value into dst. -func (p jsonParser) parse(in []byte, dst *unmarshaledConfig) error { +func (p jsonParser) parse(in []byte, dst *UnmarshaledConfig) error { decoder := json.NewDecoder(bytes.NewReader(in)) decoder.DisallowUnknownFields() return p.handleError(decoder.Decode(dst)) diff --git a/internal/configparse/parser_internal_test.go b/internal/configparse/parser_internal_test.go index 9a0f407..29aeccd 100644 --- a/internal/configparse/parser_internal_test.go +++ b/internal/configparse/parser_internal_test.go @@ -64,7 +64,7 @@ func TestYAMLParser(t *testing.T) { t.Run(tc.label, func(t *testing.T) { var ( parser yamlParser - rawcfg unmarshaledConfig + rawcfg UnmarshaledConfig yamlErr *yaml.TypeError ) @@ -122,7 +122,7 @@ func TestJSONParser(t *testing.T) { t.Run(tc.label, func(t *testing.T) { var ( parser jsonParser - rawcfg unmarshaledConfig + rawcfg UnmarshaledConfig ) gotErr := parser.parse(tc.in, &rawcfg) diff --git a/internal/websocketio/io.go b/internal/websocketio/io.go new file mode 100644 index 0000000..2c02a62 --- /dev/null +++ b/internal/websocketio/io.go @@ -0,0 +1,92 @@ +package websocketio + +import ( + "fmt" + "log" + + "github.com/gorilla/websocket" +) + +type Reader interface { + ReadTextMessage() (string, error) + ReadJSON(v interface{}) error +} + +type Writer interface { + WriteTextMessage(m string) error + WriteJSON(v interface{}) error +} + +type ReadWriter interface { + Reader + Writer +} + +type readWriter struct { + ws *websocket.Conn + silent bool +} + +// NewReadWriter returns a concrete type ReadWriter +// reading from and writing to the websocket connection. +func NewReadWriter(ws *websocket.Conn, silent bool) ReadWriter { + return &readWriter{ws, silent} +} + +func (rw *readWriter) ReadTextMessage() (string, error) { + messageType, p, err := rw.ws.ReadMessage() + if err != nil { + return "", fmt.Errorf("cannot read message: %s", err) + } + + if messageType != websocket.TextMessage { + return "", fmt.Errorf("message type is not TextMessage") + } + + m := string(p) + + if !rw.silent { + log.Printf("<- %s", m) + } + + return m, nil +} + +func (rw *readWriter) ReadJSON(v interface{}) error { + err := rw.ws.ReadJSON(v) + if err != nil { + return fmt.Errorf("cannot read message: %s", err) + } + + if !rw.silent { + log.Printf("<- %v", v) + } + + return nil +} + +func (rw *readWriter) WriteTextMessage(m string) error { + err := rw.ws.WriteMessage(websocket.TextMessage, []byte(m)) + if err != nil { + return fmt.Errorf("cannot write message: %s", err) + } + + if !rw.silent { + log.Printf("-> %s", m) + } + + return nil +} + +func (rw *readWriter) WriteJSON(v interface{}) error { + err := rw.ws.WriteJSON(v) + if err != nil { + return fmt.Errorf("cannot write message: %s", err) + } + + if !rw.silent { + log.Printf("-> %v", v) + } + + return nil +} diff --git a/server/handler.go b/server/handler.go new file mode 100644 index 0000000..3195fb2 --- /dev/null +++ b/server/handler.go @@ -0,0 +1,102 @@ +package server + +import ( + "encoding/json" + "fmt" + "log" + "net/http" + + "github.com/benchttp/engine/internal/configparse" + "github.com/benchttp/engine/internal/websocketio" + "github.com/benchttp/engine/runner" +) + +// Handler implements http.Handler. +// It serves a websocket server allowing +// remote manipulation of runner.Runner. +type Handler struct { + Silent bool + Token string + service *service +} + +func NewHandler(silent bool, token string) *Handler { + return &Handler{ + Silent: silent, + Token: token, + service: &service{}, + } +} + +func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/run": + h.handle(w, r) + default: + http.NotFound(w, r) + } +} + +func (h *Handler) handle(w http.ResponseWriter, r *http.Request) { + upgrader := secureUpgrader(h.Token) + ws, err := upgrader.Upgrade(w, r, nil) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + defer func() { + ws.Close() + // The client is gone, flush all the state. + // TODO Handle reconnect? + h.service.flush() + }() + + log.Println("connected with client via websocket") + + rw := websocketio.NewReadWriter(ws, h.Silent) + + for { + m := clientMessage{} + if err := rw.ReadJSON(&m); err != nil { + log.Println(err) + break + } + + switch m.Action { + case "run": + cfg, err := parseConfig(m.Data) + if err != nil { + log.Println(err) + break + } + + go h.service.doRun(rw, cfg) + + case "cancel": + if ok := h.service.cancelRun(); !ok { + rw.WriteJSON(errorMessage{Event: "error", Error: "not running"}) //nolint:errcheck + } + + default: + rw.WriteTextMessage(fmt.Sprintf("unknown procedure: %s", m.Action)) //nolint:errcheck + } + } +} + +// TODO Update package configparse for this purpose. + +func parseConfig(data configparse.UnmarshaledConfig) (runner.Config, error) { + p, err := json.Marshal(data) + if err != nil { + return runner.Config{}, err + } + + cfg, err := configparse.JSON(p) + if err != nil { + log.Println(err) + return runner.Config{}, err + } + + return cfg, nil +} diff --git a/server/message.go b/server/message.go new file mode 100644 index 0000000..15a4ac9 --- /dev/null +++ b/server/message.go @@ -0,0 +1,27 @@ +package server + +import ( + "github.com/benchttp/engine/internal/configparse" + "github.com/benchttp/engine/runner" +) + +type clientMessage struct { + Action string `json:"action"` + // Data is non-empty if field Action is "run". + Data configparse.UnmarshaledConfig `json:"data"` +} + +type progressMessage struct { + Event string `json:"state"` + Data string `json:"data"` // TODO runner.RecordingProgress +} + +type doneMessage struct { + Event string `json:"state"` + Data runner.Report `json:"data"` +} + +type errorMessage struct { + Event string `json:"state"` + Error string `json:"error"` +} diff --git a/server/run.go b/server/run.go deleted file mode 100644 index 1b11007..0000000 --- a/server/run.go +++ /dev/null @@ -1,44 +0,0 @@ -package server - -import ( - "io" - "net/http" - - "github.com/benchttp/engine/internal/configparse" -) - -func (s *server) handleRun(w http.ResponseWriter, r *http.Request) { - // Allow single run at a time - if s.isRequesterRunning() { - http.Error(w, "already running", http.StatusConflict) - return - } - defer s.flush() - - // Read input config - readBody, err := io.ReadAll(r.Body) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - // Parse json config - cfg, err := configparse.JSON(readBody) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - - // Start run - out, err := s.doRun(cfg) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - // Respond with run output - if _, err := out.WriteJSON(w); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } -} diff --git a/server/server.go b/server/server.go deleted file mode 100644 index 0e7405d..0000000 --- a/server/server.go +++ /dev/null @@ -1,94 +0,0 @@ -package server - -import ( - "context" - "net/http" - "sync" - - "github.com/benchttp/engine/runner" -) - -func ListenAndServe(addr string) error { - return http.ListenAndServe(addr, &server{}) -} - -type server struct { - mu sync.RWMutex - runner *runner.Runner - stopRun context.CancelFunc -} - -func (s *server) ServeHTTP(w http.ResponseWriter, r *http.Request) { - switch r.URL.Path { - case "/run": - s.handleRun(w, r) - case "/progress": - s.handleProgress(w, r) - case "/stop": - s.handleStop(w, r) - default: - http.NotFound(w, r) - } -} - -func (s *server) doRun(cfg runner.Config) (*runner.Report, error) { - ctx, cancel := context.WithCancel(context.Background()) - - s.setRunner(runner.New(nil)) - s.setStopRun(cancel) - - // Run benchmark - return s.runner.Run(ctx, silentConfig(cfg)) -} - -func (s *server) setRunner(r *runner.Runner) { - s.mu.Lock() - defer s.mu.Unlock() - s.runner = r -} - -func (s *server) setStopRun(cancelFunc context.CancelFunc) { - s.mu.Lock() - defer s.mu.Unlock() - s.stopRun = cancelFunc -} - -func (s *server) flush() { - s.mu.Lock() - defer s.mu.Unlock() - s.runner = nil - s.stopRun = nil -} - -func (s *server) isRequesterRunning() bool { - s.mu.RLock() - defer s.mu.RUnlock() - return s.runner != nil -} - -func (s *server) recordingProgress() (progress runner.RecordingProgress, ok bool) { - s.mu.RLock() - defer s.mu.RUnlock() - if s.runner == nil { - return runner.RecordingProgress{}, false - } - return s.runner.Progress(), true -} - -func (s *server) stopRequester() bool { - s.mu.Lock() - defer s.mu.Unlock() - if s.runner == nil { - return false - } - s.stopRun() - return true -} - -func silentConfig(cfg runner.Config) runner.Config { - cfg.Output = runner.OutputConfig{ - Silent: true, - Template: "", - } - return cfg -} diff --git a/server/service.go b/server/service.go new file mode 100644 index 0000000..c56ebb7 --- /dev/null +++ b/server/service.go @@ -0,0 +1,74 @@ +package server + +import ( + "context" + "fmt" + "sync" + + "github.com/benchttp/engine/internal/websocketio" + "github.com/benchttp/engine/runner" +) + +type service struct { + mu sync.RWMutex + runner *runner.Runner + cancel context.CancelFunc +} + +// doRun calls runner.Runner.Run. The service state is overwritten. +// The return value of runner.Runner.Run is send to the client via +// w. The run progress is streamed through w. +func (s *service) doRun(w websocketio.Writer, cfg runner.Config) { + ctx, cancel := context.WithCancel(context.Background()) + s.cancel = cancel + + s.runner = runner.New(s.sendRecordingProgess(w)) + + out, err := s.runner.Run(ctx, cfg) + if err != nil { + _ = w.WriteJSON(errorMessage{Event: "done", Error: err.Error()}) + return + } + + _ = w.WriteJSON(doneMessage{Event: "done", Data: *out}) +} + +// cancelRun cancels the run of the current runner. +// If the runner is nil, cancelRun is noop. +// cancelRun panics if cancelRun is invoked while +// service.runner is non-nil yet service.cancel is nil. +func (s *service) cancelRun() (ok bool) { + s.mu.Lock() + defer s.mu.Unlock() + if s.runner == nil { + return false + } + s.cancel() + return true +} + +// sendRecordingProgess returns a callback +// to send the current runner progress via w. +func (s *service) sendRecordingProgess(w websocketio.Writer) func(runner.RecordingProgress) { + // The callback is invoked from a goroutine spawned by Recorder.Record. + // Protect w from concurrent write with a lock. + return func(rp runner.RecordingProgress) { + s.mu.Lock() + defer s.mu.Unlock() + + m := progressMessage{ + Event: "progress", + Data: fmt.Sprintf("%s: %d/%d %d", rp.Status(), rp.DoneCount, rp.MaxCount, rp.Percent()), + } + _ = w.WriteJSON(m) + } +} + +// flush clears the service state. +// Calling service.flush locks it for writing. +func (s *service) flush() { + s.mu.Lock() + defer s.mu.Unlock() + s.runner = nil + s.cancel = nil +} diff --git a/server/state.go b/server/state.go deleted file mode 100644 index 8d153c6..0000000 --- a/server/state.go +++ /dev/null @@ -1,24 +0,0 @@ -package server - -import "net/http" - -func (s *server) handleProgress(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - return - } - - progress, ok := s.recordingProgress() - if !ok { - http.Error(w, "not running", http.StatusConflict) - return - } - - jsonProgress, err := progress.JSON() - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - w.Write(jsonProgress) -} diff --git a/server/stop.go b/server/stop.go deleted file mode 100644 index 76158fc..0000000 --- a/server/stop.go +++ /dev/null @@ -1,17 +0,0 @@ -package server - -import "net/http" - -func (s *server) handleStop(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - return - } - - defer s.flush() - - if ok := s.stopRequester(); !ok { - http.Error(w, "not running", http.StatusConflict) - return - } -} diff --git a/server/upgrader.go b/server/upgrader.go new file mode 100644 index 0000000..033eb65 --- /dev/null +++ b/server/upgrader.go @@ -0,0 +1,15 @@ +package server + +import ( + "net/http" + + "github.com/gorilla/websocket" +) + +func secureUpgrader(token string) websocket.Upgrader { + return websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return r.URL.Query().Get("access_token") == token + }, + } +}