Skip to content

Commit

Permalink
Merge pull request xapi-project#8 from mseri/CA-260671
Browse files Browse the repository at this point in the history
CA-260671: do not leak file descriptors
  • Loading branch information
gaborigloi authored Aug 2, 2017
2 parents 41cfb49 + 1b635dd commit 8413eea
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 77 deletions.
16 changes: 8 additions & 8 deletions src/helpers.ml
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,23 @@

let split str n =
let l = String.length str in
if n>l
then (str,"")
if n>l
then (str,"")
else (String.sub str 0 n, String.sub str n (l-n))

let break pred str =
let l = String.length str in
let rec inner = function
let rec inner = function
| 0 -> (str,"")
| n ->
if pred str.[l-n]
| n ->
if pred str.[l-n]
then split str (l-n)
else inner (n-1)
in inner l

let str_drop_while pred str =
let l = String.length str in
let rec inner = function
let str_drop_while pred str =
let l = String.length str in
let rec inner = function
| 0 -> ""
| n ->
if pred str.[l-n]
Expand Down
6 changes: 3 additions & 3 deletions src/iteratees.ml
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ module Iteratee (IO : Monad) = struct
let rec step st = match st with
| Chunk s ->
let news = str_drop_while pred s in
if news=""
if news=""
then ie_contM step (Chunk "")
else ie_doneM () (Chunk news)
| Eof _ ->
Expand All @@ -185,13 +185,13 @@ module Iteratee (IO : Monad) = struct

let apply f =
let rec step st = match st with
| Chunk s ->
| Chunk s ->
f s;
ie_contM step (Chunk "")
| Eof _ -> ie_doneM () st
in IE_cont (None, step)

let liftI m =
let liftI m =
let step st i =
match i with
| IE_cont (None, k) -> k st
Expand Down
12 changes: 6 additions & 6 deletions src/iteratees.mli
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ module Iteratee :
sig
(** The type t describes the current state of the iteratee.
It's either 'Done', in which case it's got some sort of
value, or it's in the 'Cont' state, which mean's it
hasn't finished processing - in this case it may be in
an error state, or it may be awaiting more input.
value, or it's in the 'Cont' state, which mean's it
hasn't finished processing - in this case it may be in
an error state, or it may be awaiting more input.
*)

type 'a t =
Expand Down Expand Up @@ -96,7 +96,7 @@ module Iteratee :
(** read_int32 - reads an int32 from the stream (bigendian byte order) *)
val read_int32 : int32 t

(** drop_while - iteratee that drops characters from the stream while they
(** drop_while - iteratee that drops characters from the stream while they
satisfy the supplied predicate *)
val drop_while : (char -> bool) -> unit t

Expand Down Expand Up @@ -134,11 +134,11 @@ module Iteratee :
iteratee *)
type 'a enumeratee = 'a t -> 'a t t

