Skip to content

Commit

Permalink
Bugfix: Import allowed with transition rewards only.
Browse files Browse the repository at this point in the history
  • Loading branch information
davexparker committed Sep 30, 2024
1 parent 43c6311 commit 9e1ebd0
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
4 changes: 4 additions & 0 deletions prism-tests/functionality/import/dice.pm.importexport.auto
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
-importmodel dice.all -exportmodel dice.all -h
-importmodel dice.all -exportmodel dice.all -s

# Partial
-importmodel dice.tra,lab,srew -exportmodel dice.tra,lab,srew
-importmodel dice.tra,sta,trew -exportmodel dice.tra,sta,trew

# Import model info - exact
-importmodel dice.all -exportmodel dice.exact.all -exact
-importmodel dice.exact.all -exportmodel dice.exact.all -exact
17 changes: 9 additions & 8 deletions prism/src/io/PrismExplicitImporter.java
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ public class PrismExplicitImporter implements ExplicitModelImporter

// File(s) to read in rewards from
private List<PrismExplicitImporter.RewardFile> stateRewardsReaders = new ArrayList<>();
private List<PrismExplicitImporter.RewardFile> transRewardsRewaders = new ArrayList<>();
private List<PrismExplicitImporter.RewardFile> transRewardsReaders = new ArrayList<>();

// Regex for comments
protected static final Pattern COMMENT_PATTERN = Pattern.compile("#.*");
Expand Down Expand Up @@ -137,9 +137,9 @@ public PrismExplicitImporter(File statesFile, File transFile, File labelsFile, L
for (File file : this.stateRewardsFiles) {
this.stateRewardsReaders.add(new PrismExplicitImporter.RewardFile(file));
}
this.transRewardsRewaders = new ArrayList<>(this.transRewardsFiles.size());
this.transRewardsReaders = new ArrayList<>(this.transRewardsFiles.size());
for (File file : this.transRewardsFiles) {
this.transRewardsRewaders.add(new PrismExplicitImporter.RewardFile(file));
this.transRewardsReaders.add(new PrismExplicitImporter.RewardFile(file));
}
}

Expand Down Expand Up @@ -902,7 +902,8 @@ private void buildRewardInfo() throws PrismException
@Override
public List<String> getRewardStructNames()
{
return Reducible.extend(stateRewardsReaders).map(f -> f.getName().orElse("")).collect(new ArrayList<>(stateRewardsReaders.size()));
List<PrismExplicitImporter.RewardFile> rewardsReaders = stateRewardsReaders.size() >= transRewardsFiles.size() ? stateRewardsReaders : transRewardsReaders;
return Reducible.extend(rewardsReaders).map(f -> f.getName().orElse("")).collect(new ArrayList<>(rewardsReaders.size()));
}

@Override
Expand Down Expand Up @@ -931,17 +932,17 @@ public <Value> void extractStateRewards(int rewardIndex, BiConsumer<Integer, Val
@Override
public <Value> void extractMCTransitionRewards(int rewardIndex, IOUtils.TransitionRewardConsumer<Value> storeReward, Evaluator<Value> eval) throws PrismException
{
if (rewardIndex < transRewardsRewaders.size()) {
RewardFile file = transRewardsRewaders.get(rewardIndex);
if (rewardIndex < transRewardsReaders.size()) {
RewardFile file = transRewardsReaders.get(rewardIndex);
file.extractMCTransitionRewards(storeReward, eval, numStates);
}
}

@Override
public <Value> void extractMDPTransitionRewards(int rewardIndex, IOUtils.TransitionStateRewardConsumer<Value> storeReward, Evaluator<Value> eval) throws PrismException
{
if (rewardIndex < transRewardsRewaders.size()) {
RewardFile file = transRewardsRewaders.get(rewardIndex);
if (rewardIndex < transRewardsReaders.size()) {
RewardFile file = transRewardsReaders.get(rewardIndex);
file.extractMDPTransitionRewards(storeReward, eval, numStates);
}
}
Expand Down

0 comments on commit 9e1ebd0

Please # to comment.