Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Phi-3-mini model #119

Merged
merged 4 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,17 @@ else ()
target_link_libraries(demo_opt MLLM_CPU)
endif ()

add_executable(demo_phi3 ${PROJECT_SOURCE_DIR}/examples/demo_phi3.cpp ${DIR_SRC_CPU} ${DIR_SRC_MEM_MANAGER} ${DIR_SRC_EXP} ${DIR_SRC}
src/tokenizers/Tokenizer.cpp
src/tokenizers/BPE/Bpe.cpp
)
if (MLLM_OPENMP_STATIC)
target_compile_options(demo_phi3 PRIVATE -fopenmp)
target_link_libraries(demo_phi3 PUBLIC MLLM_CPU -fopenmp -static-openmp)
else ()
target_link_libraries(demo_phi3 MLLM_CPU)
endif ()

# add_executable(demo_deepseek ${PROJECT_SOURCE_DIR}/examples/demo_deepseek.cpp ${DIR_SRC_CPU} ${DIR_SRC_MEM_MANAGER} ${DIR_SRC_EXP} ${DIR_SRC}
# src/tokenizers/Tokenizer.cpp
# src/tokenizers/BPE/Bpe.cpp
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ Wait.. why on-device multimodal LLM? - It's a key building block for [intelligen
| [Yi 6B](https://huggingface.co/01-ai/Yi-1.5-6B) | [✔️](https://huggingface.co/mllmTeam/yi-1.5-6b-chat-mllm/tree/main) | [✔️](https://huggingface.co/mllmTeam/yi-1.5-6b-chat-mllm/tree/main) | |
| [StableLM 1.6B](https://github.com/Stability-AI/StableLM) | [✔️](https://huggingface.co/mllmTeam/stablelm-2-1.6b-chat-mllm/tree/main) | [✔️](https://huggingface.co/mllmTeam/stablelm-2-1.6b-chat-mllm/tree/main) | |
| [OPT 1.3B](https://github.com/facebookresearch/metaseq/tree/main/projects/OPT) | [✔️](https://huggingface.co/mllmTeam/opt-1.3b-mllm/tree/main) | [✔️](https://huggingface.co/mllmTeam/opt-1.3b-mllm/tree/main) | |
| [Phi-3-mini 3.8B](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) | [✔️](https://huggingface.co/mllmTeam/phi-3-mini-instruct-mllm/tree/main) | [✔️](https://huggingface.co/mllmTeam/phi-3-mini-instruct-mllm/tree/main) | |

## Quick Start

Expand Down
59 changes: 59 additions & 0 deletions examples/demo_phi3.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#include <iostream>
#include "cmdline.h"
#include "models/phi3/modeling_phi3.hpp"
#include "models/phi3/tokenization_phi3.hpp"
#include "processor/PostProcess.hpp"

using namespace mllm;

int main(int argc, char **argv) {
cmdline::parser cmdParser;
cmdParser.add<string>("vocab", 'v', "specify mllm tokenizer model path", false, "../vocab/phi3_vocab.mllm");
cmdParser.add<string>("model", 'm', "specify mllm model path", false, "../models/phi-3-mini-instruct-q4_k.mllm");
cmdParser.add<int>("limits", 'l', "max KV cache size", false, 400);
cmdParser.add<int>("thread", 't', "num of threads", false, 4);
cmdParser.parse_check(argc, argv);

string vocab_path = cmdParser.get<string>("vocab");
string model_path = cmdParser.get<string>("model");
int tokens_limit = cmdParser.get<int>("limits");
CPUBackend::cpu_threads = cmdParser.get<int>("thread");

auto tokenizer = Phi3Tokenizer(vocab_path);

Phi3Config config(tokens_limit, "3.8B", HFHUBROPE);
auto model = Phi3Model(config);
model.load(model_path);

string system_prompt_start = "<|user|>\n";
string system_prompt_end = " <|end|>\n<|assistant|>";

vector<string> in_strs = {
"who are you?",
"What can you do?",
"Please introduce Beijing University of Posts and Telecommunications."};

for (int i = 0; i < in_strs.size(); ++i) {
auto in_str_origin = in_strs[i];
auto in_str = system_prompt_start + in_str_origin + system_prompt_end;
auto input_tensor = tokenizer.tokenize(in_str, i);
std::cout << "[Q] " << in_str << std::endl;
std::cout << "[A] " << std::flush;
for (int step = 0; step < 100; step++) {
auto result = model({input_tensor});
auto outputs = tokenizer.detokenize(result[0]);
auto out_string = outputs.first;
auto out_token = outputs.second;
if (out_token == tokenizer.end_id && step != 0) {
break;
}
std::cout << out_string << std::flush;
chatPostProcessing(out_token, input_tensor, {});
}
printf("\n");
model.clear_kvcache();
model.profiling();
}

return 0;
}
8 changes: 4 additions & 4 deletions src/backends/cpu/compute/Matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ ErrorCode mat_mul(Tensor *src0, Tensor *src1, Tensor *dst, bool support_bias, Te
auto src0_type_size = type_size(src0->dtype());
auto src0_blck_size = blck_size(src0->dtype());
#ifdef LLAMAFILE_SGEMM
if (check_llamafile_sgemm(N, M, K / blck_size(src0->dtype()), src1->dtype(), src0->dtype(), dst->dtype())) {
if (check_llamafile_sgemm(N, M, K / blck_size(src0->dtype()), src1->dtype(), src0->dtype(), dst->dtype()) && dst->aggregated_tensors().empty()) {
const int ld_src1 = src1->sequence_skip_dim();
const int ld_src0 = src0->sequence_skip_dim();
const int ld_dst = dst->sequence_skip_dim();
Expand Down Expand Up @@ -260,7 +260,7 @@ ErrorCode mat_mul(Tensor *src0, Tensor *src1, Tensor *dst, bool support_bias, Te
}

#ifdef LLAMAFILE_SGEMM
if (check_llamafile_sgemm(N, M, K / blck_size(src1->dtype()), src1->dtype(), src0->dtype(), dst->dtype()) && dst->dtypeAt(0, 0, 0, 0) == MLLM_TYPE_F32 && dst->ctype()==BSHD) {
if (check_llamafile_sgemm(N, M, K / blck_size(src1->dtype()), src1->dtype(), src0->dtype(), dst->dtype()) && dst->dtypeAt(0, 0, 0, 0) == MLLM_TYPE_F32 && dst->ctype()==BSHD&& dst->aggregated_tensors().empty()) {
const int ld_src1 = src1->sequence_skip_dim();
const int ld_src0 = src0->sequence_skip_dim();
const int ld_dst = dst->sequence_skip_dim();
Expand Down Expand Up @@ -697,7 +697,7 @@ ErrorCode mat_mul_elastic(Tensor *src0, Tensor *src1, Tensor *dst, bool support_
int use_N = (activate_output_dim == -1) ? N : activate_output_dim;
int use_K = (activate_input_dim == -1) ? K : activate_input_dim;

if (check_llamafile_sgemm(use_N, M, use_K / blck_size(src0->dtype()), src1->dtype(), src0->dtype(), dst->dtype())) {
if (check_llamafile_sgemm(use_N, M, use_K / blck_size(src0->dtype()), src1->dtype(), src0->dtype(), dst->dtype())&& dst->aggregated_tensors().empty()) {
const int ld_src1 = src1->sequence_skip_dim();
const int ld_src0 = src0->sequence_skip_dim();
const int ld_dst = dst->sequence_skip_dim();
Expand Down Expand Up @@ -764,7 +764,7 @@ ErrorCode mat_mul_elastic(Tensor *src0, Tensor *src1, Tensor *dst, bool support_
}

#ifdef LLAMAFILE_SGEMM
if (check_llamafile_sgemm(use_N, M, use_K / blck_size(src1->dtype()), src1->dtype(), src0->dtype(), dst->dtype()) && !support_bias && dst->ctype()==BSHD) {
if (check_llamafile_sgemm(use_N, M, use_K / blck_size(src1->dtype()), src1->dtype(), src0->dtype(), dst->dtype()) && !support_bias && dst->ctype()==BSHD&& dst->aggregated_tensors().empty()) {
const int ld_src1 = src1->sequence_skip_dim();
const int ld_src0 = src0->sequence_skip_dim();
const int ld_dst = dst->sequence_skip_dim();
Expand Down
75 changes: 75 additions & 0 deletions src/models/phi3/configuration_phi3.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
//
// Created by Guo Xiaoqiang on 2024/8/12 .
//
#ifndef CONFIG_PHI3_HPP
#define CONFIG_PHI3_HPP
#include "models/transformer/configuration_transformer.hpp"

using namespace mllm;

class Phi3NameConfig : public TransformerNameConfig {
public:
std::string blk_name;
std::string token_embd_name;
std::string post_norm_name;
std::string lm_head_name;
std::string _gate_up_proj_name;

void init(RoPEType type = HFHUBROPE) {
switch (type) {
case HFHUBROPE: {
blk_name = "model.layers.";
_attn_base_name = "self_attn.";
_ffn_base_name = "mlp.";
_qkv_proj_name = "qkv_proj";
_o_proj_name = "o_proj";
_gate_up_proj_name = "gate_up_proj";
_down_proj_name = "down_proj";
_attn_norm_name = "input_layernorm";
_ffn_norm_name = "post_attention_layernorm";
token_embd_name = "model.embed_tokens";
post_norm_name = "model.norm";
lm_head_name = "lm_head";
break;
}
default: {
throw std::runtime_error("Unsupported phi3 type");
}
}
}
};

class Phi3Config {
public:
int vocab_size{};
int hidden_dim{};
int head_size{};
int num_key_value_heads{};
int ffn_hidden{};
int block_num{};
RoPEType RoPE_type;
int cache_limit{};
Phi3NameConfig names_config;
float rope_theta;
int max_position_embeddings;

explicit Phi3Config(int token_limit, string billions = "3.8B", RoPEType type = HFHUBROPE, int vocab = 32064) {
names_config.init(type);
vocab_size = vocab;
if (billions == "3.8B" || billions == "3.8b") {
hidden_dim = 3072;
head_size = 32;
num_key_value_heads = 32;
ffn_hidden = 8192;
block_num = 32;
max_position_embeddings = 4096;
rope_theta = 10000.0;
} else {
throw std::runtime_error("Unsupported model size");
}
RoPE_type = type;
cache_limit = token_limit;
}
};

#endif // CONFIG_PHI3_HPP
108 changes: 108 additions & 0 deletions src/models/phi3/modeling_phi3.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
//
// Created by Guo Xiaoqiang on 2024/8/12.
//
#ifndef MODELING_PHI3_HPP
#define MODELING_PHI3_HPP

#include "Layer.hpp"
#include "Module.hpp"
#include "Tensor.hpp"
#include "configuration_phi3.hpp"
#include "models/transformer/modeling_transformer.hpp"

using namespace mllm;

class Phi3MLP final : public Module {
Layer gate_up_proj;
Layer silu;
Layer down_proj;
int ffn_hidden_;

public:
Phi3MLP() = default;
Phi3MLP(int hidden_dim, int ffn_hidden, const Phi3NameConfig &names, const string &base_name) {
ffn_hidden_ = ffn_hidden;
gate_up_proj = Linear(hidden_dim, 2 * ffn_hidden, false, base_name + names._gate_up_proj_name);
silu = SiLU(base_name + "act");
down_proj = Linear(ffn_hidden, hidden_dim, false, base_name + names._down_proj_name);
}
vector<Tensor> Forward(vector<Tensor> inputs, vector<std::any> args) override {
auto x = gate_up_proj(inputs[0]);
auto split_tensors = Tensor::split(x, {ffn_hidden_, ffn_hidden_}, DIMENSION);
Tensor hidden = split_tensors[1];
x = hidden * silu(split_tensors[0]);
x = down_proj(x);
return {x};
}
};

class Phi3Block final : public Module {
MultiHeadAttention attention;
Phi3MLP mlp;
Layer norm1;
Layer norm2;

public:
Phi3Block() = default;
Phi3Block(int hidden_dim, int head_size, int kv_head_size, int ffn_hidden, RoPEType RoPE_type, float rope_theta, int max_position_embeddings, int cache_limit, const Phi3NameConfig &names, const string &base_name) {
attention = MultiHeadAttention(hidden_dim, head_size, kv_head_size, hidden_dim / head_size, SPLIT_HD, false, false,
RoPE_type, rope_theta, max_position_embeddings, cache_limit, true, false, names, base_name + names._attn_base_name);
mlp = Phi3MLP(hidden_dim, ffn_hidden, names, base_name + names._ffn_base_name);
norm1 = RMSNorm(hidden_dim, 1e-6, base_name + names._attn_norm_name);
norm2 = RMSNorm(hidden_dim, 1e-6, base_name + names._ffn_norm_name);
}
vector<Tensor> Forward(vector<Tensor> inputs, vector<std::any> args) override {
auto x = norm1(inputs[0]);
x = attention({x, x, x})[0];
auto tmp = x + inputs[0];
x = norm2(tmp);
x = mlp({x})[0];
x = x + tmp;
return {x};
}

MultiHeadAttention &get_attention() {
return attention;
}
};

class Phi3Model final : public Module {
Layer embedding;
vector<Phi3Block> blocks;
Layer norm;
Layer lm_head;

public:
explicit Phi3Model(const Phi3Config &config) :
Phi3Model(config.vocab_size, config.hidden_dim, config.head_size, config.num_key_value_heads, config.ffn_hidden, config.block_num,
config.RoPE_type, config.rope_theta, config.max_position_embeddings, config.cache_limit,
config.names_config, config.names_config.blk_name) {
}
Phi3Model(int vocab_size, int hidden_dim, int head_size, int kv_head_size, int ffn_hidden, int block_num, RoPEType RoPE_type, float rope_theta, int max_position_embeddings, int cache_limit,
const Phi3NameConfig &names, const string &base_name) {
embedding = Embedding(vocab_size, hidden_dim, names.token_embd_name);
blocks = List<Phi3Block>(block_num, hidden_dim, head_size, kv_head_size, ffn_hidden, RoPE_type, rope_theta, max_position_embeddings, cache_limit, names, base_name);
norm = RMSNorm(hidden_dim, 1e-6, names.post_norm_name);
lm_head = Linear(hidden_dim, vocab_size, false, names.lm_head_name);
}
vector<Tensor> Forward(vector<Tensor> inputs, vector<std::any> args) override {
auto x = embedding(inputs[0]);
for (auto &block : blocks) {
x = block({x})[0];
}
x = norm(x);
x = lm_head(x);
return {x};
}

void clear_kvcache() {
for (auto &block : blocks) {
auto kvcahce = block.get_attention().get_cache();
for (auto &cache : kvcahce) {
cache->clearCache();
}
}
}
};

#endif // MODELING_PHI3_HPP
74 changes: 74 additions & 0 deletions src/models/phi3/tokenization_phi3.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
//
// Created by Guo Xiaoqiang on 2024/8/12.
//
#ifndef TOKENIZATION_PHI3_HPP
#define TOKENIZATION_PHI3_HPP

#include "tokenizers/BPE/Bpe.hpp"
#include <algorithm>
#include <regex>

using namespace mllm;

class Phi3Tokenizer final {
public:
explicit Phi3Tokenizer(const std::string &vocab_file) {
Module::initBackend(MLLM_CPU);
tokenizer = new BPETokenizer(vocab_file);
}

~Phi3Tokenizer() {
delete tokenizer;
}

Tensor tokenize(std::string &text, int str_i = 0) const {
// replace all blanck to '_'
std::string new_text = BPETokenizer::replaceString(text, ' ', "▁");

auto tokens_id = vector<token_id_t>();
tokenizer->tokenize(new_text, tokens_id, false);

// chat template is as follows: <|user|>\n Question <|end|>\n <|assistant|>
tokens_id.insert(tokens_id.begin(), user_id);
tokens_id.insert(tokens_id.begin() + 1, 13);
tokens_id.insert(tokens_id.end(), end_id);
tokens_id.insert(tokens_id.end(), 13);
tokens_id.insert(tokens_id.end(), assistant_id);

return BPETokenizer::tokens2Input(tokens_id);
}

std::string detokenize(const std::vector<token_id_t> &tokens) {
return tokenizer->detokenize(tokens);
}

std::pair<std::string, unsigned> detokenize(Tensor &result) {
assert(result.batch() == 1 && "Batch size of result is not 1. Which is not supported for now.");
assert(result.head() == 1 && "The 3rd dim of result should be one. e.g.:[1, 1, seq, hidden]");
std::vector<float> scores;
int _dims = result.dimension();
int _seq = result.sequence() - 1;
for (int i = 0; i < _dims; ++i) {
auto value = result.dataAt<float>(0, 0, _seq, i);
scores.push_back(value);
}
auto token_idx = this->argmax(scores);
auto text = tokenizer->detokenize({token_idx});
text = std::regex_replace(text, std::regex("▁"), " ");
return make_pair(text, token_idx);
}

private:
unsigned int argmax(const std::vector<float> &scores) {
if (scores.empty()) {
throw std::invalid_argument("Input vector is empty");
}
return std::max_element(scores.begin(), scores.end()) - scores.begin();
}
BPETokenizer *tokenizer;

public:
token_id_t pad_id = 32000, eos_id = 32000, bos_id = 1, user_id = 32010, assistant_id = 32001, end_id = 32007;
};

#endif //! TOKENIZATION_PHI3_HPP
4 changes: 4 additions & 0 deletions src/tokenizers/Tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,10 @@ string Tokenizer::detokenize(const vector<token_id_t> &tokens) {
result += " ";
}
}
if (token_id == TokenNl) {
result += "\n";
continue;
}
result += this->id_token_[token_id].token;
}
return result;
Expand Down
Loading
Loading