#!/usr/bin/python3
#
# bolt integration test suite
#
# Copyright © 2017 Red Hat, Inc
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library. If not, see <http://www.gnu.org/licenses/>.
# Authors:
#       Christian J. Kellner <christian@kellner.me>

import binascii
import os
import shutil
import sys
import subprocess
import unittest
import uuid
import tempfile
import time

from collections import namedtuple
from itertools import chain

try:
    import gi
    from gi.repository import GLib
    from gi.repository import Gio
    gi.require_version('UMockdev', '1.0')
    from gi.repository import UMockdev

    import dbus
    import dbusmock
except ImportError as e:
    sys.stderr.write('Skipping integration test due to missing depdendencies: %s\n' % str(e))
    sys.exit(0)


try:
    from subprocess import DEVNULL
except ImportError:
    DEVNULL = open(os.devnull, 'wb')

DBUS_NAME = 'org.freedesktop.bolt'
DBUS_PATH = '/org/freedesktop/bolt'
DBUS_IFACE_PREFIX = 'org.freedesktop.bolt1.'
DBUS_IFACE_MANAGER = DBUS_IFACE_PREFIX + 'Manager'
DBUS_IFACE_DEVICE = DBUS_IFACE_PREFIX + 'Device'
SERVICE_FILE = '/usr/share/dbus-1/system-services/org.freedesktop.bolt.service'


def get_timeout(topic='default'):
    vals = {
        'valgrind': {
            'default': 20,
            'daemon_start': 60
        },
        'default': {
            'default': 3,
            'daemon_start': 5
        }
    }

    valgrind = os.getenv('VALGRIND')
    lut = vals['valgrind' if valgrind is not None else 'default']
    if topic not in lut:
        raise ValueError('invalid topic')
    return lut[topic]


class Signal(object):
    def __init__(self, name):
        self.name = name
        self.callbacks = set()
        self.notify = None
        self._bridge = None

    def connect(self, callback):
        self.callbacks.add(callback)
        if self.notify is not None:
            self.notify(self, 'connect', len(self.callbacks))
        if len(self.callbacks) == 1:
            self._bridge_build()

    def disconnect(self, callback):
        self.callbacks.remove(callback)
        if self.notify is not None:
            self.notify(self, 'disconnect', len(self.callbacks))
        if len(self.callbacks) == 0:
            self._bridge_destory()

    def disconnect_all(self):
        self.callbacks = set()
        if self.notify is not None:
            self.notify(self, 'disconnect', 0)
        self._bridge_destory()

    def emit(self, *args, **kwargs):
        res = [cb(*args, **kwargs) for cb in self.callbacks]
        return any(res)

    def bridge(self, obj, name, callback):
        if self._bridge is not None:
            raise ValueError('already bridged')
        self._bridge = {'object': obj,
                        'name': name}
        if callback is not None:
            self._bridge['filter'] = callback

    def birdge_destroy(self):
        self._bridge = None

    def _bridge_build(self):
        if self._bridge is None:
            return
        b = self._bridge
        signal_id = b['object'].connect(b['name'], self._bridge_signal)
        b['signal_id'] = signal_id

    def _bridge_destory(self):
        if self._bridge is None:
            return
        b = self._bridge
        b['object'].disconnect(b['signal_id'])
        del b['signal_id']

    def _bridge_signal(self, *args, **kwargs):
        if 'filter' in self._bridge:
            res, args, kwargs = self._bridge['filter'](args, kwargs)
            if not res:
                return
        return self.emit(*args, **kwargs)

    def __call__(self, *args, **kwargs):
        return self.emit(*args, **kwargs)

    def __iadd__(self, callback):
        self.connect(callback)
        return self

    def __isub__(self, callback):
        self.disconnect(callback)
        return self

    @staticmethod
    def enable(klass):
        lst = getattr(klass, 'signals', [])
        methods = [m for m in dir(klass) if not m.startswith('__')]

        def install(l):
            if l is None:
                return
            if l in methods:
                print('WARNING: signal "%s" will overwrite method' % l, file=sys.stderr)

            def get_signals(self):
                signals = getattr(self, '__signals', None)
                if signals is None:
                    signals = {}
                    setattr(self, '__signals', signals)
                return signals

            def get_signal(self):
                signals = get_signals(self)
                if l not in signals:
                    signals[l] = Signal(l)
                return signals[l]

            def getter(self):
                return get_signal(self)

            def setter(self, value):
                return get_signal(self)

            p = property(getter, setter)
            setattr(klass, l, p)
            return l

        bases = klass.__bases__
        ps = {s for b in bases for s in getattr(b, 'signals', [])}
        klass.signals = list(ps.union({install(l) for l in lst}))
        return klass


