Pattern for calling WCF service using async/await
I think a feasible solution might be to use a custom awaiter to flow the new operation context via OperationContext.Current
. The implementation of OperationContext
itself doesn't appear to require thread affinity. Here is the pattern:
async Task TestAsync(){ using(var client = new WcfAPM.ServiceClient()) using (var scope = new FlowingOperationContextScope(client.InnerChannel)) { await client.SomeMethodAsync(1).ContinueOnScope(scope); await client.AnotherMethodAsync(2).ContinueOnScope(scope); }}
Here is the implementation of FlowingOperationContextScope
and ContinueOnScope
(only slightly tested):
public sealed class FlowingOperationContextScope : IDisposable{ bool _inflight = false; bool _disposed; OperationContext _thisContext = null; OperationContext _originalContext = null; public FlowingOperationContextScope(IContextChannel channel): this(new OperationContext(channel)) { } public FlowingOperationContextScope(OperationContext context) { _originalContext = OperationContext.Current; OperationContext.Current = _thisContext = context; } public void Dispose() { if (!_disposed) { if (_inflight || OperationContext.Current != _thisContext) throw new InvalidOperationException(); _disposed = true; OperationContext.Current = _originalContext; _thisContext = null; _originalContext = null; } } internal void BeforeAwait() { if (_inflight) return; _inflight = true; // leave _thisContext as the current context } internal void AfterAwait() { if (!_inflight) throw new InvalidOperationException(); _inflight = false; // ignore the current context, restore _thisContext OperationContext.Current = _thisContext; }}// ContinueOnScope extensionpublic static class TaskExt{ public static SimpleAwaiter<TResult> ContinueOnScope<TResult>(this Task<TResult> @this, FlowingOperationContextScope scope) { return new SimpleAwaiter<TResult>(@this, scope.BeforeAwait, scope.AfterAwait); } // awaiter public class SimpleAwaiter<TResult> : System.Runtime.CompilerServices.INotifyCompletion { readonly Task<TResult> _task; readonly Action _beforeAwait; readonly Action _afterAwait; public SimpleAwaiter(Task<TResult> task, Action beforeAwait, Action afterAwait) { _task = task; _beforeAwait = beforeAwait; _afterAwait = afterAwait; } public SimpleAwaiter<TResult> GetAwaiter() { return this; } public bool IsCompleted { get { // don't do anything if the task completed synchronously // (we're on the same thread) if (_task.IsCompleted) return true; _beforeAwait(); return false; } } public TResult GetResult() { return _task.Result; } // INotifyCompletion public void OnCompleted(Action continuation) { _task.ContinueWith(task => { _afterAwait(); continuation(); }, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously, SynchronizationContext.Current != null ? TaskScheduler.FromCurrentSynchronizationContext() : TaskScheduler.Current); } }}
Simple way is to move the await outside the using block
public Task<Document> GetDocumentAsync(string docId){ var docClient = CreateDocumentServiceClient(); using (new OperationContextScope(docClient.InnerChannel)) { var task = docClient.GetDocumentAsync(docId); } return await task;}
I decide to write my own code that helps with this, posting in case this helps anyone. Seems to be a little less to go wrong (unforeseen races etc) vs the SimpleAwaiter implementation above but you be the judge:
public static class WithOperationContextTaskExtensions{ public static ContinueOnOperationContextAwaiter<TResult> WithOperationContext<TResult>(this Task<TResult> @this, bool configureAwait = true) { return new ContinueOnOperationContextAwaiter<TResult>(@this, configureAwait); } public static ContinueOnOperationContextAwaiter WithOperationContext(this Task @this, bool configureAwait = true) { return new ContinueOnOperationContextAwaiter(@this, configureAwait); } public class ContinueOnOperationContextAwaiter : INotifyCompletion { private readonly ConfiguredTaskAwaitable.ConfiguredTaskAwaiter _awaiter; private OperationContext _operationContext; public ContinueOnOperationContextAwaiter(Task task, bool continueOnCapturedContext = true) { if (task == null) throw new ArgumentNullException("task"); _awaiter = task.ConfigureAwait(continueOnCapturedContext).GetAwaiter(); } public ContinueOnOperationContextAwaiter GetAwaiter() { return this; } public bool IsCompleted { get { return _awaiter.IsCompleted; } } public void OnCompleted(Action continuation) { _operationContext = OperationContext.Current; _awaiter.OnCompleted(continuation); } public void GetResult() { OperationContext.Current = _operationContext; _awaiter.GetResult(); } } public class ContinueOnOperationContextAwaiter<TResult> : INotifyCompletion { private readonly ConfiguredTaskAwaitable<TResult>.ConfiguredTaskAwaiter _awaiter; private OperationContext _operationContext; public ContinueOnOperationContextAwaiter(Task<TResult> task, bool continueOnCapturedContext = true) { if (task == null) throw new ArgumentNullException("task"); _awaiter = task.ConfigureAwait(continueOnCapturedContext).GetAwaiter(); } public ContinueOnOperationContextAwaiter<TResult> GetAwaiter() { return this; } public bool IsCompleted { get { return _awaiter.IsCompleted; } } public void OnCompleted(Action continuation) { _operationContext = OperationContext.Current; _awaiter.OnCompleted(continuation); } public TResult GetResult() { OperationContext.Current = _operationContext; return _awaiter.GetResult(); } }}
Usage (a little manual and nesting is untested...):
/// <summary> /// Make a call to the service /// </summary> /// <param name="action"></param> /// <param name="endpoint"> </param> public async Task<ResultCallWrapper<TResult>> CallAsync<TResult>(Func<T, Task<TResult>> action, EndpointAddress endpoint) { using (ChannelLifetime<T> channelLifetime = new ChannelLifetime<T>(ConstructChannel(endpoint))) { // OperationContextScope doesn't work with async/await var oldContext = OperationContext.Current; OperationContext.Current = new OperationContext((IContextChannel)channelLifetime.Channel); var result = await action(channelLifetime.Channel) .WithOperationContext(configureAwait: false); HttpResponseMessageProperty incomingMessageProperty = (HttpResponseMessageProperty)OperationContext.Current.IncomingMessageProperties[HttpResponseMessageProperty.Name]; string[] keys = incomingMessageProperty.Headers.AllKeys; var headersOrig = keys.ToDictionary(t => t, t => incomingMessageProperty.Headers[t]); OperationContext.Current = oldContext; return new ResultCallWrapper<TResult>(result, new ReadOnlyDictionary<string, string>(headersOrig)); } }