Commit 4b099e10 authored by gonciarz's avatar gonciarz

Next step in refactoring

parent 789ac98d
package mosaic.ia;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
......@@ -12,8 +13,6 @@ import ij.ImagePlus;
import mosaic.ia.HypothesisTesting.TestResult;
import mosaic.ia.Potentials.Potential;
import mosaic.ia.Potentials.PotentialType;
import mosaic.ia.gui.DistributionsPlot;
import mosaic.ia.gui.EstimatedPotentialPlot;
import mosaic.ia.gui.Utils;
import mosaic.utils.Debug;
import mosaic.utils.math.StatisticsUtils;
......@@ -24,12 +23,17 @@ public class Analysis {
private Potential iPotential;
private DistanceCalculations iDistanceCalculations;
public double[] iContextQdDistancesGrid;
public double[] iContextQdPdf;
public double[] iNearestNeighborDistancesXtoY;
public double[] iNearestNeighborDistancesXtoYPdf;
private double[] iContextQdDistancesGrid;
private double[] iContextQdPdf;
private double[] iNearestNeighborDistancesXtoY;
private double[] iNearestNeighborDistancesXtoYPdf;
private double[] iObservedModelFitPdPdf;
private List<CmaResult> iCmaResults;
private double[][] iBestPointsFound;
private double[] iBestFunctionValue;
private int iBestPointIndex = -1;
public void calcDist(double gridSize, double kernelWeightq, double kernelWeightp, float[][][] genMask, ImagePlus iImageX, ImagePlus iImageY) {
......@@ -69,10 +73,11 @@ public class Analysis {
}
}
public void cmaOptimization(List<CmaResult> aResultsOutput, int cmaReRunTimes, boolean aRepetitiveResults) {
public void cmaOptimization(int cmaReRunTimes, boolean aRepetitiveResults) {
final FitFunction fitfun = new FitFunction(iContextQdPdf, iContextQdDistancesGrid, iNearestNeighborDistancesXtoYPdf, iNearestNeighborDistancesXtoY, iPotential);
iBestPointsFound = new double[cmaReRunTimes][iPotential.numOfDimensions()];
double[] bestFunctionValue = new double[cmaReRunTimes];
iBestFunctionValue = new double[cmaReRunTimes];
iCmaResults = new ArrayList<CmaResult>();
double bestFitness = Double.MAX_VALUE;
boolean diffFitness = false;
......@@ -101,29 +106,27 @@ public class Analysis {
cma.setFitnessOfMeanX(fitfun.valueOf(cma.getMeanX()));
logCmaResultInfo(cma);
bestFunctionValue[cmaRunNumber] = cma.getBestFunctionValue();
if (bestFunctionValue[cmaRunNumber] < bestFitness) {
if (cmaRunNumber > 0 && bestFitness - bestFunctionValue[cmaRunNumber] > bestFunctionValue[cmaRunNumber] * 0.00001) {
iBestFunctionValue[cmaRunNumber] = cma.getBestFunctionValue();
if (iBestFunctionValue[cmaRunNumber] < bestFitness) {
if (cmaRunNumber > 0 && bestFitness - iBestFunctionValue[cmaRunNumber] > iBestFunctionValue[cmaRunNumber] * 0.00001) {
diffFitness = true;
}
bestFitness = bestFunctionValue[cmaRunNumber];
bestFitness = iBestFunctionValue[cmaRunNumber];
iBestPointIndex = cmaRunNumber;
}
iBestPointsFound[cmaRunNumber] = cma.getBestX();
addNewOutputResult(aResultsOutput, bestFunctionValue[cmaRunNumber], iBestPointsFound[cmaRunNumber]);
addNewOutputResult(iCmaResults, iBestFunctionValue[cmaRunNumber], iBestPointsFound[cmaRunNumber]);
}
logger.debug("Best Parameters Found:" + Debug.getString(iBestPointsFound[iBestPointIndex]) + " fit function value=" + bestFunctionValue[iBestPointIndex]);
logger.debug("Best Parameters Found:" + Debug.getString(iBestPointsFound[iBestPointIndex]) + " fit function value=" + iBestFunctionValue[iBestPointIndex]);
if (diffFitness) {
Utils.messageDialog("IA - CMA optimization", "Warning: Optimization returned different results for reruns. The results may not be accurate. Displaying the parameters and the plots corr. to best fitness.");
}
fitfun.l2Norm(iBestPointsFound[iBestPointIndex]); // to calc pgrid for best params
new EstimatedPotentialPlot(iContextQdDistancesGrid, iPotential, iBestPointsFound[iBestPointIndex], bestFunctionValue[iBestPointIndex]).show();
double[] observedModelFitPdPdf = fitfun.getObservedModelFitPdPdf() ;
StatisticsUtils.normalizePdf(observedModelFitPdPdf, iContextQdDistancesGrid, false);
new DistributionsPlot(iContextQdDistancesGrid, observedModelFitPdPdf, iContextQdPdf, iNearestNeighborDistancesXtoYPdf, iPotential, iBestPointsFound[iBestPointIndex], bestFunctionValue[iBestPointIndex]).show();
iObservedModelFitPdPdf = fitfun.getObservedModelFitPdPdf() ;
StatisticsUtils.normalizePdf(iObservedModelFitPdPdf, iContextQdDistancesGrid, false);
}
private void addNewOutputResult(List<CmaResult> aResultsOutput, double aBestFunctionValue, double[] aBestPointFound) {
......@@ -272,10 +275,44 @@ public class Analysis {
return iDistanceCalculations.getMaxXtoYdistance();
}
public double[] getDistances() {
public double[] getContextQdDistancesGrid() {
return iContextQdDistancesGrid;
}
public double[] getContextQdPdf() {
return iContextQdPdf;
}
public double[] getNearestNeighborDistancesXtoY() {
return iNearestNeighborDistancesXtoY;
}
public double[] getNearestNeighborDistancesXtoYPdf() {
return iNearestNeighborDistancesXtoYPdf;
}
public double[] getObservedModelFitPdPdf() {
return iObservedModelFitPdPdf;
}
/**
* @return best point coordinates found by CMA optimization
*/
public double[] getBestPointFound() {
return iBestPointsFound[iBestPointIndex];
}
/**
* @return best function value found by CMA optimization
*/
public double getBestFunctionValue() {
return iBestFunctionValue[iBestPointIndex];
}
public List<CmaResult> getCmaResults() {
return iCmaResults;
}
public void setPotentialType(Potential potentialType) {
iPotential = potentialType;
}
......
......@@ -6,7 +6,6 @@ import java.util.ArrayList;
import org.apache.log4j.Logger;
import org.scijava.vecmath.Point3d;
import ij.IJ;
import mosaic.ia.gui.Utils;
import mosaic.utils.Debug;
import mosaic.utils.math.NearestNeighborTree;
......
......@@ -6,7 +6,6 @@ import java.util.Random;
import org.apache.log4j.Logger;
import ij.IJ;
import mosaic.ia.Potentials.Potential;
import mosaic.ia.gui.Utils;
......
......@@ -130,8 +130,8 @@ public class InteractionAnalysisGui extends InteractionAnalysisGuiBase {
int numReRuns = Integer.parseInt(reRuns.getText());
Potential potential = Potentials.createPotential(getPotential(), iAnalysis.getMinDistance(), iAnalysis.getMaxDistance(), numOfSupportPointsValue, smoothnessValue);
iAnalysis.setPotentialType(potential); // for the first time
List<CmaResult> results = new ArrayList<CmaResult>();
iAnalysis.cmaOptimization(results, numReRuns, false);
iAnalysis.cmaOptimization(numReRuns, false);
List<CmaResult> results = iAnalysis.getCmaResults();
mosaic.utils.Debug.print(results);
if (!Interpreter.batchMode) {
final ResultsTable rt = new ResultsTable();
......@@ -145,6 +145,9 @@ public class InteractionAnalysisGui extends InteractionAnalysisGuiBase {
}
rt.updateResults();
rt.show("Results");
new EstimatedPotentialPlot(iAnalysis.getContextQdDistancesGrid(), potential, iAnalysis.getBestPointFound(), iAnalysis.getBestFunctionValue()).show();
new DistributionsPlot(iAnalysis.getContextQdDistancesGrid(), iAnalysis.getObservedModelFitPdPdf(), iAnalysis.getContextQdPdf(), iAnalysis.getNearestNeighborDistancesXtoYPdf(), potential, iAnalysis.getBestPointFound(), iAnalysis.getBestFunctionValue()).show();
}
}
......@@ -220,9 +223,10 @@ public class InteractionAnalysisGui extends InteractionAnalysisGuiBase {
throw new RuntimeException("Unknown tab chosen in IA GUI");
}
new DistributionsPlot(iAnalysis.iContextQdDistancesGrid, iAnalysis.iContextQdPdf, iAnalysis.iNearestNeighborDistancesXtoYPdf).show();
Utils.plotHistogram("ObservedDistances", iAnalysis.iNearestNeighborDistancesXtoY, Analysis.getOptimBins(iAnalysis.iNearestNeighborDistancesXtoY, 8, iAnalysis.iNearestNeighborDistancesXtoY.length / 8));
double suggestedKernel = Analysis.calcWekaWeights(iAnalysis.iNearestNeighborDistancesXtoY);
// Generate plots / info for a user
new DistributionsPlot(iAnalysis.getContextQdDistancesGrid(), iAnalysis.getContextQdPdf(), iAnalysis.getNearestNeighborDistancesXtoYPdf()).show();
Utils.plotHistogram("ObservedDistances", iAnalysis.getNearestNeighborDistancesXtoY(), Analysis.getOptimBins(iAnalysis.getNearestNeighborDistancesXtoY(), 8, iAnalysis.getNearestNeighborDistancesXtoY().length / 8));
double suggestedKernel = Analysis.calcWekaWeights(iAnalysis.getNearestNeighborDistancesXtoY());
Utils.messageDialog("IA - kernel", "Suggested Kernel wt(p): " + suggestedKernel);
logger.debug("Suggested kernel wt(p)=" + suggestedKernel);
}
......
......@@ -43,7 +43,7 @@ public class AnalysisTest extends CommonBase {
double epsilon = 1e-10;
assertEquals(0, analysis.getMinDistance(), epsilon);
assertEquals(2, analysis.getMaxDistance(), epsilon);
assertArrayEquals(new double[] {2, 1, 0, 1, 2}, analysis.getDistances(), epsilon);
assertArrayEquals(new double[] {2, 1, 0, 1, 2}, analysis.getNearestNeighborDistancesXtoY(), epsilon);
}
@Test
......@@ -53,10 +53,10 @@ public class AnalysisTest extends CommonBase {
double epsilon = 1e-6;
assertEquals(0.418188, analysis.getMinDistance(), epsilon);
assertEquals(112.924864, analysis.getMaxDistance(), epsilon);
assertEquals(77.722986, Analysis.calcWekaWeights(analysis.getDistances()), epsilon);
assertEquals(77.722986, Analysis.calcWekaWeights(analysis.getNearestNeighborDistancesXtoY()), epsilon);
analysis.setPotentialType(Potentials.createPotential(PotentialType.HERNQUIST));
List<CmaResult> results = new ArrayList<CmaResult>();
analysis.cmaOptimization(results, 1, true);
analysis.cmaOptimization(1, true);
List<CmaResult> results = analysis.getCmaResults();
epsilon = 1e-6;
assertEquals(36.545593028698356, results.get(0).iStrength, epsilon);
......@@ -94,11 +94,11 @@ public class AnalysisTest extends CommonBase {
double epsilon = 1e-6;
assertEquals(0.418188, analysis.getMinDistance(), epsilon);
assertEquals(112.924864, analysis.getMaxDistance(), epsilon);
assertEquals(77.722986, Analysis.calcWekaWeights(analysis.getDistances()), epsilon);
assertEquals(77.722986, Analysis.calcWekaWeights(analysis.getNearestNeighborDistancesXtoY()), epsilon);
analysis.setPotentialType(Potentials.createPotential(PotentialType.NONPARAM, analysis.getMinDistance(), analysis.getMaxDistance(), 41, 0.1));
List<CmaResult> results = new ArrayList<CmaResult>();
analysis.cmaOptimization(results, 1, true);
analysis.cmaOptimization(1, true);
List<CmaResult> results = analysis.getCmaResults();
epsilon = 1e-6;
assertEquals(0.0, results.get(0).iStrength, epsilon);
......@@ -116,11 +116,11 @@ public class AnalysisTest extends CommonBase {
double epsilon = 1e-6;
assertEquals(0.418188, analysis.getMinDistance(), epsilon);
assertEquals(112.924864, analysis.getMaxDistance(), epsilon);
assertEquals(77.722986, Analysis.calcWekaWeights(analysis.getDistances()), epsilon);
assertEquals(77.722986, Analysis.calcWekaWeights(analysis.getNearestNeighborDistancesXtoY()), epsilon);
analysis.setPotentialType(Potentials.createPotential(PotentialType.STEP));
List<CmaResult> results = new ArrayList<CmaResult>();
analysis.cmaOptimization(results, 1, true);
analysis.cmaOptimization(1, true);
List<CmaResult> results = analysis.getCmaResults();
epsilon = 1e-6;
assertEquals(2.4113236274803262, results.get(0).iStrength, epsilon);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment