
     NgL                         d dl mZ d dlmZmZ d dlZd dlmZ d dl	m
Z
 d dlmZmZmZ d dlmZ  ee          Z G d d	e          ZdS )
    )	getLogger)TupleUnionN)Fusion)NumpyHelper)	NodeProtoTensorProtohelper)	OnnxModelc                   *    e Zd ZdZdedededededef fdZd%d
ededefdZ	d Z
	 d%d
edededeeef         fdZdedededededededeedf         fdZdedededededededeedf         fdZd Zd Zd Zd Zd  Zd!efd"Zd# Zd$ Z xZS )&FusionAttentionUnetzB
    Fuse Attention subgraph of UNet into one Attention node.
    modelhidden_size	num_headsis_cross_attentionenable_packed_qkvenable_packed_kvc                     t                                          ||r|rdnddg           || _        || _        || _        || _        || _        d| _        d| _        d S )N	AttentionMultiHeadAttentionLayerNormalizationT)	super__init__r   r   r   r   r   num_heads_warninghidden_size_warning)selfr   r   r   r   r   r   	__class__s          j/var/www/html/ai-engine/env/lib/python3.11/site-packages/onnxruntime/transformers/fusion_attention_unet.pyr   zFusionAttentionUnet.__init__   s     	-]2C]KKI]!"	
 	
 	

 '""4 "3 0 "&#'       F	reshape_q	is_torch2returnc                 r   d}|r| j                             |d          }|r|j        dk    rt          |j                  dk    rg| j                             |j        d                   }t          |t          j                  r(t          |j
                  dgk    rt          |          }nm| j                             |j        d                   }t          |t          j                  r.t          |j
                  dgk    rt          |d                   }t          |t                    r|dk    r|S dS )zDetect num_heads from a reshape node.

        Args:
            reshape_q (NodeProto): reshape node for Q
            is_torch2 (bool): graph pattern is from PyTorch 2.*
        Returns:
            int: num_heads, or 0 if not found
        r      Concat      )r   
get_parentop_typeleninputget_constant_value
isinstancenpndarraylistshapeint)r   r    r!   r   reshape_parentq_shape_values         r   get_num_headsz!FusionAttentionUnet.get_num_heads3   s*    	 	2!Z229a@@N /."8H"D"D^MaIbIbfgIgIg J99.:Nq:QRR	i44 /io9N9NSTRU9U9U #II !J99)/!:LMMM-44 2m>Q9R9RWXVY9Y9Ya 011	i%% 	)a--qr   c                     | j                             |j        d                   }|rt          j        |          j        d         S dS )zDetect hidden_size from LayerNormalization node.
        Args:
            layernorm_node (NodeProto): LayerNormalization node before Q, K and V
        Returns:
            int: hidden_size, or 0 if not found
        r'   r   )r   get_initializerr+   r   to_arrayr1   )r   layernorm_nodelayernorm_biass      r   get_hidden_sizez#FusionAttentionUnet.get_hidden_sizeO   sH     33N4H4KLL 	A'77=a@@qr   r9   c                    |                      ||          }|dk    r| j        }| j        dk    r?|| j        k    r4| j        r-t                              d| j         d| d           d| _        |                     |          }|dk    r| j        }| j        dk    r?|| j        k    r4| j        r-t                              d| j         d| d           d| _        ||fS )aF  Detect num_heads and hidden_size.

        Args:
            reshape_q (NodeProto): reshape node for Q
            is_torch2 (bool): graph pattern is from PyTorch 2.*
            layernorm_node (NodeProto): LayerNormalization node before Q, K, V
        Returns:
            Tuple[int, int]: num_heads and hidden_size
        r   z--num_heads is z. Detected value is z. Using detected value.Fz--hidden_size is )r5   r   r   loggerwarningr;   r   r   )r   r    r9   r!   r   r   s         r   get_num_heads_and_hidden_sizez1FusionAttentionUnet.get_num_heads_and_hidden_size\   s    &&y)<<	>>I>A)t~"="=% /wwwU^wwwxxx).&**>::!*KaK43C$C$C' 1r(8rrkrrr   ,1(+%%r   q_matmulk_matmulv_matmulr+   outputNc           
      R   | j          }|rt|j        d         |k    s"|j        d         |k    s|j        d         |k    r@t                              d|j        d         |j        d         |j        d                    dS n~|j        d         |k    s-|j        d         |j        d         k    s|j        d         |k    r@t                              d|j        d         |j        d         |j        d                    dS |dk    r+||z  dk    r"t                              d| d|            dS | j                            |j        d                   }	| j                            |j        d                   }
