PyYAML で array を merge する


yaml in データ分析

身近なところでは、機械学習・データ分析周りの設定を yaml で書くのが流行っています(主にKedroを使っています)。
なるべくDRY(don't repeat yourself)にすべく共通の設定はアンカー(&) を使っているのですが、そこで問題になるのが yamlの仕様で、mapping は merge できるが array は merge できないというものです。
これは yaml の仕様としてはサポートしない、というのが yaml チームの見解のようです(https://github.com/yaml/yaml/issues/35 がIssue として立ち上がり、度々Openされその度 Closeされているのが見て取れます)。

yaml で困る具体例

具体的には、以下のような場面で困ります。

common_features: &common
  - member_reward_program_status
  - member_is_subscribing

transaction_features: &transaction
  - num_transactions
  - average_transaction_amount
  - time_since_last_transaction

next_product_to_buy:
  model_to_use: xgboost
  feature_whitelist:
    - *common
    - *transaction
    - last_product_bought
    - applied_to_campaign
  target: propensity

複数のfeatureの塊があったとして、それを組み合わせてモデルを作る場合を考えます。
欲しいものとしては feature_whitelist の中身が

[
  'member_reward_program_status', 
  'member_is_subscribing', 
  'num_transactions', 
  'average_transaction_amount', 
  'time_since_last_transaction', 
  'last_product_bought', 
  'applied_to_campaign'
]

になることなんですが、上の設定だと下のようなネストしたリストになってしまいます。

[
  [
    'member_reward_program_status', 
    'member_is_subscribing', 
  ],
  [
    'num_transactions', 
    'average_transaction_amount', 
    'time_since_last_transaction', 
  ],
  'last_product_bought', 
  'applied_to_campaign'
]

その他の解決法

上の問題を解決するだけであればなんでもいいので、例えばネストしたリストをフラットにする とか、リストじゃなくて辞書型として定義してmergeする、などがあります。

# 辞書型の例
feature_a: &feature_a
  age: 
feature_b: &feature_b
  price:
use_features:
  <<: *feature_a
  <<: *feature_b

使い方は下のようになります。

# > params['use_features'].keys()
dictkeys(['age', 'price'])

また同じ yaml 側で解決する場合も、package を選べる場合は PyYAML の fork である ruamel.yamlを使っても実現できます。

yaml の tag を定義する

今回は Kedro の機能を拡張するために使いたいという背景がありました。
Kedro は TemplatedConfig を読み込む際にanyconfig を使っており、anyconfig 自体は PyYAML にも ruamel.yaml にも対応しているようですが、Kedro サイドで PyYAML を requirements として指定しているので、PyYAML で実現する方法を考えます。

公式のDocs にも自前タグの実装についてある程度の解説はあるので、それを参考にしつつ、タグ用の constructor を定義します。

import yaml

yaml.add_constructor("!flatten", construct_flat_list)

def construct_flat_list(loader: yaml.Loader, node: yaml.Node) -> List[str]:
    """Make a flat list, should be used with '!flatten'

    Args:
        loader: Unused, but necessary to pass to `yaml.add_constructor`
        node: The passed node to flatten
    """
    return list(flatten_sequence(node))

def flatten_sequence(sequence: yaml.Node) -> Iterator[str]:
    """Flatten a nested sequence to a list of strings
        A nested structure is always a SequenceNode
    """
    if isinstance(sequence, yaml.ScalarNode):
        yield sequence.value
        return
    if not isinstance(sequence, yaml.SequenceNode):
        raise TypeError(f"'!flatten' can only flatten sequence nodes, not {sequence}")
    for el in sequence.value:
        if isinstance(el, yaml.SequenceNode):
            yield from flatten_sequence(el)
        elif isinstance(el, yaml.ScalarNode):
            yield el.value
        else:
            raise TypeError(f"'!flatten' can only take scalar nodes, not {el}")

PyYAML は Python のオブジェクトを作成する手前で yaml をPyYAML のオブジェクトにパースした document を作るのですが、その document では array は全て yaml.SequenceNode として、値は yaml.ScalarNode として保存されているので、上のコードで再起的に値だけを取り出すことができます。
機能を確認するためのテストコードは以下のようになります。!flatten の tag をつけることで、ネストされた array をフラットな array に変換できます。

import pytest
def test_flatten_yaml():
    # single nest
    param_string = """
    bread: &bread
      - toast
      - loafs
    chicken: &chicken
      - *bread
    midnight_meal: !flatten
      - *chicken
      - *bread
    """
    params = yaml.load(param_string)
    assert sorted(params["midnight_meal"]) == sorted(
        ["toast", "loafs", "toast", "loafs"]
    )

    # double nested
    param_string = """
    bread: &bread
      - toast
      - loafs
    chicken: &chicken
      - *bread
    dinner: &dinner
      - *chicken
      - *bread
    midnight_meal_long:
      - *chicken
      - *bread
      - *dinner
    midnight_meal: !flatten
      - *chicken
      - *bread
      - *dinner
    """
    params = yaml.load(param_string)
    assert sorted(params["midnight_meal"]) == sorted(
        ["toast", "loafs", "toast", "loafs", "toast", "loafs", "toast", "loafs"]
    )

    # doesn't work with mappings
    param_string = """
    bread: &bread
      - toast
      - loafs
    chicken: &chicken
      meat: breast
    midnight_meal: !flatten
      - *chicken
      - *bread
    """
    with pytest.raises(TypeError):
        yaml.load(param_string)

参考になれば幸いです。