changeset 8:f71ade2784c4

gad: add string representation function and improve gradify
author Lewin Bormann <lbo@spheniscida.de>
date Mon, 27 Dec 2021 23:31:40 +0100
parents 9ffcd727ea5d
children 84d09d282bdb
files gad.py
diffstat 1 files changed, 53 insertions(+), 11 deletions(-) [+]
line wrap: on
line diff
--- a/gad.py	Mon Dec 27 15:58:27 2021 +0100
+++ b/gad.py	Mon Dec 27 23:31:40 2021 +0100
@@ -55,6 +55,9 @@
         self.g[i] = 1
         self.ade = ade
 
+    def __str__(self):
+        return f"<{self.i}>"
+
     def fw(self, v):
         return v[self.i]
 
@@ -66,6 +69,9 @@
         self.v = v
         self.n = n
 
+    def __str__(self):
+        return str(self.v)
+
     def fw(self, v):
         return self.v
 
@@ -73,6 +79,10 @@
         return np.zeros(self.n)[:,None]
 
 class OpPlus(Expression):
+
+    def __str__(self):
+        return f'({str(self.l)}) + ({str(self.r)})'
+
     def fw(self, v):
         return self.l.fw(v) + self.r.fw(v)
 
@@ -80,6 +90,10 @@
         return self.l.bw() + self.r.bw()
 
 class OpMinus(Expression):
+
+    def __str__(self):
+        return f'({str(self.l)}) - ({str(self.r)})'
+
     def fw(self, v):
         return self.l.fw(v) - self.r.fw(v)
 
@@ -87,6 +101,9 @@
         return self.l.bw() - self.r.bw()
 
 class OpMult(Expression):
+    def __str__(self):
+        return f'({str(self.l)}) * ({str(self.r)})'
+
     def fw(self, v):
         self.eval_l = self.l.fw(v)
         self.eval_r = self.r.fw(v)
@@ -97,6 +114,9 @@
         return self.l.bw()*self.eval_r + self.r.bw()*self.eval_l
 
 class OpDiv(Expression):
+    def __str__(self):
+        return f'({str(self.l)}) / ({str(self.r)})'
+
     def fw(self, v):
         self.eval_l = self.l.fw(v)
         self.eval_r = self.r.fw(v)
@@ -111,6 +131,9 @@
         return J @ g
 
 class OpPow(Expression):
+    def __str__(self):
+        return f'({str(self.l)})^({str(self.r)})'
+
     def fw(self, v):
         self.eval_l = self.l.fw(v)
         self.eval_r = self.r.fw(v)
@@ -127,6 +150,10 @@
         return np.hstack((gl, gr)) @ thisgrad
 
 class UnaryExpression(Expression):
+    
+    def __str__(self):
+        return f'{str(self.op)}({str(self.l)})'
+
     def __init__(self, op, dop, e):
         self.l = e
         self.op = op
@@ -192,23 +219,38 @@
     def __exit__(self):
         pass
 
-def gradify(f):
+def gradify(f, debug_expr=False):
     """Decorate a function in order to automatically obtain its Jacobian.
 
     The wrapped function will return a tuple (value, jacobian).
 
     Additionally, the computational graph is cached automatically, accelerating repeated invocations.
     """
-    c = {}
-    def newf(*xs):
-        if 'ade' not in c:
-            ade = ADE(len(xs))
-            newxs = ade.vars()
-            expr = f(*newxs)
-            c['expr'] = expr
-            c['ade'] = ade
-        return c['ade'].grad(c['expr'], xs)
-    return newf
+    class Call:
+        def __init__(self, f, debug_expr):
+            self.f = f
+            self.c = {}
+            self.debug_expr = debug_expr
+
+        def orig(self):
+            return self.f
+
+        def __call__(self, *xs):
+            c = self.c
+            if 'ade' not in c:
+                ade = ADE(len(xs))
+                newxs = ade.vars()
+                expr = self.f(*newxs)
+                if self.debug_expr:
+                    if type(expr) in (list, tuple, np.ndarray):
+                        print([str(e) for e in expr])
+                    else:
+                        print(str(expr))
+                c['expr'] = expr
+                c['ade'] = ade
+            return c['ade'].grad(c['expr'], xs)
+
+    return Call(f, debug_expr)
 
 ade = ADE(3)
 [x,y,z] = ade.vars()