Commit 4b099e10 authored by gonciarz's avatar gonciarz

Next step in refactoring

parent 789ac98d
package mosaic.ia; package mosaic.ia;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Random; import java.util.Random;
...@@ -12,8 +13,6 @@ import ij.ImagePlus; ...@@ -12,8 +13,6 @@ import ij.ImagePlus;
import mosaic.ia.HypothesisTesting.TestResult; import mosaic.ia.HypothesisTesting.TestResult;
import mosaic.ia.Potentials.Potential; import mosaic.ia.Potentials.Potential;
import mosaic.ia.Potentials.PotentialType; import mosaic.ia.Potentials.PotentialType;
import mosaic.ia.gui.DistributionsPlot;
import mosaic.ia.gui.EstimatedPotentialPlot;
import mosaic.ia.gui.Utils; import mosaic.ia.gui.Utils;
import mosaic.utils.Debug; import mosaic.utils.Debug;
import mosaic.utils.math.StatisticsUtils; import mosaic.utils.math.StatisticsUtils;
...@@ -24,12 +23,17 @@ public class Analysis { ...@@ -24,12 +23,17 @@ public class Analysis {
private Potential iPotential; private Potential iPotential;
private DistanceCalculations iDistanceCalculations; private DistanceCalculations iDistanceCalculations;
public double[] iContextQdDistancesGrid;
public double[] iContextQdPdf; private double[] iContextQdDistancesGrid;
public double[] iNearestNeighborDistancesXtoY; private double[] iContextQdPdf;
public double[] iNearestNeighborDistancesXtoYPdf; private double[] iNearestNeighborDistancesXtoY;
private double[] iNearestNeighborDistancesXtoYPdf;
private double[] iObservedModelFitPdPdf;
private List<CmaResult> iCmaResults;
private double[][] iBestPointsFound; private double[][] iBestPointsFound;
private double[] iBestFunctionValue;
private int iBestPointIndex = -1; private int iBestPointIndex = -1;
public void calcDist(double gridSize, double kernelWeightq, double kernelWeightp, float[][][] genMask, ImagePlus iImageX, ImagePlus iImageY) { public void calcDist(double gridSize, double kernelWeightq, double kernelWeightp, float[][][] genMask, ImagePlus iImageX, ImagePlus iImageY) {
...@@ -69,10 +73,11 @@ public class Analysis { ...@@ -69,10 +73,11 @@ public class Analysis {
} }
} }
public void cmaOptimization(List<CmaResult> aResultsOutput, int cmaReRunTimes, boolean aRepetitiveResults) { public void cmaOptimization(int cmaReRunTimes, boolean aRepetitiveResults) {
final FitFunction fitfun = new FitFunction(iContextQdPdf, iContextQdDistancesGrid, iNearestNeighborDistancesXtoYPdf, iNearestNeighborDistancesXtoY, iPotential); final FitFunction fitfun = new FitFunction(iContextQdPdf, iContextQdDistancesGrid, iNearestNeighborDistancesXtoYPdf, iNearestNeighborDistancesXtoY, iPotential);
iBestPointsFound = new double[cmaReRunTimes][iPotential.numOfDimensions()]; iBestPointsFound = new double[cmaReRunTimes][iPotential.numOfDimensions()];
double[] bestFunctionValue = new double[cmaReRunTimes]; iBestFunctionValue = new double[cmaReRunTimes];
iCmaResults = new ArrayList<CmaResult>();
double bestFitness = Double.MAX_VALUE; double bestFitness = Double.MAX_VALUE;
boolean diffFitness = false; boolean diffFitness = false;
...@@ -101,29 +106,27 @@ public class Analysis { ...@@ -101,29 +106,27 @@ public class Analysis {
cma.setFitnessOfMeanX(fitfun.valueOf(cma.getMeanX())); cma.setFitnessOfMeanX(fitfun.valueOf(cma.getMeanX()));
logCmaResultInfo(cma); logCmaResultInfo(cma);
bestFunctionValue[cmaRunNumber] = cma.getBestFunctionValue(); iBestFunctionValue[cmaRunNumber] = cma.getBestFunctionValue();
if (bestFunctionValue[cmaRunNumber] < bestFitness) { if (iBestFunctionValue[cmaRunNumber] < bestFitness) {
if (cmaRunNumber > 0 && bestFitness - bestFunctionValue[cmaRunNumber] > bestFunctionValue[cmaRunNumber] * 0.00001) { if (cmaRunNumber > 0 && bestFitness - iBestFunctionValue[cmaRunNumber] > iBestFunctionValue[cmaRunNumber] * 0.00001) {
diffFitness = true; diffFitness = true;
} }
bestFitness = bestFunctionValue[cmaRunNumber]; bestFitness = iBestFunctionValue[cmaRunNumber];
iBestPointIndex = cmaRunNumber; iBestPointIndex = cmaRunNumber;
} }
iBestPointsFound[cmaRunNumber] = cma.getBestX(); iBestPointsFound[cmaRunNumber] = cma.getBestX();
addNewOutputResult(aResultsOutput, bestFunctionValue[cmaRunNumber], iBestPointsFound[cmaRunNumber]); addNewOutputResult(iCmaResults, iBestFunctionValue[cmaRunNumber], iBestPointsFound[cmaRunNumber]);
} }
logger.debug("Best Parameters Found:" + Debug.getString(iBestPointsFound[iBestPointIndex]) + " fit function value=" + bestFunctionValue[iBestPointIndex]); logger.debug("Best Parameters Found:" + Debug.getString(iBestPointsFound[iBestPointIndex]) + " fit function value=" + iBestFunctionValue[iBestPointIndex]);
if (diffFitness) { if (diffFitness) {
Utils.messageDialog("IA - CMA optimization", "Warning: Optimization returned different results for reruns. The results may not be accurate. Displaying the parameters and the plots corr. to best fitness."); Utils.messageDialog("IA - CMA optimization", "Warning: Optimization returned different results for reruns. The results may not be accurate. Displaying the parameters and the plots corr. to best fitness.");
} }
fitfun.l2Norm(iBestPointsFound[iBestPointIndex]); // to calc pgrid for best params fitfun.l2Norm(iBestPointsFound[iBestPointIndex]); // to calc pgrid for best params
new EstimatedPotentialPlot(iContextQdDistancesGrid, iPotential, iBestPointsFound[iBestPointIndex], bestFunctionValue[iBestPointIndex]).show(); iObservedModelFitPdPdf = fitfun.getObservedModelFitPdPdf() ;
double[] observedModelFitPdPdf = fitfun.getObservedModelFitPdPdf() ; StatisticsUtils.normalizePdf(iObservedModelFitPdPdf, iContextQdDistancesGrid, false);
StatisticsUtils.normalizePdf(observedModelFitPdPdf, iContextQdDistancesGrid, false);
new DistributionsPlot(iContextQdDistancesGrid, observedModelFitPdPdf, iContextQdPdf, iNearestNeighborDistancesXtoYPdf, iPotential, iBestPointsFound[iBestPointIndex], bestFunctionValue[iBestPointIndex]).show();
} }
private void addNewOutputResult(List<CmaResult> aResultsOutput, double aBestFunctionValue, double[] aBestPointFound) { private void addNewOutputResult(List<CmaResult> aResultsOutput, double aBestFunctionValue, double[] aBestPointFound) {
...@@ -272,10 +275,44 @@ public class Analysis { ...@@ -272,10 +275,44 @@ public class Analysis {
return iDistanceCalculations.getMaxXtoYdistance(); return iDistanceCalculations.getMaxXtoYdistance();
} }
public double[] getDistances() { public double[] getContextQdDistancesGrid() {
return iContextQdDistancesGrid;
}
public double[] getContextQdPdf() {
return iContextQdPdf;
}
public double[] getNearestNeighborDistancesXtoY() {
return iNearestNeighborDistancesXtoY; return iNearestNeighborDistancesXtoY;
} }
public double[] getNearestNeighborDistancesXtoYPdf() {
return iNearestNeighborDistancesXtoYPdf;
}
public double[] getObservedModelFitPdPdf() {
return iObservedModelFitPdPdf;
}
/**
* @return best point coordinates found by CMA optimization
*/
public double[] getBestPointFound() {
return iBestPointsFound[iBestPointIndex];
}
/**
* @return best function value found by CMA optimization
*/
public double getBestFunctionValue() {
return iBestFunctionValue[iBestPointIndex];
}
public List<CmaResult> getCmaResults() {
return iCmaResults;
}
public void setPotentialType(Potential potentialType) { public void setPotentialType(Potential potentialType) {
iPotential = potentialType; iPotential = potentialType;
} }
......
...@@ -6,7 +6,6 @@ import java.util.ArrayList; ...@@ -6,7 +6,6 @@ import java.util.ArrayList;
import org.apache.log4j.Logger; import org.apache.log4j.Logger;
import org.scijava.vecmath.Point3d; import org.scijava.vecmath.Point3d;
import ij.IJ;
import mosaic.ia.gui.Utils; import mosaic.ia.gui.Utils;
import mosaic.utils.Debug; import mosaic.utils.Debug;
import mosaic.utils.math.NearestNeighborTree; import mosaic.utils.math.NearestNeighborTree;
......
...@@ -6,7 +6,6 @@ import java.util.Random; ...@@ -6,7 +6,6 @@ import java.util.Random;
import org.apache.log4j.Logger; import org.apache.log4j.Logger;
import ij.IJ;
import mosaic.ia.Potentials.Potential; import mosaic.ia.Potentials.Potential;
import mosaic.ia.gui.Utils; import mosaic.ia.gui.Utils;
......
...@@ -130,8 +130,8 @@ public class InteractionAnalysisGui extends InteractionAnalysisGuiBase { ...@@ -130,8 +130,8 @@ public class InteractionAnalysisGui extends InteractionAnalysisGuiBase {
int numReRuns = Integer.parseInt(reRuns.getText()); int numReRuns = Integer.parseInt(reRuns.getText());
Potential potential = Potentials.createPotential(getPotential(), iAnalysis.getMinDistance(), iAnalysis.getMaxDistance(), numOfSupportPointsValue, smoothnessValue); Potential potential = Potentials.createPotential(getPotential(), iAnalysis.getMinDistance(), iAnalysis.getMaxDistance(), numOfSupportPointsValue, smoothnessValue);
iAnalysis.setPotentialType(potential); // for the first time iAnalysis.setPotentialType(potential); // for the first time
List<CmaResult> results = new ArrayList<CmaResult>(); iAnalysis.cmaOptimization(numReRuns, false);
iAnalysis.cmaOptimization(results, numReRuns, false); List<CmaResult> results = iAnalysis.getCmaResults();
mosaic.utils.Debug.print(results); mosaic.utils.Debug.print(results);
if (!Interpreter.batchMode) { if (!Interpreter.batchMode) {
final ResultsTable rt = new ResultsTable(); final ResultsTable rt = new ResultsTable();
...@@ -145,6 +145,9 @@ public class InteractionAnalysisGui extends InteractionAnalysisGuiBase { ...@@ -145,6 +145,9 @@ public class InteractionAnalysisGui extends InteractionAnalysisGuiBase {
} }
rt.updateResults(); rt.updateResults();
rt.show("Results"); rt.show("Results");
new EstimatedPotentialPlot(iAnalysis.getContextQdDistancesGrid(), potential, iAnalysis.getBestPointFound(), iAnalysis.getBestFunctionValue()).show();
new DistributionsPlot(iAnalysis.getContextQdDistancesGrid(), iAnalysis.getObservedModelFitPdPdf(), iAnalysis.getContextQdPdf(), iAnalysis.getNearestNeighborDistancesXtoYPdf(), potential, iAnalysis.getBestPointFound(), iAnalysis.getBestFunctionValue()).show();
} }
} }
...@@ -220,9 +223,10 @@ public class InteractionAnalysisGui extends InteractionAnalysisGuiBase { ...@@ -220,9 +223,10 @@ public class InteractionAnalysisGui extends InteractionAnalysisGuiBase {
throw new RuntimeException("Unknown tab chosen in IA GUI"); throw new RuntimeException("Unknown tab chosen in IA GUI");
} }
new DistributionsPlot(iAnalysis.iContextQdDistancesGrid, iAnalysis.iContextQdPdf, iAnalysis.iNearestNeighborDistancesXtoYPdf).show(); // Generate plots / info for a user
Utils.plotHistogram("ObservedDistances", iAnalysis.iNearestNeighborDistancesXtoY, Analysis.getOptimBins(iAnalysis.iNearestNeighborDistancesXtoY, 8, iAnalysis.iNearestNeighborDistancesXtoY.length / 8)); new DistributionsPlot(iAnalysis.getContextQdDistancesGrid(), iAnalysis.getContextQdPdf(), iAnalysis.getNearestNeighborDistancesXtoYPdf()).show();
double suggestedKernel = Analysis.calcWekaWeights(iAnalysis.iNearestNeighborDistancesXtoY); Utils.plotHistogram("ObservedDistances", iAnalysis.getNearestNeighborDistancesXtoY(), Analysis.getOptimBins(iAnalysis.getNearestNeighborDistancesXtoY(), 8, iAnalysis.getNearestNeighborDistancesXtoY().length / 8));
double suggestedKernel = Analysis.calcWekaWeights(iAnalysis.getNearestNeighborDistancesXtoY());
Utils.messageDialog("IA - kernel", "Suggested Kernel wt(p): " + suggestedKernel); Utils.messageDialog("IA - kernel", "Suggested Kernel wt(p): " + suggestedKernel);
logger.debug("Suggested kernel wt(p)=" + suggestedKernel); logger.debug("Suggested kernel wt(p)=" + suggestedKernel);
} }
......
...@@ -43,7 +43,7 @@ public class AnalysisTest extends CommonBase { ...@@ -43,7 +43,7 @@ public class AnalysisTest extends CommonBase {
double epsilon = 1e-10; double epsilon = 1e-10;
assertEquals(0, analysis.getMinDistance(), epsilon); assertEquals(0, analysis.getMinDistance(), epsilon);
assertEquals(2, analysis.getMaxDistance(), epsilon); assertEquals(2, analysis.getMaxDistance(), epsilon);
assertArrayEquals(new double[] {2, 1, 0, 1, 2}, analysis.getDistances(), epsilon); assertArrayEquals(new double[] {2, 1, 0, 1, 2}, analysis.getNearestNeighborDistancesXtoY(), epsilon);
} }
@Test @Test
...@@ -53,10 +53,10 @@ public class AnalysisTest extends CommonBase { ...@@ -53,10 +53,10 @@ public class AnalysisTest extends CommonBase {
double epsilon = 1e-6; double epsilon = 1e-6;
assertEquals(0.418188, analysis.getMinDistance(), epsilon); assertEquals(0.418188, analysis.getMinDistance(), epsilon);
assertEquals(112.924864, analysis.getMaxDistance(), epsilon); assertEquals(112.924864, analysis.getMaxDistance(), epsilon);
assertEquals(77.722986, Analysis.calcWekaWeights(analysis.getDistances()), epsilon); assertEquals(77.722986, Analysis.calcWekaWeights(analysis.getNearestNeighborDistancesXtoY()), epsilon);
analysis.setPotentialType(Potentials.createPotential(PotentialType.HERNQUIST)); analysis.setPotentialType(Potentials.createPotential(PotentialType.HERNQUIST));
List<CmaResult> results = new ArrayList<CmaResult>(); analysis.cmaOptimization(1, true);
analysis.cmaOptimization(results, 1, true); List<CmaResult> results = analysis.getCmaResults();
epsilon = 1e-6; epsilon = 1e-6;
assertEquals(36.545593028698356, results.get(0).iStrength, epsilon); assertEquals(36.545593028698356, results.get(0).iStrength, epsilon);
...@@ -94,11 +94,11 @@ public class AnalysisTest extends CommonBase { ...@@ -94,11 +94,11 @@ public class AnalysisTest extends CommonBase {
double epsilon = 1e-6; double epsilon = 1e-6;
assertEquals(0.418188, analysis.getMinDistance(), epsilon); assertEquals(0.418188, analysis.getMinDistance(), epsilon);
assertEquals(112.924864, analysis.getMaxDistance(), epsilon); assertEquals(112.924864, analysis.getMaxDistance(), epsilon);
assertEquals(77.722986, Analysis.calcWekaWeights(analysis.getDistances()), epsilon); assertEquals(77.722986, Analysis.calcWekaWeights(analysis.getNearestNeighborDistancesXtoY()), epsilon);
analysis.setPotentialType(Potentials.createPotential(PotentialType.NONPARAM, analysis.getMinDistance(), analysis.getMaxDistance(), 41, 0.1)); analysis.setPotentialType(Potentials.createPotential(PotentialType.NONPARAM, analysis.getMinDistance(), analysis.getMaxDistance(), 41, 0.1));
List<CmaResult> results = new ArrayList<CmaResult>(); analysis.cmaOptimization(1, true);
analysis.cmaOptimization(results, 1, true); List<CmaResult> results = analysis.getCmaResults();
epsilon = 1e-6; epsilon = 1e-6;
assertEquals(0.0, results.get(0).iStrength, epsilon); assertEquals(0.0, results.get(0).iStrength, epsilon);
...@@ -116,11 +116,11 @@ public class AnalysisTest extends CommonBase { ...@@ -116,11 +116,11 @@ public class AnalysisTest extends CommonBase {
double epsilon = 1e-6; double epsilon = 1e-6;
assertEquals(0.418188, analysis.getMinDistance(), epsilon); assertEquals(0.418188, analysis.getMinDistance(), epsilon);
assertEquals(112.924864, analysis.getMaxDistance(), epsilon); assertEquals(112.924864, analysis.getMaxDistance(), epsilon);
assertEquals(77.722986, Analysis.calcWekaWeights(analysis.getDistances()), epsilon); assertEquals(77.722986, Analysis.calcWekaWeights(analysis.getNearestNeighborDistancesXtoY()), epsilon);
analysis.setPotentialType(Potentials.createPotential(PotentialType.STEP)); analysis.setPotentialType(Potentials.createPotential(PotentialType.STEP));
List<CmaResult> results = new ArrayList<CmaResult>(); analysis.cmaOptimization(1, true);
analysis.cmaOptimization(results, 1, true); List<CmaResult> results = analysis.getCmaResults();
epsilon = 1e-6; epsilon = 1e-6;
assertEquals(2.4113236274803262, results.get(0).iStrength, epsilon); assertEquals(2.4113236274803262, results.get(0).iStrength, epsilon);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment