AFM 《Attentional Factorization Machines》


Learning the Weight of Feature Interactions via Attention Networks

一. 论文解读

FM estimates the target by modelling all interactions between each pair of features:
\hat{y}_{FM}(x) = w_0 + \sum_{i=1}^{n}w_ix_i + \sum_{i=1}^{n}\sum_{j=i+1}^{n}\hat{w}_{ij}x_ix_j

where w_0 is the global bias, w_i denotes the weight of the i-th feature, and \hat{w}_{ij} denotes the weight of the cross feature x_i, x_j, which is factorized as: \hat{w}_{ij} = v_i^Tv_j, where v_i \in R^k denotes the size of the embedding vector for feature i, and k denotes the size of embedding vector.

It is worth noting that FM models all feature interactions in the same way: first, a latent vector v_i is shared in estimating all feature interactions that the i-th feature involves; second, all estimated feature interactions \hat{w}_{ij} have a uniform weight of 1. In practice, it is common that not all features are relevant to prediction. However, FM models all possible feature interactions with the same weight, which may adversely deteriorate its generalization performance.

Attentional Factorization Machine (AFM), which learns the importance of each feature interaction from data via a neural attention network.

1. 模型结构:


2. 计算流程:
2.1 Pair-wise Interaction Layer

f_{PI}(\varepsilon ) = \{(v_i\odot v_j)x_ix_j\}_{(i, j)\in R_x}

Formally, let the set of non-zero features in the feature vector x be X , and the output of the embedding layer be \varepsilon = \{ v_ix_i\}_{i \in X} and R_x \in \{ (i, j) \} _{i \in X, j \in X, j>i}

2.2 Attention-based Pooling Layer

f_{ATT}(f_{PI}(\varepsilon )) = \sum _{(i, j) \in R_x}a_{ij}(v_i \odot v_j)x_ix_j
where a_{ij} is the attention score for feature interaction \hat{w}_{ij}, which can be interpreted as the importance of \hat{w}_{ij} in predicting the target.

Formally, the attention network is defined as:
{a}'_{ij} = h^TRelu(W(v_i \odot v_j)x_ix_j + b)
a_{ij} = \frac{exp({a}'_{ij})}{\sum _{(i, j) \in R_x} exp({a}'_{ij})}
where W \in R^{t×k}, b \in R^t, h \in R^t are model parameters, and t denotes the hidden layer size of the attention network, which we call attention factor.

2.3 Prediction Layer

\hat{y}_{AFM} (x) = w_0 + \sum _{i=1}^{n}w_ix_i + p^T\sum _{i=1}^{n}\sum _{j=i+1}^{n}a_{ij}(v_i\odot v_j)x_ix_j

3. 损失函数:

L = \sum (\hat{y}_{AFM}(x) - y(x))^2 + \lambda ||W||^2

二. 代码实现

1. 系统环境
  • tensorflow 2.0
  • python 3.6.8
