#!/usr/bin/env python3
# group: rw auto quick
#
# Test cases for NBD multi-conn advertisement
#
# Copyright (C) 2022 Red Hat, Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import os
from contextlib import contextmanager
import iotests
from iotests import qemu_img_create, qemu_io


disk = os.path.join(iotests.test_dir, 'disk')
size = '4M'
nbd_sock = os.path.join(iotests.sock_dir, 'nbd_sock')
nbd_uri = 'nbd+unix:///{}?socket=' + nbd_sock


@contextmanager
def open_nbd(export_name):
    h = nbd.NBD()
    try:
        h.connect_uri(nbd_uri.format(export_name))
        yield h
    finally:
        h.shutdown()

class TestNbdMulticonn(iotests.QMPTestCase):
    def setUp(self):
        qemu_img_create('-f', iotests.imgfmt, disk, size)
        qemu_io('-c', 'w -P 1 0 2M', '-c', 'w -P 2 2M 2M', disk)

        self.vm = iotests.VM()
        self.vm.launch()
        result = self.vm.qmp('blockdev-add', {
            'driver': 'qcow2',
            'node-name': 'n',
            'file': {'driver': 'file', 'filename': disk}
        })
        self.assert_qmp(result, 'return', {})

    def tearDown(self):
        self.vm.shutdown()
        os.remove(disk)
        try:
            os.remove(nbd_sock)
        except OSError:
            pass

    @contextmanager
    def run_server(self, max_connections=None):
        args = {
            'addr': {
                'type': 'unix',
                'data': {'path': nbd_sock}
            }
        }
        if max_connections is not None:
            args['max-connections'] = max_connections

        result = self.vm.qmp('nbd-server-start', args)
        self.assert_qmp(result, 'return', {})
        yield

        result = self.vm.qmp('nbd-server-stop')
        self.assert_qmp(result, 'return', {})

    def add_export(self, name, writable=None):
        args = {
            'type': 'nbd',
            'id': name,
            'node-name': 'n',
            'name': name,
        }
        if writable is not None:
            args['writable'] = writable

        result = self.vm.qmp('block-export-add', args)
        self.assert_qmp(result, 'return', {})

    def test_default_settings(self):
        with self.run_server():
            self.add_export('r')
            self.add_export('w', writable=True)
            with open_nbd('r') as h:
                self.assertTrue(h.can_multi_conn())
            with open_nbd('w') as h:
                self.assertTrue(h.can_multi_conn())

    def test_limited_connections(self):
        with self.run_server(max_connections=1):
            self.add_export('r')
            self.add_export('w', writable=True)
            with open_nbd('r') as h:
                self.assertFalse(h.can_multi_conn())
            with open_nbd('w') as h:
                self.assertFalse(h.can_multi_conn())

    def test_parallel_writes(self):
        with self.run_server():
            self.add_export('w', writable=True)

            clients = [nbd.NBD() for _ in range(3)]
            for c in clients:
                c.connect_uri(nbd_uri.format('w'))
                self.assertTrue(c.can_multi_conn())

            initial_data = clients[0].pread(1024 * 1024, 0)
            self.assertEqual(initial_data, b'\x01' * 1024 * 1024)

            updated_data = b'\x03' * 1024 * 1024
            clients[1].pwrite(updated_data, 0)
            clients[2].flush()
            current_data = clients[0].pread(1024 * 1024, 0)

            self.assertEqual(updated_data, current_data)

            for i in range(3):
                clients[i].shutdown()


if __name__ == '__main__':
    try:
        # Easier to use libnbd than to try and set up parallel
        # 'qemu-nbd --list' or 'qemu-io' processes, but not all systems
        # have libnbd installed.
        import nbd  # type: ignore

        iotests.main(supported_fmts=['qcow2'])
    except ImportError:
        iotests.notrun('libnbd not installed')
