{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dolfin as df\n",
    "import matplotlib.pyplot as plt\n",
    "import mshr as ms\n",
    "import numpy as np\n",
    "import time\n",
    "df.set_log_level(40)\n",
    "\n",
    "# domain = ms.Sphere(df.Point(0, 0, 0), 1.0)\n",
    "# mesh = ms.generate_mesh(domain, 50)\n",
    "mesh = df.UnitIntervalMesh(10000)\n",
    "F = df.FunctionSpace(mesh, 'CG', 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_sim(c0, c_tot, Ga0, dt, n_t, sym):\n",
    "    tc = df.TestFunction(F)\n",
    "    c = df.Function(F)\n",
    "    X = df.SpatialCoordinate(mesh)\n",
    "#     X.interpolate(df.Expression('x[0]', degree=1))\n",
    "    if sym == 0:\n",
    "    # Weak form 1D:\n",
    "#         form = (df.inner((c-c0)/dt, tc) +\n",
    "#              df.inner(df.grad(c), df.grad((1-c_tot)*Ga0*tc)) -\n",
    "#              df.inner(df.grad(c_tot), df.grad((1-c_tot)*Ga0/c_tot*c*tc))-\n",
    "#              tc*df.inner(df.grad(c), df.grad((1-c_tot)*Ga0))+\n",
    "#              tc*df.inner(df.grad(c_tot), df.grad((1-c_tot)/c_tot*c*Ga0))) * df.dx\n",
    "#     # Weak form 1D short:\n",
    "        form = ((c-c0)/dt*tc + (1-c_tot)*Ga0*\n",
    "                df.inner((df.grad(c)-c/c_tot*df.grad(c_tot)),\n",
    "                          df.grad(tc))) * df.dx\n",
    "    elif sym == 2:\n",
    "    # Weak form radial symmetry:\n",
    "#         form = (df.inner((c-c0)/dt, tc*X[0]*X[0]) +\n",
    "#              df.inner(df.grad(c), df.grad((1-c_tot)*Ga0*tc*X[0]*X[0])) -\n",
    "#              df.inner(df.grad(c_tot), df.grad((1-c_tot)*Ga0/c_tot*c*tc*X[0]*X[0]))-\n",
    "#              tc*df.inner(df.grad(c), df.grad((1-c_tot)*Ga0*X[0]*X[0]))+\n",
    "#              tc*df.inner(df.grad(c_tot), df.grad((1-c_tot)/c_tot*c*Ga0*X[0]*X[0]))-\n",
    "#              (1-c_tot)*Ga0*2*X[0]*c.dx(0)*tc+\n",
    "#              (1-c_tot)*Ga0/c_tot*c*2*X[0]*c_tot.dx(0)*tc) * df.dx\n",
    "#     # Weak form radial symmetry:\n",
    "#         form = ((c-c0)/dt*tc*X[0]**2 +\n",
    "#                 (1-c_tot)*Ga0*(c.dx(0)-c/c_tot*c_tot.dx(0))*\n",
    "#                          tc.dx(0)*X[0]**2) * df.dx\n",
    "        form = ((c-c0)/dt*tc + (1-c_tot)*Ga0*\n",
    "                df.inner((df.grad(c)-c/c_tot*df.grad(c_tot)),\n",
    "                          df.grad(tc)))*X[0]*X[0]*df.dx\n",
    "    t = 0\n",
    "    # Solve in time\n",
    "    ti = time.time()\n",
    "    for i in range(n_t):\n",
    "#         print(np.sum([x*x*c0([x]) for x in np.linspace(0, 1, 1000)]))\n",
    "        df.solve(form == 0, c)\n",
    "        df.assign(c0, c)\n",
    "        t += dt\n",
    "    print(time.time() - ti)\n",
    "    return c0\n",
    "\n",
    "def p_tot(p_i, p_o):\n",
    "    return str(p_i-p_o)+'*(-0.5*tanh(3500*(x[0]-0.1))+0.5)+'+str(p_o)\n",
    "#     return '(x[0]<0.1 ? '+str(p_i)+':'+ str(p_o)+')'\n",
    "\n",
    "def create_func(f_space, expr_str, deg):\n",
    "    f = df.Function(f_space)\n",
    "    f.interpolate(df.Expression(expr_str, degree=deg))\n",
    "    return f\n",
    "\n",
    "def eval_func(func, x):\n",
    "    return np.array([func([x]) for x in x])\n",
    "\n",
    "def eval_P(func, x):\n",
    "    return func(x[0])/func(x[1])\n",
    "\n",
    "def eval_D(func_ga, func_tot, x):\n",
    "    return(func_ga(x)*(1-func_tot(x)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bp = 0.1 # Boundary position at IC\n",
    "sym = 0 # Symmetry of the problem\n",
    "dt = 0.001\n",
    "nt = 10\n",
    "\n",
    "c_tot_1 = create_func(F, p_tot(0.9, 0.9), 1)\n",
    "c0_1 = create_func(F, 'x[0]<'+str(bp)+'? 0 :' + p_tot(0.9, 0.9), 1)\n",
    "Ga0_1 = create_func(F, p_tot(1, 1/9), 1)\n",
    "\n",
    "c_tot_2 = create_func(F, p_tot(0.9, 0.45), 1)\n",
    "c0_2 = create_func(F, 'x[0]<'+str(bp)+'? 0 :' +p_tot(0.9, 0.45), 1)\n",
    "Ga0_2 = create_func(F,p_tot(1, 0.08), 1)\n",
    "\n",
    "c_tot_3 = create_func(F, p_tot(0.9, 0.3), 1)\n",
    "c0_3 = create_func(F, 'x[0]<'+str(bp)+'? 0 :' +p_tot(0.9, 0.3), 1)\n",
    "Ga0_3 = create_func(F,p_tot(1, 0.145), 1)\n",
    "\n",
    "c_tot_5 = create_func(F, p_tot(0.9, 0.18), 1)\n",
    "c0_5 = create_func(F, 'x[0]<'+str(bp)+'? 0 :' +p_tot(0.9, 0.18), 1)\n",
    "Ga0_5 = create_func(F,p_tot(1, 0.34), 1)\n",
    "\n",
    "c_tot_8 = create_func(F, p_tot(0.9, 0.1125), 1)\n",
    "c0_8 = create_func(F, 'x[0]<'+str(bp)+'? 0 :' +p_tot(0.9, 0.1125), 1)\n",
    "Ga0_8 = create_func(F,p_tot(1, 0.8), 1)\n",
    "\n",
    "c_tot_9 = create_func(F, p_tot(0.9, 0.1), 1)\n",
    "c0_9 = create_func(F, 'x[0]<'+str(bp)+'? 0 :' +p_tot(0.9, 0.1), 1)\n",
    "Ga0_9 = create_func(F,p_tot(1, 1), 1)\n",
    "\n",
    "c0_1 = calc_sim(c0_1, c_tot_1, Ga0_1, dt, nt, sym)\n",
    "# c0_2 = calc_sim(c0_2, c_tot_2, Ga0_2, dt, nt, sym)\n",
    "# c0_3 = calc_sim(c0_3, c_tot_3, Ga0_3, dt, nt, sym)\n",
    "# c0_5 = calc_sim(c0_5, c_tot_5, Ga0_5, dt, nt, sym)\n",
    "# c0_8 = calc_sim(c0_8, c_tot_8, Ga0_8, dt, nt, sym)\n",
    "c0_9 = calc_sim(c0_9, c_tot_9, Ga0_9, dt, nt, sym)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = np.linspace(0, 1, 10000)\n",
    "plt.plot(x, eval_func(c0_1, x))\n",
    "plt.plot(x, eval_func(c0_2, x))\n",
    "plt.plot(x, eval_func(c0_3, x))\n",
    "plt.plot(x, eval_func(c0_5, x))\n",
    "plt.plot(x, eval_func(c0_8, x))\n",
    "plt.plot(x, eval_func(c0_9, x))\n",
    "# plt.xlim(0.08, 0.125125)\n",
    "# plt.ylim(0., 0.325125)\n",
    "plt.show()\n",
    "# Diffusion versus partitioning\n",
    "list_Ga = [Ga0_1, Ga0_2, Ga0_3, Ga0_5, Ga0_8, Ga0_9]\n",
    "list_Tot = [c_tot_1, c_tot_2, c_tot_3, c_tot_5, c_tot_8, c_tot_9]\n",
    "D = [eval_D(Ga, Tot, 1) for (Ga, Tot) in zip (list_Ga, list_Tot)]\n",
    "P = [eval_P(Tot, [0, 1]) for Tot in list_Tot]\n",
    "plt.plot(D, P)\n",
    "plt.ylim(0, 11); plt.xlim(0, 1)\n",
    "plt.ylabel('P'); plt.xlabel('Diffusion coefficient')\n",
    "plt.plot(np.linspace(0, 1, 100), 9.5*(np.linspace(0, 1, 100))**(1/2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot outside, check invariance\n",
    "x = np.linspace(bp, 1, 1000)\n",
    "y1 = eval_func(c0_1, x)\n",
    "y2 = eval_func(c0_2, x)\n",
    "y3 = eval_func(c0_3, x)\n",
    "y9 = eval_func(c0_9, x)\n",
    "plt.plot((x-bp), y1)\n",
    "plt.plot((x-bp)/2, 2*y2)\n",
    "plt.plot((x-bp)/3, 3*y3)\n",
    "plt.plot((x-bp)/9, 9*y9)\n",
    "plt.xlim(0.0, 0.02)\n",
    "plt.ylim(0.2, 1)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot inside, check invariance\n",
    "x = np.linspace(0, 0.1, 1000)\n",
    "y1 = eval_func(c0_1, x)-np.min(eval_func(c0_1, x))\n",
    "y2 = eval_func(c0_2, x)-np.min(eval_func(c0_2, x))\n",
    "y3 = eval_func(c0_3, x)-np.min(eval_func(c0_3, x))\n",
    "y9 = eval_func(c0_9, x)-np.min(eval_func(c0_9, x))\n",
    "plt.plot((x-np.min(x)), y1/eval_func(c0_1, [0.08]))\n",
    "plt.plot((x-np.min(x)), y2/eval_func(c0_2, [0.08]))\n",
    "plt.plot((x-np.min(x)), y3/eval_func(c0_3, [0.08]))\n",
    "plt.plot((x-np.min(x)), y9/eval_func(c0_9, [0.08]))\n",
    "plt.xlim(0.08, bp)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check partitioning\n",
    "cp = eval_func(c0_9, x)\n",
    "print('Partitioning: ' + str(np.max(cp[1:1050])/np.min(cp[999:1050])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set B=1\n",
    "pi = 0.8\n",
    "po = 0.5 - np.sqrt(0.25 + (pi**2-pi))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "p1_i = 0.9; p1_o = 0.1\n",
    "p2_i = 0.8; p2_o = 0.2\n",
    "p3_i = 0.7; p3_o = 0.3\n",
    "p4_i = 0.9; p4_o = 0.9\n",
    "\n",
    "ct_1 = create_func(F, p_tot(p1_i, p1_o), 1)\n",
    "ct_2 = create_func(F, p_tot(p2_i, p2_o), 1)\n",
    "ct_3 = create_func(F, p_tot(p3_i, p3_o), 1)\n",
    "ct_4 = create_func(F, p_tot(p4_i, p4_o), 1)\n",
    "\n",
    "g_1= create_func(F, '1', 1)\n",
    "g_2= create_func(F, '0.5', 1)\n",
    "g_3= create_func(F, '0.33333333333333333333', 1)\n",
    "g_4= create_func(F, '1', 1)\n",
    "\n",
    "c0_1 = create_func(F, 'x[0]<0.081 ? 0 :'+p_tot(p1_i, p1_o), 1)\n",
    "c0_2 = create_func(F, 'x[0]<0.081 ? 0 :'+p_tot(p2_i, p2_o), 1)\n",
    "c0_3 = create_func(F, 'x[0]<0.081 ? 0 :'+p_tot(p3_i, p3_o), 1)\n",
    "c0_4 = create_func(F, 'x[0]<0.081 ? 0 :'+p_tot(p4_i, p4_o), 1)\n",
    "for i in range(10):\n",
    "    c0_1 = calc_sim(c0_1, ct_1, g_1, 0.001, 2, 2)\n",
    "#     c0_2 = calc_sim(c0_2, ct_2, g_2, 0.01, 2, 2)\n",
    "#     c0_3 = calc_sim(c0_3, ct_3, g_3, 0.01, 2, 2)\n",
    "    c0_4 = calc_sim(c0_4, ct_4, g_4, 0.001, 2, 2)\n",
    "    plt.plot(x, eval_func(c0_1, x), 'r')#/eval_func(c0_1, [0.099])\n",
    "#     plt.plot(x, eval_func(c0_2, x), 'g')\n",
    "#     plt.plot(x, eval_func(c0_3, x)/eval_func(c0_3, [0.099]), 'b')\n",
    "    plt.plot(x, eval_func(c0_4, x), 'k')\n",
    "    plt.xlim(0.0, 0.625)\n",
    "    plt.ylim(0, 1.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(x, eval_func(c0_1, x)/0.9)\n",
    "plt.plot(x, eval_func(c0_2, x)/0.8)\n",
    "plt.plot(x, eval_func(c0_3, x)/0.7)\n",
    "plt.xlim(0.0, 0.125)\n",
    "plt.ylim(0, 1.1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Comparison with Matlab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ML_9 = np.genfromtxt('Matlab_workspaces/ML_9_1.csv', delimiter=',')\n",
    "ML_1 = np.genfromtxt('Matlab_workspaces/ML_1_9.csv', delimiter=',')\n",
    "plt.plot(ML_9[0, :], ML_9[-1, :])\n",
    "plt.plot(x, eval_func(c0_9, x))\n",
    "plt.xlim(0, 0.5)\n",
    "# plt.show()\n",
    "\n",
    "plt.plot(ML_1[0, :], ML_1[-1, :])\n",
    "plt.plot(x, eval_func(c0_1, x))\n",
    "plt.xlim(0, 0.5)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Radial diffusion equation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mesh = df.UnitIntervalMesh(1000)\n",
    "dt = 0.001\n",
    "F = df.FunctionSpace(mesh, 'CG', 1)\n",
    "c0 = df.Function(F)\n",
    "c0.interpolate(df.Expression('x[0]<0.5 && x[0]>0.2 ? 1:0', degree=1))\n",
    "q = df.TestFunction(F)\n",
    "c = df.Function(F)\n",
    "X = df.SpatialCoordinate(mesh)\n",
    "g = df.Expression('.00', degree=1)\n",
    "u_D = df.Expression('1', degree=1)\n",
    "def boundary(x, on_boundary):\n",
    "    return on_boundary\n",
    "bc = df.DirichletBC(F, u_D, boundary)\n",
    "# Weak form spherical symmetry\n",
    "form = ((c-c0)/dt*q*X[0]*X[0] +\n",
    "        X[0]*X[0]*df.inner(df.grad(c), df.grad(q))-\n",
    "        c.dx(0)*2*X[0]*q) * df.dx\n",
    "# Weak form 1D\n",
    "# form = ((c-c0)/dt* q + df.inner(df.grad(c), df.grad(q))) * df.dx\n",
    "# Weak form 1D with .dx(0) notation for derivative in 1st direction.\n",
    "# form = ((c-c0)/dt*q + c.dx(0)*q.dx(0)) * df.dx\n",
    "t = 0\n",
    "\n",
    "# Solve in time\n",
    "for i in range(60):\n",
    "    print(np.sum([x*x*c0([x]) for x in np.linspace(0, 1, 1000)]))\n",
    "    df.solve(form == 0, c)\n",
    "    df.assign(c0, c)\n",
    "    t += dt\n",
    "    plt.plot(np.linspace(0, 1, 1000), [c0([x]) for x in np.linspace(0, 1, 1000)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}