Skip to content

Commit

Permalink
Get vm id from wslhost.exe command line
Browse files Browse the repository at this point in the history
  • Loading branch information
MrVinkel authored and Biswa96 committed Oct 12, 2024
1 parent 5b2b652 commit 837825b
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 3 deletions.
126 changes: 126 additions & 0 deletions src/GetVmIdWsl2.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
#include <combaseapi.h>
#include <Windows.h>
#include <TlHelp32.h>
#include <winternl.h>
#include <ntstatus.h>
#include <psapi.h>
#include <vector>
#include <string>

#include "common.hpp"
#include "GetVmIdWsl2.hpp"

bool ExtractGUID(const std::wstring key, const std::wstring& commandLine, std::wstring& guid) {
size_t pos = commandLine.find(key);
if (pos != std::wstring::npos)
{
size_t start = commandLine.find(L'{', pos);
size_t end = commandLine.find(L'}', start);
if (start != std::wstring::npos && end != std::wstring::npos)
{
guid = commandLine.substr(start, end - start + 1);
return true;
}
}
return false;
}

bool GetCommandLineForPID(DWORD pid, std::wstring& commandLine)
{
HMODULE hNtdll = GetModuleHandle(L"ntdll.dll");
using NtQueryInformationProcessFunc = NTSTATUS(NTAPI*)(HANDLE, PROCESSINFOCLASS, PVOID, ULONG, PULONG);
NtQueryInformationProcessFunc NtQueryInformationProcess = (NtQueryInformationProcessFunc)GetProcAddress(hNtdll, "NtQueryInformationProcess");

if (!NtQueryInformationProcess)
return false;

// Open a handle to the process
HANDLE process = OpenProcess(PROCESS_QUERY_INFORMATION | PROCESS_VM_READ, FALSE, pid);
if (process == NULL)
{
DWORD err = GetLastError();
fatal("failed to open the process, error: %d", err);
return false;
}
// Get the address of the PEB
PROCESS_BASIC_INFORMATION pbi = {};
NTSTATUS status = NtQueryInformationProcess(process, ProcessBasicInformation, &pbi, sizeof(pbi), NULL);
if (status != STATUS_SUCCESS)
{
CloseHandle(process);
fatal("failed to query the process, error: %d", status);
return false;
}
// Get the address of the process parameters in the PEB
PEB peb = {};
if (!ReadProcessMemory(process, pbi.PebBaseAddress, &peb, sizeof(peb), NULL))
{
CloseHandle(process);
DWORD err = GetLastError();
fatal("failed to read the process PEB, error: %d", err);
return false;
}
// Get the command line arguments from the process parameters
RTL_USER_PROCESS_PARAMETERS params = {};
if (!ReadProcessMemory(process, peb.ProcessParameters, &params, sizeof(params), NULL))
{
CloseHandle(process);
DWORD err = GetLastError();
fatal("failed to read the process params, error: %d", err);
return false;
}
UNICODE_STRING &commandLineArgs = params.CommandLine;
std::vector<WCHAR> buffer(commandLineArgs.Length / sizeof(WCHAR));
if (!ReadProcessMemory(process, commandLineArgs.Buffer, buffer.data(), commandLineArgs.Length, NULL))
{
CloseHandle(process);
DWORD err = GetLastError();
fatal("failed to read the process command line, error: %d", err);
return false;
}

CloseHandle(process);
commandLine.assign(buffer.data(), buffer.size());
return true;
}

std::vector<DWORD> GetProcessIDsByName(const std::wstring& processName) {
std::vector<DWORD> processIDs;
HANDLE hProcessSnap = CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0);
PROCESSENTRY32 pe32;
pe32.dwSize = sizeof(PROCESSENTRY32);

if (Process32First(hProcessSnap, &pe32))
{
do
{
if (pe32.szExeFile == processName)
{
processIDs.push_back(pe32.th32ProcessID);
}
} while (Process32Next(hProcessSnap, &pe32));
}
CloseHandle(hProcessSnap);
return processIDs;
}

// Extract GUID from wslHost.exe command line
// Example commandline:
// wslhost.exe --vm-id {f6446e02-236e-4b24-9916-2d4ad9a1096f} --handle 1664
bool GetVmIdWsl2(GUID& vmId) {
std::vector<DWORD> pids = GetProcessIDsByName(L"wslhost.exe");
for (DWORD pid : pids) {
std::wstring cmdLine;
if (!GetCommandLineForPID(pid, cmdLine))
continue;

std::wstring cmdVmId;
if(!ExtractGUID(L"--vm-id", cmdLine, cmdVmId))
continue;

if (IIDFromString(cmdVmId.c_str(), &vmId) == S_OK)
return true;
}
return false;
}

8 changes: 8 additions & 0 deletions src/GetVmIdWsl2.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#include <Windows.h>

#ifndef VMIDWSL2_HPP
#define VMIDWSL2_HPP

bool GetVmIdWsl2(GUID& vmId);

#endif /* VMIDWSL2_HPP */
4 changes: 4 additions & 0 deletions src/Makefile.frontend
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ LIBS = -lole32 -lws2_32
OBJS = \
$(BINDIR)/common.obj \
$(BINDIR)/GetVmId.obj \
$(BINDIR)/GetVmIdWsl2.obj \
$(BINDIR)/Helpers.obj \
$(BINDIR)/TerminalState.obj \
$(BINDIR)/windows-sock.obj \
Expand Down Expand Up @@ -53,6 +54,9 @@ $(BINDIR)/windows-sock.obj : windows-sock.c
$(BINDIR)/wslbridge2.obj : wslbridge2.cpp
$(CXX) -c $(CXXFLAGS) $(CCOPT) $< -o $@

$(BINDIR)/GetVmIdWsl2.obj : GetVmIdWsl2.cpp
$(CXX) -c $(CXXFLAGS) $(CCOPT) $< -o $@

$(BINDIR) :
mkdir -p $(BINDIR)

Expand Down
7 changes: 4 additions & 3 deletions src/wslbridge2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "Environment.hpp"
#include "TerminalState.hpp"
#include "windows-sock.h"
#include "GetVmIdWsl2.hpp"

union IoSockets
{
Expand Down Expand Up @@ -384,12 +385,12 @@ int main(int argc, char *argv[])
{
// wsltty#302: Start dummy process after ComInit, otherwise RPC_E_TOO_LATE.
// wslbridge2#38: Do this only for WSL2 as WSL1 does not need the VM context.
// wslbridge2#42: Required for WSL2 to get the VM ID.
if (LiftedWSLVersion)
start_dummy(wslPath, wslCmdLine, distroName, debugMode);

const HRESULT hRes = GetVmId(&DistroId, &VmId, LiftedWSLVersion);
if (hRes != 0)
fatal("GetVmId: %s\n", GetErrorMessage(hRes).c_str());
if (!GetVmIdWsl2(VmId))
fatal("Failed to get VM ID");

inputSock = win_vsock_create();
outputSock = win_vsock_create();
Expand Down

0 comments on commit 837825b

Please # to comment.