Skip to content

Commit dc86911

Browse files
committed
Make name of svd extension consistent with recent changes
1 parent 2a26413 commit dc86911

File tree

4 files changed

+47
-47
lines changed

4 files changed

+47
-47
lines changed

CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ ENDIF()
3333
# We are now ready to compile the actual extension module
3434
nanobind_add_module(
3535
# Name of the extension
36-
_svd_only_vt
36+
_svd_only_u_vt
3737

3838
# Target the stable ABI for Python 3.12+, which reduces
3939
# the number of binary wheels that must be built. This
@@ -52,9 +52,9 @@ nanobind_add_module(
5252
varipeps/utils/extensions/svd_ffi.cpp
5353
)
5454

55-
target_include_directories(_svd_only_vt PRIVATE "${jaxlib_INCLUDE_DIR}")
55+
target_include_directories(_svd_only_u_vt PRIVATE "${jaxlib_INCLUDE_DIR}")
5656

5757
# target_link_libraries(_svd_only_vt PRIVATE lapack)
5858

5959
# Install directive for scikit-build-core
60-
install(TARGETS _svd_only_vt LIBRARY DESTINATION varipeps/utils/extensions)
60+
install(TARGETS _svd_only_u_vt LIBRARY DESTINATION varipeps/utils/extensions)

varipeps/utils/extensions/svd_ffi.cpp

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ namespace ffi = xla::ffi;
1111
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(UVtMode);
1212

1313
template <ffi::DataType dtype>
14-
static ffi::Error SvdOnlyVtImpl(
14+
static ffi::Error SvdOnlyUVtImpl(
1515
ffi::Buffer<dtype> x,
1616
ffi::ResultBuffer<dtype> x_out,
1717
ffi::ResultBuffer<ffi::ToReal(dtype)> s,
@@ -189,7 +189,7 @@ static ffi::Error SvdOnlyVtImpl(
189189
}
190190

191191
template <ffi::DataType dtype>
192-
static ffi::Error SvdOnlyVtQRImpl(
192+
static ffi::Error SvdOnlyUVtQRImpl(
193193
ffi::Buffer<dtype> x,
194194
ffi::ResultBuffer<dtype> x_out,
195195
ffi::ResultBuffer<ffi::ToReal(dtype)> s,
@@ -382,7 +382,7 @@ static ffi::Error SvdOnlyVtQRImpl(
382382

383383
#define DEFINE_REAL_SVD_ONLY_VT(fname, dtype) \
384384
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
385-
fname, SvdOnlyVtImpl<dtype>, \
385+
fname, SvdOnlyUVtImpl<dtype>, \
386386
ffi::Ffi::Bind() \
387387
.Arg<ffi::Buffer<dtype>>(/*x*/) \
388388
.Ret<ffi::Buffer<dtype>>(/*x_out*/) \
@@ -393,7 +393,7 @@ static ffi::Error SvdOnlyVtQRImpl(
393393

394394
#define DEFINE_COMPLEX_SVD_ONLY_VT(fname, dtype) \
395395
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
396-
fname, SvdOnlyVtImpl<dtype>, \
396+
fname, SvdOnlyUVtImpl<dtype>, \
397397
ffi::Ffi::Bind() \
398398
.Arg<ffi::Buffer<dtype>>(/*x*/) \
399399
.Ret<ffi::Buffer<dtype>>(/*x_out*/) \
@@ -402,14 +402,14 @@ static ffi::Error SvdOnlyVtQRImpl(
402402
.Ret<ffi::Buffer<ffi::DataType::S32>>(/*info*/) \
403403
.Attr<UVtMode>("mode"))
404404

405-
DEFINE_REAL_SVD_ONLY_VT(svd_only_vt_f32, ffi::DataType::F32);
406-
DEFINE_REAL_SVD_ONLY_VT(svd_only_vt_f64, ffi::DataType::F64);
407-
DEFINE_COMPLEX_SVD_ONLY_VT(svd_only_vt_c64, ffi::DataType::C64);
408-
DEFINE_COMPLEX_SVD_ONLY_VT(svd_only_vt_c128, ffi::DataType::C128);
405+
DEFINE_REAL_SVD_ONLY_VT(svd_only_u_vt_f32, ffi::DataType::F32);
406+
DEFINE_REAL_SVD_ONLY_VT(svd_only_u_vt_f64, ffi::DataType::F64);
407+
DEFINE_COMPLEX_SVD_ONLY_VT(svd_only_u_vt_c64, ffi::DataType::C64);
408+
DEFINE_COMPLEX_SVD_ONLY_VT(svd_only_u_vt_c128, ffi::DataType::C128);
409409

410410
#define DEFINE_REAL_SVD_ONLY_VT_QR(fname, dtype) \
411411
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
412-
fname, SvdOnlyVtQRImpl<dtype>, \
412+
fname, SvdOnlyUVtQRImpl<dtype>, \
413413
ffi::Ffi::Bind() \
414414
.Arg<ffi::Buffer<dtype>>(/*x*/) \
415415
.Ret<ffi::Buffer<dtype>>(/*x_out*/) \
@@ -420,7 +420,7 @@ DEFINE_COMPLEX_SVD_ONLY_VT(svd_only_vt_c128, ffi::DataType::C128);
420420

421421
#define DEFINE_COMPLEX_SVD_ONLY_VT_QR(fname, dtype) \
422422
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
423-
fname, SvdOnlyVtQRImpl<dtype>, \
423+
fname, SvdOnlyUVtQRImpl<dtype>, \
424424
ffi::Ffi::Bind() \
425425
.Arg<ffi::Buffer<dtype>>(/*x*/) \
426426
.Ret<ffi::Buffer<dtype>>(/*x_out*/) \
@@ -429,10 +429,10 @@ DEFINE_COMPLEX_SVD_ONLY_VT(svd_only_vt_c128, ffi::DataType::C128);
429429
.Ret<ffi::Buffer<ffi::DataType::S32>>(/*info*/) \
430430
.Attr<UVtMode>("mode"))
431431

432-
DEFINE_REAL_SVD_ONLY_VT_QR(svd_only_vt_qr_f32, ffi::DataType::F32);
433-
DEFINE_REAL_SVD_ONLY_VT_QR(svd_only_vt_qr_f64, ffi::DataType::F64);
434-
DEFINE_COMPLEX_SVD_ONLY_VT_QR(svd_only_vt_qr_c64, ffi::DataType::C64);
435-
DEFINE_COMPLEX_SVD_ONLY_VT_QR(svd_only_vt_qr_c128, ffi::DataType::C128);
432+
DEFINE_REAL_SVD_ONLY_VT_QR(svd_only_u_vt_qr_f32, ffi::DataType::F32);
433+
DEFINE_REAL_SVD_ONLY_VT_QR(svd_only_u_vt_qr_f64, ffi::DataType::F64);
434+
DEFINE_COMPLEX_SVD_ONLY_VT_QR(svd_only_u_vt_qr_c64, ffi::DataType::C64);
435+
DEFINE_COMPLEX_SVD_ONLY_VT_QR(svd_only_u_vt_qr_c128, ffi::DataType::C128);
436436

437437
template <typename T>
438438
static nb::capsule EncapsulateFfiCall(T *fn) {
@@ -441,13 +441,13 @@ static nb::capsule EncapsulateFfiCall(T *fn) {
441441
return nb::capsule(reinterpret_cast<void *>(fn));
442442
}
443443

444-
NB_MODULE(_svd_only_vt, m) {
445-
m.def("svd_only_vt_f32", []() { return EncapsulateFfiCall(svd_only_vt_f32); });
446-
m.def("svd_only_vt_f64", []() { return EncapsulateFfiCall(svd_only_vt_f64); });
447-
m.def("svd_only_vt_c64", []() { return EncapsulateFfiCall(svd_only_vt_c64); });
448-
m.def("svd_only_vt_c128", []() { return EncapsulateFfiCall(svd_only_vt_c128); });
449-
m.def("svd_only_vt_qr_f32", []() { return EncapsulateFfiCall(svd_only_vt_qr_f32); });
450-
m.def("svd_only_vt_qr_f64", []() { return EncapsulateFfiCall(svd_only_vt_qr_f64); });
451-
m.def("svd_only_vt_qr_c64", []() { return EncapsulateFfiCall(svd_only_vt_qr_c64); });
452-
m.def("svd_only_vt_qr_c128", []() { return EncapsulateFfiCall(svd_only_vt_qr_c128); });
444+
NB_MODULE(_svd_only_u_vt, m) {
445+
m.def("svd_only_u_vt_f32", []() { return EncapsulateFfiCall(svd_only_u_vt_f32); });
446+
m.def("svd_only_u_vt_f64", []() { return EncapsulateFfiCall(svd_only_u_vt_f64); });
447+
m.def("svd_only_u_vt_c64", []() { return EncapsulateFfiCall(svd_only_u_vt_c64); });
448+
m.def("svd_only_u_vt_c128", []() { return EncapsulateFfiCall(svd_only_u_vt_c128); });
449+
m.def("svd_only_u_vt_qr_f32", []() { return EncapsulateFfiCall(svd_only_u_vt_qr_f32); });
450+
m.def("svd_only_u_vt_qr_f64", []() { return EncapsulateFfiCall(svd_only_u_vt_qr_f64); });
451+
m.def("svd_only_u_vt_qr_c64", []() { return EncapsulateFfiCall(svd_only_u_vt_qr_c64); });
452+
m.def("svd_only_u_vt_qr_c128", []() { return EncapsulateFfiCall(svd_only_u_vt_qr_c128); });
453453
}

varipeps/utils/extensions/svd_ffi.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@ enum class UVtMode : int8_t {
99
computePartialUandVt = 2, // Compute only Vt
1010
};
1111

12-
XLA_FFI_DECLARE_HANDLER_SYMBOL(svd_only_vt_f32);
13-
XLA_FFI_DECLARE_HANDLER_SYMBOL(svd_only_vt_f64);
14-
XLA_FFI_DECLARE_HANDLER_SYMBOL(svd_only_vt_c64);
15-
XLA_FFI_DECLARE_HANDLER_SYMBOL(svd_only_vt_c128);
12+
XLA_FFI_DECLARE_HANDLER_SYMBOL(svd_only_u_vt_f32);
13+
XLA_FFI_DECLARE_HANDLER_SYMBOL(svd_only_u_vt_f64);
14+
XLA_FFI_DECLARE_HANDLER_SYMBOL(svd_only_u_vt_c64);
15+
XLA_FFI_DECLARE_HANDLER_SYMBOL(svd_only_u_vt_c128);
1616

17-
XLA_FFI_DECLARE_HANDLER_SYMBOL(svd_only_vt_qr_f32);
18-
XLA_FFI_DECLARE_HANDLER_SYMBOL(svd_only_vt_qr_f64);
19-
XLA_FFI_DECLARE_HANDLER_SYMBOL(svd_only_vt_qr_c64);
20-
XLA_FFI_DECLARE_HANDLER_SYMBOL(svd_only_vt_qr_c128);
17+
XLA_FFI_DECLARE_HANDLER_SYMBOL(svd_only_u_vt_qr_f32);
18+
XLA_FFI_DECLARE_HANDLER_SYMBOL(svd_only_u_vt_qr_f64);
19+
XLA_FFI_DECLARE_HANDLER_SYMBOL(svd_only_u_vt_qr_c64);
20+
XLA_FFI_DECLARE_HANDLER_SYMBOL(svd_only_u_vt_qr_c128);
2121

2222
#endif // VARIPEPS_SVD_FFI_H

varipeps/utils/svd.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from varipeps import varipeps_config
1313

14-
from .extensions import _svd_only_vt as _svd_only_vt_lib
14+
from .extensions import _svd_only_u_vt as _svd_only_u_vt_lib
1515

1616
from typing import Tuple
1717

@@ -111,45 +111,45 @@ def _svd_jvp_rule(primals, tangents):
111111

112112

113113
jax.ffi.register_ffi_target(
114-
"svd_only_vt_f32", _svd_only_vt_lib.svd_only_vt_f32(), platform="cpu"
114+
"svd_only_u_vt_f32", _svd_only_u_vt_lib.svd_only_u_vt_f32(), platform="cpu"
115115
)
116116
jax.ffi.register_ffi_target(
117-
"svd_only_vt_f64", _svd_only_vt_lib.svd_only_vt_f64(), platform="cpu"
117+
"svd_only_u_vt_f64", _svd_only_u_vt_lib.svd_only_u_vt_f64(), platform="cpu"
118118
)
119119
jax.ffi.register_ffi_target(
120-
"svd_only_vt_c64", _svd_only_vt_lib.svd_only_vt_c64(), platform="cpu"
120+
"svd_only_u_vt_c64", _svd_only_u_vt_lib.svd_only_u_vt_c64(), platform="cpu"
121121
)
122122
jax.ffi.register_ffi_target(
123-
"svd_only_vt_c128", _svd_only_vt_lib.svd_only_vt_c128(), platform="cpu"
123+
"svd_only_u_vt_c128", _svd_only_u_vt_lib.svd_only_u_vt_c128(), platform="cpu"
124124
)
125125
jax.ffi.register_ffi_target(
126-
"svd_only_vt_qr_f32", _svd_only_vt_lib.svd_only_vt_qr_f32(), platform="cpu"
126+
"svd_only_u_vt_qr_f32", _svd_only_u_vt_lib.svd_only_u_vt_qr_f32(), platform="cpu"
127127
)
128128
jax.ffi.register_ffi_target(
129-
"svd_only_vt_qr_f64", _svd_only_vt_lib.svd_only_vt_qr_f64(), platform="cpu"
129+
"svd_only_u_vt_qr_f64", _svd_only_u_vt_lib.svd_only_u_vt_qr_f64(), platform="cpu"
130130
)
131131
jax.ffi.register_ffi_target(
132-
"svd_only_vt_qr_c64", _svd_only_vt_lib.svd_only_vt_qr_c64(), platform="cpu"
132+
"svd_only_u_vt_qr_c64", _svd_only_u_vt_lib.svd_only_u_vt_qr_c64(), platform="cpu"
133133
)
134134
jax.ffi.register_ffi_target(
135-
"svd_only_vt_qr_c128", _svd_only_vt_lib.svd_only_vt_qr_c128(), platform="cpu"
135+
"svd_only_u_vt_qr_c128", _svd_only_u_vt_lib.svd_only_u_vt_qr_c128(), platform="cpu"
136136
)
137137

138138

139139
def _svd_only_u_vt_impl(a, u_or_vt, use_qr=True):
140140
suffix = "_qr" if use_qr else ""
141141

142142
if a.dtype == jnp.float32:
143-
fn = f"svd_only_vt{suffix}_f32"
143+
fn = f"svd_only_u_vt{suffix}_f32"
144144
real_dtype = jnp.float32
145145
elif a.dtype == jnp.float64:
146-
fn = f"svd_only_vt{suffix}_f64"
146+
fn = f"svd_only_u_vt{suffix}_f64"
147147
real_dtype = jnp.float64
148148
elif a.dtype == jnp.complex64:
149-
fn = f"svd_only_vt{suffix}_c64"
149+
fn = f"svd_only_u_vt{suffix}_c64"
150150
real_dtype = jnp.float32
151151
elif a.dtype == jnp.complex128:
152-
fn = f"svd_only_vt{suffix}_c128"
152+
fn = f"svd_only_u_vt{suffix}_c128"
153153
real_dtype = jnp.float64
154154
else:
155155
raise ValueError("Unsupported dtype")

0 commit comments

Comments
 (0)