{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dolfin as df\n",
    "import matplotlib.pyplot as plt\n",
    "import mshr as ms\n",
    "import numpy as np\n",
    "import time\n",
    "df.set_log_level(40)\n",
    "\n",
    "# domain = ms.Sphere(df.Point(0, 0, 0), 1.0)\n",
    "# mesh = ms.generate_mesh(domain, 50)\n",
    "mesh = df.UnitIntervalMesh(10000)\n",
    "F = df.FunctionSpace(mesh, 'CG', 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_sim(c0, c_tot, Ga0, dt, n_t, sym):\n",
    "    tc = df.TestFunction(F)\n",
    "    c = df.Function(F)\n",
    "    X = df.SpatialCoordinate(mesh)\n",
    "#     X.interpolate(df.Expression('x[0]', degree=1))\n",
    "    if sym == 0:\n",
    "    # Weak form 1D:\n",
    "        form = ((df.inner((c-c0)/dt, tc) +\n",
    "             df.inner(df.grad(c), df.grad((1-c_tot)*Ga0*tc))) -\n",
    "             df.inner(df.grad(c_tot), df.grad((1-c_tot)*Ga0/c_tot*c*tc))-\n",
    "             tc*df.inner(df.grad(c), df.grad((1-c_tot)*Ga0))+\n",
    "             tc*df.inner(df.grad(c_tot), df.grad((1-c_tot)/c_tot*c*Ga0))) * df.dx\n",
    "    elif sym == 2:\n",
    "    # Weak form radial symmetry:\n",
    "        form = ((df.inner((c-c0)/dt, tc*X[0]*X[0]) +\n",
    "             df.inner(df.grad(c), df.grad((1-c_tot)*Ga0*tc*X[0]*X[0]))) -\n",
    "             df.inner(df.grad(c_tot), df.grad((1-c_tot)*Ga0/c_tot*c*tc*X[0]*X[0]))-\n",
    "             tc*df.inner(df.grad(c), df.grad((1-c_tot)*Ga0*X[0]*X[0]))+\n",
    "             tc*df.inner(df.grad(c_tot), df.grad((1-c_tot)/c_tot*c*Ga0*X[0]*X[0]))-\n",
    "             (1-c_tot)*Ga0*2*X[0]*c.dx(0)*tc+\n",
    "             (1-c_tot)*Ga0/c_tot*c*2*X[0]*c_tot.dx(0)*tc) * df.dx\n",
    "  # Weak form radial symmetry:\n",
    "#     form = ((df.inner((c-c0)/dt, tc*X[0]*X[0]) +\n",
    "#              df.inner(df.grad(c), df.grad((1-c_tot+Ga0*c_tot*c_tot)*tc*X[0]*X[0]))) -\n",
    "#              df.inner(df.grad(c_tot), df.grad((1-c_tot+Ga0*c_tot*c_tot)/c_tot*c*tc*X[0]*X[0]))-\n",
    "#              tc*df.inner(df.grad(c), df.grad((1-c_tot+Ga0*c_tot*c_tot)*X[0]*X[0]))+\n",
    "#              tc*df.inner(df.grad(c_tot), df.grad((1-c_tot+Ga0*c_tot*c_tot)/c_tot*c*X[0]*X[0]))-\n",
    "#              (1-c_tot+Ga0*c_tot*c_tot)*2*X[0]*c.dx(0)*tc+\n",
    "#              (1-c_tot+Ga0*c_tot*c_tot)/c_tot*c*2*X[0]*c_tot.dx(0)*tc) * df.dx\n",
    "    \n",
    "    t = 0\n",
    "    # Solve in time\n",
    "    ti = time.time()\n",
    "    for i in range(n_t):\n",
    "        df.solve(form == 0, c)\n",
    "        df.assign(c0, c)\n",
    "        t += dt\n",
    "    print(time.time() - ti)\n",
    "    return c0\n",
    "\n",
    "def p_tot(p_i, p_o):\n",
    "    return str(p_i-p_o)+'*(-0.5*tanh(35000*(x[0]-0.1))+0.5)+'+str(p_o)\n",
    "#     return '(x[0]<0.1 ? '+str(p_i)+':'+ str(p_o)+')'\n",
    "\n",
    "def create_func(f_space, expr_str, deg):\n",
    "    f = df.Function(f_space)\n",
    "    f.interpolate(df.Expression(expr_str, degree=deg))\n",
    "    return f\n",
    "\n",
    "def eval_func(func, x):\n",
    "    return np.array([func([x]) for x in x])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "c_tot_1 = create_func(F, p_tot(0.9, 0.9), 1)\n",
    "c0_1 = create_func(F, 'x[0]<0.1 ? 0 :' + p_tot(0.9, 0.9), 1)\n",
    "Ga0_1 = create_func(F, p_tot(1, 1/9), 1)\n",
    "\n",
    "c_tot_9 = create_func(F, p_tot(0.9, 0.1), 1)\n",
    "c0_9 = create_func(F, 'x[0]<0.1 ? 0 :'+p_tot(0.9, 0.1), 1)\n",
    "Ga0_9 = create_func(F,p_tot(1, 1), 1)\n",
    "\n",
    "c0_1 = calc_sim(c0_1, c_tot_1, Ga0_1, 0.001, 10, 0)\n",
    "c0_9 = calc_sim(c0_9, c_tot_9, Ga0_9, 0.001, 10, 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = np.linspace(0, 1, 10000)\n",
    "plt.plot(x, eval_func(c0_1, x))\n",
    "plt.plot(x, eval_func(c0_9, x))\n",
    "plt.xlim(0, 0.2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set B=1\n",
    "pi = 0.8\n",
    "po = 0.5 - np.sqrt(0.25 + (pi**2-pi))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "p1_i = 0.9; p1_o = 0.1\n",
    "p2_i = 0.8; p2_o = 0.2\n",
    "p3_i = 0.7; p3_o = 0.3\n",
    "p4_i = 0.9; p4_o = 0.9\n",
    "\n",
    "ct_1 = create_func(F, p_tot(p1_i, p1_o), 1)\n",
    "ct_2 = create_func(F, p_tot(p2_i, p2_o), 1)\n",
    "ct_3 = create_func(F, p_tot(p3_i, p3_o), 1)\n",
    "ct_4 = create_func(F, p_tot(p4_i, p4_o), 1)\n",
    "\n",
    "g_1= create_func(F, '1', 1)\n",
    "g_2= create_func(F, '0.5', 1)\n",
    "g_3= create_func(F, '0.33333333333333333333', 1)\n",
    "g_4= create_func(F, '1', 1)\n",
    "\n",
    "c0_1 = create_func(F, 'x[0]<0.081 ? 0 :'+p_tot(p1_i, p1_o), 1)\n",
    "c0_2 = create_func(F, 'x[0]<0.081 ? 0 :'+p_tot(p2_i, p2_o), 1)\n",
    "c0_3 = create_func(F, 'x[0]<0.081 ? 0 :'+p_tot(p3_i, p3_o), 1)\n",
    "c0_4 = create_func(F, 'x[0]<0.081 ? 0 :'+p_tot(p4_i, p4_o), 1)\n",
    "for i in range(10):\n",
    "    c0_1 = calc_sim(c0_1, ct_1, g_1, 0.001, 2, 2)\n",
    "#     c0_2 = calc_sim(c0_2, ct_2, g_2, 0.01, 2, 2)\n",
    "#     c0_3 = calc_sim(c0_3, ct_3, g_3, 0.01, 2, 2)\n",
    "    c0_4 = calc_sim(c0_4, ct_4, g_4, 0.001, 2, 2)\n",
    "    plt.plot(x, eval_func(c0_1, x), 'r')#/eval_func(c0_1, [0.099])\n",
    "#     plt.plot(x, eval_func(c0_2, x), 'g')\n",
    "#     plt.plot(x, eval_func(c0_3, x)/eval_func(c0_3, [0.099]), 'b')\n",
    "    plt.plot(x, eval_func(c0_4, x), 'k')\n",
    "    plt.xlim(0.0, 0.625)\n",
    "    plt.ylim(0, 1.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(x, eval_func(c0_1, x)/0.9)\n",
    "plt.plot(x, eval_func(c0_2, x)/0.8)\n",
    "plt.plot(x, eval_func(c0_3, x)/0.7)\n",
    "plt.xlim(0.0, 0.125)\n",
    "plt.ylim(0, 1.1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Radial diffusion equation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mesh = df.UnitIntervalMesh(1000)\n",
    "dt = 0.001\n",
    "F = df.FunctionSpace(mesh, 'CG', 1)\n",
    "c0 = df.Function(F)\n",
    "c0.interpolate(df.Expression('x[0]<0.5 && x[0]>0.2 ? 1:0', degree=1))\n",
    "q = df.TestFunction(F)\n",
    "c = df.Function(F)\n",
    "X = df.SpatialCoordinate(mesh)\n",
    "g = df.Expression('.00', degree=1)\n",
    "u_D = df.Expression('1', degree=1)\n",
    "def boundary(x, on_boundary):\n",
    "    return on_boundary\n",
    "bc = df.DirichletBC(F, u_D, boundary)\n",
    "# Weak form spherical symmetry\n",
    "form = (df.inner((c-c0)/dt, q*X[0]*X[0]) +\n",
    "        df.inner(df.grad(c), df.grad(X[0]*X[0]*q))-\n",
    "        c.dx(0)*2*X[0]*q) * df.dx\n",
    "# Weak form 1D\n",
    "# form = (df.inner((c-c0)/dt, q) +\n",
    "#         df.inner(df.grad(c), df.grad(q))) * df.dx\n",
    "t = 0\n",
    "\n",
    "# Solve in time\n",
    "for i in range(60):\n",
    "    print(np.sum([x*x*c0([x]) for x in np.linspace(0, 1, 1000)]))\n",
    "    df.solve(form == 0, c)\n",
    "    df.assign(c0, c)\n",
    "    t += dt\n",
    "    plt.plot(np.linspace(0, 1, 1000), [c0([x]) for x in np.linspace(0, 1, 1000)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}