diff --git a/lib/postgresql.ml b/lib/postgresql.ml index 8a21818..e4bf45f 100644 --- a/lib/postgresql.ml +++ b/lib/postgresql.ml @@ -338,7 +338,6 @@ module Stub = struct type connection type result - external conn_isnull : connection -> bool = "PQconn_isnull" [@@noalloc] external connect : string -> bool -> connection = "PQconnectdb_stub" external finish : connection -> unit = "PQfinish_stub" external reset : connection -> unit = "PQreset_stub" @@ -797,497 +796,631 @@ external conndefaults : unit -> conninfo_option array = "PQconndefaults_stub" exception Finally of exn * exn -let protectx ~f x ~(finally : 'a -> unit) = +let protectx ~f ~(finally : unit -> unit) = let res = - try f x + try f () with exn -> - (try finally x with final_exn -> raise (Finally (exn, final_exn))); + (try finally () with final_exn -> raise (Finally (exn, final_exn))); raise exn in - finally x; + finally (); res -class connection ?host ?hostaddr ?port ?dbname ?user ?password ?options ?tty - ?requiressl ?conninfo ?(startonly = false) = - let conn_info = - match conninfo with - | Some conn_info -> conn_info - | None -> - let b = Buffer.create 512 in - let field name = function - | None -> () - | Some x -> - Printf.bprintf b "%s='" name; - for i = 0 to String.length x - 1 do - if x.[i] = '\'' then Buffer.add_string b "\\'" - else Buffer.add_char b x.[i] - done; - Buffer.add_string b "' " - in - field "host" host; - field "hostaddr" hostaddr; - field "port" port; - field "dbname" dbname; - field "user" user; - field "password" password; - field "options" options; - field "tty" tty; - field "requiressl" requiressl; - Buffer.contents b - in +module type Mutex = sig + type t + + val create : unit -> t + val lock : t -> unit + val unlock : t -> unit +end + +class type connection_class = object + method finish : unit + method try_reset : unit + method reset : unit + method notifies : Notification.t option + method set_notice_processor : (string -> unit) -> unit + method set_notice_processing : [ `Stderr | `Quiet ] -> unit + method db : string + method user : string + method pass : string + method host : string + method port : string + method tty : string + method options : string + method status : connection_status + method error_message : string + method backend_pid : int + method server_version : int * int * int + method empty_result : result_status -> result + + method exec : + ?expect:result_status list -> + ?param_types:oid array -> + ?params:string array -> + ?binary_params:bool array -> + ?binary_result:bool -> + string -> + result + + method prepare : ?param_types:oid array -> string -> string -> result + + method exec_prepared : + ?expect:result_status list -> + ?params:string array -> + ?binary_params:bool array -> + string -> + result + + method describe_prepared : string -> result + + method send_query : + ?param_types:oid array -> + ?params:string array -> + ?binary_params:bool array -> + ?binary_result:bool -> + string -> + unit + + method send_prepare : ?param_types:oid array -> string -> string -> unit - fun () -> - let my_conn = Stub.connect conn_info startonly in - let () = - if Stub.connection_status my_conn = Bad then ( - let s = Stub.error_message my_conn in - Stub.finish my_conn; - raise (Error (Connection_failure s))) - else Gc.finalise Stub.finish my_conn - in - let conn_mtx = Mutex.create () in - let conn_cnd = Condition.create () in - let conn_state = ref `Free in - let check_null () = - if Stub.conn_isnull my_conn then - failwith "Postgresql.check_null: connection already finished" - in - let wrap_mtx f = - Mutex.lock conn_mtx; - protectx conn_mtx - ~f:(fun _ -> - check_null (); - (* Check now to avoid blocking *) - f ()) - ~finally:Mutex.unlock - in - let wrap_conn ?(state = `Used) f = - wrap_mtx (fun () -> - while !conn_state <> `Free do - Condition.wait conn_cnd conn_mtx - done; - conn_state := state); - protectx conn_state - ~f:(fun _ -> - check_null (); - (* Check again in case the world has changed *) - f my_conn) - ~finally:(fun _ -> - Mutex.lock conn_mtx; - conn_state := `Free; - Condition.signal conn_cnd; - Mutex.unlock conn_mtx) - in - let signal_error conn = - raise (Error (Connection_failure (Stub.error_message conn))) - in - let request_cancel () = - wrap_mtx (fun _ -> - match !conn_state with - | `Finishing | `Free -> () - | `Used -> ( - match Stub.request_cancel my_conn with - | None -> () - | Some err -> raise (Error (Cancel_failure err)))) - in - let get_str_pos_len ~loc ?pos ?len str = - let str_len = String.length str in - match (pos, len) with - | None, None -> (0, str_len) - | Some pos, _ when pos < 0 -> - invalid_arg (sprintf "Postgresql.%s: pos < 0" loc) - | _, Some len when len < 0 -> - invalid_arg (sprintf "Postgresql.%s: len < 0" loc) - | Some pos, None when pos > str_len -> - invalid_arg (sprintf "Postgresql.%s: pos > length(str)" loc) - | Some pos, None -> (pos, str_len - pos) - | None, Some len when len > str_len -> - invalid_arg (sprintf "Postgresql.%s: len > length(str)" loc) - | None, Some len -> (0, len) - | Some pos, Some len when pos + len > str_len -> - invalid_arg (sprintf "Postgresql.%s: pos + len > length(str)" loc) - | Some pos, Some len -> (pos, len) - in - - object (self (* Main routines *)) - method finish = wrap_conn ~state:`Finishing Stub.finish - - method try_reset = - wrap_conn (fun conn -> - if Stub.connection_status conn = Bad then ( - Stub.reset conn; - if Stub.connection_status conn <> Ok then signal_error conn)) - - method reset = wrap_conn Stub.reset - - (* Asynchronous Notification *) - - method notifies = wrap_conn Stub.notifies - - (* Control Functions *) - - method set_notice_processor f = - wrap_conn (fun conn -> Stub.set_notice_processor conn f) - - method set_notice_processing (h : [ `Stderr | `Quiet ]) = - let i = match h with `Stderr -> 0 | `Quiet -> 1 in - wrap_conn (fun conn -> Stub.set_notice_processor_num conn i) - - (* Accessors *) - - method db = wrap_conn Stub.db - method user = wrap_conn Stub.user - method pass = wrap_conn Stub.pass - method host = wrap_conn Stub.host - method port = wrap_conn Stub.port - method tty = wrap_conn Stub.tty - method options = wrap_conn Stub.options - method status = wrap_conn Stub.connection_status - method error_message = wrap_conn Stub.error_message - method backend_pid = wrap_conn Stub.backend_pid - - method server_version = - let version = + method send_query_prepared : + ?params:string array -> + ?binary_params:bool array -> + ?binary_result:bool -> + string -> + unit + + method send_describe_prepared : string -> unit + method send_describe_portal : string -> unit + method set_single_row_mode : unit + method get_result : result option + method put_copy_data : ?pos:int -> ?len:int -> string -> put_copy_result + method put_copy_end : ?error_msg:string -> unit -> put_copy_result + method get_copy_data : ?async:bool -> unit -> get_copy_result + method getline : ?pos:int -> ?len:int -> Bytes.t -> getline_result + method getline_async : ?pos:int -> ?len:int -> Bytes.t -> getline_async_result + method putline : string -> unit + method putnbytes : ?pos:int -> ?len:int -> string -> unit + method endcopy : unit + method copy_out : (string -> unit) -> unit + method copy_out_channel : out_channel -> unit + method copy_in_channel : in_channel -> unit + method connect_poll : polling_status + method reset_start : bool + method reset_poll : polling_status + method set_nonblocking : bool -> unit + method is_nonblocking : bool + method consume_input : unit + method is_busy : bool + method flush : flush_status + method socket : int + method request_cancel : unit + method lo_creat : oid + method lo_import : string -> oid + method lo_export : oid -> string -> unit + method lo_open : oid -> large_object + method lo_write : ?pos:int -> ?len:int -> string -> large_object -> unit + + method lo_write_ba : + ?pos:int -> + ?len:int -> + (char, Bigarray.int8_unsigned_elt, Bigarray.c_layout) Bigarray.Array1.t -> + large_object -> + unit + + method lo_read : large_object -> ?pos:int -> ?len:int -> Bytes.t -> int + + method lo_read_ba : + large_object -> + ?pos:int -> + ?len:int -> + (char, Bigarray.int8_unsigned_elt, Bigarray.c_layout) Bigarray.Array1.t -> + int + + method lo_seek : ?pos:int -> ?whence:seek_cmd -> large_object -> unit + method lo_tell : large_object -> int + method lo_close : large_object -> unit + method lo_unlink : oid -> unit + method escape_string : ?pos:int -> ?len:int -> string -> string + method escape_bytea : ?pos:int -> ?len:int -> string -> string +end + +module Connection (Mutex : Mutex) = struct + class connection ?host ?hostaddr ?port ?dbname ?user ?password ?options ?tty + ?requiressl ?conninfo ?(startonly = false) = + let conn_info = + match conninfo with + | Some conn_info -> conn_info + | None -> + let b = Buffer.create 512 in + let field name = function + | None -> () + | Some x -> + Printf.bprintf b "%s='" name; + for i = 0 to String.length x - 1 do + if x.[i] = '\'' then Buffer.add_string b "\\'" + else Buffer.add_char b x.[i] + done; + Buffer.add_string b "' " + in + field "host" host; + field "hostaddr" hostaddr; + field "port" port; + field "dbname" dbname; + field "user" user; + field "password" password; + field "options" options; + field "tty" tty; + field "requiressl" requiressl; + Buffer.contents b + in + + fun () -> + let my_conn = Stub.connect conn_info startonly in + let () = + if Stub.connection_status my_conn = Bad then ( + let s = Stub.error_message my_conn in + Stub.finish my_conn; + raise (Error (Connection_failure s))) + else Gc.finalise Stub.finish my_conn + in + let conn_mtx = Mutex.create () in + let cancel_mtx = Mutex.create () in + let finished = ref false in + (* bool becomes true after deallocation *) + let check_null () = + if !finished then + failwith "Postgresql.check_null: connection already finished" + in + let wrap_conn f = + protectx + ~f:(fun _ -> + Mutex.lock conn_mtx; + check_null (); + (* Check again in case the world has changed *) + f my_conn) + ~finally:(fun _ -> Mutex.unlock conn_mtx) + in + let wrap_cancel f = + protectx + ~f:(fun _ -> + Mutex.lock cancel_mtx; + check_null (); + (* Check again in case the world has changed *) + f my_conn) + ~finally:(fun _ -> Mutex.unlock cancel_mtx) + in + let wrap_both f = + protectx + ~f:(fun _ -> + Mutex.lock conn_mtx; + Mutex.lock cancel_mtx; + check_null (); + (* Check again in case the world has changed *) + f my_conn) + ~finally:(fun _ -> + Mutex.unlock cancel_mtx; + Mutex.unlock conn_mtx) + in + let signal_error conn = + raise (Error (Connection_failure (Stub.error_message conn))) + in + let request_cancel () = + wrap_cancel (fun _ -> + match Stub.request_cancel my_conn with + | None -> () + | Some err -> raise (Error (Cancel_failure err))) + in + let get_str_pos_len ~loc ?pos ?len str = + let str_len = String.length str in + match (pos, len) with + | None, None -> (0, str_len) + | Some pos, _ when pos < 0 -> + invalid_arg (sprintf "Postgresql.%s: pos < 0" loc) + | _, Some len when len < 0 -> + invalid_arg (sprintf "Postgresql.%s: len < 0" loc) + | Some pos, None when pos > str_len -> + invalid_arg (sprintf "Postgresql.%s: pos > length(str)" loc) + | Some pos, None -> (pos, str_len - pos) + | None, Some len when len > str_len -> + invalid_arg (sprintf "Postgresql.%s: len > length(str)" loc) + | None, Some len -> (0, len) + | Some pos, Some len when pos + len > str_len -> + invalid_arg (sprintf "Postgresql.%s: pos + len > length(str)" loc) + | Some pos, Some len -> (pos, len) + in + + object (self (* Main routines *)) + method finish = + wrap_both (fun c -> + Stub.finish c; + finished := true) + + method try_reset = wrap_conn (fun conn -> - let version = Stub.server_version conn in - if version <> 0 then version - else - let msg = - if Stub.connection_status conn = Bad then - "server_version failed because the connection was bad" - else "server_version failed for an unknown reason" + if Stub.connection_status conn = Bad then ( + Stub.reset conn; + if Stub.connection_status conn <> Ok then signal_error conn)) + + method reset = wrap_conn Stub.reset + + (* Asynchronous Notification *) + + method notifies = wrap_conn Stub.notifies + + (* Control Functions *) + + method set_notice_processor f = + wrap_conn (fun conn -> Stub.set_notice_processor conn f) + + method set_notice_processing (h : [ `Stderr | `Quiet ]) = + let i = match h with `Stderr -> 0 | `Quiet -> 1 in + wrap_conn (fun conn -> Stub.set_notice_processor_num conn i) + + (* Accessors *) + + method db = wrap_conn Stub.db + method user = wrap_conn Stub.user + method pass = wrap_conn Stub.pass + method host = wrap_conn Stub.host + method port = wrap_conn Stub.port + method tty = wrap_conn Stub.tty + method options = wrap_conn Stub.options + method status = wrap_conn Stub.connection_status + method error_message = wrap_conn Stub.error_message + method backend_pid = wrap_conn Stub.backend_pid + + method server_version = + let version = + wrap_conn (fun conn -> + let version = Stub.server_version conn in + if version <> 0 then version + else + let msg = + if Stub.connection_status conn = Bad then + "server_version failed because the connection was bad" + else "server_version failed for an unknown reason" + in + raise (Error (Connection_failure msg))) + in + let major = version / (100 * 100) in + let minor = version / 100 mod 100 in + let revision = version mod 100 in + (major, minor, revision) + + (* Commands and Queries *) + + method empty_result status = + new result (wrap_conn (fun conn -> Stub.make_empty_res conn status)) + + method exec ?(expect = []) ?(param_types = [||]) ?(params = [||]) + ?(binary_params = [||]) ?(binary_result = false) query = + let r = + wrap_conn (fun conn -> + let r = + Stub.exec_params conn query param_types params binary_params + binary_result in - raise (Error (Connection_failure msg))) - in - let major = version / (100 * 100) in - let minor = version / 100 mod 100 in - let revision = version mod 100 in - (major, minor, revision) - - (* Commands and Queries *) + if Stub.result_isnull r then signal_error conn else r) + in + let res = new result r in + let stat = res#status in + if (not (expect = [])) && not (List.mem stat expect) then + raise (Error (Unexpected_status (stat, res#error, expect))) + else res + + method prepare ?(param_types = [||]) stm_name query = + new result + (wrap_conn (fun conn -> + let r = Stub.prepare conn stm_name query param_types in + if Stub.result_isnull r then signal_error conn else r)) + + method exec_prepared ?(expect = []) ?(params = [||]) + ?(binary_params = [||]) stm_name = + let r = + wrap_conn (fun conn -> + let r = + Stub.exec_prepared conn stm_name params binary_params + in + if Stub.result_isnull r then signal_error conn else r) + in + let res = new result r in + let stat = res#status in + if (not (expect = [])) && not (List.mem stat expect) then + raise (Error (Unexpected_status (stat, res#error, expect))) + else res + + method describe_prepared query = + new result + (wrap_conn (fun conn -> + let r = Stub.describe_prepared conn query in + if Stub.result_isnull r then signal_error conn else r)) + + method send_query ?(param_types = [||]) ?(params = [||]) + ?(binary_params = [||]) ?(binary_result = false) query = + wrap_conn (fun conn -> + if + Stub.send_query_params conn query param_types params + binary_params binary_result + <> 1 + then signal_error conn) - method empty_result status = - new result (wrap_conn (fun conn -> Stub.make_empty_res conn status)) + method send_prepare ?(param_types = [||]) stm_name query = + wrap_conn (fun conn -> + if Stub.send_prepare conn stm_name query param_types <> 1 then + signal_error conn) - method exec ?(expect = []) ?(param_types = [||]) ?(params = [||]) - ?(binary_params = [||]) ?(binary_result = false) query = - let r = + method send_query_prepared ?(params = [||]) ?(binary_params = [||]) + ?(binary_result = false) stm_name = wrap_conn (fun conn -> - let r = - Stub.exec_params conn query param_types params binary_params + if + Stub.send_query_prepared conn stm_name params binary_params binary_result - in - if Stub.result_isnull r then signal_error conn else r) - in - let res = new result r in - let stat = res#status in - if (not (expect = [])) && not (List.mem stat expect) then - raise (Error (Unexpected_status (stat, res#error, expect))) - else res - - method prepare ?(param_types = [||]) stm_name query = - new result - (wrap_conn (fun conn -> - let r = Stub.prepare conn stm_name query param_types in - if Stub.result_isnull r then signal_error conn else r)) - - method exec_prepared ?(expect = []) ?(params = [||]) - ?(binary_params = [||]) stm_name = - let r = + <> 1 + then signal_error conn) + + method send_describe_prepared stm_name = + wrap_conn (fun conn -> + if Stub.send_describe_prepared conn stm_name <> 1 then + signal_error conn) + + method send_describe_portal portal_name = + wrap_conn (fun conn -> + if Stub.send_describe_portal conn portal_name <> 1 then + signal_error conn) + + method set_single_row_mode = + wrap_conn (fun conn -> + if Stub.set_single_row_mode conn <> 1 then signal_error conn) + + method get_result = + let res = wrap_conn Stub.get_result in + if Stub.result_isnull res then None else Some (new result res) + + (* Copy operations *) + + (* Low level *) + + method put_copy_data ?(pos = 0) ?len buf = + let buf_len = String.length buf in + let len = match len with Some len -> len | None -> buf_len - pos in + if len < 0 || pos < 0 || pos + len > buf_len then + invalid_arg "Postgresql.connection#put_copy_data"; + wrap_conn (fun conn -> + match Stub.put_copy_data conn buf pos len with + | -1 -> Put_copy_error + | 0 -> Put_copy_not_queued + | 1 -> Put_copy_queued + | _ -> assert false) + + method put_copy_end ?error_msg () = + wrap_conn (fun conn -> + match Stub.put_copy_end conn error_msg with + | -1 -> Put_copy_error + | 0 -> Put_copy_not_queued + | 1 -> Put_copy_queued + | _ -> assert false) + + method get_copy_data ?(async = false) () = + wrap_conn (fun conn -> + Stub.get_copy_data conn (if async then 1 else 0)) + + method getline ?(pos = 0) ?len buf = + let buf_len = Bytes.length buf in + let len = match len with Some len -> len | None -> buf_len - pos in + if len < 0 || pos < 0 || pos + len > buf_len then + invalid_arg "Postgresql.connection#getline"; + wrap_conn (fun conn -> + match Stub.getline conn buf pos len with + | -1 -> EOF + | 0 -> LineRead + | 1 -> BufFull + | _ -> assert false) + + method getline_async ?(pos = 0) ?len buf = + let buf_len = Bytes.length buf in + let len = match len with Some len -> len | None -> buf_len - pos in + if len < 0 || pos < 0 || pos + len > buf_len then + invalid_arg "Postgresql.connection#getline_async"; + wrap_conn (fun conn -> + match Stub.getline_async conn buf pos len with + | -1 -> + if Stub.endcopy conn <> 0 then signal_error conn + else EndOfData + | 0 -> NoData + | n when n > 0 -> + if Bytes.get buf (pos + n - 1) = '\n' then DataRead n + else PartDataRead n + | _ -> assert false) + + method putline buf = + wrap_conn (fun conn -> + if Stub.putline conn buf <> 0 && not (Stub.is_nonblocking conn) + then signal_error conn) + + method putnbytes ?(pos = 0) ?len buf = + let buf_len = String.length buf in + let len = match len with Some len -> len | None -> buf_len - pos in + if len < 0 || pos < 0 || pos + len > buf_len then + invalid_arg "Postgresql.connection#putnbytes"; + wrap_conn (fun conn -> + if + Stub.putnbytes conn buf pos len <> 0 + && not (Stub.is_nonblocking conn) + then signal_error conn) + + method endcopy = wrap_conn (fun conn -> - let r = Stub.exec_prepared conn stm_name params binary_params in - if Stub.result_isnull r then signal_error conn else r) - in - let res = new result r in - let stat = res#status in - if (not (expect = [])) && not (List.mem stat expect) then - raise (Error (Unexpected_status (stat, res#error, expect))) - else res - - method describe_prepared query = - new result - (wrap_conn (fun conn -> - let r = Stub.describe_prepared conn query in - if Stub.result_isnull r then signal_error conn else r)) - - method send_query ?(param_types = [||]) ?(params = [||]) - ?(binary_params = [||]) ?(binary_result = false) query = - wrap_conn (fun conn -> - if - Stub.send_query_params conn query param_types params - binary_params binary_result - <> 1 - then signal_error conn) - - method send_prepare ?(param_types = [||]) stm_name query = - wrap_conn (fun conn -> - if Stub.send_prepare conn stm_name query param_types <> 1 then - signal_error conn) - - method send_query_prepared ?(params = [||]) ?(binary_params = [||]) - ?(binary_result = false) stm_name = - wrap_conn (fun conn -> - if - Stub.send_query_prepared conn stm_name params binary_params - binary_result - <> 1 - then signal_error conn) - - method send_describe_prepared stm_name = - wrap_conn (fun conn -> - if Stub.send_describe_prepared conn stm_name <> 1 then - signal_error conn) - - method send_describe_portal portal_name = - wrap_conn (fun conn -> - if Stub.send_describe_portal conn portal_name <> 1 then - signal_error conn) - - method set_single_row_mode = - wrap_conn (fun conn -> - if Stub.set_single_row_mode conn <> 1 then signal_error conn) - - method get_result = - let res = wrap_conn Stub.get_result in - if Stub.result_isnull res then None else Some (new result res) - - (* Copy operations *) - - (* Low level *) - - method put_copy_data ?(pos = 0) ?len buf = - let buf_len = String.length buf in - let len = match len with Some len -> len | None -> buf_len - pos in - if len < 0 || pos < 0 || pos + len > buf_len then - invalid_arg "Postgresql.connection#put_copy_data"; - wrap_conn (fun conn -> - match Stub.put_copy_data conn buf pos len with - | -1 -> Put_copy_error - | 0 -> Put_copy_not_queued - | 1 -> Put_copy_queued - | _ -> assert false) - - method put_copy_end ?error_msg () = - wrap_conn (fun conn -> - match Stub.put_copy_end conn error_msg with - | -1 -> Put_copy_error - | 0 -> Put_copy_not_queued - | 1 -> Put_copy_queued - | _ -> assert false) - - method get_copy_data ?(async = false) () = - wrap_conn (fun conn -> - Stub.get_copy_data conn (if async then 1 else 0)) - - method getline ?(pos = 0) ?len buf = - let buf_len = Bytes.length buf in - let len = match len with Some len -> len | None -> buf_len - pos in - if len < 0 || pos < 0 || pos + len > buf_len then - invalid_arg "Postgresql.connection#getline"; - wrap_conn (fun conn -> - match Stub.getline conn buf pos len with - | -1 -> EOF - | 0 -> LineRead - | 1 -> BufFull - | _ -> assert false) - - method getline_async ?(pos = 0) ?len buf = - let buf_len = Bytes.length buf in - let len = match len with Some len -> len | None -> buf_len - pos in - if len < 0 || pos < 0 || pos + len > buf_len then - invalid_arg "Postgresql.connection#getline_async"; - wrap_conn (fun conn -> - match Stub.getline_async conn buf pos len with - | -1 -> - if Stub.endcopy conn <> 0 then signal_error conn else EndOfData - | 0 -> NoData - | n when n > 0 -> - if Bytes.get buf (pos + n - 1) = '\n' then DataRead n - else PartDataRead n - | _ -> assert false) - - method putline buf = - wrap_conn (fun conn -> - if Stub.putline conn buf <> 0 && not (Stub.is_nonblocking conn) - then signal_error conn) - - method putnbytes ?(pos = 0) ?len buf = - let buf_len = String.length buf in - let len = match len with Some len -> len | None -> buf_len - pos in - if len < 0 || pos < 0 || pos + len > buf_len then - invalid_arg "Postgresql.connection#putnbytes"; - wrap_conn (fun conn -> - if - Stub.putnbytes conn buf pos len <> 0 - && not (Stub.is_nonblocking conn) - then signal_error conn) - - method endcopy = - wrap_conn (fun conn -> - if Stub.endcopy conn <> 0 && not (Stub.is_nonblocking conn) then - signal_error conn) - - (* High level *) - - method copy_out f = - let buf = Buffer.create 1024 in - let len = 512 in - let bts = Bytes.create len in - wrap_conn (fun conn -> - let rec loop () = - let r = Stub.getline conn bts 0 len in - if r = 1 then ( - (* Buffer full *) - Buffer.add_subbytes buf bts 0 len; - loop ()) - else if r = 0 then ( - (* Line read *) - let zero = Bytes.index bts '\000' in - Buffer.add_subbytes buf bts 0 zero; - match Buffer.contents buf with - | "\\." -> () - | line -> - Buffer.clear buf; - f line; - loop ()) - else if r = -1 then raise End_of_file - else assert false (* impossible *) - in - loop ()); - self#endcopy - - method copy_out_channel oc = - self#copy_out (fun s -> output_string oc (s ^ "\n")) - - method copy_in_channel ic = - try - while true do - self#putline (input_line ic ^ "\n") - done - with End_of_file -> - self#putline "\\.\n"; + if Stub.endcopy conn <> 0 && not (Stub.is_nonblocking conn) then + signal_error conn) + + (* High level *) + + method copy_out f = + let buf = Buffer.create 1024 in + let len = 512 in + let bts = Bytes.create len in + wrap_conn (fun conn -> + let rec loop () = + let r = Stub.getline conn bts 0 len in + if r = 1 then ( + (* Buffer full *) + Buffer.add_subbytes buf bts 0 len; + loop ()) + else if r = 0 then ( + (* Line read *) + let zero = Bytes.index bts '\000' in + Buffer.add_subbytes buf bts 0 zero; + match Buffer.contents buf with + | "\\." -> () + | line -> + Buffer.clear buf; + f line; + loop ()) + else if r = -1 then raise End_of_file + else assert false (* impossible *) + in + loop ()); self#endcopy - (* Asynchronous operations and non blocking mode *) - - method connect_poll = wrap_conn Stub.connect_poll - method reset_start = wrap_conn Stub.reset_start - method reset_poll = wrap_conn Stub.reset_poll - - method set_nonblocking b = - wrap_conn (fun conn -> - if Stub.set_nonblocking conn b <> 0 then signal_error conn) - - method is_nonblocking = wrap_conn Stub.is_nonblocking - - method consume_input = - wrap_conn (fun conn -> - if Stub.consume_input conn <> 1 then signal_error conn) - - method is_busy = wrap_conn Stub.is_busy - - method flush = - wrap_conn (fun conn -> - match Stub.flush conn with - | 0 -> Successful - | 1 -> Data_left_to_send - | _ -> signal_error conn) - - method socket = - wrap_conn (fun conn -> - let s = Stub.socket conn in - if s = -1 then signal_error conn else s) - - method request_cancel = request_cancel () - - (* Large objects *) - - method lo_creat = - wrap_conn (fun conn -> - let lo = Stub.lo_creat conn in - if lo <= 0 then signal_error conn; - lo) - - method lo_import filename = - wrap_conn (fun conn -> - let oid = Stub.lo_import conn filename in - if oid = 0 then signal_error conn; - oid) - - method lo_export oid filename = - wrap_conn (fun conn -> - if Stub.lo_export conn oid filename <= 0 then signal_error conn) - - method lo_open oid = - wrap_conn (fun conn -> - let lo = Stub.lo_open conn oid in - if lo = -1 then signal_error conn; - lo) - - method lo_write ?(pos = 0) ?len buf lo = - let buf_len = String.length buf in - let len = match len with Some len -> len | None -> buf_len - pos in - if len < 0 || pos < 0 || pos + len > String.length buf then - invalid_arg "Postgresql.connection#lo_write"; - wrap_conn (fun conn -> - let w = Stub.lo_write conn lo buf pos len in - if w < len then signal_error conn) - - method lo_write_ba ?(pos = 0) ?len buf lo = - let buf_len = Bigarray.Array1.dim buf in - let len = match len with Some len -> len | None -> buf_len - pos in - if len < 0 || pos < 0 || pos + len > buf_len then - invalid_arg "Postgresql.connection#lo_write_ba"; - wrap_conn (fun conn -> - let w = Stub.lo_write_ba conn lo buf pos len in - if w < len then signal_error conn) - - method lo_read lo ?(pos = 0) ?len buf = - let buf_len = Bytes.length buf in - let len = match len with Some len -> len | None -> buf_len - pos in - if len < 0 || pos < 0 || pos + len > buf_len then - invalid_arg "Postgresql.connection#lo_read"; - wrap_conn (fun conn -> - let read = Stub.lo_read conn lo buf pos len in - if read = -1 then signal_error conn; - read) - - method lo_read_ba lo ?(pos = 0) ?len buf = - let buf_len = Bigarray.Array1.dim buf in - let len = match len with Some len -> len | None -> buf_len - pos in - if len < 0 || pos < 0 || pos + len > buf_len then - invalid_arg "Postgresql.connection#lo_read_ba"; - wrap_conn (fun conn -> - let read = Stub.lo_read_ba conn lo buf pos len in - if read = -1 then signal_error conn; - read) - - method lo_seek ?(pos = 0) ?(whence = SEEK_SET) lo = - wrap_conn (fun conn -> - if Stub.lo_seek conn lo pos whence < 0 then signal_error conn) - - method lo_tell lo = - wrap_conn (fun conn -> - let pos = Stub.lo_tell conn lo in - if pos = -1 then signal_error conn; - pos) - - method lo_close oid = - wrap_conn (fun conn -> - if Stub.lo_close conn oid = -1 then signal_error conn) - - method lo_unlink oid = - wrap_conn (fun conn -> - let oid = Stub.lo_unlink conn oid in - if oid = -1 then signal_error conn) - - (* Escaping *) - - method escape_string ?pos ?len str = - let pos, len = get_str_pos_len ~loc:"escape_string" ?pos ?len str in - wrap_conn (fun conn -> Stub.escape_string_conn conn str ~pos ~len) - - method escape_bytea ?pos ?len str = - let pos, len = get_str_pos_len ~loc:"escape_bytea" ?pos ?len str in - wrap_conn (fun conn -> Stub.escape_bytea_conn conn str ~pos ~len) - end + method copy_out_channel oc = + self#copy_out (fun s -> output_string oc (s ^ "\n")) + + method copy_in_channel ic = + try + while true do + self#putline (input_line ic ^ "\n") + done + with End_of_file -> + self#putline "\\.\n"; + self#endcopy + + (* Asynchronous operations and non blocking mode *) + + method connect_poll = wrap_conn Stub.connect_poll + method reset_start = wrap_conn Stub.reset_start + method reset_poll = wrap_conn Stub.reset_poll + + method set_nonblocking b = + wrap_conn (fun conn -> + if Stub.set_nonblocking conn b <> 0 then signal_error conn) + + method is_nonblocking = wrap_conn Stub.is_nonblocking + + method consume_input = + wrap_conn (fun conn -> + if Stub.consume_input conn <> 1 then signal_error conn) + + method is_busy = wrap_conn Stub.is_busy + + method flush = + wrap_conn (fun conn -> + match Stub.flush conn with + | 0 -> Successful + | 1 -> Data_left_to_send + | _ -> signal_error conn) + + method socket = + wrap_conn (fun conn -> + let s = Stub.socket conn in + if s = -1 then signal_error conn else s) + + method request_cancel = request_cancel () + + (* Large objects *) + + method lo_creat = + wrap_conn (fun conn -> + let lo = Stub.lo_creat conn in + if lo <= 0 then signal_error conn; + lo) + + method lo_import filename = + wrap_conn (fun conn -> + let oid = Stub.lo_import conn filename in + if oid = 0 then signal_error conn; + oid) + + method lo_export oid filename = + wrap_conn (fun conn -> + if Stub.lo_export conn oid filename <= 0 then signal_error conn) + + method lo_open oid = + wrap_conn (fun conn -> + let lo = Stub.lo_open conn oid in + if lo = -1 then signal_error conn; + lo) + + method lo_write ?(pos = 0) ?len buf lo = + let buf_len = String.length buf in + let len = match len with Some len -> len | None -> buf_len - pos in + if len < 0 || pos < 0 || pos + len > String.length buf then + invalid_arg "Postgresql.connection#lo_write"; + wrap_conn (fun conn -> + let w = Stub.lo_write conn lo buf pos len in + if w < len then signal_error conn) + + method lo_write_ba ?(pos = 0) ?len buf lo = + let buf_len = Bigarray.Array1.dim buf in + let len = match len with Some len -> len | None -> buf_len - pos in + if len < 0 || pos < 0 || pos + len > buf_len then + invalid_arg "Postgresql.connection#lo_write_ba"; + wrap_conn (fun conn -> + let w = Stub.lo_write_ba conn lo buf pos len in + if w < len then signal_error conn) + + method lo_read lo ?(pos = 0) ?len buf = + let buf_len = Bytes.length buf in + let len = match len with Some len -> len | None -> buf_len - pos in + if len < 0 || pos < 0 || pos + len > buf_len then + invalid_arg "Postgresql.connection#lo_read"; + wrap_conn (fun conn -> + let read = Stub.lo_read conn lo buf pos len in + if read = -1 then signal_error conn; + read) + + method lo_read_ba lo ?(pos = 0) ?len buf = + let buf_len = Bigarray.Array1.dim buf in + let len = match len with Some len -> len | None -> buf_len - pos in + if len < 0 || pos < 0 || pos + len > buf_len then + invalid_arg "Postgresql.connection#lo_read_ba"; + wrap_conn (fun conn -> + let read = Stub.lo_read_ba conn lo buf pos len in + if read = -1 then signal_error conn; + read) + + method lo_seek ?(pos = 0) ?(whence = SEEK_SET) lo = + wrap_conn (fun conn -> + if Stub.lo_seek conn lo pos whence < 0 then signal_error conn) + + method lo_tell lo = + wrap_conn (fun conn -> + let pos = Stub.lo_tell conn lo in + if pos = -1 then signal_error conn; + pos) + + method lo_close oid = + wrap_conn (fun conn -> + if Stub.lo_close conn oid = -1 then signal_error conn) + + method lo_unlink oid = + wrap_conn (fun conn -> + let oid = Stub.lo_unlink conn oid in + if oid = -1 then signal_error conn) + + (* Escaping *) + + method escape_string ?pos ?len str = + let pos, len = get_str_pos_len ~loc:"escape_string" ?pos ?len str in + wrap_conn (fun conn -> Stub.escape_string_conn conn str ~pos ~len) + + method escape_bytea ?pos ?len str = + let pos, len = get_str_pos_len ~loc:"escape_bytea" ?pos ?len str in + wrap_conn (fun conn -> Stub.escape_bytea_conn conn str ~pos ~len) + end +end + +module DefaultConnection = Connection (Stdlib.Mutex) + +class connection = DefaultConnection.connection diff --git a/lib/postgresql.mli b/lib/postgresql.mli index cef6ab0..0fc475a 100644 --- a/lib/postgresql.mli +++ b/lib/postgresql.mli @@ -487,31 +487,8 @@ val conndefaults : unit -> conninfo_option array usable. @raise Error if there is a connection failure. *) -class connection : - ?host:string -> - (* Default: none *) - ?hostaddr:string -> - (* Default: none *) - ?port:string -> - (* Default: none *) - ?dbname:string -> - (* Default: none *) - ?user:string -> - (* Default: none *) - ?password:string -> - (* Default: none *) - ?options:string -> - (* Default: none *) - ?tty:string -> - (* Default: none *) - ?requiressl:string -> - (* Default: none *) - ?conninfo:string -> - (* Default: none *) - ?startonly:bool -> - (* Default: false *) - unit -> -object + +class type connection_class = object (* Main routines *) method finish : unit @@ -1079,3 +1056,92 @@ object @param pos default = 0 @param len default = String.length str - pos *) end + +class connection : + ?host:string -> + (* Default: none *) + ?hostaddr:string -> + (* Default: none *) + ?port:string -> + (* Default: none *) + ?dbname:string -> + (* Default: none *) + ?user:string -> + (* Default: none *) + ?password:string -> + (* Default: none *) + ?options:string -> + (* Default: none *) + ?tty:string -> + (* Default: none *) + ?requiressl:string -> + (* Default: none *) + ?conninfo:string -> + (* Default: none *) + ?startonly:bool -> + (* Default: false *) + unit -> + connection_class + +(** Type of a mutex module *) +module type Mutex = sig + type t + + val create : unit -> t + val lock : t -> unit + val unlock : t -> unit +end + +(** Connection parametrized by the type of mutex used. + + If you are using your own wrapper around connection, you could for instance + use this kind of code: + {[ + module Check = struct + type t = bool Atomic.t + + let create () = Atomic.make false + + let lock m = + if not (Atomic.compare_and_set m false true) then + failwith "Concurrent use of a Postgres connection (at lock)" + + let unlock m = + if not (Atomic.compare_and_set m true false) then + failwith + "Concurrent use of a Postgres connection (at unlock, impossible \ + ?)" + end + + module Postgresql = struct + include Postgresql + include Connection (Check) + end + ]} *) +module Connection (_ : Mutex) : sig + class connection : + ?host:string -> + (* Default: none *) + ?hostaddr:string -> + (* Default: none *) + ?port:string -> + (* Default: none *) + ?dbname:string -> + (* Default: none *) + ?user:string -> + (* Default: none *) + ?password:string -> + (* Default: none *) + ?options:string -> + (* Default: none *) + ?tty:string -> + (* Default: none *) + ?requiressl:string -> + (* Default: none *) + ?conninfo:string -> + (* Default: none *) + ?startonly:bool -> + (* Default: false *) + unit -> + connection_class +end diff --git a/lib/postgresql_stubs.c b/lib/postgresql_stubs.c index ebbecd3..df581b9 100644 --- a/lib/postgresql_stubs.c +++ b/lib/postgresql_stubs.c @@ -266,10 +266,6 @@ static inline void np_decr_refcount(np_callback *c) { #define get_cancel_obj(v) ((PGcancel *)Field(v, 2)) #define set_cancel_obj(v, cancel) (Field(v, 2) = (value)cancel) -CAMLprim value PQconn_isnull(value v_conn) { - return Val_bool((get_conn(v_conn)) ? 0 : 1); -} - static inline void free_conn(value v_conn) { PGconn *conn = get_conn(v_conn); if (conn) {