| j                            |j        d                   }|	r|
r|sdS |	j        }t          j        |	          }t          j        |
          }t          j        |          }t                              d|j	         d	|j	         d
|j	         d|            |r|j	        |j	        k    s|j	        |j	        k    rdS |j	        d         }|dk    r||k    rt          d| d| d          t          t          j        |j	        dd                             }| j        r| j                            d          }|}|}||z  }t          j        |                    |||          |                    |||          |                    |||          g                              ||dz  |z            }| j                            dd          }|                     |dz   ||j	        d         |j	        d         g|           t'          j        d|j        d         |dz   g|dz   g|          }| j        | j        |j        <   |                     |dz   t0          j        dgdd|d|gd           t'          j        d|dz   |dz   g|dz   g|dz             }| j        | j        |j        <   | j                            ||g           | j                            |||g           nvt          j        |||fd          }d|z  }| j                            d           }|                     |d!z   |||g|           n| j                            d          }| j        r|j	        |j	        k    rdS |j	        d         }|j	        d         }||k    sJ |j	        d         }|j	        d         }|j	        d         }||k    r||k    sJ |}|}||z  }t          j        |                    |||          |                    |||          g                              ||d"z  |z            }| j                            dd#          }|                     |dz   ||j	        d         |j	        d         g|           t'          j        d|j        d         |dz   g|dz   g|          }| j        | j        |j        <   |                     |dz   t0          j        dgdd|d"|gd           t'          j        d|dz   |dz   g|d$z   g|dz             }| j        | j        |j        <   | j                            ||g           | j                            ||g           t          j        d|gt          j         %          } d|z  }!|                     |d&z   ||!g|            |r| j        s||d!z   |d&z   g}"nK|dz   g}"nD| j        s+|j!        d         |j!        d         |j!        d         |d&z   g}"n|j!        d         |d$z   g}"t'          j        |r	| j        sd nd|"|g|          }#d'|#_"        |#j#                            t'          j$        d(|          g           |r	| j        sd)n&d*%                    | j        rd+n
| j        rd,nd-          }$| &                    |$           |#S ).  Create an Attention node.

        Args:
            q_matmul (NodeProto): MatMul node in fully connection for Q
            k_matmul (NodeProto): MatMul node in fully connection for K
            v_matmul (NodeProto): MatMul node in fully connection for V
            num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
            hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
            input (str): input name
            output (str): output name

        Returns:
            Union[NodeProto, None]: the node created or None if failed.
        r   RFor self attention, input hidden state for q and k/v shall be same. Got %s, %s, %sNXFor cross attention, input hidden state for q and k/v shall be different. Got %s, %s, %sinput hidden size # is not a multiple of num of heads r$   qw= kw= vw= hidden_size=Input hidden size (,) is not same as weight dimension of q,k,v (:). Please provide a correct input hidden size or pass in 0r      MatMul
MatMul_QKVname_prefix_weightname	data_typedimsvals_outinputsoutputsrX   _reshape_shape   FrX   rY   rZ   r[   rawReshape
_qkv_input_reshape)axisr   _qkv_weightr'   	MatMul_KV	_kv_inputdtype	_qkv_biascom.microsoftr   Attention (self attention)MultiHeadAttention ({})self attention with packed qkvcross attention with packed kvcross attention)'r   r+   r=   debugr   r7   rY   r   r8   r1   
ValueErrorr2   r.   prodr   create_node_namedstackreshapeadd_initializerr
   	make_nodethis_graph_namenode_name_to_graph_namerX   r	   INT64nodes_to_addextendnodes_to_removestackr   zerosfloat32rC   domain	attributemake_attributeformatincrease_counter)%r   r@   rA   rB   r   r   r+   rC   is_self_attentionq_weightk_weightv_weight