@Signal.enable
class Recorder(object):
    Event = namedtuple('Event', ['what', 'name', 'details', 'time'])

    signals = ['event']

    def __init__(self, target):
        self.recording = True
        self.events = []
        self.target = target
        self.target.g_properties_changed += self._on_props_changed
        self.target.g_signal += self._on_signal

    def close(self):
        if not self.recording:
            return
        self.target.g_properties_changed -= self._on_props_changed
        self.target.g_signal -= self._on_signal
        self.recording = False
        return self.events

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()

    def _on_props_changed(self, props):
        now = time.time()
        for k, v in props.items():
            e = self.Event('property', k, v, now)
            self._add_event(e)

    def _on_signal(self, proxy, sender, signal, params):
        now = time.time()
        e = self.Event('signal', signal, None, now)
        self._add_event(e)

    def _add_event(self, event):
        self.events.append(event)
        self.event.emit(event)

    @staticmethod
    def event_match(e, target):
        if e.what != target.what or e.name != target.name:
            return False
        if target.details is None:
            return True
        return e.details == target.details

    def events(self, what, name, details=None):
        t = self.Event(what, name, details, None)
        return list(filter(self.event_match(t), self.events))

    def have_event(self, what, name, details=None):
        return len(self.events(what, name, details=details)) > 0

    def wait_for_events(self, lst, timeout=None):
        loop = GLib.MainLoop()

        def got_event(event):
            for idx, e in enumerate(lst):
                if self.event_match(e, event):
                    del lst[idx]
                    break
            if len(lst) == 0:
                loop.quit()

        def got_timeout():
            print('WARNING: timeout reached! Waiting list: ',
                  str(lst), file=sys.stderr)
            loop.quit()

        self.event += got_event
        timeout = timeout or get_timeout()
        GLib.timeout_add(timeout*1000, got_timeout)
        loop.run()
        self.event -= got_event
        return len(lst) == 0

    def wait_for_event(self, what, name, details=None):
        t = self.Event(what, name, details, None)
        return self.wait_for_events([t])


@Signal.enable
class ProxyWrapper(object):
    signals = ['g_properties_changed', 'g_signal']

    def __init__(self, bus, iname, path):
        self._proxy = Gio.DBusProxy.new_sync(bus,
                                             Gio.DBusProxyFlags.DO_NOT_AUTO_START,
                                             None,
                                             DBUS_NAME,
                                             path,
                                             iname,
                                             None)

        def props_changed(args, kwargs):
            return True, [args[1].unpack()], {}

        self.g_properties_changed.bridge(self._proxy,
                                         'g-properties-changed',
                                         props_changed)

        self.g_signal.bridge(self._proxy, 'g-signal', None)

    def __getattr__(self, name):
        if name.startswith('_'):
            raise AttributeError
        if '_' in name:
            c = name.split('_')
            name = "".join(x.title() for x in c)
        else:
            name = name[0].upper() + name[1:]
        if name in self._proxy.get_cached_property_names():
            value = self._proxy.get_cached_property(name)
            if value is not None:
                return value.unpack()
            else:
                return value

        return getattr(self._proxy, name)

    def record(self):
        return Recorder(self)

    @property
    def object_path(self):
        return self._proxy.get_object_path()


class BoltDevice(ProxyWrapper):
    UNKNOWN = -1
    DISCONNECTED = 0
    CONNECTED = 1
    AUTHORIZING = 2
    AUTH_ERROR = 3
    AUTHORIZED = 4
    AUTHORIZED_SECURE = 5
    AUTHORIZED_NEWKEY = 6

    KEY_MISSING = 0
    KEY_HAVE = 1
    KEY_NEW = 2

    def __init__(self, bus, path):
        super(BoltDevice, self).__init__(bus,
                                         DBUS_IFACE_DEVICE,
                                         path)

    @property
    def is_connected(self):
        return self.status > self.DISCONNECTED

    @property
    def is_authorized(self):
        return self.status >= self.AUTHORIZED

    def authorize(self, flags=""):
        self.Authorize('(s)', flags)
        return True

    @property
    def status(self):
        res = getattr(self, 'Status')
        mapping = {'unknown': self.UNKNOWN,
                   'disconnected': self.DISCONNECTED,
                   'connected': self.CONNECTED,
                   'authorizing': self.AUTHORIZING,
                   'auth-error': self.AUTH_ERROR,
                   'authorized': self.AUTHORIZED,
                   'authorized-secure': self.AUTHORIZED_SECURE,
                   'authorized-newkey': self.AUTHORIZED_NEWKEY}

        return mapping.get(res, self.UNKNOWN)

    @property
    def key(self):
        res = getattr(self, 'Key')
        mapping = {'missing': self.KEY_MISSING,
                   'have': self.KEY_HAVE,
                   'new': self.KEY_NEW}

        return mapping.get(res, self.MISSING)

    @property
    def label(self):
        res = getattr(self, 'Label')
        if res is not None and len(res) < 1:
            return None
        return res

    @label.setter
    def label(self, value):
        if isinstance(value, str):
            value = GLib.Variant("s", value)

        res = self._proxy.call_sync('org.freedesktop.DBus.Properties.Set',
                                    GLib.Variant("(ssv)",
                                                 (DBUS_IFACE_DEVICE,
                                                  "Label",
                                                  value)),
                                    0,
                                    -1,
                                    None)
        return res is not None


