Skip to content

Commit

Permalink
Get the crate to compile and the tests to pass.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Jul 15, 2024
1 parent c326289 commit 3df9ece
Show file tree
Hide file tree
Showing 6 changed files with 422 additions and 0 deletions.
1 change: 1 addition & 0 deletions gen/gen.ml
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ module Func = struct
| "at::tensoroptions" -> Some TensorOptions
| "at::intarrayref" -> Some (if is_nullable then IntListOption else IntList)
| "at::arrayref<double>" -> Some DoubleList
| "const c10::list<::std::optional<at::tensor>> &"
| "const c10::list<c10::optional<at::tensor>> &" -> Some TensorOptList
| "const at::itensorlistref &" | "at::tensorlist" -> Some TensorList
| "at::device" -> Some Device
Expand Down
177 changes: 177 additions & 0 deletions src/wrappers/tensor_fallible_generated.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3450,6 +3450,68 @@ impl Tensor {
Ok(Tensor { c_tensor: c_tensors[0] })
}

pub fn f_internal_index_put_impl<T: Borrow<Tensor>>(
&self,
indices: &[Option<T>],
values: &Tensor,
accumulate: bool,
unsafe_: bool,
) -> Result<Tensor, TchError> {
let mut c_tensors = [std::ptr::null_mut(); 1];
unsafe_torch_err!(atg__index_put_impl(
c_tensors.as_mut_ptr(),
self.c_tensor,
ptr_list_opt(indices).as_ptr(),
indices.len() as i32,
values.c_tensor,
if accumulate { 1 } else { 0 },
if unsafe_ { 1 } else { 0 }
));
Ok(Tensor { c_tensor: c_tensors[0] })
}

pub fn f_internal_index_put_impl_<T: Borrow<Tensor>>(
&mut self,
indices: &[Option<T>],
values: &Tensor,
accumulate: bool,
unsafe_: bool,
) -> Result<Tensor, TchError> {
let mut c_tensors = [std::ptr::null_mut(); 1];
unsafe_torch_err!(atg__index_put_impl_(
c_tensors.as_mut_ptr(),
self.c_tensor,
ptr_list_opt(indices).as_ptr(),
indices.len() as i32,
values.c_tensor,
if accumulate { 1 } else { 0 },
if unsafe_ { 1 } else { 0 }
));
Ok(Tensor { c_tensor: c_tensors[0] })
}

pub fn f_internal_index_put_impl_out<T: Borrow<Tensor>>(
&self,
out: &Tensor,
indices: &[Option<T>],
values: &Tensor,
accumulate: bool,
unsafe_: bool,
) -> Result<Tensor, TchError> {
let mut c_tensors = [std::ptr::null_mut(); 1];
unsafe_torch_err!(atg__index_put_impl_out(
c_tensors.as_mut_ptr(),
out.c_tensor,
self.c_tensor,
ptr_list_opt(indices).as_ptr(),
indices.len() as i32,
values.c_tensor,
if accumulate { 1 } else { 0 },
if unsafe_ { 1 } else { 0 }
));
Ok(Tensor { c_tensor: c_tensors[0] })
}

