Ming, Y., et al. (2019). Interpretable and steerable sequence learning via prototypes. Proceedings of the 25th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining.
Methodology
We aim to learn representative prototype sequences (not necessarily exist in the training data) that can be used as classification references and analogical explanations.
D = { ( ( x ( t ) ) t = 1 T , y ) } \mathcal D=\{((x^{(t)})_{t=1}^T,y)\} D={((x(t))t=1T,y)}: labeled sequence dataset
x ( t ) ∈ R n , y ∈ { 1 , . . . , C } x^{(t)}\in \R^n,y\in\{1,...,C\} x(t)∈Rn,y∈{1,...,C}
Architecture
r → p → f r\to p\to f r→p→f
r r r: sequence encoder
p p p: prototype layer
f f f: FC layer
sequence encoder
LSTM or GRU
( x ( t ) ) t = 1 T ∈ R n × T (x^{(t)})_{t=1}^T\in \R^{n\times T} (x(t))t=1T∈Rn×T: input
e = h ( T ) ∈ R m e=h^{(T)}\in \R^m e=h(T)∈Rm: output, the hidden state at the last
prototype layer
p i ∈ R m p_i\in \R^m pi∈Rm: prototype vectors 共k个
a ∈ R k a\in\R^k a∈Rk: output a i = exp ( − ∣ ∣ e − p i ∣ ∣ 2 2 ) a_i=\exp(-||e-p_i||^2_2) ai=exp(−∣∣e−pi∣∣22)
exp converts the distance to a score in ( 0 , 1 ] (0,1] (0,1] (improvement)
FC layer
dim k → C k\to C k→C
weights W W W non-negative
+softmax layer
y ^ \hat y y^: out
Loss func.
L o s s ( Θ , D ) = C E ( Θ , D ) + λ c R c ( Θ , D ) + λ e R e ( Θ , D ) + λ d R d ( Θ , D ) + λ l 1 ∣ ∣ W ∣ ∣ 1 Loss(\mathcal\Theta,\mathcal D)=CE(\mathcal\Theta,\mathcal D)+\lambda_cR_c(\mathcal\Theta,\mathcal D)+\lambda_eR_e(\mathcal\Theta,\mathcal D)+\lambda_dR_d(\mathcal\Theta,\mathcal D)+\lambda_{l_1}||W||_1 Loss(Θ,D)=CE(Θ,D)+λcRc(Θ,D)+λeRe(Θ,D)+λdRd(Θ,D)+λl1∣∣W∣∣1
C E ( Θ , D ) CE(\mathcal\Theta,\mathcal D) CE(Θ,D): cross entropy loss between y y y and y ^ \hat y y^
R c ( Θ , D ) = ∑ x ( t ) ) t = 1 T min i ∣ ∣ r ( ( x ( t ) ) t = 1 T ) − p i ∣ ∣ 2 2 R_c(\mathcal\Theta,\mathcal D)=\sum\limits_{x^{(t)})_{t=1}^T}\min\limits_i||r((x^{(t)})_{t=1}^T)-p_i||_2^2 Rc(Θ,D)=x(t))t=1T∑imin∣∣r((x(t))t=1T)−pi∣∣22: item close to proto.
R e ( Θ , D ) = ∑ i min x ( t ) ) t = 1 T ∣ ∣ r ( ( x ( t ) ) t = 1 T ) − p i ∣ ∣ 2 2 R_e(\mathcal\Theta,\mathcal D)=\sum\limits_i\min\limits_{x^{(t)})_{t=1}^T}||r((x^{(t)})_{t=1}^T)-p_i||_2^2 Re(Θ,D)=i∑x(t))t=1Tmin∣∣r((x(t))t=1T)−pi∣∣22: proto. close to item
R d ( Θ ) = ∑ i = 1 k ∑ j = i + 1 k max ( 0 , d m i n − ∣ ∣ p i − p j ∣ ∣ 2 ) 2 R_d(\mathcal\Theta)=\sum\limits_{i=1}^k\sum\limits_{j=i+1}^k\max(0,d_{min}-||p_i-p_j||_2)^2 Rd(Θ)=i=1∑kj=i+1∑kmax(0,dmin−∣∣pi−pj∣∣2)2: proto. not close
d m i n = 1.0 o r 2.0 d_{min}=1.0\ or\ 2.0 dmin=1.0 or 2.0 this model
L 1 L_1 L1 penalty: help to learn sequence prototypes that have more unitary and additive semantics for classification
prototype projection
p i ← r ( s e q i ) p_i\gets r(seq_i) pi←r(seqi)
s e q i ← arg min s e q ∈ r ( X ) ∣ ∣ s e q − p i ∣ ∣ 2 seq_i\gets \arg\min\limits_{seq\in r(\mathcal X)}||seq-p_i||_2 seqi←argseq∈r(X)min∣∣seq−pi∣∣2
simplification: 这里seq没有所有的sub-sequence遍历,而是使用beam search进行查找
The projection step is only performed every few training epochs (we set to 4 in our experiments) to reduce computational cost.
Refining ProSeNet with User Knowledge
set prototype mannually and no not update