@Signal.enable
class BoltClient(ProxyWrapper):
    signals = ['device_added', 'device_removed']

    POLICY_DEFAULT = 'default'
    POLICY_MANUAL = 'manual'
    POLICY_AUTO = 'auto'

    def __init__(self, bus):
        super(BoltClient, self).__init__(bus,
                                         DBUS_IFACE_MANAGER,
                                         DBUS_PATH)
        self._proxy.connect('g-signal', self._on_dbus_signal)

    def _on_dbus_signal(self, proxy, sender, signal, params):
        bus = self._proxy.get_connection()
        if signal == 'DeviceAdded':
            self.device_added.emit(BoltDevice(bus, params[0]))
            return True
        elif signal == 'DeviceRemoved':
            self.device_removed.emit(params[0])
            return True
        return False

    def list_devices(self):
        devices = self.ListDevices()
        if devices is None:
            return None
        bus = self._proxy.get_connection()
        return [BoltDevice(bus, d) for d in devices]

    def device_by_uid(self, uid):
        object_path = self.DeviceByUid("(s)", uid)
        if object_path is None:
            return None
        bus = self._proxy.get_connection()
        return BoltDevice(bus, object_path)

    def enroll(self, uid, policy=POLICY_DEFAULT, flags=""):
        object_path = self.EnrollDevice("(sss)", uid, policy, flags)
        if object_path is None:
            return None
        bus = self._proxy.get_connection()
        return BoltDevice(bus, object_path)

    def forget(self, uid):
        self.ForgetDevice("(s)", uid)
        return True


# Mock Device Tree
@Signal.enable
class Device(object):
    subsystem = "unknown"
    udev_attrs = []
    udev_props = []

    signals = ['device_connected',
               'device_disconnected']

    def __init__(self, name, children):
        self._parent = None
        self.children = [self._adopt(c) for c in children]
        self.udev = None
        self.name = name
        self.syspath = None

    def _adopt(self, device):
        device.parent = self
        return device

    def _get_own(self, items):
        i = chain.from_iterable([a, str(getattr(self, a.lower()))] for a in items)
        return list(i)

    def collect(self, predicate):
        children = self.children
        head = [self] if predicate(self) else []
        tail = chain.from_iterable(c.collect(predicate) for c in children)
        return head + list(filter(predicate, tail))

    def first(self, predicate):
        if predicate(self):
            return self
        for c in self.children:
            found = c.first(predicate)
            if found:
                return found

    @property
    def parent(self):
        return self._parent

    @parent.setter
    def parent(self, value):
        self._parent = value

    @property
    def root(self):
        return self if self.parent is None else self.parent.root

    def connect_tree(self, bed):
        self.connect(bed)
        for c in self.children:
            c.connect_tree(bed)

    def connect(self, bed):
        print('connecting ' + self.name, file=sys.stderr)
        assert self.syspath is None
        attributes = self._get_own(self.udev_attrs)
        properties = self._get_own(self.udev_props)
        sysparent = self.parent and self.parent.syspath
        self.syspath = bed.add_device(self.subsystem,
                                      self.name,
                                      sysparent,
                                      attributes,
                                      properties)
        self.root.device_connected(self)
        self.testbed = bed

    def disconnect(self, bed):
        print('disconnecting ' + self.name, file=sys.stderr)
        for c in self.children:
            c.disconnect(bed)
        bed.uevent(self.syspath, "remove")
        bed.remove_device(self.syspath)
        self.authorized = 0
        self.key = ""
        self.root.device_disconnected(self)
        self.syspath = None
        self.testbed = None


