diff --git a/examples/common.cpp b/examples/common.cpp index ad7b0bba32f1f..d8c77fcbdcbbc 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -15,18 +15,12 @@ #endif #if defined (_WIN32) +#define WIN32_LEAN_AND_MEAN +#define NOMINMAX +#include +#include #include #include -#pragma comment(lib,"kernel32.lib") -extern "C" __declspec(dllimport) void* __stdcall GetStdHandle(unsigned long nStdHandle); -extern "C" __declspec(dllimport) int __stdcall GetConsoleMode(void* hConsoleHandle, unsigned long* lpMode); -extern "C" __declspec(dllimport) int __stdcall SetConsoleMode(void* hConsoleHandle, unsigned long dwMode); -extern "C" __declspec(dllimport) int __stdcall SetConsoleCP(unsigned int wCodePageID); -extern "C" __declspec(dllimport) int __stdcall SetConsoleOutputCP(unsigned int wCodePageID); -extern "C" __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int CodePage, unsigned long dwFlags, - const wchar_t * lpWideCharStr, int cchWideChar, - char * lpMultiByteStr, int cbMultiByte, - const char * lpDefaultChar, bool * lpUsedDefaultChar); #define CP_UTF8 65001 #endif @@ -60,7 +54,71 @@ int32_t get_num_physical_cores() { return num_physical_cores; } #elif defined(_WIN32) - //TODO: Implement + // Call GetLogicalProcessorInformationEx with a nullptr buffer to get the required buffer length + DWORD length = 0; + GetLogicalProcessorInformationEx(RelationAll, nullptr, &length); + + // Allocate memory for the buffer + std::unique_ptr buffer_ptr(new char[length]); + char* buffer = buffer_ptr.get(); + + // Things to count + unsigned int physical_cores = 0; + unsigned int physical_performance_cores = 0; + unsigned int physical_efficiency_cores = 0; + unsigned int logical_cores = 0; + unsigned int logical_performance_cores = 0; + unsigned int logical_efficiency_cores = 0; + + // Call GetLogicalProcessorInformationEx again with the allocated buffer + if (GetLogicalProcessorInformationEx( + RelationAll, + reinterpret_cast(buffer), + &length)) { + DWORD offset = 0; + + while (offset < length) { + auto info = reinterpret_cast(buffer + offset); + + if (info->Relationship == RelationProcessorCore) { + physical_cores += info->Processor.GroupCount; + + for (WORD i = 0; i < info->Processor.GroupCount; ++i) { + int core_count = static_cast(__popcnt64(info->Processor.GroupMask[i].Mask)); + logical_cores += core_count; + + // Assuming EfficiencyClass 0 represents performance cores, and others represent efficiency cores + if (info->Processor.EfficiencyClass == 0) { + physical_performance_cores++; + logical_performance_cores += core_count; + } else { + physical_efficiency_cores++; + logical_efficiency_cores += core_count; + } + } + } + offset += info->Size; + } + + // TODO: Remove this once we've verified it's working + fprintf(stderr, + "Physical Cores: %u\n" + " - Performance Cores: %u\n" + " - Efficiency Cores: %u\n" + "Logical Cores: %u\n" + " - Performance Cores: %u\n" + " - Efficiency Cores: %u\n", + physical_cores, physical_performance_cores, physical_efficiency_cores, + logical_cores, logical_performance_cores, logical_efficiency_cores); + } else { + fprintf(stderr, "Failed to get processor information. Error: %u\n", GetLastError()); + } + + if (physical_performance_cores > 0) { + return static_cast(physical_performance_cores); + } else if (physical_cores > 0) { + return static_cast(physical_cores); + } #endif unsigned int n_threads = std::thread::hardware_concurrency(); return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4; @@ -69,7 +127,6 @@ int32_t get_num_physical_cores() { bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { bool invalid_param = false; std::string arg; - gpt_params default_params; for (int i = 1; i < argc; i++) { arg = argv[i]; @@ -287,7 +344,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { } params.n_parts = std::stoi(argv[i]); } else if (arg == "-h" || arg == "--help") { - gpt_print_usage(argc, argv, default_params); + gpt_print_usage(argc, argv); exit(0); } else if (arg == "--random-prompt") { params.random_prompt = true; @@ -299,20 +356,22 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.input_prefix = argv[i]; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); - gpt_print_usage(argc, argv, default_params); + gpt_print_usage(argc, argv); exit(1); } } if (invalid_param) { fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str()); - gpt_print_usage(argc, argv, default_params); + gpt_print_usage(argc, argv); exit(1); } return true; } -void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { +void gpt_print_usage(int /*argc*/, char ** argv) { + gpt_params params; + fprintf(stderr, "usage: %s [options]\n", argv[0]); fprintf(stderr, "\n"); fprintf(stderr, "options:\n"); diff --git a/examples/common.h b/examples/common.h index 627696e30a4f6..49385711d3875 100644 --- a/examples/common.h +++ b/examples/common.h @@ -67,7 +67,7 @@ struct gpt_params { bool gpt_params_parse(int argc, char ** argv, gpt_params & params); -void gpt_print_usage(int argc, char ** argv, const gpt_params & params); +void gpt_print_usage(int argc, char ** argv); std::string gpt_random_prompt(std::mt19937 & rng);