Skip to content

Commit

Permalink
Merge pull request #1 from DIG-Kaust/mpi
Browse files Browse the repository at this point in the history
feature: added mpi support for modelling and loss_grad
  • Loading branch information
mrava87 authored Jun 3, 2024
2 parents 74ec2a8 + 6f020f4 commit 3c0e1eb
Show file tree
Hide file tree
Showing 4 changed files with 561 additions and 6 deletions.
58 changes: 52 additions & 6 deletions devitofwi/waveengine/acoustic.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,27 @@
__all__ = ["AcousticWave2D"]

from typing import Optional, Type, Tuple
from typing import Any, Optional, NewType, Type, Tuple

import numpy as np

from pylops import LinearOperator
from pylops.utils import deps
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray, SamplingLike
from tqdm.notebook import tqdm
from tqdm.autonotebook import tqdm

from devito import Function
from examples.seismic import AcquisitionGeometry, Model, Receiver
from examples.seismic.acoustic import AcousticWaveSolver
from devitofwi.devito.source import CustomSource

try:
from mpi4py import MPI
mpitype = MPI.Comm
except:
mpitype = Any

MPIType = NewType("MPIType", mpitype)


class AcousticWave2D():
"""Devito Acoustic propagator.
Expand Down Expand Up @@ -69,6 +77,8 @@ class AcousticWave2D():
Loss object.
dtype : :obj:`str`, optional
Type of elements in input array.
base_comm : :obj:`mpi4py.MPI.Comm`, optional
Base MPI Communicator. Defaults to ``mpi4py.MPI.COMM_WORLD``.
"""

Expand All @@ -95,6 +105,7 @@ def __init__(
checkpointing: Optional[bool] = False,
loss: Optional[Type] = None,
dtype: Optional[DTypeLike] = "float32",
base_comm: Optional[MPIType] = None,
) -> None:

