1 people like it.

Async.AwaitTask that honors ambient cancellation semantics

Related to https://github.com/Microsoft/visualfsharp/issues/2127

 1: 
 2: 
 3: 
 4: 
 5: 
 6: 
 7: 
 8: 
 9: 
10: 
11: 
12: 
13: 
14: 
15: 
16: 
17: 
18: 
19: 
20: 
21: 
22: 
23: 
24: 
25: 
26: 
27: 
28: 
29: 
30: 
31: 
32: 
33: 
34: 
35: 
36: 
37: 
38: 
39: 
40: 
41: 
42: 
43: 
44: 
45: 
46: 
47: 
48: 
49: 
50: 
51: 
52: 
53: 
54: 
55: 
56: 
57: 
58: 
59: 
60: 
61: 
62: 
63: 
64: 
open System
open System.Threading
open System.Threading.Tasks

type private Latch() =
    let mutable counter = 0
    member inline __.Enter() = Interlocked.Increment &counter = 1

type Async with
    static member AwaitTask2(task : Task<'T>) : Async<'T> = async {
        let! ct = Async.CancellationToken
        return! Async.FromContinuations(fun (sc,ec,cc) ->
            let l = new Latch()
            let ctrDisposer = ct.Register(fun _ -> if l.Enter() then cc(OperationCanceledException()))
            task.ContinueWith(fun (task:Task<'T>) ->
                if l.Enter() then
                    ctrDisposer.Dispose()
                    if task.IsFaulted then
                        let e = task.Exception
                        if e.InnerExceptions.Count = 1 then ec e.InnerExceptions.[0]
                        else ec e
                    elif task.IsCanceled then
                        ec(TaskCanceledException())
                    else
                        sc task.Result)
            |> ignore)
    }

    static member AwaitTask2(task : Task) : Async<unit> = async {
        let! ct = Async.CancellationToken
        return! Async.FromContinuations(fun (sc,ec,cc) ->
            let l = new Latch()
            let ctrDisposer = ct.Register(fun _ -> if l.Enter() then cc(OperationCanceledException()))
            task.ContinueWith(fun (task:Task) ->
                if l.Enter() then
                    ctrDisposer.Dispose()
                    if task.IsFaulted then
                        let e = task.Exception
                        if e.InnerExceptions.Count = 1 then ec e.InnerExceptions.[0]
                        else ec e
                    elif task.IsCanceled then
                        ec(TaskCanceledException())
                    else
                        sc ())
            |> ignore)
    }

// example

let test taskAwaiter = 
    async {
        let tcs = new TaskCompletionSource<unit>()
        let worker i = async {
            if i < 9 then
                do! taskAwaiter tcs.Task
            else
                do failwith "error"
        }

        return! Seq.init 10 worker |> Async.Parallel
    } |> Async.RunSynchronously

test Async.AwaitTask  // hangs forever
test Async.AwaitTask2 // fails with exceptions
namespace System
namespace System.Threading
namespace System.Threading.Tasks
Multiple items
type private Latch =
  new : unit -> Latch
  member Enter : unit -> bool

Full name: Script.Latch

--------------------
private new : unit -> Latch
val mutable counter : int
member private Latch.Enter : unit -> bool

Full name: Script.Latch.Enter
type Interlocked =
  static member Add : location1:int * value:int -> int + 1 overload
  static member CompareExchange : location1:int * value:int * comparand:int -> int + 6 overloads
  static member Decrement : location:int -> int + 1 overload
  static member Exchange : location1:int * value:int -> int + 6 overloads
  static member Increment : location:int -> int + 1 overload
  static member Read : location:int64 -> int64

Full name: System.Threading.Interlocked
Interlocked.Increment(location: byref<int64>) : int64
Interlocked.Increment(location: byref<int>) : int
Multiple items
type Async
static member AsBeginEnd : computation:('Arg -> Async<'T>) -> ('Arg * AsyncCallback * obj -> IAsyncResult) * (IAsyncResult -> 'T) * (IAsyncResult -> unit)
static member AwaitEvent : event:IEvent<'Del,'T> * ?cancelAction:(unit -> unit) -> Async<'T> (requires delegate and 'Del :> Delegate)
static member AwaitIAsyncResult : iar:IAsyncResult * ?millisecondsTimeout:int -> Async<bool>
static member AwaitTask : task:Task -> Async<unit>
static member AwaitTask : task:Task<'T> -> Async<'T>
static member AwaitWaitHandle : waitHandle:WaitHandle * ?millisecondsTimeout:int -> Async<bool>
static member CancelDefaultToken : unit -> unit
static member Catch : computation:Async<'T> -> Async<Choice<'T,exn>>
static member FromBeginEnd : beginAction:(AsyncCallback * obj -> IAsyncResult) * endAction:(IAsyncResult -> 'T) * ?cancelAction:(unit -> unit) -> Async<'T>
static member FromBeginEnd : arg:'Arg1 * beginAction:('Arg1 * AsyncCallback * obj -> IAsyncResult) * endAction:(IAsyncResult -> 'T) * ?cancelAction:(unit -> unit) -> Async<'T>
static member FromBeginEnd : arg1:'Arg1 * arg2:'Arg2 * beginAction:('Arg1 * 'Arg2 * AsyncCallback * obj -> IAsyncResult) * endAction:(IAsyncResult -> 'T) * ?cancelAction:(unit -> unit) -> Async<'T>
static member FromBeginEnd : arg1:'Arg1 * arg2:'Arg2 * arg3:'Arg3 * beginAction:('Arg1 * 'Arg2 * 'Arg3 * AsyncCallback * obj -> IAsyncResult) * endAction:(IAsyncResult -> 'T) * ?cancelAction:(unit -> unit) -> Async<'T>
static member FromContinuations : callback:(('T -> unit) * (exn -> unit) * (OperationCanceledException -> unit) -> unit) -> Async<'T>
static member Ignore : computation:Async<'T> -> Async<unit>
static member OnCancel : interruption:(unit -> unit) -> Async<IDisposable>
static member Parallel : computations:seq<Async<'T>> -> Async<'T []>
static member RunSynchronously : computation:Async<'T> * ?timeout:int * ?cancellationToken:CancellationToken -> 'T
static member Sleep : millisecondsDueTime:int -> Async<unit>
static member Start : computation:Async<unit> * ?cancellationToken:CancellationToken -> unit
static member StartAsTask : computation:Async<'T> * ?taskCreationOptions:TaskCreationOptions * ?cancellationToken:CancellationToken -> Task<'T>
static member StartChild : computation:Async<'T> * ?millisecondsTimeout:int -> Async<Async<'T>>
static member StartChildAsTask : computation:Async<'T> * ?taskCreationOptions:TaskCreationOptions -> Async<Task<'T>>
static member StartImmediate : computation:Async<unit> * ?cancellationToken:CancellationToken -> unit
static member StartWithContinuations : computation:Async<'T> * continuation:('T -> unit) * exceptionContinuation:(exn -> unit) * cancellationContinuation:(OperationCanceledException -> unit) * ?cancellationToken:CancellationToken -> unit
static member SwitchToContext : syncContext:SynchronizationContext -> Async<unit>
static member SwitchToNewThread : unit -> Async<unit>
static member SwitchToThreadPool : unit -> Async<unit>
static member TryCancelled : computation:Async<'T> * compensation:(OperationCanceledException -> unit) -> Async<'T>
static member CancellationToken : Async<CancellationToken>
static member DefaultCancellationToken : CancellationToken

Full name: Microsoft.FSharp.Control.Async

--------------------
type Async<'T>

Full name: Microsoft.FSharp.Control.Async<_>
static member Async.AwaitTask2 : task:Task<'T> -> Async<'T>

Full name: Script.AwaitTask2
val task : Task<'T>
Multiple items
type Task =
  new : action:Action -> Task + 7 overloads
  member AsyncState : obj
  member ContinueWith : continuationAction:Action<Task> -> Task + 9 overloads
  member CreationOptions : TaskCreationOptions
  member Dispose : unit -> unit
  member Exception : AggregateException
  member Id : int
  member IsCanceled : bool
  member IsCompleted : bool
  member IsFaulted : bool
  ...

Full name: System.Threading.Tasks.Task

--------------------
type Task<'TResult> =
  inherit Task
  new : function:Func<'TResult> -> Task<'TResult> + 7 overloads
  member ContinueWith : continuationAction:Action<Task<'TResult>> -> Task + 9 overloads
  member Result : 'TResult with get, set
  static member Factory : TaskFactory<'TResult>

Full name: System.Threading.Tasks.Task<_>

--------------------
Task(action: Action) : unit
Task(action: Action, cancellationToken: CancellationToken) : unit
Task(action: Action, creationOptions: TaskCreationOptions) : unit
Task(action: Action<obj>, state: obj) : unit
Task(action: Action, cancellationToken: CancellationToken, creationOptions: TaskCreationOptions) : unit
Task(action: Action<obj>, state: obj, cancellationToken: CancellationToken) : unit
Task(action: Action<obj>, state: obj, creationOptions: TaskCreationOptions) : unit
Task(action: Action<obj>, state: obj, cancellationToken: CancellationToken, creationOptions: TaskCreationOptions) : unit

--------------------
Task(function: Func<'TResult>) : unit
Task(function: Func<'TResult>, cancellationToken: CancellationToken) : unit
Task(function: Func<'TResult>, creationOptions: TaskCreationOptions) : unit
Task(function: Func<obj,'TResult>, state: obj) : unit
Task(function: Func<'TResult>, cancellationToken: CancellationToken, creationOptions: TaskCreationOptions) : unit
Task(function: Func<obj,'TResult>, state: obj, cancellationToken: CancellationToken) : unit
Task(function: Func<obj,'TResult>, state: obj, creationOptions: TaskCreationOptions) : unit
Task(function: Func<obj,'TResult>, state: obj, cancellationToken: CancellationToken, creationOptions: TaskCreationOptions) : unit
val async : AsyncBuilder

Full name: Microsoft.FSharp.Core.ExtraTopLevelOperators.async
val ct : CancellationToken
property Async.CancellationToken: Async<CancellationToken>
static member Async.FromContinuations : callback:(('T -> unit) * (exn -> unit) * (OperationCanceledException -> unit) -> unit) -> Async<'T>
val sc : ('T -> unit)
val ec : (exn -> unit)
val cc : (OperationCanceledException -> unit)
val l : Latch
val ctrDisposer : CancellationTokenRegistration
CancellationToken.Register(callback: Action) : CancellationTokenRegistration
CancellationToken.Register(callback: Action<obj>, state: obj) : CancellationTokenRegistration
CancellationToken.Register(callback: Action, useSynchronizationContext: bool) : CancellationTokenRegistration
CancellationToken.Register(callback: Action<obj>, state: obj, useSynchronizationContext: bool) : CancellationTokenRegistration
member private Latch.Enter : unit -> bool
Multiple items
type OperationCanceledException =
  inherit SystemException
  new : unit -> OperationCanceledException + 5 overloads
  member CancellationToken : CancellationToken with get, set

Full name: System.OperationCanceledException

--------------------
OperationCanceledException() : unit
OperationCanceledException(message: string) : unit
OperationCanceledException(token: CancellationToken) : unit
OperationCanceledException(message: string, innerException: exn) : unit
OperationCanceledException(message: string, token: CancellationToken) : unit
OperationCanceledException(message: string, innerException: exn, token: CancellationToken) : unit
Task.ContinueWith<'TResult>(continuationFunction: Func<Task,'TResult>) : Task<'TResult>
   (+0 other overloads)
Task.ContinueWith(continuationAction: Action<Task>) : Task
   (+0 other overloads)
Task.ContinueWith<'TNewResult>(continuationFunction: Func<Task<'T>,'TNewResult>) : Task<'TNewResult>
   (+0 other overloads)
Task.ContinueWith(continuationAction: Action<Task<'T>>) : Task
   (+0 other overloads)
Task.ContinueWith<'TResult>(continuationFunction: Func<Task,'TResult>, continuationOptions: TaskContinuationOptions) : Task<'TResult>
   (+0 other overloads)
Task.ContinueWith<'TResult>(continuationFunction: Func<Task,'TResult>, scheduler: TaskScheduler) : Task<'TResult>
   (+0 other overloads)
Task.ContinueWith<'TResult>(continuationFunction: Func<Task,'TResult>, cancellationToken: CancellationToken) : Task<'TResult>
   (+0 other overloads)
Task.ContinueWith(continuationAction: Action<Task>, continuationOptions: TaskContinuationOptions) : Task
   (+0 other overloads)
Task.ContinueWith(continuationAction: Action<Task>, scheduler: TaskScheduler) : Task
   (+0 other overloads)
Task.ContinueWith(continuationAction: Action<Task>, cancellationToken: CancellationToken) : Task
   (+0 other overloads)
CancellationTokenRegistration.Dispose() : unit
property Task.IsFaulted: bool
val e : AggregateException
property Task.Exception: AggregateException
property AggregateException.InnerExceptions: Collections.ObjectModel.ReadOnlyCollection<exn>
property Collections.ObjectModel.ReadOnlyCollection.Count: int
property Task.IsCanceled: bool
Multiple items
type TaskCanceledException =
  inherit OperationCanceledException
  new : unit -> TaskCanceledException + 3 overloads
  member Task : Task

Full name: System.Threading.Tasks.TaskCanceledException

--------------------
TaskCanceledException() : unit
TaskCanceledException(message: string) : unit
TaskCanceledException(task: Task) : unit
TaskCanceledException(message: string, innerException: exn) : unit
property Task.Result: 'T
val ignore : value:'T -> unit

Full name: Microsoft.FSharp.Core.Operators.ignore
static member Async.AwaitTask2 : task:Task -> Async<unit>

Full name: Script.AwaitTask2
val task : Task
type unit = Unit

Full name: Microsoft.FSharp.Core.unit
val sc : (unit -> unit)
Task.ContinueWith<'TResult>(continuationFunction: Func<Task,'TResult>) : Task<'TResult>
Task.ContinueWith(continuationAction: Action<Task>) : Task
Task.ContinueWith<'TResult>(continuationFunction: Func<Task,'TResult>, continuationOptions: TaskContinuationOptions) : Task<'TResult>
Task.ContinueWith<'TResult>(continuationFunction: Func<Task,'TResult>, scheduler: TaskScheduler) : Task<'TResult>
Task.ContinueWith<'TResult>(continuationFunction: Func<Task,'TResult>, cancellationToken: CancellationToken) : Task<'TResult>
Task.ContinueWith(continuationAction: Action<Task>, continuationOptions: TaskContinuationOptions) : Task
Task.ContinueWith(continuationAction: Action<Task>, scheduler: TaskScheduler) : Task
Task.ContinueWith(continuationAction: Action<Task>, cancellationToken: CancellationToken) : Task
Task.ContinueWith<'TResult>(continuationFunction: Func<Task,'TResult>, cancellationToken: CancellationToken, continuationOptions: TaskContinuationOptions, scheduler: TaskScheduler) : Task<'TResult>
Task.ContinueWith(continuationAction: Action<Task>, cancellationToken: CancellationToken, continuationOptions: TaskContinuationOptions, scheduler: TaskScheduler) : Task
val test : taskAwaiter:(Task<unit> -> Async<unit>) -> unit []

Full name: Script.test
val taskAwaiter : (Task<unit> -> Async<unit>)
val tcs : TaskCompletionSource<unit>
Multiple items
type TaskCompletionSource<'TResult> =
  new : unit -> TaskCompletionSource<'TResult> + 3 overloads
  member SetCanceled : unit -> unit
  member SetException : exception:Exception -> unit + 1 overload
  member SetResult : result:'TResult -> unit
  member Task : Task<'TResult>
  member TrySetCanceled : unit -> bool
  member TrySetException : exception:Exception -> bool + 1 overload
  member TrySetResult : result:'TResult -> bool

Full name: System.Threading.Tasks.TaskCompletionSource<_>

--------------------
TaskCompletionSource() : unit
TaskCompletionSource(creationOptions: TaskCreationOptions) : unit
TaskCompletionSource(state: obj) : unit
TaskCompletionSource(state: obj, creationOptions: TaskCreationOptions) : unit
val worker : (int -> Async<unit>)
val i : int
property TaskCompletionSource.Task: Task<unit>
val failwith : message:string -> 'T

Full name: Microsoft.FSharp.Core.Operators.failwith
module Seq

from Microsoft.FSharp.Collections
val init : count:int -> initializer:(int -> 'T) -> seq<'T>

Full name: Microsoft.FSharp.Collections.Seq.init
static member Async.Parallel : computations:seq<Async<'T>> -> Async<'T []>
static member Async.RunSynchronously : computation:Async<'T> * ?timeout:int * ?cancellationToken:CancellationToken -> 'T
static member Async.AwaitTask : task:Task -> Async<unit>
static member Async.AwaitTask : task:Task<'T> -> Async<'T>
static member Async.AwaitTask2 : task:Task -> Async<unit>
static member Async.AwaitTask2 : task:Task<'T> -> Async<'T>

More information

Link:http://fssnip.net/7Rw
Posted:7 years ago
Author:Eirik Tsarpalis
Tags: async