float_typeqwkwvw
qw_in_sizeqw_out_sizeattention_node_namecnh
qkv_weightmatmul_node_namematmul_nodereshape_nodeqkv_weight_dim
kw_in_size
vw_in_sizekw_out_sizevw_out_size	kv_weightqkv_biasqkv_bias_dimattention_inputsattention_nodecounter_names%                                        r   create_attention_nodez)FusionAttentionUnet.create_attention_node~   s	   0 !% 77 	~a E))X^A->%-G-G8>Z[K\`eKeKehN1%N1%N1%	   t Lf ~a E))hnQ.?8>RSCT.T.TZbZhijZkotZtZtnN1%N1%N1%	   t??i 7A==LLikii^giijjj4:--hnQ.?@@:--hnQ.?@@:--hnQ.?@@ 	X 	( 	4 '
!(++!(++!(++[28[[[[rx[[k[[\\\  G	Bx28##rx28';';t!JQ;*#<#< N+ N Ncm N N N   bgbhqrrl3344K% <&*j&A&ABV&W&W#9,  Y

1a(;(;RZZ1a=P=PRTR\R\]^`acdReRe'fggooq1uqy 
 $(:#>#>xUa#>#b#b $$)I5($*1-z/?/BC#	 %    %.$N1-/?)/KL-67)	   BFAU,[-=>$$),<<)/Q1a %     &/(61(+;; 1<?@)J6      CGBV,\->?!((+|)DEEE$++Xx,JKKKK  Xr2rl;;;
!"[&*j&A&A+&N&N#$$,}<($n5#	 %     #'*"="=>R"S"S$ 7B8rx''4Xa[
Xa[
!Z//// hqk hqk hqk"k11k[6P6P6P6P9, Irzz!Q':':BJJq!Q<O<O&PQQYYZ[]^ab]bef]fgg	#':#>#>xU`#>#a#a $$)I5(#/!,ioa.@A"	 %    %.$N1-/?)/KL-67)	   BFAU,[-=>$$),<<)/Q1a %     &/(61(+;; 1;>?)J6      CGBV,\->?!((+|)DEEE$++Xx,@AAA 8Q,BJ???;${2 	 	 	
 	
 	
  	) H'-7'+5$   %8,$F#G  ( OA&OA&OA&'+5	$   OA&'+5$ 
  )-gd6LgKKSg#H$	
 
 
 !0 '')>{I)V)V(WXXX !)-)?((*11)f009=9Ne55Te  	 	l+++r   q_matmul_addk_matmul_addv_matmul_addc           
         | j          }| j                            |dd          }	| j                            |dd          }
| j                            |dd          }|                     |          }|dS |\  }}|                     |          }|dS |\  }}|                     |          }|dS |\  }}|r|	j        d         |k    s"|
j        d         |k    s|j        d         |k    r@t
                              d|	j        d         |
j        d         |j        d                    dS |j        d         |k    s"|j        d         |k    s|j        d         |k    r@t
                              d|j        d         |j        d         |j        d                    dS n|	j        d         |k    s-|
j        d         |j        d         k    s|
j        d         |k    r@t
                              d|	j        d         |
j        d         |j        d                    dS |j        d         |k    s-|j        d         |j        d         k    s|
j        d         |k    r@t
                              d|j        d         |j        d         |j        d                    dS |dk    r+||z  dk    r"t
                              d| d	|            dS | j                            |	j        d
                   }| j                            |
j        d
                   }| j                            |j        d
                   }|r|r|sdS |j        dk    rt
                              d           dS t          j
        |          }t          j
        |          }t          j
        |          }t
                              d|j         d|j         d|j         d|            |r:|j        |j        k    s|j        |j        k    rdS |j        d         }|dk    r||k    rt          d| d| d          t          t          j        |j        d
d                             }| j        r| j                            d          }|}|}||z  } t          j        |                    |||           |                    |||           |                    |||           g                              ||dz  | z            }!| j                            dd          }"|                     |"dz   t*          j        |!j        d         |!j        d
         g|!           t/          j        d|
