irq: inform device of IRQ mask & unmask via callback (#694)

Client masks or unmasks a device IRQ using the
VFIO_USER_DEVICE_SET_IRQS message. Inform the device of such changes to
the IRQ state.

Signed-off-by: Jagannathan Raman <jag.raman@oracle.com>
Reviewed-by: John Levon <john.levon@nutanix.com>
diff --git a/include/libvfio-user.h b/include/libvfio-user.h
index ba599b5..84eb2d8 100644
--- a/include/libvfio-user.h
+++ b/include/libvfio-user.h
@@ -549,6 +549,31 @@
 vfu_setup_device_nr_irqs(vfu_ctx_t *vfu_ctx, enum vfu_dev_irq_type type,
                          uint32_t count);
 
+/*
+ * Function that is called when the guest masks or unmasks an IRQ vector.
+ *
+ * @vfu_ctx: the libvfio-user context
+ * @start: starting IRQ vector
+ * @count: number of vectors
+ * @mask: indicates if the IRQ is masked or unmasked
+ */
+typedef void (vfu_dev_irq_state_cb_t)(vfu_ctx_t *vfu_ctx, uint32_t start,
+                                      uint32_t count, bool mask);
+
+/**
+ * Set up IRQ state change callback. When libvfio-user is notified of a
+ * change to IRQ state, whether masked or unmasked, it invokes
+ * this callback.
+ *
+ * @vfu_ctx: the libvfio-user context
+ * @type: IRQ type such as VFU_DEV_MSIX_IRQ - defined by vfu_dev_irq_type
+ * @cb: IRQ state change callback
+ *
+ * @returns 0 on success, -1 on error, sets errno.
+ */
+int
+vfu_setup_irq_state_callback(vfu_ctx_t *vfu_ctx, enum vfu_dev_irq_type type,
+                             vfu_dev_irq_state_cb_t *cb);
 
 typedef enum {
     VFU_MIGR_STATE_STOP,
diff --git a/lib/irq.c b/lib/irq.c
index c7820aa..eadad7b 100644
--- a/lib/irq.c
+++ b/lib/irq.c
@@ -154,6 +154,32 @@
     }
 }
 
