@@ -11,7 +11,7 @@ namespace ffi = xla::ffi;
11
11
XLA_FFI_REGISTER_ENUM_ATTR_DECODING (UVtMode);
12
12
13
13
template <ffi::DataType dtype>
14
- static ffi::Error SvdOnlyVtImpl (
14
+ static ffi::Error SvdOnlyUVtImpl (
15
15
ffi::Buffer<dtype> x,
16
16
ffi::ResultBuffer<dtype> x_out,
17
17
ffi::ResultBuffer<ffi::ToReal(dtype)> s,
@@ -189,7 +189,7 @@ static ffi::Error SvdOnlyVtImpl(
189
189
}
190
190
191
191
template <ffi::DataType dtype>
192
- static ffi::Error SvdOnlyVtQRImpl (
192
+ static ffi::Error SvdOnlyUVtQRImpl (
193
193
ffi::Buffer<dtype> x,
194
194
ffi::ResultBuffer<dtype> x_out,
195
195
ffi::ResultBuffer<ffi::ToReal(dtype)> s,
@@ -382,7 +382,7 @@ static ffi::Error SvdOnlyVtQRImpl(
382
382
383
383
#define DEFINE_REAL_SVD_ONLY_VT (fname, dtype ) \
384
384
XLA_FFI_DEFINE_HANDLER_SYMBOL ( \
385
- fname, SvdOnlyVtImpl <dtype>, \
385
+ fname, SvdOnlyUVtImpl <dtype>, \
386
386
ffi::Ffi::Bind () \
387
387
.Arg<ffi::Buffer<dtype>>(/* x*/ ) \
388
388
.Ret<ffi::Buffer<dtype>>(/* x_out*/ ) \
@@ -393,7 +393,7 @@ static ffi::Error SvdOnlyVtQRImpl(
393
393
394
394
#define DEFINE_COMPLEX_SVD_ONLY_VT (fname, dtype ) \
395
395
XLA_FFI_DEFINE_HANDLER_SYMBOL ( \
396
- fname, SvdOnlyVtImpl <dtype>, \
396
+ fname, SvdOnlyUVtImpl <dtype>, \
397
397
ffi::Ffi::Bind () \
398
398
.Arg<ffi::Buffer<dtype>>(/* x*/ ) \
399
399
.Ret<ffi::Buffer<dtype>>(/* x_out*/ ) \
@@ -402,14 +402,14 @@ static ffi::Error SvdOnlyVtQRImpl(
402
402
.Ret<ffi::Buffer<ffi::DataType::S32>>(/* info*/ ) \
403
403
.Attr<UVtMode>(" mode" ))
404
404
405
- DEFINE_REAL_SVD_ONLY_VT(svd_only_vt_f32 , ffi::DataType::F32);
406
- DEFINE_REAL_SVD_ONLY_VT (svd_only_vt_f64 , ffi::DataType::F64);
407
- DEFINE_COMPLEX_SVD_ONLY_VT (svd_only_vt_c64 , ffi::DataType::C64);
408
- DEFINE_COMPLEX_SVD_ONLY_VT (svd_only_vt_c128 , ffi::DataType::C128);
405
+ DEFINE_REAL_SVD_ONLY_VT(svd_only_u_vt_f32 , ffi::DataType::F32);
406
+ DEFINE_REAL_SVD_ONLY_VT (svd_only_u_vt_f64 , ffi::DataType::F64);
407
+ DEFINE_COMPLEX_SVD_ONLY_VT (svd_only_u_vt_c64 , ffi::DataType::C64);
408
+ DEFINE_COMPLEX_SVD_ONLY_VT (svd_only_u_vt_c128 , ffi::DataType::C128);
409
409
410
410
#define DEFINE_REAL_SVD_ONLY_VT_QR (fname, dtype ) \
411
411
XLA_FFI_DEFINE_HANDLER_SYMBOL ( \
412
- fname, SvdOnlyVtQRImpl <dtype>, \
412
+ fname, SvdOnlyUVtQRImpl <dtype>, \
413
413
ffi::Ffi::Bind () \
414
414
.Arg<ffi::Buffer<dtype>>(/* x*/ ) \
415
415
.Ret<ffi::Buffer<dtype>>(/* x_out*/ ) \
@@ -420,7 +420,7 @@ DEFINE_COMPLEX_SVD_ONLY_VT(svd_only_vt_c128, ffi::DataType::C128);
420
420
421
421
#define DEFINE_COMPLEX_SVD_ONLY_VT_QR (fname, dtype ) \
422
422
XLA_FFI_DEFINE_HANDLER_SYMBOL ( \
423
- fname, SvdOnlyVtQRImpl <dtype>, \
423
+ fname, SvdOnlyUVtQRImpl <dtype>, \
424
424
ffi::Ffi::Bind () \
425
425
.Arg<ffi::Buffer<dtype>>(/* x*/ ) \
426
426
.Ret<ffi::Buffer<dtype>>(/* x_out*/ ) \
@@ -429,10 +429,10 @@ DEFINE_COMPLEX_SVD_ONLY_VT(svd_only_vt_c128, ffi::DataType::C128);
429
429
.Ret<ffi::Buffer<ffi::DataType::S32>>(/* info*/ ) \
430
430
.Attr<UVtMode>(" mode" ))
431
431
432
- DEFINE_REAL_SVD_ONLY_VT_QR(svd_only_vt_qr_f32 , ffi::DataType::F32);
433
- DEFINE_REAL_SVD_ONLY_VT_QR (svd_only_vt_qr_f64 , ffi::DataType::F64);
434
- DEFINE_COMPLEX_SVD_ONLY_VT_QR (svd_only_vt_qr_c64 , ffi::DataType::C64);
435
- DEFINE_COMPLEX_SVD_ONLY_VT_QR (svd_only_vt_qr_c128 , ffi::DataType::C128);
432
+ DEFINE_REAL_SVD_ONLY_VT_QR(svd_only_u_vt_qr_f32 , ffi::DataType::F32);
433
+ DEFINE_REAL_SVD_ONLY_VT_QR (svd_only_u_vt_qr_f64 , ffi::DataType::F64);
434
+ DEFINE_COMPLEX_SVD_ONLY_VT_QR (svd_only_u_vt_qr_c64 , ffi::DataType::C64);
435
+ DEFINE_COMPLEX_SVD_ONLY_VT_QR (svd_only_u_vt_qr_c128 , ffi::DataType::C128);
436
436
437
437
template <typename T>
438
438
static nb::capsule EncapsulateFfiCall (T *fn) {
@@ -441,13 +441,13 @@ static nb::capsule EncapsulateFfiCall(T *fn) {
441
441
return nb::capsule (reinterpret_cast <void *>(fn));
442
442
}
443
443
444
- NB_MODULE (_svd_only_vt , m) {
445
- m.def (" svd_only_vt_f32 " , []() { return EncapsulateFfiCall (svd_only_vt_f32 ); });
446
- m.def (" svd_only_vt_f64 " , []() { return EncapsulateFfiCall (svd_only_vt_f64 ); });
447
- m.def (" svd_only_vt_c64 " , []() { return EncapsulateFfiCall (svd_only_vt_c64 ); });
448
- m.def (" svd_only_vt_c128 " , []() { return EncapsulateFfiCall (svd_only_vt_c128 ); });
449
- m.def (" svd_only_vt_qr_f32 " , []() { return EncapsulateFfiCall (svd_only_vt_qr_f32 ); });
450
- m.def (" svd_only_vt_qr_f64 " , []() { return EncapsulateFfiCall (svd_only_vt_qr_f64 ); });
451
- m.def (" svd_only_vt_qr_c64 " , []() { return EncapsulateFfiCall (svd_only_vt_qr_c64 ); });
452
- m.def (" svd_only_vt_qr_c128 " , []() { return EncapsulateFfiCall (svd_only_vt_qr_c128 ); });
444
+ NB_MODULE (_svd_only_u_vt , m) {
445
+ m.def (" svd_only_u_vt_f32 " , []() { return EncapsulateFfiCall (svd_only_u_vt_f32 ); });
446
+ m.def (" svd_only_u_vt_f64 " , []() { return EncapsulateFfiCall (svd_only_u_vt_f64 ); });
447
+ m.def (" svd_only_u_vt_c64 " , []() { return EncapsulateFfiCall (svd_only_u_vt_c64 ); });
448
+ m.def (" svd_only_u_vt_c128 " , []() { return EncapsulateFfiCall (svd_only_u_vt_c128 ); });
449
+ m.def (" svd_only_u_vt_qr_f32 " , []() { return EncapsulateFfiCall (svd_only_u_vt_qr_f32 ); });
450
+ m.def (" svd_only_u_vt_qr_f64 " , []() { return EncapsulateFfiCall (svd_only_u_vt_qr_f64 ); });
451
+ m.def (" svd_only_u_vt_qr_c64 " , []() { return EncapsulateFfiCall (svd_only_u_vt_qr_c64 ); });
452
+ m.def (" svd_only_u_vt_qr_c128 " , []() { return EncapsulateFfiCall (svd_only_u_vt_qr_c128 ); });
453
453
}
0 commit comments