diff --git a/src/rtp/packet.zig b/src/rtp/packet.zig index fa366cb..4c622a2 100644 --- a/src/rtp/packet.zig +++ b/src/rtp/packet.zig @@ -41,6 +41,10 @@ pub const Extension = struct { try writer.writeInt(u16, @intCast(@divExact(ext.data.len, 4)), .big); try writer.writeAll(ext.data); } + + inline fn size(ext: *const Extension) usize { + return ext.data.len + 4; + } }; header: Header, @@ -111,6 +115,14 @@ pub fn format(self: Self, writer: *std.Io.Writer) !void { try writer.print("{d} bytes\n", .{self.payload.len}); } +const header_size = @divExact(@bitSizeOf(Header), 8); + +pub fn size(packet: *const Self) usize { + const ext_size = if (packet.extension) |ext| ext.size() else 0; + const padding_size = if (packet.header.padding) 4 - @rem(packet.payload.len + ext_size, 4) else 0; + return header_size + packet.csrc_list.len * 4 + ext_size + packet.payload.len + padding_size; +} + test "parse packet" { const rtp_packet: [16]u8 = [_]u8{ 0x80, 0xE0, 0x51, 0xA4, 0x00, 0x0D, 0xDF, @@ -130,6 +142,8 @@ test "parse packet" { try std.testing.expect(packet.header.timestamp == 0x000DDF22); try std.testing.expect(packet.header.ssrc == 0x54A7D4F3); try std.testing.expectEqualSlices(u8, &[_]u8{ 0x01, 0x02, 0x03, 0x04 }, packet.payload); + + try std.testing.expectEqual(16, packet.size()); } test "packet too short" { @@ -157,6 +171,8 @@ test "packet with csrc" { for (csrc_list, parsed_packet.csrc_list) |csrc, parsed_csrc| { try std.testing.expect(csrc == parsed_csrc); } + + try std.testing.expectEqual(29, parsed_packet.size()); } test "packet with extension" { @@ -183,12 +199,14 @@ test "packet with padding" { 0xB8, 0x30, 0x73, 0xBD, 0xDE, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x05, 0x00, 0x09, 0x00, 0x00, 0x00, 0x04, + 0x05, 0x00, 0x00, 0x00, 0x00, 0x04, }; const parsed_packet = try Self.parse(packet[0..]); try std.testing.expect(parsed_packet.header.padding); try std.testing.expect(parsed_packet.padding_size == 4); + + try std.testing.expectEqual(48, parsed_packet.size()); } test "write packet" { diff --git a/src/rtsp/rtsp.zig b/src/rtsp/rtsp.zig index 5d454e5..d155f0e 100644 --- a/src/rtsp/rtsp.zig +++ b/src/rtsp/rtsp.zig @@ -1,76 +1,84 @@ +pub const Server = @import("server.zig"); + const std = @import("std"); const rtp = @import("rtp"); const Reader = std.Io.Reader; -const methods = std.StaticStringMap(Method).initComptime(&.{ - .{ "OPTIONS", Method.options }, - .{ "DESCRIBE", Method.describe }, - .{ "ANNOUNCE", Method.announce }, - .{ "SETUP", Method.setup }, - .{ "PLAY", Method.play }, - .{ "PAUSE", Method.pause }, - .{ "TEARDOWN", Method.teardown }, - .{ "GET_PARAMETER", Method.get_parameter }, - .{ "SET_PARAMETER", Method.set_parameter }, - .{ "REDIRECT", Method.redirect }, - .{ "RECORD", Method.record }, -}); - -pub const uri_flags: std.Uri.Format.Flags = .{ - .authentication = false, - .scheme = true, - .authority = true, - .path = true, - .query = true, - .fragment = true, -}; - pub const Error = error{ ParseError, } || std.mem.Allocator.Error || Reader.Error; pub const Method = enum { - options, - describe, - announce, - setup, - play, - pause, - teardown, - get_parameter, - set_parameter, - redirect, - record, - - pub fn toString(self: *const Method) []const u8 { - return switch (self.*) { - .options => "OPTIONS", - .describe => "DESCRIBE", - .announce => "ANNOUNCE", - .setup => "SETUP", - .play => "PLAY", - .pause => "PAUSE", - .teardown => "TEARDOWN", - .get_parameter => "GET_PARAMETER", - .set_parameter => "SET_PARAMETER", - .redirect => "REDIRECT", - .record => "RECORD", + OPTIONS, + DESCRIBE, + ANNOUNCE, + SETUP, + PLAY, + PAUSE, + TEARDOWN, + GET_PARAMETER, + SET_PARAMETER, + REDIRECT, + RECORD, + + pub fn expectBody(self: Method) bool { + return switch (self) { + .ANNOUNCE, .SET_PARAMETER => true, + else => false, }; } - test "toString" { - try std.testing.expectEqualStrings("OPTIONS", Method.options.toString()); - try std.testing.expectEqualStrings("DESCRIBE", Method.describe.toString()); - try std.testing.expectEqualStrings("ANNOUNCE", Method.announce.toString()); - try std.testing.expectEqualStrings("SETUP", Method.setup.toString()); - try std.testing.expectEqualStrings("PLAY", Method.play.toString()); - try std.testing.expectEqualStrings("PAUSE", Method.pause.toString()); - try std.testing.expectEqualStrings("TEARDOWN", Method.teardown.toString()); - try std.testing.expectEqualStrings("GET_PARAMETER", Method.get_parameter.toString()); - try std.testing.expectEqualStrings("SET_PARAMETER", Method.set_parameter.toString()); - try std.testing.expectEqualStrings("REDIRECT", Method.redirect.toString()); - try std.testing.expectEqualStrings("RECORD", Method.record.toString()); + pub fn responseExpectBody(self: Method) bool { + return switch (self) { + .DESCRIBE, .GET_PARAMETER => true, + else => false, + }; + } +}; + +pub const Status = enum(u10) { + success = 200, + low_on_storage = 250, + + method_not_allowed = 405, + parameter_not_understood = 451, + conference_not_found = 452, + not_enough_bandwidth = 453, + session_not_found = 454, + invalid_method = 455, + invalid_header = 456, + invalid_range = 457, + parameter_readonly = 458, + aggregate_not_allowed = 459, + only_aggregate = 460, + unsupported_transport = 461, + destination_unreachable = 462, + + option_not_supported = 551, + + _, + + pub fn phrase(self: Status) ?[]const u8 { + return switch (self) { + .success => "SUCCESS", + .low_on_storage => "Low on Storage Space", + .method_not_allowed => "Method Not Allowed", + .parameter_not_understood => "Parameter Not Understood", + .conference_not_found => "Parameter Not Understood", + .not_enough_bandwidth => "Not Enough Bandwidth", + .session_not_found => "Session Not Found", + .invalid_method => "Method Not Valid in This State", + .invalid_header => "Header Field Not Valid for Resource", + .invalid_range => "Invalid Range", + .parameter_readonly => "Parameter Is Read-Only", + .aggregate_not_allowed => "Aggregate Operation Not Allowed", + .only_aggregate => "Only Aggregate Operation Allowed", + .unsupported_transport => "Unsupported Transport", + .destination_unreachable => "Destination Unreachable", + .option_not_supported => "Option not supported", + else => null, + }; } }; @@ -121,9 +129,46 @@ pub const Header = struct { pub const TransportHeader = struct { proto: enum { tcp, udp } = .udp, + /// False means multicast unicast: bool = true, interleaved: ?struct { u8, u8 } = null, + pub const TransportError = error{InvalidTransportHeader}; + + pub fn parse(header_value: []const u8) TransportError!TransportHeader { + var it = std.mem.splitScalar(u8, header_value, ';'); + var transport: TransportHeader = .{}; + + const protocol = it.next().?; + if (std.mem.eql(u8, protocol, "RTP/AVP")) { + transport.proto = .udp; + } else if (std.mem.eql(u8, protocol, "RTP/AVP/UDP")) { + transport.proto = .udp; + } else if (std.mem.eql(u8, protocol, "RTP/AVP/TCP")) { + transport.proto = .tcp; + } else { + return error.InvalidTransportHeader; + } + + while (it.next()) |parameter| { + if (std.mem.eql(u8, parameter, "unicast")) { + transport.unicast = true; + } else if (std.mem.eql(u8, parameter, "multicast")) { + transport.unicast = false; + } else if (std.mem.startsWith(u8, parameter, "interleaved=")) { + var interleaved_it = std.mem.splitScalar(u8, parameter[12..], '-'); + transport.interleaved = .{ + std.fmt.parseInt(u8, interleaved_it.next().?, 10) catch return error.InvalidTransportHeader, + std.fmt.parseInt(u8, interleaved_it.rest(), 10) catch return error.InvalidTransportHeader, + }; + } + } + + if (transport.proto == .tcp and transport.interleaved == null) return error.InvalidTransportHeader; + + return transport; + } + pub fn write(self: *const TransportHeader, writer: *std.Io.Writer) std.Io.Writer.Error!void { try writer.writeAll(if (self.proto == .tcp) "RTP/AVP/TCP" else "RTP/AVP"); if (self.unicast) { @@ -135,6 +180,22 @@ pub const TransportHeader = struct { try writer.print(";interleaved={}-{}", .{ interleaved.@"0", interleaved.@"1" }); } } + + test "parse transport header" { + const transport = try parse("RTP/AVP/TCP;unicast;interleaved=0-1"); + try std.testing.expect(transport.unicast); + try std.testing.expectEqual(.tcp, transport.proto); + try std.testing.expectEqual(.{ 0, 1 }, transport.interleaved); + } +}; + +pub const uri_flags: std.Uri.Format.Flags = .{ + .authentication = false, + .scheme = true, + .authority = true, + .path = true, + .query = true, + .fragment = true, }; pub const StatusLine = struct { @@ -165,44 +226,6 @@ pub const StatusLine = struct { } }; -pub const RequestLine = struct { - method: Method, - uri: std.Uri, - - pub fn parse(line: []const u8) !RequestLine { - var iterator = std.mem.tokenizeScalar(u8, line, ' '); - const method = blk: { - if (iterator.next()) |str| { - if (methods.get(str)) |method| break :blk method else return error.ParseError; - } else return error.ParseError; - }; - const uri = iterator.next() orelse return error.ParseError; - if (!std.mem.eql(u8, iterator.rest(), "RTSP/1.0")) return error.ParseError; - - return .{ .method = method, .uri = std.Uri.parse(uri) catch return error.ParseError }; - } - - pub fn write(self: *const RequestLine, path: ?[]const u8, writer: *std.Io.Writer) std.Io.Writer.Error!void { - _ = try writer.write(self.method.toString()); - _ = try writer.writeByte(' '); - - const absolute_path = if (path) |p| std.mem.startsWith(u8, p, "rtsp") else false; - - if (!absolute_path) { - try std.Uri.writeToStream(&self.uri, writer, uri_flags); - } - - if (path) |p| { - if (!std.mem.startsWith(u8, p, "/")) { - _ = try writer.writeByte('/'); - } - _ = try writer.write(p); - } - - _ = try writer.write(" RTSP/1.0\r\n"); - } -}; - /// A lazy parser for RTSP messages. pub const Parser = struct { reader: *Reader, @@ -215,15 +238,6 @@ pub const Parser = struct { return Parser{ .reader = reader }; } - pub fn getRequestLine(parser: *Parser) Error!RequestLine { - if (parser.parse_state != .first_line) return error.ParseError; - const line = try readLine(parser.reader); - - const result = try RequestLine.parse(line); - parser.parse_state = .header; - return result; - } - pub fn getResponseStatus(parser: *Parser) Error!StatusLine { if (parser.parse_state != .first_line) return error.ParseError; const line = parser.reader.takeDelimiterInclusive('\n') catch return error.ParseError; @@ -288,10 +302,6 @@ pub const Writer = struct { return Writer{ .writer = writer }; } - pub fn writeRequestLine(self: *Writer, path: ?[]const u8, request_line: RequestLine) std.Io.Writer.Error!void { - try request_line.write(path, self.writer); - } - pub fn writeStatusLine(self: *Writer, status_line: StatusLine) std.Io.Writer.Error!void { try status_line.write(self.writer); } @@ -587,12 +597,6 @@ test "DigestAuthParams: parse" { try std.testing.expectEqualStrings("abc123", auth_params.nonce); } -test "request line: invalid request" { - try std.testing.expectError(error.ParseError, RequestLine.parse("METHOD /url RTSP/1.0")); - try std.testing.expectError(error.ParseError, RequestLine.parse("DESCRIBE /hello RTSP/1.0")); - try std.testing.expectError(error.ParseError, RequestLine.parse("DESCRIBE rtsp://example.com/hello RTSP/1.1")); -} - test "response parser" { const response_text = "RTSP/1.0 200 OK\r\nCSeq: 2\r\nSession: 12345678\r\nContent-Length: 13\r\n\r\nHello, World!"; var reader = Reader.fixed(response_text); @@ -625,39 +629,6 @@ test "response parser" { try std.testing.expectEqualStrings("Hello, World!", body.?); } -test "request parser" { - const response_text = "ANNOUNCE rtsp://example.com/my/stream RTSP/1.0\nCSeq: 2\r\nSession: 12345678\r\nContent-Length: 13\r\n\r\nHello, World!"; - var reader = Reader.fixed(response_text); - var parser = Parser.init(&reader); - - const request_line = try parser.getRequestLine(); - - try std.testing.expectEqual(.announce, request_line.method); - try std.testing.expectEqualStrings("/my/stream", request_line.uri.path.percent_encoded); - - var header = try parser.nextHeader(); - try std.testing.expect(header != null); - try std.testing.expectEqualStrings("CSeq", header.?.name); - try std.testing.expectEqualStrings("2", header.?.value); - - header = try parser.nextHeader(); - try std.testing.expect(header != null); - try std.testing.expectEqualStrings("Session", header.?.name); - try std.testing.expectEqualStrings("12345678", header.?.value); - - header = try parser.nextHeader(); - try std.testing.expect(header != null); - try std.testing.expectEqualStrings("Content-Length", header.?.name); - try std.testing.expectEqualStrings("13", header.?.value); - - header = try parser.nextHeader(); - try std.testing.expect(header == null); - - const body = try parser.getBody(); - try std.testing.expect(body != null); - try std.testing.expectEqualStrings("Hello, World!", body.?); -} - test { std.testing.refAllDecls(@This()); } diff --git a/src/rtsp/server.zig b/src/rtsp/server.zig new file mode 100644 index 0000000..664d78c --- /dev/null +++ b/src/rtsp/server.zig @@ -0,0 +1,197 @@ +//! Handles a single client session +const std = @import("std"); +const rtsp = @import("rtsp.zig"); +const rtp = @import("rtp"); + +const Reader = std.Io.Reader; +const Writer = std.Io.Writer; + +const Server = @This(); + +reader: *Reader, +writer: *Writer, + +pub const Request = struct { + server: *Server, + head: Head, + + pub const Head = struct { + method: rtsp.Method, + uri: []const u8, + cseq: u32, + session: ?[]const u8, + authenticate: ?[]const u8, + transport: ?rtsp.TransportHeader, + content_length: u32, + + pub const Error = error{ + UnknownRtspMethod, + RtspHeadersInvalid, + RtspVersionInvalid, + MissingSequenceHeader, + /// A request body is not expected for this METHOD. + BodyUnexpected, + } || rtsp.TransportHeader.TransportError; + + pub fn parse(buffer: []const u8) !Head { + var it = std.mem.splitSequence(u8, buffer, "\r\n"); + + const first_line = it.next().?; + var it2 = std.mem.splitScalar(u8, first_line, ' '); + + const method_str = it2.next() orelse return error.RtspHeadersInvalid; + const method = std.meta.stringToEnum(rtsp.Method, method_str) orelse return error.UnknownRtspMethod; + + const uri = it2.next() orelse return error.RtspHeadersInvalid; + const version = std.mem.trim(u8, it2.rest(), " \t"); + + if (!std.ascii.eqlIgnoreCase(version, "rtsp/1.0")) return error.RtspVersionInvalid; + + var head = Head{ + .method = method, + .uri = std.mem.trim(u8, uri, " \t"), + .cseq = std.math.maxInt(u32), + .session = null, + .authenticate = null, + .transport = null, + .content_length = 0, + }; + + // Parse headers + while (it.next()) |line| { + if (line.len == 0) { + if (head.cseq == std.math.maxInt(u32)) return error.MissingSequenceHeader; + return head; + } + + var line_it = std.mem.splitScalar(u8, line, ':'); + const header_name = line_it.next().?; + const header_value = std.mem.trim(u8, line_it.rest(), " \t"); + if (header_name.len == 0) return error.RtspHeadersInvalid; + + if (std.ascii.eqlIgnoreCase(header_name, "cseq")) { + head.cseq = std.fmt.parseInt(u32, header_value, 10) catch return error.RtspHeadersInvalid; + } else if (std.ascii.eqlIgnoreCase(header_name, "session")) { + head.session = header_name; + } else if (std.ascii.eqlIgnoreCase(header_name, "www-authenticate")) { + head.authenticate = header_name; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) { + head.content_length = std.fmt.parseInt(u32, header_name, 10) catch return error.RtspHeadersInvalid; + if (head.content_length > 0 and !head.method.expectBody()) { + return error.ContentLengthUnexpected; + } + } else if (std.ascii.eqlIgnoreCase(header_name, "transport")) { + head.transport = try .parse(header_value); + } + } + + return error.MissingFinalNewLine; + } + + test "parse" { + const request_bytes = + \\ANNOUNCE rtsp://localhost/ISAPI/Streaming/Channels/101 RTSP/1.0 + \\CSeq: 5 + \\Accept: application/sdp + \\Content-Length: 140 + \\Session: 34F4545A + \\ + ; + + const head = try parse(request_bytes); + try std.testing.expectEqual(.ANNOUNCE, head.method); + try std.testing.expectEqual(5, head.cseq); + try std.testing.expectEqual(140, head.content_length); + try std.testing.expectEqual(null, head.transport); + try std.testing.expectEqual(null, head.authenticate); + try std.testing.expectEqualStrings("rtsp://localhost/ISAPI/Streaming/Channels/101", head.uri); + try std.testing.expectEqualStrings("34F454A", head.session.?); + } + }; + + pub const RespondOptions = struct { + status: rtsp.Status = .success, + reason: ?[]const u8 = null, + extra_headers: []const rtsp.Header = &.{}, + }; + + /// Send a entire rtsp response to the client. + /// + /// If the METHOD does not expect body, the `body` is ignored. + pub fn respond( + request: *Request, + body: []const u8, + options: RespondOptions, + ) Writer.Error!void { + var out = request.server.writer; + try out.print("RTSP/1.0 {} {s}\r\n", .{ + @intFromEnum(options.status), + options.reason orelse options.status.phrase() orelse "", + }); + + try out.print("CSeq: {}\r\n", .{request.head.cseq}); + try out.writeAll("Server: Zig RTSP/0.1.0\r\n"); + + if (request.head.method == .OPTIONS) { + try out.writeAll("Public: DESCRIBE, SETUP, PLAY, ANNOUNCE, RECORD, GET_PARAMETER, TEARDOWN\r\n"); + } + + for (options.extra_headers) |header| { + try out.print("{s}: {s}\r\n", .{ header.name, header.value }); + } + + if (request.head.method.responseExpectBody()) { + try out.print("Content-Length: {}\r\n\r\n", .{body.len}); + try out.writeAll(body); + } else { + try out.writeAll("\r\n"); + } + + try out.flush(); + } +}; + +/// Initialize a server that handle a single client session. +pub fn init(r: *Reader, w: *Writer) Server { + return .{ .reader = r, .writer = w }; +} + +pub fn receiveHead(s: *Server) !Request { + const head_buffer = try receiveHeadFromReader(s.reader); + return .{ + .head = Request.Head.parse(head_buffer) catch return error.RtspHeadersInvalid, + .server = s, + }; +} + +/// Writes rtp packet interleaved with RTSP/RTCP packets. +pub fn writeRtpPacket(s: *Server, channel: u8, packet: rtp.Packet) !void { + try s.writer.writeByte('$'); + try s.writer.writeInt(u8, channel, .big); + try s.writer.writeInt(u16, @intCast(packet.size()), .big); + try packet.write(s.writer); +} + +fn receiveHeadFromReader(r: *Reader) ![]const u8 { + const max_head_size = r.buffer.len; + var head_len: usize = 0; + var hp = std.http.HeadParser{}; + while (true) { + if (head_len >= max_head_size) return error.RtspHeadersOversize; + const remaining = r.buffered()[head_len..]; + if (remaining.len == 0) { + r.fillMore() catch |err| switch (err) { + error.EndOfStream => return error.RtspRequestTruncated, + error.ReadFailed => return err, + }; + continue; + } + + head_len += hp.feed(remaining); + if (hp.state == .finished) { + const result = r.buffered()[0..head_len]; + r.toss(head_len); + return result; + } + } +}