from collections import defaultdict
from.solver_base import SolverBase,Solution
from.aggregate import aggregate_estimations
from.helpers import find_most_profitable_heads_impl
class SimpleTreeSolver(SolverBase):
 def __init__(self,*args,**kwargs):
  super().__init__(*args,**kwargs)
 def solve(self,top_rows,params):
  res=self._solve(top_rows,params)
  self._postprocess_solution(res)
  return res
 def _solve(self,top_rows,params):
  platform=self.platform
  res=Solution()
  res.loop_nests=self.loop_nests
  res.host_estimations=platform.host.generate_estimations(top_rows)
  res.total_time_on_host=sum(res.host_estimations[row]['total_time']for row in top_rows)
  solution,res.accel_estimations=self.find_most_profitable_heads(top_rows,res.host_estimations,res.total_time_on_host,platform,params)
  res.offloads_by_loop_nest,res.offload_heads,res.non_offload_heads=solution[:3]
  res.regions.update(solution[3])
  self._calc_per_region_speed_up(res)
  return res
 def _postprocess_solution(self,solution):
  aggregate_estimations(self.objective_fn,solution,self.settings.get('reestimate_time_for_aggregated'))
 def find_most_profitable_heads(self,top_rows,host_estimations,total_time_on_host,platform,params,):
  accel_estimations=[]
  for accel in platform.accelerators:
   accel_estimations+=accel.generate_estimations(top_rows,settings=params)
  row2est_idx=defaultdict(list)
  for idx,est in enumerate(accel_estimations):
   if len(est['rows'])>1:
    continue
   gain=None
   if est['is_offload_candidate']and est['does_fit']:
    gain=self.objective_fn(est)-sum(self.objective_fn(host_estimations[x])for x in est['rows'])
   for row in est['rows']:
    row2est_idx[row].append((est,gain,idx))
  solution=find_most_profitable_heads_impl(top_rows,self.nests_indices,accel_estimations,row2est_idx,host_estimations,params['min_required_speed_up'],params['max_speed_up_limit'],params['MDT'],params['loop_filter_threshold'],params['unroll_functions'],total_time_on_host,lambda row:platform.accelerators[0].get_per_region_option('model_children',row.selected_by),lambda row:platform.accelerators[0].get_per_region_option('check_profitability',row.selected_by),)
  return solution,accel_estimations
