Atom.fs


//        MIT License
//
//        Copyright (c) 2020 Stefan von Stein
//        Permission is hereby granted, free of charge, to any person obtaining a copy
//        of this software and associated documentation files (the "Software"), to deal
//        in the Software without restriction, including without limitation the rights
//        to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
//        copies of the Software, and to permit persons to whom the Software is
//        furnished to do so, subject to the following conditions:
//
//        The above copyright notice and this permission notice shall be included in all
//        copies or substantial portions of the Software.
//
//        THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
//        IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
//        FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
//        AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
//        LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
//        OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
//        SOFTWARE.

module Atom 
open System
open System.Threading
      
type Cell<'T>(v: 'T) =
    member _.Value = v

type Atom<'T> =
    val mutable  cell: Cell<'T>
    val listener: Option<'T -> 'T -> unit>
    new (initial: 'T) =
        { cell = Cell initial
          listener = None }
    new (initial : 'T, callback : 'T -> 'T -> unit) =
        { cell = Cell initial
          listener = Some callback}

/// create an atom of a value
let atom (initial: 'T) : Atom<'T> = 
    Atom(initial)

/// create an atom of a value, and a listener for state changes 
let atomWithListener (initial: 'T) (listener : 'T -> 'T -> unit) =
    Atom(initial, listener)

/// Read the value of the atom
let deref (a: Atom<'T>) : 'T =
    Volatile.Read(&a.cell).Value

/// Swap in a new value, with safe transformation of old. 
/// The supplied transformation function may be called more than once
let rec swap (a: Atom<'T>) (f: 'T -> 'T) : 'T =
    let oldCell = Volatile.Read(&a.cell)
    let oldVal = oldCell.Value
    let newVal = f oldVal
    if newVal = oldVal then oldVal 
    else
        let newCell = Cell newVal
        if Interlocked.CompareExchange(&a.cell, newCell, oldCell) = oldCell then
            match a.listener with
            | Some cb -> cb oldVal newVal
            | None -> ()
            newVal
        else swap a f

/// Set a new value of the atom regardless of previous value
let reset (a: Atom<'T>) (value: 'T) : unit =
    let oldCell = Interlocked.Exchange(&a.cell, Cell value)
    match a.listener with
    | Some cb -> cb oldCell.Value value
    | None -> ()

/// A calssic CAS operation, returning success, and updates any listeners
let compareAndSwap (a: Atom<'T>) (expected: 'T) (next: 'T) : bool =
    let oldCell = Volatile.Read(&a.cell)
    let oldVal = oldCell.Value
    if oldVal <> expected then
        false
    else
        let newCell = Cell next
        if Interlocked.CompareExchange(&a.cell, newCell, oldCell) = oldCell then
            match a.listener with
            | Some cb -> cb oldVal next
            | None -> ()
            true
        else
            false

///
///  Provocative testcase
///

// A state variable for the test case
type State = {
    Value: int
}

// Lets have a lot of threads, half trying to increase a value, while rest is trying to decrease it
[<EntryPoint>]
let main argv =
    let totalThreads = 40
    let perThread = 100000
    let atom = atom( {Value = 0} :State)
    if (totalThreads < 2) || (totalThreads % 2 = 1) then failwith "Has to be even number of threads"
    // Rendezvous: all threads wait here before starting, just to make sure all start up before concurrency starts
    let barrier = new Barrier(totalThreads + 1) // +1 for main thread
      
    let makeWorker delta =
        Thread(fun () ->
            // signal the rendevouz that we are ready
            barrier.SignalAndWait()
            for i in 1 .. perThread do
                swap atom (fun v ->  
                        // Yield control to other threads frequently, to provoke retries  
                        if i % 10 = 0 then Thread.Yield() |> ignore
                        // Update the value
                        { v with Value = v.Value + delta}) |> ignore
                     )

  
    let half = totalThreads/2
    let upThreads   = [for i in 1..half -> makeWorker 1 ]
    let downThreads = [for _ in 1..half -> makeWorker -1 ]
    let allThreads = upThreads @ downThreads

    for t in allThreads do t.Start()

    // Main thread signals and lets them go
    barrier.SignalAndWait()

    // Wait for all threads to finish
    for t in allThreads do t.Join()

    printfn "Final value should be 0: %A" (deref atom).Value

    let ia = atomWithListener 0 (fun old n -> printfn $"%i{old} -> %i{n}")
    let u = compareAndSwap ia 0 3
    printfn "u should be 3: %i" (deref ia)
    0