Skip to content

Commit

Permalink
porting to windows (#24)
Browse files Browse the repository at this point in the history
1. `getopt` is supported by a submodule;
1. mapped file ported to windows (like `llama.cpp`)
1. UTF-8 input on windows (like `llama.cpp`)

Co-authored-by: Judd <foldl@boxvest.com>
  • Loading branch information
foldl and Judd committed Jul 5, 2023
1 parent 9d68879 commit b93e235
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 4 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,6 @@
[submodule "third_party/ggml"]
path = third_party/ggml
url = https://github.com/ggerganov/ggml.git
[submodule "third_party/getopt"]
path = third_party/getopt
url = https://github.com/jingyu/getopt.git
12 changes: 11 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib CACHE STRING "")
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib CACHE STRING "")
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin CACHE STRING "")

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD 20)

set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall")

if (NOT CMAKE_BUILD_TYPE)
Expand Down Expand Up @@ -83,3 +84,12 @@ add_custom_target(lint
COMMAND clang-format -i ${CPP_SOURCES}
COMMAND isort ${PY_SOURCES}
COMMAND black ${PY_SOURCES} --line-length 120)

if (MSVC)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall")
add_definitions("/wd4267 /wd4244 /wd4305 /Zc:strictStrings /utf-8")
target_sources(main
PRIVATE third_party/getopt/getopt_long.c)
target_include_directories(main
PUBLIC third_party/getopt/)
endif()
54 changes: 52 additions & 2 deletions chatglm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,32 @@
#include <random>
#include <regex>
#include <string>
#include <sys/mman.h>
#include <functional>

#include <sys/stat.h>
#include <thread>
#include <unistd.h>

#ifdef __has_include
#if __has_include(<unistd.h>)
#include <unistd.h>
#if defined(_POSIX_MAPPED_FILES)
#include <sys/mman.h>
#endif
#if defined(_POSIX_MEMLOCK_RANGE)
#include <sys/resource.h>
#endif
#endif
#endif

#if defined(_WIN32)
#define WIN32_LEAN_AND_MEAN
#ifndef NOMINMAX
#define NOMINMAX
#endif
#include <windows.h>
#include <io.h>
#include <stdio.h>
#endif

#ifdef GGML_USE_CUBLAS
#include <ggml-cuda.h>
Expand Down Expand Up @@ -129,6 +151,7 @@ void TextStreamer::end() {
print_len_ = 0;
}

#ifdef _POSIX_MAPPED_FILES
MappedFile::MappedFile(const std::string &path) {
int fd = open(path.c_str(), O_RDONLY);
CHATGLM_CHECK(fd > 0) << "cannot open file " << path << ": " << strerror(errno);
Expand All @@ -144,6 +167,33 @@ MappedFile::MappedFile(const std::string &path) {
}

MappedFile::~MappedFile() { CHATGLM_CHECK(munmap(data, size) == 0) << strerror(errno); }
#elif defined(_WIN32)
MappedFile::MappedFile(const std::string &path) {

int fd = open(path.c_str(), O_RDONLY);
CHATGLM_CHECK(fd > 0) << "cannot open file " << path << ": " << strerror(errno);

struct _stat64 sb;
CHATGLM_CHECK(_fstat64(fd, &sb) == 0) << strerror(errno);
size = sb.st_size;

HANDLE hFile = (HANDLE) _get_osfhandle(fd);

HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL);
CHATGLM_CHECK(hMapping != NULL) << strerror(errno);

data = (char *)MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0);
CloseHandle(hMapping);

CHATGLM_CHECK(data != NULL) << strerror(errno);

CHATGLM_CHECK(close(fd) == 0) << strerror(errno);
}

MappedFile::~MappedFile() {
CHATGLM_CHECK(UnmapViewOfFile(data)) << strerror(errno);
}
#endif

void ModelLoader::seek(int64_t offset, int whence) {
if (whence == SEEK_SET) {
Expand Down
46 changes: 45 additions & 1 deletion main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
#include <iomanip>
#include <iostream>

#if defined(_WIN32)
#include <fcntl.h>
#include <io.h>
#include <windows.h>
#endif

struct Args {
std::string model_path = "chatglm-ggml.bin";
std::string prompt = "你好";
Expand Down Expand Up @@ -102,6 +108,40 @@ static Args parse_args(int argc, char **argv) {
return args;
}

#if defined(_WIN32)
static void append_utf8(char32_t ch, std::string &out) {
if (ch <= 0x7F) {
out.push_back(static_cast<unsigned char>(ch));
} else if (ch <= 0x7FF) {
out.push_back(static_cast<unsigned char>(0xC0 | ((ch >> 6) & 0x1F)));
out.push_back(static_cast<unsigned char>(0x80 | (ch & 0x3F)));
} else if (ch <= 0xFFFF) {
out.push_back(static_cast<unsigned char>(0xE0 | ((ch >> 12) & 0x0F)));
out.push_back(static_cast<unsigned char>(0x80 | ((ch >> 6) & 0x3F)));
out.push_back(static_cast<unsigned char>(0x80 | (ch & 0x3F)));
} else if (ch <= 0x10FFFF) {
out.push_back(static_cast<unsigned char>(0xF0 | ((ch >> 18) & 0x07)));
out.push_back(static_cast<unsigned char>(0x80 | ((ch >> 12) & 0x3F)));
out.push_back(static_cast<unsigned char>(0x80 | ((ch >> 6) & 0x3F)));
out.push_back(static_cast<unsigned char>(0x80 | (ch & 0x3F)));
} else {
// Invalid Unicode code point
}
}

static bool get_utf8_line(std::string& line) {
std::wstring prompt;
std::wcin >> prompt;
for (auto wc : prompt)
append_utf8(wc, line);
return true;
}
#else
static bool get_utf8_line(std::string& line) {
return !!std::getline(std::cin, line);
}
#endif

void chat(const Args &args) {
chatglm::Pipeline pipeline(args.model_path);
std::string model_name = pipeline.model->type_name();
Expand All @@ -110,6 +150,10 @@ void chat(const Args &args) {
chatglm::GenerationConfig gen_config(args.max_length, args.max_context_length, args.temp > 0, args.top_k,
args.top_p, args.temp, args.num_threads);

#if defined(_WIN32)
_setmode(_fileno(stdin), _O_WTEXT);
#endif

if (args.interactive) {
std::cout << R"( ________ __ ________ __ ___ )" << '\n'
<< R"( / ____/ /_ ____ _/ /_/ ____/ / / |/ /_________ ____ )" << '\n'
Expand All @@ -123,7 +167,7 @@ void chat(const Args &args) {
std::cout << std::setw(model_name.size()) << std::left << "Prompt"
<< " > " << std::flush;
std::string prompt;
if (!std::getline(std::cin, prompt)) {
if (!get_utf8_line(prompt)) {
break;
}
if (prompt.empty()) {
Expand Down

0 comments on commit b93e235

Please sign in to comment.