Heaps.scala [raw]
/* Copyright 2009-2016 EPFL, Lausanne */
import stainless.annotation._
import stainless.lang._
object Heaps {
/*~~~~~~~~~~~~~~~~~~~~~~~*/
/* Data type definitions */
/*~~~~~~~~~~~~~~~~~~~~~~~*/
case class Node(rank : BigInt, elem : Int, nodes : Heap)
sealed abstract class Heap
private case class Nodes(head : Node, tail : Heap) extends Heap
private case object Empty extends Heap
sealed abstract class OptInt
case class Some(value : Int) extends OptInt
case object None extends OptInt
/*~~~~~~~~~~~~~~~~~~~~~~~*/
/* Abstraction functions */
/*~~~~~~~~~~~~~~~~~~~~~~~*/
def heapContent(h : Heap) : Set[Int] = h match {
case Empty => Set.empty[Int]
case Nodes(n, ns) => nodeContent(n) ++ heapContent(ns)
}
def nodeContent(n : Node) : Set[Int] = n match {
case Node(_, e, h) => Set(e) ++ heapContent(h)
}
/*~~~~~~~~~~~~~~~~~~~~~~~~*/
/* Helper/local functions */
/*~~~~~~~~~~~~~~~~~~~~~~~~*/
private def reverse(h : Heap) : Heap = reverse0(h, Empty)
private def reverse0(h : Heap, acc : Heap) : Heap = (h match {
case Empty => acc
case Nodes(n, ns) => reverse0(ns, Nodes(n, acc))
}) ensuring(res => heapContent(res) == heapContent(h) ++ heapContent(acc))
private def link(t1 : Node, t2 : Node) = (t1, t2) match {
case (Node(r, e1, ns1), Node(_, e2, ns2)) =>
if(e1 <= e2) {
Node(r + 1, e1, Nodes(t2, ns1))
} else {
Node(r + 1, e2, Nodes(t1, ns2))
}
}
private def insertNode(t : Node, h : Heap) : Heap = (h match {
case Empty => Nodes(t, Empty)
case Nodes(t2, h2) =>
if(t.rank < t2.rank) {
Nodes(t, h)
} else {
insertNode(link(t, t2), h2)
}
}) ensuring(res => heapContent(res) == nodeContent(t) ++ heapContent(h))
private def getMin(h : Heap) : (Node, Heap) = {
require(h != Empty)
h match {
case Nodes(t, Empty) => (t, Empty)
case Nodes(t, ts) =>
val (t0, ts0) = getMin(ts)
if(t.elem < t0.elem) {
(t, ts)
} else {
(t0, Nodes(t, ts0))
}
}
} ensuring(_ match {
case (n,h2) => nodeContent(n) ++ heapContent(h2) == heapContent(h)
})
/*~~~~~~~~~~~~~~~~*/
/* Heap interface */
/*~~~~~~~~~~~~~~~~*/
def empty() : Heap = {
Empty
} ensuring(res => heapContent(res) == Set.empty[Int])
def isEmpty(h : Heap) : Boolean = {
(h == Empty)
} ensuring(res => res == (heapContent(h) == Set.empty[Int]))
def insert(e : Int, h : Heap) : Heap = {
insertNode(Node(0, e, Empty), h)
} ensuring(res => heapContent(res) == heapContent(h) ++ Set(e))
def merge(h1 : Heap, h2 : Heap) : Heap = ((h1,h2) match {
case (_, Empty) => h1
case (Empty, _) => h2
case (Nodes(t1, ts1), Nodes(t2, ts2)) =>
if(t1.rank < t2.rank) {
Nodes(t1, merge(ts1, h2))
} else if(t2.rank < t1.rank) {
Nodes(t2, merge(h1, ts2))
} else {
insertNode(link(t1, t2), merge(ts1, ts2))
}
}) ensuring(res => heapContent(res) == heapContent(h1) ++ heapContent(h2))
def findMin(h : Heap) : OptInt = (h match {
case Empty => None
case Nodes(Node(_, e, _), ns) =>
findMin(ns) match {
case None => Some(e)
case Some(e2) => Some(if(e < e2) e else e2)
}
}) ensuring(_ match {
case None => isEmpty(h)
case Some(v) => heapContent(h).contains(v)
})
def deleteMin(h : Heap) : Heap = (h match {
case Empty => Empty
case ts : Nodes =>
val (Node(_, e, ns1), ns2) = getMin(ts)
merge(reverse(ns1), ns2)
}) ensuring(res => heapContent(res).subsetOf(heapContent(h)))
def sanity0() : Boolean = {
val h0 : Heap = Empty
val h1 = insert(42, h0)
val h2 = insert(72, h1)
val h3 = insert(0, h2)
findMin(h0) == None &&
findMin(h1) == Some(42) &&
findMin(h2) == Some(42) &&
findMin(h3) == Some(0)
}.holds
def sanity1() : Boolean = {
val h0 = insert(42, Empty)
val h1 = insert(0, Empty)
val h2 = merge(h0, h1)
findMin(h2) == Some(0)
}.holds
def sanity3() : Boolean = {
val h0 = insert(42, insert(0, insert(12, Empty)))
val h1 = deleteMin(h0)
findMin(h1) == Some(12)
}.holds
}
back