j        d         |"dz   g|"dz   g|"          }#| j        | j        |#j        <   |j        dz   }$|                     |$t*          j        dgdd|| gd           | j                            d d!          }%t/          j        d |j        d         |$g|%dz   g|%          }&| j        | j        |&j        <   | j                            d d"          }'t/          j        d |j        d         |$g|'dz   g|'          }(| j        | j        |(j        <   | j                            d d#          })t/          j        d |j        d         |$g|)dz   g|)          }*| j        | j        |*j        <   | j                            d$d%          }+t/          j        d$|&j        d         |(j        d         |*j        d         g|+dz   g|+          },|,j                            t/          j         d&d          g           | j        | j        |,j        <   |,j        dz   }-|                     |-t*          j        dgdd|dz  | z  gd           | j                            d d'          }.t/          j        d |,j        d         |-g|.dz   g|.          }/| j        | j        |/j        <   | j                            d(d)          }0t/          j        d(|/j        d         |#j        d         g|0dz   g|0          }1| j        | j        |1j        <   |0dz   }2|                     |2t*          j        d*gdd|d| gd           t/          j        d |1j        d         |2g|d+z   g|0d,z             }3| j        | j        |3j        <   | j!                            |#|&|(|*|,|/|1|3g           | j"                            |	|
||||g           ndS | j                            d          }| j#        r|j        |j        k    rdS |j        d         }4|j        d         }5|4|5k    sJ |j        d
         }|j        d
         }6|j        d
         }7||7k    r|6|7k    sJ |4}|}|6|z  } t          j        |                    |||           |                    |||           g                              ||d-z  | z            }8| j                            dd.          }"|                     |"dz   t*          j        |8j        d         |8j        d
         g|8           t/          j        d|
