diff --git a/readme.md b/readme.md index 4af6554..d8c4cfc 100644 --- a/readme.md +++ b/readme.md @@ -20,6 +20,7 @@ import "github.com/firstrow/tcp_server" func main() { server := tcp_server.New("localhost:9999") + server.MessageTerminator='\n' // Optional end of message byte, default to newline. server.OnNewClient(func(c *tcp_server.Client) { // new client connected diff --git a/tcp_server.go b/tcp_server.go index 588931e..1a1127d 100644 --- a/tcp_server.go +++ b/tcp_server.go @@ -18,13 +18,15 @@ type server struct { onNewClientCallback func(c *Client) onClientConnectionClosed func(c *Client, err error) onNewMessage func(c *Client, message string) + close chan bool + MessageTerminator rune } // Read client data from channel func (c *Client) listen() { reader := bufio.NewReader(c.conn) for { - message, err := reader.ReadString('\n') + message, err := reader.ReadString(byte(c.Server.MessageTerminator)) if err != nil { c.conn.Close() c.Server.onClientConnectionClosed(c, err) @@ -77,8 +79,16 @@ func (s *server) Listen() { } defer listener.Close() + go func() { + <-s.close + listener.Close() + }() + for { - conn, _ := listener.Accept() + conn, lErr := listener.Accept() + if lErr != nil { + return + } client := &Client{ conn: conn, Server: s, @@ -88,11 +98,17 @@ func (s *server) Listen() { } } +func (s *server) Close() { + s.close <- true +} + // Creates new tcp server instance func New(address string) *server { log.Println("Creating server with address", address) server := &server{ - address: address, + address: address, + close: make(chan bool, 1), + MessageTerminator: '\n', } server.OnNewClient(func(c *Client) {}) diff --git a/tcp_server_test.go b/tcp_server_test.go index 6f84ab1..30108ad 100644 --- a/tcp_server_test.go +++ b/tcp_server_test.go @@ -40,11 +40,16 @@ func Test_accepting_new_client_callback(t *testing.T) { t.Fatal("Failed to connect to test server") } conn.Write([]byte("Test message\n")) + + time.Sleep(100 * time.Millisecond) + conn.Close() // Wait for server time.Sleep(10 * time.Millisecond) + server.Close() + Convey("Messages should be equal", t, func() { So(messageText, ShouldEqual, "Test message\n") }) @@ -58,3 +63,57 @@ func Test_accepting_new_client_callback(t *testing.T) { So(connectinClosed, ShouldEqual, true) }) } + +func Test_accepting_new_client_callback_different_terminator(t *testing.T) { + server := buildTestServer() + + var messageReceived bool + var messageText string + var newClient bool + var connectinClosed bool + + server.OnNewClient(func(c *Client) { + newClient = true + }) + server.OnNewMessage(func(c *Client, message string) { + messageReceived = true + messageText = message + }) + server.OnClientConnectionClosed(func(c *Client, err error) { + connectinClosed = true + }) + server.MessageTerminator = '\u0000' + go server.Listen() + + // Wait for server + // If test fails - increase this value + time.Sleep(10 * time.Millisecond) + + conn, err := net.Dial("tcp", "localhost:9999") + if err != nil { + t.Fatal("Failed to connect to test server") + } + conn.Write([]byte("Test message\u0000")) + + time.Sleep(100 * time.Millisecond) + + conn.Close() + + // Wait for server + time.Sleep(10 * time.Millisecond) + + server.Close() + + Convey("Messages should be equal", t, func() { + So(messageText, ShouldEqual, "Test message\u0000") + }) + Convey("It should receive new client callback", t, func() { + So(newClient, ShouldEqual, true) + }) + Convey("It should receive message callback", t, func() { + So(messageReceived, ShouldEqual, true) + }) + Convey("It should receive connection closed callback", t, func() { + So(connectinClosed, ShouldEqual, true) + }) +}