MergeSort.scala [raw]


/* Copyright 2009-2016 EPFL, Lausanne */

import stainless.annotation._
import stainless.lang._

object MergeSort {
  // Data types
  sealed abstract class List
  case class Cons(head : Int, tail : List) extends List
  case class Nil() extends List

  sealed abstract class LList
  case class LCons(head : List, tail : LList) extends LList
  case class LNil() extends LList

  def content(list : List) : Set[Int] = list match {
    case Nil() => Set.empty[Int]
    case Cons(x, xs) => Set(x) ++ content(xs)
  }

  def lContent(llist : LList) : Set[Int] = llist match {
    case LNil() => Set.empty[Int]
    case LCons(x, xs) => content(x) ++ lContent(xs)
  }

  def size(list : List) : BigInt = (list match {
    case Nil() => BigInt(0)
    case Cons(_, xs) => 1 + size(xs)
  }) ensuring(_ >= 0)

  def isSorted(list : List) : Boolean = list match {
    case Nil() => true
    case Cons(_, Nil()) => true
    case Cons(x1, Cons(x2, _)) if(x1 > x2) => false
    case Cons(_, xs) => isSorted(xs)
  }

  def lIsSorted(llist : LList) : Boolean = llist match {
    case LNil() => true
    case LCons(x, xs) => isSorted(x) && lIsSorted(xs)
  }

  def abs(i : BigInt) : BigInt = {
    if(i < 0) -i else i
  } ensuring(_ >= 0)

  def mergeSpec(list1 : List, list2 : List, res : List) : Boolean = {
    isSorted(res) && content(res) == content(list1) ++ content(list2)
  }

  def mergeFast(list1 : List, list2 : List) : List = {
    require(isSorted(list1) && isSorted(list2))

    (list1, list2) match {
      case (_, Nil()) => list1
      case (Nil(), _) => list2
      case (Cons(x, xs), Cons(y, ys)) =>
        if(x <= y) {
          Cons(x, mergeFast(xs, list2)) 
        } else {
          Cons(y, mergeFast(list1, ys))
        }
    }
  } ensuring(res => mergeSpec(list1, list2, res))

  def splitSpec(list : List, res : (List,List)) : Boolean = {
    val s1 = size(res._1)
    val s2 = size(res._2)
    abs(s1 - s2) <= 1 && s1 + s2 == size(list) &&
    content(res._1) ++ content(res._2) == content(list) 
  }

  def split(list : List) : (List,List) = (list match {
    case Nil() => (Nil(), Nil())
    case Cons(x, Nil()) => (Cons(x, Nil()), Nil())
    case Cons(x1, Cons(x2, xs)) =>
      val (s1,s2) = split(xs)
      (Cons(x1, s1), Cons(x2, s2))
  }) ensuring(res => splitSpec(list, res))

  def sortSpec(in : List, out : List) : Boolean = {
    content(out) == content(in) && isSorted(out)
  }

  // Not really quicksort, neither mergesort.
  // Note: the `s` argument is just a witness for termination (always decreases),
  // and not needed for functionality. Any decent optimizer will remove it ;-)
  def weirdSort(s: BigInt, in : List) : List = {
    require(s == size(in))
    in match {
      case Nil() => Nil()
      case Cons(x, Nil()) => Cons(x, Nil())
      case _ =>
        val (s1,s2) = split(in)
        mergeFast(weirdSort(size(s1), s1), weirdSort(size(s2), s2))
    }
  } ensuring(res => sortSpec(in, res))

  def toLList(list : List) : LList = (list match {
    case Nil() => LNil()
    case Cons(x, xs) => LCons(Cons(x, Nil()), toLList(xs))
  }) ensuring(res => lContent(res) == content(list) && lIsSorted(res))

  def mergeMap(llist : LList) : LList = {
    require(lIsSorted(llist))

    llist match {
      case LNil() => LNil()
      case o @ LCons(x, LNil()) => o
      case LCons(x, LCons(y, ys)) => LCons(mergeFast(x, y), mergeMap(ys))
    }
  } ensuring(res => lContent(res) == lContent(llist) && lIsSorted(res))

  def mergeReduce(llist : LList) : List = {
    require(lIsSorted(llist))

    llist match {
      case LNil() => Nil()
      case LCons(x, LNil()) => x
      case _ => mergeReduce(mergeMap(llist))
    }
  } ensuring(res => content(res) == lContent(llist) && isSorted(res))

  def mergeSort(in : List) : List = {
    mergeReduce(toLList(in))
  } ensuring(res => sortSpec(in, res))
}

back