
    קgD                     f   d dl Z d dlmZ d dlmZmZmZmZ d dlZd dl	Zd dl
mZ d dlmZ d dlmZmZ d dlmZ d dlmZ 	 d dlZd	Zn# e$ r d
ZdZY nw xY wdgZddddddZg dZddddZer! ed
           G d d                      ZdS es! ed
           G d d                      ZdS dS )    N)chain)AnyDictOptionalTYPE_CHECKING)compatibility)_parse_stack_trace)_format_arg_get_qualified_name)normalize_function)TensorMetadataTFFxGraphDrawerz"AliceBlue"LemonChiffon1Yellow2	LightGrey
PowderBlue)placeholdercall_module	get_paramget_attroutput)
CadetBlue1CoralDarkOliveGreen1DarkSeaGreen1
GhostWhiteKhaki1LavenderBlush1LightSkyBlue
MistyRose1
MistyRose2PaleTurquoise2
PeachPuff1SalmonThistle1Thistle3Wheat1r$   "filled,rounded"#000000)	fillcolorstyle	fontcolor)is_backward_compatiblec                   6   e Zd ZdZ	 	 	 	 	 	 d$dej        j        dededed	ed
ede	e         defdZ
d%dej        fdZdej        fdZdej        fdZdeeej        f         fdZdej        j        deeef         fdZdej        j        dej        j        dej        j        fdZdedefdZ	 d&dedefdZdej        j        dej        j        d	ed
edef
dZdefdZdedefd Zd!ej        defd"Z dej        j        dededed	ed
edej        fd#Z!dS )'r   z
        Visualize a torch.fx.Graph with graphviz
        Basic usage:
            g = FxGraphDrawer(symbolic_traced, "resnet18")
            g.get_dot_graph().write_svg("a.svg")
        FTNgraph_modulenameignore_getattrignore_parameters_and_buffersskip_node_names_in_argsparse_stack_tracedot_graph_shapenormalize_argsc	           	         || _         ||nd| _        || _        | j        t          d<   ||                     ||||||          i| _        |j        j        D ]z}	|	j        dk    r| 	                    ||	          }
t          |
t          j        j                  sD|                     |
| d|	j         ||||          | j        | d|	j         <   {d S )Nrecordshaper   _)_namer5   r6   _WEIGHT_TEMPLATE_to_dot_dot_graphsgraphnodesop_get_leaf_node
isinstancetorchfxGraphModuletarget)selfr/   r0   r1   r2   r3   r4   r5   r6   node	leaf_nodes              X/var/www/html/ai-engine/env/lib/python3.11/site-packages/torch/fx/passes/graph_drawer.py__init__zFxGraphDrawer.__init__F   s    DJ#2#>H   #1D(,(<W% dll $8UWn  qB  D %*0  7m++ //dCC	!)UX-ABB <@LL++dk++"1+%= = D!8!84;!8!899     returnc                 X    ||                                  S |                     |          S )aA  
            Visualize a torch.fx.Graph with graphviz
            Example:
                >>> # xdoctest: +REQUIRES(module:pydot)
                >>> # xdoctest: +REQUIRES(module:ubelt)
                >>> # define module
                >>> class MyModule(torch.nn.Module):
                >>>     def __init__(self) -> None:
                >>>         super().__init__()
                >>>         self.linear = torch.nn.Linear(4, 5)
                >>>     def forward(self, x):
                >>>         return self.linear(x).clamp(min=0.0, max=1.0)
                >>> module = MyModule()
                >>> # trace the module
                >>> symbolic_traced = torch.fx.symbolic_trace(module)
                >>> # setup output file
                >>> import ubelt as ub
                >>> dpath = ub.Path.appdir('torch/tests/FxGraphDrawer').ensuredir()
                >>> fpath = dpath / 'linear.svg'
                >>> # draw the graph
                >>> g = FxGraphDrawer(symbolic_traced, "linear")
                >>> g.get_dot_graph().write_svg(fpath)
            )get_main_dot_graphget_submod_dot_graphrH   submod_names     rK   get_dot_graphzFxGraphDrawer.get_dot_graphp   s0    0 "..00000===rM   c                 &    | j         | j                 S Nr>   r;   rH   s    rK   rP   z FxGraphDrawer.get_main_dot_graph   s    #DJ//rM   c                 0    | j         | j         d|          S )Nr:   rW   rR   s     rK   rQ   z"FxGraphDrawer.get_submod_dot_graph   s     #tz$A$AK$A$ABBrM   c                     | j         S rV   )r>   rX   s    rK   get_all_dot_graphsz FxGraphDrawer.get_all_dot_graphs   s    ##rM   rI   c                    | j         dddd}|j        t          v rt          |j                 |d<   n|                    |j                  }t          t          j        |                                          	                                d d         d          }t          |t          t                    z           |d<   |S )Nz#CAFFE3r(   r)   )r9   r*   r+   r,   r*         )r5   rA   
