"""Utilities for constructing PyTrees of PartitionSpecs."""# utils adapted from https://github.com/google-research/google-research/blob/master/flax_models/t5x/partitions.pyimport refrom flax.core.frozen_dict import freezefrom flax.traverse_util import flatten_dict, unflatten_dictfrom jax.experimental import PartitionSpec as P# Sentinels_unmatched = object()# For specifying empty leaf dict `{}`empty_dict = object()def _match(qs, ks):    """Return True if regexes in qs match any window of strings in tuple ks."""    # compile regexes and force complete match    qts = tuple((re.compile(x + "$") for x in qs))    for i in range(len(ks) - len(qs) + 1):        matches = [x.match(y) for x, y in zip(qts, ks[i:])]        if matches and all(matches):            return True    return Falsedef _replacement_rules(rules):    def replace(key, val):        for rule, replacement in rules:            if _match(rule, key):                return replacement        return val    return replace# PartitionSpec for GPTNeo# replicate the hidden dim and shard feed-forward and head dimdef _get_partition_rules():    return [        # embeddings        (("transformer", "wpe", "embedding"), P("mp", None)),        (("transformer", "wte", "embedding"), P("mp", None)),        # atention        (("attention", "(q_proj|k_proj|v_proj)", "kernel"), P(None, "mp")),        (("attention", "out_proj", "kernel"), P("mp", None)),        (("attention", "out_proj", "bias"), None),        # mlp        (("mlp", "c_fc", "kernel"), P(None, "mp")),        (("mlp", "c_fc", "bias"), P("mp")),        (("mlp", "c_proj", "kernel"), P("mp", None)),        (("mlp", "c_proj", "bias"), None),        # layer norms        ((r"ln_\d+", "bias"), None),        ((r"\d+", r"ln_\d+", "scale"), None),        (("ln_f", "bias"), None),        (("ln_f", "scale"), None),    ]def set_partitions(in_dict):    rules = _get_partition_rules()    replace = _replacement_rules(rules)    initd = {k: _unmatched for k in flatten_dict(in_dict)}    result = {k: replace(k, v) for k, v in initd.items()}    assert _unmatched not in result.values(), "Incomplete partition spec."    return freeze(unflatten_dict(result))