changeset 5:f508e566dc78

autodiff: add autoconv and const
author Lewin Bormann <lbo@spheniscida.de>
date Thu, 23 Dec 2021 14:28:30 +0100
parents f88dda95d735
children 828857591bb6
files autodiff.py
diffstat 1 files changed, 32 insertions(+), 7 deletions(-) [+]
line wrap: on
line diff
--- a/autodiff.py	Thu Dec 23 08:17:55 2021 +0100
+++ b/autodiff.py	Thu Dec 23 14:28:30 2021 +0100
@@ -27,16 +27,33 @@
     def bw(self, grad):
         pass
 
+    def _autoconv(self, e):
+        if isinstance(e, Expression):
+            return e
+        if type(e) in (int, float):
+            return Const(e)
+        return e
+
     def __add__(self, other):
-        return OpPlus(self, other)
+        return OpPlus(self, self._autoconv(other))
     def __sub__(self, other):
-        return OpMinus(self, other)
-    def __neg__(self, other):
-        return Num(name=self.name, id=self.id)
+        return OpMinus(self, self._autoconv(other))
+    def __neg__(self):
+        return (self._autoconv(0)-self)
     def __mul__(self, other):
-        return OpMult(self, other)
+        return OpMult(self, self._autoconv(other))
     def __truediv__(self, other):
-        return OpDiv(self, other)
+        return OpDiv(self, self._autoconv(other))
+
+class Const(Expression):
+    def __init__(self, val):
+        self.v = val
+
+    def fw(self):
+        return self.v
+
+    def bw(self,grad):
+        pass
 
 class OpPlus(Expression):
     def fw(self):
@@ -120,6 +137,10 @@
 
 def jacobian(f, at):
     """Returns function value and jacobian."""
+    if type(at) not in (tuple, list, np.ndarray):
+        at = [at]
+    if type(f) not in (tuple, list, np.ndarray):
+        f = [f]
     j = np.zeros((len(f), len(at)))
     val = np.zeros(len(f))
     for i, ff in enumerate(f):
@@ -177,8 +198,12 @@
     z = np.array([sqrt(log(e)) for e in y])
     return z
 
+@gradify
+def pres_calculation(x1, x2, x3):
+    return x1*x2 + exp(x1*x3)*cos(x2)
+
 before = time.time_ns()
-print(complex_calculation(1,4,5))
+print(pres_calculation(1,4,5))
 after = time.time_ns()
 print((after-before)/1e9)