Skip to content

Commit

Permalink
Add IO utilities.
Browse files Browse the repository at this point in the history
* Add fixed size stream for reading model stream.
* Add file extension.
  • Loading branch information
trivialfis committed Dec 4, 2019
1 parent e3c34c7 commit 83b9106
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 31 deletions.
80 changes: 79 additions & 1 deletion src/common/io.cc
Original file line number Diff line number Diff line change
@@ -1,20 +1,97 @@
/*!
* Copyright (c) by Contributors 2019
* Copyright (c) by XGBoost Contributors 2019
*/
#if defined(__unix__)
#include <sys/stat.h>
#include <fcntl.h>
#include <unistd.h>
#endif // defined(__unix__)
#include <algorithm>
#include <cstdio>
#include <string>
#include <utility>

#include "xgboost/logging.h"
#include "io.h"

namespace xgboost {
namespace common {

size_t PeekableInStream::Read(void* dptr, size_t size) {
size_t nbuffer = buffer_.length() - buffer_ptr_;
if (nbuffer == 0) return strm_->Read(dptr, size);
if (nbuffer < size) {
std::memcpy(dptr, dmlc::BeginPtr(buffer_) + buffer_ptr_, nbuffer);
buffer_ptr_ += nbuffer;
return nbuffer + strm_->Read(reinterpret_cast<char*>(dptr) + nbuffer,
size - nbuffer);
} else {
std::memcpy(dptr, dmlc::BeginPtr(buffer_) + buffer_ptr_, size);
buffer_ptr_ += size;
return size;
}
}

size_t PeekableInStream::PeekRead(void* dptr, size_t size) {
size_t nbuffer = buffer_.length() - buffer_ptr_;
if (nbuffer < size) {
buffer_ = buffer_.substr(buffer_ptr_, buffer_.length());
buffer_ptr_ = 0;
buffer_.resize(size);
size_t nadd = strm_->Read(dmlc::BeginPtr(buffer_) + nbuffer, size - nbuffer);
buffer_.resize(nbuffer + nadd);
std::memcpy(dptr, dmlc::BeginPtr(buffer_), buffer_.length());
return buffer_.size();
} else {
std::memcpy(dptr, dmlc::BeginPtr(buffer_) + buffer_ptr_, size);
return size;
}
}

FixedSizeStream::FixedSizeStream(PeekableInStream* stream) : PeekableInStream(stream), pointer_{0} {
size_t constexpr kInitialSize = 4096;
size_t size {kInitialSize}, total {0};
buffer_.clear();
while (true) {
buffer_.resize(size);
size_t read = stream->PeekRead(&buffer_[0], size);
total = read;
if (read < size) {
break;
}
size *= 2;
}
buffer_.resize(total);
}

size_t FixedSizeStream::Read(void* dptr, size_t size) {
auto read = this->PeekRead(dptr, size);
pointer_ += read;
return read;
}

size_t FixedSizeStream::PeekRead(void* dptr, size_t size) {
if (size >= buffer_.size() - pointer_) {
std::copy(buffer_.cbegin() + pointer_, buffer_.cend(), reinterpret_cast<char*>(dptr));
return std::distance(buffer_.cbegin() + pointer_, buffer_.cend());
} else {
auto const beg = buffer_.cbegin() + pointer_;
auto const end = beg + size;
std::copy(beg, end, reinterpret_cast<char*>(dptr));
return std::distance(beg, end);
}
}

void FixedSizeStream::Seek(size_t pos) {
pointer_ = pos;
CHECK_LE(pointer_, buffer_.size());
}

void FixedSizeStream::Take(std::string* out) {
CHECK(out);
*out = std::move(buffer_);
}

std::string LoadSequentialFile(std::string fname) {
auto OpenErr = [&fname]() {
std::string msg;
Expand Down Expand Up @@ -59,6 +136,7 @@ std::string LoadSequentialFile(std::string fname) {

buffer.resize(fsize + 1);
fread(&buffer[0], 1, fsize, f);
buffer.back() = '\0';
fclose(f);
#endif // defined(__unix__)
return buffer;
Expand Down
65 changes: 35 additions & 30 deletions src/common/io.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#include <string>
#include <cstring>

#include "common.h"

namespace xgboost {
namespace common {
using MemoryFixSizeBuffer = rabit::utils::MemoryFixSizeBuffer;
Expand All @@ -27,36 +29,8 @@ class PeekableInStream : public dmlc::Stream {
explicit PeekableInStream(dmlc::Stream* strm)
: strm_(strm), buffer_ptr_(0) {}

size_t Read(void* dptr, size_t size) override {
size_t nbuffer = buffer_.length() - buffer_ptr_;
if (nbuffer == 0) return strm_->Read(dptr, size);
if (nbuffer < size) {
std::memcpy(dptr, dmlc::BeginPtr(buffer_) + buffer_ptr_, nbuffer);
buffer_ptr_ += nbuffer;
return nbuffer + strm_->Read(reinterpret_cast<char*>(dptr) + nbuffer,
size - nbuffer);
} else {
std::memcpy(dptr, dmlc::BeginPtr(buffer_) + buffer_ptr_, size);
buffer_ptr_ += size;
return size;
}
}

size_t PeekRead(void* dptr, size_t size) {
size_t nbuffer = buffer_.length() - buffer_ptr_;
if (nbuffer < size) {
buffer_ = buffer_.substr(buffer_ptr_, buffer_.length());
buffer_ptr_ = 0;
buffer_.resize(size);
size_t nadd = strm_->Read(dmlc::BeginPtr(buffer_) + nbuffer, size - nbuffer);
buffer_.resize(nbuffer + nadd);
std::memcpy(dptr, dmlc::BeginPtr(buffer_), buffer_.length());
return buffer_.length();
} else {
std::memcpy(dptr, dmlc::BeginPtr(buffer_) + buffer_ptr_, size);
return size;
}
}
size_t Read(void* dptr, size_t size) override;
virtual size_t PeekRead(void* dptr, size_t size);

void Write(const void* dptr, size_t size) override {
LOG(FATAL) << "Not implemented";
Expand All @@ -71,9 +45,40 @@ class PeekableInStream : public dmlc::Stream {
std::string buffer_;
};

class FixedSizeStream : public PeekableInStream {
public:
explicit FixedSizeStream(PeekableInStream* stream);
~FixedSizeStream() = default;

size_t Read(void* dptr, size_t size) override;
size_t PeekRead(void* dptr, size_t size) override;
size_t Size() const { return buffer_.size(); }
size_t Tell() const { return pointer_; }
void Seek(size_t pos);

void Write(const void* dptr, size_t size) override {
LOG(FATAL) << "Not implemented";
}

void Take(std::string* out);

private:
size_t pointer_;
std::string buffer_;
};

// Optimized for consecutive file loading in unix like systime.
std::string LoadSequentialFile(std::string fname);

inline std::string FileExtension(std::string const& fname) {
auto splited = Split(fname, '.');
if (splited.size() > 1) {
return splited.back();
} else {
return "";
}
}

} // namespace common
} // namespace xgboost
#endif // XGBOOST_COMMON_IO_H_
43 changes: 43 additions & 0 deletions tests/cpp/common/test_io.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*!
* Copyright (c) by XGBoost Contributors 2019
*/
#include <gtest/gtest.h>
#include "../../../src/common/io.h"

namespace xgboost {
namespace common {
TEST(IO, FileExtension) {
std::string filename {u8"model.json"};
auto ext = FileExtension(filename);
ASSERT_EQ(ext, u8"json");
}

TEST(IO, FixedSizeStream) {
std::string buffer {"This is the content of stream"};
{
MemoryFixSizeBuffer stream((void *)buffer.c_str(), buffer.size());
PeekableInStream peekable(&stream);
FixedSizeStream fixed(&peekable);

std::string out_buffer;
fixed.Take(&out_buffer);
ASSERT_EQ(buffer, out_buffer);
}

{
std::string huge_buffer;
for (size_t i = 0; i < 512; i++) {
huge_buffer += buffer;
}

MemoryFixSizeBuffer stream((void *)huge_buffer.c_str(), huge_buffer.size());
PeekableInStream peekable(&stream);
FixedSizeStream fixed(&peekable);

std::string out_buffer;
fixed.Take(&out_buffer);
ASSERT_EQ(huge_buffer, out_buffer);
}
}
}
} // namespace xgboost

0 comments on commit 83b9106

Please sign in to comment.