+static void
+irqs_set_state(vfu_ctx_t *vfu_ctx, struct vfio_irq_set *irq_set)
+{
+    vfu_dev_irq_state_cb_t *cb = NULL;
+    uint32_t irq_action;
+    bool mask = false;
+
+    assert(irq_set->index < VFU_DEV_NUM_IRQS);
+    cb = vfu_ctx->irq_state_cbs[irq_set->index];
+    if (cb == NULL) {
+        return;
+    }
+
+    assert((irq_set->start + irq_set->count) <=
+            vfu_ctx->irq_count[irq_set->index]);
+
+    irq_action = irq_set->flags & VFIO_IRQ_SET_ACTION_TYPE_MASK;
+
+    assert((irq_action & VFIO_IRQ_SET_ACTION_MASK) ||
+           (irq_action & VFIO_IRQ_SET_ACTION_UNMASK));
+
+    mask = (irq_action & VFIO_IRQ_SET_ACTION_MASK) ? true : false;
+
+    cb(vfu_ctx, irq_set->start, irq_set->count, mask);
+}
+
 static int
 irqs_set_data_none(vfu_ctx_t *vfu_ctx, struct vfio_irq_set *irq_set)
 {
@@ -345,8 +371,7 @@
     switch (irq_set->flags & VFIO_IRQ_SET_ACTION_TYPE_MASK) {
     case VFIO_IRQ_SET_ACTION_MASK:
     case VFIO_IRQ_SET_ACTION_UNMASK:
-        // We're always edge-triggered without un/mask support.
-        // FIXME: return an error? We don't report MASKABLE
+        irqs_set_state(vfu_ctx, irq_set);
         return 0;
     case VFIO_IRQ_SET_ACTION_TRIGGER:
         break;
diff --git a/lib/libvfio-user.c b/lib/libvfio-user.c
index 47e3572..c45ceeb 100644
--- a/lib/libvfio-user.c
+++ b/lib/libvfio-user.c
@@ -1968,6 +1968,22 @@
 }
 
 EXPORT int
+vfu_setup_irq_state_callback(vfu_ctx_t *vfu_ctx, enum vfu_dev_irq_type type,
+                             vfu_dev_irq_state_cb_t *cb)
+{
+    assert(vfu_ctx != NULL);
+
+    if (type >= VFU_DEV_NUM_IRQS) {
+        vfu_log(vfu_ctx, LOG_ERR, "Invalid IRQ type index %u", type);
+        return ERROR_INT(EINVAL);
+    }
+
+    vfu_ctx->irq_state_cbs[type] = cb;
+
+    return 0;
+}
+
+EXPORT int
 vfu_setup_device_migration_callbacks(vfu_ctx_t *vfu_ctx,
                                      const vfu_migration_callbacks_t *callbacks,
                                      uint64_t data_offset)
diff --git a/lib/private.h b/lib/private.h
index 4c483f2..7ffd6be 100644
--- a/lib/private.h
+++ b/lib/private.h
@@ -172,6 +172,7 @@
     struct migration        *migration;
 
     uint32_t                irq_count[VFU_DEV_NUM_IRQS];
+    vfu_dev_irq_state_cb_t  *irq_state_cbs[VFU_DEV_NUM_IRQS];
     vfu_irqs_t              *irqs;
     bool                    realized;
     vfu_dev_type_t          dev_type;
diff --git a/test/py/libvfio_user.py b/test/py/libvfio_user.py
index 4bdb761..76d9315 100644
--- a/test/py/libvfio_user.py
+++ b/test/py/libvfio_user.py
@@ -624,6 +624,11 @@
 
 lib.vfu_device_quiesced.argtypes = (c.c_void_p, c.c_int)
 
+vfu_dev_irq_state_cb_t = c.CFUNCTYPE(None, c.c_void_p, c.c_uint32,
+                                     c.c_bool, use_errno=True)
+lib.vfu_setup_irq_state_callback.argtypes = (c.c_void_p, c.c_int,
+                                             vfu_dev_irq_state_cb_t)
+
 
 def to_byte(val):
     """Cast an int to a byte value."""
@@ -1030,6 +1035,20 @@
     return lib.vfu_setup_device_nr_irqs(ctx, irqtype, count)
 
 
+def irq_state(ctx, vector, mask):
+    pass
+
+
+@vfu_dev_irq_state_cb_t
+def __irq_state(ctx, vector, mask):
+    irq_state(ctx, vector, mask)
+
+
+def vfu_setup_irq_state_callback(ctx, irqtype, cb=__irq_state):
+    assert ctx is not None
+    return lib.vfu_setup_irq_state_callback(ctx, irqtype, cb)
+
+
 def vfu_pci_init(ctx, pci_type=VFU_PCI_TYPE_EXPRESS,
                  hdr_type=PCI_HEADER_TYPE_NORMAL):
     assert ctx is not None
diff --git a/test/py/test_device_set_irqs.py b/test/py/test_device_set_irqs.py
index 382804a..7525b30 100644
--- a/test/py/test_device_set_irqs.py
+++ b/test/py/test_device_set_irqs.py
@@ -27,6 +27,8 @@
 #  DAMAGE.
 #
 
+from unittest.mock import patch
+
 from libvfio_user import *
 import errno
 import os
@@ -53,6 +55,9 @@
     ret = vfu_setup_device_nr_irqs(ctx, VFU_DEV_MSIX_IRQ, 2048)
     assert ret == 0
 
+    vfu_setup_irq_state_callback(ctx, VFU_DEV_MSIX_IRQ)
+    assert ret == 0
+
     ret = vfu_realize_ctx(ctx)
     assert ret == 0
 
@@ -308,6 +313,20 @@
     assert struct.unpack("Q", os.read(fd2, 8))[0] == 9
 
 
+@patch('libvfio_user.irq_state')
+def test_irq_state(mock_irq_state):
+    assert mock_irq_state.call_count == 0
+
+    payload = vfio_irq_set(argsz=argsz, flags=VFIO_IRQ_SET_DATA_NONE |
+                           VFIO_IRQ_SET_ACTION_MASK,
+                           index=VFU_DEV_MSIX_IRQ,
+                           start=0, count=1)
+
+    msg(ctx, sock, VFIO_USER_DEVICE_SET_IRQS, payload)
+
+    assert mock_irq_state.call_count == 1
+
+
 def test_device_set_irqs_cleanup():
     vfu_destroy_ctx(ctx)