j        d         |"dz   g|"dz   g|"          }#| j        | j        |#j        <   |j        dz   }9|                     |9t*          j        dgdd|| gd           | j                            d d"          }'t/          j        d |j        d         |9g|'dz   g|'          }(| j        | j        |(j        <   | j                            d d#          })t/          j        d |j        d         |9g|)dz   g|)          }*| j        | j        |*j        <   | j                            d$d/          }:t/          j        d$|(j        d         |*j        d         g|:dz   g|:          };|;j                            t/          j         d&d          g           | j        | j        |;j        <   |;j        dz   }<|                     |<t*          j        dgdd|d-z  | z  gd           | j                            d d0          }=t/          j        d |;j        d         |<g|=dz   g|=          }>| j        | j        |>j        <   | j                            d(d1          }?t/          j        d(|>j        d         |#j        d         g|?dz   g|?          }@| j        | j        |@j        <   |?dz   }2|                     |2t*          j        d*gdd|d-| gd           t/          j        d |@j        d         |2g|d2z   g|?d,z             }3| j        | j        |3j        <   | j!                            |#|(|*|;|>|@|3g           | j"                            |
|||g           ndS t          j$        d|gt          j%        3          }Ad|z  }B|                     |d4z   t*          j        |Bg|A           |r| j        sdS |d+z   g}Cn| j#        sdS |j        d         |d2z   g}Ct/          j        |r	| j        sd5nd|C|g|          }Dd6|D_&        |Dj                            t/          j         d7|          g           |r	| j        sd8n&d9'                    | j        rd:n
| j#        rd;nd<          }E| (                    |E           |DS )=rE   rR   r   NrF   z_For self attention, input hidden state for LoRA q and k/v weights shall be same. Got %s, %s, %srG   zeFor cross attention, input hidden state for LoRA q and k/v weights shall be different. Got %s, %s, %srH   rI   r$   
   zBweights are in fp16. Please run fp16 conversion after optimizationrJ   rK   rL   rM   rN   rO   rP   r   rQ   rS   rT   rV   rW   r\   r]   r`   r&   Frb   rd   Reshape_LoRA_QReshape_LoRA_KReshape_LoRA_Vr%   Concat_LoRA_QKVrg   Reshape_LoRA_QKVAddAdd_Weights_QKVra   re   rf   r'   ri   Concat_LoRA_KVReshape_LoRA_KVAdd_Weights_KVrj   rk   rm   r   rn   r   ro   rp   rq   rr   rs   ))r   r   match_parentmatch_lora_pathr+   r=   rt   r7   rY   r   r8   r1   ru   r2   r.   rv   r   rw   rx   ry   rz   r	   FLOATr
   r{   r|   r}   rX   r~   rC   r   r   r   r   r   r   r   r   r   r   r   )Fr   r   r   r   r   r   r+   rC   r   r@   rA   rB   q_lora_nodesq_lora_last_nodeq_lora_matmul_1k_lora_nodesk_lora_last_nodek_lora_matmul_1v_lora_nodesv_lora_last_nodev_lora_matmul_1r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   lora_weight_shape_tensor_nameq_lora_reshape_node_nameq_lora_reshape_nodek_lora_reshape_node_namek_lora_reshape_nodev_lora_reshape_node_namev_lora_reshape_nodeqkv_lora_concat_node_nameqkv_lora_concat_node'reshaped_lora_weights_shape_tensor_nameqkv_lora_reshaped_node_nameqkv_lora_reshaped_nodeadd_weights_node_nameadd_weights_nodeshape_tensor_namer   r   r   r   r   r    kv_lora_weight_shape_tensor_namekv_lora_concat_node_namekv_lora_concat_node*reshaped_kv_lora_weights_shape_tensor_namekv_lora_reshaped_node_namekv_lora_reshaped_nodeadd_kv_weights_node_nameadd_kv_weights_noder   r   r   r   r   sF                                                                         r   create_attention_node_loraz.FusionAttentionUnet.create_attention_node_lora~  s   0 !% 77:**<1EE:**<1EE:**<1EE++L994.:+	?++L994.:+	?++L994.:+	? .	~a E))X^A->%-G-G8>Z[K\`eKeKehN1%N1%N1%	   t  %a(E11"(+u44"(+u44u#)!,#)!,#)!,	   t 5 ~a E))hnQ.?8>RSCT.T.TZbZhijZkotZtZtnN1%N1%N1%	   t  %a(E11#)!,0Ea0HHHN1%..) $)!,#)!,#)!,   t??i 7A==LLikii^giijjj4:--hnQ.?@@:--hnQ.?@@:--hnQ.?@@ 	X 	( 	4 ##LL]^^^4!(++!(++!(++[28[[[[rx[[k[[\\\  q	x28##rx28';';t!JQ;*#<#< N+ N Ncm N N N   bgbhqrrl3344K% S&*j&A&ABV&W&W#9,  Y

1a(;(;RZZ1a=P=PRTR\R\]^`acdReRe'fggooq1uqy 
 $(:#>#>xUa#>#b#b $$)I5)/$*1-z/?/BC#	 %    %.$N1-/?)/KL-67)	   BFAU,[-=> 1A0EHX0X-$$6)/Q1 %    ,0:+F+Fy^n+F+o+o(&,&6,3A68UV5>?1	' ' '# JNI],-@-EF ,0:+F+Fy^n+F+o+o(&,&6,3A68UV5>?1	' ' '# JNI],-@-EF ,0:+F+Fy^n+F+o+o(&,&6,3A68UV5>?1	' ' '# JNI],-@-EF -1J,G,G^o,G,p,p)'-'7+215+215+215
 7?@2	( 	( 	($ %.55v7LVUV7W7W6XYYYJNJ^,-A-FG ;O:SVf:f7$$@)/QA	* %    /3j.I.I)as.I.t.t+)/)907:<cd86AB4	* * *& MQL`,-C-HI )-
(C(CEWh(C(i(i%#)#329!<k>PQR>ST2V;<.	$ $ $  GKFZ,-=-BC %:<L$L!$$*)/Q1a %     &/,3A68IJ0<?@.;	      CGBV,\->?!((#+++,.($	   $++Xx<Yegs,tuuuu t"&*"="=>R"S"S$ J8rx''4Xa[
Xa[
!Z//// hqk hqk hqk"k11k[6P6P6P6P9, Irzz!Q':':BJJq!Q<O<O&PQQYYZ[]^ab]bef]fgg	#':#>#>xU`#>#a#a $$)I5)/#/!,ioa.@A"	 %    %.$N1-/?)/KL-67)	   BFAU,[-=> 4D3HK[3[0$$9)/Q1 %    ,0:+F+Fy^n+F+o+o(&,&6,3A68XY5>?1	' ' '# JNI],-@-EF ,0:+F+Fy^n+F+o+o(&,&6,3A68XY5>?1	' ' '# JNI],-@-EF ,0:+F+Fx]m+F+n+n(&,&6/6q9;N;UVW;XY5>?1	' ' '# $-44f6KFTU6V6V5WXXXIMI],-@-EF >Q=UXh=h:$$C)/QA	* %    .2Z-H-H`q-H-r-r*(.(8/6q9;ef7&@A3	) ) )% LPK_,-B-GH ,0:+F+FuZj+F+k+k(&,&618;[=OPQ=RS5>?1	' ' '# JNI],-@-EF %=?O$O!$$*)/Q1a %     &//6q9;LM0;>?1J>	      CGBV,\->?!((#+++-+$
 
 
 $++Xx|,\]]]] t 8Q,BJ???;${2!'	 	 	
 	
 	
  	) Ht$7,$F#G  ( t !'*'+5$ 
  )-gd6LgKKSg#H$	
 
 
 !0 '')>{I)V)V(WXXX !)-)?((*11)f009=9Ne55Te  	 	l+++r   c           
         |                      |||          rd S | j                            |dd          }|#| j        s| j                            |dd          }|d S |j        d         }||         }d }|D ]}|j        dk    r|} n|d S |                     ||          p|                     ||          }	|	|	\  }
}}}}}}|}|                     |||
          \  }}|dk    rt          
                    d           d S |                     ||||||j        d         |j        d                   }|d S n|                     ||          p|                     ||          }	|	d S |	\  }
}}}}}}|}|                     |||
          \  }}|dk    rt          
                    d           d S |                     ||||||j        d         |j        d                   }|d S |                     |||
          \  }}|dk    rt          
                    d           d S | j                            |           | j        | j        |j        <   | j                            ||g           d| _        d S )Nr   r   rd   *fuse_attention: failed to detect num_headsr+   rC   T)fuse_a1111_fp16r   r   r   rC   r)   match_qkv_torch1match_qkv_torch2r?   r=   rt   r   match_qkv_torch1_loramatch_qkv_torch2_lorar   r   appendr|   r}   rX   r   r   prune_graph)r   normalize_nodeinput_name_to_nodesoutput_name_to_nodenode_before_layernorm
