From b93e2351062a76b86b6d281777b4b4f01e36992d Mon Sep 17 00:00:00 2001 From: Judd Date: Wed, 5 Jul 2023 20:49:44 +0800 Subject: [PATCH] porting to windows (#24) 1. `getopt` is supported by a submodule; 1. mapped file ported to windows (like `llama.cpp`) 1. UTF-8 input on windows (like `llama.cpp`) Co-authored-by: Judd --- .gitmodules | 3 +++ CMakeLists.txt | 12 ++++++++++- chatglm.cpp | 54 ++++++++++++++++++++++++++++++++++++++++++++++++-- main.cpp | 46 +++++++++++++++++++++++++++++++++++++++++- 4 files changed, 111 insertions(+), 4 deletions(-) diff --git a/.gitmodules b/.gitmodules index 8843620..0ce8438 100644 --- a/.gitmodules +++ b/.gitmodules @@ -8,3 +8,6 @@ [submodule "third_party/ggml"] path = third_party/ggml url = https://github.com/ggerganov/ggml.git +[submodule "third_party/getopt"] + path = third_party/getopt + url = https://github.com/jingyu/getopt.git diff --git a/CMakeLists.txt b/CMakeLists.txt index dcfb671..0303118 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -5,7 +5,8 @@ set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib CACHE STRING "") set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib CACHE STRING "") set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin CACHE STRING "") -set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD 20) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall") if (NOT CMAKE_BUILD_TYPE) @@ -83,3 +84,12 @@ add_custom_target(lint COMMAND clang-format -i ${CPP_SOURCES} COMMAND isort ${PY_SOURCES} COMMAND black ${PY_SOURCES} --line-length 120) + +if (MSVC) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall") + add_definitions("/wd4267 /wd4244 /wd4305 /Zc:strictStrings /utf-8") + target_sources(main + PRIVATE third_party/getopt/getopt_long.c) + target_include_directories(main + PUBLIC third_party/getopt/) +endif() \ No newline at end of file diff --git a/chatglm.cpp b/chatglm.cpp index 95769bc..a4420fa 100644 --- a/chatglm.cpp +++ b/chatglm.cpp @@ -11,10 +11,32 @@ #include #include #include -#include +#include + #include #include -#include + +#ifdef __has_include + #if __has_include() + #include + #if defined(_POSIX_MAPPED_FILES) + #include + #endif + #if defined(_POSIX_MEMLOCK_RANGE) + #include + #endif + #endif +#endif + +#if defined(_WIN32) + #define WIN32_LEAN_AND_MEAN + #ifndef NOMINMAX + #define NOMINMAX + #endif + #include + #include + #include +#endif #ifdef GGML_USE_CUBLAS #include @@ -129,6 +151,7 @@ void TextStreamer::end() { print_len_ = 0; } +#ifdef _POSIX_MAPPED_FILES MappedFile::MappedFile(const std::string &path) { int fd = open(path.c_str(), O_RDONLY); CHATGLM_CHECK(fd > 0) << "cannot open file " << path << ": " << strerror(errno); @@ -144,6 +167,33 @@ MappedFile::MappedFile(const std::string &path) { } MappedFile::~MappedFile() { CHATGLM_CHECK(munmap(data, size) == 0) << strerror(errno); } +#elif defined(_WIN32) +MappedFile::MappedFile(const std::string &path) { + + int fd = open(path.c_str(), O_RDONLY); + CHATGLM_CHECK(fd > 0) << "cannot open file " << path << ": " << strerror(errno); + + struct _stat64 sb; + CHATGLM_CHECK(_fstat64(fd, &sb) == 0) << strerror(errno); + size = sb.st_size; + + HANDLE hFile = (HANDLE) _get_osfhandle(fd); + + HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL); + CHATGLM_CHECK(hMapping != NULL) << strerror(errno); + + data = (char *)MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0); + CloseHandle(hMapping); + + CHATGLM_CHECK(data != NULL) << strerror(errno); + + CHATGLM_CHECK(close(fd) == 0) << strerror(errno); +} + +MappedFile::~MappedFile() { + CHATGLM_CHECK(UnmapViewOfFile(data)) << strerror(errno); +} +#endif void ModelLoader::seek(int64_t offset, int whence) { if (whence == SEEK_SET) { diff --git a/main.cpp b/main.cpp index 268ae95..ab47cba 100644 --- a/main.cpp +++ b/main.cpp @@ -3,6 +3,12 @@ #include #include +#if defined(_WIN32) + #include + #include + #include +#endif + struct Args { std::string model_path = "chatglm-ggml.bin"; std::string prompt = "你好"; @@ -102,6 +108,40 @@ static Args parse_args(int argc, char **argv) { return args; } +#if defined(_WIN32) +static void append_utf8(char32_t ch, std::string &out) { + if (ch <= 0x7F) { + out.push_back(static_cast(ch)); + } else if (ch <= 0x7FF) { + out.push_back(static_cast(0xC0 | ((ch >> 6) & 0x1F))); + out.push_back(static_cast(0x80 | (ch & 0x3F))); + } else if (ch <= 0xFFFF) { + out.push_back(static_cast(0xE0 | ((ch >> 12) & 0x0F))); + out.push_back(static_cast(0x80 | ((ch >> 6) & 0x3F))); + out.push_back(static_cast(0x80 | (ch & 0x3F))); + } else if (ch <= 0x10FFFF) { + out.push_back(static_cast(0xF0 | ((ch >> 18) & 0x07))); + out.push_back(static_cast(0x80 | ((ch >> 12) & 0x3F))); + out.push_back(static_cast(0x80 | ((ch >> 6) & 0x3F))); + out.push_back(static_cast(0x80 | (ch & 0x3F))); + } else { + // Invalid Unicode code point + } +} + +static bool get_utf8_line(std::string& line) { + std::wstring prompt; + std::wcin >> prompt; + for (auto wc : prompt) + append_utf8(wc, line); + return true; +} +#else +static bool get_utf8_line(std::string& line) { + return !!std::getline(std::cin, line); +} +#endif + void chat(const Args &args) { chatglm::Pipeline pipeline(args.model_path); std::string model_name = pipeline.model->type_name(); @@ -110,6 +150,10 @@ void chat(const Args &args) { chatglm::GenerationConfig gen_config(args.max_length, args.max_context_length, args.temp > 0, args.top_k, args.top_p, args.temp, args.num_threads); +#if defined(_WIN32) + _setmode(_fileno(stdin), _O_WTEXT); +#endif + if (args.interactive) { std::cout << R"( ________ __ ________ __ ___ )" << '\n' << R"( / ____/ /_ ____ _/ /_/ ____/ / / |/ /_________ ____ )" << '\n' @@ -123,7 +167,7 @@ void chat(const Args &args) { std::cout << std::setw(model_name.size()) << std::left << "Prompt" << " > " << std::flush; std::string prompt; - if (!std::getline(std::cin, prompt)) { + if (!get_utf8_line(prompt)) { break; } if (prompt.empty()) {