diff --git a/common/common.cpp b/common/common.cpp index 6b937f729ec2c..d748e69242126 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -424,6 +424,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.ppl_stride = std::stoi(argv[i]); + } else if (arg == "--ppl-output-type") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.ppl_output_type = std::stoi(argv[i]); } else if (arg == "--hellaswag") { params.hellaswag = true; } else if (arg == "--hellaswag-tasks") { diff --git a/common/common.h b/common/common.h index 184092f064683..6e22177a6b49f 100644 --- a/common/common.h +++ b/common/common.h @@ -65,7 +65,9 @@ struct gpt_params { std::string lora_base = ""; // base model path for the lora adapter int ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used. - // + int ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line + // (which is more convenient to use for plotting) + // bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 22425e14fdf9f..e89725efc3db6 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -125,7 +125,11 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) { ++count; } // perplexity is e^(average negative log-likelihood) - printf("[%d]%.4lf,", i + 1, std::exp(nll / count)); + if (params.ppl_output_type == 0) { + printf("[%d]%.4lf,", i + 1, std::exp(nll / count)); + } else { + printf("%8d %.4lf\n", i*params.ppl_stride, std::exp(nll / count)); + } fflush(stdout); } printf("\n"); @@ -226,7 +230,11 @@ void perplexity(llama_context * ctx, const gpt_params & params) { ++count; } // perplexity is e^(average negative log-likelihood) - printf("[%d]%.4lf,", i + 1, std::exp(nll / count)); + if (params.ppl_output_type == 0) { + printf("[%d]%.4lf,", i + 1, std::exp(nll / count)); + } else { + printf("%8d %.4lf\n", i*params.n_ctx, std::exp(nll / count)); + } fflush(stdout); } printf("\n");