_COLOR_MAP_pretty_print_targetrG   inthashlibmd5encode	hexdigest_HASH_COLOR_MAPlen)rH   rI   templatetarget_nametarget_hashs        rK   _get_node_stylezFxGraphDrawer._get_node_style   s     -&+&	 H w*$$(247(;%% #77DD!'+k.@.@.B.B"C"C"M"M"O"OPRQRPR"SUWXX(7c/FZFZ8Z([%OrM   modulec                 
   |}t          |j        t                    sJ |j                            d          }|D ]G}t	          ||          s%t          t          |          dz   |z   dz             t          ||          }H|S )N.z does not have attribute !)rC   rG   strsplithasattrRuntimeErrorgetattr)rH   rl   rI   py_objatomsatoms         rK   rB   zFxGraphDrawer._get_leaf_node   s     Fdk3/////K%%c**E / /vt,, &F&AADH3N   !..MrM   rG   c                    t          |t          j        j                  rt          j        |          }n't          |t
                    r|}nt          |          }|                    dd                              dd          S )N{\{}\})rC   rD   nnModuletypenamerp   r   replace)rH   rG   rets      rK   	_typenamezFxGraphDrawer._typename   su    &%(/22 2nV,,FC(( 2)&11
 ;;sE**223>>>rM      full_file_nametruncate_to_last_nc                     |                     d          }t          |          |k    rd                    || d                    S |S )N/)rq   rg   join)rH   r   r   splitss       rK   _shorten_file_namez FxGraphDrawer._shorten_file_name   sP    
 $))#..F6{{000xx(:':';'; <===!!rM   c                    fd}dd|j          d|j         dz   }|j        dk    ru|                     ||          |d|                               z   dz   z  }d	}t	          d
          r&d                    fdj        D                       }||dz   z  }n|d|                     |j                   dz   z  }| j        rE	 t          |j        |j
        |j        d          \  }}	n,# t          $ r |j
        |j        }	}Y nw xY w|j
        |j        }	}t          |          dk    r| ||          z  }t          |	          dk    r| ||	          z  }|dt          |j                   dz   z  }|j                            d          }
||                     |
          z  }|j                            dd           }| |d|j          dz   z  }|d|j         dz   z  }|rP|j        It)          |j                  }|                     |j                  }|d| d|j         d|j         dz   z  }|dz   S )Nc                    t          | t                    rd\  }}d | D             }n;t          | t                    r$d\  }}d |                                 D             }ndS rd |D             }t	          |          dk    rdS |d                    |          z   |z   }t	          |          d	k    r*|                    d
d                              dd          }|                    dd                              dd          S )N)z	|args=(\lz,\n)\lc                 0    g | ]}t          |d           S )r]   max_list_lenr
   .0as     rK   
<listcomp>zSFxGraphDrawer._get_node_label.<locals>._get_str_for_args_kwargs.<locals>.<listcomp>   s%    $Q$Q$Q[%C%C%C$Q$Q$QrM   )z|kwargs={\lz,\n}\lc                 @    g | ]\  }}| d t          |d           S ): r]   r   r   )r   kvs      rK   r   zSFxGraphDrawer._get_node_label.<locals>._get_str_for_args_kwargs.<locals>.<listcomp>   sG     % % % Aq @@AA > > >@@% % %rM    c                     g | ]}d |v|	S )% r   s     rK   r   zSFxGraphDrawer._get_node_label.<locals>._get_str_for_args_kwargs.<locals>.<listcomp>   s    $N$N$N1AQrM   r   z,\n   z\l\nry   rz   r{   r|   )rC   tupledictitemsrg   r   r   )argprefixsuffixarg_strs_listarg_strsr3   s        rK   _get_str_for_args_kwargsz?FxGraphDrawer._get_node_label.<locals>._get_str_for_args_kwargs   s,   c5)) 
%<NFF$Q$QS$Q$Q$QMMT** %>NFF% %$'IIKK% % %MM
 2 + O$N$N$N$N$NM}%%**2!FKK$>$>>G}%%**'//r::BB5"MMH''U33;;CGGGrM   ry   zname=%z	|op_code=
