diff --git a/lib/body.ml b/lib/body.ml index c6f2bab..de74ea2 100644 --- a/lib/body.ml +++ b/lib/body.ml @@ -151,27 +151,41 @@ module Writer = struct t let write_char t c = - Faraday.write_char t.faraday c + if not (Faraday.is_closed t.faraday) then + Faraday.write_char t.faraday c let write_string t ?off ?len s = - Faraday.write_string ?off ?len t.faraday s + if not (Faraday.is_closed t.faraday) then + Faraday.write_string ?off ?len t.faraday s let write_bigstring t ?off ?len b = - Faraday.write_bigstring ?off ?len t.faraday b + if not (Faraday.is_closed t.faraday) then + Faraday.write_bigstring ?off ?len t.faraday b let schedule_bigstring t ?off ?len (b:Bigstringaf.t) = - Faraday.schedule_bigstring ?off ?len t.faraday b + if not (Faraday.is_closed t.faraday) then + Faraday.schedule_bigstring ?off ?len t.faraday b let ready_to_write t = Serialize.Writer.wakeup t.writer let flush t kontinue = - Faraday.flush t.faraday (fun () -> - Serialize.Writer.flush t.writer kontinue); - ready_to_write t + if Serialize.Writer.is_closed t.writer then + kontinue `Closed + else begin + Faraday.flush_with_reason t.faraday (function + | Drain -> kontinue `Closed + | Nothing_pending | Shift -> Serialize.Writer.flush t.writer kontinue); + ready_to_write t + end let is_closed t = Faraday.is_closed t.faraday + let close_and_drain t = + Faraday.close t.faraday; + (* Resolve all pending flushes *) + ignore (Faraday.drain t.faraday : int) + let close t = Serialize.Writer.unyield t.writer; Faraday.close t.faraday; @@ -202,6 +216,9 @@ module Writer = struct let transfer_to_writer t = let faraday = t.faraday in + if Serialize.Writer.is_closed t.writer then + close_and_drain t + else begin match Faraday.operation faraday with | `Yield -> () | `Close -> @@ -222,9 +239,11 @@ module Writer = struct | Identity -> Serialize.Writer.schedule_fixed t.writer iovecs | Chunked _ -> Serialize.Writer.schedule_chunk t.writer iovecs end; - Serialize.Writer.flush t.writer (fun () -> - Faraday.shift faraday lengthv; - t.buffered_bytes <- t.buffered_bytes - lengthv) + Serialize.Writer.flush t.writer (function + | `Closed -> close_and_drain t + | `Written -> + Faraday.shift faraday lengthv; + t.buffered_bytes <- t.buffered_bytes - lengthv) end end end diff --git a/lib/httpun.mli b/lib/httpun.mli index 4ed24df..628c8cf 100644 --- a/lib/httpun.mli +++ b/lib/httpun.mli @@ -97,9 +97,11 @@ module Body : sig modified until a subsequent call to {!flush} has successfully completed. *) - val flush : t -> (unit -> unit) -> unit - (** [flush t f] makes all bytes in [t] available for writing to the awaiting - output channel. Once those bytes have reached that output channel, [f] + val flush : t -> ([ `Written | `Closed ] -> unit) -> unit + (** [flush t f] makes all bytes in [t] available for writing to the + awaiting output channel. Once those bytes have reached that output + channel, [f `Written] will be called. If instead, the output channel is + closed before all of those bytes are successfully written, [f `Closed] will be called. The type of the output channel is runtime-dependent, as are guarantees @@ -112,8 +114,9 @@ module Body : sig to the output channel. *) val is_closed : t -> bool - (** [is_closed t] is [true] if {!close} has been called on [t] and [false] - otherwise. A closed [t] may still have pending output. *) + (** [is_closed t] is [true] if {!close} has been called on [t], or if the + attached output channel is closed (e.g. because [report_write_result + `Closed] has been called). A closed [t] may still have pending output. *) end end diff --git a/lib/reqd.ml b/lib/reqd.ml index 8ef0d86..4749269 100644 --- a/lib/reqd.ml +++ b/lib/reqd.ml @@ -178,7 +178,9 @@ let unsafe_respond_with_upgrade t headers upgrade_handler = if t.persistent then t.persistent <- Response.persistent_connection response; t.response_state <- Upgrade (response, upgrade_handler); - Writer.flush t.writer upgrade_handler; + Writer.flush t.writer (fun _reason -> + (* TODO(anmonteiro): probably need to check `Closed here? *) + upgrade_handler ()); Body.Reader.close t.request_body; Writer.wakeup t.writer | Streaming _ | Upgrade _ -> diff --git a/lib/serialize.ml b/lib/serialize.ml index 116e7fa..ce220df 100644 --- a/lib/serialize.ml +++ b/lib/serialize.ml @@ -158,13 +158,19 @@ module Writer = struct ;; let flush t f = - flush t.encoder f + flush_with_reason t.encoder (fun reason -> + let result = + match reason with + | Nothing_pending | Shift -> `Written + | Drain -> `Closed + in + f result) let unyield t = (* This would be better implemented by a function that just takes the encoder out of a yielded state if it's in that state. Requires a change to the faraday library. *) - flush t (fun () -> ()) + flush t (fun _reason -> ()) let yield t = Faraday.yield t.encoder diff --git a/lib_test/test_client_connection.ml b/lib_test/test_client_connection.ml index e434dfd..85a6ead 100644 --- a/lib_test/test_client_connection.ml +++ b/lib_test/test_client_connection.ml @@ -1775,8 +1775,7 @@ let test_flush_response_before_shutdown () = true (Body.Writer.is_closed body); - raises_writer_closed (fun () -> - write_string t "b\r\nhello world\r\n"); + writer_closed t; connection_is_shutdown t ;; diff --git a/lib_test/test_server_connection.ml b/lib_test/test_server_connection.ml index 2eb026d..b182788 100644 --- a/lib_test/test_server_connection.ml +++ b/lib_test/test_server_connection.ml @@ -326,8 +326,9 @@ let echo_handler response reqd = let response_body = Reqd.respond_with_streaming reqd response in let rec on_read buffer ~off ~len = Body.Writer.write_string response_body (Bigstringaf.substring ~off ~len buffer); - Body.Writer.flush response_body (fun () -> - Body.Reader.schedule_read request_body ~on_eof ~on_read) + Body.Writer.flush response_body (function + | `Closed -> assert false + | `Written -> Body.Reader.schedule_read request_body ~on_eof ~on_read) and on_eof () = print_endline "echo handler eof"; Body.Writer.close response_body @@ -340,17 +341,20 @@ let streaming_handler ?(flush=false) ?(error=false) response writes reqd = let request_body = Reqd.request_body reqd in Body.Reader.close request_body; let body = Reqd.respond_with_streaming ~flush_headers_immediately:flush reqd response in - let rec write () = - match !writes with - | [] -> (match error with - | false -> Body.Writer.close body - | true -> Reqd.report_exn reqd (Failure "exn")) - | w :: ws -> - Body.Writer.write_string body w; - writes := ws; - Body.Writer.flush body write - in - write (); + let rec write reason = + match reason with + | `Closed -> assert false + | `Written -> + match !writes with + | [] -> (match error with + | false -> Body.Writer.close body + | true -> Reqd.report_exn reqd (Failure "exn")) + | w :: ws -> + Body.Writer.write_string body w; + writes := ws; + Body.Writer.flush body write + in + write `Written; ;; let synchronous_raise reqd = @@ -875,9 +879,11 @@ let test_chunked_encoding () = let response = Response.create `OK ~headers:Headers.encoding_chunked in let resp_body = Reqd.respond_with_streaming reqd response in Body.Writer.write_string resp_body "First chunk"; - Body.Writer.flush resp_body (fun () -> - Body.Writer.write_string resp_body "Second chunk"; - Body.Writer.close resp_body); + Body.Writer.flush resp_body (function + | `Closed -> assert false + | `Written -> + Body.Writer.write_string resp_body "Second chunk"; + Body.Writer.close resp_body); in let t = create ~error_handler request_handler in writer_yielded t; @@ -903,9 +909,11 @@ let test_chunked_encoding_for_error () = `Bad_request error; let body = start_response Headers.encoding_chunked in Body.Writer.write_string body "Bad"; - Body.Writer.flush body (fun () -> - Body.Writer.write_string body " request"; - Body.Writer.close body); + Body.Writer.flush body (function + | `Closed -> assert false + | `Written -> + Body.Writer.write_string body " request"; + Body.Writer.close body); in let t = create ~error_handler (fun _ -> assert false) in let c = feed_string t " X\r\n\r\n" in @@ -1079,10 +1087,12 @@ let streaming_error_handler let resp_body = start_response headers in continue_error := (fun () -> Body.Writer.write_string resp_body "got an error\n"; - Body.Writer.flush resp_body (fun () -> - continue_error := (fun () -> - Body.Writer.write_string resp_body "more output"; - Body.Writer.close resp_body))) + Body.Writer.flush resp_body (function + | `Closed -> assert false + | `Written -> + continue_error := (fun () -> + Body.Writer.write_string resp_body "more output"; + Body.Writer.close resp_body))) ;; let test_malformed_request_streaming_error_response () = @@ -1119,13 +1129,17 @@ let chunked_error_handler continue_error ?request:_ _error start_response = start_response (Headers.of_list ["transfer-encoding", "chunked"]) in Body.Writer.write_string resp_body "chunk 1\n"; - Body.Writer.flush resp_body (fun () -> - continue_error := (fun () -> - Body.Writer.write_string resp_body "chunk 2\n"; - Body.Writer.flush resp_body (fun () -> - continue_error := (fun () -> - Body.Writer.write_string resp_body "chunk 3\n"; - Body.Writer.close resp_body)))) + Body.Writer.flush resp_body (function + | `Closed -> assert false + | `Written -> + continue_error := (fun () -> + Body.Writer.write_string resp_body "chunk 2\n"; + Body.Writer.flush resp_body (function + | `Closed -> assert false + | `Written -> + continue_error := (fun () -> + Body.Writer.write_string resp_body "chunk 3\n"; + Body.Writer.close resp_body)))) ;; let test_malformed_request_chunked_error_response () = @@ -1475,9 +1489,11 @@ let test_streaming_response_before_reading_entire_body_no_error () = let resp_body = Reqd.respond_with_streaming reqd response in continue_response := (fun () -> Body.Writer.write_string resp_body "hello"; - Body.Writer.flush resp_body (fun () -> - continue_response := (fun () -> - Body.Writer.close resp_body)))) + Body.Writer.flush resp_body (function + | `Closed -> assert false + | `Written -> + continue_response := (fun () -> + Body.Writer.close resp_body)))) in let error_handler ?request:_ _error _start_response = assert false in let t = create ~error_handler request_handler in @@ -1748,9 +1764,11 @@ let test_race_condition_writer_issues_yield_after_reader_eof () = ~on_eof:(fun () -> let resp_body = Reqd.respond_with_streaming reqd response in Body.Writer.write_string resp_body (String.make 10 'a'); - Body.Writer.flush resp_body (fun () -> - continue_response := (fun () -> - Body.Writer.close resp_body)))) + Body.Writer.flush resp_body (function + | `Closed -> assert false + | `Written -> + continue_response := (fun () -> + Body.Writer.close resp_body)))) in let t = create ~error_handler response_handler in let request = @@ -1870,9 +1888,11 @@ let test_errored_chunked_streaming_response_async () = Body.Reader.close request_body; let body = Reqd.respond_with_streaming reqd response in Body.Writer.write_string body "hello"; - Body.Writer.flush body (fun () -> - continue := (fun () -> - Reqd.report_exn reqd (Failure "heh"))) + Body.Writer.flush body (function + | `Closed -> assert false + | `Written -> + continue := (fun () -> + Reqd.report_exn reqd (Failure "heh"))) in let t = create request_handler in @@ -2060,9 +2080,8 @@ let test_flush_response_before_shutdown () = write_response t response; !continue (); shutdown t; - raises_writer_closed (fun () -> - write_string t "b\r\nhello world\r\n"; - connection_is_shutdown t); + write_string t "b\r\nhello world\r\n"; + connection_is_shutdown t ;; let test_schedule_read_with_data_available () = @@ -2286,7 +2305,10 @@ let test_body_flush_after_bytes_in_the_wire () = Response.create ~headers:(Headers.of_list ["content-length", "5"]) `OK in let callback_called = ref false in - let callback () = callback_called := true in + let callback = function + | `Closed -> assert false + | `Written -> callback_called := true + in let request_handler ~flush_headers_immediately reqd = let response_body = Reqd.respond_with_streaming