C:/cmcintos/defOrgs/source/physical/Phys_Euler.cxx

00001 #ifndef _PHYS_EULER_CXX
00002 #define _PHYS_EULER_CXX
00003 
00004 #include "Phys_Euler.h"
00005 #include "itkImageRegionIterator.h"
00006 namespace mial
00007 {
00008 
00009         template<class DataType, class TGradientImage, int nDims,class MType, class VType>
00010         Phys_Euler<DataType, TGradientImage, nDims,MType,VType>
00011                 ::Phys_Euler(int numNodes ,int numSprings,int numPossibleDeformations,int defK)
00012                 :Physics<DataType,nDims,MType,VType>()
00013         {
00014                 nodes.set_size(numNodes,nDims);
00015 
00016                 nodesV.set_size(numNodes,nDims);
00017 
00018                 nodesF.set_size(numNodes,nDims);
00019 
00020                 nodesA.set_size(numNodes,nDims);
00021 
00022                 nodesM.set_size(numNodes);
00023 
00024                 springsRest.set_size(numSprings);
00025 
00026                 springsDamp.set_size(numSprings);
00027 
00028                 springsNodes.set_size(numSprings,2);
00029 
00030                 springLengths.set_size(numSprings);
00031 
00032                 springsK.set_size(numSprings);
00033 
00034                 timeStep = 0.005;
00035 
00036                 defaultK = defK;
00037                 defaultDamp = 1;
00038                 defaultMass = 1;
00039 
00040                 defaultDrag = 1;
00041 
00042                 gradientPointer = GradientImageType::New();
00043 
00044                 imageForces = true;
00045         }
00046 
00047 
00048         template<class DataType, class TGradientImage, int nDims,class MType, class VType>
00049         void Phys_Euler<DataType, TGradientImage, nDims,MType,VType>::setRestLengths(int* a, DataType* values, int n)
00050         {
00051                 if(this->geom->didTopologyChange())
00052                         updateSpringsFromGeometric();
00053                 for(int i=0; i<n; i++)
00054                 {
00055                         springsRest(a[i]) = values[i];
00056                 }
00057         }
00058 
00059 
00060 
00061         template<class DataType, class TGradientImage, int nDims,class MType, class VType>
00062         void Phys_Euler<DataType, TGradientImage, nDims,MType,VType>::setRestLengths(VectorType a, VectorType values)
00063         {
00064                 if(this->geom->didTopologyChange())
00065                         updateSpringsFromGeometric();
00066                 for(int i=0; i<a.size(); i++)
00067                 {
00068                         springsRest(a(i)) = values(i);
00069                 }
00070         }
00071 
00072 
00073         template<class DataType, class TGradientImage, int nDims,class MType, class VType>
00074         void Phys_Euler<DataType, TGradientImage, nDims,MType,VType>::setSpringLengths(int* a, DataType* values, int n)
00075         {// array a is an array of index values, and array values are the values to set them too.
00076 
00077                 if(this->geom->didTopologyChange())
00078                         updateSpringsFromGeometric();
00079                 for(int i=0; i<n; i++)
00080                 {
00081                         springLengths(a[i]) = values[i];
00082                 }
00083         }
00084 
00085 
00086         template<class DataType, class TGradientImage, int nDims,class MType, class VType>
00087         void Phys_Euler<DataType, TGradientImage, nDims,MType,VType>::setSpringLengths(VectorType a, VectorType values)
00088         {
00089                 if(this->geom->didTopologyChange())
00090                         updateSpringsFromGeometric();
00091                 for(int i=0; i<a.size(); i++)
00092                 {
00093                         springLengths(a(i)) = values(i);
00094                 }
00095         }
00096 
00097 
00098         template<class DataType, class TGradientImage, int nDims,class MType, class VType>
00099         void Phys_Euler<DataType, TGradientImage, nDims,MType,VType>::setSpringsK(int* a, DataType* values, int n)
00100         {// array a is an array of index values, and array values are the values to set them too.
00101 
00102                 if(this->geom->didTopologyChange())
00103                         updateSpringsFromGeometric();
00104                 for(int i=0; i<n; i++)
00105                 {
00106                         springsK(a[i]) = values[i];
00107                 }
00108         }
00109 
00110 
00111         template<class DataType, class TGradientImage, int nDims,class MType, class VType>
00112         void Phys_Euler<DataType, TGradientImage, nDims,MType,VType>::setSpringsK(VectorType a, VectorType values)
00113         {
00114                 if(this->geom->didTopologyChange())
00115                         updateSpringsFromGeometric();
00116                 for(int i=0; i<a.size(); i++)
00117                 {
00118                         springsK(a(i)) = values(i);
00119                 }
00120         }
00121 
00122         template<class DataType, class TGradientImage, int nDims,class MType, class VType>
00123         bool Phys_Euler<DataType, TGradientImage, nDims,MType,VType>::runDeformation(const std::string defName,typename DeformationType::deformationIn* const arg ,std::stringstream * const stream)
00124         {
00125                 //TODO make this more efficient--only update if change, only update new springs etc
00126                 nodes = this->geom->getMatrixNodePositions();
00127                 if(this->geom->didTopologyChange())
00128                         updateSpringsFromGeometric();
00129                 if(this->geom->didNodesChange())
00130                 {
00131                         nodesF.set_size(this->geom->getNumNodes(),nDims);
00132                         nodesF.fill(0);
00133                         nodesA.set_size(this->geom->getNumNodes(),nDims);
00134                         nodesA.fill(0);
00135                         nodesV.set_size(this->geom->getNumNodes(),nDims);
00136                         nodesV.fill(0);
00137                         nodesFDef.set_size(this->geom->getNumNodes(),nDims);
00138                         nodesFDef.fill(0);
00139                         nodesM.set_size(this->geom->getNumNodes());
00140                         nodesM.fill(defaultMass);
00141                 }
00142 
00143                 //std::string defName;
00144                 //args >> defName; //read the first string from the stream
00145                 // TODO: replace O(n) search with lg(n) search on sorted list.
00146                 for(int i=0; i< this->numDeformations; i++)
00147                 {
00148                         if( defName.compare(this->deformationsList[i]->getName()) == 0 )
00149                         {
00150                                 try
00151                                 {
00152                                         typename DeformationType::DefArgSet org;
00153                                         org.nodes = &nodes;
00154                                         org.nodesV = &nodesV;
00155                                         org.nodesF = &nodesF;
00156                                         org.nodesFDef = &nodesFDef;
00157                                         org.springsRest = &springsRest;
00158                                         org.springsNodes = &springsNodes;
00159                                         org.springLengths = &springLengths;
00160 
00161                                         this->deformationsList[i]->run(arg,&org,stream);
00162                                         break;
00163                                 }catch( typename DeformationType::Error * de)
00164                                 {
00165                                         std::cerr << "Could not complete deformation" << std::endl;
00166                                         Error e;
00167                                         e.msg = "Could not complete deformation";
00168                                         e.deformationNumber = i;
00169                                         e.deformationError = de;
00170                                         throw & e;
00171                                 }
00172                         }
00173                 }
00174                 return false;
00175 
00176         }
00177         template<class DataType, class TGradientImage, int nDims,class MType, class VType>
00178         void Phys_Euler<DataType, TGradientImage, nDims,MType,VType>::updateSpringsFromGeometric()
00179         {
00180                 unsigned int length = springsNodes.rows();
00181                 if(this->geom->getNumConnections() != length) //if there are new springs
00182                 {
00183                         unsigned int newLength = this->geom->getNumConnections();
00184 
00185                         springsRest.set_size(newLength);
00186                         springsDamp.set_size(newLength);
00187                         springLengths.set_size(newLength);
00188                         springsK.set_size(newLength);
00189                         springsNodes.set_size(newLength,2);
00190 
00191                         this->springsNodes = this->geom->getMatrixConnections();
00192                         int from;
00193                         int to;
00194                         for(int i=0; i<newLength; i++)
00195                         {
00196                                 from = springsNodes(i,0);
00197                                 to = springsNodes(i,1);
00198                                 if(DEBUG)
00199                                         std::cout << "from: " << from << " to: " << to << " i " << i << std::endl;
00200                                 springsRest(i) = (nodes.get_row(from)-nodes.get_row(to)).two_norm();
00201                                 springsDamp(i) = defaultDamp;
00202                                 springLengths(i) = (nodes.get_row(from)-nodes.get_row(to)).two_norm();
00203                                 springsK(i) =defaultK;
00204                         }
00205 
00206                         // TODO write method for dynamically changing topology that preserves current settings
00207                         // This will need to search the set of current connections for existing springs and find those
00208                         // still in use
00209                         /*
00210                         //Store current values in a temp variable
00211 
00212                         vnl_Vector tmp(length*4);
00213 
00214                         tmp.update(springsRest,0);
00215                         tmp.update(springsDamp, length);
00216                         tmp.update(springLengths, length*2);
00217                         tmp.update(springsK, length*3);
00218 
00219                         vnl_matrix<int> tmp2(length,2);
00220                         tmp2.update(springsNodes,0,0);
00221 
00222                         //Increase the size
00223                         springsRest.set_size(newLength);
00224                         springsDamp.set_size(newLength);
00225                         springLengths.set_size(newLength);
00226                         springsK.set_size(newLength);
00227                         springsNodes.set_size(newLength,2);
00228 
00229                         //Copy back the old values
00230                         for(int i =0; i< length; i++)
00231                         {
00232                         springsRest(i) = tmp(i);
00233 
00234                         //springsDamp.update(tmp.get_n_rows(nodes.rows(),nodes.rows()) ,0);
00235                         springsDamp(i) = tmp(i+length);
00236 
00237                         //springLengths.update(tmp.get_n_rows(nodes.rows()*2,nodes.rows()) ,0);
00238                         springLengths(i) = tmp(i+2*length);
00239 
00240                         //springsK.update(tmp.get_n_rows(nodes.rows()*3,nodes.rows()) ,0);
00241                         springsK(i) = tmp(i+3*length);
00242 
00243                         }
00244                         springsNodes.update(tmp2,0);
00245 
00246                         //Add the new springs
00247                         int from =0;
00248                         int to =0;
00249                         for(int i=length; i<newLength; i++)
00250                         {
00251                         from = 0;
00252                         springsRest(i) = (nodes.get_row(from)-nodes.get_row(to)).two_norm();
00253                         springsDamp(i) = defaultDamp;
00254 
00255                         springLengths(length) = (nodes.get_row(from)-nodes.get_row(to)).two_norm();
00256 
00257                         springsK(length) =defaultK;
00258 
00259                         springsNodes(length,0) = from;
00260                         springsNodes(length,1) = to;*/
00261                 }
00262         }
00263 
00264         template<class DataType, class TGradientImage, int nDims,class MType, class VType>
00265         bool Phys_Euler<DataType, TGradientImage, nDims,MType,VType>::simulate()
00266         {
00267                 //**************************//
00268                 //Simulate the spring forces
00269                 //**************************//
00270 
00271                 //std::cout << "Begin deform" << std::endl;
00272                 //TODO make this more efficient--only update if change, only update new springs etc
00273                 nodes = this->geom->getMatrixNodePositions();
00274                 if(this->geom->didTopologyChange())
00275                         updateSpringsFromGeometric();
00276                 if(this->geom->didNodesChange())
00277                 {
00278                         nodesF.set_size(this->geom->getNumNodes(),nDims);
00279                         nodesF.fill(0);
00280                         nodesA.set_size(this->geom->getNumNodes(),nDims);
00281                         nodesA.fill(0);
00282                         nodesV.set_size(this->geom->getNumNodes(),nDims);
00283                         nodesV.fill(0);
00284                         nodesFDef.set_size(this->geom->getNumNodes(),nDims);
00285                         nodesFDef.fill(0);
00286                         nodesM.set_size(this->geom->getNumNodes());
00287                         nodesM.fill(defaultMass);
00288                 }
00289 
00290                 const unsigned int length = springsNodes.rows();
00291 
00292                 VectorType nodeDist(nDims);
00293                 DataType distNorm;
00294                 VectorType velDiff(nDims);
00295                 DataType D(length);
00296                 VectorType F(nDims);
00297                 VectorType pointForce(nDims);
00298                 VectorType tmpVel(nDims);
00299                 VectorType tmpDist(nDims);
00300                 VectorType tmpPos(nDims);
00301                 //Can't seem to get working with length
00302                 /*      vnl_matrix_fixed<DataType, 4 ,nDims> nodeDist;
00303                 vnl_vector_fixed<DataType,length> distNorm;
00304                 vnl_matrix_fixed<DataType, 4,nDims> velDiff;
00305                 vnl_vector_fixed<DataType, 4> D;
00306                 vnl_matrix_fixed<DataType, 4, nDims> F;*/
00307                 int a,b;
00308                 DataType rowSum = 0;
00309                 int count = 0;
00310                 int runTime = 25;
00311                 int numNodes = nodesF.rows();
00312                 typename GradientImageType::RegionType gradientImageRegion = gradientPointer->GetLargestPossibleRegion();
00313                 itk::ImageRegionIterator<GradientImageType> gradIT(gradientPointer,gradientImageRegion);
00314                 typename GradientImageType::IndexType gradIndex;
00315 
00316                 typedef itk::VectorLinearInterpolateImageFunction< GradientImageType, DataType > InterpolatorType;
00317                 typename InterpolatorType::Pointer interp = InterpolatorType::New();
00318                 typename InterpolatorType::OutputType imgForce;
00319                 interp->SetInputImage(gradientPointer);
00320 
00321                 //      springsRest(0) = 15;
00322                 bool incTime = true;
00323                 DataType dispLength =0;
00324                 DataType localTimeStep = timeStep;
00325 
00326                 DataType endTime = this->time+0.5;
00327                 //while(count<runTime)
00328                 while(this->time< endTime)
00329                 {
00330                         nodesF.fill(0);
00331                         if(count==0)
00332                         {
00333                                 nodesF = nodesF + nodesFDef; //Add on any deformation forces, then zero them
00334                                 nodesFDef.fill(0);
00335                         }
00336                         //Loop over all springs
00337                         for(int i=0; i<length ;i++)
00338                         {
00339 
00340                                 //TODO: Finalize code, vectorize or loop
00341 
00342                                 //A = nodesXYZ(springsNodes(:,2),:)  -  nodesXYZ(springsNodes(:,1),:);
00343                                 //B = sum(nodeDist.^2,2).^0.5; //%norm(Xa-Xb) for all springs (numSprings*1)
00344                                 //C = nodesV(springsNodes(:,2),:)  -  nodesV(springsNodes(:,1),:);
00345                                 nodeDist = nodes.get_row(springsNodes(i,1)) - nodes.get_row(springsNodes(i,0));
00346                                 distNorm = nodeDist.two_norm();
00347                                 velDiff = nodesV.get_row(springsNodes(i,1)) - nodesV.get_row(springsNodes(i,0));
00348 
00349                                 //D = sum(C.*nodeDist,2) ./ B ;// %[(Va-Vb)'(Xa-Xb)]/norm(Xa-Xb)  for all springs (numSprings*1) 
00350                                 for(int j=0; j<nDims;j++)
00351                                 {
00352                                         rowSum+= velDiff(j)*nodeDist(j);
00353                                 }
00354                                 D = rowSum/distNorm;
00355                                 rowSum =0;
00356 
00357                                 //F = repmat((springsK.*(B-springsRest)+springsDamp.*D)./B , [1 3]).*nodeDist; //%spring forces for all springs (numSprings*2)
00358                                 F = (( (springsK(i)* (distNorm-springsRest(i)) + springsDamp(i)*D)/ distNorm ) )* nodeDist;
00359 
00360                                 /*      for k=1:numSprings,  // %assign spring forces to nodes
00361                                 a=springsNodes(k,1);
00362                                 b=springsNodes(k,2);
00363                                 nodesF(a,:)=nodesF(a,:) + F(k,:);
00364                                 nodesF(b,:)=nodesF(b,:) - F(k,:);
00365                                 end*/
00366 
00367                                 a = springsNodes(i,0);
00368                                 b = springsNodes(i,1);
00369                                 nodesF.set_row(a,nodesF.get_row(a)+ F);
00370                                 nodesF.set_row(b,nodesF.get_row(b)- F);
00371                         }
00372                         //Loop over all nodes
00373                         for(int i=0; i<numNodes; i++)
00374                         {
00375                                 //**************************//
00376                                 //Calculate Image forces
00377                                 //**************************//
00378                                 //nodesF=nodesF+ repmat(nodesExt.* nodesBnd , [1,2]) .* [...
00379                                 //     diag(Gx(round(nodesXY(:,2)),round(nodesXY(:,1)))) , ...
00380                                 //   diag(Gy(round(nodesXY(:,2)),round(nodesXY(:,1)))) ];
00381 
00382                                 if(DEBUG)
00383                                 {       std::cout << "node: " << i << " x: " << nodes(i,0) << " y: " << nodes(i,1) << " z: " << nodes(i,2) << std::endl;
00384                                 std::cout << "Force: " << i << " x: " << nodesF(i,0) << " y: " << nodesF(i,1) << " z: " << nodesF(i,2) << std::endl;
00385                                 }
00386 
00387                                 if(imageForces)
00388                                 {
00389                                         for(int a =0; a<nDims; a++)
00390                                         {gradIndex[a] = nodes(i,a);
00391                                         }
00392 
00393                                         //pointForce(0) = interp->EvaluateAtIndex(gradIndex);
00394                                         imgForce = interp->EvaluateAtIndex(gradIndex);
00395 
00396                                         for(int a =0; a<nDims; a++)
00397                                         {pointForce(a) = imgForce[a];
00398                                         }
00399                                         //pointForce(1) = interp->EvaluateAtIndex(gradIndex);
00400 
00401                                         nodesF.set_row(i, nodesF.get_row(i) +  pointForce);
00402                                 }
00403 
00404                                 if(DEBUG)
00405                                         std::cout << "Force: " << i << " x: " << nodesF(i,0) << " y: " << nodesF(i,1) << " z: " << nodesF(i,2) << std::endl;
00406 
00407                                 //std::cout << "Force: " << i << " x: " << 100*pointForce(0) << " y: " << 100*pointForce(1) << " z: " << std::endl;
00408 
00409                                 //**************************//
00410                                 //Drag Force
00411                                 //**************************//
00412                                 nodesF.set_row(i, nodesF.get_row(i)-nodesV.get_row(i)*defaultDrag); 
00413 
00414                                 //**************************//
00415                                 //Apply forces to model
00416                                 //**************************//
00417                                 //nodesA=nodesF ./ (repmat(nodesM,[1,3]));%accelaration
00418                                 //nodesV=DT * nodesA + nodesV;%new velocity 
00419                                 //nodesXYZ=(DT * nodesV).*repmat(1-nodesFxd,[1,3]) + nodesXYZ;%new position 
00420 
00421                                 nodesA.set_row(i, nodesF.get_row(i)/nodesM(i) );
00422                                 if(DEBUG)
00423                                         std::cout << "Accel: " << i << " x: " << nodesA(i,0) << " y: " << nodesA(i,1) << " z: " << std::endl;
00424 
00425                                 //TODO: This method of time adjustment may actually slow down time for some nodes and not others. Fix that.
00426                                 //TODO: Possible solution: Build a tmp matrix of new velocities, stop and update if time changes.
00427                                 while(incTime)
00428                                 {
00429                                         tmpVel = (localTimeStep*nodesA.get_row(i) + nodesV.get_row(i));
00430                                         tmpDist = (localTimeStep * tmpVel );
00431 
00432                                         if(DEBUG)
00433                                         {       std::cout << "Dist: " << i << " x: " << tmpDist(0) << " y: " << tmpDist(1) << " z: " << tmpDist(2) << std::endl;
00434                                         std::cout << "Vel: " << i << " x: " << nodesV(i,0) << " y: " << nodesV(i,1) << " z: " << nodesV(i,2) << std::endl;
00435                                         }
00436 
00437                                         //Check if velocity passes thresh
00438                                         dispLength = abs(tmpDist.two_norm());
00439                                         if(  dispLength< 1 )
00440                                         {
00441                                                 incTime = false;
00442                                                 //Make sure it doesn't push the node outside of the boundary
00443                                                 if(BOUND_CHECKING)
00444                                                 {
00445                                                         tmpPos = tmpDist+nodes.get_row(i);
00446                                                         //TODO: Is there a way without converting to an itk index?
00447                                                         for(int a =0; a<nDims; a++)
00448                                                                 gradIndex[a] = tmpPos(a);
00449                                                         if(gradientImageRegion.IsInside(gradIndex))
00450                                                         {
00451                                                                 nodesV.set_row(i, tmpVel );             
00452                                                                 nodes.set_row(i,tmpPos);
00453                                                         }
00454                                                 }
00455                                                 else
00456                                                 {
00457                                                         nodesV.set_row(i, tmpVel );                     
00458                                                         nodes.set_row(i,  tmpDist + nodes.get_row(i) );
00459                                                 }
00460                                         }
00461                                         else //shrink the time step and recalc.
00462                                         {
00463                                                 //std::cout << "DT : " << localTimeStep << std::endl;
00464                                                 localTimeStep = localTimeStep/2;
00465                                         }
00466                                 }
00467                                 incTime = true;
00468                         }
00469                         count++;
00470                         //Update the time
00471                         this->time = this->time+localTimeStep;
00472                 }
00473                 this->geom->setMatrixNodePositions(nodes);
00474                 //std::cout << "End deform" << std::endl;
00475                 std::cout << "Clock: " << this->time << std::endl;
00476                 return 1;
00477         }
00478 
00479 } // end namespace mial
00480 
00481 #endif

Generated on Wed Jul 19 13:05:18 2006 for IDO by  doxygen 1.4.7