class TbDevice(Device):
    subsystem = "thunderbolt"
    devtype = "thunderbolt_device"

    udev_attrs = ['authorized',
                  'device',
                  'device_name',
                  'key',
                  'unique_id',
                  'vendor',
                  'vendor_name']

    udev_props = ['DEVTYPE']

    def __init__(self, name, authorized=0, vendor=None, uid=None, children=None):
        super(TbDevice, self).__init__(name, children or [])
        self.unique_id = uid or str(uuid.uuid4())
        self.device_name = 'Thunderbolt ' + name
        self.device = self._make_id(self.device_name)
        self.vendor_name = vendor or 'GNOME.org'
        self.vendor = self._make_id(self.vendor_name)
        self.authorized = authorized
        self.key = ""

    def _make_id(self, name):
        return '0x%X' % binascii.crc32(name.encode('utf-8'))

    @property
    def authorized_file(self):
        if self.syspath is None:
            return None
        return os.path.join(self.syspath, 'authorized')

    @property
    def bolt_status(self):
        if self.syspath is None:
            return BoltDevice.DISCONNECTED
        elif self.authorized == 0:
            return BoltDevice.CONNECTED
        elif self.authorized == 1:
            if self.key == "":
                return BoltDevice.AUTHORIZED
            else:
                return BoltDevice.AUTHORIZED_NEWKEY
        elif self.authorized == 2:
            return BoltDevice.AUTHORIZED_SECURE

    @property
    def domain(self):
        return self.parent.domain

    @staticmethod
    def is_unauthorized(d):
        return isinstance(d, TbDevice) and d.authorized == 0

    def reload_auth(self):
        authorized = self.authorized
        key = self.key
        f = self.authorized_file
        with open(self.authorized_file, 'r') as f:
            data = f.read()
            self.authorized = int(data)
        with open(os.path.join(self.syspath, 'key'), 'r') as f:
            self.key = f.read().strip()
        if self.authorized != authorized or self.key != key:
            if self.syspath:
                self.testbed.uevent(self.syspath, 'change')


class TbHost(TbDevice):
    def __init__(self, children):
        super(TbHost, self).__init__('Laptop',
                                     authorized=1,
                                     uid='3b7d4bad-4fdf-44ff-8730-ffffdeadbabe',
                                     children=children)

    def connect(self, bed):
        self.authorized = 1
        super(TbHost, self).connect(bed)


class TbDomain(Device):
    subsystem = "thunderbolt"
    devtype = "thunderbolt_domain"

    udev_attrs = ['security']
    udev_props = ['DEVTYPE']

    SECURITY_NONE = 'none'
    SECURITY_USER = 'user'
    SECURITY_SECURE = 'secure'

    def __init__(self, security=SECURITY_SECURE, index=0, host=None):
        assert host
        assert isinstance(host, TbHost)
        name = 'domain%d' % index
        super(TbDomain, self).__init__(name, children=[host])
        self.security = security

    @property
    def devices(self):
        return self.collect(lambda c: isinstance(c, TbDevice))

    @property
    def domain(self):
        return self

    @staticmethod
    def checkattr(d, k, v):
        return hasattr(d, k) and getattr(d, k) == v

    def find(self, **kwargs):
        def finder(d):
            return all([self.checkattr(d, k, v) for k, v in kwargs.items()])

        return self.first(finder)


class TreeChecker(object):

    def __init__(self, client, tree):
        self.client = client
        self.tree = tree
        self.remote_devices = {}
        devices = client.list_devices()
        [self._register_device(d) for d in devices]
        client.device_added += self._on_device_added
        client.device_removed += self._on_device_removed
        tree.device_connected += self._on_udev_connected
        tree.device_disconnected += self._on_udev_disconnected
        self.actions = {}
        self.loop = None

    # signal handler for local (udev) devices
    def _on_udev_connected(self, dev):
        if not isinstance(dev, TbDevice):
            return

        uid = dev.unique_id
        self.actions[uid] = 'connected'

    def _on_udev_disconnected(self, dev):
        if not isinstance(dev, TbDevice):
            return
        uid = dev.unique_id
        if uid in self.actions:
            del self.actions[uid]
        else:
            self.actions[uid] = 'disconnected'

    # signal handler for remote devices
    def _on_device_added(self, dev):
        self._register_device(dev)
        self._check_action(dev.uid, 'connected')

    def _on_device_removed(self, object_path):
        dev = self.remote_devices.get(object_path, None)
        if dev is None:
            return
        self._check_action(dev.uid, 'disconnected')
        self._unregister_device(dev)

    # book keeping of remote devices
    def _register_device(self, dev):
        self.remote_devices[dev.object_path] = dev
        return dev

    def _unregister_device(self, dev):
        del self.remote_devices[dev.object_path]

    def _check_action(self, uid, action):
        if self.actions.get(uid, None) != action:
            return False
        del self.actions[uid]
        self._check_sync()
        return True

    def _check_sync(self):
        keep_going = len(self.actions) > 0
        if not keep_going:
            self._stop_loop()
        return keep_going

    def _stop_loop(self):
        if self.loop is None:
            return
        self.loop.quit()
        self.loop = None

    def _on_timeout(self):
        print('WARNING, timeout reached', file=sys.stderr)
        self._stop_loop()

    def sync(self, timeout=None):
        timeout = timeout or get_timeout()
        GLib.timeout_add(timeout*1000, self._on_timeout)
        self.loop = GLib.MainLoop()
        self.loop.run()

    def close(self):
        self.client.device_added.disconnect_all()
        self.client.device_removed.disconnect_all()
        self.tree.device_connected.disconnect_all()
        self.tree.device_disconnected.disconnect_all()