root_inputchildren_nodesskip_addnode	match_qkvr!   reshape_qkvtranspose_qkvr    matmul_qmatmul_kmatmul_vattention_last_nodeq_num_headsq_hidden_sizenew_nodematmul_add_qmatmul_add_kmatmul_add_vs                           r   fusezFusionAttentionUnet.fuseR  s+   0CEXYY 	F $
 7 7q Q Q !(1H($(J$;$;NIWX$Y$Y! (F*1!4
,Z8" 	 	D|u$$ % F))*h??n4CXCXYcemCnCn	 ]fZI{M9hRZ"-)-)K)KIWegp)q)q&KaIJJJ 11$+A.*1!4 2  H    22:xHH DLfLfHM MI  irfI{M9lLZf"-)-)K)KIWegp)q)q&KaIJJJ 66$+A.*1!4 7  H )-)K)KIWegp)q)q&KaIJJJ  ***6:6J$X]3##%8-$HIII  r   c           
         |j         d         |k    rdnd}| j                            |g d|dddddg          }|dS |\  }}}}}}| j                            |g dg d          }	|	t                              d           dS |	\  }}}}
| j                            |g dg d	          }||\  }}}nF| j                            |g d
g d          }||\  }}}}nt                              d           dS | j                            |g dg d          }|t                              d           dS |\  }}}}| j                            |g dg d          }|t                              d           dS |\  }}}}}d||||||
fS )z.Match Q, K and V paths exported by PyTorch 1.*r   r$   )r   rR   rd   	Transposerd   rR   Nrd   r   rd   rR   r$   r   r   r   &fuse_attention: failed to match v pathSoftmaxMulrR   r   r   r   r  r   r  rR   r   r   r   r   'fuse_attention: failed to match qk path&fuse_attention: failed to match q path)r   rd   r   rd   rR   r$   r   r   r   r   &fuse_attention: failed to match k pathFr+   r   match_parent_pathr=   rt   )r   r   r   another_input	qkv_nodes_r   r   
matmul_qkvv_nodesr   qk_nodes_softmax_qk_mul_qk	matmul_qk	_add_zeroq_nodes_transpose_qr    r   k_nodesr   s                         r   r   z$FusionAttentionUnet.match_qkv_torch1  s   %^A.*<<!J00JJJD$1a0
 
	 4<E9A{M1j *..z;h;h;hjvjvjvww?LLABBB4%Aq(://
<X<X<XZcZcZcdd08-['99z33J@c@c@ceqeqeqrrH#?G<i))FGGGt*..y:g:g:giuiuiuvv?LLABBB418.L)X*..QQQSbSbSb
 
 ?LLABBB4!(Aq!Xk=)XxQYYYr   c           	      T   |j         d         |k    rdnd}| j                            |g d|ddddg          }|dS |\  }}}}}| j                            |g dg d          }	|	t                              d           dS |	\  }}}
| j                            |dd	gddg          }||\  }}nt                              d
           dS | j                            |g dg d          }|t                              d           dS |\  }}}}| j                            |g dg d          }|t                              d           dS |\  }}}}| j                            |g dg d          }||d         |k    rt                              d           dS d||||||