r   r   z\n|r   __constants__c                 :    g | ]}| d t          |           S )r   )rt   )r   cleaf_modules     rK   r   z1FxGraphDrawer._get_node_label.<locals>.<listcomp>   s1    ^^^qA::a!8!8::^^^rM   z|target=T)normalize_to_only_use_kwargsr   z|num_users=tensor_metabuf_metaz|buf=z
|n_origin=z|file=: r{   )r0   rA   rB   r   rr   r   r   rG   r6   r   argskwargs	Exceptionrg   usersmetaget_tensor_meta_to_labeln_originstack_tracer	   r   filelinenocode)rH   rl   rI   r3   r4   r   labelextrar   r   r   r   parsed_stack_tracefnamer   s      `          @rK   _get_node_labelzFxGraphDrawer._get_node_label   s   H H H H H0 B49BBtwBBBBEw-''"11&$??!<!<<vEE;88 !JJ^^^^KD]^^^ E &ADNN4;$?$?AAEII& 
:>'9 KDK^b( ( (ff % > > > (,y$+f>
 $(9dk&Dt99q==55d;;;Ev;;??55f===E8s4:885@@)--66KT//<<<E y}}Z66H#0005889h&799EAA ! hT%5%A%78H%I%I"//0B0GHH_%__*<*C__FXF]__bggg 3;s   	%C/ /D
	D
c                    |dS t          |t                    r|                     |          S t          |t                    r!d}|D ]}||                     |          z  }|S t          |t
                    r3d}|                                D ]}||                     |          z  }|S t          |t                    r!d}|D ]}||                     |          z  }|S t          dt          |                     )Nr   zUnsupported tensor meta type )
rC   r   _stringify_tensor_metalistr   r   valuesr   rs   type)rH   tmresultitemr   s        rK   r   z#FxGraphDrawer._tensor_meta_to_label!  s*   zrB// O222666B%% O ? ?Dd88>>>FFB%% O < <Ad88;;;FFB&& O ? ?Dd88>>>FF"#M488#M#MNNNrM   r   c                    d}t          |d          st          d|           |dt          |j                  z   dz   z  }|dt          t	          |j                            z   dz   z  }|dt          |j                  z   dz   z  }|dt          |j                  z   dz   z  }|j        rF|j	        J d	|j	        v sJ |j	        d	         }|t          j        t          j        hv rG|d
t          |j	        d                   z   dz   z  }|dt          |j	        d                   z   dz   z  }n|t          j        t          j        t          j        hv rj|dt          |j	        d                   z   dz   z  }|dt          |j	        d                   z   dz   z  }|dt          |j	        d                   z   dz   z  }nt!          d|           |dt          |j	        d	                   z   dz   z  }|S )Nr   dtyper   z|dtype=r   z|shape=z|requires_grad=z|stride=qschemez	|q_scale=scalez|q_zero_point=
zero_pointz|q_per_channel_scale=z|q_per_channel_zero_point=z|q_per_channel_axis=axiszUnsupported qscheme: z	|qscheme=)rr   printrp   r   r   r9   requires_gradstrideis_quantizedqparamsrD   per_tensor_affineper_tensor_symmetricper_channel_affineper_channel_symmetric per_channel_affine_float_qparamsrs   )rH   r   r   r   s       rK   r   z$FxGraphDrawer._stringify_tensor_meta8  s#   F2w''  dB)CMM9EAAF)Cbh,@,@@5HHF1C8H4I4IIEQQF*S^^;eCCF Uz--- BJ....*Y//2   3c"*W:M6N6NNQVVVF83rz,?W;X;XX[```FF03>!  
 ?#bjQXFYBZBZZ]bbbFDs2:VbKcGdGddglllF>RZPVEWAXAXX[```FF&'Hw'H'HIII/#bj6K2L2LLuTTMrM   tc                 t    t          |j                  t          t          |j                            z   dz   S )Nr   )rp   r   r   r9   )rH   r   s     rK   _get_tensor_labelzFxGraphDrawer._get_tensor_labelW  s*    qw<<#d17mm"4"44u<<rM   c                     t          j        |d          i }|j        j        D ]!|rj        dk    r                               }t          j        j        fd                     |||          i|}	}
j	        
                    dd          }|D|j        dk    r9|j        }||vrt          j        ||          ||<   |
                    |          }
|
                    |	            fd	}j        d
k    rA                     |          |s)t          t           j        j                  s
 |             #|                                D ]C}|                    dd           |                    dd                               |           D|j        j        D ]L|rj        dk    rj        D ]4}                    t          j        j        |j                             5MS )a  
            Actual interface to visualize a fx.Graph. Note that it takes in the GraphModule instead of the Graph.
            If ignore_parameters_and_buffers is True, the parameters and buffers
            created with the module will not be added as nodes and edges.
            TB)rankdirr   r   r   Nr   )r   c            	         t                                                                                    D ]\  } }j        dz   | z   }t	          |t
          j        j                  r|dz   dz   nd}t          j	        |fdd|z   
                    |          z   dz   it          }                    |                               t          j        |j                             d S )Nrn   z|op_code=get_	parameterzbuffer\lr   ry   r{   )r   named_parametersnamed_buffersr0   rC   rD   r}   	ParameterpydotNoder   r<   add_nodeadd_edgeEdge)	pnameptensorpname1label1
