Skip to content

Commit

Permalink
Use prefill for term consistency (octoml#196)
Browse files Browse the repository at this point in the history
This PR changes the occurance of encode step in LLM
to prefill, which is a more standard terminology
  • Loading branch information
tqchen committed May 20, 2023
1 parent 305865c commit fac3201
Show file tree
Hide file tree
Showing 15 changed files with 74 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,12 @@ void Dummy(String text, Handler handler) {
e.printStackTrace();
}
}
Utils.sendEnd("encode: 100.0 tok/s, decode: 100.0 tok/s", handler);
Utils.sendEnd("prefill: 100.0 tok/s, decode: 100.0 tok/s", handler);
}

void Generate(String prompt, Handler handler) {
// System.err.println("Start generating");
backend.Encode(prompt);
backend.Prefill(prompt);
// System.err.println("Encoding " + prompt);
while (!backend.Stopped()) {
backend.Decode();
Expand Down
10 changes: 5 additions & 5 deletions android/MLCChat/app/src/main/java/ai/mlc/mlcchat/LLMChat.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import org.apache.tvm.Module;

public class LLMChat {
private Function encode_func_;
private Function prefill_func_;
private Function decode_func_;
private Function get_message_;
private Function stopped_func_;
Expand Down Expand Up @@ -36,7 +36,7 @@ public void Init() {
System.err.println("[INFO] Before LLM Chat create");
llm_chat_ = fcreate.pushArg(lib).pushArg(tokenizer_path).pushArg(param_path).pushArg(Device.opencl().deviceType).pushArg(0).invoke().asModule();
System.err.println("[INFO] LLM Chat created!");
encode_func_ = llm_chat_.getFunction("encode");
prefill_func_ = llm_chat_.getFunction("prefill");
decode_func_ = llm_chat_.getFunction("decode");
get_message_ = llm_chat_.getFunction("get_message");

Expand All @@ -45,7 +45,7 @@ public void Init() {

runtime_stats_text_func_ = llm_chat_.getFunction("runtime_stats_text");

assert encode_func_ != null;
assert prefill_func_ != null;
assert decode_func_ != null;
assert stopped_func_ != null;
assert runtime_stats_text_func_ != null;
Expand All @@ -71,8 +71,8 @@ public String GetMessage() {
return get_message_.invoke().asString();
}

public void Encode(String prompt) {
encode_func_.pushArg(prompt).invoke();
public void Prefill(String prompt) {
prefill_func_.pushArg(prompt).invoke();
}

public boolean Stopped() {
Expand Down
4 changes: 2 additions & 2 deletions build.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ def mod_transform_before_build(
) -> tvm.IRModule:
"""First-stage: Legalize ops and trace"""
model_names = [
"encoding",
"decoding",
"prefill",
"decode",
"create_kv_cache",
"softmax_with_temperature",
"get_metadata",
Expand Down
29 changes: 16 additions & 13 deletions cpp/cli_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ struct LLMChatModule {
public:
explicit LLMChatModule(const DLDevice& device) {
this->chat_mod_ = mlc::llm::CreateChatModule(device);
this->encode_ = this->chat_mod_->GetFunction("encode");
this->prefill_ = this->chat_mod_->GetFunction("prefill");
this->decode_ = this->chat_mod_->GetFunction("decode");
this->stopped_ = this->chat_mod_->GetFunction("stopped");
this->get_message_ = this->chat_mod_->GetFunction("get_message");
Expand All @@ -180,7 +180,7 @@ struct LLMChatModule {
this->get_role1_ = this->chat_mod_->GetFunction("get_role1");
this->runtime_stats_text_ = this->chat_mod_->GetFunction("runtime_stats_text");
this->reset_chat_ = this->chat_mod_->GetFunction("reset_chat");
ICHECK(encode_ != nullptr);
ICHECK(prefill_ != nullptr);
ICHECK(decode_ != nullptr);
ICHECK(stopped_ != nullptr);
ICHECK(get_message_ != nullptr);
Expand All @@ -206,7 +206,7 @@ struct LLMChatModule {
void Reset() { reset_chat_(); }

void Converse(const std::string& input, int stream_interval, std::ostream& os) {
this->Encode(input);
this->Prefill(input);

std::string cur_msg = "";
std::vector<std::string> cur_utf8_chars = CountUTF8(cur_msg);
Expand Down Expand Up @@ -241,7 +241,7 @@ struct LLMChatModule {

protected:
// Low-level APIs
void Encode(const std::string& input) { encode_(input); }
void Prefill(const std::string& input) { prefill_(input); }

void Decode() { decode_(); }

Expand All @@ -251,7 +251,7 @@ struct LLMChatModule {

// TVM Modules and functions with TVM's calling convention
tvm::runtime::Module chat_mod_;
tvm::runtime::PackedFunc encode_;
tvm::runtime::PackedFunc prefill_;
tvm::runtime::PackedFunc decode_;
tvm::runtime::PackedFunc stopped_;
tvm::runtime::PackedFunc get_message_;
Expand All @@ -264,10 +264,14 @@ struct LLMChatModule {

std::optional<std::filesystem::path> TryInferMLCChatConfig(const std::string& artifact_path,
const std::string& local_id) {
return FindFile({artifact_path + "/prebuilt/" + local_id, //
artifact_path + "/" + local_id + "/params"}, //
{"mlc-chat-config"}, //
{".json"});
return FindFile(
{
//
artifact_path + "/" + local_id + "/params", //
artifact_path + "/prebuilt/" + local_id, //
}, //
{"mlc-chat-config"}, //
{".json"});
}

std::string ReadStringFromJSONFile(const std::filesystem::path& config_path,
Expand Down Expand Up @@ -317,10 +321,9 @@ ModelPaths ModelPaths::Find(const std::string& artifact_path, const std::string&
std::filesystem::path lib_path;
if (auto path = FindFile(
{
artifact_path + "/prebuilt/lib/", // prebuild lib
artifact_path + "/prebuilt/" + lib_local_id, // For prebuilts
artifact_path + "/" + lib_local_id, // Usually this is the candidate
artifact_path + "/" + lib_local_id + "/lib/",
artifact_path + "/" + lib_local_id, // Usually this is the candidate
artifact_path + "/prebuilt/lib/", // prebuild lib
artifact_path + "/prebuilt/" + lib_local_id // For prebuilts
},
{
lib_name,
Expand Down
32 changes: 16 additions & 16 deletions cpp/llm_chat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ class LLMChat {
std::string RuntimeStatsText() {
std::ostringstream os;
os << "prefill: " << std::setprecision(1) << std::fixed
<< this->encode_total_tokens / this->encode_total_time << " tok/s"
<< this->prefill_total_tokens / this->prefill_total_time << " tok/s"
<< ", decode: " << std::setprecision(1) << std::fixed
<< this->decode_total_tokens / this->decode_total_time << " tok/s";
// os << ", sample-cost: " << std::setprecision(1) << std::fixed
Expand All @@ -492,8 +492,8 @@ class LLMChat {
static_cast<int>(kDLCPU), 0,
static_cast<int>(relax_vm::AllocatorType::kPooled));

encoding_func_ = vm_->GetFunction("encoding");
decoding_func_ = vm_->GetFunction("decoding");
encoding_func_ = vm_->GetFunction("prefill");
decoding_func_ = vm_->GetFunction("decode");
encoding_without_cache_func_ = vm_->GetFunction("encoding_without_cache");
softmax_func_ = vm_->GetFunction("softmax_with_temperature");
get_metadata_func_ = vm_->GetFunction("get_metadata");
Expand Down Expand Up @@ -629,9 +629,9 @@ class LLMChat {

/*! \brief reset the runtime stats. */
void ResetRuntimeStats() {
this->encode_total_tokens = 0;
this->prefill_total_tokens = 0;
this->decode_total_tokens = 0;
this->encode_total_time = 0;
this->prefill_total_time = 0;
this->decode_total_time = 0;
this->sample_total_time = 0;
}
Expand Down Expand Up @@ -725,8 +725,8 @@ class LLMChat {
/*!
* \brief Generate the next token given a prompt.
*/
void EncodeStep(std::string inp) {
if (reset_stats_per_encode_) {
void PrefillStep(std::string inp) {
if (reset_stats_per_prefill_) {
this->ResetRuntimeStats();
}
output_ids_.clear();
Expand Down Expand Up @@ -755,8 +755,8 @@ class LLMChat {
TVMSynchronize(device_.device_type, device_.device_id, nullptr);
auto tend = std::chrono::high_resolution_clock::now();

this->encode_total_time += static_cast<double>((tend - tstart).count()) / 1e9;
this->encode_total_tokens += token_len;
this->prefill_total_time += static_cast<double>((tend - tstart).count()) / 1e9;
this->prefill_total_tokens += token_len;
if (temperature_ < 1e-6f) {
next_token_ = this->SampleFromLogitsOnCPU();
} else {
Expand Down Expand Up @@ -1024,12 +1024,12 @@ class LLMChat {
//----------------------------
// Statistics
//----------------------------
bool reset_stats_per_encode_ = true;
bool reset_stats_per_prefill_ = true;
double decode_total_time = 0;
double sample_total_time = 0;
double encode_total_time = 0;
double prefill_total_time = 0;
int64_t decode_total_tokens = 0;
int64_t encode_total_tokens = 0;
int64_t prefill_total_tokens = 0;
//----------------------------
// Conversation
//----------------------------
Expand Down Expand Up @@ -1141,10 +1141,10 @@ class LLMChatModule : public ModuleNode {
} else if (name == "try_tokenizer") {
return PackedFunc(
[this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { GetChat()->TryTokenizer(); });
} else if (name == "encode") {
} else if (name == "prefill") {
return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
ICHECK_EQ(args.size(), 1);
GetChat()->EncodeStep(args[0]);
GetChat()->PrefillStep(args[0]);
});
} else if (name == "decode") {
return PackedFunc(
Expand Down Expand Up @@ -1219,8 +1219,8 @@ class LLMChatModule : public ModuleNode {
static_cast<int>(relax_vm::AllocatorType::kPooled), static_cast<int>(kDLCPU), 0,
static_cast<int>(relax_vm::AllocatorType::kPooled));

chat_->encoding_func_ = chat_->vm_->GetFunction("encoding");
chat_->decoding_func_ = chat_->vm_->GetFunction("decoding");
chat_->encoding_func_ = chat_->vm_->GetFunction("prefill");
chat_->decoding_func_ = chat_->vm_->GetFunction("decode");
chat_->encoding_without_cache_func_ = chat_->vm_->GetFunction("encoding_without_cache");
chat_->softmax_func_ = chat_->vm_->GetFunction("softmax_with_temperature");
chat_->get_metadata_func_ = chat_->vm_->GetFunction("get_metadata");
Expand Down
2 changes: 1 addition & 1 deletion ios/MLCChat/ChatState.swift
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class ChatState : ObservableObject {
threadWorker.push {[self] in
self.appendMessage(role: MessageRole.user, message: prompt)

backend.encode(prompt);
backend.prefill(prompt);
while (!backend.stopped()) {
assert(self.inProgress);
backend.decode();
Expand Down
16 changes: 8 additions & 8 deletions ios/MLCChat/LLMChat.mm
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

reload_func_ = llm_chat_->GetFunction("reload");
unload_func_ = llm_chat_->GetFunction("unload");
encode_func_ = llm_chat_->GetFunction("encode");
prefill_func_ = llm_chat_->GetFunction("prefill");
decode_func_ = llm_chat_->GetFunction("decode");
get_message_ = llm_chat_->GetFunction("get_message");
stopped_func_ = llm_chat_->GetFunction("stopped");
Expand All @@ -36,7 +36,7 @@

ICHECK(reload_func_ != nullptr);
ICHECK(unload_func_ != nullptr);
ICHECK(encode_func_ != nullptr);
ICHECK(prefill_func_ != nullptr);
ICHECK(decode_func_ != nullptr);
ICHECK(get_message_ != nullptr);
ICHECK(stopped_func_ != nullptr);
Expand Down Expand Up @@ -64,9 +64,9 @@ void Evaluate() {
return get_message_();
}

void Encode(std::string prompt) {
ICHECK(encode_func_ != nullptr);
encode_func_(prompt);
void Prefill(std::string prompt) {
ICHECK(prefill_func_ != nullptr);
prefill_func_(prompt);
}

bool Stopped() { return stopped_func_(); }
Expand All @@ -86,7 +86,7 @@ void Encode(std::string prompt) {
Module llm_chat_;
PackedFunc unload_func_;
PackedFunc reload_func_;
PackedFunc encode_func_;
PackedFunc prefill_func_;
PackedFunc decode_func_;
PackedFunc get_message_;
PackedFunc stopped_func_;
Expand All @@ -112,8 +112,8 @@ - (void)evaluate {
LLMChatModuleWrapper::Global()->Evaluate();
}

- (void)encode:(NSString*)prompt {
LLMChatModuleWrapper::Global()->Encode(prompt.UTF8String);
- (void)prefill:(NSString*)prompt {
LLMChatModuleWrapper::Global()->Prefill(prompt.UTF8String);
}

- (void)decode {
Expand Down
2 changes: 1 addition & 1 deletion ios/MLCChat/MLCChat-Bridging-Header.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
- (void)evaluate;
- (void)unload;
- (void)reload:(NSString*)model_lib modelPath:(NSString*)modelPath;
- (void)encode:(NSString*)prompt;
- (void)prefill:(NSString*)prompt;
- (void)decode;
- (void)reset;
- (NSString*)getMessage;
Expand Down
8 changes: 4 additions & 4 deletions mlc_llm/relax_model/gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ def create_encoding_func(
batch_size = tvm.tir.IntImm("int64", 1)
seq_len = tvm.tir.Var("n", "int64")
all_seq_len = tvm.tir.Var("m", "int64")
with bb.function("encoding"):
with bb.function("prefill"):
model = GPTNeoXForCausalLM(config)
input_ids = nn.Placeholder(
(batch_size, seq_len), dtype="int32", name="input_ids"
Expand Down Expand Up @@ -500,7 +500,7 @@ def create_encoding_func(
gv = bb.emit_output((logits, relax.Tuple(key_value_cache)))
bb.emit_func_output(gv, params)
mod = bb.get()
gv = mod.get_global_var("encoding")
gv = mod.get_global_var("prefill")
bb.update_func(gv, mod[gv].with_attr("num_input", 3))


Expand All @@ -512,7 +512,7 @@ def create_decoding_func(
batch_size = tvm.tir.IntImm("int64", 1)
seq_len = tvm.tir.IntImm("int64", 1)
all_seq_len = tvm.tir.Var("n", "int64")
with bb.function("decoding"):
with bb.function("decode"):
model = GPTNeoXForCausalLM(config)
input_ids = nn.Placeholder(
(batch_size, seq_len), dtype="int32", name="input_ids"
Expand Down Expand Up @@ -544,7 +544,7 @@ def create_decoding_func(
gv = bb.emit_output((logits, relax.Tuple(key_value_cache)))
bb.emit_func_output(gv, params)
mod = bb.get()
gv = mod.get_global_var("decoding")
gv = mod.get_global_var("decode")
bb.update_func(gv, mod[gv].with_attr("num_input", 3))


Expand Down
8 changes: 4 additions & 4 deletions mlc_llm/relax_model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ def create_encoding_func(bb: relax.BlockBuilder, config: LlamaConfig) -> None:
bsz = 1
seq_len = tvm.tir.Var("n", "int64")
all_seq_len = tvm.tir.Var("m", "int64")
with bb.function("encoding"):
with bb.function("prefill"):
model = LlamaForCausalLM(config)
input_ids = nn.Placeholder((bsz, seq_len), dtype="int32", name="input_ids")
all_seq_len_shape = relax.Var(
Expand All @@ -584,15 +584,15 @@ def create_encoding_func(bb: relax.BlockBuilder, config: LlamaConfig) -> None:
bb.emit_func_output(gv, params)

mod = bb.get()
gv = mod.get_global_var("encoding")
gv = mod.get_global_var("prefill")
bb.update_func(gv, mod[gv].with_attr("num_input", 3))


def create_decoding_func(bb: relax.BlockBuilder, config: LlamaConfig) -> None:
bsz = 1
all_seq_len = tvm.tir.Var("n", "int64")

with bb.function("decoding"):
with bb.function("decode"):
model = LlamaForCausalLM(config)
input_ids = nn.Placeholder((bsz, 1), dtype="int32", name="input_ids")
all_seq_len_shape = relax.Var(
Expand All @@ -617,7 +617,7 @@ def create_decoding_func(bb: relax.BlockBuilder, config: LlamaConfig) -> None:
bb.emit_func_output(gv, params)

mod = bb.get()
gv = mod.get_global_var("decoding")
gv = mod.get_global_var("decode")
bb.update_func(gv, mod[gv].with_attr("num_input", 3))


Expand Down
Loading

0 comments on commit fac3201

Please sign in to comment.