import argparse
import itertools
import sys
import json
from collections import defaultdict
from operator import attrgetter
import msgpack
import numpy as np
from homcloud.version import __version__
from homcloud.diagram import PD
import homcloud.utils as utils
from homcloud.index_map import IndexMap
from homcloud.visualize_3d import boundary_of_connected_simplices, ParaViewSimplexDrawer
from homcloud.compat import INFINITY
def main(args=None):
args = args or argument_parser().parse_args()
diagram = PD.load_from_indexed_diphafile(args.input, args.degree)
check_degree(args.degree, diagram)
phtree = PHTrees(diagram, args.degree)
phtree.construct_tree()
with open(args.output, "wb") as f:
write_phtree(f, phtree, diagram)
if args.dump_json is not None:
phtree_q = PHTreesForQuery.from_dict(phtree.to_dict(), diagram.index_map)
with open(args.dump_json, "w") as f:
json.dump(phtree_q.to_jsondict(), f)
def check_degree(degree, diagram):
if not degree == diagram.index_map.dimension - 1:
sys.stderr.write("degree should equal dimension - 1\n")
exit(1)
def write_phtree(output, phtree, diagram):
msgpack.pack({
"trees": phtree.to_dict(),
"index-map": diagram.index_map.to_dict(),
}, output, use_bin_type=True)
def argument_parser():
p = argparse.ArgumentParser(description="compute a full PH tree")
p.add_argument("-V", "--version", action="version", version=__version__)
p.add_argument("-d", "--degree", type=int, required=True, help="degree of PH")
p.add_argument("-j", "--dump-json", help="json dump output")
p.add_argument("input", help="input file")
p.add_argument("output", help="output file")
return p
class PHTrees(object):
def __init__(self, diagram, degree):
self.diagram = diagram
self.simplices = map(frozenset, diagram.index_map.simplices)
self.degree = degree
self.nodes = []
self.adjacent_nodes = defaultdict(list)
self.outer_node = PHTrees.OuterNode()
def construct_tree(self):
for (time, simplex) in reversed(list(enumerate(self.simplices))):
dim_simplex = len(simplex) - 1
if dim_simplex == self.degree + 1:
node = PHTrees.Node(simplex, time)
self.nodes.append(node)
for boundary_simplex in node.boundary():
self.adjacent_nodes[boundary_simplex].append(node)
elif dim_simplex == self.degree:
n1, n2 = self.get_adjacent_nodes(simplex)
r1 = n1.root()
r2 = n2.root()
if r1 is r2:
continue
elif r1.death_time > r2.death_time:
self.merge(r1, r2, time)
else:
self.merge(r2, r1, time)
def get_adjacent_nodes(self, simplex):
nodes = self.adjacent_nodes[simplex]
if len(nodes) == 2:
return nodes
elif len(nodes) == 1:
return [nodes[0], self.outer_node]
else:
raise RuntimeError("Algorithm Error or Simplex Data error")
@staticmethod
def merge(parent, child, time):
child.parent = child.shortcut = parent
child.birth_time = time
def to_dict(self):
return {node.index(): node.to_dict() for node in self.nodes}
class Node(object):
def __init__(self, simplex, death_time):
self.simplex = simplex
self.death_time = death_time
self.birth_time = None
self.parent = self.shortcut = None
[docs] def boundary(self):
return utils.boundary_of_simplex(self.simplex)
def root(self):
if self.parent is None:
return self
result = self.shortcut.root()
self.shortcut = result
return result
@staticmethod
def outer():
return False
def index(self):
return self.death_time
def to_dict(self):
return {
"index": self.index(), "birth-index": self.birth_time,
"parent": self.parent.index(),
}
class OuterNode(object):
def __init__(self):
self.death_time = INFINITY
def root(self):
return self
@staticmethod
def outer():
return True
@staticmethod
def index():
return None
class PHTreesForQuery(object):
def __init__(self, nodes, index_map):
self.nodes = nodes
self.index_map = index_map
for node in nodes.values():
if node.parent_index is not None:
node.parent = self.nodes[node.parent_index]
node.parent.children.append(node)
class Node(object):
def __init__(self, index, birth_index, parent_index, index_map):
self.index = index
self.birth_index = birth_index
self.parent_index = parent_index
self.index_map = index_map
self.parent = None
self.children = []
self.volume_cache = None
self.geom_resolver = index_map.geometry_resolver(None)
@classmethod
def from_dict(cls, dic, index_map):
return cls(dic["index"], dic["birth-index"], dic["parent"],
index_map)
def birth_time(self):
return self.index_map.levels[self.birth_index]
def death_time(self):
return self.index_map.levels[self.index]
def birth_simplex(self):
return self.index_map.simplices[self.birth_index]
def death_simplex(self):
return self.index_map.simplices[self.index]
def volume(self):
"""Returns all descendants list.
The list always contains self.
"""
if self.volume_cache is None:
self.volume_cache = [self] + list(itertools.chain.from_iterable([
child.volume() for child in self.children
]))
return self.volume_cache
def volume_simplices(self):
return [frozenset(self.index_map.simplices[node.index])
for node in self.volume()]
def birth_death_pair(self):
return (self.birth_time(), self.death_time())
def living(self):
return self.birth_time() < self.death_time()
def count_living_descendants(self):
return len(self.living_descendants())
def living_descendants(self):
"""Return list of all descendant nodes who is living.
The list always contains self if self is living.
"""
return [node for node in self.volume() if node.living()]
def depth_to(self, root):
if self == root:
return 0
else:
return 1 + self.parent.depth_to(root)
def draw_descendants_volumes(self, points, path, draw_birth, draw_death):
living_descendants = self.living_descendants()
drawer = ParaViewSimplexDrawer(
len(living_descendants), points,
{"depth": None, "birth": None, "death": None}
)
for (node, color) in zip(living_descendants, drawer.various_colors()):
node.draw_volume(drawer, color, {
"depth": node.depth_to(self),
"birth": node.birth_time(),
"death": node.death_time()
}, draw_birth, draw_death)
drawer.output(path)
def draw_volume(self, drawer, color, values, draw_birth=False, draw_death=False):
for b in self.boundary_simplices():
drawer.draw_simplex(b, color, **values)
if draw_birth:
drawer.draw_simplex(self.birth_simplex(), drawer.birth_color(), **values)
if draw_death:
drawer.draw_simplex(self.death_simplex(), drawer.death_color(), **values)
def points_in_volume(self):
return utils.deep_tolist(
self.geom_resolver.unique_vertices_coords(self.volume_indices())
)
points = points_in_volume
def points_symbols(self):
return self.geom_resolver.unique_vertices_symbols(self.volume_indices())
def volume_indices(self):
return map(attrgetter("index"), self.volume())
def simplices_in_volume(self):
return utils.deep_tolist(
self.geom_resolver.cells_coords(self.volume_indices())
)
simplices = simplices_in_volume
def volume_simplices_symbols(self):
return self.geom_resolver.cells_symbols(self.volume_indices())
def birth_position(self):
return self.resolve_simplex(self.birth_simplex())
birth_pos = birth_position
def death_position(self):
return self.resolve_simplex(self.death_simplex())
death_pos = death_position
def resolve_simplex(self, simplex):
return utils.deep_tolist(
self.geom_resolver.vertex_indices_to_coords(simplex)
)
def boundary(self):
return [self.resolve_simplex(s) for s in self.boundary_simplices()]
def boundary_symbols(self):
return [self.geom_resolver.vertices_to_symbols(s)
for s in self.boundary_simplices()]
def boundary_simplices(self):
return boundary_of_connected_simplices(self.volume_simplices())
def boundary_points(self):
return self.resolve_simplex(self.boundary_points_indices())
def boundary_points_symbols(self):
return self.geom_resolver.vertices_to_symbols(self.boundary_points_indices())
def boundary_points_indices(self):
return set(itertools.chain.from_iterable(self.boundary_simplices()))
def descendant_pairs_dicts(self):
return [
node.birth_death_time_dict() for node in self.living_descendants()
]
def ancestor_pairs_dicts(self):
return [node.birth_death_time_dict() for node in self.ancestors()]
def ancestors(self):
ret = []
node = self
while node is not None:
ret.append(node)
node = node.parent
return ret
def birth_death_time_dict(self):
return {"birth-time": self.birth_time(), "death-time": self.death_time()}
def children_dicts(self):
return [child.to_jsondict(True, True, True)
for child in self.children if child.living()]
def lifetime(self):
return self.death_time() - self.birth_time()
def to_jsondict(self, descendant_pairs=False, ancestor_pairs=False, children=False):
return {
"id": str(self.index),
"parent": None if self.parent_index is None else str(self.parent_index),
"birth-time": self.birth_time(), "death-time": self.death_time(),
"points": self.points_in_volume(), "simplices": self.simplices_in_volume(),
"boundary": self.boundary(), "boundary-points": self.boundary_points(),
"points-symbols": self.points_symbols(),
"simplices-symbols": self.volume_simplices_symbols(),
"boundary-symbols": self.boundary_symbols(),
"boundary-points-symbols": self.boundary_points_symbols(),
"birth-simplex": self.birth_position(),
"death-simplex": self.death_position(),
"ancestors": self.ancestor_pairs_dicts() if ancestor_pairs else None,
"descendants": self.descendant_pairs_dicts() if descendant_pairs else None,
"children": self.children_dicts() if children else None,
}
def __repr__(self):
"""TODO: support inheritance
"""
return "PHTreesForQuery.Node(birth={}, death={})".format(
self.birth_time(), self.death_time()
)
@staticmethod
def from_dict(dic, index_map, nodeclass=Node):
nodes = {index: nodeclass.from_dict(nodedict, index_map)
for (index, nodedict) in dic.items()}
return PHTreesForQuery(nodes, index_map)
@classmethod
def load_from_file(cls, f, nodeclass=Node):
data = msgpack.unpack(f, raw=False)
index_map = IndexMap.load_from_dict(data["index-map"])
phtree_q = cls.from_dict(data["trees"], index_map, nodeclass)
return (phtree_q, index_map)
def index_pairs(self):
return [(node.birth_index, node.index) for node in self.nodes.values()]
def birth_death_indices(self):
births, deaths = zip(*self.index_pairs())
return np.array(births, dtype=int), np.array(deaths, dtype=int)
def to_jsondict(self):
living_nodes = [node for node in self.nodes.values() if node.living()]
return {
"dim": self.index_map.dimension,
"num-nodes": len(living_nodes),
"nodes": {str(node.index): node.to_jsondict() for node in living_nodes}
}
# TODO: Use KDTree or something
class SpatialSearcher(object):
def __init__(self, pairs, births, deaths):
self.pairs = pairs
self.births = births
self.deaths = deaths
def nearest_pair(self, birth, death):
distances = (self.births - birth)**2 + (self.deaths - death)**2
return self.pairs[np.argmin(distances)]
def in_rectangle(self, xmin, xmax, ymin, ymax):
def is_pair_in_rectangle(birth, death):
return (xmin <= birth <= xmax) and (ymin <= death <= ymax) and (birth != death)
return [self.pairs[k] for k in range(len(self.pairs))
if is_pair_in_rectangle(self.births[k], self.deaths[k])]
@staticmethod
def from_diagram(diagram):
return SpatialSearcher(
[(diagram.index_map.simplex(b), diagram.index_map.simplex(d))
for (b, d) in zip(diagram.masked_birth_indices, diagram.masked_death_indices)],
diagram.births, diagram.deaths
)
class TreeResolver(object):
def __init__(self, phtree, spatial_searcher, index_map):
self.phtree = phtree
self.spatial_searcher = spatial_searcher
self.index_map = index_map
@staticmethod
def load(f, nodeclass=PHTreesForQuery.Node):
(phtree_q, index_map) = PHTreesForQuery.load_from_file(f, nodeclass)
nodes = list(phtree_q.nodes.values())
spatial_searcher = SpatialSearcher(
nodes,
np.array([n.birth_time() for n in nodes]),
np.array([n.death_time() for n in nodes])
)
return TreeResolver(phtree_q, spatial_searcher, index_map)
def query_node(self, birth, death):
return self.spatial_searcher.nearest_pair(birth, death)
def query_nodes_in_rectangle(self, xmin, xmax, ymin, ymax):
return self.spatial_searcher.in_rectangle(xmin, xmax, ymin, ymax)
def draw_volumes_of_nodes(self, nodes, path, draw_birth, draw_death):
drawer = ParaViewSimplexDrawer(len(nodes), self.index_map.points, {})
for node, color in zip(nodes, drawer.various_colors()):
node.draw_volume(drawer, color, dict(), draw_birth, draw_death)
drawer.output(path)
if __name__ == "__main__":
main()