[PATCH 2/2] Add APIs to send file descriptors through the network

Keith Packard keithp at keithp.com
Thu Oct 31 21:10:25 CET 2013


Exposes new TRANS(SendFd)/TRANS(RecvFd) APIs.

Signed-off-by: Keith Packard <keithp at keithp.com>
---
 Xtrans.c     |  12 +++
 Xtrans.h     |   4 +
 Xtransint.h  |  24 ++++++
 Xtranssock.c | 233 ++++++++++++++++++++++++++++++++++++++++++++++++++++++-----
 4 files changed, 256 insertions(+), 17 deletions(-)

diff --git a/Xtrans.c b/Xtrans.c
index ac8b13d..7c7967f 100644
--- a/Xtrans.c
+++ b/Xtrans.c
@@ -873,6 +873,18 @@ TRANS(Writev) (XtransConnInfo ciptr, struct iovec *buf, int size)
 }
 
 int
+TRANS(SendFd) (XtransConnInfo ciptr, int fd, int do_close)
+{
+    return ciptr->transptr->SendFd(ciptr, fd, do_close);
+}
+
+int
+TRANS(RecvFd) (XtransConnInfo ciptr)
+{
+    return ciptr->transptr->RecvFd(ciptr);
+}
+
+int
 TRANS(Disconnect) (XtransConnInfo ciptr)
 
 {
diff --git a/Xtrans.h b/Xtrans.h
index 1b0af45..53b8b62 100644
--- a/Xtrans.h
+++ b/Xtrans.h
@@ -350,6 +350,10 @@ int TRANS(Writev)(
     int			/* size */
 );
 
+int TRANS(SendFd) (XtransConnInfo ciptr, int fd, int do_close);
+
+int TRANS(RecvFd) (XtransConnInfo ciptr);
+
 int TRANS(Disconnect)(
     XtransConnInfo	/* ciptr */
 );
diff --git a/Xtransint.h b/Xtransint.h
index 3bce8dc..dd886db 100644
--- a/Xtransint.h
+++ b/Xtransint.h
@@ -72,6 +72,8 @@ from The Open Group.
 #  define XTRANSDEBUG 1
 #endif
 
+#define XTRANS_SEND_FDS       1
+
 #ifdef WIN32
 # define _WILLWINSOCK_
 #endif
@@ -123,6 +125,16 @@ from The Open Group.
 #define X_TCP_PORT	6000
 #endif
 
+#if XTRANS_SEND_FDS
+
+struct _XtransConnFd {
+    struct _XtransConnFd   *next;
+    int                    fd;
+    int                    do_close;
+};
+
+#endif
+
 struct _XtransConnInfo {
     struct _Xtransport     *transptr;
     int		index;
@@ -135,6 +147,8 @@ struct _XtransConnInfo {
     int		addrlen;
     char	*peeraddr;
     int		peeraddrlen;
+    struct _XtransConnFd        *recv_fds;
+    struct _XtransConnFd        *send_fds;
 };
 
 #define XTRANS_OPEN_COTS_CLIENT       1
@@ -275,6 +289,16 @@ typedef struct _Xtransport {
 	int			/* size */
     );
 
+    int (*SendFd)(
+	XtransConnInfo,		/* connection */
+        int,                    /* fd */
+        int                     /* do_close */
+    );
+
+    int (*RecvFd)(
+	XtransConnInfo		/* connection */
+    );
+
     int	(*Disconnect)(
 	XtransConnInfo		/* connection */
     );
diff --git a/Xtranssock.c b/Xtranssock.c
index 24269b2..23150b2 100644
--- a/Xtranssock.c
+++ b/Xtranssock.c
@@ -2097,47 +2097,176 @@ TRANS(SocketBytesReadable) (XtransConnInfo ciptr, BytesReadable_t *pend)
 #endif /* WIN32 */
 }
 
+#if XTRANS_SEND_FDS
+
+static void
+appendFd(struct _XtransConnFd **prev, int fd, int do_close)
+{
+    struct _XtransConnFd *cf, *new;
+
+    new = malloc (sizeof (struct _XtransConnFd));
+    if (!new) {
+        /* XXX mark connection as broken */
+        close(fd);
+        return;
+    }
+    new->next = 0;
+    new->fd = fd;
+    new->do_close = do_close;
+    /* search to end of list */
+    for (; (cf = *prev); prev = &(cf->next));
+    *prev = new;
+}
 
 static int
-TRANS(SocketRead) (XtransConnInfo ciptr, char *buf, int size)
+removeFd(struct _XtransConnFd **prev)
+{
+    struct _XtransConnFd *cf;
+    int fd;
+
+    if ((cf = *prev)) {
+        *prev = cf->next;
+        fd = cf->fd;
+        free(cf);
+    } else
+        fd = -1;
+    return fd;
+}
 
+static void
+discardFd(struct _XtransConnFd **prev, struct _XtransConnFd *upto, int do_close)
 {
-    prmsg (2,"SocketRead(%d,%p,%d)\n", ciptr->fd, buf, size);
+    struct _XtransConnFd *cf, *next;
 
-#if defined(WIN32)
-    {
-	int ret = recv ((SOCKET)ciptr->fd, buf, size, 0);
-#ifdef WIN32
-	if (ret == SOCKET_ERROR) errno = WSAGetLastError();
-#endif
-	return ret;
+    for (cf = *prev; cf != upto; cf = next) {
+        next = cf->next;
+        if (do_close || cf->do_close)
+            close(cf->fd);
+        free(cf);
     }
-#else
-    return read (ciptr->fd, buf, size);
-#endif /* WIN32 */
+    *prev = upto;
 }
 
+static void
+cleanupFds(XtransConnInfo ciptr)
+{
+    /* Clean up the send list but don't close the fds */
+    discardFd(&ciptr->send_fds, NULL, 0);
+    /* Clean up the recv list and *do* close the fds */
+    discardFd(&ciptr->recv_fds, NULL, 1);
+}
 
 static int
-TRANS(SocketWrite) (XtransConnInfo ciptr, char *buf, int size)
+nFd(struct _XtransConnFd **prev)
+{
+    struct _XtransConnFd *cf;
+    int n = 0;
+
+    for (cf = *prev; cf; cf = cf->next)
+        n++;
+    return n;
+}
+
+static int
+TRANS(SocketRecvFd) (XtransConnInfo ciptr)
+{
+    prmsg (2, "SocketRecvFd(%d)\n", ciptr->fd);
+    return removeFd(&ciptr->recv_fds);
+}
 
+static int
+TRANS(SocketSendFd) (XtransConnInfo ciptr, int fd, int do_close)
 {
-    prmsg (2,"SocketWrite(%d,%p,%d)\n", ciptr->fd, buf, size);
+    appendFd(&ciptr->send_fds, fd, do_close);
+    return 0;
+}
+
+static int
+TRANS(SocketRecvFdInvalid)(XtransConnInfo ciptr)
+{
+    errno = EINVAL;
+    return -1;
+}
+
+static int
+TRANS(SocketSendFdInvalid)(XtransConnInfo ciptr, int fd, int do_close)
+{
+    errno = EINVAL;
+    return -1;
+}
+
+#define MAX_FDS		128
+
+struct fd_pass {
+	struct cmsghdr	cmsghdr;
+	int		fd[MAX_FDS];
+};
+
+static inline void init_msg_recv(struct msghdr *msg, struct iovec *iov, int niov, struct fd_pass *pass, int nfd) {
+    msg->msg_name = NULL;
+    msg->msg_namelen = 0;
+    msg->msg_iov = iov;
+    msg->msg_iovlen = niov;
+    msg->msg_control = pass;
+    msg->msg_controllen = sizeof (struct cmsghdr) + nfd * sizeof (int);
+}
+
+static inline void init_msg_send(struct msghdr *msg, struct iovec *iov, int niov, struct fd_pass *pass, int nfd) {
+    init_msg_recv(msg, iov, niov, pass, nfd);
+    pass->cmsghdr.cmsg_len = msg->msg_controllen;
+    pass->cmsghdr.cmsg_level = SOL_SOCKET;
+    pass->cmsghdr.cmsg_type = SCM_RIGHTS;
+}
+
+#endif /* XTRANS_SEND_FDS */
+
+static int
+TRANS(SocketRead) (XtransConnInfo ciptr, char *buf, int size)
+
+{
+    prmsg (2,"SocketRead(%d,%p,%d)\n", ciptr->fd, buf, size);
 
 #if defined(WIN32)
     {
-	int ret = send ((SOCKET)ciptr->fd, buf, size, 0);
+	int ret = recv ((SOCKET)ciptr->fd, buf, size, 0);
 #ifdef WIN32
 	if (ret == SOCKET_ERROR) errno = WSAGetLastError();
 #endif
 	return ret;
     }
 #else
-    return write (ciptr->fd, buf, size);
+#if XTRANS_SEND_FDS
+    {
+        struct msghdr   msg;
+        struct iovec    iov;
+        struct fd_pass  pass;
+
+        iov.iov_base = buf;
+        iov.iov_len = size;
+
+        init_msg_recv(&msg, &iov, 1, &pass, MAX_FDS);
+        size = recvmsg(ciptr->fd, &msg, 0);
+        if (size >= 0 && msg.msg_controllen > sizeof (struct cmsghdr)) {
+            if (pass.cmsghdr.cmsg_level == SOL_SOCKET &&
+                pass.cmsghdr.cmsg_type == SCM_RIGHTS &&
+                !((msg.msg_flags & MSG_TRUNC) ||
+                  (msg.msg_flags & MSG_CTRUNC)))
+            {
+                int nfd = (msg.msg_controllen - sizeof (struct cmsghdr)) / sizeof (int);
+                int *fd = (int *) CMSG_DATA(&pass.cmsghdr);
+                int i;
+                for (i = 0; i < nfd; i++)
+                    appendFd(&ciptr->recv_fds, fd[i], 0);
+            }
+        }
+        return size;
+    }
+#else
+    return read(ciptr->fd, buf, size);
+#endif /* XTRANS_SEND_FDS */
 #endif /* WIN32 */
 }
 
-
 static int
 TRANS(SocketReadv) (XtransConnInfo ciptr, struct iovec *buf, int size)
 
@@ -2154,11 +2283,65 @@ TRANS(SocketWritev) (XtransConnInfo ciptr, struct iovec *buf, int size)
 {
     prmsg (2,"SocketWritev(%d,%p,%d)\n", ciptr->fd, buf, size);
 
+#if XTRANS_SEND_FDS
+    if (ciptr->send_fds)
+    {
+        struct msghdr           msg;
+        struct fd_pass          pass;
+        int                     nfd;
+        struct _XtransConnFd    *cf;
+        int                     i;
+
+        nfd = nFd(&ciptr->send_fds);
+        cf = ciptr->send_fds;
+
+        /* Set up fds */
+        for (i = 0; i < nfd; i++) {
+            pass.fd[i] = cf->fd;
+            cf = cf->next;
+        }
+
+        init_msg_send(&msg, buf, size, &pass, nfd);
+        i = sendmsg(ciptr->fd, &msg, 0);
+        if (i > 0)
+            discardFd(&ciptr->send_fds, cf, 0);
+        return i;
+    }
+#endif
     return WRITEV (ciptr, buf, size);
 }
 
 
 static int
+TRANS(SocketWrite) (XtransConnInfo ciptr, char *buf, int size)
+
+{
+    prmsg (2,"SocketWrite(%d,%p,%d)\n", ciptr->fd, buf, size);
+
+#if defined(WIN32)
+    {
+	int ret = send ((SOCKET)ciptr->fd, buf, size, 0);
+#ifdef WIN32
+	if (ret == SOCKET_ERROR) errno = WSAGetLastError();
+#endif
+	return ret;
+    }
+#else
+#if XTRANS_SEND_FDS
+    if (ciptr->send_fds)
+    {
+        struct iovec            iov;
+
+        iov.iov_base = buf;
+        iov.iov_len = size;
+        return TRANS(SocketWritev)(ciptr, &iov, 1);
+    }
+#endif /* XTRANS_SEND_FDS */
+    return write (ciptr->fd, buf, size);
+#endif /* WIN32 */
+}
+
+static int
 TRANS(SocketDisconnect) (XtransConnInfo ciptr)
 
 {
@@ -2211,6 +2394,9 @@ TRANS(SocketUNIXClose) (XtransConnInfo ciptr)
 
     prmsg (2,"SocketUNIXClose(%p,%d)\n", ciptr, ciptr->fd);
 
+#if XTRANS_SEND_FDS
+    cleanupFds(ciptr);
+#endif
     ret = close(ciptr->fd);
 
     if (ciptr->flags
@@ -2239,6 +2425,9 @@ TRANS(SocketUNIXCloseForCloning) (XtransConnInfo ciptr)
     prmsg (2,"SocketUNIXCloseForCloning(%p,%d)\n",
 	ciptr, ciptr->fd);
 
+#if XTRANS_SEND_FDS
+    cleanupFds(ciptr);
+#endif
     ret = close(ciptr->fd);
 
     return ret;
@@ -2293,6 +2482,8 @@ Xtransport	TRANS(SocketTCPFuncs) = {
 	TRANS(SocketWrite),
 	TRANS(SocketReadv),
 	TRANS(SocketWritev),
+        TRANS(SocketSendFdInvalid),
+        TRANS(SocketRecvFdInvalid),
 	TRANS(SocketDisconnect),
 	TRANS(SocketINETClose),
 	TRANS(SocketINETClose),
@@ -2333,6 +2524,8 @@ Xtransport	TRANS(SocketINETFuncs) = {
 	TRANS(SocketWrite),
 	TRANS(SocketReadv),
 	TRANS(SocketWritev),
+        TRANS(SocketSendFdInvalid),
+        TRANS(SocketRecvFdInvalid),
 	TRANS(SocketDisconnect),
 	TRANS(SocketINETClose),
 	TRANS(SocketINETClose),
@@ -2374,6 +2567,8 @@ Xtransport     TRANS(SocketINET6Funcs) = {
 	TRANS(SocketWrite),
 	TRANS(SocketReadv),
 	TRANS(SocketWritev),
+        TRANS(SocketSendFdInvalid),
+        TRANS(SocketRecvFdInvalid),
 	TRANS(SocketDisconnect),
 	TRANS(SocketINETClose),
 	TRANS(SocketINETClose),
@@ -2422,6 +2617,8 @@ Xtransport	TRANS(SocketLocalFuncs) = {
 	TRANS(SocketWrite),
 	TRANS(SocketReadv),
 	TRANS(SocketWritev),
+        TRANS(SocketSendFd),
+        TRANS(SocketRecvFd),
 	TRANS(SocketDisconnect),
 	TRANS(SocketUNIXClose),
 	TRANS(SocketUNIXCloseForCloning),
@@ -2476,6 +2673,8 @@ Xtransport	TRANS(SocketUNIXFuncs) = {
 	TRANS(SocketWrite),
 	TRANS(SocketReadv),
 	TRANS(SocketWritev),
+        TRANS(SocketSendFd),
+        TRANS(SocketRecvFd),
 	TRANS(SocketDisconnect),
 	TRANS(SocketUNIXClose),
 	TRANS(SocketUNIXCloseForCloning),
-- 
1.8.4.rc3



More information about the xorg-devel mailing list