blob: d0663d696cace97f64561dc8557bc9e53443b32d [file] [log] [blame]
Huai-Cheng Kuobc419a12024-07-03 19:20:26 +10001/* SPDX-License-Identifier: BSD-3-Clause */
2/*
3 * QEMU SPDM socket support
4 *
5 * This is based on:
6 * https://github.com/DMTF/spdm-emu/blob/07c0a838bcc1c6207c656ac75885c0603e344b6f/spdm_emu/spdm_emu_common/command.c
7 * but has been re-written to match QEMU style
8 *
9 * Copyright (c) 2021, DMTF. All rights reserved.
10 * Copyright (c) 2023. Western Digital Corporation or its affiliates.
11 */
12
13#include "qemu/osdep.h"
14#include "sysemu/spdm-socket.h"
15#include "qapi/error.h"
16
17static bool read_bytes(const int socket, uint8_t *buffer,
18 size_t number_of_bytes)
19{
20 ssize_t number_received = 0;
21 ssize_t result;
22
23 while (number_received < number_of_bytes) {
24 result = recv(socket, buffer + number_received,
25 number_of_bytes - number_received, 0);
26 if (result <= 0) {
27 return false;
28 }
29 number_received += result;
30 }
31 return true;
32}
33
34static bool read_data32(const int socket, uint32_t *data)
35{
36 bool result;
37
38 result = read_bytes(socket, (uint8_t *)data, sizeof(uint32_t));
39 if (!result) {
40 return result;
41 }
42 *data = ntohl(*data);
43 return true;
44}
45
46static bool read_multiple_bytes(const int socket, uint8_t *buffer,
47 uint32_t *bytes_received,
48 uint32_t max_buffer_length)
49{
50 uint32_t length;
51 bool result;
52
53 result = read_data32(socket, &length);
54 if (!result) {
55 return result;
56 }
57
58 if (length > max_buffer_length) {
59 return false;
60 }
61
62 if (bytes_received) {
63 *bytes_received = length;
64 }
65
66 if (length == 0) {
67 return true;
68 }
69
70 return read_bytes(socket, buffer, length);
71}
72
73static bool receive_platform_data(const int socket,
74 uint32_t transport_type,
75 uint32_t *command,
76 uint8_t *receive_buffer,
77 uint32_t *bytes_to_receive)
78{
79 bool result;
80 uint32_t response;
81 uint32_t bytes_received;
82
83 result = read_data32(socket, &response);
84 if (!result) {
85 return result;
86 }
87 *command = response;
88
89 result = read_data32(socket, &transport_type);
90 if (!result) {
91 return result;
92 }
93
94 bytes_received = 0;
95 result = read_multiple_bytes(socket, receive_buffer, &bytes_received,
96 *bytes_to_receive);
97 if (!result) {
98 return result;
99 }
100 *bytes_to_receive = bytes_received;
101
102 return result;
103}
104
105static bool write_bytes(const int socket, const uint8_t *buffer,
106 uint32_t number_of_bytes)
107{
108 ssize_t number_sent = 0;
109 ssize_t result;
110
111 while (number_sent < number_of_bytes) {
112 result = send(socket, buffer + number_sent,
113 number_of_bytes - number_sent, 0);
114 if (result == -1) {
115 return false;
116 }
117 number_sent += result;
118 }
119 return true;
120}
121
122static bool write_data32(const int socket, uint32_t data)
123{
124 data = htonl(data);
125 return write_bytes(socket, (uint8_t *)&data, sizeof(uint32_t));
126}
127
128static bool write_multiple_bytes(const int socket, const uint8_t *buffer,
129 uint32_t bytes_to_send)
130{
131 bool result;
132
133 result = write_data32(socket, bytes_to_send);
134 if (!result) {
135 return result;
136 }
137
138 return write_bytes(socket, buffer, bytes_to_send);
139}
140
141static bool send_platform_data(const int socket,
142 uint32_t transport_type, uint32_t command,
143 const uint8_t *send_buffer, size_t bytes_to_send)
144{
145 bool result;
146
147 result = write_data32(socket, command);
148 if (!result) {
149 return result;
150 }
151
152 result = write_data32(socket, transport_type);
153 if (!result) {
154 return result;
155 }
156
157 return write_multiple_bytes(socket, send_buffer, bytes_to_send);
158}
159
160int spdm_socket_connect(uint16_t port, Error **errp)
161{
162 int client_socket;
163 struct sockaddr_in server_addr;
164
165 client_socket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
166 if (client_socket < 0) {
167 error_setg(errp, "cannot create socket: %s", strerror(errno));
168 return -1;
169 }
170
171 memset((char *)&server_addr, 0, sizeof(server_addr));
172 server_addr.sin_family = AF_INET;
173 server_addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
174 server_addr.sin_port = htons(port);
175
176
177 if (connect(client_socket, (struct sockaddr *)&server_addr,
178 sizeof(server_addr)) < 0) {
179 error_setg(errp, "cannot connect: %s", strerror(errno));
180 close(client_socket);
181 return -1;
182 }
183
184 return client_socket;
185}
186
187uint32_t spdm_socket_rsp(const int socket, uint32_t transport_type,
188 void *req, uint32_t req_len,
189 void *rsp, uint32_t rsp_len)
190{
191 uint32_t command;
192 bool result;
193
194 result = send_platform_data(socket, transport_type,
195 SPDM_SOCKET_COMMAND_NORMAL,
196 req, req_len);
197 if (!result) {
198 return 0;
199 }
200
201 result = receive_platform_data(socket, transport_type, &command,
202 (uint8_t *)rsp, &rsp_len);
203 if (!result) {
204 return 0;
205 }
206
207 assert(command != 0);
208
209 return rsp_len;
210}
211
212void spdm_socket_close(const int socket, uint32_t transport_type)
213{
214 send_platform_data(socket, transport_type,
215 SPDM_SOCKET_COMMAND_SHUTDOWN, NULL, 0);
216}