fS )z.Match Q, K and V paths exported by PyTorch 2.*r   r$   )r   rR   rd   r   rR   N)r   rd   rR   r$   r   r   r   r  rR   r  )r  r   rd   rR   r   Nr   r   r  r$   Nr   r   r	  SqrtDivr  CastSliceShaper   rd   Nr   r$   r   r   r   r   r   z*fuse_attention: failed to match mul_q pathTr
  )r   r   r   r  r  r  r   r   r  r  r   r  r  r  r  mul_qr  r    r   r  _mul_kr   mul_q_nodess                          r   r   z$FusionAttentionUnet.match_qkv_torch2  s   %^A.*<<!J00???D$1-
 
	 49B6A{M:*..z;];];]_h_h_hii?LLABBB4"Ax://
Y<QTUWXSYZZ'/$[))LLBCCC4*..y:c:c:cetetetuu?LLABBB45<2i*..y:c:c:cetetetuu?LLABBB4#* Ax j22UUU'''
 

 +b/Y">">LLEFFF4[-HhPXXXr   c                 "   |j         d         |k    rdnd}| j                            |g d|ddddddg          }|dS |\  }}}}}}}| j                            |g dg d          }	|	t                              d           dS |	\  }}}}
| j                            |g dg d	          }||\  }}}nF| j                            |g d
g d          }||\  }}}}nt                              d           dS | j                            |g dg d          }|t                              d           dS |\  }}}}| j                            |g dg d          }|t                              d           dS |\  }}}}}d||||||
fS )zJMatch Q, K and V paths exported by PyTorch 1 that contains LoRA patterns.*r   r$   )r   r   rR   rd   r   rd   rR   N)rd   r   rd   r   r   +fuse_attention: failed to match LoRA v pathr   r  r  r  ,fuse_attention: failed to match LoRA qk path+fuse_attention: failed to match LoRA q path)r   rd   r   rd   r   r  +fuse_attention: failed to match LoRA k pathFr
  )r   r   r   r  r  r  r   r   r  r  r   r  r  r  r  r  r  r  r    r   r  r   s                         r   r   z)FusionAttentionUnet.match_qkv_torch1_lora  s   %^A.*<<!J00QQQAtT1a3
 
	
 4?H<Aq+}a *..z;e;e;egsgsgstt?LLFGGG4")Aq,://
<X<X<XZcZcZcdd08-['99z33J@c@c@ceqeqeqrrH#?G<i))KLLLt*..y:d:d:dfrfrfrss?LLFGGG45<2L)\*..NNNP_P_P_
 
 ?LLFGGG4%,"Aq!\k=)\<Yeeer   c           
      X   |j         d         |k    rdnd}| j                            |g d|dddddg          }|dS |\  }}}}}}| j                            |g dg d          }	|	t                              d           dS |	\  }}}
| j                            |dd	gddg          }||\  }}nt                              d
           dS | j                            |g dg d          }|t                              d           dS |\  }}}}| j                            |g dg d          }|t                              d           dS |\  }}}}| j                            |g dg d          }||d         |k    rt                              d           dS d||||||
fS )zJMatch Q, K and V paths exported by PyTorch 2 that contains LoRA patterns.*r   r$   )r   r   rR   rd   r   rR   N)r   rd   r   r  r)  r  rR   r*  )r  r   rd   r   r  r+  r  r,  r  r#  r$  z/fuse_attention: failed to match LoRA mul_q pathTr
  )r   r   r   r  r  r  r   r   r  r  r   r  r  r  r  r%  r  r    r   r  r&  r   r'  s                          r   r   z)FusionAttentionUnet.match_qkv_torch2_loraA  s#   %^A.*<<!J00FFFAtT1a0
 
	
 4<E9Aq+}j*..z;Z;Z;Z\e\e\eff?LLFGGG4&A|://
