{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Analyzing AMReX Outputs From FLEKS\n",
    "\n",
    "In this demo, we show how to analyze native AMReX field and particle outputs. Example data can be downloaded as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import flekspy\n",
    "from flekspy.util import download_testfile\n",
    "\n",
    "url = \"https://raw.githubusercontent.com/henry2004y/batsrus_data/master/fleks_particle_small.tar.gz\"\n",
    "download_testfile(url, \"data\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Inheriting from the IDL procedures, we can pass strings to limit the plotting range (experimental, may be changed in the future):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = flekspy.load(\"data/fleks_particle_small/3d*amrex\")\n",
    "dc = ds.get_slice(\"z\", 0.001)\n",
    "\n",
    "f, axes = dc.plot(\n",
    "    \"Bx<(1.5e4)>(-1.5e4) ({rhos0}+{rhos1})<(1.3e21) pS0 uxs0\", figsize=(12, 8)\n",
    ")\n",
    "dc.add_stream(axes[0], \"Bx\", \"By\", color=\"w\")\n",
    "dc.add_contour(axes[1], \"Bx\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Velocity Space Distributions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from flekspy.util import unit_one\n",
    "\n",
    "filename = \"data/fleks_particle_small/cut_*amrex\"\n",
    "ds = flekspy.load(filename)\n",
    "\n",
    "# Add a user defined field. See yt document for more information about derived field.\n",
    "ds.add_field(\n",
    "    ds.pvar(\"unit_one\"), function=unit_one, sampling_type=\"particle\", units=\"code_mass\"\n",
    ")\n",
    "\n",
    "x_field = \"p_uy\"\n",
    "y_field = \"p_uz\"\n",
    "z_field = \"unit_one\"\n",
    "# z_field = \"p_w\"\n",
    "xleft = [-0.001, -0.001, -0.001]\n",
    "xright = [0.001, 0.001, 0.001]\n",
    "\n",
    "### Select and plot the particles inside a box defined by xleft and xright\n",
    "region = ds.box(xleft, xright)\n",
    "pp = ds.plot_phase(\n",
    "    x_field,\n",
    "    y_field,\n",
    "    z_field,\n",
    "    region=region,\n",
    "    unit_type=\"si\",\n",
    "    x_bins=64,\n",
    "    y_bins=64,\n",
    "    domain_size=(-0.0002, 0.0002, -0.0002, 0.0002),\n",
    ")\n",
    "\n",
    "pp.set_cmap(pp.fields[0], \"turbo\")\n",
    "\n",
    "# plot.set_zlim(plot.fields[0], zmin, zmax)\n",
    "pp.set_xlabel(r\"$V_y$\")\n",
    "pp.set_ylabel(r\"$V_z$\")\n",
    "# pp.set_colorbar_label(pp.fields[0], \"pw\")\n",
    "pp.set_title(pp.fields[0], \"Number density\")\n",
    "pp.set_font(\n",
    "    {\n",
    "        \"size\": 34,\n",
    "        \"family\": \"DejaVu Sans\",\n",
    "    }\n",
    ")\n",
    "pp.set_log(pp.fields[0], False)\n",
    "pp.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you need the direct phase space distributions together with the axis,"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x, y, w = ds.get_phase(\n",
    "    x_field,\n",
    "    y_field,\n",
    "    z_field,\n",
    "    region=region,\n",
    "    x_bins=64,\n",
    "    y_bins=64,\n",
    "    domain_size=(-0.0002, 0.0002, -0.0002, 0.0002),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To check the newly added field,"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ad = ds.all_data()\n",
    "ad[(\"particles\", \"unit_one\")]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Plot the location of particles that are inside a sphere"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "center = [0, 0, 0]\n",
    "radius = 0.001\n",
    "z_field = \"unit_one\"\n",
    "# Object sphere is defined in yt/data_objects/selection_objects/spheroids.py\n",
    "sp = ds.sphere(center, radius)\n",
    "pp = ds.plot_particles(\n",
    "    \"p_x\", \"p_y\", z_field, region=sp, unit_type=\"planet\", x_bins=64, y_bins=64\n",
    ")\n",
    "pp.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Plot the phase space of particles that are inside a sphere"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pp = ds.plot_phase(\n",
    "    \"p_uy\", \"p_uz\", z_field, region=sp, unit_type=\"planet\", x_bins=64, y_bins=64\n",
    ")\n",
    "pp.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Plot the location of particles that are inside a disk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "center = [0, 0, 0]\n",
    "normal = [1, 1, 0]\n",
    "radius = 0.0005\n",
    "height = 0.0004\n",
    "z_field = \"unit_one\"\n",
    "# Object sphere is defined in yt/data_objects/selection_objects/disk.py\n",
    "disk = ds.disk(center, normal, radius, height)\n",
    "pp = ds.plot_particles(\n",
    "    \"p_x\", \"p_y\", z_field, region=disk, unit_type=\"planet\", x_bins=64, y_bins=64\n",
    ")\n",
    "pp.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Plot the phase space of particles that are inside a disk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pp = ds.plot_phase(\n",
    "    \"p_uy\", \"p_uz\", z_field, region=disk, unit_type=\"planet\", x_bins=64, y_bins=64\n",
    ")\n",
    "pp.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Transform the velocity coordinates and visualize the phase space distribution\n",
    "\n",
    "WIP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import flekspy, yt\n",
    "\n",
    "l = [1, 0, 0]\n",
    "m = [0, 1, 0]\n",
    "n = [0, 0, 1]\n",
    "\n",
    "\n",
    "def _vel_l(field, data):\n",
    "    res = (\n",
    "        l[0] * data[(\"particles\", \"p_ux\")]\n",
    "        + l[1] * data[(\"particles\", \"p_uy\")]\n",
    "        + l[2] * data[(\"particles\", \"p_uz\")]\n",
    "    )\n",
    "    return res\n",
    "\n",
    "\n",
    "def _vel_m(field, data):\n",
    "    res = (\n",
    "        m[0] * data[(\"particles\", \"p_ux\")]\n",
    "        + m[1] * data[(\"particles\", \"p_uy\")]\n",
    "        + m[2] * data[(\"particles\", \"p_uz\")]\n",
    "    )\n",
    "    return res\n",
    "\n",
    "\n",
    "def _vel_n(field, data):\n",
    "    res = (\n",
    "        n[0] * data[(\"particles\", \"p_ux\")]\n",
    "        + n[1] * data[(\"particles\", \"p_uy\")]\n",
    "        + n[2] * data[(\"particles\", \"p_uz\")]\n",
    "    )\n",
    "    return res\n",
    "\n",
    "\n",
    "filename = \"data/fleks_particle_small/cut_*amrex\"\n",
    "ds = flekspy.load(filename)\n",
    "\n",
    "# Add a user defined field. See yt document for more information about derived field.\n",
    "vl_name = ds.pvar(\"vel_l\")\n",
    "vm_name = ds.pvar(\"vel_m\")\n",
    "vn_name = ds.pvar(\"vel_n\")\n",
    "ds.add_field(vl_name, units=\"code_velocity\", function=_vel_l, sampling_type=\"particle\")\n",
    "ds.add_field(vm_name, units=\"code_velocity\", function=_vel_m, sampling_type=\"particle\")\n",
    "ds.add_field(vn_name, units=\"code_velocity\", function=_vel_n, sampling_type=\"particle\")\n",
    "\n",
    "######## Plot the location of particles that are inside a sphere ###########\n",
    "center = [0, 0, 0]\n",
    "radius = 0.001\n",
    "# Object sphere is defined in yt/data_objects/selection_objects/spheroids.py\n",
    "sp = ds.sphere(center, radius)\n",
    "\n",
    "x_field = vl_name\n",
    "y_field = vm_name\n",
    "z_field = ds.pvar(\"p_w\")\n",
    "\n",
    "logs = {x_field: False, y_field: False}\n",
    "profile = yt.create_profile(\n",
    "    data_source=sp,\n",
    "    bin_fields=[x_field, y_field],\n",
    "    fields=z_field,\n",
    "    n_bins=[64, 64],\n",
    "    weight_field=None,\n",
    "    logs=logs,\n",
    ")\n",
    "\n",
    "pp = yt.PhasePlot.from_profile(profile)\n",
    "\n",
    "pp.set_unit(x_field, \"km/s\")\n",
    "pp.set_unit(y_field, \"km/s\")\n",
    "pp.set_unit(z_field, \"amu\")\n",
    "\n",
    "pp.set_cmap(pp.fields[0], \"turbo\")\n",
    "# pp.set_zlim(pp.fields[0], zmin, zmax)\n",
    "pp.set_xlabel(r\"$V_l$\")\n",
    "pp.set_ylabel(r\"$V_m$\")\n",
    "pp.set_colorbar_label(pp.fields[0], \"colorbar_label\")\n",
    "pp.set_title(pp.fields[0], \"Density\")\n",
    "pp.set_font(\n",
    "    {\n",
    "        \"size\": 34,\n",
    "        \"family\": \"DejaVu Sans\",\n",
    "    }\n",
    ")\n",
    "pp.set_log(pp.fields[0], False)\n",
    "pp.show()"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}