#include #include #include #include namespace test{ template class BlockMultiVector; template class BlockVector; template class MultiVector; template class Vector; template class DistObject; template class DistObject{ public: using local_ordinal_type = LocalOrdinal; using global_ordinal_type = GlobalOrdinal; using node_type = Node; using map_type = Tpetra::Map; explicit DistObject (const Teuchos::RCP& map) : map_(*map){} DistObject (const DistObject&) = default; DistObject& operator= (const DistObject< LocalOrdinal, GlobalOrdinal, Node>&) = default; DistObject (DistObject< LocalOrdinal, GlobalOrdinal, Node>&&) = default; DistObject& operator= (DistObject< LocalOrdinal, GlobalOrdinal, Node>&&) = default; virtual ~DistObject () = default; virtual Teuchos::RCP getMap () const { return Teuchos::rcp(new map_type(map_)); } public: map_type map_; }; // // TPETRA // template class MultiVector : public DistObject { public: using map_type = Tpetra::Map; using local_ordinal_type = typename map_type::local_ordinal_type; using global_ordinal_type = typename map_type::global_ordinal_type; using node_type = typename map_type::node_type; using base_type = DistObject; MultiVector() : base_type(){} MultiVector (const Teuchos::RCP& map, const size_t numVecs) : base_type (map){} protected: MultiVector (const MultiVector& X, const size_t j) : base_type (X.getMap ()){} public: size_t getLocalLength () const{ if (this->getMap ().is_null ()) { return static_cast (0); } else { return this->getMap()->getNodeNumElements(); } } Teuchos::RCP > getVectorNonConst (const size_t j){ typedef Vector V; return Teuchos::rcp (new V (*this, j)); } }; template class Vector : public MultiVector { private: using base_type = MultiVector; public: typedef LocalOrdinal local_ordinal_type; typedef GlobalOrdinal global_ordinal_type; typedef Node node_type; typedef typename base_type::map_type map_type; Vector () : base_type(){} Vector (const MultiVector& X, const size_t j) : base_type (X, j){} Vector (const Vector&) = default; Vector (Vector&&) = default; Vector& operator= (const Vector&) = default; Vector& operator=(Vector&&) = default; virtual ~Vector () = default; }; // // TPETRA-BLOCK // template class BlockMultiVector : public DistObject { public: using dist_object_type = DistObject; using map_type = Tpetra::Map; using mv_type = MultiVector; using scalar_type = Scalar; BlockMultiVector (const BlockMultiVector&) = default; BlockMultiVector (BlockMultiVector&&) = default; BlockMultiVector& operator= (const BlockMultiVector&) = default; BlockMultiVector& operator= (BlockMultiVector&&) = default; BlockMultiVector(): dist_object_type(Teuchos::null), blockSize_(0){} BlockMultiVector (const map_type& meshMap, const LO blockSize, const LO numVecs) : dist_object_type (Teuchos::rcp (new map_type (meshMap))), // shallow copy meshMap_(meshMap), pointMap_(makePointMap (meshMap, blockSize)), mv_ (Teuchos::rcpFromRef (pointMap_), numVecs), blockSize_(blockSize){} map_type makePointMap (const map_type& meshMap, const LO blockSize) { typedef Tpetra::global_size_t GST; const GST gblNumMeshMapInds = static_cast (meshMap.getGlobalNumElements ()); const size_t lclNumMeshMapIndices = static_cast (meshMap.getLocalNumElements ()); const GST gblNumPointMapInds = gblNumMeshMapInds * static_cast (blockSize); const size_t lclNumPointMapInds = lclNumMeshMapIndices * static_cast (blockSize); const GO indexBase = meshMap.getIndexBase (); return map_type (gblNumPointMapInds, lclNumPointMapInds, indexBase, meshMap.getComm ()); } protected: map_type meshMap_; map_type pointMap_; mv_type mv_; LO blockSize_; }; template class BlockVector : public BlockMultiVector { private: typedef BlockMultiVector base_type; typedef Vector vec_type; typedef Tpetra::Map map_type; public: BlockVector() : base_type(){} BlockVector (const BlockVector&) = default; BlockVector (BlockVector&&) = default; BlockVector& operator= (const BlockVector&) = default; BlockVector& operator= (BlockVector&&) = default; BlockVector (const map_type& meshMap, const LO blockSize) : base_type (meshMap, blockSize, 1){} vec_type getVectorView () { Teuchos::RCP vPtr = this->mv_.getVectorNonConst(0); return *vPtr; } }; } using tcomm = Teuchos::Comm; using map_t = Tpetra::Map<>; using vec_type = test::BlockVector<>; class MyObject{ private: map_t map_; public: using state_type = vec_type; MyObject(){ Teuchos::RCP comm_ = Teuchos::rcp (new Teuchos::MpiComm(MPI_COMM_WORLD)); const int numGlobalEntries = 15*comm_->getSize(); map_ = map_t(numGlobalEntries, 0, comm_); } state_type createState() const{ state_type v(map_, 5); return v; } }; template struct Registry{ DT1 d1_; DT2 d2_; DT1 & getA(){ return d1_; } DT2 & getB(){ return d2_; } }; void print(int rank, const std::string & s1, const std::string & s2, vec_type r){ std::cout << "\n"; auto tpmv = r.getVectorView(); std::cout << s1 << "ext = " << r.getMap()->getGlobalNumElements() << "\n"; std::cout << s2 << "ext = " << tpmv.getMap()->getGlobalNumElements() << "\n"; } template class Foo{ int rank_; RegistryType reg_; public: Foo(RegistryType && reg) : reg_(std::move(reg)) { MPI_Comm_rank(MPI_COMM_WORLD, &rank_); auto & a = reg_.getA(); auto & b = reg_.getB(); print(rank_, "CONSTR-block-A ", "CONSTR-tpetr-A ", a); print(rank_, "CONSTR-block-B ", "CONSTR-tpetr-B ", b); } void dummy() { auto & a = reg_.getA(); auto & b = reg_.getB(); print(rank_, "DUMMY-block-A ", "DUMMY-tpetr-A ", a); print(rank_, "DUMMY-block-B ", "DUMMY-tpetr-B ", b); } }; template auto createFoo(const T& system){ using state_t = typename T::state_type; using registry_t = Registry; registry_t reg{system.createState(), system.createState()}; return Foo(std::move(reg)); } template auto createRegistry(const T & system){ using state_t = typename T::state_type; using registry_t = Registry; registry_t reg{system.createState(), system.createState()}; return reg; } int main(int argc, char **argv) { Tpetra::ScopeGuard tpetraScope (&argc, &argv); { { std::cout << "starting case that works\n"; MyObject object; auto reg = createRegistry(object); Foo f(std::move(reg)); f.dummy(); } { std::cout << "\n"; std::cout << "\n"; std::cout << "starting case that does NOT work\n"; MyObject object; auto f = createFoo(object); f.dummy(); } } return 0; }