Skip to content

Commit

Permalink
Port apache/cloudstack#1533 that converts pacthviasocket to python
Browse files Browse the repository at this point in the history
Original commits:
 * 0acd3c1: Convert patchviasocket to python (removes perl dependency for KVM agent)
 * 751d355: patchviasocket improve error handling
  • Loading branch information
Sverrir A. Berg authored and miguelaferreira committed Aug 17, 2016
1 parent de64250 commit f47d4b9
Show file tree
Hide file tree
Showing 4 changed files with 221 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ public static class Constants {
public static final String DEFAULT_VM_RNG_PATH = "/dev/random";
public static final String SCRIPT_MODIFY_VLAN = "modifyvlan.sh";
public static final String SCRIPT_VERSIONS = "versions.sh";
public static final String SCRIPT_PATCH_VIA_SOCKET = "patchviasocket.pl";
public static final String SCRIPT_PATCH_VIA_SOCKET = "patchviasocket.py";
public static final String SCRIPT_KVM_HEART_BEAT = "kvmheartbeat.sh";
public static final String SCRIPT_CREATE_VM = "createvm.sh";
public static final String SCRIPT_MANAGE_SNAPSHOT = "managesnapshot.sh";
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#!/usr/bin/env python
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

#
# This script connects to the system vm socket and writes the
# authorized_keys and cmdline data to it. The system VM then
# reads it from /dev/vport0p1 in cloud_early_config
#

import argparse
import os
import socket

SOCK_FILE = "/var/lib/libvirt/qemu/{name}.agent"
PUB_KEY_FILE = "/root/.ssh/id_rsa.pub.cloud"
MESSAGE = "pubkey:{key}\ncmdline:{cmdline}\n"


def send_to_socket(sock_file, key_file, cmdline):
if not os.path.exists(key_file):
print("ERROR: ssh public key not found on host at {0}".format(key_file))
return 1

try:
with open(key_file, "r") as f:
pub_key = f.read()
except IOError as e:
print("ERROR: unable to open {0} - {1}".format(key_file, e.strerror))
return 1

# Keep old substitution from perl code:
cmdline = cmdline.replace("%", " ")

msg = MESSAGE.format(key=pub_key, cmdline=cmdline)

if not os.path.exists(sock_file):
print("ERROR: {0} socket not found".format(sock_file))
return 1

try:
s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
s.connect(sock_file)
s.sendall(msg)
s.close()
except IOError as e:
print("ERROR: unable to connect to {0} - {1}".format(sock_file, e.strerror))
return 1

return 0 # Success


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Send configuration to system VM socket")
parser.add_argument("-n", "--name", required=True, help="Name of VM")
parser.add_argument("-p", "--cmdline", required=True, help="Command line")

arguments = parser.parse_args()

socket_file = SOCK_FILE.format(name=arguments.name)

exit(send_to_socket(socket_file, PUB_KEY_FILE, arguments.cmdline))
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
#!/usr/bin/env python
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import patchviasocket

import getpass
import os
import socket
import tempfile
import time
import threading
import unittest

KEY_DATA = "I luv\nCloudStack\n"
CMD_DATA = "/run/this-for-me --please=TRUE! very%quickly"
NON_EXISTING_FILE = "must-not-exist"


def write_key_file():
_, tmpfile = tempfile.mkstemp(".sck")
with open(tmpfile, "w") as f:
f.write(KEY_DATA)
return tmpfile


class SocketThread(threading.Thread):
def __init__(self):
super(SocketThread, self).__init__()
self._data = ""
self._folder = tempfile.mkdtemp(".sck")
self._file = os.path.join(self._folder, "socket")
self._ready = False

def data(self):
return self._data

def file(self):
return self._file

def wait_until_ready(self):
while not self._ready:
time.sleep(0.050)

def run(self):
TIMEOUT = 0.314 # Very short time for tests that don't write to socket.
MAX_SIZE = 10 * 1024

s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try:
s.bind(self._file)
s.listen(1)
s.settimeout(TIMEOUT)
try:
self._ready = True
client, address = s.accept()
self._data = client.recv(MAX_SIZE)
client.close()
except socket.timeout:
pass
finally:
s.close()
os.remove(self._file)
os.rmdir(self._folder)


class TestPatchViaSocket(unittest.TestCase):
def setUp(self):
self._key_file = write_key_file()

self._unreadable = write_key_file()
os.chmod(self._unreadable, 0)

self.assertFalse(os.path.exists(NON_EXISTING_FILE))
self.assertNotEqual("root", getpass.getuser(), "must be non-root user (to test access denied errors)")

def tearDown(self):
os.remove(self._key_file)
os.remove(self._unreadable)

def test_write_to_socket(self):
reader = SocketThread()
reader.start()
reader.wait_until_ready()
self.assertEquals(0, patchviasocket.send_to_socket(reader.file(), self._key_file, CMD_DATA))
reader.join()
data = reader.data()
self.assertIn(KEY_DATA, data)
self.assertIn(CMD_DATA.replace("%", " "), data)
self.assertNotIn("LUV", data)
self.assertNotIn("very%quickly", data) # Testing substitution

def test_host_key_error(self):
reader = SocketThread()
reader.start()
reader.wait_until_ready()
self.assertEquals(1, patchviasocket.send_to_socket(reader.file(), NON_EXISTING_FILE, CMD_DATA))
reader.join() # timeout

def test_host_key_access_denied(self):
reader = SocketThread()
reader.start()
reader.wait_until_ready()
self.assertEquals(1, patchviasocket.send_to_socket(reader.file(), self._unreadable, CMD_DATA))
reader.join() # timeout

def test_nonexistant_socket_error(self):
reader = SocketThread()
reader.start()
reader.wait_until_ready()
self.assertEquals(1, patchviasocket.send_to_socket(NON_EXISTING_FILE, self._key_file, CMD_DATA))
reader.join() # timeout

def test_invalid_socket_error(self):
reader = SocketThread()
reader.start()
reader.wait_until_ready()
self.assertEquals(1, patchviasocket.send_to_socket(self._key_file, self._key_file, CMD_DATA))
reader.join() # timeout

def test_access_denied_socket_error(self):
reader = SocketThread()
reader.start()
reader.wait_until_ready()
self.assertEquals(1, patchviasocket.send_to_socket(self._unreadable, self._key_file, CMD_DATA))
reader.join() # timeout


if __name__ == '__main__':
unittest.main()

0 comments on commit f47d4b9

Please # to comment.