Signed-off-by: Hans Leidekker <hans(a)codeweavers.com>
---
dlls/webservices/channel.c | 167 ++++++++++++++++++++++++++-------------------
1 file changed, 95 insertions(+), 72 deletions(-)
diff --git a/dlls/webservices/channel.c b/dlls/webservices/channel.c
index 5394aa6c28..139e625a36 100644
--- a/dlls/webservices/channel.c
+++ b/dlls/webservices/channel.c
@@ -1144,33 +1144,9 @@ done:
return hr;
}
-static void set_blocking( SOCKET socket, BOOL blocking )
-{
- ULONG state = !blocking;
- ioctlsocket( socket, FIONBIO, &state );
-}
-
-static int sock_recv( SOCKET socket, char *buf, int len )
-{
- int count, ret;
-
- if ((ret = recv( socket, buf, len, 0 )) <= 0) return ret;
- len -= ret;
-
- set_blocking( socket, FALSE );
- for (;;)
- {
- if ((count = recv( socket, buf + ret, len, 0 )) <= 0) break;
- ret += count;
- len -= count;
- }
- set_blocking( socket, TRUE );
- return ret;
-}
-
static HRESULT receive_bytes( struct channel *channel, unsigned char *bytes, int len )
{
- int count = sock_recv( channel->u.tcp.socket, (char *)bytes, len );
+ int count = recv( channel->u.tcp.socket, (char *)bytes, len, 0 );
if (count < 0) return HRESULT_FROM_WIN32( WSAGetLastError() );
if (count != len) return WS_E_INVALID_FORMAT;
return S_OK;
@@ -1214,8 +1190,6 @@ static HRESULT send_message( struct channel *channel, WS_MESSAGE *msg )
HRESULT hr;
channel->msg = msg;
- if ((hr = connect_channel( channel )) != S_OK) return hr;
-
WsGetMessageProperty( channel->msg, WS_MESSAGE_PROPERTY_BODY_WRITER, &writer, sizeof(writer), NULL );
WsGetWriterProperty( writer, WS_XML_WRITER_PROPERTY_BYTES, &buf, sizeof(buf), NULL );
@@ -1232,7 +1206,7 @@ static HRESULT send_message( struct channel *channel, WS_MESSAGE *msg )
return send_message_http( channel->u.http.request, buf.bytes, buf.length );
case WS_TCP_CHANNEL_BINDING:
- if (channel->encoding == WS_ENCODING_XML_BINARY_SESSION_1)
+ if (channel->type & WS_CHANNEL_TYPE_SESSION)
{
switch (channel->session_state)
{
@@ -1249,10 +1223,10 @@ static HRESULT send_message( struct channel *channel, WS_MESSAGE *msg )
return WS_E_OTHER;
}
}
- return send_bytes( channel->u.tcp.socket, buf.bytes, buf.length );
+ /* fall through */
case WS_UDP_CHANNEL_BINDING:
- return send_bytes( channel->u.udp.socket, buf.bytes, buf.length );
+ return WsFlushWriter( writer, 0, NULL, NULL );
default:
ERR( "unhandled binding %u\n", channel->binding );
@@ -1273,7 +1247,7 @@ HRESULT channel_send_message( WS_CHANNEL *handle, WS_MESSAGE *msg )
return E_INVALIDARG;
}
- hr = send_message( channel, msg );
+ if ((hr = connect_channel( channel )) == S_OK) hr = send_message( channel, msg );
LeaveCriticalSection( &channel->cs );
return hr;
@@ -1305,11 +1279,25 @@ static HRESULT CALLBACK dict_cb( void *state, const WS_XML_STRING *str, BOOL *fo
return hr;
}
+static CALLBACK HRESULT write_callback( void *state, const WS_BYTES *buf, ULONG count,
+ const WS_ASYNC_CONTEXT *ctx, WS_ERROR *error )
+{
+ SOCKET socket = *(SOCKET *)state;
+ if (send( socket, (const char *)buf->bytes, buf->length, 0 ) < 0)
+ {
+ TRACE( "send failed %u\n", WSAGetLastError() );
+ }
+ return S_OK;
+}
+
static HRESULT init_writer( struct channel *channel )
{
WS_XML_WRITER_BUFFER_OUTPUT buf = {{WS_XML_WRITER_OUTPUT_TYPE_BUFFER}};
+ WS_XML_WRITER_STREAM_OUTPUT stream = {{WS_XML_WRITER_OUTPUT_TYPE_STREAM}};
WS_XML_WRITER_TEXT_ENCODING text = {{WS_XML_WRITER_ENCODING_TYPE_TEXT}, WS_CHARSET_UTF8};
WS_XML_WRITER_BINARY_ENCODING bin = {{WS_XML_WRITER_ENCODING_TYPE_BINARY}};
+ const WS_XML_WRITER_ENCODING *encoding;
+ const WS_XML_WRITER_OUTPUT *output;
WS_XML_WRITER_PROPERTY prop;
ULONG max_size = (1 << 17);
HRESULT hr;
@@ -1322,19 +1310,33 @@ static HRESULT init_writer( struct channel *channel )
switch (channel->encoding)
{
case WS_ENCODING_XML_UTF8:
- return WsSetOutput( channel->writer, &text.encoding, &buf.output, NULL, 0, NULL );
+ encoding = &text.encoding;
+ if (channel->binding == WS_UDP_CHANNEL_BINDING ||
+ (channel->binding == WS_TCP_CHANNEL_BINDING && !(channel->type & WS_CHANNEL_TYPE_SESSION)))
+ {
+ stream.writeCallback = write_callback;
+ stream.writeCallbackState = (channel->binding == WS_UDP_CHANNEL_BINDING) ?
+ &channel->u.udp.socket : &channel->u.tcp.socket;
+ output = &stream.output;
+ }
+ else output = &buf.output;
+ break;
case WS_ENCODING_XML_BINARY_SESSION_1:
bin.staticDictionary = (WS_XML_DICTIONARY *)&dict_builtin_static.dict;
/* fall through */
case WS_ENCODING_XML_BINARY_1:
- return WsSetOutput( channel->writer, &bin.encoding, &buf.output, NULL, 0, NULL );
+ encoding = &bin.encoding;
+ output = &buf.output;
+ break;
default:
FIXME( "unhandled encoding %u\n", channel->encoding );
return WS_E_NOT_SUPPORTED;
}
+
+ return WsSetOutput( channel->writer, encoding, output, NULL, 0, NULL );
}
static HRESULT write_message( struct channel *channel, WS_MESSAGE *msg, const WS_ELEMENT_DESCRIPTION *desc,
@@ -1378,6 +1380,7 @@ HRESULT WINAPI WsSendMessage( WS_CHANNEL *handle, WS_MESSAGE *msg, const WS_MESS
if ((hr = WsAddressMessage( msg, &channel->addr, NULL )) != S_OK) goto done;
if ((hr = message_set_action( msg, desc->action )) != S_OK) goto done;
+ if ((hr = connect_channel( channel )) != S_OK) goto done;
if ((hr = init_writer( channel )) != S_OK) goto done;
if ((hr = write_message( channel, msg, desc->bodyElementDescription, option, body, size )) != S_OK) goto done;
hr = send_message( channel, msg );
@@ -1419,6 +1422,7 @@ HRESULT WINAPI WsSendReplyMessage( WS_CHANNEL *handle, WS_MESSAGE *msg, const WS
if ((hr = message_get_id( request, &req_id )) != S_OK) goto done;
if ((hr = message_set_request_id( msg, &req_id )) != S_OK) goto done;
+ if ((hr = connect_channel( channel )) != S_OK) goto done;
if ((hr = init_writer( channel )) != S_OK) goto done;
if ((hr = write_message( channel, msg, desc->bodyElementDescription, option, body, size )) != S_OK) goto done;
hr = send_message( channel, msg );
@@ -1448,12 +1452,29 @@ static HRESULT resize_read_buffer( struct channel *channel, ULONG size )
return S_OK;
}
+static CALLBACK HRESULT read_callback( void *state, void *buf, ULONG buflen, ULONG *retlen,
+ const WS_ASYNC_CONTEXT *ctx, WS_ERROR *error )
+{
+ SOCKET socket = *(SOCKET *)state;
+ int ret;
+
+ if ((ret = recv( socket, buf, buflen, 0 )) >= 0) *retlen = ret;
+ else
+ {
+ TRACE( "recv failed %u\n", WSAGetLastError() );
+ *retlen = 0;
+ }
+ return S_OK;
+}
+
static HRESULT init_reader( struct channel *channel )
{
WS_XML_READER_BUFFER_INPUT buf = {{WS_XML_READER_INPUT_TYPE_BUFFER}};
+ WS_XML_READER_STREAM_INPUT stream = {{WS_XML_READER_INPUT_TYPE_STREAM}};
WS_XML_READER_TEXT_ENCODING text = {{WS_XML_READER_ENCODING_TYPE_TEXT}};
WS_XML_READER_BINARY_ENCODING bin = {{WS_XML_READER_ENCODING_TYPE_BINARY}};
- WS_XML_READER_ENCODING *encoding;
+ const WS_XML_READER_ENCODING *encoding;
+ const WS_XML_READER_INPUT *input;
HRESULT hr;
if (!channel->reader && (hr = WsCreateReader( NULL, 0, &channel->reader, NULL )) != S_OK) return hr;
@@ -1463,6 +1484,21 @@ static HRESULT init_reader( struct channel *channel )
case WS_ENCODING_XML_UTF8:
text.charSet = WS_CHARSET_UTF8;
encoding = &text.encoding;
+
+ if (channel->binding == WS_UDP_CHANNEL_BINDING ||
+ (channel->binding == WS_TCP_CHANNEL_BINDING && !(channel->type & WS_CHANNEL_TYPE_SESSION)))
+ {
+ stream.readCallback = read_callback;
+ stream.readCallbackState = (channel->binding == WS_UDP_CHANNEL_BINDING) ?
+ &channel->u.udp.socket : &channel->u.tcp.socket;
+ input = &stream.input;
+ }
+ else
+ {
+ buf.encodedData = channel->read_buf;
+ buf.encodedDataSize = channel->read_size;
+ input = &buf.input;
+ }
break;
case WS_ENCODING_XML_BINARY_SESSION_1:
@@ -1472,6 +1508,10 @@ static HRESULT init_reader( struct channel *channel )
case WS_ENCODING_XML_BINARY_1:
encoding = &bin.encoding;
+
+ buf.encodedData = channel->read_buf;
+ buf.encodedDataSize = channel->read_size;
+ input = &buf.input;
break;
default:
@@ -1479,9 +1519,7 @@ static HRESULT init_reader( struct channel *channel )
return WS_E_NOT_SUPPORTED;
}
- buf.encodedData = channel->read_buf;
- buf.encodedDataSize = channel->read_size;
- return WsSetInput( channel->reader, encoding, &buf.input, NULL, 0, NULL );
+ return WsSetInput( channel->reader, encoding, input, NULL, 0, NULL );
}
#define INITIAL_READ_BUFFER_SIZE 4096
@@ -1515,26 +1553,6 @@ static HRESULT receive_message_http( struct channel *channel )
offset += bytes_read;
}
- return init_reader( channel );
-}
-
-static HRESULT receive_message_unsized( struct channel *channel, SOCKET socket )
-{
- int bytes_read;
- ULONG max_len;
- HRESULT hr;
-
- prop_get( channel->prop, channel->prop_count, WS_CHANNEL_PROPERTY_MAX_BUFFERED_MESSAGE_SIZE,
- &max_len, sizeof(max_len) );
-
- if ((hr = resize_read_buffer( channel, max_len )) != S_OK) return hr;
-
- channel->read_size = 0;
- if ((bytes_read = sock_recv( socket, channel->read_buf, max_len )) < 0)
- {
- return HRESULT_FROM_WIN32( WSAGetLastError() );
- }
- channel->read_size = bytes_read;
return S_OK;
}
@@ -1549,7 +1567,7 @@ static HRESULT receive_message_sized( struct channel *channel, unsigned int size
channel->read_size = 0;
while (channel->read_size < size)
{
- if ((bytes_read = sock_recv( channel->u.tcp.socket, channel->read_buf + offset, to_read )) < 0)
+ if ((bytes_read = recv( channel->u.tcp.socket, channel->read_buf + offset, to_read, 0 )) < 0)
{
return HRESULT_FROM_WIN32( WSAGetLastError() );
}
@@ -1798,14 +1816,7 @@ static HRESULT receive_message_session( struct channel *channel )
memmove( channel->read_buf, channel->read_buf + size, channel->read_size );
}
- return init_reader( channel );
-}
-
-static HRESULT receive_message_sock( struct channel *channel, SOCKET socket )
-{
- HRESULT hr;
- if ((hr = receive_message_unsized( channel, socket )) != S_OK) return hr;
- return init_reader( channel );
+ return S_OK;
}
static HRESULT receive_message_bytes( struct channel *channel )
@@ -1819,7 +1830,7 @@ static HRESULT receive_message_bytes( struct channel *channel )
return receive_message_http( channel );
case WS_TCP_CHANNEL_BINDING:
- if (channel->encoding == WS_ENCODING_XML_BINARY_SESSION_1)
+ if (channel->type & WS_CHANNEL_TYPE_SESSION)
{
switch (channel->session_state)
{
@@ -1836,10 +1847,10 @@ static HRESULT receive_message_bytes( struct channel *channel )
return WS_E_OTHER;
}
}
- return receive_message_sock( channel, channel->u.tcp.socket );
+ return S_OK; /* nothing to do, data is read through stream callback */
case WS_UDP_CHANNEL_BINDING:
- return receive_message_sock( channel, channel->u.udp.socket );
+ return S_OK;
default:
ERR( "unhandled binding %u\n", channel->binding );
@@ -1860,7 +1871,7 @@ HRESULT channel_receive_message( WS_CHANNEL *handle )
return E_INVALIDARG;
}
- hr = receive_message_bytes( channel );
+ if ((hr = receive_message_bytes( channel )) == S_OK) hr = init_reader( channel );
LeaveCriticalSection( &channel->cs );
return hr;
@@ -1901,6 +1912,8 @@ static HRESULT receive_message( struct channel *channel, WS_MESSAGE *msg, const
ULONG i;
if ((hr = receive_message_bytes( channel )) != S_OK) return hr;
+ if ((hr = init_reader( channel )) != S_OK) return hr;
+
for (i = 0; i < count; i++)
{
const WS_ELEMENT_DESCRIPTION *body = desc[i]->bodyElementDescription;
@@ -2012,6 +2025,7 @@ static HRESULT request_reply( struct channel *channel, WS_MESSAGE *request,
if ((hr = WsAddressMessage( request, &channel->addr, NULL )) != S_OK) return hr;
if ((hr = message_set_action( request, request_desc->action )) != S_OK) return hr;
+ if ((hr = connect_channel( channel )) != S_OK) return hr;
if ((hr = init_writer( channel )) != S_OK) return hr;
if ((hr = write_message( channel, request, request_desc->bodyElementDescription, write_option, request_body,
request_size )) != S_OK) return hr;
@@ -2141,7 +2155,8 @@ HRESULT WINAPI WsReadMessageStart( WS_CHANNEL *handle, WS_MESSAGE *msg, const WS
if ((hr = receive_message_bytes( channel )) == S_OK)
{
- hr = WsReadEnvelopeStart( msg, channel->reader, NULL, NULL, NULL );
+ if ((hr = init_reader( channel )) == S_OK)
+ hr = WsReadEnvelopeStart( msg, channel->reader, NULL, NULL, NULL );
}
LeaveCriticalSection( &channel->cs );
@@ -2202,6 +2217,7 @@ HRESULT WINAPI WsWriteMessageStart( WS_CHANNEL *handle, WS_MESSAGE *msg, const W
return E_INVALIDARG;
}
+ if ((hr = connect_channel( channel )) != S_OK) goto done;
if ((hr = init_writer( channel )) != S_OK) goto done;
if ((hr = WsAddressMessage( msg, &channel->addr, NULL )) != S_OK) goto done;
hr = WsWriteEnvelopeStart( msg, channel->writer, NULL, NULL, NULL );
@@ -2235,13 +2251,20 @@ HRESULT WINAPI WsWriteMessageEnd( WS_CHANNEL *handle, WS_MESSAGE *msg, const WS_
return E_INVALIDARG;
}
- if ((hr = WsWriteEnvelopeEnd( msg, NULL )) == S_OK) hr = send_message( channel, msg );
+ if ((hr = WsWriteEnvelopeEnd( msg, NULL )) == S_OK && (hr = connect_channel( channel ) == S_OK))
+ hr = send_message( channel, msg );
LeaveCriticalSection( &channel->cs );
TRACE( "returning %08x\n", hr );
return hr;
}
+static void set_blocking( SOCKET socket, BOOL blocking )
+{
+ ULONG state = !blocking;
+ ioctlsocket( socket, FIONBIO, &state );
+}
+
static HRESULT sock_accept( SOCKET socket, HANDLE wait, HANDLE cancel, SOCKET *ret )
{
HANDLE handles[] = { wait, cancel };
--
2.11.0