2. Pair-wise Interaction Layer
class PairWiseLayer(tf.keras.layers.Layer):
    include embedding layer and pair-wise interaction layer
    def __init__(self, feature_size, field_size, embedding_size, l2_reg=0.01, **kwargs):
        self.feature_size = feature_size
        self.field_size = field_size
        self.embeding_size = embedding_size
        self.l2_reg = l2_reg
    def build(self, input_shape):
        self.embeddings = tf.keras.layers.Embedding(self.feature_size, self.embeding_size,
    def call(self, inputs):
        feature_ids = inputs['feature_ids']  # [batch_size, field_size]
        feature_vals = inputs['feature_vals']  # [batch_size, filed_size]
        embeddings = self.embeddings(feature_ids)  # [batch_size, field_size, embedding_size]
        # todo 需要优化,目前这种方式太低效
        outputs = []
        for i in range(self.field_size):
            for j in range(i+1, self.field_size):
                weight = tf.multiply(embeddings[:, i, :], embeddings[:, j, :])
                dot = tf.reduce_sum(tf.multiply(feature_vals[:, i], feature_vals[:, j]))
                outputs.append(weight * dot)
        outputs = tf.transpose(tf.stack(outputs), perm=[1, 0, 2])  # [batch_size, interaction_dim, embedding_size]
        return outputs
2. Attention-based Pooling Layer
class AttentionLayer(tf.keras.layers.Layer):
    计算 a_ij, just one layer.
    def __init__(self, hidden_units=64, embedding_size=32, l2_reg=0.01, **kwargs):
        self.hidden_units = hidden_units
        self.embedding_size = embedding_size
        self.l2_reg = l2_reg
    def build(self, input_shape):
        :param input_shape: [batch_size, interaction_dim, embedding_size]
        self.h = self.add_weight(name='h',
        # self.w = self.add_weight(name='w',
        #                          shape=[self.hidden_units, self.embedding_size],
        #                          initializer=tf.zeros_initializer())
        # self.b = self.add_weight(name='b',
        #                          shape=[self.hidden_units],
        #                          initializer=tf.zeros_initializer())
        self.dense = tf.keras.layers.Dense(units=self.hidden_units,
    def call(self, inputs):
        :param inputs:
        :return: [batch_size, interaction_dim, 1]
        inner = self.dense(inputs)  # [batch_size, interaction_dim, hidden_units]
        outer = tf.matmul(inner, tf.expand_dims(self.h, axis=-1))
        outputs = tf.nn.softmax(outer, axis=1)
        return outputs
3. Predict Layer
class PredictLayer(tf.keras.layers.Layer):
    def __init__(self, embedding_size=32, **kwargs):
        self.embedding_size = embedding_size
    def build(self, input_shape):
        self.p = self.add_weight(name='p',
    def call(self, pi_out, att_out):
        # [batch_size, embedding_size]
        inner = tf.reduce_sum(att_out * pi_out, axis=1)
        outputs = tf.matmul(inner, tf.expand_dims(self.p, axis=-1))
        return outputs
4. model
class AFM(tf.keras.Model):
    def __init__(self, feature_size, field_size, embedding_size=32, l2_reg=0.01, attention_units=64, **kwargs):
        self.pi_layer = PairWiseLayer(feature_size=feature_size, field_size=field_size, embedding_size=embedding_size,
                                      l2_reg=l2_reg, **kwargs)
        self.attention_layer = AttentionLayer(hidden_units=attention_units, embedding_size=embedding_size, l2_reg=l2_reg)
        self.dense = tf.keras.layers.Dense(units=1, activation=None)
        self.predict_layer = PredictLayer(embedding_size=embedding_size)
    def call(self, inputs):
        pi_out = self.pi_layer(inputs)
        # print('pi_layer_out shape: ', pi_out.shape)
        att_out = self.attention_layer(pi_out)
        # print('att_out shape: ', att_out.shape)
        den_out = self.dense(inputs['feature_vals'])
        preds = den_out + self.predict_layer(pi_out, att_out)
        return preds
5. train
import tensorflow as tf
from afm import AFM
import argparse
import shutil
import numpy as np
import json
import requests
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', default=r'', help='Data dir.')
parser.add_argument('--feature_size', default=454, help='Number of features.')
parser.add_argument('--field_size', default=196, help='Number of field_size.')
parser.add_argument('--embedding_size', default=16, help='Embedding size.')
parser.add_argument('--attention_units', default=64, help='Hidden units of MLP.')
parser.add_argument('--l2_reg', default=0.01, help='L2 regularizer for trainable variables.')
parser.add_argument('--batch_size', default=64, help='Batch size.')
parser.add_argument('--num_epochs', default=1, help='Number of epochs.')
parser.add_argument('--learning_rate', default=0.01, help='Learning rate.')
parser.add_argument('--task_type', default='train', help='Task type {train, eval, export, predict}.')
parser.add_argument('--model_dir', default=r'',
                    help='Model check point dir.')
parser.add_argument('--servable_model_dir', default=r'',
                    help='Model for tensorflow serving dir.')
parser.add_argument('--clear_existing_model', default=True, help='Weather to clearing the old model.')
def input_fn(filename, batch_size, num_epochs=1, shuffle=False):
    print('Parsing: ', filename)
    def decode_libsvm(line):
        "parsing libsvm file."
        columns = tf.strings.split(line, sep=' ')
        labels = tf.strings.to_number(columns[0], out_type=tf.int32)
        id_vals = tf.strings.split(columns[1:], sep=':').to_tensor()
        feature_ids, feature_vals = tf.split(id_vals, num_or_size_splits=2, axis=1)
        feature_ids = tf.squeeze(tf.strings.to_number(feature_ids, out_type=tf.int32))
        feature_vals = tf.squeeze(tf.strings.to_number(feature_vals, out_type=tf.float32))
        return {'feature_ids': feature_ids, 'feature_vals': feature_vals}, labels
        # return feature_ids, feature_vals, labels
    dataset =, num_parallel_calls=10).prefetch(500000)
    if shuffle:
        dataset = dataset.shuffle(buffer_size=256)
    dataset = dataset.repeat(num_epochs)
    dataset = dataset.batch(batch_size)
    return dataset
def model_train(args):
    feature_size = args.feature_size
    field_size = args.field_size
    embedding_size = args.embedding_size
    attention_units = args.attention_units
    l2_reg = args.l2_reg
    batch_size = args.batch_size
    num_epochs = args.num_epochs
    learning_rate = args.learning_rate
    data_dir = args.data_dir
    model_dir = args.model_dir
    servable_model_dir = args.servable_model_dir
    if args.clear_existing_model:
        except Exception as e:
            print(e, ' at clear existing model.')
            print('Existing model cleared at ', model_dir)
    train_file = data_dir + 'tr.libsvm'
    valid_file = data_dir + 'va.libsvm'
    train_data = input_fn(train_file, batch_size=batch_size, num_epochs=num_epochs, shuffle=True)
    valid_data = input_fn(valid_file, batch_size=64, num_epochs=1, shuffle=False)
    model = AFM(feature_size, field_size, embedding_size, l2_reg, attention_units)
    optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
    checkpoint = tf.train.Checkpoint(model=model)
    for batch_index, (inputs, y) in enumerate(train_data):
        with tf.GradientTape() as tape:
            y_pred = model(inputs, training=True)
            loss = tf.keras.losses.binary_crossentropy(y_true=y, y_pred=y_pred)
            loss = tf.reduce_mean(loss)
        grads = tape.gradient(target=loss, sources=model.trainable_variables)
        optimizer.apply_gradients(grads_and_vars=zip(grads, model.trainable_variables))
        if batch_index % 500 == 0:
            print('batch: %d,  loss: %f' % (batch_index, loss))
        if batch_index % 1000 == 0:
   + 'model.ckpt')
    # export
    #, export_dir=servable_model_dir, signatures={'call':})
    # validation
    auc = tf.keras.metrics.AUC()
    for inputs, y in valid_data:
        y = tf.reshape(y, shape=[-1, 1])
        y_pred = model.predict(inputs)
        auc.update_state(y_true=y, y_pred=y_pred)
    print('*' * 100)
    print('Test auc: %f\n' % auc.result())
if __name__ == '__main__':
    args = parser.parse_args()
    task_type = args.task_type
    if task_type == 'train':





0 条回复 A 作者 M 管理员
欢迎您,新朋友,感谢参与互动!欢迎您 {{author}},您在本站有{{commentsCount}}条评论