Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding gauge shift #1348

Open
wants to merge 10 commits into
base: develop
Choose a base branch
from
29 changes: 29 additions & 0 deletions include/gauge_field.h
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,13 @@ namespace quda {
*/
virtual void copy(const GaugeField &src) = 0;

/**
* @brief Generic gauge field shift
* @param[in] src Source from which we are shifting (extended field in case of MPI)
* @param[in] dx Host array of shifts to apply to the field
*/
virtual void shift(const GaugeField &src, const int *dx) = 0;
maddyscientist marked this conversation as resolved.
Show resolved Hide resolved

/**
@brief Compute the L1 norm of the field
@param[in] dim Which dimension we are taking the norm of (dim=-1 mean all dimensions)
Expand Down Expand Up @@ -543,6 +550,13 @@ namespace quda {
*/
void copy(const GaugeField &src);

/**
* @brief Generic gauge field shift
* @param[in] src Source from which we are shifting (extended field in case of MPI)
* @param[in] dx Host array of shifts to apply to the field
*/
void shift(const GaugeField &src, const int *dx);
maddyscientist marked this conversation as resolved.
Show resolved Hide resolved

/**
@brief Download into this field from a CPU field
@param[in] cpu The CPU field source
Expand Down Expand Up @@ -680,6 +694,13 @@ namespace quda {
*/
void copy(const GaugeField &src);

/**
* @brief Generic gauge field shift
* @param[in] src Source from which we are shifting (extended field in case of MPI)
* @param[in] dx Host array of shifts to apply to the field
*/
void shift(const GaugeField &src, const int *dx);
maddyscientist marked this conversation as resolved.
Show resolved Hide resolved

void* Gauge_p() { return gauge; }
const void* Gauge_p() const { return gauge; }

Expand Down Expand Up @@ -872,4 +893,12 @@ namespace quda {

#define checkReconstruct(...) Reconstruct_(__func__, __FILE__, __LINE__, __VA_ARGS__)

/**
* @brief Generic gauge field shift
* @param[out] dst Gauge field to store output
* @param[in] srd Source from which we are shifting (extended field in case of MPI)
* @param[in] dx Host array of shifts to apply to the field
*/
void gaugeShift(GaugeField &dst, const GaugeField &src, const int *dx);

} // namespace quda
78 changes: 78 additions & 0 deletions include/kernels/gauge_shift.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#pragma once

#include <gauge_field_order.h>
#include <quda_matrix.h>
#include <index_helper.cuh>
#include <kernel.h>

namespace quda {

template <typename Float_, int nColor_, QudaReconstructType recon_u>
struct GaugeShiftArg : kernel_param<> {
using Float = Float_;
static constexpr int nColor = nColor_;
static_assert(nColor == 3, "Only nColor=3 enabled at this time");
typedef typename gauge_mapper<Float,recon_u>::type Gauge;

Gauge out;
const Gauge in;
int geometry;

int S[4]; // the regular volume parameters
int X[4]; // the regular volume parameters
int E[4]; // the extended volume parameters
int border[4]; // radius of border
int P; // change of parity

GaugeShiftArg(GaugeField &out, const GaugeField &in, const int* dx) :
kernel_param(dim3(in.VolumeCB(), 2, in.Geometry())),
out(out),
in(in),
geometry(in.Geometry())
{
P = 0;
for (int i=0; i<4; i++) {
S[i] = dx[i];
X[i] = out.X()[i];
E[i] = in.X()[i];
border[i] = (E[i] - X[i])/2;
P += dx[i];
}
P = std::abs(P)%2;
}
};

template <typename Arg, int dir>
maddyscientist marked this conversation as resolved.
Show resolved Hide resolved
__device__ __host__ inline void GaugeShiftKernel(const Arg &arg, int idx, int parity)
{
using real = typename Arg::Float;
typedef Matrix<complex<real>,Arg::nColor> Link;

int x[4] = {0, 0, 0, 0};
getCoords(x, idx, arg.X, parity);
for (int dr=0; dr<4; ++dr) x[dr] += arg.border[dr]; // extended grid coordinates
int nbr_oddbit = arg.P==1 ? (parity^1) : parity;

Link link = arg.in(dir, linkIndexShift(x,arg.S,arg.E), nbr_oddbit);
arg.out(dir, idx, parity) = link;
}

template <typename Arg> struct GaugeShift
{
const Arg &arg;
constexpr GaugeShift(const Arg &arg) : arg(arg) {}
static constexpr const char *filename() { return KERNEL_FILE; }

__device__ __host__ void operator()(int x_cb, int parity, int dir)
{
if(dir>=arg.geometry) return;
switch(dir) {
case 0: GaugeShiftKernel<Arg,0>(arg, x_cb, parity); break;
case 1: GaugeShiftKernel<Arg,1>(arg, x_cb, parity); break;
case 2: GaugeShiftKernel<Arg,2>(arg, x_cb, parity); break;
case 3: GaugeShiftKernel<Arg,3>(arg, x_cb, parity); break;
}
}
};

}
1 change: 1 addition & 0 deletions lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ set (QUDA_OBJS
copy_gauge_half.cu copy_gauge_quarter.cu
copy_gauge.cpp copy_gauge_mg.cu copy_clover.cu
copy_gauge_offset.cu copy_color_spinor_offset.cu copy_clover_offset.cu
gauge_shift.cu
staggered_oprod.cu clover_trace_quda.cu
hisq_paths_force_quda.cu
unitarize_force_quda.cu unitarize_links_quda.cu milc_interface.cpp
Expand Down
19 changes: 19 additions & 0 deletions lib/cpu_gauge_field.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,25 @@ namespace quda {
}
}

void cpuGaugeField::shift(const GaugeField &src, const int *dx) {
for(int i=0; i<this->nDim; i++) {
if (dx[i]!=0) break;
// if zero shift, we simply copy
if (i == this->nDim-1) return this->copy(src);
}
if (this == &src) errorQuda("Cannot copy in itself");

checkField(src);

// TODO: check src extension (needs to be enough for shifting)

if (typeid(src) == typeid(cudaGaugeField)) {
errorQuda("Not Implemented");
} else {
errorQuda("Not compatible type");
}
}

void cpuGaugeField::setGauge(void **gauge_)
{
if(create != QUDA_REFERENCE_FIELD_CREATE) {
Expand Down
18 changes: 18 additions & 0 deletions lib/cuda_gauge_field.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,24 @@ namespace quda {
qudaDeviceSynchronize(); // include sync here for accurate host-device profiling
}

void cudaGaugeField::shift(const GaugeField &src, const int *dx) {
for(int i=0; i<this->nDim; i++) {
if (dx[i]!=0) break;
if (i == this->nDim-1) return this->copy(src);
}
if (this == &src) errorQuda("Cannot copy in itself");

checkField(src);

// TODO: check src extension (needs to be enough for shifting)

if (typeid(src) == typeid(cudaGaugeField)) {
gaugeShift(*this, src, dx);
} else {
errorQuda("Not compatible type");
}
}

void cudaGaugeField::loadCPUField(const cpuGaugeField &cpu) {
copy(cpu);
qudaDeviceSynchronize();
Expand Down
57 changes: 57 additions & 0 deletions lib/gauge_shift.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#include <tunable_nd.h>
#include <instantiate.h>
#include <gauge_field.h>
#include <kernels/gauge_shift.cuh>

namespace quda {

template <typename Float, int nColor, QudaReconstructType recon_u> class ShiftGauge : public TunableKernel3D
{
GaugeField &out;
const GaugeField &in;
const int* dx;
unsigned int minThreads() const { return in.VolumeCB(); }

public:
ShiftGauge(GaugeField &out, const GaugeField &in, const int * dx) :
TunableKernel3D(in, 2, in.Geometry()),
out(out),
in(in),
dx(dx)
{
strcat(aux, ",shift=");
for(int i=0; i<in.Ndim(); i++) {
strcat(aux, std::to_string(dx[i]).c_str());
}
strcat(aux, comm_dim_partitioned_string());
apply(device::get_default_stream());
}

void apply(const qudaStream_t &stream)
{
TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
launch<GaugeShift>(tp, stream, GaugeShiftArg<Float, nColor, recon_u>(out, in, dx));
}

void preTune() { }
void postTune() { }

long long flops() const { return in.Volume() * 4; }
long long bytes() const { return in.Bytes(); }
};

void gaugeShift(GaugeField& out, const GaugeField& in, const int *dx)
{
checkPrecision(in, out);
checkLocation(in, out);
checkReconstruct(in, out);

if (out.Geometry() != in.Geometry()) {
errorQuda("Field geometries %d %d do not match", out.Geometry(), in.Geometry());
}

// gauge field must be passed as first argument so we peel off its reconstruct type
instantiate<ShiftGauge>(out, in, dx);
}

} // namespace quda