Skip to content

Commit

Permalink
surface write errors through flush
Browse files Browse the repository at this point in the history
Co-Authored-By: David House <[email protected]>
Co-Authored-By: Doug Patti <[email protected]>
  • Loading branch information
3 people committed Aug 27, 2024
1 parent 02d006c commit 947293b
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 63 deletions.
39 changes: 29 additions & 10 deletions lib/body.ml
Original file line number Diff line number Diff line change
Expand Up @@ -151,27 +151,41 @@ module Writer = struct
t

let write_char t c =
Faraday.write_char t.faraday c
if not (Faraday.is_closed t.faraday) then
Faraday.write_char t.faraday c

let write_string t ?off ?len s =
Faraday.write_string ?off ?len t.faraday s
if not (Faraday.is_closed t.faraday) then
Faraday.write_string ?off ?len t.faraday s

let write_bigstring t ?off ?len b =
Faraday.write_bigstring ?off ?len t.faraday b
if not (Faraday.is_closed t.faraday) then
Faraday.write_bigstring ?off ?len t.faraday b

let schedule_bigstring t ?off ?len (b:Bigstringaf.t) =
Faraday.schedule_bigstring ?off ?len t.faraday b
if not (Faraday.is_closed t.faraday) then
Faraday.schedule_bigstring ?off ?len t.faraday b

let ready_to_write t = Serialize.Writer.wakeup t.writer

let flush t kontinue =
Faraday.flush t.faraday (fun () ->
Serialize.Writer.flush t.writer kontinue);
ready_to_write t
if Serialize.Writer.is_closed t.writer then
kontinue `Closed
else begin
Faraday.flush_with_reason t.faraday (function
| Drain -> kontinue `Closed
| Nothing_pending | Shift -> Serialize.Writer.flush t.writer kontinue);
ready_to_write t
end

let is_closed t =
Faraday.is_closed t.faraday

let close_and_drain t =
Faraday.close t.faraday;
(* Resolve all pending flushes *)
ignore (Faraday.drain t.faraday : int)

let close t =
Serialize.Writer.unyield t.writer;
Faraday.close t.faraday;
Expand Down Expand Up @@ -202,6 +216,9 @@ module Writer = struct

let transfer_to_writer t =
let faraday = t.faraday in
if Serialize.Writer.is_closed t.writer then
close_and_drain t
else
begin match Faraday.operation faraday with
| `Yield -> ()
| `Close ->
Expand All @@ -222,9 +239,11 @@ module Writer = struct
| Identity -> Serialize.Writer.schedule_fixed t.writer iovecs
| Chunked _ -> Serialize.Writer.schedule_chunk t.writer iovecs
end;
Serialize.Writer.flush t.writer (fun () ->
Faraday.shift faraday lengthv;
t.buffered_bytes <- t.buffered_bytes - lengthv)
Serialize.Writer.flush t.writer (function
| `Closed -> close_and_drain t
| `Written ->
Faraday.shift faraday lengthv;
t.buffered_bytes <- t.buffered_bytes - lengthv)
end
end
end
13 changes: 8 additions & 5 deletions lib/httpun.mli
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,11 @@ module Body : sig
modified until a subsequent call to {!flush} has successfully
completed. *)