dot_w_node	dot_graphr   rI   rH   s	        rK   get_module_params_or_buffersz;FxGraphDrawer._to_dot.<locals>.get_module_params_or_buffers  s   */#44668Q8Q8S8S+ + J Jw "&S5!8  *'583EFF2F_4{BB!1 
 &+Z"& &"%,1G1G1P1P"PSV"V& /& &

 "**:666!**5:fdi+H+HIIIIJ JrM   r   color	royalbluepenwidth2)r   Dotr?   r@   rA   rk   r   r0   r   r   r   r   Clusterr   rB   rC   rD   rE   rF   r   setadd_subgraphr   r   r   )rH   r/   r0   r1   r2   r3   r4   buf_name_to_subgraphr+   dot_nodecurrent_graphr   buf_namer   subgraphuserr   r   rI   s   `               @@@rK   r=   zFxGraphDrawer._to_dot\  s     	$555I $& $*0 *7 *7! dg&;&;,,T22 :I %)%9%9,Negx%y%y ~C  !*9==T::'H,=,A,A'}H';;;9>xW_9`9`9`,X6$8$<$<X$F$FM&&x000J J J J J J J J$ 7m++"&"5"5lD"I"IK8 7KY^YaYmAnAn 74466607799 1 1Wk222Z---&&x0000$*0 I I! dg&;&; J I ID&&uz$)TY'G'GHHHHI rM   FFTFNFrV   )r   )"__name__
__module____qualname____doc__rD   rE   rF   rp   boolr   rL   r   r   rT   rP   rQ   r   r[   r   rk   r}   r~   rB   r   r   ra   r   r   r   r   r   Tensorr   r=   r   rM   rK   r   r   =   s       	 	 $)27,0&+-1#((	 (	(.(	 (	 !	(	
 ,0(	 &*(	  $(	 &c](	 !(	 (	 (	 (	T	> 	>UY 	> 	> 	> 	>:	0	 	0 	0 	0 	0	Cuy 	C 	C 	C 	C	$S%)^(< 	$ 	$ 	$ 	$	 	$sCx. 	 	 	 	"	(/	16	X_	 	 	 		?C 	?C 	? 	? 	? 	?& '(	" 	"	" !$	" 	" 	" 	"O	H(O	 (-O	 &*	O	
  $O	 O	 O	 O	 O	b	Os 	O 	O 	O 	O.	^ 	 	 	 	 	>	=u| 	= 	= 	= 	= 	=
M	(.M	 M	 !	M	
 ,0M	 &*M	  $M	 YM	 M	 M	 M	 M	 M	rM   c                   d    e Zd Z	 	 	 	 	 	 ddej        j        dedededed	ed
ee         defdZ	dS )r   FTNr/   r0   r1   r2   r3   r4   r5   r6   c	                      t          d          )Nz|FXGraphDrawer requires the pydot package to be installed. Please install pydot through your favorite Python package manager.)rs   )	rH   r/   r0   r1   r2   r3   r4   r5   r6   s	            rK   rL   zFxGraphDrawer.__init__  s     # $Y Z Z ZrM   r   )
r  r  r  rD   rE   rF   rp   r  r   rL   r   rM   rK   r   r     s         (-6;04*/15',Z Z#h2Z Z !%	Z
 04Z *.Z $(Z "*#Z !%Z Z Z Z Z ZrM   )rb   	itertoolsr   typingr   r   r   r   rD   torch.fxtorch.fx._compatibilityr   torch.fx.graphr	   torch.fx.noder
   r   torch.fx.operator_schemasr   torch.fx.passes.shape_propr   r   	HAS_PYDOTModuleNotFoundError__all__r_   rf   r<   r   r   rM   rK   <module>r     s          5 5 5 5 5 5 5 5 5 5 5 5   1 1 1 1 1 1 - - - - - - : : : : : : : : 8 8 8 8 8 8 5 5 5 5 5 5LLLII   IEEE
 
 !" 
  (     Z]%000k k k k k k k 10k k k\  Z	e	4	4	4	Z 	Z 	Z 	Z 	Z 	Z 	Z 
5	4	Z 	Z 	ZZ Zs    A 	AA