fix: server sample not marking dirty pages (#748)

The server sample is supposed to demonstrate dirty page logging, but it was not marking dirty pages. This commit both adds client-side dirty page tracking for pages dirtied with `vfu_sgl_write` and server-side dirty page tracking for pages directly dirtied by the server using `vfu_sgl_get/put`.

Signed-off-by: William Henderson <william.henderson@nutanix.com>
diff --git a/samples/client.c b/samples/client.c
index 1e06162..0086fd6 100644
--- a/samples/client.c
+++ b/samples/client.c
@@ -54,6 +54,7 @@
 /* This is low, so we get testing of vfu_sgl_read/write() chunking. */
 #define CLIENT_MAX_DATA_XFER_SIZE (1024)
 
+
 static char const *irq_to_str[] = {
     [VFU_DEV_INTX_IRQ] = "INTx",
     [VFU_DEV_MSI_IRQ] = "MSI",
@@ -62,6 +63,18 @@
     [VFU_DEV_REQ_IRQ] = "REQ"
 };
 
+struct client_dma_region {
+/*
+ * Our DMA regions are one page in size so we only need one bit to mark them as
+ * dirty.
+ */
+#define CLIENT_DIRTY_PAGE_TRACKING_ENABLED (1 << 0)
+#define CLIENT_DIRTY_DMA_REGION (1 << 1)
+    uint32_t flags;
+    struct vfio_user_dma_map map;
+    int fd;
+};
+
 void
 vfu_log(UNUSED vfu_ctx_t *vfu_ctx, UNUSED int level,
         const char *fmt, ...)
@@ -560,8 +573,8 @@
 }
 
 static void
-handle_dma_write(int sock, struct vfio_user_dma_map *dma_regions,
-                 int nr_dma_regions, int *dma_region_fds)
+handle_dma_write(int sock, struct client_dma_region *dma_regions,
+                 int nr_dma_regions)
 {
     struct vfio_user_dma_region_access dma_access;
     struct vfio_user_header hdr;
@@ -588,20 +601,30 @@
         off_t offset;
         ssize_t c;
 
-        if (dma_access.addr < dma_regions[i].addr ||
-            dma_access.addr >= dma_regions[i].addr + dma_regions[i].size) {
+        if (dma_access.addr < dma_regions[i].map.addr ||
+            dma_access.addr >= dma_regions[i].map.addr + dma_regions[i].map.size) {
             continue;
         }
 
-        offset = dma_regions[i].offset + dma_access.addr;
+        offset = dma_regions[i].map.offset + dma_access.addr;
 
-        c = pwrite(dma_region_fds[i], data, dma_access.count, offset);
+        c = pwrite(dma_regions[i].fd, data, dma_access.count, offset);
 
         if (c != (ssize_t)dma_access.count) {
             err(EXIT_FAILURE, "failed to write to fd=%d at [%#llx-%#llx)",
-                    dma_region_fds[i], (ull_t)offset,
+                    dma_regions[i].fd, (ull_t)offset,
                     (ull_t)(offset + dma_access.count));
         }
+
+        /*
+         * DMA regions in this example are one page in size so we use one bit
+         * to mark the newly-dirtied page as dirty.
+         */
+        if (dma_regions[i].flags & CLIENT_DIRTY_PAGE_TRACKING_ENABLED) {
+            assert(dma_regions[i].map.size == PAGE_SIZE);
+            dma_regions[i].flags |= CLIENT_DIRTY_DMA_REGION;
+        }
+
         break;
     }
 
@@ -616,8 +639,8 @@
 }
 
 static void