# Test Suite
class BoltTest(dbusmock.DBusTestCase):
    @staticmethod
    def path_from_service_file(sf):
        with open(SERVICE_FILE) as f:
                for line in f:
                    if not line.startswith('Exec='):
                        continue
                    return line.split('=', 1)[1].strip()
        return None

    @classmethod
    def setUpClass(cls):
        path = None
        if 'BOLT_BUILD_DIR' in os.environ:
            print('Testing local build')
            build_dir = os.environ['BOLT_BUILD_DIR']
            path = os.path.join(build_dir, 'boltd')
        elif 'UNDER_JHBUILD' in os.environ:
            print('Testing JHBuild version')
            jhbuild_prefix = os.environ['JHBUILD_PREFIX']
            path = os.path.join(jhbuild_prefix, 'libexec', 'boltd')
        else:
            print('Testing installed system binaries')
            path = BoltTest.path_from_service_file(SERVICE_FILE)

        assert path is not None, 'failed to find daemon'
        cls.paths = {'daemon': path}

        cls.test_bus = Gio.TestDBus.new(Gio.TestDBusFlags.NONE)
        cls.test_bus.up()
        try:
            del os.environ['DBUS_SESSION_BUS_ADDRESS']
        except KeyError:
            pass
        os.environ['DBUS_SYSTEM_BUS_ADDRESS'] = cls.test_bus.get_bus_address()
        cls.dbus = Gio.bus_get_sync(Gio.BusType.SYSTEM, None)

    @classmethod
    def tearDownClass(cls):
        cls.test_bus.down()
        dbusmock.DBusTestCase.tearDownClass()

    def setUp(self):
        self.testbed = UMockdev.Testbed.new()
        self.dbpath = tempfile.mkdtemp()

        self.client = None
        self.log = None
        self.daemon = None
        self.polkitd = None
        self.valgrind = False

    def tearDown(self):
        shutil.rmtree(self.dbpath)
        del self.testbed
        self.daemon_stop()
        self.polkitd_stop()

    # dbus helper methods
    def get_dbus_property(self, name, interface=DBUS_IFACE_MANAGER):
        proxy = Gio.DBusProxy.new_sync(self.dbus,
                                       Gio.DBusProxyFlags.DO_NOT_AUTO_START,
                                       None,
                                       DBUS_NAME,
                                       DBUS_PATH,
                                       'org.freedesktop.DBus.Properties',
                                       None)
        return proxy.Get('(ss)', interface, name)

    # daemon helper
    def daemon_start(self):
        timeout = get_timeout('daemon_start')  # seconds
        env = os.environ.copy()
        env['G_DEBUG'] = 'fatal-criticals'
        env['UMOCKDEV_DIR'] = self.testbed.get_root_dir()
        env['BOLT_DBPATH'] = self.dbpath
        argv = [self.paths['daemon'], '-v']
        valgrind = os.getenv('VALGRIND')
        if valgrind is not None:
            argv.insert(0, 'valgrind')
            argv.insert(1, '--leak-check=full')
            if os.path.exists(valgrind):
                argv.insert(2, '--suppressions=%s' % valgrind)
            self.valgrind = True
        self.daemon = subprocess.Popen(argv,
                                       env=env,
                                       stdout=self.log,
                                       stderr=subprocess.STDOUT)

        timeout_count = timeout * 10
        timeout_sleep = 0.1
        while timeout_count > 0:
            time.sleep(timeout_sleep)
            timeout_count -= 1
            try:
                self.get_dbus_property('Version')
                break
            except GLib.GError:
                pass
        else:
            timeout_time = timeout_count * timeout_sleep
            self.fail('daemon did not start in %d seconds' % timeout_time)

        self.client = BoltClient(self.dbus)
        self.assertEqual(self.daemon.poll(), None, 'daemon crashed')

    def daemon_stop(self):

        if self.daemon:
            try:
                self.daemon.terminate()
            except OSError:
                pass
            self.daemon.wait()

        self.daemon = None
        self.client = None

    def polkitd_start(self):
        self._polkitd, self._polkitd_obj = self.spawn_server_template(
            'polkitd', {}, stdout=DEVNULL)
        self.polkitd = dbus.Interface(self._polkitd_obj, dbusmock.MOCK_IFACE)

    def polkitd_stop(self):
        if self.polkitd is None:
            return
        self._polkitd.terminate()
        self._polkitd.wait()
        self.polkitd = None

    def user_config(self, **kwargs):
        import configparser
        cfg = configparser.ConfigParser()
        cfg.optionxform = lambda option: option

        cfg['config'] = {}
        for k, v in kwargs.items():
            cfg['config'][k] = v

        path = os.path.join(self.dbpath, 'boltd.conf')
        with open(path, 'w') as f:
            cfg.write(f)

        with open(path, 'r') as f:
            print(f.read())

    # mock tree stuff
    def default_mock_tree(self):
        # default mock tree
        mt = TbDomain(host=TbHost([
            TbDevice('Cable1'),
            TbDevice('Cable2'),
            TbDevice('SSD1')
        ]))
        return mt

    def simple_mock_tree(self):
        mt = TbDomain(host=TbHost([
            TbDevice('Dock')
        ]))
        return mt

    def assertDeviceEqual(self, local, remote):
        self.assertTrue(local and remote)
        self.assertEqual(local.unique_id, remote.uid)
        self.assertEqual(local.device_name, remote.Name)
        self.assertEqual(local.vendor_name, remote.Vendor)

        # if we are "connected"
        if local.syspath is not None:
            self.assertEqual(local.syspath, remote.sysfs_path)
            self.assertTrue(remote.is_connected)

            # remote.parent is also only valid if we are connected
            if local.parent is not None and isinstance(local.parent, TbDevice):
                self.assertEqual(local.parent.unique_id, remote.parent)

        self.assertEqual(local.bolt_status, remote.status)
        return True

    def add_device(self, parent, devid, name, vendor, domain=0, authorized=1):
        uid = str(uuid.uuid4())
        d = self.testbed.add_device('thunderbolt',
                                    "%d-%d" % (domain, devid), parent,
                                    ['device_name', name,
                                     'device', '0x23',
                                     'vendor_name', vendor,
                                     'vendor', '0x23',
                                     'authorized', '%d' % authorized,
                                     'key', '',
                                     'unique_id', uid],
                                    ['DEVTYPE', 'thunderbolt_device'])
        return d, uid

    def find_device_by_uid(self, lst, uid):
        x = [x for x in lst if x.uid == uid]
        self.assertEqual(len(x), 1)
        return x[0]

    # the actual tests
    def test_basic(self):
        self.daemon_start()
        version = self.client.version
        assert version is not None
        d = self.client.list_devices()
        self.assertEqual(len(d), 0)
        policy = self.client.default_policy
        self.assertIn(policy, [self.client.POLICY_AUTO,
                               self.client.POLICY_MANUAL])
        # connect all device
        tree = self.default_mock_tree()
        chk = TreeChecker(self.client, tree)
        tree.connect_tree(self.testbed)
        chk.sync()  # check for the signals

        devices = self.client.list_devices()
        self.assertEqual(len(devices), len(tree.devices))
        for remote in devices:
            local = tree.find(unique_id=remote.uid)
            self.assertDeviceEqual(local, remote)

        # disconnect all devices again
        tree.disconnect(self.testbed)
        chk.sync()  # check for the signals
        chk.close()

        devices = self.client.list_devices()
        self.assertEqual(len(devices), 0)
        self.daemon_stop()

    def test_signals_on_start(self):
        # Check that we get DeviceAdded signals for un-authorized
        # devices that are not in the database

        client = BoltClient(self.dbus)
        tree = self.default_mock_tree()
        tree.connect_tree(self.testbed)

        with client.record() as tape:
            self.daemon_start()
            res = tape.wait_for_event('signal',
                                      'DeviceAdded',
                                      None)
            self.assertTrue(res)
        self.daemon_stop()

    def test_basic_device_name(self):
        domain = 0
        security = 'secure'

        dc = self.testbed.add_device('thunderbolt', 'domain%d' % domain, None,
                                     ['security', security],
                                     ['DEVTYPE', 'thunderbolt_domain'])

        host = self.testbed.add_device('thunderbolt', "%d-0" % domain, dc,
                                       ['device_name', 'Host',
                                        'device', '0x23',
                                        'vendor_name', 'GNOME.org',
                                        'vendor', '0x23',
                                        'authorized', '1',
                                        'unique_id', str(uuid.uuid4())],
                                       ['DEVTYPE', 'thunderbolt_device'])
        print(host)
        # duplicated vendor in name
        d1_vendor = "GNOME.org"
        d1_name = "Cable"
        d2_name = '⍾ Laptop'
        d2_vendor = 'Evil Corp. ☢'
        _, d1_uid = self.add_device(host, 1, d1_vendor + d1_name, d1_vendor)
        _, d2_uid = self.add_device(host, 2, d2_name, d2_vendor)

        self.daemon_start()
        devices = self.client.list_devices()
        self.assertEqual(len(devices), 3)

        device = self.find_device_by_uid(devices, d1_uid)
        self.assertEqual(device.Name, d1_name)
        self.assertEqual(device.Vendor, d1_vendor)

        device = self.find_device_by_uid(devices, d2_uid)
        self.assertEqual(device.Name, d2_name)
        self.assertEqual(device.Vendor, d2_vendor)

        label = device.label
        self.assertIsNone(label)

        self.daemon_stop()

    def test_basic_user_config(self):
        self.user_config(DefaultPolicy='manual')

        self.daemon_start()
        policy = self.client.default_policy
        self.assertEqual(policy, self.client.POLICY_MANUAL)
        self.daemon_stop()

    def test_device_by_uid(self):
        self.daemon_start()

        with self.assertRaises(GLib.GError):
            self.client.device_by_uid("")

        with self.assertRaises(GLib.GError):
            self.client.device_by_uid("nonexistant")

        tree = self.default_mock_tree()
        tree.connect_tree(self.testbed)

        for d in tree.devices:
            remote = self.client.device_by_uid(d.unique_id)
            self.assertIsNotNone(remote)
            self.assertDeviceEqual(d, remote)

        self.daemon_stop()

    def test_device_authorize(self):
        self.daemon_start()
        tree = self.default_mock_tree()
        tree.connect_tree(self.testbed)

        self.polkitd_start()

        to_authorize = tree.collect(TbDevice.is_unauthorized)

        # check that we are not allowed to authorize devices
        for d in to_authorize:
            remote = self.client.device_by_uid(d.unique_id)
            with self.assertRaises(GLib.GError) as cm:
                remote.authorize()
            err = cm.exception
            self.assertEqual(err.domain, GLib.quark_to_string(Gio.DBusError.quark()))
            self.assertEqual(err.code, int(Gio.DBusError.ACCESS_DENIED))

        self.polkitd.SetAllowed(['org.freedesktop.bolt.authorize'])
        for d in to_authorize:
            remote = self.client.device_by_uid(d.unique_id)
            tape = remote.record()
            remote.authorize()
            d.reload_auth()  # will emit the uevent, so the daemon can update
            res = tape.wait_for_event('property', 'Status', 'authorized')
            self.assertTrue(res)
            self.assertDeviceEqual(d, remote)
            tape.close()

        for d in to_authorize:
            remote = self.client.device_by_uid(d.unique_id)
            with self.assertRaises(GLib.GError) as cm:
                remote.authorize()

        self.daemon_stop()

    def test_device_enroll(self):
        self.daemon_start()
        tree = self.default_mock_tree()
        tree.connect_tree(self.testbed)
        self.polkitd_start()

        client = self.client

        to_enroll = tree.collect(TbDevice.is_unauthorized)

        # check that we are not allowed to enroll devices, i.e. the correct
        # policykit action is called.

        for d in to_enroll:
            with self.assertRaises(GLib.GError) as cm:
                client.enroll(d.unique_id)
            err = cm.exception
            self.assertEqual(err.domain, GLib.quark_to_string(Gio.DBusError.quark()))
            self.assertEqual(err.code, int(Gio.DBusError.ACCESS_DENIED))

        self.polkitd.SetAllowed(['org.freedesktop.bolt.enroll'])

        # check we get a proper error for a unknown device
        with self.assertRaises(GLib.GError) as cm:
            # non-existent uuid
            client.forget("884c6edd-7118-4b21-b186-b02d396ecca0")

        policy = BoltClient.POLICY_AUTO
        for d in to_enroll:
            remote = client.enroll(d.unique_id, policy)
            d.reload_auth()  # will emit the uevent, so the daemon can update
            # the security level for the domain is SECURE, which means we should
            # have authorized via a new key:
            #  status should be AUTHORIZED_NEWKEY
            #  stored should be True
            #  key(state) should be KEY_NEW
            self.assertDeviceEqual(d, remote)
            self.assertTrue(remote.stored, True)
            self.assertEqual(remote.key, BoltDevice.KEY_NEW)
            self.assertEqual(remote.policy, policy)

        # we disconnect the tree, but since the devices are connected
        # the daemon should have them in its database now
        tree.disconnect(self.testbed)

        devices = self.client.list_devices()

        # the host itself is not stored in the daemon
        expected_number = len(tree.devices) - 1
        devices = self.client.list_devices()
        tries = 0
        while expected_number != len(devices) and tries < 3:
            time.sleep(.2)
            tries += 1
            devices = self.client.list_devices()
            self.assertEqual(len(devices), expected_number)

        tree.connect(self.testbed)              # we connect the domain again
        tree.children[0].connect(self.testbed)  # and the host too

        for remote in devices:
            local = tree.find(unique_id=remote.uid)
            self.assertDeviceEqual(local, remote)
            self.assertTrue(remote.stored, True)
            # key status should have changed to HAVE from NEW
            self.assertEqual(remote.key, BoltDevice.KEY_HAVE)
            self.assertEqual(remote.policy, policy)

            # now we connect that specific device and wait for
            # the property changes
            with remote.record() as tape:
                local.connect(self.testbed)
                res = tape.wait_for_event('property',
                                          'Status',
                                          'authorized-secure')
                local.reload_auth()  # will emit the uevent, so the daemon can update
                self.assertTrue(res)
                self.assertDeviceEqual(local, remote)

    def test_device_forget(self):
        self.daemon_start()
        tree = self.default_mock_tree()
        self.polkitd_start()
        tree.connect_tree(self.testbed)

        client = self.client
        self.polkitd.SetAllowed(['org.freedesktop.bolt.enroll'])

        to_enroll = tree.collect(TbDevice.is_unauthorized)
        policy = BoltClient.POLICY_MANUAL
        for d in to_enroll:
            remote = client.enroll(d.unique_id, policy)
            d.reload_auth()
            self.assertDeviceEqual(d, remote)
            self.assertTrue(remote.stored, True)
            self.assertEqual(remote.key, BoltDevice.KEY_NEW)
            self.assertEqual(remote.policy, policy)

        tree.disconnect(self.testbed)
        expected_number = len(tree.devices) - 1  # host is not stored
        devices = self.client.list_devices()
        tries = 0
        while expected_number != len(devices) and tries < 3:
            time.sleep(.2)
            tries += 1
            devices = self.client.list_devices()
        self.assertEqual(len(devices), expected_number)

        for remote in devices:
            with self.assertRaises(GLib.GError) as cm:
                client.forget(d.unique_id)
            err = cm.exception
            self.assertEqual(err.domain, GLib.quark_to_string(Gio.DBusError.quark()))
            self.assertEqual(err.code, int(Gio.DBusError.ACCESS_DENIED))

        self.polkitd.SetAllowed(['org.freedesktop.bolt.manage'])

        # check we get a proper error for a unknown device
        with self.assertRaises(GLib.GError) as cm:
            # non-existent uuid
            client.forget("884c6edd-7118-4b21-b186-b02d396ecca0")

        for remote in devices:
            client.forget(remote.uid)

        devices = self.client.list_devices()
        self.assertEqual(len(devices), 0)

    def test_device_label(self):
        self.daemon_start()
        tree = self.simple_mock_tree()
        self.polkitd_start()
        tree.connect_tree(self.testbed)

        client = self.client
        self.polkitd.SetAllowed(['org.freedesktop.bolt.enroll'])

        local = tree.collect(TbDevice.is_unauthorized)[0]
        policy = BoltClient.POLICY_AUTO

        remote = client.enroll(local.unique_id, policy)
        local.reload_auth()

        label = remote.label
        self.assertIsNone(label)

        self.assertDeviceEqual(local, remote)
        self.assertTrue(remote.stored, True)
        self.assertEqual(remote.key, BoltDevice.KEY_NEW)
        self.assertEqual(remote.policy, policy)

        with self.assertRaises(GLib.GError) as cm:
            remote.label = 'not authorized'
        err = cm.exception
        self.assertEqual(err.domain, GLib.quark_to_string(Gio.DBusError.quark()))
        self.assertEqual(err.code, int(Gio.DBusError.ACCESS_DENIED))

        self.polkitd.SetAllowed(['org.freedesktop.bolt.manage'])
        for val in ['', ' ', '     ']:
            with self.assertRaises(GLib.GError) as cm:
                remote.label = val
            err = cm.exception
            self.assertEqual(err.domain, GLib.quark_to_string(Gio.DBusError.quark()))
            self.assertEqual(err.code, int(Gio.DBusError.INVALID_ARGS))

        label = remote.label
        self.assertIsNone(label)

        val = 'A valid label'

        with remote.record() as tape:
            remote.label = val
            res = tape.wait_for_event('property',
                                      'Label',
                                      val)
            self.assertTrue(res)
        self.assertEqual(remote.label, val)

        self.daemon_stop()


if __name__ == '__main__':
    if len(sys.argv) == 2 and sys.argv[1] == "list-tests":
        suit = unittest.defaultTestLoader.loadTestsFromTestCase(BoltTest)
        for t in suit:
            name = t.id()
            print(name[9:], end=" ")
        sys.exit(0)
    unittest.main(verbosity=2)
