用于元素成对比较的高效算法

2022-09-04 23:29:46

给定一个包含一些键值对的数组:

[
  {'a': 1, 'b': 1},
  {'a': 2, 'b': 1},
  {'a': 2, 'b': 2},
  {'a': 1, 'b': 1, 'c': 1},
  {'a': 1, 'b': 1, 'c': 2},
  {'a': 2, 'b': 1, 'c': 1},
  {'a': 2, 'b': 1, 'c': 2}
]

我想找到这些对的交集交叉意味着只留下那些可以被其他人覆盖的元素,或者是唯一的。例如,和 完全覆盖 ,而 是唯一的。所以,在{'a': 1, 'b': 1, 'c': 1}{'a': 1, 'b': 1, 'c': 2}{'a': 1, 'b': 1}{'a': 2, 'b': 2}

[
  {'a': 1, 'b': 1},
  {'a': 2, 'b': 1},
  {'a': 2, 'b': 2},
  {'a': 1, 'b': 1, 'c': 1},
  {'a': 1, 'b': 1, 'c': 2},
  {'a': 2, 'b': 1, 'c': 1},
  {'a': 2, 'b': 1, 'c': 2}
]

找到交叉路口后应保留

[
  {'a': 2, 'b': 2},
  {'a': 1, 'b': 1, 'c': 1},
  {'a': 1, 'b': 1, 'c': 2},
  {'a': 2, 'b': 1, 'c': 1},
  {'a': 2, 'b': 1, 'c': 2}
]

我试图迭代所有对,并找到相互比较的覆盖对,但时间复杂度等于。是否有可能在线性时间中找到所有覆盖或唯一对?O(n^2)

这是我的代码示例 ():O(n^2)

public Set<Map<String, Integer>> find(Set<Map<String, Integer>> allPairs) {
  var results = new HashSet<Map<String, Integer>>();
  for (Map<String, Integer> stringToValue: allPairs) {
    results.add(stringToValue);
    var mapsToAdd = new HashSet<Map<String, Integer>>();
    var mapsToDelete = new HashSet<Map<String, Integer>>();
    for (Map<String, Integer> result : results) {
      var comparison = new MapComparison(stringToValue, result);
      if (comparison.isIntersected()) {
        mapsToAdd.add(comparison.max());
        mapsToDelete.add(comparison.min());
      }
    }
    results.removeAll(mapsToDelete);
    results.addAll(mapsToAdd);
  }
  return results;
}

其中 MapComparison 是:

public class MapComparison {

    private final Map<String, Integer> left;
    private final Map<String, Integer> right;
    private final ComparisonDecision decision;

    public MapComparison(Map<String, Integer> left, Map<String, Integer> right) {
        this.left = left;
        this.right = right;
        this.decision = makeDecision();
    }

    private ComparisonDecision makeDecision() {
        var inLeftOnly = new HashSet<>(left.entrySet());
        var inRightOnly = new HashSet<>(right.entrySet());

        inLeftOnly.removeAll(right.entrySet());
        inRightOnly.removeAll(left.entrySet());

        if (inLeftOnly.isEmpty() && inRightOnly.isEmpty()) {
            return EQUALS;
        } else if (inLeftOnly.isEmpty()) {
            return RIGHT_GREATER;
        } else if (inRightOnly.isEmpty()) {
            return LEFT_GREATER;
        } else {
            return NOT_COMPARABLE;
        }
    }

    public boolean isIntersected() {
        return Set.of(LEFT_GREATER, RIGHT_GREATER).contains(decision);
    }

    public boolean isEquals() {
        return Objects.equals(EQUALS, decision);
    }

    public Map<String, Integer> max() {
        if (!isIntersected()) {
            throw new IllegalStateException();
        }
        return LEFT_GREATER.equals(decision) ? left : right;
    }

    public Map<String, Integer> min() {
        if (!isIntersected()) {
            throw new IllegalStateException();
        }
        return LEFT_GREATER.equals(decision) ? right : left;
    }

    public enum ComparisonDecision {
        EQUALS,
        LEFT_GREATER,
        RIGHT_GREATER,
        NOT_COMPARABLE,

        ;
    }
}

答案 1

这是一种算法,根据数据的形状,它可能更好或更糟。让我们通过将输入行表示为集合而不是映射来简化问题,因为从本质上讲,您只是将这些映射视为对/条目的集合。如果集合是等的,则问题是等价的。目标是创建一个线性时间算法,假设输入行的长度很短。设 n 为输入行数,k 为行的最大长度;我们的假设是 k 比 n 小得多。[a1, b1]

  • 使用计数排序按长度对行进行排序。
  • 初始化结果的空值,其中集合的成员将是行(您将需要一个不可变的可哈希类来表示行)。HashSet
  • 对于每一行:
    • 从结果中删除行的幂集中的每个子集(如果存在)。
    • 将该行添加到结果中。

由于行是按长度排序的,因此可以保证,如果行是行的子集,则行将被添加到行之前,因此稍后将从结果集中正确删除。一旦算法终止,结果集将恰好包含那些不是任何其他输入行的子集的输入行。ijij

计数排序的时间复杂度为 O(n + k)。每个功率集的大小最多为 2k,并且功率集的每个成员的长度最多为 k,因此每个操作都是 O(k) 时间。因此,算法其余部分的时间复杂度为O(2k·kn),这主导了计数排序。HashSet

因此,如果我们将 k 视为常量,则总体时间复杂度为 O(n)。如果不是,那么当 k < log2 n 时,此算法仍将渐近优于朴素 O(n2·k) 算法*。

*请注意,朴素算法是 O(n2·k) 而不是 O(n2),因为两行之间的每次比较都需要 O(k) 时间。


答案 2

假定列表中的每个元素都是唯一的。(元素是具有键值对的对象。对于每个唯一的键值对,存储包含该键值对的列表元素集。按大小增加的顺序循环访问元素。对于每个元素,通过查找包含它们的元素集并将该集与当前交集相交来搜索其键值对。如果交叉点大小小于 2(假设交叉点至少包含一个元素,即我们正在研究的元素),请尽早退出。根据数据,我们可以对这些集合使用位集(每个位将表示排序列表中map元素的索引),这可以执行并行比较的交集。此外,根据数据,交叉点可以显著减少搜索空间。

Python 代码:

import collections

def f(lst):
  pairs_to_elements = collections.defaultdict(set)

  for i, element in enumerate(lst):
    for k, v in element.items():
      pairs_to_elements[(k, v)].add(i)

  lst_sorted_by_size = sorted(lst, key=lambda x: len(x))

  result = []

  for element in lst_sorted_by_size:
    pairs = list(element.items())
    intersection = pairs_to_elements[pairs[0]]
    is_contained = True

    for i in range(1, len(pairs)):
      intersection = intersection.intersection(pairs_to_elements[pairs[i]])
      if len(intersection) < 2:
        is_contained = False
        break

    if not is_contained:
      result.append(element)

  return result

输出:

lst = [
  {'a': 1, 'b': 1},
  {'a': 2, 'b': 1},
  {'a': 2, 'b': 2},
  {'a': 1, 'b': 1, 'c': 1},
  {'a': 1, 'b': 1, 'c': 2},
  {'a': 2, 'b': 1, 'c': 1},
  {'a': 2, 'b': 1, 'c': 2}
]

for element in f(lst):
  print(element)

"""
{'a': 2, 'b': 2}
{'a': 1, 'b': 1, 'c': 1}
{'a': 1, 'b': 1, 'c': 2}
{'a': 2, 'b': 1, 'c': 1}
{'a': 2, 'b': 1, 'c': 2}
"""