-handle_dma_read(int sock, struct vfio_user_dma_map *dma_regions,
-                int nr_dma_regions, int *dma_region_fds)
+handle_dma_read(int sock, struct client_dma_region *dma_regions,
+                int nr_dma_regions)
 {
     struct vfio_user_dma_region_access dma_access, *response;
     struct vfio_user_header hdr;
@@ -644,18 +667,18 @@
         off_t offset;
         ssize_t c;
 
-        if (dma_access.addr < dma_regions[i].addr ||
-            dma_access.addr >= dma_regions[i].addr + dma_regions[i].size) {
+        if (dma_access.addr < dma_regions[i].map.addr ||
+            dma_access.addr >= dma_regions[i].map.addr + dma_regions[i].map.size) {
             continue;
         }
 
-        offset = dma_regions[i].offset + dma_access.addr;
+        offset = dma_regions[i].map.offset + dma_access.addr;
 
-        c = pread(dma_region_fds[i], data, dma_access.count, offset);
+        c = pread(dma_regions[i].fd, data, dma_access.count, offset);
 
         if (c != (ssize_t)dma_access.count) {
             err(EXIT_FAILURE, "failed to read from fd=%d at [%#llx-%#llx)",
-                    dma_region_fds[i], (ull_t)offset,
+                    dma_regions[i].fd, (ull_t)offset,
                     (ull_t)offset + dma_access.count);
         }
         break;
@@ -672,23 +695,24 @@
 }
 
 static void
-handle_dma_io(int sock, struct vfio_user_dma_map *dma_regions,
-              int nr_dma_regions, int *dma_region_fds)
+handle_dma_io(int sock, struct client_dma_region *dma_regions,
+              int nr_dma_regions)
 {
     size_t i;
 
     for (i = 0; i < 4096 / CLIENT_MAX_DATA_XFER_SIZE; i++) {
-        handle_dma_write(sock, dma_regions, nr_dma_regions, dma_region_fds);
+        handle_dma_write(sock, dma_regions, nr_dma_regions);
     }
     for (i = 0; i < 4096 / CLIENT_MAX_DATA_XFER_SIZE; i++) {
-        handle_dma_read(sock, dma_regions, nr_dma_regions, dma_region_fds);
+        handle_dma_read(sock, dma_regions, nr_dma_regions);
     }
 }
 
 static void
-get_dirty_bitmap(int sock, struct vfio_user_dma_map *dma_region)
+get_dirty_bitmap(int sock, struct client_dma_region *dma_region,
+                 bool expect_dirty)
 {
-    uint64_t bitmap_size = _get_bitmap_size(dma_region->size,
+    uint64_t bitmap_size = _get_bitmap_size(dma_region->map.size,
                                             sysconf(_SC_PAGESIZE));
     struct vfio_user_dirty_pages *dirty_pages;
     struct vfio_user_bitmap_range *range;
@@ -707,8 +731,8 @@
     dirty_pages->argsz = sizeof(*dirty_pages) + sizeof(*range) + bitmap_size;
 
     range = data + sizeof(*dirty_pages);
-    range->iova = dma_region->addr;
-    range->size = dma_region->size;
+    range->iova = dma_region->map.addr;
+    range->size = dma_region->map.size;
     range->bitmap.size = bitmap_size;
     range->bitmap.pgsize = sysconf(_SC_PAGESIZE);
 
@@ -721,9 +745,17 @@
         err(EXIT_FAILURE, "failed to get dirty page bitmap");
     }
 
+    char dirtied_by_server = bitmap[0];
+    char dirtied_by_client = (dma_region->flags & CLIENT_DIRTY_DMA_REGION) != 0;
+    char dirtied = dirtied_by_server | dirtied_by_client;
+
     printf("client: %s: %#llx-%#llx\t%#x\n", __func__,
            (ull_t)range->iova,
-           (ull_t)(range->iova + range->size - 1), bitmap[0]);
+           (ull_t)(range->iova + range->size - 1), dirtied);
+
+    if (expect_dirty) {
+        assert(dirtied);
+    }
 
     free(data);
 }
@@ -1058,8 +1090,8 @@
 }
 
 static void