# Create vp if not provided and vprange is available
Expand All @@ -108,7 +119,7 @@ def __init__(
#if vpinit is not None and loss is None:
# raise ValueError("Must provide a loss to be able to run inversion...")

# Mmodelling parameters
# Modelling parameters
self.space_order = space_order
self.nbl = nbl
self.checkpointing = checkpointing
Expand All @@ -118,6 +129,9 @@ def __init__(
self.loss = loss
self.losshistory = []

# MPI parameters
self.base_comm = base_comm

# Create model
self.modelexists = True if vp is not None else False

Expand All @@ -128,7 +142,7 @@ def __init__(
# else:
# self.model = self._create_model(shape, origin, spacing, vpinit, space_order, nbl)

# create geometry
# Create geometry
self.geometry = self._create_geometry(self.model if vp is not None else self.initmodel,
src_x, src_z, rec_x, rec_z, t0, tn, src_type,
f0=f0, dt=dt)
Expand Down Expand Up @@ -305,7 +319,7 @@ def mod_allshots(self, dt=None) -> NDArray:
Returns
-------
d : :obj:`np.ndarray`
dtot : :obj:`np.ndarray`
Data for all shots
"""
Expand All @@ -316,6 +330,28 @@ def mod_allshots(self, dt=None) -> NDArray:
d = self._mod_oneshot(isrc, dt)
dtot.append(d)
dtot = np.array(dtot).reshape(nsrc, d.shape[0], d.shape[1])

return dtot

def mod_allshots_mpi(self, dt=None) -> NDArray:
"""FD modelling for all shots with mpi gathering
Parameters
----------
dt : :obj:`float`, optional
Time sampling used to resample modelled data
Returns
-------
d : :obj:`np.ndarray`
Data for all shots
"""
dtotrank = self.mod_allshots(dt)

# gather shots from all ranks
dtot = np.concatenate(self.base_comm.allgather(dtotrank), axis=0)

return dtot

def _adjoint_source(self, d_syn, isrc):
Expand Down Expand Up @@ -423,8 +459,18 @@ def _loss_grad(self, vp, isrcs=None, postprocess=None, computeloss=True, compute
elif computeloss:
loss += lossgrad

if computegrad:
grad = grad.data[:]

# Gather gradients
if self.base_comm is not None:
if computeloss:
loss = self.base_comm.allreduce(loss, op=MPI.SUM)
if computegrad:
grad = self.base_comm.allreduce(grad, op=MPI.SUM)

# Postprocess loss and gradient
grad = self._crop_model(grad.data[:], self.nbl)
grad = self._crop_model(grad, self.nbl)
vp = self._crop_model(vp.data[:], self.nbl)
if postprocess is not None:
loss, grad = postprocess(vp, loss, grad)
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,10 @@ dependencies = [
]
dynamic = ["version"]

[project.optional-dependencies]
mpi = [
"mpi4py",
]

[tool.setuptools.packages]
find = {}
249 changes: 249 additions & 0 deletions scripts/acoustic/AcousticVel_L2_1stage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
r"""
Acoustic FWI(VP) with entire data
This example is used to showcase how to perform acoustic FWI in a distributed manner using
MPI4py.
Run as: export DEVITO_LANGUAGE=openmp; export DEVITO_MPI=0; export OMP_NUM_THREADS=6; export MKL_NUM_THREADS=6; export NUMBA_NUM_THREADS=6; mpiexec -n 8 python AcousticVel_L2_1stage.py
"""

import numpy as np

from matplotlib import pyplot as plt
from mpi4py import MPI
from pylops.basicoperators import Identity
from pylops_mpi.DistributedArray import local_split, Partition

from scipy.ndimage import gaussian_filter
from scipy.optimize import minimize
from devito import configuration
from examples.seismic import AcquisitionGeometry, Model, Receiver
from examples.seismic import plot_velocity, plot_perturbation
from examples.seismic.acoustic import AcousticWaveSolver
from examples.seismic import plot_shotrecord

from devitofwi.devito.utils import clear_devito_cache
from devitofwi.waveengine.acoustic import AcousticWave2D
from devitofwi.preproc.masking import TimeSpaceMasking
from devitofwi.loss.l2 import L2
from devitofwi.postproc.acoustic import create_mask, PostProcessVP

comm = MPI.COMM_WORLD
rank = MPI.COMM_WORLD.Get_rank()
size = MPI.COMM_WORLD.Get_size()

configuration['log-level'] = 'ERROR'
clear_devito_cache()

# Callback to track model error
def fwi_callback(xk, vp, vp_error):
vp_error.append(np.linalg.norm((xk - vp.reshape(-1))/vp.reshape(-1)))


if rank == 0:
print(f'Distributed FWI ({size} ranks)')


##################################################################
# Parameters
##################################################################

# Model and aquisition parameters
par = {'nx':601, 'dx':15, 'ox':0,
'nz':221, 'dz':15, 'oz':0,
'ns':20, 'ds':300, 'os':1000, 'sz':0,
'nr':300, 'dr':30, 'or':0, 'rz':0,
'nt':3000, 'dt':0.002, 'ot':0,
'freq':15,
}

# Modelling parameters
shape = (par['nx'], par['nz'])
spacing = (par['dx'], par['dz'])
origin = (par['ox'], par['oz'])
space_order = 4
nbl = 20

# Velocity model
path = '../../data/'
velocity_file = path + 'Marm.bin'

# Time-space mask parameters
vwater = 1500
toff = 0.45

##################################################################
# Acquisition set-up
##################################################################

# Sampling frequency
fs = 1 / par['dt']

# Axes
x = np.arange(par['nx']) * par['dx'] + par['ox']
z = np.arange(par['nz']) * par['dz'] + par['oz']
t = np.arange(par['nt']) * par['dt'] + par['ot']
tmax = t[-1] * 1e3 # in ms

# Sources
x_s = np.zeros((par['ns'], 2))
x_s[:, 0] = np.arange(par['ns']) * par['ds'] + par['os']
x_s[:, 1] = par['sz']

# Receivers
x_r = np.zeros((par['nr'], 2))
x_r[:, 0] = np.arange(par['nr']) * par['dr'] + par['or']
x_r[:, 1] = par['rz']

##################################################################
# Velocity model
##################################################################

# Load the true model
vp_true = np.fromfile(velocity_file, np.float32).reshape(par['nz'], par['nx']).T
msk = create_mask(vp_true, 1.52) # get the mask for the water layer

if rank == 0:
m_vmin, m_vmax = np.percentile(vp_true, [2,98])

plt.figure(figsize=(14, 5))
plt.imshow(vp_true.T, vmin=m_vmin, vmax=m_vmax, cmap='jet',
extent=(x[0], x[-1], z[-1], z[0]))
plt.colorbar()
plt.scatter(x_r[:,0], x_r[:,1], c='w')
plt.scatter(x_s[:,0], x_s[:,1], c='r')
plt.title('True VP')
plt.axis('tight')
plt.savefig('figs/TrueVel.png')

# Initial model for FWI by smoothing the true model
vp_init = gaussian_filter(vp_true, sigma=[15,10])
vp_init = vp_init * msk # to preserve the water layer
vp_init[vp_init == 0] = 1.5

if rank == 0:
plt.figure(figsize=(14, 5))
plt.imshow(vp_init.T, vmin=m_vmin, vmax=m_vmax, cmap='jet',
extent=(x[0], x[-1], z[-1], z[0]))
plt.colorbar()
plt.scatter(x_r[:,0], x_r[:,1], c='w')
plt.scatter(x_s[:,0], x_s[:,1], c='r')
plt.title('Initial VP')
plt.axis('tight')
plt.savefig('figs/InitialVel.png')

##################################################################
# Data
##################################################################

# Choose how to split sources to ranks
ns_rank = local_split((par['ns'], ), MPI.COMM_WORLD, Partition.SCATTER, 0)
ns_ranks = np.concatenate(MPI.COMM_WORLD.allgather(ns_rank))
isin_rank = np.insert(np.cumsum(ns_ranks)[:-1] , 0, 0)[rank]
isend_rank = np.cumsum(ns_ranks)[rank]
print(f'Rank: {rank}, ns: {ns_rank}, isin: {isin_rank}, isend: {isend_rank}')

# Define modelling engine
amod = AcousticWave2D(shape, origin, spacing,
x_s[isin_rank:isend_rank, 0], x_s[isin_rank:isend_rank, 1],
x_r[:, 0], x_r[:, 1],
0., tmax,
vp=vp_true * 1e3,
src_type="Ricker", f0=par['freq'],
space_order=space_order, nbl=nbl,
base_comm=comm)

# Model data
if rank == 0:
print('Model data (and gather)...')

if rank == 0:
print('Model data...')
dobs = amod.mod_allshots()

##################################################################
# Gradient
##################################################################

# Define loss
l2loss = L2(Identity(int(np.prod(dobs.shape[1:]))), dobs.reshape(ns_rank[0], -1))

ainv = AcousticWave2D(shape, origin, spacing,
x_s[isin_rank:isend_rank, 0], x_s[isin_rank:isend_rank, 1],
x_r[:, 0], x_r[:, 1],
0., tmax,
vprange=(vp_true.min() * 1e3, vp_true.max() * 1e3),
vpinit=vp_init * 1e3,
src_type="Ricker", f0=par['freq'],
space_order=space_order, nbl=nbl,
loss=l2loss,
base_comm=comm)

# Compute first gradient and find scaling
postproc = PostProcessVP(scaling=1, mask=msk)

if rank == 0:
print('Compute gradient...')

loss, direction = ainv._loss_grad(ainv.initmodel.vp, postprocess=postproc.apply)

scaling = direction.max()

if rank == 0:
plt.figure(figsize=(14, 5))
plt.imshow(direction.T / scaling, cmap='seismic', vmin=-1e-1, vmax=1e-1,
extent=(x[0], x[-1], z[-1], z[0]))
plt.colorbar()
plt.scatter(x_r[:,0], x_r[:,1], c='w')
plt.scatter(x_s[:,0], x_s[:,1], c='r')
plt.title('L2 Gradient')
plt.axis('tight')
plt.savefig('figs/Gradient.png')


##################################################################
# FWI
##################################################################

# L-BFGS parameters
ftol = 1e-10
maxiter = 30
maxfun = 5000
vp_error = []

# Run FWI
convertvp = None
postproc = PostProcessVP(scaling=scaling, mask=msk)

if rank == 0:
print('Run FWI...')

nl = minimize(ainv.loss_grad, vp_init.ravel(), method='L-BFGS-B', jac=True,
args=(convertvp, postproc.apply),
callback=lambda x: fwi_callback(x, vp=vp_true, vp_error=vp_error),
options={'ftol':ftol, 'maxiter':maxiter, 'maxfun':maxfun,
'disp':True if rank ==0 else False})

if rank == 0:
print(nl)

plt.figure(figsize=(14, 5))
plt.plot(ainv.losshistory, 'k')
plt.title('Loss history')
plt.savefig('figs/Loss.png')

plt.figure(figsize=(14, 5))
plt.plot(vp_error, 'k')
plt.title('Model error history')
plt.savefig('figs/ModelError.png')

vp_inv = nl.x.reshape(shape)

plt.figure(figsize=(14, 5))
plt.imshow(vp_inv.T, vmin=m_vmin, vmax=m_vmax, cmap='jet', extent=(x[0], x[-1], z[-1], z[0]))
plt.colorbar()
plt.scatter(x_r[:,0], x_r[:,1], c='w')
plt.scatter(x_s[:,0], x_s[:,1], c='r')
plt.title('Inverted VP')
plt.axis('tight')
plt.savefig('figs/InvertedVP.png')
Loading

0 comments on commit 3c0e1eb

Please sign in to comment.