val flush : t -> (unit -> unit) -> unit
(** [flush t f] makes all bytes in [t] available for writing to the awaiting
output channel. Once those bytes have reached that output channel, [f]
val flush : t -> ([ `Written | `Closed ] -> unit) -> unit
(** [flush t f] makes all bytes in [t] available for writing to the
awaiting output channel. Once those bytes have reached that output
channel, [f `Written] will be called. If instead, the output channel is
closed before all of those bytes are successfully written, [f `Closed]
will be called.
The type of the output channel is runtime-dependent, as are guarantees
Expand All @@ -112,8 +114,9 @@ module Body : sig
to the output channel. *)

val is_closed : t -> bool
(** [is_closed t] is [true] if {!close} has been called on [t] and [false]
otherwise. A closed [t] may still have pending output. *)
(** [is_closed t] is [true] if {!close} has been called on [t], or if the
attached output channel is closed (e.g. because [report_write_result
`Closed] has been called). A closed [t] may still have pending output. *)
end

end
Expand Down
4 changes: 3 additions & 1 deletion lib/reqd.ml
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,9 @@ let unsafe_respond_with_upgrade t headers upgrade_handler =
if t.persistent then
t.persistent <- Response.persistent_connection response;
t.response_state <- Upgrade (response, upgrade_handler);
Writer.flush t.writer upgrade_handler;
Writer.flush t.writer (fun _reason ->
(* TODO(anmonteiro): probably need to check `Closed here? *)
upgrade_handler ());
Body.Reader.close t.request_body;
Writer.wakeup t.writer
| Streaming _ | Upgrade _ ->
Expand Down
10 changes: 8 additions & 2 deletions lib/serialize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,19 @@ module Writer = struct
;;

let flush t f =
flush t.encoder f
flush_with_reason t.encoder (fun reason ->
let result =
match reason with
| Nothing_pending | Shift -> `Written
| Drain -> `Closed
in
f result)

let unyield t =
(* This would be better implemented by a function that just takes the
encoder out of a yielded state if it's in that state. Requires a change
to the faraday library. *)
flush t (fun () -> ())
flush t (fun _reason -> ())

let yield t =
Faraday.yield t.encoder
Expand Down
3 changes: 1 addition & 2 deletions lib_test/test_client_connection.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1775,8 +1775,7 @@ let test_flush_response_before_shutdown () =
true
(Body.Writer.is_closed body);

raises_writer_closed (fun () ->
write_string t "b\r\nhello world\r\n");
writer_closed t;
connection_is_shutdown t
;;

Expand Down
108 changes: 65 additions & 43 deletions lib_test/test_server_connection.ml
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,9 @@ let echo_handler response reqd =
let response_body = Reqd.respond_with_streaming reqd response in
let rec on_read buffer ~off ~len =
Body.Writer.write_string response_body (Bigstringaf.substring ~off ~len buffer);
Body.Writer.flush response_body (fun () ->
Body.Reader.schedule_read request_body ~on_eof ~on_read)
Body.Writer.flush response_body (function
| `Closed -> assert false
| `Written -> Body.Reader.schedule_read request_body ~on_eof ~on_read)
and on_eof () =
print_endline "echo handler eof";
Body.Writer.close response_body
Expand All @@ -340,17 +341,20 @@ let streaming_handler ?(flush=false) ?(error=false) response writes reqd =
let request_body = Reqd.request_body reqd in
Body.Reader.close request_body;
let body = Reqd.respond_with_streaming ~flush_headers_immediately:flush reqd response in
let rec write () =
match !writes with
| [] -> (match error with
| false -> Body.Writer.close body
| true -> Reqd.report_exn reqd (Failure "exn"))
| w :: ws ->
Body.Writer.write_string body w;
writes := ws;
Body.Writer.flush body write
in
write ();
let rec write reason =
match reason with
| `Closed -> assert false
| `Written ->
match !writes with
| [] -> (match error with
| false -> Body.Writer.close body
| true -> Reqd.report_exn reqd (Failure "exn"))
| w :: ws ->
Body.Writer.write_string body w;
writes := ws;
Body.Writer.flush body write
in
write `Written;
;;

let synchronous_raise reqd =
Expand Down Expand Up @@ -875,9 +879,11 @@ let test_chunked_encoding () =
let response = Response.create `OK ~headers:Headers.encoding_chunked in
let resp_body = Reqd.respond_with_streaming reqd response in
Body.Writer.write_string resp_body "First chunk";
Body.Writer.flush resp_body (fun () ->
Body.Writer.write_string resp_body "Second chunk";
Body.Writer.close resp_body);
Body.Writer.flush resp_body (function
| `Closed -> assert false
| `Written ->
Body.Writer.write_string resp_body "Second chunk";
Body.Writer.close resp_body);
in
let t = create ~error_handler request_handler in
writer_yielded t;
Expand All @@ -903,9 +909,11 @@ let test_chunked_encoding_for_error () =
`Bad_request error;
let body = start_response Headers.encoding_chunked in
Body.Writer.write_string body "Bad";
Body.Writer.flush body (fun () ->
Body.Writer.write_string body " request";
Body.Writer.close body);
Body.Writer.flush body (function
| `Closed -> assert false
| `Written ->
Body.Writer.write_string body " request";
Body.Writer.close body);
in
let t = create ~error_handler (fun _ -> assert false) in
let c = feed_string t " X\r\n\r\n" in
Expand Down Expand Up @@ -1079,10 +1087,12 @@ let streaming_error_handler
let resp_body = start_response headers in
continue_error := (fun () ->
Body.Writer.write_string resp_body "got an error\n";
Body.Writer.flush resp_body (fun () ->
continue_error := (fun () ->
Body.Writer.write_string resp_body "more output";
Body.Writer.close resp_body)))
Body.Writer.flush resp_body (function
| `Closed -> assert false
| `Written ->
continue_error := (fun () ->
Body.Writer.write_string resp_body "more output";
Body.Writer.close resp_body)))
;;

let test_malformed_request_streaming_error_response () =
Expand Down Expand Up @@ -1119,13 +1129,17 @@ let chunked_error_handler continue_error ?request:_ _error start_response =
start_response (Headers.of_list ["transfer-encoding", "chunked"])
in
Body.Writer.write_string resp_body "chunk 1\n";
Body.Writer.flush resp_body (fun () ->
continue_error := (fun () ->
Body.Writer.write_string resp_body "chunk 2\n";
Body.Writer.flush resp_body (fun () ->
continue_error := (fun () ->
Body.Writer.write_string resp_body "chunk 3\n";
Body.Writer.close resp_body))))
Body.Writer.flush resp_body (function
| `Closed -> assert false
| `Written ->
continue_error := (fun () ->
Body.Writer.write_string resp_body "chunk 2\n";
Body.Writer.flush resp_body (function
| `Closed -> assert false
| `Written ->
continue_error := (fun () ->
Body.Writer.write_string resp_body "chunk 3\n";
Body.Writer.close resp_body))))
;;

let test_malformed_request_chunked_error_response () =
Expand Down Expand Up @@ -1475,9 +1489,11 @@ let test_streaming_response_before_reading_entire_body_no_error () =
let resp_body = Reqd.respond_with_streaming reqd response in
continue_response := (fun () ->
Body.Writer.write_string resp_body "hello";
Body.Writer.flush resp_body (fun () ->
continue_response := (fun () ->
Body.Writer.close resp_body))))
Body.Writer.flush resp_body (function
| `Closed -> assert false
| `Written ->
continue_response := (fun () ->
Body.Writer.close resp_body))))
in
let error_handler ?request:_ _error _start_response = assert false in
let t = create ~error_handler request_handler in
Expand Down Expand Up @@ -1748,9 +1764,11 @@ let test_race_condition_writer_issues_yield_after_reader_eof () =
~on_eof:(fun () ->
let resp_body = Reqd.respond_with_streaming reqd response in
Body.Writer.write_string resp_body (String.make 10 'a');
Body.Writer.flush resp_body (fun () ->
continue_response := (fun () ->
Body.Writer.close resp_body))))
Body.Writer.flush resp_body (function
| `Closed -> assert false
| `Written ->
continue_response := (fun () ->
Body.Writer.close resp_body))))
in
let t = create ~error_handler response_handler in
let request =
Expand Down Expand Up @@ -1870,9 +1888,11 @@ let test_errored_chunked_streaming_response_async () =
Body.Reader.close request_body;
let body = Reqd.respond_with_streaming reqd response in
Body.Writer.write_string body "hello";
Body.Writer.flush body (fun () ->
continue := (fun () ->
Reqd.report_exn reqd (Failure "heh")))
Body.Writer.flush body (function
| `Closed -> assert false
| `Written ->
continue := (fun () ->
Reqd.report_exn reqd (Failure "heh")))
in

let t = create request_handler in
Expand Down Expand Up @@ -2060,9 +2080,8 @@ let test_flush_response_before_shutdown () =
write_response t response;
!continue ();
shutdown t;
raises_writer_closed (fun () ->
write_string t "b\r\nhello world\r\n";
connection_is_shutdown t);
write_string t "b\r\nhello world\r\n";
connection_is_shutdown t
;;

let test_schedule_read_with_data_available () =
Expand Down Expand Up @@ -2286,7 +2305,10 @@ let test_body_flush_after_bytes_in_the_wire () =
Response.create ~headers:(Headers.of_list ["content-length", "5"]) `OK
in
let callback_called = ref false in
let callback () = callback_called := true in
let callback = function
| `Closed -> assert false
| `Written -> callback_called := true
in
let request_handler ~flush_headers_immediately reqd =
let response_body =
Reqd.respond_with_streaming
Expand Down

0 comments on commit 947293b

Please sign in to comment.