(** take - takes exactly n characters from the input stream and applies them to
(** take - takes exactly n characters from the input stream and applies them to
the inner stream *)
val take : int -> 'a t -> 'a t t

(** stream_printer - given a name and an iteratee i, returns an iteratee that
(** stream_printer - given a name and an iteratee i, returns an iteratee that
will print the chunks supplied before handing them off to the iteratee *)
val stream_printer : string -> 'a t -> 'a t t

Expand Down
16 changes: 8 additions & 8 deletions src/lwt_support.ml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
* GNU Lesser General Public License for more details.
*)

open Iteratees
open Iteratees

type 'a t = 'a Iteratee(Lwt).t =
| IE_done of 'a
Expand All @@ -31,22 +31,22 @@ let lwt_fd_enumerator fd =
let blocksize = 1024 in
let str = Bytes.create blocksize in
let get_str n =
if n=0
then (Eof None)
if n=0
then (Eof None)
else (Chunk (String.sub str 0 n))
in
let rec go = function
| IE_cont (None,x) ->
| IE_cont (None,x) ->
Lwt_unix.read fd str 0 blocksize >>= fun n ->
x (get_str n) >>= fun x ->
x (get_str n) >>= fun x ->
Lwt.return (fst x) >>= fun x ->
go x
| x -> Lwt.return x
in go
| x -> Lwt.return x
in go

let lwt_enumerator file iter =
let (>>=) = Lwt.bind in
Lwt_unix.openfile file [Lwt_unix.O_RDONLY] 0o777 >>= fun fd ->
Lwt_unix.openfile file [Lwt_unix.O_RDONLY] 0o777 >>= fun fd ->
lwt_fd_enumerator fd iter

exception Host_not_found of string
Expand Down
4 changes: 2 additions & 2 deletions src/test.ml
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ module NoOpMonad = struct
end

module StringMonad = struct
type 'a t =
type 'a t =
{ data : 'a;
str : string }
let return a = { data=a; str=""; }
let bind x f =
let newstr = f x.data in
{newstr with str = x.str ^ newstr.str}

let strwr x =
let strwr x =
{ data=(); str=x }
let getstr x = x.str
let getdata x = x.data
Expand Down
52 changes: 26 additions & 26 deletions src/websockets.ml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ module Wsprotocol (IO : Iteratees.Monad) = struct
module I = Iteratees.Iteratee(IO)
open I

type 'a t = 'a I.t
type 'a t = 'a I.t

let sanitize s =
(* ignore control characters: see RFC4648.1 and RFC4648.3
Expand All @@ -28,7 +28,7 @@ module Wsprotocol (IO : Iteratees.Monad) = struct
let result = Buffer.create (String.length s) in
for i = 0 to String.length s - 1 do
if (String.unsafe_get s i >= '\000' && String.unsafe_get s i <= '\032')
|| String.unsafe_get s i = '\127'
|| String.unsafe_get s i = '\127'
then ()
else Buffer.add_char result (String.unsafe_get s i)
done;
Expand All @@ -43,10 +43,10 @@ module Wsprotocol (IO : Iteratees.Monad) = struct

let wsframe s = modify (fun s ->
let l = String.length s in
if l < 126
then
if l < 126
then
Printf.sprintf "%c%c%s" (char_of_int 0x82) (char_of_int l) s
else if l < 65535
else if l < 65535
then
Printf.sprintf "%c%c%s%s" (char_of_int 0x82) (char_of_int 126)
(Helpers.marshal_int16 l) s
Expand All @@ -57,62 +57,62 @@ module Wsprotocol (IO : Iteratees.Monad) = struct
let wsframe_old s = modify (fun s ->
Printf.printf "frame: got %s\n" s; Printf.sprintf "\x00%s\xff" s) s

let rec wsunframe x =
let rec wsunframe x =
let read_sz =
read_int8 >>= fun sz ->
return (sz >= 128, sz land 0x7f)
in
let read_size sz =
if sz < 126
let read_size sz =
if sz < 126
then return sz
else if sz = 126 then
read_int16
else (* sz = 127 *)
read_int32 >>= fun x -> return (Int32.to_int x)
in
in
let read_mask has_mask =
if has_mask
if has_mask
then readn 4
else return "\x00\x00\x00\x00"
else return "\x00\x00\x00\x00"
in
let rec inner acc s =
match s with
let rec inner acc s =
match s with
| IE_cont (None, k) ->
begin
read_int8 >>= fun op ->
read_sz >>= fun (has_mask, sz) ->
read_size sz >>= fun size ->
read_mask has_mask >>= fun mask ->
read_mask has_mask >>= fun mask ->
readn size >>= fun str ->
let real_str = Helpers.unmask mask str in
if op land 0x0f = 0x08
if op land 0x0f = 0x08
then (* close frame *)
return s
else
else
if not (op land 0x80 = 0x80)
then begin
inner (acc ^ real_str) s
inner (acc ^ real_str) s
end else begin
liftI (IO.bind (k (Iteratees.Chunk (acc ^ real_str))) (fun (i, _) ->
IO.return (wsunframe i)))
end
end
end
| _ -> return s
in inner "" x

let rec wsunframe_old s =
match s with
match s with
| IE_cont (None, k) ->
begin
begin
heads "\x00" >>= fun _ ->
break ((=) '\xff') >>= fun str ->
drop 1 >>= fun () ->
break ((=) '\xff') >>= fun str ->
drop 1 >>= fun () ->
liftI (IO.bind (k (Iteratees.Chunk str)) (fun (i,_) ->
IO.return (wsunframe_old i)))
end
| _ -> return s

end
end

module TestWsIteratee = Wsprotocol(Test.StringMonad)

Expand All @@ -124,9 +124,9 @@ let test5 = test1 ^ "\x88\x00"

let testold1 = "\x00Hello\xff\x00There\xff"

let runtest () =
let runtest () =
let open TestWsIteratee in
let open I in
let open I in

let it = wsunframe (writer Test.StringMonad.strwr "foo") in
let itold = wsunframe_old (writer Test.StringMonad.strwr "bar") in
Expand All @@ -135,7 +135,7 @@ let runtest () =
let (>>=) x f = Test.StringMonad.bind x f in
let (=<<) f x = Test.StringMonad.bind x f in

let dump x =
let dump x =
let str = Test.StringMonad.getstr x in
let data = Test.StringMonad.getdata x in
Printf.printf "str='%s' state=%s\n" str (state data)
Expand Down
2 changes: 1 addition & 1 deletion src/websockets.mli
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

module Wsprotocol :
functor (IO : Iteratees.Monad) ->
sig
sig
type 'a t = 'a Iteratees.Iteratee(IO).t

(** Exposing the writer from the IO Iteratee *)
Expand Down
Loading

0 comments on commit 8413eea

Please # to comment.