Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding gauge shift #1348

Open
wants to merge 10 commits into
base: develop
Choose a base branch
from
29 changes: 29 additions & 0 deletions include/gauge_field.h
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,13 @@ namespace quda {
*/
virtual void copy(const GaugeField &src) = 0;

/**
* @brief Generic gauge field shift
* @param[in] src Source from which we are shifting (extended field in case of MPI)
* @param[in] dx Host array of shifts to apply to the field
*/
virtual void shift(const GaugeField &src, const array<int, 4> &dx) = 0;

/**
@brief Compute the L1 norm of the field
@param[in] dim Which dimension we are taking the norm of (dim=-1 mean all dimensions)
Expand Down Expand Up @@ -543,6 +550,13 @@ namespace quda {
*/
void copy(const GaugeField &src);

/**
* @brief Generic gauge field shift
* @param[in] src Source from which we are shifting (extended field in case of MPI)
* @param[in] dx Host array of shifts to apply to the field
*/
void shift(const GaugeField &src, const array<int, 4> &dx);

/**
@brief Download into this field from a CPU field
@param[in] cpu The CPU field source
Expand Down Expand Up @@ -680,6 +694,13 @@ namespace quda {
*/
void copy(const GaugeField &src);

/**
* @brief Generic gauge field shift
* @param[in] src Source from which we are shifting (extended field in case of MPI)
* @param[in] dx Host array of shifts to apply to the field
*/
void shift(const GaugeField &src, const array<int, 4> &dx);

void* Gauge_p() { return gauge; }
const void* Gauge_p() const { return gauge; }

Expand Down Expand Up @@ -872,4 +893,12 @@ namespace quda {

#define checkReconstruct(...) Reconstruct_(__func__, __FILE__, __LINE__, __VA_ARGS__)

/**
* @brief Generic gauge field shift
* @param[out] dst Gauge field to store output
* @param[in] srd Source from which we are shifting (extended field in case of MPI)
* @param[in] dx Host array of shifts to apply to the field
*/
void gaugeShift(GaugeField &dst, const GaugeField &src, const array<int, 4> &dx);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be moved to e.g. gauge_tools.h?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that seems like a consistent place for it.


} // namespace quda
61 changes: 61 additions & 0 deletions include/kernels/gauge_shift.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#pragma once

#include <gauge_field_order.h>
#include <quda_matrix.h>
#include <index_helper.cuh>
#include <kernel.h>

namespace quda
{

template <typename Float_, int nColor_, QudaReconstructType recon_u> struct GaugeShiftArg : kernel_param<> {
using Float = Float_;
static constexpr int nColor = nColor_;
static_assert(nColor == 3, "Only nColor=3 enabled at this time");
typedef typename gauge_mapper<Float, recon_u>::type Gauge;

Gauge out;
const Gauge in;

int S[4]; // the regular volume parameters
int X[4]; // the regular volume parameters
int E[4]; // the extended volume parameters
int border[4]; // radius of border
int P; // change of parity

GaugeShiftArg(GaugeField &out, const GaugeField &in, const array<int, 4> &dx) :
kernel_param(dim3(in.VolumeCB(), 2, in.Geometry())), out(out), in(in)
{
P = 0;
for (int i = 0; i < 4; i++) {
S[i] = dx[i];
X[i] = out.X()[i];
E[i] = in.X()[i];
border[i] = (E[i] - X[i]) / 2;
P += dx[i];
}
P = std::abs(P) % 2;
}
};

template <typename Arg> struct GaugeShift {
const Arg &arg;
constexpr GaugeShift(const Arg &arg) : arg(arg) { }
static constexpr const char *filename() { return KERNEL_FILE; }

__device__ __host__ void operator()(int x_cb, int parity, int dir)
{
using real = typename Arg::Float;
typedef Matrix<complex<real>, Arg::nColor> Link;

int x[4] = {0, 0, 0, 0};
getCoords(x, x_cb, arg.X, parity);
for (int dr = 0; dr < 4; ++dr) x[dr] += arg.border[dr]; // extended grid coordinates
int nbr_oddbit = arg.P == 1 ? (parity ^ 1) : parity;

Link link = arg.in(dir, linkIndexShift(x, arg.S, arg.E), nbr_oddbit);
arg.out(dir, x_cb, parity) = link;
}
};

} // namespace quda
1 change: 1 addition & 0 deletions lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ set (QUDA_OBJS
copy_gauge_half.cu copy_gauge_quarter.cu
copy_gauge.cpp copy_gauge_mg.cu copy_clover.cu
copy_gauge_offset.cu copy_color_spinor_offset.cu copy_clover_offset.cu
gauge_shift.cu
staggered_oprod.cu clover_trace_quda.cu
hisq_paths_force_quda.cu
unitarize_force_quda.cu unitarize_links_quda.cu milc_interface.cpp
Expand Down
20 changes: 20 additions & 0 deletions lib/cpu_gauge_field.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,26 @@ namespace quda {
}
}

void cpuGaugeField::shift(const GaugeField &src, const array<int, 4> &dx)
{
for (int i = 0; i < this->nDim; i++) {
if (dx[i] != 0) break;
// if zero shift, we simply copy
if (i == this->nDim - 1) return this->copy(src);
}
if (this == &src) errorQuda("Cannot copy in itself");

checkField(src);

// TODO: check src extension (needs to be enough for shifting)

if (typeid(src) == typeid(cudaGaugeField)) {
errorQuda("Not Implemented");
} else {
errorQuda("Not compatible type");
}
}

void cpuGaugeField::setGauge(void **gauge_)
{
if(create != QUDA_REFERENCE_FIELD_CREATE) {
Expand Down
19 changes: 19 additions & 0 deletions lib/cuda_gauge_field.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,25 @@ namespace quda {
qudaDeviceSynchronize(); // include sync here for accurate host-device profiling
}

void cudaGaugeField::shift(const GaugeField &src, const array<int, 4> &dx)
{
for (int i = 0; i < this->nDim; i++) {
if (dx[i] != 0) break;
if (i == this->nDim - 1) return this->copy(src);
}
if (this == &src) errorQuda("Cannot copy in itself");

checkField(src);

// TODO: check src extension (needs to be enough for shifting)

if (typeid(src) == typeid(cudaGaugeField)) {
gaugeShift(*this, src, dx);
} else {
errorQuda("Not compatible type");
}
}

void cudaGaugeField::loadCPUField(const cpuGaugeField &cpu) {
copy(cpu);
qudaDeviceSynchronize();
Expand Down
53 changes: 53 additions & 0 deletions lib/gauge_shift.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#include <tunable_nd.h>
#include <instantiate.h>
#include <gauge_field.h>
#include <kernels/gauge_shift.cuh>

namespace quda
{

template <typename Float, int nColor, QudaReconstructType recon_u> class ShiftGauge : public TunableKernel3D
{
GaugeField &out;
const GaugeField &in;
const array<int, 4> &dx;
unsigned int minThreads() const { return in.VolumeCB(); }

public:
ShiftGauge(GaugeField &out, const GaugeField &in, const array<int, 4> &dx) :
TunableKernel3D(in, 2, in.Geometry()), out(out), in(in), dx(dx)
{
strcat(aux, ",shift=");
for (int i = 0; i < in.Ndim(); i++) { strcat(aux, std::to_string(dx[i]).c_str()); }
strcat(aux, comm_dim_partitioned_string());
apply(device::get_default_stream());
}

void apply(const qudaStream_t &stream)
{
TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
launch<GaugeShift>(tp, stream, GaugeShiftArg<Float, nColor, recon_u>(out, in, dx));
}

void preTune() { }
void postTune() { }

long long flops() const { return in.Volume() * 4; }
long long bytes() const { return in.Bytes(); }
};

void gaugeShift(GaugeField &out, const GaugeField &in, const array<int, 4> &dx)
{
checkPrecision(in, out);
checkLocation(in, out);
checkReconstruct(in, out);

if (out.Geometry() != in.Geometry()) {
errorQuda("Field geometries %d %d do not match", out.Geometry(), in.Geometry());
}

// gauge field must be passed as first argument so we peel off its reconstruct type
instantiate<ShiftGauge>(out, in, dx);
}

} // namespace quda