From: Connor McAdams cmcadams@codeweavers.com
v3: Use a helper to check channel state.
Signed-off-by: Connor McAdams cmcadams@codeweavers.com Signed-off-by: Hans Leidekker hans@codeweavers.com --- dlls/webservices/channel.c | 66 ++++++++++++++++++++------------ dlls/webservices/tests/channel.c | 17 +++++++- 2 files changed, 57 insertions(+), 26 deletions(-)
diff --git a/dlls/webservices/channel.c b/dlls/webservices/channel.c index 7dd78335e11..56ad5692a6b 100644 --- a/dlls/webservices/channel.c +++ b/dlls/webservices/channel.c @@ -328,7 +328,7 @@ static void reset_channel( struct channel *channel ) channel->session_state = SESSION_STATE_UNINITIALIZED; clear_addr( &channel->addr ); init_dict( &channel->dict_send, channel->dict_size ); - init_dict( &channel->dict_recv, 0 ); + init_dict( &channel->dict_recv, channel->dict_size ); channel->msg = NULL; channel->read_size = 0; channel->send_size = 0; @@ -486,6 +486,7 @@ static HRESULT create_channel( WS_CHANNEL_TYPE type, WS_CHANNEL_BINDING binding, channel->encoding = WS_ENCODING_XML_BINARY_SESSION_1; channel->dict_size = 2048; channel->dict_send.str_bytes_max = channel->dict_size; + channel->dict_recv.str_bytes_max = channel->dict_size; break;
case WS_UDP_CHANNEL_BINDING: @@ -546,6 +547,7 @@ static HRESULT create_channel( WS_CHANNEL_TYPE type, WS_CHANNEL_BINDING binding,
channel->dict_size = *(ULONG *)prop->value; channel->dict_send.str_bytes_max = channel->dict_size; + channel->dict_recv.str_bytes_max = channel->dict_size; break;
default: @@ -875,6 +877,13 @@ static HRESULT queue_shutdown_session( struct channel *channel, const WS_ASYNC_C return queue_task( &channel->send_q, &s->task ); }
+static HRESULT check_state( struct channel *channel, WS_CHANNEL_STATE state_expected ) +{ + if (channel->state == WS_CHANNEL_STATE_FAULTED) return WS_E_OBJECT_FAULTED; + if (channel->state != state_expected) return WS_E_INVALID_OPERATION; + return S_OK; +} + HRESULT WINAPI WsShutdownSessionChannel( WS_CHANNEL *handle, const WS_ASYNC_CONTEXT *ctx, WS_ERROR *error ) { struct channel *channel = (struct channel *)handle; @@ -894,10 +903,10 @@ HRESULT WINAPI WsShutdownSessionChannel( WS_CHANNEL *handle, const WS_ASYNC_CONT LeaveCriticalSection( &channel->cs ); return E_INVALIDARG; } - if (channel->state != WS_CHANNEL_STATE_OPEN) + if ((hr = check_state( channel, WS_CHANNEL_STATE_OPEN )) != S_OK) { LeaveCriticalSection( &channel->cs ); - return WS_E_INVALID_OPERATION; + return hr; }
if (!ctx) async_init( &async, &ctx_local ); @@ -1258,10 +1267,10 @@ HRESULT WINAPI WsOpenChannel( WS_CHANNEL *handle, const WS_ENDPOINT_ADDRESS *end LeaveCriticalSection( &channel->cs ); return E_INVALIDARG; } - if (channel->state != WS_CHANNEL_STATE_CREATED) + if ((hr = check_state( channel, WS_CHANNEL_STATE_CREATED )) != S_OK) { LeaveCriticalSection( &channel->cs ); - return WS_E_INVALID_OPERATION; + return hr; }
if (!ctx) async_init( &async, &ctx_local ); @@ -1592,10 +1601,10 @@ HRESULT channel_send_message( WS_CHANNEL *handle, WS_MESSAGE *msg ) LeaveCriticalSection( &channel->cs ); return E_INVALIDARG; } - if (channel->state != WS_CHANNEL_STATE_OPEN) + if ((hr = check_state( channel, WS_CHANNEL_STATE_OPEN )) != S_OK) { LeaveCriticalSection( &channel->cs ); - return WS_E_INVALID_OPERATION; + return hr; }
hr = send_message_bytes( channel, msg ); @@ -1786,10 +1795,10 @@ HRESULT WINAPI WsSendMessage( WS_CHANNEL *handle, WS_MESSAGE *msg, const WS_MESS LeaveCriticalSection( &channel->cs ); return E_INVALIDARG; } - if (channel->state != WS_CHANNEL_STATE_OPEN) + if ((hr = check_state( channel, WS_CHANNEL_STATE_OPEN )) != S_OK) { LeaveCriticalSection( &channel->cs ); - return WS_E_INVALID_OPERATION; + return hr; }
WsInitializeMessage( msg, WS_BLANK_MESSAGE, NULL, NULL ); @@ -1832,10 +1841,10 @@ HRESULT WINAPI WsSendReplyMessage( WS_CHANNEL *handle, WS_MESSAGE *msg, const WS LeaveCriticalSection( &channel->cs ); return E_INVALIDARG; } - if (channel->state != WS_CHANNEL_STATE_OPEN) + if ((hr = check_state( channel, WS_CHANNEL_STATE_OPEN )) != S_OK) { LeaveCriticalSection( &channel->cs ); - return WS_E_INVALID_OPERATION; + return hr; }
WsInitializeMessage( msg, WS_REPLY_MESSAGE, NULL, NULL ); @@ -2201,6 +2210,11 @@ static HRESULT build_dict( const BYTE *buf, ULONG buflen, struct dictionary *dic init_dict( dict, 0 ); return WS_E_INVALID_FORMAT; } + if (size + dict->str_bytes + 1 > dict->str_bytes_max) + { + hr = WS_E_QUOTA_EXCEEDED; + goto error; + } buflen -= size; if (!(bytes = malloc( size ))) { @@ -2246,7 +2260,11 @@ static HRESULT receive_message_bytes_session( struct channel *channel ) { ULONG size; if ((hr = build_dict( (const BYTE *)channel->read_buf, channel->read_size, &channel->dict_recv, - &size )) != S_OK) return hr; + &size )) != S_OK) + { + if (hr == WS_E_QUOTA_EXCEEDED) channel->state = WS_CHANNEL_STATE_FAULTED; + return hr; + } channel->read_size -= size; memmove( channel->read_buf, channel->read_buf + size, channel->read_size ); } @@ -2303,10 +2321,10 @@ HRESULT channel_receive_message( WS_CHANNEL *handle, WS_MESSAGE *msg ) LeaveCriticalSection( &channel->cs ); return E_INVALIDARG; } - if (channel->state != WS_CHANNEL_STATE_OPEN) + if ((hr = check_state( channel, WS_CHANNEL_STATE_OPEN )) != S_OK) { LeaveCriticalSection( &channel->cs ); - return WS_E_INVALID_OPERATION; + return hr; }
if ((hr = receive_message_bytes( channel, msg )) == S_OK) hr = init_reader( channel ); @@ -2443,10 +2461,10 @@ HRESULT WINAPI WsReceiveMessage( WS_CHANNEL *handle, WS_MESSAGE *msg, const WS_M LeaveCriticalSection( &channel->cs ); return E_INVALIDARG; } - if (channel->state != WS_CHANNEL_STATE_OPEN) + if ((hr = check_state( channel, WS_CHANNEL_STATE_OPEN )) != S_OK) { LeaveCriticalSection( &channel->cs ); - return WS_E_INVALID_OPERATION; + return hr; }
if (!ctx) async_init( &async, &ctx_local ); @@ -2561,10 +2579,10 @@ HRESULT WINAPI WsRequestReply( WS_CHANNEL *handle, WS_MESSAGE *request, const WS LeaveCriticalSection( &channel->cs ); return E_INVALIDARG; } - if (channel->state != WS_CHANNEL_STATE_OPEN) + if ((hr = check_state( channel, WS_CHANNEL_STATE_OPEN )) != S_OK) { LeaveCriticalSection( &channel->cs ); - return WS_E_INVALID_OPERATION; + return hr; }
WsInitializeMessage( request, WS_REQUEST_MESSAGE, NULL, NULL ); @@ -2645,10 +2663,10 @@ HRESULT WINAPI WsReadMessageStart( WS_CHANNEL *handle, WS_MESSAGE *msg, const WS LeaveCriticalSection( &channel->cs ); return E_INVALIDARG; } - if (channel->state != WS_CHANNEL_STATE_OPEN) + if ((hr = check_state( channel, WS_CHANNEL_STATE_OPEN )) != S_OK) { LeaveCriticalSection( &channel->cs ); - return WS_E_INVALID_OPERATION; + return hr; }
if (!ctx) async_init( &async, &ctx_local ); @@ -2797,10 +2815,10 @@ HRESULT WINAPI WsWriteMessageStart( WS_CHANNEL *handle, WS_MESSAGE *msg, const W LeaveCriticalSection( &channel->cs ); return E_INVALIDARG; } - if (channel->state != WS_CHANNEL_STATE_OPEN) + if ((hr = check_state( channel, WS_CHANNEL_STATE_OPEN )) != S_OK) { LeaveCriticalSection( &channel->cs ); - return WS_E_INVALID_OPERATION; + return hr; }
if (!ctx) async_init( &async, &ctx_local ); @@ -2877,10 +2895,10 @@ HRESULT WINAPI WsWriteMessageEnd( WS_CHANNEL *handle, WS_MESSAGE *msg, const WS_ LeaveCriticalSection( &channel->cs ); return E_INVALIDARG; } - if (channel->state != WS_CHANNEL_STATE_OPEN) + if ((hr = check_state( channel, WS_CHANNEL_STATE_OPEN )) != S_OK) { LeaveCriticalSection( &channel->cs ); - return WS_E_INVALID_OPERATION; + return hr; }
if (!ctx) async_init( &async, &ctx_local ); diff --git a/dlls/webservices/tests/channel.c b/dlls/webservices/tests/channel.c index d5d35904c1c..7e0a9becbc5 100644 --- a/dlls/webservices/tests/channel.c +++ b/dlls/webservices/tests/channel.c @@ -822,6 +822,7 @@ static void client_duplex_session_dict( const struct listener_info *info ) WS_MESSAGE_DESCRIPTION desc; WS_ENDPOINT_ADDRESS addr; WS_CHANNEL_PROPERTY prop; + WS_CHANNEL_STATE state; int dict_str_cnt = 0; char elem_name[128]; WS_CHANNEL *channel; @@ -893,12 +894,24 @@ static void client_duplex_session_dict( const struct listener_info *info ) local_name.bytes = (BYTE *)short_dict_str; hr = WsReceiveMessage( channel, msg, descs, 1, WS_RECEIVE_REQUIRED_MESSAGE, WS_READ_REQUIRED_VALUE, NULL, &val, sizeof(val), NULL, NULL, NULL ); - todo_wine ok( hr == WS_E_QUOTA_EXCEEDED, "got %#lx\n", hr); + ok( hr == WS_E_QUOTA_EXCEEDED, "got %#lx\n", hr); + + state = 0xdeadbeef; + hr = WsGetChannelProperty( channel, WS_CHANNEL_PROPERTY_STATE, &state, sizeof(state), NULL ); + ok( hr == S_OK, "got %#lx\n", hr ); + ok( state == WS_CHANNEL_STATE_FAULTED, "got %u\n", state ); + + hr = WsReceiveMessage( channel, msg, descs, 1, WS_RECEIVE_REQUIRED_MESSAGE, WS_READ_REQUIRED_VALUE, + NULL, &val, sizeof(val), NULL, NULL, NULL ); + ok( hr == WS_E_OBJECT_FAULTED, "got %#lx\n", hr ); + + hr = WsSendMessage( channel, msg, &desc, WS_WRITE_REQUIRED_VALUE, &val, sizeof(val), NULL, NULL ); + ok( hr == WS_E_OBJECT_FAULTED, "got %#lx\n", hr );
WsFreeMessage( msg );
hr = WsShutdownSessionChannel( channel, NULL, NULL ); - todo_wine ok( hr == WS_E_OBJECT_FAULTED, "got %#lx\n", hr ); + ok( hr == WS_E_OBJECT_FAULTED, "got %#lx\n", hr );
hr = WsCloseChannel( channel, NULL, NULL ); ok( hr == S_OK, "got %#lx\n", hr );