algos/src/heap.rs

128 lines
3.0 KiB
Rust

#[derive(Debug)]
struct MinHeap<T> {
nodes: Vec<T>,
}
fn parent_index(index: usize) -> usize {
(index + 1) / 2 - 1
}
fn child_indices(index: usize) -> (usize, usize) {
((index + 1) * 2 - 1, (index + 1) * 2 + 1 - 1)
}
impl<T: std::cmp::PartialOrd> MinHeap<T> {
fn new() -> Self {
Self { nodes: Vec::with_capacity(128) }
}
fn smaller_than_parent(&self, index: usize) -> bool {
if index == 0 {
false
} else {
self.nodes[index] < self.nodes[parent_index(index)]
}
}
fn bubble_up(&mut self, index: usize) -> usize {
let parent_index = parent_index(index);
self.nodes.swap(index, parent_index);
parent_index
}
fn insert(&mut self, e: T) -> () {
self.nodes.push(e);
let last = self.nodes.len() - 1;
let mut bubble_index = last;
while self.smaller_than_parent(bubble_index) {
bubble_index = self.bubble_up(bubble_index);
}
}
fn greater_than_children(&self, index: usize) -> bool {
let length = self.nodes.len();
let (left_index, right_index) = child_indices(index);
if left_index < length && self.nodes[index] > self.nodes[left_index] {
true
} else if right_index < length && self.nodes[index] > self.nodes[right_index] {
true
} else {
false
}
}
fn bubble_down(&mut self, index: usize) -> usize {
let length = self.nodes.len();
let (left_index, right_index) = child_indices(index);
let mut target_index = index;
if left_index < length && self.nodes[target_index] > self.nodes[left_index] {
target_index = left_index;
}
if right_index < length && self.nodes[right_index] < self.nodes[target_index] {
target_index = right_index;
}
self.nodes.swap(index, target_index);
target_index
}
fn extract_min(&mut self) -> T {
if self.nodes.len() == 0 {
panic!("Cannot extract key from empty heap");
}
let mut bubble_index = 0;
let last = self.nodes.len() - 1;
self.nodes.swap(bubble_index, last);
let result = self.nodes.pop().unwrap();
while self.greater_than_children(bubble_index) {
bubble_index = self.bubble_down(bubble_index);
}
result
}
fn size(&self) -> usize {
self.nodes.len()
}
}
pub fn heap(v: &Vec<i64>) -> i64 {
let mut hl: MinHeap<i64> = MinHeap::new();
let mut hh: MinHeap<i64> = MinHeap::new();
let mut iter = v.into_iter();
let mut current_median = *iter.next().unwrap();
let mut median_sum = current_median;
hl.insert(-current_median);
for e in iter {
let e = *e;
if e < current_median {
hl.insert(-e);
} else {
hh.insert(e);
}
while hl.size() < hh.size() {
let e = hh.extract_min();
hl.insert(-e);
}
while hl.size() > hh.size() + 1 {
let e = -hl.extract_min();
hh.insert(e);
}
current_median = -hl.extract_min();
median_sum = (median_sum + current_median) % 10000;
hl.insert(-current_median);
}
median_sum
}