Y<QTUWXSYZZ'/$[))LLGHHH4*..y:`:`:`bqbqbqrr?LLFGGG49@6i*..y:`:`:`bqbqbqrr?LLFGGG4'.$A| j22UUU'''
 

 +b/Y">">LLJKKK4[-L,Xdddr   add_nodec                    | j                             |ddgddg          }|	|\  }}||fS | j                             |g dg d          }|
|\  }}}||fS | j                             |g dg d          }||\  }}}}||fS d S )NrR   r$   r   )r  rR   rR   r  )r  r  rR   rR   r   )r   r  )r   r.  
lora_nodeslora_matmul_2_nodelora_matmul_1_nodelora_mul_noder  s          r   r   z#FusionAttentionUnet.match_lora_patht  s     Z11x F
 

 !7A4!3&(:;; Z11'''II
 

 !5?2]A1!#566 Z11...LL
 

 !8B5]Aq"4!#566tr   c           
      <   | j                             |ddgddg          }|$| j                             |ddgddg          }|dS |\  }}|j        d         }||         }d}	|D ]}
|
j        dk    r|
}	 n|	dS |                     ||	          }|dS |\  }}}}}}| j                             |dd          }| j                             |dd          }| j                             |dd          }||| j        s||k    rn||k    r||k    sdS |j        d         |j        d         k    rdS |}|                     |d          p|                     |d          }|dk    rt          
                    d           dS |                     |          }|                     ||||||j        d         |j        d         	          }|dS | j                            |           | j        | j        |j        <   | j                            ||g           d| _        dS )
zPFuse attention of fp16 UNet exported in A1111 (stable diffusion webui) extensionr   r   r   Nrd   FTr   r   )r   r  rC   r)   match_qkv_a1111r   r   r+   r5   r=   rt   r;   r   r   r   r|   r}   rX   r   r   r   )r   r   r   r   
entry_path_castr   r   r   r   r   r   r   r   r    r   r   r   cast_qcast_kcast_vr   r   r   r   s                            r   r   z#FusionAttentionUnet.fuse_a1111_fp16  s   Z11.65/TUWXSYZZ
55nvyFY\]_`[abbJ!u'1$$*1!4
,Z8" 	 	D|u$$ % 5((X>>	5 	
 ((61==((61==((61==")-)@ #6!!!fPVFVFV&  5<?n3A6665)((D99aT=O=OPY[`=a=a!LLEFFF5,,^<< --.#&-a0 . 
 
 5  ***6:6J$X]3##%8-$HIII  tr   c           
         |j         d         |k    rdnd}| j                            |g d|dddddg          }|dS |\  }}}}}}	| j                            |	g dg d          }
|
t                              d           dS |
\  }}}}| j                            |	g dg d	          }|	|\  }}}}}nt                              d
           dS | j                            |g dg d          }|t                              d           dS |\  }}}}| j                            |g dg d          }|t                              d           dS |\  }}}}||||||fS )zKMatch Q, K and V paths exported by A1111 (stable diffusion webui) extensionr   r$   )r   rR   rd   r   rd   EinsumNr   r   r   )r   r   r  r  r<  )r   r   r   r   Nr  r  r  r	  r
  )r   r   r   r  r  r  r   r   reshape_einsum
einsum_qkvr  r   r  r  	einsum_qkr  r  r    r   r  r   s                        r   r5  z#FusionAttentionUnet.match_qkv_a1111  s   %^A.*<<!J00JJJD$1a0
 
	 4IRFA{M>:*..z;h;h;hjvjvjvww?LLABBB4%Aq(://DDDFXFXFX
 
 08-Q;99LLBCCC4*..y:g:g:giuiuiuvv?LLABBB418.L)X*..y:g:g:giuiuiuvv?LLABBB4%Aq(M9h(RRr   )F)__name__
__module____qualname____doc__r   r2   boolr   r   r5   r;   r   r?   strr   r   r   r   r   r   r   r   r   r   r5  __classcell__)r   s   @r   r   r      s        (( ( 	(
 !(  ( ( ( ( ( ( (: y T c    8   RW &  &" &4= &JN &	sCx &  &  &  &D~~ ~ 	~
 ~ ~ ~ ~ 
y$	~ ~ ~ ~@RR  R  	R
 R R R R 
y$	R R R RhX  X  X t/Z /Z /Zb2Y 2Y 2Yh.f .f .f`1e 1e 1ef** * * *XL L L\*S *S *S *S *S *S *Sr   r   )loggingr   typingr   r   numpyr.   fusion_baser   fusion_utilsr   onnxr   r	   r
   
onnx_modelr   r@  r=   r    r   r   <module>rO     s   
                         $ $ $ $ $ $ / / / / / / / / / /            	8		GS GS GS GS GS& GS GS GS GS GSr   