-map_dma_regions(int sock, struct vfio_user_dma_map *dma_regions,
-                int *dma_region_fds, int nr_dma_regions)
+map_dma_regions(int sock, struct client_dma_region *dma_regions,
+                int nr_dma_regions)
 {
     int i, ret;
 
@@ -1067,13 +1099,13 @@
         struct iovec iovecs[2] = {
             /* [0] is for the header. */
             [1] = {
-                .iov_base = &dma_regions[i],
-                .iov_len = sizeof(*dma_regions)
+                .iov_base = &dma_regions[i].map,
+                .iov_len = sizeof(struct vfio_user_dma_map)
             }
         };
         ret = tran_sock_msg_iovec(sock, 0x1234 + i, VFIO_USER_DMA_MAP,
                                   iovecs, ARRAY_SIZE(iovecs),
-                                  &dma_region_fds[i], 1,
+                                  &dma_regions[i].fd, 1,
                                   NULL, NULL, 0, NULL, 0);
         if (ret < 0) {
             err(EXIT_FAILURE, "failed to map DMA regions");
@@ -1085,9 +1117,8 @@
 {
     char template[] = "/tmp/libvfio-user.XXXXXX";
     int ret, sock, irq_fd;
-    struct vfio_user_dma_map *dma_regions;
+    struct client_dma_region *dma_regions;
     struct vfio_user_device_info client_dev_info = {0};
-    int *dma_region_fds;
     int i;
     int tmpfd;
     int server_max_fds;
@@ -1176,21 +1207,21 @@
     unlink(template);
 
     dma_regions = calloc(nr_dma_regions, sizeof(*dma_regions));
-    dma_region_fds = calloc(nr_dma_regions, sizeof(*dma_region_fds));
-    if (dma_regions == NULL || dma_region_fds == NULL) {
+    if (dma_regions == NULL) {
         err(EXIT_FAILURE, "%m\n");
     }
 
     for (i = 0; i < nr_dma_regions; i++) {
-        dma_regions[i].argsz = sizeof(struct vfio_user_dma_map);
-        dma_regions[i].addr = i * sysconf(_SC_PAGESIZE);
-        dma_regions[i].size = sysconf(_SC_PAGESIZE);
-        dma_regions[i].offset = dma_regions[i].addr;
-        dma_regions[i].flags = VFIO_USER_F_DMA_REGION_READ | VFIO_USER_F_DMA_REGION_WRITE;
-        dma_region_fds[i] = tmpfd;
+        dma_regions[i].map.argsz = sizeof(struct vfio_user_dma_map);
+        dma_regions[i].map.addr = i * sysconf(_SC_PAGESIZE);
+        dma_regions[i].map.size = sysconf(_SC_PAGESIZE);
+        dma_regions[i].map.offset = dma_regions[i].map.addr;
+        dma_regions[i].map.flags = VFIO_USER_F_DMA_REGION_READ |
+                                   VFIO_USER_F_DMA_REGION_WRITE;
+        dma_regions[i].fd = tmpfd;
     }
 
-    map_dma_regions(sock, dma_regions, dma_region_fds, nr_dma_regions);
+    map_dma_regions(sock, dma_regions, nr_dma_regions);
 
     /*
      * XXX VFIO_USER_DEVICE_GET_IRQ_INFO and VFIO_IRQ_SET_ACTION_TRIGGER
@@ -1208,6 +1239,14 @@
     }
 
     /*
+     * Start client-side dirty page tracking (which happens in
+     * `handle_dma_write` when writes are successful).
+     */
+    for (i = 0; i < nr_dma_regions; i++) {
+        dma_regions[i].flags |= CLIENT_DIRTY_PAGE_TRACKING_ENABLED;
+    }
+
+    /*
      * XXX VFIO_USER_REGION_READ and VFIO_USER_REGION_WRITE
      *
      * BAR0 in the server does not support memory mapping so it must be accessed
@@ -1220,10 +1259,15 @@
 
     /* FIXME check that above took at least 1s */
 
-    handle_dma_io(sock, dma_regions, nr_dma_regions, dma_region_fds);
+    handle_dma_io(sock, dma_regions, nr_dma_regions);
 
     for (i = 0; i < nr_dma_regions; i++) {
-        get_dirty_bitmap(sock, &dma_regions[i]);
+        /*
+         * We expect regions 0 and 1 to be dirtied: 0 through messages (so
+         * marked by the client) and 1 directly (so marked by the server). See
+         * the bottom of the main function of server.c.
+         */
+        get_dirty_bitmap(sock, &dma_regions[i], i < 2);
     }
 
     dirty_pages.argsz = sizeof(dirty_pages);
@@ -1235,6 +1279,11 @@
         err(EXIT_FAILURE, "failed to stop dirty page logging");
     }
 
+    /* Stop client-side dirty page tracking */
+    for (i = 0; i < nr_dma_regions; i++) {
+        dma_regions[i].flags &= ~CLIENT_DIRTY_PAGE_TRACKING_ENABLED;
+    }
+
     /* BAR1 can be memory mapped and read directly */
 
     /*
@@ -1245,8 +1294,8 @@
     for (i = 0; i < server_max_fds; i++) {
         struct vfio_user_dma_unmap r = {
             .argsz = sizeof(r),
-            .addr = dma_regions[i].addr,
-            .size = dma_regions[i].size
+            .addr = dma_regions[i].map.addr,
+            .size = dma_regions[i].map.size
         };
         ret = tran_sock_msg(sock, 7, VFIO_USER_DMA_UNMAP, &r, sizeof(r),
                             NULL, &r, sizeof(r));
@@ -1297,7 +1346,6 @@
      * unmapped.
      */
     map_dma_regions(sock, dma_regions + server_max_fds,
-                    dma_region_fds + server_max_fds,
                     nr_dma_regions - server_max_fds);
 
     /*
@@ -1311,8 +1359,7 @@
     wait_for_irq(irq_fd);
 
     handle_dma_io(sock, dma_regions + server_max_fds,
-                  nr_dma_regions - server_max_fds,
-                  dma_region_fds + server_max_fds);
+                  nr_dma_regions - server_max_fds);
 
     struct vfio_user_dma_unmap r = {
         .argsz = sizeof(r),
@@ -1327,7 +1374,6 @@
     }
 
     free(dma_regions);
-    free(dma_region_fds);
 
     return 0;
 }
diff --git a/samples/server.c b/samples/server.c
index 11f4074..565974d 100644
--- a/samples/server.c
+++ b/samples/server.c
@@ -192,7 +192,8 @@
  * sparsely memory mappable. We should also have a test where the server does
  * DMA directly on the client memory.
  */
-static void do_dma_io(vfu_ctx_t *vfu_ctx, struct server_data *server_data)
+static void do_dma_io(vfu_ctx_t *vfu_ctx, struct server_data *server_data,
+                      int region, bool use_messages)
 {
     const int size = 1024;
     const int count = 4;
@@ -206,21 +207,54 @@
 
     assert(vfu_ctx != NULL);
 
+    struct iovec iov = {0};
+
     /* Write some data, chunked into multiple calls to exercise offsets. */
     for (int i = 0; i < count; ++i) {
-        addr = server_data->regions[0].iova.iov_base + i * size;
+        addr = server_data->regions[region].iova.iov_base + i * size;
         ret = vfu_addr_to_sgl(vfu_ctx, (vfu_dma_addr_t)addr, size, sg, 1,
                               PROT_WRITE);
+                              
         if (ret < 0) {
             err(EXIT_FAILURE, "failed to map %p-%p", addr, addr + size - 1);
         }
 
         memset(&buf[i * size], 'A' + i, size);
-        vfu_log(vfu_ctx, LOG_DEBUG, "%s: WRITE addr %p size %d", __func__, addr,
-                size);
-        ret = vfu_sgl_write(vfu_ctx, sg, 1, &buf[i * size]);
-        if (ret < 0) {
-            err(EXIT_FAILURE, "vfu_sgl_write failed");
+
+        if (use_messages) {
+            vfu_log(vfu_ctx, LOG_DEBUG, "%s: MESSAGE WRITE addr %p size %d",
+                    __func__, addr, size);
+            ret = vfu_sgl_write(vfu_ctx, sg, 1, &buf[i * size]);
+            if (ret < 0) {
+                err(EXIT_FAILURE, "vfu_sgl_write failed");
+            }
+        } else {
+            vfu_log(vfu_ctx, LOG_DEBUG, "%s: DIRECT WRITE  addr %p size %d",
+                    __func__, addr, size);
+            ret = vfu_sgl_get(vfu_ctx, sg, &iov, 1, 0);
+            if (ret < 0) {
+                err(EXIT_FAILURE, "vfu_sgl_get failed");
+            }
+            assert(iov.iov_len == (size_t)size);
+            memcpy(iov.iov_base, &buf[i * size], size);
+
+            /*
+             * When directly writing to client memory the server is responsible
+             * for tracking dirty pages. We assert that all dirty writes are
+             * within the first page of region 1. In fact, all regions are only
+             * one page in size.
+             * 
+             * Note: this is not strictly necessary in this example, since we
+             * later call `vfu_sgl_put`, which marks pages dirty if the SGL was
+             * acquired with `PROT_WRITE`. However, `vfu_sgl_mark_dirty` is
+             * useful in cases where the server needs to mark guest memory dirty
+             * without releasing the memory with `vfu_sgl_put`.
+             */
+            vfu_sgl_mark_dirty(vfu_ctx, sg, 1);
+            assert(region == 1);
+            assert(i * size < (int)PAGE_SIZE);
+
+            vfu_sgl_put(vfu_ctx, sg, &iov, 1);
         }
     }
 
@@ -229,17 +263,30 @@
     /* Read the data back at double the chunk size. */
     memset(buf, 0, sizeof(buf));
     for (int i = 0; i < count; i += 2) {
-        addr = server_data->regions[0].iova.iov_base + i * size;
+        addr = server_data->regions[region].iova.iov_base + i * size;
         ret = vfu_addr_to_sgl(vfu_ctx, (vfu_dma_addr_t)addr, size * 2, sg, 1,
                               PROT_READ);
         if (ret < 0) {
             err(EXIT_FAILURE, "failed to map %p-%p", addr, addr + 2 * size - 1);
         }
-        vfu_log(vfu_ctx, LOG_DEBUG, "%s: READ  addr %p size %d", __func__, addr,
-                2 * size);
-        ret = vfu_sgl_read(vfu_ctx, sg, 1, &buf[i * size]);
-        if (ret < 0) {
-            err(EXIT_FAILURE, "vfu_sgl_read failed");
+
+        if (use_messages) {
+            vfu_log(vfu_ctx, LOG_DEBUG, "%s: MESSAGE READ  addr %p size %d",
+                    __func__, addr, 2 * size);
+            ret = vfu_sgl_read(vfu_ctx, sg, 1, &buf[i * size]);
+            if (ret < 0) {
+                err(EXIT_FAILURE, "vfu_sgl_read failed");
+            }
+        } else {
+            vfu_log(vfu_ctx, LOG_DEBUG, "%s: DIRECT READ   addr %p size %d",
+                    __func__, addr, 2 * size);
+            ret = vfu_sgl_get(vfu_ctx, sg, &iov, 1, 0);
+            if (ret < 0) {
+                err(EXIT_FAILURE, "vfu_sgl_get failed");
+            }
+            assert(iov.iov_len == 2 * (size_t)size);
+            memcpy(&buf[i * size], iov.iov_base, 2 * size);
+            vfu_sgl_put(vfu_ctx, sg, &iov, 1);
         }
     }
 
@@ -247,6 +294,9 @@
 
     if (crc1 != crc2) {
         errx(EXIT_FAILURE, "DMA write and DMA read mismatch");
+    } else {
+        vfu_log(vfu_ctx, LOG_DEBUG, "%s: %s success", __func__,
+                use_messages ? "MESSAGE" : "DIRECT");
     }
 }
 
@@ -603,14 +653,25 @@
                     err(EXIT_FAILURE, "vfu_irq_trigger() failed");
                 }
 
+                printf("doing dma io\n");
+
                 /*
-                 * We also initiate some dummy DMA via an explicit message,
-                 * again to show how DMA is done. This is used if the client's
-                 * RAM isn't mappable or the server implementation prefers it
-                 * this way.  Again, the client expects the server to send DMA
-                 * messages right after it has triggered the IRQs.
+                 * We initiate some dummy DMA by directly accessing the client's
+                 * memory. In this case, we keep track of dirty pages ourselves,
+                 * as the client has no knowledge of what and when we have
+                 * written to its memory.
                  */
-                do_dma_io(vfu_ctx, &server_data);
+                do_dma_io(vfu_ctx, &server_data, 1, false);
+                
+                /*
+                 * We also do some dummy DMA via explicit messages to show how
+                 * DMA is done if the client's RAM isn't mappable or the server
+                 * implementation prefers it this way. In this case, the client
+                 * is responsible for tracking pages that are dirtied, as it is
+                 * the one actually performing the writes.
+                 */
+                do_dma_io(vfu_ctx, &server_data, 0, true);
+
                 ret = 0;
             }
         }