Mercurial > lbo > hg > autodiff
changeset 5:f508e566dc78
autodiff: add autoconv and const
author | Lewin Bormann <lbo@spheniscida.de> |
---|---|
date | Thu, 23 Dec 2021 14:28:30 +0100 |
parents | f88dda95d735 |
children | 828857591bb6 |
files | autodiff.py |
diffstat | 1 files changed, 32 insertions(+), 7 deletions(-) [+] |
line wrap: on
line diff
--- a/autodiff.py Thu Dec 23 08:17:55 2021 +0100 +++ b/autodiff.py Thu Dec 23 14:28:30 2021 +0100 @@ -27,16 +27,33 @@ def bw(self, grad): pass + def _autoconv(self, e): + if isinstance(e, Expression): + return e + if type(e) in (int, float): + return Const(e) + return e + def __add__(self, other): - return OpPlus(self, other) + return OpPlus(self, self._autoconv(other)) def __sub__(self, other): - return OpMinus(self, other) - def __neg__(self, other): - return Num(name=self.name, id=self.id) + return OpMinus(self, self._autoconv(other)) + def __neg__(self): + return (self._autoconv(0)-self) def __mul__(self, other): - return OpMult(self, other) + return OpMult(self, self._autoconv(other)) def __truediv__(self, other): - return OpDiv(self, other) + return OpDiv(self, self._autoconv(other)) + +class Const(Expression): + def __init__(self, val): + self.v = val + + def fw(self): + return self.v + + def bw(self,grad): + pass class OpPlus(Expression): def fw(self): @@ -120,6 +137,10 @@ def jacobian(f, at): """Returns function value and jacobian.""" + if type(at) not in (tuple, list, np.ndarray): + at = [at] + if type(f) not in (tuple, list, np.ndarray): + f = [f] j = np.zeros((len(f), len(at))) val = np.zeros(len(f)) for i, ff in enumerate(f): @@ -177,8 +198,12 @@ z = np.array([sqrt(log(e)) for e in y]) return z +@gradify +def pres_calculation(x1, x2, x3): + return x1*x2 + exp(x1*x3)*cos(x2) + before = time.time_ns() -print(complex_calculation(1,4,5)) +print(pres_calculation(1,4,5)) after = time.time_ns() print((after-before)/1e9)