Skip to content

Commit

Permalink
Merge pull request #79 from marl/hier-eval
Browse files Browse the repository at this point in the history
Hierarchical segment evaluation bindings
  • Loading branch information
bmcfee committed Nov 12, 2015
2 parents 30c46a9 + 12f92e8 commit e98ee48
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 3 deletions.
1 change: 0 additions & 1 deletion .travis_dependencies.sh
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ if [ ! -d "$src" ]; then
pip install jsonschema
pip install python-coveralls
pip install numpydoc
pip install git+https://github.com/craffel/mir_eval.git@develop

source deactivate
popd
Expand Down
80 changes: 79 additions & 1 deletion jams/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
segment
tempo
pattern
hierarchy
'''

from collections import defaultdict
Expand All @@ -24,7 +25,7 @@

from .exceptions import NamespaceError

__all__ = ['beat', 'chord', 'melody', 'onset', 'segment', 'tempo', 'pattern']
__all__ = ['beat', 'chord', 'melody', 'onset', 'segment', 'hierarchy', 'tempo', 'pattern']


def validate_annotation(ann, namespace):
Expand Down Expand Up @@ -230,6 +231,83 @@ def segment(ref, est, **kwargs):
est_interval, est_value, **kwargs)


def hierarchy_flatten(annotation):
'''Flatten a multi_segment annotation into mir_eval style.
Parameters
----------
annotation : jams.Annotation
An annotation in the `multi_segment` namespace
Returns
-------
hier_intervalss : list
A list of lists of intervals, ordered by increasing specificity.
hier_labels : list
A list of lists of labels, ordered by increasing specificity.
'''

intervals, values = annotation.data.to_interval_values()

ordering = dict()

for interval, value in zip(intervals, values):
level = value['level']
if level not in ordering:
ordering[level] = dict(intervals=list(), labels=list())

ordering[level]['intervals'].append(interval)
ordering[level]['labels'].append(value['label'])

levels = sorted(list(ordering.keys()))
hier_intervals = [ordering[level]['intervals'] for level in levels]
hier_labels = [ordering[level]['labels'] for level in levels]

return hier_intervals, hier_labels


def hierarchy(ref, est, **kwargs):
r'''Multi-level segmentation evaluation
Parameters
----------
ref : jams.Annotation
Reference annotation object
est : jams.Annotation
Estimated annotation object
kwargs
Additional keyword arguments
Returns
-------
scores : dict
Dictionary of scores, where the key is the metric name (str) and
the value is the (float) score achieved.
See Also
--------
mir_eval.hierarchy.evaluate
Examples
--------
>>> # Load in the JAMS objects
>>> ref_jam = jams.load('reference.jams')
>>> est_jam = jams.load('estimated.jams')
>>> # Select the first relevant annotations
>>> ref_ann = ref_jam.search(namespace='multi_segment')[0]
>>> est_ann = est_jam.search(namespace='multi_segment')[0]
>>> scores = jams.eval.hierarchy(ref_ann, est_ann)
'''
namespace = 'multi_segment'
validate_annotation(ref, namespace)
validate_annotation(est, namespace)
ref_hier, _ = hierarchy_flatten(ref)
est_hier, _ = hierarchy_flatten(est)

return mir_eval.hierarchy.evaluate(ref_hier, est_hier, **kwargs)


def tempo(ref, est, **kwargs):
r'''Tempo evaluation
Expand Down
41 changes: 41 additions & 0 deletions jams/tests/eval_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,3 +244,44 @@ def test_pattern_invalid():
yield raises(jams.SchemaError)(jams.eval.pattern), est_ann, ref_ann



# Hierarchical segmentation
def create_hierarchy(values, offset=0.0, duration=20):
ann = jams.Annotation(namespace='multi_segment')

for level, labels in enumerate(values):
times = np.linspace(offset, offset + duration, num=len(labels), endpoint=False)

durations = list(np.diff(times))
durations.append(duration + offset - times[-1])

for t, d, v in zip(times, durations, labels):
ann.append(time=t, duration=d, value=dict(label=v, level=level))

return ann

def test_hierarchy_valid():

ref_ann = create_hierarchy(values=['AB', 'abac'])
est_ann = create_hierarchy(values=['ABCD', 'abacbcbd'])

jams.eval.hierarchy(ref_ann, est_ann)


def test_hierarchy_invalid():

ref_ann = create_hierarchy(values=['AB', 'abac'])
est_ann = create_hierarchy(values=['ABCD', 'abacbcbd'])

est_ann.namespace = 'segment_open'

yield raises(jams.NamespaceError)(jams.eval.hierarchy), ref_ann, est_ann
yield raises(jams.NamespaceError)(jams.eval.hierarchy), est_ann, ref_ann

est_ann = create_annotation(values=['E', 'B', 'E', 'B'],
namespace='segment_tut')
est_ann.namespace = 'multi_segment'

yield raises(jams.SchemaError)(jams.eval.hierarchy), ref_ann, est_ann
yield raises(jams.SchemaError)(jams.eval.hierarchy), est_ann, ref_ann

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
'numpy>=1.8.0',
'six',
'decorator',
'mir_eval',
'mir_eval>=0.2',
],
scripts=['scripts/jams_to_lab.py']
)

0 comments on commit e98ee48

Please sign in to comment.