pub fn f_internal_indices(&self) -> Result<Tensor, TchError> {
let mut c_tensors = [std::ptr::null_mut(); 1];
unsafe_torch_err!(atg__indices(c_tensors.as_mut_ptr(), self.c_tensor));
Expand Down Expand Up @@ -7852,6 +7914,38 @@ impl Tensor {
Ok((Tensor { c_tensor: c_tensors[0] }, Tensor { c_tensor: c_tensors[1] }))
}

pub fn f_internal_unsafe_index<T: Borrow<Tensor>>(
&self,
indices: &[Option<T>],
) -> Result<Tensor, TchError> {
let mut c_tensors = [std::ptr::null_mut(); 1];
unsafe_torch_err!(atg__unsafe_index(
c_tensors.as_mut_ptr(),
self.c_tensor,
ptr_list_opt(indices).as_ptr(),
indices.len() as i32
));
Ok(Tensor { c_tensor: c_tensors[0] })
}

pub fn f_internal_unsafe_index_put<T: Borrow<Tensor>>(
&self,
indices: &[Option<T>],
values: &Tensor,
accumulate: bool,
) -> Result<Tensor, TchError> {
let mut c_tensors = [std::ptr::null_mut(); 1];
unsafe_torch_err!(atg__unsafe_index_put(
c_tensors.as_mut_ptr(),
self.c_tensor,
ptr_list_opt(indices).as_ptr(),
indices.len() as i32,
values.c_tensor,
if accumulate { 1 } else { 0 }
));
Ok(Tensor { c_tensor: c_tensors[0] })
}

pub fn f_internal_unsafe_view(&self, size: impl IntList) -> Result<Tensor, TchError> {
let mut c_tensors = [std::ptr::null_mut(); 1];
unsafe_torch_err!(atg__unsafe_view(
Expand Down Expand Up @@ -19334,6 +19428,17 @@ impl Tensor {
Ok(Tensor { c_tensor: c_tensors[0] })
}

pub fn f_index<T: Borrow<Tensor>>(&self, indices: &[Option<T>]) -> Result<Tensor, TchError> {
let mut c_tensors = [std::ptr::null_mut(); 1];
unsafe_torch_err!(atg_index(
c_tensors.as_mut_ptr(),
self.c_tensor,
ptr_list_opt(indices).as_ptr(),
indices.len() as i32
));
Ok(Tensor { c_tensor: c_tensors[0] })
}

pub fn f_index_add(
&self,
dim: i64,
Expand Down Expand Up @@ -19546,6 +19651,62 @@ impl Tensor {
Ok(Tensor { c_tensor: c_tensors[0] })
}

pub fn f_index_put<T: Borrow<Tensor>>(
&self,
indices: &[Option<T>],
values: &Tensor,
accumulate: bool,
) -> Result<Tensor, TchError> {
let mut c_tensors = [std::ptr::null_mut(); 1];
unsafe_torch_err!(atg_index_put(
c_tensors.as_mut_ptr(),
self.c_tensor,
ptr_list_opt(indices).as_ptr(),
indices.len() as i32,
values.c_tensor,
if accumulate { 1 } else { 0 }
));
Ok(Tensor { c_tensor: c_tensors[0] })
}

pub fn f_index_put_<T: Borrow<Tensor>>(
&mut self,
indices: &[Option<T>],
values: &Tensor,
accumulate: bool,
) -> Result<Tensor, TchError> {
let mut c_tensors = [std::ptr::null_mut(); 1];
unsafe_torch_err!(atg_index_put_(
c_tensors.as_mut_ptr(),
self.c_tensor,
ptr_list_opt(indices).as_ptr(),
indices.len() as i32,
values.c_tensor,
if accumulate { 1 } else { 0 }
));
Ok(Tensor { c_tensor: c_tensors[0] })
}

pub fn f_index_put_out<T: Borrow<Tensor>>(
&self,
out: &Tensor,
indices: &[Option<T>],
values: &Tensor,
accumulate: bool,
) -> Result<Tensor, TchError> {
let mut c_tensors = [std::ptr::null_mut(); 1];
unsafe_torch_err!(atg_index_put_out(
c_tensors.as_mut_ptr(),
out.c_tensor,
self.c_tensor,
ptr_list_opt(indices).as_ptr(),
indices.len() as i32,
values.c_tensor,
if accumulate { 1 } else { 0 }
));
Ok(Tensor { c_tensor: c_tensors[0] })
}

pub fn f_index_reduce(
&self,
dim: i64,
Expand Down Expand Up @@ -19660,6 +19821,22 @@ impl Tensor {
Ok(Tensor { c_tensor: c_tensors[0] })
}

pub fn f_index_tensor_out<T: Borrow<Tensor>>(
&self,
out: &Tensor,
indices: &[Option<T>],
) -> Result<Tensor, TchError> {
let mut c_tensors = [std::ptr::null_mut(); 1];
unsafe_torch_err!(atg_index_tensor_out(
c_tensors.as_mut_ptr(),
out.c_tensor,
self.c_tensor,
ptr_list_opt(indices).as_ptr(),
indices.len() as i32
));
Ok(Tensor { c_tensor: c_tensors[0] })
}

pub fn f_indices(&self) -> Result<Tensor, TchError> {
let mut c_tensors = [std::ptr::null_mut(); 1];
unsafe_torch_err!(atg_indices(c_tensors.as_mut_ptr(), self.c_tensor));
Expand Down
84 changes: 84 additions & 0 deletions src/wrappers/tensor_generated.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2183,6 +2183,37 @@ impl Tensor {
self.f_internal_histogramdd_from_bin_tensors_out(out, bins, weight, density).unwrap()
}

pub fn internal_index_put_impl<T: Borrow<Tensor>>(
&self,
indices: &[Option<T>],
values: &Tensor,
accumulate: bool,
unsafe_: bool,
) -> Tensor {
self.f_internal_index_put_impl(indices, values, accumulate, unsafe_).unwrap()
}

pub fn internal_index_put_impl_<T: Borrow<Tensor>>(
&mut self,
indices: &[Option<T>],
values: &Tensor,
accumulate: bool,
unsafe_: bool,
) -> Tensor {
self.f_internal_index_put_impl_(indices, values, accumulate, unsafe_).unwrap()
}

pub fn internal_index_put_impl_out<T: Borrow<Tensor>>(
&self,
out: &Tensor,
indices: &[Option<T>],
values: &Tensor,
accumulate: bool,
unsafe_: bool,
) -> Tensor {
self.f_internal_index_put_impl_out(out, indices, values, accumulate, unsafe_).unwrap()
}

pub fn internal_indices(&self) -> Tensor {
self.f_internal_indices().unwrap()
}
Expand Down Expand Up @@ -4538,6 +4569,19 @@ impl Tensor {
Tensor::f_internal_unpack_dual(dual, level).unwrap()
}

pub fn internal_unsafe_index<T: Borrow<Tensor>>(&self, indices: &[Option<T>]) -> Tensor {
self.f_internal_unsafe_index(indices).unwrap()
}

pub fn internal_unsafe_index_put<T: Borrow<Tensor>>(
&self,
indices: &[Option<T>],
values: &Tensor,
accumulate: bool,
) -> Tensor {
self.f_internal_unsafe_index_put(indices, values, accumulate).unwrap()
}

pub fn internal_unsafe_view(&self, size: impl IntList) -> Tensor {
self.f_internal_unsafe_view(size).unwrap()
}
Expand Down Expand Up @@ -10094,6 +10138,10 @@ impl Tensor {
self.f_imag().unwrap()
}

pub fn index<T: Borrow<Tensor>>(&self, indices: &[Option<T>]) -> Tensor {
self.f_index(indices).unwrap()
}

pub fn index_add(&self, dim: i64, index: &Tensor, source: &Tensor) -> Tensor {
self.f_index_add(dim, index, source).unwrap()
}
Expand Down Expand Up @@ -10160,6 +10208,34 @@ impl Tensor {
self.f_index_fill_int_tensor_out(out, dim, index, value).unwrap()
}

pub fn index_put<T: Borrow<Tensor>>(
&self,
indices: &[Option<T>],
values: &Tensor,
accumulate: bool,
) -> Tensor {
self.f_index_put(indices, values, accumulate).unwrap()
}

pub fn index_put_<T: Borrow<Tensor>>(
&mut self,
indices: &[Option<T>],
values: &Tensor,
accumulate: bool,
) -> Tensor {
self.f_index_put_(indices, values, accumulate).unwrap()
}

pub fn index_put_out<T: Borrow<Tensor>>(
&self,
out: &Tensor,
indices: &[Option<T>],
values: &Tensor,
accumulate: bool,
) -> Tensor {
self.f_index_put_out(out, indices, values, accumulate).unwrap()
}

pub fn index_reduce(
&self,
dim: i64,
Expand Down Expand Up @@ -10211,6 +10287,14 @@ impl Tensor {
self.f_index_select_out(out, dim, index).unwrap()
}

pub fn index_tensor_out<T: Borrow<Tensor>>(
&self,
out: &Tensor,
indices: &[Option<T>],
) -> Tensor {
self.f_index_tensor_out(out, indices).unwrap()
}

pub fn indices(&self) -> Tensor {
self.f_indices().unwrap()
}
Expand Down
Loading

0 comments on commit 3df9ece

Please sign in to comment.